Spaces:
Sleeping
Sleeping
Commit
·
fa20419
1
Parent(s):
b0af609
Fix angry bias: add temperature scaling, bias mitigation, less aggressive noise reduction
Browse files
app.py
CHANGED
|
@@ -113,7 +113,7 @@ ENABLE_VAD = os.environ.get("ENABLE_VAD", "true").lower() == "true"
|
|
| 113 |
ENABLE_DENOISE = os.environ.get("ENABLE_DENOISE", "true").lower() == "true"
|
| 114 |
ENABLE_HIGHPASS = os.environ.get("ENABLE_HIGHPASS", "true").lower() == "true"
|
| 115 |
ENABLE_SILENCE_TRIM = os.environ.get("ENABLE_SILENCE_TRIM", "true").lower() == "true"
|
| 116 |
-
CONFIDENCE_THRESHOLD = float(os.environ.get("CONFIDENCE_THRESHOLD", "0.
|
| 117 |
MIN_VOICED_MS = int(os.environ.get("MIN_VOICED_MS", "500"))
|
| 118 |
MIN_AUDIO_DURATION_MS = int(os.environ.get("MIN_AUDIO_DURATION_MS", "300"))
|
| 119 |
MAX_AUDIO_DURATION_MS = int(os.environ.get("MAX_AUDIO_DURATION_MS", "10000"))
|
|
@@ -362,18 +362,19 @@ def preprocess_audio(audio_bytes: bytes) -> np.ndarray:
|
|
| 362 |
except Exception as e:
|
| 363 |
logger.warning(f"High-pass filter failed: {e}")
|
| 364 |
|
| 365 |
-
# Optional noise reduction (spectral gating) -
|
| 366 |
if ENABLE_DENOISE and nr is not None:
|
| 367 |
try:
|
| 368 |
-
# Use stationary noise reduction
|
|
|
|
| 369 |
audio_array = nr.reduce_noise(
|
| 370 |
y=audio_array,
|
| 371 |
sr=sample_rate,
|
| 372 |
-
prop_decrease=0.
|
| 373 |
stationary=True, # Better for voice
|
| 374 |
-
n_std_thresh_stationary=
|
| 375 |
)
|
| 376 |
-
logger.info("Applied noise reduction")
|
| 377 |
except Exception as e:
|
| 378 |
logger.warning(f"Noise reduction failed: {e}")
|
| 379 |
|
|
@@ -403,6 +404,7 @@ def preprocess_audio(audio_bytes: bytes) -> np.ndarray:
|
|
| 403 |
def predict_emotion(audio_array: np.ndarray) -> dict:
|
| 404 |
"""
|
| 405 |
Predict emotion from audio array using Wav2Vec2 model.
|
|
|
|
| 406 |
|
| 407 |
Args:
|
| 408 |
audio_array: Preprocessed audio array (float32, 16kHz, mono)
|
|
@@ -434,14 +436,19 @@ def predict_emotion(audio_array: np.ndarray) -> dict:
|
|
| 434 |
with torch.no_grad():
|
| 435 |
outputs = model(**inputs)
|
| 436 |
|
| 437 |
-
# Get
|
| 438 |
logits = outputs.logits
|
| 439 |
-
predicted_class = torch.argmax(logits, dim=-1).item()
|
| 440 |
|
| 441 |
-
#
|
| 442 |
-
|
|
|
|
|
|
|
| 443 |
|
| 444 |
-
# Get
|
|
|
|
|
|
|
|
|
|
|
|
|
| 445 |
confidence = float(probabilities[predicted_class])
|
| 446 |
|
| 447 |
# Map class index to emotion label
|
|
@@ -453,16 +460,39 @@ def predict_emotion(audio_array: np.ndarray) -> dict:
|
|
| 453 |
for i, prob in enumerate(probabilities)
|
| 454 |
}
|
| 455 |
|
| 456 |
-
|
| 457 |
-
logger.info(f"📊 Probability distribution: {emotion_probs}")
|
| 458 |
-
|
| 459 |
-
# Improved confidence handling: use top-2 emotions for better accuracy
|
| 460 |
sorted_probs = sorted(emotion_probs.items(), key=lambda x: x[1], reverse=True)
|
| 461 |
top_emotion, top_conf = sorted_probs[0]
|
| 462 |
second_emotion, second_conf = sorted_probs[1] if len(sorted_probs) > 1 else (None, 0.0)
|
|
|
|
| 463 |
|
| 464 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 465 |
confidence_diff = top_conf - second_conf
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 466 |
|
| 467 |
# Confidence gating with improved logic
|
| 468 |
if confidence < CONFIDENCE_THRESHOLD:
|
|
@@ -470,17 +500,24 @@ def predict_emotion(audio_array: np.ndarray) -> dict:
|
|
| 470 |
"emotion": "uncertain",
|
| 471 |
"confidence": confidence,
|
| 472 |
"probabilities": emotion_probs,
|
| 473 |
-
"top_emotions": {
|
| 474 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 475 |
}
|
| 476 |
elif confidence_diff < 0.15 and top_conf < 0.6:
|
| 477 |
-
# Ambiguous case: top
|
| 478 |
return {
|
| 479 |
-
"emotion":
|
| 480 |
"confidence": confidence,
|
| 481 |
"probabilities": emotion_probs,
|
| 482 |
-
"top_emotions": {
|
| 483 |
-
|
|
|
|
|
|
|
|
|
|
| 484 |
}
|
| 485 |
else:
|
| 486 |
return {
|
|
|
|
| 113 |
ENABLE_DENOISE = os.environ.get("ENABLE_DENOISE", "true").lower() == "true"
|
| 114 |
ENABLE_HIGHPASS = os.environ.get("ENABLE_HIGHPASS", "true").lower() == "true"
|
| 115 |
ENABLE_SILENCE_TRIM = os.environ.get("ENABLE_SILENCE_TRIM", "true").lower() == "true"
|
| 116 |
+
CONFIDENCE_THRESHOLD = float(os.environ.get("CONFIDENCE_THRESHOLD", "0.4"))
|
| 117 |
MIN_VOICED_MS = int(os.environ.get("MIN_VOICED_MS", "500"))
|
| 118 |
MIN_AUDIO_DURATION_MS = int(os.environ.get("MIN_AUDIO_DURATION_MS", "300"))
|
| 119 |
MAX_AUDIO_DURATION_MS = int(os.environ.get("MAX_AUDIO_DURATION_MS", "10000"))
|
|
|
|
| 362 |
except Exception as e:
|
| 363 |
logger.warning(f"High-pass filter failed: {e}")
|
| 364 |
|
| 365 |
+
# Optional noise reduction (spectral gating) - less aggressive to preserve emotion cues
|
| 366 |
if ENABLE_DENOISE and nr is not None:
|
| 367 |
try:
|
| 368 |
+
# Use stationary noise reduction with less aggressive settings
|
| 369 |
+
# Less aggressive = preserves more emotion-relevant features
|
| 370 |
audio_array = nr.reduce_noise(
|
| 371 |
y=audio_array,
|
| 372 |
sr=sample_rate,
|
| 373 |
+
prop_decrease=0.6, # Less aggressive (was 0.8) to preserve emotion features
|
| 374 |
stationary=True, # Better for voice
|
| 375 |
+
n_std_thresh_stationary=2.0 # More conservative threshold
|
| 376 |
)
|
| 377 |
+
logger.info("Applied noise reduction (conservative)")
|
| 378 |
except Exception as e:
|
| 379 |
logger.warning(f"Noise reduction failed: {e}")
|
| 380 |
|
|
|
|
| 404 |
def predict_emotion(audio_array: np.ndarray) -> dict:
|
| 405 |
"""
|
| 406 |
Predict emotion from audio array using Wav2Vec2 model.
|
| 407 |
+
Includes bias mitigation and calibration to prevent over-prediction of certain emotions.
|
| 408 |
|
| 409 |
Args:
|
| 410 |
audio_array: Preprocessed audio array (float32, 16kHz, mono)
|
|
|
|
| 436 |
with torch.no_grad():
|
| 437 |
outputs = model(**inputs)
|
| 438 |
|
| 439 |
+
# Get logits (raw model outputs before softmax)
|
| 440 |
logits = outputs.logits
|
|
|
|
| 441 |
|
| 442 |
+
# Apply temperature scaling to reduce overconfidence and bias
|
| 443 |
+
# Higher temperature (1.5) makes the distribution more uniform, reducing bias
|
| 444 |
+
temperature = 1.5
|
| 445 |
+
scaled_logits = logits / temperature
|
| 446 |
|
| 447 |
+
# Get probabilities for all emotions using softmax on scaled logits
|
| 448 |
+
probabilities = torch.nn.functional.softmax(scaled_logits, dim=-1).cpu().numpy()[0]
|
| 449 |
+
|
| 450 |
+
# Get predicted class (emotion label index) from scaled probabilities
|
| 451 |
+
predicted_class = np.argmax(probabilities)
|
| 452 |
confidence = float(probabilities[predicted_class])
|
| 453 |
|
| 454 |
# Map class index to emotion label
|
|
|
|
| 460 |
for i, prob in enumerate(probabilities)
|
| 461 |
}
|
| 462 |
|
| 463 |
+
# Sort probabilities for analysis
|
|
|
|
|
|
|
|
|
|
| 464 |
sorted_probs = sorted(emotion_probs.items(), key=lambda x: x[1], reverse=True)
|
| 465 |
top_emotion, top_conf = sorted_probs[0]
|
| 466 |
second_emotion, second_conf = sorted_probs[1] if len(sorted_probs) > 1 else (None, 0.0)
|
| 467 |
+
third_emotion, third_conf = sorted_probs[2] if len(sorted_probs) > 2 else (None, 0.0)
|
| 468 |
|
| 469 |
+
logger.info(f"🎭 Raw prediction: {emotion_label} (confidence: {confidence:.2%})")
|
| 470 |
+
logger.info(f"📊 Top 3: {top_emotion} ({top_conf:.2%}), {second_emotion} ({second_conf:.2%}), {third_emotion} ({third_conf:.2%})")
|
| 471 |
+
logger.info(f"📊 Full distribution: {emotion_probs}")
|
| 472 |
+
|
| 473 |
+
# Bias mitigation: If "angry" is predicted but confidence is not significantly higher,
|
| 474 |
+
# and other emotions are close, consider the second-best emotion
|
| 475 |
confidence_diff = top_conf - second_conf
|
| 476 |
+
confidence_diff_2 = top_conf - third_conf if third_emotion else top_conf
|
| 477 |
+
|
| 478 |
+
# If "angry" is top but margin is small, prefer second emotion if it's more reasonable
|
| 479 |
+
if top_emotion == "angry" and confidence_diff < 0.2 and top_conf < 0.65:
|
| 480 |
+
# Check if second emotion has reasonable confidence
|
| 481 |
+
if second_conf > 0.25 and second_emotion != "angry":
|
| 482 |
+
logger.info(f"⚠️ Bias mitigation: 'angry' predicted but margin small. Using {second_emotion} instead.")
|
| 483 |
+
emotion_label = second_emotion
|
| 484 |
+
confidence = second_conf
|
| 485 |
+
top_emotion = second_emotion
|
| 486 |
+
top_conf = second_conf
|
| 487 |
+
|
| 488 |
+
# Additional check: If top emotion has very low confidence, use second if it's reasonable
|
| 489 |
+
if top_conf < 0.4 and second_conf > 0.25:
|
| 490 |
+
logger.info(f"⚠️ Low confidence on top emotion. Considering {second_emotion}.")
|
| 491 |
+
if second_conf > top_conf * 0.8: # Second is at least 80% of top
|
| 492 |
+
emotion_label = second_emotion
|
| 493 |
+
confidence = second_conf
|
| 494 |
+
top_emotion = second_emotion
|
| 495 |
+
top_conf = second_conf
|
| 496 |
|
| 497 |
# Confidence gating with improved logic
|
| 498 |
if confidence < CONFIDENCE_THRESHOLD:
|
|
|
|
| 500 |
"emotion": "uncertain",
|
| 501 |
"confidence": confidence,
|
| 502 |
"probabilities": emotion_probs,
|
| 503 |
+
"top_emotions": {
|
| 504 |
+
"first": {top_emotion: top_conf},
|
| 505 |
+
"second": {second_emotion: second_conf} if second_emotion else None,
|
| 506 |
+
"third": {third_emotion: third_conf} if third_emotion else None
|
| 507 |
+
},
|
| 508 |
+
"note": f"Low confidence ({confidence:.2%} < {CONFIDENCE_THRESHOLD:.2%}). Top: {top_emotion}."
|
| 509 |
}
|
| 510 |
elif confidence_diff < 0.15 and top_conf < 0.6:
|
| 511 |
+
# Ambiguous case: top emotions are close
|
| 512 |
return {
|
| 513 |
+
"emotion": emotion_label,
|
| 514 |
"confidence": confidence,
|
| 515 |
"probabilities": emotion_probs,
|
| 516 |
+
"top_emotions": {
|
| 517 |
+
"first": {top_emotion: top_conf},
|
| 518 |
+
"second": {second_emotion: second_conf} if second_emotion else None
|
| 519 |
+
},
|
| 520 |
+
"note": f"Ambiguous: {top_emotion} ({top_conf:.2%}) vs {second_emotion} ({second_conf:.2%})"
|
| 521 |
}
|
| 522 |
else:
|
| 523 |
return {
|