fosters commited on
Commit
5266751
·
verified ·
1 Parent(s): fefad81

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -11
app.py CHANGED
@@ -16,24 +16,31 @@ else:
16
  print("Could not determine the number of CPU cores. Using default settings.")
17
 
18
  # Initialize the audio classification pipeline with the MIT model
 
19
  pipe = pipeline("audio-classification", model="MIT/ast-finetuned-audioset-10-10-0.4593")
20
 
21
  # Define the function to classify an audio file and return the top 3 results
22
- def classify_audio(audio):
23
- result = pipe(audio)
24
- return {label['label']: label['score'] for label in result}
 
 
 
 
 
 
 
25
 
26
  # Set up the Gradio interface
27
- # We removed `num_top_classes=3` from `gr.Label` and instead handle the
28
- # top-3 logic inside the `classify_audio` function. This avoids the bug.
29
  app = gr.Interface(
30
  fn=classify_audio, # Function to classify audio
31
- inputs=gr.Audio(type="filepath"), # Input for uploading an audio file
32
- outputs=gr.Label(num_top_classes=3), # Output Label will display the dictionary from the function
33
- title="Audio Classification", # App title
34
- description="Upload an audio file to classify it using MIT's fine-tuned AudioSet model."
 
35
  )
36
 
37
- # Launch the app
38
  if __name__ == "__main__":
39
- app.launch()
 
16
  print("Could not determine the number of CPU cores. Using default settings.")
17
 
18
  # Initialize the audio classification pipeline with the MIT model
19
+ # The pipeline will run on the CPU by default
20
  pipe = pipeline("audio-classification", model="MIT/ast-finetuned-audioset-10-10-0.4593")
21
 
22
  # Define the function to classify an audio file and return the top 3 results
23
+ def classify_audio(audio_filepath):
24
+ """
25
+ Classifies the audio file and returns a dictionary of the top 3 predictions.
26
+ """
27
+ preds = pipe(audio_filepath)
28
+ # The pipeline returns a sorted list of predictions. We take the top 3.
29
+ top_3_preds = preds[:3]
30
+ # Format the output as a dictionary of {label: score} for the gr.Label component
31
+ output_labels = {p["label"]: p["score"] for p in top_3_preds}
32
+ return output_labels
33
 
34
  # Set up the Gradio interface
 
 
35
  app = gr.Interface(
36
  fn=classify_audio, # Function to classify audio
37
+ inputs=gr.Audio(type="filepath", label="Upload Audio File"), # Input for uploading an audio file
38
+ outputs=gr.Label(label="Top 3 Predictions"), # Output Label will display the dictionary from the function
39
+ title="Audio Classification with MIT/AST",
40
+ description="Upload an audio file to classify it. The model will identify the top 3 most likely sound categories.",
41
+ ]
42
  )
43
 
44
+ # Launch the app with a shareable link, required for Hugging Face Spaces
45
  if __name__ == "__main__":
46
+ app.launch(share=True)