gemeo-twin-stack / src /gemeo /skill_router.py
timmers's picture
GEMEO world-model β€” initial release (module + NeuralSurv ckpt + RareBench v49 + KG embeddings)
089d665 verified
"""Skill router β€” surface relevant skills (from /skills/, MCP servers,
local tools) for a given Gemeo twin.
The repo has 535 skills (curated bio-* + auto-generated agents) and 8
MCP servers. Most are silent until manually invoked. This router scans
the twin's diagnoses + phenotypes + planned treatments and returns the
top-N skills that would be useful to call NEXT, ranked by relevance.
Used by:
- frontend (show "Suggested skills" panel in /space/[id])
- LLM agents (auto-binding the most relevant 5 tools instead of all 461)
- MCP server (gemeo://twin/{id}/suggested-skills)
"""
from __future__ import annotations
import logging
import os
import re
from dataclasses import dataclass, field
logger = logging.getLogger("gemeo.skill_router")
SKILLS_DIR = os.path.join(os.path.dirname(__file__), "..", "skills")
@dataclass
class SkillCandidate:
name: str
path: str
description: str
domain: str = ""
score: float = 0.0
matched_terms: list = field(default_factory=list)
# ─── lightweight SKILL.md indexer ─────────────────────────────────────────
_SKILL_INDEX: list[dict] = [] # cached on first call
def _load_skill_index() -> list[dict]:
"""Read all SKILL.md frontmatter once."""
global _SKILL_INDEX
if _SKILL_INDEX:
return _SKILL_INDEX
if not os.path.isdir(SKILLS_DIR):
return []
out = []
for entry in os.listdir(SKILLS_DIR):
sk_md = os.path.join(SKILLS_DIR, entry, "SKILL.md")
if not os.path.isfile(sk_md):
continue
try:
with open(sk_md, encoding="utf-8") as f:
content = f.read(4000)
# Parse minimal frontmatter
m = re.search(r"---\s*\n(.*?)\n---", content, re.DOTALL)
fm = m.group(1) if m else ""
name_m = re.search(r"^name:\s*(.+)$", fm, re.MULTILINE)
desc_m = re.search(r"^description:\s*(.+)$", fm, re.MULTILINE)
domain_m = re.search(r"domain:\s*([\w\-]+)", fm)
triggers_m = re.search(r"trigger_keywords:\s*\n((?:\s+-\s*.+\n?)+)", fm)
triggers = [
t.strip().lstrip("- ").strip().strip('"').strip("'")
for t in (triggers_m.group(1).splitlines() if triggers_m else [])
] if triggers_m else []
out.append({
"name": (name_m.group(1).strip() if name_m else entry),
"slug": entry,
"description": (desc_m.group(1).strip() if desc_m else "")[:300],
"domain": (domain_m.group(1).strip() if domain_m else ""),
"triggers": triggers,
"path": sk_md,
})
except Exception as e:
logger.debug(f"skill {entry} parse failed: {e}")
_SKILL_INDEX = out
logger.info(f"skill router indexed {len(out)} skills")
return out
# ─── relevance scoring ────────────────────────────────────────────────────
def _terms_for_twin(twin) -> set[str]:
"""Extract relevance terms from a Gemeo twin."""
terms: set[str] = set()
if twin is None:
return terms
# Top diagnoses (name + ORPHA + CID)
for d in (twin.diagnoses or [])[:5]:
for k in ("name", "disease", "orpha"):
v = d.get(k)
if v: terms.add(str(v).lower())
# Phenotypes (HPO + names)
if twin.subgraph:
for n in (twin.subgraph.nodes or [])[:30]:
if n.label == "Phenotype":
terms.add(str(n.name).lower())
if n.label == "Gene":
terms.add(str(n.name).lower())
# Drugs in regimen
if twin.drugs:
for d in (twin.drugs.candidates or [])[:10]:
terms.add(str(d.get("name", "")).lower())
# Family / inheritance
if twin.family and twin.family.inheritance_mode:
terms.add(twin.family.inheritance_mode.lower())
return {t for t in terms if t and len(t) > 2}
def _score_skill(skill: dict, terms: set[str]) -> tuple[float, list[str]]:
"""Score a skill against extracted terms. Higher = more relevant."""
desc = (skill.get("description", "") + " " + " ".join(skill.get("triggers", []))).lower()
matched = []
score = 0.0
for t in terms:
if t in desc:
score += 1.0
matched.append(t)
else:
# token-level
for tok in re.split(r"\W+", t):
if len(tok) > 3 and tok in desc:
score += 0.3
matched.append(t)
break
# Boost if explicit trigger match
for trig in skill.get("triggers", []):
if trig.lower() in " ".join(terms):
score += 0.5
matched.append(f"trigger:{trig}")
return score, list(set(matched))
def suggest(twin, *, top_n: int = 8, min_score: float = 0.5) -> list[SkillCandidate]:
"""Return top-N skill suggestions for this twin."""
skills = _load_skill_index()
if not skills:
return []
terms = _terms_for_twin(twin)
if not terms:
return []
scored = []
for sk in skills:
score, matched = _score_skill(sk, terms)
if score < min_score:
continue
scored.append(SkillCandidate(
name=sk["name"],
path=sk["path"],
description=sk["description"],
domain=sk["domain"],
score=round(score, 2),
matched_terms=matched[:5],
))
scored.sort(key=lambda x: x.score, reverse=True)
return scored[:top_n]
# ─── MCP server suggestions ───────────────────────────────────────────────
MCP_SUGGEST_RULES = {
# case has Drug candidate β†’ suggest biomcp + clinicaltrials
"drugs": ["biomcp", "clinicaltrials", "pubmed"],
# case has variant β†’ suggest biomcp (MyVariant.info + ClinVar)
"variant": ["biomcp", "neo4j"],
# case has rare disease β†’ suggest gemeo (already loaded), pubmed, exa
"diagnoses": ["gemeo", "pubmed", "exa"],
# always-on
"always": ["gemeo", "neo4j"],
}
def suggest_mcp_servers(twin) -> list[dict]:
"""Suggest MCP servers to enable for this twin."""
out_names = set(MCP_SUGGEST_RULES["always"])
if twin is None:
return list({"name": n} for n in out_names)
if twin.drugs and twin.drugs.candidates:
out_names.update(MCP_SUGGEST_RULES["drugs"])
if twin.subgraph and any(n.label == "Gene" for n in (twin.subgraph.nodes or [])):
out_names.update(MCP_SUGGEST_RULES["variant"])
if twin.diagnoses:
out_names.update(MCP_SUGGEST_RULES["diagnoses"])
return [{"name": n} for n in sorted(out_names)]