Spaces:
Running
Running
| """ | |
| compare_results.py | |
| ββββββββββββββββββ | |
| Auto-discovers and evaluates all saved models side-by-side. | |
| Handles multiple transformer architectures in saved_models/. | |
| Usage | |
| βββββ | |
| python compare_results.py | |
| """ | |
| import logging | |
| import os | |
| from typing import Dict, List | |
| import numpy as np | |
| import torch | |
| from sklearn.metrics import accuracy_score, f1_score | |
| from config import CFG | |
| from data_loader import load_test_only | |
| import traditional_model as tm | |
| import transformer_model as trm | |
| logging.basicConfig(level=logging.WARNING) | |
| # ββ Discovery βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _discover_transformer_models() -> List[str]: | |
| """Return directory names of all saved transformer models.""" | |
| found = [] | |
| if not os.path.isdir(CFG.models_dir): | |
| return found | |
| for name in sorted(os.listdir(CFG.models_dir)): | |
| path = os.path.join(CFG.models_dir, name) | |
| if os.path.isdir(path) and os.path.exists(os.path.join(path, "config.json")): | |
| found.append(name) | |
| return found | |
| # ββ Evaluation ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _eval_traditional(name: str, X_test: List[str], y_test: List[int]) -> Dict: | |
| try: | |
| pipeline = tm.load_model(name) | |
| preds = list(pipeline.predict(X_test)) | |
| return { | |
| "accuracy": accuracy_score(y_test, preds), | |
| "f1_macro": f1_score(y_test, preds, average="macro"), | |
| } | |
| except FileNotFoundError: | |
| return {} | |
| def _eval_transformer(model_dir: str, X_test: List[str], y_test: List[int]) -> Dict: | |
| path = os.path.join(CFG.models_dir, model_dir) | |
| try: | |
| from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
| model = AutoModelForSequenceClassification.from_pretrained(path) | |
| tokenizer = AutoTokenizer.from_pretrained(path) | |
| model.eval() | |
| preds = [] | |
| batch_size = 32 | |
| for i in range(0, len(X_test), batch_size): | |
| batch = X_test[i : i + batch_size] | |
| enc = tokenizer(batch, truncation=True, max_length=CFG.max_length, | |
| padding=True, return_tensors="pt") | |
| with torch.no_grad(): | |
| logits = model(**enc).logits | |
| preds.extend(logits.argmax(dim=-1).tolist()) | |
| return { | |
| "accuracy": accuracy_score(y_test, preds), | |
| "f1_macro": f1_score(y_test, preds, average="macro"), | |
| } | |
| except FileNotFoundError: | |
| return {} | |
| except Exception as exc: | |
| print(f" [{model_dir}] Error: {exc}") | |
| return {} | |
| # ββ Main ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def main() -> None: | |
| print("\n Loading AG News test set β¦") | |
| X_test, y_test = load_test_only() | |
| print(f" Loaded {len(X_test):,} examples.\n") | |
| results: Dict[str, Dict] = {} | |
| # Traditional models | |
| for name in ["lr", "svm"]: | |
| print(f" Evaluating {name.upper()} β¦") | |
| r = _eval_traditional(name, X_test, y_test) | |
| if r: | |
| results[name.upper()] = r | |
| else: | |
| print(f" [{name.upper()}] not found β skipping.") | |
| # All saved transformer models | |
| transformer_dirs = _discover_transformer_models() | |
| if not transformer_dirs: | |
| print(" No transformer models found in saved_models/.") | |
| for model_dir in transformer_dirs: | |
| display_name = model_dir.replace("_", "-") | |
| print(f" Evaluating {display_name} β¦") | |
| r = _eval_transformer(model_dir, X_test, y_test) | |
| if r: | |
| results[display_name] = r | |
| if not results: | |
| print("\n No models found. Train at least one model first.\n") | |
| return | |
| # Print table | |
| print("\n" + "β" * 58) | |
| print(f" {'Model':<22} {'Accuracy':>10} {'F1-Macro':>10}") | |
| print("β" * 58) | |
| for name, m in sorted(results.items(), key=lambda x: x[1]["accuracy"], reverse=True): | |
| star = " β best" if name == max(results, key=lambda k: results[k]["accuracy"]) else "" | |
| print( | |
| f" {name:<22} {m['accuracy']*100:>9.2f}% " | |
| f"{m['f1_macro']:>10.4f}{star}" | |
| ) | |
| print("β" * 58 + "\n") | |
| if __name__ == "__main__": | |
| main() | |