concauu commited on
Commit
5e76468
·
verified ·
1 Parent(s): 233e97e

Update app.py

Browse files

Follow example from flux--dev

Files changed (1) hide show
  1. app.py +129 -63
app.py CHANGED
@@ -1,19 +1,26 @@
1
- from PIL import Image
2
  import gradio as gr
3
  import torch
4
- from diffusers import FluxPipeline
5
  from groq import Groq # Import the Groq library
6
  from cryptography.fernet import Fernet
7
  from huggingface_hub import login
8
- # from transformers import CLIPTokenizer
9
  import os
10
  os.environ['HF_HUB_DOWNLOAD_TIMEOUT'] = '120'
 
 
 
 
 
 
 
 
 
11
  def get_hf_token(encrypted_token):
12
  # Retrieve the decryption key from an environment variable
13
  key = "K4FlQbffvTcDxT2FIhrOPV1eue6ia45FFR3kqp2hHbM="
14
  if not key:
15
  raise ValueError("Missing decryption key! Set the DECRYPTION_KEY environment variable.")
16
-
 
17
  if isinstance(key, str):
18
  key = key.encode()
19
 
@@ -25,10 +32,19 @@ groq_client = Groq(api_key="gsk_0Rj7v0ZeHyFEpdwUMBuWWGdyb3FYGUesOkfhi7Gqba9rDXwI
25
  decrypted_token = get_hf_token("gAAAAABn3GfShExoJd50nau3B5ZJNiQ9dRD1ACO3XXMwVaIQMkmi59cL-MKGr6SYnsB0E2gGITJG2j29Ar9yjaZP-EC6hHsCBmwKSj4aFtTor9_n0_NdMBv1GtlxZRmwnQwriB-Xr94e")
26
  login(token=decrypted_token)
27
 
 
 
 
 
28
 
29
- pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.float32)
30
- pipe.enable_model_cpu_offload() #save some VRAM by offloading the model to CPU. Remove this if you have enough GPU power
31
 
 
 
 
 
 
32
 
33
  def enhance_prompt(user_prompt):
34
  """Enhances the given prompt using Groq and returns the refined prompt."""
@@ -63,67 +79,117 @@ def enhance_prompt(user_prompt):
63
  except Exception as e:
64
  enhanced = f"Error enhancing prompt: {str(e)}"
65
  return enhanced
66
- # tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
67
 
68
- # def truncate_prompt(prompt, max_length=77):
69
- # # Tokenize without adding special tokens
70
- # tokens = tokenizer(prompt, add_special_tokens=False)["input_ids"]
71
- # if len(tokens) > max_length:
72
- # tokens = tokens[:max_length]
73
- # # Decode tokens back to a string (skipping special tokens)
74
- # prompt = tokenizer.decode(tokens, skip_special_tokens=True)
75
- # return prompt
76
 
77
- def generate_image(prompt):
78
- """Generates an image using the refined prompt."""
79
- # truncated_prompt = truncate_prompt(prompt)
80
- try:
81
- image = pipe(
82
- prompt,
83
- height=1024,
84
- width=1024,
85
- guidance_scale=3.5,
86
- num_inference_steps=50,
87
- max_sequence_length=512,
88
- generator=torch.Generator("cpu").manual_seed(0)
89
- ).images[0]
90
- except Exception as e:
91
- # Optionally, handle errors (you can also return a default error image)
92
- image = None
93
- print(e)
94
- return image
95
 
96
- with gr.Blocks(css=".gradio-container {background-color: #f9f9f9; padding: 20px;}") as demo:
97
- gr.Markdown("# 2-Step Image Generator")
98
- gr.Markdown(
99
- "### Step 1: Prompt Enhancement\n"
100
- "Enter your original prompt below and click **Enhance Prompt**. "
101
- "The system will generate a detailed version of your prompt. You can modify the enhanced prompt before generating the image."
102
- )
 
103
 
