File size: 4,111 Bytes
2f9599d
 
 
 
 
 
 
 
 
 
 
777d4a6
2f9599d
 
3e85369
 
 
 
 
2f9599d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3e85369
 
 
 
 
 
 
2f9599d
 
 
 
 
 
 
 
 
 
 
3e85369
 
 
 
 
2f9599d
3e85369
 
2f9599d
3e85369
 
 
2f9599d
3e85369
 
2f9599d
3e85369
2f9599d
3e85369
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2f9599d
3e85369
 
39bb474
 
 
 
 
 
 
 
 
3e85369
39bb474
 
 
 
 
 
2f9599d
3e85369
2f9599d
3e85369
2f9599d
 
3e85369
2f9599d
3e85369
 
 
 
2c1204e
 
 
3e85369
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import gradio as gr
from PIL import Image
import numpy as np
import cv2
import os
import torch

from huggingface_hub import hf_hub_download
from model_utils import get_model, predict

# --- Config ---
CLASS_NAMES = ["background", "Pale Conjunctiva", "Normal Conjuctiva"]

# Private repo + file in your HF model
REPO_ID = "IFMedTech/Pallor_Mask_RCNN_Model"
FILENAME = "mask_rcnn_conjunctiva.pth"

# Determine device once at startup
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def get_weights_path():
    """
    Download .pth from private HF repo using token stored in secrets.
    """
    token = os.environ.get("HUGGINGFACE_TOKEN")
    if not token:
        raise ValueError("Please set HUGGINGFACE_TOKEN in the Space secrets for private model access.")

    model_path = hf_hub_download(
        repo_id=REPO_ID,
        filename=FILENAME,
        token=token,
    )
    return model_path

# Optional: cache model so it loads once (recommended for Gradio)
_MODEL = None

def get_cached_model():
    global _MODEL
    if _MODEL is None:
        try:
            weights_path = get_weights_path()
            _MODEL = get_model(num_classes=3, weights_path=weights_path)
            _MODEL.to(DEVICE)
            _MODEL.eval()
        except Exception as e:
            raise RuntimeError(f"Failed to load model: {str(e)}")
    return _MODEL


def segment_image(pil_img):
    """
    pil_img comes from gr.Image(type="pil") => already a PIL.Image (or None).
    Returns a numpy RGB image for gr.Image output.
    """
    if pil_img is None:
        return None

    try:
        image = pil_img.convert("RGB")

        model = get_cached_model()
        results = predict(model, image, device=DEVICE, class_names=CLASS_NAMES)

        # Overlay masks/contours on the original image
        image_np = np.array(image)  # RGB uint8

        for res in results:
            mask = res["mask"]  # expected float/0..1
            label = res.get("label", "")

            colored_mask = (mask > 0.5).astype(np.uint8) * 255
            contours, _ = cv2.findContours(colored_mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)

            cv2.drawContours(image_np, contours, -1, (0, 255, 0), 2)

            if len(contours) > 0 and len(contours[0]) > 0:
                x, y = contours[0][0][0]
                cv2.putText(
                    image_np,
                    str(label),
                    (int(x), int(y) - 10),
                    cv2.FONT_HERSHEY_SIMPLEX,
                    0.7,
                    (255, 0, 0),
                    2,
                )
        return image_np
    except Exception as e:
        print(f"Error during segmentation: {str(e)}")
        return np.array(pil_img.convert("RGB"))  # Return original image on error


def get_sample_images():
    """
    Get list of sample images from Eye_Dataset folder.
    Returns a list of image file paths.
    """
    dataset_dir = "Eye_Dataset"
    if not os.path.exists(dataset_dir):
        return []
    
    sample_images = []
    for filename in sorted(os.listdir(dataset_dir)):
        if filename.lower().endswith(('.png', '.jpg', '.jpeg')):
            filepath = os.path.join(dataset_dir, filename)
            sample_images.append(filepath)
    
    return sample_images[:10]  # Return max 10 sample images


with gr.Blocks(title="Conjunctiva Segmentation") as demo:
    gr.Markdown("# Conjunctiva Segmentation - Mask R-CNN")
    gr.Markdown(f"Running on: **{DEVICE}**")

    with gr.Row():
        inp = gr.Image(type="pil", label="Upload Image")
        out = gr.Image(type="numpy", label="Segmented Output")
    
    submit = gr.Button("Submit", variant="primary")
    submit.click(fn=segment_image, inputs=inp, outputs=out)
    
    # Add examples from Eye_Dataset folder
    examples_list = get_sample_images()
    if examples_list:
        gr.Examples(
            examples=examples_list,
            inputs=inp,
            outputs=out,
            fn=segment_image,
            cache_examples=False,
            label="Sample Images"
        )

if __name__ == "__main__":
    demo.launch()