Tahereh Toosi commited on
Commit
f45a5a8
·
1 Parent(s): a653fde

added saving rady to be deployed

Browse files
Files changed (1) hide show
  1. app.py +109 -10
app.py CHANGED
@@ -10,7 +10,10 @@ except ImportError:
10
  return func
11
 
12
  import os
 
 
13
  import argparse
 
14
  from inference import GenerativeInferenceModel, get_inference_configs, get_imagenet_labels
15
 
16
  # Parse command line arguments
@@ -21,6 +24,8 @@ args = parser.parse_args()
21
  # Create model directories if they don't exist
22
  os.makedirs("models", exist_ok=True)
23
  os.makedirs("stimuli", exist_ok=True)
 
 
24
 
25
  # Load ImageNet labels for biased-inference dropdown (1000 classes)
26
  IMAGENET_LABELS = get_imagenet_labels()
@@ -361,16 +366,68 @@ examples = [
361
  }
362
  ]
363
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
364
  @GPU
365
  def run_inference(image, model_type, inference_type, eps_value, num_iterations,
366
  initial_noise=0.05, diffusion_noise=0.3, step_size=0.8, model_layer="layer3",
367
  use_adaptive_eps=False, use_adaptive_step=False,
368
  mask_center_x=0.0, mask_center_y=0.0, mask_radius=0.3, mask_sigma=0.2,
369
  eps_max_mult=4.0, eps_min_mult=1.0, step_max_mult=4.0, step_min_mult=1.0,
370
- use_biased_inference=False, biased_class_name=""):
 
371
  # Check if image is provided
372
  if image is None:
373
- return None, "Please upload an image before running inference."
374
 
375
  # Convert eps to float
376
  eps = float(eps_value)
@@ -453,8 +510,36 @@ def run_inference(image, model_type, inference_type, eps_value, num_iterations,
453
  # Convert the final output image to PIL
454
  final_image = Image.fromarray((output_image.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8))
455
 
456
- # Return the final inferred image and the animation frames directly
457
- return final_image, frames
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
458
 
459
  def _image_to_pil(img):
460
  """Convert Gradio image value (PIL, numpy, path, or dict) to PIL Image; return None if invalid."""
@@ -546,6 +631,7 @@ def apply_example(example):
546
  example.get("step_min_mult", 1.0),
547
  example.get("use_biased_inference", False),
548
  example.get("biased_class_name", ""),
 
549
  mask_img,
550
  gr.Group(visible=True),
551
  ]
@@ -570,11 +656,14 @@ with gr.Blocks(title="Human Hallucination Prediction", css="""
570
  2. **Click "Run Generative Inference"** to predict what hallucination humans may perceive
571
  3. **View the prediction**: Watch as the model reveals the perceptual structures it expects—matching what humans typically hallucinate
572
  4. **You can upload your own images**
 
573
  """)
574
  with gr.Row():
575
  with gr.Column(scale=1):
576
- # Inputs
577
- image_input = gr.Image(label="Input Image (click to set mask center)", type="pil", value=os.path.join("stimuli", "urbanoffice1.jpg"))
 
 
578
  mask_preview = gr.Image(
579
  label="Mask center preview (click to set center — circle shows mask)",
580
  type="pil",
@@ -644,15 +733,16 @@ with gr.Blocks(title="Human Hallucination Prediction", css="""
644
  biased_class_dropdown = gr.Dropdown(
645
  choices=[("— No bias —", "")] + [(label, label) for label in sorted(IMAGENET_LABELS)],
646
  value="",
647
- label="Target class",
648
  allow_custom_value=False,
649
  filterable=True,
650
  )
651
-
652
  with gr.Column(scale=2):
653
  # Outputs
654
  output_image = gr.Image(label="Predicted Hallucination")
655
  output_frames = gr.Gallery(label="Hallucination Prediction Process", columns=5, rows=2)
 
 
656
 
657
  # Examples section with integrated explanations
658
  gr.Markdown("## Examples")
@@ -681,6 +771,7 @@ with gr.Blocks(title="Human Hallucination Prediction", css="""
681
  eps_max_mult_slider, eps_min_mult_slider,
682
  step_max_mult_slider, step_min_mult_slider,
683
  use_biased_inference_check, biased_class_dropdown,
 
684
  mask_preview,
685
  params_section,
686
  ],
@@ -689,7 +780,8 @@ with gr.Blocks(title="Human Hallucination Prediction", css="""
689
  # Right column for the explanation
690
  with gr.Column(scale=2):
691
  gr.Markdown(f"### {ex['name']}")
692
- gr.Markdown(f"[Read more on Wikipedia]({ex['wiki']})")
 
693
 
694
  # Show instructions if they exist
695
  if "instructions" in ex:
@@ -713,8 +805,9 @@ with gr.Blocks(title="Human Hallucination Prediction", css="""
713
  eps_max_mult_slider, eps_min_mult_slider,
714
  step_max_mult_slider, step_min_mult_slider,
715
  use_biased_inference_check, biased_class_dropdown,
 
716
  ],
717
- outputs=[output_image, output_frames]
718
  )
