File size: 5,447 Bytes
338a546
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7f535c5
0f8c85b
 
 
 
 
 
 
 
 
7f535c5
0f8c85b
 
 
 
 
7f535c5
 
0f8c85b
 
338a546
0f8c85b
338a546
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f8c85b
 
 
 
 
 
 
 
 
 
 
 
 
 
338a546
0f8c85b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7f535c5
338a546
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
# First try to import with fallbacks
try:
    import torch
    from transformers import AutoTokenizer, AutoModel
    import joblib
    from huggingface_hub import hf_hub_download
    import json
except ImportError as e:
    print(f"Import error: {e}")
    # Try to install missing packages (this might not work in Spaces)
    import subprocess
    import sys
    subprocess.check_call([sys.executable, "-m", "pip", "install", "torch", "transformers", "joblib", "huggingface-hub"])
    import torch
    from transformers import AutoTokenizer, AutoModel
    import joblib
    from huggingface_hub import hf_hub_download
    import json

class DrugInteractionClassifier(torch.nn.Module):
    def __init__(self, n_classes, bert_model_name="emilyalsentzer/Bio_ClinicalBERT"):
        super(DrugInteractionClassifier, self).__init__()
        self.bert = AutoModel.from_pretrained(bert_model_name)
        self.classifier = torch.nn.Sequential(
            torch.nn.Linear(self.bert.config.hidden_size, 256),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.3),
            torch.nn.Linear(256, n_classes)
        )

    def forward(self, input_ids, attention_mask):
        bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = bert_output[0][:, 0, :]
        return self.classifier(pooled_output)

class DDIPredictor:
    def __init__(self, repo_id="Fredaaaaaa/drug_interaction_severity"):
        self.repo_id = repo_id
        print(f"πŸš€ Loading model from: {repo_id}")
        
        try:
            # Download model files from Hugging Face
            print("πŸ“₯ Downloading config.json...")
            self.config_path = hf_hub_download(repo_id=repo_id, filename="config.json")
            
            print("πŸ“₯ Downloading pytorch_model.bin...")
            self.model_path = hf_hub_download(repo_id=repo_id, filename="pytorch_model.bin")
            
            print("πŸ“₯ Downloading label_encoder.joblib...")
            self.label_encoder_path = hf_hub_download(repo_id=repo_id, filename="label_encoder.joblib")
            
            # Load config
            with open(self.config_path, "r") as f:
                self.config = json.load(f)
            
            # Load tokenizer from repo
            print("πŸ”€ Loading tokenizer...")
            self.tokenizer = AutoTokenizer.from_pretrained(repo_id)
            
            # Load label encoder
            print("🏷️ Loading label encoder...")
            self.label_encoder = joblib.load(self.label_encoder_path)
            
            # Initialize model
            print("🧠 Initializing model...")
            self.model = DrugInteractionClassifier(
                n_classes=self.config["num_labels"],
                bert_model_name=self.config["bert_model_name"]
            )
            
            # Load weights
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            print(f"βš™οΈ Loading weights on {device}...")
            self.model.load_state_dict(
                torch.load(self.model_path, map_location=device)
            )
            self.model.to(device)
            self.model.eval()
            
            self.device = device
            print(f"βœ… Model loaded successfully from {repo_id} on {device}")
            
        except Exception as e:
            print(f"❌ Error loading model: {e}")
            raise e
    
    def predict(self, text, confidence_threshold=0.0):
        """Predict drug interaction severity"""
        if not text or not text.strip():
            return {
                "prediction": "Invalid Input",
                "confidence": 0.0,
                "probabilities": {label: 0.0 for label in self.label_encoder.classes_}
            }
        
        try:
            # Tokenize
            inputs = self.tokenizer(
                text,
                max_length=self.config.get("max_length", 128),
                padding=True,
                truncation=True,
                return_tensors="pt"
            )
            
            inputs = {k: v.to(self.device) for k, v in inputs.items()}
            
            # Predict
            with torch.no_grad():
                outputs = self.model(inputs["input_ids"], inputs["attention_mask"])
                probabilities = torch.softmax(outputs, dim=1)
                confidence, predicted_idx = torch.max(probabilities, dim=1)
            
            predicted_label = self.label_encoder.inverse_transform([predicted_idx.item()])[0]
            
            # Get all probabilities
            all_probs = {
                self.label_encoder.inverse_transform([i])[0]: prob.item() 
                for i, prob in enumerate(probabilities[0])
            }
            
            return {
                "prediction": predicted_label,
                "confidence": confidence.item(),
                "probabilities": all_probs
            }
            
        except Exception as e:
            return {
                "prediction": f"Error: {str(e)}",
                "confidence": 0.0,
                "probabilities": {label: 0.0 for label in self.label_encoder.classes_}
            }

# Global predictor instance
try:
    predictor = DDIPredictor("Fredaaaaaa/drug_interaction_severity")
    MODEL_LOADED = True
except Exception as e:
    print(f"Failed to load model: {e}")
    predictor = None
    MODEL_LOADED = False