import gradio as gr from PIL import Image import numpy as np import cv2 import os import torch from huggingface_hub import hf_hub_download from model_utils import get_model, predict # --- Config --- CLASS_NAMES = ["background", "Pale Conjunctiva", "Normal Conjuctiva"] # Private repo + file in your HF model REPO_ID = "IFMedTech/Pallor_Mask_RCNN_Model" FILENAME = "mask_rcnn_conjunctiva.pth" # Determine device once at startup DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") def get_weights_path(): """ Download .pth from private HF repo using token stored in secrets. """ token = os.environ.get("HUGGINGFACE_TOKEN") if not token: raise ValueError("Please set HUGGINGFACE_TOKEN in the Space secrets for private model access.") model_path = hf_hub_download( repo_id=REPO_ID, filename=FILENAME, token=token, ) return model_path # Optional: cache model so it loads once (recommended for Gradio) _MODEL = None def get_cached_model(): global _MODEL if _MODEL is None: try: weights_path = get_weights_path() _MODEL = get_model(num_classes=3, weights_path=weights_path) _MODEL.to(DEVICE) _MODEL.eval() except Exception as e: raise RuntimeError(f"Failed to load model: {str(e)}") return _MODEL def segment_image(pil_img): """ pil_img comes from gr.Image(type="pil") => already a PIL.Image (or None). Returns a numpy RGB image for gr.Image output. """ if pil_img is None: return None try: image = pil_img.convert("RGB") model = get_cached_model() results = predict(model, image, device=DEVICE, class_names=CLASS_NAMES) # Overlay masks/contours on the original image image_np = np.array(image) # RGB uint8 for res in results: mask = res["mask"] # expected float/0..1 label = res.get("label", "") colored_mask = (mask > 0.5).astype(np.uint8) * 255 contours, _ = cv2.findContours(colored_mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) cv2.drawContours(image_np, contours, -1, (0, 255, 0), 2) if len(contours) > 0 and len(contours[0]) > 0: x, y = contours[0][0][0] cv2.putText( image_np, str(label), (int(x), int(y) - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 0, 0), 2, ) return image_np except Exception as e: print(f"Error during segmentation: {str(e)}") return np.array(pil_img.convert("RGB")) # Return original image on error def get_sample_images(): """ Get list of sample images from Eye_Dataset folder. Returns a list of image file paths. """ dataset_dir = "Eye_Dataset" if not os.path.exists(dataset_dir): return [] sample_images = [] for filename in sorted(os.listdir(dataset_dir)): if filename.lower().endswith(('.png', '.jpg', '.jpeg')): filepath = os.path.join(dataset_dir, filename) sample_images.append(filepath) return sample_images[:10] # Return max 10 sample images with gr.Blocks(title="Conjunctiva Segmentation") as demo: gr.Markdown("# Conjunctiva Segmentation - Mask R-CNN") gr.Markdown(f"Running on: **{DEVICE}**") with gr.Row(): inp = gr.Image(type="pil", label="Upload Image") out = gr.Image(type="numpy", label="Segmented Output") submit = gr.Button("Submit", variant="primary") submit.click(fn=segment_image, inputs=inp, outputs=out) # Add examples from Eye_Dataset folder examples_list = get_sample_images() if examples_list: gr.Examples( examples=examples_list, inputs=inp, outputs=out, fn=segment_image, cache_examples=False, label="Sample Images" ) if __name__ == "__main__": demo.launch()