honeybansal23's picture
added model
169f5b8
from fastapi import FastAPI
from pydantic import BaseModel
import torch
import re
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from tokenizers.normalizers import Sequence, Replace, Strip
from tokenizers import Regex
app = FastAPI(title="AI Text Detector API")
device = torch.device("cpu")
# Tokenizer
tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-base")
tokenizer.backend_tokenizer.normalizer = Sequence([
tokenizer.backend_tokenizer.normalizer,
Replace(Regex(r'\s*\n\s*'), " "),
Strip()
])
# Models
model_1 = AutoModelForSequenceClassification.from_pretrained(
"answerdotai/ModernBERT-base", num_labels=41
)
model_1.load_state_dict(torch.load("modernbert.bin", map_location="cpu"))
model_1.eval()
model_2 = AutoModelForSequenceClassification.from_pretrained(
"answerdotai/ModernBERT-base", num_labels=41
)
model_2.load_state_dict(
torch.hub.load_state_dict_from_url(
"https://huggingface.co/mihalykiss/modernbert_2/resolve/main/Model_groups_3class_seed12",
map_location="cpu"
)
)
model_2.eval()
model_3 = AutoModelForSequenceClassification.from_pretrained(
"answerdotai/ModernBERT-base", num_labels=41
)
model_3.load_state_dict(
torch.hub.load_state_dict_from_url(
"https://huggingface.co/mihalykiss/modernbert_2/resolve/main/Model_groups_3class_seed22",
map_location="cpu"
)
)
model_3.eval()
class TextInput(BaseModel):
text: str
@app.get("/")
def health():
return {"status": "ok"}
@app.post("/classify")
def classify(payload: TextInput):
text = re.sub(r"\s+", " ", payload.text).strip()
if not text:
return {"error": "Empty text"}
inputs = tokenizer(
text,
return_tensors="pt",
truncation=True,
max_length=512
)
with torch.no_grad():
p1 = torch.softmax(model_1(**inputs).logits, dim=1)
p2 = torch.softmax(model_2(**inputs).logits, dim=1)
p3 = torch.softmax(model_3(**inputs).logits, dim=1)
probs = (p1 + p2 + p3) / 3
human = probs[0][24].item()
ai = probs[0].sum().item() - human
total = human + ai
return {
"human_percentage": round((human / total) * 100, 2),
"ai_percentage": round((ai / total) * 100, 2),
}