restaurants / data.py
briankchan's picture
Add more data to embeddings; filter out restaurant chains; add location
1ef3ac7
import os
from typing import Sequence
import numpy as np
import openai
from docarray.documents import TextDoc
from docarray.index import InMemoryExactNNIndex
from docarray.typing import NdArray
class OpenaiEmbeddingDoc(TextDoc):
embedding: NdArray[1536] | None
@staticmethod
def create_embeddings(docs: Sequence['OpenaiEmbeddingDoc'], **kwargs):
if len(docs) > 16: # max allowed by azure
for i in range(0, len(docs), 16):
print(f"Processing 16 starting from index {i}")
OpenaiEmbeddingDoc.create_embeddings(docs[i:i+16], **kwargs)
else:
texts = [d.text for d in docs]
kwargs.setdefault('api_')
response = openai.Embedding.create(
input=texts,
api_key=os.environ.get('OPENAI_API_KEY', kwargs.get('api_key')),
**kwargs # API key, model/engine, api_type, api_date, api_
)
embeddings = response['data']
assert(len(embeddings) == len(docs))
for obj in embeddings:
doc = docs[obj['index']]
doc.embedding = np.array(obj['embedding'])
def embed(text: str, **kwargs) -> np.ndarray[1536]:
response = openai.Embedding.create(
input=text,
api_key=os.environ.get('OPENAI_API_KEY', kwargs.get('api_key')),
**kwargs
)
return np.array(response['data'][0]['embedding'])
class RestaurantDescription(OpenaiEmbeddingDoc):
id: str = '' # a number string
name: str
name_alt: str | None
intro: str
categories: list[str]
dishes: list[str]
rating: float # 0-1
price: int # HKD
info_url: str
image_url: str
location: list[str]
class Category(OpenaiEmbeddingDoc):
id: str = '' # same as text
restaurants: list[str] # list of ids? or we could just search the restaurants?
class Dish(OpenaiEmbeddingDoc):
"""
Note: Not all dish names are meaningful, e.g., 'Trip to Bali', 'Oakland Breeze'
May include duplicates?
"""
id: str = '' # same as text
restaurants: list[str] # list of ids
restaurant_index = InMemoryExactNNIndex[RestaurantDescription](index_file_path='data/restaurants.bin')
category_index = InMemoryExactNNIndex[Category](index_file_path='data/categories.bin')
dish_index = InMemoryExactNNIndex[Dish](index_file_path='data/dishes.bin')