faiz0983 commited on
Commit
ee4f5d4
·
verified ·
1 Parent(s): 051aa3d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +125 -66
app.py CHANGED
@@ -1,131 +1,190 @@
1
  import os
2
  import gradio as gr
3
 
4
- # Classic & Community Imports
5
- from langchain_classic.chains import ConversationalRetrievalChain
6
- from langchain_classic.memory import ConversationBufferMemory
 
 
 
 
7
  from langchain_groq import ChatGroq
8
- from langchain_community.vectorstores import FAISS
9
  from langchain_huggingface import HuggingFaceEmbeddings
10
- from langchain_community.document_loaders import PyPDFLoader, TextLoader, Docx2txtLoader
11
- from langchain_text_splitters import RecursiveCharacterTextSplitter
 
 
 
 
 
 
12
  from langchain_community.retrievers import BM25Retriever
13
- from langchain_community.retrievers.ensemble import EnsembleRetriever
14
- from langchain.prompts import PromptTemplate
15
 
16
- # --- 1. SETUP API & SYSTEM PROMPT ---
17
- # Hugging Face uses os.getenv for secrets
18
- api_key = os.getenv("GROQ_API")
 
 
 
 
19
 
20
- STRICT_PROMPT_TEMPLATE = """You are a strict document-based assistant.
21
- Use the following pieces of context to answer the user's question.
22
 
23
- RESTRICTIONS:
24
- 1. ONLY use the information provided in the context below.
25
- 2. If the answer is not contained within the context, specifically say: "I'm sorry, but the provided documents do not contain information to answer this question."
26
- 3. Do NOT use your own outside knowledge.
27
 
28
  Context:
29
  {context}
30
 
31
  Question: {question}
32
- Helpful Answer:"""
 
 
33
 
34
  STRICT_PROMPT = PromptTemplate(
35
- template=STRICT_PROMPT_TEMPLATE,
36
  input_variables=["context", "question"]
37
  )
38
 
39
- # --- 2. LOADING LOGIC ---
 
 
40
  def load_any(path: str):
41
  p = path.lower()
42
- if p.endswith(".pdf"): return PyPDFLoader(path).load()
43
- if p.endswith(".txt"): return TextLoader(path, encoding="utf-8").load()
44
- if p.endswith(".docx"): return Docx2txtLoader(path).load()
 
 
 
45
  return []
46
 
47
- # --- 3. HYBRID PROCESSING ---
 
 
48
  def process_files(files, response_length):
49
- if not files or not api_key:
50
- return None, "⚠️ Missing files or GROQ_API key in Secrets."
51
 
52
  try:
53
  docs = []
54
- for file_obj in files:
55
- docs.extend(load_any(file_obj.name))
56
-
57
- splitter = RecursiveCharacterTextSplitter(chunk_size=800, chunk_overlap=100)
 
 
 
58
  chunks = splitter.split_documents(docs)
59
 
60
- # Hybrid Retrievers
61
- embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
 
 
 
62
  faiss_db = FAISS.from_documents(chunks, embeddings)
63
  faiss_retriever = faiss_db.as_retriever(search_kwargs={"k": 3})
64
-
65
  bm25_retriever = BM25Retriever.from_documents(chunks)
66
  bm25_retriever.k = 3
67
 
68
- ensemble_retriever = EnsembleRetriever(
69
  retrievers=[faiss_retriever, bm25_retriever],
70
  weights=[0.5, 0.5]
71
  )
72
 
73
  llm = ChatGroq(
74
- groq_api_key=api_key,
75
- model="llama-3.3-70b-versatile",
76
  temperature=0,
77
  max_tokens=int(response_length)
78
  )
79
-
80
  memory = ConversationBufferMemory(
81
- memory_key="chat_history",
82
- return_messages=True,
83
  output_key="answer"
84
  )
85
-
86
  chain = ConversationalRetrievalChain.from_llm(
87
  llm=llm,
88
- retriever=ensemble_retriever,
89
  combine_docs_chain_kwargs={"prompt": STRICT_PROMPT},
90
  memory=memory,
91
  return_source_documents=True,
92
  output_key="answer"
93
  )
