out / kista_benchmarking /model_run.py
BayesTensor's picture
Upload folder using huggingface_hub
9d5b280 verified
import os
from enum import Enum
from typing import Optional, Dict, Any, List, Union
from vllm import LLM, SamplingParams
from vllm.outputs import RequestOutput
from transformers import AutoTokenizer
DEFAULT_MAX_TOKENS = 16000
class ModelType(Enum):
BASE = "base"
INSTRUCT = "instruct"
class VLLMClient:
def __init__(self,
model_path: str):
self.model_path = model_path
self.model_type = self._detect_model_type(model_path)
self.llm = LLM(model=model_path)
# Load tokenizer for all models to handle proper text formatting
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
@staticmethod
def _detect_model_type(model_path: str) -> ModelType:
model_path_lower = model_path.lower()
instruct_keywords = ['instruct', 'chat', 'dialogue', 'conversations', 'kista']
# Check if any instruct-related keyword is in the model path
is_instruct = any(keyword in model_path_lower for keyword in instruct_keywords)
return ModelType.INSTRUCT if is_instruct else ModelType.BASE
def _format_base_prompt(self, system: Optional[str], content: str) -> str:
"""
Format prompt for base models including system prompt.
"""
if system:
# For base models, we'll use a simple template
return f"{system} {content}"
return content
def _format_instruct_prompt(self, system: Optional[str], content: str) -> str:
"""
Format prompt for instruct models using the model's chat template.
"""
messages = []
if system:
messages.append({"role": "system", "content": system})
messages.append({"role": "user", "content": content})
return self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
def _create_message_payload(self,
system: Optional[str],
content: str,
max_tokens: int,
temperature: float) -> Dict[str, Any]:
"""
Create the sampling parameters and format the prompt based on model type.
"""
if self.model_type == ModelType.BASE:
formatted_prompt = self._format_base_prompt(system, content)
else:
formatted_prompt = self._format_instruct_prompt(system, content)
sampling_params = SamplingParams(
max_tokens=max_tokens,
temperature=temperature,
top_p=0.95,
presence_penalty=0.0,
frequency_penalty=0.0,
)
return {
"prompt": formatted_prompt,
"sampling_params": sampling_params
}
def send_message(self,
content: str,
system: Optional[str] = None,
max_tokens: int = 1000,
temperature: float = 0) -> Dict[str, Any]:
"""
Send a message to the model and get a response.
Args:
content: User message or raw prompt
system: System prompt (supported for both base and instruct models)
max_tokens: Maximum number of tokens to generate
temperature: Sampling temperature
json_eval: Whether to parse the response as JSON
Returns:
Dictionary containing status and result/error
"""
try:
payload = self._create_message_payload(
system=system,
content=content,
max_tokens=max_tokens,
temperature=temperature
)
outputs = self.llm.generate(
prompts=[payload["prompt"]],
sampling_params=payload["sampling_params"]
)
try:
result_text = outputs[0].outputs[0].text.strip()
result = result_text
return {'status': True, 'result': result}
except Exception as e:
return {'status': True, 'result': outputs}
except Exception as e:
return {'status': False, 'error': str(e)}