File size: 3,495 Bytes
72c0672
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import json
import os
import argparse
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd  # <--- 必须引入 pandas

def main():
    parser = argparse.ArgumentParser()
    # 这里的默认路径对应你刚才运行的输出路径
    parser.add_argument("--output_dir", type=str, default="analysis_output_parallel", help="Directory with partial .json results")
    args = parser.parse_args()

    print(f"📂 Reading results from: {args.output_dir}")
    
    # 1. 合并结果
    final_results = {"Gzip": [], "Tokenizer": [], "AC_M1": []}
    files_found = 0
    
    if not os.path.exists(args.output_dir):
        print(f"❌ Error: Directory {args.output_dir} does not exist.")
        return

    for filename in os.listdir(args.output_dir):
        if filename.startswith("partial_result_") and filename.endswith(".json"):
            files_found += 1
            file_path = os.path.join(args.output_dir, filename)
            try:
                with open(file_path, 'r') as f:
                    data = json.load(f)
                    for k in final_results:
                        if k in data:
                            final_results[k].extend(data[k])
            except Exception as e:
                print(f"⚠️ Error reading {filename}: {e}")
    
    print(f"✅ Merged data from {files_found} files.")
    
    # 2. 准备绘图数据
    plot_records = []
    stats_summary = {}
    
    for algo, vals in final_results.items():
        if not vals:
            continue
            
        # 过滤异常值 (大于 2.0 的通常是极少数的离群点)
        cleaned = [v for v in vals if v < 2.0]
        
        # 记录统计信息
        stats_summary[algo] = {
            "mean": float(np.mean(vals)),
            "median": float(np.median(vals)),
            "count": len(vals)
        }
        
        # 构建用于 DataFrame 的列表
        for v in cleaned:
            plot_records.append({"Algorithm": algo, "Normalized Edit Distance": v})
            
    if not plot_records:
        print("❌ No valid data collected to plot.")
        return

    # === 关键修正:转换为 Pandas DataFrame ===
    df = pd.DataFrame(plot_records)
    print(f"📊 Plotting {len(df)} data points...")

    # 3. 绘图
    plt.figure(figsize=(12, 7))
    sns.set_style("whitegrid")
    
    # 使用 DataFrame 进行绘图
    sns.kdeplot(
        data=df, 
        x="Normalized Edit Distance", 
        hue="Algorithm", 
        fill=True, 
        common_norm=False, 
        palette="tab10", 
        alpha=0.5,
        linewidth=2
    )
    
    plt.title("Compression Stability Analysis (Impact of 10% Perturbation)")
    plt.xlabel("Normalized Levenshtein Distance (Lower = More Stable)")
    plt.ylabel("Density")
    plt.xlim(0, 1.2) # 聚焦在 0~1.2 范围内
    
    output_img = os.path.join(args.output_dir, "stability_parallel_fixed.png")
    plt.savefig(output_img, dpi=300)
    print(f"🖼️ Plot saved to: {output_img}")
    
    # 4. 保存统计结果
    stats_file = os.path.join(args.output_dir, "final_stats_summary.json")
    with open(stats_file, 'w') as f:
        json.dump(stats_summary, f, indent=2)
    print(f"📄 Stats saved to: {stats_file}")
    
    # 打印简要统计
    print("\n=== Summary Stats ===")
    for algo, stat in stats_summary.items():
        print(f"{algo}: Mean={stat['mean']:.4f}, Count={stat['count']}")

if __name__ == "__main__":
    main()