File size: 7,529 Bytes
d7d1415 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 |
"""
Custom Handler for MORBID v0.2.0 Insurance AI
HuggingFace Inference Endpoints - Mistral Small 22B Fine-tuned
"""
from typing import Dict, List, Any
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
class EndpointHandler:
def __init__(self, path: str = ""):
"""
Initialize the handler with model and tokenizer
Args:
path: Path to the model directory
"""
# Load tokenizer and model
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
self.model = AutoModelForCausalLM.from_pretrained(
path,
torch_dtype=dtype,
device_map="auto",
low_cpu_mem_usage=True
)
# Set padding token if not already set
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
# System prompt for Morbi v0.2.0
self.system_prompt = """You are Morbi, an expert AI assistant specializing in health and life insurance, actuarial science, and risk analysis. You are:
1. KNOWLEDGEABLE: You have deep expertise in:
- Life insurance products (term, whole, universal, variable)
- Health insurance (medical, dental, disability, LTC)
- Actuarial mathematics (mortality tables, interest theory, reserving)
- Underwriting and risk classification
- Claims analysis and management
- Regulatory compliance (state, federal, NAIC)
- ICD-10 medical codes and cause-of-death classification
2. CONVERSATIONAL: You communicate naturally and warmly while maintaining professionalism.
3. ACCURATE: You provide factual, well-reasoned responses. You never make up statistics.
4. HELPFUL: You aim to assist users effectively with actionable information."""
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Process the inference request
Args:
data: Dictionary containing the input data
- inputs (str or list): The input text(s)
- parameters (dict): Generation parameters
Returns:
List of generated responses
"""
# Extract inputs
inputs = data.get("inputs", "")
parameters = data.get("parameters", {})
# Handle both string and list inputs
if isinstance(inputs, str):
inputs = [inputs]
elif not isinstance(inputs, list):
inputs = [str(inputs)]
# Set default generation parameters (optimized for Mistral Small 22B)
generation_params = {
"max_new_tokens": parameters.get("max_new_tokens", 512),
"temperature": parameters.get("temperature", 0.7),
"top_p": parameters.get("top_p", 0.9),
"do_sample": parameters.get("do_sample", True),
"repetition_penalty": parameters.get("repetition_penalty", 1.1),
"pad_token_id": self.tokenizer.pad_token_id,
"eos_token_id": self.tokenizer.eos_token_id,
}
# Process each input
results = []
for input_text in inputs:
# Format the prompt with conversational context
prompt = self._format_prompt(input_text)
# Tokenize
inputs_tokenized = self.tokenizer(
prompt,
return_tensors="pt",
padding=True,
truncation=True,
max_length=4096
).to(self.model.device)
# Generate response
# Prepare additional decoding constraints
bad_words_ids = []
try:
# Disallow role-tag leakage in generations
role_tokens = ["Human:", "User:", "Assistant:", "SYSTEM:", "System:"]
tokenized = self.tokenizer(role_tokens, add_special_tokens=False).input_ids
# input_ids can be nested lists (one per tokenized string)
for ids in tokenized:
if isinstance(ids, list) and len(ids) > 0:
bad_words_ids.append(ids)
except Exception:
pass
decoding_kwargs = {
**generation_params,
# Encourage coherence and reduce repetition/artifacts
"no_repeat_ngram_size": 3,
}
if bad_words_ids:
decoding_kwargs["bad_words_ids"] = bad_words_ids
with torch.no_grad():
outputs = self.model.generate(
**inputs_tokenized,
**decoding_kwargs
)
# Decode the response
generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract only the assistant's response and trim at stop sequences
response = self._extract_response(generated_text, prompt)
response = self._truncate_at_stops(response)
results.append({
"generated_text": response,
"conversation": {
"user": input_text,
"assistant": response
}
})
return results
def _format_prompt(self, user_input: str) -> str:
"""
Format the user input into Mistral Instruct format
Args:
user_input: The user's message
Returns:
Formatted prompt string in Mistral format
"""
# Mistral Instruct format: <s>[INST] system\n\nuser [/INST]
return f"<s>[INST] {self.system_prompt}\n\n{user_input} [/INST]"
def _extract_response(self, generated_text: str, prompt: str) -> str:
"""
Extract only the assistant's response from the generated text
Args:
generated_text: Full generated text including prompt
prompt: The original prompt
Returns:
Just the assistant's response
"""
# For Mistral format, response comes after [/INST]
if "[/INST]" in generated_text:
response = generated_text.split("[/INST]")[-1].strip()
elif generated_text.startswith(prompt):
response = generated_text[len(prompt):].strip()
else:
response = generated_text.strip()
# Remove any trailing </s> token
response = response.replace("</s>", "").strip()
# Ensure we have a response
if not response:
response = "I'm here to help! Could you please rephrase your question?"
return response
def _truncate_at_stops(self, text: str) -> str:
"""Truncate model output at conversation stop markers."""
# Mistral stop markers
stop_markers = [
"\n[INST]", "[INST]", "</s>", "<s>",
"\nHuman:", "\nUser:", "\nAssistant:",
]
cut_index = None
for marker in stop_markers:
idx = text.find(marker)
if idx != -1:
cut_index = idx if cut_index is None else min(cut_index, idx)
if cut_index is not None:
text = text[:cut_index].rstrip()
# Keep response reasonably bounded
if len(text) > 2000:
text = text[:2000].rstrip()
return text
|