HAL1993 commited on
Commit
522960f
·
verified ·
1 Parent(s): a673b7c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +91 -121
app.py CHANGED
@@ -13,11 +13,9 @@ from diffusers.models.attention_processor import AttnProcessor2_0
13
  from transformers import CLIPTextModel, CLIPTokenizer
14
  from briarmbg import BriaRMBG
15
  from enum import Enum
16
- # from torch.hub import download_url_to_file
17
 
18
-
19
- # 'stablediffusionapi/realistic-vision-v51'
20
- # 'runwayml/stable-diffusion-v1-5'
21
  sd15_name = 'stablediffusionapi/realistic-vision-v51'
22
  tokenizer = CLIPTokenizer.from_pretrained(sd15_name, subfolder="tokenizer")
23
  text_encoder = CLIPTextModel.from_pretrained(sd15_name, subfolder="text_encoder")
@@ -25,8 +23,7 @@ vae = AutoencoderKL.from_pretrained(sd15_name, subfolder="vae")
25
  unet = UNet2DConditionModel.from_pretrained(sd15_name, subfolder="unet")
26
  rmbg = BriaRMBG.from_pretrained("briaai/RMBG-1.4")
27
 
28
- # Change UNet
29
-
30
  with torch.no_grad():
31
  new_conv_in = torch.nn.Conv2d(8, unet.conv_in.out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding)
32
  new_conv_in.weight.zero_()
@@ -36,7 +33,6 @@ with torch.no_grad():
36
 
37
  unet_original_forward = unet.forward
38
 
39
-
40
  def hooked_unet_forward(sample, timestep, encoder_hidden_states, **kwargs):
41
  c_concat = kwargs['cross_attention_kwargs']['concat_conds'].to(sample)
