File size: 1,856 Bytes
403ea04 0999205 403ea04 0999205 403ea04 0999205 403ea04 0999205 403ea04 0999205 403ea04 0999205 403ea04 0999205 403ea04 0999205 403ea04 0999205 403ea04 0999205 403ea04 0999205 403ea04 0999205 403ea04 0999205 403ea04 0999205 403ea04 3e53ece |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 |
import cv2
import numpy as np
import gradio as gr
from detectron2 import model_zoo
from detectron2.config import get_cfg
from detectron2.engine import DefaultPredictor
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog
def initialize_model():
for d in ["train", "test"]:
#DatasetCatalog.register("wheat_" + d, lambda d=d: get_wheat_dicts("wheat_Detection/" + d))
MetadataCatalog.get("wheat_" + d).set(thing_classes=["wheat"])
wheat_metadata = MetadataCatalog.get("wheat_train")
cfg = get_cfg()
cfg.MODEL.DEVICE = "cpu"
cfg.DATALOADER.NUM_WORKERS = 0
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_101_C4_3x.yaml")
cfg.SOLVER.IMS_PER_BATCH = 2
cfg.SOLVER.BASE_LR = 0.00025
cfg.SOLVER.STEPS = []
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 128
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 2
cfg.MODEL.WEIGHTS = "output/model_final.pth"
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.95
predictor = DefaultPredictor(cfg)
return predictor
def process_image(predictor, img):
outputs = predictor(img)
wheat_metadata = MetadataCatalog.get("wheat_train")
v = Visualizer(img[:, :, ::-1],
metadata=wheat_metadata,
scale=1.5,
instance_mode="segmentation")
out = v.draw_instance_predictions(outputs["instances"].to("cpu"))
processed_img = cv2.cvtColor(out.get_image()[:, :, ::-1], cv2.COLOR_BGR2RGB)
return processed_img
def main(img):
predictor = initialize_model()
processed_img = process_image(predictor, img)
return processed_img
iface = gr.Interface(
fn=main,
inputs="image",
outputs="image",
title="Wheat head Detector & Counting Wheat heads",
cache_examples=False, port=7861).launch(share=True) |