| import os |
| import time |
| import json |
| from argparse import ArgumentParser |
| from typing import List |
| from torch import nn |
|
|
| import torch |
| import torch.distributed as dist |
| from transformers import AutoTokenizer |
| from safetensors.torch import load_file, load_model |
|
|
| from model import Transformer, ModelArgs, Block |
| from concurrent.futures import ThreadPoolExecutor |
| from kernel import softmax_q21, softmax_q19 |
|
|
| snark = True |
| zkDataDir = '../zkdata' |
|
|
| model = None |
| kv_caches = [ torch.zeros(1, 4096 * 4, 512, dtype=torch.int64) ] * 61 |
| pe_caches = [ torch.zeros(1, 4096 * 4, 64, dtype=torch.int64) ] * 61 |
| state_dicts = [None] * 61 |
|
|
| def getF32PrintStr(ele): |
| v = int(ele.cpu().view(torch.uint32).item()) |
| ex = str((v >> 23 & 0xFF) - 127) |
| r = '(1+' + str(v & 0x7FFFFF) + '/8388608)' |
| if v & 0x80000000: |
| vstr = '-' + r + '*2^' + ex |
| else: |
| vstr = r + '*2^' + ex |
| return vstr |
|
|
| def getBF16PrintStr(ele): |
| v = int(ele.cpu().view(torch.uint16).item()) |
| ex = v >> 7 & 0xFF |
| r = '(1+' + str(v & 0x7F) + '/128)' |
| rraw = v & 0x7F |
|
|
| if v & 0x8000: |
| vstr = '-' + r + '*2^' + str(ex - 127) |
| else: |
| vstr = r + '*2^' + str(ex - 127) |
| return vstr |
|
|
| def mem(i): |
| a = torch.cuda.memory_allocated()/1024**2 |
| r = torch.cuda.memory_reserved()/1024**2 |
| m = torch.cuda.max_memory_allocated()/1024**2 |
| print(f"{i} allocated={a:.1f}MB, reserved={r:.1f}MB, max={m:.1f}MB", flush=True) |
|
|
| def load_model2(ckpt_path): |
| global model |
|
|
| with torch.device("cuda"): |
| load_model(model.embed, os.path.join(ckpt_path, f"embed_int.safetensors")) |
| load_model(model.norm, os.path.join(ckpt_path, f"norm_int.safetensors")) |
| load_model(model.head, os.path.join(ckpt_path, f"head_int.safetensors")) |
|
|
| |
| def sample(logits, temperature: float = 1.0): |
| """ |
| Samples a token from the logits using temperature scaling. |
| |
| Args: |
| logits (torch.Tensor): The logits tensor for token predictions. |
| temperature (float, optional): Temperature for scaling logits. Defaults to 1.0. |
| |
| Returns: |
| torch.Tensor: The sampled token. |
| """ |
| |
| |
| |
| |
| |
| |
|
|
| sample_open = False |
|
|
| if sample_open: |
| maxx = logits.abs().max() |
| typ = logits.type() |
| print(f'sample logits type: {typ}, shape: {logits.shape}, abs max: {maxx}') |
| if temperature > 1e-5: |
| temp_int = int(temperature) |
| |
| logits = logits // temp_int |
| print(f'temp_int: {temp_int}', flush=True) |
| else: |
| logits = logits * (10 ** 5) |
| |
| |
|
|
| logits = logits.unsqueeze(2) |
|
|
| max0 = logits.abs().max() |
| print(f'sample 2233 logits shape: {logits.shape}, abs max0: {max0}') |
|
|
| |
| probs = torch.empty_like(logits, dtype=torch.int64, device='cuda') |
| softmax_q21(logits.contiguous(), probs) |
|
|
| probs = probs.squeeze(2) |
|
|
| |
|
|
| typ2 = probs.type() |
| max1 = probs.abs().max() |
| print(f'sample 33 probs type: {typ2}, shape: {probs.shape}, probs abs max: {max1}', flush=True) |
|
|
| rand = torch.empty_like(probs, dtype=torch.float32, device='cuda').exponential_(1) |
| rand_abs = rand.abs() |
| rmin = getF32PrintStr(rand_abs.min()) |
| rmax = getF32PrintStr(rand_abs.max()) |
| print(f'sample 333 rand abs min: {rmin}, max: {rmax}', flush=True) |
|
|
| |
| rand = (rand * (2 ** 10)).round().to(torch.int64) + (2 ** 4) |
| max2 = rand.abs().max() |
| min2 = rand.abs().min() |
| print(f'sample 55 rand abs min: {min2}, max: {max2}', flush=True) |
|
|
| |
| |
| probs = (probs * (2 ** 10)) // rand |
|
|
| max3 = probs.abs().max() |
| print(f'sample 66 probs abs max: {max3}', flush=True) |
|
|
| res = probs.argmax(dim=-1) |
| tid = res[0][0].item() |
| tv = probs[0][0][tid] |
| randv = rand[0][0][tid] |
| |
| print(f'sample 44 res: {res}, tid: {tid}, tv: {tv}, randv: {randv}') |
| else: |
| probs = logits.unsqueeze(2) |
| max3 = probs.abs().max() |
| print(f'sample 66 probs abs max: {max3}', flush=True) |
|
|
| res = probs.argmax(dim=-1) |
| return res |
|
|
| def saveTensor(fileName, t): |
| with open(fileName, "w", encoding="utf-8") as f: |
| t = t.detach() |
| if t.device.type != "cpu": |
| t = t.cpu() |
| t = t.contiguous() |
| with open(fileName, "wb") as f: |
| |
| f.write(t.numpy().tobytes(order="C")) |
|
|
| |
| |
| |
| |
| |
| |
|
|
| @torch.inference_mode() |
| def generate( |
| |
| ckpt_path: str, |
| args: ModelArgs, |
| tokenizer: AutoTokenizer, |
| prompt_tokens: List[List[int]], |
| max_new_tokens: int, |
| eos_id: int, |
| temperature: float = 1.0 |
| ) -> List[List[int]]: |
| """ |
| Generates new tokens based on the given prompt tokens using the specified model. |
| |
| Args: |
| model (Transformer): The transformer model used for token generation. |
| prompt_tokens (List[List[int]]): A list of lists containing the prompt tokens for each sequence. |
| max_new_tokens (int): The maximum number of new tokens to generate. |
| eos_id (int): The end-of-sequence token ID. |
| temperature (float, optional): The temperature value for sampling. Defaults to 1.0. |
| |
| Returns: |
| List[List[int]]: A list of lists containing the generated tokens for each sequence. |
| """ |
|
|
| global model, layers |
| global kv_caches, pe_caches |
|
|
| prompt_lens = [len(t) for t in prompt_tokens] |
| assert max(prompt_lens) <= args.max_seq_len, f"Prompt length exceeds model maximum sequence length (max_seq_len={args.max_seq_len})" |
| total_len = min(args.max_seq_len, max_new_tokens + max(prompt_lens)) |
| |
| |
| tokens = torch.full((len(prompt_tokens), total_len), -1, dtype=torch.long, device="cuda") |
| |
| for i, t in enumerate(prompt_tokens): |
| tokens[i, :len(t)] = torch.tensor(t, dtype=torch.long, device="cuda") |
|
|
| beginstr = tokenizer.decode(tokens[0][0:prompt_lens[0]], skip_special_tokens=True) |
| |
| print(' ++++++ token:', beginstr, flush=True) |
|
|
| prev_pos = 0 |
| |
| finished = torch.tensor([False] * len(prompt_tokens), device="cuda") |
| |
| |
| prompt_mask = tokens != -1 |
|
|
| |
| for cur_pos in range(min(prompt_lens), total_len): |
| print(f'prev_pos: {prev_pos}, cur_pos: {cur_pos}, total_len: {total_len}', flush=True) |
| t = tokenizer.decode(tokens[0][prev_pos:cur_pos], skip_special_tokens=True) |
| print(str(cur_pos) + ' ---------- token list: ' + str(tokens[0][prev_pos:cur_pos].tolist()), flush=True) |
|
|
| if snark: |
| os.makedirs(f'{zkDataDir}/pos_{prev_pos}', exist_ok=True) |
| saveTensor(f'{zkDataDir}/pos_{prev_pos}/tokens.bin', tokens[0][prev_pos:cur_pos].cpu()) |
|
|
| |
|
|
| h, start_pos, seqlen = model.prep_inference(tokens[:, prev_pos:cur_pos], prev_pos) |
| print('h 1 shape: ' + str(h.shape), flush=True) |
|
|
| for i in range(args.n_layers): |
| print(f'begin layer {i} -----------------', flush=True) |
| with torch.device("cuda"): |
|
|
| with torch.no_grad(): |
| if hasattr(model.layers[i], 'attn_norm'): |
| del model.layers[i].attn_norm.weight |
|
|
| model.layers[i] = Block(i, args, ckpt_path) |
| model.layers[i].load_state_dict(state_dicts[i], False) |
| model.layers[i].attn.kv_cache = kv_caches[i].to('cuda') |
| model.layers[i].attn.pe_cache = pe_caches[i].to('cuda') |
|
|
| h = model.layer_inference(i, h, start_pos, seqlen) |
|
|
| kv_caches[i] = model.layers[i].attn.kv_cache |
| pe_caches[i] = model.layers[i].attn.pe_cache |
| model.layers[i] = nn.Module() |
|
|
| tmph = model.norm(h)[0][:, -1] |
|
|
| tmph_abs = tmph.abs() |
| tmph_min = tmph_abs.min() |
| tmph_max = tmph_abs.max() |
| print(f'tmph_abs min: {tmph_min}, max: {tmph_max}', flush=True) |
|
|
| tmplogits = model.head(tmph[None, :]) |
|
|
| tmp_next_token = tmplogits.argmax(dim=-1) |
| tid = tmp_next_token[0][0].item() |
| tmp_logit = tmplogits[0][0][tid] |
| tmp_completion = tokenizer.decode([tmp_next_token[0][0]], skip_special_tokens=True) |
| print(f'position {cur_pos} tid: {tid}, tmp_logit:{tmp_logit}, candidate: {tmp_completion}', flush=True) |
|
|
| |
| logits = model.finish_inference(h) |
|
|
| if temperature > 0: |
| next_token = sample(logits, temperature) |
| else: |
| next_token = logits.argmax(dim=-1) |
| next_token = torch.where(prompt_mask[:, cur_pos], tokens[:, cur_pos], next_token) |
| |
| tokens[:, cur_pos] = next_token |
| |
| finished |= torch.logical_and(~prompt_mask[:, cur_pos], next_token.view(-1) == eos_id) |
| prev_pos = cur_pos |
|
|
| completion = tokenizer.decode(tokens[0][0:cur_pos+1], skip_special_tokens=True) |
| print(f'---------- Result: position {cur_pos}, token: {completion}', flush=True) |
|
|
| if finished.all(): |
| break |
| completion_tokens = [] |
| for i, toks in enumerate(tokens.tolist()): |
| toks = toks[prompt_lens[i]:prompt_lens[i]+max_new_tokens] |
| if eos_id in toks: |
| toks = toks[:toks.index(eos_id)] |
| completion_tokens.append(toks) |
| return completion_tokens |
|
|
|
|
| def main( |
| ckpt_path: str, |
| config: str, |
| input_file: str = "", |
| interactive: bool = True, |
| max_new_tokens: int = 100, |
| temperature: float = 1.0, |
| ) -> None: |
| """ |
| Main function to load the model and perform interactive or batch text generation. |
| |
| Args: |
| ckpt_path (str): Path to the model checkpoint directory. |
| config (str): Path to the model configuration file. |
| input_file (str, optional): Path to a file containing input prompts. Defaults to "". |
| interactive (bool, optional): Whether to run in interactive mode. Defaults to True. |
| max_new_tokens (int, optional): Maximum number of new tokens to generate. Defaults to 100. |
| temperature (float, optional): Temperature for sampling. Defaults to 1.0. |
| """ |
| global model |
|
|
| |
| world_size = int(os.getenv("WORLD_SIZE", "1")) |
| |
| rank = int(os.getenv("RANK", "0")) |
| |
| |
| print('WORLD_SIZE: ' + str(world_size) + ', rank: ' + str(rank)) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| torch.cuda.set_device(0) |
| |
| torch.set_default_dtype(torch.bfloat16) |
| |
| torch.set_num_threads(8) |
| |
| torch.manual_seed(965) |
| with open(config) as f: |
| args = ModelArgs(**json.load(f)) |
| print(args) |
| |
| |
| |
|
|
| for i in range(args.n_layers): |
| modelPath = os.path.join(ckpt_path, f"layer-{i}.safetensors") |
| state_dicts[i] = load_file(modelPath, device="cpu") |
|
|
| with torch.device("cuda"): |
| model = Transformer(args) |
|
|
| tokenizer = AutoTokenizer.from_pretrained(ckpt_path) |
| load_model2(ckpt_path) |
|
|
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
|
|
| if rank == 0: |
| |
| embed_abs = model.embed.weight.detach().cpu().abs() |
| abs_min = torch.min(embed_abs) |
| abs_max = torch.max(embed_abs) |
| print('embed abs_min: ' + str(abs_min), flush=True) |
| print('embed abs_max: ' + str(abs_max), flush=True) |
| else: |
| pass |
|
|
| if interactive: |
| messages = [] |
| while True: |
| if world_size == 1: |
| prompt = input(">>> ") |
| |
| |
| |
| elif rank == 0: |
| prompt = input(">>> ") |
| objects = [prompt] |
| dist.broadcast_object_list(objects, 0) |
| else: |
| objects = [None] |
| dist.broadcast_object_list(objects, 0) |
| prompt = objects[0] |
| if prompt == "/exit": |
| break |
| elif prompt == "/clear": |
| messages.clear() |
| continue |
| |
| |
| |
| |
| |
| |
| |
| messages.append({"role": "user", "content": prompt}) |
| |
| |
| |
| |
| prompt_tokens = tokenizer.apply_chat_template(messages, add_generation_prompt=True) |
| |
| |
| |
| |
| |
| completion_tokens = generate(ckpt_path, args, tokenizer, [prompt_tokens], max_new_tokens, tokenizer.eos_token_id, temperature) |
| |
| completion = tokenizer.decode(completion_tokens[0], skip_special_tokens=True) |
| print(completion) |
| messages.append({"role": "assistant", "content": completion}) |
| else: |
| with open(input_file) as f: |
| prompts = [line.strip() for line in f.readlines()] |
| assert len(prompts) <= args.max_batch_size, f"Number of prompts exceeds maximum batch size ({args.max_batch_size})" |
| prompt_tokens = [tokenizer.apply_chat_template([{"role": "user", "content": prompt}], add_generation_prompt=True) for prompt in prompts] |
| |
| completion_tokens = generate(ckpt_path, args, tokenizer, prompt_tokens, max_new_tokens, tokenizer.eos_token_id, temperature) |
| completions = tokenizer.batch_decode(completion_tokens, skip_special_tokens=True) |
| for prompt, completion in zip(prompts, completions): |
| print("Prompt:", prompt) |
| print("Completion:", completion) |
| print() |
|
|
| if world_size > 1: |
| dist.destroy_process_group() |
|
|
|
|
| if __name__ == "__main__": |
| """ |
| Command-line interface for distributed text generation. |
| |
| Arguments: |
| --ckpt-path (str): Path to the model checkpoint directory. 模型参数存放的路径。 |
| --config (str): Path to the model configuration file. 模型的超参配置文件的路径。 |
| --input-file (str, optional): File containing prompts for batch processing. 假设我们是批量输入prompt,则该参数是批量输入prompt的文件的路径。 |
| --interactive (bool, optional): Enable interactive mode for generating text. 是否是问答交互式?这里相当于开启模型的“问答”模式。bool变量。 |
| --max-new-tokens (int, optional): Maximum number of new tokens to generate. Defaults to 200. 限制要求生成的tokens的数量。 |
| --temperature (float, optional): Temperature for sampling. Defaults to 0.2. 采样温度。 |
| |
| Raises: |
| AssertionError: If neither input-file nor interactive mode is specified. |
| """ |
| parser = ArgumentParser() |
| parser.add_argument("--ckpt-path", type=str, required=True) |
| parser.add_argument("--config", type=str, required=True) |
| parser.add_argument("--input-file", type=str, default="") |
| parser.add_argument("--interactive", action="store_true") |
| parser.add_argument("--max-new-tokens", type=int, default=200) |
| parser.add_argument("--temperature", type=float, default=0.2) |
| args = parser.parse_args() |
| assert args.input_file or args.interactive, "Either input-file or interactive mode must be specified" |
| main(args.ckpt_path, args.config, args.input_file, args.interactive, args.max_new_tokens, args.temperature) |
|
|