Fredaaaaaa commited on
Commit
7f535c5
·
verified ·
1 Parent(s): 5d048cc

Create inference.py

Browse files
Files changed (1) hide show
  1. inference.py +62 -0
inference.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # inference.py
2
+ import torch
3
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
4
+ from torch.utils.data import DataLoader, Dataset
5
+ import numpy as np
6
+ import pandas as pd
7
+
8
+ class DrugInteractionDataset(Dataset):
9
+ """Custom Dataset for drug interaction data."""
10
+ def __init__(self, description, tokenizer, max_length=512):
11
+ self.description = [description]
12
+ self.tokenizer = tokenizer
13
+ self.max_length = max_length
14
+
15
+ def __len__(self):
16
+ return 1
17
+
18
+ def __getitem__(self, idx):
19
+ encoding = self.tokenizer(
20
+ self.description[idx],
21
+ padding='max_length',
22
+ truncation=True,
23
+ max_length=self.max_length,
24
+ return_tensors='pt'
25
+ )
26
+ return {
27
+ 'input_ids': encoding['input_ids'].squeeze(),
28
+ 'attention_mask': encoding['attention_mask'].squeeze(),
29
+ }
30
+
31
+ class DDIPredictor:
32
+ def __init__(self, model_repo="Fredaaaaaa/drug_interaction_severity"):
33
+ self.tokenizer = AutoTokenizer.from_pretrained(model_repo)
34
+ self.model = AutoModelForSequenceClassification.from_pretrained(model_repo)
35
+ self.model.eval()
36
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
37
+ self.model.to(self.device)
38
+
39
+ def predict(self, interaction_description):
40
+ # Create dataset and dataloader
41
+ dataset = DrugInteractionDataset(interaction_description, self.tokenizer)
42
+ dataloader = DataLoader(dataset, batch_size=1)
43
+
44
+ # Get prediction
45
+ with torch.no_grad():
46
+ for batch in dataloader:
47
+ input_ids = batch['input_ids'].to(self.device)
48
+ attention_mask = batch['attention_mask'].to(self.device)
49
+ outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
50
+ logits = outputs.logits
51
+ probabilities = torch.softmax(logits, dim=-1).cpu().numpy()[0]
52
+ prediction = torch.argmax(logits, dim=-1).cpu().item()
53
+
54
+ # Map prediction to label (adjust based on your training labels)
55
+ label_map = {0: "No Interaction", 1: "Mild", 2: "Moderate", 3: "Severe"} # Update based on your classes
56
+ confidence = probabilities[prediction] * 100
57
+
58
+ return {
59
+ "prediction": label_map.get(prediction, "Unknown"),
60
+ "confidence": confidence,
61
+ "probabilities": {k: v for k, v in enumerate(probabilities)}
62
+ }