SyedNaseem commited on
Commit
f9a0a96
·
verified ·
1 Parent(s): 4c72cdb

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +127 -0
  2. requirements.txt +8 -0
app.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import streamlit as st
3
+ import numpy as np
4
+ import cv2
5
+ import warnings
6
+ import os
7
+
8
+ # Suppress warnings
9
+ warnings.filterwarnings("ignore", category=FutureWarning)
10
+ warnings.filterwarnings("ignore", category=UserWarning)
11
+
12
+ # Try importing TensorFlow
13
+ try:
14
+ from tensorflow.keras.models import load_model
15
+ from tensorflow.keras.preprocessing import image
16
+ except ImportError:
17
+ st.error("Failed to import TensorFlow. Please make sure it's installed correctly.")
18
+
19
+ # Try importing PyTorch and Detectron2
20
+ try:
21
+ import torch
22
+ import detectron2
23
+ except ImportError:
24
+ with st.spinner("Installing PyTorch and Detectron2..."):
25
+ os.system("pip install torch torchvision")
26
+ os.system("pip install 'git+https://github.com/facebookresearch/detectron2.git'")
27
+
28
+ import torch
29
+ import detectron2
30
+
31
+
32
+ import streamlit as st
33
+ import numpy as np
34
+ import cv2
35
+ import torch
36
+ import os
37
+ from PIL import Image
38
+ from tensorflow.keras.models import load_model
39
+ from tensorflow.keras.preprocessing import image
40
+ from detectron2.engine import DefaultPredictor
41
+ from detectron2.config import get_cfg
42
+ from detectron2.utils.visualizer import Visualizer
43
+ from detectron2.data import MetadataCatalog
44
+
45
+ # Suppress warnings
46
+ import warnings
47
+ import tensorflow as tf
48
+ warnings.filterwarnings("ignore")
49
+ tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
50
+
51
+ @st.cache_resource
52
+ def load_models():
53
+ model_name = load_model('name_model_inception.h5')
54
+ model_quality = load_model('type_model_inception.h5')
55
+ return model_name, model_quality
56
+
57
+ model_name, model_quality = load_models()
58
+
59
+ # Detectron2 setup
60
+ @st.cache_resource
61
+ def load_detectron_model(fruit_name):
62
+ cfg = get_cfg()
63
+ config_path = os.path.join(f"{fruit_name.lower()}_config.yaml")
64
+ cfg.merge_from_file(config_path)
65
+ model_path = os.path.join(f"{fruit_name}_model.pth")
66
+ cfg.MODEL.WEIGHTS = model_path
67
+ cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
68
+ cfg.MODEL.DEVICE = 'cpu'
69
+ predictor = DefaultPredictor(cfg)
70
+ return predictor, cfg
71
+
72
+ # Labels
73
+ label_map_name = {
74
+ 0: "Banana", 1: "Cucumber", 2: "Grape", 3: "Kaki", 4: "Papaya",
75
+ 5: "Peach", 6: "Pear", 7: "Peeper", 8: "Strawberry", 9: "Watermelon",
76
+ 10: "tomato"
77
+ }
78
+ label_map_quality = {0: "Good", 1: "Mild", 2: "Rotten"}
79
+
80
+ def predict_fruit(img):
81
+ # Preprocess image
82
+ img = Image.fromarray(img.astype('uint8'), 'RGB')
83
+ img = img.resize((224, 224))
84
+ x = image.img_to_array(img)
85
+ x = np.expand_dims(x, axis=0)
86
+ x = x / 255.0
87
+
88
+ # Predict
89
+ pred_name = model_name.predict(x)
90
+ pred_quality = model_quality.predict(x)
91
+
92
+ predicted_name = label_map_name[np.argmax(pred_name, axis=1)[0]]
93
+ predicted_quality = label_map_quality[np.argmax(pred_quality, axis=1)[0]]
94
+
95
+ return predicted_name, predicted_quality, img
96
+
97
+ def main():
98
+ st.title("Automated Fruits Monitoring System")
99
+ st.write("Upload an image of a fruit to detect its type, quality, and potential damage.")
100
+
101
+ uploaded_file = st.file_uploader("Choose a fruit image...", type=["jpg", "jpeg", "png"])
102
+
103
+ if uploaded_file is not None:
104
+ image = Image.open(uploaded_file)
105
+ st.image(image, caption="Uploaded Image", use_column_width=True)
106
+
107
+ if st.button("Analyze"):
108
+ predicted_name, predicted_quality, img = predict_fruit(np.array(image))
109
+
110
+ st.write(f"Fruits Type Detection: {predicted_name}")
111
+ st.write(f"Fruits Quality Classification: {predicted_quality}")
112
+
113
+ if predicted_name.lower() in ["kaki", "tomato", "strawberry", "peeper", "pear", "peach", "papaya", "watermelon", "grape", "banana", "cucumber"] and predicted_quality in ["Mild", "Rotten"]:
114
+ st.write("Segmentation of Defective Region:")
115
+ try:
116
+ predictor, cfg = load_detectron_model(predicted_name)
117
+ outputs = predictor(cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR))
118
+ v = Visualizer(np.array(img), MetadataCatalog.get(cfg.DATASETS.TRAIN[0]), scale=0.8)
119
+ out = v.draw_instance_predictions(outputs["instances"].to("cpu"))
120
+ st.image(out.get_image(), caption="Damage Detection Result", use_column_width=True)
121
+ except Exception as e:
122
+ st.error(f"Error in damage detection: {str(e)}")
123
+ else:
124
+ st.write("No damage detection performed for this fruit or quality level.")
125
+
126
+ if __name__ == "__main__":
127
+ main()
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ streamlit
2
+ tensorflow
3
+ numpy
4
+ opencv-python
5
+ opencv-python
6
+ torch
7
+ torchvision
8
+ matplotlib