Spaces:
Build error
Build error
File size: 5,770 Bytes
f129f93 1fc8c87 b7ab39c f129f93 1fc8c87 f129f93 1fc8c87 f129f93 1fc8c87 f129f93 1fc8c87 f129f93 1fc8c87 f129f93 1fc8c87 f129f93 1fc8c87 f129f93 a81555c 1fc8c87 a81555c f129f93 1fc8c87 f129f93 1fc8c87 f129f93 1fc8c87 f129f93 1fc8c87 f129f93 a81555c b7ab39c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import torch
import models_Facies, models_Fault
import timm
from util.datasets import ThebeSet, P3DFaciesSet
from util.pos_embed import interpolate_pos_embed
import random
import huggingface_hub
from huggingface_hub import hf_hub_download
from PIL import Image
import numpy as np
from matplotlib import cm
from PIL import ImageFilter
HFACE_FAULTS = "checkpoint-24.pth"
HFACE_FACIES = "checkpoint-49.pth"
FAULT_DATA_PATH = "C:\\Users\\abhalekar\\Desktop\\DATASETS\\Thebe_DATASET\\crossline_combined_data"
FACIES_DATA_PATH = "C:\\Users\\abhalekar\\Desktop\\DATASETS\\P3D_Vol_DATASET"
def predict(seismic: torch.Tensor, task='Fault', model_type='vit_large_patch16', device = 'cpu', hface = True, thresh = 0.5):
if task == 'Fault':
model = models_Fault.__dict__[model_type](
img_size=768,
num_classes=1,
drop_path_rate=0.1,
in_chans=1,
)
checkpoint_path = hf_hub_download(repo_id="Ani24/SFM_Finetuned", filename=HFACE_FAULTS, subfolder="ckpts-Tversky-Neut")
elif task == 'Facies':
model = models_Facies.__dict__[model_type](
img_size=128,
num_classes=6,
drop_path_rate=0.1,
in_chans=1,
)
checkpoint_path = hf_hub_download(repo_id="Ani24/SFM_Finetuned", filename=HFACE_FACIES, subfolder="ckpts-RSVSFacies-P3D")
else:
raise ValueError(f"Task not configured yet: {task}")
model.to(device)
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
checkpoint_model = checkpoint['model']
state_dict = model.state_dict()
for k in ['head.weight', 'head.bias']:
if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
print(f"Removing key {k} from pretrained checkpoint")
del checkpoint_model[k]
interpolate_pos_embed(model, checkpoint_model)
msg = model.load_state_dict(checkpoint_model, strict=False)
print(msg)
print("Seismic data shape:", seismic.shape)
with torch.no_grad():
output = model(seismic.unsqueeze(0))
output = output.squeeze(0)
if task in ['Fault']:
output = torch.nn.functional.sigmoid(output)
output = output.detach().cpu().numpy()[0, :, :]
elif task in ['Facies']:
output = output.argmax(dim=0)
output = output.detach().cpu().numpy()
output_image = output/ output.max() # Normalize output to [0, 1] range
# output is numpy 2d array - convert to pil RGB image
output_image = Image.fromarray((output_image * 255).astype(np.uint8)).convert("RGB")
return output_image, output
def random_sample(task = 'Fault', data_path = None, batch_size=1, num_workers=0):
if task == 'Fault':
data_path = FAULT_DATA_PATH
dataset = ThebeSet(data_path, [768, 768], 'test')
elif task == 'Facies':
data_path = FACIES_DATA_PATH
dataset = P3DFaciesSet(data_path, mode = 'train')
else:
raise ValueError(f"Task not configured yet: {task}")
index = random.randint(0, len(dataset) - 1)
seis, label = dataset[index]
seis_image = seis.detach().cpu().numpy().squeeze(0)
seis_image = (seis_image - seis_image.min()) / (seis_image.max() - seis_image.min()) # Normalize to [0, 1] range
seis_image = Image.fromarray(np.uint8(cm.seismic(seis_image) * 255)) # Convert to PIL Image
return seis_image, seis
def overlay_images(seismic_image: Image, prediction_image: Image, alpha = 0.5) -> Image:
# Create an overlay of the predicted facies/faults on the original seismic image
prediction_image = Image.fromarray(np.array(prediction_image).astype(np.uint8)).convert("RGBA")
seismic_image = Image.fromarray(np.array(seismic_image).astype(np.uint8)).convert("RGBA")
prediction_image.putalpha(int(255 * alpha)) # Set alpha for overlay
overlay_image = Image.alpha_composite(seismic_image, prediction_image)
return overlay_image
def post_process(processed_prediction_image: Image, prediction_image: Image, method: str = 'None', value = None) -> Image:
if method == 'None':
return processed_prediction_image
elif method == 'Thresholding':
return apply_thresholding(processed_prediction_image)
elif method == 'Closing':
return apply_closing(processed_prediction_image, value)
elif method == 'Opening':
return apply_opening(processed_prediction_image, value)
elif method == 'Canny Edge':
return apply_canny_edge(processed_prediction_image, value)
elif method == 'Gaussian Smoothing':
return apply_gaussian_smoothing(processed_prediction_image, value)
elif method == 'Hysteresis':
return apply_hysteresis(processed_prediction_image, value)
else:
raise ValueError(f"Unknown post-processing method: {method}")
def apply_thresholding(image: Image, value: int) -> Image:
return image.point(lambda p: p > value and 255)
def apply_closing(image: Image, value: int) -> Image:
# Apply closing (dilation followed by erosion)
return image.filter(ImageFilter.MaxFilter(size=value)).filter(ImageFilter.MinFilter(size=value))
def apply_opening(image: Image, value: int) -> Image:
# Apply opening (erosion followed by dilation)
return image.filter(ImageFilter.MinFilter(size=value)).filter(ImageFilter.MaxFilter(size=value))
def apply_canny_edge(image: Image, value: int) -> Image:
return image.filter(ImageFilter.FIND_EDGES)
def apply_gaussian_smoothing(image: Image, value: int) -> Image:
return image.filter(ImageFilter.GaussianBlur(radius=value))
def apply_hysteresis(image: Image, value: int) -> Image:
return image.point(lambda p: p > value and 255) # Simple thresholding for hysteresis |