from collections import Counter
from math import log
import numpy as np
import plotly.graph_objects as go
from pyvis.network import Network
from ai_waiter_chatbot import MenuItem, TinyRetriever, normalize
PALETTE = [
"#4C78A8",
"#F58518",
"#54A24B",
"#E45756",
"#72B7B2",
"#EECA3B",
"#B279A2",
"#FF9DA6",
"#9D755D",
"#BAB0AC",
]
def _build_tfidf_vectors(
items: list[MenuItem], retriever: TinyRetriever
) -> tuple[list[str], np.ndarray]:
vocab = sorted(retriever.df.keys())
token_to_idx = {t: i for i, t in enumerate(vocab)}
n_docs = len(items)
matrix = np.zeros((n_docs, len(vocab)), dtype=float)
for row_idx, doc_tokens in enumerate(retriever.doc_tokens):
tf = Counter(doc_tokens)
for token, count in tf.items():
col_idx = token_to_idx.get(token)
if col_idx is None:
continue
idf = log((n_docs + 1) / (retriever.df[token] + 1)) + 1.0
matrix[row_idx, col_idx] = count * idf
return vocab, matrix
def _query_vector(query: str, vocab: list[str], retriever: TinyRetriever) -> np.ndarray:
token_to_idx = {t: i for i, t in enumerate(vocab)}
vec = np.zeros((len(vocab),), dtype=float)
q_tokens = Counter(normalize(query))
n_docs = max(len(retriever.doc_tokens), 1)
for token, count in q_tokens.items():
if token not in token_to_idx:
continue
idf = log((n_docs + 1) / (retriever.df[token] + 1)) + 1.0
vec[token_to_idx[token]] = count * idf
return vec
def _project_2d(matrix: np.ndarray, query_vec: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
full = np.vstack([matrix, query_vec.reshape(1, -1)])
full = full - full.mean(axis=0, keepdims=True)
# SVD gives a stable 2D projection without extra dependencies.
_, _, vt = np.linalg.svd(full, full_matrices=False)
axes = vt[:2].T
coords = full @ axes
return coords[:-1], coords[-1]
def _project_nd(
matrix: np.ndarray, query_vec: np.ndarray, dims: int
) -> tuple[np.ndarray, np.ndarray]:
full = np.vstack([matrix, query_vec.reshape(1, -1)])
full = full - full.mean(axis=0, keepdims=True)
_, _, vt = np.linalg.svd(full, full_matrices=False)
take = min(dims, vt.shape[0])
axes = vt[:take].T
coords = full @ axes
if take < dims:
pad = np.zeros((coords.shape[0], dims - take))
coords = np.hstack([coords, pad])
return coords[:-1], coords[-1]
def _section_color_map(items: list[MenuItem]) -> dict[str, str]:
sections = sorted({item.section for item in items})
return {section: PALETTE[i % len(PALETTE)] for i, section in enumerate(sections)}
def build_embedding_figure(
query: str, items: list[MenuItem], retriever: TinyRetriever, top_k: int = 6
) -> go.Figure:
vocab, matrix = _build_tfidf_vectors(items, retriever)
q_vec = _query_vector(query, vocab, retriever)
item_xy, query_xy = _project_2d(matrix, q_vec)
retrieved = retriever.retrieve(query, top_k=top_k)
retrieved_ids = {
(r.section, r.name, r.price)
for r in retrieved
}
section_colors = _section_color_map(items)
by_section: dict[str, dict[str, list]] = {}
hit_x, hit_y, hit_text = [], [], []
for idx, item in enumerate(items):
payload = f"{item.name}
{item.section}
{item.price}"
target = (item.section, item.name, item.price)
point_x = float(item_xy[idx, 0])
point_y = float(item_xy[idx, 1])
if item.section not in by_section:
by_section[item.section] = {"x": [], "y": [], "text": []}
by_section[item.section]["x"].append(point_x)
by_section[item.section]["y"].append(point_y)
by_section[item.section]["text"].append(payload)
if target in retrieved_ids:
hit_x.append(point_x)
hit_y.append(point_y)
hit_text.append(payload)
fig = go.Figure()
for section, payload in by_section.items():
fig.add_trace(
go.Scatter(
x=payload["x"],
y=payload["y"],
mode="markers",
name=section,
marker={
"size": 8,
"opacity": 0.50,
"color": section_colors[section],
},
text=payload["text"],
hovertemplate="%{text}",
)
)
fig.add_trace(
go.Scatter(
x=hit_x,
y=hit_y,
mode="markers",
name="Retrieved nodes",
marker={
"size": 13,
"opacity": 1.0,
"color": "#111111",
"symbol": "diamond-open",
"line": {"width": 2},
},
text=hit_text,
hovertemplate="%{text}",
)
)
fig.add_trace(
go.Scatter(
x=[float(query_xy[0])],
y=[float(query_xy[1])],
mode="markers",
name="Your question",
marker={"size": 16, "symbol": "star", "color": "#54A24B"},
text=[query],
hovertemplate="Query: %{text}",
)
)
fig.update_layout(
title=f"Interactive Embedding Space (Top-k retrieved: {top_k})",
xaxis_title="Embedding Axis 1",
yaxis_title="Embedding Axis 2",
template="plotly_white",
legend={"orientation": "h", "y": -0.15},
margin={"l": 20, "r": 20, "t": 50, "b": 20},
)
return fig
def build_embedding_figure_3d(
query: str, items: list[MenuItem], retriever: TinyRetriever, top_k: int = 6
) -> go.Figure:
vocab, matrix = _build_tfidf_vectors(items, retriever)
q_vec = _query_vector(query, vocab, retriever)
item_xyz, query_xyz = _project_nd(matrix, q_vec, dims=3)
retrieved = retriever.retrieve(query, top_k=top_k)
retrieved_ids = {(r.section, r.name, r.price) for r in retrieved}
section_colors = _section_color_map(items)
by_section: dict[str, dict[str, list]] = {}
hit_x, hit_y, hit_z, hit_text = [], [], [], []
for idx, item in enumerate(items):
payload = f"{item.name}
{item.section}
{item.price}"
point_x = float(item_xyz[idx, 0])
point_y = float(item_xyz[idx, 1])
point_z = float(item_xyz[idx, 2])
target = (item.section, item.name, item.price)
if item.section not in by_section:
by_section[item.section] = {"x": [], "y": [], "z": [], "text": []}
by_section[item.section]["x"].append(point_x)
by_section[item.section]["y"].append(point_y)
by_section[item.section]["z"].append(point_z)
by_section[item.section]["text"].append(payload)
if target in retrieved_ids:
hit_x.append(point_x)
hit_y.append(point_y)
hit_z.append(point_z)
hit_text.append(payload)
fig = go.Figure()
for section, payload in by_section.items():
fig.add_trace(
go.Scatter3d(
x=payload["x"],
y=payload["y"],
z=payload["z"],
mode="markers",
name=section,
marker={"size": 3, "opacity": 0.55, "color": section_colors[section]},
text=payload["text"],
hovertemplate="%{text}",
)
)
fig.add_trace(
go.Scatter3d(
x=hit_x,
y=hit_y,
z=hit_z,
mode="markers",
name="Retrieved nodes",
marker={"size": 6, "opacity": 1.0, "color": "#111111", "symbol": "diamond"},
text=hit_text,
hovertemplate="%{text}",
)
)
fig.add_trace(
go.Scatter3d(
x=[float(query_xyz[0])],
y=[float(query_xyz[1])],
z=[float(query_xyz[2])],
mode="markers",
name="Your question",
marker={"size": 8, "color": "#54A24B", "symbol": "diamond"},
text=[query],
hovertemplate="Query: %{text}",
)
)
fig.update_layout(
title=f"3D Embedding Explorer (Top-k retrieved: {top_k})",
template="plotly_white",
scene={
"xaxis_title": "Axis 1",
"yaxis_title": "Axis 2",
"zaxis_title": "Axis 3",
},
legend={"orientation": "h", "y": -0.12},
margin={"l": 10, "r": 10, "t": 50, "b": 10},
)
return fig
def build_pyvis_network_html(
query: str, items: list[MenuItem], retriever: TinyRetriever, top_k: int = 6
) -> str:
net = Network(height="520px", width="100%", bgcolor="#ffffff", font_color="#222222")
section_colors = _section_color_map(items)
query_id = "query-node"
net.add_node(query_id, label="Your Question", title=query, color="#54A24B", size=28)
retrieved = retriever.retrieve(query, top_k=top_k)
retrieved_ids = {(r.section, r.name, r.price) for r in retrieved}
for idx, item in enumerate(items):
node_id = f"item-{idx}"
is_retrieved = (item.section, item.name, item.price) in retrieved_ids
size = 16 if is_retrieved else 10
border = "#111111" if is_retrieved else section_colors[item.section]
title = f"{item.name}
{item.section}
{item.price}"
net.add_node(
node_id,
label=item.name[:28],
title=title,
color={"background": section_colors[item.section], "border": border},
size=size,
)
if is_retrieved:
net.add_edge(query_id, node_id, color="#111111", width=2)
net.force_atlas_2based(gravity=-45, central_gravity=0.006, spring_length=120)
net.show_buttons(filter_=["physics"])
return net.generate_html(name="embedding_network.html", notebook=False)