Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn.functional as F | |
| import os | |
| from models.cnn_model import ImprovedCNN | |
| # List of CIFAR-10 classes | |
| CLASSES = ['airplane', 'automobile', 'bird', 'cat', 'deer', | |
| 'dog', 'frog', 'horse', 'ship', 'truck'] | |
| def load_model(model_path): | |
| """ | |
| Load a pre-trained PyTorch model. | |
| Args: | |
| model_path (str): Path to the saved model file | |
| Returns: | |
| torch.nn.Module: The loaded model in evaluation mode | |
| """ | |
| # Create a new model instance | |
| model = ImprovedCNN() | |
| # Check if the model file exists | |
| if not os.path.exists(model_path): | |
| raise FileNotFoundError(f"Model file not found: {model_path}") | |
| # Load the state dictionary | |
| state_dict = torch.load(model_path, map_location=torch.device('cpu')) | |
| model.load_state_dict(state_dict) | |
| # Set the model to evaluation mode | |
| model.eval() | |
| return model | |
| def predict_image(model, image_tensor): | |
| """ | |
| Make a prediction on a processed image tensor. | |
| Args: | |
| model (torch.nn.Module): The trained model | |
| image_tensor (torch.Tensor): Processed image tensor | |
| Returns: | |
| dict: Dictionary containing prediction class and confidence | |
| """ | |
| with torch.no_grad(): | |
| # Get model outputs | |
| outputs = model(image_tensor) | |
| # Get predicted class index | |
| _, predicted = torch.max(outputs, 1) | |
| # Calculate softmax probabilities | |
| probs = F.softmax(outputs, dim=1) | |
| # Get the predicted class and confidence | |
| class_idx = predicted.item() | |
| confidence = probs[0][class_idx].item() | |
| return { | |
| "prediction": CLASSES[class_idx], | |
| "confidence": round(confidence * 100, 2) | |
| } |