kamkol commited on
Commit
ac15e0f
Β·
1 Parent(s): 5f13732

Fix DNS resolution errors and restore original UI

Browse files
Files changed (1) hide show
  1. streamlit_app.py +224 -263
streamlit_app.py CHANGED
@@ -204,143 +204,85 @@ def get_chat_model():
204
  """Get the chat model for initial RAG."""
205
  print("Initializing chat model...")
206
  try:
207
- # Very minimal OpenAI initialization for Hugging Face compatibility
208
  openai_api_key = os.environ.get("OPENAI_API_KEY", "")
209
  if not openai_api_key:
210
  print("WARNING: OPENAI_API_KEY environment variable not set!")
211
  raise ValueError("OpenAI API key not found")
 
 
 
 
 
 
212
 
213
- # First try: OpenAI client
214
- try:
215
- openai_client = OpenAI(api_key=openai_api_key)
216
-
217
- # Test the connection
218
- print("Testing OpenAI chat API connection...")
219
- test_response = openai_client.chat.completions.create(
220
- model="gpt-3.5-turbo",
221
- messages=[{"role": "user", "content": "test"}],
222
- max_tokens=5
223
- )
224
- print("OpenAI chat API connection successful!")
225
-
226
- # Create a simplified wrapper that avoids any problematic parameters
227
- class SimpleOpenAIWrapper:
228
- def invoke(self, messages):
229
- print("Invoking chat model via client...")
230
- # Convert LangChain messages to OpenAI format
231
- openai_messages = []
232
- for msg in messages:
233
- role = "user"
234
- if hasattr(msg, "type"):
235
- role = "assistant" if msg.type == "ai" else "user"
236
- openai_messages.append({
237
- "role": role,
238
- "content": msg.content
239
- })
240
 
241
- # Call API directly with absolutely minimal parameters
242
- try:
243
- response = openai_client.chat.completions.create(
244
- model="gpt-3.5-turbo", # Use a minimal, widely supported model
245
- messages=openai_messages
246
- )
247
-
248
- # Create response object
249
- class SimpleResponse:
250
- def __init__(self, content):
251
- self.content = content
252
-
253
- result = SimpleResponse(response.choices[0].message.content)
254
- print(f"Got response of length: {len(result.content)}")
255
- return result
256
- except Exception as e:
257
- print(f"Error calling OpenAI API: {str(e)}")
258
- # Fallback to HTTP request
259
- content = http_chat_request(openai_messages, openai_api_key)
 
 
 
 
 
 
260
  return type('obj', (object,), {'content': content})
261
-
262
- return SimpleOpenAIWrapper()
263
-
264
- except Exception as e:
265
- print(f"OpenAI client approach failed: {str(e)}")
266
- print("Falling back to direct HTTP requests...")
267
- raise # Continue to HTTP fallback
268
-
269
- except Exception as e:
270
- print(f"Standard chat model approach failed: {str(e)}")
 
 
271
 
