File size: 37,639 Bytes
9fa16f1
 
 
 
ece2d3a
9fa16f1
 
 
4ddf811
9fa16f1
 
 
 
 
 
 
 
 
0aec276
b70309f
0aec276
2d27322
 
9fa16f1
 
 
 
 
 
 
 
 
 
 
 
525d5c5
 
 
 
db89085
 
 
 
 
 
 
 
ece2d3a
 
 
 
 
525d5c5
ece2d3a
 
 
 
 
db89085
ece2d3a
 
 
 
db89085
ece2d3a
 
 
06d2b05
 
 
 
 
ece2d3a
06d2b05
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ece2d3a
06d2b05
 
 
 
 
 
 
 
 
 
 
 
 
 
ece2d3a
 
 
06d2b05
 
ece2d3a
 
 
 
db89085
ece2d3a
 
9fa16f1
ece2d3a
 
525d5c5
 
 
 
 
4ddf811
 
9fa16f1
 
 
 
 
 
4ddf811
9fa16f1
 
 
 
 
 
 
 
 
4ddf811
 
9fa16f1
4ddf811
 
9fa16f1
4ddf811
 
9fa16f1
 
4ddf811
 
 
 
9fa16f1
 
 
525d5c5
b70309f
 
 
 
 
4ddf811
b70309f
52657d6
 
 
b70309f
525d5c5
 
 
b70309f
52657d6
 
525d5c5
 
b70309f
9fa16f1
 
 
 
525d5c5
b70309f
ac15e0f
25ae1f7
5f13732
 
 
ac15e0f
 
 
 
 
 
5f13732
ac15e0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5f13732
ac15e0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5f13732
ac15e0f
 
 
 
 
 
 
 
 
 
 
 
25ae1f7
ac15e0f
25ae1f7
ac15e0f
 
 
 
 
 
 
 
9fa16f1
 
 
 
525d5c5
25ae1f7
 
9fa16f1
 
 
 
525d5c5
b70309f
ac15e0f
25ae1f7
5f13732
 
 
ac15e0f
 
 
 
 
 
5f13732
ac15e0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5f13732
ac15e0f
 
 
 
 
 
 
 
 
 
 
 
 
5f13732
ac15e0f
 
 
 
 
 
 
 
5f13732
ac15e0f
25ae1f7
ac15e0f
 
 
 
 
5f13732
ac15e0f
 
 
 
9fa16f1
 
 
 
525d5c5
b70309f
 
 
 
52657d6
47892f0
525d5c5
52657d6
525d5c5
2d27322
 
b70309f
 
525d5c5
b70309f
525d5c5
2d27322
 
 
 
 
 
525d5c5
2d27322
 
 
 
 
 
525d5c5
 
2d27322
52657d6
47892f0
52657d6
525d5c5
 
2d27322
 
4f7e4ee
525d5c5
2d27322
 
 
4f7e4ee
2d27322
525d5c5
 
 
 
 
 
 
 
 
 
 
 
9fa16f1
5f13732
 
 
 
 
 
 
45f16a5
5f13732
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
525d5c5
 
5f13732
 
525d5c5
 
 
 
5f13732
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4ddf811
5f13732
 
 
ac15e0f
5f13732
 
 
4ddf811
5f13732
 
 
4ddf811
5f13732
 
4ddf811
5f13732
 
 
4ddf811
5f13732
ac15e0f
5f13732
 
 
 
ac15e0f
 
 
 
 
 
 
 
 
 
 
 
 
 
5f13732
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ac15e0f
4ddf811
 
 
 
 
 
 
9fa16f1
4ddf811
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9fa16f1
 
4ddf811
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9fa16f1
 
4ddf811
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9fa16f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2d27322
 
 
9fa16f1
 
 
 
 
4ddf811
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2d27322
 
4ddf811
 
2d27322
4ddf811
2d27322
4ddf811
 
 
 
 
 
 
 
 
 
 
 
 
2d27322
4ddf811
2d27322
 
4ddf811
9fa16f1
 
 
ac15e0f
9fa16f1
 
 
 
5f13732
 
ac15e0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9fa16f1
ac15e0f
 
9fa16f1
ac15e0f
 
 
 
 
 
 
 
 
 
 
 
 
5f13732
ac15e0f
5f13732
ac15e0f
 
 
 
4ddf811
ac15e0f
 
5f13732
ac15e0f
 
 
 
 
 
5f13732
ac15e0f
5f13732
ac15e0f
 
 
 
 
 
 
 
 
 
 
 
 
5f13732
 
