WonwoongCho commited on
Commit
72db70e
·
1 Parent(s): 71c9c42

debugging app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -42
app.py CHANGED
@@ -11,47 +11,9 @@ from src.attention_processor import FluxBlendedAttnProcessor2_0
11
  from src.utils_sample import set_seed, resize_and_add_margin
12
  import os
13
 
14
- pipe = None
15
- dtype = torch.bfloat16
16
- token = os.environ.get("HF_TOKEN")
17
-
18
- pipe = FluxPipeline.from_pretrained(
19
- "black-forest-labs/FLUX.1-dev",
20
- torch_dtype=dtype,
21
- use_auth_token=token
22
- )
23
- pipe = pipe.to("cuda")
24
 
25
  @spaces.GPU
26
- def process_image_and_text(image, scale, seed, text):
27
- set_seed(seed)
28
- blended_attn_procs = {}
29
- for name, _ in pipe.transformer.attn_processors.items():
30
- if "single" in name:
31
- blended_attn_procs[name] = FluxBlendedAttnProcessor2_0(3072, ba_scale=scale, num_ref=1)
32
- else:
33
- blended_attn_procs[name] = pipe.transformer.attn_processors[name]
34
-
35
- pipe.transformer.set_attn_processor(blended_attn_procs)
36
- pipe.to(dtype)
37
-
38
- model_path = hf_hub_download(
39
- repo_id="WonwoongCho/IT-Blender",
40
- filename="FLUX/it-blender.bin",
41
- token=token
42
- )
43
-
44
- pretrained_blended_attn_weights = torch.load(model_path, map_location=pipe._execution_device)
45
-
46
- key_changed_blended_attn_weights = {}
47
- for key, value in pretrained_blended_attn_weights.items():
48
- block_idx = int(key.split(".")[0]) - 21
49
- k_or_v = key.split("_")[2]
50
- changed_key = f'single_transformer_blocks.{block_idx}.attn.processor.blended_attention_{k_or_v}_proj.weight'
51
- key_changed_blended_attn_weights[changed_key] = value.to(dtype)
52
-
53
- missing_keys, unexpected_keys = pipe.transformer.load_state_dict(key_changed_blended_attn_weights, strict=False)
54
- pipe = pipe.to("cuda")
55
 
56
  # image = Image.open(img_path).convert('RGB')
57
  image = resize_and_add_margin(image, target_size=512)
@@ -146,7 +108,16 @@ header = """
146
  def create_app(
147
  device: str = "cuda",
148
  ):
149
-
 
 
 
 
 
 
 
 
 
150
  with gr.Blocks() as app:
151
  gr.Markdown(header, elem_id="header")
152
  with gr.Row(equal_height=False):
@@ -169,16 +140,47 @@ def create_app(
169
  with gr.Column(variant="panel", elem_classes="outputPanel"):
170
  output_image = gr.Image(type="pil", elem_id="output")
171
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  with gr.Row():
173
  examples = gr.Examples(
174
  examples=get_samples(),
175
- inputs=[original_image, scale, seed, text],
176
  label="Examples",
177
  )
178
 
179
  submit_btn.click(
180
  fn=process_image_and_text,
181
- inputs=[original_image, scale, seed, text],
182
  outputs=output_image,
183
  )
184
 
 
11
  from src.utils_sample import set_seed, resize_and_add_margin
12
  import os
13
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  @spaces.GPU
16
+ def process_image_and_text(image, text, pipe):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  # image = Image.open(img_path).convert('RGB')
19
  image = resize_and_add_margin(image, target_size=512)
 
108
  def create_app(
109
  device: str = "cuda",
110
  ):
111
+ dtype = torch.bfloat16
112
+ token = os.environ.get("HF_TOKEN")
113
+
114
+ pipe = FluxPipeline.from_pretrained(
115
+ "black-forest-labs/FLUX.1-dev",
116
+ torch_dtype=dtype,
117
+ token=token
118
+ )
119
+ pipe = pipe.to(device)
120
+
121
  with gr.Blocks() as app:
122
  gr.Markdown(header, elem_id="header")
123
  with gr.Row(equal_height=False):
 
140
  with gr.Column(variant="panel", elem_classes="outputPanel"):
141
  output_image = gr.Image(type="pil", elem_id="output")
142
 
143
+
144
+ set_seed(seed)
145
+ blended_attn_procs = {}
146
+ for name, _ in pipe.transformer.attn_processors.items():
147
+ if "single" in name:
148
+ blended_attn_procs[name] = FluxBlendedAttnProcessor2_0(3072, ba_scale=scale, num_ref=1)
149
+ else:
150
+ blended_attn_procs[name] = pipe.transformer.attn_processors[name]
151
+
152
+ pipe.transformer.set_attn_processor(blended_attn_procs)
153
+ pipe.to(dtype)
154
+
155
+ model_path = hf_hub_download(
156
+ repo_id="WonwoongCho/IT-Blender",
157
+ filename="FLUX/it-blender.bin",
158
+ token=token
159
+ )
160
+
161
+ pretrained_blended_attn_weights = torch.load(model_path, map_location=pipe._execution_device)
162
+
163
+ key_changed_blended_attn_weights = {}
164
+ for key, value in pretrained_blended_attn_weights.items():
165
+ block_idx = int(key.split(".")[0]) - 21
166
+ k_or_v = key.split("_")[2]
167
+ changed_key = f'single_transformer_blocks.{block_idx}.attn.processor.blended_attention_{k_or_v}_proj.weight'
168
+ key_changed_blended_attn_weights[changed_key] = value.to(dtype)
169
+
170
+ missing_keys, unexpected_keys = pipe.transformer.load_state_dict(key_changed_blended_attn_weights, strict=False)
171
+ pipe = pipe.to(device)
172
+
173
+
174
  with gr.Row():
175
  examples = gr.Examples(
176
  examples=get_samples(),
177
+ inputs=[original_image, text],
178
  label="Examples",
179
  )
180
 
181
  submit_btn.click(
182
  fn=process_image_and_text,
183
+ inputs=[original_image, text, pipe],
184
  outputs=output_image,
185
  )
186