samerzaher80 commited on
Commit
1a6e63a
·
verified ·
1 Parent(s): af05e59

Upload 4 files

Browse files
analyze_anli_errors_round1.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import csv
3
+ from pathlib import Path
4
+
5
+ import torch
6
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
7
+ from torch.nn.functional import softmax
8
+ from datasets import load_dataset
9
+ from tqdm import tqdm
10
+
11
+ # ============================
12
+ # CONFIG
13
+ # ============================
14
+
15
+ # 🔴 Target model you want to improve
16
+ MODEL_PATH = r"C:\Users\Sam\OneDrive\AetherMind\AetherMindProject\models\student_biomed_kd_fast\adni_srl_round13_smart"
17
+
18
+ # Where to save ANLI error buffers
19
+ OUTPUT_DIR = r"C:\Users\Sam\OneDrive\AetherMind\AetherMindProject\AetherMind_for_Alzheimers_Research\data\claims\analysis"
20
+
21
+ BATCH_SIZE = 32
22
+ MAX_LENGTH = 192
23
+
24
+ LABEL_ID2NAME = {
25
+ 0: "entailment",
26
+ 1: "neutral",
27
+ 2: "contradiction",
28
+ }
29
+
30
+
31
+ # ============================
32
+ # HELPERS
33
+ # ============================
34
+
35
+ def ensure_output_dir(path: str):
36
+ os.makedirs(path, exist_ok=True)
37
+
38
+
39
+ def to_label_name(label_id: int) -> str:
40
+ return LABEL_ID2NAME.get(int(label_id), f"label_{label_id}")
41
+
42
+
43
+ def compute_error_type(true_id: int, pred_id: int) -> str:
44
+ if true_id == pred_id:
45
+ return "correct"
46
+ return f"{to_label_name(true_id)[0].upper()}->{to_label_name(pred_id)[0].upper()}"
47
+ # Example: E->N, N->C, C->E
48
+
49
+
50
+ def run_model_on_dataset(model, tokenizer, data, split_name: str, round_tag: str,
51
+ device: torch.device, output_dir: str):
52
+ """
53
+ Run the model on a dataset (list of dicts with keys: 'premise','hypothesis','label','id').
54
+ Save a CSV with detailed per-example info.
55
+ """
56
+ rows = []
57
+
58
+ print(f"\n=== Processing ANLI {split_name} ({len(data)} examples) ===")
59
+
60
+ model.eval()
61
+ with torch.no_grad():
62
+ for idx in tqdm(range(0, len(data), BATCH_SIZE), desc=f"{split_name} batches"):
63
+ batch_examples = data[idx:idx + BATCH_SIZE]
64
+
65
+ premises = [ex["premise"] for ex in batch_examples]
66
+ hypotheses = [ex["hypothesis"] for ex in batch_examples]
67
+ labels = [int(ex["label"]) for ex in batch_examples]
68
+
69
+ enc = tokenizer(
70
+ premises,
71
+ hypotheses,
72
+ padding=True,
73
+ truncation=True,
74
+ max_length=MAX_LENGTH,
75
+ return_tensors="pt",
76
+ )
77
+
78
+ input_ids = enc["input_ids"].to(device)
79
+ attention_mask = enc["attention_mask"].to(device)
80
+
81
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask)
82
+ logits = outputs.logits # [B, 3]
83
+ probs = softmax(logits, dim=-1) # [B, 3]
84
+
85
+ pred_ids = torch.argmax(probs, dim=-1).cpu().tolist()
86
+ probs_np = probs.cpu().tolist()
87
+
88
+ for i, ex in enumerate(batch_examples):
89
+ true_id = int(labels[i])
90
+ pred_id = int(pred_ids[i])
91
+ prob_vec = probs_np[i]
92
+ prob_true = float(prob_vec[true_id])
93
+
94
+ is_error = int(true_id != pred_id)
95
+ err_type = compute_error_type(true_id, pred_id)
96
+
97
+ ex_id = ex.get("id", ex.get("uid", idx + i))
98
+
99
+ rows.append({
100
+ "id": ex_id,
101
+ "premise": ex["premise"],
102
+ "hypothesis": ex["hypothesis"],
103
+ "true_label_id": true_id,
104
+ "true_label": to_label_name(true_id),
105
+ "pred_label_id": pred_id,
106
+ "pred_label": to_label_name(pred_id),
107
+ "is_error": is_error,
108
+ "error_type": err_type,
109
+ "logit_entailment": float(prob_vec[0]),
110
+ "logit_neutral": float(prob_vec[1]),
111
+ "logit_contradiction": float(prob_vec[2]),
112
+ "conf_true_label": prob_true,
113
+ "difficulty": 1.0 - prob_true,
114
+ })
115
+
116
+ ensure_output_dir(output_dir)
117
+ out_path = os.path.join(output_dir, f"anli_error_buffer_{split_name}_{round_tag}.csv")
118
+
119
+ fieldnames = [
120
+ "id",
121
+ "premise",
122
+ "hypothesis",
123
+ "true_label_id",
124
+ "true_label",
125
+ "pred_label_id",
126
+ "pred_label",
127
+ "is_error",
128
+ "error_type",
129
+ "logit_entailment",
130
+ "logit_neutral",
131
+ "logit_contradiction",
132
+ "conf_true_label",
133
+ "difficulty",
134
+ ]
135
+
136
+ with open(out_path, "w", encoding="utf-8", newline="") as f:
137
+ writer = csv.DictWriter(f, fieldnames=fieldnames)
138
+ writer.writeheader()
139
+ for row in rows:
140
+ writer.writerow(row)
141
+
142
+ total = len(rows)
143
+ errors = sum(r["is_error"] for r in rows)
144
+ acc = 100.0 * (total - errors) / max(1, total)
145
+
146
+ print(f"Saved {total} rows to: {out_path}")
147
+ print(f"{split_name} accuracy (recomputed here): {acc:.2f}% (errors={errors})")
148
+
149
+
150
+ # ============================
151
+ # MAIN
152
+ # ============================
153
+
154
+ def main():
155
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
156
+ print(f"Using device: {device}")
157
+
158
+ print(f"\nLoading tokenizer and model from: {MODEL_PATH}")
159
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
160
+ model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH)
161
+ model.to(device)
162
+
163
+ # ANLI splits: dev_r1, dev_r2, dev_r3
164
+ anli_splits = {
165
+ "anli_r1_dev": "dev_r1",
166
+ "anli_r2_dev": "dev_r2",
167
+ "anli_r3_dev": "dev_r3",
168
+ }
169
+
170
+ for split_name, hf_split in anli_splits.items():
171
+ print(f"\nLoading ANLI split: {hf_split}")
172
+ ds = load_dataset("anli", split=hf_split)
173
+
174
+ # Filter out unlabeled (-1) if present and map into a simple list of dicts
175
+ data = []
176
+ for ex in ds:
177
+ label = int(ex["label"])
178
+ if label < 0:
179
+ continue
180
+ data.append({
181
+ "id": ex.get("uid", None),
182
+ "premise": ex["premise"],
183
+ "hypothesis": ex["hypothesis"],
184
+ "label": label,
185
+ })
186
+
187
+ print(f"{split_name}: {len(data)} labeled examples")
188
+
189
+ run_model_on_dataset(
190
+ model=model,
191
+ tokenizer=tokenizer,
192
+ data=data,
193
+ split_name=split_name, # will appear in filename
194
+ round_tag="round14", # consistent with adni_error_buffer_*_round1
195
+ device=device,
196
+ output_dir=OUTPUT_DIR,
197
+ )
198
+
199
+ print("\nAll done. ANLI error buffers are ready for SRL fine-tuning.")
200
+
201
+
202
+ if __name__ == "__main__":
203
+ main()
build_anli_global_error_buffer_round1.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+ from sklearn.model_selection import train_test_split
4
+
5
+ # ============================
6
+ # INPUT FILES (already created)
7
+ # ============================
8
+
9
+ BASE_ANALYSIS_DIR = r"C:\Users\Sam\OneDrive\AetherMind\AetherMindProject\AetherMind_for_Alzheimers_Research\data\claims\analysis"
10
+
11
+ ANLI_R1_CSV = os.path.join(BASE_ANALYSIS_DIR, "anli_error_buffer_anli_r1_dev_round14.csv")
12
+ ANLI_R2_CSV = os.path.join(BASE_ANALYSIS_DIR, "anli_error_buffer_anli_r2_dev_round14.csv")
13
+ ANLI_R3_CSV = os.path.join(BASE_ANALYSIS_DIR, "anli_error_buffer_anli_r3_dev_round14.csv")
14
+
15
+ # ============================
16
+ # OUTPUT FILES (global ANLI buffer)
17
+ # ============================
18
+
19
+ OUT_TRAIN = os.path.join(BASE_ANALYSIS_DIR, "global_error_buffer_anli_round14_train.csv")
20
+ OUT_VAL = os.path.join(BASE_ANALYSIS_DIR, "global_error_buffer_anli_round14_val.csv")
21
+
22
+ RANDOM_SEED = 42
23
+ VAL_RATIO = 0.20 # 80% train / 20% val
24
+
25
+
26
+ def main():
27
+ print("============================================================")
28
+ print("BUILD GLOBAL ANLI ERROR BUFFER (ROUND 1 → SRL SOURCE)")
29
+ print("============================================================")
30
+
31
+ # 1) Load the three ANLI error CSVs
32
+ print("\nLoading ANLI error buffers...")
33
+ df_r1 = pd.read_csv(ANLI_R1_CSV)
34
+ df_r2 = pd.read_csv(ANLI_R2_CSV)
35
+ df_r3 = pd.read_csv(ANLI_R3_CSV)
36
+
37
+ print(f" R1 rows: {len(df_r1)}")
38
+ print(f" R2 rows: {len(df_r2)}")
39
+ print(f" R3 rows: {len(df_r3)}")
40
+
41
+ # 2) Concatenate
42
+ df_all = pd.concat([df_r1, df_r2, df_r3], ignore_index=True)
43
+ print(f"\nTotal ANLI rows (R1+R2+R3): {len(df_all)}")
44
+
45
+ # Sanity: required columns for SRL pipeline
46
+ required_cols = ["premise", "hypothesis", "true_label_id", "is_error"]
47
+ missing = [c for c in required_cols if c not in df_all.columns]
48
+ if missing:
49
+ raise ValueError(f"Missing required columns in ANLI buffers: {missing}")
50
+
51
+ # 3) Shuffle + split into train/val
52
+ df_all = df_all.sample(frac=1.0, random_state=RANDOM_SEED).reset_index(drop=True)
53
+
54
+ train_df, val_df = train_test_split(
55
+ df_all,
56
+ test_size=VAL_RATIO,
57
+ random_state=RANDOM_SEED,
58
+ shuffle=True,
59
+ stratify=df_all["true_label_id"], # keep class balance
60
+ )
61
+
62
+ print(f"\nTrain size: {len(train_df)}")
63
+ print(f"Val size: {len(val_df)}")
64
+
65
+ # 4) Show distributions
66
+ def show_dist(name, df):
67
+ print(f"\n{name} - class distribution:")
68
+ total = len(df)
69
+ for label_id, label_name in {0: "entailment", 1: "neutral", 2: "contradiction"}.items():
70
+ count = (df["true_label_id"] == label_id).sum()
71
+ print(f" {label_name}: {count} ({100.0 * count / total:.1f}%)")
72
+
73
+ errors = df["is_error"].sum()
74
+ print(f"{name} - errors: {errors} ({100.0 * errors / total:.1f}%), correct: {total - errors}")
75
+
76
+ show_dist("TRAIN", train_df)
77
+ show_dist("VAL", val_df)
78
+
79
+ # 5) Save
80
+ train_df.to_csv(OUT_TRAIN, index=False, encoding="utf-8")
81
+ val_df.to_csv(OUT_VAL, index=False, encoding="utf-8")
82
+
83
+ print("\nSaved:")
84
+ print(f" Train: {OUT_TRAIN}")
85
+ print(f" Val : {OUT_VAL}")
86
+ print("\n✅ Global ANLI error buffers are ready for SRL.")
87
+ print("Use them as input to the SRL buffer rebalance script.")
88
+
89
+
90
+ if __name__ == "__main__":
91
+ main()
evaluate_model_hf_only (2).py ADDED
@@ -0,0 +1,403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ from datetime import datetime
4
+ from pathlib import Path
5
+
6
+ import torch
7
+ import numpy as np
8
+ from datasets import load_dataset
9
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
10
+ from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix
11
+ from tqdm.auto import tqdm
12
+
13
+ # ============================
14
+ # CONFIG
15
+ # ============================
16
+
17
+ MODEL_PATH = r"C:\Users\Sam\OneDrive\AetherMind\AetherMindProject\models\student_biomed_kd_fast\adni_srl_round14_smart"
18
+ OUTPUT_DIR = r"C:\Users\Sam\OneDrive\AetherMind\AetherMindProject\evaluation_results\adni_srl_round2_fixed"
19
+
20
+ BATCH_SIZE = 64
21
+ MAX_LENGTH = 192
22
+
23
+ # HuggingFace datasets
24
+ DATASETS_CONFIG = [
25
+ ("SNLI", "snli", "test", None),
26
+ ("MNLI M", "nyu-mll/multi_nli", "validation_matched", None),
27
+ ("MNLI MM", "nyu-mll/multi_nli", "validation_mismatched", None),
28
+ ("ANLI R1", "facebook/anli", "test_r1", None),
29
+ ("ANLI R2", "facebook/anli", "test_r2", None),
30
+ ("ANLI R3", "facebook/anli", "test_r3", None),
31
+ ("XNLI", "facebook/xnli", "validation", "en"),
32
+ ]
33
+
34
+ # Local ADNI NLI JSON files
35
+ ADNI_DATASETS = [
36
+ ("ADNI Train", r"C:\Users\Sam\OneDrive\AetherMind\AetherMindProject\AetherMind_for_Alzheimers_Research\data\claims\splits\adni_nli_train.json"),
37
+ ("ADNI Val", r"C:\Users\Sam\OneDrive\AetherMind\AetherMindProject\AetherMind_for_Alzheimers_Research\data\claims\splits\adni_nli_val.json"),
38
+ ("ADNI Test", r"C:\Users\Sam\OneDrive\AetherMind\AetherMindProject\AetherMind_for_Alzheimers_Research\data\claims\splits\adni_nli_test.json"),
39
+ ]
40
+
41
+ LABEL_NAMES = ["entailment", "neutral", "contradiction"]
42
+
43
+
44
+ # ============================
45
+ # HELPER FUNCTIONS
46
+ # ============================
47
+
48
+ def load_model_and_tokenizer(model_path: str, device: str):
49
+ print(f"\n{'='*60}")
50
+ print("Loading Model and Tokenizer")
51
+ print(f"{'='*60}")
52
+ print(f"Model: {model_path}")
53
+
54
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
55
+ model = AutoModelForSequenceClassification.from_pretrained(model_path)
56
+ model.to(device)
57
+ model.eval()
58
+
59
+ print(f"Device: {device}")
60
+ print(f"Model loaded successfully!")
61
+
62
+ return tokenizer, model
63
+
64
+
65
+ def compute_metrics_from_predictions(name, labels, preds):
66
+ accuracy = accuracy_score(labels, preds)
67
+ precision, recall, f1, support = precision_recall_fscore_support(
68
+ labels, preds, average=None, labels=[0, 1, 2], zero_division=0
69
+ )
70
+ macro_precision = float(np.mean(precision))
71
+ macro_recall = float(np.mean(recall))
72
+ macro_f1 = float(np.mean(f1))
73
+ conf_matrix = confusion_matrix(labels, preds, labels=[0, 1, 2])
74
+
75
+ print(f"\n{'='*60}")
76
+ print(f"RESULTS: {name}")
77
+ print(f"{'='*60}")
78
+ print(f"Samples: {len(labels)}")
79
+ print(f"Accuracy: {accuracy*100:.2f}%")
80
+ print(f"Macro F1: {macro_f1*100:.2f}%")
81
+ print(f"\nPer-Class Performance:")
82
+ for i, label_name in enumerate(LABEL_NAMES):
83
+ print(
84
+ f" {label_name.upper():13} "
85
+ f"P: {precision[i]*100:.2f}% "
86
+ f"R: {recall[i]*100:.2f}% "
87
+ f"F1: {f1[i]*100:.2f}% (n={support[i]})"
88
+ )
89
+
90
+ result = {
91
+ "dataset": name,
92
+ "accuracy": float(accuracy),
93
+ "macro_precision": macro_precision,
94
+ "macro_recall": macro_recall,
95
+ "macro_f1": macro_f1,
96
+ "per_class": {
97
+ LABEL_NAMES[i]: {
98
+ "precision": float(precision[i]),
99
+ "recall": float(recall[i]),
100
+ "f1": float(f1[i]),
101
+ "support": int(support[i]),
102
+ }
103
+ for i in range(3)
104
+ },
105
+ "confusion_matrix": conf_matrix.tolist(),
106
+ "total_samples": len(labels),
107
+ }
108
+ return result
109
+
110
+
111
+ def evaluate_dataset(
112
+ name: str,
113
+ hf_name: str,
114
+ split: str,
115
+ config: str,
116
+ tokenizer,
117
+ model,
118
+ device: str,
119
+ batch_size: int,
120
+ max_length: int,
121
+ ):
122
+ print(f"\n{'='*60}")
123
+ print(f"Loading {name} Dataset")
124
+ print(f"{'='*60}")
125
+
126
+ if config:
127
+ dataset = load_dataset(hf_name, config, split=split, trust_remote_code=False)
128
+ else:
129
+ dataset = load_dataset(hf_name, split=split, trust_remote_code=False)
130
+
131
+ if "label" in dataset.column_names:
132
+ dataset = dataset.filter(lambda ex: ex["label"] != -1)
133
+
134
+ print(f"✅ Loaded {len(dataset)} valid examples")
135
+
136
+ premises = [str(ex["premise"]) for ex in dataset]
137
+ hypotheses = [str(ex["hypothesis"]) for ex in dataset]
138
+ labels = [int(ex["label"]) for ex in dataset]
139
+
140
+ label_counts = {0: 0, 1: 0, 2: 0}
141
+ for lab in labels:
142
+ label_counts[lab] = label_counts.get(lab, 0) + 1
143
+ print(f"Label distribution: {label_counts}")
144
+
145
+ print(f"\n{'='*60}")
146
+ print(f"Evaluating: {name}")
147
+ print(f"{'='*60}")
148
+
149
+ all_preds = []
150
+ num_batches = (len(labels) + batch_size - 1) // batch_size
151
+
152
+ with torch.no_grad():
153
+ for i in tqdm(range(0, len(labels), batch_size), total=num_batches, desc=f"{name}"):
154
+ batch_premises = premises[i:i+batch_size]
155
+ batch_hypotheses = hypotheses[i:i+batch_size]
156
+
157
+ encodings = tokenizer(
158
+ batch_premises,
159
+ batch_hypotheses,
160
+ padding=True,
161
+ truncation=True,
162
+ max_length=max_length,
163
+ return_tensors="pt",
164
+ ).to(device)
165
+
166
+ outputs = model(**encodings)
167
+ preds = torch.argmax(outputs.logits, dim=-1).cpu().tolist()
168
+ all_preds.extend(preds)
169
+
170
+ return compute_metrics_from_predictions(name, labels, all_preds)
171
+
172
+
173
+ def extract_label(rec):
174
+ """
175
+ Robustly extract label as int 0/1/2 from a JSON record.
176
+ Handles:
177
+ - rec['label'] as int or string
178
+ - rec['true_label_id'] as int
179
+ - rec['gold_label'] as string
180
+ """
181
+ mapping = {
182
+ "entailment": 0,
183
+ "e": 0,
184
+ "neutral": 1,
185
+ "n": 1,
186
+ "contradiction": 2,
187
+ "c": 2,
188
+ }
189
+
190
+ if "label" in rec:
191
+ v = rec["label"]
192
+ if isinstance(v, int):
193
+ return v
194
+ v_str = str(v).strip().lower()
195
+ if v_str in mapping:
196
+ return mapping[v_str]
197
+ raise ValueError(f"Unknown string label in 'label': {v}")
198
+
199
+ if "true_label_id" in rec:
200
+ return int(rec["true_label_id"])
201
+
202
+ if "gold_label" in rec:
203
+ v_str = str(rec["gold_label"]).strip().lower()
204
+ if v_str in mapping:
205
+ return mapping[v_str]
206
+ raise ValueError(f"Unknown string label in 'gold_label': {rec['gold_label']}")
207
+
208
+ raise ValueError(f"Could not extract label from record keys: {list(rec.keys())}")
209
+
210
+
211
+ def evaluate_local_json_dataset(
212
+ name: str,
213
+ json_path: str,
214
+ tokenizer,
215
+ model,
216
+ device: str,
217
+ batch_size: int,
218
+ max_length: int,
219
+ ):
220
+ print(f"\n{'='*60}")
221
+ print(f"Loading {name} (local JSON)")
222
+ print(f"{'='*60}")
223
+ print(f"Path: {json_path}")
224
+
225
+ if not os.path.exists(json_path):
226
+ raise FileNotFoundError(f"JSON file not found: {json_path}")
227
+
228
+ with open(json_path, "r", encoding="utf-8") as f:
229
+ data = json.load(f)
230
+
231
+ if isinstance(data, dict) and "data" in data:
232
+ records = data["data"]
233
+ else:
234
+ records = data
235
+
236
+ premises = []
237
+ hypotheses = []
238
+ labels = []
239
+
240
+ for rec in records:
241
+ premise = rec.get("premise")
242
+ hypothesis = rec.get("hypothesis")
243
+ if premise is None or hypothesis is None:
244
+ raise ValueError("Expected 'premise' and 'hypothesis' keys in ADNI JSON records.")
245
+ label = extract_label(rec)
246
+
247
+ if label == -1:
248
+ continue
249
+
250
+ premises.append(str(premise))
251
+ hypotheses.append(str(hypothesis))
252
+ labels.append(int(label))
253
+
254
+ print(f"✅ Loaded {len(labels)} valid examples")
255
+
256
+ label_counts = {0: 0, 1: 0, 2: 0}
257
+ for lab in labels:
258
+ label_counts[lab] = label_counts.get(lab, 0) + 1
259
+ print(f"Label distribution: {label_counts}")
260
+
261
+ print(f"\n{'='*60}")
262
+ print(f"Evaluating: {name}")
263
+ print(f"{'='*60}")
264
+
265
+ all_preds = []
266
+ num_batches = (len(labels) + batch_size - 1) // batch_size
267
+
268
+ with torch.no_grad():
269
+ for i in tqdm(range(0, len(labels), batch_size), total=num_batches, desc=name):
270
+ batch_premises = premises[i:i + batch_size]
271
+ batch_hypotheses = hypotheses[i:i + batch_size]
272
+
273
+ encodings = tokenizer(
274
+ batch_premises,
275
+ batch_hypotheses,
276
+ padding=True,
277
+ truncation=True,
278
+ max_length=max_length,
279
+ return_tensors="pt",
280
+ ).to(device)
281
+
282
+ outputs = model(**encodings)
283
+ preds = torch.argmax(outputs.logits, dim=-1).cpu().tolist()
284
+ all_preds.extend(preds)
285
+
286
+ return compute_metrics_from_predictions(name, labels, all_preds)
287
+
288
+
289
+ def save_results(results: list, output_dir: str, model_path: str):
290
+ os.makedirs(output_dir, exist_ok=True)
291
+
292
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
293
+ model_name = Path(model_path).name
294
+
295
+ json_path = os.path.join(output_dir, f"results_{model_name}_{timestamp}.json")
296
+ with open(json_path, "w", encoding="utf-8") as f:
297
+ json.dump(results, f, indent=2)
298
+
299
+ summary_path = os.path.join(output_dir, f"summary_{model_name}_{timestamp}.txt")
300
+ with open(summary_path, "w", encoding="utf-8") as f:
301
+ f.write("="*80 + "\n")
302
+ f.write("COMPREHENSIVE NLI MODEL EVALUATION SUMMARY\n")
303
+ f.write("="*80 + "\n")
304
+ f.write(f"Model: {model_path}\n")
305
+ f.write(f"Timestamp: {timestamp}\n")
306
+ f.write("="*80 + "\n\n")
307
+
308
+ for result in results:
309
+ f.write(f"{result['dataset']}\n")
310
+ f.write("-" * 40 + "\n")
311
+ f.write(f"Accuracy: {result['accuracy']*100:.2f}%\n")
312
+ f.write(f"Macro F1: {result['macro_f1']*100:.2f}%\n")
313
+ f.write(f"Samples: {result['total_samples']}\n")
314
+ f.write("\n")
315
+
316
+ f.write("\n" + "="*80 + "\n")
317
+ f.write("OVERALL STATISTICS\n")
318
+ f.write("="*80 + "\n")
319
+
320
+ avg_accuracy = np.mean([r['accuracy'] for r in results])
321
+ avg_f1 = np.mean([r['macro_f1'] for r in results])
322
+
323
+ f.write(f"Average Accuracy: {avg_accuracy*100:.2f}%\n")
324
+ f.write(f"Average Macro F1: {avg_f1*100:.2f}%\n")
325
+
326
+ print(f"\n✅ Results saved:")
327
+ print(f" JSON: {json_path}")
328
+ print(f" Summary: {summary_path}")
329
+
330
+ return json_path, summary_path
331
+
332
+
333
+ # ============================
334
+ # MAIN
335
+ # ============================
336
+
337
+ def main():
338
+ device = "cuda" if torch.cuda.is_available() else "cpu"
339
+
340
+ print("="*80)
341
+ print("COMPREHENSIVE NLI MODEL EVALUATION")
342
+ print("="*80)
343
+ print(f"Model: {MODEL_PATH}")
344
+ all_names = [d[0] for d in DATASETS_CONFIG] + [d[0] for d in ADNI_DATASETS]
345
+ print(f"Datasets: {', '.join(all_names)}")
346
+ print("="*80)
347
+
348
+ tokenizer, model = load_model_and_tokenizer(MODEL_PATH, device)
349
+
350
+ all_results = []
351
+
352
+ for name, hf_name, split, config in DATASETS_CONFIG:
353
+ result = evaluate_dataset(
354
+ name=name,
355
+ hf_name=hf_name,
356
+ split=split,
357
+ config=config,
358
+ tokenizer=tokenizer,
359
+ model=model,
360
+ device=device,
361
+ batch_size=BATCH_SIZE,
362
+ max_length=MAX_LENGTH,
363
+ )
364
+ all_results.append(result)
365
+
366
+ for name, path in ADNI_DATASETS:
367
+ result = evaluate_local_json_dataset(
368
+ name=name,
369
+ json_path=path,
370
+ tokenizer=tokenizer,
371
+ model=model,
372
+ device=device,
373
+ batch_size=BATCH_SIZE,
374
+ max_length=MAX_LENGTH,
375
+ )
376
+ all_results.append(result)
377
+
378
+ save_results(all_results, OUTPUT_DIR, MODEL_PATH)
379
+
380
+ print(f"\n{'='*80}")
381
+ print("EVALUATION COMPLETE - FINAL SUMMARY")
382
+ print(f"{'='*80}\n")
383
+
384
+ print(f"{'Dataset':<15} {'Accuracy':<12} {'Macro F1':<12} {'Samples':<10}")
385
+ print("-" * 50)
386
+
387
+ for result in all_results:
388
+ print(
389
+ f"{result['dataset']:<15} "
390
+ f"{result['accuracy']*100:>6.2f}% "
391
+ f"{result['macro_f1']*100:>6.2f}% "
392
+ f"{result['total_samples']:>6}"
393
+ )
394
+
395
+ print("-" * 50)
396
+ avg_accuracy = np.mean([r['accuracy'] for r in all_results])
397
+ avg_f1 = np.mean([r['macro_f1'] for r in all_results])
398
+ print(f"{'AVERAGE':<15} {avg_accuracy*100:>6.2f}% {avg_f1*100:>6.2f}%")
399
+ print("="*80)
400
+
401
+
402
+ if __name__ == "__main__":
403
+ main()
srl_finetune_round5_smart.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SRL Round 5 - Smart ANLI Fine-tune (Small, Safe Correction)
3
+
4
+ - Base model: best global checkpoint (adni_srl_round3_final)
5
+ - Data: smart ANLI SRL buffer (60% errors / 40% correct, pattern-tagged)
6
+ - Goal: improve ANLI robustness on real failure patterns without hurting SNLI/MNLI/XNLI
7
+ """
8
+
9
+ import os
10
+ os.environ["WANDB_DISABLED"] = "true"
11
+
12
+ from dataclasses import dataclass
13
+ from typing import Dict, List, Union
14
+
15
+ import pandas as pd
16
+ import numpy as np
17
+ import torch
18
+ from torch import nn
19
+
20
+ from transformers import (
21
+ AutoTokenizer,
22
+ AutoModelForSequenceClassification,
23
+ Trainer,
24
+ TrainingArguments,
25
+ )
26
+
27
+ # =============================================================================
28
+ # CONFIG
29
+ # =============================================================================
30
+
31
+ # Best model so far (global NLI + ADNI)
32
+ BASE_MODEL_PATH = r"C:\Users\Sam\OneDrive\AetherMind\AetherMindProject\models\student_biomed_kd_fast\adni_srl_round13_smart"
33
+
34
+ # Smart SRL buffers (ANLI Round 1 error patterns)
35
+ SMART_TRAIN_CSV = r"C:\Users\Sam\OneDrive\AetherMind\AetherMindProject\AetherMind_for_Alzheimers_Research\data\claims\analysis\global_error_buffer_anli_round14_train.csv"
36
+ SMART_VAL_CSV = r"C:\Users\Sam\OneDrive\AetherMind\AetherMindProject\AetherMind_for_Alzheimers_Research\data\claims\analysis\global_error_buffer_anli_round14_val.csv"
37
+
38
+ # Output directory for the new model
39
+ OUTPUT_DIR = r"C:\Users\Sam\OneDrive\AetherMind\AetherMindProject\models\student_biomed_kd_fast\adni_srl_round14_smart"
40
+
41
+ # Max sequence length
42
+ MAX_LENGTH = 192
43
+ # Training hyper-parameters (small, safe SRL step)
44
+ NUM_EPOCHS = 1
45
+ BATCH_SIZE = 16
46
+ LEARNING_RATE = 2e-6
47
+
48
+
49
+ # Class weights (E, N, C) – mild bias towards entailment and contradiction
50
+ CLASS_WEIGHTS = torch.tensor([1.5, 1.0, 1.3], dtype=torch.float32)
51
+
52
+ # Error vs correct weighting
53
+ ERROR_WEIGHT = 2.0 # errors * 2.0, correct * 1.0
54
+
55
+ # Seed
56
+ SEED = 42
57
+
58
+
59
+ # =============================================================================
60
+ # DATASET
61
+ # =============================================================================
62
+
63
+ class NLIDataset(torch.utils.data.Dataset):
64
+ def __init__(self, df: pd.DataFrame, tokenizer: AutoTokenizer, max_length: int = 128):
65
+ self.tokenizer = tokenizer
66
+ self.max_length = max_length
67
+
68
+ # Expect columns: premise, hypothesis, true_label_id, is_error
69
+ premises = df["premise"].astype(str).tolist()
70
+ hypotheses = df["hypothesis"].astype(str).tolist()
71
+ labels = df["true_label_id"].astype(int).tolist()
72
+
73
+ is_error = df["is_error"].astype(int).tolist()
74
+ error_weights = [ERROR_WEIGHT if e == 1 else 1.0 for e in is_error]
75
+
76
+ encodings = tokenizer(
77
+ premises,
78
+ hypotheses,
79
+ truncation=True,
80
+ padding="max_length",
81
+ max_length=max_length,
82
+ )
83
+
84
+ self.input_ids = torch.tensor(encodings["input_ids"], dtype=torch.long)
85
+ self.attention_mask = torch.tensor(encodings["attention_mask"], dtype=torch.long)
86
+ if "token_type_ids" in encodings:
87
+ self.token_type_ids = torch.tensor(encodings["token_type_ids"], dtype=torch.long)
88
+ else:
89
+ self.token_type_ids = None
90
+
91
+ self.labels = torch.tensor(labels, dtype=torch.long)
92
+ self.error_weights = torch.tensor(error_weights, dtype=torch.float32)
93
+
94
+ def __len__(self):
95
+ return self.labels.size(0)
96
+
97
+ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
98
+ item = {
99
+ "input_ids": self.input_ids[idx],
100
+ "attention_mask": self.attention_mask[idx],
101
+ "labels": self.labels[idx],
102
+ "error_weight": self.error_weights[idx],
103
+ }
104
+ if self.token_type_ids is not None:
105
+ item["token_type_ids"] = self.token_type_ids[idx]
106
+ return item
107
+
108
+
109
+ @dataclass
110
+ class DataCollatorWithWeights:
111
+ """
112
+ Simple collator: all sequences already padded to max_length.
113
+ Just stacks tensors and keeps error_weight.
114
+ """
115
+ def __call__(self, features: List[Dict[str, Union[torch.Tensor, int, float]]]) -> Dict[str, torch.Tensor]:
116
+ batch: Dict[str, torch.Tensor] = {}
117
+ keys = features[0].keys()
118
+ for key in keys:
119
+ batch[key] = torch.stack([f[key] for f in features])
120
+ return batch
121
+
122
+
123
+ # =============================================================================
124
+ # TRAINER WITH CLASS + ERROR WEIGHTED LOSS
125
+ # =============================================================================
126
+
127
+ class ClassAndErrorWeightedTrainer(Trainer):
128
+ def __init__(self, *args, class_weights: torch.Tensor = None, **kwargs):
129
+ super().__init__(*args, **kwargs)
130
+ self.class_weights = class_weights
131
+
132
+ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
133
+ labels = inputs.pop("labels")
134
+ error_weight = inputs.pop("error_weight", None)
135
+
136
+ outputs = model(**inputs)
137
+ logits = outputs.logits
138
+
139
+ # Move class weights to correct device
140
+ cw = self.class_weights.to(logits.device) if self.class_weights is not None else None
141
+
142
+ loss_fct = nn.CrossEntropyLoss(weight=cw, reduction="none")
143
+ per_sample_loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
144
+
145
+ if error_weight is not None:
146
+ ew = error_weight.to(per_sample_loss.device).view(-1)
147
+ per_sample_loss = per_sample_loss * ew
148
+
149
+ loss = per_sample_loss.mean()
150
+
151
+ if return_outputs:
152
+ return loss, outputs
153
+ return loss
154
+
155
+
156
+ # =============================================================================
157
+ # METRICS
158
+ # =============================================================================
159
+
160
+ def compute_metrics(eval_pred):
161
+ logits, labels = eval_pred
162
+ preds = np.argmax(logits, axis=-1)
163
+
164
+ labels = labels.astype(int)
165
+ preds = preds.astype(int)
166
+
167
+ acc = (preds == labels).mean()
168
+
169
+ # Per-class metrics
170
+ num_classes = 3
171
+ f1s = []
172
+ recalls = []
173
+ for cls in range(num_classes):
174
+ tp = np.logical_and(preds == cls, labels == cls).sum()
175
+ fp = np.logical_and(preds == cls, labels != cls).sum()
176
+ fn = np.logical_and(preds != cls, labels == cls).sum()
177
+
178
+ prec = tp / (tp + fp + 1e-8)
179
+ rec = tp / (tp + fn + 1e-8)
180
+ f1 = 2 * prec * rec / (prec + rec + 1e-8)
181
+
182
+ f1s.append(f1)
183
+ recalls.append(rec)
184
+
185
+ macro_f1 = float(np.mean(f1s))
186
+
187
+ return {
188
+ "accuracy": float(acc),
189
+ "macro_f1": macro_f1,
190
+ "entailment_recall": float(recalls[0]),
191
+ "neutral_recall": float(recalls[1]),
192
+ "contradiction_recall": float(recalls[2]),
193
+ }
194
+
195
+
196
+ # =============================================================================
197
+ # MAIN
198
+ # =============================================================================
199
+
200
+ def main():
201
+ torch.manual_seed(SEED)
202
+ np.random.seed(SEED)
203
+
204
+ print("=" * 80)
205
+ print("SRL ROUND 5 - SMART ANLI FINE-TUNE")
206
+ print("=" * 80)
207
+ print(f"Base model : {BASE_MODEL_PATH}")
208
+ print(f"Train CSV (SRL) : {SMART_TRAIN_CSV}")
209
+ print(f"Val CSV (SRL) : {SMART_VAL_CSV}")
210
+ print(f"Output dir : {OUTPUT_DIR}")
211
+ print("=" * 80)
212
+
213
+ # ---------------------------------------------------------
214
+ # Load tokenizer + model
215
+ # ---------------------------------------------------------
216
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_PATH)
217
+ model = AutoModelForSequenceClassification.from_pretrained(BASE_MODEL_PATH)
218
+
219
+ # ---------------------------------------------------------
220
+ # Load SRL buffers
221
+ # ---------------------------------------------------------
222
+ train_df = pd.read_csv(SMART_TRAIN_CSV)
223
+ val_df = pd.read_csv(SMART_VAL_CSV)
224
+
225
+ print("\nSMART SRL TRAIN BUFFER")
226
+ print("----------------------")
227
+ print(f"Rows: {len(train_df)}")
228
+ print(train_df["true_label_id"].value_counts(normalize=True).sort_index())
229
+
230
+ print("\nSMART SRL VAL BUFFER")
231
+ print("--------------------")
232
+ print(f"Rows: {len(val_df)}")
233
+ print(val_df["true_label_id"].value_counts(normalize=True).sort_index())
234
+
235
+ # ---------------------------------------------------------
236
+ # Build datasets
237
+ # ---------------------------------------------------------
238
+ train_dataset = NLIDataset(train_df, tokenizer, max_length=MAX_LENGTH)
239
+ val_dataset = NLIDataset(val_df, tokenizer, max_length=MAX_LENGTH)
240
+
241
+ # ---------------------------------------------------------
242
+ # Training args (SMALL, SAFE)
243
+ # ---------------------------------------------------------
244
+ training_args = TrainingArguments(
245
+ output_dir=OUTPUT_DIR,
246
+ overwrite_output_dir=True,
247
+ num_train_epochs=NUM_EPOCHS,
248
+ per_device_train_batch_size=BATCH_SIZE,
249
+ per_device_eval_batch_size=BATCH_SIZE,
250
+ learning_rate=LEARNING_RATE,
251
+ weight_decay=0.01,
252
+ logging_steps=50,
253
+ eval_strategy="epoch", # same as your other SRL scripts
254
+ save_strategy="epoch",
255
+ save_total_limit=2,
256
+ load_best_model_at_end=True,
257
+ metric_for_best_model="macro_f1",
258
+ remove_unused_columns=False,
259
+ report_to=[],
260
+ )
261
+
262
+
263
+
264
+ data_collator = DataCollatorWithWeights()
265
+
266
+ trainer = ClassAndErrorWeightedTrainer(
267
+ model=model,
268
+ args=training_args,
269
+ train_dataset=train_dataset,
270
+ eval_dataset=val_dataset,
271
+ data_collator=data_collator,
272
+ tokenizer=tokenizer,
273
+ compute_metrics=compute_metrics,
274
+ class_weights=CLASS_WEIGHTS,
275
+ )
276
+
277
+ # ---------------------------------------------------------
278
+ # Train
279
+ # ---------------------------------------------------------
280
+ print("\nStarting SRL Round 5 (smart ANLI fine-tune)...")
281
+ trainer.train()
282
+
283
+ print("\nFinal evaluation on SRL val buffer:")
284
+ metrics = trainer.evaluate()
285
+ for k, v in metrics.items():
286
+ print(f" {k}: {v:.4f}" if isinstance(v, float) else f" {k}: {v}")
287
+
288
+ # ---------------------------------------------------------
289
+ # Save
290
+ # ---------------------------------------------------------
291
+ print("\nSaving final SRL Round 5 model...")
292
+ trainer.save_model(OUTPUT_DIR)
293
+ tokenizer.save_pretrained(OUTPUT_DIR)
294
+
295
+ print("\n" + "=" * 80)
296
+ print("✅ SRL ROUND 5 SMART FINE-TUNE COMPLETE")
297
+ print("=" * 80)
298
+ print(f"Model saved to: {OUTPUT_DIR}")
299
+ print("Next: run evaluate_model_hf_only.py with this path as MODEL.")
300
+ print("=" * 80)
301
+
302
+
303
+ if __name__ == "__main__":
304
+ main()