fosters commited on
Commit
d5d5544
·
verified ·
1 Parent(s): 22b2824

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -23
app.py CHANGED
@@ -1,39 +1,19 @@
1
  import gradio as gr
2
  from transformers import pipeline
3
- import os
4
- import torch
5
-
6
- # --- Performance Improvement ---
7
- # 1. Determine the number of available CPU cores.
8
- num_cpu_cores = os.cpu_count()
9
-
10
- # 2. Configure PyTorch to use all available CPU cores for its operations.
11
- # This is crucial for speeding up model inference on a CPU.
12
- if num_cpu_cores is not None:
13
- torch.set_num_threads(num_cpu_cores)
14
- print(f"✅ PyTorch is configured to use {num_cpu_cores} CPU cores.")
15
- else:
16
- print("Could not determine the number of CPU cores. Using default settings.")
17
 
18
  # Initialize the audio classification pipeline with the MIT model
19
  pipe = pipeline("audio-classification", model="MIT/ast-finetuned-audioset-10-10-0.4593")
20
 
21
- # Define the function to classify an audio file and return the top 3 results
22
  def classify_audio(audio):
23
  result = pipe(audio)
24
- # The pipeline returns a list of dicts sorted by score.
25
- # We select the top 3 results here.
26
- top_3_results = result[:3]
27
- # Convert the list of dicts to a single dictionary for the Label component
28
- return {label['label']: label['score'] for label in top_3_results}
29
 
30
  # Set up the Gradio interface
31
- # We removed `num_top_classes=3` from `gr.Label` and instead handle the
32
- # top-3 logic inside the `classify_audio` function. This avoids the bug.
33
  app = gr.Interface(
34
  fn=classify_audio, # Function to classify audio
35
  inputs=gr.Audio(type="filepath"), # Input for uploading an audio file
36
- outputs=gr.Label(), # Output Label will display the dictionary from the function
37
  title="Audio Classification", # App title
38
  description="Upload an audio file to classify it using MIT's fine-tuned AudioSet model."
39
  )
 
1
  import gradio as gr
2
  from transformers import pipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  # Initialize the audio classification pipeline with the MIT model
5
  pipe = pipeline("audio-classification", model="MIT/ast-finetuned-audioset-10-10-0.4593")
6
 
7
+ # Define the function to classify an audio file
8
  def classify_audio(audio):
9
  result = pipe(audio)
10
+ return {label['label']: label['score'] for label in result}
 
 
 
 
11
 
12
  # Set up the Gradio interface
 
 
13
  app = gr.Interface(
14
  fn=classify_audio, # Function to classify audio
15
  inputs=gr.Audio(type="filepath"), # Input for uploading an audio file
16
+ outputs=gr.Label(num_top_classes=3), # Output with top 3 classification results
17
  title="Audio Classification", # App title
18
  description="Upload an audio file to classify it using MIT's fine-tuned AudioSet model."
19
  )