File size: 5,274 Bytes
ea3734f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
# Copyright (c) Together
# This software is distributed under the terms of the Apache License, Version 2.0
# Author: Michael Poli

# Barebones generation class for standalone inference.

import torch

from stripedhyena.sample import sample
from stripedhyena.tokenizer import CharLevelTokenizer
from stripedhyena.utils import print_rank_0


class Generator:
    def __init__(self, model, tokenizer, top_k=50, top_p=0.7, temperature=1):
        self.model = model
        self.tokenizer = tokenizer
        self.top_k = top_k
        self.top_p = top_p
        self.temperature = temperature
        self.untils = ["\n\n"]

    def generate(
        self,
        device,
        input_string=None,
        input_ids=None,
        num_tokens=32,
        cached_generation=False,
        print_generation=True,
        verbose=False,
        skip_special_tokens=False,
        stop_at_eos=True,
        max_seqlen=None,
    ):
        if isinstance(self.tokenizer.eos, int):
            eos_token_ids = torch.LongTensor([self.tokenizer.eos]).to(device)
        else:
            # is a tensor
            eos_token_ids = self.tokenizer.tokenize(self.tokenizer.eos).to(device)

        if input_ids is None:
            input = self.tokenizer.tokenize(input_string)
            if isinstance(input, list):
                input = torch.LongTensor(input).unsqueeze(0).to(device)
            # is a tensor
            else:
                input = input.unsqueeze(0).to(device)

        else:
            input = input_ids
        x = input

        if max_seqlen is not None:
            x = x[:, -max_seqlen:]

        prompt_len = x.shape[-1]

        num_tokens = int(num_tokens)
        tot_length = prompt_len + num_tokens
        batch_size = x.shape[0]

        generation = torch.empty(
            x.shape[0],
            num_tokens,
            dtype=torch.long,
            device=x.device,
        )

        scores = torch.empty(
            x.shape[0],
            num_tokens,
            self.tokenizer.vocab_size,
            dtype=torch.float,
            device=x.device,
        )

        if cached_generation:
            inference_params_dict_out = self.model.initialize_inference_params()
            inference_params_dict_out["mha"].max_batch_size = batch_size
            inference_params_dict_out["hyena"].max_batch_size = batch_size
        else:
            inference_params_dict_out = None

        if verbose:
            mem_after_tok = torch.cuda.memory_allocated(device=x.device) / 1e9
            print_rank_0(f"Memory after tokenization: {mem_after_tok} GB")
            print_rank_0("Starting generation...")
            if input_string is not None:
                print_rank_0("Prompt: " + input_string)
            else:
                print_rank_0(f"Prompt ids: {input_ids} {input_ids.shape}")

        for i in range(int(num_tokens)):
            post_prefill = cached_generation and i > 0
            # prefill then process only the last token
            if post_prefill:
                x = x[:, -1:]
                seqlen_offset = inference_params_dict_out["mha"].seqlen_offset

                if seqlen_offset == 0:
                    seqlen_offset = input.shape[-1]
                    inference_params_dict_out["hyena"].seqlen_offset = seqlen_offset
                    inference_params_dict_out["mha"].seqlen_offset = seqlen_offset
                else:
                    inference_params_dict_out["mha"].seqlen_offset += 1
                    inference_params_dict_out["hyena"].seqlen_offset += 1

            # do forward pass with no gradient
            with torch.no_grad():
                logits, inference_params_dict_out = self.model(
                    x,
                    inference_params_dict=inference_params_dict_out,
                )

            last_logits = logits[:, -1]

            new_idx = sample(
                last_logits,
                top_k=self.top_k,
                top_p=self.top_p,
                temperature=self.temperature,
            )

            if stop_at_eos and (generation[0, -2:] == eos_token_ids).all():
                print_rank_0("Stopping generation at EOS")

            if print_generation and verbose and batch_size == 1:
                print_rank_0(
                    f"{self.tokenizer.detokenize([new_idx.item()])}",
                    end=" ",
                )

            scores[:, i] = last_logits
            generation[:, i] = new_idx

            if post_prefill:
                x = new_idx[:, None]
            else:
                x = torch.cat([x, new_idx[:, None]], dim=-1)

        if verbose:
            kwargs = {}
            if not isinstance(self.tokenizer, CharLevelTokenizer):
                kwargs["skip_special_tokens"] = skip_special_tokens
            y = self.tokenizer.detokenize_batch(generation[:, : i + 1], **kwargs)

            for until in self.untils:
                if until in y:
                    y = y.split(until)[0]
                    break

            print_rank_0(f"\nInput: {input_string}, Output: {y}")

            mem_end = torch.cuda.memory_allocated(device=x.device) / 1e9
            print_rank_0(f"Memory after generation: {mem_end} GB")

        return generation[:, : i + 1], scores[:, : i + 1]