File size: 6,528 Bytes
9f031f6
 
99db407
9f031f6
 
 
 
99db407
 
9f031f6
 
 
 
 
 
 
 
 
0d06c0b
c2659c1
99db407
c2659c1
 
 
 
 
 
9f031f6
 
 
 
 
 
 
 
 
99db407
9f031f6
 
99db407
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9f031f6
 
99db407
 
 
9f031f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0d06c0b
 
 
 
c2659c1
0d06c0b
9f031f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
from __future__ import annotations

import json
import os
from collections import defaultdict, deque
from collections.abc import Generator
from dataclasses import dataclass
from pathlib import Path
from threading import Lock

from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
from langchain_ollama import ChatOllama

from src.ingest import get_or_build_vectorstore


MAX_MEMORY_TURNS = int(os.getenv("RAG_MEMORY_TURNS", "6"))
LLM_MODEL = os.getenv("LLM_MODEL", "hf.co/LiquidAI/LFM2-1.2B-RAG-GGUF:Q5_K_M")
OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL", "http://localhost:11434")
OLLAMA_AUTH_TOKEN = os.getenv("OLLAMA_AUTH_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN") or os.getenv("HF_TOKEN")
MEMORY_FILE = os.getenv("RAG_MEMORY_FILE", "data/conversation_memory.jsonl")


def _ollama_client_kwargs() -> dict:
    if not OLLAMA_AUTH_TOKEN:
        return {}
    return {"headers": {"Authorization": f"Bearer {OLLAMA_AUTH_TOKEN}"}}


@dataclass
class MemoryTurn:
    user_message: str
    assistant_message: str


class ConversationMemory:
    def __init__(self, max_turns: int = MAX_MEMORY_TURNS, storage_path: str = MEMORY_FILE) -> None:
        self._max_turns = max_turns
        self._store: dict[str, deque[MemoryTurn]] = defaultdict(lambda: deque(maxlen=self._max_turns))
        self._storage_path = Path(storage_path)
        self._write_lock = Lock()
        self._load_from_disk()

    def _load_from_disk(self) -> None:
        if not self._storage_path.exists():
            return
        try:
            with self._storage_path.open("r", encoding="utf-8") as f:
                for line in f:
                    line = line.strip()
                    if not line:
                        continue
                    item = json.loads(line)
                    conversation_id = str(item.get("conversation_id", "default"))
                    user_message = str(item.get("user_message", ""))
                    assistant_message = str(item.get("assistant_message", ""))
                    if user_message and assistant_message:
                        self._store[conversation_id].append(
                            MemoryTurn(user_message=user_message, assistant_message=assistant_message)
                        )
        except Exception as e:
            print(f"[memory] Failed to load memory file: {e}")

    def _append_to_disk(self, conversation_id: str, turn: MemoryTurn) -> None:
        try:
            self._storage_path.parent.mkdir(parents=True, exist_ok=True)
            payload = {
                "conversation_id": conversation_id,
                "user_message": turn.user_message,
                "assistant_message": turn.assistant_message,
            }
            with self._write_lock:
                with self._storage_path.open("a", encoding="utf-8") as f:
                    f.write(json.dumps(payload, ensure_ascii=True) + "\n")
        except Exception as e:
            print(f"[memory] Failed to persist memory turn: {e}")

    def append(self, conversation_id: str, user_message: str, assistant_message: str) -> None:
        turn = MemoryTurn(user_message=user_message, assistant_message=assistant_message)
        self._store[conversation_id].append(turn)
        self._append_to_disk(conversation_id, turn)

    def format_history(self, conversation_id: str) -> str:
        history = self._store.get(conversation_id)
        if not history:
            return "No previous conversation."

        lines: list[str] = []
        for turn in history:
            lines.append(f"User: {turn.user_message}")
            lines.append(f"Assistant: {turn.assistant_message}")
        return "\n".join(lines)


class RagChatService:
    def __init__(self, k: int = 4) -> None:
        self._k = k
        self._vectorstore = None
        self._retriever = None
        self._llm = None
        self._memory = ConversationMemory()

    def _get_retriever(self):
        if self._retriever is None:
            self._vectorstore = get_or_build_vectorstore()
            self._retriever = self._vectorstore.as_retriever(search_kwargs={"k": self._k})
        return self._retriever

    def _get_llm(self) -> ChatOllama:
        if self._llm is None:
            self._llm = ChatOllama(
                model=LLM_MODEL,
                temperature=0.2,
                base_url=OLLAMA_BASE_URL,
                client_kwargs=_ollama_client_kwargs(),
            )
        return self._llm

    def _format_context(self, question: str) -> str:
        docs = self._get_retriever().invoke(question)
        if not docs:
            return "No relevant FAQ context found."
        return "\n\n".join(doc.page_content for doc in docs)

    def _build_messages(self, question: str, conversation_id: str) -> list[BaseMessage]:
        history = self._memory.format_history(conversation_id)
        context = self._format_context(question)
        system_prompt = (
            "You are a concise and helpful support assistant for 9jaLingo, a voice AI platform. "
            "Use only the provided FAQ context and recent conversation history. "
            "If the answer is not in the context, say that clearly and direct the user to official support.\n\n"
            f"Conversation history:\n{history}\n\n"
            f"FAQ context:\n{context}"
        )
        return [
            SystemMessage(content=system_prompt),
            HumanMessage(content=question),
        ]

    def chat(self, question: str, conversation_id: str) -> str:
        messages = self._build_messages(question, conversation_id)
        response = self._get_llm().invoke(messages)
        answer = response.content if isinstance(response.content, str) else str(response.content)
        self._memory.append(conversation_id, question, answer)
        return answer

    def stream(self, question: str, conversation_id: str) -> Generator[str, None, None]:
        messages = self._build_messages(question, conversation_id)
        parts: list[str] = []
        for chunk in self._get_llm().stream(messages):
            content = chunk.content if isinstance(chunk.content, str) else str(chunk.content)
            if not content:
                continue
            parts.append(content)
            yield content

        self._memory.append(conversation_id, question, "".join(parts))


chat_service = RagChatService()