11111 / app.py
JQ66's picture
Create app.py
fe5b715 verified
# --- Filename: app.py ---
import gradio as gr
import openai
import torch
from diffusers import StableDiffusionPipeline, LCMScheduler
import os
from PIL import Image
import io # Required for handling audio file object for OpenAI API
import time # To estimate generation time
# --- Configuration ---
# Load API keys from Hugging Face Secrets or environment variables
# IMPORTANT: Ensure the secret/variable named OPENAI_API_KEY is set in your environment.
openai.api_key = os.environ.get("OPENAI_API_KEY")
hf_token = os.environ.get("HF_TOKEN") # May be needed for model download
if not openai.api_key:
print("\n" + "="*40)
print("ERROR: OPENAI_API_KEY environment variable not found.")
print("Please set the OPENAI_API_KEY secret/variable.")
print("OpenAI features (prompt enhancement, voice input) WILL FAIL.")
print("="*40 + "\n")
# Optionally raise an error or exit if the key is absolutely critical
# raise ValueError("OpenAI API Key not found!")
else:
print("OpenAI API Key found.")
# Model IDs
llm_model = "gpt-3.5-turbo"
sd_model_id = "runwayml/stable-diffusion-v1-5"
lcm_lora_id = "latent-consistency/lcm-lora-sdv1-5" # LCM LoRA for faster inference
# Check for GPU availability - WILL BE 'cpu' in your case
device = "cuda" if torch.cuda.is_available() else "cpu"
# Use float32 for CPU for stability/compatibility
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
print(f"Selected Device: {device.upper()}")
print(f"Selected PyTorch Dtype: {torch_dtype}")
# --- Model Loading ---
pipe = None # Initialize pipe to None
try:
print("Loading Stable Diffusion model... (This might take a while on CPU)")
pipe = StableDiffusionPipeline.from_pretrained(
sd_model_id,
torch_dtype=torch_dtype,
# use_auth_token=hf_token # Uncomment if you face download issues
)
print("Base model loaded. Loading LCM Scheduler and LoRA...")
# Using LCM Scheduler and LoRA for faster generation
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
pipe.load_lora_weights(lcm_lora_id)
pipe.fuse_lora() # Fuse LoRA for slightly faster inference after loading
pipe.to(device) # Move pipe to CPU
print("Stable Diffusion model loaded successfully with LCM-LoRA on CPU.")
# Perform a small dummy inference run to warm up / check for errors
print("Performing a quick warm-up inference...")
_ = pipe(prompt="warmup", num_inference_steps=1, guidance_scale=1.0, output_type="pil").images[0]
print("Warm-up successful.")
except Exception as e:
print(f"\n{'='*40}\nERROR loading Stable Diffusion model: {e}\n{'='*40}\n")
# pipe remains None, generation will fail gracefully later
# --- Core Functions ---
def enhance_prompt_openai(short_prompt, add_style_keywords):
"""Uses OpenAI LLM to enhance the short prompt."""
if not openai.api_key:
# Should not happen if checked at start, but good practice
return "Error: OpenAI API Key not configured."
system_message = """You are an expert prompt engineer for text-to-image models like Stable Diffusion.
Expand the user's short idea into a detailed, vivid, and structured prompt optimized for Stable Diffusion v1.5.
Include details about the subject, scene, style (e.g., photorealistic, oil painting, cinematic),
lighting (e.g., soft light, dramatic lighting), composition (e.g., wide shot, close-up),
and mood. Add high-quality keywords like 'highly detailed', 'sharp focus', 'masterpiece'.
Keep the prompt concise and effective, ideally under 100 words.""" # Slightly shorter for clarity
user_message = f"Short idea: \"{short_prompt}\""
if add_style_keywords:
user_message += "\nPlease specifically add artistic and quality keywords like 'cinematic lighting', 'photorealistic', '8k', 'masterpiece', 'professional photography'."
try:
response = openai.chat.completions.create(
model=llm_model,
messages=[
{"role": "system", "content": system_message},
{"role": "user", "content": user_message},
],
temperature=0.7,
max_tokens=150 # Reduced max tokens slightly
)
enhanced_prompt = response.choices[0].message.content.strip()
# Basic cleanup
enhanced_prompt = enhanced_prompt.replace("Here's a detailed prompt:", "").strip()
return enhanced_prompt
except Exception as e:
print(f"Error calling OpenAI API for prompt enhancement: {e}")
# Provide a more user-friendly error message
return f"Error: Could not enhance prompt using OpenAI. ({e})"
def transcribe_audio_openai(audio_path):
"""Transcribes audio using OpenAI Whisper API."""
if not audio_path:
return None
if not openai.api_key:
print("Warning: OpenAI API Key not configured. Cannot transcribe audio.")
return "Error: OpenAI API Key needed for transcription."
try:
with open(audio_path, "rb") as audio_file:
transcript = openai.audio.transcriptions.create(
model="whisper-1",
file=audio_file
)
return transcript.text
except Exception as e:
print(f"Error calling OpenAI Whisper API: {e}")
return f"Error: Could not transcribe audio using OpenAI. ({e})"
def generate_image_lcm(prompt, guidance_scale, num_inference_steps=8): # Increased steps slightly for potentially better quality on CPU
"""Generates an image using the loaded SD+LCM pipeline on CPU."""
if pipe is None:
print("Error: Stable Diffusion pipeline is not available.")
img = Image.new('RGB', (512, 512), color = (128, 128, 128)) # Grey placeholder
# Add text to placeholder if possible/easy? For now, just grey.
return img, "Error: Image generation model failed to load."
print(f"Starting image generation on CPU with prompt: '{prompt}'")
print(f"Guidance Scale: {guidance_scale}, Steps: {num_inference_steps}. BE PATIENT, THIS WILL BE SLOW.")
# LCM performs best with low guidance scale
effective_guidance = max(1.0, min(guidance_scale, 3.0))
if effective_guidance != guidance_scale:
print(f"Adjusted guidance scale to {effective_guidance} (optimal range for LCM).")
negative_prompt = "blurry, low quality, deformed, ugly, text, words, writing, signature, watermark"
start_time = time.time()
try:
# No torch.autocast(device) needed for CPU float32? Check diffusers docs.
# inference_mode is still good practice
with torch.inference_mode():
image = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
guidance_scale=effective_guidance,
num_inference_steps=num_inference_steps # LCM needs few steps
).images[0]
end_time = time.time()
duration = end_time - start_time
print(f"Image generation successful on CPU in {duration:.2f} seconds.")
return image, f"Image generated in {duration:.2f}s (CPU)." # Return image and status message
except Exception as e:
end_time = time.time()
duration = end_time - start_time
print(f"Error during image generation after {duration:.2f} seconds: {e}")
img = Image.new('RGB', (512, 512), color = (255, 100, 100)) # Red-ish placeholder
return img, f"Error generating image: {e}"
# --- Main Processing Function ---
def process_input(text_input, audio_input, add_style_keywords, guidance_scale):
"""
Main function triggered by the Gradio interface.
Handles text/audio input, enhances prompt, generates image.
"""
status_updates = []
final_text_input = ""
enhanced_prompt = ""
generated_image = None
# 1. Determine input source
if text_input and text_input.strip():
final_text_input = text_input.strip()
status_updates.append("Using provided text input.")
elif audio_input:
status_updates.append("Processing audio input...")
transcribed_text = transcribe_audio_openai(audio_input)
if transcribed_text and not transcribed_text.startswith("Error:"):
final_text_input = transcribed_text
status_updates.append(f"Transcribed Audio: \"{final_text_input[:100]}...\"" if len(final_text_input) > 100 else f"Transcribed Audio: \"{final_text_input}\"")
else:
status_updates.append(transcribed_text or "Error: Transcription failed.") # Show the error message
final_text_input = "" # Prevent proceeding if transcription fails
else:
status_updates.append("Error: Please provide a text description or record audio.")
# Return current status, empty prompt, no image
return "\n".join(status_updates), "", None
# If no valid input text after checking both sources
if not final_text_input:
return "\n".join(status_updates), "", None
# 2. Enhance Prompt
status_updates.append("Enhancing prompt using OpenAI...")
if openai.api_key:
enhanced_prompt = enhance_prompt_openai(final_text_input, add_style_keywords)
if enhanced_prompt.startswith("Error:"):
status_updates.append(enhanced_prompt) # Add error to status
# Decide if we should proceed with the *original* prompt or stop? Let's stop.
return "\n".join(status_updates), "", None
else:
status_updates.append("Prompt enhanced successfully.")
else:
status_updates.append("Warning: OpenAI API Key missing. Using original text as prompt.")
enhanced_prompt = final_text_input # Use original text if API key missing
# 3. Generate Image
status_updates.append(f"Generating image on CPU ({device})... **THIS WILL BE SLOW - PLEASE WAIT**")
# Update the UI status *before* starting generation
# yield "\n".join(status_updates), enhanced_prompt, None # Requires making the function a generator
# Simple update (blocks UI until done):
img_gen_result, img_status_msg = generate_image_lcm(enhanced_prompt, guidance_scale)
generated_image = img_gen_result
if img_status_msg:
status_updates.append(img_status_msg)
# 4. Return results
return "\n".join(status_updates), enhanced_prompt, generated_image
# --- Gradio UI ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# Prompt Enhancer & Image Generator 🪄🖼️ (CPU Version)")
gr.Markdown(
f"**WARNING:** Running on **CPU ({device.upper()})**. Image generation will be **VERY SLOW** (potentially several minutes). Please be patient after clicking Generate."
f"\nEnter a short description or record audio. It will be enhanced by `{llm_model}` and an image generated using `{sd_model_id}` + LCM acceleration."
)
with gr.Row():
with gr.Column(scale=1):
# Input Controls
text_input = gr.Textbox(
label="Short Description",
placeholder="e.g., 'magical treehouse in the sky'",
lines=2
)
audio_input = gr.Audio(
sources=["microphone"],
type="filepath", # Get file path for OpenAI API
label="Or Record Audio Input"
)
gr.Markdown("---")
gr.Markdown("**Generation Options**")
add_style_keywords = gr.Checkbox(
label="Add Extra Style Keywords (via LLM)?",
value=True,
info="Asks the LLM to add 'photorealistic', '8k', 'cinematic' etc."
)
guidance_scale = gr.Slider(
minimum=1.0,
maximum=3.0, # Keep low for LCM
step=0.1,
value=1.5, # Good default for LCM
label="Guidance Scale",
info="How closely the image follows the prompt (1-2 recommended for LCM)."
)
submit_button = gr.Button("Generate ✨ (Will be slow!)", variant="primary")
with gr.Column(scale=2):
# Output Area
status_output = gr.Textbox(
label="Status Log",
interactive=False,
lines=4 # More lines for verbose status
)
enhanced_prompt_output = gr.Textbox(
label="✨ Enhanced Prompt (from LLM)",
interactive=False,
lines=4
)
image_output = gr.Image(
label="🖼️ Generated Image (CPU)",
type="pil",
interactive=False,
height=512, # Set fixed height if desired
# width=512
)
# Connect UI elements
submit_button.click(
fn=process_input,
inputs=[
text_input,
audio_input,
add_style_keywords,
guidance_scale
],
outputs=[
status_output,
enhanced_prompt_output,
image_output
]
)
# Clear inputs upon submission for better UX
submit_button.click(lambda: ("", None), inputs=[], outputs=[text_input, audio_input])
# --- Launch the App ---
if __name__ == "__main__":
print("\nLaunching Gradio App...")
# Enable queue for better handling, especially with slow generation
# share=True can create a public link if running locally (use with caution)
demo.queue().launch(debug=False, share=False)