ongilLabs commited on
Commit
982f6ce
·
verified ·
1 Parent(s): 06b1920

Add custom handler for Inference Endpoints

Browse files
Files changed (1) hide show
  1. handler.py +93 -0
handler.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Custom handler for Hugging Face Inference Endpoints
3
+ Model: ongilLabs/IB-Math-Instruct-7B
4
+ """
5
+
6
+ from typing import Dict, List, Any
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer
8
+ import torch
9
+
10
+
11
+ class EndpointHandler:
12
+ def __init__(self, path: str = ""):
13
+ """Initialize the model and tokenizer."""
14
+ self.tokenizer = AutoTokenizer.from_pretrained(
15
+ path,
16
+ trust_remote_code=True
17
+ )
18
+ self.model = AutoModelForCausalLM.from_pretrained(
19
+ path,
20
+ torch_dtype=torch.bfloat16,
21
+ device_map="auto",
22
+ trust_remote_code=True
23
+ )
24
+ self.model.eval()
25
+
26
+ # Default system prompt
27
+ self.default_system = """You are an expert IB Mathematics tutor. When solving problems:
28
+ 1. Show your work step by step
29
+ 2. Explain your reasoning clearly
30
+ 3. Use proper mathematical notation
31
+ 4. Provide the final answer clearly marked"""
32
+
33
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
34
+ """
35
+ Handle inference request.
36
+
37
+ Args:
38
+ data: Dictionary with 'inputs' (str or list) and optional 'parameters'
39
+
40
+ Returns:
41
+ Dictionary with 'generated_text'
42
+ """
43
+ inputs = data.get("inputs", "")
44
+ parameters = data.get("parameters", {})
45
+
46
+ # Extract parameters with defaults
47
+ max_new_tokens = parameters.get("max_new_tokens", 1024)
48
+ temperature = parameters.get("temperature", 0.7)
49
+ top_p = parameters.get("top_p", 0.9)
50
+ system_prompt = parameters.get("system_prompt", self.default_system)
51
+
52
+ # Handle both string and message list inputs
53
+ if isinstance(inputs, str):
54
+ messages = [
55
+ {"role": "system", "content": system_prompt},
56
+ {"role": "user", "content": inputs}
57
+ ]
58
+ elif isinstance(inputs, list):
59
+ # Assume it's already a list of messages
60
+ messages = inputs
61
+ # Prepend system if not present
62
+ if messages and messages[0].get("role") != "system":
63
+ messages = [{"role": "system", "content": system_prompt}] + messages
64
+ else:
65
+ return {"error": "Invalid input format. Expected string or list of messages."}
66
+
67
+ # Apply chat template
68
+ prompt = self.tokenizer.apply_chat_template(
69
+ messages,
70
+ tokenize=False,
71
+ add_generation_prompt=True
72
+ )
73
+
74
+ input_ids = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
75
+
76
+ with torch.no_grad():
77
+ outputs = self.model.generate(
78
+ **input_ids,
79
+ max_new_tokens=max_new_tokens,
80
+ temperature=temperature if temperature > 0 else None,
81
+ top_p=top_p,
82
+ do_sample=temperature > 0,
83
+ pad_token_id=self.tokenizer.eos_token_id,
84
+ )
85
+
86
+ # Decode only new tokens (exclude prompt)
87
+ response = self.tokenizer.decode(
88
+ outputs[0][input_ids["input_ids"].shape[1]:],
89
+ skip_special_tokens=True
90
+ )
91
+
92
+ return {"generated_text": response}
93
+