42
  c_concat = torch.cat([c_concat] * (sample.shape[0] // c_concat.shape[0]), dim=0)
@@ -44,13 +40,10 @@ def hooked_unet_forward(sample, timestep, encoder_hidden_states, **kwargs):
44
  kwargs['cross_attention_kwargs'] = {}
45
  return unet_original_forward(new_sample, timestep, encoder_hidden_states, **kwargs)
46
 
47
-
48
  unet.forward = hooked_unet_forward
49
 
50
- # Load
51
-
52
  model_path = './models/iclight_sd15_fc.safetensors'
53
- # download_url_to_file(url='https://huggingface.co/lllyasviel/ic-light/resolve/main/iclight_sd15_fc.safetensors', dst=model_path)
54
  sd_offset = sf.load_file(model_path)
55
  sd_origin = unet.state_dict()
56
  keys = sd_origin.keys()
@@ -58,21 +51,18 @@ sd_merged = {k: sd_origin[k] + sd_offset[k] for k in sd_origin.keys()}
58
  unet.load_state_dict(sd_merged, strict=True)
59
  del sd_offset, sd_origin, sd_merged, keys
60
 
61
- # Device
62
-
63
  device = torch.device('cuda')
64
  text_encoder = text_encoder.to(device=device, dtype=torch.float16)
65
  vae = vae.to(device=device, dtype=torch.bfloat16)
66
  unet = unet.to(device=device, dtype=torch.float16)
67
  rmbg = rmbg.to(device=device, dtype=torch.float32)
68
 
69
- # SDP
70
-
71
  unet.set_attn_processor(AttnProcessor2_0())
72
  vae.set_attn_processor(AttnProcessor2_0())
73
 
74
- # Samplers
75
-
76
  ddim_scheduler = DDIMScheduler(
77
  num_train_timesteps=1000,
78
  beta_start=0.00085,
@@ -99,8 +89,7 @@ dpmpp_2m_sde_karras_scheduler = DPMSolverMultistepScheduler(
99
  steps_offset=1
100
  )
101
 
102
- # Pipelines
103
-
104
  t2i_pipe = StableDiffusionPipeline(
105
  vae=vae,
106
  text_encoder=text_encoder,
@@ -125,7 +114,28 @@ i2i_pipe = StableDiffusionImg2ImgPipeline(
125
  image_encoder=None
126
  )
127
 
128
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  @torch.inference_mode()
130
  def encode_prompt_inner(txt: str):
131
  max_length = tokenizer.model_max_length
@@ -146,7 +156,6 @@ def encode_prompt_inner(txt: str):
146
 
147
  return conds
148
 
149
-
150
  @torch.inference_mode()
151
  def encode_prompt_pair(positive_prompt, negative_prompt):
152
  c = encode_prompt_inner(positive_prompt)
@@ -167,31 +176,26 @@ def encode_prompt_pair(positive_prompt, negative_prompt):
167
 
168
  return c, uc
169
 
170
-
171
  @torch.inference_mode()
172
  def pytorch2numpy(imgs, quant=True):
173
  results = []
174
  for x in imgs:
175
  y = x.movedim(0, -1)
176
-
177
  if quant:
178
  y = y * 127.5 + 127.5
179
  y = y.detach().float().cpu().numpy().clip(0, 255).astype(np.uint8)
180
  else:
181
  y = y * 0.5 + 0.5
182
  y = y.detach().float().cpu().numpy().clip(0, 1).astype(np.float32)
183
-
184
  results.append(y)
185
  return results
186
 
187
-
188
  @torch.inference_mode()
189
  def numpy2pytorch(imgs):
190
- h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.0 - 1.0 # so that 127 must be strictly 0.0
191
  h = h.movedim(-1, 1)
192
  return h
193
 
194
-
195
  def resize_and_center_crop(image, target_width, target_height):
196
  pil_image = Image.fromarray(image)
197
  original_width, original_height = pil_image.size
@@ -206,13 +210,11 @@ def resize_and_center_crop(image, target_width, target_height):
206
  cropped_image = resized_image.crop((left, top, right, bottom))
207
  return np.array(cropped_image)
208
 
209
-
210
  def resize_without_crop(image, target_width, target_height):
211
  pil_image = Image.fromarray(image)
212
  resized_image = pil_image.resize((target_width, target_height), Image.LANCZOS)
213
  return np.array(resized_image)
214
 
215
-
216
  @torch.inference_mode()
217
  def run_rmbg(img, sigma=0.0):
218
  H, W, C = img.shape
@@ -227,7 +229,6 @@ def run_rmbg(img, sigma=0.0):
227
  result = 127 + (img.astype(np.float32) - 127 + sigma) * alpha
228
  return result.clip(0, 255).astype(np.uint8), alpha
229
 
230
-
231
  @torch.inference_mode()
232
  def process(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source):
233
  bg_source = BGSource(bg_source)
@@ -332,102 +333,71 @@ def process(input_fg, prompt, image_width, image_height, num_samples, seed, step
332
 
333
  return pytorch2numpy(pixels)
334
 
335
-
336
  @spaces.GPU
337
  @torch.inference_mode()
338
  def process_relight(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source):
 
 
 
 
 
339
  input_fg, matting = run_rmbg(input_fg)
340
- results = process(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source)
341
  return input_fg, results
342
 
343
-
344
- quick_prompts = [
345
- 'sunshine from window',
346
- 'neon light, city',
347
- 'sunset over sea',
348
- 'golden time',
349
- 'sci-fi RGB glowing, cyberpunk',
350
- 'natural lighting',
351
- 'warm atmosphere, at home, bedroom',
352
- 'magic lit',
353
- 'evil, gothic, Yharnam',
354
- 'light and shadow',
355
- 'shadow from window',
356
- 'soft studio lighting',
357
- 'home atmosphere, cozy bedroom illumination',
358
- 'neon, Wong Kar-wai, warm'
359
- ]
360
- quick_prompts = [[x] for x in quick_prompts]
361
-
362
-
363
- quick_subjects = [
364
- 'beautiful woman, detailed face',
365
- 'handsome man, detailed face',
366
- ]
367
- quick_subjects = [[x] for x in quick_subjects]
368
-
369
-
370
  class BGSource(Enum):
371
- NONE = "None"
372
- LEFT = "Left Light"
373
- RIGHT = "Right Light"
374
- TOP = "Top Light"
375
- BOTTOM = "Bottom Light"
376
-
377
-
378
- block = gr.Blocks().queue()
379
- with block:
380
- with gr.Row():
381
- gr.Markdown("## IC-Light (Relighting with Foreground Condition)")
382
- with gr.Row():
383
- gr.Markdown("See also https://github.com/lllyasviel/IC-Light for background-conditioned model and normal estimation")
384
- with gr.Row():
385
- with gr.Column():
386
- with gr.Row():
387
- input_fg = gr.Image(sources='upload', type="numpy", label="Image", height=480)
388
- output_bg = gr.Image(type="numpy", label="Preprocessed Foreground", height=480)
389
- prompt = gr.Textbox(label="Prompt")
390
- bg_source = gr.Radio(choices=[e.value for e in BGSource],
391
- value=BGSource.NONE.value,
392
- label="Lighting Preference (Initial Latent)", type='value')
393
- example_quick_subjects = gr.Dataset(samples=quick_subjects, label='Subject Quick List', samples_per_page=1000, components=[prompt])
394
- example_quick_prompts = gr.Dataset(samples=quick_prompts, label='Lighting Quick List', samples_per_page=1000, components=[prompt])
395
- relight_button = gr.Button(value="Relight")
396
-
397
- with gr.Group():
398
- with gr.Row():
399
- num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
400
- seed = gr.Number(label="Seed", value=12345, precision=0)
401
-
402
- with gr.Row():
403
- image_width = gr.Slider(label="Image Width", minimum=256, maximum=1024, value=512, step=64)
404
- image_height = gr.Slider(label="Image Height", minimum=256, maximum=1024, value=640, step=64)
405
-
406
- with gr.Accordion("Advanced options", open=False):
407
- steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=25, step=1)
408
- cfg = gr.Slider(label="CFG Scale", minimum=1.0, maximum=32.0, value=2, step=0.01)
409
- lowres_denoise = gr.Slider(label="Lowres Denoise (for initial latent)", minimum=0.1, maximum=1.0, value=0.9, step=0.01)
410
- highres_scale = gr.Slider(label="Highres Scale", minimum=1.0, maximum=3.0, value=1.5, step=0.01)
411
- highres_denoise = gr.Slider(label="Highres Denoise", minimum=0.1, maximum=1.0, value=0.5, step=0.01)
412
- a_prompt = gr.Textbox(label="Added Prompt", value='best quality')
413
- n_prompt = gr.Textbox(label="Negative Prompt", value='lowres, bad anatomy, bad hands, cropped, worst quality')
414
  with gr.Column():
415
- result_gallery = gr.Gallery(height=832, object_fit='contain', label='Outputs')
416
- with gr.Row():
417
- dummy_image_for_outputs = gr.Image(visible=False, label='Result')
418
- gr.Examples(
419
- fn=lambda *args: [[args[-1]], "imgs/dummy.png"],
420
- examples=db_examples.foreground_conditioned_examples,
421
- inputs=[
422
- input_fg, prompt, bg_source, image_width, image_height, seed, dummy_image_for_outputs
423
- ],
424
- outputs=[result_gallery, output_bg],
425
- run_on_click=True, examples_per_page=1024
426
- )
427
- ips = [input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source]
428
- relight_button.click(fn=process_relight, inputs=ips, outputs=[output_bg, result_gallery])
429
- example_quick_prompts.click(lambda x, y: ', '.join(y.split(', ')[:2] + [x[0]]), inputs=[example_quick_prompts, prompt], outputs=prompt, show_progress=False, queue=False)
430
- example_quick_subjects.click(lambda x: x[0], inputs=example_quick_subjects, outputs=prompt, show_progress=False, queue=False)
431
-
432
-
433
- block.launch(server_name='0.0.0.0')
 
13
  from transformers import CLIPTextModel, CLIPTokenizer
14
  from briarmbg import BriaRMBG
15
  from enum import Enum
16
+ import requests
17
 
18
+ # Model setup (unchanged)
 
 
19
  sd15_name = 'stablediffusionapi/realistic-vision-v51'
20
  tokenizer = CLIPTokenizer.from_pretrained(sd15_name, subfolder="tokenizer")
21
  text_encoder = CLIPTextModel.from_pretrained(sd15_name, subfolder="text_encoder")
 
23
  unet = UNet2DConditionModel.from_pretrained(sd15_name, subfolder="unet")
24
  rmbg = BriaRMBG.from_pretrained("briaai/RMBG-1.4")
25
 
26
+ # Change UNet (unchanged)
 
27
  with torch.no_grad():
28
  new_conv_in = torch.nn.Conv2d(8, unet.conv_in.out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding)
29
  new_conv_in.weight.zero_()
 
33
 
34
  unet_original_forward = unet.forward
35
 
 
36
  def hooked_unet_forward(sample, timestep, encoder_hidden_states, **kwargs):
37
  c_concat = kwargs['cross_attention_kwargs']['concat_conds'].to(sample)
38
  c_concat = torch.cat([c_concat] * (sample.shape[0] // c_concat.shape[0]), dim=0)
 
40
  kwargs['cross_attention_kwargs'] = {}
41
  return unet_original_forward(new_sample, timestep, encoder_hidden_states, **kwargs)
42
 
 
43
  unet.forward = hooked_unet_forward
44
 
45
+ # Load model (unchanged)
 
46
  model_path = './models/iclight_sd15_fc.safetensors'
 
47
  sd_offset = sf.load_file(model_path)
48
  sd_origin = unet.state_dict()
49
  keys = sd_origin.keys()
 
51
  unet.load_state_dict(sd_merged, strict=True)
52
  del sd_offset, sd_origin, sd_merged, keys
53
 
54
+ # Device setup (unchanged)
 
55
  device = torch.device('cuda')
56
  text_encoder = text_encoder.to(device=device, dtype=torch.float16)
57
  vae = vae.to(device=device, dtype=torch.bfloat16)
58
  unet = unet.to(device=device, dtype=torch.float16)
59
  rmbg = rmbg.to(device=device, dtype=torch.float32)
60
 
61
+ # SDP (unchanged)
 
62
  unet.set_attn_processor(AttnProcessor2_0())
63
  vae.set_attn_processor(AttnProcessor2_0())
64
 
65
+ # Samplers (unchanged)
 
66
  ddim_scheduler = DDIMScheduler(
67
  num_train_timesteps=1000,
68
  beta_start=0.00085,
 
89
  steps_offset=1
90
  )
91
 
92
+ # Pipelines (unchanged)
 
93
  t2i_pipe = StableDiffusionPipeline(
94
  vae=vae,
95
  text_encoder=text_encoder,
 
114
  image_encoder=None
115
  )
116
 
117
+ # Translation function (adapted from example)
118
+ @spaces.GPU
119
+ def translate_albanian_to_english(text):
120
+ if not text.strip():
121
+ return ""
122
+ for attempt in range(2):
123
+ try:
124
+ response = requests.post(
125
+ "https://hal1993-mdftranslation1234567890abcdef1234567890-fc073a6.hf.space/v1/translate",
126
+ json={"from_language": "sq", "to_language": "en", "input_text": text},
127
+ headers={"accept": "application/json", "Content-Type": "application/json"},
128
+ timeout=5
129
+ )
130
+ response.raise_for_status()
131
+ translated = response.json().get("translate", "")
132
+ return translated
133
+ except Exception as e:
134
+ if attempt == 1:
135
+ return f"Përkthimi dështoi: {str(e)}"
136
+ return f"Përkthimi dështoi"
137
+
138
+ # Core processing functions (unchanged)
139
  @torch.inference_mode()
140
  def encode_prompt_inner(txt: str):
141
  max_length = tokenizer.model_max_length
 
156
 
157
  return conds
158
 
 
159
  @torch.inference_mode()
160
  def encode_prompt_pair(positive_prompt, negative_prompt):
161
  c = encode_prompt_inner(positive_prompt)
 
176
 
177
  return c, uc
178
 
 
179
  @torch.inference_mode()
180
  def pytorch2numpy(imgs, quant=True):
181
  results = []
182
  for x in imgs:
183
  y = x.movedim(0, -1)
 
184
  if quant:
185
  y = y * 127.5 + 127.5
186
  y = y.detach().float().cpu().numpy().clip(0, 255).astype(np.uint8)
187
  else:
188
  y = y * 0.5 + 0.5
189
  y = y.detach().float().cpu().numpy().clip(0, 1).astype(np.float32)
 
190
  results.append(y)
191
  return results
192
 
 
193
  @torch.inference_mode()
194
  def numpy2pytorch(imgs):
195
+ h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.0 - 1.0
196
  h = h.movedim(-1, 1)
197
  return h
198
 
 
199
  def resize_and_center_crop(image, target_width, target_height):
200
  pil_image = Image.fromarray(image)
201
  original_width, original_height = pil_image.size
 
210
  cropped_image = resized_image.crop((left, top, right, bottom))
211
  return np.array(cropped_image)
212
 
 
213
  def resize_without_crop(image, target_width, target_height):
214
  pil_image = Image.fromarray(image)
215
  resized_image = pil_image.resize((target_width, target_height), Image.LANCZOS)
216
  return np.array(resized_image)
217
 
 
218
  @torch.inference_mode()
219
  def run_rmbg(img, sigma=0.0):
220
  H, W, C = img.shape
 
229
  result = 127 + (img.astype(np.float32) - 127 + sigma) * alpha
230
  return result.clip(0, 255).astype(np.uint8), alpha
231
 
 
232
  @torch.inference_mode()
233
  def process(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source):
234
  bg_source = BGSource(bg_source)
 
333
 
334
  return pytorch2numpy(pixels)
335
 
 
336
  @spaces.GPU
337
  @torch.inference_mode()
338
  def process_relight(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source):
339
+ # Translate Albanian prompt to English
340
+ prompt_english = translate_albanian_to_english(prompt)
341
+ if prompt_english.startswith("Përkthimi dështoi"):
342
+ return None, None
343
+
344
  input_fg, matting = run_rmbg(input_fg)
345
+ results = process(input_fg, prompt_english, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source)
346
  return input_fg, results
347
 
348
+ # Enum for background source (translated to Albanian)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
349
  class BGSource(Enum):
350
+ NONE = "Asnjë"
351
+ LEFT = "Dritë nga e Majta"
352
+ RIGHT = "Dritë nga e Djathta"
353
+ TOP = "Dritë nga Sipër"
354
+ BOTTOM = "Dritë nga Poshtë"
355
+
356
+ # UI Layout
357
+ def create_demo():
358
+ with gr.Blocks() as block:
359
+ # CSS for 320px gap and download button scaling
360
+ gr.HTML("""
361
+ <style>
362
+ body::before {
363
+ content: "";
364
+ display: block;
365
+ height: 320px;
366
+ background-color: var(--body-background-fill);
367
+ }
368
+ button[aria-label="Fullscreen"], button[aria-label="Fullscreen"]:hover {
369
+ display: none !important;
370
+ visibility: hidden !important;
371
+ opacity: 0 !important;
372
+ pointer-events: none !important;
373
+ }
374
+ button[aria-label="Share"], button[aria-label="Share"]:hover {
375
+ display: none !important;
376
+ }
377
+ button[aria-label="Download"] {
378
+ transform: scale(3);
379
+ transform-origin: top right;
380
+ margin: 0 !important;
381
+ padding: 6px !important;
382
+ }
383
+ </style>
384
+ """)
385
+
 
 
 
 
 
 
 
386
  with gr.Column():
387
+ input_fg = gr.Image(sources='upload', type="numpy", label="Imazhi i Ngarkuar", height=480)
388
+ prompt = gr.Textbox(label="Përshkrimi", placeholder="Shkruani përshkrimin këtu")
389
+ bg_source = gr.Radio(choices=[e.value for e in BGSource], value=BGSource.NONE.value, label="Preferenca e Ndriçimit", type='value')
390
+ relight_button = gr.Button(value="Rindriço")
391
+ result_gallery = gr.Gallery(label='Rezultatet', visible=False) # Hidden output
392
+ output_bg = gr.Image(type="numpy", label="Parapërpunimi i Planit të Parë", visible=False) # Hidden output
393
+
394
+ # Bind the relight button
395
+ ips = [input_fg, prompt, 512, 640, 1, 12345, 25, 'best quality', 'lowres, bad anatomy, bad hands, cropped, worst quality', 2, 1.5, 0.5, 0.9, bg_source]
396
+ relight_button.click(fn=process_relight, inputs=ips, outputs=[output_bg, result_gallery])
397
+
398
+ return block
399
+
400
+ if __name__ == "__main__":
401
+ print(f"Gradio version: {gr.__version__}")
402
+ app = create_demo()
403
+ app.launch(server_name='0.0.0.0')