| | from omegaconf import DictConfig |
| | import torch |
| | import torch.nn as nn |
| | from typing import List |
| |
|
| | from barista.data.metadata import Metadata |
| | from barista.models.tokenizer import Tokenizer |
| | from barista.models.transformer import Transformer |
| |
|
| |
|
| | class Barista(nn.Module): |
| | def __init__(self, model_config: DictConfig, metadata: Metadata, *args, **kwargs): |
| | super().__init__(*args, **kwargs) |
| | self.metadata = metadata |
| | |
| | self.tokenizer = Tokenizer( |
| | config=model_config.tokenizer, |
| | metadata=self.metadata, |
| | ) |
| | |
| | self.backbone = Transformer( |
| | **model_config.backbone, |
| | ) |
| | |
| | self.d_hidden = model_config.backbone.d_hidden |
| | |
| | self.head = None |
| | |
| | def create_downstream_head(self, n_chans, output_dim): |
| | self.channel_weights = nn.Linear( |
| | n_chans * self.tokenizer.num_subsegments, |
| | 1, |
| | bias=False, |
| | ) |
| | self.binary_classifier = nn.Linear( |
| | self.d_hidden, output_dim |
| | ) |
| | |
| | def get_latent_embeddings(self, x: torch.Tensor, subject_sessions: List): |
| | |
| | tokenized_x = self.tokenizer(x, subject_sessions, output_as_list=False) |
| | |
| | |
| | latents = self.backbone( |
| | x=tokenized_x.tokens, |
| | seq_lens=tokenized_x.seq_lens, |
| | position_ids=tokenized_x.position_ids, |
| | ) |
| | |
| | return latents |
| |
|
| | def forward(self, x: torch.Tensor, subject_sessions: List): |
| | |
| | latents = self.get_latent_embeddings(x, subject_sessions) |
| | |
| | |
| | batch_size = x[0].shape[0] |
| | latents_reshaped = latents.reshape(batch_size, -1, latents.shape[-1]) |
| | x = self.channel_weights(latents_reshaped.permute(0, 2, 1)).squeeze(dim=-1) |
| | x = self.binary_classifier(x) |
| | |
| | return x |
| |
|
| | def get_task_params(self): |
| | return [*self.channel_weights.named_parameters(), *self.binary_classifier.named_parameters()] |
| | |
| | def get_upstream_params(self): |
| | return [*self.tokenizer.named_parameters(), *self.backbone.named_parameters()] |
| |
|