Eye_Anaemia / app.py
IFMedTechdemo's picture
Update app.py
777d4a6 verified
import gradio as gr
from PIL import Image
import numpy as np
import cv2
import os
import torch
from huggingface_hub import hf_hub_download
from model_utils import get_model, predict
# --- Config ---
CLASS_NAMES = ["background", "Pale Conjunctiva", "Normal Conjuctiva"]
# Private repo + file in your HF model
REPO_ID = "IFMedTech/Pallor_Mask_RCNN_Model"
FILENAME = "mask_rcnn_conjunctiva.pth"
# Determine device once at startup
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def get_weights_path():
"""
Download .pth from private HF repo using token stored in secrets.
"""
token = os.environ.get("HUGGINGFACE_TOKEN")
if not token:
raise ValueError("Please set HUGGINGFACE_TOKEN in the Space secrets for private model access.")
model_path = hf_hub_download(
repo_id=REPO_ID,
filename=FILENAME,
token=token,
)
return model_path
# Optional: cache model so it loads once (recommended for Gradio)
_MODEL = None
def get_cached_model():
global _MODEL
if _MODEL is None:
try:
weights_path = get_weights_path()
_MODEL = get_model(num_classes=3, weights_path=weights_path)
_MODEL.to(DEVICE)
_MODEL.eval()
except Exception as e:
raise RuntimeError(f"Failed to load model: {str(e)}")
return _MODEL
def segment_image(pil_img):
"""
pil_img comes from gr.Image(type="pil") => already a PIL.Image (or None).
Returns a numpy RGB image for gr.Image output.
"""
if pil_img is None:
return None
try:
image = pil_img.convert("RGB")
model = get_cached_model()
results = predict(model, image, device=DEVICE, class_names=CLASS_NAMES)
# Overlay masks/contours on the original image
image_np = np.array(image) # RGB uint8
for res in results:
mask = res["mask"] # expected float/0..1
label = res.get("label", "")
colored_mask = (mask > 0.5).astype(np.uint8) * 255
contours, _ = cv2.findContours(colored_mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
cv2.drawContours(image_np, contours, -1, (0, 255, 0), 2)
if len(contours) > 0 and len(contours[0]) > 0:
x, y = contours[0][0][0]
cv2.putText(
image_np,
str(label),
(int(x), int(y) - 10),
cv2.FONT_HERSHEY_SIMPLEX,
0.7,
(255, 0, 0),
2,
)
return image_np
except Exception as e:
print(f"Error during segmentation: {str(e)}")
return np.array(pil_img.convert("RGB")) # Return original image on error
def get_sample_images():
"""
Get list of sample images from Eye_Dataset folder.
Returns a list of image file paths.
"""
dataset_dir = "Eye_Dataset"
if not os.path.exists(dataset_dir):
return []
sample_images = []
for filename in sorted(os.listdir(dataset_dir)):
if filename.lower().endswith(('.png', '.jpg', '.jpeg')):
filepath = os.path.join(dataset_dir, filename)
sample_images.append(filepath)
return sample_images[:10] # Return max 10 sample images
with gr.Blocks(title="Conjunctiva Segmentation") as demo:
gr.Markdown("# Conjunctiva Segmentation - Mask R-CNN")
gr.Markdown(f"Running on: **{DEVICE}**")
with gr.Row():
inp = gr.Image(type="pil", label="Upload Image")
out = gr.Image(type="numpy", label="Segmented Output")
submit = gr.Button("Submit", variant="primary")
submit.click(fn=segment_image, inputs=inp, outputs=out)
# Add examples from Eye_Dataset folder
examples_list = get_sample_images()
if examples_list:
gr.Examples(
examples=examples_list,
inputs=inp,
outputs=out,
fn=segment_image,
cache_examples=False,
label="Sample Images"
)
if __name__ == "__main__":
demo.launch()