|
|
|
|
|
""" |
|
|
Example inference code for BK Classification BART model |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from transformers import AutoTokenizer |
|
|
from huggingface_hub import hf_hub_download |
|
|
import json |
|
|
|
|
|
class BartWithClassifier(nn.Module): |
|
|
"""BART classifier for multi-label BK classification""" |
|
|
|
|
|
def __init__(self, num_labels=1884, model_name="facebook/bart-large", dropout=0.1): |
|
|
super(BartWithClassifier, self).__init__() |
|
|
|
|
|
self.num_labels = num_labels |
|
|
|
|
|
from transformers import BartModel |
|
|
self.bart = BartModel.from_pretrained("./bart_backbone") |
|
|
self.dropout = nn.Dropout(dropout) |
|
|
self.classifier = nn.Linear(self.bart.config.hidden_size, num_labels) |
|
|
|
|
|
def forward(self, input_ids, attention_mask=None): |
|
|
outputs = self.bart(input_ids=input_ids, attention_mask=attention_mask) |
|
|
last_hidden_state = outputs.last_hidden_state |
|
|
cls_output = last_hidden_state[:, 0, :] |
|
|
cls_output = self.dropout(cls_output) |
|
|
logits = self.classifier(cls_output) |
|
|
return logits |
|
|
|
|
|
def load_model_from_hf(model_name="mrehank209/bk-classification-bart-two-stage"): |
|
|
"""Load the complete model from Hugging Face Hub""" |
|
|
|
|
|
|
|
|
classifier_path = hf_hub_download(repo_id=model_name, filename="classifier_head.pt") |
|
|
config_path = hf_hub_download(repo_id=model_name, filename="config.json") |
|
|
label_map_path = hf_hub_download(repo_id=model_name, filename="label_map.json") |
|
|
|
|
|
|
|
|
with open(config_path, 'r') as f: |
|
|
config = json.load(f) |
|
|
|
|
|
|
|
|
with open(label_map_path, 'r') as f: |
|
|
label_map = json.load(f) |
|
|
|
|
|
|
|
|
model = BartWithClassifier(num_labels=config["num_labels"]) |
|
|
|
|
|
|
|
|
classifier_state = torch.load(classifier_path, map_location='cpu') |
|
|
model.classifier.load_state_dict(classifier_state) |
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
|
|
|
return model, tokenizer, label_map |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
model, tokenizer, label_map = load_model_from_hf() |
|
|
|
|
|
|
|
|
text = """ |
|
|
Title: Künstliche Intelligenz in der Bibliothek |
|
|
Summary: Ein Überblick über moderne KI-Methoden für Bibliothekswesen |
|
|
Keywords: künstliche intelligenz, bibliothek, automatisierung |
|
|
LOC_Keywords: artificial intelligence, library science |
|
|
RVK: AN 73000 |
|
|
""" |
|
|
|
|
|
|
|
|
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=768) |
|
|
|
|
|
|
|
|
model.eval() |
|
|
with torch.no_grad(): |
|
|
logits = model(**inputs) |
|
|
probs = torch.sigmoid(logits) |
|
|
|
|
|
|
|
|
top_k = 5 |
|
|
top_probs, top_indices = torch.topk(probs[0], top_k) |
|
|
|
|
|
print("Top predictions:") |
|
|
idx_to_label = {v: k for k, v in label_map.items()} |
|
|
for prob, idx in zip(top_probs, top_indices): |
|
|
label = idx_to_label[idx.item()] |
|
|
print(f" {label}: {prob:.3f}") |
|
|
|