Spaces:
Running
Running
| """ | |
| ensemble.py | |
| ----------- | |
| Soft-voting ensemble that combines all trained classifiers. | |
| Each model's class probabilities are weighted and summed for a final prediction. | |
| Usage | |
| ----- | |
| # Interactive predictions | |
| python ensemble.py --interactive | |
| # Single prediction | |
| python ensemble.py --text "Tesla stock hits all-time high after earnings beat" | |
| # Custom weights (must sum to 1.0) | |
| python ensemble.py --text "..." --weights 0.05 0.10 0.85 | |
| # Use optimised weights from optimal_weights.json | |
| python ensemble.py --text "..." --optimal | |
| """ | |
| import argparse | |
| import json | |
| import logging | |
| import os | |
| import sys | |
| from typing import Dict, List, Optional, Tuple | |
| import numpy as np | |
| import torch | |
| from config import CFG | |
| import traditional_model as tm | |
| import transformer_model as trm | |
| logging.basicConfig(level=logging.WARNING) | |
| # Path where optimize_ensemble.py saves the best weights | |
| _OPTIMAL_WEIGHTS_FILE = os.path.join( | |
| CFG.outputs_dir, "ensemble_cache", "optimal_weights.json" | |
| ) | |
| # Default model names used in this ensemble | |
| _DEFAULT_MODELS = ["lr", "svm", "distilbert_base_uncased"] | |
| _DEFAULT_WEIGHTS = [0.10, 0.15, 0.75] | |
| # -- Probability helpers ------------------------------------------------------ | |
| def _proba_sklearn(text: str, pipeline) -> np.ndarray: | |
| clf = list(pipeline.named_steps.values())[-1] | |
| if hasattr(clf, "predict_proba"): | |
| return pipeline.predict_proba([text])[0] | |
| # LinearSVC: pseudo-probabilities via softmax over decision scores | |
| scores = pipeline.decision_function([text])[0] | |
| scores -= scores.max() | |
| exp = np.exp(scores) | |
| return exp / exp.sum() | |
| def _proba_transformer(text: str, model, tokenizer) -> np.ndarray: | |
| enc = tokenizer( | |
| text, | |
| truncation=True, | |
| max_length=CFG.max_length, | |
| return_tensors="pt", | |
| ) | |
| with torch.no_grad(): | |
| logits = model(**enc).logits[0] | |
| return torch.softmax(logits, dim=-1).numpy() | |
| # -- Optimal weights loader --------------------------------------------------- | |
| def load_optimal_weights( | |
| model_names: List[str], | |
| ) -> Optional[Dict[str, float]]: | |
| """ | |
| Attempt to load optimised weights from optimal_weights.json. | |
| Returns a dict mapping model_name -> weight, or None if the file is | |
| missing or malformed. | |
| """ | |
| if not os.path.exists(_OPTIMAL_WEIGHTS_FILE): | |
| logging.warning( | |
| f"[Ensemble] Optimal weights file not found at " | |
| f"'{_OPTIMAL_WEIGHTS_FILE}'. " | |
| f"Run: python optimize_ensemble.py" | |
| ) | |
| return None | |
| try: | |
| with open(_OPTIMAL_WEIGHTS_FILE) as fh: | |
| data = json.load(fh) | |
| weights = {name: data[name] for name in model_names if name in data} | |
| if len(weights) != len(model_names): | |
| logging.warning( | |
| "[Ensemble] optimal_weights.json does not contain weights " | |
| "for all requested models. Falling back to manual weights." | |
| ) | |
| return None | |
| logging.info( | |
| f"[Ensemble] Loaded optimal weights (method={data.get('method')}, " | |
| f"val_f1={data.get('val_f1_macro')}): {weights}" | |
| ) | |
| return weights | |
| except Exception as exc: | |
| logging.warning( | |
| f"[Ensemble] Could not load optimal_weights.json: {exc}. " | |
| f"Falling back to manual weights." | |
| ) | |
| return None | |
| # -- Ensemble class ----------------------------------------------------------- | |
| class Ensemble: | |
| """ | |
| Weighted soft-voting ensemble. | |
| Parameters | |
| ---------- | |
| model_weights : list of (model_name, weight) tuples. | |
| Weights are normalised automatically. | |
| model_name must match a key in saved_models/ | |
| ('lr', 'svm', 'distilbert_base_uncased', etc.) | |
| use_optimal_weights : bool, default True | |
| If True, attempt to load weights from | |
| outputs/ensemble_cache/optimal_weights.json and | |
| override the provided model_weights. | |
| Falls back to the provided weights if the file is | |
| missing or malformed. | |
| Example | |
| ------- | |
| >>> e = Ensemble([("lr", 0.10), ("svm", 0.15), ("distilbert_base_uncased", 0.75)]) | |
| >>> e.predict("Apple M5 chip breaks all benchmarks") | |
| >>> # Load with auto-optimised weights | |
| >>> e = Ensemble.from_optimal() | |
| """ | |
| def __init__( | |
| self, | |
| model_weights: List[Tuple[str, float]], | |
| use_optimal_weights: bool = True, | |
| ): | |
| # Attempt to override with optimised weights | |
| if use_optimal_weights: | |
| names = [name for name, _ in model_weights] | |
| optimal = load_optimal_weights(names) | |
| if optimal is not None: | |
| model_weights = [(name, optimal[name]) for name in names] | |
| print( | |
| f" [Ensemble] Using optimal weights from " | |
| f"{_OPTIMAL_WEIGHTS_FILE}" | |
| ) | |
| total = sum(w for _, w in model_weights) | |
| self._weights: Dict[str, float] = { | |
| name: w / total for name, w in model_weights | |
| } | |
| self._loaded: Dict = {} | |
| self._kinds: Dict = {} | |
| self._load_all() | |
| # -- Class methods -------------------------------------------------------- | |
| def from_optimal(cls, fallback_weights: Optional[List[Tuple[str, float]]] = None): | |
| """ | |
| Build an Ensemble using weights from optimal_weights.json. | |
| If the file is missing, falls back to `fallback_weights` (or the | |
| module-level defaults). | |
| Parameters | |
| ---------- | |
| fallback_weights : list of (model_name, weight) tuples, optional. | |
| Used when optimal_weights.json cannot be loaded. | |
| Returns | |
| ------- | |
| Ensemble instance | |
| """ | |
| if fallback_weights is None: | |
| fallback_weights = list(zip(_DEFAULT_MODELS, _DEFAULT_WEIGHTS)) | |
| # Try loading the optimal weights file directly | |
| optimal = load_optimal_weights([name for name, _ in fallback_weights]) | |
| if optimal is not None: | |
| weights = [(name, optimal[name]) for name, _ in fallback_weights] | |
| else: | |
| weights = fallback_weights | |
| # Pass use_optimal_weights=False to avoid double-loading | |
| return cls(weights, use_optimal_weights=False) | |
| # -- Internal helpers ----------------------------------------------------- | |
| def _load_all(self) -> None: | |
| for name in self._weights: | |
| print(f" Loading: {name} ...") | |
| if name in ("lr", "svm"): | |
| self._loaded[name] = tm.load_model(name) | |
| self._kinds[name] = "sklearn" | |
| else: | |
| # Transformer: name is the directory under saved_models/ | |
| self._loaded[name] = trm.load_model(name) | |
| self._kinds[name] = "transformer" | |
| print() | |
| def _proba(self, text: str, name: str) -> np.ndarray: | |
| if self._kinds[name] == "sklearn": | |
| return _proba_sklearn(text, self._loaded[name]) | |
| model, tokenizer = self._loaded[name] | |
| return _proba_transformer(text, model, tokenizer) | |
| # -- Public API ----------------------------------------------------------- | |
| def predict(self, text: str) -> Dict: | |
| """ | |
| Compute the weighted ensemble prediction for a single text. | |
| Returns predicted label, ensemble probabilities, and per-model | |
| debug info. | |
| """ | |
| combined = np.zeros(CFG.num_labels, dtype=float) | |
| model_probs = {} | |
| for name, weight in self._weights.items(): | |
| p = self._proba(text, name) | |
| combined += weight * p | |
| model_probs[name] = { | |
| CFG.label_names[i]: round(float(p[i]), 4) | |
| for i in range(CFG.num_labels) | |
| } | |
| pred_id = int(np.argmax(combined)) | |
| return { | |
| "text": text, | |
| "label_id": pred_id, | |
| "label": CFG.label_names[pred_id], | |
| "confidence": round(float(combined[pred_id]), 4), | |
| "ensemble_probabilities": { | |
| CFG.label_names[i]: round(float(combined[i]), 4) | |
| for i in range(CFG.num_labels) | |
| }, | |
| "per_model": model_probs, | |
| } | |
| def weights(self) -> Dict[str, float]: | |
| """Return the normalised per-model weights.""" | |
| return dict(self._weights) | |
| # -- Display ------------------------------------------------------------------ | |
| def display(result: Dict) -> None: | |
| snippet = result["text"][:88] + ("..." if len(result["text"]) > 88 else "") | |
| print(f"\n Input : {snippet}") | |
| print(f" Label : [{result['label_id']}] {result['label']}") | |
| print(f" Confidence : {result['confidence']:.4f}") | |
| print(" Ensemble Scores:") | |
| for label, prob in sorted( | |
| result["ensemble_probabilities"].items(), | |
| key=lambda x: x[1], | |
| reverse=True, | |
| ): | |
| bar = "#" * round(prob * 28) | |
| print(f" {label:<12} [{bar:<28}] {prob:.4f}") | |
| print() | |
| # -- CLI ---------------------------------------------------------------------- | |
| def main() -> None: | |
| parser = argparse.ArgumentParser(description="Ensemble Document Classifier") | |
| parser.add_argument( | |
| "--text", type=str, default=None, help="Single text to classify" | |
| ) | |
| parser.add_argument( | |
| "--interactive", | |
| action="store_true", | |
| help="Enter interactive prediction loop", | |
| ) | |
| parser.add_argument( | |
| "--weights", | |
| nargs=3, | |
| type=float, | |
| default=_DEFAULT_WEIGHTS, | |
| metavar=("LR_W", "SVM_W", "DISTILBERT_W"), | |
| help="Weights for LR, SVM, DistilBERT (auto-normalised)", | |
| ) | |
| parser.add_argument( | |
| "--optimal", | |
| action="store_true", | |
| default=False, | |
| help="Load weights from optimal_weights.json (ignores --weights)", | |
| ) | |
| parser.add_argument( | |
| "--no-optimal", | |
| dest="optimal", | |
| action="store_false", | |
| help="Disable automatic loading of optimal weights", | |
| ) | |
| args = parser.parse_args() | |
| print("\n Building Ensemble ...") | |
| model_weights = [ | |
| ("lr", args.weights[0]), | |
| ("svm", args.weights[1]), | |
| ("distilbert_base_uncased", args.weights[2]), | |
| ] | |
| # --optimal flag forces loading optimal weights; otherwise honour | |
| # use_optimal_weights=True default (auto-load if file exists) | |
| use_optimal = True # always attempt; falls back gracefully | |
| if args.optimal: | |
| ensemble = Ensemble.from_optimal(fallback_weights=model_weights) | |
| else: | |
| ensemble = Ensemble(model_weights, use_optimal_weights=use_optimal) | |
| print(f" Ensemble ready. Active weights: {ensemble.weights}\n") | |
| if args.interactive: | |
| print(" Ensemble -- Interactive Mode | Type 'q' to exit\n") | |
| while True: | |
| try: | |
| text = input(" >> ").strip() | |
| except (KeyboardInterrupt, EOFError): | |
| print("\n Bye.") | |
| break | |
| if not text: | |
| continue | |
| if text.lower() in {"q", "quit", "exit"}: | |
| print(" Bye.") | |
| break | |
| display(ensemble.predict(text)) | |
| elif args.text: | |
| display(ensemble.predict(args.text)) | |
| else: | |
| parser.print_help() | |
| sys.exit(1) | |
| if __name__ == "__main__": | |
| main() | |