AIOmarRehan commited on
Commit
afb665f
·
verified ·
1 Parent(s): 358bfd7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -25
app.py CHANGED
@@ -1,19 +1,22 @@
1
  import gradio as gr
2
  import numpy as np
3
  from PIL import Image
4
- import random
5
  import io
 
 
6
  from collections import Counter, defaultdict
7
  from datasets import load_dataset
8
  from app.model import predict
9
  from app.preprocess import preprocess_audio
 
10
 
11
- # Load Hugging Face datasets directly
12
- audio_ds = load_dataset("AIOmarRehan/General_Audio_Dataset")
13
- image_ds = load_dataset("AIOmarRehan/Mel_Spectrogram_Images_for_Audio_Classification")
14
 
15
- # Helper function to safely load images
16
  def safe_load_image(img):
 
17
  if img is None:
18
  return None
19
  if isinstance(img, np.ndarray):
@@ -21,13 +24,11 @@ def safe_load_image(img):
21
  img = img.convert("RGBA")
22
  return img
23
 
24
- # Process spectrogram image
25
  def process_image_input(img):
26
  img = safe_load_image(img)
27
  label, confidence, probs = predict(img)
28
  return label, round(confidence, 3), probs
29
 
30
- # Process raw audio
31
  def process_audio_input(audio_path):
32
  imgs = preprocess_audio(audio_path)
33
  all_preds, all_confs, all_probs = [], [], []
@@ -55,28 +56,44 @@ def process_audio_input(audio_path):
55
  final_conf = float(np.mean([all_confs[i] for i, lbl in enumerate(all_preds) if lbl == final_label]))
56
  return final_label, round(final_conf, 3), all_preds, [round(c, 3) for c in all_confs]
57
 
58
- # Main classifier
59
  def classify(audio_path, image, random_audio=False, random_image=False):
60
- # Pick random audio from HF dataset
61
  if random_audio and len(audio_ds) > 0:
62
- sample = random.choice(audio_ds)
63
- # If dataset stores audio as file path or array
64
- if isinstance(sample["audio"], dict) and "path" in sample["audio"]:
65
- audio_path = sample["audio"]["path"]
66
- elif isinstance(sample["audio"], dict) and "array" in sample["audio"]:
67
- # Save array temporarily
68
- import soundfile as sf
69
- audio_path = "/tmp/random_audio.wav"
70
- sf.write(audio_path, sample["audio"]["array"], sample["audio"]["sampling_rate"])
 
 
 
 
 
 
 
 
 
 
 
71
 
72
- # Pick random image from HF dataset
73
  if random_image and len(image_ds) > 0:
74
- sample = random.choice(image_ds)
75
- # Handle image bytes
76
- img_bytes = sample["image"] if isinstance(sample["image"], bytes) else sample["image"].tobytes()
77
- image = Image.open(io.BytesIO(img_bytes)).convert("RGBA")
 
 
 
 
 
78
 
79
- # If spectrogram image
80
  if image is not None:
81
  label, conf, probs = process_image_input(image)
82
  return {
@@ -85,7 +102,7 @@ def classify(audio_path, image, random_audio=False, random_image=False):
85
  "Details": probs
86
  }, label
87
 
88
- # If raw audio
89
  if audio_path is not None:
90
  label, conf, all_preds, all_confs = process_audio_input(audio_path)
91
  return {
 
1
  import gradio as gr
2
  import numpy as np
3
  from PIL import Image
 
4
  import io
5
+ import random
6
+ import tempfile
7
  from collections import Counter, defaultdict
8
  from datasets import load_dataset
9
  from app.model import predict
10
  from app.preprocess import preprocess_audio
11
+ import soundfile as sf
12
 
13
+ # Load Hugging Face datasets
14
+ audio_ds = load_dataset("AIOmarRehan/General_Audio_Dataset", split="train")
15
+ image_ds = load_dataset("AIOmarRehan/Mel_Spectrogram_Images_for_Audio_Classification", split="train")
16
 
17
+ # Helper functions
18
  def safe_load_image(img):
19
+ """Ensure input is PIL RGBA image"""
20
  if img is None:
21
  return None
22
  if isinstance(img, np.ndarray):
 
24
  img = img.convert("RGBA")
25
  return img
26
 
 
27
  def process_image_input(img):
28
  img = safe_load_image(img)
29
  label, confidence, probs = predict(img)
30
  return label, round(confidence, 3), probs
31
 
 
32
  def process_audio_input(audio_path):
33
  imgs = preprocess_audio(audio_path)
34
  all_preds, all_confs, all_probs = [], [], []
 
56
  final_conf = float(np.mean([all_confs[i] for i, lbl in enumerate(all_preds) if lbl == final_label]))
57
  return final_label, round(final_conf, 3), all_preds, [round(c, 3) for c in all_confs]
58
 
59
+ # Main classifier function
60
  def classify(audio_path, image, random_audio=False, random_image=False):
61
+ # Random audio selection
62
  if random_audio and len(audio_ds) > 0:
63
+ try:
64
+ sample = random.choice(audio_ds)
65
+ # Dataset may store audio as path or array
66
+ audio_obj = sample["audio"]
67
+ if isinstance(audio_obj, dict) and "path" in audio_obj:
68
+ audio_path = audio_obj["path"]
69
+ elif isinstance(audio_obj, dict) and "array" in audio_obj:
70
+ # Save temporarily
71
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpfile:
72
+ audio_path = tmpfile.name
73
+ sf.write(audio_path, audio_obj["array"], audio_obj["sampling_rate"])
74
+ else:
75
+ # fallback: datasets.Audio object
76
+ audio_array, sr = audio_obj["array"], audio_obj["sampling_rate"]
77
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpfile:
78
+ audio_path = tmpfile.name
79
+ sf.write(audio_path, audio_array, sr)
80
+ except Exception as e:
81
+ print("Error loading random audio:", e)
82
+ audio_path = None
83
 
84
+ # Random image selection
85
  if random_image and len(image_ds) > 0:
86
+ try:
87
+ sample = random.choice(image_ds)
88
+ img_obj = sample["image"]
89
+ if not isinstance(img_obj, Image.Image):
90
+ img_obj = Image.fromarray(img_obj) # convert ndarray to PIL
91
+ image = img_obj.convert("RGBA")
92
+ except Exception as e:
93
+ print("Error loading random image:", e)
94
+ image = None
95
 
96
+ # Process spectrogram image
97
  if image is not None:
98
  label, conf, probs = process_image_input(image)
99
  return {
 
102
  "Details": probs
103
  }, label
104
 
105
+ # Process raw audio
106
  if audio_path is not None:
107
  label, conf, all_preds, all_confs = process_audio_input(audio_path)
108
  return {