File size: 6,566 Bytes
b01da00 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
"""
Custom inference handler for the arxiv-classifier PEFT adapter.
This handler loads a LLaMA-3-8B base model with a LoRA adapter fine-tuned
for arXiv paper classification into 150 subfields.
"""
from typing import Dict, List, Any
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer, BitsAndBytesConfig
from peft import PeftModel
# ArXiv subfield mapping (150 classes)
INVERSE_SUBFIELD_MAP = {
0: "astro-ph", 1: "astro-ph.CO", 2: "astro-ph.EP", 3: "astro-ph.GA",
4: "astro-ph.HE", 5: "astro-ph.IM", 6: "astro-ph.SR", 7: "cond-mat.dis-nn",
8: "cond-mat.mes-hall", 9: "cond-mat.mtrl-sci", 10: "cond-mat.other",
11: "cond-mat.quant-gas", 12: "cond-mat.soft", 13: "cond-mat.stat-mech",
14: "cond-mat.str-el", 15: "cond-mat.supr-con", 16: "cs.AI", 17: "cs.AR",
18: "cs.CC", 19: "cs.CE", 20: "cs.CG", 21: "cs.CL", 22: "cs.CR", 23: "cs.CV",
24: "cs.CY", 25: "cs.DB", 26: "cs.DC", 27: "cs.DL", 28: "cs.DM", 29: "cs.DS",
30: "cs.ET", 31: "cs.FL", 32: "cs.GL", 33: "cs.GR", 34: "cs.GT", 35: "cs.HC",
36: "cs.IR", 37: "cs.IT", 38: "cs.LG", 39: "cs.LO", 40: "cs.MA", 41: "cs.MM",
42: "cs.MS", 43: "cs.NE", 44: "cs.NI", 45: "cs.OH", 46: "cs.OS", 47: "cs.PF",
48: "cs.PL", 49: "cs.RO", 50: "cs.SC", 51: "cs.SD", 52: "cs.SE", 53: "cs.SI",
54: "econ.EM", 55: "econ.GN", 56: "econ.TH", 57: "eess.AS", 58: "eess.IV",
59: "eess.SP", 60: "eess.SY", 61: "gr-qc", 62: "hep-ex", 63: "hep-lat",
64: "hep-ph", 65: "hep-th", 66: "math-ph", 67: "math.AC", 68: "math.AG",
69: "math.AP", 70: "math.AT", 71: "math.CA", 72: "math.CO", 73: "math.CT",
74: "math.CV", 75: "math.DG", 76: "math.DS", 77: "math.FA", 78: "math.GM",
79: "math.GN", 80: "math.GR", 81: "math.GT", 82: "math.HO", 83: "math.KT",
84: "math.LO", 85: "math.MG", 86: "math.NA", 87: "math.NT", 88: "math.OA",
89: "math.OC", 90: "math.PR", 91: "math.QA", 92: "math.RA", 93: "math.RT",
94: "math.SG", 95: "math.SP", 96: "math.ST", 97: "nlin.AO", 98: "nlin.CD",
99: "nlin.CG", 100: "nlin.PS", 101: "nlin.SI", 102: "nucl-ex", 103: "nucl-th",
104: "physics.acc-ph", 105: "physics.ao-ph", 106: "physics.app-ph",
107: "physics.atm-clus", 108: "physics.atom-ph", 109: "physics.bio-ph",
110: "physics.chem-ph", 111: "physics.class-ph", 112: "physics.comp-ph",
113: "physics.data-an", 114: "physics.ed-ph", 115: "physics.flu-dyn",
116: "physics.gen-ph", 117: "physics.geo-ph", 118: "physics.hist-ph",
119: "physics.ins-det", 120: "physics.med-ph", 121: "physics.optics",
122: "physics.plasm-ph", 123: "physics.pop-ph", 124: "physics.soc-ph",
125: "physics.space-ph", 126: "q-bio.BM", 127: "q-bio.CB", 128: "q-bio.GN",
129: "q-bio.MN", 130: "q-bio.NC", 131: "q-bio.OT", 132: "q-bio.PE",
133: "q-bio.QM", 134: "q-bio.SC", 135: "q-bio.TO", 136: "q-fin.CP",
137: "q-fin.GN", 138: "q-fin.MF", 139: "q-fin.PM", 140: "q-fin.PR",
141: "q-fin.RM", 142: "q-fin.ST", 143: "q-fin.TR", 144: "quant-ph",
145: "stat.AP", 146: "stat.CO", 147: "stat.ME", 148: "stat.ML", 149: "stat.OT"
}
N_SUBFIELDS = 150
class EndpointHandler:
def __init__(self, path: str = ""):
"""
Initialize the model and tokenizer.
Args:
path: Path to the model repository (adapter files)
"""
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Base model configuration
base_model_name = "meta-llama/Meta-Llama-3-8B"
self.max_length = 2048
# Quantization config for 8-bit inference
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
# Load tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(base_model_name)
self.tokenizer.pad_token = self.tokenizer.eos_token
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
# Load base model for sequence classification
base_model = AutoModelForSequenceClassification.from_pretrained(
base_model_name,
quantization_config=quantization_config,
num_labels=N_SUBFIELDS,
device_map="auto",
)
base_model.config.pad_token_id = self.tokenizer.pad_token_id
# Load PEFT adapter
self.model = PeftModel.from_pretrained(base_model, path)
self.model.eval()
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Run inference on the input data.
Args:
data: Dictionary containing:
- inputs (str or List[str]): The text(s) to classify
- top_k (int, optional): Number of top predictions to return (default: 5)
- return_all_scores (bool, optional): Return scores for all classes (default: False)
Returns:
List of predictions with labels and scores
"""
# Get inputs
inputs = data.get("inputs", data)
if isinstance(inputs, str):
inputs = [inputs]
top_k = data.get("top_k", 5)
return_all_scores = data.get("return_all_scores", False)
# Tokenize
encoded = self.tokenizer(
inputs,
padding="max_length",
max_length=self.max_length,
truncation=True,
return_tensors="pt"
)
# Move to device
input_ids = encoded["input_ids"].to(self.device)
attention_mask = encoded["attention_mask"].to(self.device)
# Run inference
with torch.no_grad():
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
logits = outputs.logits
# Convert to probabilities
probs = torch.softmax(logits, dim=-1)
results = []
for i in range(len(inputs)):
if return_all_scores:
# Return all scores
scores = probs[i].cpu().tolist()
result = [
{"label": INVERSE_SUBFIELD_MAP[j], "score": scores[j]}
for j in range(N_SUBFIELDS)
]
else:
# Return top-k predictions
top_probs, top_indices = torch.topk(probs[i], min(top_k, N_SUBFIELDS))
result = [
{"label": INVERSE_SUBFIELD_MAP[idx.item()], "score": prob.item()}
for prob, idx in zip(top_probs, top_indices)
]
results.append(result)
# Return single result if single input
if len(results) == 1:
return results[0]
return results
|