HAL1993 commited on
Commit
130940f
·
verified ·
1 Parent(s): 54bd72e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -42
app.py CHANGED
@@ -5,15 +5,14 @@ from copy import deepcopy
5
  import gradio as gr
6
  import numpy as np
7
  import PIL
8
- from PIL import Image, ImageFilter
9
  import spaces
10
  import torch
11
  import yaml
12
  from huggingface_hub import hf_hub_download
 
13
  from safetensors.torch import load_file
14
  from torchvision.transforms import ToPILImage, ToTensor
15
  from transformers import AutoModelForImageSegmentation
16
- from diffusers import StableDiffusionPipeline
17
  from utils import extract_object, get_model_from_config, resize_and_center_crop
18
 
19
  huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
@@ -32,51 +31,20 @@ ASPECT_RATIOS = {
32
  str(1920 / 512): (1920, 512),
33
  }
34
 
35
- # Load relighting model
36
  MODEL_PATH = hf_hub_download("jasperai/LBM_relighting", "model.safetensors", token=huggingface_token)
37
  CONFIG_PATH = hf_hub_download("jasperai/LBM_relighting", "config.yaml", token=huggingface_token)
 
38
  with open(CONFIG_PATH, "r") as f:
39
  config = yaml.safe_load(f)
40
  model = get_model_from_config(**config)
41
  sd = load_file(MODEL_PATH)
42
  model.load_state_dict(sd, strict=True)
43
  model.to("cuda").to(torch.bfloat16)
44
-
45
- # Load segmentation model
46
  birefnet = AutoModelForImageSegmentation.from_pretrained("ZhengPeng7/BiRefNet", trust_remote_code=True).cuda()
47
-
48
- # Load Stable Diffusion pipeline for background generation
49
- sd_pipe = StableDiffusionPipeline.from_pretrained(
50
- "runwayml/stable-diffusion-v1-5",
51
- torch_dtype=torch.float16,
52
- use_auth_token=huggingface_token,
53
- )
54
- sd_pipe.to("cuda")
55
- sd_pipe.enable_attention_slicing()
56
 
57
  @spaces.GPU
58
- def generate_background_image(bg_prompt: str):
59
- if not bg_prompt or bg_prompt.strip() == "":
60
- return None
61
- with torch.inference_mode():
62
- bg_img = sd_pipe(prompt=bg_prompt, height=1024, width=1024, num_inference_steps=20).images[0]
63
- # Optional blur radius — tweak as you like or expose as a parameter
64
- bg_img = bg_img.filter(ImageFilter.GaussianBlur(radius=5))
65
- return bg_img
66
-
67
- @spaces.GPU
68
- def evaluate(
69
- fg_image: PIL.Image.Image,
70
- bg_image: PIL.Image.Image,
71
- bg_prompt: str,
72
- num_sampling_steps: int = 4,
73
- ):
74
- # Generate background if prompt is given
75
- if bg_prompt and bg_prompt.strip() != "":
76
- generated_bg = generate_background_image(bg_prompt)
77
- if generated_bg is not None:
78
- bg_image = generated_bg
79
-
80
  ori_h_bg, ori_w_bg = fg_image.size
81
  ar_bg = ori_h_bg / ori_w_bg
82
  closest_ar_bg = min(ASPECT_RATIOS, key=lambda x: abs(float(x) - ar_bg))
@@ -114,11 +82,16 @@ button[aria-label="Download"] {
114
  margin: 0 !important;
115
  padding: 6px !important;
116
  }
117
- button[aria-label="Share"], button[aria-label="Copy link"], button[aria-label="Open in new tab"] {
 
 
 
 
 
 
118
  display: none;
119
  }
