File size: 8,826 Bytes
6d80aa7 88759f7 6d80aa7 88759f7 6d80aa7 9646f87 6d80aa7 9646f87 6d80aa7 1d8ca69 6d80aa7 1d8ca69 6d80aa7 1d8ca69 6d80aa7 1d8ca69 6d80aa7 1d8ca69 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 |
"""
Custom Handler for MORBID-Actuarial v0.1.0 Conversational Model
Hugging Face Inference Endpoints
"""
from typing import Dict, List, Any
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
class EndpointHandler:
def __init__(self, path: str = ""):
"""
Initialize the handler with model and tokenizer
Args:
path: Path to the model directory
"""
# Load tokenizer and model
# Some repos may have a non-standard model_type. In that case, fall back to a known base model.
fallback_model_id = os.getenv("BASE_MODEL_ID", "TinyLlama/TinyLlama-1.1B-Chat-v1.0")
dtype = torch.float16 if torch.cuda.is_available() else None
try:
self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
self.model = AutoModelForCausalLM.from_pretrained(
path,
torch_dtype=dtype,
device_map="auto",
low_cpu_mem_usage=True
)
except Exception:
# Fallback to a supported base model
self.tokenizer = AutoTokenizer.from_pretrained(fallback_model_id, trust_remote_code=True)
self.model = AutoModelForCausalLM.from_pretrained(
fallback_model_id,
torch_dtype=dtype,
device_map="auto",
low_cpu_mem_usage=True
)
# Set padding token if not already set
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
# System prompt for conversational behavior
self.system_prompt = """You are MORBID.AI, a friendly and conversational actuarial assistant.
You have expertise in:
- Life expectancy and mortality statistics
- Insurance and risk calculations
- Financial mathematics (FM exam - 100% accuracy)
- Probability theory (P exam - 100% accuracy)
- Investment and financial markets (IFM exam - 93.3% accuracy)
Be warm, helpful, and engaging. Respond naturally to greetings and casual conversation while maintaining your actuarial expertise.
When users greet you, respond warmly. When they ask for help, be supportive and clear.
Balance personality with precision when discussing technical topics."""
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Process the inference request
Args:
data: Dictionary containing the input data
- inputs (str or list): The input text(s)
- parameters (dict): Generation parameters
Returns:
List of generated responses
"""
# Extract inputs
inputs = data.get("inputs", "")
parameters = data.get("parameters", {})
# Handle both string and list inputs
if isinstance(inputs, str):
inputs = [inputs]
elif not isinstance(inputs, list):
inputs = [str(inputs)]
# Set default generation parameters
generation_params = {
"max_new_tokens": parameters.get("max_new_tokens", 200),
"temperature": parameters.get("temperature", 0.8),
"top_p": parameters.get("top_p", 0.95),
"do_sample": parameters.get("do_sample", True),
"repetition_penalty": parameters.get("repetition_penalty", 1.1),
"pad_token_id": self.tokenizer.pad_token_id,
"eos_token_id": self.tokenizer.eos_token_id,
}
# Process each input
results = []
for input_text in inputs:
# Format the prompt with conversational context
prompt = self._format_prompt(input_text)
# Tokenize
inputs_tokenized = self.tokenizer(
prompt,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512
).to(self.model.device)
# Generate response
# Prepare additional decoding constraints
bad_words_ids = []
try:
# Disallow role-tag leakage in generations
role_tokens = ["Human:", "User:", "Assistant:", "SYSTEM:", "System:"]
tokenized = self.tokenizer(role_tokens, add_special_tokens=False).input_ids
# input_ids can be nested lists (one per tokenized string)
for ids in tokenized:
if isinstance(ids, list) and len(ids) > 0:
bad_words_ids.append(ids)
except Exception:
pass
decoding_kwargs = {
**generation_params,
# Encourage coherence and reduce repetition/artifacts
"no_repeat_ngram_size": 3,
}
if bad_words_ids:
decoding_kwargs["bad_words_ids"] = bad_words_ids
with torch.no_grad():
outputs = self.model.generate(
**inputs_tokenized,
**decoding_kwargs
)
# Decode the response
generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract only the assistant's response and trim at stop sequences
response = self._extract_response(generated_text, prompt)
response = self._truncate_at_stops(response)
results.append({
"generated_text": response,
"conversation": {
"user": input_text,
"assistant": response
}
})
return results
def _format_prompt(self, user_input: str) -> str:
"""
Format the user input into a conversational prompt
Args:
user_input: The user's message
Returns:
Formatted prompt string
"""
# Check if it's a greeting or casual message
lower_input = user_input.lower().strip()
# For very short inputs or greetings, add conversational context
if len(lower_input) <= 20 or any(greet in lower_input for greet in ["hi", "hello", "hey", "howdy"]):
return f"{self.system_prompt}\n\nHuman: {user_input}\nAssistant: "
# For longer inputs, check if they're actuarial
actuarial_keywords = ["mortality", "life expectancy", "insurance", "premium", "annuity",
"probability", "risk", "actuarial", "death", "survival"]
if any(keyword in lower_input for keyword in actuarial_keywords):
# Actuarial query - be precise but friendly
return f"As a conversational actuarial AI assistant, provide a helpful and accurate response.\n\nHuman: {user_input}\nAssistant: "
else:
# General conversation - be more casual
return f"{self.system_prompt}\n\nHuman: {user_input}\nAssistant: "
def _extract_response(self, generated_text: str, prompt: str) -> str:
"""
Extract only the assistant's response from the generated text
Args:
generated_text: Full generated text including prompt
prompt: The original prompt
Returns:
Just the assistant's response
"""
# Strategy: take everything after the LAST "Assistant:" marker; fallback to stripping prompt
if "Assistant:" in generated_text:
response = generated_text.split("Assistant:")[-1].strip()
elif generated_text.startswith(prompt):
response = generated_text[len(prompt):].strip()
else:
response = generated_text.strip()
# Clean up any remaining markers
if response.startswith(":"):
response = response[1:].strip()
# Ensure we have a response
if not response:
response = "I'm here to help! Could you please rephrase your question?"
return response
def _truncate_at_stops(self, text: str) -> str:
"""Truncate model output at conversation stop markers to avoid echoing future turns."""
stop_markers = ["\nHuman:", "\nUser:", "\nAssistant:", "\nSYSTEM:", "\nSystem:"]
cut_index = None
for marker in stop_markers:
idx = text.find(marker)
if idx != -1:
cut_index = idx if cut_index is None else min(cut_index, idx)
if cut_index is not None:
text = text[:cut_index].rstrip()
# Keep response reasonably bounded
if len(text) > 2000:
text = text[:2000].rstrip()
return text
|