|
|
| from transformers import PretrainedConfig, AutoConfig |
| from typing import List |
|
|
|
|
| class BioNextTaggerConfig(PretrainedConfig): |
| model_type = "crf-tagger" |
|
|
| def __init__( |
| self, |
| augmentation = "unk", |
| context_size = 64, |
| percentage_tags = 0.2, |
| p_augmentation = 0.5, |
| crf_reduction = "mean", |
| version="0.1.2", |
| **kwargs, |
| ): |
| self.version = version |
| self.augmentation = augmentation |
| self.context_size = context_size |
| self.percentage_tags = percentage_tags |
| self.p_augmentation = p_augmentation |
| self.crf_reduction = crf_reduction |
| super().__init__(**kwargs) |
| |
| def get_backbonemodel_config(self): |
| print(f"model_path {self._name_or_path}", flush=True) |
| backbonemodel_cfg = AutoConfig.from_pretrained(self._name_or_path, |
| trust_remote_code=True) |
| for k in backbonemodel_cfg.to_dict(): |
| if hasattr(self, k): |
| setattr(backbonemodel_cfg,k, getattr(self,k)) |
| |
| return backbonemodel_cfg |
| |