File size: 2,085 Bytes
aab7641
 
3208c00
aab7641
 
 
b160f6c
9b0cb23
b160f6c
 
 
aab7641
 
 
b160f6c
aab7641
 
 
 
08cb881
b160f6c
08cb881
aab7641
b160f6c
aab7641
b160f6c
aab7641
 
 
 
b160f6c
aab7641
b160f6c
 
aab7641
 
 
b160f6c
 
 
 
aab7641
08cb881
 
b160f6c
aab7641
b160f6c
08cb881
 
b160f6c
 
 
 
 
 
 
08cb881
b160f6c
aab7641
3208c00
aab7641
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import io
from pathlib import Path
import streamlit as st
from fastai.vision.all import load_learner, PILImage


# ✅ Correct absolute path for Hugging Face Spaces
MODEL_PATH = Path("models/pokemon_gen9_classifier_resnet101_after_cleaning.pkl") 


@st.cache_resource
def load_model():
    """Load and cache the FastAI learner. Returns None if model missing or incompatible."""
    if not MODEL_PATH.exists():
        st.error(f"❌ Model not found at {MODEL_PATH}")
        return None
    try:
        learner = load_learner(MODEL_PATH)
        return learner
    except Exception as e:
        st.error(f"⚠️ Error loading model:\n\n{e}")
        return None


def predict(learner, img_bytes: bytes):
    """Make a prediction on uploaded image bytes."""
    img = PILImage.create(io.BytesIO(img_bytes))
    pred, pred_idx, probs = learner.predict(img)
    return pred, probs


def main():
    st.title("🎯 FastAI Image Classifier")
    st.write("Upload an image and the model will predict its class.")

    learner = load_model()
    if learner is None:
        st.warning(
            "Model not loaded. Please ensure the `.pkl` file is correctly placed under `models/` and committed with Git LFS."
        )
        st.stop()

    uploaded_file = st.file_uploader("📤 Choose an image...", type=["png", "jpg", "jpeg"])
    if uploaded_file is not None:
        st.image(uploaded_file, caption="Uploaded Image", use_column_width=True)

        with st.spinner("Predicting..."):
            try:
                pred, probs = predict(learner, uploaded_file.read())
                st.success(f"✅ Predicted: **{pred}**")
                # Show top-5 predictions
                vocab = learner.dls.vocab
                probs_list = sorted(zip(vocab, probs), key=lambda x: x[1], reverse=True)
                st.write("### Top Predictions:")
                for label, p in probs_list[:5]:
                    st.write(f"- {label}: {p:.4f}")
            except Exception as e:
                st.error(f"Error during prediction: {e}")


if __name__ == "__main__":
    main()