File size: 3,616 Bytes
5d50e27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
"""e-Padoms FastAPI backend."""
import asyncio
import time
import uuid
from typing import Literal

from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from sse_starlette.sse import EventSourceResponse

from rag.pipeline import chat_stream
from rag.retriever import Retriever

app = FastAPI(title="e-Padoms API", version="1.0.0")

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)

# In-memory conversation store: {conversation_id: {"messages": [...], "last_active": float}}
_conversations: dict[str, dict] = {}
CONVERSATION_TTL = 30 * 60  # 30 minutes

_retriever = Retriever()

SUGGESTIONS = [
    "How do I register a SIA (limited liability company) in Latvia?",
    "What taxes apply to Latvian startups?",
    "What is the Startup Law and how do I qualify?",
    "What is LIAA and how does it support startups?",
    "Can a foreigner start a company in Latvia?",
    "What is the difference between SIA and IK?",
]


def _expire_conversations():
    now = time.time()
    expired = [k for k, v in _conversations.items() if now - v["last_active"] > CONVERSATION_TTL]
    for k in expired:
        del _conversations[k]


class ChatRequest(BaseModel):
    message: str
    conversation_id: str | None = None
    language: Literal["lv", "en", "auto"] = "auto"


@app.get("/api/health")
async def health():
    return {"status": "ok", "documents_indexed": _retriever.count()}


@app.get("/api/suggestions")
async def suggestions():
    return {"suggestions": SUGGESTIONS}


@app.post("/api/chat")
async def chat(req: ChatRequest):
    if not req.message.strip():
        raise HTTPException(status_code=400, detail="Message cannot be empty")
    if len(req.message) > 2000:
        raise HTTPException(status_code=400, detail="Message too long (max 2000 chars)")

    _expire_conversations()

    conv_id = req.conversation_id or str(uuid.uuid4())
    if conv_id not in _conversations:
        _conversations[conv_id] = {"messages": [], "last_active": time.time()}

    conv = _conversations[conv_id]
    conv["last_active"] = time.time()
    history = conv["messages"]

    # Add user message to history
    history.append({"role": "user", "content": req.message})
    # Keep last 10 messages (5 turns)
    if len(history) > 10:
        history[:] = history[-10:]

    assistant_response_parts: list[str] = []

    async def event_generator():
        try:
            async for event_data in chat_stream(req.message, history[:-1], req.language):
                import json
                parsed = json.loads(event_data)
                if parsed["type"] == "token":
                    assistant_response_parts.append(parsed["content"])
                yield {"data": event_data}
        except Exception as e:
            import json
            yield {"data": json.dumps({"type": "error", "content": str(e)})}
        finally:
            # Store assistant's full response in history
            full_response = "".join(assistant_response_parts)
            if full_response:
                history.append({"role": "assistant", "content": full_response})
                if len(history) > 10:
                    history[:] = history[-10:]

    return EventSourceResponse(
        event_generator(),
        headers={"X-Accel-Buffering": "no"},
    )
from fastapi.staticfiles import StaticFiles
from pathlib import Path

_here = Path(__file__).parent
_frontend = _here / "frontend_dist"
app.mount("/", StaticFiles(directory=str(_frontend), html=True), name="frontend")