HAL1993 commited on
Commit
a254cd6
·
verified ·
1 Parent(s): bcf2cd7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -110
app.py CHANGED
@@ -4,8 +4,7 @@ import gradio as gr
4
  import numpy as np
5
  import torch
6
  import safetensors.torch as sf
7
- import db_examples
8
-
9
  from PIL import Image
10
  from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline
11
  from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler, EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler
@@ -13,11 +12,7 @@ from diffusers.models.attention_processor import AttnProcessor2_0
13
  from transformers import CLIPTextModel, CLIPTokenizer
14
  from briarmbg import BriaRMBG
15
  from enum import Enum
16
- # from torch.hub import download_url_to_file
17
-
18
 
19
- # 'stablediffusionapi/realistic-vision-v51'
20
- # 'runwayml/stable-diffusion-v1-5'
21
  sd15_name = 'stablediffusionapi/realistic-vision-v51'
22
  tokenizer = CLIPTokenizer.from_pretrained(sd15_name, subfolder="tokenizer")
23
  text_encoder = CLIPTextModel.from_pretrained(sd15_name, subfolder="text_encoder")
@@ -25,8 +20,6 @@ vae = AutoencoderKL.from_pretrained(sd15_name, subfolder="vae")
25
  unet = UNet2DConditionModel.from_pretrained(sd15_name, subfolder="unet")
26
  rmbg = BriaRMBG.from_pretrained("briaai/RMBG-1.4")
27
 
28
- # Change UNet
29
-
30
  with torch.no_grad():
31
  new_conv_in = torch.nn.Conv2d(8, unet.conv_in.out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding)
32
  new_conv_in.weight.zero_()
@@ -36,7 +29,6 @@ with torch.no_grad():
36
 
37
  unet_original_forward = unet.forward
38
 
39
-
40
  def hooked_unet_forward(sample, timestep, encoder_hidden_states, **kwargs):
41
  c_concat = kwargs['cross_attention_kwargs']['concat_conds'].to(sample)
42
  c_concat = torch.cat([c_concat] * (sample.shape[0] // c_concat.shape[0]), dim=0)
@@ -44,13 +36,9 @@ def hooked_unet_forward(sample, timestep, encoder_hidden_states, **kwargs):
44
  kwargs['cross_attention_kwargs'] = {}
45
  return unet_original_forward(new_sample, timestep, encoder_hidden_states, **kwargs)
46
 
47
-
48
  unet.forward = hooked_unet_forward
49
 
50
- # Load
51
-
52
  model_path = './models/iclight_sd15_fc.safetensors'
53
- # download_url_to_file(url='https://huggingface.co/lllyasviel/ic-light/resolve/main/iclight_sd15_fc.safetensors', dst=model_path)
54
  sd_offset = sf.load_file(model_path)
55
  sd_origin = unet.state_dict()
56
  keys = sd_origin.keys()
@@ -58,21 +46,15 @@ sd_merged = {k: sd_origin[k] + sd_offset[k] for k in sd_origin.keys()}
58
  unet.load_state_dict(sd_merged, strict=True)
59
  del sd_offset, sd_origin, sd_merged, keys
60
 
61
- # Device
62
-
63
  device = torch.device('cuda')
64
  text_encoder = text_encoder.to(device=device, dtype=torch.float16)
65
  vae = vae.to(device=device, dtype=torch.bfloat16)
66
  unet = unet.to(device=device, dtype=torch.float16)
67
  rmbg = rmbg.to(device=device, dtype=torch.float32)
68
 
69
- # SDP
70
-
71
  unet.set_attn_processor(AttnProcessor2_0())
72
  vae.set_attn_processor(AttnProcessor2_0())
73
 
74
- # Samplers
75
-
76
  ddim_scheduler = DDIMScheduler(
77
  num_train_timesteps=1000,
78
  beta_start=0.00085,
@@ -99,8 +81,6 @@ dpmpp_2m_sde_karras_scheduler = DPMSolverMultistepScheduler(
99
  steps_offset=1
100
  )
101
 
102
- # Pipelines
103
-
104
  t2i_pipe = StableDiffusionPipeline(
105
  vae=vae,
106
  text_encoder=text_encoder,
@@ -125,7 +105,6 @@ i2i_pipe = StableDiffusionImg2ImgPipeline(
125
  image_encoder=None
126
  )
127
 
128
-
129
  @torch.inference_mode()
130
  def encode_prompt_inner(txt: str):
131
  max_length = tokenizer.model_max_length
@@ -146,7 +125,6 @@ def encode_prompt_inner(txt: str):
146
 
147
  return conds
148
 
149
-
150
  @torch.inference_mode()
151
  def encode_prompt_pair(positive_prompt, negative_prompt):
152
  c = encode_prompt_inner(positive_prompt)
@@ -167,7 +145,6 @@ def encode_prompt_pair(positive_prompt, negative_prompt):
167
 
168
  return c, uc
169
 
170
-
171
  @torch.inference_mode()
172
  def pytorch2numpy(imgs, quant=True):
173
  results = []
@@ -184,14 +161,12 @@ def pytorch2numpy(imgs, quant=True):
184
  results.append(y)
185
  return results
186
 
187
-
188
  @torch.inference_mode()
189
  def numpy2pytorch(imgs):
190
- h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.0 - 1.0 # so that 127 must be strictly 0.0
191
  h = h.movedim(-1, 1)
192
  return h
193
 
194
-
195
  def resize_and_center_crop(image, target_width, target_height):
196
  pil_image = Image.fromarray(image)
197
  original_width, original_height = pil_image.size
@@ -206,13 +181,11 @@ def resize_and_center_crop(image, target_width, target_height):
206
  cropped_image = resized_image.crop((left, top, right, bottom))
207
  return np.array(cropped_image)
208
 
209
-
210
  def resize_without_crop(image, target_width, target_height):
211
  pil_image = Image.fromarray(image)
212
  resized_image = pil_image.resize((target_width, target_height), Image.LANCZOS)
213
  return np.array(resized_image)
214
 
215
-
216
  @torch.inference_mode()
217
  def run_rmbg(img, sigma=0.0):
218
  H, W, C = img.shape
@@ -227,9 +200,42 @@ def run_rmbg(img, sigma=0.0):
227
  result = 127 + (img.astype(np.float32) - 127 + sigma) * alpha
228
  return result.clip(0, 255).astype(np.uint8), alpha
229
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
 
 
231
  @torch.inference_mode()
232
- def process(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source):
 
 
 
 
 
 
 
 
 
 
233
  bg_source = BGSource(bg_source)
234
  input_bg = None
235
 
@@ -330,104 +336,61 @@ def process(input_fg, prompt, image_width, image_height, num_samples, seed, step
330
 
331
  pixels = vae.decode(latents).sample
332
 
333
- return pytorch2numpy(pixels)
334
-
335
 
336
  @spaces.GPU
337
  @torch.inference_mode()
338
- def process_relight(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source):
339
- input_fg, matting = run_rmbg(input_fg)
340
- results = process(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source)
341
- return input_fg, results
342
-
343
-
344
- quick_prompts = [
345
- 'sunshine from window',
346
- 'neon light, city',
347
- 'sunset over sea',
348
- 'golden time',
349
- 'sci-fi RGB glowing, cyberpunk',
350
- 'natural lighting',
351
- 'warm atmosphere, at home, bedroom',
352
- 'magic lit',
353
- 'evil, gothic, Yharnam',
354
- 'light and shadow',
355
- 'shadow from window',
356
- 'soft studio lighting',
357
- 'home atmosphere, cozy bedroom illumination',
358
- 'neon, Wong Kar-wai, warm'
359
- ]
360
- quick_prompts = [[x] for x in quick_prompts]
361
-
362
-
363
- quick_subjects = [
364
- 'beautiful woman, detailed face',
365
- 'handsome man, detailed face',
366
- ]
367
- quick_subjects = [[x] for x in quick_subjects]
368
 
 
 
 
369
 
370
  class BGSource(Enum):
371
- NONE = "None"
372
- LEFT = "Left Light"
373
- RIGHT = "Right Light"
374
- TOP = "Top Light"
375
- BOTTOM = "Bottom Light"
376
-
377
 
378
  block = gr.Blocks().queue()
379
  with block:
380
- with gr.Row():
381
- gr.Markdown("## IC-Light (Relighting with Foreground Condition)")
382
- with gr.Row():
383
- gr.Markdown("See also https://github.com/lllyasviel/IC-Light for background-conditioned model and normal estimation")
384
  with gr.Row():
385
  with gr.Column():
386
  with gr.Row():
387
- input_fg = gr.Image(sources='upload', type="numpy", label="Image", height=480)
388
- output_bg = gr.Image(type="numpy", label="Preprocessed Foreground", height=480)
389
- prompt = gr.Textbox(label="Prompt")
390
  bg_source = gr.Radio(choices=[e.value for e in BGSource],
391
  value=BGSource.NONE.value,
392
- label="Lighting Preference (Initial Latent)", type='value')
393
- example_quick_subjects = gr.Dataset(samples=quick_subjects, label='Subject Quick List', samples_per_page=1000, components=[prompt])
394
- example_quick_prompts = gr.Dataset(samples=quick_prompts, label='Lighting Quick List', samples_per_page=1000, components=[prompt])
395
- relight_button = gr.Button(value="Relight")
396
 
397
  with gr.Group():
398
  with gr.Row():
399
- num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
400
- seed = gr.Number(label="Seed", value=12345, precision=0)
401
 
402
  with gr.Row():
403
- image_width = gr.Slider(label="Image Width", minimum=256, maximum=1024, value=512, step=64)
404
- image_height = gr.Slider(label="Image Height", minimum=256, maximum=1024, value=640, step=64)
405
-
406
- with gr.Accordion("Advanced options", open=False):
407
- steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=25, step=1)
408
- cfg = gr.Slider(label="CFG Scale", minimum=1.0, maximum=32.0, value=2, step=0.01)
409
- lowres_denoise = gr.Slider(label="Lowres Denoise (for initial latent)", minimum=0.1, maximum=1.0, value=0.9, step=0.01)
410
- highres_scale = gr.Slider(label="Highres Scale", minimum=1.0, maximum=3.0, value=1.5, step=0.01)
411
- highres_denoise = gr.Slider(label="Highres Denoise", minimum=0.1, maximum=1.0, value=0.5, step=0.01)
412
- a_prompt = gr.Textbox(label="Added Prompt", value='best quality')
413
- n_prompt = gr.Textbox(label="Negative Prompt", value='lowres, bad anatomy, bad hands, cropped, worst quality')
414
  with gr.Column():
415
- result_gallery = gr.Gallery(height=832, object_fit='contain', label='Outputs')
 
416
  with gr.Row():
417
- dummy_image_for_outputs = gr.Image(visible=False, label='Result')
418
- gr.Examples(
419
- fn=lambda *args: [[args[-1]], "imgs/dummy.png"],
420
- examples=db_examples.foreground_conditioned_examples,
421
- inputs=[
422
- input_fg, prompt, bg_source, image_width, image_height, seed, dummy_image_for_outputs
423
- ],
424
- outputs=[result_gallery, output_bg],
425
- run_on_click=True, examples_per_page=1024
426
- )
427
- ips = [input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source]
428
- relight_button.click(fn=process_relight, inputs=ips, outputs=[output_bg, result_gallery])
429
- example_quick_prompts.click(lambda x, y: ', '.join(y.split(', ')[:2] + [x[0]]), inputs=[example_quick_prompts, prompt], outputs=prompt, show_progress=False, queue=False)
430
- example_quick_subjects.click(lambda x: x[0], inputs=example_quick_subjects, outputs=prompt, show_progress=False, queue=False)
431
-
432
-
433
- block.launch(server_name='0.0.0.0')
 
4
  import numpy as np
5
  import torch
6
  import safetensors.torch as sf
7
+ import requests
 
8
  from PIL import Image
9
  from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline
10
  from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler, EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler
 
12
  from transformers import CLIPTextModel, CLIPTokenizer
13
  from briarmbg import BriaRMBG
14
  from enum import Enum
 
 
15
 
 
 
16
  sd15_name = 'stablediffusionapi/realistic-vision-v51'
17
  tokenizer = CLIPTokenizer.from_pretrained(sd15_name, subfolder="tokenizer")
18
  text_encoder = CLIPTextModel.from_pretrained(sd15_name, subfolder="text_encoder")
 
20
  unet = UNet2DConditionModel.from_pretrained(sd15_name, subfolder="unet")
21
  rmbg = BriaRMBG.from_pretrained("briaai/RMBG-1.4")
22
 
 
 
23
  with torch.no_grad():
24
  new_conv_in = torch.nn.Conv2d(8, unet.conv_in.out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding)
25
  new_conv_in.weight.zero_()
 
29
 
30
  unet_original_forward = unet.forward
31
 
 
32
  def hooked_unet_forward(sample, timestep, encoder_hidden_states, **kwargs):
33
  c_concat = kwargs['cross_attention_kwargs']['concat_conds'].to(sample)
34
  c_concat = torch.cat([c_concat] * (sample.shape[0] // c_concat.shape[0]), dim=0)
 
36
  kwargs['cross_attention_kwargs'] = {}
37
  return unet_original_forward(new_sample, timestep, encoder_hidden_states, **kwargs)
38
 
 
39
  unet.forward = hooked_unet_forward
40
 
 
 
41
  model_path = './models/iclight_sd15_fc.safetensors'
 
42
  sd_offset = sf.load_file(model_path)
43
  sd_origin = unet.state_dict()
44
  keys = sd_origin.keys()
 
46
  unet.load_state_dict(sd_merged, strict=True)
47
  del sd_offset, sd_origin, sd_merged, keys
48
 
 
 
49
  device = torch.device('cuda')
50
  text_encoder = text_encoder.to(device=device, dtype=torch.float16)
51
  vae = vae.to(device=device, dtype=torch.bfloat16)
52
  unet = unet.to(device=device, dtype=torch.float16)
53
  rmbg = rmbg.to(device=device, dtype=torch.float32)
54
 
 
 
55
  unet.set_attn_processor(AttnProcessor2_0())
56
  vae.set_attn_processor(AttnProcessor2_0())
57
 
 
 
58
  ddim_scheduler = DDIMScheduler(
59
  num_train_timesteps=1000,
60
  beta_start=0.00085,
 
81
  steps_offset=1
82
  )
83
 
 
 
84
  t2i_pipe = StableDiffusionPipeline(
85
  vae=vae,
86
  text_encoder=text_encoder,
 
105
  image_encoder=None
106
  )
107
 
 
108
  @torch.inference_mode()
109
  def encode_prompt_inner(txt: str):
110
  max_length = tokenizer.model_max_length
 
125
 
126
  return conds
127
 
 
128
  @torch.inference_mode()
129
  def encode_prompt_pair(positive_prompt, negative_prompt):
130
  c = encode_prompt_inner(positive_prompt)
 
145
 
146
  return c, uc
147
 
 
148
  @torch.inference_mode()
149
  def pytorch2numpy(imgs, quant=True):
150
  results = []
 
161
  results.append(y)
162
  return results
163
 
 
164
  @torch.inference_mode()
165
  def numpy2pytorch(imgs):
166
+ h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.0 - 1.0
167
  h = h.movedim(-1, 1)
168
  return h
169
 
 
170
  def resize_and_center_crop(image, target_width, target_height):
171
  pil_image = Image.fromarray(image)
172
  original_width, original_height = pil_image.size
 
181
  cropped_image = resized_image.crop((left, top, right, bottom))
182
  return np.array(cropped_image)
183
 
 
184
  def resize_without_crop(image, target_width, target_height):
185
  pil_image = Image.fromarray(image)
186
  resized_image = pil_image.resize((target_width, target_height), Image.LANCZOS)
187
  return np.array(resized_image)
188
 
 
189
  @torch.inference_mode()
190
  def run_rmbg(img, sigma=0.0):
191
  H, W, C = img.shape
 
200
  result = 127 + (img.astype(np.float32) - 127 + sigma) * alpha
201
  return result.clip(0, 255).astype(np.uint8), alpha
202
 
203
+ @spaces.GPU
204
+ def translate_albanian_to_english(text):
205
+ """Translate Albanian to English using sepioo-facebook-translation API."""
206
+ if not text.strip():
207
+ return ""
208
+ for attempt in range(2):
209
+ try:
210
+ response = requests.post(
211
+ "https://hal1993-mdftranslation1234567890abcdef1234567890-fc073a6.hf.space/v1/translate",
212
+ json={"from_language": "sq", "to_language": "en", "input_text": text},
213
+ headers={"accept": "application/json", "Content-Type": "application/json"},
214
+ timeout=5
215
+ )
216
+ response.raise_for_status()
217
+ translated = response.json().get("translate", "")
218
+ print(f"Translation response: {translated}")
219
+ return translated
220
+ except Exception as e:
221
+ print(f"Translation error (attempt {attempt + 1}): {e}")
222
+ if attempt == 1:
223
+ return f"Përkthimi dështoi: {str(e)}"
224
+ return f"Përkthimi dështoi"
225
 
226
+ @spaces.GPU
227
  @torch.inference_mode()
228
+ def process(input_fg, prompt_albanian, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source):
229
+ if not input_fg.any():
230
+ return None, None, "Gabim: Nuk është dhënë asnjë imazh."
231
+
232
+ if not prompt_albanian.strip():
233
+ prompt = ""
234
+ else:
235
+ prompt = translate_albanian_to_english(prompt_albanian)
236
+ if prompt.startswith("Përkthimi dështoi"):
237
+ return None, None, prompt
238
+
239
  bg_source = BGSource(bg_source)
240
  input_bg = None
241
 
 
336
 
337
  pixels = vae.decode(latents).sample
338
 
339
+ return pytorch2numpy(pixels), "Imazhi u gjenerua me sukses."
 
340
 
341
  @spaces.GPU
342
  @torch.inference_mode()
343
+ def process_relight(input_fg, prompt_albanian, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source):
344
+ if not input_fg.any():
345
+ return None, None, "Gabim: Nuk është dhënë asnjë imazh."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
346
 
347
+ input_fg, matting = run_rmbg(input_fg)
348
+ results, status = process(input_fg, prompt_albanian, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source)
349
+ return input_fg, results, status
350
 
351
  class BGSource(Enum):
352
+ NONE = "Asnjë"
353
+ LEFT = "Dritë nga Majtas"
354
+ RIGHT = "Dritë nga Djathtas"
355
+ TOP = "Dritë nga Sipër"
356
+ BOTTOM = "Dritë nga Poshtë"
 
357
 
358
  block = gr.Blocks().queue()
359
  with block:
 
 
 
 
360
  with gr.Row():
361
  with gr.Column():
362
  with gr.Row():
363
+ input_fg = gr.Image(sources='upload', type="numpy", label="Imazhi i Hyrjes", height=480)
364
+ output_bg = gr.Image(type="numpy", label="Sfondi i Përpunuar", height=480)
365
+ prompt_albanian = gr.Textbox(label="Përshkrimi")
366
  bg_source = gr.Radio(choices=[e.value for e in BGSource],
367
  value=BGSource.NONE.value,
368
+ label="Preferenca e Ndriçimit (Latenti Fillestar)", type='value')
369
+ relight_button = gr.Button(value="Gjenero")
 
 
370
 
371
  with gr.Group():
372
  with gr.Row():
373
+ num_samples = gr.Slider(label="Numri i Imazheve", minimum=1, maximum=12, value=1, step=1, visible=False)
374
+ seed = gr.Number(label="Farë", value=-1, precision=0, visible=False)
375
 
376
  with gr.Row():
377
+ image_width = gr.Slider(label="Gjerësia e Imazhit", minimum=256, maximum=1024, value=512, step=64)
378
+ image_height = gr.Slider(label="Lartësia e Imazhit", minimum=256, maximum=1024, value=640, step=64)
379
+
380
+ with gr.Accordion("Opsionet e Avancuara", open=False, visible=False):
381
+ steps = gr.Slider(label="Hapat", minimum=1, maximum=100, value=50, step=1)
382
+ cfg = gr.Slider(label="Shkalla CFG", minimum=1.0, maximum=32.0, value=2, step=0.01)
383
+ lowres_denoise = gr.Slider(label="Denoise për Rezolutë të Ulët (për latent fillestar)", minimum=0.1, maximum=1.0, value=0.9, step=0.01)
384
+ highres_scale = gr.Slider(label="Shkalla e Rezolutës së Lartë", minimum=1.0, maximum=3.0, value=2, step=0.01)
385
+ highres_denoise = gr.Slider(label="Denoise për Rezolutë të Lartë", minimum=0.1, maximum=1.0, value=1, step=0.01)
386
+ a_prompt = gr.Textbox(label="Përshkrim Shtesë", value='cilësi më e mirë')
387
+ n_prompt = gr.Textbox(label="Përshkrim Negativ", value='rezolutë e ulët, anatomi e dobët, duar të dobëta, prerje, cilësi më e keqe')
388
  with gr.Column():
389
+ result_gallery = gr.Gallery(height=832, object_fit='contain', label='Rezultatet')
390
+ status = gr.Textbox(label="Statusi", interactive=False)
391
  with gr.Row():
392
+ dummy_image_for_outputs = gr.Image(visible=False, label='Rezultati')
393
+ ips = [input_fg, prompt_albanian, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source]
394
+ relight_button.click(fn=process_relight, inputs=ips, outputs=[output_bg, result_gallery, status])
395
+
396
+ block.launch(server_name='0.0.0.0')