Jayal04's picture
Update app.py
e28e0cc verified
import streamlit as st
import os
import torch
import numpy as np
from PIL import Image
import time
from datetime import date
from diffusers import StableDiffusionImg2ImgPipeline
import face_alignment
from io import BytesIO
# --- Configuration ---
# Set the device for PyTorch (GPU if available, otherwise CPU)
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
# Stable Diffusion Base Model and LoRA details
BASE_MODEL_ID = "runwayml/stable-diffusion-v1-5"
LORA_MODEL_ID = "navmesh/Lora"
LORA_WEIGHT_NAME = "age_slider-LECO-v1.safetensors"
# --- Streamlit App Setup ---
st.set_page_config(
layout="wide", # Use wide layout for better side-by-side display
page_title="AI Age Progression Studio",
page_icon="⏳", # A more specific hourglass icon
initial_sidebar_state="expanded"
)
# Custom CSS for a cleaner look (optional, but enhances modern feel)
st.markdown("""
<style>
.reportview-container {
background: #f0f2f6; /* Light gray background */
}
.sidebar .sidebar-content {
background: #ffffff; /* White sidebar */
}
h1, h2, h3, h4, h5, h6 {
color: #262730; /* Darker text for headings */
}
.stButton>button {
background-color: #4CAF50; /* Green button */
color: white;
border-radius: 8px;
border: none;
padding: 10px 20px;
font-size: 16px;
cursor: pointer;
transition: background-color 0.3s ease;
}
.stButton>button:hover {
background-color: #45a049;
}
.stRadio > label > div {
padding: 5px 0;
}
.stSlider > div > div > div > div {
background-color: #4CAF50; /* Slider track color */
}
.stSlider [data-baseweb="slider"] {
background-color: #e0e0e0; /* Slider background */
}
.stSlider [data-baseweb="slider"] > div:first-child {
background-color: #4CAF50; /* Slider thumb color */
}
.stFileUploader label {
font-weight: bold;
}
.stAlert {
border-radius: 8px;
}
</style>
""", unsafe_allow_html=True)
st.title("⏳ AI Age Progression Studio")
st.markdown("""
**Upload a photo and guide the AI to transform its age!**
""")
# --- Helper Functions ---
# Helper function to align face
def align_face_for_input(image_pil, fa_model, target_size=(512, 512), padding_factor=0.6):
"""Aligns a face in the input image for better diffusion results."""
try:
preds = fa_model.get_landmarks(np.array(image_pil))
if preds is None or len(preds) == 0:
st.warning("⚠️ No face detected by face_alignment model. Please ensure the image clearly shows a face.")
return None, None
landmarks = preds[0] # Assuming one face is detected
x_min, y_min = np.min(landmarks, axis=0)
x_max, y_max = np.max(landmarks, axis=0)
center_x = (x_min + x_max) / 2
center_y = (y_min + y_max) / 2
face_width = x_max - x_min
face_height = y_max - y_min
crop_dim = max(face_width, face_height)
padded_crop_dim = crop_dim * (1 + padding_factor) # Add padding around the face
x1 = int(center_x - padded_crop_dim / 2)
y1 = int(center_y - padded_crop_dim / 2)
x2 = int(center_x + padded_crop_dim / 2)
y2 = int(center_y + padded_crop_dim / 2)
img_width, img_height = image_pil.size
x1 = max(0, x1)
y1 = max(0, y1)
x2 = min(img_width, x2)
y2 = min(img_height, y2)
cropped_image = image_pil.crop((x1, y1, x2, y2))
aligned_image = cropped_image.resize(target_size, Image.Resampling.LANCZOS)
return aligned_image, (x1, y1, x2, y2)
except Exception as e:
st.error(f"An error occurred during face alignment: {e}")
return None, None
# Core Model Loading Logic
@st.cache_resource
def load_models():
"""Loads diffusion pipeline and face alignment model, caches them."""
with st.spinner("⏳ Loading AI models... This might take a moment (especially the first time)."):
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
BASE_MODEL_ID,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
)
pipe = pipe.to(DEVICE)
pipe.load_lora_weights(LORA_MODEL_ID, weight_name=LORA_WEIGHT_NAME)
st.success("✅ Main Diffusion model and LoRA loaded.")
fa = face_alignment.FaceAlignment(face_alignment.LandmarksType.TWO_D, flip_input=False, device=DEVICE)
st.success("✅ Face Alignment model loaded.")
return pipe, fa
def perform_age_transformation_diffusion(input_image_pil, final_prompt, strength_value):
"""
Performs age transformation using a diffusion model with a given prompt.
"""
aligned_input_pil, _ = align_face_for_input(input_image_pil, fa_model, target_size=(512, 512))
if aligned_input_pil is None:
st.error("❌ Could not align face for diffusion. Please ensure the image clearly shows a single face.")
return None
try:
generator = torch.Generator(device=DEVICE).manual_seed(42) # Use a fixed seed for reproducibility
aged_image_output = pipe(
prompt=final_prompt,
image=aligned_input_pil,
num_inference_steps=50,
guidance_scale=6.5,
strength=strength_value,
generator=generator,
negative_prompt="cartoon, illustration, 3d render, anime, painting, blurry, distorted, low quality, bad anatomy, deformed, strange eyes, uncanny valley, watermark, text, signature, unnatural, extra fingers, disfigured, (gender opposite to prompt), cropped, close-up, disproportioned, multiple people, low resolution, bad quality"
).images[0]
except Exception as e:
st.error(f"An error occurred during image generation: {e}")
st.info("💡 Tip: Try a different image or adjust the 'Transformation Strength' and prompt/settings.")
return None
return aged_image_output
# --- Main Streamlit App Logic ---
# Load models
pipe, fa_model = load_models()
# Sidebar for inputs
with st.sidebar:
st.header("📸 Upload Your Photo")
uploaded_file = st.file_uploader("Upload a clear photo of a face:", type=["jpg", "jpeg", "png"], help="For best results, use a well-lit photo with the face clearly visible.")
st.markdown("---")
st.header("⚙️ Transformation Settings")
transformation_mode = st.radio(
"**Select Transformation Method:**",
["Transform to Specific Age", "Transform with Custom Prompt", "Transform to Current Age (Dynamic)"],
index=0, # Default to "Transform to Specific Age"
help="Choose how you want to define the target age for the AI."
)
st.subheader("Common Adjustments")
gender_of_person = st.radio(
"Gender of person in photo:",
["Unspecified", "Male", "Female"],
index=0,
help="Helps the AI maintain gender identity, especially during significant age shifts. This applies to 'Specific Age' and 'Dynamic Age' modes."
)
strength_value = st.slider(
"Transformation Strength:",
min_value=0.0,
max_value=1.0,
value=0.6,
step=0.05,
help="Higher values allow more dramatic changes but may alter identity. Lower values preserve identity but make subtle changes."
)
# --- Conditional Inputs based on Mode ---
target_age_display = "N/A" # For display in caption later
if transformation_mode == "Transform to Specific Age":
st.subheader("🔢 Specific Age Settings")
target_age_slider = st.slider("Select Target Age:", min_value=1, max_value=90, value=30, step=1, help="Choose the exact age you'd like the person in the photo to appear as.")
# Prepare prompt for this mode
target_age_display = f"{target_age_slider} years"
age_descriptor = f"{target_age_slider} year old"
if 1 <= target_age_slider <= 5: age_descriptor = "a baby" if target_age_slider < 2 else "a young child"
elif 6 <= target_age_slider <= 12: age_descriptor = "a child"
elif 13 <= target_age_slider <= 19: age_descriptor = "a teenager"
elif 70 <= target_age_slider <= 90: age_descriptor = "an elderly person"
elif target_age_slider > 90: age_descriptor = "a very old person"
gender_word = "person"
if gender_of_person == "Male": gender_word = "man"
elif gender_of_person == "Female": gender_word = "woman"
final_prompt_for_generation = f"a photo of {age_descriptor} {gender_word}, detailed, realistic, high quality, professional photography, studio lighting, clear face, natural skin texture"
elif transformation_mode == "Transform with Custom Prompt":
st.subheader("📝 Custom Prompt Settings")
custom_prompt_input = st.text_area( # Changed to text_area for more space
"Enter your age transformation prompt:",
value="a photo of a 70 year old man, detailed, realistic, high quality, natural skin texture",
height=100,
help="Describe the desired age and characteristics, e.g., 'a photo of an 80 year old woman with wrinkles and grey hair'. Gender and age from this prompt will override common settings."
)
final_prompt_for_generation = custom_prompt_input
target_age_display = "Custom Prompt"
elif transformation_mode == "Transform to Current Age (Dynamic)":
st.subheader("⏳ Dynamic Age Settings")
image_taken_date = st.date_input(
"Date when the photo was taken:",
value=date(2015, 1, 1),
max_value=date.today(),
help="Select the approximate date the original photo was captured."
)
age_in_photo = st.number_input(
"Age of person in this photo (on that date):",
min_value=1,
max_value=90,
value=20,
step=1,
help="Enter the age of the person depicted in the uploaded photo on the selected date."
)
# Calculate target age
current_year = date.today().year
years_diff = current_year - image_taken_date.year
calculated_target_age = age_in_photo + years_diff
# Display calculated age
st.info(f"Calculated target age: **{calculated_target_age} years old**")
target_age_display = f"{calculated_target_age} years (calculated)"
# Prepare prompt for this mode
age_descriptor = f"{calculated_target_age} year old"
if 1 <= calculated_target_age <= 5: age_descriptor = "a baby" if calculated_target_age < 2 else "a young child"
elif 6 <= calculated_target_age <= 12: age_descriptor = "a child"
elif 13 <= calculated_target_age <= 19: age_descriptor = "a teenager"
elif 70 <= calculated_target_age <= 90: age_descriptor = "an elderly person"
elif calculated_target_age > 90: age_descriptor = "a very old person"
gender_word = "person"
if gender_of_person == "Male": gender_word = "man"
elif gender_of_person == "Female": gender_word = "woman"
final_prompt_for_generation = f"a photo of {age_descriptor} {gender_word}, detailed, realistic, high quality, professional photography, studio lighting, clear face, natural skin texture"
st.markdown("---")
generate_button = st.button("✨ Generate Transformed Image ✨", use_container_width=True)
st.markdown("---")
st.markdown("Developed with ❤️")
# Main content area
if uploaded_file is not None:
input_image_pil = Image.open(uploaded_file).convert("RGB")
# Create two columns for side-by-side display
col1, col2 = st.columns(2)
with col1:
st.subheader("Original Photo 🖼️")
st.image(input_image_pil, caption="Your uploaded image", use_container_width=True)
with col2:
if generate_button:
# Basic validation for custom prompt mode
if transformation_mode == "Transform with Custom Prompt" and not final_prompt_for_generation.strip():
st.error("❌ Please enter a custom prompt for age transformation!")
elif transformation_mode == "Transform to Current Age (Dynamic)":
if calculated_target_age <= 0:
st.error("❌ Calculated age is 0 or less. Please ensure the 'Date when photo was taken' and 'Age of person in this photo' are correct.")
elif calculated_target_age > 100:
st.warning(f"⚠️ Calculated age ({calculated_target_age}) is very high. AI models may not produce realistic results for ages above ~90. Proceeding with generation, but results may be less accurate.")
# Continue with generation only if no critical error for dynamic mode
if calculated_target_age <= 0:
st.stop() # Stop execution if age is invalid
with st.spinner(f"⏳ Generating transformed image..."):
start_time = time.time()
aged_image = perform_age_transformation_diffusion(
input_image_pil,
final_prompt_for_generation,
strength_value
)
end_time = time.time()
st.success(f"✅ Image generated in {end_time - start_time:.2f} seconds!")
if aged_image:
caption_text = f"Transformed image (Target: {target_age_display})"
if transformation_mode == "Transform with Custom Prompt":
caption_text = f"Transformed image (Prompt: '{final_prompt_for_generation}')"
st.subheader(f"Transformed Photo! ✨")
st.image(aged_image, caption=caption_text, use_container_width=True)
# Download button
buf_image = BytesIO()
aged_image.save(buf_image, format="PNG")
byte_im = buf_image.getvalue()
st.download_button(
label="⬇️ Download Transformed Image (512x512)",
data=byte_im,
file_name=f"transformed_image_{target_age_display.replace(' ', '_').replace('(','').replace(')','')}.png",
mime="image/png",
use_container_width=True
)
else:
st.error("❌ Failed to generate transformed image. Please check your inputs and try again.")
else:
st.subheader("Transformed Photo will appear here 🪄")
st.info("Click 'Generate Transformed Image' to see the magic!")
# --- Warnings about model limitations in a dropdown ---
st.markdown("---")
with st.expander("❓ Important Notes & Limitations"):
st.warning("""
* **Age Range:** The AI model may struggle with very young (1-5 years) or very old (80-90+ years) transformations.
* **Gender Consistency:** Especially with large age gaps, the model might subtly or overtly change perceived gender.
* **Identity Preservation:** Significant transformations (high 'Strength') can alter facial features, making the generated person look less like the original.
* **Prompt Specificity (Custom Mode):** For 'Custom Prompt' mode, the quality of the output heavily depends on how well you describe the desired age and features.
* **Output Zoom:** If the output still appears too zoomed, try using an input image where the face is not already extremely close-up.
""")
st.info("💡 **Tip:** Experiment with the settings! For best results, use clear, well-lit photos with a single, prominent face.")
else:
st.info("⬆️ Upload a photo and choose a transformation method from the sidebar to get started!")