|
|
""" |
|
|
Custom inference handler for the arxiv-classifier PEFT adapter. |
|
|
|
|
|
This handler loads a LLaMA-3-8B base model with a LoRA adapter fine-tuned |
|
|
for arXiv paper classification into 150 subfields. |
|
|
""" |
|
|
|
|
|
from typing import Dict, List, Any |
|
|
import torch |
|
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer, BitsAndBytesConfig |
|
|
from peft import PeftModel |
|
|
|
|
|
|
|
|
INVERSE_SUBFIELD_MAP = { |
|
|
0: "astro-ph", 1: "astro-ph.CO", 2: "astro-ph.EP", 3: "astro-ph.GA", |
|
|
4: "astro-ph.HE", 5: "astro-ph.IM", 6: "astro-ph.SR", 7: "cond-mat.dis-nn", |
|
|
8: "cond-mat.mes-hall", 9: "cond-mat.mtrl-sci", 10: "cond-mat.other", |
|
|
11: "cond-mat.quant-gas", 12: "cond-mat.soft", 13: "cond-mat.stat-mech", |
|
|
14: "cond-mat.str-el", 15: "cond-mat.supr-con", 16: "cs.AI", 17: "cs.AR", |
|
|
18: "cs.CC", 19: "cs.CE", 20: "cs.CG", 21: "cs.CL", 22: "cs.CR", 23: "cs.CV", |
|
|
24: "cs.CY", 25: "cs.DB", 26: "cs.DC", 27: "cs.DL", 28: "cs.DM", 29: "cs.DS", |
|
|
30: "cs.ET", 31: "cs.FL", 32: "cs.GL", 33: "cs.GR", 34: "cs.GT", 35: "cs.HC", |
|
|
36: "cs.IR", 37: "cs.IT", 38: "cs.LG", 39: "cs.LO", 40: "cs.MA", 41: "cs.MM", |
|
|
42: "cs.MS", 43: "cs.NE", 44: "cs.NI", 45: "cs.OH", 46: "cs.OS", 47: "cs.PF", |
|
|
48: "cs.PL", 49: "cs.RO", 50: "cs.SC", 51: "cs.SD", 52: "cs.SE", 53: "cs.SI", |
|
|
54: "econ.EM", 55: "econ.GN", 56: "econ.TH", 57: "eess.AS", 58: "eess.IV", |
|
|
59: "eess.SP", 60: "eess.SY", 61: "gr-qc", 62: "hep-ex", 63: "hep-lat", |
|
|
64: "hep-ph", 65: "hep-th", 66: "math-ph", 67: "math.AC", 68: "math.AG", |
|
|
69: "math.AP", 70: "math.AT", 71: "math.CA", 72: "math.CO", 73: "math.CT", |
|
|
74: "math.CV", 75: "math.DG", 76: "math.DS", 77: "math.FA", 78: "math.GM", |
|
|
79: "math.GN", 80: "math.GR", 81: "math.GT", 82: "math.HO", 83: "math.KT", |
|
|
84: "math.LO", 85: "math.MG", 86: "math.NA", 87: "math.NT", 88: "math.OA", |
|
|
89: "math.OC", 90: "math.PR", 91: "math.QA", 92: "math.RA", 93: "math.RT", |
|
|
94: "math.SG", 95: "math.SP", 96: "math.ST", 97: "nlin.AO", 98: "nlin.CD", |
|
|
99: "nlin.CG", 100: "nlin.PS", 101: "nlin.SI", 102: "nucl-ex", 103: "nucl-th", |
|
|
104: "physics.acc-ph", 105: "physics.ao-ph", 106: "physics.app-ph", |
|
|
107: "physics.atm-clus", 108: "physics.atom-ph", 109: "physics.bio-ph", |
|
|
110: "physics.chem-ph", 111: "physics.class-ph", 112: "physics.comp-ph", |
|
|
113: "physics.data-an", 114: "physics.ed-ph", 115: "physics.flu-dyn", |
|
|
116: "physics.gen-ph", 117: "physics.geo-ph", 118: "physics.hist-ph", |
|
|
119: "physics.ins-det", 120: "physics.med-ph", 121: "physics.optics", |
|
|
122: "physics.plasm-ph", 123: "physics.pop-ph", 124: "physics.soc-ph", |
|
|
125: "physics.space-ph", 126: "q-bio.BM", 127: "q-bio.CB", 128: "q-bio.GN", |
|
|
129: "q-bio.MN", 130: "q-bio.NC", 131: "q-bio.OT", 132: "q-bio.PE", |
|
|
133: "q-bio.QM", 134: "q-bio.SC", 135: "q-bio.TO", 136: "q-fin.CP", |
|
|
137: "q-fin.GN", 138: "q-fin.MF", 139: "q-fin.PM", 140: "q-fin.PR", |
|
|
141: "q-fin.RM", 142: "q-fin.ST", 143: "q-fin.TR", 144: "quant-ph", |
|
|
145: "stat.AP", 146: "stat.CO", 147: "stat.ME", 148: "stat.ML", 149: "stat.OT" |
|
|
} |
|
|
|
|
|
N_SUBFIELDS = 150 |
|
|
|
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self, path: str = ""): |
|
|
""" |
|
|
Initialize the model and tokenizer. |
|
|
|
|
|
Args: |
|
|
path: Path to the model repository (adapter files) |
|
|
""" |
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
base_model_name = "meta-llama/Meta-Llama-3-8B" |
|
|
self.max_length = 2048 |
|
|
|
|
|
|
|
|
quantization_config = BitsAndBytesConfig(load_in_8bit=True) |
|
|
|
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(base_model_name) |
|
|
self.tokenizer.pad_token = self.tokenizer.eos_token |
|
|
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id |
|
|
|
|
|
|
|
|
base_model = AutoModelForSequenceClassification.from_pretrained( |
|
|
base_model_name, |
|
|
quantization_config=quantization_config, |
|
|
num_labels=N_SUBFIELDS, |
|
|
device_map="auto", |
|
|
) |
|
|
base_model.config.pad_token_id = self.tokenizer.pad_token_id |
|
|
|
|
|
|
|
|
self.model = PeftModel.from_pretrained(base_model, path) |
|
|
self.model.eval() |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
|
""" |
|
|
Run inference on the input data. |
|
|
|
|
|
Args: |
|
|
data: Dictionary containing: |
|
|
- inputs (str or List[str]): The text(s) to classify |
|
|
- top_k (int, optional): Number of top predictions to return (default: 5) |
|
|
- return_all_scores (bool, optional): Return scores for all classes (default: False) |
|
|
|
|
|
Returns: |
|
|
List of predictions with labels and scores |
|
|
""" |
|
|
|
|
|
inputs = data.get("inputs", data) |
|
|
if isinstance(inputs, str): |
|
|
inputs = [inputs] |
|
|
|
|
|
top_k = data.get("top_k", 5) |
|
|
return_all_scores = data.get("return_all_scores", False) |
|
|
|
|
|
|
|
|
encoded = self.tokenizer( |
|
|
inputs, |
|
|
padding="max_length", |
|
|
max_length=self.max_length, |
|
|
truncation=True, |
|
|
return_tensors="pt" |
|
|
) |
|
|
|
|
|
|
|
|
input_ids = encoded["input_ids"].to(self.device) |
|
|
attention_mask = encoded["attention_mask"].to(self.device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask) |
|
|
logits = outputs.logits |
|
|
|
|
|
|
|
|
probs = torch.softmax(logits, dim=-1) |
|
|
|
|
|
results = [] |
|
|
for i in range(len(inputs)): |
|
|
if return_all_scores: |
|
|
|
|
|
scores = probs[i].cpu().tolist() |
|
|
result = [ |
|
|
{"label": INVERSE_SUBFIELD_MAP[j], "score": scores[j]} |
|
|
for j in range(N_SUBFIELDS) |
|
|
] |
|
|
else: |
|
|
|
|
|
top_probs, top_indices = torch.topk(probs[i], min(top_k, N_SUBFIELDS)) |
|
|
result = [ |
|
|
{"label": INVERSE_SUBFIELD_MAP[idx.item()], "score": prob.item()} |
|
|
for prob, idx in zip(top_probs, top_indices) |
|
|
] |
|
|
results.append(result) |
|
|
|
|
|
|
|
|
if len(results) == 1: |
|
|
return results[0] |
|
|
return results |
|
|
|