Socrates_docker / Retrieve.py
alesamodio's picture
replace streamlit
c7a3272
# Retrieve.py (fixed version)
import os
import json
import numpy as np
from typing import List, Dict, Optional, Any
import requests
from langchain_community.vectorstores import FAISS
import numpy as np
#from db_paths import (PERSONAL_INFO_CHUNKS_PATH,CHAT_HISTORY_CHUNKS_PATH)
import json
from supabase_ie import load_user_info, load_history_for_display, download_faiss_from_supabase
from config import SUPABASE_URL, SUPABASE_HEADERS
#used in embed_query
def normalize(v):
v = np.array(v)
norm = np.linalg.norm(v)
return v if norm == 0 else (v / norm)
#used in retrieve_all_chunks
def embed_query(query: str, model) -> np.ndarray:
formatted_query = f"query: {query.strip()}"
vector = model.embed_query(formatted_query)
return normalize(vector).astype("float32").reshape(1, -1)
#used in retrieve_all_chunks
def retrieve_chunks_from_vdb(
db_key: str,
query_vector: np.ndarray,
model,
query: str,
username: str,
k: int = 10,
use_metadata_boost: bool = False,
filter_keywords: List[str] = None,
topic: str = None,
db6_override_store: Optional[FAISS] = None, # NEW
) -> List[Dict]:
"""
Download FAISS index from Supabase, search it, return top-k results.
"""
try:
if db_key == "db6" and topic == "news" and db6_override_store is not None:
print("⚡ Using db6_override_store (fresh from db7)")
vdb = db6_override_store
else:
# 2) Normal logic for all other cases
if db_key in ["db1", "db2", "db3"]:
vdb_local = download_faiss_from_supabase(db_key=db_key, username=username)
else:
# db6 (when no override) and any other FAISS-based dbs
vdb_local = download_faiss_from_supabase(db_key=db_key, username=username)
vdb = FAISS.load_local(vdb_local, model, allow_dangerous_deserialization=True)
except Exception as e:
print(f"❌ Failed to load FAISS index {db_key} from Supabase or override: {e}")
return []
query_vector = np.array(query_vector, dtype="float32").reshape(1, -1)
if query_vector.shape[1] != vdb.index.d:
print(f"❌ Dimension mismatch: query {query_vector.shape[1]} vs index {vdb.index.d}")
return []
D, I = vdb.index.search(query_vector, k=k)
results = []
for idx, score in zip(I[0], D[0]):
doc_id = vdb.index_to_docstore_id.get(idx)
if doc_id is None:
continue
doc = vdb.docstore.search(doc_id)
meta_boost = compute_metadata_boost(doc.metadata, query, filter_keywords) if use_metadata_boost else 0.0
final_score = score + meta_boost
results.append({
"content": doc.page_content,
"metadata": {**doc.metadata, "source": db_key},
"source_db": db_key,
"score": final_score
})
return sorted(results, key=lambda x: x["score"], reverse=True)[:k]
#used in app_nn.py
def retrieve_all_chunks(
query: str,
model,
user_id: str, # UUID → for tables
username: str, # username → for FAISS buckets
k: int = 10,
filter_keywords: List[str] = None,
topic: str = None,
topic_to_dbs: Dict[str, List[str]] = None,
db6_override_store: Optional[FAISS] = None # NEW
) -> List[Dict]:
"""
Retrieve chunks across all Supabase sources (dbs + personal_info + chat_history).
"""
query_vector = embed_query(query, model)
all_chunks = []
# Decide which dbs to search
if topic_to_dbs and topic:
allowed_dbs = topic_to_dbs.get(topic, topic_to_dbs.get("default", []))
else:
allowed_dbs = ["db1", "db2", "db3", "db4", "db5", "db6", "personal_info", "chat_history"]
print(f"[DEBUG][RETRIEVE] Topic={topic}, allowed_dbs={allowed_dbs}")
for db_key in allowed_dbs:
if db_key == "personal_info":
profile = load_user_info(user_id=user_id)
all_chunks.append({
"content": json.dumps(profile, indent=2),
"metadata": {"source": "personal_info"},
"source_db": "personal_info",
"score": 0.0
})
elif db_key == "chat_history":
history = load_history_for_display(user_id=user_id)
all_chunks.append({
"content": json.dumps(history, indent=2),
"metadata": {"source": "chat_history"},
"source_db": "chat_history",
"score": 0.0
})
else:
use_metadata_boost = db_key in ["db1", "db2"]
all_chunks += retrieve_chunks_from_vdb(
db_key=db_key,
query_vector=query_vector,
model=model,
query=query,
username=username,
k=k,
use_metadata_boost=use_metadata_boost,
filter_keywords=filter_keywords,
topic=topic,
db6_override_store=db6_override_store, # NEW: passed through
)
return all_chunks
#used in db3and6_utils.py
def retrieve_from_db(
db_key: str,
query: str,
model,
username: str,
k: int = 5,
db6_override_store: Optional[FAISS] = None, # NEW (optional)
) -> List[Dict]:
"""
Retrieve top-k chunks from a single Supabase FAISS db (e.g. db6).
"""
query_vector = embed_query(query, model)
return retrieve_chunks_from_vdb(
db_key=db_key,
query_vector=query_vector,
model=model,
query=query,
username=username,
k=k,
db6_override_store=db6_override_store, # pass through
)
# used in retrieve_chunks_from_vdb
def compute_metadata_boost(metadata: Dict, query: str, filter_keywords: List[str] = None) -> float:
"""
Compute an additional score boost based on how well the query matches document metadata.
"""
boost = 0.0
query_lower = query.lower()
priority_keys = [
"topic", "theme", "tone", "style",
"dialogue_name", "source_title", "characters"
]
secondary_keys = ["period", "period_covered", "location"]
for key_group, weight_direct, weight_filter in [
(priority_keys, 0.2, 0.15),
(secondary_keys, 0.1, 0.05),
]:
for key in key_group:
if key in metadata:
values = metadata[key] if isinstance(metadata[key], list) else [metadata[key]]
for val in values:
val_lower = str(val).lower()
if val_lower in query_lower:
boost += weight_direct
if filter_keywords and val_lower in filter_keywords:
boost += weight_filter
return boost
# PERSONAL_BUCKET = {"personal", "advice","philosophical"}
# def get_story_from_supabase(
# user_id: str,
# username: str,
# conversation_type: str,
# topic_for_story: str | None,
# ) -> dict | None:
# """
# If conversation_type is personal/advice and topic_for_story is set,
# call a Supabase RPC that:
# - selects a story with your rules (unseen first; else seen<=1 and >90d ago),
# - logs usage,
# - returns a compact JSON payload for the prompt builder.
# Returns None if no suitable story.
# """
# print(f"[DEBUG] SOCRATIC_STORY = in the function")
# if conversation_type not in PERSONAL_BUCKET:
# print(f"[DEBUG] SOCRATIC_STORY = NOT IN PERSONAL_BUCKET")
# return None
# if not topic_for_story or topic_for_story == "none":
# print(f"[DEBUG] SOCRATIC_STORY = topic_for_story = none")
# return None
# fn = "pick_and_log_story_with_history_rpc"
# payload = {
# "p_user_id": user_id,
# "p_topic": topic_for_story,
# }
# url = f"{SUPABASE_URL}/rest/v1/rpc/{fn}"
# r = requests.post(url, headers=SUPABASE_HEADERS, json=payload, timeout=20)
# if r.status_code == 404 or not r.text or r.text == "null":
# return None
# r.raise_for_status()
# story = r.json()
# return story
# #return r.json()