Spaces:
Sleeping
Sleeping
StableVITON Deployer
Migration: Switch to Gradio API (merve/fashn-vton-1.5) and updated token
bb682cf | """ | |
| StableVITON Inference Wrapper (Remote API Version) | |
| Abstraction layer for virtual try-on inference using Gradio API | |
| """ | |
| import os | |
| import io | |
| import logging | |
| from PIL import Image | |
| from typing import Optional | |
| from gradio_client import Client, handle_file | |
| import tempfile | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class StableVITONInference: | |
| """ | |
| Wrapper for Virtual Try-On inference via Fashn AI API. | |
| Handles remote connection, preprocessing, and result retrieval. | |
| """ | |
| def __init__( | |
| self, | |
| model_path: str = "merve/fashn-vton-1.5", | |
| hf_token: Optional[str] = None, | |
| **kwargs | |
| ): | |
| """ | |
| Initialize the remote API client. | |
| Args: | |
| model_path: Hugging Face Space ID (default: fashn-ai/fashn-vton-1.5) | |
| hf_token: Optional Hugging Face token for private/pro spaces | |
| """ | |
| self.model_path = model_path | |
| self.hf_token = hf_token or os.getenv("HF_TOKEN") | |
| logger.info(f"Connecting to Gradio API: {self.model_path}") | |
| try: | |
| # Add timeout settings via httpx_kwargs for production reliability | |
| httpx_kwargs = { | |
| "timeout": 120.0, # 2 minute timeout for AI inference | |
| } | |
| self.client = Client( | |
| self.model_path, | |
| token=self.hf_token, | |
| httpx_kwargs=httpx_kwargs | |
| ) | |
| logger.info("Gradio API connected successfully") | |
| except Exception: | |
| logger.exception(f"Failed to connect to Gradio API at {self.model_path}") | |
| raise | |
| def tryon( | |
| self, | |
| person_image: Image.Image, | |
| garment_image: Image.Image, | |
| category: str = "tops", | |
| garment_photo_type: str = "model", | |
| num_timesteps: int = 50, | |
| guidance_scale: float = 1.5, | |
| seed: int = 42, | |
| segmentation_free: bool = True, | |
| **kwargs | |
| ) -> Image.Image: | |
| """ | |
| Perform virtual try-on inference via remote API. | |
| """ | |
| try: | |
| logger.info(f"Starting remote try-on inference (category: {category})") | |
| # Use temporary files to pass images to Gradio client | |
| # delete=False is required for handle_file() to access the path later | |
| with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as p_file, \ | |
| tempfile.NamedTemporaryFile(suffix=".png", delete=False) as g_file: | |
| person_image.save(p_file.name) | |
| garment_image.save(g_file.name) | |
| # Flush and close to ensure data is written on all OSs (like Windows) | |
| p_file.flush() | |
| g_file.flush() | |
| person_path = p_file.name | |
| garment_path = g_file.name | |
| # Ensure boolean conversion for segmentation_free (handles string "true"/"false" from FormData) | |
| if isinstance(segmentation_free, str): | |
| seg_free = segmentation_free.lower() == "true" | |
| else: | |
| seg_free = bool(segmentation_free) | |
| try: | |
| logger.info("Step 4a: Preparing temporary files for API transfer") | |
| result = self.client.predict( | |
| person_image=handle_file(person_path), | |
| garment_image=handle_file(garment_path), | |
| category=category, | |
| garment_photo_type=garment_photo_type, | |
| num_timesteps=int(num_timesteps), | |
| guidance_scale=float(guidance_scale), | |
| seed=int(seed), | |
| segmentation_free=seg_free, | |
| api_name="/try_on" | |
| ) | |
| # The result can be a string (path), a dict with 'path', or a list/tuple | |
| logger.info("Step 4b: API response received, processing result path") | |
| print(f"API Result Type: {type(result)}") | |
| print(f"API Result: {result}") | |
| result_image_path = None | |
| if isinstance(result, str): | |
| result_image_path = result | |
| elif isinstance(result, dict) and 'path' in result: | |
| result_image_path = result['path'] | |
| elif isinstance(result, (list, tuple)) and len(result) > 0: | |
| # If it's a list, take the first item (often the path) | |
| first_item = result[0] | |
| if isinstance(first_item, str): | |
| result_image_path = first_item | |
| elif isinstance(first_item, dict) and 'path' in first_item: | |
| result_image_path = first_item['path'] | |
| if not result_image_path: | |
| raise ValueError(f"Could not extract image path from API result: {result}") | |
| result_image = Image.open(result_image_path) | |
| logger.info("Step 4c: Result image opened successfully") | |
| return result_image | |
| finally: | |
| # Cleanup local temp input files | |
| if os.path.exists(person_path): os.remove(person_path) | |
| if os.path.exists(garment_path): os.remove(garment_path) | |
| except Exception: | |
| logger.exception("Remote inference failed with traceback:") | |
| raise | |
| def cleanup(self): | |
| """No local memory cleanup needed for API version""" | |
| pass | |
| def __del__(self): | |
| pass | |
| # Example usage | |
| if __name__ == "__main__": | |
| # Test the API wrapper | |
| print("Testing Fashn-AI Gradio API Wrapper") | |
| # Create dummy images for testing | |
| person_img = Image.new("RGB", (512, 768), color=(200, 200, 200)) | |
| garment_img = Image.new("RGB", (512, 512), color=(100, 150, 200)) | |
| # Initialize wrapper | |
| try: | |
| wrapper = StableVITONInference() | |
| # Note: Actual prediction might fail without real images or token | |
| print("Initialized successfully") | |
| except Exception as e: | |
| print(f"Error during initialization: {e}") | |