Spaces:
Runtime error
Runtime error
Commit ·
17fa97d
1
Parent(s): b02b8a0
Hyperparameter settings
Browse files- app.py +29 -9
- gradio_scripts/upload_ui.py +15 -1
- inference.py +4 -4
- main.py +2 -5
- visualizer.py +1 -1
app.py
CHANGED
|
@@ -27,16 +27,24 @@ result = {}
|
|
| 27 |
|
| 28 |
|
| 29 |
# Called when an Aris file is uploaded for inference
|
| 30 |
-
def on_aris_input(file_list, model_id):
|
| 31 |
-
|
| 32 |
-
print(model_id)
|
| 33 |
-
print(models[model_id] if model_id in models else models['master'])
|
| 34 |
|
| 35 |
# Reset Result
|
| 36 |
reset_state(result, state)
|
| 37 |
state['files'] = file_list
|
| 38 |
state['total'] = len(file_list)
|
| 39 |
-
state['
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
|
| 41 |
# Update loading_space to start inference on first file
|
| 42 |
return {
|
|
@@ -130,8 +138,10 @@ def infer_next(_, progress=gr.Progress()):
|
|
| 130 |
file_name = file_info[0].split("/")[-1]
|
| 131 |
bytes = file_info[1]
|
| 132 |
valid, file_path, dir_name = save_data(bytes, file_name)
|
| 133 |
-
|
| 134 |
-
print(
|
|
|
|
|
|
|
| 135 |
|
| 136 |
# Check that the file was valid
|
| 137 |
if not valid:
|
|
@@ -143,8 +153,18 @@ def infer_next(_, progress=gr.Progress()):
|
|
| 143 |
# Send uploaded file to AWS
|
| 144 |
upload_file(file_path, "fishcounting", "webapp_uploads/" + file_name)
|
| 145 |
|
|
|
|
|
|
|
| 146 |
# Do inference
|
| 147 |
-
json_result, json_filepath, zip_filepath, video_filepath, marking_filepath = predict_task(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
|
| 149 |
# Store result for that file
|
| 150 |
result['json_result'].append(json_result)
|
|
@@ -370,7 +390,7 @@ with demo:
|
|
| 370 |
inference_comps = [inference_handler, master_tabs, components['cancelBtn'], components['skipBtn']]
|
| 371 |
|
| 372 |
# When a file is uploaded to the input, tell the inference_handler to start inference
|
| 373 |
-
input.upload(on_aris_input, [input
|
| 374 |
|
| 375 |
# When inference handler updates, tell result_handler to show the new result
|
| 376 |
# Also, add inference_handler as the output in order to have it display the progress
|
|
|
|
| 27 |
|
| 28 |
|
| 29 |
# Called when an Aris file is uploaded for inference
|
| 30 |
+
def on_aris_input(file_list, model_id, conf_thresh, iou_thresh, min_hits, max_age):
|
|
|
|
|
|
|
|
|
|
| 31 |
|
| 32 |
# Reset Result
|
| 33 |
reset_state(result, state)
|
| 34 |
state['files'] = file_list
|
| 35 |
state['total'] = len(file_list)
|
| 36 |
+
state['hyperparams'] = {
|
| 37 |
+
'model': models[model_id] if model_id in models else models['master'],
|
| 38 |
+
'conf_thresh': conf_thresh,
|
| 39 |
+
'iou_thresh': iou_thresh,
|
| 40 |
+
'min_hits': min_hits,
|
| 41 |
+
'max_age': max_age,
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
print(" ")
|
| 45 |
+
print("Running with:")
|
| 46 |
+
print(state['hyperparams'])
|
| 47 |
+
print(" ")
|
| 48 |
|
| 49 |
# Update loading_space to start inference on first file
|
| 50 |
return {
|
|
|
|
| 138 |
file_name = file_info[0].split("/")[-1]
|
| 139 |
bytes = file_info[1]
|
| 140 |
valid, file_path, dir_name = save_data(bytes, file_name)
|
| 141 |
+
|
| 142 |
+
print("Directory: ", dir_name)
|
| 143 |
+
print("Aris input: ", file_path)
|
| 144 |
+
print(" ")
|
| 145 |
|
| 146 |
# Check that the file was valid
|
| 147 |
if not valid:
|
|
|
|
| 153 |
# Send uploaded file to AWS
|
| 154 |
upload_file(file_path, "fishcounting", "webapp_uploads/" + file_name)
|
| 155 |
|
| 156 |
+
hyperparams = state['hyperparams']
|
| 157 |
+
|
| 158 |
# Do inference
|
| 159 |
+
json_result, json_filepath, zip_filepath, video_filepath, marking_filepath = predict_task(
|
| 160 |
+
file_path,
|
| 161 |
+
weights = hyperparams['model'],
|
| 162 |
+
conf_thresh = hyperparams['conf_thresh'],
|
| 163 |
+
iou_thresh = hyperparams['iou_thresh'],
|
| 164 |
+
min_hits = hyperparams['min_hits'],
|
| 165 |
+
max_age = hyperparams['max_age'],
|
| 166 |
+
gradio_progress=set_progress
|
| 167 |
+
)
|
| 168 |
|
| 169 |
# Store result for that file
|
| 170 |
result['json_result'].append(json_result)
|
|
|
|
| 390 |
inference_comps = [inference_handler, master_tabs, components['cancelBtn'], components['skipBtn']]
|
| 391 |
|
| 392 |
# When a file is uploaded to the input, tell the inference_handler to start inference
|
| 393 |
+
input.upload(on_aris_input, [input] + components['hyperparams'], inference_comps)
|
| 394 |
|
| 395 |
# When inference handler updates, tell result_handler to show the new result
|
| 396 |
# Also, add inference_handler as the output in order to have it display the progress
|
gradio_scripts/upload_ui.py
CHANGED
|
@@ -15,7 +15,21 @@ def Upload_Gradio(gradio_components):
|
|
| 15 |
|
| 16 |
gr.HTML("<p align='center' style='font-size: large;font-style: italic;'>Submit an .aris file to analyze result.</p>")
|
| 17 |
|
| 18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
#Input field for aris submission
|
| 21 |
gradio_components['input'] = File(file_types=[".aris", ".ddf"], type="binary", label="ARIS Input", file_count="multiple")
|
|
|
|
| 15 |
|
| 16 |
gr.HTML("<p align='center' style='font-size: large;font-style: italic;'>Submit an .aris file to analyze result.</p>")
|
| 17 |
|
| 18 |
+
with gr.Accordion("Advanced Settings", open=False):
|
| 19 |
+
settings = []
|
| 20 |
+
settings.append(gr.Dropdown(label="Model", value="master", choices=list(models.keys())))
|
| 21 |
+
|
| 22 |
+
gr.Markdown("Detection Parameters")
|
| 23 |
+
with gr.Row():
|
| 24 |
+
settings.append(gr.Slider(0, 1, value=0.05, label="Confidence Threshold", info="Confidence cutoff for detection boxes"))
|
| 25 |
+
settings.append(gr.Slider(0, 1, value=0.2, label="NMS IoU", info="IoU threshold for non-max suppression"))
|
| 26 |
+
|
| 27 |
+
gr.Markdown("Tracking Parameters")
|
| 28 |
+
with gr.Row():
|
| 29 |
+
settings.append(gr.Slider(0, 100, value=16, label="Min Hits", info="Minimum number of frames a fish has to appear in to count"))
|
| 30 |
+
settings.append(gr.Slider(0, 100, value=14, label="Max Age", info="Max age of occlusion before track is split"))
|
| 31 |
+
|
| 32 |
+
gradio_components['hyperparams'] = settings
|
| 33 |
|
| 34 |
#Input field for aris submission
|
| 35 |
gradio_components['input'] = File(file_types=[".aris", ".ddf"], type="binary", label="ARIS Input", file_count="multiple")
|
inference.py
CHANGED
|
@@ -48,7 +48,7 @@ def norm(bbox, w, h):
|
|
| 48 |
bb[3] /= h
|
| 49 |
return bb
|
| 50 |
|
| 51 |
-
def do_full_inference(dataloader, image_meter_width, image_meter_height, gp=None, weights=WEIGHTS):
|
| 52 |
|
| 53 |
model, device = setup_model(weights)
|
| 54 |
|
|
@@ -78,15 +78,15 @@ def do_full_inference(dataloader, image_meter_width, image_meter_height, gp=None
|
|
| 78 |
return
|
| 79 |
|
| 80 |
|
| 81 |
-
outputs = do_suppression(inference, gp=gp)
|
| 82 |
|
| 83 |
#do_confidence_boost(inference, outputs, gp=gp)
|
| 84 |
|
| 85 |
-
#new_outputs = do_suppression(inference, gp=gp)
|
| 86 |
|
| 87 |
all_preds, real_width, real_height = format_predictions(image_shapes, outputs, width, height, gp=gp)
|
| 88 |
|
| 89 |
-
results = do_tracking(all_preds, image_meter_width, image_meter_height, gp=gp)
|
| 90 |
|
| 91 |
return results
|
| 92 |
|
|
|
|
| 48 |
bb[3] /= h
|
| 49 |
return bb
|
| 50 |
|
| 51 |
+
def do_full_inference(dataloader, image_meter_width, image_meter_height, gp=None, weights=WEIGHTS, conf_thresh=CONF_THRES, nms_iou=NMS_IOU, min_hits=MIN_HITS, max_age=MAX_AGE):
|
| 52 |
|
| 53 |
model, device = setup_model(weights)
|
| 54 |
|
|
|
|
| 78 |
return
|
| 79 |
|
| 80 |
|
| 81 |
+
outputs = do_suppression(inference, conf_thres=conf_thresh, iou_thres=nms_iou, gp=gp)
|
| 82 |
|
| 83 |
#do_confidence_boost(inference, outputs, gp=gp)
|
| 84 |
|
| 85 |
+
#new_outputs = do_suppression(inference, conf_thres=conf_thresh, iou_thres=nms_iou, gp=gp)
|
| 86 |
|
| 87 |
all_preds, real_width, real_height = format_predictions(image_shapes, outputs, width, height, gp=gp)
|
| 88 |
|
| 89 |
+
results = do_tracking(all_preds, image_meter_width, image_meter_height, min_hits=min_hits, max_age=max_age, gp=gp)
|
| 90 |
|
| 91 |
return results
|
| 92 |
|
main.py
CHANGED
|
@@ -7,9 +7,7 @@ from dataloader import create_dataloader_aris
|
|
| 7 |
from inference import do_full_inference, json_dump_round_float
|
| 8 |
from visualizer import generate_video_batches
|
| 9 |
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
def predict_task(filepath, weights=WEIGHTS, gradio_progress=None):
|
| 13 |
"""
|
| 14 |
Main processing task to be run in gradio
|
| 15 |
- Writes aris frames to dirname(filepath)/frames/{i}.jpg
|
|
@@ -25,7 +23,6 @@ def predict_task(filepath, weights=WEIGHTS, gradio_progress=None):
|
|
| 25 |
if (gradio_progress): gradio_progress(0, "In task...")
|
| 26 |
print("Cuda available in task?", torch.cuda.is_available())
|
| 27 |
|
| 28 |
-
print(filepath)
|
| 29 |
dirname = os.path.dirname(filepath)
|
| 30 |
filename = os.path.basename(filepath).replace(".aris","").replace(".ddf","")
|
| 31 |
results_filepath = os.path.join(dirname, f"{filename}_results.json")
|
|
@@ -48,7 +45,7 @@ def predict_task(filepath, weights=WEIGHTS, gradio_progress=None):
|
|
| 48 |
frame_rate = dataset.didson.info['framerate']
|
| 49 |
|
| 50 |
# run detection + tracking
|
| 51 |
-
results = do_full_inference(dataloader, image_meter_width, image_meter_height, gp=gradio_progress, weights=weights)
|
| 52 |
|
| 53 |
# re-index results if desired - this should be done before writing the file
|
| 54 |
results = prep_for_mm(results)
|
|
|
|
| 7 |
from inference import do_full_inference, json_dump_round_float
|
| 8 |
from visualizer import generate_video_batches
|
| 9 |
|
| 10 |
+
def predict_task(filepath, weights, conf_thresh, iou_thresh, min_hits, max_age, gradio_progress=None):
|
|
|
|
|
|
|
| 11 |
"""
|
| 12 |
Main processing task to be run in gradio
|
| 13 |
- Writes aris frames to dirname(filepath)/frames/{i}.jpg
|
|
|
|
| 23 |
if (gradio_progress): gradio_progress(0, "In task...")
|
| 24 |
print("Cuda available in task?", torch.cuda.is_available())
|
| 25 |
|
|
|
|
| 26 |
dirname = os.path.dirname(filepath)
|
| 27 |
filename = os.path.basename(filepath).replace(".aris","").replace(".ddf","")
|
| 28 |
results_filepath = os.path.join(dirname, f"{filename}_results.json")
|
|
|
|
| 45 |
frame_rate = dataset.didson.info['framerate']
|
| 46 |
|
| 47 |
# run detection + tracking
|
| 48 |
+
results = do_full_inference(dataloader, image_meter_width, image_meter_height, gp=gradio_progress, weights=weights, conf_thresh=conf_thresh, nms_iou=iou_thresh, min_hits=min_hits, max_age=max_age)
|
| 49 |
|
| 50 |
# re-index results if desired - this should be done before writing the file
|
| 51 |
results = prep_for_mm(results)
|
visualizer.py
CHANGED
|
@@ -110,7 +110,7 @@ def get_video_frames(frames, preds, frame_rate, image_meter_width=None, image_me
|
|
| 110 |
cv2.putText(frame, f'Left count: {clip_pr_counts[FONT_THICKNESS]}', (BORDER_PAD, h-BORDER_PAD-LINE_HEIGHT*2), cv2.FONT_HERSHEY_SIMPLEX, FONT_SCALE, WHITE, FONT_THICKNESS, cv2.LINE_AA, False)
|
| 111 |
cv2.putText(frame, f'Other fish: {clip_pr_counts[2]}', (BORDER_PAD, h-BORDER_PAD-LINE_HEIGHT*1), cv2.FONT_HERSHEY_SIMPLEX, FONT_SCALE, WHITE, FONT_THICKNESS, cv2.LINE_AA, False)
|
| 112 |
# cv2.putText(frame, f'Upstream: {preds["upstream_direction"]}', (0, h-1-LINE_HEIGHT*1), cv2.FONT_HERSHEY_SIMPLEX, FONT_SCALE, WHITE, FONT_THICKNESS, cv2.LINE_AA, False)
|
| 113 |
-
cv2.putText(frame, f'Frame: {i}', (BORDER_PAD, h-BORDER_PAD-LINE_HEIGHT*0), cv2.FONT_HERSHEY_SIMPLEX, FONT_SCALE, WHITE, FONT_THICKNESS, cv2.LINE_AA, False)
|
| 114 |
|
| 115 |
vid_frames.append(frame)
|
| 116 |
|
|
|
|
| 110 |
cv2.putText(frame, f'Left count: {clip_pr_counts[FONT_THICKNESS]}', (BORDER_PAD, h-BORDER_PAD-LINE_HEIGHT*2), cv2.FONT_HERSHEY_SIMPLEX, FONT_SCALE, WHITE, FONT_THICKNESS, cv2.LINE_AA, False)
|
| 111 |
cv2.putText(frame, f'Other fish: {clip_pr_counts[2]}', (BORDER_PAD, h-BORDER_PAD-LINE_HEIGHT*1), cv2.FONT_HERSHEY_SIMPLEX, FONT_SCALE, WHITE, FONT_THICKNESS, cv2.LINE_AA, False)
|
| 112 |
# cv2.putText(frame, f'Upstream: {preds["upstream_direction"]}', (0, h-1-LINE_HEIGHT*1), cv2.FONT_HERSHEY_SIMPLEX, FONT_SCALE, WHITE, FONT_THICKNESS, cv2.LINE_AA, False)
|
| 113 |
+
cv2.putText(frame, f'Frame: {start_frame+i}', (BORDER_PAD, h-BORDER_PAD-LINE_HEIGHT*0), cv2.FONT_HERSHEY_SIMPLEX, FONT_SCALE, WHITE, FONT_THICKNESS, cv2.LINE_AA, False)
|
| 114 |
|
| 115 |
vid_frames.append(frame)
|
| 116 |
|