Pranav Pc
Final Deploy
4b82ab5
Raw
History Blame Contribute Delete
2.75 kB
from transformers import RobertaTokenizer
from torch.utils.data import Dataset, DataLoader
import torch
import json
from pathlib import Path
class VulnerabilityDataset(Dataset):
"""PyTorch dataset for vulnerability detection"""
def __init__(self, data_path, tokenizer, max_length=512):
self.tokenizer = tokenizer
self.max_length = max_length
self.data = []
data_path = Path(data_path)
if not data_path.exists():
raise FileNotFoundError(f"Dataset file not found: {data_path}")
with open(data_path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if line:
self.data.append(json.loads(line))
print(f"{data_path.name}: {len(self.data)} samples")
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
sample = self.data[idx]
code = sample["func"] # confirmed correct
label = sample["target"] # confirmed correct (0/1)
encoding = self.tokenizer(
code,
truncation=True,
padding="max_length",
max_length=self.max_length,
return_tensors="pt"
)
return {
"input_ids": encoding["input_ids"].squeeze(0),
"attention_mask": encoding["attention_mask"].squeeze(0),
"labels": torch.tensor(label, dtype=torch.long)
}
def load_tokenizer(model_name="Salesforce/codet5-base"):
print(f"Tokenizer: {model_name}")
return RobertaTokenizer.from_pretrained(model_name)
def create_dataloader(
train_path,
valid_path,
test_path,
tokenizer,
batch_size=8,
max_length=512,
num_workers=2,
):
train_dataset = VulnerabilityDataset(train_path, tokenizer, max_length)
valid_dataset = VulnerabilityDataset(valid_path, tokenizer, max_length)
test_dataset = VulnerabilityDataset(test_path, tokenizer, max_length)
if len(train_dataset) == 0:
raise RuntimeError(f"No samples found in {train_path}")
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=True,
persistent_workers=True
)
valid_loader = DataLoader(
valid_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=True,
persistent_workers=True
)
test_loader = DataLoader(
test_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=True,
persistent_workers=True
)
return train_loader, valid_loader, test_loader