File size: 5,718 Bytes
d4c0208
 
70d58ee
814735f
d4c0208
 
70d58ee
d4c0208
 
 
 
814735f
e346e0a
 
70d58ee
d4c0208
 
 
70d58ee
d4c0208
 
70d58ee
d4c0208
 
70d58ee
 
d4c0208
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1965f5e
70d58ee
814735f
 
 
 
 
ce118fe
 
 
814735f
22e477f
 
814735f
 
 
1965f5e
d4c0208
814735f
 
70d58ee
814735f
 
 
 
 
 
 
 
 
 
 
 
70d58ee
d4c0208
1965f5e
814735f
f84cc1f
 
1965f5e
814735f
 
 
ada8f7a
 
 
 
70d58ee
 
ada8f7a
58f73dc
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import os
from typing import Optional
from transformers import AutoModelForCausalLM, Qwen3ForCausalLM, AutoTokenizer, AutoConfig
from huggingface_hub import hf_hub_download
import torch
import torch.nn as nn
from warnings import warn


# Define a custom model that wraps a causal LM and adds a regression head
class CausalLMForRegression(nn.Module):
    config_class = Qwen3ForCausalLM.config_class
    base_model_prefix = "model"
    
    def __init__(self, base_model_name):
        super().__init__()
        # Load the causal LM with hidden states enabled
        self.model = AutoModelForCausalLM.from_pretrained(
            base_model_name, 
            output_hidden_states=True
        )
        self.base_model = base_model_name 
        # Using pooled hidden state to a single scalar
        self.regression_head = nn.Linear(self.model.config.hidden_size, 1)

        print(f"Initializing difficulty scorer from scratch using {base_model_name} as a base!")
        self._keys_to_ignore_on_save = []

    def forward(self, input_ids, attention_mask=None, labels=None):
        # Flatten extra dimensions if present
        if input_ids.dim() == 3:
            # e.g. from (accum_steps, batch_size, seq_length) to (accum_steps * batch_size, seq_length)
            input_ids = input_ids.view(-1, input_ids.size(-1))
        if attention_mask is not None and attention_mask.dim() == 3:
            attention_mask = attention_mask.view(-1, attention_mask.size(-1))

        outputs = self.model(input_ids, attention_mask=attention_mask)
        hidden_states = outputs.hidden_states[-1]  # Now should have shape: (batch, seq_length, hidden_size)
        
        # Mean-pooling over non-padding tokens
        if attention_mask is not None:
            mask = attention_mask.unsqueeze(-1).expand_as(hidden_states).to(hidden_states.dtype)
            hidden_sum = torch.sum(hidden_states * mask, dim=1)
            lengths = mask.sum(dim=1)
            pooled = hidden_sum / lengths
        else:
            pooled = hidden_states.mean(dim=1)
        
        logits = self.regression_head(pooled).squeeze(-1)
        
        loss = None
        if labels is not None:
            loss_fn = nn.HuberLoss() #nn.MSELoss()
            loss = loss_fn(logits, labels)
        
        return {"loss": loss, "logits": logits}

    def get_input_embeddings(self):
        # Delegate to the underlying causal LM's get_input_embeddings method.
        return self.model.get_input_embeddings()

    def save_pretrained(self, output_dir, safe_serialization=False):
        os.makedirs(output_dir, exist_ok=True)

        # Ensure we are saving the entire model properly
        model_state_dict = self.model.state_dict()
        for key, value in model_state_dict.items():
            if value.shape[0] == 0:
                print(f"Warning: Tensor {key} has shape {value.shape}, which may be problematic.")

        # Save model with proper weight tie handling
        self.model.save_pretrained(output_dir, safe_serialization=False)
        torch.save(self.regression_head.state_dict(), os.path.join(output_dir, "regression_head.bin"))


    def get_tokenizer(self):
        try:
            tokenizer = AutoTokenizer.from_pretrained(self.model.name_or_path)
            print(f"Loaded tokenizer from {self.model.name_or_path}")
        except:
            tokenizer = AutoTokenizer.from_pretrained(self.base_model)
            print(f"Loaded tokenizer from {self.base_model}")
        return tokenizer

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
        warn(f"The `from_pretrained` method is currently only implemented for models with Qwen3-base.")
        cfg = kwargs.pop("config", None)
        if cfg is None:
            cfg = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
        cfg.output_hidden_states = True

        if "trust_remote_code" in kwargs:
            _ = kwargs.pop("trust_remote_code")

        backbone = Qwen3ForCausalLM.from_pretrained(
            pretrained_model_name_or_path,
            *model_args,
            config=cfg,
            trust_remote_code=False,
            **kwargs
        )

        if os.path.isdir(pretrained_model_name_or_path):
            head_path = os.path.join(pretrained_model_name_or_path,
                                        "regression_head.bin")
        else:
            head_path = hf_hub_download(
                repo_id=pretrained_model_name_or_path,
                filename="regression_head.bin",
                repo_type="model"
            )

        inst = cls.__new__(cls)
        nn.Module.__init__(inst)
        inst.model = backbone
        inst.regression_head = nn.Linear(cfg.hidden_size, 1)
        inst._keys_to_ignore_on_save = []
        inst.base_model = "Qwen/Qwen3-8B"

        if os.path.exists(head_path):
            inst.regression_head.load_state_dict(
                torch.load(head_path, map_location="cpu")
            )
        else:
            print("'regression_head.bin' not found – initialising randomly.")

        return inst

    @torch.no_grad()
    def generate(self, *args, **kwargs):
        """
        Wrapper that forwards all arguments to the underlying causal‑LM so that GenerationMixin‑based helpers 
        (sampling, beam search, prepare_inputs_for_generation, etc.) keep working.
        """
        return self.model.generate(*args, **kwargs)

    def prepare_inputs_for_generation(self, *args, **kwargs):
        """
        Same here: to be able to load the model with AutoModelForCausalLM, we have to forward this method
        """
        return self.model.prepare_inputs_for_generation(*args, **kwargs)