File size: 6,876 Bytes
089d665
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
"""Skill router β€” surface relevant skills (from /skills/, MCP servers,
local tools) for a given Gemeo twin.

The repo has 535 skills (curated bio-* + auto-generated agents) and 8
MCP servers. Most are silent until manually invoked. This router scans
the twin's diagnoses + phenotypes + planned treatments and returns the
top-N skills that would be useful to call NEXT, ranked by relevance.

Used by:
  - frontend (show "Suggested skills" panel in /space/[id])
  - LLM agents (auto-binding the most relevant 5 tools instead of all 461)
  - MCP server (gemeo://twin/{id}/suggested-skills)
"""
from __future__ import annotations
import logging
import os
import re
from dataclasses import dataclass, field

logger = logging.getLogger("gemeo.skill_router")

SKILLS_DIR = os.path.join(os.path.dirname(__file__), "..", "skills")


@dataclass
class SkillCandidate:
    name: str
    path: str
    description: str
    domain: str = ""
    score: float = 0.0
    matched_terms: list = field(default_factory=list)


# ─── lightweight SKILL.md indexer ─────────────────────────────────────────

_SKILL_INDEX: list[dict] = []  # cached on first call


def _load_skill_index() -> list[dict]:
    """Read all SKILL.md frontmatter once."""
    global _SKILL_INDEX
    if _SKILL_INDEX:
        return _SKILL_INDEX
    if not os.path.isdir(SKILLS_DIR):
        return []
    out = []
    for entry in os.listdir(SKILLS_DIR):
        sk_md = os.path.join(SKILLS_DIR, entry, "SKILL.md")
        if not os.path.isfile(sk_md):
            continue
        try:
            with open(sk_md, encoding="utf-8") as f:
                content = f.read(4000)
            # Parse minimal frontmatter
            m = re.search(r"---\s*\n(.*?)\n---", content, re.DOTALL)
            fm = m.group(1) if m else ""
            name_m = re.search(r"^name:\s*(.+)$", fm, re.MULTILINE)
            desc_m = re.search(r"^description:\s*(.+)$", fm, re.MULTILINE)
            domain_m = re.search(r"domain:\s*([\w\-]+)", fm)
            triggers_m = re.search(r"trigger_keywords:\s*\n((?:\s+-\s*.+\n?)+)", fm)
            triggers = [
                t.strip().lstrip("- ").strip().strip('"').strip("'")
                for t in (triggers_m.group(1).splitlines() if triggers_m else [])
            ] if triggers_m else []
            out.append({
                "name": (name_m.group(1).strip() if name_m else entry),
                "slug": entry,
                "description": (desc_m.group(1).strip() if desc_m else "")[:300],
                "domain": (domain_m.group(1).strip() if domain_m else ""),
                "triggers": triggers,
                "path": sk_md,
            })
        except Exception as e:
            logger.debug(f"skill {entry} parse failed: {e}")
    _SKILL_INDEX = out
    logger.info(f"skill router indexed {len(out)} skills")
    return out


# ─── relevance scoring ────────────────────────────────────────────────────

def _terms_for_twin(twin) -> set[str]:
    """Extract relevance terms from a Gemeo twin."""
    terms: set[str] = set()
    if twin is None:
        return terms
    # Top diagnoses (name + ORPHA + CID)
    for d in (twin.diagnoses or [])[:5]:
        for k in ("name", "disease", "orpha"):
            v = d.get(k)
            if v: terms.add(str(v).lower())
    # Phenotypes (HPO + names)
    if twin.subgraph:
        for n in (twin.subgraph.nodes or [])[:30]:
            if n.label == "Phenotype":
                terms.add(str(n.name).lower())
            if n.label == "Gene":
                terms.add(str(n.name).lower())
    # Drugs in regimen
    if twin.drugs:
        for d in (twin.drugs.candidates or [])[:10]:
            terms.add(str(d.get("name", "")).lower())
    # Family / inheritance
    if twin.family and twin.family.inheritance_mode:
        terms.add(twin.family.inheritance_mode.lower())
    return {t for t in terms if t and len(t) > 2}


def _score_skill(skill: dict, terms: set[str]) -> tuple[float, list[str]]:
    """Score a skill against extracted terms. Higher = more relevant."""
    desc = (skill.get("description", "") + " " + " ".join(skill.get("triggers", []))).lower()
    matched = []
    score = 0.0
    for t in terms:
        if t in desc:
            score += 1.0
            matched.append(t)
        else:
            # token-level
            for tok in re.split(r"\W+", t):
                if len(tok) > 3 and tok in desc:
                    score += 0.3
                    matched.append(t)
                    break
    # Boost if explicit trigger match
    for trig in skill.get("triggers", []):
        if trig.lower() in " ".join(terms):
            score += 0.5
            matched.append(f"trigger:{trig}")
    return score, list(set(matched))


def suggest(twin, *, top_n: int = 8, min_score: float = 0.5) -> list[SkillCandidate]:
    """Return top-N skill suggestions for this twin."""
    skills = _load_skill_index()
    if not skills:
        return []
    terms = _terms_for_twin(twin)
    if not terms:
        return []
    scored = []
    for sk in skills:
        score, matched = _score_skill(sk, terms)
        if score < min_score:
            continue
        scored.append(SkillCandidate(
            name=sk["name"],
            path=sk["path"],
            description=sk["description"],
            domain=sk["domain"],
            score=round(score, 2),
            matched_terms=matched[:5],
        ))
    scored.sort(key=lambda x: x.score, reverse=True)
    return scored[:top_n]


# ─── MCP server suggestions ───────────────────────────────────────────────

MCP_SUGGEST_RULES = {
    # case has Drug candidate β†’ suggest biomcp + clinicaltrials
    "drugs": ["biomcp", "clinicaltrials", "pubmed"],
    # case has variant β†’ suggest biomcp (MyVariant.info + ClinVar)
    "variant": ["biomcp", "neo4j"],
    # case has rare disease β†’ suggest gemeo (already loaded), pubmed, exa
    "diagnoses": ["gemeo", "pubmed", "exa"],
    # always-on
    "always": ["gemeo", "neo4j"],
}


def suggest_mcp_servers(twin) -> list[dict]:
    """Suggest MCP servers to enable for this twin."""
    out_names = set(MCP_SUGGEST_RULES["always"])
    if twin is None:
        return list({"name": n} for n in out_names)
    if twin.drugs and twin.drugs.candidates:
        out_names.update(MCP_SUGGEST_RULES["drugs"])
    if twin.subgraph and any(n.label == "Gene" for n in (twin.subgraph.nodes or [])):
        out_names.update(MCP_SUGGEST_RULES["variant"])
    if twin.diagnoses:
        out_names.update(MCP_SUGGEST_RULES["diagnoses"])
    return [{"name": n} for n in sorted(out_names)]