Spaces:
Runtime error
Runtime error
| import pickle | |
| import requests | |
| import torch | |
| import gradio as gr | |
| import pandas as pd | |
| import re | |
| import numpy as np | |
| import os | |
| import shutil | |
| from huggingface_hub import hf_hub_download | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| from sklearn.utils.class_weight import compute_class_weight | |
| from collections import defaultdict | |
| print("Starting script execution...") | |
| # Helper functions (moved up) | |
| def clean_drug_name(drug_name): | |
| if not drug_name: | |
| return "" | |
| return re.sub(r'\s+', ' ', drug_name.strip().lower()) | |
| def validate_drug_input(drug_name): | |
| drug_name = clean_drug_name(drug_name) | |
| if not drug_name or len(drug_name) <= 1: | |
| return False, "Drug name is too short" | |
| if len(drug_name) == 1 or drug_name.isdigit(): | |
| return False, "Not a valid drug name" | |
| if not re.match(r'^[a-zA-Z0-9\s\-\+]+$', drug_name): | |
| return False, "Drug name contains invalid characters" | |
| if drug_name in all_drugs: | |
| return True, "Drug found in dataset" | |
| for known_drug in all_drugs: | |
| if drug_name in known_drug or known_drug in drug_name: | |
| return True, f"Drug found in dataset (matched with '{known_drug}')" | |
| return None, "Drug not in dataset, needs API validation" | |
| def validate_drug_via_api(drug_name): | |
| try: | |
| drug_name = clean_drug_name(drug_name) | |
| search_url = f"https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/name/{drug_name}/cids/JSON" | |
| response = requests.get(search_url, timeout=10) | |
| if response.status_code == 200: | |
| data = response.json() | |
| if 'IdentifierList' in data and 'CID' in data['IdentifierList']: | |
| return True, f"Drug validated via PubChem API (CID: {data['IdentifierList']['CID'][0]})" | |
| return False, "Drug not found in PubChem database" | |
| else: | |
| fallback_url = f"https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/name/{requests.utils.quote(drug_name)}/cids/JSON" | |
| fallback_response = requests.get(fallback_url, timeout=10) | |
| if fallback_response.status_code == 200: | |
| data = fallback_response.json() | |
| if 'IdentifierList' in data and 'CID' in data['IdentifierList']: | |
| return True, f"Drug validated via PubChem API (CID: {data['IdentifierList']['CID'][0]})" | |
| return False, f"Invalid drug name: API returned status {response.status_code}" | |
| except Exception as e: | |
| print(f"Error validating drug via API: {e}") | |
| return True, "API validation failed, assuming valid drug" | |
| def get_smiles_from_api(drug_name): | |
| if drug_name in drug_features_cache and 'smiles' in drug_features_cache[drug_name]: | |
| return drug_features_cache[drug_name]['smiles'] | |
| try: | |
| drug_name = clean_drug_name(drug_name) | |
| search_url = f"https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/name/{drug_name}/cids/JSON" | |
| response = requests.get(search_url, timeout=10) | |
| if response.status_code != 200: | |
| search_url = f"https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/name/{requests.utils.quote(drug_name)}/cids/JSON" | |
| response = requests.get(search_url, timeout=10) | |
| if response.status_code != 200: | |
| print(f"Drug {drug_name} not found in PubChem") | |
| return None | |
| data = response.json() | |
| if 'IdentifierList' not in data or 'CID' not in data['IdentifierList']: | |
| print(f"No CID found for drug {drug_name}") | |
| return None | |
| cid = data['IdentifierList']['CID'][0] | |
| smiles_url = f"https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/cid/{cid}/property/CanonicalSMILES/JSON" | |
| smiles_response = requests.get(smiles_url, timeout=10) | |
| if smiles_response.status_code == 200: | |
| smiles_data = smiles_response.json() | |
| if 'PropertyTable' in smiles_data and 'Properties' in smiles_data['PropertyTable']: | |
| properties = smiles_data['PropertyTable']['Properties'] | |
| if properties and 'CanonicalSMILES' in properties[0]: | |
| smiles = properties[0]['CanonicalSMILES'] | |
| drug_features_cache[drug_name]['smiles'] = smiles | |
| return smiles | |
| print(f"No SMILES found for drug {drug_name}") | |
| return None | |
| except Exception as e: | |
| print(f"Error getting SMILES from API: {e}") | |
| return None | |
| def get_pubchem_features(smiles): | |
| if smiles and any(drug for drug, data in drug_features_cache.items() if data.get('smiles') == smiles and 'features' in data): | |
| for drug, data in drug_features_cache.items(): | |
| if data.get('smiles') == smiles and 'features' in data: | |
| return data['features'] | |
| try: | |
| properties_url = f"https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/smiles/{requests.utils.quote(smiles)}/property/MolecularWeight,XLogP,TPSA,RotatableBondCount,HBondDonorCount,HBondAcceptorCount,Complexity,ExactMass/JSON" | |
| response = requests.get(properties_url, timeout=10) | |
| if response.status_code == 200: | |
| data = response.json() | |
| if 'PropertyTable' in data and 'Properties' in data['PropertyTable']: | |
| props = data['PropertyTable']['Properties'][0] | |
| features = { | |
| 'molecular_weight': props.get('MolecularWeight', 0), | |
| 'xlogp': props.get('XLogP', 0), | |
| 'tpsa': props.get('TPSA', 0), | |
| 'rotatable_bond_count': props.get('RotatableBondCount', 0), | |
| 'h_bond_donor_count': props.get('HBondDonorCount', 0), | |
| 'h_bond_acceptor_count': props.get('HBondAcceptorCount', 0), | |
| 'complexity': props.get('Complexity', 0), | |
| 'exact_mass': props.get('ExactMass', 0) | |
| } | |
| print(f"Extracted features for SMILES {smiles}: {features}") | |
| if smiles in [data['smiles'] for data in drug_features_cache.values()]: | |
| for drug, data in drug_features_cache.items(): | |
| if data.get('smiles') == smiles: | |
| data['features'] = features | |
| return features | |
| print(f"Failed to retrieve features for SMILES {smiles}") | |
| return None | |
| except Exception as e: | |
| print(f"Error getting PubChem features: {e}") | |
| return None | |
| def get_drug_features_from_dataset(drug1, drug2, df): | |
| if df.empty: | |
| print("Dataset is empty") | |
| return None, None, None | |
| drug1 = clean_drug_name(drug1) | |
| drug2 = clean_drug_name(drug2) | |
| try: | |
| if 'Drug 1_normalized' in df.columns and 'Drug 2_normalized' in df.columns: | |
| drug_data = df[ | |
| (df['Drug 1_normalized'].str.lower().str.strip() == drug1) & | |
| (df['Drug 2_normalized'].str.lower().str.strip() == drug2) | |
| ] | |
| reversed_drug_data = df[ | |
| (df['Drug 1_normalized'].str.lower().str.strip() == drug2) & | |
| (df['Drug 2_normalized'].str.lower().str.strip() == drug1) | |
| ] | |
| drug_data = pd.concat([drug_data, reversed_drug_data]) | |
| else: | |
| drug_data = pd.DataFrame() | |
| for col1, col2 in [('Drug1', 'Drug2'), ('Drug 1', 'Drug 2'), ('drug1', 'drug2'), ('drug_1', 'drug_2')]: | |
| if col1 in df.columns and col2 in df.columns: | |
| matches = df[ | |
| ((df[col1].astype(str).str.lower().str.strip() == drug1) & | |
| (df[col2].astype(str).str.lower().str.strip() == drug2)) | | |
| ((df[col1].astype(str).str.lower().str.strip() == drug2) & | |
| (df[col2].astype(str).str.lower().str.strip() == drug1)) | |
| ] | |
| if not matches.empty: | |
| drug_data = matches | |
| break | |
| if not drug_data.empty: | |
| print(f"Found drugs '{drug1}' and '{drug2}' in dataset") | |
| smiles1 = drug_data.get('canonical_smiles', None) | |
| smiles2 = drug_data.get('canonical_smiles_2', None) | |
| if isinstance(smiles1, pd.Series): | |
| smiles1 = smiles1.iloc[0] | |
| if isinstance(smiles2, pd.Series): | |
| smiles2 = smiles2.iloc[0] | |
| severity = drug_data.get('severity', None) | |
| if isinstance(severity, pd.Series): | |
| severity = severity.iloc[0] | |
| return smiles1, smiles2, severity | |
| return None, None, None | |
| except Exception as e: | |
| print(f"Error searching dataset: {e}") | |
| return None, None, None | |
| def predict_severity(drug_input): | |
| if not drug_input: | |
| return "Please enter at least one drug pair separated by commas (e.g., 'Aspirin, Warfarin')." | |
| # Parse drug pairs from comma-separated input | |
| drug_pairs = [pair.strip() for pair in drug_input.split(',') if pair.strip()] | |
| if len(drug_pairs) < 2 or len(drug_pairs) % 2 != 0: | |
| return "Please enter drug pairs separated by commas (e.g., 'Aspirin, Warfarin' or 'Aspirin, Warfarin, Ibuprofen, Naproxen')." | |
| results = [] | |
| for i in range(0, len(drug_pairs), 2): | |
| if i + 1 >= len(drug_pairs): | |
| break | |
| drug1, drug2 = drug_pairs[i], drug_pairs[i + 1] | |
| print(f"Processing: '{drug1}', '{drug2}'") | |
| smiles1, smiles2, severity = get_drug_features_from_dataset(drug1, drug2, df) | |
| if severity is not None: | |
| results.append(severity) | |
| continue | |
| # Fallback to PubChem for drugs not in dataset | |
| validation_results = [] | |
| for drug_name in [drug1, drug2]: | |
| is_valid, message = validate_drug_input(drug_name) | |
| if is_valid is None: | |
| is_valid, message = validate_drug_via_api(drug_name) | |
| validation_results.append((drug_name, is_valid, message)) | |
| invalid_drugs = [(name, msg) for name, valid, msg in validation_results if not valid] | |
| if invalid_drugs: | |
| results.append(f"Invalid drug(s) for {drug1} and {drug2}: {', '.join([f'{name} ({msg})' for name, msg in invalid_drugs])}") | |
| continue | |
| # Fetch SMILES from PubChem if not in dataset | |
| drug1_in_dataset = drug1 in all_drugs | |
| drug2_in_dataset = drug2 in all_drugs | |
| if smiles1 is None: | |
| smiles1 = get_smiles_from_api(drug1) | |
| if smiles2 is None: | |
| smiles2 = get_smiles_from_api(drug2) | |
| if smiles1 is None or smiles2 is None: | |
| results.append(f"Could not retrieve SMILES for {drug1 if smiles1 is None else ''}{', ' if smiles1 is None and smiles2 is None else ''}{drug2 if smiles2 is None else ''} from PubChem.") | |
| continue | |
| # Extract PubChem features with fallback | |
| features1 = drug_features_cache[drug1].get('features') if drug1 in drug_features_cache and 'features' in drug_features_cache[drug1] else get_pubchem_features(smiles1) | |
| features2 = drug_features_cache[drug2].get('features') if drug2 in drug_features_cache and 'features' in drug_features_cache[drug2] else get_pubchem_features(smiles2) | |
| if features1 is None: | |
| print(f"Warning: No features retrieved for {drug1}, using default values.") | |
| features1 = {'molecular_weight': 0, 'xlogp': 0, 'tpsa': 0, 'rotatable_bond_count': 0, | |
| 'h_bond_donor_count': 0, 'h_bond_acceptor_count': 0, 'complexity': 0, 'exact_mass': 0} | |
| if features2 is None: | |
| print(f"Warning: No features retrieved for {drug2}, using default values.") | |
| features2 = {'molecular_weight': 0, 'xlogp': 0, 'tpsa': 0, 'rotatable_bond_count': 0, | |
| 'h_bond_donor_count': 0, 'h_bond_acceptor_count': 0, 'complexity': 0, 'exact_mass': 0} | |
| # Combine SMILES and features into interaction description | |
| mw1 = features1.get('molecular_weight', 0) | |
| mw2 = features2.get('molecular_weight', 0) | |
| if not isinstance(mw1, (int, float)) or not isinstance(mw2, (int, float)): | |
| mw1, mw2 = 0, 0 # Fallback if not numeric | |
| drug_description = (f"{drug1} SMILES: {smiles1[:50]}, MW: {mw1:.0f}. " | |
| f"{drug2} SMILES: {smiles2[:50]}, MW: {mw2:.0f}.") | |
| interaction_description = drug_description[:256] # Reduced max length | |
| is_from_dataset = False | |
| if 'canonical_smiles' in df.columns and 'canonical_smiles_2' in df.columns: | |
| is_from_dataset = smiles1 in df['canonical_smiles'].values and smiles2 in df['canonical_smiles_2'].values | |
| print(f"Using description: {interaction_description}") | |
| inputs = tokenizer(interaction_description, return_tensors="pt", padding=True, truncation=True, max_length=128) | |
| input_ids = inputs['input_ids'].to(device) | |
| attention_mask = inputs['attention_mask'].to(device) | |
| try: | |
| with torch.no_grad(): | |
| outputs = model(input_ids, attention_mask=attention_mask) | |
| temperature = 0.6 if is_from_dataset else 0.5 | |
| logits = outputs.logits / temperature | |
| if not is_from_dataset and (drug1_in_dataset or drug2_in_dataset): | |
| no_interaction_idx = 0 | |
| if logits[0][no_interaction_idx] > 0: | |
| logits[0][no_interaction_idx] *= 0.85 | |
| probabilities = torch.nn.functional.softmax(logits, dim=1) | |
| if not is_from_dataset: | |
| top_probs, top_indices = torch.topk(probabilities, 2, dim=1) | |
| diff = top_probs[0][0].item() - top_probs[0][1].item() | |
| if diff < 0.2 and top_indices[0][1] > top_indices[0][0]: | |
| probabilities[0][top_indices[0][1]] *= 1.15 | |
| probabilities = probabilities / probabilities.sum() | |
| prediction = torch.argmax(probabilities, dim=1).item() | |
| severity_label = label_encoder.classes_[prediction] | |
| results.append(severity_label) | |
| except Exception as e: | |
| print(f"Error during prediction for {drug1} and {drug2}: {e}") | |
| results.append(f"Error for {drug1} and {drug2}: {e}") | |
| return "\n\n".join(results) if results else "No valid predictions." | |
| print("Setting up device...") | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Using device: {device}") | |
| # Model and dataset paths | |
| print("Setting up model and dataset paths...") | |
| model_name = "Fredaaaaaa/hybrid_model" | |
| output_dir = "/home/user/app/drug_interaction_model" | |
| # Create output directory | |
| print("Creating output directory...") | |
| os.makedirs(output_dir, exist_ok=True) | |
| # Download and load label encoder | |
| print("Downloading and loading label encoder...") | |
| label_encoder_path = hf_hub_download(repo_id=model_name, filename="label_encoder.pkl") | |
| with open(label_encoder_path, 'rb') as f: | |
| label_encoder = pickle.load(f) | |
| # Load model and tokenizer | |
| print("Loading model and tokenizer...") | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForSequenceClassification.from_pretrained(model_name) | |
| model.to(device) | |
| model.eval() | |
| # Download and load dataset | |
| print("Downloading and loading dataset...") | |
| dataset_path = hf_hub_download(repo_id=model_name, filename="merged_cleaned_dataset.csv") | |
| df = pd.read_csv(dataset_path, encoding='ISO-8859-1') | |
| print(f"Dataset loaded successfully! Shape: {df.shape}") | |
| print(f"Columns: {df.columns.tolist()}") # Print columns to debug | |
| print(df.head()) | |
| # Save model, tokenizer, label encoder, and dataset | |
| print("Saving model, tokenizer, label encoder, and dataset...") | |
| model.save_pretrained(output_dir) | |
| tokenizer.save_pretrained(output_dir) | |
| with open(os.path.join(output_dir, 'label_encoder.pkl'), 'wb') as f: | |
| pickle.dump(label_encoder, f) | |
| df.to_csv(os.path.join(output_dir, 'merged_cleaned_dataset.csv'), index=False) | |
| # Create zip archive | |
| print("Creating zip archive...") | |
| zip_path = "/home/user/app/drug_interaction_model.zip" | |
| shutil.make_archive("/home/user/app/drug_interaction_model", 'zip', output_dir) | |
| print(f"📦 Model saved and zipped at: {zip_path}") | |
| print(f"To download, access the file at: {zip_path} from your environment or server.") | |
| # Compute class weights | |
| print("Computing class weights...") | |
| unique_classes = df['severity'].unique() | |
| print(f"Unique severity classes: {unique_classes}") | |
| class_weights = compute_class_weight('balanced', classes=np.unique(unique_classes), y=df['severity']) | |
| class_weights = torch.tensor(class_weights, dtype=torch.float).to(device) | |
| # Extract unique drug names and precompute dataset features | |
| print("Extracting unique drug names and precomputing features...") | |
| all_drugs = set() | |
| drug_features_cache = defaultdict(dict) | |
| for col in ['Drug 1_normalized', 'Drug1', 'Drug 1', 'drug1', 'drug_1']: | |
| if col in df.columns: | |
| all_drugs.update(df[col].astype(str).str.lower().str.strip().tolist()) | |
| for col in ['Drug 2_normalized', 'Drug2', 'Drug 2', 'drug2', 'drug_2']: | |
| if col in df.columns: | |
| all_drugs.update(df[col].astype(str).str.lower().str.strip().tolist()) | |
| all_drugs = {drug for drug in all_drugs if drug and drug != 'nan'} | |
| print(f"Loaded {len(all_drugs)} unique drug names") | |
| # Precompute SMILES and features for dataset drugs | |
| print("Precomputing SMILES and features...") | |
| for index, row in df.iterrows(): | |
| for col in ['Drug 1_normalized', 'Drug1', 'Drug 1', 'drug1', 'drug_1']: | |
| if col in df.columns and pd.notna(row[col]): | |
| drug = clean_drug_name(row[col]) | |
| if 'canonical_smiles' in df.columns and pd.notna(row['canonical_smiles']): | |
| drug_features_cache[drug]['smiles'] = row['canonical_smiles'] | |
| drug_features_cache[drug]['features'] = get_pubchem_features(row['canonical_smiles']) | |
| for col in ['Drug 2_normalized', 'Drug2', 'Drug 2', 'drug2', 'drug_2']: | |
| if col in df.columns and pd.notna(row[col]): | |
| drug = clean_drug_name(row[col]) | |
| if 'canonical_smiles_2' in df.columns and pd.notna(row['canonical_smiles_2']): | |
| drug_features_cache[drug]['smiles'] = row['canonical_smiles_2'] | |
| drug_features_cache[drug]['features'] = get_pubchem_features(row['canonical_smiles_2']) | |
| # Gradio Interface | |
| print("Setting up Gradio interface...") | |
| interface = gr.Interface( | |
| fn=predict_severity, | |
| inputs=gr.Textbox(label="Drug Pairs (e.g., 'Aspirin, Warfarin' or 'Aspirin, Warfarin, Ibuprofen, Naproxen')", | |
| placeholder="Enter drug names separated by commas"), | |
| outputs=gr.Textbox(label="Prediction Result"), | |
| title="Drug Interaction Severity Predictor", | |
| description="Enter drug pairs separated by commas to predict interaction severity based on SMILES and PubChem features.", | |
| examples=[["Aspirin, Warfarin"], ["Ibuprofen, Naproxen, Hydralazine, Amphetamine"]] | |
| ) | |
| print("Launching Gradio interface...") | |
| if _name_ == "_main_": | |
| interface.launch(debug=True) |