vijaym's picture
Upload folder using huggingface_hub
434e2be verified
from __future__ import annotations
from copy import deepcopy
from typing import Iterable, get_args
from .models import CategoryName, Verdict
_ALL_CATEGORY_NAMES: tuple[CategoryName, ...] = get_args(CategoryName)
_BASE_VERDICT_SCHEMA = Verdict.model_json_schema()
_REQUEST_SCHEMA_CACHE: dict[tuple[CategoryName, ...], dict] = {}
# Hard cap that lm-format-enforcer applies during decoding so the model cannot
# loop indefinitely inside the reason string. Sized for the 95th-percentile
# reason length in the v2 SFT data with headroom; raise if real reasons get
# truncated mid-sentence in practice.
REASON_MAX_LENGTH = 400
MATCH_TEXT_MAX_LENGTH = 200
def build_request_schema(requested: Iterable[CategoryName]) -> dict:
ordered = _normalize_requested(requested)
cached = _REQUEST_SCHEMA_CACHE.get(ordered)
if cached is not None:
return deepcopy(cached)
schema = deepcopy(_BASE_VERDICT_SCHEMA)
schema["properties"]["reason"]["maxLength"] = REASON_MAX_LENGTH
match_def = schema["$defs"]["Match"]
match_def["properties"]["text"]["maxLength"] = MATCH_TEXT_MAX_LENGTH
category_def = schema["$defs"]["Category"]
category_def["properties"]["name"]["enum"] = list(ordered)
prefix_items = []
for name in ordered:
item_schema = deepcopy(category_def)
item_schema["properties"]["name"]["enum"] = [name]
prefix_items.append(item_schema)
categories_schema = schema["properties"]["categories"]
categories_schema["prefixItems"] = prefix_items
categories_schema["minItems"] = len(ordered)
categories_schema["maxItems"] = len(ordered)
categories_schema["items"] = deepcopy(category_def)
_REQUEST_SCHEMA_CACHE[ordered] = deepcopy(schema)
return schema
def _normalize_requested(requested: Iterable[CategoryName]) -> tuple[CategoryName, ...]:
if isinstance(requested, frozenset):
ordered = tuple(name for name in _ALL_CATEGORY_NAMES if name in requested)
else:
ordered = tuple(requested)
if not ordered:
raise ValueError("requested categories must not be empty")
if len(set(ordered)) != len(ordered):
raise ValueError(
"requested categories must be unique; got duplicates in "
f"{list(ordered)}"
)
unknown = [name for name in ordered if name not in _ALL_CATEGORY_NAMES]
if unknown:
raise ValueError(f"unknown requested categories: {unknown}")
return ordered