Gemini899 commited on
Commit
0a0ac19
·
verified ·
1 Parent(s): ff37b4f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -84
app.py CHANGED
@@ -6,9 +6,7 @@ import random
6
  import spaces
7
  import torch
8
  from diffusers import Flux2Pipeline, Flux2Transformer2DModel
9
- import requests
10
  from PIL import Image
11
- import base64
12
 
13
  dtype = torch.bfloat16
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -16,45 +14,19 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
16
  MAX_SEED = np.iinfo(np.int32).max
17
  MAX_IMAGE_SIZE = 1024
18
 
19
- def remote_text_encoder(prompts, max_retries=3):
20
- from gradio_client import Client
21
- import time
22
-
23
- for attempt in range(max_retries):
24
- try:
25
- client = Client("multimodalart/mistral-text-encoder")
26
- result = client.predict(
27
- prompt=prompts,
28
- api_name="/encode_text"
29
- )
30
- prompt_embeds = torch.load(result[0])
31
- return prompt_embeds
32
- except Exception as e:
33
- print(f"Text encoder attempt {attempt + 1}/{max_retries} failed: {e}")
34
- if attempt < max_retries - 1:
35
- time.sleep(2)
36
- else:
37
- raise Exception(f"Text encoder failed after {max_retries} attempts: {e}")
38
-
39
- # Load model
40
  repo_id = "black-forest-labs/FLUX.2-dev"
41
 
42
- dit = Flux2Transformer2DModel.from_pretrained(
43
- repo_id,
44
- subfolder="transformer",
45
- torch_dtype=torch.bfloat16
46
- )
47
-
48
  pipe = Flux2Pipeline.from_pretrained(
49
  repo_id,
50
- text_encoder=None,
51
- transformer=dit,
52
  torch_dtype=torch.bfloat16
53
  )
54
- pipe.to(device)
55
 
56
- # AOTI blocks temporarily disabled - HuggingFace needs to recompile for new ZeroGPU environment
57
- # spaces.aoti_blocks_load(pipe.transformer, "zerogpu-aoti/FLUX.2", variant="fa3")
 
 
58
 
59
  def update_dimensions_from_image(image_list):
60
  """Update width/height sliders based on uploaded image aspect ratio."""
@@ -81,33 +53,7 @@ def update_dimensions_from_image(image_list):
81
 
82
  return new_width, new_height
83
 
84
- def get_duration(prompt_embeds, image_list, width, height, num_inference_steps, guidance_scale, seed, progress=gr.Progress(track_tqdm=True)):
85
- num_images = 0 if image_list is None else len(image_list)
86
- step_duration = 1 + 0.8 * num_images
87
- return max(45, num_inference_steps * step_duration + 10)
88
-
89
- @spaces.GPU(duration=get_duration)
90
- def generate_image(prompt_embeds, image_list, width, height, num_inference_steps, guidance_scale, seed, progress=gr.Progress(track_tqdm=True)):
91
- prompt_embeds = prompt_embeds.to(device)
92
-
93
- generator = torch.Generator(device=device).manual_seed(seed)
94
-
95
- pipe_kwargs = {
96
- "prompt_embeds": prompt_embeds,
97
- "image": image_list,
98
- "num_inference_steps": num_inference_steps,
99
- "guidance_scale": guidance_scale,
100
- "generator": generator,
101
- "width": width,
102
- "height": height,
103
- }
104
-
105
- if progress:
106
- progress(0, desc="Starting generation...")
107
-
108
- image = pipe(**pipe_kwargs).images[0]
109
- return image
110
-
111
  def infer(prompt, input_images=None, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=50, guidance_scale=2.5, progress=gr.Progress(track_tqdm=True)):
112
 
113
  if randomize_seed:
