| """ |
| © Battelle Memorial Institute 2023 |
| Made available under the GNU General Public License v 2.0 |
| |
| BECAUSE THE PROGRAM IS LICENSED FREE OF CHARGE, THERE IS NO WARRANTY |
| FOR THE PROGRAM, TO THE EXTENT PERMITTED BY APPLICABLE LAW. EXCEPT WHEN |
| OTHERWISE STATED IN WRITING THE COPYRIGHT HOLDERS AND/OR OTHER PARTIES |
| PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY OF ANY KIND, EITHER EXPRESSED |
| OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF |
| MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE ENTIRE RISK AS |
| TO THE QUALITY AND PERFORMANCE OF THE PROGRAM IS WITH YOU. SHOULD THE |
| PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF ALL NECESSARY SERVICING, |
| REPAIR OR CORRECTION. |
| """ |
|
|
| import torch |
| from transformers import PreTrainedModel |
|
|
| from .fup_bert_config import FupBERTConfig |
| from .fup_bert_model import FupBERTModel |
|
|
|
|
| class FupBERT(PreTrainedModel): |
| """Hugging Face Wrapper""" |
| config_class = FupBERTConfig |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.model = FupBERTModel(ntoken=config.ntoken, |
| ninp=config.ninp, |
| nhead=config.nhead, |
| nhid=config.nhid, |
| nlayers=config.nlayers, |
| token_reduction=config.token_reduction, |
| padding_idx=config.padding_idx, |
| cls_idx=config.cls_idx, |
| edge_idx=config.edge_idx, |
| num_out=config.num_out, |
| dropout=config.dropout, |
| ) |
|
|
| def forward(self, src): |
| return self.model(src) |
|
|
| def load_params(self, pt_file): |
| self.model.load_state_dict(torch.load(pt_file)) |
|
|