AB739 commited on
Commit
2e96e8a
·
verified ·
1 Parent(s): 14660ec

Create audio_openvino.py

Browse files
Files changed (1) hide show
  1. tasks/audio_openvino.py +118 -0
tasks/audio_openvino.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter
2
+ from datetime import datetime
3
+ from datasets import load_dataset
4
+ from sklearn.metrics import accuracy_score
5
+ import os
6
+ import torch
7
+ from torch.utils.data import DataLoader, TensorDataset
8
+ from torchaudio import transforms
9
+ #from torchvision import models
10
+ #import onnxruntime as ort # Add ONNX Runtime
11
+ from openvino.runtime import Core
12
+ from .utils.evaluation import AudioEvaluationRequest
13
+ from .utils.emissions import tracker, clean_emissions_data, get_space_info
14
+
15
+ from dotenv import load_dotenv
16
+ load_dotenv()
17
+
18
+ router = APIRouter()
19
+
20
+ DESCRIPTION = "Tiny_DNN"
21
+ ROUTE = "/audio"
22
+
23
+ torch.set_num_threads(4)
24
+ torch.set_num_interop_threads(2)
25
+
26
+ @router.post(ROUTE, tags=["Audio Task"], description=DESCRIPTION)
27
+ async def evaluate_audio(request: AudioEvaluationRequest):
28
+ # Get space info
29
+ username, space_url = get_space_info()
30
+
31
+ # Define the label mapping
32
+ LABEL_MAPPING = {
33
+ "chainsaw": 0,
34
+ "environment": 1
35
+ }
36
+
37
+ # Load and prepare the dataset
38
+ dataset = load_dataset(request.dataset_name, token=os.getenv("HF_TOKEN"))
39
+ train_test = dataset["train"].train_test_split(test_size=request.test_size, seed=request.test_seed)
40
+ test_dataset = train_test["test"]
41
+ true_labels = test_dataset["label"]
42
+
43
+ resampler = transforms.Resample(orig_freq=12000, new_freq=16000)
44
+ mel_transform = transforms.MelSpectrogram(sample_rate=16000, n_mels=64)
45
+ amplitude_to_db = transforms.AmplitudeToDB()
46
+
47
+ def resize_audio(_waveform, target_length):
48
+ num_frames = _waveform.shape[-1]
49
+ if num_frames != target_length:
50
+ _resampler = transforms.Resample(orig_freq=num_frames, new_freq=target_length)
51
+ _waveform = _resampler(_waveform)
52
+ return _waveform
53
+
54
+ resized_waveforms = [
55
+ resize_audio(torch.tensor(sample['audio']['array'], dtype=torch.float32).unsqueeze(0), target_length=72000)
56
+ for sample in test_dataset
57
+ ]
58
+
59
+ waveforms, labels = [], []
60
+ for waveform, label in zip(resized_waveforms, true_labels):
61
+ waveforms.append(amplitude_to_db(mel_transform(resampler(waveform))))
62
+ labels.append(label)
63
+
64
+ waveforms = torch.stack(waveforms)
65
+ labels = torch.tensor(labels)
66
+
67
+ test_loader = DataLoader(
68
+ TensorDataset(waveforms, labels),
69
+ batch_size=128,
70
+ shuffle=False,
71
+ pin_memory=True,
72
+ num_workers=4
73
+ )
74
+
75
+ # Load Openvino model
76
+ core = Core()
77
+ model_path = "./openvino_model/model.xml"
78
+ compiled_model = core.compile_model(model=model_path, device_name="CPU")
79
+ input_layer = compiled_model.input(0)
80
+ output_layer = compiled_model.output(0)
81
+ # Start tracking emissions
82
+ tracker.start()
83
+ tracker.start_task("inference")
84
+
85
+ # Openvino inference
86
+ predictions = []
87
+ for data, target in test_loader:
88
+ inputs = data.numpy() # Convert tensor to numpy
89
+ inputs = inputs.reshape((-1, 1, 64, 481))
90
+ output = compiled_model([inputs])[output_layer]
91
+ predicted = np.argmax(output, axis=1)
92
+ predictions.extend(predicted.tolist())
93
+
94
+ # Stop tracking emissions
95
+ emissions_data = tracker.stop_task()
96
+
97
+ # Calculate accuracy
98
+ accuracy = accuracy_score(true_labels, predictions)
99
+
100
+ # Prepare results dictionary
101
+ results = {
102
+ "username": username,
103
+ "space_url": space_url,
104
+ "submission_timestamp": datetime.now().isoformat(),
105
+ "model_description": DESCRIPTION,
106
+ "accuracy": float(accuracy),
107
+ "energy_consumed_wh": emissions_data.energy_consumed * 1000,
108
+ "emissions_gco2eq": emissions_data.emissions * 1000,
109
+ "emissions_data": clean_emissions_data(emissions_data),
110
+ "api_route": ROUTE,
111
+ "dataset_config": {
112
+ "dataset_name": request.dataset_name,
113
+ "test_size": request.test_size,
114
+ "test_seed": request.test_seed
115
+ }
116
+ }
117
+
118
+ return results