IFMedTechdemo commited on
Commit
56cc62f
·
verified ·
1 Parent(s): ef99afc

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. .gitattributes +1 -0
  2. app.py +136 -78
  3. examples/sample_surgical.png +3 -0
.gitattributes CHANGED
@@ -35,3 +35,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  models/sam/notebooks/images/groceries.jpg filter=lfs diff=lfs merge=lfs -text
37
  models/sam/notebooks/images/truck.jpg filter=lfs diff=lfs merge=lfs -text
 
 
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  models/sam/notebooks/images/groceries.jpg filter=lfs diff=lfs merge=lfs -text
37
  models/sam/notebooks/images/truck.jpg filter=lfs diff=lfs merge=lfs -text
38
+ examples/sample_surgical.png filter=lfs diff=lfs merge=lfs -text
app.py CHANGED
@@ -1,6 +1,6 @@
1
  """
2
  Surgical-DeSAM Gradio App for Hugging Face Spaces
3
- Uses ZeroGPU for inference
4
  """
5
  import os
6
  import spaces
@@ -10,8 +10,9 @@ import numpy as np
10
  import cv2
11
  from PIL import Image
12
  from huggingface_hub import hf_hub_download
 
13
 
14
- # Model imports (will be copied to hf_space)
15
  from models.detr_seg import DETR, SAMModel
16
  from models.backbone import build_backbone
17
  from models.transformer import build_transformer
@@ -44,7 +45,6 @@ def download_weights():
44
  weights_dir = "weights"
45
  os.makedirs(weights_dir, exist_ok=True)
46
 
47
- # Download DeSAM weights
48
  desam_path = hf_hub_download(
49
  repo_id=MODEL_REPO,
50
  filename="surgical_desam_1024.pth",
@@ -52,7 +52,6 @@ def download_weights():
52
  local_dir=weights_dir
53
  )
54
 
55
- # Download SAM weights
56
  sam_path = hf_hub_download(
57
  repo_id=MODEL_REPO,
58
  filename="sam_vit_b_01ec64.pth",
@@ -60,10 +59,9 @@ def download_weights():
60
  local_dir=weights_dir
61
  )
62
 
63
- # Download Swin backbone
64
  swin_dir = "swin_backbone"
65
  os.makedirs(swin_dir, exist_ok=True)
