| import cv2,os |
| import numpy as np |
| import streamlit as st |
| from detectron2 import utils |
| from detectron2.engine import DefaultTrainer |
| from detectron2.config import get_cfg |
| from detectron2.utils import comm |
| from detectron2.utils.logger import setup_logger |
| |
| import numpy as np |
| import os, json, cv2, random |
| |
| import warnings |
| warnings.filterwarnings('ignore') |
| |
| from detectron2 import model_zoo |
| from detectron2.engine import DefaultPredictor |
| from detectron2.config import get_cfg |
| from detectron2.utils.visualizer import Visualizer |
| from detectron2.data import MetadataCatalog, DatasetCatalog |
| from detectron2.structures import BoxMode |
| from detectron2.utils.visualizer import ColorMode |
| import matplotlib.pyplot as plt |
|
|
|
|
| @st.cache(persist=True) |
| def initialization(): |
| """Loads configuration and model for the prediction. |
| |
| Returns: |
| cfg (detectron2.config.config.CfgNode): Configuration for the model. |
| predictor (detectron2.engine.defaults.DefaultPredicto): Model to use. |
| by the model. |
| |
| """ |
| for d in ["train", "test"]: |
| |
| MetadataCatalog.get("wheat_" + d).set(thing_classes=["wheat"]) |
|
|
| wheat_metadata = MetadataCatalog.get("wheat_train") |
| cfg = get_cfg() |
| |
| cfg.MODEL.DEVICE = "cpu" |
| cfg.DATALOADER.NUM_WORKERS = 0 |
| cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_101_C4_3x.yaml") |
| cfg.SOLVER.IMS_PER_BATCH = 2 |
| cfg.SOLVER.BASE_LR = 0.00025 |
| |
| cfg.SOLVER.STEPS = [] |
| cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 128 |
| cfg.MODEL.ROI_HEADS.NUM_CLASSES = 2 |
|
|
| |
| cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, "model_final.pth") |
|
|
| cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.95 |
|
|
| |
| predictor = DefaultPredictor(cfg) |
|
|
| return cfg, predictor |
|
|
|
|
| @st.cache |
| def inference(predictor, img): |
| return predictor(img) |
|
|
|
|
| @st.cache |
| def output_image(cfg, img, outputs): |
|
|
| wheat_metadata = MetadataCatalog.get("wheat_train") |
| v = Visualizer(img[:, :, ::-1], |
| metadata=wheat_metadata, |
| scale=1.5, |
| instance_mode=ColorMode.SEGMENTATION) |
| out = v.draw_instance_predictions(outputs["instances"].to("cpu")) |
| processed_img = cv2.cvtColor((out.get_image()[:, :, ::-1]), cv2.COLOR_BGR2RGB) |
|
|
| return processed_img |
|
|
|
|
| def main(): |
| |
| cfg, predictor = initialization() |
|
|
| |
| uploaded_img = st.file_uploader("Choose an image...", type=['jpg', 'jpeg', 'png']) |
| if uploaded_img is not None: |
| file_bytes = np.asarray(bytearray(uploaded_img.read()), dtype=np.uint8) |
| img = cv2.imdecode(file_bytes, 1) |
| |
| outputs = inference(predictor, img) |
| out_image = output_image(cfg, img, outputs) |
| st.image(out_image, caption='Processed Image', use_column_width=True) |
|
|
|
|
| if __name__ == '__main__': |
| main() |