File size: 19,397 Bytes
7feac49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
# Copyright (c) Alibaba, Inc. and its affiliates.
import collections
import os.path
import sys
import time
import webbrowser
from datetime import datetime
from typing import Dict, List, Tuple, Type

import gradio as gr
import json
import matplotlib.pyplot as plt
import psutil
from packaging import version
from transformers import is_tensorboard_available

from swift.ui.base import BaseUI
from swift.ui.llm_train.utils import close_loop, run_command_in_subprocess
from swift.utils import TB_COLOR, TB_COLOR_SMOOTH, get_logger, read_tensorboard_file, tensorboard_smoothing
from swift.utils.utils import format_time

logger = get_logger()


class Runtime(BaseUI):

    handlers: Dict[str, Tuple[List, Tuple]] = {}

    group = 'llm_train'

    all_plots = None

    log_event = {}

    sft_plot = [
        {
            'name': 'train/loss',
            'smooth': 0.9,
        },
        {
            'name': 'train/acc',
            'smooth': None,
        },
        {
            'name': 'train/learning_rate',
            'smooth': None,
        },
        {
            'name': 'eval/loss',
            'smooth': 0.9,
        },
        {
            'name': 'eval/acc',
            'smooth': None,
        },
    ]

    dpo_plot = [
        {
            'name': 'train/loss',
            'smooth': 0.9,
        },
        {
            'name': 'train/rewards/accuracies',
            'smooth': None,
        },
        {
            'name': 'train/rewards/margins',
            'smooth': 0.9,
        },
        {
            'name': 'train/logps/chosen',
            'smooth': 0.9,
        },
        {
            'name': 'train/logps/rejected',
            'smooth': 0.9,
        },
    ]

    kto_plot = [
        {
            'name': 'kl',
            'smooth': None,
        },
        {
            'name': 'rewards/chosen_sum',
            'smooth': 0.9,
        },
        {
            'name': 'logps/chosen_sum',
            'smooth': 0.9,
        },
        {
            'name': 'rewards/rejected_sum',
            'smooth': 0.9,
        },
        {
            'name': 'logps/rejected_sum',
            'smooth': 0.9,
        },
    ]

    orpo_plot = [
        {
            'name': 'train/loss',
            'smooth': 0.9,
        },
        {
            'name': 'train/rewards/accuracies',
            'smooth': None,
        },
        {
            'name': 'train/rewards/margins',
            'smooth': 0.9,
        },
        {
            'name': 'train/rewards/chosen',
            'smooth': 0.9,
        },
        {
            'name': 'train/log_odds_ratio',
            'smooth': 0.9,
        },
    ]

    locale_dict = {
        'runtime_tab': {
            'label': {
                'zh': '运行时',
                'en': 'Runtime'
            },
        },
        'tb_not_found': {
            'value': {
                'zh': 'tensorboard未安装,使用pip install tensorboard进行安装',
                'en': 'tensorboard not found, install it by pip install tensorboard',
            }
        },
        'running_cmd': {
            'label': {
                'zh': '运行命令',
                'en': 'Command line'
            },
            'info': {
                'zh': '执行的实际命令',
                'en': 'The actual command'
            }
        },
        'show_log': {
            'value': {
                'zh': '展示运行状态',
                'en': 'Show running status'
            },
        },
        'stop_show_log': {
            'value': {
                'zh': '停止展示运行状态',
                'en': 'Stop showing running status'
            },
        },
        'logging_dir': {
            'label': {
                'zh': '日志路径',
                'en': 'Logging dir'
            },
            'info': {
                'zh': '支持手动传入文件路径',
                'en': 'Support fill custom path in'
            }
        },
        'log': {
            'label': {
                'zh': '日志输出',
                'en': 'Logging content'
            },
            'info': {
                'zh': '如果日志无更新请再次点击"展示日志内容"',
                'en': 'Please press "Show log" if the log content is not updating'
            }
        },
        'running_tasks': {
            'label': {
                'zh': '运行中任务',
                'en': 'Running Tasks'
            },
            'info': {
                'zh': '运行中的任务(所有的swift sft命令)',
                'en': 'All running tasks(started by swift sft)'
            }
        },
        'refresh_tasks': {
            'value': {
                'zh': '找回运行时任务',
                'en': 'Find running tasks'
            },
        },
        'kill_task': {
            'value': {
                'zh': '杀死任务',
                'en': 'Kill running task'
            },
        },
        'tb_url': {
            'label': {
                'zh': 'Tensorboard链接',
                'en': 'Tensorboard URL'
            },
            'info': {
                'zh': '仅展示,不可编辑',
                'en': 'Not editable'
            }
        },
        'start_tb': {
            'value': {
                'zh': '打开TensorBoard',
                'en': 'Start TensorBoard'
            },
        },
        'close_tb': {
            'value': {
                'zh': '关闭TensorBoard',
                'en': 'Close TensorBoard'
            },
        },
    }

    @classmethod
    def do_build_ui(cls, base_tab: Type['BaseUI']):
        with gr.Accordion(elem_id='runtime_tab', open=False, visible=True):
            with gr.Blocks():
                with gr.Row():
                    gr.Textbox(elem_id='running_cmd', lines=1, scale=20, interactive=False, max_lines=1)
                    gr.Textbox(elem_id='logging_dir', lines=1, scale=20, max_lines=1)
                    gr.Button(elem_id='show_log', scale=2, variant='primary')
                    gr.Button(elem_id='stop_show_log', scale=2)
                    gr.Textbox(elem_id='tb_url', lines=1, scale=10, interactive=False, max_lines=1)
                    gr.Button(elem_id='start_tb', scale=2, variant='primary')
                    gr.Button(elem_id='close_tb', scale=2)
                with gr.Row():
                    gr.Textbox(elem_id='log', lines=6, visible=False)
                with gr.Row():
                    gr.Dropdown(elem_id='running_tasks', scale=10)
                    gr.Button(elem_id='refresh_tasks', scale=1)
                    gr.Button(elem_id='kill_task', scale=1)

                with gr.Row():
                    cls.all_plots = []
                    for idx, k in enumerate(Runtime.sft_plot):
                        name = k['name']
                        cls.all_plots.append(gr.Plot(elem_id=str(idx), label=name))

                concurrency_limit = {}
                if version.parse(gr.__version__) >= version.parse('4.0.0'):
                    concurrency_limit = {'concurrency_limit': 5}
                base_tab.element('show_log').click(
                    Runtime.update_log, [base_tab.element('running_tasks')], [cls.element('log')] + cls.all_plots).then(
                        Runtime.wait, [base_tab.element('logging_dir'),
                                       base_tab.element('running_tasks')], [cls.element('log')] + cls.all_plots,
                        **concurrency_limit)

                base_tab.element('stop_show_log').click(cls.break_log_event, [cls.element('running_tasks')], [])

                base_tab.element('start_tb').click(
                    Runtime.start_tb,
                    [base_tab.element('logging_dir')],
                    [base_tab.element('tb_url')],
                )

                base_tab.element('close_tb').click(
                    Runtime.close_tb,
                    [base_tab.element('logging_dir')],
                    [],
                )

                base_tab.element('refresh_tasks').click(
                    Runtime.refresh_tasks,
                    [base_tab.element('running_tasks')],
                    [base_tab.element('running_tasks')],
                )

    @classmethod
    def get_plot(cls, task):
        if not task or 'swift sft' in task or 'swift pt' in task:
            return cls.sft_plot

        args: dict = cls.parse_info_from_cmdline(task)[1]
        train_type = args.get('rlhf_type', 'dpo')
        if train_type in ('dpo', 'cpo', 'simpo'):
            return cls.dpo_plot
        elif train_type == 'kto':
            return cls.kto_plot
        elif train_type == 'orpo':
            return cls.orpo_plot

    @classmethod
    def update_log(cls, task):
        ret = [gr.update(visible=True)]
        plot = Runtime.get_plot(task)
        for i in range(len(plot)):
            p = plot[i]
            ret.append(gr.update(visible=True, label=p['name']))
        return ret

    @classmethod
    def get_initial(cls, line):
        tqdm_starts = ['Train:', 'Map:', 'Val:', 'Filter:']
        for start in tqdm_starts:
            if line.startswith(start):
                return start
        return None

    @classmethod
    def wait(cls, logging_dir, task):
        if not logging_dir:
            return [None] + Runtime.plot(task)
        log_file = os.path.join(logging_dir, 'run.log')
        cls.log_event[logging_dir] = False
        offset = 0
        latest_data = ''
        lines = collections.deque(maxlen=int(os.environ.get('MAX_LOG_LINES', 50)))
        try:
            with open(log_file, 'r', encoding='utf-8') as input:
                input.seek(offset)
                fail_cnt = 0
                while True:
                    try:
                        latest_data += input.read()
                    except UnicodeDecodeError:
                        continue
                    if not latest_data:
                        time.sleep(0.5)
                        fail_cnt += 1
                        if fail_cnt > 50:
                            break

                    if cls.log_event.get(logging_dir, False):
                        cls.log_event[logging_dir] = False
                        break

                    if '\n' not in latest_data:
                        continue
                    latest_lines = latest_data.split('\n')
                    if latest_data[-1] != '\n':
                        latest_data = latest_lines[-1]
                        latest_lines = latest_lines[:-1]
                    else:
                        latest_data = ''
                    lines.extend(latest_lines)
                    start = cls.get_initial(lines[-1])
                    if start:
                        i = len(lines) - 2
                        while i >= 0:
                            if lines[i].startswith(start):
                                del lines[i]
                                i -= 1
                            else:
                                break
                    yield ['\n'.join(lines)] + Runtime.plot(task)
        except IOError:
            pass

    @classmethod
    def break_log_event(cls, task):
        if not task:
            return
        pid, all_args = Runtime.parse_info_from_cmdline(task)
        cls.log_event[all_args['logging_dir']] = True

    @classmethod
    def show_log(cls, logging_dir):
        webbrowser.open('file://' + os.path.join(logging_dir, 'run.log'), new=2)

    @classmethod
    def start_tb(cls, logging_dir):
        if not is_tensorboard_available():
            gr.Error(cls.locale('tb_not_found', cls.lang)['value'])
            return ''

        logging_dir = logging_dir.strip()
        logging_dir = logging_dir if not logging_dir.endswith(os.sep) else logging_dir[:-1]
        if logging_dir in cls.handlers:
            return cls.handlers[logging_dir][1]

        handler, lines = run_command_in_subprocess('tensorboard', '--logdir', logging_dir, timeout=2)
        localhost_addr = ''
        for line in lines:
            if 'http://localhost:' in line:
                line = line[line.index('http://localhost:'):]
                localhost_addr = line[:line.index(' ')]
        cls.handlers[logging_dir] = (handler, localhost_addr)
        logger.info('===========Tensorboard Log============')
        logger.info('\n'.join(lines))
        webbrowser.open(localhost_addr, new=2)
        return localhost_addr

    @staticmethod
    def close_tb(logging_dir):
        if logging_dir in Runtime.handlers:
            close_loop(Runtime.handlers[logging_dir][0])
            Runtime.handlers.pop(logging_dir)

    @staticmethod
    def refresh_tasks(running_task=None):
        output_dir = running_task if not running_task or 'pid:' not in running_task else None
        process_name = 'swift'
        negative_name = 'swift.exe'
        cmd_name = ['pt', 'sft', 'rlhf']
        process = []
        selected = None
        for proc in psutil.process_iter():
            try:
                cmdlines = proc.cmdline()
            except (psutil.ZombieProcess, psutil.AccessDenied, psutil.NoSuchProcess):
                cmdlines = []
            if any([process_name in cmdline
                    for cmdline in cmdlines]) and not any([negative_name in cmdline
                                                           for cmdline in cmdlines]) and any(  # noqa
                                                               [cmdline in cmd_name for cmdline in cmdlines]):  # noqa
                process.append(Runtime.construct_running_task(proc))
                if output_dir is not None and any(  # noqa
                    [output_dir == cmdline for cmdline in cmdlines]):  # noqa
                    selected = Runtime.construct_running_task(proc)
        if not selected:
            if running_task and running_task in process:
                selected = running_task
        if not selected and process:
            selected = process[0]
        return gr.update(choices=process, value=selected)

    @staticmethod
    def construct_running_task(proc):
        pid = proc.pid
        ts = time.time()
        create_time = proc.create_time()
        create_time_formatted = datetime.fromtimestamp(create_time).strftime('%Y-%m-%d, %H:%M')

        return f'pid:{pid}/create:{create_time_formatted}' \
               f'/running:{format_time(ts-create_time)}/cmd:{" ".join(proc.cmdline())}'

    @staticmethod
    def parse_info_from_cmdline(task):
        pid = None
        if '/cmd:' in task:
            for i in range(3):
                slash = task.find('/')
                if i == 0:
                    pid = task[:slash].split(':')[1]
                task = task[slash + 1:]
        if 'swift sft' in task:
            args = task.split('swift sft')[1]
        elif 'swift pt' in task:
            args = task.split('swift pt')[1]
        elif 'swift rlhf' in task:
            args = task.split('swift rlhf')[1]
        else:
            raise ValueError(f'Cannot parse cmd line: {task}')
        args = [arg.strip() for arg in args.split('--') if arg.strip()]
        all_args = {}
        for i in range(len(args)):
            space = args[i].find(' ')
            splits = args[i][:space], args[i][space + 1:]
            all_args[splits[0]] = splits[1]

        output_dir = all_args['output_dir']
        if os.path.exists(os.path.join(output_dir, 'args.json')):
            with open(os.path.join(output_dir, 'args.json'), 'r', encoding='utf-8') as f:
                _json = json.load(f)
            for key in all_args.keys():
                all_args[key] = _json.get(key)
                if isinstance(all_args[key], list):
                    if any([' ' in value for value in all_args[key]]):
                        all_args[key] = [f'"{value}"' for value in all_args[key]]
                    all_args[key] = ' '.join(all_args[key])
        return pid, all_args

    @staticmethod
    def kill_task(task):
        if task:
            pid, all_args = Runtime.parse_info_from_cmdline(task)
            output_dir = all_args['output_dir']
            if sys.platform == 'win32':
                os.system(f'taskkill /f /t /pid "{pid}"')
            else:
                os.system(f'pkill -9 -f {output_dir}')
            time.sleep(1)
            Runtime.break_log_event(task)
        return [Runtime.refresh_tasks()] + [gr.update(value=None)] * (len(Runtime.get_plot(task)) + 1)

    @staticmethod
    def reset():
        return None, 'output'

    @staticmethod
    def task_changed(task, base_tab):
        if task:
            _, all_args = Runtime.parse_info_from_cmdline(task)
        else:
            all_args = {}
        elements = list(base_tab.valid_elements().values())
        ret = []
        for e in elements:
            if e.elem_id in all_args:
                if isinstance(e, gr.Dropdown) and e.multiselect:
                    arg = all_args[e.elem_id].split(' ')
                else:
                    arg = all_args[e.elem_id]
                ret.append(gr.update(value=arg))
            else:
                ret.append(gr.update())
        Runtime.break_log_event(task)
        return ret + [gr.update(value=None)] * (len(Runtime.get_plot(task)) + 1)

    @staticmethod
    def plot(task):
        plot = Runtime.get_plot(task)
        if not task:
            return [None] * len(plot)
        _, all_args = Runtime.parse_info_from_cmdline(task)
        tb_dir = all_args['logging_dir']
        if not os.path.exists(tb_dir):
            return [None] * len(plot)
        fname = [
            fname for fname in os.listdir(tb_dir)
            if os.path.isfile(os.path.join(tb_dir, fname)) and fname.startswith('events.out')
        ]
        if fname:
            fname = fname[0]
        else:
            return [None] * len(plot)
        tb_path = os.path.join(tb_dir, fname)
        data = read_tensorboard_file(tb_path)

        plots = []
        for k in plot:
            name = k['name']
            smooth = k['smooth']
            if name == 'train/acc':
                if 'train/token_acc' in data:
                    name = 'train/token_acc'
                if 'train/seq_acc' in data:
                    name = 'train/seq_acc'
            if name == 'eval/acc':
                if 'eval/token_acc' in data:
                    name = 'eval/token_acc'
                if 'eval/seq_acc' in data:
                    name = 'eval/seq_acc'
            if name not in data:
                plots.append(None)
                continue
            _data = data[name]
            steps = [d['step'] for d in _data]
            values = [d['value'] for d in _data]
            if len(values) == 0:
                continue

            plt.close('all')
            fig = plt.figure()
            ax = fig.add_subplot()
            # _, ax = plt.subplots(1, 1, squeeze=True, figsize=(8, 5), dpi=100)
            ax.set_title(name)
            if len(values) == 1:
                ax.scatter(steps, values, color=TB_COLOR_SMOOTH)
            elif smooth is not None:
                ax.plot(steps, values, color=TB_COLOR)
                values_s = tensorboard_smoothing(values, smooth)
                ax.plot(steps, values_s, color=TB_COLOR_SMOOTH)
            else:
                ax.plot(steps, values, color=TB_COLOR_SMOOTH)
            plots.append(fig)
        return plots