faiz0983 commited on
Commit
ad21633
·
verified ·
1 Parent(s): 55bcaec

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -38
app.py CHANGED
@@ -1,21 +1,42 @@
1
  import os
2
  import gradio as gr
3
- from langchain_community.document_loaders import PyPDFLoader, TextLoader, Docx2txtLoader
4
- from langchain_text_splitters import RecursiveCharacterTextSplitter
5
- from langchain_huggingface import HuggingFaceEmbeddings
6
- from langchain_community.vectorstores import FAISS
7
- from langchain_groq import ChatGroq
8
  from langchain_classic.chains import ConversationalRetrievalChain
9
  from langchain_classic.memory import ConversationBufferMemory
10
-
11
- # --- NEW IMPORTS FOR HYBRID SEARCH ---
 
 
 
12
  from langchain_community.retrievers import BM25Retriever
13
  from langchain.retrievers import EnsembleRetriever
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
- # 1. SETUP API
16
- api_key = os.environ.get("GROQ_API")
 
 
17
 
18
- # 2. FILE LOADING LOGIC
19
  def load_any(path: str):
20
  p = path.lower()
21
  if p.endswith(".pdf"): return PyPDFLoader(path).load()
@@ -23,41 +44,39 @@ def load_any(path: str):
23
  if p.endswith(".docx"): return Docx2txtLoader(path).load()
24
  return []
25
 
26
- # 3. HYBRID PROCESSING FUNCTION
27
- def process_files(files):
28
  if not files or not api_key:
29
- return None, "⚠️ Missing files or API key."
30
 
31
  try:
32
- # Load all documents
33
  docs = []
34
  for file_obj in files:
35
  docs.extend(load_any(file_obj.name))
36
 
37
- if not docs:
38
- return None, "⚠️ No readable text found."
39
-
40
- # Split into chunks
41
- splitter = RecursiveCharacterTextSplitter(chunk_size=700, chunk_overlap=100)
42
  chunks = splitter.split_documents(docs)
43
 
44
- # A. Semantic Search (FAISS)
45
  embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
46
  faiss_db = FAISS.from_documents(chunks, embeddings)
47
  faiss_retriever = faiss_db.as_retriever(search_kwargs={"k": 3})
48
-
49
- # B. Keyword Search (BM25) - THIS IS THE MULTI-RETRIEVER ADDITION
50
  bm25_retriever = BM25Retriever.from_documents(chunks)
51
  bm25_retriever.k = 3
52
 
53
- # C. Ensemble (Hybrid Search)
54
  ensemble_retriever = EnsembleRetriever(
55
  retrievers=[faiss_retriever, bm25_retriever],
56
- weights=[0.6, 0.4] # 60% Semantic, 40% Keyword
57
  )
58
 
59
- # D. Classic Chain Setup
60
- llm = ChatGroq(groq_api_key=api_key, model="llama-3.3-70b-versatile", temperature=0)
 
 
 
 
 
