| """ | |
| © Battelle Memorial Institute 2023 | |
| Made available under the GNU General Public License v 2.0 | |
| BECAUSE THE PROGRAM IS LICENSED FREE OF CHARGE, THERE IS NO WARRANTY | |
| FOR THE PROGRAM, TO THE EXTENT PERMITTED BY APPLICABLE LAW. EXCEPT WHEN | |
| OTHERWISE STATED IN WRITING THE COPYRIGHT HOLDERS AND/OR OTHER PARTIES | |
| PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY OF ANY KIND, EITHER EXPRESSED | |
| OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF | |
| MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE ENTIRE RISK AS | |
| TO THE QUALITY AND PERFORMANCE OF THE PROGRAM IS WITH YOU. SHOULD THE | |
| PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF ALL NECESSARY SERVICING, | |
| REPAIR OR CORRECTION. | |
| """ | |
| import torch | |
| from transformers import PreTrainedModel | |
| from .fup_bert_config import FupBERTConfig | |
| from .fup_bert_model import FupBERTModel | |
| class FupBERT(PreTrainedModel): | |
| """Hugging Face Wrapper""" | |
| config_class = FupBERTConfig | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.model = FupBERTModel(ntoken=config.ntoken, | |
| ninp=config.ninp, | |
| nhead=config.nhead, | |
| nhid=config.nhid, | |
| nlayers=config.nlayers, | |
| token_reduction=config.token_reduction, | |
| padding_idx=config.padding_idx, | |
| cls_idx=config.cls_idx, | |
| edge_idx=config.edge_idx, | |
| num_out=config.num_out, | |
| dropout=config.dropout, | |
| ) | |
| def forward(self, src): | |
| return self.model(src) | |
| def load_params(self, pt_file): | |
| self.model.load_state_dict(torch.load(pt_file)) | |