monishaaura commited on
Commit
fa20419
·
1 Parent(s): b0af609

Fix angry bias: add temperature scaling, bias mitigation, less aggressive noise reduction

Browse files
Files changed (1) hide show
  1. app.py +59 -22
app.py CHANGED
@@ -113,7 +113,7 @@ ENABLE_VAD = os.environ.get("ENABLE_VAD", "true").lower() == "true"
113
  ENABLE_DENOISE = os.environ.get("ENABLE_DENOISE", "true").lower() == "true"
114
  ENABLE_HIGHPASS = os.environ.get("ENABLE_HIGHPASS", "true").lower() == "true"
115
  ENABLE_SILENCE_TRIM = os.environ.get("ENABLE_SILENCE_TRIM", "true").lower() == "true"
116
- CONFIDENCE_THRESHOLD = float(os.environ.get("CONFIDENCE_THRESHOLD", "0.5"))
117
  MIN_VOICED_MS = int(os.environ.get("MIN_VOICED_MS", "500"))
118
  MIN_AUDIO_DURATION_MS = int(os.environ.get("MIN_AUDIO_DURATION_MS", "300"))
119
  MAX_AUDIO_DURATION_MS = int(os.environ.get("MAX_AUDIO_DURATION_MS", "10000"))
@@ -362,18 +362,19 @@ def preprocess_audio(audio_bytes: bytes) -> np.ndarray:
362
  except Exception as e:
363
  logger.warning(f"High-pass filter failed: {e}")
364
 
365
- # Optional noise reduction (spectral gating) - improved settings
366
  if ENABLE_DENOISE and nr is not None:
367
  try:
368
- # Use stationary noise reduction for better voice preservation
 
369
  audio_array = nr.reduce_noise(
370
  y=audio_array,
371
  sr=sample_rate,
372
- prop_decrease=0.8, # Slightly less aggressive (was 0.9)
373
  stationary=True, # Better for voice
374
- n_std_thresh_stationary=1.5 # More conservative threshold
375
  )
376
- logger.info("Applied noise reduction")
377
  except Exception as e:
378
  logger.warning(f"Noise reduction failed: {e}")
379
 
@@ -403,6 +404,7 @@ def preprocess_audio(audio_bytes: bytes) -> np.ndarray:
403
  def predict_emotion(audio_array: np.ndarray) -> dict:
