kunjcr2 commited on
Commit
9e1cffa
ยท
verified ยท
1 Parent(s): 4e3fbed

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +113 -99
README.md CHANGED
@@ -5,156 +5,170 @@ datasets:
5
  language:
6
  - en
7
  tags:
8
- - v2_pretrain_medassist
 
 
9
  - gqa
10
  - rope
11
  - swiglu
12
  - rmsnorm
13
- - medical
14
  ---
15
 
16
- # ๐Ÿง  MedAssist-GPT-401M
17
 
18
- **Mid-sized medical-domain LLM pretraining project.**
19
- โš ๏ธ *Strictly for research. Not for clinical or diagnostic use.*
20
 
21
  ---
22
 
23
- ## ๐Ÿงฉ TL;DR
24
 
25
- * **Architecture:** Transformer with **RoPE**, **GQA**, **SwiGLU** MLP, and **RMSNorm**
26
- * **Tokenizer:** `tiktoken` `p50k_base` (vocab โ‰ˆ **50,281**)
27
- * **Context length:** 1,024 tokens
28
- * **Parameters:** โ‰ˆ **401 M** (`d_model=1024`, `n_heads=32`, `blocks=24`, `d_ff=2048`)
29
- * **GQA groups:** 8 โ†’ 4 KV heads per 32 query heads
30
- * **Dropout:** 0.0 (pretraining)
31
- * **Precision:** **bf16** mixed precision
32
- * **Training objective:** Next-token prediction
33
- * **Effective batch:** 32 ร— 4 = 128
 
 
 
34
 
35
  ---
36
 
37
- ## ๐Ÿ“š Data
 
 
 
 
 
 
 
 
 
 
38
 
39
- | Field | Value |
40
- | ----------------------- | --------------------------------- |
41
- | **Dataset** | `japhba/pubmed_simple` |
42
- | **Text column** | `abstract` |
43
- | **Train/Val split** | 95 / 5 |
44
- | **Samples used** | 100 k abstracts |
45
- | **Seq length / stride** | 1,024 / 1,024 |
46
- | **Cleaning** | `use_clean=False` (raw abstracts) |
47
 
48
  ---
49
 
50
- ## โš™๏ธ Training
51
-
52
- | Item | Value |
53
- | -------------------------- | --------------------------------------------------------------------- |
54
- | **Framework** | PyTorch |
55
- | **Precision** | bf16 |
56
- | **Objective** | Causal LM (next-token prediction) |
57
- | **Optimizer** | AdamW (`ฮฒโ‚ = 0.9`, `ฮฒโ‚‚ = 0.95`, `eps = 1e-8`) |
58
- | **Learning rate** | 3 ร— 10โปโด (linear + 100-step warmup) |
59
- | **Weight decay** | 0.1 |
60
- | **Batch size** | 32 (ร— 4 grad acc โ†’ 128 effective) |
61
- | **Grad clip** | 1.0 |
62
- | **Total steps** | 100 k |
63
- | **Eval** | every 500 steps ร— 100 iters |
64
- | **Checkpoint save** | every 1 k steps |
65
- | **Seed** | 7 979 797 |
66
- | **Gradient checkpointing** | โœ… Enabled |
67
- | **WandB** | `kunjcr2-dreamable/MedAssist-GPT-Pretraining` (`medassist-401M-test`) |
68
- | **HF repo** | `kunjcr2/MedAssist-GPT-401M` |
69
 
70
  ---
71
 
72
- ## ๐Ÿงฎ Training Environment
73
 
74
- | Item | Value |
75
- | ------------------- | ---------------------- |
76
- | **Hardware** | 1ร— NVIDIA A100 (80 GB) |
77
- | **Precision dtype** | bf16 |
78
- | **Runtime** | ~15 hours |
79
- | **Scheduler** | Linear LR decay |
80
- | **Mixed precision** | Native AMP (bf16) |
 
81
 
82
  ---
83
 
84
- ## ๐Ÿ“ˆ Loss Curves
85
 
