kasitbot / app.py
snygginghani's picture
Disconnect student portal from RAG; compress system prompts 80%
fe965f7
"""
KASITBot β€” Flask application entry point.
Routes only. All config in config.py, all RAG logic in rag_pipeline.py.
"""
import json
import os
from typing import Dict, List
from flask import Flask, jsonify, render_template, request, Response, stream_with_context
from flask_cors import CORS
from flask_limiter import Limiter
from flask_limiter.util import get_remote_address
from openai import OpenAI
from config import (
OPENAI_API_KEY, OPENROUTER_API_KEY, OPENROUTER_BASE_URL,
AVAILABLE_MODELS, DEFAULT_MODEL,
MAX_MESSAGE_LENGTH, MAX_MESSAGES_PER_REQ, MAX_HISTORY_MESSAGES,
TEMPERATURE, MAX_RESPONSE_TOKENS,
)
from rag_pipeline import (
init_rag, hybrid_retrieve, build_context, build_system, format_sources,
RERANKER_AVAILABLE,
)
app = Flask(__name__)
CORS(app)
limiter = Limiter(
app=app,
key_func=get_remote_address,
default_limits=[],
storage_uri="memory://",
)
_openai_client = None
_or_client = None
_rag_status: Dict = {}
def _get_client(model_id: str) -> OpenAI:
"""Return the right OpenAI-compatible client for the given model."""
info = AVAILABLE_MODELS.get(model_id, {})
return _or_client if info.get("provider") == "openrouter" else _openai_client
# ── Startup ────────────────────────────────────────────────────────────────
def init():
global _openai_client, _or_client, _rag_status
_openai_client = OpenAI(api_key=OPENAI_API_KEY)
_or_client = OpenAI(api_key=OPENROUTER_API_KEY, base_url=OPENROUTER_BASE_URL)
_rag_status = init_rag(_openai_client)
print(f" RAG={_rag_status.get('rag')}, BM25={_rag_status.get('bm25')}, "
f"Reranker={_rag_status.get('reranker')}")
# ── Request validation ─────────────────────────────────────────────────────
def _parse_request(data):
"""Returns (messages, lang, model, error_tuple_or_None)."""
if not data or not isinstance(data, dict):
return None, None, None, (jsonify({"error": "Invalid JSON"}), 400)
messages = data.get("messages", [])
lang = data.get("lang", "en")
model = data.get("model", DEFAULT_MODEL)
if not messages:
return None, None, None, (jsonify({"error": "No messages"}), 400)
if not isinstance(messages, list) or len(messages) > MAX_MESSAGES_PER_REQ:
return None, None, None, (jsonify({"error": "Too many messages"}), 400)
for m in messages:
if not isinstance(m, dict) or m.get("role") not in ("user", "assistant"):
return None, None, None, (jsonify({"error": "Invalid message"}), 400)
if not isinstance(m.get("content", ""), str) or len(m["content"]) > MAX_MESSAGE_LENGTH:
return None, None, None, (jsonify({"error": "Message too long"}), 400)
if lang not in ("en", "ar"):
lang = "en"
if model not in AVAILABLE_MODELS:
model = DEFAULT_MODEL
return messages, lang, model, None
def _last_user_msg(messages: List[Dict]) -> str:
return next((m["content"] for m in reversed(messages) if m["role"] == "user"), "")
def _trim(messages: List[Dict]) -> List[Dict]:
return messages[-MAX_HISTORY_MESSAGES:] if len(messages) > MAX_HISTORY_MESSAGES else messages
# ── Routes ─────────────────────────────────────────────────────────────────
@app.route("/")
def index():
models_json = json.dumps(AVAILABLE_MODELS)
default_json = json.dumps(DEFAULT_MODEL)
return render_template("index.html",
models_json=models_json,
default_model=default_json)
def _safe_log(label: str, text: str) -> None:
"""Print without crashing on Windows cp1252 β€” strips non-ASCII."""
safe = text.encode("ascii", "replace").decode("ascii")
print(f" [{label}] {safe}")
@app.route("/api/chat", methods=["POST"])
@limiter.limit("12 per minute; 60 per hour; 200 per day")
def chat():
messages, lang, model, err = _parse_request(request.get_json(silent=True))
if err:
return err
messages = _trim(messages)
query = _last_user_msg(messages)
results = hybrid_retrieve(query, lang)
context = build_context(results)
system = build_system(lang, context)
try:
client = _get_client(model)
history = [{"role": "system", "content": system}] + messages
resp = client.chat.completions.create(
model=model, messages=history,
max_tokens=MAX_RESPONSE_TOKENS, temperature=TEMPERATURE,
)
answer = (resp.choices[0].message.content or "").strip()
except Exception as e:
print(f"[ERROR] {e}")
return jsonify({"error": "LLM error"}), 500
return jsonify({"answer": answer, "sources": format_sources(results)})
@app.route("/api/chat/stream", methods=["POST"])
@limiter.limit("12 per minute; 60 per hour; 200 per day")
def chat_stream():
messages, lang, model, err = _parse_request(request.get_json(silent=True))
if err:
return err
messages = _trim(messages)
query = _last_user_msg(messages)
def generate():
try:
results = hybrid_retrieve(query, lang)
context = build_context(results)
system = build_system(lang, context)
yield f"data: {json.dumps({'type': 'sources', 'sources': format_sources(results)})}\n\n"
client = _get_client(model)
history = [{"role": "system", "content": system}] + messages
# Single streaming call β€” no tool probe, no double round-trip
stream = client.chat.completions.create(
model=model, messages=history,
max_tokens=MAX_RESPONSE_TOKENS, temperature=TEMPERATURE,
stream=True,
)
for chunk in stream:
delta = chunk.choices[0].delta
if delta and delta.content:
yield f"data: {json.dumps({'type': 'token', 'content': delta.content})}\n\n"
except Exception as exc:
err_msg = str(exc).encode("ascii", "replace").decode("ascii")
print(f"[ERROR] Stream: {err_msg}")
yield f"data: {json.dumps({'type': 'error', 'message': 'Something went wrong. Please try again.'})}\n\n"
yield f"data: {json.dumps({'type': 'done'})}\n\n"
return Response(
stream_with_context(generate()),
content_type="text/event-stream",
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
)
@app.route("/api/health")
def health():
return jsonify({
"status": "ok",
"rag": _rag_status.get("rag", False),
"bm25": _rag_status.get("bm25", False),
"reranker": _rag_status.get("reranker", False),
"default_model": DEFAULT_MODEL,
"available_models": list(AVAILABLE_MODELS.keys()),
})
# ── Entry point ────────────────────────────────────────────────────────────
if __name__ == "__main__":
print("=" * 56)
print(" KASITBot v6 β€” Hybrid RAG + Streaming")
print("=" * 56)
init()
print(" http://localhost:5000\n")
app.run(debug=os.environ.get("FLASK_DEBUG", "0") == "1",
host="0.0.0.0", port=5000)