|
|
|
|
|
""" |
|
|
Test script for the Wav2Vec2 audio sentiment analysis model |
|
|
""" |
|
|
|
|
|
import os |
|
|
import torch |
|
|
import numpy as np |
|
|
import librosa |
|
|
from transformers import AutoFeatureExtractor, AutoModelForAudioClassification |
|
|
import tempfile |
|
|
|
|
|
|
|
|
def test_audio_model(): |
|
|
"""Test the audio model loading and inference""" |
|
|
|
|
|
print("π Testing Wav2Vec2 Audio Sentiment Model") |
|
|
print("=" * 50) |
|
|
|
|
|
|
|
|
model_path = "models/wav2vec2_model.pth" |
|
|
if not os.path.exists(model_path): |
|
|
print(f"β Audio model file not found at: {model_path}") |
|
|
return False |
|
|
|
|
|
print(f"β
Found model file: {model_path}") |
|
|
|
|
|
try: |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
print(f"π₯οΈ Using device: {device}") |
|
|
|
|
|
|
|
|
checkpoint = torch.load(model_path, map_location=device) |
|
|
print(f"π Checkpoint keys: {list(checkpoint.keys())}") |
|
|
|
|
|
|
|
|
if "classifier.weight" in checkpoint: |
|
|
num_classes = checkpoint["classifier.weight"].shape[0] |
|
|
print(f"π Model has {num_classes} output classes") |
|
|
else: |
|
|
print("β οΈ Could not determine number of classes from checkpoint") |
|
|
num_classes = 3 |
|
|
|
|
|
|
|
|
print("π Initializing Wav2Vec2 model...") |
|
|
model_checkpoint = "facebook/wav2vec2-base" |
|
|
model = AutoModelForAudioClassification.from_pretrained( |
|
|
model_checkpoint, num_labels=num_classes |
|
|
) |
|
|
|
|
|
|
|
|
print("π Loading trained weights...") |
|
|
model.load_state_dict(checkpoint) |
|
|
model.to(device) |
|
|
model.eval() |
|
|
|
|
|
print("β
Model loaded successfully!") |
|
|
|
|
|
|
|
|
print("π§ͺ Testing inference with dummy audio...") |
|
|
|
|
|
|
|
|
dummy_audio = np.random.randn(16000).astype(np.float32) |
|
|
|
|
|
|
|
|
feature_extractor = AutoFeatureExtractor.from_pretrained(model_checkpoint) |
|
|
|
|
|
|
|
|
inputs = feature_extractor( |
|
|
dummy_audio, |
|
|
sampling_rate=16000, |
|
|
max_length=80000, |
|
|
truncation=True, |
|
|
padding="max_length", |
|
|
return_tensors="pt", |
|
|
) |
|
|
|
|
|
|
|
|
input_values = inputs.input_values.to(device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(input_values) |
|
|
probabilities = torch.softmax(outputs.logits, dim=1) |
|
|
confidence, predicted = torch.max(probabilities, 1) |
|
|
|
|
|
print(f"π Model output shape: {outputs.logits.shape}") |
|
|
print(f"π― Predicted class: {predicted.item()}") |
|
|
print(f"π Confidence: {confidence.item():.3f}") |
|
|
print(f"π All probabilities: {probabilities.squeeze().cpu().numpy()}") |
|
|
|
|
|
|
|
|
sentiment_map = {0: "Negative", 1: "Neutral", 2: "Positive"} |
|
|
predicted_sentiment = sentiment_map.get( |
|
|
predicted.item(), f"Class_{predicted.item()}" |
|
|
) |
|
|
print(f"π Predicted sentiment: {predicted_sentiment}") |
|
|
|
|
|
print("β
Audio model test completed successfully!") |
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
print(f"β Error testing audio model: {str(e)}") |
|
|
import traceback |
|
|
|
|
|
traceback.print_exc() |
|
|
return False |
|
|
|
|
|
|
|
|
def check_audio_model_file(): |
|
|
"""Check the audio model file details""" |
|
|
|
|
|
print("\nπ Audio Model File Analysis") |
|
|
print("=" * 30) |
|
|
|
|
|
model_path = "models/wav2vec2_model.pth" |
|
|
if not os.path.exists(model_path): |
|
|
print(f"β Model file not found: {model_path}") |
|
|
return |
|
|
|
|
|
|
|
|
file_size = os.path.getsize(model_path) / (1024 * 1024) |
|
|
print(f"π File size: {file_size:.1f} MB") |
|
|
|
|
|
try: |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
checkpoint = torch.load(model_path, map_location=device) |
|
|
|
|
|
print(f"π Checkpoint keys ({len(checkpoint)} total):") |
|
|
for key, value in checkpoint.items(): |
|
|
if isinstance(value, torch.Tensor): |
|
|
print(f" - {key}: {value.shape} ({value.dtype})") |
|
|
else: |
|
|
print(f" - {key}: {type(value)}") |
|
|
|
|
|
|
|
|
if "classifier.weight" in checkpoint: |
|
|
num_classes = checkpoint["classifier.weight"].shape[0] |
|
|
print(f"\nπ― Classifier output classes: {num_classes}") |
|
|
print( |
|
|
f"π Classifier weight shape: {checkpoint['classifier.weight'].shape}" |
|
|
) |
|
|
if "classifier.bias" in checkpoint: |
|
|
print( |
|
|
f"π Classifier bias shape: {checkpoint['classifier.bias'].shape}" |
|
|
) |
|
|
|
|
|
|
|
|
if "wav2vec2.feature_extractor.conv_layers.0.conv.weight" in checkpoint: |
|
|
print(f"π Wav2Vec2 base model: Present") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"β Error analyzing checkpoint: {str(e)}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
print("π Starting Wav2Vec2 Audio Model Tests") |
|
|
print("=" * 60) |
|
|
|
|
|
|
|
|
check_audio_model_file() |
|
|
|
|
|
print("\n" + "=" * 60) |
|
|
|
|
|
|
|
|
success = test_audio_model() |
|
|
|
|
|
if success: |
|
|
print("\nπ All audio model tests passed!") |
|
|
else: |
|
|
print("\nπ₯ Audio model tests failed!") |
|
|
|