BilalSardar commited on
Commit
a65973e
·
1 Parent(s): 0e58a35

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +91 -1017
app.py CHANGED
@@ -1,1041 +1,115 @@
1
- import concurrent.futures
2
- import os
3
- import sys
4
- from multiprocessing import freeze_support
5
- from pathlib import Path
6
-
7
  import gradio as gr
8
  import librosa
9
- import webview
10
 
11
- import analyze
12
  import config as cfg
13
  import segments
14
  import species
15
  import utils
16
  from train import trainModel
17
 
18
- _WINDOW: webview.Window
19
- OUTPUT_TYPE_MAP = {"Raven selection table": "table", "Audacity": "audacity", "R": "r", "CSV": "csv"}
20
- ORIGINAL_MODEL_PATH = cfg.MODEL_PATH
21
- ORIGINAL_MDATA_MODEL_PATH = cfg.MDATA_MODEL_PATH
22
- ORIGINAL_LABELS_FILE = cfg.LABELS_FILE
23
- ORIGINAL_TRANSLATED_LABELS_PATH = cfg.TRANSLATED_LABELS_PATH
24
-
25
-
26
- def analyzeFile_wrapper(entry):
27
- return (entry[0], analyze.analyzeFile(entry))
28
-
29
-
30
- def extractSegments_wrapper(entry):
31
- return (entry[0][0], segments.extractSegments(entry))
32
-
33
-
34
- def validate(value, msg):
35
- """Checks if the value ist not falsy.
36
-
37
- If the value is falsy, an error will be raised.
38
-
39
- Args:
40
- value: Value to be tested.
41
- msg: Message in case of an error.
42
- """
43
- if not value:
44
- raise gr.Error(msg)
45
-
46
-
47
- def runSingleFileAnalysis(
48
- input_path,
49
- confidence,
50
- sensitivity,
51
- overlap,
52
- species_list_choice,
53
- species_list_file,
54
- lat,
55
- lon,
56
- week,
57
- use_yearlong,
58
- sf_thresh,
59
- custom_classifier_file,
60
- locale,
61
- ):
62
- validate(input_path, "Please select a file.")
63
-
64
- return runAnalysis(
65
- input_path,
66
- None,
67
- confidence,
68
- sensitivity,
69
- overlap,
70
- species_list_choice,
71
- species_list_file,
72
- lat,
73
- lon,
74
- week,
75
- use_yearlong,
76
- sf_thresh,
77
- custom_classifier_file,
78
- "csv",
79
- "en" if not locale else locale,
80
- 1,
81
- 4,
82
- None,
83
- progress=None,
84
- )
85
-
86
-
87
- def runBatchAnalysis(
88
- output_path,
89
- confidence,
90
- sensitivity,
91
- overlap,
92
- species_list_choice,
93
- species_list_file,
94
- lat,
95
- lon,
96
- week,
97
- use_yearlong,
98
- sf_thresh,
99
- custom_classifier_file,
100
- output_type,
101
- locale,
102
- batch_size,
103
- threads,
104
- input_dir,
105
- progress=gr.Progress(),
106
- ):
107
- validate(input_dir, "Please select a directory.")
108
- batch_size = int(batch_size)
109
- threads = int(threads)
110
-
111
- if species_list_choice == _CUSTOM_SPECIES:
112
- validate(species_list_file, "Please select a species list.")
113
-
114
- return runAnalysis(
115
- None,
116
- output_path,
117
- confidence,
118
- sensitivity,
119
- overlap,
120
- species_list_choice,
121
- species_list_file,
122
- lat,
123
- lon,
124
- week,
125
- use_yearlong,
126
- sf_thresh,
127
- custom_classifier_file,
128
- output_type,
129
- "en" if not locale else locale,
130
- batch_size if batch_size and batch_size > 0 else 1,
131
- threads if threads and threads > 0 else 4,
132
- input_dir,
133
- progress,
134
- )
135
-
136
-
137
- def runAnalysis(
138
- input_path: str,
139
- output_path: str | None,
140
- confidence: float,
141
- sensitivity: float,
142
- overlap: float,
143
- species_list_choice: str,
144
- species_list_file,
145
- lat: float,
146
- lon: float,
147
- week: int,
148
- use_yearlong: bool,
149
- sf_thresh: float,
150
- custom_classifier_file,
151
- output_type: str,
152
- locale: str,
153
- batch_size: int,
154
- threads: int,
155
- input_dir: str,
156
- progress: gr.Progress | None,
157
- ):
158
- """Starts the analysis.
159
-
160
- Args:
161
- input_path: Either a file or directory.
162
- output_path: The output path for the result, if None the input_path is used
163
- confidence: The selected minimum confidence.
164
- sensitivity: The selected sensitivity.
165
- overlap: The selected segment overlap.
166
- species_list_choice: The choice for the species list.
167
- species_list_file: The selected custom species list file.
168
- lat: The selected latitude.
169
- lon: The selected longitude.
170
- week: The selected week of the year.
171
- use_yearlong: Use yearlong instead of week.
172
- sf_thresh: The threshold for the predicted species list.
173
- custom_classifier_file: Custom classifier to be used.
174
- output_type: The type of result to be generated.
175
- locale: The translation to be used.
176
- batch_size: The number of samples in a batch.
177
- threads: The number of threads to be used.
178
- input_dir: The input directory.
179
- progress: The gradio progress bar.
180
- """
181
- if progress is not None:
182
- progress(0, desc="Preparing ...")
183
-
184
- locale = locale.lower()
185
- # Load eBird codes, labels
186
- cfg.CODES = analyze.loadCodes()
187
- cfg.LABELS = utils.readLines(ORIGINAL_LABELS_FILE)
188
- cfg.LATITUDE, cfg.LONGITUDE, cfg.WEEK = lat, lon, -1 if use_yearlong else week
189
- cfg.LOCATION_FILTER_THRESHOLD = sf_thresh
190
-
191
- if species_list_choice == _CUSTOM_SPECIES:
192
- if not species_list_file or not species_list_file.name:
193
- cfg.SPECIES_LIST_FILE = None
194
- else:
195
- cfg.SPECIES_LIST_FILE = os.path.join(os.path.dirname(os.path.abspath(sys.argv[0])), species_list_file.name)
196
-
197
- if os.path.isdir(cfg.SPECIES_LIST_FILE):
198
- cfg.SPECIES_LIST_FILE = os.path.join(cfg.SPECIES_LIST_FILE, "species_list.txt")
199
-
200
- cfg.SPECIES_LIST = utils.readLines(cfg.SPECIES_LIST_FILE)
201
- cfg.CUSTOM_CLASSIFIER = None
202
- elif species_list_choice == _PREDICT_SPECIES:
203
- cfg.SPECIES_LIST_FILE = None
204
- cfg.CUSTOM_CLASSIFIER = None
205
- cfg.SPECIES_LIST = species.getSpeciesList(cfg.LATITUDE, cfg.LONGITUDE, cfg.WEEK, cfg.LOCATION_FILTER_THRESHOLD)
206
- elif species_list_choice == _CUSTOM_CLASSIFIER:
207
- if custom_classifier_file is None:
208
- raise gr.Error("No custom classifier selected.")
209
-
210
- # Set custom classifier?
211
- cfg.CUSTOM_CLASSIFIER = custom_classifier_file # we treat this as absolute path, so no need to join with dirname
212
- cfg.LABELS_FILE = custom_classifier_file.replace(".tflite", "_Labels.txt") # same for labels file
213
- cfg.LABELS = utils.readLines(cfg.LABELS_FILE)
214
- cfg.LATITUDE = -1
215
- cfg.LONGITUDE = -1
216
- cfg.SPECIES_LIST_FILE = None
217
- cfg.SPECIES_LIST = []
218
- locale = "en"
219
- else:
220
- cfg.SPECIES_LIST_FILE = None
221
- cfg.SPECIES_LIST = []
222
- cfg.CUSTOM_CLASSIFIER = None
223
-
224
- # Load translated labels
225
- lfile = os.path.join(cfg.TRANSLATED_LABELS_PATH, os.path.basename(cfg.LABELS_FILE).replace(".txt", f"_{locale}.txt"))
226
- if not locale in ["en"] and os.path.isfile(lfile):
227
- cfg.TRANSLATED_LABELS = utils.readLines(lfile)
228
- else:
229
- cfg.TRANSLATED_LABELS = cfg.LABELS
230
-
231
- if len(cfg.SPECIES_LIST) == 0:
232
- print(f"Species list contains {len(cfg.LABELS)} species")
233
- else:
234
- print(f"Species list contains {len(cfg.SPECIES_LIST)} species")
235
-
236
- # Set input and output path
237
- cfg.INPUT_PATH = input_path
238
-
239
- if input_dir:
240
- cfg.OUTPUT_PATH = output_path if output_path else input_dir
241
- else:
242
- cfg.OUTPUT_PATH = output_path if output_path else input_path.split(".", 1)[0] + ".csv"
243
-
244
- # Parse input files
245
- if input_dir:
246
- cfg.FILE_LIST = utils.collect_audio_files(input_dir)
247
- cfg.INPUT_PATH = input_dir
248
- elif os.path.isdir(cfg.INPUT_PATH):
249
- cfg.FILE_LIST = utils.collect_audio_files(cfg.INPUT_PATH)
250
- else:
251
- cfg.FILE_LIST = [cfg.INPUT_PATH]
252
-
253
- validate(cfg.FILE_LIST, "No audio files found.")
254
-
255
- # Set confidence threshold
256
- cfg.MIN_CONFIDENCE = confidence
257
-
258
- # Set sensitivity
259
- cfg.SIGMOID_SENSITIVITY = sensitivity
260
-
261
- # Set overlap
262
- cfg.SIG_OVERLAP = overlap
263
-
264
- # Set result type
265
- cfg.RESULT_TYPE = OUTPUT_TYPE_MAP[output_type] if output_type in OUTPUT_TYPE_MAP else output_type.lower()
266
-
267
- if not cfg.RESULT_TYPE in ["table", "audacity", "r", "csv"]:
268
- cfg.RESULT_TYPE = "table"
269
-
270
- # Set number of threads
271
- if input_dir:
272
- cfg.CPU_THREADS = max(1, int(threads))
273
- cfg.TFLITE_THREADS = 1
274
- else:
275
- cfg.CPU_THREADS = 1
276
- cfg.TFLITE_THREADS = max(1, int(threads))
277
-
278
- # Set batch size
279
- cfg.BATCH_SIZE = max(1, int(batch_size))
280
-
281
- flist = []
282
-
283
- for f in cfg.FILE_LIST:
284
- flist.append((f, cfg.getConfig()))
285
-
286
- result_list = []
287
-
288
- if progress is not None:
289
- progress(0, desc="Starting ...")
290
-
291
- # Analyze files
292
- if cfg.CPU_THREADS < 2:
293
- for entry in flist:
294
- result = analyzeFile_wrapper(entry)
295
-
296
- result_list.append(result)
297
- else:
298
- with concurrent.futures.ProcessPoolExecutor(max_workers=cfg.CPU_THREADS) as executor:
299
- futures = (executor.submit(analyzeFile_wrapper, arg) for arg in flist)
300
- for i, f in enumerate(concurrent.futures.as_completed(futures), start=1):
301
- if progress is not None:
302
- progress((i, len(flist)), total=len(flist), unit="files")
303
- result = f.result()
304
-
305
- result_list.append(result)
306
-
307
- return [[os.path.relpath(r[0], input_dir), r[1]] for r in result_list] if input_dir else cfg.OUTPUT_PATH
308
-
309
-
310
- _CUSTOM_SPECIES = "Custom species list"
311
- _PREDICT_SPECIES = "Species by location"
312
- _CUSTOM_CLASSIFIER = "Custom classifier"
313
- _ALL_SPECIES = "all species"
314
-
315
-
316
- def show_species_choice(choice: str):
317
- """Sets the visibility of the species list choices.
318
-
319
- Args:
320
- choice: The label of the currently active choice.
321
-
322
- Returns:
323
- A list of [
324
- Row update,
325
- File update,
326
- Column update,
327
- Column update,
328
- ]
329
- """
330
- if choice == _CUSTOM_SPECIES:
331
- return [
332
- gr.Row.update(visible=False),
333
- gr.File.update(visible=True),
334
- gr.Column.update(visible=False),
335
- gr.Column.update(visible=False),
336
- ]
337
- elif choice == _PREDICT_SPECIES:
338
- return [
339
- gr.Row.update(visible=True),
340
- gr.File.update(visible=False),
341
- gr.Column.update(visible=False),
342
- gr.Column.update(visible=False),
343
- ]
344
- elif choice == _CUSTOM_CLASSIFIER:
345
- return [
346
- gr.Row.update(visible=False),
347
- gr.File.update(visible=False),
348
- gr.Column.update(visible=True),
349
- gr.Column.update(visible=False),
350
- ]
351
-
352
- return [
353
- gr.Row.update(visible=False),
354
- gr.File.update(visible=False),
355
- gr.Column.update(visible=False),
356
- gr.Column.update(visible=True),
357
- ]
358
-
359
-
360
- def select_subdirectories():
361
- """Creates a directory selection dialog.
362
-
363
- Returns:
364
- A tuples of (directory, list of subdirectories) or (None, None) if the dialog was canceled.
365
- """
366
- dir_name = _WINDOW.create_file_dialog(webview.FOLDER_DIALOG)
367
-
368
- if dir_name:
369
- subdirs = utils.list_subdirectories(dir_name[0])
370
-
371
- return dir_name[0], [[d] for d in subdirs]
372
-
373
- return None, None
374
-
375
-
376
- def select_file(filetypes=()):
377
- """Creates a file selection dialog.
378
-
379
- Args:
380
- filetypes: List of filetypes to be filtered in the dialog.
381
-
382
- Returns:
383
- The selected file or None of the dialog was canceled.
384
- """
385
- files = _WINDOW.create_file_dialog(webview.OPEN_DIALOG, file_types=filetypes)
386
- return files[0] if files else None
387
-
388
-
389
- def format_seconds(secs: float):
390
- """Formats a number of seconds into a string.
391
-
392
- Formats the seconds into the format "h:mm:ss.ms"
393
-
394
- Args:
395
- secs: Number of seconds.
396
-
397
- Returns:
398
- A string with the formatted seconds.
399
- """
400
- hours, secs = divmod(secs, 3600)
401
- minutes, secs = divmod(secs, 60)
402
-
403
- return "{:2.0f}:{:02.0f}:{:06.3f}".format(hours, minutes, secs)
404
-
405
-
406
- def select_directory(collect_files=True):
407
- """Shows a directory selection system dialog.
408
-
409
- Uses the pywebview to create a system dialog.
410
-
411
- Args:
412
- collect_files: If True, also lists a files inside the directory.
413
-
414
- Returns:
415
- If collect_files==True, returns (directory path, list of (relative file path, audio length))
416
- else just the directory path.
417
- All values will be None of the dialog is cancelled.
418
- """
419
- dir_name = _WINDOW.create_file_dialog(webview.FOLDER_DIALOG)
420
-
421
- if collect_files:
422
- if not dir_name:
423
- return None, None
424
-
425
- files = utils.collect_audio_files(dir_name[0])
426
-
427
- return dir_name[0], [
428
- [os.path.relpath(file, dir_name[0]), format_seconds(librosa.get_duration(filename=file))] for file in files
429
- ]
430
-
431
- return dir_name[0] if dir_name else None
432
-
433
-
434
- def start_training(
435
- data_dir,
436
- crop_mode,
437
- crop_overlap,
438
- output_dir,
439
- classifier_name,
440
- epochs,
441
- batch_size,
442
- learning_rate,
443
- hidden_units,
444
- use_mixup,
445
- upsampling_ratio,
446
- upsampling_mode,
447
- model_format,
448
- progress=gr.Progress(),
449
- ):
450
- """Starts the training of a custom classifier.
451
-
452
- Args:
453
- data_dir: Directory containing the training data.
454
- output_dir: Directory for the new classifier.
455
- classifier_name: File name of the classifier.
456
- epochs: Number of epochs to train for.
457
- batch_size: Number of samples in one batch.
458
- learning_rate: Learning rate for training.
459
- hidden_units: If > 0 the classifier contains a further hidden layer.
460
- progress: The gradio progress bar.
461
-
462
- Returns:
463
- Returns a matplotlib.pyplot figure.
464
- """
465
- validate(data_dir, "Please select your Training data.")
466
- validate(output_dir, "Please select a directory for the classifier.")
467
- validate(classifier_name, "Please enter a valid name for the classifier.")
468
-
469
- if not epochs or epochs < 0:
470
- raise gr.Error("Please enter a valid number of epochs.")
471
-
472
- if not batch_size or batch_size < 0:
473
- raise gr.Error("Please enter a valid batch size.")
474
-
475
- if not learning_rate or learning_rate < 0:
476
- raise gr.Error("Please enter a valid learning rate.")
477
-
478
- if not hidden_units or hidden_units < 0:
479
- hidden_units = 0
480
-
481
- if progress is not None:
482
- progress((0, epochs), desc="Loading data & building classifier", unit="epoch")
483
-
484
- cfg.TRAIN_DATA_PATH = data_dir
485
- cfg.SAMPLE_CROP_MODE = crop_mode
486
- cfg.SIG_OVERLAP = crop_overlap
487
- cfg.CUSTOM_CLASSIFIER = str(Path(output_dir) / classifier_name)
488
- cfg.TRAIN_EPOCHS = int(epochs)
489
- cfg.TRAIN_BATCH_SIZE = int(batch_size)
490
- cfg.TRAIN_LEARNING_RATE = learning_rate
491
- cfg.TRAIN_HIDDEN_UNITS = int(hidden_units)
492
- cfg.TRAIN_WITH_MIXUP = use_mixup
493
- cfg.UPSAMPLING_RATIO = min(max(0, upsampling_ratio), 1)
494
- cfg.UPSAMPLING_MODE = upsampling_mode
495
- cfg.TRAINED_MODEL_OUTPUT_FORMAT = model_format
496
-
497
- def progression(epoch, logs=None):
498
- if progress is not None:
499
- if epoch + 1 == epochs:
500
- progress((epoch + 1, epochs), total=epochs, unit="epoch", desc=f"Saving at {cfg.CUSTOM_CLASSIFIER}")
501
- else:
502
- progress((epoch + 1, epochs), total=epochs, unit="epoch")
503
-
504
- history = trainModel(on_epoch_end=progression)
505
-
506
- if len(history.epoch) < epochs:
507
- gr.Info("Stopped early - validation metric not improving.")
508
-
509
- auprc = history.history["val_AUPRC"]
510
-
511
- import matplotlib.pyplot as plt
512
-
513
- fig = plt.figure()
514
- plt.plot(auprc)
515
- plt.ylabel("Area under precision-recall curve")
516
- plt.xlabel("Epoch")
517
-
518
- return fig
519
-
520
-
521
- def extract_segments(audio_dir, result_dir, output_dir, min_conf, num_seq, seq_length, threads, progress=gr.Progress()):
522
- validate(audio_dir, "No audio directory selected")
523
-
524
- if not result_dir:
525
- result_dir = audio_dir
526
-
527
- if not output_dir:
528
- output_dir = audio_dir
529
-
530
- if progress is not None:
531
- progress(0, desc="Searching files ...")
532
-
533
- # Parse audio and result folders
534
- cfg.FILE_LIST = segments.parseFolders(audio_dir, result_dir)
535
-
536
- # Set output folder
537
- cfg.OUTPUT_PATH = output_dir
538
-
539
- # Set number of threads
540
- cfg.CPU_THREADS = int(threads)
541
-
542
- # Set confidence threshold
543
- cfg.MIN_CONFIDENCE = max(0.01, min(0.99, min_conf))
544
-
545
- # Parse file list and make list of segments
546
- cfg.FILE_LIST = segments.parseFiles(cfg.FILE_LIST, max(1, int(num_seq)))
547
-
548
- # Add config items to each file list entry.
549
- # We have to do this for Windows which does not
550
- # support fork() and thus each process has to
551
- # have its own config. USE LINUX!
552
- flist = [(entry, max(cfg.SIG_LENGTH, float(seq_length)), cfg.getConfig()) for entry in cfg.FILE_LIST]
553
-
554
- result_list = []
555
-
556
- # Extract segments
557
- if cfg.CPU_THREADS < 2:
558
- for i, entry in enumerate(flist):
559
- result = extractSegments_wrapper(entry)
560
- result_list.append(result)
561
-
562
- if progress is not None:
563
- progress((i, len(flist)), total=len(flist), unit="files")
564
- else:
565
- with concurrent.futures.ProcessPoolExecutor(max_workers=cfg.CPU_THREADS) as executor:
566
- futures = (executor.submit(extractSegments_wrapper, arg) for arg in flist)
567
- for i, f in enumerate(concurrent.futures.as_completed(futures), start=1):
568
- if progress is not None:
569
- progress((i, len(flist)), total=len(flist), unit="files")
570
- result = f.result()
571
-
572
- result_list.append(result)
573
-
574
- return [[os.path.relpath(r[0], audio_dir), r[1]] for r in result_list]
575
-
576
-
577
- def sample_sliders(opened=True):
578
- """Creates the gradio accordion for the inference settings.
579
-
580
- Args:
581
- opened: If True the accordion is open on init.
582
-
583
- Returns:
584
- A tuple with the created elements:
585
- (Slider (min confidence), Slider (sensitivity), Slider (overlap))
586
- """
587
- with gr.Accordion("Inference settings", open=opened):
588
- with gr.Row():
589
- confidence_slider = gr.Slider(
590
- minimum=0, maximum=1, value=0.5, step=0.01, label="Minimum Confidence", info="Minimum confidence threshold."
591
- )
592
- sensitivity_slider = gr.Slider(
593
- minimum=0.5,
594
- maximum=1.5,
595
- value=1,
596
- step=0.01,
597
- label="Sensitivity",
598
- info="Detection sensitivity; Higher values result in higher sensitivity.",
599
- )
600
- overlap_slider = gr.Slider(
601
- minimum=0, maximum=2.99, value=0, step=0.01, label="Overlap", info="Overlap of prediction segments."
602
- )
603
-
604
- return confidence_slider, sensitivity_slider, overlap_slider
605
-
606
-
607
- def locale():
608
- """Creates the gradio elements for locale selection
609
-
610
- Reads the translated labels inside the checkpoints directory.
611
-
612
- Returns:
613
- The dropdown element.
614
- """
615
- label_files = os.listdir(os.path.join(os.path.dirname(sys.argv[0]), ORIGINAL_TRANSLATED_LABELS_PATH))
616
- options = ["EN"] + [label_file.rsplit("_", 1)[-1].split(".")[0].upper() for label_file in label_files]
617
-
618
- return gr.Dropdown(options, value="EN", label="Locale", info="Locale for the translated species common names.")
619
-
620
-
621
- def species_lists(opened=True):
622
- """Creates the gradio accordion for species selection.
623
-
624
- Args:
625
- opened: If True the accordion is open on init.
626
-
627
- Returns:
628
- A tuple with the created elements:
629
- (Radio (choice), File (custom species list), Slider (lat), Slider (lon), Slider (week), Slider (threshold), Checkbox (yearlong?), State (custom classifier))
630
- """
631
- with gr.Accordion("Species selection", open=opened):
632
- with gr.Row():
633
- species_list_radio = gr.Radio(
634
- [_CUSTOM_SPECIES, _PREDICT_SPECIES, _CUSTOM_CLASSIFIER, _ALL_SPECIES],
635
- value=_ALL_SPECIES,
636
- label="Species list",
637
- info="List of all possible species",
638
- elem_classes="d-block",
639
- )
640
-
641
- with gr.Column(visible=False) as position_row:
642
- lat_number = gr.Slider(
643
- minimum=-90, maximum=90, value=0, step=1, label="Latitude", info="Recording location latitude."
644
- )
645
- lon_number = gr.Slider(
646
- minimum=-180, maximum=180, value=0, step=1, label="Longitude", info="Recording location longitude."
647
- )
648
- with gr.Row():
649
- yearlong_checkbox = gr.Checkbox(True, label="Year-round")
650
- week_number = gr.Slider(
651
- minimum=1,
652
- maximum=48,
653
- value=1,
654
- step=1,
655
- interactive=False,
656
- label="Week",
657
- info="Week of the year when the recording was made. Values in [1, 48] (4 weeks per month).",
658
- )
659
-
660
- def onChange(use_yearlong):
661
- return gr.Slider.update(interactive=(not use_yearlong))
662
-
663
- yearlong_checkbox.change(onChange, inputs=yearlong_checkbox, outputs=week_number, show_progress=False)
664
- sf_thresh_number = gr.Slider(
665
- minimum=0.01,
666
- maximum=0.99,
667
- value=0.03,
668
- step=0.01,
669
- label="Location filter threshold",
670
- info="Minimum species occurrence frequency threshold for location filter.",
671
- )
672
-
673
- species_file_input = gr.File(file_types=[".txt"], info="Path to species list file or folder.", visible=False)
674
- empty_col = gr.Column()
675
-
676
- with gr.Column(visible=False) as custom_classifier_selector:
677
- classifier_selection_button = gr.Button("Select classifier")
678
- classifier_file_input = gr.Files(
679
- file_types=[".tflite"], info="Path to the custom classifier.", visible=False, interactive=False
680
- )
681
- selected_classifier_state = gr.State()
682
-
683
- def on_custom_classifier_selection_click():
684
- file = select_file(("TFLite classifier (*.tflite)",))
685
-
686
- if file:
687
- labels = os.path.splitext(file)[0] + "_Labels.txt"
688
-
689
- return file, gr.File.update(value=[file, labels], visible=True)
690
-
691
- return None
692
-
693
- classifier_selection_button.click(
694
- on_custom_classifier_selection_click,
695
- outputs=[selected_classifier_state, classifier_file_input],
696
- show_progress=False,
697
- )
698
-
699
- species_list_radio.change(
700
- show_species_choice,
701
- inputs=[species_list_radio],
702
- outputs=[position_row, species_file_input, custom_classifier_selector, empty_col],
703
- show_progress=False,
704
- )
705
-
706
- return (
707
- species_list_radio,
708
- species_file_input,
709
- lat_number,
710
- lon_number,
711
- week_number,
712
- sf_thresh_number,
713
- yearlong_checkbox,
714
- selected_classifier_state,
715
- )
716
-
717
 
