""" Train DistilBERT binary classifier on bluff labels. Default data: training/data/poker/bluff_labels.json Model: distilbert-base-uncased + linear 768→2 80/20 train/val stratified, 3 epochs, lr 2e-5, batch 32 Saves: training/checkpoints/bluff_classifier.pt, bluff_classifier_tokenizer/ Use --data to point at negotiation_bluff_labels.json and --output to choose an alternative checkpoint path. """ import argparse import json import os from pathlib import Path import torch import torch.nn as nn from sklearn.model_selection import train_test_split from torch.utils.data import Dataset, DataLoader from transformers import AutoTokenizer, AutoModel SCRIPT_DIR = Path(__file__).resolve().parent DEFAULT_DATA_PATH = SCRIPT_DIR / "data" / "poker" / "bluff_labels.json" DEFAULT_CHECKPOINT_DIR = SCRIPT_DIR / "checkpoints" DEFAULT_MODEL_PT = DEFAULT_CHECKPOINT_DIR / "bluff_classifier.pt" TOKENIZER_DIR = DEFAULT_CHECKPOINT_DIR / "bluff_classifier_tokenizer" MAX_LENGTH = 128 EPOCHS = 3 LR = 2e-5 BATCH_SIZE = 32 class BluffClassifier(nn.Module): """DistilBERT + linear head 768 → 2 (binary: not_bluff, bluff).""" def __init__(self, base_model: str = "distilbert-base-uncased"): super().__init__() self.encoder = AutoModel.from_pretrained(base_model) hidden_size = self.encoder.config.hidden_size self.head = nn.Linear(hidden_size, 2) def forward(self, input_ids, attention_mask=None, **kwargs): out = self.encoder(input_ids=input_ids, attention_mask=attention_mask) pooled = out.last_hidden_state[:, 0, :] return self.head(pooled) class BluffDataset(Dataset): def __init__(self, texts, labels, tokenizer): self.texts = texts self.labels = labels self.tokenizer = tokenizer def __len__(self): return len(self.texts) def __getitem__(self, idx): enc = self.tokenizer( self.texts[idx], truncation=True, max_length=MAX_LENGTH, padding="max_length", return_tensors="pt", ) return { "input_ids": enc["input_ids"].squeeze(0), "attention_mask": enc["attention_mask"].squeeze(0), "labels": torch.tensor(self.labels[idx], dtype=torch.long), } def main(): parser = argparse.ArgumentParser(description="Train bluff classifier.") parser.add_argument( "--data", type=str, default=str(DEFAULT_DATA_PATH), help=( "Path to JSON bluff label file " '(default: training/data/poker/bluff_labels.json)' ), ) parser.add_argument( "--output", type=str, default=str(DEFAULT_MODEL_PT), help=( "Path to save model checkpoint " "(default: training/checkpoints/bluff_classifier.pt)" ), ) args = parser.parse_args() data_path = Path(args.data) model_pt = Path(args.output) checkpoint_dir = model_pt.parent if not data_path.exists(): print(f"ERROR: {data_path} not found.") return with data_path.open() as f: data = json.load(f) texts = [x["text"] for x in data] labels = [1 if x["is_bluff"] else 0 for x in data] X_train, X_val, y_train, y_val = train_test_split( texts, labels, test_size=0.2, stratify=labels, random_state=42 ) tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased") train_ds = BluffDataset(X_train, y_train, tokenizer) val_ds = BluffDataset(X_val, y_val, tokenizer) train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True) val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = BluffClassifier().to(device) opt = torch.optim.AdamW(model.parameters(), lr=LR) criterion = nn.CrossEntropyLoss() os.makedirs(checkpoint_dir, exist_ok=True) for epoch in range(EPOCHS): model.train() for batch in train_loader: opt.zero_grad() out = model( input_ids=batch["input_ids"].to(device), attention_mask=batch["attention_mask"].to(device), ) loss = criterion(out, batch["labels"].to(device)) loss.backward() opt.step() model.eval() correct, total = 0, 0 all_pred, all_true = [], [] with torch.no_grad(): for batch in val_loader: out = model( input_ids=batch["input_ids"].to(device), attention_mask=batch["attention_mask"].to(device), ) pred = out.argmax(dim=1) correct += (pred == batch["labels"].to(device)).sum().item() total += pred.size(0) all_pred.extend(pred.cpu().tolist()) all_true.extend(batch["labels"].tolist()) acc = correct / total if total else 0 # F1 binary: bluff=1 tp = sum(1 for p, t in zip(all_pred, all_true) if p == 1 and t == 1) fp = sum(1 for p, t in zip(all_pred, all_true) if p == 1 and t == 0) fn = sum(1 for p, t in zip(all_pred, all_true) if p == 0 and t == 1) prec = tp / (tp + fp) if (tp + fp) else 0 rec = tp / (tp + fn) if (tp + fn) else 0 f1 = 2 * prec * rec / (prec + rec) if (prec + rec) else 0 print(f"Epoch {epoch + 1}/{EPOCHS} Val accuracy: {acc:.4f} Val F1: {f1:.4f}") if acc < 0.65: print(f"WARNING: Val accuracy {acc:.4f} < 0.65 (target). Consider more data or epochs.") torch.save(model.state_dict(), model_pt) tokenizer.save_pretrained(TOKENIZER_DIR) print(f"Saved model to {model_pt}, tokenizer to {TOKENIZER_DIR}") if __name__ == "__main__": main()