MatanYehudaDataAnalyst commited on
Commit
4f6f03e
·
verified ·
1 Parent(s): 25cf791

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -33
app.py CHANGED
@@ -3,79 +3,76 @@ import pandas as pd
3
  import numpy as np
4
  import pickle
5
  import os
6
- from sentence_transformers import SentenceTransformer, util
7
- import torch
8
 
9
  # ==========================================
10
  # 1. SETUP & DATA LOADING
11
  # ==========================================
12
- # NOTE: Check your file names exactly!
13
  csv_path = "cleaned_dataset_10k.csv"
14
  pkl_path = "final_embeddings_10k.pkl"
15
 
 
16
  if not os.path.exists(csv_path) or not os.path.exists(pkl_path):
17
- # This error usually means the file names in the 'Files' tab are different
18
- raise FileNotFoundError(f"❌ FILES NOT FOUND. I see these files: {os.listdir('.')}")
19
 
20
- # Load Data & Normalize Columns
21
  df = pd.read_csv(csv_path)
 
 
22
  df.columns = [c.strip().lower().replace(' ', '_') for c in df.columns]
23
 
24
- # Helper to find columns even if names vary slightly
25
  def get_col(candidates, default):
26
  for c in candidates:
27
  if c in df.columns: return c
28
  return default
29
 
30
- col_name = get_col(['restaurant_name', 'name'], 'restaurant_name')
 
31
  col_rating = get_col(['rating', 'rating_score', 'stars'], 'rating')
32
  col_review = get_col(['review', 'review_content', 'review_content_clean'], 'review')
33
- col_persona = get_col(['reviewer_persona', 'persona'], 'reviewer_persona')
34
 
35
  # Load Embeddings
36
  with open(pkl_path, 'rb') as f:
37
  embedding_data = pickle.load(f)
38
- dataset_embeddings = embedding_data['embeddings']
 
 
 
 
39
 
40
  # Load Model
41
  model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2')
42
 
43
- # Calculate Persona Profiles (Average Vectors)
44
  persona_profiles = {}
45
  if col_persona in df.columns:
46
  for persona in df[col_persona].unique():
47
  if pd.isna(persona): continue
48
  indices = df[df[col_persona] == persona].index
49
- # valid_indices ensures we don't crash if indices mismatch
50
  valid_indices = [i for i in indices if i < len(dataset_embeddings)]
51
  if valid_indices:
52
  persona_vectors = dataset_embeddings[valid_indices]
53
- # Use torch/numpy to calculate mean
54
  persona_profiles[persona] = np.mean(persona_vectors, axis=0)
55
  else:
56
  persona_profiles['Default'] = np.mean(dataset_embeddings, axis=0)
57
 
58
  # ==========================================
59
- # 2. LOGIC ENGINE (Replaced Scikit-Learn with Util)
60
  # ==========================================
61
  def run_ven_engine(budget, dietary, company, purpose, noise):
62
  # 1. Create a search query
63
  user_context = f"Searching for a {budget} experience, {dietary} friendly. Group: {company}. Occasion: {purpose}. Atmosphere: {noise}."
64
 
65
  # 2. Encode query
66
- query_vec = model.encode(user_context, convert_to_tensor=True)
67
 
68
- # 3. Find closest Persona using Sentence-Transformers Utility (No Sklearn needed)
69
- best_score = -1
70
- closest_persona = list(persona_profiles.keys())[0]
71
-
72
- for persona, profile_vec in persona_profiles.items():
73
- # Convert profile to tensor for comparison
74
- profile_tensor = torch.tensor(profile_vec)
75
- score = util.cos_sim(query_vec, profile_tensor).item()
76
- if score > best_score:
77
- best_score = score
78
- closest_persona = persona
79
 
80
  # 4. Filter data
81
  if col_persona in df.columns:
@@ -88,18 +85,18 @@ def run_ven_engine(budget, dietary, company, purpose, noise):
88
  top_match = persona_df.sort_values(by=col_rating, ascending=False).iloc[0]
89
 
90
  # 6. Format Output
91
- review_text = str(top_match[col_review])[:180] + "..."
92
- match_pct = int(best_score * 100)
93
 
