textToimage / app.py
ShahbazAhmad-Lab's picture
Create app.py
a1d9a53 verified
import os
import time
import logging
import requests
from io import BytesIO
from PIL import Image
import gradio as gr
# ----------------------------
# Logging Configuration
# ----------------------------
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# ----------------------------
# Constants
# ----------------------------
HF_API_URL = "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-schnell"
DEFAULT_STYLES = [
"Realistic", "Cinematic", "Cyberpunk",
"Studio Lighting", "Highly Detailed", "4K"
]
# ----------------------------
# Utility Functions
# ----------------------------
def get_hf_token():
"""Load Hugging Face token from environment variable."""
token = os.getenv("HF_TOKEN")
if not token:
raise EnvironmentError("HF_TOKEN not found in environment variables.")
return token
def style_prompt(user_input: str, style: str = None) -> str:
"""Enhance prompt with selected style."""
if not user_input.strip():
raise ValueError("Prompt cannot be empty.")
if style and style != "None":
enhanced = f"{user_input}, {style}, ultra quality, sharp focus"
else:
enhanced = f"{user_input}, high quality"
return enhanced
def query_hf_api(prompt, retries=3, timeout=60, seed=None):
"""Send request to Hugging Face Inference API with retry logic."""
headers = {
"Authorization": f"Bearer {get_hf_token()}",
"Content-Type": "application/json"
}
payload = {
"inputs": prompt,
"options": {"wait_for_model": True}
}
if seed is not None:
payload["parameters"] = {"seed": seed}
for attempt in range(retries):
try:
response = requests.post(
HF_API_URL,
headers=headers,
json=payload,
timeout=timeout
)
if response.status_code == 200:
return response.content
elif response.status_code == 503:
logger.warning("Model loading, retrying...")
time.sleep(5)
elif response.status_code == 429:
logger.warning("Rate limit hit, retrying...")
time.sleep(10)
else:
logger.error(f"API Error: {response.text}")
raise RuntimeError(f"API Error: {response.text}")
except requests.exceptions.Timeout:
logger.warning("Timeout occurred, retrying...")
time.sleep(5)
raise RuntimeError("Failed after multiple retries.")
def generate_image(prompt, style, seed):
"""Main function for Gradio."""
try:
styled_prompt = style_prompt(prompt, style)
image_bytes = query_hf_api(styled_prompt, seed=seed)
image = Image.open(BytesIO(image_bytes)).convert("RGB")
return image
except Exception as e:
logger.error(str(e))
return f"Error: {str(e)}"
# ----------------------------
# Gradio UI
# ----------------------------
with gr.Blocks() as app:
gr.Markdown("# 🎨 AI Image Generator (FLUX.1-schnell)")
gr.Markdown("Generate high-quality images from text prompts using Hugging Face.")
with gr.Row():
prompt_input = gr.Textbox(
label="Enter your prompt",
placeholder="e.g., A futuristic city at sunset"
)
with gr.Row():
style_dropdown = gr.Dropdown(
["None"] + DEFAULT_STYLES,
label="Select Style",
value="None"
)
seed_input = gr.Number(
label="Seed (optional)",
value=None,
precision=0
)
generate_btn = gr.Button("Generate Image")
output_image = gr.Image(label="Generated Image")
download_btn = gr.File(label="Download Image")
examples = gr.Examples(
examples=[
["A dragon flying over mountains", "Cinematic", 42],
["Cyberpunk city at night", "Cyberpunk", 123],
["Portrait of a warrior", "Realistic", 7],
],
inputs=[prompt_input, style_dropdown, seed_input],
)
def generate_and_download(prompt, style, seed):
image = generate_image(prompt, style, seed)
if isinstance(image, str):
return None, None
file_path = "output.png"
image.save(file_path)
return image, file_path
generate_btn.click(
fn=generate_and_download,
inputs=[prompt_input, style_dropdown, seed_input],
outputs=[output_image, download_btn]
)
app.launch()