| import torch |
| from resampler import Resampler |
| from transformers import CLIPVisionModel |
|
|
| BATCH_SIZE = 2 |
| OUTPUT_DIM = 1280 |
| NUM_QUERIES = 8 |
| NUM_LATENTS_MEAN_POOLED = 4 |
| APPLY_POS_EMB = True |
| IMAGE_ENCODER_NAME_OR_PATH = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K" |
|
|
|
|
| def main(): |
| image_encoder = CLIPVisionModel.from_pretrained(IMAGE_ENCODER_NAME_OR_PATH) |
| embedding_dim = image_encoder.config.hidden_size |
| print(f"image_encoder hidden size: ", embedding_dim) |
|
|
| image_proj_model = Resampler( |
| dim=1024, |
| depth=2, |
| dim_head=64, |
| heads=16, |
| num_queries=NUM_QUERIES, |
| embedding_dim=embedding_dim, |
| output_dim=OUTPUT_DIM, |
| ff_mult=2, |
| max_seq_len=257, |
| apply_pos_emb=APPLY_POS_EMB, |
| num_latents_mean_pooled=NUM_LATENTS_MEAN_POOLED, |
| ) |
|
|
| dummy_images = torch.randn(BATCH_SIZE, 3, 224, 224) |
| with torch.no_grad(): |
| image_embeds = image_encoder(dummy_images, output_hidden_states=True).hidden_states[-2] |
| print("image_embds shape: ", image_embeds.shape) |
|
|
| with torch.no_grad(): |
| ip_tokens = image_proj_model(image_embeds) |
| print("ip_tokens shape:", ip_tokens.shape) |
| assert ip_tokens.shape == (BATCH_SIZE, NUM_QUERIES + NUM_LATENTS_MEAN_POOLED, OUTPUT_DIM) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|