File size: 28,673 Bytes
521bc19
 
 
 
 
 
 
 
 
 
 
00ddd85
 
 
 
 
cdee3b2
00ddd85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39e1b73
00ddd85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39e1b73
00ddd85
 
 
 
 
 
 
 
 
 
 
 
39e1b73
 
00ddd85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39e1b73
00ddd85
 
 
 
 
 
 
 
 
39e1b73
00ddd85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39e1b73
00ddd85
39e1b73
00ddd85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39e1b73
00ddd85
 
 
39e1b73
00ddd85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39e1b73
00ddd85
 
 
 
 
 
 
 
 
 
39e1b73
 
00ddd85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39e1b73
00ddd85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39e1b73
00ddd85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39e1b73
00ddd85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39e1b73
00ddd85
 
 
 
 
39e1b73
00ddd85
 
 
 
 
 
 
 
 
 
cdee3b2
00ddd85
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
import os
import sys

# CRITICAL: Set Streamlit directories BEFORE importing streamlit
# This prevents permission errors in restricted environments
os.environ["STREAMLIT_CONFIG_DIR"] = "/tmp/.streamlit"
os.environ["STREAMLIT_HOME"] = "/tmp/.streamlit"
os.environ["STREAMLIT_GLOBAL_CONFIG_DIR"] = "/tmp/.streamlit"

# Create the directory immediately
os.makedirs("/tmp/.streamlit", exist_ok=True)

