import os import json from datetime import datetime from pathlib import Path import torch import numpy as np from datasets import load_dataset from transformers import AutoTokenizer, AutoModelForSequenceClassification from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix from tqdm.auto import tqdm # ============================ # CONFIG # ============================ MODEL_PATH = r"C:\Users\Sam\OneDrive\AetherMind\AetherMindProject\models\student_biomed_kd_fast\adni_srl_round14_smart" OUTPUT_DIR = r"C:\Users\Sam\OneDrive\AetherMind\AetherMindProject\evaluation_results\adni_srl_round2_fixed" BATCH_SIZE = 64 MAX_LENGTH = 192 # HuggingFace datasets DATASETS_CONFIG = [ ("SNLI", "snli", "test", None), ("MNLI M", "nyu-mll/multi_nli", "validation_matched", None), ("MNLI MM", "nyu-mll/multi_nli", "validation_mismatched", None), ("ANLI R1", "facebook/anli", "test_r1", None), ("ANLI R2", "facebook/anli", "test_r2", None), ("ANLI R3", "facebook/anli", "test_r3", None), ("XNLI", "facebook/xnli", "validation", "en"), ] # Local ADNI NLI JSON files ADNI_DATASETS = [ ("ADNI Train", r"C:\Users\Sam\OneDrive\AetherMind\AetherMindProject\AetherMind_for_Alzheimers_Research\data\claims\splits\adni_nli_train.json"), ("ADNI Val", r"C:\Users\Sam\OneDrive\AetherMind\AetherMindProject\AetherMind_for_Alzheimers_Research\data\claims\splits\adni_nli_val.json"), ("ADNI Test", r"C:\Users\Sam\OneDrive\AetherMind\AetherMindProject\AetherMind_for_Alzheimers_Research\data\claims\splits\adni_nli_test.json"), ] LABEL_NAMES = ["entailment", "neutral", "contradiction"] # ============================ # HELPER FUNCTIONS # ============================ def load_model_and_tokenizer(model_path: str, device: str): print(f"\n{'='*60}") print("Loading Model and Tokenizer") print(f"{'='*60}") print(f"Model: {model_path}") tokenizer = AutoTokenizer.from_pretrained(model_path) model = AutoModelForSequenceClassification.from_pretrained(model_path) model.to(device) model.eval() print(f"Device: {device}") print(f"Model loaded successfully!") return tokenizer, model def compute_metrics_from_predictions(name, labels, preds): accuracy = accuracy_score(labels, preds) precision, recall, f1, support = precision_recall_fscore_support( labels, preds, average=None, labels=[0, 1, 2], zero_division=0 ) macro_precision = float(np.mean(precision)) macro_recall = float(np.mean(recall)) macro_f1 = float(np.mean(f1)) conf_matrix = confusion_matrix(labels, preds, labels=[0, 1, 2]) print(f"\n{'='*60}") print(f"RESULTS: {name}") print(f"{'='*60}") print(f"Samples: {len(labels)}") print(f"Accuracy: {accuracy*100:.2f}%") print(f"Macro F1: {macro_f1*100:.2f}%") print(f"\nPer-Class Performance:") for i, label_name in enumerate(LABEL_NAMES): print( f" {label_name.upper():13} " f"P: {precision[i]*100:.2f}% " f"R: {recall[i]*100:.2f}% " f"F1: {f1[i]*100:.2f}% (n={support[i]})" ) result = { "dataset": name, "accuracy": float(accuracy), "macro_precision": macro_precision, "macro_recall": macro_recall, "macro_f1": macro_f1, "per_class": { LABEL_NAMES[i]: { "precision": float(precision[i]), "recall": float(recall[i]), "f1": float(f1[i]), "support": int(support[i]), } for i in range(3) }, "confusion_matrix": conf_matrix.tolist(), "total_samples": len(labels), } return result def evaluate_dataset( name: str, hf_name: str, split: str, config: str, tokenizer, model, device: str, batch_size: int, max_length: int, ): print(f"\n{'='*60}") print(f"Loading {name} Dataset") print(f"{'='*60}") if config: dataset = load_dataset(hf_name, config, split=split, trust_remote_code=False) else: dataset = load_dataset(hf_name, split=split, trust_remote_code=False) if "label" in dataset.column_names: dataset = dataset.filter(lambda ex: ex["label"] != -1) print(f"āœ… Loaded {len(dataset)} valid examples") premises = [str(ex["premise"]) for ex in dataset] hypotheses = [str(ex["hypothesis"]) for ex in dataset] labels = [int(ex["label"]) for ex in dataset] label_counts = {0: 0, 1: 0, 2: 0} for lab in labels: label_counts[lab] = label_counts.get(lab, 0) + 1 print(f"Label distribution: {label_counts}") print(f"\n{'='*60}") print(f"Evaluating: {name}") print(f"{'='*60}") all_preds = [] num_batches = (len(labels) + batch_size - 1) // batch_size with torch.no_grad(): for i in tqdm(range(0, len(labels), batch_size), total=num_batches, desc=f"{name}"): batch_premises = premises[i:i+batch_size] batch_hypotheses = hypotheses[i:i+batch_size] encodings = tokenizer( batch_premises, batch_hypotheses, padding=True, truncation=True, max_length=max_length, return_tensors="pt", ).to(device) outputs = model(**encodings) preds = torch.argmax(outputs.logits, dim=-1).cpu().tolist() all_preds.extend(preds) return compute_metrics_from_predictions(name, labels, all_preds) def extract_label(rec): """ Robustly extract label as int 0/1/2 from a JSON record. Handles: - rec['label'] as int or string - rec['true_label_id'] as int - rec['gold_label'] as string """ mapping = { "entailment": 0, "e": 0, "neutral": 1, "n": 1, "contradiction": 2, "c": 2, } if "label" in rec: v = rec["label"] if isinstance(v, int): return v v_str = str(v).strip().lower() if v_str in mapping: return mapping[v_str] raise ValueError(f"Unknown string label in 'label': {v}") if "true_label_id" in rec: return int(rec["true_label_id"]) if "gold_label" in rec: v_str = str(rec["gold_label"]).strip().lower() if v_str in mapping: return mapping[v_str] raise ValueError(f"Unknown string label in 'gold_label': {rec['gold_label']}") raise ValueError(f"Could not extract label from record keys: {list(rec.keys())}") def evaluate_local_json_dataset( name: str, json_path: str, tokenizer, model, device: str, batch_size: int, max_length: int, ): print(f"\n{'='*60}") print(f"Loading {name} (local JSON)") print(f"{'='*60}") print(f"Path: {json_path}") if not os.path.exists(json_path): raise FileNotFoundError(f"JSON file not found: {json_path}") with open(json_path, "r", encoding="utf-8") as f: data = json.load(f) if isinstance(data, dict) and "data" in data: records = data["data"] else: records = data premises = [] hypotheses = [] labels = [] for rec in records: premise = rec.get("premise") hypothesis = rec.get("hypothesis") if premise is None or hypothesis is None: raise ValueError("Expected 'premise' and 'hypothesis' keys in ADNI JSON records.") label = extract_label(rec) if label == -1: continue premises.append(str(premise)) hypotheses.append(str(hypothesis)) labels.append(int(label)) print(f"āœ… Loaded {len(labels)} valid examples") label_counts = {0: 0, 1: 0, 2: 0} for lab in labels: label_counts[lab] = label_counts.get(lab, 0) + 1 print(f"Label distribution: {label_counts}") print(f"\n{'='*60}") print(f"Evaluating: {name}") print(f"{'='*60}") all_preds = [] num_batches = (len(labels) + batch_size - 1) // batch_size with torch.no_grad(): for i in tqdm(range(0, len(labels), batch_size), total=num_batches, desc=name): batch_premises = premises[i:i + batch_size] batch_hypotheses = hypotheses[i:i + batch_size] encodings = tokenizer( batch_premises, batch_hypotheses, padding=True, truncation=True, max_length=max_length, return_tensors="pt", ).to(device) outputs = model(**encodings) preds = torch.argmax(outputs.logits, dim=-1).cpu().tolist() all_preds.extend(preds) return compute_metrics_from_predictions(name, labels, all_preds) def save_results(results: list, output_dir: str, model_path: str): os.makedirs(output_dir, exist_ok=True) timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") model_name = Path(model_path).name json_path = os.path.join(output_dir, f"results_{model_name}_{timestamp}.json") with open(json_path, "w", encoding="utf-8") as f: json.dump(results, f, indent=2) summary_path = os.path.join(output_dir, f"summary_{model_name}_{timestamp}.txt") with open(summary_path, "w", encoding="utf-8") as f: f.write("="*80 + "\n") f.write("COMPREHENSIVE NLI MODEL EVALUATION SUMMARY\n") f.write("="*80 + "\n") f.write(f"Model: {model_path}\n") f.write(f"Timestamp: {timestamp}\n") f.write("="*80 + "\n\n") for result in results: f.write(f"{result['dataset']}\n") f.write("-" * 40 + "\n") f.write(f"Accuracy: {result['accuracy']*100:.2f}%\n") f.write(f"Macro F1: {result['macro_f1']*100:.2f}%\n") f.write(f"Samples: {result['total_samples']}\n") f.write("\n") f.write("\n" + "="*80 + "\n") f.write("OVERALL STATISTICS\n") f.write("="*80 + "\n") avg_accuracy = np.mean([r['accuracy'] for r in results]) avg_f1 = np.mean([r['macro_f1'] for r in results]) f.write(f"Average Accuracy: {avg_accuracy*100:.2f}%\n") f.write(f"Average Macro F1: {avg_f1*100:.2f}%\n") print(f"\nāœ… Results saved:") print(f" JSON: {json_path}") print(f" Summary: {summary_path}") return json_path, summary_path # ============================ # MAIN # ============================ def main(): device = "cuda" if torch.cuda.is_available() else "cpu" print("="*80) print("COMPREHENSIVE NLI MODEL EVALUATION") print("="*80) print(f"Model: {MODEL_PATH}") all_names = [d[0] for d in DATASETS_CONFIG] + [d[0] for d in ADNI_DATASETS] print(f"Datasets: {', '.join(all_names)}") print("="*80) tokenizer, model = load_model_and_tokenizer(MODEL_PATH, device) all_results = [] for name, hf_name, split, config in DATASETS_CONFIG: result = evaluate_dataset( name=name, hf_name=hf_name, split=split, config=config, tokenizer=tokenizer, model=model, device=device, batch_size=BATCH_SIZE, max_length=MAX_LENGTH, ) all_results.append(result) for name, path in ADNI_DATASETS: result = evaluate_local_json_dataset( name=name, json_path=path, tokenizer=tokenizer, model=model, device=device, batch_size=BATCH_SIZE, max_length=MAX_LENGTH, ) all_results.append(result) save_results(all_results, OUTPUT_DIR, MODEL_PATH) print(f"\n{'='*80}") print("EVALUATION COMPLETE - FINAL SUMMARY") print(f"{'='*80}\n") print(f"{'Dataset':<15} {'Accuracy':<12} {'Macro F1':<12} {'Samples':<10}") print("-" * 50) for result in all_results: print( f"{result['dataset']:<15} " f"{result['accuracy']*100:>6.2f}% " f"{result['macro_f1']*100:>6.2f}% " f"{result['total_samples']:>6}" ) print("-" * 50) avg_accuracy = np.mean([r['accuracy'] for r in all_results]) avg_f1 = np.mean([r['macro_f1'] for r in all_results]) print(f"{'AVERAGE':<15} {avg_accuracy*100:>6.2f}% {avg_f1*100:>6.2f}%") print("="*80) if __name__ == "__main__": main()