File size: 3,746 Bytes
4e2a0b2
 
 
 
 
 
 
 
 
 
4a6d61a
4e2a0b2
 
 
 
 
 
4a6d61a
4e2a0b2
4a6d61a
4e2a0b2
 
 
 
 
 
 
 
 
 
4a6d61a
4e2a0b2
 
 
 
4a6d61a
4e2a0b2
 
 
 
 
 
4a6d61a
4e2a0b2
 
 
 
 
 
 
 
 
 
 
 
 
 
4a6d61a
4e2a0b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4a6d61a
4e2a0b2
 
 
4a6d61a
4e2a0b2
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import os
import torch
import pandas as pd
import librosa
import numpy as np
from transformers import AutoFeatureExtractor, Wav2Vec2ForSequenceClassification, AutoConfig
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from tqdm import tqdm

def evaluate_model(model_path="models/wav2vec2-finetuned", metadata_path="data/processed/metadata.csv"):
    print(f"Evaluating model at: {model_path}")
    
    # 1. Load Config to get TRUE labels
    try:
        config = AutoConfig.from_pretrained(model_path)
        id2label = config.id2label
        label2id = config.label2id
        print(f"Loaded Label Map from Config: {label2id}")
    except Exception as e:
        print(f"Error loading config: {e}")
        return

    # 2. Load Model & Extractor
    device = torch.device("cpu")
    try:
        extractor = AutoFeatureExtractor.from_pretrained(model_path)
        model = Wav2Vec2ForSequenceClassification.from_pretrained(model_path)
        model.to(device)
        model.eval()
    except Exception as e:
        print(f"Error loading weights: {e}")
        return

    # 3. Load Metadata
    if not os.path.exists(metadata_path):
        print(f"Metadata not found at {metadata_path}")
        return
    
    df = pd.read_csv(metadata_path)
    # Test on 100 samples
    test_df = df[df['split'] == 'test'].sample(min(100, len(df[df['split'] == 'test'])), random_state=42)
    
    print(f"Testing on {len(test_df)} samples...")

    y_true = []
    y_pred = []

    # 4. Inference Loop
    for idx, row in tqdm(test_df.iterrows(), total=len(test_df), desc="Inference"):
        # Fix: Ensure we match the exact string expected by the model
        true_label_str = row['emotion'] 
        
        # Handle "Suprised" misspelling if model expects "surprised"
        if "surprised" in label2id and true_label_str == "suprised":
            true_label_str = "surprised"
            
        if true_label_str not in label2id:
            print(f"Warning: Label '{true_label_str}' not in model config. Skipping.")
            continue
            
        target_id = label2id[true_label_str]
        
        # Path handling
        audio_path = row['path']
        if not os.path.exists(audio_path):
            audio_path = os.path.join("C:/dev/archive/Emotions", row['emotion'].capitalize(), row['filename'])
            if not os.path.exists(audio_path):
                continue

        try:
            # Load and Preprocess
            speech, sr = librosa.load(audio_path, sr=16000)
            # Crop to 5s to match training logic
            if len(speech) > 16000 * 5:
                speech = speech[:16000 * 5]
                
            inputs = extractor(speech, sampling_rate=16000, return_tensors="pt", padding=True)
            
            # Predict
            with torch.no_grad():
                logits = model(inputs.input_values).logits
                pred_id = torch.argmax(logits, dim=-1).item()
            
            y_true.append(target_id)
            y_pred.append(pred_id)
            
        except Exception as e:
            continue

    # 5. Results
    if not y_true:
        print("No files processed.")
        return

    acc = accuracy_score(y_true, y_pred)
    print(f"\nFINAL ACCURACY: {acc:.2%}")
    
    # Map IDs back to names for the report
    target_names = [id2label[i] for i in sorted(id2label.keys())]
    
    print("\nDetailed Report:")
    print(classification_report(y_true, y_pred, target_names=target_names, labels=sorted(id2label.keys())))
    
    # Print a small confusion matrix snippet
    print("\nConfusion Matrix (True vs Pred):")
    print(confusion_matrix(y_true, y_pred))

if __name__ == "__main__":
    evaluate_model()