Jellyfish042 commited on
Commit
8d6299f
·
1 Parent(s): 6e818da

Init RWKV compressor Space demo

Browse files
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ models/*.pth
2
+ models/.cache/
3
+ __pycache__/
4
+ *.pyc
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: LLM Compressor
3
  emoji: 🐨
4
  colorFrom: gray
5
  colorTo: pink
@@ -9,4 +9,26 @@ app_file: app.py
9
  pinned: false
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: RWKV LLM Text Compressor
3
  emoji: 🐨
4
  colorFrom: gray
5
  colorTo: pink
 
9
  pinned: false
10
  ---
11
 
12
+ # RWKV LLM Text Compressor
13
+
14
+ This Space demonstrates LLM-based arithmetic coding using RWKV. It is a proof of
15
+ concept and is intentionally slow. The compressed output is only valid when the
16
+ same model, tokenizer, and context window are used for decompression.
17
+
18
+ ## Configuration
19
+
20
+ - `RWKV_MODEL_PATH`: Path to a local RWKV `.pth` file (or name without extension).
21
+ - `RWKV_TOKENIZER`: Path to `rwkv_vocab_v20230424.txt`. Default: `support/rwkv_vocab_v20230424.txt`.
22
+ - `RWKV_STRATEGY`: RWKV strategy string (example: `cpu fp32`, `cuda fp16`).
23
+
24
+ ## Notes
25
+
26
+ - CPU-only Spaces should keep `RWKV_STRATEGY=cpu fp32`. The app forces CPU when CUDA
27
+ is unavailable.
28
+ - The vocab file is not bundled; place `rwkv_vocab_v20230424.txt` in `support/` or
29
+ set `RWKV_TOKENIZER` to its path.
30
+ - The app auto-detects a `.pth` model under `models/` if `RWKV_MODEL_PATH` is not set.
31
+ - If no model is found, the app downloads `rwkv7-g1a-0.1b-20250728-ctx4096.pth` into `models/`.
32
+ - Input text is limited to 8192 characters.
33
+ - Compression and decompression are slow and not suitable for production use.
34
+ - Output is not portable across different models or tokenizers.
app.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import os
3
+ import shutil
4
+ import tempfile
5
+ import urllib.request
6
+ from pathlib import Path
7
+
8
+ import gradio as gr
9
+ import torch
10
+
11
+ from llm_compressor import compress_tokens, decompress_bytes, load_rwkv_model, tokenize_text
12
+
13
+ MAX_INPUT_CHARS = 8192
14
+ SCRIPT_DIR = Path(__file__).parent.absolute()
15
+ SUPPORT_DIR = SCRIPT_DIR / "support"
16
+ MODELS_DIR = SCRIPT_DIR / "models"
17
+ DEFAULT_MODEL_FILENAME = "rwkv7-g1a-0.1b-20250728-ctx4096.pth"
18
+ DEFAULT_MODEL_PATH = str(MODELS_DIR / DEFAULT_MODEL_FILENAME)
19
+ DEFAULT_MODEL_URL = "https://huggingface.co/BlinkDL/rwkv7-g1/resolve/main/" "rwkv7-g1a-0.1b-20250728-ctx4096.pth?download=true"
20
+ DEFAULT_TOKENIZER_PATH = str(SUPPORT_DIR / "rwkv_vocab_v20230424.txt")
21
+
22
+
23
+ def _patch_gradio_client_schema():
24
+ try:
25
+ from gradio_client import utils as gr_client_utils
26
+ except Exception:
27
+ return
28
+
29
+ if getattr(gr_client_utils, "_rwkv_patch", False):
30
+ return
31
+
32
+ original_get_type = gr_client_utils.get_type
33
+ original_json_schema = gr_client_utils._json_schema_to_python_type
34
+
35
+ def _patched_get_type(schema):
36
+ if isinstance(schema, bool):
37
+ return "Any"
38
+ return original_get_type(schema)
39
+
40
+ gr_client_utils.get_type = _patched_get_type
41
+ gr_client_utils._json_schema_to_python_type = lambda schema, defs=None: "Any" if isinstance(schema, bool) else original_json_schema(schema, defs)
42
+ gr_client_utils._rwkv_patch = True
43
+
44
+
45
+ _patch_gradio_client_schema()
46
+
47
+
48
+ def _write_temp_file(data, suffix=".llmc"):
49
+ tmp = tempfile.NamedTemporaryFile(delete=False, suffix=suffix)
50
+ tmp.write(data)
51
+ tmp.flush()
52
+ tmp.close()
53
+ return tmp.name
54
+
55
+
56
+ def _resolve_default_model_path():
57
+ env_model = os.getenv("RWKV_MODEL_PATH")
58
+ if env_model:
59
+ return env_model
60
+ default_path = Path(DEFAULT_MODEL_PATH)
61
+ if default_path.is_file():
62
+ return str(default_path)
63
+ if DEFAULT_MODEL_URL:
64
+ downloaded = _download_default_model()
65
+ if downloaded:
66
+ return downloaded
67
+ if MODELS_DIR.is_dir():
68
+ candidates = sorted(MODELS_DIR.glob("*.pth"))
69
+ if candidates:
70
+ return str(candidates[0])
71
+ return ""
72
+
73
+
74
+ def _resolve_default_tokenizer_path():
75
+ env_tokenizer = os.getenv("RWKV_TOKENIZER")
76
+ if env_tokenizer:
77
+ return env_tokenizer
78
+ default_path = Path(DEFAULT_TOKENIZER_PATH)
79
+ if default_path.is_file():
80
+ return str(default_path)
81
+ if SUPPORT_DIR.is_dir():
82
+ candidates = sorted(SUPPORT_DIR.glob("rwkv_vocab_v*.txt"))
83
+ if candidates:
84
+ return str(candidates[0])
85
+ return str(default_path)
86
+
87
+
88
+ def _download_default_model():
89
+ if not DEFAULT_MODEL_URL:
90
+ return ""
91
+ dest_path = Path(DEFAULT_MODEL_PATH)
92
+ if dest_path.is_file():
93
+ return str(dest_path)
94
+ dest_path.parent.mkdir(parents=True, exist_ok=True)
95
+ tmp_path = dest_path.with_suffix(dest_path.suffix + ".tmp")
96
+ try:
97
+ print(f"Downloading RWKV model to {dest_path}...")
98
+ with urllib.request.urlopen(DEFAULT_MODEL_URL) as response, open(tmp_path, "wb") as f:
99
+ shutil.copyfileobj(response, f)
100
+ tmp_path.replace(dest_path)
101
+ return str(dest_path)
102
+ except Exception as exc:
103
+ if tmp_path.exists():
104
+ tmp_path.unlink()
105
+ print(f"Failed to download RWKV model: {exc}")
106
+ return ""
107
+
108
+
109
+ def _resolve_model_path(value):
110
+ if not value:
111
+ return None
112
+ path = Path(value).expanduser()
113
+ candidates = [path]
114
+ if path.suffix != ".pth":
115
+ candidates.append(path.with_suffix(".pth"))
116
+ if not path.is_absolute():
117
+ candidates.append(MODELS_DIR / path)
118
+ if path.suffix != ".pth":
119
+ candidates.append((MODELS_DIR / path).with_suffix(".pth"))
120
+ for candidate in candidates:
121
+ if candidate.is_file():
122
+ return candidate
123
+ return None
124
+
125
+
126
+ def _resolve_tokenizer_path(value):
127
+ if not value:
128
+ return None
129
+ path = Path(value).expanduser()
130
+ candidates = [path]
131
+ if not path.is_absolute():
132
+ candidates.append(SUPPORT_DIR / path)
133
+ for candidate in candidates:
134
+ if candidate.is_file():
135
+ return candidate
136
+ return None
137
+
138
+
139
+ def _resolve_strategy():
140
+ return _normalize_strategy(os.getenv("RWKV_STRATEGY", "cpu fp32"))
141
+
142
+
143
+ def _extract_file_bytes(file_data):
144
+ if file_data is None:
145
+ return None
146
+ if isinstance(file_data, (bytes, bytearray)):
147
+ return bytes(file_data)
148
+ if isinstance(file_data, dict) and "data" in file_data:
149
+ return file_data["data"]
150
+ if isinstance(file_data, str):
151
+ with open(file_data, "rb") as f:
152
+ return f.read()
153
+ if hasattr(file_data, "read"):
154
+ return file_data.read()
155
+ raise gr.Error("Unsupported uploaded file format.")
156
+
157
+
158
+ def _get_compressed_bytes(b64_data, file_data):
159
+ file_bytes = _extract_file_bytes(file_data)
160
+ if file_bytes:
161
+ return file_bytes
162
+ if not b64_data or not b64_data.strip():
163
+ raise gr.Error("Compressed base64 data is empty.")
164
+ try:
165
+ return base64.b64decode(b64_data.encode("ascii"), validate=True)
166
+ except Exception as exc:
167
+ raise gr.Error(f"Invalid base64 data: {exc}") from exc
168
+
169
+
170
+ def _load_model_and_tokenizer(model_path, tokenizer_name, strategy):
171
+ resolved_model = _resolve_model_path(model_path)
172
+ if not resolved_model:
173
+ raise gr.Error(f"RWKV model file not found: {model_path}. Put a .pth in {MODELS_DIR} or set RWKV_MODEL_PATH.")
174
+ resolved_tokenizer = _resolve_tokenizer_path(tokenizer_name)
175
+ if not resolved_tokenizer:
176
+ raise gr.Error(f"Tokenizer vocab file not found: {tokenizer_name}. Put rwkv_vocab_v20230424.txt in {SUPPORT_DIR} " "or set RWKV_TOKENIZER.")
177
+ try:
178
+ return load_rwkv_model(str(resolved_model), str(resolved_tokenizer), strategy)
179
+ except Exception as exc:
180
+ raise gr.Error(f"Failed to load RWKV model: {exc}") from exc
181
+
182
+
183
+ def _format_compress_stats(stats):
184
+ return "\n".join(
185
+ [
186
+ f"- Tokens: {stats['tokens']}",
187
+ f"- Original bytes: {stats['original_bytes']}",
188
+ f"- Compressed bytes: {stats['compressed_bytes']}",
189
+ f"- Compression ratio: {stats['ratio'] * 100:.2f}%",
190
+ f"- Theoretical ratio: {stats['theoretical_ratio'] * 100:.2f}%",
191
+ f"- Time: {stats['duration_s']:.2f}s",
192
+ f"- Speed: {stats['speed_toks_per_s']:.2f} tokens/s",
193
+ ]
194
+ )
195
+
196
+
197
+ def _format_decompress_stats(stats):
198
+ return "\n".join(
199
+ [
200
+ f"- Tokens: {stats['tokens']}",
201
+ f"- Time: {stats['duration_s']:.2f}s",
202
+ ]
203
+ )
204
+
205
+
206
+ def _normalize_strategy(strategy):
207
+ if "cuda" in strategy and not torch.cuda.is_available():
208
+ return "cpu fp32"
209
+ return strategy
210
+
211
+
212
+ def compress_ui(text, context_window, progress=gr.Progress()):
213
+ if not text or not text.strip():
214
+ raise gr.Error("Input text is empty.")
215
+ if len(text) > MAX_INPUT_CHARS:
216
+ raise gr.Error(f"Input is too long ({len(text)} chars). Max is {MAX_INPUT_CHARS}.")
217
+
218
+ model_path = _resolve_default_model_path()
219
+ tokenizer_path = _resolve_default_tokenizer_path()
220
+ requested_strategy = os.getenv("RWKV_STRATEGY", "cpu fp32")
221
+ effective_strategy = _resolve_strategy()
222
+ model, tokenizer = _load_model_and_tokenizer(model_path, tokenizer_path, effective_strategy)
223
+
224
+ tokens = tokenize_text(tokenizer, text)
225
+ if not tokens:
226
+ raise gr.Error("Tokenized input is empty.")
227
+
228
+ original_bytes = len(text.encode("utf-8"))
229
+ data, stats = compress_tokens(
230
+ tokens,
231
+ model,
232
+ context_window=context_window,
233
+ original_bytes=original_bytes,
234
+ progress=progress,
235
+ progress_desc="Compressing",
236
+ )
237
+
238
+ b64 = base64.b64encode(data).decode("ascii")
239
+ file_path = _write_temp_file(data)
240
+ stats_text = _format_compress_stats(stats)
241
+ if effective_strategy != requested_strategy:
242
+ stats_text += "\n- Strategy: cpu fp32 (forced, CUDA unavailable)"
243
+ else:
244
+ stats_text += f"\n- Strategy: {effective_strategy}"
245
+ return b64, stats_text, file_path
246
+
247
+
248
+ def decompress_ui(b64_data, file_data, context_window):
249
+ raw = _get_compressed_bytes(b64_data, file_data)
250
+ model_path = _resolve_default_model_path()
251
+ tokenizer_path = _resolve_default_tokenizer_path()
252
+ requested_strategy = os.getenv("RWKV_STRATEGY", "cpu fp32")
253
+ effective_strategy = _resolve_strategy()
254
+ model, tokenizer = _load_model_and_tokenizer(model_path, tokenizer_path, effective_strategy)
255
+ text, stats = decompress_bytes(raw, model, tokenizer, context_window=context_window)
256
+ stats_text = _format_decompress_stats(stats)
257
+ if effective_strategy != requested_strategy:
258
+ stats_text += "\n- Strategy: cpu fp32 (forced, CUDA unavailable)"
259
+ else:
260
+ stats_text += f"\n- Strategy: {effective_strategy}"
261
+ return text, stats_text
262
+
263
+
264
+ def build_ui():
265
+ with gr.Blocks() as demo:
266
+ gr.Markdown("# RWKV LLM Text Compressor")
267
+ gr.Markdown(
268
+ "This is a proof-of-concept demo. Compression and decompression are slow, "
269
+ "and the output is not portable across different models or tokenizers."
270
+ )
271
+
272
+ context_window = gr.Slider(
273
+ label="Context window",
274
+ minimum=128,
275
+ maximum=4096,
276
+ step=128,
277
+ value=2048,
278
+ )
279
+
280
+ gr.Markdown(f"Max input size: {MAX_INPUT_CHARS} characters.")
281
+
282
+ with gr.Tabs():
283
+ with gr.Tab("Compress"):
284
+ input_text = gr.Textbox(label="Input text", lines=10)
285
+ compress_button = gr.Button("Compress")
286
+ output_b64 = gr.Textbox(label="Compressed data (base64)", lines=6)
287
+ compress_stats = gr.Markdown()
288
+ output_file = gr.File(label="Download compressed file")
289
+
290
+ compress_button.click(
291
+ compress_ui,
292
+ inputs=[input_text, context_window],
293
+ outputs=[output_b64, compress_stats, output_file],
294
+ )
295
+
296
+ with gr.Tab("Decompress"):
297
+ input_b64 = gr.Textbox(label="Compressed data (base64)", lines=6)
298
+ input_file = gr.File(label="Or upload compressed file", type="binary")
299
+ decompress_button = gr.Button("Decompress")
300
+ output_text = gr.Textbox(label="Decompressed text", lines=10)
301
+ decompress_stats = gr.Markdown()
302
+
303
+ decompress_button.click(
304
+ decompress_ui,
305
+ inputs=[input_b64, input_file, context_window],
306
+ outputs=[output_text, decompress_stats],
307
+ )
308
+
309
+ return demo
310
+
311
+
312
+ if __name__ == "__main__":
313
+ build_ui().queue(max_size=16).launch(
314
+ server_name="0.0.0.0",
315
+ server_port=7860,
316
+ share=False,
317
+ show_api=False,
318
+ )
llm_compressor.py ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import math
3
+ import os
4
+ import struct
5
+ import threading
6
+ import time
7
+ from functools import lru_cache
8
+
9
+ import torch
10
+
11
+ PROB_SCALE = 1 << 48
12
+ ARITHMETIC_PRECISION = 64
13
+
14
+
15
+ class BitOutputStream:
16
+ def __init__(self, file_obj):
17
+ self.file_obj = file_obj
18
+ self.byte = 0
19
+ self.bit_count = 0
20
+
21
+ def write_bit(self, bit):
22
+ self.byte = (self.byte << 1) | bit
23
+ self.bit_count += 1
24
+ if self.bit_count == 8:
25
+ self.file_obj.write(bytes([self.byte]))
26
+ self.byte = 0
27
+ self.bit_count = 0
28
+
29
+ def close(self):
30
+ if self.bit_count > 0:
31
+ self.byte <<= 8 - self.bit_count
32
+ self.file_obj.write(bytes([self.byte]))
33
+
34
+
35
+ class BitInputStream:
36
+ def __init__(self, file_obj):
37
+ self.file_obj = file_obj
38
+ self.byte = 0
39
+ self.bit_count = 0
40
+
41
+ def read_bit(self):
42
+ if self.bit_count == 0:
43
+ bytes_data = self.file_obj.read(1)
44
+ if not bytes_data:
45
+ return -1
46
+ self.byte = bytes_data[0]
47
+ self.bit_count = 8
48
+
49
+ bit = (self.byte >> (self.bit_count - 1)) & 1
50
+ self.bit_count -= 1
51
+ return bit
52
+
53
+
54
+ class ArithmeticEncoder:
55
+ def __init__(self, bit_output, precision=ARITHMETIC_PRECISION):
56
+ self.bit_output = bit_output
57
+ self.precision = precision
58
+ self.max_val = (1 << precision) - 1
59
+ self.quarter_val = 1 << (precision - 2)
60
+ self.half_val = 1 << (precision - 1)
61
+ self.three_quarter_val = self.quarter_val * 3
62
+ self.low = 0
63
+ self.high = self.max_val
64
+ self.pending_bits = 0
65
+
66
+ def encode_symbol(self, low_count, high_count, total_count):
67
+ range_val = self.high - self.low + 1
68
+ self.high = self.low + (range_val * high_count) // total_count - 1
69
+ self.low = self.low + (range_val * low_count) // total_count
70
+
71
+ while True:
72
+ if self.high < self.half_val:
73
+ self._write_bit(0)
74
+ elif self.low >= self.half_val:
75
+ self._write_bit(1)
76
+ self.low -= self.half_val
77
+ self.high -= self.half_val
78
+ elif self.low >= self.quarter_val and self.high < self.three_quarter_val:
79
+ self.pending_bits += 1
80
+ self.low -= self.quarter_val
81
+ self.high -= self.quarter_val
82
+ else:
83
+ break
84
+
85
+ self.low <<= 1
86
+ self.high = (self.high << 1) | 1
87
+
88
+ def _write_bit(self, bit):
89
+ self.bit_output.write_bit(bit)
90
+ while self.pending_bits > 0:
91
+ self.bit_output.write_bit(1 - bit)
92
+ self.pending_bits -= 1
93
+
94
+ def finish(self):
95
+ self.pending_bits += 1
96
+ if self.low < self.quarter_val:
97
+ self._write_bit(0)
98
+ else:
99
+ self._write_bit(1)
100
+
101
+
102
+ class ArithmeticDecoder:
103
+ def __init__(self, bit_input, precision=ARITHMETIC_PRECISION):
104
+ self.bit_input = bit_input
105
+ self.precision = precision
106
+ self.max_val = (1 << precision) - 1
107
+ self.quarter_val = 1 << (precision - 2)
108
+ self.half_val = 1 << (precision - 1)
109
+ self.three_quarter_val = self.quarter_val * 3
110
+ self.low = 0
111
+ self.high = self.max_val
112
+ self.value = 0
113
+
114
+ for _ in range(precision):
115
+ read_val = self.bit_input.read_bit()
116
+ if read_val == -1:
117
+ read_val = 0
118
+ self.value = (self.value << 1) | read_val
119
+
120
+ def decode_symbol_find_count(self, total_count):
121
+ range_val = self.high - self.low + 1
122
+ count = ((self.value - self.low + 1) * total_count - 1) // range_val
123
+ return count
124
+
125
+ def update_range(self, low_count, high_count, total_count):
126
+ range_val = self.high - self.low + 1
127
+ self.high = self.low + (range_val * high_count) // total_count - 1
128
+ self.low = self.low + (range_val * low_count) // total_count
129
+
130
+ while True:
131
+ if self.high < self.half_val:
132
+ pass
133
+ elif self.low >= self.half_val:
134
+ self.value -= self.half_val
135
+ self.low -= self.half_val
136
+ self.high -= self.half_val
137
+ elif self.low >= self.quarter_val and self.high < self.three_quarter_val:
138
+ self.value -= self.quarter_val
139
+ self.low -= self.quarter_val
140
+ self.high -= self.quarter_val
141
+ else:
142
+ break
143
+
144
+ self.low <<= 1
145
+ self.high = (self.high << 1) | 1
146
+
147
+ bit = self.bit_input.read_bit()
148
+ if bit == -1:
149
+ bit = 0
150
+ self.value = (self.value << 1) | bit
151
+
152
+
153
+ def _strip_pth(model_path):
154
+ return model_path[:-4] if model_path.endswith(".pth") else model_path
155
+
156
+
157
+ def _prepare_logits(logits):
158
+ if not isinstance(logits, torch.Tensor):
159
+ logits = torch.tensor(logits)
160
+ if logits.ndim > 1:
161
+ logits = logits[-1]
162
+ return logits.float()
163
+
164
+
165
+ def tokenize_text(tokenizer, text):
166
+ tokenized = tokenizer.encode(text)
167
+ if hasattr(tokenized, "ids"):
168
+ tokenized = tokenized.ids
169
+ return [int(token_id) for token_id in tokenized]
170
+
171
+
172
+ def decode_tokens(tokenizer, tokens):
173
+ return tokenizer.decode(tokens)
174
+
175
+
176
+ _MODEL_LOCK = threading.Lock()
177
+
178
+
179
+ @lru_cache(maxsize=2)
180
+ def load_rwkv_model(model_path, tokenizer_name, strategy):
181
+ if not model_path:
182
+ raise ValueError("RWKV model path is required.")
183
+ if not tokenizer_name:
184
+ raise ValueError("RWKV tokenizer name or path is required.")
185
+
186
+ if "cuda" in strategy and not torch.cuda.is_available():
187
+ strategy = "cpu fp32"
188
+
189
+ os.environ["RWKV_JIT_ON"] = "1"
190
+ os.environ["RWKV_V7_ON"] = "1"
191
+ os.environ["RWKV_CUDA_ON"] = "1" if "cuda" in strategy else "0"
192
+
193
+ with _MODEL_LOCK:
194
+ from rwkv.model import RWKV
195
+ from rwkv.rwkv_tokenizer import TRIE_TOKENIZER
196
+
197
+ model = RWKV(model=_strip_pth(model_path), strategy=strategy)
198
+ tokenizer = TRIE_TOKENIZER(tokenizer_name)
199
+ return model, tokenizer
200
+
201
+
202
+ def compress_tokens(
203
+ tokens,
204
+ model,
205
+ context_window=2048,
206
+ original_bytes=None,
207
+ progress=None,
208
+ progress_desc="Compressing",
209
+ ):
210
+ if context_window <= 0:
211
+ raise ValueError("context_window must be positive.")
212
+
213
+ token_ids = [int(token_id) for token_id in tokens]
214
+ if not token_ids:
215
+ raise ValueError("No tokens to compress.")
216
+
217
+ output = io.BytesIO()
218
+ output.write(struct.pack(">I", len(token_ids)))
219
+ bit_output = BitOutputStream(output)
220
+ encoder = ArithmeticEncoder(bit_output, precision=ARITHMETIC_PRECISION)
221
+
222
+ context_tokens = []
223
+ state = None
224
+ total_nll = 0.0
225
+ start_time = time.time()
226
+ total_tokens = len(token_ids)
227
+ if progress is not None:
228
+ progress((0, total_tokens), desc=progress_desc, unit="token")
229
+
230
+ with torch.inference_mode():
231
+ for idx, token_id in enumerate(token_ids):
232
+ if len(context_tokens) >= context_window:
233
+ context_tokens = []
234
+ state = None
235
+
236
+ input_token = context_tokens[-1] if context_tokens else 0
237
+ logits, state = model.forward([input_token], state)
238
+ next_logits = _prepare_logits(logits)
239
+
240
+ probs = torch.softmax(next_logits, dim=-1)
241
+ counts = (probs * PROB_SCALE).to(torch.long)
242
+ counts = torch.clamp(counts, min=1)
243
+
244
+ cdf = torch.cumsum(counts, dim=-1)
245
+ total_count = int(cdf[-1].item())
246
+
247
+ prob_val = probs[token_id]
248
+ total_nll += float((-torch.log(prob_val)).item())
249
+
250
+ low_val = int(cdf[token_id - 1].item()) if token_id > 0 else 0
251
+ high_val = int(cdf[token_id].item())
252
+ encoder.encode_symbol(low_val, high_val, total_count)
253
+
254
+ context_tokens.append(token_id)
255
+ if progress is not None:
256
+ progress((idx + 1, total_tokens), desc=progress_desc, unit="token")
257
+
258
+ encoder.finish()
259
+ bit_output.close()
260
+ data = output.getvalue()
261
+ end_time = time.time()
262
+
263
+ original_bytes = int(original_bytes or 0)
264
+ compressed_bytes = len(data)
265
+ ratio = compressed_bytes / original_bytes if original_bytes > 0 else 0.0
266
+
267
+ theoretical_bits = total_nll / math.log(2)
268
+ theoretical_bytes = theoretical_bits / 8
269
+ theoretical_ratio = theoretical_bytes / original_bytes if original_bytes > 0 else 0.0
270
+
271
+ duration = end_time - start_time
272
+ speed = len(token_ids) / duration if duration > 0 else 0.0
273
+
274
+ stats = {
275
+ "tokens": len(token_ids),
276
+ "original_bytes": original_bytes,
277
+ "compressed_bytes": compressed_bytes,
278
+ "ratio": ratio,
279
+ "theoretical_ratio": theoretical_ratio,
280
+ "duration_s": duration,
281
+ "speed_toks_per_s": speed,
282
+ }
283
+ return data, stats
284
+
285
+
286
+ def compress_text(text, model, tokenizer, context_window=2048):
287
+ tokens = tokenize_text(tokenizer, text)
288
+ original_bytes = len(text.encode("utf-8"))
289
+ return compress_tokens(tokens, model, context_window=context_window, original_bytes=original_bytes)
290
+
291
+
292
+ def decompress_bytes(data, model, tokenizer, context_window=2048):
293
+ if context_window <= 0:
294
+ raise ValueError("context_window must be positive.")
295
+ if not data or len(data) < 4:
296
+ raise ValueError("Compressed data is empty or invalid.")
297
+
298
+ buffer = io.BytesIO(data)
299
+ total_tokens_bytes = buffer.read(4)
300
+ total_tokens = struct.unpack(">I", total_tokens_bytes)[0]
301
+
302
+ bit_input = BitInputStream(buffer)
303
+ decoder = ArithmeticDecoder(bit_input, precision=ARITHMETIC_PRECISION)
304
+
305
+ decoded_tokens = []
306
+ context_tokens = []
307
+ state = None
308
+ start_time = time.time()
309
+
310
+ with torch.inference_mode():
311
+ for _ in range(total_tokens):
312
+ if len(context_tokens) >= context_window:
313
+ context_tokens = []
314
+ state = None
315
+
316
+ input_token = context_tokens[-1] if context_tokens else 0
317
+ logits, state = model.forward([input_token], state)
318
+ next_logits = _prepare_logits(logits)
319
+
320
+ probs = torch.softmax(next_logits, dim=-1)
321
+ counts = (probs * PROB_SCALE).to(torch.long)
322
+ counts = torch.clamp(counts, min=1)
323
+
324
+ cdf = torch.cumsum(counts, dim=-1)
325
+ total_count = int(cdf[-1].item())
326
+
327
+ count_val = decoder.decode_symbol_find_count(total_count)
328
+ count_val_tensor = torch.tensor(count_val, device=cdf.device)
329
+ target_token_id = int(torch.searchsorted(cdf, count_val_tensor, right=True).item())
330
+
331
+ decoded_tokens.append(target_token_id)
332
+ context_tokens.append(target_token_id)
333
+
334
+ low_val = int(cdf[target_token_id - 1].item()) if target_token_id > 0 else 0
335
+ high_val = int(cdf[target_token_id].item())
336
+ decoder.update_range(low_val, high_val, total_count)
337
+
338
+ text = decode_tokens(tokenizer, decoded_tokens)
339
+ duration = time.time() - start_time
340
+
341
+ stats = {
342
+ "tokens": total_tokens,
343
+ "duration_s": duration,
344
+ }
345
+ return text, stats
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ gradio>=4.0.0
2
+ rwkv
3
+ torch
support/README.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ Place the RWKV vocab file here:
2
+ - rwkv_vocab_v20230424.txt
3
+
4
+ You can also set RWKV_TOKENIZER to point to a different vocab path.
support/rwkv_vocab_v20230424.txt ADDED
The diff for this file is too large to render. See raw diff