| """ | |
| Utility module for loading and running the finetuned ModernBERT reward model. | |
| The model mirrors the architecture defined in `mosaic_bert_training.py`: | |
| - base encoder: answerdotai/ModernBERT-base (8k context support) | |
| - pooling: attention-mask-weighted mean pooling | |
| - head: single linear layer + sigmoid to output a score in [0, 1] | |
| """ | |
| import os | |
| from typing import Optional | |
| import torch | |
| from transformers import AutoModel | |
| class BERTRewardModel(torch.nn.Module): | |
| """ModernBERT encoder with a sigmoid regression head.""" | |
| def __init__(self, model_name: str = "answerdotai/ModernBERT-base"): | |
| super().__init__() | |
| self.bert = AutoModel.from_pretrained( | |
| model_name, | |
| reference_compile=False, | |
| attn_implementation="eager", | |
| ) | |
| self.classifier = torch.nn.Linear(self.bert.config.hidden_size, 1) | |
| self.sigmoid = torch.nn.Sigmoid() | |
| def forward(self, input_ids, attention_mask, labels: Optional[torch.Tensor] = None): | |
| outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) | |
| last_hidden_state = outputs.last_hidden_state | |
| attention_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float() | |
| sum_hidden = torch.sum(last_hidden_state * attention_mask_expanded, dim=1) | |
| sum_mask = torch.clamp(attention_mask_expanded.sum(dim=1), min=1e-9) | |
| pooled_output = sum_hidden / sum_mask | |
| logits = self.classifier(pooled_output) | |
| scores = self.sigmoid(logits).squeeze(-1) | |
| loss = None | |
| if labels is not None: | |
| loss_fct = torch.nn.MSELoss() | |
| loss = loss_fct(scores, labels) * 100 | |
| return {"loss": loss, "scores": scores, "logits": logits} | |
| def load_finetuned_model(model_dir: str, device: Optional[str] = None) -> BERTRewardModel: | |
| """ | |
| Load the finetuned ModernBERT reward model from `model_dir`. | |
| Args: | |
| model_dir: Path containing model.safetensors (preferred) or pytorch_model.bin. | |
| device: Optional torch device string. Defaults to CUDA if available else CPU. | |
| """ | |
| if device is None: | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| torch_device = torch.device(device) | |
| model = BERTRewardModel() | |
| model.to(torch_device) | |
| state_dict = None | |
| safetensors_path = os.path.join(model_dir, "model.safetensors") | |
| bin_path = os.path.join(model_dir, "pytorch_model.bin") | |
| if os.path.exists(safetensors_path): | |
| try: | |
| from safetensors.torch import load_file | |
| state_dict = load_file(safetensors_path) | |
| except ImportError as exc: | |
| print(f"⚠️ safetensors not available ({exc}); falling back to pytorch_model.bin if present.") | |
| if state_dict is None and os.path.exists(bin_path): | |
| state_dict = torch.load(bin_path, map_location=torch_device) | |
| if state_dict is None: | |
| raise FileNotFoundError( | |
| f"Could not find model weights in {model_dir}. " | |
| "Expected either model.safetensors or pytorch_model.bin." | |
| ) | |
| model.load_state_dict(state_dict) | |
| model.eval() | |
| return model | |