bibframe-olmo-1b / handler.py
jimfhahn's picture
Upload handler.py with huggingface_hub
2827a15 verified
"""
Custom handler for HuggingFace Inference Endpoints.
This handles loading the LoRA adapter on top of the base model.
To deploy:
1. Push this handler.py to your model repo on HuggingFace
2. Create an Inference Endpoint pointing to jimfhahn/bibframe-olmo-1b-v2
3. The endpoint will automatically use this custom handler
"""
from typing import Dict, Any
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
class EndpointHandler:
def __init__(self, path: str = ""):
"""
Initialize the model and tokenizer.
Args:
path: Path to the model directory (provided by Inference Endpoints)
"""
# Load base model
base_model_id = "amd/AMD-OLMo-1B"
self.tokenizer = AutoTokenizer.from_pretrained(base_model_id)
# Load base model with appropriate settings
self.model = AutoModelForCausalLM.from_pretrained(
base_model_id,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True,
)
# Load LoRA adapter from the endpoint path
self.model = PeftModel.from_pretrained(
self.model,
path, # This is the model repo path
torch_dtype=torch.float16,
)
# Merge adapter for faster inference (optional but recommended)
self.model = self.model.merge_and_unload()
self.model.eval()
# Set pad token if not set
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Process inference request.
Args:
data: Request payload with 'inputs' key containing RDF/XML to correct
Returns:
Dictionary with 'generated_text' containing corrected RDF/XML
"""
# Extract input
inputs = data.get("inputs", "")
parameters = data.get("parameters", {})
# Build prompt using the model's expected format
prompt = self._build_prompt(inputs)
# Tokenize
encoded = self.tokenizer(
prompt,
return_tensors="pt",
truncation=True,
max_length=2048,
).to(self.model.device)
# Generation parameters
gen_kwargs = {
"max_new_tokens": parameters.get("max_new_tokens", 1024),
"temperature": parameters.get("temperature", 0.1),
"top_p": parameters.get("top_p", 0.95),
"do_sample": parameters.get("do_sample", True),
"pad_token_id": self.tokenizer.pad_token_id,
"eos_token_id": self.tokenizer.eos_token_id,
}
# Generate
with torch.no_grad():
outputs = self.model.generate(**encoded, **gen_kwargs)
# Decode, removing the prompt
generated = self.tokenizer.decode(
outputs[0][encoded["input_ids"].shape[1]:],
skip_special_tokens=True,
)
# Extract just the RDF/XML (stop at end markers if present)
corrected = self._extract_rdf(generated)
return {
"generated_text": corrected,
"prompt_used": prompt,
}
def _build_prompt(self, rdf_input: str) -> str:
"""Build the prompt in ChatML format (matching training data)."""
return (
"<|im_start|>system\n"
"You are a BIBFRAME expert. Fix the following malformed RDF/XML "
"to produce valid BIBFRAME following Library of Congress conventions.<|im_end|>\n"
f"<|im_start|>user\n{rdf_input}<|im_end|>\n"
"<|im_start|>assistant\n"
)
def _extract_rdf(self, text: str) -> str:
"""Extract RDF/XML from generated text, handling any trailing content."""
# Try to find the closing rdf:RDF tag
if "</rdf:RDF>" in text:
end_idx = text.index("</rdf:RDF>") + len("</rdf:RDF>")
return text[:end_idx].strip()
return text.strip()