66
- swin_path = hf_hub_download(
67
  repo_id=MODEL_REPO,
68
  filename="swin_base_patch4_window7_224_22kto1k.pth",
69
  token=HF_TOKEN,
@@ -99,11 +97,8 @@ def load_models():
99
  global model, seg_model, device
100
 
101
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
102
-
103
- # Download weights
104
  desam_path, sam_path = download_weights()
105
 
106
- # Build model
107
  args = Args()
108
  args.device = str(device)
109
 
@@ -113,18 +108,16 @@ def load_models():
113
  model = DETR(
114
  backbone,
115
  transformer,
116
- num_classes=9, # 8 classes + background
117
  num_queries=args.num_queries,
118
  aux_loss=args.aux_loss,
119
  )
120
 
121
- # Load weights
122
  checkpoint = torch.load(desam_path, map_location='cpu')
123
  model.load_state_dict(checkpoint['model'], strict=False)
124
  model.to(device)
125
  model.eval()
126
 
127
- # Load SAM model
128
  seg_model = SAMModel(device=device, ckpt_path=sam_path)
129
  if 'seg_model' in checkpoint:
130
  seg_model.load_state_dict(checkpoint['seg_model'])
@@ -134,18 +127,13 @@ def load_models():
134
  print("Models loaded successfully!")
135
 
136
 
137
- def preprocess_image(image):
138
- """Preprocess image for model input"""
139
- # Resize to 1024x1024
140
- img = cv2.resize(np.array(image), (1024, 1024))
141
  img = img.astype(np.float32) / 255.0
142
-
143
- # Normalize
144
  mean = np.array([0.485, 0.456, 0.406])
145
  std = np.array([0.229, 0.224, 0.225])
146
  img = (img - mean) / std
147
-
148
- # Convert to tensor
149
  img_tensor = torch.from_numpy(img.transpose(2, 0, 1)).float()
150
  return img_tensor
151
 
@@ -158,38 +146,66 @@ def box_cxcywh_to_xyxy(x):
158
  return torch.stack(b, dim=-1)
159
 
160
 
161
- def create_visualization(image, boxes, labels, masks, scores):
162
- """Create visualization with boxes and masks"""
163
- img = np.array(image).copy()
164
- h, w = img.shape[:2]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
 
166
- for i, (box, label, mask, score) in enumerate(zip(boxes, labels, masks, scores)):
 
 
167
  if score < 0.3:
168
  continue
169
-
170
  color = COLORS[label % len(COLORS)]
171
 
172
  # Draw mask
173
- mask_resized = cv2.resize(mask, (w, h))
174
  mask_bool = mask_resized > 0.5
175
- overlay = img.copy()
176
  overlay[mask_bool] = color
177
- img = cv2.addWeighted(img, 0.6, overlay, 0.4, 0)
178
 
179
  # Draw box
180
  x1, y1, x2, y2 = box.astype(int)
181
- cv2.rectangle(img, (x1, y1), (x2, y2), color, 2)
182
 
183
  # Draw label
184
  label_text = f"{INSTRUMENT_CLASSES[label]}: {score:.2f}"
185
- cv2.putText(img, label_text, (x1, y1 - 10),
186
  cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
187
 
188
- return Image.fromarray(img)
189
 
190
 
191
  @spaces.GPU
192
- def predict(image):
193
  """Run inference on input image"""
194
  global model, seg_model, device
195
 
@@ -199,66 +215,108 @@ def predict(image):
199
  if image is None:
200
  return None
201
 
202
- # Preprocess
203
- img_tensor = preprocess_image(image).unsqueeze(0).to(device)
204
 
205
- # Create nested tensor
206
- mask = torch.zeros((1, 1024, 1024), dtype=torch.bool, device=device)
207
- samples = NestedTensor(img_tensor, mask)
208
 
209
- # Run detection
210
- with torch.no_grad():
211
- outputs, image_embeddings = model(samples)
212
-
213
- # Get predictions
214
- probas = outputs['pred_logits'].softmax(-1)[0, :, :-1]
215
- keep = probas.max(-1).values > 0.3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
 
217
- if not keep.any():
218
- return image # No detections
219
 
220
- # Get boxes
221
- boxes = outputs['pred_boxes'][0, keep]
222
- scores = probas[keep].max(-1).values.cpu().numpy()
223
- labels = probas[keep].argmax(-1).cpu().numpy()
224
 
225
- # Scale boxes to image size
226
- h, w = image.size[1], image.size[0]
227
- boxes_scaled = box_cxcywh_to_xyxy(boxes) * torch.tensor([w, h, w, h], device=device)
228
- boxes_np = boxes_scaled.cpu().numpy()
229
 
230
- # Run segmentation
231
- low_res_masks, pred_masks, _ = seg_model(
232
- img_tensor, boxes, image_embeddings,
233
- sizes=(1024, 1024), add_noise=False
234
- )
235
- masks_np = pred_masks.cpu().numpy()
236
 
237
- # Create visualization
238
- result = create_visualization(image, boxes_np, labels, masks_np, scores)
239
 
240
- return result
241
 
242
 
243
  # Create Gradio interface
244
- with gr.Blocks(title="Surgical-DeSAM") as demo:
245
  gr.Markdown("# 🔬 Surgical-DeSAM")
246
- gr.Markdown("Upload a surgical image to segment instruments.")
247
 
248
- with gr.Row():
249
- with gr.Column():
250
- input_image = gr.Image(type="pil", label="Input Image")
251
- submit_btn = gr.Button("Segment", variant="primary")
 
 
 
 
 
 
 
 
 
 
 
 
 
252
 
253
- with gr.Column():
254
- output_image = gr.Image(type="pil", label="Segmentation Result")
255
-
256
- submit_btn.click(fn=predict, inputs=input_image, outputs=output_image)
 
 
 
 
 
 
 
 
 
 
 
 
257
 
258
- gr.Examples(
259
- examples=[], # Add example images if available
260
- inputs=input_image
261
- )
 
262
 
263
  if __name__ == "__main__":
264
  demo.launch()
 
1
  """
2
  Surgical-DeSAM Gradio App for Hugging Face Spaces
3
+ Supports both Image and Video segmentation with ZeroGPU
4
  """
5
  import os
6
  import spaces
 
10
  import cv2
11
  from PIL import Image
12
  from huggingface_hub import hf_hub_download
13
+ import tempfile
14
 
15
+ # Model imports
16
  from models.detr_seg import DETR, SAMModel
17
  from models.backbone import build_backbone
18
  from models.transformer import build_transformer
 
45
  weights_dir = "weights"
46
  os.makedirs(weights_dir, exist_ok=True)
47
 
 
48
  desam_path = hf_hub_download(
49
  repo_id=MODEL_REPO,
50
  filename="surgical_desam_1024.pth",
 
52
  local_dir=weights_dir
53
  )
54
 
 
55
  sam_path = hf_hub_download(
56
  repo_id=MODEL_REPO,
57
  filename="sam_vit_b_01ec64.pth",
 
59
  local_dir=weights_dir
60
  )
61
 
 
62
  swin_dir = "swin_backbone"
63
  os.makedirs(swin_dir, exist_ok=True)
64
+ hf_hub_download(
65
  repo_id=MODEL_REPO,
66
  filename="swin_base_patch4_window7_224_22kto1k.pth",
67
  token=HF_TOKEN,
 
97
  global model, seg_model, device
98
 
99
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
100
  desam_path, sam_path = download_weights()
101
 
 
102
  args = Args()
103
  args.device = str(device)
104
 
 
108
  model = DETR(
109
  backbone,
110
  transformer,
111
+ num_classes=9,
112
  num_queries=args.num_queries,
113
  aux_loss=args.aux_loss,
114
  )
115
 
 
116
  checkpoint = torch.load(desam_path, map_location='cpu')
117
  model.load_state_dict(checkpoint['model'], strict=False)
118
  model.to(device)
119
  model.eval()
120
 
 
121
  seg_model = SAMModel(device=device, ckpt_path=sam_path)
122
  if 'seg_model' in checkpoint:
123
  seg_model.load_state_dict(checkpoint['seg_model'])
 
127
  print("Models loaded successfully!")
128
 
129
 
130
+ def preprocess_frame(frame):
131
+ """Preprocess frame for model input"""
132
+ img = cv2.resize(frame, (1024, 1024))
 
133
  img = img.astype(np.float32) / 255.0
 
 
134
  mean = np.array([0.485, 0.456, 0.406])
135
  std = np.array([0.229, 0.224, 0.225])
136
  img = (img - mean) / std
 
 
137
  img_tensor = torch.from_numpy(img.transpose(2, 0, 1)).float()
138
  return img_tensor
139
 
 
146
  return torch.stack(b, dim=-1)
147
 
148
 
149
+ def process_single_frame(frame_rgb, h, w):
150
+ """Process a single frame and return segmented result"""
151
+ global model, seg_model, device
152
+
153
+ img_tensor = preprocess_frame(frame_rgb).unsqueeze(0).to(device)
154
+
155
+ mask = torch.zeros((1, 1024, 1024), dtype=torch.bool, device=device)
156
+ samples = NestedTensor(img_tensor, mask)
157
+
158
+ with torch.no_grad():
159
+ outputs, image_embeddings = model(samples)
160
+
161
+ probas = outputs['pred_logits'].softmax(-1)[0, :, :-1]
162
+ keep = probas.max(-1).values > 0.3
163
+
164
+ if not keep.any():
165
+ return frame_rgb # No detections
166
+
167
+ boxes = outputs['pred_boxes'][0, keep]
168
+ scores = probas[keep].max(-1).values.cpu().numpy()
169
+ labels = probas[keep].argmax(-1).cpu().numpy()
170
+
171
+ boxes_scaled = box_cxcywh_to_xyxy(boxes) * torch.tensor([w, h, w, h], device=device)
172
+ boxes_np = boxes_scaled.cpu().numpy()
173
+
174
+ low_res_masks, pred_masks, _ = seg_model(
175
+ img_tensor, boxes, image_embeddings,
176
+ sizes=(1024, 1024), add_noise=False
177
+ )
178
+ masks_np = pred_masks.cpu().numpy()
179
 
180
+ # Draw on frame
181
+ result = frame_rgb.copy()
182
+ for i, (box, label, mask_pred, score) in enumerate(zip(boxes_np, labels, masks_np, scores)):
183
  if score < 0.3:
184
  continue
185
+
186
  color = COLORS[label % len(COLORS)]
187
 
188
  # Draw mask
189
+ mask_resized = cv2.resize(mask_pred, (w, h))
190
  mask_bool = mask_resized > 0.5
191
+ overlay = result.copy()
192
  overlay[mask_bool] = color
193
+ result = cv2.addWeighted(result, 0.6, overlay, 0.4, 0)
194
 
195
  # Draw box
196
  x1, y1, x2, y2 = box.astype(int)
197
+ cv2.rectangle(result, (x1, y1), (x2, y2), color, 2)
198
 
199
  # Draw label
200
  label_text = f"{INSTRUMENT_CLASSES[label]}: {score:.2f}"
201
+ cv2.putText(result, label_text, (x1, y1 - 10),
202
  cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
203
 
204
+ return result
205
 
206
 
207
  @spaces.GPU
208
+ def predict_image(image):
209
  """Run inference on input image"""
210
  global model, seg_model, device
211
 
 
215
  if image is None:
216
  return None
217
 
218
+ frame_rgb = np.array(image)
219
+ h, w = frame_rgb.shape[:2]
220
 
221
+ result = process_single_frame(frame_rgb, h, w)
 
 
222
 
223
+ return Image.fromarray(result)
224
+
225
+
226
+ @spaces.GPU(duration=300)
227
+ def predict_video(video_path, progress=gr.Progress()):
228
+ """Process video and return segmented video"""
229
+ global model, seg_model, device
230
+
231
+ if model is None:
232
+ progress(0, desc="Loading models...")
233
+ load_models()
234
+
235
+ if video_path is None:
236
+ return None
237
+
238
+ # Open video
239
+ cap = cv2.VideoCapture(video_path)
240
+ fps = int(cap.get(cv2.CAP_PROP_FPS))
241
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
242
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
243
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
244
+
245
+ # Output video
246
+ output_path = tempfile.mktemp(suffix=".mp4")
247
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
248
+ out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
249
+
250
+ frame_count = 0
251
+ while True:
252
+ ret, frame = cap.read()
253
+ if not ret:
254
+ break
255
 
256
+ # BGR to RGB
257
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
258
 
259
+ # Process frame
260
+ result_rgb = process_single_frame(frame_rgb, height, width)
 
 
261
 
262
+ # RGB to BGR for output
263
+ result_bgr = cv2.cvtColor(result_rgb, cv2.COLOR_RGB2BGR)
264
+ out.write(result_bgr)
 
265
 
266
+ frame_count += 1
267
+ progress(frame_count / total_frames, desc=f"Processing frame {frame_count}/{total_frames}")
 
 
 
 
268
 
269
+ cap.release()
270
+ out.release()
271
 
272
+ return output_path
273
 
274
 
275
  # Create Gradio interface
276
+ with gr.Blocks(title="Surgical-DeSAM", theme=gr.themes.Soft()) as demo:
277
  gr.Markdown("# 🔬 Surgical-DeSAM")
278
+ gr.Markdown("Segment surgical instruments in images or videos using DeSAM architecture.")
279
 
280
+ with gr.Tabs():
281
+ # Image Tab
282
+ with gr.TabItem("🖼️ Image Segmentation"):
283
+ with gr.Row():
284
+ with gr.Column():
285
+ input_image = gr.Image(type="pil", label="Input Image")
286
+ image_btn = gr.Button("Segment Image", variant="primary")
287
+ with gr.Column():
288
+ output_image = gr.Image(type="pil", label="Segmentation Result")
289
+
290
+ image_btn.click(fn=predict_image, inputs=input_image, outputs=output_image)
291
+
292
+ gr.Examples(
293
+ examples=["examples/sample_surgical.png"] if os.path.exists("examples/sample_surgical.png") else [],
294
+ inputs=input_image,
295
+ label="Example Images"
296
+ )
297
 
298
+ # Video Tab
299
+ with gr.TabItem("🎬 Video Segmentation"):
300
+ with gr.Row():
301
+ with gr.Column():
302
+ input_video = gr.Video(label="Input Video")
303
+ video_btn = gr.Button("Segment Video", variant="primary")
304
+ with gr.Column():
305
+ output_video = gr.Video(label="Segmentation Result")
306
+
307
+ video_btn.click(fn=predict_video, inputs=input_video, outputs=output_video)
308
+
309
+ gr.Examples(
310
+ examples=["examples/demo_surgical.mp4"] if os.path.exists("examples/demo_surgical.mp4") else [],
311
+ inputs=input_video,
312
+ label="Example Videos"
313
+ )
314
 
315
+ gr.Markdown("""
316
+ ## Detected Classes
317
+ Bipolar Forceps | Prograsp Forceps | Large Needle Driver | Monopolar Curved Scissors |
318
+ Ultrasound Probe | Suction | Clip Applier | Stapler
319
+ """)
320
 
321
  if __name__ == "__main__":
322
  demo.launch()
examples/sample_surgical.png ADDED

Git LFS Details

  • SHA256: ad9e495ef59e6c67e2dd60aa8fa7d758714e709f4ac5d71dacae90b8f79f275f
  • Pointer size: 132 Bytes
  • Size of remote file: 1.22 MB