File size: 444 Bytes
8a70e8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
from causvid.models.wan.wan_wrapper import WanTextEncoder
import torch

torch.set_grad_enabled(False)

model = WanTextEncoder().to(device="cuda:0", dtype=torch.bfloat16)

prompt_list = ["a " * 50] * 10 + ["b " * 25] * 10


print("Test Text Encoder")

encoded_dict = model(prompt_list)

assert encoded_dict['prompt_embeds'].shape[0] == 20 and encoded_dict[
    'prompt_embeds'].shape[1] == 512 and encoded_dict['prompt_embeds'].shape[2] == 4096