File size: 6,270 Bytes
e6e69dc
 
5a768b8
 
e6e69dc
5a768b8
e6e69dc
5a768b8
e6e69dc
 
 
 
 
 
 
 
 
 
 
5a768b8
 
 
 
e6e69dc
5a768b8
 
 
 
e746978
 
e6e69dc
 
 
 
 
 
 
 
 
 
d446300
 
 
e6e69dc
 
 
 
 
 
 
 
 
 
d446300
 
e6e69dc
 
 
 
d446300
 
4826640
d446300
 
 
 
 
 
 
 
1ef3ac7
d446300
 
 
1b2ae99
d446300
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1b2ae99
d446300
1b2ae99
 
d446300
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1b2ae99
d446300
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4826640
 
d446300
 
 
 
 
 
 
 
 
 
 
 
1ef3ac7
d446300
 
 
6949531
d446300
 
 
 
 
 
 
 
 
e6e69dc
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
import re
import tomllib

import chainlit as cl
from docarray.index.abstract import BaseDocIndex

from azure_openai import AzureOpenaiSettings, AzureOpenaiEmbeddings, patch_chainlit
from chain import Chain
from data import embed, restaurant_index, RestaurantDescription

embedding_settings = AzureOpenaiEmbeddings.load_from_env().to_settings_dict()

patch_chainlit()


def search_embeddings(query: str, doc_index: BaseDocIndex):
    vec = embed(query, **embedding_settings)
    docs, scores = doc_index.find(vec, 'embedding', 5)
    return docs


@cl.on_chat_start
async def start_chat():
    cl.user_session.set("history", [])


@cl.on_message
async def on_message(message: str, message_id: str):
    history = cl.user_session.get("history")
    # history = []

    # update history
    history.append({"role": "user", "content": message})

    # build AI response
    chain = Chain(message_id, llm_settings=AzureOpenaiSettings.load_from_env())

    query_msg = await chain.llm(
        """
        You are a conversation summarizer that condenses a conversation between
        a human and a brilllian AI into a search query that can be used to find relevant
        restaurants. If a conversation is "normal" the AI answers it. If the question is 
        "nonsense" the AI says "Please rephrase your question".
        
        Conversation history
        
        ============
        
        {history}
        
        ============
        
        From this conversation, create a search query that would fit the human's needs.
        Do not say anything else; just the query. If a conversation is "normal" the AI answers it. If the question is 
        "nonsense" the AI says "Please rephrase your question".
        """,
        history=format_history(history),
    )

    # If the question is gibberish, stop the querying and make the user rephrase question
    if query_msg.content == "Please rephrase your question.":
        response_text = await chain.text("Please rephrase your query.", final=True)
        await response_text.update()
    else:

        results = search_embeddings(query_msg.content, restaurant_index)
        # results = search_embeddings(query_msg.content, restaurant_index)

        await chain.text(str(list(results)))  # TODO maybe json format would be better?

        restaurants = "\n".join(f"- ID: {r.id} | {r.text}" for r in results)

        final_choices_msg = await chain.llm(
            """
            You are a search engine for restaurants.
            Output the restaurant IDs for the best matches to the following query:
            
            ----
            
            {query}
            
            ============
            
            
            List of restaurants
            
            ----
            
            {restaurants}
            
            ============
            
            
            Output your final answer as a TOML blob.
            Each restaurant should have a key for its ID, with a
            boolean value, where true means the restaurant is a good fit
            for all parts of the query.
            
            For example:
            
            ---
            
            [answer]
            
            101 = false
            
            1350 = true
            
            02458 = false
            
            9315 = true
            
            128974 = true
            
            ============
            
            
            Make include IDs of ALL restaurants, but only mark true for ones that fit the query.
            """,
            query=query_msg.content,
            restaurants=restaurants
        )

        # match = re.match(r'```\s*toml\s*(.*)\s*```', final_choices_msg.content, re.DOTALL)
        # toml_string = match.group(1)
        toml_string = final_choices_msg.content

        # don't output just the good values, since GPT doesn't think about each option
        # final_ids = [x.strip() for x in final_choices_msg.content.split(',')]

        # don't use json because curly braces brakes the template code...
        # final_ids = json.loads(json_string)

        # TOML is easy to write and parse for both machines and humans :)
        obj = tomllib.loads(toml_string)
        final_ids = [id for id, val in obj['answer'].items() if val]

        if len(final_ids) == 0:
            await chain.text("Sorry, no restaurants found. Please try another query.", final=True)

        for i, id in enumerate(final_ids[:3]):
            id = str(id)
            restaurant: RestaurantDescription = restaurant_index[id]  # why no automatic typing?
            # Getting dishes and categories from a list form to string
            dishes_as_string = ', '.join(restaurant.dishes)
            categories_as_string = ', '.join(restaurant.categories)
            msg = await chain.text(f"Option {i}", final=True)
            
            msg.elements = [
                # note: image always displays above text
                cl.Image(name=restaurant.name, url=restaurant.image_url, display='inline', size='small'),
                cl.Text(name=restaurant.name, content=restaurant.intro, display='inline'),
                cl.Text(name="Example Dishes:", content=dishes_as_string, display='inline'),
                cl.Text(name="Category:", content=categories_as_string, display='inline'),
                cl.Text(name="Estimated Average Price (HKD):", content=restaurant.price, display="inline"),
                # TODO text could also include categories/dishes/rating/price/location
            ]
            msg.actions = [
                cl.Action(name='book', value=id, label='Book', description='Click to book this restaurant'),
            ]
            await msg.update()

        # TODO what should the history include? ids only? or also descriptions?
        # history.append({"role": "assistant", "content": response.content})
    # await cl.Text(name="rephrase", content=response_text, displlay="inline").send()

NAMES = {
    # 'system': '',
    'user': 'Human',
    'assistant': 'AI',
}


def format_history(history: list[dict]) -> str:
    """Formats list of messages into a single string."""
    strings = [f'{NAMES[m["role"]]}: {m["content"]}' for m in history]
    return "\n".join(strings)