import os import csv from pathlib import Path import torch from transformers import AutoTokenizer, AutoModelForSequenceClassification from torch.nn.functional import softmax from datasets import load_dataset from tqdm import tqdm # ============================ # CONFIG # ============================ # 🔴 Target model you want to improve MODEL_PATH = r"C:\Users\Sam\OneDrive\AetherMind\AetherMindProject\models\student_biomed_kd_fast\adni_srl_round13_smart" # Where to save ANLI error buffers OUTPUT_DIR = r"C:\Users\Sam\OneDrive\AetherMind\AetherMindProject\AetherMind_for_Alzheimers_Research\data\claims\analysis" BATCH_SIZE = 32 MAX_LENGTH = 192 LABEL_ID2NAME = { 0: "entailment", 1: "neutral", 2: "contradiction", } # ============================ # HELPERS # ============================ def ensure_output_dir(path: str): os.makedirs(path, exist_ok=True) def to_label_name(label_id: int) -> str: return LABEL_ID2NAME.get(int(label_id), f"label_{label_id}") def compute_error_type(true_id: int, pred_id: int) -> str: if true_id == pred_id: return "correct" return f"{to_label_name(true_id)[0].upper()}->{to_label_name(pred_id)[0].upper()}" # Example: E->N, N->C, C->E def run_model_on_dataset(model, tokenizer, data, split_name: str, round_tag: str, device: torch.device, output_dir: str): """ Run the model on a dataset (list of dicts with keys: 'premise','hypothesis','label','id'). Save a CSV with detailed per-example info. """ rows = [] print(f"\n=== Processing ANLI {split_name} ({len(data)} examples) ===") model.eval() with torch.no_grad(): for idx in tqdm(range(0, len(data), BATCH_SIZE), desc=f"{split_name} batches"): batch_examples = data[idx:idx + BATCH_SIZE] premises = [ex["premise"] for ex in batch_examples] hypotheses = [ex["hypothesis"] for ex in batch_examples] labels = [int(ex["label"]) for ex in batch_examples] enc = tokenizer( premises, hypotheses, padding=True, truncation=True, max_length=MAX_LENGTH, return_tensors="pt", ) input_ids = enc["input_ids"].to(device) attention_mask = enc["attention_mask"].to(device) outputs = model(input_ids=input_ids, attention_mask=attention_mask) logits = outputs.logits # [B, 3] probs = softmax(logits, dim=-1) # [B, 3] pred_ids = torch.argmax(probs, dim=-1).cpu().tolist() probs_np = probs.cpu().tolist() for i, ex in enumerate(batch_examples): true_id = int(labels[i]) pred_id = int(pred_ids[i]) prob_vec = probs_np[i] prob_true = float(prob_vec[true_id]) is_error = int(true_id != pred_id) err_type = compute_error_type(true_id, pred_id) ex_id = ex.get("id", ex.get("uid", idx + i)) rows.append({ "id": ex_id, "premise": ex["premise"], "hypothesis": ex["hypothesis"], "true_label_id": true_id, "true_label": to_label_name(true_id), "pred_label_id": pred_id, "pred_label": to_label_name(pred_id), "is_error": is_error, "error_type": err_type, "logit_entailment": float(prob_vec[0]), "logit_neutral": float(prob_vec[1]), "logit_contradiction": float(prob_vec[2]), "conf_true_label": prob_true, "difficulty": 1.0 - prob_true, }) ensure_output_dir(output_dir) out_path = os.path.join(output_dir, f"anli_error_buffer_{split_name}_{round_tag}.csv") fieldnames = [ "id", "premise", "hypothesis", "true_label_id", "true_label", "pred_label_id", "pred_label", "is_error", "error_type", "logit_entailment", "logit_neutral", "logit_contradiction", "conf_true_label", "difficulty", ] with open(out_path, "w", encoding="utf-8", newline="") as f: writer = csv.DictWriter(f, fieldnames=fieldnames) writer.writeheader() for row in rows: writer.writerow(row) total = len(rows) errors = sum(r["is_error"] for r in rows) acc = 100.0 * (total - errors) / max(1, total) print(f"Saved {total} rows to: {out_path}") print(f"{split_name} accuracy (recomputed here): {acc:.2f}% (errors={errors})") # ============================ # MAIN # ============================ def main(): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") print(f"\nLoading tokenizer and model from: {MODEL_PATH}") tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH) model.to(device) # ANLI splits: dev_r1, dev_r2, dev_r3 anli_splits = { "anli_r1_dev": "dev_r1", "anli_r2_dev": "dev_r2", "anli_r3_dev": "dev_r3", } for split_name, hf_split in anli_splits.items(): print(f"\nLoading ANLI split: {hf_split}") ds = load_dataset("anli", split=hf_split) # Filter out unlabeled (-1) if present and map into a simple list of dicts data = [] for ex in ds: label = int(ex["label"]) if label < 0: continue data.append({ "id": ex.get("uid", None), "premise": ex["premise"], "hypothesis": ex["hypothesis"], "label": label, }) print(f"{split_name}: {len(data)} labeled examples") run_model_on_dataset( model=model, tokenizer=tokenizer, data=data, split_name=split_name, # will appear in filename round_tag="round14", # consistent with adni_error_buffer_*_round1 device=device, output_dir=OUTPUT_DIR, ) print("\nAll done. ANLI error buffers are ready for SRL fine-tuning.") if __name__ == "__main__": main()