File size: 3,165 Bytes
dd77677
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92

"""
Example inference code for BK Classification BART model
"""

import torch
import torch.nn as nn
from transformers import AutoTokenizer
from huggingface_hub import hf_hub_download
import json

class BartWithClassifier(nn.Module):
    """BART classifier for multi-label BK classification"""
    
    def __init__(self, num_labels=1884, model_name="facebook/bart-large", dropout=0.1):
        super(BartWithClassifier, self).__init__()
        
        self.num_labels = num_labels
        # Load from local bart_backbone directory
        from transformers import BartModel
        self.bart = BartModel.from_pretrained("./bart_backbone")
        self.dropout = nn.Dropout(dropout)
        self.classifier = nn.Linear(self.bart.config.hidden_size, num_labels)
        
    def forward(self, input_ids, attention_mask=None):
        outputs = self.bart(input_ids=input_ids, attention_mask=attention_mask)
        last_hidden_state = outputs.last_hidden_state
        cls_output = last_hidden_state[:, 0, :]  # Take [CLS] token representation
        cls_output = self.dropout(cls_output)
        logits = self.classifier(cls_output)
        return logits

def load_model_from_hf(model_name="mrehank209/bk-classification-bart-two-stage"):
    """Load the complete model from Hugging Face Hub"""
    
    # Download files
    classifier_path = hf_hub_download(repo_id=model_name, filename="classifier_head.pt")
    config_path = hf_hub_download(repo_id=model_name, filename="config.json")
    label_map_path = hf_hub_download(repo_id=model_name, filename="label_map.json")
    
    # Load config
    with open(config_path, 'r') as f:
        config = json.load(f)
    
    # Load label mapping
    with open(label_map_path, 'r') as f:
        label_map = json.load(f)
    
    # Initialize model
    model = BartWithClassifier(num_labels=config["num_labels"])
    
    # Load classifier head
    classifier_state = torch.load(classifier_path, map_location='cpu')
    model.classifier.load_state_dict(classifier_state)
    
    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    
    return model, tokenizer, label_map

# Example usage
if __name__ == "__main__":
    model, tokenizer, label_map = load_model_from_hf()
    
    # Example text
    text = """
    Title: Künstliche Intelligenz in der Bibliothek
    Summary: Ein Überblick über moderne KI-Methoden für Bibliothekswesen
    Keywords: künstliche intelligenz, bibliothek, automatisierung
    LOC_Keywords: artificial intelligence, library science
    RVK: AN 73000
    """
    
    # Tokenize
    inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=768)
    
    # Predict
    model.eval()
    with torch.no_grad():
        logits = model(**inputs)
        probs = torch.sigmoid(logits)
        
        # Get top predictions
        top_k = 5
        top_probs, top_indices = torch.topk(probs[0], top_k)
        
        print("Top predictions:")
        idx_to_label = {v: k for k, v in label_map.items()}
        for prob, idx in zip(top_probs, top_indices):
            label = idx_to_label[idx.item()]
            print(f"  {label}: {prob:.3f}")