kamkol commited on
Commit
52657d6
·
1 Parent(s): a1492d7

Fix OpenAI client initialization and add robust error handling for Hugging Face compatibility

Browse files
Files changed (1) hide show
  1. streamlit_app.py +192 -117
streamlit_app.py CHANGED
@@ -95,148 +95,252 @@ Use these tools to provide the best possible answer.
95
  @st.cache_resource
96
  def load_document_chunks():
97
  """Load pre-processed document chunks from disk."""
98
- with open(CHUNKS_FILE, 'rb') as f:
99
- return pickle.load(f)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
  @st.cache_resource
102
  def get_chat_model():
103
  """Get the chat model for initial RAG."""
 
104
  import os
105
- # Use openai_api_key and model_name for maximum compatibility
106
- return ChatOpenAI(
107
- openai_api_key=os.environ.get("OPENAI_API_KEY"),
108
- model_name="gpt-4.1-mini",
109
- temperature=0
110
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
  @st.cache_resource
113
  def get_agent_model():
114
  """Get the more powerful model for agent and evaluation."""
 
115
  import os
116
- # Use openai_api_key and model_name for maximum compatibility
117
- return ChatOpenAI(
118
- openai_api_key=os.environ.get("OPENAI_API_KEY"),
119
- model_name="gpt-4.1",
120
- temperature=0
121
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
 
123
  @st.cache_resource
124
  def get_embedding_model():
125
  """Get the embedding model."""
 
126
  import os
127
- # Use openai_api_key and model_name for maximum compatibility
128
- return OpenAIEmbeddings(
129
- openai_api_key=os.environ.get("OPENAI_API_KEY"),
130
- model_name="text-embedding-3-small"
131
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
  @st.cache_resource
134
  def setup_qdrant_client():
135
  """Set up the Qdrant client."""
136
  import os
137
 
138
- # DEBUG START - HF Compatibility fix
139
- print(f"DEBUG: Setting up Qdrant client with path: {str(QDRANT_DIR)}")
140
- print(f"DEBUG: Qdrant directory exists: {os.path.exists(QDRANT_DIR)}")
141
- # DEBUG END
 
 
 
 
 
 
142
 
143
- # Check if directory exists
144
- if not os.path.exists(QDRANT_DIR):
145
- print(f"WARNING: Qdrant directory does not exist: {str(QDRANT_DIR)}")
146
- raise ValueError(f"Qdrant directory not found at {str(QDRANT_DIR)}")
147
 
148
- # Try creating the client with minimal parameters
149
  try:
150
- return QdrantClient(path=str(QDRANT_DIR))
 
 
151
  except Exception as e:
152
- # DEBUG START
153
- print(f"DEBUG: Error initializing QdrantClient with path: {str(e)}")
154
- # DEBUG END
155
 
156
  # Try with location parameter
157
  try:
158
- return QdrantClient(location=str(QDRANT_DIR))
 
 
159
  except Exception as e2:
160
- # DEBUG START
161
- print(f"DEBUG: Error initializing with location: {str(e2)}")
162
- # DEBUG END
163
 
164
- # Last attempt with in-memory client
165
- print("Attempting to create in-memory client")
166
  return QdrantClient(":memory:")
167
 
168
  def retrieve_documents(query, k=5):
169
  """Retrieve relevant documents for a query."""
170
  # Define collection name
171
  collection_name = "kohavi_ab_testing_pdf_collection"
172
-
173
- # DEBUG START - HF Compatibility fix
174
- print(f"DEBUG: Starting document retrieval for query: '{query[:30]}...'")
175
- print(f"DEBUG: PROCESSED_DATA_DIR exists: {os.path.exists(PROCESSED_DATA_DIR)}")
176
- print(f"DEBUG: CHUNKS_FILE exists: {os.path.exists(CHUNKS_FILE)}")
177
- print(f"DEBUG: QDRANT_DIR exists: {os.path.exists(QDRANT_DIR)}")
178
- # DEBUG END
179
 
180
  try:
 
 
 
181
  # Get models and data
182
  try:
183
  embedding_model = get_embedding_model()
 
184
  except Exception as e:
185
- # DEBUG START
186
- print(f"DEBUG: Error getting embedding model: {str(e)}")
187
- # DEBUG END
188
  return [], []
189
 
190
  try:
191
- chunks = load_document_chunks()
192
- # DEBUG START
193
- print(f"DEBUG: Loaded {len(chunks)} document chunks")
194
- # DEBUG END
 
 
 
 
195
  except Exception as e:
196
- # DEBUG START
197
- print(f"DEBUG: Error loading document chunks: {str(e)}")
198
- # DEBUG END
199
  return [], []
200
 
201
  try:
202
  client = setup_qdrant_client()
203
- # DEBUG START
204
- print("DEBUG: Successfully created Qdrant client")
205
- # DEBUG END
206
  except Exception as e:
207
- # DEBUG START
208
- print(f"DEBUG: Error setting up Qdrant client: {str(e)}")
209
- # DEBUG END
210
  return [], []
211
 
212
  # Check if collection exists
213
  try:
214
  collections = client.get_collections()
215
- # DEBUG START
216
- print(f"DEBUG: Available collections: {collections}")
217
 
218
  collection_info = client.get_collection(collection_name)
219
- print(f"DEBUG: Collection info: {collection_info}")
220
- # DEBUG END
221
  except Exception as e:
222
- # DEBUG START
223
- print(f"DEBUG: Error checking collection: {str(e)}")
224
- # DEBUG END
225
- return [], []
 
 
 
 
 
 
 
 
 
 
 
 
226
 
227
  # Create a mapping of IDs to documents
228
  docs_by_id = {i: doc for i, doc in enumerate(chunks)}
229
 
230
  # Get query embedding
231
  try:
 
232
  query_embedding = embedding_model.embed_query(query)
233
- # DEBUG START
234
- print(f"DEBUG: Generated embedding of length {len(query_embedding)}")
235
- # DEBUG END
236
  except Exception as e:
237
- # DEBUG START
238
- print(f"DEBUG: Error creating query embedding: {str(e)}")
239
- # DEBUG END
240
  return [], []
241
 
242
  # Search for relevant documents
@@ -244,50 +348,27 @@ def retrieve_documents(query, k=5):
244
 
245
  # Try different querying approaches
246
  try:
247
- # Simple query_points call
248
- results = client.query_points(
249
  collection_name=collection_name,
250
  query_vector=query_embedding,
251
  limit=k
252
  )
253
- # DEBUG START
254
- print(f"DEBUG: Retrieved {len(results)} results with query_points")
255
- # DEBUG END
256
  except Exception as e1:
257
- # DEBUG START
258
- print(f"DEBUG: First query approach failed: {str(e1)}")
259
- # DEBUG END
260
-
261
- try:
262
- # Try with minimum parameters
263
- results = client.search(
264
- collection_name=collection_name,
265
- query_vector=query_embedding,
266
- limit=k
267
- )
268
- # DEBUG START
269
- print(f"DEBUG: Retrieved {len(results)} results with search method")
270
- # DEBUG END
271
- except Exception as e2:
272
- # DEBUG START
273
- print(f"DEBUG: Second query approach failed: {str(e2)}")
274
- # DEBUG END
275
- return [], []
276
 
277
  # Handle empty results
278
  if not results:
279
- # DEBUG START
280
- print("DEBUG: No results found in vector store")
281
- # DEBUG END
282
  return [], []
283
 
284
  # Process results
285
  documents = []
286
  sources_dict = {}
287
 
288
- # DEBUG START
289
- print(f"DEBUG: Processing {len(results)} search results")
290
- # DEBUG END
291
 
292
  for result in results:
293
  try:
@@ -325,27 +406,21 @@ def retrieve_documents(query, k=5):
325
  "score": float(result.score),
326
  "type": "pdf"
327
  }
328
- # DEBUG START
329
- print(f"DEBUG: Added source: {title}, Page: {page}")
330
- # DEBUG END
331
  except Exception as e:
332
- # DEBUG START
333
- print(f"DEBUG: Error processing result: {str(e)}")
334
- # DEBUG END
335
  continue
336
 
337
  # Convert sources dictionary to list
338
  sources = list(sources_dict.values())
339
 
340
- # DEBUG START
341
- print(f"DEBUG: Returning {len(documents)} documents with {len(sources)} unique sources")
342
- # DEBUG END
343
  return documents, sources
344
 
345
  except Exception as e:
346
- # DEBUG START
347
- print(f"DEBUG: Unexpected error in retrieve_documents: {str(e)}")
348
- # DEBUG END
349
  return [], []
350
 
351
  def rephrase_query(query):
 
95
  @st.cache_resource
96
  def load_document_chunks():
97
  """Load pre-processed document chunks from disk."""
98
+ try:
99
+ print(f"Attempting to load chunks from: {CHUNKS_FILE}")
100
+ if not os.path.exists(CHUNKS_FILE):
101
+ print(f"ERROR: Chunks file not found at {CHUNKS_FILE}")
102
+ return []
103
+
104
+ with open(CHUNKS_FILE, 'rb') as f:
105
+ chunks = pickle.load(f)
106
+ print(f"Successfully loaded {len(chunks)} document chunks")
107
+ return chunks
108
+ except Exception as e:
109
+ print(f"Error loading document chunks: {str(e)}")
110
+ # Try a direct load without caching
111
+ try:
112
+ print("Attempting direct load without caching")
113
+ with open(CHUNKS_FILE, 'rb') as f:
114
+ chunks = pickle.load(f)
115
+ print(f"Direct load successful: {len(chunks)} chunks")
116
+ return chunks
117
+ except Exception as e2:
118
+ print(f"Direct load also failed: {str(e2)}")
119
+ return []
120
 
121
  @st.cache_resource
122
  def get_chat_model():
123
  """Get the chat model for initial RAG."""
124
+ from openai import OpenAI
125
  import os
126
+
127
+ api_key = os.environ.get("OPENAI_API_KEY")
128
+ client = OpenAI(api_key=api_key)
129
+
130
+ # Create a function with the same interface as ChatOpenAI.invoke
131
+ class SimpleOpenAIWrapper:
132
+ def __init__(self, client, model):
133
+ self.client = client
134
+ self.model = model
135
+
136
+ def invoke(self, messages):
137
+ # Convert LangChain messages to OpenAI format
138
+ openai_messages = []
139
+ for msg in messages:
140
+ openai_messages.append({
141
+ "role": msg.type if hasattr(msg, "type") else "user",
142
+ "content": msg.content
143
+ })
144
+
145
+ # Call the OpenAI API directly
146
+ response = self.client.chat.completions.create(
147
+ model=self.model,
148
+ messages=openai_messages,
149
+ temperature=0
150
+ )
151
+
152
+ # Create a simple object with a content attribute to match LangChain interface
153
+ class SimpleResponse:
154
+ def __init__(self, content):
155
+ self.content = content
156
+
157
+ return SimpleResponse(response.choices[0].message.content)
158
+
159
+ # Return wrapper that matches the LangChain interface
160
+ return SimpleOpenAIWrapper(client, "gpt-4.1-mini")
161
 
162
  @st.cache_resource
163
  def get_agent_model():
164
  """Get the more powerful model for agent and evaluation."""
165
+ from openai import OpenAI
166
  import os
167
+
168
+ api_key = os.environ.get("OPENAI_API_KEY")
169
+ client = OpenAI(api_key=api_key)
170
+
171
+ # Create a function with the same interface as ChatOpenAI.invoke
172
+ class SimpleOpenAIWrapper:
173
+ def __init__(self, client, model):
174
+ self.client = client
175
+ self.model = model
176
+
177
+ def invoke(self, messages):
178
+ # Convert LangChain messages to OpenAI format
179
+ openai_messages = []
180
+ for msg in messages:
181
+ openai_messages.append({
182
+ "role": msg.type if hasattr(msg, "type") else "user",
183
+ "content": msg.content
184
+ })
185
+
186
+ # Call the OpenAI API directly
187
+ response = self.client.chat.completions.create(
188
+ model=self.model,
189
+ messages=openai_messages,
190
+ temperature=0
191
+ )
192
+
193
+ # Create a simple object with a content attribute to match LangChain interface
194
+ class SimpleResponse:
195
+ def __init__(self, content):
196
+ self.content = content
197
+
198
+ return SimpleResponse(response.choices[0].message.content)
199
+
200
+ # Return wrapper that matches the LangChain interface
201
+ return SimpleOpenAIWrapper(client, "gpt-4.1")
202
 
203
  @st.cache_resource
204
  def get_embedding_model():
205
  """Get the embedding model."""
206
+ from openai import OpenAI
207
  import os
208
+ import numpy as np
209
+
210
+ api_key = os.environ.get("OPENAI_API_KEY")
211
+ client = OpenAI(api_key=api_key)
212
+
213
+ # Create a wrapper class that matches the LangChain interface
214
+ class SimpleEmbeddings:
215
+ def __init__(self, client):
216
+ self.client = client
217
+
218
+ def embed_query(self, text):
219
+ print(f"Embedding query: {text[:50]}...")
220
+ response = self.client.embeddings.create(
221
+ model="text-embedding-3-small",
222
+ input=text
223
+ )
224
+ return response.data[0].embedding
225
+
226
+ def embed_documents(self, texts):
227
+ return [self.embed_query(text) for text in texts]
228
+
229
+ return SimpleEmbeddings(client)
230
 
231
  @st.cache_resource
232
  def setup_qdrant_client():
233
  """Set up the Qdrant client."""
234
  import os
235
 
236
+ # Check for processed data directory
237
+ processed_data_dir_exists = os.path.exists(PROCESSED_DATA_DIR)
238
+ print(f"PROCESSED_DATA_DIR exists: {processed_data_dir_exists}")
239
+ print(f"Contents of current directory: {os.listdir('.')}")
240
+
241
+ if processed_data_dir_exists:
242
+ print(f"Contents of PROCESSED_DATA_DIR: {os.listdir(PROCESSED_DATA_DIR)}")
243
+
244
+ qdrant_dir_exists = os.path.exists(QDRANT_DIR)
245
+ print(f"QDRANT_DIR exists: {qdrant_dir_exists}")
246
 
247
+ if qdrant_dir_exists:
248
+ print(f"Contents of QDRANT_DIR: {os.listdir(QDRANT_DIR)}")
 
 
249
 
250
+ # Try creating the client with a simple path parameter
251
  try:
252
+ client = QdrantClient(path=str(QDRANT_DIR))
253
+ print("Successfully created QdrantClient with path parameter")
254
+ return client
255
  except Exception as e:
256
+ print(f"Error creating QdrantClient with path: {str(e)}")
 
 
257
 
258
  # Try with location parameter
259
  try:
260
+ client = QdrantClient(location=str(QDRANT_DIR))
261
+ print("Successfully created QdrantClient with location parameter")
262
+ return client
263
  except Exception as e2:
264
+ print(f"Error creating QdrantClient with location: {str(e2)}")
 
 
265
 
266
+ # Last resort - try in-memory
267
+ print("Creating in-memory QdrantClient as fallback")
268
  return QdrantClient(":memory:")
269
 
270
  def retrieve_documents(query, k=5):
271
  """Retrieve relevant documents for a query."""
272
  # Define collection name
273
  collection_name = "kohavi_ab_testing_pdf_collection"
274
+ print(f"======= QUERY: {query} =======")
 
 
 
 
 
 
275
 
276
  try:
277
+ # Check for processed data
278
+ print(f"CHUNKS_FILE exists: {os.path.exists(CHUNKS_FILE)}")
279
+
280
  # Get models and data
281
  try:
282
  embedding_model = get_embedding_model()
283
+ print("Successfully created embedding model")
284
  except Exception as e:
285
+ print(f"Error getting embedding model: {str(e)}")
286
+ # Try to fallback to direct API call instead of using LangChain
 
287
  return [], []
288
 
289
  try:
290
+ print("Loading document chunks...")
291
+ if not os.path.exists(CHUNKS_FILE):
292
+ print(f"ERROR: CHUNKS_FILE not found at {CHUNKS_FILE}")
293
+ return [], []
294
+
295
+ with open(CHUNKS_FILE, 'rb') as f:
296
+ chunks = pickle.load(f)
297
+ print(f"Successfully loaded {len(chunks)} document chunks")
298
  except Exception as e:
299
+ print(f"Error loading document chunks: {str(e)}")
 
 
300
  return [], []
301
 
302
  try:
303
  client = setup_qdrant_client()
304
+ print("Successfully created Qdrant client")
 
 
305
  except Exception as e:
306
+ print(f"Error setting up Qdrant client: {str(e)}")
 
 
307
  return [], []
308
 
309
  # Check if collection exists
310
  try:
311
  collections = client.get_collections()
312
+ print(f"Available collections: {collections}")
 
313
 
314
  collection_info = client.get_collection(collection_name)
315
+ print(f"Collection info: {collection_info}")
 
316
  except Exception as e:
317
+ print(f"Error checking collection: {str(e)}")
318
+ try:
319
+ # Try to initialize collection
320
+ print("Attempting to create collection...")
321
+ sample_embedding = embedding_model.embed_query("sample")
322
+ client.create_collection(
323
+ collection_name=collection_name,
324
+ vectors_config={
325
+ "size": len(sample_embedding),
326
+ "distance": "Cosine"
327
+ }
328
+ )
329
+ print(f"Created new collection {collection_name}")
330
+ except Exception as e2:
331
+ print(f"Failed to create collection: {str(e2)}")
332
+ return [], []
333
 
334
  # Create a mapping of IDs to documents
335
  docs_by_id = {i: doc for i, doc in enumerate(chunks)}
336
 
337
  # Get query embedding
338
  try:
339
+ print(f"Generating embedding for query: {query}")
340
  query_embedding = embedding_model.embed_query(query)
341
+ print(f"Successfully generated embedding of length {len(query_embedding)}")
 
 
342
  except Exception as e:
343
+ print(f"Error creating query embedding: {str(e)}")
 
 
344
  return [], []
345
 
346
  # Search for relevant documents
 
348
 
349
  # Try different querying approaches
350
  try:
351
+ print(f"Querying collection {collection_name}")
352
+ results = client.search(
353
  collection_name=collection_name,
354
  query_vector=query_embedding,
355
  limit=k
356
  )
357
+ print(f"Retrieved {len(results)} results with search method")
 
 
358
  except Exception as e1:
359
+ print(f"Search failed: {str(e1)}")
360
+ return [], []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
361
 
362
  # Handle empty results
363
  if not results:
364
+ print("No results found in vector store")
 
 
365
  return [], []
366
 
367
  # Process results
368
  documents = []
369
  sources_dict = {}
370
 
371
+ print(f"Processing {len(results)} search results")
 
 
372
 
373
  for result in results:
374
  try:
 
406
  "score": float(result.score),
407
  "type": "pdf"
408
  }
409
+ print(f"Added source: {title}, Page: {page}")
 
 
410
  except Exception as e:
411
+ print(f"Error processing result: {str(e)}")
 
 
412
  continue
413
 
414
  # Convert sources dictionary to list
415
  sources = list(sources_dict.values())
416
 
417
+ print(f"Returning {len(documents)} documents with {len(sources)} unique sources")
 
 
418
  return documents, sources
419
 
420
  except Exception as e:
421
+ print(f"Unexpected error in retrieve_documents: {str(e)}")
422
+ import traceback
423
+ traceback.print_exc()
424
  return [], []
425
 
426
  def rephrase_query(query):