Spaces:
Sleeping
Sleeping
File size: 1,617 Bytes
c1596ac | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 | import torch
import torch.nn as nn
from torchvision import models
class EncoderViTB16(nn.Module):
def __init__(self, num_classes=50, embed_size=512):
super().__init__()
model = models.vit_b_16(
weights=models.ViT_B_16_Weights.DEFAULT
)
self.backbone = model
for param in self.backbone.parameters():
param.requires_grad = False
in_features = model.heads.head.in_features
self.backbone.heads = nn.Identity()
self.classifier = nn.Linear(
in_features,
num_classes
)
self.projector = nn.Linear(
in_features,
embed_size
)
def forward(
self,
images,
return_features=False
):
features = self.backbone(images)
if isinstance(features, tuple):
features = features[0]
features = features.view(
features.size(0),
-1
)
logits = self.classifier(features)
# 특성 추출
cap_features = self.backbone._process_input(images) # B, 196, 768
cap_features = cap_features + self.backbone.encoder.pos_embedding[:, 1:, :] # 위치 임베딩
for layer in self.backbone.encoder.layers: # B, 196, 768
cap_features = layer(cap_features)
cap_features = self.backbone.encoder.ln(cap_features) # LayerNorm
cap_features = self.projector(cap_features) # B, 196, d_model
# classification
if not return_features:
return logits
# captioning
return cap_features |