concauu commited on
Commit
fc18b7e
·
verified ·
1 Parent(s): 90bbc81

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +103 -66
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import gradio as gr
2
  import torch
3
- from groq import Groq # Import the Groq library
4
  from cryptography.fernet import Fernet
5
  from huggingface_hub import login
6
  import os
@@ -8,9 +8,14 @@ os.environ['HF_HUB_DOWNLOAD_TIMEOUT'] = '120'
8
  import numpy as np
9
  import random
10
  import spaces
11
- from diffusers import DiffusionPipeline, FlowMatchEulerDiscreteScheduler, AutoencoderTiny, AutoencoderKL
12
- from transformers import CLIPTextModel, CLIPTokenizer,T5EncoderModel, T5TokenizerFast
13
  from live_preview_helpers import calculate_shift, retrieve_timesteps, flux_pipe_call_that_returns_an_iterable_of_images
 
 
 
 
 
14
  os.environ['HF_HUB_DOWNLOAD_TIMEOUT'] = '120'
15
  dtype = torch.bfloat16
16
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -28,9 +33,14 @@ def get_hf_token(encrypted_token):
28
  # Decrypt and decode the token
29
  decrypted_token = f.decrypt(encrypted_token).decode()
30
  return decrypted_token
 
31
  groq_client = Groq(api_key="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")
32
  decrypted_token = get_hf_token("gAAAAABn3GfShExoJd50nau3B5ZJNiQ9dRD1ACO3XXMwVaIQMkmi59cL-MKGr6SYnsB0E2gGITJG2j29Ar9yjaZP-EC6hHsCBmwKSj4aFtTor9_n0_NdMBv1GtlxZRmwnQwriB-Xr94e")
33
  login(token=decrypted_token)
 
 
 
 
34
 
35
  taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
36
  good_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=dtype).to(device)
@@ -42,9 +52,45 @@ MAX_IMAGE_SIZE = 2048
42
 
43
  pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
44
 
45
- decrypted_token = get_hf_token("gAAAAABn3GfShExoJd50nau3B5ZJNiQ9dRD1ACO3XXMwVaIQMkmi59cL-MKGr6SYnsB0E2gGITJG2j29Ar9yjaZP-EC6hHsCBmwKSj4aFtTor9_n0_NdMBv1GtlxZRmwnQwriB-Xr94e")
46
- login(token=decrypted_token)
47
- groq_client = Groq(api_key="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")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
  @spaces.GPU(duration=75)
50
  def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, guidance_scale=3.5, num_inference_steps=28, progress=gr.Progress(track_tqdm=True)):
@@ -121,88 +167,79 @@ css = """
121
  """
122
 
123
  with gr.Blocks(css=css) as demo:
 
 
124
  with gr.Column(elem_id="col-container"):
125
- gr.Markdown("# FLUX.1 [dev] with Prompt Enhancement")
 
 
126
  gr.Markdown("### Step 1: Enhance Your Prompt")
127
- # Original prompt input and enhancement button
128
- original_prompt = gr.Textbox(
129
- label="Original Prompt",
130
- placeholder="Enter your idea here...",
131
- lines=2
132
- )
133
  enhance_button = gr.Button("Enhance Prompt")
134
- # Editable textbox that will hold the enhanced prompt
135
- enhanced_prompt = gr.Textbox(
136
- label="Enhanced Prompt (Editable)",
137
- placeholder="The enhanced prompt will appear here...",
138
- lines=2
139
- )
140
- # When clicked, this button calls the enhance_prompt function.
141
- enhance_button.click(fn=enhance_prompt, inputs=original_prompt, outputs=enhanced_prompt)
142
 
143
- gr.Markdown("### Step 2: Generate Image Using Enhanced Prompt")
 
144
  with gr.Row():
145
- run_button = gr.Button("Generate Image", scale=0)
146
  result = gr.Image(label="Result", show_label=False)
147
 
148
- with gr.Accordion("Advanced Settings", open=False):
149
- seed = gr.Slider(
150
- label="Seed",
151
- minimum=0,
152
- maximum=MAX_SEED,
153
- step=1,
154
- value=0,
155
- )
156
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
157
  with gr.Row():
158
- width = gr.Slider(
159
- label="Width",
160
- minimum=256,
161
- maximum=MAX_IMAGE_SIZE,
162
- step=32,
163
- value=1024,
164
- )
165
- height = gr.Slider(
166
- label="Height",
167
- minimum=256,
168
- maximum=MAX_IMAGE_SIZE,
169
- step=32,
170
- value=1024,
171
- )
172
  with gr.Row():