ac15e0f
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
import os
import pickle
import streamlit as st
from pathlib import Path
import tarfile
from dotenv import load_dotenv
from langchain_openai.chat_models import ChatOpenAI
from langchain_openai.embeddings import OpenAIEmbeddings
from langchain_core.prompts import ChatPromptTemplate, PromptTemplate
from qdrant_client import QdrantClient
from langchain_core.documents import Document
from langchain.agents import AgentExecutor, create_openai_tools_agent
from langchain_core.tools import tool
from langchain.agents.format_scratchpad.openai_tools import format_to_openai_tool_messages
from langchain_core.messages import AIMessage, HumanMessage
import requests
import json
from langchain_core.output_parsers import StrOutputParser
from openai import OpenAI
from qdrant_client.http.models import PointStruct

# Don't set proxy environment variables - they seem to cause issues in Hugging Face
# Instead, we'll handle this at the client level

# Global variable to store ArXiv sources
ARXIV_SOURCES = []

# Load environment variables
load_dotenv()
print("Loaded .env file")

# Configure OpenAI API key from environment variable
if not os.environ.get("OPENAI_API_KEY"):
    os.environ["OPENAI_API_KEY"] = os.environ.get("OPENAI_API_KEY_BACKUP", "")

# Debugging: Print current directory and its contents
print(f"Current directory: {os.getcwd()}")
print(f"Directory contents: {os.listdir('.')}")

# Check for Hugging Face Spaces path - this is where uploaded files through UI should be
HF_SPACES_DIR = "/data"
if os.path.exists(HF_SPACES_DIR):
    print(f"Found Hugging Face Spaces data directory at {HF_SPACES_DIR}")
    print(f"Contents: {os.listdir(HF_SPACES_DIR)}")
else:
    print(f"No Hugging Face Spaces data directory found at {HF_SPACES_DIR}")

# Paths to pre-processed data and package
PROCESSED_DATA_DIR = Path("processed_data")
CHUNKS_FILE = PROCESSED_DATA_DIR / "document_chunks.pkl"
QDRANT_DIR = PROCESSED_DATA_DIR / "qdrant_vectorstore"
PACKAGE_FILE = "processed_data.tar.gz"

# Extract packaged data if available
def extract_packaged_data():
    """Extract the packaged data if it exists."""
    if os.path.exists(PACKAGE_FILE):
        print(f"Found packaged data: {PACKAGE_FILE}")
        
        # Create processed_data directory if it doesn't exist
        if not os.path.exists(PROCESSED_DATA_DIR):
            os.makedirs(PROCESSED_DATA_DIR, exist_ok=True)
            print(f"Created directory: {PROCESSED_DATA_DIR}")
        
        # Extract the package
        try:
            with tarfile.open(PACKAGE_FILE, "r:gz") as tar:
                print("Examining tar file contents before extraction:")
                for member in tar.getmembers():
                    print(f"  File in archive: {member.name}")
                
                # Extract files, handling potential nested directories
                print("Extracting package...")
                for member in tar.getmembers():
                    # Skip directories
                    if member.isdir():
                        continue
                    
                    # Get the basename and handle nested paths
                    # If file is in processed_data/something, extract just "something"
                    # If file is just something, extract as is
                    basename = os.path.basename(member.name)
                    
                    # Determine target path
                    if basename == "document_chunks.pkl":
                        target_path = CHUNKS_FILE
                    elif "qdrant_vectorstore" in member.name:
                        # For Qdrant files, preserve the subdirectory structure
                        if member.name.startswith("processed_data/"):
                            # Remove 'processed_data/' prefix if it exists
                            relative_path = member.name[len("processed_data/"):]
                        else:
                            relative_path = member.name
                        
                        target_path = PROCESSED_DATA_DIR / relative_path
                    else:
                        # Other files go directly in processed_data
                        target_path = PROCESSED_DATA_DIR / basename
                    
                    # Create directories if needed
                    os.makedirs(os.path.dirname(target_path), exist_ok=True)
                    
                    # Extract the file
                    print(f"  Extracting {member.name} to {target_path}")
                    f = tar.extractfile(member)
                    if f is not None:
                        with open(target_path, "wb") as out_file:
                            out_file.write(f.read())
                
                print("Extraction complete")
                
                # Verify extraction worked
                print("Checking extracted files:")
                if os.path.exists(CHUNKS_FILE):
                    print(f"  {CHUNKS_FILE} exists: βœ“")
                else:
                    print(f"  {CHUNKS_FILE} exists: βœ—")
                
                if os.path.exists(QDRANT_DIR):
                    print(f"  {QDRANT_DIR} exists: βœ“")
                    print(f"  Contents: {os.listdir(QDRANT_DIR)}")
                else:
                    print(f"  {QDRANT_DIR} exists: βœ—")
                    
                return True
        except Exception as e:
            print(f"Error extracting package: {str(e)}")
            import traceback
            traceback.print_exc()
            return False
    else:
        print(f"No packaged data found: {PACKAGE_FILE}")
        return False

