""" Custom handler for HuggingFace Inference Endpoints. This handles loading the LoRA adapter on top of the base model. To deploy: 1. Push this handler.py to your model repo on HuggingFace 2. Create an Inference Endpoint pointing to jimfhahn/bibframe-olmo-1b-v2 3. The endpoint will automatically use this custom handler """ from typing import Dict, Any import torch from transformers import AutoModelForCausalLM, AutoTokenizer from peft import PeftModel class EndpointHandler: def __init__(self, path: str = ""): """ Initialize the model and tokenizer. Args: path: Path to the model directory (provided by Inference Endpoints) """ # Load base model base_model_id = "amd/AMD-OLMo-1B" self.tokenizer = AutoTokenizer.from_pretrained(base_model_id) # Load base model with appropriate settings self.model = AutoModelForCausalLM.from_pretrained( base_model_id, torch_dtype=torch.float16, device_map="auto", trust_remote_code=True, ) # Load LoRA adapter from the endpoint path self.model = PeftModel.from_pretrained( self.model, path, # This is the model repo path torch_dtype=torch.float16, ) # Merge adapter for faster inference (optional but recommended) self.model = self.model.merge_and_unload() self.model.eval() # Set pad token if not set if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: """ Process inference request. Args: data: Request payload with 'inputs' key containing RDF/XML to correct Returns: Dictionary with 'generated_text' containing corrected RDF/XML """ # Extract input inputs = data.get("inputs", "") parameters = data.get("parameters", {}) # Build prompt using the model's expected format prompt = self._build_prompt(inputs) # Tokenize encoded = self.tokenizer( prompt, return_tensors="pt", truncation=True, max_length=2048, ).to(self.model.device) # Generation parameters gen_kwargs = { "max_new_tokens": parameters.get("max_new_tokens", 1024), "temperature": parameters.get("temperature", 0.1), "top_p": parameters.get("top_p", 0.95), "do_sample": parameters.get("do_sample", True), "pad_token_id": self.tokenizer.pad_token_id, "eos_token_id": self.tokenizer.eos_token_id, } # Generate with torch.no_grad(): outputs = self.model.generate(**encoded, **gen_kwargs) # Decode, removing the prompt generated = self.tokenizer.decode( outputs[0][encoded["input_ids"].shape[1]:], skip_special_tokens=True, ) # Extract just the RDF/XML (stop at end markers if present) corrected = self._extract_rdf(generated) return { "generated_text": corrected, "prompt_used": prompt, } def _build_prompt(self, rdf_input: str) -> str: """Build the prompt in ChatML format (matching training data).""" return ( "<|im_start|>system\n" "You are a BIBFRAME expert. Fix the following malformed RDF/XML " "to produce valid BIBFRAME following Library of Congress conventions.<|im_end|>\n" f"<|im_start|>user\n{rdf_input}<|im_end|>\n" "<|im_start|>assistant\n" ) def _extract_rdf(self, text: str) -> str: """Extract RDF/XML from generated text, handling any trailing content.""" # Try to find the closing rdf:RDF tag if "" in text: end_idx = text.index("") + len("") return text[:end_idx].strip() return text.strip()