test_voice / examples /tts /utils /eval_emo_acc.py
tonyshark's picture
Upload 239 files
34b0b92 verified
import torch
import argparse
import os
import torch.nn.functional as F
import numpy as np
import json
from funasr import AutoModel
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Inference')
parser.add_argument('--gt', type=str, default="../test.jsonl")
parser.add_argument('--pred', type=str)
parser.add_argument('--audio_subdir', type=str, default='pred_audio/default_tone', help='Subdirectory for audio files relative to the parent directory.')
args = parser.parse_args()
pred_dir = os.path.join(args.pred, args.audio_subdir)
output_path = os.path.join(args.pred, "emo1.log")
model = AutoModel(model="iic/emotion2vec_plus_large")
simis = []
correct_predictions = 0
total_predictions = 0
# Define all labels and the selected subset for processing
all_labels = ['angry', 'disgusted', 'fearful', 'happy', 'neutral', 'other', 'sad', 'surprised', 'unk']
selected_labels = ['angry', 'happy', 'neutral', 'other', 'sad']
selected_indices = [all_labels.index(label) for label in selected_labels] #[0, 3, 4, 5, 6]
# Initialize dictionaries to track recall information
recall_stats = {label: {'correct': 0, 'total': 0} for label in selected_labels}
with torch.no_grad():
with open(args.gt, "r") as rf, open(output_path, "w") as f:
for line in rf:
data = json.loads(line.strip())
id =data["key"]
gt_path=data["target_wav"]
pred_path=pred_dir+'/'+id+'.wav'
tgt_emo=data["emotion"]
if tgt_emo not in selected_labels:
tgt_emo = "other"
if not os.path.exists(pred_path):
print(pred_path)
continue
try:
pred_result = model.generate(pred_path, granularity="utterance", extract_embedding=True)
pred_emb = pred_result[0]["feats"]
# Filter scores and labels for selected labels only
pred_scores = pred_result[0]['scores']
pred_scores_filtered = [pred_scores[i] for i in selected_indices]
pred_emo = selected_labels[pred_scores_filtered.index(max(pred_scores_filtered))]
except Exception as e:
print(f"Error processing {pred_path}: {e}")
continue
try:
tgt_result = model.generate(gt_path, granularity="utterance", extract_embedding=True)
tgt_emb = tgt_result[0]["feats"]
except Exception as e:
print(f"Error processing {gt_path}: {e}")
continue
# Update total and correct predictions
total_predictions += 1
if pred_emo == tgt_emo:
correct_predictions += 1
recall_stats[tgt_emo]['correct'] += 1
recall_stats[tgt_emo]['total'] += 1
simi = float(F.cosine_similarity(torch.FloatTensor([pred_emb]), torch.FloatTensor([tgt_emb])).item())
simis.append(simi)
print("%s %s %f"%(pred_path, gt_path, simi), file=f)
print("------------------------------------------", file=f)
print("len:", len(simis),file=f)
print("emo2vec large:", np.mean(simis), file=f)
overall_accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0
print("------------------------------------------", file=f)
print(f"Total predictions: {total_predictions}", file=f)
print(f"Correct predictions: {correct_predictions}", file=f)
print(f"Overall Accuracy: {overall_accuracy:.3f}", file=f)
# Calculate recall for each emotion
recalls = []
for label in selected_labels:
recall = (recall_stats[label]['correct'] / recall_stats[label]['total']) if recall_stats[label]['total'] > 0 else 0
recalls.append(recall)
print(f"Recall for {label}: {recall:.3f}", file=f)
# Calculate and print the average recall
average_recall = np.mean(recalls)
print(f"Average Recall: {average_recall:.3f}", file=f)