File size: 2,333 Bytes
e5a785e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
# handler.py
# Hugging Face Inference Endpoint custom handler — April 2025 edition
from pathlib import Path
from typing import Dict, Any
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
_BASE_MODEL = "unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit" # 4‑bit quantised base
class EndpointHandler:
"""
Loads the 8 B LLama‑3.1 base in 4‑bit and stitches the PEFT adapter
found in the repository root onto it. Supports standard text‑gen kwargs.
"""
def __init__(self, path: str = "."):
repo = Path(path)
# 1️⃣ Tokeniser
self.tokenizer = AutoTokenizer.from_pretrained(
repo if (repo / "tokenizer_config.json").exists() else _BASE_MODEL,
padding_side="left",
trust_remote_code=True,
)
self.tokenizer.pad_token = self.tokenizer.eos_token
# 2️⃣ Base model in 4‑bit
self.model = AutoModelForCausalLM.from_pretrained(
_BASE_MODEL,
load_in_4bit=True, # bitsandbytes
device_map="auto",
torch_dtype=torch.float16,
trust_remote_code=True,
)
# 3️⃣ Attach LoRA / QLoRA adapter if present
if (repo / "adapter_config.json").exists():
self.model = PeftModel.from_pretrained(self.model, repo, is_trainable=False)
self.model.eval()
def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
prompt = data.get("inputs") or data # raw string or nested JSON
gen_cfg = data.get("parameters", {})
tok_in = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
with torch.inference_mode():
out = self.model.generate(
**tok_in,
max_new_tokens = gen_cfg.get("max_new_tokens", 256),
temperature = gen_cfg.get("temperature", 0.7),
top_p = gen_cfg.get("top_p", 0.9),
do_sample = gen_cfg.get("do_sample", True),
repetition_penalty = gen_cfg.get("repetition_penalty", 1.1),
)
return {"generated_text": self.tokenizer.decode(out[0], skip_special_tokens=True)}
|