shuraimi commited on
Commit
02318d0
Β·
verified Β·
1 Parent(s): 695a8aa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +88 -21
app.py CHANGED
@@ -3,14 +3,56 @@ from pathlib import Path
3
  import streamlit as st
4
  from fastai.vision.all import load_learner, PILImage
5
 
 
 
 
 
 
 
 
6
 
7
- # βœ… Correct absolute path for Hugging Face Spaces
8
- MODEL_PATH = Path("models/pokemon_gen9_classifier_resnet101_after_cleaning.pkl")
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
 
 
11
  @st.cache_resource
12
  def load_model():
13
- """Load and cache the FastAI learner. Returns None if model missing or incompatible."""
14
  if not MODEL_PATH.exists():
15
  st.error(f"❌ Model not found at {MODEL_PATH}")
16
  return None
@@ -29,33 +71,58 @@ def predict(learner, img_bytes: bytes):
29
  return pred, probs
30
 
31
 
 
32
  def main():
33
  st.title("🎯 FastAI Image Classifier")
34
- st.write("Upload an image and the model will predict its class.")
35
 
36
  learner = load_model()
37
  if learner is None:
38
- st.warning(
39
- "Model not loaded. Please ensure the `.pkl` file is correctly placed under `models/` and committed with Git LFS."
40
- )
41
  st.stop()
42
 
43
- uploaded_file = st.file_uploader("πŸ“€ Choose an image...", type=["png", "jpg", "jpeg"])
44
- if uploaded_file is not None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  st.image(uploaded_file, caption="Uploaded Image", use_column_width=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
- with st.spinner("Predicting..."):
48
- try:
49
- pred, probs = predict(learner, uploaded_file.read())
50
- st.success(f"βœ… Predicted: **{pred}**")
51
- # Show top-5 predictions
52
- vocab = learner.dls.vocab
53
- probs_list = sorted(zip(vocab, probs), key=lambda x: x[1], reverse=True)
54
- st.write("### Top Predictions:")
55
- for label, p in probs_list[:5]:
56
- st.write(f"- {label}: {p:.4f}")
57
- except Exception as e:
58
- st.error(f"Error during prediction: {e}")
59
 
60
 
61
  if __name__ == "__main__":
 
3
  import streamlit as st
4
  from fastai.vision.all import load_learner, PILImage
5
 
6
+ # === CONFIG ===
7
+ st.set_page_config(
8
+ page_title="FastAI Image Classifier",
9
+ page_icon="🎯",
10
+ layout="centered",
11
+ initial_sidebar_state="collapsed",
12
+ )
13
 
14
+ MODEL_PATH = Path("models/pokemon_gen9_classifier_resnet101_after_cleaning.pkl")
15
+ EXAMPLES_DIR = Path("examples")
16
 
17
+ # === STYLES ===
18
+ st.markdown("""
19
+ <style>
20
+ .block-container {
21
+ max-width: 750px;
22
+ margin: auto;
23
+ padding-top: 2rem;
24
+ }
25
+ .stButton button {
26
+ border-radius: 10px;
27
+ background-color: #4CAF50 !important;
28
+ color: white !important;
29
+ font-weight: 600;
30
+ padding: 0.6em 1.2em;
31
+ transition: all 0.2s ease-in-out;
32
+ }
33
+ .stButton button:hover {
34
+ background-color: #45a049 !important;
35
+ transform: scale(1.05);
36
+ }
37
+ h1, h2, h3 {
38
+ text-align: center;
39
+ color: #222;
40
+ }
41
+ .prediction-box {
42
+ background-color: #f7f7f7;
43
+ border-radius: 12px;
44
+ padding: 1rem;
45
+ margin-top: 1rem;
46
+ box-shadow: 0 2px 6px rgba(0,0,0,0.1);
47
+ }
48
+ </style>
49
+ """, unsafe_allow_html=True)
50
 
51
+
52
+ # === MODEL LOADING ===
53
  @st.cache_resource
54
  def load_model():
55
+ """Load and cache the FastAI learner."""
56
  if not MODEL_PATH.exists():
57
  st.error(f"❌ Model not found at {MODEL_PATH}")
58
  return None
 
71
  return pred, probs
72
 
73
 
74
+ # === MAIN APP ===
75
  def main():
76
  st.title("🎯 FastAI Image Classifier")
77
+ st.write("Upload an image or try one of the examples below!")
78
 
79
  learner = load_model()
80
  if learner is None:
81
+ st.warning("Please ensure your `.pkl` model is correctly placed under `models/` and committed with Git LFS.")
 
 
82
  st.stop()
83
 
84
+ # Example images
85
+ example_images = list(EXAMPLES_DIR.glob("*"))
86
+ selected_example = None
87
+
88
+ if example_images:
89
+ st.subheader("✨ Try Example Images")
90
+ cols = st.columns(len(example_images))
91
+ for i, img_path in enumerate(example_images):
92
+ with cols[i]:
93
+ if st.button(img_path.stem.capitalize()):
94
+ selected_example = img_path
95
+
96
+ uploaded_file = st.file_uploader("πŸ“€ Upload your own image", type=["png", "jpg", "jpeg"])
97
+
98
+ if selected_example:
99
+ img_bytes = open(selected_example, "rb").read()
100
+ st.image(selected_example, caption=f"Example: {selected_example.stem}", use_column_width=True)
101
+ elif uploaded_file:
102
+ img_bytes = uploaded_file.read()
103
  st.image(uploaded_file, caption="Uploaded Image", use_column_width=True)
104
+ else:
105
+ st.stop()
106
+
107
+ st.markdown("---")
108
+
109
+ with st.spinner("πŸ” Analyzing the image..."):
110
+ try:
111
+ pred, probs = predict(learner, img_bytes)
112
+ vocab = learner.dls.vocab
113
+ probs_list = sorted(zip(vocab, probs), key=lambda x: x[1], reverse=True)
114
+
115
+ # Display prediction
116
+ st.markdown(f"<div class='prediction-box'><h3>βœ… Prediction: <span style='color:#4CAF50'>{pred}</span></h3></div>", unsafe_allow_html=True)
117
+ st.progress(float(probs.max()))
118
+
119
+ # Top 5 predictions
120
+ st.subheader("Top 5 Predictions")
121
+ for label, p in probs_list[:5]:
122
+ st.write(f"β€’ **{label}** β€” {p:.2%}")
123
 
124
+ except Exception as e:
125
+ st.error(f"❌ Error during prediction:\n\n{e}")
 
 
 
 
 
 
 
 
 
 
126
 
127
 
128
  if __name__ == "__main__":