syberWolf commited on
Commit
b3aebd1
·
1 Parent(s): c145e37

dd handler change

Browse files
Files changed (1) hide show
  1. handler.py +88 -74
handler.py CHANGED
@@ -1,96 +1,110 @@
1
  from llama_cpp import Llama
2
- from typing import Dict, List, Any
3
  import os
4
 
5
 
6
  class EndpointHandler:
 
 
 
 
 
 
 
 
 
7
  def __init__(self, model_path=""):
8
- # Construct the model path assuming the model is in the same directory as the handler file
9
- script_dir = os.path.dirname(os.path.abspath(__file__))
10
- model_filename = "Phi-3-medium-128k-instruct-IQ2_XS.gguf"
11
- self.model_path = os.path.join(script_dir, model_filename)
12
-
13
- # Check if the model file exists
14
- if not os.path.exists(self.model_path):
15
- raise ValueError(f"Model path does not exist: {self.model_path}")
16
-
17
- # Load the GGUF model using llama_cpp
18
- self.llm = Llama(
19
- model_path=self.model_path,
20
- n_ctx=5000, # Set context length to 5000 tokens
21
- # n_threads=12, # Adjust the number of CPU threads as per your machine
22
- n_gpu_layers=-1 # Adjust based on GPU availability
23
- )
24
-
25
- # Define generation kwargs for the model
26
- self.generation_kwargs = {
27
- "max_tokens": 400, # Respond with up to 400 tokens
28
- "stop": ["<|end|>", "<|user|>", "<|assistant|>"],
29
- "top_k": 1 # Greedy decoding
30
- }
31
-
32
- def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
 
 
 
33
  """
34
  Data args:
35
  inputs (:obj:`dict`): The input prompts for the LLM including system instructions and user messages.
36
-
 
37
  Return:
38
  A :obj:`list` | `dict`: will be serialized and returned.
39
  """
40
- # Extract inputs
41
- inputs = data.get("inputs", {})
42
- system_instructions = inputs.get("system", "")
43
- user_message = inputs.get("message", "")
44
-
45
- if not user_message:
46
- raise ValueError("No user message provided for the model.")
47
-
48
- # Combine system instructions and user message
49
- final_input = f"{system_instructions}\n{user_message}"
50
-
51
- # Run inference with llama_cpp
52
- response = self.llm.create_chat_completion(
53
- messages=[
54
- {"role": "system", "content": system_instructions},
55
- {"role": "user", "content": user_message}
56
- ],
57
- **self.generation_kwargs
58
- )
59
-
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  # Access generated text based on the response structure
61
  try:
62
  generated_text = response["choices"][0]["message"].get("content", "")
63
  except (KeyError, IndexError):
64
  raise ValueError("Unexpected response structure: missing 'content' in 'choices[0]['message']'")
65
-
66
  # Return the generated text
67
  return [{"generated_text": generated_text}]
68
 
69
 
70
- # Example usage:
 
 
 
 
 
 
 
 
 
 
 
 
71
  if __name__ == "__main__":
72
- # Instantiate the handler ONCE
73
- handler = EndpointHandler()
74
-
75
- # Handlers can be called multiple times with different inputs and the model will remain in memory
76
- data1 = {
77
- "inputs": {
78
- "system": "You are a helpful assistant.",
79
- "message": "What is the meaning of life?"
80
- }
81
- }
82
-
83
- data2 = {
84
- "inputs": {
85
- "system": "You are a knowledgeable assistant.",
86
- "message": "Tell me about the history of the internet."
87
- }
88
- }
89
-
90
- # First call - model already in memory
91
- response1 = handler(data1)
92
- print(response1)
93
-
94
- # Second call - model still in memory
95
- response2 = handler(data2)
96
- print(response2)
 
1
  from llama_cpp import Llama
2
+ from typing import Dict, List, Any, Union
3
  import os
4
 
5
 
6
  class EndpointHandler:
