shuraimi commited on
Commit
08cb881
·
verified ·
1 Parent(s): 3f4834e

Update src/app.py

Browse files
Files changed (1) hide show
  1. src/app.py +49 -41
src/app.py CHANGED
@@ -1,78 +1,86 @@
1
  import io
2
  from pathlib import Path
3
-
4
  import streamlit as st
5
  from fastai.vision.all import load_learner, PILImage
6
 
7
-
 
 
8
  MODEL_PATH = Path("models/pokemon_gen9_classifier_resnet101_after_cleaning.pkl")
9
 
10
-
11
- @st.cache_resource
 
 
12
  def load_model():
13
  """Load and cache the FastAI learner. Returns None if model missing or incompatible."""
14
  if not MODEL_PATH.exists():
 
15
  return None
16
  try:
17
  learner = load_learner(MODEL_PATH)
18
  return learner
19
- except RuntimeError as e:
20
- if "deprecated in `fastai>=2.8.0`" in str(e):
21
- st.error(
22
- f"⚠️ Model incompatibility detected!\n\n"
23
- f"Your model was exported with FastAI ≥2.8.0, but this app uses FastAI 2.7.12.\n\n"
24
- f"**To fix this:**\n"
25
- f"1. Re-export your model using FastAI 2.7.12:\n"
26
- f" - Downgrade: `pip install fastai==2.7.12 fastcore==1.7.9`\n"
27
- f" - Run: `learn.export('{MODEL_PATH}')`\n"
28
- f"2. Place the new export at: `{MODEL_PATH}`\n"
29
- f"3. Refresh this page.\n\n"
30
- f"**Error details:** {str(e)}"
31
- )
32
- return None
33
- raise
34
-
35
 
 
 
 
36
  def predict(learner, img_bytes: bytes):
 
37
  img = PILImage.create(io.BytesIO(img_bytes))
38
  pred, pred_idx, probs = learner.predict(img)
39
  return pred, probs
40
 
41
-
 
 
42
  def main():
43
- st.title("FastAI Image Classifier")
 
44
  st.write("Upload an image and the model will predict the class.")
45
 
46
  learner = load_model()
47
  if learner is None:
48
- st.warning(
49
- f"No model found at `{MODEL_PATH}`. See README.md for how to export and place your model."
50
- )
51
 
52
- uploaded_file = st.file_uploader("Choose an image...", type=["png", "jpg", "jpeg"])
53
- if uploaded_file is not None:
54
- st.image(uploaded_file, caption="Uploaded Image", use_column_width=True)
55
 
56
- if learner is None:
57
- st.error("Can't predict because the model is missing. Follow README.md to export your learner.")
58
- return
59
 
60
- with st.spinner("Predicting..."):
61
- pred, probs = predict(learner, uploaded_file.read())
 
 
 
 
62
 
63
- st.success(f"Predicted: {pred}")
64
 
65
- # Try to show top probabilities if available
66
  try:
67
  vocab = learner.dls.vocab
68
- probs_list = list(zip(map(str, vocab), map(float, probs)))
69
- probs_list.sort(key=lambda x: x[1], reverse=True)
70
- st.write("Top predictions:")
 
 
 
71
  for label, p in probs_list[:5]:
72
- st.write(f"- {label}: {p:.4f}")
73
  except Exception:
74
- st.write("Probabilities unavailable or not applicable for this model.")
75
-
76
 
 
77
  if __name__ == "__main__":
78
  main()
 
1
  import io
2
  from pathlib import Path
 
3
  import streamlit as st
4
  from fastai.vision.all import load_learner, PILImage
5
 
6
+ # ---------------------------------------------------------------------
7
+ # Configuration
8
+ # ---------------------------------------------------------------------
9
  MODEL_PATH = Path("models/pokemon_gen9_classifier_resnet101_after_cleaning.pkl")
10
 
11
+ # ---------------------------------------------------------------------
12
+ # Load model (cached so it doesn't reload on every rerun)
13
+ # ---------------------------------------------------------------------
14
+ @st.cache_resource(show_spinner=False)
15
  def load_model():
16
  """Load and cache the FastAI learner. Returns None if model missing or incompatible."""
17
  if not MODEL_PATH.exists():
18
+ st.error(f"❌ Model file not found at: `{MODEL_PATH}`")
19
  return None
20
  try:
21
  learner = load_learner(MODEL_PATH)
22
  return learner
23
+ except Exception as e:
24
+ st.error(
25
+ f"⚠️ Failed to load model at `{MODEL_PATH}`.\n\n"
26
+ f"**Error:** {e}\n\n"
27
+ f"If this model was exported using a newer FastAI version, "
28
+ f"re-export it using FastAI 2.7.12:\n"
29
+ f"`pip install fastai==2.7.12 fastcore==1.5.29`\n"
30
+ f"`learn.export('{MODEL_PATH}')`"
31
+ )
32
+ return None
 
 
 
 
 
 
33
 
34
+ # ---------------------------------------------------------------------
35
+ # Prediction function
36
+ # ---------------------------------------------------------------------
37
  def predict(learner, img_bytes: bytes):
38
+ """Run inference on uploaded image and return predictions."""
39
  img = PILImage.create(io.BytesIO(img_bytes))
40
  pred, pred_idx, probs = learner.predict(img)
41
  return pred, probs
42
 
43
+ # ---------------------------------------------------------------------
44
+ # Streamlit UI
45
+ # ---------------------------------------------------------------------
46
  def main():
47
+ st.set_page_config(page_title="FastAI Image Classifier", layout="centered")
48
+ st.title("🧠 FastAI Image Classifier")
49
  st.write("Upload an image and the model will predict the class.")
50
 
51
  learner = load_model()
52
  if learner is None:
53
+ st.warning(f"No model loaded. Make sure `{MODEL_PATH}` exists inside the container.")
54
+ return
 
55
 
56
+ uploaded_file = st.file_uploader("📤 Choose an image...", type=["png", "jpg", "jpeg"])
 
 
57
 
58
+ if uploaded_file is not None:
59
+ st.image(uploaded_file, caption="Uploaded Image", use_container_width=True)
 
60
 
61
+ with st.spinner("🔍 Predicting..."):
62
+ try:
63
+ pred, probs = predict(learner, uploaded_file.read())
64
+ except Exception as e:
65
+ st.error(f"Prediction failed: {e}")
66
+ return
67
 
68
+ st.success(f"✅ **Predicted:** {pred}")
69
 
70
+ # Show top 5 predictions (if available)
71
  try:
72
  vocab = learner.dls.vocab
73
+ probs_list = sorted(
74
+ zip(map(str, vocab), map(float, probs)),
75
+ key=lambda x: x[1],
76
+ reverse=True
77
+ )
78
+ st.subheader("Top Predictions")
79
  for label, p in probs_list[:5]:
80
+ st.write(f"- **{label}**: {p*100:.2f}%")
81
  except Exception:
82
+ st.info("Probabilities unavailable or not applicable for this model.")
 
83
 
84
+ # ---------------------------------------------------------------------
85
  if __name__ == "__main__":
86
  main()