shuraimi commited on
Commit
695a8aa
·
verified ·
1 Parent(s): 4a9cb4c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -39
app.py CHANGED
@@ -1,77 +1,61 @@
1
  import io
2
  from pathlib import Path
3
-
4
  import streamlit as st
5
  from fastai.vision.all import load_learner, PILImage
6
 
7
 
8
- MODEL_PATH = Path("pokemon_gen9_classifier_resnet101_after_cleaning.pkl")
 
9
 
10
 
11
  @st.cache_resource
12
  def load_model():
13
  """Load and cache the FastAI learner. Returns None if model missing or incompatible."""
14
  if not MODEL_PATH.exists():
 
15
  return None
16
  try:
17
  learner = load_learner(MODEL_PATH)
18
  return learner
19
- except RuntimeError as e:
20
- if "deprecated in `fastai>=2.8.0`" in str(e):
21
- st.error(
22
- f"⚠️ Model incompatibility detected!\n\n"
23
- f"Your model was exported with FastAI ≥2.8.0, but this app uses FastAI 2.7.12.\n\n"
24
- f"**To fix this:**\n"
25
- f"1. Re-export your model using FastAI 2.7.12:\n"
26
- f" - Downgrade: `pip install fastai==2.7.12 fastcore==1.7.9`\n"
27
- f" - Run: `learn.export('{MODEL_PATH}')`\n"
28
- f"2. Place the new export at: `{MODEL_PATH}`\n"
29
- f"3. Refresh this page.\n\n"
30
- f"**Error details:** {str(e)}"
31
- )
32
- return None
33
- raise
34
 
35
 
36
  def predict(learner, img_bytes: bytes):
 
37
  img = PILImage.create(io.BytesIO(img_bytes))
38
  pred, pred_idx, probs = learner.predict(img)
39
  return pred, probs
40
 
41
 
42
  def main():
43
- st.title("FastAI Image Classifier")
44
- st.write("Upload an image and the model will predict the class.")
45
 
46
  learner = load_model()
47
  if learner is None:
48
  st.warning(
49
- f"No model found at `{MODEL_PATH}`. See README.md for how to export and place your model."
50
  )
 
51
 
52
- uploaded_file = st.file_uploader("Choose an image...", type=["png", "jpg", "jpeg"])
53
  if uploaded_file is not None:
54
  st.image(uploaded_file, caption="Uploaded Image", use_column_width=True)
55
 
56
- if learner is None:
57
- st.error("Can't predict because the model is missing. Follow README.md to export your learner.")
58
- return
59
-
60
  with st.spinner("Predicting..."):
61
- pred, probs = predict(learner, uploaded_file.read())
62
-
63
- st.success(f"Predicted: {pred}")
64
-
65
- # Try to show top probabilities if available
66
- try:
67
- vocab = learner.dls.vocab
68
- probs_list = list(zip(map(str, vocab), map(float, probs)))
69
- probs_list.sort(key=lambda x: x[1], reverse=True)
70
- st.write("Top predictions:")
71
- for label, p in probs_list[:5]:
72
- st.write(f"- {label}: {p:.4f}")
73
- except Exception:
74
- st.write("Probabilities unavailable or not applicable for this model.")
75
 
76
 
77
  if __name__ == "__main__":
 
1
  import io
2
  from pathlib import Path
 
3
  import streamlit as st
4
  from fastai.vision.all import load_learner, PILImage
5
 
6
 
7
+ # Correct absolute path for Hugging Face Spaces
8
+ MODEL_PATH = Path("models/pokemon_gen9_classifier_resnet101_after_cleaning.pkl")
9
 
10
 
11
  @st.cache_resource
12
  def load_model():
13
  """Load and cache the FastAI learner. Returns None if model missing or incompatible."""
14
  if not MODEL_PATH.exists():
15
+ st.error(f"❌ Model not found at {MODEL_PATH}")
16
  return None
17
  try:
18
  learner = load_learner(MODEL_PATH)
19
  return learner
20
+ except Exception as e:
21
+ st.error(f"⚠️ Error loading model:\n\n{e}")
22
+ return None
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
 
25
  def predict(learner, img_bytes: bytes):
26
+ """Make a prediction on uploaded image bytes."""
27
  img = PILImage.create(io.BytesIO(img_bytes))
28
  pred, pred_idx, probs = learner.predict(img)
29
  return pred, probs
30
 
31
 
32
  def main():
33
+ st.title("🎯 FastAI Image Classifier")
34
+ st.write("Upload an image and the model will predict its class.")
35
 
36
  learner = load_model()
37
  if learner is None:
38
  st.warning(
39
+ "Model not loaded. Please ensure the `.pkl` file is correctly placed under `models/` and committed with Git LFS."
40
  )
41
+ st.stop()
42
 
43
+ uploaded_file = st.file_uploader("📤 Choose an image...", type=["png", "jpg", "jpeg"])
44
  if uploaded_file is not None:
45
  st.image(uploaded_file, caption="Uploaded Image", use_column_width=True)
46
 
 
 
 
 
47
  with st.spinner("Predicting..."):
48
+ try:
49
+ pred, probs = predict(learner, uploaded_file.read())
50
+ st.success(f"Predicted: **{pred}**")
51
+ # Show top-5 predictions
52
+ vocab = learner.dls.vocab
53
+ probs_list = sorted(zip(vocab, probs), key=lambda x: x[1], reverse=True)
54
+ st.write("### Top Predictions:")
55
+ for label, p in probs_list[:5]:
56
+ st.write(f"- {label}: {p:.4f}")
57
+ except Exception as e:
58
+ st.error(f"Error during prediction: {e}")
 
 
 
59
 
60
 
61
  if __name__ == "__main__":