h3ir commited on
Commit
6d80aa7
·
verified ·
1 Parent(s): 9f53f62

Add custom handler for inference endpoints

Browse files
Files changed (1) hide show
  1. handler.py +176 -0
handler.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Custom Handler for MORBID-Actuarial v0.1.0 Conversational Model
3
+ Hugging Face Inference Endpoints
4
+ """
5
+
6
+ from typing import Dict, List, Any
7
+ import torch
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer
9
+
10
+
11
+ class EndpointHandler:
12
+ def __init__(self, path: str = ""):
13
+ """
14
+ Initialize the handler with model and tokenizer
15
+
16
+ Args:
17
+ path: Path to the model directory
18
+ """
19
+ # Load tokenizer and model
20
+ self.tokenizer = AutoTokenizer.from_pretrained(path)
21
+ self.model = AutoModelForCausalLM.from_pretrained(
22
+ path,
23
+ torch_dtype=torch.float16,
24
+ device_map="auto",
25
+ low_cpu_mem_usage=True
26
+ )
27
+
28
+ # Set padding token if not already set
29
+ if self.tokenizer.pad_token is None:
30
+ self.tokenizer.pad_token = self.tokenizer.eos_token
31
+
32
+ # System prompt for conversational behavior
33
+ self.system_prompt = """You are MORBID.AI, a friendly and conversational actuarial assistant.
34
+ You have expertise in:
35
+ - Life expectancy and mortality statistics
36
+ - Insurance and risk calculations
37
+ - Financial mathematics (FM exam - 100% accuracy)
38
+ - Probability theory (P exam - 100% accuracy)
39
+ - Investment and financial markets (IFM exam - 93.3% accuracy)
40
+
41
+ Be warm, helpful, and engaging. Respond naturally to greetings and casual conversation while maintaining your actuarial expertise.
42
+ When users greet you, respond warmly. When they ask for help, be supportive and clear.
43
+ Balance personality with precision when discussing technical topics."""
44
+
45
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
46
+ """
47
+ Process the inference request
48
+
49
+ Args:
50
+ data: Dictionary containing the input data
51
+ - inputs (str or list): The input text(s)
52
+ - parameters (dict): Generation parameters
53
+
54
+ Returns:
55
+ List of generated responses
56
+ """
57
+ # Extract inputs
58
+ inputs = data.get("inputs", "")
59
+ parameters = data.get("parameters", {})
60
+
61
+ # Handle both string and list inputs
62
+ if isinstance(inputs, str):
63
+ inputs = [inputs]
64
+ elif not isinstance(inputs, list):
65
+ inputs = [str(inputs)]
66
+
67
+ # Set default generation parameters
68
+ generation_params = {
69
+ "max_new_tokens": parameters.get("max_new_tokens", 200),
70
+ "temperature": parameters.get("temperature", 0.8),
71
+ "top_p": parameters.get("top_p", 0.95),
72
+ "do_sample": parameters.get("do_sample", True),
73
+ "repetition_penalty": parameters.get("repetition_penalty", 1.1),
74
+ "pad_token_id": self.tokenizer.pad_token_id,
75
+ "eos_token_id": self.tokenizer.eos_token_id,
76
+ }
77
+
78
+ # Process each input
79
+ results = []
80
+ for input_text in inputs:
81
+ # Format the prompt with conversational context
82
+ prompt = self._format_prompt(input_text)
83
+
84
+ # Tokenize
85
+ inputs_tokenized = self.tokenizer(
86
+ prompt,
87
+ return_tensors="pt",
88
+ padding=True,
89
+ truncation=True,
90
+ max_length=512
91
+ ).to(self.model.device)
92
+
93
+ # Generate response
94
+ with torch.no_grad():
95
+ outputs = self.model.generate(
96
+ **inputs_tokenized,
97
+ **generation_params
98
+ )
99
+
100
+ # Decode the response
101
+ generated_text = self.tokenizer.decode(
102
+ outputs[0],
103
+ skip_special_tokens=True
104
+ )
105
+
106
+ # Extract only the assistant's response
107
+ response = self._extract_response(generated_text, prompt)
108
+
109
+ results.append({
110
+ "generated_text": response,
111
+ "conversation": {
112
+ "user": input_text,
113
+ "assistant": response
114
+ }
115
+ })
116
+
117
+ return results
118
+
119
+ def _format_prompt(self, user_input: str) -> str:
120
+ """
121
+ Format the user input into a conversational prompt
122
+
123
+ Args:
124
+ user_input: The user's message
125
+
126
+ Returns:
127
+ Formatted prompt string
128
+ """
129
+ # Check if it's a greeting or casual message
130
+ lower_input = user_input.lower().strip()
131
+
132
+ # For very short inputs or greetings, add conversational context
133
+ if len(lower_input) <= 20 or any(greet in lower_input for greet in ["hi", "hello", "hey", "howdy"]):
134
+ return f"{self.system_prompt}\n\nHuman: {user_input}\nAssistant: "
135
+
136
+ # For longer inputs, check if they're actuarial
137
+ actuarial_keywords = ["mortality", "life expectancy", "insurance", "premium", "annuity",
138
+ "probability", "risk", "actuarial", "death", "survival"]
139
+
140
+ if any(keyword in lower_input for keyword in actuarial_keywords):
141
+ # Actuarial query - be precise but friendly
142
+ return f"As a conversational actuarial AI assistant, provide a helpful and accurate response.\n\nHuman: {user_input}\nAssistant: "
143
+ else:
144
+ # General conversation - be more casual
145
+ return f"{self.system_prompt}\n\nHuman: {user_input}\nAssistant: "
146
+
147
+ def _extract_response(self, generated_text: str, prompt: str) -> str:
148
+ """
149
+ Extract only the assistant's response from the generated text
150
+
151
+ Args:
152
+ generated_text: Full generated text including prompt
153
+ prompt: The original prompt
154
+
155
+ Returns:
156
+ Just the assistant's response
157
+ """
158
+ # Remove the prompt from the beginning
159
+ if generated_text.startswith(prompt):
160
+ response = generated_text[len(prompt):].strip()
161
+ else:
162
+ # Try to find "Assistant:" marker
163
+ if "Assistant:" in generated_text:
164
+ response = generated_text.split("Assistant:")[-1].strip()
165
+ else:
166
+ response = generated_text.strip()
167
+
168
+ # Clean up any remaining markers
169
+ if response.startswith(":"):
170
+ response = response[1:].strip()
171
+
172
+ # Ensure we have a response
173
+ if not response:
174
+ response = "I'm here to help! Could you please rephrase your question?"
175
+
176
+ return response