prithivMLmods commited on
Commit
57943d6
·
verified ·
1 Parent(s): c53c756

update app

Browse files
Files changed (1) hide show
  1. app.py +43 -44
app.py CHANGED
@@ -1,5 +1,4 @@
1
  import os
2
- import gc
3
  import cv2
4
  import tempfile
5
  import spaces
@@ -17,6 +16,7 @@ from transformers import (
17
  Sam3VideoModel, Sam3VideoProcessor
18
  )
19
 
 
20
  colors.steel_blue = colors.Color(
21
  name="steel_blue",
22
  c50="#EBF3F8",
@@ -79,45 +79,35 @@ class CustomBlueTheme(Soft):
79
 
80
  app_theme = CustomBlueTheme()
81
 
82
- MODEL_CACHE = {}
83
  device = "cuda" if torch.cuda.is_available() else "cpu"
84
- print(f"Using compute device: {device}")
85
 
86
- def clear_vram():
87
- """Forces RAM/VRAM cleanup."""
88
- if MODEL_CACHE:
89
- print("🧹 Cleaning up memory...")
90
- MODEL_CACHE.clear()
91
- gc.collect()
92
- torch.cuda.empty_cache()
93
 
94
- def load_segmentation_model(model_key):
95
- """Lazy loads the specific SAM3 model required."""
96
- if model_key in MODEL_CACHE:
97
- return MODEL_CACHE[model_key]
 
98
 
99
- clear_vram()
100
- print(f"⏳ Loading {model_key}...")
 
 
 
101
 
102
- try:
103
- if model_key == "img_seg_model":
104
- seg_model = Sam3Model.from_pretrained("facebook/sam3").to(device)
105
- seg_processor = Sam3Processor.from_pretrained("facebook/sam3")
106
- MODEL_CACHE[model_key] = (seg_model, seg_processor)
107
-
108
- elif model_key == "vid_seg_model":
109
- vid_model = Sam3VideoModel.from_pretrained("facebook/sam3").to(device, dtype=torch.bfloat16)
110
- vid_processor = Sam3VideoProcessor.from_pretrained("facebook/sam3")
111
- MODEL_CACHE[model_key] = (vid_model, vid_processor)
112
-
113
- print(f"✅ {model_key} loaded.")
114
- return MODEL_CACHE[model_key]
115
-
116
- except Exception as e:
117
- print(f"❌ Error loading model: {e}")
118
- clear_vram()
119
- raise e
120
 
 
 
121
  def apply_mask_overlay(base_image, mask_data, opacity=0.5):
122
  """Draws segmentation masks on top of an image."""
123
  if isinstance(base_image, np.ndarray):
@@ -162,21 +152,27 @@ def apply_mask_overlay(base_image, mask_data, opacity=0.5):
162
 
163
  return Image.alpha_composite(base_image, composite_layer).convert("RGB")
164
 
 
 
 
165
  @spaces.GPU
166
  def run_image_segmentation(source_img, text_query, conf_thresh=0.5):
 
 
 
167
  if source_img is None or not text_query:
168
  raise gr.Error("Please provide an image and a text prompt.")
169
 
170
  try:
171
- active_model, active_processor = load_segmentation_model("img_seg_model")
172
  pil_image = source_img.convert("RGB")
173
 
174
- model_inputs = active_processor(images=pil_image, text=text_query, return_tensors="pt").to(device)
 
175
 
176
  with torch.no_grad():
177
- inference_output = active_model(**model_inputs)
178
 
179
- processed_results = active_processor.post_process_instance_segmentation(
180
  inference_output,
181
  threshold=conf_thresh,
182
  mask_threshold=0.5,
@@ -202,12 +198,13 @@ def calc_timeout_duration(vid_file, *args):
202
 
203
  @spaces.GPU(duration=calc_timeout_duration)
204
  def run_video_segmentation(source_vid, text_query, frame_limit, time_limit):
 
 
 
205
  if not source_vid or not text_query:
206
  raise gr.Error("Missing video or prompt.")
207
 
208
  try:
209
- active_model, active_processor = load_segmentation_model("vid_seg_model")
210
-
211
  video_cap = cv2.VideoCapture(source_vid)
212
  vid_fps = video_cap.get(cv2.CAP_PROP_FPS)
213
  vid_w = int(video_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
@@ -222,14 +219,15 @@ def run_video_segmentation(source_vid, text_query, frame_limit, time_limit):
222
  counter += 1
223
  video_cap.release()
224
 
225
- session = active_processor.init_video_session(video=video_frames, inference_device=device, dtype=torch.bfloat16)
226
- session = active_processor.add_text_prompt(inference_session=session, text=text_query)
 
227
 
228
  temp_out_path = tempfile.mktemp(suffix=".mp4")
229
  video_writer = cv2.VideoWriter(temp_out_path, cv2.VideoWriter_fourcc(*'mp4v'), vid_fps, (vid_w, vid_h))
230
 
231
- for model_out in active_model.propagate_in_video_iterator(inference_session=session, max_frame_num_to_track=len(video_frames)):
232
- post_processed = active_processor.postprocess_outputs(session, model_out)
233
  f_idx = model_out.frame_idx
234
  original_pil = Image.fromarray(video_frames[f_idx])
235
 
@@ -248,6 +246,7 @@ def run_video_segmentation(source_vid, text_query, frame_limit, time_limit):
248
  except Exception as e:
249
  return None, f"Error during video processing: {str(e)}"
250
 
 
251
  custom_css="""
252
  #col-container { margin: 0 auto; max-width: 1100px; }
253
  #main-title h1 { font-size: 2.1em !important; }
 
1
  import os
 
2
  import cv2
3
  import tempfile
4
  import spaces
 
16
  Sam3VideoModel, Sam3VideoProcessor
17
  )
18
 
19
+ # --- THEME CONFIGURATION ---
20
  colors.steel_blue = colors.Color(
21
  name="steel_blue",
22
  c50="#EBF3F8",
 
79
 
80
  app_theme = CustomBlueTheme()
81
 
82
+ # --- GLOBAL MODEL LOADING ---
83
  device = "cuda" if torch.cuda.is_available() else "cpu"
84
+ print(f"🖥️ Using compute device: {device}")
85
 
86
+ print("⏳ Loading SAM3 Models permanently into memory...")
 
 
 
 
 
 
87
 
88
+ try:
89
+ # 1. Load Image Segmentation Model
90
+ print(" ... Loading Image Model")
91
+ IMG_MODEL = Sam3Model.from_pretrained("facebook/sam3").to(device)
92
+ IMG_PROCESSOR = Sam3Processor.from_pretrained("facebook/sam3")
93
 
94
+ # 2. Load Video Segmentation Model
95
+ # Using bfloat16 for video to optimize VRAM usage while keeping speed
96
+ print(" ... Loading Video Model")
97
+ VID_MODEL = Sam3VideoModel.from_pretrained("facebook/sam3").to(device, dtype=torch.bfloat16)
98
+ VID_PROCESSOR = Sam3VideoProcessor.from_pretrained("facebook/sam3")
99
 
100
+ print("✅ All Models loaded successfully!")
101
+
102
+ except Exception as e:
103
+ print(f"❌ CRITICAL ERROR LOADING MODELS: {e}")
104
+ IMG_MODEL = None
105
+ VID_MODEL = None
106
+ IMG_PROCESSOR = None
107
+ VID_PROCESSOR = None
 
 
 
 
 
 
 
 
 
 
108
 
109
+
110
+ # --- UTILS ---
111
  def apply_mask_overlay(base_image, mask_data, opacity=0.5):
112
  """Draws segmentation masks on top of an image."""
113
  if isinstance(base_image, np.ndarray):
 
152
 
153
  return Image.alpha_composite(base_image, composite_layer).convert("RGB")
154
 
155
+
156
+ # --- GPU INFERENCE FUNCTIONS ---
157
+
158
  @spaces.GPU
159
  def run_image_segmentation(source_img, text_query, conf_thresh=0.5):
160
+ if IMG_MODEL is None or IMG_PROCESSOR is None:
161
+ raise gr.Error("Models failed to load on startup. Check logs.")
162
+
163
  if source_img is None or not text_query:
164
  raise gr.Error("Please provide an image and a text prompt.")
165
 
166
  try:
 
167
  pil_image = source_img.convert("RGB")
168
 
169
+ # Models are already on device, just move inputs
170
+ model_inputs = IMG_PROCESSOR(images=pil_image, text=text_query, return_tensors="pt").to(device)
171
 
172
  with torch.no_grad():
173
+ inference_output = IMG_MODEL(**model_inputs)
174
 
175
+ processed_results = IMG_PROCESSOR.post_process_instance_segmentation(
176
  inference_output,
177
  threshold=conf_thresh,
178
  mask_threshold=0.5,
 
198
 
199
  @spaces.GPU(duration=calc_timeout_duration)
200
  def run_video_segmentation(source_vid, text_query, frame_limit, time_limit):
201
+ if VID_MODEL is None or VID_PROCESSOR is None:
202
+ raise gr.Error("Video Models failed to load on startup.")
203
+
204
  if not source_vid or not text_query:
205
  raise gr.Error("Missing video or prompt.")
206
 
207
  try:
 
 
208
  video_cap = cv2.VideoCapture(source_vid)
209
  vid_fps = video_cap.get(cv2.CAP_PROP_FPS)
210
  vid_w = int(video_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
 
219
  counter += 1
220
  video_cap.release()
221
 
222
+ # VID_MODEL is already on device in bfloat16
223
+ session = VID_PROCESSOR.init_video_session(video=video_frames, inference_device=device, dtype=torch.bfloat16)
224
+ session = VID_PROCESSOR.add_text_prompt(inference_session=session, text=text_query)
225
 
226
  temp_out_path = tempfile.mktemp(suffix=".mp4")
227
  video_writer = cv2.VideoWriter(temp_out_path, cv2.VideoWriter_fourcc(*'mp4v'), vid_fps, (vid_w, vid_h))
228
 
229
+ for model_out in VID_MODEL.propagate_in_video_iterator(inference_session=session, max_frame_num_to_track=len(video_frames)):
230
+ post_processed = VID_PROCESSOR.postprocess_outputs(session, model_out)
231
  f_idx = model_out.frame_idx
232
  original_pil = Image.fromarray(video_frames[f_idx])
233
 
 
246
  except Exception as e:
247
  return None, f"Error during video processing: {str(e)}"
248
 
249
+ # --- GUI ---
250
  custom_css="""
251
  #col-container { margin: 0 auto; max-width: 1100px; }
252
  #main-title h1 { font-size: 2.1em !important; }