Update README.md
Browse files
README.md
CHANGED
|
@@ -5,156 +5,170 @@ datasets:
|
|
| 5 |
language:
|
| 6 |
- en
|
| 7 |
tags:
|
| 8 |
-
-
|
|
|
|
|
|
|
| 9 |
- gqa
|
| 10 |
- rope
|
| 11 |
- swiglu
|
| 12 |
- rmsnorm
|
| 13 |
-
- medical
|
| 14 |
---
|
| 15 |
|
| 16 |
-
# ๐ง
|
| 17 |
|
| 18 |
-
**
|
| 19 |
-
โ ๏ธ
|
| 20 |
|
| 21 |
---
|
| 22 |
|
| 23 |
-
##
|
| 24 |
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
---
|
| 36 |
|
| 37 |
-
##
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
| **Train/Val split** | 95 / 5 |
|
| 44 |
-
| **Samples used** | 100 k abstracts |
|
| 45 |
-
| **Seq length / stride** | 1,024 / 1,024 |
|
| 46 |
-
| **Cleaning** | `use_clean=False` (raw abstracts) |
|
| 47 |
|
| 48 |
---
|
| 49 |
|
| 50 |
-
##
|
| 51 |
-
|
| 52 |
-
|
|
| 53 |
-
|
| 54 |
-
| **
|
| 55 |
-
| **
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
| **Learning rate** | 3 ร 10โปโด (linear + 100-step warmup) |
|
| 59 |
-
| **Weight decay** | 0.1 |
|
| 60 |
-
| **Batch size** | 32 (ร 4 grad acc โ 128 effective) |
|
| 61 |
-
| **Grad clip** | 1.0 |
|
| 62 |
-
| **Total steps** | 100 k |
|
| 63 |
-
| **Eval** | every 500 steps ร 100 iters |
|
| 64 |
-
| **Checkpoint save** | every 1 k steps |
|
| 65 |
-
| **Seed** | 7 979 797 |
|
| 66 |
-
| **Gradient checkpointing** | โ
Enabled |
|
| 67 |
-
| **WandB** | `kunjcr2-dreamable/MedAssist-GPT-Pretraining` (`medassist-401M-test`) |
|
| 68 |
-
| **HF repo** | `kunjcr2/MedAssist-GPT-401M` |
|
| 69 |
|
| 70 |
---
|
| 71 |
|
| 72 |
-
##
|
| 73 |
|
| 74 |
-
| Item
|
| 75 |
-
|
| 76 |
-
|
|
| 77 |
-
|
|
| 78 |
-
|
|
| 79 |
-
|
|
| 80 |
-
|
|
|
|
|
| 81 |
|
| 82 |
---
|
| 83 |
|
| 84 |
-
##
|
| 85 |
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
|
| 90 |
---
|
| 91 |
|
| 92 |
-
##
|
| 93 |
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
ids = torch.tensor([enc.encode(
|
| 110 |
-
"A patient was admitted with severe headache. Initial assessment revealed"
|
| 111 |
-
)], dtype=torch.long)
|
| 112 |
-
|
| 113 |
-
for _ in range(100):
|
| 114 |
-
logits = model(ids)[:, -1, :]
|
| 115 |
-
next_id = torch.multinomial(torch.softmax(logits / 0.6, dim=-1), 1)
|
| 116 |
-
ids = torch.cat([ids, next_id], dim=1)
|
| 117 |
-
print(enc.decode(ids[0].tolist()))
|
| 118 |
-
```
|
| 119 |
|
| 120 |
---
|
| 121 |
|
| 122 |
-
##
|
| 123 |
|
| 124 |
-
|
| 125 |
-
|
|
|
|
| 126 |
|
| 127 |
---
|
| 128 |
|
| 129 |
-
##
|
|
|
|
|
|
|
|
|
|
| 130 |
|
| 131 |
-
|
| 132 |
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
* fine-tuning for medical text understanding.
|
| 136 |
|
| 137 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
|
| 139 |
---
|
| 140 |
|
| 141 |
-
##
|
|
|
|
|
|
|
| 142 |
|
| 143 |
-
|
|
|
|
|
|
|
|
|
|
| 144 |
|
| 145 |
-
|
| 146 |
-
* **Reinforcement Learning (PPO) for alignment**
|
| 147 |
|
| 148 |
---
|
| 149 |
|
| 150 |
-
##
|
| 151 |
|
| 152 |
-
*
|
| 153 |
-
*
|
| 154 |
-
*
|
| 155 |
|
| 156 |
---
|
| 157 |
|
| 158 |
## ๐ชช License
|
| 159 |
|
| 160 |
-
Apache 2.0
|
|
|
|
|
|
| 5 |
language:
|
| 6 |
- en
|
| 7 |
tags:
|
| 8 |
+
- research
|
| 9 |
+
- llm-pretraining
|
| 10 |
+
- transformer
|
| 11 |
- gqa
|
| 12 |
- rope
|
| 13 |
- swiglu
|
| 14 |
- rmsnorm
|
| 15 |
+
- medical-text
|
| 16 |
---
|
| 17 |
|
| 18 |
+
# ๐ง MedAssistGPT โ Pretraining Checkpoints (303M & 401M)
|
| 19 |
|
| 20 |
+
**Experimental medical-domain LLM pretraining project.**
|
| 21 |
+
โ ๏ธ **Research-only. Not for clinical, diagnostic, or production use.**
|
| 22 |
|
| 23 |
---
|
| 24 |
|
| 25 |
+
## ๐ Overview
|
| 26 |
|
| 27 |
+
This repository contains **multiple pretraining checkpoints** of the **MedAssistGPT architecture**, released in **two parameter scales**:
|
| 28 |
+
|
| 29 |
+
- **MedAssistGPT-303M**
|
| 30 |
+
- **MedAssistGPT-401M**
|
| 31 |
+
|
| 32 |
+
Both variants:
|
| 33 |
+
- share the **same architecture design**
|
| 34 |
+
- use the **same tokenizer**
|
| 35 |
+
- are trained on the **same dataset**
|
| 36 |
+
- differ only in **model width / attention configuration** and **training progress**
|
| 37 |
+
|
| 38 |
+
The purpose of this repository is to document **architecture choices, data pipelines, and large-scale training behavior**, rather than to present a fully converged or production-ready medical language model.
|
| 39 |
|
| 40 |
---
|
| 41 |
|
| 42 |
+
## ๐งฉ Architecture (Shared Design)
|
| 43 |
+
|
| 44 |
+
All models are **decoder-only Transformers** implemented from scratch in PyTorch.
|
| 45 |
+
|
| 46 |
+
### Core components
|
| 47 |
+
- **RoPE (Rotary Positional Embeddings)**
|
| 48 |
+
- **Grouped Query Attention (GQA)**
|
| 49 |
+
- **SwiGLU feed-forward layers**
|
| 50 |
+
- **RMSNorm (pre-norm)**
|
| 51 |
+
- **Weight tying** (token embeddings โ LM head)
|
| 52 |
+
- **Dropout:** 0.0 (pretraining configuration)
|
| 53 |
|
| 54 |
+
### Tokenization
|
| 55 |
+
- **Tokenizer:** `tiktoken` `p50k_base`
|
| 56 |
+
- **Vocabulary size:** โ 50,281
|
| 57 |
+
- **Context length:** 1,024 tokens
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
|
| 59 |
---
|
| 60 |
|
| 61 |
+
## ๐ Model Variants
|
| 62 |
+
|
| 63 |
+
| Variant | Parameters | d_model | Heads | GQA (KV heads) | Blocks |
|
| 64 |
+
|------|-----------|--------|-------|---------------|--------|
|
| 65 |
+
| **303M** | ~303M | 1024 | 16 | 4 | 24 |
|
| 66 |
+
| **401M** | ~401M | 1024 | 32 | 4 | 24 |
|
| 67 |
+
|
| 68 |
+
> Both variants use the **same architectural template**; the 401M model increases attention width while preserving GQA.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
|
| 70 |
---
|
| 71 |
|
| 72 |
+
## ๐ Data
|
| 73 |
|
| 74 |
+
| Item | Value |
|
| 75 |
+
|----|----|
|
| 76 |
+
| Dataset | `japhba/pubmed_simple` |
|
| 77 |
+
| Text field | `abstract` |
|
| 78 |
+
| Domain | Biomedical / medical research |
|
| 79 |
+
| Cleaning | Minimal (raw abstracts) |
|
| 80 |
+
| Sequence length | 1,024 |
|
| 81 |
+
| Sliding window stride | 512 |
|
| 82 |
|
| 83 |
---
|
| 84 |
|
| 85 |
+
## โ๏ธ Training Setup (Common)
|
| 86 |
|
| 87 |
+
| Item | Value |
|
| 88 |
+
|----|----|
|
| 89 |
+
| Objective | Causal language modeling (next-token prediction) |
|
| 90 |
+
| Optimizer | AdamW |
|
| 91 |
+
| Betas | (0.9, 0.95) |
|
| 92 |
+
| Precision | bf16 |
|
| 93 |
+
| Gradient accumulation | Enabled |
|
| 94 |
+
| Gradient clipping | 1.0 |
|
| 95 |
+
| Effective batch size | 128 |
|
| 96 |
|
| 97 |
---
|
| 98 |
|
| 99 |
+
## ๐ฆ Checkpoints
|
| 100 |
|
| 101 |
+
The `checkpoints/` directory contains **multiple snapshots of the same model variants at different training stages**.
|
| 102 |
+
|
| 103 |
+
Examples:
|
| 104 |
+
- `checkpoint_step_25000.pt` (303M) โ ~2.5B tokens seen
|
| 105 |
+
- Additional checkpoints may exist for the 401M variant
|
| 106 |
+
|
| 107 |
+
> โ ๏ธ **Important:**
|
| 108 |
+
> All released checkpoints are **early-stage pretraining snapshots**.
|
| 109 |
+
> At ~2.5B tokens (~8ร tokens/parameter for 303M), the models are **undertrained** and should **not** be treated as finished base models.
|
| 110 |
+
|
| 111 |
+
They are provided to:
|
| 112 |
+
- study training dynamics,
|
| 113 |
+
- resume or extend pretraining,
|
| 114 |
+
- experiment with fine-tuning,
|
| 115 |
+
- inspect architectural behavior at scale.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
|
| 117 |
---
|
| 118 |
|
| 119 |
+
## ๐ Training Status
|
| 120 |
|
| 121 |
+
- Training and validation loss were **still improving** at the time of the last checkpoints.
|
| 122 |
+
- Training runs were **interrupted due to infrastructure preemption** and were not resumed.
|
| 123 |
+
- No claims are made about benchmark or downstream task performance.
|
| 124 |
|
| 125 |
---
|
| 126 |
|
| 127 |
+
## ๐ Loading the Model
|
| 128 |
+
|
| 129 |
+
```python
|
| 130 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 131 |
|
| 132 |
+
repo_id = "kunjcr2/MedAssistGPT-303M" # or 401M repo
|
| 133 |
|
| 134 |
+
tokenizer = AutoTokenizer.from_pretrained(repo_id, trust_remote_code=True)
|
| 135 |
+
model = AutoModelForCausalLM.from_pretrained(repo_id, trust_remote_code=True)
|
|
|
|
| 136 |
|
| 137 |
+
prompt = "A patient was admitted with severe headache. Initial assessment revealed"
|
| 138 |
+
inputs = tokenizer(prompt, return_tensors="pt")
|
| 139 |
+
|
| 140 |
+
outputs = model.generate(
|
| 141 |
+
**inputs,
|
| 142 |
+
max_new_tokens=100,
|
| 143 |
+
temperature=0.7,
|
| 144 |
+
)
|
| 145 |
+
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
| 146 |
+
````
|
| 147 |
|
| 148 |
---
|
| 149 |
|
| 150 |
+
## ๐งช Intended Use
|
| 151 |
+
|
| 152 |
+
This repository is intended for:
|
| 153 |
|
| 154 |
+
* architecture exploration,
|
| 155 |
+
* large-scale pretraining experiments,
|
| 156 |
+
* medical-domain language modeling research,
|
| 157 |
+
* educational purposes.
|
| 158 |
|
| 159 |
+
๐ซ **Not intended for clinical or production medical use.**
|
|
|
|
| 160 |
|
| 161 |
---
|
| 162 |
|
| 163 |
+
## ๐ฎ Possible Next Steps (Not Included)
|
| 164 |
|
| 165 |
+
* Continued pretraining with larger token budgets
|
| 166 |
+
* Supervised fine-tuning (SFT) on medical QA datasets
|
| 167 |
+
* Evaluation on biomedical NLP benchmarks
|
| 168 |
|
| 169 |
---
|
| 170 |
|
| 171 |
## ๐ชช License
|
| 172 |
|
| 173 |
+
Apache 2.0
|
| 174 |
+
|