|
|
import torch |
|
|
import torch.nn as nn |
|
|
import numpy as np |
|
|
import os |
|
|
import json |
|
|
import imageio |
|
|
import argparse |
|
|
from PIL import Image |
|
|
from diffsynth import WanVideoReCamMasterPipeline, ModelManager |
|
|
from torchvision.transforms import v2 |
|
|
from einops import rearrange |
|
|
from scipy.spatial.transform import Rotation as R |
|
|
|
|
|
def compute_relative_pose_matrix(pose1, pose2): |
|
|
""" |
|
|
计算相邻两帧的相对位姿,返回3×4的相机矩阵 [R_rel | t_rel] |
|
|
|
|
|
参数: |
|
|
pose1: 第i帧的相机位姿,形状为(7,)的数组 [tx1, ty1, tz1, qx1, qy1, qz1, qw1] |
|
|
pose2: 第i+1帧的相机位姿,形状为(7,)的数组 [tx2, ty2, tz2, qx2, qy2, qz2, qw2] |
|
|
|
|
|
返回: |
|
|
relative_matrix: 3×4的相对位姿矩阵,前3列是旋转矩阵R_rel,第4列是平移向量t_rel |
|
|
""" |
|
|
|
|
|
t1 = pose1[:3] |
|
|
q1 = pose1[3:] |
|
|
t2 = pose2[:3] |
|
|
q2 = pose2[3:] |
|
|
|
|
|
|
|
|
rot1 = R.from_quat(q1) |
|
|
rot2 = R.from_quat(q2) |
|
|
rot_rel = rot2 * rot1.inv() |
|
|
R_rel = rot_rel.as_matrix() |
|
|
|
|
|
|
|
|
R1_T = rot1.as_matrix().T |
|
|
t_rel = R1_T @ (t2 - t1) |
|
|
|
|
|
|
|
|
relative_matrix = np.hstack([R_rel, t_rel.reshape(3, 1)]) |
|
|
|
|
|
return relative_matrix |
|
|
|
|
|
def load_encoded_video_from_pth(pth_path, start_frame=0, num_frames=10): |
|
|
"""从pth文件加载预编码的视频数据""" |
|
|
print(f"Loading encoded video from {pth_path}") |
|
|
|
|
|
encoded_data = torch.load(pth_path, weights_only=False, map_location="cpu") |
|
|
full_latents = encoded_data['latents'] |
|
|
|
|
|
print(f"Full latents shape: {full_latents.shape}") |
|
|
print(f"Extracting frames {start_frame} to {start_frame + num_frames}") |
|
|
|
|
|
if start_frame + num_frames > full_latents.shape[1]: |
|
|
raise ValueError(f"Not enough frames: requested {start_frame + num_frames}, available {full_latents.shape[1]}") |
|
|
|
|
|
condition_latents = full_latents[:, start_frame:start_frame + num_frames, :, :] |
|
|
print(f"Extracted condition latents shape: {condition_latents.shape}") |
|
|
|
|
|
return condition_latents, encoded_data |
|
|
|
|
|
def replace_dit_model_in_manager(): |
|
|
"""在模型加载前替换DiT模型类""" |
|
|
from diffsynth.models.wan_video_dit_recam_future import WanModelFuture |
|
|
from diffsynth.configs.model_config import model_loader_configs |
|
|
|
|
|
|
|
|
for i, config in enumerate(model_loader_configs): |
|
|
keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource = config |
|
|
|
|
|
|
|
|
if 'wan_video_dit' in model_names: |
|
|
|
|
|
new_model_names = [] |
|
|
new_model_classes = [] |
|
|
|
|
|
for name, cls in zip(model_names, model_classes): |
|
|
if name == 'wan_video_dit': |
|
|
new_model_names.append(name) |
|
|
new_model_classes.append(WanModelFuture) |
|
|
print(f"✅ 替换了模型类: {name} -> WanModelFuture") |
|
|
else: |
|
|
new_model_names.append(name) |
|
|
new_model_classes.append(cls) |
|
|
|
|
|
|
|
|
model_loader_configs[i] = (keys_hash, keys_hash_with_shape, new_model_names, new_model_classes, model_resource) |
|
|
|
|
|
def add_framepack_components(dit_model): |
|
|
"""添加FramePack相关组件""" |
|
|
if not hasattr(dit_model, 'clean_x_embedder'): |
|
|
inner_dim = dit_model.blocks[0].self_attn.q.weight.shape[0] |
|
|
|
|
|
class CleanXEmbedder(nn.Module): |
|
|
def __init__(self, inner_dim): |
|
|
super().__init__() |
|
|
|
|
|
self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2)) |
|
|
self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4)) |
|
|
self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8)) |
|
|
|
|
|
def forward(self, x, scale="1x"): |
|
|
if scale == "1x": |
|
|
return self.proj(x) |
|
|
elif scale == "2x": |
|
|
return self.proj_2x(x) |
|
|
elif scale == "4x": |
|
|
return self.proj_4x(x) |
|
|
else: |
|
|
raise ValueError(f"Unsupported scale: {scale}") |
|
|
|
|
|
dit_model.clean_x_embedder = CleanXEmbedder(inner_dim) |
|
|
model_dtype = next(dit_model.parameters()).dtype |
|
|
dit_model.clean_x_embedder = dit_model.clean_x_embedder.to(dtype=model_dtype) |
|
|
print("✅ 添加了FramePack的clean_x_embedder组件") |
|
|
|
|
|
def generate_spatialvid_camera_embeddings_sliding(cam_data, start_frame, current_history_length, new_frames, total_generated, use_real_poses=True): |
|
|
"""为SpatialVid数据集生成camera embeddings - 滑动窗口版本""" |
|
|
time_compression_ratio = 4 |
|
|
|
|
|
|
|
|
framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames |
|
|
|
|
|
if use_real_poses and cam_data is not None and 'extrinsic' in cam_data: |
|
|
print("🔧 使用真实SpatialVid camera数据") |
|
|
cam_extrinsic = cam_data['extrinsic'] |
|
|
|
|
|
|
|
|
max_needed_frames = max( |
|
|
start_frame + current_history_length + new_frames, |
|
|
framepack_needed_frames, |
|
|
30 |
|
|
) |
|
|
|
|
|
print(f"🔧 计算SpatialVid camera序列长度:") |
|
|
print(f" - 基础需求: {start_frame + current_history_length + new_frames}") |
|
|
print(f" - FramePack需求: {framepack_needed_frames}") |
|
|
print(f" - 最终生成: {max_needed_frames}") |
|
|
|
|
|
relative_poses = [] |
|
|
for i in range(max_needed_frames): |
|
|
|
|
|
frame_idx = i |
|
|
next_frame_idx = frame_idx + 1 |
|
|
|
|
|
if next_frame_idx < len(cam_extrinsic): |
|
|
cam_prev = cam_extrinsic[frame_idx] |
|
|
cam_next = cam_extrinsic[next_frame_idx] |
|
|
relative_cam = compute_relative_pose_matrix(cam_prev, cam_next) |
|
|
relative_poses.append(torch.as_tensor(relative_cam[:3, :])) |
|
|
else: |
|
|
|
|
|
print(f"⚠️ 帧{frame_idx}超出camera数据范围,使用零运动") |
|
|
relative_poses.append(torch.zeros(3, 4)) |
|
|
|
|
|
pose_embedding = torch.stack(relative_poses, dim=0) |
|
|
pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') |
|
|
|
|
|
|
|
|
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32) |
|
|
|
|
|
condition_end = min(start_frame + current_history_length, max_needed_frames) |
|
|
mask[start_frame:condition_end] = 1.0 |
|
|
|
|
|
camera_embedding = torch.cat([pose_embedding, mask], dim=1) |
|
|
print(f"🔧 SpatialVid真实camera embedding shape: {camera_embedding.shape}") |
|
|
return camera_embedding.to(torch.bfloat16) |
|
|
|
|
|
else: |
|
|
print("🔧 使用SpatialVid合成camera数据") |
|
|
|
|
|
max_needed_frames = max( |
|
|
start_frame + current_history_length + new_frames, |
|
|
framepack_needed_frames, |
|
|
30 |
|
|
) |
|
|
|
|
|
print(f"🔧 生成SpatialVid合成camera帧数: {max_needed_frames}") |
|
|
relative_poses = [] |
|
|
for i in range(max_needed_frames): |
|
|
|
|
|
yaw_per_frame = 0.03 * np.sin(i * 0.1) |
|
|
forward_speed = 0.008 |
|
|
|
|
|
pose = np.eye(4, dtype=np.float32) |
|
|
|
|
|
|
|
|
cos_yaw = np.cos(yaw_per_frame) |
|
|
sin_yaw = np.sin(yaw_per_frame) |
|
|
|
|
|
pose[0, 0] = cos_yaw |
|
|
pose[0, 2] = sin_yaw |
|
|
pose[2, 0] = -sin_yaw |
|
|
pose[2, 2] = cos_yaw |
|
|
|
|
|
|
|
|
pose[2, 3] = -forward_speed |
|
|
pose[1, 3] = 0.002 * np.sin(i * 0.15) |
|
|
|
|
|
relative_pose = pose[:3, :] |
|
|
relative_poses.append(torch.as_tensor(relative_pose)) |
|
|
|
|
|
pose_embedding = torch.stack(relative_poses, dim=0) |
|
|
pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') |
|
|
|
|
|
|
|
|
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32) |
|
|
condition_end = min(start_frame + current_history_length, max_needed_frames) |
|
|
mask[start_frame:condition_end] = 1.0 |
|
|
|
|
|
camera_embedding = torch.cat([pose_embedding, mask], dim=1) |
|
|
print(f"🔧 SpatialVid合成camera embedding shape: {camera_embedding.shape}") |
|
|
return camera_embedding.to(torch.bfloat16) |
|
|
|
|
|
def prepare_framepack_sliding_window_with_camera(history_latents, target_frames_to_generate, camera_embedding_full, start_frame, max_history_frames=49): |
|
|
"""FramePack滑动窗口机制 - SpatialVid版本""" |
|
|
|
|
|
C, T, H, W = history_latents.shape |
|
|
|
|
|
|
|
|
total_indices_length = 1 + 16 + 2 + 1 + target_frames_to_generate |
|
|
indices = torch.arange(0, total_indices_length) |
|
|
split_sizes = [1, 16, 2, 1, target_frames_to_generate] |
|
|
clean_latent_indices_start, clean_latent_4x_indices, clean_latent_2x_indices, clean_latent_1x_indices, latent_indices = \ |
|
|
indices.split(split_sizes, dim=0) |
|
|
clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices], dim=0) |
|
|
|
|
|
|
|
|
if camera_embedding_full.shape[0] < total_indices_length: |
|
|
shortage = total_indices_length - camera_embedding_full.shape[0] |
|
|
padding = torch.zeros(shortage, camera_embedding_full.shape[1], |
|
|
dtype=camera_embedding_full.dtype, device=camera_embedding_full.device) |
|
|
camera_embedding_full = torch.cat([camera_embedding_full, padding], dim=0) |
|
|
|
|
|
|
|
|
combined_camera = camera_embedding_full[:total_indices_length, :].clone() |
|
|
|
|
|
|
|
|
combined_camera[:, -1] = 0.0 |
|
|
|
|
|
|
|
|
if T > 0: |
|
|
available_frames = min(T, 19) |
|
|
start_pos = 19 - available_frames |
|
|
combined_camera[start_pos:19, -1] = 1.0 |
|
|
|
|
|
print(f"🔧 SpatialVid Camera mask更新:") |
|
|
print(f" - 历史帧数: {T}") |
|
|
print(f" - 有效condition帧数: {available_frames if T > 0 else 0}") |
|
|
|
|
|
|
|
|
clean_latents_combined = torch.zeros(C, 19, H, W, dtype=history_latents.dtype, device=history_latents.device) |
|
|
|
|
|
if T > 0: |
|
|
available_frames = min(T, 19) |
|
|
start_pos = 19 - available_frames |
|
|
clean_latents_combined[:, start_pos:, :, :] = history_latents[:, -available_frames:, :, :] |
|
|
|
|
|
clean_latents_4x = clean_latents_combined[:, 0:16, :, :] |
|
|
clean_latents_2x = clean_latents_combined[:, 16:18, :, :] |
|
|
clean_latents_1x = clean_latents_combined[:, 18:19, :, :] |
|
|
|
|
|
if T > 0: |
|
|
start_latent = history_latents[:, 0:1, :, :] |
|
|
else: |
|
|
start_latent = torch.zeros(C, 1, H, W, dtype=history_latents.dtype, device=history_latents.device) |
|
|
|
|
|
clean_latents = torch.cat([start_latent, clean_latents_1x], dim=1) |
|
|
|
|
|
return { |
|
|
'latent_indices': latent_indices, |
|
|
'clean_latents': clean_latents, |
|
|
'clean_latents_2x': clean_latents_2x, |
|
|
'clean_latents_4x': clean_latents_4x, |
|
|
'clean_latent_indices': clean_latent_indices, |
|
|
'clean_latent_2x_indices': clean_latent_2x_indices, |
|
|
'clean_latent_4x_indices': clean_latent_4x_indices, |
|
|
'camera_embedding': combined_camera, |
|
|
'current_length': T, |
|
|
'next_length': T + target_frames_to_generate |
|
|
} |
|
|
|
|
|
def inference_spatialvid_framepack_sliding_window( |
|
|
condition_pth_path, |
|
|
dit_path, |
|
|
output_path="spatialvid_results/output_spatialvid_framepack_sliding.mp4", |
|
|
start_frame=0, |
|
|
initial_condition_frames=8, |
|
|
frames_per_generation=4, |
|
|
total_frames_to_generate=32, |
|
|
max_history_frames=49, |
|
|
device="cuda", |
|
|
prompt="A man walking through indoor spaces with a first-person view", |
|
|
use_real_poses=True, |
|
|
|
|
|
use_camera_cfg=True, |
|
|
camera_guidance_scale=2.0, |
|
|
text_guidance_scale=1.0 |
|
|
): |
|
|
""" |
|
|
SpatialVid FramePack滑动窗口视频生成 |
|
|
""" |
|
|
os.makedirs(os.path.dirname(output_path), exist_ok=True) |
|
|
print(f"🔧 SpatialVid FramePack滑动窗口生成开始...") |
|
|
print(f"Camera CFG: {use_camera_cfg}, Camera guidance scale: {camera_guidance_scale}") |
|
|
print(f"Text guidance scale: {text_guidance_scale}") |
|
|
|
|
|
|
|
|
replace_dit_model_in_manager() |
|
|
|
|
|
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") |
|
|
model_manager.load_models([ |
|
|
"models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors", |
|
|
"models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth", |
|
|
"models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth", |
|
|
]) |
|
|
pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager, device="cuda") |
|
|
|
|
|
|
|
|
dim = pipe.dit.blocks[0].self_attn.q.weight.shape[0] |
|
|
for block in pipe.dit.blocks: |
|
|
block.cam_encoder = nn.Linear(13, dim) |
|
|
block.projector = nn.Linear(dim, dim) |
|
|
block.cam_encoder.weight.data.zero_() |
|
|
block.cam_encoder.bias.data.zero_() |
|
|
block.projector.weight = nn.Parameter(torch.eye(dim)) |
|
|
block.projector.bias = nn.Parameter(torch.zeros(dim)) |
|
|
|
|
|
|
|
|
add_framepack_components(pipe.dit) |
|
|
|
|
|
|
|
|
dit_state_dict = torch.load(dit_path, map_location="cpu") |
|
|
pipe.dit.load_state_dict(dit_state_dict, strict=True) |
|
|
pipe = pipe.to(device) |
|
|
model_dtype = next(pipe.dit.parameters()).dtype |
|
|
|
|
|
if hasattr(pipe.dit, 'clean_x_embedder'): |
|
|
pipe.dit.clean_x_embedder = pipe.dit.clean_x_embedder.to(dtype=model_dtype) |
|
|
|
|
|
pipe.scheduler.set_timesteps(50) |
|
|
|
|
|
|
|
|
print("Loading initial condition frames...") |
|
|
initial_latents, encoded_data = load_encoded_video_from_pth( |
|
|
condition_pth_path, |
|
|
start_frame=start_frame, |
|
|
num_frames=initial_condition_frames |
|
|
) |
|
|
|
|
|
|
|
|
target_height, target_width = 60, 104 |
|
|
C, T, H, W = initial_latents.shape |
|
|
|
|
|
if H > target_height or W > target_width: |
|
|
h_start = (H - target_height) // 2 |
|
|
w_start = (W - target_width) // 2 |
|
|
initial_latents = initial_latents[:, :, h_start:h_start+target_height, w_start:w_start+target_width] |
|
|
H, W = target_height, target_width |
|
|
|
|
|
history_latents = initial_latents.to(device, dtype=model_dtype) |
|
|
|
|
|
print(f"初始history_latents shape: {history_latents.shape}") |
|
|
|
|
|
|
|
|
if text_guidance_scale > 1.0: |
|
|
prompt_emb_pos = pipe.encode_prompt(prompt) |
|
|
prompt_emb_neg = pipe.encode_prompt("") |
|
|
print(f"使用Text CFG,guidance scale: {text_guidance_scale}") |
|
|
else: |
|
|
prompt_emb_pos = pipe.encode_prompt(prompt) |
|
|
prompt_emb_neg = None |
|
|
print("不使用Text CFG") |
|
|
|
|
|
|
|
|
camera_embedding_full = generate_spatialvid_camera_embeddings_sliding( |
|
|
encoded_data.get('cam_emb', None), |
|
|
0, |
|
|
max_history_frames, |
|
|
0, |
|
|
0, |
|
|
use_real_poses=use_real_poses |
|
|
).to(device, dtype=model_dtype) |
|
|
|
|
|
print(f"完整camera序列shape: {camera_embedding_full.shape}") |
|
|
|
|
|
|
|
|
if use_camera_cfg: |
|
|
camera_embedding_uncond = torch.zeros_like(camera_embedding_full) |
|
|
print(f"创建无条件camera embedding用于CFG") |
|
|
|
|
|
|
|
|
total_generated = 0 |
|
|
all_generated_frames = [] |
|
|
|
|
|
while total_generated < total_frames_to_generate: |
|
|
current_generation = min(frames_per_generation, total_frames_to_generate - total_generated) |
|
|
print(f"\n🔧 生成步骤 {total_generated // frames_per_generation + 1}") |
|
|
print(f"当前历史长度: {history_latents.shape[1]}, 本次生成: {current_generation}") |
|
|
|
|
|
|
|
|
framepack_data = prepare_framepack_sliding_window_with_camera( |
|
|
history_latents, |
|
|
current_generation, |
|
|
camera_embedding_full, |
|
|
start_frame, |
|
|
max_history_frames |
|
|
) |
|
|
|
|
|
|
|
|
clean_latents = framepack_data['clean_latents'].unsqueeze(0) |
|
|
clean_latents_2x = framepack_data['clean_latents_2x'].unsqueeze(0) |
|
|
clean_latents_4x = framepack_data['clean_latents_4x'].unsqueeze(0) |
|
|
camera_embedding = framepack_data['camera_embedding'].unsqueeze(0) |
|
|
|
|
|
|
|
|
if use_camera_cfg: |
|
|
camera_embedding_uncond_batch = camera_embedding_uncond[:camera_embedding.shape[1], :].unsqueeze(0) |
|
|
|
|
|
|
|
|
latent_indices = framepack_data['latent_indices'].unsqueeze(0).cpu() |
|
|
clean_latent_indices = framepack_data['clean_latent_indices'].unsqueeze(0).cpu() |
|
|
clean_latent_2x_indices = framepack_data['clean_latent_2x_indices'].unsqueeze(0).cpu() |
|
|
clean_latent_4x_indices = framepack_data['clean_latent_4x_indices'].unsqueeze(0).cpu() |
|
|
|
|
|
|
|
|
new_latents = torch.randn( |
|
|
1, C, current_generation, H, W, |
|
|
device=device, dtype=model_dtype |
|
|
) |
|
|
|
|
|
extra_input = pipe.prepare_extra_input(new_latents) |
|
|
|
|
|
print(f"Camera embedding shape: {camera_embedding.shape}") |
|
|
print(f"Camera mask分布 - condition: {torch.sum(camera_embedding[0, :, -1] == 1.0).item()}, target: {torch.sum(camera_embedding[0, :, -1] == 0.0).item()}") |
|
|
|
|
|
|
|
|
timesteps = pipe.scheduler.timesteps |
|
|
|
|
|
for i, timestep in enumerate(timesteps): |
|
|
if i % 10 == 0: |
|
|
print(f" 去噪步骤 {i}/{len(timesteps)}") |
|
|
|
|
|
timestep_tensor = timestep.unsqueeze(0).to(device, dtype=model_dtype) |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
noise_pred_pos = pipe.dit( |
|
|
new_latents, |
|
|
timestep=timestep_tensor, |
|
|
cam_emb=camera_embedding, |
|
|
latent_indices=latent_indices, |
|
|
clean_latents=clean_latents, |
|
|
clean_latent_indices=clean_latent_indices, |
|
|
clean_latents_2x=clean_latents_2x, |
|
|
clean_latent_2x_indices=clean_latent_2x_indices, |
|
|
clean_latents_4x=clean_latents_4x, |
|
|
clean_latent_4x_indices=clean_latent_4x_indices, |
|
|
**prompt_emb_pos, |
|
|
**extra_input |
|
|
) |
|
|
|
|
|
|
|
|
if use_camera_cfg and camera_guidance_scale > 1.0: |
|
|
|
|
|
noise_pred_uncond = pipe.dit( |
|
|
new_latents, |
|
|
timestep=timestep_tensor, |
|
|
cam_emb=camera_embedding_uncond_batch, |
|
|
latent_indices=latent_indices, |
|
|
clean_latents=clean_latents, |
|
|
clean_latent_indices=clean_latent_indices, |
|
|
clean_latents_2x=clean_latents_2x, |
|
|
clean_latent_2x_indices=clean_latent_2x_indices, |
|
|
clean_latents_4x=clean_latents_4x, |
|
|
clean_latent_4x_indices=clean_latent_4x_indices, |
|
|
**prompt_emb_pos, |
|
|
**extra_input |
|
|
) |
|
|
|
|
|
|
|
|
noise_pred = noise_pred_uncond + camera_guidance_scale * (noise_pred_pos - noise_pred_uncond) |
|
|
else: |
|
|
noise_pred = noise_pred_pos |
|
|
|
|
|
|
|
|
if prompt_emb_neg is not None and text_guidance_scale > 1.0: |
|
|
noise_pred_neg = pipe.dit( |
|
|
new_latents, |
|
|
timestep=timestep_tensor, |
|
|
cam_emb=camera_embedding, |
|
|
latent_indices=latent_indices, |
|
|
clean_latents=clean_latents, |
|
|
clean_latent_indices=clean_latent_indices, |
|
|
clean_latents_2x=clean_latents_2x, |
|
|
clean_latent_2x_indices=clean_latent_2x_indices, |
|
|
clean_latents_4x=clean_latents_4x, |
|
|
clean_latent_4x_indices=clean_latent_4x_indices, |
|
|
**prompt_emb_neg, |
|
|
**extra_input |
|
|
) |
|
|
|
|
|
noise_pred = noise_pred_neg + text_guidance_scale * (noise_pred - noise_pred_neg) |
|
|
|
|
|
new_latents = pipe.scheduler.step(noise_pred, timestep, new_latents) |
|
|
|
|
|
|
|
|
new_latents_squeezed = new_latents.squeeze(0) |
|
|
history_latents = torch.cat([history_latents, new_latents_squeezed], dim=1) |
|
|
|
|
|
|
|
|
if history_latents.shape[1] > max_history_frames: |
|
|
first_frame = history_latents[:, 0:1, :, :] |
|
|
recent_frames = history_latents[:, -(max_history_frames-1):, :, :] |
|
|
history_latents = torch.cat([first_frame, recent_frames], dim=1) |
|
|
print(f"历史窗口已满,保留第一帧+最新{max_history_frames-1}帧") |
|
|
|
|
|
print(f"更新后history_latents shape: {history_latents.shape}") |
|
|
|
|
|
all_generated_frames.append(new_latents_squeezed) |
|
|
total_generated += current_generation |
|
|
|
|
|
print(f"✅ 已生成 {total_generated}/{total_frames_to_generate} 帧") |
|
|
|
|
|
|
|
|
print("\n🔧 解码生成的视频...") |
|
|
|
|
|
all_generated = torch.cat(all_generated_frames, dim=1) |
|
|
final_video = torch.cat([initial_latents.to(all_generated.device), all_generated], dim=1).unsqueeze(0) |
|
|
|
|
|
print(f"最终视频shape: {final_video.shape}") |
|
|
|
|
|
decoded_video = pipe.decode_video(final_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)) |
|
|
|
|
|
print(f"Saving video to {output_path}") |
|
|
|
|
|
video_np = decoded_video[0].to(torch.float32).permute(1, 2, 3, 0).cpu().numpy() |
|
|
video_np = (video_np * 0.5 + 0.5).clip(0, 1) |
|
|
video_np = (video_np * 255).astype(np.uint8) |
|
|
|
|
|
with imageio.get_writer(output_path, fps=20) as writer: |
|
|
for frame in video_np: |
|
|
writer.append_data(frame) |
|
|
|
|
|
print(f"🔧 SpatialVid FramePack滑动窗口生成完成! 保存到: {output_path}") |
|
|
print(f"总共生成了 {total_generated} 帧 (压缩后), 对应原始 {total_generated * 4} 帧") |
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description="SpatialVid FramePack滑动窗口视频生成") |
|
|
|
|
|
|
|
|
parser.add_argument("--condition_pth", type=str, |
|
|
default="/share_zhuyixuan05/zhuyixuan05/spatialvid/a9a6d37f-0a6c-548a-a494-7d902469f3f2_0000000_0000300/encoded_video.pth", |
|
|
help="输入编码视频路径") |
|
|
parser.add_argument("--start_frame", type=int, default=0) |
|
|
parser.add_argument("--initial_condition_frames", type=int, default=16) |
|
|
parser.add_argument("--frames_per_generation", type=int, default=8) |
|
|
parser.add_argument("--total_frames_to_generate", type=int, default=16) |
|
|
parser.add_argument("--max_history_frames", type=int, default=100) |
|
|
parser.add_argument("--use_real_poses", action="store_true", default=True) |
|
|
parser.add_argument("--dit_path", type=str, |
|
|
default="/share_zhuyixuan05/zhuyixuan05/ICLR2026/spatialvid/spatialvid_framepack_random/step50.ckpt", |
|
|
help="训练好的模型权重路径") |
|
|
parser.add_argument("--output_path", type=str, |
|
|
default='spatialvid_results/output_spatialvid_framepack_sliding.mp4') |
|
|
parser.add_argument("--prompt", type=str, |
|
|
default="A man walking through indoor spaces with a first-person view") |
|
|
parser.add_argument("--device", type=str, default="cuda") |
|
|
|
|
|
|
|
|
parser.add_argument("--use_camera_cfg", action="store_true", default=True, |
|
|
help="使用Camera CFG") |
|
|
parser.add_argument("--camera_guidance_scale", type=float, default=2.0, |
|
|
help="Camera guidance scale for CFG") |
|
|
parser.add_argument("--text_guidance_scale", type=float, default=1.0, |
|
|
help="Text guidance scale for CFG") |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
print(f"🔧 SpatialVid FramePack CFG生成设置:") |
|
|
print(f"Camera CFG: {args.use_camera_cfg}") |
|
|
if args.use_camera_cfg: |
|
|
print(f"Camera guidance scale: {args.camera_guidance_scale}") |
|
|
print(f"Text guidance scale: {args.text_guidance_scale}") |
|
|
print(f"SpatialVid特有特性: camera间隔为1帧") |
|
|
|
|
|
inference_spatialvid_framepack_sliding_window( |
|
|
condition_pth_path=args.condition_pth, |
|
|
dit_path=args.dit_path, |
|
|
output_path=args.output_path, |
|
|
start_frame=args.start_frame, |
|
|
initial_condition_frames=args.initial_condition_frames, |
|
|
frames_per_generation=args.frames_per_generation, |
|
|
total_frames_to_generate=args.total_frames_to_generate, |
|
|
max_history_frames=args.max_history_frames, |
|
|
device=args.device, |
|
|
prompt=args.prompt, |
|
|
use_real_poses=args.use_real_poses, |
|
|
|
|
|
use_camera_cfg=args.use_camera_cfg, |
|
|
camera_guidance_scale=args.camera_guidance_scale, |
|
|
text_guidance_scale=args.text_guidance_scale |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |