smilesssssss / app.py
Fredaaaaaa's picture
Update app.py
22bcd30 verified
import pickle
import requests
import torch
import gradio as gr
import pandas as pd
import re
import numpy as np
import os
import shutil
from huggingface_hub import hf_hub_download
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from sklearn.utils.class_weight import compute_class_weight
from collections import defaultdict
print("Starting script execution...")
# Helper functions (moved up)
def clean_drug_name(drug_name):
if not drug_name:
return ""
return re.sub(r'\s+', ' ', drug_name.strip().lower())
def validate_drug_input(drug_name):
drug_name = clean_drug_name(drug_name)
if not drug_name or len(drug_name) <= 1:
return False, "Drug name is too short"
if len(drug_name) == 1 or drug_name.isdigit():
return False, "Not a valid drug name"
if not re.match(r'^[a-zA-Z0-9\s\-\+]+$', drug_name):
return False, "Drug name contains invalid characters"
if drug_name in all_drugs:
return True, "Drug found in dataset"
for known_drug in all_drugs:
if drug_name in known_drug or known_drug in drug_name:
return True, f"Drug found in dataset (matched with '{known_drug}')"
return None, "Drug not in dataset, needs API validation"
def validate_drug_via_api(drug_name):
try:
drug_name = clean_drug_name(drug_name)
search_url = f"https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/name/{drug_name}/cids/JSON"
response = requests.get(search_url, timeout=10)
if response.status_code == 200:
data = response.json()
if 'IdentifierList' in data and 'CID' in data['IdentifierList']:
return True, f"Drug validated via PubChem API (CID: {data['IdentifierList']['CID'][0]})"
return False, "Drug not found in PubChem database"
else:
fallback_url = f"https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/name/{requests.utils.quote(drug_name)}/cids/JSON"
fallback_response = requests.get(fallback_url, timeout=10)
if fallback_response.status_code == 200:
data = fallback_response.json()
if 'IdentifierList' in data and 'CID' in data['IdentifierList']:
return True, f"Drug validated via PubChem API (CID: {data['IdentifierList']['CID'][0]})"
return False, f"Invalid drug name: API returned status {response.status_code}"
except Exception as e:
print(f"Error validating drug via API: {e}")
return True, "API validation failed, assuming valid drug"
def get_smiles_from_api(drug_name):
if drug_name in drug_features_cache and 'smiles' in drug_features_cache[drug_name]:
return drug_features_cache[drug_name]['smiles']
try:
drug_name = clean_drug_name(drug_name)
search_url = f"https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/name/{drug_name}/cids/JSON"
response = requests.get(search_url, timeout=10)
if response.status_code != 200:
search_url = f"https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/name/{requests.utils.quote(drug_name)}/cids/JSON"
response = requests.get(search_url, timeout=10)
if response.status_code != 200:
print(f"Drug {drug_name} not found in PubChem")
return None
data = response.json()
if 'IdentifierList' not in data or 'CID' not in data['IdentifierList']:
print(f"No CID found for drug {drug_name}")
return None
cid = data['IdentifierList']['CID'][0]
smiles_url = f"https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/cid/{cid}/property/CanonicalSMILES/JSON"
smiles_response = requests.get(smiles_url, timeout=10)
if smiles_response.status_code == 200:
smiles_data = smiles_response.json()
if 'PropertyTable' in smiles_data and 'Properties' in smiles_data['PropertyTable']:
properties = smiles_data['PropertyTable']['Properties']
if properties and 'CanonicalSMILES' in properties[0]:
smiles = properties[0]['CanonicalSMILES']
drug_features_cache[drug_name]['smiles'] = smiles
return smiles
print(f"No SMILES found for drug {drug_name}")
return None
except Exception as e:
print(f"Error getting SMILES from API: {e}")
return None
def get_pubchem_features(smiles):
if smiles and any(drug for drug, data in drug_features_cache.items() if data.get('smiles') == smiles and 'features' in data):
for drug, data in drug_features_cache.items():
if data.get('smiles') == smiles and 'features' in data:
return data['features']
try:
properties_url = f"https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/smiles/{requests.utils.quote(smiles)}/property/MolecularWeight,XLogP,TPSA,RotatableBondCount,HBondDonorCount,HBondAcceptorCount,Complexity,ExactMass/JSON"
response = requests.get(properties_url, timeout=10)
if response.status_code == 200:
data = response.json()
if 'PropertyTable' in data and 'Properties' in data['PropertyTable']:
props = data['PropertyTable']['Properties'][0]
features = {
'molecular_weight': props.get('MolecularWeight', 0),
'xlogp': props.get('XLogP', 0),
'tpsa': props.get('TPSA', 0),
'rotatable_bond_count': props.get('RotatableBondCount', 0),
'h_bond_donor_count': props.get('HBondDonorCount', 0),
'h_bond_acceptor_count': props.get('HBondAcceptorCount', 0),
'complexity': props.get('Complexity', 0),
'exact_mass': props.get('ExactMass', 0)
}
print(f"Extracted features for SMILES {smiles}: {features}")
if smiles in [data['smiles'] for data in drug_features_cache.values()]:
for drug, data in drug_features_cache.items():
if data.get('smiles') == smiles:
data['features'] = features
return features
print(f"Failed to retrieve features for SMILES {smiles}")
return None
except Exception as e:
print(f"Error getting PubChem features: {e}")
return None
def get_drug_features_from_dataset(drug1, drug2, df):
if df.empty:
print("Dataset is empty")
return None, None, None
drug1 = clean_drug_name(drug1)
drug2 = clean_drug_name(drug2)
try:
if 'Drug 1_normalized' in df.columns and 'Drug 2_normalized' in df.columns:
drug_data = df[
(df['Drug 1_normalized'].str.lower().str.strip() == drug1) &
(df['Drug 2_normalized'].str.lower().str.strip() == drug2)
]
reversed_drug_data = df[
(df['Drug 1_normalized'].str.lower().str.strip() == drug2) &
(df['Drug 2_normalized'].str.lower().str.strip() == drug1)
]
drug_data = pd.concat([drug_data, reversed_drug_data])
else:
drug_data = pd.DataFrame()
for col1, col2 in [('Drug1', 'Drug2'), ('Drug 1', 'Drug 2'), ('drug1', 'drug2'), ('drug_1', 'drug_2')]:
if col1 in df.columns and col2 in df.columns:
matches = df[
((df[col1].astype(str).str.lower().str.strip() == drug1) &
(df[col2].astype(str).str.lower().str.strip() == drug2)) |
((df[col1].astype(str).str.lower().str.strip() == drug2) &
(df[col2].astype(str).str.lower().str.strip() == drug1))
]
if not matches.empty:
drug_data = matches
break
if not drug_data.empty:
print(f"Found drugs '{drug1}' and '{drug2}' in dataset")
smiles1 = drug_data.get('canonical_smiles', None)
smiles2 = drug_data.get('canonical_smiles_2', None)
if isinstance(smiles1, pd.Series):
smiles1 = smiles1.iloc[0]
if isinstance(smiles2, pd.Series):
smiles2 = smiles2.iloc[0]
severity = drug_data.get('severity', None)
if isinstance(severity, pd.Series):
severity = severity.iloc[0]
return smiles1, smiles2, severity
return None, None, None
except Exception as e:
print(f"Error searching dataset: {e}")
return None, None, None
def predict_severity(drug_input):
if not drug_input:
return "Please enter at least one drug pair separated by commas (e.g., 'Aspirin, Warfarin')."
# Parse drug pairs from comma-separated input
drug_pairs = [pair.strip() for pair in drug_input.split(',') if pair.strip()]
if len(drug_pairs) < 2 or len(drug_pairs) % 2 != 0:
return "Please enter drug pairs separated by commas (e.g., 'Aspirin, Warfarin' or 'Aspirin, Warfarin, Ibuprofen, Naproxen')."
results = []
for i in range(0, len(drug_pairs), 2):
if i + 1 >= len(drug_pairs):
break
drug1, drug2 = drug_pairs[i], drug_pairs[i + 1]
print(f"Processing: '{drug1}', '{drug2}'")
smiles1, smiles2, severity = get_drug_features_from_dataset(drug1, drug2, df)
if severity is not None:
results.append(severity)
continue
# Fallback to PubChem for drugs not in dataset
validation_results = []
for drug_name in [drug1, drug2]:
is_valid, message = validate_drug_input(drug_name)
if is_valid is None:
is_valid, message = validate_drug_via_api(drug_name)
validation_results.append((drug_name, is_valid, message))
invalid_drugs = [(name, msg) for name, valid, msg in validation_results if not valid]
if invalid_drugs:
results.append(f"Invalid drug(s) for {drug1} and {drug2}: {', '.join([f'{name} ({msg})' for name, msg in invalid_drugs])}")
continue
# Fetch SMILES from PubChem if not in dataset
drug1_in_dataset = drug1 in all_drugs
drug2_in_dataset = drug2 in all_drugs
if smiles1 is None:
smiles1 = get_smiles_from_api(drug1)
if smiles2 is None:
smiles2 = get_smiles_from_api(drug2)
if smiles1 is None or smiles2 is None:
results.append(f"Could not retrieve SMILES for {drug1 if smiles1 is None else ''}{', ' if smiles1 is None and smiles2 is None else ''}{drug2 if smiles2 is None else ''} from PubChem.")
continue
# Extract PubChem features with fallback
features1 = drug_features_cache[drug1].get('features') if drug1 in drug_features_cache and 'features' in drug_features_cache[drug1] else get_pubchem_features(smiles1)
features2 = drug_features_cache[drug2].get('features') if drug2 in drug_features_cache and 'features' in drug_features_cache[drug2] else get_pubchem_features(smiles2)
if features1 is None:
print(f"Warning: No features retrieved for {drug1}, using default values.")
features1 = {'molecular_weight': 0, 'xlogp': 0, 'tpsa': 0, 'rotatable_bond_count': 0,
'h_bond_donor_count': 0, 'h_bond_acceptor_count': 0, 'complexity': 0, 'exact_mass': 0}
if features2 is None:
print(f"Warning: No features retrieved for {drug2}, using default values.")
features2 = {'molecular_weight': 0, 'xlogp': 0, 'tpsa': 0, 'rotatable_bond_count': 0,
'h_bond_donor_count': 0, 'h_bond_acceptor_count': 0, 'complexity': 0, 'exact_mass': 0}
# Combine SMILES and features into interaction description
mw1 = features1.get('molecular_weight', 0)
mw2 = features2.get('molecular_weight', 0)
if not isinstance(mw1, (int, float)) or not isinstance(mw2, (int, float)):
mw1, mw2 = 0, 0 # Fallback if not numeric
drug_description = (f"{drug1} SMILES: {smiles1[:50]}, MW: {mw1:.0f}. "
f"{drug2} SMILES: {smiles2[:50]}, MW: {mw2:.0f}.")
interaction_description = drug_description[:256] # Reduced max length
is_from_dataset = False
if 'canonical_smiles' in df.columns and 'canonical_smiles_2' in df.columns:
is_from_dataset = smiles1 in df['canonical_smiles'].values and smiles2 in df['canonical_smiles_2'].values
print(f"Using description: {interaction_description}")
inputs = tokenizer(interaction_description, return_tensors="pt", padding=True, truncation=True, max_length=128)
input_ids = inputs['input_ids'].to(device)
attention_mask = inputs['attention_mask'].to(device)
try:
with torch.no_grad():
outputs = model(input_ids, attention_mask=attention_mask)
temperature = 0.6 if is_from_dataset else 0.5
logits = outputs.logits / temperature
if not is_from_dataset and (drug1_in_dataset or drug2_in_dataset):
no_interaction_idx = 0
if logits[0][no_interaction_idx] > 0:
logits[0][no_interaction_idx] *= 0.85
probabilities = torch.nn.functional.softmax(logits, dim=1)
if not is_from_dataset:
top_probs, top_indices = torch.topk(probabilities, 2, dim=1)
diff = top_probs[0][0].item() - top_probs[0][1].item()
if diff < 0.2 and top_indices[0][1] > top_indices[0][0]:
probabilities[0][top_indices[0][1]] *= 1.15
probabilities = probabilities / probabilities.sum()
prediction = torch.argmax(probabilities, dim=1).item()
severity_label = label_encoder.classes_[prediction]
results.append(severity_label)
except Exception as e:
print(f"Error during prediction for {drug1} and {drug2}: {e}")
results.append(f"Error for {drug1} and {drug2}: {e}")
return "\n\n".join(results) if results else "No valid predictions."
print("Setting up device...")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Model and dataset paths
print("Setting up model and dataset paths...")
model_name = "Fredaaaaaa/hybrid_model"
output_dir = "/home/user/app/drug_interaction_model"
# Create output directory
print("Creating output directory...")
os.makedirs(output_dir, exist_ok=True)
# Download and load label encoder
print("Downloading and loading label encoder...")
label_encoder_path = hf_hub_download(repo_id=model_name, filename="label_encoder.pkl")
with open(label_encoder_path, 'rb') as f:
label_encoder = pickle.load(f)
# Load model and tokenizer
print("Loading model and tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
model.to(device)
model.eval()
# Download and load dataset
print("Downloading and loading dataset...")
dataset_path = hf_hub_download(repo_id=model_name, filename="merged_cleaned_dataset.csv")
df = pd.read_csv(dataset_path, encoding='ISO-8859-1')
print(f"Dataset loaded successfully! Shape: {df.shape}")
print(f"Columns: {df.columns.tolist()}") # Print columns to debug
print(df.head())
# Save model, tokenizer, label encoder, and dataset
print("Saving model, tokenizer, label encoder, and dataset...")
model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
with open(os.path.join(output_dir, 'label_encoder.pkl'), 'wb') as f:
pickle.dump(label_encoder, f)
df.to_csv(os.path.join(output_dir, 'merged_cleaned_dataset.csv'), index=False)
# Create zip archive
print("Creating zip archive...")
zip_path = "/home/user/app/drug_interaction_model.zip"
shutil.make_archive("/home/user/app/drug_interaction_model", 'zip', output_dir)
print(f"📦 Model saved and zipped at: {zip_path}")
print(f"To download, access the file at: {zip_path} from your environment or server.")
# Compute class weights
print("Computing class weights...")
unique_classes = df['severity'].unique()
print(f"Unique severity classes: {unique_classes}")
class_weights = compute_class_weight('balanced', classes=np.unique(unique_classes), y=df['severity'])
class_weights = torch.tensor(class_weights, dtype=torch.float).to(device)
# Extract unique drug names and precompute dataset features
print("Extracting unique drug names and precomputing features...")
all_drugs = set()
drug_features_cache = defaultdict(dict)
for col in ['Drug 1_normalized', 'Drug1', 'Drug 1', 'drug1', 'drug_1']:
if col in df.columns:
all_drugs.update(df[col].astype(str).str.lower().str.strip().tolist())
for col in ['Drug 2_normalized', 'Drug2', 'Drug 2', 'drug2', 'drug_2']:
if col in df.columns:
all_drugs.update(df[col].astype(str).str.lower().str.strip().tolist())
all_drugs = {drug for drug in all_drugs if drug and drug != 'nan'}
print(f"Loaded {len(all_drugs)} unique drug names")
# Precompute SMILES and features for dataset drugs
print("Precomputing SMILES and features...")
for index, row in df.iterrows():
for col in ['Drug 1_normalized', 'Drug1', 'Drug 1', 'drug1', 'drug_1']:
if col in df.columns and pd.notna(row[col]):
drug = clean_drug_name(row[col])
if 'canonical_smiles' in df.columns and pd.notna(row['canonical_smiles']):
drug_features_cache[drug]['smiles'] = row['canonical_smiles']
drug_features_cache[drug]['features'] = get_pubchem_features(row['canonical_smiles'])
for col in ['Drug 2_normalized', 'Drug2', 'Drug 2', 'drug2', 'drug_2']:
if col in df.columns and pd.notna(row[col]):
drug = clean_drug_name(row[col])
if 'canonical_smiles_2' in df.columns and pd.notna(row['canonical_smiles_2']):
drug_features_cache[drug]['smiles'] = row['canonical_smiles_2']
drug_features_cache[drug]['features'] = get_pubchem_features(row['canonical_smiles_2'])
# Gradio Interface
print("Setting up Gradio interface...")
interface = gr.Interface(
fn=predict_severity,
inputs=gr.Textbox(label="Drug Pairs (e.g., 'Aspirin, Warfarin' or 'Aspirin, Warfarin, Ibuprofen, Naproxen')",
placeholder="Enter drug names separated by commas"),
outputs=gr.Textbox(label="Prediction Result"),
title="Drug Interaction Severity Predictor",
description="Enter drug pairs separated by commas to predict interaction severity based on SMILES and PubChem features.",
examples=[["Aspirin, Warfarin"], ["Ibuprofen, Naproxen, Hydralazine, Amphetamine"]]
)
print("Launching Gradio interface...")
if _name_ == "_main_":
interface.launch(debug=True)