Spaces:
Running on Zero
Running on Zero
Tahereh Toosi commited on
Commit ·
f45a5a8
1
Parent(s): a653fde
added saving rady to be deployed
Browse files
app.py
CHANGED
|
@@ -10,7 +10,10 @@ except ImportError:
|
|
| 10 |
return func
|
| 11 |
|
| 12 |
import os
|
|
|
|
|
|
|
| 13 |
import argparse
|
|
|
|
| 14 |
from inference import GenerativeInferenceModel, get_inference_configs, get_imagenet_labels
|
| 15 |
|
| 16 |
# Parse command line arguments
|
|
@@ -21,6 +24,8 @@ args = parser.parse_args()
|
|
| 21 |
# Create model directories if they don't exist
|
| 22 |
os.makedirs("models", exist_ok=True)
|
| 23 |
os.makedirs("stimuli", exist_ok=True)
|
|
|
|
|
|
|
| 24 |
|
| 25 |
# Load ImageNet labels for biased-inference dropdown (1000 classes)
|
| 26 |
IMAGENET_LABELS = get_imagenet_labels()
|
|
@@ -361,16 +366,68 @@ examples = [
|
|
| 361 |
}
|
| 362 |
]
|
| 363 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 364 |
@GPU
|
| 365 |
def run_inference(image, model_type, inference_type, eps_value, num_iterations,
|
| 366 |
initial_noise=0.05, diffusion_noise=0.3, step_size=0.8, model_layer="layer3",
|
| 367 |
use_adaptive_eps=False, use_adaptive_step=False,
|
| 368 |
mask_center_x=0.0, mask_center_y=0.0, mask_radius=0.3, mask_sigma=0.2,
|
| 369 |
eps_max_mult=4.0, eps_min_mult=1.0, step_max_mult=4.0, step_min_mult=1.0,
|
| 370 |
-
use_biased_inference=False, biased_class_name=""
|
|
|
|
| 371 |
# Check if image is provided
|
| 372 |
if image is None:
|
| 373 |
-
return None, "Please upload an image before running inference."
|
| 374 |
|
| 375 |
# Convert eps to float
|
| 376 |
eps = float(eps_value)
|
|
@@ -453,8 +510,36 @@ def run_inference(image, model_type, inference_type, eps_value, num_iterations,
|
|
| 453 |
# Convert the final output image to PIL
|
| 454 |
final_image = Image.fromarray((output_image.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8))
|
| 455 |
|
| 456 |
-
#
|
| 457 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 458 |
|
| 459 |
def _image_to_pil(img):
|
| 460 |
"""Convert Gradio image value (PIL, numpy, path, or dict) to PIL Image; return None if invalid."""
|
|
@@ -546,6 +631,7 @@ def apply_example(example):
|
|
| 546 |
example.get("step_min_mult", 1.0),
|
| 547 |
example.get("use_biased_inference", False),
|
| 548 |
example.get("biased_class_name", ""),
|
|
|
|
| 549 |
mask_img,
|
| 550 |
gr.Group(visible=True),
|
| 551 |
]
|
|
@@ -570,11 +656,14 @@ with gr.Blocks(title="Human Hallucination Prediction", css="""
|
|
| 570 |
2. **Click "Run Generative Inference"** to predict what hallucination humans may perceive
|
| 571 |
3. **View the prediction**: Watch as the model reveals the perceptual structures it expects—matching what humans typically hallucinate
|
| 572 |
4. **You can upload your own images**
|
|
|
|
| 573 |
""")
|
| 574 |
with gr.Row():
|
| 575 |
with gr.Column(scale=1):
|
| 576 |
-
# Inputs
|
| 577 |
-
|
|
|
|
|
|
|
| 578 |
mask_preview = gr.Image(
|
| 579 |
label="Mask center preview (click to set center — circle shows mask)",
|
| 580 |
type="pil",
|
|
@@ -644,15 +733,16 @@ with gr.Blocks(title="Human Hallucination Prediction", css="""
|
|
| 644 |
biased_class_dropdown = gr.Dropdown(
|
| 645 |
choices=[("— No bias —", "")] + [(label, label) for label in sorted(IMAGENET_LABELS)],
|
| 646 |
value="",
|
| 647 |
-
label="
|
| 648 |
allow_custom_value=False,
|
| 649 |
filterable=True,
|
| 650 |
)
|
| 651 |
-
|
| 652 |
with gr.Column(scale=2):
|
| 653 |
# Outputs
|
| 654 |
output_image = gr.Image(label="Predicted Hallucination")
|
| 655 |
output_frames = gr.Gallery(label="Hallucination Prediction Process", columns=5, rows=2)
|
|
|
|
|
|
|
| 656 |
|
| 657 |
# Examples section with integrated explanations
|
| 658 |
gr.Markdown("## Examples")
|
|
@@ -681,6 +771,7 @@ with gr.Blocks(title="Human Hallucination Prediction", css="""
|
|
| 681 |
eps_max_mult_slider, eps_min_mult_slider,
|
| 682 |
step_max_mult_slider, step_min_mult_slider,
|
| 683 |
use_biased_inference_check, biased_class_dropdown,
|
|
|
|
| 684 |
mask_preview,
|
| 685 |
params_section,
|
| 686 |
],
|
|
@@ -689,7 +780,8 @@ with gr.Blocks(title="Human Hallucination Prediction", css="""
|
|
| 689 |
# Right column for the explanation
|
| 690 |
with gr.Column(scale=2):
|
| 691 |
gr.Markdown(f"### {ex['name']}")
|
| 692 |
-
|
|
|
|
| 693 |
|
| 694 |
# Show instructions if they exist
|
| 695 |
if "instructions" in ex:
|
|
@@ -713,8 +805,9 @@ with gr.Blocks(title="Human Hallucination Prediction", css="""
|
|
| 713 |
eps_max_mult_slider, eps_min_mult_slider,
|
| 714 |
step_max_mult_slider, step_min_mult_slider,
|
| 715 |
use_biased_inference_check, biased_class_dropdown,
|
|
|
|
| 716 |
],
|
| 717 |
-
outputs=[output_image, output_frames]
|
| 718 |
)
|
| 719 |
|
| 720 |
# Toggle parameters visibility
|
|
@@ -744,6 +837,12 @@ with gr.Blocks(title="Human Hallucination Prediction", css="""
|
|
| 744 |
inputs=_mask_preview_inputs(),
|
| 745 |
outputs=[mask_preview],
|
| 746 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 747 |
mask_center_x_slider.change(
|
| 748 |
fn=draw_mask_overlay,
|
| 749 |
inputs=_mask_preview_inputs(),
|
|
|
|
| 10 |
return func
|
| 11 |
|
| 12 |
import os
|
| 13 |
+
import re
|
| 14 |
+
import json
|
| 15 |
import argparse
|
| 16 |
+
from datetime import datetime
|
| 17 |
from inference import GenerativeInferenceModel, get_inference_configs, get_imagenet_labels
|
| 18 |
|
| 19 |
# Parse command line arguments
|
|
|
|
| 24 |
# Create model directories if they don't exist
|
| 25 |
os.makedirs("models", exist_ok=True)
|
| 26 |
os.makedirs("stimuli", exist_ok=True)
|
| 27 |
+
SAVED_RUNS_DIR = "saved_runs"
|
| 28 |
+
os.makedirs(SAVED_RUNS_DIR, exist_ok=True)
|
| 29 |
|
| 30 |
# Load ImageNet labels for biased-inference dropdown (1000 classes)
|
| 31 |
IMAGENET_LABELS = get_imagenet_labels()
|
|
|
|
| 366 |
}
|
| 367 |
]
|
| 368 |
|
| 369 |
+
def _input_image_stem(image):
|
| 370 |
+
"""Return a safe filename stem from the input image: known name or 'user_img'."""
|
| 371 |
+
if image is None:
|
| 372 |
+
return "user_img"
|
| 373 |
+
path = None
|
| 374 |
+
if isinstance(image, str) and (os.path.isfile(image) or os.path.exists(image)):
|
| 375 |
+
path = image
|
| 376 |
+
if isinstance(image, dict) and image.get("path") and os.path.exists(image.get("path", "")):
|
| 377 |
+
path = image["path"]
|
| 378 |
+
if path:
|
| 379 |
+
name = os.path.splitext(os.path.basename(path))[0]
|
| 380 |
+
# Safe for filenames: alphanumeric, underscore, hyphen only; max length
|
| 381 |
+
safe = re.sub(r"[^\w\-]", "_", name).strip("_") or "user_img"
|
| 382 |
+
return safe[:80] if len(safe) > 80 else safe
|
| 383 |
+
return "user_img"
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
def _get_image_path_for_stem(img):
|
| 387 |
+
"""Extract file path from Gradio image value (path string, dict with path, or PIL) for stem tracking."""
|
| 388 |
+
if img is None:
|
| 389 |
+
return ""
|
| 390 |
+
if isinstance(img, str) and (os.path.isfile(img) or os.path.exists(img)):
|
| 391 |
+
return img
|
| 392 |
+
if isinstance(img, dict) and img.get("path"):
|
| 393 |
+
p = img["path"]
|
| 394 |
+
if isinstance(p, str) and os.path.exists(p):
|
| 395 |
+
return p
|
| 396 |
+
return ""
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
def _update_tracked_image_path(img):
|
| 400 |
+
"""Keep path only when it's a known stimulus (e.g. from stimuli/); else '' so stem is 'user_img'."""
|
| 401 |
+
path = _get_image_path_for_stem(img)
|
| 402 |
+
if path and "stimuli" in path:
|
| 403 |
+
return path
|
| 404 |
+
return ""
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
def _config_to_json_serializable(c):
|
| 408 |
+
"""Return a copy of config with only JSON-serializable values."""
|
| 409 |
+
if isinstance(c, dict):
|
| 410 |
+
return {k: _config_to_json_serializable(v) for k, v in c.items()}
|
| 411 |
+
if isinstance(c, (list, tuple)):
|
| 412 |
+
return [_config_to_json_serializable(x) for x in c]
|
| 413 |
+
if isinstance(c, (bool, int, float, str, type(None))):
|
| 414 |
+
return c
|
| 415 |
+
if hasattr(c, "item"): # e.g. numpy scalar
|
| 416 |
+
return c.item()
|
| 417 |
+
return str(c)
|
| 418 |
+
|
| 419 |
+
|
| 420 |
@GPU
|
| 421 |
def run_inference(image, model_type, inference_type, eps_value, num_iterations,
|
| 422 |
initial_noise=0.05, diffusion_noise=0.3, step_size=0.8, model_layer="layer3",
|
| 423 |
use_adaptive_eps=False, use_adaptive_step=False,
|
| 424 |
mask_center_x=0.0, mask_center_y=0.0, mask_radius=0.3, mask_sigma=0.2,
|
| 425 |
eps_max_mult=4.0, eps_min_mult=1.0, step_max_mult=4.0, step_min_mult=1.0,
|
| 426 |
+
use_biased_inference=False, biased_class_name="",
|
| 427 |
+
current_image_path=""):
|
| 428 |
# Check if image is provided
|
| 429 |
if image is None:
|
| 430 |
+
return None, [], "Please upload an image before running inference.", None
|
| 431 |
|
| 432 |
# Convert eps to float
|
| 433 |
eps = float(eps_value)
|
|
|
|
| 510 |
# Convert the final output image to PIL
|
| 511 |
final_image = Image.fromarray((output_image.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8))
|
| 512 |
|
| 513 |
+
# Always save GIF and config and offer as downloads (browser will ask where to save)
|
| 514 |
+
save_status = ""
|
| 515 |
+
files_for_download = None
|
| 516 |
+
if frames:
|
| 517 |
+
# Use tracked path when available (e.g. from Load Parameters); else derive from image (PIL loses path)
|
| 518 |
+
stem = _input_image_stem(current_image_path if (current_image_path and current_image_path.strip()) else image)
|
| 519 |
+
unique_id = f"{datetime.now().strftime('%Y%m%d_%H%M%S')}_{stem}"
|
| 520 |
+
gif_path = os.path.join(SAVED_RUNS_DIR, f"{unique_id}.gif")
|
| 521 |
+
config_path = os.path.join(SAVED_RUNS_DIR, f"{unique_id}_config.json")
|
| 522 |
+
try:
|
| 523 |
+
frames[0].save(
|
| 524 |
+
gif_path,
|
| 525 |
+
save_all=True,
|
| 526 |
+
append_images=frames[1:],
|
| 527 |
+
loop=0,
|
| 528 |
+
duration=200,
|
| 529 |
+
)
|
| 530 |
+
save_config = {
|
| 531 |
+
"model_type": model_type,
|
| 532 |
+
"input_image_name": stem,
|
| 533 |
+
**_config_to_json_serializable(config),
|
| 534 |
+
}
|
| 535 |
+
with open(config_path, "w") as f:
|
| 536 |
+
json.dump(save_config, f, indent=2)
|
| 537 |
+
files_for_download = [gif_path, config_path]
|
| 538 |
+
save_status = "**Download results** — Use the links below to save the GIF and config to your device (your browser may ask where to save)."
|
| 539 |
+
except Exception as e:
|
| 540 |
+
save_status = f"Save failed: {e}"
|
| 541 |
+
|
| 542 |
+
return final_image, frames, save_status, files_for_download
|
| 543 |
|
| 544 |
def _image_to_pil(img):
|
| 545 |
"""Convert Gradio image value (PIL, numpy, path, or dict) to PIL Image; return None if invalid."""
|
|
|
|
| 631 |
example.get("step_min_mult", 1.0),
|
| 632 |
example.get("use_biased_inference", False),
|
| 633 |
example.get("biased_class_name", ""),
|
| 634 |
+
example["image"], # keep path for save filename (e.g. UrbanOffice1 -> urbanoffice1)
|
| 635 |
mask_img,
|
| 636 |
gr.Group(visible=True),
|
| 637 |
]
|
|
|
|
| 656 |
2. **Click "Run Generative Inference"** to predict what hallucination humans may perceive
|
| 657 |
3. **View the prediction**: Watch as the model reveals the perceptual structures it expects—matching what humans typically hallucinate
|
| 658 |
4. **You can upload your own images**
|
| 659 |
+
5. **You can download the results** as a .gif file together with the configs.json
|
| 660 |
""")
|
| 661 |
with gr.Row():
|
| 662 |
with gr.Column(scale=1):
|
| 663 |
+
# Inputs (track path so save filenames use stimulus name when from example)
|
| 664 |
+
default_image_path = os.path.join("stimuli", "urbanoffice1.jpg")
|
| 665 |
+
image_input = gr.Image(label="Input Image (click to set mask center)", type="pil", value=default_image_path)
|
| 666 |
+
current_image_path_state = gr.State(value=default_image_path)
|
| 667 |
mask_preview = gr.Image(
|
| 668 |
label="Mask center preview (click to set center — circle shows mask)",
|
| 669 |
type="pil",
|
|
|
|
| 733 |
biased_class_dropdown = gr.Dropdown(
|
| 734 |
choices=[("— No bias —", "")] + [(label, label) for label in sorted(IMAGENET_LABELS)],
|
| 735 |
value="",
|
| 736 |
+
label="Biased toward category",
|
| 737 |
allow_custom_value=False,
|
| 738 |
filterable=True,
|
| 739 |
)
|
|
|
|
| 740 |
with gr.Column(scale=2):
|
| 741 |
# Outputs
|
| 742 |
output_image = gr.Image(label="Predicted Hallucination")
|
| 743 |
output_frames = gr.Gallery(label="Hallucination Prediction Process", columns=5, rows=2)
|
| 744 |
+
save_status_md = gr.Markdown(value="")
|
| 745 |
+
download_files = gr.File(label="Download results (GIF + config)", file_count="multiple")
|
| 746 |
|
| 747 |
# Examples section with integrated explanations
|
| 748 |
gr.Markdown("## Examples")
|
|
|
|
| 771 |
eps_max_mult_slider, eps_min_mult_slider,
|
| 772 |
step_max_mult_slider, step_min_mult_slider,
|
| 773 |
use_biased_inference_check, biased_class_dropdown,
|
| 774 |
+
current_image_path_state,
|
| 775 |
mask_preview,
|
| 776 |
params_section,
|
| 777 |
],
|
|
|
|
| 780 |
# Right column for the explanation
|
| 781 |
with gr.Column(scale=2):
|
| 782 |
gr.Markdown(f"### {ex['name']}")
|
| 783 |
+
if ex["name"] not in ("farm1", "ArtGallery1", "UrbanOffice1"):
|
| 784 |
+
gr.Markdown(f"[Read more on Wikipedia]({ex['wiki']})")
|
| 785 |
|
| 786 |
# Show instructions if they exist
|
| 787 |
if "instructions" in ex:
|
|
|
|
| 805 |
eps_max_mult_slider, eps_min_mult_slider,
|
| 806 |
step_max_mult_slider, step_min_mult_slider,
|
| 807 |
use_biased_inference_check, biased_class_dropdown,
|
| 808 |
+
current_image_path_state,
|
| 809 |
],
|
| 810 |
+
outputs=[output_image, output_frames, save_status_md, download_files]
|
| 811 |
)
|
| 812 |
|
| 813 |
# Toggle parameters visibility
|
|
|
|
| 837 |
inputs=_mask_preview_inputs(),
|
| 838 |
outputs=[mask_preview],
|
| 839 |
)
|
| 840 |
+
# Keep tracked path for save filename: known stimulus name or clear so stem becomes 'user_img'
|
| 841 |
+
image_input.change(
|
| 842 |
+
fn=_update_tracked_image_path,
|
| 843 |
+
inputs=[image_input],
|
| 844 |
+
outputs=[current_image_path_state],
|
| 845 |
+
)
|
| 846 |
mask_center_x_slider.change(
|
| 847 |
fn=draw_mask_overlay,
|
| 848 |
inputs=_mask_preview_inputs(),
|