import os from enum import Enum from typing import Optional, Dict, Any, List, Union from vllm import LLM, SamplingParams from vllm.outputs import RequestOutput from transformers import AutoTokenizer DEFAULT_MAX_TOKENS = 16000 class ModelType(Enum): BASE = "base" INSTRUCT = "instruct" class VLLMClient: def __init__(self, model_path: str): self.model_path = model_path self.model_type = self._detect_model_type(model_path) self.llm = LLM(model=model_path) # Load tokenizer for all models to handle proper text formatting self.tokenizer = AutoTokenizer.from_pretrained(model_path) @staticmethod def _detect_model_type(model_path: str) -> ModelType: model_path_lower = model_path.lower() instruct_keywords = ['instruct', 'chat', 'dialogue', 'conversations', 'kista'] # Check if any instruct-related keyword is in the model path is_instruct = any(keyword in model_path_lower for keyword in instruct_keywords) return ModelType.INSTRUCT if is_instruct else ModelType.BASE def _format_base_prompt(self, system: Optional[str], content: str) -> str: """ Format prompt for base models including system prompt. """ if system: # For base models, we'll use a simple template return f"{system} {content}" return content def _format_instruct_prompt(self, system: Optional[str], content: str) -> str: """ Format prompt for instruct models using the model's chat template. """ messages = [] if system: messages.append({"role": "system", "content": system}) messages.append({"role": "user", "content": content}) return self.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) def _create_message_payload(self, system: Optional[str], content: str, max_tokens: int, temperature: float) -> Dict[str, Any]: """ Create the sampling parameters and format the prompt based on model type. """ if self.model_type == ModelType.BASE: formatted_prompt = self._format_base_prompt(system, content) else: formatted_prompt = self._format_instruct_prompt(system, content) sampling_params = SamplingParams( max_tokens=max_tokens, temperature=temperature, top_p=0.95, presence_penalty=0.0, frequency_penalty=0.0, ) return { "prompt": formatted_prompt, "sampling_params": sampling_params } def send_message(self, content: str, system: Optional[str] = None, max_tokens: int = 1000, temperature: float = 0) -> Dict[str, Any]: """ Send a message to the model and get a response. Args: content: User message or raw prompt system: System prompt (supported for both base and instruct models) max_tokens: Maximum number of tokens to generate temperature: Sampling temperature json_eval: Whether to parse the response as JSON Returns: Dictionary containing status and result/error """ try: payload = self._create_message_payload( system=system, content=content, max_tokens=max_tokens, temperature=temperature ) outputs = self.llm.generate( prompts=[payload["prompt"]], sampling_params=payload["sampling_params"] ) try: result_text = outputs[0].outputs[0].text.strip() result = result_text return {'status': True, 'result': result} except Exception as e: return {'status': True, 'result': outputs} except Exception as e: return {'status': False, 'error': str(e)}