File size: 2,333 Bytes
e5a785e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
# handler.py
# Hugging Face Inference Endpoint custom handler — April 2025 edition
from pathlib import Path
from typing import Dict, Any

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel

_BASE_MODEL = "unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit"   # 4‑bit quantised base

class EndpointHandler:
    """

    Loads the 8 B LLama‑3.1 base in 4‑bit and stitches the PEFT adapter

    found in the repository root onto it. Supports standard text‑gen kwargs.

    """

    def __init__(self, path: str = "."):
        repo = Path(path)

        # 1️⃣  Tokeniser
        self.tokenizer = AutoTokenizer.from_pretrained(
            repo if (repo / "tokenizer_config.json").exists() else _BASE_MODEL,
            padding_side="left",
            trust_remote_code=True,
        )
        self.tokenizer.pad_token = self.tokenizer.eos_token

        # 2️⃣  Base model in 4‑bit
        self.model = AutoModelForCausalLM.from_pretrained(
            _BASE_MODEL,
            load_in_4bit=True,                     # bitsandbytes
            device_map="auto",
            torch_dtype=torch.float16,
            trust_remote_code=True,
        )

        # 3️⃣  Attach LoRA / QLoRA adapter if present
        if (repo / "adapter_config.json").exists():
            self.model = PeftModel.from_pretrained(self.model, repo, is_trainable=False)
        self.model.eval()

    def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
        prompt    = data.get("inputs") or data                       # raw string or nested JSON
        gen_cfg   = data.get("parameters", {})
        tok_in    = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)

        with torch.inference_mode():
            out = self.model.generate(
                **tok_in,
                max_new_tokens = gen_cfg.get("max_new_tokens", 256),
                temperature    = gen_cfg.get("temperature", 0.7),
                top_p          = gen_cfg.get("top_p", 0.9),
                do_sample      = gen_cfg.get("do_sample", True),
                repetition_penalty = gen_cfg.get("repetition_penalty", 1.1),
            )

        return {"generated_text": self.tokenizer.decode(out[0], skip_special_tokens=True)}