| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | import numpy as np |
| | import pandas as pd |
| | from torch.nn.functional import silu |
| | from torch.nn.functional import softplus |
| | from einops import rearrange, repeat, einsum |
| | from transformers import AutoTokenizer, AutoModel |
| | from torch import Tensor |
| | from einops import rearrange |
| |
|
| | class Embedding(): |
| | def __init__(self, model_name='jina', pooling=None): |
| | self.model_name = model_name |
| | self.pooling = pooling |
| | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | if model_name == 'jina': |
| | self.tokenizer = AutoTokenizer.from_pretrained("jinaai/jina-embeddings-v3", code_revision='da863dd04a4e5dce6814c6625adfba87b83838aa', trust_remote_code=True) |
| | self.model = AutoModel.from_pretrained("jinaai/jina-embeddings-v3", code_revision='da863dd04a4e5dce6814c6625adfba87b83838aa', trust_remote_code=True).to(self.device) |
| | elif model_name == 'xlm-roberta-base': |
| | self.tokenizer = AutoTokenizer.from_pretrained('xlm-roberta-base') |
| | self.model = AutoModel.from_pretrained('xlm-roberta-base').to(self.device) |
| | elif model_name == 'canine-c': |
| | self.tokenizer = AutoTokenizer.from_pretrained('google/canine-c') |
| | self.model = AutoModel.from_pretrained('google/canine-c').to(self.device) |
| | else: |
| | raise ValueError('Unknown name of Embedding') |
| | def _mean_pooling(self, X): |
| | def mean_pooling(model_output, attention_mask): |
| | token_embeddings = model_output[0] |
| | input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() |
| | return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) |
| | encoded_input = self.tokenizer(X, padding=True, truncation=True, return_tensors='pt').to(self.device) |
| | with torch.no_grad(): |
| | model_output = self.model(**encoded_input) |
| | sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask']) |
| | sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1) |
| | return sentence_embeddings.unsqueeze(1) |
| | |
| | def get_embeddings(self, X): |
| | if self.pooling is None: |
| | if self.model_name == 'canine-c_emb': |
| | max_len = 329 |
| | else: |
| | max_len = 95 |
| | encoded_input = self.tokenizer(X, padding=True, truncation=True, return_tensors='pt').to(self.device) |
| | with torch.no_grad(): |
| | features = self.model(**encoded_input)[0].detach().cpu().float().numpy() |
| | res = np.pad(features[:, :max_len, :], ((0, 0), (0, max(0, max_len - features.shape[1])), (0, 0)), "constant") |
| | return torch.tensor(res) |
| | elif self.pooling == 'mean': |
| | return self._mean_pooling(X) |
| | else: |
| | raise ValueError('Unknown type of pooling') |
| | class RMSNorm(nn.Module): |
| | def __init__(self, d_model: int, eps: float = 1e-8) -> None: |
| | super().__init__() |
| | self.eps = eps |
| | self.weight = nn.Parameter(torch.ones(d_model)) |
| |
|
| | def forward(self, x: Tensor) -> Tensor: |
| | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim = True) + self.eps) * self.weight |
| |
|
| | class Mamba(nn.Module): |
| | def __init__(self, num_layers, d_input, d_model, d_state=16, d_discr=None, ker_size=4, num_classes=7, model_name='jina', pooling=None): |
| | super().__init__() |
| | mamba_par = { |
| | 'd_input' : d_input, |
| | 'd_model' : d_model, |
| | 'd_state' : d_state, |
| | 'd_discr' : d_discr, |
| | 'ker_size': ker_size |
| | } |
| | self.model_name = model_name |
| | embed = Embedding(model_name, pooling) |
| | self.embedding = embed.get_embeddings |
| | self.layers = nn.ModuleList([nn.ModuleList([MambaBlock(**mamba_par), RMSNorm(d_input)]) for _ in range(num_layers)]) |
| | self.fc_out = nn.Linear(d_input, num_classes) |
| | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | self.softmax = nn.Softmax(dim=1) |
| | |
| | def forward(self, seq, cache=None): |
| | seq = torch.tensor(self.embedding(seq)).to(self.device) |
| | for mamba, norm in self.layers: |
| | out, cache = mamba(norm(seq), cache) |
| | seq = out + seq |
| | return self.fc_out(seq.mean(dim = 1)) |
| | |
| | def predict(self, x): |
| | label_to_emotion = { |
| | 0: 'anger', |
| | 1: 'disgust', |
| | 2: 'fear', |
| | 3: 'joy/happiness', |
| | 4: 'neutral', |
| | 5: 'sadness', |
| | 6: 'surprise/enthusiasm' |
| | } |
| | with torch.no_grad(): |
| | output = self.forward(x) |
| | _, predictions = torch.max(output, dim=1) |
| | result = [label_to_emotion[i] for i in (map(int, predictions))] |
| | return result |
| | |
| | def predict_proba(self, x): |
| | with torch.no_grad(): |
| | output = self.forward(x) |
| | |
| | return self.softmax(output) |
| | |
| | class MambaBlock(nn.Module): |
| | def __init__(self, d_input, d_model, d_state=16, d_discr=None, ker_size=4): |
| | super().__init__() |
| | d_discr = d_discr if d_discr is not None else d_model // 16 |
| | self.in_proj = nn.Linear(d_input, 2 * d_model, bias=False) |
| | self.out_proj = nn.Linear(d_model, d_input, bias=False) |
| | self.s_B = nn.Linear(d_model, d_state, bias=False) |
| | self.s_C = nn.Linear(d_model, d_state, bias=False) |
| | self.s_D = nn.Sequential(nn.Linear(d_model, d_discr, bias=False), nn.Linear(d_discr, d_model, bias=False),) |
| | self.conv = nn.Conv1d( |
| | in_channels=d_model, |
| | out_channels=d_model, |
| | kernel_size=ker_size, |
| | padding=ker_size - 1, |
| | groups=d_model, |
| | bias=True, |
| | ) |
| | self.A = nn.Parameter(torch.arange(1, d_state + 1, dtype=torch.float).repeat(d_model, 1)) |
| | self.D = nn.Parameter(torch.ones(d_model, dtype=torch.float)) |
| | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | |
| | def forward(self, seq, cache=None): |
| | b, l, d = seq.shape |
| | (prev_hid, prev_inp) = cache if cache is not None else (None, None) |
| | a, b = self.in_proj(seq).chunk(2, dim=-1) |
| | x = rearrange(a, 'b l d -> b d l') |
| | x = x if prev_inp is None else torch.cat((prev_inp, x), dim=-1) |
| | a = self.conv(x)[..., :l] |
| | a = rearrange(a, 'b d l -> b l d') |
| | a = silu(a) |
| | a, hid = self.ssm(a, prev_hid=prev_hid) |
| | b = silu(b) |
| | out = a * b |
| | out = self.out_proj(out) |
| | if cache: |
| | cache = (hid.squeeze(), x[..., 1:]) |
| | return out, cache |
| | |
| | def ssm(self, seq, prev_hid): |
| | A = -self.A |
| | D = +self.D |
| | B = self.s_B(seq) |
| | C = self.s_C(seq) |
| | s = softplus(D + self.s_D(seq)) |
| | A_bar = einsum(torch.exp(A), s, 'd s, b l d -> b l d s') |
| | B_bar = einsum( B, s, 'b l s, b l d -> b l d s') |
| | X_bar = einsum(B_bar, seq, 'b l d s, b l d -> b l d s') |
| | hid = self._hid_states(A_bar, X_bar, prev_hid=prev_hid) |
| | out = einsum(hid, C, 'b l d s, b l s -> b l d') |
| | out = out + D * seq |
| | return out, hid |
| | |
| | def _hid_states(self, A, X, prev_hid=None): |
| | b, l, d, s = A.shape |
| | A = rearrange(A, 'b l d s -> l b d s') |
| | X = rearrange(X, 'b l d s -> l b d s') |
| | if prev_hid is not None: |
| | return rearrange(A * prev_hid + X, 'l b d s -> b l d s') |
| | h = torch.zeros(b, d, s, device=self.device) |
| | return torch.stack([h := A_t * h + X_t for A_t, X_t in zip(A, X)], dim=1) |
| |
|