| from transformers import PreTrainedModel | |
| import torch | |
| from .proto import ProtoModule | |
| from .configuration_proto import ProtoConfig | |
| class ProtoForMultiLabelClassification(PreTrainedModel): | |
| config_class = ProtoConfig | |
| def __init__(self, config: ProtoConfig): | |
| super().__init__(config) | |
| self.proto_module = ProtoModule( | |
| pretrained_model=config.pretrained_model_name_or_path, | |
| num_classes=config.num_classes, | |
| label_order_path=config.label_order_path, | |
| use_sigmoid=config.use_sigmoid, | |
| use_cuda=config.use_cuda, | |
| lr_prototypes=config.lr_prototypes, | |
| lr_features=config.lr_features, | |
| lr_others=config.lr_others, | |
| num_training_steps=config.num_training_steps, | |
| num_warmup_steps=config.num_warmup_steps, | |
| loss=config.loss, | |
| save_dir=config.save_dir, | |
| use_attention=config.use_attention, | |
| dot_product=config.dot_product, | |
| normalize=config.normalize, | |
| final_layer=config.final_layer, | |
| reduce_hidden_size=config.reduce_hidden_size, | |
| use_prototype_loss=config.use_prototype_loss, | |
| prototype_vector_path=config.prototype_vector_path, | |
| attention_vector_path=config.attention_vector_path, | |
| eval_buckets=config.eval_buckets, | |
| seed=config.seed | |
| ) | |
| self.init_weights() | |
| def forward(self, input_ids, attention_mask, token_type_ids, **kwargs): | |
| batch = { | |
| "input_ids": input_ids, | |
| "attention_masks": attention_mask, | |
| "token_type_ids": token_type_ids, | |
| } | |
| logits, metadata = self.proto_module(batch) | |
| return {"logits": logits, "metadata": metadata} | |