Spaces:
Sleeping
Sleeping
File size: 4,111 Bytes
2f9599d 777d4a6 2f9599d 3e85369 2f9599d 3e85369 2f9599d 3e85369 2f9599d 3e85369 2f9599d 3e85369 2f9599d 3e85369 2f9599d 3e85369 2f9599d 3e85369 2f9599d 3e85369 39bb474 3e85369 39bb474 2f9599d 3e85369 2f9599d 3e85369 2f9599d 3e85369 2f9599d 3e85369 2c1204e 3e85369 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
import gradio as gr
from PIL import Image
import numpy as np
import cv2
import os
import torch
from huggingface_hub import hf_hub_download
from model_utils import get_model, predict
# --- Config ---
CLASS_NAMES = ["background", "Pale Conjunctiva", "Normal Conjuctiva"]
# Private repo + file in your HF model
REPO_ID = "IFMedTech/Pallor_Mask_RCNN_Model"
FILENAME = "mask_rcnn_conjunctiva.pth"
# Determine device once at startup
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def get_weights_path():
"""
Download .pth from private HF repo using token stored in secrets.
"""
token = os.environ.get("HUGGINGFACE_TOKEN")
if not token:
raise ValueError("Please set HUGGINGFACE_TOKEN in the Space secrets for private model access.")
model_path = hf_hub_download(
repo_id=REPO_ID,
filename=FILENAME,
token=token,
)
return model_path
# Optional: cache model so it loads once (recommended for Gradio)
_MODEL = None
def get_cached_model():
global _MODEL
if _MODEL is None:
try:
weights_path = get_weights_path()
_MODEL = get_model(num_classes=3, weights_path=weights_path)
_MODEL.to(DEVICE)
_MODEL.eval()
except Exception as e:
raise RuntimeError(f"Failed to load model: {str(e)}")
return _MODEL
def segment_image(pil_img):
"""
pil_img comes from gr.Image(type="pil") => already a PIL.Image (or None).
Returns a numpy RGB image for gr.Image output.
"""
if pil_img is None:
return None
try:
image = pil_img.convert("RGB")
model = get_cached_model()
results = predict(model, image, device=DEVICE, class_names=CLASS_NAMES)
# Overlay masks/contours on the original image
image_np = np.array(image) # RGB uint8
for res in results:
mask = res["mask"] # expected float/0..1
label = res.get("label", "")
colored_mask = (mask > 0.5).astype(np.uint8) * 255
contours, _ = cv2.findContours(colored_mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
cv2.drawContours(image_np, contours, -1, (0, 255, 0), 2)
if len(contours) > 0 and len(contours[0]) > 0:
x, y = contours[0][0][0]
cv2.putText(
image_np,
str(label),
(int(x), int(y) - 10),
cv2.FONT_HERSHEY_SIMPLEX,
0.7,
(255, 0, 0),
2,
)
return image_np
except Exception as e:
print(f"Error during segmentation: {str(e)}")
return np.array(pil_img.convert("RGB")) # Return original image on error
def get_sample_images():
"""
Get list of sample images from Eye_Dataset folder.
Returns a list of image file paths.
"""
dataset_dir = "Eye_Dataset"
if not os.path.exists(dataset_dir):
return []
sample_images = []
for filename in sorted(os.listdir(dataset_dir)):
if filename.lower().endswith(('.png', '.jpg', '.jpeg')):
filepath = os.path.join(dataset_dir, filename)
sample_images.append(filepath)
return sample_images[:10] # Return max 10 sample images
with gr.Blocks(title="Conjunctiva Segmentation") as demo:
gr.Markdown("# Conjunctiva Segmentation - Mask R-CNN")
gr.Markdown(f"Running on: **{DEVICE}**")
with gr.Row():
inp = gr.Image(type="pil", label="Upload Image")
out = gr.Image(type="numpy", label="Segmented Output")
submit = gr.Button("Submit", variant="primary")
submit.click(fn=segment_image, inputs=inp, outputs=out)
# Add examples from Eye_Dataset folder
examples_list = get_sample_images()
if examples_list:
gr.Examples(
examples=examples_list,
inputs=inp,
outputs=out,
fn=segment_image,
cache_examples=False,
label="Sample Images"
)
if __name__ == "__main__":
demo.launch()
|