Pujan-Dev's picture
fixed :changed everything config
4d6298c
import os
import chromadb
from dotenv import load_dotenv
from langchain_core.documents import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.llms import OpenAI
from langchain.chains.question_answering import load_qa_chain
from langchain_community.vectorstores import Chroma
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from langchain.chat_models import ChatOpenAI
from config import Config
load_dotenv()
# ChromaDB configuration
CHROMA_HOST = Config.RAG_CHROMA_HOST
CHROMA_PORT = Config.RAG_CHROMA_PORT
COLLECTION_NAME = Config.RAG_COLLECTION_NAME
# LLM Provider Configuration
LLM_PROVIDER = Config.RAG_LLM_PROVIDER
LLM_API_KEY = Config.RAG_LLM_API_KEY
LLM_MODEL = Config.RAG_LLM_MODEL
LLM_TEMPERATURE = Config.RAG_LLM_TEMPERATURE
LLM_MAX_TOKENS = Config.RAG_LLM_MAX_TOKENS
# Provider-specific configurations
PROVIDER_CONFIGS = {
"openai": {
"api_base": "https://api.openai.com/v1",
"default_model": "gpt-3.5-turbo"
},
"groq": {
"api_base": "https://api.groq.com/openai/v1",
"default_model": "llama-3.3-70b-versatile"
},
"openrouter": {
"api_base": "https://openrouter.ai/api/v1",
"default_model": "mistralai/mistral-small-3.2-24b-instruct:free"
}
}
vector_store = None
company_qa_chain = None
query_router_chain = None
cybersecurity_chain = None
llm = None
def get_llm_config():
"""Get the appropriate LLM configuration based on the provider."""
if LLM_PROVIDER not in PROVIDER_CONFIGS:
raise ValueError(f"Unsupported LLM provider: {LLM_PROVIDER}. Supported: {list(PROVIDER_CONFIGS.keys())}")
config = PROVIDER_CONFIGS[LLM_PROVIDER].copy()
# Use provided model or fall back to default
model = LLM_MODEL if LLM_MODEL != "gpt-3.5-turbo" else config["default_model"]
return {
"model": model,
"openai_api_key": LLM_API_KEY,
"openai_api_base": config["api_base"],
"temperature": LLM_TEMPERATURE,
"max_tokens": LLM_MAX_TOKENS,
}
def initialize_llm():
"""Initialize the LLM based on the configured provider."""
if not LLM_API_KEY:
raise ValueError(f"LLM_API_KEY environment variable is required for {LLM_PROVIDER}")
config = get_llm_config()
print(f"Initializing {LLM_PROVIDER.upper()} with model: {config['model']}")
return ChatOpenAI(**config)
def initialize_pipelines():
"""Initializes all required models, chains, and the vector store."""
global vector_store, company_qa_chain, query_router_chain, cybersecurity_chain, llm
try:
# Initialize LLM
llm = initialize_llm()
# Initialize embeddings
embeddings = HuggingFaceEmbeddings(
model_name="all-MiniLM-L6-v2",
model_kwargs={'device': 'cpu'},
encode_kwargs={'normalize_embeddings': True}
)
# Initialize ChromaDB client
try:
chroma_client = chromadb.HttpClient(host=CHROMA_HOST, port=CHROMA_PORT)
chroma_client.heartbeat()
except Exception as e:
raise ConnectionError("Failed to connect to ChromaDB.") from e
# Initialize vector store
vector_store = Chroma(
client=chroma_client,
collection_name=COLLECTION_NAME,
embedding_function=embeddings,
)
# Query Router Chain
router_template = """You are a query classifier. Classify the following query into one of these categories:
- COMPANY: Questions about our company, its products, services, or general information
- CYBERSECURITY: Questions about cybersecurity, security threats, best practices, or vulnerabilities
- OFF_TOPIC: Questions that don't fit the above categories
Query: {query}
Respond with only the category name (COMPANY, CYBERSECURITY, or OFF_TOPIC):"""
router_prompt = PromptTemplate(
input_variables=["query"],
template=router_template
)
query_router_chain = LLMChain(
llm=llm,
prompt=router_prompt
)
# Custom Company QA Chain
company_qa_template = """You are a helpful assistant for CyberAlertNepal. Answer the following question about our company using the information provided and links if only available. Give a natural, direct and polite response.
Question: {question}
Information:
{context}
Answer:"""
company_qa_prompt = PromptTemplate(
input_variables=["question", "context"],
template=company_qa_template
)
company_qa_chain = LLMChain(
llm=llm,
prompt=company_qa_prompt
)
# Cybersecurity Chain
cybersecurity_template = """You are a cybersecurity professional. Answer the following question truthfully and concisely.
If you are not 100% sure about the answer, simply respond with: "I am not sure about the answer."
Do not add extra explanations or assumptions. Do not provide false or speculative information.
Question: {question}
Provide a comprehensive and accurate answer about cybersecurity:"""
cybersecurity_prompt = PromptTemplate(
input_variables=["question"],
template=cybersecurity_template
)
cybersecurity_chain = LLMChain(
llm=llm,
prompt=cybersecurity_prompt
)
print(f"Successfully initialized pipelines with {LLM_PROVIDER.upper()}")
except Exception as e:
print(f"Error initializing pipelines: {e}")
raise
def add_document_to_rag(text: str, metadata: dict):
"""Splits a document and adds it to the ChromaDB index."""
global vector_store
if not vector_store:
initialize_pipelines()
try:
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=200
)
docs = text_splitter.create_documents([text], metadatas=[metadata])
if not docs:
print("Document was empty after splitting, not adding to ChromaDB.")
return False
vector_store.add_documents(docs)
print("Successfully added documents.")
return True
except Exception as e:
print(f"Error adding document to RAG: {e}")
return False
def route_and_process_query(query: str):
"""Routes the query and processes it using the appropriate pipeline."""
global query_router_chain, vector_store, company_qa_chain, cybersecurity_chain
if not all([query_router_chain, vector_store, company_qa_chain, cybersecurity_chain]):
initialize_pipelines()
try:
# 1. Classify the query
route_result = query_router_chain.run(query)
route = route_result.strip().upper()
# 2. Route to appropriate logic
if "CYBERSECURITY" in route:
answer = cybersecurity_chain.run(question=query)
return {
"answer": answer,
"source": "Cybersecurity Knowledge Base",
"route": "CYBERSECURITY",
"provider": LLM_PROVIDER.upper(),
"model": get_llm_config()["model"]
}
elif "COMPANY" in route:
# Perform similarity search on ChromaDB
docs = vector_store.similarity_search(query, k=3)
if not docs:
return {
"answer": "I could not find any relevant information to answer your question.",
"source": "Company Documents",
"route": "COMPANY",
"provider": LLM_PROVIDER.upper(),
"model": get_llm_config()["model"]
}
# Combine document content for context
context = "\n\n".join([doc.page_content for doc in docs])
# Run the custom QA chain
answer = company_qa_chain.run(question=query, context=context)
sources = list(set([doc.metadata.get("source", "Unknown") for doc in docs]))
return {
"answer": answer,
"source": "Company Documents",
"documents": sources,
"route": "COMPANY",
"provider": LLM_PROVIDER.upper(),
"model": get_llm_config()["model"]
}
else: # OFF_TOPIC
return {
"answer": "I am a specialized assistant of CyberAlertNepal. I cannot answer questions outside of cybersecurity topics.",
"source": "N/A",
"route": "OFF_TOPIC",
"provider": LLM_PROVIDER.upper(),
"model": get_llm_config()["model"]
}
except Exception as e:
print(f"Error processing query: {e}")
return {
"answer": "I encountered an error while processing your query. Please try again.",
"source": "Error",
"route": None,
"documents": None,
"provider": LLM_PROVIDER.upper(),
"error": str(e)
}
def check_system_health():
"""Check if all components are properly initialized."""
try:
# Test ChromaDB connection
if vector_store:
vector_store._client.heartbeat()
# Test if all chains are initialized
components = {
"vector_store": vector_store is not None,
"company_qa_chain": company_qa_chain is not None,
"query_router_chain": query_router_chain is not None,
"cybersecurity_chain": cybersecurity_chain is not None,
"llm": llm is not None
}
return {
"status": "healthy" if all(components.values()) else "unhealthy",
"components": components,
"provider": LLM_PROVIDER.upper(),
"model": get_llm_config()["model"] if llm else "Not initialized"
}
except Exception as e:
return {
"status": "unhealthy",
"error": str(e),
"provider": LLM_PROVIDER.upper()
}
def test_llm_connection():
"""Test the LLM API connection."""
try:
if not llm:
initialize_pipelines()
# Simple test query
test_response = llm("Say 'Hello, LLM is working!'")
return {
"success": True,
"provider": LLM_PROVIDER.upper(),
"model": get_llm_config()["model"],
"response": str(test_response)
}
except Exception as e:
return {
"success": False,
"provider": LLM_PROVIDER.upper(),
"error": str(e)
}
# Initialize pipelines on module import
try:
initialize_pipelines()
except Exception as e:
print(f"Failed to initialize pipelines on startup: {e}")