File size: 7,529 Bytes
d7d1415
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
"""
Custom Handler for MORBID v0.2.0 Insurance AI
HuggingFace Inference Endpoints - Mistral Small 22B Fine-tuned
"""

from typing import Dict, List, Any
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer


class EndpointHandler:
    def __init__(self, path: str = ""):
        """
        Initialize the handler with model and tokenizer
        
        Args:
            path: Path to the model directory
        """
        # Load tokenizer and model
        dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
        
        self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
        self.model = AutoModelForCausalLM.from_pretrained(
            path,
            torch_dtype=dtype,
            device_map="auto",
            low_cpu_mem_usage=True
        )
        
        # Set padding token if not already set
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        
        # System prompt for Morbi v0.2.0
        self.system_prompt = """You are Morbi, an expert AI assistant specializing in health and life insurance, actuarial science, and risk analysis. You are:

1. KNOWLEDGEABLE: You have deep expertise in:
   - Life insurance products (term, whole, universal, variable)
   - Health insurance (medical, dental, disability, LTC)
   - Actuarial mathematics (mortality tables, interest theory, reserving)
   - Underwriting and risk classification
   - Claims analysis and management
   - Regulatory compliance (state, federal, NAIC)
   - ICD-10 medical codes and cause-of-death classification

2. CONVERSATIONAL: You communicate naturally and warmly while maintaining professionalism.

3. ACCURATE: You provide factual, well-reasoned responses. You never make up statistics.

4. HELPFUL: You aim to assist users effectively with actionable information."""

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
        Process the inference request
        
        Args:
            data: Dictionary containing the input data
                - inputs (str or list): The input text(s)
                - parameters (dict): Generation parameters
        
        Returns:
            List of generated responses
        """
        # Extract inputs
        inputs = data.get("inputs", "")
        parameters = data.get("parameters", {})
        
        # Handle both string and list inputs
        if isinstance(inputs, str):
            inputs = [inputs]
        elif not isinstance(inputs, list):
            inputs = [str(inputs)]
        
        # Set default generation parameters (optimized for Mistral Small 22B)
        generation_params = {
            "max_new_tokens": parameters.get("max_new_tokens", 512),
            "temperature": parameters.get("temperature", 0.7),
            "top_p": parameters.get("top_p", 0.9),
            "do_sample": parameters.get("do_sample", True),
            "repetition_penalty": parameters.get("repetition_penalty", 1.1),
            "pad_token_id": self.tokenizer.pad_token_id,
            "eos_token_id": self.tokenizer.eos_token_id,
        }
        
        # Process each input
        results = []
        for input_text in inputs:
            # Format the prompt with conversational context
            prompt = self._format_prompt(input_text)
            
            # Tokenize
            inputs_tokenized = self.tokenizer(
                prompt,
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=4096
            ).to(self.model.device)
            
            # Generate response
            # Prepare additional decoding constraints
            bad_words_ids = []
            try:
                # Disallow role-tag leakage in generations
                role_tokens = ["Human:", "User:", "Assistant:", "SYSTEM:", "System:"]
                tokenized = self.tokenizer(role_tokens, add_special_tokens=False).input_ids
                # input_ids can be nested lists (one per tokenized string)
                for ids in tokenized:
                    if isinstance(ids, list) and len(ids) > 0:
                        bad_words_ids.append(ids)
            except Exception:
                pass

            decoding_kwargs = {
                **generation_params,
                # Encourage coherence and reduce repetition/artifacts
                "no_repeat_ngram_size": 3,
            }
            if bad_words_ids:
                decoding_kwargs["bad_words_ids"] = bad_words_ids

            with torch.no_grad():
                outputs = self.model.generate(
                    **inputs_tokenized,
                    **decoding_kwargs
                )
            
            # Decode the response
            generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)

            # Extract only the assistant's response and trim at stop sequences
            response = self._extract_response(generated_text, prompt)
            response = self._truncate_at_stops(response)
            
            results.append({
                "generated_text": response,
                "conversation": {
                    "user": input_text,
                    "assistant": response
                }
            })
        
        return results
    
    def _format_prompt(self, user_input: str) -> str:
        """
        Format the user input into Mistral Instruct format
        
        Args:
            user_input: The user's message
            
        Returns:
            Formatted prompt string in Mistral format
        """
        # Mistral Instruct format: <s>[INST] system\n\nuser [/INST]
        return f"<s>[INST] {self.system_prompt}\n\n{user_input} [/INST]"
    
    def _extract_response(self, generated_text: str, prompt: str) -> str:
        """
        Extract only the assistant's response from the generated text
        
        Args:
            generated_text: Full generated text including prompt
            prompt: The original prompt
            
        Returns:
            Just the assistant's response
        """
        # For Mistral format, response comes after [/INST]
        if "[/INST]" in generated_text:
            response = generated_text.split("[/INST]")[-1].strip()
        elif generated_text.startswith(prompt):
            response = generated_text[len(prompt):].strip()
        else:
            response = generated_text.strip()
        
        # Remove any trailing </s> token
        response = response.replace("</s>", "").strip()
        
        # Ensure we have a response
        if not response:
            response = "I'm here to help! Could you please rephrase your question?"
        
        return response

    def _truncate_at_stops(self, text: str) -> str:
        """Truncate model output at conversation stop markers."""
        # Mistral stop markers
        stop_markers = [
            "\n[INST]", "[INST]", "</s>", "<s>",
            "\nHuman:", "\nUser:", "\nAssistant:",
        ]
        cut_index = None
        for marker in stop_markers:
            idx = text.find(marker)
            if idx != -1:
                cut_index = idx if cut_index is None else min(cut_index, idx)
        if cut_index is not None:
            text = text[:cut_index].rstrip()
        # Keep response reasonably bounded
        if len(text) > 2000:
            text = text[:2000].rstrip()
        return text