94
-
95
- return chain, f"✅ Knowledge base built! Max answer length: {response_length} tokens."
96
 
97
  except Exception as e:
98
  return None, f"❌ Error: {str(e)}"
99
 
100
- # --- 4. CHAT FUNCTION ---
 
 
101
  def chat_function(message, history, chain):
102
- if not chain:
103
- return "⚠️ Build the chatbot first."
104
-
105
- res = chain.invoke({"question": message})
106
- answer = res["answer"]
107
-
108
- sources = list(set([os.path.basename(d.metadata.get("source", "unknown")) for d in res.get("source_documents", [])]))
109
- source_display = "\n\n----- \n**Sources used:** " + ", ".join(sources)
110
-
111
- return answer + source_display
112
-
113
- # --- 5. UI BUILDING ---
 
 
 
 
 
 
 
 
 
 
 
114
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
115
- gr.Markdown("# 🛡️ Strict Hybrid Multi-RAG")
 
116
  chain_state = gr.State(None)
117
-
118
  with gr.Row():
119
  with gr.Column(scale=1):
120
- file_input = gr.File(file_count="multiple", label="1. Upload Documents")
121
- len_slider = gr.Slider(minimum=100, maximum=4000, value=1000, step=100, label="2. Response Length")
122
- build_btn = gr.Button("3. Build Restricted Chatbot", variant="primary")
123
- status = gr.Textbox(label="Status", interactive=False)
124
-
 
 
 
 
 
 
 
 
 
 
 
 
125
  with gr.Column(scale=2):
126
- gr.ChatInterface(fn=chat_function, additional_inputs=[chain_state])
 
 
 
127
 
128
- build_btn.click(process_files, inputs=[file_input, len_slider], outputs=[chain_state, status])
 
 
 
 
129
 
130
  if __name__ == "__main__":
131
- demo.launch()
 
1
  import os
2
  import gradio as gr
3
 
4
+ # LangChain Core
5
+ from langchain.chains import ConversationalRetrievalChain
6
+ from langchain.memory import ConversationBufferMemory
7
+ from langchain.prompts import PromptTemplate
8
+ from langchain.retrievers import EnsembleRetriever
9
+
10
+ # Providers
11
  from langchain_groq import ChatGroq
 
12
  from langchain_huggingface import HuggingFaceEmbeddings
13
+
14
+ # Community
15
+ from langchain_community.vectorstores import FAISS
16
+ from langchain_community.document_loaders import (
17
+ PyPDFLoader,
18
+ TextLoader,
19
+ Docx2txtLoader
20
+ )
21
  from langchain_community.retrievers import BM25Retriever
 
 
22
 
23
+ # Text Splitters
24
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
25
+
26
+ # --------------------------------------------------
27
+ # 1. API KEY
28
+ # --------------------------------------------------
29
+ GROQ_API_KEY = os.getenv("GROQ_API")
30
 
31
+ STRICT_PROMPT_TEMPLATE = """You are a strict document-based assistant.
32
+ Use ONLY the information provided in the context.
33
 
34
+ RULES:
35
+ 1. Do not use outside knowledge.
36
+ 2. If the answer is not present, say:
37
+ "I'm sorry, but the provided documents do not contain information to answer this question."
38
 
39
  Context:
40
  {context}
41
 
42
  Question: {question}
43
+
44
+ Answer:
45
+ """
46
 
47
  STRICT_PROMPT = PromptTemplate(
48
+ template=STRICT_PROMPT_TEMPLATE,
49
  input_variables=["context", "question"]
50
  )
51
 
52
+ # --------------------------------------------------
53
+ # 2. FILE LOADER
54
+ # --------------------------------------------------
55
  def load_any(path: str):
56
  p = path.lower()
57
+ if p.endswith(".pdf"):
58
+ return PyPDFLoader(path).load()
59
+ if p.endswith(".txt"):
60
+ return TextLoader(path, encoding="utf-8").load()
61
+ if p.endswith(".docx"):
62
+ return Docx2txtLoader(path).load()
63
  return []
64
 
65
+ # --------------------------------------------------
66
+ # 3. PROCESS FILES / BUILD CHAIN
67
+ # --------------------------------------------------
68
  def process_files(files, response_length):
