File size: 4,549 Bytes
72c0672 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 | # Copyright (c) Meta Platforms, Inc. and affiliates.
import time
from dataclasses import dataclass
from omegaconf import OmegaConf
import torch
from torch import nn
from lingua.args import dataclass_from_dict
from lingua.tokenizer import Tokenizer
from apps.main.generate import (
PackedCausalTransformerGenerator,
PackedCausalTransformerGeneratorArgs,
load_consolidated_model_and_tokenizer,
)
from apps.mamba.core_mamba import SSM
from apps.mamba.mamba import LMMambaArgs, LMMamba, StateCache
@dataclass
class PackedCausalMambaGeneratorArgs(PackedCausalTransformerGeneratorArgs):
pass
class PackedCausalMambaGenerator(PackedCausalTransformerGenerator):
def __init__(
self,
cfg: PackedCausalMambaGeneratorArgs,
model: nn.Module,
tokenizer: Tokenizer,
):
self.model = model
self.tokenizer = tokenizer
self.temperature = cfg.temperature
self.top_p = cfg.top_p
self.top_k = cfg.top_k
self.max_gen_len = cfg.max_gen_len
self.max_tokens = cfg.max_tokens
self.max_prompt_len = cfg.max_prompt_len
self.until = cfg.until
self.max_until_size = max([len(e) for e in self.until]) if self.until else 1
self.device = cfg.device
# Compile if necessary
self.prefill = torch.compile(self.prefill, disable=not cfg.compile_prefilling)
self.generate_next_token = torch.compile(
self.generate_next_token,
mode="reduce-overhead",
disable=not cfg.reduce_generation_overhead,
)
self.show_progress = cfg.show_progress
self.dtype = dict(fp32=torch.float32, bf16=torch.bfloat16)[cfg.dtype]
self.prefill_tok_id = None
self.cu_seqlens = None
def clear_cache(self, lengths: torch.Tensor):
for module in self.model.modules():
if isinstance(module, SSM):
module.cache = StateCache(
lengths.size(0),
module.n_heads,
module.head_dim,
module.state_dim,
module.conv_size,
module.conv_dim,
self.dtype,
self.device,
)
@torch.compiler.disable
def setup_prefilling(self, lengths: torch.Tensor):
self.clear_cache(lengths)
self.prefill_tok_id = torch.repeat_interleave(lengths).unsqueeze(0).int()
self.cu_seqlens = lengths.cumsum(0)
self.cu_seqlens = torch.cat(
[torch.tensor([0], device=self.device), self.cu_seqlens]
).int()
@torch.compiler.disable
def setup_generation(self, lengths):
pass
def prefill(self, tokens: torch.Tensor, lengths: torch.Tensor):
self.setup_prefilling(lengths=lengths)
prefill_out = self.model.forward(
tokens,
tok_idx=self.prefill_tok_id,
cu_seqlens=self.cu_seqlens,
ssm_impl="ssm",
)
return prefill_out
def generate_next_token(self, current_token):
out = self.model.forward(
current_token,
tok_idx=None,
cu_seqlens=None,
ssm_impl="ssm_update",
)
return out
def generate(self, prompts):
return super().generate(prompts)
def main():
# Load CLI arguments (overrides) and combine with a YAML config
cfg = OmegaConf.from_cli()
gen_cfg = dataclass_from_dict(PackedCausalMambaGeneratorArgs, cfg, strict=False)
print(cfg)
model, tokenizer, _ = load_consolidated_model_and_tokenizer(
cfg.ckpt, model_cls=LMMamba, model_args_cls=LMMambaArgs
)
generator = PackedCausalMambaGenerator(gen_cfg, model, tokenizer)
# Allow multiple prompts
prompts = []
while True:
prompt = input("Enter a prompt (or press enter to finish): ")
if not prompt:
break
prompts.append(prompt)
# Start generation
start_time = time.time()
generation, loglikelihood, greedy = generator.generate(prompts)
end_time = time.time()
# Calculate tokens per second
total_tokens = sum(len(tokenizer.encode(gen, False, False)) for gen in generation)
tokens_per_second = total_tokens / (end_time - start_time)
# Display the results
for i, gen in enumerate(generation):
print(f"\nPrompt {i+1}: {prompts[i]}")
print(f"Generated Text: {gen}")
print(f"\nTokens per second: {tokens_per_second:.2f}")
if __name__ == "__main__":
main()
|