404
  """
405
  Predict emotion from audio array using Wav2Vec2 model.
 
406
 
407
  Args:
408
  audio_array: Preprocessed audio array (float32, 16kHz, mono)
@@ -434,14 +436,19 @@ def predict_emotion(audio_array: np.ndarray) -> dict:
434
  with torch.no_grad():
435
  outputs = model(**inputs)
436
 
437
- # Get predicted class (emotion label index)
438
  logits = outputs.logits
439
- predicted_class = torch.argmax(logits, dim=-1).item()
440
 
441
- # Get probabilities for all emotions using softmax
442
- probabilities = torch.nn.functional.softmax(logits, dim=-1).cpu().numpy()[0]
 
 
443
 
444
- # Get confidence (probability of predicted emotion)
 
 
 
 
445
  confidence = float(probabilities[predicted_class])
446
 
447
  # Map class index to emotion label
@@ -453,16 +460,39 @@ def predict_emotion(audio_array: np.ndarray) -> dict:
453
  for i, prob in enumerate(probabilities)
454
  }
455
 
456
- logger.info(f"🎭 Detected emotion: {emotion_label} (confidence: {confidence:.2%})")
457
- logger.info(f"📊 Probability distribution: {emotion_probs}")
458
-
459
- # Improved confidence handling: use top-2 emotions for better accuracy
460
  sorted_probs = sorted(emotion_probs.items(), key=lambda x: x[1], reverse=True)
461
  top_emotion, top_conf = sorted_probs[0]
462
  second_emotion, second_conf = sorted_probs[1] if len(sorted_probs) > 1 else (None, 0.0)
 
463
 
464
- # If top two emotions are close, might be ambiguous
 
 
 
 
 
465
  confidence_diff = top_conf - second_conf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
466
 
467
  # Confidence gating with improved logic
468
  if confidence < CONFIDENCE_THRESHOLD:
@@ -470,17 +500,24 @@ def predict_emotion(audio_array: np.ndarray) -> dict:
470
  "emotion": "uncertain",
471
  "confidence": confidence,
472
  "probabilities": emotion_probs,
473
- "top_emotions": {top_emotion: top_conf, second_emotion: second_conf} if second_emotion else {top_emotion: top_conf},
474
- "note": f"Low confidence ({confidence:.2%} < {CONFIDENCE_THRESHOLD:.2%}). Top emotion: {top_emotion}."
 
 
 
 
475
  }
476
  elif confidence_diff < 0.15 and top_conf < 0.6:
477
- # Ambiguous case: top two emotions are close and confidence is moderate
478
  return {
479
- "emotion": top_emotion,
480
  "confidence": confidence,
481
  "probabilities": emotion_probs,
482
- "top_emotions": {top_emotion: top_conf, second_emotion: second_conf},
483
- "note": f"Ambiguous detection. Top: {top_emotion} ({top_conf:.2%}), Second: {second_emotion} ({second_conf:.2%})"
 
 
 
484
  }
485
  else:
486
  return {
 
113
  ENABLE_DENOISE = os.environ.get("ENABLE_DENOISE", "true").lower() == "true"
114
  ENABLE_HIGHPASS = os.environ.get("ENABLE_HIGHPASS", "true").lower() == "true"
115
  ENABLE_SILENCE_TRIM = os.environ.get("ENABLE_SILENCE_TRIM", "true").lower() == "true"
116
+ CONFIDENCE_THRESHOLD = float(os.environ.get("CONFIDENCE_THRESHOLD", "0.4"))
117
  MIN_VOICED_MS = int(os.environ.get("MIN_VOICED_MS", "500"))
118
  MIN_AUDIO_DURATION_MS = int(os.environ.get("MIN_AUDIO_DURATION_MS", "300"))
119
  MAX_AUDIO_DURATION_MS = int(os.environ.get("MAX_AUDIO_DURATION_MS", "10000"))
 
362
  except Exception as e:
363
  logger.warning(f"High-pass filter failed: {e}")
364
 
365
+ # Optional noise reduction (spectral gating) - less aggressive to preserve emotion cues
366
  if ENABLE_DENOISE and nr is not None:
367
  try:
368
+ # Use stationary noise reduction with less aggressive settings
369
+ # Less aggressive = preserves more emotion-relevant features
370
  audio_array = nr.reduce_noise(
371
  y=audio_array,
372
  sr=sample_rate,
373
+ prop_decrease=0.6, # Less aggressive (was 0.8) to preserve emotion features
374
  stationary=True, # Better for voice
375
+ n_std_thresh_stationary=2.0 # More conservative threshold
376
  )
377
+ logger.info("Applied noise reduction (conservative)")
378
  except Exception as e:
379
  logger.warning(f"Noise reduction failed: {e}")
380
 
 
404
  def predict_emotion(audio_array: np.ndarray) -> dict:
405
  """
406
  Predict emotion from audio array using Wav2Vec2 model.
407
+ Includes bias mitigation and calibration to prevent over-prediction of certain emotions.
408
 
409
  Args:
410
  audio_array: Preprocessed audio array (float32, 16kHz, mono)
 
436
  with torch.no_grad():
437
  outputs = model(**inputs)
438
 
439
+ # Get logits (raw model outputs before softmax)
440
  logits = outputs.logits
 
441
 
442
+ # Apply temperature scaling to reduce overconfidence and bias
443
+ # Higher temperature (1.5) makes the distribution more uniform, reducing bias
444
+ temperature = 1.5
445
+ scaled_logits = logits / temperature
446
 
