from fastapi import FastAPI from pydantic import BaseModel import torch import re from transformers import AutoTokenizer, AutoModelForSequenceClassification from tokenizers.normalizers import Sequence, Replace, Strip from tokenizers import Regex app = FastAPI(title="AI Text Detector API") device = torch.device("cpu") # Tokenizer tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-base") tokenizer.backend_tokenizer.normalizer = Sequence([ tokenizer.backend_tokenizer.normalizer, Replace(Regex(r'\s*\n\s*'), " "), Strip() ]) # Models model_1 = AutoModelForSequenceClassification.from_pretrained( "answerdotai/ModernBERT-base", num_labels=41 ) model_1.load_state_dict(torch.load("modernbert.bin", map_location="cpu")) model_1.eval() model_2 = AutoModelForSequenceClassification.from_pretrained( "answerdotai/ModernBERT-base", num_labels=41 ) model_2.load_state_dict( torch.hub.load_state_dict_from_url( "https://huggingface.co/mihalykiss/modernbert_2/resolve/main/Model_groups_3class_seed12", map_location="cpu" ) ) model_2.eval() model_3 = AutoModelForSequenceClassification.from_pretrained( "answerdotai/ModernBERT-base", num_labels=41 ) model_3.load_state_dict( torch.hub.load_state_dict_from_url( "https://huggingface.co/mihalykiss/modernbert_2/resolve/main/Model_groups_3class_seed22", map_location="cpu" ) ) model_3.eval() class TextInput(BaseModel): text: str @app.get("/") def health(): return {"status": "ok"} @app.post("/classify") def classify(payload: TextInput): text = re.sub(r"\s+", " ", payload.text).strip() if not text: return {"error": "Empty text"} inputs = tokenizer( text, return_tensors="pt", truncation=True, max_length=512 ) with torch.no_grad(): p1 = torch.softmax(model_1(**inputs).logits, dim=1) p2 = torch.softmax(model_2(**inputs).logits, dim=1) p3 = torch.softmax(model_3(**inputs).logits, dim=1) probs = (p1 + p2 + p3) / 3 human = probs[0][24].item() ai = probs[0].sum().item() - human total = human + ai return { "human_percentage": round((human / total) * 100, 2), "ai_percentage": round((ai / total) * 100, 2), }