File size: 8,691 Bytes
ae63b4a eb9d187 ae63b4a 4f6f03e eb9d187 3ba73d3 eb9d187 857b4ce 08e20a5 ae63b4a 08e20a5 857b4ce ae63b4a 896f562 08e20a5 3ba73d3 857b4ce 08e20a5 eb9d187 08e20a5 eb9d187 857b4ce eb9d187 857b4ce eb9d187 08e20a5 eb9d187 857b4ce eb9d187 3ba73d3 08e20a5 857b4ce 08e20a5 857b4ce 3ba73d3 4f6f03e 3ba73d3 896f562 08e20a5 3ba73d3 857b4ce 08e20a5 3ba73d3 857b4ce 08e20a5 60b74b8 08e20a5 eb9d187 857b4ce 08e20a5 857b4ce 08e20a5 857b4ce eb9d187 08e20a5 eb9d187 08e20a5 eb9d187 08e20a5 60b74b8 08e20a5 eb9d187 857b4ce eb9d187 857b4ce 21c3413 857b4ce 08e20a5 857b4ce 08e20a5 857b4ce 896f562 857b4ce 896f562 08e20a5 857b4ce 896f562 857b4ce 896f562 08e20a5 857b4ce 08e20a5 857b4ce 08e20a5 857b4ce 08e20a5 21c3413 0d94b00 857b4ce 3ba73d3 eb9d187 08e20a5 857b4ce 08e20a5 857b4ce 08e20a5 857b4ce 60b74b8 08e20a5 3ba73d3 eb9d187 3ba73d3 08e20a5 eb9d187 08e20a5 857b4ce 3ba73d3 ae63b4a eb9d187 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 |
import gradio as gr
import pandas as pd
import numpy as np
import pickle
import os
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
# ==========================================
# 1. SETUP & DATA LOADING
# ==========================================
# Paths for the final 10k dataset and embeddings
CSV_PATH = "cleaned_dataset_10k.csv"
PKL_PATH = "final_embeddings_10k.pkl"
if not os.path.exists(CSV_PATH) or not os.path.exists(PKL_PATH):
raise FileNotFoundError("Missing required data files (CSV or PKL) in the root directory.")
# Load the restaurant dataset
df = pd.read_csv(CSV_PATH)
df.columns = [c.strip().lower().replace(' ', '_') for c in df.columns]
# Load pre-computed embeddings (768 dimensions)
with open(PKL_PATH, 'rb') as f:
embedding_data = pickle.load(f)
dataset_embeddings = embedding_data['embeddings'] if isinstance(embedding_data, dict) else embedding_data
# Load the semantic transformer model (MPNet)
# This model converts natural language queries into semantic vectors
model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2')
# Pre-calculate Taste Profiles (Mean Vectors) for each of the 6 personas
persona_profiles = {}
for persona in df['reviewer_persona'].unique():
if pd.isna(persona): continue
indices = df[df['reviewer_persona'] == persona].index
persona_vectors = dataset_embeddings[indices]
persona_profiles[persona] = np.mean(persona_vectors, axis=0)
# ==========================================
# 2. HYBRID RECOMMENDATION ENGINE
# ==========================================
def run_ven_engine(budget, dietary, company, purpose, noise):
"""
Finds the best restaurant match by combining Persona Similarity (General Taste)
and Review Similarity (Contextual Fit).
"""
# Create the user's specific context string
user_context = f"Searching for a {budget} experience, {dietary} friendly. Group: {company}. Occasion: {purpose}. Atmosphere: {noise}."
query_vec = model.encode([user_context])
# Step A: Identify the closest overall Persona Profile
persona_sims = {p: cosine_similarity(query_vec, v.reshape(1, -1))[0][0]
for p, v in persona_profiles.items()}
closest_persona = max(persona_sims, key=persona_sims.get)
# Step B: Search specifically within that persona's reviews for the best contextual fit
persona_indices = df[df['reviewer_persona'] == closest_persona].index
persona_embeddings = dataset_embeddings[persona_indices]
# Compute similarity for every individual review in this persona group
sub_similarities = cosine_similarity(query_vec, persona_embeddings)[0]
persona_df = df.loc[persona_indices].copy()
persona_df['semantic_fit'] = sub_similarities
persona_df['norm_rating'] = persona_df['rating_score'] / 5.0
# CALCULATE HYBRID SCORE (70% context match + 30% original rating)
persona_df['final_score'] = (persona_df['semantic_fit'] * 0.7) + (persona_df['norm_rating'] * 0.3)
# Retrieve the top re-ranked result
top_match = persona_df.sort_values(by='final_score', ascending=False).iloc[0]
match_pct = int(top_match['final_score'] * 100)
# Return Styled HTML Card (with absolute colors to prevent theme overrides)
return f"""
<div style="background: white; border-radius: 20px; padding: 25px; color: #0f172a !important; text-align: left; border-left: 10px solid #f97316; box-shadow: 0 10px 25px rgba(0,0,0,0.2);">
<div style="display:flex; justify-content:space-between; align-items: flex-start;">
<div style="flex: 1;">
<h2 style="margin:0; font-size: 24px; font-weight: 900; color: #0f172a !important;">{top_match['restaurant_name']}</h2>
<div style="font-size: 14px; color: #475569 !important; font-weight: 700; margin-top: 4px;">Matched for: {closest_persona} profile</div>
</div>
<div style="text-align:right; background: #f8fafc; padding: 10px; border-radius: 12px; border: 1px solid #e2e8f0;">
<div style="font-size: 30px; font-weight: 950; color: #2563eb !important; line-height: 1;">{top_match['rating_score']}</div>
<div style="font-size:10px; font-weight:900; color: #64748b !important; letter-spacing: 1px; margin-top: 5px;">RATING</div>
</div>
</div>
<hr style="border:0; border-top: 1px solid #e2e8f0; margin: 15px 0;">
<p style="color: #1e293b !important; line-height:1.6; font-size: 16px; font-weight: 500;">
<i style="color: #334155 !important;">"{top_match['review_content_clean'][:200]}..."</i>
</p>
<div style="margin-top:20px; display: flex; justify-content: space-between; align-items: center;">
<span style="font-size: 13px; font-weight: 800; color: #f97316;">VEN Match Confidence: {match_pct}%</span>
<span style="font-size: 11px; background: #0f172a; color: white; padding: 4px 10px; border-radius: 6px; font-weight: 700;">AI MATCH</span>
</div>
</div>
"""
# ==========================================
# 3. GRADIO UI SETUP (VISIBILITY FIX)
# ==========================================
# We use targeted CSS to ensure labels are white on dark background,
# but choices inside radio buttons are dark on light background.
ven_css = """
/* Global Container */
.gradio-container { background-color: #0f172a !important; }
h1 { color: white !important; text-align: center; font-weight: 950 !important; font-size: 2.5rem !important; margin-bottom: 20px !important; }
/* Input Header Labels (The questions above the inputs) */
label span {
color: white !important;
font-weight: 700 !important;
font-size: 14px !important;
margin-bottom: 5px !important;
}
/* RADIO CHOICE TEXT (The actual options like 'Solo' or 'Lively') */
/* These elements are usually inside white/grey boxes in Gradio theme,
so we force the text to be DARK for visibility. */
.gr-radio label span,
.gr-radio span,
[data-testid="block-info"] + div label span {
color: #1e293b !important;
font-weight: 700 !important;
}
/* Style the primary orange button */
.ven-button {
background-color: #f97316 !important;
color: white !important;
border: none !important;
font-weight: 900 !important;
font-size: 18px !important;
height: 52px !important;
border-radius: 12px !important;
}
/* Quick Vibe Starters table styling */
.gr-samples-table { background-color: #1e293b !important; color: white !important; }
"""
with gr.Blocks(css=ven_css, title="VEN - AI Restaurant Matchmaker") as demo:
gr.Markdown("# ๐ VEN: Restaurant Matchmaker")
with gr.Row():
with gr.Column(scale=1):
with gr.Group():
# Survey inputs for the 5 dimensions
in_budget = gr.Dropdown(["Budget-friendly", "Mid-range", "Premium"], label="1. Wallet Size", value="Mid-range")
in_diet = gr.Dropdown(["Anything", "Vegetarian", "Vegan", "Meat-lover"], label="2. Diet Preference", value="Anything")
in_company = gr.Radio(["Solo", "Date/Couple", "Friends", "Business"], label="3. Social Context", value="Date/Couple")
in_purpose = gr.Dropdown(["Casual dinner", "Special occasion", "Quick bite", "Professional meeting"], label="4. The Mission", value="Casual dinner")
in_noise = gr.Radio(["Quiet/Intimate", "Moderate/Social", "Lively/Music"], label="5. Vibe / Noise", value="Moderate/Social")
btn = gr.Button("Find My Table", variant="primary", elem_classes="ven-button")
with gr.Column(scale=1):
# Placeholder for the AI recommendation card
output_ui = gr.HTML("<div style='text-align:center; padding:100px; color:#64748b; font-weight:700; border: 2px dashed #1e293b; border-radius: 20px;'>Fill the survey to generate your AI match</div>")
gr.Markdown("### ๐ Quick Vibe Starters")
gr.Examples(
examples=[
["Budget-friendly", "Vegetarian", "Friends", "Quick bite", "Moderate/Social"],
["Premium", "Meat-lover", "Date/Couple", "Special occasion", "Quiet/Intimate"],
["Mid-range", "Anything", "Business", "Professional meeting", "Quiet/Intimate"]
],
inputs=[in_budget, in_diet, in_company, in_purpose, in_noise],
outputs=output_ui,
fn=run_ven_engine,
cache_examples=False,
)
# Event binding
btn.click(run_ven_engine, inputs=[in_budget, in_diet, in_company, in_purpose, in_noise], outputs=output_ui)
if __name__ == "__main__":
demo.launch() |