MatanYehudaDataAnalyst commited on
Commit
0d94b00
Β·
verified Β·
1 Parent(s): 1079e48

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -46
app.py CHANGED
@@ -3,35 +3,34 @@ import pandas as pd
3
  import numpy as np
4
  import pickle
5
  import os
6
- from sentence_transformers import SentenceTransformer
7
- from sklearn.metrics.pairwise import cosine_similarity
8
 
9
  # ==========================================
10
  # 1. SETUP & DATA LOADING
11
  # ==========================================
 
12
  csv_path = "cleaned_dataset_10k.csv"
13
  pkl_path = "final_embeddings_10k.pkl"
14
 
15
  if not os.path.exists(csv_path) or not os.path.exists(pkl_path):
16
- raise FileNotFoundError("❌ Missing files! Please upload 'cleaned_dataset_10k.csv' and 'final_embeddings_10k.pkl'")
 
17
 
18
- # Load Data
19
  df = pd.read_csv(csv_path)
20
-
21
- # --- SAFETY FIX: Normalize Column Names ---
22
- # This ensures it works whether your CSV has "Restaurant Name" or "restaurant_name"
23
  df.columns = [c.strip().lower().replace(' ', '_') for c in df.columns]
24
 
25
- # Check for required columns and fallback if missing
26
  def get_col(candidates, default):
27
  for c in candidates:
28
  if c in df.columns: return c
29
  return default
30
 
31
- col_name = get_col(['restaurant_name', 'name', 'place'], 'restaurant_name')
32
  col_rating = get_col(['rating', 'rating_score', 'stars'], 'rating')
33
  col_review = get_col(['review', 'review_content', 'review_content_clean'], 'review')
34
- col_persona = get_col(['reviewer_persona', 'persona', 'type'], 'reviewer_persona')
35
 
36
  # Load Embeddings
37
  with open(pkl_path, 'rb') as f:
@@ -41,69 +40,66 @@ with open(pkl_path, 'rb') as f:
41
  # Load Model
42
  model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2')
43
 
44
- # Calculate Persona Profiles
45
  persona_profiles = {}
46
  if col_persona in df.columns:
47
  for persona in df[col_persona].unique():
48
  if pd.isna(persona): continue
49
  indices = df[df[col_persona] == persona].index
 
50
  valid_indices = [i for i in indices if i < len(dataset_embeddings)]
51
  if valid_indices:
52
  persona_vectors = dataset_embeddings[valid_indices]
 
53
  persona_profiles[persona] = np.mean(persona_vectors, axis=0)
54
  else:
55
- # Fallback if no persona column exists
56
  persona_profiles['Default'] = np.mean(dataset_embeddings, axis=0)
57
 
58
  # ==========================================
59
- # 2. UI DESIGN (VEN BRANDING)
60
- # ==========================================
61
- ven_css = """
62
- body { background: radial-gradient(1200px 600px at 20% 0%, #eef6ff 0%, #f8fafc 45%, #ffffff 100%) !important; font-family: sans-serif !important; }
63
- .ven-card { background: white; border: 1px solid #e2e8f0; border-radius: 20px; padding: 24px; box-shadow: 0 10px 30px -10px rgba(0,0,0,0.1); }
64
- .ven-header { font-size: 24px; font-weight: 800; color: #1e293b; margin-bottom: 5px; }
65
- .ven-sub { font-size: 14px; color: #64748b; font-weight: 600; margin-bottom: 20px; }
66
- .ven-score { font-size: 32px; font-weight: 900; color: #2563eb; }
67
- .ven-btn { background: #2563eb; color: white; border: none; font-weight: 700; border-radius: 12px; }
68
- """
69
-
70
- # ==========================================
71
- # 3. LOGIC ENGINE
72
  # ==========================================
73
  def run_ven_engine(budget, dietary, company, purpose, noise):
74
- # 1. Create a search query from the dropdowns
75
  user_context = f"Searching for a {budget} experience, {dietary} friendly. Group: {company}. Occasion: {purpose}. Atmosphere: {noise}."
76
 
77
- # 2. Encode the query
78
- query_vec = model.encode([user_context])
79
 
80
- # 3. Find the closest Persona (Cluster)
81
- similarities = {p: cosine_similarity(query_vec, v.reshape(1, -1))[0][0] for p, v in persona_profiles.items()}
82
- closest_persona = max(similarities, key=similarities.get)
 
 
 
 
 
 
 
 
83
 
84
- # 4. Filter data for that persona
85
  if col_persona in df.columns:
86
  persona_df = df[df[col_persona] == closest_persona]
87
  if persona_df.empty: persona_df = df
88
  else:
89
  persona_df = df
90
 
91
- # 5. Get the highest rated restaurant in that group
92
  top_match = persona_df.sort_values(by=col_rating, ascending=False).iloc[0]
