SakibAhmed's picture
Upload 14 files
ca6e669 verified
from flask import Flask, request, jsonify, Response, render_template
from flask_cors import CORS
import os
import logging
import functools
import pandas as pd
import threading
import time
import tempfile
import shutil
from dotenv import load_dotenv
# Load environment variables
load_dotenv()
# Custom Imports
from rag_system import initialize_and_get_rag_system
from config import (
API_USERNAME, API_PASSWORD, RAG_SOURCES_DIR, RAG_STORAGE_PARENT_DIR,
GDRIVE_INDEX_ENABLED, GDRIVE_INDEX_ID_OR_URL,
GDRIVE_USERS_CSV_ENABLED, GDRIVE_USERS_CSV_ID_OR_URL,
ADMIN_USERNAME, ADMIN_PASSWORD, RAG_RERANKER_K,
EXTERNAL_URL, URL_UPDATE_PERIOD_MINUTES, URL_FETCH_ENABLED,
RAG_CSV_MAX_RESULTS, RAG_CSV_CONFIDENCE_THRESHOLD
)
from utils import download_and_unzip_gdrive_file, download_gdrive_file, fetch_and_clean_url
# Logging Setup
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Flask Init
app = Flask(__name__, static_folder='static', template_folder='templates')
CORS(app)
# Global State
rag_system = None
user_df = None
_APP_BASE_DIR = os.path.dirname(os.path.abspath(__file__))
# --- Helper: Load Users ---
def load_users_from_csv():
global user_df
assets_folder = os.path.join(_APP_BASE_DIR, 'assets')
os.makedirs(assets_folder, exist_ok=True)
users_csv_path = os.path.join(assets_folder, 'users.csv')
try:
if os.path.exists(users_csv_path):
user_df = pd.read_csv(users_csv_path)
# Normalize email
if 'email' in user_df.columns:
user_df['email'] = user_df['email'].str.lower().str.strip()
logger.info(f"Loaded {len(user_df)} users from CSV.")
else:
logger.warning("users.csv not found in assets folder.")
user_df = None
except Exception as e:
logger.error(f"Failed to load users.csv: {e}")
user_df = None
# --- Helper: Auth Decorators ---
def require_api_auth(f):
"""Protects the N8N Webhook endpoint"""
@functools.wraps(f)
def decorated(*args, **kwargs):
auth = request.authorization
if not auth or auth.username != API_USERNAME or auth.password != API_PASSWORD:
return Response('Unauthorized', 401, {'WWW-Authenticate': 'Basic realm="API Login Required"'})
return f(*args, **kwargs)
return decorated
def require_admin_auth(f):
"""Protects Admin Rebuild/Update endpoints"""
@functools.wraps(f)
def decorated(*args, **kwargs):
auth = request.authorization
if not auth:
return jsonify({"error": "Unauthorized"}), 401
if user_df is not None:
user_email = auth.username.lower().strip()
user_record = user_df[user_df['email'] == user_email]
if not user_record.empty:
user_data = user_record.iloc[0]
if str(user_data['password']) == auth.password and user_data['role'] == 'admin':
return f(*args, **kwargs)
if auth.username == ADMIN_USERNAME and auth.password == ADMIN_PASSWORD:
return f(*args, **kwargs)
return jsonify({"error": "Unauthorized"}), 401
return decorated
# --- URL Zero-Downtime Updater ---
def trigger_url_update():
global rag_system
if not URL_FETCH_ENABLED or not EXTERNAL_URL:
return {"error": "External URL fetching is disabled or not configured"}
logger.info(f"[URL_UPDATE] Starting zero-downtime fetch from {EXTERNAL_URL}")
# 1. Create temporary staging folders
temp_staging_sources = tempfile.mkdtemp(prefix="rag_sources_temp_")
temp_index = tempfile.mkdtemp(prefix="rag_index_temp_")
try:
# 2. COMBINE SOURCES: Copy existing GDrive/Local sources to staging first
if os.path.exists(RAG_SOURCES_DIR):
shutil.copytree(RAG_SOURCES_DIR, temp_staging_sources, dirs_exist_ok=True)
# 3. Fetch URL data — saved to <app_root>/tmp/ for persistence & inspection
tmp_dir = os.path.join(_APP_BASE_DIR, 'tmp')
os.makedirs(tmp_dir, exist_ok=True)
url_out_path = os.path.join(tmp_dir, "url_data.txt")
success = fetch_and_clean_url(EXTERNAL_URL, url_out_path)
if not success:
return {"error": "Failed to fetch or parse the URL."}
# Copy from tmp/ into staging so it gets indexed alongside other sources
shutil.copy2(url_out_path, os.path.join(temp_staging_sources, "url_data.txt"))
# 4. Build a brand new RAG instance isolated in the temp directories
new_rag = initialize_and_get_rag_system(
force_rebuild=True,
source_dir_override=temp_staging_sources,
storage_dir_override=temp_index
)
if new_rag is None:
raise Exception("Failed to build new RAG index from parsed text.")
# 5. Atomic Swap (Now incoming requests hit the new DB immediately)
rag_system = new_rag
# 6. Backup/Replace persistent INDEX directory ONLY
os.makedirs(RAG_STORAGE_PARENT_DIR, exist_ok=True)
shutil.copytree(temp_index, RAG_STORAGE_PARENT_DIR, dirs_exist_ok=True)
rag_system.index_storage_dir = RAG_STORAGE_PARENT_DIR
logger.info("[URL_UPDATE] Success! RAG database updated combining Local, GDrive, and URL sources.")
return {"status": "success", "message": "Database successfully updated using combined sources."}
except Exception as e:
logger.error(f"[URL_UPDATE] Error during update: {e}", exc_info=True)
return {"error": str(e)}
finally:
shutil.rmtree(temp_staging_sources, ignore_errors=True)
shutil.rmtree(temp_index, ignore_errors=True)
def url_periodic_loop():
if not URL_FETCH_ENABLED or not EXTERNAL_URL or URL_UPDATE_PERIOD_MINUTES <= 0:
logger.info("Periodic URL updates disabled.")
return
logger.info(f"[URL_UPDATE] Background thread started for: {EXTERNAL_URL}")
trigger_url_update()
while True:
time.sleep(URL_UPDATE_PERIOD_MINUTES * 60)
logger.info(f"[URL_UPDATE] Triggering scheduled periodic update...")
trigger_url_update()
# --- Startup Logic ---
def run_startup_tasks():
global rag_system
logger.info("--- Executing Startup Tasks ---")
if GDRIVE_USERS_CSV_ENABLED and GDRIVE_USERS_CSV_ID_OR_URL:
target = os.path.join(_APP_BASE_DIR, 'assets', 'users.csv')
download_gdrive_file(GDRIVE_USERS_CSV_ID_OR_URL, target)
load_users_from_csv()
if GDRIVE_INDEX_ENABLED and GDRIVE_INDEX_ID_OR_URL:
download_and_unzip_gdrive_file(GDRIVE_INDEX_ID_OR_URL, os.getcwd())
rag_system = initialize_and_get_rag_system()
if URL_FETCH_ENABLED and EXTERNAL_URL:
threading.Thread(target=url_periodic_loop, daemon=True).start()
logger.info("--- Startup Tasks Complete ---")
with app.app_context():
run_startup_tasks()
# ===========================
# API ROUTES
# ===========================
@app.route('/webhook/search', methods=['POST'])
@require_api_auth
def search_knowledgebase_api():
if not rag_system:
return jsonify({"error": "RAG not initialized. Check server logs."}), 503
data = request.json or {}
query = data.get('query')
if not query:
return jsonify({"error": "Query field is required"}), 400
top_k = data.get('final_k', RAG_RERANKER_K)
use_reranker = data.get('use_reranker', True)
# 1. NEW: Extract the 'cleaned' parameter (defaults to False)
cleaned = data.get('cleaned', False)
if rag_system.retriever:
if not use_reranker:
rag_system.retriever.reranker = None
elif use_reranker and rag_system.reranker:
rag_system.retriever.reranker = rag_system.reranker
try:
raw_results = rag_system.search_knowledge_base(query, top_k=top_k)
# Apply CSV limitations and thresholds
final_results = []
csv_count = 0
for res in raw_results:
is_csv = res["metadata"].get("source_type") == "csv" or res["metadata"].get("source_document_name", "").endswith(".csv")
if is_csv:
score = res["score"]
passed_threshold = False
# Check confidence limit depending on method used (reranker: higher is better | FAISS L2: lower is better)
if rag_system.reranker:
confidence = score
else:
# Convert FAISS L2 Distance into a 0-1 Confidence Score
confidence = 1 / (1 + score)
res["score"] = confidence # Update the result so the API shows the neat confidence score
passed_threshold = confidence >= RAG_CSV_CONFIDENCE_THRESHOLD
if passed_threshold and csv_count < RAG_CSV_MAX_RESULTS:
final_results.append(res)
csv_count += 1
else:
final_results.append(res)
# 2. NEW: If cleaned is True, strip out 'metadata' and 'score'
if cleaned:
final_results = [{"content": r["content"]} for r in final_results]
return jsonify({"results": final_results, "count": len(final_results), "status": "success"})
except Exception as e:
logger.error(f"Search API Error: {e}")
return jsonify({"error": str(e)}), 500
@app.route('/user-login', methods=['POST'])
def user_login():
if user_df is None:
return jsonify({"error": "User database not available."}), 503
data = request.json
email = data.get('email', '').lower().strip()
password = data.get('password')
if not email or not password:
return jsonify({"error": "Email and password required"}), 400
user_record = user_df[user_df['email'] == email]
if not user_record.empty:
u_data = user_record.iloc[0]
if str(u_data['password']) == str(password):
resp = u_data.to_dict()
if 'password' in resp:
del resp['password']
return jsonify(resp), 200
return jsonify({"error": "Invalid credentials"}), 401
@app.route('/')
def index_route():
return render_template('chat-bot.html')
@app.route('/admin/login', methods=['POST'])
@require_admin_auth
def admin_login():
return jsonify({"status": "success", "message": "Authenticated"}), 200
@app.route('/admin/update_faiss_index', methods=['POST'])
@require_admin_auth
def update_faiss_index():
if not rag_system:
return jsonify({"error": "RAG system not initialized"}), 503
data = request.json or {}
max_files = data.get('max_new_files')
try:
result = rag_system.update_index_with_new_files(RAG_SOURCES_DIR, max_files)
return jsonify(result), 200
except Exception as e:
return jsonify({"error": str(e)}), 500
@app.route('/admin/rebuild_index', methods=['POST'])
@require_admin_auth
def rebuild_index():
global rag_system
try:
if URL_FETCH_ENABLED and EXTERNAL_URL:
result = trigger_url_update()
if "error" in result:
return jsonify(result), 500
return jsonify({"status": "Index rebuilt successfully using combined local & URL sources"}), 200
else:
rag_system = initialize_and_get_rag_system(force_rebuild=True)
return jsonify({"status": "Index rebuilt successfully"}), 200
except Exception as e:
return jsonify({"error": str(e)}), 500
# Retained specific endpoint name to ensure the frontend doesn't break
@app.route('/admin/fetch_rentry', methods=['POST'])
@require_admin_auth
def api_fetch_url():
result = trigger_url_update()
if "error" in result:
return jsonify(result), 500
return jsonify(result), 200
@app.route('/status', methods=['GET'])
def status_route():
return jsonify({
"status": "online",
"rag_initialized": rag_system is not None,
"users_loaded": user_df is not None
})
if __name__ == '__main__':
port = int(os.environ.get("PORT", 7860))
app.run(host='0.0.0.0', port=port)