IFMedTechdemo's picture
Update app.py
65a0fd1 verified
import os
import numpy as np
from PIL import Image
import torch
import torchvision
import torchvision.transforms as T
import cv2
import gradio as gr
from huggingface_hub import hf_hub_download
import shutil
# -------------- Config --------------
REPO_ID = "IFMedTech/dental_segmentation"
MODEL_FILENAME = "mask_rcnn_Dental.pth"
CLASS_NAMES = ["background", "Class_A", "Class_B", "CLass_C", "Class_D"]
SCORE_THRESH_DEFAULT = 0.5
ALPHA_DEFAULT = 0.45
COLOR_MAP = [
(0, 0, 0),
(0, 255, 0),
(255, 0, 0),
(0, 165, 255),
(255, 0, 255),
]
# -------------- Cache Management --------------
def clean_unnecessary_cache():
"""Clean only unnecessary cached files, keeping essential ones"""
cache_paths = [
os.path.expanduser("~/.cache/huggingface/hub"), # Model cache
os.path.expanduser("~/.cache/torch/hub"), # Torch hub cache
os.path.expanduser("~/.cache/torch/kernels"), # Compiled kernels
]
for cache_path in cache_paths:
if os.path.exists(cache_path):
try:
# Get cache size before deletion
size_gb = sum(
os.path.getsize(os.path.join(dirpath, filename))
for dirpath, dirnames, filenames in os.walk(cache_path)
for filename in filenames
) / (1024**3)
print(f"Clearing {cache_path} ({size_gb:.2f} GB)...")
shutil.rmtree(cache_path)
print(f"Successfully cleared {cache_path}")
except Exception as e:
print(f"Warning: Could not clear {cache_path}: {str(e)}")
# Clean cache only at startup before downloading model
print("Cleaning cache to free up space...")
clean_unnecessary_cache()
# -------------- Download Model from Private Repo --------------
def download_model_from_hub():
"""Download model from private Hugging Face repository"""
token = os.environ.get("HUGGINGFACE_TOKEN")
if not token:
raise ValueError(
"HF_TOKEN environment variable is required for private repo access. "
"Please set it in your Space settings under 'Repository secrets'."
)
try:
model_path = hf_hub_download(
repo_id=REPO_ID,
filename=MODEL_FILENAME,
token=token
)
return model_path
except Exception as e:
raise RuntimeError(f"Failed to download model from {REPO_ID}: {str(e)}")
# -------------- Utils --------------
def apply_mask_bgr(image_bgr, mask_bool, color_bgr, alpha=0.45):
"""Apply colored mask overlay to image"""
overlay = image_bgr.copy().astype(np.float32)
color_vec = np.array(color_bgr, dtype=np.float32)
overlay[mask_bool] = (1 - alpha) * overlay[mask_bool] + alpha * color_vec
return overlay.astype(np.uint8)
def pil_to_tensor(img_pil):
"""Convert PIL image to tensor"""
return T.ToTensor()(img_pil)
# -------------- Model --------------
def build_maskrcnn(num_classes):
"""Build Mask R-CNN model with custom number of classes"""
model = torchvision.models.detection.maskrcnn_resnet50_fpn(weights="DEFAULT")
# Replace box predictor
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(
in_features, num_classes
)
# Replace mask predictor
in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
hidden_layer = 256
model.roi_heads.mask_predictor = torchvision.models.detection.mask_rcnn.MaskRCNNPredictor(
in_features_mask, hidden_layer, num_classes
)
return model
def load_model(weights_path, num_classes):
"""Load trained model from weights file"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = build_maskrcnn(num_classes)
state = torch.load(weights_path, map_location=device)
model.load_state_dict(state)
model.to(device)
model.eval()
return model, device
def infer_and_overlay(model, device, img_pil, score_thresh=0.5, alpha=0.45):
"""Run inference and overlay segmentation masks on image"""
with torch.no_grad():
img_t = pil_to_tensor(img_pil).unsqueeze(0).to(device)
out = model(img_t)[0]
# Convert to BGR for OpenCV processing
img_bgr = cv2.cvtColor(np.array(img_pil), cv2.COLOR_RGB2BGR)
# Extract predictions
scores = out.get("scores", torch.tensor([])).detach().cpu().numpy()
labels = out.get("labels", torch.tensor([])).detach().cpu().numpy().astype(int)
masks = out.get("masks", torch.zeros((0, 1, img_bgr.shape[0], img_bgr.shape[1]))).detach().cpu().numpy()
# Filter by score threshold
keep_idx = [i for i, s in enumerate(scores) if s >= score_thresh]
# Apply masks
for i in keep_idx:
mask_bool = masks[i, 0] > 0.5
lab = labels[i]
color = COLOR_MAP[lab] if lab < len(COLOR_MAP) else (0, 255, 255)
img_bgr = apply_mask_bgr(img_bgr, mask_bool, color, alpha=alpha)
# Convert back to RGB
img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
return Image.fromarray(img_rgb)
def run_inference(image):
"""Gradio inference function"""
if image is None:
return None
try:
return infer_and_overlay(
model, device, image,
score_thresh=SCORE_THRESH_DEFAULT,
alpha=ALPHA_DEFAULT
)
except Exception as e:
print(f"Inference error: {str(e)}")
return None
# -------------- Initialize Model --------------
print("Downloading model from Hugging Face Hub...")
model_path = download_model_from_hub()
print(f"Model downloaded to: {model_path}")
print("Loading model...")
model, device = load_model(model_path, len(CLASS_NAMES))
print(f"Model loaded successfully on device: {device}")
# -------------- Gradio UI --------------
with gr.Blocks(title="Teeth Segmentation — Mask R-CNN") as demo:
gr.Markdown("## 🦷 Teeth Segmentation — Mask R-CNN (Translucent Masks)")
gr.Markdown(
"Upload a dental image to segment different tooth classes using Mask R-CNN. "
"The model will overlay colored masks on detected teeth."
)
with gr.Column():
input_image = gr.Image(label="Input Image", type="pil")
analyze_btn = gr.Button("Analyze Image", variant="primary")
output_image = gr.Image(label="Segmented Output", type="pil")
gr.Examples(
examples=[
["example_image1.jfif"],
["example_image2.jfif"],
["example_image3.jfif"],
],
inputs=input_image,
)
analyze_btn.click(
fn=run_inference,
inputs=[input_image],
outputs=output_image,
show_progress=True
)
if __name__ == "__main__":
demo.launch()