concauu commited on
Commit
f344ac1
·
verified ·
1 Parent(s): cffbc7f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -38
app.py CHANGED
@@ -1,24 +1,30 @@
1
  import gradio as gr
2
  import torch
 
 
 
3
  import os
 
4
  import numpy as np
5
  import random
6
- import base64
 
 
 
7
  from io import BytesIO
8
- from diffusers import DiffusionPipeline, AutoencoderTiny, AutoencoderKL
9
- from cryptography.fernet import Fernet
10
- from huggingface_hub import login
11
- from groq import Groq
12
- from live_preview_helpers import flux_pipe_call_that_returns_an_iterable_of_images
13
-
14
- # Environment setup and device configuration
15
  os.environ['HF_HUB_DOWNLOAD_TIMEOUT'] = '120'
16
  dtype = torch.bfloat16
17
  device = "cuda" if torch.cuda.is_available() else "cpu"
18
 
19
- # --- Token & API Setup ---
20
  def get_hf_token(encrypted_token):
21
  key = "K4FlQbffvTcDxT2FIhrOPV1eue6ia45FFR3kqp2hHbM="
 
 
22
  if isinstance(key, str):
23
  key = key.encode()
24
  f = Fernet(key)
@@ -28,31 +34,39 @@ def get_hf_token(encrypted_token):
28
  groq_client = Groq(api_key="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")
29
  decrypted_token = get_hf_token("gAAAAABn3GfShExoJd50nau3B5ZJNiQ9dRD1ACO3XXMwVaIQMkmi59cL-MKGr6SYnsB0E2gGITJG2j29Ar9yjaZP-EC6hHsCBmwKSj4aFtTor9_n0_NdMBv1GtlxZRmwnQwriB-Xr94e")
30
  login(token=decrypted_token)
31
- # (Repeat login and groq_client setup if needed)
 
 
 
 
 
 
32
 
33
- # --- Load Models ---
34
  taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
35
  good_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=dtype).to(device)
36
  pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=dtype, vae=taef1).to(device)
37
- pipe.to(device)
38
  torch.cuda.empty_cache()
39
- pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
40
 
41
- # --- Constants ---
42
  MAX_SEED = np.iinfo(np.int32).max
43
  MAX_IMAGE_SIZE = 2048
44
 
45
- # --- History Functions ---
 
 
46
  def append_to_history(image, prompt, seed, width, height, guidance_scale, steps, history):
47
  if image is None:
48
  return history
49
  from PIL import Image
 
50
  if isinstance(image, np.ndarray):
51
- image = Image.fromarray(image.astype("uint8")) if image.dtype != np.uint8 else Image.fromarray(image)
 
 
 
52
  buffered = BytesIO()
53
  image.save(buffered, format="PNG")
54
  img_bytes = buffered.getvalue()
55
- history.append({
56
  "image": img_bytes,
57
  "prompt": prompt,
58
  "seed": seed,
@@ -60,12 +74,9 @@ def append_to_history(image, prompt, seed, width, height, guidance_scale, steps,
60
  "height": height,
61
  "guidance_scale": guidance_scale,
62
  "steps": steps,
63
- })
64
- return history
65
 
66
  def create_history_html(history):
67
- if not history:
68
- return "<p style='margin: 20px;'>No generations yet</p>"
69
  html = "<div style='display: flex; flex-direction: column; gap: 20px; margin: 20px;'>"
70
  for i, entry in enumerate(reversed(history)):
71
  img_str = base64.b64encode(entry["image"]).decode()
@@ -82,13 +93,14 @@ def create_history_html(history):
82
  </div>
83
  </div>
84
  """
85
- return html + "</div>"
86
 
87
- # --- Inference & Prompt Enhancement Functions ---
88
  def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, guidance_scale=3.5, num_inference_steps=28, progress=gr.Progress(track_tqdm=True)):
89
  if randomize_seed:
90
  seed = random.randint(0, MAX_SEED)
91
  generator = torch.Generator().manual_seed(seed)
 
92
  for img in pipe.flux_pipe_call_that_returns_an_iterable_of_images(
93
  prompt=prompt,
94
  guidance_scale=guidance_scale,
@@ -108,10 +120,16 @@ def enhance_prompt(user_prompt):
108
  {
109
  "role": "system",
110
  "content": (
111
- """Enhance user input into prompts that paint a clear picture for image generation. Be precise, detailed and direct, describing content, tone, style, color palette, and point of view. Use precise, visual descriptions with keywords for photorealistic images.
112
-
113
- Viewing Angle: Aerial view, dutch angle, straight-on, extreme closeup, etc.
114
- Background: How does the setting complement the subject?
 
 
 
 
 
 
115
  Environment: Indoor, outdoor, abstract, etc.
116
  Colors: How do they contrast or harmonize with the subject?
117
  Lighting: Time of day, intensity, direction (e.g., backlighting).
@@ -124,13 +142,25 @@ Lighting: Time of day, intensity, direction (e.g., backlighting).
124
  temperature=0.5,
125
  max_completion_tokens=1024,
126
  top_p=1,
 
 
127
  )
