How to get register token output values ?
are first 5-tokens [CLS] + 4x [REG] tokens?
Yes, num_prefix_tokens is 1 cls + 4 x reg for these models: https://github.com/huggingface/pytorch-image-models/blob/a6fe31b09670289dbc8e99a0cfae23de355534c9/timm/models/vision_transformer.py#L497-L498
easiest way to get them is forward_features() and take the [1:5] in the flattened output, or you can use forward_intermediates() to get the prefix tokens for all blocks
oo = mm.forward_intermediates(torch.randn(2,3,518,518), return_prefix_tokens=True)
>>>
oo[1][-1][1].shape
torch.Size([2, 5, 768])
output there is a tuple of the final features and block output features, each block output is a tuple of spatial features and prefix tokens when return_prefix_tokens is set to True.
Hi! I just would like to ask if what I did below is correct (please see screenshot)
So you said that the easiest way to get the prefix token embeddings is using forward_features() and take the first 5 in the sequence. I did that (top) and compared it to using forward_intermediates()... However, their outputs are different. Is there something that I have missed? Would appreciate your help! Thank you so much :)
EDIT: I was able to show that they're the same... I forgot to add norm=True argument in forward_intermediates(). Hope this helps!
Hi! It's me again. Just one more question:
Screenshot below is taken from (https://github.com/facebookresearch/dinov2/blob/main/MODEL_CARD.md)
As I've understood there's a total of 261 tokens (1 class + 4 prefix + 256 patch tokens). Now, going back to the timm version, the output shape is (1, 1374, 768). Is the 1374 semantically equivalent to the 261 i.e., is the 1374 the sequences of tokens? How was it able to come up with this versus the 261? Thank you :-)
dinov2 models I think are 518x518 by default ... so 37*37 spatial patches 1 + cls token + 4 reg tokens = 1374 ... it would be 261 if you resized and used 224x224 images
your snippets above are correct if you want both cls + reg tokens together, if you want just the regs then slice [1:5] to get the 4 reg tokens.