interactSpeech / .ipynb_checkpoints /compare_scores-checkpoint.py
Student0809's picture
Add files using upload-large-folder tool
ee3af03 verified
import json
import re
from collections import defaultdict
infer_result_path = '/root/autodl-tmp/output_7B_GRPO/v28-20250722-002940/checkpoint-870/infer_result/53_HH.jsonl'
test_path = '/root/autodl-tmp/ms-swift/all_audio_test_50.jsonl'
output_path = 'inference_comparison_result.json'
def extract_overall_score(response_text):
match = re.search(r'<overall score>(\d+)</overall score>', response_text)
if match:
return int(match.group(1))
return None
def main():
# 读取infer_result文件,建立audio到score的映射
infer_audio2score = {}
with open(infer_result_path, 'r', encoding='utf-8') as f:
for line in f:
data = json.loads(line)
score = extract_overall_score(data['response'])
audios = tuple(data.get('audios', []))
infer_audio2score[audios] = {
'score': score,
'raw_response': data['response']
}
# 读取test文件,建立audio到solution的映射
test_audio2solution = {}
with open(test_path, 'r', encoding='utf-8') as f:
for line in f:
data = json.loads(line)
solution = data['solution']
audios = tuple(data.get('audios', []))
test_audio2solution[audios] = solution
# 统计和收集错误样本 & 所有推理结果
stats_per_class = defaultdict(lambda: {'correct': 0, 'incorrect': 0})
incorrect_samples_solution1 = []
all_results = []
total = 0
correct = 0
for audios, solution in test_audio2solution.items():
infer_entry = infer_audio2score.get(audios, None)
infer_score = infer_entry['score'] if infer_entry else None
raw_response = infer_entry['raw_response'] if infer_entry else None
match = infer_score == solution
# 收集所有结果
all_results.append({
'audios': audios,
'gt_solution': solution,
'predicted_score': infer_score,
'match': match,
'response': raw_response
})
if match:
correct += 1
stats_per_class[solution]['correct'] += 1
else:
stats_per_class[solution]['incorrect'] += 1
if solution == 1:
incorrect_samples_solution1.append({
'audios': audios,
'gt_solution': solution,
'predicted_score': infer_score,
'response': raw_response
})
total += 1
# 总体准确率
print(f'\nOverall Accuracy: {correct}/{total} = {correct/total:.2%}\n')
# 每类准确率
print("Per-Class Accuracy:")
for solution, stats in sorted(stats_per_class.items()):
total_class = stats['correct'] + stats['incorrect']
accuracy = stats['correct'] / total_class if total_class > 0 else 0.0
print(f'Class {solution}: Correct={stats["correct"]}, Incorrect={stats["incorrect"]}, Accuracy={accuracy:.2%}')
# 列出 solution=1 且预测错误的样本
print("\nIncorrect Samples for solution = 1:")
for sample in incorrect_samples_solution1:
print(json.dumps(sample, indent=2, ensure_ascii=False))
# 写入所有结果到 JSON 文件
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(all_results, f, indent=2, ensure_ascii=False)
print(f"\nAll inference comparison results saved to: {output_path}")
if __name__ == '__main__':
main()