File size: 902 Bytes
8a70e8e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 | from causvid.models.sdxl.sdxl_wrapper import SDXLTextEncoder
import time
model = SDXLTextEncoder()
prompt_list = ["a" * 300] * 20
print("Test Text Tokenizer")
for _ in range(20):
start = time.time()
output = model._encode_prompt(prompt_list)
assert "text_input_ids_one" in output.keys()
assert "text_input_ids_two" in output
assert output["text_input_ids_one"].shape[0] == 20 and output["text_input_ids_one"].shape[1] == 77
assert output["text_input_ids_two"].shape[0] == 20 and output["text_input_ids_two"].shape[1] == 77
end = time.time()
print(f"Time taken: {end - start}")
print("Test Text Encoder")
encoded_dict = model(prompt_list)
assert encoded_dict['prompt_embeds'].shape[1] == 77 and encoded_dict['prompt_embeds'].shape[2] == 2048
assert encoded_dict['pooled_prompt_embeds'].shape[0] == 20 and encoded_dict['pooled_prompt_embeds'].shape[1] == 1280
|