license: apache-2.0
datasets:
- japhba/pubmed_simple
language:
- en
tags:
- research
- llm-pretraining
- transformer
- gqa
- rope
- swiglu
- rmsnorm
- medical-text
๐ง MedAssistGPT โ Pretraining Checkpoints (303M & 401M)
Experimental medical-domain LLM pretraining project.
โ ๏ธ Research-only. Not for clinical, diagnostic, or production use.
๐ Overview
This repository contains multiple pretraining checkpoints of the MedAssistGPT architecture, released in two parameter scales:
- MedAssistGPT-303M
- MedAssistGPT-401M
Both variants:
- share the same architecture design
- use the same tokenizer
- are trained on the same dataset
- differ only in model width / attention configuration and training progress
The purpose of this repository is to document architecture choices, data pipelines, and large-scale training behavior, rather than to present a fully converged or production-ready medical language model.
๐งฉ Architecture (Shared Design)
All models are decoder-only Transformers implemented from scratch in PyTorch.
Core components
- RoPE (Rotary Positional Embeddings)
- Grouped Query Attention (GQA)
- SwiGLU feed-forward layers
- RMSNorm (pre-norm)
- Weight tying (token embeddings โ LM head)
- Dropout: 0.0 (pretraining configuration)
Tokenization
- Tokenizer:
tiktokenp50k_base - Vocabulary size: โ 50,281
- Context length: 1,024 tokens
๐ Model Variants
| Variant | Parameters | d_model | Heads | GQA (KV heads) | Blocks |
|---|---|---|---|---|---|
| 303M | ~303M | 1024 | 16 | 4 | 24 |
| 401M | ~401M | 1024 | 32 | 4 | 24 |
Both variants use the same architectural template; the 401M model increases attention width while preserving GQA.
๐ Data
| Item | Value |
|---|---|
| Dataset | japhba/pubmed_simple |
| Text field | abstract |
| Domain | Biomedical / medical research |
| Cleaning | Minimal (raw abstracts) |
| Sequence length | 1,024 |
| Sliding window stride | 512 |
โ๏ธ Training Setup (Common)
| Item | Value |
|---|---|
| Objective | Causal language modeling (next-token prediction) |
| Optimizer | AdamW |
| Betas | (0.9, 0.95) |
| Precision | bf16 |
| Gradient accumulation | Enabled |
| Gradient clipping | 1.0 |
| Effective batch size | 128 |
๐ฆ Checkpoints
The checkpoints/ directory contains multiple snapshots of the same model variants at different training stages.
Examples:
checkpoint_step_25000.pt(303M) โ ~2.5B tokens seen- Additional checkpoints may exist for the 401M variant
โ ๏ธ Important:
All released checkpoints are early-stage pretraining snapshots.
At2.5B tokens (8ร tokens/parameter for 303M), the models are undertrained and should not be treated as finished base models.
They are provided to:
- study training dynamics,
- resume or extend pretraining,
- experiment with fine-tuning,
- inspect architectural behavior at scale.
๐ Training Status
- Training and validation loss were still improving at the time of the last checkpoints.
- Training runs were interrupted due to infrastructure preemption and were not resumed.
- No claims are made about benchmark or downstream task performance.
๐ Loading the Model
from transformers import AutoTokenizer, AutoModelForCausalLM
repo_id = "kunjcr2/MedAssistGPT-303M" # or 401M repo
tokenizer = AutoTokenizer.from_pretrained(repo_id, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(repo_id, trust_remote_code=True)
prompt = "A patient was admitted with severe headache. Initial assessment revealed"
inputs = tokenizer(prompt, return_tensors="pt")
outputs = model.generate(
**inputs,
max_new_tokens=100,
temperature=0.7,
)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
๐งช Intended Use
This repository is intended for:
- architecture exploration,
- large-scale pretraining experiments,
- medical-domain language modeling research,
- educational purposes.
๐ซ Not intended for clinical or production medical use.
๐ฎ Possible Next Steps (Not Included)
- Continued pretraining with larger token budgets
- Supervised fine-tuning (SFT) on medical QA datasets
- Evaluation on biomedical NLP benchmarks
๐ชช License
Apache 2.0