golyuval commited on
Commit
e5a785e
·
verified ·
1 Parent(s): 3ff3fca

Upload 2 files

Browse files
Files changed (2) hide show
  1. handler.py +58 -0
  2. requirements.txt +4 -0
handler.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # handler.py
2
+ # Hugging Face Inference Endpoint custom handler — April 2025 edition
3
+ from pathlib import Path
4
+ from typing import Dict, Any
5
+
6
+ import torch
7
+ from transformers import AutoTokenizer, AutoModelForCausalLM
8
+ from peft import PeftModel
9
+
10
+ _BASE_MODEL = "unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit" # 4‑bit quantised base
11
+
12
+ class EndpointHandler:
13
+ """
14
+ Loads the 8 B LLama‑3.1 base in 4‑bit and stitches the PEFT adapter
15
+ found in the repository root onto it. Supports standard text‑gen kwargs.
16
+ """
17
+
18
+ def __init__(self, path: str = "."):
19
+ repo = Path(path)
20
+
21
+ # 1️⃣ Tokeniser
22
+ self.tokenizer = AutoTokenizer.from_pretrained(
23
+ repo if (repo / "tokenizer_config.json").exists() else _BASE_MODEL,
24
+ padding_side="left",
25
+ trust_remote_code=True,
26
+ )
27
+ self.tokenizer.pad_token = self.tokenizer.eos_token
28
+
29
+ # 2️⃣ Base model in 4‑bit
30
+ self.model = AutoModelForCausalLM.from_pretrained(
31
+ _BASE_MODEL,
32
+ load_in_4bit=True, # bitsandbytes
33
+ device_map="auto",
34
+ torch_dtype=torch.float16,
35
+ trust_remote_code=True,
36
+ )
37
+
38
+ # 3️⃣ Attach LoRA / QLoRA adapter if present
39
+ if (repo / "adapter_config.json").exists():
40
+ self.model = PeftModel.from_pretrained(self.model, repo, is_trainable=False)
41
+ self.model.eval()
42
+
43
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
44
+ prompt = data.get("inputs") or data # raw string or nested JSON
45
+ gen_cfg = data.get("parameters", {})
46
+ tok_in = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
47
+
48
+ with torch.inference_mode():
49
+ out = self.model.generate(
50
+ **tok_in,
51
+ max_new_tokens = gen_cfg.get("max_new_tokens", 256),
52
+ temperature = gen_cfg.get("temperature", 0.7),
53
+ top_p = gen_cfg.get("top_p", 0.9),
54
+ do_sample = gen_cfg.get("do_sample", True),
55
+ repetition_penalty = gen_cfg.get("repetition_penalty", 1.1),
56
+ )
57
+
58
+ return {"generated_text": self.tokenizer.decode(out[0], skip_special_tokens=True)}
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ transformers>=4.42.0
2
+ peft>=0.11.1
3
+ accelerate>=0.29.3
4
+ bitsandbytes==0.43.2