SIMPLE_AI / db.py
tiahchia's picture
Upload 6 files
015dbc8 verified
import json
import os
from typing import Any, Dict, List, Optional
from supabase import Client, create_client
_SUPABASE_CLIENT: Optional[Client] = None
def _get_client() -> Client:
"""Create (or reuse) a Supabase client for database interactions."""
global _SUPABASE_CLIENT
if _SUPABASE_CLIENT is None:
url = os.getenv("SUPABASE_URL")
service_role_key = os.getenv("SUPABASE_SERVICE_ROLE_KEY")
if not url or not service_role_key:
raise RuntimeError(
"SUPABASE_URL and SUPABASE_SERVICE_ROLE_KEY must be set for database access."
)
_SUPABASE_CLIENT = create_client(url, service_role_key)
return _SUPABASE_CLIENT
def _execute(query):
response = query.execute()
if response.error:
raise RuntimeError(f"Supabase error: {response.error.message}")
return response
def _ensure_user_profile(user_id: str) -> None:
client = _get_client()
_execute(
client.table("user_profiles").upsert({"user_id": user_id}, on_conflict="user_id")
)
def get_user_profile(user_id: str) -> Optional[Dict[str, Any]]:
client = _get_client()
response = _execute(
client.table("user_profiles")
.select("user_id, name, preferences, personality_summary, last_updated, created_at")
.eq("user_id", user_id)
.limit(1)
)
if not response.data:
return None
record = response.data[0]
if isinstance(record.get("preferences"), str):
try:
record["preferences"] = json.loads(record["preferences"])
except json.JSONDecodeError:
record["preferences"] = {}
return record
def update_user_profile(
user_id: str,
*,
name: Optional[str] = None,
preferences: Optional[str] = None,
personality_summary: Optional[str] = None,
) -> None:
updates: Dict[str, Any] = {}
if name is not None:
updates["name"] = name
if preferences is not None:
updates["preferences"] = preferences
if personality_summary is not None:
updates["personality_summary"] = personality_summary
if not updates:
return
client = _get_client()
_ensure_user_profile(user_id)
_execute(client.table("user_profiles").update(updates).eq("user_id", user_id))
def save_conversation(user_id: str, user_message: str, ai_response: str) -> str:
client = _get_client()
_ensure_user_profile(user_id)
response = _execute(
client.table("conversations").insert(
{
"user_id": user_id,
"user_message": user_message,
"ai_response": ai_response,
}
)
)
inserted = response.data[0]
return str(inserted.get("id"))
def get_recent_conversations(user_id: str, limit: Optional[int] = None) -> List[Dict[str, Any]]:
client = _get_client()
query = (
client.table("conversations")
.select("user_message, ai_response, created_at")
.eq("user_id", user_id)
.order("created_at", desc=True)
)
if limit is not None:
query = query.limit(limit)
response = _execute(query)
return response.data or []
def get_conversation_history(user_id: str) -> List[Dict[str, Any]]:
client = _get_client()
response = _execute(
client.table("conversations")
.select("user_message, ai_response, created_at")
.eq("user_id", user_id)
.order("created_at", desc=False)
)
return response.data or []
def count_user_messages(user_id: str) -> int:
client = _get_client()
response = _execute(
client.table("conversations")
.select("id", count="exact")
.eq("user_id", user_id)
)
return response.count or 0
def update_user_profile_summary(user_id: str, summary: str) -> None:
update_user_profile(user_id, personality_summary=summary)
def get_user_embeddings(user_id: str) -> List[Dict[str, Any]]:
client = _get_client()
response = _execute(
client.table("embeddings")
.select("text, embedding")
.eq("user_id", user_id)
.order("created_at", desc=True)
)
items: List[Dict[str, Any]] = []
for record in response.data or []:
embedding = record.get("embedding")
if isinstance(embedding, str):
try:
embedding = json.loads(embedding)
except json.JSONDecodeError:
embedding = []
items.append({"text": record.get("text", ""), "embedding": embedding})
return items
def add_embedding(user_id: str, text: str, embedding: List[float]) -> None:
client = _get_client()
_ensure_user_profile(user_id)
_execute(
client.table("embeddings").insert(
{
"user_id": user_id,
"text": text,
"embedding": embedding,
}
)
)