# Extract packaged data on startup
extract_packaged_data()

# Check if processed data exists
print(f"Checking for processed data...")
print(f"CHUNKS_FILE exists: {os.path.exists(CHUNKS_FILE)}")
print(f"QDRANT_DIR exists: {os.path.exists(QDRANT_DIR)}")
if os.path.exists(QDRANT_DIR):
    print(f"QDRANT_DIR contents: {os.listdir(QDRANT_DIR)}")

# Define prompts exactly as in the notebook
RAG_PROMPT = """
CONTEXT:
{context}

QUERY:
{question}

You are a helpful assistant. Use the available context to answer the question. Do not use your own knowledge! If you cannot answer the question based on the context, you must say "I don't know".
"""

REPHRASE_QUERY_PROMPT = """
QUERY:
{question}

You are a helpful assistant. Rephrase the provided query to be more specific and to the point in order to improve retrieval in our RAG pipeline about AB Testing.
"""

EVALUATE_RESPONSE_PROMPT = """
Given an initial query, determine if the initial query is related to AB Testing (even vaguely e.g. statistics, A/B testing, etc.) or not. If not related to AB Testing, return 'Y'. If related to AB Testing, then given the initial query and a final response, determine if the final response is extremely helpful or not. If extremely helpful, return 'Y'. If not extremely helpful, return 'N'.

Initial Query:
{initial_query}

Final Response:
{final_response}
"""

rag_prompt = ChatPromptTemplate.from_template(RAG_PROMPT)
rephrase_query_prompt = ChatPromptTemplate.from_template(REPHRASE_QUERY_PROMPT)
evaluate_prompt = PromptTemplate.from_template(EVALUATE_RESPONSE_PROMPT)

@st.cache_resource
def load_document_chunks():
    """Load pre-processed document chunks from disk."""
    print(f"Attempting to load document chunks from {CHUNKS_FILE}")
    if not os.path.exists(CHUNKS_FILE):
        print(f"WARNING: Chunks file not found at {CHUNKS_FILE}")
        print(f"Working directory contents: {os.listdir('.')}")
        if os.path.exists(PROCESSED_DATA_DIR):
            print(f"PROCESSED_DATA_DIR contents: {os.listdir(PROCESSED_DATA_DIR)}")
        return []
    
    try:
        with open(CHUNKS_FILE, 'rb') as f:
            chunks = pickle.load(f)
            print(f"Successfully loaded {len(chunks)} document chunks")
            # Print first chunk to verify data
            if chunks:
                print(f"First chunk metadata: {chunks[0].metadata}")
            return chunks
    except Exception as e:
        print(f"Error loading document chunks: {str(e)}")
        import traceback
        traceback.print_exc()
        return []

