added USE_LLAMA_2_PROMPT_TEMPLATE
Browse files- .env.example +1 -0
- app_modules/llm_chat_chain.py +22 -1
.env.example
CHANGED
|
@@ -19,6 +19,7 @@ HF_PIPELINE_DEVICE_TYPE=
|
|
| 19 |
# LOAD_QUANTIZED_MODEL=4bit
|
| 20 |
# LOAD_QUANTIZED_MODEL=8bit
|
| 21 |
|
|
|
|
| 22 |
DISABLE_MODEL_PRELOADING=true
|
| 23 |
CHAT_HISTORY_ENABLED=true
|
| 24 |
SHOW_PARAM_SETTINGS=false
|
|
|
|
| 19 |
# LOAD_QUANTIZED_MODEL=4bit
|
| 20 |
# LOAD_QUANTIZED_MODEL=8bit
|
| 21 |
|
| 22 |
+
USE_LLAMA_2_PROMPT_TEMPLATE=true
|
| 23 |
DISABLE_MODEL_PRELOADING=true
|
| 24 |
CHAT_HISTORY_ENABLED=true
|
| 25 |
SHOW_PARAM_SETTINGS=false
|
app_modules/llm_chat_chain.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
|
|
|
|
|
| 1 |
from langchain import LLMChain, PromptTemplate
|
| 2 |
from langchain.chains import ConversationalRetrievalChain
|
| 3 |
from langchain.chains.base import Chain
|
|
@@ -6,19 +8,38 @@ from langchain.memory import ConversationBufferMemory
|
|
| 6 |
from app_modules.llm_inference import LLMInference
|
| 7 |
|
| 8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
class ChatChain(LLMInference):
|
| 10 |
def __init__(self, llm_loader):
|
| 11 |
super().__init__(llm_loader)
|
| 12 |
|
| 13 |
def create_chain(self) -> Chain:
|
| 14 |
-
template =
|
|
|
|
|
|
|
|
|
|
| 15 |
{chat_history}
|
| 16 |
Human: {question}
|
| 17 |
Chatbot:"""
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
prompt = PromptTemplate(
|
| 20 |
input_variables=["chat_history", "question"], template=template
|
| 21 |
)
|
|
|
|
| 22 |
memory = ConversationBufferMemory(memory_key="chat_history")
|
| 23 |
|
| 24 |
llm_chain = LLMChain(
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
from langchain import LLMChain, PromptTemplate
|
| 4 |
from langchain.chains import ConversationalRetrievalChain
|
| 5 |
from langchain.chains.base import Chain
|
|
|
|
| 8 |
from app_modules.llm_inference import LLMInference
|
| 9 |
|
| 10 |
|
| 11 |
+
def get_llama_2_prompt_template():
|
| 12 |
+
B_INST, E_INST = "[INST]", "[/INST]"
|
| 13 |
+
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
|
| 14 |
+
|
| 15 |
+
instruction = "Chat History:\n\n{chat_history} \n\nUser: {question}"
|
| 16 |
+
system_prompt = "You are a helpful assistant, you always only answer for the assistant then you stop. read the chat history to get context"
|
| 17 |
+
|
| 18 |
+
SYSTEM_PROMPT = B_SYS + system_prompt + E_SYS
|
| 19 |
+
prompt_template = B_INST + SYSTEM_PROMPT + instruction + E_INST
|
| 20 |
+
return prompt_template
|
| 21 |
+
|
| 22 |
+
|
| 23 |
class ChatChain(LLMInference):
|
| 24 |
def __init__(self, llm_loader):
|
| 25 |
super().__init__(llm_loader)
|
| 26 |
|
| 27 |
def create_chain(self) -> Chain:
|
| 28 |
+
template = (
|
| 29 |
+
get_llama_2_prompt_template()
|
| 30 |
+
if os.environ.get("USE_LLAMA_2_PROMPT_TEMPLATE") == "true"
|
| 31 |
+
else """You are a chatbot having a conversation with a human.
|
| 32 |
{chat_history}
|
| 33 |
Human: {question}
|
| 34 |
Chatbot:"""
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
print(f"template: {template}")
|
| 38 |
|
| 39 |
prompt = PromptTemplate(
|
| 40 |
input_variables=["chat_history", "question"], template=template
|
| 41 |
)
|
| 42 |
+
|
| 43 |
memory = ConversationBufferMemory(memory_key="chat_history")
|
| 44 |
|
| 45 |
llm_chain = LLMChain(
|