Justin331 commited on
Commit
af5bb79
·
verified ·
1 Parent(s): 4a5b926

Upload handler.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. handler.py +144 -16
handler.py CHANGED
@@ -68,15 +68,46 @@ class EndpointHandler:
68
  logger.info(f"CUDA Version: {torch.version.cuda}")
69
  logger.info(f"Total GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  # Build SAM3 video predictor
72
  try:
73
  logger.info("Building SAM3 video predictor...")
74
  start_time = time.time()
75
-
76
  # Ensure BPE tokenizer file exists
77
  bpe_path = self._ensure_bpe_file()
78
  logger.info(f"BPE tokenizer path: {bpe_path}")
79
-
80
  # Build predictor with explicit bpe_path
81
  self.predictor = build_sam3_video_predictor(
82
  gpus_to_use=[0],
@@ -87,28 +118,125 @@ class EndpointHandler:
87
  # This fixes: "Input type (c10::BFloat16) and bias type (float) should be the same"
88
  logger.info("Converting model to float32 to avoid dtype mismatch...")
89
 
90
- dtype_conversion_count = 0
 
 
91
 
92
- # SAM3 predictor has a 'model' attribute that contains the actual model
93
- if hasattr(self.predictor, 'model') and self.predictor.model is not None:
94
- # Convert model to float32
95
- self.predictor.model = self.predictor.model.float()
96
 
97
- # Ensure all parameters are float32
98
- for name, param in self.predictor.model.named_parameters():
99
  if param.dtype != torch.float32:
100
  param.data = param.data.float()
101
- dtype_conversion_count += 1
 
102
 
103
- # Convert buffers to float32 (important for batch norm, etc.)
104
- for buffer_name, buffer in self.predictor.model.named_buffers():
105
  if buffer.dtype != torch.float32 and buffer.dtype in [torch.float16, torch.bfloat16]:
106
- self.predictor.model.register_buffer(buffer_name, buffer.float())
107
- dtype_conversion_count += 1
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
- logger.info(f"✓ Model converted to float32 ({dtype_conversion_count} tensors converted)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  else:
111
- logger.warning("⚠ Could not find model attribute in predictor - dtype fix may not have been applied")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
  elapsed = time.time() - start_time
114
  logger.info(f"✓ SAM3 video predictor loaded successfully in {elapsed:.2f}s")
 
68
  logger.info(f"CUDA Version: {torch.version.cuda}")
69
  logger.info(f"Total GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
70
 
71
+ # CRITICAL FIX: Patch torch.autocast BEFORE building the predictor
72
+ # SAM3 has @torch.autocast decorators hardcoded to use BFloat16
73
+ # We need to override the autocast context manager to be a no-op
74
+ logger.info("Patching torch.autocast to disable BFloat16 (before model loading)...")
75
+
76
+ # Store the original autocast
77
+ self._original_autocast = torch.autocast
78
+
79
+ # Create a no-op autocast that always disables mixed precision
80
+ class Float32Autocast:
81
+ def __init__(self, device_type, dtype=None, enabled=True):
82
+ # Completely disable autocast
83
+ self.device_type = device_type
84
+ self.dtype = torch.float32
85
+ self.enabled = False
86
+
87
+ def __enter__(self):
88
+ return self
89
+
90
+ def __exit__(self, *args):
91
+ pass
92
+
93
+ # Monkey-patch torch.autocast globally BEFORE importing/building
94
+ torch.autocast = Float32Autocast
95
+ if hasattr(torch.cuda.amp, 'autocast'):
96
+ torch.cuda.amp.autocast = Float32Autocast
97
+ if hasattr(torch.amp, 'autocast'):
98
+ torch.amp.autocast = Float32Autocast
99
+
100
+ logger.info("✓ Patched torch.autocast to be a no-op (forces float32)")
101
+
102
  # Build SAM3 video predictor
103
  try:
104
  logger.info("Building SAM3 video predictor...")
105
  start_time = time.time()
106
+
107
  # Ensure BPE tokenizer file exists
108
  bpe_path = self._ensure_bpe_file()
109
  logger.info(f"BPE tokenizer path: {bpe_path}")
110
+
111
  # Build predictor with explicit bpe_path
112
  self.predictor = build_sam3_video_predictor(
113
  gpus_to_use=[0],
 
118
  # This fixes: "Input type (c10::BFloat16) and bias type (float) should be the same"
119
  logger.info("Converting model to float32 to avoid dtype mismatch...")
120
 
121
+ def convert_model_to_float32(model):
122
+ """Recursively convert all model components to float32."""
123
+ conversion_count = 0
124
 
125
+ # Convert the model itself
126
+ model.float()
 
 
127
 
128
+ # Convert all parameters
129
+ for name, param in model.named_parameters():
130
  if param.dtype != torch.float32:
131
  param.data = param.data.float()
132
+ conversion_count += 1
133
+ logger.debug(f" Converted parameter: {name}")
134
 
135
+ # Convert all buffers (batch norm running stats, etc.)
136
+ for buffer_name, buffer in model.named_buffers():
137
  if buffer.dtype != torch.float32 and buffer.dtype in [torch.float16, torch.bfloat16]:
138
+ model.register_buffer(buffer_name, buffer.float())
139
+ conversion_count += 1
140
+ logger.debug(f" Converted buffer: {buffer_name}")
141
+
142
+ # Also convert submodules explicitly
143
+ for name, module in model.named_modules():
144
+ if module is not model: # Skip the root module
145
+ try:
146
+ module.float()
147
+ except Exception:
148
+ pass # Some modules may not support .float()
149
+
150
+ return conversion_count
151
+
152
+ total_conversions = 0
153
 
154
+ # Convert the main model
155
+ if hasattr(self.predictor, 'model') and self.predictor.model is not None:
156
+ logger.info(" Converting main model...")
157
+ total_conversions += convert_model_to_float32(self.predictor.model)
158
+
159
+ # SAM3 may have additional models (detector, tracker, etc.)
160
+ # Check for other potential model attributes
161
+ for attr_name in ['detector', 'tracker', 'image_encoder', 'text_encoder']:
162
+ if hasattr(self.predictor, attr_name):
163
+ attr = getattr(self.predictor, attr_name)
164
+ if attr is not None and hasattr(attr, 'float'):
165
+ logger.info(f" Converting {attr_name}...")
166
+ try:
167
+ total_conversions += convert_model_to_float32(attr)
168
+ except Exception as e:
169
+ logger.warning(f" Could not convert {attr_name}: {e}")
170
+
171
+ # Check if model has nested models
172
+ if hasattr(self.predictor, 'model') and self.predictor.model is not None:
173
+ model = self.predictor.model
174
+ for attr_name in dir(model):
175
+ if not attr_name.startswith('_'):
176
+ try:
177
+ attr = getattr(model, attr_name)
178
+ if hasattr(attr, 'parameters') and hasattr(attr, 'float'):
179
+ # This looks like a submodel
180
+ if attr_name not in ['model', 'detector', 'tracker']:
181
+ logger.debug(f" Found submodel: {attr_name}")
182
+ try:
183
+ convert_model_to_float32(attr)
184
+ except Exception:
185
+ pass
186
+ except Exception:
187
+ pass
188
+
189
+ if total_conversions > 0:
190
+ logger.info(f"✓ Model converted to float32 ({total_conversions} tensors converted)")
191
  else:
192
+ logger.warning("⚠ No tensors were converted - dtype fix may not have been applied correctly")
193
+
194
+ # Additional safety: Wrap handle_request to ensure inputs are float32
195
+ original_handle_request = self.predictor.handle_request
196
+
197
+ def float32_handle_request(request):
198
+ """Wrapper to ensure all tensor inputs are float32."""
199
+ # Recursively convert any tensors in the request to float32
200
+ def ensure_float32(obj):
201
+ if isinstance(obj, torch.Tensor):
202
+ if obj.dtype in [torch.float16, torch.bfloat16]:
203
+ return obj.float()
204
+ return obj
205
+ elif isinstance(obj, dict):
206
+ return {k: ensure_float32(v) for k, v in obj.items()}
207
+ elif isinstance(obj, (list, tuple)):
208
+ return type(obj)(ensure_float32(item) for item in obj)
209
+ return obj
210
+
211
+ request = ensure_float32(request)
212
+ return original_handle_request(request)
213
+
214
+ self.predictor.handle_request = float32_handle_request
215
+
216
+ # Also wrap handle_stream_request if it exists
217
+ if hasattr(self.predictor, 'handle_stream_request'):
218
+ original_handle_stream_request = self.predictor.handle_stream_request
219
+
220
+ def float32_handle_stream_request(request):
221
+ """Wrapper to ensure all tensor inputs are float32."""
222
+ def ensure_float32(obj):
223
+ if isinstance(obj, torch.Tensor):
224
+ if obj.dtype in [torch.float16, torch.bfloat16]:
225
+ return obj.float()
226
+ return obj
227
+ elif isinstance(obj, dict):
228
+ return {k: ensure_float32(v) for k, v in obj.items()}
229
+ elif isinstance(obj, (list, tuple)):
230
+ return type(obj)(ensure_float32(item) for item in obj)
231
+ return obj
232
+
233
+ request = ensure_float32(request)
234
+ for response in original_handle_stream_request(request):
235
+ yield response
236
+
237
+ self.predictor.handle_stream_request = float32_handle_stream_request
238
+
239
+ logger.info("✓ Added float32 enforcement wrappers to predictor methods")
240
 
241
  elapsed = time.time() - start_time
242
  logger.info(f"✓ SAM3 video predictor loaded successfully in {elapsed:.2f}s")