concauu commited on
Commit
cffbc7f
·
verified ·
1 Parent(s): 5a1ddda

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -94
app.py CHANGED
@@ -1,30 +1,24 @@
1
  import gradio as gr
2
  import torch
3
- from groq import Groq
4
- from cryptography.fernet import Fernet
5
- from huggingface_hub import login
6
  import os
7
- os.environ['HF_HUB_DOWNLOAD_TIMEOUT'] = '120'
8
  import numpy as np
9
  import random
10
- import spaces
11
- from diffusers import DiffusionPipeline, FlowMatchEulerDiscreteScheduler, AutoencoderTiny, AutoencoderKL, UNet2DConditionModel
12
- from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast, T5Tokenizer, T5EncoderModel
13
- from live_preview_helpers import calculate_shift, retrieve_timesteps, flux_pipe_call_that_returns_an_iterable_of_images
14
- from io import BytesIO
15
  import base64
16
- from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
17
- import torch.nn.functional as F
18
- import torch.nn as nn
 
 
 
 
 
19
  os.environ['HF_HUB_DOWNLOAD_TIMEOUT'] = '120'
20
  dtype = torch.bfloat16
21
  device = "cuda" if torch.cuda.is_available() else "cpu"
22
 
23
-
24
  def get_hf_token(encrypted_token):
25
  key = "K4FlQbffvTcDxT2FIhrOPV1eue6ia45FFR3kqp2hHbM="
26
- if not key:
27
- raise ValueError("Missing decryption key! Set the DECRYPTION_KEY environment variable.")
28
  if isinstance(key, str):
29
  key = key.encode()
30
  f = Fernet(key)
@@ -34,39 +28,31 @@ def get_hf_token(encrypted_token):
34
  groq_client = Groq(api_key="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")
35
  decrypted_token = get_hf_token("gAAAAABn3GfShExoJd50nau3B5ZJNiQ9dRD1ACO3XXMwVaIQMkmi59cL-MKGr6SYnsB0E2gGITJG2j29Ar9yjaZP-EC6hHsCBmwKSj4aFtTor9_n0_NdMBv1GtlxZRmwnQwriB-Xr94e")
36
  login(token=decrypted_token)
37
- decrypted_token = get_hf_token("gAAAAABn3GfShExoJd50nau3B5ZJNiQ9dRD1ACO3XXMwVaIQMkmi59cL-MKGr6SYnsB0E2gGITJG2j29Ar9yjaZP-EC6hHsCBmwKSj4aFtTor9_n0_NdMBv1GtlxZRmwnQwriB-Xr94e")
38
- login(token=decrypted_token)
39
- groq_client = Groq(api_key="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")
40
-
41
-
42
- dtype = torch.bfloat16
43
- device = "cuda" if torch.cuda.is_available() else "cpu"
44
 
 
45
  taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
46
  good_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=dtype).to(device)
47
  pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=dtype, vae=taef1).to(device)
 
48
  torch.cuda.empty_cache()
 
49
 
 
50
  MAX_SEED = np.iinfo(np.int32).max
51
  MAX_IMAGE_SIZE = 2048
52
 
53
- pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
54
-
55
- # ----- HISTORY FUNCTIONS & GRADIO INTERFACE -----
56
  def append_to_history(image, prompt, seed, width, height, guidance_scale, steps, history):
57
  if image is None:
58
  return history
59
  from PIL import Image
60
- import numpy as np
61
  if isinstance(image, np.ndarray):
62
- if image.dtype == np.uint8:
63
- image = Image.fromarray(image)
64
- else:
65
- image = Image.fromarray((image * 255).astype(np.uint8))
66
  buffered = BytesIO()
67
  image.save(buffered, format="PNG")
68
  img_bytes = buffered.getvalue()
