File size: 1,901 Bytes
a45382c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3a2e08c
a45382c
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import cv2
import numpy as np
import gradio as gr
from detectron2 import model_zoo
from detectron2.config import get_cfg
from detectron2.engine import DefaultPredictor
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog

def initialize_model():
    for d in ["train", "test"]:
        #DatasetCatalog.register("Animals_" + d, lambda d=d: get_wheat_dicts("Animal_Detection/" + d))
        MetadataCatalog.get("Animals_" + d).set(thing_classes=["fox","sheep"])

    wheat_metadata = MetadataCatalog.get("Animals_train")  
    cfg = get_cfg()
    cfg.MODEL.DEVICE = "cpu"
    cfg.DATALOADER.NUM_WORKERS = 0
    cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_101_C4_3x.yaml")
    cfg.SOLVER.IMS_PER_BATCH = 2
    cfg.SOLVER.BASE_LR = 0.00025
    cfg.SOLVER.STEPS = []
    cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 128
    cfg.MODEL.ROI_HEADS.NUM_CLASSES = 2
    cfg.MODEL.WEIGHTS = "output/model_final.pth"
    cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.95
    predictor = DefaultPredictor(cfg)
    return predictor

def process_image(predictor, img):
    outputs = predictor(img)
    wheat_metadata = MetadataCatalog.get("Animals_train")
    v = Visualizer(img[:, :, ::-1],
                   metadata=wheat_metadata, 
                   scale=1.5, 
                   instance_mode="segmentation")
    out = v.draw_instance_predictions(outputs["instances"].to("cpu"))
    processed_img = cv2.cvtColor(out.get_image()[:, :, ::-1], cv2.COLOR_BGR2RGB)
    return processed_img

def main(img):
    predictor = initialize_model()
    processed_img = process_image(predictor, img)
    return processed_img


iface = gr.Interface(
    fn=main,
    inputs="image",
    outputs="image",
    title="Fox & Sheep Computer Vision detector",
    cache_examples=False,input_size=(8000, 8000), output_size=(8000, 8000)
)
iface.launch()