Spaces:
Running
Running
File size: 3,236 Bytes
557ee65 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 | import json
import re
import logging
from typing import Type, Any, Union, Optional, List, Dict
from pydantic import BaseModel
import os
from .errors import StructuredOutputError
logger = logging.getLogger(__name__)
# Problematic phrases that might cause models to add prose instead of raw JSON
MISALIGNMENT_PHRASES = [
"explain",
"describe",
"why",
"step by step",
"formatting",
"reasoning",
"thought process",
]
def schema_guard(prompt: str, instruction: Optional[str] = None) -> None:
"""
Scans the prompt and instruction for phrases that might conflict with strict JSON generation.
"""
combined = (prompt + " " + (instruction or "")).lower()
found = [phrase for phrase in MISALIGNMENT_PHRASES if phrase in combined]
if found:
# Check for strict mode via environment variable
strict_mode = os.getenv("LLM_SCHEMA_GUARD_STRICT", "false").lower() == "true"
warning_msg = f"Schema misalignment guard hit: problematic phrases found: {found}"
if strict_mode:
logger.error(f"STRICT MODE: {warning_msg}")
raise ValueError(warning_msg)
else:
logger.warning(warning_msg)
def get_json_instruction(schema: Type[BaseModel], current_instruction: Optional[str] = None) -> str:
"""
Returns a concise but strict JSON instruction, preserving existing instructions.
"""
json_requirements = (
"Return ONLY valid JSON. No prose, no preamble. "
"Must conform exactly to this schema. No extra keys."
)
schema_json = json.dumps(schema.model_json_schema())
base = f"{current_instruction}\n\n" if current_instruction else ""
return f"{base}{json_requirements}\nSchema: {schema_json}"
def extract_json(text: str) -> str:
"""
Robustly extract the largest JSON-like block from text.
"""
# Try to find the first '{' and last '}'
# We use non-greedy find for the first '{' but greedy for the last '}'
first = text.find("{")
last = text.rfind("}")
if first != -1 and last != -1 and last > first:
return text[first : last + 1]
return text.strip()
def validate_structured_output(
text: str, schema: Type[BaseModel], provider: str, model: str, prompt_id: str
) -> Union[Dict[str, Any], BaseModel]:
"""
Parses and validates the LLM output against a schema.
Raises StructuredOutputError on failure.
"""
clean_text = extract_json(text)
try:
data = json.loads(clean_text)
except json.JSONDecodeError as e:
logger.error(f"JSON Parse Failure. Raw text between braces: {clean_text}")
raise StructuredOutputError(
provider=provider,
model=model,
prompt_id=prompt_id,
raw_output=text,
reason="JSON Parse Failure",
details=str(e),
)
try:
return schema(**data)
except Exception as e:
logger.error(f"Schema Validation Failure. Data: {data}")
raise StructuredOutputError(
provider=provider,
model=model,
prompt_id=prompt_id,
raw_output=text,
reason="Schema Validation Failure",
details=str(e),
)
|