| import json |
| import re |
| from collections import defaultdict |
|
|
| infer_result_path = '/root/autodl-tmp/output_7B_GRPO/v28-20250722-002940/checkpoint-870/infer_result/53_HH.jsonl' |
| test_path = '/root/autodl-tmp/ms-swift/all_audio_test_50.jsonl' |
| output_path = 'inference_comparison_result.json' |
|
|
| def extract_overall_score(response_text): |
| match = re.search(r'<overall score>(\d+)</overall score>', response_text) |
| if match: |
| return int(match.group(1)) |
| return None |
|
|
| def main(): |
| |
| infer_audio2score = {} |
| with open(infer_result_path, 'r', encoding='utf-8') as f: |
| for line in f: |
| data = json.loads(line) |
| score = extract_overall_score(data['response']) |
| audios = tuple(data.get('audios', [])) |
| infer_audio2score[audios] = { |
| 'score': score, |
| 'raw_response': data['response'] |
| } |
|
|
| |
| test_audio2solution = {} |
| with open(test_path, 'r', encoding='utf-8') as f: |
| for line in f: |
| data = json.loads(line) |
| solution = data['solution'] |
| audios = tuple(data.get('audios', [])) |
| test_audio2solution[audios] = solution |
|
|
| |
| stats_per_class = defaultdict(lambda: {'correct': 0, 'incorrect': 0}) |
| incorrect_samples_solution1 = [] |
| all_results = [] |
|
|
| total = 0 |
| correct = 0 |
|
|
| for audios, solution in test_audio2solution.items(): |
| infer_entry = infer_audio2score.get(audios, None) |
| infer_score = infer_entry['score'] if infer_entry else None |
| raw_response = infer_entry['raw_response'] if infer_entry else None |
| match = infer_score == solution |
|
|
| |
| all_results.append({ |
| 'audios': audios, |
| 'gt_solution': solution, |
| 'predicted_score': infer_score, |
| 'match': match, |
| 'response': raw_response |
| }) |
|
|
| if match: |
| correct += 1 |
| stats_per_class[solution]['correct'] += 1 |
| else: |
| stats_per_class[solution]['incorrect'] += 1 |
| if solution == 1: |
| incorrect_samples_solution1.append({ |
| 'audios': audios, |
| 'gt_solution': solution, |
| 'predicted_score': infer_score, |
| 'response': raw_response |
| }) |
|
|
| total += 1 |
|
|
| |
| print(f'\nOverall Accuracy: {correct}/{total} = {correct/total:.2%}\n') |
|
|
| |
| print("Per-Class Accuracy:") |
| for solution, stats in sorted(stats_per_class.items()): |
| total_class = stats['correct'] + stats['incorrect'] |
| accuracy = stats['correct'] / total_class if total_class > 0 else 0.0 |
| print(f'Class {solution}: Correct={stats["correct"]}, Incorrect={stats["incorrect"]}, Accuracy={accuracy:.2%}') |
|
|
| |
| print("\nIncorrect Samples for solution = 1:") |
| for sample in incorrect_samples_solution1: |
| print(json.dumps(sample, indent=2, ensure_ascii=False)) |
|
|
| |
| with open(output_path, 'w', encoding='utf-8') as f: |
| json.dump(all_results, f, indent=2, ensure_ascii=False) |
| print(f"\nAll inference comparison results saved to: {output_path}") |
|
|
| if __name__ == '__main__': |
| main() |
|
|