h3ir's picture
Add custom inference handler
d7d1415 verified
"""
Custom Handler for MORBID v0.2.0 Insurance AI
HuggingFace Inference Endpoints - Mistral Small 22B Fine-tuned
"""
from typing import Dict, List, Any
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
class EndpointHandler:
def __init__(self, path: str = ""):
"""
Initialize the handler with model and tokenizer
Args:
path: Path to the model directory
"""
# Load tokenizer and model
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
self.model = AutoModelForCausalLM.from_pretrained(
path,
torch_dtype=dtype,
device_map="auto",
low_cpu_mem_usage=True
)
# Set padding token if not already set
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
# System prompt for Morbi v0.2.0
self.system_prompt = """You are Morbi, an expert AI assistant specializing in health and life insurance, actuarial science, and risk analysis. You are:
1. KNOWLEDGEABLE: You have deep expertise in:
- Life insurance products (term, whole, universal, variable)
- Health insurance (medical, dental, disability, LTC)
- Actuarial mathematics (mortality tables, interest theory, reserving)
- Underwriting and risk classification
- Claims analysis and management
- Regulatory compliance (state, federal, NAIC)
- ICD-10 medical codes and cause-of-death classification
2. CONVERSATIONAL: You communicate naturally and warmly while maintaining professionalism.
3. ACCURATE: You provide factual, well-reasoned responses. You never make up statistics.
4. HELPFUL: You aim to assist users effectively with actionable information."""
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Process the inference request
Args:
data: Dictionary containing the input data
- inputs (str or list): The input text(s)
- parameters (dict): Generation parameters
Returns:
List of generated responses
"""
# Extract inputs
inputs = data.get("inputs", "")
parameters = data.get("parameters", {})
# Handle both string and list inputs
if isinstance(inputs, str):
inputs = [inputs]
elif not isinstance(inputs, list):
inputs = [str(inputs)]
# Set default generation parameters (optimized for Mistral Small 22B)
generation_params = {
"max_new_tokens": parameters.get("max_new_tokens", 512),
"temperature": parameters.get("temperature", 0.7),
"top_p": parameters.get("top_p", 0.9),
"do_sample": parameters.get("do_sample", True),
"repetition_penalty": parameters.get("repetition_penalty", 1.1),
"pad_token_id": self.tokenizer.pad_token_id,
"eos_token_id": self.tokenizer.eos_token_id,
}
# Process each input
results = []
for input_text in inputs:
# Format the prompt with conversational context
prompt = self._format_prompt(input_text)
# Tokenize
inputs_tokenized = self.tokenizer(
prompt,
return_tensors="pt",
padding=True,
truncation=True,
max_length=4096
).to(self.model.device)
# Generate response
# Prepare additional decoding constraints
bad_words_ids = []
try:
# Disallow role-tag leakage in generations
role_tokens = ["Human:", "User:", "Assistant:", "SYSTEM:", "System:"]
tokenized = self.tokenizer(role_tokens, add_special_tokens=False).input_ids
# input_ids can be nested lists (one per tokenized string)
for ids in tokenized:
if isinstance(ids, list) and len(ids) > 0:
bad_words_ids.append(ids)
except Exception:
pass
decoding_kwargs = {
**generation_params,
# Encourage coherence and reduce repetition/artifacts
"no_repeat_ngram_size": 3,
}
if bad_words_ids:
decoding_kwargs["bad_words_ids"] = bad_words_ids
with torch.no_grad():
outputs = self.model.generate(
**inputs_tokenized,
**decoding_kwargs
)
# Decode the response
generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract only the assistant's response and trim at stop sequences
response = self._extract_response(generated_text, prompt)
response = self._truncate_at_stops(response)
results.append({
"generated_text": response,
"conversation": {
"user": input_text,
"assistant": response
}
})
return results
def _format_prompt(self, user_input: str) -> str:
"""
Format the user input into Mistral Instruct format
Args:
user_input: The user's message
Returns:
Formatted prompt string in Mistral format
"""
# Mistral Instruct format: <s>[INST] system\n\nuser [/INST]
return f"<s>[INST] {self.system_prompt}\n\n{user_input} [/INST]"
def _extract_response(self, generated_text: str, prompt: str) -> str:
"""
Extract only the assistant's response from the generated text
Args:
generated_text: Full generated text including prompt
prompt: The original prompt
Returns:
Just the assistant's response
"""
# For Mistral format, response comes after [/INST]
if "[/INST]" in generated_text:
response = generated_text.split("[/INST]")[-1].strip()
elif generated_text.startswith(prompt):
response = generated_text[len(prompt):].strip()
else:
response = generated_text.strip()
# Remove any trailing </s> token
response = response.replace("</s>", "").strip()
# Ensure we have a response
if not response:
response = "I'm here to help! Could you please rephrase your question?"
return response
def _truncate_at_stops(self, text: str) -> str:
"""Truncate model output at conversation stop markers."""
# Mistral stop markers
stop_markers = [
"\n[INST]", "[INST]", "</s>", "<s>",
"\nHuman:", "\nUser:", "\nAssistant:",
]
cut_index = None
for marker in stop_markers:
idx = text.find(marker)
if idx != -1:
cut_index = idx if cut_index is None else min(cut_index, idx)
if cut_index is not None:
text = text[:cut_index].rstrip()
# Keep response reasonably bounded
if len(text) > 2000:
text = text[:2000].rstrip()
return text