718
  if __name__ == "__main__":
719
- freeze_support()
720
-
721
- def build_single_analysis_tab():
722
- with gr.Tab("Single file"):
723
- audio_input = gr.Audio(type="filepath", label="file", elem_id="single_file_audio")
724
-
725
- confidence_slider, sensitivity_slider, overlap_slider = sample_sliders(False)
726
- (
727
- species_list_radio,
728
- species_file_input,
729
- lat_number,
730
- lon_number,
731
- week_number,
732
- sf_thresh_number,
733
- yearlong_checkbox,
734
- selected_classifier_state,
735
- ) = species_lists(False)
736
- locale_radio = locale()
737
-
738
- inputs = [
739
- audio_input,
740
- confidence_slider,
741
- sensitivity_slider,
742
- overlap_slider,
743
- species_list_radio,
744
- species_file_input,
745
- lat_number,
746
- lon_number,
747
- week_number,
748
- yearlong_checkbox,
749
- sf_thresh_number,
750
- selected_classifier_state,
751
- locale_radio,
752
- ]
753
-
754
- output_dataframe = gr.Dataframe(
755
- type="pandas",
756
- headers=["Start (s)", "End (s)", "Scientific name", "Common name", "Confidence"],
757
- elem_classes="mh-200",
758
- )
759
-
760
- single_file_analyze = gr.Button("Analyze")
761
-
762
- single_file_analyze.click(runSingleFileAnalysis, inputs=inputs, outputs=output_dataframe)
763
-
764
- def build_multi_analysis_tab():
765
- with gr.Tab("Multiple files"):
766
- input_directory_state = gr.State()
767
- output_directory_predict_state = gr.State()
768
- with gr.Row():
769
- with gr.Column():
770
- select_directory_btn = gr.Button("Select directory (recursive)")
771
- directory_input = gr.Matrix(interactive=False, elem_classes="mh-200", headers=["Subpath", "Length"])
772
-
773
- def select_directory_on_empty():
774
- res = select_directory()
775
-
776
- return res if res[1] else [res[0], [["No files found"]]]
777
-
778
- select_directory_btn.click(
779
- select_directory_on_empty, outputs=[input_directory_state, directory_input], show_progress=True
780
- )
781
-
782
- with gr.Column():
783
- select_out_directory_btn = gr.Button("Select output directory.")
784
- selected_out_textbox = gr.Textbox(
785
- label="Output directory",
786
- interactive=False,
787
- placeholder="If not selected, the input directory will be used.",
788
- )
789
-
790
- def select_directory_wrapper():
791
- return (select_directory(collect_files=False),) * 2
792
-
793
- select_out_directory_btn.click(
794
- select_directory_wrapper,
795
- outputs=[output_directory_predict_state, selected_out_textbox],
796
- show_progress=False,
797
- )
798
-
799
- confidence_slider, sensitivity_slider, overlap_slider = sample_sliders()
800
-
801
- (
802
- species_list_radio,
803
- species_file_input,
804
- lat_number,
805
- lon_number,
806
- week_number,
807
- sf_thresh_number,
808
- yearlong_checkbox,
809
- selected_classifier_state,
810
- ) = species_lists()
811
-
812
- output_type_radio = gr.Radio(
813
- list(OUTPUT_TYPE_MAP.keys()),
814
- value="Raven selection table",
815
- label="Result type",
816
- info="Specifies output format.",
817
- )
818
-
819
- with gr.Row():
820
- batch_size_number = gr.Number(
821
- precision=1, label="Batch size", value=1, info="Number of samples to process at the same time."
822
- )
823
- threads_number = gr.Number(precision=1, label="Threads", value=4, info="Number of CPU threads.")
824
-
825
- locale_radio = locale()
826
-
827
- start_batch_analysis_btn = gr.Button("Analyze")
828
-
829
- result_grid = gr.Matrix(headers=["File", "Execution"], elem_classes="mh-200")
830
-
831
- inputs = [
832
- output_directory_predict_state,
833
- confidence_slider,
834
- sensitivity_slider,
835
- overlap_slider,
836
- species_list_radio,
837
- species_file_input,
838
- lat_number,
839
- lon_number,
840
- week_number,
841
- yearlong_checkbox,
842
- sf_thresh_number,
843
- selected_classifier_state,
844
- output_type_radio,
845
- locale_radio,
846
- batch_size_number,
847
- threads_number,
848
- input_directory_state,
849
- ]
850
-
851
- start_batch_analysis_btn.click(runBatchAnalysis, inputs=inputs, outputs=result_grid)
852
-
853
- def build_train_tab():
854
- with gr.Tab("Train"):
855
- input_directory_state = gr.State()
856
- output_directory_state = gr.State()
857
-
858
- with gr.Row():
859
- with gr.Column():
860
- select_directory_btn = gr.Button("Training data")
861
- directory_input = gr.List(headers=["Classes"], interactive=False, elem_classes="mh-200")
862
- select_directory_btn.click(
863
- select_subdirectories, outputs=[input_directory_state, directory_input], show_progress=False
864
- )
865
-
866
- with gr.Column():
867
- select_directory_btn = gr.Button("Classifier output")
868
-
869
- with gr.Column():
870
- classifier_name = gr.Textbox(
871
- "CustomClassifier",
872
- visible=False,
873
- info="The name of the new classifier.",
874
- )
875
- output_format = gr.Radio(
876
- ["tflite", "raven", "both"],
877
- value="tflite",
878
- label="Model output format",
879
- info="Format for the trained classifier.",
880
- visible=False,
881
- )
882
-
883
- def select_directory_and_update_tb():
884
- dir_name = _WINDOW.create_file_dialog(webview.FOLDER_DIALOG)
885
-
886
- if dir_name:
887
- return (
888
- dir_name[0],
889
- gr.Textbox.update(label=dir_name[0] + "\\", visible=True),
890
- gr.Radio.update(visible=True, interactive=True),
891
- )
892
-
893
- return None, None
894
-
895
- select_directory_btn.click(
896
- select_directory_and_update_tb,
897
- outputs=[output_directory_state, classifier_name, output_format],
898
- show_progress=False,
899
- )
900
-
901
- with gr.Row():
902
- epoch_number = gr.Number(100, label="Epochs", info="Number of training epochs.")
903
- batch_size_number = gr.Number(32, label="Batch size", info="Batch size.")
904
- learning_rate_number = gr.Number(0.01, label="Learning rate", info="Learning rate.")
905
-
906
- with gr.Row():
907
- crop_mode = gr.Radio(
908
- ["center", "first", "segments"],
909
- value="center",
910
- label="Crop mode",
911
- info="Crop mode for training data.",
912
- )
913
- crop_overlap = gr.Number(0.0, label="Crop overlap", info="Overlap of training data segments", visible=False)
914
-
915
- def on_crop_select(new_crop_mode):
916
- return gr.Number.update(visible=new_crop_mode == "segments", interactive=new_crop_mode == "segments")
917
-
918
- crop_mode.change(on_crop_select, inputs=crop_mode, outputs=crop_overlap)
919
-
920
- with gr.Row():
921
- upsampling_mode = gr.Radio(
922
- ["repeat", "mean", "smote"],
923
- value="repeat",
924
- label="Upsampling mode",
925
- info="Balance data through upsampling.",
926
- )
927
- upsampling_ratio = gr.Slider(
928
- 0.0, 1.0, 0.0, step=0.01, label="Upsampling ratio", info="Balance train data and upsample minority classes."
929
- )
930
-
931
- with gr.Row():
932
- hidden_units_number = gr.Number(
933
- 0, label="Hidden units", info="Number of hidden units. If set to >0, a two-layer classifier is used."
934
- )
935
- use_mixup = gr.Checkbox(False, label="Use mixup", info="Whether to use mixup for training.", show_label=True)
936
-
937
- train_history_plot = gr.Plot()
938
-
939
- start_training_button = gr.Button("Start training")
940
 
