from causvid.models.wan.wan_wrapper import WanTextEncoder import torch torch.set_grad_enabled(False) model = WanTextEncoder().to(device="cuda:0", dtype=torch.bfloat16) prompt_list = ["a " * 50] * 10 + ["b " * 25] * 10 print("Test Text Encoder") encoded_dict = model(prompt_list) assert encoded_dict['prompt_embeds'].shape[0] == 20 and encoded_dict[ 'prompt_embeds'].shape[1] == 512 and encoded_dict['prompt_embeds'].shape[2] == 4096