File size: 4,164 Bytes
7c19d46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
# =============================================================================
# RAG Pipeline — DevSecOps Knowledge Assistant
# =============================================================================
# Stack: LangChain + HuggingFace Embeddings + ChromaDB + vLLM
# =============================================================================

import os
from typing import List, Optional
from dataclasses import dataclass

from langchain_community.document_loaders import (
    DirectoryLoader,
    GitLoader,
    PyPDFLoader,
)
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.vectorstores import Chroma
from langchain_community.llms import VLLM


@dataclass
class RAGConfig:
    """RAG pipeline configuration."""
    embedding_model: str = "BAAI/bge-large-en-v1.5"
    llm_model: str = "meta-llama/Llama-3.1-8B-Instruct"
    chunk_size: int = 512
    chunk_overlap: int = 64
    retriever_k: int = 4
    persist_dir: str = "/data/chromadb"
    device: str = "cuda"


class DevSecOpsRAG:
    """Retrieval-Augmented Generation pipeline for DevSecOps knowledge."""

    def __init__(self, config: Optional[RAGConfig] = None):
        self.config = config or RAGConfig()
        self.embeddings = HuggingFaceEmbeddings(
            model_name=self.config.embedding_model,
            model_kwargs={"device": self.config.device},
            encode_kwargs={"normalize_embeddings": True},
        )
        self.vectorstore = None
        self.llm = VLLM(
            model=self.config.llm_model,
            trust_remote_code=True,
            tensor_parallel_size=1,
            gpu_memory_utilization=0.85,
            max_model_len=4096,
        )
        self.text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=self.config.chunk_size,
            chunk_overlap=self.config.chunk_overlap,
            separators=["\n## ", "\n### ", "\n\n", "\n", " "],
        )

    def ingest_documents(self, source_path: str) -> int:
        """Load and index documents from a directory."""
        loader = DirectoryLoader(
            source_path,
            glob="**/*.{md,txt,rst,py,yaml,yml,json,tf}",
            show_progress=True,
        )
        documents = loader.load()
        chunks = self.text_splitter.split_documents(documents)

        self.vectorstore = Chroma.from_documents(
            documents=chunks,
            embedding=self.embeddings,
            persist_directory=self.config.persist_dir,
            collection_metadata={"hnsw:space": "cosine"},
        )
        self.vectorstore.persist()
        return len(chunks)

    def query(self, question: str) -> dict:
        """Query the RAG pipeline with a question."""
        if not self.vectorstore:
            self.vectorstore = Chroma(
                persist_directory=self.config.persist_dir,
                embedding_function=self.embeddings,
            )

        retriever = self.vectorstore.as_retriever(
            search_type="mmr",
            search_kwargs={"k": self.config.retriever_k},
        )
        docs = retriever.invoke(question)
        context = "\n\n---\n\n".join(d.page_content for d in docs)

        prompt = f"""You are a DevSecOps expert assistant. Answer the question
based on the context below. If the context doesn't contain enough information,
say so clearly. Always cite which document/section the answer comes from.

Context:
{context}

Question: {question}

Answer:"""

        response = self.llm.invoke(prompt)
        return {
            "question": question,
            "answer": response,
            "sources": [
                {"content": d.page_content[:200], "metadata": d.metadata}
                for d in docs
            ],
        }


if __name__ == "__main__":
    rag = DevSecOpsRAG()
    # Ingest platform documentation
    num_chunks = rag.ingest_documents("/app/devsecops-platform")
    print(f"Ingested {num_chunks} chunks")

    # Query
    result = rag.query("What security policies are enforced in the Kubernetes cluster?")
    print(f"Q: {result['question']}")
    print(f"A: {result['answer']}")