TheStrangerOne commited on
Commit
541d778
·
verified ·
1 Parent(s): 17c3169

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +25 -25
inference.py CHANGED
@@ -1,26 +1,26 @@
1
- from transformers import AutoModelForSequenceClassification, AutoTokenizer
2
- import torch
3
-
4
- # Define custom pipeline for multilabel classification
5
- class MultilabelPipeline:
6
- def init(self, model_name):
7
- self.model = AutoModelForSequenceClassification.from_pretrained(model_name)
8
- self.tokenizer = AutoTokenizer.from_pretrained(model_name)
9
-
10
- def call(self, input_text):
11
- inputs = self.tokenizer(input_text, return_tensors="pt")
12
- with torch.no_grad():
13
- outputs = self.model(**inputs)
14
- logits = outputs.logits
15
-
16
- # Apply sigmoid to get probabilities for multilabel classification
17
- probabilities = torch.sigmoid(logits)
18
-
19
- return probabilities.tolist()
20
-
21
- # Create instance of the custom pipeline
22
- pipe = MultilabelPipeline("your-username/gemma-2-9b-it-bnb-4bit-lora")
23
-
24
- # Example input
25
- probs = pipe("Your input prompt here")
26
  print("Probabilities:", probs)
 
1
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
2
+ import torch
3
+
4
+ # Define custom pipeline for multilabel classification
5
+ class MultilabelPipeline:
6
+ def init(self, model_name):
7
+ self.model = AutoModelForSequenceClassification.from_pretrained(model_name)
8
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
9
+
10
+ def call(self, input_text):
11
+ inputs = self.tokenizer(input_text, return_tensors="pt")
12
+ with torch.no_grad():
13
+ outputs = self.model(**inputs)
14
+ logits = outputs.logits
15
+
16
+ # Apply sigmoid to get probabilities for multilabel classification
17
+ probabilities = torch.sigmoid(logits)
18
+
19
+ return probabilities.tolist()
20
+
21
+ # Create instance of the custom pipeline
22
+ pipe = MultilabelPipeline("TheStrangerOne/gemma-2-9b-it-bnb-4bit-lora-multilabel")
23
+
24
+ # Example input
25
+ probs = pipe("Your input prompt here")
26
  print("Probabilities:", probs)