| from vjepa_encoder.vision_encoder import JepaEncoder | |
| encoder = JepaEncoder.load_model( | |
| "logs/params-encoder.yaml" | |
| ) | |
| import numpy | |
| import torch | |
| img = numpy.random.random(size=(360, 480, 3)) | |
| x = torch.rand((32, 3, 256, 900)) | |
| print("Input Img:", img.shape) | |
| embedding = encoder.embed_image(img) | |
| print(embedding) | |
| print(embedding.shape) | |
| embedding = encoder.embed_image(x) | |
| print(embedding) | |
| print(embedding.shape) | |
| encoder.save_checkpoint("./test_jepa_model.tar") |