briankchan commited on
Commit
e6e69dc
·
1 Parent(s): 5a768b8
Files changed (9) hide show
  1. .gitignore +2 -0
  2. app.py +144 -21
  3. azure_openai.py +98 -0
  4. chain.py +18 -4
  5. data.py +72 -0
  6. launch.py +6 -0
  7. requirements.txt +1 -0
  8. scripts/__init__.py +0 -0
  9. scripts/create_embeddings.py +101 -0
.gitignore CHANGED
@@ -1,4 +1,5 @@
1
  .chainlit/chat.db
 
2
 
3
  # Created by https://www.toptal.com/developers/gitignore/api/python,intellij+all,visualstudiocode
4
  # Edit at https://www.toptal.com/developers/gitignore?templates=python,intellij+all,visualstudiocode
@@ -283,3 +284,4 @@ pyrightconfig.json
283
  .ionide
284
 
285
  # End of https://www.toptal.com/developers/gitignore/api/python,intellij+all,visualstudiocode
 
 
1
  .chainlit/chat.db
2
+ .chainlit/chat_files
3
 
4
  # Created by https://www.toptal.com/developers/gitignore/api/python,intellij+all,visualstudiocode
5
  # Edit at https://www.toptal.com/developers/gitignore?templates=python,intellij+all,visualstudiocode
 
284
  .ionide
285
 
286
  # End of https://www.toptal.com/developers/gitignore/api/python,intellij+all,visualstudiocode
287
+ /.chainlit/chat_files/3105f452-a667-4cd0-b40c-7502c8bc62d2/1bf41236-505a-44b3-9ac4-55efa37a393d.txt
app.py CHANGED
@@ -1,33 +1,156 @@
1
- import asyncio
 
2
 
3
  import chainlit as cl
 
4
 
 
5
  from chain import Chain
 
 
 
 
 
 
 
 
 
 
 
6
 
7
 
8
  @cl.on_chat_start
9
  async def start_chat():
10
- chain = Chain(None)
11
- await chain.text("I will count to 5. How many concurrent times should I count?")
12
 
13
 
14
  @cl.on_message
15
  async def on_message(message: str, message_id: str):
