MedAssistGPT / README.md
kunjcr2's picture
Update README.md
9e1cffa verified
---
license: apache-2.0
datasets:
- japhba/pubmed_simple
language:
- en
tags:
- research
- llm-pretraining
- transformer
- gqa
- rope
- swiglu
- rmsnorm
- medical-text
---
# ๐Ÿง  MedAssistGPT โ€” Pretraining Checkpoints (303M & 401M)
**Experimental medical-domain LLM pretraining project.**
โš ๏ธ **Research-only. Not for clinical, diagnostic, or production use.**
---
## ๐Ÿ“Œ Overview
This repository contains **multiple pretraining checkpoints** of the **MedAssistGPT architecture**, released in **two parameter scales**:
- **MedAssistGPT-303M**
- **MedAssistGPT-401M**
Both variants:
- share the **same architecture design**
- use the **same tokenizer**
- are trained on the **same dataset**
- differ only in **model width / attention configuration** and **training progress**
The purpose of this repository is to document **architecture choices, data pipelines, and large-scale training behavior**, rather than to present a fully converged or production-ready medical language model.
---
## ๐Ÿงฉ Architecture (Shared Design)
All models are **decoder-only Transformers** implemented from scratch in PyTorch.
### Core components
- **RoPE (Rotary Positional Embeddings)**
- **Grouped Query Attention (GQA)**
- **SwiGLU feed-forward layers**
- **RMSNorm (pre-norm)**
- **Weight tying** (token embeddings โ†” LM head)
- **Dropout:** 0.0 (pretraining configuration)
### Tokenization
- **Tokenizer:** `tiktoken` `p50k_base`
- **Vocabulary size:** โ‰ˆ 50,281
- **Context length:** 1,024 tokens
---
## ๐Ÿ“ Model Variants
| Variant | Parameters | d_model | Heads | GQA (KV heads) | Blocks |
|------|-----------|--------|-------|---------------|--------|
| **303M** | ~303M | 1024 | 16 | 4 | 24 |
| **401M** | ~401M | 1024 | 32 | 4 | 24 |
> Both variants use the **same architectural template**; the 401M model increases attention width while preserving GQA.
---
## ๐Ÿ“š Data
| Item | Value |
|----|----|
| Dataset | `japhba/pubmed_simple` |
| Text field | `abstract` |
| Domain | Biomedical / medical research |
| Cleaning | Minimal (raw abstracts) |
| Sequence length | 1,024 |
| Sliding window stride | 512 |
---
## โš™๏ธ Training Setup (Common)
| Item | Value |
|----|----|
| Objective | Causal language modeling (next-token prediction) |
| Optimizer | AdamW |
| Betas | (0.9, 0.95) |
| Precision | bf16 |
| Gradient accumulation | Enabled |
| Gradient clipping | 1.0 |
| Effective batch size | 128 |
---
## ๐Ÿ“ฆ Checkpoints
The `checkpoints/` directory contains **multiple snapshots of the same model variants at different training stages**.
Examples:
- `checkpoint_step_25000.pt` (303M) โ†’ ~2.5B tokens seen
- Additional checkpoints may exist for the 401M variant
> โš ๏ธ **Important:**
> All released checkpoints are **early-stage pretraining snapshots**.
> At ~2.5B tokens (~8ร— tokens/parameter for 303M), the models are **undertrained** and should **not** be treated as finished base models.
They are provided to:
- study training dynamics,
- resume or extend pretraining,
- experiment with fine-tuning,
- inspect architectural behavior at scale.
---
## ๐Ÿ“ˆ Training Status
- Training and validation loss were **still improving** at the time of the last checkpoints.
- Training runs were **interrupted due to infrastructure preemption** and were not resumed.
- No claims are made about benchmark or downstream task performance.
---
## ๐Ÿš€ Loading the Model
```python
from transformers import AutoTokenizer, AutoModelForCausalLM
repo_id = "kunjcr2/MedAssistGPT-303M" # or 401M repo
tokenizer = AutoTokenizer.from_pretrained(repo_id, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(repo_id, trust_remote_code=True)
prompt = "A patient was admitted with severe headache. Initial assessment revealed"
inputs = tokenizer(prompt, return_tensors="pt")
outputs = model.generate(
**inputs,
max_new_tokens=100,
temperature=0.7,
)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
````
---
## ๐Ÿงช Intended Use
This repository is intended for:
* architecture exploration,
* large-scale pretraining experiments,
* medical-domain language modeling research,
* educational purposes.
๐Ÿšซ **Not intended for clinical or production medical use.**
---
## ๐Ÿ”ฎ Possible Next Steps (Not Included)
* Continued pretraining with larger token budgets
* Supervised fine-tuning (SFT) on medical QA datasets
* Evaluation on biomedical NLP benchmarks
---
## ๐Ÿชช License
Apache 2.0