File size: 4,761 Bytes
089d665
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
"""Multi-LLM ensemble — robust trajectory + dx via voting.

Calls 2-3 LLMs in parallel (DeepSeek, Gemini, Cerebras) for the same task,
then aggregates. Robust against any single model hallucinating or
misformatting JSON.

Use cases:
  - Trajectory: median risk score per horizon across models
  - Diagnosis: union differential, weighted by model confidence
  - Extraction: majority-vote on entity status (present/absent/family)
"""
from __future__ import annotations
import asyncio
import json
import logging
import os
import statistics
from typing import Optional

logger = logging.getLogger("gemeo.ensemble")


def _build_llms() -> list:
    """Return up to 3 LLM clients drawn from configured backends."""
    llms = []
    # DeepSeek
    if os.environ.get("DEEPSEEK_API_KEY"):
        try:
            from langchain_openai import ChatOpenAI
            llms.append(("deepseek-chat", ChatOpenAI(
                model=os.environ.get("DEEPSEEK_MODEL", "deepseek-chat"),
                openai_api_key=os.environ["DEEPSEEK_API_KEY"],
                openai_api_base="https://api.deepseek.com/v1",
                temperature=0.1, max_retries=2, timeout=60,
            )))
        except Exception as e:
            logger.debug(f"deepseek load failed: {e}")
    # Gemini
    if os.environ.get("GEMINI_API_KEY") or os.environ.get("GOOGLE_API_KEY"):
        try:
            from langchain_google_genai import ChatGoogleGenerativeAI
            llms.append(("gemini-2.5-flash", ChatGoogleGenerativeAI(
                model="gemini-2.5-flash",
                google_api_key=os.environ.get("GEMINI_API_KEY") or os.environ["GOOGLE_API_KEY"],
                temperature=0.1,
            )))
        except Exception as e:
            logger.debug(f"gemini load failed: {e}")
    # Cerebras
    if os.environ.get("CEREBRAS_API_KEY"):
        try:
            from langchain_openai import ChatOpenAI
            llms.append(("cerebras-qwen", ChatOpenAI(
                model=os.environ.get("CEREBRAS_MODEL", "qwen-3-32b"),
                openai_api_key=os.environ["CEREBRAS_API_KEY"],
                openai_api_base="https://api.cerebras.ai/v1",
                temperature=0.1, max_retries=2, timeout=60,
            )))
        except Exception as e:
            logger.debug(f"cerebras load failed: {e}")
    return llms


async def _call_one(label, llm, system, user, parser=None):
    """Single LLM call returning (label, parsed_or_text, raw_text)."""
    from langchain_core.messages import SystemMessage, HumanMessage
    try:
        resp = await llm.ainvoke([SystemMessage(content=system), HumanMessage(content=user)])
        text = getattr(resp, "content", None) or str(resp)
        if parser:
            try:
                return label, parser(text), text
            except Exception as e:
                logger.debug(f"{label} parse failed: {e}")
                return label, None, text
        return label, text, text
    except Exception as e:
        logger.debug(f"{label} call failed: {e}")
        return label, None, ""


async def call_ensemble(system: str, user: str, *, parser=None,
                         n_models: int = 3) -> list[tuple]:
    """Run system+user against up to n_models LLMs in parallel.

    Returns list of (label, parsed_or_text, raw_text) tuples (only successes).
    """
    llms = _build_llms()[:n_models]
    if not llms:
        return []
    tasks = [_call_one(label, llm, system, user, parser) for label, llm in llms]
    results = await asyncio.gather(*tasks, return_exceptions=True)
    return [r for r in results if not isinstance(r, Exception) and r[1] is not None]


def median_risk_per_horizon(traj_predictions: list[dict]) -> dict:
    """Aggregate trajectory risk_score per horizon across multiple predictions.

    traj_predictions: list of {months: [12,36,72], risks: [0.3, 0.5, 0.7]} dicts.
    Returns: {months: 12, p50: 0.4, p05: 0.3, p95: 0.5}-like aggregate.
    """
    by_h: dict[int, list[float]] = {}
    for tp in traj_predictions:
        if not tp:
            continue
        months = tp.get("months", [])
        risks = tp.get("risks", [])
        for m, r in zip(months, risks):
            try:
                by_h.setdefault(int(m), []).append(float(r))
            except Exception:
                continue
    out = []
    for m, vals in sorted(by_h.items()):
        if not vals:
            continue
        vals_sorted = sorted(vals)
        n = len(vals_sorted)
        out.append({
            "months": m,
            "p50": statistics.median(vals_sorted),
            "p05": vals_sorted[max(0, int(0.05 * n))],
            "p95": vals_sorted[min(n - 1, int(0.95 * n))],
            "n_models": n,
            "values": vals_sorted,
        })
    return out