ahmadsanafarooq commited on
Commit
b752b49
·
verified ·
1 Parent(s): 324f9b1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +198 -105
app.py CHANGED
@@ -1,110 +1,203 @@
1
- import gradio as gr
2
  import os
3
- import datetime
 
 
4
  from langchain.chains import RetrievalQA
5
- from langchain.vectorstores import Chroma
6
- from langchain.embeddings import HuggingFaceEmbeddings
7
- from langchain.llms import OpenAI
8
  from langchain_groq import ChatGroq
9
- from langchain.text_splitter import CharacterTextSplitter
10
- from langchain.document_loaders import TextLoader, PyPDFLoader
11
- from langchain.prompts import PromptTemplate
 
 
 
12
  from sklearn.feature_extraction.text import TfidfVectorizer
13
  from sklearn.metrics.pairwise import cosine_similarity
14
- from sentence_transformers import SentenceTransformer
15
- import numpy as np
16
-
17
- # Embedding Models
18
- hf_embed = HuggingFaceEmbeddings()
19
- fallback_model = SentenceTransformer('all-MiniLM-L6-v2')
20
-
21
- # Vector Store
22
- vector_store = Chroma(collection_name="ragstore", embedding_function=hf_embed)
23
-
24
- # LLM
25
- llm = ChatGroq(temperature=0, model_name="llama3-8b-8192")
26
-
27
- # Prompt Template
28
- prompt_template = PromptTemplate.from_template(
29
- "Answer the following question using ONLY the context provided:\n\n{context}\n\nQuestion: {question}"
30
- )
31
-
32
- # RetrievalQA Chain
33
- qa_chain = RetrievalQA.from_chain_type(
34
- llm=llm,
35
- retriever=vector_store.as_retriever(search_kwargs={"k": 3}),
36
- chain_type="stuff",
37
- chain_type_kwargs={"prompt": prompt_template}
38
- )
39
-
40
- # TF-IDF Fallback
41
- def tfidf_fallback(query, documents):
42
- texts = [doc.page_content for doc in documents]
43
- vectorizer = TfidfVectorizer().fit(texts + [query])
44
- vectors = vectorizer.transform(texts + [query])
45
- cosine_sim = cosine_similarity(vectors[-1], vectors[:-1]).flatten()
46
- top_idx = np.argmax(cosine_sim)
47
- return texts[top_idx], cosine_sim[top_idx]
48
-
49
- # Ingestion
50
- def ingest_files(files):
51
- for file in files:
52
- if file.name.endswith(".pdf"):
53
- loader = PyPDFLoader(file.name)
54
- else:
55
- loader = TextLoader(file.name)
56
- docs = loader.load()
57
- chunks = CharacterTextSplitter(chunk_size=1000, chunk_overlap=200).split_documents(docs)
58
- vector_store.add_documents(chunks)
59
-
60
- # Evaluation Info
61
- def evaluate_retrieval(query):
62
- docs = vector_store.similarity_search_with_score(query, k=3)
63
- top_docs = [doc[0].page_content for doc in docs]
64
- scores = [doc[1] for doc in docs]
65
- similarities = [1 - s for s in scores] # cosine similarity approximation
66
- return top_docs, similarities
67
-
68
- # Final Response Generator
69
- def ask_question(query):
70
- if not query.strip():
71
- return "", "", "", "", ""
72
-
73
- # Retrieve docs and similarities
74
- docs, similarities = evaluate_retrieval(query)
75
- formatted_docs = "\n\n".join([f"Doc {i+1} (Score: {similarities[i]*100:.2f}%)\n{docs[i]}" for i in range(len(docs))])
76
- context_block = f"### Top Retrieved Documents:\n{formatted_docs}"
77
-
78
- # Answer from RAG
79
- answer = qa_chain.run(query)
80
-
81
- # Baseline (Direct LLM, no context)
82
- baseline = llm.invoke(query)
83
-
84
- # Confidence score approximation
85
- confidence = np.mean(similarities) * 100
86
-
87
- return answer, context_block, f"{confidence:.2f}%", baseline, datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
88
-
89
- # Gradio UI
90
- with gr.Blocks() as demo:
91
- gr.Markdown("# 🧠 RAG-Based Learning & Code Assistant\nUpload docs, ask questions, get answers with confidence & evidence.")
92
-
93
- with gr.Row():
94
- with gr.Column():
95
- file_input = gr.File(label="Upload PDF or TXT", file_types=[".pdf", ".txt"], file_count="multiple")
96
- ingest_btn = gr.Button("Ingest Documents")
97
- question_input = gr.Textbox(label="Ask a Question")
98
- ask_btn = gr.Button("Ask")
99
- with gr.Column():
100
- answer_output = gr.Textbox(label="RAG Answer", lines=5)
101
- retrieved_docs_output = gr.Textbox(label="Top 3 Retrieved Documents", lines=10)
102
- confidence_output = gr.Textbox(label="Confidence (%)")
103
- baseline_output = gr.Textbox(label="Baseline (Direct LLM)", lines=5)
104
- timestamp_output = gr.Textbox(label="Timestamp")
105
-
106
- ingest_btn.click(fn=ingest_files, inputs=file_input, outputs=[])
107
- ask_btn.click(fn=ask_question, inputs=question_input,
108
- outputs=[answer_output, retrieved_docs_output, confidence_output, baseline_output, timestamp_output])
109
-
110
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import gradio as gr
3
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
4
+ from langchain_community.vectorstores import Chroma
5
  from langchain.chains import RetrievalQA
 
 
 
