PEFT
Safetensors
llama-quantized-1 / handler.py
Albert Gong
Add custom inference handler for Inference Endpoints
b01da00 verified
"""
Custom inference handler for the arxiv-classifier PEFT adapter.
This handler loads a LLaMA-3-8B base model with a LoRA adapter fine-tuned
for arXiv paper classification into 150 subfields.
"""
from typing import Dict, List, Any
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer, BitsAndBytesConfig
from peft import PeftModel
# ArXiv subfield mapping (150 classes)
INVERSE_SUBFIELD_MAP = {
0: "astro-ph", 1: "astro-ph.CO", 2: "astro-ph.EP", 3: "astro-ph.GA",
4: "astro-ph.HE", 5: "astro-ph.IM", 6: "astro-ph.SR", 7: "cond-mat.dis-nn",
8: "cond-mat.mes-hall", 9: "cond-mat.mtrl-sci", 10: "cond-mat.other",
11: "cond-mat.quant-gas", 12: "cond-mat.soft", 13: "cond-mat.stat-mech",
14: "cond-mat.str-el", 15: "cond-mat.supr-con", 16: "cs.AI", 17: "cs.AR",
18: "cs.CC", 19: "cs.CE", 20: "cs.CG", 21: "cs.CL", 22: "cs.CR", 23: "cs.CV",
24: "cs.CY", 25: "cs.DB", 26: "cs.DC", 27: "cs.DL", 28: "cs.DM", 29: "cs.DS",
30: "cs.ET", 31: "cs.FL", 32: "cs.GL", 33: "cs.GR", 34: "cs.GT", 35: "cs.HC",
36: "cs.IR", 37: "cs.IT", 38: "cs.LG", 39: "cs.LO", 40: "cs.MA", 41: "cs.MM",
42: "cs.MS", 43: "cs.NE", 44: "cs.NI", 45: "cs.OH", 46: "cs.OS", 47: "cs.PF",
48: "cs.PL", 49: "cs.RO", 50: "cs.SC", 51: "cs.SD", 52: "cs.SE", 53: "cs.SI",
54: "econ.EM", 55: "econ.GN", 56: "econ.TH", 57: "eess.AS", 58: "eess.IV",
59: "eess.SP", 60: "eess.SY", 61: "gr-qc", 62: "hep-ex", 63: "hep-lat",
64: "hep-ph", 65: "hep-th", 66: "math-ph", 67: "math.AC", 68: "math.AG",
69: "math.AP", 70: "math.AT", 71: "math.CA", 72: "math.CO", 73: "math.CT",
74: "math.CV", 75: "math.DG", 76: "math.DS", 77: "math.FA", 78: "math.GM",
79: "math.GN", 80: "math.GR", 81: "math.GT", 82: "math.HO", 83: "math.KT",
84: "math.LO", 85: "math.MG", 86: "math.NA", 87: "math.NT", 88: "math.OA",
89: "math.OC", 90: "math.PR", 91: "math.QA", 92: "math.RA", 93: "math.RT",
94: "math.SG", 95: "math.SP", 96: "math.ST", 97: "nlin.AO", 98: "nlin.CD",
99: "nlin.CG", 100: "nlin.PS", 101: "nlin.SI", 102: "nucl-ex", 103: "nucl-th",
104: "physics.acc-ph", 105: "physics.ao-ph", 106: "physics.app-ph",
107: "physics.atm-clus", 108: "physics.atom-ph", 109: "physics.bio-ph",
110: "physics.chem-ph", 111: "physics.class-ph", 112: "physics.comp-ph",
113: "physics.data-an", 114: "physics.ed-ph", 115: "physics.flu-dyn",
116: "physics.gen-ph", 117: "physics.geo-ph", 118: "physics.hist-ph",
119: "physics.ins-det", 120: "physics.med-ph", 121: "physics.optics",
122: "physics.plasm-ph", 123: "physics.pop-ph", 124: "physics.soc-ph",
125: "physics.space-ph", 126: "q-bio.BM", 127: "q-bio.CB", 128: "q-bio.GN",
129: "q-bio.MN", 130: "q-bio.NC", 131: "q-bio.OT", 132: "q-bio.PE",
133: "q-bio.QM", 134: "q-bio.SC", 135: "q-bio.TO", 136: "q-fin.CP",
137: "q-fin.GN", 138: "q-fin.MF", 139: "q-fin.PM", 140: "q-fin.PR",
141: "q-fin.RM", 142: "q-fin.ST", 143: "q-fin.TR", 144: "quant-ph",
145: "stat.AP", 146: "stat.CO", 147: "stat.ME", 148: "stat.ML", 149: "stat.OT"
}
N_SUBFIELDS = 150
class EndpointHandler:
def __init__(self, path: str = ""):
"""
Initialize the model and tokenizer.
Args:
path: Path to the model repository (adapter files)
"""
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Base model configuration
base_model_name = "meta-llama/Meta-Llama-3-8B"
self.max_length = 2048
# Quantization config for 8-bit inference
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
# Load tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(base_model_name)
self.tokenizer.pad_token = self.tokenizer.eos_token
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
# Load base model for sequence classification
base_model = AutoModelForSequenceClassification.from_pretrained(
base_model_name,
quantization_config=quantization_config,
num_labels=N_SUBFIELDS,
device_map="auto",
)
base_model.config.pad_token_id = self.tokenizer.pad_token_id
# Load PEFT adapter
self.model = PeftModel.from_pretrained(base_model, path)
self.model.eval()
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Run inference on the input data.
Args:
data: Dictionary containing:
- inputs (str or List[str]): The text(s) to classify
- top_k (int, optional): Number of top predictions to return (default: 5)
- return_all_scores (bool, optional): Return scores for all classes (default: False)
Returns:
List of predictions with labels and scores
"""
# Get inputs
inputs = data.get("inputs", data)
if isinstance(inputs, str):
inputs = [inputs]
top_k = data.get("top_k", 5)
return_all_scores = data.get("return_all_scores", False)
# Tokenize
encoded = self.tokenizer(
inputs,
padding="max_length",
max_length=self.max_length,
truncation=True,
return_tensors="pt"
)
# Move to device
input_ids = encoded["input_ids"].to(self.device)
attention_mask = encoded["attention_mask"].to(self.device)
# Run inference
with torch.no_grad():
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
logits = outputs.logits
# Convert to probabilities
probs = torch.softmax(logits, dim=-1)
results = []
for i in range(len(inputs)):
if return_all_scores:
# Return all scores
scores = probs[i].cpu().tolist()
result = [
{"label": INVERSE_SUBFIELD_MAP[j], "score": scores[j]}
for j in range(N_SUBFIELDS)
]
else:
# Return top-k predictions
top_probs, top_indices = torch.topk(probs[i], min(top_k, N_SUBFIELDS))
result = [
{"label": INVERSE_SUBFIELD_MAP[idx.item()], "score": prob.item()}
for prob, idx in zip(top_probs, top_indices)
]
results.append(result)
# Return single result if single input
if len(results) == 1:
return results[0]
return results