gemeo-twin-stack / src /gemeo /skill.py
timmers's picture
GEMEO world-model β€” initial release (module + NeuralSurv ckpt + RareBench v49 + KG embeddings)
089d665 verified
"""Gemeo skill β€” callable tool the LLM can invoke during a conversation.
Two tool functions exposed:
- `gemeo_lookup(query, mode="local")` β€” GraphRAG over the patient gemeo
+ (optionally) cohort + literature. Returns Markdown-ready evidence.
- `gemeo_state(section?)` β€” returns the live twin state, optionally
restricted to one section (cohort, ddi, family, ...).
Both auto-resolve the active `case_id` from the per-call context (set
by the orchestrator before the LLM turn). The agent never sees or
guesses case_ids; the platform binds them.
"""
from __future__ import annotations
import contextvars
import logging
from typing import Optional
logger = logging.getLogger("gemeo.skill")
# Per-task contextvar holding the active case id.
# Set by `with active_case(case_id): await llm.ainvoke(...)`.
_ACTIVE_CASE: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar(
"gemeo_active_case", default=None,
)
class active_case:
"""Context manager binding a case_id to the current async task."""
def __init__(self, case_id: Optional[str]):
self._case_id = case_id
self._token = None
def __enter__(self):
self._token = _ACTIVE_CASE.set(self._case_id)
return self
def __exit__(self, *exc):
if self._token is not None:
_ACTIVE_CASE.reset(self._token)
def get_active_case() -> Optional[str]:
return _ACTIVE_CASE.get()
# ─── tool implementations ─────────────────────────────────────────────────
async def gemeo_lookup(query: str, mode: str = "local") -> str:
"""GraphRAG over the patient gemeo.
Args:
query: free-text question grounded in the case.
mode: "local" (subgraph only) or "global" (also cohort + literature).
Returns:
Markdown-formatted evidence block.
"""
case_id = get_active_case()
if not case_id:
return "_(gemeo_lookup: no active case bound; nothing retrieved)_"
from . import graphrag
result = await graphrag.retrieve(case_id, query, mode=mode)
return graphrag.format_for_llm(result)
async def gemeo_state(section: Optional[str] = None) -> str:
"""Return the live twin state (or one section).
section ∈ {"diagnoses", "risk", "trajectory", "drugs", "ddi",
"pharmacogen", "family", "reverse_pheno",
"protocol_compliance", "next_questions", "sus_check"}
"""
case_id = get_active_case()
if not case_id:
return "_(gemeo_state: no active case bound)_"
from . import core as gcore, llm_context
twin = gcore.get_gemeo(case_id) or await gcore.query_gemeo(case_id)
if twin is None:
return f"_(gemeo_state: no twin for {case_id})_"
if not section:
return llm_context.serialize_twin_for_llm(twin)
val = getattr(twin, section, None)
if val is None:
return f"_(gemeo_state: section `{section}` empty)_"
from dataclasses import asdict
try:
d = asdict(val) if hasattr(val, "__dataclass_fields__") else val
except Exception:
d = str(val)
import json
return f"```json\n{json.dumps(d, default=str, indent=2)[:3000]}\n```"
# ─── registration helpers ─────────────────────────────────────────────────
def get_tool_specs() -> list[dict]:
"""Tool specs in OpenAI-tool / LangChain-tool compatible shape.
Returns a list usable for binding to the LLM (e.g. `llm.bind_tools(...)`).
"""
return [
{
"type": "function",
"function": {
"name": "gemeo_lookup",
"description": (
"GraphRAG retrieval over the patient gemeo (digital twin). "
"Returns relevant subgraph triples + KG communities + "
"(global mode) similar cohort cases + PubMed literature. "
"Use whenever you need grounded evidence for a clinical claim."
),
"parameters": {
"type": "object",
"properties": {
"query": {"type": "string", "description": "Free-text query grounded in the patient case."},
"mode": {"type": "string", "enum": ["local", "global"], "default": "local"},
},
"required": ["query"],
},
},
},
{
"type": "function",
"function": {
"name": "gemeo_state",
"description": (
"Return the live digital twin state for the active case. "
"Pass an optional `section` to restrict to one capability."
),
"parameters": {
"type": "object",
"properties": {
"section": {
"type": "string",
"enum": [
"diagnoses", "risk", "trajectory", "drugs", "ddi",
"pharmacogen", "family", "reverse_pheno",
"protocol_compliance", "next_questions", "sus_check",
],
},
},
},
},
},
]
# Registry consumed by tools.py if it exposes a register_tool() API.
TOOL_FUNCTIONS = {
"gemeo_lookup": gemeo_lookup,
"gemeo_state": gemeo_state,
}