File size: 13,497 Bytes
fe5b715
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
# --- Filename: app.py ---

import gradio as gr
import openai
import torch
from diffusers import StableDiffusionPipeline, LCMScheduler
import os
from PIL import Image
import io # Required for handling audio file object for OpenAI API
import time # To estimate generation time

# --- Configuration ---
# Load API keys from Hugging Face Secrets or environment variables
# IMPORTANT: Ensure the secret/variable named OPENAI_API_KEY is set in your environment.
openai.api_key = os.environ.get("OPENAI_API_KEY")
hf_token = os.environ.get("HF_TOKEN") # May be needed for model download

if not openai.api_key:
    print("\n" + "="*40)
    print("ERROR: OPENAI_API_KEY environment variable not found.")
    print("Please set the OPENAI_API_KEY secret/variable.")
    print("OpenAI features (prompt enhancement, voice input) WILL FAIL.")
    print("="*40 + "\n")
    # Optionally raise an error or exit if the key is absolutely critical
    # raise ValueError("OpenAI API Key not found!")
else:
    print("OpenAI API Key found.")


# Model IDs
llm_model = "gpt-3.5-turbo"
sd_model_id = "runwayml/stable-diffusion-v1-5"
lcm_lora_id = "latent-consistency/lcm-lora-sdv1-5" # LCM LoRA for faster inference

# Check for GPU availability - WILL BE 'cpu' in your case
device = "cuda" if torch.cuda.is_available() else "cpu"
# Use float32 for CPU for stability/compatibility
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

print(f"Selected Device: {device.upper()}")
print(f"Selected PyTorch Dtype: {torch_dtype}")

# --- Model Loading ---
pipe = None # Initialize pipe to None
try:
    print("Loading Stable Diffusion model... (This might take a while on CPU)")
    pipe = StableDiffusionPipeline.from_pretrained(
        sd_model_id,
        torch_dtype=torch_dtype,
        # use_auth_token=hf_token # Uncomment if you face download issues
    )
    print("Base model loaded. Loading LCM Scheduler and LoRA...")
    # Using LCM Scheduler and LoRA for faster generation
    pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
    pipe.load_lora_weights(lcm_lora_id)
    pipe.fuse_lora() # Fuse LoRA for slightly faster inference after loading
    pipe.to(device) # Move pipe to CPU
    print("Stable Diffusion model loaded successfully with LCM-LoRA on CPU.")
    # Perform a small dummy inference run to warm up / check for errors
    print("Performing a quick warm-up inference...")
    _ = pipe(prompt="warmup", num_inference_steps=1, guidance_scale=1.0, output_type="pil").images[0]
    print("Warm-up successful.")

except Exception as e:
    print(f"\n{'='*40}\nERROR loading Stable Diffusion model: {e}\n{'='*40}\n")
    # pipe remains None, generation will fail gracefully later


# --- Core Functions ---

def enhance_prompt_openai(short_prompt, add_style_keywords):
    """Uses OpenAI LLM to enhance the short prompt."""
    if not openai.api_key:
        # Should not happen if checked at start, but good practice
        return "Error: OpenAI API Key not configured."

    system_message = """You are an expert prompt engineer for text-to-image models like Stable Diffusion.
    Expand the user's short idea into a detailed, vivid, and structured prompt optimized for Stable Diffusion v1.5.
    Include details about the subject, scene, style (e.g., photorealistic, oil painting, cinematic),
    lighting (e.g., soft light, dramatic lighting), composition (e.g., wide shot, close-up),
    and mood. Add high-quality keywords like 'highly detailed', 'sharp focus', 'masterpiece'.
    Keep the prompt concise and effective, ideally under 100 words.""" # Slightly shorter for clarity

    user_message = f"Short idea: \"{short_prompt}\""
    if add_style_keywords:
        user_message += "\nPlease specifically add artistic and quality keywords like 'cinematic lighting', 'photorealistic', '8k', 'masterpiece', 'professional photography'."

    try:
        response = openai.chat.completions.create(
            model=llm_model,
            messages=[
                {"role": "system", "content": system_message},
                {"role": "user", "content": user_message},
            ],
            temperature=0.7,
            max_tokens=150 # Reduced max tokens slightly
        )
        enhanced_prompt = response.choices[0].message.content.strip()
        # Basic cleanup
        enhanced_prompt = enhanced_prompt.replace("Here's a detailed prompt:", "").strip()
        return enhanced_prompt
    except Exception as e:
        print(f"Error calling OpenAI API for prompt enhancement: {e}")
        # Provide a more user-friendly error message
        return f"Error: Could not enhance prompt using OpenAI. ({e})"

