File size: 13,599 Bytes
3b28871
 
 
 
9995f1a
3b28871
 
 
9995f1a
3b28871
 
9995f1a
3b28871
9995f1a
3b28871
 
 
 
9995f1a
 
3b28871
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3bb9213
 
 
 
 
 
 
5f0d937
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3b28871
 
 
 
 
 
 
 
 
 
 
3bb9213
 
 
 
 
 
3b28871
 
5f0d937
 
 
 
3b28871
 
 
 
 
 
 
3bb9213
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3b28871
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3bb9213
 
 
 
 
 
 
3b28871
 
 
 
 
3bb9213
 
 
 
 
 
 
3b28871
 
3bb9213
 
 
 
 
 
3b28871
 
 
 
 
3bb9213
 
 
 
 
 
 
 
 
 
 
 
 
 
3b28871
 
 
 
 
 
 
3bb9213
4330ca1
3b28871
 
 
 
 
3bb9213
3b28871
 
 
3bb9213
 
 
 
3b28871
 
 
 
 
 
 
 
 
 
3bb9213
3b28871
3bb9213
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4330ca1
3bb9213
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9995f1a
 
5f0d937
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9995f1a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
from __future__ import annotations

import json
import logging
from pathlib import Path
from typing import AsyncIterator

import httpx
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel, Field
from starlette.responses import FileResponse

from app.config import settings

LOG = logging.getLogger(__name__)

STATIC_DIR = Path(__file__).resolve().parent.parent / "static"
logging.basicConfig(level=logging.INFO)

app = FastAPI(title="Ask Jerry API")

