""" Custom Handler for MORBID v0.2.0 Insurance AI HuggingFace Inference Endpoints - Mistral Small 22B Fine-tuned """ from typing import Dict, List, Any import os import torch from transformers import AutoModelForCausalLM, AutoTokenizer class EndpointHandler: def __init__(self, path: str = ""): """ Initialize the handler with model and tokenizer Args: path: Path to the model directory """ # Load tokenizer and model dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True) self.model = AutoModelForCausalLM.from_pretrained( path, torch_dtype=dtype, device_map="auto", low_cpu_mem_usage=True ) # Set padding token if not already set if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token # System prompt for Morbi v0.2.0 self.system_prompt = """You are Morbi, an expert AI assistant specializing in health and life insurance, actuarial science, and risk analysis. You are: 1. KNOWLEDGEABLE: You have deep expertise in: - Life insurance products (term, whole, universal, variable) - Health insurance (medical, dental, disability, LTC) - Actuarial mathematics (mortality tables, interest theory, reserving) - Underwriting and risk classification - Claims analysis and management - Regulatory compliance (state, federal, NAIC) - ICD-10 medical codes and cause-of-death classification 2. CONVERSATIONAL: You communicate naturally and warmly while maintaining professionalism. 3. ACCURATE: You provide factual, well-reasoned responses. You never make up statistics. 4. HELPFUL: You aim to assist users effectively with actionable information.""" def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ Process the inference request Args: data: Dictionary containing the input data - inputs (str or list): The input text(s) - parameters (dict): Generation parameters Returns: List of generated responses """ # Extract inputs inputs = data.get("inputs", "") parameters = data.get("parameters", {}) # Handle both string and list inputs if isinstance(inputs, str): inputs = [inputs] elif not isinstance(inputs, list): inputs = [str(inputs)] # Set default generation parameters (optimized for Mistral Small 22B) generation_params = { "max_new_tokens": parameters.get("max_new_tokens", 512), "temperature": parameters.get("temperature", 0.7), "top_p": parameters.get("top_p", 0.9), "do_sample": parameters.get("do_sample", True), "repetition_penalty": parameters.get("repetition_penalty", 1.1), "pad_token_id": self.tokenizer.pad_token_id, "eos_token_id": self.tokenizer.eos_token_id, } # Process each input results = [] for input_text in inputs: # Format the prompt with conversational context prompt = self._format_prompt(input_text) # Tokenize inputs_tokenized = self.tokenizer( prompt, return_tensors="pt", padding=True, truncation=True, max_length=4096 ).to(self.model.device) # Generate response # Prepare additional decoding constraints bad_words_ids = [] try: # Disallow role-tag leakage in generations role_tokens = ["Human:", "User:", "Assistant:", "SYSTEM:", "System:"] tokenized = self.tokenizer(role_tokens, add_special_tokens=False).input_ids # input_ids can be nested lists (one per tokenized string) for ids in tokenized: if isinstance(ids, list) and len(ids) > 0: bad_words_ids.append(ids) except Exception: pass decoding_kwargs = { **generation_params, # Encourage coherence and reduce repetition/artifacts "no_repeat_ngram_size": 3, } if bad_words_ids: decoding_kwargs["bad_words_ids"] = bad_words_ids with torch.no_grad(): outputs = self.model.generate( **inputs_tokenized, **decoding_kwargs ) # Decode the response generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) # Extract only the assistant's response and trim at stop sequences response = self._extract_response(generated_text, prompt) response = self._truncate_at_stops(response) results.append({ "generated_text": response, "conversation": { "user": input_text, "assistant": response } }) return results def _format_prompt(self, user_input: str) -> str: """ Format the user input into Mistral Instruct format Args: user_input: The user's message Returns: Formatted prompt string in Mistral format """ # Mistral Instruct format: [INST] system\n\nuser [/INST] return f"[INST] {self.system_prompt}\n\n{user_input} [/INST]" def _extract_response(self, generated_text: str, prompt: str) -> str: """ Extract only the assistant's response from the generated text Args: generated_text: Full generated text including prompt prompt: The original prompt Returns: Just the assistant's response """ # For Mistral format, response comes after [/INST] if "[/INST]" in generated_text: response = generated_text.split("[/INST]")[-1].strip() elif generated_text.startswith(prompt): response = generated_text[len(prompt):].strip() else: response = generated_text.strip() # Remove any trailing token response = response.replace("", "").strip() # Ensure we have a response if not response: response = "I'm here to help! Could you please rephrase your question?" return response def _truncate_at_stops(self, text: str) -> str: """Truncate model output at conversation stop markers.""" # Mistral stop markers stop_markers = [ "\n[INST]", "[INST]", "", "", "\nHuman:", "\nUser:", "\nAssistant:", ] cut_index = None for marker in stop_markers: idx = text.find(marker) if idx != -1: cut_index = idx if cut_index is None else min(cut_index, idx) if cut_index is not None: text = text[:cut_index].rstrip() # Keep response reasonably bounded if len(text) > 2000: text = text[:2000].rstrip() return text