| """e-Padoms FastAPI backend.""" |
| import asyncio |
| import time |
| import uuid |
| from typing import Literal |
|
|
| from fastapi import FastAPI, HTTPException |
| from fastapi.middleware.cors import CORSMiddleware |
| from pydantic import BaseModel |
| from sse_starlette.sse import EventSourceResponse |
|
|
| from rag.pipeline import chat_stream |
| from rag.retriever import Retriever |
|
|
| app = FastAPI(title="e-Padoms API", version="1.0.0") |
|
|
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| |
| _conversations: dict[str, dict] = {} |
| CONVERSATION_TTL = 30 * 60 |
|
|
| _retriever = Retriever() |
|
|
| SUGGESTIONS = [ |
| "How do I register a SIA (limited liability company) in Latvia?", |
| "What taxes apply to Latvian startups?", |
| "What is the Startup Law and how do I qualify?", |
| "What is LIAA and how does it support startups?", |
| "Can a foreigner start a company in Latvia?", |
| "What is the difference between SIA and IK?", |
| ] |
|
|
|
|
| def _expire_conversations(): |
| now = time.time() |
| expired = [k for k, v in _conversations.items() if now - v["last_active"] > CONVERSATION_TTL] |
| for k in expired: |
| del _conversations[k] |
|
|
|
|
| class ChatRequest(BaseModel): |
| message: str |
| conversation_id: str | None = None |
| language: Literal["lv", "en", "auto"] = "auto" |
|
|
|
|
| @app.get("/api/health") |
| async def health(): |
| return {"status": "ok", "documents_indexed": _retriever.count()} |
|
|
|
|
| @app.get("/api/suggestions") |
| async def suggestions(): |
| return {"suggestions": SUGGESTIONS} |
|
|
|
|
| @app.post("/api/chat") |
| async def chat(req: ChatRequest): |
| if not req.message.strip(): |
| raise HTTPException(status_code=400, detail="Message cannot be empty") |
| if len(req.message) > 2000: |
| raise HTTPException(status_code=400, detail="Message too long (max 2000 chars)") |
|
|
| _expire_conversations() |
|
|
| conv_id = req.conversation_id or str(uuid.uuid4()) |
| if conv_id not in _conversations: |
| _conversations[conv_id] = {"messages": [], "last_active": time.time()} |
|
|
| conv = _conversations[conv_id] |
| conv["last_active"] = time.time() |
| history = conv["messages"] |
|
|
| |
| history.append({"role": "user", "content": req.message}) |
| |
| if len(history) > 10: |
| history[:] = history[-10:] |
|
|
| assistant_response_parts: list[str] = [] |
|
|
| async def event_generator(): |
| try: |
| async for event_data in chat_stream(req.message, history[:-1], req.language): |
| import json |
| parsed = json.loads(event_data) |
| if parsed["type"] == "token": |
| assistant_response_parts.append(parsed["content"]) |
| yield {"data": event_data} |
| except Exception as e: |
| import json |
| yield {"data": json.dumps({"type": "error", "content": str(e)})} |
| finally: |
| |
| full_response = "".join(assistant_response_parts) |
| if full_response: |
| history.append({"role": "assistant", "content": full_response}) |
| if len(history) > 10: |
| history[:] = history[-10:] |
|
|
| return EventSourceResponse( |
| event_generator(), |
| headers={"X-Accel-Buffering": "no"}, |
| ) |
| from fastapi.staticfiles import StaticFiles |
| from pathlib import Path |
|
|
| _here = Path(__file__).parent |
| _frontend = _here / "frontend_dist" |
| app.mount("/", StaticFiles(directory=str(_frontend), html=True), name="frontend") |
|
|