chatbot / src /chat_service.py
okoliechykwuka
Persist conversation memory across restarts
99db407
from __future__ import annotations
import json
import os
from collections import defaultdict, deque
from collections.abc import Generator
from dataclasses import dataclass
from pathlib import Path
from threading import Lock
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
from langchain_ollama import ChatOllama
from src.ingest import get_or_build_vectorstore
MAX_MEMORY_TURNS = int(os.getenv("RAG_MEMORY_TURNS", "6"))
LLM_MODEL = os.getenv("LLM_MODEL", "hf.co/LiquidAI/LFM2-1.2B-RAG-GGUF:Q5_K_M")
OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL", "http://localhost:11434")
OLLAMA_AUTH_TOKEN = os.getenv("OLLAMA_AUTH_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN") or os.getenv("HF_TOKEN")
MEMORY_FILE = os.getenv("RAG_MEMORY_FILE", "data/conversation_memory.jsonl")
def _ollama_client_kwargs() -> dict:
if not OLLAMA_AUTH_TOKEN:
return {}
return {"headers": {"Authorization": f"Bearer {OLLAMA_AUTH_TOKEN}"}}
@dataclass
class MemoryTurn:
user_message: str
assistant_message: str
class ConversationMemory:
def __init__(self, max_turns: int = MAX_MEMORY_TURNS, storage_path: str = MEMORY_FILE) -> None:
self._max_turns = max_turns
self._store: dict[str, deque[MemoryTurn]] = defaultdict(lambda: deque(maxlen=self._max_turns))
self._storage_path = Path(storage_path)
self._write_lock = Lock()
self._load_from_disk()
def _load_from_disk(self) -> None:
if not self._storage_path.exists():
return
try:
with self._storage_path.open("r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
item = json.loads(line)
conversation_id = str(item.get("conversation_id", "default"))
user_message = str(item.get("user_message", ""))
assistant_message = str(item.get("assistant_message", ""))
if user_message and assistant_message:
self._store[conversation_id].append(
MemoryTurn(user_message=user_message, assistant_message=assistant_message)
)
except Exception as e:
print(f"[memory] Failed to load memory file: {e}")
def _append_to_disk(self, conversation_id: str, turn: MemoryTurn) -> None:
try:
self._storage_path.parent.mkdir(parents=True, exist_ok=True)
payload = {
"conversation_id": conversation_id,
"user_message": turn.user_message,
"assistant_message": turn.assistant_message,
}
with self._write_lock:
with self._storage_path.open("a", encoding="utf-8") as f:
f.write(json.dumps(payload, ensure_ascii=True) + "\n")
except Exception as e:
print(f"[memory] Failed to persist memory turn: {e}")
def append(self, conversation_id: str, user_message: str, assistant_message: str) -> None:
turn = MemoryTurn(user_message=user_message, assistant_message=assistant_message)
self._store[conversation_id].append(turn)
self._append_to_disk(conversation_id, turn)
def format_history(self, conversation_id: str) -> str:
history = self._store.get(conversation_id)
if not history:
return "No previous conversation."
lines: list[str] = []
for turn in history:
lines.append(f"User: {turn.user_message}")
lines.append(f"Assistant: {turn.assistant_message}")
return "\n".join(lines)
class RagChatService:
def __init__(self, k: int = 4) -> None:
self._k = k
self._vectorstore = None
self._retriever = None
self._llm = None
self._memory = ConversationMemory()
def _get_retriever(self):
if self._retriever is None:
self._vectorstore = get_or_build_vectorstore()
self._retriever = self._vectorstore.as_retriever(search_kwargs={"k": self._k})
return self._retriever
def _get_llm(self) -> ChatOllama:
if self._llm is None:
self._llm = ChatOllama(
model=LLM_MODEL,
temperature=0.2,
base_url=OLLAMA_BASE_URL,
client_kwargs=_ollama_client_kwargs(),
)
return self._llm
def _format_context(self, question: str) -> str:
docs = self._get_retriever().invoke(question)
if not docs:
return "No relevant FAQ context found."
return "\n\n".join(doc.page_content for doc in docs)
def _build_messages(self, question: str, conversation_id: str) -> list[BaseMessage]:
history = self._memory.format_history(conversation_id)
context = self._format_context(question)
system_prompt = (
"You are a concise and helpful support assistant for 9jaLingo, a voice AI platform. "
"Use only the provided FAQ context and recent conversation history. "
"If the answer is not in the context, say that clearly and direct the user to official support.\n\n"
f"Conversation history:\n{history}\n\n"
f"FAQ context:\n{context}"
)
return [
SystemMessage(content=system_prompt),
HumanMessage(content=question),
]
def chat(self, question: str, conversation_id: str) -> str:
messages = self._build_messages(question, conversation_id)
response = self._get_llm().invoke(messages)
answer = response.content if isinstance(response.content, str) else str(response.content)
self._memory.append(conversation_id, question, answer)
return answer
def stream(self, question: str, conversation_id: str) -> Generator[str, None, None]:
messages = self._build_messages(question, conversation_id)
parts: list[str] = []
for chunk in self._get_llm().stream(messages):
content = chunk.content if isinstance(chunk.content, str) else str(chunk.content)
if not content:
continue
parts.append(content)
yield content
self._memory.append(conversation_id, question, "".join(parts))
chat_service = RagChatService()