Spaces:
Running
Running
| import argparse | |
| import json | |
| import os | |
| import warnings | |
| from pathlib import Path | |
| from typing import Any, Dict, Literal, Tuple | |
| import joblib | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import optuna | |
| import seaborn as sns | |
| from sklearn.exceptions import ConvergenceWarning | |
| from sklearn.feature_extraction.text import TfidfVectorizer | |
| from sklearn.linear_model import LogisticRegression | |
| from sklearn.metrics import accuracy_score, confusion_matrix, f1_score | |
| from sklearn.pipeline import Pipeline | |
| from sklearn.svm import LinearSVC | |
| from tqdm import tqdm | |
| from config import CFG | |
| from data_loader import get_raw_splits, load_ag_news | |
| ModelType = Literal["lr", "svm"] | |
| BASELINE_TEST_ACCURACY: Dict[ModelType, float] = {"lr": 0.9045, "svm": 0.9089} | |
| DEFAULT_MAX_TRAIN = 40_000 | |
| def _storage_url() -> str: | |
| db_path = os.path.abspath(os.path.join(CFG.outputs_dir, "optuna_studies.db")) | |
| return f"sqlite:///{Path(db_path).as_posix()}" | |
| def _hp_outputs_dir() -> str: | |
| path = os.path.join(CFG.outputs_dir, "hp_search") | |
| os.makedirs(path, exist_ok=True) | |
| return path | |
| def _trial_params(trial: optuna.trial.Trial, model_type: ModelType) -> Dict[str, Any]: | |
| params: Dict[str, Any] = {} | |
| params["max_features"] = trial.suggest_int( | |
| "max_features", 20_000, 100_000, step=10_000 | |
| ) | |
| params["ngram_min"] = trial.suggest_int("ngram_min", 1, 1) | |
| params["ngram_max"] = trial.suggest_int("ngram_max", 1, 3) | |
| params["min_df"] = trial.suggest_int("min_df", 1, 5) | |
| params["sublinear_tf"] = trial.suggest_categorical("sublinear_tf", [True, False]) | |
| if model_type == "lr": | |
| params["C"] = trial.suggest_float("C", 0.1, 20.0, log=True) | |
| params["solver"] = trial.suggest_categorical("solver", ["saga", "lbfgs"]) | |
| else: | |
| params["C"] = trial.suggest_float("C", 0.01, 10.0, log=True) | |
| return params | |
| def _build_pipeline(model_type: ModelType, params: Dict[str, Any]) -> Pipeline: | |
| tfidf = TfidfVectorizer( | |
| max_features=int(params["max_features"]), | |
| ngram_range=(int(params["ngram_min"]), int(params["ngram_max"])), | |
| sublinear_tf=bool(params["sublinear_tf"]), | |
| min_df=int(params["min_df"]), | |
| strip_accents="unicode", | |
| analyzer="word", | |
| token_pattern=r"\w{1,}", | |
| dtype=np.float32, | |
| ) | |
| if model_type == "lr": | |
| clf = LogisticRegression( | |
| C=float(params["C"]), | |
| solver=str(params["solver"]), | |
| max_iter=2_000, | |
| n_jobs=-1, | |
| random_state=CFG.seed, | |
| ) | |
| else: | |
| clf = LinearSVC( | |
| C=float(params["C"]), | |
| max_iter=3_000, | |
| random_state=CFG.seed, | |
| ) | |
| return Pipeline([("tfidf", tfidf), (model_type, clf)]) | |
| def _make_objective( | |
| model_type: ModelType, | |
| X_train, | |
| y_train, | |
| X_val, | |
| y_val, | |
| ): | |
| def objective(trial: optuna.trial.Trial) -> float: | |
| params = _trial_params(trial, model_type) | |
| pipeline = _build_pipeline(model_type, params) | |
| with warnings.catch_warnings(): | |
| warnings.filterwarnings("ignore", category=ConvergenceWarning) | |
| warnings.filterwarnings("ignore", category=FutureWarning) | |
| pipeline.fit(X_train, y_train) | |
| val_preds = pipeline.predict(X_val) | |
| return float(f1_score(y_val, val_preds, average="macro")) | |
| return objective | |
| def _save_confusion_matrix( | |
| cm, | |
| title: str, | |
| save_path: str, | |
| ) -> None: | |
| fig, ax = plt.subplots(figsize=(7, 6)) | |
| sns.heatmap( | |
| cm, | |
| annot=True, | |
| fmt="d", | |
| cmap="Blues", | |
| xticklabels=CFG.label_names, | |
| yticklabels=CFG.label_names, | |
| linewidths=0.5, | |
| ax=ax, | |
| ) | |
| ax.set_xlabel("Predicted Label", fontsize=11) | |
| ax.set_ylabel("True Label", fontsize=11) | |
| ax.set_title(title, fontsize=13, fontweight="bold") | |
| plt.tight_layout() | |
| os.makedirs(os.path.dirname(save_path), exist_ok=True) | |
| plt.savefig(save_path, dpi=150) | |
| plt.close(fig) | |
| def _save_optuna_plots(study: optuna.Study, model_type: ModelType) -> Tuple[str, str]: | |
| out_dir = _hp_outputs_dir() | |
| import optuna.visualization.matplotlib as ovm | |
| with warnings.catch_warnings(): | |
| warnings.filterwarnings("ignore", category=optuna.exceptions.ExperimentalWarning) | |
| warnings.filterwarnings("ignore", category=UserWarning) | |
| ax1 = ovm.plot_parallel_coordinate(study) | |
| fig1 = ax1.figure | |
| fig1.tight_layout() | |
| p1 = os.path.join(out_dir, f"{model_type}_parallel_coords.png") | |
| fig1.savefig(p1, dpi=150) | |
| plt.close(fig1) | |
| p2 = os.path.join(out_dir, f"{model_type}_param_importance.png") | |
| try: | |
| with warnings.catch_warnings(): | |
| warnings.filterwarnings("ignore", category=optuna.exceptions.ExperimentalWarning) | |
| warnings.filterwarnings("ignore", category=UserWarning) | |
| ax2 = ovm.plot_param_importances(study) | |
| fig2 = ax2.figure | |
| fig2.tight_layout() | |
| fig2.savefig(p2, dpi=150) | |
| plt.close(fig2) | |
| except ValueError: | |
| p2 = "" | |
| return p1, p2 | |
| def _create_or_reset_study( | |
| study_name: str, | |
| storage: str, | |
| resume: bool, | |
| ) -> optuna.Study: | |
| sampler = optuna.samplers.TPESampler(seed=CFG.seed) | |
| if not resume: | |
| try: | |
| optuna.delete_study(study_name=study_name, storage=storage) | |
| except KeyError: | |
| pass | |
| return optuna.create_study( | |
| direction="maximize", | |
| study_name=study_name, | |
| storage=storage, | |
| load_if_exists=True, | |
| sampler=sampler, | |
| ) | |
| def _run_model_search( | |
| model_type: ModelType, | |
| n_trials_total: int, | |
| full: bool, | |
| resume: bool, | |
| ) -> Dict[str, Any]: | |
| max_train = None if full else DEFAULT_MAX_TRAIN | |
| dataset = load_ag_news(max_train=max_train, max_eval=None, max_test=None) | |
| X_train, y_train, X_val, y_val, X_test, y_test = get_raw_splits(dataset) | |
| storage = _storage_url() | |
| study_name = f"{model_type}_hyperparams" | |
| study = _create_or_reset_study(study_name=study_name, storage=storage, resume=resume) | |
| remaining = max(0, int(n_trials_total) - len(study.trials)) | |
| if remaining == 0: | |
| best_params = study.best_params | |
| best_val_f1 = float(study.best_value) | |
| else: | |
| objective = _make_objective(model_type, X_train, y_train, X_val, y_val) | |
| pbar = tqdm(total=remaining, desc=f"{model_type.upper()} Optuna", unit="trial") | |
| def _cb(_study: optuna.Study, _trial: optuna.trial.FrozenTrial) -> None: | |
| pbar.update(1) | |
| study.optimize( | |
| objective, | |
| n_trials=remaining, | |
| callbacks=[_cb], | |
| gc_after_trial=True, | |
| show_progress_bar=False, | |
| ) | |
| pbar.close() | |
| best_params = study.best_params | |
| best_val_f1 = float(study.best_value) | |
| pipeline = _build_pipeline(model_type, best_params) | |
| with warnings.catch_warnings(): | |
| warnings.filterwarnings("ignore", category=ConvergenceWarning) | |
| warnings.filterwarnings("ignore", category=FutureWarning) | |
| if full: | |
| X_final = list(X_train) + list(X_val) | |
| y_final = list(y_train) + list(y_val) | |
| pipeline.fit(X_final, y_final) | |
| else: | |
| pipeline.fit(X_train, y_train) | |
| test_preds = pipeline.predict(X_test) | |
| test_acc = float(accuracy_score(y_test, test_preds)) | |
| cm = confusion_matrix(y_test, test_preds) | |
| out_dir = _hp_outputs_dir() | |
| cm_path = os.path.join(out_dir, f"{model_type}_confusion_matrix.png") | |
| _save_confusion_matrix( | |
| cm, | |
| title=f"Optimized {model_type.upper()} — Confusion Matrix", | |
| save_path=cm_path, | |
| ) | |
| model_path = os.path.join(CFG.models_dir, f"traditional_{model_type}_optimized.joblib") | |
| joblib.dump(pipeline, model_path) | |
| p1, p2 = _save_optuna_plots(study, model_type=model_type) | |
| return { | |
| "model": model_type, | |
| "study_name": study_name, | |
| "storage": storage, | |
| "max_train": max_train, | |
| "best_val_f1_macro": best_val_f1, | |
| "best_params": best_params, | |
| "test_accuracy": test_acc, | |
| "confusion_matrix_path": cm_path, | |
| "model_path": model_path, | |
| "plot_parallel_coords_path": p1, | |
| "plot_param_importance_path": p2, | |
| } | |
| def _print_summary(results: Dict[ModelType, Dict[str, Any]]) -> None: | |
| def _fmt_params(d: Dict[str, Any]) -> str: | |
| s = json.dumps(d, sort_keys=True) | |
| return s if len(s) <= 110 else s[:107] + "..." | |
| headers = ["Model", "Best Val F1", "Best Params", "Improvement vs Baseline"] | |
| rows = [] | |
| for model_type, r in results.items(): | |
| baseline = BASELINE_TEST_ACCURACY[model_type] | |
| improvement_pp = (float(r["test_accuracy"]) - baseline) * 100.0 | |
| rows.append( | |
| [ | |
| model_type.upper(), | |
| f"{float(r['best_val_f1_macro']):.4f}", | |
| _fmt_params(r["best_params"]), | |
| f"{improvement_pp:+.2f} pp", | |
| ] | |
| ) | |
| col_widths = [ | |
| max(len(headers[i]), max(len(str(row[i])) for row in rows)) for i in range(4) | |
| ] | |
| fmt = " | ".join(f"{{:<{w}}}" for w in col_widths) | |
| sep = "-+-".join("-" * w for w in col_widths) | |
| print("\n" + fmt.format(*headers)) | |
| print(sep) | |
| for row in rows: | |
| print(fmt.format(*row)) | |
| def main() -> None: | |
| parser = argparse.ArgumentParser(description="Optuna hyperparameter search for TF-IDF + LR/SVM.") | |
| parser.add_argument( | |
| "--model", | |
| choices=["lr", "svm", "all"], | |
| default="all", | |
| help="Which model study to run.", | |
| ) | |
| parser.add_argument( | |
| "--n-trials", | |
| type=int, | |
| default=30, | |
| help="Total trials per model study (respects --resume).", | |
| ) | |
| parser.add_argument( | |
| "--full", | |
| action="store_true", | |
| help="Use full 120K training examples (much slower per trial).", | |
| ) | |
| parser.add_argument( | |
| "--resume", | |
| action="store_true", | |
| help="Resume from the SQLite study DB (otherwise, resets the study).", | |
| ) | |
| args = parser.parse_args() | |
| optuna.logging.set_verbosity(optuna.logging.WARNING) | |
| train_note = ( | |
| "FULL (120K)" | |
| if args.full | |
| else f"CAPPED ({DEFAULT_MAX_TRAIN:,} max_train for i3 CPU)" | |
| ) | |
| print( | |
| f"[HP Search] Train size: {train_note}. " | |
| f"Override with --full to use the complete dataset." | |
| ) | |
| print(f"[HP Search] Storage: {_storage_url()}") | |
| results: Dict[ModelType, Dict[str, Any]] = {} | |
| if args.model in ("lr", "all"): | |
| results["lr"] = _run_model_search( | |
| model_type="lr", | |
| n_trials_total=args.n_trials, | |
| full=args.full, | |
| resume=args.resume, | |
| ) | |
| if args.model in ("svm", "all"): | |
| results["svm"] = _run_model_search( | |
| model_type="svm", | |
| n_trials_total=args.n_trials, | |
| full=args.full, | |
| resume=args.resume, | |
| ) | |
| _print_summary(results) | |
| out_dir = _hp_outputs_dir() | |
| best_params_path = os.path.join(out_dir, "best_params.json") | |
| with open(best_params_path, "w", encoding="utf-8") as f: | |
| json.dump(results, f, indent=2) | |
| print(f"\n[HP Search] Best params -> {best_params_path}") | |
| if __name__ == "__main__": | |
| main() | |