xyxingx commited on
Commit
8b04f5b
·
verified ·
1 Parent(s): c01603e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +143 -48
app.py CHANGED
@@ -10,78 +10,173 @@ from cldm.model import create_model, load_state_dict
10
  from cldm.ddim_hacked import DDIMSampler
11
  from huggingface_hub import hf_hub_download
12
 
 
 
 
 
 
 
13
 
 
 
 
14
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  def load_model(checkpoint_path):
16
- model = create_model('./models/cldm_v21_LumiNet.yaml').cpu()
17
- model.add_new_layers()
18
  model.concat = False
19
- model.load_state_dict(load_state_dict(checkpoint_path, location='cuda'))
 
20
  model.parameterization = "v"
21
- return model.cuda()
 
22
 
23
- # Download the checkpoint and load the model.
 
24
  resume_path = hf_hub_download(repo_id="xyxingx/LumiNet", filename="LumiNet.ckpt")
25
  model = load_model(resume_path)
26
  ddim_sampler = DDIMSampler(model)
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  @spaces.GPU
29
- def process_images(input_image, reference_image, ddim_steps=50):
30
- seed_list = [random.randint(0, 100000) for _ in range(3)] # Generate with 3 random seeds
31
- output_images = []
32
-
33
- for seed in seed_list:
34
- torch.manual_seed(seed)
35
-
36
- input_image_np = np.array(input_image) / 255
37
- reference_image_np = np.array(reference_image) / 255
38
-
39
- input_image_resized = cv2.resize(input_image_np, (512, 512))
40
- reference_image_resized = cv2.resize(reference_image_np, (512, 512))
41
- control_feat = np.concatenate((input_image_resized, reference_image_resized), axis=2)
42
-
43
- control = torch.from_numpy(control_feat.copy()).float().cuda()
44
- control = einops.rearrange(control, 'h w c -> 1 c h w').clone()
45
-
46
- c_cat = control.cuda()
47
- c = model.get_unconditional_conditioning(1)
48
- uc_cross = model.get_unconditional_conditioning(1)
 
 
 
 
 
 
 
 
 
49
  uc_cat = c_cat
 
50
  uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]}
51
- cond = {"c_concat": [c_cat], "c_crossattn": [c]}
52
- shape = (4, 64, 64) # Adjusted latent space shape
 
 
53
 
54
- samples, _ = ddim_sampler.sample(ddim_steps, 1, shape, cond, verbose=False, eta=0.0,
55
- unconditional_guidance_scale=9.0,
56
- unconditional_conditioning=uc_full)
57
 
