NimrodDev commited on
Commit
c66dd86
·
1 Parent(s): dd7779a
Files changed (1) hide show
  1. rag.py +70 -82
rag.py CHANGED
@@ -1,30 +1,29 @@
1
- # rag.py bullet-proof and container-safe
2
  from __future__ import annotations
3
- import os, re, json, requests, traceback
4
- from functools import lru_cache
5
  from typing import List, Tuple
 
 
 
6
 
 
 
7
  from langchain.text_splitter import RecursiveCharacterTextSplitter
8
  from langchain_community.vectorstores import FAISS
 
9
  from langchain_huggingface import HuggingFaceEndpoint
10
  from langchain_core.prompts import PromptTemplate
11
  from langchain.chains import RetrievalQA
12
- from supabase import create_client
13
 
14
  # ------------------------------------------------------------------ CONFIG
15
- DATASET_API = "https://datasets-server.huggingface.co/rows"
16
- DATASET = "NimrodDev/LD_Events2"
17
- CONFIG = "default"
18
- SPLIT = "train"
19
- LIMIT = 500
20
- LLM_MODEL = "microsoft/DialoGPT-medium"
21
-
22
  SUPABASE_URL = os.getenv("SUPABASE_URL")
23
  SUPABASE_KEY = os.getenv("SUPABASE_KEY")
24
- HF_TOKEN = os.getenv("HF_TOKEN")
 
 
25
 
26
- # ------------------------------------------------------------------ SAFE CACHE
27
- CACHE_DIR = os.getenv("HF_HOME", "/tmp/hf_cache")
28
  os.makedirs(CACHE_DIR, exist_ok=True)
29
  os.environ["TRANSFORMERS_CACHE"] = CACHE_DIR
30
  os.environ["HF_HOME"] = CACHE_DIR
@@ -32,14 +31,13 @@ os.environ["HF_HUB_CACHE"] = CACHE_DIR
32
 
33
  supabase = create_client(SUPABASE_URL, SUPABASE_KEY)
34
 
35
- # ------------------------------------------------------------------ INTENTS
36
  GREETING_RE = re.compile(r"\b(hi|hello|hey|good morning|good afternoon|good evening)\b", re.I)
37
  THANKS_RE = re.compile(r"\b(thank|thanks|appreciate)\b", re.I)
38
  BYE_RE = re.compile(r"\b(bye|goodbye|see you|later)\b", re.I)
39
  MONEY_RE = re.compile(r"\b(price|cost|budget|cheap|expensive|money|usd|ksh|payment|deposit)\b", re.I)
40
  COMPLAIN_RE = re.compile(r"\b(complain|bad|terrible|awful|disappointed|angry|slow|rude)\b", re.I)
41
 
