from typing import Dict, List, Any import torch import base64 import io from PIL import Image from tryon_core import TryOnEngine from api_utils import prepare_image_for_processing, image_to_base64 class EndpointHandler: def __init__(self, path=""): # Initialize the engine # path is the path to the model files on the HF container print("Initializing IDM-VTON Handler...") self.engine = TryOnEngine(load_mode="4bit", enable_cpu_offload=False, fixed_vae=True) # Override model_id to load from local path if needed, # or let it download from Hub if path is just a directory # self.engine.model_id = path self.engine.load_models() self.engine.load_processing_models() print("Handler Initialized!") def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ Args: data (:obj:): includes the input data and the parameters for the inference. """ # 1. Extract inputs inputs = data.pop("inputs", data) human_img_b64 = inputs.get("human_image") garment_img_b64 = inputs.get("garment_image") description = inputs.get("garment_description", "a photo of a garment") category = inputs.get("category", "upper_body") # 2. Decode images human_img = Image.open(io.BytesIO(base64.b64decode(human_img_b64))) garment_img = Image.open(io.BytesIO(base64.b64decode(garment_img_b64))) # 3. Process human_img = prepare_image_for_processing(human_img) garment_img = prepare_image_for_processing(garment_img) # 4. Generate generated_images, masked_image = self.engine.generate( human_img=human_img, garment_img=garment_img, garment_description=description, category=category, use_auto_mask=True, use_auto_crop=True, denoise_steps=30, seed=42, num_images=1 ) # 5. Return result return [{ "generated_image": image_to_base64(generated_images[0]), "masked_image": image_to_base64(masked_image) }]