File size: 1,856 Bytes
403ea04
0999205
403ea04
0999205
 
403ea04
0999205
403ea04
0999205
403ea04
0999205
403ea04
0999205
 
403ea04
0999205
 
 
403ea04
0999205
403ea04
 
 
 
 
 
0999205
403ea04
0999205
403ea04
 
0999205
 
 
 
403ea04
0999205
403ea04
0999205
 
403ea04
 
 
 
0999205
 
403ea04
 
 
 
 
3e53ece
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import cv2
import numpy as np
import gradio as gr
from detectron2 import model_zoo
from detectron2.config import get_cfg
from detectron2.engine import DefaultPredictor
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog

def initialize_model():
    for d in ["train", "test"]:
        #DatasetCatalog.register("wheat_" + d, lambda d=d: get_wheat_dicts("wheat_Detection/" + d))
        MetadataCatalog.get("wheat_" + d).set(thing_classes=["wheat"])

    wheat_metadata = MetadataCatalog.get("wheat_train")  
    cfg = get_cfg()
    cfg.MODEL.DEVICE = "cpu"
    cfg.DATALOADER.NUM_WORKERS = 0
    cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_101_C4_3x.yaml")
    cfg.SOLVER.IMS_PER_BATCH = 2
    cfg.SOLVER.BASE_LR = 0.00025
    cfg.SOLVER.STEPS = []
    cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 128
    cfg.MODEL.ROI_HEADS.NUM_CLASSES = 2
    cfg.MODEL.WEIGHTS = "output/model_final.pth"
    cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.95
    predictor = DefaultPredictor(cfg)
    return predictor

def process_image(predictor, img):
    outputs = predictor(img)
    wheat_metadata = MetadataCatalog.get("wheat_train")
    v = Visualizer(img[:, :, ::-1],
                   metadata=wheat_metadata, 
                   scale=1.5, 
                   instance_mode="segmentation")
    out = v.draw_instance_predictions(outputs["instances"].to("cpu"))
    processed_img = cv2.cvtColor(out.get_image()[:, :, ::-1], cv2.COLOR_BGR2RGB)
    return processed_img

def main(img):
    predictor = initialize_model()
    processed_img = process_image(predictor, img)
    return processed_img


iface = gr.Interface(
    fn=main,
    inputs="image",
    outputs="image",
    title="Wheat head Detector & Counting Wheat heads",
    cache_examples=False, port=7861).launch(share=True)