import os import pandas as pd from sklearn.model_selection import train_test_split # ============================ # INPUT FILES (already created) # ============================ BASE_ANALYSIS_DIR = r"C:\Users\Sam\OneDrive\AetherMind\AetherMindProject\AetherMind_for_Alzheimers_Research\data\claims\analysis" ANLI_R1_CSV = os.path.join(BASE_ANALYSIS_DIR, "anli_error_buffer_anli_r1_dev_round14.csv") ANLI_R2_CSV = os.path.join(BASE_ANALYSIS_DIR, "anli_error_buffer_anli_r2_dev_round14.csv") ANLI_R3_CSV = os.path.join(BASE_ANALYSIS_DIR, "anli_error_buffer_anli_r3_dev_round14.csv") # ============================ # OUTPUT FILES (global ANLI buffer) # ============================ OUT_TRAIN = os.path.join(BASE_ANALYSIS_DIR, "global_error_buffer_anli_round14_train.csv") OUT_VAL = os.path.join(BASE_ANALYSIS_DIR, "global_error_buffer_anli_round14_val.csv") RANDOM_SEED = 42 VAL_RATIO = 0.20 # 80% train / 20% val def main(): print("============================================================") print("BUILD GLOBAL ANLI ERROR BUFFER (ROUND 1 → SRL SOURCE)") print("============================================================") # 1) Load the three ANLI error CSVs print("\nLoading ANLI error buffers...") df_r1 = pd.read_csv(ANLI_R1_CSV) df_r2 = pd.read_csv(ANLI_R2_CSV) df_r3 = pd.read_csv(ANLI_R3_CSV) print(f" R1 rows: {len(df_r1)}") print(f" R2 rows: {len(df_r2)}") print(f" R3 rows: {len(df_r3)}") # 2) Concatenate df_all = pd.concat([df_r1, df_r2, df_r3], ignore_index=True) print(f"\nTotal ANLI rows (R1+R2+R3): {len(df_all)}") # Sanity: required columns for SRL pipeline required_cols = ["premise", "hypothesis", "true_label_id", "is_error"] missing = [c for c in required_cols if c not in df_all.columns] if missing: raise ValueError(f"Missing required columns in ANLI buffers: {missing}") # 3) Shuffle + split into train/val df_all = df_all.sample(frac=1.0, random_state=RANDOM_SEED).reset_index(drop=True) train_df, val_df = train_test_split( df_all, test_size=VAL_RATIO, random_state=RANDOM_SEED, shuffle=True, stratify=df_all["true_label_id"], # keep class balance ) print(f"\nTrain size: {len(train_df)}") print(f"Val size: {len(val_df)}") # 4) Show distributions def show_dist(name, df): print(f"\n{name} - class distribution:") total = len(df) for label_id, label_name in {0: "entailment", 1: "neutral", 2: "contradiction"}.items(): count = (df["true_label_id"] == label_id).sum() print(f" {label_name}: {count} ({100.0 * count / total:.1f}%)") errors = df["is_error"].sum() print(f"{name} - errors: {errors} ({100.0 * errors / total:.1f}%), correct: {total - errors}") show_dist("TRAIN", train_df) show_dist("VAL", val_df) # 5) Save train_df.to_csv(OUT_TRAIN, index=False, encoding="utf-8") val_df.to_csv(OUT_VAL, index=False, encoding="utf-8") print("\nSaved:") print(f" Train: {OUT_TRAIN}") print(f" Val : {OUT_VAL}") print("\nāœ… Global ANLI error buffers are ready for SRL.") print("Use them as input to the SRL buffer rebalance script.") if __name__ == "__main__": main()