"""Model loading and GPU-aware inference wrappers for the GigaCheck demo.""" from __future__ import annotations import torch from loguru import logger from transformers import AutoModel from bootstrap import ensure_gigacheck from rendering import AiInterval from validation import CONFIG # gigacheck must be importable before any from_pretrained() loads the remote # model code, which does `from gigacheck.model... import ...`. ensure_gigacheck() try: import spaces gpu = spaces.GPU HAS_SPACES = True except ImportError: HAS_SPACES = False def gpu(*args, **kwargs): """No-op stand-in for :func:`spaces.GPU` when ``spaces`` is unavailable. Supports both bare ``@gpu`` and parameterized ``@gpu(duration=...)`` usage. """ if args and callable(args[0]): return args[0] def wrap(fn): return fn return wrap def select_device() -> str: """Pick the inference device, preferring GPU when it is reachable. On ZeroGPU, CUDA is not visible at import time but is materialized inside the decorated function, so the presence of the ``spaces`` package implies a GPU. Returns: ``"cuda"`` when a GPU is (or will be) available, otherwise ``"cpu"``. """ if torch.cuda.is_available() or HAS_SPACES: return "cuda" return "cpu" def label_index(id2label: dict[int, str], target: str) -> int: """Return the class index whose label matches ``target`` case-insensitively. Resolving probabilities by label name (rather than assuming a fixed column order) keeps the demo correct regardless of how ``id2label`` is ordered. Args: id2label: Mapping of class index to label string. target: Lowercase label to locate (e.g. ``"ai"`` or ``"human"``). Returns: The matching class index. Raises: ValueError: If no label matches ``target``. """ for idx, name in id2label.items(): if str(name).lower() == target: return int(idx) raise ValueError(f"Label {target!r} not found in id2label={id2label}") DEVICE = select_device() logger.info("GigaCheck demo device: {} (spaces={})", DEVICE, HAS_SPACES) logger.info("Loading classifier: {}", CONFIG.classifier_model_id) classifier_model = AutoModel.from_pretrained( CONFIG.classifier_model_id, trust_remote_code=True, device_map=DEVICE, torch_dtype=torch.bfloat16, ) CLASSIFIER_ID2LABEL = { int(i): str(name) for i, name in classifier_model.config.id2label.items() } CLASSIFIER_AI_IDX = label_index(CLASSIFIER_ID2LABEL, "ai") CLASSIFIER_HUMAN_IDX = label_index(CLASSIFIER_ID2LABEL, "human") logger.info("Loading detector: {}", CONFIG.detector_model_id) detector_model = AutoModel.from_pretrained( CONFIG.detector_model_id, trust_remote_code=True, device_map=DEVICE, torch_dtype=torch.float32, ) @gpu(duration=180) def classify_text(text: str) -> tuple[str, float, float]: """Classify a text as human-written or AI-generated. Args: text: Input text (English or Russian). Returns: A tuple ``(label, p_human, p_ai)`` where ``label`` is the raw model label and the probabilities are floats in ``[0, 1]``. """ with torch.inference_mode(): output = classifier_model([text.replace("\n", " ")]) # Resolve probabilities by label name rather than assuming column order. probs = output.classification_head_probs[0].to(torch.float32) p_ai = float(probs[CLASSIFIER_AI_IDX]) p_human = float(probs[CLASSIFIER_HUMAN_IDX]) label = CLASSIFIER_ID2LABEL[int(probs.argmax())] return label, p_human, p_ai @gpu(duration=180) def detect_intervals(text: str, conf_threshold: float) -> list[AiInterval]: """Detect character spans likely written by an AI model. Args: text: Input text (English or Russian). conf_threshold: Confidence threshold for keeping a span. Returns: A list of ``(start_char, end_char, score)`` tuples. """ with torch.inference_mode(): output = detector_model([text], conf_interval_thresh=conf_threshold) raw = output.ai_intervals[0].cpu() return [(int(span[0]), int(span[1]), float(span[2])) for span in raw]