rokmr commited on
Commit
894f265
·
verified ·
1 Parent(s): 0b524ac

Updating with lazy loading

Browse files
Files changed (1) hide show
  1. app.py +38 -24
app.py CHANGED
@@ -1,4 +1,4 @@
1
- import spaces
2
  import torch
3
  from diffusers import Flux2Pipeline
4
  from huggingface_hub import get_token
@@ -10,11 +10,33 @@ import os
10
 
11
  # Configuration
12
  repo_id = "diffusers/FLUX.2-dev-bnb-4bit"
13
- device = "cuda" if torch.cuda.is_available() else "cpu"
14
- torch_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
15
 
16
- print(f"Using device: {device}")
17
- print(f"Using dtype: {torch_dtype}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  def remote_text_encoder(prompts):
20
  """Encode prompts using remote text encoder API."""
@@ -34,26 +56,12 @@ def remote_text_encoder(prompts):
34
  )
35
  response.raise_for_status()
36
  prompt_embeds = torch.load(io.BytesIO(response.content))
 
 
37
  return prompt_embeds.to(device)
38
  except Exception as e:
39
  raise Exception(f"Failed to encode prompt: {str(e)}")
40
 
41
- # Load the pipeline
42
- print("Loading Flux2 pipeline...")
43
- try:
44
- pipe = Flux2Pipeline.from_pretrained(
45
- repo_id,
46
- text_encoder=None,
47
- torch_dtype=torch_dtype,
48
- device_map="cuda"
49
- )
50
- if not torch.cuda.is_available():
51
- pipe = pipe.to(device)
52
- print("Pipeline loaded successfully!")
53
- except Exception as e:
54
- print(f"Error loading pipeline: {e}")
55
- raise
56
-
57
  def get_duration(num_inference_steps: int, input_image: Image.Image = None):
58
  """Calculate dynamic GPU duration based on inference steps and input image."""
59
  num_images = 0 if input_image is None else 1
@@ -82,9 +90,14 @@ def generate_image(
82
  if not prompt or prompt.strip() == "":
83
  raise gr.Error("Please enter a prompt!")
84
 
85
- progress(0, desc="Encoding prompt...")
86
 
87
  try:
 
 
 
 
 
88
  # Get prompt embeddings from remote encoder
89
  prompt_embeds = remote_text_encoder(prompt)
90
 
@@ -113,7 +126,7 @@ def generate_image(
113
 
114
  # Generate image
115
  with torch.inference_mode():
116
- image = pipe(**pipe_kwargs).images[0]
117
 
118
  progress(1.0, desc="Done!")
119
 
@@ -129,6 +142,7 @@ def generate_image(
129
  # Create Gradio interface
130
  with gr.Blocks(
131
  title="Flux2 Image Generator",
 
132
  ) as demo:
133
  gr.Markdown(
134
  """
@@ -191,7 +205,6 @@ with gr.Blocks(
191
  "🚀 Generate Image",
192
  variant="primary",
193
  size="lg",
194
- elem_classes="generate-btn"
195
  )
196
 
197
  gr.Markdown(
@@ -267,4 +280,5 @@ with gr.Blocks(
267
  )
268
 
269
  if __name__ == "__main__":
 
270
  demo.queue(max_size=20).launch()
 
1
+ import spaces # Import spaces FIRST, before any CUDA-related packages
2
  import torch
3
  from diffusers import Flux2Pipeline
4
  from huggingface_hub import get_token
 
10
 
11
  # Configuration
12
  repo_id = "diffusers/FLUX.2-dev-bnb-4bit"
13
+ torch_dtype = torch.bfloat16
 
14
 
15
+ print("Starting Flux2 Image Generator...")
16
+
17
+ # Global variable to hold the pipeline
18
+ pipe = None
19
+
20
+ def load_pipeline():
21
+ """Lazy load the pipeline when needed."""
22
+ global pipe
23
+ if pipe is None:
24
+ print("Loading Flux2 pipeline...")
25
+ device = "cuda" if torch.cuda.is_available() else "cpu"
26
+ print(f"Using device: {device}")
27
+
28
+ try:
29
+ pipe = Flux2Pipeline.from_pretrained(
30
+ repo_id,
31
+ text_encoder=None,
32
+ torch_dtype=torch_dtype,
33
+ device_map="auto"
34
+ )
35
+ print("Pipeline loaded successfully!")
36
+ except Exception as e:
37
+ print(f"Error loading pipeline: {e}")
38
+ raise
39
+ return pipe
40
 
41
  def remote_text_encoder(prompts):
42
  """Encode prompts using remote text encoder API."""
 
56
  )
57
  response.raise_for_status()
58
  prompt_embeds = torch.load(io.BytesIO(response.content))
59
+
60
+ device = "cuda" if torch.cuda.is_available() else "cpu"
61
  return prompt_embeds.to(device)
62
  except Exception as e:
63
  raise Exception(f"Failed to encode prompt: {str(e)}")
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  def get_duration(num_inference_steps: int, input_image: Image.Image = None):
66
  """Calculate dynamic GPU duration based on inference steps and input image."""
67
  num_images = 0 if input_image is None else 1
 
90
  if not prompt or prompt.strip() == "":
91
  raise gr.Error("Please enter a prompt!")
92
 
93
+ progress(0, desc="Loading model...")
94
 
95
  try:
96
+ # Load pipeline (lazy loading)
97
+ pipeline = load_pipeline()
98
+
99
+ progress(0.1, desc="Encoding prompt...")
100
+
101
  # Get prompt embeddings from remote encoder
102
  prompt_embeds = remote_text_encoder(prompt)
103
 
 
126
 
127
  # Generate image
128
  with torch.inference_mode():
129
+ image = pipeline(**pipe_kwargs).images[0]
130
 
131
  progress(1.0, desc="Done!")
132
 
 
142
  # Create Gradio interface
143
  with gr.Blocks(
144
  title="Flux2 Image Generator",
145
+ theme=gr.themes.Soft(),
146
  ) as demo:
147
  gr.Markdown(
148
  """
 
205
  "🚀 Generate Image",
206
  variant="primary",
207
  size="lg",
 
208
  )
209
 
210
  gr.Markdown(
 
280
  )
281
 
282
  if __name__ == "__main__":
283
+ print("Launching Gradio interface...")
284
  demo.queue(max_size=20).launch()