Arulkumar03's picture
Upload app.py
0999205
raw
history blame
3.63 kB
import cv2,os
import numpy as np
import streamlit as st
from detectron2 import utils
from detectron2.engine import DefaultTrainer
from detectron2.config import get_cfg
from detectron2.utils import comm
from detectron2.utils.logger import setup_logger
# import some common libraries
import numpy as np
import os, json, cv2, random
#from google.colab.patches import cv2_imshow
import warnings
warnings.filterwarnings('ignore')
# import some common detectron2 utilities
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog, DatasetCatalog
from detectron2.structures import BoxMode
from detectron2.utils.visualizer import ColorMode
import matplotlib.pyplot as plt
@st.cache(persist=True)
def initialization():
"""Loads configuration and model for the prediction.
Returns:
cfg (detectron2.config.config.CfgNode): Configuration for the model.
predictor (detectron2.engine.defaults.DefaultPredicto): Model to use.
by the model.
"""
for d in ["train", "test"]:
#DatasetCatalog.register("Animals_" + d, lambda d=d: get_wheat_dicts("Animal_Detection/" + d))
MetadataCatalog.get("wheat_" + d).set(thing_classes=["wheat"])
wheat_metadata = MetadataCatalog.get("wheat_train")
cfg = get_cfg()
cfg.MODEL.DEVICE = "cpu"
cfg.DATALOADER.NUM_WORKERS = 0
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_101_C4_3x.yaml") # Let training initialize from model zoo
cfg.SOLVER.IMS_PER_BATCH = 2
cfg.SOLVER.BASE_LR = 0.00025 # pick a good LR
#cfg.SOLVER.MAX_ITER =3000 # 300 iterations seems good enough for this toy dataset; you will need to train longer for a practical dataset
cfg.SOLVER.STEPS = [] # do not decay learning rate
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 128 # faster, and good enough for this toy dataset (default: 512)
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 2 # only has one class (wheat).
# Find a model from detectron2's model zoo. You can use the https://dl.fbaipublicfiles... url as well
cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, "model_final.pth") # path to the model we just trained
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.95 # set a custom testing threshold
# Initialize prediction model
predictor = DefaultPredictor(cfg)
return cfg, predictor
@st.cache
def inference(predictor, img):
return predictor(img)
@st.cache
def output_image(cfg, img, outputs):
wheat_metadata = MetadataCatalog.get("wheat_train")
v = Visualizer(img[:, :, ::-1],
metadata=wheat_metadata,
scale=1.5,
instance_mode=ColorMode.SEGMENTATION)
out = v.draw_instance_predictions(outputs["instances"].to("cpu"))
processed_img = cv2.cvtColor((out.get_image()[:, :, ::-1]), cv2.COLOR_BGR2RGB)
return processed_img
def main():
# Initialization
cfg, predictor = initialization()
# Retrieve image
uploaded_img = st.file_uploader("Choose an image...", type=['jpg', 'jpeg', 'png'])
if uploaded_img is not None:
file_bytes = np.asarray(bytearray(uploaded_img.read()), dtype=np.uint8)
img = cv2.imdecode(file_bytes, 1)
# Detection code
outputs = inference(predictor, img)
out_image = output_image(cfg, img, outputs)
st.image(out_image, caption='Processed Image', use_column_width=True)
if __name__ == '__main__':
main()