Fredaaaaaa commited on
Commit
338a546
Β·
verified Β·
1 Parent(s): 0de66ef

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +73 -48
inference.py CHANGED
@@ -1,8 +1,21 @@
1
- import torch
2
- from transformers import AutoTokenizer, AutoModel
3
- import joblib
4
- from huggingface_hub import hf_hub_download
5
- import json
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  class DrugInteractionClassifier(torch.nn.Module):
8
  def __init__(self, n_classes, bert_model_name="emilyalsentzer/Bio_ClinicalBERT"):
@@ -23,38 +36,53 @@ class DrugInteractionClassifier(torch.nn.Module):
23
  class DDIPredictor:
24
  def __init__(self, repo_id="Fredaaaaaa/drug_interaction_severity"):
25
  self.repo_id = repo_id
 
26
 
27
- # Download model files from Hugging Face
28
- self.config_path = hf_hub_download(repo_id=repo_id, filename="config.json")
29
- self.model_path = hf_hub_download(repo_id=repo_id, filename="pytorch_model.bin")
30
- self.label_encoder_path = hf_hub_download(repo_id=repo_id, filename="label_encoder.joblib")
31
-
32
- # Load config
33
- with open(self.config_path, "r") as f:
34
- self.config = json.load(f)
35
-
36
- # Load tokenizer from repo
37
- self.tokenizer = AutoTokenizer.from_pretrained(repo_id)
38
-
39
- # Load label encoder
40
- self.label_encoder = joblib.load(self.label_encoder_path)
41
-
42
- # Initialize model
43
- self.model = DrugInteractionClassifier(
44
- n_classes=self.config["num_labels"],
45
- bert_model_name=self.config["bert_model_name"]
46
- )
47
-
48
- # Load weights
49
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
50
- self.model.load_state_dict(
51
- torch.load(self.model_path, map_location=device, weights_only=True)
52
- )
53
- self.model.to(device)
54
- self.model.eval()
55
-
56
- self.device = device
57
- print(f"βœ… Model loaded successfully from {repo_id} on {device}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
  def predict(self, text, confidence_threshold=0.0):
60
  """Predict drug interaction severity"""
@@ -69,7 +97,7 @@ class DDIPredictor:
69
  # Tokenize
70
  inputs = self.tokenizer(
71
  text,
72
- max_length=self.config["max_length"],
73
  padding=True,
74
  truncation=True,
75
  return_tensors="pt"
@@ -104,14 +132,11 @@ class DDIPredictor:
104
  "probabilities": {label: 0.0 for label in self.label_encoder.classes_}
105
  }
106
 
107
- # Simple test
108
- if __name__ == "__main__":
109
- try:
110
- predictor = DDIPredictor("Fredaaaaaa/drug_interaction_severity")
111
- test_text = "Drug interaction may increase bleeding risk"
112
- result = predictor.predict(test_text)
113
- print("βœ… Test successful!")
114
- print(f"Prediction: {result['prediction']}")
115
- print(f"Confidence: {result['confidence']:.3f}")
116
- except Exception as e:
117
- print(f"❌ Error: {e}")
 
1
+ # First try to import with fallbacks
2
+ try:
3
+ import torch
4
+ from transformers import AutoTokenizer, AutoModel
5
+ import joblib
6
+ from huggingface_hub import hf_hub_download
7
+ import json
8
+ except ImportError as e:
9
+ print(f"Import error: {e}")
10
+ # Try to install missing packages (this might not work in Spaces)
11
+ import subprocess
12
+ import sys
13
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "torch", "transformers", "joblib", "huggingface-hub"])
14
+ import torch
15
+ from transformers import AutoTokenizer, AutoModel
16
+ import joblib
17
+ from huggingface_hub import hf_hub_download
18
+ import json
19
 
20
  class DrugInteractionClassifier(torch.nn.Module):
21
  def __init__(self, n_classes, bert_model_name="emilyalsentzer/Bio_ClinicalBERT"):
 
36
  class DDIPredictor:
37
  def __init__(self, repo_id="Fredaaaaaa/drug_interaction_severity"):
38
  self.repo_id = repo_id
39
+ print(f"πŸš€ Loading model from: {repo_id}")
40
 
41
+ try:
42
+ # Download model files from Hugging Face
43
+ print("πŸ“₯ Downloading config.json...")
44
+ self.config_path = hf_hub_download(repo_id=repo_id, filename="config.json")
45
+
46
+ print("πŸ“₯ Downloading pytorch_model.bin...")
47
+ self.model_path = hf_hub_download(repo_id=repo_id, filename="pytorch_model.bin")
48
+
49
+ print("πŸ“₯ Downloading label_encoder.joblib...")
50
+ self.label_encoder_path = hf_hub_download(repo_id=repo_id, filename="label_encoder.joblib")
51
+
52
+ # Load config
53
+ with open(self.config_path, "r") as f:
54
+ self.config = json.load(f)
55
+
56
+ # Load tokenizer from repo
57
+ print("πŸ”€ Loading tokenizer...")
58
+ self.tokenizer = AutoTokenizer.from_pretrained(repo_id)
59
+
60
+ # Load label encoder
61
+ print("🏷️ Loading label encoder...")
62
+ self.label_encoder = joblib.load(self.label_encoder_path)
63
+
64
+ # Initialize model
65
+ print("🧠 Initializing model...")
66
+ self.model = DrugInteractionClassifier(
67
+ n_classes=self.config["num_labels"],
68
+ bert_model_name=self.config["bert_model_name"]
69
+ )
70
+
71
+ # Load weights
72
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
73
+ print(f"βš™οΈ Loading weights on {device}...")
74
+ self.model.load_state_dict(
75
+ torch.load(self.model_path, map_location=device)
76
+ )
77
+ self.model.to(device)
78
+ self.model.eval()
79
+
80
+ self.device = device
81
+ print(f"βœ… Model loaded successfully from {repo_id} on {device}")
82
+
83
+ except Exception as e:
84
+ print(f"❌ Error loading model: {e}")
85
+ raise e
86
 
87
  def predict(self, text, confidence_threshold=0.0):
88
  """Predict drug interaction severity"""
 
97
  # Tokenize
98
  inputs = self.tokenizer(
99
  text,
100
+ max_length=self.config.get("max_length", 128),
101
  padding=True,
102
  truncation=True,
103
  return_tensors="pt"
 
132
  "probabilities": {label: 0.0 for label in self.label_encoder.classes_}
133
  }
134
 
135
+ # Global predictor instance
136
+ try:
137
+ predictor = DDIPredictor("Fredaaaaaa/drug_interaction_severity")
138
+ MODEL_LOADED = True
139
+ except Exception as e:
140
+ print(f"Failed to load model: {e}")
141
+ predictor = None
142
+ MODEL_LOADED = False