|
|
import gradio as gr |
|
|
import torch |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
import yaml |
|
|
import os |
|
|
from models.unet import UNet |
|
|
|
|
|
|
|
|
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
CONFIG_PATHS = { |
|
|
'isic': './configs/isic/isic2018_unet.yaml', |
|
|
'segpc': './configs/segpc/segpc2021_unet.yaml' |
|
|
} |
|
|
MODEL_PATHS = { |
|
|
'isic': './saved_models/isic2018_unet/best_model_state_dict.pt', |
|
|
'segpc': './saved_models/segpc2021_unet/best_model_state_dict.pt' |
|
|
} |
|
|
|
|
|
def load_config(config_path): |
|
|
with open(config_path, 'r') as f: |
|
|
return yaml.safe_load(f) |
|
|
|
|
|
def load_model(dataset_name): |
|
|
config = load_config(CONFIG_PATHS[dataset_name]) |
|
|
model = UNet( |
|
|
in_channels=config['model']['params']['in_channels'], |
|
|
out_channels=config['model']['params']['out_channels'] |
|
|
) |
|
|
model_path = MODEL_PATHS[dataset_name] |
|
|
if os.path.exists(model_path): |
|
|
state_dict = torch.load(model_path, map_location=DEVICE) |
|
|
model.load_state_dict(state_dict) |
|
|
print(f"Loaded model for {dataset_name} from {model_path}") |
|
|
else: |
|
|
print(f"Warning: Model weights not found for {dataset_name} at {model_path}") |
|
|
|
|
|
model.to(DEVICE) |
|
|
model.eval() |
|
|
return model |
|
|
|
|
|
|
|
|
models = {} |
|
|
for ds in ['isic', 'segpc']: |
|
|
try: |
|
|
models[ds] = load_model(ds) |
|
|
except Exception as e: |
|
|
print(f"Error loading model {ds}: {e}") |
|
|
|
|
|
def predict(image, dataset_choice): |
|
|
if image is None: |
|
|
return None |
|
|
|
|
|
if dataset_choice not in models: |
|
|
return None |
|
|
|
|
|
model = models[dataset_choice] |
|
|
|
|
|
|
|
|
|
|
|
img_resized = image.resize((224, 224)) |
|
|
img_np = np.array(img_resized).astype(np.float32) / 255.0 |
|
|
|
|
|
|
|
|
if dataset_choice == 'isic': |
|
|
|
|
|
if img_np.shape[-1] == 4: |
|
|
img_np = img_np[:, :, :3] |
|
|
img_tensor = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).float() |
|
|
else: |
|
|
|
|
|
if img_np.shape[-1] == 3: |
|
|
|
|
|
padding = np.zeros((224, 224, 1), dtype=np.float32) |
|
|
img_np = np.concatenate([img_np, padding], axis=-1) |
|
|
|
|
|
img_tensor = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).float() |
|
|
|
|
|
img_tensor = img_tensor.to(DEVICE) |
|
|
|
|
|
with torch.no_grad(): |
|
|
output = model(img_tensor) |
|
|
probs = torch.sigmoid(output) |
|
|
pred_mask = (probs > 0.5).float().cpu().numpy()[0, 0] |
|
|
|
|
|
|
|
|
|
|
|
base_img = np.array(img_resized) |
|
|
overlay = base_img.copy() |
|
|
|
|
|
|
|
|
mask_bool = pred_mask > 0 |
|
|
overlay[mask_bool] = [0, 255, 0] |
|
|
|
|
|
|
|
|
final_img = (0.6 * base_img + 0.4 * overlay).astype(np.uint8) |
|
|
|
|
|
return final_img |
|
|
|
|
|
|
|
|
iface = gr.Interface( |
|
|
fn=predict, |
|
|
inputs=[ |
|
|
gr.Image(type="pil", label="Input Image"), |
|
|
gr.Radio(["isic", "segpc"], label="Dataset Model", value="isic") |
|
|
], |
|
|
outputs=gr.Image(type="numpy", label="Prediction Overlay"), |
|
|
title="Medical Image Segmentation (Awesome-U-Net)", |
|
|
description="Upload an image to segment skin lesions (ISIC) or cells (SegPC).", |
|
|
examples=[ |
|
|
|
|
|
|
|
|
] |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
iface.launch() |
|
|
|