import pandas as pd import numpy as np import argparse from pathlib import Path import os try: from metric import Correct_Rate, Accuracy, Align, insertions, deletions, substitutions except ImportError: # Handle import when running from parent directory or installed as package from .metric import Correct_Rate, Accuracy, Align, insertions, deletions, substitutions def get_op(seq1, seq2): return ['D' if s1 != "" and s2 == "" else 'I' if s1 == "" and s2 != "" else 'S' if s1 != s2 else 'C' for s1, s2 in zip(seq1, seq2)] def get_align(k, s1, s2): a1, a2 = Align(s1, s2) I = insertions(a1, a2) D = deletions(a1, a2)[0] S = substitutions(a1, a2) C = len(a1) - I - D - S return [ k + ' ref ' + ' '.join(a1), k + ' hyp ' + ' '.join(a2), k + ' op ' + ' '.join(get_op(a1, a2)), k + f' #csid {C} {S} {I} {D}' ] def evaluate_from_dfs(truth_df: pd.DataFrame, pred_df: pd.DataFrame, output_dir: str = 'aligned', wov: bool = False, print_output: bool = True, output_text=None): output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) ref_col = 'Reference_vow_phn' if wov else 'Reference_phn' ann_col = 'Annotation_vow_phn' if wov else 'Annotation_phn' valid_ids = set(truth_df['ID']) & set(pred_df['ID']) truth_df = truth_df[truth_df['ID'].isin(valid_ids)][['ID', ref_col, ann_col]] pred_df = pred_df[pred_df['ID'].isin(valid_ids)][['ID', 'Prediction']] # Merge the necessary columns only df = pd.merge(truth_df, pred_df, on='ID') ref_our_detail_path = output_dir / 'ref_our_detail' ref_human_detail_path = output_dir / 'ref_human_detail' human_our_detail_path = output_dir / 'human_our_detail' canon_pred_f = open(ref_our_detail_path, 'w') canon_trans_f = open(ref_human_detail_path, 'w') trans_pred_f = open(human_our_detail_path, 'w') list_correct_rate, list_accuracy, list_len = [], [], [] for _, row in df.iterrows(): data_id = row['ID'] pred = row['Prediction'].split() trans = row[ann_col].split() canon = row[ref_col].split() canon_pred_f.write('\n'.join(get_align(data_id, canon, pred)) + '\n') canon_trans_f.write('\n'.join(get_align(data_id, canon, trans)) + '\n') trans_pred_f.write('\n'.join(get_align(data_id, trans, pred)) + '\n') correct_rate, len_, _ = Correct_Rate(trans, pred) acc, len_ = Accuracy(trans, pred) list_correct_rate.append((len_ - correct_rate) / len_) list_accuracy.append((len_ - acc) / len_) list_len.append(len_) canon_pred_f.close() canon_trans_f.close() trans_pred_f.close() if not list_len: # Handle empty intersection or empty input if print_output: print("No valid overlapping IDs found.") return 0.0, 0.0 weights = np.array(list_len) corr_rate = round(np.sum(np.array(list_correct_rate) * weights) / weights.sum(), 4) acc = round(np.sum(np.array(list_accuracy) * weights) / weights.sum(), 4) if print_output: print("** MD&D Evaluation **") print("Correct Rate:", corr_rate) print("Accuracy:", acc) if output_text: with open(output_text,"a") as w: w.write(f"Correct Rate: {corr_rate}\n") w.write(f"Accuracy: {acc}\n") return corr_rate, acc def main(): parser = argparse.ArgumentParser(description="Evaluate prediction accuracy against annotated and canonical phonemes.") parser.add_argument('--truth-file', type=str, required=True, help="Path to CSV containing truth data (with canonical and annotated phonemes)") parser.add_argument('--pred-file', type=str, required=True, help="Path to CSV containing predictions (with ID and Prediction columns)") parser.add_argument('--output-dir', type=str, default='./aligned', help="Directory to save alignment output files") parser.add_argument('--wov', action='store_true', help="Use vowel-only phoneme columns instead of full ones") parser.add_argument('--output_text', type=str, default=None, help="File to save text output") args = parser.parse_args() # Load data truth_df = pd.read_csv(args.truth_file) pred_df = pd.read_csv(args.pred_file) # print(pred_df) # Run evaluation evaluate_from_dfs( truth_df=truth_df, pred_df=pred_df, output_dir=args.output_dir, wov=args.wov, output_text=args.output_text ) if __name__ == '__main__': main()