File size: 6,528 Bytes
9f031f6 99db407 9f031f6 99db407 9f031f6 0d06c0b c2659c1 99db407 c2659c1 9f031f6 99db407 9f031f6 99db407 9f031f6 99db407 9f031f6 0d06c0b c2659c1 0d06c0b 9f031f6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 | from __future__ import annotations
import json
import os
from collections import defaultdict, deque
from collections.abc import Generator
from dataclasses import dataclass
from pathlib import Path
from threading import Lock
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
from langchain_ollama import ChatOllama
from src.ingest import get_or_build_vectorstore
MAX_MEMORY_TURNS = int(os.getenv("RAG_MEMORY_TURNS", "6"))
LLM_MODEL = os.getenv("LLM_MODEL", "hf.co/LiquidAI/LFM2-1.2B-RAG-GGUF:Q5_K_M")
OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL", "http://localhost:11434")
OLLAMA_AUTH_TOKEN = os.getenv("OLLAMA_AUTH_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN") or os.getenv("HF_TOKEN")
MEMORY_FILE = os.getenv("RAG_MEMORY_FILE", "data/conversation_memory.jsonl")
def _ollama_client_kwargs() -> dict:
if not OLLAMA_AUTH_TOKEN:
return {}
return {"headers": {"Authorization": f"Bearer {OLLAMA_AUTH_TOKEN}"}}
@dataclass
class MemoryTurn:
user_message: str
assistant_message: str
class ConversationMemory:
def __init__(self, max_turns: int = MAX_MEMORY_TURNS, storage_path: str = MEMORY_FILE) -> None:
self._max_turns = max_turns
self._store: dict[str, deque[MemoryTurn]] = defaultdict(lambda: deque(maxlen=self._max_turns))
self._storage_path = Path(storage_path)
self._write_lock = Lock()
self._load_from_disk()
def _load_from_disk(self) -> None:
if not self._storage_path.exists():
return
try:
with self._storage_path.open("r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
item = json.loads(line)
conversation_id = str(item.get("conversation_id", "default"))
user_message = str(item.get("user_message", ""))
assistant_message = str(item.get("assistant_message", ""))
if user_message and assistant_message:
self._store[conversation_id].append(
MemoryTurn(user_message=user_message, assistant_message=assistant_message)
)
except Exception as e:
print(f"[memory] Failed to load memory file: {e}")
def _append_to_disk(self, conversation_id: str, turn: MemoryTurn) -> None:
try:
self._storage_path.parent.mkdir(parents=True, exist_ok=True)
payload = {
"conversation_id": conversation_id,
"user_message": turn.user_message,
"assistant_message": turn.assistant_message,
}
with self._write_lock:
with self._storage_path.open("a", encoding="utf-8") as f:
f.write(json.dumps(payload, ensure_ascii=True) + "\n")
except Exception as e:
print(f"[memory] Failed to persist memory turn: {e}")
def append(self, conversation_id: str, user_message: str, assistant_message: str) -> None:
turn = MemoryTurn(user_message=user_message, assistant_message=assistant_message)
self._store[conversation_id].append(turn)
self._append_to_disk(conversation_id, turn)
def format_history(self, conversation_id: str) -> str:
history = self._store.get(conversation_id)
if not history:
return "No previous conversation."
lines: list[str] = []
for turn in history:
lines.append(f"User: {turn.user_message}")
lines.append(f"Assistant: {turn.assistant_message}")
return "\n".join(lines)
class RagChatService:
def __init__(self, k: int = 4) -> None:
self._k = k
self._vectorstore = None
self._retriever = None
self._llm = None
self._memory = ConversationMemory()
def _get_retriever(self):
if self._retriever is None:
self._vectorstore = get_or_build_vectorstore()
self._retriever = self._vectorstore.as_retriever(search_kwargs={"k": self._k})
return self._retriever
def _get_llm(self) -> ChatOllama:
if self._llm is None:
self._llm = ChatOllama(
model=LLM_MODEL,
temperature=0.2,
base_url=OLLAMA_BASE_URL,
client_kwargs=_ollama_client_kwargs(),
)
return self._llm
def _format_context(self, question: str) -> str:
docs = self._get_retriever().invoke(question)
if not docs:
return "No relevant FAQ context found."
return "\n\n".join(doc.page_content for doc in docs)
def _build_messages(self, question: str, conversation_id: str) -> list[BaseMessage]:
history = self._memory.format_history(conversation_id)
context = self._format_context(question)
system_prompt = (
"You are a concise and helpful support assistant for 9jaLingo, a voice AI platform. "
"Use only the provided FAQ context and recent conversation history. "
"If the answer is not in the context, say that clearly and direct the user to official support.\n\n"
f"Conversation history:\n{history}\n\n"
f"FAQ context:\n{context}"
)
return [
SystemMessage(content=system_prompt),
HumanMessage(content=question),
]
def chat(self, question: str, conversation_id: str) -> str:
messages = self._build_messages(question, conversation_id)
response = self._get_llm().invoke(messages)
answer = response.content if isinstance(response.content, str) else str(response.content)
self._memory.append(conversation_id, question, answer)
return answer
def stream(self, question: str, conversation_id: str) -> Generator[str, None, None]:
messages = self._build_messages(question, conversation_id)
parts: list[str] = []
for chunk in self._get_llm().stream(messages):
content = chunk.content if isinstance(chunk.content, str) else str(chunk.content)
if not content:
continue
parts.append(content)
yield content
self._memory.append(conversation_id, question, "".join(parts))
chat_service = RagChatService() |