harsh99 commited on
Commit
ffb11e9
·
1 Parent(s): 39c0f9d
Files changed (3) hide show
  1. CatVTON_model.py +138 -0
  2. app.py +28 -156
  3. mask_free_inference.ipynb +20 -177
CatVTON_model.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import os
3
+ from typing import Union
4
+
5
+ import PIL
6
+ import numpy as np
7
+ import torch
8
+ from diffusers.utils.torch_utils import randn_tensor
9
+
10
+ from utils import (check_inputs_maskfree, get_time_embedding, numpy_to_pil, prepare_image, compute_vae_encodings)
11
+ from ddpm import DDPMSampler
12
+ from tqdm import tqdm
13
+
14
+
15
+ class CatVTONPix2PixPipeline:
16
+ def __init__(
17
+ self,
18
+ weight_dtype=torch.float32,
19
+ device='cuda',
20
+ compile=False,
21
+ skip_safety_check=True,
22
+ use_tf32=True,
23
+ models={},
24
+ ):
25
+ self.device = device
26
+ self.weight_dtype = weight_dtype
27
+ self.skip_safety_check = skip_safety_check
28
+ self.models = models
29
+
30
+ self.generator = torch.Generator(device=device)
31
+ self.noise_scheduler = DDPMSampler(generator=self.generator)
32
+ # self.vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(device, dtype=weight_dtype)
33
+ self.encoder= models.get('encoder', None)
34
+ self.decoder= models.get('decoder', None)
35
+
36
+ self.unet=models.get('diffusion', None)
37
+ # # Enable TF32 for faster training on Ampere GPUs (A100 and RTX 30 series).
38
+ if use_tf32:
39
+ torch.set_float32_matmul_precision("high")
40
+ torch.backends.cuda.matmul.allow_tf32 = True
41
+
42
+ @torch.no_grad()
43
+ def __call__(
44
+ self,
45
+ image: Union[PIL.Image.Image, torch.Tensor],
46
+ condition_image: Union[PIL.Image.Image, torch.Tensor],
47
+ num_inference_steps: int = 50,
48
+ guidance_scale: float = 2.5,
49
+ height: int = 1024,
50
+ width: int = 768,
51
+ generator=None,
52
+ eta=1.0,
53
+ **kwargs
54
+ ):
55
+ concat_dim = -1 # FIXME: y axis concat
56
+ # Prepare inputs to Tensor
57
+ image, condition_image = check_inputs_maskfree(image, condition_image, width, height)
58
+
59
+ image = prepare_image(image).to(self.device, dtype=self.weight_dtype)
60
+
61
+ condition_image = prepare_image(condition_image).to(self.device, dtype=self.weight_dtype)
62
+
63
+ # Encode the image
64
+ image_latent = compute_vae_encodings(image, self.encoder)
65
+ condition_latent = compute_vae_encodings(condition_image, self.encoder)
66
+
67
+ del image, condition_image
68
+ # Concatenate latents
69
+ # Concatenate latents
70
+ condition_latent_concat = torch.cat([image_latent, condition_latent], dim=concat_dim)
71
+ # Prepare noise
72
+ latents = randn_tensor(
73
+ condition_latent_concat.shape,
74
+ generator=generator,
75
+ device=condition_latent_concat.device,
76
+ dtype=self.weight_dtype,
77
+ )
78
+ # Prepare timesteps
79
+ self.noise_scheduler.set_inference_timesteps(num_inference_steps)
80
+ timesteps = self.noise_scheduler.timesteps
81
+ # latents = latents * self.noise_scheduler.init_noise_sigma
82
+ latents = self.noise_scheduler.add_noise(latents, timesteps[0])
83
+
84
+ # Classifier-Free Guidance
85
+ if do_classifier_free_guidance := (guidance_scale > 1.0):
86
+ condition_latent_concat = torch.cat(
87
+ [
88
+ torch.cat([image_latent, torch.zeros_like(condition_latent)], dim=concat_dim),
89
+ condition_latent_concat,
90
+ ]
91
+ )
92
+
93
+ num_warmup_steps = 0 # For simple DDPM, no warmup needed
94
+ with tqdm(total=num_inference_steps) as progress_bar:
95
+ for i, t in enumerate(timesteps):
96
+ # expand the latents if we are doing classifier free guidance
97
+
98
+ latent_model_input = (torch.cat([latents] * 2) if do_classifier_free_guidance else latents)
99
+
100
+ # prepare the input for the inpainting model
101
+
102
+ p2p_latent_model_input = torch.cat([latent_model_input, condition_latent_concat], dim=1)
103
+ # predict the noise residual
104
+
105
+ timestep = t.repeat(p2p_latent_model_input.shape[0])
106
+ time_embedding = get_time_embedding(timestep).to(self.device, dtype=self.weight_dtype)
107
+
108
+ noise_pred = self.unet(
109
+ p2p_latent_model_input,
110
+ time_embedding
111
+ )
112
+ # perform guidance
113
+ if do_classifier_free_guidance:
114
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
115
+ noise_pred = noise_pred_uncond + guidance_scale * (
116
+ noise_pred_text - noise_pred_uncond
117
+ )
118
+ # compute the previous noisy sample x_t -> x_t-1
119
+ latents = self.noise_scheduler.step(
120
+ t, latents, noise_pred
121
+ )
122
+ # call the callback, if provided
123
+ if i == len(timesteps) - 1 or (
124
+ (i + 1) > num_warmup_steps
125
+ ):
126
+ progress_bar.update()
127
+
128
+ # Decode the final latents
129
+ latents = latents.split(latents.shape[concat_dim] // 2, dim=concat_dim)[0]
130
+ # latents = 1 / self.vae.config.scaling_factor * latents
131
+ # image = self.vae.decode(latents.to(self.device, dtype=self.weight_dtype)).sample
132
+ image = self.decoder(latents.to(self.device, dtype=self.weight_dtype))
133
+ image = (image / 2 + 0.5).clamp(0, 1)
134
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
135
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
136
+ image = numpy_to_pil(image)
137
+
138
+ return image
app.py CHANGED
@@ -8,146 +8,32 @@ from typing import Optional
8
  # Import your custom modules
9
  from load_model import preload_models_from_standard_weights
10
  from utils import to_pil_image
 
11
 
12
- import inspect
13
  import os
14
- from typing import Union
15
-
16
- import PIL
17
- import numpy as np
18
  import torch
19
- import tqdm
20
- from diffusers.utils.torch_utils import randn_tensor
21
-
22
- from utils import (check_inputs_maskfree, get_time_embedding, numpy_to_pil, prepare_image, compute_vae_encodings)
23
- from ddpm import DDPMSampler
24
-
25
-
26
- class CatVTONPix2PixPipeline:
27
- def __init__(
28
- self,
29
- weight_dtype=torch.float32,
30
- device='cuda',
31
- compile=False,
32
- skip_safety_check=True,
33
- use_tf32=True,
34
- models={},
35
- ):
36
- self.device = device
37
- self.weight_dtype = weight_dtype
38
- self.skip_safety_check = skip_safety_check
39
- self.models = models
40
 
41
- self.generator = torch.Generator(device=device)
42
- self.noise_scheduler = DDPMSampler(generator=self.generator)
43
- self.encoder= models.get('encoder', None)
44
- self.decoder= models.get('decoder', None)
45
- self.unet=models.get('diffusion', None)
46
-
47
- # Enable TF32 for faster training on Ampere GPUs
48
- if use_tf32:
49
- torch.set_float32_matmul_precision("high")
50
- torch.backends.cuda.matmul.allow_tf32 = True
51
-
52
- @torch.no_grad()
53
- def __call__(
54
- self,
55
- image: Union[PIL.Image.Image, torch.Tensor],
56
- condition_image: Union[PIL.Image.Image, torch.Tensor],
57
- num_inference_steps: int = 50,
58
- guidance_scale: float = 2.5,
59
- height: int = 1024,
60
- width: int = 768,
61
- generator=None,
62
- eta=1.0,
63
- **kwargs
64
- ):
65
- concat_dim = -1 # FIXME: y axis concat
66
- # Prepare inputs to Tensor
67
- image, condition_image = check_inputs_maskfree(image, condition_image, width, height)
68
-
69
- # Ensure consistent dtype for all tensors
70
- image = prepare_image(image).to(self.device, dtype=self.weight_dtype)
71
- condition_image = prepare_image(condition_image).to(self.device, dtype=self.weight_dtype)
72
-
73
- # Encode the image
74
- image_latent = compute_vae_encodings(image, self.encoder)
75
- condition_latent = compute_vae_encodings(condition_image, self.encoder)
76
-
77
- del image, condition_image
78
-
79
- # Concatenate latents
80
- condition_latent_concat = torch.cat([image_latent, condition_latent], dim=concat_dim)
81
-
82
- # Prepare noise
83
- latents = randn_tensor(
84
- condition_latent_concat.shape,
85
- generator=generator,
86
- device=condition_latent_concat.device,
87
- dtype=self.weight_dtype,
88
- )
89
-
90
- # Prepare timesteps
91
- self.noise_scheduler.set_inference_timesteps(num_inference_steps)
92
- timesteps = self.noise_scheduler.timesteps
93
- latents = self.noise_scheduler.add_noise(latents, timesteps[0])
94
-
95
- # Classifier-Free Guidance
96
- if do_classifier_free_guidance := (guidance_scale > 1.0):
97
- condition_latent_concat = torch.cat(
98
- [
99
- torch.cat([image_latent, torch.zeros_like(condition_latent)], dim=concat_dim),
100
- condition_latent_concat,
101
- ]
102
- )
103
 
104
- num_warmup_steps = 0 # For simple DDPM, no warmup needed
105
- with tqdm.tqdm(total=num_inference_steps) as progress_bar:
106
- for i, t in enumerate(timesteps):
107
- # expand the latents if we are doing classifier free guidance
108
- latent_model_input = (torch.cat([latents] * 2) if do_classifier_free_guidance else latents)
109
 
110
- # prepare the input for the inpainting model
111
- p2p_latent_model_input = torch.cat([latent_model_input, condition_latent_concat], dim=1)
112
-
113
- # predict the noise residual
114
- timestep = t.repeat(p2p_latent_model_input.shape[0])
115
- time_embedding = get_time_embedding(timestep).to(self.device, dtype=self.weight_dtype)
 
116
 
117
- noise_pred = self.unet(
118
- p2p_latent_model_input,
119
- time_embedding
120
- )
121
-
122
- # perform guidance
123
- if do_classifier_free_guidance:
124
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
125
- noise_pred = noise_pred_uncond + guidance_scale * (
126
- noise_pred_text - noise_pred_uncond
127
- )
128
-
129
- # compute the previous noisy sample x_t -> x_t-1
130
- latents = self.noise_scheduler.step(
131
- t, latents, noise_pred
132
- )
133
-
134
- # call the callback, if provided
135
- if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps):
136
- progress_bar.update()
137
 
