gyung commited on
Commit
24cac6a
ยท
verified ยท
1 Parent(s): 9345300

Add KoHRM CPU quantized runtime pack

Browse files
README.md ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ base_model: LLM-OS-Models/KoHRM-Text-1.4B
4
+ base_model_relation: quantized
5
+ library_name: pytorch
6
+ tags:
7
+ - kohrm
8
+ - hrm-text
9
+ - cpu
10
+ - int8
11
+ - int4
12
+ - korean
13
+ - terminal
14
+ ---
15
+
16
+ # KoHRM-Text-1.4B CPU Runtime
17
+
18
+ This repository contains a CPU-oriented inference runtime for
19
+ `LLM-OS-Models/KoHRM-Text-1.4B`.
20
+
21
+ It does not duplicate the original model weights. The runtime downloads the
22
+ base model from Hugging Face and applies CPU quantization at load time.
23
+
24
+ # KoHRM-Text CPU Runtime Pack
25
+
26
+ ์ž‘์„ฑ์ผ: `2026-06-09`
27
+
28
+ ## ๊ฒฐ๋ก 
29
+
30
+ `LLM-OS-Models/KoHRM-Text-1.4B`๋Š” ํ˜„์žฌ GGUF๋กœ ๋ฐ”๋กœ ๋งŒ๋“ค ์ˆ˜ ์—†๋‹ค.
31
+
32
+ ์ด์œ ๋Š” ๋ชจ๋ธ ๊ตฌ์กฐ๊ฐ€ ์ผ๋ฐ˜ Llama/Qwen/Gemma ๊ณ„์—ด์ด ์•„๋‹ˆ๋ผ ์•„๋ž˜ ์ „์šฉ ๊ตฌ์กฐ์ด๊ธฐ ๋•Œ๋ฌธ์ด๋‹ค.
33
+
34
+ ```text
35
+ model_type: hrm_text
36
+ architectures: HrmTextForCausalLM
37
+ H_cycles: 2
38
+ L_cycles: 3
39
+ prefix_lm: true
40
+ ```
41
+
42
+ llama.cpp ๋ณ€ํ™˜๊ธฐ๋กœ ์ง์ ‘ ์‹œ๋„ํ•˜๋ฉด ๋‹ค์Œ ์ง€์ ์—์„œ ๋ง‰ํžŒ๋‹ค.
43
+
44
+ ```text
45
+ ERROR:hf-to-gguf:Model HrmTextForCausalLM is not supported
46
+ ```
47
+
48
+ ๋”ฐ๋ผ์„œ ์ง€๊ธˆ ํ˜„์‹ค์ ์ธ CPU ๊ฒฝ๋กœ๋Š” GGUF๊ฐ€ ์•„๋‹ˆ๋ผ PyTorch ์ „์šฉ runtime์ด๋‹ค.
49
+
50
+ ## ์ถ”๊ฐ€ํ•œ ํŒŒ์ผ
51
+
52
+ ```text
53
+ HRM-Text/inference/kohrm_cpu_runtime.py
54
+ HRM-Text/inference/requirements-cpu.txt
55
+ HRM-Text/scripts/upload_kohrm_cpu_runtime_pack.py
56
+ ```
57
+
58
+ ์ด runtime์€ ๊ธฐ์กด `HRM-Text/notebooks/kohrm_colab_generate.py`์˜ safetensors ์ง์ ‘ ๋กœ๋”ฉ ๊ฒฝ๋กœ๋ฅผ ์žฌ์‚ฌ์šฉํ•˜๊ณ , CPU์šฉ ์–‘์žํ™”์™€ H/L cycle override๋ฅผ ์ถ”๊ฐ€ํ•œ๋‹ค.
59
+
60
+ ## ์‚ฌ์šฉ๋ฒ•
61
+
62
+ ๊ธฐ๋ณธ ๊ถŒ์žฅ๊ฐ’์€ `dynamic-int8`์ด๋‹ค.
63
+
64
+ ```bash
65
+ cd /home/work/.projects/LLM-OS-Models/Terminal
66
+
67
+ CUDA_VISIBLE_DEVICES= OMP_NUM_THREADS=8 \
68
+ python HRM-Text/inference/kohrm_cpu_runtime.py \
69
+ --model LLM-OS-Models/KoHRM-Text-1.4B \
70
+ --quant dynamic-int8 \
71
+ --prompt "๋ฆฌ๋ˆ…์Šค์—์„œ ํ˜„์žฌ ๋””๋ ‰ํ† ๋ฆฌ ํŒŒ์ผ ๋ชฉ๋ก์„ ๋ณด๋Š” ๋ช…๋ น์–ด๋Š”?" \
72
+ --max-new-tokens 128 \
73
+ --max-seq-len 768 \
74
+ --temperature 0
75
+ ```
76
+
77
+ 16GB CPU RAM ํ™˜๊ฒฝ์—์„œ๋Š” ์•„๋ž˜ ์ˆœ์„œ๋กœ ์“ฐ๋ฉด ๋œ๋‹ค.
78
+
79
+ ```text
80
+ 1์ˆœ์œ„: dynamic-int8
81
+ 2์ˆœ์œ„: none
82
+ 3์ˆœ์œ„: weight-int4
83
+ ```
84
+
85
+ `dynamic-int8`์€ PyTorch CPU dynamic quantization์„ ์‚ฌ์šฉํ•œ๋‹ค. ์ผ๋ฐ˜์ ์œผ๋กœ ๋ฉ”๋ชจ๋ฆฌ์™€ ์†๋„ ๊ท ํ˜•์ด ๊ฐ€์žฅ ๋‚ซ๋‹ค.
86
+
87
+ `weight-int4`๋Š” ์ง์ ‘ ๊ตฌํ˜„ํ•œ portable 4bit weight-only fallback์ด๋‹ค. ๋ฉ”๋ชจ๋ฆฌ๋Š” ์ค„์ง€๋งŒ ๋งค forward๋งˆ๋‹ค unpack/dequantize๊ฐ€ ๋“ค์–ด๊ฐ€์„œ ๋งค์šฐ ๋А๋ฆฌ๋‹ค. โ€œ๋ฐ˜๋“œ์‹œ ์ž‘์€ ๋ฉ”๋ชจ๋ฆฌ๋กœ ๋Œ์•„๊ฐ€์•ผ ํ•œ๋‹คโ€๋Š” ๊ฒฝ์šฐ์—๋งŒ ์“ด๋‹ค.
88
+
89
+ ## H/L cycle override
90
+
91
+ KoHRM์€ ๊ฐ™์€ H/L module์„ ๋ฐ˜๋ณต ์ ์šฉํ•œ๋‹ค. ๊ธฐ๋ณธ์€ `H=2`, `L=3`์ด๋‹ค.
92
+
93
+ CPU์—์„œ๋Š” ์•„๋ž˜์ฒ˜๋Ÿผ ๋ฐ˜๋ณต ํšŸ์ˆ˜๋ฅผ ์ค„์—ฌ ์†๋„๋ฅผ ์˜ฌ๋ฆด ์ˆ˜ ์žˆ๋‹ค.
94
+
95
+ ```bash
96
+ CUDA_VISIBLE_DEVICES= OMP_NUM_THREADS=8 \
97
+ python HRM-Text/inference/kohrm_cpu_runtime.py \
98
+ --model LLM-OS-Models/KoHRM-Text-1.4B \
99
+ --quant dynamic-int8 \
100
+ --h-cycles 1 \
101
+ --l-cycles 1 \
102
+ --prompt "๋ฆฌ๋ˆ…์Šค์—์„œ ํ˜„์žฌ ๋””๋ ‰ํ† ๋ฆฌ ํŒŒ์ผ ๋ชฉ๋ก์„ ๋ณด๋Š” ๋ช…๋ น์–ด๋Š”?" \
103
+ --max-new-tokens 128 \
104
+ --max-seq-len 768 \
105
+ --temperature 0
106
+ ```
107
+
108
+ ์ฃผ์˜ํ•  ์ ์€ ๋ช…ํ™•ํ•˜๋‹ค.
109
+
110
+ - `H=2,L=3`: ์›๋ž˜ ํ’ˆ์งˆ ๊ฒฝ๋กœ.
111
+ - `H=1,L=1`: CPU ์†๋„ ์šฐ์„  ๊ฒฝ๋กœ.
112
+ - cycle์„ ์ค„์ด๋ฉด ํ’ˆ์งˆ์ด ๋–จ์–ด์งˆ ์ˆ˜ ์žˆ๋‹ค.
113
+
114
+ ## Smoke test ๊ฒฐ๊ณผ
115
+
116
+ ๊ฐ™์€ ์งง์€ prompt, `max_new_tokens=4`, `max_seq_len=128`, `OMP_NUM_THREADS=8` ๊ธฐ์ค€์ด๋‹ค.
117
+
118
+ ```text
119
+ none:
120
+ elapsed: 1.48s
121
+ speed: 2.69 tok/s
122
+ cycles: H=2, L=3
123
+
124
+ dynamic-int8:
125
+ elapsed: 0.53s
126
+ speed: 7.59 tok/s
127
+ cycles: H=2, L=3
128
+
129
+ dynamic-int8 + H=1,L=1:
130
+ elapsed: 0.24s
131
+ speed: 8.18 tok/s
132
+ cycles: H=1, L=1
133
+
134
+ weight-int4:
135
+ elapsed: 23.25s
136
+ speed: 0.17 tok/s
137
+ cycles: H=2, L=3
138
+ ```
139
+
140
+ ์งง์€ smoke test๋ผ ์ ˆ๋Œ€ ์„ฑ๋Šฅ ์ˆซ์ž๋Š” ์ฐธ๊ณ ์šฉ์ด๋‹ค. ํ•˜์ง€๋งŒ ๋ฐฉํ–ฅ์€ ๋ถ„๋ช…ํ•˜๋‹ค.
141
+
142
+ ```text
143
+ ์‹ค์‚ฌ์šฉ: dynamic-int8
144
+ ๋ฉ”๋ชจ๋ฆฌ ๊ฐ•์ œ ์ ˆ์•ฝ: weight-int4
145
+ ํ’ˆ์งˆ ์œ ์ง€: H=2,L=3
146
+ ์†๋„ ์šฐ์„ : H=1,L=1
147
+ ```
148
+
149
+ ## ์™œ GGUF๊ฐ€ ์–ด๋ ค์šด๊ฐ€
150
+
151
+ GGUF ํŒŒ์ผ์€ ๋‹จ์ˆœํžˆ weight๋ฅผ ๋‹ด๋Š” ํฌ๋งท์ด ์•„๋‹ˆ๋‹ค. llama.cpp๊ฐ€ ํ•ด๋‹น architecture์˜ forward pass๋ฅผ ์•Œ์•„์•ผ ํ•œ๋‹ค.
152
+
153
+ KoHRM์€ ์ผ๋ฐ˜ Transformer block์„ ํ•œ ๋ฒˆ์”ฉ ์Œ“๋Š” ๋ชจ๋ธ์ด ์•„๋‹ˆ๋‹ค.
154
+
155
+ - H module๊ณผ L module์ด ์žˆ๋‹ค.
156
+ - `H_cycles`, `L_cycles`๋งŒํผ recurrentํ•˜๊ฒŒ ๋ฐ˜๋ณตํ•œ๋‹ค.
157
+ - PrefixLM formatting๊ณผ stop token ์ฒ˜๋ฆฌ๊ฐ€ ๋‹ค๋ฅด๋‹ค.
158
+ - KV cache ๊ตฌ์กฐ๋„ ์ผ๋ฐ˜ chat causal LM๊ณผ ๋‹ค๋ฅด๋‹ค.
159
+
160
+ ๋”ฐ๋ผ์„œ GGUF๋ฅผ ์ œ๋Œ€๋กœ ๋งŒ๋“ค๋ ค๋ฉด ๋‹ค์Œ ์ž‘์—…์ด ํ•„์š”ํ•˜๋‹ค.
161
+
162
+ ```text
163
+ 1. llama.cpp MODEL_ARCH์— HRM_TEXT ์ถ”๊ฐ€
164
+ 2. H/L recurrent forward ๊ตฌํ˜„
165
+ 3. gqkv gated attention ๊ตฌํ˜„
166
+ 4. PrefixLM prompt/token boundary ์ฒ˜๋ฆฌ
167
+ 5. tokenizer pre-tokenizer hash ๋“ฑ๋ก
168
+ 6. quantized tensor name mapping ์ž‘์„ฑ
169
+ 7. llama-cli generation smoke test
170
+ ```
171
+
172
+ ๋‹จ์ˆœ converter patch๋กœ ๋๋‚˜๋Š” ๋ฌธ์ œ๊ฐ€ ์•„๋‹ˆ๋‹ค.
173
+
174
+ ## HF CPU pack
175
+
176
+ HF์—๋Š” ๊ฐ€์ค‘์น˜๋ฅผ ์ค‘๋ณต ์—…๋กœ๋“œํ•˜์ง€ ์•Š๊ณ  CPU runtime pack์„ ๋”ฐ๋กœ ์˜ฌ๋ฆฐ๋‹ค.
177
+
178
+ ๋Œ€์ƒ repo:
179
+
180
+ ```text
181
+ LLM-OS-Models/KoHRM-Text-1.4B-CPU-Runtime
182
+ ```
183
+
184
+ ์ด repo์—๋Š” ๋‹ค์Œ๋งŒ ๋“ค์–ด๊ฐ„๋‹ค.
185
+
186
+ ```text
187
+ README.md
188
+ inference/kohrm_cpu_runtime.py
189
+ inference/requirements-cpu.txt
190
+ notebooks/kohrm_colab_generate.py
191
+ ```
192
+
193
+ ๊ฐ€์ค‘์น˜๋Š” ๏ฟฝ๏ฟฝ๏ฟฝํ–‰ ์‹œ ์›๋ณธ repo์—์„œ ๋ฐ›๋Š”๋‹ค.
194
+
195
+ ```text
196
+ LLM-OS-Models/KoHRM-Text-1.4B
197
+ ```
198
+
199
+ ๊ณต์šฉ์ปด ๊ธฐ์ค€์œผ๋กœ HF token์€ `.env`์—์„œ ์ฝ๋˜ ์ถœ๋ ฅํ•˜์ง€ ์•Š๋Š”๋‹ค.
inference/kohrm_cpu_runtime.py ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """CPU-oriented KoHRM-Text inference runtime.
2
+
3
+ KoHRM-Text uses the custom ``hrm_text`` / ``HrmTextForCausalLM`` architecture,
4
+ so it cannot currently be served by llama.cpp/GGUF or ordinary vLLM paths.
5
+ This runtime wraps the existing safetensors loader and adds CPU-friendly
6
+ quantization and cycle overrides.
7
+
8
+ Recommended mode for normal CPU use:
9
+
10
+ python HRM-Text/inference/kohrm_cpu_runtime.py \
11
+ --model LLM-OS-Models/KoHRM-Text-1.4B \
12
+ --quant dynamic-int8 \
13
+ --prompt "๋ฆฌ๋ˆ…์Šค์—์„œ ํ˜„์žฌ ๋””๋ ‰ํ† ๋ฆฌ ํŒŒ์ผ ๋ชฉ๋ก์„ ๋ณด๋Š” ๋ช…๋ น์–ด๋Š”?" \
14
+ --max-new-tokens 64
15
+
16
+ Experimental memory-first mode:
17
+
18
+ python HRM-Text/inference/kohrm_cpu_runtime.py --quant weight-int4 ...
19
+ """
20
+
21
+ from __future__ import annotations
22
+
23
+ import argparse
24
+ import gc
25
+ import importlib.util
26
+ import json
27
+ import math
28
+ import os
29
+ import shutil
30
+ import sys
31
+ import time
32
+ from dataclasses import dataclass
33
+ from pathlib import Path
34
+ from typing import Any
35
+
36
+ import torch
37
+ import torch.nn as nn
38
+ import torch.nn.functional as F
39
+ from huggingface_hub import snapshot_download
40
+
41
+
42
+ REPO_ROOT = Path(__file__).resolve().parents[1]
43
+ HELPER_PATH = REPO_ROOT / "notebooks" / "kohrm_colab_generate.py"
44
+ DEFAULT_REPO_ID = "LLM-OS-Models/KoHRM-Text-1.4B"
45
+
46
+
47
+ def _load_helper():
48
+ if not HELPER_PATH.exists():
49
+ raise FileNotFoundError(f"missing KoHRM helper: {HELPER_PATH}")
50
+ spec = importlib.util.spec_from_file_location("kohrm_colab_generate", HELPER_PATH)
51
+ if spec is None or spec.loader is None:
52
+ raise RuntimeError(f"cannot import helper from {HELPER_PATH}")
53
+ module = importlib.util.module_from_spec(spec)
54
+ sys.modules.setdefault("kohrm_colab_generate", module)
55
+ spec.loader.exec_module(module)
56
+ return module
57
+
58
+
59
+ def _read_dotenv_token() -> str | None:
60
+ """Read a local HF token without printing it or exporting it to shell logs."""
61
+ candidates = [
62
+ Path.cwd() / ".env",
63
+ REPO_ROOT.parent / ".env",
64
+ REPO_ROOT / ".env",
65
+ Path.home() / ".cache" / "huggingface" / "token",
66
+ ]
67
+ for path in candidates:
68
+ if not path.exists():
69
+ continue
70
+ if path.name == "token":
71
+ token = path.read_text(encoding="utf-8").strip()
72
+ return token or None
73
+ for raw in path.read_text(encoding="utf-8", errors="ignore").splitlines():
74
+ line = raw.strip()
75
+ if not line or line.startswith("#") or "=" not in line:
76
+ continue
77
+ key, value = line.split("=", 1)
78
+ key = key.strip()
79
+ if key.startswith("export "):
80
+ key = key.split(None, 1)[1]
81
+ if key in {"HF_TOKEN", "HUGGINGFACE_TOKEN", "HUGGING_FACE_HUB_TOKEN"}:
82
+ token = value.strip().strip('"').strip("'")
83
+ return token or None
84
+ return os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_TOKEN")
85
+
86
+
87
+ def resolve_model_dir(model: str, revision: str | None = None) -> Path:
88
+ path = Path(model).expanduser()
89
+ if path.exists():
90
+ return path
91
+ token = _read_dotenv_token()
92
+ return Path(
93
+ snapshot_download(
94
+ repo_id=model,
95
+ revision=revision,
96
+ allow_patterns=["config.json", "tokenizer.json", "tokenizer_config.json", "model.safetensors", "README.md"],
97
+ token=token,
98
+ )
99
+ )
100
+
101
+
102
+ @dataclass
103
+ class RuntimeStats:
104
+ prompt_tokens: int
105
+ generated_tokens: int
106
+ elapsed_s: float
107
+ tokens_per_s: float
108
+ quantization: str
109
+ h_cycles: int
110
+ l_cycles: int
111
+ dtype: str
112
+
113
+
114
+ class WeightOnlyInt8Linear(nn.Module):
115
+ """Simple symmetric per-group int8 weight-only Linear.
116
+
117
+ This is a portability fallback, not an optimized kernel. It reduces resident
118
+ weight memory after conversion, but dequantizes on forward. For speed, prefer
119
+ PyTorch dynamic int8.
120
+ """
121
+
122
+ def __init__(self, qweight: torch.Tensor, scales: torch.Tensor, in_features: int, out_features: int, group_size: int) -> None:
123
+ super().__init__()
124
+ self.in_features = int(in_features)
125
+ self.out_features = int(out_features)
126
+ self.group_size = int(group_size)
127
+ self.register_buffer("qweight", qweight.contiguous())
128
+ self.register_buffer("scales", scales.contiguous())
129
+
130
+ @classmethod
131
+ def from_linear(cls, linear: nn.Linear, group_size: int = 128) -> "WeightOnlyInt8Linear":
132
+ weight = linear.weight.detach().to(dtype=torch.float32, device="cpu")
133
+ out_features, in_features = weight.shape
134
+ pad = (-in_features) % group_size
135
+ if pad:
136
+ weight = F.pad(weight, (0, pad))
137
+ grouped = weight.view(out_features, -1, group_size)
138
+ scales = grouped.abs().amax(dim=-1).clamp_min(1e-8) / 127.0
139
+ qweight = torch.round(grouped / scales.unsqueeze(-1)).clamp(-127, 127).to(torch.int8)
140
+ return cls(qweight=qweight, scales=scales.to(torch.float16), in_features=in_features, out_features=out_features, group_size=group_size)
141
+
142
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
143
+ weight = (self.qweight.to(torch.float32) * self.scales.to(torch.float32).unsqueeze(-1)).view(self.out_features, -1)
144
+ weight = weight[:, : self.in_features].to(dtype=x.dtype)
145
+ return F.linear(x, weight)
146
+
147
+
148
+ class WeightOnlyInt4Linear(nn.Module):
149
+ """Portable symmetric per-group int4 weight-only Linear.
150
+
151
+ Values are stored as packed signed nibbles. Forward unpacks and dequantizes
152
+ on CPU, so this is memory-first rather than speed-first.
153
+ """
154
+
155
+ def __init__(self, packed: torch.Tensor, scales: torch.Tensor, in_features: int, out_features: int, padded_features: int, group_size: int) -> None:
156
+ super().__init__()
157
+ self.in_features = int(in_features)
158
+ self.out_features = int(out_features)
159
+ self.padded_features = int(padded_features)
160
+ self.group_size = int(group_size)
161
+ self.register_buffer("packed", packed.contiguous())
162
+ self.register_buffer("scales", scales.contiguous())
163
+
164
+ @classmethod
165
+ def from_linear(cls, linear: nn.Linear, group_size: int = 128) -> "WeightOnlyInt4Linear":
166
+ weight = linear.weight.detach().to(dtype=torch.float32, device="cpu")
167
+ out_features, in_features = weight.shape
168
+ pad_group = (-in_features) % group_size
169
+ if pad_group:
170
+ weight = F.pad(weight, (0, pad_group))
171
+ if weight.shape[1] % 2:
172
+ weight = F.pad(weight, (0, 1))
173
+ padded_features = weight.shape[1]
174
+ grouped = weight.view(out_features, -1, group_size)
175
+ scales = grouped.abs().amax(dim=-1).clamp_min(1e-8) / 7.0
176
+ q = torch.round(grouped / scales.unsqueeze(-1)).clamp(-8, 7).to(torch.int16)
177
+ q = (q + 16).remainder(16).to(torch.uint8).view(out_features, padded_features)
178
+ low = q[:, 0::2]
179
+ high = q[:, 1::2] << 4
180
+ packed = low | high
181
+ return cls(
182
+ packed=packed,
183
+ scales=scales.to(torch.float16),
184
+ in_features=in_features,
185
+ out_features=out_features,
186
+ padded_features=padded_features,
187
+ group_size=group_size,
188
+ )
189
+
190
+ def _unpack(self) -> torch.Tensor:
191
+ low = self.packed & 0x0F
192
+ high = (self.packed >> 4) & 0x0F
193
+ q = torch.empty((self.out_features, self.packed.shape[1] * 2), dtype=torch.int16, device=self.packed.device)
194
+ q[:, 0::2] = low.to(torch.int16)
195
+ q[:, 1::2] = high.to(torch.int16)
196
+ q = torch.where(q >= 8, q - 16, q)
197
+ return q[:, : self.padded_features]
198
+
199
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
200
+ q = self._unpack().to(torch.float32)
201
+ weight = (q.view(self.out_features, -1, self.group_size) * self.scales.to(torch.float32).unsqueeze(-1)).view(self.out_features, -1)
202
+ weight = weight[:, : self.in_features].to(dtype=x.dtype)
203
+ return F.linear(x, weight)
204
+
205
+
206
+ def _replace_linear_modules(module: nn.Module, *, quant: str, group_size: int, quantize_lm_head: bool, prefix: str = "") -> int:
207
+ replaced = 0
208
+ for name, child in list(module.named_children()):
209
+ child_prefix = f"{prefix}.{name}" if prefix else name
210
+ if isinstance(child, nn.Linear):
211
+ if child_prefix == "lm_head" and not quantize_lm_head:
212
+ continue
213
+ if child.bias is not None:
214
+ raise ValueError(f"bias is not supported by portable weight-only quantization: {child_prefix}")
215
+ if quant == "weight-int8":
216
+ new_child = WeightOnlyInt8Linear.from_linear(child, group_size=group_size)
217
+ elif quant == "weight-int4":
218
+ new_child = WeightOnlyInt4Linear.from_linear(child, group_size=group_size)
219
+ else:
220
+ raise ValueError(f"unsupported weight-only quantization: {quant}")
221
+ setattr(module, name, new_child)
222
+ replaced += 1
223
+ else:
224
+ replaced += _replace_linear_modules(child, quant=quant, group_size=group_size, quantize_lm_head=quantize_lm_head, prefix=child_prefix)
225
+ return replaced
226
+
227
+
228
+ def apply_quantization(
229
+ model: nn.Module,
230
+ quant: str,
231
+ *,
232
+ group_size: int = 128,
233
+ quantize_lm_head: bool = False,
234
+ ) -> nn.Module:
235
+ if quant == "none":
236
+ return model
237
+ if quant == "dynamic-int8":
238
+ torch.backends.quantized.engine = "fbgemm"
239
+ return torch.ao.quantization.quantize_dynamic(model.cpu(), {nn.Linear}, dtype=torch.qint8, inplace=False)
240
+ if quant in {"weight-int8", "weight-int4"}:
241
+ replaced = _replace_linear_modules(model, quant=quant, group_size=group_size, quantize_lm_head=quantize_lm_head)
242
+ if replaced == 0:
243
+ raise RuntimeError("no Linear modules were replaced")
244
+ gc.collect()
245
+ return model.cpu().eval()
246
+ raise ValueError(f"unknown quantization mode: {quant}")
247
+
248
+
249
+ def load_runtime(
250
+ model_dir: Path,
251
+ *,
252
+ quant: str,
253
+ h_cycles: int | None,
254
+ l_cycles: int | None,
255
+ group_size: int,
256
+ quantize_lm_head: bool,
257
+ ):
258
+ helper = _load_helper()
259
+ model, tokenizer, cfg = helper.load_kohrm(model_dir, device="cpu")
260
+ if h_cycles is not None:
261
+ cfg["H_cycles"] = int(h_cycles)
262
+ model.cfg["H_cycles"] = int(h_cycles)
263
+ model.model.cfg["H_cycles"] = int(h_cycles)
264
+ if l_cycles is not None:
265
+ cfg["L_cycles"] = int(l_cycles)
266
+ model.cfg["L_cycles"] = int(l_cycles)
267
+ model.model.cfg["L_cycles"] = int(l_cycles)
268
+ model = apply_quantization(model, quant, group_size=group_size, quantize_lm_head=quantize_lm_head)
269
+ return helper, model.eval(), tokenizer, cfg
270
+
271
+
272
+ def generate(
273
+ model: nn.Module,
274
+ tokenizer: Any,
275
+ cfg: dict[str, Any],
276
+ helper: Any,
277
+ prompt: str,
278
+ *,
279
+ max_new_tokens: int,
280
+ min_new_tokens: int,
281
+ max_seq_len: int,
282
+ temperature: float,
283
+ top_p: float,
284
+ repetition_penalty: float,
285
+ no_repeat_ngram_size: int,
286
+ condition: str,
287
+ ) -> tuple[str, RuntimeStats]:
288
+ wrapped = helper.format_kohrm_prompt(prompt, condition=condition)
289
+ prompt_tokens = len(tokenizer.encode(wrapped, add_special_tokens=False).ids)
290
+ start = time.perf_counter()
291
+ output = helper.generate_from_loaded(
292
+ model,
293
+ tokenizer,
294
+ cfg,
295
+ prompt,
296
+ max_new_tokens=max_new_tokens,
297
+ min_new_tokens=min_new_tokens,
298
+ max_seq_len=max_seq_len,
299
+ temperature=temperature,
300
+ top_p=top_p,
301
+ repetition_penalty=repetition_penalty,
302
+ no_repeat_ngram_size=no_repeat_ngram_size,
303
+ condition=condition,
304
+ )
305
+ elapsed = time.perf_counter() - start
306
+ out_tokens = len(tokenizer.encode(output, add_special_tokens=False).ids) if output else 0
307
+ stats = RuntimeStats(
308
+ prompt_tokens=prompt_tokens,
309
+ generated_tokens=out_tokens,
310
+ elapsed_s=elapsed,
311
+ tokens_per_s=(out_tokens / elapsed if elapsed > 0 else math.nan),
312
+ quantization="",
313
+ h_cycles=int(cfg.get("H_cycles", 0)),
314
+ l_cycles=int(cfg.get("L_cycles", 0)),
315
+ dtype=str(next(model.parameters()).dtype) if any(True for _ in model.parameters()) else "unknown",
316
+ )
317
+ return output, stats
318
+
319
+
320
+ def build_arg_parser() -> argparse.ArgumentParser:
321
+ ap = argparse.ArgumentParser(description="Run KoHRM-Text on CPU with optional quantization.")
322
+ ap.add_argument("--model", default=DEFAULT_REPO_ID, help="HF repo id or local directory containing KoHRM HF export files.")
323
+ ap.add_argument("--revision", default=None)
324
+ ap.add_argument("--prompt", required=True)
325
+ ap.add_argument("--quant", choices=["none", "dynamic-int8", "weight-int8", "weight-int4"], default="dynamic-int8")
326
+ ap.add_argument("--group-size", type=int, default=128)
327
+ ap.add_argument("--quantize-lm-head", action="store_true", help="Also quantize lm_head in portable weight-only modes. Saves memory but slows generation.")
328
+ ap.add_argument("--h-cycles", type=int, default=None, help="Override H_cycles. Lower values trade quality for CPU speed.")
329
+ ap.add_argument("--l-cycles", type=int, default=None, help="Override L_cycles. Lower values trade quality for CPU speed.")
330
+ ap.add_argument("--max-new-tokens", type=int, default=128)
331
+ ap.add_argument("--min-new-tokens", type=int, default=0)
332
+ ap.add_argument("--max-seq-len", type=int, default=768)
333
+ ap.add_argument("--temperature", type=float, default=0.0)
334
+ ap.add_argument("--top-p", type=float, default=0.9)
335
+ ap.add_argument("--repetition-penalty", type=float, default=1.05)
336
+ ap.add_argument("--no-repeat-ngram-size", type=int, default=0)
337
+ ap.add_argument("--condition", default="direct", choices=["direct", "cot", "noisy", "synth"])
338
+ ap.add_argument("--json-stats", action="store_true")
339
+ return ap
340
+
341
+
342
+ def main() -> None:
343
+ args = build_arg_parser().parse_args()
344
+ # Keep CPU execution predictable on shared machines.
345
+ if "OMP_NUM_THREADS" not in os.environ:
346
+ os.environ["OMP_NUM_THREADS"] = str(min(8, os.cpu_count() or 8))
347
+ model_dir = resolve_model_dir(args.model, revision=args.revision)
348
+ helper, model, tokenizer, cfg = load_runtime(
349
+ model_dir,
350
+ quant=args.quant,
351
+ h_cycles=args.h_cycles,
352
+ l_cycles=args.l_cycles,
353
+ group_size=args.group_size,
354
+ quantize_lm_head=args.quantize_lm_head,
355
+ )
356
+ output, stats = generate(
357
+ model,
358
+ tokenizer,
359
+ cfg,
360
+ helper,
361
+ args.prompt,
362
+ max_new_tokens=args.max_new_tokens,
363
+ min_new_tokens=args.min_new_tokens,
364
+ max_seq_len=args.max_seq_len,
365
+ temperature=args.temperature,
366
+ top_p=args.top_p,
367
+ repetition_penalty=args.repetition_penalty,
368
+ no_repeat_ngram_size=args.no_repeat_ngram_size,
369
+ condition=args.condition,
370
+ )
371
+ stats.quantization = args.quant
372
+ print(output)
373
+ if args.json_stats:
374
+ print(json.dumps(stats.__dict__, ensure_ascii=False), file=sys.stderr)
375
+
376
+
377
+ if __name__ == "__main__":
378
+ main()
inference/requirements-cpu.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch>=2.6
2
+ safetensors>=0.4.5
3
+ tokenizers>=0.20
4
+ huggingface_hub>=0.28
notebooks/kohrm_colab_generate.py ADDED
@@ -0,0 +1,474 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Minimal KoHRM-Text generation runtime for Colab.
2
+
3
+ This file intentionally avoids `transformers` and FlashAttention. It loads the
4
+ public `model.safetensors` export and runs HRM-Text generation with PyTorch
5
+ scaled-dot-product attention. It is built for long pretraining-checkpoint
6
+ knowledge probes on Colab T4 and small CUDA machines.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import json
12
+ import math
13
+ import argparse
14
+ from pathlib import Path
15
+ from typing import Any
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+ from safetensors.torch import load_file
21
+ from tokenizers import Tokenizer
22
+
23
+
24
+ DEFAULT_CONDITION_TOKENS = {
25
+ "direct": "<|object_ref_start|>",
26
+ "cot": "<|object_ref_end|>",
27
+ "noisy": "<|quad_start|>",
28
+ "synth": "<|quad_end|>",
29
+ }
30
+
31
+
32
+ def _rms_norm(x: torch.Tensor, eps: float) -> torch.Tensor:
33
+ return F.rms_norm(x, (x.shape[-1],), eps=eps)
34
+
35
+
36
+ def _rotate_half(x: torch.Tensor) -> torch.Tensor:
37
+ x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
38
+ return torch.cat((-x2, x1), dim=-1)
39
+
40
+
41
+ def _rope_cos_sin(position_ids: torch.Tensor, head_dim: int, theta: float, dtype: torch.dtype) -> tuple[torch.Tensor, torch.Tensor]:
42
+ inv_freq = 1.0 / (theta ** (torch.arange(0, head_dim, 2, device=position_ids.device, dtype=torch.float32) / head_dim))
43
+ freqs = torch.einsum("bt,d->btd", position_ids.to(torch.float32), inv_freq)
44
+ emb = torch.cat((freqs, freqs), dim=-1)
45
+ return emb.cos().to(dtype), emb.sin().to(dtype)
46
+
47
+
48
+ def _apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
49
+ return ((x * cos.unsqueeze(-2)) + (_rotate_half(x) * sin.unsqueeze(-2))).to(x.dtype)
50
+
51
+
52
+ class KoHRMAttention(nn.Module):
53
+ def __init__(self, hidden_size: int, num_heads: int, head_dim: int, device: str = "meta") -> None:
54
+ super().__init__()
55
+ self.num_heads = num_heads
56
+ self.head_dim = head_dim
57
+ self.gqkv_proj = nn.Linear(hidden_size, (4 * num_heads) * head_dim, bias=False, device=device)
58
+ self.o_proj = nn.Linear(num_heads * head_dim, hidden_size, bias=False, device=device)
59
+
60
+ def forward(
61
+ self,
62
+ x: torch.Tensor,
63
+ cos: torch.Tensor,
64
+ sin: torch.Tensor,
65
+ cache: dict[str, torch.Tensor] | None,
66
+ cache_pos: int,
67
+ ) -> torch.Tensor:
68
+ bsz, seqlen, _ = x.shape
69
+ gqkv = self.gqkv_proj(x).view(bsz, seqlen, 4 * self.num_heads, self.head_dim)
70
+ gate, q, k, v = gqkv.split((self.num_heads, self.num_heads, self.num_heads, self.num_heads), dim=-2)
71
+ q = _apply_rope(q, cos, sin)
72
+ k = _apply_rope(k, cos, sin)
73
+
74
+ if cache is not None:
75
+ end = cache_pos + seqlen
76
+ cache["k"][:, cache_pos:end].copy_(k)
77
+ cache["v"][:, cache_pos:end].copy_(v)
78
+ k = cache["k"][:, :end]
79
+ v = cache["v"][:, :end]
80
+
81
+ q = q.transpose(1, 2)
82
+ k = k.transpose(1, 2)
83
+ v = v.transpose(1, 2)
84
+ y = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=False)
85
+ y = y.transpose(1, 2)
86
+ y = (torch.sigmoid(gate) * y).reshape(bsz, seqlen, self.num_heads * self.head_dim)
87
+ return self.o_proj(y)
88
+
89
+
90
+ class KoHRMMLP(nn.Module):
91
+ def __init__(self, hidden_size: int, intermediate_size: int, device: str = "meta") -> None:
92
+ super().__init__()
93
+ self.gate_up_proj = nn.Linear(hidden_size, 2 * intermediate_size, bias=False, device=device)
94
+ self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False, device=device)
95
+
96
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
97
+ gate, up = self.gate_up_proj(x).chunk(2, dim=-1)
98
+ return self.down_proj(F.silu(gate) * up)
99
+
100
+
101
+ class KoHRMBlock(nn.Module):
102
+ def __init__(self, cfg: dict[str, Any], device: str = "meta") -> None:
103
+ super().__init__()
104
+ self.eps = float(cfg["rms_norm_eps"])
105
+ self.attn = KoHRMAttention(cfg["hidden_size"], cfg["num_attention_heads"], cfg["head_dim"], device=device)
106
+ self.mlp = KoHRMMLP(cfg["hidden_size"], cfg["intermediate_size"], device=device)
107
+
108
+ def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, cache: dict[str, torch.Tensor] | None, cache_pos: int) -> torch.Tensor:
109
+ x = x + self.attn(_rms_norm(x, self.eps), cos, sin, cache, cache_pos)
110
+ x = x + self.mlp(_rms_norm(x, self.eps))
111
+ return x
112
+
113
+
114
+ class KoHRMModule(nn.Module):
115
+ def __init__(self, cfg: dict[str, Any], num_layers: int, device: str = "meta") -> None:
116
+ super().__init__()
117
+ self.eps = float(cfg["rms_norm_eps"])
118
+ self.layers = nn.ModuleList([KoHRMBlock(cfg, device=device) for _ in range(num_layers)])
119
+
120
+ def forward(
121
+ self,
122
+ hidden_states: torch.Tensor,
123
+ input_injection: torch.Tensor,
124
+ cos: torch.Tensor,
125
+ sin: torch.Tensor,
126
+ caches: list[dict[str, torch.Tensor]] | None,
127
+ cache_pos: int,
128
+ ) -> torch.Tensor:
129
+ x = hidden_states + input_injection
130
+ for idx, layer in enumerate(self.layers):
131
+ x = layer(x, cos, sin, None if caches is None else caches[idx], cache_pos)
132
+ return _rms_norm(x, self.eps)
133
+
134
+
135
+ class KoHRMCore(nn.Module):
136
+ def __init__(self, cfg: dict[str, Any], num_layers: int, device: str = "meta") -> None:
137
+ super().__init__()
138
+ self.cfg = cfg
139
+ self.embedding_scale = float(cfg.get("embedding_scale", 1.0))
140
+ self.embed_tokens = nn.Embedding(cfg["vocab_size"], cfg["hidden_size"], device=device)
141
+ self.register_buffer("z_L_init", torch.empty(cfg["hidden_size"], device=device), persistent=True)
142
+ self.H_module = KoHRMModule(cfg, num_layers, device=device)
143
+ self.L_module = KoHRMModule(cfg, num_layers, device=device)
144
+
145
+ def forward(
146
+ self,
147
+ input_ids: torch.Tensor,
148
+ position_ids: torch.Tensor,
149
+ caches: dict[str, list[list[dict[str, torch.Tensor]]]] | None,
150
+ cache_pos: int,
151
+ ) -> torch.Tensor:
152
+ x = self.embedding_scale * self.embed_tokens(input_ids)
153
+ cos, sin = _rope_cos_sin(position_ids, self.cfg["head_dim"], float(self.cfg["rope_theta"]), x.dtype)
154
+ z_h = x
155
+ z_l = self.z_L_init.to(dtype=x.dtype).view(1, 1, -1).expand_as(x)
156
+
157
+ h_cycles, l_cycles = int(self.cfg["H_cycles"]), int(self.cfg["L_cycles"])
158
+ for h_idx in range(h_cycles):
159
+ for l_idx in range(l_cycles):
160
+ pass_idx = h_idx * l_cycles + l_idx
161
+ z_l = self.L_module(z_l, z_h, cos, sin, None if caches is None else caches["L"][pass_idx], cache_pos)
162
+ z_h = self.H_module(z_h, z_l, cos, sin, None if caches is None else caches["H"][h_idx], cache_pos)
163
+ return z_h
164
+
165
+
166
+ class KoHRMTextForGeneration(nn.Module):
167
+ def __init__(self, cfg: dict[str, Any], num_layers: int, device: str = "meta") -> None:
168
+ super().__init__()
169
+ self.cfg = cfg
170
+ self.num_layers = num_layers
171
+ self.model = KoHRMCore(cfg, num_layers, device=device)
172
+ self.lm_head = nn.Linear(cfg["hidden_size"], cfg["vocab_size"], bias=False, device=device)
173
+
174
+ def forward(
175
+ self,
176
+ input_ids: torch.Tensor,
177
+ position_ids: torch.Tensor,
178
+ caches: dict[str, list[list[dict[str, torch.Tensor]]]] | None = None,
179
+ cache_pos: int = 0,
180
+ ) -> torch.Tensor:
181
+ hidden = self.model(input_ids, position_ids, caches, cache_pos)
182
+ return self.lm_head(hidden)
183
+
184
+ def init_cache(self, batch_size: int, max_seq_len: int, device: torch.device, dtype: torch.dtype) -> dict[str, list[list[dict[str, torch.Tensor]]]]:
185
+ heads, head_dim = int(self.cfg["num_attention_heads"]), int(self.cfg["head_dim"])
186
+
187
+ def one_layer() -> dict[str, torch.Tensor]:
188
+ shape = (batch_size, max_seq_len, heads, head_dim)
189
+ return {
190
+ "k": torch.empty(shape, device=device, dtype=dtype),
191
+ "v": torch.empty(shape, device=device, dtype=dtype),
192
+ }
193
+
194
+ def one_pass() -> list[dict[str, torch.Tensor]]:
195
+ return [one_layer() for _ in range(self.num_layers)]
196
+
197
+ return {
198
+ "H": [one_pass() for _ in range(int(self.cfg["H_cycles"]))],
199
+ "L": [one_pass() for _ in range(int(self.cfg["H_cycles"]) * int(self.cfg["L_cycles"]))],
200
+ }
201
+
202
+
203
+ def _module_layer_count(state: dict[str, torch.Tensor], prefix: str) -> int:
204
+ layers = set()
205
+ marker = f"{prefix}.layers."
206
+ for key in state:
207
+ if key.startswith(marker):
208
+ layers.add(int(key[len(marker) :].split(".", 1)[0]))
209
+ return max(layers) + 1
210
+
211
+
212
+ def load_kohrm(repo_dir: str | Path, device: str | None = None, max_gpu_memory_gib: float | None = None) -> tuple[KoHRMTextForGeneration, Tokenizer, dict[str, Any]]:
213
+ repo_dir = Path(repo_dir)
214
+ cfg = json.loads((repo_dir / "config.json").read_text())
215
+ tokenizer = Tokenizer.from_file(str(repo_dir / "tokenizer.json"))
216
+
217
+ state = load_file(str(repo_dir / "model.safetensors"), device="cpu")
218
+ num_layers = _module_layer_count(state, "model.H_module")
219
+ model = KoHRMTextForGeneration(cfg, num_layers=num_layers, device="meta")
220
+ model.load_state_dict(state, strict=True, assign=True)
221
+ del state
222
+
223
+ if device is None:
224
+ device = "cuda" if torch.cuda.is_available() else "cpu"
225
+ target = torch.device(device)
226
+ dtype = torch.float16 if target.type == "cuda" else torch.float32
227
+ model = model.to(device=target, dtype=dtype).eval()
228
+ if target.type == "cuda":
229
+ torch.set_float32_matmul_precision("high")
230
+ if target.type == "cuda" and max_gpu_memory_gib is not None:
231
+ free, total = torch.cuda.mem_get_info()
232
+ print(f"GPU memory free/total GiB: {free / 2**30:.2f}/{total / 2**30:.2f}")
233
+ return model, tokenizer, cfg
234
+
235
+
236
+ def condition_to_tokens(condition: str = "direct", mapping: dict[str, str] | None = None) -> str:
237
+ """Map upstream HRM-Text condition names to tokenizer control tokens."""
238
+ mapping = mapping or DEFAULT_CONDITION_TOKENS
239
+ pieces: list[str] = []
240
+ for raw_name in condition.split(","):
241
+ name = raw_name.strip()
242
+ if not name:
243
+ continue
244
+ if name not in mapping:
245
+ valid = ", ".join(sorted(mapping))
246
+ raise ValueError(f"Unknown condition {name!r}; expected one of: {valid}")
247
+ pieces.append(mapping[name])
248
+ if not pieces:
249
+ pieces.append(mapping["direct"])
250
+ return "".join(pieces)
251
+
252
+
253
+ def format_kohrm_prompt(
254
+ prompt: str,
255
+ condition: str = "direct",
256
+ condition_token: str | None = None,
257
+ ) -> str:
258
+ """Format prompts like upstream InferenceCheckpoint.tokenize_prompt().
259
+
260
+ Upstream wraps prompts as:
261
+ `<boq><condition_tokens><instruction><eoq>`.
262
+
263
+ For answer-only generation use condition="direct", which maps to
264
+ `<|object_ref_start|>` in the KoHRM tokenizer. `condition_token` is kept
265
+ for backward compatibility and overrides `condition` when supplied.
266
+ """
267
+ if condition_token is None:
268
+ condition_token = condition_to_tokens(condition)
269
+ return f"<|im_start|>{condition_token}{prompt}<|im_end|>"
270
+
271
+
272
+ def _apply_repetition_penalty(logits: torch.Tensor, seen_ids: list[int], penalty: float) -> torch.Tensor:
273
+ if penalty <= 1.0 or not seen_ids:
274
+ return logits
275
+ for token_id in set(seen_ids):
276
+ value = logits[..., token_id]
277
+ logits[..., token_id] = torch.where(value < 0, value * penalty, value / penalty)
278
+ return logits
279
+
280
+
281
+ def _apply_no_repeat_ngram(logits: torch.Tensor, seen_ids: list[int], ngram_size: int) -> torch.Tensor:
282
+ if ngram_size <= 0 or len(seen_ids) < ngram_size - 1:
283
+ return logits
284
+ prefix = tuple(seen_ids[-(ngram_size - 1):])
285
+ blocked: set[int] = set()
286
+ for idx in range(len(seen_ids) - ngram_size + 1):
287
+ if tuple(seen_ids[idx:idx + ngram_size - 1]) == prefix:
288
+ blocked.add(seen_ids[idx + ngram_size - 1])
289
+ if blocked:
290
+ logits[..., list(blocked)] = -torch.inf
291
+ return logits
292
+
293
+
294
+ def _sample_next(
295
+ logits: torch.Tensor,
296
+ temperature: float,
297
+ top_p: float,
298
+ seen_ids: list[int] | None = None,
299
+ repetition_penalty: float = 1.0,
300
+ no_repeat_ngram_size: int = 0,
301
+ blocked_ids: set[int] | None = None,
302
+ ) -> int:
303
+ logits = logits.float()
304
+ seen_ids = seen_ids or []
305
+ logits = _apply_repetition_penalty(logits, seen_ids, repetition_penalty)
306
+ logits = _apply_no_repeat_ngram(logits, seen_ids, no_repeat_ngram_size)
307
+ if blocked_ids:
308
+ logits[..., list(blocked_ids)] = -torch.inf
309
+ if temperature <= 0:
310
+ return int(torch.argmax(logits, dim=-1).item())
311
+ probs = torch.softmax(logits / temperature, dim=-1)
312
+ if top_p < 1.0:
313
+ sorted_probs, sorted_idx = torch.sort(probs, descending=True)
314
+ keep = torch.cumsum(sorted_probs, dim=-1) <= top_p
315
+ keep[..., 0] = True
316
+ sorted_probs = sorted_probs.masked_fill(~keep, 0)
317
+ sorted_probs = sorted_probs / sorted_probs.sum(dim=-1, keepdim=True)
318
+ next_sorted = torch.multinomial(sorted_probs, num_samples=1)
319
+ return int(sorted_idx.gather(-1, next_sorted).item())
320
+ return int(torch.multinomial(probs, num_samples=1).item())
321
+
322
+
323
+ @torch.inference_mode()
324
+ def generate_from_loaded(
325
+ model: KoHRMTextForGeneration,
326
+ tokenizer: Tokenizer,
327
+ cfg: dict[str, Any],
328
+ prompt: str,
329
+ *,
330
+ max_new_tokens: int = 64,
331
+ min_new_tokens: int = 0,
332
+ max_seq_len: int = 512,
333
+ temperature: float = 0.0,
334
+ top_p: float = 0.9,
335
+ repetition_penalty: float = 1.18,
336
+ no_repeat_ngram_size: int = 4,
337
+ condition: str = "direct",
338
+ condition_token: str | None = None,
339
+ ) -> str:
340
+ dev = next(model.parameters()).device
341
+ dtype = next(model.parameters()).dtype
342
+ wrapped = format_kohrm_prompt(prompt, condition=condition, condition_token=condition_token)
343
+ input_ids = tokenizer.encode(wrapped, add_special_tokens=False).ids
344
+ if len(input_ids) + max_new_tokens + 1 > max_seq_len:
345
+ raise ValueError(f"Prompt plus generation exceeds max_seq_len={max_seq_len}: prompt_tokens={len(input_ids)}")
346
+
347
+ caches = model.init_cache(1, max_seq_len, dev, dtype)
348
+ ids = torch.tensor([input_ids], device=dev, dtype=torch.long)
349
+ pos = torch.arange(ids.shape[1], device=dev, dtype=torch.long).unsqueeze(0)
350
+ logits = model(ids, pos, caches=caches, cache_pos=0)[:, -1, :]
351
+ cache_pos = ids.shape[1]
352
+
353
+ eos_id = int(cfg.get("eos_token_id") or tokenizer.token_to_id("<|box_end|>"))
354
+ stop_ids = {
355
+ eos_id,
356
+ tokenizer.token_to_id("<|im_end|>"),
357
+ tokenizer.token_to_id("<|box_end|>"),
358
+ }
359
+ stop_ids = {int(x) for x in stop_ids if x is not None}
360
+ out_ids: list[int] = []
361
+ seen_ids = list(input_ids)
362
+ next_id = _sample_next(
363
+ logits,
364
+ temperature,
365
+ top_p,
366
+ seen_ids,
367
+ repetition_penalty,
368
+ no_repeat_ngram_size,
369
+ blocked_ids=stop_ids if min_new_tokens > 0 else None,
370
+ )
371
+ for _ in range(max_new_tokens):
372
+ if next_id in stop_ids and len(out_ids) >= min_new_tokens:
373
+ break
374
+ out_ids.append(next_id)
375
+ seen_ids.append(next_id)
376
+ token = torch.tensor([[next_id]], device=dev, dtype=torch.long)
377
+ pos = torch.tensor([[cache_pos]], device=dev, dtype=torch.long)
378
+ logits = model(token, pos, caches=caches, cache_pos=cache_pos)[:, -1, :]
379
+ cache_pos += 1
380
+ next_id = _sample_next(
381
+ logits,
382
+ temperature,
383
+ top_p,
384
+ seen_ids,
385
+ repetition_penalty,
386
+ no_repeat_ngram_size,
387
+ blocked_ids=stop_ids if len(out_ids) < min_new_tokens else None,
388
+ )
389
+
390
+ return tokenizer.decode(out_ids, skip_special_tokens=True).strip()
391
+
392
+
393
+ @torch.inference_mode()
394
+ def generate_text(
395
+ repo_dir: str | Path,
396
+ prompt: str,
397
+ *,
398
+ max_new_tokens: int = 64,
399
+ min_new_tokens: int = 0,
400
+ max_seq_len: int = 512,
401
+ temperature: float = 0.0,
402
+ top_p: float = 0.9,
403
+ repetition_penalty: float = 1.18,
404
+ no_repeat_ngram_size: int = 4,
405
+ condition: str = "direct",
406
+ condition_token: str | None = None,
407
+ device: str | None = None,
408
+ ) -> str:
409
+ model, tokenizer, cfg = load_kohrm(repo_dir, device=device, max_gpu_memory_gib=14.0)
410
+ return generate_from_loaded(
411
+ model,
412
+ tokenizer,
413
+ cfg,
414
+ prompt,
415
+ max_new_tokens=max_new_tokens,
416
+ min_new_tokens=min_new_tokens,
417
+ max_seq_len=max_seq_len,
418
+ temperature=temperature,
419
+ top_p=top_p,
420
+ repetition_penalty=repetition_penalty,
421
+ no_repeat_ngram_size=no_repeat_ngram_size,
422
+ condition=condition,
423
+ condition_token=condition_token,
424
+ )
425
+
426
+
427
+ def main() -> None:
428
+ parser = argparse.ArgumentParser(description="Run a KoHRM-Text long generation probe without transformers.")
429
+ parser.add_argument("repo_dir", type=Path, help="Directory containing config.json, tokenizer.json, and model.safetensors")
430
+ parser.add_argument(
431
+ "--prompt",
432
+ default=(
433
+ "๋‹ค์Œ์€ ํ•œ๊ตญ์–ด ์œ„ํ‚ค๋ฐฑ๊ณผ ๋ฌธ์„œ ์›๋ฌธ ์ผ๋ถ€์ž…๋‹ˆ๋‹ค. ๋ฐฑ๊ณผ์‚ฌ์ „์‹ ํ•œ๊ตญ์–ด, "
434
+ "๊ณ ์œ ๋ช…์‚ฌ, ๋‚ ์งœ, ๊ธฐ์ˆ /์‚ฌํšŒ/๋ฌธํ™” ์ง€์‹์„ ๊ทธ๋Œ€๋กœ ํ•™์Šตํ•˜์‹ญ์‹œ์˜ค.\n\n"
435
+ "[๋ฌธ์„œ๋ช…]\nํ›ˆ๋ฏผ์ •์Œ\n\n[๋ถ€๋ถ„]\n1/1"
436
+ ),
437
+ )
438
+ parser.add_argument("--max-new-tokens", type=int, default=384)
439
+ parser.add_argument("--min-new-tokens", type=int, default=160)
440
+ parser.add_argument("--max-seq-len", type=int, default=1536)
441
+ parser.add_argument("--temperature", type=float, default=0.65)
442
+ parser.add_argument("--top-p", type=float, default=0.92)
443
+ parser.add_argument("--repetition-penalty", type=float, default=1.05)
444
+ parser.add_argument("--no-repeat-ngram-size", type=int, default=0)
445
+ parser.add_argument(
446
+ "--condition",
447
+ default="direct",
448
+ help="Comma-separated HRM-Text condition names: direct, cot, noisy, synth. Use direct for answer-only outputs.",
449
+ )
450
+ parser.add_argument(
451
+ "--condition-token",
452
+ default=None,
453
+ help="Optional raw condition token override. Normally use --condition direct instead.",
454
+ )
455
+ parser.add_argument("--device", default=None)
456
+ args = parser.parse_args()
457
+ print(generate_text(
458
+ args.repo_dir,
459
+ args.prompt,
460
+ max_new_tokens=args.max_new_tokens,
461
+ min_new_tokens=args.min_new_tokens,
462
+ max_seq_len=args.max_seq_len,
463
+ temperature=args.temperature,
464
+ top_p=args.top_p,
465
+ repetition_penalty=args.repetition_penalty,
466
+ no_repeat_ngram_size=args.no_repeat_ngram_size,
467
+ condition=args.condition,
468
+ condition_token=args.condition_token,
469
+ device=args.device,
470
+ ))
471
+
472
+
473
+ if __name__ == "__main__":
474
+ main()