causvid / tests /wan /test_text_encoder.py
lyttt's picture
Add files using upload-large-folder tool
8a70e8e verified
from causvid.models.wan.wan_wrapper import WanTextEncoder
import torch
torch.set_grad_enabled(False)
model = WanTextEncoder().to(device="cuda:0", dtype=torch.bfloat16)
prompt_list = ["a " * 50] * 10 + ["b " * 25] * 10
print("Test Text Encoder")
encoded_dict = model(prompt_list)
assert encoded_dict['prompt_embeds'].shape[0] == 20 and encoded_dict[
'prompt_embeds'].shape[1] == 512 and encoded_dict['prompt_embeds'].shape[2] == 4096