|
|
""" |
|
|
Custom Handler for MORBID-Actuarial v0.1.0 Conversational Model |
|
|
Hugging Face Inference Endpoints |
|
|
""" |
|
|
|
|
|
from typing import Dict, List, Any |
|
|
import os |
|
|
import torch |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self, path: str = ""): |
|
|
""" |
|
|
Initialize the handler with model and tokenizer |
|
|
|
|
|
Args: |
|
|
path: Path to the model directory |
|
|
""" |
|
|
|
|
|
|
|
|
fallback_model_id = os.getenv("BASE_MODEL_ID", "TinyLlama/TinyLlama-1.1B-Chat-v1.0") |
|
|
dtype = torch.float16 if torch.cuda.is_available() else None |
|
|
try: |
|
|
self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True) |
|
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
|
path, |
|
|
torch_dtype=dtype, |
|
|
device_map="auto", |
|
|
low_cpu_mem_usage=True |
|
|
) |
|
|
except Exception: |
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(fallback_model_id, trust_remote_code=True) |
|
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
|
fallback_model_id, |
|
|
torch_dtype=dtype, |
|
|
device_map="auto", |
|
|
low_cpu_mem_usage=True |
|
|
) |
|
|
|
|
|
|
|
|
if self.tokenizer.pad_token is None: |
|
|
self.tokenizer.pad_token = self.tokenizer.eos_token |
|
|
|
|
|
|
|
|
self.system_prompt = """You are MORBID.AI, a friendly and conversational actuarial assistant. |
|
|
You have expertise in: |
|
|
- Life expectancy and mortality statistics |
|
|
- Insurance and risk calculations |
|
|
- Financial mathematics (FM exam - 100% accuracy) |
|
|
- Probability theory (P exam - 100% accuracy) |
|
|
- Investment and financial markets (IFM exam - 93.3% accuracy) |
|
|
|
|
|
Be warm, helpful, and engaging. Respond naturally to greetings and casual conversation while maintaining your actuarial expertise. |
|
|
When users greet you, respond warmly. When they ask for help, be supportive and clear. |
|
|
Balance personality with precision when discussing technical topics.""" |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
|
""" |
|
|
Process the inference request |
|
|
|
|
|
Args: |
|
|
data: Dictionary containing the input data |
|
|
- inputs (str or list): The input text(s) |
|
|
- parameters (dict): Generation parameters |
|
|
|
|
|
Returns: |
|
|
List of generated responses |
|
|
""" |
|
|
|
|
|
inputs = data.get("inputs", "") |
|
|
parameters = data.get("parameters", {}) |
|
|
|
|
|
|
|
|
if isinstance(inputs, str): |
|
|
inputs = [inputs] |
|
|
elif not isinstance(inputs, list): |
|
|
inputs = [str(inputs)] |
|
|
|
|
|
|
|
|
generation_params = { |
|
|
"max_new_tokens": parameters.get("max_new_tokens", 200), |
|
|
"temperature": parameters.get("temperature", 0.8), |
|
|
"top_p": parameters.get("top_p", 0.95), |
|
|
"do_sample": parameters.get("do_sample", True), |
|
|
"repetition_penalty": parameters.get("repetition_penalty", 1.1), |
|
|
"pad_token_id": self.tokenizer.pad_token_id, |
|
|
"eos_token_id": self.tokenizer.eos_token_id, |
|
|
} |
|
|
|
|
|
|
|
|
results = [] |
|
|
for input_text in inputs: |
|
|
|
|
|
prompt = self._format_prompt(input_text) |
|
|
|
|
|
|
|
|
inputs_tokenized = self.tokenizer( |
|
|
prompt, |
|
|
return_tensors="pt", |
|
|
padding=True, |
|
|
truncation=True, |
|
|
max_length=512 |
|
|
).to(self.model.device) |
|
|
|
|
|
|
|
|
|
|
|
bad_words_ids = [] |
|
|
try: |
|
|
|
|
|
role_tokens = ["Human:", "User:", "Assistant:", "SYSTEM:", "System:"] |
|
|
tokenized = self.tokenizer(role_tokens, add_special_tokens=False).input_ids |
|
|
|
|
|
for ids in tokenized: |
|
|
if isinstance(ids, list) and len(ids) > 0: |
|
|
bad_words_ids.append(ids) |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
decoding_kwargs = { |
|
|
**generation_params, |
|
|
|
|
|
"no_repeat_ngram_size": 3, |
|
|
} |
|
|
if bad_words_ids: |
|
|
decoding_kwargs["bad_words_ids"] = bad_words_ids |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.model.generate( |
|
|
**inputs_tokenized, |
|
|
**decoding_kwargs |
|
|
) |
|
|
|
|
|
|
|
|
generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
|
|
|
response = self._extract_response(generated_text, prompt) |
|
|
response = self._truncate_at_stops(response) |
|
|
|
|
|
results.append({ |
|
|
"generated_text": response, |
|
|
"conversation": { |
|
|
"user": input_text, |
|
|
"assistant": response |
|
|
} |
|
|
}) |
|
|
|
|
|
return results |
|
|
|
|
|
def _format_prompt(self, user_input: str) -> str: |
|
|
""" |
|
|
Format the user input into a conversational prompt |
|
|
|
|
|
Args: |
|
|
user_input: The user's message |
|
|
|
|
|
Returns: |
|
|
Formatted prompt string |
|
|
""" |
|
|
|
|
|
lower_input = user_input.lower().strip() |
|
|
|
|
|
|
|
|
if len(lower_input) <= 20 or any(greet in lower_input for greet in ["hi", "hello", "hey", "howdy"]): |
|
|
return f"{self.system_prompt}\n\nHuman: {user_input}\nAssistant: " |
|
|
|
|
|
|
|
|
actuarial_keywords = ["mortality", "life expectancy", "insurance", "premium", "annuity", |
|
|
"probability", "risk", "actuarial", "death", "survival"] |
|
|
|
|
|
if any(keyword in lower_input for keyword in actuarial_keywords): |
|
|
|
|
|
return f"As a conversational actuarial AI assistant, provide a helpful and accurate response.\n\nHuman: {user_input}\nAssistant: " |
|
|
else: |
|
|
|
|
|
return f"{self.system_prompt}\n\nHuman: {user_input}\nAssistant: " |
|
|
|
|
|
def _extract_response(self, generated_text: str, prompt: str) -> str: |
|
|
""" |
|
|
Extract only the assistant's response from the generated text |
|
|
|
|
|
Args: |
|
|
generated_text: Full generated text including prompt |
|
|
prompt: The original prompt |
|
|
|
|
|
Returns: |
|
|
Just the assistant's response |
|
|
""" |
|
|
|
|
|
if "Assistant:" in generated_text: |
|
|
response = generated_text.split("Assistant:")[-1].strip() |
|
|
elif generated_text.startswith(prompt): |
|
|
response = generated_text[len(prompt):].strip() |
|
|
else: |
|
|
response = generated_text.strip() |
|
|
|
|
|
|
|
|
if response.startswith(":"): |
|
|
response = response[1:].strip() |
|
|
|
|
|
|
|
|
if not response: |
|
|
response = "I'm here to help! Could you please rephrase your question?" |
|
|
|
|
|
return response |
|
|
|
|
|
def _truncate_at_stops(self, text: str) -> str: |
|
|
"""Truncate model output at conversation stop markers to avoid echoing future turns.""" |
|
|
stop_markers = ["\nHuman:", "\nUser:", "\nAssistant:", "\nSYSTEM:", "\nSystem:"] |
|
|
cut_index = None |
|
|
for marker in stop_markers: |
|
|
idx = text.find(marker) |
|
|
if idx != -1: |
|
|
cut_index = idx if cut_index is None else min(cut_index, idx) |
|
|
if cut_index is not None: |
|
|
text = text[:cut_index].rstrip() |
|
|
|
|
|
if len(text) > 2000: |
|
|
text = text[:2000].rstrip() |
|
|
return text |
|
|
|