""" Example inference code for BK Classification BART model """ import torch import torch.nn as nn from transformers import AutoTokenizer from huggingface_hub import hf_hub_download import json class BartWithClassifier(nn.Module): """BART classifier for multi-label BK classification""" def __init__(self, num_labels=1884, model_name="facebook/bart-large", dropout=0.1): super(BartWithClassifier, self).__init__() self.num_labels = num_labels # Load from local bart_backbone directory from transformers import BartModel self.bart = BartModel.from_pretrained("./bart_backbone") self.dropout = nn.Dropout(dropout) self.classifier = nn.Linear(self.bart.config.hidden_size, num_labels) def forward(self, input_ids, attention_mask=None): outputs = self.bart(input_ids=input_ids, attention_mask=attention_mask) last_hidden_state = outputs.last_hidden_state cls_output = last_hidden_state[:, 0, :] # Take [CLS] token representation cls_output = self.dropout(cls_output) logits = self.classifier(cls_output) return logits def load_model_from_hf(model_name="mrehank209/bk-classification-bart-two-stage"): """Load the complete model from Hugging Face Hub""" # Download files classifier_path = hf_hub_download(repo_id=model_name, filename="classifier_head.pt") config_path = hf_hub_download(repo_id=model_name, filename="config.json") label_map_path = hf_hub_download(repo_id=model_name, filename="label_map.json") # Load config with open(config_path, 'r') as f: config = json.load(f) # Load label mapping with open(label_map_path, 'r') as f: label_map = json.load(f) # Initialize model model = BartWithClassifier(num_labels=config["num_labels"]) # Load classifier head classifier_state = torch.load(classifier_path, map_location='cpu') model.classifier.load_state_dict(classifier_state) # Load tokenizer tokenizer = AutoTokenizer.from_pretrained(model_name) return model, tokenizer, label_map # Example usage if __name__ == "__main__": model, tokenizer, label_map = load_model_from_hf() # Example text text = """ Title: Künstliche Intelligenz in der Bibliothek Summary: Ein Überblick über moderne KI-Methoden für Bibliothekswesen Keywords: künstliche intelligenz, bibliothek, automatisierung LOC_Keywords: artificial intelligence, library science RVK: AN 73000 """ # Tokenize inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=768) # Predict model.eval() with torch.no_grad(): logits = model(**inputs) probs = torch.sigmoid(logits) # Get top predictions top_k = 5 top_probs, top_indices = torch.topk(probs[0], top_k) print("Top predictions:") idx_to_label = {v: k for k, v in label_map.items()} for prob, idx in zip(top_probs, top_indices): label = idx_to_label[idx.item()] print(f" {label}: {prob:.3f}")