import gradio as gr import torch import numpy as np from PIL import Image import yaml import os from models.unet import UNet # Configuration DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Map dataset names to config/model paths CONFIG_PATHS = { 'isic': './configs/isic/isic2018_unet.yaml', 'segpc': './configs/segpc/segpc2021_unet.yaml' } MODEL_PATHS = { 'isic': './saved_models/isic2018_unet/best_model_state_dict.pt', 'segpc': './saved_models/segpc2021_unet/best_model_state_dict.pt' } def load_config(config_path): with open(config_path, 'r') as f: return yaml.safe_load(f) def load_model(dataset_name): config = load_config(CONFIG_PATHS[dataset_name]) model = UNet( in_channels=config['model']['params']['in_channels'], out_channels=config['model']['params']['out_channels'] ) model_path = MODEL_PATHS[dataset_name] if os.path.exists(model_path): state_dict = torch.load(model_path, map_location=DEVICE) model.load_state_dict(state_dict) print(f"Loaded model for {dataset_name} from {model_path}") else: print(f"Warning: Model weights not found for {dataset_name} at {model_path}") model.to(DEVICE) model.eval() return model # Load models once (cache them) models = {} for ds in ['isic', 'segpc']: try: models[ds] = load_model(ds) except Exception as e: print(f"Error loading model {ds}: {e}") def predict(image, dataset_choice): if image is None: return None if dataset_choice not in models: return None model = models[dataset_choice] # Preprocess # Resize to 224x224 as per config img_resized = image.resize((224, 224)) img_np = np.array(img_resized).astype(np.float32) / 255.0 # Handle channels if dataset_choice == 'isic': # ISIC: 3 channels (RGB) if img_np.shape[-1] == 4: img_np = img_np[:, :, :3] img_tensor = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).float() else: # SegPC: 4 channels (BMP input often loaded as RGB, need to assume/check) if img_np.shape[-1] == 3: # Create fake 4th channel padding = np.zeros((224, 224, 1), dtype=np.float32) img_np = np.concatenate([img_np, padding], axis=-1) img_tensor = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).float() img_tensor = img_tensor.to(DEVICE) with torch.no_grad(): output = model(img_tensor) probs = torch.sigmoid(output) pred_mask = (probs > 0.5).float().cpu().numpy()[0, 0] # Post-process for visualization # Create an overlay base_img = np.array(img_resized) overlay = base_img.copy() # Green mask mask_bool = pred_mask > 0 overlay[mask_bool] = [0, 255, 0] # Make Green # Blend final_img = (0.6 * base_img + 0.4 * overlay).astype(np.uint8) return final_img # Interface iface = gr.Interface( fn=predict, inputs=[ gr.Image(type="pil", label="Input Image"), gr.Radio(["isic", "segpc"], label="Dataset Model", value="isic") ], outputs=gr.Image(type="numpy", label="Prediction Overlay"), title="Medical Image Segmentation (Awesome-U-Net)", description="Upload an image to segment skin lesions (ISIC) or cells (SegPC).", examples=[ # Add example paths if available # ["dataset_examples/isic_sample.jpg", "isic"] ] ) if __name__ == "__main__": iface.launch()