43 / Meissonic /tests /test.py
BryanW's picture
Upload folder using huggingface_hub
3d1c0e1 verified
import math
import torch
from src.transformer_video import WanDiscreteVideoTransformer
def _available_device():
return "cuda" if torch.cuda.is_available() else "cpu"
def test_wan_discrete_video_transformer_forward_and_shapes():
"""
Basic smoke test:
- build a tiny WanDiscreteVideoTransformer
- run a forward pass with random pseudo-video tokens + random text
- check output shapes, parameter count and (if CUDA present) memory usage
"""
device = _available_device()
# small config to keep the test lightweight
codebook_size = 128
vocab_size = codebook_size + 1 # reserve one for mask if needed later
num_frames = 2
height = 16
width = 16
model = WanDiscreteVideoTransformer(
codebook_size=codebook_size,
vocab_size=vocab_size,
num_frames=num_frames,
height=height,
width=width,
# shrink Wan backbone for the unit test
in_dim=32,
dim=64,
ffn_dim=128,
freq_dim=32,
text_dim=64,
out_dim=32,
num_heads=4,
num_layers=2,
).to(device)
model.eval()
batch_size = 2
# pseudo-video tokens from 2D VQ-VAE on frames: [B, F, H, W]
tokens = torch.randint(
low=0,
high=codebook_size,
size=(batch_size, num_frames, height, width),
dtype=torch.long,
device=device,
)
# text: [B, L, C_text]
text_seq_len = 8
encoder_hidden_states = torch.randn(
batch_size, text_seq_len, model.backbone.text_dim, device=device
)
# timesteps: [B]
timesteps = torch.randint(
low=0, high=1000, size=(batch_size,), dtype=torch.long, device=device
)
# track memory if CUDA is available
if device == "cuda":
torch.cuda.reset_peak_memory_stats()
mem_before = torch.cuda.memory_allocated()
else:
mem_before = 0
with torch.no_grad():
logits = model(
tokens=tokens,
timesteps=timesteps,
encoder_hidden_states=encoder_hidden_states,
y=None,
)
if device == "cuda":
mem_after = torch.cuda.memory_allocated()
peak_mem = torch.cuda.max_memory_allocated()
else:
mem_after = mem_before
peak_mem = mem_before
# logits: [B, codebook_size, F, H_out, W_out]
assert logits.shape[0] == batch_size
assert logits.shape[1] == codebook_size
assert logits.shape[2] == num_frames
# spatial size after Wan patch embedding with default patch_size (1, 2, 2)
h_out = height // model.backbone.patch_size[1]
w_out = width // model.backbone.patch_size[2]
assert logits.shape[3] == h_out
assert logits.shape[4] == w_out
# parameter count sanity check (just ensure it's > 0 and finite)
num_params = sum(p.numel() for p in model.parameters())
assert num_params > 0
assert math.isfinite(float(num_params))
# memory sanity check (on CUDA the forward pass should allocate > 0 bytes)
if device == "cuda":
assert peak_mem >= mem_after >= mem_before