ttoosi's picture
added more sliders
feeaf45 verified
raw
history blame
9.76 kB
import gradio as gr
import torch
import numpy as np
from PIL import Image
try:
from spaces import GPU
except ImportError:
# Define a no-op decorator if running locally
def GPU(func):
return func
import os
import argparse
from inference import GenerativeInferenceModel, get_inference_configs
# Parse command line arguments
parser = argparse.ArgumentParser(description='Run Generative Inference Demo')
parser.add_argument('--port', type=int, default=7860, help='Port to run the server on')
args = parser.parse_args()
# Create model directories if they don't exist
os.makedirs("models", exist_ok=True)
os.makedirs("stimuli", exist_ok=True)
# Check if running on Hugging Face Spaces (using 'SPACE_ID' as an example environment variable)
if "SPACE_ID" in os.environ:
default_port = int(os.environ.get("PORT", 7860)) # Use provided PORT or fallback to 7860
else:
default_port = 8861 # Local default port
# Initialize model
model = GenerativeInferenceModel()
@GPU
def run_inference(image, model_type, inference_type, eps_value, num_iterations,
step_size, initial_noise=0.05, step_noise=0.01, model_layer="all"):
# Convert eps to float
eps = float(eps_value)
# Load inference configuration based on the selected type
config = get_inference_configs(inference_type=inference_type, eps=eps, n_itr=int(num_iterations), step_size=float(step_size))
# Handle ReverseDiffusion specific parameters
if inference_type == "ReverseDiffusion":
config['initial_inference_noise_ratio'] = float(initial_noise)
config['diffusion_noise_ratio'] = float(step_noise)
config['top_layer'] = model_layer
# Run generative inference
result = model.inference(image, model_type, config)
# Extract results based on return type
if isinstance(result, tuple):
# Old format returning (output_image, all_steps)
output_image, all_steps = result
else:
# New format returning dictionary
output_image = result['final_image']
all_steps = result['steps']
# Create animation frames
frames = []
for i, step_image in enumerate(all_steps):
# Convert tensor to PIL image
step_pil = Image.fromarray((step_image.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8))
frames.append(step_pil)
# Convert the final output image to PIL
final_image = Image.fromarray((output_image.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8))
# Return the final inferred image and the animation frames directly
return final_image, frames
# Define the interface
with gr.Blocks(title="Generative Inference Demo") as demo:
gr.Markdown("# Generative Inference Demo")
gr.Markdown("This demo showcases how neural networks can perceive visual illusions through generative inference.")
with gr.Row():
with gr.Column(scale=1):
# Inputs
image_input = gr.Image(label="Upload Image or Select an Illusion", type="pil")
with gr.Row():
model_choice = gr.Dropdown(
choices=["robust_resnet50", "standard_resnet50"],
value="robust_resnet50",
label="Model"
)
inference_type = gr.Dropdown(
choices=["IncreaseConfidence", "ReverseDiffusion"],
value="IncreaseConfidence",
label="Inference Method"
)
with gr.Row():
eps_slider = gr.Slider(minimum=0.0, maximum=50.0, value=0.5, step=0.1, label="Epsilon (Perturbation Size)")
iterations_slider = gr.Slider(minimum=1, maximum=500, value=50, step=1, label="Number of Iterations")
step_size_slider = gr.Slider(minimum=0.0, maximum=10.0, value=1.0, step=0.1, label="Step Size")
# Additional parameters for ReverseDiffusion that appear conditionally
with gr.Row(visible=False) as diffusion_params:
initial_noise_slider = gr.Slider(minimum=0.0, maximum=0.5, value=0.05, step=0.01,
label="Initial Noise Ratio")
step_noise_slider = gr.Slider(minimum=0.0, maximum=0.2, value=0.01, step=0.01,
label="Per-Step Noise Ratio")
with gr.Row(visible=False) as layer_params:
layer_choice = gr.Dropdown(
choices=["all", "conv1", "bn1", "relu", "maxpool", "layer1", "layer2", "layer3", "layer4", "avgpool"],
value="all",
label="Model Layer"
)
# Show/hide parameters based on inference type
def toggle_params(inference):
if inference == "ReverseDiffusion":
return gr.update(visible=True), gr.update(visible=True)
else:
return gr.update(visible=False), gr.update(visible=False)
inference_type.change(toggle_params, [inference_type], [diffusion_params, layer_params])
run_button = gr.Button("Run Inference")
with gr.Column(scale=2):
# Outputs
output_image = gr.Image(label="Final Inferred Image")
output_frames = gr.Gallery(label="Inference Steps", columns=4, rows=2)
# Set up example images with default parameters for all inputs
examples = [
# IncreaseConfidence examples
[os.path.join("stimuli", "Kanizsa_square.jpg"), "robust_resnet50", "IncreaseConfidence",
0.5, 50, 1.0, 0.05, 0.01, "all"],
[os.path.join("stimuli", "face_vase.png"), "robust_resnet50", "IncreaseConfidence",
0.5, 50, 1.0, 0.05, 0.01, "all"],
[os.path.join("stimuli", "figure_ground.png"), "robust_resnet50", "IncreaseConfidence",
0.7, 100, 1.0, 0.05, 0.01, "all"],
# ReverseDiffusion examples with different layers and noise values
[os.path.join("stimuli", "Neon_Color_Circle.jpg"), "robust_resnet50", "ReverseDiffusion",
0.3, 80, 0.8, 0.05, 0.01, "all"],
[os.path.join("stimuli", "Kanizsa_square.jpg"), "robust_resnet50", "ReverseDiffusion",
0.5, 50, 0.8, 0.1, 0.02, "layer4"], # Using layer4 (high-level features)
[os.path.join("stimuli", "face_vase.png"), "robust_resnet50", "ReverseDiffusion",
0.4, 60, 0.8, 0.15, 0.03, "layer1"] # Using layer1 (lower-level features)
]
gr.Examples(examples=examples, inputs=[
image_input, model_choice, inference_type,
eps_slider, iterations_slider, step_size_slider,
initial_noise_slider, step_noise_slider, layer_choice
])
# Set up event handler
run_button.click(
fn=run_inference,
inputs=[
image_input, model_choice, inference_type,
eps_slider, iterations_slider, step_size_slider,
initial_noise_slider, step_noise_slider, layer_choice
],
outputs=[output_image, output_frames]
)
# Include a description of the technique
gr.Markdown("""
## About Generative Inference
Generative inference is a technique that reveals how neural networks perceive visual stimuli. This demo offers two methods:
### 1. IncreaseConfidence
Optimizes the input to increase the network's confidence in its least confident predictions. This reveals how the
network perceives contours, figure-ground separation, and other visual phenomena similar to human perception.
### 2. ReverseDiffusion
Starts with a noisy version of the image and guides the optimization to match features of the noisy image.
This approach can reveal different aspects of visual processing and is inspired by diffusion models.
When using ReverseDiffusion, additional parameters become available:
- **Initial Noise Ratio**: Controls the amount of noise added to the image at the beginning
- **Per-Step Noise Ratio**: Controls the amount of noise added at each optimization step
- **Model Layer**: Select a specific layer of the ResNet50 model to extract features from:
- `all`: Use the full model (default)
- `conv1`: First convolutional layer
- `bn1`: First batch normalization layer
- `relu`: First ReLU activation
- `maxpool`: Max pooling layer
- `layer1`: First residual block
- `layer2`: Second residual block
- `layer3`: Third residual block
- `layer4`: Fourth residual block
- `avgpool`: Average pooling layer
Different layers capture different levels of abstraction - earlier layers represent low-level features
like edges and textures, while later layers represent higher-level features and object parts.
This demo allows you to:
1. Upload your own images or select from example images
2. Choose between inference methods (IncreaseConfidence or ReverseDiffusion)
3. Select between robust or standard ResNet50 models
4. Adjust parameters like perturbation size (epsilon) and number of iterations
5. For ReverseDiffusion, fine-tune noise levels and select specific model layers
6. Visualize how the perception emerges over time
""")
# Launch the demo with specific settings
if __name__ == "__main__":
print(f"Starting server on port {args.port}")
# Simplified launch parameters
demo.launch(
server_name="0.0.0.0", # Listen on all interfaces
server_port=args.port, # Use the port from command line arguments
share=False,
debug=True
)