|
|
|
|
|
|
|
|
import gradio as gr |
|
|
import openai |
|
|
import torch |
|
|
from diffusers import StableDiffusionPipeline, LCMScheduler |
|
|
import os |
|
|
from PIL import Image |
|
|
import io |
|
|
import time |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
openai.api_key = os.environ.get("OPENAI_API_KEY") |
|
|
hf_token = os.environ.get("HF_TOKEN") |
|
|
|
|
|
if not openai.api_key: |
|
|
print("\n" + "="*40) |
|
|
print("ERROR: OPENAI_API_KEY environment variable not found.") |
|
|
print("Please set the OPENAI_API_KEY secret/variable.") |
|
|
print("OpenAI features (prompt enhancement, voice input) WILL FAIL.") |
|
|
print("="*40 + "\n") |
|
|
|
|
|
|
|
|
else: |
|
|
print("OpenAI API Key found.") |
|
|
|
|
|
|
|
|
|
|
|
llm_model = "gpt-3.5-turbo" |
|
|
sd_model_id = "runwayml/stable-diffusion-v1-5" |
|
|
lcm_lora_id = "latent-consistency/lcm-lora-sdv1-5" |
|
|
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 |
|
|
|
|
|
print(f"Selected Device: {device.upper()}") |
|
|
print(f"Selected PyTorch Dtype: {torch_dtype}") |
|
|
|
|
|
|
|
|
pipe = None |
|
|
try: |
|
|
print("Loading Stable Diffusion model... (This might take a while on CPU)") |
|
|
pipe = StableDiffusionPipeline.from_pretrained( |
|
|
sd_model_id, |
|
|
torch_dtype=torch_dtype, |
|
|
|
|
|
) |
|
|
print("Base model loaded. Loading LCM Scheduler and LoRA...") |
|
|
|
|
|
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config) |
|
|
pipe.load_lora_weights(lcm_lora_id) |
|
|
pipe.fuse_lora() |
|
|
pipe.to(device) |
|
|
print("Stable Diffusion model loaded successfully with LCM-LoRA on CPU.") |
|
|
|
|
|
print("Performing a quick warm-up inference...") |
|
|
_ = pipe(prompt="warmup", num_inference_steps=1, guidance_scale=1.0, output_type="pil").images[0] |
|
|
print("Warm-up successful.") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"\n{'='*40}\nERROR loading Stable Diffusion model: {e}\n{'='*40}\n") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def enhance_prompt_openai(short_prompt, add_style_keywords): |
|
|
"""Uses OpenAI LLM to enhance the short prompt.""" |
|
|
if not openai.api_key: |
|
|
|
|
|
return "Error: OpenAI API Key not configured." |
|
|
|
|
|
system_message = """You are an expert prompt engineer for text-to-image models like Stable Diffusion. |
|
|
Expand the user's short idea into a detailed, vivid, and structured prompt optimized for Stable Diffusion v1.5. |
|
|
Include details about the subject, scene, style (e.g., photorealistic, oil painting, cinematic), |
|
|
lighting (e.g., soft light, dramatic lighting), composition (e.g., wide shot, close-up), |
|
|
and mood. Add high-quality keywords like 'highly detailed', 'sharp focus', 'masterpiece'. |
|
|
Keep the prompt concise and effective, ideally under 100 words.""" |
|
|
|
|
|
user_message = f"Short idea: \"{short_prompt}\"" |
|
|
if add_style_keywords: |
|
|
user_message += "\nPlease specifically add artistic and quality keywords like 'cinematic lighting', 'photorealistic', '8k', 'masterpiece', 'professional photography'." |
|
|
|
|
|
try: |
|
|
response = openai.chat.completions.create( |
|
|
model=llm_model, |
|
|
messages=[ |
|
|
{"role": "system", "content": system_message}, |
|
|
{"role": "user", "content": user_message}, |
|
|
], |
|
|
temperature=0.7, |
|
|
max_tokens=150 |
|
|
) |
|
|
enhanced_prompt = response.choices[0].message.content.strip() |
|
|
|
|
|
enhanced_prompt = enhanced_prompt.replace("Here's a detailed prompt:", "").strip() |
|
|
return enhanced_prompt |
|
|
except Exception as e: |
|
|
print(f"Error calling OpenAI API for prompt enhancement: {e}") |
|
|
|
|
|
return f"Error: Could not enhance prompt using OpenAI. ({e})" |
|
|
|
|
|
def transcribe_audio_openai(audio_path): |
|
|
"""Transcribes audio using OpenAI Whisper API.""" |
|
|
if not audio_path: |
|
|
return None |
|
|
if not openai.api_key: |
|
|
print("Warning: OpenAI API Key not configured. Cannot transcribe audio.") |
|
|
return "Error: OpenAI API Key needed for transcription." |
|
|
|
|
|
try: |
|
|
with open(audio_path, "rb") as audio_file: |
|
|
transcript = openai.audio.transcriptions.create( |
|
|
model="whisper-1", |
|
|
file=audio_file |
|
|
) |
|
|
return transcript.text |
|
|
except Exception as e: |
|
|
print(f"Error calling OpenAI Whisper API: {e}") |
|
|
return f"Error: Could not transcribe audio using OpenAI. ({e})" |
|
|
|
|
|
def generate_image_lcm(prompt, guidance_scale, num_inference_steps=8): |
|
|
"""Generates an image using the loaded SD+LCM pipeline on CPU.""" |
|
|
if pipe is None: |
|
|
print("Error: Stable Diffusion pipeline is not available.") |
|
|
img = Image.new('RGB', (512, 512), color = (128, 128, 128)) |
|
|
|
|
|
return img, "Error: Image generation model failed to load." |
|
|
|
|
|
print(f"Starting image generation on CPU with prompt: '{prompt}'") |
|
|
print(f"Guidance Scale: {guidance_scale}, Steps: {num_inference_steps}. BE PATIENT, THIS WILL BE SLOW.") |
|
|
|
|
|
|
|
|
effective_guidance = max(1.0, min(guidance_scale, 3.0)) |
|
|
if effective_guidance != guidance_scale: |
|
|
print(f"Adjusted guidance scale to {effective_guidance} (optimal range for LCM).") |
|
|
|
|
|
negative_prompt = "blurry, low quality, deformed, ugly, text, words, writing, signature, watermark" |
|
|
|
|
|
start_time = time.time() |
|
|
try: |
|
|
|
|
|
|
|
|
with torch.inference_mode(): |
|
|
image = pipe( |
|
|
prompt=prompt, |
|
|
negative_prompt=negative_prompt, |
|
|
guidance_scale=effective_guidance, |
|
|
num_inference_steps=num_inference_steps |
|
|
).images[0] |
|
|
end_time = time.time() |
|
|
duration = end_time - start_time |
|
|
print(f"Image generation successful on CPU in {duration:.2f} seconds.") |
|
|
return image, f"Image generated in {duration:.2f}s (CPU)." |
|
|
except Exception as e: |
|
|
end_time = time.time() |
|
|
duration = end_time - start_time |
|
|
print(f"Error during image generation after {duration:.2f} seconds: {e}") |
|
|
img = Image.new('RGB', (512, 512), color = (255, 100, 100)) |
|
|
return img, f"Error generating image: {e}" |
|
|
|
|
|
|
|
|
|
|
|
def process_input(text_input, audio_input, add_style_keywords, guidance_scale): |
|
|
""" |
|
|
Main function triggered by the Gradio interface. |
|
|
Handles text/audio input, enhances prompt, generates image. |
|
|
""" |
|
|
status_updates = [] |
|
|
final_text_input = "" |
|
|
enhanced_prompt = "" |
|
|
generated_image = None |
|
|
|
|
|
|
|
|
if text_input and text_input.strip(): |
|
|
final_text_input = text_input.strip() |
|
|
status_updates.append("Using provided text input.") |
|
|
elif audio_input: |
|
|
status_updates.append("Processing audio input...") |
|
|
transcribed_text = transcribe_audio_openai(audio_input) |
|
|
if transcribed_text and not transcribed_text.startswith("Error:"): |
|
|
final_text_input = transcribed_text |
|
|
status_updates.append(f"Transcribed Audio: \"{final_text_input[:100]}...\"" if len(final_text_input) > 100 else f"Transcribed Audio: \"{final_text_input}\"") |
|
|
else: |
|
|
status_updates.append(transcribed_text or "Error: Transcription failed.") |
|
|
final_text_input = "" |
|
|
else: |
|
|
status_updates.append("Error: Please provide a text description or record audio.") |
|
|
|
|
|
return "\n".join(status_updates), "", None |
|
|
|
|
|
|
|
|
if not final_text_input: |
|
|
return "\n".join(status_updates), "", None |
|
|
|
|
|
|
|
|
status_updates.append("Enhancing prompt using OpenAI...") |
|
|
if openai.api_key: |
|
|
enhanced_prompt = enhance_prompt_openai(final_text_input, add_style_keywords) |
|
|
if enhanced_prompt.startswith("Error:"): |
|
|
status_updates.append(enhanced_prompt) |
|
|
|
|
|
return "\n".join(status_updates), "", None |
|
|
else: |
|
|
status_updates.append("Prompt enhanced successfully.") |
|
|
else: |
|
|
status_updates.append("Warning: OpenAI API Key missing. Using original text as prompt.") |
|
|
enhanced_prompt = final_text_input |
|
|
|
|
|
|
|
|
status_updates.append(f"Generating image on CPU ({device})... **THIS WILL BE SLOW - PLEASE WAIT**") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
img_gen_result, img_status_msg = generate_image_lcm(enhanced_prompt, guidance_scale) |
|
|
generated_image = img_gen_result |
|
|
if img_status_msg: |
|
|
status_updates.append(img_status_msg) |
|
|
|
|
|
|
|
|
return "\n".join(status_updates), enhanced_prompt, generated_image |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft()) as demo: |
|
|
gr.Markdown("# Prompt Enhancer & Image Generator 🪄🖼️ (CPU Version)") |
|
|
gr.Markdown( |
|
|
f"**WARNING:** Running on **CPU ({device.upper()})**. Image generation will be **VERY SLOW** (potentially several minutes). Please be patient after clicking Generate." |
|
|
f"\nEnter a short description or record audio. It will be enhanced by `{llm_model}` and an image generated using `{sd_model_id}` + LCM acceleration." |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
|
|
|
text_input = gr.Textbox( |
|
|
label="Short Description", |
|
|
placeholder="e.g., 'magical treehouse in the sky'", |
|
|
lines=2 |
|
|
) |
|
|
audio_input = gr.Audio( |
|
|
sources=["microphone"], |
|
|
type="filepath", |
|
|
label="Or Record Audio Input" |
|
|
) |
|
|
gr.Markdown("---") |
|
|
gr.Markdown("**Generation Options**") |
|
|
add_style_keywords = gr.Checkbox( |
|
|
label="Add Extra Style Keywords (via LLM)?", |
|
|
value=True, |
|
|
info="Asks the LLM to add 'photorealistic', '8k', 'cinematic' etc." |
|
|
) |
|
|
guidance_scale = gr.Slider( |
|
|
minimum=1.0, |
|
|
maximum=3.0, |
|
|
step=0.1, |
|
|
value=1.5, |
|
|
label="Guidance Scale", |
|
|
info="How closely the image follows the prompt (1-2 recommended for LCM)." |
|
|
) |
|
|
submit_button = gr.Button("Generate ✨ (Will be slow!)", variant="primary") |
|
|
|
|
|
with gr.Column(scale=2): |
|
|
|
|
|
status_output = gr.Textbox( |
|
|
label="Status Log", |
|
|
interactive=False, |
|
|
lines=4 |
|
|
) |
|
|
enhanced_prompt_output = gr.Textbox( |
|
|
label="✨ Enhanced Prompt (from LLM)", |
|
|
interactive=False, |
|
|
lines=4 |
|
|
) |
|
|
image_output = gr.Image( |
|
|
label="🖼️ Generated Image (CPU)", |
|
|
type="pil", |
|
|
interactive=False, |
|
|
height=512, |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
submit_button.click( |
|
|
fn=process_input, |
|
|
inputs=[ |
|
|
text_input, |
|
|
audio_input, |
|
|
add_style_keywords, |
|
|
guidance_scale |
|
|
], |
|
|
outputs=[ |
|
|
status_output, |
|
|
enhanced_prompt_output, |
|
|
image_output |
|
|
] |
|
|
) |
|
|
|
|
|
|
|
|
submit_button.click(lambda: ("", None), inputs=[], outputs=[text_input, audio_input]) |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
print("\nLaunching Gradio App...") |
|
|
|
|
|
|
|
|
demo.queue().launch(debug=False, share=False) |