Tahereh Toosi commited on
Commit
a653fde
·
1 Parent(s): 89de4e8

biased inferneceimplemented, next up: saving ourput as a gif and the config json

Browse files
Files changed (2) hide show
  1. app.py +29 -2
  2. inference.py +41 -5
app.py CHANGED
@@ -11,7 +11,7 @@ except ImportError:
11
 
12
  import os
13
  import argparse
14
- from inference import GenerativeInferenceModel, get_inference_configs
15
 
16
  # Parse command line arguments
17
  parser = argparse.ArgumentParser(description='Run Generative Inference Demo')
@@ -22,6 +22,9 @@ args = parser.parse_args()
22
  os.makedirs("models", exist_ok=True)
23
  os.makedirs("stimuli", exist_ok=True)
24
 
 
 
 
25
  # Check if running on Hugging Face Spaces
26
  if "SPACE_ID" in os.environ:
27
  default_port = int(os.environ.get("PORT", 7860))
@@ -363,7 +366,8 @@ def run_inference(image, model_type, inference_type, eps_value, num_iterations,
363
  initial_noise=0.05, diffusion_noise=0.3, step_size=0.8, model_layer="layer3",
364
  use_adaptive_eps=False, use_adaptive_step=False,
365
  mask_center_x=0.0, mask_center_y=0.0, mask_radius=0.3, mask_sigma=0.2,
366
- eps_max_mult=4.0, eps_min_mult=1.0, step_max_mult=4.0, step_min_mult=1.0):
 
367
  # Check if image is provided
368
  if image is None:
369
  return None, "Please upload an image before running inference."
@@ -419,6 +423,14 @@ def run_inference(image, model_type, inference_type, eps_value, num_iterations,
419
  else:
420
  config['adaptive_step_size'] = None
421
 
 
 
 
 
 
 
 
 
422
  # Run generative inference
423
  result = model.inference(image, model_type, config)
424
 
@@ -532,6 +544,8 @@ def apply_example(example):
532
  example.get("eps_min_mult", 1.0),
533
  example.get("step_max_mult", 4.0),
534
  example.get("step_min_mult", 1.0),
 
 
535
  mask_img,
536
  gr.Group(visible=True),
537
  ]
@@ -623,6 +637,17 @@ with gr.Blocks(title="Human Hallucination Prediction", css="""
623
  with gr.Row():
624
  step_max_mult_slider = gr.Slider(minimum=0.1, maximum=150.0, value=50.0, step=0.1, label="Step size: multiplier at center")
625
  step_min_mult_slider = gr.Slider(minimum=0.1, maximum=10.0, value=0.2, step=0.1, label="Step size: multiplier at periphery")
 
 
 
 
 
 
 
 
 
 
 
626
 
627
  with gr.Column(scale=2):
628
  # Outputs
@@ -655,6 +680,7 @@ with gr.Blocks(title="Human Hallucination Prediction", css="""
655
  mask_radius_slider, mask_sigma_slider,
656
  eps_max_mult_slider, eps_min_mult_slider,
657
  step_max_mult_slider, step_min_mult_slider,
 
658
  mask_preview,
659
  params_section,
660
  ],
