| from __future__ import annotations
|
|
|
| import json
|
| import os
|
| from collections import defaultdict, deque
|
| from collections.abc import Generator
|
| from dataclasses import dataclass
|
| from pathlib import Path
|
| from threading import Lock
|
|
|
| from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
|
| from langchain_ollama import ChatOllama
|
|
|
| from src.ingest import get_or_build_vectorstore
|
|
|
|
|
| MAX_MEMORY_TURNS = int(os.getenv("RAG_MEMORY_TURNS", "6"))
|
| LLM_MODEL = os.getenv("LLM_MODEL", "hf.co/LiquidAI/LFM2-1.2B-RAG-GGUF:Q5_K_M")
|
| OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL", "http://localhost:11434")
|
| OLLAMA_AUTH_TOKEN = os.getenv("OLLAMA_AUTH_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN") or os.getenv("HF_TOKEN")
|
| MEMORY_FILE = os.getenv("RAG_MEMORY_FILE", "data/conversation_memory.jsonl")
|
|
|
|
|
| def _ollama_client_kwargs() -> dict:
|
| if not OLLAMA_AUTH_TOKEN:
|
| return {}
|
| return {"headers": {"Authorization": f"Bearer {OLLAMA_AUTH_TOKEN}"}}
|
|
|
|
|
| @dataclass
|
| class MemoryTurn:
|
| user_message: str
|
| assistant_message: str
|
|
|
|
|
| class ConversationMemory:
|
| def __init__(self, max_turns: int = MAX_MEMORY_TURNS, storage_path: str = MEMORY_FILE) -> None:
|
| self._max_turns = max_turns
|
| self._store: dict[str, deque[MemoryTurn]] = defaultdict(lambda: deque(maxlen=self._max_turns))
|
| self._storage_path = Path(storage_path)
|
| self._write_lock = Lock()
|
| self._load_from_disk()
|
|
|
| def _load_from_disk(self) -> None:
|
| if not self._storage_path.exists():
|
| return
|
| try:
|
| with self._storage_path.open("r", encoding="utf-8") as f:
|
| for line in f:
|
| line = line.strip()
|
| if not line:
|
| continue
|
| item = json.loads(line)
|
| conversation_id = str(item.get("conversation_id", "default"))
|
| user_message = str(item.get("user_message", ""))
|
| assistant_message = str(item.get("assistant_message", ""))
|
| if user_message and assistant_message:
|
| self._store[conversation_id].append(
|
| MemoryTurn(user_message=user_message, assistant_message=assistant_message)
|
| )
|
| except Exception as e:
|
| print(f"[memory] Failed to load memory file: {e}")
|
|
|
| def _append_to_disk(self, conversation_id: str, turn: MemoryTurn) -> None:
|
| try:
|
| self._storage_path.parent.mkdir(parents=True, exist_ok=True)
|
| payload = {
|
| "conversation_id": conversation_id,
|
| "user_message": turn.user_message,
|
| "assistant_message": turn.assistant_message,
|
| }
|
| with self._write_lock:
|
| with self._storage_path.open("a", encoding="utf-8") as f:
|
| f.write(json.dumps(payload, ensure_ascii=True) + "\n")
|
| except Exception as e:
|
| print(f"[memory] Failed to persist memory turn: {e}")
|
|
|
| def append(self, conversation_id: str, user_message: str, assistant_message: str) -> None:
|
| turn = MemoryTurn(user_message=user_message, assistant_message=assistant_message)
|
| self._store[conversation_id].append(turn)
|
| self._append_to_disk(conversation_id, turn)
|
|
|
| def format_history(self, conversation_id: str) -> str:
|
| history = self._store.get(conversation_id)
|
| if not history:
|
| return "No previous conversation."
|
|
|
| lines: list[str] = []
|
| for turn in history:
|
| lines.append(f"User: {turn.user_message}")
|
| lines.append(f"Assistant: {turn.assistant_message}")
|
| return "\n".join(lines)
|
|
|
|
|
| class RagChatService:
|
| def __init__(self, k: int = 4) -> None:
|
| self._k = k
|
| self._vectorstore = None
|
| self._retriever = None
|
| self._llm = None
|
| self._memory = ConversationMemory()
|
|
|
| def _get_retriever(self):
|
| if self._retriever is None:
|
| self._vectorstore = get_or_build_vectorstore()
|
| self._retriever = self._vectorstore.as_retriever(search_kwargs={"k": self._k})
|
| return self._retriever
|
|
|
| def _get_llm(self) -> ChatOllama:
|
| if self._llm is None:
|
| self._llm = ChatOllama(
|
| model=LLM_MODEL,
|
| temperature=0.2,
|
| base_url=OLLAMA_BASE_URL,
|
| client_kwargs=_ollama_client_kwargs(),
|
| )
|
| return self._llm
|
|
|
| def _format_context(self, question: str) -> str:
|
| docs = self._get_retriever().invoke(question)
|
| if not docs:
|
| return "No relevant FAQ context found."
|
| return "\n\n".join(doc.page_content for doc in docs)
|
|
|
| def _build_messages(self, question: str, conversation_id: str) -> list[BaseMessage]:
|
| history = self._memory.format_history(conversation_id)
|
| context = self._format_context(question)
|
| system_prompt = (
|
| "You are a concise and helpful support assistant for 9jaLingo, a voice AI platform. "
|
| "Use only the provided FAQ context and recent conversation history. "
|
| "If the answer is not in the context, say that clearly and direct the user to official support.\n\n"
|
| f"Conversation history:\n{history}\n\n"
|
| f"FAQ context:\n{context}"
|
| )
|
| return [
|
| SystemMessage(content=system_prompt),
|
| HumanMessage(content=question),
|
| ]
|
|
|
| def chat(self, question: str, conversation_id: str) -> str:
|
| messages = self._build_messages(question, conversation_id)
|
| response = self._get_llm().invoke(messages)
|
| answer = response.content if isinstance(response.content, str) else str(response.content)
|
| self._memory.append(conversation_id, question, answer)
|
| return answer
|
|
|
| def stream(self, question: str, conversation_id: str) -> Generator[str, None, None]:
|
| messages = self._build_messages(question, conversation_id)
|
| parts: list[str] = []
|
| for chunk in self._get_llm().stream(messages):
|
| content = chunk.content if isinstance(chunk.content, str) else str(chunk.content)
|
| if not content:
|
| continue
|
| parts.append(content)
|
| yield content
|
|
|
| self._memory.append(conversation_id, question, "".join(parts))
|
|
|
|
|
| chat_service = RagChatService() |