--- license: apache-2.0 datasets: - japhba/pubmed_simple language: - en tags: - research - llm-pretraining - transformer - gqa - rope - swiglu - rmsnorm - medical-text --- # ๐Ÿง  MedAssistGPT โ€” Pretraining Checkpoints (303M & 401M) **Experimental medical-domain LLM pretraining project.** โš ๏ธ **Research-only. Not for clinical, diagnostic, or production use.** --- ## ๐Ÿ“Œ Overview This repository contains **multiple pretraining checkpoints** of the **MedAssistGPT architecture**, released in **two parameter scales**: - **MedAssistGPT-303M** - **MedAssistGPT-401M** Both variants: - share the **same architecture design** - use the **same tokenizer** - are trained on the **same dataset** - differ only in **model width / attention configuration** and **training progress** The purpose of this repository is to document **architecture choices, data pipelines, and large-scale training behavior**, rather than to present a fully converged or production-ready medical language model. --- ## ๐Ÿงฉ Architecture (Shared Design) All models are **decoder-only Transformers** implemented from scratch in PyTorch. ### Core components - **RoPE (Rotary Positional Embeddings)** - **Grouped Query Attention (GQA)** - **SwiGLU feed-forward layers** - **RMSNorm (pre-norm)** - **Weight tying** (token embeddings โ†” LM head) - **Dropout:** 0.0 (pretraining configuration) ### Tokenization - **Tokenizer:** `tiktoken` `p50k_base` - **Vocabulary size:** โ‰ˆ 50,281 - **Context length:** 1,024 tokens --- ## ๐Ÿ“ Model Variants | Variant | Parameters | d_model | Heads | GQA (KV heads) | Blocks | |------|-----------|--------|-------|---------------|--------| | **303M** | ~303M | 1024 | 16 | 4 | 24 | | **401M** | ~401M | 1024 | 32 | 4 | 24 | > Both variants use the **same architectural template**; the 401M model increases attention width while preserving GQA. --- ## ๐Ÿ“š Data | Item | Value | |----|----| | Dataset | `japhba/pubmed_simple` | | Text field | `abstract` | | Domain | Biomedical / medical research | | Cleaning | Minimal (raw abstracts) | | Sequence length | 1,024 | | Sliding window stride | 512 | --- ## โš™๏ธ Training Setup (Common) | Item | Value | |----|----| | Objective | Causal language modeling (next-token prediction) | | Optimizer | AdamW | | Betas | (0.9, 0.95) | | Precision | bf16 | | Gradient accumulation | Enabled | | Gradient clipping | 1.0 | | Effective batch size | 128 | --- ## ๐Ÿ“ฆ Checkpoints The `checkpoints/` directory contains **multiple snapshots of the same model variants at different training stages**. Examples: - `checkpoint_step_25000.pt` (303M) โ†’ ~2.5B tokens seen - Additional checkpoints may exist for the 401M variant > โš ๏ธ **Important:** > All released checkpoints are **early-stage pretraining snapshots**. > At ~2.5B tokens (~8ร— tokens/parameter for 303M), the models are **undertrained** and should **not** be treated as finished base models. They are provided to: - study training dynamics, - resume or extend pretraining, - experiment with fine-tuning, - inspect architectural behavior at scale. --- ## ๐Ÿ“ˆ Training Status - Training and validation loss were **still improving** at the time of the last checkpoints. - Training runs were **interrupted due to infrastructure preemption** and were not resumed. - No claims are made about benchmark or downstream task performance. --- ## ๐Ÿš€ Loading the Model ```python from transformers import AutoTokenizer, AutoModelForCausalLM repo_id = "kunjcr2/MedAssistGPT-303M" # or 401M repo tokenizer = AutoTokenizer.from_pretrained(repo_id, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained(repo_id, trust_remote_code=True) prompt = "A patient was admitted with severe headache. Initial assessment revealed" inputs = tokenizer(prompt, return_tensors="pt") outputs = model.generate( **inputs, max_new_tokens=100, temperature=0.7, ) print(tokenizer.decode(outputs[0], skip_special_tokens=True)) ```` --- ## ๐Ÿงช Intended Use This repository is intended for: * architecture exploration, * large-scale pretraining experiments, * medical-domain language modeling research, * educational purposes. ๐Ÿšซ **Not intended for clinical or production medical use.** --- ## ๐Ÿ”ฎ Possible Next Steps (Not Included) * Continued pretraining with larger token budgets * Supervised fine-tuning (SFT) on medical QA datasets * Evaluation on biomedical NLP benchmarks --- ## ๐Ÿชช License Apache 2.0