mash-stylebart-trainer / src /train_sft.py
catninja123's picture
Upload src/train_sft.py with huggingface_hub
a06fe6d verified
"""
MASH Stage 2: Style-injection Supervised Fine-Tuning
Trains StyleBART on dual-task objective:
- L_trans: (ai_text, s_human_*) → human_text (style transfer)
- L_recon: (ai_text, s_ai_*) → ai_text (reconstruction)
- L_SFT = λ·L_recon + (1-λ)·L_trans
Multi-style: separate embeddings for PS and Supp genres.
"""
import os
import sys
import json
import time
import argparse
import torch
from torch.utils.data import DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
# Add parent dir to path
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from model import StyleBART
from dataset import MASHSFTDataset, collate_fn
def train_epoch(model, dataloader, optimizer, scheduler, device, lambda_recon=0.3):
"""Train one epoch with dual-task loss."""
model.train()
total_loss = 0
total_trans_loss = 0
total_recon_loss = 0
n_batches = 0
for batch in dataloader:
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
labels = batch['labels'].to(device)
style_keys = batch['style_keys']
tasks = batch['tasks']
# Forward pass
outputs = model(input_ids, attention_mask, labels, style_keys)
# Separate losses by task
loss_per_sample = torch.zeros(len(tasks), device=device)
for i, task in enumerate(tasks):
# Get per-sample loss (need to compute manually)
sample_labels = labels[i:i+1]
sample_logits = outputs.logits[i:i+1]
# Cross-entropy loss per sample
shift_logits = sample_logits[..., :-1, :].contiguous()
shift_labels = sample_labels[..., 1:].contiguous()
loss_fn = torch.nn.CrossEntropyLoss(ignore_index=-100)
sample_loss = loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
loss_per_sample[i] = sample_loss
# Weighted combination
trans_mask = torch.tensor([1.0 if t == 'transfer' else 0.0 for t in tasks], device=device)
recon_mask = torch.tensor([1.0 if t == 'reconstruction' else 0.0 for t in tasks], device=device)
n_trans = trans_mask.sum().clamp(min=1)
n_recon = recon_mask.sum().clamp(min=1)
trans_loss = (loss_per_sample * trans_mask).sum() / n_trans
recon_loss = (loss_per_sample * recon_mask).sum() / n_recon
# Combined loss: L_SFT = λ·L_recon + (1-λ)·L_trans
if recon_mask.sum() > 0:
loss = lambda_recon * recon_loss + (1 - lambda_recon) * trans_loss
else:
loss = trans_loss
# Backward
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
scheduler.step()
total_loss += loss.item()
total_trans_loss += trans_loss.item()
total_recon_loss += recon_loss.item() if recon_mask.sum() > 0 else 0
n_batches += 1
return {
'loss': total_loss / n_batches,
'trans_loss': total_trans_loss / n_batches,
'recon_loss': total_recon_loss / n_batches,
}
@torch.no_grad()
def evaluate(model, dataloader, device):
"""Evaluate on validation set."""
model.eval()
total_loss = 0
n_batches = 0
for batch in dataloader:
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
labels = batch['labels'].to(device)
style_keys = batch['style_keys']
outputs = model(input_ids, attention_mask, labels, style_keys)
total_loss += outputs.loss.item()
n_batches += 1
return {'val_loss': total_loss / n_batches}
@torch.no_grad()
def generate_samples(model, dataloader, device, n_samples=5):
"""Generate sample outputs for qualitative evaluation."""
model.eval()
samples = []
batch = next(iter(dataloader))
input_ids = batch['input_ids'][:n_samples].to(device)
attention_mask = batch['attention_mask'][:n_samples].to(device)
style_keys = batch['style_keys'][:n_samples]
# Only use human style keys for generation
gen_style_keys = []
for sk in style_keys:
if sk.startswith('ai_'):
gen_style_keys.append(sk.replace('ai_', 'human_'))
else:
gen_style_keys.append(sk)
generated = model.generate_text(
input_ids, attention_mask, gen_style_keys,
max_length=512, num_beams=4,
)
for i in range(n_samples):
input_text = model.tokenizer.decode(input_ids[i], skip_special_tokens=True)
output_text = model.tokenizer.decode(generated[i], skip_special_tokens=True)
samples.append({
'input': input_text[:200] + '...',
'output': output_text[:200] + '...',
'style': gen_style_keys[i],
})
return samples
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--train_data', default='data/train.jsonl')
parser.add_argument('--val_data', default='data/val.jsonl')
parser.add_argument('--output_dir', default='checkpoints/sft')
parser.add_argument('--model_name', default='facebook/bart-base')
parser.add_argument('--style_dim', type=int, default=64)
parser.add_argument('--batch_size', type=int, default=8)
parser.add_argument('--epochs', type=int, default=5)
parser.add_argument('--lr', type=float, default=3e-5)
parser.add_argument('--lambda_recon', type=float, default=0.3)
parser.add_argument('--recon_ratio', type=float, default=0.3)
parser.add_argument('--max_input_len', type=int, default=512)
parser.add_argument('--max_target_len', type=int, default=512)
parser.add_argument('--seed', type=int, default=42)
parser.add_argument('--log_every', type=int, default=50)
args = parser.parse_args()
# Setup
torch.manual_seed(args.seed)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
# Model
print(f"Loading model: {args.model_name}")
model = StyleBART(args.model_name, style_dim=args.style_dim)
model = model.to(device)
param_count = sum(p.numel() for p in model.parameters())
trainable_count = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total params: {param_count:,}")
print(f"Trainable params: {trainable_count:,}")
# Datasets
print("Loading datasets...")
train_dataset = MASHSFTDataset(
args.train_data, model.tokenizer,
max_input_len=args.max_input_len,
max_target_len=args.max_target_len,
include_reconstruction=True,
reconstruction_ratio=args.recon_ratio,
)
val_dataset = MASHSFTDataset(
args.val_data, model.tokenizer,
max_input_len=args.max_input_len,
max_target_len=args.max_target_len,
include_reconstruction=True,
reconstruction_ratio=args.recon_ratio,
)
train_loader = DataLoader(
train_dataset, batch_size=args.batch_size,
shuffle=True, collate_fn=collate_fn, num_workers=2,
)
val_loader = DataLoader(
val_dataset, batch_size=args.batch_size,
shuffle=False, collate_fn=collate_fn, num_workers=2,
)
print(f"Train examples: {len(train_dataset)} ({len(train_loader)} batches)")
print(f"Val examples: {len(val_dataset)} ({len(val_loader)} batches)")
# Optimizer and scheduler
optimizer = AdamW(model.parameters(), lr=args.lr, weight_decay=0.01)
total_steps = len(train_loader) * args.epochs
scheduler = CosineAnnealingLR(optimizer, T_max=total_steps, eta_min=1e-6)
# Training loop
os.makedirs(args.output_dir, exist_ok=True)
best_val_loss = float('inf')
history = []
print(f"\n{'='*60}")
print(f"Starting Style-SFT Training")
print(f" Epochs: {args.epochs}")
print(f" Batch size: {args.batch_size}")
print(f" Learning rate: {args.lr}")
print(f" Lambda (recon weight): {args.lambda_recon}")
print(f" Reconstruction ratio: {args.recon_ratio}")
print(f" Style dim: {args.style_dim}")
print(f"{'='*60}\n")
for epoch in range(1, args.epochs + 1):
t0 = time.time()
# Train
train_metrics = train_epoch(
model, train_loader, optimizer, scheduler,
device, lambda_recon=args.lambda_recon,
)
# Evaluate
val_metrics = evaluate(model, val_loader, device)
elapsed = time.time() - t0
metrics = {
'epoch': epoch,
'train_loss': train_metrics['loss'],
'train_trans_loss': train_metrics['trans_loss'],
'train_recon_loss': train_metrics['recon_loss'],
'val_loss': val_metrics['val_loss'],
'lr': optimizer.param_groups[0]['lr'],
'time': elapsed,
}
history.append(metrics)
print(f"Epoch {epoch}/{args.epochs} ({elapsed:.0f}s)")
print(f" Train loss: {metrics['train_loss']:.4f} "
f"(trans: {metrics['train_trans_loss']:.4f}, "
f"recon: {metrics['train_recon_loss']:.4f})")
print(f" Val loss: {metrics['val_loss']:.4f}")
print(f" LR: {metrics['lr']:.2e}")
# Save best model
if val_metrics['val_loss'] < best_val_loss:
best_val_loss = val_metrics['val_loss']
model.save_pretrained(os.path.join(args.output_dir, 'best'))
print(f" ★ New best model saved (val_loss={best_val_loss:.4f})")
# Generate samples every epoch
samples = generate_samples(model, val_loader, device, n_samples=3)
print(" Sample outputs:")
for s in samples:
print(f" [{s['style']}] {s['output'][:100]}...")
print()
# Save final model
model.save_pretrained(os.path.join(args.output_dir, 'final'))
# Save training history
with open(os.path.join(args.output_dir, 'history.json'), 'w') as f:
json.dump(history, f, indent=2)
print(f"\nTraining complete!")
print(f"Best val loss: {best_val_loss:.4f}")
print(f"Models saved to: {args.output_dir}/")
if __name__ == '__main__':
main()