Arulkumar03 commited on
Commit
0999205
·
1 Parent(s): f9d01f3

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +99 -0
app.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2,os
2
+ import numpy as np
3
+ import streamlit as st
4
+ from detectron2 import utils
5
+ from detectron2.engine import DefaultTrainer
6
+ from detectron2.config import get_cfg
7
+ from detectron2.utils import comm
8
+ from detectron2.utils.logger import setup_logger
9
+ # import some common libraries
10
+ import numpy as np
11
+ import os, json, cv2, random
12
+ #from google.colab.patches import cv2_imshow
13
+ import warnings
14
+ warnings.filterwarnings('ignore')
15
+ # import some common detectron2 utilities
16
+ from detectron2 import model_zoo
17
+ from detectron2.engine import DefaultPredictor
18
+ from detectron2.config import get_cfg
19
+ from detectron2.utils.visualizer import Visualizer
20
+ from detectron2.data import MetadataCatalog, DatasetCatalog
21
+ from detectron2.structures import BoxMode
22
+ from detectron2.utils.visualizer import ColorMode
23
+ import matplotlib.pyplot as plt
24
+
25
+
26
+ @st.cache(persist=True)
27
+ def initialization():
28
+ """Loads configuration and model for the prediction.
29
+
30
+ Returns:
31
+ cfg (detectron2.config.config.CfgNode): Configuration for the model.
32
+ predictor (detectron2.engine.defaults.DefaultPredicto): Model to use.
33
+ by the model.
34
+
35
+ """
36
+ for d in ["train", "test"]:
37
+ #DatasetCatalog.register("Animals_" + d, lambda d=d: get_wheat_dicts("Animal_Detection/" + d))
38
+ MetadataCatalog.get("wheat_" + d).set(thing_classes=["wheat"])
39
+
40
+ wheat_metadata = MetadataCatalog.get("wheat_train")
41
+ cfg = get_cfg()
42
+
43
+ cfg.MODEL.DEVICE = "cpu"
44
+ cfg.DATALOADER.NUM_WORKERS = 0
45
+ cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_101_C4_3x.yaml") # Let training initialize from model zoo
46
+ cfg.SOLVER.IMS_PER_BATCH = 2
47
+ cfg.SOLVER.BASE_LR = 0.00025 # pick a good LR
48
+ #cfg.SOLVER.MAX_ITER =3000 # 300 iterations seems good enough for this toy dataset; you will need to train longer for a practical dataset
49
+ cfg.SOLVER.STEPS = [] # do not decay learning rate
50
+ cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 128 # faster, and good enough for this toy dataset (default: 512)
51
+ cfg.MODEL.ROI_HEADS.NUM_CLASSES = 2 # only has one class (wheat).
52
+
53
+ # Find a model from detectron2's model zoo. You can use the https://dl.fbaipublicfiles... url as well
54
+ cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, "model_final.pth") # path to the model we just trained
55
+
56
+ cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.95 # set a custom testing threshold
57
+
58
+ # Initialize prediction model
59
+ predictor = DefaultPredictor(cfg)
60
+
61
+ return cfg, predictor
62
+
63
+
64
+ @st.cache
65
+ def inference(predictor, img):
66
+ return predictor(img)
67
+
68
+
69
+ @st.cache
70
+ def output_image(cfg, img, outputs):
71
+
72
+ wheat_metadata = MetadataCatalog.get("wheat_train")
73
+ v = Visualizer(img[:, :, ::-1],
74
+ metadata=wheat_metadata,
75
+ scale=1.5,
76
+ instance_mode=ColorMode.SEGMENTATION)
77
+ out = v.draw_instance_predictions(outputs["instances"].to("cpu"))
78
+ processed_img = cv2.cvtColor((out.get_image()[:, :, ::-1]), cv2.COLOR_BGR2RGB)
79
+
80
+ return processed_img
81
+
82
+
83
+ def main():
84
+ # Initialization
85
+ cfg, predictor = initialization()
86
+
87
+ # Retrieve image
88
+ uploaded_img = st.file_uploader("Choose an image...", type=['jpg', 'jpeg', 'png'])
89
+ if uploaded_img is not None:
90
+ file_bytes = np.asarray(bytearray(uploaded_img.read()), dtype=np.uint8)
91
+ img = cv2.imdecode(file_bytes, 1)
92
+ # Detection code
93
+ outputs = inference(predictor, img)
94
+ out_image = output_image(cfg, img, outputs)
95
+ st.image(out_image, caption='Processed Image', use_column_width=True)
96
+
97
+
98
+ if __name__ == '__main__':
99
+ main()