Milch commited on
Commit
15cd2b4
·
1 Parent(s): 0d4c591

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +24 -7
main.py CHANGED
@@ -33,6 +33,7 @@ file_handler.setFormatter(formatter)
33
  logger = logging.getLogger('gunicorn.error')
34
  logger.addHandler(file_handler)
35
 
 
36
  model_path = os.environ.get('LLAMACPP_PATH', None)
37
 
38
 
@@ -75,19 +76,35 @@ async def read_device():
75
  def create_generated_text(messages: list[dict[str, str]] = Body(...), temperature: float = Body(default=1.0)):
76
  input_text = ''
77
 
78
- for message in messages:
79
- if message['role'] == 'system' or message['role'] == 'user':
80
- input_text += f"<start_of_turn>user\n{message['content']}<end_of_turn>\n"
81
- elif message['role'] == 'assistant':
82
- input_text += f"<start_of_turn>model\n{message['content']}<end_of_turn>\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
  if len(input_text) > 0:
85
  llm = Llama(model_path=model_path, n_ctx=8192, n_gpu_layers=-1, n_batch=32, verbose=False)
86
  choices = []
87
 
88
  try:
89
- for choice in llm(input_text + '<start_of_turn>model\n', max_tokens=2048, temperature=temperature, top_p=0.95, echo=True)['choices']:
90
- matches = re.findall(r'<start_of_turn>model\n(.+?)(?:(?:<end_of_turn>)|$)', choice['text'], re.DOTALL)
91
 
92
  if len(matches) > 0:
93
  choices.append({'role': 'assistant', 'content': matches[len(matches) - 1]})
 
33
  logger = logging.getLogger('gunicorn.error')
34
  logger.addHandler(file_handler)
35
 
36
+ llm_prompt_format = os.getenv('LLM_PROMPT_FORMAT', None)
37
  model_path = os.environ.get('LLAMACPP_PATH', None)
38
 
39
 
 
76
  def create_generated_text(messages: list[dict[str, str]] = Body(...), temperature: float = Body(default=1.0)):
77
  input_text = ''
78
 
79
+ if llm_prompt_format == 'Llama':
80
+ for message in messages:
81
+ if message['role'] == 'system':
82
+ input_text += f"<|start_header_id|>system<|end_header_id|>\n\n{message['content']}<|eot_id|>"
83
+ elif message['role'] == 'user':
84
+ input_text += f"<|start_header_id|>user<|end_header_id|>\n\n{message['content']}<|eot_id|>"
85
+ elif message['role'] == 'assistant':
86
+ input_text += f"<|start_header_id|>assistant<|end_header_id|>\n\n{message['content']}<|eot_id|>"
87
+
88
+ input_text += '<|start_header_id|>assistant<|end_header_id|>\n\n'
89
+ pattern = r'<|start_header_id|>assistant<|end_header_id|>\n\n(.+?)(?:(?:<|eot_id|>)|$)'
90
+
91
+ else:
92
+ for message in messages:
93
+ if message['role'] == 'system' or message['role'] == 'user':
94
+ input_text += f"<start_of_turn>user\n{message['content']}<end_of_turn>\n"
95
+ elif message['role'] == 'assistant':
96
+ input_text += f"<start_of_turn>model\n{message['content']}<end_of_turn>\n"
97
+
98
+ input_text += '<start_of_turn>model\n'
99
+ pattern = r'<start_of_turn>model\n(.+?)(?:(?:<end_of_turn>)|$)'
100
 
101
  if len(input_text) > 0:
102
  llm = Llama(model_path=model_path, n_ctx=8192, n_gpu_layers=-1, n_batch=32, verbose=False)
103
  choices = []
104
 
105
  try:
106
+ for choice in llm(input_text, max_tokens=2048, temperature=temperature, top_p=0.95, echo=True)['choices']:
107
+ matches = re.findall(pattern, choice['text'], re.DOTALL)
108
 
109
  if len(matches) > 0:
110
  choices.append({'role': 'assistant', 'content': matches[len(matches) - 1]})