@st.cache_resource
def get_chat_model():
    """Get the chat model for initial RAG."""
    print("Initializing chat model...")
    try:
        # Set API key from environment
        openai_api_key = os.environ.get("OPENAI_API_KEY", "")
        if not openai_api_key:
            print("WARNING: OPENAI_API_KEY environment variable not set!")
            raise ValueError("OpenAI API key not found")
        
        # Create a wrapper class with a shorter timeout to fail faster on DNS issues
        class TimeoutChatModel:
            def __init__(self, api_key):
                self.api_key = api_key
                self.timeout = 5  # Short timeout to fail fast on DNS issues
            
            def invoke(self, messages):
                print("Invoking chat model...")
                try:
                    # Convert string input to message format if needed
                    if isinstance(messages, str):
                        openai_messages = [{"role": "user", "content": messages}]
                    else:
                        # Convert LangChain messages to OpenAI format
                        openai_messages = []
                        for msg in messages:
                            role = "user"
                            if hasattr(msg, "type"):
                                role = "assistant" if msg.type == "ai" else "user"
                            openai_messages.append({
                                "role": role,
                                "content": msg.content
                            })
                    
                    # Direct API call with timeout
                    import requests
                    import json
                    
                    url = "https://api.openai.com/v1/chat/completions"
                    headers = {
                        "Content-Type": "application/json",
                        "Authorization": f"Bearer {self.api_key}"
                    }
                    data = {
                        "model": "gpt-3.5-turbo",
                        "messages": openai_messages
                    }
                    
                    response = requests.post(
                        url, 
                        headers=headers, 
                        data=json.dumps(data),
                        timeout=self.timeout
                    )
                    
                    if response.status_code == 200:
                        result = response.json()
                        content = result["choices"][0]["message"]["content"]
                        print(f"Got response of length: {len(content)}")
                        return type('obj', (object,), {'content': content})
                    else:
                        print(f"API request failed with status {response.status_code}")
                        raise Exception(f"API request failed: {response.text}")
                except requests.exceptions.Timeout:
                    print("Timeout connecting to OpenAI API")
                    raise Exception("Timeout connecting to OpenAI API")
                except requests.exceptions.ConnectionError as e:
                    print(f"Connection error to OpenAI API: {str(e)}")
                    raise Exception(f"Connection error: {str(e)}")
                except Exception as e:
                    print(f"Error in chat model: {str(e)}")
                    raise
        
        return TimeoutChatModel(openai_api_key)
    except Exception as e:
        print(f"Error initializing chat model: {str(e)}")
        # Create dummy for testing
        class DummyModel:
            def invoke(self, messages):
                print("WARNING: Using dummy model!")
                return type('obj', (object,), {'content': 'I apologize, but I cannot access the necessary data to answer this question due to API connectivity issues.'})
        
        return DummyModel()

@st.cache_resource
def get_agent_model():
    """Get the more powerful model for agent and evaluation."""
    print("Initializing agent model...")
    # Use the exact same approach as the chat model for consistency
    return get_chat_model()

@st.cache_resource
def get_embedding_model():
    """Get the embedding model."""
    print("Initializing embedding model...")
    try:
        # Set API key from environment
        openai_api_key = os.environ.get("OPENAI_API_KEY", "")
        if not openai_api_key:
            print("WARNING: OPENAI_API_KEY environment variable not set!")
            raise ValueError("OpenAI API key not found")
        
        # Create a wrapper class with a shorter timeout to fail faster on DNS issues
        class TimeoutEmbeddings:
            def __init__(self, api_key):
                self.api_key = api_key
                self.timeout = 5  # Short timeout to fail fast on DNS issues
            
            def embed_query(self, text):
                print(f"Embedding query of length: {len(text)}")
                try:
                    # Direct API call with timeout
                    import requests
                    import json
                    
                    url = "https://api.openai.com/v1/embeddings"
                    headers = {
                        "Content-Type": "application/json",
                        "Authorization": f"Bearer {self.api_key}"
                    }
                    data = {
                        "model": "text-embedding-ada-002",
                        "input": text
                    }
                    
                    response = requests.post(
                        url, 
                        headers=headers, 
                        data=json.dumps(data),
                        timeout=self.timeout
                    )
                    
                    if response.status_code == 200:
                        result = response.json()
                        print("Successfully got embedding")
                        return result["data"][0]["embedding"]
                    else:
                        print(f"API request failed with status {response.status_code}")
                        raise Exception(f"API request failed: {response.text}")
                except requests.exceptions.Timeout:
                    print("Timeout connecting to OpenAI API - using dummy embedding")
                    return [0.0] * 1536
                except requests.exceptions.ConnectionError:
                    print("Connection error to OpenAI API - using dummy embedding")
                    return [0.0] * 1536
                except Exception as e:
                    print(f"Error getting embeddings: {str(e)}")
                    return [0.0] * 1536
            
            def embed_documents(self, texts):
                print(f"Embedding {len(texts)} documents")
                results = []
                for i, text in enumerate(texts):
                    results.append(self.embed_query(text))
                return results
        
        return TimeoutEmbeddings(openai_api_key)
    except Exception as e:
        print(f"Error initializing embedding model: {str(e)}")
        
        # Create dummy for testing
        class DummyEmbeddings:
            def embed_query(self, text):
                print("WARNING: Using dummy embeddings!")
                return [0.0] * 1536
            
            def embed_documents(self, texts):
                return [[0.0] * 1536 for _ in range(len(texts))]
        
        return DummyEmbeddings()

