import json import logging from dataclasses import dataclass import os from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union import numpy as np import torch import torch.nn as nn from torch import Tensor from transformers import ( AutoModelForCausalLM, AutoConfig, PretrainedConfig, PreTrainedModel, EvalPrediction ) from transformers.modeling_outputs import CausalLMOutputWithPast def get_response_positions( y, response_token_ids: List[int] ) -> Tensor: response_token_ids_idxs = [] for i, _ in enumerate(y): matched_token_positions = np.where(y[i] == response_token_ids[0])[0] for assistant_idx in matched_token_positions: assistant_idx = int(assistant_idx) if (response_token_ids == y[i][assistant_idx : assistant_idx + len(response_token_ids)].tolist()): response_token_ids_idxs.append(assistant_idx + len(response_token_ids)) return torch.tensor(response_token_ids_idxs) MODEL_TYPE = "lm_with_head" @dataclass class LMWithHeadOutputWithPast(CausalLMOutputWithPast): classification_logits: Optional[torch.FloatTensor] = None classification_loss: Optional[torch.FloatTensor] = None @dataclass class LMWithHeadGenerationOutput: sequences: torch.LongTensor classification_logits: torch.FloatTensor hidden_states: torch.FloatTensor base_output: LMWithHeadOutputWithPast class LMWithHeadConfig(PretrainedConfig): model_type = MODEL_TYPE def __init__( self, base_model_id: str = None, num_labels: int = 2, classifier_dropout: float = 0.1, freeze_base: bool = True, # New configurable head parameters classifier_hidden_layers: List[int] = None, # List of hidden dimensions classifier_activation: str = "relu", # Activation function name **kwargs, ): super().__init__(**kwargs) self.base_model_id = base_model_id self.num_labels = num_labels self.classifier_dropout = classifier_dropout self.freeze_base = freeze_base # Default to empty list if None (single layer classifier) self.classifier_hidden_layers = classifier_hidden_layers or [] self.classifier_activation = classifier_activation class ConfigurableClassifierHead(nn.Module): """Configurable classifier head with variable number of hidden layers and activations.""" def __init__( self, input_dim: int, hidden_dims: List[int], output_dim: int, dropout_rate: float = 0.1, activation: str = "relu" ): super().__init__() # Map activation function name to actual function activation_map = { "relu": nn.ReLU(), "gelu": nn.GELU(), "silu": nn.SiLU(), "tanh": nn.Tanh(), "leaky_relu": nn.LeakyReLU(), "elu": nn.ELU(), } if activation not in activation_map: raise ValueError(f"Unsupported activation: {activation}. " f"Choose from: {list(activation_map.keys())}") activation_fn = activation_map[activation] # Build layers layers = [] # Input dimension current_dim = input_dim # Add hidden layers if specified if hidden_dims: for hidden_dim in hidden_dims: layers.append(nn.Linear(current_dim, hidden_dim)) layers.append(activation_fn) layers.append(nn.Dropout(dropout_rate)) current_dim = hidden_dim else: # If no hidden dims are provided, add a dropout layer before the output layer layers.append(nn.Dropout(dropout_rate)) # Output layer layers.append(nn.Linear(current_dim, output_dim)) self.classifier = nn.Sequential(*layers) def forward(self, x): return self.classifier(x) class LMWithHead(PreTrainedModel): config_class = LMWithHeadConfig def __init__(self, config: LMWithHeadConfig): super().__init__(config) # Load the backbone straight from HF (or local cache) if config.base_model_id is None: raise ValueError("base_model_id must be specified in the config.") self.base = AutoModelForCausalLM.from_pretrained(config.base_model_id) if config.freeze_base: for p in self.base.parameters(): p.requires_grad_(False) # Get the hidden size from the base model hid = self.base.config.hidden_size # Initialize the configurable classifier head # If no hidden layers are specified, this will create a single-layer classifier self.classifier = ConfigurableClassifierHead( input_dim=hid, hidden_dims=config.classifier_hidden_layers, output_dim=config.num_labels, dropout_rate=config.classifier_dropout, activation=config.classifier_activation ) self.post_init() # initialize the new head def forward( self, input_ids, attention_mask=None, labels=None, class_labels=None, class_labels_mask=None, output_hidden_states=False, **kwargs ): out = self.base( input_ids=input_ids, attention_mask=attention_mask, labels=labels, output_hidden_states=True, **kwargs, ) hidden_states = out.hidden_states[-1] # (B, L, H) logits_cls = self.classifier(hidden_states) # (B, L, C) loss_cls = None if class_labels is not None: # boolean mask of shape (B, L) mask = class_labels_mask # rename for clarity if mask.any(): # skip batches with no valid tokens if self.config.num_labels == 1: # binary (BCE) preds = logits_cls[mask].squeeze(-1) # (N,) target = class_labels[mask].float() # (N,) loss_fct = nn.BCEWithLogitsLoss() loss_cls = loss_fct(preds, target) else: # multi‑class (CE) preds = logits_cls[mask] # (N, C) target = class_labels[mask] # (N,) loss_fct = nn.CrossEntropyLoss() loss_cls = loss_fct(preds, target) else: # Optional: set loss to zero so it still back‑propagates loss_cls = torch.tensor(0.0, device=logits_cls.device, requires_grad=True) # combine losses if you like total_loss = 0 if out.loss is not None: total_loss += out.loss if loss_cls is not None: total_loss += loss_cls # Use the dataclass for output return LMWithHeadOutputWithPast( loss=total_loss, logits=out.logits, past_key_values=out.past_key_values, hidden_states=out.hidden_states if output_hidden_states else None, attentions=out.attentions if kwargs.get("output_attentions") else None, classification_logits=logits_cls, classification_loss=loss_cls, ) def save_pretrained(self, save_dir, head_only=True, **kwargs): os.makedirs(save_dir, exist_ok=True) self.config.save_pretrained(save_dir) if head_only: # just the delta torch.save(self.classifier.state_dict(), os.path.join(save_dir, "classifier.pt")) # tiny helper to remember which backbone to reload with open(os.path.join(save_dir, "base.json"), "w") as f: json.dump({"base_model_id": self.config.base_model_id}, f) else: # normal full save super().save_pretrained(save_dir, **kwargs) @classmethod def from_pretrained(cls, path, **kwargs): # Get config first config = kwargs.get("config", None) if config is None: config = LMWithHeadConfig.from_pretrained(path, **kwargs) # Check if we're loading from a local path or a Hub repo is_local = os.path.isdir(path) # Try to load custom checkpoint structure try: if is_local: # Local path approach base_json_path = os.path.join(path, "base.json") classifier_path = os.path.join(path, "classifier.pt") else: # Hub approach - use the Hugging Face Hub file system from huggingface_hub import hf_hub_download base_json_path = hf_hub_download(repo_id=path, filename="base.json") classifier_path = hf_hub_download(repo_id=path, filename="classifier.pt") # Load base model ID from base.json with open(base_json_path) as f: base_id = json.load(f)["base_model_id"] # Update config config.base_model_id = base_id # Create model with config model = cls(config) # Load classifier weights head_sd = torch.load(classifier_path, map_location="cpu") model.classifier.load_state_dict(head_sd, strict=True) return model except (FileNotFoundError, OSError, Exception) as e: # If custom loading fails, try standard approach # This will likely fail unless there are pytorch_model.bin files try: return super().from_pretrained(path, **kwargs) except Exception as inner_e: # If both approaches fail, provide a helpful error message raise ValueError( f"Could not load model from {path}. " f"Custom loading failed with: {str(e)}. " f"Standard loading failed with: {str(inner_e)}. " f"Make sure the repository contains either 'base.json' and 'classifier.pt' files, " f"or standard model weights files." ) def generate_with_classification( self, input_ids: torch.LongTensor, attention_mask: Optional[torch.LongTensor] = None, **generate_kwargs, ) -> Dict[str, torch.Tensor]: # Step 1: generate tokens with base model gen_output = self.base.generate( input_ids=input_ids, attention_mask=attention_mask, return_dict_in_generate=True, output_hidden_states=True, # ensure we can get states later **generate_kwargs, ) # Step 2: re-run forward pass to get hidden states for classification # This is necessary because `generate()` does not return all hidden states with torch.no_grad(): outputs = self.base( input_ids=gen_output.sequences, # TODO: this currently is hardcoded to Llama!!! attention_mask=(gen_output.sequences != 128009), #self.base.config.pad_token_id), output_hidden_states=True, ) last_hidden = outputs.hidden_states[-1] # (B, L, H) classification_logits = self.classifier(last_hidden) # (B, L, C) return LMWithHeadGenerationOutput( sequences=gen_output.sequences, classification_logits=classification_logits, hidden_states=last_hidden, base_output=gen_output, ) def mask_range( tensor, fill_value: float, start_pos, end_pos, ): if end_pos is not None: mask = torch.arange(tensor.shape[1], device=tensor.device).unsqueeze(0) mask = (mask >= start_pos.unsqueeze(1)) & (mask <= end_pos.unsqueeze(1)) else: mask = torch.arange(tensor.shape[1], device=tensor.device).unsqueeze( 0 ) == start_pos.unsqueeze(1) return torch.where(mask, tensor, fill_value) class LMWithHeadComputeMetrics: def __init__(self, response_idx: int | List[int]): """ Args: response_idx (int | List[int]): The index of the response token(s) in the vocabulary, i.e. <|assistant|> """ if isinstance(response_idx, int): response_idx = [response_idx] self.response_idx = response_idx def __call__(self, p: EvalPrediction) -> Dict: metrics = {} response_start_idx = get_response_positions(p.inputs, self.response_idx) label_mask = p.label_ids[1] & (p.label_ids[0] != -100) # if not all(label_mask[torch.arange(len(label_mask)), response_start_idx+1]): # logging.warning("Label mask does not match response start index, may have included an offset. Loss metrics may be incorrect") # TODO: get standard perplexity loss # getting probs of classification on harmfulness logits = torch.tensor(p.predictions[1]) probs = torch.softmax(logits, dim=-1) preds = probs.argmax(dim=-1) # pct tokens harmful pct_harmful_all = preds[label_mask].to(float).mean().item() # pct correct classified pct_correct = (preds == p.label_ids[0])[label_mask].to(float).mean().item() # pct strings correctly classified anywhere _any_harmful = (preds * label_mask).any(-1) pct_any_harmful = _any_harmful.to(float).mean().item() pct_any_correct = (_any_harmful == (p.label_ids[0] * label_mask).any(-1)).to(float).mean().item() metrics["pct_harmful"] = pct_harmful_all metrics["pct_correct"] = pct_correct metrics["pct_any_in_seq_harmful"] = pct_any_harmful metrics["pct_any_in_seq_correct"] = pct_any_correct return metrics # registration so you can call AutoModelForCausalLM.from_pretrained(...) AutoConfig.register(MODEL_TYPE, LMWithHeadConfig) AutoModelForCausalLM.register(LMWithHeadConfig, LMWithHead)