|
|
import math |
|
|
|
|
|
import torch |
|
|
|
|
|
from src.transformer_video import WanDiscreteVideoTransformer |
|
|
|
|
|
|
|
|
def _available_device(): |
|
|
return "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
|
|
|
def test_wan_discrete_video_transformer_forward_and_shapes(): |
|
|
""" |
|
|
Basic smoke test: |
|
|
- build a tiny WanDiscreteVideoTransformer |
|
|
- run a forward pass with random pseudo-video tokens + random text |
|
|
- check output shapes, parameter count and (if CUDA present) memory usage |
|
|
""" |
|
|
|
|
|
device = _available_device() |
|
|
|
|
|
|
|
|
codebook_size = 128 |
|
|
vocab_size = codebook_size + 1 |
|
|
num_frames = 2 |
|
|
height = 16 |
|
|
width = 16 |
|
|
|
|
|
model = WanDiscreteVideoTransformer( |
|
|
codebook_size=codebook_size, |
|
|
vocab_size=vocab_size, |
|
|
num_frames=num_frames, |
|
|
height=height, |
|
|
width=width, |
|
|
|
|
|
in_dim=32, |
|
|
dim=64, |
|
|
ffn_dim=128, |
|
|
freq_dim=32, |
|
|
text_dim=64, |
|
|
out_dim=32, |
|
|
num_heads=4, |
|
|
num_layers=2, |
|
|
).to(device) |
|
|
model.eval() |
|
|
|
|
|
batch_size = 2 |
|
|
|
|
|
|
|
|
tokens = torch.randint( |
|
|
low=0, |
|
|
high=codebook_size, |
|
|
size=(batch_size, num_frames, height, width), |
|
|
dtype=torch.long, |
|
|
device=device, |
|
|
) |
|
|
|
|
|
|
|
|
text_seq_len = 8 |
|
|
encoder_hidden_states = torch.randn( |
|
|
batch_size, text_seq_len, model.backbone.text_dim, device=device |
|
|
) |
|
|
|
|
|
|
|
|
timesteps = torch.randint( |
|
|
low=0, high=1000, size=(batch_size,), dtype=torch.long, device=device |
|
|
) |
|
|
|
|
|
|
|
|
if device == "cuda": |
|
|
torch.cuda.reset_peak_memory_stats() |
|
|
mem_before = torch.cuda.memory_allocated() |
|
|
else: |
|
|
mem_before = 0 |
|
|
|
|
|
with torch.no_grad(): |
|
|
logits = model( |
|
|
tokens=tokens, |
|
|
timesteps=timesteps, |
|
|
encoder_hidden_states=encoder_hidden_states, |
|
|
y=None, |
|
|
) |
|
|
|
|
|
if device == "cuda": |
|
|
mem_after = torch.cuda.memory_allocated() |
|
|
peak_mem = torch.cuda.max_memory_allocated() |
|
|
else: |
|
|
mem_after = mem_before |
|
|
peak_mem = mem_before |
|
|
|
|
|
|
|
|
assert logits.shape[0] == batch_size |
|
|
assert logits.shape[1] == codebook_size |
|
|
assert logits.shape[2] == num_frames |
|
|
|
|
|
|
|
|
h_out = height // model.backbone.patch_size[1] |
|
|
w_out = width // model.backbone.patch_size[2] |
|
|
assert logits.shape[3] == h_out |
|
|
assert logits.shape[4] == w_out |
|
|
|
|
|
|
|
|
num_params = sum(p.numel() for p in model.parameters()) |
|
|
assert num_params > 0 |
|
|
assert math.isfinite(float(num_params)) |
|
|
|
|
|
|
|
|
if device == "cuda": |
|
|
assert peak_mem >= mem_after >= mem_before |
|
|
|
|
|
|
|
|
|