6
  from langchain_groq import ChatGroq
7
+ from langchain_community.document_loaders import TextLoader, PyPDFLoader
8
+ from langchain.schema import Document
9
+ from pathlib import Path
10
+ from typing import List
11
+ import logging
12
+ import numpy as np
13
  from sklearn.feature_extraction.text import TfidfVectorizer
14
  from sklearn.metrics.pairwise import cosine_similarity
15
+ from dotenv import load_dotenv
16
+
17
+ # Configure logging
18
+ logging.basicConfig(level=logging.INFO)
19
+ logger = logging.getLogger(__name__)
20
+
21
+ class SimpleEmbeddings:
22
+ def __init__(self):
23
+ self.vectorizer = TfidfVectorizer(max_features=384, stop_words='english')
24
+ self.fitted = False
25
+
26
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
27
+ if not self.fitted:
28
+ self.vectorizer.fit(texts)
29
+ self.fitted = True
30
+ embeddings = self.vectorizer.transform(texts)
31
+ return embeddings.toarray().tolist()
32
+
33
+ def embed_query(self, text: str) -> List[float]:
34
+ if not self.fitted:
35
+ return [0.0] * 384
36
+ embedding = self.vectorizer.transform([text])
37
+ return embedding.toarray()[0].tolist()
38
+
39
+ class RetrieverEvaluator:
40
+ def __init__(self, retriever, ground_truth, k=3):
41
+ self.retriever = retriever
42
+ self.ground_truth = ground_truth
43
+ self.k = k
44
+
45
+ def recall_at_k(self):
46
+ correct = 0
47
+ for query, relevant_docs in self.ground_truth.items():
48
+ results = self.retriever.get_relevant_documents(query)
49
+ retrieved = [Path(doc.metadata.get("source", "")).name for doc in results]
50
+ if any(doc in retrieved[:self.k] for doc in relevant_docs):
51
+ correct += 1
52
+ recall = correct / len(self.ground_truth)
53
+ print(f"Recall@{self.k}: {recall:.2f}")
54
+ return recall
55
+
56
+ def mean_reciprocal_rank(self):
57
+ mrr_total = 0
58
+ for query, relevant_docs in self.ground_truth.items():
59
+ results = self.retriever.get_relevant_documents(query)
60
+ retrieved = [Path(doc.metadata.get("source", "")).name for doc in results]
61
+ for rank, doc in enumerate(retrieved[:self.k], 1):
62
+ if doc in relevant_docs:
63
+ mrr_total += 1 / rank
64
+ break
65
+ mrr = mrr_total / len(self.ground_truth)
66
+ print(f"MRR@{self.k}: {mrr:.2f}")
67
+ return mrr
68
+
69
+ class RAGAssistant:
70
+ def __init__(self, groq_api_key: str):
71
+ self.groq_api_key = groq_api_key
72
+ self.embeddings = self._init_embeddings()
73
+ self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
74
+ self.learning_vectorstore = None
75
+ self.llm = ChatGroq(groq_api_key=groq_api_key, model_name="llama3-70b-8192", temperature=0.1)
76
+ self.learning_persist_dir = "./chroma_learning_db"
77
+ self._init_vector_store()
78
+
79
+ def _init_embeddings(self):
80
+ try:
81
+ from langchain_huggingface import HuggingFaceEmbeddings
82
+ for model_name in ["all-MiniLM-L6-v2", "paraphrase-MiniLM-L3-v2", "all-mpnet-base-v2"]:
83
+ try:
84
+ return HuggingFaceEmbeddings(model_name=model_name, model_kwargs={'device': 'cpu'})
85
+ except:
86
+ continue
87
+ except ImportError:
88
+ pass
89
+ return SimpleEmbeddings()
90
+
91
+ def _init_vector_store(self):
92
+ self.learning_vectorstore = Chroma(
93
+ persist_directory=self.learning_persist_dir,
94
+ embedding_function=self.embeddings,
95
+ collection_name="learning_materials"
96
+ )
97
+
98
+ def load_documents(self, files: List[str]) -> str:
99
+ documents = []
100
+ for file_path in files:
101
+ try:
102
+ loader = PyPDFLoader(file_path) if file_path.endswith(".pdf") else TextLoader(file_path, encoding="utf-8")
103
+ docs = loader.load()
104
+ documents.extend(docs)
105
+ except Exception as e:
106
+ print(f"Error loading {file_path}: {e}")
107
+ if not documents:
108
+ return "No valid documents found."
109
+ chunks = self.text_splitter.split_documents(documents)
110
+ for chunk in chunks:
111
+ chunk.metadata['source'] = chunk.metadata.get('source', 'unknown')
112
+ self.learning_vectorstore.add_documents(chunks)
113
+ self.learning_vectorstore.persist()
114
+ return f"Loaded {len(chunks)} document chunks."
115
+
116
+ def get_response(self, query: str) -> str:
117
+ if not self.learning_vectorstore:
118
+ return "Please upload learning materials first."
119
+ qa_chain = RetrievalQA.from_chain_type(
120
+ llm=self.llm,
121
+ chain_type="stuff",
122
+ retriever=self.learning_vectorstore.as_retriever(search_kwargs={"k": 3}),
123
+ return_source_documents=True
124
+ )
125
+ prompt = f"""
126
+ You are a helpful educational assistant.
127
+ Answer the student's question clearly and provide references if applicable.
128
+
129
+ Question: {query}
130
+ """
131
+ result = qa_chain({"query": prompt})
132
+ response = result['result']
133
+ if result.get("source_documents"):
134
+ response += "\n\n**Sources:**\n"
135
+ for doc in result["source_documents"]:
136
+ response += f"- {Path(doc.metadata.get('source', 'Unknown')).name}\n"
137
+ return response
138
+
139
+ def evaluate_retriever(self, user_queries: List[str], file_names: List[str]):
140
+ """Evaluate with user-provided queries and expected file names"""
141
+ ground_truth = dict(zip(user_queries, file_names))
142
+ retriever = self.learning_vectorstore.as_retriever(search_kwargs={"k": 3})
143
+ evaluator = RetrieverEvaluator(retriever, ground_truth, k=3)
144
+ recall = evaluator.recall_at_k()
145
+ mrr = evaluator.mean_reciprocal_rank()
146
+ return f"Recall@3: {recall:.2f}, MRR@3: {mrr:.2f}"
147
+
148
+ def create_interface(assistant: RAGAssistant):
149
+ def upload_files(files):
150
+ file_paths = [f.name for f in files]
151
+ return assistant.load_documents(file_paths)
152
+
153
+ def chat_fn(message, history):
154
+ response = assistant.get_response(message)
155
+ history.append((message, response))
156
+ return history, ""
157
+
158
+ def evaluate_fn(queries, file_names):
159
+ query_list = [q.strip() for q in queries.split('\n') if q.strip()]
160
+ file_list = [f.strip() for f in file_names.split('\n') if f.strip()]
161
+ if len(query_list) != len(file_list):
162
+ return "Number of queries and expected file names must match."
163
+ return assistant.evaluate_retriever(query_list, file_list)
164
+
165
+ with gr.Blocks(title="RAG Assistant") as demo:
166
+ gr.Markdown("# 📘 RAG-Based Assistant")
167
+ with gr.Tab("📄 Upload & Chat"):
168
+ file_input = gr.File(label="Upload PDFs or Text Files", file_count="multiple", file_types=[".pdf", ".txt"])
169
+ upload_btn = gr.Button("Load Documents")
170
+ status = gr.Textbox(label="Status", interactive=False)
171
+ chatbot = gr.Chatbot()
172
+ user_input = gr.Textbox(label="Ask a question")
173
+ send_btn = gr.Button("Send")
174
+
175
+ upload_btn.click(fn=upload_files, inputs=[file_input], outputs=[status])
176
+ send_btn.click(fn=chat_fn, inputs=[user_input, chatbot], outputs=[chatbot, user_input])
177
+ user_input.submit(fn=chat_fn, inputs=[user_input, chatbot], outputs=[chatbot, user_input])
178
+
179
+ with gr.Tab("📊 Evaluate Retriever"):
180
+ gr.Markdown("Paste queries and expected file names (one per line).")
181
+ queries = gr.Textbox(lines=5, label="Queries")
182
+ filenames = gr.Textbox(lines=5, label="Expected File Names")
183
+ eval_btn = gr.Button("Run Evaluation")
184
+ eval_result = gr.Textbox(label="Evaluation Result")
185
+ eval_btn.click(fn=evaluate_fn, inputs=[queries, filenames], outputs=[eval_result])
186
+
187
+ gr.Markdown("---")
188
+ gr.Markdown("*Powered by LangChain, ChromaDB, and Groq API*")
189
+
190
+ return demo
191
+
192
+ def main():
193
+ load_dotenv()
194
+ groq_api_key = os.getenv("GROQ_API_KEY")
195
+ if not groq_api_key:
196
+ print("Missing GROQ_API_KEY. Set it in your environment.")
197
+ return
198
+ assistant = RAGAssistant(groq_api_key)
199
+ app = create_interface(assistant)
200
+ app.launch(server_name="0.0.0.0", server_port=7860, share=True)
201
+
202
+ if __name__ == "__main__":
203
+ main()