Spaces:
Sleeping
Sleeping
File size: 1,955 Bytes
98a28e4 f776d79 98a28e4 f776d79 98a28e4 f776d79 98a28e4 f776d79 98a28e4 f776d79 98a28e4 f776d79 98a28e4 f776d79 98a28e4 1ebd313 98a28e4 1ebd313 98a28e4 f776d79 98a28e4 f776d79 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 | """
CodeSheriff Inference Space
Minimal FastAPI server that loads the fine-tuned CodeBERT classifier
and exposes a POST /predict endpoint. Called remotely by the Render backend.
"""
import torch
from fastapi import FastAPI
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import uvicorn
MODEL_ID = "jayansh21/codesheriff-bug-classifier"
NUM_LABELS = 5
MAX_LENGTH = 512
LABEL_NAMES = {
0: "Clean",
1: "Null Reference Risk",
2: "Type Mismatch",
3: "Security Vulnerability",
4: "Logic Flaw",
}
app = FastAPI(title="CodeSheriff Inference")
print("Loading CodeSheriff classifier …")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForSequenceClassification.from_pretrained(
MODEL_ID, num_labels=NUM_LABELS
)
model.eval()
print("Model loaded ✅")
@app.post("/predict")
def predict(data: dict):
"""Classify a code snippet and return label, confidence, label_id."""
code = data.get("code", "")
if not code or not code.strip():
return {"label": "Clean", "confidence": 0.0, "label_id": 0}
encoding = tokenizer(
code,
truncation=True,
padding="max_length",
max_length=MAX_LENGTH,
return_tensors="pt",
)
with torch.no_grad():
outputs = model(**encoding)
probs = torch.softmax(outputs.logits, dim=-1).squeeze(0)
label_id = int(torch.argmax(probs).item())
confidence = float(probs[label_id].item())
all_probs = {str(i): round(float(probs[i].item()), 4) for i in range(len(probs))}
return {
"label": LABEL_NAMES.get(label_id, f"Unknown({label_id})"),
"confidence": round(confidence, 4),
"label_id": label_id,
"all_probs": all_probs,
}
@app.get("/health")
def health():
return {"status": "ok"}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)
|