| |
| """ |
| Predict Tg using a pre-trained MLP + MACCS model for a given PID. |
| |
| Requirements: |
| - Trained model : model_mlp_tg.pkl |
| - SMILES.csv file with columns: PID, SMILES |
| |
| Usage: |
| python predict_tg.py P000001 |
| """ |
|
|
| import sys |
| import os |
| import numpy as np |
| import pandas as pd |
| from rdkit import Chem |
| from rdkit.Chem import MACCSkeys, DataStructs |
| import joblib |
|
|
| |
| |
| MODEL_PATH = "Files/model_mlp_tg.pkl" |
| SMILES_CSV_PATH = "SMILES.csv" |
| |
|
|
|
|
| def load_model(model_path: str): |
| """Load the pre-trained MLP model from disk.""" |
| if not os.path.exists(model_path): |
| raise FileNotFoundError(f"Model file not found: {model_path}") |
| return joblib.load(model_path) |
|
|
|
|
| def load_smiles_for_pid(pid: str, smiles_csv_path: str) -> str: |
| """Return the SMILES string corresponding to the given PID.""" |
| if not os.path.exists(smiles_csv_path): |
| raise FileNotFoundError(f"SMILES CSV not found: {smiles_csv_path}") |
|
|
| df = pd.read_csv(smiles_csv_path) |
|
|
| if "PID" not in df.columns or "SMILES" not in df.columns: |
| raise ValueError("SMILES.csv must contain 'PID' and 'SMILES' columns.") |
|
|
| |
| df["PID"] = df["PID"].astype(str) |
| pid_str = str(pid) |
|
|
| row = df.loc[df["PID"] == pid_str] |
| if row.empty: |
| raise ValueError(f"PID '{pid_str}' not found in {smiles_csv_path}.") |
|
|
| smiles = row["SMILES"].iloc[0] |
| if not isinstance(smiles, str) or smiles.strip() == "": |
| raise ValueError(f"Empty or invalid SMILES for PID '{pid_str}'.") |
|
|
| return smiles |
|
|
|
|
| def smiles_to_maccs(smiles: str) -> np.ndarray: |
| """Convert a single SMILES string to a MACCS fingerprint row (shape: 1 x n_bits).""" |
| mol = Chem.MolFromSmiles(smiles) |
| if mol is None: |
| raise ValueError(f"Invalid SMILES string: {smiles}") |
|
|
| fp = MACCSkeys.GenMACCSKeys(mol) |
| arr = np.zeros((fp.GetNumBits(),), dtype=np.int8) |
| DataStructs.ConvertToNumpyArray(fp, arr) |
|
|
| |
| return np.array([arr], dtype=np.float32) |
|
|
|
|
| def main(): |
| if len(sys.argv) != 2: |
| print("Usage: python predict_tg_mlp_maccs.py <PID>") |
| sys.exit(1) |
|
|
| pid = sys.argv[1] |
|
|
| |
| model = load_model(MODEL_PATH) |
|
|
| |
| smiles = load_smiles_for_pid(pid, SMILES_CSV_PATH) |
|
|
| |
| X = smiles_to_maccs(smiles) |
|
|
| |
| tg_pred = model.predict(X)[0] |
|
|
| |
| print(f"{tg_pred:.6f}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|