RyanHangZhou commited on
Commit
4d9ba5b
·
verified ·
1 Parent(s): 8ece528

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -18
app.py CHANGED
@@ -1,20 +1,51 @@
1
  import gradio as gr
2
  import numpy as np
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  def pics_pairwise_inference(background, img_a, mask_a, img_b, mask_b):
5
  """
6
- Standard PICS Inference with 5 Explicit Inputs:
7
- 1. Background Scene
8
- 2. Object A Image
9
- 3. Object A Mask
10
- 4. Object B Image
11
- 5. Object B Mask
12
  """
13
- # In a real scenario, we would preprocess these 5 inputs here
14
- # e.g., result = model.forward(background, img_a, mask_a, img_b, mask_b)
15
 
16
- # Verification: Returning background to confirm the 5-input pipeline is live
17
- return background
 
 
 
 
 
18
 
19
  with gr.Blocks(title="PICS: Pairwise Spatial Compositing") as demo:
20
  gr.Markdown("# 🚀 PICS: Pairwise Image Compositing (5-Input Framework)")
@@ -22,17 +53,14 @@ with gr.Blocks(title="PICS: Pairwise Spatial Compositing") as demo:
22
 
23
  with gr.Row():
24
  with gr.Column(scale=2):
25
- # 1. Background Input
26
  bg_input = gr.Image(label="1. Scene Background", type="pil")
27
 
28
  with gr.Row():
29
- # 2 & 3. Object A Pair
30
  with gr.Column():
31
  gr.Markdown("### Object A")
32
  obj_a_img = gr.Image(label="Image A", type="pil")
33
  obj_a_mask = gr.Image(label="Mask A", type="pil")
34
 
35
- # 4 & 5. Object B Pair
36
  with gr.Column():
37
  gr.Markdown("### Object B")
38
  obj_b_img = gr.Image(label="Image B", type="pil")
@@ -41,16 +69,14 @@ with gr.Blocks(title="PICS: Pairwise Spatial Compositing") as demo:
41
  run_btn = gr.Button("Execute PICS Inference ✨", variant="primary")
42
 
43
  with gr.Column(scale=1):
44
- # Result Section
45
  output_img = gr.Image(label="PICS Composite Result")
46
  gr.Markdown("""
47
  ### 🔬 Technical Requirements
48
- * **Pairwise Reasoning**: The model takes 5 distinct inputs to compute depth, occlusion, and lighting.
49
- * **Mask Alignment**: Ensure Mask A/B perfectly align with Image A/B.
50
- * **InfiniKin Powered**: Trained on high-fidelity synthetic pairs generated by the InfiniKin engine.
51
  """)
52
 
53
- # --- Linking all 5 inputs to the inference function ---
54
  run_btn.click(
55
  fn=pics_pairwise_inference,
56
  inputs=[bg_input, obj_a_img, obj_a_mask, obj_b_img, obj_b_mask],
 
1
  import gradio as gr
2
  import numpy as np
3
+ import os
4
+ import sys
5
+ import torch
6
+ from omegaconf import OmegaConf
7
+ from huggingface_hub import snapshot_download
8
+
9
+ # 1. Download assets (Code + Checkpoints) from your Model Repo
10
+ # This ensures we have the same environment as your training setup
11
+ REPO_DIR = snapshot_download(repo_id="Hang2991/PICS")
12
+ sys.path.append(REPO_DIR)
13
+
14
+ # 2. Import components from your uploaded code
15
+ # Note: Ensure these paths match your folder structure in the PICS repo
16
+ from cldm.model import create_model, load_state_dict
17
+ from cldm.ddim_hacked import DDIMSampler
18
+
19
+ # 3. Hardware Adaptive Setup
20
+ device = "cuda" if torch.cuda.is_available() else "cpu"
21
+
22
+ # 4. Load Configurations and Initialize Model
23
+ # We use os.path.join to handle the dynamic directory from snapshot_download
24
+ config = OmegaConf.load(os.path.join(REPO_DIR, 'configs/inference.yaml'))
25
+ model_ckpt = os.path.join(REPO_DIR, config.pretrained_model)
26
+ model_config = os.path.join(REPO_DIR, config.config_file)
27
+
28
+ # Initialize the model (using your specific 'create_model' function)
29
+ model = create_model(model_config).to(device)
30
+ model.load_state_dict(load_state_dict(model_ckpt, location=device))
31
+ model.eval() # Set to evaluation mode
32
+ ddim_sampler = DDIMSampler(model)
33
 
34
  def pics_pairwise_inference(background, img_a, mask_a, img_b, mask_b):
35
  """
36
+ Main PICS Inference Logic.
37
+ This function bridges the Gradio UI with your core algorithm.
 
 
 
 
38
  """
39
+ # NOTE: You might need to add image preprocessing here
40
+ # (e.g., resizing to 512x512, converting to tensors, etc.)
41
 
42
+ with torch.no_grad():
43
+ # Using 'model' (the variable we initialized above)
44
+ # Assuming your model has a .inference() or similar method
45
+ # If your actual inference logic is different, replace this line:
46
+ result = model.inference(background, img_a, mask_a, img_b, mask_b)
47
+
48
+ return result
49
 
50
  with gr.Blocks(title="PICS: Pairwise Spatial Compositing") as demo:
51
  gr.Markdown("# 🚀 PICS: Pairwise Image Compositing (5-Input Framework)")
 
53
 
54
  with gr.Row():
55
  with gr.Column(scale=2):
 
56
  bg_input = gr.Image(label="1. Scene Background", type="pil")
57
 
58
  with gr.Row():
 
59
  with gr.Column():
60
  gr.Markdown("### Object A")
61
  obj_a_img = gr.Image(label="Image A", type="pil")
62
  obj_a_mask = gr.Image(label="Mask A", type="pil")
63
 
 
64
  with gr.Column():
65
  gr.Markdown("### Object B")
66
  obj_b_img = gr.Image(label="Image B", type="pil")
 
69
  run_btn = gr.Button("Execute PICS Inference ✨", variant="primary")
70
 
71
  with gr.Column(scale=1):
 
72
  output_img = gr.Image(label="PICS Composite Result")
73
  gr.Markdown("""
74
  ### 🔬 Technical Requirements
75
+ * **Pairwise Reasoning**: Computing occlusion and interaction for A & B.
76
+ * **Mask Alignment**: Masks must perfectly match the objects.
 
77
  """)
78
 
79
+ # Linking the 5 inputs to the inference function
80
  run_btn.click(
81
  fn=pics_pairwise_inference,
82
  inputs=[bg_input, obj_a_img, obj_a_mask, obj_b_img, obj_b_mask],