from typing import Dict, List, Any from transformers import AutoModelForCausalLM, AutoTokenizer import torch class EndpointHandler: def __init__(self, path: str = ""): """ Initialize the model and tokenizer for inference. """ self.tokenizer = AutoTokenizer.from_pretrained(path) self.model = AutoModelForCausalLM.from_pretrained( path, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True, ) self.model.eval() def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ Handle inference requests. Expected input format (OpenAI-compatible): { "messages": [ {"role": "system", "content": "..."}, {"role": "user", "content": "..."} ], "max_tokens": 512, "temperature": 0.7, "top_p": 0.9 } """ # Extract parameters messages = data.get("messages", []) max_tokens = data.get("max_tokens", 512) temperature = data.get("temperature", 0.7) top_p = data.get("top_p", 0.9) # Build prompt from messages prompt = self._build_prompt(messages) # Tokenize inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) # Generate with torch.no_grad(): outputs = self.model.generate( **inputs, max_new_tokens=max_tokens, temperature=temperature, top_p=top_p, do_sample=True, pad_token_id=self.tokenizer.eos_token_id, ) # Decode response (only the new tokens) response = self.tokenizer.decode( outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True ) # Return in OpenAI-compatible format return [{ "choices": [{ "message": { "role": "assistant", "content": response.strip() }, "finish_reason": "stop" }] }] def _build_prompt(self, messages: List[Dict[str, str]]) -> str: """ Build prompt in Mistral Instruct format. """ prompt_parts = [] system_content = "" for msg in messages: role = msg.get("role", "") content = msg.get("content", "") if role == "system": system_content = content elif role == "user": if system_content: prompt_parts.append(f"[INST] {system_content}\n\n{content} [/INST]") system_content = "" else: prompt_parts.append(f"[INST] {content} [/INST]") elif role == "assistant": prompt_parts.append(content) return "".join(prompt_parts)