Spaces:
Running
Running
| import matplotlib.pyplot as plt | |
| from matplotlib.ticker import ScalarFormatter | |
| from collections import defaultdict | |
| import numpy as np | |
| import json | |
| import os | |
| def get_aggregated_stats(data_list): | |
| if not data_list: | |
| return None, None | |
| # 计算最大长度 | |
| max_len = 0 | |
| for r, c in data_list: | |
| assert len(r) == len(c) | |
| valid_len = len(c) | |
| max_len = max(max_len, valid_len) | |
| # 加权平均 | |
| numerator = np.zeros(max_len) | |
| denominator = np.zeros(max_len) | |
| for rate, count in data_list: | |
| l = len(rate) | |
| numerator[:l] += rate * count | |
| denominator[:l] += count | |
| with np.errstate(divide="ignore", invalid="ignore"): | |
| avg_rate = numerator / denominator | |
| return avg_rate, denominator | |
| def process_files(file_paths): | |
| model_raw_data = defaultdict(list) | |
| dataset_names = set() | |
| for file_path in file_paths: | |
| with open(file_path, "r", encoding="utf-8") as file: | |
| data = json.load(file) | |
| default_name = os.path.basename(file_path).replace(".json", "") | |
| model_name = data.get("model_name_or_path", default_name).split("/")[-1].replace(".pth", "") | |
| data_path = data.get("data_path", "") | |
| if data_path: | |
| dataset_names.add(data_path.split("/")[-1]) | |
| rate_list = data.get("byte_wise_compression_rate", []) | |
| counts_list = data.get("byte_wise_counts", []) | |
| if rate_list and counts_list: | |
| model_raw_data[model_name].append((np.array(rate_list), np.array(counts_list))) | |
| return model_raw_data, dataset_names | |
| def clean_label_for_path(l): | |
| if l.endswith(".json"): | |
| l = l.replace("UncheatableEval-", "").replace(".json", "")[:-20] | |
| return l | |
| def draw_long_context_plot( | |
| viz_mode, | |
| data_source_map, | |
| baseline_key, | |
| cutoff_ratio, | |
| smooth_window, | |
| start_offset, | |
| y_range, | |
| tail_drop_ratio=0.02, | |
| ): | |
| if not data_source_map: | |
| return None | |
| final_curves = {} | |
| dataset_info = set() | |
| for label, file_paths in data_source_map.items(): | |
| print(file_paths) | |
| raw_data_list = [] | |
| for fp in file_paths: | |
| with open(fp, "r", encoding="utf-8") as f: | |
| d = json.load(f) | |
| rate = d.get("byte_wise_compression_rate", []) | |
| counts = d.get("byte_wise_counts", []) | |
| d_path = d.get("data_path", "") | |
| if d_path: | |
| dataset_info.add(d_path.split("/")[-1]) | |
| if rate and counts: | |
| raw_data_list.append((np.array(rate), np.array(counts))) | |
| avg_rate, total_counts = get_aggregated_stats(raw_data_list) | |
| if avg_rate is not None: | |
| if len(total_counts) > 0: | |
| initial_count = total_counts[0] | |
| threshold = initial_count * cutoff_ratio | |
| cut_idx = np.where(total_counts < threshold)[0] | |
| if len(cut_idx) > 0: | |
| avg_rate = avg_rate[: cut_idx[0]] | |
| # 应用末尾数据丢弃率 | |
| if tail_drop_ratio > 0 and len(avg_rate) > 0: | |
| drop_count = int(len(avg_rate) * tail_drop_ratio) | |
| if drop_count > 0: | |
| avg_rate = avg_rate[:-drop_count] | |
| final_curves[label] = avg_rate | |
| plot_items = [] | |
| is_relative = "Relative" in viz_mode | |
| if is_relative: | |
| if baseline_key not in final_curves: | |
| return None | |
| baseline_curve = final_curves[baseline_key] | |
| for label, curve in final_curves.items(): | |
| if label == baseline_key: | |
| continue | |
| min_len = min(len(baseline_curve), len(curve)) | |
| if min_len <= start_offset: | |
| continue | |
| delta = curve[:min_len] - baseline_curve[:min_len] | |
| plot_items.append({"label": clean_label_for_path(label), "values": delta[start_offset:], "start_x": start_offset + 1}) | |
| else: # Absolute | |
| for label, curve in final_curves.items(): | |
| if len(curve) <= start_offset: | |
| continue | |
| plot_items.append({"label": clean_label_for_path(label), "values": curve[start_offset:], "start_x": start_offset + 1}) | |
| if not plot_items: | |
| return None | |
| # fig, ax = plt.subplots(figsize=(12, 6), dpi=300) | |
| fig, ax = plt.subplots(figsize=(12, 6), dpi=300, layout="constrained") | |
| if is_relative: | |
| ax.axhline(0, color="black", linestyle="--", linewidth=1, alpha=0.5, label=f"Baseline: {clean_label_for_path(baseline_key)}") | |
| colors = plt.cm.tab10(np.linspace(0, 1, max(1, len(plot_items)))) | |
| for idx, item in enumerate(plot_items): | |
| y_val = item["values"] | |
| x_val = np.arange(item["start_x"], item["start_x"] + len(y_val)) | |
| color = colors[idx % 10] | |
| if smooth_window > 1 and len(y_val) > smooth_window: | |
| kernel = np.ones(smooth_window) / smooth_window | |
| y_smooth = np.convolve(y_val, kernel, mode="valid") | |
| x_smooth = x_val[smooth_window - 1 :] | |
| ax.plot(x_val, y_val, color=color, alpha=0.15, linewidth=0.5) | |
| ax.plot(x_smooth, y_smooth, color=color, linewidth=2, label=item["label"]) | |
| else: | |
| ax.plot(x_val, y_val, color=color, linewidth=1.5, label=item["label"]) | |
| ax.set_xscale("log") | |
| ax.xaxis.set_major_formatter(ScalarFormatter()) | |
| if y_range: | |
| ymin, ymax = y_range | |
| if ymin is not None and ymax is not None: | |
| ax.set_ylim(ymin, ymax) | |
| elif ymin is not None: | |
| current_ymax = ax.get_ylim()[1] | |
| ax.set_ylim(bottom=ymin, top=current_ymax) | |
| elif ymax is not None: | |
| current_ymin = ax.get_ylim()[0] | |
| ax.set_ylim(bottom=current_ymin, top=ymax) | |
| mode_title = "Compression Rate vs Byte Position" if not is_relative else "Compression Rate Delta vs Byte Position" | |
| ax.set_title(f"{mode_title}", fontsize=12, pad=8) | |
| ax.legend(loc="upper left", bbox_to_anchor=(1.02, 1), ncol=1, fontsize="small", frameon=False, borderaxespad=0.0) | |
| ax.set_xlabel("Byte Position") | |
| ax.set_ylabel("Compression Rate (%)" if not is_relative else "Compression Rate Delta (%)") | |
| ax.grid(True, which="major", ls="-", alpha=0.8) | |
| ax.grid(True, which="minor", ls="--", alpha=0.3) | |
| return fig | |