h3ir's picture
Add inference handler
9810e3a verified
from typing import Dict, List, Any
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
class EndpointHandler:
def __init__(self, path: str = ""):
"""
Initialize the model and tokenizer for inference.
"""
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.model = AutoModelForCausalLM.from_pretrained(
path,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
)
self.model.eval()
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Handle inference requests.
Expected input format (OpenAI-compatible):
{
"messages": [
{"role": "system", "content": "..."},
{"role": "user", "content": "..."}
],
"max_tokens": 512,
"temperature": 0.7,
"top_p": 0.9
}
"""
# Extract parameters
messages = data.get("messages", [])
max_tokens = data.get("max_tokens", 512)
temperature = data.get("temperature", 0.7)
top_p = data.get("top_p", 0.9)
# Build prompt from messages
prompt = self._build_prompt(messages)
# Tokenize
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
# Generate
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True,
pad_token_id=self.tokenizer.eos_token_id,
)
# Decode response (only the new tokens)
response = self.tokenizer.decode(
outputs[0][inputs["input_ids"].shape[1]:],
skip_special_tokens=True
)
# Return in OpenAI-compatible format
return [{
"choices": [{
"message": {
"role": "assistant",
"content": response.strip()
},
"finish_reason": "stop"
}]
}]
def _build_prompt(self, messages: List[Dict[str, str]]) -> str:
"""
Build prompt in Mistral Instruct format.
"""
prompt_parts = []
system_content = ""
for msg in messages:
role = msg.get("role", "")
content = msg.get("content", "")
if role == "system":
system_content = content
elif role == "user":
if system_content:
prompt_parts.append(f"[INST] {system_content}\n\n{content} [/INST]")
system_content = ""
else:
prompt_parts.append(f"[INST] {content} [/INST]")
elif role == "assistant":
prompt_parts.append(content)
return "".join(prompt_parts)