| | import json |
| | import logging |
| | from dataclasses import dataclass |
| | import os |
| | from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union |
| | import numpy as np |
| | import torch |
| | import torch.nn as nn |
| | from torch import Tensor |
| | from transformers import ( |
| | AutoModelForCausalLM, |
| | AutoConfig, |
| | PretrainedConfig, |
| | PreTrainedModel, |
| | EvalPrediction |
| | ) |
| | from transformers.modeling_outputs import CausalLMOutputWithPast |
| |
|
| |
|
| | def get_response_positions( |
| | y, |
| | response_token_ids: List[int] |
| | ) -> Tensor: |
| | response_token_ids_idxs = [] |
| | for i, _ in enumerate(y): |
| | matched_token_positions = np.where(y[i] == response_token_ids[0])[0] |
| | for assistant_idx in matched_token_positions: |
| | assistant_idx = int(assistant_idx) |
| | if (response_token_ids == y[i][assistant_idx : assistant_idx + len(response_token_ids)].tolist()): |
| | response_token_ids_idxs.append(assistant_idx + len(response_token_ids)) |
| | return torch.tensor(response_token_ids_idxs) |
| |
|
| |
|
| | MODEL_TYPE = "lm_with_head" |
| |
|
| |
|
| | @dataclass |
| | class LMWithHeadOutputWithPast(CausalLMOutputWithPast): |
| | classification_logits: Optional[torch.FloatTensor] = None |
| | classification_loss: Optional[torch.FloatTensor] = None |
| |
|
| |
|
| | @dataclass |
| | class LMWithHeadGenerationOutput: |
| | sequences: torch.LongTensor |
| | classification_logits: torch.FloatTensor |
| | hidden_states: torch.FloatTensor |
| | base_output: LMWithHeadOutputWithPast |
| |
|
| | class LMWithHeadConfig(PretrainedConfig): |
| | model_type = MODEL_TYPE |
| |
|
| | def __init__( |
| | self, |
| | base_model_id: str = None, |
| | num_labels: int = 2, |
| | classifier_dropout: float = 0.1, |
| | freeze_base: bool = True, |
| | |
| | classifier_hidden_layers: List[int] = None, |
| | classifier_activation: str = "relu", |
| | **kwargs, |
| | ): |
| | super().__init__(**kwargs) |
| |
|
| | self.base_model_id = base_model_id |
| | self.num_labels = num_labels |
| | self.classifier_dropout = classifier_dropout |
| | self.freeze_base = freeze_base |
| | |
| | |
| | self.classifier_hidden_layers = classifier_hidden_layers or [] |
| | self.classifier_activation = classifier_activation |
| | |
| |
|
| |
|
| | class ConfigurableClassifierHead(nn.Module): |
| | """Configurable classifier head with variable number of hidden layers and activations.""" |
| | |
| | def __init__( |
| | self, |
| | input_dim: int, |
| | hidden_dims: List[int], |
| | output_dim: int, |
| | dropout_rate: float = 0.1, |
| | activation: str = "relu" |
| | ): |
| | super().__init__() |
| | |
| | |
| | activation_map = { |
| | "relu": nn.ReLU(), |
| | "gelu": nn.GELU(), |
| | "silu": nn.SiLU(), |
| | "tanh": nn.Tanh(), |
| | "leaky_relu": nn.LeakyReLU(), |
| | "elu": nn.ELU(), |
| | } |
| | |
| | if activation not in activation_map: |
| | raise ValueError(f"Unsupported activation: {activation}. " |
| | f"Choose from: {list(activation_map.keys())}") |
| | |
| | activation_fn = activation_map[activation] |
| | |
| | |
| | layers = [] |
| | |
| | |
| | current_dim = input_dim |
| | |
| | |
| | if hidden_dims: |
| | for hidden_dim in hidden_dims: |
| | layers.append(nn.Linear(current_dim, hidden_dim)) |
| | layers.append(activation_fn) |
| | layers.append(nn.Dropout(dropout_rate)) |
| | current_dim = hidden_dim |
| | else: |
| | |
| | layers.append(nn.Dropout(dropout_rate)) |
| | |
| | |
| | layers.append(nn.Linear(current_dim, output_dim)) |
| | |
| | self.classifier = nn.Sequential(*layers) |
| | |
| | def forward(self, x): |
| | return self.classifier(x) |
| |
|
| |
|
| | class LMWithHead(PreTrainedModel): |
| | config_class = LMWithHeadConfig |
| |
|
| | def __init__(self, config: LMWithHeadConfig): |
| | super().__init__(config) |
| |
|
| | |
| | if config.base_model_id is None: |
| | raise ValueError("base_model_id must be specified in the config.") |
| | self.base = AutoModelForCausalLM.from_pretrained(config.base_model_id) |
| |
|
| | if config.freeze_base: |
| | for p in self.base.parameters(): |
| | p.requires_grad_(False) |
| |
|
| | |
| | hid = self.base.config.hidden_size |
| | |
| | |
| | |
| | self.classifier = ConfigurableClassifierHead( |
| | input_dim=hid, |
| | hidden_dims=config.classifier_hidden_layers, |
| | output_dim=config.num_labels, |
| | dropout_rate=config.classifier_dropout, |
| | activation=config.classifier_activation |
| | ) |
| | |
| | self.post_init() |
| |
|
| | def forward( |
| | self, |
| | input_ids, |
| | attention_mask=None, |
| | labels=None, |
| | class_labels=None, |
| | class_labels_mask=None, |
| | output_hidden_states=False, |
| | **kwargs |
| | ): |
| | out = self.base( |
| | input_ids=input_ids, |
| | attention_mask=attention_mask, |
| | labels=labels, |
| | output_hidden_states=True, |
| | **kwargs, |
| | ) |
| |
|
| | hidden_states = out.hidden_states[-1] |
| | logits_cls = self.classifier(hidden_states) |
| |
|
| | loss_cls = None |
| | if class_labels is not None: |
| | |
| | mask = class_labels_mask |
| |
|
| | if mask.any(): |
| | if self.config.num_labels == 1: |
| | preds = logits_cls[mask].squeeze(-1) |
| | target = class_labels[mask].float() |
| | loss_fct = nn.BCEWithLogitsLoss() |
| | loss_cls = loss_fct(preds, target) |
| | else: |
| | preds = logits_cls[mask] |
| | target = class_labels[mask] |
| | loss_fct = nn.CrossEntropyLoss() |
| | loss_cls = loss_fct(preds, target) |
| | else: |
| | |
| | loss_cls = torch.tensor(0.0, device=logits_cls.device, requires_grad=True) |
| |
|
| | |
| | total_loss = 0 |
| | if out.loss is not None: |
| | total_loss += out.loss |
| | if loss_cls is not None: |
| | total_loss += loss_cls |
| |
|
| | |
| | return LMWithHeadOutputWithPast( |
| | loss=total_loss, |
| | logits=out.logits, |
| | past_key_values=out.past_key_values, |
| | hidden_states=out.hidden_states if output_hidden_states else None, |
| | attentions=out.attentions if kwargs.get("output_attentions") else None, |
| | classification_logits=logits_cls, |
| | classification_loss=loss_cls, |
| | ) |
| |
|
| | def save_pretrained(self, save_dir, head_only=True, **kwargs): |
| | os.makedirs(save_dir, exist_ok=True) |
| | self.config.save_pretrained(save_dir) |
| |
|
| | if head_only: |
| | torch.save(self.classifier.state_dict(), os.path.join(save_dir, "classifier.pt")) |
| | |
| | with open(os.path.join(save_dir, "base.json"), "w") as f: |
| | json.dump({"base_model_id": self.config.base_model_id}, f) |
| | else: |
| | super().save_pretrained(save_dir, **kwargs) |
| |
|
| | @classmethod |
| | def from_pretrained(cls, path, **kwargs): |
| | |
| | config = kwargs.get("config", None) |
| | if config is None: |
| | config = LMWithHeadConfig.from_pretrained(path, **kwargs) |
| | |
| | |
| | is_local = os.path.isdir(path) |
| | |
| | |
| | try: |
| | if is_local: |
| | |
| | base_json_path = os.path.join(path, "base.json") |
| | classifier_path = os.path.join(path, "classifier.pt") |
| | else: |
| | |
| | from huggingface_hub import hf_hub_download |
| | base_json_path = hf_hub_download(repo_id=path, filename="base.json") |
| | classifier_path = hf_hub_download(repo_id=path, filename="classifier.pt") |
| | |
| | |
| | with open(base_json_path) as f: |
| | base_id = json.load(f)["base_model_id"] |
| | |
| | |
| | config.base_model_id = base_id |
| | |
| | |
| | model = cls(config) |
| | |
| | |
| | head_sd = torch.load(classifier_path, map_location="cpu") |
| | model.classifier.load_state_dict(head_sd, strict=True) |
| | |
| | return model |
| | |
| | except (FileNotFoundError, OSError, Exception) as e: |
| | |
| | |
| | try: |
| | return super().from_pretrained(path, **kwargs) |
| | except Exception as inner_e: |
| | |
| | raise ValueError( |
| | f"Could not load model from {path}. " |
| | f"Custom loading failed with: {str(e)}. " |
| | f"Standard loading failed with: {str(inner_e)}. " |
| | f"Make sure the repository contains either 'base.json' and 'classifier.pt' files, " |
| | f"or standard model weights files." |
| | ) |
| |
|
| | def generate_with_classification( |
| | self, |
| | input_ids: torch.LongTensor, |
| | attention_mask: Optional[torch.LongTensor] = None, |
| | **generate_kwargs, |
| | ) -> Dict[str, torch.Tensor]: |
| | |
| | gen_output = self.base.generate( |
| | input_ids=input_ids, |
| | attention_mask=attention_mask, |
| | return_dict_in_generate=True, |
| | output_hidden_states=True, |
| | **generate_kwargs, |
| | ) |
| |
|
| | |
| | |
| | with torch.no_grad(): |
| | outputs = self.base( |
| | input_ids=gen_output.sequences, |
| | |
| | attention_mask=(gen_output.sequences != 128009), |
| | output_hidden_states=True, |
| | ) |
| | last_hidden = outputs.hidden_states[-1] |
| | classification_logits = self.classifier(last_hidden) |
| |
|
| | return LMWithHeadGenerationOutput( |
| | sequences=gen_output.sequences, |
| | classification_logits=classification_logits, |
| | hidden_states=last_hidden, |
| | base_output=gen_output, |
| | ) |
| |
|
| | def mask_range( |
| | tensor, |
| | fill_value: float, |
| | start_pos, |
| | end_pos, |
| | ): |
| | if end_pos is not None: |
| | mask = torch.arange(tensor.shape[1], device=tensor.device).unsqueeze(0) |
| | mask = (mask >= start_pos.unsqueeze(1)) & (mask <= end_pos.unsqueeze(1)) |
| | else: |
| | mask = torch.arange(tensor.shape[1], device=tensor.device).unsqueeze( |
| | 0 |
| | ) == start_pos.unsqueeze(1) |
| | return torch.where(mask, tensor, fill_value) |
| |
|
| |
|
| | class LMWithHeadComputeMetrics: |
| | def __init__(self, response_idx: int | List[int]): |
| | """ |
| | Args: |
| | response_idx (int | List[int]): The index of the response token(s) in the vocabulary, |
| | i.e. <|assistant|> |
| | """ |
| | if isinstance(response_idx, int): |
| | response_idx = [response_idx] |
| | self.response_idx = response_idx |
| |
|
| | def __call__(self, p: EvalPrediction) -> Dict: |
| | metrics = {} |
| | response_start_idx = get_response_positions(p.inputs, self.response_idx) |
| |
|
| | label_mask = p.label_ids[1] & (p.label_ids[0] != -100) |
| | |
| | |
| |
|
| | |
| |
|
| | |
| | logits = torch.tensor(p.predictions[1]) |
| | probs = torch.softmax(logits, dim=-1) |
| | preds = probs.argmax(dim=-1) |
| |
|
| | |
| | pct_harmful_all = preds[label_mask].to(float).mean().item() |
| |
|
| | |
| | pct_correct = (preds == p.label_ids[0])[label_mask].to(float).mean().item() |
| |
|
| | |
| | _any_harmful = (preds * label_mask).any(-1) |
| | pct_any_harmful = _any_harmful.to(float).mean().item() |
| | pct_any_correct = (_any_harmful == (p.label_ids[0] * label_mask).any(-1)).to(float).mean().item() |
| |
|
| | metrics["pct_harmful"] = pct_harmful_all |
| | metrics["pct_correct"] = pct_correct |
| | metrics["pct_any_in_seq_harmful"] = pct_any_harmful |
| | metrics["pct_any_in_seq_correct"] = pct_any_correct |
| |
|
| | return metrics |
| |
|
| |
|
| | |
| | AutoConfig.register(MODEL_TYPE, LMWithHeadConfig) |
| | AutoModelForCausalLM.register(LMWithHeadConfig, LMWithHead) |
| |
|
| |
|
| |
|