Spaces:
Running
Running
| import torch.nn as nn | |
| from transformers import ( | |
| AutoModel, | |
| AutoTokenizer, | |
| PreTrainedTokenizerFast, | |
| PreTrainedModel, | |
| ) | |
| import torch | |
| from omegaconf import DictConfig | |
| from typing import Dict | |
| import transformers | |
| transformers.logging.set_verbosity_error() | |
| class BaseDocEncoder(nn.Module): | |
| def __init__(self, config: DictConfig): | |
| super(BaseDocEncoder, self).__init__() | |
| self.config = config | |
| gradient_checkpointing = False | |
| if config.finetune: | |
| gradient_checkpointing = True | |
| model_str: str = config.transformer.model_str | |
| self.lm_encoder: PreTrainedModel = AutoModel.from_pretrained( | |
| pretrained_model_name_or_path=model_str, | |
| output_hidden_states=False, | |
| add_pooling_layer=False, ## Comment it out for LLAMA | |
| ) | |
| if gradient_checkpointing: | |
| self.lm_encoder.gradient_checkpointing_enable() ####### This is the line that we need to put in the code when we enable. | |
| self.tokenizer: PreTrainedTokenizerFast = AutoTokenizer.from_pretrained( | |
| pretrained_model_name_or_path=model_str, | |
| use_fast=True, | |
| clean_up_tokenization_spaces=True, | |
| ) | |
| if config.add_speaker_tokens: | |
| self.tokenizer.add_special_tokens( | |
| { | |
| "additional_special_tokens": [ | |
| config.speaker_start, | |
| config.speaker_end, | |
| ] | |
| } | |
| ) | |
| self.lm_encoder.resize_token_embeddings(len(self.tokenizer)) | |
| if not config.finetune: | |
| for param in self.lm_encoder.parameters(): | |
| # Don't update encoder params | |
| param.requires_grad = False | |
| self.hidden_size: int = self.lm_encoder.config.hidden_size | |
| def device(self) -> torch.device: | |
| return next(self.lm_encoder.parameters()).device | |
| def get_tokenizer(self) -> PreTrainedTokenizerFast: | |
| return self.tokenizer | |
| def to_add_speaker_tokens(self) -> bool: | |
| return self.add_speaker_tokens | |
| def forward(self, document: Dict): | |
| raise NotImplementedError | |