import os import chromadb from dotenv import load_dotenv from langchain_core.documents import Document from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.embeddings import HuggingFaceEmbeddings from langchain_community.llms import OpenAI from langchain.chains.question_answering import load_qa_chain from langchain_community.vectorstores import Chroma from langchain.chains import LLMChain from langchain.prompts import PromptTemplate from langchain.chat_models import ChatOpenAI from config import Config load_dotenv() # ChromaDB configuration CHROMA_HOST = Config.RAG_CHROMA_HOST CHROMA_PORT = Config.RAG_CHROMA_PORT COLLECTION_NAME = Config.RAG_COLLECTION_NAME # LLM Provider Configuration LLM_PROVIDER = Config.RAG_LLM_PROVIDER LLM_API_KEY = Config.RAG_LLM_API_KEY LLM_MODEL = Config.RAG_LLM_MODEL LLM_TEMPERATURE = Config.RAG_LLM_TEMPERATURE LLM_MAX_TOKENS = Config.RAG_LLM_MAX_TOKENS # Provider-specific configurations PROVIDER_CONFIGS = { "openai": { "api_base": "https://api.openai.com/v1", "default_model": "gpt-3.5-turbo" }, "groq": { "api_base": "https://api.groq.com/openai/v1", "default_model": "llama-3.3-70b-versatile" }, "openrouter": { "api_base": "https://openrouter.ai/api/v1", "default_model": "mistralai/mistral-small-3.2-24b-instruct:free" } } vector_store = None company_qa_chain = None query_router_chain = None cybersecurity_chain = None llm = None def get_llm_config(): """Get the appropriate LLM configuration based on the provider.""" if LLM_PROVIDER not in PROVIDER_CONFIGS: raise ValueError(f"Unsupported LLM provider: {LLM_PROVIDER}. Supported: {list(PROVIDER_CONFIGS.keys())}") config = PROVIDER_CONFIGS[LLM_PROVIDER].copy() # Use provided model or fall back to default model = LLM_MODEL if LLM_MODEL != "gpt-3.5-turbo" else config["default_model"] return { "model": model, "openai_api_key": LLM_API_KEY, "openai_api_base": config["api_base"], "temperature": LLM_TEMPERATURE, "max_tokens": LLM_MAX_TOKENS, } def initialize_llm(): """Initialize the LLM based on the configured provider.""" if not LLM_API_KEY: raise ValueError(f"LLM_API_KEY environment variable is required for {LLM_PROVIDER}") config = get_llm_config() print(f"Initializing {LLM_PROVIDER.upper()} with model: {config['model']}") return ChatOpenAI(**config) def initialize_pipelines(): """Initializes all required models, chains, and the vector store.""" global vector_store, company_qa_chain, query_router_chain, cybersecurity_chain, llm try: # Initialize LLM llm = initialize_llm() # Initialize embeddings embeddings = HuggingFaceEmbeddings( model_name="all-MiniLM-L6-v2", model_kwargs={'device': 'cpu'}, encode_kwargs={'normalize_embeddings': True} ) # Initialize ChromaDB client try: chroma_client = chromadb.HttpClient(host=CHROMA_HOST, port=CHROMA_PORT) chroma_client.heartbeat() except Exception as e: raise ConnectionError("Failed to connect to ChromaDB.") from e # Initialize vector store vector_store = Chroma( client=chroma_client, collection_name=COLLECTION_NAME, embedding_function=embeddings, ) # Query Router Chain router_template = """You are a query classifier. Classify the following query into one of these categories: - COMPANY: Questions about our company, its products, services, or general information - CYBERSECURITY: Questions about cybersecurity, security threats, best practices, or vulnerabilities - OFF_TOPIC: Questions that don't fit the above categories Query: {query} Respond with only the category name (COMPANY, CYBERSECURITY, or OFF_TOPIC):""" router_prompt = PromptTemplate( input_variables=["query"], template=router_template ) query_router_chain = LLMChain( llm=llm, prompt=router_prompt ) # Custom Company QA Chain company_qa_template = """You are a helpful assistant for CyberAlertNepal. Answer the following question about our company using the information provided and links if only available. Give a natural, direct and polite response. Question: {question} Information: {context} Answer:""" company_qa_prompt = PromptTemplate( input_variables=["question", "context"], template=company_qa_template ) company_qa_chain = LLMChain( llm=llm, prompt=company_qa_prompt ) # Cybersecurity Chain cybersecurity_template = """You are a cybersecurity professional. Answer the following question truthfully and concisely. If you are not 100% sure about the answer, simply respond with: "I am not sure about the answer." Do not add extra explanations or assumptions. Do not provide false or speculative information. Question: {question} Provide a comprehensive and accurate answer about cybersecurity:""" cybersecurity_prompt = PromptTemplate( input_variables=["question"], template=cybersecurity_template ) cybersecurity_chain = LLMChain( llm=llm, prompt=cybersecurity_prompt ) print(f"Successfully initialized pipelines with {LLM_PROVIDER.upper()}") except Exception as e: print(f"Error initializing pipelines: {e}") raise def add_document_to_rag(text: str, metadata: dict): """Splits a document and adds it to the ChromaDB index.""" global vector_store if not vector_store: initialize_pipelines() try: text_splitter = RecursiveCharacterTextSplitter( chunk_size=1000, chunk_overlap=200 ) docs = text_splitter.create_documents([text], metadatas=[metadata]) if not docs: print("Document was empty after splitting, not adding to ChromaDB.") return False vector_store.add_documents(docs) print("Successfully added documents.") return True except Exception as e: print(f"Error adding document to RAG: {e}") return False def route_and_process_query(query: str): """Routes the query and processes it using the appropriate pipeline.""" global query_router_chain, vector_store, company_qa_chain, cybersecurity_chain if not all([query_router_chain, vector_store, company_qa_chain, cybersecurity_chain]): initialize_pipelines() try: # 1. Classify the query route_result = query_router_chain.run(query) route = route_result.strip().upper() # 2. Route to appropriate logic if "CYBERSECURITY" in route: answer = cybersecurity_chain.run(question=query) return { "answer": answer, "source": "Cybersecurity Knowledge Base", "route": "CYBERSECURITY", "provider": LLM_PROVIDER.upper(), "model": get_llm_config()["model"] } elif "COMPANY" in route: # Perform similarity search on ChromaDB docs = vector_store.similarity_search(query, k=3) if not docs: return { "answer": "I could not find any relevant information to answer your question.", "source": "Company Documents", "route": "COMPANY", "provider": LLM_PROVIDER.upper(), "model": get_llm_config()["model"] } # Combine document content for context context = "\n\n".join([doc.page_content for doc in docs]) # Run the custom QA chain answer = company_qa_chain.run(question=query, context=context) sources = list(set([doc.metadata.get("source", "Unknown") for doc in docs])) return { "answer": answer, "source": "Company Documents", "documents": sources, "route": "COMPANY", "provider": LLM_PROVIDER.upper(), "model": get_llm_config()["model"] } else: # OFF_TOPIC return { "answer": "I am a specialized assistant of CyberAlertNepal. I cannot answer questions outside of cybersecurity topics.", "source": "N/A", "route": "OFF_TOPIC", "provider": LLM_PROVIDER.upper(), "model": get_llm_config()["model"] } except Exception as e: print(f"Error processing query: {e}") return { "answer": "I encountered an error while processing your query. Please try again.", "source": "Error", "route": None, "documents": None, "provider": LLM_PROVIDER.upper(), "error": str(e) } def check_system_health(): """Check if all components are properly initialized.""" try: # Test ChromaDB connection if vector_store: vector_store._client.heartbeat() # Test if all chains are initialized components = { "vector_store": vector_store is not None, "company_qa_chain": company_qa_chain is not None, "query_router_chain": query_router_chain is not None, "cybersecurity_chain": cybersecurity_chain is not None, "llm": llm is not None } return { "status": "healthy" if all(components.values()) else "unhealthy", "components": components, "provider": LLM_PROVIDER.upper(), "model": get_llm_config()["model"] if llm else "Not initialized" } except Exception as e: return { "status": "unhealthy", "error": str(e), "provider": LLM_PROVIDER.upper() } def test_llm_connection(): """Test the LLM API connection.""" try: if not llm: initialize_pipelines() # Simple test query test_response = llm("Say 'Hello, LLM is working!'") return { "success": True, "provider": LLM_PROVIDER.upper(), "model": get_llm_config()["model"], "response": str(test_response) } except Exception as e: return { "success": False, "provider": LLM_PROVIDER.upper(), "error": str(e) } # Initialize pipelines on module import try: initialize_pipelines() except Exception as e: print(f"Failed to initialize pipelines on startup: {e}")