GigaCheck / models.py
iitolstykh's picture
update app
486c3f6
Raw
History Blame Contribute Delete
4.24 kB
"""Model loading and GPU-aware inference wrappers for the GigaCheck demo."""
from __future__ import annotations
import torch
from loguru import logger
from transformers import AutoModel
from bootstrap import ensure_gigacheck
from rendering import AiInterval
from validation import CONFIG
# gigacheck must be importable before any from_pretrained() loads the remote
# model code, which does `from gigacheck.model... import ...`.
ensure_gigacheck()
try:
import spaces
gpu = spaces.GPU
HAS_SPACES = True
except ImportError:
HAS_SPACES = False
def gpu(*args, **kwargs):
"""No-op stand-in for :func:`spaces.GPU` when ``spaces`` is unavailable.
Supports both bare ``@gpu`` and parameterized ``@gpu(duration=...)`` usage.
"""
if args and callable(args[0]):
return args[0]
def wrap(fn):
return fn
return wrap
def select_device() -> str:
"""Pick the inference device, preferring GPU when it is reachable.
On ZeroGPU, CUDA is not visible at import time but is materialized inside the
decorated function, so the presence of the ``spaces`` package implies a GPU.
Returns:
``"cuda"`` when a GPU is (or will be) available, otherwise ``"cpu"``.
"""
if torch.cuda.is_available() or HAS_SPACES:
return "cuda"
return "cpu"
def label_index(id2label: dict[int, str], target: str) -> int:
"""Return the class index whose label matches ``target`` case-insensitively.
Resolving probabilities by label name (rather than assuming a fixed column
order) keeps the demo correct regardless of how ``id2label`` is ordered.
Args:
id2label: Mapping of class index to label string.
target: Lowercase label to locate (e.g. ``"ai"`` or ``"human"``).
Returns:
The matching class index.
Raises:
ValueError: If no label matches ``target``.
"""
for idx, name in id2label.items():
if str(name).lower() == target:
return int(idx)
raise ValueError(f"Label {target!r} not found in id2label={id2label}")
DEVICE = select_device()
logger.info("GigaCheck demo device: {} (spaces={})", DEVICE, HAS_SPACES)
logger.info("Loading classifier: {}", CONFIG.classifier_model_id)
classifier_model = AutoModel.from_pretrained(
CONFIG.classifier_model_id,
trust_remote_code=True,
device_map=DEVICE,
torch_dtype=torch.bfloat16,
)
CLASSIFIER_ID2LABEL = {
int(i): str(name) for i, name in classifier_model.config.id2label.items()
}
CLASSIFIER_AI_IDX = label_index(CLASSIFIER_ID2LABEL, "ai")
CLASSIFIER_HUMAN_IDX = label_index(CLASSIFIER_ID2LABEL, "human")
logger.info("Loading detector: {}", CONFIG.detector_model_id)
detector_model = AutoModel.from_pretrained(
CONFIG.detector_model_id,
trust_remote_code=True,
device_map=DEVICE,
torch_dtype=torch.float32,
)
@gpu(duration=180)
def classify_text(text: str) -> tuple[str, float, float]:
"""Classify a text as human-written or AI-generated.
Args:
text: Input text (English or Russian).
Returns:
A tuple ``(label, p_human, p_ai)`` where ``label`` is the raw model
label and the probabilities are floats in ``[0, 1]``.
"""
with torch.inference_mode():
output = classifier_model([text.replace("\n", " ")])
# Resolve probabilities by label name rather than assuming column order.
probs = output.classification_head_probs[0].to(torch.float32)
p_ai = float(probs[CLASSIFIER_AI_IDX])
p_human = float(probs[CLASSIFIER_HUMAN_IDX])
label = CLASSIFIER_ID2LABEL[int(probs.argmax())]
return label, p_human, p_ai
@gpu(duration=180)
def detect_intervals(text: str, conf_threshold: float) -> list[AiInterval]:
"""Detect character spans likely written by an AI model.
Args:
text: Input text (English or Russian).
conf_threshold: Confidence threshold for keeping a span.
Returns:
A list of ``(start_char, end_char, score)`` tuples.
"""
with torch.inference_mode():
output = detector_model([text], conf_interval_thresh=conf_threshold)
raw = output.ai_intervals[0].cpu()
return [(int(span[0]), int(span[1]), float(span[2])) for span in raw]