| |
|
| | from sentence_transformers import models |
| | import torch |
| | import torch.nn as nn |
| |
|
| | class CustTrans(models.Transformer): |
| |
|
| | def __init__(self, *args, **kwargs): |
| | super().__init__(*args, **kwargs) |
| | self.curr_task_type = None |
| | self._rebuild_taskembedding(['sts', 'quora']) |
| |
|
| | def forward(self, inputs, task_type=None): |
| |
|
| | enc = self.auto_model(**inputs).last_hidden_state |
| |
|
| | if task_type == None: |
| | task_type = self.curr_task_type |
| |
|
| | if task_type in self.task_types: |
| | idx = torch.tensor(self.task_types.index(task_type), device=self.TaskEmbedding.weight.device) |
| | hyp = self.TaskEmbedding(idx) |
| | inputs['token_embeddings'] = self._project(enc, hyp) |
| |
|
| | else: |
| | inputs['token_embeddings'] = enc |
| |
|
| | return inputs |
| |
|
| | def _set_curr_task_type(self, task_type): |
| | self.curr_task_type = task_type |
| |
|
| | def _set_taskembedding_grad(self, value): |
| | self.TaskEmbedding.weight.requires_grad = value |
| |
|
| | def _set_transformer_grad(self, value): |
| | for param in self.auto_model.parameters(): |
| | param.requires_grad = value |
| |
|
| | def _rebuild_taskembedding(self, task_types): |
| | self.task_types = task_types |
| | self.task_emb = 1 - torch.eye(len(self.task_types),768) |
| | self.TaskEmbedding = nn.Embedding(len(self.task_types), 768).from_pretrained(self.task_emb) |
| |
|
| | def _project(self, v, normal_hyper): |
| | |
| | return v*normal_hyper |
| |
|