Spaces:
Sleeping
Sleeping
| # Retrieve.py (fixed version) | |
| import os | |
| import json | |
| import numpy as np | |
| from typing import List, Dict, Optional, Any | |
| import requests | |
| from langchain_community.vectorstores import FAISS | |
| import numpy as np | |
| #from db_paths import (PERSONAL_INFO_CHUNKS_PATH,CHAT_HISTORY_CHUNKS_PATH) | |
| import json | |
| from supabase_ie import load_user_info, load_history_for_display, download_faiss_from_supabase | |
| from config import SUPABASE_URL, SUPABASE_HEADERS | |
| #used in embed_query | |
| def normalize(v): | |
| v = np.array(v) | |
| norm = np.linalg.norm(v) | |
| return v if norm == 0 else (v / norm) | |
| #used in retrieve_all_chunks | |
| def embed_query(query: str, model) -> np.ndarray: | |
| formatted_query = f"query: {query.strip()}" | |
| vector = model.embed_query(formatted_query) | |
| return normalize(vector).astype("float32").reshape(1, -1) | |
| #used in retrieve_all_chunks | |
| def retrieve_chunks_from_vdb( | |
| db_key: str, | |
| query_vector: np.ndarray, | |
| model, | |
| query: str, | |
| username: str, | |
| k: int = 10, | |
| use_metadata_boost: bool = False, | |
| filter_keywords: List[str] = None, | |
| topic: str = None, | |
| db6_override_store: Optional[FAISS] = None, # NEW | |
| ) -> List[Dict]: | |
| """ | |
| Download FAISS index from Supabase, search it, return top-k results. | |
| """ | |
| try: | |
| if db_key == "db6" and topic == "news" and db6_override_store is not None: | |
| print("⚡ Using db6_override_store (fresh from db7)") | |
| vdb = db6_override_store | |
| else: | |
| # 2) Normal logic for all other cases | |
| if db_key in ["db1", "db2", "db3"]: | |
| vdb_local = download_faiss_from_supabase(db_key=db_key, username=username) | |
| else: | |
| # db6 (when no override) and any other FAISS-based dbs | |
| vdb_local = download_faiss_from_supabase(db_key=db_key, username=username) | |
| vdb = FAISS.load_local(vdb_local, model, allow_dangerous_deserialization=True) | |
| except Exception as e: | |
| print(f"❌ Failed to load FAISS index {db_key} from Supabase or override: {e}") | |
| return [] | |
| query_vector = np.array(query_vector, dtype="float32").reshape(1, -1) | |
| if query_vector.shape[1] != vdb.index.d: | |
| print(f"❌ Dimension mismatch: query {query_vector.shape[1]} vs index {vdb.index.d}") | |
| return [] | |
| D, I = vdb.index.search(query_vector, k=k) | |
| results = [] | |
| for idx, score in zip(I[0], D[0]): | |
| doc_id = vdb.index_to_docstore_id.get(idx) | |
| if doc_id is None: | |
| continue | |
| doc = vdb.docstore.search(doc_id) | |
| meta_boost = compute_metadata_boost(doc.metadata, query, filter_keywords) if use_metadata_boost else 0.0 | |
| final_score = score + meta_boost | |
| results.append({ | |
| "content": doc.page_content, | |
| "metadata": {**doc.metadata, "source": db_key}, | |
| "source_db": db_key, | |
| "score": final_score | |
| }) | |
| return sorted(results, key=lambda x: x["score"], reverse=True)[:k] | |
| #used in app_nn.py | |
| def retrieve_all_chunks( | |
| query: str, | |
| model, | |
| user_id: str, # UUID → for tables | |
| username: str, # username → for FAISS buckets | |
| k: int = 10, | |
| filter_keywords: List[str] = None, | |
| topic: str = None, | |
| topic_to_dbs: Dict[str, List[str]] = None, | |
| db6_override_store: Optional[FAISS] = None # NEW | |
| ) -> List[Dict]: | |
| """ | |
| Retrieve chunks across all Supabase sources (dbs + personal_info + chat_history). | |
| """ | |
| query_vector = embed_query(query, model) | |
| all_chunks = [] | |
| # Decide which dbs to search | |
| if topic_to_dbs and topic: | |
| allowed_dbs = topic_to_dbs.get(topic, topic_to_dbs.get("default", [])) | |
| else: | |
| allowed_dbs = ["db1", "db2", "db3", "db4", "db5", "db6", "personal_info", "chat_history"] | |
| print(f"[DEBUG][RETRIEVE] Topic={topic}, allowed_dbs={allowed_dbs}") | |
| for db_key in allowed_dbs: | |
| if db_key == "personal_info": | |
| profile = load_user_info(user_id=user_id) | |
| all_chunks.append({ | |
| "content": json.dumps(profile, indent=2), | |
| "metadata": {"source": "personal_info"}, | |
| "source_db": "personal_info", | |
| "score": 0.0 | |
| }) | |
| elif db_key == "chat_history": | |
| history = load_history_for_display(user_id=user_id) | |
| all_chunks.append({ | |
| "content": json.dumps(history, indent=2), | |
| "metadata": {"source": "chat_history"}, | |
| "source_db": "chat_history", | |
| "score": 0.0 | |
| }) | |
| else: | |
| use_metadata_boost = db_key in ["db1", "db2"] | |
| all_chunks += retrieve_chunks_from_vdb( | |
| db_key=db_key, | |
| query_vector=query_vector, | |
| model=model, | |
| query=query, | |
| username=username, | |
| k=k, | |
| use_metadata_boost=use_metadata_boost, | |
| filter_keywords=filter_keywords, | |
| topic=topic, | |
| db6_override_store=db6_override_store, # NEW: passed through | |
| ) | |
| return all_chunks | |
| #used in db3and6_utils.py | |
| def retrieve_from_db( | |
| db_key: str, | |
| query: str, | |
| model, | |
| username: str, | |
| k: int = 5, | |
| db6_override_store: Optional[FAISS] = None, # NEW (optional) | |
| ) -> List[Dict]: | |
| """ | |
| Retrieve top-k chunks from a single Supabase FAISS db (e.g. db6). | |
| """ | |
| query_vector = embed_query(query, model) | |
| return retrieve_chunks_from_vdb( | |
| db_key=db_key, | |
| query_vector=query_vector, | |
| model=model, | |
| query=query, | |
| username=username, | |
| k=k, | |
| db6_override_store=db6_override_store, # pass through | |
| ) | |
| # used in retrieve_chunks_from_vdb | |
| def compute_metadata_boost(metadata: Dict, query: str, filter_keywords: List[str] = None) -> float: | |
| """ | |
| Compute an additional score boost based on how well the query matches document metadata. | |
| """ | |
| boost = 0.0 | |
| query_lower = query.lower() | |
| priority_keys = [ | |
| "topic", "theme", "tone", "style", | |
| "dialogue_name", "source_title", "characters" | |
| ] | |
| secondary_keys = ["period", "period_covered", "location"] | |
| for key_group, weight_direct, weight_filter in [ | |
| (priority_keys, 0.2, 0.15), | |
| (secondary_keys, 0.1, 0.05), | |
| ]: | |
| for key in key_group: | |
| if key in metadata: | |
| values = metadata[key] if isinstance(metadata[key], list) else [metadata[key]] | |
| for val in values: | |
| val_lower = str(val).lower() | |
| if val_lower in query_lower: | |
| boost += weight_direct | |
| if filter_keywords and val_lower in filter_keywords: | |
| boost += weight_filter | |
| return boost | |
| # PERSONAL_BUCKET = {"personal", "advice","philosophical"} | |
| # def get_story_from_supabase( | |
| # user_id: str, | |
| # username: str, | |
| # conversation_type: str, | |
| # topic_for_story: str | None, | |
| # ) -> dict | None: | |
| # """ | |
| # If conversation_type is personal/advice and topic_for_story is set, | |
| # call a Supabase RPC that: | |
| # - selects a story with your rules (unseen first; else seen<=1 and >90d ago), | |
| # - logs usage, | |
| # - returns a compact JSON payload for the prompt builder. | |
| # Returns None if no suitable story. | |
| # """ | |
| # print(f"[DEBUG] SOCRATIC_STORY = in the function") | |
| # if conversation_type not in PERSONAL_BUCKET: | |
| # print(f"[DEBUG] SOCRATIC_STORY = NOT IN PERSONAL_BUCKET") | |
| # return None | |
| # if not topic_for_story or topic_for_story == "none": | |
| # print(f"[DEBUG] SOCRATIC_STORY = topic_for_story = none") | |
| # return None | |
| # fn = "pick_and_log_story_with_history_rpc" | |
| # payload = { | |
| # "p_user_id": user_id, | |
| # "p_topic": topic_for_story, | |
| # } | |
| # url = f"{SUPABASE_URL}/rest/v1/rpc/{fn}" | |
| # r = requests.post(url, headers=SUPABASE_HEADERS, json=payload, timeout=20) | |
| # if r.status_code == 404 or not r.text or r.text == "null": | |
| # return None | |
| # r.raise_for_status() | |
| # story = r.json() | |
| # return story | |
| # #return r.json() |