vton-backend / inference_wrapper.py
StableVITON Deployer
Migration: Switch to Gradio API (merve/fashn-vton-1.5) and updated token
bb682cf
Raw
History Blame Contribute Delete
6.22 kB
"""
StableVITON Inference Wrapper (Remote API Version)
Abstraction layer for virtual try-on inference using Gradio API
"""
import os
import io
import logging
from PIL import Image
from typing import Optional
from gradio_client import Client, handle_file
import tempfile
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class StableVITONInference:
"""
Wrapper for Virtual Try-On inference via Fashn AI API.
Handles remote connection, preprocessing, and result retrieval.
"""
def __init__(
self,
model_path: str = "merve/fashn-vton-1.5",
hf_token: Optional[str] = None,
**kwargs
):
"""
Initialize the remote API client.
Args:
model_path: Hugging Face Space ID (default: fashn-ai/fashn-vton-1.5)
hf_token: Optional Hugging Face token for private/pro spaces
"""
self.model_path = model_path
self.hf_token = hf_token or os.getenv("HF_TOKEN")
logger.info(f"Connecting to Gradio API: {self.model_path}")
try:
# Add timeout settings via httpx_kwargs for production reliability
httpx_kwargs = {
"timeout": 120.0, # 2 minute timeout for AI inference
}
self.client = Client(
self.model_path,
token=self.hf_token,
httpx_kwargs=httpx_kwargs
)
logger.info("Gradio API connected successfully")
except Exception:
logger.exception(f"Failed to connect to Gradio API at {self.model_path}")
raise
def tryon(
self,
person_image: Image.Image,
garment_image: Image.Image,
category: str = "tops",
garment_photo_type: str = "model",
num_timesteps: int = 50,
guidance_scale: float = 1.5,
seed: int = 42,
segmentation_free: bool = True,
**kwargs
) -> Image.Image:
"""
Perform virtual try-on inference via remote API.
"""
try:
logger.info(f"Starting remote try-on inference (category: {category})")
# Use temporary files to pass images to Gradio client
# delete=False is required for handle_file() to access the path later
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as p_file, \
tempfile.NamedTemporaryFile(suffix=".png", delete=False) as g_file:
person_image.save(p_file.name)
garment_image.save(g_file.name)
# Flush and close to ensure data is written on all OSs (like Windows)
p_file.flush()
g_file.flush()
person_path = p_file.name
garment_path = g_file.name
# Ensure boolean conversion for segmentation_free (handles string "true"/"false" from FormData)
if isinstance(segmentation_free, str):
seg_free = segmentation_free.lower() == "true"
else:
seg_free = bool(segmentation_free)
try:
logger.info("Step 4a: Preparing temporary files for API transfer")
result = self.client.predict(
person_image=handle_file(person_path),
garment_image=handle_file(garment_path),
category=category,
garment_photo_type=garment_photo_type,
num_timesteps=int(num_timesteps),
guidance_scale=float(guidance_scale),
seed=int(seed),
segmentation_free=seg_free,
api_name="/try_on"
)
# The result can be a string (path), a dict with 'path', or a list/tuple
logger.info("Step 4b: API response received, processing result path")
print(f"API Result Type: {type(result)}")
print(f"API Result: {result}")
result_image_path = None
if isinstance(result, str):
result_image_path = result
elif isinstance(result, dict) and 'path' in result:
result_image_path = result['path']
elif isinstance(result, (list, tuple)) and len(result) > 0:
# If it's a list, take the first item (often the path)
first_item = result[0]
if isinstance(first_item, str):
result_image_path = first_item
elif isinstance(first_item, dict) and 'path' in first_item:
result_image_path = first_item['path']
if not result_image_path:
raise ValueError(f"Could not extract image path from API result: {result}")
result_image = Image.open(result_image_path)
logger.info("Step 4c: Result image opened successfully")
return result_image
finally:
# Cleanup local temp input files
if os.path.exists(person_path): os.remove(person_path)
if os.path.exists(garment_path): os.remove(garment_path)
except Exception:
logger.exception("Remote inference failed with traceback:")
raise
def cleanup(self):
"""No local memory cleanup needed for API version"""
pass
def __del__(self):
pass
# Example usage
if __name__ == "__main__":
# Test the API wrapper
print("Testing Fashn-AI Gradio API Wrapper")
# Create dummy images for testing
person_img = Image.new("RGB", (512, 768), color=(200, 200, 200))
garment_img = Image.new("RGB", (512, 512), color=(100, 150, 200))
# Initialize wrapper
try:
wrapper = StableVITONInference()
# Note: Actual prediction might fail without real images or token
print("Initialized successfully")
except Exception as e:
print(f"Error during initialization: {e}")