AI Agent commited on
Commit
599b438
·
1 Parent(s): a760beb

DIAGNOSTIC: Remove float32 cast, let model run in native bfloat16 with interceptors. Print param dtypes

Browse files
Files changed (1) hide show
  1. app.py +11 -3
app.py CHANGED
@@ -86,9 +86,17 @@ processor = None
86
  if model_installed:
87
  print(f"Loading SAM 3 onto {device}...")
88
  model = build_sam3_image_model(checkpoint_path=ckpt_path)
89
- # NOTE: build_sam3_image_model() internally loads weights, remaps keys,
90
- # and calls model.cuda().eval(). We just need to cast to float32 for T4.
91
- model.to(torch.float32)
 
 
 
 
 
 
 
 
92
 
93
  processor = Sam3Processor(model)
94
  if not torch.cuda.is_available():
 
86
  if model_installed:
87
  print(f"Loading SAM 3 onto {device}...")
88
  model = build_sam3_image_model(checkpoint_path=ckpt_path)
89
+ # Let model stay in its NATIVE dtype (bfloat16 from checkpoint).
90
+ # Our F.linear/F.conv2d interceptors handle dtype mismatches dynamically.
91
+ # DO NOT cast to float32 — it was causing zero mask outputs!
92
+
93
+ # Diagnostic: print parameter dtypes to verify checkpoint loaded correctly
94
+ param_dtypes = set()
95
+ for name, p in model.named_parameters():
96
+ param_dtypes.add(str(p.dtype))
97
+ print(f"Model parameter dtypes: {param_dtypes}", flush=True)
98
+ total_params = sum(p.numel() for p in model.parameters())
99
+ print(f"Total parameters: {total_params:,}", flush=True)
100
 
101
  processor = Sam3Processor(model)
102
  if not torch.cuda.is_available():