restaurants / scripts /create_embeddings.py
briankchan's picture
Update embeddings
818d7ad
import csv
import json
from ast import literal_eval
from pathlib import Path
from typing import TypeVar
from docarray import DocList
from dotenv import load_dotenv
from azure_openai import AzureOpenaiEmbeddings
from data import RestaurantDescription, restaurant_index, Dish, Category, dish_index, category_index
def calculate_rating(low: str, medium: str, high: str) -> float:
low = int(low)
medium = int(medium)
high = int(high)
total = low + medium + high
return (medium*0.7 + high) / total
def normalize_dish(dish_name: str) -> str:
output = dish_name.replace('\xa0', '')
return output.title()
T = TypeVar('T')
def add_to_all(restaurant: RestaurantDescription, keys: list[str], mapping: dict[T], cls: type[T]):
keys = set(keys) # guard against duplicates
for k in keys:
v = mapping.get(k)
if v is None:
v = mapping[k] = cls(id=k, text=k, restaurants=[])
v.restaurants.append(restaurant.id)
def load_districts():
with Path('data/district_boundary.json').open('r', encoding='utf-8') as f:
districts = json.load(f)
from matplotlib.path import Path as Polyline
output = []
for d in districts['features']:
district = {
"name": d['properties']['District'],
"polygon": Polyline(d['geometry']['coordinates'][0])
}
output.append(district)
return output
DISTRICTS = load_districts()
def get_district_name(lat: str, lon: str):
lat = float(lat)
lon = float(lon)
matches = []
for district in DISTRICTS:
if district['polygon'].contains_point((lon, lat)):
matches.append(district['name'])
return matches
restaurants, dish_list, category_list = None, None, None
def main():
global restaurants, dish_list, category_list
load_dotenv()
csv_file = Path('restaurants.csv')
restaurants = DocList[RestaurantDescription]()
restaurant_names = set()
dishes = {}
categories = {}
with csv_file.open(encoding='utf-8-sig', newline='') as f:
reader = csv.DictReader(f)
for row in reader:
if row['name_lang2']:
name = row['name_lang2']
name_alt = row['name_lang1']
else:
name = row['name_lang1']
name_alt = None
# for this demo, don't add multiple locations of the same restaurant chain
if name in restaurant_names:
continue
ds = literal_eval(row['dishes'])
ds = [normalize_dish(d) for d in ds]
ds = list(set(ds)) # unique
cs = literal_eval(row['categories'])
location = get_district_name(row['map_latitude'], row['map_longitude'])
price = int(row['price'])
if price < 100:
price_bucket = 'cheap'
elif 300 > price >= 100:
price_bucket = 'moderate'
elif price >= 300:
price_bucket = 'expensive'
extra_data = ' Romantic Dining.' if 'Romantic Dining' in cs else ""
text = f"""\
Name: {name}
Intro: {row['intro']}{extra_data}
Dishes: {", ".join(ds)}
Location: {", ".join(location)}
Price: {price_bucket}\
"""
r = RestaurantDescription(
embedding=None, # batch create all embeddings later
text=text,
id=row['id'],
name=name,
name_alt=name_alt,
intro=row['intro'],
price=price,
rating=calculate_rating(row['score_cry'], row['score_o_k'], row['score_smile']),
categories=cs,
dishes=ds,
info_url=row['poi_url'],
image_url=row['door_photos'],
location=location,
)
restaurants.append(r)
restaurant_names.add(name)
add_to_all(r, ds, dishes, Dish)
add_to_all(r, cs, categories, Category)
dish_list = DocList[Dish](dishes.values())
category_list = DocList[Category](categories.values())
import IPython
IPython.embed()
embedding_settings = AzureOpenaiEmbeddings.load_from_env()
RestaurantDescription.create_embeddings(restaurants, **embedding_settings.to_settings_dict())
Dish.create_embeddings(dish_list, **embedding_settings.to_settings_dict())
Category.create_embeddings(category_list, **embedding_settings.to_settings_dict())
restaurant_index.index(restaurants)
dish_index.index(dish_list)
category_index.index(category_list)
restaurant_index.persist()
dish_index.persist()
category_index.persist()
if __name__ == '__main__':
main()