campedersen commited on
Commit
f8de50e
·
verified ·
1 Parent(s): 862f58d

Add custom handler for Inference Endpoints

Browse files
Files changed (1) hide show
  1. handler.py +90 -0
handler.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Custom handler for cad0 HuggingFace Inference Endpoint.
3
+
4
+ This loads the Qwen2.5-Coder-7B-Instruct base model with the cad0 LoRA adapter.
5
+ Upload this file to the campedersen/cad0 model repo.
6
+ """
7
+
8
+ from typing import Dict, Any
9
+ import torch
10
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
11
+
12
+
13
+ class EndpointHandler:
14
+ def __init__(self, path: str = ""):
15
+ """Load model and tokenizer."""
16
+ # Base model that cad0 was fine-tuned from
17
+ base_model = "Qwen/Qwen2.5-Coder-7B-Instruct"
18
+
19
+ # Load tokenizer from base model
20
+ self.tokenizer = AutoTokenizer.from_pretrained(
21
+ base_model,
22
+ trust_remote_code=True
23
+ )
24
+
25
+ # Quantization config for efficient inference
26
+ bnb_config = BitsAndBytesConfig(
27
+ load_in_4bit=True,
28
+ bnb_4bit_compute_dtype=torch.float16,
29
+ )
30
+
31
+ # Load the fine-tuned model (path points to the model repo)
32
+ self.model = AutoModelForCausalLM.from_pretrained(
33
+ path,
34
+ quantization_config=bnb_config,
35
+ trust_remote_code=True,
36
+ device_map="auto",
37
+ low_cpu_mem_usage=True,
38
+ )
39
+
40
+ self.model.eval()
41
+
42
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
43
+ """
44
+ Handle inference request.
45
+
46
+ Expected input format:
47
+ {
48
+ "inputs": "prompt text or chat-formatted text",
49
+ "parameters": {
50
+ "max_new_tokens": 256,
51
+ "temperature": 0.1,
52
+ "do_sample": true,
53
+ "return_full_text": false
54
+ }
55
+ }
56
+ """
57
+ inputs = data.get("inputs", "")
58
+ parameters = data.get("parameters", {})
59
+
60
+ # Default parameters
61
+ max_new_tokens = parameters.get("max_new_tokens", 256)
62
+ temperature = parameters.get("temperature", 0.1)
63
+ do_sample = parameters.get("do_sample", temperature > 0)
64
+ return_full_text = parameters.get("return_full_text", False)
65
+
66
+ # Tokenize
67
+ encoded = self.tokenizer(inputs, return_tensors="pt").to(self.model.device)
68
+ input_length = encoded.input_ids.shape[1]
69
+
70
+ # Generate
71
+ with torch.no_grad():
72
+ outputs = self.model.generate(
73
+ **encoded,
74
+ max_new_tokens=max_new_tokens,
75
+ temperature=temperature if temperature > 0 else 1.0,
76
+ do_sample=do_sample,
77
+ pad_token_id=self.tokenizer.eos_token_id,
78
+ eos_token_id=self.tokenizer.eos_token_id,
79
+ )
80
+
81
+ # Decode
82
+ if return_full_text:
83
+ generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
84
+ else:
85
+ generated_text = self.tokenizer.decode(
86
+ outputs[0][input_length:],
87
+ skip_special_tokens=True
88
+ )
89
+
90
+ return {"generated_text": generated_text}