File size: 5,628 Bytes
089d665
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
"""Gemeo skill β€” callable tool the LLM can invoke during a conversation.

Two tool functions exposed:

  - `gemeo_lookup(query, mode="local")` β€” GraphRAG over the patient gemeo
    + (optionally) cohort + literature. Returns Markdown-ready evidence.
  - `gemeo_state(section?)` β€” returns the live twin state, optionally
    restricted to one section (cohort, ddi, family, ...).

Both auto-resolve the active `case_id` from the per-call context (set
by the orchestrator before the LLM turn). The agent never sees or
guesses case_ids; the platform binds them.
"""
from __future__ import annotations
import contextvars
import logging
from typing import Optional

logger = logging.getLogger("gemeo.skill")


# Per-task contextvar holding the active case id.
# Set by `with active_case(case_id): await llm.ainvoke(...)`.
_ACTIVE_CASE: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar(
    "gemeo_active_case", default=None,
)


class active_case:
    """Context manager binding a case_id to the current async task."""

    def __init__(self, case_id: Optional[str]):
        self._case_id = case_id
        self._token = None

    def __enter__(self):
        self._token = _ACTIVE_CASE.set(self._case_id)
        return self

    def __exit__(self, *exc):
        if self._token is not None:
            _ACTIVE_CASE.reset(self._token)


def get_active_case() -> Optional[str]:
    return _ACTIVE_CASE.get()


# ─── tool implementations ─────────────────────────────────────────────────

async def gemeo_lookup(query: str, mode: str = "local") -> str:
    """GraphRAG over the patient gemeo.

    Args:
        query: free-text question grounded in the case.
        mode: "local" (subgraph only) or "global" (also cohort + literature).
    Returns:
        Markdown-formatted evidence block.
    """
    case_id = get_active_case()
    if not case_id:
        return "_(gemeo_lookup: no active case bound; nothing retrieved)_"
    from . import graphrag
    result = await graphrag.retrieve(case_id, query, mode=mode)
    return graphrag.format_for_llm(result)


async def gemeo_state(section: Optional[str] = None) -> str:
    """Return the live twin state (or one section).

    section ∈ {"diagnoses", "risk", "trajectory", "drugs", "ddi",
              "pharmacogen", "family", "reverse_pheno",
              "protocol_compliance", "next_questions", "sus_check"}
    """
    case_id = get_active_case()
    if not case_id:
        return "_(gemeo_state: no active case bound)_"
    from . import core as gcore, llm_context
    twin = gcore.get_gemeo(case_id) or await gcore.query_gemeo(case_id)
    if twin is None:
        return f"_(gemeo_state: no twin for {case_id})_"
    if not section:
        return llm_context.serialize_twin_for_llm(twin)
    val = getattr(twin, section, None)
    if val is None:
        return f"_(gemeo_state: section `{section}` empty)_"
    from dataclasses import asdict
    try:
        d = asdict(val) if hasattr(val, "__dataclass_fields__") else val
    except Exception:
        d = str(val)
    import json
    return f"```json\n{json.dumps(d, default=str, indent=2)[:3000]}\n```"


# ─── registration helpers ─────────────────────────────────────────────────

def get_tool_specs() -> list[dict]:
    """Tool specs in OpenAI-tool / LangChain-tool compatible shape.

    Returns a list usable for binding to the LLM (e.g. `llm.bind_tools(...)`).
    """
    return [
        {
            "type": "function",
            "function": {
                "name": "gemeo_lookup",
                "description": (
                    "GraphRAG retrieval over the patient gemeo (digital twin). "
                    "Returns relevant subgraph triples + KG communities + "
                    "(global mode) similar cohort cases + PubMed literature. "
                    "Use whenever you need grounded evidence for a clinical claim."
                ),
                "parameters": {
                    "type": "object",
                    "properties": {
                        "query": {"type": "string", "description": "Free-text query grounded in the patient case."},
                        "mode": {"type": "string", "enum": ["local", "global"], "default": "local"},
                    },
                    "required": ["query"],
                },
            },
        },
        {
            "type": "function",
            "function": {
                "name": "gemeo_state",
                "description": (
                    "Return the live digital twin state for the active case. "
                    "Pass an optional `section` to restrict to one capability."
                ),
                "parameters": {
                    "type": "object",
                    "properties": {
                        "section": {
                            "type": "string",
                            "enum": [
                                "diagnoses", "risk", "trajectory", "drugs", "ddi",
                                "pharmacogen", "family", "reverse_pheno",
                                "protocol_compliance", "next_questions", "sus_check",
                            ],
                        },
                    },
                },
            },
        },
    ]


# Registry consumed by tools.py if it exposes a register_tool() API.
TOOL_FUNCTIONS = {
    "gemeo_lookup": gemeo_lookup,
    "gemeo_state": gemeo_state,
}