Spaces:
Running
Running
| from __future__ import annotations | |
| import json | |
| import logging | |
| from pathlib import Path | |
| from typing import AsyncIterator | |
| import httpx | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import StreamingResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from pydantic import BaseModel, Field | |
| from starlette.responses import FileResponse | |
| from app.config import settings | |
| LOG = logging.getLogger(__name__) | |
| STATIC_DIR = Path(__file__).resolve().parent.parent / "static" | |
| logging.basicConfig(level=logging.INFO) | |
| app = FastAPI(title="Ask Jerry API") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=settings.cors_origin_list, | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| JERRY_SYSTEM_BASE = ( | |
| "You are AI Jerry, a cybersecurity-focused assistant running on the BrainForge Security model. " | |
| "You give clear, practical guidance; you distinguish facts from speculation; you flag risks and " | |
| "compliance considerations when relevant. You are friendly and professional." | |
| ) | |
| MODEL_CONTEXT_WINDOW = 8192 | |
| SUMMARIZE_SYSTEM = ( | |
| "You are a concise summarizer. Condense the following conversation into a short summary " | |
| "that preserves the key topics discussed, any conclusions reached, important facts shared, " | |
| "and the overall tone. Keep it under 300 words. Write in third person narrative form." | |
| ) | |
| _STATEMENT_MAX_CHARS = 4000 | |
| SEARCH_REF_SYSTEM = """You help a user find web sources to **research** an AI assistant's cybersecurity answer. | |
| You will receive the full text of that answer. Do this in order: | |
| 1. **Facts** — Identify the main factual claims (CVEs, standards, protocols, vendor names, regulations, definitions, and procedural steps). | |
| 2. **Meaning** — In a few words, capture the overall gist: what the answer is explaining or recommending. | |
| 3. **Search query** — Compose one concise web search query (or two short queries separated by `; `) optimized to find **authoritative references** that could verify or deepen those facts—e.g. NIST, CISA, vendor docs, RFCs, CWE/CVE pages, or reputable security guidance. | |
| Rules: | |
| - Prefer concrete, verifiable keywords from the text. | |
| - The query should help someone **research** the topic, not merely restate the answer in different words. | |
| - Do not include meta-commentary, labels, bullets, or step numbers in your output. | |
| - Return **only** the search query string (or two `; `-separated queries), with no quotes around the whole thing and no preamble.""" | |
| class ChatMessage(BaseModel): | |
| role: str | |
| content: str | |
| class ChatStreamBody(BaseModel): | |
| messages: list[ChatMessage] = Field(..., min_length=1) | |
| extra_persona: str = "" | |
| temperature: float | None = None | |
| max_tokens: int | None = None | |
| summary: str | None = None | |
| class SummarizeBody(BaseModel): | |
| messages: list[ChatMessage] = Field(..., min_length=1) | |
| extra_persona: str = "" | |
| class SearchRefBody(BaseModel): | |
| statement: str = "" | |
| def _build_system_prompt(extra_persona: str) -> str: | |
| extra = (extra_persona or "").strip() | |
| if not extra: | |
| return JERRY_SYSTEM_BASE | |
| return f"{JERRY_SYSTEM_BASE}\n\nAdditional instructions from the user:\n{extra}" | |
| def _estimate_tokens(text: str) -> int: | |
| return max(1, len(text) // 4) | |
| def _estimate_messages_tokens(msgs: list[dict]) -> int: | |
| total = 0 | |
| for m in msgs: | |
| total += _estimate_tokens(m.get("content", "")) + 4 | |
| return total | |
| def _build_api_messages( | |
| system: str, | |
| body_messages: list[ChatMessage], | |
| summary: str | None, | |
| ) -> list[dict]: | |
| msgs: list[dict] = [{"role": "system", "content": system}] | |
| if summary: | |
| msgs.append({ | |
| "role": "system", | |
| "content": f"Summary of earlier conversation:\n{summary}", | |
| }) | |
| for m in body_messages: | |
| if m.role in ("user", "assistant") and m.content.strip(): | |
| msgs.append({"role": m.role, "content": m.content}) | |
| return msgs | |
| def _delta_text(delta: dict) -> str: | |
| c = delta.get("content") | |
| if c is None: | |
| return "" | |
| if isinstance(c, str): | |
| return c | |
| if isinstance(c, list): | |
| parts: list[str] = [] | |
| for p in c: | |
| if isinstance(p, str): | |
| parts.append(p) | |
| elif isinstance(p, dict): | |
| t = p.get("text") | |
| if isinstance(t, str): | |
| parts.append(t) | |
| return "".join(parts) | |
| return str(c) | |
| def _is_context_overflow(error_text: str) -> bool: | |
| indicators = ["context length", "max_tokens", "too large", "too many tokens"] | |
| lower = error_text.lower() | |
| return any(ind in lower for ind in indicators) | |
| async def _yield_sse_tokens(line_iter: AsyncIterator[str]) -> AsyncIterator[str]: | |
| async for line in line_iter: | |
| if not line: | |
| continue | |
| if not line.startswith("data: "): | |
| continue | |
| payload = line[6:].strip() | |
| if payload == "[DONE]": | |
| yield f"data: {json.dumps({'type': 'done'})}\n\n" | |
| return | |
| try: | |
| obj = json.loads(payload) | |
| except json.JSONDecodeError: | |
| continue | |
| err = obj.get("error") | |
| if err: | |
| yield f"data: {json.dumps({'type': 'error', 'detail': str(err)})}\n\n" | |
| return | |
| choices = obj.get("choices") or [] | |
| if not choices: | |
| continue | |
| ch0 = choices[0] if isinstance(choices[0], dict) else {} | |
| delta = ch0.get("delta") or {} | |
| if not isinstance(delta, dict): | |
| delta = {} | |
| piece = _delta_text(delta) | |
| if not piece and isinstance(ch0.get("message"), dict): | |
| piece = _delta_text(ch0["message"]) | |
| if piece: | |
| yield f"data: {json.dumps({'type': 'token', 'content': piece})}\n\n" | |
| yield f"data: {json.dumps({'type': 'done'})}\n\n" | |
| def _vllm_headers() -> dict[str, str]: | |
| headers: dict[str, str] = {"Content-Type": "application/json"} | |
| if settings.vllm_api_key: | |
| headers["Authorization"] = f"Bearer {settings.vllm_api_key}" | |
| return headers | |
| async def health(): | |
| return { | |
| "status": "ok", | |
| "model": settings.chat_model_id, | |
| "context_window": MODEL_CONTEXT_WINDOW, | |
| "max_tokens": settings.max_tokens, | |
| } | |
| async def chat_stream(body: ChatStreamBody): | |
| system = _build_system_prompt(body.extra_persona) | |
| msgs = _build_api_messages(system, body.messages, body.summary) | |
| input_tokens = _estimate_messages_tokens(msgs) | |
| reply_budget = body.max_tokens if body.max_tokens is not None else settings.max_tokens | |
| if input_tokens + reply_budget > MODEL_CONTEXT_WINDOW: | |
| reply_budget = max(256, MODEL_CONTEXT_WINDOW - input_tokens - 64) | |
| if reply_budget < 256: | |
| detail = ( | |
| f"Context too large: ~{input_tokens} input tokens with a {MODEL_CONTEXT_WINDOW} " | |
| f"token window leaves no room for a reply." | |
| ) | |
| async def overflow_gen(): | |
| yield f"data: {json.dumps({'type': 'context_overflow', 'detail': detail, 'input_tokens': input_tokens, 'context_window': MODEL_CONTEXT_WINDOW})}\n\n" | |
| return StreamingResponse(overflow_gen(), media_type="text/event-stream", headers={"Cache-Control": "no-cache"}) | |
| url = f"{settings.vllm_base_url.rstrip('/')}/chat/completions" | |
| req_body: dict = { | |
| "model": settings.chat_model_id, | |
| "messages": msgs, | |
| "stream": True, | |
| "temperature": body.temperature if body.temperature is not None else settings.temperature, | |
| "max_tokens": reply_budget, | |
| "stop": ["<|user|>", "<|end|>", "<|endoftext|>", "<|im_end|>", "</s>"], | |
| } | |
| async def event_gen(): | |
| try: | |
| async with httpx.AsyncClient(timeout=httpx.Timeout(120.0, connect=15.0)) as client: | |
| async with client.stream("POST", url, json=req_body, headers=_vllm_headers()) as resp: | |
| if resp.status_code >= 400: | |
| text = (await resp.aread()).decode("utf-8", errors="replace")[:2000] | |
| LOG.warning("vLLM error %s: %s", resp.status_code, text) | |
| if _is_context_overflow(text): | |
| yield f"data: {json.dumps({'type': 'context_overflow', 'detail': text})}\n\n" | |
| else: | |
| yield f"data: {json.dumps({'type': 'error', 'detail': text or resp.reason_phrase})}\n\n" | |
| return | |
| async for chunk in _yield_sse_tokens(resp.aiter_lines()): | |
| yield chunk | |
| except httpx.RequestError as e: | |
| LOG.exception("vLLM request failed") | |
| yield f"data: {json.dumps({'type': 'error', 'detail': str(e)})}\n\n" | |
| return StreamingResponse( | |
| event_gen(), | |
| media_type="text/event-stream", | |
| headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}, | |
| ) | |
| async def chat_summarize(body: SummarizeBody): | |
| transcript_lines: list[str] = [] | |
| for m in body.messages: | |
| label = "User" if m.role == "user" else "AI Jerry" | |
| transcript_lines.append(f"{label}: {m.content}") | |
| transcript = "\n".join(transcript_lines) | |
| msgs = [ | |
| {"role": "system", "content": SUMMARIZE_SYSTEM}, | |
| {"role": "user", "content": transcript}, | |
| ] | |
| url = f"{settings.vllm_base_url.rstrip('/')}/chat/completions" | |
| req_body: dict = { | |
| "model": settings.chat_model_id, | |
| "messages": msgs, | |
| "stream": False, | |
| "temperature": 0.3, | |
| "max_tokens": 1024, | |
| "stop": ["<|user|>", "<|end|>", "<|endoftext|>", "<|im_end|>", "</s>"], | |
| } | |
| try: | |
| async with httpx.AsyncClient(timeout=httpx.Timeout(60.0, connect=15.0)) as client: | |
| resp = await client.post(url, json=req_body, headers=_vllm_headers()) | |
| if resp.status_code >= 400: | |
| text = resp.text[:500] | |
| LOG.warning("Summarize error %s: %s", resp.status_code, text) | |
| return {"summary": "", "error": text} | |
| data = resp.json() | |
| choices = data.get("choices") or [] | |
| if choices: | |
| msg = choices[0].get("message") or {} | |
| return {"summary": (msg.get("content") or "").strip()} | |
| return {"summary": "", "error": "No choices returned"} | |
| except Exception as e: | |
| LOG.exception("Summarize request failed") | |
| return {"summary": "", "error": str(e)} | |
| async def chat_estimate(body: ChatStreamBody): | |
| system = _build_system_prompt(body.extra_persona) | |
| msgs = _build_api_messages(system, body.messages, body.summary) | |
| input_tokens = _estimate_messages_tokens(msgs) | |
| reply_budget = body.max_tokens if body.max_tokens is not None else settings.max_tokens | |
| return { | |
| "input_tokens": input_tokens, | |
| "reply_budget": reply_budget, | |
| "context_window": MODEL_CONTEXT_WINDOW, | |
| "headroom": MODEL_CONTEXT_WINDOW - input_tokens - reply_budget, | |
| } | |
| def _extract_message_content(data: dict) -> str: | |
| choices = data.get("choices") or [] | |
| if not choices: | |
| return "" | |
| msg = choices[0].get("message") or {} | |
| return (msg.get("content") or "").strip() | |
| async def search_references(body: SearchRefBody): | |
| """Generate a web search query from an assistant answer (for Perplexity / copy).""" | |
| text = (body.statement or "")[:_STATEMENT_MAX_CHARS] | |
| if not text.strip(): | |
| return {"search_query": ""} | |
| msgs = [ | |
| {"role": "system", "content": SEARCH_REF_SYSTEM}, | |
| { | |
| "role": "user", | |
| "content": f"Assistant answer to analyze for research search terms:\n\n{text}", | |
| }, | |
| ] | |
| url = f"{settings.vllm_base_url.rstrip('/')}/chat/completions" | |
| req_body: dict = { | |
| "model": settings.chat_model_id, | |
| "messages": msgs, | |
| "stream": False, | |
| "temperature": 0.25, | |
| "max_tokens": 200, | |
| "stop": ["<|user|>", "<|end|>", "<|endoftext|>", "<|redacted_im_end|>", "</s>"], | |
| } | |
| try: | |
| async with httpx.AsyncClient(timeout=httpx.Timeout(60.0, connect=15.0)) as client: | |
| resp = await client.post(url, json=req_body, headers=_vllm_headers()) | |
| if resp.status_code >= 400: | |
| LOG.warning("search-references error %s: %s", resp.status_code, resp.text[:500]) | |
| return {"search_query": text[:100].strip()} | |
| data = resp.json() | |
| q = _extract_message_content(data) | |
| return {"search_query": q or text[:100].strip()} | |
| except Exception as e: | |
| LOG.exception("search-references failed") | |
| return {"search_query": text[:100].strip()} | |
| # Production / Hugging Face Spaces: Vite build copied to ./static (see Dockerfile) | |
| if STATIC_DIR.is_dir(): | |
| assets_dir = STATIC_DIR / "assets" | |
| if assets_dir.is_dir(): | |
| app.mount("/assets", StaticFiles(directory=str(assets_dir)), name="vite-assets") | |
| async def spa_index(): | |
| return FileResponse(STATIC_DIR / "index.html") | |
| async def spa_fallback(full_path: str): | |
| if full_path.startswith("api"): | |
| raise HTTPException(status_code=404, detail="Not found") | |
| file_path = STATIC_DIR / full_path | |
| if full_path and file_path.is_file(): | |
| return FileResponse(file_path) | |
| return FileResponse(STATIC_DIR / "index.html") | |