""" Inference script for UnixCoder-512 ===================================== Usage: Simply run this script with your code samples """ import torch import torch.nn as nn import torch.nn.functional as F from transformers import AutoTokenizer, AutoModel, AutoConfig, AutoModelForSequenceClassification from safetensors.torch import load_file import numpy as np DEVICE = "cuda" if torch.cuda.is_available() else "cpu" CLASS_NAMES = ["Human", "AI-Generated", "Hybrid", "Adversarial"] class UnixCoderModel(nn.Module): def __init__(self, config): super().__init__() from transformers import RobertaModel self.encoder = RobertaModel(config) self.classifier = nn.Linear(config.hidden_size, 4) def forward(self, input_ids, attention_mask): return self.classifier(self.encoder(input_ids, attention_mask=attention_mask)[0][:, 0, :]) def load_model(): """Load the model and tokenizer""" from transformers import RobertaConfig from huggingface_hub import hf_hub_download repo = "YoungDSMLKZ/UnixCoder-512" config = RobertaConfig.from_pretrained(repo) tokenizer = AutoTokenizer.from_pretrained(repo) model = UnixCoderModel(config) weights_path = hf_hub_download(repo_id=repo, filename="model.safetensors") weights = load_file(weights_path) model.load_state_dict({k.replace("unixcoder.", "encoder."): v for k, v in weights.items()}) model.to(DEVICE).eval() return model, tokenizer def predict(code: str, model, tokenizer) -> dict: """Predict class for a single code sample""" inputs = tokenizer(code, return_tensors="pt", truncation=True, max_length=512, padding=True).to(DEVICE) with torch.no_grad(): logits = model(inputs["input_ids"], inputs["attention_mask"]) probs = F.softmax(logits, dim=-1)[0] pred = torch.argmax(probs).item() return {"class": CLASS_NAMES[pred], "confidence": probs[pred].item()} if __name__ == "__main__": print("Loading model...") model, tokenizer = load_model() # Example usage test_code = """ def hello_world(): print("Hello, World!") """ result = predict(test_code, model, tokenizer) print(f"Predicted: {result['class']} (confidence: {result['confidence']:.2%})")