fosters commited on
Commit
39ec782
·
verified ·
1 Parent(s): 684b36b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -28
app.py CHANGED
@@ -4,42 +4,58 @@ import os
4
  import torch
5
 
6
  # --- Performance Improvement ---
7
- # 1. Determine the number of available CPU cores.
8
- num_cpu_cores = os.cpu_count()
9
-
10
- # 2. Configure PyTorch to use all available CPU cores for its operations.
11
- # This is crucial for speeding up model inference on a CPU.
12
- if num_cpu_cores is not None:
13
- torch.set_num_threads(num_cpu_cores)
14
- print(f"✅ PyTorch is configured to use {num_cpu_cores} CPU cores.")
15
- else:
16
- print("Could not determine the number of CPU cores. Using default settings.")
17
-
18
- # Initialize the audio classification pipeline with the MIT model
19
- # The pipeline will run on the CPU by default
20
- pipe = pipeline("audio-classification", model="MIT/ast-finetuned-audioset-10-10-0.4593")
21
-
22
- # Define the function to classify an audio file and return the top 3 results
23
  def classify_audio(audio_filepath):
24
  """
25
- Classifies the audio file and returns a dictionary of the top 3 predictions.
 
26
  """
 
 
 
27
  preds = pipe(audio_filepath)
28
- # The pipeline returns a sorted list of predictions. We take the top 3.
29
- top_3_preds = preds[:3]
30
- # Format the output as a dictionary of {label: score} for the gr.Label component
31
- output_labels = {p["label"]: p["score"] for p in top_3_preds}
32
- return output_labels
33
 
34
- # Set up the Gradio interface
 
 
 
 
 
 
 
 
 
 
 
 
35
  app = gr.Interface(
36
- fn=classify_audio, # Function to classify audio
37
- inputs=gr.Audio(type="filepath", label="Upload Audio File"), # Input for uploading an audio file
38
- outputs=gr.Label(label="Top 3 Predictions"), # Output Label will display the dictionary from the function
39
  title="Audio Classification with MIT/AST",
40
- description="Upload an audio file to classify it. The model will identify the top 3 most likely sound categories."
 
 
 
 
41
  )
42
 
43
- # Launch the app with a shareable link, required for Hugging Face Spaces
 
44
  if __name__ == "__main__":
45
  app.launch(share=True)
 
4
  import torch
5
 
6
  # --- Performance Improvement ---
7
+ # Configure PyTorch for CPU performance
8
+ num_cpu_cores = os.cpu_count() or 1 # Default to 1 if os.cpu_count() is None
9
+ torch.set_num_threads(num_cpu_cores)
10
+ print(f"✅ PyTorch is configured to use {num_cpu_cores} CPU cores.")
11
+
12
+
13
+ # --- Model and Pipeline ---
14
+ # Initialize the pipeline. It will default to the CPU.
15
+ # Using a specific revision for reproducibility
16
+ pipe = pipeline(
17
+ "audio-classification",
18
+ model="MIT/ast-finetuned-audioset-10-10-0.4593"
19
+ )
20
+
21
+
22
+ # --- Core Logic Function ---
23
  def classify_audio(audio_filepath):
24
  """
25
+ Classifies the audio, takes the top 3 predictions,
26
+ and formats them into a single, human-readable string.
27
  """
28
+ if audio_filepath is None:
29
+ return "Please upload an audio file first."
30
+
31
  preds = pipe(audio_filepath)
 
 
 
 
 
32
 
33
+ # Format the output as a string instead of a dictionary
34
+ # This is the key change to fix the TypeError
35
+ output_str = ""
36
+ for i, pred in enumerate(preds[:3]):
37
+ label = pred["label"]
38
+ score = pred["score"]
39
+ output_str += f"{i+1}. {label}: {score:.4f}\n"
40
+
41
+ return output_str.strip()
42
+
43
+
44
+ # --- Gradio Interface ---
45
+ # Create the Gradio app interface
46
  app = gr.Interface(
47
+ fn=classify_audio,
48
+ inputs=gr.Audio(type="filepath", label="Upload Audio File"),
49
+ outputs=gr.Label(label="Top 3 Predictions"), # This will now receive a simple string
50
  title="Audio Classification with MIT/AST",
51
+ description=(
52
+ "Upload an audio file to classify it. The model will identify the top 3 most likely sound categories. "
53
+ "This version is corrected to avoid common Gradio backend errors."
54
+ ),
55
+ cache_examples=False,
56
  )
57
 
58
+ # --- App Launch ---
59
+ # Launch the app with sharing enabled for Hugging Face Spaces
60
  if __name__ == "__main__":
61
  app.launch(share=True)