File size: 2,196 Bytes
a35137b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from omegaconf import DictConfig
import torch
import torch.nn as nn
from typing import List

from barista.data.metadata import Metadata
from barista.models.tokenizer import Tokenizer
from barista.models.transformer import Transformer


class Barista(nn.Module):
    def __init__(self, model_config: DictConfig, metadata: Metadata, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.metadata = metadata
        
        self.tokenizer = Tokenizer(
            config=model_config.tokenizer,
            metadata=self.metadata,
        )
        
        self.backbone = Transformer(
            **model_config.backbone,
        )
        
        self.d_hidden = model_config.backbone.d_hidden
        
        self.head = None
        
    def create_downstream_head(self, n_chans, output_dim):
        self.channel_weights = nn.Linear(
            n_chans * self.tokenizer.num_subsegments,
            1,
            bias=False,
        )
        self.binary_classifier = nn.Linear(
            self.d_hidden, output_dim
        )
        
    def get_latent_embeddings(self, x: torch.Tensor, subject_sessions: List):
        #  Get tokens
        tokenized_x = self.tokenizer(x, subject_sessions, output_as_list=False)
        
        # Pass through transformer
        latents = self.backbone(
            x=tokenized_x.tokens, 
            seq_lens=tokenized_x.seq_lens, 
            position_ids=tokenized_x.position_ids,
            )
        
        return latents

    def forward(self, x: torch.Tensor, subject_sessions: List):
        
        latents = self.get_latent_embeddings(x, subject_sessions)
        
        # Pass through Task head
        batch_size = x[0].shape[0]
        latents_reshaped = latents.reshape(batch_size, -1, latents.shape[-1])
        x = self.channel_weights(latents_reshaped.permute(0, 2, 1)).squeeze(dim=-1)
        x = self.binary_classifier(x)
    
        return x

    def get_task_params(self):
        return [*self.channel_weights.named_parameters(), *self.binary_classifier.named_parameters()]
    
    def get_upstream_params(self):
        return [*self.tokenizer.named_parameters(), *self.backbone.named_parameters()]