Spaces:
Sleeping
Sleeping
| import os | |
| import logging | |
| from typing import List, Optional, Dict, Any | |
| import streamlit as st | |
| from langchain_community.document_loaders import TextLoader | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain_community.vectorstores import FAISS | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| from langchain_core.prompts import PromptTemplate | |
| from langchain_community.llms import HuggingFacePipeline | |
| from langchain.agents import Tool, initialize_agent, AgentType | |
| from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer | |
| import requests | |
| import ast | |
| import torch | |
| import asyncio | |
| # Set environment variables for Streamlit and cache | |
| os.environ["STREAMLIT_CONFIG_DIR"] = "/tmp/.streamlit" | |
| os.environ["STREAMLIT_DATA_DIR"] = "/tmp/.streamlit" | |
| os.environ["STREAMLIT_SERVER_FILE_WATCHER_TYPE"] = "none" | |
| os.environ["HF_HOME"] = "/tmp/hf_cache" | |
| os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache" | |
| os.environ["SENTENCE_TRANSFORMERS_HOME"] = "/tmp/st_cache" | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| # Streamlit bootstrap to handle event loop | |
| try: | |
| import streamlit.web.bootstrap as bootstrap | |
| def patched_run(*args, **kwargs): | |
| try: | |
| loop = asyncio.get_running_loop() | |
| except RuntimeError: | |
| loop = asyncio.new_event_loop() | |
| asyncio.set_event_loop(loop) | |
| loop.run_until_complete(bootstrap.run(*args, **kwargs)) | |
| bootstrap.run = patched_run | |
| except ImportError: | |
| logger.warning("Could not patch Streamlit bootstrap; event loop errors may persist.") | |
| # Application configuration | |
| CONFIG = { | |
| "document_paths": [ | |
| "./src/nvidia_overview.txt", | |
| "./src/geforce_now_faq.txt", | |
| "./src/rtx_50_series_specs.txt", | |
| "./src/rtx_pro_6000.txt", | |
| "./src/investor_faq.txt" | |
| ], | |
| "chunk_size": 500, | |
| "chunk_overlap": 50, | |
| "embedding_model": "sentence-transformers/all-MiniLM-L6-v2", | |
| "llm_model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", | |
| "max_new_tokens": 100, # Reduced for CPU performance | |
| "temperature": 0.7, | |
| "top_p": 0.9, | |
| "retriever_k": 1 # Reduced for memory efficiency | |
| } | |
| # ---- Document Loading ---- # | |
| def load_documents() -> List[Any]: | |
| """ | |
| Load and split text documents into chunks for vector store processing. | |
| Returns a list of document chunks with metadata. | |
| """ | |
| logger.info("Loading documents...") | |
| try: | |
| text_splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=CONFIG["chunk_size"], | |
| chunk_overlap=CONFIG["chunk_overlap"] | |
| ) | |
| all_chunks = [] | |
| for path in CONFIG["document_paths"]: | |
| if not os.path.exists(path): | |
| logger.warning(f"Document not found: {path}, skipping...") | |
| continue | |
| loader = TextLoader(path) | |
| chunks = loader.load_and_split(text_splitter) | |
| for chunk in chunks: | |
| chunk.metadata["title"] = os.path.basename(path) | |
| all_chunks.extend(chunks) | |
| if not all_chunks: | |
| raise ValueError("No documents loaded.") | |
| logger.info(f"Loaded {len(all_chunks)} document chunks.") | |
| return all_chunks | |
| except Exception as e: | |
| logger.error(f"Error loading documents: {str(e)}") | |
| st.error(f"Failed to load documents: {str(e)}") | |
| return [] | |
| # ---- Vector Store Creation ---- # | |
| def create_vector_store(_documents: List[Any]) -> Optional[FAISS]: | |
| """ | |
| Create a FAISS vector store from document chunks using HuggingFace embeddings. | |
| Args: | |
| _documents: List of document chunks. | |
| Returns: | |
| FAISS vector store or None if creation fails. | |
| """ | |
| logger.info("Creating vector store...") | |
| try: | |
| if not _documents: | |
| raise ValueError("No documents provided for vector store.") | |
| embeddings = HuggingFaceEmbeddings( | |
| model_name=CONFIG["embedding_model"], | |
| cache_folder="/tmp/st_cache" | |
| ) | |
| vector_store = FAISS.from_documents(_documents, embeddings) | |
| logger.info("Vector store created successfully.") | |
| return vector_store | |
| except Exception as e: | |
| logger.error(f"Error creating vector store: {str(e)}") | |
| st.error(f"Failed to create vector store: {str(e)}") | |
| return None | |
| # ---- LLM Initialization ---- # | |
| def initialize_llm() -> Optional[HuggingFacePipeline]: | |
| """ | |
| Initialize the TinyLlama model without quantization for CPU compatibility. | |
| Returns: | |
| HuggingFacePipeline LLM or None if initialization fails. | |
| """ | |
| logger.info("Initializing LLM...") | |
| try: | |
| model = AutoModelForCausalLM.from_pretrained( | |
| CONFIG["llm_model"], | |
| device_map="cpu", | |
| torch_dtype=torch.float32, | |
| trust_remote_code=True, | |
| cache_dir="/tmp/hf_cache" | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| CONFIG["llm_model"], | |
| trust_remote_code=True, | |
| cache_dir="/tmp/hf_cache" | |
| ) | |
| text_gen = pipeline( | |
| "text-generation", | |
| model=model, | |
| tokenizer=tokenizer, | |
| max_new_tokens=CONFIG["max_new_tokens"], | |
| do_sample=True, | |
| temperature=CONFIG["temperature"], | |
| top_p=CONFIG["top_p"], | |
| truncation=True | |
| ) | |
| logger.info("LLM initialized successfully.") | |
| return HuggingFacePipeline(pipeline=text_gen) | |
| except Exception as e: | |
| logger.error(f"Error initializing LLM: {str(e)}") | |
| st.error(f"Failed to initialize LLM: {str(e)}") | |
| return None | |
| # ---- Context Retrieval ---- # | |
| def retrieve_context(query: str, vector_store: FAISS, k: int = CONFIG["retriever_k"]) -> List[str]: | |
| """ | |
| Retrieve relevant document chunks for a given query using the vector store. | |
| Args: | |
| query: User query string. | |
| vector_store: FAISS vector store. | |
| k: Number of documents to retrieve. | |
| Returns: | |
| List of document content strings. | |
| """ | |
| logger.info(f"Retrieving context for query: {query}") | |
| try: | |
| if not vector_store: | |
| raise ValueError("Vector store is not initialized.") | |
| docs = vector_store.similarity_search(query, k=k) | |
| logger.info(f"Retrieved {len(docs)} context chunks.") | |
| return [doc.page_content for doc in docs] | |
| except Exception as e: | |
| logger.error(f"Error retrieving context: {str(e)}") | |
| st.error(f"Failed to retrieve context: {str(e)}") | |
| return [] | |
| # ---- Answer Generation ---- # | |
| def generate_answer(query: str, context_chunks: List[str], llm: HuggingFacePipeline) -> str: | |
| """ | |
| Generate an answer for the query using the LLM and context chunks. | |
| Args: | |
| query: User query string. | |
| context_chunks: List of relevant document content. | |
| llm: Initialized LLM pipeline. | |
| Returns: | |
| Generated answer string. | |
| """ | |
| logger.info(f"Generating answer for query: {query}") | |
| try: | |
| if not llm: | |
| raise ValueError("LLM is not initialized.") | |
| context_str = "\n".join(context_chunks) if context_chunks else "No context available." | |
| prompt_template = PromptTemplate( | |
| input_variables=["context", "query"], | |
| template="Context:\n{context}\n\nQuestion: {query}\nAnswer:" | |
| ) | |
| prompt = prompt_template.format(context=context_str, query=query) | |
| response = llm(prompt) | |
| answer = response.split("Answer:")[-1].strip() if "Answer:" in response else response | |
| logger.info("Answer generated successfully.") | |
| return answer | |
| except Exception as e: | |
| logger.error(f"Error generating answer: {str(e)}") | |
| st.error(f"Failed to generate answer: {str(e)}") | |
| return "Unable to generate answer." | |
| # ---- Tool Definitions ---- # | |
| def calculator_tool(expr: str) -> str: | |
| """ | |
| Evaluate a mathematical expression. | |
| Args: | |
| expr: String containing the mathematical expression. | |
| Returns: | |
| Result of the calculation or error message. | |
| """ | |
| logger.info(f"Evaluating expression: {expr}") | |
| try: | |
| result = str(ast.literal_eval(expr)) | |
| logger.info(f"Calculation result: {result}") | |
| return result | |
| except Exception as e: | |
| logger.error(f"Calculation error: {str(e)}") | |
| return "Error in calculation: Invalid expression." | |
| def dictionary_tool(word: str) -> str: | |
| """ | |
| Fetch the definition of a word from an online dictionary API. | |
| Args: | |
| word: Word to define. | |
| Returns: | |
| Definition or error message. | |
| """ | |
| logger.info(f"Fetching definition for word: {word}") | |
| try: | |
| response = requests.get(f"https://api.dictionaryapi.dev/api/v2/entries/en/{word}", timeout=5) | |
| response.raise_for_status() | |
| definition = response.json()[0]["meanings"][0]["definitions"][0]["definition"] | |
| logger.info(f"Definition found: {definition}") | |
| return definition | |
| except Exception as e: | |
| logger.error(f"Dictionary error: {str(e)}") | |
| return "Definition not found." | |
| # ---- Agent Initialization ---- # | |
| def initialize_agent_with_tools(llm: HuggingFacePipeline, vector_store: FAISS) -> Optional[Any]: | |
| """ | |
| Initialize a LangChain agent with RAG, calculator, and dictionary tools. | |
| Args: | |
| llm: Initialized LLM pipeline. | |
| vector_store: FAISS vector store. | |
| Returns: | |
| Initialized agent or None if initialization fails. | |
| """ | |
| logger.info("Initializing agent...") | |
| try: | |
| if not llm or not vector_store: | |
| raise ValueError("LLM or vector store not initialized.") | |
| def rag_tool(query: str) -> str: | |
| context = retrieve_context(query, vector_store) | |
| return generate_answer(query, context, llm) | |
| tools = [ | |
| Tool( | |
| name="Calculator", | |
| func=calculator_tool, | |
| description="Use for mathematical calculations (e.g., 'calculate 5 + 3')." | |
| ), | |
| Tool( | |
| name="Dictionary", | |
| func=dictionary_tool, | |
| description="Use to find definitions of words (e.g., 'define GPU')." | |
| ), | |
| Tool( | |
| name="RAG", | |
| func=rag_tool, | |
| description="Use for general knowledge and contextual answers about NVIDIA." | |
| ) | |
| ] | |
| agent = initialize_agent( | |
| tools=tools, | |
| llm=llm, | |
| agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION, | |
| verbose=True, | |
| handle_parsing_errors=True, | |
| max_iterations=15, | |
| max_execution_time=30 | |
| ) | |
| logger.info("Agent initialized successfully.") | |
| return agent | |
| except Exception as e: | |
| logger.error(f"Error initializing agent: {str(e)}") | |
| st.error(f"Failed to initialize agent: {str(e)}") | |
| return None | |
| # ---- Streamlit Application ---- # | |
| def main(): | |
| """Main function to run the Streamlit application.""" | |
| st.set_page_config(page_title="NVIDIA Q&A Assistant", layout="centered") | |
| st.title("๐ค RAG-Powered NVIDIA Q&A Assistant") | |
| st.markdown( | |
| "Ask about NVIDIA products, company, or services, or use 'calculate' or 'define' for specific tasks." | |
| ) | |
| # Initialize session state | |
| if "initialized" not in st.session_state: | |
| with st.spinner("Initializing assistant..."): | |
| documents = load_documents() | |
| if not documents: | |
| st.stop() | |
| vector_store = create_vector_store(documents) | |
| if not vector_store: | |
| st.stop() | |
| llm = initialize_llm() | |
| if not llm: | |
| st.stop() | |
| agent = initialize_agent_with_tools(llm, vector_store) | |
| if not agent: | |
| st.stop() | |
| st.session_state.update({ | |
| "store": vector_store, | |
| "llm": llm, | |
| "agent": agent, | |
| "documents": documents, | |
| "initialized": True | |
| }) | |
| logger.info("Application initialized successfully.") | |
| # User input | |
| query = st.text_input( | |
| "Enter your question (e.g., 'What is GeForce NOW?', 'calculate 5 * 3', 'define GPU'):" | |
| ) | |
| if query: | |
| if len(query.strip()) < 3: | |
| st.error("Query is too short. Please provide a more detailed question.") | |
| return | |
| with st.spinner("Processing query..."): | |
| try: | |
| agent = st.session_state.agent | |
| vector_store = st.session_state.store | |
| result = agent.run(query) | |
| # Determine tool used and retrieve context for RAG | |
| context = [] | |
| tool_used = "Agent (unknown tool)" | |
| if "RAG" in str(result) or not ("Calculator" in str(result) or "Dictionary" in str(result)): | |
| context = retrieve_context(query, vector_store) | |
| tool_used = "RAG" | |
| elif "Calculator" in str(result): | |
| tool_used = "Calculator" | |
| elif "Dictionary" in str(result): | |
| tool_used = "Dictionary" | |
| # Display results | |
| if context: | |
| with st.expander("๐ Retrieved Context"): | |
| for i, chunk in enumerate(context, 1): | |
| st.markdown(f"**Snippet {i}:** {chunk}") | |
| else: | |
| st.info("No context snippets available for this query (e.g., calculator or dictionary task).") | |
| st.success(f"๐ก Answer: {result}") | |
| st.markdown(f"**Tool Used:** {tool_used}") | |
| logger.info(f"Query processed successfully: {query}") | |
| except Exception as e: | |
| logger.error(f"Error processing query: {str(e)}") | |
| st.error(f"Failed to process query: {str(e)}") | |
| if __name__ == "__main__": | |
| main() |