272
- # Direct HTTP fallback approach
273
- try:
274
- openai_api_key = os.environ.get("OPENAI_API_KEY", "")
275
- if not openai_api_key:
276
- raise ValueError("OpenAI API key not found")
277
-
278
- # Test HTTP connection
279
- print("Testing direct HTTP connection to OpenAI chat...")
280
- test_message = http_chat_request([{"role": "user", "content": "test"}], openai_api_key)
281
- if not test_message:
282
- raise ValueError("HTTP chat fallback test failed")
283
- print("Direct HTTP chat connection successful!")
284
-
285
- class HTTPChatModel:
286
- def invoke(self, messages):
287
- print("Invoking chat model via HTTP...")
288
- # Convert LangChain messages to OpenAI format
289
- openai_messages = []
290
- for msg in messages:
291
- role = "user"
292
- if hasattr(msg, "type"):
293
- role = "assistant" if msg.type == "ai" else "user"
294
- openai_messages.append({
295
- "role": role,
296
- "content": msg.content
297
- })
298
-
299
- content = http_chat_request(openai_messages, openai_api_key)
300
- return type('obj', (object,), {'content': content})
301
-
302
- return HTTPChatModel()
303
-
304
- except Exception as e:
305
- print(f"All chat model approaches failed: {str(e)}")
306
- # Create dummy for testing
307
- class DummyModel:
308
- def invoke(self, messages):
309
- print("WARNING: Using dummy model!")
310
- return type('obj', (object,), {'content': 'I apologize, but I cannot access the necessary data to answer this question due to API connectivity issues.'})
311
-
312
- return DummyModel()
313
-
314
- # Add HTTP chat completion function
315
- def http_chat_request(messages, api_key):
316
- """Make a direct HTTP request to OpenAI chat API."""
317
- import requests
318
- import json
319
-
320
- print("Using direct HTTP request for chat completion")
321
- url = "https://api.openai.com/v1/chat/completions"
322
- headers = {
323
- "Content-Type": "application/json",
324
- "Authorization": f"Bearer {api_key}"
325
- }
326
- data = {
327
- "model": "gpt-3.5-turbo",
328
- "messages": messages
329
- }
330
-
331
- try:
332
- response = requests.post(url, headers=headers, data=json.dumps(data))
333
- if response.status_code == 200:
334
- result = response.json()
335
- content = result["choices"][0]["message"]["content"]
336
- print(f"Successfully got chat response via HTTP (length: {len(content)})")
337
- return content
338
- else:
339
- print(f"HTTP chat request failed with status {response.status_code}: {response.text}")
340
- return "I apologize, but I encountered an error connecting to the AI service."
341
  except Exception as e:
342
- print(f"HTTP chat request exception: {str(e)}")
343
- return "I apologize, but I encountered a technical issue while processing your request."
 
 
 
 
 
 
344
 
345
  @st.cache_resource
346
  def get_agent_model():
@@ -354,127 +296,80 @@ def get_embedding_model():
354
  """Get the embedding model."""
355
  print("Initializing embedding model...")
356
  try:
357
- # Very minimal OpenAI initialization for Hugging Face compatibility
358
  openai_api_key = os.environ.get("OPENAI_API_KEY", "")
359
  if not openai_api_key:
360
  print("WARNING: OPENAI_API_KEY environment variable not set!")
361
  raise ValueError("OpenAI API key not found")
 
 
 
 
 
 
362
 
363
- # First try: Use OpenAI client
364
- try:
365
- openai_client = OpenAI(api_key=openai_api_key)
366
-
367
- # Test the connection
368
- print("Testing OpenAI API connection...")
369
- response = openai_client.embeddings.create(
370
- model="text-embedding-ada-002",
371
- input="Test"
372
- )
373
- print("OpenAI API connection successful!")
374
-
375
- # Create a wrapper that avoids any problematic parameters
376
- class SimpleEmbeddings:
377
- def embed_query(self, text):
378
- print(f"Embedding query of length: {len(text)}")
379
- try:
380
- response = openai_client.embeddings.create(
381
- model="text-embedding-ada-002", # Use older, more compatible model
382
- input=text
383
- )
 
 
 
 
 
384
  print("Successfully got embedding")
385
- return response.data[0].embedding
386
- except Exception as e:
387
- print(f"Error in embed_query: {str(e)}")
388
- # Fall back to direct HTTP request
389
- return http_embed_request(text, openai_api_key)
390
-
391
- def embed_documents(self, texts):
392
- print(f"Embedding {len(texts)} documents")
393
- results = []
394
- for i, text in enumerate(texts):
395
- results.append(self.embed_query(text))
396
- return results
397
-
398
- return SimpleEmbeddings()
399
-
400
- except Exception as e:
401
- print(f"OpenAI client failed: {str(e)}")
402
- print("Falling back to direct HTTP requests...")
403
- raise # Continue to HTTP fallback
404
 
 
 
 
 
 
 
 
 
405
  except Exception as e:
406
- print(f"Standard embedding approach failed: {str(e)}")
407
 
