File size: 4,761 Bytes
089d665 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 | """Multi-LLM ensemble — robust trajectory + dx via voting.
Calls 2-3 LLMs in parallel (DeepSeek, Gemini, Cerebras) for the same task,
then aggregates. Robust against any single model hallucinating or
misformatting JSON.
Use cases:
- Trajectory: median risk score per horizon across models
- Diagnosis: union differential, weighted by model confidence
- Extraction: majority-vote on entity status (present/absent/family)
"""
from __future__ import annotations
import asyncio
import json
import logging
import os
import statistics
from typing import Optional
logger = logging.getLogger("gemeo.ensemble")
def _build_llms() -> list:
"""Return up to 3 LLM clients drawn from configured backends."""
llms = []
# DeepSeek
if os.environ.get("DEEPSEEK_API_KEY"):
try:
from langchain_openai import ChatOpenAI
llms.append(("deepseek-chat", ChatOpenAI(
model=os.environ.get("DEEPSEEK_MODEL", "deepseek-chat"),
openai_api_key=os.environ["DEEPSEEK_API_KEY"],
openai_api_base="https://api.deepseek.com/v1",
temperature=0.1, max_retries=2, timeout=60,
)))
except Exception as e:
logger.debug(f"deepseek load failed: {e}")
# Gemini
if os.environ.get("GEMINI_API_KEY") or os.environ.get("GOOGLE_API_KEY"):
try:
from langchain_google_genai import ChatGoogleGenerativeAI
llms.append(("gemini-2.5-flash", ChatGoogleGenerativeAI(
model="gemini-2.5-flash",
google_api_key=os.environ.get("GEMINI_API_KEY") or os.environ["GOOGLE_API_KEY"],
temperature=0.1,
)))
except Exception as e:
logger.debug(f"gemini load failed: {e}")
# Cerebras
if os.environ.get("CEREBRAS_API_KEY"):
try:
from langchain_openai import ChatOpenAI
llms.append(("cerebras-qwen", ChatOpenAI(
model=os.environ.get("CEREBRAS_MODEL", "qwen-3-32b"),
openai_api_key=os.environ["CEREBRAS_API_KEY"],
openai_api_base="https://api.cerebras.ai/v1",
temperature=0.1, max_retries=2, timeout=60,
)))
except Exception as e:
logger.debug(f"cerebras load failed: {e}")
return llms
async def _call_one(label, llm, system, user, parser=None):
"""Single LLM call returning (label, parsed_or_text, raw_text)."""
from langchain_core.messages import SystemMessage, HumanMessage
try:
resp = await llm.ainvoke([SystemMessage(content=system), HumanMessage(content=user)])
text = getattr(resp, "content", None) or str(resp)
if parser:
try:
return label, parser(text), text
except Exception as e:
logger.debug(f"{label} parse failed: {e}")
return label, None, text
return label, text, text
except Exception as e:
logger.debug(f"{label} call failed: {e}")
return label, None, ""
async def call_ensemble(system: str, user: str, *, parser=None,
n_models: int = 3) -> list[tuple]:
"""Run system+user against up to n_models LLMs in parallel.
Returns list of (label, parsed_or_text, raw_text) tuples (only successes).
"""
llms = _build_llms()[:n_models]
if not llms:
return []
tasks = [_call_one(label, llm, system, user, parser) for label, llm in llms]
results = await asyncio.gather(*tasks, return_exceptions=True)
return [r for r in results if not isinstance(r, Exception) and r[1] is not None]
def median_risk_per_horizon(traj_predictions: list[dict]) -> dict:
"""Aggregate trajectory risk_score per horizon across multiple predictions.
traj_predictions: list of {months: [12,36,72], risks: [0.3, 0.5, 0.7]} dicts.
Returns: {months: 12, p50: 0.4, p05: 0.3, p95: 0.5}-like aggregate.
"""
by_h: dict[int, list[float]] = {}
for tp in traj_predictions:
if not tp:
continue
months = tp.get("months", [])
risks = tp.get("risks", [])
for m, r in zip(months, risks):
try:
by_h.setdefault(int(m), []).append(float(r))
except Exception:
continue
out = []
for m, vals in sorted(by_h.items()):
if not vals:
continue
vals_sorted = sorted(vals)
n = len(vals_sorted)
out.append({
"months": m,
"p50": statistics.median(vals_sorted),
"p05": vals_sorted[max(0, int(0.05 * n))],
"p95": vals_sorted[min(n - 1, int(0.95 * n))],
"n_models": n,
"values": vals_sorted,
})
return out
|