my_datasets / simple_cell_acc.py
zeyuzy's picture
Upload folder using huggingface_hub
077b816 verified
import json
import matplotlib.pyplot as plt
import os
import numpy as np
from collections import defaultdict
def calculate_accuracy(label, predict):
"""计算单个样本的准确率"""
label = label.replace("[PAD]", "").replace("[EOS]", "0")
predict = predict.replace("[PAD]", "").replace("[EOS]", "0")
total_chars = len(label)
correct_chars = sum(1 for l, p in zip(label, predict) if l == p)
return correct_chars / total_chars
# return label == predict
def evaluate_jsonl(file_path):
total_accuracy = 0
sample_count = 0
individual_accuracies = []
with open(file_path, 'r') as f:
for line in f:
data = json.loads(line)
label = data['label']
predict = data['predict']
if len(label) != len(predict):
print(f"警告: 第{sample_count+1}行长度不一致 (label:{len(label)} vs predict:{len(predict)})")
continue
accuracy = calculate_accuracy(label, predict)
total_accuracy += accuracy
sample_count += 1
if sample_count == 0:
return 0, []
avg_accuracy = total_accuracy / sample_count
return avg_accuracy
def load_experiment_data(exp_path, dataset_name, T_list, acc_key="predict_acc"):
t_values = []
accuracies = []
missing_files = []
results = {}
for t in T_list:
file_path = os.path.join(exp_path, f"{dataset_name}_T{t}")
filename = os.path.join(file_path, "all_results.json")
if not os.path.exists(filename):
missing_files.append(filename)
continue
try:
with open(filename, 'r') as f:
data = json.load(f)
if acc_key not in data:
print(f"警告: {filename} 中未找到键 '{acc_key}'")
src_filename = os.path.join(file_path, "generated_predictions.jsonl")
data[acc_key] = evaluate_jsonl(src_filename)
with open(filename, 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=4)
acc = data[acc_key]
if isinstance(acc, str) and acc.endswith('%'):
acc = float(acc.strip('%')) / 100.0
t_values.append(t)
accuracies.append(acc)
results[t] = acc
except Exception as e:
print(f"处理文件 {filename} 时出错: {str(e)}")
return results, missing_files
# experiments = [
# "output/sudoku/gpt2-model-bs1024-lr1e-3-ep100-20250702-123750",
# "output/sudoku/gpt2-model-bs1024-lr1e-3-ep100-20250703-021344"
# ]
# dataset_names = ["hard_test", "sudoku_test"]
experiments = [
"output/sudoku/gpt2-model-bs1024-lr1e-3-ep100-20250703-073900",
"output/sudoku/gpt2-model-bs1024-lr1e-3-ep100-20250703-075910",
"output/sudoku/gpt2-model-bs1024-lr1e-3-ep300-20250618-082232"
]
dataset_names = ["sudoku_test", "sudoku_test", "sudoku_test"]
acc_key = "cell_acc"
for experiment in experiments:
for dataset_name in dataset_names:
file_path = os.path.join(experiment, dataset_name)
filename = os.path.join(file_path, "all_results.json")
with open(filename, 'r') as f:
data = json.load(f)
if acc_key not in data:
print(f"警告: {filename} 中未找到键 '{acc_key}'")
src_filename = os.path.join(file_path, "generated_predictions.jsonl")
data[acc_key] = evaluate_jsonl(src_filename)
with open(filename, 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=4)