from __future__ import annotations from copy import deepcopy from typing import Iterable, get_args from .models import CategoryName, Verdict _ALL_CATEGORY_NAMES: tuple[CategoryName, ...] = get_args(CategoryName) _BASE_VERDICT_SCHEMA = Verdict.model_json_schema() _REQUEST_SCHEMA_CACHE: dict[tuple[CategoryName, ...], dict] = {} # Hard cap that lm-format-enforcer applies during decoding so the model cannot # loop indefinitely inside the reason string. Sized for the 95th-percentile # reason length in the v2 SFT data with headroom; raise if real reasons get # truncated mid-sentence in practice. REASON_MAX_LENGTH = 400 MATCH_TEXT_MAX_LENGTH = 200 def build_request_schema(requested: Iterable[CategoryName]) -> dict: ordered = _normalize_requested(requested) cached = _REQUEST_SCHEMA_CACHE.get(ordered) if cached is not None: return deepcopy(cached) schema = deepcopy(_BASE_VERDICT_SCHEMA) schema["properties"]["reason"]["maxLength"] = REASON_MAX_LENGTH match_def = schema["$defs"]["Match"] match_def["properties"]["text"]["maxLength"] = MATCH_TEXT_MAX_LENGTH category_def = schema["$defs"]["Category"] category_def["properties"]["name"]["enum"] = list(ordered) prefix_items = [] for name in ordered: item_schema = deepcopy(category_def) item_schema["properties"]["name"]["enum"] = [name] prefix_items.append(item_schema) categories_schema = schema["properties"]["categories"] categories_schema["prefixItems"] = prefix_items categories_schema["minItems"] = len(ordered) categories_schema["maxItems"] = len(ordered) categories_schema["items"] = deepcopy(category_def) _REQUEST_SCHEMA_CACHE[ordered] = deepcopy(schema) return schema def _normalize_requested(requested: Iterable[CategoryName]) -> tuple[CategoryName, ...]: if isinstance(requested, frozenset): ordered = tuple(name for name in _ALL_CATEGORY_NAMES if name in requested) else: ordered = tuple(requested) if not ordered: raise ValueError("requested categories must not be empty") if len(set(ordered)) != len(ordered): raise ValueError( "requested categories must be unique; got duplicates in " f"{list(ordered)}" ) unknown = [name for name in ordered if name not in _ALL_CATEGORY_NAMES] if unknown: raise ValueError(f"unknown requested categories: {unknown}") return ordered