Leaderboard / utils.py
Yassine El Kheir
debug
f071110
import pandas as pd
import evaluate
import csv
import io
import tempfile
import os
import shutil
import sys
import traceback
# Dynamically add the mdd_eval directory to path to allow imports
current_dir = os.path.dirname(os.path.abspath(__file__))
mdd_eval_path = os.path.join(current_dir, "mdd_eval")
if mdd_eval_path not in sys.path:
sys.path.append(mdd_eval_path)
try:
from mdd_eval.align_data import evaluate_from_dfs
from mdd_eval.ins_del_cor_sub_analysis import analyze_alignment
except ImportError:
# If standard import fails, try relative import if possible or just rely on sys.path
import align_data
import ins_del_cor_sub_analysis
evaluate_from_dfs = align_data.evaluate_from_dfs
analyze_alignment = ins_del_cor_sub_analysis.analyze_alignment
wer = evaluate.load("wer")
def load_leaderboard(db_path):
try:
if os.path.exists(db_path):
df = pd.read_csv(db_path)
return df
else:
return pd.DataFrame()
except Exception:
# Return empty dataframe on error (e.g. EmptyDataError)
return pd.DataFrame()
try:
from datasets import load_dataset
except ImportError:
load_dataset = None
IDS_CACHE_FILE = "IDs.txt"
def get_allowed_ids():
print("DEBUG: Entering get_allowed_ids...", flush=True)
# Return matched IDs from cache or huggingface
if os.path.exists(IDS_CACHE_FILE):
print(f"DEBUG: Found {IDS_CACHE_FILE}, reading...", flush=True)
with open(IDS_CACHE_FILE, 'r') as f:
ids = set(line.strip() for line in f if line.strip())
if ids:
print(f"Loaded {len(ids)} allowed IDs from {IDS_CACHE_FILE}", flush=True)
return ids
# If not cached, load from HF
hf_token = os.environ.get("SPACE_HF_TOKEN") or os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN")
print(f"DEBUG: Checking HF_TOKEN: {'Present' if hf_token else 'Missing'}", flush=True)
if hf_token and load_dataset:
try:
print("Fetching allowed IDs from IqraEval/QuranMB.v2 (split='test')...", flush=True)
# Use streaming to avoid downloading audio files
dataset = load_dataset("IqraEval/QuranMB.v2", split="test", token=hf_token, streaming=True)
allowed_ids = set()
print("DEBUG: Iterating dataset...", flush=True)
count = 0
for item in dataset:
if "ID" in item:
allowed_ids.add(str(item["ID"]).strip())
elif "id" in item:
allowed_ids.add(str(item["id"]).strip())
count += 1
if count % 1000 == 0:
print(f"DEBUG: Processed {count} items...", flush=True)
if allowed_ids:
with open(IDS_CACHE_FILE, 'w') as f:
for i in sorted(allowed_ids):
f.write(f"{i}\n")
print(f"Cached {len(allowed_ids)} IDs to {IDS_CACHE_FILE}", flush=True)
return allowed_ids
else:
print("DEBUG: No IDs found in dataset.", flush=True)
except Exception as e:
print(f"Error fetching allowed IDs: {e}", flush=True)
pass
print("DEBUG: Returning None from get_allowed_ids", flush=True)
return None
def load_ground_truth_references(ground_truth_path):
# Get allowed IDs first
allowed_ids = get_allowed_ids()
# Attempt to load from Hugging Face if HF_TOKEN is present
hf_token = os.environ.get("SPACE_HF_TOKEN") or os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN")
if hf_token and load_dataset:
try:
print("Attempting to load ground truth from Hugging Face (IqraEval/QuranMB.v2.labels)...")
# Using 'train' split as per user request/implication or previous context?
# User said "QuranMB.v2 train with key" for IDs.
# For LABELS, previous successful load was 'train' split (from logs).
dataset = load_dataset("IqraEval/QuranMB.v2.labels", split="train", token=hf_token)
references = []
for item in dataset:
rid = str(item.get("ID")).strip()
# Filter if allowed_ids list exists
if allowed_ids and rid not in allowed_ids:
continue
ref_data = {
"ID": rid,
"Reference_phn": item.get("reference"),
"Annotation_phn": item.get("canonical")
}
# Fallback mapping if keys differ slightly in actual dataset
if not ref_data["Reference_phn"] and "Reference_phn" in item:
ref_data["Reference_phn"] = item["Reference_phn"]
if not ref_data["Annotation_phn"] and "Annotation_phn" in item:
ref_data["Annotation_phn"] = item["Annotation_phn"]
references.append(ref_data)
print(f"Successfully loaded {len(references)} filtered references from Hugging Face.")
return references
except Exception as e:
print(f"Failed to load from Hugging Face: {e}")
print("Falling back to local CSV file.")
# Fallback to local file
if not os.path.exists(ground_truth_path):
print(f"Warning: Ground truth file not found at {ground_truth_path}")
return []
with open(ground_truth_path, newline='') as f:
reader = csv.DictReader(f)
references = []
for row in reader:
rid = str(row.get("id")).strip()
# Filter if allowed_ids list exists
if allowed_ids and rid not in allowed_ids:
continue
ref_data = {"ID": rid}
if "Reference_phn" in row:
ref_data["Reference_phn"] = row["Reference_phn"]
elif "reference_phn" in row:
ref_data["Reference_phn"] = row["reference_phn"]
elif "correct_phoneme" in row:
ref_data["Reference_phn"] = row["correct_phoneme"]
if "Annotation_phn" in row:
ref_data["Annotation_phn"] = row["Annotation_phn"]
elif "annotation_phn" in row:
ref_data["Annotation_phn"] = row["annotation_phn"]
elif "Reference_phn" in ref_data:
ref_data["Annotation_phn"] = ref_data["Reference_phn"]
references.append(ref_data)
print(f"Loaded {len(references)} filtered references from local file.")
return references
def parse_submission_csv(file_obj):
# Handle different input types (bytes or string)
content = ""
if isinstance(file_obj, bytes):
content = file_obj.decode("utf-8")
elif hasattr(file_obj, 'read'): # File-like object
try:
content = file_obj.read()
if isinstance(content, bytes):
content = content.decode("utf-8")
except Exception:
# If read fails or it's not actually a readable file object in the way we expect
pass
if not content and isinstance(file_obj, str):
if os.path.exists(file_obj): # File path
with open(file_obj, 'r', encoding='utf-8') as f:
content = f.read()
else:
# Maybe it's just a string content?
content = file_obj
if not content:
raise ValueError("Could not read content from submission file.")
text_stream = io.StringIO(content)
reader = csv.DictReader(text_stream)
if not reader.fieldnames:
raise ValueError("CSV file is empty or missing headers.")
# case-insensitive check and strip whitespace
lower_fieldnames = [f.lower().strip() for f in reader.fieldnames]
# Validate columns
id_col = None
pred_col = None
if "id" in lower_fieldnames: id_col = "id"
if "predicted_phoneme" in lower_fieldnames: pred_col = "predicted_phoneme"
elif "labels" in lower_fieldnames: pred_col = "labels"
elif "prediction" in lower_fieldnames: pred_col = "prediction" # Added resilience
if not id_col or not pred_col:
raise ValueError(f"Submission CSV must contain columns: 'id', 'predicted_phoneme' (or 'Labels')")
results = []
for row in reader:
# map lower case keys to values and strip whitespace from keys
row_lower = {k.lower().strip(): v for k, v in row.items()}
id_val = row_lower.get("id")
pred_val = row_lower.get(pred_col)
results.append({"ID": id_val, "Prediction": pred_val})
return results
def calculate_comprehensive_metrics(submission_file_obj, references):
metrics = {}
error_message = None
if not references:
return None, "Error: Ground truth references could not be loaded. Please ensure 'ground_truth.csv' exists locally or HF_TOKEN is set to access 'IqraEval/QuranMB.v2.labels'."
temp_dir = tempfile.mkdtemp()
try:
temp_aligned_dir = os.path.join(temp_dir, "aligned")
os.makedirs(temp_aligned_dir, exist_ok=True)
# Prepare ground truth DataFrame
truth_df = pd.DataFrame(references)
# Ensure correct column names
truth_df = truth_df.rename(columns={"id": "ID", "reference_phn": "Reference_phn", "annotation_phn": "Annotation_phn"})
# Prepare prediction DataFrame
predictions = parse_submission_csv(submission_file_obj)
pred_df = pd.DataFrame(predictions)
if not all(col in truth_df.columns for col in ["ID", "Reference_phn", "Annotation_phn"]):
return None, "Error: Ground truth references missing required columns."
if not all(col in pred_df.columns for col in ["ID", "Prediction"]):
return None, "Error: Submission missing 'ID' or 'Prediction' columns."
# Validating IDs mismatch
truth_ids = set(truth_df["ID"].astype(str).str.strip())
pred_ids = set(pred_df["ID"].astype(str).str.strip())
if len(truth_ids) != len(pred_ids):
return None, f"Error: Mismatch in number of predictions. Expected {len(truth_ids)} IDs, but got {len(pred_ids)}. Please ensure your submission covers the entire test set."
missing_ids = truth_ids - pred_ids
if missing_ids:
# Show a few missing IDs as example
example_missing = list(missing_ids)[:3]
return None, f"Error: Submission IDs do not match ground truth. Missing IDs example: {example_missing}"
# --- Step 1: Run alignment ---
corr_rate, acc = evaluate_from_dfs(
truth_df=truth_df,
pred_df=pred_df,
output_dir=temp_aligned_dir,
wov=False,
print_output=False
)
# --- Step 2: Run Analysis ---
metrics = analyze_alignment(temp_aligned_dir)
if "Error" in metrics:
return None, f"Analysis Error: {metrics['Error']}"
# Ensure Samples is included
pred_ids = set(pred_df["ID"].unique())
truth_ids = set(truth_df["ID"].unique())
metrics["Samples"] = len(pred_ids.intersection(truth_ids))
# Add basic metrics if missing
if "Accuracy" not in metrics:
metrics["Accuracy"] = acc
if "PER" not in metrics:
metrics["PER"] = 1.0 - acc
except Exception as e:
# import traceback
error_message = f"Error during metric calculation: {str(e)}"
print(traceback.format_exc())
finally:
# Clean up temporary files
if temp_dir and os.path.exists(temp_dir):
shutil.rmtree(temp_dir)
return metrics if not error_message else None, error_message
def calculate_per_score(submission_file_obj, references):
"""Legacy function for backward compatibility."""
metrics, error_message = calculate_comprehensive_metrics(submission_file_obj, references)
if metrics is None:
return 0.0, 0
return metrics.get("PER", 0.0), metrics.get("Samples", 0)
custom_css = """
#leaderboard-table td, #leaderboard-table th {
white-space: nowrap;
min-width: 100px;
text-align: center !important;
}
#leaderboard-table th > div, #leaderboard-table th > span {
justify-content: center !important;
text-align: center !important;
width: 100%;
display: flex;
}
"""