File size: 4,549 Bytes
72c0672
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
# Copyright (c) Meta Platforms, Inc. and affiliates.

import time
from dataclasses import dataclass

from omegaconf import OmegaConf

import torch
from torch import nn

from lingua.args import dataclass_from_dict
from lingua.tokenizer import Tokenizer

from apps.main.generate import (
    PackedCausalTransformerGenerator,
    PackedCausalTransformerGeneratorArgs,
    load_consolidated_model_and_tokenizer,
)
from apps.mamba.core_mamba import SSM
from apps.mamba.mamba import LMMambaArgs, LMMamba, StateCache


@dataclass
class PackedCausalMambaGeneratorArgs(PackedCausalTransformerGeneratorArgs):
    pass


class PackedCausalMambaGenerator(PackedCausalTransformerGenerator):
    def __init__(
        self,
        cfg: PackedCausalMambaGeneratorArgs,
        model: nn.Module,
        tokenizer: Tokenizer,
    ):
        self.model = model
        self.tokenizer = tokenizer
        self.temperature = cfg.temperature
        self.top_p = cfg.top_p
        self.top_k = cfg.top_k

        self.max_gen_len = cfg.max_gen_len
        self.max_tokens = cfg.max_tokens
        self.max_prompt_len = cfg.max_prompt_len
        self.until = cfg.until
        self.max_until_size = max([len(e) for e in self.until]) if self.until else 1
        self.device = cfg.device

        # Compile if necessary
        self.prefill = torch.compile(self.prefill, disable=not cfg.compile_prefilling)
        self.generate_next_token = torch.compile(
            self.generate_next_token,
            mode="reduce-overhead",
            disable=not cfg.reduce_generation_overhead,
        )

        self.show_progress = cfg.show_progress
        self.dtype = dict(fp32=torch.float32, bf16=torch.bfloat16)[cfg.dtype]

        self.prefill_tok_id = None
        self.cu_seqlens = None

    def clear_cache(self, lengths: torch.Tensor):
        for module in self.model.modules():
            if isinstance(module, SSM):
                module.cache = StateCache(
                    lengths.size(0),
                    module.n_heads,
                    module.head_dim,
                    module.state_dim,
                    module.conv_size,
                    module.conv_dim,
                    self.dtype,
                    self.device,
                )

    @torch.compiler.disable
    def setup_prefilling(self, lengths: torch.Tensor):
        self.clear_cache(lengths)

        self.prefill_tok_id = torch.repeat_interleave(lengths).unsqueeze(0).int()
        self.cu_seqlens = lengths.cumsum(0)
        self.cu_seqlens = torch.cat(
            [torch.tensor([0], device=self.device), self.cu_seqlens]
        ).int()

    @torch.compiler.disable
    def setup_generation(self, lengths):
        pass

    def prefill(self, tokens: torch.Tensor, lengths: torch.Tensor):
        self.setup_prefilling(lengths=lengths)
        prefill_out = self.model.forward(
            tokens,
            tok_idx=self.prefill_tok_id,
            cu_seqlens=self.cu_seqlens,
            ssm_impl="ssm",
        )

        return prefill_out

    def generate_next_token(self, current_token):
        out = self.model.forward(
            current_token,
            tok_idx=None,
            cu_seqlens=None,
            ssm_impl="ssm_update",
        )
        return out

    def generate(self, prompts):
        return super().generate(prompts)


def main():
    # Load CLI arguments (overrides) and combine with a YAML config
    cfg = OmegaConf.from_cli()
    gen_cfg = dataclass_from_dict(PackedCausalMambaGeneratorArgs, cfg, strict=False)
    print(cfg)

    model, tokenizer, _ = load_consolidated_model_and_tokenizer(
        cfg.ckpt, model_cls=LMMamba, model_args_cls=LMMambaArgs
    )

    generator = PackedCausalMambaGenerator(gen_cfg, model, tokenizer)

    # Allow multiple prompts
    prompts = []
    while True:
        prompt = input("Enter a prompt (or press enter to finish): ")
        if not prompt:
            break
        prompts.append(prompt)

    # Start generation
    start_time = time.time()
    generation, loglikelihood, greedy = generator.generate(prompts)
    end_time = time.time()

    # Calculate tokens per second
    total_tokens = sum(len(tokenizer.encode(gen, False, False)) for gen in generation)
    tokens_per_second = total_tokens / (end_time - start_time)

    # Display the results
    for i, gen in enumerate(generation):
        print(f"\nPrompt {i+1}: {prompts[i]}")
        print(f"Generated Text: {gen}")

    print(f"\nTokens per second: {tokens_per_second:.2f}")


if __name__ == "__main__":
    main()