Spaces:
Running
Running
| from flask import Flask, request, jsonify, Response, render_template | |
| from flask_cors import CORS | |
| import os | |
| import logging | |
| import functools | |
| import pandas as pd | |
| import threading | |
| import time | |
| import tempfile | |
| import shutil | |
| from dotenv import load_dotenv | |
| # Load environment variables | |
| load_dotenv() | |
| # Custom Imports | |
| from rag_system import initialize_and_get_rag_system | |
| from config import ( | |
| API_USERNAME, API_PASSWORD, RAG_SOURCES_DIR, RAG_STORAGE_PARENT_DIR, | |
| GDRIVE_INDEX_ENABLED, GDRIVE_INDEX_ID_OR_URL, | |
| GDRIVE_USERS_CSV_ENABLED, GDRIVE_USERS_CSV_ID_OR_URL, | |
| ADMIN_USERNAME, ADMIN_PASSWORD, RAG_RERANKER_K, | |
| EXTERNAL_URL, URL_UPDATE_PERIOD_MINUTES, URL_FETCH_ENABLED, | |
| RAG_CSV_MAX_RESULTS, RAG_CSV_CONFIDENCE_THRESHOLD | |
| ) | |
| from utils import download_and_unzip_gdrive_file, download_gdrive_file, fetch_and_clean_url | |
| # Logging Setup | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Flask Init | |
| app = Flask(__name__, static_folder='static', template_folder='templates') | |
| CORS(app) | |
| # Global State | |
| rag_system = None | |
| user_df = None | |
| _APP_BASE_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| # --- Helper: Load Users --- | |
| def load_users_from_csv(): | |
| global user_df | |
| assets_folder = os.path.join(_APP_BASE_DIR, 'assets') | |
| os.makedirs(assets_folder, exist_ok=True) | |
| users_csv_path = os.path.join(assets_folder, 'users.csv') | |
| try: | |
| if os.path.exists(users_csv_path): | |
| user_df = pd.read_csv(users_csv_path) | |
| # Normalize email | |
| if 'email' in user_df.columns: | |
| user_df['email'] = user_df['email'].str.lower().str.strip() | |
| logger.info(f"Loaded {len(user_df)} users from CSV.") | |
| else: | |
| logger.warning("users.csv not found in assets folder.") | |
| user_df = None | |
| except Exception as e: | |
| logger.error(f"Failed to load users.csv: {e}") | |
| user_df = None | |
| # --- Helper: Auth Decorators --- | |
| def require_api_auth(f): | |
| """Protects the N8N Webhook endpoint""" | |
| def decorated(*args, **kwargs): | |
| auth = request.authorization | |
| if not auth or auth.username != API_USERNAME or auth.password != API_PASSWORD: | |
| return Response('Unauthorized', 401, {'WWW-Authenticate': 'Basic realm="API Login Required"'}) | |
| return f(*args, **kwargs) | |
| return decorated | |
| def require_admin_auth(f): | |
| """Protects Admin Rebuild/Update endpoints""" | |
| def decorated(*args, **kwargs): | |
| auth = request.authorization | |
| if not auth: | |
| return jsonify({"error": "Unauthorized"}), 401 | |
| if user_df is not None: | |
| user_email = auth.username.lower().strip() | |
| user_record = user_df[user_df['email'] == user_email] | |
| if not user_record.empty: | |
| user_data = user_record.iloc[0] | |
| if str(user_data['password']) == auth.password and user_data['role'] == 'admin': | |
| return f(*args, **kwargs) | |
| if auth.username == ADMIN_USERNAME and auth.password == ADMIN_PASSWORD: | |
| return f(*args, **kwargs) | |
| return jsonify({"error": "Unauthorized"}), 401 | |
| return decorated | |
| # --- URL Zero-Downtime Updater --- | |
| def trigger_url_update(): | |
| global rag_system | |
| if not URL_FETCH_ENABLED or not EXTERNAL_URL: | |
| return {"error": "External URL fetching is disabled or not configured"} | |
| logger.info(f"[URL_UPDATE] Starting zero-downtime fetch from {EXTERNAL_URL}") | |
| # 1. Create temporary staging folders | |
| temp_staging_sources = tempfile.mkdtemp(prefix="rag_sources_temp_") | |
| temp_index = tempfile.mkdtemp(prefix="rag_index_temp_") | |
| try: | |
| # 2. COMBINE SOURCES: Copy existing GDrive/Local sources to staging first | |
| if os.path.exists(RAG_SOURCES_DIR): | |
| shutil.copytree(RAG_SOURCES_DIR, temp_staging_sources, dirs_exist_ok=True) | |
| # 3. Fetch URL data — saved to <app_root>/tmp/ for persistence & inspection | |
| tmp_dir = os.path.join(_APP_BASE_DIR, 'tmp') | |
| os.makedirs(tmp_dir, exist_ok=True) | |
| url_out_path = os.path.join(tmp_dir, "url_data.txt") | |
| success = fetch_and_clean_url(EXTERNAL_URL, url_out_path) | |
| if not success: | |
| return {"error": "Failed to fetch or parse the URL."} | |
| # Copy from tmp/ into staging so it gets indexed alongside other sources | |
| shutil.copy2(url_out_path, os.path.join(temp_staging_sources, "url_data.txt")) | |
| # 4. Build a brand new RAG instance isolated in the temp directories | |
| new_rag = initialize_and_get_rag_system( | |
| force_rebuild=True, | |
| source_dir_override=temp_staging_sources, | |
| storage_dir_override=temp_index | |
| ) | |
| if new_rag is None: | |
| raise Exception("Failed to build new RAG index from parsed text.") | |
| # 5. Atomic Swap (Now incoming requests hit the new DB immediately) | |
| rag_system = new_rag | |
| # 6. Backup/Replace persistent INDEX directory ONLY | |
| os.makedirs(RAG_STORAGE_PARENT_DIR, exist_ok=True) | |
| shutil.copytree(temp_index, RAG_STORAGE_PARENT_DIR, dirs_exist_ok=True) | |
| rag_system.index_storage_dir = RAG_STORAGE_PARENT_DIR | |
| logger.info("[URL_UPDATE] Success! RAG database updated combining Local, GDrive, and URL sources.") | |
| return {"status": "success", "message": "Database successfully updated using combined sources."} | |
| except Exception as e: | |
| logger.error(f"[URL_UPDATE] Error during update: {e}", exc_info=True) | |
| return {"error": str(e)} | |
| finally: | |
| shutil.rmtree(temp_staging_sources, ignore_errors=True) | |
| shutil.rmtree(temp_index, ignore_errors=True) | |
| def url_periodic_loop(): | |
| if not URL_FETCH_ENABLED or not EXTERNAL_URL or URL_UPDATE_PERIOD_MINUTES <= 0: | |
| logger.info("Periodic URL updates disabled.") | |
| return | |
| logger.info(f"[URL_UPDATE] Background thread started for: {EXTERNAL_URL}") | |
| trigger_url_update() | |
| while True: | |
| time.sleep(URL_UPDATE_PERIOD_MINUTES * 60) | |
| logger.info(f"[URL_UPDATE] Triggering scheduled periodic update...") | |
| trigger_url_update() | |
| # --- Startup Logic --- | |
| def run_startup_tasks(): | |
| global rag_system | |
| logger.info("--- Executing Startup Tasks ---") | |
| if GDRIVE_USERS_CSV_ENABLED and GDRIVE_USERS_CSV_ID_OR_URL: | |
| target = os.path.join(_APP_BASE_DIR, 'assets', 'users.csv') | |
| download_gdrive_file(GDRIVE_USERS_CSV_ID_OR_URL, target) | |
| load_users_from_csv() | |
| if GDRIVE_INDEX_ENABLED and GDRIVE_INDEX_ID_OR_URL: | |
| download_and_unzip_gdrive_file(GDRIVE_INDEX_ID_OR_URL, os.getcwd()) | |
| rag_system = initialize_and_get_rag_system() | |
| if URL_FETCH_ENABLED and EXTERNAL_URL: | |
| threading.Thread(target=url_periodic_loop, daemon=True).start() | |
| logger.info("--- Startup Tasks Complete ---") | |
| with app.app_context(): | |
| run_startup_tasks() | |
| # =========================== | |
| # API ROUTES | |
| # =========================== | |
| def search_knowledgebase_api(): | |
| if not rag_system: | |
| return jsonify({"error": "RAG not initialized. Check server logs."}), 503 | |
| data = request.json or {} | |
| query = data.get('query') | |
| if not query: | |
| return jsonify({"error": "Query field is required"}), 400 | |
| top_k = data.get('final_k', RAG_RERANKER_K) | |
| use_reranker = data.get('use_reranker', True) | |
| # 1. NEW: Extract the 'cleaned' parameter (defaults to False) | |
| cleaned = data.get('cleaned', False) | |
| if rag_system.retriever: | |
| if not use_reranker: | |
| rag_system.retriever.reranker = None | |
| elif use_reranker and rag_system.reranker: | |
| rag_system.retriever.reranker = rag_system.reranker | |
| try: | |
| raw_results = rag_system.search_knowledge_base(query, top_k=top_k) | |
| # Apply CSV limitations and thresholds | |
| final_results = [] | |
| csv_count = 0 | |
| for res in raw_results: | |
| is_csv = res["metadata"].get("source_type") == "csv" or res["metadata"].get("source_document_name", "").endswith(".csv") | |
| if is_csv: | |
| score = res["score"] | |
| passed_threshold = False | |
| # Check confidence limit depending on method used (reranker: higher is better | FAISS L2: lower is better) | |
| if rag_system.reranker: | |
| confidence = score | |
| else: | |
| # Convert FAISS L2 Distance into a 0-1 Confidence Score | |
| confidence = 1 / (1 + score) | |
| res["score"] = confidence # Update the result so the API shows the neat confidence score | |
| passed_threshold = confidence >= RAG_CSV_CONFIDENCE_THRESHOLD | |
| if passed_threshold and csv_count < RAG_CSV_MAX_RESULTS: | |
| final_results.append(res) | |
| csv_count += 1 | |
| else: | |
| final_results.append(res) | |
| # 2. NEW: If cleaned is True, strip out 'metadata' and 'score' | |
| if cleaned: | |
| final_results = [{"content": r["content"]} for r in final_results] | |
| return jsonify({"results": final_results, "count": len(final_results), "status": "success"}) | |
| except Exception as e: | |
| logger.error(f"Search API Error: {e}") | |
| return jsonify({"error": str(e)}), 500 | |
| def user_login(): | |
| if user_df is None: | |
| return jsonify({"error": "User database not available."}), 503 | |
| data = request.json | |
| email = data.get('email', '').lower().strip() | |
| password = data.get('password') | |
| if not email or not password: | |
| return jsonify({"error": "Email and password required"}), 400 | |
| user_record = user_df[user_df['email'] == email] | |
| if not user_record.empty: | |
| u_data = user_record.iloc[0] | |
| if str(u_data['password']) == str(password): | |
| resp = u_data.to_dict() | |
| if 'password' in resp: | |
| del resp['password'] | |
| return jsonify(resp), 200 | |
| return jsonify({"error": "Invalid credentials"}), 401 | |
| def index_route(): | |
| return render_template('chat-bot.html') | |
| def admin_login(): | |
| return jsonify({"status": "success", "message": "Authenticated"}), 200 | |
| def update_faiss_index(): | |
| if not rag_system: | |
| return jsonify({"error": "RAG system not initialized"}), 503 | |
| data = request.json or {} | |
| max_files = data.get('max_new_files') | |
| try: | |
| result = rag_system.update_index_with_new_files(RAG_SOURCES_DIR, max_files) | |
| return jsonify(result), 200 | |
| except Exception as e: | |
| return jsonify({"error": str(e)}), 500 | |
| def rebuild_index(): | |
| global rag_system | |
| try: | |
| if URL_FETCH_ENABLED and EXTERNAL_URL: | |
| result = trigger_url_update() | |
| if "error" in result: | |
| return jsonify(result), 500 | |
| return jsonify({"status": "Index rebuilt successfully using combined local & URL sources"}), 200 | |
| else: | |
| rag_system = initialize_and_get_rag_system(force_rebuild=True) | |
| return jsonify({"status": "Index rebuilt successfully"}), 200 | |
| except Exception as e: | |
| return jsonify({"error": str(e)}), 500 | |
| # Retained specific endpoint name to ensure the frontend doesn't break | |
| def api_fetch_url(): | |
| result = trigger_url_update() | |
| if "error" in result: | |
| return jsonify(result), 500 | |
| return jsonify(result), 200 | |
| def status_route(): | |
| return jsonify({ | |
| "status": "online", | |
| "rag_initialized": rag_system is not None, | |
| "users_loaded": user_df is not None | |
| }) | |
| if __name__ == '__main__': | |
| port = int(os.environ.get("PORT", 7860)) | |
| app.run(host='0.0.0.0', port=port) |