104
- with gr.Row():
105
- original_prompt = gr.Textbox(
106
- label="Your Original Prompt",
107
- placeholder="Describe your idea here...",
108
- lines=3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  )
110
- enhance_button = gr.Button("Enhance Prompt")
111
-
112
- enhanced_prompt_box = gr.Textbox(
113
- label="Enhanced Prompt (Editable)",
114
- placeholder="The enhanced prompt will appear here...",
115
- lines=3
116
  )
117
-
118
- enhance_button.click(fn=enhance_prompt, inputs=original_prompt, outputs=enhanced_prompt_box)
119
-
120
- gr.Markdown("### Step 2: Image Generation\n"
121
- "Review and modify the enhanced prompt if necessary. Once you're ready, click **Generate Image** to create your image.")
122
-
123
- generate_button = gr.Button("Generate Image")
124
- image_output = gr.Image(label="Generated Image")
125
-
126
- generate_button.click(fn=generate_image, inputs=enhanced_prompt_box, outputs=image_output)
127
 
128
- if __name__ == "__main__":
129
- demo.launch(share=True)
 
 
1
  import gradio as gr
2
  import torch
 
3
  from groq import Groq # Import the Groq library
4
  from cryptography.fernet import Fernet
5
  from huggingface_hub import login
 
6
  import os
7
  os.environ['HF_HUB_DOWNLOAD_TIMEOUT'] = '120'
8
+ import numpy as np
9
+ import random
10
+ import spaces
11
+ from diffusers import DiffusionPipeline, FlowMatchEulerDiscreteScheduler, AutoencoderTiny, AutoencoderKL
12
+ from transformers import CLIPTextModel, CLIPTokenizer,T5EncoderModel, T5TokenizerFast
13
+ from live_preview_helpers import calculate_shift, retrieve_timesteps, flux_pipe_call_that_returns_an_iterable_of_images
14
+ os.environ['HF_HUB_DOWNLOAD_TIMEOUT'] = '120'
15
+ dtype = torch.bfloat16
16
+ device = "cuda" if torch.cuda.is_available() else "cpu"
17
  def get_hf_token(encrypted_token):
18
  # Retrieve the decryption key from an environment variable
19
  key = "K4FlQbffvTcDxT2FIhrOPV1eue6ia45FFR3kqp2hHbM="
20
  if not key:
21
  raise ValueError("Missing decryption key! Set the DECRYPTION_KEY environment variable.")
22
+
23
+ # Convert key from string to bytes if necessary
24
  if isinstance(key, str):
25
  key = key.encode()
26
 
 
32
  decrypted_token = get_hf_token("gAAAAABn3GfShExoJd50nau3B5ZJNiQ9dRD1ACO3XXMwVaIQMkmi59cL-MKGr6SYnsB0E2gGITJG2j29Ar9yjaZP-EC6hHsCBmwKSj4aFtTor9_n0_NdMBv1GtlxZRmwnQwriB-Xr94e")
33
  login(token=decrypted_token)
34
 
35
+ taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
36
+ good_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=dtype).to(device)
37
+ pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=dtype, vae=taef1).to(device)
38
+ torch.cuda.empty_cache()
39
 
40
+ MAX_SEED = np.iinfo(np.int32).max
41
+ MAX_IMAGE_SIZE = 2048
42
 
43
+ pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
44
+
45
+ decrypted_token = get_hf_token("gAAAAABn3GfShExoJd50nau3B5ZJNiQ9dRD1ACO3XXMwVaIQMkmi59cL-MKGr6SYnsB0E2gGITJG2j29Ar9yjaZP-EC6hHsCBmwKSj4aFtTor9_n0_NdMBv1GtlxZRmwnQwriB-Xr94e")
46
+ login(token=decrypted_token)
47
+ groq_client = Groq(api_key="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")
48
 
49
  def enhance_prompt(user_prompt):
50
  """Enhances the given prompt using Groq and returns the refined prompt."""
 
79
  except Exception as e:
80
  enhanced = f"Error enhancing prompt: {str(e)}"
