""" StableVITON Inference Wrapper (Remote API Version) Abstraction layer for virtual try-on inference using Gradio API """ import os import io import logging from PIL import Image from typing import Optional from gradio_client import Client, handle_file import tempfile logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class StableVITONInference: """ Wrapper for Virtual Try-On inference via Fashn AI API. Handles remote connection, preprocessing, and result retrieval. """ def __init__( self, model_path: str = "merve/fashn-vton-1.5", hf_token: Optional[str] = None, **kwargs ): """ Initialize the remote API client. Args: model_path: Hugging Face Space ID (default: fashn-ai/fashn-vton-1.5) hf_token: Optional Hugging Face token for private/pro spaces """ self.model_path = model_path self.hf_token = hf_token or os.getenv("HF_TOKEN") logger.info(f"Connecting to Gradio API: {self.model_path}") try: # Add timeout settings via httpx_kwargs for production reliability httpx_kwargs = { "timeout": 120.0, # 2 minute timeout for AI inference } self.client = Client( self.model_path, token=self.hf_token, httpx_kwargs=httpx_kwargs ) logger.info("Gradio API connected successfully") except Exception: logger.exception(f"Failed to connect to Gradio API at {self.model_path}") raise def tryon( self, person_image: Image.Image, garment_image: Image.Image, category: str = "tops", garment_photo_type: str = "model", num_timesteps: int = 50, guidance_scale: float = 1.5, seed: int = 42, segmentation_free: bool = True, **kwargs ) -> Image.Image: """ Perform virtual try-on inference via remote API. """ try: logger.info(f"Starting remote try-on inference (category: {category})") # Use temporary files to pass images to Gradio client # delete=False is required for handle_file() to access the path later with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as p_file, \ tempfile.NamedTemporaryFile(suffix=".png", delete=False) as g_file: person_image.save(p_file.name) garment_image.save(g_file.name) # Flush and close to ensure data is written on all OSs (like Windows) p_file.flush() g_file.flush() person_path = p_file.name garment_path = g_file.name # Ensure boolean conversion for segmentation_free (handles string "true"/"false" from FormData) if isinstance(segmentation_free, str): seg_free = segmentation_free.lower() == "true" else: seg_free = bool(segmentation_free) try: logger.info("Step 4a: Preparing temporary files for API transfer") result = self.client.predict( person_image=handle_file(person_path), garment_image=handle_file(garment_path), category=category, garment_photo_type=garment_photo_type, num_timesteps=int(num_timesteps), guidance_scale=float(guidance_scale), seed=int(seed), segmentation_free=seg_free, api_name="/try_on" ) # The result can be a string (path), a dict with 'path', or a list/tuple logger.info("Step 4b: API response received, processing result path") print(f"API Result Type: {type(result)}") print(f"API Result: {result}") result_image_path = None if isinstance(result, str): result_image_path = result elif isinstance(result, dict) and 'path' in result: result_image_path = result['path'] elif isinstance(result, (list, tuple)) and len(result) > 0: # If it's a list, take the first item (often the path) first_item = result[0] if isinstance(first_item, str): result_image_path = first_item elif isinstance(first_item, dict) and 'path' in first_item: result_image_path = first_item['path'] if not result_image_path: raise ValueError(f"Could not extract image path from API result: {result}") result_image = Image.open(result_image_path) logger.info("Step 4c: Result image opened successfully") return result_image finally: # Cleanup local temp input files if os.path.exists(person_path): os.remove(person_path) if os.path.exists(garment_path): os.remove(garment_path) except Exception: logger.exception("Remote inference failed with traceback:") raise def cleanup(self): """No local memory cleanup needed for API version""" pass def __del__(self): pass # Example usage if __name__ == "__main__": # Test the API wrapper print("Testing Fashn-AI Gradio API Wrapper") # Create dummy images for testing person_img = Image.new("RGB", (512, 768), color=(200, 200, 200)) garment_img = Image.new("RGB", (512, 512), color=(100, 150, 200)) # Initialize wrapper try: wrapper = StableVITONInference() # Note: Actual prediction might fail without real images or token print("Initialized successfully") except Exception as e: print(f"Error during initialization: {e}")