File size: 6,631 Bytes
1a6e63a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
import os
import csv
from pathlib import Path

import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from torch.nn.functional import softmax
from datasets import load_dataset
from tqdm import tqdm

# ============================
# CONFIG
# ============================

# 🔴 Target model you want to improve
MODEL_PATH = r"C:\Users\Sam\OneDrive\AetherMind\AetherMindProject\models\student_biomed_kd_fast\adni_srl_round13_smart"

# Where to save ANLI error buffers
OUTPUT_DIR = r"C:\Users\Sam\OneDrive\AetherMind\AetherMindProject\AetherMind_for_Alzheimers_Research\data\claims\analysis"

BATCH_SIZE = 32
MAX_LENGTH = 192

LABEL_ID2NAME = {
    0: "entailment",
    1: "neutral",
    2: "contradiction",
}


# ============================
# HELPERS
# ============================

def ensure_output_dir(path: str):
    os.makedirs(path, exist_ok=True)


def to_label_name(label_id: int) -> str:
    return LABEL_ID2NAME.get(int(label_id), f"label_{label_id}")


def compute_error_type(true_id: int, pred_id: int) -> str:
    if true_id == pred_id:
        return "correct"
    return f"{to_label_name(true_id)[0].upper()}->{to_label_name(pred_id)[0].upper()}"
    # Example: E->N, N->C, C->E


def run_model_on_dataset(model, tokenizer, data, split_name: str, round_tag: str,

                         device: torch.device, output_dir: str):
    """

    Run the model on a dataset (list of dicts with keys: 'premise','hypothesis','label','id').

    Save a CSV with detailed per-example info.

    """
    rows = []

    print(f"\n=== Processing ANLI {split_name} ({len(data)} examples) ===")

    model.eval()
    with torch.no_grad():
        for idx in tqdm(range(0, len(data), BATCH_SIZE), desc=f"{split_name} batches"):
            batch_examples = data[idx:idx + BATCH_SIZE]

            premises = [ex["premise"] for ex in batch_examples]
            hypotheses = [ex["hypothesis"] for ex in batch_examples]
            labels = [int(ex["label"]) for ex in batch_examples]

            enc = tokenizer(
                premises,
                hypotheses,
                padding=True,
                truncation=True,
                max_length=MAX_LENGTH,
                return_tensors="pt",
            )

            input_ids = enc["input_ids"].to(device)
            attention_mask = enc["attention_mask"].to(device)

            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            logits = outputs.logits  # [B, 3]
            probs = softmax(logits, dim=-1)  # [B, 3]

            pred_ids = torch.argmax(probs, dim=-1).cpu().tolist()
            probs_np = probs.cpu().tolist()

            for i, ex in enumerate(batch_examples):
                true_id = int(labels[i])
                pred_id = int(pred_ids[i])
                prob_vec = probs_np[i]
                prob_true = float(prob_vec[true_id])

                is_error = int(true_id != pred_id)
                err_type = compute_error_type(true_id, pred_id)

                ex_id = ex.get("id", ex.get("uid", idx + i))

                rows.append({
                    "id": ex_id,
                    "premise": ex["premise"],
                    "hypothesis": ex["hypothesis"],
                    "true_label_id": true_id,
                    "true_label": to_label_name(true_id),
                    "pred_label_id": pred_id,
                    "pred_label": to_label_name(pred_id),
                    "is_error": is_error,
                    "error_type": err_type,
                    "logit_entailment": float(prob_vec[0]),
                    "logit_neutral": float(prob_vec[1]),
                    "logit_contradiction": float(prob_vec[2]),
                    "conf_true_label": prob_true,
                    "difficulty": 1.0 - prob_true,
                })

    ensure_output_dir(output_dir)
    out_path = os.path.join(output_dir, f"anli_error_buffer_{split_name}_{round_tag}.csv")

    fieldnames = [
        "id",
        "premise",
        "hypothesis",
        "true_label_id",
        "true_label",
        "pred_label_id",
        "pred_label",
        "is_error",
        "error_type",
        "logit_entailment",
        "logit_neutral",
        "logit_contradiction",
        "conf_true_label",
        "difficulty",
    ]

    with open(out_path, "w", encoding="utf-8", newline="") as f:
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        writer.writeheader()
        for row in rows:
            writer.writerow(row)

    total = len(rows)
    errors = sum(r["is_error"] for r in rows)
    acc = 100.0 * (total - errors) / max(1, total)

    print(f"Saved {total} rows to: {out_path}")
    print(f"{split_name} accuracy (recomputed here): {acc:.2f}%  (errors={errors})")


# ============================
# MAIN
# ============================

def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    print(f"\nLoading tokenizer and model from: {MODEL_PATH}")
    tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
    model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH)
    model.to(device)

    # ANLI splits: dev_r1, dev_r2, dev_r3
    anli_splits = {
        "anli_r1_dev": "dev_r1",
        "anli_r2_dev": "dev_r2",
        "anli_r3_dev": "dev_r3",
    }

    for split_name, hf_split in anli_splits.items():
        print(f"\nLoading ANLI split: {hf_split}")
        ds = load_dataset("anli", split=hf_split)

        # Filter out unlabeled (-1) if present and map into a simple list of dicts
        data = []
        for ex in ds:
            label = int(ex["label"])
            if label < 0:
                continue
            data.append({
                "id": ex.get("uid", None),
                "premise": ex["premise"],
                "hypothesis": ex["hypothesis"],
                "label": label,
            })

        print(f"{split_name}: {len(data)} labeled examples")

        run_model_on_dataset(
            model=model,
            tokenizer=tokenizer,
            data=data,
            split_name=split_name,      # will appear in filename
            round_tag="round14",         # consistent with adni_error_buffer_*_round1
            device=device,
            output_dir=OUTPUT_DIR,
        )

    print("\nAll done. ANLI error buffers are ready for SRL fine-tuning.")


if __name__ == "__main__":
    main()