restaurants / app.py
briankchan's picture
Merge remote-tracking branch 'origin/main'
508c1af
import re
import tomllib
import chainlit as cl
from docarray.index.abstract import BaseDocIndex
from azure_openai import AzureOpenaiSettings, AzureOpenaiEmbeddings, patch_chainlit
from chain import Chain
from data import embed, restaurant_index, RestaurantDescription
embedding_settings = AzureOpenaiEmbeddings.load_from_env().to_settings_dict()
patch_chainlit()
def search_embeddings(query: str, doc_index: BaseDocIndex):
vec = embed(query, **embedding_settings)
docs, scores = doc_index.find(vec, 'embedding', 5)
return docs
@cl.on_chat_start
async def start_chat():
cl.user_session.set("history", [])
@cl.on_message
async def on_message(message: str, message_id: str):
history = cl.user_session.get("history")
# history = []
# update history
history.append({"role": "user", "content": message})
# build AI response
chain = Chain(message_id, llm_settings=AzureOpenaiSettings.load_from_env())
query_msg = await chain.llm(
"""
You are a conversation summarizer that condenses a conversation between
a human and a brilllian AI into a search query that can be used to find relevant
restaurants. If a conversation is "normal" the AI answers it. If the question is
"nonsense" the AI says "Please rephrase your question".
Conversation history
============
{history}
============
From this conversation, create a search query that would fit the human's needs.
Do not say anything else; just the query. If a conversation is "normal" the AI answers it. If the question is
"nonsense" the AI says "Please rephrase your question".
""",
history=format_history(history),
)
# If the question is gibberish, stop the querying and make the user rephrase question
if query_msg.content == "Please rephrase your question.":
response_text = await chain.text("Please rephrase your query.", final=True)
await response_text.update()
else:
results = search_embeddings(query_msg.content, restaurant_index)
# results = search_embeddings(query_msg.content, restaurant_index)
await chain.text(str(list(results))) # TODO maybe json format would be better?
restaurants = "\n".join(f"- ID: {r.id} | {r.text}" for r in results)
final_choices_msg = await chain.llm(
"""
You are a search engine for restaurants.
Output the restaurant IDs for the best matches to the following query:
----
{query}
============
List of restaurants
----
{restaurants}
============
Output your final answer as a TOML blob.
Each restaurant should have a key for its ID, with a
boolean value, where true means the restaurant is a good fit
for all parts of the query.
For example:
---
[answer]
101 = false
1350 = true
02458 = false
9315 = true
128974 = true
============
Make include IDs of ALL restaurants, but only mark true for ones that fit the query.
""",
query=query_msg.content,
restaurants=restaurants
)
# match = re.match(r'```\s*toml\s*(.*)\s*```', final_choices_msg.content, re.DOTALL)
# toml_string = match.group(1)
toml_string = final_choices_msg.content
# don't output just the good values, since GPT doesn't think about each option
# final_ids = [x.strip() for x in final_choices_msg.content.split(',')]
# don't use json because curly braces brakes the template code...
# final_ids = json.loads(json_string)
# TOML is easy to write and parse for both machines and humans :)
obj = tomllib.loads(toml_string)
final_ids = [id for id, val in obj['answer'].items() if val]
if len(final_ids) == 0:
await chain.text("Sorry, no restaurants found. Please try another query.", final=True)
for i, id in enumerate(final_ids[:3]):
id = str(id)
restaurant: RestaurantDescription = restaurant_index[id] # why no automatic typing?
# Getting dishes and categories from a list form to string
dishes_as_string = ', '.join(restaurant.dishes)
categories_as_string = ', '.join(restaurant.categories)
msg = await chain.text(f"Option {i}", final=True)
msg.elements = [
# note: image always displays above text
cl.Image(name=restaurant.name, url=restaurant.image_url, display='inline', size='small'),
cl.Text(name=restaurant.name, content=restaurant.intro, display='inline'),
cl.Text(name="Example Dishes:", content=dishes_as_string, display='inline'),
cl.Text(name="Category:", content=categories_as_string, display='inline'),
cl.Text(name="Estimated Average Price (HKD):", content=restaurant.price, display="inline"),
# TODO text could also include categories/dishes/rating/price/location
]
msg.actions = [
cl.Action(name='book', value=id, label='Book', description='Click to book this restaurant'),
]
await msg.update()
# TODO what should the history include? ids only? or also descriptions?
# history.append({"role": "assistant", "content": response.content})
# await cl.Text(name="rephrase", content=response_text, displlay="inline").send()
NAMES = {
# 'system': '',
'user': 'Human',
'assistant': 'AI',
}
def format_history(history: list[dict]) -> str:
"""Formats list of messages into a single string."""
strings = [f'{NAMES[m["role"]]}: {m["content"]}' for m in history]
return "\n".join(strings)