File size: 5,770 Bytes
f129f93
 
 
 
 
 
1fc8c87
 
 
 
 
b7ab39c
f129f93
1fc8c87
 
 
 
 
 
 
f129f93
 
 
 
 
 
 
1fc8c87
f129f93
 
1fc8c87
f129f93
 
 
 
1fc8c87
f129f93
 
 
 
1fc8c87
 
f129f93
 
 
 
 
 
 
 
 
 
 
 
 
1fc8c87
f129f93
 
1fc8c87
f129f93
 
 
 
 
 
 
 
 
 
a81555c
1fc8c87
a81555c
 
 
f129f93
 
1fc8c87
f129f93
1fc8c87
f129f93
 
1fc8c87
f129f93
 
 
 
 
 
1fc8c87
 
 
 
f129f93
a81555c
 
 
 
 
 
 
 
 
 
 
 
 
 
b7ab39c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import torch 
import models_Facies, models_Fault
import timm
from util.datasets import ThebeSet, P3DFaciesSet
from util.pos_embed import interpolate_pos_embed
import random
import huggingface_hub
from huggingface_hub import hf_hub_download
from PIL import Image
import numpy as np
from matplotlib import cm 
from PIL import ImageFilter

HFACE_FAULTS = "checkpoint-24.pth"
HFACE_FACIES = "checkpoint-49.pth"

FAULT_DATA_PATH = "C:\\Users\\abhalekar\\Desktop\\DATASETS\\Thebe_DATASET\\crossline_combined_data"
FACIES_DATA_PATH = "C:\\Users\\abhalekar\\Desktop\\DATASETS\\P3D_Vol_DATASET"

def predict(seismic: torch.Tensor, task='Fault', model_type='vit_large_patch16', device = 'cpu', hface = True, thresh = 0.5): 
    if task == 'Fault':
        model = models_Fault.__dict__[model_type](
            img_size=768,
            num_classes=1,
            drop_path_rate=0.1,
            in_chans=1,
        )
        checkpoint_path = hf_hub_download(repo_id="Ani24/SFM_Finetuned", filename=HFACE_FAULTS, subfolder="ckpts-Tversky-Neut")
    elif task == 'Facies':
        model = models_Facies.__dict__[model_type](
            img_size=128,
            num_classes=6,
            drop_path_rate=0.1,
            in_chans=1,
        )
        checkpoint_path = hf_hub_download(repo_id="Ani24/SFM_Finetuned", filename=HFACE_FACIES, subfolder="ckpts-RSVSFacies-P3D")
    else:
        raise ValueError(f"Task not configured yet: {task}")

    model.to(device)
    checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)

    checkpoint_model = checkpoint['model']
    state_dict = model.state_dict()

    for k in ['head.weight', 'head.bias']:
        if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
            print(f"Removing key {k} from pretrained checkpoint")
            del checkpoint_model[k]

    interpolate_pos_embed(model, checkpoint_model)
    msg = model.load_state_dict(checkpoint_model, strict=False)
    print(msg)


    print("Seismic data shape:", seismic.shape)

    with torch.no_grad():
        output = model(seismic.unsqueeze(0))
    output = output.squeeze(0)
    
    if task in ['Fault']: 
        output = torch.nn.functional.sigmoid(output)
        output = output.detach().cpu().numpy()[0, :, :]

    elif task in ['Facies']:
        output = output.argmax(dim=0)
        output = output.detach().cpu().numpy()

    output_image = output/ output.max()  # Normalize output to [0, 1] range
    # output is numpy 2d array - convert to pil RGB image
    output_image = Image.fromarray((output_image * 255).astype(np.uint8)).convert("RGB")
    
    return output_image, output


def random_sample(task = 'Fault', data_path = None, batch_size=1, num_workers=0):
    if task == 'Fault':
        data_path = FAULT_DATA_PATH
        dataset = ThebeSet(data_path, [768, 768], 'test')
    elif task == 'Facies':
        data_path = FACIES_DATA_PATH
        dataset = P3DFaciesSet(data_path, mode = 'train')
    else:
        raise ValueError(f"Task not configured yet: {task}")
    
    index = random.randint(0, len(dataset) - 1)
    seis, label = dataset[index]
    
    seis_image = seis.detach().cpu().numpy().squeeze(0)
    seis_image = (seis_image - seis_image.min()) / (seis_image.max() - seis_image.min())  # Normalize to [0, 1] range
    seis_image = Image.fromarray(np.uint8(cm.seismic(seis_image) * 255))  # Convert to PIL Image

    return seis_image, seis


def overlay_images(seismic_image: Image, prediction_image: Image, alpha = 0.5) -> Image:
    # Create an overlay of the predicted facies/faults on the original seismic image
    prediction_image = Image.fromarray(np.array(prediction_image).astype(np.uint8)).convert("RGBA")
    seismic_image = Image.fromarray(np.array(seismic_image).astype(np.uint8)).convert("RGBA")
    
    prediction_image.putalpha(int(255 * alpha))  # Set alpha for overlay

    overlay_image = Image.alpha_composite(seismic_image, prediction_image)
    return overlay_image


def post_process(processed_prediction_image: Image, prediction_image: Image, method: str = 'None', value = None) -> Image:
    if method == 'None':
        return processed_prediction_image
    elif method == 'Thresholding':
        return apply_thresholding(processed_prediction_image)
    elif method == 'Closing':
        return apply_closing(processed_prediction_image, value)
    elif method == 'Opening':
        return apply_opening(processed_prediction_image, value)
    elif method == 'Canny Edge':
        return apply_canny_edge(processed_prediction_image, value)
    elif method == 'Gaussian Smoothing':
        return apply_gaussian_smoothing(processed_prediction_image, value)
    elif method == 'Hysteresis':
        return apply_hysteresis(processed_prediction_image, value)
    else:
        raise ValueError(f"Unknown post-processing method: {method}")


def apply_thresholding(image: Image, value: int) -> Image:
    return image.point(lambda p: p > value and 255)
def apply_closing(image: Image, value: int) -> Image:
    # Apply closing (dilation followed by erosion)
    return image.filter(ImageFilter.MaxFilter(size=value)).filter(ImageFilter.MinFilter(size=value))
def apply_opening(image: Image, value: int) -> Image:
    # Apply opening (erosion followed by dilation)
    return image.filter(ImageFilter.MinFilter(size=value)).filter(ImageFilter.MaxFilter(size=value))
def apply_canny_edge(image: Image, value: int) -> Image:
    return image.filter(ImageFilter.FIND_EDGES)
def apply_gaussian_smoothing(image: Image, value: int) -> Image:
    return image.filter(ImageFilter.GaussianBlur(radius=value))
def apply_hysteresis(image: Image, value: int) -> Image:
    return image.point(lambda p: p > value and 255)  # Simple thresholding for hysteresis