94
  return f"""
95
- <div style="background: white; border: 1px solid #e2e8f0; border-radius: 20px; padding: 24px; box-shadow: 0 10px 30px -10px rgba(0,0,0,0.1);">
96
  <div style="display:flex; justify-content:space-between;">
97
  <div>
98
- <div style="font-size: 24px; font-weight: 800; color: #1e293b;">{top_match[col_name]}</div>
99
- <div style="font-size: 14px; color: #64748b; font-weight: 600;">Top Match for {closest_persona}</div>
100
  </div>
101
  <div style="text-align:right;">
102
- <div style="font-size: 32px; font-weight: 900; color: #2563eb;">{top_match[col_rating]}</div>
103
  <div style="font-size:12px; font-weight:bold; color:#94a3b8;">RATING</div>
104
  </div>
105
  </div>
@@ -129,7 +126,8 @@ with gr.Blocks(css=ven_css, title="VEN Project") as demo:
129
  with gr.Column():
130
  output_ui = gr.HTML("<h4>Recommendation will appear here...</h4>")
131
 
132
- gr.Markdown("### 🚀 One-Click Examples")
 
133
  gr.Examples(
134
  examples=[
135
  ["Budget-friendly", "Vegetarian", "Friends", "Quick bite", "Moderate/Social"],
@@ -141,6 +139,7 @@ with gr.Blocks(css=ven_css, title="VEN Project") as demo:
141
  fn=run_ven_engine,
142
  cache_examples=True,
143
  )
 
144
  btn.click(run_ven_engine, inputs=[in_budget, in_diet, in_company, in_purpose, in_noise], outputs=output_ui)
145
 
146
  if __name__ == "__main__":
 
3
  import numpy as np
4
  import pickle
5
  import os
6
+ from sentence_transformers import SentenceTransformer
7
+ from sklearn.metrics.pairwise import cosine_similarity
8
 
9
  # ==========================================
10
  # 1. SETUP & DATA LOADING
11
  # ==========================================
12
+ # We use the EXACT filenames you provided
13
  csv_path = "cleaned_dataset_10k.csv"
14
  pkl_path = "final_embeddings_10k.pkl"
15
 
16
+ # Check if files exist to prevent crashing
17
  if not os.path.exists(csv_path) or not os.path.exists(pkl_path):
18
+ raise FileNotFoundError(f"Error: Files not found. I see: {os.listdir('.')}")
 
19
 
20
+ # Load Data
21
  df = pd.read_csv(csv_path)
22
+
23
+ # Normalize column names (fixes 'Restaurant Name' vs 'restaurant_name' issues)
24
  df.columns = [c.strip().lower().replace(' ', '_') for c in df.columns]
25
 
26
+ # Helper to find the right column names
27
  def get_col(candidates, default):
28
  for c in candidates:
29
  if c in df.columns: return c
30
  return default
31
 
32
+ # Map your CSV columns to what the app needs
33
+ col_name = get_col(['restaurant_name', 'name', 'place'], 'restaurant_name')
34
  col_rating = get_col(['rating', 'rating_score', 'stars'], 'rating')
35
  col_review = get_col(['review', 'review_content', 'review_content_clean'], 'review')
36
+ col_persona = get_col(['reviewer_persona', 'persona', 'type'], 'reviewer_persona')
37
 
38
  # Load Embeddings
39
  with open(pkl_path, 'rb') as f:
40
  embedding_data = pickle.load(f)
41
+ # Handle if pickle is a dictionary or direct array
42
+ if isinstance(embedding_data, dict) and 'embeddings' in embedding_data:
43
+ dataset_embeddings = embedding_data['embeddings']
44
+ else:
45
+ dataset_embeddings = embedding_data
46
 
47
  # Load Model
48
  model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2')
49
 
50
+ # Calculate Persona Profiles
51
  persona_profiles = {}
52
  if col_persona in df.columns:
53
  for persona in df[col_persona].unique():
54
  if pd.isna(persona): continue
55
  indices = df[df[col_persona] == persona].index
 
56
  valid_indices = [i for i in indices if i < len(dataset_embeddings)]
57
  if valid_indices:
58
  persona_vectors = dataset_embeddings[valid_indices]
 
59
  persona_profiles[persona] = np.mean(persona_vectors, axis=0)
60
  else:
61
  persona_profiles['Default'] = np.mean(dataset_embeddings, axis=0)
62
 
63
  # ==========================================
64
+ # 2. LOGIC ENGINE
65
  # ==========================================
66
  def run_ven_engine(budget, dietary, company, purpose, noise):
67
  # 1. Create a search query
68
  user_context = f"Searching for a {budget} experience, {dietary} friendly. Group: {company}. Occasion: {purpose}. Atmosphere: {noise}."
69
 
70
  # 2. Encode query
71
+ query_vec = model.encode([user_context])
72
 
73
+ # 3. Find closest Persona
74
+ similarities = {p: cosine_similarity(query_vec, v.reshape(1, -1))[0][0] for p, v in persona_profiles.items()}
75
+ closest_persona = max(similarities, key=similarities.get)
 
 
 
 
 
 
 
 
76
 
77
  # 4. Filter data
78
  if col_persona in df.columns:
 
85
  top_match = persona_df.sort_values(by=col_rating, ascending=False).iloc[0]
86
 
87
  # 6. Format Output
88
+ match_pct = int(similarities[closest_persona] * 100)
89
+ review_text = str(top_match[col_review])[:160] + "..."
90
 
91
  return f"""
92
+ <div style="background: white; border: 1px solid #e2e8f0; border-radius: 20px; padding: 24px;">
93
  <div style="display:flex; justify-content:space-between;">
94
  <div>
95
+ <div style="font-size: 22px; font-weight: 800; color: #1e293b;">{top_match[col_name]}</div>
96
+ <div style="font-size: 14px; color: #64748b; font-weight: 600;">Match for: {closest_persona}</div>
97
  </div>
98
  <div style="text-align:right;">
99
+ <div style="font-size: 28px; font-weight: 900; color: #2563eb;">{top_match[col_rating]}</div>
100
  <div style="font-size:12px; font-weight:bold; color:#94a3b8;">RATING</div>
101
  </div>
102
  </div>
 
126
  with gr.Column():
127
  output_ui = gr.HTML("<h4>Recommendation will appear here...</h4>")
128
 
129
+ # --- THIS IS STEP 7: ONE-CLICK STARTERS ---
130
+ gr.Markdown("### 🚀 Quick Starters (One-Click)")
131
  gr.Examples(
132
  examples=[
133
  ["Budget-friendly", "Vegetarian", "Friends", "Quick bite", "Moderate/Social"],
 
139
  fn=run_ven_engine,
140
  cache_examples=True,
141
  )
142
+
143
  btn.click(run_ven_engine, inputs=[in_budget, in_diet, in_company, in_purpose, in_noise], outputs=output_ui)
144
 
145
  if __name__ == "__main__":