h3ir's picture
Add custom handler for inference endpoints
9646f87 verified
"""
Custom Handler for MORBID-Actuarial v0.1.0 Conversational Model
Hugging Face Inference Endpoints
"""
from typing import Dict, List, Any
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
class EndpointHandler:
def __init__(self, path: str = ""):
"""
Initialize the handler with model and tokenizer
Args:
path: Path to the model directory
"""
# Load tokenizer and model
# Some repos may have a non-standard model_type. In that case, fall back to a known base model.
fallback_model_id = os.getenv("BASE_MODEL_ID", "TinyLlama/TinyLlama-1.1B-Chat-v1.0")
dtype = torch.float16 if torch.cuda.is_available() else None
try:
self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
self.model = AutoModelForCausalLM.from_pretrained(
path,
torch_dtype=dtype,
device_map="auto",
low_cpu_mem_usage=True
)
except Exception:
# Fallback to a supported base model
self.tokenizer = AutoTokenizer.from_pretrained(fallback_model_id, trust_remote_code=True)
self.model = AutoModelForCausalLM.from_pretrained(
fallback_model_id,
torch_dtype=dtype,
device_map="auto",
low_cpu_mem_usage=True
)
# Set padding token if not already set
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
# System prompt for conversational behavior
self.system_prompt = """You are MORBID.AI, a friendly and conversational actuarial assistant.
You have expertise in:
- Life expectancy and mortality statistics
- Insurance and risk calculations
- Financial mathematics (FM exam - 100% accuracy)
- Probability theory (P exam - 100% accuracy)
- Investment and financial markets (IFM exam - 93.3% accuracy)
Be warm, helpful, and engaging. Respond naturally to greetings and casual conversation while maintaining your actuarial expertise.
When users greet you, respond warmly. When they ask for help, be supportive and clear.
Balance personality with precision when discussing technical topics."""
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Process the inference request
Args:
data: Dictionary containing the input data
- inputs (str or list): The input text(s)
- parameters (dict): Generation parameters
Returns:
List of generated responses
"""
# Extract inputs
inputs = data.get("inputs", "")
parameters = data.get("parameters", {})
# Handle both string and list inputs
if isinstance(inputs, str):
inputs = [inputs]
elif not isinstance(inputs, list):
inputs = [str(inputs)]
# Set default generation parameters
generation_params = {
"max_new_tokens": parameters.get("max_new_tokens", 200),
"temperature": parameters.get("temperature", 0.8),
"top_p": parameters.get("top_p", 0.95),
"do_sample": parameters.get("do_sample", True),
"repetition_penalty": parameters.get("repetition_penalty", 1.1),
"pad_token_id": self.tokenizer.pad_token_id,
"eos_token_id": self.tokenizer.eos_token_id,
}
# Process each input
results = []
for input_text in inputs:
# Format the prompt with conversational context
prompt = self._format_prompt(input_text)
# Tokenize
inputs_tokenized = self.tokenizer(
prompt,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512
).to(self.model.device)
# Generate response
# Prepare additional decoding constraints
bad_words_ids = []
try:
# Disallow role-tag leakage in generations
role_tokens = ["Human:", "User:", "Assistant:", "SYSTEM:", "System:"]
tokenized = self.tokenizer(role_tokens, add_special_tokens=False).input_ids
# input_ids can be nested lists (one per tokenized string)
for ids in tokenized:
if isinstance(ids, list) and len(ids) > 0:
bad_words_ids.append(ids)
except Exception:
pass
decoding_kwargs = {
**generation_params,
# Encourage coherence and reduce repetition/artifacts
"no_repeat_ngram_size": 3,
}
if bad_words_ids:
decoding_kwargs["bad_words_ids"] = bad_words_ids
with torch.no_grad():
outputs = self.model.generate(
**inputs_tokenized,
**decoding_kwargs
)
# Decode the response
generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract only the assistant's response and trim at stop sequences
response = self._extract_response(generated_text, prompt)
response = self._truncate_at_stops(response)
results.append({
"generated_text": response,
"conversation": {
"user": input_text,
"assistant": response
}
})
return results
def _format_prompt(self, user_input: str) -> str:
"""
Format the user input into a conversational prompt
Args:
user_input: The user's message
Returns:
Formatted prompt string
"""
# Check if it's a greeting or casual message
lower_input = user_input.lower().strip()
# For very short inputs or greetings, add conversational context
if len(lower_input) <= 20 or any(greet in lower_input for greet in ["hi", "hello", "hey", "howdy"]):
return f"{self.system_prompt}\n\nHuman: {user_input}\nAssistant: "
# For longer inputs, check if they're actuarial
actuarial_keywords = ["mortality", "life expectancy", "insurance", "premium", "annuity",
"probability", "risk", "actuarial", "death", "survival"]
if any(keyword in lower_input for keyword in actuarial_keywords):
# Actuarial query - be precise but friendly
return f"As a conversational actuarial AI assistant, provide a helpful and accurate response.\n\nHuman: {user_input}\nAssistant: "
else:
# General conversation - be more casual
return f"{self.system_prompt}\n\nHuman: {user_input}\nAssistant: "
def _extract_response(self, generated_text: str, prompt: str) -> str:
"""
Extract only the assistant's response from the generated text
Args:
generated_text: Full generated text including prompt
prompt: The original prompt
Returns:
Just the assistant's response
"""
# Strategy: take everything after the LAST "Assistant:" marker; fallback to stripping prompt
if "Assistant:" in generated_text:
response = generated_text.split("Assistant:")[-1].strip()
elif generated_text.startswith(prompt):
response = generated_text[len(prompt):].strip()
else:
response = generated_text.strip()
# Clean up any remaining markers
if response.startswith(":"):
response = response[1:].strip()
# Ensure we have a response
if not response:
response = "I'm here to help! Could you please rephrase your question?"
return response
def _truncate_at_stops(self, text: str) -> str:
"""Truncate model output at conversation stop markers to avoid echoing future turns."""
stop_markers = ["\nHuman:", "\nUser:", "\nAssistant:", "\nSYSTEM:", "\nSystem:"]
cut_index = None
for marker in stop_markers:
idx = text.find(marker)
if idx != -1:
cut_index = idx if cut_index is None else min(cut_index, idx)
if cut_index is not None:
text = text[:cut_index].rstrip()
# Keep response reasonably bounded
if len(text) > 2000:
text = text[:2000].rstrip()
return text