def transcribe_audio_openai(audio_path):
    """Transcribes audio using OpenAI Whisper API."""
    if not audio_path:
        return None
    if not openai.api_key:
        print("Warning: OpenAI API Key not configured. Cannot transcribe audio.")
        return "Error: OpenAI API Key needed for transcription."

    try:
        with open(audio_path, "rb") as audio_file:
            transcript = openai.audio.transcriptions.create(
              model="whisper-1",
              file=audio_file
            )
        return transcript.text
    except Exception as e:
        print(f"Error calling OpenAI Whisper API: {e}")
        return f"Error: Could not transcribe audio using OpenAI. ({e})"

def generate_image_lcm(prompt, guidance_scale, num_inference_steps=8): # Increased steps slightly for potentially better quality on CPU
    """Generates an image using the loaded SD+LCM pipeline on CPU."""
    if pipe is None:
        print("Error: Stable Diffusion pipeline is not available.")
        img = Image.new('RGB', (512, 512), color = (128, 128, 128)) # Grey placeholder
        # Add text to placeholder if possible/easy? For now, just grey.
        return img, "Error: Image generation model failed to load."

    print(f"Starting image generation on CPU with prompt: '{prompt}'")
    print(f"Guidance Scale: {guidance_scale}, Steps: {num_inference_steps}. BE PATIENT, THIS WILL BE SLOW.")

    # LCM performs best with low guidance scale
    effective_guidance = max(1.0, min(guidance_scale, 3.0))
    if effective_guidance != guidance_scale:
        print(f"Adjusted guidance scale to {effective_guidance} (optimal range for LCM).")

    negative_prompt = "blurry, low quality, deformed, ugly, text, words, writing, signature, watermark"

    start_time = time.time()
    try:
        # No torch.autocast(device) needed for CPU float32? Check diffusers docs.
        # inference_mode is still good practice
        with torch.inference_mode():
             image = pipe(
                prompt=prompt,
                negative_prompt=negative_prompt,
                guidance_scale=effective_guidance,
                num_inference_steps=num_inference_steps # LCM needs few steps
             ).images[0]
        end_time = time.time()
        duration = end_time - start_time
        print(f"Image generation successful on CPU in {duration:.2f} seconds.")
        return image, f"Image generated in {duration:.2f}s (CPU)." # Return image and status message
    except Exception as e:
        end_time = time.time()
        duration = end_time - start_time
        print(f"Error during image generation after {duration:.2f} seconds: {e}")
        img = Image.new('RGB', (512, 512), color = (255, 100, 100)) # Red-ish placeholder
        return img, f"Error generating image: {e}"

# --- Main Processing Function ---