16
- chain = Chain(message_id)
17
-
18
- try:
19
- num = int(message)
20
- except ValueError:
21
- await chain.text_stream("Sorry, that doesn't look like an integer to me.", final=True)
22
- return
23
-
24
- if num > 10:
25
- await chain.text_stream("Whoa, let's try a smaller number. (Max 10.)", final=True)
26
- return
27
-
28
- await chain.text("Alright, here we go:")
29
- coroutines = []
30
- for i in range(num):
31
- coroutines.append(chain.text_stream("1 2 3 4 5", delay=1, name=f"Counter {i + 1}"))
32
- await asyncio.gather(*coroutines)
33
- await chain.text_stream("Okay, I'm done counting now.", final=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import tomllib
3
 
4
  import chainlit as cl
5
+ from docarray.index.abstract import BaseDocIndex
6
 
7
+ from azure_openai import AzureOpenaiSettings, AzureOpenaiEmbeddings, patch_chainlit
8
  from chain import Chain
9
+ from data import embed, restaurant_index, RestaurantDescription
10
+
11
+ embedding_settings = AzureOpenaiEmbeddings.load_from_env().to_settings_dict()
12
+
13
+ patch_chainlit()
14
+
15
+
16
+ def search_embeddings(query: str, doc_index: BaseDocIndex):
17
+ vec = embed(query, **embedding_settings)
18
+ docs, scores = doc_index.find(vec, 'embedding', 5)
19
+ return docs
20
 
21
 
22
  @cl.on_chat_start
23
  async def start_chat():
24
+ cl.user_session.set("history", [])
 
25
 
26
 
27
  @cl.on_message
28
  async def on_message(message: str, message_id: str):
29
+ history = cl.user_session.get("history")
30
+
31
+ # update history
32
+ history.append({"role": "user", "content": message})
33
+
34
+ # build AI response
35
+ chain = Chain(message_id, llm_settings=AzureOpenaiSettings.load_from_env())
36
+
37
+ query_msg = await chain.llm(
38
+ """
39
+ You are a conversation summarizer that condenses a conversation between
40
+ a human and AI into a search query that can be used to find relevant
41
+ restaurants.
42
+
43
+ Conversation history
44
+
45
+ ============
46
+
47
+ {history}
48
+
49
+ ============
50
+
51
+ From this conversation, create a search query that would fit the human's needs.
52
+ Do not say anything else; just the query.
53
+ """,
54
+ history=format_history(history),
55
+ )
56
+
57
+ results = search_embeddings(query_msg.content, restaurant_index)
58
+ await chain.text(str(list(results))) # TODO maybe json format would be better?
59
+
60
+ restaurants = "\n".join(f"- ID: {r.id} | Description: {r.text}" for r in results)
61
+
62
+ final_choices_msg = await chain.llm(
63
+ """
64
+ You are a search engine for restaurants.
65
+ Output the restaurant IDs for the best matches to the following query:
66
+
67
+ ----
68
+
69
+ {query}
70
+
71
+ ============
72
+
73
+
74
+ List of restaurants
75
+
76
+ ----
77
+
78
+ {restaurants}
79
+
80
+ ============
81
+
82
+
83
+ Output your final answer as a TOML blob.
84
+ Each restaurant should have a key for its ID, with a
85
+ boolean value, where true means the restaurant is a good fit.
86
+
87
+ For example:
88
+
89
+ ---
90
+
91
+ [answer]
92
+
93
+ 101 = false
94
+
95
+ 1350 = true
96
+
97
+ 02458 = false
98
+
99
+ 9315 = true
100
+
101
+ 128974 = true
102
+
103
+ ============
104
+
105
+
106
+ Make include IDs of ALL restaurants, but only mark true for ones that fit the query.
107
+
108
+ """,
109
+ query=query_msg.content,
110
+ restaurants=restaurants
111
+ )
112
+
113
+ # match = re.match(r'```\s*toml\s*(.*)\s*```', final_choices_msg.content, re.DOTALL)
114
+ # toml_string = match.group(1)
115
+ toml_string = final_choices_msg.content
116
+
117
+ # don't output just the good values, since GPT doesn't think about each option
118
+ # final_ids = [x.strip() for x in final_choices_msg.content.split(',')]
119
+
120
+ # don't use json because curly braces brakes the template code...
121
+ # final_ids = json.loads(json_string)
122
+
123
+ # TOML is easy to write and parse for both machines and humans :)
124
+ obj = tomllib.loads(toml_string)
125
+ final_ids = [id for id, val in obj['answer'].items() if val]
126
+
127
+ for i, id in enumerate(final_ids[:3]):
128
+ id = str(id)
129
+ restaurant: RestaurantDescription = restaurant_index[id] # why no automatic typing?
130
+ msg = await chain.text(f"Option {i}", final=True)
131
+ msg.elements = [
132
+ # note: image always displays above text
133
+ cl.Image(name=restaurant.name, url=restaurant.image_url, display='inline', size='small'),
134
+ cl.Text(name=restaurant.name, content=restaurant.text, display='inline'),
135
+ # TODO text could also include categories/dishes/rating/price
136
+ ]
137
+ msg.actions = [
138
+ cl.Action(name='book', value=id, label='Book', description='Click to book this restaurant'),
139
+ ]
140
+ await msg.update()
141
+
142
+ # TODO what should the history include? ids only? or also descriptions?
143
+ # history.append({"role": "assistant", "content": response.content})
144
+
145
+
146
+ NAMES = {
147
+ # 'system': '',
148
+ 'user': 'Human',
149
+ 'assistant': 'AI',
150
+ }
151
+
152
+
153
+ def format_history(history: list[dict]) -> str:
154
+ """Formats list of messages into a single string."""
155
+ strings = [f'{NAMES[m["role"]]}: {m["content"]}' for m in history]
156
+ return "\n".join(strings)
azure_openai.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Self
3
+
4
+ from chainlit import LLMSettings
5
+ from chainlit.telemetry import trace_event
6
+ from chainlit.types import CompletionRequest
7
+ from pydantic.dataclasses import dataclass
8
+ from starlette.responses import PlainTextResponse
9
+
10
+
11
+ @dataclass
12
+ class AzureOpenaiSettings(LLMSettings):
13
+ api_type: str = 'azure'
14
+ api_base: str = ''
15
+ engine: str = ''
16
+ api_version: str = '2023-05-15'
17
+
18
+ def to_settings_dict(self):
19
+ return {
20
+ **super().to_settings_dict(),
21
+ "api_type": self.api_type,
22
+ "api_base": self.api_base,
23
+ "api_version": self.api_version,
24
+ "engine": self.engine,
25
+ }
26
+
27
+ @classmethod
28
+ def load_from_env(cls: type[Self], *args, **kwargs) -> Self:
29
+ return cls(
30
+ *args,
31
+ api_type='azure',
32
+ api_base=os.environ.get('AZURE_OPENAI_ENDPOINT'),
33
+ engine=os.environ.get('AZURE_OPENAI_DEPLOYMENT'),
34
+ api_version=os.environ.get('AZURE_OPENAI_VERSION', '2023-05-15'),
35
+ **kwargs,
36
+ )
37
+
38
+
39
+ @dataclass
40
+ class AzureOpenaiEmbeddings:
41
+ api_type: str = 'azure'
42
+ api_base: str = ''
43
+ engine: str = ''
44
+ api_version: str = '2023-05-15'
45
+
46
+ def to_settings_dict(self):
47
+ return {
48
+ "api_type": self.api_type,
49
+ "api_base": self.api_base,
50
+ "api_version": self.api_version,
51
+ "engine": self.engine,
52
+ }
53
+
54
+ @classmethod
55
+ def load_from_env(cls: type[Self], *args, **kwargs) -> Self:
56
+ return cls(
57
+ *args,
58
+ api_type='azure',
59
+ api_base=os.environ.get('AZURE_OPENAI_ENDPOINT'),
60
+ engine=os.environ.get('AZURE_OPENAI_EMBEDDINGS_DEPLOYMENT'),
61
+ api_version=os.environ.get('AZURE_OPENAI_VERSION', '2023-05-15'),
62
+ **kwargs,
63
+ )
64
+
65
+
66
+ def patch_chainlit():
67
+ from chainlit.server import app
68
+
69
+ # replace playground's completion endpoint with one that uses custom openai settings
70
+ app.router.routes = list(filter(lambda route: route.path != '/completion', app.router.routes))
71
+
72
+ @app.post("/completion")
73
+ async def completion(request: CompletionRequest):
74
+ """Handle a completion request from the prompt playground."""
75
+
76
+ import openai
77
+
78
+ trace_event("completion")
79
+
80
+ api_key = request.userEnv.get("OPENAI_API_KEY", os.environ.get("OPENAI_API_KEY"))
81
+
82
+ stop = request.settings.stop
83
+ # OpenAI doesn't support an empty stop array, clear it
84
+ if isinstance(stop, list) and len(stop) == 0:
85
+ stop = None
86
+
87
+ response = await openai.ChatCompletion.acreate(
88
+ api_key=api_key,
89
+ messages=[{"role": "user", "content": request.prompt}],
90
+ stop=stop,
91
+ # **completion.settings.to_settings_dict(),
92
+ # HACK: hard-code llm settings
93
+ **dict(api_type='azure', api_base=os.environ.get('AZURE_OPENAI_ENDPOINT'),
94
+ engine=os.environ.get('AZURE_OPENAI_DEPLOYMENT'),
95
+ api_version=os.environ.get('AZURE_OPENAI_VERSION', '2023-05-15')
96
+ ),
97
+ )
98
+ return PlainTextResponse(content=response["choices"][0]["message"]["content"])
chain.py CHANGED
@@ -1,6 +1,7 @@
1
  import asyncio
2
  import os
3
  import re
 
4
 
5
  import chainlit as cl
6
  import openai
@@ -8,6 +9,14 @@ from chainlit import LLMSettings
8
  from chainlit.config import config
9
 
10
 
 
 
 
 
 
 
 
 
11
  # TODO each chain should be able to make a child chain?
12
  # root = Chain()
13
  # first = root.child("something")
@@ -26,11 +35,12 @@ class Chain:
26
  **kwargs,
27
  )
