Spaces:
Runtime error
Runtime error
| import json | |
| import random | |
| from typing import List, Dict, Any | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from torch.utils.data import Dataset, DataLoader, random_split | |
| from tqdm import tqdm | |
| from transformers import AutoTokenizer, AutoModel | |
| class DiplomacyRewardModel(nn.Module): | |
| """ | |
| DistilBERT-based reward regressor for Diplomacy self-play states. | |
| Input: state_text (encoded by distilbert-base-uncased) | |
| Output: raw scalar (regression target), no sigmoid. | |
| """ | |
| def __init__(self, base_model: str = "distilbert-base-uncased"): | |
| super().__init__() | |
| self.encoder = AutoModel.from_pretrained(base_model) | |
| hidden_size = self.encoder.config.hidden_size | |
| self.head = nn.Sequential( | |
| nn.Linear(hidden_size, 256), | |
| nn.ReLU(), | |
| nn.Dropout(0.2), | |
| nn.Linear(256, 64), | |
| nn.ReLU(), | |
| nn.Linear(64, 1), | |
| ) | |
| def forward(self, input_ids, attention_mask=None, **kwargs): | |
| # Ignore token_type_ids and other unused fields if present. | |
| outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask) | |
| pooled = outputs.last_hidden_state[:, 0, :] # CLS token | |
| return self.head(pooled).squeeze(-1) | |
| def score( | |
| self, | |
| state_text: str, | |
| action_text: str, | |
| tokenizer: AutoTokenizer, | |
| device: torch.device, | |
| ) -> float: | |
| """ | |
| Convenience helper: encode (state, action) text pair and return scalar reward. | |
| """ | |
| self.eval() | |
| combined = f"STATE: {state_text}\nACTION: {action_text}" | |
| with torch.no_grad(): | |
| enc = tokenizer( | |
| combined, | |
| truncation=True, | |
| max_length=128, | |
| padding="max_length", | |
| return_tensors="pt", | |
| ) | |
| enc = {k: v.to(device) for k, v in enc.items()} | |
| pred = self(**enc) | |
| return float(pred.item()) | |
| class StatesDataset(Dataset): | |
| """Simple dataset wrapping tokenized state_text and scalar rewards.""" | |
| def __init__(self, texts: List[str], rewards: List[float], tokenizer: AutoTokenizer, max_length: int = 128): | |
| self.texts = texts | |
| self.rewards = rewards | |
| self.tokenizer = tokenizer | |
| self.max_length = max_length | |
| def __len__(self) -> int: | |
| return len(self.texts) | |
| def __getitem__(self, idx: int) -> Dict[str, Any]: | |
| text = self.texts[idx] | |
| reward = self.rewards[idx] | |
| enc = self.tokenizer( | |
| text, | |
| truncation=True, | |
| max_length=self.max_length, | |
| padding="max_length", | |
| return_tensors="pt", | |
| ) | |
| item = { | |
| "input_ids": enc["input_ids"].squeeze(0), | |
| "attention_mask": enc["attention_mask"].squeeze(0), | |
| "reward": torch.tensor(reward, dtype=torch.float32), | |
| } | |
| return item | |
| def score_state(model: DiplomacyRewardModel, tokenizer: AutoTokenizer, state_text: str, device: torch.device) -> float: | |
| """Helper to score a single state_text with the trained model.""" | |
| model.eval() | |
| with torch.no_grad(): | |
| enc = tokenizer( | |
| state_text, | |
| truncation=True, | |
| max_length=128, | |
| padding="max_length", | |
| return_tensors="pt", | |
| ) | |
| enc = {k: v.to(device) for k, v in enc.items()} | |
| pred = model(**enc) | |
| return float(pred.item()) | |
| def train( | |
| data_path: str = "selfplay_states.json", | |
| output_model_path: str = "reward_model.pt", | |
| loss_plot_path: str = "reward_model_loss.png", | |
| epochs: int = 2, | |
| batch_size: int = 128, | |
| lr: float = 2e-5, | |
| ) -> None: | |
| # Device setup | |
| use_cuda = torch.cuda.is_available() | |
| print("torch.cuda.is_available():", use_cuda) | |
| device = torch.device("cuda" if use_cuda else "cpu") | |
| if use_cuda: | |
| print("Using GPU:", torch.cuda.get_device_name(0)) | |
| # Load data | |
| print(f"Loading self-play states from {data_path}...") | |
| with open(data_path, "r") as f: | |
| data = json.load(f) | |
| # Limit to 50k random examples for faster training. | |
| random.shuffle(data) | |
| data = data[:50000] | |
| texts: List[str] = [ex.get("state_text", "") for ex in data] | |
| rewards: List[float] = [float(ex.get("reward", 0.0)) for ex in data] | |
| print(f"Total examples: {len(texts)}") | |
| # Tokenizer and dataset | |
| tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased") | |
| dataset = StatesDataset(texts, rewards, tokenizer, max_length=128) | |
| # 90/10 train/val split | |
| n_total = len(dataset) | |
| n_train = int(0.9 * n_total) | |
| n_val = n_total - n_train | |
| train_ds, val_ds = random_split(dataset, [n_train, n_val]) | |
| print(f"Train examples: {n_train} | Val examples: {n_val}") | |
| train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True) | |
| val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False) | |
| # Model, optimizer, loss | |
| model = DiplomacyRewardModel().to(device) | |
| optimizer = torch.optim.AdamW(model.parameters(), lr=lr) | |
| criterion = nn.MSELoss() | |
| best_val_loss = float("inf") | |
| train_losses: List[float] = [] | |
| val_losses: List[float] = [] | |
| for epoch in range(1, epochs + 1): | |
| # Train epoch | |
| model.train() | |
| running_train_loss = 0.0 | |
| for batch in tqdm(train_loader, desc=f"Epoch {epoch}/{epochs} - train"): | |
| input_ids = batch["input_ids"].to(device) | |
| attention_mask = batch["attention_mask"].to(device) | |
| targets = batch["reward"].to(device) | |
| optimizer.zero_grad() | |
| preds = model(input_ids=input_ids, attention_mask=attention_mask) | |
| loss = criterion(preds, targets) | |
| loss.backward() | |
| optimizer.step() | |
| running_train_loss += loss.item() * input_ids.size(0) | |
| avg_train_loss = running_train_loss / n_train | |
| # Validation epoch | |
| model.eval() | |
| running_val_loss = 0.0 | |
| with torch.no_grad(): | |
| for batch in tqdm(val_loader, desc=f"Epoch {epoch}/{epochs} - val"): | |
| input_ids = batch["input_ids"].to(device) | |
| attention_mask = batch["attention_mask"].to(device) | |
| targets = batch["reward"].to(device) | |
| preds = model(input_ids=input_ids, attention_mask=attention_mask) | |
| loss = criterion(preds, targets) | |
| running_val_loss += loss.item() * input_ids.size(0) | |
| avg_val_loss = running_val_loss / n_val | |
| train_losses.append(avg_train_loss) | |
| val_losses.append(avg_val_loss) | |
| print( | |
| f"Epoch {epoch}/{epochs} | " | |
| f"Train Loss: {avg_train_loss:.6f} | " | |
| f"Val Loss: {avg_val_loss:.6f}" | |
| ) | |
| # Save best model | |
| if avg_val_loss < best_val_loss: | |
| best_val_loss = avg_val_loss | |
| torch.save(model.state_dict(), output_model_path) | |
| print(f" -> New best val loss. Model saved to {output_model_path}") | |
| # Plot loss curves | |
| epochs_axis = np.arange(1, epochs + 1) | |
| plt.figure() | |
| plt.plot(epochs_axis, train_losses, label="Train Loss") | |
| plt.plot(epochs_axis, val_losses, label="Val Loss") | |
| plt.xlabel("Epoch") | |
| plt.ylabel("MSE Loss") | |
| plt.title("Reward Model Training Loss") | |
| plt.legend() | |
| plt.tight_layout() | |
| plt.savefig(loss_plot_path) | |
| plt.close() | |
| print(f"Loss curves saved to {loss_plot_path}") | |
| if __name__ == "__main__": | |
| train() | |