|
|
import gradio as gr |
|
|
import pandas as pd |
|
|
import numpy as np |
|
|
import pickle |
|
|
import os |
|
|
from sentence_transformers import SentenceTransformer |
|
|
from sklearn.metrics.pairwise import cosine_similarity |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
CSV_PATH = "cleaned_dataset_10k.csv" |
|
|
PKL_PATH = "final_embeddings_10k.pkl" |
|
|
|
|
|
if not os.path.exists(CSV_PATH) or not os.path.exists(PKL_PATH): |
|
|
raise FileNotFoundError("Missing required data files (CSV or PKL) in the root directory.") |
|
|
|
|
|
|
|
|
df = pd.read_csv(CSV_PATH) |
|
|
df.columns = [c.strip().lower().replace(' ', '_') for c in df.columns] |
|
|
|
|
|
|
|
|
with open(PKL_PATH, 'rb') as f: |
|
|
embedding_data = pickle.load(f) |
|
|
dataset_embeddings = embedding_data['embeddings'] if isinstance(embedding_data, dict) else embedding_data |
|
|
|
|
|
|
|
|
|
|
|
model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2') |
|
|
|
|
|
|
|
|
persona_profiles = {} |
|
|
for persona in df['reviewer_persona'].unique(): |
|
|
if pd.isna(persona): continue |
|
|
indices = df[df['reviewer_persona'] == persona].index |
|
|
persona_vectors = dataset_embeddings[indices] |
|
|
persona_profiles[persona] = np.mean(persona_vectors, axis=0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def run_ven_engine(budget, dietary, company, purpose, noise): |
|
|
""" |
|
|
Finds the best restaurant match by combining Persona Similarity (General Taste) |
|
|
and Review Similarity (Contextual Fit). |
|
|
""" |
|
|
|
|
|
user_context = f"Searching for a {budget} experience, {dietary} friendly. Group: {company}. Occasion: {purpose}. Atmosphere: {noise}." |
|
|
query_vec = model.encode([user_context]) |
|
|
|
|
|
|
|
|
persona_sims = {p: cosine_similarity(query_vec, v.reshape(1, -1))[0][0] |
|
|
for p, v in persona_profiles.items()} |
|
|
closest_persona = max(persona_sims, key=persona_sims.get) |
|
|
|
|
|
|
|
|
persona_indices = df[df['reviewer_persona'] == closest_persona].index |
|
|
persona_embeddings = dataset_embeddings[persona_indices] |
|
|
|
|
|
|
|
|
sub_similarities = cosine_similarity(query_vec, persona_embeddings)[0] |
|
|
|
|
|
persona_df = df.loc[persona_indices].copy() |
|
|
persona_df['semantic_fit'] = sub_similarities |
|
|
persona_df['norm_rating'] = persona_df['rating_score'] / 5.0 |
|
|
|
|
|
|
|
|
persona_df['final_score'] = (persona_df['semantic_fit'] * 0.7) + (persona_df['norm_rating'] * 0.3) |
|
|
|
|
|
|
|
|
top_match = persona_df.sort_values(by='final_score', ascending=False).iloc[0] |
|
|
match_pct = int(top_match['final_score'] * 100) |
|
|
|
|
|
|
|
|
return f""" |
|
|
<div style="background: white; border-radius: 20px; padding: 25px; color: #0f172a !important; text-align: left; border-left: 10px solid #f97316; box-shadow: 0 10px 25px rgba(0,0,0,0.2);"> |
|
|
<div style="display:flex; justify-content:space-between; align-items: flex-start;"> |
|
|
<div style="flex: 1;"> |
|
|
<h2 style="margin:0; font-size: 24px; font-weight: 900; color: #0f172a !important;">{top_match['restaurant_name']}</h2> |
|
|
<div style="font-size: 14px; color: #475569 !important; font-weight: 700; margin-top: 4px;">Matched for: {closest_persona} profile</div> |
|
|
</div> |
|
|
<div style="text-align:right; background: #f8fafc; padding: 10px; border-radius: 12px; border: 1px solid #e2e8f0;"> |
|
|
<div style="font-size: 30px; font-weight: 950; color: #2563eb !important; line-height: 1;">{top_match['rating_score']}</div> |
|
|
<div style="font-size:10px; font-weight:900; color: #64748b !important; letter-spacing: 1px; margin-top: 5px;">RATING</div> |
|
|
</div> |
|
|
</div> |
|
|
<hr style="border:0; border-top: 1px solid #e2e8f0; margin: 15px 0;"> |
|
|
<p style="color: #1e293b !important; line-height:1.6; font-size: 16px; font-weight: 500;"> |
|
|
<i style="color: #334155 !important;">"{top_match['review_content_clean'][:200]}..."</i> |
|
|
</p> |
|
|
<div style="margin-top:20px; display: flex; justify-content: space-between; align-items: center;"> |
|
|
<span style="font-size: 13px; font-weight: 800; color: #f97316;">VEN Match Confidence: {match_pct}%</span> |
|
|
<span style="font-size: 11px; background: #0f172a; color: white; padding: 4px 10px; border-radius: 6px; font-weight: 700;">AI MATCH</span> |
|
|
</div> |
|
|
</div> |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ven_css = """ |
|
|
/* Global Container */ |
|
|
.gradio-container { background-color: #0f172a !important; } |
|
|
h1 { color: white !important; text-align: center; font-weight: 950 !important; font-size: 2.5rem !important; margin-bottom: 20px !important; } |
|
|
|
|
|
/* Input Header Labels (The questions above the inputs) */ |
|
|
label span { |
|
|
color: white !important; |
|
|
font-weight: 700 !important; |
|
|
font-size: 14px !important; |
|
|
margin-bottom: 5px !important; |
|
|
} |
|
|
|
|
|
/* RADIO CHOICE TEXT (The actual options like 'Solo' or 'Lively') */ |
|
|
/* These elements are usually inside white/grey boxes in Gradio theme, |
|
|
so we force the text to be DARK for visibility. */ |
|
|
.gr-radio label span, |
|
|
.gr-radio span, |
|
|
[data-testid="block-info"] + div label span { |
|
|
color: #1e293b !important; |
|
|
font-weight: 700 !important; |
|
|
} |
|
|
|
|
|
/* Style the primary orange button */ |
|
|
.ven-button { |
|
|
background-color: #f97316 !important; |
|
|
color: white !important; |
|
|
border: none !important; |
|
|
font-weight: 900 !important; |
|
|
font-size: 18px !important; |
|
|
height: 52px !important; |
|
|
border-radius: 12px !important; |
|
|
} |
|
|
|
|
|
/* Quick Vibe Starters table styling */ |
|
|
.gr-samples-table { background-color: #1e293b !important; color: white !important; } |
|
|
""" |
|
|
|
|
|
with gr.Blocks(css=ven_css, title="VEN - AI Restaurant Matchmaker") as demo: |
|
|
gr.Markdown("# ๐ VEN: Restaurant Matchmaker") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
with gr.Group(): |
|
|
|
|
|
in_budget = gr.Dropdown(["Budget-friendly", "Mid-range", "Premium"], label="1. Wallet Size", value="Mid-range") |
|
|
in_diet = gr.Dropdown(["Anything", "Vegetarian", "Vegan", "Meat-lover"], label="2. Diet Preference", value="Anything") |
|
|
in_company = gr.Radio(["Solo", "Date/Couple", "Friends", "Business"], label="3. Social Context", value="Date/Couple") |
|
|
in_purpose = gr.Dropdown(["Casual dinner", "Special occasion", "Quick bite", "Professional meeting"], label="4. The Mission", value="Casual dinner") |
|
|
in_noise = gr.Radio(["Quiet/Intimate", "Moderate/Social", "Lively/Music"], label="5. Vibe / Noise", value="Moderate/Social") |
|
|
|
|
|
btn = gr.Button("Find My Table", variant="primary", elem_classes="ven-button") |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
|
|
|
output_ui = gr.HTML("<div style='text-align:center; padding:100px; color:#64748b; font-weight:700; border: 2px dashed #1e293b; border-radius: 20px;'>Fill the survey to generate your AI match</div>") |
|
|
|
|
|
gr.Markdown("### ๐ Quick Vibe Starters") |
|
|
gr.Examples( |
|
|
examples=[ |
|
|
["Budget-friendly", "Vegetarian", "Friends", "Quick bite", "Moderate/Social"], |
|
|
["Premium", "Meat-lover", "Date/Couple", "Special occasion", "Quiet/Intimate"], |
|
|
["Mid-range", "Anything", "Business", "Professional meeting", "Quiet/Intimate"] |
|
|
], |
|
|
inputs=[in_budget, in_diet, in_company, in_purpose, in_noise], |
|
|
outputs=output_ui, |
|
|
fn=run_ven_engine, |
|
|
cache_examples=False, |
|
|
) |
|
|
|
|
|
|
|
|
btn.click(run_ven_engine, inputs=[in_budget, in_diet, in_company, in_purpose, in_noise], outputs=output_ui) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |