File size: 4,326 Bytes
9d5b280
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import os
from enum import Enum
from typing import Optional, Dict, Any, List, Union
from vllm import LLM, SamplingParams
from vllm.outputs import RequestOutput
from transformers import AutoTokenizer

DEFAULT_MAX_TOKENS = 16000

class ModelType(Enum):
    BASE = "base"
    INSTRUCT = "instruct"

class VLLMClient:
    def __init__(self, 
                 model_path: str):
 
        self.model_path = model_path
        self.model_type = self._detect_model_type(model_path)
        self.llm = LLM(model=model_path)
        
        # Load tokenizer for all models to handle proper text formatting
        self.tokenizer = AutoTokenizer.from_pretrained(model_path)

    @staticmethod
    def _detect_model_type(model_path: str) -> ModelType:

        model_path_lower = model_path.lower()
        instruct_keywords = ['instruct', 'chat', 'dialogue', 'conversations', 'kista']
        
        # Check if any instruct-related keyword is in the model path
        is_instruct = any(keyword in model_path_lower for keyword in instruct_keywords)
        return ModelType.INSTRUCT if is_instruct else ModelType.BASE

    def _format_base_prompt(self, system: Optional[str], content: str) -> str:
        """
        Format prompt for base models including system prompt.
        """
        if system:
            # For base models, we'll use a simple template
            return f"{system} {content}"
        return content

    def _format_instruct_prompt(self, system: Optional[str], content: str) -> str:
        """
        Format prompt for instruct models using the model's chat template.
        """
        messages = []
        if system:
            messages.append({"role": "system", "content": system})
        messages.append({"role": "user", "content": content})
        
        return self.tokenizer.apply_chat_template(
                messages, 
                tokenize=False,
                add_generation_prompt=True
            )

    def _create_message_payload(self, 
                              system: Optional[str], 
                              content: str,
                              max_tokens: int,
                              temperature: float) -> Dict[str, Any]:
        """
        Create the sampling parameters and format the prompt based on model type.
        """
        if self.model_type == ModelType.BASE:
            formatted_prompt = self._format_base_prompt(system, content)
        else:
            formatted_prompt = self._format_instruct_prompt(system, content)

        sampling_params = SamplingParams(
            max_tokens=max_tokens,
            temperature=temperature,
            top_p=0.95,
            presence_penalty=0.0,
            frequency_penalty=0.0,
        )
        
        return {
            "prompt": formatted_prompt,
            "sampling_params": sampling_params
        }

    
    def send_message(self,
                    content: str,
                    system: Optional[str] = None,
                    max_tokens: int = 1000,
                    temperature: float = 0) -> Dict[str, Any]:
        """
        Send a message to the model and get a response.
        
        Args:
            content: User message or raw prompt
            system: System prompt (supported for both base and instruct models)
            max_tokens: Maximum number of tokens to generate
            temperature: Sampling temperature
            json_eval: Whether to parse the response as JSON
            
        Returns:
            Dictionary containing status and result/error
        """
        try:
            payload = self._create_message_payload(
                system=system,
                content=content,
                max_tokens=max_tokens,
                temperature=temperature
            )

            outputs = self.llm.generate(
                prompts=[payload["prompt"]],
                sampling_params=payload["sampling_params"]
            )

            try:
                result_text = outputs[0].outputs[0].text.strip()
                result = result_text
                return {'status': True, 'result': result}
            except Exception as e:
                return {'status': True, 'result': outputs}
                
        except Exception as e:
            return {'status': False, 'error': str(e)}