Commit
·
90101b2
1
Parent(s):
226c7c9
print latent shape
Browse files
app.py
CHANGED
|
@@ -28,8 +28,8 @@ except Exception as e:
|
|
| 28 |
# download checkpoints
|
| 29 |
from download_checkpoints import main as download_checkpoints
|
| 30 |
|
| 31 |
-
os.makedirs(CHECKPOINTS_PATH, exist_ok=True)
|
| 32 |
-
download_checkpoints(hf_token="", output_dir=CHECKPOINTS_PATH, model="7b_av")
|
| 33 |
|
| 34 |
|
| 35 |
from test_environment import main as check_environment
|
|
|
|
| 28 |
# download checkpoints
|
| 29 |
from download_checkpoints import main as download_checkpoints
|
| 30 |
|
| 31 |
+
# os.makedirs(CHECKPOINTS_PATH, exist_ok=True)
|
| 32 |
+
# download_checkpoints(hf_token="", output_dir=CHECKPOINTS_PATH, model="7b_av")
|
| 33 |
|
| 34 |
|
| 35 |
from test_environment import main as check_environment
|
cosmos_transfer1/diffusion/inference/world_generation_pipeline.py
CHANGED
|
@@ -553,21 +553,28 @@ class DiffusionControl2WorldGenerationPipeline(BaseWorldGenerationPipeline):
|
|
| 553 |
end_frame = num_new_generated_frames * (i_clip + 1) + self.num_input_frames
|
| 554 |
|
| 555 |
# Prepare x_sigma_max
|
|
|
|
| 556 |
if input_video is not None:
|
| 557 |
if is_upscale_case:
|
| 558 |
x_sigma_max = []
|
| 559 |
for b in range(B):
|
| 560 |
input_frames = input_video[b : b + 1, :, start_frame:end_frame].cuda()
|
| 561 |
x0 = self.model.encode(input_frames).contiguous()
|
|
|
|
| 562 |
x_sigma_max.append(self.model.get_x_from_clean(x0, self.sigma_max, seed=(self.seed + i_clip)))
|
|
|
|
| 563 |
x_sigma_max = torch.cat(x_sigma_max)
|
| 564 |
else:
|
| 565 |
input_frames = input_video[:, :, start_frame:end_frame].cuda()
|
| 566 |
x0 = self.model.encode(input_frames).contiguous()
|
|
|
|
| 567 |
x_sigma_max = self.model.get_x_from_clean(x0, self.sigma_max, seed=(self.seed + i_clip))
|
|
|
|
| 568 |
|
| 569 |
else:
|
| 570 |
x_sigma_max = None
|
|
|
|
|
|
|
| 571 |
|
| 572 |
data_batch_i[hint_key] = control_input[:, :, start_frame:end_frame].cuda()
|
| 573 |
latent_hint = []
|
|
|
|
| 553 |
end_frame = num_new_generated_frames * (i_clip + 1) + self.num_input_frames
|
| 554 |
|
| 555 |
# Prepare x_sigma_max
|
| 556 |
+
print("==============================================================")
|
| 557 |
if input_video is not None:
|
| 558 |
if is_upscale_case:
|
| 559 |
x_sigma_max = []
|
| 560 |
for b in range(B):
|
| 561 |
input_frames = input_video[b : b + 1, :, start_frame:end_frame].cuda()
|
| 562 |
x0 = self.model.encode(input_frames).contiguous()
|
| 563 |
+
print("x0 shape ->", x0.shape)
|
| 564 |
x_sigma_max.append(self.model.get_x_from_clean(x0, self.sigma_max, seed=(self.seed + i_clip)))
|
| 565 |
+
print("x_sigma_max shape ->", x_sigma_max.shape)
|
| 566 |
x_sigma_max = torch.cat(x_sigma_max)
|
| 567 |
else:
|
| 568 |
input_frames = input_video[:, :, start_frame:end_frame].cuda()
|
| 569 |
x0 = self.model.encode(input_frames).contiguous()
|
| 570 |
+
print("x0 shape ->", x0.shape)
|
| 571 |
x_sigma_max = self.model.get_x_from_clean(x0, self.sigma_max, seed=(self.seed + i_clip))
|
| 572 |
+
print("x_sigma_max shape ->", x_sigma_max.shape)
|
| 573 |
|
| 574 |
else:
|
| 575 |
x_sigma_max = None
|
| 576 |
+
print("final ->", x_sigma_max.shape)
|
| 577 |
+
print("==============================================================")
|
| 578 |
|
| 579 |
data_batch_i[hint_key] = control_input[:, :, start_frame:end_frame].cuda()
|
| 580 |
latent_hint = []
|
cosmos_transfer1/diffusion/model/model_t2w.py
CHANGED
|
@@ -177,6 +177,9 @@ class DiffusionT2WModel(torch.nn.Module):
|
|
| 177 |
noise prediction (eps_pred) and optional confidence (logvar).
|
| 178 |
"""
|
| 179 |
|
|
|
|
|
|
|
|
|
|
| 180 |
xt = xt.to(**self.tensor_kwargs)
|
| 181 |
sigma = sigma.to(**self.tensor_kwargs)
|
| 182 |
# get precondition for the network
|
|
|
|
| 177 |
noise prediction (eps_pred) and optional confidence (logvar).
|
| 178 |
"""
|
| 179 |
|
| 180 |
+
print("<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<")
|
| 181 |
+
print(xt.shape)
|
| 182 |
+
print("<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<")
|
| 183 |
xt = xt.to(**self.tensor_kwargs)
|
| 184 |
sigma = sigma.to(**self.tensor_kwargs)
|
| 185 |
# get precondition for the network
|