Fredaaaaaa commited on
Commit
80a0bab
·
verified ·
1 Parent(s): 3df3aeb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +341 -100
app.py CHANGED
@@ -1,100 +1,341 @@
1
- import torch
2
- from transformers import AutoModelForSequenceClassification, AutoTokenizer
3
- import requests
4
- import gradio as gr
5
-
6
- # Load the pre-trained model and tokenizer from the Hugging Face directory where the model is saved
7
- model = AutoModelForSequenceClassification.from_pretrained("Fredaaaaaa/hybrid_model")
8
-
9
- tokenizer = AutoTokenizer.from_pretrained("hybrid_model")
10
-
11
- # Function to fetch drug features from an external API (e.g., PubChem)
12
- def get_drug_features(drug1, drug2):
13
- # You can modify this function to fetch additional features based on your API choice.
14
- api_url = f'https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/name/{drug1}/property/SMILES,Pharmacology,Toxicity,Mechanism-of-action,Route-of-elimination,Metabolism/JSON'
15
- response = requests.get(api_url)
16
-
17
- if response.status_code == 200:
18
- data = response.json()
19
- drug_features = {
20
- 'SMILES': data['PropertyTable']['Properties'][0].get('SMILES', 'No data'),
21
- 'pharmacology': data['PropertyTable']['Properties'][0].get('Pharmacology', 'No data'),
22
- 'toxicity': data['PropertyTable']['Properties'][0].get('Toxicity', 'No data'),
23
- 'mechanism-of-action': data['PropertyTable']['Properties'][0].get('Mechanism-of-action', 'No data'),
24
- 'route-of-elimination': data['PropertyTable']['Properties'][0].get('Route-of-elimination', 'No data'),
25
- 'metabolism': data['PropertyTable']['Properties'][0].get('Metabolism', 'No data'),
26
- }
27
- return drug_features
28
- else:
29
- return None # If no data is returned, handle the missing values gracefully
30
-
31
- # Define the Hybrid Model (already trained)
32
- class HybridModel(torch.nn.Module):
33
- def __init__(self, text_model, input_size, dropout_rate=0.3):
34
- super(HybridModel, self).__init__()
35
- self.text_model = text_model
36
- self.fc1 = torch.nn.Linear(input_size, 128) # Fully connected layer for drug features
37
- self.fc2 = torch.nn.Linear(128, 64) # Additional fully connected layer
38
- self.fc3 = torch.nn.Linear(64, 4) # Output layer (4 classes for severity)
39
- self.dropout = torch.nn.Dropout(dropout_rate) # Dropout layer to prevent overfitting
40
-
41
- def forward(self, input_ids, attention_mask, drug_features):
42
- # Process the text data (interaction description) through BioBERT
43
- text_outputs = self.text_model(input_ids=input_ids, attention_mask=attention_mask)
44
- text_features = text_outputs.logits # Shape: (batch_size, num_labels)
45
-
46
- # Pass the drug features through the fully connected layers
47
- x = torch.relu(self.fc1(drug_features)) # Apply ReLU activation
48
- x = self.dropout(x) # Apply Dropout
49
- x = torch.relu(self.fc2(x)) # Additional ReLU layer
50
- x = self.fc3(x) # Output layer (4 classes)
51
-
52
- # Combine text features and drug features (can use addition or concatenation)
53
- combined = text_features + x # Simple addition, can experiment with concatenation
54
- return combined
55
-
56
- # Initialize the model
57
- text_model = AutoModelForSequenceClassification.from_pretrained("hybrid_model")
58
- hybrid_model = HybridModel(text_model, input_size=3) # Example size, adjust based on your data
59
- hybrid_model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
60
-
61
- # Function to process the input and predict severity
62
- def predict_severity(drug1, drug2):
63
- # Fetch drug features from external API
64
- drug_features = get_drug_features(drug1, drug2)
65
-
66
- if not drug_features:
67
- return "Error: Could not fetch drug features from the API."
68
-
69
- # Preprocess text data (interaction description)
70
- interaction_description = f"{drug1} interacts with {drug2}" # Example interaction description
71
- inputs = tokenizer(interaction_description, return_tensors="pt", padding=True, truncation=True, max_length=128)
72
- input_ids = inputs['input_ids'].to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
73
- attention_mask = inputs['attention_mask'].to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
74
-
75
- # Process drug features (SMILES, pharmacology, toxicity, etc.)
76
- drug_feature_values = [drug_features['SMILES'], drug_features['pharmacology'], drug_features['toxicity']]
77
- drug_features_tensor = torch.tensor(drug_feature_values, dtype=torch.float32).unsqueeze(0).to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
78
-
79
- # Run the model to get predictions
80
- hybrid_model.eval()
81
- with torch.no_grad():
82
- outputs = hybrid_model(input_ids, attention_mask, drug_features_tensor)
83
-
84
- # Get the predicted class
85
- prediction = torch.argmax(outputs, dim=1).item()
86
-
87
- # Map the predicted class index to the severity label
88
- severity_labels = ["No interaction", "Mild", "Moderate", "Severe"]
89
- return severity_labels[prediction]
90
-
91
- # Gradio Interface
92
- interface = gr.Interface(
93
- fn=predict_severity,
94
- inputs=[gr.Textbox(label="Drug 1"), gr.Textbox(label="Drug 2")],
95
- outputs="text",
96
- live=True
97
- )
98
-
99
- # Launch the interface
100
- interface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ import requests
3
+ from huggingface_hub import hf_hub_download
4
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
5
+ import torch
6
+ import gradio as gr
7
+ import pandas as pd
8
+ import re
9
+
10
+ # Download label encoder from Hugging Face Hub
11
+ label_encoder_path = hf_hub_download(repo_id="Fredaaaaaa/hybrid_model", filename="label_encoder.pkl")
12
+ with open(label_encoder_path, 'rb') as f:
13
+ label_encoder = pickle.load(f)
14
+
15
+ # Load model and tokenizer
16
+ model_name = "Fredaaaaaa/hybrid_model"
17
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
18
+ model = AutoModelForSequenceClassification.from_pretrained(model_name)
19
+ model.eval()
20
+
21
+ # Download the dataset from Hugging Face Hub
22
+ dataset_path = hf_hub_download(repo_id="Fredaaaaaa/hybrid_model", filename="labeled_severity.csv")
23
+
24
+ # Load the dataset with appropriate encoding
25
+ try:
26
+ df = pd.read_csv(dataset_path, encoding='ISO-8859-1')
27
+ print(f"Dataset loaded successfully! Shape: {df.shape}")
28
+
29
+ # Create a set of all unique drugs in the dataset for validation
30
+ all_drugs = set()
31
+
32
+ # Check which columns contain drug names
33
+ drug_columns = []
34
+ for col in df.columns:
35
+ if 'drug' in col.lower() or 'medication' in col.lower():
36
+ drug_columns.append(col)
37
+ # Add all drugs from this column to our set after cleaning
38
+ clean_drugs = df[col].dropna().astype(str).apply(lambda x: x.strip().lower())
39
+ all_drugs.update(clean_drugs.unique())
40
+
41
+ print(f"Found {len(all_drugs)} unique drugs in the dataset")
42
+ print(f"Drug name columns found: {drug_columns}")
43
+
44
+ except Exception as e:
45
+ print(f"Error reading the dataset: {e}")
46
+ df = pd.DataFrame() # Empty dataframe as fallback
47
+ all_drugs = set()
48
+
49
+ # Function to properly clean drug names
50
+ def clean_drug_name(drug_name):
51
+ if not drug_name:
52
+ return ""
53
+ # Remove extra whitespace and standardize to lowercase
54
+ return re.sub(r'\s+', ' ', drug_name.strip().lower())
55
+
56
+ # Function to validate if input is a legitimate drug name
57
+ def validate_drug_input(drug_name):
58
+ # Clean the input
59
+ drug_name = clean_drug_name(drug_name)
60
+
61
+ if not drug_name or len(drug_name) <= 1:
62
+ return False, "Drug name is too short"
63
+
64
+ # Check if it's just a single letter or number
65
+ if len(drug_name) == 1 or drug_name.isdigit():
66
+ return False, "Not a valid drug name"
67
+
68
+ # Check if it contains weird characters
69
+ if not re.match(r'^[a-zA-Z0-9\s\-\+]+$', drug_name):
70
+ return False, "Drug name contains invalid characters"
71
+
72
+ # Print for debugging
73
+ print(f"Validating drug: '{drug_name}'")
74
+ print(f"Drug in dataset: {drug_name in all_drugs}")
75
+
76
+ # Check if it's in our known drug list
77
+ if drug_name in all_drugs:
78
+ return True, "Drug found in dataset"
79
+
80
+ # If we have a small drug list or need to be more forgiving, we can try fuzzy matching
81
+ for known_drug in all_drugs:
82
+ if drug_name in known_drug or known_drug in drug_name:
83
+ return True, f"Drug found in dataset (matched with '{known_drug}')"
84
+
85
+ # If not in dataset, we'll try the API validation
86
+ return None, "Drug not in dataset, needs API validation"
87
+
88
+ # Function to validate drug via PubChem API
89
+ def validate_drug_via_api(drug_name):
90
+ drug_name = clean_drug_name(drug_name)
91
+ try:
92
+ # Try to get basic info about the drug from PubChem
93
+ api_url = f'https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/name/{drug_name}/cids/JSON'
94
+ print(f"Calling PubChem API: {api_url}")
95
+ response = requests.get(api_url, timeout=10)
96
+
97
+ if response.status_code == 200:
98
+ data = response.json()
99
+ if 'IdentifierList' in data and 'CID' in data['IdentifierList'] and len(data['IdentifierList']['CID']) > 0:
100
+ return True, f"Drug validated via PubChem (CID: {data['IdentifierList']['CID'][0]})"
101
+
102
+ # Try the full name version with any suffixes that may have been removed
103
+ return False, f"Not found in PubChem database (Status: {response.status_code})"
104
+ except Exception as e:
105
+ print(f"API validation error: {e}")
106
+ return False, f"API validation error: {e}"
107
+
108
+ # Function to check if drugs are in the dataset
109
+ def get_drug_features_from_dataset(drug1, drug2, df):
110
+ if df.empty:
111
+ print("Dataset is empty, cannot search for drugs")
112
+ return None
113
+
114
+ # Normalize drug names for matching
115
+ drug1 = clean_drug_name(drug1)
116
+ drug2 = clean_drug_name(drug2)
117
+
118
+ print(f"Checking for drugs in dataset: '{drug1}', '{drug2}'")
119
+
120
+ try:
121
+ # First try with normalized columns
122
+ if 'Drug 1_normalized' in df.columns and 'Drug 2_normalized' in df.columns:
123
+ # Apply cleaning function to dataframe columns for comparison
124
+ drug_data = df[
125
+ (df['Drug 1_normalized'].str.lower().str.strip() == drug1) &
126
+ (df['Drug 2_normalized'].str.lower().str.strip() == drug2)
127
+ ]
128
+
129
+ # Also check the reverse combination
130
+ reversed_drug_data = df[
131
+ (df['Drug 1_normalized'].str.lower().str.strip() == drug2) &
132
+ (df['Drug 2_normalized'].str.lower().str.strip() == drug1)
133
+ ]
134
+
135
+ # Combine the results
136
+ drug_data = pd.concat([drug_data, reversed_drug_data])
137
+ else:
138
+ # Try with regular Drug1/Drug2 columns if normalized not available
139
+ possible_column_pairs = [
140
+ ('Drug1', 'Drug2'),
141
+ ('Drug 1', 'Drug 2'),
142
+ ('drug1', 'drug2'),
143
+ ('drug_1', 'drug_2')
144
+ ]
145
+
146
+ drug_data = pd.DataFrame() # Initialize as empty
147
+
148
+ for col1, col2 in possible_column_pairs:
149
+ if col1 in df.columns and col2 in df.columns:
150
+ # Clean the strings in the dataframe columns for comparison
151
+ matches = df[
152
+ ((df[col1].astype(str).str.lower().str.strip() == drug1) &
153
+ (df[col2].astype(str).str.lower().str.strip() == drug2)) |
154
+ ((df[col1].astype(str).str.lower().str.strip() == drug2) &
155
+ (df[col2].astype(str).str.lower().str.strip() == drug1))
156
+ ]
157
+ if not matches.empty:
158
+ drug_data = matches
159
+ break
160
+
161
+ if not drug_data.empty:
162
+ print(f"Found drugs '{drug1}' and '{drug2}' in the dataset!")
163
+ return drug_data.iloc[0] # Returns the first match
164
+ else:
165
+ print(f"Drugs '{drug1}' and '{drug2}' not found in the dataset.")
166
+ return None
167
+
168
+ except Exception as e:
169
+ print(f"Error searching for drugs in dataset: {e}")
170
+ return None
171
+
172
+ # Function to predict the severity based on the drugs' data
173
+ def predict_severity(drug1, drug2):
174
+ if not drug1 or not drug2:
175
+ return "Please enter both drugs to predict interaction severity."
176
+
177
+ # Clean input before processing
178
+ drug1 = clean_drug_name(drug1)
179
+ drug2 = clean_drug_name(drug2)
180
+
181
+ print(f"Processing request for drugs: '{drug1}' and '{drug2}'")
182
+
183
+ # For drugs in the dataset, we'll bypass validation
184
+ drug_data = get_drug_features_from_dataset(drug1, drug2, df)
185
+
186
+ if drug_data is not None:
187
+ print(f"Found drugs in dataset, bypassing validation")
188
+ is_valid_drug1 = True
189
+ is_valid_drug2 = True
190
+ else:
191
+ # Step 1: Validate the inputs are actual drug names if not found in dataset
192
+ print("Drugs not found in dataset, validating through other means")
193
+
194
+ validation_results = []
195
+ for drug_name in [drug1, drug2]:
196
+ # Try dataset validation first (individual drug)
197
+ is_valid, message = validate_drug_input(drug_name)
198
+
199
+ # If not in dataset, try API validation
200
+ if is_valid is None:
201
+ is_valid, message = validate_drug_via_api(drug_name)
202
+
203
+ validation_results.append((drug_name, is_valid, message))
204
+
205
+ # If either drug failed validation, return error
206
+ invalid_drugs = [(name, msg) for name, valid, msg in validation_results if not valid]
207
+ if invalid_drugs:
208
+ invalid_names = ", ".join([f"'{name}' ({msg})" for name, msg in invalid_drugs])
209
+ return f"Invalid drug name(s): {invalid_names}. Please enter valid drug names."
210
+
211
+ is_valid_drug1 = validation_results[0][1]
212
+ is_valid_drug2 = validation_results[1][1]
213
+
214
+ # If we've made it here, both drugs are valid
215
+
216
+ # If we already have the drug data from the dataset check
217
+ if drug_data is not None:
218
+ print(f"Using dataset features for '{drug1}' and '{drug2}'")
219
+ # Extract features based on available columns
220
+ try:
221
+ # Prepare feature dictionary based on available columns
222
+ drug_features = {}
223
+
224
+ # Map potential column names to expected feature names
225
+ column_mappings = {
226
+ 'SMILES': ['SMILES', 'smiles'],
227
+ 'pharmacodynamics': ['pharmacodynamics', 'Pharmacodynamics', 'pharmacology'],
228
+ 'toxicity': ['toxicity', 'Toxicity']
229
+ }
230
+
231
+ # Get features from dataset using flexible column matching
232
+ for feature, possible_cols in column_mappings.items():
233
+ feature_found = False
234
+ for col in possible_cols:
235
+ if col in drug_data.index or col in drug_data:
236
+ try:
237
+ drug_features[feature] = drug_data[col]
238
+ feature_found = True
239
+ break
240
+ except Exception as e:
241
+ print(f"Error accessing column {col}: {e}")
242
+ continue
243
+ if not feature_found:
244
+ drug_features[feature] = 'No data'
245
+
246
+ except Exception as e:
247
+ print(f"Error extracting features from dataset: {e}")
248
+ return f"Error processing drug data: {e}"
249
+ else:
250
+ print(f"Fetching API data for '{drug1}' and '{drug2}'")
251
+ # If drugs not found in dataset, fetch from API
252
+ drug1_features = get_drug_features_from_api(drug1)
253
+ if drug1_features is None and is_valid_drug1:
254
+ # Try again with a fallback approach for special characters
255
+ drug1_features = {
256
+ 'SMILES': 'No data from API',
257
+ 'pharmacodynamics': 'No data from API',
258
+ 'toxicity': 'No data from API'
259
+ }
260
+
261
+ drug2_features = get_drug_features_from_api(drug2)
262
+ if drug2_features is None and is_valid_drug2:
263
+ # Try again with a fallback approach for special characters
264
+ drug2_features = {
265
+ 'SMILES': 'No data from API',
266
+ 'pharmacodynamics': 'No data from API',
267
+ 'toxicity': 'No data from API'
268
+ }
269
+
270
+ # Verify we got data for both drugs
271
+ if drug1_features is None or drug2_features is None:
272
+ return "Couldn't retrieve sufficient data for one or both drugs. Please try different drugs or check your spelling."
273
+
274
+ # Combine features from both drugs
275
+ drug_features = {
276
+ 'SMILES': f"{drug1}: {drug1_features['SMILES']}; {drug2}: {drug2_features['SMILES']}",
277
+ 'pharmacodynamics': f"{drug1}: {drug1_features.get('pharmacodynamics', 'No data')}; {drug2}: {drug2_features.get('pharmacodynamics', 'No data')}",
278
+ 'toxicity': f"{drug1}: {drug1_features.get('toxicity', 'No data')}; {drug2}: {drug2_features.get('toxicity', 'No data')}"
279
+ }
280
+
281
+ # Create interaction description
282
+ interaction_description = f"{drug1} interacts with {drug2}"
283
+
284
+ # Tokenize the input for the model
285
+ inputs = tokenizer(interaction_description, return_tensors="pt", padding=True, truncation=True, max_length=128)
286
+
287
+ # Move inputs to appropriate device
288
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
289
+ input_ids = inputs['input_ids'].to(device)
290
+ attention_mask = inputs['attention_mask'].to(device)
291
+
292
+ try:
293
+ # Run the model to get predictions
294
+ with torch.no_grad():
295
+ outputs = model(input_ids, attention_mask=attention_mask)
296
+
297
+ # Get the predicted class
298
+ prediction = torch.argmax(outputs.logits, dim=1).item()
299
+
300
+ # Map the predicted class index to the severity label using label encoder if available
301
+ if hasattr(label_encoder, 'classes_'):
302
+ severity_label = label_encoder.classes_[prediction]
303
+ else:
304
+ # Fallback labels if encoder doesn't work
305
+ severity_labels = ["No interaction", "Mild", "Moderate", "Severe"]
306
+ severity_label = severity_labels[prediction]
307
+
308
+ # Calculate confidence score
309
+ probabilities = torch.nn.functional.softmax(outputs.logits, dim=1)
310
+ confidence = probabilities[0][prediction].item() * 100
311
+
312
+ result = f"Predicted interaction severity: {severity_label} (Confidence: {confidence:.1f}%)"
313
+
314
+ # Add source information
315
+ if drug_data is not None:
316
+ result += "\nData source: Features from dataset"
317
+ else:
318
+ result += "\nData source: Features from PubChem API"
319
+
320
+ return result
321
+
322
+ except Exception as e:
323
+ print(f"Error during prediction: {e}")
324
+ return f"Error making prediction: {e}"
325
+
326
+ # Gradio Interface
327
+ interface = gr.Interface(
328
+ fn=predict_severity,
329
+ inputs=[
330
+ gr.Textbox(label="Drug 1 (e.g., Aspirin)", placeholder="Enter first drug name"),
331
+ gr.Textbox(label="Drug 2 (e.g., Warfarin)", placeholder="Enter second drug name")
332
+ ],
333
+ outputs=gr.Textbox(label="Prediction Result"),
334
+ title="Drug Interaction Severity Predictor",
335
+ description="Enter two drug names to predict the severity of their interaction.",
336
+ examples=[["Aspirin", "Warfarin"], ["Ibuprofen", "Naproxen"], ["Hydralazine", "Amphetamine"]]
337
+ )
338
+
339
+ # Launch the interface
340
+ if __name__ == "__main__":
341
+ interface.launch(debug=True)