|
|
from transformers import PreTrainedModel
|
|
|
from sproto.model.multi_proto import MultiProtoModule
|
|
|
from .configuration_sproto import SprotoConfig
|
|
|
|
|
|
class SprotoModel(PreTrainedModel):
|
|
|
config_class = SprotoConfig
|
|
|
base_model_prefix = "sproto"
|
|
|
|
|
|
def __init__(self, config: SprotoConfig):
|
|
|
super().__init__(config)
|
|
|
|
|
|
self.module = MultiProtoModule(
|
|
|
pretrained_model=config.pretrained_model,
|
|
|
num_classes=config.num_classes,
|
|
|
label_order_path=config.label_order_path,
|
|
|
use_sigmoid=config.use_sigmoid,
|
|
|
use_cuda=config.use_cuda,
|
|
|
lr_prototypes=config.lr_prototypes,
|
|
|
lr_features=config.lr_features,
|
|
|
lr_others=config.lr_others,
|
|
|
num_training_steps=config.num_training_steps,
|
|
|
num_warmup_steps=config.num_warmup_steps,
|
|
|
loss=config.loss,
|
|
|
save_dir=config.save_dir,
|
|
|
use_attention=config.use_attention,
|
|
|
use_global_attention=config.use_global_attention,
|
|
|
dot_product=config.dot_product,
|
|
|
normalize=config.normalize,
|
|
|
final_layer=config.final_layer,
|
|
|
reduce_hidden_size=config.reduce_hidden_size,
|
|
|
use_prototype_loss=config.use_prototype_loss,
|
|
|
prototype_vector_path=config.prototype_vector_path,
|
|
|
attention_vector_path=config.attention_vector_path,
|
|
|
eval_buckets=config.eval_buckets,
|
|
|
seed=config.seed,
|
|
|
num_prototypes_per_class=config.num_prototypes_per_class,
|
|
|
batch_size=config.batch_size,
|
|
|
)
|
|
|
|
|
|
|
|
|
self.post_init()
|
|
|
|
|
|
def _init_weights(self, module):
|
|
|
"""Initialize the weights"""
|
|
|
if isinstance(module, (MultiProtoModule)):
|
|
|
|
|
|
return
|
|
|
|
|
|
pass
|
|
|
|
|
|
def forward(
|
|
|
self,
|
|
|
input_ids=None,
|
|
|
attention_mask=None,
|
|
|
token_type_ids=None,
|
|
|
targets=None,
|
|
|
tokens=None,
|
|
|
sample_ids=None,
|
|
|
**kwargs,
|
|
|
):
|
|
|
|
|
|
batch = {
|
|
|
"input_ids": input_ids,
|
|
|
"attention_masks": attention_mask,
|
|
|
"token_type_ids": token_type_ids,
|
|
|
"targets": targets,
|
|
|
"tokens": tokens,
|
|
|
"sample_ids": sample_ids,
|
|
|
}
|
|
|
|
|
|
logits, max_indices, metadata = self.module(batch)
|
|
|
|
|
|
return {
|
|
|
"logits": logits,
|
|
|
"max_indices": max_indices,
|
|
|
"metadata": metadata,
|
|
|
} |