28
 
29
- async def text(self, text, final=False, name=None):
30
  message = self.make_message(content=text, final=final, name=name)
31
  await message.send()
 
32
 
33
- async def text_stream(self, text: str, delay=.1, name=None, final=False):
34
  message = self.make_message(content='', final=final, name=name)
35
  tokens = text.split(" ")
36
  first = True
@@ -41,8 +51,12 @@ class Chain:
41
  await asyncio.sleep(delay)
42
  first = False
43
  await message.send()
 
 
 
 
 
44
 
45
- async def llm(self, template, *args, name=None, final=False, **kwargs) -> str:
46
  variables = re.findall(r'\{(.*?)}', template)
47
  if len(args) > 1:
48
  raise RuntimeError("If there is more than one argument, use kwargs")
@@ -66,4 +80,4 @@ class Chain:
66
  await message.stream_token(token)
67
 
68
  await message.send()
69
- return message.content
 
1
  import asyncio
2
  import os
3
  import re
4
+ from inspect import cleandoc
5
 
6
  import chainlit as cl
7
  import openai
 
9
  from chainlit.config import config
10
 
11
 
12
+ def replace_newlines(match: re.Match) -> str:
13
+ newlines = match.group(0)
14
+ count = len(newlines)
15
+ if count <= 1:
16
+ return " "
17
+ return newlines[1:]
18
+
19
+
20
  # TODO each chain should be able to make a child chain?
