File size: 6,354 Bytes
ed99557
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
# Copyright (c) Kyutai, all rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Utilities for the command line client, in particular for handling interactions with the terminal.
"""

from dataclasses import dataclass
import sys


def colorize(text, color):
    code = f"\033[{color}m"
    restore = "\033[0m"
    return "".join([code, text, restore])


def make_log(level: str, msg: str) -> str:
    if level == "warning":
        prefix = colorize("[Warn]", "1;31")
    elif level == "info":
        prefix = colorize("[Info]", "1;34")
    elif level == "error":
        prefix = colorize("[Err ]", "1;31")
    else:
        raise ValueError(f"Unknown level {level}")
    return prefix + " " + msg


class RawPrinter:
    def __init__(self, stream=sys.stdout, err_stream=sys.stderr):
        self.stream = stream
        self.err_stream = err_stream

    def print_header(self):
        pass

    def print_token(self, token: str):
        self.stream.write(token)
        self.stream.flush()

    def log(self, level: str, msg: str):
        print(f"{level.capitalize()}: {msg}", file=self.err_stream)

    def print_lag(self):
        self.err_stream.write(colorize(" [LAG]", "31"))
        self.err_stream.flush()

    def print_pending(self):
        pass


@dataclass
class LineEntry:
    msg: str
    color: str | None = None

    def render(self):
        if self.color is None:
            return self.msg
        else:
            return colorize(self.msg, self.color)

    def __len__(self):
        return len(self.msg)


class Line:
    def __init__(self, stream):
        self.stream = stream
        self._line: list[LineEntry] = []
        self._has_padding: bool = False
        self._max_line_length = 0

    def __bool__(self):
        return bool(self._line)

    def __len__(self):
        return sum(len(entry) for entry in self._line)

    def add(self, msg: str, color: str | None = None) -> int:
        entry = LineEntry(msg, color)
        return self._add(entry)

    def _add(self, entry: LineEntry) -> int:
        if self._has_padding:
            self.erase(count=0)
        self._line.append(entry)
        self.stream.write(entry.render())
        self._max_line_length = max(self._max_line_length, len(self))
        return len(entry)

    def erase(self, count: int = 1):
        if count:
            entries = list(self._line[:-count])
        else:
            entries = list(self._line)
        self._line.clear()
        self.stream.write("\r")
        for entry in entries:
            self._line.append(entry)
            self.stream.write(entry.render())

        self._has_padding = False

    def newline(self):
        missing = self._max_line_length - len(self)
        if missing > 0:
            self.stream.write(" " * missing)
        self.stream.write("\n")
        self._line.clear()
        self._max_line_length = 0
        self._has_padding = False

    def flush(self):
        missing = self._max_line_length - len(self)
        if missing > 0:
            self.stream.write(" " * missing)
            self._has_padding = True
        self.stream.flush()


class Printer:
    def __init__(self, max_cols: int = 80, stream=sys.stdout, err_stream=sys.stderr):
        self.max_cols = max_cols
        self.line = Line(stream)
        self.stream = stream
        self.err_stream = err_stream
        self._pending_count = 0
        self._pending_printed = False

    def print_header(self):
        self.line.add(" " + "-" * (self.max_cols) + " ")
        self.line.newline()
        self.line.flush()
        self.line.add("| ")

    def _remove_pending(self) -> bool:
        if self._pending_printed:
            self._pending_printed = False
            self.line.erase(1)
            return True
        return False

    def print_token(self, token: str, color: str | None = None):
        self._remove_pending()
        remaining = self.max_cols - len(self.line)
        if len(token) <= remaining:
            self.line.add(token, color)
        else:
            end = " " * remaining + " |"
            if token.startswith(" "):
                token = token.lstrip()
                self.line.add(end)
                self.line.newline()
                self.line.add("| ")
                self.line.add(token, color)
            else:
                assert color is None
                erase_count = None
                cumulated = ""
                for idx, entry in enumerate(self.line._line[::-1]):
                    if entry.color:
                        # probably a LAG message
                        erase_count = idx
                        break
                    if entry.msg.startswith(" "):
                        erase_count = idx + 1
                        cumulated = entry.msg + cumulated
                        break
                if erase_count is not None:
                    if erase_count > 0:
                        self.line.erase(erase_count)
                    remaining = self.max_cols - len(self.line)
                    end = " " * remaining + " |"
                    self.line.add(end)
                    self.line.newline()
                    self.line.add("| ")
                    token = cumulated.lstrip() + token
                    self.line.add(token)
                else:
                    self.line.add(token[:remaining])
                    self.line.add(" |")
                    self.line.newline()
                    self.line.add("| ")
                    self.line.add(token[remaining:])
        self.line.flush()

    def log(self, level: str, msg: str):
        msg = make_log(level, msg)
        self._remove_pending()
        if self.line:
            self.line.newline()
        self.line.flush()
        print(msg, file=self.err_stream)
        self.err_stream.flush()

    def print_lag(self):
        self.print_token(" [LAG]", "31")

    def print_pending(self):
        chars = ["|", "/", "-", "\\"]
        count = int(self._pending_count / 5)
        char = chars[count % len(chars)]
        colors = ["32", "33", "31"]
        self._remove_pending()
        self.line.add(char, colors[count % len(colors)])
        self._pending_printed = True
        self._pending_count += 1


AnyPrinter = Printer | RawPrinter