import os from typing import Tuple import numpy as np import gradio as gr from glob import glob from functools import partial from dataclasses import dataclass import torch import torch.nn.functional as F import torchvision.transforms as TF from transformers import SegformerForSemanticSegmentation """ Medical Image Segmentation Web Interface This module provides a Gradio-based web interface for performing semantic segmentation on medical images using a pre-trained SegFormer model. Features: - Real-time image segmentation - Confidence score visualization - Color-coded organ detection (stomach, large bowel, small bowel) - Interactive web interface - CPU/GPU automatic detection Author: Medical Image Segmentation Project License: MIT """ @dataclass class Configs: NUM_CLASSES: int = 4 # bao gồm background CLASSES: Tuple[str, ...] = ("Ruột già", "Ruột non", "Dạ dày") IMAGE_SIZE: Tuple[int, int] = (288, 288) # W, H MEAN: Tuple[float, ...] = (0.485, 0.456, 0.406) STD: Tuple[float, ...] = (0.229, 0.224, 0.225) # Sử dụng model mới huấn luyện, fallback sang model cũ nếu không tìm thấy MODEL_PATH: str = os.path.join(os.getcwd(), "models", "best_model") if os.path.exists(os.path.join(os.getcwd(), "models", "best_model")) else os.path.join(os.getcwd(), "segformer_trained_weights") def get_model(*, model_path, num_classes): """ Load pre-trained SegFormer model. """ model = SegformerForSemanticSegmentation.from_pretrained( model_path, num_labels=num_classes, ignore_mismatched_sizes=True ) return model @torch.inference_mode() def predict(input_image, model=None, preprocess_fn=None, device="cpu"): """ Perform semantic segmentation on input medical image. """ if input_image is None: return None, [] shape_H_W = input_image.size[::-1] input_tensor = preprocess_fn(input_image) input_tensor = input_tensor.unsqueeze(0).to(device) # Generate predictions outputs = model(pixel_values=input_tensor.to(device), return_dict=True) predictions = F.interpolate(outputs["logits"], size=shape_H_W, mode="bilinear", align_corners=False) # Get predicted class and confidence probs = torch.softmax(predictions, dim=1) preds_argmax = predictions.argmax(dim=1).cpu().squeeze().numpy() confidence_map = probs.max(dim=1)[0].cpu().squeeze().numpy() # Create segmentation info with confidence seg_info = [] # Classes: 1=Ruột già, 2=Ruột non, 3=Dạ dày for idx, class_name in enumerate(Configs.CLASSES, 1): mask = preds_argmax == idx if mask.sum() > 0: # Chỉ tính confidence nếu có pixel được dự đoán conf_score = confidence_map[mask].mean() label = f"{class_name} ({conf_score:.1%})" seg_info.append((mask, label)) else: # Không hiển thị label nếu không phát hiện được pass return (input_image, seg_info) if __name__ == "__main__": """ Main application entry point. """ # Mapping màu sắc cho hiển thị class2hexcolor = { "Dạ dày": "#007fff", # Xanh dương "Ruột non": "#009A17", # Xanh lá "Ruột già": "#FF0000" # Đỏ } DEVICE = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") # Load model locally try: model_dir = Configs.MODEL_PATH if not os.path.exists(model_dir): print(f"Model path not found: {model_dir}") model_dir = "./segformer_trained_weights" except Exception as e: print(f"Error loading model: {e}") model_dir = "./segformer_trained_weights" # Load model print(f"Loading model from: {model_dir}") model = get_model(model_path=model_dir, num_classes=Configs.NUM_CLASSES) model.to(DEVICE) model.eval() # Warmup try: _ = model(torch.randn(1, 3, *Configs.IMAGE_SIZE[::-1], device=DEVICE)) except Exception as e: print(f"Warmup warning: {e}") preprocess = TF.Compose( [ TF.Resize(size=Configs.IMAGE_SIZE[::-1]), TF.ToTensor(), TF.Normalize(Configs.MEAN, Configs.STD, inplace=True), ] ) with gr.Blocks(title="Phân Đoạn Ảnh Y Tế") as demo: gr.Markdown("""