Bobby commited on
Commit
52c8e07
·
1 Parent(s): b8caffe
Files changed (1) hide show
  1. app.py +22 -38
app.py CHANGED
@@ -15,7 +15,6 @@ import spaces
15
  import gc
16
  import torch
17
  from PIL import Image
18
- import hashlib
19
  from diffusers import (
20
  ControlNetModel,
21
  DPMSolverMultistepScheduler,
@@ -240,9 +239,7 @@ def apply_style(style_name):
240
  p = styles.get(style_name, "boho chic")
241
  return p
242
 
243
- def image_hash(image):
244
- return hashlib.md5(image.tobytes()).hexdigest()
245
-
246
  css = """
247
  h1, h2, h3 {
248
  text-align: center;
@@ -360,40 +357,35 @@ with gr.Blocks(theme="bethecloud/storj_theme", css=css) as demo:
360
  guidance_scale,
361
  seed,
362
  ]
363
-
364
  with gr.Row():
365
  helper_text = gr.Markdown("## Tap and hold (on mobile) to save the image.", visible=True)
366
 
367
- previous_image_hash = None
368
-
369
  @gr.on(triggers=[image.upload, prompt.submit, run_button.click], inputs=config, outputs=result, show_progress="minimal")
370
  def auto_process_image(image, style_selection, prompt, a_prompt, n_prompt, num_images, image_resolution, preprocess_resolution, num_steps, guidance_scale, seed, progress=gr.Progress(track_tqdm=True)):
371
- global previous_image_hash
372
- result, new_hash = process_image(image, style_selection, prompt, a_prompt, n_prompt, num_images, image_resolution, preprocess_resolution, num_steps, guidance_scale, seed, previous_image_hash)
373
- previous_image_hash = new_hash
374
- return result
375
 
 
376
  @gr.on(triggers=[use_ai_button.click], inputs=[result] + config, outputs=[image, result], show_progress="minimal")
377
  def submit(previous_result, image, style_selection, prompt, a_prompt, n_prompt, num_images, image_resolution, preprocess_resolution, num_steps, guidance_scale, seed, progress=gr.Progress(track_tqdm=True)):
378
- global previous_image_hash
379
  yield previous_result, gr.update()
380
- new_result, new_hash = process_image(previous_result, style_selection, prompt, a_prompt, n_prompt, num_images, image_resolution, preprocess_resolution, num_steps, guidance_scale, seed, previous_image_hash)
381
- previous_image_hash = new_hash
 
382
  yield previous_result, new_result
383
 
 
384
  @gr.on(triggers=[image.upload, use_ai_button.click, run_button.click], inputs=None, outputs=[run_button, use_ai_button], show_progress="hidden")
385
  def turn_buttons_off():
386
  return gr.update(visible=False), gr.update(visible=False)
387
 
 
388
  @gr.on(triggers=[result.change], inputs=None, outputs=[use_ai_button, run_button], show_progress="hidden")
389
  def turn_buttons_on():
390
  return gr.update(visible=True), gr.update(visible=True)
391
 
392
- @gr.on(triggers=[image.upload], inputs=None, outputs=None)
393
- def clear_image_hash():
394
- global previous_image_hash
395
- previous_image_hash = None
396
-
397
  @spaces.GPU(duration=12)
398
  @torch.inference_mode()
399
  def process_image(
@@ -408,34 +400,25 @@ def process_image(
408
  num_steps,
409
  guidance_scale,
410
  seed,
411
- previous_image_hash=None
412
  ):
413
- current_image_hash = image_hash(image)
414
-
415
  preprocess_start = time.time()
416
  print("processing image")
417
 
418
  seed = random.randint(0, MAX_SEED)
419
  generator = torch.cuda.manual_seed(seed)
420
-
421
- if previous_image_hash != current_image_hash:
422
- print("not the same image")
423
- preprocessor.load("NormalBae")
424
- control_image = preprocessor(
425
- image=image,
426
- image_resolution=image_resolution,
427
- detect_resolution=preprocess_resolution,
428
- )
429
- else:
430
- print("Image unchanged, skipping preprocessing")
431
- control_image = image
432
-
433
  preprocess_time = time.time() - preprocess_start
434
  if style_selection is not None or style_selection != "None":
435
  prompt = "Photo from Pinterest of " + apply_style(style_selection) + " " + prompt + " " + a_prompt
436
  else:
437
- prompt = str(get_prompt(prompt, a_prompt))
438
- negative_prompt = str(n_prompt)
439
  print(prompt)
440
  print(f"\n-------------------------Preprocess done in: {preprocess_time:.2f} seconds-------------------------")
441
  start = time.time()
@@ -449,8 +432,9 @@ def process_image(
449
  image=control_image,
450
  ).images[0]
451
  print(f"\n-------------------------Inference done in: {time.time() - start:.2f} seconds-------------------------")
 
452
  torch.cuda.empty_cache()
453
- return results, current_image_hash
454
 
455
  if prod:
456
  demo.queue(max_size=20).launch(server_name="localhost", server_port=port)
 
15
  import gc
16
  import torch
17
  from PIL import Image
 
18
  from diffusers import (
19
  ControlNetModel,
20
  DPMSolverMultistepScheduler,
 
239
  p = styles.get(style_name, "boho chic")
240
  return p
241
 
242
+
 
 
243
  css = """
244
  h1, h2, h3 {
245
  text-align: center;
 
357
  guidance_scale,
358
  seed,
359
  ]
360
+
361
  with gr.Row():
362
  helper_text = gr.Markdown("## Tap and hold (on mobile) to save the image.", visible=True)
363
 
364
+ # image processing
 
365
  @gr.on(triggers=[image.upload, prompt.submit, run_button.click], inputs=config, outputs=result, show_progress="minimal")
366
  def auto_process_image(image, style_selection, prompt, a_prompt, n_prompt, num_images, image_resolution, preprocess_resolution, num_steps, guidance_scale, seed, progress=gr.Progress(track_tqdm=True)):
367
+ return process_image(image, style_selection, prompt, a_prompt, n_prompt, num_images, image_resolution, preprocess_resolution, num_steps, guidance_scale, seed)
 
 
 
368
 
369
+ # AI image processing
370
  @gr.on(triggers=[use_ai_button.click], inputs=[result] + config, outputs=[image, result], show_progress="minimal")
371
  def submit(previous_result, image, style_selection, prompt, a_prompt, n_prompt, num_images, image_resolution, preprocess_resolution, num_steps, guidance_scale, seed, progress=gr.Progress(track_tqdm=True)):
372
+ # First, yield the previous result to update the input image immediately
373
  yield previous_result, gr.update()
374
+ # Then, process the new input image
375
+ new_result = process_image(previous_result, style_selection, prompt, a_prompt, n_prompt, num_images, image_resolution, preprocess_resolution, num_steps, guidance_scale, seed)
376
+ # Finally, yield the new result
377
  yield previous_result, new_result
378
 
379
+ # Turn off buttons when processing
380
  @gr.on(triggers=[image.upload, use_ai_button.click, run_button.click], inputs=None, outputs=[run_button, use_ai_button], show_progress="hidden")
381
  def turn_buttons_off():
382
  return gr.update(visible=False), gr.update(visible=False)
383
 
384
+ # Turn on buttons when processing is complete
385
  @gr.on(triggers=[result.change], inputs=None, outputs=[use_ai_button, run_button], show_progress="hidden")
386
  def turn_buttons_on():
387
  return gr.update(visible=True), gr.update(visible=True)
388
 
 
 
 
 
 
389
  @spaces.GPU(duration=12)
390
  @torch.inference_mode()
391
  def process_image(
 
400
  num_steps,
401
  guidance_scale,
402
  seed,
 
403
  ):
404
+ # torch.cuda.synchronize()
 
405
  preprocess_start = time.time()
406
  print("processing image")
407
 
408
  seed = random.randint(0, MAX_SEED)
409
  generator = torch.cuda.manual_seed(seed)
410
+ preprocessor.load("NormalBae")
411
+ control_image = preprocessor(
412
+ image=image,
413
+ image_resolution=image_resolution,
414
+ detect_resolution=preprocess_resolution,
415
+ )
 
 
 
 
 
 
 
416
  preprocess_time = time.time() - preprocess_start
417
  if style_selection is not None or style_selection != "None":
418
  prompt = "Photo from Pinterest of " + apply_style(style_selection) + " " + prompt + " " + a_prompt
419
  else:
420
+ prompt=str(get_prompt(prompt, a_prompt))
421
+ negative_prompt=str(n_prompt)
422
  print(prompt)
423
  print(f"\n-------------------------Preprocess done in: {preprocess_time:.2f} seconds-------------------------")
424
  start = time.time()
 
432
  image=control_image,
433
  ).images[0]
434
  print(f"\n-------------------------Inference done in: {time.time() - start:.2f} seconds-------------------------")
435
+ # torch.cuda.synchronize()
436
  torch.cuda.empty_cache()
437
+ return results
438
 
439
  if prod:
440
  demo.queue(max_size=20).launch(server_name="localhost", server_port=port)