HAL1993 commited on
Commit
1176392
·
verified ·
1 Parent(s): 8d4aa9e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -137
app.py CHANGED
@@ -31,7 +31,6 @@ ASPECT_RATIOS = {
31
  str(1920 / 512): (1920, 512),
32
  }
33
 
34
- # download the config and model
35
  MODEL_PATH = hf_hub_download("jasperai/LBM_relighting", "model.safetensors", token=huggingface_token)
36
  CONFIG_PATH = hf_hub_download("jasperai/LBM_relighting", "config.yaml", token=huggingface_token)
37
 
@@ -41,20 +40,11 @@ model = get_model_from_config(**config)
41
  sd = load_file(MODEL_PATH)
42
  model.load_state_dict(sd, strict=True)
43
  model.to("cuda").to(torch.bfloat16)
44
- birefnet = AutoModelForImageSegmentation.from_pretrained(
45
- "ZhengPeng7/BiRefNet", trust_remote_code=True
46
- ).cuda()
47
  image_size = (1024, 1024)
48
 
49
-
50
  @spaces.GPU
51
- def evaluate(
52
- fg_image: PIL.Image.Image,
53
- bg_image: PIL.Image.Image,
54
- num_sampling_steps: int = 1,
55
- ):
56
- gr.Info("Relighting Image...", duration=3)
57
-
58
  ori_h_bg, ori_w_bg = fg_image.size
59
  ar_bg = ori_h_bg / ori_w_bg
60
  closest_ar_bg = min(ASPECT_RATIOS, key=lambda x: abs(float(x) - ar_bg))
