File size: 7,109 Bytes
71e1c4b
 
 
 
 
 
 
 
 
 
 
 
 
 
fe965f7
71e1c4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c26a997
71e1c4b
 
 
 
 
 
 
 
c26a997
71e1c4b
 
 
 
c26a997
 
71e1c4b
 
 
 
c26a997
71e1c4b
 
 
 
 
 
c26a997
 
71e1c4b
 
 
 
 
c26a997
71e1c4b
 
 
c26a997
 
71e1c4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c26a997
 
 
71e1c4b
 
c26a997
71e1c4b
 
 
c26a997
 
 
 
 
71e1c4b
 
 
 
 
 
 
 
 
 
c26a997
 
 
 
71e1c4b
 
c26a997
fe965f7
c26a997
 
 
 
 
fe965f7
 
71e1c4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c26a997
71e1c4b
 
 
 
 
c26a997
71e1c4b
 
 
c26a997
71e1c4b
 
 
c26a997
 
 
 
71e1c4b
 
 
 
 
 
 
 
c26a997
 
71e1c4b
 
 
 
 
 
 
 
 
 
 
 
 
 
c26a997
 
 
 
 
71e1c4b
 
 
 
c26a997
71e1c4b
 
 
c26a997
71e1c4b
 
 
c26a997
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
import json
import os
from typing import Dict, List

from flask import Flask, jsonify, render_template, request, Response, stream_with_context
from flask_cors import CORS
from flask_limiter import Limiter
from flask_limiter.util import get_remote_address
from openai import OpenAI

from config import (
    OPENAI_API_KEY, OPENROUTER_API_KEY, OPENROUTER_BASE_URL,
    AVAILABLE_MODELS, DEFAULT_MODEL,
    MAX_MESSAGE_LENGTH, MAX_MESSAGES_PER_REQ, MAX_HISTORY_MESSAGES,
    TEMPERATURE, MAX_RESPONSE_TOKENS,
)
from rag_pipeline import (
    init_rag, hybrid_retrieve, build_context, build_system, format_sources,
    RERANKER_AVAILABLE,
)

app = Flask(__name__)
CORS(app)

limiter = Limiter(
    app=app,
    key_func=get_remote_address,
    default_limits=[],
    storage_uri="memory://",
)

_openai_client = None
_or_client = None
_rag_status: Dict = {}


def _get_client(model_id: str) -> OpenAI:
    info = AVAILABLE_MODELS.get(model_id, {})
    return _or_client if info.get("provider") == "openrouter" else _openai_client


# ── Startup ───────────────────────────────────────────────────────────────

def init():
    global _openai_client, _or_client, _rag_status
    _openai_client = OpenAI(api_key=OPENAI_API_KEY)
    _or_client = OpenAI(api_key=OPENROUTER_API_KEY, base_url=OPENROUTER_BASE_URL)
    _rag_status = init_rag(_openai_client)
    print(f"  RAG={_rag_status.get('rag')}, BM25={_rag_status.get('bm25')}, "
          f"Reranker={_rag_status.get('reranker')}")


# ── Request validation ────────────────────────────────────────────────────

def _parse_request(data):
    if not data or not isinstance(data, dict):
        return None, None, None, (jsonify({"error": "Invalid JSON"}), 400)

    messages = data.get("messages", [])
    lang = data.get("lang", "en")
    model = data.get("model", DEFAULT_MODEL)

    if not messages:
        return None, None, None, (jsonify({"error": "No messages"}), 400)
    if not isinstance(messages, list) or len(messages) > MAX_MESSAGES_PER_REQ:
        return None, None, None, (jsonify({"error": "Too many messages"}), 400)

    for m in messages:
        if not isinstance(m, dict) or m.get("role") not in ("user", "assistant"):
            return None, None, None, (jsonify({"error": "Invalid message"}), 400)
        content = m.get("content", "")
        if not isinstance(content, str) or len(content) > MAX_MESSAGE_LENGTH:
            return None, None, None, (jsonify({"error": "Message too long"}), 400)

    if lang not in ("en", "ar"):
        lang = "en"
    if model not in AVAILABLE_MODELS:
        model = DEFAULT_MODEL

    return messages, lang, model, None


