IFMedTechdemo commited on
Commit
d7946ce
·
verified ·
1 Parent(s): 85029f5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +326 -327
app.py CHANGED
@@ -1,327 +1,326 @@
1
- """
2
- Surgical-DeSAM Gradio App for Hugging Face Spaces
3
- Supports both Image and Video segmentation with ZeroGPU
4
- """
5
- import os
6
- import spaces
7
- import gradio as gr
8
- import torch
9
- import numpy as np
10
- import cv2
11
- from PIL import Image
12
- from huggingface_hub import hf_hub_download
13
- import tempfile
14
-
15
- # Model imports
16
- from models.detr_seg import DETR, SAMModel
17
- from models.backbone import build_backbone
18
- from models.transformer import build_transformer
19
- from util.misc import NestedTensor
20
-
21
- # Configuration
22
- MODEL_REPO = os.environ.get("MODEL_REPO", "IFMedTech/surgical-desam-weights")
23
- HF_TOKEN = os.environ.get("HF_TOKEN")
24
-
25
- INSTRUMENT_CLASSES = (
26
- 'bipolar_forceps', 'prograsp_forceps', 'large_needle_driver',
27
- 'monopolar_curved_scissors', 'ultrasound_probe', 'suction',
28
- 'clip_applier', 'stapler'
29
- )
30
-
31
- COLORS = [
32
- [0, 114, 189], [217, 83, 25], [237, 177, 32],
33
- [126, 47, 142], [119, 172, 48], [77, 190, 238],
34
- [162, 20, 47], [76, 76, 76]
35
- ]
36
-
37
- # Global model variables
38
- model = None
39
- seg_model = None
40
- device = None
41
-
42
-
43
- def download_weights():
44
- """Download model weights from private HF repo"""
45
- weights_dir = "weights"
46
- os.makedirs(weights_dir, exist_ok=True)
47
-
48
- desam_path = hf_hub_download(
49
- repo_id=MODEL_REPO,
50
- filename="surgical_desam_1024.pth",
51
- token=HF_TOKEN,
52
- local_dir=weights_dir
53
- )
54
-
55
- sam_path = hf_hub_download(
56
- repo_id=MODEL_REPO,
57
- filename="sam_vit_b_01ec64.pth",
58
- token=HF_TOKEN,
59
- local_dir=weights_dir
60
- )
61
-
62
- swin_dir = "swin_backbone"
63
- os.makedirs(swin_dir, exist_ok=True)
64
- hf_hub_download(
65
- repo_id=MODEL_REPO,
66
- filename="swin_base_patch4_window7_224_22kto1k.pth",
67
- token=HF_TOKEN,
68
- local_dir=swin_dir
69
- )
70
-
71
- return desam_path, sam_path
72
-
73
-
74
- class Args:
75
- """Mock args for model building"""
76
- backbone = 'swin_B_224_22k'
77
- dilation = False
78
- position_embedding = 'sine'
79
- hidden_dim = 256
80
- dropout = 0.1
81
- nheads = 8
82
- dim_feedforward = 2048
83
- enc_layers = 6
84
- dec_layers = 6
85
- pre_norm = False
86
- num_queries = 100
87
- aux_loss = False
88
- lr_backbone = 1e-5
89
- masks = False
90
- dataset_file = 'endovis18'
91
- device = 'cuda'
92
- backbone_dir = './swin_backbone'
93
-
94
-
95
- def load_models():
96
- """Load DETR and SAM models"""
97
- global model, seg_model, device
98
-
99
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
100
- desam_path, sam_path = download_weights()
101
-
102
- args = Args()
103
- args.device = str(device)
104
-
105
- backbone = build_backbone(args)
106
- transformer = build_transformer(args)
107
-
108
- model = DETR(
109
- backbone,
110
- transformer,
111
- num_classes=9,
112
- num_queries=args.num_queries,
113
- aux_loss=args.aux_loss,
114
- )
115
-
116
- checkpoint = torch.load(desam_path, map_location='cpu', weights_only=False)
117
- model.load_state_dict(checkpoint['model'], strict=False)
118
- model.to(device)
119
- model.eval()
120
-
121
- seg_model = SAMModel(device=device, ckpt_path=sam_path)
122
- if 'seg_model' in checkpoint:
123
- seg_model.load_state_dict(checkpoint['seg_model'])
124
- seg_model.to(device)
125
- seg_model.eval()
126
-
127
- print("Models loaded successfully!")
128
-
129
-
130
- def preprocess_frame(frame):
131
- """Preprocess frame for model input"""
132
- img = cv2.resize(frame, (1024, 1024))
133
- img = img.astype(np.float32) / 255.0
134
- mean = np.array([0.485, 0.456, 0.406])
135
- std = np.array([0.229, 0.224, 0.225])
136
- img = (img - mean) / std
137
- img_tensor = torch.from_numpy(img.transpose(2, 0, 1)).float()
138
- return img_tensor
139
-
140
-
141
- def box_cxcywh_to_xyxy(x):
142
- """Convert boxes from center format to corner format"""
143
- x_c, y_c, w, h = x.unbind(-1)
144
- b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
145
- (x_c + 0.5 * w), (y_c + 0.5 * h)]
146
- return torch.stack(b, dim=-1)
147
-
148
-
149
- def process_single_frame(frame_rgb, h, w):
150
- """Process a single frame and return segmented result"""
151
- global model, seg_model, device
152
-
153
- img_tensor = preprocess_frame(frame_rgb).unsqueeze(0).to(device)
154
-
155
- mask = torch.zeros((1, 1024, 1024), dtype=torch.bool, device=device)
156
- samples = NestedTensor(img_tensor, mask)
157
-
158
- with torch.no_grad():
159
- outputs, image_embeddings = model(samples)
160
-
161
- probas = outputs['pred_logits'].softmax(-1)[0, :, :-1]
162
- keep = probas.max(-1).values > 0.3
163
-
164
- if not keep.any():
165
- return frame_rgb # No detections
166
-
167
- boxes = outputs['pred_boxes'][0, keep]
168
- scores = probas[keep].max(-1).values.cpu().numpy()
169
- labels = probas[keep].argmax(-1).cpu().numpy()
170
-
171
- boxes_scaled = box_cxcywh_to_xyxy(boxes) * torch.tensor([w, h, w, h], device=device)
172
- boxes_np = boxes_scaled.cpu().numpy()
173
-
174
- low_res_masks, pred_masks, _ = seg_model(
175
- img_tensor, boxes, image_embeddings,
176
- sizes=(1024, 1024), add_noise=False
177
- )
178
- masks_np = pred_masks.cpu().numpy()
179
-
180
- # Draw on frame
181
- result = frame_rgb.copy()
182
- for i, (box, label, mask_pred, score) in enumerate(zip(boxes_np, labels, masks_np, scores)):
183
- if score < 0.3:
184
- continue
185
-
186
- color = COLORS[label % len(COLORS)]
187
-
188
- # Draw mask
189
- mask_resized = cv2.resize(mask_pred, (w, h))
190
- mask_bool = mask_resized > 0.5
191
- overlay = result.copy()
192
- overlay[mask_bool] = color
193
- result = cv2.addWeighted(result, 0.6, overlay, 0.4, 0)
194
-
195
- # Draw box
196
- x1, y1, x2, y2 = box.astype(int)
197
- cv2.rectangle(result, (x1, y1), (x2, y2), color, 2)
198
-
199
- # Draw label
200
- label_text = f"{INSTRUMENT_CLASSES[label]}: {score:.2f}"
201
- cv2.putText(result, label_text, (x1, y1 - 10),
202
- cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
203
-
204
- return result
205
-
206
-
207
- @spaces.GPU
208
- def predict_image(image):
209
- """Run inference on input image"""
210
- global model, seg_model, device
211
-
212
- if model is None:
213
- load_models()
214
-
215
- if image is None:
216
- return None
217
-
218
- frame_rgb = np.array(image)
219
- h, w = frame_rgb.shape[:2]
220
-
221
- result = process_single_frame(frame_rgb, h, w)
222
-
223
- return Image.fromarray(result)
224
-
225
-
226
- @spaces.GPU(duration=300)
227
- def predict_video(video_path, progress=gr.Progress()):
228
- """Process video and return segmented video"""
229
- global model, seg_model, device
230
-
231
- if model is None:
232
- progress(0, desc="Loading models...")
233
- load_models()
234
-
235
- if video_path is None:
236
- return None
237
-
238
- # Open video
239
- cap = cv2.VideoCapture(video_path)
240
- fps = int(cap.get(cv2.CAP_PROP_FPS))
241
- total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
242
- width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
243
- height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
244
-
245
- # Output video
246
- output_path = tempfile.mktemp(suffix=".mp4")
247
- fourcc = cv2.VideoWriter_fourcc(*'mp4v')
248
- out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
249
-
250
- frame_count = 0
251
- while True:
252
- ret, frame = cap.read()
253
- if not ret:
254
- break
255
-
256
- # BGR to RGB
257
- frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
258
-
259
- # Process frame
260
- result_rgb = process_single_frame(frame_rgb, height, width)
261
-
262
- # RGB to BGR for output
263
- result_bgr = cv2.cvtColor(result_rgb, cv2.COLOR_RGB2BGR)
264
- out.write(result_bgr)
265
-
266
- frame_count += 1
267
- progress(frame_count / total_frames, desc=f"Processing frame {frame_count}/{total_frames}")
268
-
269
- cap.release()
270
- out.release()
271
-
272
- return output_path
273
-
274
-
275
- # Create Gradio interface
276
- with gr.Blocks(title="Surgical-DeSAM", theme=gr.themes.Soft()) as demo:
277
- gr.Markdown("# 🔬 Surgical-DeSAM")
278
- gr.Markdown("Segment surgical instruments in images or videos using DeSAM architecture.")
279
-
280
- with gr.Tabs():
281
- # Image Tab
282
- with gr.TabItem("🖼️ Image Segmentation"):
283
- with gr.Row():
284
- with gr.Column():
285
- input_image = gr.Image(type="pil", label="Input Image")
286
- image_btn = gr.Button("Segment Image", variant="primary")
287
- with gr.Column():
288
- output_image = gr.Image(type="pil", label="Segmentation Result")
289
-
290
- image_btn.click(fn=predict_image, inputs=input_image, outputs=output_image)
291
-
292
- gr.Examples(
293
- examples=[
294
- "examples/example_1.png",
295
- "examples/example_2.png",
296
- "examples/example_3.png",
297
- "examples/example_4.png",
298
- ],
299
- inputs=input_image,
300
- label="Example Surgical Images"
301
- )
302
-
303
- # Video Tab
304
- with gr.TabItem("🎬 Video Segmentation"):
305
- with gr.Row():
306
- with gr.Column():
307
- input_video = gr.Video(label="Input Video")
308
- video_btn = gr.Button("Segment Video", variant="primary")
309
- with gr.Column():
310
- output_video = gr.Video(label="Segmentation Result")
311
-
312
- video_btn.click(fn=predict_video, inputs=input_video, outputs=output_video)
313
-
314
- gr.Examples(
315
- examples=["examples/surgical_demo.mp4"],
316
- inputs=input_video,
317
- label="Example Surgical Video"
318
- )
319
-
320
- gr.Markdown("""
321
- ## Detected Classes
322
- Bipolar Forceps | Prograsp Forceps | Large Needle Driver | Monopolar Curved Scissors |
323
- Ultrasound Probe | Suction | Clip Applier | Stapler
324
- """)
325
-
326
- if __name__ == "__main__":
327
- demo.launch()
 
1
+ """
2
+ Surgical-DeSAM Gradio App for Hugging Face Spaces
3
+ Supports both Image and Video segmentation with ZeroGPU
4
+ """
5
+ import os
6
+ import spaces
7
+ import gradio as gr
8
+ import torch
9
+ import numpy as np
10
+ import cv2
11
+ from PIL import Image
12
+ from huggingface_hub import hf_hub_download
13
+ import tempfile
14
+
15
+ # Model imports
16
+ from models.detr_seg import DETR, SAMModel
17
+ from models.backbone import build_backbone
18
+ from models.transformer import build_transformer
19
+ from util.misc import NestedTensor
20
+
21
+ # Configuration
22
+ MODEL_REPO = os.environ.get("MODEL_REPO", "IFMedTech/surgical-desam-weights")
23
+ HF_TOKEN = os.environ.get("HF_TOKEN")
24
+
25
+ INSTRUMENT_CLASSES = (
26
+ 'bipolar_forceps', 'prograsp_forceps', 'large_needle_driver',
27
+ 'monopolar_curved_scissors', 'ultrasound_probe', 'suction',
28
+ 'clip_applier', 'stapler'
29
+ )
30
+
31
+ COLORS = [
32
+ [0, 114, 189], [217, 83, 25], [237, 177, 32],
33
+ [126, 47, 142], [119, 172, 48], [77, 190, 238],
34
+ [162, 20, 47], [76, 76, 76]
35
+ ]
36
+
37
+ # Global model variables
38
+ model = None
39
+ seg_model = None
40
+ device = None
41
+
42
+
43
+ def download_weights():
44
+ """Download model weights from private HF repo"""
45
+ weights_dir = "weights"
46
+ os.makedirs(weights_dir, exist_ok=True)
47
+
48
+ desam_path = hf_hub_download(
49
+ repo_id=MODEL_REPO,
50
+ filename="surgical_desam_1024.pth",
51
+ token=HF_TOKEN,
52
+ local_dir=weights_dir
53
+ )
54
+
55
+ sam_path = hf_hub_download(
56
+ repo_id=MODEL_REPO,
57
+ filename="sam_vit_b_01ec64.pth",
58
+ token=HF_TOKEN,
59
+ local_dir=weights_dir
60
+ )
61
+
62
+ swin_dir = "swin_backbone"
63
+ os.makedirs(swin_dir, exist_ok=True)
64
+ hf_hub_download(
65
+ repo_id=MODEL_REPO,
66
+ filename="swin_base_patch4_window7_224_22kto1k.pth",
67
+ token=HF_TOKEN,
68
+ local_dir=swin_dir
69
+ )
70
+
71
+ return desam_path, sam_path
72
+
73
+
74
+ class Args:
75
+ """Mock args for model building"""
76
+ backbone = 'swin_B_224_22k'
77
+ dilation = False
78
+ position_embedding = 'sine'
79
+ hidden_dim = 256
80
+ dropout = 0.1
81
+ nheads = 8
82
+ dim_feedforward = 2048
83
+ enc_layers = 6
84
+ dec_layers = 6
85
+ pre_norm = False
86
+ num_queries = 100
87
+ aux_loss = False
88
+ lr_backbone = 1e-5
89
+ masks = False
90
+ dataset_file = 'endovis18'
91
+ device = 'cuda'
92
+ backbone_dir = './swin_backbone'
93
+
94
+
95
+ def load_models():
96
+ """Load DETR and SAM models"""
97
+ global model, seg_model, device
98
+
99
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
100
+ desam_path, sam_path = download_weights()
101
+
102
+ args = Args()
103
+ args.device = str(device)
104
+
105
+ backbone = build_backbone(args)
106
+ transformer = build_transformer(args)
107
+
108
+ model = DETR(
109
+ backbone,
110
+ transformer,
111
+ num_classes=9,
112
+ num_queries=args.num_queries,
113
+ aux_loss=args.aux_loss,
114
+ )
115
+
116
+ checkpoint = torch.load(desam_path, map_location='cpu', weights_only=False)
117
+ model.load_state_dict(checkpoint['model'], strict=False)
118
+ model.to(device)
119
+ model.eval()
120
+
121
+ seg_model = SAMModel(device=device, ckpt_path=sam_path)
122
+ if 'seg_model' in checkpoint:
123
+ seg_model.load_state_dict(checkpoint['seg_model'])
124
+ seg_model.to(device)
125
+ seg_model.eval()
126
+
127
+ print("Models loaded successfully!")
128
+
129
+
130
+ def preprocess_frame(frame):
131
+ """Preprocess frame for model input"""
132
+ img = cv2.resize(frame, (1024, 1024))
133
+ img = img.astype(np.float32) / 255.0
134
+ mean = np.array([0.485, 0.456, 0.406])
135
+ std = np.array([0.229, 0.224, 0.225])
136
+ img = (img - mean) / std
137
+ img_tensor = torch.from_numpy(img.transpose(2, 0, 1)).float()
138
+ return img_tensor
139
+
140
+
141
+ def box_cxcywh_to_xyxy(x):
142
+ """Convert boxes from center format to corner format"""
143
+ x_c, y_c, w, h = x.unbind(-1)
144
+ b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
145
+ (x_c + 0.5 * w), (y_c + 0.5 * h)]
146
+ return torch.stack(b, dim=-1)
147
+
148
+
149
+ def process_single_frame(frame_rgb, h, w):
150
+ """Process a single frame and return segmented result"""
151
+ global model, seg_model, device
152
+
153
+ img_tensor = preprocess_frame(frame_rgb).unsqueeze(0).to(device)
154
+
155
+ mask = torch.zeros((1, 1024, 1024), dtype=torch.bool, device=device)
156
+ samples = NestedTensor(img_tensor, mask)
157
+
158
+ with torch.no_grad():
159
+ outputs, image_embeddings = model(samples)
160
+
161
+ probas = outputs['pred_logits'].softmax(-1)[0, :, :-1]
162
+ keep = probas.max(-1).values > 0.3
163
+
164
+ if not keep.any():
165
+ return frame_rgb # No detections
166
+
167
+ boxes = outputs['pred_boxes'][0, keep]
168
+ scores = probas[keep].max(-1).values.cpu().numpy()
169
+ labels = probas[keep].argmax(-1).cpu().numpy()
170
+
171
+ boxes_scaled = box_cxcywh_to_xyxy(boxes) * torch.tensor([w, h, w, h], device=device)
172
+ boxes_np = boxes_scaled.cpu().numpy()
173
+
174
+ low_res_masks, pred_masks, _ = seg_model(
175
+ img_tensor, boxes, image_embeddings,
176
+ sizes=(1024, 1024), add_noise=False
177
+ )
178
+ masks_np = pred_masks.cpu().numpy()
179
+
180
+ # Draw on frame
181
+ result = frame_rgb.copy()
182
+ for i, (box, label, mask_pred, score) in enumerate(zip(boxes_np, labels, masks_np, scores)):
183
+ if score < 0.3:
184
+ continue
185
+
186
+ color = COLORS[label % len(COLORS)]
187
+
188
+ # Draw mask
189
+ mask_resized = cv2.resize(mask_pred, (w, h))
190
+ mask_bool = mask_resized > 0.5
191
+ overlay = result.copy()
192
+ overlay[mask_bool] = color
193
+ result = cv2.addWeighted(result, 0.6, overlay, 0.4, 0)
194
+
195
+ # Draw box
196
+ x1, y1, x2, y2 = box.astype(int)
197
+ cv2.rectangle(result, (x1, y1), (x2, y2), color, 2)
198
+
199
+ # Draw label
200
+ label_text = f"{INSTRUMENT_CLASSES[label]}: {score:.2f}"
201
+ cv2.putText(result, label_text, (x1, y1 - 10),
202
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
203
+
204
+ return result
205
+
206
+
207
+ @spaces.GPU
208
+ def predict_image(image):
209
+ """Run inference on input image"""
210
+ global model, seg_model, device
211
+
212
+ if model is None:
213
+ load_models()
214
+
215
+ if image is None:
216
+ return None
217
+
218
+ frame_rgb = np.array(image)
219
+ h, w = frame_rgb.shape[:2]
220
+
221
+ result = process_single_frame(frame_rgb, h, w)
222
+
223
+ return Image.fromarray(result)
224
+
225
+
226
+ @spaces.GPU(duration=300)
227
+ def predict_video(video_path, progress=gr.Progress()):
228
+ """Process video and return segmented video"""
229
+ global model, seg_model, device
230
+
231
+ if model is None:
232
+ progress(0, desc="Loading models...")
233
+ load_models()
234
+
235
+ if video_path is None:
236
+ return None
237
+
238
+ # Open video
239
+ cap = cv2.VideoCapture(video_path)
240
+ fps = int(cap.get(cv2.CAP_PROP_FPS))
241
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
242
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
243
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
244
+
245
+ # Output video
246
+ output_path = tempfile.mktemp(suffix=".mp4")
247
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
248
+ out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
249
+
250
+ frame_count = 0
251
+ while True:
252
+ ret, frame = cap.read()
253
+ if not ret:
254
+ break
255
+
256
+ # BGR to RGB
257
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
258
+
259
+ # Process frame
260
+ result_rgb = process_single_frame(frame_rgb, height, width)
261
+
262
+ # RGB to BGR for output
263
+ result_bgr = cv2.cvtColor(result_rgb, cv2.COLOR_RGB2BGR)
264
+ out.write(result_bgr)
265
+
266
+ frame_count += 1
267
+ progress(frame_count / total_frames, desc=f"Processing frame {frame_count}/{total_frames}")
268
+
269
+ cap.release()
270
+ out.release()
271
+
272
+ return output_path
273
+
274
+
275
+ # Create Gradio interface
276
+ with gr.Blocks(title="Surgical-DeSAM", theme=gr.themes.Soft()) as demo:
277
+ gr.Markdown("# 🔬 Surgical-DeSAM")
278
+ gr.Markdown("Segment surgical instruments in images or videos using DeSAM architecture.")
279
+
280
+ with gr.Tabs():
281
+ # Image Tab
282
+ with gr.TabItem("🖼️ Image Segmentation"):
283
+ with gr.Row():
284
+ with gr.Column():
285
+ input_image = gr.Image(type="pil", label="Input Image")
286
+ image_btn = gr.Button("Segment Image", variant="primary")
287
+ with gr.Column():
288
+ output_image = gr.Image(type="pil", label="Segmentation Result")
289
+
290
+ image_btn.click(fn=predict_image, inputs=input_image, outputs=output_image)
291
+
292
+ gr.Examples(
293
+ examples=[
294
+ "examples/example_2.png",
295
+ "examples/example_3.png",
296
+ "examples/example_4.png",
297
+ ],
298
+ inputs=input_image,
299
+ label="Example Surgical Images"
300
+ )
301
+
302
+ # Video Tab
303
+ with gr.TabItem("🎬 Video Segmentation"):
304
+ with gr.Row():
305
+ with gr.Column():
306
+ input_video = gr.Video(label="Input Video")
307
+ video_btn = gr.Button("Segment Video", variant="primary")
308
+ with gr.Column():
309
+ output_video = gr.Video(label="Segmentation Result")
310
+
311
+ video_btn.click(fn=predict_video, inputs=input_video, outputs=output_video)
312
+
313
+ gr.Examples(
314
+ examples=["examples/surgical_demo.mp4"],
315
+ inputs=input_video,
316
+ label="Example Surgical Video"
317
+ )
318
+
319
+ gr.Markdown("""
320
+ ## Detected Classes
321
+ Bipolar Forceps | Prograsp Forceps | Large Needle Driver | Monopolar Curved Scissors |
322
+ Ultrasound Probe | Suction | Clip Applier | Stapler
323
+ """)
324
+
325
+ if __name__ == "__main__":
326
+ demo.launch()