add stream pipeline
Browse files- aimakerspace/openai_utils/chatmodel.py +7 -0
- app.py +17 -1
aimakerspace/openai_utils/chatmodel.py
CHANGED
|
@@ -25,3 +25,10 @@ class ChatOpenAI:
|
|
| 25 |
return response.choices[0].message.content
|
| 26 |
|
| 27 |
return response
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
return response.choices[0].message.content
|
| 26 |
|
| 27 |
return response
|
| 28 |
+
|
| 29 |
+
def run_stream(self, messages, settings, chainlit_msg, text_only: bool = True):
|
| 30 |
+
async for stream_resp in await openai.Completion.acreate(
|
| 31 |
+
model=self.model_name, prompt=messages, stream=True, **settings
|
| 32 |
+
):
|
| 33 |
+
token = stream_resp.get("choices")[0].get("text")
|
| 34 |
+
await chainlit_msg.stream_token(token)
|
app.py
CHANGED
|
@@ -76,6 +76,19 @@ class RetrievalAugmentedQAPipeline:
|
|
| 76 |
|
| 77 |
return self.llm.run([formatted_system_prompt, formatted_user_prompt])
|
| 78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
|
| 80 |
@cl.on_chat_start # marks a function that will be executed at the start of a user session
|
| 81 |
def start_chat():
|
|
@@ -97,7 +110,10 @@ def start_chat():
|
|
| 97 |
async def main(message: str):
|
| 98 |
|
| 99 |
qaPipeline = RetrievalAugmentedQAPipeline(vector_db_retriever=vector_db, llm=chat_openai)
|
| 100 |
-
|
|
|
|
|
|
|
|
|
|
| 101 |
|
| 102 |
|
| 103 |
|
|
|
|
| 76 |
|
| 77 |
return self.llm.run([formatted_system_prompt, formatted_user_prompt])
|
| 78 |
|
| 79 |
+
def stream_pipeline(self, user_query: str, msg: cl.Message) -> str:
|
| 80 |
+
context_list = self.vector_db_retriever.search_by_text(user_query, k=4)
|
| 81 |
+
|
| 82 |
+
context_prompt = ""
|
| 83 |
+
for context in context_list:
|
| 84 |
+
context_prompt += context[0] + "\n"
|
| 85 |
+
|
| 86 |
+
formatted_system_prompt = raqa_prompt.create_message(context=context_prompt)
|
| 87 |
+
|
| 88 |
+
formatted_user_prompt = user_prompt.create_message(user_query=user_query)
|
| 89 |
+
|
| 90 |
+
self.llm.stream([formatted_system_prompt, formatted_user_prompt])
|
| 91 |
+
|
| 92 |
|
| 93 |
@cl.on_chat_start # marks a function that will be executed at the start of a user session
|
| 94 |
def start_chat():
|
|
|
|
| 110 |
async def main(message: str):
|
| 111 |
|
| 112 |
qaPipeline = RetrievalAugmentedQAPipeline(vector_db_retriever=vector_db, llm=chat_openai)
|
| 113 |
+
msg = cl.Message(content="")
|
| 114 |
+
|
| 115 |
+
qaPipeline.stream_pipeline(user_query=message, msg=msg)
|
| 116 |
+
await msg.send()
|
| 117 |
|
| 118 |
|
| 119 |
|