ttoosi commited on
Commit
f56511f
·
verified ·
1 Parent(s): 64114f8

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +4 -1
inference.py CHANGED
@@ -302,7 +302,8 @@ class GenerativeInferenceModel:
302
  try:
303
  print(f"\n=== Running model integrity check for {model_type} ===")
304
  # Create a deterministic test input directly on the correct device
305
- test_input = torch.zeros(1, 3, 224, 224, device=device)
 
306
  test_input[0, 0, 100:124, 100:124] = 0.5 # Red square
307
 
308
  # Run forward pass
@@ -705,6 +706,8 @@ class GenerativeInferenceModel:
705
  norm_mean=pre["mean"],
706
  norm_std=pre["std"]
707
  )
 
 
708
 
709
  # Special handling for GradModulation as in original
710
  if config['loss_infer'] == 'GradModulation' and 'misc_info' in config and 'grad_modulation' in config['misc_info']:
 
302
  try:
303
  print(f"\n=== Running model integrity check for {model_type} ===")
304
  # Create a deterministic test input directly on the correct device
305
+ H = W = MODEL_PREPROC.get(model_type, {"size": 224})["size"]
306
+ test_input = torch.zeros(1, 3, H, W, device=device)
307
  test_input[0, 0, 100:124, 100:124] = 0.5 # Red square
308
 
309
  # Run forward pass
 
706
  norm_mean=pre["mean"],
707
  norm_std=pre["std"]
708
  )
709
+
710
+ print(f"[PREPROC] {model_type}: size={pre['size']} mean={pre['mean']} std={pre['std']} (transform normalize=False; model has internal normalizer)")
711
 
712
  # Special handling for GradModulation as in original
713
  if config['loss_infer'] == 'GradModulation' and 'misc_info' in config and 'grad_modulation' in config['misc_info']: