mash-stylebart-trainer / train_ddp.py
catninja123's picture
Upload train_ddp.py with huggingface_hub
5f948d4 verified
"""
MASH v6 DDP Training Script — launched via torchrun
Supports:
- Multi-GPU DDP (8x A100)
- SFT Stage 2 + DPO Stage 3
- Label smoothing, sampling generation, fp32 logprobs
- Auto checkpoint upload to Hub
- GPTZero evaluation
"""
import os
import sys
import json
import time
import gc
import math
import traceback
import requests
# Add src to path BEFORE any imports
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, os.path.join(SCRIPT_DIR, 'src'))
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
# ============ Config ============
HUB_REPO = 'catninja123/mash-v6-flan-t5-xl'
HF_TOKEN = os.environ.get('HF_TOKEN', '')
GPTZERO_API_KEY = "fc8f8873b1f8444d9daf3d1ae8997cd1"
# Paths — try /data for persistent storage, fallback to local
if os.path.isdir('/data'):
SFT_DIR = '/data/checkpoints/sft_v6'
DPO_DIR = '/data/checkpoints/dpo_v6'
else:
SFT_DIR = os.path.join(SCRIPT_DIR, 'checkpoints', 'sft_v6')
DPO_DIR = os.path.join(SCRIPT_DIR, 'checkpoints', 'dpo_v6')
TRAIN_DATA = os.path.join(SCRIPT_DIR, 'data', 'train_supp.jsonl')
VAL_DATA = os.path.join(SCRIPT_DIR, 'data', 'val_supp.jsonl')
# SFT Config
SFT_EPOCHS = 2
SFT_BATCH_SIZE = 4 # per GPU
SFT_GRAD_ACCUM = 2 # effective batch = 4 * 8 * 2 = 64
SFT_LR = 5e-5
SFT_LABEL_SMOOTHING = 0.1
SFT_WARMUP_RATIO = 0.06
SFT_MAX_INPUT_LEN = 768
SFT_MAX_TARGET_LEN = 512
# DPO Config
DPO_EPOCHS = 2
DPO_BATCH_SIZE = 2 # per GPU (2 models loaded)
DPO_GRAD_ACCUM = 4 # effective batch = 2 * 8 * 4 = 64
DPO_LR = 5e-6
DPO_BETA = 0.1
DPO_WARMUP_RATIO = 0.1
DPO_MAX_INPUT_LEN = 768
DPO_MAX_TARGET_LEN = 512
# Generation Config
GEN_MAX_TOKENS = 512
GEN_TEMPERATURE = 0.9
GEN_TOP_P = 0.92
GEN_TOP_K = 50
GEN_NO_REPEAT_NGRAM = 3
# ============ DDP Setup ============
def setup_ddp():
dist.init_process_group(backend='nccl')
rank = dist.get_rank()
world_size = dist.get_world_size()
local_rank = int(os.environ.get('LOCAL_RANK', 0))
torch.cuda.set_device(local_rank)
return rank, world_size, local_rank
def cleanup_ddp():
dist.destroy_process_group()
def is_main():
return dist.get_rank() == 0
def log(msg):
"""Only rank 0 prints."""
if is_main():
print(f"[RANK0] {msg}", flush=True)
# ============ LR Schedule ============
def get_lr(step, total_steps, warmup_steps, base_lr):
if step < warmup_steps:
return base_lr * step / max(warmup_steps, 1)
progress = (step - warmup_steps) / max(total_steps - warmup_steps, 1)
return base_lr * 0.5 * (1 + math.cos(math.pi * progress))
# ============ DPO Loss ============
def compute_seq2seq_logprobs(model, input_ids, attention_mask, labels):
"""Compute per-sequence log probabilities with fp32 precision."""
with torch.amp.autocast('cuda', dtype=torch.bfloat16):
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
# Cast to fp32 for numerical stability
logits = outputs.logits.float()
log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
# Gather log probs for actual tokens
labels_for_gather = labels.clone()
labels_for_gather[labels_for_gather == -100] = 0
per_token_logps = log_probs.gather(-1, labels_for_gather.unsqueeze(-1)).squeeze(-1)
# Mask padding
mask = (labels != -100).float()
# Length-normalized log probability
seq_logps = (per_token_logps * mask).sum(-1) / mask.sum(-1).clamp(min=1)
return seq_logps
def dpo_loss(policy_chosen, policy_rejected, ref_chosen, ref_rejected, beta=0.1):
"""DPO loss with metrics."""
chosen_rewards = beta * (policy_chosen - ref_chosen)
rejected_rewards = beta * (policy_rejected - ref_rejected)
logits = chosen_rewards - rejected_rewards
loss = -torch.nn.functional.logsigmoid(logits).mean()
with torch.no_grad():
accuracy = (logits > 0).float().mean().item()
chosen_reward = chosen_rewards.mean().item()
rejected_reward = rejected_rewards.mean().item()
margin = (chosen_rewards - rejected_rewards).mean().item()
return loss, {
'accuracy': accuracy,
'chosen_reward': chosen_reward,
'rejected_reward': rejected_reward,
'reward_margin': margin,
}
# ============ GPTZero Evaluation ============
def evaluate_with_gptzero(model, tokenizer, val_loader, device, n_samples=5):
"""Generate samples and evaluate with GPTZero API."""
if not is_main():
return
log("\n--- GPTZero Evaluation ---")
# Use the unwrapped model for generation
gen_model = model.module if hasattr(model, 'module') else model
gen_model.eval()
results = []
for i, batch in enumerate(val_loader):
if i >= n_samples:
break
s_ids = batch['input_ids'][:1].to(device)
s_mask = batch['attention_mask'][:1].to(device)
with torch.no_grad(), torch.amp.autocast('cuda', dtype=torch.bfloat16):
gen = gen_model.generate(
input_ids=s_ids, attention_mask=s_mask,
max_new_tokens=GEN_MAX_TOKENS,
do_sample=True,
temperature=GEN_TEMPERATURE,
top_p=GEN_TOP_P,
top_k=GEN_TOP_K,
no_repeat_ngram_size=GEN_NO_REPEAT_NGRAM,
)
text = tokenizer.decode(gen[0], skip_special_tokens=True)
word_count = len(text.split())
if word_count < 50:
log(f" Sample {i}: too short ({word_count}w), skipping GPTZero")
continue
# Call GPTZero
try:
resp = requests.post(
"https://api.gptzero.me/v2/predict/text",
headers={"x-api-key": GPTZERO_API_KEY, "Content-Type": "application/json"},
json={"document": text},
timeout=30,
)
if resp.status_code == 200:
data = resp.json()
doc = data.get('documents', [{}])[0]
human_prob = doc.get('class_probabilities', {}).get('human', 0)
verdict = doc.get('predicted_class', 'unknown')
results.append(human_prob)
log(f" Sample {i}: {word_count}w, GPTZero={verdict} (human={human_prob:.1%})")
log(f" Text: {text[:200]}...")
else:
log(f" Sample {i}: GPTZero API error {resp.status_code}")
except Exception as e:
log(f" Sample {i}: GPTZero error: {e}")
time.sleep(1) # Rate limit
if results:
avg = sum(results) / len(results)
log(f"\n GPTZero avg human_prob: {avg:.1%} ({len(results)} samples)")
log("--- End GPTZero Evaluation ---\n")
# ============ SFT Training ============
def run_sft(rank, world_size, local_rank):
from transformers import T5ForConditionalGeneration, T5Tokenizer
from dataset_v4 import InstructionDataset, collate_fn
device = torch.device(f'cuda:{local_rank}')
sft_best = os.path.join(SFT_DIR, 'best')
os.makedirs(sft_best, exist_ok=True)
log("=" * 60)
log("STAGE 2: SFT Training (Supp only)")
log(f" GPUs: {world_size}, batch/gpu: {SFT_BATCH_SIZE}, grad_accum: {SFT_GRAD_ACCUM}")
log(f" Effective batch size: {SFT_BATCH_SIZE * world_size * SFT_GRAD_ACCUM}")
log("=" * 60)
# Load model
log("Loading Flan-T5-XL...")
tokenizer = T5Tokenizer.from_pretrained('google/flan-t5-xl')
model = T5ForConditionalGeneration.from_pretrained(
'google/flan-t5-xl', torch_dtype=torch.bfloat16
).to(device)
model.gradient_checkpointing_enable()
# Wrap in DDP
model = DDP(model, device_ids=[local_rank], output_device=local_rank,
find_unused_parameters=False)
if is_main() and torch.cuda.is_available():
log(f"GPU: {torch.cuda.get_device_name()}")
log(f"GPU memory per card: {torch.cuda.get_device_properties(0).total_memory/1e9:.1f}GB")
log(f"Model memory: {torch.cuda.memory_allocated(device)/1e9:.1f}GB")
# Dataset
train_ds = InstructionDataset(TRAIN_DATA, tokenizer, SFT_MAX_INPUT_LEN, SFT_MAX_TARGET_LEN)
val_ds = InstructionDataset(VAL_DATA, tokenizer, SFT_MAX_INPUT_LEN, SFT_MAX_TARGET_LEN)
train_sampler = DistributedSampler(train_ds, num_replicas=world_size, rank=rank, shuffle=True)
val_sampler = DistributedSampler(val_ds, num_replicas=world_size, rank=rank, shuffle=False)
train_loader = DataLoader(train_ds, batch_size=SFT_BATCH_SIZE, sampler=train_sampler,
collate_fn=collate_fn, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=SFT_BATCH_SIZE, sampler=val_sampler,
collate_fn=collate_fn, num_workers=4, pin_memory=True)
log(f"Dataset: {len(train_ds)} train, {len(val_ds)} val")
log(f"Steps/epoch: {len(train_loader)}, total optimizer steps: {len(train_loader) * SFT_EPOCHS // SFT_GRAD_ACCUM}")
# Optimizer & Loss
optimizer = torch.optim.AdamW(model.parameters(), lr=SFT_LR, weight_decay=0.01)
loss_fn = torch.nn.CrossEntropyLoss(
ignore_index=-100,
label_smoothing=SFT_LABEL_SMOOTHING,
)
total_steps = len(train_loader) * SFT_EPOCHS // SFT_GRAD_ACCUM
warmup_steps = int(total_steps * SFT_WARMUP_RATIO)
log(f"LR schedule: warmup={warmup_steps}, total={total_steps}")
best_val_loss = float('inf')
global_step = 0
for epoch in range(1, SFT_EPOCHS + 1):
model.train()
train_sampler.set_epoch(epoch)
epoch_loss = 0
epoch_n = 0
t0 = time.time()
optimizer.zero_grad()
lr = 0
for step, batch in enumerate(train_loader):
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
labels = batch['labels'].to(device)
with torch.amp.autocast('cuda', dtype=torch.bfloat16):
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
logits = outputs.logits
loss = loss_fn(logits.view(-1, logits.size(-1)), labels.view(-1))
scaled_loss = loss / SFT_GRAD_ACCUM
scaled_loss.backward()
epoch_loss += loss.item()
epoch_n += 1
if (step + 1) % SFT_GRAD_ACCUM == 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
global_step += 1
lr = get_lr(global_step, total_steps, warmup_steps, SFT_LR)
for pg in optimizer.param_groups:
pg['lr'] = lr
optimizer.step()
optimizer.zero_grad()
if is_main() and (step + 1) % 50 == 0:
avg = epoch_loss / epoch_n
eta = (time.time() - t0) / (step + 1) * (len(train_loader) - step - 1)
log(f" SFT E{epoch} step {step+1}/{len(train_loader)}: "
f"loss={avg:.4f}, lr={lr:.2e}, ETA={eta:.0f}s")
# Validation
model.eval()
val_loss_total = 0
val_n = 0
with torch.no_grad():
for batch in val_loader:
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
labels = batch['labels'].to(device)
with torch.amp.autocast('cuda', dtype=torch.bfloat16):
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
logits = outputs.logits
loss = loss_fn(logits.view(-1, logits.size(-1)), labels.view(-1))
val_loss_total += loss.item()
val_n += 1
# Aggregate val loss across ranks
val_loss_tensor = torch.tensor([val_loss_total, val_n], device=device)
dist.all_reduce(val_loss_tensor, op=dist.ReduceOp.SUM)
val_loss = val_loss_tensor[0].item() / max(val_loss_tensor[1].item(), 1)
# Aggregate train loss
train_loss_tensor = torch.tensor([epoch_loss, epoch_n], device=device, dtype=torch.float)
dist.all_reduce(train_loss_tensor, op=dist.ReduceOp.SUM)
train_loss = train_loss_tensor[0].item() / max(train_loss_tensor[1].item(), 1)
elapsed = time.time() - t0
log(f"\n SFT Epoch {epoch}/{SFT_EPOCHS}: train={train_loss:.4f}, val={val_loss:.4f} ({elapsed:.0f}s)")
if is_main() and val_loss < best_val_loss:
best_val_loss = val_loss
model.module.save_pretrained(sft_best)
tokenizer.save_pretrained(sft_best)
log(f" >>> Best SFT model saved (val_loss={best_val_loss:.4f})")
# Generate samples (rank 0 only)
if is_main():
try:
sample_batch = next(iter(val_loader))
s_ids = sample_batch['input_ids'][:2].to(device)
s_mask = sample_batch['attention_mask'][:2].to(device)
gen_model = model.module
with torch.no_grad(), torch.amp.autocast('cuda', dtype=torch.bfloat16):
gen = gen_model.generate(
input_ids=s_ids, attention_mask=s_mask,
max_new_tokens=256,
do_sample=True,
temperature=GEN_TEMPERATURE,
top_p=GEN_TOP_P,
top_k=GEN_TOP_K,
no_repeat_ngram_size=GEN_NO_REPEAT_NGRAM,
)
log(" SFT Samples:")
for i in range(min(2, len(gen))):
out = tokenizer.decode(gen[i], skip_special_tokens=True)
log(f" [{i}] {out[:300]}...")
except Exception as e:
log(f" Sample error: {e}")
dist.barrier()
torch.cuda.empty_cache()
gc.collect()
log(f"\nSFT Complete. Best val_loss={best_val_loss:.4f}")
# GPTZero eval (rank 0)
if is_main():
try:
eval_loader = DataLoader(val_ds, batch_size=1, shuffle=False,
collate_fn=collate_fn, num_workers=0)
evaluate_with_gptzero(model, tokenizer, eval_loader, device, n_samples=5)
except Exception as e:
log(f"GPTZero eval error: {e}")
# Upload SFT to Hub (rank 0)
if is_main():
try:
if HF_TOKEN:
sft_hub = HUB_REPO + '-sft'
log(f"Uploading SFT checkpoint to Hub: {sft_hub}")
model.module.push_to_hub(sft_hub, token=HF_TOKEN, private=True)
tokenizer.push_to_hub(sft_hub, token=HF_TOKEN, private=True)
log(" SFT upload complete!")
except Exception as e:
log(f" SFT Hub upload error: {e}")
# Cleanup
del model, optimizer
torch.cuda.empty_cache()
gc.collect()
dist.barrier()
return sft_best, tokenizer
# ============ DPO Training ============
def run_dpo(sft_ckpt, rank, world_size, local_rank):
from transformers import T5ForConditionalGeneration, T5Tokenizer
from dataset_dpo import DPODataset, dpo_collate_fn
from dataset_v4 import InstructionDataset, collate_fn as sft_collate_fn
device = torch.device(f'cuda:{local_rank}')
dpo_best = os.path.join(DPO_DIR, 'best')
os.makedirs(dpo_best, exist_ok=True)
log("\n" + "=" * 60)
log("STAGE 3: DPO Training (Supp only)")
log(f" GPUs: {world_size}, batch/gpu: {DPO_BATCH_SIZE}, grad_accum: {DPO_GRAD_ACCUM}")
log(f" Effective batch size: {DPO_BATCH_SIZE * world_size * DPO_GRAD_ACCUM}")
log("=" * 60)
tokenizer = T5Tokenizer.from_pretrained(sft_ckpt)
# Policy model (trainable)
log("Loading policy model...")
policy = T5ForConditionalGeneration.from_pretrained(
sft_ckpt, torch_dtype=torch.bfloat16
).to(device)
policy.gradient_checkpointing_enable()
policy = DDP(policy, device_ids=[local_rank], output_device=local_rank,
find_unused_parameters=False)
# Reference model (frozen, no DDP needed)
log("Loading reference model...")
ref = T5ForConditionalGeneration.from_pretrained(
sft_ckpt, torch_dtype=torch.bfloat16
).to(device)
ref.eval()
for p in ref.parameters():
p.requires_grad = False
if is_main() and torch.cuda.is_available():
log(f"GPU memory (2 models): {torch.cuda.memory_allocated(device)/1e9:.1f}GB")
# DPO Dataset
train_ds = DPODataset(TRAIN_DATA, tokenizer, DPO_MAX_INPUT_LEN, DPO_MAX_TARGET_LEN)
val_ds = DPODataset(VAL_DATA, tokenizer, DPO_MAX_INPUT_LEN, DPO_MAX_TARGET_LEN)
train_sampler = DistributedSampler(train_ds, num_replicas=world_size, rank=rank, shuffle=True)
val_sampler = DistributedSampler(val_ds, num_replicas=world_size, rank=rank, shuffle=False)
train_loader = DataLoader(train_ds, batch_size=DPO_BATCH_SIZE, sampler=train_sampler,
collate_fn=dpo_collate_fn, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=DPO_BATCH_SIZE, sampler=val_sampler,
collate_fn=dpo_collate_fn, num_workers=4, pin_memory=True)
# SFT val loader for generation eval
sft_val_ds = InstructionDataset(VAL_DATA, tokenizer, DPO_MAX_INPUT_LEN, DPO_MAX_TARGET_LEN)
sft_val_loader = DataLoader(sft_val_ds, batch_size=1, shuffle=False,
collate_fn=sft_collate_fn, num_workers=0)
log(f"DPO Dataset: {len(train_ds)} train, {len(val_ds)} val")
# Optimizer
optimizer = torch.optim.AdamW(policy.parameters(), lr=DPO_LR, weight_decay=0.01)
total_steps = len(train_loader) * DPO_EPOCHS // DPO_GRAD_ACCUM
warmup_steps = int(total_steps * DPO_WARMUP_RATIO)
log(f"Config: beta={DPO_BETA}, lr={DPO_LR}")
log(f"Steps: {total_steps} total, {warmup_steps} warmup")
best_val_loss = float('inf')
global_step = 0
for epoch in range(1, DPO_EPOCHS + 1):
policy.train()
train_sampler.set_epoch(epoch)
epoch_loss = 0
epoch_metrics = {'chosen_reward': 0, 'rejected_reward': 0, 'reward_margin': 0, 'accuracy': 0}
epoch_n = 0
t0 = time.time()
optimizer.zero_grad()
lr = 0
for step, batch in enumerate(train_loader):
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
chosen_labels = batch['chosen_labels'].to(device)
rejected_labels = batch['rejected_labels'].to(device)
# Policy logprobs (fp32 inside)
policy_chosen_logps = compute_seq2seq_logprobs(
policy, input_ids, attention_mask, chosen_labels)
policy_rejected_logps = compute_seq2seq_logprobs(
policy, input_ids, attention_mask, rejected_labels)
# Reference logprobs (no grad)
with torch.no_grad():
ref_chosen_logps = compute_seq2seq_logprobs(
ref, input_ids, attention_mask, chosen_labels)
ref_rejected_logps = compute_seq2seq_logprobs(
ref, input_ids, attention_mask, rejected_labels)
loss, metrics = dpo_loss(
policy_chosen_logps, policy_rejected_logps,
ref_chosen_logps, ref_rejected_logps,
beta=DPO_BETA,
)
scaled_loss = loss / DPO_GRAD_ACCUM
scaled_loss.backward()
epoch_loss += loss.item()
for k, v in metrics.items():
epoch_metrics[k] += v
epoch_n += 1
if (step + 1) % DPO_GRAD_ACCUM == 0:
torch.nn.utils.clip_grad_norm_(policy.parameters(), 1.0)
global_step += 1
lr = get_lr(global_step, total_steps, warmup_steps, DPO_LR)
for pg in optimizer.param_groups:
pg['lr'] = lr
optimizer.step()
optimizer.zero_grad()
if is_main() and global_step % 10 == 0:
avg_loss = epoch_loss / epoch_n
avg_acc = epoch_metrics['accuracy'] / epoch_n
avg_margin = epoch_metrics['reward_margin'] / epoch_n
eta = (time.time() - t0) / (step + 1) * (len(train_loader) - step - 1)
log(f" DPO E{epoch} step {global_step}/{total_steps}: "
f"loss={avg_loss:.4f}, acc={avg_acc:.2%}, "
f"margin={avg_margin:.3f}, lr={lr:.2e}, ETA={eta:.0f}s")
# Validation
policy.eval()
val_loss_total = 0
val_metrics_total = {'accuracy': 0}
val_n = 0
with torch.no_grad():
for batch in val_loader:
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
chosen_labels = batch['chosen_labels'].to(device)
rejected_labels = batch['rejected_labels'].to(device)
p_c = compute_seq2seq_logprobs(policy, input_ids, attention_mask, chosen_labels)
p_r = compute_seq2seq_logprobs(policy, input_ids, attention_mask, rejected_labels)
r_c = compute_seq2seq_logprobs(ref, input_ids, attention_mask, chosen_labels)
r_r = compute_seq2seq_logprobs(ref, input_ids, attention_mask, rejected_labels)
loss, metrics = dpo_loss(p_c, p_r, r_c, r_r, beta=DPO_BETA)
val_loss_total += loss.item()
val_metrics_total['accuracy'] += metrics['accuracy']
val_n += 1
# Aggregate across ranks
val_tensor = torch.tensor([val_loss_total, val_n, val_metrics_total['accuracy']], device=device)
dist.all_reduce(val_tensor, op=dist.ReduceOp.SUM)
val_loss = val_tensor[0].item() / max(val_tensor[1].item(), 1)
val_acc = val_tensor[2].item() / max(val_tensor[1].item(), 1)
train_tensor = torch.tensor([epoch_loss, epoch_n], device=device, dtype=torch.float)
dist.all_reduce(train_tensor, op=dist.ReduceOp.SUM)
train_loss = train_tensor[0].item() / max(train_tensor[1].item(), 1)
train_acc = epoch_metrics['accuracy'] / max(epoch_n, 1)
elapsed = time.time() - t0
log(f"\n DPO Epoch {epoch}/{DPO_EPOCHS} ({elapsed:.0f}s)")
log(f" Train: loss={train_loss:.4f}, acc={train_acc:.2%}")
log(f" Val: loss={val_loss:.4f}, acc={val_acc:.2%}")
if is_main() and torch.cuda.is_available():
log(f" GPU: {torch.cuda.max_memory_allocated(device)/1e9:.1f}GB peak")
if is_main() and val_loss < best_val_loss:
best_val_loss = val_loss
policy.module.save_pretrained(dpo_best)
tokenizer.save_pretrained(dpo_best)
log(f" >>> Best DPO model saved (val_loss={val_loss:.4f})")
# Generate samples (rank 0)
if is_main():
try:
sample_batch = next(iter(sft_val_loader))
s_ids = sample_batch['input_ids'][:1].to(device)
s_mask = sample_batch['attention_mask'][:1].to(device)
with torch.no_grad(), torch.amp.autocast('cuda', dtype=torch.bfloat16):
gen = policy.module.generate(
input_ids=s_ids, attention_mask=s_mask,
max_new_tokens=GEN_MAX_TOKENS,
do_sample=True,
temperature=GEN_TEMPERATURE,
top_p=GEN_TOP_P,
top_k=GEN_TOP_K,
no_repeat_ngram_size=GEN_NO_REPEAT_NGRAM,
)
out = tokenizer.decode(gen[0], skip_special_tokens=True)
log(f" DPO Sample: {out[:300]}...")
except Exception as e:
log(f" Sample error: {e}")
dist.barrier()
torch.cuda.empty_cache()
gc.collect()
# Save final
if is_main():
final_dir = os.path.join(DPO_DIR, 'final')
os.makedirs(final_dir, exist_ok=True)
policy.module.save_pretrained(final_dir)
tokenizer.save_pretrained(final_dir)
log(f"\nDPO Complete. Best val_loss={best_val_loss:.4f}")
# GPTZero eval (rank 0)
if is_main():
try:
evaluate_with_gptzero(policy, tokenizer, sft_val_loader, device, n_samples=10)
except Exception as e:
log(f"GPTZero eval error: {e}")
# Upload to Hub (rank 0)
if is_main():
try:
if HF_TOKEN:
log(f"\nUploading DPO model to Hub: {HUB_REPO}")
policy.module.push_to_hub(HUB_REPO, token=HF_TOKEN, private=True)
tokenizer.push_to_hub(HUB_REPO, token=HF_TOKEN, private=True)
log(" DPO upload complete!")
except Exception as e:
log(f" Hub upload error: {e}")
dist.barrier()
return dpo_best
# ============ Main ============
def main():
try:
rank, world_size, local_rank = setup_ddp()
except Exception as e:
print(f"[ERROR] DDP setup failed: {e}", flush=True)
traceback.print_exc()
sys.exit(1)
if is_main():
print(f"[RANK0] Starting MASH v6 DDP training", flush=True)
print(f"[RANK0] World size: {world_size}", flush=True)
print(f"[RANK0] SCRIPT_DIR: {SCRIPT_DIR}", flush=True)
print(f"[RANK0] SFT_DIR: {SFT_DIR}", flush=True)
print(f"[RANK0] TRAIN_DATA: {TRAIN_DATA} (exists={os.path.exists(TRAIN_DATA)})", flush=True)
print(f"[RANK0] VAL_DATA: {VAL_DATA} (exists={os.path.exists(VAL_DATA)})", flush=True)
for i in range(world_size):
print(f"[RANK0] GPU {i}: {torch.cuda.get_device_name(i)}", flush=True)
print(f"[RANK0] Python: {sys.version}", flush=True)
print(f"[RANK0] PyTorch: {torch.__version__}", flush=True)
print(f"[RANK0] sys.path: {sys.path[:5]}", flush=True)
# Check for existing SFT checkpoint
sft_ckpt = os.path.join(SFT_DIR, 'best')
need_sft = not os.path.exists(os.path.join(sft_ckpt, 'config.json'))
if need_sft and is_main():
# Try Hub
sft_hub = HUB_REPO + '-sft'
try:
from transformers import T5ForConditionalGeneration, T5Tokenizer
print(f"[RANK0] Trying Hub: {sft_hub}", flush=True)
os.makedirs(sft_ckpt, exist_ok=True)
m = T5ForConditionalGeneration.from_pretrained(sft_hub, token=HF_TOKEN)
t = T5Tokenizer.from_pretrained(sft_hub, token=HF_TOKEN)
m.save_pretrained(sft_ckpt)
t.save_pretrained(sft_ckpt)
del m, t
need_sft = False
print(f"[RANK0] Downloaded SFT from Hub", flush=True)
except Exception as e:
print(f"[RANK0] No Hub checkpoint: {e}", flush=True)
# Broadcast need_sft to all ranks
need_sft_tensor = torch.tensor([1 if need_sft else 0], device=f'cuda:{local_rank}')
dist.broadcast(need_sft_tensor, src=0)
need_sft = need_sft_tensor.item() == 1
if need_sft:
sft_ckpt, _ = run_sft(rank, world_size, local_rank)
else:
if is_main():
print(f"[RANK0] Using existing SFT checkpoint", flush=True)
dist.barrier()
# Run DPO
run_dpo(sft_ckpt, rank, world_size, local_rank)
if is_main():
print("[RANK0] " + "=" * 60, flush=True)
print("[RANK0] ALL TRAINING COMPLETE!", flush=True)
print("[RANK0] " + "=" * 60, flush=True)
cleanup_ddp()
if __name__ == '__main__':
try:
main()
except Exception as e:
print(f"[FATAL] Unhandled exception: {e}", flush=True)
traceback.print_exc()
sys.exit(1)