81
  return enhanced
82
+ # ---------- End Prompt Enhancement Setup ----------
83
 
84
+ @spaces.GPU(duration=75)
85
+ def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, guidance_scale=3.5, num_inference_steps=28, progress=gr.Progress(track_tqdm=True)):
86
+ # First, refine the prompt using the enhancer
87
+ enhanced_prompt = enhance_prompt(prompt)
88
+ print("Enhanced prompt:", enhanced_prompt) # You might remove or modify this logging
 
 
 
89
 
90
+ if randomize_seed:
91
+ seed = random.randint(0, MAX_SEED)
92
+ generator = torch.Generator().manual_seed(seed)
93
+
94
+ # Use the enhanced prompt for image generation
95
+ for img in pipe.flux_pipe_call_that_returns_an_iterable_of_images(
96
+ prompt=enhanced_prompt,
97
+ guidance_scale=guidance_scale,
98
+ num_inference_steps=num_inference_steps,
99
+ width=width,
100
+ height=height,
101
+ generator=generator,
102
+ output_type="pil",
103
+ good_vae=good_vae,
104
+ ):
105
+ yield img, seed
 
 
106
 
107
+ css = """
108
+ #col-container {
109
+ margin: 0 auto;
110
+ max-width: 520px;
111
+ }
112
+ """
113
+
114
+ with gr.Blocks(css=css) as demo:
115
 
116
+ with gr.Column(elem_id="col-container"):
117
+ gr.Markdown(f"""# FLUX.1 [dev] with Prompt Enhancement
118
+ Use the enhanced prompt (generated by Groq) for a refined image generation experience.
119
+ """)
120
+
121
+ with gr.Row():
122
+ prompt = gr.Text(
123
+ label="Prompt",
124
+ show_label=False,
125
+ max_lines=1,
126
+ placeholder="Enter your prompt",
127
+ container=False,
128
+ )
129
+ run_button = gr.Button("Run", scale=0)
130
+
131
+ result = gr.Image(label="Result", show_label=False)
132
+
133
+ with gr.Accordion("Advanced Settings", open=False):
134
+ seed = gr.Slider(
135
+ label="Seed",
136
+ minimum=0,
137
+ maximum=MAX_SEED,
138
+ step=1,
139
+ value=0,
140
+ )
141
+
142
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
143
+
144
+ with gr.Row():
145
+ width = gr.Slider(
146
+ label="Width",
147
+ minimum=256,
148
+ maximum=MAX_IMAGE_SIZE,
149
+ step=32,
150
+ value=1024,
151
+ )
152
+ height = gr.Slider(
153
+ label="Height",
154
+ minimum=256,
155
+ maximum=MAX_IMAGE_SIZE,
156
+ step=32,
157
+ value=1024,
158
+ )
159
+
160
+ with gr.Row():
161
+ guidance_scale = gr.Slider(
162
+ label="Guidance Scale",
163
+ minimum=1,
164
+ maximum=15,
165
+ step=0.1,
166
+ value=3.5,
167
+ )
168
+ num_inference_steps = gr.Slider(
169
+ label="Number of inference steps",
170
+ minimum=1,
171
+ maximum=50,
172
+ step=1,
173
+ value=28,
174
+ )
175
+
176
+ gr.Examples(
177
+ examples=[
178
+ "a tiny astronaut hatching from an egg on the moon",
179
+ "a cat holding a sign that says hello world",
180
+ "an anime illustration of a wiener schnitzel",
181
+ ],
182
+ fn=infer,
183
+ inputs=[prompt],
184
+ outputs=[result, seed],
185
+ cache_examples="lazy"
186
  )
187
+
188
+ gr.on(
189
+ triggers=[run_button.click, prompt.submit],
190
+ fn=infer,
191
+ inputs=[prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
192
+ outputs=[result, seed]
193
  )
 
 
 
 
 
 
 
 
 
 
194
 
195
+ demo.launch(share=True)