Kim Mạnh Hưng
Fix config loading bug
6400203
import gradio as gr
import torch
import numpy as np
from PIL import Image
import yaml
import os
from models.unet import UNet
# Configuration
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Map dataset names to config/model paths
CONFIG_PATHS = {
'isic': './configs/isic/isic2018_unet.yaml',
'segpc': './configs/segpc/segpc2021_unet.yaml'
}
MODEL_PATHS = {
'isic': './saved_models/isic2018_unet/best_model_state_dict.pt',
'segpc': './saved_models/segpc2021_unet/best_model_state_dict.pt'
}
def load_config(config_path):
with open(config_path, 'r') as f:
return yaml.safe_load(f)
def load_model(dataset_name):
config = load_config(CONFIG_PATHS[dataset_name])
model = UNet(
in_channels=config['model']['params']['in_channels'],
out_channels=config['model']['params']['out_channels']
)
model_path = MODEL_PATHS[dataset_name]
if os.path.exists(model_path):
state_dict = torch.load(model_path, map_location=DEVICE)
model.load_state_dict(state_dict)
print(f"Loaded model for {dataset_name} from {model_path}")
else:
print(f"Warning: Model weights not found for {dataset_name} at {model_path}")
model.to(DEVICE)
model.eval()
return model
# Load models once (cache them)
models = {}
for ds in ['isic', 'segpc']:
try:
models[ds] = load_model(ds)
except Exception as e:
print(f"Error loading model {ds}: {e}")
def predict(image, dataset_choice):
if image is None:
return None
if dataset_choice not in models:
return None
model = models[dataset_choice]
# Preprocess
# Resize to 224x224 as per config
img_resized = image.resize((224, 224))
img_np = np.array(img_resized).astype(np.float32) / 255.0
# Handle channels
if dataset_choice == 'isic':
# ISIC: 3 channels (RGB)
if img_np.shape[-1] == 4:
img_np = img_np[:, :, :3]
img_tensor = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).float()
else:
# SegPC: 4 channels (BMP input often loaded as RGB, need to assume/check)
if img_np.shape[-1] == 3:
# Create fake 4th channel
padding = np.zeros((224, 224, 1), dtype=np.float32)
img_np = np.concatenate([img_np, padding], axis=-1)
img_tensor = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).float()
img_tensor = img_tensor.to(DEVICE)
with torch.no_grad():
output = model(img_tensor)
probs = torch.sigmoid(output)
pred_mask = (probs > 0.5).float().cpu().numpy()[0, 0]
# Post-process for visualization
# Create an overlay
base_img = np.array(img_resized)
overlay = base_img.copy()
# Green mask
mask_bool = pred_mask > 0
overlay[mask_bool] = [0, 255, 0] # Make Green
# Blend
final_img = (0.6 * base_img + 0.4 * overlay).astype(np.uint8)
return final_img
# Interface
iface = gr.Interface(
fn=predict,
inputs=[
gr.Image(type="pil", label="Input Image"),
gr.Radio(["isic", "segpc"], label="Dataset Model", value="isic")
],
outputs=gr.Image(type="numpy", label="Prediction Overlay"),
title="Medical Image Segmentation (Awesome-U-Net)",
description="Upload an image to segment skin lesions (ISIC) or cells (SegPC).",
examples=[
# Add example paths if available
# ["dataset_examples/isic_sample.jpg", "isic"]
]
)
if __name__ == "__main__":
iface.launch()