bk-classification-bart-two-stage / inference_example.py
mrehank209's picture
Upload inference_example.py with huggingface_hub
dd77677 verified
"""
Example inference code for BK Classification BART model
"""
import torch
import torch.nn as nn
from transformers import AutoTokenizer
from huggingface_hub import hf_hub_download
import json
class BartWithClassifier(nn.Module):
"""BART classifier for multi-label BK classification"""
def __init__(self, num_labels=1884, model_name="facebook/bart-large", dropout=0.1):
super(BartWithClassifier, self).__init__()
self.num_labels = num_labels
# Load from local bart_backbone directory
from transformers import BartModel
self.bart = BartModel.from_pretrained("./bart_backbone")
self.dropout = nn.Dropout(dropout)
self.classifier = nn.Linear(self.bart.config.hidden_size, num_labels)
def forward(self, input_ids, attention_mask=None):
outputs = self.bart(input_ids=input_ids, attention_mask=attention_mask)
last_hidden_state = outputs.last_hidden_state
cls_output = last_hidden_state[:, 0, :] # Take [CLS] token representation
cls_output = self.dropout(cls_output)
logits = self.classifier(cls_output)
return logits
def load_model_from_hf(model_name="mrehank209/bk-classification-bart-two-stage"):
"""Load the complete model from Hugging Face Hub"""
# Download files
classifier_path = hf_hub_download(repo_id=model_name, filename="classifier_head.pt")
config_path = hf_hub_download(repo_id=model_name, filename="config.json")
label_map_path = hf_hub_download(repo_id=model_name, filename="label_map.json")
# Load config
with open(config_path, 'r') as f:
config = json.load(f)
# Load label mapping
with open(label_map_path, 'r') as f:
label_map = json.load(f)
# Initialize model
model = BartWithClassifier(num_labels=config["num_labels"])
# Load classifier head
classifier_state = torch.load(classifier_path, map_location='cpu')
model.classifier.load_state_dict(classifier_state)
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
return model, tokenizer, label_map
# Example usage
if __name__ == "__main__":
model, tokenizer, label_map = load_model_from_hf()
# Example text
text = """
Title: Künstliche Intelligenz in der Bibliothek
Summary: Ein Überblick über moderne KI-Methoden für Bibliothekswesen
Keywords: künstliche intelligenz, bibliothek, automatisierung
LOC_Keywords: artificial intelligence, library science
RVK: AN 73000
"""
# Tokenize
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=768)
# Predict
model.eval()
with torch.no_grad():
logits = model(**inputs)
probs = torch.sigmoid(logits)
# Get top predictions
top_k = 5
top_probs, top_indices = torch.topk(probs[0], top_k)
print("Top predictions:")
idx_to_label = {v: k for k, v in label_map.items()}
for prob, idx in zip(top_probs, top_indices):
label = idx_to_label[idx.item()]
print(f" {label}: {prob:.3f}")