| """ |
| MASH Stage 2: Style-injection Supervised Fine-Tuning |
| |
| Trains StyleBART on dual-task objective: |
| - L_trans: (ai_text, s_human_*) → human_text (style transfer) |
| - L_recon: (ai_text, s_ai_*) → ai_text (reconstruction) |
| - L_SFT = λ·L_recon + (1-λ)·L_trans |
| |
| Multi-style: separate embeddings for PS and Supp genres. |
| """ |
|
|
| import os |
| import sys |
| import json |
| import time |
| import argparse |
| import torch |
| from torch.utils.data import DataLoader |
| from torch.optim import AdamW |
| from torch.optim.lr_scheduler import CosineAnnealingLR |
|
|
| |
| sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) |
| from model import StyleBART |
| from dataset import MASHSFTDataset, collate_fn |
|
|
|
|
| def train_epoch(model, dataloader, optimizer, scheduler, device, lambda_recon=0.3): |
| """Train one epoch with dual-task loss.""" |
| model.train() |
| total_loss = 0 |
| total_trans_loss = 0 |
| total_recon_loss = 0 |
| n_batches = 0 |
| |
| for batch in dataloader: |
| input_ids = batch['input_ids'].to(device) |
| attention_mask = batch['attention_mask'].to(device) |
| labels = batch['labels'].to(device) |
| style_keys = batch['style_keys'] |
| tasks = batch['tasks'] |
| |
| |
| outputs = model(input_ids, attention_mask, labels, style_keys) |
| |
| |
| loss_per_sample = torch.zeros(len(tasks), device=device) |
| for i, task in enumerate(tasks): |
| |
| sample_labels = labels[i:i+1] |
| sample_logits = outputs.logits[i:i+1] |
| |
| |
| shift_logits = sample_logits[..., :-1, :].contiguous() |
| shift_labels = sample_labels[..., 1:].contiguous() |
| loss_fn = torch.nn.CrossEntropyLoss(ignore_index=-100) |
| sample_loss = loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) |
| loss_per_sample[i] = sample_loss |
| |
| |
| trans_mask = torch.tensor([1.0 if t == 'transfer' else 0.0 for t in tasks], device=device) |
| recon_mask = torch.tensor([1.0 if t == 'reconstruction' else 0.0 for t in tasks], device=device) |
| |
| n_trans = trans_mask.sum().clamp(min=1) |
| n_recon = recon_mask.sum().clamp(min=1) |
| |
| trans_loss = (loss_per_sample * trans_mask).sum() / n_trans |
| recon_loss = (loss_per_sample * recon_mask).sum() / n_recon |
| |
| |
| if recon_mask.sum() > 0: |
| loss = lambda_recon * recon_loss + (1 - lambda_recon) * trans_loss |
| else: |
| loss = trans_loss |
| |
| |
| optimizer.zero_grad() |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) |
| optimizer.step() |
| scheduler.step() |
| |
| total_loss += loss.item() |
| total_trans_loss += trans_loss.item() |
| total_recon_loss += recon_loss.item() if recon_mask.sum() > 0 else 0 |
| n_batches += 1 |
| |
| return { |
| 'loss': total_loss / n_batches, |
| 'trans_loss': total_trans_loss / n_batches, |
| 'recon_loss': total_recon_loss / n_batches, |
| } |
|
|
|
|
| @torch.no_grad() |
| def evaluate(model, dataloader, device): |
| """Evaluate on validation set.""" |
| model.eval() |
| total_loss = 0 |
| n_batches = 0 |
| |
| for batch in dataloader: |
| input_ids = batch['input_ids'].to(device) |
| attention_mask = batch['attention_mask'].to(device) |
| labels = batch['labels'].to(device) |
| style_keys = batch['style_keys'] |
| |
| outputs = model(input_ids, attention_mask, labels, style_keys) |
| total_loss += outputs.loss.item() |
| n_batches += 1 |
| |
| return {'val_loss': total_loss / n_batches} |
|
|
|
|
| @torch.no_grad() |
| def generate_samples(model, dataloader, device, n_samples=5): |
| """Generate sample outputs for qualitative evaluation.""" |
| model.eval() |
| samples = [] |
| |
| batch = next(iter(dataloader)) |
| input_ids = batch['input_ids'][:n_samples].to(device) |
| attention_mask = batch['attention_mask'][:n_samples].to(device) |
| style_keys = batch['style_keys'][:n_samples] |
| |
| |
| gen_style_keys = [] |
| for sk in style_keys: |
| if sk.startswith('ai_'): |
| gen_style_keys.append(sk.replace('ai_', 'human_')) |
| else: |
| gen_style_keys.append(sk) |
| |
| generated = model.generate_text( |
| input_ids, attention_mask, gen_style_keys, |
| max_length=512, num_beams=4, |
| ) |
| |
| for i in range(n_samples): |
| input_text = model.tokenizer.decode(input_ids[i], skip_special_tokens=True) |
| output_text = model.tokenizer.decode(generated[i], skip_special_tokens=True) |
| samples.append({ |
| 'input': input_text[:200] + '...', |
| 'output': output_text[:200] + '...', |
| 'style': gen_style_keys[i], |
| }) |
| |
| return samples |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument('--train_data', default='data/train.jsonl') |
| parser.add_argument('--val_data', default='data/val.jsonl') |
| parser.add_argument('--output_dir', default='checkpoints/sft') |
| parser.add_argument('--model_name', default='facebook/bart-base') |
| parser.add_argument('--style_dim', type=int, default=64) |
| parser.add_argument('--batch_size', type=int, default=8) |
| parser.add_argument('--epochs', type=int, default=5) |
| parser.add_argument('--lr', type=float, default=3e-5) |
| parser.add_argument('--lambda_recon', type=float, default=0.3) |
| parser.add_argument('--recon_ratio', type=float, default=0.3) |
| parser.add_argument('--max_input_len', type=int, default=512) |
| parser.add_argument('--max_target_len', type=int, default=512) |
| parser.add_argument('--seed', type=int, default=42) |
| parser.add_argument('--log_every', type=int, default=50) |
| args = parser.parse_args() |
| |
| |
| torch.manual_seed(args.seed) |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| print(f"Device: {device}") |
| |
| |
| print(f"Loading model: {args.model_name}") |
| model = StyleBART(args.model_name, style_dim=args.style_dim) |
| model = model.to(device) |
| |
| param_count = sum(p.numel() for p in model.parameters()) |
| trainable_count = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| print(f"Total params: {param_count:,}") |
| print(f"Trainable params: {trainable_count:,}") |
| |
| |
| print("Loading datasets...") |
| train_dataset = MASHSFTDataset( |
| args.train_data, model.tokenizer, |
| max_input_len=args.max_input_len, |
| max_target_len=args.max_target_len, |
| include_reconstruction=True, |
| reconstruction_ratio=args.recon_ratio, |
| ) |
| val_dataset = MASHSFTDataset( |
| args.val_data, model.tokenizer, |
| max_input_len=args.max_input_len, |
| max_target_len=args.max_target_len, |
| include_reconstruction=True, |
| reconstruction_ratio=args.recon_ratio, |
| ) |
| |
| train_loader = DataLoader( |
| train_dataset, batch_size=args.batch_size, |
| shuffle=True, collate_fn=collate_fn, num_workers=2, |
| ) |
| val_loader = DataLoader( |
| val_dataset, batch_size=args.batch_size, |
| shuffle=False, collate_fn=collate_fn, num_workers=2, |
| ) |
| |
| print(f"Train examples: {len(train_dataset)} ({len(train_loader)} batches)") |
| print(f"Val examples: {len(val_dataset)} ({len(val_loader)} batches)") |
| |
| |
| optimizer = AdamW(model.parameters(), lr=args.lr, weight_decay=0.01) |
| total_steps = len(train_loader) * args.epochs |
| scheduler = CosineAnnealingLR(optimizer, T_max=total_steps, eta_min=1e-6) |
| |
| |
| os.makedirs(args.output_dir, exist_ok=True) |
| best_val_loss = float('inf') |
| history = [] |
| |
| print(f"\n{'='*60}") |
| print(f"Starting Style-SFT Training") |
| print(f" Epochs: {args.epochs}") |
| print(f" Batch size: {args.batch_size}") |
| print(f" Learning rate: {args.lr}") |
| print(f" Lambda (recon weight): {args.lambda_recon}") |
| print(f" Reconstruction ratio: {args.recon_ratio}") |
| print(f" Style dim: {args.style_dim}") |
| print(f"{'='*60}\n") |
| |
| for epoch in range(1, args.epochs + 1): |
| t0 = time.time() |
| |
| |
| train_metrics = train_epoch( |
| model, train_loader, optimizer, scheduler, |
| device, lambda_recon=args.lambda_recon, |
| ) |
| |
| |
| val_metrics = evaluate(model, val_loader, device) |
| |
| elapsed = time.time() - t0 |
| |
| metrics = { |
| 'epoch': epoch, |
| 'train_loss': train_metrics['loss'], |
| 'train_trans_loss': train_metrics['trans_loss'], |
| 'train_recon_loss': train_metrics['recon_loss'], |
| 'val_loss': val_metrics['val_loss'], |
| 'lr': optimizer.param_groups[0]['lr'], |
| 'time': elapsed, |
| } |
| history.append(metrics) |
| |
| print(f"Epoch {epoch}/{args.epochs} ({elapsed:.0f}s)") |
| print(f" Train loss: {metrics['train_loss']:.4f} " |
| f"(trans: {metrics['train_trans_loss']:.4f}, " |
| f"recon: {metrics['train_recon_loss']:.4f})") |
| print(f" Val loss: {metrics['val_loss']:.4f}") |
| print(f" LR: {metrics['lr']:.2e}") |
| |
| |
| if val_metrics['val_loss'] < best_val_loss: |
| best_val_loss = val_metrics['val_loss'] |
| model.save_pretrained(os.path.join(args.output_dir, 'best')) |
| print(f" ★ New best model saved (val_loss={best_val_loss:.4f})") |
| |
| |
| samples = generate_samples(model, val_loader, device, n_samples=3) |
| print(" Sample outputs:") |
| for s in samples: |
| print(f" [{s['style']}] {s['output'][:100]}...") |
| print() |
| |
| |
| model.save_pretrained(os.path.join(args.output_dir, 'final')) |
| |
| |
| with open(os.path.join(args.output_dir, 'history.json'), 'w') as f: |
| json.dump(history, f, indent=2) |
| |
| print(f"\nTraining complete!") |
| print(f"Best val loss: {best_val_loss:.4f}") |
| print(f"Models saved to: {args.output_dir}/") |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|