app.add_middleware(
    CORSMiddleware,
    allow_origins=settings.cors_origin_list,
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

JERRY_SYSTEM_BASE = (
    "You are AI Jerry, a cybersecurity-focused assistant running on the BrainForge Security model. "
    "You give clear, practical guidance; you distinguish facts from speculation; you flag risks and "
    "compliance considerations when relevant. You are friendly and professional."
)

MODEL_CONTEXT_WINDOW = 8192
SUMMARIZE_SYSTEM = (
    "You are a concise summarizer. Condense the following conversation into a short summary "
    "that preserves the key topics discussed, any conclusions reached, important facts shared, "
    "and the overall tone. Keep it under 300 words. Write in third person narrative form."
)

_STATEMENT_MAX_CHARS = 4000

SEARCH_REF_SYSTEM = """You help a user find web sources to **research** an AI assistant's cybersecurity answer.

You will receive the full text of that answer. Do this in order:

1. **Facts** — Identify the main factual claims (CVEs, standards, protocols, vendor names, regulations, definitions, and procedural steps).
2. **Meaning** — In a few words, capture the overall gist: what the answer is explaining or recommending.
3. **Search query** — Compose one concise web search query (or two short queries separated by `; `) optimized to find **authoritative references** that could verify or deepen those facts—e.g. NIST, CISA, vendor docs, RFCs, CWE/CVE pages, or reputable security guidance.

Rules:
- Prefer concrete, verifiable keywords from the text.
- The query should help someone **research** the topic, not merely restate the answer in different words.
- Do not include meta-commentary, labels, bullets, or step numbers in your output.
- Return **only** the search query string (or two `; `-separated queries), with no quotes around the whole thing and no preamble."""


class ChatMessage(BaseModel):
    role: str
    content: str


class ChatStreamBody(BaseModel):
    messages: list[ChatMessage] = Field(..., min_length=1)
    extra_persona: str = ""
    temperature: float | None = None
    max_tokens: int | None = None
    summary: str | None = None


class SummarizeBody(BaseModel):
    messages: list[ChatMessage] = Field(..., min_length=1)
    extra_persona: str = ""


class SearchRefBody(BaseModel):
    statement: str = ""


def _build_system_prompt(extra_persona: str) -> str:
    extra = (extra_persona or "").strip()
    if not extra:
        return JERRY_SYSTEM_BASE
    return f"{JERRY_SYSTEM_BASE}\n\nAdditional instructions from the user:\n{extra}"


def _estimate_tokens(text: str) -> int:
    return max(1, len(text) // 4)


def _estimate_messages_tokens(msgs: list[dict]) -> int:
    total = 0
    for m in msgs:
        total += _estimate_tokens(m.get("content", "")) + 4
    return total


def _build_api_messages(
    system: str,
    body_messages: list[ChatMessage],
    summary: str | None,
) -> list[dict]:
    msgs: list[dict] = [{"role": "system", "content": system}]
    if summary:
        msgs.append({
            "role": "system",
            "content": f"Summary of earlier conversation:\n{summary}",
        })
    for m in body_messages:
        if m.role in ("user", "assistant") and m.content.strip():
            msgs.append({"role": m.role, "content": m.content})
    return msgs


def _delta_text(delta: dict) -> str:
    c = delta.get("content")
    if c is None:
        return ""
    if isinstance(c, str):
        return c
    if isinstance(c, list):
        parts: list[str] = []
        for p in c:
            if isinstance(p, str):
                parts.append(p)
            elif isinstance(p, dict):
                t = p.get("text")
                if isinstance(t, str):
                    parts.append(t)
        return "".join(parts)
    return str(c)


def _is_context_overflow(error_text: str) -> bool:
    indicators = ["context length", "max_tokens", "too large", "too many tokens"]
    lower = error_text.lower()
    return any(ind in lower for ind in indicators)


async def _yield_sse_tokens(line_iter: AsyncIterator[str]) -> AsyncIterator[str]:
    async for line in line_iter:
        if not line:
            continue
        if not line.startswith("data: "):
            continue
        payload = line[6:].strip()
        if payload == "[DONE]":
            yield f"data: {json.dumps({'type': 'done'})}\n\n"
            return
        try:
            obj = json.loads(payload)
        except json.JSONDecodeError:
            continue
        err = obj.get("error")
        if err:
            yield f"data: {json.dumps({'type': 'error', 'detail': str(err)})}\n\n"
            return
        choices = obj.get("choices") or []
        if not choices:
            continue
        ch0 = choices[0] if isinstance(choices[0], dict) else {}
        delta = ch0.get("delta") or {}
        if not isinstance(delta, dict):
            delta = {}
        piece = _delta_text(delta)
        if not piece and isinstance(ch0.get("message"), dict):
            piece = _delta_text(ch0["message"])
        if piece:
            yield f"data: {json.dumps({'type': 'token', 'content': piece})}\n\n"
    yield f"data: {json.dumps({'type': 'done'})}\n\n"


def _vllm_headers() -> dict[str, str]:
    headers: dict[str, str] = {"Content-Type": "application/json"}
    if settings.vllm_api_key:
        headers["Authorization"] = f"Bearer {settings.vllm_api_key}"
    return headers


@app.get("/health")
async def health():
    return {
        "status": "ok",
        "model": settings.chat_model_id,
        "context_window": MODEL_CONTEXT_WINDOW,
        "max_tokens": settings.max_tokens,
    }


@app.post("/api/chat/stream")
async def chat_stream(body: ChatStreamBody):
    system = _build_system_prompt(body.extra_persona)
    msgs = _build_api_messages(system, body.messages, body.summary)
    input_tokens = _estimate_messages_tokens(msgs)
    reply_budget = body.max_tokens if body.max_tokens is not None else settings.max_tokens

    if input_tokens + reply_budget > MODEL_CONTEXT_WINDOW:
        reply_budget = max(256, MODEL_CONTEXT_WINDOW - input_tokens - 64)
        if reply_budget < 256:
            detail = (
                f"Context too large: ~{input_tokens} input tokens with a {MODEL_CONTEXT_WINDOW} "
                f"token window leaves no room for a reply."
            )
            async def overflow_gen():
                yield f"data: {json.dumps({'type': 'context_overflow', 'detail': detail, 'input_tokens': input_tokens, 'context_window': MODEL_CONTEXT_WINDOW})}\n\n"
            return StreamingResponse(overflow_gen(), media_type="text/event-stream", headers={"Cache-Control": "no-cache"})

    url = f"{settings.vllm_base_url.rstrip('/')}/chat/completions"
    req_body: dict = {
        "model": settings.chat_model_id,
        "messages": msgs,
        "stream": True,
        "temperature": body.temperature if body.temperature is not None else settings.temperature,
        "max_tokens": reply_budget,
        "stop": ["<|user|>", "<|end|>", "<|endoftext|>", "<|im_end|>", "</s>"],
    }

    async def event_gen():
        try:
            async with httpx.AsyncClient(timeout=httpx.Timeout(120.0, connect=15.0)) as client:
                async with client.stream("POST", url, json=req_body, headers=_vllm_headers()) as resp:
                    if resp.status_code >= 400:
                        text = (await resp.aread()).decode("utf-8", errors="replace")[:2000]
                        LOG.warning("vLLM error %s: %s", resp.status_code, text)
                        if _is_context_overflow(text):
                            yield f"data: {json.dumps({'type': 'context_overflow', 'detail': text})}\n\n"
                        else:
                            yield f"data: {json.dumps({'type': 'error', 'detail': text or resp.reason_phrase})}\n\n"
                        return
                    async for chunk in _yield_sse_tokens(resp.aiter_lines()):
                        yield chunk
        except httpx.RequestError as e:
            LOG.exception("vLLM request failed")
            yield f"data: {json.dumps({'type': 'error', 'detail': str(e)})}\n\n"

    return StreamingResponse(
        event_gen(),
        media_type="text/event-stream",
        headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
    )


@app.post("/api/chat/summarize")
async def chat_summarize(body: SummarizeBody):
    transcript_lines: list[str] = []
    for m in body.messages:
        label = "User" if m.role == "user" else "AI Jerry"
        transcript_lines.append(f"{label}: {m.content}")
    transcript = "\n".join(transcript_lines)

    msgs = [
        {"role": "system", "content": SUMMARIZE_SYSTEM},
        {"role": "user", "content": transcript},
    ]

    url = f"{settings.vllm_base_url.rstrip('/')}/chat/completions"
    req_body: dict = {
        "model": settings.chat_model_id,
        "messages": msgs,
        "stream": False,
        "temperature": 0.3,
        "max_tokens": 1024,
        "stop": ["<|user|>", "<|end|>", "<|endoftext|>", "<|im_end|>", "</s>"],
    }

    try:
        async with httpx.AsyncClient(timeout=httpx.Timeout(60.0, connect=15.0)) as client:
            resp = await client.post(url, json=req_body, headers=_vllm_headers())
            if resp.status_code >= 400:
                text = resp.text[:500]
                LOG.warning("Summarize error %s: %s", resp.status_code, text)
                return {"summary": "", "error": text}
            data = resp.json()
            choices = data.get("choices") or []
            if choices:
                msg = choices[0].get("message") or {}
                return {"summary": (msg.get("content") or "").strip()}
            return {"summary": "", "error": "No choices returned"}
    except Exception as e:
        LOG.exception("Summarize request failed")
        return {"summary": "", "error": str(e)}


@app.post("/api/chat/estimate")
async def chat_estimate(body: ChatStreamBody):
    system = _build_system_prompt(body.extra_persona)
    msgs = _build_api_messages(system, body.messages, body.summary)
    input_tokens = _estimate_messages_tokens(msgs)
    reply_budget = body.max_tokens if body.max_tokens is not None else settings.max_tokens
    return {
        "input_tokens": input_tokens,
        "reply_budget": reply_budget,
        "context_window": MODEL_CONTEXT_WINDOW,
        "headroom": MODEL_CONTEXT_WINDOW - input_tokens - reply_budget,
    }


def _extract_message_content(data: dict) -> str:
    choices = data.get("choices") or []
    if not choices:
        return ""
    msg = choices[0].get("message") or {}
    return (msg.get("content") or "").strip()


@app.post("/api/search-references")
async def search_references(body: SearchRefBody):
    """Generate a web search query from an assistant answer (for Perplexity / copy)."""
    text = (body.statement or "")[:_STATEMENT_MAX_CHARS]
    if not text.strip():
        return {"search_query": ""}

    msgs = [
        {"role": "system", "content": SEARCH_REF_SYSTEM},
        {
            "role": "user",
            "content": f"Assistant answer to analyze for research search terms:\n\n{text}",
        },
    ]
    url = f"{settings.vllm_base_url.rstrip('/')}/chat/completions"
    req_body: dict = {
        "model": settings.chat_model_id,
        "messages": msgs,
        "stream": False,
        "temperature": 0.25,
        "max_tokens": 200,
        "stop": ["<|user|>", "<|end|>", "<|endoftext|>", "<|redacted_im_end|>", "</s>"],
    }

    try:
        async with httpx.AsyncClient(timeout=httpx.Timeout(60.0, connect=15.0)) as client:
            resp = await client.post(url, json=req_body, headers=_vllm_headers())
            if resp.status_code >= 400:
                LOG.warning("search-references error %s: %s", resp.status_code, resp.text[:500])
                return {"search_query": text[:100].strip()}
            data = resp.json()
            q = _extract_message_content(data)
            return {"search_query": q or text[:100].strip()}
    except Exception as e:
        LOG.exception("search-references failed")
        return {"search_query": text[:100].strip()}


# Production / Hugging Face Spaces: Vite build copied to ./static (see Dockerfile)
if STATIC_DIR.is_dir():
    assets_dir = STATIC_DIR / "assets"
    if assets_dir.is_dir():
        app.mount("/assets", StaticFiles(directory=str(assets_dir)), name="vite-assets")

    @app.get("/")
    async def spa_index():
        return FileResponse(STATIC_DIR / "index.html")

    @app.get("/{full_path:path}")
    async def spa_fallback(full_path: str):
        if full_path.startswith("api"):
            raise HTTPException(status_code=404, detail="Not found")
        file_path = STATIC_DIR / full_path
        if full_path and file_path.is_file():
            return FileResponse(file_path)
        return FileResponse(STATIC_DIR / "index.html")