feylur commited on
Commit
cb2d5c5
·
verified ·
1 Parent(s): 7f0ab34

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +59 -1
utils.py CHANGED
@@ -81,4 +81,62 @@ def get_trainable_module(unet, trainable_module_name):
81
  raise ValueError(f"Unknown trainable_module_name: {trainable_module_name}")
82
 
83
 
84
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  raise ValueError(f"Unknown trainable_module_name: {trainable_module_name}")
82
 
83
 
84
+ import torch
85
+ import numpy as np
86
+ from PIL import Image
87
+
88
+ # =====================================================
89
+ # Image and VAE utility functions used by CatVTONPipeline
90
+ # =====================================================
91
+
92
+ def compute_vae_encodings(image, vae):
93
+ """Encode an image tensor using the model's VAE encoder."""
94
+ if isinstance(image, list):
95
+ image = torch.cat(image, dim=0)
96
+ latents = vae.encode(image).latent_dist.sample()
97
+ latents = latents * vae.config.scaling_factor
98
+ return latents
99
+
100
+
101
+ def numpy_to_pil(images):
102
+ """Convert numpy arrays to PIL Images."""
103
+ if images.ndim == 3:
104
+ images = images[None, ...]
105
+ images = (images * 255).round().astype("uint8")
106
+ return [Image.fromarray(image) for image in images]
107
+
108
+
109
+ def prepare_image(image):
110
+ """Convert PIL image to normalized torch tensor."""
111
+ if isinstance(image, Image.Image):
112
+ image = np.array(image.convert("RGB")).astype(np.float32) / 255.0
113
+ image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0)
114
+ return image
115
+
116
+
117
+ def prepare_mask_image(mask_image):
118
+ """Convert PIL mask to tensor in [0,1] range."""
119
+ if isinstance(mask_image, Image.Image):
120
+ mask_image = np.array(mask_image.convert("L")).astype(np.float32) / 255.0
121
+ mask_image = torch.from_numpy(mask_image).unsqueeze(0).unsqueeze(0)
122
+ return mask_image
123
+
124
+
125
+ def resize_and_crop(image, size):
126
+ """Resize image keeping aspect ratio then center crop."""
127
+ if isinstance(image, Image.Image):
128
+ image = image.resize(size, Image.BICUBIC)
129
+ return image
130
+
131
+
132
+ def resize_and_padding(image, size):
133
+ """Resize and pad to match target size."""
134
+ if isinstance(image, Image.Image):
135
+ image.thumbnail(size, Image.BICUBIC)
136
+ new_image = Image.new("RGB", size)
137
+ left = (size[0] - image.size[0]) // 2
138
+ top = (size[1] - image.size[1]) // 2
139
+ new_image.paste(image, (left, top))
140
+ image = new_image
141
+ return image
142
+