@@ -119,30 +65,34 @@ def infer(prompt, input_images=None, seed=42, randomize_seed=False, width=1024,
119
  for item in input_images:
120
  image_list.append(item[0])
121
 
122
- # Text Encoding
123
- progress(0.1, desc="Encoding prompt...")
124
- prompt_embeds = remote_text_encoder(prompt)
125
 
126
- # Image Generation
127
- progress(0.3, desc="Waiting for GPU...")
128
- image = generate_image(
129
- prompt_embeds,
130
- image_list,
131
- width,
132
- height,
133
- num_inference_steps,
134
- guidance_scale,
135
- seed,
136
- progress
137
- )
 
 
 
 
 
 
138
 
139
  return image, seed
140
 
141
  examples = [
142
  ["Create a vase on a table in living room, the color of the vase is a gradient of color, starting with #02eb3c color and finishing with #edfa3c. The flowers inside the vase have the color #ff0088"],
143
- ["Photorealistic infographic showing the complete Berlin TV Tower (Fernsehturm) from ground base to antenna tip, full vertical view with entire structure visible including concrete shaft, metallic sphere, and antenna spire. Slight upward perspective angle looking up toward the iconic sphere, perfectly centered on clean white background. Left side labels with thin horizontal connector lines: the text '368m' in extra large bold dark grey numerals (#2D3748) positioned at exactly the antenna tip with 'TOTAL HEIGHT' in small caps below. The text '207m' in extra large bold with 'TELECAFÉ' in small caps below, with connector line touching the sphere precisely at the window level. Right side label with horizontal connector line touching the sphere's equator: the text '32m' in extra large bold dark grey numerals with 'SPHERE DIAMETER' in small caps below. Bottom section arranged in three balanced columns: Left - Large text '986' in extra bold dark grey with 'STEPS' in caps below. Center - 'BERLIN TV TOWER' in bold caps with 'FERNSEHTURM' in lighter weight below. Right - 'INAUGURATED' in bold caps with 'OCTOBER 3, 1969' below. All typography in modern sans-serif font (such as Inter or Helvetica), color #2D3748, clean minimal technical diagram style. Horizontal connector lines are thin, precise, and clearly visible, touching the tower structure at exact corresponding measurement points. Professional architectural elevation drawing aesthetic with dynamic low angle perspective creating sense of height and grandeur, poster-ready infographic design with perfect visual hierarchy."],
144
  ["Soaking wet capybara taking shelter under a banana leaf in the rainy jungle, close up photo"],
145
- ["A kawaii die-cut sticker of a chubby orange cat, featuring big sparkly eyes and a happy smile with paws raised in greeting and a heart-shaped pink nose. The design should have smooth rounded lines with black outlines and soft gradient shading with pink cheeks."],
146
  ]
147
 
148
  examples_images = [
@@ -159,7 +109,7 @@ css="""
159
  }
160
  """
161
 
162
- with gr.Blocks() as demo:
163
 
164
  with gr.Column(elem_id="col-container"):
165
  gr.Markdown(f"""# FLUX.2 [dev]
@@ -240,8 +190,7 @@ FLUX.2 [dev] is a 32B model rectified flow capable of generating, editing and co
240
  fn=infer,
241
  inputs=[prompt],
242
  outputs=[result, seed],
243
- cache_examples=True,
244
- cache_mode="lazy"
245
  )
246
 
247
  gr.Examples(
@@ -249,8 +198,7 @@ FLUX.2 [dev] is a 32B model rectified flow capable of generating, editing and co
249
  fn=infer,
250
  inputs=[prompt, input_images],
251
  outputs=[result, seed],
252
- cache_examples=True,
253
- cache_mode="lazy"
254
  )
255
 
256
  input_images.upload(
@@ -266,4 +214,4 @@ FLUX.2 [dev] is a 32B model rectified flow capable of generating, editing and co
266
  outputs=[result, seed]
267
  )
268
 
269
- demo.launch(css=css)
 
6
  import spaces
7
  import torch
8
  from diffusers import Flux2Pipeline, Flux2Transformer2DModel
 
9
  from PIL import Image
 
10
 
11
  dtype = torch.bfloat16
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
14
  MAX_SEED = np.iinfo(np.int32).max
15
  MAX_IMAGE_SIZE = 1024
16
 
17
+ # Load the full pipeline WITH text encoder
18
+ print("Loading FLUX.2 pipeline with text encoder...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  repo_id = "black-forest-labs/FLUX.2-dev"
20
 
 
 
 
 
 
 
21
  pipe = Flux2Pipeline.from_pretrained(
22
  repo_id,
 
 
23
  torch_dtype=torch.bfloat16
24
  )
 
25
 
26
+ # Enable memory optimizations for ZeroGPU
27
+ pipe.enable_model_cpu_offload() # This offloads models to CPU when not in use, saving VRAM
28
+
29
+ print("Pipeline loaded successfully!")
30
 
31
  def update_dimensions_from_image(image_list):
32
  """Update width/height sliders based on uploaded image aspect ratio."""
 
53
 
54
  return new_width, new_height
55
 
56
+ @spaces.GPU(duration=120) # Increased duration since we're doing text encoding + generation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  def infer(prompt, input_images=None, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=50, guidance_scale=2.5, progress=gr.Progress(track_tqdm=True)):
58
 
59
  if randomize_seed:
 
65
  for item in input_images:
66
  image_list.append(item[0])
67
 
68
+ generator = torch.Generator(device="cuda").manual_seed(seed)
 
 
69
 
70
+ progress(0.1, desc="Encoding prompt & generating image...")
71
+
72
+ # Build pipeline arguments
73
+ pipe_kwargs = {
74
+ "prompt": prompt,
75
+ "num_inference_steps": num_inference_steps,
76
+ "guidance_scale": guidance_scale,
77
+ "generator": generator,
78
+ "width": width,
79
+ "height": height,
80
+ }
81
+
82
+ # Add images if provided
83
+ if image_list is not None:
84
+ pipe_kwargs["image"] = image_list
85
+
86
+ # Run the pipeline - text encoding happens automatically inside
87
+ image = pipe(**pipe_kwargs).images[0]
88
 
89
  return image, seed
90
 
91
  examples = [
92
  ["Create a vase on a table in living room, the color of the vase is a gradient of color, starting with #02eb3c color and finishing with #edfa3c. The flowers inside the vase have the color #ff0088"],
93
+ ["Photorealistic infographic showing the complete Berlin TV Tower (Fernsehturm) from ground base to antenna tip, full vertical view with entire structure visible including concrete shaft, metallic sphere, and antenna spire."],
94
  ["Soaking wet capybara taking shelter under a banana leaf in the rainy jungle, close up photo"],
95
+ ["A kawaii die-cut sticker of a chubby orange cat, featuring big sparkly eyes and a happy smile with paws raised in greeting and a heart-shaped pink nose."],
96
  ]
97
 
98
  examples_images = [
 
109
  }
110
  """
111
 
112
+ with gr.Blocks(css=css) as demo:
113
 
114
  with gr.Column(elem_id="col-container"):
115
  gr.Markdown(f"""# FLUX.2 [dev]
 
190
  fn=infer,
191
  inputs=[prompt],
192
  outputs=[result, seed],
193
+ cache_examples="lazy"
 
194
  )
195
 
196
  gr.Examples(
 
198
  fn=infer,
199
  inputs=[prompt, input_images],
200
  outputs=[result, seed],
201
+ cache_examples="lazy"
 
202
  )
203
 
204
  input_images.upload(
 
214
  outputs=[result, seed]
215
  )
216
 
217
+ demo.launch()