Spaces:
Runtime error
Runtime error
| import spaces | |
| import os | |
| import gradio as gr | |
| import random | |
| import torch | |
| import logging | |
| import numpy as np | |
| from typing import Dict, Any, List | |
| from diffusers import DiffusionPipeline | |
| from api import PromptEnhancementSystem | |
| # Constants | |
| MAX_SEED = np.iinfo(np.int32).max | |
| MAX_IMAGE_SIZE = 2048 | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| MODEL_ID = "black-forest-labs/FLUX.1-schnell" | |
| DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32 | |
| print(f"Using device: {DEVICE}") | |
| logger = logging.getLogger(__name__) | |
| # Initialize model | |
| try: | |
| print("Loading model...") | |
| pipe = DiffusionPipeline.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=DTYPE | |
| ).to(DEVICE) | |
| print("Model loaded successfully") | |
| logger.info("Model loaded successfully") | |
| except Exception as e: | |
| print(f"Failed to load model: {str(e)}") | |
| logger.error(f"Failed to load model: {str(e)}") | |
| raise | |
| def generate_multiple_images_batch( | |
| improvement_axes, | |
| current_gallery, | |
| seed=42, | |
| randomize_seed=False, | |
| width=512, | |
| height=512, | |
| num_inference_steps=4, | |
| current_prompt="", | |
| initial_prompt="", | |
| progress=gr.Progress(track_tqdm=True) | |
| ): | |
| try: | |
| # Use current_prompt if not empty, otherwise fall back to initial_prompt | |
| input_prompt = current_prompt if current_prompt.strip() else initial_prompt | |
| # Extract prompts from improvement axes or use the input prompt if no axes | |
| prompts = [axis["enhanced_prompt"] for axis in improvement_axes if axis.get("enhanced_prompt")] | |
| if not prompts and input_prompt: | |
| prompts = [input_prompt] | |
| if not prompts: | |
| return [None] * 4 + [current_gallery] + [seed] | |
| if randomize_seed: | |
| current_seed = random.randint(0, MAX_SEED) | |
| else: | |
| current_seed = seed | |
| print(f"Generating images with prompt: {input_prompt}") | |
| print(f"Using seed: {current_seed}") | |
| # Generate images with the selected prompt | |
| generator = torch.Generator().manual_seed(current_seed) | |
| images = pipe( | |
| prompt=prompts, | |
| width=width, | |
| height=height, | |
| num_inference_steps=num_inference_steps, | |
| generator=generator, | |
| max_sequence_length=256, # Maximum allowed for schnell | |
| guidance_scale=0.0 | |
| ).images | |
| # Pad with None if we have fewer than 4 images | |
| while len(images) < 4: | |
| images.append(None) | |
| # Update gallery with new images | |
| current_gallery = current_gallery or [] | |
| new_gallery = current_gallery + [(img, f"Prompt: {prompt}") for img, prompt in zip(images, prompts) if img is not None] | |
| print("All images generated successfully") | |
| return images[:4] + [new_gallery] + [current_seed] | |
| except Exception as e: | |
| print(f"Image generation error: {str(e)}") | |
| logger.error(f"Image generation error: {str(e)}") | |
| raise | |
| def handle_image_select(evt: gr.SelectData, improvement_axes_data): | |
| try: | |
| if improvement_axes_data and isinstance(improvement_axes_data, list): | |
| selected_index = evt.index[1] if isinstance(evt.index, tuple) else evt.index | |
| if selected_index < len(improvement_axes_data): | |
| selected_prompt = improvement_axes_data[selected_index].get("enhanced_prompt", "") | |
| return selected_prompt | |
| return "" | |
| except Exception as e: | |
| print(f"Error in handle_image_select: {str(e)}") | |
| return "" | |
| def handle_gallery_select(evt: gr.SelectData, gallery_data): | |
| try: | |
| if gallery_data and isinstance(evt.index, int) and evt.index < len(gallery_data): | |
| image, prompt = gallery_data[evt.index] | |
| # Remove "Prompt: " prefix if it exists | |
| prompt = prompt.replace("Prompt: ", "") if prompt else "" | |
| return {"prompt": prompt}, prompt | |
| return None, "" | |
| except Exception as e: | |
| print(f"Error in handle_gallery_select: {str(e)}") | |
| return None, "" | |
| def clear_gallery(): | |
| return [], None, None, None, None # Returns empty gallery and clears the 4 images | |
| def zip_gallery_images(gallery): | |
| try: | |
| if not gallery: | |
| return None | |
| import io | |
| import zipfile | |
| from datetime import datetime | |
| import numpy as np | |
| from PIL import Image | |
| # Create zip file in memory | |
| zip_buffer = io.BytesIO() | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| filename = f"gallery_images_{timestamp}.zip" | |
| with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file: | |
| for i, (img_data, prompt) in enumerate(gallery): | |
| try: | |
| if img_data is not None: | |
| # Convert numpy array to PIL Image if needed | |
| if isinstance(img_data, np.ndarray): | |
| img = Image.fromarray(np.uint8(img_data)) | |
| elif isinstance(img_data, Image.Image): | |
| img = img_data | |
| else: | |
| print(f"Skipping image {i}: invalid type {type(img_data)}") | |
| continue | |
| # Save image to bytes | |
| img_buffer = io.BytesIO() | |
| img.save(img_buffer, format='PNG') | |
| img_buffer.seek(0) | |
| # Create filename with prompt | |
| safe_prompt = "".join(c for c in prompt[:30] if c.isalnum() or c in (' ', '-', '_')).strip() | |
| img_filename = f"image_{i+1}_{safe_prompt}.png" | |
| # Add to zip | |
| zip_file.writestr(img_filename, img_buffer.getvalue()) | |
| except Exception as img_error: | |
| print(f"Error processing image {i}: {str(img_error)}") | |
| continue | |
| # Prepare zip for download | |
| zip_buffer.seek(0) | |
| # Return the file data and name | |
| return { | |
| "name": filename, | |
| "data": zip_buffer.getvalue() | |
| } | |
| except Exception as e: | |
| print(f"Error creating zip: {str(e)}") | |
| return None | |
| def create_interface(): | |
| print("Creating interface...") | |
| api_key = os.getenv("GROQ_API_KEY") | |
| base_url = os.getenv("API_BASE_URL") | |
| if not api_key: | |
| print("GROQ_API_KEY not found in environment variables") | |
| raise ValueError("GROQ_API_KEY not found in environment variables") | |
| system = PromptEnhancementSystem(api_key, base_url) | |
| print("PromptEnhancementSystem initialized") | |
| def update_interface(prompt, user_directive): | |
| try: | |
| print(f"\n=== Processing prompt: {prompt}") | |
| print(f"User directive: {user_directive}") | |
| state = system.start_session(prompt, user_directive) | |
| improvement_axes = state.get("improvement_axes", []) | |
| initial_analysis = state.get("initial_analysis", {}) | |
| enhanced_prompt = "" | |
| if improvement_axes and len(improvement_axes) > 0: | |
| enhanced_prompt = improvement_axes[0].get("enhanced_prompt", prompt) | |
| button_updates = [] | |
| for i in range(4): | |
| if i < len(improvement_axes): | |
| focus_area = improvement_axes[i].get("focus_area", f"Option {i+1}") | |
| button_updates.append(gr.update(visible=True, value=focus_area)) | |
| else: | |
| button_updates.append(gr.update(visible=False)) | |
| return [prompt, enhanced_prompt] + [ | |
| initial_analysis.get(key, {}) for key in [ | |
| "subject_analysis", | |
| "style_evaluation", | |
| "technical_assessment", | |
| "composition_review", | |
| "context_evaluation", | |
| "mood_assessment" | |
| ] | |
| ] + [ | |
| improvement_axes, | |
| state.get("technical_recommendations", {}), | |
| state | |
| ] + button_updates | |
| except Exception as e: | |
| print(f"Error in update_interface: {str(e)}") | |
| logger.error(f"Error in update_interface: {str(e)}") | |
| empty_analysis = {"score": 0, "strengths": [], "weaknesses": ["Error occurred"]} | |
| return [prompt, prompt] + [empty_analysis] * 6 + [{}, {}, {}] + [gr.update(visible=False)] * 4 | |
| def handle_option_click(option_num, input_prompt, current_text, user_directive): | |
| try: | |
| print(f"\n=== Processing option {option_num}") | |
| state = system.current_state | |
| if state and "improvement_axes" in state: | |
| improvement_axes = state["improvement_axes"] | |
| if option_num < len(improvement_axes): | |
| selected_prompt = improvement_axes[option_num]["enhanced_prompt"] | |
| return [ | |
| input_prompt, | |
| selected_prompt, | |
| state.get("initial_analysis", {}).get("subject_analysis", {}), | |
| state.get("initial_analysis", {}).get("style_evaluation", {}), | |
| state.get("initial_analysis", {}).get("technical_assessment", {}), | |
| state.get("initial_analysis", {}).get("composition_review", {}), | |
| state.get("initial_analysis", {}).get("context_evaluation", {}), | |
| state.get("initial_analysis", {}).get("mood_assessment", {}), | |
| improvement_axes, | |
| state.get("technical_recommendations", {}), | |
| state | |
| ] | |
| return handle_error() | |
| except Exception as e: | |
| print(f"Error in handle_option_click: {str(e)}") | |
| logger.error(f"Error in handle_option_click: {str(e)}") | |
| return handle_error() | |
| def handle_error(): | |
| empty_analysis = {"score": 0, "strengths": [], "weaknesses": ["Error occurred"]} | |
| return ["", "", empty_analysis, empty_analysis, empty_analysis, empty_analysis, empty_analysis, empty_analysis, [], {}, {}] | |
| with gr.Blocks( | |
| title="AI Prompt Enhancement System", | |
| theme=gr.themes.Soft(), | |
| css="footer {visibility: hidden}" | |
| ) as interface: | |
| gr.Markdown("# 🎨 AI Prompt Enhancement & Image Generation System") | |
| with gr.TabItem("Images Generation"): | |
| with gr.Row(): | |
| input_prompt = gr.Textbox( | |
| label="Initial Prompt", | |
| placeholder="Enter your prompt here...", | |
| lines=3, | |
| scale=1 | |
| ) | |
| with gr.Row(): | |
| user_directive = gr.Textbox( | |
| label="User Directive", | |
| placeholder="Enter specific requirements...", | |
| lines=2, | |
| scale=1 | |
| ) | |
| with gr.Row(): | |
| start_btn = gr.Button("Start Enhancement", variant="primary") | |
| with gr.Row(): | |
| current_prompt = gr.Textbox( | |
| label="Current Prompt", | |
| lines=3, | |
| scale=1, | |
| interactive=True | |
| ) | |
| with gr.Row(): | |
| option_buttons = [gr.Button("", visible=False) for _ in range(4)] | |
| with gr.Row(): | |
| finalize_btn = gr.Button("Generate Images", variant="primary") | |
| with gr.Row(): | |
| generated_images = [ | |
| gr.Image( | |
| label=f"Image {i+1}", | |
| type="pil", | |
| show_label=False, | |
| height=256, | |
| width=256, | |
| interactive=False, | |
| show_download_button=False, | |
| elem_id=f"image_{i}" | |
| ) for i in range(4) | |
| ] | |
| with gr.TabItem("Images Gallery"): | |
| with gr.Row(): | |
| image_gallery = gr.Gallery( | |
| label="Generated Images History", | |
| show_label=False, | |
| columns=4, | |
| rows=None, | |
| height=800, | |
| object_fit="contain" | |
| ) | |
| with gr.Row(): | |
| clear_gallery_btn = gr.Button("Clear Gallery", variant="secondary") | |
| with gr.Row(): | |
| selected_image_data = gr.JSON(label="Selected Image Data", visible=True) | |
| copy_to_prompt_btn = gr.Button("Copy Prompt to Current", visible=True) | |
| with gr.TabItem("Image Generation Settings"): | |
| with gr.Row(): | |
| seed = gr.Slider( | |
| label="Seed", | |
| minimum=0, | |
| maximum=MAX_SEED, | |
| step=1, | |
| value=42 | |
| ) | |
| randomize_seed = gr.Checkbox( | |
| label="Randomize seed", | |
| value=True | |
| ) | |
| with gr.Row(): | |
| width = gr.Slider( | |
| label="Width", | |
| minimum=256, | |
| maximum=MAX_IMAGE_SIZE, | |
| step=256, | |
| value=512 | |
| ) | |
| height = gr.Slider( | |
| label="Height", | |
| minimum=256, | |
| maximum=MAX_IMAGE_SIZE, | |
| step=256, | |
| value=512 | |
| ) | |
| num_inference_steps = gr.Slider( | |
| label="Steps", | |
| minimum=1, | |
| maximum=50, | |
| step=1, | |
| value=4 | |
| ) | |
| with gr.TabItem("Initial Analysis"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| subject_analysis = gr.JSON(label="Subject Analysis") | |
| with gr.Column(): | |
| style_evaluation = gr.JSON(label="Style Evaluation") | |
| with gr.Column(): | |
| technical_assessment = gr.JSON(label="Technical Assessment") | |
| with gr.Row(): | |
| with gr.Column(): | |
| composition_review = gr.JSON(label="Composition Review") | |
| with gr.Column(): | |
| context_evaluation = gr.JSON(label="Context Evaluation") | |
| with gr.Column(): | |
| mood_assessment = gr.JSON(label="Mood Assessment") | |
| with gr.Accordion("Additional Information", open=False): | |
| improvement_axes = gr.JSON(label="Improvement Axes") | |
| technical_recommendations = gr.JSON(label="Technical Recommendations") | |
| full_llm_response = gr.JSON(label="Full LLM Response") | |
| # Add event handlers | |
| for i, img in enumerate(generated_images): | |
| img.select( | |
| fn=handle_image_select, | |
| inputs=[improvement_axes], | |
| outputs=[current_prompt], | |
| show_progress=False | |
| ) | |
| start_btn.click( | |
| update_interface, | |
| inputs=[input_prompt, user_directive], | |
| outputs=[ | |
| input_prompt, | |
| current_prompt, | |
| subject_analysis, | |
| style_evaluation, | |
| technical_assessment, | |
| composition_review, | |
| context_evaluation, | |
| mood_assessment, | |
| improvement_axes, | |
| technical_recommendations, | |
| full_llm_response | |
| ] + option_buttons | |
| ) | |
| for i, btn in enumerate(option_buttons): | |
| btn.click( | |
| handle_option_click, | |
| inputs=[ | |
| gr.Slider(value=i, visible=False), | |
| input_prompt, | |
| current_prompt, | |
| user_directive | |
| ], | |
| outputs=[ | |
| input_prompt, | |
| current_prompt, | |
| subject_analysis, | |
| style_evaluation, | |
| technical_assessment, | |
| composition_review, | |
| context_evaluation, | |
| mood_assessment, | |
| improvement_axes, | |
| technical_recommendations, | |
| full_llm_response | |
| ] | |
| ) | |
| finalize_btn.click( | |
| generate_multiple_images_batch, | |
| inputs=[ | |
| improvement_axes, | |
| image_gallery, | |
| seed, | |
| randomize_seed, | |
| width, | |
| height, | |
| num_inference_steps, | |
| current_prompt, | |
| input_prompt | |
| ], | |
| outputs=generated_images + [image_gallery] + [seed] | |
| ) | |
| clear_gallery_btn.click( | |
| clear_gallery, | |
| inputs=[], | |
| outputs=[image_gallery] + generated_images | |
| ) | |
| # Add gallery selection handler | |
| image_gallery.select( | |
| fn=handle_gallery_select, | |
| inputs=[image_gallery], | |
| outputs=[selected_image_data, current_prompt] | |
| ) | |
| # Add copy button handler | |
| # Fix the copy button handler by adding a null check | |
| copy_to_prompt_btn.click( | |
| lambda x: x["prompt"] if x and isinstance(x, dict) and "prompt" in x else "", | |
| inputs=[selected_image_data], | |
| outputs=[current_prompt] | |
| ) | |
| print("Interface setup complete") | |
| return interface | |
| if __name__ == "__main__": | |
| interface = create_interface() | |
| interface.launch() |