@st.cache_resource
def setup_qdrant_client():
    """Set up the Qdrant client."""
    print(f"Attempting to setup Qdrant client with path: {QDRANT_DIR}")
    # Check if Qdrant dir exists
    if not os.path.exists(QDRANT_DIR):
        print(f"WARNING: Qdrant directory not found: {QDRANT_DIR}")
        print(f"Contents of {PROCESSED_DATA_DIR}: {os.listdir(PROCESSED_DATA_DIR) if os.path.exists(PROCESSED_DATA_DIR) else 'Not found'}")
    
    try:
        print("Trying to create QdrantClient with path parameter")
        client = QdrantClient(path=str(QDRANT_DIR))
        print("Successfully created Qdrant client with path parameter")
        
        # Verify client works by getting collections
        try:
            collection_name = "kohavi_ab_testing_pdf_collection"
            print(f"Trying to get collections from Qdrant")
            collections = client.get_collections()
            print(f"Available collections: {collections.collections}")
            
            # Check if our collection exists
            collection_exists = False
            for collection in collections.collections:
                if collection.name == collection_name:
                    collection_exists = True
                    print(f"Found our collection: {collection_name}")
                    break
            
            if not collection_exists:
                print(f"WARNING: Collection '{collection_name}' not found!")
        except Exception as e:
            print(f"Warning: Could not get collections: {str(e)}")
            import traceback
            traceback.print_exc()
        
        return client
    except Exception as e:
        print(f"Error creating QdrantClient with path: {str(e)}")
        import traceback
        traceback.print_exc()
        
        # Try alternative parameter
        try:
            print("Trying to create QdrantClient with location parameter")
            client = QdrantClient(location=str(QDRANT_DIR))
            print("Successfully created QdrantClient with location parameter")
            return client
        except Exception as e2:
            print(f"Alternative initialization failed: {str(e2)}")
            
            # Try in-memory as last resort (for testing)
            try:
                print("FALLBACK: Creating in-memory QdrantClient")
                client = QdrantClient(":memory:")
                print("Created in-memory QdrantClient as fallback")
                return client
            except Exception as e3:
                print(f"Even in-memory Qdrant failed: {str(e3)}")
                import traceback
                traceback.print_exc()
                raise

def setup_retriever():
    """Setup a retriever that uses the Qdrant vector database."""
    print("Setting up retriever...")
    # Setup Qdrant client
    client = setup_qdrant_client()
    collection_name = "kohavi_ab_testing_pdf_collection"
    embedding_model = get_embedding_model()
    
    # Load document chunks
    chunks = load_document_chunks()
    print(f"Loaded {len(chunks)} document chunks")
    
    # Create a retriever class that implements get_relevant_documents
    class QdrantRetriever:
        def get_relevant_documents(self, query):
            print(f"Retrieving documents for: {query}")
            
            # Create embedding for query
            query_embedding = embedding_model.embed_query(query)
            print("Generated query embedding")
            
            # Map of document IDs to actual documents
            docs_by_id = {i: doc for i, doc in enumerate(chunks)}
            
            # Search using Qdrant
            print(f"Searching Qdrant collection '{collection_name}'...")
            
            try:
                # First try using query_points (newer method)
                results = client.query_points(
                    collection_name=collection_name,
                    query_vector=query_embedding,
                    limit=5
                )
                print(f"Found {len(results)} results using query_points")
            except Exception as e:
                print(f"query_points failed: {str(e)}")
                
                # Try search method as alternative
                results = client.search(
                    collection_name=collection_name,
                    query_vector=query_embedding,
                    limit=5
                )
                print(f"Found {len(results)} results using search")
            
            # Convert results to documents
            docs = []
            for result in results:
                doc_id = result.id
                if doc_id in docs_by_id:
                    docs.append(docs_by_id[doc_id])
                    print(f"Added document {doc_id}")
                else:
                    print(f"Document ID {doc_id} not found in chunks")
            
            print(f"Returning {len(docs)} documents from Qdrant")
            return docs
    
    return QdrantRetriever()