import requests
import arxiv
from pathlib import Path
from typing import List, Dict, Optional
import streamlit as st
from llama_index.core import (
    VectorStoreIndex, 
    SimpleDirectoryReader, 
    Settings,
    Document,
    StorageContext,
    load_index_from_storage
)
from llama_index.core.node_parser import SentenceSplitter
from llama_index.llms.groq import Groq
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.core.retrievers import VectorIndexRetriever
from llama_index.core.response_synthesizers import get_response_synthesizer
from llama_index.core.memory import ChatMemoryBuffer
from llama_index.core.chat_engine import CondensePlusContextChatEngine
import logging
import hashlib
import json
import time
from datetime import datetime

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class AcademicPaperQA:
    def __init__(self, model_name="llama3-70b-8192", groq_api_key=None):
        """Initialize the Academic Paper Q&A system with Groq API"""
        self.data_dir = Path("./papers")
        self.storage_dir = Path("./storage")
        self.model_name = model_name
        self.groq_api_key = groq_api_key
        
        # Create directories
        self.data_dir.mkdir(exist_ok=True)
        self.storage_dir.mkdir(exist_ok=True)
        
        # Initialize models
        self._setup_models()
        
        # Initialize index and chat engine
        self.index = None
        self.query_engine = None
        self.chat_engine = None
        self.current_papers_hash = None
        self.is_ready = False
        
        # Chat history
        self.chat_history = []
    
    def _setup_models(self):
        """Setup LLM and embedding models with Groq API"""
        try:
            if not self.groq_api_key:
                raise ValueError("Groq API key is required. Please set GROQ_API_KEY environment variable or pass it directly.")
            
            # Initialize LLM via Groq API with conservative token settings
            self.llm = Groq(
                model=self.model_name,
                api_key=self.groq_api_key,
                temperature=0.3,
                max_tokens=2048,
                top_p=0.9,
                system_prompt="""You are an expert academic research assistant. Provide comprehensive, detailed responses about research papers including:
1. Direct answers to questions
2. Relevant background context
3. Specific details from papers including methodologies and findings
4. Analysis and interpretation
5. Connections between concepts when relevant
Keep responses thorough but concise to stay within token limits."""
            )
            
            # Initialize lightweight embedding model for CPU usage
            try:
                self.embed_model = HuggingFaceEmbedding(
                    model_name="sentence-transformers/all-MiniLM-L6-v2",
                    device="cpu",
                    max_length=512
                )
            except Exception as e:
                logger.warning(f"Failed to load HuggingFace embedding, trying alternative: {e}")
                self.embed_model = HuggingFaceEmbedding(
                    model_name="BAAI/bge-small-en-v1.5",
                    device="cpu",
                    max_length=512
                )
            
            # Configure global settings with conservative values
            Settings.llm = self.llm
            Settings.embed_model = self.embed_model
            Settings.chunk_size = 256
            Settings.chunk_overlap = 25
            
            logger.info(f"Models initialized successfully with {self.model_name} via Groq API")
            
        except Exception as e:
            logger.error(f"Error setting up models: {e}")
            raise
    
    def _get_papers_hash(self) -> str:
        """Generate hash of current papers in directory"""
        pdf_files = list(self.data_dir.glob("*.pdf"))
        if not pdf_files:
            return ""
        
        file_info = []
        for pdf_file in sorted(pdf_files):
            file_info.append(f"{pdf_file.name}:{pdf_file.stat().st_size}")
        
        papers_string = "|".join(file_info)
        return hashlib.md5(papers_string.encode()).hexdigest()
    
    def _save_papers_metadata(self, papers_hash: str):
        """Save metadata about current papers"""
        metadata_file = self.storage_dir / "papers_metadata.json"
        metadata = {
            "papers_hash": papers_hash,
            "model_name": self.model_name
        }
        with open(metadata_file, "w") as f:
            json.dump(metadata, f)
    
    def _load_papers_metadata(self) -> Dict:
        """Load metadata about papers"""
        metadata_file = self.storage_dir / "papers_metadata.json"
        if metadata_file.exists():
            with open(metadata_file, "r") as f:
                return json.load(f)
        return {}
    
    def download_arxiv_paper(self, arxiv_id: str) -> Optional[str]:
        """Download paper from arXiv"""
        try:
            search = arxiv.Search(id_list=[arxiv_id])
            paper = next(search.results())
            
            filename = f"{arxiv_id.replace('/', '_')}.pdf"
            filepath = self.data_dir / filename
            
            paper.download_pdf(dirpath=str(self.data_dir), filename=filename)
            
            logger.info(f"Downloaded paper: {paper.title}")
            return str(filepath)
            
        except Exception as e:
            logger.error(f"Error downloading paper {arxiv_id}: {e}")
            return None
    
    def load_documents(self, file_paths: List[str] = None) -> List[Document]:
        """Load documents from PDF files with error handling"""
        try:
            if file_paths is None:
                reader = SimpleDirectoryReader(
                    input_dir=str(self.data_dir),
                    required_exts=[".pdf"],
                    recursive=False
                )
            else:
                reader = SimpleDirectoryReader(input_files=file_paths)
            
            documents = reader.load_data()
            logger.info(f"Loaded {len(documents)} documents")
            
            cleaned_documents = []
            for doc in documents:
                if doc.text and len(doc.text.strip()) > 50:
                    if len(doc.text) > 50000:
                        doc.text = doc.text[:50000] + "... [Document truncated]"
                    cleaned_documents.append(doc)
            
            logger.info(f"After cleaning: {len(cleaned_documents)} valid documents")
            return cleaned_documents
            
        except Exception as e:
            logger.error(f"Error loading documents: {e}")
            return []
    
    def create_index(self, documents: List[Document], save_index: bool = True):
        """Create vector index from documents with CPU-optimized settings"""
        try:
            if not documents:
                raise ValueError("No documents provided for indexing")
                
            logger.info(f"Creating index from {len(documents)} documents")
            
            sentence_splitter = SentenceSplitter(
                chunk_size=256,
                chunk_overlap=25,
                separator=" "
            )
            
            batch_size = 5
            all_nodes = []
            
            for i in range(0, len(documents), batch_size):
                batch = documents[i:i + batch_size]
                logger.info(f"Processing batch {i//batch_size + 1}/{(len(documents) + batch_size - 1)//batch_size}")
                
                nodes = sentence_splitter.get_nodes_from_documents(batch)
                all_nodes.extend(nodes)
            
            self.index = VectorStoreIndex(
                nodes=all_nodes,
                show_progress=True
            )
            
            if save_index:
                self.index.storage_context.persist(persist_dir=str(self.storage_dir))
                current_hash = self._get_papers_hash()
                self._save_papers_metadata(current_hash)
                self.current_papers_hash = current_hash
                logger.info("Index saved to storage")
            
            self._create_query_engine()
            self._create_chat_engine()
            self.is_ready = True
            logger.info("Vector index created successfully")
            
        except Exception as e:
            logger.error(f"Error creating index: {e}")
            self.is_ready = False
            raise
    
    def should_rebuild_index(self) -> bool:
        """Check if index should be rebuilt based on papers"""
        current_hash = self._get_papers_hash()
        
        if not current_hash:
            return False
            
        metadata = self._load_papers_metadata()
        
        if not metadata:
            logger.info("No metadata found, rebuilding index")
            return True
        
        if metadata.get("papers_hash") != current_hash:
            logger.info("Papers hash changed, rebuilding index")
            return True
            
        if metadata.get("model_name") != self.model_name:
            logger.info("Model changed, rebuilding index")
            return True
            
        return False
    
    def load_index(self) -> bool:
        """Load existing index from storage if it matches current papers"""
        try:
            if self.should_rebuild_index():
                logger.info("Index needs to be rebuilt due to changes")
                return False
            
            index_files = list(self.storage_dir.glob("*"))
            if not index_files:
                logger.info("No index files found")
                return False
            
            storage_context = StorageContext.from_defaults(
                persist_dir=str(self.storage_dir)
            )
            self.index = load_index_from_storage(storage_context)
            self._create_query_engine()
            self._create_chat_engine()
            self.current_papers_hash = self._get_papers_hash()
            self.is_ready = True
            
            logger.info("Index loaded from storage successfully")
            return True
            
        except Exception as e:
            logger.error(f"Error loading index: {e}")
            self.is_ready = False
            return False
    
    def _create_query_engine(self):
        """Create query engine with settings for detailed responses"""
        try:
            if not self.index:
                raise ValueError("No index available for query engine")
                
            retriever = VectorIndexRetriever(
                index=self.index,
                similarity_top_k=2
            )
            
            response_synthesizer = get_response_synthesizer(
                response_mode="compact",
                streaming=False,
                text_qa_template="""Context information is below.
---------------------
{context_str}
---------------------
Based on the context information, provide a comprehensive answer to the question. Include specific details from the research papers and explain key concepts clearly.
Question: {query_str}
Answer: """
            )
            
            self.query_engine = RetrieverQueryEngine(
                retriever=retriever,
                response_synthesizer=response_synthesizer
            )
            
            logger.info("Query engine created successfully")
            
        except Exception as e:
            logger.error(f"Error creating query engine: {e}")
            raise
    
    def _create_chat_engine(self):
        """Create chat engine for conversational interactions with conservative settings"""
        try:
            if not self.index:
                raise ValueError("No index available for chat engine")
            
            memory = ChatMemoryBuffer.from_defaults(token_limit=1000)
            
            self.chat_engine = CondensePlusContextChatEngine.from_defaults(
                retriever=VectorIndexRetriever(
                    index=self.index,
                    similarity_top_k=2
                ),
                memory=memory,
                llm=self.llm,
                context_prompt=(
                    "You are an expert academic research assistant. "
                    "Use the following context to answer questions thoroughly but concisely. "
                    "Context:\n{context_str}\n"
                    "Answer the user's question based on the provided context."
                ),
                verbose=True,
                context_window=4096,
                max_tokens=1500
            )
            
            logger.info("Chat engine created successfully")
            
        except Exception as e:
            logger.error(f"Error creating chat engine: {e}")
            raise
    
    def get_loaded_papers_info(self) -> List[str]:
        """Get list of currently loaded papers"""
        pdf_files = list(self.data_dir.glob("*.pdf"))
        return [pdf_file.name for pdf_file in pdf_files]
    
    def clear_papers(self):
        """Clear all papers and reset index"""
        try:
            for pdf_file in self.data_dir.glob("*.pdf"):
                pdf_file.unlink()
            
            if self.storage_dir.exists():
                import shutil
                shutil.rmtree(self.storage_dir)
                self.storage_dir.mkdir(exist_ok=True)
            
            self.index = None
            self.query_engine = None
            self.chat_engine = None
            self.current_papers_hash = None
            self.is_ready = False
            self.chat_history = []
            
            logger.info("Papers and index cleared")
            return True
            
        except Exception as e:
            logger.error(f"Error clearing papers: {e}")
            return False
    
    def clear_chat_history(self):
        """Clear chat history and reset memory"""
        try:
            self.chat_history = []
            if self.chat_engine and hasattr(self.chat_engine, 'memory'):
                self.chat_engine.memory.reset()
            logger.info("Chat history cleared")
        except Exception as e:
            logger.error(f"Error clearing chat history: {e}")
    
    def process_all_papers(self) -> Dict[str, str]:
        """Process all papers in the directory and create/load index"""
        try:
            current_papers = self.get_loaded_papers_info()
            if not current_papers:
                return {"error": "No papers found in directory"}
            
            logger.info(f"Processing {len(current_papers)} papers: {current_papers}")
            
            if self.load_index():
                return {"success": f"Loaded existing index for {len(current_papers)} papers"}
            
            logger.info("Creating new index from documents...")
            documents = self.load_documents()
            
            if not documents:
                return {"error": "Failed to load documents from PDF files"}
            
            self.create_index(documents)
            
            if self.is_ready:
                return {"success": f"Successfully created index for {len(current_papers)} papers"}
            else:
                return {"error": "Failed to create index"}
                
        except Exception as e:
            logger.error(f"Error processing papers: {e}")
            return {"error": f"Error processing papers: {str(e)}"}
    
    def ask_question(self, question: str, use_chat_engine: bool = True) -> Dict[str, any]:
        """Ask a question using either chat engine (conversational) or query engine (standalone)"""
        if not self.is_ready:
            return {"error": "System not ready. Please process papers first."}
        
        try:
            logger.info(f"Asking question: {question}")
            
            if len(question) > 500:
                question = question[:500] + "..."
                logger.warning("Question truncated to prevent context overflow")
            
            if use_chat_engine and self.chat_engine:
                try:
                    response = self.chat_engine.chat(question)
                    answer = str(response)
                except Exception as chat_error:
                    logger.warning(f"Chat engine failed, falling back to query engine: {chat_error}")
                    response = self.query_engine.query(question)
                    answer = str(response)
                    use_chat_engine = False
                
            else:
                response = self.query_engine.query(question)  
                answer = str(response)
            
            self.chat_history.append({
                "timestamp": datetime.now().strftime("%H:%M:%S"),
                "question": question,
                "answer": answer,
                "type": "chat" if use_chat_engine else "query"
            })
            
            sources = []
            if hasattr(response, 'source_nodes') and response.source_nodes:
                for i, node in enumerate(response.source_nodes):
                    sources.append({
                        'text': node.text[:300] + "..." if len(node.text) > 300 else node.text,
                        'score': node.score if hasattr(node, 'score') else 'N/A'
                    })
            
            logger.info(f"Generated answer length: {len(answer)} characters")
            
            return {
                "answer": answer,
                "sources": sources,
                "timestamp": datetime.now().strftime("%H:%M:%S")
            }
            
        except Exception as e:
            logger.error(f"Error answering question: {e}")
            return {"error": f"Error processing question: {str(e)}"}