941
- start_training_button.click(
942
- start_training,
943
- inputs=[
944
- input_directory_state,
945
- crop_mode,
946
- crop_overlap,
947
- output_directory_state,
948
- classifier_name,
949
- epoch_number,
950
- batch_size_number,
951
- learning_rate_number,
952
- hidden_units_number,
953
- use_mixup,
954
- upsampling_ratio,
955
- upsampling_mode,
956
- output_format,
957
- ],
958
- outputs=[train_history_plot],
959
- )
960
 
961
- def build_segments_tab():
962
- with gr.Tab("Segments"):
963
- audio_directory_state = gr.State()
964
- result_directory_state = gr.State()
965
- output_directory_state = gr.State()
 
966
 
967
- def select_directory_to_state_and_tb():
968
- return (select_directory(collect_files=False),) * 2
969
 
970
- with gr.Row():
971
- select_audio_directory_btn = gr.Button("Select audio directory (recursive)")
972
- selected_audio_directory_tb = gr.Textbox(show_label=False, interactive=False)
973
- select_audio_directory_btn.click(
974
- select_directory_to_state_and_tb,
975
- outputs=[selected_audio_directory_tb, audio_directory_state],
976
- show_progress=False,
977
- )
978
 
979
- with gr.Row():
980
- select_result_directory_btn = gr.Button("Select result directory")
981
- selected_result_directory_tb = gr.Textbox(
982
- show_label=False, interactive=False, placeholder="Same as audio directory if not selected"
983
- )
984
- select_result_directory_btn.click(
985
- select_directory_to_state_and_tb,
986
- outputs=[result_directory_state, selected_result_directory_tb],
987
- show_progress=False,
988
- )
989
 
