Pranav Pc
Final Deploy
4b82ab5
Raw
History Blame Contribute Delete
2.45 kB
"""CodeT5 Vulnerability Detection model
Binary Classication Safe(0) vs Vulnerable(1)"""
import torch
import torch.nn as nn
from transformers import T5ForConditionalGeneration, RobertaTokenizer
class VulnerabilityCodeT5(nn.Module):
"""CodeT5 model for vulnerability detection"""
def __init__(self, model_name="Salesforce/codet5-base", num_labels=2):
super().__init__()
self.encoder_decoder = T5ForConditionalGeneration.from_pretrained(model_name)
#Get hidden size from config
hidden_size = self.encoder_decoder.config.d_model #768 for base
#Classification Head
self.classifier = nn.Sequential(
nn.Dropout(0.1),
nn.Linear(hidden_size, hidden_size),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(hidden_size, num_labels)
)
self.num_labels = num_labels
def forward(self, input_ids, attention_mask, labels=None):
"""
Forward pass
Args:
input_ids : tokenized code [batch_size, seq_len]
attention_mask : attention mask [batch_size, seq_len]
labels: ground truth labels [batch_size]
"""
#Get encoder outputs
encoder_outputs = self.encoder_decoder.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
return_dict=True
)
#Pool encoder outputs (use first token [CLS])
hidden_state = encoder_outputs.last_hidden_state # [batch, seq_len, hidden]
pooled_output = hidden_state[:, 0, :] # [batch, hidden]
#Classification
logits = self.classifier(pooled_output) # [batch, num_labels]
#Calculate loss
loss = None
if labels is not None:
loss_fn = nn.CrossEntropyLoss()
loss = loss_fn(logits, labels)
return {
'loss': loss,
'logits': logits,
'hidden_states': hidden_state
}
def predict(self, input_ids, attention_mask):
"""Make Predictions"""
self.eval()
with torch.no_grad():
outputs = self.forward(input_ids, attention_mask)
probs = torch.softmax(outputs["logits"], dim=1)
predictions = torch.argmax(probs, dim=1)
return predictions, probs
def count_parameters(model):
"""Count trainable parameters"""
return sum(p.numel() for p in model.parameters() if p.requires_grad)