def create_streamlit_app():
    """Create Streamlit web interface with chat functionality"""
    st.set_page_config(
        page_title="Academic Paper Q&A Bot (Groq Powered)",
        page_icon="πŸ”¬",
        layout="wide"
    )
    
    st.title("πŸ”¬ Academic Paper Q&A Bot (Groq Powered)")
    
    st.markdown("""
    <style>
    .chat-message {
        padding: 1rem;
        border-radius: 0.5rem;
        margin-bottom: 1rem;
        display: flex;
        flex-direction: column;
    }
    .user-message {
        background-color: #e3f2fd;
        margin-left: 20%;
    }
    .bot-message {
        background-color: #f5f5f5;
        margin-right: 20%;
    }
    .message-content {
        margin: 0.5rem 0;
    }
    .message-timestamp {
        font-size: 0.8rem;
        color: #666;
        align-self: flex-end;
    }
    </style>
    """, unsafe_allow_html=True)
    
    st.sidebar.header("πŸ”‘ API Configuration")
    groq_api_key = st.sidebar.text_input(
        "Groq API Key:",
        type="password",
        help="Get your free API key from https://console.groq.com/keys"
    )
    
    if not groq_api_key:
        groq_api_key = os.getenv("GROQ_API_KEY")
    
    if not groq_api_key:
        st.sidebar.error("Please enter your Groq API key or set GROQ_API_KEY environment variable")
        st.info("πŸ”‘ **To get started:**\n1. Go to https://console.groq.com/keys\n2. Create a free account\n3. Generate an API key\n4. Enter it in the sidebar")
        st.stop()
    
    st.sidebar.header("βš™οΈ Configuration")
    model_options = {
        "Llama3 8B (Fast & Stable)": "llama3-8b-8192", 
        "Llama3 70B (Most Capable)": "llama3-70b-8192",
        "Mixtral 8x7B (Balanced)": "mixtral-8x7b-32768",
        "Gemma 7B (Efficient)": "gemma-7b-it"
    }
    
    selected_model = st.sidebar.selectbox(
        "Choose Groq Model:",
        list(model_options.keys()),
        index=0
    )
    
    model_name = model_options[selected_model]
    
    if ('qa_system' not in st.session_state or 
        st.session_state.get('current_model') != model_name or
        st.session_state.get('current_api_key') != groq_api_key):
        
        with st.spinner(f"Initializing system with {selected_model}..."):
            try:
                st.session_state.qa_system = AcademicPaperQA(
                    model_name=model_name, 
                    groq_api_key=groq_api_key
                )
                st.session_state.current_model = model_name
                st.session_state.current_api_key = groq_api_key
                st.session_state.papers_loaded = False
                st.success(f"System initialized with {selected_model} via Groq API!")
            except Exception as e:
                st.error(f"Error initializing system: {e}")
                st.info("Please check your Groq API key and try again.")
                st.stop()
    
    if 'papers_loaded' not in st.session_state:
        st.session_state.papers_loaded = False
    
    st.sidebar.info(f"**Current model:** {selected_model}")
    st.sidebar.success("βœ… Using Groq API (Cloud)")
    st.sidebar.info("πŸ’¬ Conversational Mode: ON")
    
    if hasattr(st.session_state.qa_system, 'is_ready'):
        if st.session_state.qa_system.is_ready:
            st.sidebar.success("βœ… System Ready")
        else:
            st.sidebar.warning("⚠️ Process papers first")
    
    current_papers = st.session_state.qa_system.get_loaded_papers_info()
    if current_papers:
        st.sidebar.subheader("πŸ“š Loaded Papers:")
        for paper in current_papers:
            st.sidebar.text(f"πŸ“„ {paper}")
        
        if st.sidebar.button("πŸ—‘οΈ Clear All Papers"):
            with st.spinner("Clearing papers..."):
                if st.session_state.qa_system.clear_papers():
                    st.session_state.papers_loaded = False
                    st.sidebar.success("Papers cleared!")
                    st.rerun()
    
    st.sidebar.subheader("πŸ’¬ Chat Controls")
    if st.sidebar.button("🧹 Clear Chat History"):
        st.session_state.qa_system.clear_chat_history()
        st.sidebar.success("Chat cleared!")
        st.rerun()
    
    if not st.session_state.qa_system.is_ready:
        st.header("πŸ“₯ Load Academic Papers")
        
        col1, col2 = st.columns([1, 1])
        
        with col1:
            st.subheader("From arXiv")
            arxiv_id = st.text_input("Enter arXiv ID (e.g., 2301.00001)")
            if st.button("Download from arXiv"):
                if arxiv_id:
                    with st.spinner("Downloading paper..."):
                        filepath = st.session_state.qa_system.download_arxiv_paper(arxiv_id)
                        if filepath:
                            st.success(f"Downloaded paper")
                            st.session_state.papers_loaded = False
                        else:
                            st.error("Failed to download paper")
        
        with col2:
            st.subheader("Upload PDF Files")
            uploaded_files = st.file_uploader(
                "Choose PDF files",
                type="pdf",
                accept_multiple_files=True
            )
            
            if uploaded_files:
                saved_files = []
                for uploaded_file in uploaded_files:
                    file_path = st.session_state.qa_system.data_dir / uploaded_file.name
                    with open(file_path, "wb") as f:
                        f.write(uploaded_file.getbuffer())
                    saved_files.append(str(file_path))
                
                st.success(f"Uploaded {len(saved_files)} files")
                st.session_state.papers_loaded = False
        
        st.subheader("πŸ”„ Process Papers")
        current_papers = st.session_state.qa_system.get_loaded_papers_info()
        
        if not current_papers:
            st.info("No papers found. Please upload or download papers first.")
        else:
            st.info(f"Found {len(current_papers)} paper(s): {', '.join(current_papers)}")
            
            if st.button("πŸš€ Process Papers", type="primary"):
                with st.spinner("Processing papers (creating embeddings on CPU)..."):
                    result = st.session_state.qa_system.process_all_papers()
                    
                    if "error" in result:
                        st.error(result["error"])
                        st.session_state.papers_loaded = False
                    else:
                        st.success(result["success"])
                        st.session_state.papers_loaded = True
                        st.rerun()
    
    else:
        st.header("πŸ’¬ Chat with Your Papers")
        
        loaded_papers = st.session_state.qa_system.get_loaded_papers_info()
        st.info(f"πŸ“š Chatting with {len(loaded_papers)} paper(s): {', '.join(loaded_papers)}")
        
        chat_container = st.container()
        
        with chat_container:
            for i, message in enumerate(st.session_state.qa_system.chat_history[-10:]):
                st.markdown(f"""
                    <div class="chat-message user-message">
                        <div class="message-content" style="color: black;">
                            <strong>You:</strong> {message['question']}
                        </div>
                        <div class="message-timestamp">{message['timestamp']}</div>
                    </div>
                    """, unsafe_allow_html=True)
                
                st.markdown(f"""
                <div class="chat-message bot-message">
                    <div class="message-content"><strong style="color: black;">Assistant:</strong></div>
                </div>
                """, unsafe_allow_html=True)
                
                st.write(message['answer'])
                st.markdown("---")
        
        st.subheader("πŸš€ Quick Questions")
        col1, col2, col3 = st.columns(3)
        
        quick_question = None
        with col1:
            if st.button("🎯 Main Research Question"):
                quick_question = "What is the main research question addressed in this paper?"
            if st.button("πŸ”¬ Methodology"):
                quick_question = "What methodology was used in this study?"
        
        with col2:
            if st.button("πŸ“Š Key Findings"):
                quick_question = "What are the key findings of this research?"
            if st.button("🎯 Conclusions"):
                quick_question = "What are the main conclusions of this research?"
        
        with col3:
            if st.button("⚠️ Limitations"):
                quick_question = "What are the limitations of this study?"
            if st.button("πŸ“‹ Summary"):
                quick_question = "Please provide a summary of this paper."
        
        st.subheader("πŸ’­ Ask Your Question")
        user_question = st.text_area("Type your question here...", height=100, placeholder="Ask anything about your papers...")
        
        question_to_ask = quick_question if quick_question else user_question
        
        if st.button("Send Message", type="primary", disabled=not question_to_ask):
            if question_to_ask:
                with st.spinner("Thinking... (Processing via Groq API)"):
                    result = st.session_state.qa_system.ask_question(
                        question_to_ask, 
                        use_chat_engine=True
                    )
                    
                    if "error" in result:
                        st.error(result["error"])
                    else:
                        st.rerun()
        
        if (st.session_state.qa_system.chat_history and 
            st.session_state.qa_system.chat_history[-1].get('sources')):
            
            with st.expander("πŸ“š View Sources", expanded=False):
                sources = st.session_state.qa_system.chat_history[-1]['sources']
                for i, source in enumerate(sources, 1):
                    st.markdown(f"**Source {i}** (Relevance: {source['score']})")
                    st.text(source['text'])
                    st.markdown("---")

if __name__ == "__main__":
    create_streamlit_app()