import os import numpy as np from PIL import Image import torch import torchvision import torchvision.transforms as T import cv2 import gradio as gr from huggingface_hub import hf_hub_download import shutil # -------------- Config -------------- REPO_ID = "IFMedTech/dental_segmentation" MODEL_FILENAME = "mask_rcnn_Dental.pth" CLASS_NAMES = ["background", "Class_A", "Class_B", "CLass_C", "Class_D"] SCORE_THRESH_DEFAULT = 0.5 ALPHA_DEFAULT = 0.45 COLOR_MAP = [ (0, 0, 0), (0, 255, 0), (255, 0, 0), (0, 165, 255), (255, 0, 255), ] # -------------- Cache Management -------------- def clean_unnecessary_cache(): """Clean only unnecessary cached files, keeping essential ones""" cache_paths = [ os.path.expanduser("~/.cache/huggingface/hub"), # Model cache os.path.expanduser("~/.cache/torch/hub"), # Torch hub cache os.path.expanduser("~/.cache/torch/kernels"), # Compiled kernels ] for cache_path in cache_paths: if os.path.exists(cache_path): try: # Get cache size before deletion size_gb = sum( os.path.getsize(os.path.join(dirpath, filename)) for dirpath, dirnames, filenames in os.walk(cache_path) for filename in filenames ) / (1024**3) print(f"Clearing {cache_path} ({size_gb:.2f} GB)...") shutil.rmtree(cache_path) print(f"Successfully cleared {cache_path}") except Exception as e: print(f"Warning: Could not clear {cache_path}: {str(e)}") # Clean cache only at startup before downloading model print("Cleaning cache to free up space...") clean_unnecessary_cache() # -------------- Download Model from Private Repo -------------- def download_model_from_hub(): """Download model from private Hugging Face repository""" token = os.environ.get("HUGGINGFACE_TOKEN") if not token: raise ValueError( "HF_TOKEN environment variable is required for private repo access. " "Please set it in your Space settings under 'Repository secrets'." ) try: model_path = hf_hub_download( repo_id=REPO_ID, filename=MODEL_FILENAME, token=token ) return model_path except Exception as e: raise RuntimeError(f"Failed to download model from {REPO_ID}: {str(e)}") # -------------- Utils -------------- def apply_mask_bgr(image_bgr, mask_bool, color_bgr, alpha=0.45): """Apply colored mask overlay to image""" overlay = image_bgr.copy().astype(np.float32) color_vec = np.array(color_bgr, dtype=np.float32) overlay[mask_bool] = (1 - alpha) * overlay[mask_bool] + alpha * color_vec return overlay.astype(np.uint8) def pil_to_tensor(img_pil): """Convert PIL image to tensor""" return T.ToTensor()(img_pil) # -------------- Model -------------- def build_maskrcnn(num_classes): """Build Mask R-CNN model with custom number of classes""" model = torchvision.models.detection.maskrcnn_resnet50_fpn(weights="DEFAULT") # Replace box predictor in_features = model.roi_heads.box_predictor.cls_score.in_features model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor( in_features, num_classes ) # Replace mask predictor in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels hidden_layer = 256 model.roi_heads.mask_predictor = torchvision.models.detection.mask_rcnn.MaskRCNNPredictor( in_features_mask, hidden_layer, num_classes ) return model def load_model(weights_path, num_classes): """Load trained model from weights file""" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = build_maskrcnn(num_classes) state = torch.load(weights_path, map_location=device) model.load_state_dict(state) model.to(device) model.eval() return model, device def infer_and_overlay(model, device, img_pil, score_thresh=0.5, alpha=0.45): """Run inference and overlay segmentation masks on image""" with torch.no_grad(): img_t = pil_to_tensor(img_pil).unsqueeze(0).to(device) out = model(img_t)[0] # Convert to BGR for OpenCV processing img_bgr = cv2.cvtColor(np.array(img_pil), cv2.COLOR_RGB2BGR) # Extract predictions scores = out.get("scores", torch.tensor([])).detach().cpu().numpy() labels = out.get("labels", torch.tensor([])).detach().cpu().numpy().astype(int) masks = out.get("masks", torch.zeros((0, 1, img_bgr.shape[0], img_bgr.shape[1]))).detach().cpu().numpy() # Filter by score threshold keep_idx = [i for i, s in enumerate(scores) if s >= score_thresh] # Apply masks for i in keep_idx: mask_bool = masks[i, 0] > 0.5 lab = labels[i] color = COLOR_MAP[lab] if lab < len(COLOR_MAP) else (0, 255, 255) img_bgr = apply_mask_bgr(img_bgr, mask_bool, color, alpha=alpha) # Convert back to RGB img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB) return Image.fromarray(img_rgb) def run_inference(image): """Gradio inference function""" if image is None: return None try: return infer_and_overlay( model, device, image, score_thresh=SCORE_THRESH_DEFAULT, alpha=ALPHA_DEFAULT ) except Exception as e: print(f"Inference error: {str(e)}") return None # -------------- Initialize Model -------------- print("Downloading model from Hugging Face Hub...") model_path = download_model_from_hub() print(f"Model downloaded to: {model_path}") print("Loading model...") model, device = load_model(model_path, len(CLASS_NAMES)) print(f"Model loaded successfully on device: {device}") # -------------- Gradio UI -------------- with gr.Blocks(title="Teeth Segmentation — Mask R-CNN") as demo: gr.Markdown("## 🦷 Teeth Segmentation — Mask R-CNN (Translucent Masks)") gr.Markdown( "Upload a dental image to segment different tooth classes using Mask R-CNN. " "The model will overlay colored masks on detected teeth." ) with gr.Column(): input_image = gr.Image(label="Input Image", type="pil") analyze_btn = gr.Button("Analyze Image", variant="primary") output_image = gr.Image(label="Segmented Output", type="pil") gr.Examples( examples=[ ["example_image1.jfif"], ["example_image2.jfif"], ["example_image3.jfif"], ], inputs=input_image, ) analyze_btn.click( fn=run_inference, inputs=[input_image], outputs=output_image, show_progress=True ) if __name__ == "__main__": demo.launch()