Jellyfish042 commited on
Commit
1be7b2f
·
1 Parent(s): c1b1328
Files changed (1) hide show
  1. app.py +123 -85
app.py CHANGED
@@ -40,15 +40,18 @@ metric_to_sheet = {
40
  "Bits Per Character (BPC)": "bpc",
41
  "Bits Per Byte (BPB)": "bpb",
42
  }
43
- model_size_to_file_name = {
44
- ">20B": "20b+",
45
  "~14B": "14b",
46
  # "~9B": "9b",
47
  "~7B": "7b",
48
  "~3B": "3b",
49
  "~1.5B": "1b5",
50
- "Other": "other",
51
- }
 
 
 
52
 
53
 
54
  def read_about_md():
@@ -264,27 +267,33 @@ def filter_pareto_frontier(x_values, y_values, names):
264
  return [], [], []
265
 
266
 
267
- def fit_power_law_with_offset(x_values, y_values):
268
- """
269
- 使用带偏置的幂律拟合原始数据
270
- 返回: (params, raw_rmse, log_rmse, fit_x, fit_y)
271
- """
272
  x_arr = np.array(x_values)
273
  y_arr = np.array(y_values)
274
 
275
- # 初始参数估计
276
- # 使用简单的幂律拟合作为初始值
277
- log_x = np.log10(x_arr)
278
- log_y = np.log10(y_arr)
279
- slope, intercept = np.polyfit(log_x, log_y, 1)
280
-
281
- a_init = 10**intercept
282
- b_init = slope
283
- c_init = 0 # 偏置初始值设为0
284
-
285
- try:
286
- # 使用curve_fit进行非线性拟合
287
- params, _ = curve_fit(
 
 
 
 
 
 
288
  power_law_with_offset,
289
  x_arr,
290
  y_arr,
@@ -304,14 +313,12 @@ def fit_power_law_with_offset(x_values, y_values):
304
  log_y_pred = np.log10(y_pred)
305
  log_rmse = np.sqrt(np.mean((log_y_actual - log_y_pred) ** 2))
306
 
307
- # 生成拟合曲线的点
308
- x_min, x_max = min(x_values), max(x_values)
309
- fit_x = np.linspace(x_min * 0.8, x_max * 1.2, 100)
310
- fit_y = power_law_with_offset(fit_x, a, b, c)
311
-
312
- return params, raw_rmse, log_rmse, fit_x, fit_y
313
- except Exception as e:
314
- print(f"Fitting failed: {e}")
315
  # 如果拟合失败,返回简单幂律拟合结果
316
  a = a_init
317
  b = b_init
@@ -328,11 +335,9 @@ def fit_power_law_with_offset(x_values, y_values):
328
  log_y_pred = np.log10(y_pred)
329
  log_rmse = np.sqrt(np.mean((log_y_actual - log_y_pred) ** 2))
330
 
331
- x_min, x_max = min(x_values), max(x_values)
332
- fit_x = np.linspace(x_min * 0.8, x_max * 1.2, 100)
333
- fit_y = a * np.power(fit_x, b)
334
-
335
- return params, raw_rmse, log_rmse, fit_x, fit_y
336
 
337
 
338
  def create_scaling_plot(data_manager: DataManager, period: str, use_pareto: bool = False):
@@ -373,16 +378,32 @@ def create_scaling_plot(data_manager: DataManager, period: str, use_pareto: bool
373
  else:
374
  fit_x_values, fit_y_values, fit_names = x_values, y_values, names
375
 
376
- x_min, x_max = np.log10(min(x_values)), np.log10(max(x_values))
377
- y_min, y_max = np.log10(min(y_values)), np.log10(max(y_values))
378
- x_dtick = (x_max - x_min) / 4
379
- y_dtick = (y_max - y_min) / 4
380
-
381
- # 使用筛选后的数据进行拟合
382
- params, raw_rmse, log_rmse, fit_x, fit_y = fit_power_law_with_offset(fit_x_values, fit_y_values)
383
- a, b, c = params
384
-
385
- fig = go.Figure()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
386
 
387
  # 添加所有数据点
388
  fig.add_trace(
@@ -423,16 +444,16 @@ def create_scaling_plot(data_manager: DataManager, period: str, use_pareto: bool
423
  fit_label = f"{fit_type}: y = {a:.2f} × x^{b:.3f}<br>Raw RMSE: {raw_rmse:.2f}, Log-RMSE: {log_rmse:.3f}"
424
  else:
425
  fit_label = f"{fit_type}: y = {a:.2f} × x^{b:.3f} + {c:.2f}<br>Raw RMSE: {raw_rmse:.2f}, Log-RMSE: {log_rmse:.3f}"
426
- fig.add_trace(
427
- go.Scatter(
428
- x=fit_x.tolist(),
429
- y=fit_y.tolist(),
430
- mode="lines",
431
- name=fit_label,
432
- line=dict(color="#FF6B6B", width=2, dash="dash"),
433
- hoverinfo="skip",
434
- )
435
- )
436
 
437
  title_suffix = " (Pareto Frontier)" if use_pareto else ""
438
  fig.update_layout(
@@ -550,8 +571,15 @@ def create_category_scaling_plot(data_manager: DataManager, period: str, selecte
550
  fit_x_vals, fit_y_vals, fit_name_vals = x_vals, y_vals, name_vals
551
 
552
  # 使用筛选后的数据进行拟合
553
- params, raw_rmse, log_rmse, fit_x, fit_y = fit_power_law_with_offset(fit_x_vals, fit_y_vals)
554
- a, b, c = params
 
 
 
 
 
 
 
555
 
556
  # 构建数据集名称列表(用于hover显示)
557
  datasets_label = f"Average of {len(selected_datasets)} datasets"
@@ -595,16 +623,16 @@ def create_category_scaling_plot(data_manager: DataManager, period: str, selecte
595
  fit_label = f"{fit_type}: y = {a:.2f} × x^{b:.3f}<br>Raw RMSE: {raw_rmse:.2f}, Log-RMSE: {log_rmse:.3f}"
596
  else:
597
  fit_label = f"{fit_type}: y = {a:.2f} × x^{b:.3f} + {c:.2f}<br>Raw RMSE: {raw_rmse:.2f}, Log-RMSE: {log_rmse:.3f}"
598
- fig.add_trace(
599
- go.Scatter(
600
- x=fit_x.tolist(),
601
- y=fit_y.tolist(),
602
- mode="lines",
603
- name=fit_label,
604
- line=dict(color="#FF6B6B", width=2, dash="dash"),
605
- hoverinfo="skip",
606
- )
607
- )
608
  else:
609
  # 单独显示模式:为每个数据集创建散点图和拟合线
610
  for idx, dataset in enumerate(selected_datasets):
@@ -639,8 +667,15 @@ def create_category_scaling_plot(data_manager: DataManager, period: str, selecte
639
  fit_x_vals, fit_y_vals, fit_name_vals = x_vals, y_vals, name_vals
640
 
641
  # 使用筛选后的数据进行拟合
642
- params, raw_rmse, log_rmse, fit_x, fit_y = fit_power_law_with_offset(fit_x_vals, fit_y_vals)
643
- a, b, c = params
 
 
 
 
 
 
 
644
 
645
  # 添加所有数据点
646
  fig.add_trace(
@@ -683,18 +718,18 @@ def create_category_scaling_plot(data_manager: DataManager, period: str, selecte
683
  fit_label = f"{dataset} {fit_type}: y = {a:.2f} × x^{b:.3f}<br>Raw RMSE: {raw_rmse:.2f}, Log-RMSE: {log_rmse:.3f}"
684
  else:
685
  fit_label = f"{dataset} {fit_type}: y = {a:.2f} × x^{b:.3f} + {c:.2f}<br>Raw RMSE: {raw_rmse:.2f}, Log-RMSE: {log_rmse:.3f}"
686
- fig.add_trace(
687
- go.Scatter(
688
- x=fit_x.tolist(),
689
- y=fit_y.tolist(),
690
- mode="lines",
691
- name=fit_label,
692
- line=dict(color=color, width=2, dash="dash"),
693
- hoverinfo="skip",
694
- legendgroup=dataset,
695
- showlegend=True,
696
- )
697
- )
698
 
699
  if not all_x_values or not all_y_values:
700
  fig = go.Figure()
@@ -702,10 +737,13 @@ def create_category_scaling_plot(data_manager: DataManager, period: str, selecte
702
  return fig
703
 
704
  # 计算全局坐标范围
705
- x_min, x_max = np.log10(min(all_x_values)), np.log10(max(all_x_values))
706
- y_min, y_max = np.log10(min(all_y_values)), np.log10(max(all_y_values))
707
- x_dtick = (x_max - x_min) / 4
708
- y_dtick = (y_max - y_min) / 4
 
 
 
709
 
710
  fig.update_layout(
711
  title={"text": "Scaling Law by Dataset", "x": 0.5, "xanchor": "center", "yanchor": "top"},
 
40
  "Bits Per Character (BPC)": "bpc",
41
  "Bits Per Byte (BPB)": "bpb",
42
  }
43
+ model_size_to_file_name = {
44
+ ">20B": "20b+",
45
  "~14B": "14b",
46
  # "~9B": "9b",
47
  "~7B": "7b",
48
  "~3B": "3b",
49
  "~1.5B": "1b5",
50
+ "Other": "other",
51
+ }
52
+ SCALING_EXTRAPOLATE_MAX_B = 10000
53
+ SCALING_FIT_POINTS = 200
54
+ FIT_LINE_HOVER_TEMPLATE = "Params: %{x:.2f}B<br>Predicted CR: %{y:.2f}%<extra></extra>"
55
 
56
 
57
  def read_about_md():
 
267
  return [], [], []
268
 
269
 
270
+ def fit_power_law_with_offset(x_values, y_values, extrapolate_max_b=None, num_points=SCALING_FIT_POINTS):
271
+ """
272
+ 使用带偏置的幂律拟合原始数据
273
+ 返回: (params, raw_rmse, log_rmse, fit_x, fit_y)
274
+ """
275
  x_arr = np.array(x_values)
276
  y_arr = np.array(y_values)
277
 
278
+ # 初始参数估计
279
+ # 使用简单的幂律拟合作为初始值
280
+ log_x = np.log10(x_arr)
281
+ log_y = np.log10(y_arr)
282
+ slope, intercept = np.polyfit(log_x, log_y, 1)
283
+
284
+ a_init = 10**intercept
285
+ b_init = slope
286
+ c_init = 0 # 偏置初始值设为0
287
+ x_min, x_max = x_arr.min(), x_arr.max()
288
+ x_start = max(x_min * 0.8, np.finfo(float).tiny)
289
+ x_end = x_max * 1.2
290
+ if extrapolate_max_b is not None:
291
+ x_end = max(x_end, extrapolate_max_b)
292
+ fit_x = np.logspace(np.log10(x_start), np.log10(x_end), num_points)
293
+
294
+ try:
295
+ # 使用curve_fit进行非线性拟合
296
+ params, _ = curve_fit(
297
  power_law_with_offset,
298
  x_arr,
299
  y_arr,
 
313
  log_y_pred = np.log10(y_pred)
314
  log_rmse = np.sqrt(np.mean((log_y_actual - log_y_pred) ** 2))
315
 
316
+ # 生成拟合曲线的点
317
+ fit_y = power_law_with_offset(fit_x, a, b, c)
318
+
319
+ return params, raw_rmse, log_rmse, fit_x, fit_y
320
+ except Exception as e:
321
+ print(f"Fitting failed: {e}")
 
 
322
  # 如果拟合失败,返回简单幂律拟合结果
323
  a = a_init
324
  b = b_init
 
335
  log_y_pred = np.log10(y_pred)
336
  log_rmse = np.sqrt(np.mean((log_y_actual - log_y_pred) ** 2))
337
 
338
+ fit_y = a * np.power(fit_x, b)
339
+
340
+ return params, raw_rmse, log_rmse, fit_x, fit_y
 
 
341
 
342
 
343
  def create_scaling_plot(data_manager: DataManager, period: str, use_pareto: bool = False):
 
378
  else:
379
  fit_x_values, fit_y_values, fit_names = x_values, y_values, names
380
 
381
+ x_min_val = min(x_values)
382
+ x_max_val = max(x_values)
383
+ x_axis_max = x_max_val
384
+
385
+ # 使用筛选后的数据进行拟合
386
+ params, raw_rmse, log_rmse, fit_x, fit_y = fit_power_law_with_offset(
387
+ fit_x_values,
388
+ fit_y_values,
389
+ extrapolate_max_b=SCALING_EXTRAPOLATE_MAX_B,
390
+ )
391
+ a, b, c = params
392
+ y_min_val = min(y_values)
393
+ y_max_val = max(y_values)
394
+ positive_fit_y = fit_y[fit_y > 0]
395
+ if positive_fit_y.size:
396
+ y_min_val = min(y_min_val, float(positive_fit_y.min()))
397
+ y_max_val = max(y_max_val, float(positive_fit_y.max()))
398
+
399
+ x_min = np.log10(x_min_val)
400
+ x_max = np.log10(x_axis_max)
401
+ y_min = np.log10(y_min_val)
402
+ y_max = np.log10(y_max_val)
403
+ x_dtick = (x_max - x_min) / 4
404
+ y_dtick = (y_max - y_min) / 4
405
+
406
+ fig = go.Figure()
407
 
408
  # 添加所有数据点
409
  fig.add_trace(
 
444
  fit_label = f"{fit_type}: y = {a:.2f} × x^{b:.3f}<br>Raw RMSE: {raw_rmse:.2f}, Log-RMSE: {log_rmse:.3f}"
445
  else:
446
  fit_label = f"{fit_type}: y = {a:.2f} × x^{b:.3f} + {c:.2f}<br>Raw RMSE: {raw_rmse:.2f}, Log-RMSE: {log_rmse:.3f}"
447
+ fig.add_trace(
448
+ go.Scatter(
449
+ x=fit_x.tolist(),
450
+ y=fit_y.tolist(),
451
+ mode="lines",
452
+ name=fit_label,
453
+ line=dict(color="#FF6B6B", width=2, dash="dash"),
454
+ hovertemplate=FIT_LINE_HOVER_TEMPLATE,
455
+ )
456
+ )
457
 
458
  title_suffix = " (Pareto Frontier)" if use_pareto else ""
459
  fig.update_layout(
 
571
  fit_x_vals, fit_y_vals, fit_name_vals = x_vals, y_vals, name_vals
572
 
573
  # 使用筛选后的数据进行拟合
574
+ params, raw_rmse, log_rmse, fit_x, fit_y = fit_power_law_with_offset(
575
+ fit_x_vals,
576
+ fit_y_vals,
577
+ extrapolate_max_b=SCALING_EXTRAPOLATE_MAX_B,
578
+ )
579
+ a, b, c = params
580
+ positive_fit_y = fit_y[fit_y > 0]
581
+ if positive_fit_y.size:
582
+ all_y_values.extend(positive_fit_y.tolist())
583
 
584
  # 构建数据集名称列表(用于hover显示)
585
  datasets_label = f"Average of {len(selected_datasets)} datasets"
 
623
  fit_label = f"{fit_type}: y = {a:.2f} × x^{b:.3f}<br>Raw RMSE: {raw_rmse:.2f}, Log-RMSE: {log_rmse:.3f}"
624
  else:
625
  fit_label = f"{fit_type}: y = {a:.2f} × x^{b:.3f} + {c:.2f}<br>Raw RMSE: {raw_rmse:.2f}, Log-RMSE: {log_rmse:.3f}"
626
+ fig.add_trace(
627
+ go.Scatter(
628
+ x=fit_x.tolist(),
629
+ y=fit_y.tolist(),
630
+ mode="lines",
631
+ name=fit_label,
632
+ line=dict(color="#FF6B6B", width=2, dash="dash"),
633
+ hovertemplate=FIT_LINE_HOVER_TEMPLATE,
634
+ )
635
+ )
636
  else:
637
  # 单独显示模式:为每个数据集创建散点图和拟合线
638
  for idx, dataset in enumerate(selected_datasets):
 
667
  fit_x_vals, fit_y_vals, fit_name_vals = x_vals, y_vals, name_vals
668
 
669
  # 使用筛选后的数据进行拟合
670
+ params, raw_rmse, log_rmse, fit_x, fit_y = fit_power_law_with_offset(
671
+ fit_x_vals,
672
+ fit_y_vals,
673
+ extrapolate_max_b=SCALING_EXTRAPOLATE_MAX_B,
674
+ )
675
+ a, b, c = params
676
+ positive_fit_y = fit_y[fit_y > 0]
677
+ if positive_fit_y.size:
678
+ all_y_values.extend(positive_fit_y.tolist())
679
 
680
  # 添加所有数据点
681
  fig.add_trace(
 
718
  fit_label = f"{dataset} {fit_type}: y = {a:.2f} × x^{b:.3f}<br>Raw RMSE: {raw_rmse:.2f}, Log-RMSE: {log_rmse:.3f}"
719
  else:
720
  fit_label = f"{dataset} {fit_type}: y = {a:.2f} × x^{b:.3f} + {c:.2f}<br>Raw RMSE: {raw_rmse:.2f}, Log-RMSE: {log_rmse:.3f}"
721
+ fig.add_trace(
722
+ go.Scatter(
723
+ x=fit_x.tolist(),
724
+ y=fit_y.tolist(),
725
+ mode="lines",
726
+ name=fit_label,
727
+ line=dict(color=color, width=2, dash="dash"),
728
+ hovertemplate=FIT_LINE_HOVER_TEMPLATE,
729
+ legendgroup=dataset,
730
+ showlegend=True,
731
+ )
732
+ )
733
 
734
  if not all_x_values or not all_y_values:
735
  fig = go.Figure()
 
737
  return fig
738
 
739
  # 计算全局坐标范围
740
+ x_min_val = min(all_x_values)
741
+ x_max_val = max(all_x_values)
742
+ x_axis_max = x_max_val
743
+ x_min, x_max = np.log10(x_min_val), np.log10(x_axis_max)
744
+ y_min, y_max = np.log10(min(all_y_values)), np.log10(max(all_y_values))
745
+ x_dtick = (x_max - x_min) / 4
746
+ y_dtick = (y_max - y_min) / 4
747
 
748
  fig.update_layout(
749
  title={"text": "Scaling Law by Dataset", "x": 0.5, "xanchor": "center", "yanchor": "top"},