990
- with gr.Row():
991
- select_output_directory_btn = gr.Button("Select output directory")
992
- selected_output_directory_tb = gr.Textbox(
993
- show_label=False, interactive=False, placeholder="Same as audio directory if not selected"
994
- )
995
- select_output_directory_btn.click(
996
- select_directory_to_state_and_tb,
997
- outputs=[selected_output_directory_tb, output_directory_state],
998
- show_progress=False,
999
- )
1000
 
1001
- min_conf_slider = gr.Slider(
1002
- minimum=0.1, maximum=0.99, step=0.01, label="Minimum confidence", info="Minimum confidence threshold."
1003
- )
1004
- num_seq_number = gr.Number(
1005
- 100, label="Max number of segments", info="Maximum number of randomly extracted segments per species."
1006
- )
1007
- seq_length_number = gr.Number(3.0, label="Sequence length", info="Length of extracted segments in seconds.")
1008
- threads_number = gr.Number(4, label="Threads", info="Number of CPU threads.")
1009
 
1010
- extract_segments_btn = gr.Button("Extract segments")
 
 
1011
 
1012
- result_grid = gr.Matrix(headers=["File", "Execution"], elem_classes="mh-200")
 
1013
 
1014
- extract_segments_btn.click(
1015
- extract_segments,
1016
- inputs=[
1017
- audio_directory_state,
1018
- result_directory_state,
1019
- output_directory_state,
1020
- min_conf_slider,
1021
- num_seq_number,
1022
- seq_length_number,
1023
- threads_number,
1024
- ],
1025
- outputs=result_grid,
1026
- )
1027
 
