NimrodDev commited on
Commit
41b9355
Β·
1 Parent(s): 94f580e
Files changed (1) hide show
  1. rag.py +44 -68
rag.py CHANGED
@@ -1,37 +1,34 @@
1
- # rag.py (v8 – exact imports in 0.3.7 wheels)
2
  from __future__ import annotations
3
- import os, uuid, tempfile, requests, shutil, re
4
- from pathlib import Path
5
  from functools import lru_cache
6
  from typing import List, Tuple
 
7
  from datasets import load_dataset
8
- from langchain.text_splitter import RecursiveCharacterTextSplitter # ← exists in 0.3.7
9
  from langchain_community.vectorstores import FAISS
10
  from langchain_huggingface import HuggingFaceEmbeddings, HuggingFaceEndpoint
11
  from langchain_core.prompts import PromptTemplate
12
- from langchain.chains import RetrievalQA # ← exists in 0.3.7
13
  from supabase import create_client
14
 
15
-
16
- #________________________config_________________#
17
- HF_DS = "NimrodDev/LD_Events2"
18
- EMBED_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
19
- LLM_MODEL = "microsoft/DialoGPT-medium"
20
- FAISS_PATH = Path("faiss_db")
21
  SUPABASE_URL = os.getenv("SUPABASE_URL")
22
  SUPABASE_KEY = os.getenv("SUPABASE_KEY")
23
- HF_TOKEN = os.getenv("HF_TOKEN")
24
 
25
  supabase = create_client(SUPABASE_URL, SUPABASE_KEY)
26
 
27
- # ---------- keyword filters ----------
28
- GREETING_RE = re.compile(r"\b(hi|hello|hey|good morning|good afternoon|good evening)\b", re.I)
29
- THANKS_RE = re.compile(r"\b(thank|thanks|appreciate)\b", re.I)
30
- BYE_RE = re.compile(r"\b(bye|goodbye|see you|later)\b", re.I)
31
- MONEY_RE = re.compile(r"\b(price|cost|budget|cheap|expensive|money|usd|ksh|payment|deposit)\b", re.I)
32
- COMPLAIN_RE = re.compile(r"\b(complain|bad|terrible|awful|disappointed|angry|slow|rude)\b", re.I)
33
 