21
  # root = Chain()
22
  # first = root.child("something")
 
35
  **kwargs,
36
  )
37
 
38
+ async def text(self, text, final=False, name=None) -> cl.Message:
39
  message = self.make_message(content=text, final=final, name=name)
40
  await message.send()
41
+ return message
42
 
43
+ async def text_stream(self, text: str, delay=.1, name=None, final=False) -> cl.Message:
44
  message = self.make_message(content='', final=final, name=name)
45
  tokens = text.split(" ")
46
  first = True
 
51
  await asyncio.sleep(delay)
52
  first = False
53
  await message.send()
54
+ return message
55
+
56
+ async def llm(self, template, *args, name=None, final=False, **kwargs) -> cl.Message:
57
+ template = cleandoc(template)
58
+ template = re.sub('\n+', replace_newlines, template) # remove a newline
59
 
 
60
  variables = re.findall(r'\{(.*?)}', template)
61
  if len(args) > 1:
62
  raise RuntimeError("If there is more than one argument, use kwargs")
 
80
  await message.stream_token(token)
81
 
82
  await message.send()
83
+ return message
data.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Sequence
3
+
4
+ import numpy as np
5
+ import openai
6
+ from docarray.documents import TextDoc
7
+ from docarray.index import InMemoryExactNNIndex
8
+ from docarray.typing import NdArray
9
+
10
+
11
+ class OpenaiEmbeddingDoc(TextDoc):
12
+ embedding: NdArray[1536] | None
13
+
14
+ @staticmethod
15
+ def create_embeddings(docs: Sequence['OpenaiEmbeddingDoc'], **kwargs):
16
+ if len(docs) > 16: # max allowed by azure
17
+ for i in range(0, len(docs), 16):
18
+ print(f"Processing 16 starting from index {i}")
19
+ OpenaiEmbeddingDoc.create_embeddings(docs[i:i+16], **kwargs)
20
+ else:
21
+ texts = [d.text for d in docs]
22
+ kwargs.setdefault('api_')
23
+ response = openai.Embedding.create(
24
+ input=texts,
25
+ api_key=os.environ.get('OPENAI_API_KEY', kwargs.get('api_key')),
26
+ **kwargs # API key, model/engine, api_type, api_date, api_
27
+ )
28
+ embeddings = response['data']
29
+ assert(len(embeddings) == len(docs))
30
+ for obj in embeddings:
31
+ doc = docs[obj['index']]
32
+ doc.embedding = np.array(obj['embedding'])
33
+
34
+
35
+ def embed(text: str, **kwargs) -> np.ndarray[1536]:
36
+ response = openai.Embedding.create(
37
+ input=text,
38
+ api_key=os.environ.get('OPENAI_API_KEY', kwargs.get('api_key')),
39
+ **kwargs
40
+ )
41
+ return np.array(response['data'][0]['embedding'])
42
+
43
+
44
+ class RestaurantDescription(OpenaiEmbeddingDoc):
45
+ id: str = '' # a number string
46
+ name: str
47
+ name_alt: str | None
48
+ categories: list[str]
49
+ dishes: list[str]
50
+ rating: float # 0-1
51
+ price: int # HKD
52
+ info_url: str
53
+ image_url: str
54
+
55
+
56
+ class Category(OpenaiEmbeddingDoc):
57
+ id: str = '' # same as text
58
+ restaurants: list[str] # list of ids? or we could just search the restaurants?
59
+
60
+
61
+ class Dish(OpenaiEmbeddingDoc):
62
+ """
63
+ Note: Not all dish names are meaningful, e.g., 'Trip to Bali', 'Oakland Breeze'
64
+ May include duplicates?
65
+ """
66
+ id: str = '' # same as text
67
+ restaurants: list[str] # list of ids
68
+
69
+
70
+ restaurant_index = InMemoryExactNNIndex[RestaurantDescription](index_file_path='data/restaurants.bin')
71
+ category_index = InMemoryExactNNIndex[Category](index_file_path='data/categories.bin')
72
+ dish_index = InMemoryExactNNIndex[Dish](index_file_path='data/dishes.bin')
launch.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from chainlit.cli import run_chainlit
2
+ from chainlit.config import config
3
+
4
+ if __name__ == '__main__':
5
+ config.run.watch = True
6
+ run_chainlit('app.py')
requirements.txt CHANGED
@@ -1 +1,2 @@
1
  chainlit==0.6.2
 
 
