MusaR commited on
Commit
e8a5efc
·
1 Parent(s): 3d2ba7a

Upload 4 files

Browse files
Files changed (4) hide show
  1. README.md +12 -12
  2. app.py +58 -0
  3. pipeline.py +190 -0
  4. requirements.txt +13 -0
README.md CHANGED
@@ -1,12 +1,12 @@
1
- ---
2
- title: NLP RAG World News
3
- emoji: 🏆
4
- colorFrom: yellow
5
- colorTo: gray
6
- sdk: gradio
7
- sdk_version: 5.34.2
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+ ---
2
+ title: NLP RAG World News
3
+ emoji: 🏆
4
+ colorFrom: yellow
5
+ colorTo: gray
6
+ sdk: gradio
7
+ sdk_version: 5.34.2
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
11
+
12
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from pipeline import RAGPipeline
3
+
4
+ # --- Load the pipeline once globally ---
5
+ # This is crucial for performance, so models are not reloaded on every request.
6
+ print("Initializing RAG Pipeline...")
7
+ try:
8
+ rag_pipeline = RAGPipeline(artifacts_dir="rag_artifacts")
9
+ print("RAG Pipeline initialized successfully.")
10
+ except Exception as e:
11
+ print(f"FATAL: Failed to initialize RAG Pipeline: {e}")
12
+ # If the pipeline fails to load, we can't run the app.
13
+ # We'll display an error in the Gradio interface.
14
+ rag_pipeline = None
15
+
16
+ # --- Define the function that Gradio will call ---
17
+ def get_answer_from_pipeline(query):
18
+ if rag_pipeline is None:
19
+ return "Error: The RAG pipeline failed to load. Please check the server logs.", ""
20
+
21
+ print(f"Processing query in Gradio app: {query}")
22
+ try:
23
+ answer, _, sources = rag_pipeline.answer_query(query)
24
+ # Combine the answer and sources into a single string for display
25
+ full_response = answer + "\n\n" + sources
26
+ return full_response
27
+ except Exception as e:
28
+ print(f"Error during query processing: {e}")
29
+ return f"An error occurred while processing your request: {e}", ""
30
+
31
+ # --- Build the Gradio Interface ---
32
+ title = "Ask the News: A RAG system for World News Articles"
33
+ description = """
34
+ This demo showcases a Retrieval-Augmented Generation (RAG) system built from scratch.
35
+ Enter a question about world events (e.g., Brexit, COVID-19, geopolitical conflicts),
36
+ and the system will retrieve relevant articles from a 30,000-document dataset and generate an answer.
37
+ """
38
+ examples = [
39
+ "What were the main arguments for and against Brexit?",
40
+ "What was the initial response to the COVID-19 outbreak?",
41
+ "Tell me about the conflict in South Ossetia in 2008."
42
+ ]
43
+
44
+ # Using the gr.Markdown component to correctly render the links
45
+ iface = gr.Interface(
46
+ fn=get_answer_from_pipeline,
47
+ inputs=gr.Textbox(lines=2, placeholder="e.g., What happened with Brexit?", label="Question"),
48
+ outputs=gr.Markdown(label="Answer"), # Using Markdown to render links
49
+ title=title,
50
+ description=description,
51
+ examples=examples,
52
+ allow_flagging="never",
53
+ theme=gr.themes.Soft()
54
+ )
55
+
56
+ # --- Launch the App ---
57
+ if __name__ == "__main__":
58
+ iface.launch()
pipeline.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gc
3
+ import time
4
+ import pickle
5
+ from pathlib import Path
6
+ import warnings
7
+
8
+ import pandas as pd
9
+ import numpy as np
10
+ from tqdm.auto import tqdm
11
+ import nltk
12
+ import faiss
13
+ import torch
14
+
15
+ from rank_bm25 import BM25Okapi
16
+ from sentence_transformers import SentenceTransformer, CrossEncoder
17
+ from transformers import AutoTokenizer, AutoModelForCausalLM
18
+
19
+ # --- Basic Configuration ---
20
+ warnings.filterwarnings("ignore")
21
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
22
+ nltk.download('punkt', quiet=True)
23
+ RANDOM_SEED = 42
24
+ np.random.seed(RANDOM_SEED)
25
+ torch.manual_seed(RANDOM_SEED)
26
+ if torch.cuda.is_available():
27
+ torch.cuda.manual_seed_all(RANDOM_SEED)
28
+
29
+ DEVICE = "cpu"
30
+
31
+ class RAGPipeline:
32
+ def __init__(self, artifacts_dir="rag_artifacts"):
33
+ self.artifacts_dir = Path(artifacts_dir)
34
+ self.df = None
35
+ self.chunks_df = None
36
+ self.bm25 = None
37
+ self.index_faiss = None
38
+ self.embedding_model = None
39
+ self.reranker_model = None
40
+ self.llm_model = None
41
+ self.llm_tokenizer = None
42
+ self.load_artifacts()
43
+ self.load_models()
44
+
45
+ def load_artifacts(self):
46
+ print(f"--> Loading artifacts from {self.artifacts_dir}")
47
+ self.df = pd.read_parquet(self.artifacts_dir / "final_df.parquet")
48
+ self.chunks_df = pd.read_parquet(self.artifacts_dir / "chunks_df.parquet")
49
+ print(f"Loaded {len(self.df)} documents and {len(self.chunks_df)} chunks.")
50
+
51
+ with open(self.artifacts_dir / "bm25_index.pkl", "rb") as f:
52
+ self.bm25 = pickle.load(f)
53
+ print("Loaded BM25 index.")
54
+
55
+ self.index_faiss = faiss.read_index(str(self.artifacts_dir / "news_chunks.faiss_index"))
56
+ print(f"Loaded FAISS index with {self.index_faiss.ntotal} vectors.")
57
+
58
+ def load_models(self):
59
+ print("--> Loading models...")
60
+ # Dense Retriever
61
+ EMBEDDING_MODEL_NAME = 'multi-qa-MiniLM-L6-cos-v1'
62
+ self.embedding_model = SentenceTransformer(EMBEDDING_MODEL_NAME, device=DEVICE)
63
+ print(f"Embedding model '{EMBEDDING_MODEL_NAME}' loaded.")
64
+
65
+ # Reranker
66
+ CROSS_ENCODER_MODEL_NAME = 'cross-encoder/ms-marco-MiniLM-L-6-v2'
67
+ self.reranker_model = CrossEncoder(CROSS_ENCODER_MODEL_NAME, device=DEVICE, max_length=512)
68
+ print(f"Reranker model '{CROSS_ENCODER_MODEL_NAME}' loaded.")
69
+
70
+ # LLM
71
+ LLM_MODEL_NAME = "microsoft/Phi-3-mini-4k-instruct"
72
+ print(f"Loading LLM: {LLM_MODEL_NAME}...")
73
+ self.llm_tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_NAME, trust_remote_code=True)
74
+ self.llm_model = AutoModelForCausalLM.from_pretrained(
75
+ LLM_MODEL_NAME,
76
+ trust_remote_code=True
77
+ )
78
+ self.llm_model.to(DEVICE)
79
+ self.llm_model.eval()
80
+
81
+ if self.llm_tokenizer.pad_token is None:
82
+ self.llm_tokenizer.pad_token = self.llm_tokenizer.eos_token
83
+ if hasattr(self.llm_model, 'config'):
84
+ self.llm_model.config.pad_token_id = self.llm_model.config.eos_token_id
85
+ print("LLM loaded successfully.")
86
+
87
+ def search_bm25(self, query: str, k: int = 5):
88
+ tokenized_query = query.lower().split()
89
+ scores = self.bm25.get_scores(tokenized_query)
90
+ topk_indices_scores = sorted(zip(range(len(scores)), scores), key=lambda x: x[1], reverse=True)[:k]
91
+ results = []
92
+ for i, score in topk_indices_scores:
93
+ chunk_info = self.chunks_df.iloc[i]
94
+ results.append({
95
+ 'chunk_id': chunk_info['chunk_id'], 'doc_id': chunk_info['doc_id'],
96
+ 'score': score, 'text': chunk_info['chunk_text'],
97
+ 'title': chunk_info['original_title'], 'url': chunk_info['original_url']
98
+ })
99
+ return results
100
+
101
+ def search_faiss(self, query: str, k: int = 5):
102
+ query_embedding = self.embedding_model.encode(query, convert_to_tensor=True, device=DEVICE)
103
+ query_embedding_cpu = query_embedding.cpu().numpy().reshape(1, -1)
104
+ faiss.normalize_L2(query_embedding_cpu)
105
+ distances, indices = self.index_faiss.search(query_embedding_cpu, k)
106
+ results = []
107
+ for i in range(len(indices[0])):
108
+ idx = indices[0][i]
109
+ score = distances[0][i]
110
+ chunk_info = self.chunks_df.iloc[idx]
111
+ results.append({
112
+ 'chunk_id': chunk_info['chunk_id'], 'doc_id': chunk_info['doc_id'],
113
+ 'score': score, 'text': chunk_info['chunk_text'],
114
+ 'title': chunk_info['original_title'], 'url': chunk_info['original_url']
115
+ })
116
+ return results
117
+
118
+ def hybrid_search_and_rerank(self, query: str, bm25_k: int = 20, faiss_k: int = 20, rerank_top_n: int = 5):
119
+ bm25_res = self.search_bm25(query, k=bm25_k)
120
+ faiss_res = self.search_faiss(query, k=faiss_k)
121
+
122
+ combined_results_dict = {}
123
+ for res_item in bm25_res + faiss_res:
124
+ chunk_id = res_item['chunk_id']
125
+ if chunk_id not in combined_results_dict:
126
+ combined_results_dict[chunk_id] = res_item
127
+
128
+ candidate_chunks = list(combined_results_dict.values())
129
+ if not candidate_chunks:
130
+ return []
131
+
132
+ reranker_pairs = [[query, chunk['text']] for chunk in candidate_chunks]
133
+ rerank_scores = self.reranker_model.predict(reranker_pairs, show_progress_bar=False)
134
+
135
+ for chunk, score in zip(candidate_chunks, rerank_scores):
136
+ chunk['rerank_score'] = score
137
+
138
+ reranked_results = sorted(candidate_chunks, key=lambda x: x['rerank_score'], reverse=True)
139
+ return reranked_results[:rerank_top_n]
140
+
141
+ def format_rag_prompt(self, query: str, context_chunks: list):
142
+ context_str = "\n\n---\n\n".join([chunk['text'] for chunk in context_chunks])
143
+ system_message = "You are a helpful AI assistant. Answer the user's QUESTION based *only* on the provided CONTEXT. If the context does not contain the answer, say 'I cannot answer the question based on the provided context.' Do not use any prior knowledge. Be concise and directly answer the question."
144
+ user_message_content = f"CONTEXT:\n{context_str}\n\nQUESTION: {query}"
145
+ messages = [
146
+ {"role": "system", "content": system_message},
147
+ {"role": "user", "content": user_message_content}
148
+ ]
149
+ prompt = self.llm_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
150
+ return prompt
151
+
152
+ def generate_llm_answer(self, query: str, context_chunks: list):
153
+ if not context_chunks:
154
+ return "No relevant context found to answer the question.", []
155
+
156
+ formatted_prompt = self.format_rag_prompt(query, context_chunks)
157
+ inputs = self.llm_tokenizer(formatted_prompt, return_tensors="pt", padding=True, truncation=True, max_length=3800).to(DEVICE)
158
+
159
+ generation_args = {
160
+ "max_new_tokens": 250, "temperature": 0.1, "do_sample": True,
161
+ "top_p": 0.9, "eos_token_id": self.llm_tokenizer.eos_token_id,
162
+ "pad_token_id": self.llm_tokenizer.pad_token_id
163
+ }
164
+
165
+ with torch.no_grad():
166
+ output_ids = self.llm_model.generate(**inputs, **generation_args)
167
+
168
+ answer = self.llm_tokenizer.decode(output_ids[0, inputs["input_ids"].shape[1]:], skip_special_tokens=True)
169
+ return answer.strip(), context_chunks
170
+
171
+ def answer_query(self, query: str):
172
+ print(f"Received query: {query}")
173
+ # 1. Retrieve and Rerank
174
+ retrieved_context = self.hybrid_search_and_rerank(query, bm25_k=15, faiss_k=15, rerank_top_n=3)
175
+
176
+ if not retrieved_context:
177
+ return "Could not find any relevant documents to answer your question.", [], "No context found."
178
+
179
+ # 2. Generate Answer
180
+ llm_answer, used_context_chunks = self.generate_llm_answer(query, retrieved_context)
181
+
182
+ # 3. Format sources
183
+ sources_text = "\n\n**Sources:**\n"
184
+ seen_urls = set()
185
+ for i, chunk in enumerate(used_context_chunks):
186
+ if chunk['url'] not in seen_urls:
187
+ sources_text += f"- [{chunk['title']}]({chunk['url']})\n"
188
+ seen_urls.add(chunk['url'])
189
+
190
+ return llm_answer, used_context_chunks, sources_text
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pandas
2
+ pyarrow
3
+ datasets==2.19.*
4
+ sentence-transformers==2.7.0
5
+ faiss-cpu==1.8.0
6
+ rank_bm25==0.2.2
7
+ nltk==3.8.1
8
+ tqdm==4.66.1
9
+ transformers==4.40.*
10
+ accelerate==0.29.*
11
+ langchain
12
+ gradio
13
+ torch