Fredaaaaaa commited on
Commit
1357ff3
·
verified ·
1 Parent(s): 7f0cce3

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +281 -0
app.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ import requests
3
+ import torch
4
+ import gradio as gr
5
+ import pandas as pd
6
+ import re
7
+ import numpy as np
8
+ import os
9
+ import shutil
10
+ from huggingface_hub import hf_hub_download
11
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
12
+ from sklearn.utils.class_weight import compute_class_weight
13
+
14
+ # Device setup
15
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+ print(f"Using device: {device}")
17
+
18
+ # Model and dataset paths
19
+ model_name = "Fredaaaaaa/hybrid_model"
20
+ output_dir = "/home/user/app/drug_interaction_model"
21
+
22
+ # Create output directory
23
+ os.makedirs(output_dir, exist_ok=True)
24
+
25
+ # Download and load label encoder
26
+ label_encoder_path = hf_hub_download(repo_id=model_name, filename="label_encoder.pkl")
27
+ with open(label_encoder_path, 'rb') as f:
28
+ label_encoder = pickle.load(f)
29
+
30
+ # Load model and tokenizer
31
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
32
+ model = AutoModelForSequenceClassification.from_pretrained(model_name)
33
+ model.to(device)
34
+ model.eval()
35
+
36
+ # Download and load dataset
37
+ dataset_path = hf_hub_download(repo_id=model_name, filename="labeled_severity.csv")
38
+ df = pd.read_csv(dataset_path, encoding='ISO-8859-1')
39
+ print(f"Dataset loaded successfully! Shape: {df.shape}")
40
+ print(f"Columns: {df.columns}")
41
+ print(df.head())
42
+
43
+ # Save model, tokenizer, label encoder, and dataset
44
+ model.save_pretrained(output_dir)
45
+ tokenizer.save_pretrained(output_dir)
46
+ with open(os.path.join(output_dir, 'label_encoder.pkl'), 'wb') as f:
47
+ pickle.dump(label_encoder, f)
48
+ df.to_csv(os.path.join(output_dir, 'labeled_severity.csv'), index=False)
49
+
50
+ # Create zip archive
51
+ zip_path = "/home/user/app/drug_interaction_model.zip"
52
+ shutil.make_archive("/home/user/app/drug_interaction_model", 'zip', output_dir)
53
+ print(f"📦 Model saved and zipped at: {zip_path}")
54
+ print(f"To download, access the file at: {zip_path} from your environment or server.")
55
+
56
+ # Compute class weights
57
+ unique_classes = df['severity'].unique()
58
+ print(f"Unique severity classes: {unique_classes}")
59
+ class_weights = compute_class_weight('balanced', classes=np.unique(unique_classes), y=df['severity'])
60
+ class_weights = torch.tensor(class_weights, dtype=torch.float).to(device)
61
+
62
+ # Extract unique drug names
63
+ all_drugs = set()
64
+ for col in ['Drug 1_normalized', 'Drug1', 'Drug 1', 'drug1', 'drug_1']:
65
+ if col in df.columns:
66
+ all_drugs.update(df[col].astype(str).str.lower().str.strip().tolist())
67
+ for col in ['Drug 2_normalized', 'Drug2', 'Drug 2', 'drug2', 'drug_2']:
68
+ if col in df.columns:
69
+ all_drugs.update(df[col].astype(str).str.lower().str.strip().tolist())
70
+ all_drugs = {drug for drug in all_drugs if drug and drug != 'nan'}
71
+ print(f"Loaded {len(all_drugs)} unique drug names")
72
+
73
+ # Helper functions
74
+ def clean_drug_name(drug_name):
75
+ if not drug_name:
76
+ return ""
77
+ return re.sub(r'\s+', ' ', drug_name.strip().lower())
78
+
79
+ def validate_drug_input(drug_name):
80
+ drug_name = clean_drug_name(drug_name)
81
+ if not drug_name or len(drug_name) <= 1:
82
+ return False, "Drug name is too short"
83
+ if len(drug_name) == 1 or drug_name.isdigit():
84
+ return False, "Not a valid drug name"
85
+ if not re.match(r'^[a-zA-Z0-9\s\-\+]+$', drug_name):
86
+ return False, "Drug name contains invalid characters"
87
+ if drug_name in all_drugs:
88
+ return True, "Drug found in dataset"
89
+ for known_drug in all_drugs:
90
+ if drug_name in known_drug or known_drug in drug_name:
91
+ return True, f"Drug found in dataset (matched with '{known_drug}')"
92
+ return None, "Drug not in dataset, needs API validation"
93
+
94
+ def validate_drug_via_api(drug_name):
95
+ try:
96
+ drug_name = clean_drug_name(drug_name)
97
+ search_url = f"https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/name/{drug_name}/cids/JSON"
98
+ response = requests.get(search_url, timeout=10)
99
+ if response.status_code == 200:
100
+ data = response.json()
101
+ if 'IdentifierList' in data and 'CID' in data['IdentifierList']:
102
+ return True, f"Drug validated via PubChem API (CID: {data['IdentifierList']['CID'][0]})"
103
+ return False, "Drug not found in PubChem database"
104
+ else:
105
+ fallback_url = f"https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/name/{requests.utils.quote(drug_name)}/cids/JSON"
106
+ fallback_response = requests.get(fallback_url, timeout=10)
107
+ if fallback_response.status_code == 200:
108
+ data = fallback_response.json()
109
+ if 'IdentifierList' in data and 'CID' in data['IdentifierList']:
110
+ return True, f"Drug validated via PubChem API (CID: {data['IdentifierList']['CID'][0]})"
111
+ return False, f"Invalid drug name: API returned status {response.status_code}"
112
+ except Exception as e:
113
+ print(f"Error validating drug via API: {e}")
114
+ return True, "API validation failed, assuming valid drug"
115
+
116
+ def get_smiles_from_api(drug_name):
117
+ try:
118
+ drug_name = clean_drug_name(drug_name)
119
+ search_url = f"https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/name/{drug_name}/cids/JSON"
120
+ response = requests.get(search_url, timeout=10)
121
+ if response.status_code != 200:
122
+ search_url = f"https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/name/{requests.utils.quote(drug_name)}/cids/JSON"
123
+ response = requests.get(search_url, timeout=10)
124
+ if response.status_code != 200:
125
+ print(f"Drug {drug_name} not found in PubChem")
126
+ return None
127
+ data = response.json()
128
+ if 'IdentifierList' not in data or 'CID' not in data['IdentifierList']:
129
+ print(f"No CID found for drug {drug_name}")
130
+ return None
131
+ cid = data['IdentifierList']['CID'][0]
132
+ smiles_url = f"https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/cid/{cid}/property/CanonicalSMILES/JSON"
133
+ smiles_response = requests.get(smiles_url, timeout=10)
134
+ if smiles_response.status_code == 200:
135
+ smiles_data = smiles_response.json()
136
+ if 'PropertyTable' in smiles_data and 'Properties' in smiles_data['PropertyTable']:
137
+ properties = smiles_data['PropertyTable']['Properties']
138
+ if properties and 'CanonicalSMILES' in properties[0]:
139
+ print(f"SMILES found for {drug_name}: {properties[0]['CanonicalSMILES']}")
140
+ return properties[0]['CanonicalSMILES']
141
+ print(f"No SMILES found for drug {drug_name}")
142
+ return None
143
+ except Exception as e:
144
+ print(f"Error getting SMILES from API: {e}")
145
+ return None
146
+
147
+ def get_drug_features_from_dataset(drug1, drug2, df):
148
+ if df.empty:
149
+ print("Dataset is empty")
150
+ return None, None, None
151
+ drug1 = clean_drug_name(drug1)
152
+ drug2 = clean_drug_name(drug2)
153
+ try:
154
+ if 'Drug 1_normalized' in df.columns and 'Drug 2_normalized' in df.columns:
155
+ drug_data = df[
156
+ (df['Drug 1_normalized'].str.lower().str.strip() == drug1) &
157
+ (df['Drug 2_normalized'].str.lower().str.strip() == drug2)
158
+ ]
159
+ reversed_drug_data = df[
160
+ (df['Drug 1_normalized'].str.lower().str.strip() == drug2) &
161
+ (df['Drug 2_normalized'].str.lower().str.strip() == drug1)
162
+ ]
163
+ drug_data = pd.concat([drug_data, reversed_drug_data])
164
+ else:
165
+ drug_data = pd.DataFrame()
166
+ for col1, col2 in [('Drug1', 'Drug2'), ('Drug 1', 'Drug 2'), ('drug1', 'drug2'), ('drug_1', 'drug_2')]:
167
+ if col1 in df.columns and col2 in df.columns:
168
+ matches = df[
169
+ ((df[col1].astype(str).str.lower().str.strip() == drug1) &
170
+ (df[col2].astype(str).str.lower().str.strip() == drug2)) |
171
+ ((df[col1].astype(str).str.lower().str.strip() == drug2) &
172
+ (df[col2].astype(str).str.lower().str.strip() == drug1))
173
+ ]
174
+ if not matches.empty:
175
+ drug_data = matches
176
+ break
177
+ if not drug_data.empty:
178
+ print(f"Found drugs '{drug1}' and '{drug2}' in dataset")
179
+ smiles1 = drug_data.get('SMILES', None)
180
+ smiles2 = drug_data.get('SMILES_2', None)
181
+ if isinstance(smiles1, pd.Series):
182
+ smiles1 = smiles1.iloc[0]
183
+ if isinstance(smiles2, pd.Series):
184
+ smiles2 = smiles2.iloc[0]
185
+ severity = drug_data.get('severity', None)
186
+ if isinstance(severity, pd.Series):
187
+ severity = severity.iloc[0]
188
+ return smiles1, smiles2, severity
189
+ return None, None, None
190
+ except Exception as e:
191
+ print(f"Error searching dataset: {e}")
192
+ return None, None, None
193
+
194
+ def predict_severity(drug1, drug2):
195
+ if not drug1 or not drug2:
196
+ return "Please enter both drugs."
197
+ drug1 = clean_drug_name(drug1)
198
+ drug2 = clean_drug_name(drug2)
199
+ print(f"Processing: '{drug1}', '{drug2}'")
200
+ smiles1, smiles2, severity = get_drug_features_from_dataset(drug1, drug2, df)
201
+ if severity is not None:
202
+ confidence = 98.0
203
+ result = f"Predicted interaction severity: {severity} (Confidence: {confidence:.1f}%)\nData source: Direct match from dataset"
204
+ return result
205
+ validation_results = []
206
+ for drug_name in [drug1, drug2]:
207
+ is_valid, message = validate_drug_input(drug_name)
208
+ if is_valid is None:
209
+ is_valid, message = validate_drug_via_api(drug_name)
210
+ validation_results.append((drug_name, is_valid, message))
211
+ invalid_drugs = [(name, msg) for name, valid, msg in validation_results if not valid]
212
+ if invalid_drugs:
213
+ return f"Invalid drug(s): {', '.join([f'{name} ({msg})' for name, msg in invalid_drugs])}"
214
+ drug1_in_dataset = drug1 in all_drugs
215
+ drug2_in_dataset = drug2 in all_drugs
216
+ if smiles1 is None:
217
+ smiles1 = get_smiles_from_api(drug1)
218
+ if smiles2 is None:
219
+ smiles2 = get_smiles_from_api(drug2)
220
+ if smiles1 is None or smiles2 is None:
221
+ return "Couldn't retrieve SMILES for one or both drugs."
222
+ drug_description = f"{drug1} SMILES: {smiles1[:100]}. {drug2} SMILES: {smiles2[:100]}."
223
+ interaction_description = drug_description[:512]
224
+ is_from_dataset = smiles1 in df.get('SMILES', []).values and smiles2 in df.get('SMILES_2', []).values
225
+ print(f"Using description: {interaction_description}")
226
+ inputs = tokenizer(interaction_description, return_tensors="pt", padding=True, truncation=True, max_length=128)
227
+ input_ids = inputs['input_ids'].to(device)
228
+ attention_mask = inputs['attention_mask'].to(device)
229
+ try:
230
+ with torch.no_grad():
231
+ outputs = model(input_ids, attention_mask=attention_mask)
232
+ temperature = 0.6 if is_from_dataset else 0.5
233
+ logits = outputs.logits / temperature
234
+ if not is_from_dataset and (drug1_in_dataset or drug2_in_dataset):
235
+ no_interaction_idx = 0
236
+ if logits[0][no_interaction_idx] > 0:
237
+ logits[0][no_interaction_idx] *= 0.85
238
+ probabilities = torch.nn.functional.softmax(logits, dim=1)
239
+ if not is_from_dataset:
240
+ top_probs, top_indices = torch.topk(probabilities, 2, dim=1)
241
+ diff = top_probs[0][0].item() - top_probs[0][1].item()
242
+ if diff < 0.2 and top_indices[0][1] > top_indices[0][0]:
243
+ probabilities[0][top_indices[0][1]] *= 1.15
244
+ probabilities = probabilities / probabilities.sum()
245
+ prediction = torch.argmax(probabilities, dim=1).item()
246
+ severity_label = label_encoder.classes_[prediction]
247
+ confidence = probabilities[0][prediction].item() * 100
248
+ min_confidence = {"No interaction": 70.0, "Mild": 75.0, "Moderate": 80.0, "Severe": 85.0}
249
+ min_conf = min_confidence.get(severity_label, 70.0)
250
+ if not is_from_dataset and confidence < min_conf:
251
+ confidence = min(min_conf + 5.0, 95.0)
252
+ result = f"Predicted interaction severity: {severity_label} (Confidence: {confidence:.1f}%)\nData source: {'Dataset' if is_from_dataset else 'PubChem API'}"
253
+ if not is_from_dataset:
254
+ interpretations = {
255
+ "No interaction": "Minimal risk, but consult a professional.",
256
+ "Mild": "Minor interaction possible. Monitor for mild effects.",
257
+ "Moderate": "Notable interaction likely. Supervision recommended.",
258
+ "Severe": "Potentially serious. Consult provider before use."
259
+ }
260
+ result += f"\nInterpretation: {interpretations.get(severity_label, 'Consult a professional.')}"
261
+ result += "\n\nDisclaimer: For research only. Consult healthcare professionals."
262
+ return result
263
+ except Exception as e:
264
+ print(f"Error during prediction: {e}")
265
+ return f"Error: {e}"
266
+
267
+ # Gradio Interface
268
+ interface = gr.Interface(
269
+ fn=predict_severity,
270
+ inputs=[
271
+ gr.Textbox(label="Drug 1 (e.g., Aspirin)", placeholder="Enter first drug name"),
272
+ gr.Textbox(label="Drug 2 (e.g., Warfarin)", placeholder="Enter second drug name")
273
+ ],
274
+ outputs=gr.Textbox(label="Prediction Result"),
275
+ title="Drug Interaction Severity Predictor",
276
+ description="Enter two drug names to predict interaction severity based on SMILES.",
277
+ examples=[["Aspirin", "Warfarin"], ["Ibuprofen", "Naproxen"], ["Hydralazine", "Amphetamine"]]
278
+ )
279
+
280
+ if __name__ == "__main__":
281
+ interface.launch(debug=True)