CocoRoF's picture
Update model.py
b6bc975 verified
from transformers import AutoTokenizer, TextClassificationPipeline, AutoModelForSequenceClassification;
import torch
class MajorClassifier(TextClassificationPipeline):
def __init__(self, repo_id, power_device=None):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.tokenizer = AutoTokenizer.from_pretrained(repo_id)
self.model = AutoModelForSequenceClassification.from_pretrained(repo_id)
if power_device is not None:
self.model.to(power_device)
print(f"Using device: {power_device}")
else:
self.model.to(self.device)
print(f"Using device: {self.device}")
def __call__(self, inputs, top_k=5, **kwargs):
inputs = self.tokenizer(inputs, return_tensors="pt", truncation=True, padding=True, max_length=512, **kwargs)
inputs = {k: v.to(next(self.model.parameters()).device) for k, v in inputs.items()}
with torch.no_grad():
outputs = self.model(**inputs)
probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
scores, indices = torch.topk(probs, top_k, dim=-1)
results = []
for batch_idx in range(indices.shape[0]):
batch_results = []
for score, idx in zip(scores[batch_idx], indices[batch_idx]):
label = self.model.config.id2label[idx.item()]
batch_results.append({
"label": idx.item(),
"label_decode": label,
"score": score.item(),
})
results.append(batch_results)
return results[0]