Spaces:
Running on Zero
Running on Zero
Tahereh Toosi commited on
Commit ·
a653fde
1
Parent(s): 89de4e8
biased inferneceimplemented, next up: saving ourput as a gif and the config json
Browse files- app.py +29 -2
- inference.py +41 -5
app.py
CHANGED
|
@@ -11,7 +11,7 @@ except ImportError:
|
|
| 11 |
|
| 12 |
import os
|
| 13 |
import argparse
|
| 14 |
-
from inference import GenerativeInferenceModel, get_inference_configs
|
| 15 |
|
| 16 |
# Parse command line arguments
|
| 17 |
parser = argparse.ArgumentParser(description='Run Generative Inference Demo')
|
|
@@ -22,6 +22,9 @@ args = parser.parse_args()
|
|
| 22 |
os.makedirs("models", exist_ok=True)
|
| 23 |
os.makedirs("stimuli", exist_ok=True)
|
| 24 |
|
|
|
|
|
|
|
|
|
|
| 25 |
# Check if running on Hugging Face Spaces
|
| 26 |
if "SPACE_ID" in os.environ:
|
| 27 |
default_port = int(os.environ.get("PORT", 7860))
|
|
@@ -363,7 +366,8 @@ def run_inference(image, model_type, inference_type, eps_value, num_iterations,
|
|
| 363 |
initial_noise=0.05, diffusion_noise=0.3, step_size=0.8, model_layer="layer3",
|
| 364 |
use_adaptive_eps=False, use_adaptive_step=False,
|
| 365 |
mask_center_x=0.0, mask_center_y=0.0, mask_radius=0.3, mask_sigma=0.2,
|
| 366 |
-
eps_max_mult=4.0, eps_min_mult=1.0, step_max_mult=4.0, step_min_mult=1.0
|
|
|
|
| 367 |
# Check if image is provided
|
| 368 |
if image is None:
|
| 369 |
return None, "Please upload an image before running inference."
|
|
@@ -419,6 +423,14 @@ def run_inference(image, model_type, inference_type, eps_value, num_iterations,
|
|
| 419 |
else:
|
| 420 |
config['adaptive_step_size'] = None
|
| 421 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 422 |
# Run generative inference
|
| 423 |
result = model.inference(image, model_type, config)
|
| 424 |
|
|
@@ -532,6 +544,8 @@ def apply_example(example):
|
|
| 532 |
example.get("eps_min_mult", 1.0),
|
| 533 |
example.get("step_max_mult", 4.0),
|
| 534 |
example.get("step_min_mult", 1.0),
|
|
|
|
|
|
|
| 535 |
mask_img,
|
| 536 |
gr.Group(visible=True),
|
| 537 |
]
|
|
@@ -623,6 +637,17 @@ with gr.Blocks(title="Human Hallucination Prediction", css="""
|
|
| 623 |
with gr.Row():
|
| 624 |
step_max_mult_slider = gr.Slider(minimum=0.1, maximum=150.0, value=50.0, step=0.1, label="Step size: multiplier at center")
|
| 625 |
step_min_mult_slider = gr.Slider(minimum=0.1, maximum=10.0, value=0.2, step=0.1, label="Step size: multiplier at periphery")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 626 |
|
| 627 |
with gr.Column(scale=2):
|
| 628 |
# Outputs
|
|
@@ -655,6 +680,7 @@ with gr.Blocks(title="Human Hallucination Prediction", css="""
|
|
| 655 |
mask_radius_slider, mask_sigma_slider,
|
| 656 |
eps_max_mult_slider, eps_min_mult_slider,
|
| 657 |
step_max_mult_slider, step_min_mult_slider,
|
|
|
|
| 658 |
mask_preview,
|
| 659 |
params_section,
|
| 660 |
],
|
|
@@ -686,6 +712,7 @@ with gr.Blocks(title="Human Hallucination Prediction", css="""
|
|
| 686 |
mask_radius_slider, mask_sigma_slider,
|
| 687 |
eps_max_mult_slider, eps_min_mult_slider,
|
| 688 |
step_max_mult_slider, step_min_mult_slider,
|
|
|
|
| 689 |
],
|
| 690 |
outputs=[output_image, output_frames]
|
| 691 |
)
|
|
|
|
| 11 |
|
| 12 |
import os
|
| 13 |
import argparse
|
| 14 |
+
from inference import GenerativeInferenceModel, get_inference_configs, get_imagenet_labels
|
| 15 |
|
| 16 |
# Parse command line arguments
|
| 17 |
parser = argparse.ArgumentParser(description='Run Generative Inference Demo')
|
|
|
|
| 22 |
os.makedirs("models", exist_ok=True)
|
| 23 |
os.makedirs("stimuli", exist_ok=True)
|
| 24 |
|
| 25 |
+
# Load ImageNet labels for biased-inference dropdown (1000 classes)
|
| 26 |
+
IMAGENET_LABELS = get_imagenet_labels()
|
| 27 |
+
|
| 28 |
# Check if running on Hugging Face Spaces
|
| 29 |
if "SPACE_ID" in os.environ:
|
| 30 |
default_port = int(os.environ.get("PORT", 7860))
|
|
|
|
| 366 |
initial_noise=0.05, diffusion_noise=0.3, step_size=0.8, model_layer="layer3",
|
| 367 |
use_adaptive_eps=False, use_adaptive_step=False,
|
| 368 |
mask_center_x=0.0, mask_center_y=0.0, mask_radius=0.3, mask_sigma=0.2,
|
| 369 |
+
eps_max_mult=4.0, eps_min_mult=1.0, step_max_mult=4.0, step_min_mult=1.0,
|
| 370 |
+
use_biased_inference=False, biased_class_name=""):
|
| 371 |
# Check if image is provided
|
| 372 |
if image is None:
|
| 373 |
return None, "Please upload an image before running inference."
|
|
|
|
| 423 |
else:
|
| 424 |
config['adaptive_step_size'] = None
|
| 425 |
|
| 426 |
+
# Biased inference: bias perception toward a target ImageNet class
|
| 427 |
+
use_biased_inference = bool(use_biased_inference) if use_biased_inference is not None else False
|
| 428 |
+
biased_class_name = (biased_class_name or "").strip() if biased_class_name else ""
|
| 429 |
+
if use_biased_inference and biased_class_name:
|
| 430 |
+
config['biased_inference'] = {'enable': True, 'class': biased_class_name}
|
| 431 |
+
else:
|
| 432 |
+
config['biased_inference'] = config.get('biased_inference') or {'enable': False, 'class': None}
|
| 433 |
+
|
| 434 |
# Run generative inference
|
| 435 |
result = model.inference(image, model_type, config)
|
| 436 |
|
|
|
|
| 544 |
example.get("eps_min_mult", 1.0),
|
| 545 |
example.get("step_max_mult", 4.0),
|
| 546 |
example.get("step_min_mult", 1.0),
|
| 547 |
+
example.get("use_biased_inference", False),
|
| 548 |
+
example.get("biased_class_name", ""),
|
| 549 |
mask_img,
|
| 550 |
gr.Group(visible=True),
|
| 551 |
]
|
|
|
|
| 637 |
with gr.Row():
|
| 638 |
step_max_mult_slider = gr.Slider(minimum=0.1, maximum=150.0, value=50.0, step=0.1, label="Step size: multiplier at center")
|
| 639 |
step_min_mult_slider = gr.Slider(minimum=0.1, maximum=10.0, value=0.2, step=0.1, label="Step size: multiplier at periphery")
|
| 640 |
+
gr.Markdown("### 🎯 Biased inference")
|
| 641 |
+
gr.Markdown("Bias the prediction toward a specific ImageNet category (1000 classes).")
|
| 642 |
+
with gr.Row():
|
| 643 |
+
use_biased_inference_check = gr.Checkbox(value=False, label="Use biased inference (bias toward a target class)")
|
| 644 |
+
biased_class_dropdown = gr.Dropdown(
|
| 645 |
+
choices=[("— No bias —", "")] + [(label, label) for label in sorted(IMAGENET_LABELS)],
|
| 646 |
+
value="",
|
| 647 |
+
label="Target class",
|
| 648 |
+
allow_custom_value=False,
|
| 649 |
+
filterable=True,
|
| 650 |
+
)
|
| 651 |
|
| 652 |
with gr.Column(scale=2):
|
| 653 |
# Outputs
|
|
|
|
| 680 |
mask_radius_slider, mask_sigma_slider,
|
| 681 |
eps_max_mult_slider, eps_min_mult_slider,
|
| 682 |
step_max_mult_slider, step_min_mult_slider,
|
| 683 |
+
use_biased_inference_check, biased_class_dropdown,
|
| 684 |
mask_preview,
|
| 685 |
params_section,
|
| 686 |
],
|
|
|
|
| 712 |
mask_radius_slider, mask_sigma_slider,
|
| 713 |
eps_max_mult_slider, eps_min_mult_slider,
|
| 714 |
step_max_mult_slider, step_min_mult_slider,
|
| 715 |
+
use_biased_inference_check, biased_class_dropdown,
|
| 716 |
],
|
| 717 |
outputs=[output_image, output_frames]
|
| 718 |
)
|
inference.py
CHANGED
|
@@ -280,12 +280,16 @@ class InferStep:
|
|
| 280 |
self.step_size = step_size
|
| 281 |
|
| 282 |
def project(self, x):
|
| 283 |
-
# Per-pixel
|
| 284 |
-
|
|
|
|
|
|
|
| 285 |
if self.use_adaptive_eps:
|
| 286 |
-
|
|
|
|
| 287 |
else:
|
| 288 |
-
|
|
|
|
| 289 |
return torch.clamp(self.orig_image + diff, 0, 1)
|
| 290 |
|
| 291 |
def step(self, x, grad):
|
|
@@ -334,7 +338,8 @@ def get_inference_configs(inference_type='IncreaseConfidence', eps=0.5, n_itr=50
|
|
| 334 |
'inference_normalization': 'off', # 'on' or 'off' (match psychiatry implementation)
|
| 335 |
'recognition_normalization': 'off',
|
| 336 |
'iterations_to_show': get_iterations_to_show(n_itr), # Dynamic iterations to visualize
|
| 337 |
-
'misc_info': {'keep_grads': False} # Additional configuration
|
|
|
|
| 338 |
}
|
| 339 |
|
| 340 |
# Customize based on inference type
|
|
@@ -829,6 +834,30 @@ class GenerativeInferenceModel:
|
|
| 829 |
adaptive_step_config=adaptive_step,
|
| 830 |
)
|
| 831 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 832 |
# Storage for inference steps
|
| 833 |
# Create a new tensor that requires gradients
|
| 834 |
x = image_tensor.clone().detach().requires_grad_(True)
|
|
@@ -934,6 +963,13 @@ class GenerativeInferenceModel:
|
|
| 934 |
|
| 935 |
grad = torch.autograd.grad(loss, x, retain_graph=True)[0]
|
| 936 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 937 |
if grad is None:
|
| 938 |
print("Warning: Direct gradient calculation failed")
|
| 939 |
# Fall back to random perturbation
|
|
|
|
| 280 |
self.step_size = step_size
|
| 281 |
|
| 282 |
def project(self, x):
|
| 283 |
+
# Per-pixel L2 (color) constraint: scale diff so color_norm <= eps (or eps_map).
|
| 284 |
+
# Constraint is per-pixel L2 on RGB (spherical), not L∞ per channel.
|
| 285 |
+
diff = x - self.orig_image # (B, C, H, W)
|
| 286 |
+
color_norm = torch.norm(diff, dim=1, keepdim=True) # (B, 1, H, W)
|
| 287 |
if self.use_adaptive_eps:
|
| 288 |
+
# eps_map has shape (1, 1, H, W), broadcasts to (B, 1, H, W)
|
| 289 |
+
scale = torch.clamp(self.eps_map / (color_norm + 1e-10), max=1.0)
|
| 290 |
else:
|
| 291 |
+
scale = torch.clamp(self.eps / (color_norm + 1e-10), max=1.0)
|
| 292 |
+
diff = diff * scale
|
| 293 |
return torch.clamp(self.orig_image + diff, 0, 1)
|
| 294 |
|
| 295 |
def step(self, x, grad):
|
|
|
|
| 338 |
'inference_normalization': 'off', # 'on' or 'off' (match psychiatry implementation)
|
| 339 |
'recognition_normalization': 'off',
|
| 340 |
'iterations_to_show': get_iterations_to_show(n_itr), # Dynamic iterations to visualize
|
| 341 |
+
'misc_info': {'keep_grads': False}, # Additional configuration
|
| 342 |
+
'biased_inference': {'enable': False, 'class': None} # Bias perception toward a target class (ImageNet label)
|
| 343 |
}
|
| 344 |
|
| 345 |
# Customize based on inference type
|
|
|
|
| 834 |
adaptive_step_config=adaptive_step,
|
| 835 |
)
|
| 836 |
|
| 837 |
+
# Biased inference: resolve class name to index (ImageNet simple labels, case-insensitive)
|
| 838 |
+
biased_inference_config = config.get('biased_inference', {'enable': False, 'class': None})
|
| 839 |
+
biased_class_index = None
|
| 840 |
+
biased_class_tensor = None
|
| 841 |
+
if biased_inference_config.get('enable', False):
|
| 842 |
+
class_name = biased_inference_config.get('class') or biased_inference_config.get('class_name')
|
| 843 |
+
if class_name:
|
| 844 |
+
try:
|
| 845 |
+
biased_class_index = next(
|
| 846 |
+
i for i, label in enumerate(self.labels)
|
| 847 |
+
if label.lower() == class_name.lower()
|
| 848 |
+
)
|
| 849 |
+
biased_class_tensor = torch.tensor(
|
| 850 |
+
[biased_class_index], device=device, dtype=torch.long
|
| 851 |
+
)
|
| 852 |
+
print(f"Biased inference: biasing toward class '{self.labels[biased_class_index]}' (index {biased_class_index})")
|
| 853 |
+
except StopIteration:
|
| 854 |
+
raise ValueError(
|
| 855 |
+
f"biased_inference class '{class_name}' not found in ImageNet simple labels. "
|
| 856 |
+
"Use a label from imagenet-simple-labels (e.g. 'goldfish', 'tabby cat')."
|
| 857 |
+
)
|
| 858 |
+
else:
|
| 859 |
+
print("Biased inference enabled but no class specified; ignoring.")
|
| 860 |
+
|
| 861 |
# Storage for inference steps
|
| 862 |
# Create a new tensor that requires gradients
|
| 863 |
x = image_tensor.clone().detach().requires_grad_(True)
|
|
|
|
| 963 |
|
| 964 |
grad = torch.autograd.grad(loss, x, retain_graph=True)[0]
|
| 965 |
|
| 966 |
+
# Biased inference: subtract gradient that pushes toward target class
|
| 967 |
+
if biased_inference_config.get('enable', False) and biased_class_tensor is not None:
|
| 968 |
+
output_full = model(x)
|
| 969 |
+
loss_biased = F.cross_entropy(output_full, biased_class_tensor)
|
| 970 |
+
grad_biased = torch.autograd.grad(loss_biased, x)[0]
|
| 971 |
+
grad = grad - grad_biased
|
| 972 |
+
|
| 973 |
if grad is None:
|
| 974 |
print("Warning: Direct gradient calculation failed")
|
| 975 |
# Fall back to random perturbation
|