tyon / src /streamlit_app.py
ayushsinghal1510's picture
Update src/streamlit_app.py
5bc7631 verified
import streamlit as st
import requests
import base64
import io
import os
import time
from PIL import Image as PILImage
from dotenv import load_dotenv
load_dotenv()
import base64
from io import BytesIO
from PIL import Image as PILImage
def image_to_base64(image_file) -> str :
image_bytes = image_file.getvalue()
img = PILImage.open(BytesIO(image_bytes))
img_format : str = img.format.lower() if img.format else 'jpeg'
base64_str = base64.b64encode(image_bytes).decode('utf-8')
return f'data:image/{img_format};base64,{base64_str}'
def base64_to_image(base64_str : str) :
if ',' in base64_str :
base64_str = base64_str.split(',')[1]
image_bytes = base64.b64decode(base64_str)
return PILImage.open(BytesIO(image_bytes))
api_key : str = os.environ['FASHN_API_KEY']
col1 , col2 = st.columns(2)
with col1 :
st.subheader("Model Image")
model_file = st.file_uploader(
'Upload model/person image' ,
type = ['png' , 'jpg' , 'jpeg'] ,
key = 'model_uploader'
)
if model_file is not None :
model_image = PILImage.open(model_file)
st.image(model_image , caption = 'Uploaded Model Image' , use_container_width = True)
with col2 :
st.subheader('Garment Image')
garment_file = st.file_uploader(
'Upload garment/clothing image' ,
type = ['png' , 'jpg' , 'jpeg'] ,
key = 'garment_uploader'
)
if garment_file is not None :
garment_image = PILImage.open(garment_file)
st.image(garment_image , caption = 'Uploaded Garment Image' , use_container_width = True)
st.markdown('---')
if st.button('🎨 Generate Virtual Try-On', type = 'primary' , use_container_width = True) :
if not api_key :
st.error("⚠️ Please enter your Fashn.ai API key in the sidebar!")
elif model_file is None or garment_file is None :
st.error("⚠️ Please upload both model and garment images!")
else :
try :
model_base64 : str = image_to_base64(model_file)
garment_base64 : str = image_to_base64(garment_file)
with st.spinner('Making this person wear something else') :
url = 'https://api.fashn.ai/v1/run'
headers = {
'Authorization' : f'Bearer {api_key}' ,
'Content-Type' : 'application/json'
}
payload = {
'model_name' : 'tryon-v1.6' ,
'inputs' : {
'model_image' : model_base64 ,
'garment_image' : garment_base64 ,
'category' : 'auto' ,
'segmentation_free' : True ,
'moderation_level' : 'conservative' ,
'garment_photo_type' : 'auto' ,
'mode' : 'balanced' ,
'seed' : 42 ,
'num_samples' : 1 ,
'output_format' : 'png' ,
'return_base64' : True
}
}
response = requests.post(url , headers = headers , json = payload)
if response.status_code == 200 :
result = response.json()
if 'id' in result :
job_id = result['id']
max_attempts = 60
attempt = 0
progress_bar = st.progress(0)
status_text = st.empty()
while attempt < max_attempts:
attempt += 1
progress_bar.progress(min(attempt / max_attempts, 0.99))
status_text.text(f"Polling for results... Attempt {attempt}/{max_attempts}")
# Get job status
status_response = requests.get(
f"https://api.fashn.ai/v1/status/{job_id}",
headers={"Authorization": f"Bearer {api_key}"}
)
if status_response.status_code == 200:
status_result = status_response.json()
status = status_result.get("status")
if status == "completed":
progress_bar.progress(1.0)
status_text.empty()
progress_bar.empty()
# Display the result
st.success("βœ… Virtual try-on generated successfully!")
st.subheader("Result")
# Get the generated image from base64
if "output" in status_result and len(status_result["output"]) > 0:
base64_image = status_result["output"][0]
generated_image = base64_to_image(base64_image)
# Display the generated image
st.image(generated_image, caption="Virtual Try-On Result", use_container_width=True)
# Convert to bytes for download
img_byte_arr = io.BytesIO()
generated_image.save(img_byte_arr, format='PNG')
img_byte_arr = img_byte_arr.getvalue()
# Download button
st.download_button(
label="πŸ“₯ Download Result",
data=img_byte_arr,
file_name="virtual_tryon_result.png",
mime="image/png"
)
else:
st.error("No output image found in the response")
with st.expander("View API Response"):
st.json(status_result)
break
elif status == "failed":
progress_bar.empty()
status_text.empty()
st.error(f"❌ Job failed: {status_result.get('error', 'Unknown error')}")
with st.expander("View API Response"):
st.json(status_result)
break
elif status in ["processing", "queued"]:
# Still processing, continue polling
time.sleep(2)
continue
else:
# Unknown status
st.warning(f"Unknown status: {status}")
with st.expander("View API Response"):
st.json(status_result)
time.sleep(2)
else:
progress_bar.empty()
status_text.empty()
st.error(f"Failed to get status: {status_response.status_code}")
st.error(status_response.text)
break
if attempt >= max_attempts:
progress_bar.empty()
status_text.empty()
st.error("⏱️ Timeout: Job took too long to complete. Please try again.")
elif "output" in result:
# Direct response with output (synchronous)
st.success("βœ… Virtual try-on generated successfully!")
st.subheader("Result")
base64_image = result["output"][0]
generated_image = base64_to_image(base64_image)
st.image(generated_image, caption="Virtual Try-On Result", use_container_width=True)
img_byte_arr = io.BytesIO()
generated_image.save(img_byte_arr, format='PNG')
img_byte_arr = img_byte_arr.getvalue()
st.download_button(
label="πŸ“₯ Download Result",
data=img_byte_arr,
file_name="virtual_tryon_result.png",
mime="image/png"
)
else:
st.error("Unexpected response format")
with st.expander("View API Response"):
st.json(result)
else:
st.error(f"❌ API Error: {response.status_code}")
st.error(f"Message: {response.text}")
except Exception as e:
st.error(f"❌ An error occurred: {str(e)}")
import traceback
with st.expander("View Error Details"):
st.code(traceback.format_exc())