File size: 5,718 Bytes
d4c0208 70d58ee 814735f d4c0208 70d58ee d4c0208 814735f e346e0a 70d58ee d4c0208 70d58ee d4c0208 70d58ee d4c0208 70d58ee d4c0208 1965f5e 70d58ee 814735f ce118fe 814735f 22e477f 814735f 1965f5e d4c0208 814735f 70d58ee 814735f 70d58ee d4c0208 1965f5e 814735f f84cc1f 1965f5e 814735f ada8f7a 70d58ee ada8f7a 58f73dc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
import os
from typing import Optional
from transformers import AutoModelForCausalLM, Qwen3ForCausalLM, AutoTokenizer, AutoConfig
from huggingface_hub import hf_hub_download
import torch
import torch.nn as nn
from warnings import warn
# Define a custom model that wraps a causal LM and adds a regression head
class CausalLMForRegression(nn.Module):
config_class = Qwen3ForCausalLM.config_class
base_model_prefix = "model"
def __init__(self, base_model_name):
super().__init__()
# Load the causal LM with hidden states enabled
self.model = AutoModelForCausalLM.from_pretrained(
base_model_name,
output_hidden_states=True
)
self.base_model = base_model_name
# Using pooled hidden state to a single scalar
self.regression_head = nn.Linear(self.model.config.hidden_size, 1)
print(f"Initializing difficulty scorer from scratch using {base_model_name} as a base!")
self._keys_to_ignore_on_save = []
def forward(self, input_ids, attention_mask=None, labels=None):
# Flatten extra dimensions if present
if input_ids.dim() == 3:
# e.g. from (accum_steps, batch_size, seq_length) to (accum_steps * batch_size, seq_length)
input_ids = input_ids.view(-1, input_ids.size(-1))
if attention_mask is not None and attention_mask.dim() == 3:
attention_mask = attention_mask.view(-1, attention_mask.size(-1))
outputs = self.model(input_ids, attention_mask=attention_mask)
hidden_states = outputs.hidden_states[-1] # Now should have shape: (batch, seq_length, hidden_size)
# Mean-pooling over non-padding tokens
if attention_mask is not None:
mask = attention_mask.unsqueeze(-1).expand_as(hidden_states).to(hidden_states.dtype)
hidden_sum = torch.sum(hidden_states * mask, dim=1)
lengths = mask.sum(dim=1)
pooled = hidden_sum / lengths
else:
pooled = hidden_states.mean(dim=1)
logits = self.regression_head(pooled).squeeze(-1)
loss = None
if labels is not None:
loss_fn = nn.HuberLoss() #nn.MSELoss()
loss = loss_fn(logits, labels)
return {"loss": loss, "logits": logits}
def get_input_embeddings(self):
# Delegate to the underlying causal LM's get_input_embeddings method.
return self.model.get_input_embeddings()
def save_pretrained(self, output_dir, safe_serialization=False):
os.makedirs(output_dir, exist_ok=True)
# Ensure we are saving the entire model properly
model_state_dict = self.model.state_dict()
for key, value in model_state_dict.items():
if value.shape[0] == 0:
print(f"Warning: Tensor {key} has shape {value.shape}, which may be problematic.")
# Save model with proper weight tie handling
self.model.save_pretrained(output_dir, safe_serialization=False)
torch.save(self.regression_head.state_dict(), os.path.join(output_dir, "regression_head.bin"))
def get_tokenizer(self):
try:
tokenizer = AutoTokenizer.from_pretrained(self.model.name_or_path)
print(f"Loaded tokenizer from {self.model.name_or_path}")
except:
tokenizer = AutoTokenizer.from_pretrained(self.base_model)
print(f"Loaded tokenizer from {self.base_model}")
return tokenizer
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
warn(f"The `from_pretrained` method is currently only implemented for models with Qwen3-base.")
cfg = kwargs.pop("config", None)
if cfg is None:
cfg = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
cfg.output_hidden_states = True
if "trust_remote_code" in kwargs:
_ = kwargs.pop("trust_remote_code")
backbone = Qwen3ForCausalLM.from_pretrained(
pretrained_model_name_or_path,
*model_args,
config=cfg,
trust_remote_code=False,
**kwargs
)
if os.path.isdir(pretrained_model_name_or_path):
head_path = os.path.join(pretrained_model_name_or_path,
"regression_head.bin")
else:
head_path = hf_hub_download(
repo_id=pretrained_model_name_or_path,
filename="regression_head.bin",
repo_type="model"
)
inst = cls.__new__(cls)
nn.Module.__init__(inst)
inst.model = backbone
inst.regression_head = nn.Linear(cfg.hidden_size, 1)
inst._keys_to_ignore_on_save = []
inst.base_model = "Qwen/Qwen3-8B"
if os.path.exists(head_path):
inst.regression_head.load_state_dict(
torch.load(head_path, map_location="cpu")
)
else:
print("'regression_head.bin' not found – initialising randomly.")
return inst
@torch.no_grad()
def generate(self, *args, **kwargs):
"""
Wrapper that forwards all arguments to the underlying causal‑LM so that GenerationMixin‑based helpers
(sampling, beam search, prepare_inputs_for_generation, etc.) keep working.
"""
return self.model.generate(*args, **kwargs)
def prepare_inputs_for_generation(self, *args, **kwargs):
"""
Same here: to be able to load the model with AutoModelForCausalLM, we have to forward this method
"""
return self.model.prepare_inputs_for_generation(*args, **kwargs) |