AetherMind_SRL / analyze_anli_errors_round1.py
samerzaher80's picture
Upload 4 files
1a6e63a verified
import os
import csv
from pathlib import Path
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from torch.nn.functional import softmax
from datasets import load_dataset
from tqdm import tqdm
# ============================
# CONFIG
# ============================
# 🔴 Target model you want to improve
MODEL_PATH = r"C:\Users\Sam\OneDrive\AetherMind\AetherMindProject\models\student_biomed_kd_fast\adni_srl_round13_smart"
# Where to save ANLI error buffers
OUTPUT_DIR = r"C:\Users\Sam\OneDrive\AetherMind\AetherMindProject\AetherMind_for_Alzheimers_Research\data\claims\analysis"
BATCH_SIZE = 32
MAX_LENGTH = 192
LABEL_ID2NAME = {
0: "entailment",
1: "neutral",
2: "contradiction",
}
# ============================
# HELPERS
# ============================
def ensure_output_dir(path: str):
os.makedirs(path, exist_ok=True)
def to_label_name(label_id: int) -> str:
return LABEL_ID2NAME.get(int(label_id), f"label_{label_id}")
def compute_error_type(true_id: int, pred_id: int) -> str:
if true_id == pred_id:
return "correct"
return f"{to_label_name(true_id)[0].upper()}->{to_label_name(pred_id)[0].upper()}"
# Example: E->N, N->C, C->E
def run_model_on_dataset(model, tokenizer, data, split_name: str, round_tag: str,
device: torch.device, output_dir: str):
"""
Run the model on a dataset (list of dicts with keys: 'premise','hypothesis','label','id').
Save a CSV with detailed per-example info.
"""
rows = []
print(f"\n=== Processing ANLI {split_name} ({len(data)} examples) ===")
model.eval()
with torch.no_grad():
for idx in tqdm(range(0, len(data), BATCH_SIZE), desc=f"{split_name} batches"):
batch_examples = data[idx:idx + BATCH_SIZE]
premises = [ex["premise"] for ex in batch_examples]
hypotheses = [ex["hypothesis"] for ex in batch_examples]
labels = [int(ex["label"]) for ex in batch_examples]
enc = tokenizer(
premises,
hypotheses,
padding=True,
truncation=True,
max_length=MAX_LENGTH,
return_tensors="pt",
)
input_ids = enc["input_ids"].to(device)
attention_mask = enc["attention_mask"].to(device)
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
logits = outputs.logits # [B, 3]
probs = softmax(logits, dim=-1) # [B, 3]
pred_ids = torch.argmax(probs, dim=-1).cpu().tolist()
probs_np = probs.cpu().tolist()
for i, ex in enumerate(batch_examples):
true_id = int(labels[i])
pred_id = int(pred_ids[i])
prob_vec = probs_np[i]
prob_true = float(prob_vec[true_id])
is_error = int(true_id != pred_id)
err_type = compute_error_type(true_id, pred_id)
ex_id = ex.get("id", ex.get("uid", idx + i))
rows.append({
"id": ex_id,
"premise": ex["premise"],
"hypothesis": ex["hypothesis"],
"true_label_id": true_id,
"true_label": to_label_name(true_id),
"pred_label_id": pred_id,
"pred_label": to_label_name(pred_id),
"is_error": is_error,
"error_type": err_type,
"logit_entailment": float(prob_vec[0]),
"logit_neutral": float(prob_vec[1]),
"logit_contradiction": float(prob_vec[2]),
"conf_true_label": prob_true,
"difficulty": 1.0 - prob_true,
})
ensure_output_dir(output_dir)
out_path = os.path.join(output_dir, f"anli_error_buffer_{split_name}_{round_tag}.csv")
fieldnames = [
"id",
"premise",
"hypothesis",
"true_label_id",
"true_label",
"pred_label_id",
"pred_label",
"is_error",
"error_type",
"logit_entailment",
"logit_neutral",
"logit_contradiction",
"conf_true_label",
"difficulty",
]
with open(out_path, "w", encoding="utf-8", newline="") as f:
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
for row in rows:
writer.writerow(row)
total = len(rows)
errors = sum(r["is_error"] for r in rows)
acc = 100.0 * (total - errors) / max(1, total)
print(f"Saved {total} rows to: {out_path}")
print(f"{split_name} accuracy (recomputed here): {acc:.2f}% (errors={errors})")
# ============================
# MAIN
# ============================
def main():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
print(f"\nLoading tokenizer and model from: {MODEL_PATH}")
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH)
model.to(device)
# ANLI splits: dev_r1, dev_r2, dev_r3
anli_splits = {
"anli_r1_dev": "dev_r1",
"anli_r2_dev": "dev_r2",
"anli_r3_dev": "dev_r3",
}
for split_name, hf_split in anli_splits.items():
print(f"\nLoading ANLI split: {hf_split}")
ds = load_dataset("anli", split=hf_split)
# Filter out unlabeled (-1) if present and map into a simple list of dicts
data = []
for ex in ds:
label = int(ex["label"])
if label < 0:
continue
data.append({
"id": ex.get("uid", None),
"premise": ex["premise"],
"hypothesis": ex["hypothesis"],
"label": label,
})
print(f"{split_name}: {len(data)} labeled examples")
run_model_on_dataset(
model=model,
tokenizer=tokenizer,
data=data,
split_name=split_name, # will appear in filename
round_tag="round14", # consistent with adni_error_buffer_*_round1
device=device,
output_dir=OUTPUT_DIR,
)
print("\nAll done. ANLI error buffers are ready for SRL fine-tuning.")
if __name__ == "__main__":
main()