fevot commited on
Commit
6ee41b9
·
verified ·
1 Parent(s): d13a9f5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -6
app.py CHANGED
@@ -85,11 +85,11 @@ model.load_state_dict(torch.load('model_weights.pth', map_location=device))
85
  model.eval()
86
 
87
  # Prediction function for Gradio
88
- def predict_bird(audio_file):
89
- if audio_file is None:
90
  return "Please upload an MP3 file."
91
 
92
- predictions = infer_birdcall(model, audio_file, segment_length=500, device=str(device))
93
 
94
  # Format the predictions with numbering
95
  if not predictions:
@@ -98,13 +98,28 @@ def predict_bird(audio_file):
98
  numbered_predictions = [f"{i+1}. {bird}" for i, bird in enumerate(predictions)]
99
  return "\n".join(numbered_predictions)
100
 
 
 
 
 
 
 
101
  # Create Gradio Blocks for more complex layout
102
  with gr.Blocks() as demo:
103
  gr.Markdown("# Bird Call Identification")
104
 
105
  with gr.Row():
106
- with gr.Column():
107
- audio_input = gr.Audio(type="filepath", label="Upload Bird Call Audio")
 
 
 
 
 
 
 
 
 
108
 
109
  with gr.Row():
110
  submit_btn = gr.Button("Identify Birds")
@@ -130,7 +145,7 @@ with gr.Blocks() as demo:
130
  # Set up the prediction event
131
  submit_btn.click(
132
  fn=predict_bird,
133
- inputs=audio_input,
134
  outputs=output_text
135
  )
136
 
 
85
  model.eval()
86
 
87
  # Prediction function for Gradio
88
+ def predict_bird(file_path):
89
+ if file_path is None:
90
  return "Please upload an MP3 file."
91
 
92
+ predictions = infer_birdcall(model, file_path, segment_length=500, device=str(device))
93
 
94
  # Format the predictions with numbering
95
  if not predictions:
 
98
  numbered_predictions = [f"{i+1}. {bird}" for i, bird in enumerate(predictions)]
99
  return "\n".join(numbered_predictions)
100
 
101
+ # Function to handle file upload and return path for audio player
102
+ def process_upload(file_obj):
103
+ if file_obj is None:
104
+ return None
105
+ return file_obj
106
+
107
  # Create Gradio Blocks for more complex layout
108
  with gr.Blocks() as demo:
109
  gr.Markdown("# Bird Call Identification")
110
 
111
  with gr.Row():
112
+ upload_file = gr.File(label="Upload MP3 file", file_types=[".mp3"])
113
+
114
+ with gr.Row():
115
+ audio_player = gr.Audio(label="Listen to bird call", type="filepath", interactive=False)
116
+
117
+ # Connect file upload to audio player
118
+ upload_file.change(
119
+ fn=process_upload,
120
+ inputs=upload_file,
121
+ outputs=audio_player
122
+ )
123
 
124
  with gr.Row():
125
  submit_btn = gr.Button("Identify Birds")
 
145
  # Set up the prediction event
146
  submit_btn.click(
147
  fn=predict_bird,
148
+ inputs=upload_file,
149
  outputs=output_text
150
  )
151