128
  enhanced = chat_completion.choices[0].message.content
129
  except Exception as e:
130
  enhanced = f"Error enhancing prompt: {str(e)}"
131
  return enhanced
 
 
 
 
 
 
 
 
 
 
132
 
133
- # --- Gradio Interface with Enhanced UI ---
134
  custom_css = """
135
  #col-container {
136
  margin: 0 auto;
@@ -141,9 +171,9 @@ custom_css = """
141
  .bot-msg { background: #f5f5f5; border-radius: 15px; padding: 10px; margin: 5px; }
142
  """
143
 
144
- with gr.Blocks(css=custom_css, title="FLUX.1 [dev] Enhanced UI") as demo:
145
- gr.Markdown("# FLUX.1 [dev] with Enhanced UI")
146
- # Using Tabs to separate prompt enhancement and image generation
147
  with gr.Tabs():
148
  with gr.Tab("Prompt Enhancement"):
149
  gr.Markdown("### Step 1: Enhance Your Prompt")
@@ -152,6 +182,13 @@ with gr.Blocks(css=custom_css, title="FLUX.1 [dev] Enhanced UI") as demo:
152
  enhanced_prompt = gr.Textbox(label="Enhanced Prompt (Editable)", placeholder="Enhanced prompt appears here...", lines=3)
153
  enhance_button.click(enhance_prompt, original_prompt, enhanced_prompt)
154
 
 
 
 
 
 
 
 
155
  with gr.Tab("Generate Image"):
156
  gr.Markdown("### Step 2: Generate Image")
157
  with gr.Row():
@@ -171,8 +208,6 @@ with gr.Blocks(css=custom_css, title="FLUX.1 [dev] Enhanced UI") as demo:
171
  history_display = gr.HTML("<p style='margin: 20px;'>No generations yet</p>")
172
  # State to track generation history
173
  history_state = gr.State([])
174
-
175
- # --- Define interactions ---
176
  generation_event = run_button.click(
177
  fn=infer,
178
  inputs=[enhanced_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
@@ -187,11 +222,17 @@ with gr.Blocks(css=custom_css, title="FLUX.1 [dev] Enhanced UI") as demo:
187
  inputs=history_state,
188
  outputs=history_display
189
  )
190
- # Clear history action
191
- clear_history_button.click(fn=lambda: [], inputs=[], outputs=history_state).then(
192
- fn=lambda hist: "<p style='margin: 20px;'>No generations yet</p>",
 
 
 
 
 
 
 
193
  inputs=history_state,
194
  outputs=history_display
195
  )
196
-
197
- demo.launch(share=True)
 
1
  import gradio as gr
2
  import torch
3
+ from groq import Groq
4
+ from cryptography.fernet import Fernet
5
+ from huggingface_hub import login
6
  import os
7
+ os.environ['HF_HUB_DOWNLOAD_TIMEOUT'] = '120'
8
  import numpy as np
9
  import random
10
+ import spaces
11
+ from diffusers import DiffusionPipeline, FlowMatchEulerDiscreteScheduler, AutoencoderTiny, AutoencoderKL, UNet2DConditionModel
12
+ from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast, T5Tokenizer, T5EncoderModel
13
+ from live_preview_helpers import calculate_shift, retrieve_timesteps, flux_pipe_call_that_returns_an_iterable_of_images
14
  from io import BytesIO
15
+ import base64
16
+ from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
17
+ # For voice transcription
18
+ import speech_recognition as sr
 
 
 
19
  os.environ['HF_HUB_DOWNLOAD_TIMEOUT'] = '120'
20
  dtype = torch.bfloat16
21
  device = "cuda" if torch.cuda.is_available() else "cpu"
22
 
23
+
24
  def get_hf_token(encrypted_token):
25
  key = "K4FlQbffvTcDxT2FIhrOPV1eue6ia45FFR3kqp2hHbM="
26
+ if not key:
27
+ raise ValueError("Missing decryption key! Set the DECRYPTION_KEY environment variable.")
28
  if isinstance(key, str):
29
  key = key.encode()
30
  f = Fernet(key)
 
34
  groq_client = Groq(api_key="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")
35
  decrypted_token = get_hf_token("gAAAAABn3GfShExoJd50nau3B5ZJNiQ9dRD1ACO3XXMwVaIQMkmi59cL-MKGr6SYnsB0E2gGITJG2j29Ar9yjaZP-EC6hHsCBmwKSj4aFtTor9_n0_NdMBv1GtlxZRmwnQwriB-Xr94e")
36
  login(token=decrypted_token)
37
+ decrypted_token = get_hf_token("gAAAAABn3GfShExoJd50nau3B5ZJNiQ9dRD1ACO3XXMwVaIQMkmi59cL-MKGr6SYnsB0E2gGITJG2j29Ar9yjaZP-EC6hHsCBmwKSj4aFtTor9_n0_NdMBv1GtlxZRmwnQwriB-Xr94e")
38
+ login(token=decrypted_token)
39
+ groq_client = Groq(api_key="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")
40
+
41
+
42
+ dtype = torch.bfloat16
43
+ device = "cuda" if torch.cuda.is_available() else "cpu"
44
 
 
45
  taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
46
  good_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=dtype).to(device)
47
  pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=dtype, vae=taef1).to(device)
 
48
  torch.cuda.empty_cache()
 
49
 
 
50
  MAX_SEED = np.iinfo(np.int32).max
51
  MAX_IMAGE_SIZE = 2048
52
 
53
+ pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
54
+
55
+ # ----- HISTORY FUNCTIONS & GRADIO INTERFACE -----
56
  def append_to_history(image, prompt, seed, width, height, guidance_scale, steps, history):
57
  if image is None:
58
  return history
59
  from PIL import Image
60
+ import numpy as np
61
  if isinstance(image, np.ndarray):
62
+ if image.dtype == np.uint8:
63
+ image = Image.fromarray(image)
64
+ else:
65
+ image = Image.fromarray((image * 255).astype(np.uint8))
66
  buffered = BytesIO()
67
  image.save(buffered, format="PNG")
68
  img_bytes = buffered.getvalue()
69
+ return history + [{
70
  "image": img_bytes,
71
  "prompt": prompt,
72
  "seed": seed,
 
74
  "height": height,
75
  "guidance_scale": guidance_scale,
76
  "steps": steps,
77
+ }]
 
78
 
79
  def create_history_html(history):
 
 
80
  html = "<div style='display: flex; flex-direction: column; gap: 20px; margin: 20px;'>"
81
  for i, entry in enumerate(reversed(history)):
82
  img_str = base64.b64encode(entry["image"]).decode()
 
93
  </div>
94
  </div>
95
  """
96
+ return html + "</div>" if history else "<p style='margin: 20px;'>No generations yet</p>"
97
 
98
+ @spaces.GPU(duration=75)
99
  def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, guidance_scale=3.5, num_inference_steps=28, progress=gr.Progress(track_tqdm=True)):
