|
|
|
|
|
""" |
|
|
Test script for the vision sentiment analysis model. |
|
|
This script verifies that the ResNet-50 model can be loaded and run inference. |
|
|
""" |
|
|
|
|
|
import os |
|
|
import sys |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torchvision import transforms, models |
|
|
from PIL import Image |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
def get_sentiment_mapping(num_classes): |
|
|
"""Get the sentiment mapping based on number of classes""" |
|
|
if num_classes == 3: |
|
|
return {0: "Negative", 1: "Neutral", 2: "Positive"} |
|
|
elif num_classes == 4: |
|
|
|
|
|
return {0: "Angry", 1: "Sad", 2: "Happy", 3: "Neutral"} |
|
|
elif num_classes == 7: |
|
|
|
|
|
return {0: "Angry", 1: "Disgust", 2: "Fear", 3: "Happy", 4: "Sad", 5: "Surprise", 6: "Neutral"} |
|
|
else: |
|
|
|
|
|
return {i: f"Class_{i}" for i in range(num_classes)} |
|
|
|
|
|
|
|
|
def test_vision_model(): |
|
|
"""Test the vision sentiment analysis model""" |
|
|
|
|
|
print("π§ Testing Vision Sentiment Analysis Model") |
|
|
print("=" * 50) |
|
|
|
|
|
|
|
|
model_path = "models/resnet50_model.pth" |
|
|
if not os.path.exists(model_path): |
|
|
print(f"β Model file not found: {model_path}") |
|
|
print("Please ensure the model file exists in the models/ directory") |
|
|
return False |
|
|
|
|
|
print(f"β
Model file found: {model_path}") |
|
|
|
|
|
try: |
|
|
|
|
|
print("π₯ Loading model checkpoint...") |
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
checkpoint = torch.load(model_path, map_location=device) |
|
|
|
|
|
|
|
|
if 'fc.weight' in checkpoint: |
|
|
num_classes = checkpoint['fc.weight'].shape[0] |
|
|
print(f"π Model checkpoint has {num_classes} output classes") |
|
|
else: |
|
|
|
|
|
num_classes = 3 |
|
|
print("β οΈ Could not determine number of classes from checkpoint, assuming 3") |
|
|
|
|
|
|
|
|
print("π§ Initializing ResNet-50 model...") |
|
|
model = models.resnet50(weights=None) |
|
|
num_ftrs = model.fc.in_features |
|
|
model.fc = nn.Linear(num_ftrs, num_classes) |
|
|
|
|
|
print(f"π₯ Loading trained weights for {num_classes} classes...") |
|
|
model.load_state_dict(checkpoint) |
|
|
model.to(device) |
|
|
model.eval() |
|
|
|
|
|
print(f"β
Model loaded successfully with {num_classes} classes!") |
|
|
print(f"π₯οΈ Using device: {device}") |
|
|
|
|
|
|
|
|
print("π§ͺ Testing inference with dummy image...") |
|
|
|
|
|
|
|
|
dummy_image = Image.fromarray( |
|
|
np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8) |
|
|
) |
|
|
|
|
|
|
|
|
transform = transforms.Compose( |
|
|
[ |
|
|
transforms.Resize(224), |
|
|
transforms.CenterCrop(224), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize( |
|
|
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] |
|
|
), |
|
|
] |
|
|
) |
|
|
|
|
|
image_tensor = transform(dummy_image).unsqueeze(0).to(device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(image_tensor) |
|
|
print(f"π Model output shape: {outputs.shape}") |
|
|
|
|
|
probabilities = torch.nn.functional.softmax(outputs, dim=1) |
|
|
confidence, predicted = torch.max(probabilities, 1) |
|
|
|
|
|
|
|
|
sentiment_map = get_sentiment_mapping(num_classes) |
|
|
sentiment = sentiment_map[predicted.item()] |
|
|
confidence_score = confidence.item() |
|
|
|
|
|
print(f"π― Test prediction: {sentiment} (confidence: {confidence_score:.3f})") |
|
|
print(f"π Available classes: {list(sentiment_map.values())}") |
|
|
print("β
Inference test passed!") |
|
|
|
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
print(f"β Error testing model: {str(e)}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
return False |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Main function""" |
|
|
success = test_vision_model() |
|
|
|
|
|
if success: |
|
|
print("\nπ All tests passed! The vision model is ready to use.") |
|
|
print("You can now run the Streamlit app with: streamlit run app.py") |
|
|
else: |
|
|
print("\nπ₯ Tests failed. Please check the error messages above.") |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|