from flask import Flask, request, jsonify import torch import torch.nn as nn import pandas as pd import numpy as np from transformers import AutoTokenizer, AutoModelForSequenceClassification from rdkit import Chem from rdkit.Chem.rdFingerprintGenerator import GetMorganGenerator import joblib import pickle import pubchempy as pcp import logging import os app = Flask(__name__) # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Device setup device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logger.info(f"Using device: {device}") # Model name model_name = "Fredaaaaaa/smiles" # Load dataset dataset_path = "/kaggle/input/labeled-data/labeled_severity.csv" try: df = pd.read_csv(dataset_path, encoding='latin1') df.rename(columns={"Interaction Description": "interaction_description"}, inplace=True) df['Drug 1_normalized'] = df['Drug 1_normalized'].str.lower() df['Drug 2_normalized'] = df['Drug 2_normalized'].str.lower() logger.info("Dataset loaded successfully") except Exception as e: logger.error(f"Failed to load dataset: {e}") df = pd.DataFrame() # Load model components model_dir = "/kaggle/working/drug_interaction_model" try: # Load tokenizer and BioBERT model if os.path.exists(model_dir): tokenizer = AutoTokenizer.from_pretrained(model_dir) text_model = AutoModelForSequenceClassification.from_pretrained(model_dir).to(device) else: logger.warning(f"Local model directory {model_dir} not found, falling back to {model_name}") tokenizer = AutoTokenizer.from_pretrained(model_name) text_model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=3).to(device) text_model.eval() logger.info("BioBERT and tokenizer loaded") # Load custom model components checkpoint = torch.load(os.path.join(model_dir, 'custom_model.pt'), map_location=device) input_size = checkpoint['input_size'] dropout_rate = checkpoint['dropout_rate'] # Define HybridModel class HybridModel(nn.Module): def __init__(self, text_model, input_size, dropout_rate=0.4): super(HybridModel, self).__init__() self.text_model = text_model self.drug_branch = nn.Sequential( nn.Linear(input_size, 2048), nn.ReLU(), nn.BatchNorm1d(2048), nn.Dropout(dropout_rate), nn.Linear(2048, 1024), nn.ReLU(), nn.BatchNorm1d(1024), nn.Dropout(dropout_rate), nn.Linear(1024, 512), nn.ReLU(), nn.BatchNorm1d(512), nn.Dropout(dropout_rate), nn.Linear(512, 256), nn.ReLU(), nn.BatchNorm1d(256), nn.Dropout(dropout_rate), nn.Linear(256, 128), nn.ReLU() ) self.fusion = nn.Sequential( nn.Linear(128 + 3, 1024), nn.ReLU(), nn.BatchNorm1d(1024), nn.Dropout(dropout_rate), nn.Linear(1024, 512), nn.ReLU(), nn.BatchNorm1d(512), nn.Dropout(dropout_rate), nn.Linear(512, 256), nn.ReLU(), nn.BatchNorm1d(256), nn.Dropout(dropout_rate), nn.Linear(256, 128), nn.ReLU(), nn.BatchNorm1d(128), nn.Dropout(dropout_rate), nn.Linear(128, 3) ) def forward(self, input_ids, attention_mask, drug_features): text_outputs = self.text_model(input_ids=input_ids, attention_mask=attention_mask) text_features = text_outputs.logits drug_features = self.drug_branch(drug_features) combined = torch.cat((text_features, drug_features), dim=1) output = self.fusion(combined) return output model = HybridModel(text_model, input_size, dropout_rate).to(device) model.drug_branch.load_state_dict(checkpoint['drug_branch_state_dict']) model.fusion.load_state_dict(checkpoint['fusion_state_dict']) model.eval() logger.info("HybridModel loaded") # Load Random Forest rf_model = joblib.load(os.path.join(model_dir, 'rf_model.joblib')) logger.info("Random Forest model loaded") # Load label encoder with open(os.path.join(model_dir, 'label_encoder.pkl'), 'rb') as f: label_encoder = pickle.load(f) logger.info("Label encoder loaded") except Exception as e: logger.error(f"Failed to load model components: {e}") raise # Function to fetch SMILES from PubChem def get_smiles(drug_name): try: compounds = pcp.get_compounds(drug_name, 'name') if compounds: return compounds[0].canonical_smiles logger.warning(f"No SMILES found for {drug_name}") return None except Exception as e: logger.error(f"PubChem API error for {drug_name}: {e}") return None # Function to compute Morgan fingerprints def preprocess_smiles(smiles): try: mol = Chem.MolFromSmiles(smiles) if mol is None: return np.zeros(1024) morgan_gen = GetMorganGenerator(radius=2, fpSize=1024) fingerprint = morgan_gen.GetFingerprint(mol) return np.array(fingerprint) except: return np.zeros(1024) # Prediction function def predict_interaction(drug1, drug2, interaction_description): drug1 = drug1.lower() drug2 = drug2.lower() # Check dataset for SMILES smiles1 = None smiles2 = None if not df.empty: drug1_matches = df[df['Drug 1_normalized'] == drug1] drug2_matches = df[df['Drug 2_normalized'] == drug2] if not drug1_matches.empty: smiles1 = drug1_matches['SMILES'].iloc[0] if not drug2_matches.empty: smiles2 = drug2_matches['SMILES_2'].iloc[0] # Fetch SMILES from PubChem if not in dataset if smiles1 is None: smiles1 = get_smiles(drug1) if smiles2 is None: smiles2 = get_smiles(drug2) # Validate SMILES if not smiles1 or not smiles2: return {"error": "Unable to retrieve SMILES for one or both drugs"} # Preprocess SMILES drug1_features = preprocess_smiles(smiles1) drug2_features = preprocess_smiles(smiles2) drug_features = np.hstack([drug1_features, drug2_features]) drug_features_tensor = torch.tensor(drug_features, dtype=torch.float32).unsqueeze(0).to(device) # Tokenize interaction description encodings = tokenizer(interaction_description, truncation=True, padding=True, max_length=128, return_tensors='pt') input_ids = encodings['input_ids'].to(device) attention_mask = encodings['attention_mask'].to(device) # Model prediction with torch.no_grad(): outputs = model(input_ids, attention_mask, drug_features_tensor) nn_pred = torch.argmax(outputs, dim=1).cpu().numpy()[0] # Random Forest prediction rf_pred = rf_model.predict(drug_features.reshape(1, -1))[0] # Ensemble prediction votes = [nn_pred] * 9 + [rf_pred] * 1 ensemble_pred = max(set(votes), key=votes.count) # Decode prediction severity = label_encoder.inverse_transform([ensemble_pred])[0] return {"severity": severity} # Flask routes @app.route('/') def index(): return """