lucweber's picture
Update model.py
ce118fe verified
import os
from typing import Optional
from transformers import AutoModelForCausalLM, Qwen3ForCausalLM, AutoTokenizer, AutoConfig
from huggingface_hub import hf_hub_download
import torch
import torch.nn as nn
from warnings import warn
# Define a custom model that wraps a causal LM and adds a regression head
class CausalLMForRegression(nn.Module):
config_class = Qwen3ForCausalLM.config_class
base_model_prefix = "model"
def __init__(self, base_model_name):
super().__init__()
# Load the causal LM with hidden states enabled
self.model = AutoModelForCausalLM.from_pretrained(
base_model_name,
output_hidden_states=True
)
self.base_model = base_model_name
# Using pooled hidden state to a single scalar
self.regression_head = nn.Linear(self.model.config.hidden_size, 1)
print(f"Initializing difficulty scorer from scratch using {base_model_name} as a base!")
self._keys_to_ignore_on_save = []
def forward(self, input_ids, attention_mask=None, labels=None):
# Flatten extra dimensions if present
if input_ids.dim() == 3:
# e.g. from (accum_steps, batch_size, seq_length) to (accum_steps * batch_size, seq_length)
input_ids = input_ids.view(-1, input_ids.size(-1))
if attention_mask is not None and attention_mask.dim() == 3:
attention_mask = attention_mask.view(-1, attention_mask.size(-1))
outputs = self.model(input_ids, attention_mask=attention_mask)
hidden_states = outputs.hidden_states[-1] # Now should have shape: (batch, seq_length, hidden_size)
# Mean-pooling over non-padding tokens
if attention_mask is not None:
mask = attention_mask.unsqueeze(-1).expand_as(hidden_states).to(hidden_states.dtype)
hidden_sum = torch.sum(hidden_states * mask, dim=1)
lengths = mask.sum(dim=1)
pooled = hidden_sum / lengths
else:
pooled = hidden_states.mean(dim=1)
logits = self.regression_head(pooled).squeeze(-1)
loss = None
if labels is not None:
loss_fn = nn.HuberLoss() #nn.MSELoss()
loss = loss_fn(logits, labels)
return {"loss": loss, "logits": logits}
def get_input_embeddings(self):
# Delegate to the underlying causal LM's get_input_embeddings method.
return self.model.get_input_embeddings()
def save_pretrained(self, output_dir, safe_serialization=False):
os.makedirs(output_dir, exist_ok=True)
# Ensure we are saving the entire model properly
model_state_dict = self.model.state_dict()
for key, value in model_state_dict.items():
if value.shape[0] == 0:
print(f"Warning: Tensor {key} has shape {value.shape}, which may be problematic.")
# Save model with proper weight tie handling
self.model.save_pretrained(output_dir, safe_serialization=False)
torch.save(self.regression_head.state_dict(), os.path.join(output_dir, "regression_head.bin"))
def get_tokenizer(self):
try:
tokenizer = AutoTokenizer.from_pretrained(self.model.name_or_path)
print(f"Loaded tokenizer from {self.model.name_or_path}")
except:
tokenizer = AutoTokenizer.from_pretrained(self.base_model)
print(f"Loaded tokenizer from {self.base_model}")
return tokenizer
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
warn(f"The `from_pretrained` method is currently only implemented for models with Qwen3-base.")
cfg = kwargs.pop("config", None)
if cfg is None:
cfg = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
cfg.output_hidden_states = True
if "trust_remote_code" in kwargs:
_ = kwargs.pop("trust_remote_code")
backbone = Qwen3ForCausalLM.from_pretrained(
pretrained_model_name_or_path,
*model_args,
config=cfg,
trust_remote_code=False,
**kwargs
)
if os.path.isdir(pretrained_model_name_or_path):
head_path = os.path.join(pretrained_model_name_or_path,
"regression_head.bin")
else:
head_path = hf_hub_download(
repo_id=pretrained_model_name_or_path,
filename="regression_head.bin",
repo_type="model"
)
inst = cls.__new__(cls)
nn.Module.__init__(inst)
inst.model = backbone
inst.regression_head = nn.Linear(cfg.hidden_size, 1)
inst._keys_to_ignore_on_save = []
inst.base_model = "Qwen/Qwen3-8B"
if os.path.exists(head_path):
inst.regression_head.load_state_dict(
torch.load(head_path, map_location="cpu")
)
else:
print("'regression_head.bin' not found – initialising randomly.")
return inst
@torch.no_grad()
def generate(self, *args, **kwargs):
"""
Wrapper that forwards all arguments to the underlying causal‑LM so that GenerationMixin‑based helpers
(sampling, beam search, prepare_inputs_for_generation, etc.) keep working.
"""
return self.model.generate(*args, **kwargs)
def prepare_inputs_for_generation(self, *args, **kwargs):
"""
Same here: to be able to load the model with AutoModelForCausalLM, we have to forward this method
"""
return self.model.prepare_inputs_for_generation(*args, **kwargs)