import streamlit as st import requests import base64 import io import os import time from PIL import Image as PILImage from dotenv import load_dotenv load_dotenv() import base64 from io import BytesIO from PIL import Image as PILImage def image_to_base64(image_file) -> str : image_bytes = image_file.getvalue() img = PILImage.open(BytesIO(image_bytes)) img_format : str = img.format.lower() if img.format else 'jpeg' base64_str = base64.b64encode(image_bytes).decode('utf-8') return f'data:image/{img_format};base64,{base64_str}' def base64_to_image(base64_str : str) : if ',' in base64_str : base64_str = base64_str.split(',')[1] image_bytes = base64.b64decode(base64_str) return PILImage.open(BytesIO(image_bytes)) api_key : str = os.environ['FASHN_API_KEY'] col1 , col2 = st.columns(2) with col1 : st.subheader("Model Image") model_file = st.file_uploader( 'Upload model/person image' , type = ['png' , 'jpg' , 'jpeg'] , key = 'model_uploader' ) if model_file is not None : model_image = PILImage.open(model_file) st.image(model_image , caption = 'Uploaded Model Image' , use_container_width = True) with col2 : st.subheader('Garment Image') garment_file = st.file_uploader( 'Upload garment/clothing image' , type = ['png' , 'jpg' , 'jpeg'] , key = 'garment_uploader' ) if garment_file is not None : garment_image = PILImage.open(garment_file) st.image(garment_image , caption = 'Uploaded Garment Image' , use_container_width = True) st.markdown('---') if st.button('🎨 Generate Virtual Try-On', type = 'primary' , use_container_width = True) : if not api_key : st.error("⚠️ Please enter your Fashn.ai API key in the sidebar!") elif model_file is None or garment_file is None : st.error("⚠️ Please upload both model and garment images!") else : try : model_base64 : str = image_to_base64(model_file) garment_base64 : str = image_to_base64(garment_file) with st.spinner('Making this person wear something else') : url = 'https://api.fashn.ai/v1/run' headers = { 'Authorization' : f'Bearer {api_key}' , 'Content-Type' : 'application/json' } payload = { 'model_name' : 'tryon-v1.6' , 'inputs' : { 'model_image' : model_base64 , 'garment_image' : garment_base64 , 'category' : 'auto' , 'segmentation_free' : True , 'moderation_level' : 'conservative' , 'garment_photo_type' : 'auto' , 'mode' : 'balanced' , 'seed' : 42 , 'num_samples' : 1 , 'output_format' : 'png' , 'return_base64' : True } } response = requests.post(url , headers = headers , json = payload) if response.status_code == 200 : result = response.json() if 'id' in result : job_id = result['id'] max_attempts = 60 attempt = 0 progress_bar = st.progress(0) status_text = st.empty() while attempt < max_attempts: attempt += 1 progress_bar.progress(min(attempt / max_attempts, 0.99)) status_text.text(f"Polling for results... Attempt {attempt}/{max_attempts}") # Get job status status_response = requests.get( f"https://api.fashn.ai/v1/status/{job_id}", headers={"Authorization": f"Bearer {api_key}"} ) if status_response.status_code == 200: status_result = status_response.json() status = status_result.get("status") if status == "completed": progress_bar.progress(1.0) status_text.empty() progress_bar.empty() # Display the result st.success("✅ Virtual try-on generated successfully!") st.subheader("Result") # Get the generated image from base64 if "output" in status_result and len(status_result["output"]) > 0: base64_image = status_result["output"][0] generated_image = base64_to_image(base64_image) # Display the generated image st.image(generated_image, caption="Virtual Try-On Result", use_container_width=True) # Convert to bytes for download img_byte_arr = io.BytesIO() generated_image.save(img_byte_arr, format='PNG') img_byte_arr = img_byte_arr.getvalue() # Download button st.download_button( label="📥 Download Result", data=img_byte_arr, file_name="virtual_tryon_result.png", mime="image/png" ) else: st.error("No output image found in the response") with st.expander("View API Response"): st.json(status_result) break elif status == "failed": progress_bar.empty() status_text.empty() st.error(f"❌ Job failed: {status_result.get('error', 'Unknown error')}") with st.expander("View API Response"): st.json(status_result) break elif status in ["processing", "queued"]: # Still processing, continue polling time.sleep(2) continue else: # Unknown status st.warning(f"Unknown status: {status}") with st.expander("View API Response"): st.json(status_result) time.sleep(2) else: progress_bar.empty() status_text.empty() st.error(f"Failed to get status: {status_response.status_code}") st.error(status_response.text) break if attempt >= max_attempts: progress_bar.empty() status_text.empty() st.error("⏱️ Timeout: Job took too long to complete. Please try again.") elif "output" in result: # Direct response with output (synchronous) st.success("✅ Virtual try-on generated successfully!") st.subheader("Result") base64_image = result["output"][0] generated_image = base64_to_image(base64_image) st.image(generated_image, caption="Virtual Try-On Result", use_container_width=True) img_byte_arr = io.BytesIO() generated_image.save(img_byte_arr, format='PNG') img_byte_arr = img_byte_arr.getvalue() st.download_button( label="📥 Download Result", data=img_byte_arr, file_name="virtual_tryon_result.png", mime="image/png" ) else: st.error("Unexpected response format") with st.expander("View API Response"): st.json(result) else: st.error(f"❌ API Error: {response.status_code}") st.error(f"Message: {response.text}") except Exception as e: st.error(f"❌ An error occurred: {str(e)}") import traceback with st.expander("View Error Details"): st.code(traceback.format_exc())