File size: 6,613 Bytes
72c0672 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 | # Copyright (c) Meta Platforms, Inc. and affiliates.
from pathlib import Path
import time
from dataclasses import dataclass
from omegaconf import OmegaConf
import torch
from torch import nn
from lingua.args import dataclass_from_dict
from lingua.checkpoint import CONSOLIDATE_NAME
from lingua.tokenizer import Tokenizer, build_tokenizer
from apps.main.generate import (
PackedCausalTransformerGenerator,
PackedCausalTransformerGeneratorArgs,
)
from apps.fastRNN.minGRU.core_gru import GRU
from apps.fastRNN.minLSTM.core_lstm import LSTM
from apps.fastRNN.hawk.core_hawk import RGLRU
from apps.fastRNN.minGRU.mingru import LMMinGRU, LMMinGRUArgs
from apps.fastRNN.minLSTM.minlstm import LMMinLSTM, LMMinLSTMArgs
from apps.fastRNN.hawk.hawk import LMHawk, LMHawkArgs
def load_consolidated_model_and_tokenizer(consolidated_path):
ckpt_path = Path(consolidated_path)
config = ckpt_path / "params.json"
config = OmegaConf.load(config)
if config.model_type.lower() == "mingru":
model_cls = LMMinGRU
model_args_cls = LMMinGRUArgs
elif config.model_type.lower() == "minlstm":
model_cls = LMMinLSTM
model_args_cls = LMMinLSTMArgs
elif config.model_type.lower() == "hawk":
model_cls = LMHawk
model_args_cls = LMHawkArgs
else:
raise ValueError(f"Unknown model type: {config.model_type}")
param_dtype = dict(fp32=torch.float32, fp16=torch.float16, bf16=torch.bfloat16)[
config.distributed.model_dtype
]
model_args = dataclass_from_dict(model_args_cls, config.model, strict=False)
tokenizer = build_tokenizer(config.data.tokenizer.name, config.data.tokenizer.path)
model = model_cls(model_args)
st_dict = torch.load(ckpt_path / CONSOLIDATE_NAME, weights_only=True)
model.load_state_dict(st_dict["model"], strict=False)
model = model.cuda().eval()
for param in model.parameters():
param.data = param.data.to(dtype=param_dtype)
return model, tokenizer, config
class StateCache(nn.Module):
def __init__(self, bsz, n_heads, head_dim, conv_size, conv_dim, dtype, device):
super().__init__()
state_shape = (n_heads, head_dim, bsz)
if conv_size is None:
conv_shape = (0,)
else:
conv_shape = (bsz, conv_dim, conv_size)
self.register_buffer(
"conv_cache",
torch.zeros(conv_shape, dtype=dtype, device=device),
persistent=False,
)
self.register_buffer(
"state_cache",
torch.zeros(state_shape, dtype=dtype, device=device),
persistent=False,
)
def reset(self):
self.conv_cache.zero_()
self.state_cache.zero_()
@dataclass
class PackedRNNGeneratorArgs(PackedCausalTransformerGeneratorArgs):
pass
class PackedRNNGenerator(PackedCausalTransformerGenerator):
def __init__(
self,
cfg: PackedRNNGeneratorArgs,
model: nn.Module,
tokenizer: Tokenizer,
):
self.model = model
self.tokenizer = tokenizer
self.temperature = cfg.temperature
self.top_p = cfg.top_p
self.top_k = cfg.top_k
self.max_gen_len = cfg.max_gen_len
self.max_tokens = cfg.max_tokens
self.max_prompt_len = cfg.max_prompt_len
self.until = cfg.until
self.max_until_size = max([len(e) for e in self.until]) if self.until else 1
self.device = cfg.device
# Compile if necessary
self.prefill = torch.compile(self.prefill, disable=not cfg.compile_prefilling)
self.generate_next_token = torch.compile(
self.generate_next_token,
mode="reduce-overhead",
disable=not cfg.reduce_generation_overhead,
)
self.show_progress = cfg.show_progress
self.dtype = dict(fp32=torch.float32, bf16=torch.bfloat16)[cfg.dtype]
self.cu_seqlens = None
self.tok_idx = None
def clear_cache(self, lengths: torch.Tensor):
for module in self.model.modules():
if isinstance(module, (GRU, LSTM, RGLRU)):
module.cache = StateCache(
lengths.size(0),
module.n_heads,
module.head_dim,
module.conv_size,
module.conv_dim,
self.dtype,
self.device,
)
@torch.compiler.disable
def setup_prefilling(self, lengths: torch.Tensor):
self.clear_cache(lengths)
self.cu_seqlens = lengths.cumsum(0)
self.cu_seqlens = torch.cat(
[torch.tensor([0], device=self.device), self.cu_seqlens]
).int()
self.tok_idx = torch.repeat_interleave(lengths).int().unsqueeze(0).to(self.device)
@torch.compiler.disable
def setup_generation(self, lengths):
pass
def prefill(self, tokens: torch.Tensor, lengths: torch.Tensor):
self.setup_prefilling(lengths=lengths)
prefill_out = self.model.forward(
tokens,
tok_idx=self.tok_idx,
cu_seqlens=self.cu_seqlens,
impl="parallel",
)
return prefill_out
def generate_next_token(self, current_token):
out = self.model.forward(
current_token,
cu_seqlens=None,
impl="sequential",
)
return out
def generate(self, prompts):
return super().generate(prompts)
def main():
# Load CLI arguments (overrides) and combine with a YAML config
cfg = OmegaConf.from_cli()
gen_cfg = dataclass_from_dict(PackedRNNGeneratorArgs, cfg, strict=False)
print(cfg)
model, tokenizer, _ = load_consolidated_model_and_tokenizer(cfg.ckpt)
generator = PackedRNNGenerator(gen_cfg, model, tokenizer)
# Allow multiple prompts
prompts = []
while True:
prompt = input("Enter a prompt (or press enter to finish): ")
if not prompt:
break
prompts.append(prompt)
# Start generation
start_time = time.time()
generation, loglikelihood, greedy = generator.generate(prompts)
end_time = time.time()
# Calculate tokens per second
total_tokens = sum(len(tokenizer.encode(gen, False, False)) for gen in generation)
tokens_per_second = total_tokens / (end_time - start_time)
# Display the results
for i, gen in enumerate(generation):
print(f"\nPrompt {i+1}: {prompts[i]}")
print(f"Generated Text: {gen}")
print(f"\nTokens per second: {tokens_per_second:.2f}")
if __name__ == "__main__":
main()
|