newww / app.py
Fredaaaaaa's picture
Update app.py
a05f3a1 verified
import joblib
import numpy as np
from pubchempy import get_compounds
from rdkit import Chem
from rdkit.Chem import AllChem
import gradio as gr
print("Loading model and preprocessors...")
try:
# Load saved model and preprocessors
model = joblib.load('random_forest_model.joblib')
scaler = joblib.load('standard_scaler.joblib')
le = joblib.load('label_encoder.joblib')
print(f"Model loaded successfully. Type: {type(model).__name__}")
except Exception as e:
print(f"Error loading model: {e}")
raise
# Define numerical columns from training (match exactly)
numerical_cols = [
'molecular_weight', 'molecular_weight_2', 'xlogp', 'xlogp_2', 'tpsa', 'tpsa_2',
'rotatable_bond_count', 'rotatable_bond_count_2', 'h_bond_donor_count', 'h_bond_donor_count_2',
'h_bond_acceptor_count', 'h_bond_acceptor_count_2', 'complexity', 'complexity_2',
'charge', 'charge_2', 'exact_mass', 'exact_mass_2'
]
# Preprocessing function
def get_morgan_fingerprint(smiles, radius=2, n_bits=512):
try:
mol = Chem.MolFromSmiles(smiles)
if mol is None:
return np.zeros(n_bits)
fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=n_bits)
return np.array(fp)
except:
return np.zeros(n_bits)
# Function to extract features from PubChem
def extract_features(drug1_name, drug2_name):
# Fetch compounds from PubChem
compounds1 = get_compounds(drug1_name, 'name')
compounds2 = get_compounds(drug2_name, 'name')
if not compounds1 or not compounds2:
return None, f"One or both drugs not found: {drug1_name}, {drug2_name}"
compound1 = compounds1[0] # Take the first match
compound2 = compounds2[0]
# Extract PubChem properties
props1 = {
'molecular_weight': compound1.molecular_weight if compound1.molecular_weight else 0,
'xlogp': compound1.xlogp if compound1.xlogp else 0,
'tpsa': compound1.tpsa if compound1.tpsa else 0,
'rotatable_bond_count': compound1.rotatable_bond_count if compound1.rotatable_bond_count else 0,
'h_bond_donor_count': compound1.h_bond_donor_count if compound1.h_bond_donor_count else 0,
'h_bond_acceptor_count': compound1.h_bond_acceptor_count if compound1.h_bond_acceptor_count else 0,
'complexity': compound1.complexity if compound1.complexity else 0,
'charge': 0, # PubChem doesn't provide direct charge, assume 0
'exact_mass': compound1.exact_mass if compound1.exact_mass else 0
}
props2 = {
'molecular_weight_2': compound2.molecular_weight if compound2.molecular_weight else 0,
'xlogp_2': compound2.xlogp if compound2.xlogp else 0,
'tpsa_2': compound2.tpsa if compound2.tpsa else 0,
'rotatable_bond_count_2': compound2.rotatable_bond_count if compound2.rotatable_bond_count else 0,
'h_bond_donor_count_2': compound2.h_bond_donor_count if compound2.h_bond_donor_count else 0,
'h_bond_acceptor_count_2': compound2.h_bond_acceptor_count if compound2.h_bond_acceptor_count else 0,
'complexity_2': compound2.complexity if compound2.complexity else 0,
'charge_2': 0, # Assume 0
'exact_mass_2': compound2.exact_mass if compound2.exact_mass else 0
}
# Combine properties into a single feature vector
features = [props1.get(col, 0) for col in numerical_cols[:9]] + [props2.get(col, 0) for col in numerical_cols[9:]]
# Get SMILES for fingerprints
smiles1 = compound1.canonical_smiles
smiles2 = compound2.canonical_smiles
fp1 = get_morgan_fingerprint(smiles1)
fp2 = get_morgan_fingerprint(smiles2)
# Combine all features with padding for BioBERT (768 dimensions)
X = np.hstack([np.array(features).reshape(1, -1), fp1.reshape(1, -1), fp2.reshape(1, -1), np.zeros((1, 768))])
return X, None, smiles1, smiles2
# Function to predict severity
def predict_severity(drug1, drug2):
# Fetch drug features from PubChem
X, error, smiles1, smiles2 = extract_features(drug1, drug2)
if error:
return error
# Scale and predict
X_scaled = scaler.transform(X)
prediction = model.predict(X_scaled)
severity = le.inverse_transform(prediction)[0]
probabilities = model.predict_proba(X_scaled)[0]
# Format output with SMILES
result = f"Predicted Severity: {severity}\n"
result += f"Drug 1 SMILES: {smiles1}\n"
result += f"Drug 2 SMILES: {smiles2}\n"
result += "Prediction Probabilities:\n"
for i, label in enumerate(le.classes_):
result += f" {label}: {probabilities[i]:.2%}\n"
return result
# Gradio Interface
interface = gr.Interface(
fn=predict_severity,
inputs=[gr.Textbox(label="Drug 1"), gr.Textbox(label="Drug 2")],
outputs="text",
live=True,
title="Drug Interaction Severity Predictor",
description="Enter two drug names to predict the severity of their interaction."
)
# Launch the interface
interface.launch()