File size: 2,285 Bytes
80a6f14
169f5b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80a6f14
 
 
169f5b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
from fastapi import FastAPI
from pydantic import BaseModel
import torch
import re
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from tokenizers.normalizers import Sequence, Replace, Strip
from tokenizers import Regex

app = FastAPI(title="AI Text Detector API")

device = torch.device("cpu")

# Tokenizer
tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-base")

tokenizer.backend_tokenizer.normalizer = Sequence([
    tokenizer.backend_tokenizer.normalizer,
    Replace(Regex(r'\s*\n\s*'), " "),
    Strip()
])

# Models
model_1 = AutoModelForSequenceClassification.from_pretrained(
    "answerdotai/ModernBERT-base", num_labels=41
)
model_1.load_state_dict(torch.load("modernbert.bin", map_location="cpu"))
model_1.eval()

model_2 = AutoModelForSequenceClassification.from_pretrained(
    "answerdotai/ModernBERT-base", num_labels=41
)
model_2.load_state_dict(
    torch.hub.load_state_dict_from_url(
        "https://huggingface.co/mihalykiss/modernbert_2/resolve/main/Model_groups_3class_seed12",
        map_location="cpu"
    )
)
model_2.eval()

model_3 = AutoModelForSequenceClassification.from_pretrained(
    "answerdotai/ModernBERT-base", num_labels=41
)
model_3.load_state_dict(
    torch.hub.load_state_dict_from_url(
        "https://huggingface.co/mihalykiss/modernbert_2/resolve/main/Model_groups_3class_seed22",
        map_location="cpu"
    )
)
model_3.eval()


class TextInput(BaseModel):
    text: str


@app.get("/")
def health():
    return {"status": "ok"}


@app.post("/classify")
def classify(payload: TextInput):
    text = re.sub(r"\s+", " ", payload.text).strip()

    if not text:
        return {"error": "Empty text"}

    inputs = tokenizer(
        text,
        return_tensors="pt",
        truncation=True,
        max_length=512
    )

    with torch.no_grad():
        p1 = torch.softmax(model_1(**inputs).logits, dim=1)
        p2 = torch.softmax(model_2(**inputs).logits, dim=1)
        p3 = torch.softmax(model_3(**inputs).logits, dim=1)

    probs = (p1 + p2 + p3) / 3
    human = probs[0][24].item()
    ai = probs[0].sum().item() - human
    total = human + ai

    return {
        "human_percentage": round((human / total) * 100, 2),
        "ai_percentage": round((ai / total) * 100, 2),
    }