PEFT
Safetensors
File size: 6,566 Bytes
b01da00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
"""
Custom inference handler for the arxiv-classifier PEFT adapter.

This handler loads a LLaMA-3-8B base model with a LoRA adapter fine-tuned
for arXiv paper classification into 150 subfields.
"""

from typing import Dict, List, Any
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer, BitsAndBytesConfig
from peft import PeftModel

# ArXiv subfield mapping (150 classes)
INVERSE_SUBFIELD_MAP = {
    0: "astro-ph", 1: "astro-ph.CO", 2: "astro-ph.EP", 3: "astro-ph.GA",
    4: "astro-ph.HE", 5: "astro-ph.IM", 6: "astro-ph.SR", 7: "cond-mat.dis-nn",
    8: "cond-mat.mes-hall", 9: "cond-mat.mtrl-sci", 10: "cond-mat.other",
    11: "cond-mat.quant-gas", 12: "cond-mat.soft", 13: "cond-mat.stat-mech",
    14: "cond-mat.str-el", 15: "cond-mat.supr-con", 16: "cs.AI", 17: "cs.AR",
    18: "cs.CC", 19: "cs.CE", 20: "cs.CG", 21: "cs.CL", 22: "cs.CR", 23: "cs.CV",
    24: "cs.CY", 25: "cs.DB", 26: "cs.DC", 27: "cs.DL", 28: "cs.DM", 29: "cs.DS",
    30: "cs.ET", 31: "cs.FL", 32: "cs.GL", 33: "cs.GR", 34: "cs.GT", 35: "cs.HC",
    36: "cs.IR", 37: "cs.IT", 38: "cs.LG", 39: "cs.LO", 40: "cs.MA", 41: "cs.MM",
    42: "cs.MS", 43: "cs.NE", 44: "cs.NI", 45: "cs.OH", 46: "cs.OS", 47: "cs.PF",
    48: "cs.PL", 49: "cs.RO", 50: "cs.SC", 51: "cs.SD", 52: "cs.SE", 53: "cs.SI",
    54: "econ.EM", 55: "econ.GN", 56: "econ.TH", 57: "eess.AS", 58: "eess.IV",
    59: "eess.SP", 60: "eess.SY", 61: "gr-qc", 62: "hep-ex", 63: "hep-lat",
    64: "hep-ph", 65: "hep-th", 66: "math-ph", 67: "math.AC", 68: "math.AG",
    69: "math.AP", 70: "math.AT", 71: "math.CA", 72: "math.CO", 73: "math.CT",
    74: "math.CV", 75: "math.DG", 76: "math.DS", 77: "math.FA", 78: "math.GM",
    79: "math.GN", 80: "math.GR", 81: "math.GT", 82: "math.HO", 83: "math.KT",
    84: "math.LO", 85: "math.MG", 86: "math.NA", 87: "math.NT", 88: "math.OA",
    89: "math.OC", 90: "math.PR", 91: "math.QA", 92: "math.RA", 93: "math.RT",
    94: "math.SG", 95: "math.SP", 96: "math.ST", 97: "nlin.AO", 98: "nlin.CD",
    99: "nlin.CG", 100: "nlin.PS", 101: "nlin.SI", 102: "nucl-ex", 103: "nucl-th",
    104: "physics.acc-ph", 105: "physics.ao-ph", 106: "physics.app-ph",
    107: "physics.atm-clus", 108: "physics.atom-ph", 109: "physics.bio-ph",
    110: "physics.chem-ph", 111: "physics.class-ph", 112: "physics.comp-ph",
    113: "physics.data-an", 114: "physics.ed-ph", 115: "physics.flu-dyn",
    116: "physics.gen-ph", 117: "physics.geo-ph", 118: "physics.hist-ph",
    119: "physics.ins-det", 120: "physics.med-ph", 121: "physics.optics",
    122: "physics.plasm-ph", 123: "physics.pop-ph", 124: "physics.soc-ph",
    125: "physics.space-ph", 126: "q-bio.BM", 127: "q-bio.CB", 128: "q-bio.GN",
    129: "q-bio.MN", 130: "q-bio.NC", 131: "q-bio.OT", 132: "q-bio.PE",
    133: "q-bio.QM", 134: "q-bio.SC", 135: "q-bio.TO", 136: "q-fin.CP",
    137: "q-fin.GN", 138: "q-fin.MF", 139: "q-fin.PM", 140: "q-fin.PR",
    141: "q-fin.RM", 142: "q-fin.ST", 143: "q-fin.TR", 144: "quant-ph",
    145: "stat.AP", 146: "stat.CO", 147: "stat.ME", 148: "stat.ML", 149: "stat.OT"
}

N_SUBFIELDS = 150


class EndpointHandler:
    def __init__(self, path: str = ""):
        """
        Initialize the model and tokenizer.

        Args:
            path: Path to the model repository (adapter files)
        """
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # Base model configuration
        base_model_name = "meta-llama/Meta-Llama-3-8B"
        self.max_length = 2048

        # Quantization config for 8-bit inference
        quantization_config = BitsAndBytesConfig(load_in_8bit=True)

        # Load tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(base_model_name)
        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.tokenizer.pad_token_id = self.tokenizer.eos_token_id

        # Load base model for sequence classification
        base_model = AutoModelForSequenceClassification.from_pretrained(
            base_model_name,
            quantization_config=quantization_config,
            num_labels=N_SUBFIELDS,
            device_map="auto",
        )
        base_model.config.pad_token_id = self.tokenizer.pad_token_id

        # Load PEFT adapter
        self.model = PeftModel.from_pretrained(base_model, path)
        self.model.eval()

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
        Run inference on the input data.

        Args:
            data: Dictionary containing:
                - inputs (str or List[str]): The text(s) to classify
                - top_k (int, optional): Number of top predictions to return (default: 5)
                - return_all_scores (bool, optional): Return scores for all classes (default: False)

        Returns:
            List of predictions with labels and scores
        """
        # Get inputs
        inputs = data.get("inputs", data)
        if isinstance(inputs, str):
            inputs = [inputs]

        top_k = data.get("top_k", 5)
        return_all_scores = data.get("return_all_scores", False)

        # Tokenize
        encoded = self.tokenizer(
            inputs,
            padding="max_length",
            max_length=self.max_length,
            truncation=True,
            return_tensors="pt"
        )

        # Move to device
        input_ids = encoded["input_ids"].to(self.device)
        attention_mask = encoded["attention_mask"].to(self.device)

        # Run inference
        with torch.no_grad():
            outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
            logits = outputs.logits

        # Convert to probabilities
        probs = torch.softmax(logits, dim=-1)

        results = []
        for i in range(len(inputs)):
            if return_all_scores:
                # Return all scores
                scores = probs[i].cpu().tolist()
                result = [
                    {"label": INVERSE_SUBFIELD_MAP[j], "score": scores[j]}
                    for j in range(N_SUBFIELDS)
                ]
            else:
                # Return top-k predictions
                top_probs, top_indices = torch.topk(probs[i], min(top_k, N_SUBFIELDS))
                result = [
                    {"label": INVERSE_SUBFIELD_MAP[idx.item()], "score": prob.item()}
                    for prob, idx in zip(top_probs, top_indices)
                ]
            results.append(result)

        # Return single result if single input
        if len(results) == 1:
            return results[0]
        return results