File size: 902 Bytes
8a70e8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from causvid.models.sdxl.sdxl_wrapper import SDXLTextEncoder
import time

model = SDXLTextEncoder()

prompt_list = ["a" * 300] * 20

print("Test Text Tokenizer")

for _ in range(20):
    start = time.time()

    output = model._encode_prompt(prompt_list)

    assert "text_input_ids_one" in output.keys()
    assert "text_input_ids_two" in output

    assert output["text_input_ids_one"].shape[0] == 20 and output["text_input_ids_one"].shape[1] == 77
    assert output["text_input_ids_two"].shape[0] == 20 and output["text_input_ids_two"].shape[1] == 77

    end = time.time()

    print(f"Time taken: {end - start}")

print("Test Text Encoder")

encoded_dict = model(prompt_list)

assert encoded_dict['prompt_embeds'].shape[1] == 77 and encoded_dict['prompt_embeds'].shape[2] == 2048
assert encoded_dict['pooled_prompt_embeds'].shape[0] == 20 and encoded_dict['pooled_prompt_embeds'].shape[1] == 1280