syeedalireza's picture
Upload folder using huggingface_hub
35a2e1c verified
"""
Code quality classifier: encode with CodeBERT (or similar), then linear head.
"""
import torch
import torch.nn as nn
from transformers import AutoModel, AutoConfig
class CodeQualityClassifier(nn.Module):
def __init__(self, model_name: str = "microsoft/codebert-base", num_labels: int = 2, dropout: float = 0.1):
super().__init__()
self.config = AutoConfig.from_pretrained(model_name)
self.encoder = AutoModel.from_pretrained(model_name)
self.dropout = nn.Dropout(dropout)
self.classifier = nn.Linear(self.config.hidden_size, num_labels)
def forward(self, input_ids, attention_mask=None, token_type_ids=None):
out = self.encoder(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
pooled = out.last_hidden_state[:, 0, :]
pooled = self.dropout(pooled)
return self.classifier(pooled)