Spaces:
No application file
No application file
File size: 6,703 Bytes
37c33ae |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
# rag.py (v2 – with filters + fallbacks)
from __future__ import annotations
import os, uuid, tempfile, requests, shutil, re
from pathlib import Path
from functools import lru_cache
from typing import List, Tuple
from datasets import load_dataset
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_core.prompts import PromptTemplate
from langchain.chains import RetrievalQA
from langchain_huggingface import HuggingFaceEndpoint
from supabase import create_client
# ---------- config ----------
HF_DS = "NimrodDev/LD_Events2"
EMBED_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
LLM_MODEL = "microsoft/DialoGPT-medium"
FAISS_PATH = Path("faiss_db")
SUPABASE_URL = os.getenv("SUPABASE_URL")
SUPABASE_KEY = os.getenv("SUPABASE_KEY")
HF_TOKEN = os.getenv("HF_TOKEN")
supabase = create_client(SUPABASE_URL, SUPABASE_KEY)
# ---------- keyword filters ----------
GREETING_RE = re.compile(r"\b(hi|hello|hey|good morning|good afternoon|good evening)\b", re.I)
THANKS_RE = re.compile(r"\b(thank|thanks|appreciate)\b", re.I)
BYE_RE = re.compile(r"\b(bye|goodbye|see you|later)\b", re.I)
MONEY_RE = re.compile(r"\b(price|cost|budget|cheap|expensive|money|usd|ksh|payment|deposit)\b", re.I)
COMPLAIN_RE = re.compile(r"\b(complain|bad|terrible|awful|disappointed|angry|slow|rude)\b", re.I)
# ---------- company-specific fallbacks ----------
FALLBACKS = {
"LD Events": {
"greeting": "Hello! 👋 Welcome to LD Events – your trusted partner for weddings, graduations and corporate events.",
"money": "Our pricing depends on venue, guest count and package. Please share a few details so we can give you a tailored quote.",
"complain": "We’re sorry to hear this. A senior agent will contact you within 30 minutes to resolve the issue.",
"thanks": "You’re welcome! If you need anything else, just text back.",
"bye": "Thanks for chatting with LD Events. Have a lovely day!",
"default": "I’m not sure about that, but a human agent will follow up shortly."
},
"Lamaki Designs": {
"greeting": "Karibu! 🏗️ Lamaki Designs here – quality construction, architectural plans and project management.",
"money": "Cost varies by project size and materials. Kindly share your plot size / plan so we can estimate for you.",
"complain": "We apologise for the inconvenience. Our site manager will call you within 30 minutes to sort it out.",
"thanks": "Asante! Feel free to text any time.",
"bye": "Good-bye and stay safe!",
"default": "Let me get back to you on that."
}
}
# ---------- helpers ----------
def _company_from_text(text: str) -> str:
t = text.lower()
if any(k in t for k in ("ld events", "event", "wedding", "venue", "graduation")):
return "LD Events"
if any(k in t for k in ("lamaki", "construction", "build", "site", "bungalow", "architect")):
return "Lamaki Designs"
return "LD Events" # default
def _detect_intent(text: str) -> str:
if GREETING_RE.search(text):
return "greeting"
if THANKS_RE.search(text):
return "thanks"
if BYE_RE.search(text):
return "bye"
if MONEY_RE.search(text):
return "money"
if COMPLAIN_RE.search(text):
return "complain"
return "normal"
def _fallback_answer(company: str, intent: str) -> str:
return FALLBACKS[company].get(intent, FALLBACKS[company]["default"])
# ---------- pdf loader ----------
def download_pdfs() -> List[Path]:
data_dir = Path("data")
data_dir.mkdir(exist_ok=True)
ds = load_dataset(HF_DS, split="train", streaming=True)
paths = []
for row in ds:
url = row["pdf_url"]
name = row.get("name", uuid.uuid4().hex) + ".pdf"
dest = data_dir / name
if not dest.exists():
r = requests.get(url, stream=True, timeout=30)
r.raise_for_status()
with open(dest, "wb") as f:
shutil.copyfileobj(r.raw, f)
paths.append(dest)
return paths
# ---------- vector store ----------
@lru_cache(maxsize=1)
def get_vectorstore() -> FAISS:
if FAISS_PATH.exists():
return FAISS.load_local(str(FAISS_PATH), HuggingFaceEmbeddings(model_name=EMBED_MODEL),
allow_dangerous_deserialization=True)
docs = []
splitter = RecursiveCharacterTextSplitter(chunk_size=600, chunk_overlap=50)
for pdf in download_pdfs():
text = Path(pdf).read_text(encoding="utf-8", errors="ignore")
docs.extend(splitter.create_documents([text], metadatas=[{"source": pdf.name}]))
vs = FAISS.from_documents(docs, HuggingFaceEmbeddings(model_name=EMBED_MODEL))
vs.save_local(str(FAISS_PATH))
return vs
# ---------- llm ----------
@lru_cache(maxsize=1)
def get_llm():
return HuggingFaceEndpoint(
repo_id=LLM_MODEL,
temperature=0.1,
max_new_tokens=150,
huggingfacehub_api_token=HF_TOKEN
)
PROMPT = PromptTemplate.from_template("""You are Amina, assistant for {company}.
Use only the context below. If unsure, say: “A human agent will follow up.”
Context: {context}
Question: {question}
Answer:""")
# ---------- main entry ----------
def ask_question(phone: str, question: str) -> Tuple[str, List]:
intent = _detect_intent(question)
company = _company_from_text(question)
# short-circuit greetings/thanks/bye – no LLM, no context needed
if intent in ("greeting", "thanks", "bye"):
answer = _fallback_answer(company, intent)
_save_chat(phone, question, answer)
return answer, []
# money or complaints: fallback if no docs
vs = get_vectorstore()
docs = vs.similarity_search(question, k=3)
if not docs:
answer = _fallback_answer(company, intent if intent in ("money", "complain") else "default")
_save_chat(phone, question, answer)
return answer, []
# normal RAG
qa = RetrievalQA.from_chain_type(
llm=get_llm(),
retriever=vs.as_retriever(search_kwargs={"k": 3}),
return_source_documents=True,
chain_type_kwargs={"prompt": PROMPT}
)
result = qa({"query": question, "company": company})
answer = result["result"].strip()
_save_chat(phone, question, answer)
return answer, result.get("source_documents", [])
def _save_chat(phone: str, q: str, a: str) -> None:
supabase.table("chat_memory").insert({"user_phone": phone, "role": "user", "message": q}).execute()
supabase.table("chat_memory").insert({"user_phone": phone, "role": "assistant", "message": a}).execute() |