RAG_SPACE / rag.py
NimrodDev's picture
cmc
f56ceca
# rag.py — Dual-company RAG pipeline (LD Events + Lamaki Designs)
from __future__ import annotations
import os, re, json, pickle, tempfile
from typing import List, Tuple
from functools import lru_cache
from supabase import create_client
from datasets import load_dataset
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEndpoint
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.chains import RetrievalQA
from langchain_core.prompts import PromptTemplate
# ---------------------------------------------------------------- CONFIG
DATASET_ID = "NimrodDev/LD_Events2"
LLM_MODEL = "mistralai/Mistral-7B-Instruct-v0.3"
CACHE_DIR = os.getenv("HF_HOME", tempfile.gettempdir())
FAISS_PATH = os.path.join(CACHE_DIR, "faiss_index.pkl")
HF_TOKEN = os.getenv("HF_TOKEN", os.getenv("HUGGINGFACEHUB_API_TOKEN", ""))
SUPABASE_URL = os.getenv("SUPABASE_URL")
SUPABASE_KEY = os.getenv("SUPABASE_KEY")
supabase = None
if SUPABASE_URL and SUPABASE_KEY:
try:
supabase = create_client(SUPABASE_URL, SUPABASE_KEY)
print("✅ Supabase client initialized.")
except Exception as e:
print(f"⚠️ Supabase init failed: {e}")
os.makedirs(CACHE_DIR, exist_ok=True)
os.environ.update({
"HF_HOME": CACHE_DIR,
"HF_HUB_CACHE": CACHE_DIR,
"TRANSFORMERS_CACHE": CACHE_DIR
})
# ---------------------------------------------------------------- INTENT DETECTION
GREETING_RE = re.compile(r"\b(hi|hello|hey|good morning|good afternoon|good evening)\b", re.I)
THANKS_RE = re.compile(r"\b(thank|thanks|appreciate)\b", re.I)
BYE_RE = re.compile(r"\b(bye|goodbye|see you|later)\b", re.I)
MONEY_RE = re.compile(r"\b(price|cost|budget|cheap|expensive|money|usd|ksh|payment|deposit|fee|quote)\b", re.I)
COMPLAIN_RE = re.compile(r"\b(complain|bad|terrible|awful|disappointed|angry|slow|rude|issue|problem)\b", re.I)
HUMAN_RE = re.compile(r"\b(agent|human|representative|manager|someone|person)\b", re.I)
# ---------------------------------------------------------------- COMPANY FALLBACKS
FALLBACKS = {
"LD Events": {
"greeting": "Hello! 👋 I’m *Amina*, your assistant for **LD Events** (weddings, graduations, corporate events) and **Lamaki Designs** (construction & architecture). How may I help you today?",
"money": "Our event packages vary depending on venue and number of guests. Could you share a few details so we can estimate a quote?",
"complain": "I’m sorry to hear that 😔. I’ll alert our support team — expect a call from a senior agent shortly.",
"thanks": "You’re most welcome! 💐",
"bye": "Thanks for chatting with LD Events. Have a beautiful day!",
"handoff": "Sure! I’ll connect you to a human representative now. Please hold on a moment.",
"default": "Let me get back to you on that. I’ve forwarded your question to a senior planner."
},
"Lamaki Designs": {
"greeting": "Karibu! 🏗️ I’m *Amina*, assistant for **Lamaki Designs** (construction, architectural plans, project management) and **LD Events** (weddings, graduations, corporate events). How may I assist?",
"money": "Construction costs depend on project scope and materials. Kindly share your plot size or design type for an accurate estimate.",
"complain": "We’re truly sorry for the inconvenience. Our site supervisor will reach out within 30 minutes to help.",
"thanks": "Asante! We appreciate your time.",
"bye": "Goodbye 👋 and thank you for trusting Lamaki Designs.",
"handoff": "No problem. A Lamaki Designs representative will join the chat soon.",
"default": "Let me get back to you on that — I’ll forward this to our design team."
}
}
# ---------------------------------------------------------------- HELPERS
def _company_from_text(text: str) -> str:
t = text.lower()
if any(k in t for k in ("lamaki", "construction", "architect", "plan", "bungalow", "site", "building")):
return "Lamaki Designs"
if any(k in t for k in ("ld events", "event", "wedding", "graduation", "venue", "party")):
return "LD Events"
return "LD Events"
def _detect_intent(text: str) -> str:
if GREETING_RE.search(text): return "greeting"
if THANKS_RE.search(text): return "thanks"
if BYE_RE.search(text): return "bye"
if MONEY_RE.search(text): return "money"
if COMPLAIN_RE.search(text): return "complain"
if HUMAN_RE.search(text): return "handoff"
return "normal"
def _fallback_answer(company: str, intent: str) -> str:
return FALLBACKS[company].get(intent, FALLBACKS[company]["default"])
# ---------------------------------------------------------------- DATA FETCH
@lru_cache(maxsize=1)
def get_texts() -> List[str]:
try:
print("🔍 Loading dataset from Parquet...")
ds = load_dataset(DATASET_ID, split="train", revision="refs/convert/parquet")
texts = [str(row["text"]) for row in ds if row.get("text")]
print(f"✅ Loaded {len(texts)} text chunks from {DATASET_ID}")
return texts
except Exception as e:
print(f"⚠️ Dataset load failed: {e}")
return []
# ---------------------------------------------------------------- VECTORSTORE
@lru_cache(maxsize=1)
def get_vectorstore() -> FAISS:
if os.path.exists(FAISS_PATH):
try:
with open(FAISS_PATH, "rb") as f:
print("📦 Using cached FAISS index.")
return pickle.load(f)
except Exception as e:
print(f"⚠️ Failed to load FAISS cache: {e}, rebuilding...")
texts = get_texts()
if not texts:
print("⚠️ No dataset found; using dummy FAISS index.")
return FAISS.from_texts(["No context available."],
HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2"))
splitter = RecursiveCharacterTextSplitter(chunk_size=700, chunk_overlap=100)
docs = splitter.create_documents(texts)
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
vs = FAISS.from_documents(docs, embeddings)
with open(FAISS_PATH, "wb") as f:
pickle.dump(vs, f)
print("✅ FAISS index created and cached.")
return vs
# ---------------------------------------------------------------- LLM
@lru_cache(maxsize=1)
def get_llm():
if not HF_TOKEN:
raise ValueError("Hugging Face token missing! Please set HF_TOKEN or HUGGINGFACEHUB_API_TOKEN.")
return HuggingFaceEndpoint(
repo_id=LLM_MODEL,
huggingfacehub_api_token=HF_TOKEN,
temperature=0.3,
max_new_tokens=300
)
PROMPT = PromptTemplate.from_template("""
You are Amina, a friendly virtual assistant for {company}.
Use the context below to answer questions concisely and politely.
If unsure, say: "Let me get back to you on that."
Context:
{context}
Question:
{question}
Answer:
""")
# ---------------------------------------------------------------- MAIN CHAT LOGIC
def ask_question(phone: str, question: str) -> Tuple[str, List]:
intent = _detect_intent(question)
company = _company_from_text(question)
# Fast fallback for simple intents
if intent in ("greeting", "thanks", "bye", "handoff"):
answer = _fallback_answer(company, intent)
_save_chat(phone, question, answer)
return answer, []
vs = get_vectorstore()
retriever = vs.as_retriever(search_kwargs={"k": 4})
qa = RetrievalQA.from_chain_type(
llm=get_llm(),
retriever=retriever,
chain_type_kwargs={"prompt": PROMPT.partial(company=company)},
return_source_documents=True,
)
try:
result = qa({"query": question})
answer = result.get("result", "").strip()
docs = result.get("source_documents", [])
except Exception as e:
print(f"❌ QA pipeline error: {e}")
answer, docs = "", []
# Smart fallback (pricing, complaints, or missing)
if not answer or len(answer.split()) < 4:
answer = _fallback_answer(company, intent if intent in ("money", "complain") else "default")
_save_chat(phone, question, answer)
return answer, docs
# ---------------------------------------------------------------- SUPABASE LOGGING
def _save_chat(phone: str, q: str, a: str) -> None:
if not supabase:
return
try:
data = [
{"user_phone": phone, "role": "user", "message": q},
{"user_phone": phone, "role": "assistant", "message": a}
]
supabase.table("chat_memory").insert(data).execute()
except Exception as e:
print(f"⚠️ Chat log failed: {e}")