import torch import models_Facies, models_Fault import timm from util.datasets import ThebeSet, P3DFaciesSet from util.pos_embed import interpolate_pos_embed import random import huggingface_hub from huggingface_hub import hf_hub_download from PIL import Image import numpy as np from matplotlib import cm from PIL import ImageFilter HFACE_FAULTS = "checkpoint-24.pth" HFACE_FACIES = "checkpoint-49.pth" FAULT_DATA_PATH = "C:\\Users\\abhalekar\\Desktop\\DATASETS\\Thebe_DATASET\\crossline_combined_data" FACIES_DATA_PATH = "C:\\Users\\abhalekar\\Desktop\\DATASETS\\P3D_Vol_DATASET" def predict(seismic: torch.Tensor, task='Fault', model_type='vit_large_patch16', device = 'cpu', hface = True, thresh = 0.5): if task == 'Fault': model = models_Fault.__dict__[model_type]( img_size=768, num_classes=1, drop_path_rate=0.1, in_chans=1, ) checkpoint_path = hf_hub_download(repo_id="Ani24/SFM_Finetuned", filename=HFACE_FAULTS, subfolder="ckpts-Tversky-Neut") elif task == 'Facies': model = models_Facies.__dict__[model_type]( img_size=128, num_classes=6, drop_path_rate=0.1, in_chans=1, ) checkpoint_path = hf_hub_download(repo_id="Ani24/SFM_Finetuned", filename=HFACE_FACIES, subfolder="ckpts-RSVSFacies-P3D") else: raise ValueError(f"Task not configured yet: {task}") model.to(device) checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False) checkpoint_model = checkpoint['model'] state_dict = model.state_dict() for k in ['head.weight', 'head.bias']: if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape: print(f"Removing key {k} from pretrained checkpoint") del checkpoint_model[k] interpolate_pos_embed(model, checkpoint_model) msg = model.load_state_dict(checkpoint_model, strict=False) print(msg) print("Seismic data shape:", seismic.shape) with torch.no_grad(): output = model(seismic.unsqueeze(0)) output = output.squeeze(0) if task in ['Fault']: output = torch.nn.functional.sigmoid(output) output = output.detach().cpu().numpy()[0, :, :] elif task in ['Facies']: output = output.argmax(dim=0) output = output.detach().cpu().numpy() output_image = output/ output.max() # Normalize output to [0, 1] range # output is numpy 2d array - convert to pil RGB image output_image = Image.fromarray((output_image * 255).astype(np.uint8)).convert("RGB") return output_image, output def random_sample(task = 'Fault', data_path = None, batch_size=1, num_workers=0): if task == 'Fault': data_path = FAULT_DATA_PATH dataset = ThebeSet(data_path, [768, 768], 'test') elif task == 'Facies': data_path = FACIES_DATA_PATH dataset = P3DFaciesSet(data_path, mode = 'train') else: raise ValueError(f"Task not configured yet: {task}") index = random.randint(0, len(dataset) - 1) seis, label = dataset[index] seis_image = seis.detach().cpu().numpy().squeeze(0) seis_image = (seis_image - seis_image.min()) / (seis_image.max() - seis_image.min()) # Normalize to [0, 1] range seis_image = Image.fromarray(np.uint8(cm.seismic(seis_image) * 255)) # Convert to PIL Image return seis_image, seis def overlay_images(seismic_image: Image, prediction_image: Image, alpha = 0.5) -> Image: # Create an overlay of the predicted facies/faults on the original seismic image prediction_image = Image.fromarray(np.array(prediction_image).astype(np.uint8)).convert("RGBA") seismic_image = Image.fromarray(np.array(seismic_image).astype(np.uint8)).convert("RGBA") prediction_image.putalpha(int(255 * alpha)) # Set alpha for overlay overlay_image = Image.alpha_composite(seismic_image, prediction_image) return overlay_image def post_process(processed_prediction_image: Image, prediction_image: Image, method: str = 'None', value = None) -> Image: if method == 'None': return processed_prediction_image elif method == 'Thresholding': return apply_thresholding(processed_prediction_image) elif method == 'Closing': return apply_closing(processed_prediction_image, value) elif method == 'Opening': return apply_opening(processed_prediction_image, value) elif method == 'Canny Edge': return apply_canny_edge(processed_prediction_image, value) elif method == 'Gaussian Smoothing': return apply_gaussian_smoothing(processed_prediction_image, value) elif method == 'Hysteresis': return apply_hysteresis(processed_prediction_image, value) else: raise ValueError(f"Unknown post-processing method: {method}") def apply_thresholding(image: Image, value: int) -> Image: return image.point(lambda p: p > value and 255) def apply_closing(image: Image, value: int) -> Image: # Apply closing (dilation followed by erosion) return image.filter(ImageFilter.MaxFilter(size=value)).filter(ImageFilter.MinFilter(size=value)) def apply_opening(image: Image, value: int) -> Image: # Apply opening (erosion followed by dilation) return image.filter(ImageFilter.MinFilter(size=value)).filter(ImageFilter.MaxFilter(size=value)) def apply_canny_edge(image: Image, value: int) -> Image: return image.filter(ImageFilter.FIND_EDGES) def apply_gaussian_smoothing(image: Image, value: int) -> Image: return image.filter(ImageFilter.GaussianBlur(radius=value)) def apply_hysteresis(image: Image, value: int) -> Image: return image.point(lambda p: p > value and 255) # Simple thresholding for hysteresis