import streamlit as st from PIL import Image import torch import torchvision.transforms as transforms from torchvision.models import resnet50 # Load CIFAR-10 class labels CLASS_NAMES = [ "airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck" ] # Load the trained model @st.cache_resource def load_model(): model = resnet50(pretrained=False) model.fc = torch.nn.Linear(model.fc.in_features, 10) # CIFAR-10 has 10 classes model.load_state_dict(torch.load("best_model.pth", map_location=torch.device('cpu'))) model.eval() return model model = load_model() # Image preprocessing function def preprocess_image(image): transform = transforms.Compose([ transforms.Resize((32, 32)), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ]) return transform(image).unsqueeze(0) # Streamlit UI st.title("CIFAR-10 Image Classifier") uploaded_file = st.file_uploader("Upload an Image (JPG/PNG)", type=["jpg", "png"]) if uploaded_file is not None: # Display the uploaded image image = Image.open(uploaded_file).convert("RGB") st.image(image, caption="Uploaded Image", use_column_width=True) # Preprocess the image and make prediction st.write("Classifying...") input_tensor = preprocess_image(image) with torch.no_grad(): outputs = model(input_tensor) _, predicted = outputs.max(1) label = CLASS_NAMES[predicted.item()] st.write(f"Prediction: **{label}**")