Fredaaaaaa commited on
Commit
f21c58d
·
verified ·
1 Parent(s): c2c62e3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +285 -1
app.py CHANGED
@@ -1,3 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # Updated prediction function with improved confidence handling
2
  def predict_severity(drug1, drug2):
3
  if not drug1 or not drug2:
@@ -247,7 +531,7 @@ def predict_severity(drug1, drug2):
247
  print(f"Error during prediction: {e}")
248
  return f"Error making prediction: {e}"
249
 
250
- # Gradio Interface
251
  interface = gr.Interface(
252
  fn=predict_severity,
253
  inputs=[
 
1
+ import pickle
2
+ import requests
3
+ from huggingface_hub import hf_hub_download
4
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
5
+ import torch
6
+ import gradio as gr
7
+ import pandas as pd
8
+ import re
9
+ from sklearn.utils.class_weight import compute_class_weight
10
+ import numpy as np
11
+
12
+ # ✅ Device setup
13
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
+ print(f"Using device: {device}")
15
+
16
+ # Download label encoder from Hugging Face Hub
17
+ label_encoder_path = hf_hub_download(repo_id="Fredaaaaaa/hybrid_model", filename="label_encoder.pkl")
18
+ with open(label_encoder_path, 'rb') as f:
19
+ label_encoder = pickle.load(f)
20
+
21
+ # Load model and tokenizer
22
+ model_name = "Fredaaaaaa/hybrid_model"
23
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
24
+ model = AutoModelForSequenceClassification.from_pretrained(model_name)
25
+ model.to(device) # Move model to appropriate device
26
+ model.eval()
27
+
28
+ # Download the dataset from Hugging Face Hub
29
+ dataset_path = hf_hub_download(repo_id="Fredaaaaaa/hybrid_model", filename="labeled_severity.csv")
30
+
31
+ # Load the dataset with appropriate encoding
32
+ df = pd.read_csv(dataset_path, encoding='ISO-8859-1')
33
+ print(f"Dataset loaded successfully! Shape: {df.shape}")
34
+
35
+ # Check the columns and display first few rows for debugging
36
+ print(df.columns)
37
+ print(df.head())
38
+
39
+ # Get unique severity classes from the dataset
40
+ unique_classes = df['severity'].unique()
41
+ print(f"Unique severity classes in dataset: {unique_classes}")
42
+
43
+ # Calculate class weights to handle imbalanced classes
44
+ # Use the unique classes from the dataset for the `classes` parameter
45
+ class_weights = compute_class_weight('balanced', classes=np.unique(unique_classes), y=df['severity'])
46
+ class_weights = torch.tensor(class_weights, dtype=torch.float).to(device)
47
+ loss_fn = torch.nn.CrossEntropyLoss(weight=class_weights)
48
+
49
+ # Extract unique drug names from the dataset to create a list of known drugs
50
+ all_drugs = set()
51
+ # Check the possible column names and add drugs to our set
52
+ for col in ['Drug1', 'Drug 1', 'drug1', 'drug_1', 'Drug 1_normalized']:
53
+ if col in df.columns:
54
+ # Convert to strings, clean and add to set
55
+ all_drugs.update(df[col].astype(str).str.lower().str.strip().tolist())
56
+ for col in ['Drug2', 'Drug 2', 'drug2', 'drug_2', 'Drug 2_normalized']:
57
+ if col in df.columns:
58
+ # Convert to strings, clean and add to set
59
+ all_drugs.update(df[col].astype(str).str.lower().str.strip().tolist())
60
+
61
+ # Remove any empty strings or NaN values
62
+ all_drugs = {drug for drug in all_drugs if drug and drug != 'nan'}
63
+ print(f"Loaded {len(all_drugs)} unique drug names from dataset")
64
+
65
+ # Function to properly clean drug names
66
+ def clean_drug_name(drug_name):
67
+ if not drug_name:
68
+ return ""
69
+ # Remove extra whitespace and standardize to lowercase
70
+ return re.sub(r'\s+', ' ', drug_name.strip().lower())
71
+
72
+ # Function to validate if input is a legitimate drug name
73
+ def validate_drug_input(drug_name):
74
+ # Clean the input
75
+ drug_name = clean_drug_name(drug_name)
76
+
77
+ if not drug_name or len(drug_name) <= 1:
78
+ return False, "Drug name is too short"
79
+
80
+ # Check if it's just a single letter or number
81
+ if len(drug_name) == 1 or drug_name.isdigit():
82
+ return False, "Not a valid drug name"
83
+
84
+ # Check if it contains weird characters
85
+ if not re.match(r'^[a-zA-Z0-9\s\-\+]+$', drug_name):
86
+ return False, "Drug name contains invalid characters"
87
+
88
+ # Check if it's in our known drug list
89
+ if drug_name in all_drugs:
90
+ return True, "Drug found in dataset"
91
+
92
+ # If we have a small drug list or need to be more forgiving, we can try fuzzy matching
93
+ for known_drug in all_drugs:
94
+ if drug_name in known_drug or known_drug in drug_name:
95
+ return True, f"Drug found in dataset (matched with '{known_drug}')"
96
+
97
+ # If not in dataset, we'll try the API validation
98
+ return None, "Drug not in dataset, needs API validation"
99
+
100
+ def validate_drug_via_api(drug_name):
101
+ """Validate a drug name using PubChem API"""
102
+ try:
103
+ # Clean the input
104
+ drug_name = clean_drug_name(drug_name)
105
+
106
+ # Use PubChem API to search for the drug
107
+ search_url = f"https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/name/{drug_name}/cids/JSON"
108
+ response = requests.get(search_url, timeout=10)
109
+
110
+ if response.status_code == 200:
111
+ data = response.json()
112
+ # Check if we got a valid CID (PubChem Compound ID)
113
+ if 'IdentifierList' in data and 'CID' in data['IdentifierList']:
114
+ return True, f"Drug validated via PubChem API (CID: {data['IdentifierList']['CID'][0]})"
115
+ else:
116
+ return False, "Drug not found in PubChem database"
117
+ else:
118
+ # Try a fallback for compounds with special characters
119
+ fallback_url = f"https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/name/{requests.utils.quote(drug_name)}/cids/JSON"
120
+ fallback_response = requests.get(fallback_url, timeout=10)
121
+
122
+ if fallback_response.status_code == 200:
123
+ data = fallback_response.json()
124
+ if 'IdentifierList' in data and 'CID' in data['IdentifierList']:
125
+ return True, f"Drug validated via PubChem API (CID: {data['IdentifierList']['CID'][0]})"
126
+
127
+ return False, f"Invalid drug name: API returned status {response.status_code}"
128
+
129
+ except Exception as e:
130
+ print(f"Error validating drug via API: {e}")
131
+ # Be more lenient if API validation fails
132
+ return True, "API validation failed, assuming valid drug"
133
+
134
+ def get_drug_features_from_api(drug_name):
135
+ """Get drug features from PubChem API"""
136
+ try:
137
+ # Clean the input
138
+ drug_name = clean_drug_name(drug_name)
139
+
140
+ # First get the CID from PubChem
141
+ search_url = f"https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/name/{drug_name}/cids/JSON"
142
+ response = requests.get(search_url, timeout=10)
143
+
144
+ if response.status_code != 200:
145
+ # Try URL encoding for drugs with special characters
146
+ search_url = f"https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/name/{requests.utils.quote(drug_name)}/cids/JSON"
147
+ response = requests.get(search_url, timeout=10)
148
+
149
+ if response.status_code != 200:
150
+ print(f"Drug {drug_name} not found in PubChem")
151
+ return None
152
+
153
+ # Extract the CID
154
+ data = response.json()
155
+ if 'IdentifierList' not in data or 'CID' not in data['IdentifierList']:
156
+ print(f"No CID found for drug {drug_name}")
157
+ return None
158
+
159
+ cid = data['IdentifierList']['CID'][0]
160
+
161
+ # Get the SMILES notation
162
+ smiles_url = f"https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/cid/{cid}/property/CanonicalSMILES/JSON"
163
+ smiles_response = requests.get(smiles_url, timeout=10)
164
+
165
+ # Initialize features dictionary
166
+ features = {
167
+ 'SMILES': 'No data',
168
+ 'pharmacodynamics': 'No data',
169
+ 'toxicity': 'No data'
170
+ }
171
+
172
+ # Extract SMILES if available
173
+ if smiles_response.status_code == 200:
174
+ smiles_data = smiles_response.json()
175
+ if 'PropertyTable' in smiles_data and 'Properties' in smiles_data['PropertyTable']:
176
+ properties = smiles_data['PropertyTable']['Properties']
177
+ if properties and 'CanonicalSMILES' in properties[0]:
178
+ features['SMILES'] = properties[0]['CanonicalSMILES']
179
+
180
+ # Get pharmacological information (we'll use this for both pharmacodynamics and toxicity)
181
+ info_url = f"https://pubchem.ncbi.nlm.nih.gov/rest/pug_view/data/compound/{cid}/JSON"
182
+ info_response = requests.get(info_url, timeout=15) # Increased timeout
183
+
184
+ if info_response.status_code == 200:
185
+ info_data = info_response.json()
186
+ if 'Record' in info_data and 'Section' in info_data['Record']:
187
+ # Search through sections for pharmacology information
188
+ for section in info_data['Record']['Section']:
189
+ if 'TOCHeading' in section:
190
+ # Look for Pharmacology section
191
+ if section['TOCHeading'] == 'Pharmacology':
192
+ if 'Section' in section:
193
+ for subsection in section['Section']:
194
+ if 'TOCHeading' in subsection:
195
+ # Extract pharmacodynamics
196
+ if subsection['TOCHeading'] == 'Mechanism of Action':
197
+ if 'Information' in subsection:
198
+ for info in subsection['Information']:
199
+ if 'Value' in info and 'StringWithMarkup' in info['Value']:
200
+ for text in info['Value']['StringWithMarkup']:
201
+ if 'String' in text:
202
+ features['pharmacodynamics'] = text['String'][:500] # Limit to 500 chars
203
+ break
204
+
205
+ # Look for toxicity information
206
+ if section['TOCHeading'] == 'Toxicity':
207
+ if 'Information' in section:
208
+ for info in section['Information']:
209
+ if 'Value' in info and 'StringWithMarkup' in info['Value']:
210
+ for text in info['Value']['StringWithMarkup']:
211
+ if 'String' in text:
212
+ features['toxicity'] = text['String'][:500] # Limit to 500 chars
213
+ break
214
+
215
+ return features
216
+
217
+ except Exception as e:
218
+ print(f"Error getting drug features from API: {e}")
219
+ return None
220
+
221
+ # Function to check if drugs are in the dataset
222
+ def get_drug_features_from_dataset(drug1, drug2, df):
223
+ if df.empty:
224
+ print("Dataset is empty, cannot search for drugs")
225
+ return None
226
+
227
+ # Normalize drug names for matching
228
+ drug1 = clean_drug_name(drug1)
229
+ drug2 = clean_drug_name(drug2)
230
+
231
+ print(f"Checking for drugs in dataset: '{drug1}', '{drug2}'")
232
+
233
+ try:
234
+ # First try with normalized columns
235
+ if 'Drug 1_normalized' in df.columns and 'Drug 2_normalized' in df.columns:
236
+ # Apply cleaning function to dataframe columns for comparison
237
+ drug_data = df[
238
+ (df['Drug 1_normalized'].str.lower().str.strip() == drug1) &
239
+ (df['Drug 2_normalized'].str.lower().str.strip() == drug2)
240
+ ]
241
+
242
+ # Also check the reverse combination
243
+ reversed_drug_data = df[
244
+ (df['Drug 1_normalized'].str.lower().str.strip() == drug2) &
245
+ (df['Drug 2_normalized'].str.lower().str.strip() == drug1)
246
+ ]
247
+
248
+ # Combine the results
249
+ drug_data = pd.concat([drug_data, reversed_drug_data])
250
+ else:
251
+ # Try with regular Drug1/Drug2 columns if normalized not available
252
+ possible_column_pairs = [
253
+ ('Drug1', 'Drug2'),
254
+ ('Drug 1', 'Drug 2'),
255
+ ('drug1', 'drug2'),
256
+ ('drug_1', 'drug_2')
257
+ ]
258
+
259
+ drug_data = pd.DataFrame() # Initialize as empty
260
+
261
+ for col1, col2 in possible_column_pairs:
262
+ if col1 in df.columns and col2 in df.columns:
263
+ # Clean the strings in the dataframe columns for comparison
264
+ matches = df[
265
+ ((df[col1].astype(str).str.lower().str.strip() == drug1) &
266
+ (df[col2].astype(str).str.lower().str.strip() == drug2)) |
267
+ ((df[col1].astype(str).str.lower().str.strip() == drug2) &
268
+ (df[col2].astype(str).str.lower().str.strip() == drug1))
269
+ ]
270
+ if not matches.empty:
271
+ drug_data = matches
272
+ break
273
+
274
+ if not drug_data.empty:
275
+ print(f"Found drugs '{drug1}' and '{drug2}' in the dataset!")
276
+ return drug_data.iloc[0] # Returns the first match
277
+ else:
278
+ print(f"Drugs '{drug1}' and '{drug2}' not found in the dataset.")
279
+ return None
280
+
281
+ except Exception as e:
282
+ print(f"Error searching for drugs in dataset: {e}")
283
+ return None
284
+
285
  # Updated prediction function with improved confidence handling
286
  def predict_severity(drug1, drug2):
287
  if not drug1 or not drug2:
 
531
  print(f"Error during prediction: {e}")
532
  return f"Error making prediction: {e}"
533
 
534
+ # Gradio Interface
535
  interface = gr.Interface(
536
  fn=predict_severity,
537
  inputs=[