Spaces:
Runtime error
Runtime error
| from flask import Flask, request, jsonify | |
| import torch | |
| import torch.nn as nn | |
| import pandas as pd | |
| import numpy as np | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| from rdkit import Chem | |
| from rdkit.Chem.rdFingerprintGenerator import GetMorganGenerator | |
| import joblib | |
| import pickle | |
| import pubchempy as pcp | |
| import logging | |
| import os | |
| app = Flask(__name__) | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Device setup | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| logger.info(f"Using device: {device}") | |
| # Model name | |
| model_name = "Fredaaaaaa/smiles" | |
| # Load dataset | |
| dataset_path = "/kaggle/input/labeled-data/labeled_severity.csv" | |
| try: | |
| df = pd.read_csv(dataset_path, encoding='latin1') | |
| df.rename(columns={"Interaction Description": "interaction_description"}, inplace=True) | |
| df['Drug 1_normalized'] = df['Drug 1_normalized'].str.lower() | |
| df['Drug 2_normalized'] = df['Drug 2_normalized'].str.lower() | |
| logger.info("Dataset loaded successfully") | |
| except Exception as e: | |
| logger.error(f"Failed to load dataset: {e}") | |
| df = pd.DataFrame() | |
| # Load model components | |
| model_dir = "/kaggle/working/drug_interaction_model" | |
| try: | |
| # Load tokenizer and BioBERT model | |
| if os.path.exists(model_dir): | |
| tokenizer = AutoTokenizer.from_pretrained(model_dir) | |
| text_model = AutoModelForSequenceClassification.from_pretrained(model_dir).to(device) | |
| else: | |
| logger.warning(f"Local model directory {model_dir} not found, falling back to {model_name}") | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| text_model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=3).to(device) | |
| text_model.eval() | |
| logger.info("BioBERT and tokenizer loaded") | |
| # Load custom model components | |
| checkpoint = torch.load(os.path.join(model_dir, 'custom_model.pt'), map_location=device) | |
| input_size = checkpoint['input_size'] | |
| dropout_rate = checkpoint['dropout_rate'] | |
| # Define HybridModel | |
| class HybridModel(nn.Module): | |
| def __init__(self, text_model, input_size, dropout_rate=0.4): | |
| super(HybridModel, self).__init__() | |
| self.text_model = text_model | |
| self.drug_branch = nn.Sequential( | |
| nn.Linear(input_size, 2048), | |
| nn.ReLU(), | |
| nn.BatchNorm1d(2048), | |
| nn.Dropout(dropout_rate), | |
| nn.Linear(2048, 1024), | |
| nn.ReLU(), | |
| nn.BatchNorm1d(1024), | |
| nn.Dropout(dropout_rate), | |
| nn.Linear(1024, 512), | |
| nn.ReLU(), | |
| nn.BatchNorm1d(512), | |
| nn.Dropout(dropout_rate), | |
| nn.Linear(512, 256), | |
| nn.ReLU(), | |
| nn.BatchNorm1d(256), | |
| nn.Dropout(dropout_rate), | |
| nn.Linear(256, 128), | |
| nn.ReLU() | |
| ) | |
| self.fusion = nn.Sequential( | |
| nn.Linear(128 + 3, 1024), | |
| nn.ReLU(), | |
| nn.BatchNorm1d(1024), | |
| nn.Dropout(dropout_rate), | |
| nn.Linear(1024, 512), | |
| nn.ReLU(), | |
| nn.BatchNorm1d(512), | |
| nn.Dropout(dropout_rate), | |
| nn.Linear(512, 256), | |
| nn.ReLU(), | |
| nn.BatchNorm1d(256), | |
| nn.Dropout(dropout_rate), | |
| nn.Linear(256, 128), | |
| nn.ReLU(), | |
| nn.BatchNorm1d(128), | |
| nn.Dropout(dropout_rate), | |
| nn.Linear(128, 3) | |
| ) | |
| def forward(self, input_ids, attention_mask, drug_features): | |
| text_outputs = self.text_model(input_ids=input_ids, attention_mask=attention_mask) | |
| text_features = text_outputs.logits | |
| drug_features = self.drug_branch(drug_features) | |
| combined = torch.cat((text_features, drug_features), dim=1) | |
| output = self.fusion(combined) | |
| return output | |
| model = HybridModel(text_model, input_size, dropout_rate).to(device) | |
| model.drug_branch.load_state_dict(checkpoint['drug_branch_state_dict']) | |
| model.fusion.load_state_dict(checkpoint['fusion_state_dict']) | |
| model.eval() | |
| logger.info("HybridModel loaded") | |
| # Load Random Forest | |
| rf_model = joblib.load(os.path.join(model_dir, 'rf_model.joblib')) | |
| logger.info("Random Forest model loaded") | |
| # Load label encoder | |
| with open(os.path.join(model_dir, 'label_encoder.pkl'), 'rb') as f: | |
| label_encoder = pickle.load(f) | |
| logger.info("Label encoder loaded") | |
| except Exception as e: | |
| logger.error(f"Failed to load model components: {e}") | |
| raise | |
| # Function to fetch SMILES from PubChem | |
| def get_smiles(drug_name): | |
| try: | |
| compounds = pcp.get_compounds(drug_name, 'name') | |
| if compounds: | |
| return compounds[0].canonical_smiles | |
| logger.warning(f"No SMILES found for {drug_name}") | |
| return None | |
| except Exception as e: | |
| logger.error(f"PubChem API error for {drug_name}: {e}") | |
| return None | |
| # Function to compute Morgan fingerprints | |
| def preprocess_smiles(smiles): | |
| try: | |
| mol = Chem.MolFromSmiles(smiles) | |
| if mol is None: | |
| return np.zeros(1024) | |
| morgan_gen = GetMorganGenerator(radius=2, fpSize=1024) | |
| fingerprint = morgan_gen.GetFingerprint(mol) | |
| return np.array(fingerprint) | |
| except: | |
| return np.zeros(1024) | |
| # Prediction function | |
| def predict_interaction(drug1, drug2, interaction_description): | |
| drug1 = drug1.lower() | |
| drug2 = drug2.lower() | |
| # Check dataset for SMILES | |
| smiles1 = None | |
| smiles2 = None | |
| if not df.empty: | |
| drug1_matches = df[df['Drug 1_normalized'] == drug1] | |
| drug2_matches = df[df['Drug 2_normalized'] == drug2] | |
| if not drug1_matches.empty: | |
| smiles1 = drug1_matches['SMILES'].iloc[0] | |
| if not drug2_matches.empty: | |
| smiles2 = drug2_matches['SMILES_2'].iloc[0] | |
| # Fetch SMILES from PubChem if not in dataset | |
| if smiles1 is None: | |
| smiles1 = get_smiles(drug1) | |
| if smiles2 is None: | |
| smiles2 = get_smiles(drug2) | |
| # Validate SMILES | |
| if not smiles1 or not smiles2: | |
| return {"error": "Unable to retrieve SMILES for one or both drugs"} | |
| # Preprocess SMILES | |
| drug1_features = preprocess_smiles(smiles1) | |
| drug2_features = preprocess_smiles(smiles2) | |
| drug_features = np.hstack([drug1_features, drug2_features]) | |
| drug_features_tensor = torch.tensor(drug_features, dtype=torch.float32).unsqueeze(0).to(device) | |
| # Tokenize interaction description | |
| encodings = tokenizer(interaction_description, truncation=True, padding=True, max_length=128, return_tensors='pt') | |
| input_ids = encodings['input_ids'].to(device) | |
| attention_mask = encodings['attention_mask'].to(device) | |
| # Model prediction | |
| with torch.no_grad(): | |
| outputs = model(input_ids, attention_mask, drug_features_tensor) | |
| nn_pred = torch.argmax(outputs, dim=1).cpu().numpy()[0] | |
| # Random Forest prediction | |
| rf_pred = rf_model.predict(drug_features.reshape(1, -1))[0] | |
| # Ensemble prediction | |
| votes = [nn_pred] * 9 + [rf_pred] * 1 | |
| ensemble_pred = max(set(votes), key=votes.count) | |
| # Decode prediction | |
| severity = label_encoder.inverse_transform([ensemble_pred])[0] | |
| return {"severity": severity} | |
| # Flask routes | |
| def index(): | |
| return """ | |
| <h1>Drug Interaction Severity Prediction</h1> | |
| <form method="POST" action="/predict"> | |
| <label>Drug 1:</label><br> | |
| <input type="text" name="drug1" required><br> | |
| <label>Drug 2:</label><br> | |
| <input type="text" name="drug2" required><br> | |
| <label>Interaction Description:</label><br> | |
| <textarea name="interaction_description" required></textarea><br> | |
| <input type="submit" value="Predict"> | |
| </form> | |
| """ | |
| def predict(): | |
| try: | |
| drug1 = request.form['drug1'] | |
| drug2 = request.form['drug2'] | |
| interaction_description = request.form['interaction_description'] | |
| result = predict_interaction(drug1, drug2, interaction_description) | |
| return jsonify(result) | |
| except Exception as e: | |
| logger.error(f"Prediction error: {e}") | |
| return jsonify({"error": str(e)}), 500 | |
| if __name__ == '__main__': | |
| app.run(debug=True, host='0.0.0.0', port=5000) |