Add KoHRM CPU quantized runtime pack
Browse files- README.md +199 -0
- inference/kohrm_cpu_runtime.py +378 -0
- inference/requirements-cpu.txt +4 -0
- notebooks/kohrm_colab_generate.py +474 -0
README.md
ADDED
|
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: apache-2.0
|
| 3 |
+
base_model: LLM-OS-Models/KoHRM-Text-1.4B
|
| 4 |
+
base_model_relation: quantized
|
| 5 |
+
library_name: pytorch
|
| 6 |
+
tags:
|
| 7 |
+
- kohrm
|
| 8 |
+
- hrm-text
|
| 9 |
+
- cpu
|
| 10 |
+
- int8
|
| 11 |
+
- int4
|
| 12 |
+
- korean
|
| 13 |
+
- terminal
|
| 14 |
+
---
|
| 15 |
+
|
| 16 |
+
# KoHRM-Text-1.4B CPU Runtime
|
| 17 |
+
|
| 18 |
+
This repository contains a CPU-oriented inference runtime for
|
| 19 |
+
`LLM-OS-Models/KoHRM-Text-1.4B`.
|
| 20 |
+
|
| 21 |
+
It does not duplicate the original model weights. The runtime downloads the
|
| 22 |
+
base model from Hugging Face and applies CPU quantization at load time.
|
| 23 |
+
|
| 24 |
+
# KoHRM-Text CPU Runtime Pack
|
| 25 |
+
|
| 26 |
+
์์ฑ์ผ: `2026-06-09`
|
| 27 |
+
|
| 28 |
+
## ๊ฒฐ๋ก
|
| 29 |
+
|
| 30 |
+
`LLM-OS-Models/KoHRM-Text-1.4B`๋ ํ์ฌ GGUF๋ก ๋ฐ๋ก ๋ง๋ค ์ ์๋ค.
|
| 31 |
+
|
| 32 |
+
์ด์ ๋ ๋ชจ๋ธ ๊ตฌ์กฐ๊ฐ ์ผ๋ฐ Llama/Qwen/Gemma ๊ณ์ด์ด ์๋๋ผ ์๋ ์ ์ฉ ๊ตฌ์กฐ์ด๊ธฐ ๋๋ฌธ์ด๋ค.
|
| 33 |
+
|
| 34 |
+
```text
|
| 35 |
+
model_type: hrm_text
|
| 36 |
+
architectures: HrmTextForCausalLM
|
| 37 |
+
H_cycles: 2
|
| 38 |
+
L_cycles: 3
|
| 39 |
+
prefix_lm: true
|
| 40 |
+
```
|
| 41 |
+
|
| 42 |
+
llama.cpp ๋ณํ๊ธฐ๋ก ์ง์ ์๋ํ๋ฉด ๋ค์ ์ง์ ์์ ๋งํ๋ค.
|
| 43 |
+
|
| 44 |
+
```text
|
| 45 |
+
ERROR:hf-to-gguf:Model HrmTextForCausalLM is not supported
|
| 46 |
+
```
|
| 47 |
+
|
| 48 |
+
๋ฐ๋ผ์ ์ง๊ธ ํ์ค์ ์ธ CPU ๊ฒฝ๋ก๋ GGUF๊ฐ ์๋๋ผ PyTorch ์ ์ฉ runtime์ด๋ค.
|
| 49 |
+
|
| 50 |
+
## ์ถ๊ฐํ ํ์ผ
|
| 51 |
+
|
| 52 |
+
```text
|
| 53 |
+
HRM-Text/inference/kohrm_cpu_runtime.py
|
| 54 |
+
HRM-Text/inference/requirements-cpu.txt
|
| 55 |
+
HRM-Text/scripts/upload_kohrm_cpu_runtime_pack.py
|
| 56 |
+
```
|
| 57 |
+
|
| 58 |
+
์ด runtime์ ๊ธฐ์กด `HRM-Text/notebooks/kohrm_colab_generate.py`์ safetensors ์ง์ ๋ก๋ฉ ๊ฒฝ๋ก๋ฅผ ์ฌ์ฌ์ฉํ๊ณ , CPU์ฉ ์์ํ์ H/L cycle override๋ฅผ ์ถ๊ฐํ๋ค.
|
| 59 |
+
|
| 60 |
+
## ์ฌ์ฉ๋ฒ
|
| 61 |
+
|
| 62 |
+
๊ธฐ๋ณธ ๊ถ์ฅ๊ฐ์ `dynamic-int8`์ด๋ค.
|
| 63 |
+
|
| 64 |
+
```bash
|
| 65 |
+
cd /home/work/.projects/LLM-OS-Models/Terminal
|
| 66 |
+
|
| 67 |
+
CUDA_VISIBLE_DEVICES= OMP_NUM_THREADS=8 \
|
| 68 |
+
python HRM-Text/inference/kohrm_cpu_runtime.py \
|
| 69 |
+
--model LLM-OS-Models/KoHRM-Text-1.4B \
|
| 70 |
+
--quant dynamic-int8 \
|
| 71 |
+
--prompt "๋ฆฌ๋
์ค์์ ํ์ฌ ๋๋ ํ ๋ฆฌ ํ์ผ ๋ชฉ๋ก์ ๋ณด๋ ๋ช
๋ น์ด๋?" \
|
| 72 |
+
--max-new-tokens 128 \
|
| 73 |
+
--max-seq-len 768 \
|
| 74 |
+
--temperature 0
|
| 75 |
+
```
|
| 76 |
+
|
| 77 |
+
16GB CPU RAM ํ๊ฒฝ์์๋ ์๋ ์์๋ก ์ฐ๋ฉด ๋๋ค.
|
| 78 |
+
|
| 79 |
+
```text
|
| 80 |
+
1์์: dynamic-int8
|
| 81 |
+
2์์: none
|
| 82 |
+
3์์: weight-int4
|
| 83 |
+
```
|
| 84 |
+
|
| 85 |
+
`dynamic-int8`์ PyTorch CPU dynamic quantization์ ์ฌ์ฉํ๋ค. ์ผ๋ฐ์ ์ผ๋ก ๋ฉ๋ชจ๋ฆฌ์ ์๋ ๊ท ํ์ด ๊ฐ์ฅ ๋ซ๋ค.
|
| 86 |
+
|
| 87 |
+
`weight-int4`๋ ์ง์ ๊ตฌํํ portable 4bit weight-only fallback์ด๋ค. ๋ฉ๋ชจ๋ฆฌ๋ ์ค์ง๋ง ๋งค forward๋ง๋ค unpack/dequantize๊ฐ ๋ค์ด๊ฐ์ ๋งค์ฐ ๋๋ฆฌ๋ค. โ๋ฐ๋์ ์์ ๋ฉ๋ชจ๋ฆฌ๋ก ๋์๊ฐ์ผ ํ๋คโ๋ ๊ฒฝ์ฐ์๋ง ์ด๋ค.
|
| 88 |
+
|
| 89 |
+
## H/L cycle override
|
| 90 |
+
|
| 91 |
+
KoHRM์ ๊ฐ์ H/L module์ ๋ฐ๋ณต ์ ์ฉํ๋ค. ๊ธฐ๋ณธ์ `H=2`, `L=3`์ด๋ค.
|
| 92 |
+
|
| 93 |
+
CPU์์๋ ์๋์ฒ๋ผ ๋ฐ๋ณต ํ์๋ฅผ ์ค์ฌ ์๋๋ฅผ ์ฌ๋ฆด ์ ์๋ค.
|
| 94 |
+
|
| 95 |
+
```bash
|
| 96 |
+
CUDA_VISIBLE_DEVICES= OMP_NUM_THREADS=8 \
|
| 97 |
+
python HRM-Text/inference/kohrm_cpu_runtime.py \
|
| 98 |
+
--model LLM-OS-Models/KoHRM-Text-1.4B \
|
| 99 |
+
--quant dynamic-int8 \
|
| 100 |
+
--h-cycles 1 \
|
| 101 |
+
--l-cycles 1 \
|
| 102 |
+
--prompt "๋ฆฌ๋
์ค์์ ํ์ฌ ๋๋ ํ ๋ฆฌ ํ์ผ ๋ชฉ๋ก์ ๋ณด๋ ๋ช
๋ น์ด๋?" \
|
| 103 |
+
--max-new-tokens 128 \
|
| 104 |
+
--max-seq-len 768 \
|
| 105 |
+
--temperature 0
|
| 106 |
+
```
|
| 107 |
+
|
| 108 |
+
์ฃผ์ํ ์ ์ ๋ช
ํํ๋ค.
|
| 109 |
+
|
| 110 |
+
- `H=2,L=3`: ์๋ ํ์ง ๊ฒฝ๋ก.
|
| 111 |
+
- `H=1,L=1`: CPU ์๋ ์ฐ์ ๊ฒฝ๋ก.
|
| 112 |
+
- cycle์ ์ค์ด๋ฉด ํ์ง์ด ๋จ์ด์ง ์ ์๋ค.
|
| 113 |
+
|
| 114 |
+
## Smoke test ๊ฒฐ๊ณผ
|
| 115 |
+
|
| 116 |
+
๊ฐ์ ์งง์ prompt, `max_new_tokens=4`, `max_seq_len=128`, `OMP_NUM_THREADS=8` ๊ธฐ์ค์ด๋ค.
|
| 117 |
+
|
| 118 |
+
```text
|
| 119 |
+
none:
|
| 120 |
+
elapsed: 1.48s
|
| 121 |
+
speed: 2.69 tok/s
|
| 122 |
+
cycles: H=2, L=3
|
| 123 |
+
|
| 124 |
+
dynamic-int8:
|
| 125 |
+
elapsed: 0.53s
|
| 126 |
+
speed: 7.59 tok/s
|
| 127 |
+
cycles: H=2, L=3
|
| 128 |
+
|
| 129 |
+
dynamic-int8 + H=1,L=1:
|
| 130 |
+
elapsed: 0.24s
|
| 131 |
+
speed: 8.18 tok/s
|
| 132 |
+
cycles: H=1, L=1
|
| 133 |
+
|
| 134 |
+
weight-int4:
|
| 135 |
+
elapsed: 23.25s
|
| 136 |
+
speed: 0.17 tok/s
|
| 137 |
+
cycles: H=2, L=3
|
| 138 |
+
```
|
| 139 |
+
|
| 140 |
+
์งง์ smoke test๋ผ ์ ๋ ์ฑ๋ฅ ์ซ์๋ ์ฐธ๊ณ ์ฉ์ด๋ค. ํ์ง๋ง ๋ฐฉํฅ์ ๋ถ๋ช
ํ๋ค.
|
| 141 |
+
|
| 142 |
+
```text
|
| 143 |
+
์ค์ฌ์ฉ: dynamic-int8
|
| 144 |
+
๋ฉ๋ชจ๋ฆฌ ๊ฐ์ ์ ์ฝ: weight-int4
|
| 145 |
+
ํ์ง ์ ์ง: H=2,L=3
|
| 146 |
+
์๋ ์ฐ์ : H=1,L=1
|
| 147 |
+
```
|
| 148 |
+
|
| 149 |
+
## ์ GGUF๊ฐ ์ด๋ ค์ด๊ฐ
|
| 150 |
+
|
| 151 |
+
GGUF ํ์ผ์ ๋จ์ํ weight๋ฅผ ๋ด๋ ํฌ๋งท์ด ์๋๋ค. llama.cpp๊ฐ ํด๋น architecture์ forward pass๋ฅผ ์์์ผ ํ๋ค.
|
| 152 |
+
|
| 153 |
+
KoHRM์ ์ผ๋ฐ Transformer block์ ํ ๋ฒ์ฉ ์๋ ๋ชจ๋ธ์ด ์๋๋ค.
|
| 154 |
+
|
| 155 |
+
- H module๊ณผ L module์ด ์๋ค.
|
| 156 |
+
- `H_cycles`, `L_cycles`๋งํผ recurrentํ๊ฒ ๋ฐ๋ณตํ๋ค.
|
| 157 |
+
- PrefixLM formatting๊ณผ stop token ์ฒ๋ฆฌ๊ฐ ๋ค๋ฅด๋ค.
|
| 158 |
+
- KV cache ๊ตฌ์กฐ๋ ์ผ๋ฐ chat causal LM๊ณผ ๋ค๋ฅด๋ค.
|
| 159 |
+
|
| 160 |
+
๋ฐ๋ผ์ GGUF๋ฅผ ์ ๋๋ก ๋ง๋ค๋ ค๋ฉด ๋ค์ ์์
์ด ํ์ํ๋ค.
|
| 161 |
+
|
| 162 |
+
```text
|
| 163 |
+
1. llama.cpp MODEL_ARCH์ HRM_TEXT ์ถ๊ฐ
|
| 164 |
+
2. H/L recurrent forward ๊ตฌํ
|
| 165 |
+
3. gqkv gated attention ๊ตฌํ
|
| 166 |
+
4. PrefixLM prompt/token boundary ์ฒ๋ฆฌ
|
| 167 |
+
5. tokenizer pre-tokenizer hash ๋ฑ๋ก
|
| 168 |
+
6. quantized tensor name mapping ์์ฑ
|
| 169 |
+
7. llama-cli generation smoke test
|
| 170 |
+
```
|
| 171 |
+
|
| 172 |
+
๋จ์ converter patch๋ก ๋๋๋ ๋ฌธ์ ๊ฐ ์๋๋ค.
|
| 173 |
+
|
| 174 |
+
## HF CPU pack
|
| 175 |
+
|
| 176 |
+
HF์๋ ๊ฐ์ค์น๋ฅผ ์ค๋ณต ์
๋ก๋ํ์ง ์๊ณ CPU runtime pack์ ๋ฐ๋ก ์ฌ๋ฆฐ๋ค.
|
| 177 |
+
|
| 178 |
+
๋์ repo:
|
| 179 |
+
|
| 180 |
+
```text
|
| 181 |
+
LLM-OS-Models/KoHRM-Text-1.4B-CPU-Runtime
|
| 182 |
+
```
|
| 183 |
+
|
| 184 |
+
์ด repo์๋ ๋ค์๋ง ๋ค์ด๊ฐ๋ค.
|
| 185 |
+
|
| 186 |
+
```text
|
| 187 |
+
README.md
|
| 188 |
+
inference/kohrm_cpu_runtime.py
|
| 189 |
+
inference/requirements-cpu.txt
|
| 190 |
+
notebooks/kohrm_colab_generate.py
|
| 191 |
+
```
|
| 192 |
+
|
| 193 |
+
๊ฐ์ค์น๋ ๏ฟฝ๏ฟฝ๏ฟฝํ ์ ์๋ณธ repo์์ ๋ฐ๋๋ค.
|
| 194 |
+
|
| 195 |
+
```text
|
| 196 |
+
LLM-OS-Models/KoHRM-Text-1.4B
|
| 197 |
+
```
|
| 198 |
+
|
| 199 |
+
๊ณต์ฉ์ปด ๊ธฐ์ค์ผ๋ก HF token์ `.env`์์ ์ฝ๋ ์ถ๋ ฅํ์ง ์๋๋ค.
|
inference/kohrm_cpu_runtime.py
ADDED
|
@@ -0,0 +1,378 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""CPU-oriented KoHRM-Text inference runtime.
|
| 2 |
+
|
| 3 |
+
KoHRM-Text uses the custom ``hrm_text`` / ``HrmTextForCausalLM`` architecture,
|
| 4 |
+
so it cannot currently be served by llama.cpp/GGUF or ordinary vLLM paths.
|
| 5 |
+
This runtime wraps the existing safetensors loader and adds CPU-friendly
|
| 6 |
+
quantization and cycle overrides.
|
| 7 |
+
|
| 8 |
+
Recommended mode for normal CPU use:
|
| 9 |
+
|
| 10 |
+
python HRM-Text/inference/kohrm_cpu_runtime.py \
|
| 11 |
+
--model LLM-OS-Models/KoHRM-Text-1.4B \
|
| 12 |
+
--quant dynamic-int8 \
|
| 13 |
+
--prompt "๋ฆฌ๋
์ค์์ ํ์ฌ ๋๋ ํ ๋ฆฌ ํ์ผ ๋ชฉ๋ก์ ๋ณด๋ ๋ช
๋ น์ด๋?" \
|
| 14 |
+
--max-new-tokens 64
|
| 15 |
+
|
| 16 |
+
Experimental memory-first mode:
|
| 17 |
+
|
| 18 |
+
python HRM-Text/inference/kohrm_cpu_runtime.py --quant weight-int4 ...
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
from __future__ import annotations
|
| 22 |
+
|
| 23 |
+
import argparse
|
| 24 |
+
import gc
|
| 25 |
+
import importlib.util
|
| 26 |
+
import json
|
| 27 |
+
import math
|
| 28 |
+
import os
|
| 29 |
+
import shutil
|
| 30 |
+
import sys
|
| 31 |
+
import time
|
| 32 |
+
from dataclasses import dataclass
|
| 33 |
+
from pathlib import Path
|
| 34 |
+
from typing import Any
|
| 35 |
+
|
| 36 |
+
import torch
|
| 37 |
+
import torch.nn as nn
|
| 38 |
+
import torch.nn.functional as F
|
| 39 |
+
from huggingface_hub import snapshot_download
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
REPO_ROOT = Path(__file__).resolve().parents[1]
|
| 43 |
+
HELPER_PATH = REPO_ROOT / "notebooks" / "kohrm_colab_generate.py"
|
| 44 |
+
DEFAULT_REPO_ID = "LLM-OS-Models/KoHRM-Text-1.4B"
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def _load_helper():
|
| 48 |
+
if not HELPER_PATH.exists():
|
| 49 |
+
raise FileNotFoundError(f"missing KoHRM helper: {HELPER_PATH}")
|
| 50 |
+
spec = importlib.util.spec_from_file_location("kohrm_colab_generate", HELPER_PATH)
|
| 51 |
+
if spec is None or spec.loader is None:
|
| 52 |
+
raise RuntimeError(f"cannot import helper from {HELPER_PATH}")
|
| 53 |
+
module = importlib.util.module_from_spec(spec)
|
| 54 |
+
sys.modules.setdefault("kohrm_colab_generate", module)
|
| 55 |
+
spec.loader.exec_module(module)
|
| 56 |
+
return module
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def _read_dotenv_token() -> str | None:
|
| 60 |
+
"""Read a local HF token without printing it or exporting it to shell logs."""
|
| 61 |
+
candidates = [
|
| 62 |
+
Path.cwd() / ".env",
|
| 63 |
+
REPO_ROOT.parent / ".env",
|
| 64 |
+
REPO_ROOT / ".env",
|
| 65 |
+
Path.home() / ".cache" / "huggingface" / "token",
|
| 66 |
+
]
|
| 67 |
+
for path in candidates:
|
| 68 |
+
if not path.exists():
|
| 69 |
+
continue
|
| 70 |
+
if path.name == "token":
|
| 71 |
+
token = path.read_text(encoding="utf-8").strip()
|
| 72 |
+
return token or None
|
| 73 |
+
for raw in path.read_text(encoding="utf-8", errors="ignore").splitlines():
|
| 74 |
+
line = raw.strip()
|
| 75 |
+
if not line or line.startswith("#") or "=" not in line:
|
| 76 |
+
continue
|
| 77 |
+
key, value = line.split("=", 1)
|
| 78 |
+
key = key.strip()
|
| 79 |
+
if key.startswith("export "):
|
| 80 |
+
key = key.split(None, 1)[1]
|
| 81 |
+
if key in {"HF_TOKEN", "HUGGINGFACE_TOKEN", "HUGGING_FACE_HUB_TOKEN"}:
|
| 82 |
+
token = value.strip().strip('"').strip("'")
|
| 83 |
+
return token or None
|
| 84 |
+
return os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_TOKEN")
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def resolve_model_dir(model: str, revision: str | None = None) -> Path:
|
| 88 |
+
path = Path(model).expanduser()
|
| 89 |
+
if path.exists():
|
| 90 |
+
return path
|
| 91 |
+
token = _read_dotenv_token()
|
| 92 |
+
return Path(
|
| 93 |
+
snapshot_download(
|
| 94 |
+
repo_id=model,
|
| 95 |
+
revision=revision,
|
| 96 |
+
allow_patterns=["config.json", "tokenizer.json", "tokenizer_config.json", "model.safetensors", "README.md"],
|
| 97 |
+
token=token,
|
| 98 |
+
)
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
@dataclass
|
| 103 |
+
class RuntimeStats:
|
| 104 |
+
prompt_tokens: int
|
| 105 |
+
generated_tokens: int
|
| 106 |
+
elapsed_s: float
|
| 107 |
+
tokens_per_s: float
|
| 108 |
+
quantization: str
|
| 109 |
+
h_cycles: int
|
| 110 |
+
l_cycles: int
|
| 111 |
+
dtype: str
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
class WeightOnlyInt8Linear(nn.Module):
|
| 115 |
+
"""Simple symmetric per-group int8 weight-only Linear.
|
| 116 |
+
|
| 117 |
+
This is a portability fallback, not an optimized kernel. It reduces resident
|
| 118 |
+
weight memory after conversion, but dequantizes on forward. For speed, prefer
|
| 119 |
+
PyTorch dynamic int8.
|
| 120 |
+
"""
|
| 121 |
+
|
| 122 |
+
def __init__(self, qweight: torch.Tensor, scales: torch.Tensor, in_features: int, out_features: int, group_size: int) -> None:
|
| 123 |
+
super().__init__()
|
| 124 |
+
self.in_features = int(in_features)
|
| 125 |
+
self.out_features = int(out_features)
|
| 126 |
+
self.group_size = int(group_size)
|
| 127 |
+
self.register_buffer("qweight", qweight.contiguous())
|
| 128 |
+
self.register_buffer("scales", scales.contiguous())
|
| 129 |
+
|
| 130 |
+
@classmethod
|
| 131 |
+
def from_linear(cls, linear: nn.Linear, group_size: int = 128) -> "WeightOnlyInt8Linear":
|
| 132 |
+
weight = linear.weight.detach().to(dtype=torch.float32, device="cpu")
|
| 133 |
+
out_features, in_features = weight.shape
|
| 134 |
+
pad = (-in_features) % group_size
|
| 135 |
+
if pad:
|
| 136 |
+
weight = F.pad(weight, (0, pad))
|
| 137 |
+
grouped = weight.view(out_features, -1, group_size)
|
| 138 |
+
scales = grouped.abs().amax(dim=-1).clamp_min(1e-8) / 127.0
|
| 139 |
+
qweight = torch.round(grouped / scales.unsqueeze(-1)).clamp(-127, 127).to(torch.int8)
|
| 140 |
+
return cls(qweight=qweight, scales=scales.to(torch.float16), in_features=in_features, out_features=out_features, group_size=group_size)
|
| 141 |
+
|
| 142 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 143 |
+
weight = (self.qweight.to(torch.float32) * self.scales.to(torch.float32).unsqueeze(-1)).view(self.out_features, -1)
|
| 144 |
+
weight = weight[:, : self.in_features].to(dtype=x.dtype)
|
| 145 |
+
return F.linear(x, weight)
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
class WeightOnlyInt4Linear(nn.Module):
|
| 149 |
+
"""Portable symmetric per-group int4 weight-only Linear.
|
| 150 |
+
|
| 151 |
+
Values are stored as packed signed nibbles. Forward unpacks and dequantizes
|
| 152 |
+
on CPU, so this is memory-first rather than speed-first.
|
| 153 |
+
"""
|
| 154 |
+
|
| 155 |
+
def __init__(self, packed: torch.Tensor, scales: torch.Tensor, in_features: int, out_features: int, padded_features: int, group_size: int) -> None:
|
| 156 |
+
super().__init__()
|
| 157 |
+
self.in_features = int(in_features)
|
| 158 |
+
self.out_features = int(out_features)
|
| 159 |
+
self.padded_features = int(padded_features)
|
| 160 |
+
self.group_size = int(group_size)
|
| 161 |
+
self.register_buffer("packed", packed.contiguous())
|
| 162 |
+
self.register_buffer("scales", scales.contiguous())
|
| 163 |
+
|
| 164 |
+
@classmethod
|
| 165 |
+
def from_linear(cls, linear: nn.Linear, group_size: int = 128) -> "WeightOnlyInt4Linear":
|
| 166 |
+
weight = linear.weight.detach().to(dtype=torch.float32, device="cpu")
|
| 167 |
+
out_features, in_features = weight.shape
|
| 168 |
+
pad_group = (-in_features) % group_size
|
| 169 |
+
if pad_group:
|
| 170 |
+
weight = F.pad(weight, (0, pad_group))
|
| 171 |
+
if weight.shape[1] % 2:
|
| 172 |
+
weight = F.pad(weight, (0, 1))
|
| 173 |
+
padded_features = weight.shape[1]
|
| 174 |
+
grouped = weight.view(out_features, -1, group_size)
|
| 175 |
+
scales = grouped.abs().amax(dim=-1).clamp_min(1e-8) / 7.0
|
| 176 |
+
q = torch.round(grouped / scales.unsqueeze(-1)).clamp(-8, 7).to(torch.int16)
|
| 177 |
+
q = (q + 16).remainder(16).to(torch.uint8).view(out_features, padded_features)
|
| 178 |
+
low = q[:, 0::2]
|
| 179 |
+
high = q[:, 1::2] << 4
|
| 180 |
+
packed = low | high
|
| 181 |
+
return cls(
|
| 182 |
+
packed=packed,
|
| 183 |
+
scales=scales.to(torch.float16),
|
| 184 |
+
in_features=in_features,
|
| 185 |
+
out_features=out_features,
|
| 186 |
+
padded_features=padded_features,
|
| 187 |
+
group_size=group_size,
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
def _unpack(self) -> torch.Tensor:
|
| 191 |
+
low = self.packed & 0x0F
|
| 192 |
+
high = (self.packed >> 4) & 0x0F
|
| 193 |
+
q = torch.empty((self.out_features, self.packed.shape[1] * 2), dtype=torch.int16, device=self.packed.device)
|
| 194 |
+
q[:, 0::2] = low.to(torch.int16)
|
| 195 |
+
q[:, 1::2] = high.to(torch.int16)
|
| 196 |
+
q = torch.where(q >= 8, q - 16, q)
|
| 197 |
+
return q[:, : self.padded_features]
|
| 198 |
+
|
| 199 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 200 |
+
q = self._unpack().to(torch.float32)
|
| 201 |
+
weight = (q.view(self.out_features, -1, self.group_size) * self.scales.to(torch.float32).unsqueeze(-1)).view(self.out_features, -1)
|
| 202 |
+
weight = weight[:, : self.in_features].to(dtype=x.dtype)
|
| 203 |
+
return F.linear(x, weight)
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def _replace_linear_modules(module: nn.Module, *, quant: str, group_size: int, quantize_lm_head: bool, prefix: str = "") -> int:
|
| 207 |
+
replaced = 0
|
| 208 |
+
for name, child in list(module.named_children()):
|
| 209 |
+
child_prefix = f"{prefix}.{name}" if prefix else name
|
| 210 |
+
if isinstance(child, nn.Linear):
|
| 211 |
+
if child_prefix == "lm_head" and not quantize_lm_head:
|
| 212 |
+
continue
|
| 213 |
+
if child.bias is not None:
|
| 214 |
+
raise ValueError(f"bias is not supported by portable weight-only quantization: {child_prefix}")
|
| 215 |
+
if quant == "weight-int8":
|
| 216 |
+
new_child = WeightOnlyInt8Linear.from_linear(child, group_size=group_size)
|
| 217 |
+
elif quant == "weight-int4":
|
| 218 |
+
new_child = WeightOnlyInt4Linear.from_linear(child, group_size=group_size)
|
| 219 |
+
else:
|
| 220 |
+
raise ValueError(f"unsupported weight-only quantization: {quant}")
|
| 221 |
+
setattr(module, name, new_child)
|
| 222 |
+
replaced += 1
|
| 223 |
+
else:
|
| 224 |
+
replaced += _replace_linear_modules(child, quant=quant, group_size=group_size, quantize_lm_head=quantize_lm_head, prefix=child_prefix)
|
| 225 |
+
return replaced
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def apply_quantization(
|
| 229 |
+
model: nn.Module,
|
| 230 |
+
quant: str,
|
| 231 |
+
*,
|
| 232 |
+
group_size: int = 128,
|
| 233 |
+
quantize_lm_head: bool = False,
|
| 234 |
+
) -> nn.Module:
|
| 235 |
+
if quant == "none":
|
| 236 |
+
return model
|
| 237 |
+
if quant == "dynamic-int8":
|
| 238 |
+
torch.backends.quantized.engine = "fbgemm"
|
| 239 |
+
return torch.ao.quantization.quantize_dynamic(model.cpu(), {nn.Linear}, dtype=torch.qint8, inplace=False)
|
| 240 |
+
if quant in {"weight-int8", "weight-int4"}:
|
| 241 |
+
replaced = _replace_linear_modules(model, quant=quant, group_size=group_size, quantize_lm_head=quantize_lm_head)
|
| 242 |
+
if replaced == 0:
|
| 243 |
+
raise RuntimeError("no Linear modules were replaced")
|
| 244 |
+
gc.collect()
|
| 245 |
+
return model.cpu().eval()
|
| 246 |
+
raise ValueError(f"unknown quantization mode: {quant}")
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
def load_runtime(
|
| 250 |
+
model_dir: Path,
|
| 251 |
+
*,
|
| 252 |
+
quant: str,
|
| 253 |
+
h_cycles: int | None,
|
| 254 |
+
l_cycles: int | None,
|
| 255 |
+
group_size: int,
|
| 256 |
+
quantize_lm_head: bool,
|
| 257 |
+
):
|
| 258 |
+
helper = _load_helper()
|
| 259 |
+
model, tokenizer, cfg = helper.load_kohrm(model_dir, device="cpu")
|
| 260 |
+
if h_cycles is not None:
|
| 261 |
+
cfg["H_cycles"] = int(h_cycles)
|
| 262 |
+
model.cfg["H_cycles"] = int(h_cycles)
|
| 263 |
+
model.model.cfg["H_cycles"] = int(h_cycles)
|
| 264 |
+
if l_cycles is not None:
|
| 265 |
+
cfg["L_cycles"] = int(l_cycles)
|
| 266 |
+
model.cfg["L_cycles"] = int(l_cycles)
|
| 267 |
+
model.model.cfg["L_cycles"] = int(l_cycles)
|
| 268 |
+
model = apply_quantization(model, quant, group_size=group_size, quantize_lm_head=quantize_lm_head)
|
| 269 |
+
return helper, model.eval(), tokenizer, cfg
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
def generate(
|
| 273 |
+
model: nn.Module,
|
| 274 |
+
tokenizer: Any,
|
| 275 |
+
cfg: dict[str, Any],
|
| 276 |
+
helper: Any,
|
| 277 |
+
prompt: str,
|
| 278 |
+
*,
|
| 279 |
+
max_new_tokens: int,
|
| 280 |
+
min_new_tokens: int,
|
| 281 |
+
max_seq_len: int,
|
| 282 |
+
temperature: float,
|
| 283 |
+
top_p: float,
|
| 284 |
+
repetition_penalty: float,
|
| 285 |
+
no_repeat_ngram_size: int,
|
| 286 |
+
condition: str,
|
| 287 |
+
) -> tuple[str, RuntimeStats]:
|
| 288 |
+
wrapped = helper.format_kohrm_prompt(prompt, condition=condition)
|
| 289 |
+
prompt_tokens = len(tokenizer.encode(wrapped, add_special_tokens=False).ids)
|
| 290 |
+
start = time.perf_counter()
|
| 291 |
+
output = helper.generate_from_loaded(
|
| 292 |
+
model,
|
| 293 |
+
tokenizer,
|
| 294 |
+
cfg,
|
| 295 |
+
prompt,
|
| 296 |
+
max_new_tokens=max_new_tokens,
|
| 297 |
+
min_new_tokens=min_new_tokens,
|
| 298 |
+
max_seq_len=max_seq_len,
|
| 299 |
+
temperature=temperature,
|
| 300 |
+
top_p=top_p,
|
| 301 |
+
repetition_penalty=repetition_penalty,
|
| 302 |
+
no_repeat_ngram_size=no_repeat_ngram_size,
|
| 303 |
+
condition=condition,
|
| 304 |
+
)
|
| 305 |
+
elapsed = time.perf_counter() - start
|
| 306 |
+
out_tokens = len(tokenizer.encode(output, add_special_tokens=False).ids) if output else 0
|
| 307 |
+
stats = RuntimeStats(
|
| 308 |
+
prompt_tokens=prompt_tokens,
|
| 309 |
+
generated_tokens=out_tokens,
|
| 310 |
+
elapsed_s=elapsed,
|
| 311 |
+
tokens_per_s=(out_tokens / elapsed if elapsed > 0 else math.nan),
|
| 312 |
+
quantization="",
|
| 313 |
+
h_cycles=int(cfg.get("H_cycles", 0)),
|
| 314 |
+
l_cycles=int(cfg.get("L_cycles", 0)),
|
| 315 |
+
dtype=str(next(model.parameters()).dtype) if any(True for _ in model.parameters()) else "unknown",
|
| 316 |
+
)
|
| 317 |
+
return output, stats
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
def build_arg_parser() -> argparse.ArgumentParser:
|
| 321 |
+
ap = argparse.ArgumentParser(description="Run KoHRM-Text on CPU with optional quantization.")
|
| 322 |
+
ap.add_argument("--model", default=DEFAULT_REPO_ID, help="HF repo id or local directory containing KoHRM HF export files.")
|
| 323 |
+
ap.add_argument("--revision", default=None)
|
| 324 |
+
ap.add_argument("--prompt", required=True)
|
| 325 |
+
ap.add_argument("--quant", choices=["none", "dynamic-int8", "weight-int8", "weight-int4"], default="dynamic-int8")
|
| 326 |
+
ap.add_argument("--group-size", type=int, default=128)
|
| 327 |
+
ap.add_argument("--quantize-lm-head", action="store_true", help="Also quantize lm_head in portable weight-only modes. Saves memory but slows generation.")
|
| 328 |
+
ap.add_argument("--h-cycles", type=int, default=None, help="Override H_cycles. Lower values trade quality for CPU speed.")
|
| 329 |
+
ap.add_argument("--l-cycles", type=int, default=None, help="Override L_cycles. Lower values trade quality for CPU speed.")
|
| 330 |
+
ap.add_argument("--max-new-tokens", type=int, default=128)
|
| 331 |
+
ap.add_argument("--min-new-tokens", type=int, default=0)
|
| 332 |
+
ap.add_argument("--max-seq-len", type=int, default=768)
|
| 333 |
+
ap.add_argument("--temperature", type=float, default=0.0)
|
| 334 |
+
ap.add_argument("--top-p", type=float, default=0.9)
|
| 335 |
+
ap.add_argument("--repetition-penalty", type=float, default=1.05)
|
| 336 |
+
ap.add_argument("--no-repeat-ngram-size", type=int, default=0)
|
| 337 |
+
ap.add_argument("--condition", default="direct", choices=["direct", "cot", "noisy", "synth"])
|
| 338 |
+
ap.add_argument("--json-stats", action="store_true")
|
| 339 |
+
return ap
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
def main() -> None:
|
| 343 |
+
args = build_arg_parser().parse_args()
|
| 344 |
+
# Keep CPU execution predictable on shared machines.
|
| 345 |
+
if "OMP_NUM_THREADS" not in os.environ:
|
| 346 |
+
os.environ["OMP_NUM_THREADS"] = str(min(8, os.cpu_count() or 8))
|
| 347 |
+
model_dir = resolve_model_dir(args.model, revision=args.revision)
|
| 348 |
+
helper, model, tokenizer, cfg = load_runtime(
|
| 349 |
+
model_dir,
|
| 350 |
+
quant=args.quant,
|
| 351 |
+
h_cycles=args.h_cycles,
|
| 352 |
+
l_cycles=args.l_cycles,
|
| 353 |
+
group_size=args.group_size,
|
| 354 |
+
quantize_lm_head=args.quantize_lm_head,
|
| 355 |
+
)
|
| 356 |
+
output, stats = generate(
|
| 357 |
+
model,
|
| 358 |
+
tokenizer,
|
| 359 |
+
cfg,
|
| 360 |
+
helper,
|
| 361 |
+
args.prompt,
|
| 362 |
+
max_new_tokens=args.max_new_tokens,
|
| 363 |
+
min_new_tokens=args.min_new_tokens,
|
| 364 |
+
max_seq_len=args.max_seq_len,
|
| 365 |
+
temperature=args.temperature,
|
| 366 |
+
top_p=args.top_p,
|
| 367 |
+
repetition_penalty=args.repetition_penalty,
|
| 368 |
+
no_repeat_ngram_size=args.no_repeat_ngram_size,
|
| 369 |
+
condition=args.condition,
|
| 370 |
+
)
|
| 371 |
+
stats.quantization = args.quant
|
| 372 |
+
print(output)
|
| 373 |
+
if args.json_stats:
|
| 374 |
+
print(json.dumps(stats.__dict__, ensure_ascii=False), file=sys.stderr)
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
if __name__ == "__main__":
|
| 378 |
+
main()
|
inference/requirements-cpu.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=2.6
|
| 2 |
+
safetensors>=0.4.5
|
| 3 |
+
tokenizers>=0.20
|
| 4 |
+
huggingface_hub>=0.28
|
notebooks/kohrm_colab_generate.py
ADDED
|
@@ -0,0 +1,474 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Minimal KoHRM-Text generation runtime for Colab.
|
| 2 |
+
|
| 3 |
+
This file intentionally avoids `transformers` and FlashAttention. It loads the
|
| 4 |
+
public `model.safetensors` export and runs HRM-Text generation with PyTorch
|
| 5 |
+
scaled-dot-product attention. It is built for long pretraining-checkpoint
|
| 6 |
+
knowledge probes on Colab T4 and small CUDA machines.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
import json
|
| 12 |
+
import math
|
| 13 |
+
import argparse
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
from typing import Any
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn as nn
|
| 19 |
+
import torch.nn.functional as F
|
| 20 |
+
from safetensors.torch import load_file
|
| 21 |
+
from tokenizers import Tokenizer
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
DEFAULT_CONDITION_TOKENS = {
|
| 25 |
+
"direct": "<|object_ref_start|>",
|
| 26 |
+
"cot": "<|object_ref_end|>",
|
| 27 |
+
"noisy": "<|quad_start|>",
|
| 28 |
+
"synth": "<|quad_end|>",
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def _rms_norm(x: torch.Tensor, eps: float) -> torch.Tensor:
|
| 33 |
+
return F.rms_norm(x, (x.shape[-1],), eps=eps)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
|
| 37 |
+
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
|
| 38 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def _rope_cos_sin(position_ids: torch.Tensor, head_dim: int, theta: float, dtype: torch.dtype) -> tuple[torch.Tensor, torch.Tensor]:
|
| 42 |
+
inv_freq = 1.0 / (theta ** (torch.arange(0, head_dim, 2, device=position_ids.device, dtype=torch.float32) / head_dim))
|
| 43 |
+
freqs = torch.einsum("bt,d->btd", position_ids.to(torch.float32), inv_freq)
|
| 44 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 45 |
+
return emb.cos().to(dtype), emb.sin().to(dtype)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def _apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
|
| 49 |
+
return ((x * cos.unsqueeze(-2)) + (_rotate_half(x) * sin.unsqueeze(-2))).to(x.dtype)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class KoHRMAttention(nn.Module):
|
| 53 |
+
def __init__(self, hidden_size: int, num_heads: int, head_dim: int, device: str = "meta") -> None:
|
| 54 |
+
super().__init__()
|
| 55 |
+
self.num_heads = num_heads
|
| 56 |
+
self.head_dim = head_dim
|
| 57 |
+
self.gqkv_proj = nn.Linear(hidden_size, (4 * num_heads) * head_dim, bias=False, device=device)
|
| 58 |
+
self.o_proj = nn.Linear(num_heads * head_dim, hidden_size, bias=False, device=device)
|
| 59 |
+
|
| 60 |
+
def forward(
|
| 61 |
+
self,
|
| 62 |
+
x: torch.Tensor,
|
| 63 |
+
cos: torch.Tensor,
|
| 64 |
+
sin: torch.Tensor,
|
| 65 |
+
cache: dict[str, torch.Tensor] | None,
|
| 66 |
+
cache_pos: int,
|
| 67 |
+
) -> torch.Tensor:
|
| 68 |
+
bsz, seqlen, _ = x.shape
|
| 69 |
+
gqkv = self.gqkv_proj(x).view(bsz, seqlen, 4 * self.num_heads, self.head_dim)
|
| 70 |
+
gate, q, k, v = gqkv.split((self.num_heads, self.num_heads, self.num_heads, self.num_heads), dim=-2)
|
| 71 |
+
q = _apply_rope(q, cos, sin)
|
| 72 |
+
k = _apply_rope(k, cos, sin)
|
| 73 |
+
|
| 74 |
+
if cache is not None:
|
| 75 |
+
end = cache_pos + seqlen
|
| 76 |
+
cache["k"][:, cache_pos:end].copy_(k)
|
| 77 |
+
cache["v"][:, cache_pos:end].copy_(v)
|
| 78 |
+
k = cache["k"][:, :end]
|
| 79 |
+
v = cache["v"][:, :end]
|
| 80 |
+
|
| 81 |
+
q = q.transpose(1, 2)
|
| 82 |
+
k = k.transpose(1, 2)
|
| 83 |
+
v = v.transpose(1, 2)
|
| 84 |
+
y = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=False)
|
| 85 |
+
y = y.transpose(1, 2)
|
| 86 |
+
y = (torch.sigmoid(gate) * y).reshape(bsz, seqlen, self.num_heads * self.head_dim)
|
| 87 |
+
return self.o_proj(y)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class KoHRMMLP(nn.Module):
|
| 91 |
+
def __init__(self, hidden_size: int, intermediate_size: int, device: str = "meta") -> None:
|
| 92 |
+
super().__init__()
|
| 93 |
+
self.gate_up_proj = nn.Linear(hidden_size, 2 * intermediate_size, bias=False, device=device)
|
| 94 |
+
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False, device=device)
|
| 95 |
+
|
| 96 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 97 |
+
gate, up = self.gate_up_proj(x).chunk(2, dim=-1)
|
| 98 |
+
return self.down_proj(F.silu(gate) * up)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
class KoHRMBlock(nn.Module):
|
| 102 |
+
def __init__(self, cfg: dict[str, Any], device: str = "meta") -> None:
|
| 103 |
+
super().__init__()
|
| 104 |
+
self.eps = float(cfg["rms_norm_eps"])
|
| 105 |
+
self.attn = KoHRMAttention(cfg["hidden_size"], cfg["num_attention_heads"], cfg["head_dim"], device=device)
|
| 106 |
+
self.mlp = KoHRMMLP(cfg["hidden_size"], cfg["intermediate_size"], device=device)
|
| 107 |
+
|
| 108 |
+
def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, cache: dict[str, torch.Tensor] | None, cache_pos: int) -> torch.Tensor:
|
| 109 |
+
x = x + self.attn(_rms_norm(x, self.eps), cos, sin, cache, cache_pos)
|
| 110 |
+
x = x + self.mlp(_rms_norm(x, self.eps))
|
| 111 |
+
return x
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
class KoHRMModule(nn.Module):
|
| 115 |
+
def __init__(self, cfg: dict[str, Any], num_layers: int, device: str = "meta") -> None:
|
| 116 |
+
super().__init__()
|
| 117 |
+
self.eps = float(cfg["rms_norm_eps"])
|
| 118 |
+
self.layers = nn.ModuleList([KoHRMBlock(cfg, device=device) for _ in range(num_layers)])
|
| 119 |
+
|
| 120 |
+
def forward(
|
| 121 |
+
self,
|
| 122 |
+
hidden_states: torch.Tensor,
|
| 123 |
+
input_injection: torch.Tensor,
|
| 124 |
+
cos: torch.Tensor,
|
| 125 |
+
sin: torch.Tensor,
|
| 126 |
+
caches: list[dict[str, torch.Tensor]] | None,
|
| 127 |
+
cache_pos: int,
|
| 128 |
+
) -> torch.Tensor:
|
| 129 |
+
x = hidden_states + input_injection
|
| 130 |
+
for idx, layer in enumerate(self.layers):
|
| 131 |
+
x = layer(x, cos, sin, None if caches is None else caches[idx], cache_pos)
|
| 132 |
+
return _rms_norm(x, self.eps)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
class KoHRMCore(nn.Module):
|
| 136 |
+
def __init__(self, cfg: dict[str, Any], num_layers: int, device: str = "meta") -> None:
|
| 137 |
+
super().__init__()
|
| 138 |
+
self.cfg = cfg
|
| 139 |
+
self.embedding_scale = float(cfg.get("embedding_scale", 1.0))
|
| 140 |
+
self.embed_tokens = nn.Embedding(cfg["vocab_size"], cfg["hidden_size"], device=device)
|
| 141 |
+
self.register_buffer("z_L_init", torch.empty(cfg["hidden_size"], device=device), persistent=True)
|
| 142 |
+
self.H_module = KoHRMModule(cfg, num_layers, device=device)
|
| 143 |
+
self.L_module = KoHRMModule(cfg, num_layers, device=device)
|
| 144 |
+
|
| 145 |
+
def forward(
|
| 146 |
+
self,
|
| 147 |
+
input_ids: torch.Tensor,
|
| 148 |
+
position_ids: torch.Tensor,
|
| 149 |
+
caches: dict[str, list[list[dict[str, torch.Tensor]]]] | None,
|
| 150 |
+
cache_pos: int,
|
| 151 |
+
) -> torch.Tensor:
|
| 152 |
+
x = self.embedding_scale * self.embed_tokens(input_ids)
|
| 153 |
+
cos, sin = _rope_cos_sin(position_ids, self.cfg["head_dim"], float(self.cfg["rope_theta"]), x.dtype)
|
| 154 |
+
z_h = x
|
| 155 |
+
z_l = self.z_L_init.to(dtype=x.dtype).view(1, 1, -1).expand_as(x)
|
| 156 |
+
|
| 157 |
+
h_cycles, l_cycles = int(self.cfg["H_cycles"]), int(self.cfg["L_cycles"])
|
| 158 |
+
for h_idx in range(h_cycles):
|
| 159 |
+
for l_idx in range(l_cycles):
|
| 160 |
+
pass_idx = h_idx * l_cycles + l_idx
|
| 161 |
+
z_l = self.L_module(z_l, z_h, cos, sin, None if caches is None else caches["L"][pass_idx], cache_pos)
|
| 162 |
+
z_h = self.H_module(z_h, z_l, cos, sin, None if caches is None else caches["H"][h_idx], cache_pos)
|
| 163 |
+
return z_h
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
class KoHRMTextForGeneration(nn.Module):
|
| 167 |
+
def __init__(self, cfg: dict[str, Any], num_layers: int, device: str = "meta") -> None:
|
| 168 |
+
super().__init__()
|
| 169 |
+
self.cfg = cfg
|
| 170 |
+
self.num_layers = num_layers
|
| 171 |
+
self.model = KoHRMCore(cfg, num_layers, device=device)
|
| 172 |
+
self.lm_head = nn.Linear(cfg["hidden_size"], cfg["vocab_size"], bias=False, device=device)
|
| 173 |
+
|
| 174 |
+
def forward(
|
| 175 |
+
self,
|
| 176 |
+
input_ids: torch.Tensor,
|
| 177 |
+
position_ids: torch.Tensor,
|
| 178 |
+
caches: dict[str, list[list[dict[str, torch.Tensor]]]] | None = None,
|
| 179 |
+
cache_pos: int = 0,
|
| 180 |
+
) -> torch.Tensor:
|
| 181 |
+
hidden = self.model(input_ids, position_ids, caches, cache_pos)
|
| 182 |
+
return self.lm_head(hidden)
|
| 183 |
+
|
| 184 |
+
def init_cache(self, batch_size: int, max_seq_len: int, device: torch.device, dtype: torch.dtype) -> dict[str, list[list[dict[str, torch.Tensor]]]]:
|
| 185 |
+
heads, head_dim = int(self.cfg["num_attention_heads"]), int(self.cfg["head_dim"])
|
| 186 |
+
|
| 187 |
+
def one_layer() -> dict[str, torch.Tensor]:
|
| 188 |
+
shape = (batch_size, max_seq_len, heads, head_dim)
|
| 189 |
+
return {
|
| 190 |
+
"k": torch.empty(shape, device=device, dtype=dtype),
|
| 191 |
+
"v": torch.empty(shape, device=device, dtype=dtype),
|
| 192 |
+
}
|
| 193 |
+
|
| 194 |
+
def one_pass() -> list[dict[str, torch.Tensor]]:
|
| 195 |
+
return [one_layer() for _ in range(self.num_layers)]
|
| 196 |
+
|
| 197 |
+
return {
|
| 198 |
+
"H": [one_pass() for _ in range(int(self.cfg["H_cycles"]))],
|
| 199 |
+
"L": [one_pass() for _ in range(int(self.cfg["H_cycles"]) * int(self.cfg["L_cycles"]))],
|
| 200 |
+
}
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def _module_layer_count(state: dict[str, torch.Tensor], prefix: str) -> int:
|
| 204 |
+
layers = set()
|
| 205 |
+
marker = f"{prefix}.layers."
|
| 206 |
+
for key in state:
|
| 207 |
+
if key.startswith(marker):
|
| 208 |
+
layers.add(int(key[len(marker) :].split(".", 1)[0]))
|
| 209 |
+
return max(layers) + 1
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def load_kohrm(repo_dir: str | Path, device: str | None = None, max_gpu_memory_gib: float | None = None) -> tuple[KoHRMTextForGeneration, Tokenizer, dict[str, Any]]:
|
| 213 |
+
repo_dir = Path(repo_dir)
|
| 214 |
+
cfg = json.loads((repo_dir / "config.json").read_text())
|
| 215 |
+
tokenizer = Tokenizer.from_file(str(repo_dir / "tokenizer.json"))
|
| 216 |
+
|
| 217 |
+
state = load_file(str(repo_dir / "model.safetensors"), device="cpu")
|
| 218 |
+
num_layers = _module_layer_count(state, "model.H_module")
|
| 219 |
+
model = KoHRMTextForGeneration(cfg, num_layers=num_layers, device="meta")
|
| 220 |
+
model.load_state_dict(state, strict=True, assign=True)
|
| 221 |
+
del state
|
| 222 |
+
|
| 223 |
+
if device is None:
|
| 224 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 225 |
+
target = torch.device(device)
|
| 226 |
+
dtype = torch.float16 if target.type == "cuda" else torch.float32
|
| 227 |
+
model = model.to(device=target, dtype=dtype).eval()
|
| 228 |
+
if target.type == "cuda":
|
| 229 |
+
torch.set_float32_matmul_precision("high")
|
| 230 |
+
if target.type == "cuda" and max_gpu_memory_gib is not None:
|
| 231 |
+
free, total = torch.cuda.mem_get_info()
|
| 232 |
+
print(f"GPU memory free/total GiB: {free / 2**30:.2f}/{total / 2**30:.2f}")
|
| 233 |
+
return model, tokenizer, cfg
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def condition_to_tokens(condition: str = "direct", mapping: dict[str, str] | None = None) -> str:
|
| 237 |
+
"""Map upstream HRM-Text condition names to tokenizer control tokens."""
|
| 238 |
+
mapping = mapping or DEFAULT_CONDITION_TOKENS
|
| 239 |
+
pieces: list[str] = []
|
| 240 |
+
for raw_name in condition.split(","):
|
| 241 |
+
name = raw_name.strip()
|
| 242 |
+
if not name:
|
| 243 |
+
continue
|
| 244 |
+
if name not in mapping:
|
| 245 |
+
valid = ", ".join(sorted(mapping))
|
| 246 |
+
raise ValueError(f"Unknown condition {name!r}; expected one of: {valid}")
|
| 247 |
+
pieces.append(mapping[name])
|
| 248 |
+
if not pieces:
|
| 249 |
+
pieces.append(mapping["direct"])
|
| 250 |
+
return "".join(pieces)
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
def format_kohrm_prompt(
|
| 254 |
+
prompt: str,
|
| 255 |
+
condition: str = "direct",
|
| 256 |
+
condition_token: str | None = None,
|
| 257 |
+
) -> str:
|
| 258 |
+
"""Format prompts like upstream InferenceCheckpoint.tokenize_prompt().
|
| 259 |
+
|
| 260 |
+
Upstream wraps prompts as:
|
| 261 |
+
`<boq><condition_tokens><instruction><eoq>`.
|
| 262 |
+
|
| 263 |
+
For answer-only generation use condition="direct", which maps to
|
| 264 |
+
`<|object_ref_start|>` in the KoHRM tokenizer. `condition_token` is kept
|
| 265 |
+
for backward compatibility and overrides `condition` when supplied.
|
| 266 |
+
"""
|
| 267 |
+
if condition_token is None:
|
| 268 |
+
condition_token = condition_to_tokens(condition)
|
| 269 |
+
return f"<|im_start|>{condition_token}{prompt}<|im_end|>"
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
def _apply_repetition_penalty(logits: torch.Tensor, seen_ids: list[int], penalty: float) -> torch.Tensor:
|
| 273 |
+
if penalty <= 1.0 or not seen_ids:
|
| 274 |
+
return logits
|
| 275 |
+
for token_id in set(seen_ids):
|
| 276 |
+
value = logits[..., token_id]
|
| 277 |
+
logits[..., token_id] = torch.where(value < 0, value * penalty, value / penalty)
|
| 278 |
+
return logits
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
def _apply_no_repeat_ngram(logits: torch.Tensor, seen_ids: list[int], ngram_size: int) -> torch.Tensor:
|
| 282 |
+
if ngram_size <= 0 or len(seen_ids) < ngram_size - 1:
|
| 283 |
+
return logits
|
| 284 |
+
prefix = tuple(seen_ids[-(ngram_size - 1):])
|
| 285 |
+
blocked: set[int] = set()
|
| 286 |
+
for idx in range(len(seen_ids) - ngram_size + 1):
|
| 287 |
+
if tuple(seen_ids[idx:idx + ngram_size - 1]) == prefix:
|
| 288 |
+
blocked.add(seen_ids[idx + ngram_size - 1])
|
| 289 |
+
if blocked:
|
| 290 |
+
logits[..., list(blocked)] = -torch.inf
|
| 291 |
+
return logits
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
def _sample_next(
|
| 295 |
+
logits: torch.Tensor,
|
| 296 |
+
temperature: float,
|
| 297 |
+
top_p: float,
|
| 298 |
+
seen_ids: list[int] | None = None,
|
| 299 |
+
repetition_penalty: float = 1.0,
|
| 300 |
+
no_repeat_ngram_size: int = 0,
|
| 301 |
+
blocked_ids: set[int] | None = None,
|
| 302 |
+
) -> int:
|
| 303 |
+
logits = logits.float()
|
| 304 |
+
seen_ids = seen_ids or []
|
| 305 |
+
logits = _apply_repetition_penalty(logits, seen_ids, repetition_penalty)
|
| 306 |
+
logits = _apply_no_repeat_ngram(logits, seen_ids, no_repeat_ngram_size)
|
| 307 |
+
if blocked_ids:
|
| 308 |
+
logits[..., list(blocked_ids)] = -torch.inf
|
| 309 |
+
if temperature <= 0:
|
| 310 |
+
return int(torch.argmax(logits, dim=-1).item())
|
| 311 |
+
probs = torch.softmax(logits / temperature, dim=-1)
|
| 312 |
+
if top_p < 1.0:
|
| 313 |
+
sorted_probs, sorted_idx = torch.sort(probs, descending=True)
|
| 314 |
+
keep = torch.cumsum(sorted_probs, dim=-1) <= top_p
|
| 315 |
+
keep[..., 0] = True
|
| 316 |
+
sorted_probs = sorted_probs.masked_fill(~keep, 0)
|
| 317 |
+
sorted_probs = sorted_probs / sorted_probs.sum(dim=-1, keepdim=True)
|
| 318 |
+
next_sorted = torch.multinomial(sorted_probs, num_samples=1)
|
| 319 |
+
return int(sorted_idx.gather(-1, next_sorted).item())
|
| 320 |
+
return int(torch.multinomial(probs, num_samples=1).item())
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
@torch.inference_mode()
|
| 324 |
+
def generate_from_loaded(
|
| 325 |
+
model: KoHRMTextForGeneration,
|
| 326 |
+
tokenizer: Tokenizer,
|
| 327 |
+
cfg: dict[str, Any],
|
| 328 |
+
prompt: str,
|
| 329 |
+
*,
|
| 330 |
+
max_new_tokens: int = 64,
|
| 331 |
+
min_new_tokens: int = 0,
|
| 332 |
+
max_seq_len: int = 512,
|
| 333 |
+
temperature: float = 0.0,
|
| 334 |
+
top_p: float = 0.9,
|
| 335 |
+
repetition_penalty: float = 1.18,
|
| 336 |
+
no_repeat_ngram_size: int = 4,
|
| 337 |
+
condition: str = "direct",
|
| 338 |
+
condition_token: str | None = None,
|
| 339 |
+
) -> str:
|
| 340 |
+
dev = next(model.parameters()).device
|
| 341 |
+
dtype = next(model.parameters()).dtype
|
| 342 |
+
wrapped = format_kohrm_prompt(prompt, condition=condition, condition_token=condition_token)
|
| 343 |
+
input_ids = tokenizer.encode(wrapped, add_special_tokens=False).ids
|
| 344 |
+
if len(input_ids) + max_new_tokens + 1 > max_seq_len:
|
| 345 |
+
raise ValueError(f"Prompt plus generation exceeds max_seq_len={max_seq_len}: prompt_tokens={len(input_ids)}")
|
| 346 |
+
|
| 347 |
+
caches = model.init_cache(1, max_seq_len, dev, dtype)
|
| 348 |
+
ids = torch.tensor([input_ids], device=dev, dtype=torch.long)
|
| 349 |
+
pos = torch.arange(ids.shape[1], device=dev, dtype=torch.long).unsqueeze(0)
|
| 350 |
+
logits = model(ids, pos, caches=caches, cache_pos=0)[:, -1, :]
|
| 351 |
+
cache_pos = ids.shape[1]
|
| 352 |
+
|
| 353 |
+
eos_id = int(cfg.get("eos_token_id") or tokenizer.token_to_id("<|box_end|>"))
|
| 354 |
+
stop_ids = {
|
| 355 |
+
eos_id,
|
| 356 |
+
tokenizer.token_to_id("<|im_end|>"),
|
| 357 |
+
tokenizer.token_to_id("<|box_end|>"),
|
| 358 |
+
}
|
| 359 |
+
stop_ids = {int(x) for x in stop_ids if x is not None}
|
| 360 |
+
out_ids: list[int] = []
|
| 361 |
+
seen_ids = list(input_ids)
|
| 362 |
+
next_id = _sample_next(
|
| 363 |
+
logits,
|
| 364 |
+
temperature,
|
| 365 |
+
top_p,
|
| 366 |
+
seen_ids,
|
| 367 |
+
repetition_penalty,
|
| 368 |
+
no_repeat_ngram_size,
|
| 369 |
+
blocked_ids=stop_ids if min_new_tokens > 0 else None,
|
| 370 |
+
)
|
| 371 |
+
for _ in range(max_new_tokens):
|
| 372 |
+
if next_id in stop_ids and len(out_ids) >= min_new_tokens:
|
| 373 |
+
break
|
| 374 |
+
out_ids.append(next_id)
|
| 375 |
+
seen_ids.append(next_id)
|
| 376 |
+
token = torch.tensor([[next_id]], device=dev, dtype=torch.long)
|
| 377 |
+
pos = torch.tensor([[cache_pos]], device=dev, dtype=torch.long)
|
| 378 |
+
logits = model(token, pos, caches=caches, cache_pos=cache_pos)[:, -1, :]
|
| 379 |
+
cache_pos += 1
|
| 380 |
+
next_id = _sample_next(
|
| 381 |
+
logits,
|
| 382 |
+
temperature,
|
| 383 |
+
top_p,
|
| 384 |
+
seen_ids,
|
| 385 |
+
repetition_penalty,
|
| 386 |
+
no_repeat_ngram_size,
|
| 387 |
+
blocked_ids=stop_ids if len(out_ids) < min_new_tokens else None,
|
| 388 |
+
)
|
| 389 |
+
|
| 390 |
+
return tokenizer.decode(out_ids, skip_special_tokens=True).strip()
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
@torch.inference_mode()
|
| 394 |
+
def generate_text(
|
| 395 |
+
repo_dir: str | Path,
|
| 396 |
+
prompt: str,
|
| 397 |
+
*,
|
| 398 |
+
max_new_tokens: int = 64,
|
| 399 |
+
min_new_tokens: int = 0,
|
| 400 |
+
max_seq_len: int = 512,
|
| 401 |
+
temperature: float = 0.0,
|
| 402 |
+
top_p: float = 0.9,
|
| 403 |
+
repetition_penalty: float = 1.18,
|
| 404 |
+
no_repeat_ngram_size: int = 4,
|
| 405 |
+
condition: str = "direct",
|
| 406 |
+
condition_token: str | None = None,
|
| 407 |
+
device: str | None = None,
|
| 408 |
+
) -> str:
|
| 409 |
+
model, tokenizer, cfg = load_kohrm(repo_dir, device=device, max_gpu_memory_gib=14.0)
|
| 410 |
+
return generate_from_loaded(
|
| 411 |
+
model,
|
| 412 |
+
tokenizer,
|
| 413 |
+
cfg,
|
| 414 |
+
prompt,
|
| 415 |
+
max_new_tokens=max_new_tokens,
|
| 416 |
+
min_new_tokens=min_new_tokens,
|
| 417 |
+
max_seq_len=max_seq_len,
|
| 418 |
+
temperature=temperature,
|
| 419 |
+
top_p=top_p,
|
| 420 |
+
repetition_penalty=repetition_penalty,
|
| 421 |
+
no_repeat_ngram_size=no_repeat_ngram_size,
|
| 422 |
+
condition=condition,
|
| 423 |
+
condition_token=condition_token,
|
| 424 |
+
)
|
| 425 |
+
|
| 426 |
+
|
| 427 |
+
def main() -> None:
|
| 428 |
+
parser = argparse.ArgumentParser(description="Run a KoHRM-Text long generation probe without transformers.")
|
| 429 |
+
parser.add_argument("repo_dir", type=Path, help="Directory containing config.json, tokenizer.json, and model.safetensors")
|
| 430 |
+
parser.add_argument(
|
| 431 |
+
"--prompt",
|
| 432 |
+
default=(
|
| 433 |
+
"๋ค์์ ํ๊ตญ์ด ์ํค๋ฐฑ๊ณผ ๋ฌธ์ ์๋ฌธ ์ผ๋ถ์
๋๋ค. ๋ฐฑ๊ณผ์ฌ์ ์ ํ๊ตญ์ด, "
|
| 434 |
+
"๊ณ ์ ๋ช
์ฌ, ๋ ์ง, ๊ธฐ์ /์ฌํ/๋ฌธํ ์ง์์ ๊ทธ๋๋ก ํ์ตํ์ญ์์ค.\n\n"
|
| 435 |
+
"[๋ฌธ์๋ช
]\nํ๋ฏผ์ ์\n\n[๋ถ๋ถ]\n1/1"
|
| 436 |
+
),
|
| 437 |
+
)
|
| 438 |
+
parser.add_argument("--max-new-tokens", type=int, default=384)
|
| 439 |
+
parser.add_argument("--min-new-tokens", type=int, default=160)
|
| 440 |
+
parser.add_argument("--max-seq-len", type=int, default=1536)
|
| 441 |
+
parser.add_argument("--temperature", type=float, default=0.65)
|
| 442 |
+
parser.add_argument("--top-p", type=float, default=0.92)
|
| 443 |
+
parser.add_argument("--repetition-penalty", type=float, default=1.05)
|
| 444 |
+
parser.add_argument("--no-repeat-ngram-size", type=int, default=0)
|
| 445 |
+
parser.add_argument(
|
| 446 |
+
"--condition",
|
| 447 |
+
default="direct",
|
| 448 |
+
help="Comma-separated HRM-Text condition names: direct, cot, noisy, synth. Use direct for answer-only outputs.",
|
| 449 |
+
)
|
| 450 |
+
parser.add_argument(
|
| 451 |
+
"--condition-token",
|
| 452 |
+
default=None,
|
| 453 |
+
help="Optional raw condition token override. Normally use --condition direct instead.",
|
| 454 |
+
)
|
| 455 |
+
parser.add_argument("--device", default=None)
|
| 456 |
+
args = parser.parse_args()
|
| 457 |
+
print(generate_text(
|
| 458 |
+
args.repo_dir,
|
| 459 |
+
args.prompt,
|
| 460 |
+
max_new_tokens=args.max_new_tokens,
|
| 461 |
+
min_new_tokens=args.min_new_tokens,
|
| 462 |
+
max_seq_len=args.max_seq_len,
|
| 463 |
+
temperature=args.temperature,
|
| 464 |
+
top_p=args.top_p,
|
| 465 |
+
repetition_penalty=args.repetition_penalty,
|
| 466 |
+
no_repeat_ngram_size=args.no_repeat_ngram_size,
|
| 467 |
+
condition=args.condition,
|
| 468 |
+
condition_token=args.condition_token,
|
| 469 |
+
device=args.device,
|
| 470 |
+
))
|
| 471 |
+
|
| 472 |
+
|
| 473 |
+
if __name__ == "__main__":
|
| 474 |
+
main()
|