Spaces:
Sleeping
Sleeping
| # First try to import with fallbacks | |
| try: | |
| import torch | |
| from transformers import AutoTokenizer, AutoModel | |
| import joblib | |
| from huggingface_hub import hf_hub_download | |
| import json | |
| except ImportError as e: | |
| print(f"Import error: {e}") | |
| # Try to install missing packages (this might not work in Spaces) | |
| import subprocess | |
| import sys | |
| subprocess.check_call([sys.executable, "-m", "pip", "install", "torch", "transformers", "joblib", "huggingface-hub"]) | |
| import torch | |
| from transformers import AutoTokenizer, AutoModel | |
| import joblib | |
| from huggingface_hub import hf_hub_download | |
| import json | |
| class DrugInteractionClassifier(torch.nn.Module): | |
| def __init__(self, n_classes, bert_model_name="emilyalsentzer/Bio_ClinicalBERT"): | |
| super(DrugInteractionClassifier, self).__init__() | |
| self.bert = AutoModel.from_pretrained(bert_model_name) | |
| self.classifier = torch.nn.Sequential( | |
| torch.nn.Linear(self.bert.config.hidden_size, 256), | |
| torch.nn.ReLU(), | |
| torch.nn.Dropout(0.3), | |
| torch.nn.Linear(256, n_classes) | |
| ) | |
| def forward(self, input_ids, attention_mask): | |
| bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask) | |
| pooled_output = bert_output[0][:, 0, :] | |
| return self.classifier(pooled_output) | |
| class DDIPredictor: | |
| def __init__(self, repo_id="Fredaaaaaa/drug_interaction_severity"): | |
| self.repo_id = repo_id | |
| print(f"π Loading model from: {repo_id}") | |
| try: | |
| # Download model files from Hugging Face | |
| print("π₯ Downloading config.json...") | |
| self.config_path = hf_hub_download(repo_id=repo_id, filename="config.json") | |
| print("π₯ Downloading pytorch_model.bin...") | |
| self.model_path = hf_hub_download(repo_id=repo_id, filename="pytorch_model.bin") | |
| print("π₯ Downloading label_encoder.joblib...") | |
| self.label_encoder_path = hf_hub_download(repo_id=repo_id, filename="label_encoder.joblib") | |
| # Load config | |
| with open(self.config_path, "r") as f: | |
| self.config = json.load(f) | |
| # Load tokenizer from repo | |
| print("π€ Loading tokenizer...") | |
| self.tokenizer = AutoTokenizer.from_pretrained(repo_id) | |
| # Load label encoder | |
| print("π·οΈ Loading label encoder...") | |
| self.label_encoder = joblib.load(self.label_encoder_path) | |
| # Initialize model | |
| print("π§ Initializing model...") | |
| self.model = DrugInteractionClassifier( | |
| n_classes=self.config["num_labels"], | |
| bert_model_name=self.config["bert_model_name"] | |
| ) | |
| # Load weights | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"βοΈ Loading weights on {device}...") | |
| self.model.load_state_dict( | |
| torch.load(self.model_path, map_location=device) | |
| ) | |
| self.model.to(device) | |
| self.model.eval() | |
| self.device = device | |
| print(f"β Model loaded successfully from {repo_id} on {device}") | |
| except Exception as e: | |
| print(f"β Error loading model: {e}") | |
| raise e | |
| def predict(self, text, confidence_threshold=0.0): | |
| """Predict drug interaction severity""" | |
| if not text or not text.strip(): | |
| return { | |
| "prediction": "Invalid Input", | |
| "confidence": 0.0, | |
| "probabilities": {label: 0.0 for label in self.label_encoder.classes_} | |
| } | |
| try: | |
| # Tokenize | |
| inputs = self.tokenizer( | |
| text, | |
| max_length=self.config.get("max_length", 128), | |
| padding=True, | |
| truncation=True, | |
| return_tensors="pt" | |
| ) | |
| inputs = {k: v.to(self.device) for k, v in inputs.items()} | |
| # Predict | |
| with torch.no_grad(): | |
| outputs = self.model(inputs["input_ids"], inputs["attention_mask"]) | |
| probabilities = torch.softmax(outputs, dim=1) | |
| confidence, predicted_idx = torch.max(probabilities, dim=1) | |
| predicted_label = self.label_encoder.inverse_transform([predicted_idx.item()])[0] | |
| # Get all probabilities | |
| all_probs = { | |
| self.label_encoder.inverse_transform([i])[0]: prob.item() | |
| for i, prob in enumerate(probabilities[0]) | |
| } | |
| return { | |
| "prediction": predicted_label, | |
| "confidence": confidence.item(), | |
| "probabilities": all_probs | |
| } | |
| except Exception as e: | |
| return { | |
| "prediction": f"Error: {str(e)}", | |
| "confidence": 0.0, | |
| "probabilities": {label: 0.0 for label in self.label_encoder.classes_} | |
| } | |
| # Global predictor instance | |
| try: | |
| predictor = DDIPredictor("Fredaaaaaa/drug_interaction_severity") | |
| MODEL_LOADED = True | |
| except Exception as e: | |
| print(f"Failed to load model: {e}") | |
| predictor = None | |
| MODEL_LOADED = False |