Jellyfish042 commited on
Commit
eb04823
·
1 Parent(s): 88a3875

Fix launch args and add decompression progress

Browse files
Files changed (2) hide show
  1. app.py +23 -8
  2. llm_compressor.py +14 -2
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import base64
 
2
  import os
3
  import shutil
4
  import tempfile
@@ -262,14 +263,21 @@ def compress_ui(text, context_window, progress=gr.Progress()):
262
  return b64, stats_text, file_path
263
 
264
 
265
- def decompress_ui(b64_data, file_data, context_window):
266
  raw = _get_compressed_bytes(b64_data, file_data)
267
  model_path = _resolve_default_model_path()
268
  tokenizer_path = _resolve_default_tokenizer_path()
269
  requested_strategy = os.getenv("RWKV_STRATEGY", "cpu fp32")
270
  effective_strategy = _resolve_strategy()
271
  model, tokenizer = _load_model_and_tokenizer(model_path, tokenizer_path, effective_strategy)
272
- text, stats = decompress_bytes(raw, model, tokenizer, context_window=context_window)
 
 
 
 
 
 
 
273
  stats_text = _format_decompress_stats(stats, char_count=len(text))
274
  if effective_strategy != requested_strategy:
275
  stats_text += "\n- Strategy: cpu fp32 (forced, CUDA unavailable)"
@@ -349,9 +357,16 @@ def build_ui():
349
 
350
 
351
  if __name__ == "__main__":
352
- build_ui().queue(max_size=16).launch(
353
- server_name="0.0.0.0",
354
- server_port=7860,
355
- share=False,
356
- show_api=False,
357
- )
 
 
 
 
 
 
 
 
1
  import base64
2
+ import inspect
3
  import os
4
  import shutil
5
  import tempfile
 
263
  return b64, stats_text, file_path
264
 
265
 
266
+ def decompress_ui(b64_data, file_data, context_window, progress=gr.Progress()):
267
  raw = _get_compressed_bytes(b64_data, file_data)
268
  model_path = _resolve_default_model_path()
269
  tokenizer_path = _resolve_default_tokenizer_path()
270
  requested_strategy = os.getenv("RWKV_STRATEGY", "cpu fp32")
271
  effective_strategy = _resolve_strategy()
272
  model, tokenizer = _load_model_and_tokenizer(model_path, tokenizer_path, effective_strategy)
273
+ text, stats = decompress_bytes(
274
+ raw,
275
+ model,
276
+ tokenizer,
277
+ context_window=context_window,
278
+ progress=progress,
279
+ progress_desc="Decompressing",
280
+ )
281
  stats_text = _format_decompress_stats(stats, char_count=len(text))
282
  if effective_strategy != requested_strategy:
283
  stats_text += "\n- Strategy: cpu fp32 (forced, CUDA unavailable)"
 
357
 
358
 
359
  if __name__ == "__main__":
360
+ launch_kwargs = {
361
+ "server_name": "0.0.0.0",
362
+ "server_port": 7860,
363
+ "share": False,
364
+ }
365
+ try:
366
+ launch_params = inspect.signature(gr.Blocks.launch).parameters
367
+ if "show_api" in launch_params:
368
+ launch_kwargs["show_api"] = False
369
+ except (TypeError, ValueError):
370
+ pass
371
+
372
+ build_ui().queue(max_size=16).launch(**launch_kwargs)
llm_compressor.py CHANGED
@@ -289,7 +289,14 @@ def compress_text(text, model, tokenizer, context_window=2048):
289
  return compress_tokens(tokens, model, context_window=context_window, original_bytes=original_bytes)
290
 
291
 
292
- def decompress_bytes(data, model, tokenizer, context_window=2048):
 
 
 
 
 
 
 
293
  if context_window <= 0:
294
  raise ValueError("context_window must be positive.")
295
  if not data or len(data) < 4:
@@ -306,9 +313,12 @@ def decompress_bytes(data, model, tokenizer, context_window=2048):
306
  context_tokens = []
307
  state = None
308
  start_time = time.time()
 
 
 
309
 
310
  with torch.inference_mode():
311
- for _ in range(total_tokens):
312
  if len(context_tokens) >= context_window:
313
  context_tokens = []
314
  state = None
@@ -334,6 +344,8 @@ def decompress_bytes(data, model, tokenizer, context_window=2048):
334
  low_val = int(cdf[target_token_id - 1].item()) if target_token_id > 0 else 0
335
  high_val = int(cdf[target_token_id].item())
336
  decoder.update_range(low_val, high_val, total_count)
 
 
337
 
338
  text = decode_tokens(tokenizer, decoded_tokens)
339
  duration = time.time() - start_time
 
289
  return compress_tokens(tokens, model, context_window=context_window, original_bytes=original_bytes)
290
 
291
 
292
+ def decompress_bytes(
293
+ data,
294
+ model,
295
+ tokenizer,
296
+ context_window=2048,
297
+ progress=None,
298
+ progress_desc="Decompressing",
299
+ ):
300
  if context_window <= 0:
301
  raise ValueError("context_window must be positive.")
302
  if not data or len(data) < 4:
 
313
  context_tokens = []
314
  state = None
315
  start_time = time.time()
316
+ if progress is not None:
317
+ progress((0, total_tokens), desc=progress_desc, unit="token")
318
+ progress_step = max(1, total_tokens // 100)
319
 
320
  with torch.inference_mode():
321
+ for idx in range(total_tokens):
322
  if len(context_tokens) >= context_window:
323
  context_tokens = []
324
  state = None
 
344
  low_val = int(cdf[target_token_id - 1].item()) if target_token_id > 0 else 0
345
  high_val = int(cdf[target_token_id].item())
346
  decoder.update_range(low_val, high_val, total_count)
347
+ if progress is not None and (idx + 1 == total_tokens or (idx + 1) % progress_step == 0):
348
+ progress((idx + 1, total_tokens), desc=progress_desc, unit="token")
349
 
350
  text = decode_tokens(tokenizer, decoded_tokens)
351
  duration = time.time() - start_time