WonwoongCho commited on
Commit
0e2a4e5
·
1 Parent(s): 310a7f1

add torchvision to requirement

Browse files
Files changed (1) hide show
  1. app.py +19 -19
app.py CHANGED
@@ -11,29 +11,37 @@ from src.attention_processor import FluxBlendedAttnProcessor2_0
11
  from src.utils_sample import set_seed, resize_and_add_margin
12
  import os
13
 
14
- pipe = None
 
 
 
 
 
 
 
 
 
 
15
 
16
  @spaces.GPU
17
- def process_image_and_text(image, scale, seed, text, pipe):
18
  set_seed(seed)
19
- print("execution_device 1", pipe._execution_device)
20
  blended_attn_procs = {}
21
  for name, _ in pipe.transformer.attn_processors.items():
22
  if "single" in name:
23
  blended_attn_procs[name] = FluxBlendedAttnProcessor2_0(3072, ba_scale=scale, num_ref=1)
24
  else:
25
  blended_attn_procs[name] = pipe.transformer.attn_processors[name]
26
-
27
  pipe.transformer.set_attn_processor(blended_attn_procs)
28
- pipe = pipe.to(dtype)
29
- pipe = pipe.to("cuda")
30
- print("execution_device 2", pipe._execution_device)
31
 
32
  model_path = hf_hub_download(
33
  repo_id="WonwoongCho/IT-Blender",
34
  filename="FLUX/it-blender.bin",
35
  token=token
36
  )
 
37
  pretrained_blended_attn_weights = torch.load(model_path, map_location=pipe._execution_device)
38
 
39
  key_changed_blended_attn_weights = {}
@@ -45,6 +53,8 @@ def process_image_and_text(image, scale, seed, text, pipe):
45
 
46
  missing_keys, unexpected_keys = pipe.transformer.load_state_dict(key_changed_blended_attn_weights, strict=False)
47
 
 
 
48
  # image = Image.open(img_path).convert('RGB')
49
  image = resize_and_add_margin(image, target_size=512)
50
 
@@ -137,16 +147,6 @@ header = """
137
 
138
  def create_app():
139
 
140
- dtype = torch.bfloat16
141
- token = os.environ.get("HF_TOKEN")
142
-
143
- pipe = FluxPipeline.from_pretrained(
144
- "black-forest-labs/FLUX.1-dev",
145
- torch_dtype=dtype,
146
- use_auth_token=token
147
- )
148
- pipe = pipe.to("cuda")
149
-
150
  with gr.Blocks() as app:
151
  gr.Markdown(header, elem_id="header")
152
  with gr.Row(equal_height=False):
@@ -175,10 +175,10 @@ def create_app():
175
  inputs=[original_image, scale, seed, text],
176
  label="Examples",
177
  )
178
-
179
  submit_btn.click(
180
  fn=process_image_and_text,
181
- inputs=[original_image, scale, seed, text, pipe],
182
  outputs=output_image,
183
  )
184
 
 
11
  from src.utils_sample import set_seed, resize_and_add_margin
12
  import os
13
 
14
+
15
+ dtype = torch.bfloat16
16
+ token = os.environ.get("HF_TOKEN")
17
+
18
+ pipe = FluxPipeline.from_pretrained(
19
+ "black-forest-labs/FLUX.1-dev",
20
+ torch_dtype=dtype,
21
+ use_auth_token=token
22
+ )
23
+ pipe = pipe.to("cuda")
24
+
25
 
26
  @spaces.GPU
27
+ def process_image_and_text(image, scale, seed, text):
28
  set_seed(seed)
 
29
  blended_attn_procs = {}
30
  for name, _ in pipe.transformer.attn_processors.items():
31
  if "single" in name:
32
  blended_attn_procs[name] = FluxBlendedAttnProcessor2_0(3072, ba_scale=scale, num_ref=1)
33
  else:
34
  blended_attn_procs[name] = pipe.transformer.attn_processors[name]
35
+
36
  pipe.transformer.set_attn_processor(blended_attn_procs)
37
+ pipe.to(dtype)
 
 
38
 
39
  model_path = hf_hub_download(
40
  repo_id="WonwoongCho/IT-Blender",
41
  filename="FLUX/it-blender.bin",
42
  token=token
43
  )
44
+
45
  pretrained_blended_attn_weights = torch.load(model_path, map_location=pipe._execution_device)
46
 
47
  key_changed_blended_attn_weights = {}
 
53
 
54
  missing_keys, unexpected_keys = pipe.transformer.load_state_dict(key_changed_blended_attn_weights, strict=False)
55
 
56
+ pipe = pipe.to("cuda")
57
+
58
  # image = Image.open(img_path).convert('RGB')
59
  image = resize_and_add_margin(image, target_size=512)
60
 
 
147
 
148
  def create_app():
149
 
 
 
 
 
 
 
 
 
 
 
150
  with gr.Blocks() as app:
151
  gr.Markdown(header, elem_id="header")
152
  with gr.Row(equal_height=False):
 
175
  inputs=[original_image, scale, seed, text],
176
  label="Examples",
177
  )
178
+
179
  submit_btn.click(
180
  fn=process_image_and_text,
181
+ inputs=[original_image, scale, seed, text],
182
  outputs=output_image,
183
  )
184