modernbert-nl-sql / modeling_reward.py
DarianNLP's picture
Upload 11 files
6b77899 verified
"""
Utility module for loading and running the finetuned ModernBERT reward model.
The model mirrors the architecture defined in `mosaic_bert_training.py`:
- base encoder: answerdotai/ModernBERT-base (8k context support)
- pooling: attention-mask-weighted mean pooling
- head: single linear layer + sigmoid to output a score in [0, 1]
"""
import os
from typing import Optional
import torch
from transformers import AutoModel
class BERTRewardModel(torch.nn.Module):
"""ModernBERT encoder with a sigmoid regression head."""
def __init__(self, model_name: str = "answerdotai/ModernBERT-base"):
super().__init__()
self.bert = AutoModel.from_pretrained(
model_name,
reference_compile=False,
attn_implementation="eager",
)
self.classifier = torch.nn.Linear(self.bert.config.hidden_size, 1)
self.sigmoid = torch.nn.Sigmoid()
def forward(self, input_ids, attention_mask, labels: Optional[torch.Tensor] = None):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
last_hidden_state = outputs.last_hidden_state
attention_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
sum_hidden = torch.sum(last_hidden_state * attention_mask_expanded, dim=1)
sum_mask = torch.clamp(attention_mask_expanded.sum(dim=1), min=1e-9)
pooled_output = sum_hidden / sum_mask
logits = self.classifier(pooled_output)
scores = self.sigmoid(logits).squeeze(-1)
loss = None
if labels is not None:
loss_fct = torch.nn.MSELoss()
loss = loss_fct(scores, labels) * 100
return {"loss": loss, "scores": scores, "logits": logits}
def load_finetuned_model(model_dir: str, device: Optional[str] = None) -> BERTRewardModel:
"""
Load the finetuned ModernBERT reward model from `model_dir`.
Args:
model_dir: Path containing model.safetensors (preferred) or pytorch_model.bin.
device: Optional torch device string. Defaults to CUDA if available else CPU.
"""
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_device = torch.device(device)
model = BERTRewardModel()
model.to(torch_device)
state_dict = None
safetensors_path = os.path.join(model_dir, "model.safetensors")
bin_path = os.path.join(model_dir, "pytorch_model.bin")
if os.path.exists(safetensors_path):
try:
from safetensors.torch import load_file
state_dict = load_file(safetensors_path)
except ImportError as exc:
print(f"⚠️ safetensors not available ({exc}); falling back to pytorch_model.bin if present.")
if state_dict is None and os.path.exists(bin_path):
state_dict = torch.load(bin_path, map_location=torch_device)
if state_dict is None:
raise FileNotFoundError(
f"Could not find model weights in {model_dir}. "
"Expected either model.safetensors or pytorch_model.bin."
)
model.load_state_dict(state_dict)
model.eval()
return model