Vidostar / src /streamlit_app.py
shaheerawan3's picture
Update src/streamlit_app.py
79b8cf7 verified
import streamlit as st
import json
import requests
from PIL import Image
import io
import base64
from datetime import datetime
import time
# Page configuration
st.set_page_config(
page_title="AdForge AI - AI Advertisement Generator",
page_icon="🎨",
layout="wide",
initial_sidebar_state="expanded"
)
# Custom CSS for professional styling
st.markdown("""
<style>
.main-header {
font-size: 3rem;
font-weight: bold;
background: linear-gradient(120deg, #667eea 0%, #764ba2 100%);
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
text-align: center;
padding: 1rem 0;
}
.sub-header {
text-align: center;
color: #666;
font-size: 1.2rem;
margin-bottom: 2rem;
}
.stButton>button {
width: 100%;
background: linear-gradient(120deg, #667eea 0%, #764ba2 100%);
color: white;
border: none;
padding: 0.75rem 1.5rem;
font-size: 1.1rem;
font-weight: bold;
border-radius: 10px;
transition: all 0.3s;
}
.stButton>button:hover {
transform: translateY(-2px);
box-shadow: 0 10px 20px rgba(0,0,0,0.2);
}
.success-box {
padding: 1rem;
border-radius: 10px;
background: #d4edda;
border: 1px solid #c3e6cb;
color: #155724;
margin: 1rem 0;
}
.info-box {
padding: 1rem;
border-radius: 10px;
background: #d1ecf1;
border: 1px solid #bee5eb;
color: #0c5460;
margin: 1rem 0;
}
.feature-card {
padding: 1.5rem;
border-radius: 10px;
background: white;
box-shadow: 0 4px 6px rgba(0,0,0,0.1);
margin: 1rem 0;
}
</style>
""", unsafe_allow_html=True)
# Initialize session state
if 'generated_images' not in st.session_state:
st.session_state.generated_images = []
if 'generation_history' not in st.session_state:
st.session_state.generation_history = []
if 'api_token' not in st.session_state:
st.session_state.api_token = ""
# Helper Functions
def validate_json(json_string):
"""Validate and parse JSON input"""
try:
data = json.loads(json_string)
return data, None
except json.JSONDecodeError as e:
return None, f"Invalid JSON format: {str(e)}"
def generate_advanced_prompt(ad_data):
"""Generate sophisticated advertising prompt from JSON data"""
product = ad_data.get('product_name', 'product')
description = ad_data.get('description', '')
style = ad_data.get('style', 'professional')
mood = ad_data.get('mood', 'energetic')
target_audience = ad_data.get('target_audience', 'general audience')
colors = ad_data.get('colors', [])
setting = ad_data.get('setting', 'studio')
composition = ad_data.get('composition', 'centered')
lighting = ad_data.get('lighting', 'soft studio lighting')
# Build sophisticated prompt
prompt_parts = []
# Main subject
prompt_parts.append(f"Professional advertising photography of {product}")
if description:
prompt_parts.append(f"{description}")
# Style and mood
prompt_parts.append(f"{style} style, {mood} mood")
# Target audience context
if target_audience and target_audience != 'general audience':
prompt_parts.append(f"appealing to {target_audience}")
# Visual elements
if colors:
color_palette = ", ".join(colors)
prompt_parts.append(f"color palette: {color_palette}")
prompt_parts.append(f"{setting} setting")
prompt_parts.append(f"{composition} composition")
prompt_parts.append(f"{lighting}")
# Quality tags
quality_tags = [
"high-end commercial photography",
"8k resolution",
"sharp focus",
"professional color grading",
"dramatic lighting",
"photorealistic",
"award-winning advertisement"
]
prompt = ", ".join(prompt_parts) + ", " + ", ".join(quality_tags)
return prompt
def generate_negative_prompt(ad_data):
"""Generate negative prompt to avoid unwanted elements"""
base_negative = [
"low quality", "blurry", "distorted", "ugly", "amateur",
"poor composition", "bad lighting", "oversaturated",
"watermark", "text", "signature", "logo overlay",
"cropped", "out of frame", "deformed", "disfigured",
"extra limbs", "poorly rendered", "cartoon", "anime"
]
# Add custom negative prompts if specified
if 'avoid' in ad_data and ad_data['avoid']:
base_negative.extend(ad_data['avoid'])
return ", ".join(base_negative)
def call_huggingface_api(prompt, negative_prompt, api_token):
"""Call Hugging Face Inference API for FLUX.1-schnell"""
API_URL = "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-schnell"
headers = {
"Authorization": f"Bearer {api_token}",
"Content-Type": "application/json"
}
payload = {
"inputs": prompt,
"parameters": {
"negative_prompt": negative_prompt,
"num_inference_steps": 4, # Schnell optimized for 1-4 steps
"guidance_scale": 0.0, # Schnell doesn't use guidance
"width": 1024,
"height": 1024
}
}
try:
response = requests.post(API_URL, headers=headers, json=payload, timeout=60)
if response.status_code == 200:
image = Image.open(io.BytesIO(response.content))
return image, None
elif response.status_code == 503:
return None, "Model is loading. Please wait 20 seconds and try again."
elif response.status_code == 401:
return None, "Invalid API token. Please check your Hugging Face token."
elif response.status_code == 402:
return None, "You've exceeded your free tier limit. Consider upgrading to PRO or wait for next month."
else:
error_msg = response.json().get('error', 'Unknown error')
return None, f"API Error ({response.status_code}): {error_msg}"
except requests.Timeout:
return None, "Request timed out. The model might be loading. Please try again."
except Exception as e:
return None, f"Error: {str(e)}"
def image_to_base64(image):
"""Convert PIL Image to base64 string"""
buffered = io.BytesIO()
image.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode()
# Header
st.markdown('<h1 class="main-header">🎨 AdForge AI</h1>', unsafe_allow_html=True)
st.markdown('<p class="sub-header">Transform your concept into stunning advertisements with AI</p>', unsafe_allow_html=True)
# Sidebar Configuration
with st.sidebar:
st.header("βš™οΈ Configuration")
# API Token Input
st.subheader("πŸ”‘ Hugging Face API Token")
api_token_input = st.text_input(
"Enter your HF token",
type="password",
value=st.session_state.api_token,
help="Get your free token from https://huggingface.co/settings/tokens"
)
if api_token_input:
st.session_state.api_token = api_token_input
st.success("βœ… Token configured!")
else:
st.warning("⚠️ Please enter your Hugging Face API token to generate images")
st.markdown("---")
# Advanced Settings
st.subheader("πŸŽ›οΈ Generation Settings")
num_variations = st.slider(
"Number of Variations",
min_value=1,
max_value=4,
value=1,
help="Generate multiple variations of your ad"
)
image_size = st.selectbox(
"Image Resolution",
options=["1024x1024 (Square)", "1024x768 (Landscape)", "768x1024 (Portrait)"],
index=0
)
st.markdown("---")
# Model Info
st.subheader("ℹ️ About")
st.info("""
**Model:** FLUX.1-schnell
**Speed:** 2-5 seconds
**Quality:** Professional-grade
**License:** Apache 2.0
FLUX.1-schnell is optimized for rapid, high-quality image generation perfect for advertising.
""")
# Quick Guide
with st.expander("πŸ“– Quick Guide"):
st.markdown("""
**Steps to Generate:**
1. Enter your HF API token above
2. Fill in the JSON with your ad concept
3. Click "Generate Advertisement"
4. Download your results!
**Get API Token:**
1. Visit [Hugging Face](https://huggingface.co)
2. Sign up/Login
3. Go to Settings β†’ Access Tokens
4. Create new token with "Make calls to Inference API" permission
""")
# Main Content Area
col1, col2 = st.columns([1, 1], gap="large")
with col1:
st.subheader("πŸ“ Advertisement Concept (JSON)")
# Sample JSON templates
sample_templates = {
"Tech Product": {
"product_name": "AirPods Pro Max",
"description": "Premium wireless headphones with spatial audio and active noise cancellation",
"style": "minimalist and modern",
"mood": "sophisticated and premium",
"target_audience": "tech enthusiasts and professionals",
"colors": ["space gray", "silver", "white"],
"setting": "modern studio with gradient background",
"composition": "centered with dramatic angle",
"lighting": "soft key light with rim lighting",
"key_features": ["noise cancellation", "spatial audio", "premium build"]
},
"Food & Beverage": {
"product_name": "Artisan Cold Brew Coffee",
"description": "Smooth cold brew coffee with hints of chocolate and caramel",
"style": "rustic and organic",
"mood": "warm and inviting",
"target_audience": "coffee lovers and millennials",
"colors": ["rich brown", "cream", "amber"],
"setting": "wooden table with natural elements",
"composition": "rule of thirds with foreground elements",
"lighting": "natural morning light with soft shadows",
"key_features": ["organic", "small batch", "sustainably sourced"]
},
"Fashion": {
"product_name": "Summer Collection Dress",
"description": "Elegant flowing summer dress with floral patterns",
"style": "editorial fashion",
"mood": "elegant and breezy",
"target_audience": "fashion-forward women 25-40",
"colors": ["pastel pink", "soft white", "gold accents"],
"setting": "outdoor garden at golden hour",
"composition": "full body shot with movement",
"lighting": "golden hour backlight with fill",
"key_features": ["sustainable fabric", "handcrafted", "limited edition"]
},
"Automotive": {
"product_name": "Electric Sports Car",
"description": "Sleek electric vehicle with cutting-edge design and performance",
"style": "dynamic and futuristic",
"mood": "powerful and innovative",
"target_audience": "luxury car enthusiasts",
"colors": ["metallic blue", "chrome", "carbon black"],
"setting": "modern urban environment at night",
"composition": "low angle hero shot",
"lighting": "dramatic city lights with reflections",
"key_features": ["0-60 in 2.5s", "400 mile range", "autopilot"]
}
}
template_choice = st.selectbox(
"πŸ“‹ Choose a template or start from scratch",
options=["Custom"] + list(sample_templates.keys())
)
if template_choice != "Custom":
default_json = json.dumps(sample_templates[template_choice], indent=2)
else:
default_json = json.dumps({
"product_name": "",
"description": "",
"style": "professional",
"mood": "energetic",
"target_audience": "",
"colors": [],
"setting": "studio",
"composition": "centered",
"lighting": "soft studio lighting"
}, indent=2)
json_input = st.text_area(
"Paste or edit your advertisement concept:",
value=default_json,
height=400,
help="Provide details about your product and desired ad style"
)
# Generate Button
generate_btn = st.button(
"πŸš€ Generate Advertisement",
type="primary",
disabled=not st.session_state.api_token,
use_container_width=True
)
if generate_btn:
if not st.session_state.api_token:
st.error("❌ Please enter your Hugging Face API token in the sidebar first!")
else:
# Validate JSON
ad_data, error = validate_json(json_input)
if error:
st.error(f"❌ {error}")
else:
st.markdown('<div class="success-box">βœ… JSON validated successfully!</div>', unsafe_allow_html=True)
# Generate prompts
main_prompt = generate_advanced_prompt(ad_data)
negative_prompt = generate_negative_prompt(ad_data)
# Display prompts
with st.expander("πŸ” View Generated Prompts"):
st.write("**Positive Prompt:**")
st.code(main_prompt, language=None)
st.write("**Negative Prompt:**")
st.code(negative_prompt, language=None)
# Generate images
st.markdown('<div class="info-box">⏳ Generating your advertisement... This takes 5-10 seconds.</div>', unsafe_allow_html=True)
progress_bar = st.progress(0)
generated_images = []
for i in range(num_variations):
progress_bar.progress((i + 1) / num_variations)
with st.spinner(f"Creating variation {i+1}/{num_variations}..."):
image, error = call_huggingface_api(
main_prompt,
negative_prompt,
st.session_state.api_token
)
if error:
st.error(f"❌ Variation {i+1} failed: {error}")
if "loading" in error.lower():
st.info("πŸ’‘ Tip: The model is warming up. Wait 20 seconds and try again.")
break
else:
generated_images.append({
'image': image,
'prompt': main_prompt,
'timestamp': datetime.now(),
'ad_data': ad_data
})
time.sleep(1) # Small delay between requests
progress_bar.progress(1.0)
if generated_images:
st.session_state.generated_images = generated_images
st.session_state.generation_history.append({
'timestamp': datetime.now(),
'count': len(generated_images),
'product': ad_data.get('product_name', 'Unknown')
})
st.success(f"✨ Successfully generated {len(generated_images)} advertisement(s)!")
st.balloons()
with col2:
st.subheader("πŸ–ΌοΈ Generated Advertisements")
if st.session_state.generated_images:
for idx, item in enumerate(st.session_state.generated_images):
st.image(
item['image'],
caption=f"Advertisement {idx + 1} - {item['ad_data'].get('product_name', 'Product')}",
use_container_width=True
)
# Download and info buttons
col_a, col_b = st.columns(2)
with col_a:
# Convert image to bytes for download
img_buffer = io.BytesIO()
item['image'].save(img_buffer, format='PNG')
img_buffer.seek(0)
st.download_button(
label=f"⬇️ Download Ad {idx + 1}",
data=img_buffer,
file_name=f"adforge_ad_{idx+1}_{item['timestamp'].strftime('%Y%m%d_%H%M%S')}.png",
mime="image/png",
key=f"download_{idx}",
use_container_width=True
)
with col_b:
with st.expander("ℹ️ Details"):
st.json(item['ad_data'])
st.markdown("---")
# Clear history button
if st.button("πŸ—‘οΈ Clear All", use_container_width=True):
st.session_state.generated_images = []
st.rerun()
else:
st.info("πŸ‘ˆ Configure your advertisement concept and click 'Generate Advertisement' to see results here")
# Show example
st.markdown("### 🎯 Example Output")
st.markdown("""
Your generated advertisements will appear here with:
- ✨ High-resolution professional images
- πŸ“₯ One-click download buttons
- πŸ“Š Detailed generation information
- 🎨 Multiple variations (if selected)
""")
# JSON Schema Documentation
with st.expander("πŸ“š JSON Schema Reference"):
st.markdown("""
### Complete JSON Schema for Advertisement Generation
```json
{
"product_name": "Your Product Name", // REQUIRED: Product identifier
"description": "Detailed product description", // REQUIRED: What the product is
"style": "minimalist | professional | vibrant | editorial | rustic",
"mood": "energetic | calm | luxurious | playful | sophisticated",
"target_audience": "Description of target demographic",
"colors": ["color1", "color2", "color3"], // Preferred color palette
"setting": "studio | outdoor | urban | natural",
"composition": "centered | rule of thirds | dynamic",
"lighting": "soft | dramatic | natural | studio",
"key_features": ["feature1", "feature2"], // Product highlights
"avoid": ["element1", "element2"] // Optional: Things to exclude
}
```
### Field Descriptions
**Required Fields:**
- `product_name`: The name of your product
- `description`: Brief but descriptive product overview
**Style Options:**
- `minimalist`: Clean, simple, modern aesthetics
- `professional`: Corporate, polished look
- `vibrant`: Colorful, energetic, eye-catching
- `editorial`: Magazine-quality, artistic
- `rustic`: Natural, organic, handcrafted feel
**Mood Options:**
- `energetic`: Dynamic, active, exciting
- `calm`: Peaceful, serene, relaxing
- `luxurious`: Premium, high-end, exclusive
- `playful`: Fun, lighthearted, creative
- `sophisticated`: Refined, elegant, classy
**Pro Tips:**
- Be specific with colors (e.g., "ocean blue" instead of just "blue")
- Combine multiple style elements for unique results
- Use the `avoid` field to exclude unwanted elements
- Experiment with different composition and lighting combinations
""")
# Footer with statistics
st.markdown("---")
col_foot1, col_foot2, col_foot3 = st.columns(3)
with col_foot1:
st.metric("Ads Generated This Session", len(st.session_state.generation_history))
with col_foot2:
total_images = sum(item['count'] for item in st.session_state.generation_history)
st.metric("Total Images Created", total_images)
with col_foot3:
if st.session_state.generation_history:
last_gen = st.session_state.generation_history[-1]['product']
st.metric("Last Product", last_gen)
st.markdown("""
<div style='text-align: center; color: gray; padding: 2rem 0;'>
<p><strong>AdForge AI</strong> - Powered by FLUX.1-schnell & Hugging Face</p>
<p>🌟 Create professional advertisements in seconds | πŸš€ Free to use | πŸ’Ž Commercial license</p>
</div>
""", unsafe_allow_html=True)