gemeo-twin-stack / src /gemeo /feedback.py
timmers's picture
GEMEO world-model β€” initial release (module + NeuralSurv ckpt + RareBench v49 + KG embeddings)
089d665 verified
"""Feedback loop β€” capture user corrections to retrain.
Each piece of feedback is appended to a JSONL ledger:
./gemeo/artifacts/feedback.jsonl
Schema (one JSON per line):
{
"ts": "...",
"twin_id": "gemeo_...",
"case_id": "...",
"kind": "diagnosis|trajectory|drug|trial|next_question|cohort",
"target": {...}, # what the model said
"user_correction": {...}, # what the user said is right
"user_id": "...",
"comment": "..."
}
The training pipelines in `gemeo/train/` consume this ledger to:
- HGT: re-rank patient embeddings via supervised contrastive (positive=
confirmed-similar, negative=user-rejected)
- TxGNN: hard-negative mining for drug recs the user marked wrong
- TGNN: outcome supervision when user provides actual trajectory
This is the closed-loop piece β€” the feature that turns Gemeo from a
static SOTA model into a *learning* digital twin.
"""
from __future__ import annotations
import os
import json
import logging
from datetime import datetime, timezone
logger = logging.getLogger("gemeo.feedback")
LEDGER_PATH = os.environ.get(
"GEMEO_FEEDBACK_LEDGER",
os.path.join(os.path.dirname(__file__), "artifacts", "feedback.jsonl"),
)
def _ensure_dir():
os.makedirs(os.path.dirname(LEDGER_PATH), exist_ok=True)
def record(
*,
twin_id: str,
kind: str,
target: dict,
user_correction: dict,
case_id: str = None,
user_id: str = None,
comment: str = None,
) -> dict:
"""Append a feedback record. Returns the record dict.
kind ∈ {"diagnosis", "trajectory", "drug", "trial", "next_question", "cohort", "subgraph"}
"""
_ensure_dir()
rec = {
"ts": datetime.now(timezone.utc).isoformat(),
"twin_id": twin_id,
"case_id": case_id,
"kind": kind,
"target": target,
"user_correction": user_correction,
"user_id": user_id,
"comment": comment,
}
try:
with open(LEDGER_PATH, "a") as f:
f.write(json.dumps(rec, default=str) + "\n")
except Exception as e:
logger.error(f"failed to write feedback: {e}")
return rec
def stats() -> dict:
"""Counts per kind β€” useful for /api/gemeo/health."""
if not os.path.exists(LEDGER_PATH):
return {"total": 0, "by_kind": {}, "ledger": LEDGER_PATH}
counts = {}
n = 0
try:
with open(LEDGER_PATH) as f:
for line in f:
try:
rec = json.loads(line)
k = rec.get("kind", "unknown")
counts[k] = counts.get(k, 0) + 1
n += 1
except Exception:
continue
except Exception as e:
logger.error(f"failed to read ledger: {e}")
return {"total": n, "by_kind": counts, "ledger": LEDGER_PATH}
def iter_records(kind: str = None):
"""Iterator over feedback records β€” used by training pipelines."""
if not os.path.exists(LEDGER_PATH):
return
with open(LEDGER_PATH) as f:
for line in f:
try:
rec = json.loads(line)
except Exception:
continue
if kind and rec.get("kind") != kind:
continue
yield rec