69
+ if not files or not GROQ_API_KEY:
70
+ return None, "⚠️ Missing documents or GROQ_API key."
71
 
72
  try:
73
  docs = []
74
+ for f in files:
75
+ docs.extend(load_any(f.name))
76
+
77
+ splitter = RecursiveCharacterTextSplitter(
78
+ chunk_size=800,
79
+ chunk_overlap=100
80
+ )
81
  chunks = splitter.split_documents(docs)
82
 
83
+ # --- Hybrid Retrieval ---
84
+ embeddings = HuggingFaceEmbeddings(
85
+ model_name="sentence-transformers/all-MiniLM-L6-v2"
86
+ )
87
+
88
  faiss_db = FAISS.from_documents(chunks, embeddings)
89
  faiss_retriever = faiss_db.as_retriever(search_kwargs={"k": 3})
90
+
91
  bm25_retriever = BM25Retriever.from_documents(chunks)
92
  bm25_retriever.k = 3
93
 
94
+ retriever = EnsembleRetriever(
95
  retrievers=[faiss_retriever, bm25_retriever],
96
  weights=[0.5, 0.5]
97
  )
98
 
99
  llm = ChatGroq(
100
+ groq_api_key=GROQ_API_KEY,
101
+ model="llama-3.3-70b-versatile",
102
  temperature=0,
103
  max_tokens=int(response_length)
104
  )
105
+
106
  memory = ConversationBufferMemory(
107
+ memory_key="chat_history",
108
+ return_messages=True,
109
  output_key="answer"
110
  )
111
+
112
  chain = ConversationalRetrievalChain.from_llm(
113
  llm=llm,
114
+ retriever=retriever,
115
  combine_docs_chain_kwargs={"prompt": STRICT_PROMPT},
116
  memory=memory,
117
  return_source_documents=True,
118
  output_key="answer"
119
  )
120
+
121
+ return chain, f"✅ Chatbot ready (max {response_length} tokens)"
122
 
123
  except Exception as e:
124
  return None, f"❌ Error: {str(e)}"
125
 
126
+ # --------------------------------------------------
127
+ # 4. CHAT FUNCTION
128
+ # --------------------------------------------------
129
  def chat_function(message, history, chain):
130
+ if chain is None:
131
+ return "⚠️ Please build the chatbot first."
132
+
133
+ result = chain.invoke({
134
+ "question": message,
135
+ "chat_history": history
136
+ })
137
+
138
+ answer = result["answer"]
139
+
140
+ sources = {
141
+ os.path.basename(doc.metadata.get("source", "unknown"))
142
+ for doc in result.get("source_documents", [])
143
+ }
144
+
145
+ if sources:
146
+ answer += "\n\n---\n**Sources:** " + ", ".join(sources)
147
+
148
+ return answer
149
+
150
+ # --------------------------------------------------
151
+ # 5. GRADIO UI
152
+ # --------------------------------------------------
153
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
154
+ gr.Markdown("# 🛡️ Strict Hybrid Multi-RAG (Groq + FAISS + BM25)")
155
+
156
  chain_state = gr.State(None)
157
+
158
  with gr.Row():
159
  with gr.Column(scale=1):
160
+ file_input = gr.File(
161
+ file_count="multiple",
162
+ label="Upload Documents"
163
+ )
164
+ len_slider = gr.Slider(
165
+ 100, 4000, value=1000, step=100,
166
+ label="Max Answer Tokens"
167
+ )
168
+ build_btn = gr.Button(
169
+ "Build Chatbot",
170
+ variant="primary"
171
+ )
172
+ status = gr.Textbox(
173
+ label="Status",
174
+ interactive=False
175
+ )
176
+
177
  with gr.Column(scale=2):
178
+ gr.ChatInterface(
179
+ fn=chat_function,
180
+ additional_inputs=[chain_state]
181
+ )
182
 
183
+ build_btn.click(
184
+ process_files,
185
+ inputs=[file_input, len_slider],
186
+ outputs=[chain_state, status]
187
+ )
188
 
189
  if __name__ == "__main__":
190
+ demo.launch()