ArbitrAgent / reward_model.py
AbeBhatti
Play-gent: Diplomacy-trained negotiation agent
afd245f
import json
import random
from typing import List, Dict, Any
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModel
class DiplomacyRewardModel(nn.Module):
"""
DistilBERT-based reward regressor for Diplomacy self-play states.
Input: state_text (encoded by distilbert-base-uncased)
Output: raw scalar (regression target), no sigmoid.
"""
def __init__(self, base_model: str = "distilbert-base-uncased"):
super().__init__()
self.encoder = AutoModel.from_pretrained(base_model)
hidden_size = self.encoder.config.hidden_size
self.head = nn.Sequential(
nn.Linear(hidden_size, 256),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(256, 64),
nn.ReLU(),
nn.Linear(64, 1),
)
def forward(self, input_ids, attention_mask=None, **kwargs):
# Ignore token_type_ids and other unused fields if present.
outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
pooled = outputs.last_hidden_state[:, 0, :] # CLS token
return self.head(pooled).squeeze(-1)
def score(
self,
state_text: str,
action_text: str,
tokenizer: AutoTokenizer,
device: torch.device,
) -> float:
"""
Convenience helper: encode (state, action) text pair and return scalar reward.
"""
self.eval()
combined = f"STATE: {state_text}\nACTION: {action_text}"
with torch.no_grad():
enc = tokenizer(
combined,
truncation=True,
max_length=128,
padding="max_length",
return_tensors="pt",
)
enc = {k: v.to(device) for k, v in enc.items()}
pred = self(**enc)
return float(pred.item())
class StatesDataset(Dataset):
"""Simple dataset wrapping tokenized state_text and scalar rewards."""
def __init__(self, texts: List[str], rewards: List[float], tokenizer: AutoTokenizer, max_length: int = 128):
self.texts = texts
self.rewards = rewards
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self) -> int:
return len(self.texts)
def __getitem__(self, idx: int) -> Dict[str, Any]:
text = self.texts[idx]
reward = self.rewards[idx]
enc = self.tokenizer(
text,
truncation=True,
max_length=self.max_length,
padding="max_length",
return_tensors="pt",
)
item = {
"input_ids": enc["input_ids"].squeeze(0),
"attention_mask": enc["attention_mask"].squeeze(0),
"reward": torch.tensor(reward, dtype=torch.float32),
}
return item
def score_state(model: DiplomacyRewardModel, tokenizer: AutoTokenizer, state_text: str, device: torch.device) -> float:
"""Helper to score a single state_text with the trained model."""
model.eval()
with torch.no_grad():
enc = tokenizer(
state_text,
truncation=True,
max_length=128,
padding="max_length",
return_tensors="pt",
)
enc = {k: v.to(device) for k, v in enc.items()}
pred = model(**enc)
return float(pred.item())
def train(
data_path: str = "selfplay_states.json",
output_model_path: str = "reward_model.pt",
loss_plot_path: str = "reward_model_loss.png",
epochs: int = 2,
batch_size: int = 128,
lr: float = 2e-5,
) -> None:
# Device setup
use_cuda = torch.cuda.is_available()
print("torch.cuda.is_available():", use_cuda)
device = torch.device("cuda" if use_cuda else "cpu")
if use_cuda:
print("Using GPU:", torch.cuda.get_device_name(0))
# Load data
print(f"Loading self-play states from {data_path}...")
with open(data_path, "r") as f:
data = json.load(f)
# Limit to 50k random examples for faster training.
random.shuffle(data)
data = data[:50000]
texts: List[str] = [ex.get("state_text", "") for ex in data]
rewards: List[float] = [float(ex.get("reward", 0.0)) for ex in data]
print(f"Total examples: {len(texts)}")
# Tokenizer and dataset
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
dataset = StatesDataset(texts, rewards, tokenizer, max_length=128)
# 90/10 train/val split
n_total = len(dataset)
n_train = int(0.9 * n_total)
n_val = n_total - n_train
train_ds, val_ds = random_split(dataset, [n_train, n_val])
print(f"Train examples: {n_train} | Val examples: {n_val}")
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False)
# Model, optimizer, loss
model = DiplomacyRewardModel().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
criterion = nn.MSELoss()
best_val_loss = float("inf")
train_losses: List[float] = []
val_losses: List[float] = []
for epoch in range(1, epochs + 1):
# Train epoch
model.train()
running_train_loss = 0.0
for batch in tqdm(train_loader, desc=f"Epoch {epoch}/{epochs} - train"):
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
targets = batch["reward"].to(device)
optimizer.zero_grad()
preds = model(input_ids=input_ids, attention_mask=attention_mask)
loss = criterion(preds, targets)
loss.backward()
optimizer.step()
running_train_loss += loss.item() * input_ids.size(0)
avg_train_loss = running_train_loss / n_train
# Validation epoch
model.eval()
running_val_loss = 0.0
with torch.no_grad():
for batch in tqdm(val_loader, desc=f"Epoch {epoch}/{epochs} - val"):
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
targets = batch["reward"].to(device)
preds = model(input_ids=input_ids, attention_mask=attention_mask)
loss = criterion(preds, targets)
running_val_loss += loss.item() * input_ids.size(0)
avg_val_loss = running_val_loss / n_val
train_losses.append(avg_train_loss)
val_losses.append(avg_val_loss)
print(
f"Epoch {epoch}/{epochs} | "
f"Train Loss: {avg_train_loss:.6f} | "
f"Val Loss: {avg_val_loss:.6f}"
)
# Save best model
if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss
torch.save(model.state_dict(), output_model_path)
print(f" -> New best val loss. Model saved to {output_model_path}")
# Plot loss curves
epochs_axis = np.arange(1, epochs + 1)
plt.figure()
plt.plot(epochs_axis, train_losses, label="Train Loss")
plt.plot(epochs_axis, val_losses, label="Val Loss")
plt.xlabel("Epoch")
plt.ylabel("MSE Loss")
plt.title("Reward Model Training Loss")
plt.legend()
plt.tight_layout()
plt.savefig(loss_plot_path)
plt.close()
print(f"Loss curves saved to {loss_plot_path}")
if __name__ == "__main__":
train()