tueniuu commited on
Commit
8c15eb9
·
verified ·
1 Parent(s): e53b93b

Upload 4 files

Browse files
data_processing/build_index.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import faiss
2
+ import numpy as np
3
+
4
+ def build_and_save_index(embedding_file, index_file):
5
+ """
6
+ Loads embeddings, NORMALIZES them, builds a FAISS IndexFlatIP index,
7
+ and saves the index to disk.
8
+ """
9
+ try:
10
+ # Load the embeddings from the .npy file
11
+ print(f"Loading embeddings from '{embedding_file}'...")
12
+ embeddings = np.load(embedding_file).astype('float32')
13
+ print(f"Embeddings loaded. Shape: {embeddings.shape}")
14
+
15
+ # --- FIX 1: NORMALIZE THE EMBEDDINGS ---
16
+ # This crucial step scales all vectors to a unit length of 1.
17
+ print("Normalizing embeddings to unit length...")
18
+ faiss.normalize_L2(embeddings)
19
+
20
+ embedding_dimension = embeddings.shape[1]
21
+
22
+ # --- FIX 2: USE IndexFlatIP FOR COSINE SIMILARITY ---
23
+ # IndexFlatIP (Inner Product) is the correct index for comparing normalized text vectors.
24
+ print(f"Building FAISS IndexFlatIP with dimension {embedding_dimension}...")
25
+ index = faiss.IndexFlatIP(embedding_dimension)
26
+
27
+ # Add the normalized embeddings to the index
28
+ index.add(embeddings)
29
+ print(f"Successfully added {index.ntotal} vectors to the index.")
30
+
31
+ # Save the Index
32
+ print(f"Saving index to '{index_file}'...")
33
+ faiss.write_index(index, index_file)
34
+ print(f"Index saved successfully!")
35
+
36
+ except FileNotFoundError:
37
+ print(f"ERROR: The file '{embedding_file}' was not found.")
38
+ except Exception as e:
39
+ print(f"AN UNEXPECTED ERROR OCCURRED: {e}")
40
+
41
+ if __name__ == '__main__':
42
+ embedding_filename = 'location_embeddings.npy'
43
+ index_filename = 'location_index.faiss'
44
+
45
+ build_and_save_index(embedding_filename, index_filename)
data_processing/excel_to_db.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import psycopg2
3
+ import json
4
+ import numpy as np
5
+
6
+ def format_operating_hours(time_str):
7
+ """
8
+ Transforms a simple time string (e.g., '11:00-22:00') into a
9
+ structured JSON object for all days of the week.
10
+ """
11
+ if not isinstance(time_str, str) or '-' not in time_str:
12
+ return None
13
+ schedule = {
14
+ "monday": time_str, "tuesday": time_str, "wednesday": time_str,
15
+ "thursday": time_str, "friday": time_str, "saturday": time_str,
16
+ "sunday": time_str,
17
+ }
18
+ return json.dumps(schedule)
19
+
20
+ def parse_period_dates(period_str):
21
+ """
22
+ Parses a date range string like '2025.01.04 - 2025.12.27'
23
+ and returns a tuple of (start_date, end_date).
24
+ """
25
+ if not isinstance(period_str, str) or '-' not in period_str:
26
+ return None, None
27
+ try:
28
+ parts = period_str.split('-')
29
+ start_date = parts[0].strip().replace('.', '-')
30
+ end_date = parts[1].strip().replace('.', '-')
31
+ return start_date, end_date
32
+ except Exception:
33
+ return None, None
34
+
35
+ def load_excel_to_postgres(excel_path, db_params):
36
+ """
37
+ Connects to PostgreSQL, reads an Excel file, and inserts the data.
38
+ """
39
+ conn = None
40
+ cur = None
41
+ try:
42
+ print(f"Reading data from '{excel_path}'...")
43
+ df = pd.read_excel(excel_path)
44
+ df = df.replace({np.nan: None})
45
+ print("Data read successfully. Preparing for database insertion...")
46
+
47
+ conn = psycopg2.connect(**db_params)
48
+ cur = conn.cursor()
49
+ print("Successfully connected to the PostgreSQL database.")
50
+
51
+ for index, row in df.iterrows():
52
+ operating_hours_json = format_operating_hours(row.get('time'))
53
+ start_date, end_date = parse_period_dates(row.get('period'))
54
+
55
+ sql = """
56
+ INSERT INTO locations (
57
+ name, address, naver_url, region, primary_category, tags,
58
+ price_level, indoor_outdoor, operating_hours,
59
+ period_start_date, period_end_date, website, meal_type, geom
60
+ ) VALUES (
61
+ %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s,
62
+ ST_SetSRID(ST_MakePoint(%s, %s), 4326)
63
+ ) ON CONFLICT (name) DO NOTHING;
64
+ """
65
+
66
+ data_tuple = (
67
+ row.get('name'), row.get('address'), row.get('naver_url'), row.get('region'),
68
+ row.get('primary_category'), row.get('tags'), row.get('price_level'),
69
+ row.get('indoor_outdoor'), operating_hours_json, start_date, end_date,
70
+ row.get('website'), row.get('type'),
71
+ row.get('longitude'), row.get('latitude')
72
+ )
73
+
74
+ cur.execute(sql, data_tuple)
75
+
76
+ conn.commit()
77
+ print(f"\nSuccessfully processed and inserted {len(df)} rows into the 'locations' table.")
78
+
79
+ except FileNotFoundError:
80
+ print(f"ERROR: The file '{excel_path}' was not found.")
81
+ except psycopg2.Error as e:
82
+ print(f"DATABASE ERROR: {e}")
83
+ except Exception as e:
84
+ print(f"AN UNEXPECTED ERROR OCCURRED: {e}")
85
+ finally:
86
+ if cur is not None:
87
+ cur.close()
88
+ if conn is not None:
89
+ conn.close()
90
+ print("Database connection closed.")
91
+
92
+ if __name__ == '__main__':
93
+ db_connection_params = {
94
+ "host": "localhost",
95
+ "database": "recommendation_locations",
96
+ "user": "postgres",
97
+ "password": "nafikova03",
98
+ "port": "5432"
99
+ }
100
+ excel_file_path = "loc_data.xlsx"
101
+ load_excel_to_postgres(excel_file_path, db_connection_params)
data_processing/generate_embeddings.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import psycopg2
3
+ import numpy as np
4
+ from sentence_transformers import SentenceTransformer
5
+ import sys
6
+ import os
7
+ import time
8
+ import json
9
+ import random
10
+ import google.generativeai as genai
11
+ from google.api_core import exceptions
12
+
13
+ # --- (The 'generate_gemini_description_json' function remains the same as the last version with exponential backoff) ---
14
+ def generate_gemini_description_json(row, model):
15
+ # This function with the retry logic is still needed and is unchanged.
16
+ location_data = {
17
+ 'name': row.get('name', ''),
18
+ 'category': row.get('primary_category', ''),
19
+ 'tags': row.get('tags', '')
20
+ }
21
+ location_data = {k: v for k, v in location_data.items() if v}
22
+ data_string = ", ".join([f"{key}: {value}" for key, value in location_data.items()])
23
+ prompt = f"""
24
+ You are a data enrichment specialist. Your task is to generate a JSON object containing a creative, appealing description for a location in Busan, South Korea.
25
+ **Instructions:**
26
+ - Use the provided data to create an engaging, fluent paragraph between 40 and 70 words.
27
+ - Your output MUST be a valid JSON object with a single key: "description".
28
+ **Input Data:**
29
+ {data_string}
30
+ **Output JSON:**
31
+ """
32
+ max_retries = 5
33
+ base_delay = 60
34
+ for attempt in range(max_retries):
35
+ try:
36
+ generation_config = genai.types.GenerationConfig(response_mime_type="application/json")
37
+ response = model.generate_content(prompt, generation_config=generation_config)
38
+ response_json = json.loads(response.text)
39
+ description = response_json['description']
40
+ return description.strip()
41
+ except exceptions.ResourceExhausted as e:
42
+ print(f" - Rate limit hit for ID {row.get('id', 'N/A')}. Waiting... (Attempt {attempt + 1}/{max_retries})")
43
+ delay = base_delay * (2 ** attempt) + random.uniform(0, 5)
44
+ print(f" ...backing off for {delay:.2f} seconds.")
45
+ time.sleep(delay)
46
+ except (json.JSONDecodeError, KeyError) as e:
47
+ print(f" - JSON Parsing Error for ID {row.get('id', 'N/A')}: {e}. Retrying...")
48
+ time.sleep(5)
49
+ except Exception as e:
50
+ print(f" - Unexpected API Error for ID {row.get('id', 'N/A')}: {e}. Retrying with backoff...")
51
+ delay = base_delay * (2 ** attempt) + random.uniform(0, 5)
52
+ time.sleep(delay)
53
+ print(f" - All retries failed for ID {row.get('id', 'N/A')}. Using fallback.")
54
+ return f"{row.get('name', '')}. Tags include: {row.get('tags', '')}"
55
+
56
+
57
+ # --- THE MAIN ORCHESTRATION SCRIPT ---
58
+ def generate_and_save_embeddings(db_params, api_key):
59
+ # Configure the Gemini API
60
+ genai.configure(api_key=api_key)
61
+ gemini_model = genai.GenerativeModel('gemini-2.5-pro')
62
+
63
+ # --- NEW: RESUME LOGIC ---
64
+ PROGRESS_FILE = 'descriptions_progress.csv'
65
+
66
+ # 1. Fetch ALL locations from the database first
67
+ conn = psycopg2.connect(**db_params)
68
+ sql_query = "SELECT id, name, tags, primary_category, meal_type FROM locations ORDER BY id;"
69
+ df_all_locations = pd.read_sql_query(sql_query, conn)
70
+ conn.close()
71
+ print(f"Total locations to process: {len(df_all_locations)}")
72
+
73
+ # 2. Check for an existing progress file
74
+ processed_ids = set()
75
+ if os.path.exists(PROGRESS_FILE):
76
+ print(f"Found existing progress file: '{PROGRESS_FILE}'. Resuming...")
77
+ df_progress = pd.read_csv(PROGRESS_FILE)
78
+ processed_ids = set(df_progress['id'])
79
+ print(f"{len(processed_ids)} locations already have descriptions.")
80
+ else:
81
+ print("No progress file found. Starting a new session.")
82
+ df_progress = pd.DataFrame(columns=['id', 'description'])
83
+
84
+ # 3. Filter out the locations that are already processed
85
+ df_to_process = df_all_locations[~df_all_locations['id'].isin(processed_ids)]
86
+ print(f"{len(df_to_process)} locations remaining to be processed.")
87
+
88
+ if df_to_process.empty:
89
+ print("All locations have already been processed.")
90
+
91
+ # --- NEW: GRACEFUL SHUTDOWN AND INCREMENTAL SAVING ---
92
+ newly_processed_data = []
93
+ try:
94
+ if not df_to_process.empty:
95
+ print("\nStarting Gemini description generation...")
96
+ for index, row in df_to_process.iterrows():
97
+ print(f"Processing ID: {row['id']}...")
98
+ description = generate_gemini_description_json(row, gemini_model)
99
+ newly_processed_data.append({'id': row['id'], 'description': description})
100
+
101
+ except KeyboardInterrupt:
102
+ print("\n--- KeyboardInterrupt detected! Saving progress before exiting. ---")
103
+ finally:
104
+ if newly_processed_data:
105
+ df_new_progress = pd.DataFrame(newly_processed_data)
106
+ df_combined = pd.concat([df_progress, df_new_progress], ignore_index=True)
107
+ df_combined.to_csv(PROGRESS_FILE, index=False)
108
+ print(f"\nSuccessfully saved {len(newly_processed_data)} new descriptions to '{PROGRESS_FILE}'.")
109
+ df_progress = df_combined
110
+ else:
111
+ print("\nNo new descriptions were generated in this session.")
112
+
113
+ # --- FINAL EMBEDDING GENERATION (runs after the loop is complete) ---
114
+ print("\n--- All descriptions are now generated. Proceeding to create embeddings. ---")
115
+
116
+ # Merge the final descriptions with the original data to ensure correct order
117
+ df_final = df_all_locations.merge(df_progress, on='id', how='left')
118
+
119
+ # Check for any locations that might have been missed
120
+ if df_final['description'].isnull().any():
121
+ print("WARNING: Some locations are missing descriptions. Using fallback.")
122
+ df_final['description'].fillna("No description available.", inplace=True)
123
+
124
+ sbert_model = SentenceTransformer('all-MiniLM-L6-v2')
125
+ sentences = df_final['description'].tolist()
126
+ print(f"Encoding {len(sentences)} final descriptions into vectors...")
127
+
128
+ location_embeddings = sbert_model.encode(sentences, show_progress_bar=True)
129
+ location_ids = df_final['id'].to_numpy()
130
+
131
+ np.save('location_embeddings.npy', location_embeddings)
132
+ np.save('location_ids.npy', location_ids)
133
+
134
+ print("\nEmbeddings from Gemini descriptions generated successfully!")
135
+ print(f"Embeddings matrix shape: {location_embeddings.shape}")
136
+
137
+ if __name__ == '__main__':
138
+ GEMINI_API_KEY = "AIzaSyBwMSL341arzL_FxPzy_DvhDl4Jc46DlaY"
139
+
140
+ if GEMINI_API_KEY == "YOUR_API_KEY_HERE":
141
+ print("ERROR: Please replace 'YOUR_API_KEY_HERE' with your actual Gemini API key.")
142
+ else:
143
+ db_connection_params = {
144
+ "host": "localhost",
145
+ "database": "recommendation_locations",
146
+ "user": "postgres",
147
+ "password": "nafikova03",
148
+ "port": "5432"
149
+ }
150
+ generate_and_save_embeddings(db_connection_params, GEMINI_API_KEY)
data_processing/loc_data.xlsx ADDED
Binary file (74.1 kB). View file