""" Custom inference handler for the arxiv-classifier PEFT adapter. This handler loads a LLaMA-3-8B base model with a LoRA adapter fine-tuned for arXiv paper classification into 150 subfields. """ from typing import Dict, List, Any import torch from transformers import AutoModelForSequenceClassification, AutoTokenizer, BitsAndBytesConfig from peft import PeftModel # ArXiv subfield mapping (150 classes) INVERSE_SUBFIELD_MAP = { 0: "astro-ph", 1: "astro-ph.CO", 2: "astro-ph.EP", 3: "astro-ph.GA", 4: "astro-ph.HE", 5: "astro-ph.IM", 6: "astro-ph.SR", 7: "cond-mat.dis-nn", 8: "cond-mat.mes-hall", 9: "cond-mat.mtrl-sci", 10: "cond-mat.other", 11: "cond-mat.quant-gas", 12: "cond-mat.soft", 13: "cond-mat.stat-mech", 14: "cond-mat.str-el", 15: "cond-mat.supr-con", 16: "cs.AI", 17: "cs.AR", 18: "cs.CC", 19: "cs.CE", 20: "cs.CG", 21: "cs.CL", 22: "cs.CR", 23: "cs.CV", 24: "cs.CY", 25: "cs.DB", 26: "cs.DC", 27: "cs.DL", 28: "cs.DM", 29: "cs.DS", 30: "cs.ET", 31: "cs.FL", 32: "cs.GL", 33: "cs.GR", 34: "cs.GT", 35: "cs.HC", 36: "cs.IR", 37: "cs.IT", 38: "cs.LG", 39: "cs.LO", 40: "cs.MA", 41: "cs.MM", 42: "cs.MS", 43: "cs.NE", 44: "cs.NI", 45: "cs.OH", 46: "cs.OS", 47: "cs.PF", 48: "cs.PL", 49: "cs.RO", 50: "cs.SC", 51: "cs.SD", 52: "cs.SE", 53: "cs.SI", 54: "econ.EM", 55: "econ.GN", 56: "econ.TH", 57: "eess.AS", 58: "eess.IV", 59: "eess.SP", 60: "eess.SY", 61: "gr-qc", 62: "hep-ex", 63: "hep-lat", 64: "hep-ph", 65: "hep-th", 66: "math-ph", 67: "math.AC", 68: "math.AG", 69: "math.AP", 70: "math.AT", 71: "math.CA", 72: "math.CO", 73: "math.CT", 74: "math.CV", 75: "math.DG", 76: "math.DS", 77: "math.FA", 78: "math.GM", 79: "math.GN", 80: "math.GR", 81: "math.GT", 82: "math.HO", 83: "math.KT", 84: "math.LO", 85: "math.MG", 86: "math.NA", 87: "math.NT", 88: "math.OA", 89: "math.OC", 90: "math.PR", 91: "math.QA", 92: "math.RA", 93: "math.RT", 94: "math.SG", 95: "math.SP", 96: "math.ST", 97: "nlin.AO", 98: "nlin.CD", 99: "nlin.CG", 100: "nlin.PS", 101: "nlin.SI", 102: "nucl-ex", 103: "nucl-th", 104: "physics.acc-ph", 105: "physics.ao-ph", 106: "physics.app-ph", 107: "physics.atm-clus", 108: "physics.atom-ph", 109: "physics.bio-ph", 110: "physics.chem-ph", 111: "physics.class-ph", 112: "physics.comp-ph", 113: "physics.data-an", 114: "physics.ed-ph", 115: "physics.flu-dyn", 116: "physics.gen-ph", 117: "physics.geo-ph", 118: "physics.hist-ph", 119: "physics.ins-det", 120: "physics.med-ph", 121: "physics.optics", 122: "physics.plasm-ph", 123: "physics.pop-ph", 124: "physics.soc-ph", 125: "physics.space-ph", 126: "q-bio.BM", 127: "q-bio.CB", 128: "q-bio.GN", 129: "q-bio.MN", 130: "q-bio.NC", 131: "q-bio.OT", 132: "q-bio.PE", 133: "q-bio.QM", 134: "q-bio.SC", 135: "q-bio.TO", 136: "q-fin.CP", 137: "q-fin.GN", 138: "q-fin.MF", 139: "q-fin.PM", 140: "q-fin.PR", 141: "q-fin.RM", 142: "q-fin.ST", 143: "q-fin.TR", 144: "quant-ph", 145: "stat.AP", 146: "stat.CO", 147: "stat.ME", 148: "stat.ML", 149: "stat.OT" } N_SUBFIELDS = 150 class EndpointHandler: def __init__(self, path: str = ""): """ Initialize the model and tokenizer. Args: path: Path to the model repository (adapter files) """ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Base model configuration base_model_name = "meta-llama/Meta-Llama-3-8B" self.max_length = 2048 # Quantization config for 8-bit inference quantization_config = BitsAndBytesConfig(load_in_8bit=True) # Load tokenizer self.tokenizer = AutoTokenizer.from_pretrained(base_model_name) self.tokenizer.pad_token = self.tokenizer.eos_token self.tokenizer.pad_token_id = self.tokenizer.eos_token_id # Load base model for sequence classification base_model = AutoModelForSequenceClassification.from_pretrained( base_model_name, quantization_config=quantization_config, num_labels=N_SUBFIELDS, device_map="auto", ) base_model.config.pad_token_id = self.tokenizer.pad_token_id # Load PEFT adapter self.model = PeftModel.from_pretrained(base_model, path) self.model.eval() def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ Run inference on the input data. Args: data: Dictionary containing: - inputs (str or List[str]): The text(s) to classify - top_k (int, optional): Number of top predictions to return (default: 5) - return_all_scores (bool, optional): Return scores for all classes (default: False) Returns: List of predictions with labels and scores """ # Get inputs inputs = data.get("inputs", data) if isinstance(inputs, str): inputs = [inputs] top_k = data.get("top_k", 5) return_all_scores = data.get("return_all_scores", False) # Tokenize encoded = self.tokenizer( inputs, padding="max_length", max_length=self.max_length, truncation=True, return_tensors="pt" ) # Move to device input_ids = encoded["input_ids"].to(self.device) attention_mask = encoded["attention_mask"].to(self.device) # Run inference with torch.no_grad(): outputs = self.model(input_ids=input_ids, attention_mask=attention_mask) logits = outputs.logits # Convert to probabilities probs = torch.softmax(logits, dim=-1) results = [] for i in range(len(inputs)): if return_all_scores: # Return all scores scores = probs[i].cpu().tolist() result = [ {"label": INVERSE_SUBFIELD_MAP[j], "score": scores[j]} for j in range(N_SUBFIELDS) ] else: # Return top-k predictions top_probs, top_indices = torch.topk(probs[i], min(top_k, N_SUBFIELDS)) result = [ {"label": INVERSE_SUBFIELD_MAP[idx.item()], "score": prob.item()} for prob, idx in zip(top_probs, top_indices) ] results.append(result) # Return single result if single input if len(results) == 1: return results[0] return results