File size: 14,129 Bytes
d7a1dee e96e9e8 d7a1dee 7b22090 d7a1dee 7b22090 d7a1dee 7b22090 d7a1dee 7b22090 d7a1dee | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 | import json
import logging
from dataclasses import dataclass
import os
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
from torch import Tensor
from transformers import (
AutoModelForCausalLM,
AutoConfig,
PretrainedConfig,
PreTrainedModel,
EvalPrediction
)
from transformers.modeling_outputs import CausalLMOutputWithPast
def get_response_positions(
y,
response_token_ids: List[int]
) -> Tensor:
response_token_ids_idxs = []
for i, _ in enumerate(y):
matched_token_positions = np.where(y[i] == response_token_ids[0])[0]
for assistant_idx in matched_token_positions:
assistant_idx = int(assistant_idx)
if (response_token_ids == y[i][assistant_idx : assistant_idx + len(response_token_ids)].tolist()):
response_token_ids_idxs.append(assistant_idx + len(response_token_ids))
return torch.tensor(response_token_ids_idxs)
MODEL_TYPE = "lm_with_head"
@dataclass
class LMWithHeadOutputWithPast(CausalLMOutputWithPast):
classification_logits: Optional[torch.FloatTensor] = None
classification_loss: Optional[torch.FloatTensor] = None
@dataclass
class LMWithHeadGenerationOutput:
sequences: torch.LongTensor
classification_logits: torch.FloatTensor
hidden_states: torch.FloatTensor
base_output: LMWithHeadOutputWithPast
class LMWithHeadConfig(PretrainedConfig):
model_type = MODEL_TYPE
def __init__(
self,
base_model_id: str = None,
num_labels: int = 2,
classifier_dropout: float = 0.1,
freeze_base: bool = True,
# New configurable head parameters
classifier_hidden_layers: List[int] = None, # List of hidden dimensions
classifier_activation: str = "relu", # Activation function name
**kwargs,
):
super().__init__(**kwargs)
self.base_model_id = base_model_id
self.num_labels = num_labels
self.classifier_dropout = classifier_dropout
self.freeze_base = freeze_base
# Default to empty list if None (single layer classifier)
self.classifier_hidden_layers = classifier_hidden_layers or []
self.classifier_activation = classifier_activation
class ConfigurableClassifierHead(nn.Module):
"""Configurable classifier head with variable number of hidden layers and activations."""
def __init__(
self,
input_dim: int,
hidden_dims: List[int],
output_dim: int,
dropout_rate: float = 0.1,
activation: str = "relu"
):
super().__init__()
# Map activation function name to actual function
activation_map = {
"relu": nn.ReLU(),
"gelu": nn.GELU(),
"silu": nn.SiLU(),
"tanh": nn.Tanh(),
"leaky_relu": nn.LeakyReLU(),
"elu": nn.ELU(),
}
if activation not in activation_map:
raise ValueError(f"Unsupported activation: {activation}. "
f"Choose from: {list(activation_map.keys())}")
activation_fn = activation_map[activation]
# Build layers
layers = []
# Input dimension
current_dim = input_dim
# Add hidden layers if specified
if hidden_dims:
for hidden_dim in hidden_dims:
layers.append(nn.Linear(current_dim, hidden_dim))
layers.append(activation_fn)
layers.append(nn.Dropout(dropout_rate))
current_dim = hidden_dim
else:
# If no hidden dims are provided, add a dropout layer before the output layer
layers.append(nn.Dropout(dropout_rate))
# Output layer
layers.append(nn.Linear(current_dim, output_dim))
self.classifier = nn.Sequential(*layers)
def forward(self, x):
return self.classifier(x)
class LMWithHead(PreTrainedModel):
config_class = LMWithHeadConfig
def __init__(self, config: LMWithHeadConfig):
super().__init__(config)
# Load the backbone straight from HF (or local cache)
if config.base_model_id is None:
raise ValueError("base_model_id must be specified in the config.")
self.base = AutoModelForCausalLM.from_pretrained(config.base_model_id)
if config.freeze_base:
for p in self.base.parameters():
p.requires_grad_(False)
# Get the hidden size from the base model
hid = self.base.config.hidden_size
# Initialize the configurable classifier head
# If no hidden layers are specified, this will create a single-layer classifier
self.classifier = ConfigurableClassifierHead(
input_dim=hid,
hidden_dims=config.classifier_hidden_layers,
output_dim=config.num_labels,
dropout_rate=config.classifier_dropout,
activation=config.classifier_activation
)
self.post_init() # initialize the new head
def forward(
self,
input_ids,
attention_mask=None,
labels=None,
class_labels=None,
class_labels_mask=None,
output_hidden_states=False,
**kwargs
):
out = self.base(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
output_hidden_states=True,
**kwargs,
)
hidden_states = out.hidden_states[-1] # (B, L, H)
logits_cls = self.classifier(hidden_states) # (B, L, C)
loss_cls = None
if class_labels is not None:
# boolean mask of shape (B, L)
mask = class_labels_mask # rename for clarity
if mask.any(): # skip batches with no valid tokens
if self.config.num_labels == 1: # binary (BCE)
preds = logits_cls[mask].squeeze(-1) # (N,)
target = class_labels[mask].float() # (N,)
loss_fct = nn.BCEWithLogitsLoss()
loss_cls = loss_fct(preds, target)
else: # multi‑class (CE)
preds = logits_cls[mask] # (N, C)
target = class_labels[mask] # (N,)
loss_fct = nn.CrossEntropyLoss()
loss_cls = loss_fct(preds, target)
else:
# Optional: set loss to zero so it still back‑propagates
loss_cls = torch.tensor(0.0, device=logits_cls.device, requires_grad=True)
# combine losses if you like
total_loss = 0
if out.loss is not None:
total_loss += out.loss
if loss_cls is not None:
total_loss += loss_cls
# Use the dataclass for output
return LMWithHeadOutputWithPast(
loss=total_loss,
logits=out.logits,
past_key_values=out.past_key_values,
hidden_states=out.hidden_states if output_hidden_states else None,
attentions=out.attentions if kwargs.get("output_attentions") else None,
classification_logits=logits_cls,
classification_loss=loss_cls,
)
def save_pretrained(self, save_dir, head_only=True, **kwargs):
os.makedirs(save_dir, exist_ok=True)
self.config.save_pretrained(save_dir)
if head_only: # just the delta
torch.save(self.classifier.state_dict(), os.path.join(save_dir, "classifier.pt"))
# tiny helper to remember which backbone to reload
with open(os.path.join(save_dir, "base.json"), "w") as f:
json.dump({"base_model_id": self.config.base_model_id}, f)
else: # normal full save
super().save_pretrained(save_dir, **kwargs)
@classmethod
def from_pretrained(cls, path, **kwargs):
# Get config first
config = kwargs.get("config", None)
if config is None:
config = LMWithHeadConfig.from_pretrained(path, **kwargs)
# Check if we're loading from a local path or a Hub repo
is_local = os.path.isdir(path)
# Try to load custom checkpoint structure
try:
if is_local:
# Local path approach
base_json_path = os.path.join(path, "base.json")
classifier_path = os.path.join(path, "classifier.pt")
else:
# Hub approach - use the Hugging Face Hub file system
from huggingface_hub import hf_hub_download
base_json_path = hf_hub_download(repo_id=path, filename="base.json")
classifier_path = hf_hub_download(repo_id=path, filename="classifier.pt")
# Load base model ID from base.json
with open(base_json_path) as f:
base_id = json.load(f)["base_model_id"]
# Update config
config.base_model_id = base_id
# Create model with config
model = cls(config)
# Load classifier weights
head_sd = torch.load(classifier_path, map_location="cpu")
model.classifier.load_state_dict(head_sd, strict=True)
return model
except (FileNotFoundError, OSError, Exception) as e:
# If custom loading fails, try standard approach
# This will likely fail unless there are pytorch_model.bin files
try:
return super().from_pretrained(path, **kwargs)
except Exception as inner_e:
# If both approaches fail, provide a helpful error message
raise ValueError(
f"Could not load model from {path}. "
f"Custom loading failed with: {str(e)}. "
f"Standard loading failed with: {str(inner_e)}. "
f"Make sure the repository contains either 'base.json' and 'classifier.pt' files, "
f"or standard model weights files."
)
def generate_with_classification(
self,
input_ids: torch.LongTensor,
attention_mask: Optional[torch.LongTensor] = None,
**generate_kwargs,
) -> Dict[str, torch.Tensor]:
# Step 1: generate tokens with base model
gen_output = self.base.generate(
input_ids=input_ids,
attention_mask=attention_mask,
return_dict_in_generate=True,
output_hidden_states=True, # ensure we can get states later
**generate_kwargs,
)
# Step 2: re-run forward pass to get hidden states for classification
# This is necessary because `generate()` does not return all hidden states
with torch.no_grad():
outputs = self.base(
input_ids=gen_output.sequences,
# TODO: this currently is hardcoded to Llama!!!
attention_mask=(gen_output.sequences != 128009), #self.base.config.pad_token_id),
output_hidden_states=True,
)
last_hidden = outputs.hidden_states[-1] # (B, L, H)
classification_logits = self.classifier(last_hidden) # (B, L, C)
return LMWithHeadGenerationOutput(
sequences=gen_output.sequences,
classification_logits=classification_logits,
hidden_states=last_hidden,
base_output=gen_output,
)
def mask_range(
tensor,
fill_value: float,
start_pos,
end_pos,
):
if end_pos is not None:
mask = torch.arange(tensor.shape[1], device=tensor.device).unsqueeze(0)
mask = (mask >= start_pos.unsqueeze(1)) & (mask <= end_pos.unsqueeze(1))
else:
mask = torch.arange(tensor.shape[1], device=tensor.device).unsqueeze(
0
) == start_pos.unsqueeze(1)
return torch.where(mask, tensor, fill_value)
class LMWithHeadComputeMetrics:
def __init__(self, response_idx: int | List[int]):
"""
Args:
response_idx (int | List[int]): The index of the response token(s) in the vocabulary,
i.e. <|assistant|>
"""
if isinstance(response_idx, int):
response_idx = [response_idx]
self.response_idx = response_idx
def __call__(self, p: EvalPrediction) -> Dict:
metrics = {}
response_start_idx = get_response_positions(p.inputs, self.response_idx)
label_mask = p.label_ids[1] & (p.label_ids[0] != -100)
# if not all(label_mask[torch.arange(len(label_mask)), response_start_idx+1]):
# logging.warning("Label mask does not match response start index, may have included an offset. Loss metrics may be incorrect")
# TODO: get standard perplexity loss
# getting probs of classification on harmfulness
logits = torch.tensor(p.predictions[1])
probs = torch.softmax(logits, dim=-1)
preds = probs.argmax(dim=-1)
# pct tokens harmful
pct_harmful_all = preds[label_mask].to(float).mean().item()
# pct correct classified
pct_correct = (preds == p.label_ids[0])[label_mask].to(float).mean().item()
# pct strings correctly classified anywhere
_any_harmful = (preds * label_mask).any(-1)
pct_any_harmful = _any_harmful.to(float).mean().item()
pct_any_correct = (_any_harmful == (p.label_ids[0] * label_mask).any(-1)).to(float).mean().item()
metrics["pct_harmful"] = pct_harmful_all
metrics["pct_correct"] = pct_correct
metrics["pct_any_in_seq_harmful"] = pct_any_harmful
metrics["pct_any_in_seq_correct"] = pct_any_correct
return metrics
# registration so you can call AutoModelForCausalLM.from_pretrained(...)
AutoConfig.register(MODEL_TYPE, LMWithHeadConfig)
AutoModelForCausalLM.register(LMWithHeadConfig, LMWithHead)
|