File size: 3,793 Bytes
c14b474
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b302157
c14b474
 
 
 
 
63c41aa
c14b474
 
 
 
 
 
 
7743aec
c14b474
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63c41aa
 
 
c14b474
63c41aa
 
c14b474
63c41aa
 
 
c14b474
63c41aa
 
 
c14b474
63c41aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import streamlit as st
from PIL import Image
import torch
import torchvision.transforms as transforms
import json
import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
# Now import from model.py
from model.model import ResNet50

# Load the model
@st.cache_data
def load_class_names():
    try:
        with open("imagenet_classes.json", 'r', encoding='utf-8') as f:
            content = f.read()
            content = ''.join(char for char in content if ord(char) >= 32 or char in '\n\r\t')
            class_names = json.loads(content)
            return class_names
    except Exception as e:
        st.error(f"Error loading class names: {str(e)}")
        return {}

# Load model
@st.cache_resource
def load_model():
    try:
        model = ResNet50(num_classes=1000)
        checkpoint = torch.load("./checkpoints/model_best.pth", map_location=torch.device("cpu"))
        if "model_state_dict" in checkpoint:
            model.load_state_dict(checkpoint["model_state_dict"])
        else:
            st.error("Invalid model checkpoint format")
            return None
        model.eval()
        return model
    except Exception as e:
        st.error(f"Error loading model: {str(e)}")
        return None

# Preprocess image
def preprocess_image(image):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    return transform(image).unsqueeze(0)

# Streamlit app
st.title("Image Classification with ResNet50")
class_names = load_class_names()
model = load_model()

# Update the main section to handle None model
if model is None:
    st.error("Failed to load the model. Please check the model file.")
    st.stop()

# Initialize session state
if 'show_upload' not in st.session_state:
    st.session_state.show_upload = True

# Main content container
main_container = st.empty()

with main_container.container():
    if st.session_state.show_upload:
        uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
        
        if uploaded_file:
            # Load and display image
            image = Image.open(uploaded_file).convert("RGB")
            
            col1, col2 = st.columns(2)
            
            with col1:
                st.markdown("### Uploaded Image")
                st.image(image, use_container_width=True)
            
            with col2:
                st.markdown("### Predictions")
                # Process image and get predictions
                input_tensor = preprocess_image(image)
                with torch.no_grad():
                    outputs = model(input_tensor)
                    probabilities = torch.nn.functional.softmax(outputs[0], dim=0)
                    top5_prob, top5_idx = torch.topk(probabilities, 5)
                    
                    results = []
                    for i in range(5):
                        class_id = top5_idx[i].item()
                        prob = top5_prob[i].item() * 100
                        class_name = class_names[str(class_id)]
                        results.append({
                            "Rank": i + 1,
                            "Class": class_name,
                            "Confidence": f"{prob:.2f}%"
                        })
                    st.table(results)
            
            # Add the New Image button
            st.markdown("<br>", unsafe_allow_html=True)
            col1, col2, col3 = st.columns([2, 1, 2])
            with col2:
                if st.button("↻ New Image"):
                    main_container.empty()
                    st.session_state.show_upload = True
                    st.rerun()