File size: 6,613 Bytes
72c0672
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
# Copyright (c) Meta Platforms, Inc. and affiliates.

from pathlib import Path
import time
from dataclasses import dataclass

from omegaconf import OmegaConf

import torch
from torch import nn

from lingua.args import dataclass_from_dict
from lingua.checkpoint import CONSOLIDATE_NAME
from lingua.tokenizer import Tokenizer, build_tokenizer

from apps.main.generate import (
    PackedCausalTransformerGenerator,
    PackedCausalTransformerGeneratorArgs,
)

from apps.fastRNN.minGRU.core_gru import GRU
from apps.fastRNN.minLSTM.core_lstm import LSTM
from apps.fastRNN.hawk.core_hawk import RGLRU

from apps.fastRNN.minGRU.mingru import LMMinGRU, LMMinGRUArgs
from apps.fastRNN.minLSTM.minlstm import LMMinLSTM, LMMinLSTMArgs
from apps.fastRNN.hawk.hawk import LMHawk, LMHawkArgs


def load_consolidated_model_and_tokenizer(consolidated_path):
    ckpt_path = Path(consolidated_path)
    config = ckpt_path / "params.json"
    config = OmegaConf.load(config)

    if config.model_type.lower() == "mingru":
        model_cls = LMMinGRU
        model_args_cls = LMMinGRUArgs
    elif config.model_type.lower() == "minlstm":
        model_cls = LMMinLSTM
        model_args_cls = LMMinLSTMArgs
    elif config.model_type.lower() == "hawk":
        model_cls = LMHawk
        model_args_cls = LMHawkArgs
    else:
        raise ValueError(f"Unknown model type: {config.model_type}")

    param_dtype = dict(fp32=torch.float32, fp16=torch.float16, bf16=torch.bfloat16)[
        config.distributed.model_dtype
    ]
    model_args = dataclass_from_dict(model_args_cls, config.model, strict=False)
    tokenizer = build_tokenizer(config.data.tokenizer.name, config.data.tokenizer.path)
    model = model_cls(model_args)
    st_dict = torch.load(ckpt_path / CONSOLIDATE_NAME, weights_only=True)
    model.load_state_dict(st_dict["model"], strict=False)
    model = model.cuda().eval()
    for param in model.parameters():
        param.data = param.data.to(dtype=param_dtype)
    return model, tokenizer, config


class StateCache(nn.Module):
    def __init__(self, bsz, n_heads, head_dim, conv_size, conv_dim, dtype, device):
        super().__init__()
        state_shape = (n_heads, head_dim, bsz)
        if conv_size is None:
            conv_shape = (0,)
        else:
            conv_shape = (bsz, conv_dim, conv_size)

        self.register_buffer(
            "conv_cache",
            torch.zeros(conv_shape, dtype=dtype, device=device),
            persistent=False,
        )
        self.register_buffer(
            "state_cache",
            torch.zeros(state_shape, dtype=dtype, device=device),
            persistent=False,
        )

    def reset(self):
        self.conv_cache.zero_()
        self.state_cache.zero_()


@dataclass
class PackedRNNGeneratorArgs(PackedCausalTransformerGeneratorArgs):
    pass


class PackedRNNGenerator(PackedCausalTransformerGenerator):
    def __init__(
        self,
        cfg: PackedRNNGeneratorArgs,
        model: nn.Module,
        tokenizer: Tokenizer,
    ):
        self.model = model
        self.tokenizer = tokenizer
        self.temperature = cfg.temperature
        self.top_p = cfg.top_p
        self.top_k = cfg.top_k

        self.max_gen_len = cfg.max_gen_len
        self.max_tokens = cfg.max_tokens
        self.max_prompt_len = cfg.max_prompt_len
        self.until = cfg.until
        self.max_until_size = max([len(e) for e in self.until]) if self.until else 1
        self.device = cfg.device

        # Compile if necessary
        self.prefill = torch.compile(self.prefill, disable=not cfg.compile_prefilling)
        self.generate_next_token = torch.compile(
            self.generate_next_token,
            mode="reduce-overhead",
            disable=not cfg.reduce_generation_overhead,
        )

        self.show_progress = cfg.show_progress
        self.dtype = dict(fp32=torch.float32, bf16=torch.bfloat16)[cfg.dtype]

        self.cu_seqlens = None
        self.tok_idx = None

    def clear_cache(self, lengths: torch.Tensor):
        for module in self.model.modules():
            if isinstance(module, (GRU, LSTM, RGLRU)):
                module.cache = StateCache(
                    lengths.size(0),
                    module.n_heads,
                    module.head_dim,
                    module.conv_size,
                    module.conv_dim,
                    self.dtype,
                    self.device,
                )

    @torch.compiler.disable
    def setup_prefilling(self, lengths: torch.Tensor):
        self.clear_cache(lengths)

        self.cu_seqlens = lengths.cumsum(0)
        self.cu_seqlens = torch.cat(
            [torch.tensor([0], device=self.device), self.cu_seqlens]
        ).int()

        self.tok_idx = torch.repeat_interleave(lengths).int().unsqueeze(0).to(self.device)


    @torch.compiler.disable
    def setup_generation(self, lengths):
        pass

    def prefill(self, tokens: torch.Tensor, lengths: torch.Tensor):
        self.setup_prefilling(lengths=lengths)
        prefill_out = self.model.forward(
            tokens,
            tok_idx=self.tok_idx,
            cu_seqlens=self.cu_seqlens,
            impl="parallel",
        )

        return prefill_out

    def generate_next_token(self, current_token):
        out = self.model.forward(
            current_token,
            cu_seqlens=None,
            impl="sequential",
        )
        return out

    def generate(self, prompts):
        return super().generate(prompts)


def main():
    # Load CLI arguments (overrides) and combine with a YAML config
    cfg = OmegaConf.from_cli()
    gen_cfg = dataclass_from_dict(PackedRNNGeneratorArgs, cfg, strict=False)
    print(cfg)

    model, tokenizer, _ = load_consolidated_model_and_tokenizer(cfg.ckpt)

    generator = PackedRNNGenerator(gen_cfg, model, tokenizer)

    # Allow multiple prompts
    prompts = []
    while True:
        prompt = input("Enter a prompt (or press enter to finish): ")
        if not prompt:
            break
        prompts.append(prompt)

    # Start generation
    start_time = time.time()
    generation, loglikelihood, greedy = generator.generate(prompts)
    end_time = time.time()

    # Calculate tokens per second
    total_tokens = sum(len(tokenizer.encode(gen, False, False)) for gen in generation)
    tokens_per_second = total_tokens / (end_time - start_time)

    # Display the results
    for i, gen in enumerate(generation):
        print(f"\nPrompt {i+1}: {prompts[i]}")
        print(f"Generated Text: {gen}")

    print(f"\nTokens per second: {tokens_per_second:.2f}")


if __name__ == "__main__":
    main()