42
- # ------------------------------------------------------------------ FALLBACKS
43
  FALLBACKS = {
44
  "LD Events": {
45
  "greeting": "Hello! 👋 I’m Amina, your assistant for *LD Events* (weddings, graduations, corporate events) "
@@ -83,72 +81,67 @@ def _detect_intent(text: str) -> str:
83
  def _fallback_answer(company: str, intent: str) -> str:
84
  return FALLBACKS[company].get(intent, FALLBACKS[company]["default"])
85
 
86
- # ------------------------------------------------------------------ DATA FETCH
87
  @lru_cache(maxsize=1)
88
  def get_texts() -> List[str]:
89
  try:
90
- url = f"{DATASET_API}?dataset={DATASET}&config={CONFIG}&split={SPLIT}&offset=0&length={LIMIT}"
91
- r = requests.get(url, timeout=60)
92
- r.raise_for_status()
93
- rows = r.json()["rows"]
94
- texts = [row["row"]["text"] for row in rows if row["row"].get("text")]
95
- print(f"✓ Fetched {len(texts)} texts from {DATASET}")
96
  return texts
97
  except Exception as e:
98
- print(f" Dataset fetch failed: {e} – using empty corpus")
99
  return []
100
 
101
- # ------------------------------------------------------------------ EMBEDDINGS
102
  @lru_cache(maxsize=1)
103
  def get_vectorstore() -> FAISS:
104
- texts = get_texts()
105
- embeddings = None
106
-
107
- try:
108
- # Primary attempt: use Inference API embeddings
109
- from langchain_huggingface import HuggingFaceInferenceAPIEmbeddings
110
- embeddings = HuggingFaceInferenceAPIEmbeddings(
111
- api_key=HF_TOKEN,
112
- model_name="sentence-transformers/all-MiniLM-L6-v2"
113
- )
114
- print("✅ Using HuggingFaceInferenceAPIEmbeddings")
115
- except ImportError:
116
- # Fallback: local embeddings (container-safe)
117
  try:
118
- from langchain_huggingface import HuggingFaceEmbeddings
119
- embeddings = HuggingFaceEmbeddings(
120
- model_name="sentence-transformers/all-MiniLM-L6-v2",
121
- cache_folder=CACHE_DIR
122
- )
123
- print("⚙️ Using local HuggingFaceEmbeddings fallback")
124
- except Exception as e:
125
- traceback.print_exc()
126
- raise RuntimeError(f"❌ Failed to load embeddings: {e}")
127
 
 
128
  if not texts:
129
- print(" No dataset texts available; creating dummy FAISS index.")
130
- return FAISS.from_texts([""], embeddings)
131
 
132
- splitter = RecursiveCharacterTextSplitter(chunk_size=600, chunk_overlap=50)
133
- docs = splitter.create_documents(texts, metadatas=[{"source": DATASET}] * len(texts))
134
- print(f"🔹 Building FAISS index with {len(docs)} chunks")
135
- return FAISS.from_documents(docs, embeddings)
 
 
 
 
 
136
 
137
  # ------------------------------------------------------------------ LLM
138
  @lru_cache(maxsize=1)
139
  def get_llm():
140
  return HuggingFaceEndpoint(
141
  repo_id=LLM_MODEL,
 
142
  temperature=0.1,
143
- max_new_tokens=150,
144
- huggingfacehub_api_token=HF_TOKEN
145
  )
146
 
147
  PROMPT = PromptTemplate.from_template("""
148
  You are Amina, assistant for {company}.
149
- Use only the context below. If unsure, say: "A human agent will follow up."
150
- Context: {context}
151
- Question: {question}
 
 
 
 
 
 
152
  Answer:
153
  """)
154
 
@@ -157,40 +150,35 @@ def ask_question(phone: str, question: str) -> Tuple[str, List]:
157
  intent = _detect_intent(question)
158
  company = _company_from_text(question)
159
 
 
160
  if intent in ("greeting", "thanks", "bye"):
161
  answer = _fallback_answer(company, intent)
162
  _save_chat(phone, question, answer)
163
  return answer, []
164
 
165
- try:
166
- vs = get_vectorstore()
167
- docs = vs.similarity_search(question, k=3)
168
- except Exception as e:
169
- print(f"❌ Vector store retrieval failed: {e}")
170
- docs = []
171
-
172
- if not docs or docs[0].page_content.strip() == "":
173
- answer = _fallback_answer(company, intent if intent in ("money", "complain") else "default")
174
- _save_chat(phone, question, answer)
175
- return answer, []
176
 
177
  try:
178
- qa = RetrievalQA.from_chain_type(
179
- llm=get_llm(),
180
- retriever=vs.as_retriever(search_kwargs={"k": 3}),
181
- return_source_documents=True,
182
- chain_type_kwargs={"prompt": PROMPT}
183
- )
184
  result = qa({"query": question, "company": company})
185
- answer = result.get("result", "").strip() or _fallback_answer(company, "default")
186
- docs_used = result.get("source_documents", [])
187
  except Exception as e:
188
- traceback.print_exc()
189
- answer = _fallback_answer(company, "default")
190
- docs_used = []
 
 
 
191
 
192
  _save_chat(phone, question, answer)
193
- return answer, docs_used
194
 
195
  # ------------------------------------------------------------------ SUPABASE LOGGING
196
  def _save_chat(phone: str, q: str, a: str) -> None:
@@ -198,4 +186,4 @@ def _save_chat(phone: str, q: str, a: str) -> None:
198
  supabase.table("chat_memory").insert({"user_phone": phone, "role": "user", "message": q}).execute()
199
  supabase.table("chat_memory").insert({"user_phone": phone, "role": "assistant", "message": a}).execute()
200
  except Exception as e:
201
- print(f" Chat log save failed: {e}")
 
1
+ # rag.py production-grade, Parquet-ready, FAISS-cached RAG pipeline
2
  from __future__ import annotations
3
+ import os, re, json, pickle
 
4
  from typing import List, Tuple
5
+ from functools import lru_cache
6
+ import requests
7
+ from supabase import create_client
8
 
9
+ # LangChain + HF
10
+ from datasets import load_dataset
11
  from langchain.text_splitter import RecursiveCharacterTextSplitter
12
  from langchain_community.vectorstores import FAISS
13
+ from langchain_community.embeddings import HuggingFaceEmbeddings
14
  from langchain_huggingface import HuggingFaceEndpoint
15
  from langchain_core.prompts import PromptTemplate
16
  from langchain.chains import RetrievalQA
 
17
 
18
  # ------------------------------------------------------------------ CONFIG
19
+ DATASET_ID = "NimrodDev/LD_Events2"
20
+ HF_TOKEN = os.getenv("HF_TOKEN")
 
 
 
 
 
21
  SUPABASE_URL = os.getenv("SUPABASE_URL")
22
  SUPABASE_KEY = os.getenv("SUPABASE_KEY")
23
+ LLM_MODEL = "mistralai/Mistral-7B-Instruct-v0.3" # stronger model
24
+ CACHE_DIR = os.getenv("HF_HOME", "/tmp/hf_cache")
25
+ FAISS_PATH = "/tmp/faiss_index.pkl"
26
 
 
 
27
  os.makedirs(CACHE_DIR, exist_ok=True)
28
  os.environ["TRANSFORMERS_CACHE"] = CACHE_DIR
29
  os.environ["HF_HOME"] = CACHE_DIR
 
31
 
32
  supabase = create_client(SUPABASE_URL, SUPABASE_KEY)
33
 
34
+ # ------------------------------------------------------------------ INTENT DETECTION
35
  GREETING_RE = re.compile(r"\b(hi|hello|hey|good morning|good afternoon|good evening)\b", re.I)
36
  THANKS_RE = re.compile(r"\b(thank|thanks|appreciate)\b", re.I)
37
  BYE_RE = re.compile(r"\b(bye|goodbye|see you|later)\b", re.I)
38
  MONEY_RE = re.compile(r"\b(price|cost|budget|cheap|expensive|money|usd|ksh|payment|deposit)\b", re.I)
39
  COMPLAIN_RE = re.compile(r"\b(complain|bad|terrible|awful|disappointed|angry|slow|rude)\b", re.I)
40
 
 
41
  FALLBACKS = {
42
  "LD Events": {
43
  "greeting": "Hello! 👋 I’m Amina, your assistant for *LD Events* (weddings, graduations, corporate events) "
 
81
  def _fallback_answer(company: str, intent: str) -> str:
82
  return FALLBACKS[company].get(intent, FALLBACKS[company]["default"])
83
 
84
+ # ------------------------------------------------------------------ DATA FETCH (Parquet)
85
  @lru_cache(maxsize=1)
86
  def get_texts() -> List[str]:
87
  try:
88
+ print("🔍 Loading dataset from Parquet branch...")
89
+ ds = load_dataset(DATASET_ID, split="train", revision="refs/convert/parquet")
90
+ texts = [str(row["text"]) for row in ds if row.get("text")]
91
+ print(f"✅ Loaded {len(texts)} rows from {DATASET_ID}")
 
 
92
  return texts
93
  except Exception as e:
94
+ print(f"⚠️ Failed to load Parquet dataset: {e}")
95
  return []
96
 
97
+ # ------------------------------------------------------------------ VECTORSTORE
98
  @lru_cache(maxsize=1)
99
  def get_vectorstore() -> FAISS:
100
+ if os.path.exists(FAISS_PATH):
 
 
 
 
 
 
 
 
 
 
 
 
101
  try:
102
+ with open(FAISS_PATH, "rb") as f:
103
+ vs = pickle.load(f)
104
+ print("📦 Loaded cached FAISS index.")
105
+ return vs
106
+ except Exception:
107
+ print("⚠️ Cached FAISS corrupted, rebuilding...")
 
 
 
108
 
109
+ texts = get_texts()
110
  if not texts:
111
+ print("⚠️ No dataset found; creating dummy FAISS index.")
112
+ return FAISS.from_texts(["No context available."], HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2"))
113
 
114
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
115
+ splitter = RecursiveCharacterTextSplitter(chunk_size=600, chunk_overlap=100)
116
+ docs = splitter.create_documents(texts)
117
+ vs = FAISS.from_documents(docs, embeddings)
118
+
119
+ with open(FAISS_PATH, "wb") as f:
120
+ pickle.dump(vs, f)
121
+ print("✅ FAISS index created and cached.")
122
+ return vs
123
 
124
  # ------------------------------------------------------------------ LLM
125
  @lru_cache(maxsize=1)
126
  def get_llm():
127
  return HuggingFaceEndpoint(
128
  repo_id=LLM_MODEL,
129
+ huggingfacehub_api_token=HF_TOKEN,
130
  temperature=0.1,
131
+ max_new_tokens=200
 
132
  )
133
 
134
  PROMPT = PromptTemplate.from_template("""
135
  You are Amina, assistant for {company}.
136
+ Answer clearly and politely using the context below.
137
+ If context lacks information, say: "Let me get back to you on that."
138
+
139
+ Context:
140
+ {context}
141
+
142
+ Question:
143
+ {question}
144
+
145
  Answer:
146
  """)
147
 
 
150
  intent = _detect_intent(question)
151
  company = _company_from_text(question)
152
 
153
+ # Quick intents skip RAG
154
  if intent in ("greeting", "thanks", "bye"):
155
  answer = _fallback_answer(company, intent)
156
  _save_chat(phone, question, answer)
157
  return answer, []
158
 
159
+ vs = get_vectorstore()
160
+ retriever = vs.as_retriever(search_kwargs={"k": 4})
161
+ qa = RetrievalQA.from_chain_type(
162
+ llm=get_llm(),
163
+ retriever=retriever,
164
+ chain_type_kwargs={"prompt": PROMPT},
165
+ return_source_documents=True
166
+ )
 
 
 
167
 
168
  try:
 
 
 
 
 
 
169
  result = qa({"query": question, "company": company})
170
+ answer = result.get("result", "").strip()
171
+ docs = result.get("source_documents", [])
172
  except Exception as e:
173
+ print(f"❌ QA error: {e}")
174
+ answer, docs = "", []
175
+
176
+ # Fallback logic
177
+ if not answer or len(answer.split()) < 3:
178
+ answer = _fallback_answer(company, intent if intent in ("money", "complain") else "default")
179
 
180
  _save_chat(phone, question, answer)
181
+ return answer, docs
182
 
183
  # ------------------------------------------------------------------ SUPABASE LOGGING
184
  def _save_chat(phone: str, q: str, a: str) -> None:
 
186
  supabase.table("chat_memory").insert({"user_phone": phone, "role": "user", "message": q}).execute()
187
  supabase.table("chat_memory").insert({"user_phone": phone, "role": "assistant", "message": a}).execute()
188
  except Exception as e:
189
+ print(f"⚠️ Supabase logging failed: {e}")