Fredaaaaaa commited on
Commit
9b06fe2
Β·
verified Β·
1 Parent(s): f422844

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +180 -28
app.py CHANGED
@@ -1,15 +1,137 @@
1
  import gradio as gr
2
  import requests
3
  import time
 
 
 
 
4
 
5
- # Try to import the predictor
6
  try:
7
- from inference import predictor, MODEL_LOADED
8
- print("βœ… Inference module imported successfully")
9
- except ImportError as e:
10
- print(f"❌ Failed to import inference: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  MODEL_LOADED = False
12
- predictor = None
13
 
14
  def fetch_pubchem_data(drug_name):
15
  """Fetch drug data from PubChem by name"""
@@ -26,7 +148,7 @@ def fetch_pubchem_data(drug_name):
26
 
27
  cid = search_response.json()['IdentifierList']['CID'][0]
28
 
29
- # Fetch basic compound data
30
  compound_url = f"https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/cid/{cid}/property/CanonicalSMILES,MolecularWeight,IUPACName/JSON"
31
  compound_response = requests.get(compound_url, timeout=10)
32
 
@@ -51,10 +173,12 @@ def generate_interaction_description(drug1_data, drug2_data):
51
  if mw1 and mw2:
52
  mw_diff = abs(mw1 - mw2)
53
  if mw_diff > 300:
54
- descriptions.append("Significant molecular size difference")
 
 
55
 
56
  if not descriptions:
57
- descriptions.append("Potential drug interaction")
58
 
59
  return ". ".join(descriptions) + ". Clinical evaluation recommended."
60
  except:
@@ -63,9 +187,6 @@ def generate_interaction_description(drug1_data, drug2_data):
63
  def predict_ddi(drug1_name, drug2_name):
64
  """Main prediction function"""
65
  try:
66
- if not MODEL_LOADED or predictor is None:
67
- return "Model not loaded. Please check requirements.txt", "", "", ""
68
-
69
  if not drug1_name or not drug2_name:
70
  return "Please enter both drug names", "", "", ""
71
 
@@ -84,36 +205,67 @@ def predict_ddi(drug1_name, drug2_name):
84
 
85
  # Prepare output
86
  drug_info = f"""
87
- **{drug1_name}**: MW={drug1_data.get('MolecularWeight', 'N/A')} g/mol
88
- **{drug2_name}**: MW={drug2_data.get('MolecularWeight', 'N/A')} g/mol
 
 
 
 
 
89
  """
90
 
91
  prediction_output = f"""
92
- **Prediction:** {result['prediction']}
 
 
93
  **Confidence:** {result['confidence']:.1%}
94
- **Description:** {interaction_description}
 
 
 
 
 
95
  """
96
 
97
- return prediction_output, drug_info, interaction_description, "βœ… Success"
 
 
98
 
99
  except Exception as e:
100
  return f"Error: {str(e)}", "", "", ""
101
 
102
- # Create simple interface
103
- with gr.Blocks(title="Drug Interaction Predictor") as demo:
104
- gr.Markdown("# πŸ’Š Drug Interaction Predictor")
105
- gr.Markdown("Model: Fredaaaaaa/drug_interaction_severity")
106
 
107
  with gr.Row():
108
- drug1 = gr.Textbox(label="Drug 1", placeholder="e.g., Aspirin")
109
- drug2 = gr.Textbox(label="Drug 2", placeholder="e.g., Warfarin")
110
 
111
- predict_btn = gr.Button("Predict", variant="primary")
112
 
113
- output = gr.Markdown("## Results will appear here")
114
- drug_info = gr.Markdown()
115
- interaction_desc = gr.Textbox(label="Generated Description")
116
- status = gr.Textbox(label="Status")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
  predict_btn.click(
119
  predict_ddi,
 
1
  import gradio as gr
2
  import requests
3
  import time
4
+ import json
5
+ import joblib
6
+ from huggingface_hub import hf_hub_download
7
+ import numpy as np
8
 
9
+ # Try to import torch with fallback
10
  try:
11
+ import torch
12
+ from transformers import AutoTokenizer, AutoModel
13
+ TORCH_AVAILABLE = True
14
+ except ImportError:
15
+ print("Torch not available, using mock mode")
16
+ TORCH_AVAILABLE = False
17
+
18
+ class MockPredictor:
19
+ """Mock predictor for when torch is not available"""
20
+ def __init__(self):
21
+ self.classes = ["Mild", "Moderate", "No Interaction", "Severe"]
22
+
23
+ def predict(self, text):
24
+ return {
25
+ "prediction": "Moderate",
26
+ "confidence": 0.75,
27
+ "probabilities": {cls: 0.25 for cls in self.classes}
28
+ }
29
+
30
+ # Initialize predictor
31
+ if TORCH_AVAILABLE:
32
+ try:
33
+ # Define model class
34
+ class DrugInteractionClassifier(torch.nn.Module):
35
+ def __init__(self, n_classes, bert_model_name="emilyalsentzer/Bio_ClinicalBERT"):
36
+ super(DrugInteractionClassifier, self).__init__()
37
+ self.bert = AutoModel.from_pretrained(bert_model_name)
38
+ self.classifier = torch.nn.Sequential(
39
+ torch.nn.Linear(self.bert.config.hidden_size, 256),
40
+ torch.nn.ReLU(),
41
+ torch.nn.Dropout(0.3),
42
+ torch.nn.Linear(256, n_classes)
43
+ )
44
+
45
+ def forward(self, input_ids, attention_mask):
46
+ bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask)
47
+ pooled_output = bert_output[0][:, 0, :]
48
+ return self.classifier(pooled_output)
49
+
50
+ # Load model components
51
+ print("Downloading model files from Fredaaaaaa/drug_interaction_severity...")
52
+
53
+ # Download config
54
+ config_path = hf_hub_download(repo_id="Fredaaaaaa/drug_interaction_severity", filename="config.json")
55
+ with open(config_path, "r") as f:
56
+ config = json.load(f)
57
+
58
+ # Download label encoder
59
+ label_encoder_path = hf_hub_download(repo_id="Fredaaaaaa/drug_interaction_severity", filename="label_encoder.joblib")
60
+ label_encoder = joblib.load(label_encoder_path)
61
+
62
+ # Download model weights
63
+ model_path = hf_hub_download(repo_id="Fredaaaaaa/drug_interaction_severity", filename="pytorch_model.bin")
64
+
65
+ # Load tokenizer
66
+ tokenizer = AutoTokenizer.from_pretrained("Fredaaaaaa/drug_interaction_severity")
67
+
68
+ # Initialize model
69
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
70
+ model = DrugInteractionClassifier(n_classes=len(label_encoder.classes_))
71
+ model.load_state_dict(torch.load(model_path, map_location=device))
72
+ model.to(device)
73
+ model.eval()
74
+
75
+ print("βœ… Model loaded successfully!")
76
+
77
+ class RealPredictor:
78
+ def __init__(self, model, tokenizer, label_encoder, device, config):
79
+ self.model = model
80
+ self.tokenizer = tokenizer
81
+ self.label_encoder = label_encoder
82
+ self.device = device
83
+ self.config = config
84
+
85
+ def predict(self, text):
86
+ try:
87
+ # Tokenize
88
+ inputs = self.tokenizer(
89
+ text,
90
+ max_length=self.config.get("max_length", 128),
91
+ padding=True,
92
+ truncation=True,
93
+ return_tensors="pt"
94
+ )
95
+
96
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
97
+
98
+ # Predict
99
+ with torch.no_grad():
100
+ outputs = self.model(inputs["input_ids"], inputs["attention_mask"])
101
+ probabilities = torch.softmax(outputs, dim=1)
102
+ confidence, predicted_idx = torch.max(probabilities, dim=1)
103
+
104
+ predicted_label = self.label_encoder.inverse_transform([predicted_idx.item()])[0]
105
+
106
+ # Get all probabilities
107
+ all_probs = {
108
+ self.label_encoder.inverse_transform([i])[0]: prob.item()
109
+ for i, prob in enumerate(probabilities[0])
110
+ }
111
+
112
+ return {
113
+ "prediction": predicted_label,
114
+ "confidence": confidence.item(),
115
+ "probabilities": all_probs
116
+ }
117
+
118
+ except Exception as e:
119
+ return {
120
+ "prediction": f"Error: {str(e)}",
121
+ "confidence": 0.0,
122
+ "probabilities": {label: 0.0 for label in self.label_encoder.classes_}
123
+ }
124
+
125
+ predictor = RealPredictor(model, tokenizer, label_encoder, device, config)
126
+ MODEL_LOADED = True
127
+
128
+ except Exception as e:
129
+ print(f"Error loading real model: {e}")
130
+ predictor = MockPredictor()
131
+ MODEL_LOADED = False
132
+ else:
133
+ predictor = MockPredictor()
134
  MODEL_LOADED = False
 
135
 
136
  def fetch_pubchem_data(drug_name):
137
  """Fetch drug data from PubChem by name"""
 
148
 
149
  cid = search_response.json()['IdentifierList']['CID'][0]
150
 
151
+ # Fetch compound data
152
  compound_url = f"https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/cid/{cid}/property/CanonicalSMILES,MolecularWeight,IUPACName/JSON"
153
  compound_response = requests.get(compound_url, timeout=10)
154
 
 
173
  if mw1 and mw2:
174
  mw_diff = abs(mw1 - mw2)
175
  if mw_diff > 300:
176
+ descriptions.append("Significant molecular size difference may affect metabolism")
177
+ elif mw_diff > 100:
178
+ descriptions.append("Moderate molecular size difference")
179
 
180
  if not descriptions:
181
+ descriptions.append("Potential pharmacokinetic interaction")
182
 
183
  return ". ".join(descriptions) + ". Clinical evaluation recommended."
184
  except:
 
187
  def predict_ddi(drug1_name, drug2_name):
188
  """Main prediction function"""
189
  try:
 
 
 
190
  if not drug1_name or not drug2_name:
191
  return "Please enter both drug names", "", "", ""
192
 
 
205
 
206
  # Prepare output
207
  drug_info = f"""
208
+ **{drug1_name}**:
209
+ - Molecular Weight: {drug1_data.get('MolecularWeight', 'N/A')} g/mol
210
+ - IUPAC Name: {drug1_data.get('IUPACName', 'N/A')}
211
+
212
+ **{drug2_name}**:
213
+ - Molecular Weight: {drug2_data.get('MolecularWeight', 'N/A')} g/mol
214
+ - IUPAC Name: {drug2_data.get('IUPACName', 'N/A')}
215
  """
216
 
217
  prediction_output = f"""
218
+ ## πŸ” Prediction Results
219
+
220
+ **Severity:** **{result['prediction']}**
221
  **Confidence:** {result['confidence']:.1%}
222
+
223
+ **Generated Description:**
224
+ {interaction_description}
225
+
226
+ **Probabilities:**
227
+ {', '.join([f'{k}: {v:.1%}' for k, v in result['probabilities'].items()])}
228
  """
229
 
230
+ status = "βœ… Success" if MODEL_LOADED else "⚠️ Using mock data (torch not available)"
231
+
232
+ return prediction_output, drug_info, interaction_description, status
233
 
234
  except Exception as e:
235
  return f"Error: {str(e)}", "", "", ""
236
 
237
+ # Create interface
238
+ with gr.Blocks(title="Drug Interaction Predictor", theme=gr.themes.Soft()) as demo:
239
+ gr.Markdown("# πŸ’Š Drug Interaction Severity Predictor")
240
+ gr.Markdown("Model: [Fredaaaaaa/drug_interaction_severity](https://huggingface.co/Fredaaaaaa/drug_interaction_severity)")
241
 
242
  with gr.Row():
243
+ drug1 = gr.Textbox(label="First Drug", placeholder="e.g., Warfarin", value="Warfarin")
244
+ drug2 = gr.Textbox(label="Second Drug", placeholder="e.g., Aspirin", value="Aspirin")
245
 
246
+ predict_btn = gr.Button("πŸ”¬ Predict Interaction", variant="primary")
247
 
248
+ with gr.Row():
249
+ output = gr.Markdown("## πŸ“Š Results will appear here")
250
+
251
+ with gr.Row():
252
+ drug_info = gr.Markdown("### πŸ’Š Drug Properties")
253
+
254
+ with gr.Row():
255
+ interaction_desc = gr.Textbox(label="πŸ“ Generated Description", interactive=False)
256
+ status = gr.Textbox(label="πŸ”„ Status", interactive=False)
257
+
258
+ # Examples
259
+ gr.Examples(
260
+ examples=[
261
+ ["Warfarin", "Aspirin"],
262
+ ["Simvastatin", "Clarithromycin"],
263
+ ["Digoxin", "Quinine"],
264
+ ["Metformin", "Ibuprofen"]
265
+ ],
266
+ inputs=[drug1, drug2],
267
+ label="πŸ’‘ Example Drug Pairs"
268
+ )
269
 
270
  predict_btn.click(
271
  predict_ddi,