Upload 2 files
Browse files- handler.py +62 -0
- requirements.txt +21 -0
handler.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, List, Any
|
| 2 |
+
import torch
|
| 3 |
+
import base64
|
| 4 |
+
import io
|
| 5 |
+
from PIL import Image
|
| 6 |
+
from tryon_core import TryOnEngine
|
| 7 |
+
from api_utils import prepare_image_for_processing, image_to_base64
|
| 8 |
+
|
| 9 |
+
class EndpointHandler:
|
| 10 |
+
def __init__(self, path=""):
|
| 11 |
+
# Initialize the engine
|
| 12 |
+
# path is the path to the model files on the HF container
|
| 13 |
+
print("Initializing IDM-VTON Handler...")
|
| 14 |
+
self.engine = TryOnEngine(load_mode="4bit", enable_cpu_offload=False, fixed_vae=True)
|
| 15 |
+
|
| 16 |
+
# Override model_id to load from local path if needed,
|
| 17 |
+
# or let it download from Hub if path is just a directory
|
| 18 |
+
# self.engine.model_id = path
|
| 19 |
+
|
| 20 |
+
self.engine.load_models()
|
| 21 |
+
self.engine.load_processing_models()
|
| 22 |
+
print("Handler Initialized!")
|
| 23 |
+
|
| 24 |
+
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
| 25 |
+
"""
|
| 26 |
+
Args:
|
| 27 |
+
data (:obj:):
|
| 28 |
+
includes the input data and the parameters for the inference.
|
| 29 |
+
"""
|
| 30 |
+
# 1. Extract inputs
|
| 31 |
+
inputs = data.pop("inputs", data)
|
| 32 |
+
human_img_b64 = inputs.get("human_image")
|
| 33 |
+
garment_img_b64 = inputs.get("garment_image")
|
| 34 |
+
description = inputs.get("garment_description", "a photo of a garment")
|
| 35 |
+
category = inputs.get("category", "upper_body")
|
| 36 |
+
|
| 37 |
+
# 2. Decode images
|
| 38 |
+
human_img = Image.open(io.BytesIO(base64.b64decode(human_img_b64)))
|
| 39 |
+
garment_img = Image.open(io.BytesIO(base64.b64decode(garment_img_b64)))
|
| 40 |
+
|
| 41 |
+
# 3. Process
|
| 42 |
+
human_img = prepare_image_for_processing(human_img)
|
| 43 |
+
garment_img = prepare_image_for_processing(garment_img)
|
| 44 |
+
|
| 45 |
+
# 4. Generate
|
| 46 |
+
generated_images, masked_image = self.engine.generate(
|
| 47 |
+
human_img=human_img,
|
| 48 |
+
garment_img=garment_img,
|
| 49 |
+
garment_description=description,
|
| 50 |
+
category=category,
|
| 51 |
+
use_auto_mask=True,
|
| 52 |
+
use_auto_crop=True,
|
| 53 |
+
denoise_steps=30,
|
| 54 |
+
seed=42,
|
| 55 |
+
num_images=1
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
# 5. Return result
|
| 59 |
+
return [{
|
| 60 |
+
"generated_image": image_to_base64(generated_images[0]),
|
| 61 |
+
"masked_image": image_to_base64(masked_image)
|
| 62 |
+
}]
|
requirements.txt
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Core Dependencies
|
| 2 |
+
torch==2.1.2
|
| 3 |
+
torchvision==0.16.2
|
| 4 |
+
torchaudio==2.1.2
|
| 5 |
+
diffusers==0.26.3
|
| 6 |
+
transformers==4.38.2
|
| 7 |
+
accelerate==0.27.2
|
| 8 |
+
bitsandbytes==0.41.1
|
| 9 |
+
scipy==1.12.0
|
| 10 |
+
numpy==1.26.4
|
| 11 |
+
Pillow==10.2.0
|
| 12 |
+
opencv-python==4.9.0.80
|
| 13 |
+
einops==0.7.0
|
| 14 |
+
onnyxruntime-gpu
|
| 15 |
+
insightface==0.7.3
|
| 16 |
+
|
| 17 |
+
# API Dependencies
|
| 18 |
+
fastapi==0.104.1
|
| 19 |
+
uvicorn[standard]==0.24.0
|
| 20 |
+
python-multipart==0.0.6
|
| 21 |
+
pydantic==2.5.0
|