|
|
|
|
|
import collections |
|
|
import os.path |
|
|
import sys |
|
|
import time |
|
|
import webbrowser |
|
|
from datetime import datetime |
|
|
from typing import Dict, List, Tuple, Type |
|
|
|
|
|
import gradio as gr |
|
|
import json |
|
|
import matplotlib.pyplot as plt |
|
|
import psutil |
|
|
from packaging import version |
|
|
from transformers import is_tensorboard_available |
|
|
|
|
|
from swift.ui.base import BaseUI |
|
|
from swift.ui.llm_train.utils import close_loop, run_command_in_subprocess |
|
|
from swift.utils import TB_COLOR, TB_COLOR_SMOOTH, get_logger, read_tensorboard_file, tensorboard_smoothing |
|
|
from swift.utils.utils import format_time |
|
|
|
|
|
logger = get_logger() |
|
|
|
|
|
|
|
|
class Runtime(BaseUI): |
|
|
|
|
|
handlers: Dict[str, Tuple[List, Tuple]] = {} |
|
|
|
|
|
group = 'llm_train' |
|
|
|
|
|
all_plots = None |
|
|
|
|
|
log_event = {} |
|
|
|
|
|
sft_plot = [ |
|
|
{ |
|
|
'name': 'train/loss', |
|
|
'smooth': 0.9, |
|
|
}, |
|
|
{ |
|
|
'name': 'train/acc', |
|
|
'smooth': None, |
|
|
}, |
|
|
{ |
|
|
'name': 'train/learning_rate', |
|
|
'smooth': None, |
|
|
}, |
|
|
{ |
|
|
'name': 'eval/loss', |
|
|
'smooth': 0.9, |
|
|
}, |
|
|
{ |
|
|
'name': 'eval/acc', |
|
|
'smooth': None, |
|
|
}, |
|
|
] |
|
|
|
|
|
dpo_plot = [ |
|
|
{ |
|
|
'name': 'train/loss', |
|
|
'smooth': 0.9, |
|
|
}, |
|
|
{ |
|
|
'name': 'train/rewards/accuracies', |
|
|
'smooth': None, |
|
|
}, |
|
|
{ |
|
|
'name': 'train/rewards/margins', |
|
|
'smooth': 0.9, |
|
|
}, |
|
|
{ |
|
|
'name': 'train/logps/chosen', |
|
|
'smooth': 0.9, |
|
|
}, |
|
|
{ |
|
|
'name': 'train/logps/rejected', |
|
|
'smooth': 0.9, |
|
|
}, |
|
|
] |
|
|
|
|
|
kto_plot = [ |
|
|
{ |
|
|
'name': 'kl', |
|
|
'smooth': None, |
|
|
}, |
|
|
{ |
|
|
'name': 'rewards/chosen_sum', |
|
|
'smooth': 0.9, |
|
|
}, |
|
|
{ |
|
|
'name': 'logps/chosen_sum', |
|
|
'smooth': 0.9, |
|
|
}, |
|
|
{ |
|
|
'name': 'rewards/rejected_sum', |
|
|
'smooth': 0.9, |
|
|
}, |
|
|
{ |
|
|
'name': 'logps/rejected_sum', |
|
|
'smooth': 0.9, |
|
|
}, |
|
|
] |
|
|
|
|
|
orpo_plot = [ |
|
|
{ |
|
|
'name': 'train/loss', |
|
|
'smooth': 0.9, |
|
|
}, |
|
|
{ |
|
|
'name': 'train/rewards/accuracies', |
|
|
'smooth': None, |
|
|
}, |
|
|
{ |
|
|
'name': 'train/rewards/margins', |
|
|
'smooth': 0.9, |
|
|
}, |
|
|
{ |
|
|
'name': 'train/rewards/chosen', |
|
|
'smooth': 0.9, |
|
|
}, |
|
|
{ |
|
|
'name': 'train/log_odds_ratio', |
|
|
'smooth': 0.9, |
|
|
}, |
|
|
] |
|
|
|
|
|
locale_dict = { |
|
|
'runtime_tab': { |
|
|
'label': { |
|
|
'zh': '运行时', |
|
|
'en': 'Runtime' |
|
|
}, |
|
|
}, |
|
|
'tb_not_found': { |
|
|
'value': { |
|
|
'zh': 'tensorboard未安装,使用pip install tensorboard进行安装', |
|
|
'en': 'tensorboard not found, install it by pip install tensorboard', |
|
|
} |
|
|
}, |
|
|
'running_cmd': { |
|
|
'label': { |
|
|
'zh': '运行命令', |
|
|
'en': 'Command line' |
|
|
}, |
|
|
'info': { |
|
|
'zh': '执行的实际命令', |
|
|
'en': 'The actual command' |
|
|
} |
|
|
}, |
|
|
'show_log': { |
|
|
'value': { |
|
|
'zh': '展示运行状态', |
|
|
'en': 'Show running status' |
|
|
}, |
|
|
}, |
|
|
'stop_show_log': { |
|
|
'value': { |
|
|
'zh': '停止展示运行状态', |
|
|
'en': 'Stop showing running status' |
|
|
}, |
|
|
}, |
|
|
'logging_dir': { |
|
|
'label': { |
|
|
'zh': '日志路径', |
|
|
'en': 'Logging dir' |
|
|
}, |
|
|
'info': { |
|
|
'zh': '支持手动传入文件路径', |
|
|
'en': 'Support fill custom path in' |
|
|
} |
|
|
}, |
|
|
'log': { |
|
|
'label': { |
|
|
'zh': '日志输出', |
|
|
'en': 'Logging content' |
|
|
}, |
|
|
'info': { |
|
|
'zh': '如果日志无更新请再次点击"展示日志内容"', |
|
|
'en': 'Please press "Show log" if the log content is not updating' |
|
|
} |
|
|
}, |
|
|
'running_tasks': { |
|
|
'label': { |
|
|
'zh': '运行中任务', |
|
|
'en': 'Running Tasks' |
|
|
}, |
|
|
'info': { |
|
|
'zh': '运行中的任务(所有的swift sft命令)', |
|
|
'en': 'All running tasks(started by swift sft)' |
|
|
} |
|
|
}, |
|
|
'refresh_tasks': { |
|
|
'value': { |
|
|
'zh': '找回运行时任务', |
|
|
'en': 'Find running tasks' |
|
|
}, |
|
|
}, |
|
|
'kill_task': { |
|
|
'value': { |
|
|
'zh': '杀死任务', |
|
|
'en': 'Kill running task' |
|
|
}, |
|
|
}, |
|
|
'tb_url': { |
|
|
'label': { |
|
|
'zh': 'Tensorboard链接', |
|
|
'en': 'Tensorboard URL' |
|
|
}, |
|
|
'info': { |
|
|
'zh': '仅展示,不可编辑', |
|
|
'en': 'Not editable' |
|
|
} |
|
|
}, |
|
|
'start_tb': { |
|
|
'value': { |
|
|
'zh': '打开TensorBoard', |
|
|
'en': 'Start TensorBoard' |
|
|
}, |
|
|
}, |
|
|
'close_tb': { |
|
|
'value': { |
|
|
'zh': '关闭TensorBoard', |
|
|
'en': 'Close TensorBoard' |
|
|
}, |
|
|
}, |
|
|
} |
|
|
|
|
|
@classmethod |
|
|
def do_build_ui(cls, base_tab: Type['BaseUI']): |
|
|
with gr.Accordion(elem_id='runtime_tab', open=False, visible=True): |
|
|
with gr.Blocks(): |
|
|
with gr.Row(): |
|
|
gr.Textbox(elem_id='running_cmd', lines=1, scale=20, interactive=False, max_lines=1) |
|
|
gr.Textbox(elem_id='logging_dir', lines=1, scale=20, max_lines=1) |
|
|
gr.Button(elem_id='show_log', scale=2, variant='primary') |
|
|
gr.Button(elem_id='stop_show_log', scale=2) |
|
|
gr.Textbox(elem_id='tb_url', lines=1, scale=10, interactive=False, max_lines=1) |
|
|
gr.Button(elem_id='start_tb', scale=2, variant='primary') |
|
|
gr.Button(elem_id='close_tb', scale=2) |
|
|
with gr.Row(): |
|
|
gr.Textbox(elem_id='log', lines=6, visible=False) |
|
|
with gr.Row(): |
|
|
gr.Dropdown(elem_id='running_tasks', scale=10) |
|
|
gr.Button(elem_id='refresh_tasks', scale=1) |
|
|
gr.Button(elem_id='kill_task', scale=1) |
|
|
|
|
|
with gr.Row(): |
|
|
cls.all_plots = [] |
|
|
for idx, k in enumerate(Runtime.sft_plot): |
|
|
name = k['name'] |
|
|
cls.all_plots.append(gr.Plot(elem_id=str(idx), label=name)) |
|
|
|
|
|
concurrency_limit = {} |
|
|
if version.parse(gr.__version__) >= version.parse('4.0.0'): |
|
|
concurrency_limit = {'concurrency_limit': 5} |
|
|
base_tab.element('show_log').click( |
|
|
Runtime.update_log, [base_tab.element('running_tasks')], [cls.element('log')] + cls.all_plots).then( |
|
|
Runtime.wait, [base_tab.element('logging_dir'), |
|
|
base_tab.element('running_tasks')], [cls.element('log')] + cls.all_plots, |
|
|
**concurrency_limit) |
|
|
|
|
|
base_tab.element('stop_show_log').click(cls.break_log_event, [cls.element('running_tasks')], []) |
|
|
|
|
|
base_tab.element('start_tb').click( |
|
|
Runtime.start_tb, |
|
|
[base_tab.element('logging_dir')], |
|
|
[base_tab.element('tb_url')], |
|
|
) |
|
|
|
|
|
base_tab.element('close_tb').click( |
|
|
Runtime.close_tb, |
|
|
[base_tab.element('logging_dir')], |
|
|
[], |
|
|
) |
|
|
|
|
|
base_tab.element('refresh_tasks').click( |
|
|
Runtime.refresh_tasks, |
|
|
[base_tab.element('running_tasks')], |
|
|
[base_tab.element('running_tasks')], |
|
|
) |
|
|
|
|
|
@classmethod |
|
|
def get_plot(cls, task): |
|
|
if not task or 'swift sft' in task or 'swift pt' in task: |
|
|
return cls.sft_plot |
|
|
|
|
|
args: dict = cls.parse_info_from_cmdline(task)[1] |
|
|
train_type = args.get('rlhf_type', 'dpo') |
|
|
if train_type in ('dpo', 'cpo', 'simpo'): |
|
|
return cls.dpo_plot |
|
|
elif train_type == 'kto': |
|
|
return cls.kto_plot |
|
|
elif train_type == 'orpo': |
|
|
return cls.orpo_plot |
|
|
|
|
|
@classmethod |
|
|
def update_log(cls, task): |
|
|
ret = [gr.update(visible=True)] |
|
|
plot = Runtime.get_plot(task) |
|
|
for i in range(len(plot)): |
|
|
p = plot[i] |
|
|
ret.append(gr.update(visible=True, label=p['name'])) |
|
|
return ret |
|
|
|
|
|
@classmethod |
|
|
def get_initial(cls, line): |
|
|
tqdm_starts = ['Train:', 'Map:', 'Val:', 'Filter:'] |
|
|
for start in tqdm_starts: |
|
|
if line.startswith(start): |
|
|
return start |
|
|
return None |
|
|
|
|
|
@classmethod |
|
|
def wait(cls, logging_dir, task): |
|
|
if not logging_dir: |
|
|
return [None] + Runtime.plot(task) |
|
|
log_file = os.path.join(logging_dir, 'run.log') |
|
|
cls.log_event[logging_dir] = False |
|
|
offset = 0 |
|
|
latest_data = '' |
|
|
lines = collections.deque(maxlen=int(os.environ.get('MAX_LOG_LINES', 50))) |
|
|
try: |
|
|
with open(log_file, 'r', encoding='utf-8') as input: |
|
|
input.seek(offset) |
|
|
fail_cnt = 0 |
|
|
while True: |
|
|
try: |
|
|
latest_data += input.read() |
|
|
except UnicodeDecodeError: |
|
|
continue |
|
|
if not latest_data: |
|
|
time.sleep(0.5) |
|
|
fail_cnt += 1 |
|
|
if fail_cnt > 50: |
|
|
break |
|
|
|
|
|
if cls.log_event.get(logging_dir, False): |
|
|
cls.log_event[logging_dir] = False |
|
|
break |
|
|
|
|
|
if '\n' not in latest_data: |
|
|
continue |
|
|
latest_lines = latest_data.split('\n') |
|
|
if latest_data[-1] != '\n': |
|
|
latest_data = latest_lines[-1] |
|
|
latest_lines = latest_lines[:-1] |
|
|
else: |
|
|
latest_data = '' |
|
|
lines.extend(latest_lines) |
|
|
start = cls.get_initial(lines[-1]) |
|
|
if start: |
|
|
i = len(lines) - 2 |
|
|
while i >= 0: |
|
|
if lines[i].startswith(start): |
|
|
del lines[i] |
|
|
i -= 1 |
|
|
else: |
|
|
break |
|
|
yield ['\n'.join(lines)] + Runtime.plot(task) |
|
|
except IOError: |
|
|
pass |
|
|
|
|
|
@classmethod |
|
|
def break_log_event(cls, task): |
|
|
if not task: |
|
|
return |
|
|
pid, all_args = Runtime.parse_info_from_cmdline(task) |
|
|
cls.log_event[all_args['logging_dir']] = True |
|
|
|
|
|
@classmethod |
|
|
def show_log(cls, logging_dir): |
|
|
webbrowser.open('file://' + os.path.join(logging_dir, 'run.log'), new=2) |
|
|
|
|
|
@classmethod |
|
|
def start_tb(cls, logging_dir): |
|
|
if not is_tensorboard_available(): |
|
|
gr.Error(cls.locale('tb_not_found', cls.lang)['value']) |
|
|
return '' |
|
|
|
|
|
logging_dir = logging_dir.strip() |
|
|
logging_dir = logging_dir if not logging_dir.endswith(os.sep) else logging_dir[:-1] |
|
|
if logging_dir in cls.handlers: |
|
|
return cls.handlers[logging_dir][1] |
|
|
|
|
|
handler, lines = run_command_in_subprocess('tensorboard', '--logdir', logging_dir, timeout=2) |
|
|
localhost_addr = '' |
|
|
for line in lines: |
|
|
if 'http://localhost:' in line: |
|
|
line = line[line.index('http://localhost:'):] |
|
|
localhost_addr = line[:line.index(' ')] |
|
|
cls.handlers[logging_dir] = (handler, localhost_addr) |
|
|
logger.info('===========Tensorboard Log============') |
|
|
logger.info('\n'.join(lines)) |
|
|
webbrowser.open(localhost_addr, new=2) |
|
|
return localhost_addr |
|
|
|
|
|
@staticmethod |
|
|
def close_tb(logging_dir): |
|
|
if logging_dir in Runtime.handlers: |
|
|
close_loop(Runtime.handlers[logging_dir][0]) |
|
|
Runtime.handlers.pop(logging_dir) |
|
|
|
|
|
@staticmethod |
|
|
def refresh_tasks(running_task=None): |
|
|
output_dir = running_task if not running_task or 'pid:' not in running_task else None |
|
|
process_name = 'swift' |
|
|
negative_name = 'swift.exe' |
|
|
cmd_name = ['pt', 'sft', 'rlhf'] |
|
|
process = [] |
|
|
selected = None |
|
|
for proc in psutil.process_iter(): |
|
|
try: |
|
|
cmdlines = proc.cmdline() |
|
|
except (psutil.ZombieProcess, psutil.AccessDenied, psutil.NoSuchProcess): |
|
|
cmdlines = [] |
|
|
if any([process_name in cmdline |
|
|
for cmdline in cmdlines]) and not any([negative_name in cmdline |
|
|
for cmdline in cmdlines]) and any( |
|
|
[cmdline in cmd_name for cmdline in cmdlines]): |
|
|
process.append(Runtime.construct_running_task(proc)) |
|
|
if output_dir is not None and any( |
|
|
[output_dir == cmdline for cmdline in cmdlines]): |
|
|
selected = Runtime.construct_running_task(proc) |
|
|
if not selected: |
|
|
if running_task and running_task in process: |
|
|
selected = running_task |
|
|
if not selected and process: |
|
|
selected = process[0] |
|
|
return gr.update(choices=process, value=selected) |
|
|
|
|
|
@staticmethod |
|
|
def construct_running_task(proc): |
|
|
pid = proc.pid |
|
|
ts = time.time() |
|
|
create_time = proc.create_time() |
|
|
create_time_formatted = datetime.fromtimestamp(create_time).strftime('%Y-%m-%d, %H:%M') |
|
|
|
|
|
return f'pid:{pid}/create:{create_time_formatted}' \ |
|
|
f'/running:{format_time(ts-create_time)}/cmd:{" ".join(proc.cmdline())}' |
|
|
|
|
|
@staticmethod |
|
|
def parse_info_from_cmdline(task): |
|
|
pid = None |
|
|
if '/cmd:' in task: |
|
|
for i in range(3): |
|
|
slash = task.find('/') |
|
|
if i == 0: |
|
|
pid = task[:slash].split(':')[1] |
|
|
task = task[slash + 1:] |
|
|
if 'swift sft' in task: |
|
|
args = task.split('swift sft')[1] |
|
|
elif 'swift pt' in task: |
|
|
args = task.split('swift pt')[1] |
|
|
elif 'swift rlhf' in task: |
|
|
args = task.split('swift rlhf')[1] |
|
|
else: |
|
|
raise ValueError(f'Cannot parse cmd line: {task}') |
|
|
args = [arg.strip() for arg in args.split('--') if arg.strip()] |
|
|
all_args = {} |
|
|
for i in range(len(args)): |
|
|
space = args[i].find(' ') |
|
|
splits = args[i][:space], args[i][space + 1:] |
|
|
all_args[splits[0]] = splits[1] |
|
|
|
|
|
output_dir = all_args['output_dir'] |
|
|
if os.path.exists(os.path.join(output_dir, 'args.json')): |
|
|
with open(os.path.join(output_dir, 'args.json'), 'r', encoding='utf-8') as f: |
|
|
_json = json.load(f) |
|
|
for key in all_args.keys(): |
|
|
all_args[key] = _json.get(key) |
|
|
if isinstance(all_args[key], list): |
|
|
if any([' ' in value for value in all_args[key]]): |
|
|
all_args[key] = [f'"{value}"' for value in all_args[key]] |
|
|
all_args[key] = ' '.join(all_args[key]) |
|
|
return pid, all_args |
|
|
|
|
|
@staticmethod |
|
|
def kill_task(task): |
|
|
if task: |
|
|
pid, all_args = Runtime.parse_info_from_cmdline(task) |
|
|
output_dir = all_args['output_dir'] |
|
|
if sys.platform == 'win32': |
|
|
os.system(f'taskkill /f /t /pid "{pid}"') |
|
|
else: |
|
|
os.system(f'pkill -9 -f {output_dir}') |
|
|
time.sleep(1) |
|
|
Runtime.break_log_event(task) |
|
|
return [Runtime.refresh_tasks()] + [gr.update(value=None)] * (len(Runtime.get_plot(task)) + 1) |
|
|
|
|
|
@staticmethod |
|
|
def reset(): |
|
|
return None, 'output' |
|
|
|
|
|
@staticmethod |
|
|
def task_changed(task, base_tab): |
|
|
if task: |
|
|
_, all_args = Runtime.parse_info_from_cmdline(task) |
|
|
else: |
|
|
all_args = {} |
|
|
elements = list(base_tab.valid_elements().values()) |
|
|
ret = [] |
|
|
for e in elements: |
|
|
if e.elem_id in all_args: |
|
|
if isinstance(e, gr.Dropdown) and e.multiselect: |
|
|
arg = all_args[e.elem_id].split(' ') |
|
|
else: |
|
|
arg = all_args[e.elem_id] |
|
|
ret.append(gr.update(value=arg)) |
|
|
else: |
|
|
ret.append(gr.update()) |
|
|
Runtime.break_log_event(task) |
|
|
return ret + [gr.update(value=None)] * (len(Runtime.get_plot(task)) + 1) |
|
|
|
|
|
@staticmethod |
|
|
def plot(task): |
|
|
plot = Runtime.get_plot(task) |
|
|
if not task: |
|
|
return [None] * len(plot) |
|
|
_, all_args = Runtime.parse_info_from_cmdline(task) |
|
|
tb_dir = all_args['logging_dir'] |
|
|
if not os.path.exists(tb_dir): |
|
|
return [None] * len(plot) |
|
|
fname = [ |
|
|
fname for fname in os.listdir(tb_dir) |
|
|
if os.path.isfile(os.path.join(tb_dir, fname)) and fname.startswith('events.out') |
|
|
] |
|
|
if fname: |
|
|
fname = fname[0] |
|
|
else: |
|
|
return [None] * len(plot) |
|
|
tb_path = os.path.join(tb_dir, fname) |
|
|
data = read_tensorboard_file(tb_path) |
|
|
|
|
|
plots = [] |
|
|
for k in plot: |
|
|
name = k['name'] |
|
|
smooth = k['smooth'] |
|
|
if name == 'train/acc': |
|
|
if 'train/token_acc' in data: |
|
|
name = 'train/token_acc' |
|
|
if 'train/seq_acc' in data: |
|
|
name = 'train/seq_acc' |
|
|
if name == 'eval/acc': |
|
|
if 'eval/token_acc' in data: |
|
|
name = 'eval/token_acc' |
|
|
if 'eval/seq_acc' in data: |
|
|
name = 'eval/seq_acc' |
|
|
if name not in data: |
|
|
plots.append(None) |
|
|
continue |
|
|
_data = data[name] |
|
|
steps = [d['step'] for d in _data] |
|
|
values = [d['value'] for d in _data] |
|
|
if len(values) == 0: |
|
|
continue |
|
|
|
|
|
plt.close('all') |
|
|
fig = plt.figure() |
|
|
ax = fig.add_subplot() |
|
|
|
|
|
ax.set_title(name) |
|
|
if len(values) == 1: |
|
|
ax.scatter(steps, values, color=TB_COLOR_SMOOTH) |
|
|
elif smooth is not None: |
|
|
ax.plot(steps, values, color=TB_COLOR) |
|
|
values_s = tensorboard_smoothing(values, smooth) |
|
|
ax.plot(steps, values_s, color=TB_COLOR_SMOOTH) |
|
|
else: |
|
|
ax.plot(steps, values, color=TB_COLOR_SMOOTH) |
|
|
plots.append(fig) |
|
|
return plots |
|
|
|