IFMedTechdemo commited on
Commit
2f9599d
·
verified ·
1 Parent(s): e536850

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +102 -70
app.py CHANGED
@@ -1,70 +1,102 @@
1
- import gradio as gr
2
- from PIL import Image
3
- import numpy as np
4
- import cv2
5
-
6
- from model_utils import get_model, predict
7
-
8
- # ---- Config ----
9
- CLASS_NAMES = ["background", "Normal_Conjuctiva", "conjunctiva_pallor"]
10
- WEIGHTS_PATH = r"skin_pallor_segment/Saved_model/mask_rcnn_conjunctiva.pth"
11
-
12
- # Optional: cache model so it loads once (recommended for Gradio)
13
- _MODEL = None
14
-
15
- def get_cached_model():
16
- global _MODEL
17
- if _MODEL is None:
18
- _MODEL = get_model(num_classes=3, weights_path=WEIGHTS_PATH)
19
- return _MODEL
20
-
21
- def segment_image(pil_img):
22
- """
23
- pil_img comes from gr.Image(type="pil") => already a PIL.Image (or None).
24
- Returns a numpy RGB image for gr.Image output.
25
- """
26
- if pil_img is None:
27
- return None
28
-
29
- image = pil_img.convert("RGB")
30
-
31
- model = get_cached_model()
32
- results = predict(model, image, device="cpu", class_names=CLASS_NAMES)
33
-
34
- # Overlay masks/contours on the original image
35
- image_np = np.array(image) # RGB uint8
36
-
37
- for res in results:
38
- mask = res["mask"] # expected float/0..1
39
- label = res.get("label", "")
40
-
41
- colored_mask = (mask > 0.5).astype(np.uint8) * 255
42
- contours, _ = cv2.findContours(colored_mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
43
-
44
- cv2.drawContours(image_np, contours, -1, (0, 255, 0), 2)
45
-
46
- if len(contours) > 0 and len(contours[0]) > 0:
47
- x, y = contours[0][0][0]
48
- cv2.putText(
49
- image_np,
50
- str(label),
51
- (int(x), int(y) - 10),
52
- cv2.FONT_HERSHEY_SIMPLEX,
53
- 0.7,
54
- (255, 0, 0),
55
- 2,
56
- )
57
-
58
- return image_np # RGB numpy array works with gr.Image
59
-
60
- with gr.Blocks() as demo:
61
- gr.Markdown("# Conjunctiva Segmentation - Mask R-CNN")
62
-
63
- with gr.Row():
64
- inp = gr.Image(type="pil", label="Upload Image (Preview)")
65
- out = gr.Image(type="numpy", label="Segmented Output")
66
-
67
- submit = gr.Button("Submit")
68
- submit.click(fn=segment_image, inputs=inp, outputs=out) # button triggers inference [web:14][web:19]
69
-
70
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ import numpy as np
4
+ import cv2
5
+ import os
6
+ import torch
7
+
8
+ from huggingface_hub import hf_hub_download
9
+ from model_utils import get_model, predict
10
+
11
+ # --- Config ---
12
+ CLASS_NAMES = ["background", "Normal_Conjuctiva", "conjunctiva_pallor"]
13
+
14
+ # Private repo + file in your HF model
15
+ REPO_ID = "IFMedTech/Pallor_Mask_RCNN_Model" # e.g. "IFMedTech/Eye-Anaemia-Model"
16
+ FILENAME = "mask_rcnn_conjunctiva.pth" # e.g. "models/mask_rcnn_conjunctiva.pth"
17
+
18
+ def get_weights_path():
19
+ """
20
+ Download .pth from private HF repo using token stored in secrets.
21
+ """
22
+ token = os.environ.get("HUGGINGFACE_TOKEN")
23
+ if not token:
24
+ raise ValueError("Please set HUGGINGFACE_TOKEN in the Space secrets for private model access.")
25
+
26
+ model_path = hf_hub_download(
27
+ repo_id=REPO_ID,
28
+ filename=FILENAME,
29
+ token=token,
30
+ )
31
+ return model_path
32
+
33
+ # Optional: cache model so it loads once (recommended for Gradio)
34
+ _MODEL = None
35
+
36
+ def get_cached_model():
37
+ global _MODEL
38
+ if _MODEL is None:
39
+ weights_path = get_weights_path()
40
+ # If your get_model expects weights_path, keep as is
41
+ _MODEL = get_model(num_classes=3, weights_path=weights_path)
42
+ # If instead you want to load via torch directly, you can:
43
+ # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
44
+ # model = get_model(num_classes=3)
45
+ # state_dict = torch.load(weights_path, map_location=device)
46
+ # model.load_state_dict(state_dict)
47
+ # model.to(device)
48
+ # model.eval()
49
+ # _MODEL = model
50
+ return _MODEL
51
+
52
+
53
+ def segment_image(pil_img):
54
+ """
55
+ pil_img comes from gr.Image(type="pil") => already a PIL.Image (or None).
56
+ Returns a numpy RGB image for gr.Image output.
57
+ """
58
+ if pil_img is None:
59
+ return None
60
+
61
+ image = pil_img.convert("RGB")
62
+
63
+ model = get_cached_model()
64
+ results = predict(model, image, device="cpu", class_names=CLASS_NAMES)
65
+
66
+ # Overlay masks/contours on the original image
67
+ image_np = np.array(image) # RGB uint8
68
+
69
+ for res in results:
70
+ mask = res["mask"] # expected float/0..1
71
+ label = res.get("label", "")
72
+
73
+ colored_mask = (mask > 0.5).astype(np.uint8) * 255
74
+ contours, _ = cv2.findContours(colored_mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
75
+
76
+ cv2.drawContours(image_np, contours, -1, (0, 255, 0), 2)
77
+
78
+ if len(contours) > 0 and len(contours[0]) > 0:
79
+ x, y = contours[0][0][0]
80
+ cv2.putText(
81
+ image_np,
82
+ str(label),
83
+ (int(x), int(y) - 10),
84
+ cv2.FONT_HERSHEY_SIMPLEX,
85
+ 0.7,
86
+ (255, 0, 0),
87
+ 2,
88
+ )
89
+
90
+ return image_np # RGB numpy array works with gr.Image
91
+
92
+ with gr.Blocks() as demo:
93
+ gr.Markdown("# Conjunctiva Segmentation - Mask R-CNN")
94
+
95
+ with gr.Row():
96
+ inp = gr.Image(type="pil", label="Upload Image (Preview)")
97
+ out = gr.Image(type="numpy", label="Segmented Output")
98
+
99
+ submit = gr.Button("Submit")
100
+ submit.click(fn=segment_image, inputs=inp, outputs=out) # button triggers inference [web:14][web:19]
101
+
102
+ demo.launch()