Spaces:
Sleeping
Sleeping
| import os | |
| import logging | |
| from typing import List, Dict | |
| import torch | |
| import gradio as gr | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain.embeddings import HuggingFaceEmbeddings | |
| from langchain.vectorstores import FAISS | |
| from langchain.chains import RetrievalQA | |
| from langchain.prompts import PromptTemplate | |
| from langchain.llms import HuggingFacePipeline | |
| from langchain_community.document_loaders import PyPDFLoader | |
| from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer | |
| import spaces | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(levelname)s - %(message)s' | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # Get HuggingFace token from environment variables | |
| hf_token = os.environ.get('HUGGINGFACE_TOKEN') or os.environ.get('HF_TOKEN') | |
| if not hf_token: | |
| logger.error("No Hugging Face token found in environment variables") | |
| logger.error("Please set either HUGGINGFACE_TOKEN or HF_TOKEN in your Space settings") | |
| raise ValueError("Missing Hugging Face token. Please configure it in the Space settings under Repository Secrets.") | |
| # Constants | |
| MODEL_NAME = "meta-llama/Llama-2-7b-chat-hf" | |
| KNOWLEDGE_BASE_DIR = "." | |
| class DocumentLoader: | |
| """Class to manage PDF document loading.""" | |
| def load_pdfs(directory_path: str) -> List: | |
| documents = [] | |
| pdf_files = [ | |
| f for f in os.listdir(directory_path) | |
| if f.endswith('.pdf') and | |
| (f.startswith('valencia') or 'fislac' in f.lower() or 'Valencia' in f) | |
| ] | |
| if not pdf_files: | |
| logger.warning(f"No matching PDF files found in {directory_path}") | |
| return documents | |
| for pdf_file in pdf_files: | |
| pdf_path = os.path.join(directory_path, pdf_file) | |
| try: | |
| loader = PyPDFLoader(pdf_path) | |
| pdf_documents = loader.load() | |
| for doc in pdf_documents: | |
| doc.metadata.update({ | |
| 'title': pdf_file, | |
| 'type': 'technical' if 'valencia' in pdf_file.lower() or 'Valencia' in pdf_file else 'qa', | |
| 'language': 'en', | |
| 'page': doc.metadata.get('page', 0) | |
| }) | |
| documents.append(doc) | |
| logger.info(f"Document {pdf_file} loaded successfully") | |
| except Exception as e: | |
| logger.error(f"Error loading {pdf_file}: {str(e)}") | |
| return documents | |
| class TextProcessor: | |
| """Class to process and split text into chunks.""" | |
| def __init__(self): | |
| self.technical_splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=800, | |
| chunk_overlap=200, | |
| separators=["\n\n", "\n", ". ", " ", ""], | |
| length_function=len | |
| ) | |
| self.qa_splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=500, | |
| chunk_overlap=100, | |
| separators=["\n\n", "\n", ". ", " ", ""], | |
| length_function=len | |
| ) | |
| def process_documents(self, documents: List) -> List: | |
| if not documents: | |
| logger.warning("No documents to process") | |
| return [] | |
| processed_chunks = [] | |
| for doc in documents: | |
| splitter = self.technical_splitter if doc.metadata['type'] == 'technical' else self.qa_splitter | |
| chunks = splitter.split_documents([doc]) | |
| processed_chunks.extend(chunks) | |
| logger.info(f"Documents processed into {len(processed_chunks)} chunks") | |
| return processed_chunks | |
| class RAGSystem: | |
| """Main RAG system class.""" | |
| def __init__(self, model_name: str = MODEL_NAME): | |
| self.model_name = model_name | |
| self.embeddings = None | |
| self.vector_store = None | |
| self.qa_chain = None | |
| self.tokenizer = None | |
| self.model = None | |
| def initialize_system(self): | |
| """Initialize complete RAG system.""" | |
| try: | |
| logger.info("Starting RAG system initialization...") | |
| # Load and process documents | |
| loader = DocumentLoader() | |
| documents = loader.load_pdfs(KNOWLEDGE_BASE_DIR) | |
| if not documents: | |
| raise ValueError("No documents were loaded. Please check the PDF files in the root directory.") | |
| processor = TextProcessor() | |
| processed_chunks = processor.process_documents(documents) | |
| if not processed_chunks: | |
| raise ValueError("No chunks were created from the documents.") | |
| # Initialize embeddings | |
| logger.info("Initializing embeddings...") | |
| self.embeddings = HuggingFaceEmbeddings( | |
| model_name="intfloat/multilingual-e5-large", | |
| model_kwargs={'device': 'cuda' if torch.cuda.is_available() else 'cpu'}, | |
| encode_kwargs={'normalize_embeddings': True} | |
| ) | |
| # Create vector store | |
| logger.info("Creating vector store...") | |
| self.vector_store = FAISS.from_documents( | |
| processed_chunks, | |
| self.embeddings | |
| ) | |
| # Initialize LLM | |
| logger.info("Initializing language model...") | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| self.model_name, | |
| token=hf_token, | |
| trust_remote_code=True | |
| ) | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| self.model_name, | |
| token=hf_token, | |
| torch_dtype=torch.float16, | |
| trust_remote_code=True, | |
| device_map="auto" | |
| ) | |
| # Create generation pipeline | |
| logger.info("Creating generation pipeline...") | |
| pipe = pipeline( | |
| "text-generation", | |
| model=self.model, | |
| tokenizer=self.tokenizer, | |
| max_new_tokens=512, | |
| temperature=0.1, | |
| top_p=0.95, | |
| repetition_penalty=1.15, | |
| device_map="auto" | |
| ) | |
| llm = HuggingFacePipeline(pipeline=pipe) | |
| # Create prompt template | |
| prompt_template = """ | |
| Context: {context} | |
| Based on the context above, please provide a clear and concise answer to the following question. | |
| If the information is not in the context, explicitly state so. | |
| Question: {question} | |
| """ | |
| PROMPT = PromptTemplate( | |
| template=prompt_template, | |
| input_variables=["context", "question"] | |
| ) | |
| # Set up QA chain | |
| logger.info("Setting up QA chain...") | |
| self.qa_chain = RetrievalQA.from_chain_type( | |
| llm=llm, | |
| chain_type="stuff", | |
| retriever=self.vector_store.as_retriever( | |
| search_kwargs={"k": 6} | |
| ), | |
| return_source_documents=True, | |
| chain_type_kwargs={"prompt": PROMPT} | |
| ) | |
| logger.info("RAG system initialized successfully") | |
| except Exception as e: | |
| logger.error(f"Error during RAG system initialization: {str(e)}") | |
| raise | |
| def generate_response(self, question: str) -> Dict: | |
| """Generate response for a given question.""" | |
| try: | |
| result = self.qa_chain({"query": question}) | |
| response = { | |
| 'answer': result['result'], | |
| 'sources': [] | |
| } | |
| for doc in result['source_documents']: | |
| source = { | |
| 'title': doc.metadata.get('title', 'Unknown'), | |
| 'content': doc.page_content[:200] + "..." if len(doc.page_content) > 200 else doc.page_content, | |
| 'metadata': doc.metadata | |
| } | |
| response['sources'].append(source) | |
| return response | |
| except Exception as e: | |
| logger.error(f"Error generating response: {str(e)}") | |
| raise | |
| def process_response(user_input: str, chat_history: List) -> tuple: | |
| """Process user input and generate response.""" | |
| try: | |
| response = rag_system.generate_response(user_input) | |
| # Clean and format response | |
| answer = response['answer'] | |
| if "Answer:" in answer: | |
| answer = answer.split("Answer:")[-1].strip() | |
| # Format sources | |
| sources = set([source['title'] for source in response['sources'][:3]]) | |
| if sources: | |
| answer += "\n\nπ Sources consulted:\n" + "\n".join([f"β’ {source}" for source in sources]) | |
| chat_history.append((user_input, answer)) | |
| return chat_history | |
| except Exception as e: | |
| logger.error(f"Error in process_response: {str(e)}") | |
| error_message = f"Sorry, an error occurred: {str(e)}" | |
| chat_history.append((user_input, error_message)) | |
| return chat_history | |
| # Initialize RAG system | |
| logger.info("Initializing RAG system...") | |
| try: | |
| rag_system = RAGSystem() | |
| rag_system.initialize_system() | |
| logger.info("RAG system initialization completed") | |
| except Exception as e: | |
| logger.error(f"Failed to initialize RAG system: {str(e)}") | |
| raise | |
| # Create Gradio interface | |
| try: | |
| logger.info("Creating Gradio interface...") | |
| with gr.Blocks(css="div.gradio-container {background-color: #f0f2f6}") as demo: | |
| gr.HTML(""" | |
| <div style="text-align: center; max-width: 800px; margin: 0 auto; padding: 20px;"> | |
| <h1 style="color: #2d333a;">π FislacBot</h1> | |
| <p style="color: #4a5568;"> | |
| AI Assistant specialized in fiscal analysis and FISLAC documentation | |
| </p> | |
| </div> | |
| """) | |
| chatbot = gr.Chatbot( | |
| show_label=False, | |
| container=True, | |
| height=500, | |
| bubble_full_width=True, | |
| show_copy_button=True, | |
| scale=2 | |
| ) | |
| with gr.Row(): | |
| message = gr.Textbox( | |
| placeholder="π Type your question here...", | |
| show_label=False, | |
| container=False, | |
| scale=8, | |
| autofocus=True | |
| ) | |
| clear = gr.Button("ποΈ Clear", size="sm", scale=1) | |
| # Suggested questions | |
| gr.HTML('<p style="color: #2d333a; font-weight: bold; margin: 20px 0 10px 0;">π‘ Suggested questions:</p>') | |
| with gr.Row(): | |
| suggestion1 = gr.Button("What is FISLAC?", scale=1) | |
| suggestion2 = gr.Button("What are the main modules of FISLAC?", scale=1) | |
| with gr.Row(): | |
| suggestion3 = gr.Button("What macroeconomic variables are relevant for advanced economies?", scale=1) | |
| suggestion4 = gr.Button("How does fiscal risk compare between emerging and advanced countries?", scale=1) | |
| # Footer | |
| gr.HTML(""" | |
| <div style="text-align: center; max-width: 800px; margin: 20px auto; padding: 20px; | |
| background-color: #f8f9fa; border-radius: 10px;"> | |
| <div style="margin-bottom: 15px;"> | |
| <h3 style="color: #2d333a;">π About this assistant</h3> | |
| <p style="color: #666; font-size: 14px;"> | |
| This bot uses RAG (Retrieval Augmented Generation) technology combining: | |
| </p> | |
| <ul style="list-style: none; color: #666; font-size: 14px;"> | |
| <li>πΉ LLM Engine: Llama-2-7b-chat-hf</li> | |
| <li>πΉ Embeddings: multilingual-e5-large</li> | |
| <li>πΉ Vector Store: FAISS</li> | |
| </ul> | |
| </div> | |
| <div style="border-top: 1px solid #ddd; padding-top: 15px;"> | |
| <p style="color: #666; font-size: 14px;"> | |
| <strong>Current Knowledge Base:</strong><br> | |
| β’ Valencia et al. (2022) - "Assessing macro-fiscal risk for Latin American and Caribbean countries"<br> | |
| β’ FISLAC Technical Documentation | |
| </p> | |
| </div> | |
| <div style="border-top: 1px solid #ddd; margin-top: 15px; padding-top: 15px;"> | |
| <p style="color: #666; font-size: 14px;"> | |
| Created by <a href="https://www.linkedin.com/in/camilo-vega-169084b1/" | |
| target="_blank" style="color: #2196F3; text-decoration: none;">Camilo Vega</a>, | |
| AI Consultant π€ | |
| </p> | |
| </div> | |
| </div> | |
| """) | |
| # Configure event handlers | |
| def submit(user_input, chat_history): | |
| return process_response(user_input, chat_history) | |
| message.submit(submit, [message, chatbot], [chatbot]) | |
| clear.click(lambda: None, None, chatbot) | |
| # Handle suggested questions | |
| for btn in [suggestion1, suggestion2, suggestion3, suggestion4]: | |
| btn.click(submit, [btn, chatbot], [chatbot]) | |
| logger.info("Gradio interface created successfully") | |
| demo.launch() | |
| except Exception as e: | |
| logger.error(f"Error in Gradio interface creation: {str(e)}") | |
| raise |