Spaces:
Sleeping
Sleeping
Update src/evaluation.py
Browse files- src/evaluation.py +52 -36
src/evaluation.py
CHANGED
|
@@ -11,7 +11,7 @@ from config import ALL_UG40_LANGUAGES, GOOGLE_SUPPORTED_LANGUAGES, METRICS_CONFI
|
|
| 11 |
from src.utils import get_all_language_pairs, get_google_comparable_pairs
|
| 12 |
|
| 13 |
def calculate_sentence_metrics(reference: str, prediction: str) -> Dict[str, float]:
|
| 14 |
-
"""Calculate all metrics for a single sentence pair."""
|
| 15 |
|
| 16 |
# Handle empty predictions
|
| 17 |
if not prediction or not isinstance(prediction, str):
|
|
@@ -27,14 +27,14 @@ def calculate_sentence_metrics(reference: str, prediction: str) -> Dict[str, flo
|
|
| 27 |
|
| 28 |
metrics = {}
|
| 29 |
|
| 30 |
-
# BLEU score
|
| 31 |
try:
|
| 32 |
bleu = BLEU(effective_order=True)
|
| 33 |
metrics['bleu'] = bleu.sentence_score(pred_norm, [ref_norm]).score
|
| 34 |
except:
|
| 35 |
metrics['bleu'] = 0.0
|
| 36 |
|
| 37 |
-
# ChrF score
|
| 38 |
try:
|
| 39 |
chrf = CHRF()
|
| 40 |
metrics['chrf'] = chrf.sentence_score(pred_norm, [ref_norm]).score / 100.0
|
|
@@ -83,19 +83,30 @@ def calculate_sentence_metrics(reference: str, prediction: str) -> Dict[str, flo
|
|
| 83 |
metrics['rouge2'] = 0.0
|
| 84 |
metrics['rougeL'] = 0.0
|
| 85 |
|
| 86 |
-
# Quality score (composite metric)
|
| 87 |
try:
|
| 88 |
quality_components = [
|
| 89 |
metrics['bleu'] / 100.0, # Normalize BLEU to 0-1
|
| 90 |
-
metrics['chrf'],
|
| 91 |
1.0 - min(metrics['cer'], 1.0), # Invert error rates
|
| 92 |
1.0 - min(metrics['wer'], 1.0),
|
| 93 |
metrics['rouge1'],
|
| 94 |
metrics['rougeL']
|
| 95 |
]
|
| 96 |
metrics['quality_score'] = np.mean(quality_components)
|
| 97 |
-
except:
|
| 98 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
|
| 100 |
return metrics
|
| 101 |
|
|
@@ -132,7 +143,7 @@ def evaluate_predictions(predictions: pd.DataFrame, test_set: pd.DataFrame) -> D
|
|
| 132 |
|
| 133 |
sample_df = pd.DataFrame(sample_metrics)
|
| 134 |
|
| 135 |
-
# Aggregate by language pairs
|
| 136 |
pair_metrics = {}
|
| 137 |
overall_metrics = defaultdict(list)
|
| 138 |
google_comparable_metrics = defaultdict(list)
|
|
@@ -153,16 +164,19 @@ def evaluate_predictions(predictions: pd.DataFrame, test_set: pd.DataFrame) -> D
|
|
| 153 |
# Calculate averages for this pair
|
| 154 |
for metric in METRICS_CONFIG['primary_metrics'] + METRICS_CONFIG['secondary_metrics']:
|
| 155 |
if metric in pair_data.columns:
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
|
|
|
|
|
|
|
|
|
| 166 |
|
| 167 |
pair_metrics[pair_key]['sample_count'] = len(pair_data)
|
| 168 |
|
|
@@ -185,11 +199,12 @@ def evaluate_predictions(predictions: pd.DataFrame, test_set: pd.DataFrame) -> D
|
|
| 185 |
# Generate evaluation summary
|
| 186 |
summary = {
|
| 187 |
'total_samples': len(sample_df),
|
| 188 |
-
'language_pairs_covered': len([k for k in pair_metrics if pair_metrics[k]
|
| 189 |
'google_comparable_pairs': len([k for k in pair_metrics
|
| 190 |
if '_to_' in k and
|
| 191 |
k.split('_to_')[0] in GOOGLE_SUPPORTED_LANGUAGES and
|
| 192 |
-
k.split('_to_')[1] in GOOGLE_SUPPORTED_LANGUAGES
|
|
|
|
| 193 |
'primary_metrics': {metric: averages.get(metric, 0.0)
|
| 194 |
for metric in METRICS_CONFIG['primary_metrics']},
|
| 195 |
'secondary_metrics': {metric: averages.get(metric, 0.0)
|
|
@@ -206,6 +221,7 @@ def evaluate_predictions(predictions: pd.DataFrame, test_set: pd.DataFrame) -> D
|
|
| 206 |
'error': None
|
| 207 |
}
|
| 208 |
|
|
|
|
| 209 |
def compare_with_baseline(results: Dict, baseline_results: Dict = None) -> Dict:
|
| 210 |
"""Compare results with baseline (e.g., Google Translate)."""
|
| 211 |
|
|
@@ -290,20 +306,19 @@ def generate_evaluation_report(results: Dict, model_name: str = "", comparison:
|
|
| 290 |
report = []
|
| 291 |
|
| 292 |
# Header
|
| 293 |
-
report.append(f"
|
| 294 |
-
report.append(f"Generated: {pd.Timestamp.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
| 295 |
report.append("")
|
| 296 |
|
| 297 |
# Summary
|
| 298 |
summary = results['summary']
|
| 299 |
-
report.append("
|
| 300 |
report.append(f"- **Total Samples Evaluated**: {summary['total_samples']:,}")
|
| 301 |
report.append(f"- **Language Pairs Covered**: {summary['language_pairs_covered']}")
|
| 302 |
report.append(f"- **Google Comparable Pairs**: {summary['google_comparable_pairs']}")
|
| 303 |
report.append("")
|
| 304 |
|
| 305 |
# Primary metrics
|
| 306 |
-
report.append("
|
| 307 |
for metric, value in summary['primary_metrics'].items():
|
| 308 |
formatted_value = f"{value:.4f}" if metric != 'bleu' else f"{value:.2f}"
|
| 309 |
report.append(f"- **{metric.upper()}**: {formatted_value}")
|
|
@@ -323,7 +338,7 @@ def generate_evaluation_report(results: Dict, model_name: str = "", comparison:
|
|
| 323 |
report.append("")
|
| 324 |
|
| 325 |
# Secondary metrics
|
| 326 |
-
report.append("
|
| 327 |
for metric, value in summary['secondary_metrics'].items():
|
| 328 |
formatted_value = f"{value:.4f}"
|
| 329 |
report.append(f"- **{metric.upper()}**: {formatted_value}")
|
|
@@ -339,26 +354,27 @@ def generate_evaluation_report(results: Dict, model_name: str = "", comparison:
|
|
| 339 |
reverse=True
|
| 340 |
)
|
| 341 |
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
report.append(f"- **{src} β {tgt}**: {score:.3f}")
|
| 346 |
-
|
| 347 |
-
if len(sorted_pairs) > 5:
|
| 348 |
-
report.append("")
|
| 349 |
-
report.append("## π Challenging Language Pairs")
|
| 350 |
-
for pair, score in sorted_pairs[-3:]:
|
| 351 |
src, tgt = pair.replace('_to_', ' β ').split(' β ')
|
| 352 |
report.append(f"- **{src} β {tgt}**: {score:.3f}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 353 |
|
| 354 |
# Comparison with baseline
|
| 355 |
if comparison and comparison.get('comparison_available'):
|
| 356 |
report.append("")
|
| 357 |
-
report.append("
|
| 358 |
|
| 359 |
better_count = len(comparison.get('better_pairs', []))
|
| 360 |
worse_count = len(comparison.get('worse_pairs', []))
|
| 361 |
-
total_comparable =
|
| 362 |
|
| 363 |
if total_comparable > 0:
|
| 364 |
report.append(f"- **Better than baseline**: {better_count}/{total_comparable} pairs")
|
|
|
|
| 11 |
from src.utils import get_all_language_pairs, get_google_comparable_pairs
|
| 12 |
|
| 13 |
def calculate_sentence_metrics(reference: str, prediction: str) -> Dict[str, float]:
|
| 14 |
+
"""Calculate all metrics for a single sentence pair - Fixed to match reference implementation."""
|
| 15 |
|
| 16 |
# Handle empty predictions
|
| 17 |
if not prediction or not isinstance(prediction, str):
|
|
|
|
| 27 |
|
| 28 |
metrics = {}
|
| 29 |
|
| 30 |
+
# BLEU score (keep as 0-100 scale initially)
|
| 31 |
try:
|
| 32 |
bleu = BLEU(effective_order=True)
|
| 33 |
metrics['bleu'] = bleu.sentence_score(pred_norm, [ref_norm]).score
|
| 34 |
except:
|
| 35 |
metrics['bleu'] = 0.0
|
| 36 |
|
| 37 |
+
# ChrF score (normalize to 0-1)
|
| 38 |
try:
|
| 39 |
chrf = CHRF()
|
| 40 |
metrics['chrf'] = chrf.sentence_score(pred_norm, [ref_norm]).score / 100.0
|
|
|
|
| 83 |
metrics['rouge2'] = 0.0
|
| 84 |
metrics['rougeL'] = 0.0
|
| 85 |
|
| 86 |
+
# Quality score (composite metric) - Fixed to match reference
|
| 87 |
try:
|
| 88 |
quality_components = [
|
| 89 |
metrics['bleu'] / 100.0, # Normalize BLEU to 0-1
|
| 90 |
+
metrics['chrf'], # Already 0-1
|
| 91 |
1.0 - min(metrics['cer'], 1.0), # Invert error rates
|
| 92 |
1.0 - min(metrics['wer'], 1.0),
|
| 93 |
metrics['rouge1'],
|
| 94 |
metrics['rougeL']
|
| 95 |
]
|
| 96 |
metrics['quality_score'] = np.mean(quality_components)
|
| 97 |
+
except Exception as e:
|
| 98 |
+
# Fallback without ROUGE
|
| 99 |
+
print(f"Error calculating quality score: {e}")
|
| 100 |
+
try:
|
| 101 |
+
fallback_components = [
|
| 102 |
+
metrics['bleu'] / 100.0,
|
| 103 |
+
metrics['chrf'],
|
| 104 |
+
1.0 - min(metrics['cer'], 1.0),
|
| 105 |
+
1.0 - min(metrics['wer'], 1.0)
|
| 106 |
+
]
|
| 107 |
+
metrics['quality_score'] = np.mean(fallback_components)
|
| 108 |
+
except:
|
| 109 |
+
metrics['quality_score'] = 0.0
|
| 110 |
|
| 111 |
return metrics
|
| 112 |
|
|
|
|
| 143 |
|
| 144 |
sample_df = pd.DataFrame(sample_metrics)
|
| 145 |
|
| 146 |
+
# Aggregate by language pairs - Fixed aggregation
|
| 147 |
pair_metrics = {}
|
| 148 |
overall_metrics = defaultdict(list)
|
| 149 |
google_comparable_metrics = defaultdict(list)
|
|
|
|
| 164 |
# Calculate averages for this pair
|
| 165 |
for metric in METRICS_CONFIG['primary_metrics'] + METRICS_CONFIG['secondary_metrics']:
|
| 166 |
if metric in pair_data.columns:
|
| 167 |
+
# Filter out invalid values
|
| 168 |
+
valid_values = pair_data[metric].replace([np.inf, -np.inf], np.nan).dropna()
|
| 169 |
+
if len(valid_values) > 0:
|
| 170 |
+
avg_value = float(valid_values.mean())
|
| 171 |
+
pair_metrics[pair_key][metric] = avg_value
|
| 172 |
+
|
| 173 |
+
# Add to overall averages
|
| 174 |
+
overall_metrics[metric].append(avg_value)
|
| 175 |
+
|
| 176 |
+
# Add to Google comparable if applicable
|
| 177 |
+
if (src_lang in GOOGLE_SUPPORTED_LANGUAGES and
|
| 178 |
+
tgt_lang in GOOGLE_SUPPORTED_LANGUAGES):
|
| 179 |
+
google_comparable_metrics[metric].append(avg_value)
|
| 180 |
|
| 181 |
pair_metrics[pair_key]['sample_count'] = len(pair_data)
|
| 182 |
|
|
|
|
| 199 |
# Generate evaluation summary
|
| 200 |
summary = {
|
| 201 |
'total_samples': len(sample_df),
|
| 202 |
+
'language_pairs_covered': len([k for k in pair_metrics if pair_metrics[k].get('sample_count', 0) > 0]),
|
| 203 |
'google_comparable_pairs': len([k for k in pair_metrics
|
| 204 |
if '_to_' in k and
|
| 205 |
k.split('_to_')[0] in GOOGLE_SUPPORTED_LANGUAGES and
|
| 206 |
+
k.split('_to_')[1] in GOOGLE_SUPPORTED_LANGUAGES and
|
| 207 |
+
pair_metrics[k].get('sample_count', 0) > 0]),
|
| 208 |
'primary_metrics': {metric: averages.get(metric, 0.0)
|
| 209 |
for metric in METRICS_CONFIG['primary_metrics']},
|
| 210 |
'secondary_metrics': {metric: averages.get(metric, 0.0)
|
|
|
|
| 221 |
'error': None
|
| 222 |
}
|
| 223 |
|
| 224 |
+
# Keep the rest of the functions unchanged...
|
| 225 |
def compare_with_baseline(results: Dict, baseline_results: Dict = None) -> Dict:
|
| 226 |
"""Compare results with baseline (e.g., Google Translate)."""
|
| 227 |
|
|
|
|
| 306 |
report = []
|
| 307 |
|
| 308 |
# Header
|
| 309 |
+
report.append(f"## Evaluation Report: {model_name or 'Submission'}")
|
|
|
|
| 310 |
report.append("")
|
| 311 |
|
| 312 |
# Summary
|
| 313 |
summary = results['summary']
|
| 314 |
+
report.append("### π Summary")
|
| 315 |
report.append(f"- **Total Samples Evaluated**: {summary['total_samples']:,}")
|
| 316 |
report.append(f"- **Language Pairs Covered**: {summary['language_pairs_covered']}")
|
| 317 |
report.append(f"- **Google Comparable Pairs**: {summary['google_comparable_pairs']}")
|
| 318 |
report.append("")
|
| 319 |
|
| 320 |
# Primary metrics
|
| 321 |
+
report.append("### π― Primary Metrics")
|
| 322 |
for metric, value in summary['primary_metrics'].items():
|
| 323 |
formatted_value = f"{value:.4f}" if metric != 'bleu' else f"{value:.2f}"
|
| 324 |
report.append(f"- **{metric.upper()}**: {formatted_value}")
|
|
|
|
| 338 |
report.append("")
|
| 339 |
|
| 340 |
# Secondary metrics
|
| 341 |
+
report.append("### π Secondary Metrics")
|
| 342 |
for metric, value in summary['secondary_metrics'].items():
|
| 343 |
formatted_value = f"{value:.4f}"
|
| 344 |
report.append(f"- **{metric.upper()}**: {formatted_value}")
|
|
|
|
| 354 |
reverse=True
|
| 355 |
)
|
| 356 |
|
| 357 |
+
if sorted_pairs:
|
| 358 |
+
report.append("### π Best Performing Language Pairs")
|
| 359 |
+
for pair, score in sorted_pairs[:5]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 360 |
src, tgt = pair.replace('_to_', ' β ').split(' β ')
|
| 361 |
report.append(f"- **{src} β {tgt}**: {score:.3f}")
|
| 362 |
+
|
| 363 |
+
if len(sorted_pairs) > 5:
|
| 364 |
+
report.append("")
|
| 365 |
+
report.append("### π Challenging Language Pairs")
|
| 366 |
+
for pair, score in sorted_pairs[-3:]:
|
| 367 |
+
src, tgt = pair.replace('_to_', ' β ').split(' β ')
|
| 368 |
+
report.append(f"- **{src} β {tgt}**: {score:.3f}")
|
| 369 |
|
| 370 |
# Comparison with baseline
|
| 371 |
if comparison and comparison.get('comparison_available'):
|
| 372 |
report.append("")
|
| 373 |
+
report.append("### π Comparison with Baseline")
|
| 374 |
|
| 375 |
better_count = len(comparison.get('better_pairs', []))
|
| 376 |
worse_count = len(comparison.get('worse_pairs', []))
|
| 377 |
+
total_comparable = len(comparison.get('pair_comparisons', {}))
|
| 378 |
|
| 379 |
if total_comparable > 0:
|
| 380 |
report.append(f"- **Better than baseline**: {better_count}/{total_comparable} pairs")
|