| """ |
| Custom handler for HuggingFace Inference Endpoints. |
| This handles loading the LoRA adapter on top of the base model. |
| |
| To deploy: |
| 1. Push this handler.py to your model repo on HuggingFace |
| 2. Create an Inference Endpoint pointing to jimfhahn/bibframe-olmo-1b-v2 |
| 3. The endpoint will automatically use this custom handler |
| """ |
|
|
| from typing import Dict, Any |
| import torch |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
| from peft import PeftModel |
|
|
|
|
| class EndpointHandler: |
| def __init__(self, path: str = ""): |
| """ |
| Initialize the model and tokenizer. |
| |
| Args: |
| path: Path to the model directory (provided by Inference Endpoints) |
| """ |
| |
| base_model_id = "amd/AMD-OLMo-1B" |
| |
| self.tokenizer = AutoTokenizer.from_pretrained(base_model_id) |
| |
| |
| self.model = AutoModelForCausalLM.from_pretrained( |
| base_model_id, |
| torch_dtype=torch.float16, |
| device_map="auto", |
| trust_remote_code=True, |
| ) |
| |
| |
| self.model = PeftModel.from_pretrained( |
| self.model, |
| path, |
| torch_dtype=torch.float16, |
| ) |
| |
| |
| self.model = self.model.merge_and_unload() |
| self.model.eval() |
| |
| |
| if self.tokenizer.pad_token is None: |
| self.tokenizer.pad_token = self.tokenizer.eos_token |
|
|
| def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
| """ |
| Process inference request. |
| |
| Args: |
| data: Request payload with 'inputs' key containing RDF/XML to correct |
| |
| Returns: |
| Dictionary with 'generated_text' containing corrected RDF/XML |
| """ |
| |
| inputs = data.get("inputs", "") |
| parameters = data.get("parameters", {}) |
| |
| |
| prompt = self._build_prompt(inputs) |
| |
| |
| encoded = self.tokenizer( |
| prompt, |
| return_tensors="pt", |
| truncation=True, |
| max_length=2048, |
| ).to(self.model.device) |
| |
| |
| gen_kwargs = { |
| "max_new_tokens": parameters.get("max_new_tokens", 1024), |
| "temperature": parameters.get("temperature", 0.1), |
| "top_p": parameters.get("top_p", 0.95), |
| "do_sample": parameters.get("do_sample", True), |
| "pad_token_id": self.tokenizer.pad_token_id, |
| "eos_token_id": self.tokenizer.eos_token_id, |
| } |
| |
| |
| with torch.no_grad(): |
| outputs = self.model.generate(**encoded, **gen_kwargs) |
| |
| |
| generated = self.tokenizer.decode( |
| outputs[0][encoded["input_ids"].shape[1]:], |
| skip_special_tokens=True, |
| ) |
| |
| |
| corrected = self._extract_rdf(generated) |
| |
| return { |
| "generated_text": corrected, |
| "prompt_used": prompt, |
| } |
| |
| def _build_prompt(self, rdf_input: str) -> str: |
| """Build the prompt in ChatML format (matching training data).""" |
| return ( |
| "<|im_start|>system\n" |
| "You are a BIBFRAME expert. Fix the following malformed RDF/XML " |
| "to produce valid BIBFRAME following Library of Congress conventions.<|im_end|>\n" |
| f"<|im_start|>user\n{rdf_input}<|im_end|>\n" |
| "<|im_start|>assistant\n" |
| ) |
| |
| def _extract_rdf(self, text: str) -> str: |
| """Extract RDF/XML from generated text, handling any trailing content.""" |
| |
| if "</rdf:RDF>" in text: |
| end_idx = text.index("</rdf:RDF>") + len("</rdf:RDF>") |
| return text[:end_idx].strip() |
| return text.strip() |
|
|