fastapi-hfspace / main.py
aghoraguru's picture
Add root endpoint to return a welcome message
e3d3637
import os
import logging
import json
import uuid
from typing import Dict, List, Optional, Any, Union
from datetime import datetime
# FastAPI and related imports
from fastapi import (
FastAPI,
WebSocket,
WebSocketDisconnect,
HTTPException,
Body,
Query,
Depends
)
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, EmailStr
from dotenv import load_dotenv
# LangChain / RAG Pipeline Imports (placeholder imports—adjust for your project)
from langchain_core.documents import Document
from langchain_community.vectorstores import Chroma
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain_core.tools import tool
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import DirectoryLoader
from langgraph.graph import StateGraph, START, END
from typing_extensions import TypedDict
from bs4 import BeautifulSoup
import requests
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi import Depends
# Supabase
from supabase import create_client, Client
###############################################################################
# ENV & LOGGING SETUP
###############################################################################
load_dotenv()
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
SUPABASE_URL = os.getenv("SUPABASE_URL")
SUPABASE_ANON_KEY = os.getenv("SUPABASE_ANON_KEY")
SUPABASE_SERVICE_ROLE_KEY = os.getenv("SUPABASE_SERVICE_ROLE_KEY")
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s - %(message)s"
)
if not OPENAI_API_KEY:
raise ValueError("Missing OPENAI_API_KEY in environment!")
os.environ["OPENAI_API_KEY"] = OPENAI_API_KEY
###############################################################################
# SUPABASE CREDENTIALS & CLIENT INITIALIZATION
###############################################################################
SUPABASE_URL = os.getenv("SUPABASE_URL")
SUPABASE_ANON_KEY = os.getenv("SUPABASE_ANON_KEY")
SUPABASE_SERVICE_ROLE_KEY = os.getenv("SUPABASE_SERVICE_ROLE_KEY")
supabase_client: Client = create_client(SUPABASE_URL, SUPABASE_ANON_KEY)
supabase_admin: Client = create_client(SUPABASE_URL, SUPABASE_SERVICE_ROLE_KEY)
###############################################################################
# OPTIONAL: CREATE TABLES / SCHEMA
###############################################################################
def create_db_schema() -> None:
"""
You can run this function ONCE in a safe admin environment to create
the necessary tables in your Supabase Postgres database (if they do not exist).
"""
schema_sql = """
-- Enable UUID generation if not enabled
CREATE EXTENSION IF NOT EXISTS "uuid-ossp";
CREATE TABLE IF NOT EXISTS public.users (
id uuid DEFAULT uuid_generate_v4() PRIMARY KEY,
created_at timestamp with time zone DEFAULT now(),
email text UNIQUE NOT NULL,
password_hash text,
full_name text,
last_login timestamp with time zone,
role text DEFAULT 'user'
);
CREATE TABLE IF NOT EXISTS public.chats (
chat_id uuid DEFAULT uuid_generate_v4() PRIMARY KEY,
user_id uuid REFERENCES public.users (id) ON DELETE CASCADE,
created_at timestamp with time zone DEFAULT now(),
title text,
last_updated timestamp with time zone DEFAULT now()
);
CREATE TABLE IF NOT EXISTS public.chat_session (
session_id uuid DEFAULT uuid_generate_v4() PRIMARY KEY,
chat_id uuid REFERENCES public.chats (chat_id) ON DELETE CASCADE,
created_at timestamp with time zone DEFAULT now(),
updated_at timestamp with time zone DEFAULT now(),
content jsonb DEFAULT '{}'::jsonb
);
CREATE TABLE IF NOT EXISTS public.logs (
log_id uuid DEFAULT uuid_generate_v4() PRIMARY KEY,
session_id uuid REFERENCES public.chat_session (session_id) ON DELETE CASCADE,
timestamp timestamp with time zone DEFAULT now(),
event_type text,
details jsonb DEFAULT '{}'::jsonb
);
CREATE TABLE IF NOT EXISTS public.ai_thought_table (
id uuid DEFAULT uuid_generate_v4() PRIMARY KEY,
created_at timestamp with time zone DEFAULT now(),
session_id uuid REFERENCES public.chat_session (session_id) ON DELETE CASCADE,
thought_process text,
decision_making jsonb DEFAULT '{}'::jsonb
);
"""
logging.info("Schema creation SQL:\n%s", schema_sql)
# You can run this SQL in Supabase's SQL Editor, or use an RPC if you have one:
# supabase_admin.rpc('execute_sql', {'q': schema_sql}).execute()
# Or manually run it in your project's SQL editor.
pass
###############################################################################
# FASTAPI APP
###############################################################################
app = FastAPI(
title="RAG-GENAI-Women",
version="1.0.0",
description=(
"A production-ready pipeline with session-based JSON storage, plus "
"auth endpoints for SignUp, Login, and more. "
"Supports multiple concurrent WebSocket connections (one per session)."
)
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Restrict in production
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Add security scheme
security = HTTPBearer()
###############################################################################
# LLM & VECTOR STORE SETUP
###############################################################################
embeddings_model = OpenAIEmbeddings(model="text-embedding-3-large")
llm = ChatOpenAI(model="gpt-4o") # Example placeholder name
llm_decision_maker = ChatOpenAI(model="gpt-4o-mini")
vector_store = Chroma(
persist_directory="./chroma_db",
embedding_function=embeddings_model
)
def get_time_date() -> str:
return datetime.now().strftime("%Y-%m-%d %H:%M:%S")
def get_country_from_ip() -> str:
# Stub; in production, do a real IP lookup
return "India"
###############################################################################
# WEB SEARCH TOOL
###############################################################################
@tool
def web_search_tool(query: str) -> Dict[str, Any]:
"""
Perform a web search and return a single dictionary:
{"results": [...], "count": <int>}
"""
from googlesearch import search
results = []
try:
for url in search(query, num_results=3):
try:
resp = requests.get(url, timeout=10)
soup = BeautifulSoup(resp.text, "html.parser")
snippet = soup.get_text()[:1000]
results.append({"url": url, "content": snippet})
except Exception as e:
logging.exception("Error fetching content from: %s", url)
results.append({"url": url, "content": f"Error: {str(e)}"})
except Exception as e:
logging.error("Error performing search: %s", e)
return {"results": results, "count": len(results)}
###############################################################################
# RAG PIPELINE
###############################################################################
class State(TypedDict):
question: str
retrieved_context: List[Document]
demographic_context: str
web_search_needed: int
web_search_results: List[dict]
final_answer: str
tone: str
def add_demographic_context(state: State):
country = get_country_from_ip()
timestamp = get_time_date()
demo = f"User from {country} at {timestamp}"
logging.info(f"[add_demographic_context] {demo}")
return {"demographic_context": demo}
def retrieve(state: State):
logging.info("[retrieve] Searching vector store...")
user_query = state["question"]
docs = vector_store.similarity_search(user_query)
combined_text = "\n\n".join(doc.page_content for doc in docs)
tone = state.get("tone", "detailed")
sys_msg = (
"You are an assistant extracting key points. Focus on relevant details. Think like a lawyer and search for relevant details."
if tone != "casual" else
"You are an assistant extracting key points in a conversational manner. Think like a lawyer and search for relevant details."
)
prompt = [
{"role": "system", "content": sys_msg},
{
"role": "user",
"content": (
f"Query:\n{user_query}\n\nDocs:\n{combined_text}\n"
"Extract relevant points."
)
}
]
resp = llm.invoke(prompt)
extracted = resp.content.strip()
return {"retrieved_context": [Document(page_content=extracted, metadata={"source": "filtered"})]}
def decide_web_search(state: State):
logging.info("[decide_web_search]")
retrieved_text = "\n\n".join(doc.page_content for doc in state["retrieved_context"])
messages = [
{
"role": "system",
"content": (
"You are a decision-making assistant. "
"Respond strictly with '1' if a web search is required, or '0' if not."
)
},
{
"role": "user",
"content": f"Question:\n{state['question']}\n\nContext:\n{retrieved_text}"
},
]
response = llm_decision_maker.invoke(messages)
decision = response.content.strip()
logging.info(f"[decide_web_search] LLM decision raw: {decision}")
try:
return {"web_search_needed": int(decision)}
except ValueError:
logging.error(f"Invalid decision response: {decision}")
raise ValueError(f"Unexpected LLM response for web search decision: {decision}")
def perform_web_search(state: State):
need_search = state.get("web_search_needed", 0)
if need_search == 1:
logging.info("[perform_web_search] Searching the web...")
query = f"{state['question']} ({state['demographic_context']})"
search_data = web_search_tool.invoke(query) # returns a dict
structured_results = search_data["results"]
summarized_results = []
for r in structured_results:
c = r["content"]
sum_prompt = [
{"role": "system", "content": "Summarize the content with short citation."},
{"role": "user", "content": f"{c}\nURL: {r['url']}"}
]
sum_resp = llm.invoke(sum_prompt)
summarized_results.append({
"url": r["url"],
"summary": sum_resp.content.strip()
})
return {"web_search_results": summarized_results}
else:
logging.info("[perform_web_search] Skipping web search...")
return {"web_search_results": []}
def consolidate(state: State):
logging.info("[consolidate] Generating final answer...")
retrieved_text = "\n\n".join(doc.page_content for doc in state["retrieved_context"])
web_data = state.get("web_search_results", [])
sources_text = "\n".join(
f"URL: {r['url']}\nSummary: {r['summary']}" for r in web_data
)
tone = state.get("tone", "detailed")
sys_msg = (
"You are a precise assistant. Combine context and results into a final answer."
if tone != "casual" else
"You are a friendly assistant. Combine context and results in a final manner."
)
final_prompt = [
{"role": "system", "content": sys_msg},
{
"role": "user",
"content": (
f"Question:\n{state['question']}\n\n"
f"Retrieved:\n{retrieved_text}\n\n"
f"Web:\n{sources_text}\n\n"
"Give a comprehensive final answer."
)
}
]
resp = llm.invoke(final_prompt)
raw_ans = resp.content.strip()
# Summarize for chat
summ_prompt = [
{
"role": "system",
"content": "Provide a concise version of the answer, preserving key details."
},
{
"role": "user",
"content": raw_ans
}
]
s_resp = llm.invoke(summ_prompt)
chat_ans = s_resp.content.strip()
final = {
"crunched_summary": chat_ans,
"full_answer": raw_ans,
"sources": web_data if web_data else None,
"source_type": (
"Web + Retrieved" if web_data and retrieved_text
else "Web" if web_data
else "Retrieved"
)
}
return {"final_answer": final}
###############################################################################
# PIPELINE GRAPH BUILD
###############################################################################
graph_builder = StateGraph(State).add_sequence([
add_demographic_context,
retrieve,
decide_web_search,
perform_web_search,
consolidate
])
graph_builder.add_edge(START, "add_demographic_context")
graph_builder.add_edge("add_demographic_context", "retrieve")
graph_builder.add_edge("retrieve", "decide_web_search")
graph_builder.add_edge("decide_web_search", "perform_web_search")
graph_builder.add_edge("perform_web_search", "consolidate")
graph_builder.add_edge("consolidate", END)
pipeline_graph = graph_builder.compile()
###############################################################################
# SESSION-BASED JSON STORAGE
###############################################################################
SESSIONS_DIR = "sessions_data"
# os.makedirs(SESSIONS_DIR, exist_ok=True)
def generate_session_id() -> str:
return str(uuid.uuid4())
def get_session_file(session_id: str) -> str:
return os.path.join(SESSIONS_DIR, f"{session_id}.json")
def load_session_from_json(session_id: str) -> dict:
"""Load or create session data from JSON."""
path = get_session_file(session_id)
if os.path.exists(path):
with open(path, "r", encoding="utf-8") as f:
return json.load(f)
else:
data = {
"session_id": session_id,
"started_at": get_time_date(),
"messages": []
}
with open(path, "w", encoding="utf-8") as f:
json.dump(data, f, indent=2)
return data
def save_session_to_json(session_data: dict):
session_id = session_data["session_id"]
path = get_session_file(session_id)
with open(path, "w", encoding="utf-8") as f:
json.dump(session_data, f, indent=2)
def append_message(session_id: str, role: str, content: str):
data = load_session_from_json(session_id)
data["messages"].append({
"role": role,
"content": content,
"timestamp": get_time_date()
})
save_session_to_json(data)
###############################################################################
# AUTH & USER MODELS
###############################################################################
class SignupRequest(BaseModel):
email: EmailStr
password: str
full_name: Optional[str] = None
class SignupResponse(BaseModel):
user_id: Optional[str]
message: str
class LoginRequest(BaseModel):
email: EmailStr
password: str
class LoginResponse(BaseModel):
access_token: Optional[str]
token_type: str = "bearer"
user_id: Optional[str]
message: str
class LogoutResponse(BaseModel):
message: str
class Identity(BaseModel):
provider: str
identity_id: str
created_at: Union[datetime, str]
last_sign_in_at: Union[datetime, str]
class UserProfile(BaseModel):
user_id: str
email: str
full_name: Optional[str]
role: str
created_at: datetime
updated_at: Optional[datetime]
last_sign_in_at: Optional[datetime]
email_verified: bool
phone_verified: bool
is_anonymous: bool
app_metadata: Dict[str, Union[str, List[str]]]
user_metadata: Dict[str, Union[str, bool]]
identities: List[Identity]
###############################################################################
# HTTP MODELS
###############################################################################
class AskRequest(BaseModel):
user_input: str
tone: Optional[str] = "detailed"
class AskResponse(BaseModel):
session_id: str
message: str
###############################################################################
# HTTP AUTH ENDPOINTS
###############################################################################
@app.get("/")
def read_root():
return {"message": "Hello from FastAPI on Hugging Face Spaces!"}
@app.post("/auth/signup", response_model=SignupResponse)
def signup(payload: SignupRequest):
"""
Sign up a new user using Supabase Auth.
Optionally store extra info (e.g., full_name) in your custom 'users' table.
"""
# 1) Use Supabase Auth to create the user
try:
result = supabase_client.auth.sign_up(
{
"email": payload.email,
"password": payload.password
}
)
except Exception as e:
logging.exception("[signup] Error from Supabase Auth sign_up")
return SignupResponse(user_id=None, message=f"Sign up failed: {str(e)}")
if result.user is None:
# Possibly means "Confirm email" is enabled, user needs to verify
return SignupResponse(
user_id=None,
message="User created, but email confirmation required."
)
# 2) The user is created in supabase.auth. We can optionally store extra data
user_id = result.user.id
full_name = payload.full_name if payload.full_name else ""
now = datetime.utcnow()
# Attempt to store in our custom 'users' table
try:
insert_res = supabase_admin.table("users").insert({
"id": user_id,
"email": payload.email,
"password_hash": "N/A (Using Supabase Auth)",
"full_name": full_name,
"created_at": now.isoformat(),
"last_login": None,
"role": "user"
}).execute()
logging.info("[signup] Inserted custom user record: %s", insert_res.data)
except Exception as e:
logging.exception("[signup] Error inserting into 'users' table")
return SignupResponse(user_id=user_id, message="Sign up successful.")
@app.post("/auth/login", response_model=LoginResponse)
def login(payload: LoginRequest):
"""
Log in an existing user with Supabase Auth.
Return the access_token, which you can store on client side for usage,
or rely on same-site cookies if you have it configured.
"""
try:
result = supabase_client.auth.sign_in_with_password(
{
"email": payload.email,
"password": payload.password
}
)
if result.user is None:
return LoginResponse(
access_token=None,
user_id=None,
message="Login failed: invalid credentials or user not confirmed."
)
user_id = result.user.id
access_token = result.session.access_token if result.session else None
# We can track "last_login" in our custom table:
now = datetime.utcnow()
try:
supabase_admin.table("users").update({
"last_login": now.isoformat()
}).eq("id", user_id).execute()
except Exception as e:
logging.exception("[login] Error updating last_login in 'users' table")
return LoginResponse(
access_token=access_token,
user_id=user_id,
message="Login success."
)
except Exception as e:
logging.exception("[login] Error from Supabase Auth sign_in_with_password")
return LoginResponse(
access_token=None,
user_id=None,
message=f"Login error: {str(e)}"
)
@app.post("/auth/logout", response_model=LogoutResponse)
def logout():
"""
Invalidate the user's session if you are storing it on the server
or using persistent session management. For token-based approach,
you can have the client discard the token and possibly call
supabase_client.auth.sign_out() as well.
"""
try:
# This will revoke the refresh token from Supabase's perspective
supabase_client.auth.sign_out()
return LogoutResponse(message="Logout successful.")
except Exception as e:
logging.exception("[logout] Error from Supabase Auth sign_out")
raise HTTPException(status_code=500, detail="Logout failed.")
@app.get("/auth/me", response_model=UserProfile)
def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(security)):
"""
Retrieve info about the currently logged-in user.
"""
try:
# Extract access token from Authorization header
access_token = credentials.credentials
# Retrieve user details using the access token
user_response = supabase_client.auth.get_user(access_token)
if not user_response or not user_response.user:
raise HTTPException(status_code=401, detail="User not authenticated.")
user = user_response.user
# Optionally fetch additional data from your custom `users` table
res = supabase_client.table("users").select("*").eq("id", user.id).single().execute()
record = res.data
# Construct the UserProfile response
return UserProfile(
user_id=user.id,
email=user.email,
full_name=record.get("full_name") if record else None,
role=user.role,
created_at=user.created_at,
updated_at=user.updated_at,
last_sign_in_at=user.last_sign_in_at,
email_verified=user.user_metadata.get("email_verified", False),
phone_verified=user.user_metadata.get("phone_verified", False),
is_anonymous=user.is_anonymous,
app_metadata=user.app_metadata,
user_metadata=user.user_metadata,
identities=[
Identity(
provider=identity.provider,
identity_id=identity.identity_id,
created_at=str(identity.created_at) if isinstance(identity.created_at, datetime) else identity.created_at,
last_sign_in_at=str(identity.last_sign_in_at) if isinstance(identity.last_sign_in_at, datetime) else identity.last_sign_in_at,
)
for identity in user.identities
] if user.identities else []
)
except Exception as e:
logging.exception("[get_current_user] Error retrieving user info")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/auth/confirm")
def confirm_email(
access_token: str = Query(...),
refresh_token: str = Query(...),
expires_in: int = Query(...),
token_type: str = Query(...)
):
"""
Endpoint to handle confirmation links sent via email.
"""
try:
# Use Supabase client to retrieve and confirm the user
result = supabase_client.auth.get_user(access_token)
if result.user:
return {"status": "success", "message": "Email confirmed successfully.", "user": result.user}
else:
raise HTTPException(status_code=400, detail="Invalid or expired confirmation link.")
except Exception as e:
logging.exception("[confirm_email] Error during confirmation")
raise HTTPException(status_code=500, detail=str(e))
###############################################################################
# HTTP ENDPOINTS
###############################################################################
@app.get("/health")
def health_check():
"""Simple health check endpoint."""
return {"status": "ok", "message": "Service is healthy."}
@app.post("/ask", response_model=AskResponse)
def ask_endpoint(payload: AskRequest):
"""
Optional endpoint to create a session or store the first user message
before switching to WebSockets.
"""
session_id = generate_session_id()
user_input = payload.user_input
append_message(session_id, "user", user_input)
return AskResponse(
session_id=session_id,
message="Session created. Connect via WS to continue."
)
@app.post("/reset")
def reset_session(session_id: str = Body(..., embed=True)):
"""
Deletes the session JSON file, effectively resetting the conversation.
"""
path = get_session_file(session_id)
if os.path.exists(path):
os.remove(path)
return {"status": "ok", "message": f"Session {session_id} reset."}
else:
raise HTTPException(status_code=404, detail="Session not found.")
###############################################################################
# WEBSOCKET CONCURRENCY
###############################################################################
class ConnectionManager:
"""
Manages EXACTLY ONE active WebSocket per session_id.
If a new WebSocket for the same session_id arrives,
it closes the old connection first.
"""
def __init__(self):
self.active_connections: Dict[str, WebSocket] = {}
async def connect(self, session_id: str, websocket: WebSocket):
# If there's already an active socket for this session, close it
if session_id in self.active_connections:
old_ws = self.active_connections[session_id]
logging.info(f"[WS] Closing old connection for session {session_id} to allow new one.")
await old_ws.close(code=4000, reason="Replaced by a new connection")
logging.info(f"[WS] Accepting WebSocket for session: {session_id}")
await websocket.accept()
self.active_connections[session_id] = websocket
logging.info(f"[WS] Session {session_id} connected. "
f"Total active sessions: {len(self.active_connections)}")
def disconnect(self, session_id: str, websocket: WebSocket):
stored_ws = self.active_connections.get(session_id)
if stored_ws is websocket:
del self.active_connections[session_id]
logging.info(f"[WS] Session {session_id} disconnected. "
f"Remaining active sessions: {len(self.active_connections)}")
async def send_json(self, session_id: str, data: dict):
ws = self.active_connections.get(session_id)
if ws is not None:
await ws.send_json(data)
manager = ConnectionManager()
@app.websocket("/ws")
async def websocket_endpoint(
websocket: WebSocket,
session_id: Optional[str] = Query(None),
tone: str = Query("detailed")
):
"""
WebSocket endpoint.
- The user can pass `session_id` and `tone` as query parameters, e.g.:
ws://localhost:8000/ws?session_id=abc-123&tone=casual
- Or omit `session_id` to generate one automatically.
- Each message from client must be JSON with {"user_input": "..."}.
"""
if not session_id:
session_id = generate_session_id()
logging.info(f"[WS] No session_id provided. Created new: {session_id}")
await manager.connect(session_id, websocket)
while True:
try:
data = await websocket.receive_json()
user_input = data.get("user_input", "")
append_message(session_id, "user", user_input)
session_data = load_session_from_json(session_id)
conversation_text = ""
for msg in session_data["messages"]:
role_name = msg["role"].capitalize()
conversation_text += f"{role_name}: {msg['content']}\n"
chain_state = {
"question": conversation_text,
"tone": tone
}
await manager.send_json(session_id, {
"type": "status",
"message": "Starting pipeline..."
})
try:
async for step_result in pipeline_graph.astream(chain_state, stream_mode="values"):
if "demographic_context" in step_result:
await manager.send_json(session_id, {
"type": "status",
"message": f"Demographic: {step_result['demographic_context']}"
})
if "retrieved_context" in step_result:
excerpt = step_result["retrieved_context"][0].page_content[:60]
await manager.send_json(session_id, {
"type": "status",
"message": f"Retrieved context: {excerpt}..."
})
if "web_search_needed" in step_result:
await manager.send_json(session_id, {
"type": "status",
"message": f"Web search needed = {step_result['web_search_needed']}"
})
if "web_search_results" in step_result:
count = len(step_result["web_search_results"])
await manager.send_json(session_id, {
"type": "status",
"message": f"Web search returned {count} results."
})
if "final_answer" in step_result:
final_ans = step_result["final_answer"]
short_answer = final_ans["crunched_summary"]
append_message(session_id, "assistant", short_answer)
await manager.send_json(session_id, {
"type": "final_answer",
"short_answer": short_answer,
"full_answer": final_ans["full_answer"],
"sources": final_ans["sources"],
"source_type": final_ans["source_type"]
})
except Exception as e:
logging.exception("[WS] Error during pipeline streaming.")
await manager.send_json(session_id, {
"type": "error",
"message": str(e)
})
except WebSocketDisconnect:
logging.info(f"[WS] Client disconnected for session {session_id}")
manager.disconnect(session_id, websocket)
break
except Exception as e:
logging.exception("[WS] Error reading JSON from WebSocket.")
await manager.send_json(session_id, {
"type": "error",
"message": str(e)
})
# Not disconnecting immediately—client may continue with valid input
###############################################################################
# LOCAL DEV ENTRY POINT
###############################################################################
# if __name__ == "__main__":
# import uvicorn
# # Uncomment if you want to log out or run the DDL
# uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)