shuraimi commited on
Commit
3de6fac
·
verified ·
1 Parent(s): 9c40242

Upload 2 files

Browse files
app.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ from pathlib import Path
3
+
4
+ import streamlit as st
5
+ from fastai.vision.all import load_learner, PILImage
6
+
7
+
8
+ MODEL_PATH = Path("models/pokemon_gen9_classifier_resnet101_after_cleaning.pkl")
9
+
10
+
11
+ @st.cache_resource
12
+ def load_model():
13
+ """Load and cache the FastAI learner. Returns None if model missing or incompatible."""
14
+ if not MODEL_PATH.exists():
15
+ return None
16
+ try:
17
+ learner = load_learner(MODEL_PATH)
18
+ return learner
19
+ except RuntimeError as e:
20
+ if "deprecated in `fastai>=2.8.0`" in str(e):
21
+ st.error(
22
+ f"⚠️ Model incompatibility detected!\n\n"
23
+ f"Your model was exported with FastAI ≥2.8.0, but this app uses FastAI 2.7.12.\n\n"
24
+ f"**To fix this:**\n"
25
+ f"1. Re-export your model using FastAI 2.7.12:\n"
26
+ f" - Downgrade: `pip install fastai==2.7.12 fastcore==1.7.9`\n"
27
+ f" - Run: `learn.export('{MODEL_PATH}')`\n"
28
+ f"2. Place the new export at: `{MODEL_PATH}`\n"
29
+ f"3. Refresh this page.\n\n"
30
+ f"**Error details:** {str(e)}"
31
+ )
32
+ return None
33
+ raise
34
+
35
+
36
+ def predict(learner, img_bytes: bytes):
37
+ img = PILImage.create(io.BytesIO(img_bytes))
38
+ pred, pred_idx, probs = learner.predict(img)
39
+ return pred, probs
40
+
41
+
42
+ def main():
43
+ st.title("FastAI Image Classifier")
44
+ st.write("Upload an image and the model will predict the class.")
45
+
46
+ learner = load_model()
47
+ if learner is None:
48
+ st.warning(
49
+ f"No model found at `{MODEL_PATH}`. See README.md for how to export and place your model."
50
+ )
51
+
52
+ uploaded_file = st.file_uploader("Choose an image...", type=["png", "jpg", "jpeg"])
53
+ if uploaded_file is not None:
54
+ st.image(uploaded_file, caption="Uploaded Image", use_column_width=True)
55
+
56
+ if learner is None:
57
+ st.error("Can't predict because the model is missing. Follow README.md to export your learner.")
58
+ return
59
+
60
+ with st.spinner("Predicting..."):
61
+ pred, probs = predict(learner, uploaded_file.read())
62
+
63
+ st.success(f"Predicted: {pred}")
64
+
65
+ # Try to show top probabilities if available
66
+ try:
67
+ vocab = learner.dls.vocab
68
+ probs_list = list(zip(map(str, vocab), map(float, probs)))
69
+ probs_list.sort(key=lambda x: x[1], reverse=True)
70
+ st.write("Top predictions:")
71
+ for label, p in probs_list[:5]:
72
+ st.write(f"- {label}: {p:.4f}")
73
+ except Exception:
74
+ st.write("Probabilities unavailable or not applicable for this model.")
75
+
76
+
77
+ if __name__ == "__main__":
78
+ main()
pokemon_gen9_classifier_resnet101_after_cleaning.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cb8271ec116835a40bb86af8ac5fbee061e761c2f08a65018bd2340872ee643f
3
+ size 179712690