JQ66 commited on
Commit
fe5b715
·
verified ·
1 Parent(s): 4a2e843

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +313 -0
app.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --- Filename: app.py ---
2
+
3
+ import gradio as gr
4
+ import openai
5
+ import torch
6
+ from diffusers import StableDiffusionPipeline, LCMScheduler
7
+ import os
8
+ from PIL import Image
9
+ import io # Required for handling audio file object for OpenAI API
10
+ import time # To estimate generation time
11
+
12
+ # --- Configuration ---
13
+ # Load API keys from Hugging Face Secrets or environment variables
14
+ # IMPORTANT: Ensure the secret/variable named OPENAI_API_KEY is set in your environment.
15
+ openai.api_key = os.environ.get("OPENAI_API_KEY")
16
+ hf_token = os.environ.get("HF_TOKEN") # May be needed for model download
17
+
18
+ if not openai.api_key:
19
+ print("\n" + "="*40)
20
+ print("ERROR: OPENAI_API_KEY environment variable not found.")
21
+ print("Please set the OPENAI_API_KEY secret/variable.")
22
+ print("OpenAI features (prompt enhancement, voice input) WILL FAIL.")
23
+ print("="*40 + "\n")
24
+ # Optionally raise an error or exit if the key is absolutely critical
25
+ # raise ValueError("OpenAI API Key not found!")
26
+ else:
27
+ print("OpenAI API Key found.")
28
+
29
+
30
+ # Model IDs
31
+ llm_model = "gpt-3.5-turbo"
32
+ sd_model_id = "runwayml/stable-diffusion-v1-5"
33
+ lcm_lora_id = "latent-consistency/lcm-lora-sdv1-5" # LCM LoRA for faster inference
34
+
35
+ # Check for GPU availability - WILL BE 'cpu' in your case
36
+ device = "cuda" if torch.cuda.is_available() else "cpu"
37
+ # Use float32 for CPU for stability/compatibility
38
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
39
+
40
+ print(f"Selected Device: {device.upper()}")
41
+ print(f"Selected PyTorch Dtype: {torch_dtype}")
42
+
43
+ # --- Model Loading ---
44
+ pipe = None # Initialize pipe to None
45
+ try:
46
+ print("Loading Stable Diffusion model... (This might take a while on CPU)")
47
+ pipe = StableDiffusionPipeline.from_pretrained(
48
+ sd_model_id,
49
+ torch_dtype=torch_dtype,
50
+ # use_auth_token=hf_token # Uncomment if you face download issues
51
+ )
52
+ print("Base model loaded. Loading LCM Scheduler and LoRA...")
53
+ # Using LCM Scheduler and LoRA for faster generation
54
+ pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
55
+ pipe.load_lora_weights(lcm_lora_id)
56
+ pipe.fuse_lora() # Fuse LoRA for slightly faster inference after loading
57
+ pipe.to(device) # Move pipe to CPU
58
+ print("Stable Diffusion model loaded successfully with LCM-LoRA on CPU.")
59
+ # Perform a small dummy inference run to warm up / check for errors
60
+ print("Performing a quick warm-up inference...")
61
+ _ = pipe(prompt="warmup", num_inference_steps=1, guidance_scale=1.0, output_type="pil").images[0]
62
+ print("Warm-up successful.")
63
+
64
+ except Exception as e:
65
+ print(f"\n{'='*40}\nERROR loading Stable Diffusion model: {e}\n{'='*40}\n")
66
+ # pipe remains None, generation will fail gracefully later
67
+
68
+
69
+ # --- Core Functions ---
70
+
71
+ def enhance_prompt_openai(short_prompt, add_style_keywords):
72
+ """Uses OpenAI LLM to enhance the short prompt."""
73
+ if not openai.api_key:
74
+ # Should not happen if checked at start, but good practice
75
+ return "Error: OpenAI API Key not configured."
76
+
77
+ system_message = """You are an expert prompt engineer for text-to-image models like Stable Diffusion.
78
+ Expand the user's short idea into a detailed, vivid, and structured prompt optimized for Stable Diffusion v1.5.
79
+ Include details about the subject, scene, style (e.g., photorealistic, oil painting, cinematic),
80
+ lighting (e.g., soft light, dramatic lighting), composition (e.g., wide shot, close-up),
81
+ and mood. Add high-quality keywords like 'highly detailed', 'sharp focus', 'masterpiece'.
82
+ Keep the prompt concise and effective, ideally under 100 words.""" # Slightly shorter for clarity
83
+
84
+ user_message = f"Short idea: \"{short_prompt}\""
85
+ if add_style_keywords:
86
+ user_message += "\nPlease specifically add artistic and quality keywords like 'cinematic lighting', 'photorealistic', '8k', 'masterpiece', 'professional photography'."
87
+
88
+ try:
89
+ response = openai.chat.completions.create(
90
+ model=llm_model,
91
+ messages=[
92
+ {"role": "system", "content": system_message},
93
+ {"role": "user", "content": user_message},
94
+ ],
95
+ temperature=0.7,
96
+ max_tokens=150 # Reduced max tokens slightly
97
+ )
98
+ enhanced_prompt = response.choices[0].message.content.strip()
99
+ # Basic cleanup
100
+ enhanced_prompt = enhanced_prompt.replace("Here's a detailed prompt:", "").strip()
101
+ return enhanced_prompt
102
+ except Exception as e:
103
+ print(f"Error calling OpenAI API for prompt enhancement: {e}")
104
+ # Provide a more user-friendly error message
105
+ return f"Error: Could not enhance prompt using OpenAI. ({e})"
106
+
107
+ def transcribe_audio_openai(audio_path):
108
+ """Transcribes audio using OpenAI Whisper API."""
109
+ if not audio_path:
110
+ return None
111
+ if not openai.api_key:
112
+ print("Warning: OpenAI API Key not configured. Cannot transcribe audio.")
113
+ return "Error: OpenAI API Key needed for transcription."
114
+
115
+ try:
116
+ with open(audio_path, "rb") as audio_file:
117
+ transcript = openai.audio.transcriptions.create(
118
+ model="whisper-1",
119
+ file=audio_file
120
+ )
121
+ return transcript.text
122
+ except Exception as e:
123
+ print(f"Error calling OpenAI Whisper API: {e}")
124
+ return f"Error: Could not transcribe audio using OpenAI. ({e})"
125
+
126
+ def generate_image_lcm(prompt, guidance_scale, num_inference_steps=8): # Increased steps slightly for potentially better quality on CPU
127
+ """Generates an image using the loaded SD+LCM pipeline on CPU."""
128
+ if pipe is None:
129
+ print("Error: Stable Diffusion pipeline is not available.")
130
+ img = Image.new('RGB', (512, 512), color = (128, 128, 128)) # Grey placeholder
131
+ # Add text to placeholder if possible/easy? For now, just grey.
132
+ return img, "Error: Image generation model failed to load."
133
+
134
+ print(f"Starting image generation on CPU with prompt: '{prompt}'")
135
+ print(f"Guidance Scale: {guidance_scale}, Steps: {num_inference_steps}. BE PATIENT, THIS WILL BE SLOW.")
136
+
137
+ # LCM performs best with low guidance scale
138
+ effective_guidance = max(1.0, min(guidance_scale, 3.0))
139
+ if effective_guidance != guidance_scale:
140
+ print(f"Adjusted guidance scale to {effective_guidance} (optimal range for LCM).")
141
+
142
+ negative_prompt = "blurry, low quality, deformed, ugly, text, words, writing, signature, watermark"
143
+
144
+ start_time = time.time()
145
+ try:
146
+ # No torch.autocast(device) needed for CPU float32? Check diffusers docs.
147
+ # inference_mode is still good practice
148
+ with torch.inference_mode():
149
+ image = pipe(
150
+ prompt=prompt,
151
+ negative_prompt=negative_prompt,
152
+ guidance_scale=effective_guidance,
153
+ num_inference_steps=num_inference_steps # LCM needs few steps
154
+ ).images[0]
155
+ end_time = time.time()
156
+ duration = end_time - start_time
157
+ print(f"Image generation successful on CPU in {duration:.2f} seconds.")
158
+ return image, f"Image generated in {duration:.2f}s (CPU)." # Return image and status message
159
+ except Exception as e:
160
+ end_time = time.time()
161
+ duration = end_time - start_time
162
+ print(f"Error during image generation after {duration:.2f} seconds: {e}")
163
+ img = Image.new('RGB', (512, 512), color = (255, 100, 100)) # Red-ish placeholder
164
+ return img, f"Error generating image: {e}"
165
+
166
+ # --- Main Processing Function ---
167
+
168
+ def process_input(text_input, audio_input, add_style_keywords, guidance_scale):
169
+ """
170
+ Main function triggered by the Gradio interface.
171
+ Handles text/audio input, enhances prompt, generates image.
172
+ """
173
+ status_updates = []
174
+ final_text_input = ""
175
+ enhanced_prompt = ""
176
+ generated_image = None
177
+
178
+ # 1. Determine input source
179
+ if text_input and text_input.strip():
180
+ final_text_input = text_input.strip()
181
+ status_updates.append("Using provided text input.")
182
+ elif audio_input:
183
+ status_updates.append("Processing audio input...")
184
+ transcribed_text = transcribe_audio_openai(audio_input)
185
+ if transcribed_text and not transcribed_text.startswith("Error:"):
186
+ final_text_input = transcribed_text
187
+ status_updates.append(f"Transcribed Audio: \"{final_text_input[:100]}...\"" if len(final_text_input) > 100 else f"Transcribed Audio: \"{final_text_input}\"")
188
+ else:
189
+ status_updates.append(transcribed_text or "Error: Transcription failed.") # Show the error message
190
+ final_text_input = "" # Prevent proceeding if transcription fails
191
+ else:
192
+ status_updates.append("Error: Please provide a text description or record audio.")
193
+ # Return current status, empty prompt, no image
194
+ return "\n".join(status_updates), "", None
195
+
196
+ # If no valid input text after checking both sources
197
+ if not final_text_input:
198
+ return "\n".join(status_updates), "", None
199
+
200
+ # 2. Enhance Prompt
201
+ status_updates.append("Enhancing prompt using OpenAI...")
202
+ if openai.api_key:
203
+ enhanced_prompt = enhance_prompt_openai(final_text_input, add_style_keywords)
204
+ if enhanced_prompt.startswith("Error:"):
205
+ status_updates.append(enhanced_prompt) # Add error to status
206
+ # Decide if we should proceed with the *original* prompt or stop? Let's stop.
207
+ return "\n".join(status_updates), "", None
208
+ else:
209
+ status_updates.append("Prompt enhanced successfully.")
210
+ else:
211
+ status_updates.append("Warning: OpenAI API Key missing. Using original text as prompt.")
212
+ enhanced_prompt = final_text_input # Use original text if API key missing
213
+
214
+ # 3. Generate Image
215
+ status_updates.append(f"Generating image on CPU ({device})... **THIS WILL BE SLOW - PLEASE WAIT**")
216
+ # Update the UI status *before* starting generation
217
+ # yield "\n".join(status_updates), enhanced_prompt, None # Requires making the function a generator
218
+
219
+ # Simple update (blocks UI until done):
220
+ img_gen_result, img_status_msg = generate_image_lcm(enhanced_prompt, guidance_scale)
221
+ generated_image = img_gen_result
222
+ if img_status_msg:
223
+ status_updates.append(img_status_msg)
224
+
225
+ # 4. Return results
226
+ return "\n".join(status_updates), enhanced_prompt, generated_image
227
+
228
+
229
+ # --- Gradio UI ---
230
+
231
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
232
+ gr.Markdown("# Prompt Enhancer & Image Generator 🪄🖼️ (CPU Version)")
233
+ gr.Markdown(
234
+ f"**WARNING:** Running on **CPU ({device.upper()})**. Image generation will be **VERY SLOW** (potentially several minutes). Please be patient after clicking Generate."
235
+ f"\nEnter a short description or record audio. It will be enhanced by `{llm_model}` and an image generated using `{sd_model_id}` + LCM acceleration."
236
+ )
237
+
238
+ with gr.Row():
239
+ with gr.Column(scale=1):
240
+ # Input Controls
241
+ text_input = gr.Textbox(
242
+ label="Short Description",
243
+ placeholder="e.g., 'magical treehouse in the sky'",
244
+ lines=2
245
+ )
246
+ audio_input = gr.Audio(
247
+ sources=["microphone"],
248
+ type="filepath", # Get file path for OpenAI API
249
+ label="Or Record Audio Input"
250
+ )
251
+ gr.Markdown("---")
252
+ gr.Markdown("**Generation Options**")
253
+ add_style_keywords = gr.Checkbox(
254
+ label="Add Extra Style Keywords (via LLM)?",
255
+ value=True,
256
+ info="Asks the LLM to add 'photorealistic', '8k', 'cinematic' etc."
257
+ )
258
+ guidance_scale = gr.Slider(
259
+ minimum=1.0,
260
+ maximum=3.0, # Keep low for LCM
261
+ step=0.1,
262
+ value=1.5, # Good default for LCM
263
+ label="Guidance Scale",
264
+ info="How closely the image follows the prompt (1-2 recommended for LCM)."
265
+ )
266
+ submit_button = gr.Button("Generate ✨ (Will be slow!)", variant="primary")
267
+
268
+ with gr.Column(scale=2):
269
+ # Output Area
270
+ status_output = gr.Textbox(
271
+ label="Status Log",
272
+ interactive=False,
273
+ lines=4 # More lines for verbose status
274
+ )
275
+ enhanced_prompt_output = gr.Textbox(
276
+ label="✨ Enhanced Prompt (from LLM)",
277
+ interactive=False,
278
+ lines=4
279
+ )
280
+ image_output = gr.Image(
281
+ label="🖼️ Generated Image (CPU)",
282
+ type="pil",
283
+ interactive=False,
284
+ height=512, # Set fixed height if desired
285
+ # width=512
286
+ )
287
+
288
+ # Connect UI elements
289
+ submit_button.click(
290
+ fn=process_input,
291
+ inputs=[
292
+ text_input,
293
+ audio_input,
294
+ add_style_keywords,
295
+ guidance_scale
296
+ ],
297
+ outputs=[
298
+ status_output,
299
+ enhanced_prompt_output,
300
+ image_output
301
+ ]
302
+ )
303
+
304
+ # Clear inputs upon submission for better UX
305
+ submit_button.click(lambda: ("", None), inputs=[], outputs=[text_input, audio_input])
306
+
307
+
308
+ # --- Launch the App ---
309
+ if __name__ == "__main__":
310
+ print("\nLaunching Gradio App...")
311
+ # Enable queue for better handling, especially with slow generation
312
+ # share=True can create a public link if running locally (use with caution)
313
+ demo.queue().launch(debug=False, share=False)