| | import torch |
| | import torch.nn as nn |
| | import clip |
| | from model import ST_GCN_18 |
| |
|
| | class ContrastiveModule(nn.Module): |
| | |
| | def __init__(self, args): |
| | super(ContrastiveModule, self).__init__() |
| |
|
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| | model, preprocess = clip.load("ViT-B/32", device=device) |
| | del model.visual |
| | self.model = model |
| |
|
| | base_channel = 3 |
| | base_channel = base_channel * 2 if args.gyro else base_channel |
| | base_channel = base_channel * 2 if args.stft else base_channel |
| | self.model.acc = ST_GCN_18(in_channels=base_channel) |
| |
|
| | self.model = self.model.float() |
| |
|
| | if args.stage == 'finetune': |
| | self.fc = nn.Linear(512, args.num_class) |
| |
|
| | def encode_image(self, image): |
| | return self.model.acc(image.float()).squeeze(-1).squeeze(-1) |
| | |
| | def encode_text(self, text): |
| | x = self.model.token_embedding(text).float() |
| | x = x + self.model.positional_embedding.float() |
| | x = x.permute(1, 0, 2) |
| | x = self.model.transformer(x) |
| | x = x.permute(1, 0, 2) |
| | x = self.model.ln_final(x).float() |
| |
|
| | |
| | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.model.text_projection |
| |
|
| | return x |
| | |
| | def classifier(self, image): |
| | |
| | imu_features = self.model.acc(image.float()).squeeze(-1).squeeze(-1) |
| | out = self.fc(imu_features) |
| | return out |
| | |
| | def forward(self, inputs_imu, inputs_text): |
| |
|
| | imu_features = self.encode_image(inputs_imu) |
| | text_features = self.encode_text(inputs_text) |
| |
|
| | |
| | imu_features = imu_features / imu_features.norm(dim=-1, keepdim=True) |
| | text_features = text_features / text_features.norm(dim=-1, keepdim=True) |
| |
|
| | |
| | logit_scale = self.model.logit_scale.exp() |
| | logits_per_image = logit_scale * imu_features @ text_features.t() |
| | logits_per_text = logits_per_image.t() |
| |
|
| | return logits_per_image, logits_per_text |