Spaces:
No application file
No application file
| # rag.py (v2 – with filters + fallbacks) | |
| from __future__ import annotations | |
| import os, uuid, tempfile, requests, shutil, re | |
| from pathlib import Path | |
| from functools import lru_cache | |
| from typing import List, Tuple | |
| from datasets import load_dataset | |
| from langchain_text_splitters import RecursiveCharacterTextSplitter | |
| from langchain_community.vectorstores import FAISS | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| from langchain_core.prompts import PromptTemplate | |
| from langchain.chains import RetrievalQA | |
| from langchain_huggingface import HuggingFaceEndpoint | |
| from supabase import create_client | |
| # ---------- config ---------- | |
| HF_DS = "NimrodDev/LD_Events2" | |
| EMBED_MODEL = "sentence-transformers/all-MiniLM-L6-v2" | |
| LLM_MODEL = "microsoft/DialoGPT-medium" | |
| FAISS_PATH = Path("faiss_db") | |
| SUPABASE_URL = os.getenv("SUPABASE_URL") | |
| SUPABASE_KEY = os.getenv("SUPABASE_KEY") | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| supabase = create_client(SUPABASE_URL, SUPABASE_KEY) | |
| # ---------- keyword filters ---------- | |
| GREETING_RE = re.compile(r"\b(hi|hello|hey|good morning|good afternoon|good evening)\b", re.I) | |
| THANKS_RE = re.compile(r"\b(thank|thanks|appreciate)\b", re.I) | |
| BYE_RE = re.compile(r"\b(bye|goodbye|see you|later)\b", re.I) | |
| MONEY_RE = re.compile(r"\b(price|cost|budget|cheap|expensive|money|usd|ksh|payment|deposit)\b", re.I) | |
| COMPLAIN_RE = re.compile(r"\b(complain|bad|terrible|awful|disappointed|angry|slow|rude)\b", re.I) | |
| # ---------- company-specific fallbacks ---------- | |
| FALLBACKS = { | |
| "LD Events": { | |
| "greeting": "Hello! 👋 Welcome to LD Events – your trusted partner for weddings, graduations and corporate events.", | |
| "money": "Our pricing depends on venue, guest count and package. Please share a few details so we can give you a tailored quote.", | |
| "complain": "We’re sorry to hear this. A senior agent will contact you within 30 minutes to resolve the issue.", | |
| "thanks": "You’re welcome! If you need anything else, just text back.", | |
| "bye": "Thanks for chatting with LD Events. Have a lovely day!", | |
| "default": "I’m not sure about that, but a human agent will follow up shortly." | |
| }, | |
| "Lamaki Designs": { | |
| "greeting": "Karibu! 🏗️ Lamaki Designs here – quality construction, architectural plans and project management.", | |
| "money": "Cost varies by project size and materials. Kindly share your plot size / plan so we can estimate for you.", | |
| "complain": "We apologise for the inconvenience. Our site manager will call you within 30 minutes to sort it out.", | |
| "thanks": "Asante! Feel free to text any time.", | |
| "bye": "Good-bye and stay safe!", | |
| "default": "Let me get back to you on that." | |
| } | |
| } | |
| # ---------- helpers ---------- | |
| def _company_from_text(text: str) -> str: | |
| t = text.lower() | |
| if any(k in t for k in ("ld events", "event", "wedding", "venue", "graduation")): | |
| return "LD Events" | |
| if any(k in t for k in ("lamaki", "construction", "build", "site", "bungalow", "architect")): | |
| return "Lamaki Designs" | |
| return "LD Events" # default | |
| def _detect_intent(text: str) -> str: | |
| if GREETING_RE.search(text): | |
| return "greeting" | |
| if THANKS_RE.search(text): | |
| return "thanks" | |
| if BYE_RE.search(text): | |
| return "bye" | |
| if MONEY_RE.search(text): | |
| return "money" | |
| if COMPLAIN_RE.search(text): | |
| return "complain" | |
| return "normal" | |
| def _fallback_answer(company: str, intent: str) -> str: | |
| return FALLBACKS[company].get(intent, FALLBACKS[company]["default"]) | |
| # ---------- pdf loader ---------- | |
| def download_pdfs() -> List[Path]: | |
| data_dir = Path("data") | |
| data_dir.mkdir(exist_ok=True) | |
| ds = load_dataset(HF_DS, split="train", streaming=True) | |
| paths = [] | |
| for row in ds: | |
| url = row["pdf_url"] | |
| name = row.get("name", uuid.uuid4().hex) + ".pdf" | |
| dest = data_dir / name | |
| if not dest.exists(): | |
| r = requests.get(url, stream=True, timeout=30) | |
| r.raise_for_status() | |
| with open(dest, "wb") as f: | |
| shutil.copyfileobj(r.raw, f) | |
| paths.append(dest) | |
| return paths | |
| # ---------- vector store ---------- | |
| def get_vectorstore() -> FAISS: | |
| if FAISS_PATH.exists(): | |
| return FAISS.load_local(str(FAISS_PATH), HuggingFaceEmbeddings(model_name=EMBED_MODEL), | |
| allow_dangerous_deserialization=True) | |
| docs = [] | |
| splitter = RecursiveCharacterTextSplitter(chunk_size=600, chunk_overlap=50) | |
| for pdf in download_pdfs(): | |
| text = Path(pdf).read_text(encoding="utf-8", errors="ignore") | |
| docs.extend(splitter.create_documents([text], metadatas=[{"source": pdf.name}])) | |
| vs = FAISS.from_documents(docs, HuggingFaceEmbeddings(model_name=EMBED_MODEL)) | |
| vs.save_local(str(FAISS_PATH)) | |
| return vs | |
| # ---------- llm ---------- | |
| def get_llm(): | |
| return HuggingFaceEndpoint( | |
| repo_id=LLM_MODEL, | |
| temperature=0.1, | |
| max_new_tokens=150, | |
| huggingfacehub_api_token=HF_TOKEN | |
| ) | |
| PROMPT = PromptTemplate.from_template("""You are Amina, assistant for {company}. | |
| Use only the context below. If unsure, say: “A human agent will follow up.” | |
| Context: {context} | |
| Question: {question} | |
| Answer:""") | |
| # ---------- main entry ---------- | |
| def ask_question(phone: str, question: str) -> Tuple[str, List]: | |
| intent = _detect_intent(question) | |
| company = _company_from_text(question) | |
| # short-circuit greetings/thanks/bye – no LLM, no context needed | |
| if intent in ("greeting", "thanks", "bye"): | |
| answer = _fallback_answer(company, intent) | |
| _save_chat(phone, question, answer) | |
| return answer, [] | |
| # money or complaints: fallback if no docs | |
| vs = get_vectorstore() | |
| docs = vs.similarity_search(question, k=3) | |
| if not docs: | |
| answer = _fallback_answer(company, intent if intent in ("money", "complain") else "default") | |
| _save_chat(phone, question, answer) | |
| return answer, [] | |
| # normal RAG | |
| qa = RetrievalQA.from_chain_type( | |
| llm=get_llm(), | |
| retriever=vs.as_retriever(search_kwargs={"k": 3}), | |
| return_source_documents=True, | |
| chain_type_kwargs={"prompt": PROMPT} | |
| ) | |
| result = qa({"query": question, "company": company}) | |
| answer = result["result"].strip() | |
| _save_chat(phone, question, answer) | |
| return answer, result.get("source_documents", []) | |
| def _save_chat(phone: str, q: str, a: str) -> None: | |
| supabase.table("chat_memory").insert({"user_phone": phone, "role": "user", "message": q}).execute() | |
| supabase.table("chat_memory").insert({"user_phone": phone, "role": "assistant", "message": a}).execute() |