gargaman07 commited on
Commit
75f5a09
·
verified ·
1 Parent(s): 0de5893

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +500 -0
app.py ADDED
@@ -0,0 +1,500 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+ import shutil
5
+ import tempfile
6
+ from pathlib import Path
7
+
8
+ import librosa
9
+ import numpy as np
10
+ import requests
11
+ import tensorflow as tf
12
+ import tensorflow_hub as hub
13
+ from acrcloud.recognizer import ACRCloudRecognizer
14
+ from fastapi import FastAPI, File, Form, HTTPException, Request, UploadFile
15
+ from fastapi.middleware.cors import CORSMiddleware
16
+ from fastapi.responses import HTMLResponse
17
+ from fastapi.templating import Jinja2Templates
18
+ from pydantic import BaseModel
19
+ from pydub import AudioSegment
20
+ from tensorflow.keras.models import load_model
21
+
22
+ app = FastAPI()
23
+
24
+ # Add CORS middleware
25
+ app.add_middleware(
26
+ CORSMiddleware,
27
+ allow_origins=["*"],
28
+ allow_credentials=True,
29
+ allow_methods=["*"],
30
+ allow_headers=["*"],
31
+ )
32
+
33
+ templates = Jinja2Templates(directory=".")
34
+ model = load_model('./models/neural_networks.h5')
35
+ # ACRCloud Configuration using SDK
36
+ ACRCLOUD_CONFIG = {
37
+ 'host': 'identify-ap-southeast-1.acrcloud.com',
38
+ 'access_key': 'c529996b7457352ca72e2ccb1fcbc4dd',
39
+ 'access_secret': 'MQitmw327GTfkoLhCzk90Uwcf2dL0DGhUvQvQwS0',
40
+ 'timeout': 1 # seconds
41
+ }
42
+ acr_recognizer = ACRCloudRecognizer(ACRCLOUD_CONFIG)
43
+
44
+ # Load YAMNet model and labels
45
+ yamnet_model_handle = 'https://tfhub.dev/google/yamnet/1'
46
+ yamnet_model = hub.load(yamnet_model_handle)
47
+
48
+ with open("yamnet_class_map.csv", "r") as f:
49
+ yamnet_classes = [line.strip().split(",")[2] for line in f.readlines()[1:]]
50
+
51
+ # # Set up ffmpeg path
52
+ # FFMPEG_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "ffmpeg-master-latest-win64-gpl", "bin")
53
+ # if os.path.exists(FFMPEG_PATH):
54
+ # os.environ["PATH"] = FFMPEG_PATH + os.pathsep + os.environ["PATH"]
55
+ # AudioSegment.converter = os.path.join(FFMPEG_PATH, "ffmpeg.exe")
56
+ # AudioSegment.ffmpeg = os.path.join(FFMPEG_PATH, "ffmpeg.exe")
57
+ # AudioSegment.ffprobe = os.path.join(FFMPEG_PATH, "ffprobe.exe")
58
+
59
+ # Comment out or remove the Windows-specific FFMPEG_PATH setup
60
+ # In Docker, ffmpeg will be installed via apt-get and should be in the PATH
61
+ # pydub should find it automatically.
62
+ # If issues arise, one might need to set AudioSegment.converter explicitly,
63
+ # but without the Windows-specific path.
64
+ # For example:
65
+ # AudioSegment.converter = "/usr/bin/ffmpeg" # or wherever ffmpeg is installed
66
+ # AudioSegment.ffmpeg = "/usr/bin/ffmpeg"
67
+ # AudioSegment.ffprobe = "/usr/bin/ffprobe"
68
+ # However, this is often not needed if ffmpeg is in the system PATH.
69
+
70
+ def extract_features(audio_path, max_length=100):
71
+ y, sr = librosa.load(audio_path, sr=None)
72
+ y_normalized = librosa.util.normalize(y)
73
+ segments = librosa.effects.split(y_normalized, top_db=20)
74
+
75
+ mfccs = []
76
+ for start, end in segments:
77
+ segment = y[start:end]
78
+ mfcc = librosa.feature.mfcc(y=segment, sr=sr, n_mfcc=13)
79
+ if mfcc.shape[1] > max_length:
80
+ mfcc = mfcc[:, :max_length]
81
+ else:
82
+ pad_width = max_length - mfcc.shape[1]
83
+ mfcc = np.pad(mfcc, pad_width=((0, 0), (0, pad_width)), mode='constant')
84
+ mfccs.append(mfcc)
85
+
86
+ return mfccs
87
+
88
+ def predict_vehicle_class(audio_path):
89
+ features = extract_features(audio_path)
90
+
91
+ # Normalize using training distribution (consider saving stats during training if accuracy matters)
92
+ features = np.array(features)
93
+ features = (features - np.mean(features)) / np.std(features)
94
+
95
+ # Average predictions across all segments
96
+ predictions = model.predict(features)
97
+ averaged_prediction = np.mean(predictions, axis=0)
98
+ predicted_class = int(np.argmax(averaged_prediction)) # Convert numpy.int64 to Python int
99
+
100
+ return predicted_class
101
+
102
+ def convert_audio_to_wav(src_path: str, dst_path: str) -> bool:
103
+ """Convert any audio file to WAV format using pydub."""
104
+ try:
105
+ # Get the file extension
106
+ ext = os.path.splitext(src_path)[1].lower().lstrip('.')
107
+
108
+ # Load the audio file with specific parameters
109
+ audio = AudioSegment.from_file(
110
+ src_path,
111
+ format=ext,
112
+ parameters=["-ar", "16000", "-ac", "1"] # Set sample rate to 16kHz and mono
113
+ )
114
+
115
+ # Export as WAV with specific parameters
116
+ audio.export(
117
+ dst_path,
118
+ format="wav",
119
+ parameters=["-ar", "16000", "-ac", "1", "-acodec", "pcm_s16le"]
120
+ )
121
+ return True
122
+ except Exception as e:
123
+ logging.error(f"Error converting audio file: {str(e)}")
124
+ return False
125
+
126
+ def classify_audio_with_yamnet(file_path):
127
+ try:
128
+ # Create a temporary WAV file
129
+ with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as temp_wav:
130
+ temp_wav_path = temp_wav.name
131
+
132
+ # Convert the input file to WAV if needed
133
+ if not convert_audio_to_wav(file_path, temp_wav_path):
134
+ return {
135
+ "success": False,
136
+ "message": "Failed to convert audio file to WAV format"
137
+ }
138
+
139
+ try:
140
+ # Load and process the audio
141
+ waveform, sr = librosa.load(temp_wav_path, sr=16000) # YAMNet expects 16kHz
142
+ scores, embeddings, spectrogram = yamnet_model(waveform)
143
+ scores_np = scores.numpy().mean(axis=0) # average over time
144
+
145
+ top5_i = np.argsort(scores_np)[::-1][:5]
146
+ top_labels = [(yamnet_classes[i], float(scores_np[i])) for i in top5_i] # Convert scores to Python float
147
+
148
+ return {
149
+ "success": True,
150
+ "top_classes": top_labels
151
+ }
152
+ finally:
153
+ # Clean up the temporary file
154
+ if os.path.exists(temp_wav_path):
155
+ os.unlink(temp_wav_path)
156
+
157
+ except Exception as e:
158
+ logging.exception("YAMNet classification failed:")
159
+ return {
160
+ "success": False,
161
+ "message": f"Audio classification failed: {str(e)}"
162
+ }
163
+
164
+ def is_vehicle_sound(yamnet_classes):
165
+ """
166
+ Check if any of the top YAMNet classifications are vehicle-related.
167
+ Returns True if a vehicle sound is detected, along with the matched class and score.
168
+ """
169
+ vehicle_keywords = [
170
+ # General vehicle terms
171
+ 'vehicle', 'automobile', 'motor vehicle',
172
+ # Specific vehicle types
173
+ 'car', 'truck', 'bus', 'van', 'motorcycle', 'scooter',
174
+ # Vehicle components
175
+ 'engine', 'motor', 'horn', 'siren', 'tire', 'wheel',
176
+ # Vehicle sounds
177
+ 'revving', 'acceleration', 'braking', 'idling',
178
+ # Transportation
179
+ 'transport', 'traffic', 'road'
180
+ ]
181
+
182
+ # Log the top classifications for debugging
183
+ logging.info("Top YAMNet classifications:")
184
+ for class_name, score in yamnet_classes:
185
+ logging.info(f"- {class_name}: {score:.2f}")
186
+
187
+ # Check each classification against vehicle keywords
188
+ for class_name, score in yamnet_classes:
189
+ class_name_lower = class_name.lower()
190
+ for keyword in vehicle_keywords:
191
+ if keyword in class_name_lower:
192
+ logging.info(f"Vehicle sound detected: '{class_name}' (score: {score:.2f})")
193
+ return True, class_name, score
194
+
195
+ logging.info("No vehicle sounds detected in the audio")
196
+ return False, None, 0.0
197
+
198
+ @app.post("/classify/")
199
+ async def classify_audio(file: UploadFile = File(...)):
200
+ temp_filename = f"temp_classify_{file.filename}"
201
+ file_content = await file.read()
202
+
203
+ try:
204
+ with open(temp_filename, "wb") as f:
205
+ f.write(file_content)
206
+
207
+ # First try music recognition
208
+ result_json_str = acr_recognizer.recognize_by_file(temp_filename, 0)
209
+ music_result = format_acrcloud_response(result_json_str)
210
+
211
+ if music_result["success"]:
212
+ # If music recognition was successful, return that result
213
+ return {
214
+ "success": True,
215
+ "type": "music",
216
+ "music_result": music_result
217
+ }
218
+ else:
219
+ # If music recognition failed, try YAMNet classification
220
+ yamnet_result = classify_audio_with_yamnet(temp_filename)
221
+ if yamnet_result["success"]:
222
+ # Check if the sound is vehicle-related
223
+ is_vehicle, vehicle_class, vehicle_score = is_vehicle_sound(yamnet_result["top_classes"])
224
+ if is_vehicle:
225
+ # If it's a vehicle sound, use the neural network for specific classification
226
+ vehicle_class = predict_vehicle_class(temp_filename)
227
+ vehicle_type = "Car" if vehicle_class == 0 else "Truck"
228
+
229
+ return {
230
+ "success": True,
231
+ "type": "vehicle",
232
+ "vehicle_result": {
233
+ "vehicle_type": vehicle_type,
234
+ "detected_sound": vehicle_class,
235
+ "confidence": float(vehicle_score) * 100
236
+ }
237
+ }
238
+
239
+ # If not a vehicle sound, return YAMNet classification
240
+ return {
241
+ "success": True,
242
+ "type": "sound",
243
+ "sound_result": yamnet_result
244
+ }
245
+ else:
246
+ return {
247
+ "success": False,
248
+ "message": "No music, vehicle, or sound patterns recognized."
249
+ }
250
+
251
+ except Exception as e:
252
+ logging.exception("Error during classification:")
253
+ return {"success": False, "message": str(e)}
254
+
255
+ finally:
256
+ if os.path.exists(temp_filename):
257
+ os.remove(temp_filename)
258
+
259
+ @app.get("/", response_class=HTMLResponse)
260
+ async def read_root(request: Request):
261
+ return templates.TemplateResponse("index.html", {"request": request})
262
+
263
+ @app.post("/recognize/")
264
+ async def recognize_song_acr(file: UploadFile = File(...)):
265
+ temp_filename = f"temp_recognize_{file.filename}"
266
+ file_content = await file.read()
267
+
268
+ try:
269
+ with open(temp_filename, "wb") as buffer:
270
+ buffer.write(file_content)
271
+
272
+ result_json_str = acr_recognizer.recognize_by_file(temp_filename, 0)
273
+
274
+ return format_acrcloud_response(result_json_str)
275
+ except Exception as e:
276
+ logging.exception("Error during SDK ACRCloud recognition:")
277
+ return {"success": False, "message": f"Recognition failed: {str(e)}"}
278
+ finally:
279
+ # Changed: Ensure temp file is cleaned up
280
+ if os.path.exists(temp_filename):
281
+ os.remove(temp_filename)
282
+
283
+ @app.post("/upload/")
284
+ async def upload_song_acr(file: UploadFile = File(...), song_name: str = Form(None)):
285
+ temp_filename = f"temp_upload_{file.filename}"
286
+ file_content = await file.read()
287
+
288
+ try:
289
+ with open(temp_filename, "wb") as buffer:
290
+ buffer.write(file_content)
291
+
292
+ result_json_str = acr_recognizer.recognize_by_file(temp_filename, 0)
293
+
294
+ response_data = format_acrcloud_response(result_json_str)
295
+ if song_name and response_data.get("success"):
296
+ response_data["message_context"] = f"Recognition for (originally uploaded as '{song_name}')"
297
+ elif song_name and not response_data.get("success"):
298
+ response_data["message"] = f"Recognition for (originally uploaded as '{song_name}') failed: {response_data.get('message')}"
299
+
300
+ return response_data
301
+ except Exception as e:
302
+ logging.exception("Error during SDK ACRCloud upload/recognition:")
303
+ return {"success": False, "message": f"Upload/Recognition failed: {str(e)}"}
304
+ finally:
305
+ if os.path.exists(temp_filename):
306
+ os.remove(temp_filename)
307
+
308
+ @app.post("/recognize-live-chunk/")
309
+ async def recognize_live_chunk(file: UploadFile = File(...)):
310
+ file_content = await file.read()
311
+
312
+ if not file_content:
313
+ return {"success": False, "message": "Empty audio chunk received."}
314
+
315
+ try:
316
+ logging.info(f"Received live chunk, size: {len(file_content)} bytes, filename: {file.filename}")
317
+
318
+ # First try music recognition
319
+ result_json_str = acr_recognizer.recognize_by_filebuffer(file_content, 0)
320
+ music_result = format_acrcloud_response(result_json_str)
321
+
322
+ # Check if we got a valid music result
323
+ if music_result["success"] and music_result.get("song_name"):
324
+ # If we have a valid song name, return the music result
325
+ return {
326
+ "success": True,
327
+ "type": "music",
328
+ "music_result": music_result
329
+ }
330
+
331
+ # If no valid music result, try YAMNet classification
332
+ with tempfile.NamedTemporaryFile(suffix='.webm', delete=False) as temp_file:
333
+ temp_filename = temp_file.name
334
+ temp_file.write(file_content)
335
+
336
+ try:
337
+ # Convert to WAV first
338
+ wav_filename = temp_filename.replace('.webm', '.wav')
339
+ if convert_audio_to_wav(temp_filename, wav_filename):
340
+ yamnet_result = classify_audio_with_yamnet(wav_filename)
341
+
342
+ if yamnet_result["success"]:
343
+ # Check if the sound is vehicle-related
344
+ is_vehicle, vehicle_class, vehicle_score = is_vehicle_sound(yamnet_result["top_classes"])
345
+ if is_vehicle:
346
+ # If it's a vehicle sound, use the neural network for specific classification
347
+ vehicle_class = predict_vehicle_class(wav_filename)
348
+ vehicle_type = "Car" if vehicle_class == 0 else "Truck"
349
+
350
+ return {
351
+ "success": True,
352
+ "type": "vehicle",
353
+ "vehicle_result": {
354
+ "vehicle_type": vehicle_type,
355
+ "detected_sound": str(vehicle_class), # Convert to string
356
+ "confidence": float(vehicle_score) * 100 # Convert to Python float
357
+ }
358
+ }
359
+
360
+ # If not a vehicle sound, return YAMNet classification
361
+ return {
362
+ "success": True,
363
+ "type": "sound",
364
+ "sound_result": {
365
+ "top_classes": [(str(label), float(score)) for label, score in yamnet_result["top_classes"]]
366
+ }
367
+ }
368
+
369
+ # If we get here, all recognition attempts failed
370
+ return {
371
+ "success": False,
372
+ "message": "No music, vehicle, or sound patterns recognized."
373
+ }
374
+ finally:
375
+ # Clean up temporary files
376
+ if os.path.exists(temp_filename):
377
+ os.remove(temp_filename)
378
+ if os.path.exists(wav_filename):
379
+ os.remove(wav_filename)
380
+
381
+ except Exception as e:
382
+ logging.exception("Error during audio processing:")
383
+ return {"success": False, "message": f"Processing failed: {str(e)}"}
384
+
385
+ def format_acrcloud_response(result_json_str: str):
386
+ """
387
+ Parses the JSON string response from ACRCloud and formats it.
388
+ """
389
+ try:
390
+ result = json.loads(result_json_str)
391
+ logging.info(f"ACRCloud raw response: {result}")
392
+
393
+ # Check if we have a valid music result
394
+ if result.get("status", {}).get("code") == 0 and "metadata" in result and "music" in result["metadata"]:
395
+ # Ensure 'music' list is not empty
396
+ if not result["metadata"]["music"]:
397
+ return {"success": False, "message": "No music metadata found in response."}
398
+
399
+ music_info = result["metadata"]["music"][0]
400
+ title = music_info.get("title")
401
+
402
+ # If no title, it's not a valid music result
403
+ if not title:
404
+ return {"success": False, "message": "No song title found in response."}
405
+
406
+ artists_list = music_info.get("artists", [])
407
+ artists = ", ".join([artist["name"] for artist in artists_list if "name" in artist])
408
+ album = music_info.get("album", {}).get("name")
409
+
410
+ offset_seconds = music_info.get("play_offset_ms", 0) / 1000.0
411
+ if offset_seconds == 0 and "sample_begin_time_offset_ms" in music_info:
412
+ offset_seconds = music_info.get("sample_begin_time_offset_ms", 0) / 1000.0
413
+
414
+ confidence = music_info.get("score", 0)
415
+ if confidence == 0 and "result_type" in result:
416
+ confidence = result.get("result_type",0) * 25
417
+
418
+ return {
419
+ "success": True,
420
+ "song_name": title,
421
+ "artists": artists,
422
+ "album": album,
423
+ "confidence": confidence,
424
+ "offset_seconds": offset_seconds,
425
+ "raw_acr_response": result
426
+ }
427
+ else:
428
+ return {"success": False, "message": result.get("status", {}).get("msg", "Song not recognized or error in response.")}
429
+ except json.JSONDecodeError:
430
+ logging.error(f"Failed to decode ACRCloud JSON response: {result_json_str}")
431
+ return {"success": False, "message": "Error parsing recognition server response."}
432
+ except Exception as e:
433
+ logging.error(f"Error processing ACRCloud response: {e} -- Response was: {result_json_str}")
434
+ return {"success": False, "message": f"An unexpected error occurred: {str(e)}"}
435
+
436
+ @app.post("/predict/")
437
+ async def predict_audio(file: UploadFile = File(...)):
438
+ # Save uploaded file temporarily
439
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp:
440
+ shutil.copyfileobj(file.file, tmp)
441
+ tmp_path = tmp.name
442
+
443
+ try:
444
+ # Predict using the neural network
445
+ predicted_class = predict_vehicle_class(tmp_path)
446
+ return {"filename": file.filename, "predicted_class": int(predicted_class)}
447
+ finally:
448
+ os.remove(tmp_path)
449
+
450
+ # Mistral AI configuration
451
+ MISTRAL_API_KEY = "SDV5ynlJBEs0n15l2PDvO9eor1ki4dTI"
452
+ MISTRAL_API_URL = "https://api.mistral.ai/v1/chat/completions"
453
+
454
+ class ChatRequest(BaseModel):
455
+ system_prompt: str
456
+ user_message: str
457
+
458
+ @app.post("/chat-with-mistral/")
459
+ async def chat_with_mistral(request: ChatRequest):
460
+ try:
461
+ headers = {
462
+ "Authorization": f"Bearer {MISTRAL_API_KEY}",
463
+ "Content-Type": "application/json"
464
+ }
465
+
466
+ data = {
467
+ "model": "mistral-small",
468
+ "messages": [
469
+ {
470
+ "role": "system",
471
+ "content": request.system_prompt
472
+ },
473
+ {
474
+ "role": "user",
475
+ "content": request.user_message
476
+ }
477
+ ]
478
+ }
479
+
480
+ response = requests.post(MISTRAL_API_URL, headers=headers, json=data)
481
+
482
+ if response.status_code == 200:
483
+ ai_response = response.json()["choices"][0]["message"]["content"]
484
+ return {
485
+ "success": True,
486
+ "response": ai_response
487
+ }
488
+ else:
489
+ raise HTTPException(
490
+ status_code=500,
491
+ detail=f"Error from Mistral API: {response.status_code} - {response.text}"
492
+ )
493
+
494
+ except Exception as e:
495
+ raise HTTPException(
496
+ status_code=500,
497
+ detail=str(e)
498
+ )
499
+
500
+ logging.basicConfig(level=logging.INFO)