File size: 1,022 Bytes
d860e14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
class Tokens:
    def __init__(self, tokens):
        self.tokens = tokens

#    def clone(self):
#        return Tokens(self.tokens.clone())

    def __str__(self):
        return str(self.tokens.tolist())

    def __repr__(self):
        return self.__str__()

class BeamTokens(Tokens):
    def __init__(self, tokens, beam_size):
        self.tokens = tokens
        self.beam_size = beam_size

    def clone(self):
        return BeamTokens(self.tokens.clone())

    def __str__(self):
        return f"BeamTokens({self.tokens.tolist()}, beam_size={self.beam_size})"

    def __repr__(self):
        return self.__str__()

    def as_text(self, tokenizer):
        return tokenizer.decode(self.tokens)

class Logits(Tokens):
    def __init__(self, logits):
        super().__init__(logits)

#    def clone(self):
#        return Logits(self.tokens.clone(), self.beam_size)

    def __str__(self):
#        return "abc"
        return f"Logits({self.tokens.shape})"

    def __repr__(self):
        return self.__str__()