100
  if randomize_seed:
101
  seed = random.randint(0, MAX_SEED)
102
  generator = torch.Generator().manual_seed(seed)
103
+
104
  for img in pipe.flux_pipe_call_that_returns_an_iterable_of_images(
105
  prompt=prompt,
106
  guidance_scale=guidance_scale,
 
120
  {
121
  "role": "system",
122
  "content": (
123
+ """Enhance user input into prompts that paint a clear picture for image generation. Be precise, detailed and direct, describe not only the content of the image but also such details as tone, style, color palette, and point of view, for photorealistic images Use precise, visual descriptions (rather than metaphorical concepts).
124
+ Try to keep prompts to contain only keywords, yet precise, and awe-inspiring.
125
+ Medium:
126
+ Consider what form of art this image should be simulating.
127
+
128
+ -Viewing Angle: Aerial view, dutch angle, straight-on, extreme closeup, etc
129
+
130
+ Background:
131
+ How does the setting complement the subject?
132
+
133
  Environment: Indoor, outdoor, abstract, etc.
134
  Colors: How do they contrast or harmonize with the subject?
135
  Lighting: Time of day, intensity, direction (e.g., backlighting).
 
142
  temperature=0.5,
143
  max_completion_tokens=1024,
144
  top_p=1,
145
+ stop=None,
146
+ stream=False,
147
  )
148
  enhanced = chat_completion.choices[0].message.content
149
  except Exception as e:
150
  enhanced = f"Error enhancing prompt: {str(e)}"
151
  return enhanced
152
+ # --- Voice Transcription Function ---
153
+ def transcribe_audio(audio_file):
154
+ r = sr.Recognizer()
155
+ try:
156
+ with sr.AudioFile(audio_file) as source:
157
+ audio_data = r.record(source)
158
+ text = r.recognize_google(audio_data)
159
+ except Exception as e:
160
+ text = f"Transcription error: {str(e)}"
161
+ return text
162
 
163
+ # --- Gradio Interface with Enhanced UI and Voice Recognition ---
164
  custom_css = """
165
  #col-container {
166
  margin: 0 auto;
 
171
  .bot-msg { background: #f5f5f5; border-radius: 15px; padding: 10px; margin: 5px; }
172
  """
173
 
174
+ with gr.Blocks(css=custom_css, title="FLUX.1 [dev] Enhanced UI with Voice Recognition") as demo:
175
+ gr.Markdown("# FLUX.1 [dev] with Enhanced UI and Voice Recognition")
176
+ # Using Tabs to separate functionalities
177
  with gr.Tabs():
178
  with gr.Tab("Prompt Enhancement"):
179
  gr.Markdown("### Step 1: Enhance Your Prompt")
 
182
  enhanced_prompt = gr.Textbox(label="Enhanced Prompt (Editable)", placeholder="Enhanced prompt appears here...", lines=3)
183
  enhance_button.click(enhance_prompt, original_prompt, enhanced_prompt)
184
 
185
+ with gr.Tab("Voice Recognition"):
186
+ gr.Markdown("### Step 1A: Record Your Prompt")
187
+ audio_input = gr.Audio(source="microphone", type="filepath", label="Record your prompt")
188
+ transcribe_button = gr.Button("Transcribe Audio", variant="secondary")
189
+ voice_text = gr.Textbox(label="Transcribed Prompt", placeholder="Your spoken prompt will appear here...", lines=3)
190
+ transcribe_button.click(transcribe_audio, audio_input, voice_text)
191
+
192
  with gr.Tab("Generate Image"):
193
  gr.Markdown("### Step 2: Generate Image")
194
  with gr.Row():
 
208
  history_display = gr.HTML("<p style='margin: 20px;'>No generations yet</p>")
209
  # State to track generation history
210
  history_state = gr.State([])
 
 
211
  generation_event = run_button.click(
212
  fn=infer,
213
  inputs=[enhanced_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
 
222
  inputs=history_state,
223
  outputs=history_display
224
  )
225
+ enhanced_prompt.submit(
226
+ fn=infer,
227
+ inputs=[enhanced_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
228
+ outputs=[result, seed]
229
+ ).then(
230
+ fn=append_to_history,
231
+ inputs=[result, enhanced_prompt, seed, width, height, guidance_scale, num_inference_steps, history_state],
232
+ outputs=history_state
233
+ ).then(
234
+ fn=create_history_html,
235
  inputs=history_state,
236
  outputs=history_display
237
  )
238
+ demo.launch(share=True)