AI Agent commited on
Commit
273261c
·
1 Parent(s): 599b438

Switch to model.half() (float16) for native T4 acceleration with correct Meta checkpoint loading

Browse files
Files changed (1) hide show
  1. app.py +6 -8
app.py CHANGED
@@ -86,17 +86,15 @@ processor = None
86
  if model_installed:
87
  print(f"Loading SAM 3 onto {device}...")
88
  model = build_sam3_image_model(checkpoint_path=ckpt_path)
89
- # Let model stay in its NATIVE dtype (bfloat16 from checkpoint).
90
- # Our F.linear/F.conv2d interceptors handle dtype mismatches dynamically.
91
- # DO NOT cast to float32 — it was causing zero mask outputs!
92
 
93
- # Diagnostic: print parameter dtypes to verify checkpoint loaded correctly
94
- param_dtypes = set()
95
- for name, p in model.named_parameters():
96
- param_dtypes.add(str(p.dtype))
97
- print(f"Model parameter dtypes: {param_dtypes}", flush=True)
98
  total_params = sum(p.numel() for p in model.parameters())
99
  print(f"Total parameters: {total_params:,}", flush=True)
 
 
100
 
101
  processor = Sam3Processor(model)
102
  if not torch.cuda.is_available():
 
86
  if model_installed:
87
  print(f"Loading SAM 3 onto {device}...")
88
  model = build_sam3_image_model(checkpoint_path=ckpt_path)
89
+ # Cast to float16 T4 has native float16 Tensor Core acceleration.
90
+ # bfloat16 hangs (software emulated on Turing), float32 produced zero masks.
91
+ model.half()
92
 
93
+ # Diagnostic: verify checkpoint loaded correctly
 
 
 
 
94
  total_params = sum(p.numel() for p in model.parameters())
95
  print(f"Total parameters: {total_params:,}", flush=True)
96
+ sample_dtype = next(model.parameters()).dtype
97
+ print(f"Model dtype: {sample_dtype}", flush=True)
98
 
99
  processor = Sam3Processor(model)
100
  if not torch.cuda.is_available():