HAL1993 commited on
Commit
e2cb293
·
verified ·
1 Parent(s): 19054cb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -22
app.py CHANGED
@@ -5,14 +5,15 @@ from copy import deepcopy
5
  import gradio as gr
6
  import numpy as np
7
  import PIL
 
8
  import spaces
9
  import torch
10
  import yaml
11
  from huggingface_hub import hf_hub_download
12
- from PIL import Image
13
  from safetensors.torch import load_file
14
  from torchvision.transforms import ToPILImage, ToTensor
15
  from transformers import AutoModelForImageSegmentation
 
16
  from utils import extract_object, get_model_from_config, resize_and_center_crop
17
 
18
  huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
@@ -31,20 +32,52 @@ ASPECT_RATIOS = {
31
  str(1920 / 512): (1920, 512),
32
  }
33
 
 
34
  MODEL_PATH = hf_hub_download("jasperai/LBM_relighting", "model.safetensors", token=huggingface_token)
35
  CONFIG_PATH = hf_hub_download("jasperai/LBM_relighting", "config.yaml", token=huggingface_token)
36
-
37
  with open(CONFIG_PATH, "r") as f:
38
  config = yaml.safe_load(f)
39
  model = get_model_from_config(**config)
40
  sd = load_file(MODEL_PATH)
41
  model.load_state_dict(sd, strict=True)
42
  model.to("cuda").to(torch.bfloat16)
 
 
43
  birefnet = AutoModelForImageSegmentation.from_pretrained("ZhengPeng7/BiRefNet", trust_remote_code=True).cuda()
44
- image_size = (1024, 1024)
 
 
 
 
 
 
 
 
 
45
 
46
  @spaces.GPU
47
- def evaluate(fg_image: PIL.Image.Image, bg_image: PIL.Image.Image, num_sampling_steps: int = 4):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  ori_h_bg, ori_w_bg = fg_image.size
49
  ar_bg = ori_h_bg / ori_w_bg
50
  closest_ar_bg = min(ASPECT_RATIOS, key=lambda x: abs(float(x) - ar_bg))
@@ -82,41 +115,30 @@ button[aria-label="Download"] {
82
  margin: 0 !important;
83
  padding: 6px !important;
84
  }
85
- button[aria-label="Share"] {
86
- display: none;
87
- }
88
- button[aria-label="Copy link"] {
89
- display: none;
90
- }
91
- button[aria-label="Open in new tab"] {
92
  display: none;
93
  }
