File size: 6,521 Bytes
ec17199
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46411f2
ec17199
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46411f2
ec17199
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87c739b
 
 
 
 
46411f2
 
 
 
 
87c739b
 
 
ec17199
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46411f2
ec17199
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
import itertools as it
import os

import joblib
import numpy as np
import pandas as pd
import pkg_resources
import streamlit as st
from b3clf.descriptor_padel import compute_descriptors
from b3clf.geometry_opt import geometry_optimize
from b3clf.utils import get_descriptors, scale_descriptors, select_descriptors


@st.cache_resource()
def load_all_models():
    """Get b3clf fitted classifier"""
    clf_list = ["dtree", "knn", "logreg", "xgb"]
    sampling_list = [
        "borderline_SMOTE",
        "classic_ADASYN",
        "classic_RandUndersampling",
        "classic_SMOTE",
        "kmeans_SMOTE",
        "common",
    ]

    model_dict = {}
    package_name = "b3clf"

    for clf_str, sampling_str in it.product(clf_list, sampling_list):
        # joblib_fpath = os.path.join(
        #     dirname, "pre_trained", "b3clf_{}_{}.joblib".format(clf_str, sampling_str))
        # pred_model = joblib.load(joblib_fpath)
        joblib_path_str = f"pre_trained/b3clf_{clf_str}_{sampling_str}.joblib"
        with pkg_resources.resource_stream(package_name, joblib_path_str) as f:
            pred_model = joblib.load(f)

        model_dict[clf_str + "_" + sampling_str] = pred_model

    return model_dict


@st.cache_resource
def predict_permeability(
    clf_str, sampling_str, _models_dict, mol_features, info_df, threshold="none"
):
    """Compute permeability prediction for given feature data."""
    # load the model
    # pred_model = load_all_models()[clf_str + "_" + sampling_str]
    pred_model = _models_dict[clf_str + "_" + sampling_str]

    # load the threshold data
    package_name = "b3clf"
    with pkg_resources.resource_stream(package_name, "data/B3clf_thresholds.xlsx") as f:
        df_thres = pd.read_excel(f, index_col=0, engine="openpyxl")

    # default threshold is 0.5
    label_pool = np.zeros(mol_features.shape[0], dtype=int)

    if type(mol_features) == pd.DataFrame:
        if mol_features.index.tolist() != info_df.index.tolist():
            raise ValueError("mol_features and Info_df do not have the same index.")

    # get predicted probabilities
    info_df.loc[:, "B3clf_predicted_probability"] = pred_model.predict_proba(
        mol_features
    )[:, 1]
    # get predicted label from probability using the threshold
    mask = np.greater_equal(
        info_df["B3clf_predicted_probability"].to_numpy(),
        # df_thres.loc[clf_str + "-" + sampling_str, threshold])
        df_thres.loc["xgb-classic_ADASYN", threshold],
    )
    label_pool[mask] = 1

    # save the predicted labels
    info_df["B3clf_predicted_label"] = label_pool
    info_df.reset_index(inplace=True)

    return info_df


@st.cache_resource
def generate_predictions(
    input_fname: str = None,
    sep: str = "\s+|\t+",
    clf: str = "xgb",
    _models_dict: dict = None,
    keep_sdf: str = "no",
    sampling: str = "classic_ADASYN",
    time_per_mol: int = 120,
    mol_features: pd.DataFrame = None,
    info_df: pd.DataFrame = None,
):
    """
    Generate predictions for a given input file.
    """
    try:
        if mol_features is None and info_df is None:
            if input_fname is None:
                raise ValueError("Either input_fname or mol_features/info_df must be provided")

            mol_tag = os.path.basename(input_fname).split(".")[0]
            file_ext = os.path.splitext(input_fname)[1].lower()
            internal_sdf = f"{mol_tag}_optimized_3d.sdf"

            try:
                # Handle different file types
                if file_ext == '.csv':
                    sep = ','
                elif file_ext == '.txt' or file_ext == '.smi':
                    sep = '\s+|\t+'
                elif file_ext != '.sdf':
                    raise ValueError(f"Unsupported file type: {file_ext}")

                # Geometry optimization
                geometry_optimize(input_fname=input_fname, output_sdf=internal_sdf, sep=sep)

                # Compute descriptors with timeout handling
                df_features = compute_descriptors(
                    sdf_file=internal_sdf,
                    excel_out=None,
                    output_csv=None,
                    timeout=time_per_mol * 2,  # Double the per-molecule time for total timeout
                    time_per_molecule=time_per_mol,
                )

                # Get computed descriptors
                mol_features, info_df = get_descriptors(df=df_features)

                # Select descriptors
                mol_features = select_descriptors(df=mol_features)

                # Clean data before scaling - replace empty strings with NaN and drop rows with NaN values
                mol_features = mol_features.replace('', np.nan)
                mol_features = mol_features.apply(pd.to_numeric, errors='coerce')
                if mol_features.isnull().any().any():
                    st.warning("Some descriptors contained invalid values and were removed")
                    # Get indices of valid rows
                    valid_indices = ~mol_features.isnull().any(axis=1)
                    # Update both dataframes to keep only valid rows
                    mol_features = mol_features[valid_indices]
                    info_df = info_df[valid_indices]
                    if len(mol_features) == 0:
                        raise ValueError("No valid data remains after cleaning")

                # Scale descriptors
                mol_features.iloc[:, :] = scale_descriptors(df=mol_features)

            finally:
                # Clean up temporary files
                if os.path.exists(internal_sdf) and keep_sdf == "no":
                    try:
                        os.remove(internal_sdf)
                    except:
                        pass

        # Get predictions
        result_df = predict_permeability(
            clf_str=clf,
            sampling_str=sampling,
            _models_dict=_models_dict,
            mol_features=mol_features,
            info_df=info_df,
            threshold="none",
        )

        # Select display columns
        display_cols = [
            "ID",
            "SMILES",
            "B3clf_predicted_probability",
            "B3clf_predicted_label",
        ]

        result_df = result_df[
            [col for col in result_df.columns.to_list() if col in display_cols]
        ]

        return mol_features, info_df, result_df

    except Exception as e:
        import traceback
        st.error(f"Error in generate_predictions: {str(e)}\n{traceback.format_exc()}")
        raise