File size: 4,326 Bytes
9d5b280 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
import os
from enum import Enum
from typing import Optional, Dict, Any, List, Union
from vllm import LLM, SamplingParams
from vllm.outputs import RequestOutput
from transformers import AutoTokenizer
DEFAULT_MAX_TOKENS = 16000
class ModelType(Enum):
BASE = "base"
INSTRUCT = "instruct"
class VLLMClient:
def __init__(self,
model_path: str):
self.model_path = model_path
self.model_type = self._detect_model_type(model_path)
self.llm = LLM(model=model_path)
# Load tokenizer for all models to handle proper text formatting
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
@staticmethod
def _detect_model_type(model_path: str) -> ModelType:
model_path_lower = model_path.lower()
instruct_keywords = ['instruct', 'chat', 'dialogue', 'conversations', 'kista']
# Check if any instruct-related keyword is in the model path
is_instruct = any(keyword in model_path_lower for keyword in instruct_keywords)
return ModelType.INSTRUCT if is_instruct else ModelType.BASE
def _format_base_prompt(self, system: Optional[str], content: str) -> str:
"""
Format prompt for base models including system prompt.
"""
if system:
# For base models, we'll use a simple template
return f"{system} {content}"
return content
def _format_instruct_prompt(self, system: Optional[str], content: str) -> str:
"""
Format prompt for instruct models using the model's chat template.
"""
messages = []
if system:
messages.append({"role": "system", "content": system})
messages.append({"role": "user", "content": content})
return self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
def _create_message_payload(self,
system: Optional[str],
content: str,
max_tokens: int,
temperature: float) -> Dict[str, Any]:
"""
Create the sampling parameters and format the prompt based on model type.
"""
if self.model_type == ModelType.BASE:
formatted_prompt = self._format_base_prompt(system, content)
else:
formatted_prompt = self._format_instruct_prompt(system, content)
sampling_params = SamplingParams(
max_tokens=max_tokens,
temperature=temperature,
top_p=0.95,
presence_penalty=0.0,
frequency_penalty=0.0,
)
return {
"prompt": formatted_prompt,
"sampling_params": sampling_params
}
def send_message(self,
content: str,
system: Optional[str] = None,
max_tokens: int = 1000,
temperature: float = 0) -> Dict[str, Any]:
"""
Send a message to the model and get a response.
Args:
content: User message or raw prompt
system: System prompt (supported for both base and instruct models)
max_tokens: Maximum number of tokens to generate
temperature: Sampling temperature
json_eval: Whether to parse the response as JSON
Returns:
Dictionary containing status and result/error
"""
try:
payload = self._create_message_payload(
system=system,
content=content,
max_tokens=max_tokens,
temperature=temperature
)
outputs = self.llm.generate(
prompts=[payload["prompt"]],
sampling_params=payload["sampling_params"]
)
try:
result_text = outputs[0].outputs[0].text.strip()
result = result_text
return {'status': True, 'result': result}
except Exception as e:
return {'status': True, 'result': outputs}
except Exception as e:
return {'status': False, 'error': str(e)}
|