69
- return history + [{
70
  "image": img_bytes,
71
  "prompt": prompt,
72
  "seed": seed,
@@ -74,9 +60,12 @@ def append_to_history(image, prompt, seed, width, height, guidance_scale, steps,
74
  "height": height,
75
  "guidance_scale": guidance_scale,
76
  "steps": steps,
77
- }]
 
78
 
79
  def create_history_html(history):
 
 
80
  html = "<div style='display: flex; flex-direction: column; gap: 20px; margin: 20px;'>"
81
  for i, entry in enumerate(reversed(history)):
82
  img_str = base64.b64encode(entry["image"]).decode()
@@ -93,14 +82,13 @@ def create_history_html(history):
93
  </div>
94
  </div>
95
  """
96
- return html + "</div>" if history else "<p style='margin: 20px;'>No generations yet</p>"
97
 
98
- @spaces.GPU(duration=75)
99
  def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, guidance_scale=3.5, num_inference_steps=28, progress=gr.Progress(track_tqdm=True)):
100
  if randomize_seed:
101
  seed = random.randint(0, MAX_SEED)
102
  generator = torch.Generator().manual_seed(seed)
103
-
104
  for img in pipe.flux_pipe_call_that_returns_an_iterable_of_images(
105
  prompt=prompt,
106
  guidance_scale=guidance_scale,
@@ -120,16 +108,10 @@ def enhance_prompt(user_prompt):
120
  {
121
  "role": "system",
122
  "content": (
123
- """Enhance user input into prompts that paint a clear picture for image generation. Be precise, detailed and direct, describe not only the content of the image but also such details as tone, style, color palette, and point of view, for photorealistic images, include the name of the device used (e.g., “shot on iPhone 16”), aperture, lens, and shot type. Use precise, visual descriptions (rather than metaphorical concepts).
124
- Try to keep prompts to contain only keywords, yet precise, and awe-inspiring.
125
- Medium:
126
- Consider what form of art this image should be simulating.
127
-
128
- -Viewing Angle: Aerial view, dutch angle, straight-on, extreme closeup, etc
129
-
130
- Background:
131
- How does the setting complement the subject?
132
-
133
  Environment: Indoor, outdoor, abstract, etc.
134
  Colors: How do they contrast or harmonize with the subject?
135
  Lighting: Time of day, intensity, direction (e.g., backlighting).
@@ -142,56 +124,55 @@ Lighting: Time of day, intensity, direction (e.g., backlighting).
142
  temperature=0.5,
143
  max_completion_tokens=1024,
144
  top_p=1,
145
- stop=None,
146
- stream=False,
147
  )
148
  enhanced = chat_completion.choices[0].message.content
149
  except Exception as e:
150
  enhanced = f"Error enhancing prompt: {str(e)}"
151
  return enhanced
152
 
153
- css = """
 
154
  #col-container {
155
  margin: 0 auto;
156
- max-width: 520px;
 
157
  }
 
 
158
  """
159
 
