| """ |
| MASH v6 DDP Training Script — launched via torchrun |
| |
| Supports: |
| - Multi-GPU DDP (8x A100) |
| - SFT Stage 2 + DPO Stage 3 |
| - Label smoothing, sampling generation, fp32 logprobs |
| - Auto checkpoint upload to Hub |
| - GPTZero evaluation |
| """ |
|
|
| import os |
| import sys |
| import json |
| import time |
| import gc |
| import math |
| import traceback |
| import requests |
|
|
| |
| SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) |
| sys.path.insert(0, os.path.join(SCRIPT_DIR, 'src')) |
|
|
| import torch |
| import torch.distributed as dist |
| from torch.nn.parallel import DistributedDataParallel as DDP |
| from torch.utils.data import DataLoader |
| from torch.utils.data.distributed import DistributedSampler |
|
|
| |
| HUB_REPO = 'catninja123/mash-v6-flan-t5-xl' |
| HF_TOKEN = os.environ.get('HF_TOKEN', '') |
| GPTZERO_API_KEY = "fc8f8873b1f8444d9daf3d1ae8997cd1" |
|
|
| |
| if os.path.isdir('/data'): |
| SFT_DIR = '/data/checkpoints/sft_v6' |
| DPO_DIR = '/data/checkpoints/dpo_v6' |
| else: |
| SFT_DIR = os.path.join(SCRIPT_DIR, 'checkpoints', 'sft_v6') |
| DPO_DIR = os.path.join(SCRIPT_DIR, 'checkpoints', 'dpo_v6') |
| TRAIN_DATA = os.path.join(SCRIPT_DIR, 'data', 'train_supp.jsonl') |
| VAL_DATA = os.path.join(SCRIPT_DIR, 'data', 'val_supp.jsonl') |
|
|
| |
| SFT_EPOCHS = 2 |
| SFT_BATCH_SIZE = 4 |
| SFT_GRAD_ACCUM = 2 |
| SFT_LR = 5e-5 |
| SFT_LABEL_SMOOTHING = 0.1 |
| SFT_WARMUP_RATIO = 0.06 |
| SFT_MAX_INPUT_LEN = 768 |
| SFT_MAX_TARGET_LEN = 512 |
|
|
| |
| DPO_EPOCHS = 2 |
| DPO_BATCH_SIZE = 2 |
| DPO_GRAD_ACCUM = 4 |
| DPO_LR = 5e-6 |
| DPO_BETA = 0.1 |
| DPO_WARMUP_RATIO = 0.1 |
| DPO_MAX_INPUT_LEN = 768 |
| DPO_MAX_TARGET_LEN = 512 |
|
|
| |
| GEN_MAX_TOKENS = 512 |
| GEN_TEMPERATURE = 0.9 |
| GEN_TOP_P = 0.92 |
| GEN_TOP_K = 50 |
| GEN_NO_REPEAT_NGRAM = 3 |
|
|
| |
| def setup_ddp(): |
| dist.init_process_group(backend='nccl') |
| rank = dist.get_rank() |
| world_size = dist.get_world_size() |
| local_rank = int(os.environ.get('LOCAL_RANK', 0)) |
| torch.cuda.set_device(local_rank) |
| return rank, world_size, local_rank |
|
|
|
|
| def cleanup_ddp(): |
| dist.destroy_process_group() |
|
|
|
|
| def is_main(): |
| return dist.get_rank() == 0 |
|
|
|
|
| def log(msg): |
| """Only rank 0 prints.""" |
| if is_main(): |
| print(f"[RANK0] {msg}", flush=True) |
|
|
|
|
| |
| def get_lr(step, total_steps, warmup_steps, base_lr): |
| if step < warmup_steps: |
| return base_lr * step / max(warmup_steps, 1) |
| progress = (step - warmup_steps) / max(total_steps - warmup_steps, 1) |
| return base_lr * 0.5 * (1 + math.cos(math.pi * progress)) |
|
|
|
|
| |
| def compute_seq2seq_logprobs(model, input_ids, attention_mask, labels): |
| """Compute per-sequence log probabilities with fp32 precision.""" |
| with torch.amp.autocast('cuda', dtype=torch.bfloat16): |
| outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) |
|
|
| |
| logits = outputs.logits.float() |
| log_probs = torch.nn.functional.log_softmax(logits, dim=-1) |
|
|
| |
| labels_for_gather = labels.clone() |
| labels_for_gather[labels_for_gather == -100] = 0 |
| per_token_logps = log_probs.gather(-1, labels_for_gather.unsqueeze(-1)).squeeze(-1) |
|
|
| |
| mask = (labels != -100).float() |
| |
| seq_logps = (per_token_logps * mask).sum(-1) / mask.sum(-1).clamp(min=1) |
|
|
| return seq_logps |
|
|
|
|
| def dpo_loss(policy_chosen, policy_rejected, ref_chosen, ref_rejected, beta=0.1): |
| """DPO loss with metrics.""" |
| chosen_rewards = beta * (policy_chosen - ref_chosen) |
| rejected_rewards = beta * (policy_rejected - ref_rejected) |
|
|
| logits = chosen_rewards - rejected_rewards |
| loss = -torch.nn.functional.logsigmoid(logits).mean() |
|
|
| with torch.no_grad(): |
| accuracy = (logits > 0).float().mean().item() |
| chosen_reward = chosen_rewards.mean().item() |
| rejected_reward = rejected_rewards.mean().item() |
| margin = (chosen_rewards - rejected_rewards).mean().item() |
|
|
| return loss, { |
| 'accuracy': accuracy, |
| 'chosen_reward': chosen_reward, |
| 'rejected_reward': rejected_reward, |
| 'reward_margin': margin, |
| } |
|
|
|
|
| |
| def evaluate_with_gptzero(model, tokenizer, val_loader, device, n_samples=5): |
| """Generate samples and evaluate with GPTZero API.""" |
| if not is_main(): |
| return |
|
|
| log("\n--- GPTZero Evaluation ---") |
| |
| gen_model = model.module if hasattr(model, 'module') else model |
| gen_model.eval() |
|
|
| results = [] |
| for i, batch in enumerate(val_loader): |
| if i >= n_samples: |
| break |
| s_ids = batch['input_ids'][:1].to(device) |
| s_mask = batch['attention_mask'][:1].to(device) |
|
|
| with torch.no_grad(), torch.amp.autocast('cuda', dtype=torch.bfloat16): |
| gen = gen_model.generate( |
| input_ids=s_ids, attention_mask=s_mask, |
| max_new_tokens=GEN_MAX_TOKENS, |
| do_sample=True, |
| temperature=GEN_TEMPERATURE, |
| top_p=GEN_TOP_P, |
| top_k=GEN_TOP_K, |
| no_repeat_ngram_size=GEN_NO_REPEAT_NGRAM, |
| ) |
| text = tokenizer.decode(gen[0], skip_special_tokens=True) |
| word_count = len(text.split()) |
|
|
| if word_count < 50: |
| log(f" Sample {i}: too short ({word_count}w), skipping GPTZero") |
| continue |
|
|
| |
| try: |
| resp = requests.post( |
| "https://api.gptzero.me/v2/predict/text", |
| headers={"x-api-key": GPTZERO_API_KEY, "Content-Type": "application/json"}, |
| json={"document": text}, |
| timeout=30, |
| ) |
| if resp.status_code == 200: |
| data = resp.json() |
| doc = data.get('documents', [{}])[0] |
| human_prob = doc.get('class_probabilities', {}).get('human', 0) |
| verdict = doc.get('predicted_class', 'unknown') |
| results.append(human_prob) |
| log(f" Sample {i}: {word_count}w, GPTZero={verdict} (human={human_prob:.1%})") |
| log(f" Text: {text[:200]}...") |
| else: |
| log(f" Sample {i}: GPTZero API error {resp.status_code}") |
| except Exception as e: |
| log(f" Sample {i}: GPTZero error: {e}") |
|
|
| time.sleep(1) |
|
|
| if results: |
| avg = sum(results) / len(results) |
| log(f"\n GPTZero avg human_prob: {avg:.1%} ({len(results)} samples)") |
| log("--- End GPTZero Evaluation ---\n") |
|
|
|
|
| |
| def run_sft(rank, world_size, local_rank): |
| from transformers import T5ForConditionalGeneration, T5Tokenizer |
| from dataset_v4 import InstructionDataset, collate_fn |
|
|
| device = torch.device(f'cuda:{local_rank}') |
| sft_best = os.path.join(SFT_DIR, 'best') |
| os.makedirs(sft_best, exist_ok=True) |
|
|
| log("=" * 60) |
| log("STAGE 2: SFT Training (Supp only)") |
| log(f" GPUs: {world_size}, batch/gpu: {SFT_BATCH_SIZE}, grad_accum: {SFT_GRAD_ACCUM}") |
| log(f" Effective batch size: {SFT_BATCH_SIZE * world_size * SFT_GRAD_ACCUM}") |
| log("=" * 60) |
|
|
| |
| log("Loading Flan-T5-XL...") |
| tokenizer = T5Tokenizer.from_pretrained('google/flan-t5-xl') |
| model = T5ForConditionalGeneration.from_pretrained( |
| 'google/flan-t5-xl', torch_dtype=torch.bfloat16 |
| ).to(device) |
| model.gradient_checkpointing_enable() |
|
|
| |
| model = DDP(model, device_ids=[local_rank], output_device=local_rank, |
| find_unused_parameters=False) |
|
|
| if is_main() and torch.cuda.is_available(): |
| log(f"GPU: {torch.cuda.get_device_name()}") |
| log(f"GPU memory per card: {torch.cuda.get_device_properties(0).total_memory/1e9:.1f}GB") |
| log(f"Model memory: {torch.cuda.memory_allocated(device)/1e9:.1f}GB") |
|
|
| |
| train_ds = InstructionDataset(TRAIN_DATA, tokenizer, SFT_MAX_INPUT_LEN, SFT_MAX_TARGET_LEN) |
| val_ds = InstructionDataset(VAL_DATA, tokenizer, SFT_MAX_INPUT_LEN, SFT_MAX_TARGET_LEN) |
|
|
| train_sampler = DistributedSampler(train_ds, num_replicas=world_size, rank=rank, shuffle=True) |
| val_sampler = DistributedSampler(val_ds, num_replicas=world_size, rank=rank, shuffle=False) |
|
|
| train_loader = DataLoader(train_ds, batch_size=SFT_BATCH_SIZE, sampler=train_sampler, |
| collate_fn=collate_fn, num_workers=4, pin_memory=True) |
| val_loader = DataLoader(val_ds, batch_size=SFT_BATCH_SIZE, sampler=val_sampler, |
| collate_fn=collate_fn, num_workers=4, pin_memory=True) |
|
|
| log(f"Dataset: {len(train_ds)} train, {len(val_ds)} val") |
| log(f"Steps/epoch: {len(train_loader)}, total optimizer steps: {len(train_loader) * SFT_EPOCHS // SFT_GRAD_ACCUM}") |
|
|
| |
| optimizer = torch.optim.AdamW(model.parameters(), lr=SFT_LR, weight_decay=0.01) |
| loss_fn = torch.nn.CrossEntropyLoss( |
| ignore_index=-100, |
| label_smoothing=SFT_LABEL_SMOOTHING, |
| ) |
|
|
| total_steps = len(train_loader) * SFT_EPOCHS // SFT_GRAD_ACCUM |
| warmup_steps = int(total_steps * SFT_WARMUP_RATIO) |
| log(f"LR schedule: warmup={warmup_steps}, total={total_steps}") |
|
|
| best_val_loss = float('inf') |
| global_step = 0 |
|
|
| for epoch in range(1, SFT_EPOCHS + 1): |
| model.train() |
| train_sampler.set_epoch(epoch) |
| epoch_loss = 0 |
| epoch_n = 0 |
| t0 = time.time() |
| optimizer.zero_grad() |
| lr = 0 |
|
|
| for step, batch in enumerate(train_loader): |
| input_ids = batch['input_ids'].to(device) |
| attention_mask = batch['attention_mask'].to(device) |
| labels = batch['labels'].to(device) |
|
|
| with torch.amp.autocast('cuda', dtype=torch.bfloat16): |
| outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) |
| logits = outputs.logits |
| loss = loss_fn(logits.view(-1, logits.size(-1)), labels.view(-1)) |
| scaled_loss = loss / SFT_GRAD_ACCUM |
|
|
| scaled_loss.backward() |
| epoch_loss += loss.item() |
| epoch_n += 1 |
|
|
| if (step + 1) % SFT_GRAD_ACCUM == 0: |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| global_step += 1 |
| lr = get_lr(global_step, total_steps, warmup_steps, SFT_LR) |
| for pg in optimizer.param_groups: |
| pg['lr'] = lr |
| optimizer.step() |
| optimizer.zero_grad() |
|
|
| if is_main() and (step + 1) % 50 == 0: |
| avg = epoch_loss / epoch_n |
| eta = (time.time() - t0) / (step + 1) * (len(train_loader) - step - 1) |
| log(f" SFT E{epoch} step {step+1}/{len(train_loader)}: " |
| f"loss={avg:.4f}, lr={lr:.2e}, ETA={eta:.0f}s") |
|
|
| |
| model.eval() |
| val_loss_total = 0 |
| val_n = 0 |
| with torch.no_grad(): |
| for batch in val_loader: |
| input_ids = batch['input_ids'].to(device) |
| attention_mask = batch['attention_mask'].to(device) |
| labels = batch['labels'].to(device) |
| with torch.amp.autocast('cuda', dtype=torch.bfloat16): |
| outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) |
| logits = outputs.logits |
| loss = loss_fn(logits.view(-1, logits.size(-1)), labels.view(-1)) |
| val_loss_total += loss.item() |
| val_n += 1 |
|
|
| |
| val_loss_tensor = torch.tensor([val_loss_total, val_n], device=device) |
| dist.all_reduce(val_loss_tensor, op=dist.ReduceOp.SUM) |
| val_loss = val_loss_tensor[0].item() / max(val_loss_tensor[1].item(), 1) |
|
|
| |
| train_loss_tensor = torch.tensor([epoch_loss, epoch_n], device=device, dtype=torch.float) |
| dist.all_reduce(train_loss_tensor, op=dist.ReduceOp.SUM) |
| train_loss = train_loss_tensor[0].item() / max(train_loss_tensor[1].item(), 1) |
|
|
| elapsed = time.time() - t0 |
| log(f"\n SFT Epoch {epoch}/{SFT_EPOCHS}: train={train_loss:.4f}, val={val_loss:.4f} ({elapsed:.0f}s)") |
|
|
| if is_main() and val_loss < best_val_loss: |
| best_val_loss = val_loss |
| model.module.save_pretrained(sft_best) |
| tokenizer.save_pretrained(sft_best) |
| log(f" >>> Best SFT model saved (val_loss={best_val_loss:.4f})") |
|
|
| |
| if is_main(): |
| try: |
| sample_batch = next(iter(val_loader)) |
| s_ids = sample_batch['input_ids'][:2].to(device) |
| s_mask = sample_batch['attention_mask'][:2].to(device) |
| gen_model = model.module |
| with torch.no_grad(), torch.amp.autocast('cuda', dtype=torch.bfloat16): |
| gen = gen_model.generate( |
| input_ids=s_ids, attention_mask=s_mask, |
| max_new_tokens=256, |
| do_sample=True, |
| temperature=GEN_TEMPERATURE, |
| top_p=GEN_TOP_P, |
| top_k=GEN_TOP_K, |
| no_repeat_ngram_size=GEN_NO_REPEAT_NGRAM, |
| ) |
| log(" SFT Samples:") |
| for i in range(min(2, len(gen))): |
| out = tokenizer.decode(gen[i], skip_special_tokens=True) |
| log(f" [{i}] {out[:300]}...") |
| except Exception as e: |
| log(f" Sample error: {e}") |
|
|
| dist.barrier() |
| torch.cuda.empty_cache() |
| gc.collect() |
|
|
| log(f"\nSFT Complete. Best val_loss={best_val_loss:.4f}") |
|
|
| |
| if is_main(): |
| try: |
| eval_loader = DataLoader(val_ds, batch_size=1, shuffle=False, |
| collate_fn=collate_fn, num_workers=0) |
| evaluate_with_gptzero(model, tokenizer, eval_loader, device, n_samples=5) |
| except Exception as e: |
| log(f"GPTZero eval error: {e}") |
|
|
| |
| if is_main(): |
| try: |
| if HF_TOKEN: |
| sft_hub = HUB_REPO + '-sft' |
| log(f"Uploading SFT checkpoint to Hub: {sft_hub}") |
| model.module.push_to_hub(sft_hub, token=HF_TOKEN, private=True) |
| tokenizer.push_to_hub(sft_hub, token=HF_TOKEN, private=True) |
| log(" SFT upload complete!") |
| except Exception as e: |
| log(f" SFT Hub upload error: {e}") |
|
|
| |
| del model, optimizer |
| torch.cuda.empty_cache() |
| gc.collect() |
| dist.barrier() |
|
|
| return sft_best, tokenizer |
|
|
|
|
| |
| def run_dpo(sft_ckpt, rank, world_size, local_rank): |
| from transformers import T5ForConditionalGeneration, T5Tokenizer |
| from dataset_dpo import DPODataset, dpo_collate_fn |
| from dataset_v4 import InstructionDataset, collate_fn as sft_collate_fn |
|
|
| device = torch.device(f'cuda:{local_rank}') |
| dpo_best = os.path.join(DPO_DIR, 'best') |
| os.makedirs(dpo_best, exist_ok=True) |
|
|
| log("\n" + "=" * 60) |
| log("STAGE 3: DPO Training (Supp only)") |
| log(f" GPUs: {world_size}, batch/gpu: {DPO_BATCH_SIZE}, grad_accum: {DPO_GRAD_ACCUM}") |
| log(f" Effective batch size: {DPO_BATCH_SIZE * world_size * DPO_GRAD_ACCUM}") |
| log("=" * 60) |
|
|
| tokenizer = T5Tokenizer.from_pretrained(sft_ckpt) |
|
|
| |
| log("Loading policy model...") |
| policy = T5ForConditionalGeneration.from_pretrained( |
| sft_ckpt, torch_dtype=torch.bfloat16 |
| ).to(device) |
| policy.gradient_checkpointing_enable() |
| policy = DDP(policy, device_ids=[local_rank], output_device=local_rank, |
| find_unused_parameters=False) |
|
|
| |
| log("Loading reference model...") |
| ref = T5ForConditionalGeneration.from_pretrained( |
| sft_ckpt, torch_dtype=torch.bfloat16 |
| ).to(device) |
| ref.eval() |
| for p in ref.parameters(): |
| p.requires_grad = False |
|
|
| if is_main() and torch.cuda.is_available(): |
| log(f"GPU memory (2 models): {torch.cuda.memory_allocated(device)/1e9:.1f}GB") |
|
|
| |
| train_ds = DPODataset(TRAIN_DATA, tokenizer, DPO_MAX_INPUT_LEN, DPO_MAX_TARGET_LEN) |
| val_ds = DPODataset(VAL_DATA, tokenizer, DPO_MAX_INPUT_LEN, DPO_MAX_TARGET_LEN) |
|
|
| train_sampler = DistributedSampler(train_ds, num_replicas=world_size, rank=rank, shuffle=True) |
| val_sampler = DistributedSampler(val_ds, num_replicas=world_size, rank=rank, shuffle=False) |
|
|
| train_loader = DataLoader(train_ds, batch_size=DPO_BATCH_SIZE, sampler=train_sampler, |
| collate_fn=dpo_collate_fn, num_workers=4, pin_memory=True) |
| val_loader = DataLoader(val_ds, batch_size=DPO_BATCH_SIZE, sampler=val_sampler, |
| collate_fn=dpo_collate_fn, num_workers=4, pin_memory=True) |
|
|
| |
| sft_val_ds = InstructionDataset(VAL_DATA, tokenizer, DPO_MAX_INPUT_LEN, DPO_MAX_TARGET_LEN) |
| sft_val_loader = DataLoader(sft_val_ds, batch_size=1, shuffle=False, |
| collate_fn=sft_collate_fn, num_workers=0) |
|
|
| log(f"DPO Dataset: {len(train_ds)} train, {len(val_ds)} val") |
|
|
| |
| optimizer = torch.optim.AdamW(policy.parameters(), lr=DPO_LR, weight_decay=0.01) |
| total_steps = len(train_loader) * DPO_EPOCHS // DPO_GRAD_ACCUM |
| warmup_steps = int(total_steps * DPO_WARMUP_RATIO) |
|
|
| log(f"Config: beta={DPO_BETA}, lr={DPO_LR}") |
| log(f"Steps: {total_steps} total, {warmup_steps} warmup") |
|
|
| best_val_loss = float('inf') |
| global_step = 0 |
|
|
| for epoch in range(1, DPO_EPOCHS + 1): |
| policy.train() |
| train_sampler.set_epoch(epoch) |
| epoch_loss = 0 |
| epoch_metrics = {'chosen_reward': 0, 'rejected_reward': 0, 'reward_margin': 0, 'accuracy': 0} |
| epoch_n = 0 |
| t0 = time.time() |
| optimizer.zero_grad() |
| lr = 0 |
|
|
| for step, batch in enumerate(train_loader): |
| input_ids = batch['input_ids'].to(device) |
| attention_mask = batch['attention_mask'].to(device) |
| chosen_labels = batch['chosen_labels'].to(device) |
| rejected_labels = batch['rejected_labels'].to(device) |
|
|
| |
| policy_chosen_logps = compute_seq2seq_logprobs( |
| policy, input_ids, attention_mask, chosen_labels) |
| policy_rejected_logps = compute_seq2seq_logprobs( |
| policy, input_ids, attention_mask, rejected_labels) |
|
|
| |
| with torch.no_grad(): |
| ref_chosen_logps = compute_seq2seq_logprobs( |
| ref, input_ids, attention_mask, chosen_labels) |
| ref_rejected_logps = compute_seq2seq_logprobs( |
| ref, input_ids, attention_mask, rejected_labels) |
|
|
| loss, metrics = dpo_loss( |
| policy_chosen_logps, policy_rejected_logps, |
| ref_chosen_logps, ref_rejected_logps, |
| beta=DPO_BETA, |
| ) |
|
|
| scaled_loss = loss / DPO_GRAD_ACCUM |
| scaled_loss.backward() |
|
|
| epoch_loss += loss.item() |
| for k, v in metrics.items(): |
| epoch_metrics[k] += v |
| epoch_n += 1 |
|
|
| if (step + 1) % DPO_GRAD_ACCUM == 0: |
| torch.nn.utils.clip_grad_norm_(policy.parameters(), 1.0) |
| global_step += 1 |
| lr = get_lr(global_step, total_steps, warmup_steps, DPO_LR) |
| for pg in optimizer.param_groups: |
| pg['lr'] = lr |
| optimizer.step() |
| optimizer.zero_grad() |
|
|
| if is_main() and global_step % 10 == 0: |
| avg_loss = epoch_loss / epoch_n |
| avg_acc = epoch_metrics['accuracy'] / epoch_n |
| avg_margin = epoch_metrics['reward_margin'] / epoch_n |
| eta = (time.time() - t0) / (step + 1) * (len(train_loader) - step - 1) |
| log(f" DPO E{epoch} step {global_step}/{total_steps}: " |
| f"loss={avg_loss:.4f}, acc={avg_acc:.2%}, " |
| f"margin={avg_margin:.3f}, lr={lr:.2e}, ETA={eta:.0f}s") |
|
|
| |
| policy.eval() |
| val_loss_total = 0 |
| val_metrics_total = {'accuracy': 0} |
| val_n = 0 |
|
|
| with torch.no_grad(): |
| for batch in val_loader: |
| input_ids = batch['input_ids'].to(device) |
| attention_mask = batch['attention_mask'].to(device) |
| chosen_labels = batch['chosen_labels'].to(device) |
| rejected_labels = batch['rejected_labels'].to(device) |
|
|
| p_c = compute_seq2seq_logprobs(policy, input_ids, attention_mask, chosen_labels) |
| p_r = compute_seq2seq_logprobs(policy, input_ids, attention_mask, rejected_labels) |
| r_c = compute_seq2seq_logprobs(ref, input_ids, attention_mask, chosen_labels) |
| r_r = compute_seq2seq_logprobs(ref, input_ids, attention_mask, rejected_labels) |
|
|
| loss, metrics = dpo_loss(p_c, p_r, r_c, r_r, beta=DPO_BETA) |
| val_loss_total += loss.item() |
| val_metrics_total['accuracy'] += metrics['accuracy'] |
| val_n += 1 |
|
|
| |
| val_tensor = torch.tensor([val_loss_total, val_n, val_metrics_total['accuracy']], device=device) |
| dist.all_reduce(val_tensor, op=dist.ReduceOp.SUM) |
| val_loss = val_tensor[0].item() / max(val_tensor[1].item(), 1) |
| val_acc = val_tensor[2].item() / max(val_tensor[1].item(), 1) |
|
|
| train_tensor = torch.tensor([epoch_loss, epoch_n], device=device, dtype=torch.float) |
| dist.all_reduce(train_tensor, op=dist.ReduceOp.SUM) |
| train_loss = train_tensor[0].item() / max(train_tensor[1].item(), 1) |
| train_acc = epoch_metrics['accuracy'] / max(epoch_n, 1) |
|
|
| elapsed = time.time() - t0 |
| log(f"\n DPO Epoch {epoch}/{DPO_EPOCHS} ({elapsed:.0f}s)") |
| log(f" Train: loss={train_loss:.4f}, acc={train_acc:.2%}") |
| log(f" Val: loss={val_loss:.4f}, acc={val_acc:.2%}") |
|
|
| if is_main() and torch.cuda.is_available(): |
| log(f" GPU: {torch.cuda.max_memory_allocated(device)/1e9:.1f}GB peak") |
|
|
| if is_main() and val_loss < best_val_loss: |
| best_val_loss = val_loss |
| policy.module.save_pretrained(dpo_best) |
| tokenizer.save_pretrained(dpo_best) |
| log(f" >>> Best DPO model saved (val_loss={val_loss:.4f})") |
|
|
| |
| if is_main(): |
| try: |
| sample_batch = next(iter(sft_val_loader)) |
| s_ids = sample_batch['input_ids'][:1].to(device) |
| s_mask = sample_batch['attention_mask'][:1].to(device) |
| with torch.no_grad(), torch.amp.autocast('cuda', dtype=torch.bfloat16): |
| gen = policy.module.generate( |
| input_ids=s_ids, attention_mask=s_mask, |
| max_new_tokens=GEN_MAX_TOKENS, |
| do_sample=True, |
| temperature=GEN_TEMPERATURE, |
| top_p=GEN_TOP_P, |
| top_k=GEN_TOP_K, |
| no_repeat_ngram_size=GEN_NO_REPEAT_NGRAM, |
| ) |
| out = tokenizer.decode(gen[0], skip_special_tokens=True) |
| log(f" DPO Sample: {out[:300]}...") |
| except Exception as e: |
| log(f" Sample error: {e}") |
|
|
| dist.barrier() |
| torch.cuda.empty_cache() |
| gc.collect() |
|
|
| |
| if is_main(): |
| final_dir = os.path.join(DPO_DIR, 'final') |
| os.makedirs(final_dir, exist_ok=True) |
| policy.module.save_pretrained(final_dir) |
| tokenizer.save_pretrained(final_dir) |
| log(f"\nDPO Complete. Best val_loss={best_val_loss:.4f}") |
|
|
| |
| if is_main(): |
| try: |
| evaluate_with_gptzero(policy, tokenizer, sft_val_loader, device, n_samples=10) |
| except Exception as e: |
| log(f"GPTZero eval error: {e}") |
|
|
| |
| if is_main(): |
| try: |
| if HF_TOKEN: |
| log(f"\nUploading DPO model to Hub: {HUB_REPO}") |
| policy.module.push_to_hub(HUB_REPO, token=HF_TOKEN, private=True) |
| tokenizer.push_to_hub(HUB_REPO, token=HF_TOKEN, private=True) |
| log(" DPO upload complete!") |
| except Exception as e: |
| log(f" Hub upload error: {e}") |
|
|
| dist.barrier() |
| return dpo_best |
|
|
|
|
| |
| def main(): |
| try: |
| rank, world_size, local_rank = setup_ddp() |
| except Exception as e: |
| print(f"[ERROR] DDP setup failed: {e}", flush=True) |
| traceback.print_exc() |
| sys.exit(1) |
|
|
| if is_main(): |
| print(f"[RANK0] Starting MASH v6 DDP training", flush=True) |
| print(f"[RANK0] World size: {world_size}", flush=True) |
| print(f"[RANK0] SCRIPT_DIR: {SCRIPT_DIR}", flush=True) |
| print(f"[RANK0] SFT_DIR: {SFT_DIR}", flush=True) |
| print(f"[RANK0] TRAIN_DATA: {TRAIN_DATA} (exists={os.path.exists(TRAIN_DATA)})", flush=True) |
| print(f"[RANK0] VAL_DATA: {VAL_DATA} (exists={os.path.exists(VAL_DATA)})", flush=True) |
| for i in range(world_size): |
| print(f"[RANK0] GPU {i}: {torch.cuda.get_device_name(i)}", flush=True) |
| print(f"[RANK0] Python: {sys.version}", flush=True) |
| print(f"[RANK0] PyTorch: {torch.__version__}", flush=True) |
| print(f"[RANK0] sys.path: {sys.path[:5]}", flush=True) |
|
|
| |
| sft_ckpt = os.path.join(SFT_DIR, 'best') |
| need_sft = not os.path.exists(os.path.join(sft_ckpt, 'config.json')) |
|
|
| if need_sft and is_main(): |
| |
| sft_hub = HUB_REPO + '-sft' |
| try: |
| from transformers import T5ForConditionalGeneration, T5Tokenizer |
| print(f"[RANK0] Trying Hub: {sft_hub}", flush=True) |
| os.makedirs(sft_ckpt, exist_ok=True) |
| m = T5ForConditionalGeneration.from_pretrained(sft_hub, token=HF_TOKEN) |
| t = T5Tokenizer.from_pretrained(sft_hub, token=HF_TOKEN) |
| m.save_pretrained(sft_ckpt) |
| t.save_pretrained(sft_ckpt) |
| del m, t |
| need_sft = False |
| print(f"[RANK0] Downloaded SFT from Hub", flush=True) |
| except Exception as e: |
| print(f"[RANK0] No Hub checkpoint: {e}", flush=True) |
|
|
| |
| need_sft_tensor = torch.tensor([1 if need_sft else 0], device=f'cuda:{local_rank}') |
| dist.broadcast(need_sft_tensor, src=0) |
| need_sft = need_sft_tensor.item() == 1 |
|
|
| if need_sft: |
| sft_ckpt, _ = run_sft(rank, world_size, local_rank) |
| else: |
| if is_main(): |
| print(f"[RANK0] Using existing SFT checkpoint", flush=True) |
|
|
| dist.barrier() |
|
|
| |
| run_dpo(sft_ckpt, rank, world_size, local_rank) |
|
|
| if is_main(): |
| print("[RANK0] " + "=" * 60, flush=True) |
| print("[RANK0] ALL TRAINING COMPLETE!", flush=True) |
| print("[RANK0] " + "=" * 60, flush=True) |
|
|
| cleanup_ddp() |
|
|
|
|
| if __name__ == '__main__': |
| try: |
| main() |
| except Exception as e: |
| print(f"[FATAL] Unhandled exception: {e}", flush=True) |
| traceback.print_exc() |
| sys.exit(1) |
|
|