import matplotlib.pyplot as plt from matplotlib.ticker import ScalarFormatter from collections import defaultdict import numpy as np import json import os def get_aggregated_stats(data_list): if not data_list: return None, None # 计算最大长度 max_len = 0 for r, c in data_list: assert len(r) == len(c) valid_len = len(c) max_len = max(max_len, valid_len) # 加权平均 numerator = np.zeros(max_len) denominator = np.zeros(max_len) for rate, count in data_list: l = len(rate) numerator[:l] += rate * count denominator[:l] += count with np.errstate(divide="ignore", invalid="ignore"): avg_rate = numerator / denominator return avg_rate, denominator def process_files(file_paths): model_raw_data = defaultdict(list) dataset_names = set() for file_path in file_paths: with open(file_path, "r", encoding="utf-8") as file: data = json.load(file) default_name = os.path.basename(file_path).replace(".json", "") model_name = data.get("model_name_or_path", default_name).split("/")[-1].replace(".pth", "") data_path = data.get("data_path", "") if data_path: dataset_names.add(data_path.split("/")[-1]) rate_list = data.get("byte_wise_compression_rate", []) counts_list = data.get("byte_wise_counts", []) if rate_list and counts_list: model_raw_data[model_name].append((np.array(rate_list), np.array(counts_list))) return model_raw_data, dataset_names def clean_label_for_path(l): if l.endswith(".json"): l = l.replace("UncheatableEval-", "").replace(".json", "")[:-20] return l def draw_long_context_plot( viz_mode, data_source_map, baseline_key, cutoff_ratio, smooth_window, start_offset, y_range, tail_drop_ratio=0.02, ): if not data_source_map: return None final_curves = {} dataset_info = set() for label, file_paths in data_source_map.items(): print(file_paths) raw_data_list = [] for fp in file_paths: with open(fp, "r", encoding="utf-8") as f: d = json.load(f) rate = d.get("byte_wise_compression_rate", []) counts = d.get("byte_wise_counts", []) d_path = d.get("data_path", "") if d_path: dataset_info.add(d_path.split("/")[-1]) if rate and counts: raw_data_list.append((np.array(rate), np.array(counts))) avg_rate, total_counts = get_aggregated_stats(raw_data_list) if avg_rate is not None: if len(total_counts) > 0: initial_count = total_counts[0] threshold = initial_count * cutoff_ratio cut_idx = np.where(total_counts < threshold)[0] if len(cut_idx) > 0: avg_rate = avg_rate[: cut_idx[0]] # 应用末尾数据丢弃率 if tail_drop_ratio > 0 and len(avg_rate) > 0: drop_count = int(len(avg_rate) * tail_drop_ratio) if drop_count > 0: avg_rate = avg_rate[:-drop_count] final_curves[label] = avg_rate plot_items = [] is_relative = "Relative" in viz_mode if is_relative: if baseline_key not in final_curves: return None baseline_curve = final_curves[baseline_key] for label, curve in final_curves.items(): if label == baseline_key: continue min_len = min(len(baseline_curve), len(curve)) if min_len <= start_offset: continue delta = curve[:min_len] - baseline_curve[:min_len] plot_items.append({"label": clean_label_for_path(label), "values": delta[start_offset:], "start_x": start_offset + 1}) else: # Absolute for label, curve in final_curves.items(): if len(curve) <= start_offset: continue plot_items.append({"label": clean_label_for_path(label), "values": curve[start_offset:], "start_x": start_offset + 1}) if not plot_items: return None # fig, ax = plt.subplots(figsize=(12, 6), dpi=300) fig, ax = plt.subplots(figsize=(12, 6), dpi=300, layout="constrained") if is_relative: ax.axhline(0, color="black", linestyle="--", linewidth=1, alpha=0.5, label=f"Baseline: {clean_label_for_path(baseline_key)}") colors = plt.cm.tab10(np.linspace(0, 1, max(1, len(plot_items)))) for idx, item in enumerate(plot_items): y_val = item["values"] x_val = np.arange(item["start_x"], item["start_x"] + len(y_val)) color = colors[idx % 10] if smooth_window > 1 and len(y_val) > smooth_window: kernel = np.ones(smooth_window) / smooth_window y_smooth = np.convolve(y_val, kernel, mode="valid") x_smooth = x_val[smooth_window - 1 :] ax.plot(x_val, y_val, color=color, alpha=0.15, linewidth=0.5) ax.plot(x_smooth, y_smooth, color=color, linewidth=2, label=item["label"]) else: ax.plot(x_val, y_val, color=color, linewidth=1.5, label=item["label"]) ax.set_xscale("log") ax.xaxis.set_major_formatter(ScalarFormatter()) if y_range: ymin, ymax = y_range if ymin is not None and ymax is not None: ax.set_ylim(ymin, ymax) elif ymin is not None: current_ymax = ax.get_ylim()[1] ax.set_ylim(bottom=ymin, top=current_ymax) elif ymax is not None: current_ymin = ax.get_ylim()[0] ax.set_ylim(bottom=current_ymin, top=ymax) mode_title = "Compression Rate vs Byte Position" if not is_relative else "Compression Rate Delta vs Byte Position" ax.set_title(f"{mode_title}", fontsize=12, pad=8) ax.legend(loc="upper left", bbox_to_anchor=(1.02, 1), ncol=1, fontsize="small", frameon=False, borderaxespad=0.0) ax.set_xlabel("Byte Position") ax.set_ylabel("Compression Rate (%)" if not is_relative else "Compression Rate Delta (%)") ax.grid(True, which="major", ls="-", alpha=0.8) ax.grid(True, which="minor", ls="--", alpha=0.3) return fig