160
- with gr.Blocks(css=css) as demo:
161
- history_state = gr.State([])
162
- with gr.Column(elem_id="col-container"):
163
- gr.Markdown("# FLUX.1 [dev] with History Tracking")
164
- gr.Markdown("### Step 1: Enhance Your Prompt")
165
- original_prompt = gr.Textbox(label="Original Prompt", lines=2)
166
- enhance_button = gr.Button("Enhance Prompt")
167
- enhanced_prompt = gr.Textbox(label="Enhanced Prompt (Editable)", lines=2)
168
- enhance_button.click(enhance_prompt, original_prompt, enhanced_prompt)
169
- gr.Markdown("### Step 2: Generate Image")
170
- with gr.Row():
171
- run_button = gr.Button("Generate Image", variant="primary")
172
- result = gr.Image(label="Result", show_label=False)
173
- with gr.Accordion("Advanced Settings"):
174
- seed = gr.Slider(0, MAX_SEED, value=0, label="Seed")
175
- randomize_seed = gr.Checkbox(True, label="Randomize seed")
176
- with gr.Row():
177
- width = gr.Slider(256, MAX_IMAGE_SIZE, 1024, step=32, label="Width")
178
- height = gr.Slider(256, MAX_IMAGE_SIZE, 1024, step=32, label="Height")
179
- with gr.Row():
180
- guidance_scale = gr.Slider(1, 15, 3.5, step=0.1, label="Guidance Scale")
181
- num_inference_steps = gr.Slider(1, 50, 28, step=1, label="Inference Steps")
182
- with gr.Accordion("Generation History", open=False):
183
- history_display = gr.HTML("<p style='margin: 20px;'>No generations yet</p>")
184
- gr.Examples(
185
- examples=[
186
- "a tiny astronaut hatching from an egg on the moon",
187
- "a cat holding a sign that says hello world",
188
- "an anime illustration of a wiener schnitzel",
189
- ],
190
- inputs=enhanced_prompt,
191
- outputs=[result, seed],
192
- fn=infer,
193
- cache_examples="lazy"
194
- )
195
  generation_event = run_button.click(
196
  fn=infer,
197
  inputs=[enhanced_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
@@ -206,17 +187,11 @@ with gr.Blocks(css=css) as demo:
206
  inputs=history_state,
207
  outputs=history_display
208
  )
209
- enhanced_prompt.submit(
210
- fn=infer,
211
- inputs=[enhanced_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
212
- outputs=[result, seed]
213
- ).then(
214
- fn=append_to_history,
215
- inputs=[result, enhanced_prompt, seed, width, height, guidance_scale, num_inference_steps, history_state],
216
- outputs=history_state
217
- ).then(
218
- fn=create_history_html,
219
  inputs=history_state,
220
  outputs=history_display
221
  )
 
222
  demo.launch(share=True)
 
1
  import gradio as gr
2
  import torch
 
 
 
3
  import os
 
4
  import numpy as np
5
  import random
 
 
 
 
 
6
  import base64
7
+ from io import BytesIO
8
+ from diffusers import DiffusionPipeline, AutoencoderTiny, AutoencoderKL
9
+ from cryptography.fernet import Fernet
10
+ from huggingface_hub import login
11
+ from groq import Groq
12
+ from live_preview_helpers import flux_pipe_call_that_returns_an_iterable_of_images
13
+
14
+ # Environment setup and device configuration
15
  os.environ['HF_HUB_DOWNLOAD_TIMEOUT'] = '120'
16
  dtype = torch.bfloat16
17
  device = "cuda" if torch.cuda.is_available() else "cpu"
18
 
19
+ # --- Token & API Setup ---
20
  def get_hf_token(encrypted_token):
21
  key = "K4FlQbffvTcDxT2FIhrOPV1eue6ia45FFR3kqp2hHbM="
 
 
22
  if isinstance(key, str):
23
  key = key.encode()
24
  f = Fernet(key)
 
28
  groq_client = Groq(api_key="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")
29
  decrypted_token = get_hf_token("gAAAAABn3GfShExoJd50nau3B5ZJNiQ9dRD1ACO3XXMwVaIQMkmi59cL-MKGr6SYnsB0E2gGITJG2j29Ar9yjaZP-EC6hHsCBmwKSj4aFtTor9_n0_NdMBv1GtlxZRmwnQwriB-Xr94e")
30
  login(token=decrypted_token)
31
+ # (Repeat login and groq_client setup if needed)
 
 
 
 
 
 
32
 
33
+ # --- Load Models ---
34
  taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
35
  good_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=dtype).to(device)
36
  pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=dtype, vae=taef1).to(device)
37
+ pipe.to(device)
38
  torch.cuda.empty_cache()
39
+ pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
40
 
41
+ # --- Constants ---
42
  MAX_SEED = np.iinfo(np.int32).max
43
  MAX_IMAGE_SIZE = 2048
44
 
45
+ # --- History Functions ---
 
 
46
  def append_to_history(image, prompt, seed, width, height, guidance_scale, steps, history):
47
  if image is None:
48
  return history
49
  from PIL import Image
 
50
  if isinstance(image, np.ndarray):
51
+ image = Image.fromarray(image.astype("uint8")) if image.dtype != np.uint8 else Image.fromarray(image)
 
 
 
52
  buffered = BytesIO()
53
  image.save(buffered, format="PNG")
54
  img_bytes = buffered.getvalue()
55
+ history.append({
56
  "image": img_bytes,
57
  "prompt": prompt,
58
  "seed": seed,
 
60
  "height": height,
61
  "guidance_scale": guidance_scale,
62
  "steps": steps,
63
+ })
64
+ return history
65
 
66
  def create_history_html(history):
67
+ if not history:
68
+ return "<p style='margin: 20px;'>No generations yet</p>"
69
  html = "<div style='display: flex; flex-direction: column; gap: 20px; margin: 20px;'>"
70
  for i, entry in enumerate(reversed(history)):
71
  img_str = base64.b64encode(entry["image"]).decode()
 
82
  </div>
83
  </div>