719
 
720
  # Toggle parameters visibility
@@ -744,6 +837,12 @@ with gr.Blocks(title="Human Hallucination Prediction", css="""
744
  inputs=_mask_preview_inputs(),
745
  outputs=[mask_preview],
746
  )
 
 
 
 
 
 
747
  mask_center_x_slider.change(
748
  fn=draw_mask_overlay,
749
  inputs=_mask_preview_inputs(),
 
10
  return func
11
 
12
  import os
13
+ import re
14
+ import json
15
  import argparse
16
+ from datetime import datetime
17
  from inference import GenerativeInferenceModel, get_inference_configs, get_imagenet_labels
18
 
19
  # Parse command line arguments
 
24
  # Create model directories if they don't exist
25
  os.makedirs("models", exist_ok=True)
26
  os.makedirs("stimuli", exist_ok=True)
27
+ SAVED_RUNS_DIR = "saved_runs"
28
+ os.makedirs(SAVED_RUNS_DIR, exist_ok=True)
29
 
30
  # Load ImageNet labels for biased-inference dropdown (1000 classes)
31
  IMAGENET_LABELS = get_imagenet_labels()
 
366
  }
367
  ]
368
 
369
+ def _input_image_stem(image):
370
+ """Return a safe filename stem from the input image: known name or 'user_img'."""
371
+ if image is None:
372
+ return "user_img"
373
+ path = None
374
+ if isinstance(image, str) and (os.path.isfile(image) or os.path.exists(image)):
375
+ path = image
376
+ if isinstance(image, dict) and image.get("path") and os.path.exists(image.get("path", "")):
377
+ path = image["path"]
378
+ if path:
379
+ name = os.path.splitext(os.path.basename(path))[0]
380
+ # Safe for filenames: alphanumeric, underscore, hyphen only; max length
381
+ safe = re.sub(r"[^\w\-]", "_", name).strip("_") or "user_img"
382
+ return safe[:80] if len(safe) > 80 else safe
383
+ return "user_img"
384
+
385
+
386
+ def _get_image_path_for_stem(img):
387
+ """Extract file path from Gradio image value (path string, dict with path, or PIL) for stem tracking."""
388
+ if img is None:
389
+ return ""
390
+ if isinstance(img, str) and (os.path.isfile(img) or os.path.exists(img)):
391
+ return img
392
+ if isinstance(img, dict) and img.get("path"):
393
+ p = img["path"]
394
+ if isinstance(p, str) and os.path.exists(p):
395
+ return p
396
+ return ""
397
+
398
+
399
+ def _update_tracked_image_path(img):
400
+ """Keep path only when it's a known stimulus (e.g. from stimuli/); else '' so stem is 'user_img'."""
401
+ path = _get_image_path_for_stem(img)
402
+ if path and "stimuli" in path:
403
+ return path
404
+ return ""
405
+
406
+
407
+ def _config_to_json_serializable(c):
408
+ """Return a copy of config with only JSON-serializable values."""
409
+ if isinstance(c, dict):
410
+ return {k: _config_to_json_serializable(v) for k, v in c.items()}
411
+ if isinstance(c, (list, tuple)):
412
+ return [_config_to_json_serializable(x) for x in c]
413
+ if isinstance(c, (bool, int, float, str, type(None))):
414
+ return c
415
+ if hasattr(c, "item"): # e.g. numpy scalar
416
+ return c.item()
417
+ return str(c)
418
+
419
+
420
  @GPU
421
  def run_inference(image, model_type, inference_type, eps_value, num_iterations,
422
  initial_noise=0.05, diffusion_noise=0.3, step_size=0.8, model_layer="layer3",
423
  use_adaptive_eps=False, use_adaptive_step=False,
424
  mask_center_x=0.0, mask_center_y=0.0, mask_radius=0.3, mask_sigma=0.2,
425
  eps_max_mult=4.0, eps_min_mult=1.0, step_max_mult=4.0, step_min_mult=1.0,
426
+ use_biased_inference=False, biased_class_name="",
427
+ current_image_path=""):
428
  # Check if image is provided
429
  if image is None:
430
+ return None, [], "Please upload an image before running inference.", None
431
 
432
  # Convert eps to float
433
  eps = float(eps_value)
 
510
  # Convert the final output image to PIL
511
  final_image = Image.fromarray((output_image.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8))
512
 
513
+ # Always save GIF and config and offer as downloads (browser will ask where to save)
514
+ save_status = ""
515
+ files_for_download = None
516
+ if frames:
517
+ # Use tracked path when available (e.g. from Load Parameters); else derive from image (PIL loses path)
518
+ stem = _input_image_stem(current_image_path if (current_image_path and current_image_path.strip()) else image)
519
+ unique_id = f"{datetime.now().strftime('%Y%m%d_%H%M%S')}_{stem}"
520
+ gif_path = os.path.join(SAVED_RUNS_DIR, f"{unique_id}.gif")
521
+ config_path = os.path.join(SAVED_RUNS_DIR, f"{unique_id}_config.json")
522
+ try:
523
+ frames[0].save(
524
+ gif_path,
525
+ save_all=True,
526
+ append_images=frames[1:],
527
+ loop=0,
528
+ duration=200,
529
+ )
530
+ save_config = {
531
+ "model_type": model_type,
532
+ "input_image_name": stem,
533
+ **_config_to_json_serializable(config),
534
+ }
535
+ with open(config_path, "w") as f:
536
+ json.dump(save_config, f, indent=2)
537
+ files_for_download = [gif_path, config_path]
538
+ save_status = "**Download results** — Use the links below to save the GIF and config to your device (your browser may ask where to save)."
539
+ except Exception as e:
540
+ save_status = f"Save failed: {e}"
541
+
542
+ return final_image, frames, save_status, files_for_download
543
 
544
  def _image_to_pil(img):
545
  """Convert Gradio image value (PIL, numpy, path, or dict) to PIL Image; return None if invalid."""
 
631
  example.get("step_min_mult", 1.0),
632
  example.get("use_biased_inference", False),
633
  example.get("biased_class_name", ""),
634
+ example["image"], # keep path for save filename (e.g. UrbanOffice1 -> urbanoffice1)
635
  mask_img,
636
  gr.Group(visible=True),
637
  ]
 
