Gemini899 commited on
Commit
c566cc0
·
verified ·
1 Parent(s): efc55ff

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +98 -45
app.py CHANGED
@@ -6,7 +6,9 @@ import random
6
  import spaces
7
  import torch
8
  from diffusers import Flux2Pipeline, Flux2Transformer2DModel
 
9
  from PIL import Image
 
10
 
11
  dtype = torch.bfloat16
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -14,19 +16,45 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
14
  MAX_SEED = np.iinfo(np.int32).max
15
  MAX_IMAGE_SIZE = 1024
16
 
17
- # Load the full pipeline WITH text encoder
18
- print("Loading FLUX.2 pipeline with text encoder...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  repo_id = "black-forest-labs/FLUX.2-dev"
20
 
21
- pipe = Flux2Pipeline.from_pretrained(
22
  repo_id,
 
23
  torch_dtype=torch.bfloat16
24
  )
25
 
26
- # Enable memory optimizations for ZeroGPU
27
- pipe.enable_model_cpu_offload()
 
 
 
 
 
28
 
29
- print("Pipeline loaded successfully!")
 
30
 
31
  def update_dimensions_from_image(image_list):
32
  """Update width/height sliders based on uploaded image aspect ratio."""
@@ -53,25 +81,20 @@ def update_dimensions_from_image(image_list):
53
 
54
  return new_width, new_height
55
 
56
- @spaces.GPU(duration=60)
57
- def infer(prompt, input_images=None, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=50, guidance_scale=2.5, progress=gr.Progress(track_tqdm=True)):
58
-
59
- if randomize_seed:
60
- seed = random.randint(0, MAX_SEED)
61
-
62
- image_list = None
63
- if input_images is not None and len(input_images) > 0:
64
- image_list = []
65
- for item in input_images:
66
- image_list.append(item[0])
67
 
68
- generator = torch.Generator(device="cuda").manual_seed(seed)
 
 
69
 
70
- progress(0.1, desc="Encoding prompt & generating image...")
71
 
72
- # Build pipeline arguments
73
  pipe_kwargs = {
74
- "prompt": prompt,
 
75
  "num_inference_steps": num_inference_steps,
76
  "guidance_scale": guidance_scale,
77
  "generator": generator,
@@ -79,32 +102,59 @@ def infer(prompt, input_images=None, seed=42, randomize_seed=False, width=1024,
79
  "height": height,
80
  }
81
 
82
- # Add images if provided
83
- if image_list is not None:
84
- pipe_kwargs["image"] = image_list
85
-
86
- # Run the pipeline - text encoding happens automatically inside
87
  image = pipe(**pipe_kwargs).images[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
  return image, seed
90
 
91
  examples = [
92
  ["Create a vase on a table in living room, the color of the vase is a gradient of color, starting with #02eb3c color and finishing with #edfa3c. The flowers inside the vase have the color #ff0088"],
93
- ["Photorealistic infographic showing the complete Berlin TV Tower (Fernsehturm) from ground base to antenna tip, full vertical view with entire structure visible including concrete shaft, metallic sphere, and antenna spire."],
94
  ["Soaking wet capybara taking shelter under a banana leaf in the rainy jungle, close up photo"],
95
- ["A kawaii die-cut sticker of a chubby orange cat, featuring big sparkly eyes and a happy smile with paws raised in greeting and a heart-shaped pink nose."],
96
  ]
97
 
98
  examples_images = [
99
  ["The person from image 1 is petting the cat from image 2, the bird from image 3 is next to them", ["woman1.webp", "cat_window.webp", "bird.webp"]]
100
  ]
101
 
102
- css = """
103
  #col-container {
104
  margin: 0 auto;
105
  max-width: 1200px;
106
  }
107
- .gallery-container img {
108
  object-fit: contain;
109
  }
110
  """
@@ -112,7 +162,7 @@ css = """
112
  with gr.Blocks() as demo:
113
 
114
  with gr.Column(elem_id="col-container"):
115
- gr.Markdown("""# FLUX.2 [dev]
116
  FLUX.2 [dev] is a 32B model rectified flow capable of generating, editing and combining images based on text instructions model [[model](https://huggingface.co/black-forest-labs/FLUX.2-dev)], [[blog](https://bfl.ai/blog/flux-2)]
117
  """)
118
  with gr.Row():
@@ -185,19 +235,23 @@ FLUX.2 [dev] is a 32B model rectified flow capable of generating, editing and co
185
  with gr.Column():
186
  result = gr.Image(label="Result", show_label=False)
187
 
188
- gr.Examples(
189
- examples=examples,
190
- fn=infer,
191
- inputs=[prompt],
192
- outputs=[result, seed]
193
- )
194
-
195
- gr.Examples(
196
- examples=examples_images,
197
- fn=infer,
198
- inputs=[prompt, input_images],
199
- outputs=[result, seed]
200
- )
 
 
 
 
201
 
202
  input_images.upload(
203
  fn=update_dimensions_from_image,
@@ -212,4 +266,3 @@ FLUX.2 [dev] is a 32B model rectified flow capable of generating, editing and co
212
  outputs=[result, seed]
213
  )
214
 
215
- demo.launch(css=css)
 
6
  import spaces
7
  import torch
8
  from diffusers import Flux2Pipeline, Flux2Transformer2DModel
9
+ import requests
10
  from PIL import Image
11
+ import base64
12
 
13
  dtype = torch.bfloat16
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
16
  MAX_SEED = np.iinfo(np.int32).max
17
  MAX_IMAGE_SIZE = 1024
18
 
19
+ def remote_text_encoder(prompts, max_retries=3):
20
+ from gradio_client import Client
21
+ import time
22
+
23
+ for attempt in range(max_retries):
24
+ try:
25
+ client = Client("multimodalart/mistral-text-encoder")
26
+ result = client.predict(
27
+ prompt=prompts,
28
+ api_name="/encode_text"
29
+ )
30
+ prompt_embeds = torch.load(result[0])
31
+ return prompt_embeds
32
+ except Exception as e:
33
+ print(f"Text encoder attempt {attempt + 1}/{max_retries} failed: {e}")
34
+ if attempt < max_retries - 1:
35
+ time.sleep(2)
36
+ else:
37
+ raise Exception(f"Text encoder failed after {max_retries} attempts: {e}")
38
+
39
+ # Load model
40
  repo_id = "black-forest-labs/FLUX.2-dev"
41
 
42
+ dit = Flux2Transformer2DModel.from_pretrained(
43
  repo_id,
44
+ subfolder="transformer",
45
  torch_dtype=torch.bfloat16
46
  )
47
 
48
+ pipe = Flux2Pipeline.from_pretrained(
49
+ repo_id,
50
+ text_encoder=None,
51
+ transformer=dit,
52
+ torch_dtype=torch.bfloat16
53
+ )
54
+ pipe.to(device)
55
 
56
+ # AOTI blocks temporarily disabled - HuggingFace needs to recompile for new ZeroGPU environment
57
+ # spaces.aoti_blocks_load(pipe.transformer, "zerogpu-aoti/FLUX.2", variant="fa3")
58
 
59
  def update_dimensions_from_image(image_list):
60
  """Update width/height sliders based on uploaded image aspect ratio."""
 
81
 
82
  return new_width, new_height
83
 
84
+ def get_duration(prompt_embeds, image_list, width, height, num_inference_steps, guidance_scale, seed, progress=gr.Progress(track_tqdm=True)):
85
+ num_images = 0 if image_list is None else len(image_list)
86
+ step_duration = 1 + 0.8 * num_images
87
+ return max(45, num_inference_steps * step_duration + 10)
 
 
 
 
 
 
 
88
 
89
+ @spaces.GPU(duration=get_duration)
90
+ def generate_image(prompt_embeds, image_list, width, height, num_inference_steps, guidance_scale, seed, progress=gr.Progress(track_tqdm=True)):
91
+ prompt_embeds = prompt_embeds.to(device)
92
 
93
+ generator = torch.Generator(device=device).manual_seed(seed)
94
 
 
95
  pipe_kwargs = {
96
+ "prompt_embeds": prompt_embeds,
97
+ "image": image_list,
98
  "num_inference_steps": num_inference_steps,
99
  "guidance_scale": guidance_scale,
100
  "generator": generator,
 
102
  "height": height,
103
  }
104
 
105
+ if progress:
106
+ progress(0, desc="Starting generation...")
107
+
 
 
108
  image = pipe(**pipe_kwargs).images[0]
109
+ return image
110
+
111
+ def infer(prompt, input_images=None, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=50, guidance_scale=2.5, progress=gr.Progress(track_tqdm=True)):
112
+
113
+ if randomize_seed:
114
+ seed = random.randint(0, MAX_SEED)
115
+
116
+ image_list = None
117
+ if input_images is not None and len(input_images) > 0:
118
+ image_list = []
119
+ for item in input_images:
120
+ image_list.append(item[0])
121
+
122
+ # Text Encoding
123
+ progress(0.1, desc="Encoding prompt...")
124
+ prompt_embeds = remote_text_encoder(prompt)
125
+
126
+ # Image Generation
127
+ progress(0.3, desc="Waiting for GPU...")
128
+ image = generate_image(
129
+ prompt_embeds,
130
+ image_list,
131
+ width,
132
+ height,
133
+ num_inference_steps,
134
+ guidance_scale,
135
+ seed,
136
+ progress
137
+ )
138
 
139
  return image, seed
140
 
141
  examples = [
142
  ["Create a vase on a table in living room, the color of the vase is a gradient of color, starting with #02eb3c color and finishing with #edfa3c. The flowers inside the vase have the color #ff0088"],
143
+ ["Photorealistic infographic showing the complete Berlin TV Tower (Fernsehturm) from ground base to antenna tip, full vertical view with entire structure visible including concrete shaft, metallic sphere, and antenna spire. Slight upward perspective angle looking up toward the iconic sphere, perfectly centered on clean white background. Left side labels with thin horizontal connector lines: the text '368m' in extra large bold dark grey numerals (#2D3748) positioned at exactly the antenna tip with 'TOTAL HEIGHT' in small caps below. The text '207m' in extra large bold with 'TELECAFÉ' in small caps below, with connector line touching the sphere precisely at the window level. Right side label with horizontal connector line touching the sphere's equator: the text '32m' in extra large bold dark grey numerals with 'SPHERE DIAMETER' in small caps below. Bottom section arranged in three balanced columns: Left - Large text '986' in extra bold dark grey with 'STEPS' in caps below. Center - 'BERLIN TV TOWER' in bold caps with 'FERNSEHTURM' in lighter weight below. Right - 'INAUGURATED' in bold caps with 'OCTOBER 3, 1969' below. All typography in modern sans-serif font (such as Inter or Helvetica), color #2D3748, clean minimal technical diagram style. Horizontal connector lines are thin, precise, and clearly visible, touching the tower structure at exact corresponding measurement points. Professional architectural elevation drawing aesthetic with dynamic low angle perspective creating sense of height and grandeur, poster-ready infographic design with perfect visual hierarchy."],
144
  ["Soaking wet capybara taking shelter under a banana leaf in the rainy jungle, close up photo"],
145
+ ["A kawaii die-cut sticker of a chubby orange cat, featuring big sparkly eyes and a happy smile with paws raised in greeting and a heart-shaped pink nose. The design should have smooth rounded lines with black outlines and soft gradient shading with pink cheeks."],
146
  ]
147
 
148
  examples_images = [
149
  ["The person from image 1 is petting the cat from image 2, the bird from image 3 is next to them", ["woman1.webp", "cat_window.webp", "bird.webp"]]
150
  ]
151
 
152
+ css="""
153
  #col-container {
154
  margin: 0 auto;
155
  max-width: 1200px;
156
  }
157
+ .gallery-container img{
158
  object-fit: contain;
159
  }
160
  """
 
162
  with gr.Blocks() as demo:
163
 
164
  with gr.Column(elem_id="col-container"):
165
+ gr.Markdown(f"""# FLUX.2 [dev]
166
  FLUX.2 [dev] is a 32B model rectified flow capable of generating, editing and combining images based on text instructions model [[model](https://huggingface.co/black-forest-labs/FLUX.2-dev)], [[blog](https://bfl.ai/blog/flux-2)]
167
  """)
168
  with gr.Row():
 
235
  with gr.Column():
236
  result = gr.Image(label="Result", show_label=False)
237
 
238
+ gr.Examples(
239
+ examples=examples,
240
+ fn=infer,
241
+ inputs=[prompt],
242
+ outputs=[result, seed],
243
+ cache_examples=True,
244
+ cache_mode="lazy"
245
+ )
246
+
247
+ gr.Examples(
248
+ examples=examples_images,
249
+ fn=infer,
250
+ inputs=[prompt, input_images],
251
+ outputs=[result, seed],
252
+ cache_examples=True,
253
+ cache_mode="lazy"
254
+ )
255
 
256
  input_images.upload(
257
  fn=update_dimensions_from_image,
 
266
  outputs=[result, seed]
267
  )
268