WonwoongCho commited on
Commit
d5ee1b3
·
1 Parent(s): 36cd713

debugging app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -14
app.py CHANGED
@@ -11,6 +11,15 @@ from src.attention_processor import FluxBlendedAttnProcessor2_0
11
  from src.utils_sample import set_seed, resize_and_add_margin
12
  import os
13
 
 
 
 
 
 
 
 
 
 
14
 
15
  @spaces.GPU
16
  def process_image_and_text(image, text, seed, pipeline):
@@ -105,18 +114,7 @@ header = """
105
  """
106
 
107
 
108
- def create_app(
109
- device: str = "cuda",
110
- ):
111
- dtype = torch.bfloat16
112
- token = os.environ.get("HF_TOKEN")
113
-
114
- pipe = FluxPipeline.from_pretrained(
115
- "black-forest-labs/FLUX.1-dev",
116
- torch_dtype=dtype,
117
- token=token
118
- )
119
- pipe = pipe.to(device)
120
 
121
  with gr.Blocks() as app:
122
  gr.Markdown(header, elem_id="header")
@@ -167,7 +165,7 @@ def create_app(
167
  key_changed_blended_attn_weights[changed_key] = value.to(dtype)
168
 
169
  missing_keys, unexpected_keys = pipe.transformer.load_state_dict(key_changed_blended_attn_weights, strict=False)
170
- pipe = pipe.to(device)
171
 
172
 
173
  with gr.Row():
@@ -188,5 +186,5 @@ def create_app(
188
 
189
 
190
  if __name__ == "__main__":
191
- demo = create_app(device="cuda")
192
  demo.launch(debug=True, ssr_mode=False)
 
11
  from src.utils_sample import set_seed, resize_and_add_margin
12
  import os
13
 
14
+ dtype = torch.bfloat16
15
+ token = os.environ.get("HF_TOKEN")
16
+
17
+ pipe = FluxPipeline.from_pretrained(
18
+ "black-forest-labs/FLUX.1-dev",
19
+ torch_dtype=dtype,
20
+ token=token
21
+ )
22
+ pipe = pipe.to("cuda")
23
 
24
  @spaces.GPU
25
  def process_image_and_text(image, text, seed, pipeline):
 
114
  """
115
 
116
 
117
+ def create_app():
 
 
 
 
 
 
 
 
 
 
 
118
 
119
  with gr.Blocks() as app:
120
  gr.Markdown(header, elem_id="header")
 
165
  key_changed_blended_attn_weights[changed_key] = value.to(dtype)
166
 
167
  missing_keys, unexpected_keys = pipe.transformer.load_state_dict(key_changed_blended_attn_weights, strict=False)
168
+ pipe = pipe.to("cuda")
169
 
170
 
171
  with gr.Row():
 
186
 
187
 
188
  if __name__ == "__main__":
189
+ demo = create_app()
190
  demo.launch(debug=True, ssr_mode=False)