|
|
import os |
|
|
from enum import Enum |
|
|
from typing import Optional, Dict, Any, List, Union |
|
|
from vllm import LLM, SamplingParams |
|
|
from vllm.outputs import RequestOutput |
|
|
from transformers import AutoTokenizer |
|
|
|
|
|
DEFAULT_MAX_TOKENS = 16000 |
|
|
|
|
|
class ModelType(Enum): |
|
|
BASE = "base" |
|
|
INSTRUCT = "instruct" |
|
|
|
|
|
class VLLMClient: |
|
|
def __init__(self, |
|
|
model_path: str): |
|
|
|
|
|
self.model_path = model_path |
|
|
self.model_type = self._detect_model_type(model_path) |
|
|
self.llm = LLM(model=model_path) |
|
|
|
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
|
|
|
|
@staticmethod |
|
|
def _detect_model_type(model_path: str) -> ModelType: |
|
|
|
|
|
model_path_lower = model_path.lower() |
|
|
instruct_keywords = ['instruct', 'chat', 'dialogue', 'conversations', 'kista'] |
|
|
|
|
|
|
|
|
is_instruct = any(keyword in model_path_lower for keyword in instruct_keywords) |
|
|
return ModelType.INSTRUCT if is_instruct else ModelType.BASE |
|
|
|
|
|
def _format_base_prompt(self, system: Optional[str], content: str) -> str: |
|
|
""" |
|
|
Format prompt for base models including system prompt. |
|
|
""" |
|
|
if system: |
|
|
|
|
|
return f"{system} {content}" |
|
|
return content |
|
|
|
|
|
def _format_instruct_prompt(self, system: Optional[str], content: str) -> str: |
|
|
""" |
|
|
Format prompt for instruct models using the model's chat template. |
|
|
""" |
|
|
messages = [] |
|
|
if system: |
|
|
messages.append({"role": "system", "content": system}) |
|
|
messages.append({"role": "user", "content": content}) |
|
|
|
|
|
return self.tokenizer.apply_chat_template( |
|
|
messages, |
|
|
tokenize=False, |
|
|
add_generation_prompt=True |
|
|
) |
|
|
|
|
|
def _create_message_payload(self, |
|
|
system: Optional[str], |
|
|
content: str, |
|
|
max_tokens: int, |
|
|
temperature: float) -> Dict[str, Any]: |
|
|
""" |
|
|
Create the sampling parameters and format the prompt based on model type. |
|
|
""" |
|
|
if self.model_type == ModelType.BASE: |
|
|
formatted_prompt = self._format_base_prompt(system, content) |
|
|
else: |
|
|
formatted_prompt = self._format_instruct_prompt(system, content) |
|
|
|
|
|
sampling_params = SamplingParams( |
|
|
max_tokens=max_tokens, |
|
|
temperature=temperature, |
|
|
top_p=0.95, |
|
|
presence_penalty=0.0, |
|
|
frequency_penalty=0.0, |
|
|
) |
|
|
|
|
|
return { |
|
|
"prompt": formatted_prompt, |
|
|
"sampling_params": sampling_params |
|
|
} |
|
|
|
|
|
|
|
|
def send_message(self, |
|
|
content: str, |
|
|
system: Optional[str] = None, |
|
|
max_tokens: int = 1000, |
|
|
temperature: float = 0) -> Dict[str, Any]: |
|
|
""" |
|
|
Send a message to the model and get a response. |
|
|
|
|
|
Args: |
|
|
content: User message or raw prompt |
|
|
system: System prompt (supported for both base and instruct models) |
|
|
max_tokens: Maximum number of tokens to generate |
|
|
temperature: Sampling temperature |
|
|
json_eval: Whether to parse the response as JSON |
|
|
|
|
|
Returns: |
|
|
Dictionary containing status and result/error |
|
|
""" |
|
|
try: |
|
|
payload = self._create_message_payload( |
|
|
system=system, |
|
|
content=content, |
|
|
max_tokens=max_tokens, |
|
|
temperature=temperature |
|
|
) |
|
|
|
|
|
outputs = self.llm.generate( |
|
|
prompts=[payload["prompt"]], |
|
|
sampling_params=payload["sampling_params"] |
|
|
) |
|
|
|
|
|
try: |
|
|
result_text = outputs[0].outputs[0].text.strip() |
|
|
result = result_text |
|
|
return {'status': True, 'result': result} |
|
|
except Exception as e: |
|
|
return {'status': True, 'result': outputs} |
|
|
|
|
|
except Exception as e: |
|
|
return {'status': False, 'error': str(e)} |
|
|
|