kunjcr2 commited on
Commit
36d5852
·
verified ·
1 Parent(s): 4d5e222

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +74 -1
README.md CHANGED
@@ -4,4 +4,77 @@ datasets:
4
  - Hack90/europe_pmc_articles_part_2
5
  language:
6
  - en
7
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  - Hack90/europe_pmc_articles_part_2
5
  language:
6
  - en
7
+ tags:
8
+ - v0_pretrain_medassist
9
+ ---
10
+ # MedAssist-GPT
11
+
12
+ Tiny medical-domain LLM pretraining project.
13
+ **NOT for clinical use.**
14
+
15
+ ## TL;DR
16
+
17
+ * **Arch:** Transformer with **RoPE** + **GQA**, **SwiGLU** MLP, **RMSNorm**, causal LM head (tied embeddings).
18
+ * **Tokenizer:** `tiktoken` **p50k_base** (vocab ≈ 50,281).
19
+ * **Context:** 1,024 tokens (default).
20
+ * **Size (default config):** ~125M params (d_model=512, n_heads=16, layers=16, d_ff=2048).
21
+ * **Trained on** about 2.2B tokens of pure medical data.
22
+
23
+ ## Data (example)
24
+
25
+ * Source: `Hack90/europe_pmc_articles_part_2` (`full_text`).
26
+ * XML → plain text via `clean()`; sliding windows (`max_length=1024`, `stride=1024`).
27
+
28
+ ## Training (script)
29
+
30
+ * AdamW + OneCycleLR, bf16 AMP, grad accumulation, checkpoints, optional HF upload, wandb logging.
31
+
32
+ ## Loss
33
+
34
+
35
+ ![train_loss](https://cdn-uploads.huggingface.co/production/uploads/67c358189919777813863c48/bQGVqgx4GoqXZTcMh8KhM.png)
36
+ ![val_loss](https://cdn-uploads.huggingface.co/production/uploads/67c358189919777813863c48/jhNnS_Wvhj4-fzNoO2dRN.png)
37
+
38
+ ## Try it (minimal)
39
+
40
+ ```python
41
+ # pip install torch tiktoken huggingface_hub safetensors
42
+ import torch, tiktoken
43
+ from safetensors.torch import load_file
44
+ from huggingface_hub import hf_hub_download
45
+
46
+ REPO_ID = "kunjcr2/MedAssistGPT" # change if needed
47
+ WEIGHTS = hf_hub_download(REPO_ID, "model.safetensors")
48
+ state = load_file(WEIGHTS, device="cpu")
49
+
50
+ # Import your MedAssistGPT class from the script/notebook
51
+ from MedAssistGPT import MedAssistGPT, MODEL_CONFIG # ensure paths match your repo
52
+
53
+ model = MedAssistGPT(MODEL_CONFIG)
54
+ model.load_state_dict(state, strict=True).eval()
55
+
56
+ enc = tiktoken.get_encoding("p50k_base")
57
+ ids = torch.tensor([enc.encode("To live a good life")], dtype=torch.long)
58
+ with torch.no_grad():
59
+ for _ in range(100):
60
+ logits = model(ids)[:, -1, :]
61
+ next_id = torch.multinomial(torch.softmax(logits/0.7, dim=-1), 1)
62
+ ids = torch.cat([ids, next_id], dim=1)
63
+ if next_id.item() == enc.eot_token: break
64
+
65
+ print(enc.decode(ids[0].tolist()))
66
+ ```
67
+
68
+ ## Intended use & limitations
69
+
70
+ Research/experimentation + downstream finetuning after pretraining.
71
+ Do **NOT** use for medical decisions.
72
+
73
+ ## Files
74
+
75
+ * `model.safetensors` (weights)
76
+ * `config.json`, `tokenizer_config.json`
77
+ * Script/notebook defining `MedAssistGPT` class
78
+
79
+ ## License
80
+ Apache-2.0