94
  """, title="LBM Object Relighting") as demo:
95
- gr.Markdown("Rindriçim i Objektit me Sfondin e Zgjedhur")
 
96
 
97
  with gr.Row():
98
  with gr.Column():
99
  with gr.Row():
100
  fg_image = gr.Image(type="pil", label="Imazhi Kryesor", image_mode="RGB", height=360)
101
  bg_image = gr.Image(type="pil", label="Sfondi i Ri", image_mode="RGB", height=360)
 
102
 
103
  with gr.Row():
104
  submit_button = gr.Button("Rindriço", variant="primary")
105
- with gr.Row():
106
- num_inference_steps = gr.Slider(minimum=1, maximum=4, value=4, step=1, visible=False)
107
 
108
- bg_gallery = gr.Gallery(object_fit="contain", visible=False)
109
 
110
  with gr.Column():
111
  output_slider = gr.ImageSlider(label="Para / Pas", type="numpy")
112
- output_slider.upload(fn=evaluate, inputs=[fg_image, bg_image, num_inference_steps], outputs=[output_slider])
113
-
114
- submit_button.click(evaluate, inputs=[fg_image, bg_image, num_inference_steps], outputs=[output_slider], show_progress="full", show_api=False)
115
-
116
- def bg_gallery_selected(gal, evt: gr.SelectData):
117
- return gal[evt.index][0]
118
 
119
- bg_gallery.select(bg_gallery_selected, inputs=bg_gallery, outputs=bg_image)
120
 
121
  if __name__ == "__main__":
122
  demo.queue().launch(show_api=False)
 
5
  import gradio as gr
6
  import numpy as np
7
  import PIL
8
+ from PIL import Image, ImageFilter
9
  import spaces
10
  import torch
11
  import yaml
12
  from huggingface_hub import hf_hub_download
 
13
  from safetensors.torch import load_file
14
  from torchvision.transforms import ToPILImage, ToTensor
15
  from transformers import AutoModelForImageSegmentation
16
+ from diffusers import StableDiffusionPipeline
17
  from utils import extract_object, get_model_from_config, resize_and_center_crop
18
 
19
  huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
 
32
  str(1920 / 512): (1920, 512),
33
  }
34
 
35
+ # Load relighting model
36
  MODEL_PATH = hf_hub_download("jasperai/LBM_relighting", "model.safetensors", token=huggingface_token)
37
  CONFIG_PATH = hf_hub_download("jasperai/LBM_relighting", "config.yaml", token=huggingface_token)
 
38
  with open(CONFIG_PATH, "r") as f:
39
  config = yaml.safe_load(f)
40
  model = get_model_from_config(**config)
41
  sd = load_file(MODEL_PATH)
42
  model.load_state_dict(sd, strict=True)
43
  model.to("cuda").to(torch.bfloat16)
44
+
45
+ # Load segmentation model
46
  birefnet = AutoModelForImageSegmentation.from_pretrained("ZhengPeng7/BiRefNet", trust_remote_code=True).cuda()
47
+
48
+ # Load Stable Diffusion pipeline for background generation
49
+ sd_pipe = StableDiffusionPipeline.from_pretrained(
50
+ "runwayml/stable-diffusion-v1-5",
51
+ torch_dtype=torch.float16,
52
+ revision="fp16",
53
+ use_auth_token=huggingface_token,
54
+ )
55
+ sd_pipe.to("cuda")
56
+ sd_pipe.enable_attention_slicing()
57
 
58
  @spaces.GPU
59
+ def generate_background_image(bg_prompt: str):
60
+ if not bg_prompt or bg_prompt.strip() == "":
61
+ return None
62
+ with torch.inference_mode():
63
+ bg_img = sd_pipe(prompt=bg_prompt, height=1024, width=1024, num_inference_steps=20).images[0]
64
+ # Optional blur radius — tweak as you like or expose as a parameter
65
+ bg_img = bg_img.filter(ImageFilter.GaussianBlur(radius=5))
66
+ return bg_img
67
+
68
+ @spaces.GPU
69
+ def evaluate(
70
+ fg_image: PIL.Image.Image,
71
+ bg_image: PIL.Image.Image,
72
+ bg_prompt: str,
73
+ num_sampling_steps: int = 4,
74
+ ):
75
+ # Generate background if prompt is given
76
+ if bg_prompt and bg_prompt.strip() != "":
77
+ generated_bg = generate_background_image(bg_prompt)
78
+ if generated_bg is not None:
79
+ bg_image = generated_bg
80
+
81
  ori_h_bg, ori_w_bg = fg_image.size
82
  ar_bg = ori_h_bg / ori_w_bg
83
  closest_ar_bg = min(ASPECT_RATIOS, key=lambda x: abs(float(x) - ar_bg))
 
115
  margin: 0 !important;
116
  padding: 6px !important;
117
  }
118
+ button[aria-label="Share"], button[aria-label="Copy link"], button[aria-label="Open in new tab"] {
 
 
 
 
 
 
119
  display: none;
120
  }
121
  """, title="LBM Object Relighting") as demo:
122
+
123
+ gr.Markdown("# Rindriçim i Objektit me Sfondin e Zgjedhur")
124
 
125
  with gr.Row():
126
  with gr.Column():
127
  with gr.Row():
128
  fg_image = gr.Image(type="pil", label="Imazhi Kryesor", image_mode="RGB", height=360)
129
  bg_image = gr.Image(type="pil", label="Sfondi i Ri", image_mode="RGB", height=360)
130
+ bg_prompt = gr.Textbox(label="Sfondi (p.sh. 'në Milano')", placeholder="Shkruani një përshkrim për sfondin", lines=1)
131
 
132
  with gr.Row():
133
  submit_button = gr.Button("Rindriço", variant="primary")
 
 
134
 
135
+ num_inference_steps = gr.Slider(minimum=1, maximum=4, value=4, step=1, visible=False)
136
 
137
  with gr.Column():
138
  output_slider = gr.ImageSlider(label="Para / Pas", type="numpy")
139
+ output_slider.upload(fn=evaluate, inputs=[fg_image, bg_image, bg_prompt, num_inference_steps], outputs=[output_slider])
 
 
 
 
 
140
 
141
+ submit_button.click(evaluate, inputs=[fg_image, bg_image, bg_prompt, num_inference_steps], outputs=[output_slider], show_progress="full", show_api=False)
142
 
143
  if __name__ == "__main__":
144
  demo.queue().launch(show_api=False)