Spaces:
Sleeping
Sleeping
| import os | |
| import chromadb | |
| from dotenv import load_dotenv | |
| from langchain_core.documents import Document | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| from langchain_community.llms import OpenAI | |
| from langchain.chains.question_answering import load_qa_chain | |
| from langchain_community.vectorstores import Chroma | |
| from langchain.chains import LLMChain | |
| from langchain.prompts import PromptTemplate | |
| from langchain.chat_models import ChatOpenAI | |
| from config import Config | |
| load_dotenv() | |
| # ChromaDB configuration | |
| CHROMA_HOST = Config.RAG_CHROMA_HOST | |
| CHROMA_PORT = Config.RAG_CHROMA_PORT | |
| COLLECTION_NAME = Config.RAG_COLLECTION_NAME | |
| # LLM Provider Configuration | |
| LLM_PROVIDER = Config.RAG_LLM_PROVIDER | |
| LLM_API_KEY = Config.RAG_LLM_API_KEY | |
| LLM_MODEL = Config.RAG_LLM_MODEL | |
| LLM_TEMPERATURE = Config.RAG_LLM_TEMPERATURE | |
| LLM_MAX_TOKENS = Config.RAG_LLM_MAX_TOKENS | |
| # Provider-specific configurations | |
| PROVIDER_CONFIGS = { | |
| "openai": { | |
| "api_base": "https://api.openai.com/v1", | |
| "default_model": "gpt-3.5-turbo" | |
| }, | |
| "groq": { | |
| "api_base": "https://api.groq.com/openai/v1", | |
| "default_model": "llama-3.3-70b-versatile" | |
| }, | |
| "openrouter": { | |
| "api_base": "https://openrouter.ai/api/v1", | |
| "default_model": "mistralai/mistral-small-3.2-24b-instruct:free" | |
| } | |
| } | |
| vector_store = None | |
| company_qa_chain = None | |
| query_router_chain = None | |
| cybersecurity_chain = None | |
| llm = None | |
| def get_llm_config(): | |
| """Get the appropriate LLM configuration based on the provider.""" | |
| if LLM_PROVIDER not in PROVIDER_CONFIGS: | |
| raise ValueError(f"Unsupported LLM provider: {LLM_PROVIDER}. Supported: {list(PROVIDER_CONFIGS.keys())}") | |
| config = PROVIDER_CONFIGS[LLM_PROVIDER].copy() | |
| # Use provided model or fall back to default | |
| model = LLM_MODEL if LLM_MODEL != "gpt-3.5-turbo" else config["default_model"] | |
| return { | |
| "model": model, | |
| "openai_api_key": LLM_API_KEY, | |
| "openai_api_base": config["api_base"], | |
| "temperature": LLM_TEMPERATURE, | |
| "max_tokens": LLM_MAX_TOKENS, | |
| } | |
| def initialize_llm(): | |
| """Initialize the LLM based on the configured provider.""" | |
| if not LLM_API_KEY: | |
| raise ValueError(f"LLM_API_KEY environment variable is required for {LLM_PROVIDER}") | |
| config = get_llm_config() | |
| print(f"Initializing {LLM_PROVIDER.upper()} with model: {config['model']}") | |
| return ChatOpenAI(**config) | |
| def initialize_pipelines(): | |
| """Initializes all required models, chains, and the vector store.""" | |
| global vector_store, company_qa_chain, query_router_chain, cybersecurity_chain, llm | |
| try: | |
| # Initialize LLM | |
| llm = initialize_llm() | |
| # Initialize embeddings | |
| embeddings = HuggingFaceEmbeddings( | |
| model_name="all-MiniLM-L6-v2", | |
| model_kwargs={'device': 'cpu'}, | |
| encode_kwargs={'normalize_embeddings': True} | |
| ) | |
| # Initialize ChromaDB client | |
| try: | |
| chroma_client = chromadb.HttpClient(host=CHROMA_HOST, port=CHROMA_PORT) | |
| chroma_client.heartbeat() | |
| except Exception as e: | |
| raise ConnectionError("Failed to connect to ChromaDB.") from e | |
| # Initialize vector store | |
| vector_store = Chroma( | |
| client=chroma_client, | |
| collection_name=COLLECTION_NAME, | |
| embedding_function=embeddings, | |
| ) | |
| # Query Router Chain | |
| router_template = """You are a query classifier. Classify the following query into one of these categories: | |
| - COMPANY: Questions about our company, its products, services, or general information | |
| - CYBERSECURITY: Questions about cybersecurity, security threats, best practices, or vulnerabilities | |
| - OFF_TOPIC: Questions that don't fit the above categories | |
| Query: {query} | |
| Respond with only the category name (COMPANY, CYBERSECURITY, or OFF_TOPIC):""" | |
| router_prompt = PromptTemplate( | |
| input_variables=["query"], | |
| template=router_template | |
| ) | |
| query_router_chain = LLMChain( | |
| llm=llm, | |
| prompt=router_prompt | |
| ) | |
| # Custom Company QA Chain | |
| company_qa_template = """You are a helpful assistant for CyberAlertNepal. Answer the following question about our company using the information provided and links if only available. Give a natural, direct and polite response. | |
| Question: {question} | |
| Information: | |
| {context} | |
| Answer:""" | |
| company_qa_prompt = PromptTemplate( | |
| input_variables=["question", "context"], | |
| template=company_qa_template | |
| ) | |
| company_qa_chain = LLMChain( | |
| llm=llm, | |
| prompt=company_qa_prompt | |
| ) | |
| # Cybersecurity Chain | |
| cybersecurity_template = """You are a cybersecurity professional. Answer the following question truthfully and concisely. | |
| If you are not 100% sure about the answer, simply respond with: "I am not sure about the answer." | |
| Do not add extra explanations or assumptions. Do not provide false or speculative information. | |
| Question: {question} | |
| Provide a comprehensive and accurate answer about cybersecurity:""" | |
| cybersecurity_prompt = PromptTemplate( | |
| input_variables=["question"], | |
| template=cybersecurity_template | |
| ) | |
| cybersecurity_chain = LLMChain( | |
| llm=llm, | |
| prompt=cybersecurity_prompt | |
| ) | |
| print(f"Successfully initialized pipelines with {LLM_PROVIDER.upper()}") | |
| except Exception as e: | |
| print(f"Error initializing pipelines: {e}") | |
| raise | |
| def add_document_to_rag(text: str, metadata: dict): | |
| """Splits a document and adds it to the ChromaDB index.""" | |
| global vector_store | |
| if not vector_store: | |
| initialize_pipelines() | |
| try: | |
| text_splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=1000, | |
| chunk_overlap=200 | |
| ) | |
| docs = text_splitter.create_documents([text], metadatas=[metadata]) | |
| if not docs: | |
| print("Document was empty after splitting, not adding to ChromaDB.") | |
| return False | |
| vector_store.add_documents(docs) | |
| print("Successfully added documents.") | |
| return True | |
| except Exception as e: | |
| print(f"Error adding document to RAG: {e}") | |
| return False | |
| def route_and_process_query(query: str): | |
| """Routes the query and processes it using the appropriate pipeline.""" | |
| global query_router_chain, vector_store, company_qa_chain, cybersecurity_chain | |
| if not all([query_router_chain, vector_store, company_qa_chain, cybersecurity_chain]): | |
| initialize_pipelines() | |
| try: | |
| # 1. Classify the query | |
| route_result = query_router_chain.run(query) | |
| route = route_result.strip().upper() | |
| # 2. Route to appropriate logic | |
| if "CYBERSECURITY" in route: | |
| answer = cybersecurity_chain.run(question=query) | |
| return { | |
| "answer": answer, | |
| "source": "Cybersecurity Knowledge Base", | |
| "route": "CYBERSECURITY", | |
| "provider": LLM_PROVIDER.upper(), | |
| "model": get_llm_config()["model"] | |
| } | |
| elif "COMPANY" in route: | |
| # Perform similarity search on ChromaDB | |
| docs = vector_store.similarity_search(query, k=3) | |
| if not docs: | |
| return { | |
| "answer": "I could not find any relevant information to answer your question.", | |
| "source": "Company Documents", | |
| "route": "COMPANY", | |
| "provider": LLM_PROVIDER.upper(), | |
| "model": get_llm_config()["model"] | |
| } | |
| # Combine document content for context | |
| context = "\n\n".join([doc.page_content for doc in docs]) | |
| # Run the custom QA chain | |
| answer = company_qa_chain.run(question=query, context=context) | |
| sources = list(set([doc.metadata.get("source", "Unknown") for doc in docs])) | |
| return { | |
| "answer": answer, | |
| "source": "Company Documents", | |
| "documents": sources, | |
| "route": "COMPANY", | |
| "provider": LLM_PROVIDER.upper(), | |
| "model": get_llm_config()["model"] | |
| } | |
| else: # OFF_TOPIC | |
| return { | |
| "answer": "I am a specialized assistant of CyberAlertNepal. I cannot answer questions outside of cybersecurity topics.", | |
| "source": "N/A", | |
| "route": "OFF_TOPIC", | |
| "provider": LLM_PROVIDER.upper(), | |
| "model": get_llm_config()["model"] | |
| } | |
| except Exception as e: | |
| print(f"Error processing query: {e}") | |
| return { | |
| "answer": "I encountered an error while processing your query. Please try again.", | |
| "source": "Error", | |
| "route": None, | |
| "documents": None, | |
| "provider": LLM_PROVIDER.upper(), | |
| "error": str(e) | |
| } | |
| def check_system_health(): | |
| """Check if all components are properly initialized.""" | |
| try: | |
| # Test ChromaDB connection | |
| if vector_store: | |
| vector_store._client.heartbeat() | |
| # Test if all chains are initialized | |
| components = { | |
| "vector_store": vector_store is not None, | |
| "company_qa_chain": company_qa_chain is not None, | |
| "query_router_chain": query_router_chain is not None, | |
| "cybersecurity_chain": cybersecurity_chain is not None, | |
| "llm": llm is not None | |
| } | |
| return { | |
| "status": "healthy" if all(components.values()) else "unhealthy", | |
| "components": components, | |
| "provider": LLM_PROVIDER.upper(), | |
| "model": get_llm_config()["model"] if llm else "Not initialized" | |
| } | |
| except Exception as e: | |
| return { | |
| "status": "unhealthy", | |
| "error": str(e), | |
| "provider": LLM_PROVIDER.upper() | |
| } | |
| def test_llm_connection(): | |
| """Test the LLM API connection.""" | |
| try: | |
| if not llm: | |
| initialize_pipelines() | |
| # Simple test query | |
| test_response = llm("Say 'Hello, LLM is working!'") | |
| return { | |
| "success": True, | |
| "provider": LLM_PROVIDER.upper(), | |
| "model": get_llm_config()["model"], | |
| "response": str(test_response) | |
| } | |
| except Exception as e: | |
| return { | |
| "success": False, | |
| "provider": LLM_PROVIDER.upper(), | |
| "error": str(e) | |
| } | |
| # Initialize pipelines on module import | |
| try: | |
| initialize_pipelines() | |
| except Exception as e: | |
| print(f"Failed to initialize pipelines on startup: {e}") |