image_generator / app.py
concauu's picture
Update app.py
ab2938c verified
import gradio as gr
import torch
from groq import Groq
from cryptography.fernet import Fernet
from huggingface_hub import login
import os
os.environ['HF_HUB_DOWNLOAD_TIMEOUT'] = '120'
import numpy as np
import random
import spaces
from diffusers import DiffusionPipeline, FlowMatchEulerDiscreteScheduler, AutoencoderTiny, AutoencoderKL, UNet2DConditionModel
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast, T5Tokenizer, T5EncoderModel
from live_preview_helpers import calculate_shift, retrieve_timesteps, flux_pipe_call_that_returns_an_iterable_of_images
from io import BytesIO
import base64
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
import speech_recognition as sr
os.environ['HF_HUB_DOWNLOAD_TIMEOUT'] = '120'
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
def get_hf_token(encrypted_token):
key = "K4FlQbffvTcDxT2FIhrOPV1eue6ia45FFR3kqp2hHbM="
if not key:
raise ValueError("Missing decryption key! Set the DECRYPTION_KEY environment variable.")
if isinstance(key, str):
key = key.encode()
f = Fernet(key)
decrypted_token = f.decrypt(encrypted_token).decode()
return decrypted_token
groq_client = Groq(api_key="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")
decrypted_token = get_hf_token("gAAAAABn3GfShExoJd50nau3B5ZJNiQ9dRD1ACO3XXMwVaIQMkmi59cL-MKGr6SYnsB0E2gGITJG2j29Ar9yjaZP-EC6hHsCBmwKSj4aFtTor9_n0_NdMBv1GtlxZRmwnQwriB-Xr94e")
login(token=decrypted_token)
decrypted_token = get_hf_token("gAAAAABn3GfShExoJd50nau3B5ZJNiQ9dRD1ACO3XXMwVaIQMkmi59cL-MKGr6SYnsB0E2gGITJG2j29Ar9yjaZP-EC6hHsCBmwKSj4aFtTor9_n0_NdMBv1GtlxZRmwnQwriB-Xr94e")
login(token=decrypted_token)
groq_client = Groq(api_key="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")
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
good_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=dtype).to(device)
pipe = DiffusionPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
torch_dtype=dtype,
vae=good_vae # Use the AutoencoderKL instance
).to(device)
torch.cuda.empty_cache()
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 2048
pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
# ----- HISTORY FUNCTIONS & GRADIO INTERFACE -----
def append_to_history(image, prompt, seed, width, height, guidance_scale, steps, history):
if image is None:
return history
from PIL import Image
import numpy as np
if isinstance(image, np.ndarray):
if image.dtype == np.uint8:
image = Image.fromarray(image)
else:
image = Image.fromarray((image * 255).astype(np.uint8))
buffered = BytesIO()
image.save(buffered, format="PNG")
img_bytes = buffered.getvalue()
return history + [{
"image": img_bytes,
"prompt": prompt,
"seed": seed,
"width": width,
"height": height,
"guidance_scale": guidance_scale,
"steps": steps,
}]
def create_history_html(history):
html = "<div style='display: flex; flex-direction: column; gap: 20px; margin: 20px;'>"
for i, entry in enumerate(reversed(history)):
img_str = base64.b64encode(entry["image"]).decode()
html += f"""
<div style='display: flex; gap: 20px; padding: 20px; background: #f5f5f5; border-radius: 10px;'>
<img src="data:image/png;base64,{img_str}" style="width: 150px; height: 150px; object-fit: cover; border-radius: 5px;"/>
<div style='flex: 1;'>
<h3 style='margin: 0;'>Generation #{len(history)-i}</h3>
<p><strong>Prompt:</strong> {entry["prompt"]}</p>
<p><strong>Seed:</strong> {entry["seed"]}</p>
<p><strong>Size:</strong> {entry["width"]}x{entry["height"]}</p>
<p><strong>Guidance:</strong> {entry["guidance_scale"]}</p>
<p><strong>Steps:</strong> {entry["steps"]}</p>
</div>
</div>
"""
return html + "</div>" if history else "<p style='margin: 20px;'>No generations yet</p>"
def transcribe_speech(audio_file):
r = sr.Recognizer()
with sr.AudioFile(audio_file) as source:
audio = r.record(source)
return r.recognize_google(audio)
def set_font_size(choice):
return gr.skip()
@spaces.GPU(duration=75)
def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, guidance_scale=3.5, num_inference_steps=28, progress=gr.Progress(track_tqdm=True)):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator().manual_seed(seed)
for img in pipe.flux_pipe_call_that_returns_an_iterable_of_images(
prompt=prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
width=width,
height=height,
generator=generator,
output_type="pil",
good_vae=good_vae,
):
yield img, seed
def enhance_prompt(user_prompt):
try:
chat_completion = groq_client.chat.completions.create(
messages=[
{
"role": "system",
"content": (
"""Enhance user input into prompts that paint a clear picture for image generation. Be precise, detailed and direct, describe not only the content of the image but also such details as tone, style, color palette, and point of view, for photorealistic images Use precise, visual descriptions (rather than metaphorical concepts).
Try to keep prompts to contain only keywords, yet precise, and awe-inspiring.
Medium:
Consider what form of art this image should be simulating.
-Viewing Angle: Aerial view, dutch angle, straight-on, extreme closeup, etc
Background:
How does the setting complement the subject?
Environment: Indoor, outdoor, abstract, etc.
Colors: How do they contrast or harmonize with the subject?
Lighting: Time of day, intensity, direction (e.g., backlighting).
"""
),
},
{"role": "user", "content": user_prompt}
],
model="llama-3.3-70b-versatile",
temperature=0.5,
max_completion_tokens=1024,
top_p=1,
stop=None,
stream=False,
)
enhanced = chat_completion.choices[0].message.content
except Exception as e:
enhanced = f"Error enhancing prompt: {str(e)}"
return enhanced
def fake_typing():
for _ in range(3):
yield {"visible": False}, {"visible": True}
time.sleep(0.5)
css = """
#col-container {
margin: 0 auto;
max-width: 520px;
}
"""
js_callback = """
(fontSize) => {
const sizeMap = {"Small": "12px", "Medium": "16px", "Large": "20px"};
const newSize = sizeMap[fontSize];
if (newSize) {
let styleEl = document.getElementById("global-font-style");
if (!styleEl) {
styleEl = document.createElement("style");
styleEl.id = "global-font-style";
document.head.appendChild(styleEl);
}
styleEl.innerHTML = `
body, .gradio-container, .gradio-container * {
font-size: ${newSize} !important;
}
`;
}
return "";
}
"""
# --- Gradio Interface ---
with gr.Blocks(css="""
.user-msg { background: #e3f2fd; border-radius: 15px; padding: 10px; margin: 5px; }
.bot-msg { background: #f5f5f5; border-radius: 15px; padding: 10px; margin: 5px; }
#col-container { margin: 0 auto; max-width: 520px; }
""") as demo:
# Global controls: Font size adjuster and voice input
font_choice = gr.Dropdown(
choices=["Small", "Medium", "Large"],
value="Medium",
label="Font Size"
)
gr.Markdown("### This text (and all other text) will change size dynamically.")
sample_text = gr.Textbox(label="Try typing here:", value="The quick brown fox jumps over the lazy dog.")
# When font_choice changes, call set_font_size (returns gr.skip() so value remains unchanged)
font_choice.change(
fn=None,
inputs=font_choice,
outputs=[], # No output updated by Python.
js=js_callback # JS callback updates document's font size.
)
gr.Markdown("## Voice Input: Record your prompt")
# For Gradio 4.x, use source="microphone"
audio_input = gr.Audio(sources="microphone", type="filepath", label="Record your voice")
transcribe_button = gr.Button("Transcribe Voice")
transcribed_text = gr.Textbox(label="Transcribed Text", lines=2)
transcribe_button.click(fn=transcribe_speech, inputs=audio_input, outputs=transcribed_text)
copy_transcribed = gr.Button("Copy Transcribed Text to Prompt")
# Main interface for image generation
history_state = gr.State([])
with gr.Column(elem_id="col-container"):
gr.Markdown("# FLUX.1 [dev] with History Tracking")
gr.Markdown("### Step 1: Enhance Your Prompt")
original_prompt = gr.Textbox(label="Original Prompt", lines=2)
enhance_button = gr.Button("Enhance Prompt")
enhanced_prompt = gr.Textbox(label="Enhanced Prompt (Editable)", lines=2)
copy_transcribed.click(fn=lambda txt: txt, inputs=transcribed_text, outputs=original_prompt)
enhance_button.click(enhance_prompt, original_prompt, enhanced_prompt)
gr.Markdown("### Step 2: Generate Image")
with gr.Row():
run_button = gr.Button("Generate Image", variant="primary")
result = gr.Image(label="Result", show_label=False)
with gr.Accordion("Advanced Settings"):
seed = gr.Slider(0, MAX_SEED, value=0, label="Seed")
randomize_seed = gr.Checkbox(True, label="Randomize seed")
with gr.Row():
width = gr.Slider(256, MAX_IMAGE_SIZE, 1024, step=32, label="Width")
height = gr.Slider(256, MAX_IMAGE_SIZE, 1024, step=32, label="Height")
with gr.Row():
guidance_scale = gr.Slider(1, 15, 3.5, step=0.1, label="Guidance Scale")
num_inference_steps = gr.Slider(1, 50, 28, step=1, label="Inference Steps")
with gr.Accordion("Generation History", open=False):
history_display = gr.HTML("<p style='margin: 20px;'>No generations yet</p>")
gr.Examples(
examples=[
"a tiny astronaut hatching from an egg on the moon",
"a cat holding a sign that says hello world",
"an anime illustration of a wiener schnitzel",
],
inputs=enhanced_prompt,
outputs=[result, seed],
fn=infer,
cache_examples="lazy"
)
generation_event = run_button.click(
fn=infer,
inputs=[enhanced_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
outputs=[result, seed]
)
generation_event.then(
fn=append_to_history,
inputs=[result, enhanced_prompt, seed, width, height, guidance_scale, num_inference_steps, history_state],
outputs=history_state
).then(
fn=create_history_html,
inputs=history_state,
outputs=history_display
)
enhanced_prompt.submit(
fn=infer,
inputs=[enhanced_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
outputs=[result, seed]
).then(
fn=append_to_history,
inputs=[result, enhanced_prompt, seed, width, height, guidance_scale, num_inference_steps, history_state],
outputs=history_state
).then(
fn=create_history_html,
inputs=history_state,
outputs=history_display
)
demo.launch(share=True)