Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from PIL import Image | |
| import numpy as np | |
| import cv2 | |
| import os | |
| import torch | |
| from huggingface_hub import hf_hub_download | |
| from model_utils import get_model, predict | |
| # --- Config --- | |
| CLASS_NAMES = ["background", "Pale Conjunctiva", "Normal Conjuctiva"] | |
| # Private repo + file in your HF model | |
| REPO_ID = "IFMedTech/Pallor_Mask_RCNN_Model" | |
| FILENAME = "mask_rcnn_conjunctiva.pth" | |
| # Determine device once at startup | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| def get_weights_path(): | |
| """ | |
| Download .pth from private HF repo using token stored in secrets. | |
| """ | |
| token = os.environ.get("HUGGINGFACE_TOKEN") | |
| if not token: | |
| raise ValueError("Please set HUGGINGFACE_TOKEN in the Space secrets for private model access.") | |
| model_path = hf_hub_download( | |
| repo_id=REPO_ID, | |
| filename=FILENAME, | |
| token=token, | |
| ) | |
| return model_path | |
| # Optional: cache model so it loads once (recommended for Gradio) | |
| _MODEL = None | |
| def get_cached_model(): | |
| global _MODEL | |
| if _MODEL is None: | |
| try: | |
| weights_path = get_weights_path() | |
| _MODEL = get_model(num_classes=3, weights_path=weights_path) | |
| _MODEL.to(DEVICE) | |
| _MODEL.eval() | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to load model: {str(e)}") | |
| return _MODEL | |
| def segment_image(pil_img): | |
| """ | |
| pil_img comes from gr.Image(type="pil") => already a PIL.Image (or None). | |
| Returns a numpy RGB image for gr.Image output. | |
| """ | |
| if pil_img is None: | |
| return None | |
| try: | |
| image = pil_img.convert("RGB") | |
| model = get_cached_model() | |
| results = predict(model, image, device=DEVICE, class_names=CLASS_NAMES) | |
| # Overlay masks/contours on the original image | |
| image_np = np.array(image) # RGB uint8 | |
| for res in results: | |
| mask = res["mask"] # expected float/0..1 | |
| label = res.get("label", "") | |
| colored_mask = (mask > 0.5).astype(np.uint8) * 255 | |
| contours, _ = cv2.findContours(colored_mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) | |
| cv2.drawContours(image_np, contours, -1, (0, 255, 0), 2) | |
| if len(contours) > 0 and len(contours[0]) > 0: | |
| x, y = contours[0][0][0] | |
| cv2.putText( | |
| image_np, | |
| str(label), | |
| (int(x), int(y) - 10), | |
| cv2.FONT_HERSHEY_SIMPLEX, | |
| 0.7, | |
| (255, 0, 0), | |
| 2, | |
| ) | |
| return image_np | |
| except Exception as e: | |
| print(f"Error during segmentation: {str(e)}") | |
| return np.array(pil_img.convert("RGB")) # Return original image on error | |
| def get_sample_images(): | |
| """ | |
| Get list of sample images from Eye_Dataset folder. | |
| Returns a list of image file paths. | |
| """ | |
| dataset_dir = "Eye_Dataset" | |
| if not os.path.exists(dataset_dir): | |
| return [] | |
| sample_images = [] | |
| for filename in sorted(os.listdir(dataset_dir)): | |
| if filename.lower().endswith(('.png', '.jpg', '.jpeg')): | |
| filepath = os.path.join(dataset_dir, filename) | |
| sample_images.append(filepath) | |
| return sample_images[:10] # Return max 10 sample images | |
| with gr.Blocks(title="Conjunctiva Segmentation") as demo: | |
| gr.Markdown("# Conjunctiva Segmentation - Mask R-CNN") | |
| gr.Markdown(f"Running on: **{DEVICE}**") | |
| with gr.Row(): | |
| inp = gr.Image(type="pil", label="Upload Image") | |
| out = gr.Image(type="numpy", label="Segmented Output") | |
| submit = gr.Button("Submit", variant="primary") | |
| submit.click(fn=segment_image, inputs=inp, outputs=out) | |
| # Add examples from Eye_Dataset folder | |
| examples_list = get_sample_images() | |
| if examples_list: | |
| gr.Examples( | |
| examples=examples_list, | |
| inputs=inp, | |
| outputs=out, | |
| fn=segment_image, | |
| cache_examples=False, | |
| label="Sample Images" | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |