Murtaza249's picture
Create app.py
1777c49 verified
import streamlit as st
import torch
from diffusers import StableDiffusionPipeline
from PIL import Image
import io
import base64
# Page configuration
st.set_page_config(
page_title="Text to Image Generator",
page_icon="🎨",
layout="wide",
)
# Custom CSS for better UI
st.markdown("""
<style>
.main {
background-color: #f5f7f9;
}
.stApp {
max-width: 1200px;
margin: 0 auto;
}
.css-18e3th9 {
padding-top: 2rem;
}
.css-1d391kg {
padding: 3rem 1rem;
}
.stButton>button {
background-color: #4CAF50;
color: white;
padding: 10px 24px;
border: none;
border-radius: 4px;
cursor: pointer;
font-size: 16px;
transition: all 0.3s;
}
.stButton>button:hover {
background-color: #45a049;
box-shadow: 0 4px 8px rgba(0,0,0,0.1);
}
.download-link {
display: inline-block;
margin-top: 10px;
padding: 10px 15px;
background-color: #3498db;
color: white;
text-decoration: none;
border-radius: 4px;
text-align: center;
}
h1 {
color: #2C3E50;
margin-bottom: 30px;
}
.stSlider {
padding-top: 1rem;
padding-bottom: 1rem;
}
</style>
""", unsafe_allow_html=True)
# Header
st.title("✨ Text to Image Generator")
st.markdown("Turn your imagination into images using AI! 🖼️")
@st.cache_resource
def load_model():
model_id = "runwayml/stable-diffusion-v1-5"
pipe = StableDiffusionPipeline.from_pretrained(
model_id,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
)
if torch.cuda.is_available():
pipe = pipe.to("cuda")
return pipe
# Create columns for the layout
col1, col2 = st.columns([1, 1])
with col1:
# Text input for the prompt
prompt = st.text_area(
"Enter your prompt",
placeholder="A beautiful sunset over the ocean with palm trees in the foreground",
height=100
)
# Advanced options in an expander
with st.expander("Advanced Options"):
num_inference_steps = st.slider(
"Quality (Steps)",
min_value=1,
max_value=100,
value=50,
help="Higher values produce more detailed images but take longer"
)
guidance_scale = st.slider(
"Creativity",
min_value=1.0,
max_value=20.0,
value=7.5,
step=0.5,
help="Higher values make the image more closely follow the prompt"
)
seed = st.number_input(
"Random Seed",
min_value=-1,
max_value=2147483647,
value=-1,
help="Set a specific seed for reproducible results, or -1 for random"
)
image_width = st.select_slider(
"Image Width",
options=[512, 576, 640, 704, 768],
value=512
)
image_height = st.select_slider(
"Image Height",
options=[512, 576, 640, 704, 768],
value=512
)
# Generate button
generate_button = st.button("Generate Image")
# Function to get a download link
def get_image_download_link(img, filename="generated_image.png", text="Download Image"):
buffered = io.BytesIO()
img.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode()
href = f'<a href="data:image/png;base64,{img_str}" download="{filename}" class="download-link">{text}</a>'
return href
with col2:
# Display the generated image and provide a download link
st.markdown("### Generated Image")
result_container = st.empty()
if generate_button and prompt:
try:
with st.spinner("Generating your image..."):
# Load the model
pipe = load_model()
# Set the seed for reproducibility
generator = None
if seed != -1:
generator = torch.Generator().manual_seed(seed)
# Generate the image
image = pipe(
prompt=prompt,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
generator=generator,
width=image_width,
height=image_height
).images[0]
# Display the image
result_container.image(image, caption=prompt, use_column_width=True)
# Provide a download link
st.markdown(get_image_download_link(image), unsafe_allow_html=True)
# Display the parameters used
st.markdown("### Parameters Used")
st.text(f"Steps: {num_inference_steps}")
st.text(f"Guidance Scale: {guidance_scale}")
st.text(f"Seed: {seed}")
st.text(f"Dimensions: {image_width}x{image_height}")
except Exception as e:
result_container.error(f"An error occurred: {str(e)}")
else:
result_container.markdown("Your image will appear here after generation.")
# Add footer
st.markdown("---")
st.markdown("Powered by Hugging Face Diffusers and Streamlit")