UncheatableEval / longctx_utils.py
Jellyfish042's picture
2026-01 update
a607ab6
import matplotlib.pyplot as plt
from matplotlib.ticker import ScalarFormatter
from collections import defaultdict
import numpy as np
import json
import os
def get_aggregated_stats(data_list):
if not data_list:
return None, None
# 计算最大长度
max_len = 0
for r, c in data_list:
assert len(r) == len(c)
valid_len = len(c)
max_len = max(max_len, valid_len)
# 加权平均
numerator = np.zeros(max_len)
denominator = np.zeros(max_len)
for rate, count in data_list:
l = len(rate)
numerator[:l] += rate * count
denominator[:l] += count
with np.errstate(divide="ignore", invalid="ignore"):
avg_rate = numerator / denominator
return avg_rate, denominator
def process_files(file_paths):
model_raw_data = defaultdict(list)
dataset_names = set()
for file_path in file_paths:
with open(file_path, "r", encoding="utf-8") as file:
data = json.load(file)
default_name = os.path.basename(file_path).replace(".json", "")
model_name = data.get("model_name_or_path", default_name).split("/")[-1].replace(".pth", "")
data_path = data.get("data_path", "")
if data_path:
dataset_names.add(data_path.split("/")[-1])
rate_list = data.get("byte_wise_compression_rate", [])
counts_list = data.get("byte_wise_counts", [])
if rate_list and counts_list:
model_raw_data[model_name].append((np.array(rate_list), np.array(counts_list)))
return model_raw_data, dataset_names
def clean_label_for_path(l):
if l.endswith(".json"):
l = l.replace("UncheatableEval-", "").replace(".json", "")[:-20]
return l
def draw_long_context_plot(
viz_mode,
data_source_map,
baseline_key,
cutoff_ratio,
smooth_window,
start_offset,
y_range,
tail_drop_ratio=0.02,
):
if not data_source_map:
return None
final_curves = {}
dataset_info = set()
for label, file_paths in data_source_map.items():
print(file_paths)
raw_data_list = []
for fp in file_paths:
with open(fp, "r", encoding="utf-8") as f:
d = json.load(f)
rate = d.get("byte_wise_compression_rate", [])
counts = d.get("byte_wise_counts", [])
d_path = d.get("data_path", "")
if d_path:
dataset_info.add(d_path.split("/")[-1])
if rate and counts:
raw_data_list.append((np.array(rate), np.array(counts)))
avg_rate, total_counts = get_aggregated_stats(raw_data_list)
if avg_rate is not None:
if len(total_counts) > 0:
initial_count = total_counts[0]
threshold = initial_count * cutoff_ratio
cut_idx = np.where(total_counts < threshold)[0]
if len(cut_idx) > 0:
avg_rate = avg_rate[: cut_idx[0]]
# 应用末尾数据丢弃率
if tail_drop_ratio > 0 and len(avg_rate) > 0:
drop_count = int(len(avg_rate) * tail_drop_ratio)
if drop_count > 0:
avg_rate = avg_rate[:-drop_count]
final_curves[label] = avg_rate
plot_items = []
is_relative = "Relative" in viz_mode
if is_relative:
if baseline_key not in final_curves:
return None
baseline_curve = final_curves[baseline_key]
for label, curve in final_curves.items():
if label == baseline_key:
continue
min_len = min(len(baseline_curve), len(curve))
if min_len <= start_offset:
continue
delta = curve[:min_len] - baseline_curve[:min_len]
plot_items.append({"label": clean_label_for_path(label), "values": delta[start_offset:], "start_x": start_offset + 1})
else: # Absolute
for label, curve in final_curves.items():
if len(curve) <= start_offset:
continue
plot_items.append({"label": clean_label_for_path(label), "values": curve[start_offset:], "start_x": start_offset + 1})
if not plot_items:
return None
# fig, ax = plt.subplots(figsize=(12, 6), dpi=300)
fig, ax = plt.subplots(figsize=(12, 6), dpi=300, layout="constrained")
if is_relative:
ax.axhline(0, color="black", linestyle="--", linewidth=1, alpha=0.5, label=f"Baseline: {clean_label_for_path(baseline_key)}")
colors = plt.cm.tab10(np.linspace(0, 1, max(1, len(plot_items))))
for idx, item in enumerate(plot_items):
y_val = item["values"]
x_val = np.arange(item["start_x"], item["start_x"] + len(y_val))
color = colors[idx % 10]
if smooth_window > 1 and len(y_val) > smooth_window:
kernel = np.ones(smooth_window) / smooth_window
y_smooth = np.convolve(y_val, kernel, mode="valid")
x_smooth = x_val[smooth_window - 1 :]
ax.plot(x_val, y_val, color=color, alpha=0.15, linewidth=0.5)
ax.plot(x_smooth, y_smooth, color=color, linewidth=2, label=item["label"])
else:
ax.plot(x_val, y_val, color=color, linewidth=1.5, label=item["label"])
ax.set_xscale("log")
ax.xaxis.set_major_formatter(ScalarFormatter())
if y_range:
ymin, ymax = y_range
if ymin is not None and ymax is not None:
ax.set_ylim(ymin, ymax)
elif ymin is not None:
current_ymax = ax.get_ylim()[1]
ax.set_ylim(bottom=ymin, top=current_ymax)
elif ymax is not None:
current_ymin = ax.get_ylim()[0]
ax.set_ylim(bottom=current_ymin, top=ymax)
mode_title = "Compression Rate vs Byte Position" if not is_relative else "Compression Rate Delta vs Byte Position"
ax.set_title(f"{mode_title}", fontsize=12, pad=8)
ax.legend(loc="upper left", bbox_to_anchor=(1.02, 1), ncol=1, fontsize="small", frameon=False, borderaxespad=0.0)
ax.set_xlabel("Byte Position")
ax.set_ylabel("Compression Rate (%)" if not is_relative else "Compression Rate Delta (%)")
ax.grid(True, which="major", ls="-", alpha=0.8)
ax.grid(True, which="minor", ls="--", alpha=0.3)
return fig