File size: 8,826 Bytes
6d80aa7
 
 
 
 
 
88759f7
6d80aa7
 
 
 
 
 
 
 
 
 
 
 
 
88759f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6d80aa7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9646f87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6d80aa7
 
 
9646f87
6d80aa7
 
 
1d8ca69
 
 
6d80aa7
1d8ca69
6d80aa7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1d8ca69
 
 
 
6d80aa7
 
1d8ca69
6d80aa7
 
 
 
 
 
 
 
 
 
1d8ca69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
"""
Custom Handler for MORBID-Actuarial v0.1.0 Conversational Model
Hugging Face Inference Endpoints
"""

from typing import Dict, List, Any
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer


class EndpointHandler:
    def __init__(self, path: str = ""):
        """
        Initialize the handler with model and tokenizer
        
        Args:
            path: Path to the model directory
        """
        # Load tokenizer and model
        # Some repos may have a non-standard model_type. In that case, fall back to a known base model.
        fallback_model_id = os.getenv("BASE_MODEL_ID", "TinyLlama/TinyLlama-1.1B-Chat-v1.0")
        dtype = torch.float16 if torch.cuda.is_available() else None
        try:
            self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
            self.model = AutoModelForCausalLM.from_pretrained(
                path,
                torch_dtype=dtype,
                device_map="auto",
                low_cpu_mem_usage=True
            )
        except Exception:
            # Fallback to a supported base model
            self.tokenizer = AutoTokenizer.from_pretrained(fallback_model_id, trust_remote_code=True)
            self.model = AutoModelForCausalLM.from_pretrained(
                fallback_model_id,
                torch_dtype=dtype,
                device_map="auto",
                low_cpu_mem_usage=True
            )
        
        # Set padding token if not already set
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        
        # System prompt for conversational behavior
        self.system_prompt = """You are MORBID.AI, a friendly and conversational actuarial assistant. 
You have expertise in:
- Life expectancy and mortality statistics
- Insurance and risk calculations
- Financial mathematics (FM exam - 100% accuracy)
- Probability theory (P exam - 100% accuracy)
- Investment and financial markets (IFM exam - 93.3% accuracy)

Be warm, helpful, and engaging. Respond naturally to greetings and casual conversation while maintaining your actuarial expertise.
When users greet you, respond warmly. When they ask for help, be supportive and clear.
Balance personality with precision when discussing technical topics."""

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
        Process the inference request
        
        Args:
            data: Dictionary containing the input data
                - inputs (str or list): The input text(s)
                - parameters (dict): Generation parameters
        
        Returns:
            List of generated responses
        """
        # Extract inputs
        inputs = data.get("inputs", "")
        parameters = data.get("parameters", {})
        
        # Handle both string and list inputs
        if isinstance(inputs, str):
            inputs = [inputs]
        elif not isinstance(inputs, list):
            inputs = [str(inputs)]
        
        # Set default generation parameters
        generation_params = {
            "max_new_tokens": parameters.get("max_new_tokens", 200),
            "temperature": parameters.get("temperature", 0.8),
            "top_p": parameters.get("top_p", 0.95),
            "do_sample": parameters.get("do_sample", True),
            "repetition_penalty": parameters.get("repetition_penalty", 1.1),
            "pad_token_id": self.tokenizer.pad_token_id,
            "eos_token_id": self.tokenizer.eos_token_id,
        }
        
        # Process each input
        results = []
        for input_text in inputs:
            # Format the prompt with conversational context
            prompt = self._format_prompt(input_text)
            
            # Tokenize
            inputs_tokenized = self.tokenizer(
                prompt,
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=512
            ).to(self.model.device)
            
            # Generate response
            # Prepare additional decoding constraints
            bad_words_ids = []
            try:
                # Disallow role-tag leakage in generations
                role_tokens = ["Human:", "User:", "Assistant:", "SYSTEM:", "System:"]
                tokenized = self.tokenizer(role_tokens, add_special_tokens=False).input_ids
                # input_ids can be nested lists (one per tokenized string)
                for ids in tokenized:
                    if isinstance(ids, list) and len(ids) > 0:
                        bad_words_ids.append(ids)
            except Exception:
                pass

            decoding_kwargs = {
                **generation_params,
                # Encourage coherence and reduce repetition/artifacts
                "no_repeat_ngram_size": 3,
            }
            if bad_words_ids:
                decoding_kwargs["bad_words_ids"] = bad_words_ids

            with torch.no_grad():
                outputs = self.model.generate(
                    **inputs_tokenized,
                    **decoding_kwargs
                )
            
            # Decode the response
            generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)

            # Extract only the assistant's response and trim at stop sequences
            response = self._extract_response(generated_text, prompt)
            response = self._truncate_at_stops(response)
            
            results.append({
                "generated_text": response,
                "conversation": {
                    "user": input_text,
                    "assistant": response
                }
            })
        
        return results
    
    def _format_prompt(self, user_input: str) -> str:
        """
        Format the user input into a conversational prompt
        
        Args:
            user_input: The user's message
            
        Returns:
            Formatted prompt string
        """
        # Check if it's a greeting or casual message
        lower_input = user_input.lower().strip()
        
        # For very short inputs or greetings, add conversational context
        if len(lower_input) <= 20 or any(greet in lower_input for greet in ["hi", "hello", "hey", "howdy"]):
            return f"{self.system_prompt}\n\nHuman: {user_input}\nAssistant: "
        
        # For longer inputs, check if they're actuarial
        actuarial_keywords = ["mortality", "life expectancy", "insurance", "premium", "annuity", 
                             "probability", "risk", "actuarial", "death", "survival"]
        
        if any(keyword in lower_input for keyword in actuarial_keywords):
            # Actuarial query - be precise but friendly
            return f"As a conversational actuarial AI assistant, provide a helpful and accurate response.\n\nHuman: {user_input}\nAssistant: "
        else:
            # General conversation - be more casual
            return f"{self.system_prompt}\n\nHuman: {user_input}\nAssistant: "
    
    def _extract_response(self, generated_text: str, prompt: str) -> str:
        """
        Extract only the assistant's response from the generated text
        
        Args:
            generated_text: Full generated text including prompt
            prompt: The original prompt
            
        Returns:
            Just the assistant's response
        """
        # Strategy: take everything after the LAST "Assistant:" marker; fallback to stripping prompt
        if "Assistant:" in generated_text:
            response = generated_text.split("Assistant:")[-1].strip()
        elif generated_text.startswith(prompt):
            response = generated_text[len(prompt):].strip()
        else:
            response = generated_text.strip()
        
        # Clean up any remaining markers
        if response.startswith(":"):
            response = response[1:].strip()
        
        # Ensure we have a response
        if not response:
            response = "I'm here to help! Could you please rephrase your question?"
        
        return response

    def _truncate_at_stops(self, text: str) -> str:
        """Truncate model output at conversation stop markers to avoid echoing future turns."""
        stop_markers = ["\nHuman:", "\nUser:", "\nAssistant:", "\nSYSTEM:", "\nSystem:"]
        cut_index = None
        for marker in stop_markers:
            idx = text.find(marker)
            if idx != -1:
                cut_index = idx if cut_index is None else min(cut_index, idx)
        if cut_index is not None:
            text = text[:cut_index].rstrip()
        # Keep response reasonably bounded
        if len(text) > 2000:
            text = text[:2000].rstrip()
        return text