7
+ _instance = None # Singleton instance
8
+ _model_loaded = False # Flag to check if the model is loaded
9
+
10
+ def __new__(cls, *args, **kwargs):
11
+ if not cls._instance:
12
+ cls._instance = super(EndpointHandler, cls).__new__(cls, *args, **kwargs)
13
+ cls._instance._model_loaded = False
14
+ return cls._instance
15
+
16
  def __init__(self, model_path=""):
17
+ if not self._model_loaded:
18
+ # Construct the model path assuming the model is in the same directory as the handler file
19
+ script_dir = os.path.dirname(os.path.abspath(__file__))
20
+ model_filename = "Phi-3-medium-128k-instruct-IQ2_XS.gguf"
21
+ self.model_path = os.path.join(script_dir, model_filename)
22
+
23
+ # Check if the model file exists
24
+ if not os.path.exists(self.model_path):
25
+ raise ValueError(f"Model path does not exist: {self.model_path}")
26
+
27
+ # Load the GGUF model using llama_cpp
28
+ self.llm = Llama(
29
+ model_path=self.model_path,
30
+ n_ctx=5000, # Set context length to 5000 tokens
31
+ # n_threads=12, # Adjust the number of CPU threads as per your machine
32
+ n_gpu_layers=-1 # Adjust based on GPU availability
33
+ )
34
+
35
+ # Define generation kwargs for the model
36
+ self.generation_kwargs = {
37
+ "max_tokens": 400, # Respond with up to 400 tokens
38
+ "stop": ["<|end|>", "<|user|>", "<|assistant|>"],
39
+ "top_k": 1 # Greedy decoding
40
+ }
41
+
42
+ self._model_loaded = True
43
+
44
+ def __call__(self, data: Union[Dict[str, Any], str]) -> List[Dict[str, Any]]:
45
  """
46
  Data args:
47
  inputs (:obj:`dict`): The input prompts for the LLM including system instructions and user messages.
48
+ str: A string input which will create a chat completion.
49
+
50
  Return:
51
  A :obj:`list` | `dict`: will be serialized and returned.
52
  """
53
+ if isinstance(data, dict):
54
+ # Extract inputs
55
+ inputs = data.get("inputs", {})
56
+ system_instructions = inputs.get("system", "")
57
+ user_message = inputs.get("message", "")
58
+
59
+ if not user_message:
60
+ raise ValueError("No user message provided for the model.")
61
+
62
+ # Combine system instructions and user message
63
+ final_input = f"{system_instructions}\n{user_message}"
64
+
65
+ # Run inference with llama_cpp
66
+ response = self.llm.create_chat_completion(
67
+ messages=[
68
+ {"role": "system", "content": system_instructions},
69
+ {"role": "user", "content": user_message}
70
+ ],
71
+ **self.generation_kwargs
72
+ )
73
+
74
+ elif isinstance(data, str):
75
+ # Create a chat completion from the input string
76
+ response = self.llm.create_chat_completion(
77
+ messages=[
78
+ {"role": "user", "content": data}
79
+ ],
80
+ **self.generation_kwargs
81
+ )
82
+
83
+ else:
84
+ raise ValueError("Invalid input type. Expected dict or str, got {}".format(type(data)))
85
+
86
  # Access generated text based on the response structure
87
  try:
88
  generated_text = response["choices"][0]["message"].get("content", "")
89
  except (KeyError, IndexError):
90
  raise ValueError("Unexpected response structure: missing 'content' in 'choices[0]['message']'")
91
+
92
  # Return the generated text
93
  return [{"generated_text": generated_text}]
94
 
95
 
96
+ def main():
97
+ handler = EndpointHandler() # assume Handler is the class that contains the __call__ method
98
+
99
+ # Test 1: Dictionary input
100
+ data_dict = {"inputs": {"system": "System instructions", "message": "Hello, how are you?"}}
101
+ result_dict = handler(data_dict)
102
+ print("Dictionary input result:", result_dict)
103
+
104
+ # Test 2: String input
105
+ data_str = "Hello, how are you?"
106
+ result_str = handler(data_str)
107
+ print("String input result:", result_str)
108
+
109
  if __name__ == "__main__":
110
+ main()