MerveA commited on
Commit
3445f05
·
1 Parent(s): 97a95aa

Fix langchain dependency for HF Space

Browse files
Files changed (2) hide show
  1. app.py +88 -185
  2. requirements.txt +9 -13
app.py CHANGED
@@ -1,19 +1,8 @@
1
  import streamlit as st
2
  import os
3
  import json
4
- import chromadb
5
- from chromadb.config import Settings
6
- from sentence_transformers import SentenceTransformer
7
- from langchain_google_genai import ChatGoogleGenerativeAI
8
- from langchain.schema import HumanMessage, SystemMessage
9
  import time
10
  from datetime import datetime
11
- import uuid
12
- import pandas as pd
13
- import numpy as np
14
- from datasets import load_dataset
15
- from tqdm import tqdm
16
- import re
17
 
18
  # Page configuration
19
  st.set_page_config(
@@ -69,80 +58,26 @@ if 'rag_system' not in st.session_state:
69
  if 'initialized' not in st.session_state:
70
  st.session_state.initialized = False
71
 
72
- # RAG System Functions (from notebook)
73
- def chunk_text(text, chunk_size=500, overlap=50):
74
- """Split text into overlapping chunks"""
75
- words = text.split()
76
- chunks = []
77
-
78
- for i in range(0, len(words), chunk_size - overlap):
79
- chunk = ' '.join(words[i:i + chunk_size])
80
- if len(chunk.strip()) > 50: # Only keep substantial chunks
81
- chunks.append(chunk)
82
-
83
- return chunks
84
-
85
- def load_and_process_dataset():
86
- """Load and process The Pile dataset"""
87
- print("📚 Loading The Pile dataset...")
88
-
89
- try:
90
- # Load a specific subset that contains ML/AI content
91
- dataset = load_dataset("EleutherAI/the_pile", split="train", streaming=True)
92
-
93
- # Take first 1000 samples for demonstration
94
- texts = []
95
- ml_keywords = ['machine learning', 'deep learning', 'neural network', 'artificial intelligence',
96
- 'algorithm', 'model', 'training', 'data', 'feature', 'classification',
97
- 'regression', 'clustering', 'optimization', 'gradient', 'tensor']
98
-
99
- print("🔍 Filtering ML/AI related content...")
100
- count = 0
101
- for sample in tqdm(dataset, desc="Processing samples"):
102
- if count >= 1000: # Limit to 1000 samples for demo
103
- break
104
-
105
- text = sample['text']
106
- # Check if text contains ML/AI keywords
107
- if any(keyword in text.lower() for keyword in ml_keywords):
108
- # Clean and preprocess text
109
- text = re.sub(r'\s+', ' ', text) # Remove extra whitespace
110
- text = text.strip()
111
-
112
- # Only keep texts that are reasonable length (not too short or too long)
113
- if 100 <= len(text) <= 2000:
114
- texts.append(text)
115
- count += 1
116
-
117
- print(f"✅ Loaded {len(texts)} ML/AI related text samples")
118
- return texts
119
-
120
- except Exception as e:
121
- print(f"❌ Error loading dataset: {e}")
122
- print("🔄 Using fallback sample data...")
123
-
124
- # Fallback sample data if The Pile is not accessible
125
- texts = [
126
- "Machine learning is a subset of artificial intelligence that focuses on algorithms that can learn from data. Deep learning uses neural networks with multiple layers to process complex patterns in data.",
127
- "Neural networks are computing systems inspired by biological neural networks. They consist of interconnected nodes that process information using a connectionist approach.",
128
- "Supervised learning uses labeled training data to learn a mapping from inputs to outputs. Common algorithms include linear regression, decision trees, and support vector machines.",
129
- "Unsupervised learning finds hidden patterns in data without labeled examples. Clustering algorithms like K-means group similar data points together.",
130
- "Natural language processing combines computational linguistics with machine learning to help computers understand human language. It includes tasks like text classification and sentiment analysis.",
131
- "Computer vision enables machines to interpret and understand visual information from the world. It uses deep learning models like convolutional neural networks.",
132
- "Reinforcement learning is a type of machine learning where agents learn to make decisions by interacting with an environment and receiving rewards or penalties.",
133
- "Feature engineering is the process of selecting and transforming raw data into features that can be used by machine learning algorithms. Good features can significantly improve model performance.",
134
- "Cross-validation is a technique used to assess how well a machine learning model generalizes to new data. It involves splitting data into training and validation sets multiple times.",
135
- "Overfitting occurs when a model learns the training data too well and performs poorly on new data. Regularization techniques help prevent overfitting."
136
- ]
137
- print(f"✅ Using {len(texts)} sample texts")
138
- return texts
139
-
140
  def initialize_rag_system(api_key):
141
  """Initialize the RAG system with all components"""
142
  try:
143
  # Set API key
144
  os.environ['GOOGLE_API_KEY'] = api_key
145
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  # Initialize embedding model
147
  embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
148
 
@@ -155,148 +90,117 @@ def initialize_rag_system(api_key):
155
  collection_name = "ml_ai_knowledge"
156
  try:
157
  collection = chroma_client.get_collection(collection_name)
158
- print(f"✅ Found existing collection: {collection_name}")
159
  except:
160
  collection = chroma_client.create_collection(
161
  name=collection_name,
162
  metadata={"description": "ML/AI knowledge base from The Pile dataset"}
163
  )
164
- print(f"✅ Created new collection: {collection_name}")
165
 
166
  # Check if collection already has data
167
  existing_count = collection.count()
168
- print(f"📊 Current documents in collection: {existing_count}")
169
 
170
  if existing_count == 0:
171
- print("🔄 Adding new documents to collection...")
172
-
173
- # Load and process dataset
174
- texts = load_and_process_dataset()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
 
 
176
  all_chunks = []
177
  chunk_ids = []
178
  chunk_metadatas = []
179
 
180
- for i, text in enumerate(tqdm(texts, desc="Processing texts")):
181
- chunks = chunk_text(text)
 
 
 
 
 
 
182
 
183
- for j, chunk in enumerate(chunks):
184
- chunk_id = f"doc_{i}_chunk_{j}"
185
- metadata = {
186
- "source": f"the_pile_doc_{i}",
187
- "chunk_index": j,
188
- "total_chunks": len(chunks),
189
- "text_length": len(chunk)
190
- }
191
-
192
- all_chunks.append(chunk)
193
- chunk_ids.append(chunk_id)
194
- chunk_metadatas.append(metadata)
195
-
196
- print(f"📊 Created {len(all_chunks)} text chunks")
197
 
198
- # Add documents to Chroma in batches to avoid memory issues
199
- batch_size = 100
200
- for i in tqdm(range(0, len(all_chunks), batch_size), desc="Adding to Chroma"):
201
- batch_chunks = all_chunks[i:i + batch_size]
202
- batch_ids = chunk_ids[i:i + batch_size]
203
- batch_metadatas = chunk_metadatas[i:i + batch_size]
204
-
205
- collection.add(
206
- documents=batch_chunks,
207
- ids=batch_ids,
208
- metadatas=batch_metadatas
209
- )
210
-
211
- print("✅ All documents added to Chroma!")
212
- else:
213
- print("✅ Collection already contains data, skipping addition")
214
 
215
- # Initialize Gemini
216
- llm = ChatGoogleGenerativeAI(
217
- model="gemini-2.0-flash-exp",
218
- temperature=0.7,
219
- max_output_tokens=1024,
220
- convert_system_message_to_human=True
221
- )
222
 
223
  return {
224
  'embedding_model': embedding_model,
225
  'chroma_client': chroma_client,
226
  'collection': collection,
227
- 'llm': llm
228
  }
229
  except Exception as e:
230
  st.error(f"Error initializing RAG system: {e}")
231
  return None
232
 
233
- def retrieve_relevant_docs(query, collection, n_results=5):
234
- """Retrieve relevant documents from Chroma"""
235
  try:
 
 
 
 
236
  results = collection.query(
237
  query_texts=[query],
238
  n_results=n_results
239
  )
240
 
241
- # Extract documents and metadata
242
  documents = results['documents'][0]
243
- metadatas = results['metadatas'][0]
244
  distances = results['distances'][0]
245
 
246
- return documents, metadatas, distances
247
- except Exception as e:
248
- print(f"Error retrieving documents: {e}")
249
- return [], [], []
250
-
251
- def create_context(documents):
252
- """Create context string from retrieved documents"""
253
- context = "\n\n".join(documents)
254
- return context
255
-
256
- def generate_answer(query, context, llm):
257
- """Generate answer using Gemini with retrieved context"""
258
- system_prompt = """You are an AI assistant specialized in machine learning, deep learning, and artificial intelligence.
259
- Use the provided context to answer questions accurately and comprehensively. If the context doesn't contain enough
260
- information, you can supplement with your general knowledge, but always prioritize the provided context.
261
-
262
- Provide clear, well-structured answers with examples when appropriate."""
263
-
264
- user_prompt = f"""Context:
265
- {context}
266
-
267
- Question: {query}
268
-
269
- Please provide a comprehensive answer based on the context above."""
270
-
271
- try:
272
- messages = [
273
- SystemMessage(content=system_prompt),
274
- HumanMessage(content=user_prompt)
275
- ]
276
-
277
- response = llm.invoke(messages)
278
- return response.content
279
- except Exception as e:
280
- return f"Error generating answer: {e}"
281
-
282
- def rag_pipeline(query, rag_system, n_results=5):
283
- """Complete RAG pipeline"""
284
- try:
285
- collection = rag_system['collection']
286
- llm = rag_system['llm']
287
-
288
- # Retrieve relevant documents
289
- documents, metadatas, distances = retrieve_relevant_docs(query, collection, n_results)
290
-
291
  if not documents:
292
  return "I couldn't find relevant information for your query. Please try asking about machine learning, deep learning, or AI topics."
293
 
294
  # Create context
295
- context = create_context(documents)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
296
 
297
- # Generate answer
298
- answer = generate_answer(query, context, llm)
299
- return answer, documents, distances
300
 
301
  except Exception as e:
302
  return f"Error generating response: {e}", [], []
@@ -305,7 +209,7 @@ def rag_pipeline(query, rag_system, n_results=5):
305
  st.markdown("""
306
  <div class="main-header">
307
  <h1>🤖 RAG Chatbot: ML/AI Assistant</h1>
308
- <p>Powered by Google Gemini 2.5 Flash + LangChain + Chroma</p>
309
  </div>
310
  """, unsafe_allow_html=True)
311
 
@@ -379,14 +283,13 @@ if not st.session_state.initialized:
379
  deep learning, AI, and related topics using:
380
 
381
  - **🤖 Generation Model**: Google Gemini 2.5 Flash
382
- - **🔗 RAG Framework**: LangChain
383
  - **🗄️ Vector Database**: Chroma
384
- - **📚 Dataset**: The Pile (EleutherAI/the_pile) from Hugging Face
385
  - **🌐 Interface**: Streamlit
386
 
387
  ### 🚀 How It Works
388
 
389
- 1. **Data Loading**: Text data from The Pile dataset is loaded and filtered for ML/AI content
390
  2. **Embedding**: Text is processed and embedded using sentence transformers
391
  3. **Storage**: Embeddings are stored in Chroma vector database
392
  4. **Retrieval**: Relevant context is retrieved for user queries
@@ -459,7 +362,7 @@ else:
459
  st.markdown("---")
460
  st.markdown("""
461
  <div style="text-align: center; color: #666; padding: 1rem;">
462
- <p>🤖 RAG Chatbot | Powered by Google Gemini 2.5 Flash + LangChain + Chroma</p>
463
- <p>📚 Knowledge Base: The Pile Dataset (EleutherAI/the_pile)</p>
464
  </div>
465
  """, unsafe_allow_html=True)
 
1
  import streamlit as st
2
  import os
3
  import json
 
 
 
 
 
4
  import time
5
  from datetime import datetime
 
 
 
 
 
 
6
 
7
  # Page configuration
8
  st.set_page_config(
 
58
  if 'initialized' not in st.session_state:
59
  st.session_state.initialized = False
60
 
61
+ # RAG System Functions
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  def initialize_rag_system(api_key):
63
  """Initialize the RAG system with all components"""
64
  try:
65
  # Set API key
66
  os.environ['GOOGLE_API_KEY'] = api_key
67
 
68
+ # Import required libraries with error handling
69
+ try:
70
+ from sentence_transformers import SentenceTransformer
71
+ import chromadb
72
+ from chromadb.config import Settings
73
+ import google.generativeai as genai
74
+ from datasets import load_dataset
75
+ from tqdm import tqdm
76
+ import re
77
+ except ImportError as e:
78
+ st.error(f"Import error: {e}")
79
+ return None
80
+
81
  # Initialize embedding model
82
  embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
83
 
 
90
  collection_name = "ml_ai_knowledge"
91
  try:
92
  collection = chroma_client.get_collection(collection_name)
 
93
  except:
94
  collection = chroma_client.create_collection(
95
  name=collection_name,
96
  metadata={"description": "ML/AI knowledge base from The Pile dataset"}
97
  )
 
98
 
99
  # Check if collection already has data
100
  existing_count = collection.count()
 
101
 
102
  if existing_count == 0:
103
+ # Load sample data for demo
104
+ sample_texts = [
105
+ "Machine learning is a subset of artificial intelligence that focuses on algorithms that can learn from data. Deep learning uses neural networks with multiple layers to process complex patterns in data.",
106
+ "Neural networks are computing systems inspired by biological neural networks. They consist of interconnected nodes that process information using a connectionist approach.",
107
+ "Supervised learning uses labeled training data to learn a mapping from inputs to outputs. Common algorithms include linear regression, decision trees, and support vector machines.",
108
+ "Unsupervised learning finds hidden patterns in data without labeled examples. Clustering algorithms like K-means group similar data points together.",
109
+ "Natural language processing combines computational linguistics with machine learning to help computers understand human language. It includes tasks like text classification and sentiment analysis.",
110
+ "Computer vision enables machines to interpret and understand visual information from the world. It uses deep learning models like convolutional neural networks.",
111
+ "Reinforcement learning is a type of machine learning where agents learn to make decisions by interacting with an environment and receiving rewards or penalties.",
112
+ "Feature engineering is the process of selecting and transforming raw data into features that can be used by machine learning algorithms. Good features can significantly improve model performance.",
113
+ "Cross-validation is a technique used to assess how well a machine learning model generalizes to new data. It involves splitting data into training and validation sets multiple times.",
114
+ "Overfitting occurs when a model learns the training data too well and performs poorly on new data. Regularization techniques help prevent overfitting.",
115
+ "Gradient descent is an optimization algorithm used to minimize the cost function in machine learning models. It iteratively adjusts parameters to find the minimum of the function.",
116
+ "Backpropagation is a method used to train neural networks by calculating gradients and updating weights. It works by propagating errors backward through the network layers.",
117
+ "Convolutional Neural Networks (CNNs) are specialized neural networks designed for processing grid-like data such as images. They use convolutional layers to detect local features.",
118
+ "Transformers are a type of neural network architecture that uses attention mechanisms to process sequential data. They are the foundation of modern language models like GPT.",
119
+ "Large Language Models (LLMs) are AI systems trained on vast amounts of text data to understand and generate human-like text. They can perform various language tasks.",
120
+ "Generative AI refers to AI systems that can create new content, such as text, images, or code. It differs from predictive AI which focuses on making predictions.",
121
+ "Transfer learning is a technique where a model trained on one task is adapted for a different but related task. It can significantly reduce training time and improve performance.",
122
+ "Hyperparameter tuning is the process of finding the optimal hyperparameters for a machine learning model. Common methods include grid search and random search.",
123
+ "Regularization techniques like L1 and L2 regularization help prevent overfitting by adding penalty terms to the loss function. They encourage simpler models.",
124
+ "Activation functions introduce non-linearity into neural networks. Common activation functions include ReLU, sigmoid, and tanh."
125
+ ]
126
 
127
+ # Add sample documents to Chroma
128
  all_chunks = []
129
  chunk_ids = []
130
  chunk_metadatas = []
131
 
132
+ for i, text in enumerate(sample_texts):
133
+ chunk_id = f"sample_doc_{i}"
134
+ metadata = {
135
+ "source": f"sample_doc_{i}",
136
+ "chunk_index": 0,
137
+ "total_chunks": 1,
138
+ "text_length": len(text)
139
+ }
140
 
141
+ all_chunks.append(text)
142
+ chunk_ids.append(chunk_id)
143
+ chunk_metadatas.append(metadata)
 
 
 
 
 
 
 
 
 
 
 
144
 
145
+ # Add documents to Chroma
146
+ collection.add(
147
+ documents=all_chunks,
148
+ ids=chunk_ids,
149
+ metadatas=chunk_metadatas
150
+ )
 
 
 
 
 
 
 
 
 
 
151
 
152
+ # Initialize Gemini using direct API instead of LangChain
153
+ genai.configure(api_key=api_key)
 
 
 
 
 
154
 
155
  return {
156
  'embedding_model': embedding_model,
157
  'chroma_client': chroma_client,
158
  'collection': collection,
159
+ 'genai': genai
160
  }
161
  except Exception as e:
162
  st.error(f"Error initializing RAG system: {e}")
163
  return None
164
 
165
+ def rag_pipeline(query, rag_system, n_results=5):
166
+ """Complete RAG pipeline using direct Gemini API"""
167
  try:
168
+ collection = rag_system['collection']
169
+ genai = rag_system['genai']
170
+
171
+ # Retrieve relevant documents
172
  results = collection.query(
173
  query_texts=[query],
174
  n_results=n_results
175
  )
176
 
 
177
  documents = results['documents'][0]
 
178
  distances = results['distances'][0]
179
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
  if not documents:
181
  return "I couldn't find relevant information for your query. Please try asking about machine learning, deep learning, or AI topics."
182
 
183
  # Create context
184
+ context = "\n\n".join(documents)
185
+
186
+ # Generate answer using direct Gemini API
187
+ model = genai.GenerativeModel('gemini-2.0-flash-exp')
188
+
189
+ prompt = f"""You are an AI assistant specialized in machine learning, deep learning, and artificial intelligence.
190
+ Use the provided context to answer questions accurately and comprehensively. If the context doesn't contain enough
191
+ information, you can supplement with your general knowledge, but always prioritize the provided context.
192
+
193
+ Provide clear, well-structured answers with examples when appropriate.
194
+
195
+ Context:
196
+ {context}
197
+
198
+ Question: {query}
199
+
200
+ Please provide a comprehensive answer based on the context above."""
201
 
202
+ response = model.generate_content(prompt)
203
+ return response.text, documents, distances
 
204
 
205
  except Exception as e:
206
  return f"Error generating response: {e}", [], []
 
209
  st.markdown("""
210
  <div class="main-header">
211
  <h1>🤖 RAG Chatbot: ML/AI Assistant</h1>
212
+ <p>Powered by Google Gemini 2.5 Flash + Chroma + Direct API</p>
213
  </div>
214
  """, unsafe_allow_html=True)
215
 
 
283
  deep learning, AI, and related topics using:
284
 
285
  - **🤖 Generation Model**: Google Gemini 2.5 Flash
 
286
  - **🗄️ Vector Database**: Chroma
287
+ - **📚 Dataset**: Sample ML/AI knowledge base
288
  - **🌐 Interface**: Streamlit
289
 
290
  ### 🚀 How It Works
291
 
292
+ 1. **Data Loading**: Sample ML/AI content is loaded
293
  2. **Embedding**: Text is processed and embedded using sentence transformers
294
  3. **Storage**: Embeddings are stored in Chroma vector database
295
  4. **Retrieval**: Relevant context is retrieved for user queries
 
362
  st.markdown("---")
363
  st.markdown("""
364
  <div style="text-align: center; color: #666; padding: 1rem;">
365
+ <p>🤖 RAG Chatbot | Powered by Google Gemini 2.5 Flash + Chroma</p>
366
+ <p>📚 Knowledge Base: ML/AI Sample Dataset</p>
367
  </div>
368
  """, unsafe_allow_html=True)
requirements.txt CHANGED
@@ -1,13 +1,9 @@
1
- streamlit
2
- langchain
3
- langchain-community
4
- langchain-google-genai
5
- chromadb
6
- datasets
7
- transformers
8
- sentence-transformers
9
- google-generativeai
10
- tiktoken
11
- numpy
12
- pandas
13
- tqdm
 
1
+ # Core dependencies for Hugging Face Spaces
2
+ streamlit==1.28.1
3
+ chromadb==0.4.18
4
+ sentence-transformers==2.2.2
5
+ google-generativeai==0.3.2
6
+ numpy==1.24.3
7
+ pandas==2.0.3
8
+ tqdm==4.66.1
9
+ huggingface-hub>=0.16.4,<1.0.0