PRONGS-CHIRAG
reduced inference steps
92b7960
import streamlit as st
import requests
import io
import os
from PIL import Image
from datetime import datetime
# Set up Hugging Face API key
HUGGINGFACE_API_KEY = os.getenv('HUGGINGFACE_API_KEY')
# Dictionary of available models and their API endpoints
MODELS = {
"Stable Diffusion XL": "https://api-inference.huggingface.co/models/stabilityai/stable-diffusion-xl-base-1.0",
"Flux 1": "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-dev",
"Flux Midjourney": "https://api-inference.huggingface.co/models/strangerzonehf/Flux-Midjourney-Mix2-LoRA",
"Playground v2": "https://api-inference.huggingface.co/models/playgroundai/playground-v2-1024px-aesthetic"
}
def generate_image(prompt, model_url):
"""Generate image using Hugging Face API"""
headers = {
"Authorization": f"Bearer {HUGGINGFACE_API_KEY}",
"Content-Type": "application/json"
}
data = {
"inputs": prompt,
"parameters": {
"num_inference_steps": 20,
"guidance_scale": 7.5
}
}
response = requests.post(
model_url,
headers=headers,
json=data
)
if response.status_code == 200:
return response.content
else:
st.error(f"Error: {response.text}")
return None
def save_image(image_data, model_name):
"""Save the generated image"""
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"generated_{model_name}_{timestamp}.png"
image = Image.open(io.BytesIO(image_data))
image.save(filename)
return filename
# Streamlit UI
st.title("Open Source Text-to-Image Generation")
st.markdown("""
This app uses state-of-the-art open source text-to-image models via Hugging Face's API.
- **SDXL**: Latest Stable Diffusion model with exceptional quality
- **Flux 1**: Optimized for speed and quality
- **Flux Midjourney**: Compact but powerful model
- **Playground v2**: Specialized in aesthetic generations
""")
# API key input
if 'api_key_set' not in st.session_state:
st.session_state.api_key_set = False
if not st.session_state.api_key_set:
with st.form("api_key_form"):
huggingface_key = st.text_input("Enter Hugging Face API Key:", type="password",
help="Get your free API key from huggingface.co")
submit = st.form_submit_button("Save API Key")
if submit:
os.environ['HUGGINGFACE_API_KEY'] = huggingface_key
st.session_state.api_key_set = True
st.rerun()
if st.session_state.api_key_set:
# Model selection
model = st.selectbox(
"Choose Model:",
list(MODELS.keys())
)
# Advanced options
with st.expander("Advanced Options"):
seed = st.number_input("Random Seed (optional)", min_value=0, max_value=999999999, value=0)
if seed > 0:
st.info("Using the same seed will generate similar images for the same prompt")
# Prompt input
prompt = st.text_area(
"Enter your prompt:",
height=100,
help="Describe what you want to see in the image. Be specific!"
)
# Generate image button
if st.button("Generate Image"):
if prompt:
with st.spinner(f"Generating image using {model}..."):
image_data = generate_image(prompt, MODELS[model])
if image_data:
# Save the image
filename = save_image(image_data, model.lower().replace(" ", "_"))
# Display the image
st.image(Image.open(io.BytesIO(image_data)))
# Display prompt used
st.text_area("Prompt used:", prompt, height=100, disabled=True)
# Download button
with open(filename, "rb") as file:
btn = st.download_button(
label="Download Image",
data=file,
file_name=filename,
mime="image/png"
)
else:
st.warning("Please enter a prompt.")
# Display model information
st.markdown("### About the Selected Model")
model_info = {
"Stable Diffusion XL": "The latest version of Stable Diffusion, known for its high-quality outputs and improved understanding of prompts.",
"Flux 1": "A highly optimized model that balances speed and quality, great for quick iterations.",
"Flux Midjourney": "A compact 1B parameter model that produces impressive results with faster inference times.",
"Playground v2": "Specialized in creating highly aesthetic images with strong composition and artistic quality."
}
st.info(model_info[model])