| from transformers import RobertaTokenizer |
| from torch.utils.data import Dataset, DataLoader |
| import torch |
| import json |
| from pathlib import Path |
|
|
|
|
| class VulnerabilityDataset(Dataset): |
| """PyTorch dataset for vulnerability detection""" |
|
|
| def __init__(self, data_path, tokenizer, max_length=512): |
| self.tokenizer = tokenizer |
| self.max_length = max_length |
|
|
| self.data = [] |
| data_path = Path(data_path) |
|
|
| if not data_path.exists(): |
| raise FileNotFoundError(f"Dataset file not found: {data_path}") |
|
|
| with open(data_path, "r", encoding="utf-8") as f: |
| for line in f: |
| line = line.strip() |
| if line: |
| self.data.append(json.loads(line)) |
|
|
| print(f"{data_path.name}: {len(self.data)} samples") |
|
|
| def __len__(self): |
| return len(self.data) |
|
|
| def __getitem__(self, idx): |
| sample = self.data[idx] |
|
|
| code = sample["func"] |
| label = sample["target"] |
|
|
| encoding = self.tokenizer( |
| code, |
| truncation=True, |
| padding="max_length", |
| max_length=self.max_length, |
| return_tensors="pt" |
| ) |
|
|
| return { |
| "input_ids": encoding["input_ids"].squeeze(0), |
| "attention_mask": encoding["attention_mask"].squeeze(0), |
| "labels": torch.tensor(label, dtype=torch.long) |
| } |
|
|
|
|
| def load_tokenizer(model_name="Salesforce/codet5-base"): |
| print(f"Tokenizer: {model_name}") |
| return RobertaTokenizer.from_pretrained(model_name) |
|
|
|
|
| def create_dataloader( |
| train_path, |
| valid_path, |
| test_path, |
| tokenizer, |
| batch_size=8, |
| max_length=512, |
| num_workers=2, |
| ): |
| train_dataset = VulnerabilityDataset(train_path, tokenizer, max_length) |
| valid_dataset = VulnerabilityDataset(valid_path, tokenizer, max_length) |
| test_dataset = VulnerabilityDataset(test_path, tokenizer, max_length) |
|
|
| if len(train_dataset) == 0: |
| raise RuntimeError(f"No samples found in {train_path}") |
|
|
| train_loader = DataLoader( |
| train_dataset, |
| batch_size=batch_size, |
| shuffle=True, |
| num_workers=num_workers, |
| pin_memory=True, |
| persistent_workers=True |
| ) |
|
|
| valid_loader = DataLoader( |
| valid_dataset, |
| batch_size=batch_size, |
| shuffle=False, |
| num_workers=num_workers, |
| pin_memory=True, |
| persistent_workers=True |
| ) |
|
|
| test_loader = DataLoader( |
| test_dataset, |
| batch_size=batch_size, |
| shuffle=False, |
| num_workers=num_workers, |
| pin_memory=True, |
| persistent_workers=True |
| ) |
|
|
| return train_loader, valid_loader, test_loader |
|
|