krislette commited on
Commit
344eaf8
·
1 Parent(s): 87d96e9

Auto-deploy from GitHub: c09552b1b4b0829d2a857bdf91e7e48ba2bb8eb3

Browse files
Files changed (2) hide show
  1. scripts/predict.py +10 -3
  2. src/models/mlp.py +1 -1
scripts/predict.py CHANGED
@@ -61,9 +61,16 @@ def predict_pipeline(audio_file, lyrics):
61
  classifier.model.eval()
62
 
63
  # 8.) Run prediction
64
- probability, prediction, label = classifier.predict_single(results.flatten())
65
-
66
- return {"probability": probability, "prediction": prediction, "label": label}
 
 
 
 
 
 
 
67
 
68
 
69
  if __name__ == "__main__":
 
61
  classifier.model.eval()
62
 
63
  # 8.) Run prediction
64
+ confidence, prediction, label, probability = classifier.predict_single(
65
+ results.flatten()
66
+ )
67
+
68
+ return {
69
+ "confidence": confidence,
70
+ "prediction": prediction,
71
+ "label": label,
72
+ "probability": probability,
73
+ }
74
 
75
 
76
  if __name__ == "__main__":
src/models/mlp.py CHANGED
@@ -496,7 +496,7 @@ class MLPClassifier:
496
 
497
  confidence = probability * 100 if prediction == 1 else (1 - probability) * 100
498
 
499
- return confidence, prediction, label
500
 
501
  def predict_batch(self, features: np.ndarray, return_details: bool = False) -> Dict:
502
  """
 
496
 
497
  confidence = probability * 100 if prediction == 1 else (1 - probability) * 100
498
 
499
+ return confidence, prediction, label, probability
500
 
501
  def predict_batch(self, features: np.ndarray, return_details: bool = False) -> Dict:
502
  """