Update app.py
Browse files
app.py
CHANGED
|
@@ -10,78 +10,173 @@ from cldm.model import create_model, load_state_dict
|
|
| 10 |
from cldm.ddim_hacked import DDIMSampler
|
| 11 |
from huggingface_hub import hf_hub_download
|
| 12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
|
|
|
|
|
|
|
|
|
| 14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
def load_model(checkpoint_path):
|
| 16 |
-
model = create_model(
|
| 17 |
-
model.add_new_layers()
|
| 18 |
model.concat = False
|
| 19 |
-
|
|
|
|
| 20 |
model.parameterization = "v"
|
| 21 |
-
|
|
|
|
| 22 |
|
| 23 |
-
|
|
|
|
| 24 |
resume_path = hf_hub_download(repo_id="xyxingx/LumiNet", filename="LumiNet.ckpt")
|
| 25 |
model = load_model(resume_path)
|
| 26 |
ddim_sampler = DDIMSampler(model)
|
| 27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
@spaces.GPU
|
| 29 |
-
def process_images(input_image, reference_image, ddim_steps=50):
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
uc_cat = c_cat
|
|
|
|
| 50 |
uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]}
|
| 51 |
-
cond
|
| 52 |
-
|
|
|
|
|
|
|
| 53 |
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
x_samples = (x_samples.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
|
| 62 |
-
output_images.append(Image.fromarray(x_samples))
|
| 63 |
-
|
| 64 |
-
return output_images
|
| 65 |
|
|
|
|
|
|
|
| 66 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
with gr.Blocks() as gram:
|
| 68 |
gr.Markdown("# LumiNet: Latent Intrinsics Meets Diffusion Models for Indoor Scene Relighting")
|
| 69 |
gr.Markdown("A demo for [paper](https://luminet-relight.github.io/)")
|
| 70 |
-
gr.Markdown("Upload your own image and reference
|
| 71 |
-
gr.Markdown("Note: No post-processing is used
|
| 72 |
|
| 73 |
with gr.Row():
|
| 74 |
-
input_img = gr.Image(type="pil", label="Input Image", sources=["upload"]
|
| 75 |
-
ref_img
|
| 76 |
-
|
| 77 |
-
|
|
|
|
|
|
|
|
|
|
| 78 |
btn = gr.Button("Generate")
|
| 79 |
-
|
| 80 |
with gr.Row():
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
|
| 85 |
if __name__ == "__main__":
|
| 86 |
gram.launch()
|
| 87 |
-
|
|
|
|
| 10 |
from cldm.ddim_hacked import DDIMSampler
|
| 11 |
from huggingface_hub import hf_hub_download
|
| 12 |
|
| 13 |
+
# -------------------------
|
| 14 |
+
# Global settings & helpers
|
| 15 |
+
# -------------------------
|
| 16 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 17 |
+
BATCH_N = 1
|
| 18 |
+
INF_SIZE = 512 # inference resolution (square)
|
| 19 |
|
| 20 |
+
# Lazy flag for loading the new/bypass decoder weights once
|
| 21 |
+
_NEW_DECODER_LOADED = False
|
| 22 |
+
_NEW_DECODER_PATH = None
|
| 23 |
|
| 24 |
+
def _ensure_new_decoder_loaded(model):
|
| 25 |
+
"""Load weights for the new/bypass decoder only once."""
|
| 26 |
+
global _NEW_DECODER_LOADED, _NEW_DECODER_PATH
|
| 27 |
+
if not _NEW_DECODER_LOADED:
|
| 28 |
+
_NEW_DECODER_PATH = hf_hub_download(repo_id="xyxingx/LumiNet", filename="new_decoder.ckpt")
|
| 29 |
+
model.change_first_stage(_NEW_DECODER_PATH)
|
| 30 |
+
_NEW_DECODER_LOADED = True
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
# -------------------------
|
| 34 |
+
# Model loading
|
| 35 |
+
# -------------------------
|
| 36 |
def load_model(checkpoint_path):
|
| 37 |
+
model = create_model("./models/cldm_v21_LumiNet.yaml").cpu()
|
| 38 |
+
model.add_new_layers() # ensures new decoder layers exist
|
| 39 |
model.concat = False
|
| 40 |
+
sd = load_state_dict(checkpoint_path, location=DEVICE)
|
| 41 |
+
model.load_state_dict(sd)
|
| 42 |
model.parameterization = "v"
|
| 43 |
+
model = model.to(DEVICE).eval()
|
| 44 |
+
return model
|
| 45 |
|
| 46 |
+
|
| 47 |
+
# Download main checkpoint & build sampler
|
| 48 |
resume_path = hf_hub_download(repo_id="xyxingx/LumiNet", filename="LumiNet.ckpt")
|
| 49 |
model = load_model(resume_path)
|
| 50 |
ddim_sampler = DDIMSampler(model)
|
| 51 |
|
| 52 |
+
|
| 53 |
+
# -------------------------
|
| 54 |
+
# Inference
|
| 55 |
+
# -------------------------
|
| 56 |
+
def _preprocess_to_np_rgb(img_pil):
|
| 57 |
+
"""PIL -> float32 numpy [H,W,3] in [0,1], RGB."""
|
| 58 |
+
return (np.array(img_pil.convert("RGB"), dtype=np.uint8).astype(np.float32) / 255.0)
|
| 59 |
+
|
| 60 |
+
def _resize_to_square_512(img_np):
|
| 61 |
+
return cv2.resize(img_np, (INF_SIZE, INF_SIZE), interpolation=cv2.INTER_LANCZOS4)
|
| 62 |
+
|
| 63 |
+
def _tensor_from_np(img_np):
|
| 64 |
+
"""HWC [0..1] -> BCHW float32 on DEVICE."""
|
| 65 |
+
t = torch.from_numpy(img_np.copy()).float() # HWC
|
| 66 |
+
t = einops.rearrange(t, "h w c -> 1 c h w") # BCHW
|
| 67 |
+
return t.to(DEVICE)
|
| 68 |
+
|
| 69 |
@spaces.GPU
|
| 70 |
+
def process_images(input_image, reference_image, ddim_steps=50, use_new_decoder=False):
|
| 71 |
+
"""
|
| 72 |
+
input_image, reference_image: PIL Images
|
| 73 |
+
Returns 3 PIL images with original aspect ratio, generated with different seeds.
|
| 74 |
+
"""
|
| 75 |
+
assert input_image is not None and reference_image is not None, "Please upload both input and reference images."
|
| 76 |
+
|
| 77 |
+
# Prepare originals (for aspect-ratio restoration)
|
| 78 |
+
input_np_full = _preprocess_to_np_rgb(input_image) # [H,W,3] 0..1
|
| 79 |
+
ref_np_full = _preprocess_to_np_rgb(reference_image) # [H,W,3] 0..1
|
| 80 |
+
orig_h, orig_w = input_np_full.shape[:2]
|
| 81 |
+
|
| 82 |
+
# Inference inputs @ 512×512
|
| 83 |
+
input_np_512 = _resize_to_square_512(input_np_full)
|
| 84 |
+
ref_np_512 = _resize_to_square_512(ref_np_full)
|
| 85 |
+
|
| 86 |
+
# Control feature: concat input & reference along channels -> [H,W,6]
|
| 87 |
+
control_feat = np.concatenate((input_np_512, ref_np_512), axis=2).astype(np.float32)
|
| 88 |
+
control = _tensor_from_np(control_feat) # [1,6,512,512]
|
| 89 |
+
|
| 90 |
+
# Also keep the input tensor for new-decoder decoding path (needs input AE features)
|
| 91 |
+
input_tensor = _tensor_from_np(input_np_512) # [1,3,512,512]
|
| 92 |
+
|
| 93 |
+
# Conditioning
|
| 94 |
+
with torch.no_grad():
|
| 95 |
+
c_cat = control
|
| 96 |
+
# Cross-attention uses unconditional embeddings because there is no text prompt
|
| 97 |
+
c = model.get_unconditional_conditioning(BATCH_N)
|
| 98 |
+
uc_cross = model.get_unconditional_conditioning(BATCH_N)
|
| 99 |
uc_cat = c_cat
|
| 100 |
+
|
| 101 |
uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]}
|
| 102 |
+
cond = {"c_concat": [c_cat], "c_crossattn": [c]}
|
| 103 |
+
|
| 104 |
+
# Latent shape for 512×512 with factor 8
|
| 105 |
+
shape = (4, INF_SIZE // 8, INF_SIZE // 8)
|
| 106 |
|
| 107 |
+
# Make 3 different seeds
|
| 108 |
+
seeds = [random.randint(0, 999_999) for _ in range(3)]
|
| 109 |
+
outputs = []
|
| 110 |
|
| 111 |
+
# Ensure new/bypass decoder weights are loaded if requested
|
| 112 |
+
if use_new_decoder:
|
| 113 |
+
_ensure_new_decoder_loaded(model)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
|
| 115 |
+
for seed in seeds:
|
| 116 |
+
torch.manual_seed(seed)
|
| 117 |
|
| 118 |
+
samples, _ = ddim_sampler.sample(
|
| 119 |
+
S=ddim_steps,
|
| 120 |
+
batch_size=BATCH_N,
|
| 121 |
+
shape=shape,
|
| 122 |
+
conditioning=cond,
|
| 123 |
+
verbose=False,
|
| 124 |
+
eta=0.0,
|
| 125 |
+
unconditional_guidance_scale=9.0,
|
| 126 |
+
unconditional_conditioning=uc_full
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
# Decode
|
| 130 |
+
if use_new_decoder:
|
| 131 |
+
# encode_first_stage expects [-1,1] range
|
| 132 |
+
ae_hs = model.encode_first_stage(input_tensor * 2.0 - 1.0)[1]
|
| 133 |
+
x = model.decode_new_first_stage(samples, ae_hs)
|
| 134 |
+
else:
|
| 135 |
+
x = model.decode_first_stage(samples)
|
| 136 |
+
|
| 137 |
+
# To image in [0,255], HWC
|
| 138 |
+
x = (x.squeeze(0) + 1.0) / 2.0
|
| 139 |
+
x = x.clamp(0, 1)
|
| 140 |
+
x = (einops.rearrange(x, "c h w -> h w c").detach().cpu().numpy() * 255.0).astype(np.uint8)
|
| 141 |
+
|
| 142 |
+
# Resize back to original aspect ratio/size
|
| 143 |
+
x = cv2.resize(x, (orig_w, orig_h), interpolation=cv2.INTER_LANCZOS4)
|
| 144 |
+
|
| 145 |
+
outputs.append(Image.fromarray(x))
|
| 146 |
+
|
| 147 |
+
return outputs
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
# -------------------------
|
| 151 |
+
# UI
|
| 152 |
+
# -------------------------
|
| 153 |
with gr.Blocks() as gram:
|
| 154 |
gr.Markdown("# LumiNet: Latent Intrinsics Meets Diffusion Models for Indoor Scene Relighting")
|
| 155 |
gr.Markdown("A demo for [paper](https://luminet-relight.github.io/)")
|
| 156 |
+
gr.Markdown("Upload your own image and a reference. The demo outputs 3 relit images with different random seeds.")
|
| 157 |
+
gr.Markdown("**Note:** Inference runs at 512×512. Results are resized back to your input image’s original aspect ratio. No post-processing is used.")
|
| 158 |
|
| 159 |
with gr.Row():
|
| 160 |
+
input_img = gr.Image(type="pil", label="Input Image", sources=["upload"])
|
| 161 |
+
ref_img = gr.Image(type="pil", label="Reference Image", sources=["upload"])
|
| 162 |
+
|
| 163 |
+
with gr.Row():
|
| 164 |
+
ddim_slider = gr.Slider(minimum=10, maximum=1000, step=1, label="DDIM Steps", value=50)
|
| 165 |
+
use_new_dec = gr.Checkbox(label="Use bypass (new) decoder for better identity preservation", value=False)
|
| 166 |
+
|
| 167 |
btn = gr.Button("Generate")
|
| 168 |
+
|
| 169 |
with gr.Row():
|
| 170 |
+
# No fixed width/height so images keep their native aspect ratio in the layout
|
| 171 |
+
out1 = gr.Image(type="pil", label="Generated Image 1")
|
| 172 |
+
out2 = gr.Image(type="pil", label="Generated Image 2")
|
| 173 |
+
out3 = gr.Image(type="pil", label="Generated Image 3")
|
| 174 |
+
|
| 175 |
+
btn.click(
|
| 176 |
+
fn=process_images,
|
| 177 |
+
inputs=[input_img, ref_img, ddim_slider, use_new_dec],
|
| 178 |
+
outputs=[out1, out2, out3]
|
| 179 |
+
)
|
| 180 |
|
| 181 |
if __name__ == "__main__":
|
| 182 |
gram.launch()
|
|
|