58
- x_samples = model.decode_first_stage(samples)
59
- x_samples = (x_samples.squeeze(0) + 1.0) / 2.0
60
- x_samples = x_samples.clamp(0,1)
61
- x_samples = (x_samples.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
62
- output_images.append(Image.fromarray(x_samples))
63
-
64
- return output_images
65
 
 
 
66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  with gr.Blocks() as gram:
68
  gr.Markdown("# LumiNet: Latent Intrinsics Meets Diffusion Models for Indoor Scene Relighting")
69
  gr.Markdown("A demo for [paper](https://luminet-relight.github.io/)")
70
- gr.Markdown("Upload your own image and reference, our demo will output 3 relit images, with different seeds.")
71
- gr.Markdown("Note: No post-processing is used in this demo.")
72
 
73
  with gr.Row():
74
- input_img = gr.Image(type="pil", label="Input Image", sources=["upload"], width=256, height=256)
75
- ref_img = gr.Image(type="pil", label="Reference Image", sources=["upload"], width=256, height=256)
76
-
77
- ddim_slider = gr.Slider(minimum=10, maximum=1000, step=1, label="DDIM Steps", value=50)
 
 
 
78
  btn = gr.Button("Generate")
79
-
80
  with gr.Row():
81
- output_imgs = [gr.Image(label=f"Generated Image {i+1}", width=256, height=256) for i in range(3)]
82
-
83
- btn.click(process_images, inputs=[input_img, ref_img, ddim_slider], outputs=output_imgs)
 
 
 
 
 
 
 
84
 
85
  if __name__ == "__main__":
86
  gram.launch()
87
-
 
10
  from cldm.ddim_hacked import DDIMSampler
11
  from huggingface_hub import hf_hub_download
12
 
13
+ # -------------------------
14
+ # Global settings & helpers
15
+ # -------------------------
16
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
17
+ BATCH_N = 1
18
+ INF_SIZE = 512 # inference resolution (square)
19
 
20
+ # Lazy flag for loading the new/bypass decoder weights once
21
+ _NEW_DECODER_LOADED = False
22
+ _NEW_DECODER_PATH = None
23
 
24
+ def _ensure_new_decoder_loaded(model):
25
+ """Load weights for the new/bypass decoder only once."""
26
+ global _NEW_DECODER_LOADED, _NEW_DECODER_PATH
27
+ if not _NEW_DECODER_LOADED:
28
+ _NEW_DECODER_PATH = hf_hub_download(repo_id="xyxingx/LumiNet", filename="new_decoder.ckpt")
29
+ model.change_first_stage(_NEW_DECODER_PATH)
30
+ _NEW_DECODER_LOADED = True
31
+
32
+
33
+ # -------------------------
34
+ # Model loading
35
+ # -------------------------
36
  def load_model(checkpoint_path):
37
+ model = create_model("./models/cldm_v21_LumiNet.yaml").cpu()
38
+ model.add_new_layers() # ensures new decoder layers exist
39
  model.concat = False
40
+ sd = load_state_dict(checkpoint_path, location=DEVICE)
41
+ model.load_state_dict(sd)
42
  model.parameterization = "v"
43
+ model = model.to(DEVICE).eval()
44
+ return model
45
 
46
+
47
+ # Download main checkpoint & build sampler
48
  resume_path = hf_hub_download(repo_id="xyxingx/LumiNet", filename="LumiNet.ckpt")
49
  model = load_model(resume_path)
50
  ddim_sampler = DDIMSampler(model)
51
 
52
+
53
+ # -------------------------
54
+ # Inference
55
+ # -------------------------
56
+ def _preprocess_to_np_rgb(img_pil):
57
+ """PIL -> float32 numpy [H,W,3] in [0,1], RGB."""
58
+ return (np.array(img_pil.convert("RGB"), dtype=np.uint8).astype(np.float32) / 255.0)
59
+
60
+ def _resize_to_square_512(img_np):
61
+ return cv2.resize(img_np, (INF_SIZE, INF_SIZE), interpolation=cv2.INTER_LANCZOS4)
62
+
63
+ def _tensor_from_np(img_np):
64
+ """HWC [0..1] -> BCHW float32 on DEVICE."""
65
+ t = torch.from_numpy(img_np.copy()).float() # HWC
66
+ t = einops.rearrange(t, "h w c -> 1 c h w") # BCHW
67
+ return t.to(DEVICE)
68
+
69
  @spaces.GPU
70
+ def process_images(input_image, reference_image, ddim_steps=50, use_new_decoder=False):
71
+ """
72
+ input_image, reference_image: PIL Images
73
+ Returns 3 PIL images with original aspect ratio, generated with different seeds.
74
+ """
75
+ assert input_image is not None and reference_image is not None, "Please upload both input and reference images."
76
+
77
+ # Prepare originals (for aspect-ratio restoration)
78
+ input_np_full = _preprocess_to_np_rgb(input_image) # [H,W,3] 0..1
79
+ ref_np_full = _preprocess_to_np_rgb(reference_image) # [H,W,3] 0..1
80
+ orig_h, orig_w = input_np_full.shape[:2]
81
+
82
+ # Inference inputs @ 512×512
83
+ input_np_512 = _resize_to_square_512(input_np_full)
84
+ ref_np_512 = _resize_to_square_512(ref_np_full)
85
+
86
+ # Control feature: concat input & reference along channels -> [H,W,6]
87
+ control_feat = np.concatenate((input_np_512, ref_np_512), axis=2).astype(np.float32)
88
+ control = _tensor_from_np(control_feat) # [1,6,512,512]
89
+
90
+ # Also keep the input tensor for new-decoder decoding path (needs input AE features)
91
+ input_tensor = _tensor_from_np(input_np_512) # [1,3,512,512]
92
+
93
+ # Conditioning
94
+ with torch.no_grad():
95
+ c_cat = control
96
+ # Cross-attention uses unconditional embeddings because there is no text prompt
97
+ c = model.get_unconditional_conditioning(BATCH_N)
98
+ uc_cross = model.get_unconditional_conditioning(BATCH_N)
99
  uc_cat = c_cat
100
+
101
  uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]}
