Spaces:
Runtime error
Runtime error
| import spaces | |
| import os | |
| import gradio as gr | |
| import random | |
| import torch | |
| import logging | |
| import numpy as np | |
| from typing import Dict, Any, List | |
| from diffusers import DiffusionPipeline | |
| from api import PromptEnhancementSystem | |
| # Constants | |
| MAX_SEED = np.iinfo(np.int32).max | |
| MAX_IMAGE_SIZE = 2048 | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| MODEL_ID = "black-forest-labs/FLUX.1-schnell" | |
| DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32 | |
| print(f"Using device: {DEVICE}") | |
| logger = logging.getLogger(__name__) | |
| # Initialize model | |
| try: | |
| print("Loading model...") | |
| pipe = DiffusionPipeline.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=DTYPE | |
| ).to(DEVICE) | |
| print("Model loaded successfully") | |
| logger.info("Model loaded successfully") | |
| except Exception as e: | |
| print(f"Failed to load model: {str(e)}") | |
| logger.error(f"Failed to load model: {str(e)}") | |
| raise | |
| def generate_multiple_images_batch( | |
| improvement_axes, | |
| seed=42, | |
| randomize_seed=False, | |
| width=512, | |
| height=512, | |
| num_inference_steps=4, | |
| progress=gr.Progress(track_tqdm=True) | |
| ): | |
| try: | |
| # Extract prompts from improvement axes | |
| prompts = [axis["enhanced_prompt"] for axis in improvement_axes if axis.get("enhanced_prompt")] | |
| if not prompts: | |
| return [None] * 4 + [seed] | |
| if randomize_seed: | |
| current_seed = random.randint(0, MAX_SEED) | |
| else: | |
| current_seed = seed | |
| print(f"Generating images with {len(prompts)} prompts") | |
| print(f"Using seed: {current_seed}") | |
| # Generate all images in a single batch | |
| generator = torch.Generator().manual_seed(current_seed) | |
| images = pipe( | |
| prompt=prompts, # Pass list of prompts directly | |
| width=width, | |
| height=height, | |
| num_inference_steps=num_inference_steps, | |
| generator=generator, | |
| guidance_scale=0.0 | |
| ).images | |
| # Pad with None if we have fewer than 4 images | |
| while len(images) < 4: | |
| images.append(None) | |
| print("All images generated successfully") | |
| return images[:4] + [current_seed] | |
| except Exception as e: | |
| print(f"Image generation error: {str(e)}") | |
| logger.error(f"Image generation error: {str(e)}") | |
| raise | |
| def handle_image_select(evt: gr.SelectData, improvement_axes_data): | |
| """Handle image selection event""" | |
| try: | |
| if improvement_axes_data and isinstance(improvement_axes_data, list): | |
| selected_index = evt.index[1] if isinstance(evt.index, tuple) else evt.index | |
| if selected_index < len(improvement_axes_data): | |
| selected_prompt = improvement_axes_data[selected_index].get("enhanced_prompt", "") | |
| return selected_prompt | |
| return "" | |
| except Exception as e: | |
| print(f"Error in handle_image_select: {str(e)}") | |
| return "" | |
| def create_interface(): | |
| print("Creating interface...") | |
| api_key = os.getenv("GROQ_API_KEY") | |
| base_url = os.getenv("API_BASE_URL") | |
| if not api_key: | |
| print("GROQ_API_KEY not found in environment variables") | |
| raise ValueError("GROQ_API_KEY not found in environment variables") | |
| system = PromptEnhancementSystem(api_key, base_url) | |
| print("PromptEnhancementSystem initialized") | |
| def update_interface(prompt): | |
| try: | |
| print(f"\n=== Processing prompt: {prompt}") | |
| state = system.start_session(prompt) | |
| improvement_axes = state.get("improvement_axes", []) | |
| initial_analysis = state.get("initial_analysis", {}) | |
| enhanced_prompt = "" | |
| if improvement_axes and len(improvement_axes) > 0: | |
| enhanced_prompt = improvement_axes[0].get("enhanced_prompt", prompt) | |
| button_updates = [] | |
| for i in range(4): | |
| if i < len(improvement_axes): | |
| focus_area = improvement_axes[i].get("focus_area", f"Option {i+1}") | |
| button_updates.append(gr.update(visible=True, value=focus_area)) | |
| else: | |
| button_updates.append(gr.update(visible=False)) | |
| return [prompt, enhanced_prompt] + [ | |
| initial_analysis.get(key, {}) for key in [ | |
| "subject_analysis", | |
| "style_evaluation", | |
| "technical_assessment", | |
| "composition_review", | |
| "context_evaluation", | |
| "mood_assessment" | |
| ] | |
| ] + [ | |
| improvement_axes, | |
| state.get("technical_recommendations", {}), | |
| None, None, None, None, # Four None values for the four image outputs | |
| state | |
| ] + button_updates | |
| except Exception as e: | |
| print(f"Error in update_interface: {str(e)}") | |
| logger.error(f"Error in update_interface: {str(e)}") | |
| empty_analysis = {"score": 0, "strengths": [], "weaknesses": ["Error occurred"]} | |
| return [prompt, prompt] + [empty_analysis] * 6 + [{}, {}, None, None, None, None, {}] + [gr.update(visible=False)] * 4 | |
| def handle_option_click(option_num, input_prompt, current_text): | |
| try: | |
| print(f"\n=== Processing option {option_num}") | |
| state = system.current_state | |
| if state and "improvement_axes" in state: | |
| improvement_axes = state["improvement_axes"] | |
| if option_num < len(improvement_axes): | |
| selected_prompt = improvement_axes[option_num]["enhanced_prompt"] | |
| return [ | |
| input_prompt, | |
| selected_prompt, | |
| state.get("initial_analysis", {}).get("subject_analysis", {}), | |
| state.get("initial_analysis", {}).get("style_evaluation", {}), | |
| state.get("initial_analysis", {}).get("technical_assessment", {}), | |
| state.get("initial_analysis", {}).get("composition_review", {}), | |
| state.get("initial_analysis", {}).get("context_evaluation", {}), | |
| state.get("initial_analysis", {}).get("mood_assessment", {}), | |
| improvement_axes, | |
| state.get("technical_recommendations", {}), | |
| state | |
| ] | |
| return handle_error() | |
| except Exception as e: | |
| print(f"Error in handle_option_click: {str(e)}") | |
| logger.error(f"Error in handle_option_click: {str(e)}") | |
| return handle_error() | |
| def handle_error(): | |
| empty_analysis = {"score": 0, "strengths": [], "weaknesses": ["Error occurred"]} | |
| return ["", "", empty_analysis, empty_analysis, empty_analysis, empty_analysis, empty_analysis, empty_analysis, [], {}, {}] | |
| with gr.Blocks( | |
| title="AI Prompt Enhancement System", | |
| theme=gr.themes.Soft(), | |
| css="footer {visibility: hidden}" | |
| ) as interface: | |
| gr.Markdown("# 🎨 AI Prompt Enhancement & Image Generation System") | |
| with gr.Row(): | |
| input_prompt = gr.Textbox( | |
| label="Initial Prompt", | |
| placeholder="Enter your prompt here...", | |
| lines=3, | |
| scale=1 | |
| ) | |
| current_prompt = gr.Textbox( | |
| label="Current Prompt", | |
| lines=3, | |
| scale=1, | |
| interactive=True | |
| ) | |
| with gr.Row(): | |
| start_btn = gr.Button("Start Enhancement", variant="primary") | |
| with gr.Row(): | |
| option_buttons = [gr.Button("", visible=False) for _ in range(4)] | |
| with gr.Tabs(): | |
| with gr.TabItem("Initial Analysis"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| subject_analysis = gr.JSON(label="Subject Analysis") | |
| with gr.Column(): | |
| style_evaluation = gr.JSON(label="Style Evaluation") | |
| with gr.Column(): | |
| technical_assessment = gr.JSON(label="Technical Assessment") | |
| with gr.Row(): | |
| with gr.Column(): | |
| composition_review = gr.JSON(label="Composition Review") | |
| with gr.Column(): | |
| context_evaluation = gr.JSON(label="Context Evaluation") | |
| with gr.Column(): | |
| mood_assessment = gr.JSON(label="Mood Assessment") | |
| with gr.TabItem("Generated Images"): | |
| with gr.Row(): | |
| generated_images = [ | |
| gr.Image( | |
| label=f"Image {i+1}", | |
| type="pil", | |
| show_label=True, | |
| height=256, | |
| width=256, | |
| interactive=True, | |
| elem_id=f"image_{i}" | |
| ) for i in range(4) | |
| ] | |
| with gr.Row(): | |
| finalize_btn = gr.Button("Generate All Images", variant="primary") | |
| with gr.Accordion("Image Generation Settings", open=False): | |
| with gr.Row(): | |
| seed = gr.Slider( | |
| label="Seed", | |
| minimum=0, | |
| maximum=2048, | |
| step=1, | |
| value=42 | |
| ) | |
| randomize_seed = gr.Checkbox( | |
| label="Randomize seed", | |
| value=True | |
| ) | |
| with gr.Row(): | |
| width = gr.Slider( | |
| label="Width", | |
| minimum=256, | |
| maximum=2048, | |
| step=256, | |
| value=512 | |
| ) | |
| height = gr.Slider( | |
| label="Height", | |
| minimum=256, | |
| maximum=2048, | |
| step=256, | |
| value=512 | |
| ) | |
| num_inference_steps = gr.Slider( | |
| label="Steps", | |
| minimum=1, | |
| maximum=50, | |
| step=1, | |
| value=4 | |
| ) | |
| with gr.Accordion("Additional Information", open=False): | |
| improvement_axes = gr.JSON(label="Improvement Axes") | |
| technical_recommendations = gr.JSON(label="Technical Recommendations") | |
| full_llm_response = gr.JSON(label="Full LLM Response") | |
| # Add select events for each image | |
| for i, img in enumerate(generated_images): | |
| img.select( | |
| fn=handle_image_select, | |
| inputs=[improvement_axes], | |
| outputs=[input_prompt] | |
| ) | |
| start_btn.click( | |
| update_interface, | |
| inputs=[input_prompt], | |
| outputs=[ | |
| input_prompt, | |
| current_prompt, | |
| subject_analysis, | |
| style_evaluation, | |
| technical_assessment, | |
| composition_review, | |
| context_evaluation, | |
| mood_assessment, | |
| improvement_axes, | |
| technical_recommendations | |
| ] + generated_images + [full_llm_response] + option_buttons | |
| ) | |
| for i, btn in enumerate(option_buttons): | |
| btn.click( | |
| handle_option_click, | |
| inputs=[ | |
| gr.Slider(value=i, visible=False), | |
| input_prompt, | |
| current_prompt | |
| ], | |
| outputs=[ | |
| input_prompt, | |
| current_prompt, | |
| subject_analysis, | |
| style_evaluation, | |
| technical_assessment, | |
| composition_review, | |
| context_evaluation, | |
| mood_assessment, | |
| improvement_axes, | |
| technical_recommendations, | |
| full_llm_response | |
| ] | |
| ) | |
| finalize_btn.click( | |
| generate_multiple_images_batch, | |
| inputs=[ | |
| improvement_axes, | |
| seed, | |
| randomize_seed, | |
| width, | |
| height, | |
| num_inference_steps | |
| ], | |
| outputs=generated_images + [seed] | |
| ) | |
| print("Interface setup complete") | |
| return interface |