@@ -67,159 +57,69 @@ def evaluate(
67
  bg_image = resize_and_center_crop(bg_image, dimensions_bg[0], dimensions_bg[1])
68
 
69
  img_pasted = Image.composite(fg_image, bg_image, fg_mask)
70
-
71
  img_pasted_tensor = ToTensor()(img_pasted).unsqueeze(0) * 2 - 1
72
- batch = {
73
- "source_image": img_pasted_tensor.cuda().to(torch.bfloat16),
74
- }
75
 
76
  z_source = model.vae.encode(batch[model.source_key])
77
-
78
- output_image = model.sample(
79
- z=z_source,
80
- num_steps=num_sampling_steps,
81
- conditioner_inputs=batch,
82
- max_samples=1,
83
- ).clamp(-1, 1)
84
-
85
  output_image = (output_image[0].float().cpu() + 1) / 2
86
  output_image = ToPILImage()(output_image)
87
-
88
- # paste the output image on the background image
89
  output_image = Image.composite(output_image, bg_image, fg_mask)
90
-
91
  output_image.resize((ori_h_bg, ori_w_bg))
92
- print(output_image.size, img_pasted.size)
93
 
94
  return (np.array(img_pasted), np.array(output_image))
95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
- with gr.Blocks(title="LBM Object Relighting") as demo:
98
- gr.Markdown(
99
- f"""
100
- # Object Relighting with Latent Bridge Matching
101
- This is an interactive demo of [LBM: Latent Bridge Matching for Fast Image-to-Image Translation](https://arxiv.org/abs/2503.07535) *by Jasper Research*. This demo is based on the [LBM relighting checkpoint](https://huggingface.co/jasperai/LBM_relighting).
102
- """
103
- )
104
- gr.Markdown(
105
- """
106
- If you enjoy the space, please also promote *open-source* by giving a ⭐ to the <a href='https://github.com/gojasper/LBM' target='_blank'>Github Repo</a>.
107
- """
108
- )
109
- gr.Markdown(
110
- "💡 *Hint:* To better appreciate the low latency of our method, run the demo locally !"
111
- )
112
  with gr.Row():
113
  with gr.Column():
114
  with gr.Row():
115
- fg_image = gr.Image(
116
- type="pil",
117
- label="Input Image",
118
- image_mode="RGB",
119
- height=360,
120
- # width=360,
121
- )
122
- bg_image = gr.Image(
123
- type="pil",
124
- label="Target Background",
125
- image_mode="RGB",
126
- height=360,
127
- # width=360,
128
- )
129
 
130
  with gr.Row():
131
- submit_button = gr.Button("Relight", variant="primary")
132
  with gr.Row():
133
- num_inference_steps = gr.Slider(
134
- minimum=1,
135
- maximum=4,
136
- value=1,
137
- step=1,
138
- label="Number of Inference Steps",
139
- )
140
 
141
  bg_gallery = gr.Gallery(
142
- # height=450,
143
  object_fit="contain",
144
- label="Background List",
145
- value=[path for path in glob.glob("examples/backgrounds/*.jpg")],
146
- columns=5,
147
- allow_preview=False,
148
- )
149
 
150
  with gr.Column():
151
- output_slider = gr.ImageSlider(label="Composite vs LBM", type="numpy")
152
- output_slider.upload(
153
- fn=evaluate,
154
- inputs=[fg_image, bg_image, num_inference_steps],
155
- outputs=[output_slider],
156
- )
157
-
158
- submit_button.click(
159
- evaluate,
160
- inputs=[fg_image, bg_image, num_inference_steps],
161
- outputs=[output_slider],
162
- show_progress="full",
163
- show_api=False,
164
- )
165
 
166
- with gr.Row():
167
- gr.Examples(
168
- fn=evaluate,
169
- examples=[
170
- [
171
- "examples/foregrounds/2.jpg",
172
- "examples/backgrounds/14.jpg",
173
- 1,
174
- ],
175
- [
176
- "examples/foregrounds/10.jpg",
177
- "examples/backgrounds/4.jpg",
178
- 1,
179
- ],
180
- [
181
- "examples/foregrounds/11.jpg",
182
- "examples/backgrounds/24.jpg",
183
- 1,
184
- ],
185
- [
186
- "examples/foregrounds/19.jpg",
187
- "examples/backgrounds/3.jpg",
188
- 1,
189
- ],
190
- [
191
- "examples/foregrounds/4.jpg",
192
- "examples/backgrounds/6.jpg",
193
- 1,
194
- ],
195
- [
196
- "examples/foregrounds/14.jpg",
197
- "examples/backgrounds/22.jpg",
198
- 1,
199
- ],
200
- [
201
- "examples/foregrounds/12.jpg",
202
- "examples/backgrounds/1.jpg",
203
- 1,
204
- ],
205
- ],
206
- inputs=[fg_image, bg_image, num_inference_steps],
207
- outputs=[output_slider],
208
- run_on_click=True,
209
- )
210
-
211
- gr.Markdown("**Disclaimer:**")
212
- gr.Markdown(
213
- "This demo is only for research purpose. Jasper cannot be held responsible for the generation of NSFW (Not Safe For Work) content through the use of this demo. Users are solely responsible for any content they create, and it is their obligation to ensure that it adheres to appropriate and ethical standards. Jasper provides the tools, but the responsibility for their use lies with the individual user."
214
- )
215
- gr.Markdown("**Note:** Some backgrounds example are taken from [IC-Light repo](https://github.com/lllyasviel/IC-Light)")
216
 
217
  def bg_gallery_selected(gal, evt: gr.SelectData):
218
- print(gal, evt.index)
219
  return gal[evt.index][0]
220
 
221
  bg_gallery.select(bg_gallery_selected, inputs=bg_gallery, outputs=bg_image)
222
 
223
  if __name__ == "__main__":
224
-
225
  demo.queue().launch(show_api=False)
 
31
  str(1920 / 512): (1920, 512),
32
  }
33
 
 
34
  MODEL_PATH = hf_hub_download("jasperai/LBM_relighting", "model.safetensors", token=huggingface_token)
35
  CONFIG_PATH = hf_hub_download("jasperai/LBM_relighting", "config.yaml", token=huggingface_token)
36
 
 
40
  sd = load_file(MODEL_PATH)
41
  model.load_state_dict(sd, strict=True)
42
  model.to("cuda").to(torch.bfloat16)
43
+ birefnet = AutoModelForImageSegmentation.from_pretrained("ZhengPeng7/BiRefNet", trust_remote_code=True).cuda()
 
 
44
  image_size = (1024, 1024)
