File size: 4,959 Bytes
e032893
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
#!/usr/bin/env python3
"""Hugging Face compatible StudentPRM model.

Provides StudentPRMConfig + StudentPRM so that after training you can
push to the hub and later load with:

    from transformers import AutoTokenizer, AutoModel
    tok = AutoTokenizer.from_pretrained(repo_id, trust_remote_code=True)
    model = AutoModel.from_pretrained(repo_id, trust_remote_code=True)

The 2-logit PRM head is included. The model pools the final hidden state
at the last occurrence of the configured pool token (e.g. </think>)."""
from typing import Optional
import torch
import torch.nn as nn
from transformers import (
    AutoModel,
    PreTrainedModel,
    PretrainedConfig,
)
from transformers.modeling_outputs import SequenceClassifierOutput


def last_token_index(input_ids: torch.Tensor, token_id: int) -> torch.Tensor:
    mask = (input_ids == token_id)
    flipped = torch.flip(mask, dims=[1]).int().argmax(dim=1)
    return (input_ids.shape[1] - 1) - flipped


class StudentPRMConfig(PretrainedConfig):
    model_type = "student_prm"

    def __init__(
        self,
        base_model_name: str,
        pool_token: str,
        hidden_size: int,
        vocab_size: int,
        num_labels: int = 2,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.base_model_name = base_model_name
        self.pool_token = pool_token
        self.hidden_size = hidden_size
        self.vocab_size = vocab_size
        self.num_labels = num_labels
        self.auto_map = {"AutoModel": "modeling_student_prm.StudentPRM", "AutoConfig": "modeling_student_prm.StudentPRMConfig"}
        self.architectures = ["StudentPRM"]
        # Prevent default-diff logic from re-instantiating without required args
        self.has_no_defaults_at_init = True

    def _get_non_default_generation_parameters(self):
        # Classification model; no generation params to diff
        # Returning empty dict prevents transformers from calling self.__class__() with missing args
        return {}

    def to_diff_dict(self):  # override to bypass default instantiation logic
        return self.to_dict()


class StudentPRM(PreTrainedModel):
    config_class = StudentPRMConfig

    def __init__(self, config: StudentPRMConfig, base: Optional[PreTrainedModel] = None, tokenizer=None):
        super().__init__(config)
        if base is None:
            # Load base model; rely on remote code if needed
            base = AutoModel.from_pretrained(config.base_model_name, trust_remote_code=True)
        # Resize embeddings if vocab changed due to added special tokens
        try:
            current_vocab = base.get_input_embeddings().weight.shape[0]
            if current_vocab != config.vocab_size:
                base.resize_token_embeddings(config.vocab_size)
        except Exception:
            pass
        self.base = base
        self.head = nn.Linear(config.hidden_size, config.num_labels)
        self.tokenizer = tokenizer  # optional, only needed for pool id resolution when provided
        if tokenizer is not None and config.pool_token in tokenizer.get_vocab():
            self.pool_id = tokenizer.convert_tokens_to_ids(config.pool_token)
        else:
            # Will be resolved later if tokenizer added special token dynamically
            self.pool_id = None
        self.post_init()

    def _resolve_pool_id(self, input_ids: torch.Tensor):
        if self.pool_id is None and self.tokenizer is not None:
            self.pool_id = self.tokenizer.convert_tokens_to_ids(self.config.pool_token)
        if self.pool_id is None:
            raise ValueError("pool_id not set and tokenizer unavailable to resolve it.")
        return self.pool_id

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        **kwargs,
    ) -> SequenceClassifierOutput:
        pool_id = self._resolve_pool_id(input_ids)
        out = self.base(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True,
            use_cache=False,
        )
        hidden = out.hidden_states[-1]
        pos = last_token_index(input_ids, pool_id)
        pooled = hidden[torch.arange(len(input_ids), device=input_ids.device), pos]
        # Match head weight dtype for safety (bfloat16 training etc.)
        if pooled.dtype != self.head.weight.dtype:
            pooled = pooled.to(self.head.weight.dtype)
        logits = self.head(pooled)
        return SequenceClassifierOutput(logits=logits)

    def save_pretrained(self, save_directory: str, *args, **kwargs):
        if not getattr(self.config, "auto_map", None):
            self.config.auto_map = {"AutoModel": "modeling_student_prm.StudentPRM", "AutoConfig": "modeling_student_prm.StudentPRMConfig"}
        if not getattr(self.config, "architectures", None):
            self.config.architectures = ["StudentPRM"]
        super().save_pretrained(save_directory, *args, **kwargs)