def process_input(text_input, audio_input, add_style_keywords, guidance_scale):
    """
    Main function triggered by the Gradio interface.
    Handles text/audio input, enhances prompt, generates image.
    """
    status_updates = []
    final_text_input = ""
    enhanced_prompt = ""
    generated_image = None

    # 1. Determine input source
    if text_input and text_input.strip():
        final_text_input = text_input.strip()
        status_updates.append("Using provided text input.")
    elif audio_input:
        status_updates.append("Processing audio input...")
        transcribed_text = transcribe_audio_openai(audio_input)
        if transcribed_text and not transcribed_text.startswith("Error:"):
            final_text_input = transcribed_text
            status_updates.append(f"Transcribed Audio: \"{final_text_input[:100]}...\"" if len(final_text_input) > 100 else f"Transcribed Audio: \"{final_text_input}\"")
        else:
            status_updates.append(transcribed_text or "Error: Transcription failed.") # Show the error message
            final_text_input = "" # Prevent proceeding if transcription fails
    else:
        status_updates.append("Error: Please provide a text description or record audio.")
        # Return current status, empty prompt, no image
        return "\n".join(status_updates), "", None

    # If no valid input text after checking both sources
    if not final_text_input:
         return "\n".join(status_updates), "", None

    # 2. Enhance Prompt
    status_updates.append("Enhancing prompt using OpenAI...")
    if openai.api_key:
        enhanced_prompt = enhance_prompt_openai(final_text_input, add_style_keywords)
        if enhanced_prompt.startswith("Error:"):
            status_updates.append(enhanced_prompt) # Add error to status
            # Decide if we should proceed with the *original* prompt or stop? Let's stop.
            return "\n".join(status_updates), "", None
        else:
             status_updates.append("Prompt enhanced successfully.")
    else:
        status_updates.append("Warning: OpenAI API Key missing. Using original text as prompt.")
        enhanced_prompt = final_text_input # Use original text if API key missing

    # 3. Generate Image
    status_updates.append(f"Generating image on CPU ({device})... **THIS WILL BE SLOW - PLEASE WAIT**")
    # Update the UI status *before* starting generation
    # yield "\n".join(status_updates), enhanced_prompt, None # Requires making the function a generator

    # Simple update (blocks UI until done):
    img_gen_result, img_status_msg = generate_image_lcm(enhanced_prompt, guidance_scale)
    generated_image = img_gen_result
    if img_status_msg:
        status_updates.append(img_status_msg)

    # 4. Return results
    return "\n".join(status_updates), enhanced_prompt, generated_image


# --- Gradio UI ---

with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("# Prompt Enhancer & Image Generator 🪄🖼️ (CPU Version)")
    gr.Markdown(
        f"**WARNING:** Running on **CPU ({device.upper()})**. Image generation will be **VERY SLOW** (potentially several minutes). Please be patient after clicking Generate."
        f"\nEnter a short description or record audio. It will be enhanced by `{llm_model}` and an image generated using `{sd_model_id}` + LCM acceleration."
    )

    with gr.Row():
        with gr.Column(scale=1):
            # Input Controls
            text_input = gr.Textbox(
                label="Short Description",
                placeholder="e.g., 'magical treehouse in the sky'",
                lines=2
            )
            audio_input = gr.Audio(
                sources=["microphone"],
                type="filepath", # Get file path for OpenAI API
                label="Or Record Audio Input"
            )
            gr.Markdown("---")
            gr.Markdown("**Generation Options**")
            add_style_keywords = gr.Checkbox(
                label="Add Extra Style Keywords (via LLM)?",
                value=True,
                info="Asks the LLM to add 'photorealistic', '8k', 'cinematic' etc."
            )
            guidance_scale = gr.Slider(
                minimum=1.0,
                maximum=3.0, # Keep low for LCM
                step=0.1,
                value=1.5, # Good default for LCM
                label="Guidance Scale",
                info="How closely the image follows the prompt (1-2 recommended for LCM)."
            )
            submit_button = gr.Button("Generate ✨ (Will be slow!)", variant="primary")

        with gr.Column(scale=2):
            # Output Area
            status_output = gr.Textbox(
                label="Status Log",
                interactive=False,
                lines=4 # More lines for verbose status
            )
            enhanced_prompt_output = gr.Textbox(
                label="✨ Enhanced Prompt (from LLM)",
                interactive=False,
                lines=4
            )
            image_output = gr.Image(
                label="🖼️ Generated Image (CPU)",
                type="pil",
                interactive=False,
                height=512, # Set fixed height if desired
                # width=512
            )

    # Connect UI elements
    submit_button.click(
        fn=process_input,
        inputs=[
            text_input,
            audio_input,
            add_style_keywords,
            guidance_scale
        ],
        outputs=[
            status_output,
            enhanced_prompt_output,
            image_output
        ]
    )

    # Clear inputs upon submission for better UX
    submit_button.click(lambda: ("", None), inputs=[], outputs=[text_input, audio_input])


# --- Launch the App ---
if __name__ == "__main__":
    print("\nLaunching Gradio App...")
    # Enable queue for better handling, especially with slow generation
    # share=True can create a public link if running locally (use with caution)
    demo.queue().launch(debug=False, share=False)