Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import numpy as np | |
| import cv2 | |
| import warnings | |
| import os | |
| # Suppress warnings | |
| warnings.filterwarnings("ignore", category=FutureWarning) | |
| warnings.filterwarnings("ignore", category=UserWarning) | |
| # Try importing TensorFlow | |
| try: | |
| from tensorflow.keras.models import load_model | |
| from tensorflow.keras.preprocessing import image | |
| except ImportError: | |
| st.error("Failed to import TensorFlow. Please make sure it's installed correctly.") | |
| # Try importing PyTorch and Detectron2 | |
| try: | |
| import torch | |
| import detectron2 | |
| except ImportError: | |
| with st.spinner("Installing PyTorch and Detectron2..."): | |
| os.system("pip install torch torchvision") | |
| os.system("pip install 'git+https://github.com/facebookresearch/detectron2.git'") | |
| import torch | |
| import detectron2 | |
| import streamlit as st | |
| import numpy as np | |
| import cv2 | |
| import torch | |
| import os | |
| from PIL import Image | |
| from tensorflow.keras.models import load_model | |
| from tensorflow.keras.preprocessing import image | |
| from detectron2.engine import DefaultPredictor | |
| from detectron2.config import get_cfg | |
| from detectron2.utils.visualizer import Visualizer | |
| from detectron2.data import MetadataCatalog | |
| # Suppress warnings | |
| import warnings | |
| import tensorflow as tf | |
| warnings.filterwarnings("ignore") | |
| tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) | |
| def load_models(): | |
| model_name = load_model('name_model_inception.h5') | |
| model_quality = load_model('type_model_inception.h5') | |
| return model_name, model_quality | |
| model_name, model_quality = load_models() | |
| # Detectron2 setup | |
| def load_detectron_model(fruit_name): | |
| cfg = get_cfg() | |
| config_path = os.path.join(f"{fruit_name.lower()}_config.yaml") | |
| cfg.merge_from_file(config_path) | |
| model_path = os.path.join(f"{fruit_name}_model.pth") | |
| cfg.MODEL.WEIGHTS = model_path | |
| cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 | |
| cfg.MODEL.DEVICE = 'cpu' | |
| predictor = DefaultPredictor(cfg) | |
| return predictor, cfg | |
| # Labels | |
| label_map_name = { | |
| 0: "Banana", 1: "Cucumber", 2: "Grape", 3: "Kaki", 4: "Papaya", | |
| 5: "Peach", 6: "Pear", 7: "Peeper", 8: "Strawberry", 9: "Watermelon", | |
| 10: "tomato" | |
| } | |
| label_map_quality = {0: "Good", 1: "Mild", 2: "Rotten"} | |
| def predict_fruit(img): | |
| # Preprocess image | |
| img = Image.fromarray(img.astype('uint8'), 'RGB') | |
| img = img.resize((224, 224)) | |
| x = image.img_to_array(img) | |
| x = np.expand_dims(x, axis=0) | |
| x = x / 255.0 | |
| # Predict | |
| pred_name = model_name.predict(x) | |
| pred_quality = model_quality.predict(x) | |
| predicted_name = label_map_name[np.argmax(pred_name, axis=1)[0]] | |
| predicted_quality = label_map_quality[np.argmax(pred_quality, axis=1)[0]] | |
| return predicted_name, predicted_quality, img | |
| def main(): | |
| st.title("An Intelligent Fruits Monitoring System") | |
| st.write("Upload an image of a fruit to detect its type, quality, and potential damage.") | |
| uploaded_file = st.file_uploader("Choose a fruit image...", type=["jpg", "jpeg", "png"]) | |
| if uploaded_file is not None: | |
| image = Image.open(uploaded_file) | |
| st.image(image, caption="Uploaded Image", use_column_width=True) | |
| if st.button("Analyze"): | |
| predicted_name, predicted_quality, img = predict_fruit(np.array(image)) | |
| st.write(f"Fruits Type Detection: {predicted_name}") | |
| st.write(f"Fruits Quality Classification: {predicted_quality}") | |
| if predicted_name.lower() in ["kaki", "tomato", "strawberry", "peeper", "pear", "peach", "papaya", "watermelon", "grape", "banana", "cucumber"] and predicted_quality in ["Mild", "Rotten"]: | |
| st.write("Segmentation of Defective Region of Fruit.") | |
| try: | |
| predictor, cfg = load_detectron_model(predicted_name) | |
| outputs = predictor(cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)) | |
| v = Visualizer(np.array(img), MetadataCatalog.get(cfg.DATASETS.TRAIN[0]), scale=0.8) | |
| out = v.draw_instance_predictions(outputs["instances"].to("cpu")) | |
| st.image(out.get_image(), caption="Damage Detection Result", use_column_width=True) | |
| except Exception as e: | |
| st.error(f"Error in damage detection: {str(e)}") | |
| else: | |
| st.write("No damage detection performed for this fruit or quality level.") | |
| if __name__ == "__main__": | |
| main() |