llama3-2_classifier / modeling_lm_classifier.py
busycalibrating's picture
Update modeling_lm_classifier.py
e96e9e8 verified
import json
import logging
from dataclasses import dataclass
import os
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
from torch import Tensor
from transformers import (
AutoModelForCausalLM,
AutoConfig,
PretrainedConfig,
PreTrainedModel,
EvalPrediction
)
from transformers.modeling_outputs import CausalLMOutputWithPast
def get_response_positions(
y,
response_token_ids: List[int]
) -> Tensor:
response_token_ids_idxs = []
for i, _ in enumerate(y):
matched_token_positions = np.where(y[i] == response_token_ids[0])[0]
for assistant_idx in matched_token_positions:
assistant_idx = int(assistant_idx)
if (response_token_ids == y[i][assistant_idx : assistant_idx + len(response_token_ids)].tolist()):
response_token_ids_idxs.append(assistant_idx + len(response_token_ids))
return torch.tensor(response_token_ids_idxs)
MODEL_TYPE = "lm_with_head"
@dataclass
class LMWithHeadOutputWithPast(CausalLMOutputWithPast):
classification_logits: Optional[torch.FloatTensor] = None
classification_loss: Optional[torch.FloatTensor] = None
@dataclass
class LMWithHeadGenerationOutput:
sequences: torch.LongTensor
classification_logits: torch.FloatTensor
hidden_states: torch.FloatTensor
base_output: LMWithHeadOutputWithPast
class LMWithHeadConfig(PretrainedConfig):
model_type = MODEL_TYPE
def __init__(
self,
base_model_id: str = None,
num_labels: int = 2,
classifier_dropout: float = 0.1,
freeze_base: bool = True,
# New configurable head parameters
classifier_hidden_layers: List[int] = None, # List of hidden dimensions
classifier_activation: str = "relu", # Activation function name
**kwargs,
):
super().__init__(**kwargs)
self.base_model_id = base_model_id
self.num_labels = num_labels
self.classifier_dropout = classifier_dropout
self.freeze_base = freeze_base
# Default to empty list if None (single layer classifier)
self.classifier_hidden_layers = classifier_hidden_layers or []
self.classifier_activation = classifier_activation
class ConfigurableClassifierHead(nn.Module):
"""Configurable classifier head with variable number of hidden layers and activations."""
def __init__(
self,
input_dim: int,
hidden_dims: List[int],
output_dim: int,
dropout_rate: float = 0.1,
activation: str = "relu"
):
super().__init__()
# Map activation function name to actual function
activation_map = {
"relu": nn.ReLU(),
"gelu": nn.GELU(),
"silu": nn.SiLU(),
"tanh": nn.Tanh(),
"leaky_relu": nn.LeakyReLU(),
"elu": nn.ELU(),
}
if activation not in activation_map:
raise ValueError(f"Unsupported activation: {activation}. "
f"Choose from: {list(activation_map.keys())}")
activation_fn = activation_map[activation]
# Build layers
layers = []
# Input dimension
current_dim = input_dim
# Add hidden layers if specified
if hidden_dims:
for hidden_dim in hidden_dims:
layers.append(nn.Linear(current_dim, hidden_dim))
layers.append(activation_fn)
layers.append(nn.Dropout(dropout_rate))
current_dim = hidden_dim
else:
# If no hidden dims are provided, add a dropout layer before the output layer
layers.append(nn.Dropout(dropout_rate))
# Output layer
layers.append(nn.Linear(current_dim, output_dim))
self.classifier = nn.Sequential(*layers)
def forward(self, x):
return self.classifier(x)
class LMWithHead(PreTrainedModel):
config_class = LMWithHeadConfig
def __init__(self, config: LMWithHeadConfig):
super().__init__(config)
# Load the backbone straight from HF (or local cache)
if config.base_model_id is None:
raise ValueError("base_model_id must be specified in the config.")
self.base = AutoModelForCausalLM.from_pretrained(config.base_model_id)
if config.freeze_base:
for p in self.base.parameters():
p.requires_grad_(False)
# Get the hidden size from the base model
hid = self.base.config.hidden_size
# Initialize the configurable classifier head
# If no hidden layers are specified, this will create a single-layer classifier
self.classifier = ConfigurableClassifierHead(
input_dim=hid,
hidden_dims=config.classifier_hidden_layers,
output_dim=config.num_labels,
dropout_rate=config.classifier_dropout,
activation=config.classifier_activation
)
self.post_init() # initialize the new head
def forward(
self,
input_ids,
attention_mask=None,
labels=None,
class_labels=None,
class_labels_mask=None,
output_hidden_states=False,
**kwargs
):
out = self.base(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
output_hidden_states=True,
**kwargs,
)
hidden_states = out.hidden_states[-1] # (B, L, H)
logits_cls = self.classifier(hidden_states) # (B, L, C)
loss_cls = None
if class_labels is not None:
# boolean mask of shape (B, L)
mask = class_labels_mask # rename for clarity
if mask.any(): # skip batches with no valid tokens
if self.config.num_labels == 1: # binary (BCE)
preds = logits_cls[mask].squeeze(-1) # (N,)
target = class_labels[mask].float() # (N,)
loss_fct = nn.BCEWithLogitsLoss()
loss_cls = loss_fct(preds, target)
else: # multi‑class (CE)
preds = logits_cls[mask] # (N, C)
target = class_labels[mask] # (N,)
loss_fct = nn.CrossEntropyLoss()
loss_cls = loss_fct(preds, target)
else:
# Optional: set loss to zero so it still back‑propagates
loss_cls = torch.tensor(0.0, device=logits_cls.device, requires_grad=True)
# combine losses if you like
total_loss = 0
if out.loss is not None:
total_loss += out.loss
if loss_cls is not None:
total_loss += loss_cls
# Use the dataclass for output
return LMWithHeadOutputWithPast(
loss=total_loss,
logits=out.logits,
past_key_values=out.past_key_values,
hidden_states=out.hidden_states if output_hidden_states else None,
attentions=out.attentions if kwargs.get("output_attentions") else None,
classification_logits=logits_cls,
classification_loss=loss_cls,
)
def save_pretrained(self, save_dir, head_only=True, **kwargs):
os.makedirs(save_dir, exist_ok=True)
self.config.save_pretrained(save_dir)
if head_only: # just the delta
torch.save(self.classifier.state_dict(), os.path.join(save_dir, "classifier.pt"))
# tiny helper to remember which backbone to reload
with open(os.path.join(save_dir, "base.json"), "w") as f:
json.dump({"base_model_id": self.config.base_model_id}, f)
else: # normal full save
super().save_pretrained(save_dir, **kwargs)
@classmethod
def from_pretrained(cls, path, **kwargs):
# Get config first
config = kwargs.get("config", None)
if config is None:
config = LMWithHeadConfig.from_pretrained(path, **kwargs)
# Check if we're loading from a local path or a Hub repo
is_local = os.path.isdir(path)
# Try to load custom checkpoint structure
try:
if is_local:
# Local path approach
base_json_path = os.path.join(path, "base.json")
classifier_path = os.path.join(path, "classifier.pt")
else:
# Hub approach - use the Hugging Face Hub file system
from huggingface_hub import hf_hub_download
base_json_path = hf_hub_download(repo_id=path, filename="base.json")
classifier_path = hf_hub_download(repo_id=path, filename="classifier.pt")
# Load base model ID from base.json
with open(base_json_path) as f:
base_id = json.load(f)["base_model_id"]
# Update config
config.base_model_id = base_id
# Create model with config
model = cls(config)
# Load classifier weights
head_sd = torch.load(classifier_path, map_location="cpu")
model.classifier.load_state_dict(head_sd, strict=True)
return model
except (FileNotFoundError, OSError, Exception) as e:
# If custom loading fails, try standard approach
# This will likely fail unless there are pytorch_model.bin files
try:
return super().from_pretrained(path, **kwargs)
except Exception as inner_e:
# If both approaches fail, provide a helpful error message
raise ValueError(
f"Could not load model from {path}. "
f"Custom loading failed with: {str(e)}. "
f"Standard loading failed with: {str(inner_e)}. "
f"Make sure the repository contains either 'base.json' and 'classifier.pt' files, "
f"or standard model weights files."
)
def generate_with_classification(
self,
input_ids: torch.LongTensor,
attention_mask: Optional[torch.LongTensor] = None,
**generate_kwargs,
) -> Dict[str, torch.Tensor]:
# Step 1: generate tokens with base model
gen_output = self.base.generate(
input_ids=input_ids,
attention_mask=attention_mask,
return_dict_in_generate=True,
output_hidden_states=True, # ensure we can get states later
**generate_kwargs,
)
# Step 2: re-run forward pass to get hidden states for classification
# This is necessary because `generate()` does not return all hidden states
with torch.no_grad():
outputs = self.base(
input_ids=gen_output.sequences,
# TODO: this currently is hardcoded to Llama!!!
attention_mask=(gen_output.sequences != 128009), #self.base.config.pad_token_id),
output_hidden_states=True,
)
last_hidden = outputs.hidden_states[-1] # (B, L, H)
classification_logits = self.classifier(last_hidden) # (B, L, C)
return LMWithHeadGenerationOutput(
sequences=gen_output.sequences,
classification_logits=classification_logits,
hidden_states=last_hidden,
base_output=gen_output,
)
def mask_range(
tensor,
fill_value: float,
start_pos,
end_pos,
):
if end_pos is not None:
mask = torch.arange(tensor.shape[1], device=tensor.device).unsqueeze(0)
mask = (mask >= start_pos.unsqueeze(1)) & (mask <= end_pos.unsqueeze(1))
else:
mask = torch.arange(tensor.shape[1], device=tensor.device).unsqueeze(
0
) == start_pos.unsqueeze(1)
return torch.where(mask, tensor, fill_value)
class LMWithHeadComputeMetrics:
def __init__(self, response_idx: int | List[int]):
"""
Args:
response_idx (int | List[int]): The index of the response token(s) in the vocabulary,
i.e. <|assistant|>
"""
if isinstance(response_idx, int):
response_idx = [response_idx]
self.response_idx = response_idx
def __call__(self, p: EvalPrediction) -> Dict:
metrics = {}
response_start_idx = get_response_positions(p.inputs, self.response_idx)
label_mask = p.label_ids[1] & (p.label_ids[0] != -100)
# if not all(label_mask[torch.arange(len(label_mask)), response_start_idx+1]):
# logging.warning("Label mask does not match response start index, may have included an offset. Loss metrics may be incorrect")
# TODO: get standard perplexity loss
# getting probs of classification on harmfulness
logits = torch.tensor(p.predictions[1])
probs = torch.softmax(logits, dim=-1)
preds = probs.argmax(dim=-1)
# pct tokens harmful
pct_harmful_all = preds[label_mask].to(float).mean().item()
# pct correct classified
pct_correct = (preds == p.label_ids[0])[label_mask].to(float).mean().item()
# pct strings correctly classified anywhere
_any_harmful = (preds * label_mask).any(-1)
pct_any_harmful = _any_harmful.to(float).mean().item()
pct_any_correct = (_any_harmful == (p.label_ids[0] * label_mask).any(-1)).to(float).mean().item()
metrics["pct_harmful"] = pct_harmful_all
metrics["pct_correct"] = pct_correct
metrics["pct_any_in_seq_harmful"] = pct_any_harmful
metrics["pct_any_in_seq_correct"] = pct_any_correct
return metrics
# registration so you can call AutoModelForCausalLM.from_pretrained(...)
AutoConfig.register(MODEL_TYPE, LMWithHeadConfig)
AutoModelForCausalLM.register(LMWithHeadConfig, LMWithHead)