import numpy as np from typing import Tuple, Callable from m1_compression import arithmetic_coder from m1_compression import utils import torch import logging from pathlib import Path import time from apps.main.transformer import LMTransformer, LMTransformerArgs from lingua.checkpoint import CONSOLIDATE_FOLDER, consolidate_checkpoints from apps.main.generate import ( load_consolidated_model_and_tokenizer, ) from lingua.args import dataclass_from_dict from lingua.checkpoint import CONSOLIDATE_NAME from omegaconf import OmegaConf logger = logging.getLogger() ALPHABET_SIZE = 256 # Base 2 means that the coder writes bits. ARITHMETIC_CODER_BASE = 2 # Precision 16 implies 16 bit arithmetic, in the original paper it is 32. ARITHMETIC_CODER_PRECISION = 32 WINDOW_SIZE = 32 # 窗口大小,以位为单位,可根据需求调整 def load_m1_model_and_tokenizer(consolidated_path: str): """ Args: consolidated_path (str): 模型检查点的路径。 """ # 加载配置文件 ckpt_dir = Path(consolidated_path) if ( Path(ckpt_dir).exists() and (Path(ckpt_dir) / "params.json").exists() and next(Path(ckpt_dir).glob("*.pth"), None) is not None ): consolidate_path = Path(ckpt_dir) else: consolidate_path = Path(ckpt_dir) / CONSOLIDATE_FOLDER if not consolidate_path.exists(): consolidate_path = consolidate_checkpoints(ckpt_dir) # use api to load model consolidate_path = str(consolidate_path) logger.info("Loading model") model, tokenizer, train_cfg = load_consolidated_model_and_tokenizer( consolidate_path, model_cls=LMTransformer, model_args_cls=LMTransformerArgs, ) logger.info("Model loaded") model.eval() predict_fn = get_predict_fn(model, tokenizer) return model, tokenizer, predict_fn def load_m1_model_cpu(consolidated_path: str): """ Args: consolidated_path (str): 模型检查点的路径。 """ # 加载配置文件 ckpt_dir = Path(consolidated_path) if ( Path(ckpt_dir).exists() and (Path(ckpt_dir) / "params.json").exists() and next(Path(ckpt_dir).glob("*.pth"), None) is not None ): consolidate_path = Path(ckpt_dir) else: consolidate_path = Path(ckpt_dir) / CONSOLIDATE_FOLDER if not consolidate_path.exists(): consolidate_path = consolidate_checkpoints(ckpt_dir) # use api to load model consolidate_path = str(consolidate_path) logger.info("Loading model") ckpt_path = Path(consolidate_path) config = ckpt_path / "params.json" config = OmegaConf.load(config) param_dtype = dict(fp32=torch.float32, fp16=torch.float16, bf16=torch.bfloat16)[ config.distributed.model_dtype ] model_args = dataclass_from_dict(LMTransformerArgs, config.model, strict=False) model = LMTransformer(model_args) st_dict = torch.load(ckpt_path / CONSOLIDATE_NAME, weights_only=True) model.load_state_dict(st_dict["model"]) model = model.eval() for param in model.parameters(): param.data = param.data.to(dtype=param_dtype) logger.info("Model loaded") model.eval() return model def get_predict_fn(model, tokenizer): """ return a function that takes a sequence of tokens and returns the probability distribution of the next token. Args: model: the model to use for prediction tokenizer: the tokenizer to use for encoding/decoding """ def predict_fn(input_sequence: np.ndarray) -> np.ndarray: """ Args: input_sequence (np.ndarray): 输入序列,形状为 (batch_size, seq_len)。 Returns: np.ndarray: 每个 token 的概率分布,形状为 (batch_size, seq_len, vocab_size)。 """ # 如果输入的序列为空,返回模型平均概率分布---压缩和解压我必须要看见这个 if input_sequence.size == 0: initial_probs = np.ones((1, 1, 256), dtype=np.float32) / 256 # 均匀分布 return initial_probs # turn to torch tensor input_tensor = torch.tensor(input_sequence, dtype=torch.long).cuda() if input_tensor.dim() == 1: input_tensor = input_tensor.unsqueeze(0) with torch.no_grad(): # get logits logits = model(input_tensor) logits = logits[..., :256] logits = logits.float() assert torch.isfinite(logits).all(), "Logits contain NaN or Inf values." probs = torch.softmax(logits, dim=-1) probs = probs.float().cpu().numpy() return probs return predict_fn def m1_arithmetic_compress( data: bytes, predict_fn: Callable, return_num_padded_bits: bool = True, use_slow_lossless_compression: bool = True ) -> bytes | tuple[bytes, int]: """use language model to compress, return compressed bytes and padded bits""" sequence_array = np.frombuffer(data, dtype=np.uint8) if use_slow_lossless_compression: probs = [] for k in range(len(sequence_array)): # k 表示已处理的 token 数(从 0 到 n-1) # 输入前 k 个 token(初始 k=0 时为空序列),预测第 k+1 个 token 的概率 input_seq = sequence_array[:k] # 前 k 个 token(空序列当 k=0 时) if input_seq.size == 0: # 空输入时使用均匀分布(或模型预设的初始分布) current_probs = np.ones(ALPHABET_SIZE, dtype=np.float32) / ALPHABET_SIZE else: model_probs = predict_fn(input_seq[None]) # 形状 (1, k, 256) current_probs = model_probs[0, -1] # 提取最后一个位置(第 k 个位置,预测第 k+1 个 token) probs.append(current_probs) probs = np.array(probs) else: if sequence_array.size == 0: probs = np.ones(ALPHABET_SIZE, dtype=np.float32) / ALPHABET_SIZE else: full_probs = predict_fn(sequence_array[None])[0, ...] probs = np.concatenate( [ np.ones(ALPHABET_SIZE, dtype=np.float32)[None] / ALPHABET_SIZE, full_probs[:-1,...] ], axis=0 ) # print("[FAST] shape : {}".format(probs.shape)) probs /= probs.sum(axis=-1, keepdims=True) assert probs.shape[1] == ALPHABET_SIZE, "The shape of probs is not correct." assert probs.shape[0] == len(sequence_array), "The shape of probs is not correct." assert np.isclose(sum(probs[0]), 1, atol=1e-6), "The probs is not normalized." output = [] encoder = arithmetic_coder.Encoder( base=ARITHMETIC_CODER_BASE, precision=ARITHMETIC_CODER_PRECISION, output_fn=output.append ) for pdf, symbol in zip(probs, sequence_array): encoder.encode(utils.normalize_pdf_for_arithmetic_coding(pdf), symbol) encoder.terminate() compressed_bits = ''.join(map(str, output)) ## padding zero to turn the bitstream into bytes compressed_bytes, num_padded_bits = utils.bits_to_bytes(compressed_bits) if return_num_padded_bits: return compressed_bytes, num_padded_bits return compressed_bytes def m1_arithmetic_decompress( compressed: bytes, predict_fn: Callable, num_padded_bits: int, length: int, ) -> np.ndarray: bits = utils.bytes_to_bits(compressed, num_padded_bits=num_padded_bits) data_iter = iter(bits) def _input_fn() -> int | None: try: return int(next(data_iter)) except StopIteration: return None decoder = arithmetic_coder.Decoder( base=ARITHMETIC_CODER_BASE, precision=ARITHMETIC_CODER_PRECISION, input_fn=_input_fn ) sequence_array = np.empty((0,), dtype=np.uint8) for k in range(length): # 预测第 k+1 个 token 的概率:输入已解码的 k 个 token if k == 0: # 第一个 token:输入空序列,使用初始分布 input_seq = np.empty((0,), dtype=np.uint8) current_probs = np.ones(ALPHABET_SIZE, dtype=np.float32) / ALPHABET_SIZE else: input_seq = sequence_array # 已解码的 k 个 token model_probs = predict_fn(input_seq[None]) # 形状 (1, k, 256) current_probs = model_probs[0, -1] # 提取最后一个位置的概率(下一个 token) # 确保概率归一化(如果模型输出未归一化) current_probs /= current_probs.sum() #print("------ this range is ------",k) #print("------current_probs-----------:", current_probs) token = decoder.decode(utils.normalize_pdf_for_arithmetic_coding(current_probs)) ## 直接append 出现了中间吧sequence_array类型扩大的情况,出现了前置很多0 #sequence_array = np.append(sequence_array, token) sequence_array = np.concatenate([sequence_array, np.array([token], dtype=np.uint8)]) # print("解压缩后的 token IDs:", sequence_array) # print("to byte is:", sequence_array.tobytes()) return sequence_array.tobytes() def m1_arithmetic_compress_with_windows( data: bytes, predict_fn: Callable, window_bit_size=WINDOW_SIZE, return_num_padded_bits: bool = True, use_slow_lossless_compression: bool = True ) -> bytes | Tuple[bytes, list[int], list[int]]: compressed_windows = [] num_padded_bits_per_window = [] original_lengths_per_window = [] current_window = [] for byte in data: # 压缩窗口中加入此字节 current_window.append(byte) #进行压缩,得到压缩后的字节和填充位数 compressed, num_padded = m1_arithmetic_compress( bytes(current_window), predict_fn=predict_fn, return_num_padded_bits=True, use_slow_lossless_compression=use_slow_lossless_compression, ) #计算压缩后的bitstream compressed_bits = ''.join(map(str, utils.bytes_to_bits(compressed, num_padded))) # 计算是否超过窗口大小,进行截断 if len(compressed_bits) > window_bit_size: # print(f"oversize当前窗口压缩后的位数: {len(compressed_bits)}") current_window.pop() #得到上一个切分的窗口进行窗口内压缩并且保存 compressed, num_padded = m1_arithmetic_compress( bytes(current_window), predict_fn=predict_fn, return_num_padded_bits=True, use_slow_lossless_compression=use_slow_lossless_compression, ) compressed_bits = ''.join(map(str, utils.bytes_to_bits(compressed, num_padded))) compressed_windows.append(compressed) num_padded_bits_per_window.append(num_padded) original_lengths_per_window.append(len(current_window)) # print(f"当前窗口: {current_window}") # print(f"当前窗口压缩长度: {len(compressed)}") # print(f"当前窗口压缩后的位数: {len(compressed_bits)}") # print(f"填充长度: {num_padded}") # print(f"----------") # print(f"当前compressed byte数据: {compressed}") # print(f"当前bitstream: {compressed_bits}") current_window = [byte] # 处理最后一个窗口 if current_window: compressed, num_padded = m1_arithmetic_compress( bytes(current_window), predict_fn=predict_fn, return_num_padded_bits=True, use_slow_lossless_compression=use_slow_lossless_compression, ) compressed_windows.append(compressed) num_padded_bits_per_window.append(num_padded) original_lengths_per_window.append(len(current_window)) all_compressed_bytes = b''.join(compressed_windows) if return_num_padded_bits: return all_compressed_bytes, num_padded_bits_per_window, original_lengths_per_window return all_compressed_bytes def m1_arithmetic_decompress_with_windows( compressed: bytes, predict_fn: Callable, window_bit_size, num_padded_bits_per_window: list[int], original_lengths_per_window: list[int] ) -> bytes: decoded_bytes = b'' start = 0 bitstream = utils.bytes_to_bits(compressed) for num_padded, length in zip(num_padded_bits_per_window, original_lengths_per_window): # 按照窗口大小从比特流中提取当前窗口的比特流 window_bitstream = bitstream[start:start + window_bit_size] # 将比特流转换回字节流 window_compressed, _ = utils.bits_to_bytes(window_bitstream) print(f"当前窗口压缩数据: {window_compressed}") # 解码当前窗口 decoded_window = m1_arithmetic_decompress( window_compressed, predict_fn, num_padded, length, ) print(f"解压缩窗口: {decoded_window}") print(f"解压缩窗口的长度: {len(decoded_window)}") decoded_bytes += decoded_window # 更新起始位置 start += window_bit_size return decoded_bytes def test_equal_window_compression(sequence: str): print(f"测试序列: {sequence}") model, tokenizer, predict_fn = load_m1_model_and_tokenizer(consolidated_path = "/mnt/bn/tiktok-mm-5/aiic/users/linzheng/artifacts/m1_checkpoints/m1_1M_steps10k_bs32_seqlen2048/checkpoints/0000010000") original_bytes = tokenizer.encode(sequence) print(f"token 字节: {original_bytes}") original_bytes = bytes(original_bytes) print(f"转为字节流: {original_bytes}") compressed_data = m1_arithmetic_compress_with_windows( original_bytes, predict_fn=predict_fn, window_bit_size=WINDOW_SIZE, return_num_padded_bits=True, use_slow_lossless_compression=False ) compressed_bytes, num_padded_bits_per_window, original_lengths_per_window = compressed_data print(f"压缩后的字节: {compressed_bytes}") print(f"窗口填充位数数组: {num_padded_bits_per_window}") print(f"窗口原始长度数组: {original_lengths_per_window}") try: decoded_bytes = m1_arithmetic_decompress_with_windows( compressed_bytes, predict_fn, WINDOW_SIZE, num_padded_bits_per_window, original_lengths_per_window ) decoded_sequence = decoded_bytes.decode('utf-8') assert decoded_sequence == sequence, f"解码失败:原始={sequence}, 解码={decoded_sequence}" except: print(f"解码失败") compression_ratio = len(original_bytes) / len(compressed_bytes) print(f"原始大小: {len(original_bytes)} bytes") print(f"压缩后大小: {len(compressed_bytes)} bytes") print(f"各窗口填充位数: {num_padded_bits_per_window}") print(f"各窗口原始长度: {original_lengths_per_window}") print(f"压缩率: {compression_ratio:.2f}x") # print(f"解码结果: {decoded_sequence}\n") if __name__ == "__main__": test_sequences = [ "if month % 2 : return 30", "def __init__(self, name): self.name = name", "import pandas as pd\n import matplotlib.pyplot as plt", "import torch", "import torch\nimport torch.nn as nn\nimport torch.nn.functional as nn" "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F" * 100, ] # test_sequences = [ # "hello world", # "this is a test for arithmetic coding", # "language modeling is compression - this is the core idea", # b"".join([b'\x48' * 100]).decode('utf-8'), # 重复字节测试(H的ASCII码) # ] for seq in test_sequences: print(f"=== m1语言模型 Equal_Window 算术压缩 ===") start_time = time.time() test_equal_window_compression(seq) end_time = time.time() print(f"用时: {end_time - start_time} 秒")