RAG_SPACE_V2 / rag.py
NimrodDev's picture
clean push – no secrets
37c33ae
# rag.py (v2 – with filters + fallbacks)
from __future__ import annotations
import os, uuid, tempfile, requests, shutil, re
from pathlib import Path
from functools import lru_cache
from typing import List, Tuple
from datasets import load_dataset
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_core.prompts import PromptTemplate
from langchain.chains import RetrievalQA
from langchain_huggingface import HuggingFaceEndpoint
from supabase import create_client
# ---------- config ----------
HF_DS = "NimrodDev/LD_Events2"
EMBED_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
LLM_MODEL = "microsoft/DialoGPT-medium"
FAISS_PATH = Path("faiss_db")
SUPABASE_URL = os.getenv("SUPABASE_URL")
SUPABASE_KEY = os.getenv("SUPABASE_KEY")
HF_TOKEN = os.getenv("HF_TOKEN")
supabase = create_client(SUPABASE_URL, SUPABASE_KEY)
# ---------- keyword filters ----------
GREETING_RE = re.compile(r"\b(hi|hello|hey|good morning|good afternoon|good evening)\b", re.I)
THANKS_RE = re.compile(r"\b(thank|thanks|appreciate)\b", re.I)
BYE_RE = re.compile(r"\b(bye|goodbye|see you|later)\b", re.I)
MONEY_RE = re.compile(r"\b(price|cost|budget|cheap|expensive|money|usd|ksh|payment|deposit)\b", re.I)
COMPLAIN_RE = re.compile(r"\b(complain|bad|terrible|awful|disappointed|angry|slow|rude)\b", re.I)
# ---------- company-specific fallbacks ----------
FALLBACKS = {
"LD Events": {
"greeting": "Hello! 👋 Welcome to LD Events – your trusted partner for weddings, graduations and corporate events.",
"money": "Our pricing depends on venue, guest count and package. Please share a few details so we can give you a tailored quote.",
"complain": "We’re sorry to hear this. A senior agent will contact you within 30 minutes to resolve the issue.",
"thanks": "You’re welcome! If you need anything else, just text back.",
"bye": "Thanks for chatting with LD Events. Have a lovely day!",
"default": "I’m not sure about that, but a human agent will follow up shortly."
},
"Lamaki Designs": {
"greeting": "Karibu! 🏗️ Lamaki Designs here – quality construction, architectural plans and project management.",
"money": "Cost varies by project size and materials. Kindly share your plot size / plan so we can estimate for you.",
"complain": "We apologise for the inconvenience. Our site manager will call you within 30 minutes to sort it out.",
"thanks": "Asante! Feel free to text any time.",
"bye": "Good-bye and stay safe!",
"default": "Let me get back to you on that."
}
}
# ---------- helpers ----------
def _company_from_text(text: str) -> str:
t = text.lower()
if any(k in t for k in ("ld events", "event", "wedding", "venue", "graduation")):
return "LD Events"
if any(k in t for k in ("lamaki", "construction", "build", "site", "bungalow", "architect")):
return "Lamaki Designs"
return "LD Events" # default
def _detect_intent(text: str) -> str:
if GREETING_RE.search(text):
return "greeting"
if THANKS_RE.search(text):
return "thanks"
if BYE_RE.search(text):
return "bye"
if MONEY_RE.search(text):
return "money"
if COMPLAIN_RE.search(text):
return "complain"
return "normal"
def _fallback_answer(company: str, intent: str) -> str:
return FALLBACKS[company].get(intent, FALLBACKS[company]["default"])
# ---------- pdf loader ----------
def download_pdfs() -> List[Path]:
data_dir = Path("data")
data_dir.mkdir(exist_ok=True)
ds = load_dataset(HF_DS, split="train", streaming=True)
paths = []
for row in ds:
url = row["pdf_url"]
name = row.get("name", uuid.uuid4().hex) + ".pdf"
dest = data_dir / name
if not dest.exists():
r = requests.get(url, stream=True, timeout=30)
r.raise_for_status()
with open(dest, "wb") as f:
shutil.copyfileobj(r.raw, f)
paths.append(dest)
return paths
# ---------- vector store ----------
@lru_cache(maxsize=1)
def get_vectorstore() -> FAISS:
if FAISS_PATH.exists():
return FAISS.load_local(str(FAISS_PATH), HuggingFaceEmbeddings(model_name=EMBED_MODEL),
allow_dangerous_deserialization=True)
docs = []
splitter = RecursiveCharacterTextSplitter(chunk_size=600, chunk_overlap=50)
for pdf in download_pdfs():
text = Path(pdf).read_text(encoding="utf-8", errors="ignore")
docs.extend(splitter.create_documents([text], metadatas=[{"source": pdf.name}]))
vs = FAISS.from_documents(docs, HuggingFaceEmbeddings(model_name=EMBED_MODEL))
vs.save_local(str(FAISS_PATH))
return vs
# ---------- llm ----------
@lru_cache(maxsize=1)
def get_llm():
return HuggingFaceEndpoint(
repo_id=LLM_MODEL,
temperature=0.1,
max_new_tokens=150,
huggingfacehub_api_token=HF_TOKEN
)
PROMPT = PromptTemplate.from_template("""You are Amina, assistant for {company}.
Use only the context below. If unsure, say: “A human agent will follow up.”
Context: {context}
Question: {question}
Answer:""")
# ---------- main entry ----------
def ask_question(phone: str, question: str) -> Tuple[str, List]:
intent = _detect_intent(question)
company = _company_from_text(question)
# short-circuit greetings/thanks/bye – no LLM, no context needed
if intent in ("greeting", "thanks", "bye"):
answer = _fallback_answer(company, intent)
_save_chat(phone, question, answer)
return answer, []
# money or complaints: fallback if no docs
vs = get_vectorstore()
docs = vs.similarity_search(question, k=3)
if not docs:
answer = _fallback_answer(company, intent if intent in ("money", "complain") else "default")
_save_chat(phone, question, answer)
return answer, []
# normal RAG
qa = RetrievalQA.from_chain_type(
llm=get_llm(),
retriever=vs.as_retriever(search_kwargs={"k": 3}),
return_source_documents=True,
chain_type_kwargs={"prompt": PROMPT}
)
result = qa({"query": question, "company": company})
answer = result["result"].strip()
_save_chat(phone, question, answer)
return answer, result.get("source_documents", [])
def _save_chat(phone: str, q: str, a: str) -> None:
supabase.table("chat_memory").insert({"user_phone": phone, "role": "user", "message": q}).execute()
supabase.table("chat_memory").insert({"user_phone": phone, "role": "assistant", "message": a}).execute()