Spaces:
Runtime error
Runtime error
Commit
·
e6e69dc
1
Parent(s):
5a768b8
Add app
Browse files- .gitignore +2 -0
- app.py +144 -21
- azure_openai.py +98 -0
- chain.py +18 -4
- data.py +72 -0
- launch.py +6 -0
- requirements.txt +1 -0
- scripts/__init__.py +0 -0
- scripts/create_embeddings.py +101 -0
.gitignore
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
.chainlit/chat.db
|
|
|
|
| 2 |
|
| 3 |
# Created by https://www.toptal.com/developers/gitignore/api/python,intellij+all,visualstudiocode
|
| 4 |
# Edit at https://www.toptal.com/developers/gitignore?templates=python,intellij+all,visualstudiocode
|
|
@@ -283,3 +284,4 @@ pyrightconfig.json
|
|
| 283 |
.ionide
|
| 284 |
|
| 285 |
# End of https://www.toptal.com/developers/gitignore/api/python,intellij+all,visualstudiocode
|
|
|
|
|
|
| 1 |
.chainlit/chat.db
|
| 2 |
+
.chainlit/chat_files
|
| 3 |
|
| 4 |
# Created by https://www.toptal.com/developers/gitignore/api/python,intellij+all,visualstudiocode
|
| 5 |
# Edit at https://www.toptal.com/developers/gitignore?templates=python,intellij+all,visualstudiocode
|
|
|
|
| 284 |
.ionide
|
| 285 |
|
| 286 |
# End of https://www.toptal.com/developers/gitignore/api/python,intellij+all,visualstudiocode
|
| 287 |
+
/.chainlit/chat_files/3105f452-a667-4cd0-b40c-7502c8bc62d2/1bf41236-505a-44b3-9ac4-55efa37a393d.txt
|
app.py
CHANGED
|
@@ -1,33 +1,156 @@
|
|
| 1 |
-
import
|
|
|
|
| 2 |
|
| 3 |
import chainlit as cl
|
|
|
|
| 4 |
|
|
|
|
| 5 |
from chain import Chain
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
|
| 8 |
@cl.on_chat_start
|
| 9 |
async def start_chat():
|
| 10 |
-
|
| 11 |
-
await chain.text("I will count to 5. How many concurrent times should I count?")
|
| 12 |
|
| 13 |
|
| 14 |
@cl.on_message
|
| 15 |
async def on_message(message: str, message_id: str):
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import tomllib
|
| 3 |
|
| 4 |
import chainlit as cl
|
| 5 |
+
from docarray.index.abstract import BaseDocIndex
|
| 6 |
|
| 7 |
+
from azure_openai import AzureOpenaiSettings, AzureOpenaiEmbeddings, patch_chainlit
|
| 8 |
from chain import Chain
|
| 9 |
+
from data import embed, restaurant_index, RestaurantDescription
|
| 10 |
+
|
| 11 |
+
embedding_settings = AzureOpenaiEmbeddings.load_from_env().to_settings_dict()
|
| 12 |
+
|
| 13 |
+
patch_chainlit()
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def search_embeddings(query: str, doc_index: BaseDocIndex):
|
| 17 |
+
vec = embed(query, **embedding_settings)
|
| 18 |
+
docs, scores = doc_index.find(vec, 'embedding', 5)
|
| 19 |
+
return docs
|
| 20 |
|
| 21 |
|
| 22 |
@cl.on_chat_start
|
| 23 |
async def start_chat():
|
| 24 |
+
cl.user_session.set("history", [])
|
|
|
|
| 25 |
|
| 26 |
|
| 27 |
@cl.on_message
|
| 28 |
async def on_message(message: str, message_id: str):
|
| 29 |
+
history = cl.user_session.get("history")
|
| 30 |
+
|
| 31 |
+
# update history
|
| 32 |
+
history.append({"role": "user", "content": message})
|
| 33 |
+
|
| 34 |
+
# build AI response
|
| 35 |
+
chain = Chain(message_id, llm_settings=AzureOpenaiSettings.load_from_env())
|
| 36 |
+
|
| 37 |
+
query_msg = await chain.llm(
|
| 38 |
+
"""
|
| 39 |
+
You are a conversation summarizer that condenses a conversation between
|
| 40 |
+
a human and AI into a search query that can be used to find relevant
|
| 41 |
+
restaurants.
|
| 42 |
+
|
| 43 |
+
Conversation history
|
| 44 |
+
|
| 45 |
+
============
|
| 46 |
+
|
| 47 |
+
{history}
|
| 48 |
+
|
| 49 |
+
============
|
| 50 |
+
|
| 51 |
+
From this conversation, create a search query that would fit the human's needs.
|
| 52 |
+
Do not say anything else; just the query.
|
| 53 |
+
""",
|
| 54 |
+
history=format_history(history),
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
results = search_embeddings(query_msg.content, restaurant_index)
|
| 58 |
+
await chain.text(str(list(results))) # TODO maybe json format would be better?
|
| 59 |
+
|
| 60 |
+
restaurants = "\n".join(f"- ID: {r.id} | Description: {r.text}" for r in results)
|
| 61 |
+
|
| 62 |
+
final_choices_msg = await chain.llm(
|
| 63 |
+
"""
|
| 64 |
+
You are a search engine for restaurants.
|
| 65 |
+
Output the restaurant IDs for the best matches to the following query:
|
| 66 |
+
|
| 67 |
+
----
|
| 68 |
+
|
| 69 |
+
{query}
|
| 70 |
+
|
| 71 |
+
============
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
List of restaurants
|
| 75 |
+
|
| 76 |
+
----
|
| 77 |
+
|
| 78 |
+
{restaurants}
|
| 79 |
+
|
| 80 |
+
============
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
Output your final answer as a TOML blob.
|
| 84 |
+
Each restaurant should have a key for its ID, with a
|
| 85 |
+
boolean value, where true means the restaurant is a good fit.
|
| 86 |
+
|
| 87 |
+
For example:
|
| 88 |
+
|
| 89 |
+
---
|
| 90 |
+
|
| 91 |
+
[answer]
|
| 92 |
+
|
| 93 |
+
101 = false
|
| 94 |
+
|
| 95 |
+
1350 = true
|
| 96 |
+
|
| 97 |
+
02458 = false
|
| 98 |
+
|
| 99 |
+
9315 = true
|
| 100 |
+
|
| 101 |
+
128974 = true
|
| 102 |
+
|
| 103 |
+
============
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
Make include IDs of ALL restaurants, but only mark true for ones that fit the query.
|
| 107 |
+
|
| 108 |
+
""",
|
| 109 |
+
query=query_msg.content,
|
| 110 |
+
restaurants=restaurants
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
# match = re.match(r'```\s*toml\s*(.*)\s*```', final_choices_msg.content, re.DOTALL)
|
| 114 |
+
# toml_string = match.group(1)
|
| 115 |
+
toml_string = final_choices_msg.content
|
| 116 |
+
|
| 117 |
+
# don't output just the good values, since GPT doesn't think about each option
|
| 118 |
+
# final_ids = [x.strip() for x in final_choices_msg.content.split(',')]
|
| 119 |
+
|
| 120 |
+
# don't use json because curly braces brakes the template code...
|
| 121 |
+
# final_ids = json.loads(json_string)
|
| 122 |
+
|
| 123 |
+
# TOML is easy to write and parse for both machines and humans :)
|
| 124 |
+
obj = tomllib.loads(toml_string)
|
| 125 |
+
final_ids = [id for id, val in obj['answer'].items() if val]
|
| 126 |
+
|
| 127 |
+
for i, id in enumerate(final_ids[:3]):
|
| 128 |
+
id = str(id)
|
| 129 |
+
restaurant: RestaurantDescription = restaurant_index[id] # why no automatic typing?
|
| 130 |
+
msg = await chain.text(f"Option {i}", final=True)
|
| 131 |
+
msg.elements = [
|
| 132 |
+
# note: image always displays above text
|
| 133 |
+
cl.Image(name=restaurant.name, url=restaurant.image_url, display='inline', size='small'),
|
| 134 |
+
cl.Text(name=restaurant.name, content=restaurant.text, display='inline'),
|
| 135 |
+
# TODO text could also include categories/dishes/rating/price
|
| 136 |
+
]
|
| 137 |
+
msg.actions = [
|
| 138 |
+
cl.Action(name='book', value=id, label='Book', description='Click to book this restaurant'),
|
| 139 |
+
]
|
| 140 |
+
await msg.update()
|
| 141 |
+
|
| 142 |
+
# TODO what should the history include? ids only? or also descriptions?
|
| 143 |
+
# history.append({"role": "assistant", "content": response.content})
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
NAMES = {
|
| 147 |
+
# 'system': '',
|
| 148 |
+
'user': 'Human',
|
| 149 |
+
'assistant': 'AI',
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def format_history(history: list[dict]) -> str:
|
| 154 |
+
"""Formats list of messages into a single string."""
|
| 155 |
+
strings = [f'{NAMES[m["role"]]}: {m["content"]}' for m in history]
|
| 156 |
+
return "\n".join(strings)
|
azure_openai.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import Self
|
| 3 |
+
|
| 4 |
+
from chainlit import LLMSettings
|
| 5 |
+
from chainlit.telemetry import trace_event
|
| 6 |
+
from chainlit.types import CompletionRequest
|
| 7 |
+
from pydantic.dataclasses import dataclass
|
| 8 |
+
from starlette.responses import PlainTextResponse
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@dataclass
|
| 12 |
+
class AzureOpenaiSettings(LLMSettings):
|
| 13 |
+
api_type: str = 'azure'
|
| 14 |
+
api_base: str = ''
|
| 15 |
+
engine: str = ''
|
| 16 |
+
api_version: str = '2023-05-15'
|
| 17 |
+
|
| 18 |
+
def to_settings_dict(self):
|
| 19 |
+
return {
|
| 20 |
+
**super().to_settings_dict(),
|
| 21 |
+
"api_type": self.api_type,
|
| 22 |
+
"api_base": self.api_base,
|
| 23 |
+
"api_version": self.api_version,
|
| 24 |
+
"engine": self.engine,
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
@classmethod
|
| 28 |
+
def load_from_env(cls: type[Self], *args, **kwargs) -> Self:
|
| 29 |
+
return cls(
|
| 30 |
+
*args,
|
| 31 |
+
api_type='azure',
|
| 32 |
+
api_base=os.environ.get('AZURE_OPENAI_ENDPOINT'),
|
| 33 |
+
engine=os.environ.get('AZURE_OPENAI_DEPLOYMENT'),
|
| 34 |
+
api_version=os.environ.get('AZURE_OPENAI_VERSION', '2023-05-15'),
|
| 35 |
+
**kwargs,
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
@dataclass
|
| 40 |
+
class AzureOpenaiEmbeddings:
|
| 41 |
+
api_type: str = 'azure'
|
| 42 |
+
api_base: str = ''
|
| 43 |
+
engine: str = ''
|
| 44 |
+
api_version: str = '2023-05-15'
|
| 45 |
+
|
| 46 |
+
def to_settings_dict(self):
|
| 47 |
+
return {
|
| 48 |
+
"api_type": self.api_type,
|
| 49 |
+
"api_base": self.api_base,
|
| 50 |
+
"api_version": self.api_version,
|
| 51 |
+
"engine": self.engine,
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
@classmethod
|
| 55 |
+
def load_from_env(cls: type[Self], *args, **kwargs) -> Self:
|
| 56 |
+
return cls(
|
| 57 |
+
*args,
|
| 58 |
+
api_type='azure',
|
| 59 |
+
api_base=os.environ.get('AZURE_OPENAI_ENDPOINT'),
|
| 60 |
+
engine=os.environ.get('AZURE_OPENAI_EMBEDDINGS_DEPLOYMENT'),
|
| 61 |
+
api_version=os.environ.get('AZURE_OPENAI_VERSION', '2023-05-15'),
|
| 62 |
+
**kwargs,
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def patch_chainlit():
|
| 67 |
+
from chainlit.server import app
|
| 68 |
+
|
| 69 |
+
# replace playground's completion endpoint with one that uses custom openai settings
|
| 70 |
+
app.router.routes = list(filter(lambda route: route.path != '/completion', app.router.routes))
|
| 71 |
+
|
| 72 |
+
@app.post("/completion")
|
| 73 |
+
async def completion(request: CompletionRequest):
|
| 74 |
+
"""Handle a completion request from the prompt playground."""
|
| 75 |
+
|
| 76 |
+
import openai
|
| 77 |
+
|
| 78 |
+
trace_event("completion")
|
| 79 |
+
|
| 80 |
+
api_key = request.userEnv.get("OPENAI_API_KEY", os.environ.get("OPENAI_API_KEY"))
|
| 81 |
+
|
| 82 |
+
stop = request.settings.stop
|
| 83 |
+
# OpenAI doesn't support an empty stop array, clear it
|
| 84 |
+
if isinstance(stop, list) and len(stop) == 0:
|
| 85 |
+
stop = None
|
| 86 |
+
|
| 87 |
+
response = await openai.ChatCompletion.acreate(
|
| 88 |
+
api_key=api_key,
|
| 89 |
+
messages=[{"role": "user", "content": request.prompt}],
|
| 90 |
+
stop=stop,
|
| 91 |
+
# **completion.settings.to_settings_dict(),
|
| 92 |
+
# HACK: hard-code llm settings
|
| 93 |
+
**dict(api_type='azure', api_base=os.environ.get('AZURE_OPENAI_ENDPOINT'),
|
| 94 |
+
engine=os.environ.get('AZURE_OPENAI_DEPLOYMENT'),
|
| 95 |
+
api_version=os.environ.get('AZURE_OPENAI_VERSION', '2023-05-15')
|
| 96 |
+
),
|
| 97 |
+
)
|
| 98 |
+
return PlainTextResponse(content=response["choices"][0]["message"]["content"])
|
chain.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
import asyncio
|
| 2 |
import os
|
| 3 |
import re
|
|
|
|
| 4 |
|
| 5 |
import chainlit as cl
|
| 6 |
import openai
|
|
@@ -8,6 +9,14 @@ from chainlit import LLMSettings
|
|
| 8 |
from chainlit.config import config
|
| 9 |
|
| 10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
# TODO each chain should be able to make a child chain?
|
| 12 |
# root = Chain()
|
| 13 |
# first = root.child("something")
|
|
@@ -26,11 +35,12 @@ class Chain:
|
|
| 26 |
**kwargs,
|
| 27 |
)
|
| 28 |
|
| 29 |
-
async def text(self, text, final=False, name=None):
|
| 30 |
message = self.make_message(content=text, final=final, name=name)
|
| 31 |
await message.send()
|
|
|
|
| 32 |
|
| 33 |
-
async def text_stream(self, text: str, delay=.1, name=None, final=False):
|
| 34 |
message = self.make_message(content='', final=final, name=name)
|
| 35 |
tokens = text.split(" ")
|
| 36 |
first = True
|
|
@@ -41,8 +51,12 @@ class Chain:
|
|
| 41 |
await asyncio.sleep(delay)
|
| 42 |
first = False
|
| 43 |
await message.send()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
-
async def llm(self, template, *args, name=None, final=False, **kwargs) -> str:
|
| 46 |
variables = re.findall(r'\{(.*?)}', template)
|
| 47 |
if len(args) > 1:
|
| 48 |
raise RuntimeError("If there is more than one argument, use kwargs")
|
|
@@ -66,4 +80,4 @@ class Chain:
|
|
| 66 |
await message.stream_token(token)
|
| 67 |
|
| 68 |
await message.send()
|
| 69 |
-
return message
|
|
|
|
| 1 |
import asyncio
|
| 2 |
import os
|
| 3 |
import re
|
| 4 |
+
from inspect import cleandoc
|
| 5 |
|
| 6 |
import chainlit as cl
|
| 7 |
import openai
|
|
|
|
| 9 |
from chainlit.config import config
|
| 10 |
|
| 11 |
|
| 12 |
+
def replace_newlines(match: re.Match) -> str:
|
| 13 |
+
newlines = match.group(0)
|
| 14 |
+
count = len(newlines)
|
| 15 |
+
if count <= 1:
|
| 16 |
+
return " "
|
| 17 |
+
return newlines[1:]
|
| 18 |
+
|
| 19 |
+
|
| 20 |
# TODO each chain should be able to make a child chain?
|
| 21 |
# root = Chain()
|
| 22 |
# first = root.child("something")
|
|
|
|
| 35 |
**kwargs,
|
| 36 |
)
|
| 37 |
|
| 38 |
+
async def text(self, text, final=False, name=None) -> cl.Message:
|
| 39 |
message = self.make_message(content=text, final=final, name=name)
|
| 40 |
await message.send()
|
| 41 |
+
return message
|
| 42 |
|
| 43 |
+
async def text_stream(self, text: str, delay=.1, name=None, final=False) -> cl.Message:
|
| 44 |
message = self.make_message(content='', final=final, name=name)
|
| 45 |
tokens = text.split(" ")
|
| 46 |
first = True
|
|
|
|
| 51 |
await asyncio.sleep(delay)
|
| 52 |
first = False
|
| 53 |
await message.send()
|
| 54 |
+
return message
|
| 55 |
+
|
| 56 |
+
async def llm(self, template, *args, name=None, final=False, **kwargs) -> cl.Message:
|
| 57 |
+
template = cleandoc(template)
|
| 58 |
+
template = re.sub('\n+', replace_newlines, template) # remove a newline
|
| 59 |
|
|
|
|
| 60 |
variables = re.findall(r'\{(.*?)}', template)
|
| 61 |
if len(args) > 1:
|
| 62 |
raise RuntimeError("If there is more than one argument, use kwargs")
|
|
|
|
| 80 |
await message.stream_token(token)
|
| 81 |
|
| 82 |
await message.send()
|
| 83 |
+
return message
|
data.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import Sequence
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import openai
|
| 6 |
+
from docarray.documents import TextDoc
|
| 7 |
+
from docarray.index import InMemoryExactNNIndex
|
| 8 |
+
from docarray.typing import NdArray
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class OpenaiEmbeddingDoc(TextDoc):
|
| 12 |
+
embedding: NdArray[1536] | None
|
| 13 |
+
|
| 14 |
+
@staticmethod
|
| 15 |
+
def create_embeddings(docs: Sequence['OpenaiEmbeddingDoc'], **kwargs):
|
| 16 |
+
if len(docs) > 16: # max allowed by azure
|
| 17 |
+
for i in range(0, len(docs), 16):
|
| 18 |
+
print(f"Processing 16 starting from index {i}")
|
| 19 |
+
OpenaiEmbeddingDoc.create_embeddings(docs[i:i+16], **kwargs)
|
| 20 |
+
else:
|
| 21 |
+
texts = [d.text for d in docs]
|
| 22 |
+
kwargs.setdefault('api_')
|
| 23 |
+
response = openai.Embedding.create(
|
| 24 |
+
input=texts,
|
| 25 |
+
api_key=os.environ.get('OPENAI_API_KEY', kwargs.get('api_key')),
|
| 26 |
+
**kwargs # API key, model/engine, api_type, api_date, api_
|
| 27 |
+
)
|
| 28 |
+
embeddings = response['data']
|
| 29 |
+
assert(len(embeddings) == len(docs))
|
| 30 |
+
for obj in embeddings:
|
| 31 |
+
doc = docs[obj['index']]
|
| 32 |
+
doc.embedding = np.array(obj['embedding'])
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def embed(text: str, **kwargs) -> np.ndarray[1536]:
|
| 36 |
+
response = openai.Embedding.create(
|
| 37 |
+
input=text,
|
| 38 |
+
api_key=os.environ.get('OPENAI_API_KEY', kwargs.get('api_key')),
|
| 39 |
+
**kwargs
|
| 40 |
+
)
|
| 41 |
+
return np.array(response['data'][0]['embedding'])
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class RestaurantDescription(OpenaiEmbeddingDoc):
|
| 45 |
+
id: str = '' # a number string
|
| 46 |
+
name: str
|
| 47 |
+
name_alt: str | None
|
| 48 |
+
categories: list[str]
|
| 49 |
+
dishes: list[str]
|
| 50 |
+
rating: float # 0-1
|
| 51 |
+
price: int # HKD
|
| 52 |
+
info_url: str
|
| 53 |
+
image_url: str
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class Category(OpenaiEmbeddingDoc):
|
| 57 |
+
id: str = '' # same as text
|
| 58 |
+
restaurants: list[str] # list of ids? or we could just search the restaurants?
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class Dish(OpenaiEmbeddingDoc):
|
| 62 |
+
"""
|
| 63 |
+
Note: Not all dish names are meaningful, e.g., 'Trip to Bali', 'Oakland Breeze'
|
| 64 |
+
May include duplicates?
|
| 65 |
+
"""
|
| 66 |
+
id: str = '' # same as text
|
| 67 |
+
restaurants: list[str] # list of ids
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
restaurant_index = InMemoryExactNNIndex[RestaurantDescription](index_file_path='data/restaurants.bin')
|
| 71 |
+
category_index = InMemoryExactNNIndex[Category](index_file_path='data/categories.bin')
|
| 72 |
+
dish_index = InMemoryExactNNIndex[Dish](index_file_path='data/dishes.bin')
|
launch.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from chainlit.cli import run_chainlit
|
| 2 |
+
from chainlit.config import config
|
| 3 |
+
|
| 4 |
+
if __name__ == '__main__':
|
| 5 |
+
config.run.watch = True
|
| 6 |
+
run_chainlit('app.py')
|
requirements.txt
CHANGED
|
@@ -1 +1,2 @@
|
|
| 1 |
chainlit==0.6.2
|
|
|
|
|
|
| 1 |
chainlit==0.6.2
|
| 2 |
+
docarray>=0.37.0
|
scripts/__init__.py
ADDED
|
File without changes
|
scripts/create_embeddings.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import csv
|
| 2 |
+
from ast import literal_eval
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import TypeVar
|
| 5 |
+
|
| 6 |
+
from docarray import DocList
|
| 7 |
+
from dotenv import load_dotenv
|
| 8 |
+
|
| 9 |
+
from azure_openai import AzureOpenaiEmbeddings
|
| 10 |
+
from data import RestaurantDescription, restaurant_index, Dish, Category, dish_index, category_index
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def calculate_rating(low: str, medium: str, high: str) -> float:
|
| 14 |
+
low = int(low)
|
| 15 |
+
medium = int(medium)
|
| 16 |
+
high = int(high)
|
| 17 |
+
total = low + medium + high
|
| 18 |
+
return (medium*0.7 + high) / total
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def normalize_dish(dish_name: str) -> str:
|
| 22 |
+
output = dish_name.replace('\xa0', '')
|
| 23 |
+
return output.title()
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
T = TypeVar('T')
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def add_to_all(restaurant: RestaurantDescription, keys: list[str], mapping: dict[T], cls: type[T]):
|
| 30 |
+
keys = set(keys) # guard against duplicates
|
| 31 |
+
for k in keys:
|
| 32 |
+
v = mapping.get(k)
|
| 33 |
+
if v is None:
|
| 34 |
+
v = mapping[k] = cls(id=k, text=k, restaurants=[])
|
| 35 |
+
v.restaurants.append(restaurant.id)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
restaurants, dish_list, category_list = None, None, None
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def main():
|
| 42 |
+
global restaurants, dish_list, category_list
|
| 43 |
+
load_dotenv()
|
| 44 |
+
|
| 45 |
+
csv_file = Path('restaurants.csv')
|
| 46 |
+
|
| 47 |
+
restaurants = DocList[RestaurantDescription]()
|
| 48 |
+
dishes = {}
|
| 49 |
+
categories = {}
|
| 50 |
+
|
| 51 |
+
with csv_file.open(encoding='utf-8-sig', newline='') as f:
|
| 52 |
+
reader = csv.DictReader(f)
|
| 53 |
+
for row in reader:
|
| 54 |
+
if row['name_lang2']:
|
| 55 |
+
name = row['name_lang2']
|
| 56 |
+
name_alt = row['name_lang1']
|
| 57 |
+
else:
|
| 58 |
+
name = row['name_lang1']
|
| 59 |
+
name_alt = None
|
| 60 |
+
|
| 61 |
+
ds = literal_eval(row['dishes'])
|
| 62 |
+
ds = [normalize_dish(d) for d in ds]
|
| 63 |
+
cs = literal_eval(row['categories'])
|
| 64 |
+
|
| 65 |
+
r = RestaurantDescription(
|
| 66 |
+
embedding=None, # batch create all embeddings later
|
| 67 |
+
id=row['id'],
|
| 68 |
+
name=name,
|
| 69 |
+
name_alt=name_alt,
|
| 70 |
+
text=row['intro'],
|
| 71 |
+
price=int(row['price']),
|
| 72 |
+
rating=calculate_rating(row['score_cry'], row['score_o_k'], row['score_smile']),
|
| 73 |
+
categories=cs,
|
| 74 |
+
dishes=ds,
|
| 75 |
+
info_url=row['poi_url'],
|
| 76 |
+
image_url=row['door_photos'],
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
restaurants.append(r)
|
| 80 |
+
add_to_all(r, ds, dishes, Dish)
|
| 81 |
+
add_to_all(r, cs, categories, Category)
|
| 82 |
+
dish_list = DocList[Dish](dishes.values())
|
| 83 |
+
category_list = DocList[Category](categories.values())
|
| 84 |
+
|
| 85 |
+
embedding_settings = AzureOpenaiEmbeddings.load_from_env()
|
| 86 |
+
|
| 87 |
+
RestaurantDescription.create_embeddings(restaurants, **embedding_settings.to_settings_dict())
|
| 88 |
+
Dish.create_embeddings(dish_list, **embedding_settings.to_settings_dict())
|
| 89 |
+
Category.create_embeddings(category_list, **embedding_settings.to_settings_dict())
|
| 90 |
+
|
| 91 |
+
restaurant_index.index(restaurants)
|
| 92 |
+
dish_index.index(dish_list)
|
| 93 |
+
category_index.index(category_list)
|
| 94 |
+
|
| 95 |
+
restaurant_index.persist()
|
| 96 |
+
dish_index.persist()
|
| 97 |
+
category_index.persist()
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
if __name__ == '__main__':
|
| 101 |
+
main()
|