from transformers import PreTrainedModel from sproto.model.multi_proto import MultiProtoModule from .configuration_sproto import SprotoConfig class SprotoModel(PreTrainedModel): config_class = SprotoConfig base_model_prefix = "sproto" def __init__(self, config: SprotoConfig): super().__init__(config) self.module = MultiProtoModule( pretrained_model=config.pretrained_model, num_classes=config.num_classes, label_order_path=config.label_order_path, use_sigmoid=config.use_sigmoid, use_cuda=config.use_cuda, lr_prototypes=config.lr_prototypes, lr_features=config.lr_features, lr_others=config.lr_others, num_training_steps=config.num_training_steps, num_warmup_steps=config.num_warmup_steps, loss=config.loss, save_dir=config.save_dir, use_attention=config.use_attention, use_global_attention=config.use_global_attention, dot_product=config.dot_product, normalize=config.normalize, final_layer=config.final_layer, reduce_hidden_size=config.reduce_hidden_size, use_prototype_loss=config.use_prototype_loss, prototype_vector_path=config.prototype_vector_path, attention_vector_path=config.attention_vector_path, eval_buckets=config.eval_buckets, seed=config.seed, num_prototypes_per_class=config.num_prototypes_per_class, batch_size=config.batch_size, ) # Initialize weights and apply final processing self.post_init() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (MultiProtoModule)): # MultiProtoModule handles its own initialization or is loaded from checkpoint return # Add other initializations if standard layers are used directly in SprotoModel pass def forward( self, input_ids=None, attention_mask=None, token_type_ids=None, targets=None, tokens=None, sample_ids=None, **kwargs, ): batch = { "input_ids": input_ids, "attention_masks": attention_mask, "token_type_ids": token_type_ids, "targets": targets, "tokens": tokens, "sample_ids": sample_ids, } logits, max_indices, metadata = self.module(batch) return { "logits": logits, "max_indices": max_indices, "metadata": metadata, }