Spaces:
Sleeping
Sleeping
Upload 4 files
Browse files- README.md +12 -12
- app.py +58 -0
- pipeline.py +190 -0
- requirements.txt +13 -0
README.md
CHANGED
|
@@ -1,12 +1,12 @@
|
|
| 1 |
-
---
|
| 2 |
-
title: NLP RAG World News
|
| 3 |
-
emoji: 🏆
|
| 4 |
-
colorFrom: yellow
|
| 5 |
-
colorTo: gray
|
| 6 |
-
sdk: gradio
|
| 7 |
-
sdk_version: 5.34.2
|
| 8 |
-
app_file: app.py
|
| 9 |
-
pinned: false
|
| 10 |
-
---
|
| 11 |
-
|
| 12 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: NLP RAG World News
|
| 3 |
+
emoji: 🏆
|
| 4 |
+
colorFrom: yellow
|
| 5 |
+
colorTo: gray
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 5.34.2
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
---
|
| 11 |
+
|
| 12 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
from pipeline import RAGPipeline
|
| 3 |
+
|
| 4 |
+
# --- Load the pipeline once globally ---
|
| 5 |
+
# This is crucial for performance, so models are not reloaded on every request.
|
| 6 |
+
print("Initializing RAG Pipeline...")
|
| 7 |
+
try:
|
| 8 |
+
rag_pipeline = RAGPipeline(artifacts_dir="rag_artifacts")
|
| 9 |
+
print("RAG Pipeline initialized successfully.")
|
| 10 |
+
except Exception as e:
|
| 11 |
+
print(f"FATAL: Failed to initialize RAG Pipeline: {e}")
|
| 12 |
+
# If the pipeline fails to load, we can't run the app.
|
| 13 |
+
# We'll display an error in the Gradio interface.
|
| 14 |
+
rag_pipeline = None
|
| 15 |
+
|
| 16 |
+
# --- Define the function that Gradio will call ---
|
| 17 |
+
def get_answer_from_pipeline(query):
|
| 18 |
+
if rag_pipeline is None:
|
| 19 |
+
return "Error: The RAG pipeline failed to load. Please check the server logs.", ""
|
| 20 |
+
|
| 21 |
+
print(f"Processing query in Gradio app: {query}")
|
| 22 |
+
try:
|
| 23 |
+
answer, _, sources = rag_pipeline.answer_query(query)
|
| 24 |
+
# Combine the answer and sources into a single string for display
|
| 25 |
+
full_response = answer + "\n\n" + sources
|
| 26 |
+
return full_response
|
| 27 |
+
except Exception as e:
|
| 28 |
+
print(f"Error during query processing: {e}")
|
| 29 |
+
return f"An error occurred while processing your request: {e}", ""
|
| 30 |
+
|
| 31 |
+
# --- Build the Gradio Interface ---
|
| 32 |
+
title = "Ask the News: A RAG system for World News Articles"
|
| 33 |
+
description = """
|
| 34 |
+
This demo showcases a Retrieval-Augmented Generation (RAG) system built from scratch.
|
| 35 |
+
Enter a question about world events (e.g., Brexit, COVID-19, geopolitical conflicts),
|
| 36 |
+
and the system will retrieve relevant articles from a 30,000-document dataset and generate an answer.
|
| 37 |
+
"""
|
| 38 |
+
examples = [
|
| 39 |
+
"What were the main arguments for and against Brexit?",
|
| 40 |
+
"What was the initial response to the COVID-19 outbreak?",
|
| 41 |
+
"Tell me about the conflict in South Ossetia in 2008."
|
| 42 |
+
]
|
| 43 |
+
|
| 44 |
+
# Using the gr.Markdown component to correctly render the links
|
| 45 |
+
iface = gr.Interface(
|
| 46 |
+
fn=get_answer_from_pipeline,
|
| 47 |
+
inputs=gr.Textbox(lines=2, placeholder="e.g., What happened with Brexit?", label="Question"),
|
| 48 |
+
outputs=gr.Markdown(label="Answer"), # Using Markdown to render links
|
| 49 |
+
title=title,
|
| 50 |
+
description=description,
|
| 51 |
+
examples=examples,
|
| 52 |
+
allow_flagging="never",
|
| 53 |
+
theme=gr.themes.Soft()
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
# --- Launch the App ---
|
| 57 |
+
if __name__ == "__main__":
|
| 58 |
+
iface.launch()
|
pipeline.py
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import gc
|
| 3 |
+
import time
|
| 4 |
+
import pickle
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
import warnings
|
| 7 |
+
|
| 8 |
+
import pandas as pd
|
| 9 |
+
import numpy as np
|
| 10 |
+
from tqdm.auto import tqdm
|
| 11 |
+
import nltk
|
| 12 |
+
import faiss
|
| 13 |
+
import torch
|
| 14 |
+
|
| 15 |
+
from rank_bm25 import BM25Okapi
|
| 16 |
+
from sentence_transformers import SentenceTransformer, CrossEncoder
|
| 17 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 18 |
+
|
| 19 |
+
# --- Basic Configuration ---
|
| 20 |
+
warnings.filterwarnings("ignore")
|
| 21 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 22 |
+
nltk.download('punkt', quiet=True)
|
| 23 |
+
RANDOM_SEED = 42
|
| 24 |
+
np.random.seed(RANDOM_SEED)
|
| 25 |
+
torch.manual_seed(RANDOM_SEED)
|
| 26 |
+
if torch.cuda.is_available():
|
| 27 |
+
torch.cuda.manual_seed_all(RANDOM_SEED)
|
| 28 |
+
|
| 29 |
+
DEVICE = "cpu"
|
| 30 |
+
|
| 31 |
+
class RAGPipeline:
|
| 32 |
+
def __init__(self, artifacts_dir="rag_artifacts"):
|
| 33 |
+
self.artifacts_dir = Path(artifacts_dir)
|
| 34 |
+
self.df = None
|
| 35 |
+
self.chunks_df = None
|
| 36 |
+
self.bm25 = None
|
| 37 |
+
self.index_faiss = None
|
| 38 |
+
self.embedding_model = None
|
| 39 |
+
self.reranker_model = None
|
| 40 |
+
self.llm_model = None
|
| 41 |
+
self.llm_tokenizer = None
|
| 42 |
+
self.load_artifacts()
|
| 43 |
+
self.load_models()
|
| 44 |
+
|
| 45 |
+
def load_artifacts(self):
|
| 46 |
+
print(f"--> Loading artifacts from {self.artifacts_dir}")
|
| 47 |
+
self.df = pd.read_parquet(self.artifacts_dir / "final_df.parquet")
|
| 48 |
+
self.chunks_df = pd.read_parquet(self.artifacts_dir / "chunks_df.parquet")
|
| 49 |
+
print(f"Loaded {len(self.df)} documents and {len(self.chunks_df)} chunks.")
|
| 50 |
+
|
| 51 |
+
with open(self.artifacts_dir / "bm25_index.pkl", "rb") as f:
|
| 52 |
+
self.bm25 = pickle.load(f)
|
| 53 |
+
print("Loaded BM25 index.")
|
| 54 |
+
|
| 55 |
+
self.index_faiss = faiss.read_index(str(self.artifacts_dir / "news_chunks.faiss_index"))
|
| 56 |
+
print(f"Loaded FAISS index with {self.index_faiss.ntotal} vectors.")
|
| 57 |
+
|
| 58 |
+
def load_models(self):
|
| 59 |
+
print("--> Loading models...")
|
| 60 |
+
# Dense Retriever
|
| 61 |
+
EMBEDDING_MODEL_NAME = 'multi-qa-MiniLM-L6-cos-v1'
|
| 62 |
+
self.embedding_model = SentenceTransformer(EMBEDDING_MODEL_NAME, device=DEVICE)
|
| 63 |
+
print(f"Embedding model '{EMBEDDING_MODEL_NAME}' loaded.")
|
| 64 |
+
|
| 65 |
+
# Reranker
|
| 66 |
+
CROSS_ENCODER_MODEL_NAME = 'cross-encoder/ms-marco-MiniLM-L-6-v2'
|
| 67 |
+
self.reranker_model = CrossEncoder(CROSS_ENCODER_MODEL_NAME, device=DEVICE, max_length=512)
|
| 68 |
+
print(f"Reranker model '{CROSS_ENCODER_MODEL_NAME}' loaded.")
|
| 69 |
+
|
| 70 |
+
# LLM
|
| 71 |
+
LLM_MODEL_NAME = "microsoft/Phi-3-mini-4k-instruct"
|
| 72 |
+
print(f"Loading LLM: {LLM_MODEL_NAME}...")
|
| 73 |
+
self.llm_tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_NAME, trust_remote_code=True)
|
| 74 |
+
self.llm_model = AutoModelForCausalLM.from_pretrained(
|
| 75 |
+
LLM_MODEL_NAME,
|
| 76 |
+
trust_remote_code=True
|
| 77 |
+
)
|
| 78 |
+
self.llm_model.to(DEVICE)
|
| 79 |
+
self.llm_model.eval()
|
| 80 |
+
|
| 81 |
+
if self.llm_tokenizer.pad_token is None:
|
| 82 |
+
self.llm_tokenizer.pad_token = self.llm_tokenizer.eos_token
|
| 83 |
+
if hasattr(self.llm_model, 'config'):
|
| 84 |
+
self.llm_model.config.pad_token_id = self.llm_model.config.eos_token_id
|
| 85 |
+
print("LLM loaded successfully.")
|
| 86 |
+
|
| 87 |
+
def search_bm25(self, query: str, k: int = 5):
|
| 88 |
+
tokenized_query = query.lower().split()
|
| 89 |
+
scores = self.bm25.get_scores(tokenized_query)
|
| 90 |
+
topk_indices_scores = sorted(zip(range(len(scores)), scores), key=lambda x: x[1], reverse=True)[:k]
|
| 91 |
+
results = []
|
| 92 |
+
for i, score in topk_indices_scores:
|
| 93 |
+
chunk_info = self.chunks_df.iloc[i]
|
| 94 |
+
results.append({
|
| 95 |
+
'chunk_id': chunk_info['chunk_id'], 'doc_id': chunk_info['doc_id'],
|
| 96 |
+
'score': score, 'text': chunk_info['chunk_text'],
|
| 97 |
+
'title': chunk_info['original_title'], 'url': chunk_info['original_url']
|
| 98 |
+
})
|
| 99 |
+
return results
|
| 100 |
+
|
| 101 |
+
def search_faiss(self, query: str, k: int = 5):
|
| 102 |
+
query_embedding = self.embedding_model.encode(query, convert_to_tensor=True, device=DEVICE)
|
| 103 |
+
query_embedding_cpu = query_embedding.cpu().numpy().reshape(1, -1)
|
| 104 |
+
faiss.normalize_L2(query_embedding_cpu)
|
| 105 |
+
distances, indices = self.index_faiss.search(query_embedding_cpu, k)
|
| 106 |
+
results = []
|
| 107 |
+
for i in range(len(indices[0])):
|
| 108 |
+
idx = indices[0][i]
|
| 109 |
+
score = distances[0][i]
|
| 110 |
+
chunk_info = self.chunks_df.iloc[idx]
|
| 111 |
+
results.append({
|
| 112 |
+
'chunk_id': chunk_info['chunk_id'], 'doc_id': chunk_info['doc_id'],
|
| 113 |
+
'score': score, 'text': chunk_info['chunk_text'],
|
| 114 |
+
'title': chunk_info['original_title'], 'url': chunk_info['original_url']
|
| 115 |
+
})
|
| 116 |
+
return results
|
| 117 |
+
|
| 118 |
+
def hybrid_search_and_rerank(self, query: str, bm25_k: int = 20, faiss_k: int = 20, rerank_top_n: int = 5):
|
| 119 |
+
bm25_res = self.search_bm25(query, k=bm25_k)
|
| 120 |
+
faiss_res = self.search_faiss(query, k=faiss_k)
|
| 121 |
+
|
| 122 |
+
combined_results_dict = {}
|
| 123 |
+
for res_item in bm25_res + faiss_res:
|
| 124 |
+
chunk_id = res_item['chunk_id']
|
| 125 |
+
if chunk_id not in combined_results_dict:
|
| 126 |
+
combined_results_dict[chunk_id] = res_item
|
| 127 |
+
|
| 128 |
+
candidate_chunks = list(combined_results_dict.values())
|
| 129 |
+
if not candidate_chunks:
|
| 130 |
+
return []
|
| 131 |
+
|
| 132 |
+
reranker_pairs = [[query, chunk['text']] for chunk in candidate_chunks]
|
| 133 |
+
rerank_scores = self.reranker_model.predict(reranker_pairs, show_progress_bar=False)
|
| 134 |
+
|
| 135 |
+
for chunk, score in zip(candidate_chunks, rerank_scores):
|
| 136 |
+
chunk['rerank_score'] = score
|
| 137 |
+
|
| 138 |
+
reranked_results = sorted(candidate_chunks, key=lambda x: x['rerank_score'], reverse=True)
|
| 139 |
+
return reranked_results[:rerank_top_n]
|
| 140 |
+
|
| 141 |
+
def format_rag_prompt(self, query: str, context_chunks: list):
|
| 142 |
+
context_str = "\n\n---\n\n".join([chunk['text'] for chunk in context_chunks])
|
| 143 |
+
system_message = "You are a helpful AI assistant. Answer the user's QUESTION based *only* on the provided CONTEXT. If the context does not contain the answer, say 'I cannot answer the question based on the provided context.' Do not use any prior knowledge. Be concise and directly answer the question."
|
| 144 |
+
user_message_content = f"CONTEXT:\n{context_str}\n\nQUESTION: {query}"
|
| 145 |
+
messages = [
|
| 146 |
+
{"role": "system", "content": system_message},
|
| 147 |
+
{"role": "user", "content": user_message_content}
|
| 148 |
+
]
|
| 149 |
+
prompt = self.llm_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 150 |
+
return prompt
|
| 151 |
+
|
| 152 |
+
def generate_llm_answer(self, query: str, context_chunks: list):
|
| 153 |
+
if not context_chunks:
|
| 154 |
+
return "No relevant context found to answer the question.", []
|
| 155 |
+
|
| 156 |
+
formatted_prompt = self.format_rag_prompt(query, context_chunks)
|
| 157 |
+
inputs = self.llm_tokenizer(formatted_prompt, return_tensors="pt", padding=True, truncation=True, max_length=3800).to(DEVICE)
|
| 158 |
+
|
| 159 |
+
generation_args = {
|
| 160 |
+
"max_new_tokens": 250, "temperature": 0.1, "do_sample": True,
|
| 161 |
+
"top_p": 0.9, "eos_token_id": self.llm_tokenizer.eos_token_id,
|
| 162 |
+
"pad_token_id": self.llm_tokenizer.pad_token_id
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
with torch.no_grad():
|
| 166 |
+
output_ids = self.llm_model.generate(**inputs, **generation_args)
|
| 167 |
+
|
| 168 |
+
answer = self.llm_tokenizer.decode(output_ids[0, inputs["input_ids"].shape[1]:], skip_special_tokens=True)
|
| 169 |
+
return answer.strip(), context_chunks
|
| 170 |
+
|
| 171 |
+
def answer_query(self, query: str):
|
| 172 |
+
print(f"Received query: {query}")
|
| 173 |
+
# 1. Retrieve and Rerank
|
| 174 |
+
retrieved_context = self.hybrid_search_and_rerank(query, bm25_k=15, faiss_k=15, rerank_top_n=3)
|
| 175 |
+
|
| 176 |
+
if not retrieved_context:
|
| 177 |
+
return "Could not find any relevant documents to answer your question.", [], "No context found."
|
| 178 |
+
|
| 179 |
+
# 2. Generate Answer
|
| 180 |
+
llm_answer, used_context_chunks = self.generate_llm_answer(query, retrieved_context)
|
| 181 |
+
|
| 182 |
+
# 3. Format sources
|
| 183 |
+
sources_text = "\n\n**Sources:**\n"
|
| 184 |
+
seen_urls = set()
|
| 185 |
+
for i, chunk in enumerate(used_context_chunks):
|
| 186 |
+
if chunk['url'] not in seen_urls:
|
| 187 |
+
sources_text += f"- [{chunk['title']}]({chunk['url']})\n"
|
| 188 |
+
seen_urls.add(chunk['url'])
|
| 189 |
+
|
| 190 |
+
return llm_answer, used_context_chunks, sources_text
|
requirements.txt
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
pandas
|
| 2 |
+
pyarrow
|
| 3 |
+
datasets==2.19.*
|
| 4 |
+
sentence-transformers==2.7.0
|
| 5 |
+
faiss-cpu==1.8.0
|
| 6 |
+
rank_bm25==0.2.2
|
| 7 |
+
nltk==3.8.1
|
| 8 |
+
tqdm==4.66.1
|
| 9 |
+
transformers==4.40.*
|
| 10 |
+
accelerate==0.29.*
|
| 11 |
+
langchain
|
| 12 |
+
gradio
|
| 13 |
+
torch
|