Spaces:
Running
Running
| """ | |
| KASITBot β Flask application entry point. | |
| Routes only. All config in config.py, all RAG logic in rag_pipeline.py. | |
| """ | |
| import json | |
| import os | |
| from typing import Dict, List | |
| from flask import Flask, jsonify, render_template, request, Response, stream_with_context | |
| from flask_cors import CORS | |
| from flask_limiter import Limiter | |
| from flask_limiter.util import get_remote_address | |
| from openai import OpenAI | |
| from config import ( | |
| OPENAI_API_KEY, OPENROUTER_API_KEY, OPENROUTER_BASE_URL, | |
| AVAILABLE_MODELS, DEFAULT_MODEL, | |
| MAX_MESSAGE_LENGTH, MAX_MESSAGES_PER_REQ, MAX_HISTORY_MESSAGES, | |
| TEMPERATURE, MAX_RESPONSE_TOKENS, | |
| ) | |
| from rag_pipeline import ( | |
| init_rag, hybrid_retrieve, build_context, build_system, format_sources, | |
| RERANKER_AVAILABLE, | |
| ) | |
| app = Flask(__name__) | |
| CORS(app) | |
| limiter = Limiter( | |
| app=app, | |
| key_func=get_remote_address, | |
| default_limits=[], | |
| storage_uri="memory://", | |
| ) | |
| _openai_client = None | |
| _or_client = None | |
| _rag_status: Dict = {} | |
| def _get_client(model_id: str) -> OpenAI: | |
| """Return the right OpenAI-compatible client for the given model.""" | |
| info = AVAILABLE_MODELS.get(model_id, {}) | |
| return _or_client if info.get("provider") == "openrouter" else _openai_client | |
| # ββ Startup ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def init(): | |
| global _openai_client, _or_client, _rag_status | |
| _openai_client = OpenAI(api_key=OPENAI_API_KEY) | |
| _or_client = OpenAI(api_key=OPENROUTER_API_KEY, base_url=OPENROUTER_BASE_URL) | |
| _rag_status = init_rag(_openai_client) | |
| print(f" RAG={_rag_status.get('rag')}, BM25={_rag_status.get('bm25')}, " | |
| f"Reranker={_rag_status.get('reranker')}") | |
| # ββ Request validation βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _parse_request(data): | |
| """Returns (messages, lang, model, error_tuple_or_None).""" | |
| if not data or not isinstance(data, dict): | |
| return None, None, None, (jsonify({"error": "Invalid JSON"}), 400) | |
| messages = data.get("messages", []) | |
| lang = data.get("lang", "en") | |
| model = data.get("model", DEFAULT_MODEL) | |
| if not messages: | |
| return None, None, None, (jsonify({"error": "No messages"}), 400) | |
| if not isinstance(messages, list) or len(messages) > MAX_MESSAGES_PER_REQ: | |
| return None, None, None, (jsonify({"error": "Too many messages"}), 400) | |
| for m in messages: | |
| if not isinstance(m, dict) or m.get("role") not in ("user", "assistant"): | |
| return None, None, None, (jsonify({"error": "Invalid message"}), 400) | |
| if not isinstance(m.get("content", ""), str) or len(m["content"]) > MAX_MESSAGE_LENGTH: | |
| return None, None, None, (jsonify({"error": "Message too long"}), 400) | |
| if lang not in ("en", "ar"): | |
| lang = "en" | |
| if model not in AVAILABLE_MODELS: | |
| model = DEFAULT_MODEL | |
| return messages, lang, model, None | |
| def _last_user_msg(messages: List[Dict]) -> str: | |
| return next((m["content"] for m in reversed(messages) if m["role"] == "user"), "") | |
| def _trim(messages: List[Dict]) -> List[Dict]: | |
| return messages[-MAX_HISTORY_MESSAGES:] if len(messages) > MAX_HISTORY_MESSAGES else messages | |
| # ββ Routes βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def index(): | |
| models_json = json.dumps(AVAILABLE_MODELS) | |
| default_json = json.dumps(DEFAULT_MODEL) | |
| return render_template("index.html", | |
| models_json=models_json, | |
| default_model=default_json) | |
| def _safe_log(label: str, text: str) -> None: | |
| """Print without crashing on Windows cp1252 β strips non-ASCII.""" | |
| safe = text.encode("ascii", "replace").decode("ascii") | |
| print(f" [{label}] {safe}") | |
| def chat(): | |
| messages, lang, model, err = _parse_request(request.get_json(silent=True)) | |
| if err: | |
| return err | |
| messages = _trim(messages) | |
| query = _last_user_msg(messages) | |
| results = hybrid_retrieve(query, lang) | |
| context = build_context(results) | |
| system = build_system(lang, context) | |
| try: | |
| client = _get_client(model) | |
| history = [{"role": "system", "content": system}] + messages | |
| resp = client.chat.completions.create( | |
| model=model, messages=history, | |
| max_tokens=MAX_RESPONSE_TOKENS, temperature=TEMPERATURE, | |
| ) | |
| answer = (resp.choices[0].message.content or "").strip() | |
| except Exception as e: | |
| print(f"[ERROR] {e}") | |
| return jsonify({"error": "LLM error"}), 500 | |
| return jsonify({"answer": answer, "sources": format_sources(results)}) | |
| def chat_stream(): | |
| messages, lang, model, err = _parse_request(request.get_json(silent=True)) | |
| if err: | |
| return err | |
| messages = _trim(messages) | |
| query = _last_user_msg(messages) | |
| def generate(): | |
| try: | |
| results = hybrid_retrieve(query, lang) | |
| context = build_context(results) | |
| system = build_system(lang, context) | |
| yield f"data: {json.dumps({'type': 'sources', 'sources': format_sources(results)})}\n\n" | |
| client = _get_client(model) | |
| history = [{"role": "system", "content": system}] + messages | |
| # Single streaming call β no tool probe, no double round-trip | |
| stream = client.chat.completions.create( | |
| model=model, messages=history, | |
| max_tokens=MAX_RESPONSE_TOKENS, temperature=TEMPERATURE, | |
| stream=True, | |
| ) | |
| for chunk in stream: | |
| delta = chunk.choices[0].delta | |
| if delta and delta.content: | |
| yield f"data: {json.dumps({'type': 'token', 'content': delta.content})}\n\n" | |
| except Exception as exc: | |
| err_msg = str(exc).encode("ascii", "replace").decode("ascii") | |
| print(f"[ERROR] Stream: {err_msg}") | |
| yield f"data: {json.dumps({'type': 'error', 'message': 'Something went wrong. Please try again.'})}\n\n" | |
| yield f"data: {json.dumps({'type': 'done'})}\n\n" | |
| return Response( | |
| stream_with_context(generate()), | |
| content_type="text/event-stream", | |
| headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}, | |
| ) | |
| def health(): | |
| return jsonify({ | |
| "status": "ok", | |
| "rag": _rag_status.get("rag", False), | |
| "bm25": _rag_status.get("bm25", False), | |
| "reranker": _rag_status.get("reranker", False), | |
| "default_model": DEFAULT_MODEL, | |
| "available_models": list(AVAILABLE_MODELS.keys()), | |
| }) | |
| # ββ Entry point ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if __name__ == "__main__": | |
| print("=" * 56) | |
| print(" KASITBot v6 β Hybrid RAG + Streaming") | |
| print("=" * 56) | |
| init() | |
| print(" http://localhost:5000\n") | |
| app.run(debug=os.environ.get("FLASK_DEBUG", "0") == "1", | |
| host="0.0.0.0", port=5000) | |