| | import torch |
| | from resampler import Resampler |
| | from transformers import CLIPVisionModel |
| |
|
| | BATCH_SIZE = 2 |
| | OUTPUT_DIM = 1280 |
| | NUM_QUERIES = 8 |
| | NUM_LATENTS_MEAN_POOLED = 4 |
| | APPLY_POS_EMB = True |
| | IMAGE_ENCODER_NAME_OR_PATH = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K" |
| |
|
| |
|
| | def main(): |
| | image_encoder = CLIPVisionModel.from_pretrained(IMAGE_ENCODER_NAME_OR_PATH) |
| | embedding_dim = image_encoder.config.hidden_size |
| | print(f"image_encoder hidden size: ", embedding_dim) |
| |
|
| | image_proj_model = Resampler( |
| | dim=1024, |
| | depth=2, |
| | dim_head=64, |
| | heads=16, |
| | num_queries=NUM_QUERIES, |
| | embedding_dim=embedding_dim, |
| | output_dim=OUTPUT_DIM, |
| | ff_mult=2, |
| | max_seq_len=257, |
| | apply_pos_emb=APPLY_POS_EMB, |
| | num_latents_mean_pooled=NUM_LATENTS_MEAN_POOLED, |
| | ) |
| |
|
| | dummy_images = torch.randn(BATCH_SIZE, 3, 224, 224) |
| | with torch.no_grad(): |
| | image_embeds = image_encoder(dummy_images, output_hidden_states=True).hidden_states[-2] |
| | print("image_embds shape: ", image_embeds.shape) |
| |
|
| | with torch.no_grad(): |
| | ip_tokens = image_proj_model(image_embeds) |
| | print("ip_tokens shape:", ip_tokens.shape) |
| | assert ip_tokens.shape == (BATCH_SIZE, NUM_QUERIES + NUM_LATENTS_MEAN_POOLED, OUTPUT_DIM) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|