173
- guidance_scale = gr.Slider(
174
- label="Guidance Scale",
175
- minimum=1,
176
- maximum=15,
177
- step=0.1,
178
- value=3.5,
179
- )
180
- num_inference_steps = gr.Slider(
181
- label="Number of inference steps",
182
- minimum=1,
183
- maximum=50,
184
- step=1,
185
- value=28,
186
- )
187
 
 
188
  gr.Examples(
189
  examples=[
190
  "a tiny astronaut hatching from an egg on the moon",
191
  "a cat holding a sign that says hello world",
192
  "an anime illustration of a wiener schnitzel",
193
  ],
194
- fn=infer,
195
- inputs=enhanced_prompt, # Uses the enhanced prompt
196
  outputs=[result, seed],
 
197
  cache_examples="lazy"
198
  )
199
-
200
- # Trigger the original image generation code using the enhanced prompt.
201
- gr.on(
202
- triggers=[run_button.click, enhanced_prompt.submit],
 
 
 
 
 
 
 
 
 
 
 
 
 
203
  fn=infer,
204
  inputs=[enhanced_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
205
  outputs=[result, seed]
 
 
 
 
 
 
 
 
206
  )
207
 
208
  demo.launch(share=True)
 
1
  import gradio as gr
2
  import torch
3
+ from groq import Groq
4
  from cryptography.fernet import Fernet
5
  from huggingface_hub import login
6
  import os
 
8
  import numpy as np
9
  import random
10
  import spaces
11
+ from diffusers import DiffusionPipeline, FlowMatchEulerDiscreteScheduler, AutoencoderTiny, AutoencoderKL
12
+ from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
13
  from live_preview_helpers import calculate_shift, retrieve_timesteps, flux_pipe_call_that_returns_an_iterable_of_images
14
+ from io import BytesIO
15
+ import base64
16
+
17
+
18
+
19
  os.environ['HF_HUB_DOWNLOAD_TIMEOUT'] = '120'
20
  dtype = torch.bfloat16
21
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
33
  # Decrypt and decode the token
34
  decrypted_token = f.decrypt(encrypted_token).decode()
35
  return decrypted_token
36
+
37
  groq_client = Groq(api_key="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")
38
  decrypted_token = get_hf_token("gAAAAABn3GfShExoJd50nau3B5ZJNiQ9dRD1ACO3XXMwVaIQMkmi59cL-MKGr6SYnsB0E2gGITJG2j29Ar9yjaZP-EC6hHsCBmwKSj4aFtTor9_n0_NdMBv1GtlxZRmwnQwriB-Xr94e")
39
  login(token=decrypted_token)
40
+ decrypted_token = get_hf_token("gAAAAABn3GfShExoJd50nau3B5ZJNiQ9dRD1ACO3XXMwVaIQMkmi59cL-MKGr6SYnsB0E2gGITJG2j29Ar9yjaZP-EC6hHsCBmwKSj4aFtTor9_n0_NdMBv1GtlxZRmwnQwriB-Xr94e")
41
+ login(token=decrypted_token)
42
+ groq_client = Groq(api_key="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")
43
+
44
 
45
  taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
46
  good_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=dtype).to(device)
 
52
 
53
  pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
54
 
55
+
56
+ # History functions
57
+ def append_to_history(image, prompt, seed, width, height, guidance_scale, steps, history):
58
+ if image is None:
59
+ return history
60
+
61
+ buffered = BytesIO()
62
+ image.save(buffered, format="PNG")
63
+ img_bytes = buffered.getvalue()
64
+
65
+ return history + [{
66
+ "image": img_bytes,
67
+ "prompt": prompt,
68
+ "seed": seed,
69
+ "width": width,
70
+ "height": height,
71
+ "guidance_scale": guidance_scale,
72
+ "steps": steps,
73
+ }]
74
+
75
+ def create_history_html(history):
76
+ html = "<div style='display: flex; flex-direction: column; gap: 20px; margin: 20px;'>"
77
+ for i, entry in enumerate(reversed(history)):
78
+ img_str = base64.b64encode(entry["image"]).decode()
79
+ html += f"""
80
+ <div style='display: flex; gap: 20px; padding: 20px; background: #f5f5f5; border-radius: 10px;'>
81
+ <img src="data:image/png;base64,{img_str}" style="width: 150px; height: 150px; object-fit: cover; border-radius: 5px;"/>
82
+ <div style='flex: 1;'>
83
+ <h3 style='margin: 0;'>Generation #{len(history)-i}</h3>
84
+ <p><strong>Prompt:</strong> {entry["prompt"]}</p>
85
+ <p><strong>Seed:</strong> {entry["seed"]}</p>
86
+ <p><strong>Size:</strong> {entry["width"]}x{entry["height"]}</p>
87
+ <p><strong>Guidance:</strong> {entry["guidance_scale"]}</p>
88
+ <p><strong>Steps:</strong> {entry["steps"]}</p>
89
+ </div>
90
+ </div>
91
+ """
92
+ return html + "</div>" if history else "<p style='margin: 20px;'>No generations yet</p>"
93
+
94
 
