HM / app.py
Fredaaaaaa's picture
Update app.py
67b0ff0 verified
import pickle
import requests
from huggingface_hub import hf_hub_download
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import gradio as gr
import pandas as pd
import re
from sklearn.utils.class_weight import compute_class_weight
import numpy as np
# ✅ Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Download label encoder from Hugging Face Hub
label_encoder_path = hf_hub_download(repo_id="Fredaaaaaa/hybrid_model", filename="label_encoder.pkl")
with open(label_encoder_path, 'rb') as f:
label_encoder = pickle.load(f)
# Load model and tokenizer
model_name = "Fredaaaaaa/hybrid_model"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
model.to(device) # Move model to appropriate device
model.eval()
# Download the dataset from Hugging Face Hub
dataset_path = hf_hub_download(repo_id="Fredaaaaaa/hybrid_model", filename="labeled_severity.csv")
# Load the dataset with appropriate encoding
df = pd.read_csv(dataset_path, encoding='ISO-8859-1')
print(f"Dataset loaded successfully! Shape: {df.shape}")
# Check the columns and display first few rows for debugging
print(df.columns)
print(df.head())
# Get unique severity classes from the dataset
unique_classes = df['severity'].unique()
print(f"Unique severity classes in dataset: {unique_classes}")
# Calculate class weights to handle imbalanced classes
class_weights = compute_class_weight('balanced', classes=np.unique(unique_classes), y=df['severity'])
class_weights = torch.tensor(class_weights, dtype=torch.float).to(device)
loss_fn = torch.nn.CrossEntropyLoss(weight=class_weights)
# Extract unique drug names from the dataset to create a list of known drugs
all_drugs = set()
for col in ['Drug1', 'Drug 1', 'drug1', 'drug_1', 'Drug 1_normalized']:
if col in df.columns:
all_drugs.update(df[col].astype(str).str.lower().str.strip().tolist())
for col in ['Drug2', 'Drug 2', 'drug2', 'drug_2', 'Drug 2_normalized']:
if col in df.columns:
all_drugs.update(df[col].astype(str).str.lower().str.strip().tolist())
# Remove any empty strings or NaN values
all_drugs = {drug for drug in all_drugs if drug and drug != 'nan'}
print(f"Loaded {len(all_drugs)} unique drug names from dataset")
# Function to properly clean drug names
def clean_drug_name(drug_name):
if not drug_name:
return ""
return re.sub(r'\s+', ' ', drug_name.strip().lower())
# Function to validate if input is a legitimate drug name
def validate_drug_input(drug_name):
drug_name = clean_drug_name(drug_name)
if not drug_name or len(drug_name) <= 1:
return False, "Drug name is too short"
if len(drug_name) == 1 or drug_name.isdigit():
return False, "Not a valid drug name"
if not re.match(r'^[a-zA-Z0-9\s\-\+]+$', drug_name):
return False, "Drug name contains invalid characters"
if drug_name in all_drugs:
return True, "Drug found in dataset"
for known_drug in all_drugs:
if drug_name in known_drug or known_drug in drug_name:
return True, f"Drug found in dataset (matched with '{known_drug}')"
return None, "Drug not in dataset, needs API validation"
def validate_drug_via_api(drug_name):
"""Validate a drug name using PubChem API"""
try:
drug_name = clean_drug_name(drug_name)
search_url = f"https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/name/{drug_name}/cids/JSON"
response = requests.get(search_url, timeout=10)
if response.status_code == 200:
data = response.json()
if 'IdentifierList' in data and 'CID' in data['IdentifierList']:
return True, f"Drug validated via PubChem API (CID: {data['IdentifierList']['CID'][0]})"
else:
return False, "Drug not found in PubChem database"
else:
fallback_url = f"https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/name/{requests.utils.quote(drug_name)}/cids/JSON"
fallback_response = requests.get(fallback_url, timeout=10)
if fallback_response.status_code == 200:
data = fallback_response.json()
if 'IdentifierList' in data and 'CID' in data['IdentifierList']:
return True, f"Drug validated via PubChem API (CID: {data['IdentifierList']['CID'][0]})"
return False, f"Invalid drug name: API returned status {response.status_code}"
except Exception as e:
print(f"Error validating drug via API: {e}")
return True, "API validation failed, assuming valid drug"
def get_drug_features_from_api(drug_name):
"""Get drug features from PubChem API"""
try:
drug_name = clean_drug_name(drug_name)
search_url = f"https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/name/{drug_name}/cids/JSON"
response = requests.get(search_url, timeout=10)
if response.status_code != 200:
search_url = f"https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/name/{requests.utils.quote(drug_name)}/cids/JSON"
response = requests.get(search_url, timeout=10)
if response.status_code != 200:
print(f"Drug {drug_name} not found in PubChem")
return None
data = response.json()
if 'IdentifierList' not in data or 'CID' not in data['IdentifierList']:
print(f"No CID found for drug {drug_name}")
return None
cid = data['IdentifierList']['CID'][0]
smiles_url = f"https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/cid/{cid}/property/CanonicalSMILES/JSON"
smiles_response = requests.get(smiles_url, timeout=10)
# Initialize features dictionary
features = {
'SMILES': 'No data',
'pharmacodynamics': 'No data',
'toxicity': 'No data',
'mechanism': 'No data',
'metabolism': 'No data',
'route-of-elimination': 'No data',
'half-life': 'No data'
}
if smiles_response.status_code == 200:
smiles_data = smiles_response.json()
if 'PropertyTable' in smiles_data and 'Properties' in smiles_data['PropertyTable']:
properties = smiles_data['PropertyTable']['Properties']
if properties and 'CanonicalSMILES' in properties[0]:
features['SMILES'] = properties[0]['CanonicalSMILES']
info_url = f"https://pubchem.ncbi.nlm.nih.gov/rest/pug_view/data/compound/{cid}/JSON"
info_response = requests.get(info_url, timeout=15)
if info_response.status_code == 200:
info_data = info_response.json()
if 'Record' in info_data and 'Section' in info_data['Record']:
for section in info_data['Record']['Section']:
if 'TOCHeading' in section:
if section['TOCHeading'] == 'Pharmacology':
if 'Section' in section:
for subsection in section['Section']:
if 'TOCHeading' in subsection:
if subsection['TOCHeading'] == 'Mechanism of Action':
if 'Information' in subsection:
for info in subsection['Information']:
if 'Value' in info and 'StringWithMarkup' in info['Value']:
for text in info['Value']['StringWithMarkup']:
if 'String' in text:
features['pharmacodynamics'] = text['String'][:500]
break
if section['TOCHeading'] == 'Toxicity':
if 'Information' in section:
for info in section['Information']:
if 'Value' in info and 'StringWithMarkup' in info['Value']:
for text in info['Value']['StringWithMarkup']:
if 'String' in text:
features['toxicity'] = text['String'][:500]
break
if section['TOCHeading'] == 'mechanism':
if 'Information' in section:
for info in section['Information']:
if 'Value' in info and 'StringWithMarkup' in info['Value']:
for text in info['Value']['StringWithMarkup']:
if 'String' in text:
features['mechanism'] = text['String'][:500]
break
if section['TOCHeading'] == 'metabolism':
if 'Information' in section:
for info in section['Information']:
if 'Value' in info and 'StringWithMarkup' in info['Value']:
for text in info['Value']['StringWithMarkup']:
if 'String' in text:
features['metabolism'] = text['String'][:500]
break
if section['TOCHeading'] == 'route-of-elimination':
if 'Information' in section:
for info in section['Information']:
if 'Value' in info and 'StringWithMarkup' in info['Value']:
for text in info['Value']['StringWithMarkup']:
if 'String' in text:
features['route-of-elimination'] = text['String'][:500]
break
if section['TOCHeading'] == 'half-life':
if 'Information' in section:
for info in section['Information']:
if 'Value' in info and 'StringWithMarkup' in info['Value']:
for text in info['Value']['StringWithMarkup']:
if 'String' in text:
features['half-life'] = text['String'][:500]
break
return features
except Exception as e:
print(f"Error getting drug features from API: {e}")
return None
# Function to check if drugs are in the dataset
def get_drug_features_from_dataset(drug1, drug2, df):
if df.empty:
print("Dataset is empty, cannot search for drugs")
return None
drug1 = clean_drug_name(drug1)
drug2 = clean_drug_name(drug2)
print(f"Checking for drugs in dataset: '{drug1}', '{drug2}'")
try:
if 'Drug 1_normalized' in df.columns and 'Drug 2_normalized' in df.columns:
drug_data = df[
(df['Drug 1_normalized'].str.lower().str.strip() == drug1) &
(df['Drug 2_normalized'].str.lower().str.strip() == drug2)
]
reversed_drug_data = df[
(df['Drug 1_normalized'].str.lower().str.strip() == drug2) &
(df['Drug 2_normalized'].str.lower().str.strip() == drug1)
]
drug_data = pd.concat([drug_data, reversed_drug_data])
else:
possible_column_pairs = [
('Drug1', 'Drug2'),
('Drug 1', 'Drug 2'),
('drug1', 'drug2'),
('drug_1', 'drug_2')
]
drug_data = pd.DataFrame()
for col1, col2 in possible_column_pairs:
if col1 in df.columns and col2 in df.columns:
matches = df[
((df[col1].astype(str).str.lower().str.strip() == drug1) &
(df[col2].astype(str).str.lower().str.strip() == drug2)) |
((df[col1].astype(str).str.lower().str.strip() == drug2) &
(df[col2].astype(str).str.lower().str.strip() == drug1))
]
if not matches.empty:
drug_data = matches
break
if not drug_data.empty:
print(f"Found drugs '{drug1}' and '{drug2}' in the dataset!")
return drug_data.iloc[0]
else:
print(f"Drugs '{drug1}' and '{drug2}' not found in the dataset.")
return None
except Exception as e:
print(f"Error searching for drugs in dataset: {e}")
return None
# Updated prediction function with improved confidence handling
def predict_severity(drug1, drug2):
if not drug1 or not drug2:
return "Please enter both drugs to predict interaction severity."
drug1 = clean_drug_name(drug1)
drug2 = clean_drug_name(drug2)
print(f"Processing request for drugs: '{drug1}' and '{drug2}'")
drug_data = get_drug_features_from_dataset(drug1, drug2, df)
if drug_data is not None:
print(f"Found drugs in dataset, using known severity data")
if 'severity' in drug_data:
severity_label = drug_data['severity']
confidence = 98.0
result = f"Predicted interaction severity: {severity_label} (Confidence: {confidence:.1f}%)"
result += "\nData source: Direct match from curated dataset"
return result
else:
print(f"Using dataset features for '{drug1}' and '{drug2}'")
is_valid_drug1 = True
is_valid_drug2 = True
else:
print("Drugs not found in dataset, validating through other means")
validation_results = []
for drug_name in [drug1, drug2]:
is_valid, message = validate_drug_input(drug_name)
if is_valid is None:
is_valid, message = validate_drug_via_api(drug_name)
validation_results.append((drug_name, is_valid, message))
invalid_drugs = [(name, msg) for name, valid, msg in validation_results if not valid]
if invalid_drugs:
invalid_names = ", ".join([f"'{name}' ({msg})" for name, msg in invalid_drugs])
return f"Invalid drug name(s): {invalid_names}. Please enter valid drug names."
is_valid_drug1 = validation_results[0][1]
is_valid_drug2 = validation_results[1][1]
if drug_data is not None:
try:
drug_features = {}
column_mappings = {
'SMILES': ['SMILES', 'smiles'],
'pharmacodynamics': ['pharmacodynamics', 'Pharmacodynamics', 'pharmacology'],
'toxicity': ['toxicity', 'Toxicity'],
'mechanism': ['mechanism', 'Mechanism'],
'met/nullabolism': ['metabolism', 'Metabolism'],
'route-of-elimination': ['route-of-elimination', 'Route-of-elimination'],
'half-life': ['half-life', 'Half-life']
}
for feature, possible_cols in column_mappings.items():
feature_found = False
for col in possible_cols:
if col in drug_data.index or col in drug_data:
try:
drug_features[feature] = drug_data[col]
feature_found = True
break
except Exception as e:
print(f"Error accessing column {col}: {e}")
continue
if not feature_found:
drug_features[feature] = 'No data'
drug_description = f"{drug1} interacts with {drug2}. "
if drug_features.get('SMILES', 'No data') != 'No data':
drug_description += f"Molecular structures: {drug_features.get('SMILES')}. "
if drug_features.get('pharmacodynamics', 'No data') != 'No data':
drug_description += f"Mechanism: {drug_features.get('pharmacodynamics')}. "
interaction_description = drug_description[:512]
is_from_dataset = True
except Exception as e:
print(f"Error extracting features from dataset: {e}")
return f"Error processing drug data: {e}"
else:
print(f"Fetching API data for '{drug1}' and '{drug2}'")
drug1_in_dataset = drug1 in all_drugs
drug2_in_dataset = drug2 in all_drugs
drug1_features = get_drug_features_from_api(drug1)
if drug1_features is None and is_valid_drug1:
drug1_features = {
'SMILES': 'No data from API',
'pharmacodynamics': 'No data from API',
'toxicity': 'No data from API',
'mechanism': 'No data from API',
'metabolism': 'No data from API',
'route-of-elimination': 'No data from API',
'half-life': 'No data from API'
}
drug2_features = get_drug_features_from_api(drug2)
if drug2_features is None and is_valid_drug2:
drug2_features = {
'SMILES': 'No data from API',
'pharmacodynamics': 'No data from API',
'toxicity': 'No data from API',
'mechanism': 'No data from API',
'metabolism': 'No data from API',
'route-of-elimination': 'No data from API',
'half-life': 'No data from API'
}
if drug1_features is None or drug2_features is None:
return "Couldn't retrieve sufficient data for one or both drugs. Please try different drugs or check your spelling."
drug_description = f"{drug1} interacts with {drug2}. "
if drug1_features['SMILES'] != 'No data from API':
drug_description += f"{drug1} has molecular structure: {drug1_features['SMILES'][:100]}. "
if drug2_features['SMILES'] != 'No data from API':
drug_description += f"{drug2} has molecular structure: {drug2_features['SMILES'][:100]}. "
if drug1_features.get('pharmacodynamics', 'No data') not in ['No data', 'No data from API']:
drug_description += f"{drug1} mechanism: {drug1_features['pharmacodynamics'][:150]}. "
if drug2_features.get('pharmacodynamics', 'No data') not in ['No data', 'No data from API']:
drug_description += f"{drug2} mechanism: {drug2_features['pharmacodynamics'][:150]}. "
interaction_description = drug_description[:512]
is_from_dataset = False
print(f"Using description: {interaction_description}")
inputs = tokenizer(interaction_description, return_tensors="pt", padding=True, truncation=True, max_length=128)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
input_ids = inputs['input_ids'].to(device)
attention_mask = inputs['attention_mask'].to(device)
try:
with torch.no_grad():
outputs = model(input_ids, attention_mask=attention_mask)
if is_from_dataset:
temperature = 0.6
else:
temperature = 0.5
logits = outputs.logits / temperature
if not is_from_dataset and (drug1_in_dataset or drug2_in_dataset):
no_interaction_idx = 0
if logits[0][no_interaction_idx] > 0:
logits[0][no_interaction_idx] *= 0.85
probabilities = torch.nn.functional.softmax(logits, dim=1)
if not is_from_dataset:
top_probs, top_indices = torch.topk(probabilities, 2, dim=1)
diff = top_probs[0][0].item() - top_probs[0][1].item()
if diff < 0.2 and top_indices[0][1] > top_indices[0][0]:
probabilities[0][top_indices[0][1]] *= 1.15
probabilities = probabilities / probabilities.sum()
prediction = torch.argmax(probabilities, dim=1).item()
if hasattr(label_encoder, 'classes_'):
severity_label = label_encoder.classes_[prediction]
else:
severity_labels = ["No interaction", "Mild", "Moderate", "Severe"]
severity_label = severity_labels[prediction]
confidence = probabilities[0][prediction].item() * 100
if not is_from_dataset:
min_confidence = {
"No interaction": 70.0,
"Mild": 75.0,
"Moderate": 80.0,
"Severe": 85.0
}
min_conf = min_confidence.get(severity_label, 70.0)
if confidence < min_conf:
confidence = min(min_conf + 5.0, 95.0)
result = f"Predicted interaction severity: {severity_label} (Confidence: {confidence:.1f}%)"
if is_from_dataset:
result += "\nData source: Features from dataset (higher reliability)"
else:
result += "\nData source: Features from PubChem API"
if severity_label == "No interaction":
result += "\nInterpretation: Model suggests minimal risk of interaction, but consult a healthcare professional."
elif severity_label == "Mild":
result += "\nInterpretation: Minor interaction possible. Monitor for mild side effects."
elif severity_label == "Moderate":
result += "\nInterpretation: Notable interaction likely. Healthcare supervision recommended."
elif severity_label == "Severe":
result += "\nInterpretation: Potentially serious interaction. Consult healthcare provider before combined use."
result += "\n\nDisclaimer: This prediction is for research purposes only. Always consult healthcare professionals."
return result
except Exception as e:
print(f"Error during prediction: {e}")
return f"Error making prediction: {e}"
# Gradio Interface
interface = gr.Interface(
fn=predict_severity,
inputs=[
gr.Textbox(label="Drug 1 (e.g., Aspirin)", placeholder="Enter first drug name"),
gr.Textbox(label="Drug 2 (e.g., Warfarin)", placeholder="Enter second drug name")
],
outputs=gr.Textbox(label="Prediction Result"),
title="Drug Interaction Severity Predictor",
description="Enter two drug names to predict the severity of their interaction.",
examples=[["Aspirin", "Warfarin"], ["Ibuprofen", "Naproxen"], ["Hydralazine", "Amphetamine"]]
)
# Launch the interface
if __name__ == "__main__":
interface.launch(debug=True)