File size: 6,335 Bytes
c807fbd
 
 
 
 
 
 
 
 
 
 
 
7a8bfa8
c807fbd
 
7a8bfa8
 
 
c807fbd
7a8bfa8
c807fbd
 
7a8bfa8
c807fbd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a607ab6
c807fbd
 
 
 
 
 
 
 
7a8bfa8
c807fbd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a607ab6
 
 
 
 
 
c807fbd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a607ab6
c807fbd
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
import matplotlib.pyplot as plt
from matplotlib.ticker import ScalarFormatter
from collections import defaultdict
import numpy as np
import json
import os


def get_aggregated_stats(data_list):
    if not data_list:
        return None, None

    # 计算最大长度
    max_len = 0
    for r, c in data_list:
        assert len(r) == len(c)
        valid_len = len(c)
        max_len = max(max_len, valid_len)

    # 加权平均
    numerator = np.zeros(max_len)
    denominator = np.zeros(max_len)
    for rate, count in data_list:
        l = len(rate)
        numerator[:l] += rate * count
        denominator[:l] += count
    with np.errstate(divide="ignore", invalid="ignore"):
        avg_rate = numerator / denominator

    return avg_rate, denominator


def process_files(file_paths):
    model_raw_data = defaultdict(list)
    dataset_names = set()

    for file_path in file_paths:
        with open(file_path, "r", encoding="utf-8") as file:
            data = json.load(file)
            default_name = os.path.basename(file_path).replace(".json", "")
            model_name = data.get("model_name_or_path", default_name).split("/")[-1].replace(".pth", "")

            data_path = data.get("data_path", "")
            if data_path:
                dataset_names.add(data_path.split("/")[-1])

            rate_list = data.get("byte_wise_compression_rate", [])
            counts_list = data.get("byte_wise_counts", [])

            if rate_list and counts_list:
                model_raw_data[model_name].append((np.array(rate_list), np.array(counts_list)))

    return model_raw_data, dataset_names


def clean_label_for_path(l):
    if l.endswith(".json"):
        l = l.replace("UncheatableEval-", "").replace(".json", "")[:-20]
    return l


def draw_long_context_plot(
    viz_mode,
    data_source_map,
    baseline_key,
    cutoff_ratio,
    smooth_window,
    start_offset,
    y_range,
    tail_drop_ratio=0.02,
):
    if not data_source_map:
        return None

    final_curves = {}
    dataset_info = set()

    for label, file_paths in data_source_map.items():
        print(file_paths)
        raw_data_list = []

        for fp in file_paths:
            with open(fp, "r", encoding="utf-8") as f:
                d = json.load(f)
                rate = d.get("byte_wise_compression_rate", [])
                counts = d.get("byte_wise_counts", [])
                d_path = d.get("data_path", "")
                if d_path:
                    dataset_info.add(d_path.split("/")[-1])

                if rate and counts:
                    raw_data_list.append((np.array(rate), np.array(counts)))

        avg_rate, total_counts = get_aggregated_stats(raw_data_list)

        if avg_rate is not None:
            if len(total_counts) > 0:
                initial_count = total_counts[0]
                threshold = initial_count * cutoff_ratio
                cut_idx = np.where(total_counts < threshold)[0]
                if len(cut_idx) > 0:
                    avg_rate = avg_rate[: cut_idx[0]]

            # 应用末尾数据丢弃率
            if tail_drop_ratio > 0 and len(avg_rate) > 0:
                drop_count = int(len(avg_rate) * tail_drop_ratio)
                if drop_count > 0:
                    avg_rate = avg_rate[:-drop_count]

            final_curves[label] = avg_rate

    plot_items = []
    is_relative = "Relative" in viz_mode

    if is_relative:
        if baseline_key not in final_curves:
            return None
        baseline_curve = final_curves[baseline_key]

        for label, curve in final_curves.items():
            if label == baseline_key:
                continue

            min_len = min(len(baseline_curve), len(curve))
            if min_len <= start_offset:
                continue

            delta = curve[:min_len] - baseline_curve[:min_len]
            plot_items.append({"label": clean_label_for_path(label), "values": delta[start_offset:], "start_x": start_offset + 1})
    else:  # Absolute
        for label, curve in final_curves.items():
            if len(curve) <= start_offset:
                continue
            plot_items.append({"label": clean_label_for_path(label), "values": curve[start_offset:], "start_x": start_offset + 1})

    if not plot_items:
        return None

    # fig, ax = plt.subplots(figsize=(12, 6), dpi=300)
    fig, ax = plt.subplots(figsize=(12, 6), dpi=300, layout="constrained")

    if is_relative:
        ax.axhline(0, color="black", linestyle="--", linewidth=1, alpha=0.5, label=f"Baseline: {clean_label_for_path(baseline_key)}")

    colors = plt.cm.tab10(np.linspace(0, 1, max(1, len(plot_items))))

    for idx, item in enumerate(plot_items):
        y_val = item["values"]
        x_val = np.arange(item["start_x"], item["start_x"] + len(y_val))
        color = colors[idx % 10]

        if smooth_window > 1 and len(y_val) > smooth_window:
            kernel = np.ones(smooth_window) / smooth_window
            y_smooth = np.convolve(y_val, kernel, mode="valid")
            x_smooth = x_val[smooth_window - 1 :]
            ax.plot(x_val, y_val, color=color, alpha=0.15, linewidth=0.5)
            ax.plot(x_smooth, y_smooth, color=color, linewidth=2, label=item["label"])
        else:
            ax.plot(x_val, y_val, color=color, linewidth=1.5, label=item["label"])

    ax.set_xscale("log")
    ax.xaxis.set_major_formatter(ScalarFormatter())

    if y_range:
        ymin, ymax = y_range
        if ymin is not None and ymax is not None:
            ax.set_ylim(ymin, ymax)
        elif ymin is not None:
            current_ymax = ax.get_ylim()[1]
            ax.set_ylim(bottom=ymin, top=current_ymax)
        elif ymax is not None:
            current_ymin = ax.get_ylim()[0]
            ax.set_ylim(bottom=current_ymin, top=ymax)

    mode_title = "Compression Rate vs Byte Position" if not is_relative else "Compression Rate Delta vs Byte Position"

    ax.set_title(f"{mode_title}", fontsize=12, pad=8)

    ax.legend(loc="upper left", bbox_to_anchor=(1.02, 1), ncol=1, fontsize="small", frameon=False, borderaxespad=0.0)
    ax.set_xlabel("Byte Position")
    ax.set_ylabel("Compression Rate (%)" if not is_relative else "Compression Rate Delta (%)")
    ax.grid(True, which="major", ls="-", alpha=0.8)
    ax.grid(True, which="minor", ls="--", alpha=0.3)

    return fig