| """Multi-LLM ensemble — robust trajectory + dx via voting. |
| |
| Calls 2-3 LLMs in parallel (DeepSeek, Gemini, Cerebras) for the same task, |
| then aggregates. Robust against any single model hallucinating or |
| misformatting JSON. |
| |
| Use cases: |
| - Trajectory: median risk score per horizon across models |
| - Diagnosis: union differential, weighted by model confidence |
| - Extraction: majority-vote on entity status (present/absent/family) |
| """ |
| from __future__ import annotations |
| import asyncio |
| import json |
| import logging |
| import os |
| import statistics |
| from typing import Optional |
|
|
| logger = logging.getLogger("gemeo.ensemble") |
|
|
|
|
| def _build_llms() -> list: |
| """Return up to 3 LLM clients drawn from configured backends.""" |
| llms = [] |
| |
| if os.environ.get("DEEPSEEK_API_KEY"): |
| try: |
| from langchain_openai import ChatOpenAI |
| llms.append(("deepseek-chat", ChatOpenAI( |
| model=os.environ.get("DEEPSEEK_MODEL", "deepseek-chat"), |
| openai_api_key=os.environ["DEEPSEEK_API_KEY"], |
| openai_api_base="https://api.deepseek.com/v1", |
| temperature=0.1, max_retries=2, timeout=60, |
| ))) |
| except Exception as e: |
| logger.debug(f"deepseek load failed: {e}") |
| |
| if os.environ.get("GEMINI_API_KEY") or os.environ.get("GOOGLE_API_KEY"): |
| try: |
| from langchain_google_genai import ChatGoogleGenerativeAI |
| llms.append(("gemini-2.5-flash", ChatGoogleGenerativeAI( |
| model="gemini-2.5-flash", |
| google_api_key=os.environ.get("GEMINI_API_KEY") or os.environ["GOOGLE_API_KEY"], |
| temperature=0.1, |
| ))) |
| except Exception as e: |
| logger.debug(f"gemini load failed: {e}") |
| |
| if os.environ.get("CEREBRAS_API_KEY"): |
| try: |
| from langchain_openai import ChatOpenAI |
| llms.append(("cerebras-qwen", ChatOpenAI( |
| model=os.environ.get("CEREBRAS_MODEL", "qwen-3-32b"), |
| openai_api_key=os.environ["CEREBRAS_API_KEY"], |
| openai_api_base="https://api.cerebras.ai/v1", |
| temperature=0.1, max_retries=2, timeout=60, |
| ))) |
| except Exception as e: |
| logger.debug(f"cerebras load failed: {e}") |
| return llms |
|
|
|
|
| async def _call_one(label, llm, system, user, parser=None): |
| """Single LLM call returning (label, parsed_or_text, raw_text).""" |
| from langchain_core.messages import SystemMessage, HumanMessage |
| try: |
| resp = await llm.ainvoke([SystemMessage(content=system), HumanMessage(content=user)]) |
| text = getattr(resp, "content", None) or str(resp) |
| if parser: |
| try: |
| return label, parser(text), text |
| except Exception as e: |
| logger.debug(f"{label} parse failed: {e}") |
| return label, None, text |
| return label, text, text |
| except Exception as e: |
| logger.debug(f"{label} call failed: {e}") |
| return label, None, "" |
|
|
|
|
| async def call_ensemble(system: str, user: str, *, parser=None, |
| n_models: int = 3) -> list[tuple]: |
| """Run system+user against up to n_models LLMs in parallel. |
| |
| Returns list of (label, parsed_or_text, raw_text) tuples (only successes). |
| """ |
| llms = _build_llms()[:n_models] |
| if not llms: |
| return [] |
| tasks = [_call_one(label, llm, system, user, parser) for label, llm in llms] |
| results = await asyncio.gather(*tasks, return_exceptions=True) |
| return [r for r in results if not isinstance(r, Exception) and r[1] is not None] |
|
|
|
|
| def median_risk_per_horizon(traj_predictions: list[dict]) -> dict: |
| """Aggregate trajectory risk_score per horizon across multiple predictions. |
| |
| traj_predictions: list of {months: [12,36,72], risks: [0.3, 0.5, 0.7]} dicts. |
| Returns: {months: 12, p50: 0.4, p05: 0.3, p95: 0.5}-like aggregate. |
| """ |
| by_h: dict[int, list[float]] = {} |
| for tp in traj_predictions: |
| if not tp: |
| continue |
| months = tp.get("months", []) |
| risks = tp.get("risks", []) |
| for m, r in zip(months, risks): |
| try: |
| by_h.setdefault(int(m), []).append(float(r)) |
| except Exception: |
| continue |
| out = [] |
| for m, vals in sorted(by_h.items()): |
| if not vals: |
| continue |
| vals_sorted = sorted(vals) |
| n = len(vals_sorted) |
| out.append({ |
| "months": m, |
| "p50": statistics.median(vals_sorted), |
| "p05": vals_sorted[max(0, int(0.05 * n))], |
| "p95": vals_sorted[min(n - 1, int(0.95 * n))], |
| "n_models": n, |
| "values": vals_sorted, |
| }) |
| return out |
|
|