Spaces:
Sleeping
Sleeping
| import os | |
| from typing import Tuple | |
| import numpy as np | |
| import gradio as gr | |
| from glob import glob | |
| from functools import partial | |
| from dataclasses import dataclass | |
| import torch | |
| import torch.nn.functional as F | |
| import torchvision.transforms as TF | |
| from transformers import SegformerForSemanticSegmentation | |
| """ | |
| Medical Image Segmentation Web Interface | |
| This module provides a Gradio-based web interface for performing semantic | |
| segmentation on medical images using a pre-trained SegFormer model. | |
| Features: | |
| - Real-time image segmentation | |
| - Confidence score visualization | |
| - Color-coded organ detection (stomach, large bowel, small bowel) | |
| - Interactive web interface | |
| - CPU/GPU automatic detection | |
| Author: Medical Image Segmentation Project | |
| License: MIT | |
| """ | |
| class Configs: | |
| NUM_CLASSES: int = 4 # bao gồm background | |
| CLASSES: Tuple[str, ...] = ("Ruột già", "Ruột non", "Dạ dày") | |
| IMAGE_SIZE: Tuple[int, int] = (288, 288) # W, H | |
| MEAN: Tuple[float, ...] = (0.485, 0.456, 0.406) | |
| STD: Tuple[float, ...] = (0.229, 0.224, 0.225) | |
| # Sử dụng model mới huấn luyện, fallback sang model cũ nếu không tìm thấy | |
| MODEL_PATH: str = os.path.join(os.getcwd(), "models", "best_model") if os.path.exists(os.path.join(os.getcwd(), "models", "best_model")) else os.path.join(os.getcwd(), "segformer_trained_weights") | |
| def get_model(*, model_path, num_classes): | |
| """ | |
| Load pre-trained SegFormer model. | |
| """ | |
| model = SegformerForSemanticSegmentation.from_pretrained( | |
| model_path, | |
| num_labels=num_classes, | |
| ignore_mismatched_sizes=True | |
| ) | |
| return model | |
| def predict(input_image, model=None, preprocess_fn=None, device="cpu"): | |
| """ | |
| Perform semantic segmentation on input medical image. | |
| """ | |
| if input_image is None: | |
| return None, [] | |
| shape_H_W = input_image.size[::-1] | |
| input_tensor = preprocess_fn(input_image) | |
| input_tensor = input_tensor.unsqueeze(0).to(device) | |
| # Generate predictions | |
| outputs = model(pixel_values=input_tensor.to(device), return_dict=True) | |
| predictions = F.interpolate(outputs["logits"], size=shape_H_W, mode="bilinear", align_corners=False) | |
| # Get predicted class and confidence | |
| probs = torch.softmax(predictions, dim=1) | |
| preds_argmax = predictions.argmax(dim=1).cpu().squeeze().numpy() | |
| confidence_map = probs.max(dim=1)[0].cpu().squeeze().numpy() | |
| # Create segmentation info with confidence | |
| seg_info = [] | |
| # Classes: 1=Ruột già, 2=Ruột non, 3=Dạ dày | |
| for idx, class_name in enumerate(Configs.CLASSES, 1): | |
| mask = preds_argmax == idx | |
| if mask.sum() > 0: | |
| # Chỉ tính confidence nếu có pixel được dự đoán | |
| conf_score = confidence_map[mask].mean() | |
| label = f"{class_name} ({conf_score:.1%})" | |
| seg_info.append((mask, label)) | |
| else: | |
| # Không hiển thị label nếu không phát hiện được | |
| pass | |
| return (input_image, seg_info) | |
| if __name__ == "__main__": | |
| """ | |
| Main application entry point. | |
| """ | |
| # Mapping màu sắc cho hiển thị | |
| class2hexcolor = { | |
| "Dạ dày": "#007fff", # Xanh dương | |
| "Ruột non": "#009A17", # Xanh lá | |
| "Ruột già": "#FF0000" # Đỏ | |
| } | |
| DEVICE = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") | |
| # Load model locally | |
| try: | |
| model_dir = Configs.MODEL_PATH | |
| if not os.path.exists(model_dir): | |
| print(f"Model path not found: {model_dir}") | |
| model_dir = "./segformer_trained_weights" | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| model_dir = "./segformer_trained_weights" | |
| # Load model | |
| print(f"Loading model from: {model_dir}") | |
| model = get_model(model_path=model_dir, num_classes=Configs.NUM_CLASSES) | |
| model.to(DEVICE) | |
| model.eval() | |
| # Warmup | |
| try: | |
| _ = model(torch.randn(1, 3, *Configs.IMAGE_SIZE[::-1], device=DEVICE)) | |
| except Exception as e: | |
| print(f"Warmup warning: {e}") | |
| preprocess = TF.Compose( | |
| [ | |
| TF.Resize(size=Configs.IMAGE_SIZE[::-1]), | |
| TF.ToTensor(), | |
| TF.Normalize(Configs.MEAN, Configs.STD, inplace=True), | |
| ] | |
| ) | |
| with gr.Blocks(title="Phân Đoạn Ảnh Y Tế") as demo: | |
| gr.Markdown(""" | |
| <h1><center>🏥 Phân Đoạn Ảnh Y Tế - Tập Dữ Liệu UW-Madison GI Tract</center></h1> | |
| <p><center>Hệ thống tự động phát hiện và phân đoạn các cơ quan tiêu hóa: Dạ dày, Ruột non, Ruột già.</center></p> | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### 📥 Ảnh Đầu Vào") | |
| img_input = gr.Image(type="pil", height=360, width=360, label="Tải ảnh lên") | |
| with gr.Column(): | |
| gr.Markdown("### 📊 Kết Quả Dự Đoán") | |
| # AnnotatedImage hiển thị ảnh gốc + các lớp mask | |
| img_output = gr.AnnotatedImage( | |
| label="Kết quả phân đoạn", | |
| height=360, | |
| width=360, | |
| color_map=class2hexcolor | |
| ) | |
| section_btn = gr.Button("🎯 Chạy Phân Đoạn", size="lg", variant="primary") | |
| section_btn.click(partial(predict, model=model, preprocess_fn=preprocess, device=DEVICE), img_input, img_output) | |
| gr.Markdown("---") | |
| gr.Markdown("### 📸 Ảnh Mẫu (Click để thử)") | |
| images_dir = glob(os.path.join(os.getcwd(), "samples") + os.sep + "*.png") | |
| if len(images_dir) > 0: | |
| examples = [i for i in np.random.choice(images_dir, size=min(10, len(images_dir)), replace=False)] | |
| gr.Examples( | |
| examples=examples, | |
| inputs=img_input, | |
| outputs=img_output, | |
| fn=partial(predict, model=model, preprocess_fn=preprocess, device=DEVICE), | |
| cache_examples=False, | |
| label="Thư viện ảnh mẫu" | |
| ) | |
| gr.Markdown(""" | |
| --- | |
| ### 🎨 Chú Thích Màu Sắc | |
| - 🔵 **Xanh Dương**: Dạ dày | |
| - 🟢 **Xanh Lá**: Ruột non | |
| - 🔴 **Đỏ**: Ruột già | |
| ### ℹ️ Thông Tin Hệ Thống | |
| - **Mô hình**: SegFormer (mit-b0) | |
| - **Kích thước đầu vào**: 288 × 288 pixels | |
| - **Framework**: PyTorch + Gradio | |
| """) | |
| demo.launch() | |