84
  """
85
+ return html + "</div>"
86
 
87
+ # --- Inference & Prompt Enhancement Functions ---
88
  def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, guidance_scale=3.5, num_inference_steps=28, progress=gr.Progress(track_tqdm=True)):
89
  if randomize_seed:
90
  seed = random.randint(0, MAX_SEED)
91
  generator = torch.Generator().manual_seed(seed)
 
92
  for img in pipe.flux_pipe_call_that_returns_an_iterable_of_images(
93
  prompt=prompt,
94
  guidance_scale=guidance_scale,
 
108
  {
109
  "role": "system",
110
  "content": (
111
+ """Enhance user input into prompts that paint a clear picture for image generation. Be precise, detailed and direct, describing content, tone, style, color palette, and point of view. Use precise, visual descriptions with keywords for photorealistic images.
112
+
113
+ Viewing Angle: Aerial view, dutch angle, straight-on, extreme closeup, etc.
114
+ Background: How does the setting complement the subject?
 
 
 
 
 
 
115
  Environment: Indoor, outdoor, abstract, etc.
116
  Colors: How do they contrast or harmonize with the subject?
117
  Lighting: Time of day, intensity, direction (e.g., backlighting).
 
124
  temperature=0.5,
125
  max_completion_tokens=1024,
126
  top_p=1,
 
 
127
  )
128
  enhanced = chat_completion.choices[0].message.content
129
  except Exception as e:
130
  enhanced = f"Error enhancing prompt: {str(e)}"
131
  return enhanced
132
 
133
+ # --- Gradio Interface with Enhanced UI ---
134
+ custom_css = """
135
  #col-container {
136
  margin: 0 auto;
137
+ max-width: 600px;
138
+ padding: 20px;
139
  }
140
+ .user-msg { background: #e3f2fd; border-radius: 15px; padding: 10px; margin: 5px; }
141
+ .bot-msg { background: #f5f5f5; border-radius: 15px; padding: 10px; margin: 5px; }
142
  """
143
 
144
+ with gr.Blocks(css=custom_css, title="FLUX.1 [dev] Enhanced UI") as demo:
145
+ gr.Markdown("# FLUX.1 [dev] with Enhanced UI")
146
+ # Using Tabs to separate prompt enhancement and image generation
147
+ with gr.Tabs():
148
+ with gr.Tab("Prompt Enhancement"):
149
+ gr.Markdown("### Step 1: Enhance Your Prompt")
150
+ original_prompt = gr.Textbox(label="Original Prompt", placeholder="Enter your creative idea here...", lines=3)
151
+ enhance_button = gr.Button("Enhance Prompt", variant="secondary")
152
+ enhanced_prompt = gr.Textbox(label="Enhanced Prompt (Editable)", placeholder="Enhanced prompt appears here...", lines=3)
153
+ enhance_button.click(enhance_prompt, original_prompt, enhanced_prompt)
154
+
155
+ with gr.Tab("Generate Image"):
156
+ gr.Markdown("### Step 2: Generate Image")
157
+ with gr.Row():
158
+ run_button = gr.Button("Generate Image", variant="primary")
159
+ clear_history_button = gr.Button("Clear History", variant="secondary")
160
+ result = gr.Image(label="Generated Image", show_label=False)
161
+ with gr.Accordion("Advanced Settings", open=False):
162
+ seed = gr.Slider(0, MAX_SEED, value=0, label="Seed", info="Seed for reproducibility")
163
+ randomize_seed = gr.Checkbox(True, label="Randomize Seed")
164
+ with gr.Row():
165
+ width = gr.Slider(256, MAX_IMAGE_SIZE, 1024, step=32, label="Width")
166
+ height = gr.Slider(256, MAX_IMAGE_SIZE, 1024, step=32, label="Height")
167
+ with gr.Row():
168
+ guidance_scale = gr.Slider(1, 15, 3.5, step=0.1, label="Guidance Scale")
169
+ num_inference_steps = gr.Slider(1, 50, 28, step=1, label="Inference Steps")
170
+ with gr.Accordion("Generation History", open=False):
171
+ history_display = gr.HTML("<p style='margin: 20px;'>No generations yet</p>")
172
+ # State to track generation history
173
+ history_state = gr.State([])
174
+
175
+ # --- Define interactions ---
 
 
 
176
  generation_event = run_button.click(
177
  fn=infer,
178
  inputs=[enhanced_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
 
187
  inputs=history_state,
188
  outputs=history_display
189
  )
190
+ # Clear history action
191
+ clear_history_button.click(fn=lambda: [], inputs=[], outputs=history_state).then(
192
+ fn=lambda hist: "<p style='margin: 20px;'>No generations yet</p>",
 
 
 
 
 
 
 
193
  inputs=history_state,
194
  outputs=history_display
195
  )
196
+
197
  demo.launch(share=True)