Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from transformers import AutoModel | |
| class TextEncoder(nn.Module): | |
| def __init__(self, output_dim=64, lang_model="sentence-transformers/all-MiniLM-L6-v2", unfreeze_n_blocks=4): | |
| super().__init__() | |
| self.lang_model = lang_model | |
| self.encoder = AutoModel.from_pretrained(lang_model) | |
| # freeze all parameters | |
| for param in self.encoder.parameters(): | |
| param.requires_grad = False | |
| # unfreeze the last few encoder layers | |
| for layer in self.encoder.encoder.layer[ - unfreeze_n_blocks :]: | |
| for param in layer.parameters(): | |
| param.requires_grad = True | |
| # unfreeze the pooler layer | |
| for param in self.encoder.pooler.parameters(): | |
| param.requires_grad = True | |
| self.fc = nn.Linear(self.encoder.config.hidden_size, output_dim) | |
| def forward(self, input_ids, attention_mask=None): | |
| x = self.encoder(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state[:, 0] | |
| x = self.fc(x) | |
| return x |