|
|
import streamlit as st |
|
|
from PIL import Image |
|
|
import torch |
|
|
import torchvision.transforms as transforms |
|
|
import json |
|
|
import sys |
|
|
import os |
|
|
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) |
|
|
|
|
|
from model.model import ResNet50 |
|
|
|
|
|
|
|
|
@st.cache_data |
|
|
def load_class_names(): |
|
|
try: |
|
|
with open("imagenet_classes.json", 'r', encoding='utf-8') as f: |
|
|
content = f.read() |
|
|
content = ''.join(char for char in content if ord(char) >= 32 or char in '\n\r\t') |
|
|
class_names = json.loads(content) |
|
|
return class_names |
|
|
except Exception as e: |
|
|
st.error(f"Error loading class names: {str(e)}") |
|
|
return {} |
|
|
|
|
|
|
|
|
@st.cache_resource |
|
|
def load_model(): |
|
|
try: |
|
|
model = ResNet50(num_classes=1000) |
|
|
checkpoint = torch.load("./checkpoints/model_best.pth", map_location=torch.device("cpu")) |
|
|
if "model_state_dict" in checkpoint: |
|
|
model.load_state_dict(checkpoint["model_state_dict"]) |
|
|
else: |
|
|
st.error("Invalid model checkpoint format") |
|
|
return None |
|
|
model.eval() |
|
|
return model |
|
|
except Exception as e: |
|
|
st.error(f"Error loading model: {str(e)}") |
|
|
return None |
|
|
|
|
|
|
|
|
def preprocess_image(image): |
|
|
transform = transforms.Compose([ |
|
|
transforms.Resize((224, 224)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
|
|
]) |
|
|
return transform(image).unsqueeze(0) |
|
|
|
|
|
|
|
|
st.title("Image Classification with ResNet50") |
|
|
class_names = load_class_names() |
|
|
model = load_model() |
|
|
|
|
|
|
|
|
if model is None: |
|
|
st.error("Failed to load the model. Please check the model file.") |
|
|
st.stop() |
|
|
|
|
|
|
|
|
if 'show_upload' not in st.session_state: |
|
|
st.session_state.show_upload = True |
|
|
|
|
|
|
|
|
main_container = st.empty() |
|
|
|
|
|
with main_container.container(): |
|
|
if st.session_state.show_upload: |
|
|
uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"]) |
|
|
|
|
|
if uploaded_file: |
|
|
|
|
|
image = Image.open(uploaded_file).convert("RGB") |
|
|
|
|
|
col1, col2 = st.columns(2) |
|
|
|
|
|
with col1: |
|
|
st.markdown("### Uploaded Image") |
|
|
st.image(image, use_container_width=True) |
|
|
|
|
|
with col2: |
|
|
st.markdown("### Predictions") |
|
|
|
|
|
input_tensor = preprocess_image(image) |
|
|
with torch.no_grad(): |
|
|
outputs = model(input_tensor) |
|
|
probabilities = torch.nn.functional.softmax(outputs[0], dim=0) |
|
|
top5_prob, top5_idx = torch.topk(probabilities, 5) |
|
|
|
|
|
results = [] |
|
|
for i in range(5): |
|
|
class_id = top5_idx[i].item() |
|
|
prob = top5_prob[i].item() * 100 |
|
|
class_name = class_names[str(class_id)] |
|
|
results.append({ |
|
|
"Rank": i + 1, |
|
|
"Class": class_name, |
|
|
"Confidence": f"{prob:.2f}%" |
|
|
}) |
|
|
st.table(results) |
|
|
|
|
|
|
|
|
st.markdown("<br>", unsafe_allow_html=True) |
|
|
col1, col2, col3 = st.columns([2, 1, 2]) |
|
|
with col2: |
|
|
if st.button("↻ New Image"): |
|
|
main_container.empty() |
|
|
st.session_state.show_upload = True |
|
|
st.rerun() |