1028
- with gr.Blocks(
1029
- css=r".d-block .wrap {display: block !important;} .mh-200 {max-height: 300px; overflow-y: auto !important;} footer {display: none !important;} #single_file_audio, #single_file_audio * {max-height: 81.6px; min-height: 0;}",
1030
- theme=gr.themes.Default(),
1031
- analytics_enabled=False,
1032
- ) as demo:
1033
- build_single_analysis_tab()
1034
- build_multi_analysis_tab()
1035
- build_train_tab()
1036
- build_segments_tab()
1037
 
1038
- url = demo.queue(api_open=False).launch(prevent_thread_lock=True, quiet=True)[1]
1039
- _WINDOW = webview.create_window("BirdNET-Analyzer", url.rstrip("/") + "?__theme=light", min_size=(1024, 768))
1040
 
1041
- webview.start(private_mode=False)
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import librosa
3
+ import os
4
 
5
+ import analyze
6
  import config as cfg
7
  import segments
8
  import species
9
  import utils
10
  from train import trainModel
11
 
12
+ def runSingleFileAnalysis(audio_file, confidence, sensitivity, overlap, species_list_choice, species_list_file, lat, lon, week, use_yearlong, sf_thresh, custom_classifier_file):
13
+
14
+ # Load labels, codes etc
15
+ cfg.CODES = analyze.loadCodes()
16
+ cfg.LABELS = utils.readLines(cfg.LABELS_FILE)
17
+
18
+ # Set species list
19
+ if species_list_choice == "Custom":
20
+ cfg.SPECIES_LIST_FILE = species_list_file
21
+ cfg.SPECIES_LIST = utils.readLines(cfg.SPECIES_LIST_FILE)
22
+ elif species_list_choice == "Predict":
23
+ cfg.SPECIES_LIST = species.getSpeciesList(lat, lon, week, sf_thresh)
24
+ else:
25
+ cfg.SPECIES_LIST = []
26
+
27
+ # Set other params
28
+ cfg.LATITUDE = lat
29
+ cfg.LONGITUDE = lon
30
+ cfg.WEEK = week
31
+ cfg.LOCATION_FILTER_THRESHOLD = sf_thresh
32
+ cfg.INPUT_PATH = audio_file
33
+ cfg.MIN_CONFIDENCE = confidence
34
+ cfg.SIGMOID_SENSITIVITY = sensitivity
35
+ cfg.SIG_OVERLAP = overlap
36
+
37
+ # Analyze
38
+ return analyze.analyzeFile(cfg.INPUT_PATH, cfg)
39
+
40
+
41
+ def runBatchAnalysis(input_dir, output_dir, confidence, sensitivity, overlap, species_list_choice, species_list_file, lat, lon, week, use_yearlong, sf_thresh, batch_size, threads):
42
+
43
+ # Set params
44
+ cfg.MIN_CONFIDENCE = confidence
45
+ cfg.SIGMOID_SENSITIVITY = sensitivity
46
+ cfg.SIG_OVERLAP = overlap
47
+ cfg.INPUT_PATH = input_dir
48
+ cfg.OUTPUT_PATH = output_dir
49
+ cfg.FILE_LIST = utils.collect_audio_files(input_dir)
50
+ cfg.BATCH_SIZE = batch_size
51
+ cfg.CPU_THREADS = threads
52
+
53
+ # Set species list
54
+ if species_list_choice == "Custom":
55
+ cfg.SPECIES_LIST_FILE = species_list_file
56
+ cfg.SPECIES_LIST = utils.readLines(cfg.SPECIES_LIST_FILE)
57
+ elif species_list_choice == "Predict":
58
+ cfg.SPECIES_LIST = species.getSpeciesList(lat, lon, week, sf_thresh)
59
+ else:
60
+ cfg.SPECIES_LIST = []
61
+
62
+ # Analyze
63
+ return analyze.batchAnalyze()
64
+
65
+
66
+ # Rest of the code
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
  if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
