File size: 444 Bytes
8a70e8e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 | from causvid.models.wan.wan_wrapper import WanTextEncoder
import torch
torch.set_grad_enabled(False)
model = WanTextEncoder().to(device="cuda:0", dtype=torch.bfloat16)
prompt_list = ["a " * 50] * 10 + ["b " * 25] * 10
print("Test Text Encoder")
encoded_dict = model(prompt_list)
assert encoded_dict['prompt_embeds'].shape[0] == 20 and encoded_dict[
'prompt_embeds'].shape[1] == 512 and encoded_dict['prompt_embeds'].shape[2] == 4096
|