HAL1993 commited on
Commit
9a7039a
·
verified ·
1 Parent(s): 5c10790

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +93 -87
app.py CHANGED
@@ -15,7 +15,7 @@ from briarmbg import BriaRMBG
15
  from enum import Enum
16
  import requests
17
 
18
- # Model setup (unchanged)
19
  sd15_name = 'stablediffusionapi/realistic-vision-v51'
20
  tokenizer = CLIPTokenizer.from_pretrained(sd15_name, subfolder="tokenizer")
21
  text_encoder = CLIPTextModel.from_pretrained(sd15_name, subfolder="text_encoder")
@@ -23,7 +23,7 @@ vae = AutoencoderKL.from_pretrained(sd15_name, subfolder="vae")
23
  unet = UNet2DConditionModel.from_pretrained(sd15_name, subfolder="unet")
24
  rmbg = BriaRMBG.from_pretrained("briaai/RMBG-1.4")
25
 
26
- # Change UNet (unchanged)
27
  with torch.no_grad():
28
  new_conv_in = torch.nn.Conv2d(8, unet.conv_in.out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding)
29
  new_conv_in.weight.zero_()
@@ -42,27 +42,26 @@ def hooked_unet_forward(sample, timestep, encoder_hidden_states, **kwargs):
42
 
43
  unet.forward = hooked_unet_forward
44
 
45
- # Load model (unchanged)
46
  model_path = './models/iclight_sd15_fc.safetensors'
47
  sd_offset = sf.load_file(model_path)
48
  sd_origin = unet.state_dict()
49
- keys = sd_origin.keys()
50
  sd_merged = {k: sd_origin[k] + sd_offset[k] for k in sd_origin.keys()}
51
  unet.load_state_dict(sd_merged, strict=True)
52
- del sd_offset, sd_origin, sd_merged, keys
53
 
54
- # Device setup (unchanged)
55
  device = torch.device('cuda')
56
  text_encoder = text_encoder.to(device=device, dtype=torch.float16)
57
  vae = vae.to(device=device, dtype=torch.bfloat16)
58
  unet = unet.to(device=device, dtype=torch.float16)
59
  rmbg = rmbg.to(device=device, dtype=torch.float32)
60
 
61
- # SDP (unchanged)
62
  unet.set_attn_processor(AttnProcessor2_0())
63
  vae.set_attn_processor(AttnProcessor2_0())
64
 
