savaw's picture
Upload folder using huggingface_hub
a35137b verified
from omegaconf import DictConfig
import torch
import torch.nn as nn
from typing import List
from barista.data.metadata import Metadata
from barista.models.tokenizer import Tokenizer
from barista.models.transformer import Transformer
class Barista(nn.Module):
def __init__(self, model_config: DictConfig, metadata: Metadata, *args, **kwargs):
super().__init__(*args, **kwargs)
self.metadata = metadata
self.tokenizer = Tokenizer(
config=model_config.tokenizer,
metadata=self.metadata,
)
self.backbone = Transformer(
**model_config.backbone,
)
self.d_hidden = model_config.backbone.d_hidden
self.head = None
def create_downstream_head(self, n_chans, output_dim):
self.channel_weights = nn.Linear(
n_chans * self.tokenizer.num_subsegments,
1,
bias=False,
)
self.binary_classifier = nn.Linear(
self.d_hidden, output_dim
)
def get_latent_embeddings(self, x: torch.Tensor, subject_sessions: List):
# Get tokens
tokenized_x = self.tokenizer(x, subject_sessions, output_as_list=False)
# Pass through transformer
latents = self.backbone(
x=tokenized_x.tokens,
seq_lens=tokenized_x.seq_lens,
position_ids=tokenized_x.position_ids,
)
return latents
def forward(self, x: torch.Tensor, subject_sessions: List):
latents = self.get_latent_embeddings(x, subject_sessions)
# Pass through Task head
batch_size = x[0].shape[0]
latents_reshaped = latents.reshape(batch_size, -1, latents.shape[-1])
x = self.channel_weights(latents_reshaped.permute(0, 2, 1)).squeeze(dim=-1)
x = self.binary_classifier(x)
return x
def get_task_params(self):
return [*self.channel_weights.named_parameters(), *self.binary_classifier.named_parameters()]
def get_upstream_params(self):
return [*self.tokenizer.named_parameters(), *self.backbone.named_parameters()]