"""Skill router — surface relevant skills (from /skills/, MCP servers, local tools) for a given Gemeo twin. The repo has 535 skills (curated bio-* + auto-generated agents) and 8 MCP servers. Most are silent until manually invoked. This router scans the twin's diagnoses + phenotypes + planned treatments and returns the top-N skills that would be useful to call NEXT, ranked by relevance. Used by: - frontend (show "Suggested skills" panel in /space/[id]) - LLM agents (auto-binding the most relevant 5 tools instead of all 461) - MCP server (gemeo://twin/{id}/suggested-skills) """ from __future__ import annotations import logging import os import re from dataclasses import dataclass, field logger = logging.getLogger("gemeo.skill_router") SKILLS_DIR = os.path.join(os.path.dirname(__file__), "..", "skills") @dataclass class SkillCandidate: name: str path: str description: str domain: str = "" score: float = 0.0 matched_terms: list = field(default_factory=list) # ─── lightweight SKILL.md indexer ───────────────────────────────────────── _SKILL_INDEX: list[dict] = [] # cached on first call def _load_skill_index() -> list[dict]: """Read all SKILL.md frontmatter once.""" global _SKILL_INDEX if _SKILL_INDEX: return _SKILL_INDEX if not os.path.isdir(SKILLS_DIR): return [] out = [] for entry in os.listdir(SKILLS_DIR): sk_md = os.path.join(SKILLS_DIR, entry, "SKILL.md") if not os.path.isfile(sk_md): continue try: with open(sk_md, encoding="utf-8") as f: content = f.read(4000) # Parse minimal frontmatter m = re.search(r"---\s*\n(.*?)\n---", content, re.DOTALL) fm = m.group(1) if m else "" name_m = re.search(r"^name:\s*(.+)$", fm, re.MULTILINE) desc_m = re.search(r"^description:\s*(.+)$", fm, re.MULTILINE) domain_m = re.search(r"domain:\s*([\w\-]+)", fm) triggers_m = re.search(r"trigger_keywords:\s*\n((?:\s+-\s*.+\n?)+)", fm) triggers = [ t.strip().lstrip("- ").strip().strip('"').strip("'") for t in (triggers_m.group(1).splitlines() if triggers_m else []) ] if triggers_m else [] out.append({ "name": (name_m.group(1).strip() if name_m else entry), "slug": entry, "description": (desc_m.group(1).strip() if desc_m else "")[:300], "domain": (domain_m.group(1).strip() if domain_m else ""), "triggers": triggers, "path": sk_md, }) except Exception as e: logger.debug(f"skill {entry} parse failed: {e}") _SKILL_INDEX = out logger.info(f"skill router indexed {len(out)} skills") return out # ─── relevance scoring ──────────────────────────────────────────────────── def _terms_for_twin(twin) -> set[str]: """Extract relevance terms from a Gemeo twin.""" terms: set[str] = set() if twin is None: return terms # Top diagnoses (name + ORPHA + CID) for d in (twin.diagnoses or [])[:5]: for k in ("name", "disease", "orpha"): v = d.get(k) if v: terms.add(str(v).lower()) # Phenotypes (HPO + names) if twin.subgraph: for n in (twin.subgraph.nodes or [])[:30]: if n.label == "Phenotype": terms.add(str(n.name).lower()) if n.label == "Gene": terms.add(str(n.name).lower()) # Drugs in regimen if twin.drugs: for d in (twin.drugs.candidates or [])[:10]: terms.add(str(d.get("name", "")).lower()) # Family / inheritance if twin.family and twin.family.inheritance_mode: terms.add(twin.family.inheritance_mode.lower()) return {t for t in terms if t and len(t) > 2} def _score_skill(skill: dict, terms: set[str]) -> tuple[float, list[str]]: """Score a skill against extracted terms. Higher = more relevant.""" desc = (skill.get("description", "") + " " + " ".join(skill.get("triggers", []))).lower() matched = [] score = 0.0 for t in terms: if t in desc: score += 1.0 matched.append(t) else: # token-level for tok in re.split(r"\W+", t): if len(tok) > 3 and tok in desc: score += 0.3 matched.append(t) break # Boost if explicit trigger match for trig in skill.get("triggers", []): if trig.lower() in " ".join(terms): score += 0.5 matched.append(f"trigger:{trig}") return score, list(set(matched)) def suggest(twin, *, top_n: int = 8, min_score: float = 0.5) -> list[SkillCandidate]: """Return top-N skill suggestions for this twin.""" skills = _load_skill_index() if not skills: return [] terms = _terms_for_twin(twin) if not terms: return [] scored = [] for sk in skills: score, matched = _score_skill(sk, terms) if score < min_score: continue scored.append(SkillCandidate( name=sk["name"], path=sk["path"], description=sk["description"], domain=sk["domain"], score=round(score, 2), matched_terms=matched[:5], )) scored.sort(key=lambda x: x.score, reverse=True) return scored[:top_n] # ─── MCP server suggestions ─────────────────────────────────────────────── MCP_SUGGEST_RULES = { # case has Drug candidate → suggest biomcp + clinicaltrials "drugs": ["biomcp", "clinicaltrials", "pubmed"], # case has variant → suggest biomcp (MyVariant.info + ClinVar) "variant": ["biomcp", "neo4j"], # case has rare disease → suggest gemeo (already loaded), pubmed, exa "diagnoses": ["gemeo", "pubmed", "exa"], # always-on "always": ["gemeo", "neo4j"], } def suggest_mcp_servers(twin) -> list[dict]: """Suggest MCP servers to enable for this twin.""" out_names = set(MCP_SUGGEST_RULES["always"]) if twin is None: return list({"name": n} for n in out_names) if twin.drugs and twin.drugs.candidates: out_names.update(MCP_SUGGEST_RULES["drugs"]) if twin.subgraph and any(n.label == "Gene" for n in (twin.subgraph.nodes or [])): out_names.update(MCP_SUGGEST_RULES["variant"]) if twin.diagnoses: out_names.update(MCP_SUGGEST_RULES["diagnoses"]) return [{"name": n} for n in sorted(out_names)]