|
|
import os
|
|
|
import pandas as pd
|
|
|
from sklearn.model_selection import train_test_split
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
BASE_ANALYSIS_DIR = r"C:\Users\Sam\OneDrive\AetherMind\AetherMindProject\AetherMind_for_Alzheimers_Research\data\claims\analysis"
|
|
|
|
|
|
ANLI_R1_CSV = os.path.join(BASE_ANALYSIS_DIR, "anli_error_buffer_anli_r1_dev_round14.csv")
|
|
|
ANLI_R2_CSV = os.path.join(BASE_ANALYSIS_DIR, "anli_error_buffer_anli_r2_dev_round14.csv")
|
|
|
ANLI_R3_CSV = os.path.join(BASE_ANALYSIS_DIR, "anli_error_buffer_anli_r3_dev_round14.csv")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
OUT_TRAIN = os.path.join(BASE_ANALYSIS_DIR, "global_error_buffer_anli_round14_train.csv")
|
|
|
OUT_VAL = os.path.join(BASE_ANALYSIS_DIR, "global_error_buffer_anli_round14_val.csv")
|
|
|
|
|
|
RANDOM_SEED = 42
|
|
|
VAL_RATIO = 0.20
|
|
|
|
|
|
|
|
|
def main():
|
|
|
print("============================================================")
|
|
|
print("BUILD GLOBAL ANLI ERROR BUFFER (ROUND 1 → SRL SOURCE)")
|
|
|
print("============================================================")
|
|
|
|
|
|
|
|
|
print("\nLoading ANLI error buffers...")
|
|
|
df_r1 = pd.read_csv(ANLI_R1_CSV)
|
|
|
df_r2 = pd.read_csv(ANLI_R2_CSV)
|
|
|
df_r3 = pd.read_csv(ANLI_R3_CSV)
|
|
|
|
|
|
print(f" R1 rows: {len(df_r1)}")
|
|
|
print(f" R2 rows: {len(df_r2)}")
|
|
|
print(f" R3 rows: {len(df_r3)}")
|
|
|
|
|
|
|
|
|
df_all = pd.concat([df_r1, df_r2, df_r3], ignore_index=True)
|
|
|
print(f"\nTotal ANLI rows (R1+R2+R3): {len(df_all)}")
|
|
|
|
|
|
|
|
|
required_cols = ["premise", "hypothesis", "true_label_id", "is_error"]
|
|
|
missing = [c for c in required_cols if c not in df_all.columns]
|
|
|
if missing:
|
|
|
raise ValueError(f"Missing required columns in ANLI buffers: {missing}")
|
|
|
|
|
|
|
|
|
df_all = df_all.sample(frac=1.0, random_state=RANDOM_SEED).reset_index(drop=True)
|
|
|
|
|
|
train_df, val_df = train_test_split(
|
|
|
df_all,
|
|
|
test_size=VAL_RATIO,
|
|
|
random_state=RANDOM_SEED,
|
|
|
shuffle=True,
|
|
|
stratify=df_all["true_label_id"],
|
|
|
)
|
|
|
|
|
|
print(f"\nTrain size: {len(train_df)}")
|
|
|
print(f"Val size: {len(val_df)}")
|
|
|
|
|
|
|
|
|
def show_dist(name, df):
|
|
|
print(f"\n{name} - class distribution:")
|
|
|
total = len(df)
|
|
|
for label_id, label_name in {0: "entailment", 1: "neutral", 2: "contradiction"}.items():
|
|
|
count = (df["true_label_id"] == label_id).sum()
|
|
|
print(f" {label_name}: {count} ({100.0 * count / total:.1f}%)")
|
|
|
|
|
|
errors = df["is_error"].sum()
|
|
|
print(f"{name} - errors: {errors} ({100.0 * errors / total:.1f}%), correct: {total - errors}")
|
|
|
|
|
|
show_dist("TRAIN", train_df)
|
|
|
show_dist("VAL", val_df)
|
|
|
|
|
|
|
|
|
train_df.to_csv(OUT_TRAIN, index=False, encoding="utf-8")
|
|
|
val_df.to_csv(OUT_VAL, index=False, encoding="utf-8")
|
|
|
|
|
|
print("\nSaved:")
|
|
|
print(f" Train: {OUT_TRAIN}")
|
|
|
print(f" Val : {OUT_VAL}")
|
|
|
print("\n✅ Global ANLI error buffers are ready for SRL.")
|
|
|
print("Use them as input to the SRL buffer rebalance script.")
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
main()
|
|
|
|