138
- # Decode the final latents
139
- latents = latents.split(latents.shape[concat_dim] // 2, dim=concat_dim)[0]
140
- image = self.decoder(latents.to(self.device, dtype=self.weight_dtype))
141
- image = (image / 2 + 0.5).clamp(0, 1)
142
- image = image.cpu().permute(0, 2, 3, 1).float().numpy()
143
- image = numpy_to_pil(image)
144
-
145
- return image
146
-
147
- def load_models():
148
- try:
149
- print("🚀 Starting model loading process...")
150
-
151
  # Check CUDA availability
152
  cuda_available = torch.cuda.is_available()
153
  print(f"CUDA available: {cuda_available}")
@@ -155,21 +41,9 @@ def load_models():
155
  print(f"CUDA device: {torch.cuda.get_device_name()}")
156
  free_memory = torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated(0)
157
  print(f"Available CUDA memory: {free_memory / 1e9:.2f} GB")
158
-
159
  device = "cuda" if cuda_available else "cpu"
160
-
161
- # Check if model files exist
162
- ckpt_path = "instruct-pix2pix-00-22000.ckpt"
163
- finetune_path = "maskfree_finetuned_weights.safetensors"
164
-
165
- if not os.path.exists(ckpt_path):
166
- print(f"❌ Checkpoint file not found: {ckpt_path}")
167
- return None, None
168
-
169
- if not os.path.exists(finetune_path):
170
- print(f"❌ Finetune weights file not found: {finetune_path}")
171
- return None, None
172
-
173
  print("📦 Loading models from weights...")
174
 
175
  models = preload_models_from_standard_weights(
@@ -177,16 +51,14 @@ def load_models():
177
  device=device,
178
  finetune_weights_path=finetune_path
179
  )
180
-
181
  if not models:
182
  print("❌ Failed to load models")
183
  return None, None
184
-
185
- # Convert all models to consistent dtype to avoid mixed precision issues
186
- weight_dtype = torch.float32 # Use float32 to avoid dtype mismatch
187
  print(f"Converting models to {weight_dtype}...")
188
-
189
- # Ensure all models use the same dtype
190
  for model_name, model in models.items():
191
  if model is not None:
192
  try:
@@ -195,19 +67,19 @@ def load_models():
195
  print(f"✅ {model_name} converted to {weight_dtype}")
196
  except Exception as e:
197
  print(f"⚠️ Could not convert {model_name} to {weight_dtype}: {e}")
198
-
199
  print("🔧 Initializing pipeline...")
200
-
201
  pipeline = CatVTONPix2PixPipeline(
202
  weight_dtype=weight_dtype,
203
  device=device,
204
  skip_safety_check=True,
205
  models=models,
206
  )
207
-
208
  print("✅ Models and pipeline loaded successfully!")
209
  return models, pipeline
210
-
211
  except Exception as e:
212
  print(f"❌ Error in load_models: {e}")
213
  import traceback
 
8
  # Import your custom modules
9
  from load_model import preload_models_from_standard_weights
10
  from utils import to_pil_image
11
+ from CatVTON_model import CatVTONPix2PixPipeline
12
 
 
13
  import os
 
 
 
 
14
  import torch
15
+ import urllib.request
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
+ def load_models():
18
+ try:
19
+ print("🚀 Starting model loading process...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
+ # Check and download model files if missing
22
+ ckpt_path = "instruct-pix2pix-00-22000.ckpt"
23
+ finetune_path = "maskfree_finetuned_weights.safetensors"
 
 
24
 
25
+ if not os.path.exists(ckpt_path):
26
+ print(f"⬇️ Downloading {ckpt_path}...")
27
+ url = "https://huggingface.co/timbrooks/instruct-pix2pix/resolve/main/instruct-pix2pix-00-22000.ckpt"
28
+ urllib.request.urlretrieve(url, ckpt_path)
29
+ print("✅ Download complete.")
30
+ else:
31
+ print("✅ Checkpoint already exists.")
32
 
33
+ if not os.path.exists(finetune_path):
34
+ print(f"❌ Finetune weights file not found: {finetune_path}")
35
+ return None, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  # Check CUDA availability
38
  cuda_available = torch.cuda.is_available()
39
  print(f"CUDA available: {cuda_available}")
 
41
  print(f"CUDA device: {torch.cuda.get_device_name()}")
42
  free_memory = torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated(0)
43
  print(f"Available CUDA memory: {free_memory / 1e9:.2f} GB")
44
+
45
  device = "cuda" if cuda_available else "cpu"
46
+
 
 
 
 
 
 
 
 
 
 
 
 
47
  print("📦 Loading models from weights...")
48
 
49
  models = preload_models_from_standard_weights(
 
51
  device=device,
52
  finetune_weights_path=finetune_path
53
  )
54
+
55
  if not models:
56
  print("❌ Failed to load models")
57
  return None, None
58
+
59
+ weight_dtype = torch.float32
 
60
  print(f"Converting models to {weight_dtype}...")
61
+
 
62
  for model_name, model in models.items():
63
  if model is not None:
64
  try:
 
67
  print(f"✅ {model_name} converted to {weight_dtype}")
68
  except Exception as e:
69
  print(f"⚠️ Could not convert {model_name} to {weight_dtype}: {e}")
70
+
71
  print("🔧 Initializing pipeline...")
72
+
73
  pipeline = CatVTONPix2PixPipeline(
74
  weight_dtype=weight_dtype,
75
  device=device,
76
  skip_safety_check=True,
77
  models=models,
78
  )
79
+
80
  print("✅ Models and pipeline loaded successfully!")
81
  return models, pipeline
82
+
83
  except Exception as e:
84
  print(f"❌ Error in load_models: {e}")
85
  import traceback
mask_free_inference.ipynb CHANGED
@@ -95,197 +95,39 @@
95
  },
96
  {
97
  "cell_type": "code",
98
- "execution_count": 3,
99
  "id": "78e3d8b9",
100
  "metadata": {},
101
- "outputs": [
102
- {
103
- "name": "stderr",
104
- "output_type": "stream",
105
- "text": [
106
- "/home/mahesh/miniconda3/envs/harsh/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
107
- " from .autonotebook import tqdm as notebook_tqdm\n"
108
- ]
109
- }
110
- ],
111
- "source": [
112
- "import inspect\n",
113
- "import os\n",
114
- "from typing import Union\n",
115
- "\n",
116
- "import PIL\n",
117
- "import numpy as np\n",
118
- "import torch\n",
119
- "import tqdm\n",
120
- "from diffusers.utils.torch_utils import randn_tensor\n",
121
- "\n",
122
- "from utils import (check_inputs_maskfree, get_time_embedding, numpy_to_pil, prepare_image, compute_vae_encodings)\n",
123
- "from ddpm import DDPMSampler\n",
124
- "\n",
125
- "class CatVTONPix2PixPipeline:\n",
126
- " def __init__(\n",
127
- " self, \n",
128
- " weight_dtype=torch.float32,\n",
129
- " device='cuda',\n",
130
- " compile=False,\n",
131
- " skip_safety_check=True,\n",
132
- " use_tf32=True,\n",
133
- " models={},\n",
134
- " ):\n",
135
- " self.device = device\n",
136
- " self.weight_dtype = weight_dtype\n",
137
- " self.skip_safety_check = skip_safety_check\n",
138
- " self.models = models\n",
139
- "\n",
140
- " self.generator = torch.Generator(device=device)\n",
141
- " self.noise_scheduler = DDPMSampler(generator=self.generator)\n",
142
- " # self.vae = AutoencoderKL.from_pretrained(\"stabilityai/sd-vae-ft-mse\").to(device, dtype=weight_dtype)\n",
143
- " self.encoder= models.get('encoder', None)\n",
144
- " self.decoder= models.get('decoder', None)\n",
145
- " \n",
146
- " self.unet=models.get('diffusion', None) \n",
147
- " # # Enable TF32 for faster training on Ampere GPUs (A100 and RTX 30 series).\n",
148
- " if use_tf32:\n",
149
- " torch.set_float32_matmul_precision(\"high\")\n",
150
- " torch.backends.cuda.matmul.allow_tf32 = True\n",
151
- "\n",
152
- " @torch.no_grad()\n",
153
- " def __call__(\n",
154
- " self, \n",
155
- " image: Union[PIL.Image.Image, torch.Tensor],\n",
156
- " condition_image: Union[PIL.Image.Image, torch.Tensor],\n",
157
- " num_inference_steps: int = 50,\n",
158
- " guidance_scale: float = 2.5,\n",
159
- " height: int = 1024,\n",
160
- " width: int = 768,\n",
161
- " generator=None,\n",
162
- " eta=1.0,\n",
163
- " **kwargs\n",
164
- " ):\n",
165
- " concat_dim = -1 # FIXME: y axis concat\n",
166
- " # Prepare inputs to Tensor\n",
167
- " image, condition_image = check_inputs_maskfree(image, condition_image, width, height)\n",
168
- " \n",
169
- " image = prepare_image(image).to(self.device, dtype=self.weight_dtype)\n",
170
- " \n",
171
- " condition_image = prepare_image(condition_image).to(self.device, dtype=self.weight_dtype)\n",
172
- " \n",
173
- " # Encode the image\n",
174
- " image_latent = compute_vae_encodings(image, self.encoder)\n",
175
- " condition_latent = compute_vae_encodings(condition_image, self.encoder)\n",
176
- " \n",
177
- " del image, condition_image\n",
178
- " # Concatenate latents\n",
179
- " # Concatenate latents\n",
180
- " condition_latent_concat = torch.cat([image_latent, condition_latent], dim=concat_dim)\n",
181
- " # Prepare noise\n",
182
- " latents = randn_tensor(\n",
183
- " condition_latent_concat.shape,\n",
184
- " generator=generator,\n",
185
- " device=condition_latent_concat.device,\n",
186
- " dtype=self.weight_dtype,\n",
187
- " )\n",
188
- " # Prepare timesteps\n",
189
- " self.noise_scheduler.set_inference_timesteps(num_inference_steps)\n",
190
- " timesteps = self.noise_scheduler.timesteps\n",
191
- " # latents = latents * self.noise_scheduler.init_noise_sigma\n",
192
- " latents = self.noise_scheduler.add_noise(latents, timesteps[0])\n",
193
- " \n",
194
- " # Classifier-Free Guidance\n",
195
- " if do_classifier_free_guidance := (guidance_scale > 1.0):\n",
196
- " condition_latent_concat = torch.cat(\n",
197
- " [\n",
198
- " torch.cat([image_latent, torch.zeros_like(condition_latent)], dim=concat_dim),\n",
199
- " condition_latent_concat,\n",
200
- " ]\n",
201
- " )\n",
202
- "\n",
203
- " num_warmup_steps = 0 # For simple DDPM, no warmup needed\n",
204
- " with tqdm(total=num_inference_steps) as progress_bar:\n",
205
- " for i, t in enumerate(timesteps):\n",
206
- " # expand the latents if we are doing classifier free guidance\n",
207
- " \n",
208
- " latent_model_input = (torch.cat([latents] * 2) if do_classifier_free_guidance else latents)\n",
209
- "\n",
210
- " # prepare the input for the inpainting model\n",
211
- " \n",
212
- " p2p_latent_model_input = torch.cat([latent_model_input, condition_latent_concat], dim=1)\n",
213
- " # predict the noise residual\n",
214
- " \n",
215
- " timestep = t.repeat(p2p_latent_model_input.shape[0])\n",
216
- " time_embedding = get_time_embedding(timestep).to(self.device, dtype=self.weight_dtype)\n",
217
- "\n",
218
- " noise_pred = self.unet(\n",
219
- " p2p_latent_model_input,\n",
220
- " time_embedding\n",
221
- " )\n",
222
- " # perform guidance\n",
223
- " if do_classifier_free_guidance:\n",
224
- " noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n",
225
- " noise_pred = noise_pred_uncond + guidance_scale * (\n",
226
- " noise_pred_text - noise_pred_uncond\n",
227
- " )\n",
228
- " # compute the previous noisy sample x_t -> x_t-1\n",
229
- " latents = self.noise_scheduler.step(\n",
230
- " t, latents, noise_pred\n",
231
- " )\n",
232
- " # call the callback, if provided\n",
233
- " if i == len(timesteps) - 1 or (\n",
234
- " (i + 1) > num_warmup_steps\n",
235
- " ):\n",
236
- " progress_bar.update()\n",
237
- "\n",
238
- " # Decode the final latents\n",
239
- " latents = latents.split(latents.shape[concat_dim] // 2, dim=concat_dim)[0]\n",
240
- " # latents = 1 / self.vae.config.scaling_factor * latents\n",
241
- " # image = self.vae.decode(latents.to(self.device, dtype=self.weight_dtype)).sample\n",
242
- " image = self.decoder(latents.to(self.device, dtype=self.weight_dtype))\n",
243
- " image = (image / 2 + 0.5).clamp(0, 1)\n",
244
- " # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16\n",
245
- " image = image.cpu().permute(0, 2, 3, 1).float().numpy()\n",
246
- " image = numpy_to_pil(image)\n",
247
- " \n",
248
- " return image\n"
249
- ]
250
  },
251
  {
252
  "cell_type": "code",
253
- "execution_count": 4,
254
  "id": "5627b2d2",
255
  "metadata": {},
256
  "outputs": [
257
  {
258
- "name": "stdout",
259
  "output_type": "stream",
260
  "text": [
261
- "Dataset vitonhd loaded, total 20 pairs.\n"
 
262
  ]
263
  },
264
  {
265
- "name": "stderr",
266
- "output_type": "stream",
267
- "text": [
268
- "100%|██████████| 50/50 [00:07<00:00, 7.12it/s]\n",
269
- "100%|██████████| 50/50 [00:06<00:00, 7.31it/s]\n",
270
- "100%|██████████| 50/50 [00:07<00:00, 7.09it/s]\n",
271
- "100%|██████████| 50/50 [00:07<00:00, 6.98it/s]\n",
272
- "100%|██████████| 50/50 [00:07<00:00, 7.01it/s]\n",
273
- "100%|██████████| 50/50 [00:07<00:00, 7.13it/s]\n",
274
- "100%|██████████| 50/50 [00:06<00:00, 7.28it/s]\n",
275
- "100%|██████████| 50/50 [00:07<00:00, 7.13it/s]\n",
276
- "100%|██████████| 50/50 [00:06<00:00, 7.17it/s]\n",
277
- "100%|██████████| 50/50 [00:07<00:00, 6.97it/s]\n",
278
- "100%|██████████| 50/50 [00:06<00:00, 7.17it/s]\n",
279
- "100%|██████████| 50/50 [00:06<00:00, 7.38it/s]\n",
280
- "100%|██████████| 50/50 [00:06<00:00, 7.20it/s]\n",
281
- "100%|██████████| 50/50 [00:07<00:00, 6.92it/s]\n",
282
- "100%|██████████| 50/50 [00:07<00:00, 6.71it/s]\n",
283
- "100%|██████████| 50/50 [00:06<00:00, 7.25it/s]\n",
284
- "100%|██████████| 50/50 [00:06<00:00, 7.49it/s]\n",
285
- "100%|██████████| 50/50 [00:07<00:00, 6.87it/s]\n",
286
- "100%|██████████| 50/50 [00:07<00:00, 6.89it/s]\n",
287
- "100%|██████████| 50/50 [00:07<00:00, 6.92it/s]\n",
288
- "100%|██████████| 20/20 [02:26<00:00, 7.35s/it]\n"
289
  ]
290
  }
291
  ],
@@ -297,6 +139,7 @@
297
  "from VITON_Dataset import VITONHDTestDataset\n",
298
  "from tqdm import tqdm\n",
299
  "from PIL import Image\n",
 
300
  "\n",
301
  "from utils import to_pil_image\n",
302
  "\n",
 
95
  },
96
  {
97
  "cell_type": "code",
98
+ "execution_count": null,
99
  "id": "78e3d8b9",
100
  "metadata": {},
101
+ "outputs": [],
102
+ "source": []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  },
104
  {
105
  "cell_type": "code",
106
+ "execution_count": 3,
107
  "id": "5627b2d2",
108
  "metadata": {},
109
  "outputs": [
110
  {
111
+ "name": "stderr",
112
  "output_type": "stream",
113
  "text": [
114
+ "/home/mahesh/miniconda3/envs/harsh/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
115
+ " from .autonotebook import tqdm as notebook_tqdm\n"
116
  ]
117
  },
118
  {
119
+ "ename": "AssertionError",
120
+ "evalue": "File ./sample_dataset/samples_pairs.txt does not exist.",
121
+ "output_type": "error",
122
+ "traceback": [
123
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
124
+ "\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)",
125
+ "Cell \u001b[0;32mIn[3], line 97\u001b[0m\n\u001b[1;32m 94\u001b[0m result\u001b[38;5;241m.\u001b[39msave(output_path)\n\u001b[1;32m 96\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;18m__name__\u001b[39m \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m__main__\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[0;32m---> 97\u001b[0m \u001b[43mmain\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
126
+ "File \u001b[0;32m~/miniconda3/envs/harsh/lib/python3.10/site-packages/torch/utils/_contextlib.py:116\u001b[0m, in \u001b[0;36mcontext_decorator.<locals>.decorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 113\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[1;32m 114\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mdecorate_context\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 115\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m ctx_factory():\n\u001b[0;32m--> 116\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
127
+ "Cell \u001b[0;32mIn[3], line 48\u001b[0m, in \u001b[0;36mmain\u001b[0;34m()\u001b[0m\n\u001b[1;32m 46\u001b[0m \u001b[38;5;66;03m# Dataset\u001b[39;00m\n\u001b[1;32m 47\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m args\u001b[38;5;241m.\u001b[39mdataset_name \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mvitonhd\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[0;32m---> 48\u001b[0m dataset \u001b[38;5;241m=\u001b[39m \u001b[43mVITONHDTestDataset\u001b[49m\u001b[43m(\u001b[49m\u001b[43margs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 49\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 50\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mInvalid dataset name \u001b[39m\u001b[38;5;132;01m{\u001b[39;00margs\u001b[38;5;241m.\u001b[39mdataset\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
128
+ "File \u001b[0;32m~/harsh/stable-diffusion/VITON_Dataset.py:14\u001b[0m, in \u001b[0;36mInferenceDataset.__init__\u001b[0;34m(self, args)\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mvae_processor \u001b[38;5;241m=\u001b[39m VaeImageProcessor(vae_scale_factor\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m8\u001b[39m) \n\u001b[1;32m 13\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmask_processor \u001b[38;5;241m=\u001b[39m VaeImageProcessor(vae_scale_factor\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m8\u001b[39m, do_normalize\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m, do_binarize\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m, do_convert_grayscale\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m) \n\u001b[0;32m---> 14\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdata \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mload_data\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
129
+ "File \u001b[0;32m~/harsh/stable-diffusion/VITON_Dataset.py:36\u001b[0m, in \u001b[0;36mVITONHDTestDataset.load_data\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 34\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mload_data\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m 35\u001b[0m name\u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtrain\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs\u001b[38;5;241m.\u001b[39mis_train \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124msamples\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m---> 36\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39mexists(pair_txt\u001b[38;5;241m:=\u001b[39mos\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39mjoin(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs\u001b[38;5;241m.\u001b[39mdata_root_path, \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mname\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m_pairs.txt\u001b[39m\u001b[38;5;124m'\u001b[39m)), \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mFile \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mpair_txt\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m does not exist.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 37\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mopen\u001b[39m(pair_txt, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mr\u001b[39m\u001b[38;5;124m'\u001b[39m) \u001b[38;5;28;01mas\u001b[39;00m f:\n\u001b[1;32m 38\u001b[0m lines \u001b[38;5;241m=\u001b[39m f\u001b[38;5;241m.\u001b[39mreadlines()\n",
130
+ "\u001b[0;31mAssertionError\u001b[0m: File ./sample_dataset/samples_pairs.txt does not exist."
 
 
 
 
 
 
 
 
 
 
 
 
131
  ]
132
  }
133
  ],
 
139
  "from VITON_Dataset import VITONHDTestDataset\n",
140
  "from tqdm import tqdm\n",
141
  "from PIL import Image\n",
142
+ "from CatVTON_model import CatVTONPix2PixPipeline\n",
143
  "\n",
144
  "from utils import to_pil_image\n",
145
  "\n",