65
- # Samplers (unchanged)
66
  ddim_scheduler = DDIMScheduler(
67
  num_train_timesteps=1000,
68
  beta_start=0.00085,
@@ -89,7 +88,7 @@ dpmpp_2m_sde_karras_scheduler = DPMSolverMultistepScheduler(
89
  steps_offset=1
90
  )
91
 
92
- # Pipelines (unchanged)
93
  t2i_pipe = StableDiffusionPipeline(
94
  vae=vae,
95
  text_encoder=text_encoder,
@@ -114,7 +113,7 @@ i2i_pipe = StableDiffusionImg2ImgPipeline(
114
  image_encoder=None
115
  )
116
 
117
- # Translation function (unchanged)
118
  @spaces.GPU
119
  def translate_albanian_to_english(text):
120
  if not text.strip():
@@ -132,10 +131,10 @@ def translate_albanian_to_english(text):
132
  return translated
133
  except Exception as e:
134
  if attempt == 1:
135
- return f"Përkthimi dështoi: {str(e)}"
136
- return f"Përkthimi dështoi"
137
 
138
- # Core processing functions (unchanged)
139
  @torch.inference_mode()
140
  def encode_prompt_inner(txt: str):
141
  max_length = tokenizer.model_max_length
@@ -153,7 +152,6 @@ def encode_prompt_inner(txt: str):
153
 
154
  token_ids = torch.tensor(chunks).to(device=device, dtype=torch.int64)
155
  conds = text_encoder(token_ids).last_hidden_state
156
-
157
  return conds
158
 
159
  @torch.inference_mode()
@@ -173,7 +171,6 @@ def encode_prompt_pair(positive_prompt, negative_prompt):
173
 
174
  c = torch.cat([p[None, ...] for p in c], dim=1)
175
  uc = torch.cat([p[None, ...] for p in uc], dim=1)
176
-
177
  return c, uc
178
 
179
  @torch.inference_mode()
@@ -231,6 +228,9 @@ def run_rmbg(img, sigma=0.0):
231
 
232
  @torch.inference_mode()
233
  def process(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source):
 
 
 
234
  bg_source = BGSource(bg_source)
235
  input_bg = None
236
 
@@ -253,42 +253,75 @@ def process(input_fg, prompt, image_width, image_height, num_samples, seed, step
253
  image = np.tile(gradient, (1, image_width))
254
  input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
255
  else:
256
- raise 'Wrong initial latent!'
257
 
258
  rng = torch.Generator(device=device).manual_seed(int(seed))
259
 
260
- fg = resize_and_center_crop(input_fg, image_width, image_height)
261
-
262
- concat_conds = numpy2pytorch([fg]).to(device=vae.device, dtype=vae.dtype)
263
- concat_conds = vae.encode(concat_conds).latent_dist.mode() * vae.config.scaling_factor
264
-
265
- conds, unconds = encode_prompt_pair(positive_prompt=prompt + ', ' + a_prompt, negative_prompt=n_prompt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
266
 
267
- if input_bg is None:
268
- latents = t2i_pipe(
269
- prompt_embeds=conds,
270
- negative_prompt_embeds=unconds,
271
- width=image_width,
272
- height=image_height,
273
- num_inference_steps=steps,
274
- num_images_per_prompt=num_samples,
275
- generator=rng,
276
- output_type='latent',
277
- guidance_scale=cfg,
278
- cross_attention_kwargs={'concat_conds': concat_conds},
279
- ).images.to(vae.dtype) / vae.config.scaling_factor
280
- else:
281
- bg = resize_and_center_crop(input_bg, image_width, image_height)
282
- bg_latent = numpy2pytorch([bg]).to(device=vae.device, dtype=vae.dtype)
283
- bg_latent = vae.encode(bg_latent).latent_dist.mode() * vae.config.scaling_factor
284
  latents = i2i_pipe(
285
- image=bg_latent,
286
- strength=lowres_denoise,
287
  prompt_embeds=conds,
288
  negative_prompt_embeds=unconds,
289
  width=image_width,
290
  height=image_height,
291
- num_inference_steps=int(round(steps / lowres_denoise)),
292
  num_images_per_prompt=num_samples,
293
  generator=rng,
294
  output_type='latent',
@@ -296,54 +329,28 @@ def process(input_fg, prompt, image_width, image_height, num_samples, seed, step
296
  cross_attention_kwargs={'concat_conds': concat_conds},
297
  ).images.to(vae.dtype) / vae.config.scaling_factor
298
 
299
- pixels = vae.decode(latents).sample
300
- pixels = pytorch2numpy(pixels)
301
- pixels = [resize_without_crop(
302
- image=p,
303
- target_width=int(round(image_width * highres_scale / 64.0) * 64),
304
- target_height=int(round(image_height * highres_scale / 64.0) * 64))
305
- for p in pixels]
306
-
307
- pixels = numpy2pytorch(pixels).to(device=vae.device, dtype=vae.dtype)
308
- latents = vae.encode(pixels).latent_dist.mode() * vae.config.scaling_factor
309
- latents = latents.to(device=unet.device, dtype=unet.dtype)
310
-
311
- image_height, image_width = latents.shape[2] * 8, latents.shape[3] * 8
312
-
313
- fg = resize_and_center_crop(input_fg, image_width, image_height)
314
- concat_conds = numpy2pytorch([fg]).to(device=vae.device, dtype=vae.dtype)
315
- concat_conds = vae.encode(concat_conds).latent_dist.mode() * vae.config.scaling_factor
316
-
317
- latents = i2i_pipe(
318
- image=latents,
319
- strength=highres_denoise,
320
- prompt_embeds=conds,
321
- negative_prompt_embeds=unconds,
322
- width=image_width,
323
- height=image_height,
324
- num_inference_steps=int(round(steps / highres_denoise)),
325
- num_images_per_prompt=num_samples,
326
- generator=rng,
327
- output_type='latent',
328
- guidance_scale=cfg,
329
- cross_attention_kwargs={'concat_conds': concat_conds},
330
- ).images.to(vae.dtype) / vae.config.scaling_factor
331
-
332
- pixels = vae.decode(latents).sample
333
-
334
- return pytorch2numpy(pixels)
335
 
336
  @spaces.GPU
337
  @torch.inference_mode()
338
  def process_relight(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source):
 
 
 
339
  # Translate Albanian prompt to English
340
- prompt_english = translate_albanian_to_english(prompt)
341
- if prompt_english.startswith("Përkthimi dështoi"):
342
- return None, None
343
 
 
344
  input_fg, matting = run_rmbg(input_fg)
345
- results = process(input_fg, prompt_english, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source)
346
- return input_fg, results
 
 
 
347
 
348
  # Enum for background source (translated to Albanian)
349
  class BGSource(Enum):
@@ -402,7 +409,7 @@ def create_demo():
402
  """)
403
 
404
  gr.Markdown("# Rindriço Imazhin")
405
- gr.Markdown("Rindriço imazhin duke ndryshuar sfondin bazuar në përshkrimin e dhënë")
406
 
407
  with gr.Row():
408
  with gr.Column(elem_classes="constrained-container"):
@@ -412,7 +419,7 @@ def create_demo():
412
  aspect_ratio = gr.Radio(choices=["9:16", "1:1", "16:9"], value="1:1", label="Raporti i Aspektit")
413
  relight_button = gr.Button(value="Rindriço")
414
  result_image = gr.Image(label="Rezultati", type="numpy", height=480, width=480, elem_classes="constrained-container")
415
- # Hidden components for other parameters and output_bg
416
  image_width = gr.Slider(label="Gjerësia e Imazhit", minimum=256, maximum=1024, value=640, step=64, visible=False)
417
  image_height = gr.Slider(label="Lartësia e Imazhit", minimum=256, maximum=1024, value=640, step=64, visible=False)
418
  num_samples = gr.Slider(label="Numri i Imazheve", minimum=1, maximum=12, value=1, step=1, visible=False)
@@ -424,7 +431,6 @@ def create_demo():
424
  highres_scale = gr.Slider(label="Shkalla e Rezolutës së Lartë", minimum=1.0, maximum=3.0, value=2, step=0.01, visible=False)
425
  highres_denoise = gr.Slider(label="Denoise i Rezolutës së Lartë", minimum=0.1, maximum=1.0, value=0.5, step=0.01, visible=False)
426
  lowres_denoise = gr.Slider(label="Denoise i Rezolutës së Ulët", minimum=0.1, maximum=1.0, value=0.9, step=0.01, visible=False)
427
- output_bg = gr.Image(type="numpy", label="Parapërpunimi i Planit të Parë", visible=False)
428
 
429
  # Update hidden sliders based on aspect ratio
430
  aspect_ratio.change(
@@ -438,7 +444,7 @@ def create_demo():
438
  input_fg, prompt, image_width, image_height, num_samples, seed, steps,
439
  a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source
440
  ]
441
- relight_button.click(fn=process_relight, inputs=ips, outputs=[output_bg, result_image])
442
 
443
  return block
444
 
 
15
  from enum import Enum
16
  import requests
17
 
18
+ # Model setup
19
  sd15_name = 'stablediffusionapi/realistic-vision-v51'
20
  tokenizer = CLIPTokenizer.from_pretrained(sd15_name, subfolder="tokenizer")
21
  text_encoder = CLIPTextModel.from_pretrained(sd15_name, subfolder="text_encoder")
 
23
  unet = UNet2DConditionModel.from_pretrained(sd15_name, subfolder="unet")
24
  rmbg = BriaRMBG.from_pretrained("briaai/RMBG-1.4")
25
 
26
+ # Change UNet
27
  with torch.no_grad():
28
  new_conv_in = torch.nn.Conv2d(8, unet.conv_in.out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding)
29
  new_conv_in.weight.zero_()
 
42
 
43
  unet.forward = hooked_unet_forward
44
 
45
+ # Load model
46
  model_path = './models/iclight_sd15_fc.safetensors'
47
  sd_offset = sf.load_file(model_path)
48
  sd_origin = unet.state_dict()
 
49
  sd_merged = {k: sd_origin[k] + sd_offset[k] for k in sd_origin.keys()}
50
  unet.load_state_dict(sd_merged, strict=True)
51
+ del sd_offset, sd_origin, sd_merged
52
 
53
+ # Device setup
54
  device = torch.device('cuda')
55
  text_encoder = text_encoder.to(device=device, dtype=torch.float16)
56
  vae = vae.to(device=device, dtype=torch.bfloat16)
57
  unet = unet.to(device=device, dtype=torch.float16)
58
  rmbg = rmbg.to(device=device, dtype=torch.float32)
59
 
60
+ # SDP
61
  unet.set_attn_processor(AttnProcessor2_0())
62
  vae.set_attn_processor(AttnProcessor2_0())
63
 
64
+ # Samplers
65
  ddim_scheduler = DDIMScheduler(
66
  num_train_timesteps=1000,
67
  beta_start=0.00085,
 
88
  steps_offset=1
89
  )
90
 
91
+ # Pipelines
92
  t2i_pipe = StableDiffusionPipeline(
93
  vae=vae,
94
  text_encoder=text_encoder,
 
113
  image_encoder=None
114
  )
115
 
116
+ # Translation function
117
  @spaces.GPU
118
  def translate_albanian_to_english(text):
119
  if not text.strip():
 
131
  return translated
132
  except Exception as e:
133
  if attempt == 1:
134
+ raise gr.Error(f"Përkthimi dështoi: {str(e)}")
135
+ raise gr.Error("Përkthimi dështoi. Ju lutem provoni përsëri.")
136
 
137
+ # Core processing functions
138
  @torch.inference_mode()
139
  def encode_prompt_inner(txt: str):
140
  max_length = tokenizer.model_max_length
 
152
 
153
  token_ids = torch.tensor(chunks).to(device=device, dtype=torch.int64)
154
  conds = text_encoder(token_ids).last_hidden_state
 
155
  return conds
156
 
157
  @torch.inference_mode()
 
171
 
172
  c = torch.cat([p[None, ...] for p in c], dim=1)
173
  uc = torch.cat([p[None, ...] for p in uc], dim=1)
 
174
  return c, uc
175
 
176
  @torch.inference_mode()
 
228
 
229
  @torch.inference_mode()
230
  def process(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source):
231
+ if input_fg is None:
232
+ raise gr.Error("Ju lutem ngarkoni një imazh.")
233
+
234
  bg_source = BGSource(bg_source)
235
  input_bg = None
236
 
 
253
  image = np.tile(gradient, (1, image_width))
254
  input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
255
  else:
256
+ raise gr.Error("Preferenca e ndriçimit është e pavlefshme!")
257
 
258
  rng = torch.Generator(device=device).manual_seed(int(seed))
259
 
260
+ try:
261
+ fg = resize_and_center_crop(input_fg, image_width, image_height)
262
+ concat_conds = numpy2pytorch([fg]).to(device=vae.device, dtype=vae.dtype)
263
+ concat_conds = vae.encode(concat_conds).latent_dist.mode() * vae.config.scaling_factor
264
+
265
+ conds, unconds = encode_prompt_pair(positive_prompt=prompt + ', ' + a_prompt, negative_prompt=n_prompt)
266
+
267
+ if input_bg is None:
268
+ latents = t2i_pipe(
269
+ prompt_embeds=conds,
270
+ negative_prompt_embeds=unconds,
271
+ width=image_width,
272
+ height=image_height,
273
+ num_inference_steps=steps,
274
+ num_images_per_prompt=num_samples,
275
+ generator=rng,
276
+ output_type='latent',
277
+ guidance_scale=cfg,
278
+ cross_attention_kwargs={'concat_conds': concat_conds},
279
+ ).images.to(vae.dtype) / vae.config.scaling_factor
280
+ else:
281
+ bg = resize_and_center_crop(input_bg, image_width, image_height)
282
+ bg_latent = numpy2pytorch([bg]).to(device=vae.device, dtype=vae.dtype)
283
+ bg_latent = vae.encode(bg_latent).latent_dist.mode() * vae.config.scaling_factor
284
+ latents = i2i_pipe(
285
+ image=bg_latent,
286
+ strength=lowres_denoise,
287
+ prompt_embeds=conds,
288
+ negative_prompt_embeds=unconds,
289
+ width=image_width,
290
+ height=image_height,
291
+ num_inference_steps=int(round(steps / lowres_denoise)),
292
+ num_images_per_prompt=num_samples,
293
+ generator=rng,
294
+ output_type='latent',
295
+ guidance_scale=cfg,
296
+ cross_attention_kwargs={'concat_conds': concat_conds},
297
+ ).images.to(vae.dtype) / vae.config.scaling_factor
298
+
299
+ pixels = vae.decode(latents).sample
300
+ pixels = pytorch2numpy(pixels)
301
+ pixels = [resize_without_crop(
302
+ image=p,
303
+ target_width=int(round(image_width * highres_scale / 64.0) * 64),
304
+ target_height=int(round(image_height * highres_scale / 64.0) * 64))
305
+ for p in pixels]
306
+
307
+ pixels = numpy2pytorch(pixels).to(device=vae.device, dtype=vae.dtype)
308
+ latents = vae.encode(pixels).latent_dist.mode() * vae.config.scaling_factor
309
+ latents = latents.to(device=unet.device, dtype=unet.dtype)
310
+
311
+ image_height, image_width = latents.shape[2] * 8, latents.shape[3] * 8
312
+
313
+ fg = resize_and_center_crop(input_fg, image_width, image_height)
314
+ concat_conds = numpy2pytorch([fg]).to(device=vae.device, dtype=vae.dtype)
315
+ concat_conds = vae.encode(concat_conds).latent_dist.mode() * vae.config.scaling_factor
316
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
317
  latents = i2i_pipe(
318
+ image=latents,
319
+ strength=highres_denoise,
320
  prompt_embeds=conds,
321
  negative_prompt_embeds=unconds,
322
  width=image_width,
323
  height=image_height,
324
+ num_inference_steps=int(round(steps / highres_denoise)),
325
  num_images_per_prompt=num_samples,
326
  generator=rng,
327
  output_type='latent',
 
329
  cross_attention_kwargs={'concat_conds': concat_conds},
330
  ).images.to(vae.dtype) / vae.config.scaling_factor
331
 
332
+ pixels = vae.decode(latents).sample
333
+ results = pytorch2numpy(pixels)
334
+ return results[0] # Return single image since num_samples=1
335
+ except Exception as e:
336
+ raise gr.Error(f"Gabim gjatë përpunimit imazhit: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
337
 
338
  @spaces.GPU
339
  @torch.inference_mode()
340
  def process_relight(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source):
341
+ if input_fg is None:
342
+ raise gr.Error("Ju lutem ngarkoni një imazh.")
343
+
344
  # Translate Albanian prompt to English
345
+ prompt_english = translate_albanian_to_english(prompt.strip()) if prompt.strip() else ""
 
 
346
 
347
+ # Run background removal
348
  input_fg, matting = run_rmbg(input_fg)
349
+
350
+ # Process the image
351
+ result = process(input_fg, prompt_english, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source)
352
+
353
+ return result
354
 
355
  # Enum for background source (translated to Albanian)
356
  class BGSource(Enum):
 
409
  """)
410
 
411
  gr.Markdown("# Rindriço Imazhin")
412
+ gr.Markdown("Rindriço imazhin duke ndryshuar ndriçimin e sfondit bazuar në përshkrimin e dhënë")
413
 
414
  with gr.Row():
415
  with gr.Column(elem_classes="constrained-container"):
 
419
  aspect_ratio = gr.Radio(choices=["9:16", "1:1", "16:9"], value="1:1", label="Raporti i Aspektit")
420
  relight_button = gr.Button(value="Rindriço")
421
  result_image = gr.Image(label="Rezultati", type="numpy", height=480, width=480, elem_classes="constrained-container")
422
+ # Hidden components for other parameters
423
  image_width = gr.Slider(label="Gjerësia e Imazhit", minimum=256, maximum=1024, value=640, step=64, visible=False)
424
  image_height = gr.Slider(label="Lartësia e Imazhit", minimum=256, maximum=1024, value=640, step=64, visible=False)
425
  num_samples = gr.Slider(label="Numri i Imazheve", minimum=1, maximum=12, value=1, step=1, visible=False)
 
431
  highres_scale = gr.Slider(label="Shkalla e Rezolutës së Lartë", minimum=1.0, maximum=3.0, value=2, step=0.01, visible=False)
432
  highres_denoise = gr.Slider(label="Denoise i Rezolutës së Lartë", minimum=0.1, maximum=1.0, value=0.5, step=0.01, visible=False)
433
  lowres_denoise = gr.Slider(label="Denoise i Rezolutës së Ulët", minimum=0.1, maximum=1.0, value=0.9, step=0.01, visible=False)
 
434
 
435
  # Update hidden sliders based on aspect ratio
436
  aspect_ratio.change(
 
444
  input_fg, prompt, image_width, image_height, num_samples, seed, steps,
445
  a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source
446
  ]
447
+ relight_button.click(fn=process_relight, inputs=ips, outputs=result_image)
448
 
449
  return block
450