Spaces:
Runtime error
Runtime error
Commit ·
e8f4d7e
1
Parent(s): 2482ba4
Inference config
Browse files- inference.py +9 -6
- scripts/inferEval.py +34 -0
- scripts/infer_frames.py +22 -10
inference.py
CHANGED
|
@@ -24,9 +24,12 @@ WEIGHTS = 'models/v5m_896_300best.pt'
|
|
| 24 |
# will need to configure these based on GPU hardware
|
| 25 |
BATCH_SIZE = 32
|
| 26 |
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
|
|
|
|
|
|
|
|
|
| 30 |
###
|
| 31 |
|
| 32 |
def norm(bbox, w, h):
|
|
@@ -131,7 +134,7 @@ def do_detection(dataloader, model, device, gp=None, batch_size=BATCH_SIZE):
|
|
| 131 |
|
| 132 |
return inference, width, height
|
| 133 |
|
| 134 |
-
def do_suppression(dataloader, inference, width, height, gp=None, batch_size=BATCH_SIZE):
|
| 135 |
"""
|
| 136 |
Args:
|
| 137 |
frames_dir: a directory containing frames to be evaluated
|
|
@@ -177,7 +180,7 @@ def do_suppression(dataloader, inference, width, height, gp=None, batch_size=BAT
|
|
| 177 |
|
| 178 |
return all_preds, real_width, real_height
|
| 179 |
|
| 180 |
-
def do_tracking(all_preds, image_meter_width, image_meter_height, gp=None):
|
| 181 |
|
| 182 |
if (gp): gp(0, "Tracking...")
|
| 183 |
|
|
@@ -188,7 +191,7 @@ def do_tracking(all_preds, image_meter_width, image_meter_height, gp=None):
|
|
| 188 |
'image_meter_width': image_meter_width,
|
| 189 |
'image_meter_height': image_meter_height
|
| 190 |
}
|
| 191 |
-
tracker = Tracker(clip_info, args={ 'max_age':
|
| 192 |
|
| 193 |
# Run tracking
|
| 194 |
with tqdm(total=len(all_preds), desc="Running tracking", ncols=0) as pbar:
|
|
|
|
| 24 |
# will need to configure these based on GPU hardware
|
| 25 |
BATCH_SIZE = 32
|
| 26 |
|
| 27 |
+
CONF_THRES = 0.3 # detection
|
| 28 |
+
NMS_IOU = 0.3 # NMS IOU
|
| 29 |
+
MIN_LENGTH = 0.3 # minimum fish length, in meters
|
| 30 |
+
MAX_AGE = 20 # time until missing fish get's new id
|
| 31 |
+
IOU_THRES = 0.01 # IOU threshold for tracking
|
| 32 |
+
MIN_HITS = 11 # minimum number of frames with a specific fish for it to count
|
| 33 |
###
|
| 34 |
|
| 35 |
def norm(bbox, w, h):
|
|
|
|
| 134 |
|
| 135 |
return inference, width, height
|
| 136 |
|
| 137 |
+
def do_suppression(dataloader, inference, width, height, gp=None, batch_size=BATCH_SIZE, conf_thres=CONF_THRES, iou_thres=NMS_IOU):
|
| 138 |
"""
|
| 139 |
Args:
|
| 140 |
frames_dir: a directory containing frames to be evaluated
|
|
|
|
| 180 |
|
| 181 |
return all_preds, real_width, real_height
|
| 182 |
|
| 183 |
+
def do_tracking(all_preds, image_meter_width, image_meter_height, gp=None, max_age=MAX_AGE, iou_thres=IOU_THRES, min_hits=MIN_HITS, min_length=MIN_LENGTH):
|
| 184 |
|
| 185 |
if (gp): gp(0, "Tracking...")
|
| 186 |
|
|
|
|
| 191 |
'image_meter_width': image_meter_width,
|
| 192 |
'image_meter_height': image_meter_height
|
| 193 |
}
|
| 194 |
+
tracker = Tracker(clip_info, args={ 'max_age': max_age, 'min_hits': 0, 'iou_threshold': iou_thres}, min_hits=min_hits)
|
| 195 |
|
| 196 |
# Run tracking
|
| 197 |
with tqdm(total=len(all_preds), desc="Running tracking", ncols=0) as pbar:
|
scripts/inferEval.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import project_path
|
| 2 |
+
import argparse
|
| 3 |
+
from infer_frames import main as infer
|
| 4 |
+
import sys
|
| 5 |
+
sys.path.append('..')
|
| 6 |
+
sys.path.append('../caltech-fish-counting')
|
| 7 |
+
|
| 8 |
+
from evaluate import evaluate
|
| 9 |
+
|
| 10 |
+
class Object(object):
|
| 11 |
+
pass
|
| 12 |
+
|
| 13 |
+
def main(args):
|
| 14 |
+
|
| 15 |
+
infer_args = Object()
|
| 16 |
+
infer_args.metadata = "../caltech-fish-counting/data/metadata"
|
| 17 |
+
infer_args.frames = "../caltech-fish-counting/data/images"
|
| 18 |
+
infer_args.output = "../caltech-fish-counting/data/result"
|
| 19 |
+
infer_args.weights = "models/v5m_896_300best.pt"
|
| 20 |
+
infer_args.config = args.config
|
| 21 |
+
|
| 22 |
+
infer(infer_args)
|
| 23 |
+
|
| 24 |
+
evaluate("../frames/result_testing", "../frames/MOT", "../frames/metadata", "tracker", True)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def argument_parser():
|
| 28 |
+
parser = argparse.ArgumentParser()
|
| 29 |
+
parser.add_argument("--config", required=True, help="Config object. Required.")
|
| 30 |
+
return parser
|
| 31 |
+
|
| 32 |
+
if __name__ == "__main__":
|
| 33 |
+
args = argument_parser().parse_args()
|
| 34 |
+
main(args)
|
scripts/infer_frames.py
CHANGED
|
@@ -26,9 +26,20 @@ def main(args):
|
|
| 26 |
print("In task...")
|
| 27 |
print("Cuda available in task?", torch.cuda.is_available())
|
| 28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
dirname = args.frames
|
| 30 |
|
| 31 |
-
locations = ["
|
| 32 |
for loc in locations:
|
| 33 |
|
| 34 |
in_loc_dir = os.path.join(dirname, loc)
|
|
@@ -39,6 +50,9 @@ def main(args):
|
|
| 39 |
print(out_dir)
|
| 40 |
print(metadata_path)
|
| 41 |
|
|
|
|
|
|
|
|
|
|
| 42 |
seq_list = os.listdir(in_loc_dir)
|
| 43 |
idx = 1
|
| 44 |
for seq in seq_list:
|
|
@@ -47,11 +61,11 @@ def main(args):
|
|
| 47 |
print(" ")
|
| 48 |
idx += 1
|
| 49 |
in_seq_dir = os.path.join(in_loc_dir, seq)
|
| 50 |
-
infer_seq(in_seq_dir, out_dir, seq,
|
| 51 |
|
| 52 |
-
def infer_seq(in_dir, out_dir, seq_name,
|
| 53 |
|
| 54 |
-
|
| 55 |
|
| 56 |
image_meter_width = -1
|
| 57 |
image_meter_height = -1
|
|
@@ -68,21 +82,18 @@ def infer_seq(in_dir, out_dir, seq_name, weights, metadata_path):
|
|
| 68 |
|
| 69 |
# create dataloader
|
| 70 |
dataloader = create_dataloader_frames_only(in_dir)
|
| 71 |
-
|
| 72 |
-
# run detection + tracking
|
| 73 |
-
model, device = setup_model(weights)
|
| 74 |
|
| 75 |
try:
|
| 76 |
-
inference, width, height = do_detection(dataloader, model, device
|
| 77 |
except:
|
| 78 |
print("Error in " + seq_name)
|
| 79 |
with open(os.path.join(out_dir, "ERROR_" + seq_name + ".txt"), 'w') as f:
|
| 80 |
f.write("ERROR")
|
| 81 |
return
|
| 82 |
|
| 83 |
-
all_preds, real_width, real_height = do_suppression(dataloader, inference, width, height,
|
| 84 |
|
| 85 |
-
results = do_tracking(all_preds, image_meter_width, image_meter_height,
|
| 86 |
|
| 87 |
mot_rows = []
|
| 88 |
for frame in results['frames']:
|
|
@@ -118,6 +129,7 @@ def argument_parser():
|
|
| 118 |
parser.add_argument("--frames", required=True, help="Path to frame directory. Required.")
|
| 119 |
parser.add_argument("--metadata", required=True, help="Path to metadata directory. Required.")
|
| 120 |
parser.add_argument("--output", required=True, help="Path to output directory. Required.")
|
|
|
|
| 121 |
parser.add_argument("--weights", default='models/v5m_896_300best.pt', help="Path to saved YOLOv5 weights. Default: ../models/v5m_896_300best.pt")
|
| 122 |
return parser
|
| 123 |
|
|
|
|
| 26 |
print("In task...")
|
| 27 |
print("Cuda available in task?", torch.cuda.is_available())
|
| 28 |
|
| 29 |
+
# setup config
|
| 30 |
+
config = json.loads(args.config)
|
| 31 |
+
if "conf_threshold" not in config: config['conf_threshold'] = 0.3
|
| 32 |
+
if "nms_iou" not in config: config['nms_iou'] = 0.3
|
| 33 |
+
if "min_length" not in config: config['min_length'] = 0.3
|
| 34 |
+
if "max_age" not in config: config['max_age'] = 20
|
| 35 |
+
if "iou_threshold" not in config: config['iou_threshold'] = 0.01
|
| 36 |
+
if "min_hits" not in config: config['min_hits'] = 11
|
| 37 |
+
|
| 38 |
+
print(config)
|
| 39 |
+
|
| 40 |
dirname = args.frames
|
| 41 |
|
| 42 |
+
locations = ["kenai-val"]
|
| 43 |
for loc in locations:
|
| 44 |
|
| 45 |
in_loc_dir = os.path.join(dirname, loc)
|
|
|
|
| 50 |
print(out_dir)
|
| 51 |
print(metadata_path)
|
| 52 |
|
| 53 |
+
# run detection + tracking
|
| 54 |
+
model, device = setup_model(args.weights)
|
| 55 |
+
|
| 56 |
seq_list = os.listdir(in_loc_dir)
|
| 57 |
idx = 1
|
| 58 |
for seq in seq_list:
|
|
|
|
| 61 |
print(" ")
|
| 62 |
idx += 1
|
| 63 |
in_seq_dir = os.path.join(in_loc_dir, seq)
|
| 64 |
+
infer_seq(in_seq_dir, out_dir, config, seq, model, device, metadata_path)
|
| 65 |
|
| 66 |
+
def infer_seq(in_dir, out_dir, config, seq_name, model, device, metadata_path):
|
| 67 |
|
| 68 |
+
#progress_log = lambda p, m: 0
|
| 69 |
|
| 70 |
image_meter_width = -1
|
| 71 |
image_meter_height = -1
|
|
|
|
| 82 |
|
| 83 |
# create dataloader
|
| 84 |
dataloader = create_dataloader_frames_only(in_dir)
|
|
|
|
|
|
|
|
|
|
| 85 |
|
| 86 |
try:
|
| 87 |
+
inference, width, height = do_detection(dataloader, model, device)
|
| 88 |
except:
|
| 89 |
print("Error in " + seq_name)
|
| 90 |
with open(os.path.join(out_dir, "ERROR_" + seq_name + ".txt"), 'w') as f:
|
| 91 |
f.write("ERROR")
|
| 92 |
return
|
| 93 |
|
| 94 |
+
all_preds, real_width, real_height = do_suppression(dataloader, inference, width, height, conf_thres=config['conf_threshold'], iou_thres=config['nms_iou'])
|
| 95 |
|
| 96 |
+
results = do_tracking(all_preds, image_meter_width, image_meter_height, min_length=config['min_length'], max_age=config['max_age'], iou_thres=config['iou_threshold'], min_hits=config['min_hits'])
|
| 97 |
|
| 98 |
mot_rows = []
|
| 99 |
for frame in results['frames']:
|
|
|
|
| 129 |
parser.add_argument("--frames", required=True, help="Path to frame directory. Required.")
|
| 130 |
parser.add_argument("--metadata", required=True, help="Path to metadata directory. Required.")
|
| 131 |
parser.add_argument("--output", required=True, help="Path to output directory. Required.")
|
| 132 |
+
parser.add_argument("--config", default="{}", help="Config object. Required.")
|
| 133 |
parser.add_argument("--weights", default='models/v5m_896_300best.pt', help="Path to saved YOLOv5 weights. Default: ../models/v5m_896_300best.pt")
|
| 134 |
return parser
|
| 135 |
|