Fredaaaaaa commited on
Commit
608709a
·
verified ·
1 Parent(s): 51ca3d4

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +236 -0
app.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request, jsonify
2
+ import torch
3
+ import torch.nn as nn
4
+ import pandas as pd
5
+ import numpy as np
6
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
7
+ from rdkit import Chem
8
+ from rdkit.Chem.rdFingerprintGenerator import GetMorganGenerator
9
+ import joblib
10
+ import pickle
11
+ import pubchempy as pcp
12
+ import logging
13
+ import os
14
+
15
+ app = Flask(__name__)
16
+
17
+ # Set up logging
18
+ logging.basicConfig(level=logging.INFO)
19
+ logger = logging.getLogger(__name__)
20
+
21
+ # Device setup
22
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
+ logger.info(f"Using device: {device}")
24
+
25
+ # Model name
26
+ model_name = "Fredaaaaaa/smiles"
27
+
28
+ # Load dataset
29
+ dataset_path = "/kaggle/input/labeled-data/labeled_severity.csv"
30
+ try:
31
+ df = pd.read_csv(dataset_path, encoding='latin1')
32
+ df.rename(columns={"Interaction Description": "interaction_description"}, inplace=True)
33
+ df['Drug 1_normalized'] = df['Drug 1_normalized'].str.lower()
34
+ df['Drug 2_normalized'] = df['Drug 2_normalized'].str.lower()
35
+ logger.info("Dataset loaded successfully")
36
+ except Exception as e:
37
+ logger.error(f"Failed to load dataset: {e}")
38
+ df = pd.DataFrame()
39
+
40
+ # Load model components
41
+ model_dir = "/kaggle/working/drug_interaction_model"
42
+ try:
43
+ # Load tokenizer and BioBERT model
44
+ if os.path.exists(model_dir):
45
+ tokenizer = AutoTokenizer.from_pretrained(model_dir)
46
+ text_model = AutoModelForSequenceClassification.from_pretrained(model_dir).to(device)
47
+ else:
48
+ logger.warning(f"Local model directory {model_dir} not found, falling back to {model_name}")
49
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
50
+ text_model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=3).to(device)
51
+ text_model.eval()
52
+ logger.info("BioBERT and tokenizer loaded")
53
+
54
+ # Load custom model components
55
+ checkpoint = torch.load(os.path.join(model_dir, 'custom_model.pt'), map_location=device)
56
+ input_size = checkpoint['input_size']
57
+ dropout_rate = checkpoint['dropout_rate']
58
+
59
+ # Define HybridModel
60
+ class HybridModel(nn.Module):
61
+ def __init__(self, text_model, input_size, dropout_rate=0.4):
62
+ super(HybridModel, self).__init__()
63
+ self.text_model = text_model
64
+ self.drug_branch = nn.Sequential(
65
+ nn.Linear(input_size, 2048),
66
+ nn.ReLU(),
67
+ nn.BatchNorm1d(2048),
68
+ nn.Dropout(dropout_rate),
69
+ nn.Linear(2048, 1024),
70
+ nn.ReLU(),
71
+ nn.BatchNorm1d(1024),
72
+ nn.Dropout(dropout_rate),
73
+ nn.Linear(1024, 512),
74
+ nn.ReLU(),
75
+ nn.BatchNorm1d(512),
76
+ nn.Dropout(dropout_rate),
77
+ nn.Linear(512, 256),
78
+ nn.ReLU(),
79
+ nn.BatchNorm1d(256),
80
+ nn.Dropout(dropout_rate),
81
+ nn.Linear(256, 128),
82
+ nn.ReLU()
83
+ )
84
+ self.fusion = nn.Sequential(
85
+ nn.Linear(128 + 3, 1024),
86
+ nn.ReLU(),
87
+ nn.BatchNorm1d(1024),
88
+ nn.Dropout(dropout_rate),
89
+ nn.Linear(1024, 512),
90
+ nn.ReLU(),
91
+ nn.BatchNorm1d(512),
92
+ nn.Dropout(dropout_rate),
93
+ nn.Linear(512, 256),
94
+ nn.ReLU(),
95
+ nn.BatchNorm1d(256),
96
+ nn.Dropout(dropout_rate),
97
+ nn.Linear(256, 128),
98
+ nn.ReLU(),
99
+ nn.BatchNorm1d(128),
100
+ nn.Dropout(dropout_rate),
101
+ nn.Linear(128, 3)
102
+ )
103
+
104
+ def forward(self, input_ids, attention_mask, drug_features):
105
+ text_outputs = self.text_model(input_ids=input_ids, attention_mask=attention_mask)
106
+ text_features = text_outputs.logits
107
+ drug_features = self.drug_branch(drug_features)
108
+ combined = torch.cat((text_features, drug_features), dim=1)
109
+ output = self.fusion(combined)
110
+ return output
111
+
112
+ model = HybridModel(text_model, input_size, dropout_rate).to(device)
113
+ model.drug_branch.load_state_dict(checkpoint['drug_branch_state_dict'])
114
+ model.fusion.load_state_dict(checkpoint['fusion_state_dict'])
115
+ model.eval()
116
+ logger.info("HybridModel loaded")
117
+
118
+ # Load Random Forest
119
+ rf_model = joblib.load(os.path.join(model_dir, 'rf_model.joblib'))
120
+ logger.info("Random Forest model loaded")
121
+
122
+ # Load label encoder
123
+ with open(os.path.join(model_dir, 'label_encoder.pkl'), 'rb') as f:
124
+ label_encoder = pickle.load(f)
125
+ logger.info("Label encoder loaded")
126
+ except Exception as e:
127
+ logger.error(f"Failed to load model components: {e}")
128
+ raise
129
+
130
+ # Function to fetch SMILES from PubChem
131
+ def get_smiles(drug_name):
132
+ try:
133
+ compounds = pcp.get_compounds(drug_name, 'name')
134
+ if compounds:
135
+ return compounds[0].canonical_smiles
136
+ logger.warning(f"No SMILES found for {drug_name}")
137
+ return None
138
+ except Exception as e:
139
+ logger.error(f"PubChem API error for {drug_name}: {e}")
140
+ return None
141
+
142
+ # Function to compute Morgan fingerprints
143
+ def preprocess_smiles(smiles):
144
+ try:
145
+ mol = Chem.MolFromSmiles(smiles)
146
+ if mol is None:
147
+ return np.zeros(1024)
148
+ morgan_gen = GetMorganGenerator(radius=2, fpSize=1024)
149
+ fingerprint = morgan_gen.GetFingerprint(mol)
150
+ return np.array(fingerprint)
151
+ except:
152
+ return np.zeros(1024)
153
+
154
+ # Prediction function
155
+ def predict_interaction(drug1, drug2, interaction_description):
156
+ drug1 = drug1.lower()
157
+ drug2 = drug2.lower()
158
+
159
+ # Check dataset for SMILES
160
+ smiles1 = None
161
+ smiles2 = None
162
+ if not df.empty:
163
+ drug1_matches = df[df['Drug 1_normalized'] == drug1]
164
+ drug2_matches = df[df['Drug 2_normalized'] == drug2]
165
+ if not drug1_matches.empty:
166
+ smiles1 = drug1_matches['SMILES'].iloc[0]
167
+ if not drug2_matches.empty:
168
+ smiles2 = drug2_matches['SMILES_2'].iloc[0]
169
+
170
+ # Fetch SMILES from PubChem if not in dataset
171
+ if smiles1 is None:
172
+ smiles1 = get_smiles(drug1)
173
+ if smiles2 is None:
174
+ smiles2 = get_smiles(drug2)
175
+
176
+ # Validate SMILES
177
+ if not smiles1 or not smiles2:
178
+ return {"error": "Unable to retrieve SMILES for one or both drugs"}
179
+
180
+ # Preprocess SMILES
181
+ drug1_features = preprocess_smiles(smiles1)
182
+ drug2_features = preprocess_smiles(smiles2)
183
+ drug_features = np.hstack([drug1_features, drug2_features])
184
+ drug_features_tensor = torch.tensor(drug_features, dtype=torch.float32).unsqueeze(0).to(device)
185
+
186
+ # Tokenize interaction description
187
+ encodings = tokenizer(interaction_description, truncation=True, padding=True, max_length=128, return_tensors='pt')
188
+ input_ids = encodings['input_ids'].to(device)
189
+ attention_mask = encodings['attention_mask'].to(device)
190
+
191
+ # Model prediction
192
+ with torch.no_grad():
193
+ outputs = model(input_ids, attention_mask, drug_features_tensor)
194
+ nn_pred = torch.argmax(outputs, dim=1).cpu().numpy()[0]
195
+
196
+ # Random Forest prediction
197
+ rf_pred = rf_model.predict(drug_features.reshape(1, -1))[0]
198
+
199
+ # Ensemble prediction
200
+ votes = [nn_pred] * 9 + [rf_pred] * 1
201
+ ensemble_pred = max(set(votes), key=votes.count)
202
+
203
+ # Decode prediction
204
+ severity = label_encoder.inverse_transform([ensemble_pred])[0]
205
+ return {"severity": severity}
206
+
207
+ # Flask routes
208
+ @app.route('/')
209
+ def index():
210
+ return """
211
+ <h1>Drug Interaction Severity Prediction</h1>
212
+ <form method="POST" action="/predict">
213
+ <label>Drug 1:</label><br>
214
+ <input type="text" name="drug1" required><br>
215
+ <label>Drug 2:</label><br>
216
+ <input type="text" name="drug2" required><br>
217
+ <label>Interaction Description:</label><br>
218
+ <textarea name="interaction_description" required></textarea><br>
219
+ <input type="submit" value="Predict">
220
+ </form>
221
+ """
222
+
223
+ @app.route('/predict', methods=['POST'])
224
+ def predict():
225
+ try:
226
+ drug1 = request.form['drug1']
227
+ drug2 = request.form['drug2']
228
+ interaction_description = request.form['interaction_description']
229
+ result = predict_interaction(drug1, drug2, interaction_description)
230
+ return jsonify(result)
231
+ except Exception as e:
232
+ logger.error(f"Prediction error: {e}")
233
+ return jsonify({"error": str(e)}), 500
234
+
235
+ if __name__ == '__main__':
236
+ app.run(debug=True, host='0.0.0.0', port=5000)