86
- *(Placeholder โ€” will update post-training)*
87
- ![train\_loss](https://cdn-uploads.huggingface.co/production/uploads/67c358189919777813863c48/bQGVqgx4GoqXZTcMh8KhM.png)
88
- ![val\_loss](https://cdn-uploads.huggingface.co/production/uploads/67c358189919777813863c48/jhNnS_Wvhj4-fzNoO2dRN.png)
 
 
 
 
 
 
89
 
90
  ---
91
 
92
- ## ๐Ÿš€ Minimal Inference
93
 
94
- ```python
95
- # pip install torch tiktoken huggingface_hub safetensors
96
- import torch, tiktoken
97
- from safetensors.torch import load_file
98
- from huggingface_hub import hf_hub_download
99
- from MedAssistGPT import MedAssistGPT, MODEL_CONFIG
100
-
101
- REPO_ID = "kunjcr2/MedAssist-GPT-401M"
102
- weights = hf_hub_download(REPO_ID, "model.safetensors")
103
- state = load_file(weights, device="cpu")
104
-
105
- model = MedAssistGPT(MODEL_CONFIG)
106
- model.load_state_dict(state, strict=True).eval()
107
-
108
- enc = tiktoken.get_encoding("p50k_base")
109
- ids = torch.tensor([enc.encode(
110
- "A patient was admitted with severe headache. Initial assessment revealed"
111
- )], dtype=torch.long)
112
-
113
- for _ in range(100):
114
- logits = model(ids)[:, -1, :]
115
- next_id = torch.multinomial(torch.softmax(logits / 0.6, dim=-1), 1)
116
- ids = torch.cat([ids, next_id], dim=1)
117
- print(enc.decode(ids[0].tolist()))
118
- ```
119
 
120
  ---
121
 
122
- ## ๐Ÿ’พ Checkpoints
123
 
124
- * Main run: `medassist-401M-test`
125
- * Checkpoint: `/checkpoints/checkpoint_step_44500.pt`
 
126
 
127
  ---
128
 
129
- ## ๐Ÿงช Intended Use
 
 
 
130
 
131
- For research and experimentation only โ€” e.g.,
132
 
133
- * domain-adapted pretraining,
134
- * architecture exploration,
135
- * fine-tuning for medical text understanding.
136
 
137
- ๐Ÿšซ **Not intended for clinical or production medical use.**
 
 
 
 
 
 
 
 
 
138
 
139
  ---
140
 
141
- ## ๐Ÿ”ฎ Future Work
 
 
142
 
143
- Next update includes:
 
 
 
144
 
145
- * **Supervised fine-tuning (SFT)**
146
- * **Reinforcement Learning (PPO) for alignment**
147
 
148
  ---
149
 
150
- ## ๐Ÿ“ Files
151
 
152
- * 'checkpoints/'
153
- * `config.json`, `tokenizer_config.json`
154
- * Training script / notebook defining `MedAssistGPT`
155
 
156
  ---
157
 
158
  ## ๐Ÿชช License
159
 
160
- Apache 2.0
 
 
5
  language:
6
  - en
7
  tags:
8
+ - research
9
+ - llm-pretraining
10
+ - transformer
11
  - gqa
12
  - rope
13
  - swiglu
14
  - rmsnorm
15
+ - medical-text
16
  ---
17
 
18
+ # ๐Ÿง  MedAssistGPT โ€” Pretraining Checkpoints (303M & 401M)
19
 
20
+ **Experimental medical-domain LLM pretraining project.**
21
+ โš ๏ธ **Research-only. Not for clinical, diagnostic, or production use.**
22
 
23
  ---
24
 
25
+ ## ๐Ÿ“Œ Overview
26
 
27
+ This repository contains **multiple pretraining checkpoints** of the **MedAssistGPT architecture**, released in **two parameter scales**:
28
+
29
+ - **MedAssistGPT-303M**
30
+ - **MedAssistGPT-401M**
31
+
32
+ Both variants:
33
+ - share the **same architecture design**
34
+ - use the **same tokenizer**
35
+ - are trained on the **same dataset**
36
+ - differ only in **model width / attention configuration** and **training progress**
37
+
38
+ The purpose of this repository is to document **architecture choices, data pipelines, and large-scale training behavior**, rather than to present a fully converged or production-ready medical language model.
39
 
40
  ---
41
 
42
+ ## ๐Ÿงฉ Architecture (Shared Design)
43
+
44
+ All models are **decoder-only Transformers** implemented from scratch in PyTorch.
45
+
46
+ ### Core components
47
+ - **RoPE (Rotary Positional Embeddings)**
48
+ - **Grouped Query Attention (GQA)**
49
+ - **SwiGLU feed-forward layers**
50
+ - **RMSNorm (pre-norm)**
51
+ - **Weight tying** (token embeddings โ†” LM head)
52
+ - **Dropout:** 0.0 (pretraining configuration)
53
 
54
+ ### Tokenization
55
+ - **Tokenizer:** `tiktoken` `p50k_base`
56
+ - **Vocabulary size:** โ‰ˆ 50,281
57
+ - **Context length:** 1,024 tokens
 
 
 
 
58
 
59
  ---
60
 
61
+ ## ๐Ÿ“ Model Variants
62
+
63
+ | Variant | Parameters | d_model | Heads | GQA (KV heads) | Blocks |
64
+ |------|-----------|--------|-------|---------------|--------|
65
+ | **303M** | ~303M | 1024 | 16 | 4 | 24 |
66
+ | **401M** | ~401M | 1024 | 32 | 4 | 24 |
67
+
68
+ > Both variants use the **same architectural template**; the 401M model increases attention width while preserving GQA.
 
 
 
 
 
 
 
 
 
 
 
69
 
70
  ---
71
 
72
+ ## ๐Ÿ“š Data
73
 
74
+ | Item | Value |
75
+ |----|----|
76
+ | Dataset | `japhba/pubmed_simple` |
77
+ | Text field | `abstract` |
78
+ | Domain | Biomedical / medical research |
79
+ | Cleaning | Minimal (raw abstracts) |
80
+ | Sequence length | 1,024 |
81
+ | Sliding window stride | 512 |
82
 
83
  ---
84
 
85
+ ## โš™๏ธ Training Setup (Common)
86
 
87
+ | Item | Value |
88
+ |----|----|
89
+ | Objective | Causal language modeling (next-token prediction) |
90
+ | Optimizer | AdamW |
91
+ | Betas | (0.9, 0.95) |
92
+ | Precision | bf16 |
93
+ | Gradient accumulation | Enabled |
94
+ | Gradient clipping | 1.0 |
95
+ | Effective batch size | 128 |
96
 
97
  ---
98
 
99
+ ## ๐Ÿ“ฆ Checkpoints
100
 
101
+ The `checkpoints/` directory contains **multiple snapshots of the same model variants at different training stages**.
102
+
103
+ Examples:
104
+ - `checkpoint_step_25000.pt` (303M) โ†’ ~2.5B tokens seen
105
+ - Additional checkpoints may exist for the 401M variant
106
+
107
+ > โš ๏ธ **Important:**
108
+ > All released checkpoints are **early-stage pretraining snapshots**.
109
+ > At ~2.5B tokens (~8ร— tokens/parameter for 303M), the models are **undertrained** and should **not** be treated as finished base models.
110
+
111
+ They are provided to:
112
+ - study training dynamics,
113
+ - resume or extend pretraining,
114
+ - experiment with fine-tuning,
115
+ - inspect architectural behavior at scale.
 
 
 
 
 
 
 
 
 
 
116
 
117
  ---
118
 
119
+ ## ๐Ÿ“ˆ Training Status
120
 
121
+ - Training and validation loss were **still improving** at the time of the last checkpoints.
122
+ - Training runs were **interrupted due to infrastructure preemption** and were not resumed.
123
+ - No claims are made about benchmark or downstream task performance.
124
 
125
  ---
126
 
127
+ ## ๐Ÿš€ Loading the Model
128
+
129
+ ```python
130
+ from transformers import AutoTokenizer, AutoModelForCausalLM
131
 
132
+ repo_id = "kunjcr2/MedAssistGPT-303M" # or 401M repo
133
 
134
+ tokenizer = AutoTokenizer.from_pretrained(repo_id, trust_remote_code=True)
135
+ model = AutoModelForCausalLM.from_pretrained(repo_id, trust_remote_code=True)
 
136
 
137
+ prompt = "A patient was admitted with severe headache. Initial assessment revealed"
138
+ inputs = tokenizer(prompt, return_tensors="pt")
139
+
140
+ outputs = model.generate(
141
+ **inputs,
142
+ max_new_tokens=100,
143
+ temperature=0.7,
144
+ )
145
+ print(tokenizer.decode(outputs[0], skip_special_tokens=True))
146
+ ````
147
 
148
  ---
149
 
150
+ ## ๐Ÿงช Intended Use
151
+
152
+ This repository is intended for:
153
 
154
+ * architecture exploration,
155
+ * large-scale pretraining experiments,
156
+ * medical-domain language modeling research,
157
+ * educational purposes.
158
 
159
+ ๐Ÿšซ **Not intended for clinical or production medical use.**
 
160
 
161
  ---
162
 
163
+ ## ๐Ÿ”ฎ Possible Next Steps (Not Included)
164
 
165
+ * Continued pretraining with larger token budgets
166
+ * Supervised fine-tuning (SFT) on medical QA datasets
167
+ * Evaluation on biomedical NLP benchmarks
168
 
169
  ---
170
 
171
  ## ๐Ÿชช License
172
 
173
+ Apache 2.0
174
+