rishabhsetiya commited on
Commit
e56befa
·
verified ·
1 Parent(s): 056863e

Create rag.py

Browse files
Files changed (1) hide show
  1. rag.py +247 -0
rag.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dotenv import load_dotenv
3
+ import re
4
+ import pickle
5
+ import faiss
6
+ import numpy as np
7
+ from typing import List, Dict
8
+ from sentence_transformers import SentenceTransformer, CrossEncoder, util
9
+ from rank_bm25 import BM25Okapi
10
+ import nltk
11
+ from nltk.corpus import stopwords
12
+ import requests
13
+ import json
14
+ from openai import OpenAI
15
+ import logging
16
+
17
+ load_dotenv()
18
+ # ---------------- Logging Setup ----------------
19
+ logging.basicConfig(
20
+ level=logging.INFO,
21
+ format='%(asctime)s %(levelname)s %(message)s',
22
+ handlers=[logging.StreamHandler()]
23
+ )
24
+ logger = logging.getLogger(__name__)
25
+
26
+ nltk.download("stopwords")
27
+ STOPWORDS = set(stopwords.words("english"))
28
+
29
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
30
+ # ...rest of your imports...
31
+
32
+ # ---------------- Paths & Models ----------------
33
+ EMBED_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
34
+ CROSS_ENCODER = "cross-encoder/ms-marco-MiniLM-L-6-v2"
35
+ OUT_DIR = "data/index_merged"
36
+
37
+ FAISS_PATH = os.path.join(OUT_DIR, "faiss_merged.index")
38
+ BM25_PATH = os.path.join(OUT_DIR, "bm25_merged.pkl")
39
+ META_PATH = os.path.join(OUT_DIR, "meta_merged.pkl")
40
+
41
+ BLOCKED_TERMS = ["weather","cricket","movie","song","football","holiday",
42
+ "travel","recipe","music","game","sports","politics","election"]
43
+
44
+ FINANCE_DOMAINS = [
45
+ "financial reporting","balance sheet","income statement","assets and liabilities",
46
+ "equity","revenue","profit and loss","goodwill impairment","cash flow","dividends",
47
+ "taxation","investment","valuation","capital structure","ownership interests",
48
+ "subsidiaries","shareholders equity","expenses","earnings","debt","amortization","depreciation"
49
+ ]
50
+
51
+ ALLOWED_COMPANY = ["make my trip","mmt"]
52
+
53
+ # crude regex to detect "company-like" words (any capitalized word(s) followed by Ltd, Inc, Company, etc.)
54
+ COMPANY_PATTERN = re.compile(r"\b([A-Z][a-zA-Z]+(?:\s+[A-Z][a-zA-Z]+)*\s+(?:Ltd|Limited|Inc|Corporation|Corp|LLC|Group|Company|Bank))\b", re.IGNORECASE)
55
+
56
+ # ---------------- Load Indexes ----------------
57
+ logger.info("Loading FAISS, BM25, metadata, and models...")
58
+ try:
59
+ faiss_index = faiss.read_index(FAISS_PATH)
60
+ with open(BM25_PATH, "rb") as f:
61
+ bm25_obj = pickle.load(f)
62
+ bm25 = bm25_obj["bm25"]
63
+ with open(META_PATH, "rb") as f:
64
+ meta: List[Dict] = pickle.load(f)
65
+ embed_model = SentenceTransformer(EMBED_MODEL)
66
+ reranker = CrossEncoder(CROSS_ENCODER)
67
+ api_key = os.getenv("HF_API_KEY")
68
+ if not api_key:
69
+ logger.error("HF_API_KEY environment variable not set. Please check your .env file or environment.")
70
+ raise ValueError("HF_API_KEY environment variable not set.")
71
+ client = OpenAI(
72
+ base_url="https://router.huggingface.co/v1",
73
+ api_key=api_key
74
+ )
75
+ except Exception as e:
76
+ logger.error(f"Error loading models or indexes: {e}")
77
+ raise
78
+
79
+ # ---------------- Hugging Face Mistral API ----------------
80
+ #HF_TOKEN = "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"
81
+ #HF_MODEL = "mistralai/Mistral-7B-Instruct-v0.2:featherless-ai"
82
+
83
+ def get_mistral_answer(query: str, context: str) -> str:
84
+ """
85
+ Calls Mistral 7B Instruct API via Hugging Face Inference API.
86
+ Adds error handling and logging.
87
+ """
88
+ prompt = f"Context:\n{context}\n\nQuestion: {query}\nAnswer in full sentences using context."
89
+ try:
90
+ logger.info(f"Calling Mistral API for query: {query}")
91
+ completion = client.chat.completions.create(
92
+ model="mistralai/Mistral-7B-Instruct-v0.2:featherless-ai",
93
+ messages=[
94
+ {
95
+ "role": "user",
96
+ "content": prompt
97
+ }
98
+ ]
99
+ )
100
+ answer = str(completion.choices[0].message.content)
101
+ logger.info(f"Mistral API response: {answer}")
102
+ return answer
103
+ except Exception as e:
104
+ logger.error(f"Error in Mistral API call: {e}")
105
+ return f"Error fetching answer from LLM: {e}"
106
+
107
+ # ---------------- Guardrails ----------------
108
+ finance_embeds = embed_model.encode(FINANCE_DOMAINS, convert_to_tensor=True)
109
+
110
+ def validate_query(query: str, threshold: float = 0.5) -> bool:
111
+ q_lower = query.lower()
112
+
113
+ # Blocklist check
114
+ if any(bad in q_lower for bad in BLOCKED_TERMS):
115
+ print("[Guardrail] Rejected by blocklist.")
116
+ return False
117
+
118
+ # Check for company mentions
119
+ companies_found = COMPANY_PATTERN.findall(query)
120
+ if companies_found:
121
+ # If any company is mentioned, only allow MakeMyTrip
122
+ if not any(ALLOWED_COMPANY in c.lower() for c in companies_found):
123
+ print(f"[Guardrail] Rejected: company mention {companies_found}, not {ALLOWED_COMPANY}.")
124
+ return False
125
+
126
+ # Semantic similarity check with financial domain
127
+ q_emb = embed_model.encode(query, convert_to_tensor=True)
128
+ sim_scores = util.cos_sim(q_emb, finance_embeds)
129
+ max_score = float(sim_scores.max())
130
+
131
+ if max_score > threshold:
132
+ print(f"[Guardrail] Accepted (semantic match {max_score:.2f})")
133
+ return True
134
+ else:
135
+ print(f"[Guardrail] Rejected (low semantic score {max_score:.2f})")
136
+ return False
137
+
138
+ #-------------------Output Guardrail------------------
139
+ def validate_output(answer: str, context_docs: List[Dict]) -> str:
140
+ combined_context = " ".join([doc["content"].lower() for doc in context_docs])
141
+ if answer.lower() in combined_context:
142
+ return answer
143
+ return "The information could not be verified in the financial statement attached."
144
+
145
+ # ---------------- Preprocess ----------------
146
+ def preprocess_query(query: str, remove_stopwords: bool = True) -> str:
147
+ query = query.lower()
148
+ query = re.sub(r"[^a-z0-9\s]", " ", query)
149
+ tokens = query.split()
150
+ if remove_stopwords:
151
+ tokens = [t for t in tokens if t not in STOPWORDS]
152
+ return " ".join(tokens)
153
+
154
+ # ---------------- Hybrid Retrieval ----------------
155
+ def hybrid_candidates(query: str, candidate_k: int = 50, alpha: float = 0.5) -> List[int]:
156
+ q_emb = embed_model.encode([preprocess_query(query, remove_stopwords=False)], convert_to_numpy=True, normalize_embeddings=True)
157
+ faiss_scores, faiss_ids = faiss_index.search(q_emb, max(candidate_k, 50))
158
+ faiss_ids = faiss_ids[0]
159
+ faiss_scores = faiss_scores[0]
160
+
161
+ tokenized_query = preprocess_query(query).split()
162
+ bm25_scores = bm25.get_scores(tokenized_query)
163
+
164
+ topN = max(candidate_k, 50)
165
+ bm25_top = np.argsort(bm25_scores)[::-1][:topN]
166
+ faiss_top = faiss_ids[:topN]
167
+ union_ids = np.unique(np.concatenate([bm25_top, faiss_top]))
168
+
169
+ faiss_score_map = {int(i): float(s) for i, s in zip(faiss_ids, faiss_scores)}
170
+ f_arr = np.array([faiss_score_map.get(int(i), -1.0) for i in union_ids], dtype=float)
171
+ f_min = np.min(f_arr)
172
+ if np.any(f_arr < 0):
173
+ f_arr = np.where(f_arr < 0, f_min, f_arr)
174
+ b_arr = np.array([bm25_scores[int(i)] for i in union_ids], dtype=float)
175
+
176
+ def _norm(x): return (x - np.min(x)) / (np.ptp(x) + 1e-9)
177
+ combined = alpha * _norm(f_arr) + (1 - alpha) * _norm(b_arr)
178
+ order = np.argsort(combined)[::-1]
179
+ return union_ids[order][:candidate_k].tolist()
180
+
181
+ # ---------------- Cross-Encoder Rerank ----------------
182
+ def rerank_cross_encoder(query: str, cand_ids: List[int], top_k: int = 10) -> List[Dict]:
183
+ pairs = [(query, meta[i]["content"]) for i in cand_ids]
184
+ scores = reranker.predict(pairs)
185
+ order = np.argsort(scores)[::-1][:top_k]
186
+ return [{"id": cand_ids[i], "chunk_size": meta[cand_ids[i]]["chunk_size"], "content": meta[cand_ids[i]]["content"], "rerank_score": float(scores[i])} for i in order]
187
+
188
+ # ---------------- Extract Numeric ----------------
189
+ def extract_value_for_year_and_concept(year: str, concept: str, context_docs: List[Dict]) -> str:
190
+ target_year = str(year)
191
+ concept_lower = concept.lower()
192
+ for doc in context_docs:
193
+ text = doc.get("content", "")
194
+ lines = [line for line in text.split("\n") if line.strip() and any(c.isdigit() for c in line)]
195
+ header_idx = None
196
+ year_to_col = {}
197
+ for idx, line in enumerate(lines):
198
+ years_in_line = re.findall(r"20\d{2}", line)
199
+ if years_in_line:
200
+ for col_idx, y in enumerate(years_in_line):
201
+ year_to_col[y] = col_idx
202
+ header_idx = idx
203
+ break
204
+ if target_year not in year_to_col or header_idx is None:
205
+ continue
206
+ for line in lines[header_idx+1:]:
207
+ if concept_lower in line.lower():
208
+ cols = re.split(r"\s{2,}|\t", line)
209
+ col_idx = year_to_col[target_year]
210
+ if col_idx < len(cols):
211
+ return cols[col_idx].replace(",", "")
212
+ return ""
213
+
214
+ # ---------------- RAG Pipeline ----------------
215
+ def generate_answer(query: str, top_k: int = 5, candidate_k: int = 50, alpha: float = 0.6):
216
+ logger.info(f"Received query: {query}")
217
+ try:
218
+ if not validate_query(query):
219
+ logger.warning("Query rejected: Not finance-related.")
220
+ return "Query rejected: Please ask finance-related questions.", []
221
+
222
+ cand_ids = hybrid_candidates(query, candidate_k=candidate_k, alpha=alpha)
223
+ logger.info(f"Hybrid candidates retrieved: {cand_ids}")
224
+ reranked = rerank_cross_encoder(query, cand_ids, top_k=top_k)
225
+ logger.info(f"Reranked top docs: {[d['id'] for d in reranked]}")
226
+
227
+ year_match = re.search(r"(20\d{2})", query)
228
+ year = year_match.group(0) if year_match else None
229
+ concept = re.sub(r"for the year 20\d{2}", "", query, flags=re.IGNORECASE).strip()
230
+
231
+ year_specific_answer = None
232
+ if year and concept:
233
+ year_specific_answer = extract_value_for_year_and_concept(year, concept, reranked)
234
+ logger.info(f"Year-specific answer: {year_specific_answer}")
235
+
236
+ if year_specific_answer:
237
+ answer = year_specific_answer
238
+ else:
239
+ # Pass top 5 chunks as context
240
+ context_text = "\n".join([d["content"] for d in reranked])
241
+ answer = get_mistral_answer(query, context_text)
242
+ final_answer = answer #validate_output(answer, reranked)
243
+ logger.info(f"Final Answer: {final_answer}")
244
+ return final_answer
245
+ except Exception as e:
246
+ logger.error(f"Error in RAG pipeline: {e}")
247
+ return f"Error in RAG pipeline: {e}", []