Spaces:
Sleeping
Sleeping
| import pickle | |
| import requests | |
| from huggingface_hub import hf_hub_download | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| import torch | |
| import gradio as gr | |
| import pandas as pd | |
| import re | |
| from sklearn.utils.class_weight import compute_class_weight | |
| import numpy as np | |
| # ✅ Device setup | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Using device: {device}") | |
| # Download label encoder from Hugging Face Hub | |
| label_encoder_path = hf_hub_download(repo_id="Fredaaaaaa/hybrid_model", filename="label_encoder.pkl") | |
| with open(label_encoder_path, 'rb') as f: | |
| label_encoder = pickle.load(f) | |
| # Load model and tokenizer | |
| model_name = "Fredaaaaaa/hybrid_model" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForSequenceClassification.from_pretrained(model_name) | |
| model.to(device) # Move model to appropriate device | |
| model.eval() | |
| # Download the dataset from Hugging Face Hub | |
| dataset_path = hf_hub_download(repo_id="Fredaaaaaa/hybrid_model", filename="labeled_severity.csv") | |
| # Load the dataset with appropriate encoding | |
| df = pd.read_csv(dataset_path, encoding='ISO-8859-1') | |
| print(f"Dataset loaded successfully! Shape: {df.shape}") | |
| # Check the columns and display first few rows for debugging | |
| print(df.columns) | |
| print(df.head()) | |
| # Get unique severity classes from the dataset | |
| unique_classes = df['severity'].unique() | |
| print(f"Unique severity classes in dataset: {unique_classes}") | |
| # Calculate class weights to handle imbalanced classes | |
| class_weights = compute_class_weight('balanced', classes=np.unique(unique_classes), y=df['severity']) | |
| class_weights = torch.tensor(class_weights, dtype=torch.float).to(device) | |
| loss_fn = torch.nn.CrossEntropyLoss(weight=class_weights) | |
| # Extract unique drug names from the dataset to create a list of known drugs | |
| all_drugs = set() | |
| for col in ['Drug1', 'Drug 1', 'drug1', 'drug_1', 'Drug 1_normalized']: | |
| if col in df.columns: | |
| all_drugs.update(df[col].astype(str).str.lower().str.strip().tolist()) | |
| for col in ['Drug2', 'Drug 2', 'drug2', 'drug_2', 'Drug 2_normalized']: | |
| if col in df.columns: | |
| all_drugs.update(df[col].astype(str).str.lower().str.strip().tolist()) | |
| # Remove any empty strings or NaN values | |
| all_drugs = {drug for drug in all_drugs if drug and drug != 'nan'} | |
| print(f"Loaded {len(all_drugs)} unique drug names from dataset") | |
| # Function to properly clean drug names | |
| def clean_drug_name(drug_name): | |
| if not drug_name: | |
| return "" | |
| return re.sub(r'\s+', ' ', drug_name.strip().lower()) | |
| # Function to validate if input is a legitimate drug name | |
| def validate_drug_input(drug_name): | |
| drug_name = clean_drug_name(drug_name) | |
| if not drug_name or len(drug_name) <= 1: | |
| return False, "Drug name is too short" | |
| if len(drug_name) == 1 or drug_name.isdigit(): | |
| return False, "Not a valid drug name" | |
| if not re.match(r'^[a-zA-Z0-9\s\-\+]+$', drug_name): | |
| return False, "Drug name contains invalid characters" | |
| if drug_name in all_drugs: | |
| return True, "Drug found in dataset" | |
| for known_drug in all_drugs: | |
| if drug_name in known_drug or known_drug in drug_name: | |
| return True, f"Drug found in dataset (matched with '{known_drug}')" | |
| return None, "Drug not in dataset, needs API validation" | |
| def validate_drug_via_api(drug_name): | |
| """Validate a drug name using PubChem API""" | |
| try: | |
| drug_name = clean_drug_name(drug_name) | |
| search_url = f"https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/name/{drug_name}/cids/JSON" | |
| response = requests.get(search_url, timeout=10) | |
| if response.status_code == 200: | |
| data = response.json() | |
| if 'IdentifierList' in data and 'CID' in data['IdentifierList']: | |
| return True, f"Drug validated via PubChem API (CID: {data['IdentifierList']['CID'][0]})" | |
| else: | |
| return False, "Drug not found in PubChem database" | |
| else: | |
| fallback_url = f"https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/name/{requests.utils.quote(drug_name)}/cids/JSON" | |
| fallback_response = requests.get(fallback_url, timeout=10) | |
| if fallback_response.status_code == 200: | |
| data = fallback_response.json() | |
| if 'IdentifierList' in data and 'CID' in data['IdentifierList']: | |
| return True, f"Drug validated via PubChem API (CID: {data['IdentifierList']['CID'][0]})" | |
| return False, f"Invalid drug name: API returned status {response.status_code}" | |
| except Exception as e: | |
| print(f"Error validating drug via API: {e}") | |
| return True, "API validation failed, assuming valid drug" | |
| def get_drug_features_from_api(drug_name): | |
| """Get drug features from PubChem API""" | |
| try: | |
| drug_name = clean_drug_name(drug_name) | |
| search_url = f"https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/name/{drug_name}/cids/JSON" | |
| response = requests.get(search_url, timeout=10) | |
| if response.status_code != 200: | |
| search_url = f"https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/name/{requests.utils.quote(drug_name)}/cids/JSON" | |
| response = requests.get(search_url, timeout=10) | |
| if response.status_code != 200: | |
| print(f"Drug {drug_name} not found in PubChem") | |
| return None | |
| data = response.json() | |
| if 'IdentifierList' not in data or 'CID' not in data['IdentifierList']: | |
| print(f"No CID found for drug {drug_name}") | |
| return None | |
| cid = data['IdentifierList']['CID'][0] | |
| smiles_url = f"https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/cid/{cid}/property/CanonicalSMILES/JSON" | |
| smiles_response = requests.get(smiles_url, timeout=10) | |
| # Initialize features dictionary | |
| features = { | |
| 'SMILES': 'No data', | |
| 'pharmacodynamics': 'No data', | |
| 'toxicity': 'No data', | |
| 'mechanism': 'No data', | |
| 'metabolism': 'No data', | |
| 'route-of-elimination': 'No data', | |
| 'half-life': 'No data' | |
| } | |
| if smiles_response.status_code == 200: | |
| smiles_data = smiles_response.json() | |
| if 'PropertyTable' in smiles_data and 'Properties' in smiles_data['PropertyTable']: | |
| properties = smiles_data['PropertyTable']['Properties'] | |
| if properties and 'CanonicalSMILES' in properties[0]: | |
| features['SMILES'] = properties[0]['CanonicalSMILES'] | |
| info_url = f"https://pubchem.ncbi.nlm.nih.gov/rest/pug_view/data/compound/{cid}/JSON" | |
| info_response = requests.get(info_url, timeout=15) | |
| if info_response.status_code == 200: | |
| info_data = info_response.json() | |
| if 'Record' in info_data and 'Section' in info_data['Record']: | |
| for section in info_data['Record']['Section']: | |
| if 'TOCHeading' in section: | |
| if section['TOCHeading'] == 'Pharmacology': | |
| if 'Section' in section: | |
| for subsection in section['Section']: | |
| if 'TOCHeading' in subsection: | |
| if subsection['TOCHeading'] == 'Mechanism of Action': | |
| if 'Information' in subsection: | |
| for info in subsection['Information']: | |
| if 'Value' in info and 'StringWithMarkup' in info['Value']: | |
| for text in info['Value']['StringWithMarkup']: | |
| if 'String' in text: | |
| features['pharmacodynamics'] = text['String'][:500] | |
| break | |
| if section['TOCHeading'] == 'Toxicity': | |
| if 'Information' in section: | |
| for info in section['Information']: | |
| if 'Value' in info and 'StringWithMarkup' in info['Value']: | |
| for text in info['Value']['StringWithMarkup']: | |
| if 'String' in text: | |
| features['toxicity'] = text['String'][:500] | |
| break | |
| if section['TOCHeading'] == 'mechanism': | |
| if 'Information' in section: | |
| for info in section['Information']: | |
| if 'Value' in info and 'StringWithMarkup' in info['Value']: | |
| for text in info['Value']['StringWithMarkup']: | |
| if 'String' in text: | |
| features['mechanism'] = text['String'][:500] | |
| break | |
| if section['TOCHeading'] == 'metabolism': | |
| if 'Information' in section: | |
| for info in section['Information']: | |
| if 'Value' in info and 'StringWithMarkup' in info['Value']: | |
| for text in info['Value']['StringWithMarkup']: | |
| if 'String' in text: | |
| features['metabolism'] = text['String'][:500] | |
| break | |
| if section['TOCHeading'] == 'route-of-elimination': | |
| if 'Information' in section: | |
| for info in section['Information']: | |
| if 'Value' in info and 'StringWithMarkup' in info['Value']: | |
| for text in info['Value']['StringWithMarkup']: | |
| if 'String' in text: | |
| features['route-of-elimination'] = text['String'][:500] | |
| break | |
| if section['TOCHeading'] == 'half-life': | |
| if 'Information' in section: | |
| for info in section['Information']: | |
| if 'Value' in info and 'StringWithMarkup' in info['Value']: | |
| for text in info['Value']['StringWithMarkup']: | |
| if 'String' in text: | |
| features['half-life'] = text['String'][:500] | |
| break | |
| return features | |
| except Exception as e: | |
| print(f"Error getting drug features from API: {e}") | |
| return None | |
| # Function to check if drugs are in the dataset | |
| def get_drug_features_from_dataset(drug1, drug2, df): | |
| if df.empty: | |
| print("Dataset is empty, cannot search for drugs") | |
| return None | |
| drug1 = clean_drug_name(drug1) | |
| drug2 = clean_drug_name(drug2) | |
| print(f"Checking for drugs in dataset: '{drug1}', '{drug2}'") | |
| try: | |
| if 'Drug 1_normalized' in df.columns and 'Drug 2_normalized' in df.columns: | |
| drug_data = df[ | |
| (df['Drug 1_normalized'].str.lower().str.strip() == drug1) & | |
| (df['Drug 2_normalized'].str.lower().str.strip() == drug2) | |
| ] | |
| reversed_drug_data = df[ | |
| (df['Drug 1_normalized'].str.lower().str.strip() == drug2) & | |
| (df['Drug 2_normalized'].str.lower().str.strip() == drug1) | |
| ] | |
| drug_data = pd.concat([drug_data, reversed_drug_data]) | |
| else: | |
| possible_column_pairs = [ | |
| ('Drug1', 'Drug2'), | |
| ('Drug 1', 'Drug 2'), | |
| ('drug1', 'drug2'), | |
| ('drug_1', 'drug_2') | |
| ] | |
| drug_data = pd.DataFrame() | |
| for col1, col2 in possible_column_pairs: | |
| if col1 in df.columns and col2 in df.columns: | |
| matches = df[ | |
| ((df[col1].astype(str).str.lower().str.strip() == drug1) & | |
| (df[col2].astype(str).str.lower().str.strip() == drug2)) | | |
| ((df[col1].astype(str).str.lower().str.strip() == drug2) & | |
| (df[col2].astype(str).str.lower().str.strip() == drug1)) | |
| ] | |
| if not matches.empty: | |
| drug_data = matches | |
| break | |
| if not drug_data.empty: | |
| print(f"Found drugs '{drug1}' and '{drug2}' in the dataset!") | |
| return drug_data.iloc[0] | |
| else: | |
| print(f"Drugs '{drug1}' and '{drug2}' not found in the dataset.") | |
| return None | |
| except Exception as e: | |
| print(f"Error searching for drugs in dataset: {e}") | |
| return None | |
| # Updated prediction function with improved confidence handling | |
| def predict_severity(drug1, drug2): | |
| if not drug1 or not drug2: | |
| return "Please enter both drugs to predict interaction severity." | |
| drug1 = clean_drug_name(drug1) | |
| drug2 = clean_drug_name(drug2) | |
| print(f"Processing request for drugs: '{drug1}' and '{drug2}'") | |
| drug_data = get_drug_features_from_dataset(drug1, drug2, df) | |
| if drug_data is not None: | |
| print(f"Found drugs in dataset, using known severity data") | |
| if 'severity' in drug_data: | |
| severity_label = drug_data['severity'] | |
| confidence = 98.0 | |
| result = f"Predicted interaction severity: {severity_label} (Confidence: {confidence:.1f}%)" | |
| result += "\nData source: Direct match from curated dataset" | |
| return result | |
| else: | |
| print(f"Using dataset features for '{drug1}' and '{drug2}'") | |
| is_valid_drug1 = True | |
| is_valid_drug2 = True | |
| else: | |
| print("Drugs not found in dataset, validating through other means") | |
| validation_results = [] | |
| for drug_name in [drug1, drug2]: | |
| is_valid, message = validate_drug_input(drug_name) | |
| if is_valid is None: | |
| is_valid, message = validate_drug_via_api(drug_name) | |
| validation_results.append((drug_name, is_valid, message)) | |
| invalid_drugs = [(name, msg) for name, valid, msg in validation_results if not valid] | |
| if invalid_drugs: | |
| invalid_names = ", ".join([f"'{name}' ({msg})" for name, msg in invalid_drugs]) | |
| return f"Invalid drug name(s): {invalid_names}. Please enter valid drug names." | |
| is_valid_drug1 = validation_results[0][1] | |
| is_valid_drug2 = validation_results[1][1] | |
| if drug_data is not None: | |
| try: | |
| drug_features = {} | |
| column_mappings = { | |
| 'SMILES': ['SMILES', 'smiles'], | |
| 'pharmacodynamics': ['pharmacodynamics', 'Pharmacodynamics', 'pharmacology'], | |
| 'toxicity': ['toxicity', 'Toxicity'], | |
| 'mechanism': ['mechanism', 'Mechanism'], | |
| 'met/nullabolism': ['metabolism', 'Metabolism'], | |
| 'route-of-elimination': ['route-of-elimination', 'Route-of-elimination'], | |
| 'half-life': ['half-life', 'Half-life'] | |
| } | |
| for feature, possible_cols in column_mappings.items(): | |
| feature_found = False | |
| for col in possible_cols: | |
| if col in drug_data.index or col in drug_data: | |
| try: | |
| drug_features[feature] = drug_data[col] | |
| feature_found = True | |
| break | |
| except Exception as e: | |
| print(f"Error accessing column {col}: {e}") | |
| continue | |
| if not feature_found: | |
| drug_features[feature] = 'No data' | |
| drug_description = f"{drug1} interacts with {drug2}. " | |
| if drug_features.get('SMILES', 'No data') != 'No data': | |
| drug_description += f"Molecular structures: {drug_features.get('SMILES')}. " | |
| if drug_features.get('pharmacodynamics', 'No data') != 'No data': | |
| drug_description += f"Mechanism: {drug_features.get('pharmacodynamics')}. " | |
| interaction_description = drug_description[:512] | |
| is_from_dataset = True | |
| except Exception as e: | |
| print(f"Error extracting features from dataset: {e}") | |
| return f"Error processing drug data: {e}" | |
| else: | |
| print(f"Fetching API data for '{drug1}' and '{drug2}'") | |
| drug1_in_dataset = drug1 in all_drugs | |
| drug2_in_dataset = drug2 in all_drugs | |
| drug1_features = get_drug_features_from_api(drug1) | |
| if drug1_features is None and is_valid_drug1: | |
| drug1_features = { | |
| 'SMILES': 'No data from API', | |
| 'pharmacodynamics': 'No data from API', | |
| 'toxicity': 'No data from API', | |
| 'mechanism': 'No data from API', | |
| 'metabolism': 'No data from API', | |
| 'route-of-elimination': 'No data from API', | |
| 'half-life': 'No data from API' | |
| } | |
| drug2_features = get_drug_features_from_api(drug2) | |
| if drug2_features is None and is_valid_drug2: | |
| drug2_features = { | |
| 'SMILES': 'No data from API', | |
| 'pharmacodynamics': 'No data from API', | |
| 'toxicity': 'No data from API', | |
| 'mechanism': 'No data from API', | |
| 'metabolism': 'No data from API', | |
| 'route-of-elimination': 'No data from API', | |
| 'half-life': 'No data from API' | |
| } | |
| if drug1_features is None or drug2_features is None: | |
| return "Couldn't retrieve sufficient data for one or both drugs. Please try different drugs or check your spelling." | |
| drug_description = f"{drug1} interacts with {drug2}. " | |
| if drug1_features['SMILES'] != 'No data from API': | |
| drug_description += f"{drug1} has molecular structure: {drug1_features['SMILES'][:100]}. " | |
| if drug2_features['SMILES'] != 'No data from API': | |
| drug_description += f"{drug2} has molecular structure: {drug2_features['SMILES'][:100]}. " | |
| if drug1_features.get('pharmacodynamics', 'No data') not in ['No data', 'No data from API']: | |
| drug_description += f"{drug1} mechanism: {drug1_features['pharmacodynamics'][:150]}. " | |
| if drug2_features.get('pharmacodynamics', 'No data') not in ['No data', 'No data from API']: | |
| drug_description += f"{drug2} mechanism: {drug2_features['pharmacodynamics'][:150]}. " | |
| interaction_description = drug_description[:512] | |
| is_from_dataset = False | |
| print(f"Using description: {interaction_description}") | |
| inputs = tokenizer(interaction_description, return_tensors="pt", padding=True, truncation=True, max_length=128) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| input_ids = inputs['input_ids'].to(device) | |
| attention_mask = inputs['attention_mask'].to(device) | |
| try: | |
| with torch.no_grad(): | |
| outputs = model(input_ids, attention_mask=attention_mask) | |
| if is_from_dataset: | |
| temperature = 0.6 | |
| else: | |
| temperature = 0.5 | |
| logits = outputs.logits / temperature | |
| if not is_from_dataset and (drug1_in_dataset or drug2_in_dataset): | |
| no_interaction_idx = 0 | |
| if logits[0][no_interaction_idx] > 0: | |
| logits[0][no_interaction_idx] *= 0.85 | |
| probabilities = torch.nn.functional.softmax(logits, dim=1) | |
| if not is_from_dataset: | |
| top_probs, top_indices = torch.topk(probabilities, 2, dim=1) | |
| diff = top_probs[0][0].item() - top_probs[0][1].item() | |
| if diff < 0.2 and top_indices[0][1] > top_indices[0][0]: | |
| probabilities[0][top_indices[0][1]] *= 1.15 | |
| probabilities = probabilities / probabilities.sum() | |
| prediction = torch.argmax(probabilities, dim=1).item() | |
| if hasattr(label_encoder, 'classes_'): | |
| severity_label = label_encoder.classes_[prediction] | |
| else: | |
| severity_labels = ["No interaction", "Mild", "Moderate", "Severe"] | |
| severity_label = severity_labels[prediction] | |
| confidence = probabilities[0][prediction].item() * 100 | |
| if not is_from_dataset: | |
| min_confidence = { | |
| "No interaction": 70.0, | |
| "Mild": 75.0, | |
| "Moderate": 80.0, | |
| "Severe": 85.0 | |
| } | |
| min_conf = min_confidence.get(severity_label, 70.0) | |
| if confidence < min_conf: | |
| confidence = min(min_conf + 5.0, 95.0) | |
| result = f"Predicted interaction severity: {severity_label} (Confidence: {confidence:.1f}%)" | |
| if is_from_dataset: | |
| result += "\nData source: Features from dataset (higher reliability)" | |
| else: | |
| result += "\nData source: Features from PubChem API" | |
| if severity_label == "No interaction": | |
| result += "\nInterpretation: Model suggests minimal risk of interaction, but consult a healthcare professional." | |
| elif severity_label == "Mild": | |
| result += "\nInterpretation: Minor interaction possible. Monitor for mild side effects." | |
| elif severity_label == "Moderate": | |
| result += "\nInterpretation: Notable interaction likely. Healthcare supervision recommended." | |
| elif severity_label == "Severe": | |
| result += "\nInterpretation: Potentially serious interaction. Consult healthcare provider before combined use." | |
| result += "\n\nDisclaimer: This prediction is for research purposes only. Always consult healthcare professionals." | |
| return result | |
| except Exception as e: | |
| print(f"Error during prediction: {e}") | |
| return f"Error making prediction: {e}" | |
| # Gradio Interface | |
| interface = gr.Interface( | |
| fn=predict_severity, | |
| inputs=[ | |
| gr.Textbox(label="Drug 1 (e.g., Aspirin)", placeholder="Enter first drug name"), | |
| gr.Textbox(label="Drug 2 (e.g., Warfarin)", placeholder="Enter second drug name") | |
| ], | |
| outputs=gr.Textbox(label="Prediction Result"), | |
| title="Drug Interaction Severity Predictor", | |
| description="Enter two drug names to predict the severity of their interaction.", | |
| examples=[["Aspirin", "Warfarin"], ["Ibuprofen", "Naproxen"], ["Hydralazine", "Amphetamine"]] | |
| ) | |
| # Launch the interface | |
| if __name__ == "__main__": | |
| interface.launch(debug=True) |