Spaces:
Runtime error
Runtime error
| import chromadb | |
| import sqlite3 | |
| import hashlib | |
| import pandas as pd | |
| from sentence_transformers import SentenceTransformer | |
| #--- Initialize ChromaDB and SentenceTransformer --- | |
| SCHEMA_DESCRIPTIONS = { | |
| "restaurants": """Table restaurants contains restaurant details: | |
| - id: unique identifier | |
| - name: restaurant name | |
| - cuisine: type of cuisine | |
| - location: area or neighborhood | |
| - seating_capacity: total seats | |
| - rating: average rating | |
| - address: full address | |
| - contact: phone or email | |
| - price_range: price category | |
| - special_features: amenities or highlights""", | |
| "tables": """Table tables contains table details: | |
| - id: unique identifier | |
| - restaurant_id: links to restaurants.id | |
| - capacity: number of seats (default 4)""", | |
| "slots": """Table slots contains reservation time slots: | |
| - id: unique identifier | |
| - table_id: links to tables.id | |
| - date: reservation date | |
| - hour: reservation hour | |
| - is_reserved: 0=available, 1=booked""" | |
| } | |
| class SchemaVectorDB: | |
| def __init__(self): | |
| self.client = chromadb.Client() | |
| self.collection = self.client.get_or_create_collection("schema") | |
| self.model = SentenceTransformer('all-MiniLM-L6-v2') | |
| for idx, (name, desc) in enumerate(SCHEMA_DESCRIPTIONS.items()): | |
| self.collection.add(ids=str(idx), documents=desc, metadatas={"name": name}) | |
| def get_relevant_schema(self, query, k=2): | |
| query_embedding = self.model.encode(query).tolist() | |
| results = self.collection.query(query_embeddings=[query_embedding], n_results=k) | |
| # results['metadatas'] is a list of lists: [[{...}, {...}], ...] | |
| # We only have one query, so grab the first list | |
| metadatas = results['metadatas'][0] if results['metadatas'] else [] | |
| return [m['name'] for m in metadatas if m and 'name' in m] | |
| class FullVectorDB: | |
| def __init__(self): | |
| self.client = chromadb.PersistentClient(path="db/chroma") | |
| self.model = SentenceTransformer('all-MiniLM-L6-v2') | |
| # Get existing collections or create if not exist | |
| self.restaurants_col = self.client.get_or_create_collection("restaurants") | |
| self.tables_col = self.client.get_or_create_collection("tables") | |
| self.slots_col = self.client.get_or_create_collection("slots") | |
| # Initialize only if collections are empty | |
| if len(self.restaurants_col.get()['ids']) == 0: | |
| self._initialize_collections() | |
| def _row_to_text(self, row): | |
| return ' '.join(str(v) for v in row.values if pd.notnull(v)) | |
| def _row_hash(self, row): | |
| return hashlib.sha256(str(row.values).encode()).hexdigest() | |
| def _initialize_collections(self): | |
| conn = sqlite3.connect("db/restaurant_reservation.db") | |
| # Create external changelog table | |
| conn.execute(""" | |
| CREATE TABLE IF NOT EXISTS chroma_changelog ( | |
| id INTEGER PRIMARY KEY, | |
| table_name TEXT, | |
| record_id INTEGER, | |
| content_hash TEXT, | |
| UNIQUE(table_name, record_id) | |
| ) | |
| """) | |
| conn.commit() | |
| # Process tables | |
| self._process_table(conn, "restaurants", self.restaurants_col) | |
| self._process_table(conn, "tables", self.tables_col) | |
| self._process_table(conn, "slots", self.slots_col) | |
| conn.close() | |
| def _process_table(self, conn, table_name, collection): | |
| # Get existing records from Chroma | |
| existing_ids = set(collection.get()['ids']) | |
| # Get all records from SQLite with hash | |
| df = pd.read_sql(f"SELECT * FROM {table_name}", conn) | |
| # Process each row | |
| for _, row in df.iterrows(): | |
| chroma_id = f"{table_name}_{row['id']}" | |
| current_hash = self._row_hash(row) | |
| # Check if exists in changelog | |
| changelog = pd.read_sql(f""" | |
| SELECT content_hash | |
| FROM chroma_changelog | |
| WHERE table_name = ? AND record_id = ? | |
| """, conn, params=(table_name, row['id'])) | |
| # Skip if hash matches | |
| if not changelog.empty and changelog.iloc[0]['content_hash'] == current_hash: | |
| continue | |
| # Generate embedding | |
| embedding = self.model.encode(self._row_to_text(row)) | |
| # Update Chroma | |
| collection.upsert( | |
| ids=[chroma_id], | |
| embeddings=[embedding.tolist()], | |
| metadatas=[row.to_dict()] | |
| ) | |
| # Update changelog | |
| conn.execute(""" | |
| INSERT OR REPLACE INTO chroma_changelog | |
| (table_name, record_id, content_hash) | |
| VALUES (?, ?, ?) | |
| """, (table_name, row['id'], current_hash)) | |
| conn.commit() | |
| def semantic_search(self, query, collection_name, k=5): | |
| query_embedding = self.model.encode(query).tolist() | |
| collection = getattr(self, f"{collection_name}_col") | |
| results = collection.query( | |
| query_embeddings=[query_embedding], | |
| n_results=k, | |
| include=["metadatas"] | |
| ) | |
| return results['metadatas'][0] | |