| import os |
| import chromadb |
| from dotenv import load_dotenv |
| from langchain_core.documents import Document |
| from langchain.text_splitter import RecursiveCharacterTextSplitter |
| from langchain_community.embeddings import HuggingFaceEmbeddings |
| from langchain_community.llms import OpenAI |
| from langchain.chains.question_answering import load_qa_chain |
| from langchain_community.vectorstores import Chroma |
| from langchain.chains import LLMChain |
| from langchain.prompts import PromptTemplate |
| from langchain.chat_models import ChatOpenAI |
| from config import Config |
|
|
|
|
| load_dotenv() |
|
|
| |
| CHROMA_HOST = Config.RAG_CHROMA_HOST |
| CHROMA_PORT = Config.RAG_CHROMA_PORT |
| COLLECTION_NAME = Config.RAG_COLLECTION_NAME |
|
|
| |
| LLM_PROVIDER = Config.RAG_LLM_PROVIDER |
| LLM_API_KEY = Config.RAG_LLM_API_KEY |
| LLM_MODEL = Config.RAG_LLM_MODEL |
| LLM_TEMPERATURE = Config.RAG_LLM_TEMPERATURE |
| LLM_MAX_TOKENS = Config.RAG_LLM_MAX_TOKENS |
|
|
| |
| PROVIDER_CONFIGS = { |
| "openai": { |
| "api_base": "https://api.openai.com/v1", |
| "default_model": "gpt-3.5-turbo" |
| }, |
| "groq": { |
| "api_base": "https://api.groq.com/openai/v1", |
| "default_model": "llama-3.3-70b-versatile" |
| }, |
| "openrouter": { |
| "api_base": "https://openrouter.ai/api/v1", |
| "default_model": "mistralai/mistral-small-3.2-24b-instruct:free" |
| } |
| } |
|
|
| vector_store = None |
| company_qa_chain = None |
| query_router_chain = None |
| cybersecurity_chain = None |
| llm = None |
|
|
| def get_llm_config(): |
| """Get the appropriate LLM configuration based on the provider.""" |
| if LLM_PROVIDER not in PROVIDER_CONFIGS: |
| raise ValueError(f"Unsupported LLM provider: {LLM_PROVIDER}. Supported: {list(PROVIDER_CONFIGS.keys())}") |
| |
| config = PROVIDER_CONFIGS[LLM_PROVIDER].copy() |
| |
| |
| model = LLM_MODEL if LLM_MODEL != "gpt-3.5-turbo" else config["default_model"] |
| |
| return { |
| "model": model, |
| "openai_api_key": LLM_API_KEY, |
| "openai_api_base": config["api_base"], |
| "temperature": LLM_TEMPERATURE, |
| "max_tokens": LLM_MAX_TOKENS, |
| } |
|
|
| def initialize_llm(): |
| """Initialize the LLM based on the configured provider.""" |
| if not LLM_API_KEY: |
| raise ValueError(f"LLM_API_KEY environment variable is required for {LLM_PROVIDER}") |
| |
| config = get_llm_config() |
| |
| print(f"Initializing {LLM_PROVIDER.upper()} with model: {config['model']}") |
| |
| return ChatOpenAI(**config) |
|
|
| def initialize_pipelines(): |
| """Initializes all required models, chains, and the vector store.""" |
| global vector_store, company_qa_chain, query_router_chain, cybersecurity_chain, llm |
| |
| try: |
| |
| llm = initialize_llm() |
|
|
| |
| embeddings = HuggingFaceEmbeddings( |
| model_name="all-MiniLM-L6-v2", |
| model_kwargs={'device': 'cpu'}, |
| encode_kwargs={'normalize_embeddings': True} |
| ) |
| |
| |
| try: |
| chroma_client = chromadb.HttpClient(host=CHROMA_HOST, port=CHROMA_PORT) |
| chroma_client.heartbeat() |
| except Exception as e: |
| raise ConnectionError("Failed to connect to ChromaDB.") from e |
| |
| |
| vector_store = Chroma( |
| client=chroma_client, |
| collection_name=COLLECTION_NAME, |
| embedding_function=embeddings, |
| ) |
|
|
| |
| router_template = """You are a query classifier. Classify the following query into one of these categories: |
| - COMPANY: Questions about our company, its products, services, or general information |
| - CYBERSECURITY: Questions about cybersecurity, security threats, best practices, or vulnerabilities |
| - OFF_TOPIC: Questions that don't fit the above categories |
| |
| Query: {query} |
| |
| Respond with only the category name (COMPANY, CYBERSECURITY, or OFF_TOPIC):""" |
| |
| router_prompt = PromptTemplate( |
| input_variables=["query"], |
| template=router_template |
| ) |
| |
| query_router_chain = LLMChain( |
| llm=llm, |
| prompt=router_prompt |
| ) |
| |
| |
| company_qa_template = """You are a helpful assistant for CyberAlertNepal. Answer the following question about our company using the information provided and links if only available. Give a natural, direct and polite response. |
| |
| Question: {question} |
| |
| Information: |
| {context} |
| |
| Answer:""" |
| |
| company_qa_prompt = PromptTemplate( |
| input_variables=["question", "context"], |
| template=company_qa_template |
| ) |
| |
| company_qa_chain = LLMChain( |
| llm=llm, |
| prompt=company_qa_prompt |
| ) |
| |
| |
| cybersecurity_template = """You are a cybersecurity professional. Answer the following question truthfully and concisely. |
| If you are not 100% sure about the answer, simply respond with: "I am not sure about the answer." |
| Do not add extra explanations or assumptions. Do not provide false or speculative information. |
| |
| Question: {question} |
| |
| Provide a comprehensive and accurate answer about cybersecurity:""" |
| |
| cybersecurity_prompt = PromptTemplate( |
| input_variables=["question"], |
| template=cybersecurity_template |
| ) |
| |
| cybersecurity_chain = LLMChain( |
| llm=llm, |
| prompt=cybersecurity_prompt |
| ) |
| |
| print(f"Successfully initialized pipelines with {LLM_PROVIDER.upper()}") |
|
|
| except Exception as e: |
| print(f"Error initializing pipelines: {e}") |
| raise |
|
|
| def add_document_to_rag(text: str, metadata: dict): |
| """Splits a document and adds it to the ChromaDB index.""" |
| global vector_store |
| |
| if not vector_store: |
| initialize_pipelines() |
| |
| try: |
| text_splitter = RecursiveCharacterTextSplitter( |
| chunk_size=1000, |
| chunk_overlap=200 |
| ) |
| docs = text_splitter.create_documents([text], metadatas=[metadata]) |
| |
| if not docs: |
| print("Document was empty after splitting, not adding to ChromaDB.") |
| return False |
|
|
| vector_store.add_documents(docs) |
| print("Successfully added documents.") |
| return True |
| |
| except Exception as e: |
| print(f"Error adding document to RAG: {e}") |
| return False |
|
|
| def route_and_process_query(query: str): |
| """Routes the query and processes it using the appropriate pipeline.""" |
| global query_router_chain, vector_store, company_qa_chain, cybersecurity_chain |
| |
| if not all([query_router_chain, vector_store, company_qa_chain, cybersecurity_chain]): |
| initialize_pipelines() |
|
|
| try: |
| |
| route_result = query_router_chain.run(query) |
| route = route_result.strip().upper() |
| |
|
|
| |
| if "CYBERSECURITY" in route: |
| answer = cybersecurity_chain.run(question=query) |
| return { |
| "answer": answer, |
| "source": "Cybersecurity Knowledge Base", |
| "route": "CYBERSECURITY", |
| "provider": LLM_PROVIDER.upper(), |
| "model": get_llm_config()["model"] |
| } |
| |
| elif "COMPANY" in route: |
| |
| docs = vector_store.similarity_search(query, k=3) |
|
|
| if not docs: |
| return { |
| "answer": "I could not find any relevant information to answer your question.", |
| "source": "Company Documents", |
| "route": "COMPANY", |
| "provider": LLM_PROVIDER.upper(), |
| "model": get_llm_config()["model"] |
| } |
| |
| |
| context = "\n\n".join([doc.page_content for doc in docs]) |
| |
| |
| answer = company_qa_chain.run(question=query, context=context) |
| sources = list(set([doc.metadata.get("source", "Unknown") for doc in docs])) |
| |
| return { |
| "answer": answer, |
| "source": "Company Documents", |
| "documents": sources, |
| "route": "COMPANY", |
| "provider": LLM_PROVIDER.upper(), |
| "model": get_llm_config()["model"] |
| } |
|
|
| else: |
| return { |
| "answer": "I am a specialized assistant of CyberAlertNepal. I cannot answer questions outside of cybersecurity topics.", |
| "source": "N/A", |
| "route": "OFF_TOPIC", |
| "provider": LLM_PROVIDER.upper(), |
| "model": get_llm_config()["model"] |
| } |
| |
| except Exception as e: |
| print(f"Error processing query: {e}") |
| return { |
| "answer": "I encountered an error while processing your query. Please try again.", |
| "source": "Error", |
| "route": None, |
| "documents": None, |
| "provider": LLM_PROVIDER.upper(), |
| "error": str(e) |
| } |
|
|
| def check_system_health(): |
| """Check if all components are properly initialized.""" |
| try: |
| |
| if vector_store: |
| vector_store._client.heartbeat() |
| |
| |
| components = { |
| "vector_store": vector_store is not None, |
| "company_qa_chain": company_qa_chain is not None, |
| "query_router_chain": query_router_chain is not None, |
| "cybersecurity_chain": cybersecurity_chain is not None, |
| "llm": llm is not None |
| } |
| |
| return { |
| "status": "healthy" if all(components.values()) else "unhealthy", |
| "components": components, |
| "provider": LLM_PROVIDER.upper(), |
| "model": get_llm_config()["model"] if llm else "Not initialized" |
| } |
| |
| except Exception as e: |
| return { |
| "status": "unhealthy", |
| "error": str(e), |
| "provider": LLM_PROVIDER.upper() |
| } |
|
|
| def test_llm_connection(): |
| """Test the LLM API connection.""" |
| try: |
| if not llm: |
| initialize_pipelines() |
| |
| |
| test_response = llm("Say 'Hello, LLM is working!'") |
| return { |
| "success": True, |
| "provider": LLM_PROVIDER.upper(), |
| "model": get_llm_config()["model"], |
| "response": str(test_response) |
| } |
| except Exception as e: |
| return { |
| "success": False, |
| "provider": LLM_PROVIDER.upper(), |
| "error": str(e) |
| } |
|
|
| |
| try: |
| initialize_pipelines() |
| except Exception as e: |
| print(f"Failed to initialize pipelines on startup: {e}") |