93
 
94
- # 6. Format the output HTML
95
- match_pct = int(similarities[closest_persona] * 100)
96
  review_text = str(top_match[col_review])[:180] + "..."
 
97
 
98
  return f"""
99
- <div class="ven-card">
100
  <div style="display:flex; justify-content:space-between;">
101
  <div>
102
- <div class="ven-header">{top_match[col_name]}</div>
103
- <div class="ven-sub">Top Match for {closest_persona}</div>
104
  </div>
105
  <div style="text-align:right;">
106
- <div class="ven-score">{top_match[col_rating]}</div>
107
  <div style="font-size:12px; font-weight:bold; color:#94a3b8;">RATING</div>
108
  </div>
109
  </div>
@@ -114,11 +110,12 @@ def run_ven_engine(budget, dietary, company, purpose, noise):
114
  """
115
 
116
  # ==========================================
117
- # 4. APP LAYOUT
118
  # ==========================================
 
 
119
  with gr.Blocks(css=ven_css, title="VEN Project") as demo:
120
  gr.Markdown("# πŸ” VEN: Restaurant Matchmaker")
121
- gr.Markdown("Select your vibe below to get a personalized recommendation.")
122
 
123
  with gr.Row():
124
  with gr.Column():
@@ -127,14 +124,12 @@ with gr.Blocks(css=ven_css, title="VEN Project") as demo:
127
  in_company = gr.Radio(["Solo", "Date/Couple", "Friends", "Business"], label="With who?", value="Date/Couple")
128
  in_purpose = gr.Dropdown(["Casual dinner", "Special occasion", "Quick bite"], label="Occasion", value="Casual dinner")
129
  in_noise = gr.Radio(["Quiet/Intimate", "Moderate/Social", "Lively/Music"], label="Environment", value="Moderate/Social")
130
-
131
  btn = gr.Button("Find My Table", variant="primary")
132
 
133
  with gr.Column():
134
  output_ui = gr.HTML("<h4>Recommendation will appear here...</h4>")
135
 
