Exception handling
Browse files
app.py
CHANGED
|
@@ -39,39 +39,52 @@ def load_model():
|
|
| 39 |
logger.info("Model loaded successfully...")
|
| 40 |
except Exception as e:
|
| 41 |
logger.error(f"Model loading failed: {e}")
|
| 42 |
-
|
| 43 |
-
raise gr.Error(
|
| 44 |
-
"Failed to load model.........."
|
| 45 |
-
)
|
| 46 |
return feature_extractor, model
|
| 47 |
|
| 48 |
|
| 49 |
-
|
| 50 |
def predict_audio(audio_path):
|
| 51 |
logger.info(f"inside predict_audio : {audio_path}")
|
| 52 |
|
| 53 |
feature_extractor, model = load_model()
|
| 54 |
id2label = model.config.id2label
|
| 55 |
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
waveform.
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
return waveform.numpy(), probs, id2label
|
| 76 |
|
| 77 |
|
|
@@ -99,7 +112,7 @@ with gr.Blocks(title="AST Model") as demo:
|
|
| 99 |
img = librosa.display.specshow(
|
| 100 |
mel_db, sr = AST_SR,
|
| 101 |
x_axis = 'time', y_axis = 'mel',
|
| 102 |
-
ax = ax[1]
|
| 103 |
)
|
| 104 |
|
| 105 |
ax[1].set_title("Mel Spectrogram")
|
|
|
|
| 39 |
logger.info("Model loaded successfully...")
|
| 40 |
except Exception as e:
|
| 41 |
logger.error(f"Model loading failed: {e}")
|
| 42 |
+
raise gr.Error( "Failed to load model...")
|
|
|
|
|
|
|
|
|
|
| 43 |
return feature_extractor, model
|
| 44 |
|
| 45 |
|
|
|
|
| 46 |
def predict_audio(audio_path):
|
| 47 |
logger.info(f"inside predict_audio : {audio_path}")
|
| 48 |
|
| 49 |
feature_extractor, model = load_model()
|
| 50 |
id2label = model.config.id2label
|
| 51 |
|
| 52 |
+
try:
|
| 53 |
+
waveform, sr = librosa.load(audio_path, sr=AST_SR, mono=True)
|
| 54 |
+
except Exception as e:
|
| 55 |
+
logger.error(f"Audio loading failed : {e}")
|
| 56 |
+
raise.gr.Error("Failed to load the audio ...")
|
| 57 |
+
|
| 58 |
+
try:
|
| 59 |
+
waveform = torch.tensor(waveform)
|
| 60 |
+
max_val = waveform.abs().max()
|
| 61 |
+
if max_val > 0:
|
| 62 |
+
waveform = waveform / max_val
|
| 63 |
+
except Exception as e:
|
| 64 |
+
logger.error(f"Preprocessing failed: {e}")
|
| 65 |
+
raise.gr.Error("Preprocessing failed ...")
|
| 66 |
+
|
| 67 |
+
try:
|
| 68 |
+
inputs = feature_extractor(
|
| 69 |
+
waveform.numpy(),
|
| 70 |
+
sampling_rate=sr,
|
| 71 |
+
return_tensors="pt"
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
| 75 |
+
except Exception as e:
|
| 76 |
+
logger.error(f"Feature extraction failed : {e}")
|
| 77 |
+
raise.gr.Error("Feature extraction failed ...")
|
| 78 |
+
|
| 79 |
+
try:
|
| 80 |
+
with torch.no_grad():
|
| 81 |
+
logits = model(**inputs).logits.squeeze(0)
|
| 82 |
+
|
| 83 |
+
probs = torch.softmax(logits, dim=0).cpu().numpy()
|
| 84 |
+
except Exception as e:
|
| 85 |
+
logger.error(f"Inference failed : {e}")
|
| 86 |
+
raise.gr.Error("Inference failed ...")
|
| 87 |
+
|
| 88 |
return waveform.numpy(), probs, id2label
|
| 89 |
|
| 90 |
|
|
|
|
| 112 |
img = librosa.display.specshow(
|
| 113 |
mel_db, sr = AST_SR,
|
| 114 |
x_axis = 'time', y_axis = 'mel',
|
| 115 |
+
cmap = 'viridis', ax = ax[1]
|
| 116 |
)
|
| 117 |
|
| 118 |
ax[1].set_title("Mel Spectrogram")
|