Spaces:
Runtime error
Runtime error
Commit
·
c376f3c
1
Parent(s):
a63e231
Autoload parameters
Browse files- InferenceConfig.py +13 -1
- gradio_scripts/upload_ui.py +14 -12
InferenceConfig.py
CHANGED
|
@@ -5,6 +5,11 @@ class TrackerType(Enum):
|
|
| 5 |
CONF_BOOST = 1
|
| 6 |
BYTETRACK = 2
|
| 7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
### Configuration options
|
| 9 |
WEIGHTS = 'models/v5m_896_300best.pt'
|
| 10 |
# will need to configure these based on GPU hardware
|
|
@@ -16,7 +21,7 @@ MAX_AGE = 20 # time until missing fish get's new id
|
|
| 16 |
MIN_HITS = 11 # minimum number of frames with a specific fish for it to count
|
| 17 |
MIN_LENGTH = 0.3 # minimum fish length, in meters
|
| 18 |
IOU_THRES = 0.01 # IOU threshold for tracking
|
| 19 |
-
MIN_TRAVEL =
|
| 20 |
DEFAULT_TRACKER = TrackerType.BYTETRACK
|
| 21 |
|
| 22 |
class InferenceConfig:
|
|
@@ -50,6 +55,13 @@ class InferenceConfig:
|
|
| 50 |
self.byte_low_conf = low
|
| 51 |
self.byte_high_conf = high
|
| 52 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
def to_dict(self):
|
| 54 |
dict = {
|
| 55 |
'weights': self.weights,
|
|
|
|
| 5 |
CONF_BOOST = 1
|
| 6 |
BYTETRACK = 2
|
| 7 |
|
| 8 |
+
def toString(val):
|
| 9 |
+
if val == TrackerType.NONE: return "None"
|
| 10 |
+
if val == TrackerType.CONF_BOOST: return "Confidence Boost"
|
| 11 |
+
if val == TrackerType.BYTETRACK: return "ByteTrack"
|
| 12 |
+
|
| 13 |
### Configuration options
|
| 14 |
WEIGHTS = 'models/v5m_896_300best.pt'
|
| 15 |
# will need to configure these based on GPU hardware
|
|
|
|
| 21 |
MIN_HITS = 11 # minimum number of frames with a specific fish for it to count
|
| 22 |
MIN_LENGTH = 0.3 # minimum fish length, in meters
|
| 23 |
IOU_THRES = 0.01 # IOU threshold for tracking
|
| 24 |
+
MIN_TRAVEL = 0 # Minimum distance a track has to travel
|
| 25 |
DEFAULT_TRACKER = TrackerType.BYTETRACK
|
| 26 |
|
| 27 |
class InferenceConfig:
|
|
|
|
| 55 |
self.byte_low_conf = low
|
| 56 |
self.byte_high_conf = high
|
| 57 |
|
| 58 |
+
def find_model(self, model_list):
|
| 59 |
+
for model_name, model_path in enumerate(model_list):
|
| 60 |
+
if model_path == self.weights:
|
| 61 |
+
return model_name
|
| 62 |
+
return None
|
| 63 |
+
|
| 64 |
+
|
| 65 |
def to_dict(self):
|
| 66 |
dict = {
|
| 67 |
'weights': self.weights,
|
gradio_scripts/upload_ui.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
from gradio_scripts.file_reader import File
|
|
|
|
| 3 |
|
| 4 |
|
| 5 |
models = {
|
|
@@ -17,35 +18,36 @@ def Upload_Gradio(gradio_components):
|
|
| 17 |
|
| 18 |
gr.HTML("<p align='center' style='font-size: large;font-style: italic;'>Submit an .aris file to analyze result.</p>")
|
| 19 |
|
|
|
|
| 20 |
settings = []
|
| 21 |
with gr.Accordion("Advanced Settings", open=False):
|
| 22 |
-
settings.append(gr.Dropdown(label="Model", value=
|
| 23 |
|
| 24 |
gr.Markdown("Detection Parameters")
|
| 25 |
with gr.Row():
|
| 26 |
-
settings.append(gr.Slider(0, 1, value=
|
| 27 |
-
settings.append(gr.Slider(0, 1, value=
|
| 28 |
|
| 29 |
gr.Markdown("Tracking Parameters")
|
| 30 |
with gr.Row():
|
| 31 |
-
settings.append(gr.Slider(0, 100, value=
|
| 32 |
-
settings.append(gr.Slider(0, 100, value=
|
| 33 |
|
| 34 |
-
tracker = gr.Dropdown(["None", "Confidence Boost", "ByteTrack"], label="Associative Tracking"
|
| 35 |
settings.append(tracker)
|
| 36 |
with gr.Row(visible=False) as track_row:
|
| 37 |
-
settings.append(gr.Slider(0, 5, value=
|
| 38 |
-
settings.append(gr.Slider(0, 1, value=
|
| 39 |
tracker.change(lambda x: gr.update(visible=(x=="Confidence Boost")), tracker, track_row)
|
| 40 |
with gr.Row(visible=False) as track_row:
|
| 41 |
-
settings.append(gr.Slider(0, 1, value=
|
| 42 |
-
settings.append(gr.Slider(0, 1, value=
|
| 43 |
tracker.change(lambda x: gr.update(visible=(x=="ByteTrack")), tracker, track_row)
|
| 44 |
|
| 45 |
gr.Markdown("Other")
|
| 46 |
with gr.Row():
|
| 47 |
-
settings.append(gr.Slider(0, 3, value=
|
| 48 |
-
settings.append(gr.Slider(0, 5, value=
|
| 49 |
|
| 50 |
gradio_components['hyperparams'] = settings
|
| 51 |
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
from gradio_scripts.file_reader import File
|
| 3 |
+
from InferenceConfig import InferenceConfig, TrackerType
|
| 4 |
|
| 5 |
|
| 6 |
models = {
|
|
|
|
| 18 |
|
| 19 |
gr.HTML("<p align='center' style='font-size: large;font-style: italic;'>Submit an .aris file to analyze result.</p>")
|
| 20 |
|
| 21 |
+
default_settings = InferenceConfig()
|
| 22 |
settings = []
|
| 23 |
with gr.Accordion("Advanced Settings", open=False):
|
| 24 |
+
settings.append(gr.Dropdown(label="Model", value=default_settings.find_model(models), choices=list(models.keys())))
|
| 25 |
|
| 26 |
gr.Markdown("Detection Parameters")
|
| 27 |
with gr.Row():
|
| 28 |
+
settings.append(gr.Slider(0, 1, value=default_settings.conf_thresh, label="Confidence Threshold", info="Confidence cutoff for detection boxes"))
|
| 29 |
+
settings.append(gr.Slider(0, 1, value=default_settings.nms_iou, label="NMS IoU", info="IoU threshold for non-max suppression"))
|
| 30 |
|
| 31 |
gr.Markdown("Tracking Parameters")
|
| 32 |
with gr.Row():
|
| 33 |
+
settings.append(gr.Slider(0, 100, value=default_settings.min_hits, label="Min Hits", info="Minimum number of frames a fish has to appear in to count"))
|
| 34 |
+
settings.append(gr.Slider(0, 100, value=default_settings.max_age, label="Max Age", info="Max age of occlusion before track is split"))
|
| 35 |
|
| 36 |
+
tracker = gr.Dropdown(["None", "Confidence Boost", "ByteTrack"], value=TrackerType.toString(default_settings.associative_tracker), label="Associative Tracking")
|
| 37 |
settings.append(tracker)
|
| 38 |
with gr.Row(visible=False) as track_row:
|
| 39 |
+
settings.append(gr.Slider(0, 5, value=default_settings.boost_power, label="Boost Power", info=""))
|
| 40 |
+
settings.append(gr.Slider(0, 1, value=default_settings.boost_decay, label="Boost Decay", info=""))
|
| 41 |
tracker.change(lambda x: gr.update(visible=(x=="Confidence Boost")), tracker, track_row)
|
| 42 |
with gr.Row(visible=False) as track_row:
|
| 43 |
+
settings.append(gr.Slider(0, 1, value=default_settings.byte_low_conf, label="Low Conf Threshold", info=""))
|
| 44 |
+
settings.append(gr.Slider(0, 1, value=default_settings.byte_high_conf, label="High Conf Threshold", info=""))
|
| 45 |
tracker.change(lambda x: gr.update(visible=(x=="ByteTrack")), tracker, track_row)
|
| 46 |
|
| 47 |
gr.Markdown("Other")
|
| 48 |
with gr.Row():
|
| 49 |
+
settings.append(gr.Slider(0, 3, value=default_settings.min_length, label="Min Length", info="Minimum length of fish (meters) in order for it to count"))
|
| 50 |
+
settings.append(gr.Slider(0, 5, value=default_settings.min_travel, label="Min Travel", info="Minimum travel distance of track (meters) in order for it to count"))
|
| 51 |
|
| 52 |
gradio_components['hyperparams'] = settings
|
| 53 |
|