@@ -686,6 +712,7 @@ with gr.Blocks(title="Human Hallucination Prediction", css="""
686
  mask_radius_slider, mask_sigma_slider,
687
  eps_max_mult_slider, eps_min_mult_slider,
688
  step_max_mult_slider, step_min_mult_slider,
 
689
  ],
690
  outputs=[output_image, output_frames]
691
  )
 
11
 
12
  import os
13
  import argparse
14
+ from inference import GenerativeInferenceModel, get_inference_configs, get_imagenet_labels
15
 
16
  # Parse command line arguments
17
  parser = argparse.ArgumentParser(description='Run Generative Inference Demo')
 
22
  os.makedirs("models", exist_ok=True)
23
  os.makedirs("stimuli", exist_ok=True)
24
 
25
+ # Load ImageNet labels for biased-inference dropdown (1000 classes)
26
+ IMAGENET_LABELS = get_imagenet_labels()
27
+
28
  # Check if running on Hugging Face Spaces
29
  if "SPACE_ID" in os.environ:
30
  default_port = int(os.environ.get("PORT", 7860))
 
366
  initial_noise=0.05, diffusion_noise=0.3, step_size=0.8, model_layer="layer3",
367
  use_adaptive_eps=False, use_adaptive_step=False,
368
  mask_center_x=0.0, mask_center_y=0.0, mask_radius=0.3, mask_sigma=0.2,
369
+ eps_max_mult=4.0, eps_min_mult=1.0, step_max_mult=4.0, step_min_mult=1.0,
370
+ use_biased_inference=False, biased_class_name=""):
371
  # Check if image is provided
372
  if image is None:
373
  return None, "Please upload an image before running inference."
 
423
  else:
424
  config['adaptive_step_size'] = None
425
 
426
+ # Biased inference: bias perception toward a target ImageNet class
427
+ use_biased_inference = bool(use_biased_inference) if use_biased_inference is not None else False
428
+ biased_class_name = (biased_class_name or "").strip() if biased_class_name else ""
429
+ if use_biased_inference and biased_class_name:
430
+ config['biased_inference'] = {'enable': True, 'class': biased_class_name}
431
+ else:
432
+ config['biased_inference'] = config.get('biased_inference') or {'enable': False, 'class': None}
433
+
434
  # Run generative inference
435
  result = model.inference(image, model_type, config)
436
 
 
544
  example.get("eps_min_mult", 1.0),
545
  example.get("step_max_mult", 4.0),
546
  example.get("step_min_mult", 1.0),
547
+ example.get("use_biased_inference", False),
548
+ example.get("biased_class_name", ""),
549
  mask_img,
550
  gr.Group(visible=True),
551
  ]
 
637
  with gr.Row():
638
  step_max_mult_slider = gr.Slider(minimum=0.1, maximum=150.0, value=50.0, step=0.1, label="Step size: multiplier at center")
639
  step_min_mult_slider = gr.Slider(minimum=0.1, maximum=10.0, value=0.2, step=0.1, label="Step size: multiplier at periphery")
640
+ gr.Markdown("### 🎯 Biased inference")
641
+ gr.Markdown("Bias the prediction toward a specific ImageNet category (1000 classes).")
642
+ with gr.Row():
643
+ use_biased_inference_check = gr.Checkbox(value=False, label="Use biased inference (bias toward a target class)")
644
+ biased_class_dropdown = gr.Dropdown(
645
+ choices=[("— No bias —", "")] + [(label, label) for label in sorted(IMAGENET_LABELS)],
646
+ value="",
647
+ label="Target class",
648
+ allow_custom_value=False,
649
+ filterable=True,
650
+ )
651
 
652
  with gr.Column(scale=2):
653
  # Outputs
 
680
  mask_radius_slider, mask_sigma_slider,
681
  eps_max_mult_slider, eps_min_mult_slider,
682
  step_max_mult_slider, step_min_mult_slider,
683
+ use_biased_inference_check, biased_class_dropdown,
684
  mask_preview,
685
  params_section,
686
  ],
 
712
  mask_radius_slider, mask_sigma_slider,
713
  eps_max_mult_slider, eps_min_mult_slider,
714
  step_max_mult_slider, step_min_mult_slider,
715
+ use_biased_inference_check, biased_class_dropdown,
716
  ],
717
  outputs=[output_image, output_frames]
718
  )
inference.py CHANGED
@@ -280,12 +280,16 @@ class InferStep:
280
  self.step_size = step_size
281
 
282
  def project(self, x):
283
- # Per-pixel deviation from original; clamp by eps (or eps_map). No add/subtract of mask.
284
- diff = x - self.orig_image
 
 
285
  if self.use_adaptive_eps:
286
- diff = torch.clamp(diff, -self.eps_map, self.eps_map)
 
287
  else:
288
- diff = torch.clamp(diff, -self.eps, self.eps)
 
289
  return torch.clamp(self.orig_image + diff, 0, 1)
290
 
291
  def step(self, x, grad):
@@ -334,7 +338,8 @@ def get_inference_configs(inference_type='IncreaseConfidence', eps=0.5, n_itr=50
334
  'inference_normalization': 'off', # 'on' or 'off' (match psychiatry implementation)
335
  'recognition_normalization': 'off',
336
  'iterations_to_show': get_iterations_to_show(n_itr), # Dynamic iterations to visualize
337
- 'misc_info': {'keep_grads': False} # Additional configuration
 
338
  }
339
 
340
  # Customize based on inference type
@@ -829,6 +834,30 @@ class GenerativeInferenceModel:
829
  adaptive_step_config=adaptive_step,
830
  )
831
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
832
  # Storage for inference steps
833
  # Create a new tensor that requires gradients
834
  x = image_tensor.clone().detach().requires_grad_(True)
@@ -934,6 +963,13 @@ class GenerativeInferenceModel:
934
 
935
  grad = torch.autograd.grad(loss, x, retain_graph=True)[0]
936
 
 
 
 
 
 
 
 
937
  if grad is None:
938
  print("Warning: Direct gradient calculation failed")
939
  # Fall back to random perturbation
 
280
  self.step_size = step_size
281
 
282
  def project(self, x):
283
+ # Per-pixel L2 (color) constraint: scale diff so color_norm <= eps (or eps_map).
284
+ # Constraint is per-pixel L2 on RGB (spherical), not L∞ per channel.
285
+ diff = x - self.orig_image # (B, C, H, W)
286
+ color_norm = torch.norm(diff, dim=1, keepdim=True) # (B, 1, H, W)
287
  if self.use_adaptive_eps:
288
+ # eps_map has shape (1, 1, H, W), broadcasts to (B, 1, H, W)
289
+ scale = torch.clamp(self.eps_map / (color_norm + 1e-10), max=1.0)
290
  else:
291
+ scale = torch.clamp(self.eps / (color_norm + 1e-10), max=1.0)
292
+ diff = diff * scale
293
  return torch.clamp(self.orig_image + diff, 0, 1)
294
 
295
  def step(self, x, grad):
 
338
  'inference_normalization': 'off', # 'on' or 'off' (match psychiatry implementation)
339
  'recognition_normalization': 'off',
340
  'iterations_to_show': get_iterations_to_show(n_itr), # Dynamic iterations to visualize
341
+ 'misc_info': {'keep_grads': False}, # Additional configuration
342
+ 'biased_inference': {'enable': False, 'class': None} # Bias perception toward a target class (ImageNet label)
343
  }
344
 
345
  # Customize based on inference type
 
834
  adaptive_step_config=adaptive_step,
835
  )
836
 
837
+ # Biased inference: resolve class name to index (ImageNet simple labels, case-insensitive)
838
+ biased_inference_config = config.get('biased_inference', {'enable': False, 'class': None})
839
+ biased_class_index = None
840
+ biased_class_tensor = None
841
+ if biased_inference_config.get('enable', False):
842
+ class_name = biased_inference_config.get('class') or biased_inference_config.get('class_name')
843
+ if class_name:
844
+ try:
845
+ biased_class_index = next(
846
+ i for i, label in enumerate(self.labels)
847
+ if label.lower() == class_name.lower()
848
+ )
849
+ biased_class_tensor = torch.tensor(
850
+ [biased_class_index], device=device, dtype=torch.long
851
+ )
852
+ print(f"Biased inference: biasing toward class '{self.labels[biased_class_index]}' (index {biased_class_index})")
853
+ except StopIteration:
854
+ raise ValueError(
855
+ f"biased_inference class '{class_name}' not found in ImageNet simple labels. "
856
+ "Use a label from imagenet-simple-labels (e.g. 'goldfish', 'tabby cat')."
857
+ )
858
+ else:
859
+ print("Biased inference enabled but no class specified; ignoring.")
860
+
861
  # Storage for inference steps
862
  # Create a new tensor that requires gradients
863
  x = image_tensor.clone().detach().requires_grad_(True)
 
963
 
964
  grad = torch.autograd.grad(loss, x, retain_graph=True)[0]
965
 
966
+ # Biased inference: subtract gradient that pushes toward target class
967
+ if biased_inference_config.get('enable', False) and biased_class_tensor is not None:
968
+ output_full = model(x)
969
+ loss_biased = F.cross_entropy(output_full, biased_class_tensor)
970
+ grad_biased = torch.autograd.grad(loss_biased, x)[0]
971
+ grad = grad - grad_biased
972
+
973
  if grad is None:
974
  print("Warning: Direct gradient calculation failed")
975
  # Fall back to random perturbation