136
- # --- STEP 7: REQUIRED QUICK STARTERS ---
137
- gr.Markdown("### πŸš€ One-Click Examples (Quick Starters)")
138
  gr.Examples(
139
  examples=[
140
  ["Budget-friendly", "Vegetarian", "Friends", "Quick bite", "Moderate/Social"],
@@ -144,9 +139,8 @@ with gr.Blocks(css=ven_css, title="VEN Project") as demo:
144
  inputs=[in_budget, in_diet, in_company, in_purpose, in_noise],
145
  outputs=output_ui,
146
  fn=run_ven_engine,
147
- cache_examples=True, # This pre-runs the examples so they are instant!
148
  )
149
-
150
  btn.click(run_ven_engine, inputs=[in_budget, in_diet, in_company, in_purpose, in_noise], outputs=output_ui)
151
 
152
  if __name__ == "__main__":
 
3
  import numpy as np
4
  import pickle
5
  import os
6
+ from sentence_transformers import SentenceTransformer, util
7
+ import torch
8
 
9
  # ==========================================
10
  # 1. SETUP & DATA LOADING
11
  # ==========================================
12
+ # NOTE: Check your file names exactly!
13
  csv_path = "cleaned_dataset_10k.csv"
14
  pkl_path = "final_embeddings_10k.pkl"
15
 
16
  if not os.path.exists(csv_path) or not os.path.exists(pkl_path):
17
+ # This error usually means the file names in the 'Files' tab are different
18
+ raise FileNotFoundError(f"❌ FILES NOT FOUND. I see these files: {os.listdir('.')}")
19
 
20
+ # Load Data & Normalize Columns
21
  df = pd.read_csv(csv_path)
 
 
 
22
  df.columns = [c.strip().lower().replace(' ', '_') for c in df.columns]
23
 
24
+ # Helper to find columns even if names vary slightly
25
  def get_col(candidates, default):
26
  for c in candidates:
27
  if c in df.columns: return c
28
  return default
29
 
30
+ col_name = get_col(['restaurant_name', 'name'], 'restaurant_name')
31
  col_rating = get_col(['rating', 'rating_score', 'stars'], 'rating')
32
  col_review = get_col(['review', 'review_content', 'review_content_clean'], 'review')
33
+ col_persona = get_col(['reviewer_persona', 'persona'], 'reviewer_persona')
34
 
35
  # Load Embeddings
36
  with open(pkl_path, 'rb') as f:
 
40
  # Load Model
41
  model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2')
42
 
43
+ # Calculate Persona Profiles (Average Vectors)
44
  persona_profiles = {}
45
  if col_persona in df.columns:
46
  for persona in df[col_persona].unique():
47
  if pd.isna(persona): continue
48
  indices = df[df[col_persona] == persona].index
49
+ # valid_indices ensures we don't crash if indices mismatch
50
  valid_indices = [i for i in indices if i < len(dataset_embeddings)]
51
  if valid_indices:
52
  persona_vectors = dataset_embeddings[valid_indices]
53
+ # Use torch/numpy to calculate mean
54
  persona_profiles[persona] = np.mean(persona_vectors, axis=0)
55
  else:
 
56
  persona_profiles['Default'] = np.mean(dataset_embeddings, axis=0)
57
 
58
  # ==========================================
59
+ # 2. LOGIC ENGINE (Replaced Scikit-Learn with Util)
 
 
 
 
 
 
 
 
 
 
 
 
60
  # ==========================================
61
  def run_ven_engine(budget, dietary, company, purpose, noise):
62
+ # 1. Create a search query
63
  user_context = f"Searching for a {budget} experience, {dietary} friendly. Group: {company}. Occasion: {purpose}. Atmosphere: {noise}."
64
 
65
+ # 2. Encode query
66
+ query_vec = model.encode(user_context, convert_to_tensor=True)
67
 
68
+ # 3. Find closest Persona using Sentence-Transformers Utility (No Sklearn needed)
69
+ best_score = -1
70
+ closest_persona = list(persona_profiles.keys())[0]
71
+
72
+ for persona, profile_vec in persona_profiles.items():
73
+ # Convert profile to tensor for comparison
74
+ profile_tensor = torch.tensor(profile_vec)
75
+ score = util.cos_sim(query_vec, profile_tensor).item()
76
+ if score > best_score:
77
+ best_score = score
78
+ closest_persona = persona
79
 
80
+ # 4. Filter data
81
  if col_persona in df.columns:
82
  persona_df = df[df[col_persona] == closest_persona]
83
  if persona_df.empty: persona_df = df
84
  else:
85
  persona_df = df
86
 
87
+ # 5. Get top result
88
  top_match = persona_df.sort_values(by=col_rating, ascending=False).iloc[0]
89
 
90
+ # 6. Format Output
 
91
  review_text = str(top_match[col_review])[:180] + "..."
92
+ match_pct = int(best_score * 100)
93
 
94
  return f"""
95
+ <div style="background: white; border: 1px solid #e2e8f0; border-radius: 20px; padding: 24px; box-shadow: 0 10px 30px -10px rgba(0,0,0,0.1);">
96
  <div style="display:flex; justify-content:space-between;">
97
  <div>
98
+ <div style="font-size: 24px; font-weight: 800; color: #1e293b;">{top_match[col_name]}</div>
99
+ <div style="font-size: 14px; color: #64748b; font-weight: 600;">Top Match for {closest_persona}</div>
100
  </div>
101
  <div style="text-align:right;">
102
+ <div style="font-size: 32px; font-weight: 900; color: #2563eb;">{top_match[col_rating]}</div>
103
  <div style="font-size:12px; font-weight:bold; color:#94a3b8;">RATING</div>
104
  </div>
105
  </div>
 
110
  """
111
 
112
  # ==========================================
113
+ # 3. APP UI
114
  # ==========================================
115
+ ven_css = """body { background: radial-gradient(1200px 600px at 20% 0%, #eef6ff 0%, #f8fafc 45%, #ffffff 100%) !important; font-family: sans-serif !important; }"""
116
+
117
  with gr.Blocks(css=ven_css, title="VEN Project") as demo:
118
  gr.Markdown("# πŸ” VEN: Restaurant Matchmaker")
 
119
 
120
  with gr.Row():
121
  with gr.Column():
 
124
  in_company = gr.Radio(["Solo", "Date/Couple", "Friends", "Business"], label="With who?", value="Date/Couple")
125
  in_purpose = gr.Dropdown(["Casual dinner", "Special occasion", "Quick bite"], label="Occasion", value="Casual dinner")
126
  in_noise = gr.Radio(["Quiet/Intimate", "Moderate/Social", "Lively/Music"], label="Environment", value="Moderate/Social")
 
127
  btn = gr.Button("Find My Table", variant="primary")
128
 
129
  with gr.Column():
130
  output_ui = gr.HTML("<h4>Recommendation will appear here...</h4>")
131
 
132
+ gr.Markdown("### πŸš€ One-Click Examples")
 
133
  gr.Examples(
134
  examples=[
135
  ["Budget-friendly", "Vegetarian", "Friends", "Quick bite", "Moderate/Social"],
 
139
  inputs=[in_budget, in_diet, in_company, in_purpose, in_noise],
140
  outputs=output_ui,
141
  fn=run_ven_engine,
142
+ cache_examples=True,
143
  )
 
144
  btn.click(run_ven_engine, inputs=[in_budget, in_diet, in_company, in_purpose, in_noise], outputs=output_ui)
145
 
146
  if __name__ == "__main__":