causvid / tests /sdxl /test_text_encoder.py
lyttt's picture
Add files using upload-large-folder tool
8a70e8e verified
from causvid.models.sdxl.sdxl_wrapper import SDXLTextEncoder
import time
model = SDXLTextEncoder()
prompt_list = ["a" * 300] * 20
print("Test Text Tokenizer")
for _ in range(20):
start = time.time()
output = model._encode_prompt(prompt_list)
assert "text_input_ids_one" in output.keys()
assert "text_input_ids_two" in output
assert output["text_input_ids_one"].shape[0] == 20 and output["text_input_ids_one"].shape[1] == 77
assert output["text_input_ids_two"].shape[0] == 20 and output["text_input_ids_two"].shape[1] == 77
end = time.time()
print(f"Time taken: {end - start}")
print("Test Text Encoder")
encoded_dict = model(prompt_list)
assert encoded_dict['prompt_embeds'].shape[1] == 77 and encoded_dict['prompt_embeds'].shape[2] == 2048
assert encoded_dict['pooled_prompt_embeds'].shape[0] == 20 and encoded_dict['pooled_prompt_embeds'].shape[1] == 1280