408
- # Direct HTTP fallback approach
409
- try:
410
- openai_api_key = os.environ.get("OPENAI_API_KEY", "")
411
- if not openai_api_key:
412
- raise ValueError("OpenAI API key not found")
413
-
414
- # Test the connection with direct HTTP
415
- print("Testing direct HTTP connection to OpenAI...")
416
- test_embedding = http_embed_request("Test", openai_api_key)
417
- if not test_embedding:
418
- raise ValueError("HTTP fallback test failed")
419
- print("Direct HTTP connection successful!")
420
-
421
- class HTTPEmbeddings:
422
- def embed_query(self, text):
423
- print(f"HTTP embedding query of length: {len(text)}")
424
- return http_embed_request(text, openai_api_key)
425
-
426
- def embed_documents(self, texts):
427
- print(f"HTTP embedding {len(texts)} documents")
428
- results = []
429
- for text in texts:
430
- results.append(self.embed_query(text))
431
- return results
432
 
433
- return HTTPEmbeddings()
434
-
435
- except Exception as e:
436
- print(f"All embedding approaches failed: {str(e)}")
437
- # Last resort: Dummy implementation
438
- print("Using dummy embeddings as last resort")
439
- class DummyEmbeddings:
440
- def embed_query(self, text):
441
- print("WARNING: Using dummy embeddings!")
442
- return [0.0] * 1536
443
-
444
- def embed_documents(self, texts):
445
- return [[0.0] * 1536 for _ in range(len(texts))]
446
-
447
- return DummyEmbeddings()
448
-
449
- # Add HTTP fallback function
450
- def http_embed_request(text, api_key):
451
- """Make a direct HTTP request to OpenAI embeddings API."""
452
- import requests
453
- import json
454
-
455
- print("Using direct HTTP request for embedding")
456
- url = "https://api.openai.com/v1/embeddings"
457
- headers = {
458
- "Content-Type": "application/json",
459
- "Authorization": f"Bearer {api_key}"
460
- }
461
- data = {
462
- "model": "text-embedding-ada-002",
463
- "input": text
464
- }
465
-
466
- try:
467
- response = requests.post(url, headers=headers, data=json.dumps(data))
468
- if response.status_code == 200:
469
- result = response.json()
470
- print("Successfully got embedding via HTTP")
471
- return result["data"][0]["embedding"]
472
- else:
473
- print(f"HTTP request failed with status {response.status_code}: {response.text}")
474
- return None
475
- except Exception as e:
476
- print(f"HTTP request exception: {str(e)}")
477
- return None
478
 
479
  @st.cache_resource
480
  def setup_qdrant_client():
@@ -601,7 +496,7 @@ def setup_retriever():
601
  return QdrantRetriever()
602
 
603
  def rag_chain_node(query, run_manager):
604
- """A LangGraph node for retrieval augmented generation. Returns a string."""
605
  print("Starting rag_chain_node...")
606
  # Log the query
607
  print(f"Query: {query}")
@@ -618,10 +513,25 @@ def rag_chain_node(query, run_manager):
618
  print(f"Retrieved {len(relevant_docs)} documents")
619
 
620
  # Print document sources for debugging
 
621
  for i, doc in enumerate(relevant_docs):
622
  source = doc.metadata.get("source", "Unknown")
623
  page = doc.metadata.get("page", "Unknown")
