File size: 1,994 Bytes
325e5a1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
"""
Model loaders for the AI detection pipeline.

Uses `desklib/ai-text-detector-v1.01` — a DeBERTa-v3-large classifier that
currently tops the RAID benchmark for modern LLM detection (ChatGPT, Claude,
Gemini, Llama, Grok, etc). The model ships a custom head, so we load it via
the `DesklibAIDetectionModel` wrapper defined in `utils.desklib_model`.
"""
import logging
from functools import lru_cache

import torch
from transformers import AutoTokenizer

from utils.desklib_model import DesklibAIDetectionModel

logger = logging.getLogger(__name__)

DETECTOR_MODEL_ID = "desklib/ai-text-detector-v1.01"


@lru_cache(maxsize=1)
def load_detector_model():
    """Load the desklib AI detector (DeBERTa-v3-large + custom head).

    Returns (model, tokenizer, device). First call downloads ~1.75 GB
    and caches it under `~/.cache/huggingface`. Subsequent calls return
    the cached in-process instance.
    """
    if torch.cuda.is_available():
        device = torch.device("cuda")
    elif torch.backends.mps.is_available():
        device = torch.device("mps")
    else:
        device = torch.device("cpu")

    logger.info("Loading detector %s on %s", DETECTOR_MODEL_ID, device)
    tokenizer = AutoTokenizer.from_pretrained(DETECTOR_MODEL_ID)
    model = DesklibAIDetectionModel.from_pretrained(DETECTOR_MODEL_ID)
    model.to(device)
    model.eval()
    logger.info("Detector ready")
    return model, tokenizer, device


@torch.no_grad()
def predict_ai_probability(text, model, tokenizer, device, max_len=768):
    """Return probability (0..1) that `text` is AI-generated."""
    encoded = tokenizer(
        text,
        padding="max_length",
        truncation=True,
        max_length=max_len,
        return_tensors="pt",
    )
    input_ids = encoded["input_ids"].to(device)
    attention_mask = encoded["attention_mask"].to(device)

    outputs = model(input_ids=input_ids, attention_mask=attention_mask)
    logits = outputs["logits"]
    return torch.sigmoid(logits).item()