Spaces:
Sleeping
Sleeping
File size: 5,447 Bytes
338a546 7f535c5 0f8c85b 7f535c5 0f8c85b 7f535c5 0f8c85b 338a546 0f8c85b 338a546 0f8c85b 338a546 0f8c85b 7f535c5 338a546 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
# First try to import with fallbacks
try:
import torch
from transformers import AutoTokenizer, AutoModel
import joblib
from huggingface_hub import hf_hub_download
import json
except ImportError as e:
print(f"Import error: {e}")
# Try to install missing packages (this might not work in Spaces)
import subprocess
import sys
subprocess.check_call([sys.executable, "-m", "pip", "install", "torch", "transformers", "joblib", "huggingface-hub"])
import torch
from transformers import AutoTokenizer, AutoModel
import joblib
from huggingface_hub import hf_hub_download
import json
class DrugInteractionClassifier(torch.nn.Module):
def __init__(self, n_classes, bert_model_name="emilyalsentzer/Bio_ClinicalBERT"):
super(DrugInteractionClassifier, self).__init__()
self.bert = AutoModel.from_pretrained(bert_model_name)
self.classifier = torch.nn.Sequential(
torch.nn.Linear(self.bert.config.hidden_size, 256),
torch.nn.ReLU(),
torch.nn.Dropout(0.3),
torch.nn.Linear(256, n_classes)
)
def forward(self, input_ids, attention_mask):
bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask)
pooled_output = bert_output[0][:, 0, :]
return self.classifier(pooled_output)
class DDIPredictor:
def __init__(self, repo_id="Fredaaaaaa/drug_interaction_severity"):
self.repo_id = repo_id
print(f"π Loading model from: {repo_id}")
try:
# Download model files from Hugging Face
print("π₯ Downloading config.json...")
self.config_path = hf_hub_download(repo_id=repo_id, filename="config.json")
print("π₯ Downloading pytorch_model.bin...")
self.model_path = hf_hub_download(repo_id=repo_id, filename="pytorch_model.bin")
print("π₯ Downloading label_encoder.joblib...")
self.label_encoder_path = hf_hub_download(repo_id=repo_id, filename="label_encoder.joblib")
# Load config
with open(self.config_path, "r") as f:
self.config = json.load(f)
# Load tokenizer from repo
print("π€ Loading tokenizer...")
self.tokenizer = AutoTokenizer.from_pretrained(repo_id)
# Load label encoder
print("π·οΈ Loading label encoder...")
self.label_encoder = joblib.load(self.label_encoder_path)
# Initialize model
print("π§ Initializing model...")
self.model = DrugInteractionClassifier(
n_classes=self.config["num_labels"],
bert_model_name=self.config["bert_model_name"]
)
# Load weights
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"βοΈ Loading weights on {device}...")
self.model.load_state_dict(
torch.load(self.model_path, map_location=device)
)
self.model.to(device)
self.model.eval()
self.device = device
print(f"β
Model loaded successfully from {repo_id} on {device}")
except Exception as e:
print(f"β Error loading model: {e}")
raise e
def predict(self, text, confidence_threshold=0.0):
"""Predict drug interaction severity"""
if not text or not text.strip():
return {
"prediction": "Invalid Input",
"confidence": 0.0,
"probabilities": {label: 0.0 for label in self.label_encoder.classes_}
}
try:
# Tokenize
inputs = self.tokenizer(
text,
max_length=self.config.get("max_length", 128),
padding=True,
truncation=True,
return_tensors="pt"
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
# Predict
with torch.no_grad():
outputs = self.model(inputs["input_ids"], inputs["attention_mask"])
probabilities = torch.softmax(outputs, dim=1)
confidence, predicted_idx = torch.max(probabilities, dim=1)
predicted_label = self.label_encoder.inverse_transform([predicted_idx.item()])[0]
# Get all probabilities
all_probs = {
self.label_encoder.inverse_transform([i])[0]: prob.item()
for i, prob in enumerate(probabilities[0])
}
return {
"prediction": predicted_label,
"confidence": confidence.item(),
"probabilities": all_probs
}
except Exception as e:
return {
"prediction": f"Error: {str(e)}",
"confidence": 0.0,
"probabilities": {label: 0.0 for label in self.label_encoder.classes_}
}
# Global predictor instance
try:
predictor = DDIPredictor("Fredaaaaaa/drug_interaction_severity")
MODEL_LOADED = True
except Exception as e:
print(f"Failed to load model: {e}")
predictor = None
MODEL_LOADED = False |