Spaces:
Running
Running
| import pandas as pd | |
| import numpy as np | |
| import argparse | |
| from pathlib import Path | |
| import os | |
| try: | |
| from metric import Correct_Rate, Accuracy, Align, insertions, deletions, substitutions | |
| except ImportError: | |
| # Handle import when running from parent directory or installed as package | |
| from .metric import Correct_Rate, Accuracy, Align, insertions, deletions, substitutions | |
| def get_op(seq1, seq2): | |
| return ['D' if s1 != "<eps>" and s2 == "<eps>" else | |
| 'I' if s1 == "<eps>" and s2 != "<eps>" else | |
| 'S' if s1 != s2 else 'C' for s1, s2 in zip(seq1, seq2)] | |
| def get_align(k, s1, s2): | |
| a1, a2 = Align(s1, s2) | |
| I = insertions(a1, a2) | |
| D = deletions(a1, a2)[0] | |
| S = substitutions(a1, a2) | |
| C = len(a1) - I - D - S | |
| return [ | |
| k + ' ref ' + ' '.join(a1), | |
| k + ' hyp ' + ' '.join(a2), | |
| k + ' op ' + ' '.join(get_op(a1, a2)), | |
| k + f' #csid {C} {S} {I} {D}' | |
| ] | |
| def evaluate_from_dfs(truth_df: pd.DataFrame, | |
| pred_df: pd.DataFrame, | |
| output_dir: str = 'aligned', | |
| wov: bool = False, | |
| print_output: bool = True, output_text=None): | |
| output_dir = Path(output_dir) | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| ref_col = 'Reference_vow_phn' if wov else 'Reference_phn' | |
| ann_col = 'Annotation_vow_phn' if wov else 'Annotation_phn' | |
| valid_ids = set(truth_df['ID']) & set(pred_df['ID']) | |
| truth_df = truth_df[truth_df['ID'].isin(valid_ids)][['ID', ref_col, ann_col]] | |
| pred_df = pred_df[pred_df['ID'].isin(valid_ids)][['ID', 'Prediction']] | |
| # Merge the necessary columns only | |
| df = pd.merge(truth_df, pred_df, on='ID') | |
| ref_our_detail_path = output_dir / 'ref_our_detail' | |
| ref_human_detail_path = output_dir / 'ref_human_detail' | |
| human_our_detail_path = output_dir / 'human_our_detail' | |
| canon_pred_f = open(ref_our_detail_path, 'w') | |
| canon_trans_f = open(ref_human_detail_path, 'w') | |
| trans_pred_f = open(human_our_detail_path, 'w') | |
| list_correct_rate, list_accuracy, list_len = [], [], [] | |
| for _, row in df.iterrows(): | |
| data_id = row['ID'] | |
| pred = row['Prediction'].split() | |
| trans = row[ann_col].split() | |
| canon = row[ref_col].split() | |
| canon_pred_f.write('\n'.join(get_align(data_id, canon, pred)) + '\n') | |
| canon_trans_f.write('\n'.join(get_align(data_id, canon, trans)) + '\n') | |
| trans_pred_f.write('\n'.join(get_align(data_id, trans, pred)) + '\n') | |
| correct_rate, len_, _ = Correct_Rate(trans, pred) | |
| acc, len_ = Accuracy(trans, pred) | |
| list_correct_rate.append((len_ - correct_rate) / len_) | |
| list_accuracy.append((len_ - acc) / len_) | |
| list_len.append(len_) | |
| canon_pred_f.close() | |
| canon_trans_f.close() | |
| trans_pred_f.close() | |
| if not list_len: | |
| # Handle empty intersection or empty input | |
| if print_output: | |
| print("No valid overlapping IDs found.") | |
| return 0.0, 0.0 | |
| weights = np.array(list_len) | |
| corr_rate = round(np.sum(np.array(list_correct_rate) * weights) / weights.sum(), 4) | |
| acc = round(np.sum(np.array(list_accuracy) * weights) / weights.sum(), 4) | |
| if print_output: | |
| print("** MD&D Evaluation **") | |
| print("Correct Rate:", corr_rate) | |
| print("Accuracy:", acc) | |
| if output_text: | |
| with open(output_text,"a") as w: | |
| w.write(f"Correct Rate: {corr_rate}\n") | |
| w.write(f"Accuracy: {acc}\n") | |
| return corr_rate, acc | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Evaluate prediction accuracy against annotated and canonical phonemes.") | |
| parser.add_argument('--truth-file', type=str, required=True, help="Path to CSV containing truth data (with canonical and annotated phonemes)") | |
| parser.add_argument('--pred-file', type=str, required=True, help="Path to CSV containing predictions (with ID and Prediction columns)") | |
| parser.add_argument('--output-dir', type=str, default='./aligned', help="Directory to save alignment output files") | |
| parser.add_argument('--wov', action='store_true', help="Use vowel-only phoneme columns instead of full ones") | |
| parser.add_argument('--output_text', type=str, default=None, help="File to save text output") | |
| args = parser.parse_args() | |
| # Load data | |
| truth_df = pd.read_csv(args.truth_file) | |
| pred_df = pd.read_csv(args.pred_file) | |
| # print(pred_df) | |
| # Run evaluation | |
| evaluate_from_dfs( | |
| truth_df=truth_df, | |
| pred_df=pred_df, | |
| output_dir=args.output_dir, | |
| wov=args.wov, | |
| output_text=args.output_text | |
| ) | |
| if __name__ == '__main__': | |
| main() | |