justkitting commited on
Commit
6a5f2d9
·
verified ·
1 Parent(s): dabf1e7

Update usage section with proper inference example

Browse files
Files changed (1) hide show
  1. README.md +21 -12
README.md CHANGED
@@ -26,6 +26,27 @@ python sample.py
26
 
27
  The checkpoint downloads automatically on first run.
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  ## Results
30
 
31
  | Metric | Value |
@@ -92,18 +113,6 @@ The `.pth` file contains:
92
  - `config` — model config dict
93
  - `chunk_idx` — data chunk index
94
 
95
- ## Usage
96
-
97
- ```python
98
- import torch
99
- from model import GPT, GPTConfig
100
-
101
- ckpt = torch.load("pg19_step_50000_final.pth", map_location="cpu")
102
- config = GPTConfig(**ckpt["config"])
103
- model = GPT(config)
104
- model.load_state_dict(ckpt["ema_model"]) # use EMA weights for best quality
105
- ```
106
-
107
  ## License
108
 
109
  MIT
 
26
 
27
  The checkpoint downloads automatically on first run.
28
 
29
+ ## Usage
30
+
31
+ ```python
32
+ from huggingface_hub import hf_hub_download
33
+ from sample import load_model, sample
34
+
35
+ # Download checkpoint (~4GB, cached after first run)
36
+ ckpt_path = hf_hub_download(
37
+ repo_id="justkitting/gpt-diffusion-nanogptsettings",
38
+ filename="pg19_step_50000_final.pth",
39
+ )
40
+
41
+ # Load model (uses EMA weights automatically)
42
+ device = "cuda"
43
+ model, config = load_model(ckpt_path, device)
44
+
45
+ # Generate text
46
+ text = sample(model, config, steps=128, length=512, device=device)
47
+ print(text)
48
+ ```
49
+
50
  ## Results
51
 
52
  | Metric | Value |
 
113
  - `config` — model config dict
114
  - `chunk_idx` — data chunk index
115
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  ## License
117
 
118
  MIT