1
  chainlit==0.6.2
2
+ docarray>=0.37.0
scripts/__init__.py ADDED
File without changes
scripts/create_embeddings.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ from ast import literal_eval
3
+ from pathlib import Path
4
+ from typing import TypeVar
5
+
6
+ from docarray import DocList
7
+ from dotenv import load_dotenv
8
+
9
+ from azure_openai import AzureOpenaiEmbeddings
10
+ from data import RestaurantDescription, restaurant_index, Dish, Category, dish_index, category_index
11
+
12
+
13
+ def calculate_rating(low: str, medium: str, high: str) -> float:
14
+ low = int(low)
15
+ medium = int(medium)
16
+ high = int(high)
17
+ total = low + medium + high
18
+ return (medium*0.7 + high) / total
19
+
20
+
21
+ def normalize_dish(dish_name: str) -> str:
22
+ output = dish_name.replace('\xa0', '')
23
+ return output.title()
24
+
25
+
26
+ T = TypeVar('T')
27
+
28
+
29
+ def add_to_all(restaurant: RestaurantDescription, keys: list[str], mapping: dict[T], cls: type[T]):
30
+ keys = set(keys) # guard against duplicates
31
+ for k in keys:
32
+ v = mapping.get(k)
33
+ if v is None:
34
+ v = mapping[k] = cls(id=k, text=k, restaurants=[])
35
+ v.restaurants.append(restaurant.id)
36
+
37
+
38
+ restaurants, dish_list, category_list = None, None, None
39
+
40
+
41
+ def main():
42
+ global restaurants, dish_list, category_list
43
+ load_dotenv()
44
+
45
+ csv_file = Path('restaurants.csv')
46
+
47
+ restaurants = DocList[RestaurantDescription]()
48
+ dishes = {}
49
+ categories = {}
50
+
51
+ with csv_file.open(encoding='utf-8-sig', newline='') as f:
52
+ reader = csv.DictReader(f)
53
+ for row in reader:
54
+ if row['name_lang2']:
55
+ name = row['name_lang2']
56
+ name_alt = row['name_lang1']
57
+ else:
58
+ name = row['name_lang1']
59
+ name_alt = None
60
+
61
+ ds = literal_eval(row['dishes'])
62
+ ds = [normalize_dish(d) for d in ds]
63
+ cs = literal_eval(row['categories'])
64
+
65
+ r = RestaurantDescription(
66
+ embedding=None, # batch create all embeddings later
67
+ id=row['id'],
68
+ name=name,
69
+ name_alt=name_alt,
70
+ text=row['intro'],
71
+ price=int(row['price']),
72
+ rating=calculate_rating(row['score_cry'], row['score_o_k'], row['score_smile']),
73
+ categories=cs,
74
+ dishes=ds,
75
+ info_url=row['poi_url'],
76
+ image_url=row['door_photos'],
77
+ )
78
+
79
+ restaurants.append(r)
80
+ add_to_all(r, ds, dishes, Dish)
81
+ add_to_all(r, cs, categories, Category)
82
+ dish_list = DocList[Dish](dishes.values())
83
+ category_list = DocList[Category](categories.values())
84
+
85
+ embedding_settings = AzureOpenaiEmbeddings.load_from_env()
86
+
87
+ RestaurantDescription.create_embeddings(restaurants, **embedding_settings.to_settings_dict())
88
+ Dish.create_embeddings(dish_list, **embedding_settings.to_settings_dict())
89
+ Category.create_embeddings(category_list, **embedding_settings.to_settings_dict())
90
+
91
+ restaurant_index.index(restaurants)
92
+ dish_index.index(dish_list)
93
+ category_index.index(category_list)
94
+
95
+ restaurant_index.persist()
96
+ dish_index.persist()
97
+ category_index.persist()
98
+
99
+
100
+ if __name__ == '__main__':
101
+ main()