File size: 4,666 Bytes
969b92f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2d8c381
 
5673b4c
2d8c381
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import streamlit as st
from PIL import Image
import torch
import torchvision.transforms as transforms
import json
from load import load_model  # Your existing load.py

# ---------------------------
# Set page configuration (must be first)
# ---------------------------
st.set_page_config(page_title="Flower Identifier", layout="wide")

# ---------------------------
# Custom CSS for the Look
# ---------------------------
# This CSS targets Streamlit's root container (.stApp) and forces a white background.
# It also defines a main container with our custom styling.
st.markdown(
    """
    <style>
    /* Force the entire app to have a white background */
    .stApp {
        background-color: #ffffff;
    }
    /* Create a main container for content */
    .main-container {
        background-color: #ffffff;
        border-radius: 15px;
        padding: 40px;
        margin: 20px auto;
        max-width: 900px;
        box-shadow: 0px 4px 20px rgba(0,0,0,0.1);
    }
    /* Set the title to be red */
    h1 {
        text-align: center;
        color: #ff0000 !important;
    }
    /* Hide the default Streamlit menu and footer */
    #MainMenu {visibility: hidden;}
    footer {visibility: hidden;}
    </style>
    """,
    unsafe_allow_html=True
)

# ---------------------------
# Helper Functions
# ---------------------------
@st.cache_data
def load_flower_info(filename="flower.json"):
    with open(filename, "r") as f:
        data = json.load(f)
    # Create a dictionary keyed by flower id
    return {flower["id"]: flower for flower in data}

flower_info = load_flower_info("flower.json")

@st.cache_resource
def get_model():
    return load_model("fine_tuned_resnet50.pth")

model = get_model()

# Define image transformation
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

def classify_image(image):
    """Preprocess image, run inference, and return predicted flower info."""
    image = image.convert("RGB")
    image_tensor = transform(image).unsqueeze(0)
    with torch.no_grad():
        output = model(image_tensor)
    predicted_class = torch.argmax(output, dim=1).item()
    info = flower_info.get(predicted_class, None)
    return predicted_class, info

def log_feedback(predicted_class, user_feedback):
    """Log feedback to a file for future model improvement."""
    with open("feedback.log", "a") as f:
        f.write(f"Predicted: {predicted_class}, Correction: {user_feedback}\n")

# ---------------------------
# Layout: Banner, Title, and Main Container
# ---------------------------

# Display a banner image at the top
banner_url = "flowers-identifier.webp"
st.image(banner_url, use_column_width=True)

# Wrap our main content in a custom container div
st.markdown("<div class='main-container'>", unsafe_allow_html=True)

st.title("Flower Classification")
st.write("Upload a flower image to identify it.")

# File uploader widget
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])

if uploaded_file is not None:
    image = Image.open(uploaded_file)
    
    # Display the uploaded image and prediction side by side
    col1, col2 = st.columns(2)
    
    with col1:
        st.image(image, caption="Uploaded Image", use_container_width=True)
    
    with col2:
        predicted_class, info = classify_image(image)
        st.header("Prediction")
        if info is not None:
            st.markdown(f"**Flower Name:** {info['name'].title()}")
            st.markdown(f"**Scientific Name:** {info['scientific_name']}")
            st.markdown(f"**Genus:** {info['genus']}")
            st.markdown(f"**Fun Fact:** {info['fun_fact']}")
            st.markdown(f"**Where Found:** {info['where_found']}")
        else:
            st.markdown("**Prediction:** This flower is not in our database.")
        
        st.markdown("---")
        st.subheader("Is this prediction correct?")
        feedback = st.radio("", ("Yes", "No"), key="feedback_radio")
        
        if feedback == "No":
            st.write("Please enter the correct flower name:")
            user_correction = st.text_input("", key="user_correction")
            if st.button("Submit Correction"):
                if user_correction.strip() != "":
                    log_feedback(predicted_class, user_correction.strip())
                    st.success("Thank you for your feedback! We'll use it to improve our model.")
                else:
                    st.error("Please enter a valid correction.")

st.markdown("</div>", unsafe_allow_html=True)