Spaces:
Running
Running
| import pandas as pd | |
| import evaluate | |
| import csv | |
| import io | |
| import tempfile | |
| import os | |
| import shutil | |
| import sys | |
| import traceback | |
| # Dynamically add the mdd_eval directory to path to allow imports | |
| current_dir = os.path.dirname(os.path.abspath(__file__)) | |
| mdd_eval_path = os.path.join(current_dir, "mdd_eval") | |
| if mdd_eval_path not in sys.path: | |
| sys.path.append(mdd_eval_path) | |
| try: | |
| from mdd_eval.align_data import evaluate_from_dfs | |
| from mdd_eval.ins_del_cor_sub_analysis import analyze_alignment | |
| except ImportError: | |
| # If standard import fails, try relative import if possible or just rely on sys.path | |
| import align_data | |
| import ins_del_cor_sub_analysis | |
| evaluate_from_dfs = align_data.evaluate_from_dfs | |
| analyze_alignment = ins_del_cor_sub_analysis.analyze_alignment | |
| wer = evaluate.load("wer") | |
| def load_leaderboard(db_path): | |
| try: | |
| if os.path.exists(db_path): | |
| df = pd.read_csv(db_path) | |
| return df | |
| else: | |
| return pd.DataFrame() | |
| except Exception: | |
| # Return empty dataframe on error (e.g. EmptyDataError) | |
| return pd.DataFrame() | |
| try: | |
| from datasets import load_dataset | |
| except ImportError: | |
| load_dataset = None | |
| IDS_CACHE_FILE = "IDs.txt" | |
| def get_allowed_ids(): | |
| print("DEBUG: Entering get_allowed_ids...", flush=True) | |
| # Return matched IDs from cache or huggingface | |
| if os.path.exists(IDS_CACHE_FILE): | |
| print(f"DEBUG: Found {IDS_CACHE_FILE}, reading...", flush=True) | |
| with open(IDS_CACHE_FILE, 'r') as f: | |
| ids = set(line.strip() for line in f if line.strip()) | |
| if ids: | |
| print(f"Loaded {len(ids)} allowed IDs from {IDS_CACHE_FILE}", flush=True) | |
| return ids | |
| # If not cached, load from HF | |
| hf_token = os.environ.get("SPACE_HF_TOKEN") or os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN") | |
| print(f"DEBUG: Checking HF_TOKEN: {'Present' if hf_token else 'Missing'}", flush=True) | |
| if hf_token and load_dataset: | |
| try: | |
| print("Fetching allowed IDs from IqraEval/QuranMB.v2 (split='test')...", flush=True) | |
| # Use streaming to avoid downloading audio files | |
| dataset = load_dataset("IqraEval/QuranMB.v2", split="test", token=hf_token, streaming=True) | |
| allowed_ids = set() | |
| print("DEBUG: Iterating dataset...", flush=True) | |
| count = 0 | |
| for item in dataset: | |
| if "ID" in item: | |
| allowed_ids.add(str(item["ID"]).strip()) | |
| elif "id" in item: | |
| allowed_ids.add(str(item["id"]).strip()) | |
| count += 1 | |
| if count % 1000 == 0: | |
| print(f"DEBUG: Processed {count} items...", flush=True) | |
| if allowed_ids: | |
| with open(IDS_CACHE_FILE, 'w') as f: | |
| for i in sorted(allowed_ids): | |
| f.write(f"{i}\n") | |
| print(f"Cached {len(allowed_ids)} IDs to {IDS_CACHE_FILE}", flush=True) | |
| return allowed_ids | |
| else: | |
| print("DEBUG: No IDs found in dataset.", flush=True) | |
| except Exception as e: | |
| print(f"Error fetching allowed IDs: {e}", flush=True) | |
| pass | |
| print("DEBUG: Returning None from get_allowed_ids", flush=True) | |
| return None | |
| def load_ground_truth_references(ground_truth_path): | |
| # Get allowed IDs first | |
| allowed_ids = get_allowed_ids() | |
| # Attempt to load from Hugging Face if HF_TOKEN is present | |
| hf_token = os.environ.get("SPACE_HF_TOKEN") or os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN") | |
| if hf_token and load_dataset: | |
| try: | |
| print("Attempting to load ground truth from Hugging Face (IqraEval/QuranMB.v2.labels)...") | |
| # Using 'train' split as per user request/implication or previous context? | |
| # User said "QuranMB.v2 train with key" for IDs. | |
| # For LABELS, previous successful load was 'train' split (from logs). | |
| dataset = load_dataset("IqraEval/QuranMB.v2.labels", split="train", token=hf_token) | |
| references = [] | |
| for item in dataset: | |
| rid = str(item.get("ID")).strip() | |
| # Filter if allowed_ids list exists | |
| if allowed_ids and rid not in allowed_ids: | |
| continue | |
| ref_data = { | |
| "ID": rid, | |
| "Reference_phn": item.get("reference"), | |
| "Annotation_phn": item.get("canonical") | |
| } | |
| # Fallback mapping if keys differ slightly in actual dataset | |
| if not ref_data["Reference_phn"] and "Reference_phn" in item: | |
| ref_data["Reference_phn"] = item["Reference_phn"] | |
| if not ref_data["Annotation_phn"] and "Annotation_phn" in item: | |
| ref_data["Annotation_phn"] = item["Annotation_phn"] | |
| references.append(ref_data) | |
| print(f"Successfully loaded {len(references)} filtered references from Hugging Face.") | |
| return references | |
| except Exception as e: | |
| print(f"Failed to load from Hugging Face: {e}") | |
| print("Falling back to local CSV file.") | |
| # Fallback to local file | |
| if not os.path.exists(ground_truth_path): | |
| print(f"Warning: Ground truth file not found at {ground_truth_path}") | |
| return [] | |
| with open(ground_truth_path, newline='') as f: | |
| reader = csv.DictReader(f) | |
| references = [] | |
| for row in reader: | |
| rid = str(row.get("id")).strip() | |
| # Filter if allowed_ids list exists | |
| if allowed_ids and rid not in allowed_ids: | |
| continue | |
| ref_data = {"ID": rid} | |
| if "Reference_phn" in row: | |
| ref_data["Reference_phn"] = row["Reference_phn"] | |
| elif "reference_phn" in row: | |
| ref_data["Reference_phn"] = row["reference_phn"] | |
| elif "correct_phoneme" in row: | |
| ref_data["Reference_phn"] = row["correct_phoneme"] | |
| if "Annotation_phn" in row: | |
| ref_data["Annotation_phn"] = row["Annotation_phn"] | |
| elif "annotation_phn" in row: | |
| ref_data["Annotation_phn"] = row["annotation_phn"] | |
| elif "Reference_phn" in ref_data: | |
| ref_data["Annotation_phn"] = ref_data["Reference_phn"] | |
| references.append(ref_data) | |
| print(f"Loaded {len(references)} filtered references from local file.") | |
| return references | |
| def parse_submission_csv(file_obj): | |
| # Handle different input types (bytes or string) | |
| content = "" | |
| if isinstance(file_obj, bytes): | |
| content = file_obj.decode("utf-8") | |
| elif hasattr(file_obj, 'read'): # File-like object | |
| try: | |
| content = file_obj.read() | |
| if isinstance(content, bytes): | |
| content = content.decode("utf-8") | |
| except Exception: | |
| # If read fails or it's not actually a readable file object in the way we expect | |
| pass | |
| if not content and isinstance(file_obj, str): | |
| if os.path.exists(file_obj): # File path | |
| with open(file_obj, 'r', encoding='utf-8') as f: | |
| content = f.read() | |
| else: | |
| # Maybe it's just a string content? | |
| content = file_obj | |
| if not content: | |
| raise ValueError("Could not read content from submission file.") | |
| text_stream = io.StringIO(content) | |
| reader = csv.DictReader(text_stream) | |
| if not reader.fieldnames: | |
| raise ValueError("CSV file is empty or missing headers.") | |
| # case-insensitive check and strip whitespace | |
| lower_fieldnames = [f.lower().strip() for f in reader.fieldnames] | |
| # Validate columns | |
| id_col = None | |
| pred_col = None | |
| if "id" in lower_fieldnames: id_col = "id" | |
| if "predicted_phoneme" in lower_fieldnames: pred_col = "predicted_phoneme" | |
| elif "labels" in lower_fieldnames: pred_col = "labels" | |
| elif "prediction" in lower_fieldnames: pred_col = "prediction" # Added resilience | |
| if not id_col or not pred_col: | |
| raise ValueError(f"Submission CSV must contain columns: 'id', 'predicted_phoneme' (or 'Labels')") | |
| results = [] | |
| for row in reader: | |
| # map lower case keys to values and strip whitespace from keys | |
| row_lower = {k.lower().strip(): v for k, v in row.items()} | |
| id_val = row_lower.get("id") | |
| pred_val = row_lower.get(pred_col) | |
| results.append({"ID": id_val, "Prediction": pred_val}) | |
| return results | |
| def calculate_comprehensive_metrics(submission_file_obj, references): | |
| metrics = {} | |
| error_message = None | |
| if not references: | |
| return None, "Error: Ground truth references could not be loaded. Please ensure 'ground_truth.csv' exists locally or HF_TOKEN is set to access 'IqraEval/QuranMB.v2.labels'." | |
| temp_dir = tempfile.mkdtemp() | |
| try: | |
| temp_aligned_dir = os.path.join(temp_dir, "aligned") | |
| os.makedirs(temp_aligned_dir, exist_ok=True) | |
| # Prepare ground truth DataFrame | |
| truth_df = pd.DataFrame(references) | |
| # Ensure correct column names | |
| truth_df = truth_df.rename(columns={"id": "ID", "reference_phn": "Reference_phn", "annotation_phn": "Annotation_phn"}) | |
| # Prepare prediction DataFrame | |
| predictions = parse_submission_csv(submission_file_obj) | |
| pred_df = pd.DataFrame(predictions) | |
| if not all(col in truth_df.columns for col in ["ID", "Reference_phn", "Annotation_phn"]): | |
| return None, "Error: Ground truth references missing required columns." | |
| if not all(col in pred_df.columns for col in ["ID", "Prediction"]): | |
| return None, "Error: Submission missing 'ID' or 'Prediction' columns." | |
| # Validating IDs mismatch | |
| truth_ids = set(truth_df["ID"].astype(str).str.strip()) | |
| pred_ids = set(pred_df["ID"].astype(str).str.strip()) | |
| if len(truth_ids) != len(pred_ids): | |
| return None, f"Error: Mismatch in number of predictions. Expected {len(truth_ids)} IDs, but got {len(pred_ids)}. Please ensure your submission covers the entire test set." | |
| missing_ids = truth_ids - pred_ids | |
| if missing_ids: | |
| # Show a few missing IDs as example | |
| example_missing = list(missing_ids)[:3] | |
| return None, f"Error: Submission IDs do not match ground truth. Missing IDs example: {example_missing}" | |
| # --- Step 1: Run alignment --- | |
| corr_rate, acc = evaluate_from_dfs( | |
| truth_df=truth_df, | |
| pred_df=pred_df, | |
| output_dir=temp_aligned_dir, | |
| wov=False, | |
| print_output=False | |
| ) | |
| # --- Step 2: Run Analysis --- | |
| metrics = analyze_alignment(temp_aligned_dir) | |
| if "Error" in metrics: | |
| return None, f"Analysis Error: {metrics['Error']}" | |
| # Ensure Samples is included | |
| pred_ids = set(pred_df["ID"].unique()) | |
| truth_ids = set(truth_df["ID"].unique()) | |
| metrics["Samples"] = len(pred_ids.intersection(truth_ids)) | |
| # Add basic metrics if missing | |
| if "Accuracy" not in metrics: | |
| metrics["Accuracy"] = acc | |
| if "PER" not in metrics: | |
| metrics["PER"] = 1.0 - acc | |
| except Exception as e: | |
| # import traceback | |
| error_message = f"Error during metric calculation: {str(e)}" | |
| print(traceback.format_exc()) | |
| finally: | |
| # Clean up temporary files | |
| if temp_dir and os.path.exists(temp_dir): | |
| shutil.rmtree(temp_dir) | |
| return metrics if not error_message else None, error_message | |
| def calculate_per_score(submission_file_obj, references): | |
| """Legacy function for backward compatibility.""" | |
| metrics, error_message = calculate_comprehensive_metrics(submission_file_obj, references) | |
| if metrics is None: | |
| return 0.0, 0 | |
| return metrics.get("PER", 0.0), metrics.get("Samples", 0) | |
| custom_css = """ | |
| #leaderboard-table td, #leaderboard-table th { | |
| white-space: nowrap; | |
| min-width: 100px; | |
| text-align: center !important; | |
| } | |
| #leaderboard-table th > div, #leaderboard-table th > span { | |
| justify-content: center !important; | |
| text-align: center !important; | |
| width: 100%; | |
| display: flex; | |
| } | |
| """ | |