| | import cv2 |
| | import numpy as np |
| | import gradio as gr |
| | from detectron2 import model_zoo |
| | from detectron2.config import get_cfg |
| | from detectron2.engine import DefaultPredictor |
| | from detectron2.utils.visualizer import Visualizer |
| | from detectron2.data import MetadataCatalog |
| |
|
| | def initialize_model(): |
| | for d in ["train", "test"]: |
| | |
| | MetadataCatalog.get("wheat_" + d).set(thing_classes=["wheat"]) |
| |
|
| | wheat_metadata = MetadataCatalog.get("wheat_train") |
| | cfg = get_cfg() |
| | cfg.MODEL.DEVICE = "cpu" |
| | cfg.DATALOADER.NUM_WORKERS = 0 |
| | cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_101_C4_3x.yaml") |
| | cfg.SOLVER.IMS_PER_BATCH = 2 |
| | cfg.SOLVER.BASE_LR = 0.00025 |
| | cfg.SOLVER.STEPS = [] |
| | cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 128 |
| | cfg.MODEL.ROI_HEADS.NUM_CLASSES = 2 |
| | cfg.MODEL.WEIGHTS = "output/model_final.pth" |
| | cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.95 |
| | predictor = DefaultPredictor(cfg) |
| | return predictor |
| |
|
| | def process_image(predictor, img): |
| | outputs = predictor(img) |
| | wheat_metadata = MetadataCatalog.get("wheat_train") |
| | v = Visualizer(img[:, :, ::-1], |
| | metadata=wheat_metadata, |
| | scale=1.5, |
| | instance_mode="segmentation") |
| | out = v.draw_instance_predictions(outputs["instances"].to("cpu")) |
| | processed_img = cv2.cvtColor(out.get_image()[:, :, ::-1], cv2.COLOR_BGR2RGB) |
| | return processed_img |
| |
|
| | def main(img): |
| | predictor = initialize_model() |
| | processed_img = process_image(predictor, img) |
| | return processed_img |
| |
|
| |
|
| | iface = gr.Interface( |
| | fn=main, |
| | inputs="image", |
| | outputs="image", |
| | title="Wheat head Detector & Counting Wheat heads", |
| | cache_examples=False |
| | ) |
| | iface.launch() |