TheStrangerOne's picture
Update inference.py
541d778 verified
raw
history blame contribute delete
928 Bytes
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch
# Define custom pipeline for multilabel classification
class MultilabelPipeline:
def init(self, model_name):
self.model = AutoModelForSequenceClassification.from_pretrained(model_name)
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
def call(self, input_text):
inputs = self.tokenizer(input_text, return_tensors="pt")
with torch.no_grad():
outputs = self.model(**inputs)
logits = outputs.logits
# Apply sigmoid to get probabilities for multilabel classification
probabilities = torch.sigmoid(logits)
return probabilities.tolist()
# Create instance of the custom pipeline
pipe = MultilabelPipeline("TheStrangerOne/gemma-2-9b-it-bnb-4bit-lora-multilabel")
# Example input
probs = pipe("Your input prompt here")
print("Probabilities:", probs)