+ with gr.Blocks() as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
+ gr.Markdown("### Single File Analysis")
73
+ with gr.Column():
74
+ audio_input = gr.Audio(type="filepath", label="Audio file")
75
+ confidence_slider = gr.Slider(0, 1, 0.5, step=0.01)
76
+ sensitivity_slider = gr.Slider(0.5, 1.5, 1, step=0.01)
77
+ overlap_slider = gr.Slider(0, 3, 0, step=0.01)
78
 
79
+ species_list_radio = gr.Radio(["All", "Custom", "Predict"], label="Species List")
80
+ species_file_input = gr.File(label="Species list file")
81
 
82
+ lat_number = gr.Number(label="Latitude")
83
+ lon_number = gr.Number(label="Longitude")
84
+ week_number = gr.Number(label="Week")
85
+ sf_thresh_number = gr.Number(label="Location filter threshold")
 
 
 
 
86
 
87
+ output_df = gr.Dataframe(headers=["Start", "End", "Species", "Confidence"])
 
 
 
 
 
 
 
 
 
88
 
89
+ analyze_button = gr.Button("Analyze")
90
+ analyze_button.click(runSingleFileAnalysis, inputs=[audio_input, confidence_slider, sensitivity_slider, overlap_slider, species_list_radio, species_file_input, lat_number, lon_number, week_number, sf_thresh_number], outputs=output_df)
 
 
 
 
 
 
 
 
91
 
