senku02 commited on
Commit
47ac06f
·
verified ·
1 Parent(s): a93eb7a

Upload 2 files

Browse files
Files changed (2) hide show
  1. handler.py +62 -0
  2. requirements.txt +21 -0
handler.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ import torch
3
+ import base64
4
+ import io
5
+ from PIL import Image
6
+ from tryon_core import TryOnEngine
7
+ from api_utils import prepare_image_for_processing, image_to_base64
8
+
9
+ class EndpointHandler:
10
+ def __init__(self, path=""):
11
+ # Initialize the engine
12
+ # path is the path to the model files on the HF container
13
+ print("Initializing IDM-VTON Handler...")
14
+ self.engine = TryOnEngine(load_mode="4bit", enable_cpu_offload=False, fixed_vae=True)
15
+
16
+ # Override model_id to load from local path if needed,
17
+ # or let it download from Hub if path is just a directory
18
+ # self.engine.model_id = path
19
+
20
+ self.engine.load_models()
21
+ self.engine.load_processing_models()
22
+ print("Handler Initialized!")
23
+
24
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
25
+ """
26
+ Args:
27
+ data (:obj:):
28
+ includes the input data and the parameters for the inference.
29
+ """
30
+ # 1. Extract inputs
31
+ inputs = data.pop("inputs", data)
32
+ human_img_b64 = inputs.get("human_image")
33
+ garment_img_b64 = inputs.get("garment_image")
34
+ description = inputs.get("garment_description", "a photo of a garment")
35
+ category = inputs.get("category", "upper_body")
36
+
37
+ # 2. Decode images
38
+ human_img = Image.open(io.BytesIO(base64.b64decode(human_img_b64)))
39
+ garment_img = Image.open(io.BytesIO(base64.b64decode(garment_img_b64)))
40
+
41
+ # 3. Process
42
+ human_img = prepare_image_for_processing(human_img)
43
+ garment_img = prepare_image_for_processing(garment_img)
44
+
45
+ # 4. Generate
46
+ generated_images, masked_image = self.engine.generate(
47
+ human_img=human_img,
48
+ garment_img=garment_img,
49
+ garment_description=description,
50
+ category=category,
51
+ use_auto_mask=True,
52
+ use_auto_crop=True,
53
+ denoise_steps=30,
54
+ seed=42,
55
+ num_images=1
56
+ )
57
+
58
+ # 5. Return result
59
+ return [{
60
+ "generated_image": image_to_base64(generated_images[0]),
61
+ "masked_image": image_to_base64(masked_image)
62
+ }]
requirements.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core Dependencies
2
+ torch==2.1.2
3
+ torchvision==0.16.2
4
+ torchaudio==2.1.2
5
+ diffusers==0.26.3
6
+ transformers==4.38.2
7
+ accelerate==0.27.2
8
+ bitsandbytes==0.41.1
9
+ scipy==1.12.0
10
+ numpy==1.26.4
11
+ Pillow==10.2.0
12
+ opencv-python==4.9.0.80
13
+ einops==0.7.0
14
+ onnyxruntime-gpu
15
+ insightface==0.7.3
16
+
17
+ # API Dependencies
18
+ fastapi==0.104.1
19
+ uvicorn[standard]==0.24.0
20
+ python-multipart==0.0.6
21
+ pydantic==2.5.0