cifar-10-fastapi / utils /model_utils.py
avidaldo's picture
Autocommit
6b2d154
Raw
History Blame Contribute Delete
1.76 kB
import torch
import torch.nn.functional as F
import os
from models.cnn_model import ImprovedCNN
# List of CIFAR-10 classes
CLASSES = ['airplane', 'automobile', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck']
def load_model(model_path):
"""
Load a pre-trained PyTorch model.
Args:
model_path (str): Path to the saved model file
Returns:
torch.nn.Module: The loaded model in evaluation mode
"""
# Create a new model instance
model = ImprovedCNN()
# Check if the model file exists
if not os.path.exists(model_path):
raise FileNotFoundError(f"Model file not found: {model_path}")
# Load the state dictionary
state_dict = torch.load(model_path, map_location=torch.device('cpu'))
model.load_state_dict(state_dict)
# Set the model to evaluation mode
model.eval()
return model
def predict_image(model, image_tensor):
"""
Make a prediction on a processed image tensor.
Args:
model (torch.nn.Module): The trained model
image_tensor (torch.Tensor): Processed image tensor
Returns:
dict: Dictionary containing prediction class and confidence
"""
with torch.no_grad():
# Get model outputs
outputs = model(image_tensor)
# Get predicted class index
_, predicted = torch.max(outputs, 1)
# Calculate softmax probabilities
probs = F.softmax(outputs, dim=1)
# Get the predicted class and confidence
class_idx = predicted.item()
confidence = probs[0][class_idx].item()
return {
"prediction": CLASSES[class_idx],
"confidence": round(confidence * 100, 2)
}