File size: 2,750 Bytes
d728ce3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
from transformers import PreTrainedModel
from sproto.model.multi_proto import MultiProtoModule
from .configuration_sproto import SprotoConfig

class SprotoModel(PreTrainedModel):
    config_class = SprotoConfig
    base_model_prefix = "sproto"

    def __init__(self, config: SprotoConfig):
        super().__init__(config)

        self.module = MultiProtoModule(
            pretrained_model=config.pretrained_model,
            num_classes=config.num_classes,
            label_order_path=config.label_order_path,
            use_sigmoid=config.use_sigmoid,
            use_cuda=config.use_cuda,
            lr_prototypes=config.lr_prototypes,
            lr_features=config.lr_features,
            lr_others=config.lr_others,
            num_training_steps=config.num_training_steps,
            num_warmup_steps=config.num_warmup_steps,
            loss=config.loss,
            save_dir=config.save_dir,
            use_attention=config.use_attention,
            use_global_attention=config.use_global_attention,
            dot_product=config.dot_product,
            normalize=config.normalize,
            final_layer=config.final_layer,
            reduce_hidden_size=config.reduce_hidden_size,
            use_prototype_loss=config.use_prototype_loss,
            prototype_vector_path=config.prototype_vector_path,
            attention_vector_path=config.attention_vector_path,
            eval_buckets=config.eval_buckets,
            seed=config.seed,
            num_prototypes_per_class=config.num_prototypes_per_class,
            batch_size=config.batch_size,
        )
        
        # Initialize weights and apply final processing
        self.post_init()

    def _init_weights(self, module):
        """Initialize the weights"""
        if isinstance(module, (MultiProtoModule)):
            # MultiProtoModule handles its own initialization or is loaded from checkpoint
            return
        # Add other initializations if standard layers are used directly in SprotoModel
        pass

    def forward(

        self,

        input_ids=None,

        attention_mask=None,

        token_type_ids=None,

        targets=None,

        tokens=None,

        sample_ids=None,

        **kwargs,

    ):

        batch = {
            "input_ids": input_ids,
            "attention_masks": attention_mask,
            "token_type_ids": token_type_ids,
            "targets": targets,
            "tokens": tokens,
            "sample_ids": sample_ids,
        }

        logits, max_indices, metadata = self.module(batch)

        return {
            "logits": logits,
            "max_indices": max_indices,
            "metadata": metadata,
        }