|
|
from typing import Dict, List, Any |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
import torch |
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self, path: str = ""): |
|
|
""" |
|
|
Initialize the model and tokenizer for inference. |
|
|
""" |
|
|
self.tokenizer = AutoTokenizer.from_pretrained(path) |
|
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
|
path, |
|
|
torch_dtype=torch.bfloat16, |
|
|
device_map="auto", |
|
|
trust_remote_code=True, |
|
|
) |
|
|
self.model.eval() |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
|
""" |
|
|
Handle inference requests. |
|
|
|
|
|
Expected input format (OpenAI-compatible): |
|
|
{ |
|
|
"messages": [ |
|
|
{"role": "system", "content": "..."}, |
|
|
{"role": "user", "content": "..."} |
|
|
], |
|
|
"max_tokens": 512, |
|
|
"temperature": 0.7, |
|
|
"top_p": 0.9 |
|
|
} |
|
|
""" |
|
|
|
|
|
messages = data.get("messages", []) |
|
|
max_tokens = data.get("max_tokens", 512) |
|
|
temperature = data.get("temperature", 0.7) |
|
|
top_p = data.get("top_p", 0.9) |
|
|
|
|
|
|
|
|
prompt = self._build_prompt(messages) |
|
|
|
|
|
|
|
|
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=max_tokens, |
|
|
temperature=temperature, |
|
|
top_p=top_p, |
|
|
do_sample=True, |
|
|
pad_token_id=self.tokenizer.eos_token_id, |
|
|
) |
|
|
|
|
|
|
|
|
response = self.tokenizer.decode( |
|
|
outputs[0][inputs["input_ids"].shape[1]:], |
|
|
skip_special_tokens=True |
|
|
) |
|
|
|
|
|
|
|
|
return [{ |
|
|
"choices": [{ |
|
|
"message": { |
|
|
"role": "assistant", |
|
|
"content": response.strip() |
|
|
}, |
|
|
"finish_reason": "stop" |
|
|
}] |
|
|
}] |
|
|
|
|
|
def _build_prompt(self, messages: List[Dict[str, str]]) -> str: |
|
|
""" |
|
|
Build prompt in Mistral Instruct format. |
|
|
""" |
|
|
prompt_parts = [] |
|
|
system_content = "" |
|
|
|
|
|
for msg in messages: |
|
|
role = msg.get("role", "") |
|
|
content = msg.get("content", "") |
|
|
|
|
|
if role == "system": |
|
|
system_content = content |
|
|
elif role == "user": |
|
|
if system_content: |
|
|
prompt_parts.append(f"[INST] {system_content}\n\n{content} [/INST]") |
|
|
system_content = "" |
|
|
else: |
|
|
prompt_parts.append(f"[INST] {content} [/INST]") |
|
|
elif role == "assistant": |
|
|
prompt_parts.append(content) |
|
|
|
|
|
return "".join(prompt_parts) |
|
|
|