|
|
|
|
|
|
|
|
from pathlib import Path
|
|
|
from typing import Dict, Any
|
|
|
|
|
|
import torch
|
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
|
from peft import PeftModel
|
|
|
|
|
|
_BASE_MODEL = "unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit"
|
|
|
|
|
|
class EndpointHandler:
|
|
|
"""
|
|
|
Loads the 8 B LLama‑3.1 base in 4‑bit and stitches the PEFT adapter
|
|
|
found in the repository root onto it. Supports standard text‑gen kwargs.
|
|
|
"""
|
|
|
|
|
|
def __init__(self, path: str = "."):
|
|
|
repo = Path(path)
|
|
|
|
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(
|
|
|
repo if (repo / "tokenizer_config.json").exists() else _BASE_MODEL,
|
|
|
padding_side="left",
|
|
|
trust_remote_code=True,
|
|
|
)
|
|
|
self.tokenizer.pad_token = self.tokenizer.eos_token
|
|
|
|
|
|
|
|
|
self.model = AutoModelForCausalLM.from_pretrained(
|
|
|
_BASE_MODEL,
|
|
|
load_in_4bit=True,
|
|
|
device_map="auto",
|
|
|
torch_dtype=torch.float16,
|
|
|
trust_remote_code=True,
|
|
|
)
|
|
|
|
|
|
|
|
|
if (repo / "adapter_config.json").exists():
|
|
|
self.model = PeftModel.from_pretrained(self.model, repo, is_trainable=False)
|
|
|
self.model.eval()
|
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
|
|
|
prompt = data.get("inputs") or data
|
|
|
gen_cfg = data.get("parameters", {})
|
|
|
tok_in = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
|
|
|
|
|
|
with torch.inference_mode():
|
|
|
out = self.model.generate(
|
|
|
**tok_in,
|
|
|
max_new_tokens = gen_cfg.get("max_new_tokens", 256),
|
|
|
temperature = gen_cfg.get("temperature", 0.7),
|
|
|
top_p = gen_cfg.get("top_p", 0.9),
|
|
|
do_sample = gen_cfg.get("do_sample", True),
|
|
|
repetition_penalty = gen_cfg.get("repetition_penalty", 1.1),
|
|
|
)
|
|
|
|
|
|
return {"generated_text": self.tokenizer.decode(out[0], skip_special_tokens=True)}
|
|
|
|