Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| import requests | |
| import base64 | |
| import io | |
| import os | |
| import time | |
| from PIL import Image as PILImage | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| import base64 | |
| from io import BytesIO | |
| from PIL import Image as PILImage | |
| def image_to_base64(image_file) -> str : | |
| image_bytes = image_file.getvalue() | |
| img = PILImage.open(BytesIO(image_bytes)) | |
| img_format : str = img.format.lower() if img.format else 'jpeg' | |
| base64_str = base64.b64encode(image_bytes).decode('utf-8') | |
| return f'data:image/{img_format};base64,{base64_str}' | |
| def base64_to_image(base64_str : str) : | |
| if ',' in base64_str : | |
| base64_str = base64_str.split(',')[1] | |
| image_bytes = base64.b64decode(base64_str) | |
| return PILImage.open(BytesIO(image_bytes)) | |
| api_key : str = os.environ['FASHN_API_KEY'] | |
| col1 , col2 = st.columns(2) | |
| with col1 : | |
| st.subheader("Model Image") | |
| model_file = st.file_uploader( | |
| 'Upload model/person image' , | |
| type = ['png' , 'jpg' , 'jpeg'] , | |
| key = 'model_uploader' | |
| ) | |
| if model_file is not None : | |
| model_image = PILImage.open(model_file) | |
| st.image(model_image , caption = 'Uploaded Model Image' , use_container_width = True) | |
| with col2 : | |
| st.subheader('Garment Image') | |
| garment_file = st.file_uploader( | |
| 'Upload garment/clothing image' , | |
| type = ['png' , 'jpg' , 'jpeg'] , | |
| key = 'garment_uploader' | |
| ) | |
| if garment_file is not None : | |
| garment_image = PILImage.open(garment_file) | |
| st.image(garment_image , caption = 'Uploaded Garment Image' , use_container_width = True) | |
| st.markdown('---') | |
| if st.button('π¨ Generate Virtual Try-On', type = 'primary' , use_container_width = True) : | |
| if not api_key : | |
| st.error("β οΈ Please enter your Fashn.ai API key in the sidebar!") | |
| elif model_file is None or garment_file is None : | |
| st.error("β οΈ Please upload both model and garment images!") | |
| else : | |
| try : | |
| model_base64 : str = image_to_base64(model_file) | |
| garment_base64 : str = image_to_base64(garment_file) | |
| with st.spinner('Making this person wear something else') : | |
| url = 'https://api.fashn.ai/v1/run' | |
| headers = { | |
| 'Authorization' : f'Bearer {api_key}' , | |
| 'Content-Type' : 'application/json' | |
| } | |
| payload = { | |
| 'model_name' : 'tryon-v1.6' , | |
| 'inputs' : { | |
| 'model_image' : model_base64 , | |
| 'garment_image' : garment_base64 , | |
| 'category' : 'auto' , | |
| 'segmentation_free' : True , | |
| 'moderation_level' : 'conservative' , | |
| 'garment_photo_type' : 'auto' , | |
| 'mode' : 'balanced' , | |
| 'seed' : 42 , | |
| 'num_samples' : 1 , | |
| 'output_format' : 'png' , | |
| 'return_base64' : True | |
| } | |
| } | |
| response = requests.post(url , headers = headers , json = payload) | |
| if response.status_code == 200 : | |
| result = response.json() | |
| if 'id' in result : | |
| job_id = result['id'] | |
| max_attempts = 60 | |
| attempt = 0 | |
| progress_bar = st.progress(0) | |
| status_text = st.empty() | |
| while attempt < max_attempts: | |
| attempt += 1 | |
| progress_bar.progress(min(attempt / max_attempts, 0.99)) | |
| status_text.text(f"Polling for results... Attempt {attempt}/{max_attempts}") | |
| # Get job status | |
| status_response = requests.get( | |
| f"https://api.fashn.ai/v1/status/{job_id}", | |
| headers={"Authorization": f"Bearer {api_key}"} | |
| ) | |
| if status_response.status_code == 200: | |
| status_result = status_response.json() | |
| status = status_result.get("status") | |
| if status == "completed": | |
| progress_bar.progress(1.0) | |
| status_text.empty() | |
| progress_bar.empty() | |
| # Display the result | |
| st.success("β Virtual try-on generated successfully!") | |
| st.subheader("Result") | |
| # Get the generated image from base64 | |
| if "output" in status_result and len(status_result["output"]) > 0: | |
| base64_image = status_result["output"][0] | |
| generated_image = base64_to_image(base64_image) | |
| # Display the generated image | |
| st.image(generated_image, caption="Virtual Try-On Result", use_container_width=True) | |
| # Convert to bytes for download | |
| img_byte_arr = io.BytesIO() | |
| generated_image.save(img_byte_arr, format='PNG') | |
| img_byte_arr = img_byte_arr.getvalue() | |
| # Download button | |
| st.download_button( | |
| label="π₯ Download Result", | |
| data=img_byte_arr, | |
| file_name="virtual_tryon_result.png", | |
| mime="image/png" | |
| ) | |
| else: | |
| st.error("No output image found in the response") | |
| with st.expander("View API Response"): | |
| st.json(status_result) | |
| break | |
| elif status == "failed": | |
| progress_bar.empty() | |
| status_text.empty() | |
| st.error(f"β Job failed: {status_result.get('error', 'Unknown error')}") | |
| with st.expander("View API Response"): | |
| st.json(status_result) | |
| break | |
| elif status in ["processing", "queued"]: | |
| # Still processing, continue polling | |
| time.sleep(2) | |
| continue | |
| else: | |
| # Unknown status | |
| st.warning(f"Unknown status: {status}") | |
| with st.expander("View API Response"): | |
| st.json(status_result) | |
| time.sleep(2) | |
| else: | |
| progress_bar.empty() | |
| status_text.empty() | |
| st.error(f"Failed to get status: {status_response.status_code}") | |
| st.error(status_response.text) | |
| break | |
| if attempt >= max_attempts: | |
| progress_bar.empty() | |
| status_text.empty() | |
| st.error("β±οΈ Timeout: Job took too long to complete. Please try again.") | |
| elif "output" in result: | |
| # Direct response with output (synchronous) | |
| st.success("β Virtual try-on generated successfully!") | |
| st.subheader("Result") | |
| base64_image = result["output"][0] | |
| generated_image = base64_to_image(base64_image) | |
| st.image(generated_image, caption="Virtual Try-On Result", use_container_width=True) | |
| img_byte_arr = io.BytesIO() | |
| generated_image.save(img_byte_arr, format='PNG') | |
| img_byte_arr = img_byte_arr.getvalue() | |
| st.download_button( | |
| label="π₯ Download Result", | |
| data=img_byte_arr, | |
| file_name="virtual_tryon_result.png", | |
| mime="image/png" | |
| ) | |
| else: | |
| st.error("Unexpected response format") | |
| with st.expander("View API Response"): | |
| st.json(result) | |
| else: | |
| st.error(f"β API Error: {response.status_code}") | |
| st.error(f"Message: {response.text}") | |
| except Exception as e: | |
| st.error(f"β An error occurred: {str(e)}") | |
| import traceback | |
| with st.expander("View Error Details"): | |
| st.code(traceback.format_exc()) | |