nexa-classify-api / compare_results.py
Prototype6239's picture
Upload folder using huggingface_hub
a229747 verified
Raw
History Blame Contribute Delete
4.68 kB
"""
compare_results.py
──────────────────
Auto-discovers and evaluates all saved models side-by-side.
Handles multiple transformer architectures in saved_models/.
Usage
─────
python compare_results.py
"""
import logging
import os
from typing import Dict, List
import numpy as np
import torch
from sklearn.metrics import accuracy_score, f1_score
from config import CFG
from data_loader import load_test_only
import traditional_model as tm
import transformer_model as trm
logging.basicConfig(level=logging.WARNING)
# ── Discovery ─────────────────────────────────────────────────────────────────
def _discover_transformer_models() -> List[str]:
"""Return directory names of all saved transformer models."""
found = []
if not os.path.isdir(CFG.models_dir):
return found
for name in sorted(os.listdir(CFG.models_dir)):
path = os.path.join(CFG.models_dir, name)
if os.path.isdir(path) and os.path.exists(os.path.join(path, "config.json")):
found.append(name)
return found
# ── Evaluation ────────────────────────────────────────────────────────────────
def _eval_traditional(name: str, X_test: List[str], y_test: List[int]) -> Dict:
try:
pipeline = tm.load_model(name)
preds = list(pipeline.predict(X_test))
return {
"accuracy": accuracy_score(y_test, preds),
"f1_macro": f1_score(y_test, preds, average="macro"),
}
except FileNotFoundError:
return {}
def _eval_transformer(model_dir: str, X_test: List[str], y_test: List[int]) -> Dict:
path = os.path.join(CFG.models_dir, model_dir)
try:
from transformers import AutoModelForSequenceClassification, AutoTokenizer
model = AutoModelForSequenceClassification.from_pretrained(path)
tokenizer = AutoTokenizer.from_pretrained(path)
model.eval()
preds = []
batch_size = 32
for i in range(0, len(X_test), batch_size):
batch = X_test[i : i + batch_size]
enc = tokenizer(batch, truncation=True, max_length=CFG.max_length,
padding=True, return_tensors="pt")
with torch.no_grad():
logits = model(**enc).logits
preds.extend(logits.argmax(dim=-1).tolist())
return {
"accuracy": accuracy_score(y_test, preds),
"f1_macro": f1_score(y_test, preds, average="macro"),
}
except FileNotFoundError:
return {}
except Exception as exc:
print(f" [{model_dir}] Error: {exc}")
return {}
# ── Main ──────────────────────────────────────────────────────────────────────
def main() -> None:
print("\n Loading AG News test set …")
X_test, y_test = load_test_only()
print(f" Loaded {len(X_test):,} examples.\n")
results: Dict[str, Dict] = {}
# Traditional models
for name in ["lr", "svm"]:
print(f" Evaluating {name.upper()} …")
r = _eval_traditional(name, X_test, y_test)
if r:
results[name.upper()] = r
else:
print(f" [{name.upper()}] not found β€” skipping.")
# All saved transformer models
transformer_dirs = _discover_transformer_models()
if not transformer_dirs:
print(" No transformer models found in saved_models/.")
for model_dir in transformer_dirs:
display_name = model_dir.replace("_", "-")
print(f" Evaluating {display_name} …")
r = _eval_transformer(model_dir, X_test, y_test)
if r:
results[display_name] = r
if not results:
print("\n No models found. Train at least one model first.\n")
return
# Print table
print("\n" + "═" * 58)
print(f" {'Model':<22} {'Accuracy':>10} {'F1-Macro':>10}")
print("─" * 58)
for name, m in sorted(results.items(), key=lambda x: x[1]["accuracy"], reverse=True):
star = " β—€ best" if name == max(results, key=lambda k: results[k]["accuracy"]) else ""
print(
f" {name:<22} {m['accuracy']*100:>9.2f}% "
f"{m['f1_macro']:>10.4f}{star}"
)
print("═" * 58 + "\n")
if __name__ == "__main__":
main()