447
+ # Get probabilities for all emotions using softmax on scaled logits
448
+ probabilities = torch.nn.functional.softmax(scaled_logits, dim=-1).cpu().numpy()[0]
449
+
450
+ # Get predicted class (emotion label index) from scaled probabilities
451
+ predicted_class = np.argmax(probabilities)
452
  confidence = float(probabilities[predicted_class])
453
 
454
  # Map class index to emotion label
 
460
  for i, prob in enumerate(probabilities)
461
  }
462
 
463
+ # Sort probabilities for analysis
 
 
 
464
  sorted_probs = sorted(emotion_probs.items(), key=lambda x: x[1], reverse=True)
465
  top_emotion, top_conf = sorted_probs[0]
466
  second_emotion, second_conf = sorted_probs[1] if len(sorted_probs) > 1 else (None, 0.0)
467
+ third_emotion, third_conf = sorted_probs[2] if len(sorted_probs) > 2 else (None, 0.0)
468
 
469
+ logger.info(f"🎭 Raw prediction: {emotion_label} (confidence: {confidence:.2%})")
470
+ logger.info(f"📊 Top 3: {top_emotion} ({top_conf:.2%}), {second_emotion} ({second_conf:.2%}), {third_emotion} ({third_conf:.2%})")
471
+ logger.info(f"📊 Full distribution: {emotion_probs}")
472
+
473
+ # Bias mitigation: If "angry" is predicted but confidence is not significantly higher,
474
+ # and other emotions are close, consider the second-best emotion
475
  confidence_diff = top_conf - second_conf
476
+ confidence_diff_2 = top_conf - third_conf if third_emotion else top_conf
477
+
478
+ # If "angry" is top but margin is small, prefer second emotion if it's more reasonable
479
+ if top_emotion == "angry" and confidence_diff < 0.2 and top_conf < 0.65:
480
+ # Check if second emotion has reasonable confidence
481
+ if second_conf > 0.25 and second_emotion != "angry":
482
+ logger.info(f"⚠️ Bias mitigation: 'angry' predicted but margin small. Using {second_emotion} instead.")
483
+ emotion_label = second_emotion
484
+ confidence = second_conf
485
+ top_emotion = second_emotion
486
+ top_conf = second_conf
487
+
488
+ # Additional check: If top emotion has very low confidence, use second if it's reasonable
489
+ if top_conf < 0.4 and second_conf > 0.25:
490
+ logger.info(f"⚠️ Low confidence on top emotion. Considering {second_emotion}.")
491
+ if second_conf > top_conf * 0.8: # Second is at least 80% of top
492
+ emotion_label = second_emotion
493
+ confidence = second_conf
494
+ top_emotion = second_emotion
495
+ top_conf = second_conf
496
 
497
  # Confidence gating with improved logic
498
  if confidence < CONFIDENCE_THRESHOLD:
 
500
  "emotion": "uncertain",
501
  "confidence": confidence,
502
  "probabilities": emotion_probs,
503
+ "top_emotions": {
504
+ "first": {top_emotion: top_conf},
505
+ "second": {second_emotion: second_conf} if second_emotion else None,
506
+ "third": {third_emotion: third_conf} if third_emotion else None
507
+ },
508
+ "note": f"Low confidence ({confidence:.2%} < {CONFIDENCE_THRESHOLD:.2%}). Top: {top_emotion}."
509
  }
510
  elif confidence_diff < 0.15 and top_conf < 0.6:
511
+ # Ambiguous case: top emotions are close
512
  return {
513
+ "emotion": emotion_label,
514
  "confidence": confidence,
515
  "probabilities": emotion_probs,
516
+ "top_emotions": {
517
+ "first": {top_emotion: top_conf},
518
+ "second": {second_emotion: second_conf} if second_emotion else None
519
+ },
520
+ "note": f"Ambiguous: {top_emotion} ({top_conf:.2%}) vs {second_emotion} ({second_conf:.2%})"
521
  }
522
  else:
523
  return {