File size: 3,236 Bytes
557ee65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import json
import re
import logging
from typing import Type, Any, Union, Optional, List, Dict
from pydantic import BaseModel
import os

from .errors import StructuredOutputError

logger = logging.getLogger(__name__)

# Problematic phrases that might cause models to add prose instead of raw JSON
MISALIGNMENT_PHRASES = [
    "explain",
    "describe",
    "why",
    "step by step",
    "formatting",
    "reasoning",
    "thought process",
]


def schema_guard(prompt: str, instruction: Optional[str] = None) -> None:
    """
    Scans the prompt and instruction for phrases that might conflict with strict JSON generation.
    """
    combined = (prompt + " " + (instruction or "")).lower()
    found = [phrase for phrase in MISALIGNMENT_PHRASES if phrase in combined]

    if found:
        # Check for strict mode via environment variable
        strict_mode = os.getenv("LLM_SCHEMA_GUARD_STRICT", "false").lower() == "true"
        warning_msg = f"Schema misalignment guard hit: problematic phrases found: {found}"

        if strict_mode:
            logger.error(f"STRICT MODE: {warning_msg}")
            raise ValueError(warning_msg)
        else:
            logger.warning(warning_msg)


def get_json_instruction(schema: Type[BaseModel], current_instruction: Optional[str] = None) -> str:
    """
    Returns a concise but strict JSON instruction, preserving existing instructions.
    """
    json_requirements = (
        "Return ONLY valid JSON. No prose, no preamble. "
        "Must conform exactly to this schema. No extra keys."
    )
    schema_json = json.dumps(schema.model_json_schema())

    base = f"{current_instruction}\n\n" if current_instruction else ""
    return f"{base}{json_requirements}\nSchema: {schema_json}"


def extract_json(text: str) -> str:
    """
    Robustly extract the largest JSON-like block from text.
    """
    # Try to find the first '{' and last '}'
    # We use non-greedy find for the first '{' but greedy for the last '}'
    first = text.find("{")
    last = text.rfind("}")

    if first != -1 and last != -1 and last > first:
        return text[first : last + 1]

    return text.strip()


def validate_structured_output(
    text: str, schema: Type[BaseModel], provider: str, model: str, prompt_id: str
) -> Union[Dict[str, Any], BaseModel]:
    """
    Parses and validates the LLM output against a schema.
    Raises StructuredOutputError on failure.
    """
    clean_text = extract_json(text)

    try:
        data = json.loads(clean_text)
    except json.JSONDecodeError as e:
        logger.error(f"JSON Parse Failure. Raw text between braces: {clean_text}")
        raise StructuredOutputError(
            provider=provider,
            model=model,
            prompt_id=prompt_id,
            raw_output=text,
            reason="JSON Parse Failure",
            details=str(e),
        )

    try:
        return schema(**data)
    except Exception as e:
        logger.error(f"Schema Validation Failure. Data: {data}")
        raise StructuredOutputError(
            provider=provider,
            model=model,
            prompt_id=prompt_id,
            raw_output=text,
            reason="Schema Validation Failure",
            details=str(e),
        )