File size: 4,674 Bytes
c0fff99
 
c8dea05
 
c0fff99
 
 
 
 
c8dea05
c0fff99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c8dea05
 
 
 
 
 
 
 
 
 
c0fff99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
"""
JSON output parser for LLM responses.
Uses json_repair for malformed JSON, and llm_output_parser as fallback
to extract JSON from mixed text/markdown LLM output.
"""

import logging
from collections.abc import Mapping
from json_repair import repair_json
from llm_output_parser import parse_json as extract_json

logger = logging.getLogger(__name__)

_TOP_LEVEL_KEYS = {
    # Diagnostician
    "findings",
    "differential_diagnoses",
    # Bias detector
    "discrepancy_summary",
    "identified_biases",
    "missed_findings",
    "agreement_points",
    # Devil's advocate
    "challenges",
    "must_not_miss",
    "recommended_workup",
    # Consultant
    "consultation_note",
    "alternative_diagnoses",
    "immediate_actions",
    "confidence_note",
}


def parse_json_response(text: str) -> dict:
    """
    Extract and repair JSON from an LLM response.
    Handles: raw JSON, ```json blocks, missing commas, truncated output, etc.
    Returns parsed dict. Raises ValueError if repair fails completely.
    """
    result = repair_json(text, return_objects=True)

    # Typical (desired) case: top-level object.
    if isinstance(result, Mapping):
        return dict(result)

    # Some model outputs come back as a top-level array. Coerce to a dict so
    # downstream code can continue, while preserving the payload for callers to
    # interpret (via 'items') when schema keys are missing.
    if isinstance(result, list):
        return _coerce_list_root(result)

    # Fallback: json_repair returned a plain string (model output natural language).
    # Use llm_output_parser to extract JSON from mixed text/markdown.
    if isinstance(result, str):
        logger.warning("json_repair returned str, trying llm_output_parser extraction")
        extracted = extract_json(text, allow_incomplete=True, strict=False)
        if isinstance(extracted, Mapping):
            return dict(extracted)
        if isinstance(extracted, list):
            return _coerce_list_root(extracted)

    raise ValueError(
        f"Could not parse JSON from LLM output (got {type(result).__name__}, length={len(text)})"
    )


def _coerce_list_root(items: list) -> dict:
    if not items:
        return {"items": []}

    mapping_items = [x for x in items if isinstance(x, Mapping)]
    if not mapping_items:
        return {"items": items}

    merged: dict = {}
    contains_top_level_key = False
    for m in mapping_items:
        d = dict(m)
        contains_top_level_key = contains_top_level_key or bool(_TOP_LEVEL_KEYS.intersection(d.keys()))
        merged.update(d)

    # If the extracted objects already contain known top-level schema keys, it's
    # likely a wrapped/duplicated object (or multiple partial objects). Merge.
    if contains_top_level_key:
        return merged

    all_mappings = len(mapping_items) == len(items)
    if all_mappings:
        # Distinguish between (a) a true list of repeated schema items, vs (b)
        # multiple standalone JSON objects extracted from a noisy response.
        key_sets = [set(dict(m).keys()) for m in mapping_items[:10]]
        union = set().union(*key_sets)
        intersection = set(key_sets[0]).intersection(*key_sets[1:]) if len(key_sets) > 1 else set(key_sets[0])
        overlap_ratio = (len(intersection) / len(union)) if union else 0.0

        if len(items) == 1 or overlap_ratio >= 0.35:
            inferred_key = _infer_list_container_key(mapping_items)
            if inferred_key:
                return {inferred_key: [dict(m) for m in mapping_items]}
            return {"items": [dict(m) for m in mapping_items]}

        # Low overlap between objects: treat as multiple extracted JSON objects.
        return merged

    # Mixed list: preserve non-mapping items, but coerce mappings to dict.
    coerced = [dict(x) if isinstance(x, Mapping) else x for x in items]
    return {"items": coerced}


def _infer_list_container_key(items: list[Mapping]) -> str | None:
    keys: set[str] = set()
    for item in items[:5]:
        keys.update(str(k) for k in item.keys())

    # Diagnostician
    if {"finding", "description"} & keys:
        return "findings"
    if "reasoning" in keys:
        return "differential_diagnoses"

    # Bias detector
    if {"type", "severity"} <= keys or ("type" in keys and "severity" in keys):
        return "identified_biases"

    # Devil's advocate
    if "claim" in keys or "counter_evidence" in keys:
        return "challenges"
    if {"why_dangerous", "rule_out_test", "supporting_signs"} & keys:
        return "must_not_miss"

    # Consultant
    if {"urgency", "next_step"} & keys:
        return "alternative_diagnoses"

    return None