def rag_chain_node(query, run_manager):
    """A LangGraph node for retrieval augmented generation. Returns a string and sources."""
    print("Starting rag_chain_node...")
    # Log the query
    print(f"Query: {query}")
    
    # Get the chat model and retriever
    chat_model = get_chat_model()
    retriever = setup_retriever()
    
    # Log that we're retrieving documents
    print("Retrieving documents...")
    
    # Get relevant documents
    relevant_docs = retriever.get_relevant_documents(query)
    print(f"Retrieved {len(relevant_docs)} documents")
    
    # Print document sources for debugging
    sources = []
    for i, doc in enumerate(relevant_docs):
        source = doc.metadata.get("source", "Unknown")
        page = doc.metadata.get("page", "Unknown")
        print(f"Document {i+1} source: {source}, Page: {page}")
        
        # Extract source information for display
        source_path = source
        filename = source_path.split("/")[-1] if "/" in source_path else source_path
        
        # Remove .pdf extension if present
        if filename.lower().endswith('.pdf'):
            filename = filename[:-4]
        
        sources.append({
            "title": f"Ron Kohavi: {filename}",
            "page": page,
            "type": "pdf"
        })
    
    # Format documents to include in the prompt
    formatted_docs = "\n\n".join([f"Document from {doc.metadata.get('source', 'Unknown')}, Page {doc.metadata.get('page', 'Unknown')}:\n{doc.page_content}" for doc in relevant_docs])
    
    # Create a simple RAG prompt
    rag_prompt = f"""You are an AI assistant specializing in A/B testing and online experimentation.
Answer the following question using only the information provided in the documents below.
If you don't know the answer or the documents don't contain the relevant information, just say so.
Do not make up information or draw from knowledge outside of these documents.

Documents:
{formatted_docs}

Question: {query}

Answer:"""

    # Log that we're generating response
    print("Generating response...")
    
    # Generate response
    response = chat_model.invoke(rag_prompt)
    print("Successfully generated response")
    return response.content, sources

def evaluate_response(query, response):
    """
    Determines if the initial RAG response was sufficient using the original evaluation logic.
    Returns True if the response is sufficient, False otherwise.
    """
    print(f"Evaluating response for '{query}'")
    agent_model = get_agent_model()
    
    formatted_prompt = evaluate_prompt.format(
        initial_query=query,
        final_response=response
    )
    
    helpfulness_chain = agent_model
    messages = [HumanMessage(content=formatted_prompt)]
    helpfulness_response = helpfulness_chain.invoke(messages)
    
    # Check if 'Y' is in the response
    if "Y" in helpfulness_response.content:
        print("Evaluation: Initial response is sufficient")
        return True
    else:
        print("Evaluation: Initial response is NOT sufficient, need to use agent")
        return False

@tool
def retrieve_information(query: str) -> str:
    """Use Retrieval Augmented Generation to retrieve information about AB Testing."""
    # 1. Retrieve documents
    client = setup_qdrant_client()
    collection_name = "kohavi_ab_testing_pdf_collection"
    
    # Get embedding for the query
    embedding_model = get_embedding_model()
    query_embedding = embedding_model.embed_query(query)
    
    # Get documents
    chunks = load_document_chunks()
    
    # Map of document IDs to actual documents
    docs_by_id = {i: doc for i, doc in enumerate(chunks)}
    
    # Search for relevant documents
    try:
        search_results = client.search(
            collection_name=collection_name,
            query_vector=query_embedding,
            limit=5
        )
    except Exception as e:
        print(f"Error in search: {str(e)}")
        try:
            search_results = client.query_points(
                collection_name=collection_name,
                query_vector=query_embedding,
                limit=5
            )
        except Exception as e2:
            print(f"Error in query_points: {str(e2)}")
            return "Error retrieving documents."
    
    # Convert search results to documents
    docs = []
    for result in search_results:
        doc_id = result.id
        if doc_id in docs_by_id:
            docs.append(docs_by_id[doc_id])
    
    # 2. Extract and store sources
    sources = []
    for doc in docs:
        source_path = doc.metadata.get("source", "")
        filename = source_path.split("/")[-1] if "/" in source_path else source_path
        
        # Remove .pdf extension if present
        if filename.lower().endswith('.pdf'):
            filename = filename[:-4]
        
        sources.append({
            "title": f"Ron Kohavi: {filename}",
            "page": doc.metadata.get("page", "unknown"),
            "type": "pdf"
        })
    
    # Store sources for later access
    retrieve_information.last_sources = sources
    
    # 3. Return just the formatted document contents
    formatted_content = "\n\n".join([f"Retrieved Information: {i+1}\n{doc.page_content}" 
                                  for i, doc in enumerate(docs)])
    return formatted_content

