Spaces:
Running on Zero
Running on Zero
| """Model loading and GPU-aware inference wrappers for the GigaCheck demo.""" | |
| from __future__ import annotations | |
| import torch | |
| from loguru import logger | |
| from transformers import AutoModel | |
| from bootstrap import ensure_gigacheck | |
| from rendering import AiInterval | |
| from validation import CONFIG | |
| # gigacheck must be importable before any from_pretrained() loads the remote | |
| # model code, which does `from gigacheck.model... import ...`. | |
| ensure_gigacheck() | |
| try: | |
| import spaces | |
| gpu = spaces.GPU | |
| HAS_SPACES = True | |
| except ImportError: | |
| HAS_SPACES = False | |
| def gpu(*args, **kwargs): | |
| """No-op stand-in for :func:`spaces.GPU` when ``spaces`` is unavailable. | |
| Supports both bare ``@gpu`` and parameterized ``@gpu(duration=...)`` usage. | |
| """ | |
| if args and callable(args[0]): | |
| return args[0] | |
| def wrap(fn): | |
| return fn | |
| return wrap | |
| def select_device() -> str: | |
| """Pick the inference device, preferring GPU when it is reachable. | |
| On ZeroGPU, CUDA is not visible at import time but is materialized inside the | |
| decorated function, so the presence of the ``spaces`` package implies a GPU. | |
| Returns: | |
| ``"cuda"`` when a GPU is (or will be) available, otherwise ``"cpu"``. | |
| """ | |
| if torch.cuda.is_available() or HAS_SPACES: | |
| return "cuda" | |
| return "cpu" | |
| def label_index(id2label: dict[int, str], target: str) -> int: | |
| """Return the class index whose label matches ``target`` case-insensitively. | |
| Resolving probabilities by label name (rather than assuming a fixed column | |
| order) keeps the demo correct regardless of how ``id2label`` is ordered. | |
| Args: | |
| id2label: Mapping of class index to label string. | |
| target: Lowercase label to locate (e.g. ``"ai"`` or ``"human"``). | |
| Returns: | |
| The matching class index. | |
| Raises: | |
| ValueError: If no label matches ``target``. | |
| """ | |
| for idx, name in id2label.items(): | |
| if str(name).lower() == target: | |
| return int(idx) | |
| raise ValueError(f"Label {target!r} not found in id2label={id2label}") | |
| DEVICE = select_device() | |
| logger.info("GigaCheck demo device: {} (spaces={})", DEVICE, HAS_SPACES) | |
| logger.info("Loading classifier: {}", CONFIG.classifier_model_id) | |
| classifier_model = AutoModel.from_pretrained( | |
| CONFIG.classifier_model_id, | |
| trust_remote_code=True, | |
| device_map=DEVICE, | |
| torch_dtype=torch.bfloat16, | |
| ) | |
| CLASSIFIER_ID2LABEL = { | |
| int(i): str(name) for i, name in classifier_model.config.id2label.items() | |
| } | |
| CLASSIFIER_AI_IDX = label_index(CLASSIFIER_ID2LABEL, "ai") | |
| CLASSIFIER_HUMAN_IDX = label_index(CLASSIFIER_ID2LABEL, "human") | |
| logger.info("Loading detector: {}", CONFIG.detector_model_id) | |
| detector_model = AutoModel.from_pretrained( | |
| CONFIG.detector_model_id, | |
| trust_remote_code=True, | |
| device_map=DEVICE, | |
| torch_dtype=torch.float32, | |
| ) | |
| def classify_text(text: str) -> tuple[str, float, float]: | |
| """Classify a text as human-written or AI-generated. | |
| Args: | |
| text: Input text (English or Russian). | |
| Returns: | |
| A tuple ``(label, p_human, p_ai)`` where ``label`` is the raw model | |
| label and the probabilities are floats in ``[0, 1]``. | |
| """ | |
| with torch.inference_mode(): | |
| output = classifier_model([text.replace("\n", " ")]) | |
| # Resolve probabilities by label name rather than assuming column order. | |
| probs = output.classification_head_probs[0].to(torch.float32) | |
| p_ai = float(probs[CLASSIFIER_AI_IDX]) | |
| p_human = float(probs[CLASSIFIER_HUMAN_IDX]) | |
| label = CLASSIFIER_ID2LABEL[int(probs.argmax())] | |
| return label, p_human, p_ai | |
| def detect_intervals(text: str, conf_threshold: float) -> list[AiInterval]: | |
| """Detect character spans likely written by an AI model. | |
| Args: | |
| text: Input text (English or Russian). | |
| conf_threshold: Confidence threshold for keeping a span. | |
| Returns: | |
| A list of ``(start_char, end_char, score)`` tuples. | |
| """ | |
| with torch.inference_mode(): | |
| output = detector_model([text], conf_interval_thresh=conf_threshold) | |
| raw = output.ai_intervals[0].cpu() | |
| return [(int(span[0]), int(span[1]), float(span[2])) for span in raw] | |