Spaces:
Running
Running
| import joblib | |
| from pathlib import Path | |
| from typing import Optional, List, Dict, Union, Any, Literal | |
| import pandas as pd | |
| import numpy as np | |
| from sklearn.base import BaseEstimator, ClassifierMixin | |
| from sklearn.compose import ColumnTransformer | |
| from sklearn.preprocessing import StandardScaler, OneHotEncoder | |
| from sklearn.decomposition import TruncatedSVD | |
| from imblearn.over_sampling import SMOTE | |
| from imblearn.pipeline import Pipeline as ImbPipeline | |
| from sklearn.pipeline import Pipeline | |
| from sklearn.metrics import classification_report | |
| from sklearn.metrics import confusion_matrix | |
| from xgboost import XGBClassifier | |
| import optuna | |
| from optuna.samplers import QMCSampler | |
| from sklearn.metrics import accuracy_score, f1_score | |
| try: | |
| import seaborn as sns | |
| import matplotlib.pyplot as plt | |
| HAS_VISUALIZATION = True | |
| except ImportError: | |
| HAS_VISUALIZATION = False | |
| from .edge_features import extract_edge_features, get_edge_features | |
| class GraphEdgeClassifier(BaseEstimator, ClassifierMixin): | |
| """ | |
| Edge-level graph classifier for PROTACs with integrated pipeline building. | |
| """ | |
| def __init__( | |
| self, | |
| graph_features: List[str], | |
| categorical_features: Optional[List[str]] = None, | |
| descriptor_features: Optional[List[str]] = None, | |
| fingerprint_features: Optional[List[str]] = None, | |
| use_descriptors: bool = True, | |
| use_fingerprints: bool = True, | |
| scaler_graph: Literal["passthrough", "standard"] = "passthrough", | |
| scaler_desc: Literal["passthrough", "standard"] = "passthrough", | |
| use_svd_fp: bool = True, | |
| n_svd_components: int = 100, | |
| binary: bool = False, | |
| smote_k_neighbors: Optional[int] = 5, | |
| xgb_params: Optional[dict] = None, | |
| n_bits: int = 512, | |
| radius: int = 6, | |
| descriptor_names: Optional[List[str]] = None | |
| ): | |
| self.graph_features = graph_features | |
| self.categorical_features = categorical_features | |
| self.descriptor_features = descriptor_features | |
| self.fingerprint_features = fingerprint_features | |
| self.use_descriptors = use_descriptors | |
| self.use_fingerprints = use_fingerprints | |
| self.scaler_graph = scaler_graph | |
| self.scaler_desc = scaler_desc | |
| self.use_svd_fp = use_svd_fp | |
| self.n_svd_components = n_svd_components | |
| self.binary = binary | |
| self.smote_k_neighbors = smote_k_neighbors | |
| self.xgb_params = xgb_params or {} | |
| self.n_bits = n_bits | |
| self.radius = radius | |
| self.descriptor_names = descriptor_names or [ | |
| "MolWt", "HeavyAtomCount", "NumHAcceptors", "NumHDonors", | |
| "TPSA", "NumRotatableBonds", "RingCount", "MolLogP" | |
| ] | |
| self.pipeline = self._build_pipeline() | |
| def _build_pipeline(self): | |
| transformers = [] | |
| if self.categorical_features: | |
| transformers.append(("cat", OneHotEncoder(handle_unknown="ignore"), self.categorical_features)) | |
| if self.scaler_graph == "standard": | |
| transformers.append(("num", StandardScaler(), self.graph_features)) | |
| else: | |
| transformers.append(("num", "passthrough", self.graph_features)) | |
| if self.use_descriptors and self.descriptor_features: | |
| desc_block = ( | |
| ("desc", StandardScaler(), self.descriptor_features) | |
| if self.scaler_desc == "standard" | |
| else ("desc", "passthrough", self.descriptor_features) | |
| ) | |
| transformers.append(desc_block) | |
| if self.use_fingerprints and self.fingerprint_features: | |
| if self.use_svd_fp: | |
| fp_block = ("fp", | |
| ImbPipeline([ | |
| ("svd", TruncatedSVD(n_components=self.n_svd_components, random_state=42)) | |
| ]), | |
| self.fingerprint_features) | |
| else: | |
| fp_block = ("fp", "passthrough", self.fingerprint_features) | |
| transformers.append(fp_block) | |
| preprocessor = ColumnTransformer(transformers) | |
| # Define the classifier | |
| classifier = XGBClassifier( | |
| random_state=42, | |
| eval_metric="logloss" if self.binary else "mlogloss", | |
| objective="binary:logistic" if self.binary else "multi:softprob", | |
| **self.xgb_params | |
| ) | |
| if self.smote_k_neighbors is not None: | |
| return ImbPipeline([ | |
| ("preprocess", preprocessor), | |
| ("smote", SMOTE(random_state=42, k_neighbors=self.smote_k_neighbors)), | |
| ("clf", classifier) | |
| ]) | |
| else: | |
| return Pipeline([ | |
| ("preprocess", preprocessor), | |
| ("clf", classifier) | |
| ]) | |
| def fit(self, X: pd.DataFrame, y: pd.Series): | |
| self.pipeline.fit(X, y) | |
| return self | |
| def predict(self, X: Union[pd.DataFrame, List[Dict], List[str]]) -> Any: | |
| X_proc = self._ensure_features(X) | |
| return self.pipeline.predict(X_proc) | |
| def predict_proba(self, X: Union[pd.DataFrame, List[Dict], List[str]]) -> Any: | |
| X_proc = self._ensure_features(X) | |
| return self.pipeline.predict_proba(X_proc) | |
| def save(self, path: Union[str, Path]): | |
| joblib.dump(self, str(path)) | |
| def load(cls, path: Union[str, Path]) -> "GraphEdgeClassifier": | |
| return joblib.load(str(path)) | |
| def extract_graph_features( | |
| protac_smiles: Union[str, List[str]], | |
| wh_smiles: Optional[Union[str, List[str]]] = None, | |
| lk_smiles: Optional[Union[str, List[str]]] = None, | |
| e3_smiles: Optional[Union[str, List[str]]] = None, | |
| n_bits: int = 512, | |
| radius: int = 6, | |
| descriptor_names: Optional[List[str]] = None, | |
| verbose: int = 0 | |
| ) -> pd.DataFrame: | |
| if any(x is None for x in [wh_smiles, lk_smiles, e3_smiles]): | |
| # Get features from PROTAC only, for inference | |
| return extract_edge_features( | |
| protac_smiles=protac_smiles, | |
| n_bits=n_bits, | |
| radius=radius, | |
| descriptor_names=descriptor_names, | |
| ) | |
| else: | |
| # Get features and labels from all components, for training | |
| return get_edge_features( | |
| protac_smiles=protac_smiles, | |
| wh_smiles=wh_smiles, | |
| lk_smiles=lk_smiles, | |
| e3_smiles=e3_smiles, | |
| n_bits=n_bits, | |
| radius=radius, | |
| descriptor_names=descriptor_names, | |
| verbose=verbose | |
| ) | |
| def build_multiclass_target( | |
| df: pd.DataFrame, | |
| poi_attachment_id: int = 1, | |
| e3_attachment_id: int = 2, | |
| ) -> pd.Series: | |
| """ | |
| Returns multiclass target: 0 = no split, 1 = E3 split, 2 = WH split | |
| """ | |
| assert ((df["label_e3_split"] + df["label_wh_split"]) <= 1).all() | |
| y = ( | |
| df["label_wh_split"] * poi_attachment_id + | |
| df["label_e3_split"] * e3_attachment_id | |
| ) | |
| return y.astype("int32") | |
| def _ensure_features(self, X: Union[pd.DataFrame, List[Dict], List[str]]) -> pd.DataFrame: | |
| """ Filter out features/columns that are are not used in the pipeline. """ | |
| required_columns = ( | |
| (self.graph_features or []) + | |
| (self.categorical_features or []) + | |
| (self.descriptor_features or []) + | |
| (self.fingerprint_features or []) | |
| ) | |
| # If input is a DataFrame with SMILES, assume already featurized | |
| if isinstance(X, pd.DataFrame): | |
| Xf = X | |
| elif isinstance(X, list) and isinstance(X[0], dict): | |
| Xf = pd.DataFrame(X) | |
| else: | |
| raise ValueError("Provide either a DataFrame or list of feature dicts. Use extract_graph_features for SMILES.") | |
| missing = set(required_columns) - set(Xf.columns) | |
| if missing: | |
| raise ValueError(f"Input data missing required columns: {missing}") | |
| return Xf[required_columns].copy() | |
| def predict_proba_from_smiles( | |
| self, | |
| protac_smiles: Union[str, List[str]], | |
| wh_smiles: Union[str, List[str]], | |
| lk_smiles: Union[str, List[str]], | |
| e3_smiles: Union[str, List[str]], | |
| verbose: int = 0, | |
| ): | |
| features = self.extract_graph_features( | |
| protac_smiles, wh_smiles, lk_smiles, e3_smiles, | |
| n_bits=self.n_bits, | |
| radius=self.radius, | |
| descriptor_names=self.descriptor_names, | |
| verbose=verbose | |
| ) | |
| Xf = self._ensure_features(features) | |
| return self.pipeline.predict_proba(Xf) | |
| def predict_from_smiles( | |
| self, | |
| protac_smiles: Union[str, List[str]], | |
| wh_smiles: Union[str, List[str]], | |
| lk_smiles: Union[str, List[str]], | |
| e3_smiles: Union[str, List[str]], | |
| top_n: int = 1, | |
| return_array: bool = True, | |
| verbose: int = 0, | |
| ) -> Union[pd.DataFrame, np.ndarray]: | |
| """ | |
| For binary classification: | |
| For each SMILES, return the top_n edge chem_bond_idx indices among those predicted as class 1, | |
| sorted by predicted probability. If not enough edges are class 1, pad with -1. | |
| For multiclass: | |
| For each SMILES, return the chem_bond_idx with highest probability for class 1 (E3 split) | |
| and for class 2 (WH split). Shape: (num_smiles, 2). | |
| If no edge is predicted as that class, value is -1. | |
| """ | |
| features = self.extract_graph_features( | |
| protac_smiles, wh_smiles, lk_smiles, e3_smiles, | |
| n_bits=self.n_bits, | |
| radius=self.radius, | |
| descriptor_names=self.descriptor_names, | |
| verbose=verbose | |
| ) | |
| Xf = self._ensure_features(features) | |
| pred_proba = self.pipeline.predict_proba(Xf) | |
| pred_label = self.pipeline.predict(Xf) | |
| features = features.copy() | |
| features["pred_label"] = pred_label | |
| features["pred_proba"] = pred_proba[:, 1] if pred_proba.shape[1] > 1 else pred_proba[:, 0] | |
| # NOTE: The SMILES is repeated for each edge, so we can drop duplicates | |
| # and group by SMILES to get the top_n edges per SMILES. | |
| unique_smiles = pd.Series(features["chem_mol_smiles"]).drop_duplicates().tolist() | |
| groupby = features.groupby("chem_mol_smiles") | |
| results = [] | |
| if return_array: | |
| if pred_proba.shape[1] == 2: # Binary case | |
| for mol_smiles in unique_smiles: | |
| group = groupby.get_group(mol_smiles) | |
| # Sort by proba, take top_n | |
| if top_n < 0: | |
| top_n = len(group["graph_num_bridges"]) | |
| top_edges = group.nlargest(top_n, "pred_proba") | |
| idxs = top_edges["chem_bond_idx"].to_numpy() | |
| if len(idxs) < top_n: | |
| idxs = np.pad(idxs, (0, top_n - len(idxs)), constant_values=-1) | |
| results.append(idxs[:top_n]) | |
| return np.vstack(results) | |
| else: # Multiclass case | |
| for mol_smiles in unique_smiles: | |
| group = groupby.get_group(mol_smiles) | |
| # For class 1 | |
| class1_idx = -1 | |
| if (group["pred_label"] == 1).any(): | |
| # Take the edge with highest class-1 probability | |
| mask = group["pred_label"] == 1 | |
| idx1 = group.loc[mask, "pred_proba"].idxmax() | |
| class1_idx = group.loc[idx1, "chem_bond_idx"] | |
| # For class 2 | |
| class2_idx = -1 | |
| if (group["pred_label"] == 2).any(): | |
| mask = group["pred_label"] == 2 | |
| idx2 = group.loc[mask, "pred_proba"].idxmax() | |
| class2_idx = group.loc[idx2, "chem_bond_idx"] | |
| results.append([class1_idx, class2_idx]) | |
| return np.array(results, dtype=int) | |
| else: | |
| return features | |
| def get_classification_report(y_true, y_pred, labels): | |
| report = classification_report(y_true, y_pred, target_names=labels, output_dict=True) | |
| df_report = pd.DataFrame(report).transpose().round(2) | |
| print(df_report) | |
| return df_report | |
| def plot_confusion_matrix(y_true, y_pred, labels): | |
| cm = confusion_matrix(y_true, y_pred) | |
| if HAS_VISUALIZATION: | |
| plt.figure(figsize=(8, 6)) | |
| sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=labels, yticklabels=labels) | |
| plt.xlabel("Predicted") | |
| plt.ylabel("True") | |
| plt.title("Confusion Matrix") | |
| plt.show() | |
| else: | |
| print("Visualization libraries not available. Skipping confusion matrix plot.") | |
| print("Confusion Matrix:") | |
| print(cm) | |
| def get_classification_report_and_plot(y_true, y_pred, labels): | |
| report = get_classification_report(y_true, y_pred, labels) | |
| plot_confusion_matrix(y_true, y_pred, labels) | |
| return report | |
| def train_edge_classifier( | |
| train_df: pd.DataFrame, | |
| val_df: Optional[pd.DataFrame] = None, | |
| test_df: Optional[pd.DataFrame] = None, | |
| model_filename: Optional[Union[str, Path]] = None, | |
| edge_classifier_kwargs: Optional[Dict[str, Any]] = None, | |
| cache_dir: Optional[Union[str, Path]] = None, | |
| return_reports: bool = True, | |
| plot_confusion_matrix: bool = False, | |
| ) -> GraphEdgeClassifier: | |
| """ | |
| Train an edge-level graph classifier for PROTACs. | |
| Args: | |
| train_df (pd.DataFrame): Training data with columns: | |
| - 'PROTAC SMILES' | |
| - 'POI Ligand SMILES with direction' | |
| - 'Linker SMILES with direction' | |
| - 'E3 Binder SMILES with direction' | |
| val_df (Optional[pd.DataFrame]): Validation data, same format as train_df. | |
| test_df (Optional[pd.DataFrame]): Test data, same format as train_df. | |
| model_filename (Optional[Union[str, Path]]): Path to save the trained model. | |
| edge_classifier_kwargs (Optional[Dict[str, Any]]): Additional parameters for GraphEdgeClassifier. | |
| return_reports (bool): Whether to return classification reports for validation and test sets. | |
| Returns: | |
| GraphEdgeClassifier: Trained edge classifier instance. | |
| """ | |
| sets = {} | |
| for set_name, df in [ | |
| ("train", train_df), | |
| ("val", val_df), | |
| ("test", test_df), | |
| ]: | |
| if cache_dir is not None: | |
| cache_path = Path(cache_dir) / f"{set_name}.csv" | |
| if cache_path.exists(): | |
| print(f"Loading cached features for {set_name} from {cache_path}") | |
| sets[set_name] = pd.read_csv(cache_path) | |
| continue | |
| else: | |
| print(f"Cache not found for {set_name}, extracting features...") | |
| if df is None or df.empty: | |
| continue | |
| print(f"Set: {set_name}, size: {len(df):,}") | |
| if 'PROTAC SMILES' not in df.columns or \ | |
| 'POI Ligand SMILES with direction' not in df.columns or \ | |
| 'Linker SMILES with direction' not in df.columns or \ | |
| 'E3 Binder SMILES with direction' not in df.columns: | |
| raise ValueError(f"DataFrame for {set_name} is missing required columns: 'PROTAC SMILES', 'POI Ligand SMILES with direction', 'Linker SMILES with direction', 'E3 Binder SMILES with direction'.") | |
| sets[set_name] = GraphEdgeClassifier.extract_graph_features( | |
| df['PROTAC SMILES'].tolist(), | |
| df['POI Ligand SMILES with direction'].tolist(), | |
| df['Linker SMILES with direction'].tolist(), | |
| df['E3 Binder SMILES with direction'].tolist(), | |
| verbose=1, | |
| ) | |
| # Drop rows with label_e3_split + label_wh_split > 1 | |
| sets[set_name] = sets[set_name][(sets[set_name]["label_e3_split"] + sets[set_name]["label_wh_split"]) <= 1] | |
| print(f"Set: {set_name}, size: {len(sets[set_name]):,}") | |
| if cache_dir is not None: | |
| cache_path = Path(cache_dir) / f"{set_name}.csv" | |
| cache_path.parent.mkdir(parents=True, exist_ok=True) | |
| sets[set_name].to_csv(cache_path, index=False) | |
| print(f"Saved {set_name} features to {cache_path}") | |
| train_set = sets["train"] | |
| label_cols = [c for c in train_set.columns if c.startswith("label_")] | |
| train_set = train_set.dropna(subset=label_cols) | |
| train_set = train_set[(train_set["label_e3_split"] + train_set["label_wh_split"]) <= 1] | |
| X_train = train_set.drop(columns=label_cols) | |
| # Instantiate and train | |
| clf = GraphEdgeClassifier(**edge_classifier_kwargs or { | |
| "graph_features": [c for c in X_train.columns if c.startswith("graph_")], | |
| "categorical_features": ["chem_bond_type", "chem_atom_u", "chem_atom_v"], | |
| "fingerprint_features": [c for c in X_train.columns if c.startswith("chem_mol_fp_")], | |
| "use_descriptors": False, | |
| "use_fingerprints": True, | |
| "n_svd_components": 50, | |
| "binary": True, | |
| "smote_k_neighbors": 10, | |
| "xgb_params": { | |
| "max_depth": 6, | |
| "learning_rate": 0.3, | |
| "alpha": 0.1, # Default: 0 | |
| "lambda": 0.5, # Default: 1 | |
| "gamma": 0.1, # Default: 0 | |
| }, | |
| }) | |
| # Prepare target variable according to classification type | |
| if clf.binary: | |
| y_train = train_set["label_is_split"].astype("int32") | |
| else: | |
| y_train = GraphEdgeClassifier.build_multiclass_target(train_set) | |
| print(f"Training set size: {len(X_train):,}, labels: {y_train.unique()}") | |
| clf.fit(X_train, y_train) | |
| print("Training complete.") | |
| if model_filename is not None: | |
| clf.save(model_filename) | |
| print(f"Model saved to {model_filename}") | |
| target_labels = ["No Split", "Split"] if clf.binary else ["No Split", "WH-Linker", "E3-Linker"] | |
| report = None | |
| if "val" in sets: | |
| # Get validation data | |
| val_set = sets["val"].dropna(subset=label_cols) | |
| val_set = val_set[(val_set["label_e3_split"] + val_set["label_wh_split"]) <= 1] | |
| X_val = val_set.drop(columns=label_cols) | |
| y_val = val_set["label_is_split"].astype("int32") if clf.binary else GraphEdgeClassifier.build_multiclass_target(val_set) | |
| y_pred = clf.predict(X_val) | |
| if plot_confusion_matrix: | |
| report = get_classification_report_and_plot(y_val, y_pred, target_labels) | |
| else: | |
| report = get_classification_report(y_val, y_pred, target_labels) | |
| print(f"Validation set classification report:\n{report.to_markdown(index=False)}") | |
| if "test" in sets: | |
| # Get test data | |
| test_set = sets["test"].dropna(subset=label_cols) | |
| test_set = test_set[(test_set["label_e3_split"] + test_set["label_wh_split"]) <= 1] | |
| X_test = test_set.drop(columns=label_cols) | |
| y_test = test_set["label_is_split"].astype("int32") if clf.binary else GraphEdgeClassifier.build_multiclass_target(test_set) | |
| y_pred = clf.predict(X_test) | |
| if plot_confusion_matrix: | |
| report = get_classification_report_and_plot(y_test, y_pred, target_labels) | |
| else: | |
| report = get_classification_report(y_test, y_pred, target_labels) | |
| print(f"Test set classification report:\n{report.to_markdown(index=False)}") | |
| if return_reports: | |
| return clf, report | |
| else: | |
| return clf | |
| def objective(trial, train_df, val_df): | |
| # HP space | |
| max_depth = trial.suggest_int("max_depth", 3, 10) | |
| learning_rate = trial.suggest_float("learning_rate", 0.01, 0.3, log=True) | |
| alpha = trial.suggest_float("alpha", 0.0, 2.0) | |
| reg_lambda = trial.suggest_float("lambda", 0.0, 2.0) | |
| gamma = trial.suggest_float("gamma", 0.0, 1.0) | |
| n_svd_components = trial.suggest_int("n_svd_components", 16, 128) | |
| smote_k_neighbors = trial.suggest_int("smote_k_neighbors", 3, 15) | |
| use_descriptors = trial.suggest_categorical("use_descriptors", [False, True]) | |
| use_fingerprints = trial.suggest_categorical("use_fingerprints", [True, False]) | |
| edge_classifier_kwargs = { | |
| "graph_features": None, # Will be set in train_edge_classifier | |
| "categorical_features": None, | |
| "fingerprint_features": None, | |
| "use_descriptors": use_descriptors, | |
| "use_fingerprints": use_fingerprints, | |
| "n_svd_components": n_svd_components, | |
| "binary": True, | |
| "smote_k_neighbors": smote_k_neighbors, | |
| "xgb_params": { | |
| "max_depth": max_depth, | |
| "learning_rate": learning_rate, | |
| "alpha": alpha, | |
| "lambda": reg_lambda, | |
| "gamma": gamma, | |
| }, | |
| } | |
| _, val_report = train_edge_classifier( | |
| train_df=train_df, | |
| val_df=val_df, | |
| edge_classifier_kwargs=edge_classifier_kwargs, | |
| return_reports=True, | |
| ) | |
| # Evaluate metrics on validation set | |
| # Assume val_report has columns: ['Label', 'precision', 'recall', 'f1-score', 'support'] | |
| # and that the binary positive class is "Split" or "1" | |
| try: | |
| f1_1 = float(val_report[val_report["Label"].isin(["Split", 1, "1"])]["f1-score"]) | |
| except Exception: | |
| f1_1 = 0.0 | |
| try: | |
| acc = float(val_report[val_report["Label"] == "accuracy"]["f1-score"]) | |
| except Exception: | |
| acc = 0.0 | |
| # Multi-objective: prioritize F1 for minority class, but keep accuracy | |
| # Adjust weight depending on task (here equal) | |
| score = 0.5 * acc + 0.5 * f1_1 | |
| return score | |
| def run_optuna_search( | |
| train_df: pd.DataFrame, | |
| val_df: pd.DataFrame, | |
| n_trials: int = 50, | |
| study_name: str = "edge_classifier_hp_search", | |
| study_dir: str = "./optuna_studies", | |
| seed: int = 42, | |
| ) -> Any: | |
| import os | |
| os.makedirs(study_dir, exist_ok=True) | |
| study_path = f"sqlite:///{os.path.join(study_dir, study_name)}.db" | |
| study = optuna.create_study( | |
| study_name=study_name, | |
| direction="maximize", | |
| sampler=QMCSampler(seed=seed, qmc_type="sobol"), | |
| storage=study_path, | |
| load_if_exists=True, | |
| ) | |
| func = lambda trial: objective(trial, train_df, val_df) | |
| study.optimize(func, n_trials=n_trials, show_progress_bar=True) | |
| print("Best trial:") | |
| print(study.best_trial) | |
| # Train classifier with best HP and return it | |
| best_params = study.best_trial.params | |
| edge_classifier_kwargs = { | |
| "graph_features": None, | |
| "categorical_features": None, | |
| "fingerprint_features": None, | |
| "use_descriptors": best_params["use_descriptors"], | |
| "use_fingerprints": best_params["use_fingerprints"], | |
| "n_svd_components": best_params["n_svd_components"], | |
| "binary": True, | |
| "smote_k_neighbors": best_params["smote_k_neighbors"], | |
| "xgb_params": { | |
| "max_depth": best_params["max_depth"], | |
| "learning_rate": best_params["learning_rate"], | |
| "alpha": best_params["alpha"], | |
| "lambda": best_params["lambda"], | |
| "gamma": best_params["gamma"], | |
| }, | |
| } | |
| clf, _ = train_edge_classifier( | |
| train_df=train_df, | |
| val_df=val_df, | |
| edge_classifier_kwargs=edge_classifier_kwargs, | |
| return_reports=True, | |
| ) | |
| study_file = os.path.join(study_dir, f"{study_name}_study.pkl") | |
| import joblib | |
| joblib.dump(study, study_file) | |
| print(f"Optuna study saved to {study_file}") | |
| return clf, study |