ahmadsanafarooq commited on
Commit
324f9b1
·
verified ·
1 Parent(s): 92c2a42

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +105 -207
app.py CHANGED
@@ -1,212 +1,110 @@
1
- import os
2
  import gradio as gr
3
- from langchain.text_splitter import RecursiveCharacterTextSplitter
4
- from langchain_community.vectorstores import Chroma
5
  from langchain.chains import RetrievalQA
 
 
 
6
  from langchain_groq import ChatGroq
7
- from langchain_community.document_loaders import TextLoader, PyPDFLoader
8
- from langchain.schema import Document
9
- from pathlib import Path
10
- from typing import List
11
- import logging
12
- import numpy as np
13
  from sklearn.feature_extraction.text import TfidfVectorizer
14
  from sklearn.metrics.pairwise import cosine_similarity
15
- import pickle
16
- from dotenv import load_dotenv
17
-
18
- # Configure logging
19
- logging.basicConfig(level=logging.INFO)
20
- logger = logging.getLogger(__name__)
21
-
22
- class SimpleEmbeddings:
23
- def __init__(self):
24
- self.vectorizer = TfidfVectorizer(max_features=384, stop_words='english')
25
- self.fitted = False
26
-
27
- def embed_documents(self, texts: List[str]) -> List[List[float]]:
28
- if not self.fitted:
29
- self.vectorizer.fit(texts)
30
- self.fitted = True
31
- embeddings = self.vectorizer.transform(texts)
32
- return embeddings.toarray().tolist()
33
-
34
- def embed_query(self, text: str) -> List[float]:
35
- if not self.fitted:
36
- return [0.0] * 384
37
- embedding = self.vectorizer.transform([text])
38
- return embedding.toarray()[0].tolist()
39
-
40
- class RAGAssistant:
41
- def __init__(self, groq_api_key: str):
42
- self.groq_api_key = groq_api_key
43
- self.embeddings = self._init_embeddings()
44
- self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200, length_function=len)
45
- self.learning_vectorstore = None
46
- self.code_vectorstore = None
47
- self.llm = ChatGroq(groq_api_key=groq_api_key, model_name="llama3-70b-8192", temperature=0.1)
48
- self.learning_persist_dir = "./chroma_learning_db"
49
- self.code_persist_dir = "./chroma_code_db"
50
- self._init_vector_stores()
51
-
52
- def _init_embeddings(self):
53
- try:
54
- from langchain_huggingface import HuggingFaceEmbeddings
55
- models_to_try = ["all-MiniLM-L6-v2", "paraphrase-MiniLM-L3-v2", "all-mpnet-base-v2"]
56
- for model_name in models_to_try:
57
- try:
58
- embeddings = HuggingFaceEmbeddings(
59
- model_name=model_name,
60
- model_kwargs={'device': 'cpu'},
61
- encode_kwargs={'normalize_embeddings': False}
62
- )
63
- return embeddings
64
- except:
65
- continue
66
- except ImportError:
67
- pass
68
- return SimpleEmbeddings()
69
-
70
- def _init_vector_stores(self):
71
- try:
72
- self.learning_vectorstore = Chroma(
73
- persist_directory=self.learning_persist_dir,
74
- embedding_function=self.embeddings,
75
- collection_name="learning_materials"
76
- )
77
- self.code_vectorstore = Chroma(
78
- persist_directory=self.code_persist_dir,
79
- embedding_function=self.embeddings,
80
- collection_name="code_documentation"
81
- )
82
- except Exception as e:
83
- logger.error(f"Error initializing vector stores: {str(e)}")
84
- raise
85
-
86
- def load_documents(self, files: List[str], assistant_type: str) -> str:
87
- try:
88
- documents = []
89
- for file_path in files:
90
- try:
91
- loader = PyPDFLoader(file_path) if file_path.endswith('.pdf') else TextLoader(file_path, encoding='utf-8')
92
- docs = loader.load()
93
- documents.extend(docs)
94
- except Exception as e:
95
- print(f"Error loading {file_path}: {e}")
96
- continue
97
- if not documents:
98
- return "No documents could be loaded. Please check your files."
99
- chunks = self.text_splitter.split_documents(documents)
100
- for chunk in chunks:
101
- chunk.metadata['assistant_type'] = assistant_type
102
- if assistant_type == "learning":
103
- self.learning_vectorstore.add_documents(chunks)
104
- self.learning_vectorstore.persist()
105
- elif assistant_type == "code":
106
- self.code_vectorstore.add_documents(chunks)
107
- self.code_vectorstore.persist()
108
- return f"Successfully loaded {len(chunks)} chunks from {len(documents)} documents into {assistant_type} assistant."
109
- except Exception as e:
110
- logger.error(f"Error loading documents: {str(e)}")
111
- return f"Error loading documents: {str(e)}"
112
-
113
- def get_learning_tutor_response(self, question: str) -> str:
114
- try:
115
- if not self.learning_vectorstore:
116
- return "Please upload some learning materials first."
117
- qa_chain = RetrievalQA.from_chain_type(
118
- llm=self.llm,
119
- chain_type="stuff",
120
- retriever=self.learning_vectorstore.as_retriever(search_kwargs={"k": 3}),
121
- return_source_documents=True
122
- )
123
- prompt = f"""
124
- You are an AI learning assistant. Answer the following student question based on uploaded course materials.
125
- Question: {question}
126
- """
127
- result = qa_chain({"query": prompt})
128
- response = result['result']
129
- if result.get('source_documents'):
130
- response += "\n\n**Sources:**\n"
131
- for i, doc in enumerate(result['source_documents'][:3]):
132
- source = doc.metadata.get('source', 'Unknown')
133
- response += f"- {Path(source).name}\n"
134
- return response
135
- except Exception as e:
136
- logger.error(f"Error in learning tutor: {str(e)}")
137
- return f"Error generating response: {str(e)}"
138
-
139
- def get_code_helper_response(self, question: str) -> str:
140
- try:
141
- if not self.code_vectorstore:
142
- return "Please upload some code documentation first."
143
- qa_chain = RetrievalQA.from_chain_type(
144
- llm=self.llm,
145
- chain_type="stuff",
146
- retriever=self.code_vectorstore.as_retriever(search_kwargs={"k": 3}),
147
- return_source_documents=True
148
- )
149
- prompt = f"""
150
- You are a code assistant. Answer the following developer question based on uploaded technical documentation.
151
- Question: {question}
152
- """
153
- result = qa_chain({"query": prompt})
154
- response = result['result']
155
- if result.get('source_documents'):
156
- response += "\n\n**Documentation Sources:**\n"
157
- for i, doc in enumerate(result['source_documents'][:3]):
158
- source = doc.metadata.get('source', 'Unknown')
159
- response += f"- {Path(source).name}\n"
160
- return response
161
- except Exception as e:
162
- logger.error(f"Error in code helper: {str(e)}")
163
- return f"Error generating response: {str(e)}"
164
-
165
- def evaluate_retrieval(query: str, ground_truth_docs: List[str], retriever, k: int = 5):
166
- try:
167
- retrieved_docs = retriever.get_relevant_documents(query)
168
- top_k = [doc.page_content for doc in retrieved_docs[:k]]
169
- hits = sum([1 for doc in top_k if any(gt.lower() in doc.lower() for gt in ground_truth_docs)])
170
- precision = hits / k
171
- recall = hits / len(ground_truth_docs) if ground_truth_docs else 0.0
172
- print("\n Query:", query)
173
- print(" Top-K Retrieved Documents:")
174
- for i, doc in enumerate(top_k, 1):
175
- print(f"{i}. {doc[:200]}...")
176
- print(f"\n Evaluation Results:")
177
- print(f" Precision@{k}: {precision:.2f}")
178
- print(f" Recall@{k}: {recall:.2f}")
179
- return {
180
- f"Precision@{k}": precision,
181
- f"Recall@{k}": recall,
182
- "Hits": hits,
183
- "Retrieved": top_k
184
- }
185
- except Exception as e:
186
- logger.error(f"❌ Error during evaluation: {str(e)}")
187
- return {
188
- f"Precision@{k}": 0.0,
189
- f"Recall@{k}": 0.0,
190
- "Hits": 0,
191
- "Retrieved": []
192
- }
193
-
194
- def main():
195
- load_dotenv()
196
- groq_api_key = os.getenv("GROQ_API_KEY")
197
- if not groq_api_key:
198
- print("Please set your GROQ_API_KEY environment variable")
199
- return
200
- assistant = RAGAssistant(groq_api_key)
201
-
202
- # Example Evaluation
203
- query = "What is supervised learning?"
204
- ground_truth_docs = ["Supervised learning is a type of machine learning where the model learns from labeled data."]
205
- evaluate_retrieval(
206
- query=query,
207
- ground_truth_docs=ground_truth_docs,
208
- retriever=assistant.learning_vectorstore.as_retriever(search_kwargs={"k": 5})
209
- )
210
-
211
- if __name__ == "__main__":
212
- main()
 
 
1
  import gradio as gr
2
+ import os
3
+ import datetime
4
  from langchain.chains import RetrievalQA
5
+ from langchain.vectorstores import Chroma
6
+ from langchain.embeddings import HuggingFaceEmbeddings
7
+ from langchain.llms import OpenAI
8
  from langchain_groq import ChatGroq
9
+ from langchain.text_splitter import CharacterTextSplitter
10
+ from langchain.document_loaders import TextLoader, PyPDFLoader
11
+ from langchain.prompts import PromptTemplate
 
 
 
12
  from sklearn.feature_extraction.text import TfidfVectorizer
13
  from sklearn.metrics.pairwise import cosine_similarity
14
+ from sentence_transformers import SentenceTransformer
15
+ import numpy as np
16
+
17
+ # Embedding Models
18
+ hf_embed = HuggingFaceEmbeddings()
19
+ fallback_model = SentenceTransformer('all-MiniLM-L6-v2')
20
+
21
+ # Vector Store
22
+ vector_store = Chroma(collection_name="ragstore", embedding_function=hf_embed)
23
+
24
+ # LLM
25
+ llm = ChatGroq(temperature=0, model_name="llama3-8b-8192")
26
+
27
+ # Prompt Template
28
+ prompt_template = PromptTemplate.from_template(
29
+ "Answer the following question using ONLY the context provided:\n\n{context}\n\nQuestion: {question}"
30
+ )
31
+
32
+ # RetrievalQA Chain
33
+ qa_chain = RetrievalQA.from_chain_type(
34
+ llm=llm,
35
+ retriever=vector_store.as_retriever(search_kwargs={"k": 3}),
36
+ chain_type="stuff",
37
+ chain_type_kwargs={"prompt": prompt_template}
38
+ )
39
+
40
+ # TF-IDF Fallback
41
+ def tfidf_fallback(query, documents):
42
+ texts = [doc.page_content for doc in documents]
43
+ vectorizer = TfidfVectorizer().fit(texts + [query])
44
+ vectors = vectorizer.transform(texts + [query])
45
+ cosine_sim = cosine_similarity(vectors[-1], vectors[:-1]).flatten()
46
+ top_idx = np.argmax(cosine_sim)
47
+ return texts[top_idx], cosine_sim[top_idx]
48
+
49
+ # Ingestion
50
+ def ingest_files(files):
51
+ for file in files:
52
+ if file.name.endswith(".pdf"):
53
+ loader = PyPDFLoader(file.name)
54
+ else:
55
+ loader = TextLoader(file.name)
56
+ docs = loader.load()
57
+ chunks = CharacterTextSplitter(chunk_size=1000, chunk_overlap=200).split_documents(docs)
58
+ vector_store.add_documents(chunks)
59
+
60
+ # Evaluation Info
61
+ def evaluate_retrieval(query):
62
+ docs = vector_store.similarity_search_with_score(query, k=3)
63
+ top_docs = [doc[0].page_content for doc in docs]
64
+ scores = [doc[1] for doc in docs]
65
+ similarities = [1 - s for s in scores] # cosine similarity approximation
66
+ return top_docs, similarities
67
+
68
+ # Final Response Generator
69
+ def ask_question(query):
70
+ if not query.strip():
71
+ return "", "", "", "", ""
72
+
73
+ # Retrieve docs and similarities
74
+ docs, similarities = evaluate_retrieval(query)
75
+ formatted_docs = "\n\n".join([f"Doc {i+1} (Score: {similarities[i]*100:.2f}%)\n{docs[i]}" for i in range(len(docs))])
76
+ context_block = f"### Top Retrieved Documents:\n{formatted_docs}"
77
+
78
+ # Answer from RAG
79
+ answer = qa_chain.run(query)
80
+
81
+ # Baseline (Direct LLM, no context)
82
+ baseline = llm.invoke(query)
83
+
84
+ # Confidence score approximation
85
+ confidence = np.mean(similarities) * 100
86
+
87
+ return answer, context_block, f"{confidence:.2f}%", baseline, datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
88
+
89
+ # Gradio UI
90
+ with gr.Blocks() as demo:
91
+ gr.Markdown("# 🧠 RAG-Based Learning & Code Assistant\nUpload docs, ask questions, get answers with confidence & evidence.")
92
+
93
+ with gr.Row():
94
+ with gr.Column():
95
+ file_input = gr.File(label="Upload PDF or TXT", file_types=[".pdf", ".txt"], file_count="multiple")
96
+ ingest_btn = gr.Button("Ingest Documents")
97
+ question_input = gr.Textbox(label="Ask a Question")
98
+ ask_btn = gr.Button("Ask")
99
+ with gr.Column():
100
+ answer_output = gr.Textbox(label="RAG Answer", lines=5)
101
+ retrieved_docs_output = gr.Textbox(label="Top 3 Retrieved Documents", lines=10)
102
+ confidence_output = gr.Textbox(label="Confidence (%)")
103
+ baseline_output = gr.Textbox(label="Baseline (Direct LLM)", lines=5)
104
+ timestamp_output = gr.Textbox(label="Timestamp")
105
+
106
+ ingest_btn.click(fn=ingest_files, inputs=file_input, outputs=[])
107
+ ask_btn.click(fn=ask_question, inputs=question_input,
108
+ outputs=[answer_output, retrieved_docs_output, confidence_output, baseline_output, timestamp_output])
109
+
110
+ demo.launch()