95
  @spaces.GPU(duration=75)
96
  def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, guidance_scale=3.5, num_inference_steps=28, progress=gr.Progress(track_tqdm=True)):
 
167
  """
168
 
169
  with gr.Blocks(css=css) as demo:
170
+ history_state = gr.State([])
171
+
172
  with gr.Column(elem_id="col-container"):
173
+ gr.Markdown("# FLUX.1 [dev] with History Tracking")
174
+
175
+ # Prompt section
176
  gr.Markdown("### Step 1: Enhance Your Prompt")
177
+ original_prompt = gr.Textbox(label="Original Prompt", lines=2)
 
 
 
 
 
178
  enhance_button = gr.Button("Enhance Prompt")
179
+ enhanced_prompt = gr.Textbox(label="Enhanced Prompt (Editable)", lines=2)
180
+ enhance_button.click(enhance_prompt, original_prompt, enhanced_prompt)
 
 
 
 
 
 
181
 
182
+ # Generation section
183
+ gr.Markdown("### Step 2: Generate Image")
184
  with gr.Row():
185
+ run_button = gr.Button("Generate Image", variant="primary")
186
  result = gr.Image(label="Result", show_label=False)
187
 
188
+ # Advanced settings
189
+ with gr.Accordion("Advanced Settings"):
190
+ seed = gr.Slider(0, MAX_SEED, value=0, label="Seed")
191
+ randomize_seed = gr.Checkbox(True, label="Randomize seed")
 
 
 
 
 
192
  with gr.Row():
193
+ width = gr.Slider(256, MAX_IMAGE_SIZE, 1024, step=32, label="Width")
194
+ height = gr.Slider(256, MAX_IMAGE_SIZE, 1024, step=32, label="Height")
 
 
 
 
 
 
 
 
 
 
 
 
195
  with gr.Row():
196
+ guidance_scale = gr.Slider(1, 15, 3.5, step=0.1, label="Guidance Scale")
197
+ num_inference_steps = gr.Slider(1, 50, 28, step=1, label="Inference Steps")
198
+
199
+ # History section
200
+ with gr.Accordion("Generation History", open=False):
201
+ history_display = gr.HTML("<p style='margin: 20px;'>No generations yet</p>")
 
 
 
 
 
 
 
 
202
 
203
+ # Examples
204
  gr.Examples(
205
  examples=[
206
  "a tiny astronaut hatching from an egg on the moon",
207
  "a cat holding a sign that says hello world",
208
  "an anime illustration of a wiener schnitzel",
209
  ],
210
+ inputs=enhanced_prompt,
 
211
  outputs=[result, seed],
212
+ fn=infer,
213
  cache_examples="lazy"
214
  )
215
+
216
+ # Event handling
217
+ generation_event = run_button.click(
218
+ fn=infer,
219
+ inputs=[enhanced_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
220
+ outputs=[result, seed]
221
+ ).then(
222
+ fn=append_to_history,
223
+ inputs=[result, enhanced_prompt, seed, width, height, guidance_scale, num_inference_steps, history_state],
224
+ outputs=history_state
225
+ ).then(
226
+ fn=create_history_html,
227
+ inputs=history_state,
228
+ outputs=history_display
229
+ )
230
+
231
+ enhanced_prompt.submit(
232
  fn=infer,
233
  inputs=[enhanced_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
234
  outputs=[result, seed]
235
+ ).then(
236
+ fn=append_to_history,
237
+ inputs=[result, enhanced_prompt, seed, width, height, guidance_scale, num_inference_steps, history_state],
238
+ outputs=history_state
239
+ ).then(
240
+ fn=create_history_html,
241
+ inputs=history_state,
242
+ outputs=history_display
243
  )
244
 
245
  demo.launch(share=True)