Fredaaaaaa commited on
Commit
0f8c85b
·
verified ·
1 Parent(s): 1cfa4ef

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +110 -55
inference.py CHANGED
@@ -1,62 +1,117 @@
1
- # inference.py
2
  import torch
3
- from transformers import AutoModelForSequenceClassification, AutoTokenizer
4
- from torch.utils.data import DataLoader, Dataset
5
- import numpy as np
6
- import pandas as pd
7
 
8
- class DrugInteractionDataset(Dataset):
9
- """Custom Dataset for drug interaction data."""
10
- def __init__(self, description, tokenizer, max_length=512):
11
- self.description = [description]
12
- self.tokenizer = tokenizer
13
- self.max_length = max_length
14
-
15
- def __len__(self):
16
- return 1
17
-
18
- def __getitem__(self, idx):
19
- encoding = self.tokenizer(
20
- self.description[idx],
21
- padding='max_length',
22
- truncation=True,
23
- max_length=self.max_length,
24
- return_tensors='pt'
25
  )
26
- return {
27
- 'input_ids': encoding['input_ids'].squeeze(),
28
- 'attention_mask': encoding['attention_mask'].squeeze(),
29
- }
 
30
 
31
  class DDIPredictor:
32
- def __init__(self, model_repo="Fredaaaaaa/drug_interaction_severity"):
33
- self.tokenizer = AutoTokenizer.from_pretrained(model_repo)
34
- self.model = AutoModelForSequenceClassification.from_pretrained(model_repo)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  self.model.eval()
36
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
37
- self.model.to(self.device)
38
-
39
- def predict(self, interaction_description):
40
- # Create dataset and dataloader
41
- dataset = DrugInteractionDataset(interaction_description, self.tokenizer)
42
- dataloader = DataLoader(dataset, batch_size=1)
43
-
44
- # Get prediction
45
- with torch.no_grad():
46
- for batch in dataloader:
47
- input_ids = batch['input_ids'].to(self.device)
48
- attention_mask = batch['attention_mask'].to(self.device)
49
- outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
50
- logits = outputs.logits
51
- probabilities = torch.softmax(logits, dim=-1).cpu().numpy()[0]
52
- prediction = torch.argmax(logits, dim=-1).cpu().item()
53
-
54
- # Map prediction to label (adjust based on your training labels)
55
- label_map = {0: "No Interaction", 1: "Mild", 2: "Moderate", 3: "Severe"} # Update based on your classes
56
- confidence = probabilities[prediction] * 100
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
- return {
59
- "prediction": label_map.get(prediction, "Unknown"),
60
- "confidence": confidence,
61
- "probabilities": {k: v for k, v in enumerate(probabilities)}
62
- }
 
 
 
 
 
 
 
 
1
  import torch
2
+ from transformers import AutoTokenizer, AutoModel
3
+ import joblib
4
+ from huggingface_hub import hf_hub_download
5
+ import json
6
 
7
+ class DrugInteractionClassifier(torch.nn.Module):
8
+ def __init__(self, n_classes, bert_model_name="emilyalsentzer/Bio_ClinicalBERT"):
9
+ super(DrugInteractionClassifier, self).__init__()
10
+ self.bert = AutoModel.from_pretrained(bert_model_name)
11
+ self.classifier = torch.nn.Sequential(
12
+ torch.nn.Linear(self.bert.config.hidden_size, 256),
13
+ torch.nn.ReLU(),
14
+ torch.nn.Dropout(0.3),
15
+ torch.nn.Linear(256, n_classes)
 
 
 
 
 
 
 
 
16
  )
17
+
18
+ def forward(self, input_ids, attention_mask):
19
+ bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask)
20
+ pooled_output = bert_output[0][:, 0, :]
21
+ return self.classifier(pooled_output)
22
 
23
  class DDIPredictor:
24
+ def __init__(self, repo_id="Fredaaaaaa/drug_interaction_severity"):
25
+ self.repo_id = repo_id
26
+
27
+ # Download model files from Hugging Face
28
+ self.config_path = hf_hub_download(repo_id=repo_id, filename="config.json")
29
+ self.model_path = hf_hub_download(repo_id=repo_id, filename="pytorch_model.bin")
30
+ self.label_encoder_path = hf_hub_download(repo_id=repo_id, filename="label_encoder.joblib")
31
+
32
+ # Load config
33
+ with open(self.config_path, "r") as f:
34
+ self.config = json.load(f)
35
+
36
+ # Load tokenizer from repo
37
+ self.tokenizer = AutoTokenizer.from_pretrained(repo_id)
38
+
39
+ # Load label encoder
40
+ self.label_encoder = joblib.load(self.label_encoder_path)
41
+
42
+ # Initialize model
43
+ self.model = DrugInteractionClassifier(
44
+ n_classes=self.config["num_labels"],
45
+ bert_model_name=self.config["bert_model_name"]
46
+ )
47
+
48
+ # Load weights
49
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
50
+ self.model.load_state_dict(
51
+ torch.load(self.model_path, map_location=device, weights_only=True)
52
+ )
53
+ self.model.to(device)
54
  self.model.eval()
55
+
56
+ self.device = device
57
+ print(f"✅ Model loaded successfully from {repo_id} on {device}")
58
+
59
+ def predict(self, text, confidence_threshold=0.0):
60
+ """Predict drug interaction severity"""
61
+ if not text or not text.strip():
62
+ return {
63
+ "prediction": "Invalid Input",
64
+ "confidence": 0.0,
65
+ "probabilities": {label: 0.0 for label in self.label_encoder.classes_}
66
+ }
67
+
68
+ try:
69
+ # Tokenize
70
+ inputs = self.tokenizer(
71
+ text,
72
+ max_length=self.config["max_length"],
73
+ padding=True,
74
+ truncation=True,
75
+ return_tensors="pt"
76
+ )
77
+
78
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
79
+
80
+ # Predict
81
+ with torch.no_grad():
82
+ outputs = self.model(inputs["input_ids"], inputs["attention_mask"])
83
+ probabilities = torch.softmax(outputs, dim=1)
84
+ confidence, predicted_idx = torch.max(probabilities, dim=1)
85
+
86
+ predicted_label = self.label_encoder.inverse_transform([predicted_idx.item()])[0]
87
+
88
+ # Get all probabilities
89
+ all_probs = {
90
+ self.label_encoder.inverse_transform([i])[0]: prob.item()
91
+ for i, prob in enumerate(probabilities[0])
92
+ }
93
+
94
+ return {
95
+ "prediction": predicted_label,
96
+ "confidence": confidence.item(),
97
+ "probabilities": all_probs
98
+ }
99
+
100
+ except Exception as e:
101
+ return {
102
+ "prediction": f"Error: {str(e)}",
103
+ "confidence": 0.0,
104
+ "probabilities": {label: 0.0 for label in self.label_encoder.classes_}
105
+ }
106
 
107
+ # Simple test
108
+ if __name__ == "__main__":
109
+ try:
110
+ predictor = DDIPredictor("Fredaaaaaa/drug_interaction_severity")
111
+ test_text = "Drug interaction may increase bleeding risk"
112
+ result = predictor.predict(test_text)
113
+ print("✅ Test successful!")
114
+ print(f"Prediction: {result['prediction']}")
115
+ print(f"Confidence: {result['confidence']:.3f}")
116
+ except Exception as e:
117
+ print(f"❌ Error: {e}")