|
|
import streamlit as st |
|
|
from PIL import Image |
|
|
import torch |
|
|
import torchvision.transforms as transforms |
|
|
from torchvision.models import resnet50 |
|
|
|
|
|
|
|
|
CLASS_NAMES = [ |
|
|
"airplane", "automobile", "bird", "cat", "deer", |
|
|
"dog", "frog", "horse", "ship", "truck" |
|
|
] |
|
|
|
|
|
|
|
|
@st.cache_resource |
|
|
def load_model(): |
|
|
model = resnet50(pretrained=False) |
|
|
model.fc = torch.nn.Linear(model.fc.in_features, 10) |
|
|
model.load_state_dict(torch.load("best_model.pth", map_location=torch.device('cpu'))) |
|
|
model.eval() |
|
|
return model |
|
|
|
|
|
model = load_model() |
|
|
|
|
|
|
|
|
def preprocess_image(image): |
|
|
transform = transforms.Compose([ |
|
|
transforms.Resize((32, 32)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) |
|
|
]) |
|
|
return transform(image).unsqueeze(0) |
|
|
|
|
|
|
|
|
st.title("CIFAR-10 Image Classifier") |
|
|
uploaded_file = st.file_uploader("Upload an Image (JPG/PNG)", type=["jpg", "png"]) |
|
|
|
|
|
if uploaded_file is not None: |
|
|
|
|
|
image = Image.open(uploaded_file).convert("RGB") |
|
|
st.image(image, caption="Uploaded Image", use_column_width=True) |
|
|
|
|
|
|
|
|
st.write("Classifying...") |
|
|
input_tensor = preprocess_image(image) |
|
|
with torch.no_grad(): |
|
|
outputs = model(input_tensor) |
|
|
_, predicted = outputs.max(1) |
|
|
label = CLASS_NAMES[predicted.item()] |
|
|
st.write(f"Prediction: **{label}**") |
|
|
|