@tool
def retrieve_information_with_rephrased_query(query: str) -> str:
    """This tool will intelligently rephrase your AB testing query and then will use Retrieval Augmented Generation to retrieve information about the rephrased query."""
    # 1. Rephrase the query first
    chat_model = get_chat_model()
    rephrased_query_msg = rephrase_query_prompt.format(question=query)
    rephrased_query_response = chat_model.invoke(rephrased_query_msg)
    rephrased_query = rephrased_query_response.content
    
    # 2. Retrieve documents using the rephrased query
    client = setup_qdrant_client()
    collection_name = "kohavi_ab_testing_pdf_collection"
    
    # Get embedding for the query
    embedding_model = get_embedding_model()
    query_embedding = embedding_model.embed_query(rephrased_query)
    
    # Get documents
    chunks = load_document_chunks()
    
    # Map of document IDs to actual documents
    docs_by_id = {i: doc for i, doc in enumerate(chunks)}
    
    # Search for relevant documents
    try:
        search_results = client.search(
            collection_name=collection_name,
            query_vector=query_embedding,
            limit=5
        )
    except Exception as e:
        print(f"Error in search: {str(e)}")
        try:
            search_results = client.query_points(
                collection_name=collection_name,
                query_vector=query_embedding,
                limit=5
            )
        except Exception as e2:
            print(f"Error in query_points: {str(e2)}")
            return f"Error retrieving documents with rephrased query: {rephrased_query}"
    
    # Convert search results to documents
    docs = []
    for result in search_results:
        doc_id = result.id
        if doc_id in docs_by_id:
            docs.append(docs_by_id[doc_id])
    
    # 3. Extract and store sources
    sources = []
    for doc in docs:
        source_path = doc.metadata.get("source", "")
        filename = source_path.split("/")[-1] if "/" in source_path else source_path
        
        # Remove .pdf extension if present
        if filename.lower().endswith('.pdf'):
            filename = filename[:-4]
        
        sources.append({
            "title": f"Ron Kohavi: {filename}",
            "page": doc.metadata.get("page", "unknown"),
            "type": "pdf"
        })
    
    # Store sources for later access
    retrieve_information_with_rephrased_query.last_sources = sources
    
    # 4. Return formatted content with rephrased query
    formatted_content = f"Rephrased query: {rephrased_query}\n\n" + "\n\n".join(
        [f"Retrieved Information: {i+1}\n{doc.page_content}" for i, doc in enumerate(docs)]
    )
    return formatted_content

@tool
def search_arxiv(query: str) -> str:
    """Search ArXiv for academic papers related to the query."""
    global ARXIV_SOURCES
    ARXIV_SOURCES = []  # Reset sources for new search
    
    try:
        # Check if the query is looking for a specific paper by title
        if "paper" in query.lower() and ("title" in query.lower() or "called" in query.lower() or "named" in query.lower() or "'" in query or '"' in query):
            # Try to extract paper title from quotes if present
            import re
            title_match = re.search(r'[\'"]([^\'"]+)[\'"]', query)
            
            if title_match:
                paper_title = title_match.group(1)
                # Use title-specific search with exact match
                formatted_query = f'ti:"{paper_title}"'
            else:
                # Fall back to general search but optimize for title
                formatted_query = query.replace(' ', '+')
                formatted_query = f'all:{formatted_query}'
        else:
            # General query
            formatted_query = query.replace(' ', '+')
            formatted_query = f'all:{formatted_query}'
        
        print(f"Searching ArXiv with query: {formatted_query}")
        url = f"http://export.arxiv.org/api/query?search_query={formatted_query}&start=0&max_results=5"
        
        response = requests.get(url)
        if response.status_code != 200:
            return "Error fetching data from ArXiv"
        
        # Parse response
        import xml.etree.ElementTree as ET
        root = ET.fromstring(response.text)
        
        results = []
        ns = {'atom': 'http://www.w3.org/2005/Atom'}
        
        # Count total entries
        total_entries = len(root.findall('atom:entry', ns))
        print(f"Found {total_entries} papers on ArXiv")
        
        # Clear previous sources and add new ones
        ARXIV_SOURCES.clear()
        
        for entry in root.findall('atom:entry', ns):
            title = entry.find('atom:title', ns).text
            authors = [author.find('atom:name', ns).text for author in entry.findall('atom:author', ns)]
            summary = entry.find('atom:summary', ns).text
            link = entry.find('atom:id', ns).text
            
            # Add to global sources list
            ARXIV_SOURCES.append({
                "title": title,
                "authors": ", ".join(authors),
                "type": "arxiv"
            })
            
            results.append({
                "title": title,
                "authors": ", ".join(authors),
                "summary": summary,
                "link": link
            })
        
        if not results:
            return "No papers found on ArXiv matching the query"
        
        # Format results as text
        text_results = []
        for i, paper in enumerate(results):
            text_results.append(f"Paper {i+1}:\nTitle: {paper['title']}\nAuthors: {paper['authors']}\nSummary: {paper['summary'][:300]}...\nLink: {paper['link']}\n")
        
        return "\n".join(text_results)
    except Exception as e:
        print(f"Error searching ArXiv: {str(e)}")
        import traceback
        traceback.print_exc()
        return f"Error searching ArXiv: {str(e)}"