92
+ gr.Markdown("### Batch Analysis")
93
+ with gr.Column():
94
+ input_dir = gr.Files(file_types=["audio/*"], label="Input directory")
95
+ output_dir = gr.Directory(label="Output directory")
 
 
 
 
96
 
97
+ confidence_slider = gr.Slider(0, 1, 0.5, step=0.01)
98
+ sensitivity_slider = gr.Slider(0.5, 1.5, 1, step=0.01)
99
+ overlap_slider = gr.Slider(0, 3, 0, step=0.01)
100
 
101
+ species_list_radio = gr.Radio(["All", "Custom", "Predict"], label="Species List")
102
+ species_file_input = gr.File(label="Species list file")
103
 
104
+ lat_number = gr.Number(label="Latitude")
105
+ lon_number = gr.Number(label="Longitude")
106
+ week_number = gr.Number(label="Week")
107
+ sf_thresh_number = gr.Number(label="Location filter threshold")
 
 
 
 
 
 
 
 
 
108
 
109
+ batch_size_number = gr.Number(label="Batch size")
110
+ threads_number = gr.Number(label="Threads")
 
 
 
 
 
 
 
111
 
112
+ analyze_button = gr.Button("Analyze")
113
+ analyze_button.click(runBatchAnalysis, inputs=[input_dir, output_dir, confidence_slider, sensitivity_slider, overlap_slider, species_list_radio, species_file_input, lat_number, lon_number, week_number, sf_thresh_number, batch_size_number, threads_number])
114
 
115
+ demo.launch()