File size: 3,380 Bytes
1a6e63a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import os
import pandas as pd
from sklearn.model_selection import train_test_split

# ============================
# INPUT FILES (already created)
# ============================

BASE_ANALYSIS_DIR = r"C:\Users\Sam\OneDrive\AetherMind\AetherMindProject\AetherMind_for_Alzheimers_Research\data\claims\analysis"

ANLI_R1_CSV = os.path.join(BASE_ANALYSIS_DIR, "anli_error_buffer_anli_r1_dev_round14.csv")
ANLI_R2_CSV = os.path.join(BASE_ANALYSIS_DIR, "anli_error_buffer_anli_r2_dev_round14.csv")
ANLI_R3_CSV = os.path.join(BASE_ANALYSIS_DIR, "anli_error_buffer_anli_r3_dev_round14.csv")

# ============================
# OUTPUT FILES (global ANLI buffer)
# ============================

OUT_TRAIN = os.path.join(BASE_ANALYSIS_DIR, "global_error_buffer_anli_round14_train.csv")
OUT_VAL   = os.path.join(BASE_ANALYSIS_DIR, "global_error_buffer_anli_round14_val.csv")

RANDOM_SEED = 42
VAL_RATIO = 0.20   # 80% train / 20% val


def main():
    print("============================================================")
    print("BUILD GLOBAL ANLI ERROR BUFFER (ROUND 1 → SRL SOURCE)")
    print("============================================================")

    # 1) Load the three ANLI error CSVs
    print("\nLoading ANLI error buffers...")
    df_r1 = pd.read_csv(ANLI_R1_CSV)
    df_r2 = pd.read_csv(ANLI_R2_CSV)
    df_r3 = pd.read_csv(ANLI_R3_CSV)

    print(f"  R1 rows: {len(df_r1)}")
    print(f"  R2 rows: {len(df_r2)}")
    print(f"  R3 rows: {len(df_r3)}")

    # 2) Concatenate
    df_all = pd.concat([df_r1, df_r2, df_r3], ignore_index=True)
    print(f"\nTotal ANLI rows (R1+R2+R3): {len(df_all)}")

    # Sanity: required columns for SRL pipeline
    required_cols = ["premise", "hypothesis", "true_label_id", "is_error"]
    missing = [c for c in required_cols if c not in df_all.columns]
    if missing:
        raise ValueError(f"Missing required columns in ANLI buffers: {missing}")

    # 3) Shuffle + split into train/val
    df_all = df_all.sample(frac=1.0, random_state=RANDOM_SEED).reset_index(drop=True)

    train_df, val_df = train_test_split(
        df_all,
        test_size=VAL_RATIO,
        random_state=RANDOM_SEED,
        shuffle=True,
        stratify=df_all["true_label_id"],  # keep class balance
    )

    print(f"\nTrain size: {len(train_df)}")
    print(f"Val   size: {len(val_df)}")

    # 4) Show distributions
    def show_dist(name, df):
        print(f"\n{name} - class distribution:")
        total = len(df)
        for label_id, label_name in {0: "entailment", 1: "neutral", 2: "contradiction"}.items():
            count = (df["true_label_id"] == label_id).sum()
            print(f"  {label_name}: {count} ({100.0 * count / total:.1f}%)")

        errors = df["is_error"].sum()
        print(f"{name} - errors: {errors} ({100.0 * errors / total:.1f}%), correct: {total - errors}")

    show_dist("TRAIN", train_df)
    show_dist("VAL", val_df)

    # 5) Save
    train_df.to_csv(OUT_TRAIN, index=False, encoding="utf-8")
    val_df.to_csv(OUT_VAL, index=False, encoding="utf-8")

    print("\nSaved:")
    print(f"  Train: {OUT_TRAIN}")
    print(f"  Val  : {OUT_VAL}")
    print("\n✅ Global ANLI error buffers are ready for SRL.")
    print("Use them as input to the SRL buffer rebalance script.")


if __name__ == "__main__":
    main()