File size: 5,420 Bytes
c4ee4aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158

# #accuracy
import json
import re

def compute_accuracy_from_file(filepath):
    total, correct = 0, 0
    pattern = re.compile(r"<ANSWER>(.*?)</ANSWER>", re.IGNORECASE)

    with open(filepath, 'r', encoding='utf-8') as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            data = json.loads(line)
            print(data)
            tgt_match = data.get("reference_answer", "")
            print(tgt_match)
            pred_match = pattern.search(data.get("generated_answer", ""))
            #pred_match = data.get("generated_answer", "")
            print(pred_match)
            if not tgt_match or not pred_match:
                continue
            
            try:
                tgt_val = int(tgt_match.strip())
                pred_val = int(pred_match.group(1).strip())
                 #.group(1).
            except ValueError:
                continue  # 如果无法转换为int,跳过此条


            total += 1
            if pred_val == tgt_val:
                correct += 1

    if total == 0:
        return 0.0, 0
    return correct / total, total

if __name__ == "__main__":
    filepath = "/nas/shared/kilab/wangyujia/BIO/ablation/temperature_stability.jsonl"
    acc, count = compute_accuracy_from_file(filepath)
    print(f"Checked {count} items. Accuracy: {acc*100:.3f}%")


# #spearman
# import json
# import re
# from scipy.stats import spearmanr

# # 文件路径(替换为你的 JSONL 文件路径)
# file_path = '/nas/shared/kilab/wangyujia/BIO/ablation/temperature_stability.jsonl'

# # 正则:提取 <answer>...</answer> 中的数值
# pattern = re.compile(r"<answer>(.*?)</answer>")

# y_true = []  # reference_answer 中的真实值
# y_pred = []  # generated_answer 预测值

# with open(file_path, 'r', encoding='utf-8') as f:
#     for line in f:
#         line = line.strip()
#         print(line)
#         if not line:
#             continue
#         try:
#             data = json.loads(line)  # 每行都是一个 JSON 对象
#             reference_str = data.get("reference_answer", "").strip()
#             generated_str = data.get("generated_answer", "").strip()
#             print(reference_str)
#             print(generated_str)
#             # 提取 generated_answer 中的数值
#             pred_match = pattern.search(generated_str)
#             print(pred_match.group(1))
#             if pred_match:
#                 pred_value = float(pred_match.group(1))  # 提取 <answer> 中的值
#                 true_value = float(reference_str)        # reference_answer 本身就是数值字符串
#                 y_true.append(true_value)
#                 y_pred.append(pred_value)
#             else:
#                 print(f"未找到 <answer> 标签,跳过:{generated_str}")
#         except Exception as e:
#             print(f"处理行时出错:{line}")
#             print(f"错误:{e}")
#             continue

# # 计算 Spearman 相关系数
# if len(y_true) > 1:
#     print(len(y_true))
#     print(len(y_pred))
#     rho, p_value = spearmanr(y_pred,y_true)
#     print(f"有效样本数:{len(y_true)}")
#     print(f"Spearman correlation coefficient: {rho:.5f}, p-value: {p_value:.4e}")
# else:
#     print("有效数据不足,无法计算 Spearman 相关系数。")


# # f1
# import re
# import json

# def extract_numbers_from_generated_answer(s):
#     """从<answer>...</answer>中提取纯数字"""
#     match = re.search(r'<answer>(.*?)</answer>', s)
#     if match:
#         numbers_str = match.group(1)
#         numbers = re.findall(r'\d+', numbers_str)
#         return set(map(int, numbers))
#     else:
#         return set()

# def extract_numbers_from_reference_answer(s):
#     """从参考答案字符串中提取数字,忽略非数字字符"""
#     numbers = re.findall(r'\d+', s)
#     return set(map(int, numbers))

# def calculate_f1(pred_set, target_set):
#     if not pred_set and not target_set:
#         return 1.0
#     if not pred_set or not target_set:
#         return 0.0
#     tp = len(pred_set & target_set)
#     fp = len(pred_set - target_set)
#     fn = len(target_set - pred_set)
#     precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
#     recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
#     if precision + recall == 0:
#         return 0.0
#     return 2 * precision * recall / (precision + recall)

# # 文件路径
# file_path = '/nas/shared/kilab/wangyujia/BIO/ablation/enzyme_commission_number.jsonl'

# all_f1_scores = []

# with open(file_path, 'r', encoding='utf-8') as f:
#     for line in f:
#         try:
#             data = json.loads(line)
#             reference = data.get('reference_answer', '')
#             generated = data.get('generated_answer', '')
#             target_set = extract_numbers_from_reference_answer(reference)
#             pred_set = extract_numbers_from_generated_answer(generated)
#             f1 = calculate_f1(pred_set, target_set)
#             all_f1_scores.append(f1)
#         except Exception as e:
#             print(f"处理行时出错:{line.strip()}")
#             print(f"错误:{e}")

# # 输出结果
# if all_f1_scores:
#     avg_f1 = sum(all_f1_scores) / len(all_f1_scores)
#     print(f"总样本数:{len(all_f1_scores)}")
#     print(f"平均 F1 分数:{avg_f1:.5f}")
# else:
#     print("没有有效的F1分数可供计算")