fosters commited on
Commit
fefad81
·
verified ·
1 Parent(s): d5d5544

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -2
app.py CHANGED
@@ -1,19 +1,35 @@
1
  import gradio as gr
2
  from transformers import pipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  # Initialize the audio classification pipeline with the MIT model
5
  pipe = pipeline("audio-classification", model="MIT/ast-finetuned-audioset-10-10-0.4593")
6
 
7
- # Define the function to classify an audio file
8
  def classify_audio(audio):
9
  result = pipe(audio)
10
  return {label['label']: label['score'] for label in result}
11
 
12
  # Set up the Gradio interface
 
 
13
  app = gr.Interface(
14
  fn=classify_audio, # Function to classify audio
15
  inputs=gr.Audio(type="filepath"), # Input for uploading an audio file
16
- outputs=gr.Label(num_top_classes=3), # Output with top 3 classification results
17
  title="Audio Classification", # App title
18
  description="Upload an audio file to classify it using MIT's fine-tuned AudioSet model."
19
  )
 
1
  import gradio as gr
2
  from transformers import pipeline
3
+ import os
4
+ import torch
5
+
6
+ # --- Performance Improvement ---
7
+ # 1. Determine the number of available CPU cores.
8
+ num_cpu_cores = os.cpu_count()
9
+
10
+ # 2. Configure PyTorch to use all available CPU cores for its operations.
11
+ # This is crucial for speeding up model inference on a CPU.
12
+ if num_cpu_cores is not None:
13
+ torch.set_num_threads(num_cpu_cores)
14
+ print(f"✅ PyTorch is configured to use {num_cpu_cores} CPU cores.")
15
+ else:
16
+ print("Could not determine the number of CPU cores. Using default settings.")
17
 
18
  # Initialize the audio classification pipeline with the MIT model
19
  pipe = pipeline("audio-classification", model="MIT/ast-finetuned-audioset-10-10-0.4593")
20
 
21
+ # Define the function to classify an audio file and return the top 3 results
22
  def classify_audio(audio):
23
  result = pipe(audio)
24
  return {label['label']: label['score'] for label in result}
25
 
26
  # Set up the Gradio interface
27
+ # We removed `num_top_classes=3` from `gr.Label` and instead handle the
28
+ # top-3 logic inside the `classify_audio` function. This avoids the bug.
29
  app = gr.Interface(
30
  fn=classify_audio, # Function to classify audio
31
  inputs=gr.Audio(type="filepath"), # Input for uploading an audio file
32
+ outputs=gr.Label(num_top_classes=3), # Output Label will display the dictionary from the function
33
  title="Audio Classification", # App title
34
  description="Upload an audio file to classify it using MIT's fine-tuned AudioSet model."
35
  )