File size: 3,571 Bytes
aa04f76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6400203
 
aa04f76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import gradio as gr
import torch
import numpy as np
from PIL import Image
import yaml
import os
from models.unet import UNet

# Configuration
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Map dataset names to config/model paths
CONFIG_PATHS = {
    'isic': './configs/isic/isic2018_unet.yaml',
    'segpc': './configs/segpc/segpc2021_unet.yaml'
}
MODEL_PATHS = {
    'isic': './saved_models/isic2018_unet/best_model_state_dict.pt',
    'segpc': './saved_models/segpc2021_unet/best_model_state_dict.pt'
}

def load_config(config_path):
    with open(config_path, 'r') as f:
        return yaml.safe_load(f)

def load_model(dataset_name):
    config = load_config(CONFIG_PATHS[dataset_name])
    model = UNet(
        in_channels=config['model']['params']['in_channels'],
        out_channels=config['model']['params']['out_channels']
    )
    model_path = MODEL_PATHS[dataset_name]
    if os.path.exists(model_path):
        state_dict = torch.load(model_path, map_location=DEVICE)
        model.load_state_dict(state_dict)
        print(f"Loaded model for {dataset_name} from {model_path}")
    else:
        print(f"Warning: Model weights not found for {dataset_name} at {model_path}")
    
    model.to(DEVICE)
    model.eval()
    return model

# Load models once (cache them)
models = {}
for ds in ['isic', 'segpc']:
    try:
        models[ds] = load_model(ds)
    except Exception as e:
        print(f"Error loading model {ds}: {e}")

def predict(image, dataset_choice):
    if image is None:
        return None
    
    if dataset_choice not in models:
        return None
    
    model = models[dataset_choice]
    
    # Preprocess
    # Resize to 224x224 as per config
    img_resized = image.resize((224, 224))
    img_np = np.array(img_resized).astype(np.float32) / 255.0
    
    # Handle channels
    if dataset_choice == 'isic':
        # ISIC: 3 channels (RGB)
        if img_np.shape[-1] == 4:
            img_np = img_np[:, :, :3]
        img_tensor = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).float()
    else:
        # SegPC: 4 channels (BMP input often loaded as RGB, need to assume/check)
        if img_np.shape[-1] == 3:
             # Create fake 4th channel
             padding = np.zeros((224, 224, 1), dtype=np.float32)
             img_np = np.concatenate([img_np, padding], axis=-1)
             
        img_tensor = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).float()

    img_tensor = img_tensor.to(DEVICE)
    
    with torch.no_grad():
        output = model(img_tensor)
        probs = torch.sigmoid(output)
        pred_mask = (probs > 0.5).float().cpu().numpy()[0, 0]
    
    # Post-process for visualization
    # Create an overlay
    base_img = np.array(img_resized)
    overlay = base_img.copy()
    
    # Green mask
    mask_bool = pred_mask > 0
    overlay[mask_bool] = [0, 255, 0] # Make Green
    
    # Blend
    final_img = (0.6 * base_img + 0.4 * overlay).astype(np.uint8)
    
    return final_img

# Interface
iface = gr.Interface(
    fn=predict,
    inputs=[
        gr.Image(type="pil", label="Input Image"),
        gr.Radio(["isic", "segpc"], label="Dataset Model", value="isic")
    ],
    outputs=gr.Image(type="numpy", label="Prediction Overlay"),
    title="Medical Image Segmentation (Awesome-U-Net)",
    description="Upload an image to segment skin lesions (ISIC) or cells (SegPC).",
    examples=[
        # Add example paths if available
        # ["dataset_examples/isic_sample.jpg", "isic"]
    ]
)

if __name__ == "__main__":
    iface.launch()