MedAssistGPT / README.md
kunjcr2's picture
Update README.md
9e1cffa verified
metadata
license: apache-2.0
datasets:
  - japhba/pubmed_simple
language:
  - en
tags:
  - research
  - llm-pretraining
  - transformer
  - gqa
  - rope
  - swiglu
  - rmsnorm
  - medical-text

๐Ÿง  MedAssistGPT โ€” Pretraining Checkpoints (303M & 401M)

Experimental medical-domain LLM pretraining project.
โš ๏ธ Research-only. Not for clinical, diagnostic, or production use.


๐Ÿ“Œ Overview

This repository contains multiple pretraining checkpoints of the MedAssistGPT architecture, released in two parameter scales:

  • MedAssistGPT-303M
  • MedAssistGPT-401M

Both variants:

  • share the same architecture design
  • use the same tokenizer
  • are trained on the same dataset
  • differ only in model width / attention configuration and training progress

The purpose of this repository is to document architecture choices, data pipelines, and large-scale training behavior, rather than to present a fully converged or production-ready medical language model.


๐Ÿงฉ Architecture (Shared Design)

All models are decoder-only Transformers implemented from scratch in PyTorch.

Core components

  • RoPE (Rotary Positional Embeddings)
  • Grouped Query Attention (GQA)
  • SwiGLU feed-forward layers
  • RMSNorm (pre-norm)
  • Weight tying (token embeddings โ†” LM head)
  • Dropout: 0.0 (pretraining configuration)

Tokenization

  • Tokenizer: tiktoken p50k_base
  • Vocabulary size: โ‰ˆ 50,281
  • Context length: 1,024 tokens

๐Ÿ“ Model Variants

Variant Parameters d_model Heads GQA (KV heads) Blocks
303M ~303M 1024 16 4 24
401M ~401M 1024 32 4 24

Both variants use the same architectural template; the 401M model increases attention width while preserving GQA.


๐Ÿ“š Data

Item Value
Dataset japhba/pubmed_simple
Text field abstract
Domain Biomedical / medical research
Cleaning Minimal (raw abstracts)
Sequence length 1,024
Sliding window stride 512

โš™๏ธ Training Setup (Common)

Item Value
Objective Causal language modeling (next-token prediction)
Optimizer AdamW
Betas (0.9, 0.95)
Precision bf16
Gradient accumulation Enabled
Gradient clipping 1.0
Effective batch size 128

๐Ÿ“ฆ Checkpoints

The checkpoints/ directory contains multiple snapshots of the same model variants at different training stages.

Examples:

  • checkpoint_step_25000.pt (303M) โ†’ ~2.5B tokens seen
  • Additional checkpoints may exist for the 401M variant

โš ๏ธ Important:
All released checkpoints are early-stage pretraining snapshots.
At 2.5B tokens (8ร— tokens/parameter for 303M), the models are undertrained and should not be treated as finished base models.

They are provided to:

  • study training dynamics,
  • resume or extend pretraining,
  • experiment with fine-tuning,
  • inspect architectural behavior at scale.

๐Ÿ“ˆ Training Status

  • Training and validation loss were still improving at the time of the last checkpoints.
  • Training runs were interrupted due to infrastructure preemption and were not resumed.
  • No claims are made about benchmark or downstream task performance.

๐Ÿš€ Loading the Model

from transformers import AutoTokenizer, AutoModelForCausalLM

repo_id = "kunjcr2/MedAssistGPT-303M"  # or 401M repo

tokenizer = AutoTokenizer.from_pretrained(repo_id, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(repo_id, trust_remote_code=True)

prompt = "A patient was admitted with severe headache. Initial assessment revealed"
inputs = tokenizer(prompt, return_tensors="pt")

outputs = model.generate(
    **inputs,
    max_new_tokens=100,
    temperature=0.7,
)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

๐Ÿงช Intended Use

This repository is intended for:

  • architecture exploration,
  • large-scale pretraining experiments,
  • medical-domain language modeling research,
  • educational purposes.

๐Ÿšซ Not intended for clinical or production medical use.


๐Ÿ”ฎ Possible Next Steps (Not Included)

  • Continued pretraining with larger token budgets
  • Supervised fine-tuning (SFT) on medical QA datasets
  • Evaluation on biomedical NLP benchmarks

๐Ÿชช License

Apache 2.0