import json import re import logging from typing import Type, Any, Union, Optional, List, Dict from pydantic import BaseModel import os from .errors import StructuredOutputError logger = logging.getLogger(__name__) # Problematic phrases that might cause models to add prose instead of raw JSON MISALIGNMENT_PHRASES = [ "explain", "describe", "why", "step by step", "formatting", "reasoning", "thought process", ] def schema_guard(prompt: str, instruction: Optional[str] = None) -> None: """ Scans the prompt and instruction for phrases that might conflict with strict JSON generation. """ combined = (prompt + " " + (instruction or "")).lower() found = [phrase for phrase in MISALIGNMENT_PHRASES if phrase in combined] if found: # Check for strict mode via environment variable strict_mode = os.getenv("LLM_SCHEMA_GUARD_STRICT", "false").lower() == "true" warning_msg = f"Schema misalignment guard hit: problematic phrases found: {found}" if strict_mode: logger.error(f"STRICT MODE: {warning_msg}") raise ValueError(warning_msg) else: logger.warning(warning_msg) def get_json_instruction(schema: Type[BaseModel], current_instruction: Optional[str] = None) -> str: """ Returns a concise but strict JSON instruction, preserving existing instructions. """ json_requirements = ( "Return ONLY valid JSON. No prose, no preamble. " "Must conform exactly to this schema. No extra keys." ) schema_json = json.dumps(schema.model_json_schema()) base = f"{current_instruction}\n\n" if current_instruction else "" return f"{base}{json_requirements}\nSchema: {schema_json}" def extract_json(text: str) -> str: """ Robustly extract the largest JSON-like block from text. """ # Try to find the first '{' and last '}' # We use non-greedy find for the first '{' but greedy for the last '}' first = text.find("{") last = text.rfind("}") if first != -1 and last != -1 and last > first: return text[first : last + 1] return text.strip() def validate_structured_output( text: str, schema: Type[BaseModel], provider: str, model: str, prompt_id: str ) -> Union[Dict[str, Any], BaseModel]: """ Parses and validates the LLM output against a schema. Raises StructuredOutputError on failure. """ clean_text = extract_json(text) try: data = json.loads(clean_text) except json.JSONDecodeError as e: logger.error(f"JSON Parse Failure. Raw text between braces: {clean_text}") raise StructuredOutputError( provider=provider, model=model, prompt_id=prompt_id, raw_output=text, reason="JSON Parse Failure", details=str(e), ) try: return schema(**data) except Exception as e: logger.error(f"Schema Validation Failure. Data: {data}") raise StructuredOutputError( provider=provider, model=model, prompt_id=prompt_id, raw_output=text, reason="Schema Validation Failure", details=str(e), )