624
  print(f"Document {i+1} source: {source}, Page: {page}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
625
 
626
  # Format documents to include in the prompt
627
  formatted_docs = "\n\n".join([f"Document from {doc.metadata.get('source', 'Unknown')}, Page {doc.metadata.get('page', 'Unknown')}:\n{doc.page_content}" for doc in relevant_docs])
@@ -645,7 +555,7 @@ Answer:"""
645
  # Generate response
646
  response = chat_model.invoke(rag_prompt)
647
  print("Successfully generated response")
648
- return response.content
649
 
650
  def evaluate_response(query, response):
651
  """
@@ -944,41 +854,92 @@ def execute_agent(agent, query):
944
 
945
  # Streamlit UI
946
  st.set_page_config(
947
- page_title="AB Testing RAG Agent",
948
  page_icon="πŸ“Š",
949
  layout="wide"
950
  )
951
 
952
  def main():
953
  """Main function for the Streamlit app."""
954
- st.title("A/B Testing RAG Assistant")
955
- st.write("Ask me about A/B testing concepts in the pdfs!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
956
 
957
- # Add a text input for the query
958
- query = st.text_input("Ask a question:")
959
 
960
- # Process the query when submitted
961
- if st.button("Submit") or query:
962
- if query:
963
- with st.spinner("Thinking..."):
 
 
 
 
 
 
 
 
 
964
  try:
965
- # First try the RAG approach
966
  st.write("Starting with Initial RAG...")
967
- response = rag_chain_node(query, None)
 
 
 
968
 
969
- # Display the initial response
970
- st.write(response)
971
 
972
- # For debugging only
973
- print("Initial response complete")
 
 
 
 
974
 
 
975
  except Exception as e:
976
- st.error(f"An error occurred during document retrieval or response generation: {str(e)}")
977
- print(f"Error in main function: {str(e)}")
978
- import traceback
979
- traceback.print_exc()
980
- else:
981
- st.write("Please enter a question.")
 
 
 
 
 
 
 
982
 
983
  if __name__ == "__main__":
984
- main()
 
 
204
  """Get the chat model for initial RAG."""
205
  print("Initializing chat model...")
206
  try:
207
+ # Set API key from environment
208
  openai_api_key = os.environ.get("OPENAI_API_KEY", "")
209
  if not openai_api_key:
210
  print("WARNING: OPENAI_API_KEY environment variable not set!")
211
  raise ValueError("OpenAI API key not found")
212
+
213
+ # Create a wrapper class with a shorter timeout to fail faster on DNS issues
214
+ class TimeoutChatModel:
215
+ def __init__(self, api_key):
216
+ self.api_key = api_key
217
+ self.timeout = 5 # Short timeout to fail fast on DNS issues
218
 
219
+ def invoke(self, messages):
220
+ print("Invoking chat model...")
221
+ try:
222
+ # Convert string input to message format if needed
223
+ if isinstance(messages, str):
224
+ openai_messages = [{"role": "user", "content": messages}]
225
+ else:
226
+ # Convert LangChain messages to OpenAI format
227
+ openai_messages = []
228
+ for msg in messages:
229
+ role = "user"
230
+ if hasattr(msg, "type"):
231
+ role = "assistant" if msg.type == "ai" else "user"
232
+ openai_messages.append({
233
+ "role": role,
234
+ "content": msg.content
235
+ })
 
 
 
 
 
 
 
 
 
 
236
 
237
+ # Direct API call with timeout
238
+ import requests
239
+ import json
240
+
241
+ url = "https://api.openai.com/v1/chat/completions"
242
+ headers = {
243
+ "Content-Type": "application/json",
244
+ "Authorization": f"Bearer {self.api_key}"
245
+ }
246
+ data = {
247
+ "model": "gpt-3.5-turbo",
248
+ "messages": openai_messages
249
+ }
250
+
251
+ response = requests.post(
252
+ url,
253
+ headers=headers,
254
+ data=json.dumps(data),
255
+ timeout=self.timeout
256
+ )
257
+
258
+ if response.status_code == 200:
259
+ result = response.json()
260
+ content = result["choices"][0]["message"]["content"]
261
+ print(f"Got response of length: {len(content)}")
262
  return type('obj', (object,), {'content': content})
263
+ else:
264
+ print(f"API request failed with status {response.status_code}")
265
+ raise Exception(f"API request failed: {response.text}")
266
+ except requests.exceptions.Timeout:
267
+ print("Timeout connecting to OpenAI API")
268
+ raise Exception("Timeout connecting to OpenAI API")
269
+ except requests.exceptions.ConnectionError as e:
270
+ print(f"Connection error to OpenAI API: {str(e)}")
271
+ raise Exception(f"Connection error: {str(e)}")
272
+ except Exception as e:
273
+ print(f"Error in chat model: {str(e)}")
274
+ raise
275
 
276
+ return TimeoutChatModel(openai_api_key)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
277
  except Exception as e:
278
+ print(f"Error initializing chat model: {str(e)}")
279
+ # Create dummy for testing
280
+ class DummyModel:
281
+ def invoke(self, messages):
282
+ print("WARNING: Using dummy model!")
283
+ return type('obj', (object,), {'content': 'I apologize, but I cannot access the necessary data to answer this question due to API connectivity issues.'})
284
+
285
+ return DummyModel()
286
 
287
  @st.cache_resource
288
  def get_agent_model():
 
296
  """Get the embedding model."""
297
  print("Initializing embedding model...")
298
  try:
299
+ # Set API key from environment
300
  openai_api_key = os.environ.get("OPENAI_API_KEY", "")
301
  if not openai_api_key:
302
  print("WARNING: OPENAI_API_KEY environment variable not set!")
303
  raise ValueError("OpenAI API key not found")
304
+
305
+ # Create a wrapper class with a shorter timeout to fail faster on DNS issues
306
+ class TimeoutEmbeddings:
307
+ def __init__(self, api_key):
308
+ self.api_key = api_key
309
+ self.timeout = 5 # Short timeout to fail fast on DNS issues
310
 
311
+ def embed_query(self, text):
312
+ print(f"Embedding query of length: {len(text)}")
313
+ try:
314
+ # Direct API call with timeout
315
+ import requests
316
+ import json
317
+
318
+ url = "https://api.openai.com/v1/embeddings"
319
+ headers = {
320
+ "Content-Type": "application/json",
321
+ "Authorization": f"Bearer {self.api_key}"
322
+ }
323
+ data = {
324
+ "model": "text-embedding-ada-002",
325
+ "input": text
326
+ }
327
+
328
+ response = requests.post(
329
+ url,
330
+ headers=headers,
331
+ data=json.dumps(data),
332
+ timeout=self.timeout
333
+ )
334
+
335
+ if response.status_code == 200:
336
+ result = response.json()
337
  print("Successfully got embedding")
338
+ return result["data"][0]["embedding"]
339
+ else:
340
+ print(f"API request failed with status {response.status_code}")
341
+ raise Exception(f"API request failed: {response.text}")
342
+ except requests.exceptions.Timeout:
343
+ print("Timeout connecting to OpenAI API - using dummy embedding")
344
+ return [0.0] * 1536
345
+ except requests.exceptions.ConnectionError:
346
+ print("Connection error to OpenAI API - using dummy embedding")
347
+ return [0.0] * 1536
348
+ except Exception as e:
349
+ print(f"Error getting embeddings: {str(e)}")
350
+ return [0.0] * 1536
 
 
 
 
 
 
351
 
352
+ def embed_documents(self, texts):
353
+ print(f"Embedding {len(texts)} documents")
354
+ results = []
355
+ for i, text in enumerate(texts):
356
+ results.append(self.embed_query(text))
357
+ return results
358
+
359
+ return TimeoutEmbeddings(openai_api_key)
360
  except Exception as e:
361
+ print(f"Error initializing embedding model: {str(e)}")
362
 
363
+ # Create dummy for testing
364
+ class DummyEmbeddings:
365
+ def embed_query(self, text):
366
+ print("WARNING: Using dummy embeddings!")
367
+ return [0.0] * 1536
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
368
 
369
+ def embed_documents(self, texts):
370
+ return [[0.0] * 1536 for _ in range(len(texts))]
371
+
372
+ return DummyEmbeddings()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
373
 
374
  @st.cache_resource
375
  def setup_qdrant_client():
 
496
  return QdrantRetriever()
497
 
498
  def rag_chain_node(query, run_manager):
499
+ """A LangGraph node for retrieval augmented generation. Returns a string and sources."""
500
  print("Starting rag_chain_node...")
501
  # Log the query
502
  print(f"Query: {query}")
 
513
  print(f"Retrieved {len(relevant_docs)} documents")
514
 
515
  # Print document sources for debugging
516
+ sources = []
517
  for i, doc in enumerate(relevant_docs):
518
  source = doc.metadata.get("source", "Unknown")
519
  page = doc.metadata.get("page", "Unknown")
520
  print(f"Document {i+1} source: {source}, Page: {page}")
521
+
522
+ # Extract source information for display
523
+ source_path = source
524
+ filename = source_path.split("/")[-1] if "/" in source_path else source_path
525
+
526
+ # Remove .pdf extension if present
527
+ if filename.lower().endswith('.pdf'):
528
+ filename = filename[:-4]
529
+
530
+ sources.append({
531
+ "title": f"Ron Kohavi: {filename}",
532
+ "page": page,
533
+ "type": "pdf"
534
+ })
535
 
536
  # Format documents to include in the prompt
537
  formatted_docs = "\n\n".join([f"Document from {doc.metadata.get('source', 'Unknown')}, Page {doc.metadata.get('page', 'Unknown')}:\n{doc.page_content}" for doc in relevant_docs])
 
555
  # Generate response
556
  response = chat_model.invoke(rag_prompt)
557
  print("Successfully generated response")
558
+ return response.content, sources
559
 
560
  def evaluate_response(query, response):
561
  """
 
854
 
855
  # Streamlit UI
856
  st.set_page_config(
857
+ page_title="πŸ“Š AB Testing RAG Agent",
858
  page_icon="πŸ“Š",
859
  layout="wide"
860
  )
861
 
862
  def main():
863
  """Main function for the Streamlit app."""
864
+ st.title("πŸ“Š AB Testing RAG Agent")
865
+ st.markdown("""
866
+ This specialized agent can answer questions about A/B Testing using a collection of Ron Kohavi's work. If it can't fully answer your A/B Testing questions using this collection, it will then automatically search Arxiv. Let's begin!
867
+ """)
868
+
869
+ # Initialize chat history
870
+ if "messages" not in st.session_state:
871
+ st.session_state.messages = []
872
+
873
+ # Display chat history
874
+ for message in st.session_state.messages:
875
+ with st.chat_message(message["role"]):
876
+ st.markdown(message["content"])
877
+
878
+ # Display sources if available
879
+ if "sources" in message and message["sources"]:
880
+ st.markdown("#### Sources")
881
+ for i, source in enumerate(message["sources"]):
882
+ title = source.get("title", "Unknown")
883
+
884
+ # Display differently based on source type
885
+ if source.get("type") == "arxiv":
886
+ authors = source.get("authors", "Unknown authors")
887
+ st.markdown(f"**{i+1}. {title}**\nAuthors: {authors}")
888
+ else:
889
+ # PDF source with page number
890
+ page = source.get("page", "Unknown")
891
+ st.markdown(f"**{i+1}. {title}** (Page: {page})")
892
 
893
+ # Input for new question
894
+ query = st.chat_input("Ask a question about A/B Testing")
895
 
896
+ if query:
897
+ # Add user message to chat history
898
+ st.session_state.messages.append({"role": "user", "content": query})
899
+
900
+ # Display user message
901
+ with st.chat_message("user"):
902
+ st.markdown(query)
903
+
904
+ # Display assistant response
905
+ with st.chat_message("assistant"):
906
+ message_placeholder = st.empty()
907
+
908
+ with st.status("Processing your query...", expanded=True) as status:
909
  try:
910
+ # Use the RAG approach with a timeout
911
  st.write("Starting with Initial RAG...")
912
+ print("Starting RAG process for query:", query)
913
+
914
+ # Step 1: Initial RAG
915
+ response, sources = rag_chain_node(query, None)
916
 
917
+ # Display the processed response
918
+ message_placeholder.markdown(response)
919
 
920
+ # Add assistant message to chat history
921
+ st.session_state.messages.append({
922
+ "role": "assistant",
923
+ "content": response,
924
+ "sources": sources
925
+ })
926
 
927
+ status.update(label="Completed!", state="complete", expanded=False)
928
  except Exception as e:
929
+ error_msg = str(e)
930
+ if "Name or service not known" in error_msg:
931
+ response = "I'm having trouble connecting to the language model API due to network restrictions. The Hugging Face environment may be blocking external API calls."
932
+ else:
933
+ response = f"An error occurred: {error_msg}"
934
+
935
+ message_placeholder.markdown(response)
936
+ st.session_state.messages.append({
937
+ "role": "assistant",
938
+ "content": response,
939
+ "sources": []
940
+ })
941
+ status.update(label="Error", state="error", expanded=False)
942
 
943
  if __name__ == "__main__":
944
+ if query:
945
+ main()