File size: 7,873 Bytes
8746415
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
# Narada: Teaching an LLM to Diagnose Rare Disease by Navigating a Knowledge Graph

**KrishVenky · Meta × PyTorch OpenEnv Hackathon 2026**

---

## The Problem

A rare disease patient waits, on average, 4–7 years to receive a correct diagnosis. In that window they cycle through specialists, accumulate misdiagnoses, and often receive treatments for conditions they don't have. The bottleneck is not compassion — it is the sheer cognitive load of cross-referencing thousands of variants, phenotypes, and disease mechanisms simultaneously.

We asked: can reinforcement learning teach a small language model to reason through a clinical knowledge graph the way a senior geneticist does — following causal chains, resisting high-salience decoys, and building evidence across multiple steps?

**Narada** is our answer.

---

## The Environment

Narada is a live OpenEnv-compliant RL environment backed by a **55,000-node knowledge graph** built from ClinVar and HPO (Human Phenotype Ontology). The agent navigates the graph via WebSocket, making one decision per step.

**Graph node types:** phenotype → disease → gene → variant → pathway

**Three task tiers:**

| Task | Difficulty | What it tests |
|---|---|---|
| `monogenic` | Easy | Follow phenotype→disease→gene→variant chain. Single causal variant, 3–4 HPO terms, 4–8 hops. |
| `oligogenic` | Medium | Two contributing variants across different genes. 5–7 phenotype terms across two organ systems. Must find both within budget. |
| `phenotype_mismatch` | Hard | A high-pathogenicity BRCA1/BRCA2/TP53 frameshift is planted as a decoy. Patient phenotypes are cardiac or neurological. The correct variant is lower-pathogenicity but phenotypically matched. Most untrained LLMs flag the decoy. |

**Available actions:** `hop(node_id)` · `flag_causal(variant_id)` · `backtrack()` · `request_lab(test)` · `summarise_trail()`

The environment is **live at** `https://krishvenky-narada-env.hf.space` and fully reproducible — the full server source is in `src/envs/narada/`.

---

## Training Approach: GRPO with Curriculum RL

We trained **Qwen3-1.7B** with 4-bit quantization (Unsloth + bitsandbytes) using **Group Relative Policy Optimisation (GRPO)** via TRL.

**Key design choices:**

- **G=8 completions per prompt, temperature=1.1** — enough diversity to get meaningful reward variance without collapsing
- **Async-parallel reward evaluation** — all 8 completions evaluated concurrently via separate WebSocket sessions, cutting collection from ~5 min to ~30 sec per batch
- **Curriculum order:** monogenic (80 steps) → oligogenic (60 steps) → phenotype_mismatch (60 steps) — teach basic navigation before introducing multi-objective and adversarial tasks
- **max_completion_length=800** — critical; at 300 tokens Qwen3's thinking blocks consume the entire budget leaving no room for the JSON action
- **17.4M trainable parameters** via LoRA rank 16

---

## Results

### Baseline (Zero-Shot, Qwen3-1.7B)

Before any gradient updates, benchmarked across 5 evaluation seeds per task:

| Task | Zero-Shot Score |
|---|---|
| monogenic | 0.4955 |
| oligogenic | 0.4955 |
| phenotype_mismatch | 0.4955 |

Near-chance performance across all three tiers, as expected. The model has the clinical vocabulary but no learned navigation strategy. It hops randomly, ignores the phenotype→disease→gene chain, and flags the BRCA1 decoy in phenotype_mismatch almost every time.

### After GRPO Training

Full curriculum completed: 80 steps monogenic, 60 steps oligogenic, 60 steps phenotype_mismatch (200 total, ~2 hours on an A100).

| Task | Baseline | After GRPO | Gain |
|---|---|---|---|
| monogenic | 0.4955 | **0.572** | +15.4% |
| oligogenic | 0.4955 | **0.561** | +13.2% |
| phenotype_mismatch | 0.4955 | **0.552** | +11.4% |
| **Average** | 0.4955 | **0.562** | **+13.3%** |