def setup_agent():
    """Set up the agent with tools."""
    agent_model = get_agent_model()
    tools = [retrieve_information, retrieve_information_with_rephrased_query, search_arxiv]
    
    try:
        return create_openai_tools_agent(
            llm=agent_model,
            tools=tools,
            prompt=ChatPromptTemplate.from_messages([
                ("system", "You are an expert AB Testing assistant. Your job is to provide helpful, accurate information about AB Testing topics."),
                ("human", "{input}"),
                ("ai", "{agent_scratchpad}")
            ])
        )
    except Exception as e:
        print(f"Error creating agent: {str(e)}")
        return None

def execute_agent(agent, query):
    """Execute the agent with the given query."""
    try:
        executor = AgentExecutor(
            agent=agent,
            tools=[retrieve_information, retrieve_information_with_rephrased_query, search_arxiv],
            verbose=True,
            handle_parsing_errors=True
        )
        
        response = executor.invoke({"input": query})
        
        # Extract sources based on used tools
        sources = []
        if hasattr(retrieve_information, "last_sources"):
            sources = retrieve_information.last_sources
        elif hasattr(retrieve_information_with_rephrased_query, "last_sources"):
            sources = retrieve_information_with_rephrased_query.last_sources
        elif ARXIV_SOURCES:
            sources = ARXIV_SOURCES
        
        return response["output"], sources
    except Exception as e:
        print(f"Error executing agent: {str(e)}")
        import traceback
        traceback.print_exc()
        return "I'm having trouble processing your request. Please try again.", []

# Streamlit UI
st.set_page_config(
    page_title="πŸ“Š AB Testing RAG Agent",
    page_icon="πŸ“Š",
    layout="wide"
)

def main():
    """Main function for the Streamlit app."""
    st.title("πŸ“Š AB Testing RAG Agent")
    st.markdown("""
This specialized agent can answer questions about A/B Testing using a collection of Ron Kohavi's work. If it can't fully answer your A/B Testing questions using this collection, it will then automatically search Arxiv. Let's begin!
""")

    # Initialize chat history
    if "messages" not in st.session_state:
        st.session_state.messages = []

    # Display chat history
    for message in st.session_state.messages:
        with st.chat_message(message["role"]):
            st.markdown(message["content"])
            
            # Display sources if available
            if "sources" in message and message["sources"]:
                st.markdown("#### Sources")
                for i, source in enumerate(message["sources"]):
                    title = source.get("title", "Unknown")

                    # Display differently based on source type
                    if source.get("type") == "arxiv":
                        authors = source.get("authors", "Unknown authors")
                        st.markdown(f"**{i+1}. {title}**\nAuthors: {authors}")
                    else:
                        # PDF source with page number
                        page = source.get("page", "Unknown")
                        st.markdown(f"**{i+1}. {title}** (Page: {page})")

    # Input for new question
    query = st.chat_input("Ask a question about A/B Testing")

    if query:
        # Add user message to chat history
        st.session_state.messages.append({"role": "user", "content": query})
        
        # Display user message
        with st.chat_message("user"):
            st.markdown(query)
        
        # Display assistant response
        with st.chat_message("assistant"):
            message_placeholder = st.empty()
            
            with st.status("Processing your query...", expanded=True) as status:
                try:
                    # Use the RAG approach with a timeout
                    st.write("Starting with Initial RAG...")
                    print("Starting RAG process for query:", query)
                    
                    # Step 1: Initial RAG
                    response, sources = rag_chain_node(query, None)
                    
                    # Display the processed response
                    message_placeholder.markdown(response)
                    
                    # Add assistant message to chat history
                    st.session_state.messages.append({
                        "role": "assistant", 
                        "content": response,
                        "sources": sources
                    })
                    
                    status.update(label="Completed!", state="complete", expanded=False)
                except Exception as e:
                    error_msg = str(e)
                    if "Name or service not known" in error_msg:
                        response = "I'm having trouble connecting to the language model API due to network restrictions. The Hugging Face environment may be blocking external API calls."
                    else:
                        response = f"An error occurred: {error_msg}"
                    
                    message_placeholder.markdown(response)
                    st.session_state.messages.append({
                        "role": "assistant", 
                        "content": response,
                        "sources": []
                    })
                    status.update(label="Error", state="error", expanded=False)

if __name__ == "__main__":
    if query:
        main()