| import torch |
| import torch.nn as nn |
| from torch.optim import AdamW |
| from torch.amp import autocast, GradScaler |
| from transformers import get_linear_schedule_with_warmup |
| from pathlib import Path |
| from tqdm import tqdm |
| import argparse |
| import json |
| import gc |
| import sys |
|
|
| sys.path.append(str(Path(__file__).parent.parent)) |
|
|
| from src.v2.data_processor import load_tokenizer, create_dataloader |
| from src.v2.model import VulnerabilityCodeT5, count_parameters |
|
|
|
|
| class Trainer: |
| def __init__( |
| self, |
| model, |
| train_loader, |
| valid_loader, |
| device, |
| learning_rate=2e-5, |
| num_epochs=5, |
| gradient_accumulation_steps=4, |
| ): |
| self.model = model.to(device) |
| self.train_loader = train_loader |
| self.valid_loader = valid_loader |
| self.device = device |
| self.num_epochs = num_epochs |
| self.gradient_accumulation_steps = gradient_accumulation_steps |
|
|
| self.use_amp = device.type == "cuda" |
| self.scaler = GradScaler(enabled=self.use_amp) |
|
|
| self.optimizer = AdamW( |
| self.model.parameters(), lr=learning_rate, weight_decay=0.01 |
| ) |
|
|
| total_steps = ( |
| len(self.train_loader) * num_epochs |
| ) // gradient_accumulation_steps |
|
|
| self.scheduler = get_linear_schedule_with_warmup( |
| self.optimizer, |
| num_warmup_steps=max(1, total_steps // 10), |
| num_training_steps=total_steps, |
| ) |
|
|
| self.best_val_acc = 0.0 |
| self.history = { |
| "train_loss": [], |
| "train_acc": [], |
| "val_loss": [], |
| "val_acc": [], |
| } |
|
|
| def clear_memory(self): |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| gc.collect() |
|
|
| def train_epoch(self): |
| self.model.train() |
| total_loss = 0.0 |
| correct = 0 |
| total = 0 |
|
|
| self.optimizer.zero_grad(set_to_none=True) |
|
|
| pbar = tqdm(self.train_loader, desc="Training") |
|
|
| for step, batch in enumerate(pbar): |
| input_ids = batch["input_ids"].to(self.device, non_blocking=True) |
| attention_mask = batch["attention_mask"].to(self.device, non_blocking=True) |
| labels = batch["labels"].to(self.device, non_blocking=True) |
|
|
| with autocast(device_type="cuda", enabled=self.use_amp): |
| outputs = self.model(input_ids, attention_mask, labels) |
| loss = outputs["loss"] / self.gradient_accumulation_steps |
|
|
| self.scaler.scale(loss).backward() |
|
|
| if (step + 1) % self.gradient_accumulation_steps == 0: |
| self.scaler.unscale_(self.optimizer) |
| torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) |
|
|
| self.scaler.step(self.optimizer) |
| self.scaler.update() |
| self.scheduler.step() |
| self.optimizer.zero_grad(set_to_none=True) |
|
|
| with torch.no_grad(): |
| preds = torch.argmax(outputs["logits"], dim=1) |
| correct += (preds == labels).sum().item() |
| total += labels.size(0) |
|
|
| total_loss += loss.item() * self.gradient_accumulation_steps |
|
|
| gpu_mem = ( |
| torch.cuda.memory_allocated() / 1024 ** 3 |
| if torch.cuda.is_available() |
| else 0 |
| ) |
|
|
| pbar.set_postfix( |
| { |
| "loss": f"{loss.item() * self.gradient_accumulation_steps:.4f}", |
| "acc": f"{100 * correct / max(1, total):.2f}%", |
| "gpu": f"{gpu_mem:.2f}GB", |
| } |
| ) |
|
|
| del input_ids, attention_mask, labels, outputs, loss |
|
|
| self.clear_memory() |
|
|
| return total_loss / len(self.train_loader), 100 * correct / total |
|
|
| def validate(self): |
| self.model.eval() |
| total_loss = 0.0 |
| correct = 0 |
| total = 0 |
|
|
| with torch.no_grad(): |
| for batch in tqdm(self.valid_loader, desc="Validating"): |
| input_ids = batch["input_ids"].to(self.device) |
| attention_mask = batch["attention_mask"].to(self.device) |
| labels = batch["labels"].to(self.device) |
|
|
| with autocast(device_type="cuda", enabled=self.use_amp): |
| outputs = self.model(input_ids, attention_mask, labels) |
| loss = outputs["loss"] |
|
|
| preds = torch.argmax(outputs["logits"], dim=1) |
| correct += (preds == labels).sum().item() |
| total += labels.size(0) |
| total_loss += loss.item() |
|
|
| self.clear_memory() |
| return total_loss / len(self.valid_loader), 100 * correct / total |
|
|
| def train(self, save_dir="models/v2"): |
| print(f"Training samples: {len(self.train_loader.dataset)}") |
| print(f"Validation samples: {len(self.valid_loader.dataset)}") |
| if torch.cuda.is_available(): |
| print(f"GPU: {torch.cuda.get_device_name(0)}") |
|
|
| save_dir = Path(save_dir) |
| save_dir.mkdir(parents=True, exist_ok=True) |
|
|
| for epoch in range(self.num_epochs): |
| print(f"\n{'=' * 60}") |
| print(f"Epoch {epoch + 1}/{self.num_epochs}") |
| print(f"{'=' * 60}") |
|
|
| train_loss, train_acc = self.train_epoch() |
| val_loss, val_acc = self.validate() |
|
|
| print( |
| f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%" |
| ) |
| print(f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%") |
|
|
| self.history["train_loss"].append(train_loss) |
| self.history["train_acc"].append(train_acc) |
| self.history["val_loss"].append(val_loss) |
| self.history["val_acc"].append(val_acc) |
|
|
| if val_acc > self.best_val_acc: |
| self.best_val_acc = val_acc |
| torch.save( |
| { |
| "model_state_dict": self.model.state_dict(), |
| "optimizer_state_dict": self.optimizer.state_dict(), |
| "val_acc": val_acc, |
| }, |
| save_dir / "best_model.pt", |
| ) |
| print("Saved best model") |
|
|
| torch.save( |
| { |
| "model_state_dict": self.model.state_dict(), |
| "history": self.history, |
| }, |
| save_dir / "final_model.pt", |
| ) |
|
|
| with open(save_dir / "training_history.json", "w") as f: |
| json.dump(self.history, f, indent=2) |
|
|
| print(f"\nTraining complete. Best Val Acc: {self.best_val_acc:.2f}%") |
|
|
|
|
| def main(args): |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| data_dir = ( |
| Path("data/processed/sample") if args.use_sample else Path("data/processed") |
| ) |
|
|
| train_path = data_dir / "train.jsonl" |
| valid_path = data_dir / "valid.jsonl" |
| test_path = data_dir / "test.jsonl" |
|
|
| tokenizer = load_tokenizer(args.model_name) |
|
|
| train_loader, valid_loader, test_loader = create_dataloader( |
| train_path, |
| valid_path, |
| test_path, |
| tokenizer, |
| batch_size=args.batch_size, |
| max_length=args.max_length, |
| num_workers=2, |
| ) |
|
|
| model = VulnerabilityCodeT5(model_name=args.model_name, num_labels=2) |
| print(f"Trainable parameters: {count_parameters(model):,}") |
|
|
| trainer = Trainer( |
| model, |
| train_loader, |
| valid_loader, |
| device, |
| learning_rate=args.learning_rate, |
| num_epochs=args.epochs, |
| gradient_accumulation_steps=args.gradient_accumulation, |
| ) |
|
|
| trainer.train(args.output_dir) |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--model_name", default="Salesforce/codet5-base") |
| parser.add_argument("--batch_size", type=int, default=4) |
| parser.add_argument("--max_length", type=int, default=256) |
| parser.add_argument("--learning_rate", type=float, default=2e-5) |
| parser.add_argument("--epochs", type=int, default=3) |
| parser.add_argument("--gradient_accumulation", type=int, default=4) |
| parser.add_argument("--output_dir", default="models/v2") |
| parser.add_argument("--use_sample", action="store_true") |
|
|
| main(parser.parse_args()) |
|
|