34
- # ---------- company-specific fallbacks ----------
35
  FALLBACKS = {
36
  "LD Events": {
37
  "greeting": "Hello! πŸ‘‹ Welcome to LD Events – your trusted partner for weddings, graduations and corporate events.",
@@ -51,65 +48,45 @@ FALLBACKS = {
51
  }
52
  }
53
 
54
- # ---------- helpers ----------
55
  def _company_from_text(text: str) -> str:
56
  t = text.lower()
57
  if any(k in t for k in ("ld events", "event", "wedding", "venue", "graduation")):
58
  return "LD Events"
59
  if any(k in t for k in ("lamaki", "construction", "build", "site", "bungalow", "architect")):
60
  return "Lamaki Designs"
61
- return "LD Events" # default
62
 
63
  def _detect_intent(text: str) -> str:
64
- if GREETING_RE.search(text):
65
- return "greeting"
66
- if THANKS_RE.search(text):
67
- return "thanks"
68
- if BYE_RE.search(text):
69
- return "bye"
70
- if MONEY_RE.search(text):
71
- return "money"
72
- if COMPLAIN_RE.search(text):
73
- return "complain"
74
  return "normal"
75
 
76
  def _fallback_answer(company: str, intent: str) -> str:
77
  return FALLBACKS[company].get(intent, FALLBACKS[company]["default"])
78
 
79
- # ---------- pdf loader ----------
80
- def download_pdfs() -> List[Path]:
81
- data_dir = Path("data")
82
- data_dir.mkdir(exist_ok=True)
83
- ds = load_dataset(HF_DS, split="train", streaming=True)
84
- paths = []
85
- for row in ds:
86
- url = row["pdf_url"]
87
- name = row.get("name", uuid.uuid4().hex) + ".pdf"
88
- dest = data_dir / name
89
- if not dest.exists():
90
- r = requests.get(url, stream=True, timeout=30)
91
- r.raise_for_status()
92
- with open(dest, "wb") as f:
93
- shutil.copyfileobj(r.raw, f)
94
- paths.append(dest)
95
- return paths
96
-
97
- # ---------- vector store ----------
98
  @lru_cache(maxsize=1)
99
  def get_vectorstore() -> FAISS:
100
- if FAISS_PATH.exists():
101
- return FAISS.load_local(str(FAISS_PATH), HuggingFaceEmbeddings(model_name=EMBED_MODEL),
102
- allow_dangerous_deserialization=True)
103
- docs = []
104
  splitter = RecursiveCharacterTextSplitter(chunk_size=600, chunk_overlap=50)
105
- for pdf in download_pdfs():
106
- text = Path(pdf).read_text(encoding="utf-8", errors="ignore")
107
- docs.extend(splitter.create_documents([text], metadatas=[{"source": pdf.name}]))
108
- vs = FAISS.from_documents(docs, HuggingFaceEmbeddings(model_name=EMBED_MODEL))
109
- vs.save_local(str(FAISS_PATH))
110
- return vs
111
-
112
- # ---------- llm ----------
113
  @lru_cache(maxsize=1)
114
  def get_llm():
115
  return HuggingFaceEndpoint(
@@ -125,26 +102,24 @@ Context: {context}
125
  Question: {question}
126
  Answer:""")
127
 
128
- # ---------- main entry ----------
129
  def ask_question(phone: str, question: str) -> Tuple[str, List]:
130
- intent = _detect_intent(question)
131
  company = _company_from_text(question)
132
 
133
- # short-circuit greetings/thanks/bye – no LLM, no context needed
134
  if intent in ("greeting", "thanks", "bye"):
135
  answer = _fallback_answer(company, intent)
136
  _save_chat(phone, question, answer)
137
  return answer, []
138
 
139
- # money or complaints: fallback if no docs
140
- vs = get_vectorstore()
141
  docs = vs.similarity_search(question, k=3)
142
  if not docs:
143
  answer = _fallback_answer(company, intent if intent in ("money", "complain") else "default")
144
  _save_chat(phone, question, answer)
145
  return answer, []
146
 
147
- # normal RAG
148
  qa = RetrievalQA.from_chain_type(
149
  llm=get_llm(),
150
  retriever=vs.as_retriever(search_kwargs={"k": 3}),
@@ -156,6 +131,7 @@ def ask_question(phone: str, question: str) -> Tuple[str, List]:
156
  _save_chat(phone, question, answer)
157
  return answer, result.get("source_documents", [])
158
 
 
159
  def _save_chat(phone: str, q: str, a: str) -> None:
160
  supabase.table("chat_memory").insert({"user_phone": phone, "role": "user", "message": q}).execute()
161
  supabase.table("chat_memory").insert({"user_phone": phone, "role": "assistant", "message": a}).execute()
 
1
+ # rag.py – zero-disk, HF-Space-safe edition
2
  from __future__ import annotations
3
+ import os, re
 
4
  from functools import lru_cache
5
  from typing import List, Tuple
6
+
7
  from datasets import load_dataset
8
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
9
  from langchain_community.vectorstores import FAISS
10
  from langchain_huggingface import HuggingFaceEmbeddings, HuggingFaceEndpoint
11
  from langchain_core.prompts import PromptTemplate
12
+ from langchain.chains import RetrievalQA
13
  from supabase import create_client
14
 
15
+ # ---------------- config ----------------
16
+ HF_DS = "NimrodDev/LD_Events2"
17
+ EMBED_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
18
+ LLM_MODEL = "microsoft/DialoGPT-medium"
 
 
19
  SUPABASE_URL = os.getenv("SUPABASE_URL")
20
  SUPABASE_KEY = os.getenv("SUPABASE_KEY")
21
+ HF_TOKEN = os.getenv("HF_TOKEN")
22
 
23
  supabase = create_client(SUPABASE_URL, SUPABASE_KEY)
24
 
25
+ # ---------------- keyword & fallback blocks (unchanged) ----------------
26
+ GREETING_RE = re.compile(r"\b(hi|hello|hey|good morning|good afternoon|good evening)\b", re.I)
27
+ THANKS_RE = re.compile(r"\b(thank|thanks|appreciate)\b", re.I)
28
+ BYE_RE = re.compile(r"\b(bye|goodbye|see you|later)\b", re.I)
29
+ MONEY_RE = re.compile(r"\b(price|cost|budget|cheap|expensive|money|usd|ksh|payment|deposit)\b", re.I)
30
+ COMPLAIN_RE = re.compile(r"\b(complain|bad|terrible|awful|disappointed|angry|slow|rude)\b", re.I)
31
 
 
32
  FALLBACKS = {
33
  "LD Events": {
34
  "greeting": "Hello! πŸ‘‹ Welcome to LD Events – your trusted partner for weddings, graduations and corporate events.",
 
48
  }
49
  }
50
 
51
+ # ---------------- helpers ----------------
52
  def _company_from_text(text: str) -> str:
53
  t = text.lower()
54
  if any(k in t for k in ("ld events", "event", "wedding", "venue", "graduation")):
55
  return "LD Events"
56
  if any(k in t for k in ("lamaki", "construction", "build", "site", "bungalow", "architect")):
57
  return "Lamaki Designs"
58
+ return "LD Events"
59
 
60
  def _detect_intent(text: str) -> str:
61
+ if GREETING_RE.search(text): return "greeting"
62
+ if THANKS_RE.search(text): return "thanks"
63
+ if BYE_RE.search(text): return "bye"
64
+ if MONEY_RE.search(text): return "money"
65
+ if COMPLAIN_RE.search(text): return "complain"
 
 
 
 
 
66
  return "normal"
67
 
68
  def _fallback_answer(company: str, intent: str) -> str:
69
  return FALLBACKS[company].get(intent, FALLBACKS[company]["default"])
70
 
71
+ # ---------------- doc loader (RAM only) ----------------
72
+ def load_texts() -> List[str]:
73
+ ds = load_dataset(HF_DS, split="train")
74
+ # adjust field name if your dataset uses something else
75
+ return [row["text"] for row in ds if row.get("text")]
76
+
77
+ # ---------------- vector store (RAM only) ----------------
 
 
 
 
 
 
 
 
 
 
 
 
78
  @lru_cache(maxsize=1)
79
  def get_vectorstore() -> FAISS:
80
+ texts = load_texts()
 
 
 
81
  splitter = RecursiveCharacterTextSplitter(chunk_size=600, chunk_overlap=50)
82
+ docs = splitter.create_documents(
83
+ texts,
84
+ metadatas=[{"source": HF_DS}] * len(texts)
85
+ )
86
+ embeddings = HuggingFaceEmbeddings(model_name=EMBED_MODEL)
87
+ return FAISS.from_documents(docs, embeddings)
88
+
89
+ # ---------------- llm ----------------
90
  @lru_cache(maxsize=1)
91
  def get_llm():
92
  return HuggingFaceEndpoint(
 
102
  Question: {question}
103
  Answer:""")
104
 
105
+ # ---------------- main entry ----------------
106
  def ask_question(phone: str, question: str) -> Tuple[str, List]:
107
+ intent = _detect_intent(question)
108
  company = _company_from_text(question)
109
 
110
+ # short-circuit small-talk
111
  if intent in ("greeting", "thanks", "bye"):
112
  answer = _fallback_answer(company, intent)
113
  _save_chat(phone, question, answer)
114
  return answer, []
115
 
116
+ vs = get_vectorstore()
 
117
  docs = vs.similarity_search(question, k=3)
118
  if not docs:
119
  answer = _fallback_answer(company, intent if intent in ("money", "complain") else "default")
120
  _save_chat(phone, question, answer)
121
  return answer, []
122
 
 
123
  qa = RetrievalQA.from_chain_type(
124
  llm=get_llm(),
125
  retriever=vs.as_retriever(search_kwargs={"k": 3}),
 
131
  _save_chat(phone, question, answer)
132
  return answer, result.get("source_documents", [])
133
 
134
+ # ---------------- chat persistence ----------------
135
  def _save_chat(phone: str, q: str, a: str) -> None:
136
  supabase.table("chat_memory").insert({"user_phone": phone, "role": "user", "message": q}).execute()
137
  supabase.table("chat_memory").insert({"user_phone": phone, "role": "assistant", "message": a}).execute()