File size: 5,657 Bytes
4b35e49 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 |
#!/usr/bin/env python3
"""
Test script for the Wav2Vec2 audio sentiment analysis model
"""
import os
import torch
import numpy as np
import librosa
from transformers import AutoFeatureExtractor, AutoModelForAudioClassification
import tempfile
def test_audio_model():
"""Test the audio model loading and inference"""
print("π Testing Wav2Vec2 Audio Sentiment Model")
print("=" * 50)
# Check if model file exists
model_path = "models/wav2vec2_model.pth"
if not os.path.exists(model_path):
print(f"β Audio model file not found at: {model_path}")
return False
print(f"β
Found model file: {model_path}")
try:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"π₯οΈ Using device: {device}")
# Load the model checkpoint to check architecture
checkpoint = torch.load(model_path, map_location=device)
print(f"π Checkpoint keys: {list(checkpoint.keys())}")
# Check for classifier weights
if "classifier.weight" in checkpoint:
num_classes = checkpoint["classifier.weight"].shape[0]
print(f"π Model has {num_classes} output classes")
else:
print("β οΈ Could not determine number of classes from checkpoint")
num_classes = 3 # Default assumption
# Initialize model
print("π Initializing Wav2Vec2 model...")
model_checkpoint = "facebook/wav2vec2-base"
model = AutoModelForAudioClassification.from_pretrained(
model_checkpoint, num_labels=num_classes
)
# Load trained weights
print("π Loading trained weights...")
model.load_state_dict(checkpoint)
model.to(device)
model.eval()
print("β
Model loaded successfully!")
# Test with dummy audio
print("π§ͺ Testing inference with dummy audio...")
# Create dummy audio (1 second of random noise at 16kHz)
dummy_audio = np.random.randn(16000).astype(np.float32)
# Load feature extractor
feature_extractor = AutoFeatureExtractor.from_pretrained(model_checkpoint)
# Preprocess audio
inputs = feature_extractor(
dummy_audio,
sampling_rate=16000,
max_length=80000, # 5 seconds * 16000 Hz
truncation=True,
padding="max_length",
return_tensors="pt",
)
# Move to device
input_values = inputs.input_values.to(device)
# Run inference
with torch.no_grad():
outputs = model(input_values)
probabilities = torch.softmax(outputs.logits, dim=1)
confidence, predicted = torch.max(probabilities, 1)
print(f"π Model output shape: {outputs.logits.shape}")
print(f"π― Predicted class: {predicted.item()}")
print(f"π Confidence: {confidence.item():.3f}")
print(f"π All probabilities: {probabilities.squeeze().cpu().numpy()}")
# Sentiment mapping
sentiment_map = {0: "Negative", 1: "Neutral", 2: "Positive"}
predicted_sentiment = sentiment_map.get(
predicted.item(), f"Class_{predicted.item()}"
)
print(f"π Predicted sentiment: {predicted_sentiment}")
print("β
Audio model test completed successfully!")
return True
except Exception as e:
print(f"β Error testing audio model: {str(e)}")
import traceback
traceback.print_exc()
return False
def check_audio_model_file():
"""Check the audio model file details"""
print("\nπ Audio Model File Analysis")
print("=" * 30)
model_path = "models/wav2vec2_model.pth"
if not os.path.exists(model_path):
print(f"β Model file not found: {model_path}")
return
# File size
file_size = os.path.getsize(model_path) / (1024 * 1024) # MB
print(f"π File size: {file_size:.1f} MB")
try:
# Load checkpoint
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
checkpoint = torch.load(model_path, map_location=device)
print(f"π Checkpoint keys ({len(checkpoint)} total):")
for key, value in checkpoint.items():
if isinstance(value, torch.Tensor):
print(f" - {key}: {value.shape} ({value.dtype})")
else:
print(f" - {key}: {type(value)}")
# Check classifier
if "classifier.weight" in checkpoint:
num_classes = checkpoint["classifier.weight"].shape[0]
print(f"\nπ― Classifier output classes: {num_classes}")
print(
f"π Classifier weight shape: {checkpoint['classifier.weight'].shape}"
)
if "classifier.bias" in checkpoint:
print(
f"π Classifier bias shape: {checkpoint['classifier.bias'].shape}"
)
# Check wav2vec2 base model
if "wav2vec2.feature_extractor.conv_layers.0.conv.weight" in checkpoint:
print(f"π Wav2Vec2 base model: Present")
except Exception as e:
print(f"β Error analyzing checkpoint: {str(e)}")
if __name__ == "__main__":
print("π Starting Wav2Vec2 Audio Model Tests")
print("=" * 60)
# Check model file
check_audio_model_file()
print("\n" + "=" * 60)
# Test model loading and inference
success = test_audio_model()
if success:
print("\nπ All audio model tests passed!")
else:
print("\nπ₯ Audio model tests failed!")
|