""" KASITBot — Flask application entry point. Routes only. All config in config.py, all RAG logic in rag_pipeline.py. """ import json import os from typing import Dict, List from flask import Flask, jsonify, render_template, request, Response, stream_with_context from flask_cors import CORS from flask_limiter import Limiter from flask_limiter.util import get_remote_address from openai import OpenAI from config import ( OPENAI_API_KEY, OPENROUTER_API_KEY, OPENROUTER_BASE_URL, AVAILABLE_MODELS, DEFAULT_MODEL, MAX_MESSAGE_LENGTH, MAX_MESSAGES_PER_REQ, MAX_HISTORY_MESSAGES, TEMPERATURE, MAX_RESPONSE_TOKENS, ) from rag_pipeline import ( init_rag, hybrid_retrieve, build_context, build_system, format_sources, RERANKER_AVAILABLE, ) app = Flask(__name__) CORS(app) limiter = Limiter( app=app, key_func=get_remote_address, default_limits=[], storage_uri="memory://", ) _openai_client = None _or_client = None _rag_status: Dict = {} def _get_client(model_id: str) -> OpenAI: """Return the right OpenAI-compatible client for the given model.""" info = AVAILABLE_MODELS.get(model_id, {}) return _or_client if info.get("provider") == "openrouter" else _openai_client # ── Startup ──────────────────────────────────────────────────────────────── def init(): global _openai_client, _or_client, _rag_status _openai_client = OpenAI(api_key=OPENAI_API_KEY) _or_client = OpenAI(api_key=OPENROUTER_API_KEY, base_url=OPENROUTER_BASE_URL) _rag_status = init_rag(_openai_client) print(f" RAG={_rag_status.get('rag')}, BM25={_rag_status.get('bm25')}, " f"Reranker={_rag_status.get('reranker')}") # ── Request validation ───────────────────────────────────────────────────── def _parse_request(data): """Returns (messages, lang, model, error_tuple_or_None).""" if not data or not isinstance(data, dict): return None, None, None, (jsonify({"error": "Invalid JSON"}), 400) messages = data.get("messages", []) lang = data.get("lang", "en") model = data.get("model", DEFAULT_MODEL) if not messages: return None, None, None, (jsonify({"error": "No messages"}), 400) if not isinstance(messages, list) or len(messages) > MAX_MESSAGES_PER_REQ: return None, None, None, (jsonify({"error": "Too many messages"}), 400) for m in messages: if not isinstance(m, dict) or m.get("role") not in ("user", "assistant"): return None, None, None, (jsonify({"error": "Invalid message"}), 400) if not isinstance(m.get("content", ""), str) or len(m["content"]) > MAX_MESSAGE_LENGTH: return None, None, None, (jsonify({"error": "Message too long"}), 400) if lang not in ("en", "ar"): lang = "en" if model not in AVAILABLE_MODELS: model = DEFAULT_MODEL return messages, lang, model, None def _last_user_msg(messages: List[Dict]) -> str: return next((m["content"] for m in reversed(messages) if m["role"] == "user"), "") def _trim(messages: List[Dict]) -> List[Dict]: return messages[-MAX_HISTORY_MESSAGES:] if len(messages) > MAX_HISTORY_MESSAGES else messages # ── Routes ───────────────────────────────────────────────────────────────── @app.route("/") def index(): models_json = json.dumps(AVAILABLE_MODELS) default_json = json.dumps(DEFAULT_MODEL) return render_template("index.html", models_json=models_json, default_model=default_json) def _safe_log(label: str, text: str) -> None: """Print without crashing on Windows cp1252 — strips non-ASCII.""" safe = text.encode("ascii", "replace").decode("ascii") print(f" [{label}] {safe}") @app.route("/api/chat", methods=["POST"]) @limiter.limit("12 per minute; 60 per hour; 200 per day") def chat(): messages, lang, model, err = _parse_request(request.get_json(silent=True)) if err: return err messages = _trim(messages) query = _last_user_msg(messages) results = hybrid_retrieve(query, lang) context = build_context(results) system = build_system(lang, context) try: client = _get_client(model) history = [{"role": "system", "content": system}] + messages resp = client.chat.completions.create( model=model, messages=history, max_tokens=MAX_RESPONSE_TOKENS, temperature=TEMPERATURE, ) answer = (resp.choices[0].message.content or "").strip() except Exception as e: print(f"[ERROR] {e}") return jsonify({"error": "LLM error"}), 500 return jsonify({"answer": answer, "sources": format_sources(results)}) @app.route("/api/chat/stream", methods=["POST"]) @limiter.limit("12 per minute; 60 per hour; 200 per day") def chat_stream(): messages, lang, model, err = _parse_request(request.get_json(silent=True)) if err: return err messages = _trim(messages) query = _last_user_msg(messages) def generate(): try: results = hybrid_retrieve(query, lang) context = build_context(results) system = build_system(lang, context) yield f"data: {json.dumps({'type': 'sources', 'sources': format_sources(results)})}\n\n" client = _get_client(model) history = [{"role": "system", "content": system}] + messages # Single streaming call — no tool probe, no double round-trip stream = client.chat.completions.create( model=model, messages=history, max_tokens=MAX_RESPONSE_TOKENS, temperature=TEMPERATURE, stream=True, ) for chunk in stream: delta = chunk.choices[0].delta if delta and delta.content: yield f"data: {json.dumps({'type': 'token', 'content': delta.content})}\n\n" except Exception as exc: err_msg = str(exc).encode("ascii", "replace").decode("ascii") print(f"[ERROR] Stream: {err_msg}") yield f"data: {json.dumps({'type': 'error', 'message': 'Something went wrong. Please try again.'})}\n\n" yield f"data: {json.dumps({'type': 'done'})}\n\n" return Response( stream_with_context(generate()), content_type="text/event-stream", headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}, ) @app.route("/api/health") def health(): return jsonify({ "status": "ok", "rag": _rag_status.get("rag", False), "bm25": _rag_status.get("bm25", False), "reranker": _rag_status.get("reranker", False), "default_model": DEFAULT_MODEL, "available_models": list(AVAILABLE_MODELS.keys()), }) # ── Entry point ──────────────────────────────────────────────────────────── if __name__ == "__main__": print("=" * 56) print(" KASITBot v6 — Hybrid RAG + Streaming") print("=" * 56) init() print(" http://localhost:5000\n") app.run(debug=os.environ.get("FLASK_DEBUG", "0") == "1", host="0.0.0.0", port=5000)