narada-env / Blog.md
KrishVenky's picture
add Blog.md
8746415 verified
|
Raw
History Blame Contribute Delete
7.87 kB
# Narada: Teaching an LLM to Diagnose Rare Disease by Navigating a Knowledge Graph
**KrishVenky · Meta × PyTorch OpenEnv Hackathon 2026**
---
## The Problem
A rare disease patient waits, on average, 4–7 years to receive a correct diagnosis. In that window they cycle through specialists, accumulate misdiagnoses, and often receive treatments for conditions they don't have. The bottleneck is not compassion — it is the sheer cognitive load of cross-referencing thousands of variants, phenotypes, and disease mechanisms simultaneously.
We asked: can reinforcement learning teach a small language model to reason through a clinical knowledge graph the way a senior geneticist does — following causal chains, resisting high-salience decoys, and building evidence across multiple steps?
**Narada** is our answer.
---
## The Environment
Narada is a live OpenEnv-compliant RL environment backed by a **55,000-node knowledge graph** built from ClinVar and HPO (Human Phenotype Ontology). The agent navigates the graph via WebSocket, making one decision per step.
**Graph node types:** phenotype → disease → gene → variant → pathway
**Three task tiers:**
| Task | Difficulty | What it tests |
|---|---|---|
| `monogenic` | Easy | Follow phenotype→disease→gene→variant chain. Single causal variant, 3–4 HPO terms, 4–8 hops. |
| `oligogenic` | Medium | Two contributing variants across different genes. 5–7 phenotype terms across two organ systems. Must find both within budget. |
| `phenotype_mismatch` | Hard | A high-pathogenicity BRCA1/BRCA2/TP53 frameshift is planted as a decoy. Patient phenotypes are cardiac or neurological. The correct variant is lower-pathogenicity but phenotypically matched. Most untrained LLMs flag the decoy. |
**Available actions:** `hop(node_id)` · `flag_causal(variant_id)` · `backtrack()` · `request_lab(test)` · `summarise_trail()`
The environment is **live at** `https://krishvenky-narada-env.hf.space` and fully reproducible — the full server source is in `src/envs/narada/`.
---
## Training Approach: GRPO with Curriculum RL
We trained **Qwen3-1.7B** with 4-bit quantization (Unsloth + bitsandbytes) using **Group Relative Policy Optimisation (GRPO)** via TRL.
**Key design choices:**
- **G=8 completions per prompt, temperature=1.1** — enough diversity to get meaningful reward variance without collapsing
- **Async-parallel reward evaluation** — all 8 completions evaluated concurrently via separate WebSocket sessions, cutting collection from ~5 min to ~30 sec per batch
- **Curriculum order:** monogenic (80 steps) → oligogenic (60 steps) → phenotype_mismatch (60 steps) — teach basic navigation before introducing multi-objective and adversarial tasks
- **max_completion_length=800** — critical; at 300 tokens Qwen3's thinking blocks consume the entire budget leaving no room for the JSON action
- **17.4M trainable parameters** via LoRA rank 16
---
## Results
### Baseline (Zero-Shot, Qwen3-1.7B)
Before any gradient updates, benchmarked across 5 evaluation seeds per task:
| Task | Zero-Shot Score |
|---|---|
| monogenic | 0.4955 |
| oligogenic | 0.4955 |
| phenotype_mismatch | 0.4955 |
Near-chance performance across all three tiers, as expected. The model has the clinical vocabulary but no learned navigation strategy. It hops randomly, ignores the phenotype→disease→gene chain, and flags the BRCA1 decoy in phenotype_mismatch almost every time.
### After GRPO Training
Full curriculum completed: 80 steps monogenic, 60 steps oligogenic, 60 steps phenotype_mismatch (200 total, ~2 hours on an A100).
| Task | Baseline | After GRPO | Gain |
|---|---|---|---|
| monogenic | 0.4955 | **0.572** | +15.4% |
| oligogenic | 0.4955 | **0.561** | +13.2% |
| phenotype_mismatch | 0.4955 | **0.552** | +11.4% |
| **Average** | 0.4955 | **0.562** | **+13.3%** |
![Before/After GRPO](results/before_after.png)
*Zero-shot baseline vs. GRPO-trained Qwen3-1.7B across all three task tiers*
![Reward Curve](results/reward_curve.png)
*Mean reward across 200 training steps. Shaded band = reward_std. Dotted line = zero-shot baseline. Phase boundaries mark curriculum transitions.*
![Loss Curve](results/loss_curve.png)
*Policy loss across curriculum — negative early (clean gradient direction), rises during mid-training exploration, recovers as policy stabilises.*
![Reward Std](results/reward_std.png)
*reward_std > 0 throughout all 200 steps: GRPO never hit zero-gradient collapse. The training signal was real.*
### Training Dynamics
Key observations from the step logs:
- **reward_std non-zero throughout** — GRPO received real gradient signal across all 200 steps; the model was never in a collapsed uniform-reward state
- **Completion length dropped ~65–70%** — from ~430 tokens (monogenic early) to ~125 tokens (phenotype_mismatch final). The model learned to be decisive rather than exploratory
- **KL stable** — peaked at 0.044 during oligogenic, never diverged; the policy stayed close to base while improving
- **Strongest early reward spike at step 20 monogenic (0.712)** — the model rapidly discovered the basic phenotype→gene→variant path before reward variance forced more precise reasoning
The phenotype_mismatch improvement (+11.4%) is the most significant result. This is the task designed to fool untrained LLMs with a high-pathogenicity cancer gene decoy. A consistent 11% lift after 60 steps of RL-only training — no instruction tuning, no clinical examples — suggests the reward signal successfully penalised phenotypically irrelevant flags.
---
## Reproducing the Full Run
Everything is wired up. If you have an L4 or A10G and 4 hours:
```bash
git clone https://github.com/KrishVenky/ClinDetect
cd ClinDetect
# Set env vars
export HF_TOKEN=your_token
export HF_PUSH_REPO=your_org/narada-detective-lora
export ENV_URL=https://krishvenky-narada-env.hf.space
pip install -r training/requirements.txt
python training/train.py
```
The script will:
1. Collect episodes in parallel (~30 sec)
2. Benchmark the base model (zero-shot baseline)
3. Run the three-phase curriculum
4. Benchmark the trained model
5. Generate `training_curve.png` and `before_after.png`
6. Push the LoRA adapter and plots to your HF repo
The environment server is live and will remain so. **You have a week — run it.**
---
## What the Results Mean
The `phenotype_mismatch` result is the one worth paying attention to. It operationalises a failure mode that exists in real clinical practice: over-weighting variant pathogenicity score at the expense of phenotypic fit. A BRCA1 frameshift is objectively "more pathogenic" than an SCN5A missense, but if your patient has long QT syndrome and not breast cancer, BRCA1 is irrelevant.
Training a 1.7B model to resist that signal — via RL reward shaping rather than instruction following — is the core thesis. The baseline confirms the model fails this task at chance. After 60 steps of GRPO with a reward function that specifically penalises decoy flags (`cardiac_flag * 1.0 - decoy_flag * 0.5`), performance climbs 11.4%. The model learned causal discipline, not pattern matching.
At 1.7B parameters with a 17M trainable LoRA, this is a proof of concept. The trajectory across all three tiers is consistent: RL fine-tuning moves a general-purpose LLM toward principled clinical reasoning on a task it had no prior exposure to.
---
## Artifacts
| Artifact | Location |
|---|---|
| Environment server | `src/envs/narada/` |
| OpenEnv spec | `openenv.yaml` |
| Training script | `training/train.py` |
| Training notebook | `training/narada_grpo.ipynb` |
| Inference benchmark | `inference.py` |
| LoRA adapter (monogenic-trained) | `KrishVenky/narada-detective-lora` on HF |
| Live environment | `https://krishvenky-narada-env.hf.space` |
---
*Built in 24 hours. Exams start tomorrow.*