120
  """, title="LBM Object Relighting") as demo:
121
-
122
  gr.Markdown("# Rindriçim i Objektit me Sfondin e Zgjedhur")
123
 
124
  with gr.Row():
@@ -126,18 +99,24 @@ button[aria-label="Share"], button[aria-label="Copy link"], button[aria-label="O
126
  with gr.Row():
127
  fg_image = gr.Image(type="pil", label="Imazhi Kryesor", image_mode="RGB", height=360)
128
  bg_image = gr.Image(type="pil", label="Sfondi i Ri", image_mode="RGB", height=360)
129
- bg_prompt = gr.Textbox(label="Sfondi (p.sh. 'në Milano')", placeholder="Shkruani një përshkrim për sfondin", lines=1)
130
 
131
  with gr.Row():
132
  submit_button = gr.Button("Rindriço", variant="primary")
 
 
133
 
134
- num_inference_steps = gr.Slider(minimum=1, maximum=4, value=4, step=1, visible=False)
135
 
136
  with gr.Column():
137
  output_slider = gr.ImageSlider(label="Para / Pas", type="numpy")
138
- output_slider.upload(fn=evaluate, inputs=[fg_image, bg_image, bg_prompt, num_inference_steps], outputs=[output_slider])
 
 
 
 
 
139
 
140
- submit_button.click(evaluate, inputs=[fg_image, bg_image, bg_prompt, num_inference_steps], outputs=[output_slider], show_progress="full", show_api=False)
141
 
142
  if __name__ == "__main__":
143
  demo.queue().launch(show_api=False)
 
5
  import gradio as gr
6
  import numpy as np
7
  import PIL
 
8
  import spaces
9
  import torch
10
  import yaml
11
  from huggingface_hub import hf_hub_download
12
+ from PIL import Image
13
  from safetensors.torch import load_file
14
  from torchvision.transforms import ToPILImage, ToTensor
15
  from transformers import AutoModelForImageSegmentation
 
16
  from utils import extract_object, get_model_from_config, resize_and_center_crop
17
 
18
  huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
 
31
  str(1920 / 512): (1920, 512),
32
  }
33
 
 
34
  MODEL_PATH = hf_hub_download("jasperai/LBM_relighting", "model.safetensors", token=huggingface_token)
35
  CONFIG_PATH = hf_hub_download("jasperai/LBM_relighting", "config.yaml", token=huggingface_token)
36
+
37
  with open(CONFIG_PATH, "r") as f:
38
  config = yaml.safe_load(f)
39
  model = get_model_from_config(**config)
40
  sd = load_file(MODEL_PATH)
41
  model.load_state_dict(sd, strict=True)
42
  model.to("cuda").to(torch.bfloat16)
 
 
43
  birefnet = AutoModelForImageSegmentation.from_pretrained("ZhengPeng7/BiRefNet", trust_remote_code=True).cuda()
44
+ image_size = (1024, 1024)
 
 
 
 
 
 
 
 
45
 
46
  @spaces.GPU
47
+ def evaluate(fg_image: PIL.Image.Image, bg_image: PIL.Image.Image, num_sampling_steps: int = 4):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  ori_h_bg, ori_w_bg = fg_image.size
49
  ar_bg = ori_h_bg / ori_w_bg
50
  closest_ar_bg = min(ASPECT_RATIOS, key=lambda x: abs(float(x) - ar_bg))
 
82
  margin: 0 !important;
83
  padding: 6px !important;
84
  }
85
+ button[aria-label="Share"] {
86
+ display: none;
87
+ }
88
+ button[aria-label="Copy link"] {
89
+ display: none;
90
+ }
91
+ button[aria-label="Open in new tab"] {
92
  display: none;
93
  }
94
  """, title="LBM Object Relighting") as demo:
 
95
  gr.Markdown("# Rindriçim i Objektit me Sfondin e Zgjedhur")
96
 
97
  with gr.Row():
 
99
  with gr.Row():
100
  fg_image = gr.Image(type="pil", label="Imazhi Kryesor", image_mode="RGB", height=360)
101
  bg_image = gr.Image(type="pil", label="Sfondi i Ri", image_mode="RGB", height=360)
 
102
 
103
  with gr.Row():
104
  submit_button = gr.Button("Rindriço", variant="primary")
105
+ with gr.Row():
106
+ num_inference_steps = gr.Slider(minimum=1, maximum=4, value=4, step=1, visible=False)
107
 
108
+ bg_gallery = gr.Gallery(object_fit="contain", visible=False)
109
 
110
  with gr.Column():
111
  output_slider = gr.ImageSlider(label="Para / Pas", type="numpy")
112
+ output_slider.upload(fn=evaluate, inputs=[fg_image, bg_image, num_inference_steps], outputs=[output_slider])
113
+
114
+ submit_button.click(evaluate, inputs=[fg_image, bg_image, num_inference_steps], outputs=[output_slider], show_progress="full", show_api=False)
115
+
116
+ def bg_gallery_selected(gal, evt: gr.SelectData):
117
+ return gal[evt.index][0]
118
 
119
+ bg_gallery.select(bg_gallery_selected, inputs=bg_gallery, outputs=bg_image)
120
 
121
  if __name__ == "__main__":
122
  demo.queue().launch(show_api=False)