102
+ cond = {"c_concat": [c_cat], "c_crossattn": [c]}
103
+
104
+ # Latent shape for 512×512 with factor 8
105
+ shape = (4, INF_SIZE // 8, INF_SIZE // 8)
106
 
107
+ # Make 3 different seeds
108
+ seeds = [random.randint(0, 999_999) for _ in range(3)]
109
+ outputs = []
110
 
111
+ # Ensure new/bypass decoder weights are loaded if requested
112
+ if use_new_decoder:
113
+ _ensure_new_decoder_loaded(model)
 
 
 
 
114
 
115
+ for seed in seeds:
116
+ torch.manual_seed(seed)
117
 
118
+ samples, _ = ddim_sampler.sample(
119
+ S=ddim_steps,
120
+ batch_size=BATCH_N,
121
+ shape=shape,
122
+ conditioning=cond,
123
+ verbose=False,
124
+ eta=0.0,
125
+ unconditional_guidance_scale=9.0,
126
+ unconditional_conditioning=uc_full
127
+ )
128
+
129
+ # Decode
130
+ if use_new_decoder:
131
+ # encode_first_stage expects [-1,1] range
132
+ ae_hs = model.encode_first_stage(input_tensor * 2.0 - 1.0)[1]
133
+ x = model.decode_new_first_stage(samples, ae_hs)
134
+ else:
135
+ x = model.decode_first_stage(samples)
136
+
137
+ # To image in [0,255], HWC
138
+ x = (x.squeeze(0) + 1.0) / 2.0
139
+ x = x.clamp(0, 1)
140
+ x = (einops.rearrange(x, "c h w -> h w c").detach().cpu().numpy() * 255.0).astype(np.uint8)
141
+
142
+ # Resize back to original aspect ratio/size
143
+ x = cv2.resize(x, (orig_w, orig_h), interpolation=cv2.INTER_LANCZOS4)
144
+
145
+ outputs.append(Image.fromarray(x))
146
+
147
+ return outputs
148
+
149
+
150
+ # -------------------------
151
+ # UI
152
+ # -------------------------
153
  with gr.Blocks() as gram:
154
  gr.Markdown("# LumiNet: Latent Intrinsics Meets Diffusion Models for Indoor Scene Relighting")
155
  gr.Markdown("A demo for [paper](https://luminet-relight.github.io/)")
156
+ gr.Markdown("Upload your own image and a reference. The demo outputs 3 relit images with different random seeds.")
157
+ gr.Markdown("**Note:** Inference runs at 512×512. Results are resized back to your input image’s original aspect ratio. No post-processing is used.")
158
 
159
  with gr.Row():
160
+ input_img = gr.Image(type="pil", label="Input Image", sources=["upload"])
161
+ ref_img = gr.Image(type="pil", label="Reference Image", sources=["upload"])
162
+
163
+ with gr.Row():
164
+ ddim_slider = gr.Slider(minimum=10, maximum=1000, step=1, label="DDIM Steps", value=50)
165
+ use_new_dec = gr.Checkbox(label="Use bypass (new) decoder for better identity preservation", value=False)
166
+
167
  btn = gr.Button("Generate")
168
+
169
  with gr.Row():
170
+ # No fixed width/height so images keep their native aspect ratio in the layout
171
+ out1 = gr.Image(type="pil", label="Generated Image 1")
172
+ out2 = gr.Image(type="pil", label="Generated Image 2")
173
+ out3 = gr.Image(type="pil", label="Generated Image 3")
174
+
175
+ btn.click(
176
+ fn=process_images,
177
+ inputs=[input_img, ref_img, ddim_slider, use_new_dec],
178
+ outputs=[out1, out2, out3]
179
+ )
180
 
181
  if __name__ == "__main__":
182
  gram.launch()