import pickle import requests import torch import gradio as gr import pandas as pd import re import numpy as np import os import shutil from huggingface_hub import hf_hub_download from transformers import AutoTokenizer, AutoModelForSequenceClassification from sklearn.utils.class_weight import compute_class_weight # Device setup device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") # Model and dataset paths model_name = "Fredaaaaaa/hybrid_model" output_dir = "/home/user/app/drug_interaction_model" # Create output directory os.makedirs(output_dir, exist_ok=True) # Download and load label encoder label_encoder_path = hf_hub_download(repo_id=model_name, filename="label_encoder.pkl") with open(label_encoder_path, 'rb') as f: label_encoder = pickle.load(f) # Load model and tokenizer tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForSequenceClassification.from_pretrained(model_name) model.to(device) model.eval() # Download and load dataset dataset_path = hf_hub_download(repo_id=model_name, filename="labeled_severity.csv") df = pd.read_csv(dataset_path, encoding='ISO-8859-1') print(f"Dataset loaded successfully! Shape: {df.shape}") print(f"Columns: {df.columns}") print(df.head()) # Save model, tokenizer, label encoder, and dataset model.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir) with open(os.path.join(output_dir, 'label_encoder.pkl'), 'wb') as f: pickle.dump(label_encoder, f) df.to_csv(os.path.join(output_dir, 'labeled_severity.csv'), index=False) # Create zip archive zip_path = "/home/user/app/drug_interaction_model.zip" shutil.make_archive("/home/user/app/drug_interaction_model", 'zip', output_dir) print(f"📦 Model saved and zipped at: {zip_path}") print(f"To download, access the file at: {zip_path} from your environment or server.") # Compute class weights unique_classes = df['severity'].unique() print(f"Unique severity classes: {unique_classes}") class_weights = compute_class_weight('balanced', classes=np.unique(unique_classes), y=df['severity']) class_weights = torch.tensor(class_weights, dtype=torch.float).to(device) # Extract unique drug names all_drugs = set() for col in ['Drug 1_normalized', 'Drug1', 'Drug 1', 'drug1', 'drug_1']: if col in df.columns: all_drugs.update(df[col].astype(str).str.lower().str.strip().tolist()) for col in ['Drug 2_normalized', 'Drug2', 'Drug 2', 'drug2', 'drug_2']: if col in df.columns: all_drugs.update(df[col].astype(str).str.lower().str.strip().tolist()) all_drugs = {drug for drug in all_drugs if drug and drug != 'nan'} print(f"Loaded {len(all_drugs)} unique drug names") # Helper functions def clean_drug_name(drug_name): if not drug_name: return "" return re.sub(r'\s+', ' ', drug_name.strip().lower()) def validate_drug_input(drug_name): drug_name = clean_drug_name(drug_name) if not drug_name or len(drug_name) <= 1: return False, "Drug name is too short" if len(drug_name) == 1 or drug_name.isdigit(): return False, "Not a valid drug name" if not re.match(r'^[a-zA-Z0-9\s\-\+]+$', drug_name): return False, "Drug name contains invalid characters" if drug_name in all_drugs: return True, "Drug found in dataset" for known_drug in all_drugs: if drug_name in known_drug or known_drug in drug_name: return True, f"Drug found in dataset (matched with '{known_drug}')" return None, "Drug not in dataset, needs API validation" def validate_drug_via_api(drug_name): try: drug_name = clean_drug_name(drug_name) search_url = f"https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/name/{drug_name}/cids/JSON" response = requests.get(search_url, timeout=10) if response.status_code == 200: data = response.json() if 'IdentifierList' in data and 'CID' in data['IdentifierList']: return True, f"Drug validated via PubChem API (CID: {data['IdentifierList']['CID'][0]})" return False, "Drug not found in PubChem database" else: fallback_url = f"https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/name/{requests.utils.quote(drug_name)}/cids/JSON" fallback_response = requests.get(fallback_url, timeout=10) if fallback_response.status_code == 200: data = fallback_response.json() if 'IdentifierList' in data and 'CID' in data['IdentifierList']: return True, f"Drug validated via PubChem API (CID: {data['IdentifierList']['CID'][0]})" return False, f"Invalid drug name: API returned status {response.status_code}" except Exception as e: print(f"Error validating drug via API: {e}") return True, "API validation failed, assuming valid drug" def get_smiles_from_api(drug_name): try: drug_name = clean_drug_name(drug_name) search_url = f"https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/name/{drug_name}/cids/JSON" response = requests.get(search_url, timeout=10) if response.status_code != 200: search_url = f"https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/name/{requests.utils.quote(drug_name)}/cids/JSON" response = requests.get(search_url, timeout=10) if response.status_code != 200: print(f"Drug {drug_name} not found in PubChem") return None data = response.json() if 'IdentifierList' not in data or 'CID' not in data['IdentifierList']: print(f"No CID found for drug {drug_name}") return None cid = data['IdentifierList']['CID'][0] smiles_url = f"https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/cid/{cid}/property/CanonicalSMILES/JSON" smiles_response = requests.get(smiles_url, timeout=10) if smiles_response.status_code == 200: smiles_data = smiles_response.json() if 'PropertyTable' in smiles_data and 'Properties' in smiles_data['PropertyTable']: properties = smiles_data['PropertyTable']['Properties'] if properties and 'CanonicalSMILES' in properties[0]: print(f"SMILES found for {drug_name}: {properties[0]['CanonicalSMILES']}") return properties[0]['CanonicalSMILES'] print(f"No SMILES found for drug {drug_name}") return None except Exception as e: print(f"Error getting SMILES from API: {e}") return None def get_drug_features_from_dataset(drug1, drug2, df): if df.empty: print("Dataset is empty") return None, None, None drug1 = clean_drug_name(drug1) drug2 = clean_drug_name(drug2) try: if 'Drug 1_normalized' in df.columns and 'Drug 2_normalized' in df.columns: drug_data = df[ (df['Drug 1_normalized'].str.lower().str.strip() == drug1) & (df['Drug 2_normalized'].str.lower().str.strip() == drug2) ] reversed_drug_data = df[ (df['Drug 1_normalized'].str.lower().str.strip() == drug2) & (df['Drug 2_normalized'].str.lower().str.strip() == drug1) ] drug_data = pd.concat([drug_data, reversed_drug_data]) else: drug_data = pd.DataFrame() for col1, col2 in [('Drug1', 'Drug2'), ('Drug 1', 'Drug 2'), ('drug1', 'drug2'), ('drug_1', 'drug_2')]: if col1 in df.columns and col2 in df.columns: matches = df[ ((df[col1].astype(str).str.lower().str.strip() == drug1) & (df[col2].astype(str).str.lower().str.strip() == drug2)) | ((df[col1].astype(str).str.lower().str.strip() == drug2) & (df[col2].astype(str).str.lower().str.strip() == drug1)) ] if not matches.empty: drug_data = matches break if not drug_data.empty: print(f"Found drugs '{drug1}' and '{drug2}' in dataset") smiles1 = drug_data.get('SMILES', None) smiles2 = drug_data.get('SMILES_2', None) if isinstance(smiles1, pd.Series): smiles1 = smiles1.iloc[0] if isinstance(smiles2, pd.Series): smiles2 = smiles2.iloc[0] severity = drug_data.get('severity', None) if isinstance(severity, pd.Series): severity = severity.iloc[0] return smiles1, smiles2, severity return None, None, None except Exception as e: print(f"Error searching dataset: {e}") return None, None, None def predict_severity(drug1, drug2): if not drug1 or not drug2: return "Please enter both drugs." drug1 = clean_drug_name(drug1) drug2 = clean_drug_name(drug2) print(f"Processing: '{drug1}', '{drug2}'") smiles1, smiles2, severity = get_drug_features_from_dataset(drug1, drug2, df) if severity is not None: confidence = 98.0 result = f"Predicted interaction severity: {severity} (Confidence: {confidence:.1f}%)\nData source: Direct match from dataset" return result validation_results = [] for drug_name in [drug1, drug2]: is_valid, message = validate_drug_input(drug_name) if is_valid is None: is_valid, message = validate_drug_via_api(drug_name) validation_results.append((drug_name, is_valid, message)) invalid_drugs = [(name, msg) for name, valid, msg in validation_results if not valid] if invalid_drugs: return f"Invalid drug(s): {', '.join([f'{name} ({msg})' for name, msg in invalid_drugs])}" drug1_in_dataset = drug1 in all_drugs drug2_in_dataset = drug2 in all_drugs if smiles1 is None: smiles1 = get_smiles_from_api(drug1) if smiles2 is None: smiles2 = get_smiles_from_api(drug2) if smiles1 is None or smiles2 is None: return "Couldn't retrieve SMILES for one or both drugs." drug_description = f"{drug1} SMILES: {smiles1[:100]}. {drug2} SMILES: {smiles2[:100]}." interaction_description = drug_description[:512] is_from_dataset = smiles1 in df.get('SMILES', []).values and smiles2 in df.get('SMILES_2', []).values print(f"Using description: {interaction_description}") inputs = tokenizer(interaction_description, return_tensors="pt", padding=True, truncation=True, max_length=128) input_ids = inputs['input_ids'].to(device) attention_mask = inputs['attention_mask'].to(device) try: with torch.no_grad(): outputs = model(input_ids, attention_mask=attention_mask) temperature = 0.6 if is_from_dataset else 0.5 logits = outputs.logits / temperature if not is_from_dataset and (drug1_in_dataset or drug2_in_dataset): no_interaction_idx = 0 if logits[0][no_interaction_idx] > 0: logits[0][no_interaction_idx] *= 0.85 probabilities = torch.nn.functional.softmax(logits, dim=1) if not is_from_dataset: top_probs, top_indices = torch.topk(probabilities, 2, dim=1) diff = top_probs[0][0].item() - top_probs[0][1].item() if diff < 0.2 and top_indices[0][1] > top_indices[0][0]: probabilities[0][top_indices[0][1]] *= 1.15 probabilities = probabilities / probabilities.sum() prediction = torch.argmax(probabilities, dim=1).item() severity_label = label_encoder.classes_[prediction] confidence = probabilities[0][prediction].item() * 100 min_confidence = {"No interaction": 70.0, "Mild": 75.0, "Moderate": 80.0, "Severe": 85.0} min_conf = min_confidence.get(severity_label, 70.0) if not is_from_dataset and confidence < min_conf: confidence = min(min_conf + 5.0, 95.0) result = f"Predicted interaction severity: {severity_label} (Confidence: {confidence:.1f}%)\nData source: {'Dataset' if is_from_dataset else 'PubChem API'}" if not is_from_dataset: interpretations = { "No interaction": "Minimal risk, but consult a professional.", "Mild": "Minor interaction possible. Monitor for mild effects.", "Moderate": "Notable interaction likely. Supervision recommended.", "Severe": "Potentially serious. Consult provider before use." } result += f"\nInterpretation: {interpretations.get(severity_label, 'Consult a professional.')}" result += "\n\nDisclaimer: For research only. Consult healthcare professionals." return result except Exception as e: print(f"Error during prediction: {e}") return f"Error: {e}" # Gradio Interface interface = gr.Interface( fn=predict_severity, inputs=[ gr.Textbox(label="Drug 1 (e.g., Aspirin)", placeholder="Enter first drug name"), gr.Textbox(label="Drug 2 (e.g., Warfarin)", placeholder="Enter second drug name") ], outputs=gr.Textbox(label="Prediction Result"), title="Drug Interaction Severity Predictor", description="Enter two drug names to predict interaction severity based on SMILES.", examples=[["Aspirin", "Warfarin"], ["Ibuprofen", "Naproxen"], ["Hydralazine", "Amphetamine"]] ) if __name__ == "__main__": interface.launch(debug=True)