Notebook 4 v2: Complete bulletproof fixer training with all error fixes integrated
Browse files
notebook4_fixer_training_v2_FIXED.py
ADDED
|
@@ -0,0 +1,417 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# ============================================================
|
| 3 |
+
# NOTEBOOK 4/4: Fixer Model Training - BULLETPROOF VERSION
|
| 4 |
+
# ============================================================
|
| 5 |
+
# Run on Kaggle with T4 GPU
|
| 6 |
+
# This notebook is SELF-CONTAINED - can restart kernel and run all cells
|
| 7 |
+
# Estimated time: ~3-4 hours on T4
|
| 8 |
+
# Saves model to HF Hub: ayshajavd/codet5p-vuln-fixer
|
| 9 |
+
# ============================================================
|
| 10 |
+
# ALL PREVIOUS ERRORS FIXED:
|
| 11 |
+
# 1. Tokenizer: RobertaTokenizer.from_pretrained(..., use_fast=False)
|
| 12 |
+
# 2. NaN loss: fp16 DISABLED (avoid unscale_fp16 error), LR=5e-5, max_grad_norm=1.0
|
| 13 |
+
# 3. OOM: batch_size=2, gradient_accumulation=16, single GPU forced
|
| 14 |
+
# 4. DataParallel: model.to('cuda:0') + trainer.args._n_gpu=1
|
| 15 |
+
# 5. Deprecation: no no_cuda param, use warmup_ratio instead of warmup_steps
|
| 16 |
+
# 6. Padding warning: padding='max_length' in tokenizer
|
| 17 |
+
# 7. CodeBLEU: tree-sitter-c parser pre-installed check
|
| 18 |
+
# ============================================================
|
| 19 |
+
|
| 20 |
+
# %% [CELL 1] Install + Login
|
| 21 |
+
import subprocess
|
| 22 |
+
subprocess.run(["pip", "install", "-q", "transformers", "datasets", "scikit-learn",
|
| 23 |
+
"accelerate", "huggingface_hub", "evaluate", "sentencepiece",
|
| 24 |
+
"sacrebleu", "rouge_score", "codebleu", "tree-sitter-c", "problog"], capture_output=True)
|
| 25 |
+
|
| 26 |
+
from huggingface_hub import login
|
| 27 |
+
import os
|
| 28 |
+
try:
|
| 29 |
+
from kaggle_secrets import UserSecretsClient
|
| 30 |
+
token = UserSecretsClient().get_secret("HF_TOKEN")
|
| 31 |
+
except:
|
| 32 |
+
token = os.environ.get("HF_TOKEN", None)
|
| 33 |
+
if token:
|
| 34 |
+
login(token=token)
|
| 35 |
+
print("β
Logged in to HF Hub")
|
| 36 |
+
else:
|
| 37 |
+
print("β οΈ No HF token found. Set HF_TOKEN environment variable or Kaggle secret.")
|
| 38 |
+
|
| 39 |
+
# %% [CELL 2] Imports + Config
|
| 40 |
+
import json, numpy as np, torch
|
| 41 |
+
from datasets import load_dataset
|
| 42 |
+
from transformers import (
|
| 43 |
+
AutoModelForSeq2SeqLM, RobertaTokenizer,
|
| 44 |
+
Seq2SeqTrainingArguments, Seq2SeqTrainer,
|
| 45 |
+
DataCollatorForSeq2Seq, EarlyStoppingCallback,
|
| 46 |
+
)
|
| 47 |
+
import evaluate
|
| 48 |
+
from huggingface_hub import HfApi
|
| 49 |
+
|
| 50 |
+
MODEL_NAME = "Salesforce/codet5p-220m"
|
| 51 |
+
HUB_MODEL_ID = "ayshajavd/codet5p-vuln-fixer"
|
| 52 |
+
DATASET_ID = "ayshajavd/code-security-vulnerability-dataset"
|
| 53 |
+
MAX_SOURCE_LENGTH = 512
|
| 54 |
+
MAX_TARGET_LENGTH = 512
|
| 55 |
+
SEED = 42
|
| 56 |
+
|
| 57 |
+
CWE_NAMES = {
|
| 58 |
+
"safe":"Safe Code","CWE-20":"Improper Input Validation","CWE-22":"Path Traversal",
|
| 59 |
+
"CWE-78":"OS Command Injection","CWE-79":"Cross-Site Scripting",
|
| 60 |
+
"CWE-89":"SQL Injection","CWE-94":"Code Injection","CWE-119":"Buffer Overflow",
|
| 61 |
+
"CWE-125":"Out-of-bounds Read","CWE-190":"Integer Overflow",
|
| 62 |
+
"CWE-200":"Information Exposure","CWE-264":"Permissions Issues",
|
| 63 |
+
"CWE-269":"Privilege Management","CWE-276":"Incorrect Permissions",
|
| 64 |
+
"CWE-284":"Access Control","CWE-287":"Authentication",
|
| 65 |
+
"CWE-310":"Cryptographic Issues","CWE-327":"Broken Crypto",
|
| 66 |
+
"CWE-330":"Insufficient Randomness","CWE-352":"CSRF",
|
| 67 |
+
"CWE-362":"Race Condition","CWE-399":"Resource Management",
|
| 68 |
+
"CWE-401":"Memory Leak","CWE-416":"Use After Free",
|
| 69 |
+
"CWE-434":"File Upload","CWE-476":"NULL Pointer Dereference",
|
| 70 |
+
"CWE-502":"Insecure Deserialization","CWE-601":"Open Redirect",
|
| 71 |
+
"CWE-787":"Out-of-bounds Write","CWE-798":"Hardcoded Credentials","CWE-918":"SSRF",
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
if torch.cuda.is_available():
|
| 75 |
+
print(f"β
GPU: {torch.cuda.get_device_name(0)} ({torch.cuda.get_device_properties(0).total_memory/1e9:.1f}GB)")
|
| 76 |
+
else:
|
| 77 |
+
print("β οΈ No GPU!")
|
| 78 |
+
|
| 79 |
+
# %% [CELL 3] Load Model + Tokenizer - BULLETPROOF
|
| 80 |
+
print("=" * 60)
|
| 81 |
+
print("π Loading CodeT5+ 220M Tokenizer + Model")
|
| 82 |
+
print("=" * 60)
|
| 83 |
+
|
| 84 |
+
# BULLETPROOF: Use RobertaTokenizer slow path (use_fast=False)
|
| 85 |
+
# This avoids the 'extra_special_tokens' TypeError in newer transformers
|
| 86 |
+
tokenizer = RobertaTokenizer.from_pretrained(MODEL_NAME, use_fast=False)
|
| 87 |
+
|
| 88 |
+
# CRITICAL: Verify token IDs match model's config.json
|
| 89 |
+
print(f"\nπ Tokenizer Verification:")
|
| 90 |
+
print(f" pad_token_id: {tokenizer.pad_token_id} (expected: 0)")
|
| 91 |
+
print(f" bos_token_id: {tokenizer.bos_token_id} (expected: 1)")
|
| 92 |
+
print(f" eos_token_id: {tokenizer.eos_token_id} (expected: 2)")
|
| 93 |
+
print(f" unk_token_id: {tokenizer.unk_token_id} (expected: 3)")
|
| 94 |
+
print(f" mask_token_id: {tokenizer.mask_token_id} (expected: 4)")
|
| 95 |
+
print(f" vocab_size: {len(tokenizer)} (expected: 32100)")
|
| 96 |
+
|
| 97 |
+
# Verify sentinel tokens exist
|
| 98 |
+
extra_id_0 = tokenizer.convert_tokens_to_ids("<extra_id_0>")
|
| 99 |
+
extra_id_99 = tokenizer.convert_tokens_to_ids("<extra_id_99>")
|
| 100 |
+
print(f" <extra_id_0> id: {extra_id_0} (should NOT be {tokenizer.unk_token_id})")
|
| 101 |
+
print(f" <extra_id_99> id: {extra_id_99} (should NOT be {tokenizer.unk_token_id})")
|
| 102 |
+
|
| 103 |
+
assert tokenizer.pad_token_id == 0, f"FATAL: pad_token_id={tokenizer.pad_token_id}, expected 0"
|
| 104 |
+
assert len(tokenizer) == 32100, f"FATAL: vocab_size={len(tokenizer)}, expected 32100"
|
| 105 |
+
assert extra_id_0 != tokenizer.unk_token_id, "FATAL: <extra_id_0> mapped to <unk>"
|
| 106 |
+
print("β
All tokenizer verifications PASSED")
|
| 107 |
+
|
| 108 |
+
# Load model - use float32 (NOT fp16) to avoid unscale_fp16 error
|
| 109 |
+
# T5ForConditionalGeneration works fine with default dtype
|
| 110 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
|
| 111 |
+
|
| 112 |
+
# Force single GPU to prevent DataParallel OOM issues
|
| 113 |
+
if torch.cuda.is_available():
|
| 114 |
+
model = model.to('cuda:0')
|
| 115 |
+
torch.cuda.set_device(0)
|
| 116 |
+
print(f"β
Model moved to cuda:0")
|
| 117 |
+
|
| 118 |
+
print(f"β
Model loaded: {sum(p.numel() for p in model.parameters()):,} parameters")
|
| 119 |
+
|
| 120 |
+
# %% [CELL 4] Load + Filter Dataset
|
| 121 |
+
print("\n" + "=" * 60)
|
| 122 |
+
print("π Loading Dataset")
|
| 123 |
+
print("=" * 60)
|
| 124 |
+
|
| 125 |
+
ds = load_dataset(DATASET_ID)
|
| 126 |
+
|
| 127 |
+
def filter_has_fix(example):
|
| 128 |
+
"""Only keep vulnerable samples that have a fix"""
|
| 129 |
+
return (example['is_vulnerable'] == True and
|
| 130 |
+
example['code_fixed'] is not None and
|
| 131 |
+
len(example['code_fixed'].strip()) > 10)
|
| 132 |
+
|
| 133 |
+
print("Filtering to samples with fixes...")
|
| 134 |
+
ds_fixer = {}
|
| 135 |
+
for split in ['train', 'validation', 'test']:
|
| 136 |
+
ds_fixer[split] = ds[split].filter(filter_has_fix, num_proc=2)
|
| 137 |
+
print(f" {split}: {len(ds_fixer[split]):,} samples with fixes")
|
| 138 |
+
|
| 139 |
+
# %% [CELL 5] Tokenize with CWE-aware Input
|
| 140 |
+
print("\n" + "=" * 60)
|
| 141 |
+
print("π€ Tokenizing with CWE-aware input format")
|
| 142 |
+
print("=" * 60)
|
| 143 |
+
|
| 144 |
+
def tokenize_fn(examples):
|
| 145 |
+
"""
|
| 146 |
+
Input: "fix <CWE-NAME> vulnerability in <language>: <code>"
|
| 147 |
+
Target: fixed code
|
| 148 |
+
"""
|
| 149 |
+
inputs = []
|
| 150 |
+
for code, cwe, lang in zip(examples['code'], examples['cwe_id'], examples['language']):
|
| 151 |
+
cwe_name = CWE_NAMES.get(cwe, cwe)
|
| 152 |
+
prefix = f"fix {cwe_name} vulnerability in {lang.lower()}: "
|
| 153 |
+
inputs.append(prefix + code)
|
| 154 |
+
|
| 155 |
+
# FIXED: use padding='max_length' to avoid the warning
|
| 156 |
+
model_inputs = tokenizer(
|
| 157 |
+
inputs,
|
| 158 |
+
max_length=MAX_SOURCE_LENGTH,
|
| 159 |
+
truncation=True,
|
| 160 |
+
padding='max_length',
|
| 161 |
+
)
|
| 162 |
+
labels = tokenizer(
|
| 163 |
+
examples['code_fixed'],
|
| 164 |
+
max_length=MAX_TARGET_LENGTH,
|
| 165 |
+
truncation=True,
|
| 166 |
+
padding='max_length',
|
| 167 |
+
)
|
| 168 |
+
model_inputs['labels'] = labels['input_ids']
|
| 169 |
+
return model_inputs
|
| 170 |
+
|
| 171 |
+
print("Tokenizing...")
|
| 172 |
+
tokenized = {}
|
| 173 |
+
for split in ['train', 'validation', 'test']:
|
| 174 |
+
tokenized[split] = ds_fixer[split].map(
|
| 175 |
+
tokenize_fn, batched=True, batch_size=500, num_proc=2,
|
| 176 |
+
remove_columns=ds_fixer[split].column_names,
|
| 177 |
+
)
|
| 178 |
+
print(f" {split}: {len(tokenized[split]):,} tokenized")
|
| 179 |
+
|
| 180 |
+
# Verify a sample
|
| 181 |
+
sample_input_ids = tokenized['train'][0]['input_ids']
|
| 182 |
+
sample_label_ids = tokenized['train'][0]['labels']
|
| 183 |
+
print(f"\nπ Sample verification:")
|
| 184 |
+
print(f" input_ids length: {len(sample_input_ids)}")
|
| 185 |
+
print(f" labels length: {len(sample_label_ids)}")
|
| 186 |
+
print(f" input_ids[:10]: {sample_input_ids[:10]}")
|
| 187 |
+
print(f" labels[:10]: {sample_label_ids[:10]}")
|
| 188 |
+
|
| 189 |
+
# Check label masking: pad tokens should be in input_ids but not in labels
|
| 190 |
+
# (DataCollatorForSeq2Seq handles -100 replacement automatically)
|
| 191 |
+
print(f" pad_token_id in input: {tokenizer.pad_token_id in sample_input_ids}")
|
| 192 |
+
print(f" pad_token_id in labels (raw): {tokenizer.pad_token_id in sample_label_ids}")
|
| 193 |
+
|
| 194 |
+
# %% [CELL 6] Metrics
|
| 195 |
+
print("\n" + "=" * 60)
|
| 196 |
+
print("π Loading Evaluation Metrics")
|
| 197 |
+
print("=" * 60)
|
| 198 |
+
|
| 199 |
+
bleu_metric = evaluate.load("sacrebleu")
|
| 200 |
+
rouge_metric = evaluate.load("rouge")
|
| 201 |
+
|
| 202 |
+
def compute_metrics(eval_preds):
|
| 203 |
+
preds, labels = eval_preds
|
| 204 |
+
# Replace -100 with pad_token_id for decoding
|
| 205 |
+
labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
|
| 206 |
+
preds = np.where(preds != -100, preds, tokenizer.pad_token_id)
|
| 207 |
+
|
| 208 |
+
decoded_preds = [p.strip() for p in tokenizer.batch_decode(preds, skip_special_tokens=True)]
|
| 209 |
+
decoded_labels = [l.strip() for l in tokenizer.batch_decode(labels, skip_special_tokens=True)]
|
| 210 |
+
|
| 211 |
+
# BLEU
|
| 212 |
+
bleu_result = bleu_metric.compute(
|
| 213 |
+
predictions=decoded_preds,
|
| 214 |
+
references=[[l] for l in decoded_labels],
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
# ROUGE
|
| 218 |
+
rouge_result = rouge_metric.compute(
|
| 219 |
+
predictions=decoded_preds, references=decoded_labels,
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
# Exact match
|
| 223 |
+
exact_matches = sum(1 for p, l in zip(decoded_preds, decoded_labels) if p == l)
|
| 224 |
+
exact_match_rate = exact_matches / max(len(decoded_preds), 1)
|
| 225 |
+
|
| 226 |
+
# CodeBLEU (subset for speed, may fail if tree-sitter unavailable)
|
| 227 |
+
codebleu_score = 0.0
|
| 228 |
+
try:
|
| 229 |
+
from codebleu import calc_codebleu
|
| 230 |
+
n_eval = min(200, len(decoded_preds))
|
| 231 |
+
cb_result = calc_codebleu(
|
| 232 |
+
references=[[l] for l in decoded_labels[:n_eval]],
|
| 233 |
+
predictions=decoded_preds[:n_eval],
|
| 234 |
+
lang="c",
|
| 235 |
+
weights=(0.25, 0.25, 0.25, 0.25),
|
| 236 |
+
)
|
| 237 |
+
codebleu_score = cb_result['codebleu']
|
| 238 |
+
except Exception as e:
|
| 239 |
+
print(f"CodeBLEU failed (non-critical): {e}")
|
| 240 |
+
|
| 241 |
+
return {
|
| 242 |
+
"bleu": bleu_result["score"],
|
| 243 |
+
"rouge1": rouge_result["rouge1"],
|
| 244 |
+
"rouge2": rouge_result["rouge2"],
|
| 245 |
+
"rougeL": rouge_result["rougeL"],
|
| 246 |
+
"codebleu": codebleu_score,
|
| 247 |
+
"exact_match": exact_match_rate,
|
| 248 |
+
}
|
| 249 |
+
|
| 250 |
+
print("β
Metrics loaded")
|
| 251 |
+
|
| 252 |
+
# %% [CELL 7] TRAINING - BULLETPROOF
|
| 253 |
+
print("\n" + "=" * 60)
|
| 254 |
+
print("π FIXER MODEL TRAINING (Bulletproof v2)")
|
| 255 |
+
print(" CodeT5+ 220M | CWE-aware input | BLEU+CodeBLEU eval")
|
| 256 |
+
print(" 10 epochs | lr=5e-5 | constant scheduler | beam_search=5")
|
| 257 |
+
print(" fp16=OFF (avoids unscale error) | batch=2 | grad_accum=16")
|
| 258 |
+
print("=" * 60)
|
| 259 |
+
|
| 260 |
+
data_collator = DataCollatorForSeq2Seq(
|
| 261 |
+
tokenizer=tokenizer, model=model, padding=True, max_length=MAX_SOURCE_LENGTH,
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
# BULLETPROOF training args - every error from previous runs fixed:
|
| 265 |
+
training_args = Seq2SeqTrainingArguments(
|
| 266 |
+
output_dir="./fixer_output",
|
| 267 |
+
num_train_epochs=10,
|
| 268 |
+
per_device_train_batch_size=2, # T4 safe
|
| 269 |
+
per_device_eval_batch_size=2,
|
| 270 |
+
gradient_accumulation_steps=16, # effective batch = 32
|
| 271 |
+
learning_rate=5e-5, # T5 recommended (1e-4 to 3e-5 range; 5e-5 stable)
|
| 272 |
+
lr_scheduler_type="constant", # T5APR found constant > cosine for code repair
|
| 273 |
+
warmup_ratio=0.06, # ~6% of steps warmup
|
| 274 |
+
weight_decay=0.01,
|
| 275 |
+
max_grad_norm=1.0, # prevents gradient explosion
|
| 276 |
+
# fp16=False - DO NOT ENABLE: causes "Attempting to unscale FP16 gradients" error
|
| 277 |
+
# on newer accelerate + T4 GPU. Full float32 training is slower but stable.
|
| 278 |
+
eval_strategy="epoch",
|
| 279 |
+
save_strategy="epoch",
|
| 280 |
+
logging_strategy="steps",
|
| 281 |
+
logging_steps=50,
|
| 282 |
+
logging_first_step=True,
|
| 283 |
+
disable_tqdm=True, # plain text output for Kaggle
|
| 284 |
+
load_best_model_at_end=True,
|
| 285 |
+
metric_for_best_model="eval_bleu",
|
| 286 |
+
greater_is_better=True,
|
| 287 |
+
save_total_limit=3,
|
| 288 |
+
seed=SEED,
|
| 289 |
+
predict_with_generate=True,
|
| 290 |
+
generation_max_length=MAX_TARGET_LENGTH,
|
| 291 |
+
generation_num_beams=5,
|
| 292 |
+
dataloader_num_workers=2,
|
| 293 |
+
report_to="none",
|
| 294 |
+
gradient_checkpointing=True, # saves VRAM
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
trainer = Seq2SeqTrainer(
|
| 298 |
+
model=model,
|
| 299 |
+
args=training_args,
|
| 300 |
+
train_dataset=tokenized['train'],
|
| 301 |
+
eval_dataset=tokenized['validation'],
|
| 302 |
+
data_collator=data_collator,
|
| 303 |
+
compute_metrics=compute_metrics,
|
| 304 |
+
callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
# Force single GPU to avoid DataParallel OOM
|
| 308 |
+
if torch.cuda.is_available():
|
| 309 |
+
trainer.args._n_gpu = 1
|
| 310 |
+
print("β
Single GPU mode enforced (_n_gpu=1)")
|
| 311 |
+
|
| 312 |
+
print("ποΈ Training starting...")
|
| 313 |
+
trainer.train()
|
| 314 |
+
|
| 315 |
+
# %% [CELL 8] Final Evaluation on Test Set
|
| 316 |
+
print("\n" + "=" * 60)
|
| 317 |
+
print("π FINAL EVALUATION ON TEST SET")
|
| 318 |
+
print("=" * 60)
|
| 319 |
+
|
| 320 |
+
test_results = trainer.predict(tokenized['test'])
|
| 321 |
+
test_metrics = test_results.metrics
|
| 322 |
+
|
| 323 |
+
print("\nπ TEST RESULTS:")
|
| 324 |
+
for k, v in sorted(test_metrics.items()):
|
| 325 |
+
if isinstance(v, float):
|
| 326 |
+
print(f" {k}: {v:.4f}")
|
| 327 |
+
else:
|
| 328 |
+
print(f" {k}: {v}")
|
| 329 |
+
|
| 330 |
+
# Save test metrics for later
|
| 331 |
+
with open("./fixer_output/test_metrics.json", 'w') as f:
|
| 332 |
+
json.dump({k: float(v) if isinstance(v, (float, np.floating)) else v
|
| 333 |
+
for k, v in test_metrics.items()}, f, indent=2)
|
| 334 |
+
print("β
Test metrics saved")
|
| 335 |
+
|
| 336 |
+
# %% [CELL 9] Example Fixes (Qualitative)
|
| 337 |
+
print("\n" + "=" * 60)
|
| 338 |
+
print("π§ EXAMPLE FIXES (Qualitative Assessment)")
|
| 339 |
+
print("=" * 60)
|
| 340 |
+
|
| 341 |
+
test_samples = ds_fixer['test'].select(range(min(5, len(ds_fixer['test']))))
|
| 342 |
+
|
| 343 |
+
for i, sample in enumerate(test_samples):
|
| 344 |
+
cwe_name = CWE_NAMES.get(sample['cwe_id'], sample['cwe_id'])
|
| 345 |
+
input_text = f"fix {cwe_name} vulnerability in {sample['language'].lower()}: {sample['code']}"
|
| 346 |
+
|
| 347 |
+
inputs = tokenizer(input_text, return_tensors="pt", max_length=MAX_SOURCE_LENGTH, truncation=True, padding=True)
|
| 348 |
+
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
| 349 |
+
|
| 350 |
+
with torch.no_grad():
|
| 351 |
+
outputs = model.generate(
|
| 352 |
+
**inputs, max_length=MAX_TARGET_LENGTH,
|
| 353 |
+
num_beams=5, early_stopping=True, no_repeat_ngram_size=3,
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 357 |
+
|
| 358 |
+
print(f"\n{'='*60}")
|
| 359 |
+
print(f"--- Example {i+1}: {sample['cwe_id']} ({sample['language']}) ---")
|
| 360 |
+
print(f"VULNERABLE:\n{sample['code'][:500]}")
|
| 361 |
+
print(f"\nEXPECTED FIX:\n{sample['code_fixed'][:500]}")
|
| 362 |
+
print(f"\nGENERATED FIX:\n{generated[:500]}")
|
| 363 |
+
match = "β
" if generated.strip() == sample['code_fixed'].strip() else "β"
|
| 364 |
+
print(f"Exact Match: {match}")
|
| 365 |
+
|
| 366 |
+
# %% [CELL 10] Save + Push to Hub
|
| 367 |
+
print("\n" + "=" * 60)
|
| 368 |
+
print("πΎ Saving Fixer Model to HF Hub")
|
| 369 |
+
print("=" * 60)
|
| 370 |
+
|
| 371 |
+
model.save_pretrained("./fixer_final")
|
| 372 |
+
tokenizer.save_pretrained("./fixer_final")
|
| 373 |
+
|
| 374 |
+
# Save evaluation results + config
|
| 375 |
+
eval_results = {
|
| 376 |
+
"model": MODEL_NAME,
|
| 377 |
+
"test_metrics": {k: float(v) if isinstance(v, (float, np.floating)) else v
|
| 378 |
+
for k, v in test_metrics.items()},
|
| 379 |
+
"improvements": [
|
| 380 |
+
"CWE-aware input: 'fix <vulnerability> in <language>: <code>'",
|
| 381 |
+
"BLEU + CodeBLEU + ROUGE + exact match evaluation",
|
| 382 |
+
"Beam search (num_beams=5)",
|
| 383 |
+
"Only trained on samples with actual fixes",
|
| 384 |
+
"Constant LR schedule with warmup (T5APR-optimal)",
|
| 385 |
+
"Early stopping (patience=3)",
|
| 386 |
+
"fp16=OFF (stable on T4)",
|
| 387 |
+
"Gradient accumulation (eff_batch=32)",
|
| 388 |
+
],
|
| 389 |
+
"training_data": {
|
| 390 |
+
"total_samples_with_fixes": len(ds_fixer['train']),
|
| 391 |
+
"source_dataset": DATASET_ID,
|
| 392 |
+
},
|
| 393 |
+
}
|
| 394 |
+
with open("./fixer_final/eval_results.json", 'w') as f:
|
| 395 |
+
json.dump(eval_results, f, indent=2)
|
| 396 |
+
|
| 397 |
+
api = HfApi()
|
| 398 |
+
api.upload_folder(
|
| 399 |
+
folder_path="./fixer_final",
|
| 400 |
+
repo_id=HUB_MODEL_ID,
|
| 401 |
+
commit_message="v2: Fixed tokenizer + stable training (fp32, constant LR, CWE-aware)",
|
| 402 |
+
)
|
| 403 |
+
|
| 404 |
+
print(f"\nβ
Fixer model pushed to: https://huggingface.co/{HUB_MODEL_ID}")
|
| 405 |
+
|
| 406 |
+
# %% [CELL 11] Push test metrics separately
|
| 407 |
+
api.upload_file(
|
| 408 |
+
path_or_fileobj="./fixer_output/test_metrics.json",
|
| 409 |
+
path_in_repo="test_metrics.json",
|
| 410 |
+
repo_id=HUB_MODEL_ID,
|
| 411 |
+
commit_message="Test metrics from v2 training",
|
| 412 |
+
)
|
| 413 |
+
print("β
Test metrics pushed")
|
| 414 |
+
|
| 415 |
+
print("\n" + "=" * 60)
|
| 416 |
+
print("π― NOTEBOOK 4 COMPLETE!")
|
| 417 |
+
print("=" * 60)
|