Pranav Pc
Final Deploy
4b82ab5
Raw
History Blame Contribute Delete
8.27 kB
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.amp import autocast, GradScaler
from transformers import get_linear_schedule_with_warmup
from pathlib import Path
from tqdm import tqdm
import argparse
import json
import gc
import sys
sys.path.append(str(Path(__file__).parent.parent))
from src.v2.data_processor import load_tokenizer, create_dataloader
from src.v2.model import VulnerabilityCodeT5, count_parameters
class Trainer:
def __init__(
self,
model,
train_loader,
valid_loader,
device,
learning_rate=2e-5,
num_epochs=5,
gradient_accumulation_steps=4,
):
self.model = model.to(device)
self.train_loader = train_loader
self.valid_loader = valid_loader
self.device = device
self.num_epochs = num_epochs
self.gradient_accumulation_steps = gradient_accumulation_steps
self.use_amp = device.type == "cuda"
self.scaler = GradScaler(enabled=self.use_amp)
self.optimizer = AdamW(
self.model.parameters(), lr=learning_rate, weight_decay=0.01
)
total_steps = (
len(self.train_loader) * num_epochs
) // gradient_accumulation_steps
self.scheduler = get_linear_schedule_with_warmup(
self.optimizer,
num_warmup_steps=max(1, total_steps // 10),
num_training_steps=total_steps,
)
self.best_val_acc = 0.0
self.history = {
"train_loss": [],
"train_acc": [],
"val_loss": [],
"val_acc": [],
}
def clear_memory(self):
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
def train_epoch(self):
self.model.train()
total_loss = 0.0
correct = 0
total = 0
self.optimizer.zero_grad(set_to_none=True)
pbar = tqdm(self.train_loader, desc="Training")
for step, batch in enumerate(pbar):
input_ids = batch["input_ids"].to(self.device, non_blocking=True)
attention_mask = batch["attention_mask"].to(self.device, non_blocking=True)
labels = batch["labels"].to(self.device, non_blocking=True)
with autocast(device_type="cuda", enabled=self.use_amp):
outputs = self.model(input_ids, attention_mask, labels)
loss = outputs["loss"] / self.gradient_accumulation_steps
self.scaler.scale(loss).backward()
if (step + 1) % self.gradient_accumulation_steps == 0:
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
self.scaler.step(self.optimizer)
self.scaler.update()
self.scheduler.step()
self.optimizer.zero_grad(set_to_none=True)
with torch.no_grad():
preds = torch.argmax(outputs["logits"], dim=1)
correct += (preds == labels).sum().item()
total += labels.size(0)
total_loss += loss.item() * self.gradient_accumulation_steps
gpu_mem = (
torch.cuda.memory_allocated() / 1024 ** 3
if torch.cuda.is_available()
else 0
)
pbar.set_postfix(
{
"loss": f"{loss.item() * self.gradient_accumulation_steps:.4f}",
"acc": f"{100 * correct / max(1, total):.2f}%",
"gpu": f"{gpu_mem:.2f}GB",
}
)
del input_ids, attention_mask, labels, outputs, loss
self.clear_memory()
return total_loss / len(self.train_loader), 100 * correct / total
def validate(self):
self.model.eval()
total_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
for batch in tqdm(self.valid_loader, desc="Validating"):
input_ids = batch["input_ids"].to(self.device)
attention_mask = batch["attention_mask"].to(self.device)
labels = batch["labels"].to(self.device)
with autocast(device_type="cuda", enabled=self.use_amp):
outputs = self.model(input_ids, attention_mask, labels)
loss = outputs["loss"]
preds = torch.argmax(outputs["logits"], dim=1)
correct += (preds == labels).sum().item()
total += labels.size(0)
total_loss += loss.item()
self.clear_memory()
return total_loss / len(self.valid_loader), 100 * correct / total
def train(self, save_dir="models/v2"):
print(f"Training samples: {len(self.train_loader.dataset)}")
print(f"Validation samples: {len(self.valid_loader.dataset)}")
if torch.cuda.is_available():
print(f"GPU: {torch.cuda.get_device_name(0)}")
save_dir = Path(save_dir)
save_dir.mkdir(parents=True, exist_ok=True)
for epoch in range(self.num_epochs):
print(f"\n{'=' * 60}")
print(f"Epoch {epoch + 1}/{self.num_epochs}")
print(f"{'=' * 60}")
train_loss, train_acc = self.train_epoch()
val_loss, val_acc = self.validate()
print(
f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%"
)
print(f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%")
self.history["train_loss"].append(train_loss)
self.history["train_acc"].append(train_acc)
self.history["val_loss"].append(val_loss)
self.history["val_acc"].append(val_acc)
if val_acc > self.best_val_acc:
self.best_val_acc = val_acc
torch.save(
{
"model_state_dict": self.model.state_dict(),
"optimizer_state_dict": self.optimizer.state_dict(),
"val_acc": val_acc,
},
save_dir / "best_model.pt",
)
print("Saved best model")
torch.save(
{
"model_state_dict": self.model.state_dict(),
"history": self.history,
},
save_dir / "final_model.pt",
)
with open(save_dir / "training_history.json", "w") as f:
json.dump(self.history, f, indent=2)
print(f"\nTraining complete. Best Val Acc: {self.best_val_acc:.2f}%")
def main(args):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
data_dir = (
Path("data/processed/sample") if args.use_sample else Path("data/processed")
)
train_path = data_dir / "train.jsonl"
valid_path = data_dir / "valid.jsonl"
test_path = data_dir / "test.jsonl"
tokenizer = load_tokenizer(args.model_name)
train_loader, valid_loader, test_loader = create_dataloader(
train_path,
valid_path,
test_path,
tokenizer,
batch_size=args.batch_size,
max_length=args.max_length,
num_workers=2,
)
model = VulnerabilityCodeT5(model_name=args.model_name, num_labels=2)
print(f"Trainable parameters: {count_parameters(model):,}")
trainer = Trainer(
model,
train_loader,
valid_loader,
device,
learning_rate=args.learning_rate,
num_epochs=args.epochs,
gradient_accumulation_steps=args.gradient_accumulation,
)
trainer.train(args.output_dir)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_name", default="Salesforce/codet5-base")
parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--max_length", type=int, default=256)
parser.add_argument("--learning_rate", type=float, default=2e-5)
parser.add_argument("--epochs", type=int, default=3)
parser.add_argument("--gradient_accumulation", type=int, default=4)
parser.add_argument("--output_dir", default="models/v2")
parser.add_argument("--use_sample", action="store_true")
main(parser.parse_args())