multi-pdf-rag-ui / src /streamlit_app.py
Hamza4100's picture
Update src/streamlit_app.py
d82fc22 verified
"""
Streamlit Multi-PDF RAG System with User Authentication
========================================================
Multi-user UI that connects to HF backend API with:
- User login/signup
- Per-user document management
- RAG-based question answering
- Session management
"""
import streamlit as st
import requests
import os
import time
from typing import Optional, Dict, List
import hashlib
import json
from datetime import datetime
from user_management import create_hf_user_manager
# Chat history persistence helpers
def derive_user_id_from_api_key(api_key: str) -> str:
try:
return hashlib.sha256(api_key.encode()).hexdigest()[:12]
except Exception:
return "unknown"
def _chat_history_path_for_user(user_id: str) -> str:
base_dir = os.path.dirname(os.path.abspath(__file__))
user_dir = os.path.join(base_dir, 'users', user_id)
os.makedirs(user_dir, exist_ok=True)
return os.path.join(user_dir, 'chat_history.json')
def load_chat_history_for_user(user_id: str) -> List[Dict]:
# Prefer HF-backed storage when available
try:
user_manager = st.session_state.get('user_manager')
if user_manager and user_manager.enabled:
data = user_manager.load_user_json(user_id, 'chat_history.json')
if isinstance(data, list):
return data
return []
except Exception:
pass
# Fallback to local file
path = _chat_history_path_for_user(user_id)
if os.path.exists(path):
try:
with open(path, 'r', encoding='utf-8') as f:
return json.load(f)
except Exception:
return []
return []
def save_chat_history_for_user(user_id: str, history: List[Dict]):
# Prefer HF-backed storage when available
try:
user_manager = st.session_state.get('user_manager')
if user_manager and user_manager.enabled:
user_manager.save_user_json(user_id, 'chat_history.json', history, commit_message=f"Update chat history for {user_id}")
return
except Exception:
pass
# Fallback to local file
path = _chat_history_path_for_user(user_id)
try:
with open(path, 'w', encoding='utf-8') as f:
json.dump(history, f, indent=2, ensure_ascii=False)
except Exception:
pass
# ============================================
# PAGE CONFIGURATION
# ============================================
st.set_page_config(
page_title="Multi-PDF RAG System",
page_icon="πŸ“š",
layout="wide",
initial_sidebar_state="expanded"
)
# ============================================
# CONFIGURATION
# ============================================
# Backend API URL (your HF backend)
API_BASE_URL = os.environ.get("API_BASE_URL", "https://hamza4100-multi-pdf-rag-api.hf.space")
# Pre-configured API keys that match backend.
# Read from env `API_KEYS` (comma-separated). This defines the limited pool of keys.
raw_keys = os.environ.get("API_KEYS", "")
BACKEND_API_KEYS = {k.strip(): "Available" for k in raw_keys.split(",") if k.strip()}
# ============================================
# USER AUTHENTICATION FUNCTIONS (HF-based)
# ============================================
def hash_password(password: str) -> str:
"""Hash password using SHA256."""
return hashlib.sha256(password.encode()).hexdigest()
def generate_api_key(username: str) -> str:
"""
Assign an available backend API key to user.
Checks HF user database for already assigned keys.
"""
user_manager = st.session_state.get('user_manager')
if not user_manager:
return "key1"
all_users = user_manager.get_all_users()
# Get list of already assigned keys
assigned_keys = {user_data.get("api_key") for user_data in all_users.values() if user_data.get("api_key")}
# If no configured keys, fallback to first env or None
if not BACKEND_API_KEYS:
# No pool configured; return None to indicate unavailable
return None
# Find first available key from configured pool
for key in BACKEND_API_KEYS.keys():
if key not in assigned_keys:
return key
# No keys available
return None
def register_user(username: str, password: str, email: str) -> tuple[bool, str]:
"""
Register new user - stored in HF repository.
Returns:
(success, message/api_key)
"""
# Initialize user manager if not in session
if 'user_manager' not in st.session_state:
HF_TOKEN = os.environ.get("HF_TOKEN")
HF_REPO = os.environ.get("HF_REPO", "Hamza4100/multi-pdf-storage")
st.session_state.user_manager = create_hf_user_manager(hf_token=HF_TOKEN, hf_repo=HF_REPO)
user_manager = st.session_state.user_manager
# Check if user manager is enabled
if not user_manager.enabled:
return False, "User management via HF repo is not configured"
# Check if user already exists
if user_manager.get_user(username):
return False, "User already exists"
# Assign an API key from the pool
api_key = generate_api_key(username)
if not api_key:
return False, "No backend API keys available; signup temporarily disabled"
# Create user in HF repo
success = user_manager.create_user(
username=username,
password_hash=hash_password(password),
email=email,
api_key=api_key
)
if success:
return True, api_key
else:
return False, "Failed to create user in HF repository"
def authenticate_user(username: str, password: str) -> Optional[str]:
"""
Authenticate user and return API key - checks HF repository.
Returns:
API key if successful, None otherwise
"""
# Initialize user manager if not in session
if 'user_manager' not in st.session_state:
HF_TOKEN = os.environ.get("HF_TOKEN")
HF_REPO = os.environ.get("HF_REPO", "Hamza4100/multi-pdf-storage")
st.session_state.user_manager = create_hf_user_manager(hf_token=HF_TOKEN, hf_repo=HF_REPO)
user_manager = st.session_state.user_manager
# Check if user manager is enabled
if not user_manager.enabled:
return None
# Get user from HF
user_data = user_manager.get_user(username)
if not user_data:
return None
# Verify password
if user_manager.verify_password(username, hash_password(password)):
return user_data.get("api_key")
return None
# ============================================
# API CLIENT FUNCTIONS
# ============================================
def get_headers(api_key: str) -> Dict:
"""Get headers with API key for backend requests."""
return {"X-API-KEY": api_key}
def check_backend_health() -> bool:
"""Check if backend API is accessible."""
try:
response = requests.get(f"{API_BASE_URL}/health", timeout=5)
return response.status_code == 200
except:
return False
def upload_pdf(api_key: str, file) -> Dict:
"""
Upload a PDF file to the backend (server-side, Streamlit).
Args:
api_key (str): User's backend API key.
file: Streamlit UploadedFile object.
Returns:
dict: {"success": True, "data": response_json} or {"success": False, "error": "..."}
"""
if not api_key:
return {"success": False, "error": "Missing API key. Please login first."}
# Prepare file payload
try:
# file is a Streamlit UploadedFile
file_content = file.getvalue()
files = {"file": (file.name, file_content, "application/pdf")}
except Exception as e:
return {"success": False, "error": f"Failed to read file: {e}"}
max_retries = 1
read_timeout = 180 # seconds
for attempt in range(max_retries + 1):
try:
response = requests.post(
f"{API_BASE_URL}/upload",
headers={"X-API-KEY": api_key},
files=files,
timeout=(10, read_timeout)
)
# Explicitly handle 403 (invalid API key)
if response.status_code == 403:
return {"success": False, "error": "HTTP 403 Forbidden: Invalid or missing API key."}
# Raise for other HTTP errors
response.raise_for_status()
try:
data = response.json()
except Exception:
return {"success": False, "error": f"Invalid JSON response: {response.text}"}
# Success
return {"success": True, "data": data}
except requests.exceptions.ReadTimeout as e:
if attempt < max_retries:
continue
return {"success": False, "error": f"Read timeout after {read_timeout}s: {e}"}
except requests.exceptions.RequestException as e:
return {"success": False, "error": f"Request failed: {e}"}
except Exception as e:
return {"success": False, "error": f"Unexpected error: {e}"}
def query_documents(api_key: str, question: str, top_k: int = 5, doc_id: str = None) -> Dict:
"""Query documents with question. Optionally scope to a specific `doc_id`."""
try:
data = {"question": question, "top_k": top_k}
if doc_id:
data["doc_id"] = doc_id
response = requests.post(
f"{API_BASE_URL}/query",
headers=get_headers(api_key),
json=data,
timeout=90
)
response.raise_for_status()
return {"success": True, "data": response.json()}
except Exception as e:
return {"success": False, "error": str(e)}
def list_documents(api_key: str) -> Dict:
"""Get list of user's documents."""
try:
response = requests.get(
f"{API_BASE_URL}/documents",
headers=get_headers(api_key),
timeout=10
)
response.raise_for_status()
data = response.json()
# Handle both list and dict responses
if isinstance(data, list):
return {"success": True, "data": data}
else:
return {"success": True, "data": data.get("documents", [])}
except Exception as e:
return {"success": False, "error": str(e)}
def delete_document(api_key: str, doc_id: str) -> Dict:
"""Delete a document."""
try:
response = requests.delete(
f"{API_BASE_URL}/documents/{doc_id}",
headers=get_headers(api_key),
timeout=10
)
response.raise_for_status()
return {"success": True, "data": response.json()}
except Exception as e:
return {"success": False, "error": str(e)}
def get_stats(api_key: str) -> Dict:
"""Get user statistics."""
try:
response = requests.get(
f"{API_BASE_URL}/stats",
headers=get_headers(api_key),
timeout=10
)
response.raise_for_status()
return {"success": True, "data": response.json()}
except Exception as e:
return {"success": False, "error": str(e)}
# ============================================
# SESSION STATE INITIALIZATION
# ============================================
if "logged_in" not in st.session_state:
st.session_state.logged_in = False
if "username" not in st.session_state:
st.session_state.username = None
if "api_key" not in st.session_state:
st.session_state.api_key = None
if "chat_history" not in st.session_state:
st.session_state.chat_history = []
# ============================================
# LOGIN/SIGNUP UI
# ============================================
def show_login_page():
"""Display login/signup page."""
st.title("πŸ“š Multi-PDF RAG System")
st.markdown("### Intelligent Document Q&A System")
# Check backend status
col1, col2, col3 = st.columns([1, 2, 1])
with col2:
if check_backend_health():
st.success("βœ… Backend API is online")
else:
st.error("❌ Backend API is offline")
st.stop()
tab1, tab2 = st.tabs(["πŸ” Login", "πŸ“ Sign Up"])
with tab1:
st.subheader("Login to Your Account")
with st.form("login_form"):
username = st.text_input("Username", key="login_username")
password = st.text_input("Password", type="password", key="login_password")
submit = st.form_submit_button("Login", use_container_width=True)
if submit:
if not username or not password:
st.error("Please enter both username and password")
else:
api_key = authenticate_user(username, password)
if api_key:
st.session_state.logged_in = True
st.session_state.username = username
st.session_state.api_key = api_key
st.success("βœ… Login successful!")
st.rerun()
else:
st.error("❌ Invalid username or password")
with tab2:
st.subheader("Create New Account")
with st.form("signup_form"):
new_username = st.text_input("Username", key="signup_username")
new_email = st.text_input("Email", key="signup_email")
new_password = st.text_input("Password", type="password", key="signup_password")
new_password_confirm = st.text_input("Confirm Password", type="password", key="signup_password_confirm")
submit = st.form_submit_button("Sign Up", use_container_width=True)
if submit:
if not new_username or not new_email or not new_password:
st.error("Please fill all fields")
elif new_password != new_password_confirm:
st.error("Passwords don't match")
elif len(new_password) < 6:
st.error("Password must be at least 6 characters")
else:
success, result = register_user(new_username, new_password, new_email)
if success:
# Auto-login the newly created user
st.session_state.logged_in = True
st.session_state.username = new_username
st.session_state.api_key = result
st.success(f"βœ… Account created and logged in as {new_username}!")
st.rerun()
else:
st.error(f"❌ {result}")
def upload_and_poll_files(api_key: str, uploaded_files: list):
"""
Upload PDFs to backend and poll until processed.
After processing, files disappear from uploader automatically.
"""
if "processed_files" not in st.session_state:
st.session_state.processed_files = []
files_to_process = [f for f in uploaded_files if f.name not in st.session_state.processed_files]
if not files_to_process:
return # Nothing new to process
progress_bar = st.progress(0)
status_text = st.empty()
for i, file in enumerate(files_to_process):
status_text.text(f"Uploading {file.name}...")
result = upload_pdf(api_key, file)
if not result["success"]:
st.error(f"❌ {file.name}: {result['error']}")
progress_bar.progress((i + 1) / len(files_to_process))
continue
data = result.get("data") or {}
job_id = data.get("document_id") or data.get("job_id")
status = data.get("status", "queued")
display_name = st.session_state.username or "You"
# Poll until final status
if job_id:
poll_text = st.empty()
poll_progress = st.progress(0)
max_poll_seconds = 600
poll_interval = 2
elapsed = 0
final_status = status
while elapsed < max_poll_seconds:
try:
resp = requests.get(
f"{API_BASE_URL}/upload-status/{job_id}",
headers={"X-API-KEY": api_key},
timeout=10
)
if resp.status_code == 200:
job = resp.json()
final_status = job.get("status", final_status)
if final_status in ["success", "failed"]:
display_name = job.get("username") or display_name
if final_status == "success":
res = job.get("result", {})
st.success(f"βœ… {file.name} processed for {display_name}: {res.get('filename', file.name)}")
else:
st.error(f"❌ {file.name} failed for {display_name}: {job.get('error', 'unknown error')}")
break
else:
poll_text.text(f"Processing {file.name} for {display_name}... elapsed {elapsed}s")
else:
poll_text.text(f"Waiting for processing... (status {resp.status_code})")
except Exception as e:
poll_text.text(f"Waiting... ({str(e)})")
time.sleep(poll_interval)
elapsed += poll_interval
poll_progress.progress(min(1.0, elapsed / max_poll_seconds))
if elapsed >= max_poll_seconds:
st.error(f"❌ {file.name}: Processing timed out after {max_poll_seconds}s")
# Clear polling message after a short delay
time.sleep(2)
poll_text.empty()
poll_progress.empty()
elif status == "success":
st.success(f"βœ… {file.name} uploaded and processed successfully")
else:
st.info(f"ℹ️ {file.name} has unknown status: {status}")
# Mark as processed so it disappears from uploader
st.session_state.processed_files.append(file.name)
progress_bar.progress((i + 1) / len(files_to_process))
# Final upload complete message, disappears after 2s
status_text.success("βœ… Files uploaded and indexed. You may now ask questions or upload more files.")
time.sleep(2)
status_text.empty()
progress_bar.empty()
# ============================================
# MAIN APPLICATION UI
# ============================================
def show_main_app():
"""Display main application after login."""
# Load chat history from HF (or local fallback) for logged-in user
try:
if st.session_state.api_key and not st.session_state.chat_history:
uid = derive_user_id_from_api_key(st.session_state.api_key)
st.session_state.chat_history = load_chat_history_for_user(uid) or []
except Exception:
pass
# Sidebar
with st.sidebar:
st.title("πŸ“š PDF RAG System")
st.markdown(f"πŸ‘€ **User:** {st.session_state.username}")
# Logout button
if st.button("πŸšͺ Logout", use_container_width=True):
st.session_state.logged_in = False
st.session_state.username = None
st.session_state.api_key = None
st.session_state.chat_history = []
st.rerun()
st.divider()
# Statistics
st.subheader("πŸ“Š Your Statistics")
stats_result = get_stats(st.session_state.api_key)
if stats_result["success"]:
stats = stats_result["data"]
st.metric("Documents", stats.get("total_documents", 0))
st.metric("Text Chunks", stats.get("total_chunks", 0))
st.metric("Index Size", stats.get("index_size", 0))
else:
st.error("Failed to load stats")
st.divider()
# API Key display removed for security
# Main content tabs
tab1, tab2, tab3 = st.tabs(["πŸ’¬ Ask Questions", "πŸ“ Manage Documents", "πŸ“š Chat History"])
# Tab 1: Ask Questions (includes upload)
with tab1:
st.header("πŸ’¬ Ask Questions")
st.markdown("Query your uploaded documents using natural language")
# Integrated uploader (upload before asking)
uploaded_files = st.file_uploader(
"Upload PDF files (optional)",
type=["pdf"],
accept_multiple_files=True,
key="pdf_uploader")
# Initialize processed files list if not present
if "processed_files" not in st.session_state:
st.session_state.processed_files = []
# Filter out already processed files so they disappear from uploader
display_files = [f for f in uploaded_files if f.name not in st.session_state.get("processed_files", [])]
if display_files:
if st.button("πŸ“€ Upload All", use_container_width=True, type="secondary"):
upload_and_poll_files(st.session_state.api_key, display_files)
else:
if uploaded_files:
st.info("βœ… All uploaded files have been processed and removed from uploader.")
# Optional: Clear uploader after processing
# Note: Streamlit does not allow clearing file_uploader after instantiation.
# Keep filenames visible as per UX best practices.
# Chat interface
# Provide a scope selector: General (all docs) or a specific uploaded document
docs_result = list_documents(st.session_state.api_key)
doc_options = [("General (All Documents)", None)]
if docs_result["success"]:
documents = docs_result["data"]
for d in documents:
label = f"{d.get('filename')} ({d.get('doc_id')})"
doc_options.append((label, d.get('doc_id')))
else:
documents = []
selected_item = st.selectbox(
"Query scope",
options=doc_options,
format_func=lambda x: x[0],
index=0
)
selected_doc_id = selected_item[1]
question = st.text_input("Enter your question:", key="question_input", placeholder="What is this document about?")
col1, col2 = st.columns([3, 1])
with col1:
ask_button = st.button("πŸ” Ask", use_container_width=True, type="primary")
with col2:
# Clear history lives in Chat History tab now
clear_button = st.button("πŸ—‘οΈ Clear History", use_container_width=True)
if clear_button:
# clear and save
if st.session_state.api_key:
user_id = derive_user_id_from_api_key(st.session_state.api_key)
st.session_state.chat_history = []
save_chat_history_for_user(user_id, [])
else:
st.session_state.chat_history = []
if ask_button and question:
with st.spinner("πŸ€” Thinking..."):
result = query_documents(st.session_state.api_key, question, doc_id=selected_doc_id)
if result["success"]:
data = result["data"]
answer = data.get("answer", "No answer available")
sources = data.get("sources", [])
# Add to chat history and persist
entry = {
"question": question,
"answer": answer,
"sources": sources,
"timestamp": datetime.now().isoformat()
}
st.session_state.chat_history.append(entry)
if st.session_state.api_key:
user_id = derive_user_id_from_api_key(st.session_state.api_key)
save_chat_history_for_user(user_id, st.session_state.chat_history)
# Immediately display the answer and sources
st.markdown("### Answer")
st.success(answer)
if sources:
with st.expander(f"πŸ“š Sources ({len(sources)})"):
for idx, source in enumerate(sources, 1):
filename = source.get('file') or source.get('filename') or source.get('name') or 'Unknown'
st.markdown(f"**{idx}.** {filename} (Page {source.get('page', 'N/A')})")
snippet = source.get('text') or source.get('snippet') or ''
if snippet:
st.caption(f"_{snippet[:200]}..._")
else:
st.error(f"❌ Error: {result['error']}")
# Tab 2: Manage Documents
with tab2:
st.header("πŸ“ Your Documents")
st.markdown("View and manage your uploaded documents")
# Refresh button
if st.button("πŸ”„ Refresh", key="refresh_docs"):
st.rerun()
# Load documents
result = list_documents(st.session_state.api_key)
if result["success"]:
documents = result["data"]
if not documents:
st.info("πŸ“­ No documents uploaded yet. Go to 'Upload Documents' tab to add some!")
else:
st.success(f"πŸ“š {len(documents)} document(s) found")
# Display documents
for doc in documents:
with st.container():
col1, col2, col3 = st.columns([3, 1, 1])
with col1:
st.markdown(f"**πŸ“„ {doc.get('filename', 'Unknown')}**")
st.caption(f"ID: {doc.get('doc_id', 'N/A')} | Pages: {doc.get('num_pages', 0)} | Chunks: {doc.get('num_chunks', 0)}")
st.caption(f"Uploaded: {doc.get('upload_timestamp', 'N/A')}")
with col2:
st.metric("Pages", doc.get('num_pages', 0))
with col3:
if st.button("πŸ—‘οΈ Delete", key=f"delete_{doc.get('doc_id')}", type="secondary"):
with st.spinner("Deleting..."):
del_result = delete_document(st.session_state.api_key, doc.get('doc_id'))
if del_result["success"]:
st.success("βœ… Deleted")
st.rerun()
else:
st.error(f"❌ {del_result['error']}")
st.divider()
else:
st.error(f"❌ Failed to load documents: {result['error']}")
# Tab 3: Chat History (expandable entries)
with tab3:
st.header("πŸ’¬ Chat History")
st.markdown("Your conversation history with the system (persisted per user)")
if st.session_state.chat_history:
total = len(st.session_state.chat_history)
for idx, chat in enumerate(reversed(st.session_state.chat_history)):
qnum = total - idx
raw_q = chat.get('question', '')
# Short label for expander
label = raw_q if len(raw_q) <= 80 else raw_q[:77] + '...'
with st.expander(f"Q{qnum}: {label}"):
st.markdown(f"**Question:** {raw_q}")
st.markdown(f"**Answer:**")
st.success(chat.get('answer', ''))
if chat.get('sources'):
sources = chat.get('sources') or []
st.markdown(f"**πŸ“š Sources ({len(sources)})**")
for sidx, source in enumerate(sources, 1):
filename = source.get('file') or source.get('filename') or source.get('name') or 'Unknown'
st.markdown(f"**{sidx}.** {filename} (Page {source.get('page', 'N/A')})")
snippet = source.get('text') or source.get('snippet') or ''
if snippet:
st.caption(f"_{snippet[:200]}..._")
ts = chat.get('timestamp') or ''
if ts:
st.caption(f"Asked: {ts}")
st.divider()
else:
st.info("No chat history yet. Ask a question to start a conversation.")
if st.button("πŸ—‘οΈ Clear Chat History", use_container_width=True):
if st.session_state.api_key:
uid = derive_user_id_from_api_key(st.session_state.api_key)
save_chat_history_for_user(uid, [])
st.session_state.chat_history = []
st.experimental_rerun()
# ============================================
# MAIN APPLICATION ENTRY POINT
# ============================================
def main():
"""Main application entry point."""
if not st.session_state.logged_in:
show_login_page()
else:
show_main_app()
if __name__ == "__main__":
main()