Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| JSON output parser for LLM responses. | |
| Uses json_repair for malformed JSON, and llm_output_parser as fallback | |
| to extract JSON from mixed text/markdown LLM output. | |
| """ | |
| import logging | |
| from collections.abc import Mapping | |
| from json_repair import repair_json | |
| from llm_output_parser import parse_json as extract_json | |
| logger = logging.getLogger(__name__) | |
| _TOP_LEVEL_KEYS = { | |
| # Diagnostician | |
| "findings", | |
| "differential_diagnoses", | |
| # Bias detector | |
| "discrepancy_summary", | |
| "identified_biases", | |
| "missed_findings", | |
| "agreement_points", | |
| # Devil's advocate | |
| "challenges", | |
| "must_not_miss", | |
| "recommended_workup", | |
| # Consultant | |
| "consultation_note", | |
| "alternative_diagnoses", | |
| "immediate_actions", | |
| "confidence_note", | |
| } | |
| def parse_json_response(text: str) -> dict: | |
| """ | |
| Extract and repair JSON from an LLM response. | |
| Handles: raw JSON, ```json blocks, missing commas, truncated output, etc. | |
| Returns parsed dict. Raises ValueError if repair fails completely. | |
| """ | |
| result = repair_json(text, return_objects=True) | |
| # Typical (desired) case: top-level object. | |
| if isinstance(result, Mapping): | |
| return dict(result) | |
| # Some model outputs come back as a top-level array. Coerce to a dict so | |
| # downstream code can continue, while preserving the payload for callers to | |
| # interpret (via 'items') when schema keys are missing. | |
| if isinstance(result, list): | |
| return _coerce_list_root(result) | |
| # Fallback: json_repair returned a plain string (model output natural language). | |
| # Use llm_output_parser to extract JSON from mixed text/markdown. | |
| if isinstance(result, str): | |
| logger.warning("json_repair returned str, trying llm_output_parser extraction") | |
| extracted = extract_json(text, allow_incomplete=True, strict=False) | |
| if isinstance(extracted, Mapping): | |
| return dict(extracted) | |
| if isinstance(extracted, list): | |
| return _coerce_list_root(extracted) | |
| raise ValueError( | |
| f"Could not parse JSON from LLM output (got {type(result).__name__}, length={len(text)})" | |
| ) | |
| def _coerce_list_root(items: list) -> dict: | |
| if not items: | |
| return {"items": []} | |
| mapping_items = [x for x in items if isinstance(x, Mapping)] | |
| if not mapping_items: | |
| return {"items": items} | |
| merged: dict = {} | |
| contains_top_level_key = False | |
| for m in mapping_items: | |
| d = dict(m) | |
| contains_top_level_key = contains_top_level_key or bool(_TOP_LEVEL_KEYS.intersection(d.keys())) | |
| merged.update(d) | |
| # If the extracted objects already contain known top-level schema keys, it's | |
| # likely a wrapped/duplicated object (or multiple partial objects). Merge. | |
| if contains_top_level_key: | |
| return merged | |
| all_mappings = len(mapping_items) == len(items) | |
| if all_mappings: | |
| # Distinguish between (a) a true list of repeated schema items, vs (b) | |
| # multiple standalone JSON objects extracted from a noisy response. | |
| key_sets = [set(dict(m).keys()) for m in mapping_items[:10]] | |
| union = set().union(*key_sets) | |
| intersection = set(key_sets[0]).intersection(*key_sets[1:]) if len(key_sets) > 1 else set(key_sets[0]) | |
| overlap_ratio = (len(intersection) / len(union)) if union else 0.0 | |
| if len(items) == 1 or overlap_ratio >= 0.35: | |
| inferred_key = _infer_list_container_key(mapping_items) | |
| if inferred_key: | |
| return {inferred_key: [dict(m) for m in mapping_items]} | |
| return {"items": [dict(m) for m in mapping_items]} | |
| # Low overlap between objects: treat as multiple extracted JSON objects. | |
| return merged | |
| # Mixed list: preserve non-mapping items, but coerce mappings to dict. | |
| coerced = [dict(x) if isinstance(x, Mapping) else x for x in items] | |
| return {"items": coerced} | |
| def _infer_list_container_key(items: list[Mapping]) -> str | None: | |
| keys: set[str] = set() | |
| for item in items[:5]: | |
| keys.update(str(k) for k in item.keys()) | |
| # Diagnostician | |
| if {"finding", "description"} & keys: | |
| return "findings" | |
| if "reasoning" in keys: | |
| return "differential_diagnoses" | |
| # Bias detector | |
| if {"type", "severity"} <= keys or ("type" in keys and "severity" in keys): | |
| return "identified_biases" | |
| # Devil's advocate | |
| if "claim" in keys or "counter_evidence" in keys: | |
| return "challenges" | |
| if {"why_dangerous", "rule_out_test", "supporting_signs"} & keys: | |
| return "must_not_miss" | |
| # Consultant | |
| if {"urgency", "next_step"} & keys: | |
| return "alternative_diagnoses" | |
| return None | |