import os import time import logging import requests from io import BytesIO from PIL import Image import gradio as gr # ---------------------------- # Logging Configuration # ---------------------------- logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # ---------------------------- # Constants # ---------------------------- HF_API_URL = "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-schnell" DEFAULT_STYLES = [ "Realistic", "Cinematic", "Cyberpunk", "Studio Lighting", "Highly Detailed", "4K" ] # ---------------------------- # Utility Functions # ---------------------------- def get_hf_token(): """Load Hugging Face token from environment variable.""" token = os.getenv("HF_TOKEN") if not token: raise EnvironmentError("HF_TOKEN not found in environment variables.") return token def style_prompt(user_input: str, style: str = None) -> str: """Enhance prompt with selected style.""" if not user_input.strip(): raise ValueError("Prompt cannot be empty.") if style and style != "None": enhanced = f"{user_input}, {style}, ultra quality, sharp focus" else: enhanced = f"{user_input}, high quality" return enhanced def query_hf_api(prompt, retries=3, timeout=60, seed=None): """Send request to Hugging Face Inference API with retry logic.""" headers = { "Authorization": f"Bearer {get_hf_token()}", "Content-Type": "application/json" } payload = { "inputs": prompt, "options": {"wait_for_model": True} } if seed is not None: payload["parameters"] = {"seed": seed} for attempt in range(retries): try: response = requests.post( HF_API_URL, headers=headers, json=payload, timeout=timeout ) if response.status_code == 200: return response.content elif response.status_code == 503: logger.warning("Model loading, retrying...") time.sleep(5) elif response.status_code == 429: logger.warning("Rate limit hit, retrying...") time.sleep(10) else: logger.error(f"API Error: {response.text}") raise RuntimeError(f"API Error: {response.text}") except requests.exceptions.Timeout: logger.warning("Timeout occurred, retrying...") time.sleep(5) raise RuntimeError("Failed after multiple retries.") def generate_image(prompt, style, seed): """Main function for Gradio.""" try: styled_prompt = style_prompt(prompt, style) image_bytes = query_hf_api(styled_prompt, seed=seed) image = Image.open(BytesIO(image_bytes)).convert("RGB") return image except Exception as e: logger.error(str(e)) return f"Error: {str(e)}" # ---------------------------- # Gradio UI # ---------------------------- with gr.Blocks() as app: gr.Markdown("# 🎨 AI Image Generator (FLUX.1-schnell)") gr.Markdown("Generate high-quality images from text prompts using Hugging Face.") with gr.Row(): prompt_input = gr.Textbox( label="Enter your prompt", placeholder="e.g., A futuristic city at sunset" ) with gr.Row(): style_dropdown = gr.Dropdown( ["None"] + DEFAULT_STYLES, label="Select Style", value="None" ) seed_input = gr.Number( label="Seed (optional)", value=None, precision=0 ) generate_btn = gr.Button("Generate Image") output_image = gr.Image(label="Generated Image") download_btn = gr.File(label="Download Image") examples = gr.Examples( examples=[ ["A dragon flying over mountains", "Cinematic", 42], ["Cyberpunk city at night", "Cyberpunk", 123], ["Portrait of a warrior", "Realistic", 7], ], inputs=[prompt_input, style_dropdown, seed_input], ) def generate_and_download(prompt, style, seed): image = generate_image(prompt, style, seed) if isinstance(image, str): return None, None file_path = "output.png" image.save(file_path) return image, file_path generate_btn.click( fn=generate_and_download, inputs=[prompt_input, style_dropdown, seed_input], outputs=[output_image, download_btn] ) app.launch()