Spaces:
Runtime error
Runtime error
| from bokeh.io import curdoc | |
| from bokeh.layouts import column, row | |
| from bokeh.models import Slider, Select, ColumnDataSource, Span, Div, Button, LogColorMapper, ColorBar, LogTicker | |
| from bokeh.models.tools import CrosshairTool | |
| from bokeh.plotting import figure | |
| from bokeh.events import Tap | |
| from bokeh.transform import log_cmap | |
| import pandas as pd | |
| from scipy.spatial import ConvexHull | |
| from scipy.optimize import curve_fit | |
| from time import sleep | |
| from utils import * | |
| from conversions import * | |
| ######################################################################################################################## | |
| # Basic dimensions | |
| ######################################################################################################################## | |
| plot_width = 1200 | |
| plot_height = 400 | |
| sidebar_width = 400 | |
| in_text_plot_width = 800 | |
| in_text_plot_height = 300 | |
| ######################################################################################################################## | |
| # Set up data | |
| ######################################################################################################################## | |
| df = pd.read_csv("optimal_training/static/loss_vs_compute.csv") | |
| loss_keys = [key for key in df.keys() if "loss" in key] | |
| losses_per_run = {key: np.array(clean_run(list(zip(df["global_step"], df[key])))) for key in loss_keys} | |
| losses_per_run = {k: v for k, v in losses_per_run.items() if len(v) > 5} | |
| bounds_per_run = {key: [min(value[:, 0]), max(value[:, 0])] for key, value in losses_per_run.items()} | |
| params_per_run = {key: param_count(run) for key, run in losses_per_run.items()} | |
| ordered_keys = sorted(losses_per_run, key=lambda x: params_per_run[x]) | |
| losses_per_run = [losses_per_run[key] for key in ordered_keys] | |
| bounds_per_run = [bounds_per_run[key] for key in ordered_keys] | |
| params_per_run = [params_per_run[key] for key in ordered_keys] | |
| palette = "Viridis256" | |
| color_mapper = LogColorMapper(palette=palette, low=min(params_per_run), high=max(params_per_run)) | |
| general_bounds = bounds_per_run[2][0], bounds_per_run[-2][1] | |
| print("{:.4e}, {:.4e}".format(general_bounds[0] * day_ratio, general_bounds[1] * day_ratio)) | |
| color_list = ["#000000" in params_per_run] | |
| # there's a bogus point of small coordinates at position 0 to get the ConvexHull facing the origin | |
| # hacky, but it's the syntax here, qhull_options=QG0 means the ConvexHull facing point 0 | |
| bounded_points = np.array([(10e8, 3, -1)] + [(a, b, i) for i, run in enumerate(losses_per_run) for a, b in run if | |
| general_bounds[0] < a < general_bounds[1]]) | |
| all_points = np.array([(a, b, i) for i, run in enumerate(losses_per_run) for a, b in run]) | |
| all_hull = ConvexHull(bounded_points[:, :2], qhull_options='QG0') | |
| log_points = np.array([(np.log(a), b) for a, b, i in bounded_points]) | |
| log_hull = ConvexHull(log_points, qhull_options='QG0') | |
| indexed_runs = [np.array([(a, b) for a, b in run]) for run in losses_per_run] | |
| ######################################################################################################################## | |
| # Set up loss_plot | |
| ######################################################################################################################## | |
| color_bar = ColorBar(color_mapper=color_mapper, ticker=LogTicker(), label_standoff=12, | |
| border_line_color=None, location=(0, 0), title="Num of params") | |
| loss_plot = figure(plot_height=plot_height, plot_width=plot_width, | |
| title="Validation loss during training for an array of models of different sizes", | |
| tools="pan,reset,save,wheel_zoom,tap", active_scroll="wheel_zoom", | |
| x_range=[min(all_points[:, 0]) * day_ratio, max(all_points[:, 0]) * day_ratio], | |
| y_range=[min(all_points[:, 1]), max(all_points[:, 1])], | |
| x_axis_type="log", y_axis_type="log", | |
| x_axis_label="Floating-point operations (excluding embeddings & softmax)", | |
| y_axis_label="Validation loss on Wikitext-103", output_backend="webgl") | |
| loss_plot.add_tools(CrosshairTool(dimensions="width", line_alpha=0.2)) | |
| loss_plot.add_layout(color_bar, "left") | |
| # for i, run in indexed_runs.items(): | |
| # source = ColumnDataSource(data=dict(x=run[:, 0] * day_ratio, y=run[:, 1])) | |
| # loss_plot.line('x', 'y', source=source, line_width=1, line_alpha=0.6, color=color_list[i]) | |
| # loss_plot.scatter('x', 'y', source=source, line_width=1, line_alpha=0.6, color=color_list[i]) | |
| source = ColumnDataSource(data=dict( | |
| xs=[run[:, 0] * day_ratio for run in indexed_runs], # x coords for each line (list of lists) | |
| ys=[run[:, 1] for run in indexed_runs], # y coords for each line (list of lists) | |
| params=params_per_run # data to use for colormapping | |
| )) | |
| loss_plot.multi_line('xs', 'ys', source=source, | |
| color=log_cmap('params', palette, min(params_per_run), max(params_per_run))) | |
| source = ColumnDataSource(data=dict( | |
| x=[compute for run in indexed_runs for compute in run[:, 0] * day_ratio], # x coords for each line (list of lists) | |
| y=[loss for run in indexed_runs for loss in run[:, 1]], # y coords for each line (list of lists) | |
| params=[repeated_params for i, params in enumerate(params_per_run) | |
| for repeated_params in [params] * len(indexed_runs[i])] # data to use for colormapping | |
| )) | |
| loss_plot.scatter('x', 'y', source=source, | |
| color=log_cmap('params', palette, min(params_per_run), max(params_per_run)), size=3) | |
| hull_indices = set(index for pair in all_hull.simplices[all_hull.good] for index in pair) | |
| hull_indices = sorted(hull_indices, key=lambda x: bounded_points[x, 0]) | |
| ######################################################################################################################## | |
| # Fit frontier | |
| ######################################################################################################################## | |
| hull_points = np.array([bounded_points[index] for index in hull_indices]) | |
| loss_popt, loss_pcov = curve_fit(loss_fit, hull_points[:, 0], hull_points[:, 1]) | |
| a, b, c = loss_popt | |
| print(a, b, c) | |
| display_abscisses = np.array([min(all_points[:, 0]) / 1.25] + sorted(list(all_points[:, 0])) + | |
| [max(all_points[:, 0]) * 1.25]) | |
| source = ColumnDataSource( | |
| data=dict(x=sorted(display_abscisses * day_ratio), y=loss_fit(sorted(display_abscisses), *loss_popt))) | |
| loss_plot.line('x', 'y', source=source, line_width=1, line_alpha=0.8, color="red") | |
| ######################################################################################################################## | |
| # Set up param_plot | |
| ######################################################################################################################## | |
| param_plot = figure(plot_height=plot_height, plot_width=plot_width, | |
| title="Optimal number of non-embedding parameters per floating-point operations budget", | |
| tools="pan,reset,save,wheel_zoom,tap", active_scroll="wheel_zoom", | |
| x_range=loss_plot.x_range, | |
| y_range=[min(params_per_run), max(params_per_run)], | |
| x_axis_type="log", y_axis_type="log", | |
| x_axis_label="Floating-point operations (excluding embeddings & softmax)", | |
| y_axis_label="Optimal number of non-embedding parameters", output_backend="webgl") | |
| param_plot.add_tools(CrosshairTool(dimensions="width", line_alpha=0.2)) | |
| param_plot.add_layout(color_bar, "left") | |
| logspace_points = convert_to_logspace(bounded_points, *loss_popt) | |
| logspace_losses_per_run = [convert_to_logspace(run, *loss_popt) for run in losses_per_run] | |
| passing_points = [] | |
| for run_index, log_run in enumerate(logspace_losses_per_run): | |
| current_point = None | |
| passed = False | |
| difference = log_run[:, 1] - log_run[:, 0] | |
| passing_points.append(np.argmax(difference)) | |
| compute_at_passing_points = np.array([(losses_per_run[i][passing_point, 0], params_per_run[i]) | |
| for i, passing_point in enumerate(passing_points)]) | |
| compute_at_hull = np.array([(losses_per_run[i][passing_point, 0], params_per_run[i]) | |
| for i, passing_point in enumerate(passing_points) if i in set(hull_points[:, 2])]) | |
| run_indices_at_hull = [i for i, passing_point in enumerate(passing_points) if i in set(hull_points[:, 2])] | |
| param_popt, param_pcov = curve_fit(param_fit, compute_at_hull[:, 0], np.log(compute_at_hull[:, 1])) | |
| d, e, f = param_popt | |
| source = ColumnDataSource(data=dict(x=compute_at_hull[:, 0] * day_ratio, | |
| y=compute_at_hull[:, 1], | |
| params=[params for i, params in enumerate(params_per_run) if | |
| i in set(hull_points[:, 2])])) | |
| param_plot.scatter('x', 'y', source=source, | |
| color=log_cmap('params', palette, min(params_per_run), max(params_per_run))) | |
| display_abscisses = np.array([min(compute_at_hull[:, 0]) / 1.25] + sorted(list(compute_at_hull[:, 0])) + | |
| [max(compute_at_hull[:, 0]) * 1.25]) | |
| source = ColumnDataSource(data=dict(x=display_abscisses * day_ratio, | |
| y=safe_flo_to_param(display_abscisses, d, e, f))) | |
| param_plot.line('x', 'y', source=source, line_width=1, line_alpha=0.8, color="orange") | |
| ######################################################################################################################## | |
| # Set up widgets | |
| ######################################################################################################################## | |
| hours_end = 24 | |
| hours_initial = 3.23 | |
| gpu_dropdown = Select(title="GPU", | |
| options=["V100", "P100", "P4", "K80", ], | |
| value="V100", width=sidebar_width, sizing_mode="stretch_width") | |
| amp_mode_dropdown = Select(title="AMP mode", options=["O0", "O1", "O2"], value="O0", width=sidebar_width, | |
| sizing_mode="stretch_width") | |
| tipping_width = tipping_point(gpu_dropdown.value, amp_mode_dropdown.value, param_popt) | |
| tip = {} | |
| update_tip(tip, tipping_width, gpu_dropdown.value, amp_mode_dropdown.value, loss_popt, param_popt) | |
| hours_slider = Slider(title="Wall time (hours)", value=hours_initial, start=tip["hours"], end=hours_end, step=1 / 100, | |
| width=sidebar_width, sizing_mode="stretch_width") | |
| dollars_slider = Slider(title="Budget (dollars)", value=hours_to_dollars(hours_initial, gpu_dropdown.value), | |
| start=dollars_to_hours(tip["hours"], gpu_dropdown.value), | |
| end=hours_to_dollars(hours_end, gpu_dropdown.value), | |
| step=1 / 100, width=sidebar_width, sizing_mode="stretch_width") | |
| input_buffer = Div(text="", width=sidebar_width, height=10, | |
| style={"display": "block", "margin": "0 auto", "width": f"{sidebar_width}px", | |
| "text-align": 'center'}) | |
| top_sidebar_div_style = {"display": "block", "margin": "0 auto", 'font-size': "125%", | |
| "width": f"{sidebar_width}px", "text-align": 'center'} | |
| energy_text = Div(text=energy_fill(hours_to_kWh(hours_slider.value, gpu_dropdown.value), | |
| hours_to_co2(hours_slider.value, gpu_dropdown.value)), | |
| width=sidebar_width, height=45, | |
| style=top_sidebar_div_style) | |
| slider_moves = {"hours": 0, "dollars": 0, "kWh": 0, "co2": 0} | |
| n_sliders = len(slider_moves) | |
| width = hours_to_width(hours_slider.value, gpu_dropdown.value, amp_mode_dropdown.value, param_popt) | |
| flo = width_to_flo(width, *param_popt) | |
| optimal_params = safe_flo_to_param(flo / 24 / 3600, *param_popt) | |
| final_loss = loss_fit(flo / 24 / 3600, *loss_popt) | |
| example_shape = {} | |
| example_shape['example_depth'], example_shape['example_width'] = optimal_model_shape(width, optimal_params) | |
| example_shape['alternate_depth'], example_shape['alternate_width'] = alternate_model_shape(width, optimal_params) | |
| flo_line = Span(location=flo, line_alpha=0.7, | |
| dimension='height', line_color='purple', | |
| line_dash='dashed', line_width=1) | |
| loss_line = Span(location=final_loss, line_alpha=0.7, | |
| dimension='width', line_color='red', | |
| line_dash='dashed', line_width=1) | |
| param_line = Span(location=optimal_params, line_alpha=0.7, | |
| dimension='width', line_color='orange', | |
| line_dash='dashed', line_width=1) | |
| loss_plot.add_layout(flo_line) | |
| loss_plot.add_layout(loss_line) | |
| param_plot.add_layout(flo_line) | |
| param_plot.add_layout(param_line) | |
| sidebar_div_style = {"display": "block", "margin": "0 auto", "width": f"{sidebar_width}px", "text-align": 'center'} | |
| big_sidebar_div_style = {"display": "block", "margin": "0 auto", "width": f"{sidebar_width}px", | |
| "text-align": 'center', 'font-size': "200%", 'font-weight': "bold"} | |
| static_loss_text = Div(text="Expected wt-103 validation loss:", width=sidebar_width, height=10, style=sidebar_div_style) | |
| optimal_loss_text = Div(text="{:.2f}".format(final_loss), width=sidebar_width, height=45, | |
| style={"display": "block", "margin": "0 auto", 'font-size': "200%", | |
| 'font-weight': "bold", "width": f"{sidebar_width}px", "text-align": 'center'}) | |
| static_param_text = Div(text="Optimal number of non-embedding parameters:", width=sidebar_width, height=10, | |
| style=sidebar_div_style) | |
| optimal_param_text = Div(text="{:.2e}".format(optimal_params), width=sidebar_width, height=45, | |
| style=big_sidebar_div_style) | |
| static_shape_text = Div(text="For example, this could be a model of", width=sidebar_width, height=10, | |
| style=sidebar_div_style) | |
| optimal_shape_text = Div(text=f"{example_shape['example_depth']} layers of {example_shape['example_width']} dimensions", | |
| width=sidebar_width, height=30, style=big_sidebar_div_style) | |
| static_altshape_text = Div(text="Or a model of", width=sidebar_width, height=10, style=sidebar_div_style) | |
| optimal_altshape_text = Div( | |
| text=f"{example_shape['alternate_depth']} layers of {example_shape['alternate_width']} dimensions", | |
| width=sidebar_width, height=30, style=big_sidebar_div_style) | |
| def compare_and_update(width): | |
| if width >= tip["width"]: | |
| update_width(width) | |
| hours = width_to_hours(width, gpu_dropdown.value, amp_mode_dropdown.value, param_popt) | |
| hours_slider.value = hours | |
| else: | |
| width = min(tip["width"], width + 5) | |
| update_width(width) | |
| compare_and_update(width) | |
| def update_width(width): | |
| flo = width_to_flo(width, *param_popt) | |
| flo_line.location = flo | |
| optimal_params = safe_flo_to_param(flo / 24 / 3600, *param_popt) | |
| final_loss = loss_fit(flo / 24 / 3600, *loss_popt) | |
| loss_line.location = final_loss | |
| param_line.location = optimal_params | |
| example_shape['example_depth'], example_shape['example_width'] = optimal_model_shape(width, optimal_params) | |
| example_shape['alternate_depth'], example_shape['alternate_width'] = alternate_model_shape(width, optimal_params) | |
| optimal_shape_text.text = f"{example_shape['example_depth']} layers of {example_shape['example_width']} dimensions" | |
| optimal_altshape_text.text = f"{example_shape['alternate_depth']} layers of {example_shape['alternate_width']} dimensions" | |
| optimal_param_text.text = "{:.2e}".format(optimal_params) | |
| optimal_loss_text.text = "{:.2f}".format(final_loss) | |
| def hours_update(attrname, old, new): | |
| slider_moves["hours"] += 1 | |
| # if hours was the first updated slider | |
| if sum(slider_moves.values()) <= n_sliders * slider_moves["hours"] - n_sliders + 1: | |
| dollars_slider.value = hours_to_dollars(hours_slider.value, gpu_dropdown.value) | |
| energy_text.text = energy_fill(hours_to_kWh(hours_slider.value, gpu_dropdown.value), | |
| hours_to_co2(hours_slider.value, gpu_dropdown.value)) | |
| width = hours_to_width(hours_slider.value, gpu_dropdown.value, amp_mode_dropdown.value, param_popt) | |
| update_width(width) | |
| def dollars_update(attrname, old, new): | |
| slider_moves["dollars"] += 1 | |
| # if hours was the first updated slider | |
| if sum(slider_moves.values()) <= n_sliders * slider_moves["dollars"] - n_sliders + 1: | |
| hours_slider.value = dollars_to_hours(dollars_slider.value, gpu_dropdown.value) | |
| energy_text.text = energy_fill(hours_to_kWh(hours_slider.value, gpu_dropdown.value), | |
| hours_to_co2(hours_slider.value, gpu_dropdown.value)) | |
| def gpu_update(attrname, old, new): | |
| update_tip(tip, tipping_point(gpu_dropdown.value, amp_mode_dropdown.value, param_popt), gpu_dropdown.value, | |
| amp_mode_dropdown.value, loss_popt, param_popt) | |
| hours_slider.start = tip["hours"] | |
| dollars_slider.start = hours_to_dollars(tip["hours"], gpu_dropdown.value) | |
| if dollars_to_hours(dollars_slider.value, gpu_dropdown.value) == hours_slider.value: | |
| width = hours_to_width(hours_slider.value, gpu_dropdown.value, amp_mode_dropdown.value, param_popt) | |
| compare_and_update(width) | |
| else: | |
| dollars_slider.end = hours_to_dollars(hours_end, new) | |
| hours_slider.value = dollars_to_hours(dollars_slider.value, gpu_dropdown.value) | |
| energy_text.text = energy_fill(hours_to_kWh(hours_slider.value, gpu_dropdown.value), | |
| hours_to_co2(hours_slider.value, gpu_dropdown.value)) | |
| def amp_update(attrname, old, new): | |
| update_tip(tip, tipping_point(gpu_dropdown.value, amp_mode_dropdown.value, param_popt), gpu_dropdown.value, | |
| amp_mode_dropdown.value, loss_popt, param_popt) | |
| width = hours_to_width(hours_slider.value, gpu_dropdown.value, amp_mode_dropdown.value, param_popt) | |
| hours_slider.start = tip["hours"] | |
| dollars_slider.start = hours_to_dollars(tip["hours"], gpu_dropdown.value) | |
| compare_and_update(width) | |
| energy_text.text = energy_fill(hours_to_kWh(hours_slider.value, gpu_dropdown.value), | |
| hours_to_co2(hours_slider.value, gpu_dropdown.value)) | |
| def loss_tap(event): | |
| _, loss = event.x, event.y | |
| flo = loss_to_flo(loss, *loss_popt) | |
| param_number = safe_flo_to_param(flo, *param_popt) | |
| width = param_to_width(param_number) | |
| compare_and_update(width) | |
| loss_plot.on_event(Tap, loss_tap) | |
| def param_tap(event): | |
| _, param_number = event.x, event.y | |
| width = param_to_width(param_number) | |
| hours = width_to_hours(width, gpu_dropdown.value, amp_mode_dropdown.value, param_popt) | |
| hours_slider.value = hours | |
| param_plot.on_event(Tap, param_tap) | |
| hours_slider.on_change('value', hours_update) | |
| dollars_slider.on_change('value', dollars_update) | |
| gpu_dropdown.on_change("value", gpu_update) | |
| amp_mode_dropdown.on_change("value", amp_update) | |
| ######################################################################################################################## | |
| # Buttons | |
| ######################################################################################################################## | |
| def on_optimal_click(): | |
| code_box.text = hf_code(example_shape['example_width'], example_shape['example_depth']) | |
| def on_alternate_click(): | |
| code_box.text = hf_code(example_shape['alternate_width'], example_shape['alternate_depth']) | |
| input_text = Div(text="Choose a GPU, AMP mode, and budget:", width=sidebar_width, height=30, | |
| style={"display": "block", "margin": "0 auto", 'font-size': "125%", | |
| 'font-weight': "bold", "width": f"{sidebar_width}px", "text-align": 'center'}) | |
| initialize_optimal = Button(width=175, label="Initialize in 🤗transformers!") | |
| initialize_optimal.align = "center" | |
| initialize_optimal.on_click(on_optimal_click) | |
| results_buffer = Div(text="", width=sidebar_width, height=5, style=sidebar_div_style) | |
| initialize_alternate = Button(width=175, label="Initialize in 🤗transformers!") | |
| initialize_alternate.align = "center" | |
| initialize_alternate.on_click(on_alternate_click) | |
| code_box_style = {"display": "block", "margin": "0 auto", "width": f"{sidebar_width + plot_width}px", | |
| "text-align": 'center', | |
| "white-space": "pre-wrap", "background": "#f4f4f4", | |
| "border": "1px solid #ddd", | |
| "border-left": "3px solid #f36d33", | |
| "color": "#666", | |
| "page-break-inside": "avoid", | |
| "font-family": "monospace", | |
| "font-size": "15px", | |
| "line-height": "1.6", | |
| "max-width": "100%", | |
| "overflow": "hidden", | |
| "min-height": "30px", | |
| "word-wrap": "break-word"} | |
| code_box = Div(text="Find the right model for you with the curves and sliders then click the buttons to display the " | |
| "corresponding 🤗transformers code here!", width=sidebar_width + plot_width, style=code_box_style, | |
| sizing_mode="scale_width") | |
| code_box.align = "center" | |
| ######################################################################################################################## | |
| # Add write-up text | |
| ######################################################################################################################## | |
| text_width = "800px" | |
| main_text_style = {"min-height": "100px", | |
| "overflow": "hidden", | |
| "display": "block", | |
| "margin": "auto", | |
| "width": text_width, | |
| "font-size": "18px"} | |
| formula_img_style_1 = {"min-height": "25px", | |
| "display": "block", | |
| "margin": "0 auto", | |
| "width": text_width, | |
| "height": "auto", | |
| "max-width": "100%", | |
| "max-height": "100%"} | |
| formula_img_style_2 = {"min-height": "50px", | |
| "display": "block", | |
| "margin": "0 auto", | |
| "width": text_width, | |
| "height": "auto", | |
| "max-width": "100%", | |
| "max-height": "100%"} | |
| text_1 = Div(text=md1, style=main_text_style) | |
| text_2 = Div(text=md2, style=main_text_style) | |
| text_3 = Div(text=md3, style=main_text_style) | |
| text_4 = Div(text=md4, style=main_text_style) | |
| ######################################################################################################################## | |
| # Loss plot in write-up | |
| ######################################################################################################################## | |
| in_text_loss_plot = figure(plot_height=in_text_plot_height, plot_width=in_text_plot_width, | |
| title="Validation loss during training for an array of models of different sizes", | |
| tools="pan,reset,save,wheel_zoom,tap", active_scroll="wheel_zoom", | |
| x_range=[min(all_points[:, 0]) * day_ratio, max(all_points[:, 0]) * day_ratio], | |
| y_range=[min(all_points[:, 1]), max(all_points[:, 1])], | |
| x_axis_type="log", y_axis_type="log", | |
| x_axis_label="Floating-point operations (excluding embeddings & softmax)", | |
| y_axis_label="Validation loss on Wikitext-103", output_backend="webgl") | |
| in_text_loss_plot.add_layout(color_bar, "left") | |
| in_text_loss_plot.align = "center" | |
| source = ColumnDataSource(data=dict( | |
| xs=[run[:, 0] * day_ratio for run in indexed_runs], # x coords for each line (list of lists) | |
| ys=[run[:, 1] for run in indexed_runs], # y coords for each line (list of lists) | |
| params=params_per_run # data to use for colormapping | |
| )) | |
| in_text_loss_plot.multi_line('xs', 'ys', source=source, | |
| color=log_cmap('params', palette, min(params_per_run), max(params_per_run))) | |
| source = ColumnDataSource(data=dict( | |
| x=[compute for run in indexed_runs for compute in run[:, 0] * day_ratio], # x coords for each line (list of lists) | |
| y=[loss for run in indexed_runs for loss in run[:, 1]], # y coords for each line (list of lists) | |
| params=[repeated_params for i, params in enumerate(params_per_run) | |
| for repeated_params in [params] * len(indexed_runs[i])] # data to use for colormapping | |
| )) | |
| in_text_loss_plot.scatter('x', 'y', source=source, | |
| color=log_cmap('params', palette, min(params_per_run), max(params_per_run)), size=3) | |
| # for i, run in indexed_runs.items(): | |
| # source = ColumnDataSource(data=dict(x=run[:, 0] * day_ratio, y=run[:, 1])) | |
| # in_text_loss_plot.line('x', 'y', source=source, line_width=1, line_alpha=0.6, color=color_list[i]) | |
| # in_text_loss_plot.scatter('x', 'y', source=source, line_width=1, line_alpha=0.6, color=color_list[i]) | |
| in_text_param_plot = figure(plot_height=in_text_plot_height, plot_width=in_text_plot_width, | |
| title="Optimal number of non-embedding parameters per floating-point operations budget", | |
| tools="pan,reset,save,wheel_zoom,tap", active_scroll="wheel_zoom", | |
| x_range=in_text_loss_plot.x_range, | |
| y_range=[min(params_per_run), max(params_per_run)], | |
| x_axis_type="log", y_axis_type="log", | |
| x_axis_label="Floating-point operations (excluding embeddings & softmax)", | |
| y_axis_label="Optimal number of non-embedding parameters", output_backend="webgl") | |
| in_text_param_plot.add_layout(color_bar, "left") | |
| in_text_param_plot.align = "center" | |
| # for i, run_apex in enumerate(compute_at_hull): | |
| # source = ColumnDataSource(data=dict(x=[compute_at_hull[i, 0] * day_ratio], y=[compute_at_hull[i, 1]])) | |
| # in_text_param_plot.scatter('x', 'y', source=source, color=color_list[run_indices_at_hull[i]]) | |
| source = ColumnDataSource(data=dict(x=compute_at_hull[:, 0] * day_ratio, y=compute_at_hull[:, 1], | |
| params=[params for i, params in enumerate(params_per_run) if | |
| i in set(hull_points[:, 2])])) | |
| in_text_param_plot.scatter('x', 'y', source=source, | |
| color=log_cmap('params', palette, min(params_per_run), max(params_per_run))) | |
| training_button = Button(width=175, label="Fit!") | |
| training_button.align = "center" | |
| fit_button = Button(width=175, label="Fit!") | |
| fit_button.align = "center" | |
| def on_train_click(): | |
| display_abscisses = np.array([min(all_points[:, 0]) / 1.25] + sorted(list(all_points[:, 0])) + | |
| [max(all_points[:, 0]) * 1.25]) | |
| source = ColumnDataSource( | |
| data=dict(x=sorted(display_abscisses * day_ratio), y=loss_fit(sorted(display_abscisses), *loss_popt))) | |
| in_text_loss_plot.line('x', 'y', source=source, line_width=1, line_alpha=1, color="red") | |
| def on_fit_click(): | |
| display_abscisses = np.array([min(compute_at_hull[:, 0]) / 1.25] + sorted(list(compute_at_hull[:, 0])) + | |
| [max(compute_at_hull[:, 0]) * 1.25]) | |
| source = ColumnDataSource(data=dict(x=display_abscisses * day_ratio, | |
| y=safe_flo_to_param(display_abscisses, d, e, f))) | |
| in_text_param_plot.line('x', 'y', source=source, line_width=1, line_alpha=0.8, color="orange") | |
| training_button.on_click(on_train_click) | |
| fit_button.on_click(on_fit_click) | |
| before_text = column(text_1, training_button, in_text_loss_plot, text_2, fit_button, in_text_param_plot, text_3) | |
| after_text = column(text_4) | |
| ######################################################################################################################## | |
| # Set up layouts and add to document | |
| ######################################################################################################################## | |
| inputs = column(input_text, gpu_dropdown, amp_mode_dropdown, hours_slider, dollars_slider, input_buffer, energy_text, | |
| sizing_mode="scale_width", width=sidebar_width, height=plot_height) | |
| results = column(static_loss_text, | |
| optimal_loss_text, | |
| static_param_text, | |
| optimal_param_text, | |
| static_shape_text, | |
| optimal_shape_text, | |
| initialize_optimal, | |
| results_buffer, | |
| static_altshape_text, | |
| optimal_altshape_text, | |
| initialize_alternate, sizing_mode="scale_width", width=sidebar_width, height=plot_height) | |
| # app = column(row(inputs, loss_plot, sizing_mode="scale_width"), row(results, param_plot, sizing_mode="scale_width"), | |
| # code_box, sizing_mode="scale_width") | |
| app = column(row(column(inputs, results, sizing_mode="fixed"), | |
| column(loss_plot, param_plot, sizing_mode="stretch_width", )), | |
| code_box, sizing_mode="scale_width") | |
| before_text.align = "center" | |
| app.align = "center" | |
| after_text.align = "center" | |
| main_body = column(before_text, app, after_text, sizing_mode="scale_width") | |
| curdoc().add_root(main_body) | |
| curdoc().title = "How big should my language model be ?" | |