Spaces:
Runtime error
Runtime error
| import torch | |
| import argparse | |
| import os | |
| import torch.nn.functional as F | |
| import numpy as np | |
| import json | |
| from funasr import AutoModel | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description='Inference') | |
| parser.add_argument('--gt', type=str, default="../test.jsonl") | |
| parser.add_argument('--pred', type=str) | |
| parser.add_argument('--audio_subdir', type=str, default='pred_audio/default_tone', help='Subdirectory for audio files relative to the parent directory.') | |
| args = parser.parse_args() | |
| pred_dir = os.path.join(args.pred, args.audio_subdir) | |
| output_path = os.path.join(args.pred, "emo1.log") | |
| model = AutoModel(model="iic/emotion2vec_plus_large") | |
| simis = [] | |
| correct_predictions = 0 | |
| total_predictions = 0 | |
| # Define all labels and the selected subset for processing | |
| all_labels = ['angry', 'disgusted', 'fearful', 'happy', 'neutral', 'other', 'sad', 'surprised', 'unk'] | |
| selected_labels = ['angry', 'happy', 'neutral', 'other', 'sad'] | |
| selected_indices = [all_labels.index(label) for label in selected_labels] #[0, 3, 4, 5, 6] | |
| # Initialize dictionaries to track recall information | |
| recall_stats = {label: {'correct': 0, 'total': 0} for label in selected_labels} | |
| with torch.no_grad(): | |
| with open(args.gt, "r") as rf, open(output_path, "w") as f: | |
| for line in rf: | |
| data = json.loads(line.strip()) | |
| id =data["key"] | |
| gt_path=data["target_wav"] | |
| pred_path=pred_dir+'/'+id+'.wav' | |
| tgt_emo=data["emotion"] | |
| if tgt_emo not in selected_labels: | |
| tgt_emo = "other" | |
| if not os.path.exists(pred_path): | |
| print(pred_path) | |
| continue | |
| try: | |
| pred_result = model.generate(pred_path, granularity="utterance", extract_embedding=True) | |
| pred_emb = pred_result[0]["feats"] | |
| # Filter scores and labels for selected labels only | |
| pred_scores = pred_result[0]['scores'] | |
| pred_scores_filtered = [pred_scores[i] for i in selected_indices] | |
| pred_emo = selected_labels[pred_scores_filtered.index(max(pred_scores_filtered))] | |
| except Exception as e: | |
| print(f"Error processing {pred_path}: {e}") | |
| continue | |
| try: | |
| tgt_result = model.generate(gt_path, granularity="utterance", extract_embedding=True) | |
| tgt_emb = tgt_result[0]["feats"] | |
| except Exception as e: | |
| print(f"Error processing {gt_path}: {e}") | |
| continue | |
| # Update total and correct predictions | |
| total_predictions += 1 | |
| if pred_emo == tgt_emo: | |
| correct_predictions += 1 | |
| recall_stats[tgt_emo]['correct'] += 1 | |
| recall_stats[tgt_emo]['total'] += 1 | |
| simi = float(F.cosine_similarity(torch.FloatTensor([pred_emb]), torch.FloatTensor([tgt_emb])).item()) | |
| simis.append(simi) | |
| print("%s %s %f"%(pred_path, gt_path, simi), file=f) | |
| print("------------------------------------------", file=f) | |
| print("len:", len(simis),file=f) | |
| print("emo2vec large:", np.mean(simis), file=f) | |
| overall_accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0 | |
| print("------------------------------------------", file=f) | |
| print(f"Total predictions: {total_predictions}", file=f) | |
| print(f"Correct predictions: {correct_predictions}", file=f) | |
| print(f"Overall Accuracy: {overall_accuracy:.3f}", file=f) | |
| # Calculate recall for each emotion | |
| recalls = [] | |
| for label in selected_labels: | |
| recall = (recall_stats[label]['correct'] / recall_stats[label]['total']) if recall_stats[label]['total'] > 0 else 0 | |
| recalls.append(recall) | |
| print(f"Recall for {label}: {recall:.3f}", file=f) | |
| # Calculate and print the average recall | |
| average_recall = np.mean(recalls) | |
| print(f"Average Recall: {average_recall:.3f}", file=f) |