656
  2. **Click "Run Generative Inference"** to predict what hallucination humans may perceive
657
  3. **View the prediction**: Watch as the model reveals the perceptual structures it expects—matching what humans typically hallucinate
658
  4. **You can upload your own images**
659
+ 5. **You can download the results** as a .gif file together with the configs.json
660
  """)
661
  with gr.Row():
662
  with gr.Column(scale=1):
663
+ # Inputs (track path so save filenames use stimulus name when from example)
664
+ default_image_path = os.path.join("stimuli", "urbanoffice1.jpg")
665
+ image_input = gr.Image(label="Input Image (click to set mask center)", type="pil", value=default_image_path)
666
+ current_image_path_state = gr.State(value=default_image_path)
667
  mask_preview = gr.Image(
668
  label="Mask center preview (click to set center — circle shows mask)",
669
  type="pil",
 
733
  biased_class_dropdown = gr.Dropdown(
734
  choices=[("— No bias —", "")] + [(label, label) for label in sorted(IMAGENET_LABELS)],
735
  value="",
736
+ label="Biased toward category",
737
  allow_custom_value=False,
738
  filterable=True,
739
  )
 
740
  with gr.Column(scale=2):
741
  # Outputs
742
  output_image = gr.Image(label="Predicted Hallucination")
743
  output_frames = gr.Gallery(label="Hallucination Prediction Process", columns=5, rows=2)
744
+ save_status_md = gr.Markdown(value="")
745
+ download_files = gr.File(label="Download results (GIF + config)", file_count="multiple")
746
 
747
  # Examples section with integrated explanations
748
  gr.Markdown("## Examples")
 
771
  eps_max_mult_slider, eps_min_mult_slider,
772
  step_max_mult_slider, step_min_mult_slider,
773
  use_biased_inference_check, biased_class_dropdown,
774
+ current_image_path_state,
775
  mask_preview,
776
  params_section,
777
  ],
 
780
  # Right column for the explanation
781
  with gr.Column(scale=2):
782
  gr.Markdown(f"### {ex['name']}")
783
+ if ex["name"] not in ("farm1", "ArtGallery1", "UrbanOffice1"):
784
+ gr.Markdown(f"[Read more on Wikipedia]({ex['wiki']})")
785
 
786
  # Show instructions if they exist
787
  if "instructions" in ex:
 
805
  eps_max_mult_slider, eps_min_mult_slider,
806
  step_max_mult_slider, step_min_mult_slider,
807
  use_biased_inference_check, biased_class_dropdown,
808
+ current_image_path_state,
809
  ],
810
+ outputs=[output_image, output_frames, save_status_md, download_files]
811
  )
812
 
813
  # Toggle parameters visibility
 
837
  inputs=_mask_preview_inputs(),
838
  outputs=[mask_preview],
839
  )
840
+ # Keep tracked path for save filename: known stimulus name or clear so stem becomes 'user_img'
841
+ image_input.change(
842
+ fn=_update_tracked_image_path,
843
+ inputs=[image_input],
844
+ outputs=[current_image_path_state],
845
+ )
846
  mask_center_x_slider.change(
847
  fn=draw_mask_overlay,
848
  inputs=_mask_preview_inputs(),