61
  memory = ConversationBufferMemory(
62
  memory_key="chat_history",
63
  return_messages=True,
@@ -66,18 +85,19 @@ def process_files(files):
66
 
67
  chain = ConversationalRetrievalChain.from_llm(
68
  llm=llm,
69
- retriever=ensemble_retriever, # Use Hybrid Retriever
 
70
  memory=memory,
71
  return_source_documents=True,
72
  output_key="answer"
73
  )
74
 
75
- return chain, f"✅ Hybrid Multi-RAG Ready! ({len(chunks)} chunks)"
76
 
77
  except Exception as e:
78
  return None, f"❌ Error: {str(e)}"
79
 
80
- # 4. CHAT FUNCTION
81
  def chat_function(message, history, chain):
82
  if not chain:
83
  return "⚠️ Build the chatbot first."
@@ -85,27 +105,27 @@ def chat_function(message, history, chain):
85
  res = chain.invoke({"question": message})
86
  answer = res["answer"]
87
 
88
- # Format Sources
89
  sources = list(set([os.path.basename(d.metadata.get("source", "unknown")) for d in res.get("source_documents", [])]))
90
- source_text = "\n\n---\n**Sources:** " + ", ".join(sources)
91
 
92
- return answer + source_text
93
 
94
- # 5. UI
95
- with gr.Blocks(title="Hybrid RAG") as demo:
96
- gr.Markdown("# 🚀 Hybrid Multi-RAG Chatbot")
97
  chain_state = gr.State(None)
98
 
99
  with gr.Row():
100
  with gr.Column(scale=1):
101
- file_input = gr.File(file_count="multiple", label="Upload Docs")
102
- build_btn = gr.Button("Build Hybrid RAG", variant="primary")
 
103
  status = gr.Textbox(label="Status", interactive=False)
104
 
105
  with gr.Column(scale=2):
106
  gr.ChatInterface(fn=chat_function, additional_inputs=[chain_state])
107
 
108
- build_btn.click(process_files, inputs=[file_input], outputs=[chain_state, status])
109
 
110
  if __name__ == "__main__":
111
  demo.launch()
 
1
  import os
2
  import gradio as gr
3
+
4
+ # Classic & Community Imports
 
 
 
5
  from langchain_classic.chains import ConversationalRetrievalChain
6
  from langchain_classic.memory import ConversationBufferMemory
7
+ from langchain_groq import ChatGroq
8
+ from langchain_community.vectorstores import FAISS
9
+ from langchain_huggingface import HuggingFaceEmbeddings
10
+ from langchain_community.document_loaders import PyPDFLoader, TextLoader, Docx2txtLoader
11
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
12
  from langchain_community.retrievers import BM25Retriever
13
  from langchain.retrievers import EnsembleRetriever
14
+ from langchain.prompts import PromptTemplate
15
+
16
+ # --- 1. SETUP API & SYSTEM PROMPT ---
17
+ # Hugging Face uses os.getenv for secrets
18
+ api_key = os.getenv("GROQ_API")
19
+
20
+ STRICT_PROMPT_TEMPLATE = """You are a strict document-based assistant.
21
+ Use the following pieces of context to answer the user's question.
22
+
23
+ RESTRICTIONS:
24
+ 1. ONLY use the information provided in the context below.
25
+ 2. If the answer is not contained within the context, specifically say: "I'm sorry, but the provided documents do not contain information to answer this question."
26
+ 3. Do NOT use your own outside knowledge.
27
+
28
+ Context:
29
+ {context}
30
+
31
+ Question: {question}
32
+ Helpful Answer:"""
33
 
34
+ STRICT_PROMPT = PromptTemplate(
35
+ template=STRICT_PROMPT_TEMPLATE,
36
+ input_variables=["context", "question"]
37
+ )
38
 
39
+ # --- 2. LOADING LOGIC ---
40
  def load_any(path: str):
41
  p = path.lower()
42
  if p.endswith(".pdf"): return PyPDFLoader(path).load()
 
44
  if p.endswith(".docx"): return Docx2txtLoader(path).load()
45
  return []
46
 
47
+ # --- 3. HYBRID PROCESSING ---
48
+ def process_files(files, response_length):
49
  if not files or not api_key:
50
+ return None, "⚠️ Missing files or GROQ_API key in Secrets."
51
 
52
  try:
 
53
  docs = []
54
  for file_obj in files:
55
  docs.extend(load_any(file_obj.name))
56
 
57
+ splitter = RecursiveCharacterTextSplitter(chunk_size=800, chunk_overlap=100)
 
 
 
 
58
  chunks = splitter.split_documents(docs)
59
 
60
+ # Hybrid Retrievers
61
  embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
62
  faiss_db = FAISS.from_documents(chunks, embeddings)
63
  faiss_retriever = faiss_db.as_retriever(search_kwargs={"k": 3})
64
+
 
65
  bm25_retriever = BM25Retriever.from_documents(chunks)
66
  bm25_retriever.k = 3
67
 
 
68
  ensemble_retriever = EnsembleRetriever(
69
  retrievers=[faiss_retriever, bm25_retriever],
70
+ weights=[0.5, 0.5]
71
  )
72
 
73
+ llm = ChatGroq(
74
+ groq_api_key=api_key,
75
+ model="llama-3.3-70b-versatile",
76
+ temperature=0,
77
+ max_tokens=int(response_length)
78
+ )
79
+
80
  memory = ConversationBufferMemory(
81
  memory_key="chat_history",
82
  return_messages=True,
 
85
 
86
  chain = ConversationalRetrievalChain.from_llm(
87
  llm=llm,
88
+ retriever=ensemble_retriever,
89
+ combine_docs_chain_kwargs={"prompt": STRICT_PROMPT},
90
  memory=memory,
91
  return_source_documents=True,
92
  output_key="answer"
93
  )
94
 
95
+ return chain, f"✅ Knowledge base built! Max answer length: {response_length} tokens."
96
 
97
  except Exception as e:
98
  return None, f"❌ Error: {str(e)}"
99
 
100
+ # --- 4. CHAT FUNCTION ---
101
  def chat_function(message, history, chain):
102
  if not chain:
103
  return "⚠️ Build the chatbot first."
 
105
  res = chain.invoke({"question": message})
106
  answer = res["answer"]
107
 
 
108
  sources = list(set([os.path.basename(d.metadata.get("source", "unknown")) for d in res.get("source_documents", [])]))
109
+ source_display = "\n\n----- \n**Sources used:** " + ", ".join(sources)
110
 
111
+ return answer + source_display
112
 
113
+ # --- 5. UI BUILDING ---
114
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
115
+ gr.Markdown("# 🛡️ Strict Hybrid Multi-RAG")
116
  chain_state = gr.State(None)
117
 
118
  with gr.Row():
119
  with gr.Column(scale=1):
120
+ file_input = gr.File(file_count="multiple", label="1. Upload Documents")
121
+ len_slider = gr.Slider(minimum=100, maximum=4000, value=1000, step=100, label="2. Response Length")
122
+ build_btn = gr.Button("3. Build Restricted Chatbot", variant="primary")
123
  status = gr.Textbox(label="Status", interactive=False)
124
 
125
  with gr.Column(scale=2):
126
  gr.ChatInterface(fn=chat_function, additional_inputs=[chain_state])
127
 
128
+ build_btn.click(process_files, inputs=[file_input, len_slider], outputs=[chain_state, status])
129
 
130
  if __name__ == "__main__":
131
  demo.launch()