45
 
 
46
  @spaces.GPU
47
+ def evaluate(fg_image: PIL.Image.Image, bg_image: PIL.Image.Image, num_sampling_steps: int = 1):
 
 
 
 
 
 
48
  ori_h_bg, ori_w_bg = fg_image.size
49
  ar_bg = ori_h_bg / ori_w_bg
50
  closest_ar_bg = min(ASPECT_RATIOS, key=lambda x: abs(float(x) - ar_bg))
 
57
  bg_image = resize_and_center_crop(bg_image, dimensions_bg[0], dimensions_bg[1])
58
 
59
  img_pasted = Image.composite(fg_image, bg_image, fg_mask)
 
60
  img_pasted_tensor = ToTensor()(img_pasted).unsqueeze(0) * 2 - 1
61
+ batch = {"source_image": img_pasted_tensor.cuda().to(torch.bfloat16)}
 
 
62
 
63
  z_source = model.vae.encode(batch[model.source_key])
64
+ output_image = model.sample(z=z_source, num_steps=num_sampling_steps, conditioner_inputs=batch, max_samples=1).clamp(-1, 1)
 
 
 
 
 
 
 
65
  output_image = (output_image[0].float().cpu() + 1) / 2
66
  output_image = ToPILImage()(output_image)
 
 
67
  output_image = Image.composite(output_image, bg_image, fg_mask)
 
68
  output_image.resize((ori_h_bg, ori_w_bg))
 
69
 
70
  return (np.array(img_pasted), np.array(output_image))
71
 
72
+ with gr.Blocks(css="""
73
+ body::before {
74
+ content: "";
75
+ display: block;
76
+ height: 640px;
77
+ background-color: #0f1117;
78
+ }
79
+ button[aria-label="Download"] {
80
+ transform: scale(2);
81
+ transform-origin: top right;
82
+ margin: 0 !important;
83
+ padding: 6px !important;
84
+ }
85
+ button[aria-label="Share"] {
86
+ display: none;
87
+ }
88
+ button[aria-label="Copy link"] {
89
+ display: none;
90
+ }
91
+ button[aria-label="Open in new tab"] {
92
+ display: none;
93
+ }
94
+ """, title="LBM Object Relighting") as demo:
95
+ gr.Markdown("# 🌄 Rindriçim i Objektit me Sfondin e Zgjedhur")
96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  with gr.Row():
98
  with gr.Column():
99
  with gr.Row():
100
+ fg_image = gr.Image(type="pil", label="Imazhi Kryesor", image_mode="RGB", height=360)
101
+ bg_image = gr.Image(type="pil", label="Sfondi i Ri", image_mode="RGB", height=360)
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
  with gr.Row():
104
+ submit_button = gr.Button("Rindriço", variant="primary")
105
  with gr.Row():
106
+ num_inference_steps = gr.Slider(minimum=1, maximum=4, value=1, step=1, label="Hapat e Inferencës")
 
 
 
 
 
 
107
 
108
  bg_gallery = gr.Gallery(
 
109
  object_fit="contain",
110
+ label="Sfondet", value=[path for path in glob.glob("examples/backgrounds/*.jpg")],
111
+ columns=5, allow_preview=False)
 
 
 
112
 
113
  with gr.Column():
114
+ output_slider = gr.ImageSlider(label="Para / Pas", type="numpy", show_share_button=False)
115
+ output_slider.upload(fn=evaluate, inputs=[fg_image, bg_image, num_inference_steps], outputs=[output_slider])
 
 
 
 
 
 
 
 
 
 
 
 
116
 
117
+ submit_button.click(evaluate, inputs=[fg_image, bg_image, num_inference_steps], outputs=[output_slider], show_progress="full", show_api=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
  def bg_gallery_selected(gal, evt: gr.SelectData):
 
120
  return gal[evt.index][0]
121
 
122
  bg_gallery.select(bg_gallery_selected, inputs=bg_gallery, outputs=bg_image)
123
 
124
  if __name__ == "__main__":
 
125
  demo.queue().launch(show_api=False)