Spaces:
Sleeping
Sleeping
| import os | |
| import numpy as np | |
| from PIL import Image | |
| import torch | |
| import torchvision | |
| import torchvision.transforms as T | |
| import cv2 | |
| import gradio as gr | |
| from huggingface_hub import hf_hub_download | |
| import shutil | |
| # -------------- Config -------------- | |
| REPO_ID = "IFMedTech/dental_segmentation" | |
| MODEL_FILENAME = "mask_rcnn_Dental.pth" | |
| CLASS_NAMES = ["background", "Class_A", "Class_B", "CLass_C", "Class_D"] | |
| SCORE_THRESH_DEFAULT = 0.5 | |
| ALPHA_DEFAULT = 0.45 | |
| COLOR_MAP = [ | |
| (0, 0, 0), | |
| (0, 255, 0), | |
| (255, 0, 0), | |
| (0, 165, 255), | |
| (255, 0, 255), | |
| ] | |
| # -------------- Cache Management -------------- | |
| def clean_unnecessary_cache(): | |
| """Clean only unnecessary cached files, keeping essential ones""" | |
| cache_paths = [ | |
| os.path.expanduser("~/.cache/huggingface/hub"), # Model cache | |
| os.path.expanduser("~/.cache/torch/hub"), # Torch hub cache | |
| os.path.expanduser("~/.cache/torch/kernels"), # Compiled kernels | |
| ] | |
| for cache_path in cache_paths: | |
| if os.path.exists(cache_path): | |
| try: | |
| # Get cache size before deletion | |
| size_gb = sum( | |
| os.path.getsize(os.path.join(dirpath, filename)) | |
| for dirpath, dirnames, filenames in os.walk(cache_path) | |
| for filename in filenames | |
| ) / (1024**3) | |
| print(f"Clearing {cache_path} ({size_gb:.2f} GB)...") | |
| shutil.rmtree(cache_path) | |
| print(f"Successfully cleared {cache_path}") | |
| except Exception as e: | |
| print(f"Warning: Could not clear {cache_path}: {str(e)}") | |
| # Clean cache only at startup before downloading model | |
| print("Cleaning cache to free up space...") | |
| clean_unnecessary_cache() | |
| # -------------- Download Model from Private Repo -------------- | |
| def download_model_from_hub(): | |
| """Download model from private Hugging Face repository""" | |
| token = os.environ.get("HUGGINGFACE_TOKEN") | |
| if not token: | |
| raise ValueError( | |
| "HF_TOKEN environment variable is required for private repo access. " | |
| "Please set it in your Space settings under 'Repository secrets'." | |
| ) | |
| try: | |
| model_path = hf_hub_download( | |
| repo_id=REPO_ID, | |
| filename=MODEL_FILENAME, | |
| token=token | |
| ) | |
| return model_path | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to download model from {REPO_ID}: {str(e)}") | |
| # -------------- Utils -------------- | |
| def apply_mask_bgr(image_bgr, mask_bool, color_bgr, alpha=0.45): | |
| """Apply colored mask overlay to image""" | |
| overlay = image_bgr.copy().astype(np.float32) | |
| color_vec = np.array(color_bgr, dtype=np.float32) | |
| overlay[mask_bool] = (1 - alpha) * overlay[mask_bool] + alpha * color_vec | |
| return overlay.astype(np.uint8) | |
| def pil_to_tensor(img_pil): | |
| """Convert PIL image to tensor""" | |
| return T.ToTensor()(img_pil) | |
| # -------------- Model -------------- | |
| def build_maskrcnn(num_classes): | |
| """Build Mask R-CNN model with custom number of classes""" | |
| model = torchvision.models.detection.maskrcnn_resnet50_fpn(weights="DEFAULT") | |
| # Replace box predictor | |
| in_features = model.roi_heads.box_predictor.cls_score.in_features | |
| model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor( | |
| in_features, num_classes | |
| ) | |
| # Replace mask predictor | |
| in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels | |
| hidden_layer = 256 | |
| model.roi_heads.mask_predictor = torchvision.models.detection.mask_rcnn.MaskRCNNPredictor( | |
| in_features_mask, hidden_layer, num_classes | |
| ) | |
| return model | |
| def load_model(weights_path, num_classes): | |
| """Load trained model from weights file""" | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = build_maskrcnn(num_classes) | |
| state = torch.load(weights_path, map_location=device) | |
| model.load_state_dict(state) | |
| model.to(device) | |
| model.eval() | |
| return model, device | |
| def infer_and_overlay(model, device, img_pil, score_thresh=0.5, alpha=0.45): | |
| """Run inference and overlay segmentation masks on image""" | |
| with torch.no_grad(): | |
| img_t = pil_to_tensor(img_pil).unsqueeze(0).to(device) | |
| out = model(img_t)[0] | |
| # Convert to BGR for OpenCV processing | |
| img_bgr = cv2.cvtColor(np.array(img_pil), cv2.COLOR_RGB2BGR) | |
| # Extract predictions | |
| scores = out.get("scores", torch.tensor([])).detach().cpu().numpy() | |
| labels = out.get("labels", torch.tensor([])).detach().cpu().numpy().astype(int) | |
| masks = out.get("masks", torch.zeros((0, 1, img_bgr.shape[0], img_bgr.shape[1]))).detach().cpu().numpy() | |
| # Filter by score threshold | |
| keep_idx = [i for i, s in enumerate(scores) if s >= score_thresh] | |
| # Apply masks | |
| for i in keep_idx: | |
| mask_bool = masks[i, 0] > 0.5 | |
| lab = labels[i] | |
| color = COLOR_MAP[lab] if lab < len(COLOR_MAP) else (0, 255, 255) | |
| img_bgr = apply_mask_bgr(img_bgr, mask_bool, color, alpha=alpha) | |
| # Convert back to RGB | |
| img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB) | |
| return Image.fromarray(img_rgb) | |
| def run_inference(image): | |
| """Gradio inference function""" | |
| if image is None: | |
| return None | |
| try: | |
| return infer_and_overlay( | |
| model, device, image, | |
| score_thresh=SCORE_THRESH_DEFAULT, | |
| alpha=ALPHA_DEFAULT | |
| ) | |
| except Exception as e: | |
| print(f"Inference error: {str(e)}") | |
| return None | |
| # -------------- Initialize Model -------------- | |
| print("Downloading model from Hugging Face Hub...") | |
| model_path = download_model_from_hub() | |
| print(f"Model downloaded to: {model_path}") | |
| print("Loading model...") | |
| model, device = load_model(model_path, len(CLASS_NAMES)) | |
| print(f"Model loaded successfully on device: {device}") | |
| # -------------- Gradio UI -------------- | |
| with gr.Blocks(title="Teeth Segmentation — Mask R-CNN") as demo: | |
| gr.Markdown("## 🦷 Teeth Segmentation — Mask R-CNN (Translucent Masks)") | |
| gr.Markdown( | |
| "Upload a dental image to segment different tooth classes using Mask R-CNN. " | |
| "The model will overlay colored masks on detected teeth." | |
| ) | |
| with gr.Column(): | |
| input_image = gr.Image(label="Input Image", type="pil") | |
| analyze_btn = gr.Button("Analyze Image", variant="primary") | |
| output_image = gr.Image(label="Segmented Output", type="pil") | |
| gr.Examples( | |
| examples=[ | |
| ["example_image1.jfif"], | |
| ["example_image2.jfif"], | |
| ["example_image3.jfif"], | |
| ], | |
| inputs=input_image, | |
| ) | |
| analyze_btn.click( | |
| fn=run_inference, | |
| inputs=[input_image], | |
| outputs=output_image, | |
| show_progress=True | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |