HAL1993 commited on
Commit
a673b7c
Β·
verified Β·
1 Parent(s): 5ba8d58

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +107 -107
app.py CHANGED
@@ -4,7 +4,8 @@ import gradio as gr
4
  import numpy as np
5
  import torch
6
  import safetensors.torch as sf
7
- import requests
 
8
  from PIL import Image
9
  from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline
10
  from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler, EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler
@@ -12,7 +13,11 @@ from diffusers.models.attention_processor import AttnProcessor2_0
12
  from transformers import CLIPTextModel, CLIPTokenizer
13
  from briarmbg import BriaRMBG
14
  from enum import Enum
 
 
15
 
 
 
16
  sd15_name = 'stablediffusionapi/realistic-vision-v51'
17
  tokenizer = CLIPTokenizer.from_pretrained(sd15_name, subfolder="tokenizer")
18
  text_encoder = CLIPTextModel.from_pretrained(sd15_name, subfolder="text_encoder")
@@ -20,6 +25,8 @@ vae = AutoencoderKL.from_pretrained(sd15_name, subfolder="vae")
20
  unet = UNet2DConditionModel.from_pretrained(sd15_name, subfolder="unet")
21
  rmbg = BriaRMBG.from_pretrained("briaai/RMBG-1.4")
22
 
 
 
23
  with torch.no_grad():
24
  new_conv_in = torch.nn.Conv2d(8, unet.conv_in.out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding)
25
  new_conv_in.weight.zero_()
@@ -29,6 +36,7 @@ with torch.no_grad():
29
 
30
  unet_original_forward = unet.forward
31
 
 
32
  def hooked_unet_forward(sample, timestep, encoder_hidden_states, **kwargs):
33
  c_concat = kwargs['cross_attention_kwargs']['concat_conds'].to(sample)
34
  c_concat = torch.cat([c_concat] * (sample.shape[0] // c_concat.shape[0]), dim=0)
@@ -36,9 +44,13 @@ def hooked_unet_forward(sample, timestep, encoder_hidden_states, **kwargs):
36
  kwargs['cross_attention_kwargs'] = {}
37
  return unet_original_forward(new_sample, timestep, encoder_hidden_states, **kwargs)
38
 
 
39
  unet.forward = hooked_unet_forward
40
 
 
 
41
  model_path = './models/iclight_sd15_fc.safetensors'
 
42
  sd_offset = sf.load_file(model_path)
43
  sd_origin = unet.state_dict()
44
  keys = sd_origin.keys()
@@ -46,15 +58,21 @@ sd_merged = {k: sd_origin[k] + sd_offset[k] for k in sd_origin.keys()}
46
  unet.load_state_dict(sd_merged, strict=True)
47
  del sd_offset, sd_origin, sd_merged, keys
48
 
 
 
49
  device = torch.device('cuda')
50
  text_encoder = text_encoder.to(device=device, dtype=torch.float16)
51
  vae = vae.to(device=device, dtype=torch.bfloat16)
52
  unet = unet.to(device=device, dtype=torch.float16)
53
  rmbg = rmbg.to(device=device, dtype=torch.float32)
54
 
 
 
55
  unet.set_attn_processor(AttnProcessor2_0())
56
  vae.set_attn_processor(AttnProcessor2_0())
57
 
 
 
58
  ddim_scheduler = DDIMScheduler(
59
  num_train_timesteps=1000,
60
  beta_start=0.00085,
@@ -81,6 +99,8 @@ dpmpp_2m_sde_karras_scheduler = DPMSolverMultistepScheduler(
81
  steps_offset=1
82
  )
83
 
 
 
84
  t2i_pipe = StableDiffusionPipeline(
85
  vae=vae,
86
  text_encoder=text_encoder,
@@ -105,6 +125,7 @@ i2i_pipe = StableDiffusionImg2ImgPipeline(
105
  image_encoder=None
106
  )
107
 
 
108
  @torch.inference_mode()
109
  def encode_prompt_inner(txt: str):
110
  max_length = tokenizer.model_max_length
@@ -125,6 +146,7 @@ def encode_prompt_inner(txt: str):
125
 
126
  return conds
127
 
 
128
  @torch.inference_mode()
129
  def encode_prompt_pair(positive_prompt, negative_prompt):
130
  c = encode_prompt_inner(positive_prompt)
@@ -145,6 +167,7 @@ def encode_prompt_pair(positive_prompt, negative_prompt):
145
 
146
  return c, uc
147
 
 
148
  @torch.inference_mode()
149
  def pytorch2numpy(imgs, quant=True):
150
  results = []
@@ -161,12 +184,14 @@ def pytorch2numpy(imgs, quant=True):
161
  results.append(y)
162
  return results
163
 
 
164
  @torch.inference_mode()
165
  def numpy2pytorch(imgs):
166
- h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.0 - 1.0
167
  h = h.movedim(-1, 1)
168
  return h
169
 
 
170
  def resize_and_center_crop(image, target_width, target_height):
171
  pil_image = Image.fromarray(image)
172
  original_width, original_height = pil_image.size
@@ -181,11 +206,13 @@ def resize_and_center_crop(image, target_width, target_height):
181
  cropped_image = resized_image.crop((left, top, right, bottom))
182
  return np.array(cropped_image)
183
 
 
184
  def resize_without_crop(image, target_width, target_height):
185
  pil_image = Image.fromarray(image)
186
  resized_image = pil_image.resize((target_width, target_height), Image.LANCZOS)
187
  return np.array(resized_image)
188
 
 
189
  @torch.inference_mode()
190
  def run_rmbg(img, sigma=0.0):
191
  H, W, C = img.shape
@@ -200,42 +227,9 @@ def run_rmbg(img, sigma=0.0):
200
  result = 127 + (img.astype(np.float32) - 127 + sigma) * alpha
201
  return result.clip(0, 255).astype(np.uint8), alpha
202
 
203
- @spaces.GPU
204
- def translate_albanian_to_english(text):
205
- """Translate Albanian to English using sepioo-facebook-translation API."""
206
- if not text.strip():
207
- return ""
208
- for attempt in range(2):
209
- try:
210
- response = requests.post(
211
- "https://hal1993-mdftranslation1234567890abcdef1234567890-fc073a6.hf.space/v1/translate",
212
- json={"from_language": "sq", "to_language": "en", "input_text": text},
213
- headers={"accept": "application/json", "Content-Type": "application/json"},
214
- timeout=5
215
- )
216
- response.raise_for_status()
217
- translated = response.json().get("translate", "")
218
- print(f"Translation response: {translated}")
219
- return translated
220
- except Exception as e:
221
- print(f"Translation error (attempt {attempt + 1}): {e}")
222
- if attempt == 1:
223
- return f"PΓ«rkthimi dΓ«shtoi: {str(e)}"
224
- return f"PΓ«rkthimi dΓ«shtoi"
225
 
226
- @spaces.GPU
227
  @torch.inference_mode()
228
- def process(input_fg, prompt_albanian, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source):
229
- if not input_fg.any():
230
- return None, None
231
-
232
- if not prompt_albanian.strip():
233
- prompt = ""
234
- else:
235
- prompt = translate_albanian_to_english(prompt_albanian)
236
- if prompt.startswith("PΓ«rkthimi dΓ«shtoi"):
237
- return None, None
238
-
239
  bg_source = BGSource(bg_source)
240
  input_bg = None
241
 
@@ -338,96 +332,102 @@ def process(input_fg, prompt_albanian, image_width, image_height, num_samples, s
338
 
339
  return pytorch2numpy(pixels)
340
 
 
341
  @spaces.GPU
342
  @torch.inference_mode()
343
- def process_relight(input_fg, prompt_albanian, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source):
344
- if not input_fg.any():
345
- return None, None
346
-
347
  input_fg, matting = run_rmbg(input_fg)
348
- results = process(input_fg, prompt_albanian, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source)
349
  return input_fg, results
350
 
351
- def update_resolution(aspect_ratio):
352
- if aspect_ratio == "9:16":
353
- return 512, 910
354
- elif aspect_ratio == "1:1":
355
- return 640, 640
356
- elif aspect_ratio == "16:9":
357
- return 910, 512
358
- return 512, 910 # Default to 9:16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
359
 
360
  class BGSource(Enum):
361
- NONE = "AsnjΓ«"
362
- LEFT = "DritΓ« nga Majtas"
363
- RIGHT = "DritΓ« nga Djathtas"
364
- TOP = "DritΓ« nga SipΓ«r"
365
- BOTTOM = "DritΓ« nga PoshtΓ«"
366
-
367
- css = """
368
- body::before {
369
- content: "";
370
- display: block;
371
- height: 320px;
372
- background-color: var(--body-background-fill);
373
- }
374
- button[aria-label="Fullscreen"], button[aria-label="Fullscreen"]:hover {
375
- display: none !important;
376
- visibility: hidden !important;
377
- opacity: 0 !important;
378
- pointer-events: none !important;
379
- }
380
- button[aria-label="Share"], button[aria-label="Share"]:hover {
381
- display: none !important;
382
- }
383
- button[aria-label="Download"] {
384
- transform: scale(3);
385
- transform-origin: top right;
386
- margin: 0 !important;
387
- padding: 6px !important;
388
- }
389
- """
390
-
391
- block = gr.Blocks(css=css).queue()
392
  with block:
 
 
 
 
393
  with gr.Row():
394
  with gr.Column():
395
  with gr.Row():
396
- input_fg = gr.Image(sources='upload', type="numpy", label="Imazhi i Hyrjes", height=480)
397
- output_bg = gr.Image(type="numpy", label="Sfondi i PΓ«rpunuar", height=480)
398
- prompt_albanian = gr.Textbox(label="PΓ«rshkrimi")
399
  bg_source = gr.Radio(choices=[e.value for e in BGSource],
400
  value=BGSource.NONE.value,
401
- label="Preferenca e NdriΓ§imit (Latenti Fillestar)", type='value')
402
- aspect_ratio = gr.Radio(choices=["9:16", "1:1", "16:9"], value="9:16", label="Raporti i Imazhit")
403
- relight_button = gr.Button(value="Gjenero")
 
404
 
405
  with gr.Group():
406
  with gr.Row():
407
- num_samples = gr.Slider(label="Numri i Imazheve", minimum=1, maximum=12, value=1, step=1, visible=False)
408
- seed = gr.Number(label="FarΓ«", value=-1, precision=0, visible=False)
409
- image_width = gr.State(value=512)
410
- image_height = gr.State(value=910)
411
-
412
- with gr.Accordion("Opsionet e Avancuara", open=False, visible=False):
413
- steps = gr.Slider(label="Hapat", minimum=1, maximum=100, value=50, step=1)
414
- cfg = gr.Slider(label="Shkalla CFG", minimum=1.0, maximum=32.0, value=2, step=0.01)
415
- lowres_denoise = gr.Slider(label="Denoise pΓ«r RezolutΓ« tΓ« UlΓ«t (pΓ«r latent fillestar)", minimum=0.1, maximum=1.0, value=0.9, step=0.01)
416
- highres_scale = gr.Slider(label="Shkalla e RezolutΓ«s sΓ« LartΓ«", minimum=1.0, maximum=3.0, value=2, step=0.01)
417
- highres_denoise = gr.Slider(label="Denoise pΓ«r RezolutΓ« tΓ« LartΓ«", minimum=0.1, maximum=1.0, value=1, step=0.01)
 
 
418
  a_prompt = gr.Textbox(label="Added Prompt", value='best quality')
419
  n_prompt = gr.Textbox(label="Negative Prompt", value='lowres, bad anatomy, bad hands, cropped, worst quality')
420
  with gr.Column():
421
- result_gallery = gr.Gallery(height=832, object_fit='contain', label='Rezultatet')
422
  with gr.Row():
423
- dummy_image_for_outputs = gr.Image(visible=False, label='Rezultati')
424
- ips = [input_fg, prompt_albanian, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source]
425
- aspect_ratio.change(
426
- fn=update_resolution,
427
- inputs=[aspect_ratio],
428
- outputs=[image_width, image_height],
429
- queue=False
430
- )
 
 
 
431
  relight_button.click(fn=process_relight, inputs=ips, outputs=[output_bg, result_gallery])
 
 
 
432
 
433
- block.launch(server_name='0.0.0.0')
 
4
  import numpy as np
5
  import torch
6
  import safetensors.torch as sf
7
+ import db_examples
8
+
9
  from PIL import Image
10
  from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline
11
  from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler, EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler
 
13
  from transformers import CLIPTextModel, CLIPTokenizer
14
  from briarmbg import BriaRMBG
15
  from enum import Enum
16
+ # from torch.hub import download_url_to_file
17
+
18
 
19
+ # 'stablediffusionapi/realistic-vision-v51'
20
+ # 'runwayml/stable-diffusion-v1-5'
21
  sd15_name = 'stablediffusionapi/realistic-vision-v51'
22
  tokenizer = CLIPTokenizer.from_pretrained(sd15_name, subfolder="tokenizer")
23
  text_encoder = CLIPTextModel.from_pretrained(sd15_name, subfolder="text_encoder")
 
25
  unet = UNet2DConditionModel.from_pretrained(sd15_name, subfolder="unet")
26
  rmbg = BriaRMBG.from_pretrained("briaai/RMBG-1.4")
27
 
28
+ # Change UNet
29
+
30
  with torch.no_grad():
31
  new_conv_in = torch.nn.Conv2d(8, unet.conv_in.out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding)
32
  new_conv_in.weight.zero_()
 
36
 
37
  unet_original_forward = unet.forward
38
 
39
+
40
  def hooked_unet_forward(sample, timestep, encoder_hidden_states, **kwargs):
41
  c_concat = kwargs['cross_attention_kwargs']['concat_conds'].to(sample)
42
  c_concat = torch.cat([c_concat] * (sample.shape[0] // c_concat.shape[0]), dim=0)
 
44
  kwargs['cross_attention_kwargs'] = {}
45
  return unet_original_forward(new_sample, timestep, encoder_hidden_states, **kwargs)
46
 
47
+
48
  unet.forward = hooked_unet_forward
49
 
50
+ # Load
51
+
52
  model_path = './models/iclight_sd15_fc.safetensors'
53
+ # download_url_to_file(url='https://huggingface.co/lllyasviel/ic-light/resolve/main/iclight_sd15_fc.safetensors', dst=model_path)
54
  sd_offset = sf.load_file(model_path)
55
  sd_origin = unet.state_dict()
56
  keys = sd_origin.keys()
 
58
  unet.load_state_dict(sd_merged, strict=True)
59
  del sd_offset, sd_origin, sd_merged, keys
60
 
61
+ # Device
62
+
63
  device = torch.device('cuda')
64
  text_encoder = text_encoder.to(device=device, dtype=torch.float16)
65
  vae = vae.to(device=device, dtype=torch.bfloat16)
66
  unet = unet.to(device=device, dtype=torch.float16)
67
  rmbg = rmbg.to(device=device, dtype=torch.float32)
68
 
69
+ # SDP
70
+
71
  unet.set_attn_processor(AttnProcessor2_0())
72
  vae.set_attn_processor(AttnProcessor2_0())
73
 
74
+ # Samplers
75
+
76
  ddim_scheduler = DDIMScheduler(
77
  num_train_timesteps=1000,
78
  beta_start=0.00085,
 
99
  steps_offset=1
100
  )
101
 
102
+ # Pipelines
103
+
104
  t2i_pipe = StableDiffusionPipeline(
105
  vae=vae,
106
  text_encoder=text_encoder,
 
125
  image_encoder=None
126
  )
127
 
128
+
129
  @torch.inference_mode()
130
  def encode_prompt_inner(txt: str):
131
  max_length = tokenizer.model_max_length
 
146
 
147
  return conds
148
 
149
+
150
  @torch.inference_mode()
151
  def encode_prompt_pair(positive_prompt, negative_prompt):
152
  c = encode_prompt_inner(positive_prompt)
 
167
 
168
  return c, uc
169
 
170
+
171
  @torch.inference_mode()
172
  def pytorch2numpy(imgs, quant=True):
173
  results = []
 
184
  results.append(y)
185
  return results
186
 
187
+
188
  @torch.inference_mode()
189
  def numpy2pytorch(imgs):
190
+ h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.0 - 1.0 # so that 127 must be strictly 0.0
191
  h = h.movedim(-1, 1)
192
  return h
193
 
194
+
195
  def resize_and_center_crop(image, target_width, target_height):
196
  pil_image = Image.fromarray(image)
197
  original_width, original_height = pil_image.size
 
206
  cropped_image = resized_image.crop((left, top, right, bottom))
207
  return np.array(cropped_image)
208
 
209
+
210
  def resize_without_crop(image, target_width, target_height):
211
  pil_image = Image.fromarray(image)
212
  resized_image = pil_image.resize((target_width, target_height), Image.LANCZOS)
213
  return np.array(resized_image)
214
 
215
+
216
  @torch.inference_mode()
217
  def run_rmbg(img, sigma=0.0):
218
  H, W, C = img.shape
 
227
  result = 127 + (img.astype(np.float32) - 127 + sigma) * alpha
228
  return result.clip(0, 255).astype(np.uint8), alpha
229
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
 
 
231
  @torch.inference_mode()
232
+ def process(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source):
 
 
 
 
 
 
 
 
 
 
233
  bg_source = BGSource(bg_source)
234
  input_bg = None
235
 
 
332
 
333
  return pytorch2numpy(pixels)
334
 
335
+
336
  @spaces.GPU
337
  @torch.inference_mode()
338
+ def process_relight(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source):
 
 
 
339
  input_fg, matting = run_rmbg(input_fg)
340
+ results = process(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source)
341
  return input_fg, results
342
 
343
+
344
+ quick_prompts = [
345
+ 'sunshine from window',
346
+ 'neon light, city',
347
+ 'sunset over sea',
348
+ 'golden time',
349
+ 'sci-fi RGB glowing, cyberpunk',
350
+ 'natural lighting',
351
+ 'warm atmosphere, at home, bedroom',
352
+ 'magic lit',
353
+ 'evil, gothic, Yharnam',
354
+ 'light and shadow',
355
+ 'shadow from window',
356
+ 'soft studio lighting',
357
+ 'home atmosphere, cozy bedroom illumination',
358
+ 'neon, Wong Kar-wai, warm'
359
+ ]
360
+ quick_prompts = [[x] for x in quick_prompts]
361
+
362
+
363
+ quick_subjects = [
364
+ 'beautiful woman, detailed face',
365
+ 'handsome man, detailed face',
366
+ ]
367
+ quick_subjects = [[x] for x in quick_subjects]
368
+
369
 
370
  class BGSource(Enum):
371
+ NONE = "None"
372
+ LEFT = "Left Light"
373
+ RIGHT = "Right Light"
374
+ TOP = "Top Light"
375
+ BOTTOM = "Bottom Light"
376
+
377
+
378
+ block = gr.Blocks().queue()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
379
  with block:
380
+ with gr.Row():
381
+ gr.Markdown("## IC-Light (Relighting with Foreground Condition)")
382
+ with gr.Row():
383
+ gr.Markdown("See also https://github.com/lllyasviel/IC-Light for background-conditioned model and normal estimation")
384
  with gr.Row():
385
  with gr.Column():
386
  with gr.Row():
387
+ input_fg = gr.Image(sources='upload', type="numpy", label="Image", height=480)
388
+ output_bg = gr.Image(type="numpy", label="Preprocessed Foreground", height=480)
389
+ prompt = gr.Textbox(label="Prompt")
390
  bg_source = gr.Radio(choices=[e.value for e in BGSource],
391
  value=BGSource.NONE.value,
392
+ label="Lighting Preference (Initial Latent)", type='value')
393
+ example_quick_subjects = gr.Dataset(samples=quick_subjects, label='Subject Quick List', samples_per_page=1000, components=[prompt])
394
+ example_quick_prompts = gr.Dataset(samples=quick_prompts, label='Lighting Quick List', samples_per_page=1000, components=[prompt])
395
+ relight_button = gr.Button(value="Relight")
396
 
397
  with gr.Group():
398
  with gr.Row():
399
+ num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
400
+ seed = gr.Number(label="Seed", value=12345, precision=0)
401
+
402
+ with gr.Row():
403
+ image_width = gr.Slider(label="Image Width", minimum=256, maximum=1024, value=512, step=64)
404
+ image_height = gr.Slider(label="Image Height", minimum=256, maximum=1024, value=640, step=64)
405
+
406
+ with gr.Accordion("Advanced options", open=False):
407
+ steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=25, step=1)
408
+ cfg = gr.Slider(label="CFG Scale", minimum=1.0, maximum=32.0, value=2, step=0.01)
409
+ lowres_denoise = gr.Slider(label="Lowres Denoise (for initial latent)", minimum=0.1, maximum=1.0, value=0.9, step=0.01)
410
+ highres_scale = gr.Slider(label="Highres Scale", minimum=1.0, maximum=3.0, value=1.5, step=0.01)
411
+ highres_denoise = gr.Slider(label="Highres Denoise", minimum=0.1, maximum=1.0, value=0.5, step=0.01)
412
  a_prompt = gr.Textbox(label="Added Prompt", value='best quality')
413
  n_prompt = gr.Textbox(label="Negative Prompt", value='lowres, bad anatomy, bad hands, cropped, worst quality')
414
  with gr.Column():
415
+ result_gallery = gr.Gallery(height=832, object_fit='contain', label='Outputs')
416
  with gr.Row():
417
+ dummy_image_for_outputs = gr.Image(visible=False, label='Result')
418
+ gr.Examples(
419
+ fn=lambda *args: [[args[-1]], "imgs/dummy.png"],
420
+ examples=db_examples.foreground_conditioned_examples,
421
+ inputs=[
422
+ input_fg, prompt, bg_source, image_width, image_height, seed, dummy_image_for_outputs
423
+ ],
424
+ outputs=[result_gallery, output_bg],
425
+ run_on_click=True, examples_per_page=1024
426
+ )
427
+ ips = [input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source]
428
  relight_button.click(fn=process_relight, inputs=ips, outputs=[output_bg, result_gallery])
429
+ example_quick_prompts.click(lambda x, y: ', '.join(y.split(', ')[:2] + [x[0]]), inputs=[example_quick_prompts, prompt], outputs=prompt, show_progress=False, queue=False)
430
+ example_quick_subjects.click(lambda x: x[0], inputs=example_quick_subjects, outputs=prompt, show_progress=False, queue=False)
431
+
432
 
433
+ block.launch(server_name='0.0.0.0')