NagashreePai commited on
Commit
df0132b
·
verified ·
1 Parent(s): 6f06748

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -21
app.py CHANGED
@@ -1,21 +1,54 @@
1
- import gradio as gr
2
- from utils import load_all_models, predict_image
3
-
4
- # Load all models and class mappings once at startup
5
- models, offsets, idx_to_class = load_all_models()
6
-
7
- # Inference function wrapper
8
- def classify(image):
9
- return predict_image(image, models, offsets, idx_to_class)
10
-
11
- # Launch Gradio interface
12
- demo = gr.Interface(
13
- fn=classify,
14
- inputs=gr.Image(type="pil", label="Upload Weed Image"),
15
- outputs=gr.Label(num_top_classes=3, label="Top Predicted Classes"),
16
- title="🌿 Weed Species Classifier",
17
- description="Upload a weed image to classify it into one of 25 species using confidence-based routing across 3 Swin Transformer models.",
18
- live=True # ✅ Shows a spinner while running
19
- )
20
-
21
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from utils import load_model, preprocess_image, predict_final_class
3
+
4
+ st.set_page_config(page_title="Weed Classifier", layout="centered")
5
+
6
+ st.title("🌿 Weed Species Classifier")
7
+ st.write("Upload an image of a weed, and the model will classify it.")
8
+
9
+ # Define the models and class mappings
10
+ model_defs = [
11
+ {
12
+ "name": "Model 1",
13
+ "path": "MMIM_best1.pth",
14
+ "class_names": ["class10", "class11", "class12", "class13"]
15
+ },
16
+ {
17
+ "name": "Model 2",
18
+ "path": "MMIM_best2.pth",
19
+ "class_names": ["class14", "class15", "class16", "class17", "class18", "class19"]
20
+ },
21
+ {
22
+ "name": "Model 3",
23
+ "path": "MMIM_best3.pth",
24
+ "class_names": ["class20", "class21", "class22", "class23", "class24", "class25"]
25
+ }
26
+ ]
27
+
28
+ # Load models once (on app startup)
29
+ @st.cache_resource
30
+ def load_all_models():
31
+ for model_def in model_defs:
32
+ model_def["model"] = load_model(model_def["path"], len(model_def["class_names"]))
33
+ return model_defs
34
+
35
+ model_defs = load_all_models()
36
+
37
+ # Upload section
38
+ uploaded_image = st.file_uploader("📤 Upload Weed Image", type=["jpg", "jpeg", "png"])
39
+
40
+ # Prediction
41
+ if uploaded_image:
42
+ st.image(uploaded_image, caption="Uploaded Image", use_column_width=True)
43
+
44
+ # Preprocess and predict
45
+ image_tensor = preprocess_image(uploaded_image)
46
+ predicted_class = predict_final_class(image_tensor, model_defs)
47
+
48
+ # Display result
49
+ st.markdown("## 🔍 Predicted Class")
50
+ st.success(f"**{predicted_class}**")
51
+
52
+ st.markdown("---")
53
+ if st.button("Clear"):
54
+ st.experimental_rerun()