Milch commited on
Commit ·
15cd2b4
1
Parent(s): 0d4c591
Update main.py
Browse files
main.py
CHANGED
|
@@ -33,6 +33,7 @@ file_handler.setFormatter(formatter)
|
|
| 33 |
logger = logging.getLogger('gunicorn.error')
|
| 34 |
logger.addHandler(file_handler)
|
| 35 |
|
|
|
|
| 36 |
model_path = os.environ.get('LLAMACPP_PATH', None)
|
| 37 |
|
| 38 |
|
|
@@ -75,19 +76,35 @@ async def read_device():
|
|
| 75 |
def create_generated_text(messages: list[dict[str, str]] = Body(...), temperature: float = Body(default=1.0)):
|
| 76 |
input_text = ''
|
| 77 |
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
|
| 84 |
if len(input_text) > 0:
|
| 85 |
llm = Llama(model_path=model_path, n_ctx=8192, n_gpu_layers=-1, n_batch=32, verbose=False)
|
| 86 |
choices = []
|
| 87 |
|
| 88 |
try:
|
| 89 |
-
for choice in llm(input_text
|
| 90 |
-
matches = re.findall(
|
| 91 |
|
| 92 |
if len(matches) > 0:
|
| 93 |
choices.append({'role': 'assistant', 'content': matches[len(matches) - 1]})
|
|
|
|
| 33 |
logger = logging.getLogger('gunicorn.error')
|
| 34 |
logger.addHandler(file_handler)
|
| 35 |
|
| 36 |
+
llm_prompt_format = os.getenv('LLM_PROMPT_FORMAT', None)
|
| 37 |
model_path = os.environ.get('LLAMACPP_PATH', None)
|
| 38 |
|
| 39 |
|
|
|
|
| 76 |
def create_generated_text(messages: list[dict[str, str]] = Body(...), temperature: float = Body(default=1.0)):
|
| 77 |
input_text = ''
|
| 78 |
|
| 79 |
+
if llm_prompt_format == 'Llama':
|
| 80 |
+
for message in messages:
|
| 81 |
+
if message['role'] == 'system':
|
| 82 |
+
input_text += f"<|start_header_id|>system<|end_header_id|>\n\n{message['content']}<|eot_id|>"
|
| 83 |
+
elif message['role'] == 'user':
|
| 84 |
+
input_text += f"<|start_header_id|>user<|end_header_id|>\n\n{message['content']}<|eot_id|>"
|
| 85 |
+
elif message['role'] == 'assistant':
|
| 86 |
+
input_text += f"<|start_header_id|>assistant<|end_header_id|>\n\n{message['content']}<|eot_id|>"
|
| 87 |
+
|
| 88 |
+
input_text += '<|start_header_id|>assistant<|end_header_id|>\n\n'
|
| 89 |
+
pattern = r'<|start_header_id|>assistant<|end_header_id|>\n\n(.+?)(?:(?:<|eot_id|>)|$)'
|
| 90 |
+
|
| 91 |
+
else:
|
| 92 |
+
for message in messages:
|
| 93 |
+
if message['role'] == 'system' or message['role'] == 'user':
|
| 94 |
+
input_text += f"<start_of_turn>user\n{message['content']}<end_of_turn>\n"
|
| 95 |
+
elif message['role'] == 'assistant':
|
| 96 |
+
input_text += f"<start_of_turn>model\n{message['content']}<end_of_turn>\n"
|
| 97 |
+
|
| 98 |
+
input_text += '<start_of_turn>model\n'
|
| 99 |
+
pattern = r'<start_of_turn>model\n(.+?)(?:(?:<end_of_turn>)|$)'
|
| 100 |
|
| 101 |
if len(input_text) > 0:
|
| 102 |
llm = Llama(model_path=model_path, n_ctx=8192, n_gpu_layers=-1, n_batch=32, verbose=False)
|
| 103 |
choices = []
|
| 104 |
|
| 105 |
try:
|
| 106 |
+
for choice in llm(input_text, max_tokens=2048, temperature=temperature, top_p=0.95, echo=True)['choices']:
|
| 107 |
+
matches = re.findall(pattern, choice['text'], re.DOTALL)
|
| 108 |
|
| 109 |
if len(matches) > 0:
|
| 110 |
choices.append({'role': 'assistant', 'content': matches[len(matches) - 1]})
|