![Before/After GRPO](results/before_after.png)
*Zero-shot baseline vs. GRPO-trained Qwen3-1.7B across all three task tiers*

![Reward Curve](results/reward_curve.png)
*Mean reward across 200 training steps. Shaded band = reward_std. Dotted line = zero-shot baseline. Phase boundaries mark curriculum transitions.*

![Loss Curve](results/loss_curve.png)
*Policy loss across curriculum — negative early (clean gradient direction), rises during mid-training exploration, recovers as policy stabilises.*

![Reward Std](results/reward_std.png)
*reward_std > 0 throughout all 200 steps: GRPO never hit zero-gradient collapse. The training signal was real.*

### Training Dynamics

Key observations from the step logs:

- **reward_std non-zero throughout** — GRPO received real gradient signal across all 200 steps; the model was never in a collapsed uniform-reward state
- **Completion length dropped ~65–70%** — from ~430 tokens (monogenic early) to ~125 tokens (phenotype_mismatch final). The model learned to be decisive rather than exploratory
- **KL stable** — peaked at 0.044 during oligogenic, never diverged; the policy stayed close to base while improving
- **Strongest early reward spike at step 20 monogenic (0.712)** — the model rapidly discovered the basic phenotype→gene→variant path before reward variance forced more precise reasoning

The phenotype_mismatch improvement (+11.4%) is the most significant result. This is the task designed to fool untrained LLMs with a high-pathogenicity cancer gene decoy. A consistent 11% lift after 60 steps of RL-only training — no instruction tuning, no clinical examples — suggests the reward signal successfully penalised phenotypically irrelevant flags.

---

## Reproducing the Full Run

Everything is wired up. If you have an L4 or A10G and 4 hours:

```bash
git clone https://github.com/KrishVenky/ClinDetect
cd ClinDetect

# Set env vars
export HF_TOKEN=your_token
export HF_PUSH_REPO=your_org/narada-detective-lora
export ENV_URL=https://krishvenky-narada-env.hf.space

pip install -r training/requirements.txt
python training/train.py
```

The script will:
1. Collect episodes in parallel (~30 sec)
2. Benchmark the base model (zero-shot baseline)
3. Run the three-phase curriculum
4. Benchmark the trained model
5. Generate `training_curve.png` and `before_after.png`
6. Push the LoRA adapter and plots to your HF repo

The environment server is live and will remain so. **You have a week — run it.**

---

## What the Results Mean

The `phenotype_mismatch` result is the one worth paying attention to. It operationalises a failure mode that exists in real clinical practice: over-weighting variant pathogenicity score at the expense of phenotypic fit. A BRCA1 frameshift is objectively "more pathogenic" than an SCN5A missense, but if your patient has long QT syndrome and not breast cancer, BRCA1 is irrelevant.

Training a 1.7B model to resist that signal — via RL reward shaping rather than instruction following — is the core thesis. The baseline confirms the model fails this task at chance. After 60 steps of GRPO with a reward function that specifically penalises decoy flags (`cardiac_flag * 1.0 - decoy_flag * 0.5`), performance climbs 11.4%. The model learned causal discipline, not pattern matching.

At 1.7B parameters with a 17M trainable LoRA, this is a proof of concept. The trajectory across all three tiers is consistent: RL fine-tuning moves a general-purpose LLM toward principled clinical reasoning on a task it had no prior exposure to.

---

## Artifacts

| Artifact | Location |
|---|---|
| Environment server | `src/envs/narada/` |
| OpenEnv spec | `openenv.yaml` |
| Training script | `training/train.py` |
| Training notebook | `training/narada_grpo.ipynb` |
| Inference benchmark | `inference.py` |
| LoRA adapter (monogenic-trained) | `KrishVenky/narada-detective-lora` on HF |
| Live environment | `https://krishvenky-narada-env.hf.space` |

---

*Built in 24 hours. Exams start tomorrow.*