|
|
--- |
|
|
license: apache-2.0 |
|
|
datasets: |
|
|
- japhba/pubmed_simple |
|
|
language: |
|
|
- en |
|
|
tags: |
|
|
- research |
|
|
- llm-pretraining |
|
|
- transformer |
|
|
- gqa |
|
|
- rope |
|
|
- swiglu |
|
|
- rmsnorm |
|
|
- medical-text |
|
|
--- |
|
|
|
|
|
# ๐ง MedAssistGPT โ Pretraining Checkpoints (303M & 401M) |
|
|
|
|
|
**Experimental medical-domain LLM pretraining project.** |
|
|
โ ๏ธ **Research-only. Not for clinical, diagnostic, or production use.** |
|
|
|
|
|
--- |
|
|
|
|
|
## ๐ Overview |
|
|
|
|
|
This repository contains **multiple pretraining checkpoints** of the **MedAssistGPT architecture**, released in **two parameter scales**: |
|
|
|
|
|
- **MedAssistGPT-303M** |
|
|
- **MedAssistGPT-401M** |
|
|
|
|
|
Both variants: |
|
|
- share the **same architecture design** |
|
|
- use the **same tokenizer** |
|
|
- are trained on the **same dataset** |
|
|
- differ only in **model width / attention configuration** and **training progress** |
|
|
|
|
|
The purpose of this repository is to document **architecture choices, data pipelines, and large-scale training behavior**, rather than to present a fully converged or production-ready medical language model. |
|
|
|
|
|
--- |
|
|
|
|
|
## ๐งฉ Architecture (Shared Design) |
|
|
|
|
|
All models are **decoder-only Transformers** implemented from scratch in PyTorch. |
|
|
|
|
|
### Core components |
|
|
- **RoPE (Rotary Positional Embeddings)** |
|
|
- **Grouped Query Attention (GQA)** |
|
|
- **SwiGLU feed-forward layers** |
|
|
- **RMSNorm (pre-norm)** |
|
|
- **Weight tying** (token embeddings โ LM head) |
|
|
- **Dropout:** 0.0 (pretraining configuration) |
|
|
|
|
|
### Tokenization |
|
|
- **Tokenizer:** `tiktoken` `p50k_base` |
|
|
- **Vocabulary size:** โ 50,281 |
|
|
- **Context length:** 1,024 tokens |
|
|
|
|
|
--- |
|
|
|
|
|
## ๐ Model Variants |
|
|
|
|
|
| Variant | Parameters | d_model | Heads | GQA (KV heads) | Blocks | |
|
|
|------|-----------|--------|-------|---------------|--------| |
|
|
| **303M** | ~303M | 1024 | 16 | 4 | 24 | |
|
|
| **401M** | ~401M | 1024 | 32 | 4 | 24 | |
|
|
|
|
|
> Both variants use the **same architectural template**; the 401M model increases attention width while preserving GQA. |
|
|
|
|
|
--- |
|
|
|
|
|
## ๐ Data |
|
|
|
|
|
| Item | Value | |
|
|
|----|----| |
|
|
| Dataset | `japhba/pubmed_simple` | |
|
|
| Text field | `abstract` | |
|
|
| Domain | Biomedical / medical research | |
|
|
| Cleaning | Minimal (raw abstracts) | |
|
|
| Sequence length | 1,024 | |
|
|
| Sliding window stride | 512 | |
|
|
|
|
|
--- |
|
|
|
|
|
## โ๏ธ Training Setup (Common) |
|
|
|
|
|
| Item | Value | |
|
|
|----|----| |
|
|
| Objective | Causal language modeling (next-token prediction) | |
|
|
| Optimizer | AdamW | |
|
|
| Betas | (0.9, 0.95) | |
|
|
| Precision | bf16 | |
|
|
| Gradient accumulation | Enabled | |
|
|
| Gradient clipping | 1.0 | |
|
|
| Effective batch size | 128 | |
|
|
|
|
|
--- |
|
|
|
|
|
## ๐ฆ Checkpoints |
|
|
|
|
|
The `checkpoints/` directory contains **multiple snapshots of the same model variants at different training stages**. |
|
|
|
|
|
Examples: |
|
|
- `checkpoint_step_25000.pt` (303M) โ ~2.5B tokens seen |
|
|
- Additional checkpoints may exist for the 401M variant |
|
|
|
|
|
> โ ๏ธ **Important:** |
|
|
> All released checkpoints are **early-stage pretraining snapshots**. |
|
|
> At ~2.5B tokens (~8ร tokens/parameter for 303M), the models are **undertrained** and should **not** be treated as finished base models. |
|
|
|
|
|
They are provided to: |
|
|
- study training dynamics, |
|
|
- resume or extend pretraining, |
|
|
- experiment with fine-tuning, |
|
|
- inspect architectural behavior at scale. |
|
|
|
|
|
--- |
|
|
|
|
|
## ๐ Training Status |
|
|
|
|
|
- Training and validation loss were **still improving** at the time of the last checkpoints. |
|
|
- Training runs were **interrupted due to infrastructure preemption** and were not resumed. |
|
|
- No claims are made about benchmark or downstream task performance. |
|
|
|
|
|
--- |
|
|
|
|
|
## ๐ Loading the Model |
|
|
|
|
|
```python |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
|
|
repo_id = "kunjcr2/MedAssistGPT-303M" # or 401M repo |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(repo_id, trust_remote_code=True) |
|
|
model = AutoModelForCausalLM.from_pretrained(repo_id, trust_remote_code=True) |
|
|
|
|
|
prompt = "A patient was admitted with severe headache. Initial assessment revealed" |
|
|
inputs = tokenizer(prompt, return_tensors="pt") |
|
|
|
|
|
outputs = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=100, |
|
|
temperature=0.7, |
|
|
) |
|
|
print(tokenizer.decode(outputs[0], skip_special_tokens=True)) |
|
|
```` |
|
|
|
|
|
--- |
|
|
|
|
|
## ๐งช Intended Use |
|
|
|
|
|
This repository is intended for: |
|
|
|
|
|
* architecture exploration, |
|
|
* large-scale pretraining experiments, |
|
|
* medical-domain language modeling research, |
|
|
* educational purposes. |
|
|
|
|
|
๐ซ **Not intended for clinical or production medical use.** |
|
|
|
|
|
--- |
|
|
|
|
|
## ๐ฎ Possible Next Steps (Not Included) |
|
|
|
|
|
* Continued pretraining with larger token budgets |
|
|
* Supervised fine-tuning (SFT) on medical QA datasets |
|
|
* Evaluation on biomedical NLP benchmarks |
|
|
|
|
|
--- |
|
|
|
|
|
## ๐ชช License |
|
|
|
|
|
Apache 2.0 |
|
|
|
|
|
|