Spaces:
Configuration error
Configuration error
File size: 4,580 Bytes
a1d9a53 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 | import os
import time
import logging
import requests
from io import BytesIO
from PIL import Image
import gradio as gr
# ----------------------------
# Logging Configuration
# ----------------------------
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# ----------------------------
# Constants
# ----------------------------
HF_API_URL = "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-schnell"
DEFAULT_STYLES = [
"Realistic", "Cinematic", "Cyberpunk",
"Studio Lighting", "Highly Detailed", "4K"
]
# ----------------------------
# Utility Functions
# ----------------------------
def get_hf_token():
"""Load Hugging Face token from environment variable."""
token = os.getenv("HF_TOKEN")
if not token:
raise EnvironmentError("HF_TOKEN not found in environment variables.")
return token
def style_prompt(user_input: str, style: str = None) -> str:
"""Enhance prompt with selected style."""
if not user_input.strip():
raise ValueError("Prompt cannot be empty.")
if style and style != "None":
enhanced = f"{user_input}, {style}, ultra quality, sharp focus"
else:
enhanced = f"{user_input}, high quality"
return enhanced
def query_hf_api(prompt, retries=3, timeout=60, seed=None):
"""Send request to Hugging Face Inference API with retry logic."""
headers = {
"Authorization": f"Bearer {get_hf_token()}",
"Content-Type": "application/json"
}
payload = {
"inputs": prompt,
"options": {"wait_for_model": True}
}
if seed is not None:
payload["parameters"] = {"seed": seed}
for attempt in range(retries):
try:
response = requests.post(
HF_API_URL,
headers=headers,
json=payload,
timeout=timeout
)
if response.status_code == 200:
return response.content
elif response.status_code == 503:
logger.warning("Model loading, retrying...")
time.sleep(5)
elif response.status_code == 429:
logger.warning("Rate limit hit, retrying...")
time.sleep(10)
else:
logger.error(f"API Error: {response.text}")
raise RuntimeError(f"API Error: {response.text}")
except requests.exceptions.Timeout:
logger.warning("Timeout occurred, retrying...")
time.sleep(5)
raise RuntimeError("Failed after multiple retries.")
def generate_image(prompt, style, seed):
"""Main function for Gradio."""
try:
styled_prompt = style_prompt(prompt, style)
image_bytes = query_hf_api(styled_prompt, seed=seed)
image = Image.open(BytesIO(image_bytes)).convert("RGB")
return image
except Exception as e:
logger.error(str(e))
return f"Error: {str(e)}"
# ----------------------------
# Gradio UI
# ----------------------------
with gr.Blocks() as app:
gr.Markdown("# 🎨 AI Image Generator (FLUX.1-schnell)")
gr.Markdown("Generate high-quality images from text prompts using Hugging Face.")
with gr.Row():
prompt_input = gr.Textbox(
label="Enter your prompt",
placeholder="e.g., A futuristic city at sunset"
)
with gr.Row():
style_dropdown = gr.Dropdown(
["None"] + DEFAULT_STYLES,
label="Select Style",
value="None"
)
seed_input = gr.Number(
label="Seed (optional)",
value=None,
precision=0
)
generate_btn = gr.Button("Generate Image")
output_image = gr.Image(label="Generated Image")
download_btn = gr.File(label="Download Image")
examples = gr.Examples(
examples=[
["A dragon flying over mountains", "Cinematic", 42],
["Cyberpunk city at night", "Cyberpunk", 123],
["Portrait of a warrior", "Realistic", 7],
],
inputs=[prompt_input, style_dropdown, seed_input],
)
def generate_and_download(prompt, style, seed):
image = generate_image(prompt, style, seed)
if isinstance(image, str):
return None, None
file_path = "output.png"
image.save(file_path)
return image, file_path
generate_btn.click(
fn=generate_and_download,
inputs=[prompt_input, style_dropdown, seed_input],
outputs=[output_image, download_btn]
)
app.launch() |