def _last_user_msg(messages: List[Dict]) -> str:
    return next((m["content"] for m in reversed(messages) if m["role"] == "user"), "")


def _trim(messages: List[Dict]) -> List[Dict]:
    if len(messages) > MAX_HISTORY_MESSAGES:
        return messages[-MAX_HISTORY_MESSAGES:]
    return messages


# ── Routes ────────────────────────────────────────────────────────────────

@app.route("/")
def index():
    return render_template(
        "index.html",
        models_json=json.dumps(AVAILABLE_MODELS),
        default_model=json.dumps(DEFAULT_MODEL),
    )


@app.route("/api/chat", methods=["POST"])
@limiter.limit("12 per minute; 60 per hour; 200 per day")
def chat():
    messages, lang, model, err = _parse_request(request.get_json(silent=True))
    if err:
        return err

    messages = _trim(messages)
    query = _last_user_msg(messages)
    results = hybrid_retrieve(query, lang)
    context = build_context(results)
    system = build_system(lang, context)

    try:
        client = _get_client(model)
        history = [{"role": "system", "content": system}] + messages
        resp = client.chat.completions.create(
            model=model,
            messages=history,
            max_tokens=MAX_RESPONSE_TOKENS,
            temperature=TEMPERATURE,
        )
        answer = (resp.choices[0].message.content or "").strip()
    except Exception as e:
        print(f"[ERROR] {e}")
        return jsonify({"error": "LLM error"}), 500

    return jsonify({"answer": answer, "sources": format_sources(results)})


@app.route("/api/chat/stream", methods=["POST"])
@limiter.limit("12 per minute; 60 per hour; 200 per day")
def chat_stream():
    messages, lang, model, err = _parse_request(request.get_json(silent=True))
    if err:
        return err

    messages = _trim(messages)
    query = _last_user_msg(messages)

    def generate():
        try:
            results = hybrid_retrieve(query, lang)
            context = build_context(results)
            system = build_system(lang, context)

            yield f"data: {json.dumps({'type': 'sources', 'sources': format_sources(results)})}\n\n"

            client = _get_client(model)
            history = [{"role": "system", "content": system}] + messages

            stream = client.chat.completions.create(
                model=model,
                messages=history,
                max_tokens=MAX_RESPONSE_TOKENS,
                temperature=TEMPERATURE,
                stream=True,
            )
            for chunk in stream:
                delta = chunk.choices[0].delta
                if delta and delta.content:
                    yield f"data: {json.dumps({'type': 'token', 'content': delta.content})}\n\n"

        except Exception as exc:
            safe_msg = str(exc).encode("ascii", "replace").decode("ascii")
            print(f"[ERROR] Stream: {safe_msg}")
            yield f"data: {json.dumps({'type': 'error', 'message': 'Something went wrong. Please try again.'})}\n\n"

        yield f"data: {json.dumps({'type': 'done'})}\n\n"

    return Response(
        stream_with_context(generate()),
        content_type="text/event-stream",
        headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
    )


@app.route("/api/health")
def health():
    return jsonify({
        "status":           "ok",
        "rag":              _rag_status.get("rag", False),
        "bm25":             _rag_status.get("bm25", False),
        "reranker":         _rag_status.get("reranker", False),
        "default_model":    DEFAULT_MODEL,
        "available_models": list(AVAILABLE_MODELS.keys()),
    })


# ── Entry point ───────────────────────────────────────────────────────────

if __name__ == "__main__":
    print("=" * 56)
    print("  KASITBot v7 β€” Hybrid RAG + Streaming")
    print("=" * 56)
    init()
    print("  http://localhost:5000\n")
    app.run(
        debug=os.environ.get("FLASK_DEBUG", "0") == "1",
        host="0.0.0.0",
        port=5000,
    )