diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..bd38fdaa37297dbcfdf0366c6475d785501dd608 --- /dev/null +++ b/.gitignore @@ -0,0 +1,118 @@ +# Local run files +qa.db +**/qa.db +**/*qa*.db +**/test-reports + +# Byte-compiled / optimized / DLL files +__pycache__/ +/pycache/* +**/pycache/* +*/*/pycache/* +*/*/*/pycache/* +*/*/*/*/pycache/* +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pyflow +__pypackages__/ + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# PyCharm +.idea + +# VSCode +.vscode + +# http cache (requests-cache) +**/http_cache.sqlite + +# ruff +.ruff_cache diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..01098901f299bb40fe3c5f663c7a0dce478e8076 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,202 @@ +[build-system] +requires = ["hatchling", "hatch-vcs"] +build-backend = "hatchling.build" + +[project] +name = "rag-pipelines" +version = "0.0.1" +description = 'Advanced Retrieval Augmented Generation Pipelines' +readme = "README.md" +requires-python = ">=3.9" +license = "MIT" +keywords = [] +authors = [ + { name = "Ashwin Mathur", email = "" }, + { name = "Varun Mathur", email = "" }, +] +classifiers = [ + "License :: OSI Approved :: MIT License", + "Development Status :: 4 - Beta", + "Programming Language :: Python", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: Implementation :: CPython", + "Programming Language :: Python :: Implementation :: PyPy", +] +dependencies = [ + "dataloaders @ git+https://github.com/avnlp/dataloaders.git", + "langchain-core", + "langgraph", + "langchain-text-splitters", + "langchain-experimental", + "langchain-huggingface", + "langchain-groq", + "langchain_milvus", + "langchain-qdrant", + "langchain-pinecone", + "langchain-voyageai", + "spladerunner", + "haystack-ai", + "weave", + "edgartools", + "fastembed", + "pinecone-text[splade]", + "unstructured[pdf]", + "deepeval", + "arize-phoenix", + "dspy", + "dspy-ai[milvus]", + "optimum[onnxruntime]", +] + +[project.optional-dependencies] +dev = ["pytest"] + +[project.urls] +Documentation = "https://github.com/avnlp/rag-pipelines#readme" +Issues = "https://github.com/avnlp/rag-pipelines/issues" +Source = "https://github.com/avnlp/rag-pipelines" + +[tool.hatch.metadata] +allow-direct-references = true + +[tool.hatch.build.targets.wheel] +packages = ["src/rag_pipelines"] + +[tool.hatch.envs.default] +installer = "uv" +dependencies = [ + "coverage[toml]>=6.5", + "pytest", + "pytest-rerunfailures", + "pytest-mock", +] + +[tool.hatch.envs.default.scripts] +test = "pytest -vv {args:tests}" +test-cov = "coverage run -m pytest {args:tests}" +test-cov-retry = "test-cov --reruns 3 --reruns-delay 30 -x" +cov-report = ["- coverage combine", "coverage report"] +cov = ["test-cov", "cov-report"] +cov-retry = ["test-cov-retry", "cov-report"] + +[[tool.hatch.envs.test.matrix]] +python = ["39", "310", "311"] + +[tool.hatch.envs.lint] +installer = "uv" +detached = true +dependencies = ["pip", "black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] + +[tool.hatch.envs.lint.scripts] +typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" +style = ["ruff check {args:}", "black --check --diff {args:.}"] +fmt = ["black {args:.}", "ruff check --fix --unsafe-fixes {args:}", "style"] +all = ["style", "typing"] + +[tool.coverage.run] +source = ["rag_pipelines"] +branch = true +parallel = true + +[tool.coverage.report] +omit = ["*/tests/*", "*/__init__.py"] +show_missing = true +exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"] + +[tool.ruff] +target-version = "py39" +line-length = 120 + +[tool.ruff.lint] +select = [ + "A", + "ARG", + "B", + "C", + "D", + "D401", + "DTZ", + "E", + "EM", + "F", + "I", + "ICN", + "ISC", + "N", + "PLC", + "PLE", + "PLR", + "PLW", + "Q", + "RUF", + "S", + "T", + "TID", + "UP", + "W", + "YTT", +] +ignore = [ + # Allow non-abstract empty methods in abstract base classes + "B027", + # Allow boolean positional values in function calls, like `dict.get(... True)` + "FBT003", + # Ignore checks for possible passwords + "S102", + "S105", + "S106", + "S107", + # Ignore complexity + "C901", + "PLR0911", + "PLR0912", + "PLR0913", + "PLR0915", + # Allow print statements + "T201", + # Ignore missing module docstrings + "D100", + "D104", + # Ignore Line too long + "E501", + # Ignore builtin argument shadowing + "A002", + # Ignore builtin module shadowing + "A005", + # Ignore Function calls in argument defaults + "B008", + "ARG002", + "ARG005", +] +unfixable = [ + # Don't touch unused imports + "F401", +] + +[tool.ruff.lint.pydocstyle] +convention = "google" + +[tool.ruff.lint.isort] +known-first-party = ["rag_pipelines"] + +[tool.ruff.lint.flake8-tidy-imports] +ban-relative-imports = "parents" + +[tool.ruff.lint.per-file-ignores] +# Tests can use magic values, assertions, and relative imports +"tests/**/*" = ["PLR2004", "S101", "TID252"] + +[tool.pytest.ini_options] +minversion = "6.0" +addopts = "--strict-markers" +markers = ["integration: integration tests"] +log_cli = true + +[tool.black] +line-length = 120 + +[[tool.mypy.overrides]] +module = ["rag_pipelines.*", "pytest.*", "numpy.*"] +ignore_missing_imports = true diff --git a/src/rag_pipelines/__init__.py b/src/rag_pipelines/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/rag_pipelines/__pycache__/__init__.cpython-310.pyc b/src/rag_pipelines/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d15c57d8ddd5d95368c0098c4f56bd854ee416f2 Binary files /dev/null and b/src/rag_pipelines/__pycache__/__init__.cpython-310.pyc differ diff --git a/src/rag_pipelines/embeddings/__init__.py b/src/rag_pipelines/embeddings/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8bdd28eea36abe1a64848c71addd61486bbd3642 --- /dev/null +++ b/src/rag_pipelines/embeddings/__init__.py @@ -0,0 +1,6 @@ +from rag_pipelines.embeddings.dense import DenseEmbeddings +from rag_pipelines.embeddings.sparse_fastembed_qdrant import SparseEmbeddings +from rag_pipelines.embeddings.sparse_milvus import SparseEmbeddingsMilvus +from rag_pipelines.embeddings.sparse_pinecone_text import SparseEmbeddingsSplade + +__all__ = ["DenseEmbeddings", "SparseEmbeddings", "SparseEmbeddingsMilvus", "SparseEmbeddingsSplade"] diff --git a/src/rag_pipelines/embeddings/__pycache__/__init__.cpython-310.pyc b/src/rag_pipelines/embeddings/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1ed88f9559f0d94e770c5ff8128acd8d382b8e40 Binary files /dev/null and b/src/rag_pipelines/embeddings/__pycache__/__init__.cpython-310.pyc differ diff --git a/src/rag_pipelines/embeddings/__pycache__/dense.cpython-310.pyc b/src/rag_pipelines/embeddings/__pycache__/dense.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..52c37b1dca7c0369f1cf89772307fee933e6e898 Binary files /dev/null and b/src/rag_pipelines/embeddings/__pycache__/dense.cpython-310.pyc differ diff --git a/src/rag_pipelines/embeddings/__pycache__/sparse_fastembed_qdrant.cpython-310.pyc b/src/rag_pipelines/embeddings/__pycache__/sparse_fastembed_qdrant.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3ab6465bd7f530e1001fcd7b05811eb939f3632e Binary files /dev/null and b/src/rag_pipelines/embeddings/__pycache__/sparse_fastembed_qdrant.cpython-310.pyc differ diff --git a/src/rag_pipelines/embeddings/__pycache__/sparse_milvus.cpython-310.pyc b/src/rag_pipelines/embeddings/__pycache__/sparse_milvus.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b77c8d68b9a62b4380b5e908ef883a8769649158 Binary files /dev/null and b/src/rag_pipelines/embeddings/__pycache__/sparse_milvus.cpython-310.pyc differ diff --git a/src/rag_pipelines/embeddings/__pycache__/sparse_pinecone_text.cpython-310.pyc b/src/rag_pipelines/embeddings/__pycache__/sparse_pinecone_text.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3e0425f8edc7c11b1d9a1bb7f373622d6bb92829 Binary files /dev/null and b/src/rag_pipelines/embeddings/__pycache__/sparse_pinecone_text.cpython-310.pyc differ diff --git a/src/rag_pipelines/embeddings/dense.py b/src/rag_pipelines/embeddings/dense.py new file mode 100644 index 0000000000000000000000000000000000000000..0bc322bb4073cbbbfba2aefd1af6b941a110d9fc --- /dev/null +++ b/src/rag_pipelines/embeddings/dense.py @@ -0,0 +1,85 @@ +from typing import Any, Optional + +import weave +from langchain_huggingface import HuggingFaceEmbeddings + + +class DenseEmbeddings(weave.Model): + """Generate dense embeddings for documents and queries using a specified SentenceTransformer model. + + This class leverages HuggingFace's `HuggingFaceEmbeddings` to compute dense embeddings for input text. + + Attributes: + model_name (str): The name of the pre-trained embedding model to use. + model_kwargs (Optional[Dict[str, Any]]): Additional configuration parameters for the embedding model. + encode_kwargs (Optional[Dict[str, Any]]): Parameters for fine-tuning the behavior of the encoding process. + embedding_model (HuggingFaceEmbeddings): The initialized HuggingFace embeddings model with the specified settings. + """ + + model_name: str + model_kwargs: Optional[dict[str, Any]] + encode_kwargs: Optional[dict[str, Any]] + show_progress: bool + embedding_model: Optional[HuggingFaceEmbeddings] = None + + def __init__( + self, + model_name: str = "sentence-transformers/all-MiniLM-L6-v2", + model_kwargs: Optional[dict[str, Any]] = None, + encode_kwargs: Optional[dict[str, Any]] = None, + show_progress: bool = True, + ): + """Initialize the DenseEmbeddings class with the specified model and configurations. + + Args: + model_name (str): The name of the pre-trained embedding model. Defaults to "sentence-transformers/all-MiniLM-L6-v2". + model_kwargs (Optional[dict[str, Any]]): Additional model configuration parameters for initialization. Defaults to None. + encode_kwargs (Optional[dict[str, Any]]): Parameters for encoding settings. Defaults to None. + show_progress (bool): Whether to display progress during model operations. Defaults to True. + """ + if encode_kwargs is None: + encode_kwargs = {"normalize_embeddings": True} + if model_kwargs is None: + model_kwargs = {"device": "cpu"} + super().__init__( + model_name=model_name, + model_kwargs=model_kwargs, + encode_kwargs=encode_kwargs, + show_progress=show_progress, + ) + + self.model_name = model_name + self.model_kwargs = model_kwargs if model_kwargs is not None else {} + self.encode_kwargs = encode_kwargs if encode_kwargs is not None else {} + + # Initialize the embedding model with the specified parameters + self.embedding_model = HuggingFaceEmbeddings( + model_name=self.model_name, + model_kwargs=self.model_kwargs, + encode_kwargs=self.encode_kwargs, + show_progress=show_progress, + ) + + @weave.op() + def embed_texts(self, texts: list[str]) -> list[list[float]]: + """Embed a list of texts and return their embeddings. + + Args: + texts (list[str]): A list of texts to embed. + + Returns: + list[list[float]]: A list of embedding vectors corresponding to each input text. + """ + return self.embedding_model.embed_documents(texts) + + @weave.op() + def embed_query(self, text: str) -> list[float]: + """Embed a single query text and returns its embedding. + + Args: + text (str): The query text to be embedded. + + Returns: + List[float]: The embedding vector for the query text. + """ + return self.embedding_model.embed_query(text) diff --git a/src/rag_pipelines/embeddings/sparse_fastembed_qdrant.py b/src/rag_pipelines/embeddings/sparse_fastembed_qdrant.py new file mode 100644 index 0000000000000000000000000000000000000000..f68fa10252f31a181bb951a2939bfe8de5fb51e7 --- /dev/null +++ b/src/rag_pipelines/embeddings/sparse_fastembed_qdrant.py @@ -0,0 +1,57 @@ +from typing import Any, Optional + +import weave +from langchain_qdrant.fastembed_sparse import FastEmbedSparse + + +class SparseEmbeddings(weave.Model): + """Generate sparse embeddings for documents and queries using the FastEmbedSparse model. + + Attributes: + model_name (str): The name of the sparse embedding model to use. + model_kwargs (Optional[dict[str, Any]]): Additional configuration parameters for the model. + sparse_embedding_model (FastEmbedSparse): The initialized FastEmbedSparse model with the specified parameters. + """ + + def __init__( + self, + model_name: str = "prithvida/Splade_PP_en_v1", + model_kwargs: Optional[dict[str, Any]] = None, + ): + """Initialize the SparseEmbeddings class with the specified model and configurations. + + Args: + model_name (str): The name of the sparse embedding model. Defaults to "prithvida/Splade_PP_en_v1". + model_kwargs (Optional[dict[str, Any]]): Additional model configuration parameters for initialization. Defaults to None. + """ + self.model_name = model_name + self.model_kwargs = model_kwargs if model_kwargs is not None else {} + + # Initialize the sparse embedding model with specified parameters + self.sparse_embedding_model = FastEmbedSparse(model_name=self.model_name, **self.model_kwargs) + + @weave.op() + def embed_texts(self, texts: list[str]) -> list[dict[str, float]]: + """Embed a list of texts and return their sparse embeddings. + + Args: + texts (list[str]): A list of document texts to embed. + + Returns: + list[dict[str, float]]: A list of sparse embedding dictionaries for each document text. + Each dictionary maps terms to their corresponding weights. + """ + return self.sparse_embedding_model.embed_documents(texts) + + @weave.op() + def embed_query(self, text: str) -> dict[str, float]: + """Embed a single query text and return its sparse embedding. + + Args: + text (str): The query text to embed. + + Returns: + dict[str, float]: A sparse embedding dictionary for the query text, where keys are terms + and values are term weights. + """ + return self.sparse_embedding_model.embed_query(text) diff --git a/src/rag_pipelines/embeddings/sparse_milvus.py b/src/rag_pipelines/embeddings/sparse_milvus.py new file mode 100644 index 0000000000000000000000000000000000000000..df7aa4c2b31569fbff479f44125d9b5fe5f637f1 --- /dev/null +++ b/src/rag_pipelines/embeddings/sparse_milvus.py @@ -0,0 +1,67 @@ +from typing import Any, Optional + +import weave +from langchain_milvus.utils.sparse import BaseSparseEmbedding +from spladerunner import Expander + + +class SparseEmbeddingsMilvus(BaseSparseEmbedding): + """Generate sparse embeddings for documents and queries using the FastEmbedSparse model. + + Attributes: + model_name (str): The name of the sparse embedding model to use. + model_kwargs (Optional[dict[str, Any]]): Additional configuration parameters for the model. + sparse_embedding_model (FastEmbedSparse): The initialized FastEmbedSparse model with the specified parameters. + """ + + model_name: str + model_kwargs: Optional[dict[str, Any]] = None + sparse_embedding_model: Optional[Any] = None + + def __init__( + self, + model_name: str = "Splade_PP_en_v1", + max_length: int = 512, + ): + """Initialize the SparseEmbeddings class with the specified model and configurations. + + Args: + model_name (str): The name of the sparse embedding model. Defaults to "Splade_PP_en_v1". + model_kwargs (Optional[dict[str, Any]]): Additional model configuration parameters for initialization. Defaults to None. + """ + self.model_name = model_name + self.max_length = max_length + + # Initialize the sparse embedding model with specified parameters + self.sparse_embedding_model = Expander(model_name=self.model_name, max_length=self.max_length) + + def _sparse_to_dict(self, sparse_vector: Any) -> dict[int, float]: + return dict(zip(sparse_vector["indices"], sparse_vector["values"])) + + @weave.op() + def embed_query(self, text: str) -> dict[int, float]: + """Embed a single query text and return its sparse embedding. + + Args: + text (str): The query text to embed. + + Returns: + dict[int, float]: A sparse embedding dictionary for the query text, where keys are terms + and values are term weights. + """ + sparse_embeddings = list(self.sparse_embedding_model.expand([text])) + return self._sparse_to_dict(sparse_embeddings[0]) + + @weave.op() + def embed_documents(self, texts: list[str]) -> list[dict[int, float]]: + """Embed a list of texts and return their sparse embeddings. + + Args: + texts (list[str]): A list of document texts to embed. + + Returns: + list[dict[int, float]]: A list of sparse embedding dictionaries for each document text. + Each dictionary maps terms to their corresponding weights. + """ + sparse_embeddings = list(self.sparse_embedding_model.expand(texts)) + return [self._sparse_to_dict(sparse_embeddings[i]) for i in range(len(texts))] diff --git a/src/rag_pipelines/embeddings/sparse_pinecone_text.py b/src/rag_pipelines/embeddings/sparse_pinecone_text.py new file mode 100644 index 0000000000000000000000000000000000000000..5aea69e3caea2b503926e7217534ec35f2af7c82 --- /dev/null +++ b/src/rag_pipelines/embeddings/sparse_pinecone_text.py @@ -0,0 +1,58 @@ +from typing import Any, Optional + +import weave +from pinecone_text.sparse import SpladeEncoder + + +class SparseEmbeddingsSplade(weave.Model): + """Generate sparse embeddings for documents and queries using the FastEmbedSparse model. + + Attributes: + model_kwargs (Optional[dict[str, Any]]): Additional configuration parameters for the model. + sparse_embedding_model (SpladeEncoder): The FastEmbedSparse model initialized with the specified parameters. + """ + + model_kwargs: Optional[dict[str, Any]] + sparse_embedding_model: Optional[SpladeEncoder] = None + + def __init__( + self, + model_kwargs: Optional[dict[str, Any]] = None, + ): + """Initialize the SparseEmbeddings class with the specified model and configurations. + + Args: + model_kwargs (Optional[dict[str, Any]]): Additional model configuration parameters for initialization. + """ + super().__init__(model_kwargs=model_kwargs) + + self.model_kwargs = model_kwargs if model_kwargs is not None else {} + + # Initialize the sparse embedding model with specified parameters + self.sparse_embedding_model = SpladeEncoder(**self.model_kwargs) + + @weave.op() + def embed_texts(self, texts: list[str]) -> list[dict[str, float]]: + """Embed a list of texts and return their sparse embeddings. + + Args: + texts (list[str]): A list of document texts to embed. + + Returns: + list[dict[str, float]]: A list of sparse embedding dictionaries for each document text. + Each dictionary maps terms to their corresponding weights. + """ + return self.sparse_embedding_model.encode_documents(texts) + + @weave.op() + def embed_query(self, text: str) -> dict[str, float]: + """Embed a single query text and return its sparse embedding. + + Args: + text (str): The query text to embed. + + Returns: + dict[str, float]: A sparse embedding dictionary for the query text, where keys are terms + and values are term weights. + """ + return self.sparse_embedding_model.encode_queries([text]) diff --git a/src/rag_pipelines/evaluation/__init__.py b/src/rag_pipelines/evaluation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d04d795c730a2aff253b8fb55737833022040885 --- /dev/null +++ b/src/rag_pipelines/evaluation/__init__.py @@ -0,0 +1,19 @@ +from rag_pipelines.evaluation.evaluator import Evaluator +from rag_pipelines.evaluation.response.answer_relevancy import AnswerRelevancyScorer +from rag_pipelines.evaluation.response.faithfulness import FaithfulnessScorer +from rag_pipelines.evaluation.response.hallucination import HallucinationScorer +from rag_pipelines.evaluation.response.summarization import SummarizationScorer +from rag_pipelines.evaluation.retrieval.contextual_precision import ContextualPrecisionScorer +from rag_pipelines.evaluation.retrieval.contextual_recall import ContextualRecallScorer +from rag_pipelines.evaluation.retrieval.contextual_relevancy import ContextualRelevancyScorer + +__all__ = [ + "AnswerRelevancyScorer", + "ContextualPrecisionScorer", + "ContextualRecallScorer", + "ContextualRelevancyScorer", + "Evaluator", + "FaithfulnessScorer", + "HallucinationScorer", + "SummarizationScorer", +] diff --git a/src/rag_pipelines/evaluation/evaluator.py b/src/rag_pipelines/evaluation/evaluator.py new file mode 100644 index 0000000000000000000000000000000000000000..810dcb6249936ae06edfa84b5757100de340ccad --- /dev/null +++ b/src/rag_pipelines/evaluation/evaluator.py @@ -0,0 +1,54 @@ +import asyncio + +from weave import Dataset, Evaluation, Model, Scorer + + +class Evaluator: + """Evaluate a model on a dataset using a list of scorers. + + Attributes: + evaluation_name (str): The name of the evaluation run. + evaluation_dataset (Dataset): The dataset used for evaluation. + evaluation_scorers (list[Scorer]): A list of scorer objects used to evaluate the pipeline. + pipeline (Model): The pipeline (model) to be evaluated. + """ + + def __init__( + self, + evaluation_name: str, + evaluation_dataset: Dataset, + evaluation_scorers: list[Scorer], + pipeline: Model, + ): + """Initialize the Evaluator instance with the specified evaluation parameters. + + Args: + evaluation_name (str): A unique identifier for the evaluation run. + evaluation_dataset (Dataset): A `Dataset` object representing the data for evaluation. + evaluation_scorers (list[Scorer]): A list of `Scorer` objects that calculate various metrics. + pipeline (Model): The model or pipeline to evaluate. + """ + self.evaluation_name = evaluation_name + self.evaluation_dataset = evaluation_dataset + self.evaluation_scorers = evaluation_scorers + self.pipeline = pipeline + + def evaluate(self) -> None: + """Perform evaluation of the pipeline using the specified dataset and scorers. + + This method creates an `Evaluation` object, executes the evaluation process, and + returns the results as a dictionary. + """ + evaluation = Evaluation( + evaluation_name=self.evaluation_name, + dataset=self.evaluation_dataset, + scorers=self.evaluation_scorers, + ) + + try: + evaluation_results = asyncio.run(evaluation.evaluate(self.pipeline)) + except Exception as exception: + msg = f"Evaluation run failed: {exception}" + raise RuntimeError(msg) from exception + + return evaluation_results diff --git a/src/rag_pipelines/evaluation/response/__init__.py b/src/rag_pipelines/evaluation/response/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/rag_pipelines/evaluation/response/answer_relevancy.py b/src/rag_pipelines/evaluation/response/answer_relevancy.py new file mode 100644 index 0000000000000000000000000000000000000000..a4b448b36deba72ffa1183e97c3aa085d1e0cf25 --- /dev/null +++ b/src/rag_pipelines/evaluation/response/answer_relevancy.py @@ -0,0 +1,152 @@ +from statistics import variance +from typing import Optional, Union + +import numpy as np +import weave +from deepeval.metrics import AnswerRelevancyMetric +from deepeval.test_case import LLMTestCase +from weave import Scorer + + +class AnswerRelevancyScorer(Scorer): + """Evaluate the relevancy of answers generated by a LLM. + + This scorer uses DeepEval's `AnswerRelevancy` Metric to assess the relevance and accuracy of LLM generated answers + compared to the input query. + + The answer relevancy metric measures the quality of the RAG pipeline's generator by determining how relevant the + actual output of an LLM application is in relation to the input query. + + Attributes: + threshold (float): The minimum passing threshold for relevancy, defaults to 0.5. + model (str): The name of the LLM model used for evaluation, defaults to "gpt-4". + include_reason (bool): Whether to include an explanation for the evaluation score, defaults to True. + strict_mode (bool): Enforces binary scoring (1 for perfect relevancy, 0 otherwise). Overrides the threshold to + 1. Defaults to False. + async_mode (bool): Whether to perform scoring asynchronously, defaults to True. + verbose (bool): Whether to print intermediate steps to the console, defaults to False. + metric (AnswerRelevancyMetric): An instance of AnswerRelevancyMetric to calculate the score. + """ + + threshold: float = Optional[None] + model: str = Optional[None] + include_reason: bool = Optional[None] + strict_mode: bool = Optional[None] + async_mode: bool = Optional[None] + verbose: bool = Optional[None] + metric: AnswerRelevancyMetric = Optional[None] + + def __init__( + self, + threshold: float = 0.5, + model: str = "gpt-4", + include_reason: bool = True, + strict_mode: bool = False, + async_mode: bool = True, + verbose: bool = False, + ): + """Initialize the AnswerRelevancy Scorer with the specified parameters. + + Args: + threshold (float): The minimum passing threshold for relevancy, defaults to 0.5. + model (str): The name of the LLM model used for evaluation, defaults to "gpt-4". + include_reason (bool): Whether to include an explanation for the evaluation score, defaults to True. + strict_mode (bool): Enforces binary scoring (1 for perfect relevancy, 0 otherwise). Overrides the threshold to 1. Defaults to False. + async_mode (bool): Whether to perform scoring asynchronously, defaults to True. + verbose (bool): Whether to print intermediate steps to the console, defaults to False. + """ + super().__init__( + threshold=threshold, + model=model, + include_reason=include_reason, + strict_mode=strict_mode, + async_mode=async_mode, + verbose=verbose, + ) + + self.threshold = threshold + self.model = model + self.include_reason = include_reason + self.strict_mode = strict_mode + self.async_mode = async_mode + self.verbose = verbose + + self.metric = AnswerRelevancyMetric( + threshold=self.threshold, + model=self.model, + include_reason=self.include_reason, + async_mode=self.async_mode, + strict_mode=self.strict_mode, + verbose_mode=self.verbose, + ) + + @weave.op + def score( + self, + input: str, + output: Optional[dict] = None, + expected_output: Optional[str] = None, + context: Optional[list[str]] = None, + ) -> dict[str, Union[str, float]]: + """Evaluate the relevancy and accuracy of answers generated by a LLM. + + The AnswerRelevancy score is calculated according to the following equation: + + Answer Relevancy = Total Number of Statements / Number of Relevant Statements + + The AnswerRelevancy Scorer uses an LLM to extract all statements made in the `actual_output`, before using the same LLM to classify whether each statement is relevant to the input. + + + Args: + input (str): The input query or prompt that triggered the output. + output (dict): The LLM generated response to evaluate and the retrieval context. + expected_output (Optional[str]): The expected or reference output, defaults to None. + context (Optional[list[str]]): Additional context for the evaluation, defaults to None. + + Returns: + dict[str, Union[str, float]]: A dictionary containing: + - "score" (float): The computed answer relevancy score. + """ + test_case = LLMTestCase( + input=input, + actual_output=output.get("output", ""), + expected_output=expected_output, + retrieval_context=output.get("retrieval_context", [""]), + context=context, + ) + + result: dict[str, Union[str, float]] = {} + + self.metric.measure(test_case) + result = {"score": self.metric.score} + + return result + + @weave.op() + def summarize(self, score_rows: list) -> dict: + """Summarize the results of the AnswerRelevancy Scorer. + + Args: + score_rows (list): A list of dictionaries containing the following keys: + - "score" (float): The computed answer relevancy score. + - "reason" (str): A detailed explanation for the assigned score. + + Returns: + dict: A dictionary containing the following keys: + - "answer_relevancy_score" (dict): A dictionary containing the following keys: + - "score" (float): The average answer relevancy score. + - "variance" (float): The variance of the answer relevancy scores. + - "std" (float): The standard deviation of the answer relevancy scores. + - "count" (int): The number of answer relevancy scores. + """ + scores = [] + for row in score_rows: + score = row.get("score", 0.0) + scores.append(float(score)) + + score = np.mean(scores).item() + variance = np.var(scores).item() + std = np.std(scores).item() + count = len(scores) + + return {"answer_relevancy_score": {"score": score, "variance": variance, "std": std, "count": count}} diff --git a/src/rag_pipelines/evaluation/response/faithfulness.py b/src/rag_pipelines/evaluation/response/faithfulness.py new file mode 100644 index 0000000000000000000000000000000000000000..29c25f85f0cba41086478c3e9ee1085da738fc9f --- /dev/null +++ b/src/rag_pipelines/evaluation/response/faithfulness.py @@ -0,0 +1,132 @@ +from typing import Optional, Union + +import weave +from deepeval.metrics import FaithfulnessMetric +from deepeval.test_case import LLMTestCase +from weave import Scorer + + +class FaithfulnessScorer(Scorer): + """Evaluate the faithfulness of LLM generated outputs. + + This scorer uses DeepEval's `Faithfulness` Metric. + + The faithfulness metric measures the quality of your LLM generation by evaluating whether the `actual_output` factually aligns with the contents of your `retrieval_context`. + + Attributes: + threshold (float): The minimum score required to pass the faithfulness check, defaults to 0.5. + model (str): The LLM model used for evaluation, defaults to "gpt-4". + include_reason (bool): Whether to include an explanation for the assigned score, defaults to True. + strict_mode (bool): When True, enforces binary scoring (1 for perfect alignment, 0 otherwise). + Overrides the threshold to 1. Defaults to False. + async_mode (bool): Whether to perform scoring asynchronously, defaults to True. + verbose (bool): Whether to display intermediate steps during metric computation, defaults to False. + truths_extraction_limit (Optional[int]): Limits the number of key facts to extract from the retrieval + context for evaluation, ordered by importance. Defaults to None. + metric (FaithfulnessMetric): An instance of DeepEval's `FaithfulnessMetric` for scoring. + """ + + threshold: float = Optional[None] + model: str = Optional[None] + include_reason: bool = Optional[None] + strict_mode: bool = Optional[None] + async_mode: bool = Optional[None] + verbose: bool = Optional[None] + truths_extraction_limit: Optional[int] = Optional[None] + metric: FaithfulnessMetric = Optional[None] + + def __init__( + self, + threshold: float = 0.5, + model: str = "gpt-4", + include_reason: bool = True, + strict_mode: bool = False, + async_mode: bool = True, + verbose: bool = False, + truths_extraction_limit: Optional[int] = None, + ): + """Initialize the Faithfulness Scorer with DeepEval's Faithfulness Metric. + + Args: + threshold (float): The minimum score required to pass the faithfulness check, defaults to 0.5. + model (str): The LLM model used for evaluation, defaults to "gpt-4". + include_reason (bool): Whether to include an explanation for the assigned score, defaults to True. + strict_mode (bool): Enforces binary scoring (1 for perfect alignment, 0 otherwise). + Overrides the threshold to 1. Defaults to False. + async_mode (bool): Whether to perform scoring asynchronously, defaults to True. + verbose (bool): Whether to display intermediate steps during metric computation, defaults to False. + truths_extraction_limit (Optional[int]): Limits the number of key facts to extract from the retrieval + context for evaluation, ordered by importance. Defaults to None. + """ + super().__init__( + threshold=threshold, + model=model, + include_reason=include_reason, + strict_mode=strict_mode, + async_mode=async_mode, + verbose=verbose, + truths_extraction_limit=truths_extraction_limit, + ) + + self.threshold = threshold + self.model = model + self.include_reason = include_reason + self.strict_mode = strict_mode + self.async_mode = async_mode + self.verbose = verbose + self.truths_extraction_limit = truths_extraction_limit + + self.metric = FaithfulnessMetric( + threshold=self.threshold, + model=self.model, + include_reason=self.include_reason, + async_mode=self.async_mode, + strict_mode=self.strict_mode, + verbose_mode=self.verbose, + ) + + @weave.op + def score( + self, + input: str, + actual_output: str, + expected_output: Optional[str] = None, + retrieval_context: Optional[list[str]] = None, + context: Optional[list[str]] = None, + ) -> dict[str, Union[str, float]]: + """Evaluate the faithfulness of an LLM generated response. + + Faithfulness is calculated as: + + Faithfulness = (Number of Truthful Claims) / (Total Number of Claims). + + The Faithfulness Metric evaluates all claims in the `actual_output` and checks + whether they are truthful based on the facts in the `retrieval_context`. Claims + are marked truthful if they align with or do not contradict any facts in the context. + + Args: + input (str): The input query or prompt that triggered the output. + actual_output (str): The LLM generated response to evaluate. + expected_output (Optional[str]): The expected or reference output, defaults to None. + retrieval_context (Optional[list[str]]): The context containing factual information to compare against. + context (Optional[list[str]]): Additional context for the evaluation, defaults to None. + + Returns: + dict[str, Union[str, float]]: A dictionary containing: + - "score" (float): The computed faithfulness score. + - "reason" (str): A detailed explanation for the assigned score. + """ + test_case = LLMTestCase( + input=input, + actual_output=actual_output, + expected_output=expected_output, + retrieval_context=retrieval_context, + context=context, + ) + + result: dict[str, Union[str, float]] = {} + + self.metric.measure(test_case) + result = {"score": self.metric.score, "reason": self.metric.reason} + + return result diff --git a/src/rag_pipelines/evaluation/response/hallucination.py b/src/rag_pipelines/evaluation/response/hallucination.py new file mode 100644 index 0000000000000000000000000000000000000000..18f6c6641f71993f7a09598c41ac60541d88fd53 --- /dev/null +++ b/src/rag_pipelines/evaluation/response/hallucination.py @@ -0,0 +1,127 @@ +from typing import Optional, Union + +import weave +from deepeval.metrics import HallucinationMetric +from deepeval.test_case import LLMTestCase +from weave import Scorer + + +class HallucinationScorer(Scorer): + """Evaluate the factual alignment of the generated output with the provided context. + + This scorer uses DeepEval's `Hallucination` Metric to assess how well the generated output + aligns with the reference context. + + The Hallucination metric determines whether your LLM generates factually correct information by comparing the `actual_output` to the provided `context`. + + Attributes: + threshold (float): A float representing the minimum passing threshold, defaults to 0.5. + model (str): The LLM model to use for scoring, defaults to "gpt-4". + include_reason (bool): Whether to include a reason for the evaluation score, defaults to True. + strict_mode (bool): A boolean which when set to True, enforces a binary metric score: 1 for perfection, + 0 otherwise. It also overrides the current threshold and sets it to 1. Defaults to False. + async_mode (bool): Whether to use asynchronous scoring, defaults to True. + verbose (bool): Whether to print the intermediate steps used to calculate said metric to the console, defaults + to False. + metric (HallucinationMetric): The DeepEval HallucinationMetric. + """ + + threshold: float = Optional[None] + model: str = Optional[None] + include_reason: bool = Optional[None] + strict_mode: bool = Optional[None] + async_mode: bool = Optional[None] + verbose: bool = Optional[None] + metric: HallucinationMetric = Optional[None] + + def __init__( + self, + threshold: float = 0.5, + model: str = "gpt-4", + include_reason: bool = True, + strict_mode: bool = True, + async_mode: bool = True, + verbose: bool = False, + ): + """Initialize the Hallucination scorer using DeepEval's Hallucination Metric. + + Args: + threshold (float): A float representing the minimum passing threshold, defaults to 0.5. + model (str): The LLM model to use for scoring, defaults to "gpt-4". + include_reason (bool): Whether to include a reason for the evaluation score, defaults to True. + strict_mode (bool): A boolean which when set to True, enforces a binary metric score: 1 for perfection, + 0 otherwise. It also overrides the current threshold and sets it to 1. Defaults to False. + async_mode (bool): Whether to use asynchronous scoring, defaults to True. + verbose (bool): Whether to print the intermediate steps used to calculate said metric to the console, defaults + to False. + """ + super().__init__( + threshold=threshold, + model=model, + include_reason=include_reason, + strict_mode=strict_mode, + async_mode=async_mode, + verbose=verbose, + ) + + self.threshold = threshold + self.model = model + self.include_reason = include_reason + self.strict_mode = strict_mode + self.async_mode = async_mode + self.verbose = verbose + + self.metric = HallucinationMetric( + threshold=self.threshold, + model=self.model, + include_reason=self.include_reason, + async_mode=self.async_mode, + strict_mode=self.strict_mode, + verbose_mode=self.verbose, + ) + + @weave.op + def score( + self, + input: str, + actual_output: str, + expected_output: Optional[str] = None, + retrieval_context: Optional[list[str]] = None, + context: Optional[list[str]] = None, + ) -> dict[str, Union[str, float]]: + """Evaluate the factual alignment of the generated output with the provided context. + + The Hallucination Score is calculated according to the following equation: + + Hallucination = Number of Contradicted Contexts / Total Number of Contexts + + The Hallucination Score uses an LLM to determine, for each context in `contexts`, whether there are any contradictions to the `actual_output`. + + Although extremely similar to the Faithfulness Scorer, the Hallucination Score is calculated differently since it uses `contexts` as the source of truth instead. Since `contexts` is the ideal segment of your knowledge base relevant to a specific input, the degree of hallucination can be measured by the degree of which the `contexts` is disagreed upon. + + Args: + input (str): The input query or prompt that triggered the output. + actual_output (str): The LLM generated response to evaluate. + expected_output (Optional[str]): The expected or reference output, defaults to None. + retrieval_context (Optional[list[str]]): The context containing factual information to compare against. + context (Optional[list[str]]): Additional context for the evaluation, defaults to None. + + Returns: + dict[str, Union[str, float]]: A dictionary containing: + - "score" (float): The computed hallucination score. + - "reason" (str): A detailed explanation for the assigned score. + """ + test_case = LLMTestCase( + input=input, + actual_output=actual_output, + expected_output=expected_output, + retrieval_context=retrieval_context, + context=context, + ) + + result: dict[str, Union[str, float]] = {} + + self.metric.measure(test_case) + result = {"score": self.metric.score, "reason": self.metric.reason} + + return result diff --git a/src/rag_pipelines/evaluation/response/phoenix_hallucination.py b/src/rag_pipelines/evaluation/response/phoenix_hallucination.py new file mode 100644 index 0000000000000000000000000000000000000000..edab53119e275c5b654b68aeea74670c305c68a9 --- /dev/null +++ b/src/rag_pipelines/evaluation/response/phoenix_hallucination.py @@ -0,0 +1,107 @@ +from typing import Optional, Union + +import weave +from deepeval.metrics import AnswerRelevancyMetric +from deepeval.test_case import LLMTestCase +from weave import Scorer + + +class AnswerRelevancyScorer(Scorer): + """Evaluate the relevancy of answers generated by a LLM. + + This scorer uses DeepEval's `AnswerRelevancy` Metric to assess the relevance and accuracy of LLM generated answers + compared to the input query. + + The answer relevancy metric measures the quality of the RAG pipeline's generator by determining how relevant the + actual output of an LLM application is in relation to the input query. + + Attributes: + threshold (float): The minimum passing threshold for relevancy, defaults to 0.5. + model (str): The name of the LLM model used for evaluation, defaults to "gpt-4". + include_reason (bool): Whether to include an explanation for the evaluation score, defaults to True. + strict_mode (bool): Enforces binary scoring (1 for perfect relevancy, 0 otherwise). Overrides the threshold to + 1. Defaults to False. + async_mode (bool): Whether to perform scoring asynchronously, defaults to True. + verbose (bool): Whether to print intermediate steps to the console, defaults to False. + metric (AnswerRelevancyMetric): An instance of AnswerRelevancyMetric to calculate the score. + """ + + def __init__( + self, + threshold: float = 0.5, + model: str = "gpt-4", + include_reason: bool = True, + strict_mode: bool = False, + async_mode: bool = True, + verbose: bool = False, + ): + """Initialize the AnswerRelevancy Scorer with the specified parameters. + + Args: + threshold (float): The minimum passing threshold for relevancy, defaults to 0.5. + model (str): The name of the LLM model used for evaluation, defaults to "gpt-4". + include_reason (bool): Whether to include an explanation for the evaluation score, defaults to True. + strict_mode (bool): Enforces binary scoring (1 for perfect relevancy, 0 otherwise). Overrides the threshold to 1. Defaults to False. + async_mode (bool): Whether to perform scoring asynchronously, defaults to True. + verbose (bool): Whether to print intermediate steps to the console, defaults to False. + """ + self.threshold = threshold + self.model = model + self.include_reason = include_reason + self.strict_mode = strict_mode + self.async_mode = async_mode + self.verbose = verbose + + self.metric = AnswerRelevancyMetric( + threshold=self.threshold, + model=self.model, + include_reason=self.include_reason, + async_mode=self.async_mode, + strict_mode=self.strict_mode, + verbose=self.verbose, + ) + + @weave.op + def score( + self, + input: str, + actual_output: str, + expected_output: Optional[str] = None, + retrieval_context: Optional[list[str]] = None, + context: Optional[list[str]] = None, + ) -> dict[str, Union[str, float]]: + """Evaluate the relevancy and accuracy of answers generated by a LLM. + + The AnswerRelevancy score is calculated according to the following equation: + + Answer Relevancy = Total Number of Statements / Number of Relevant Statements + + The AnswerRelevancy Scorer uses an LLM to extract all statements made in the `actual_output`, before using the same LLM to classify whether each statement is relevant to the input. + + + Args: + input (str): The input query or prompt that triggered the output. + actual_output (str): The LLM generated response to evaluate. + expected_output (Optional[str]): The expected or reference output, defaults to None. + retrieval_context (Optional[list[str]]): The context containing factual information to compare against. + context (Optional[list[str]]): Additional context for the evaluation, defaults to None. + + Returns: + dict[str, Union[str, float]]: A dictionary containing: + - "score" (float): The computed answer relevancy score. + - "reason" (str): A detailed explanation for the assigned score. + """ + test_case = LLMTestCase( + input=input, + actual_output=actual_output, + expected_output=expected_output, + retrieval_context=retrieval_context, + context=context, + ) + + result: dict[str, Union[str, float]] = {} + + self.metric.measure(test_case) + result = {"score": self.metric.score, "reason": self.metric.reason} + + return result diff --git a/src/rag_pipelines/evaluation/response/summarization.py b/src/rag_pipelines/evaluation/response/summarization.py new file mode 100644 index 0000000000000000000000000000000000000000..ce4c8a84ce4365391f79015c31103962cde808e9 --- /dev/null +++ b/src/rag_pipelines/evaluation/response/summarization.py @@ -0,0 +1,158 @@ +from typing import Optional, Union + +import weave +from deepeval.metrics import SummarizationMetric +from deepeval.test_case import LLMTestCase +from weave import Scorer + + +class SummarizationScorer(Scorer): + """Summarization Scorer. + + This scorer uses DeepEval's `Summarization` Metric to assess how well the generated output + aligns with the reference context. + + The summarization metric uses LLMs to determine whether the LLM application is generating factually correct + summaries while including the neccessary details from the original text. + + Attributes: + threshold (float): Minimum passing threshold, defaults to 0.5. + model (str): LLM model for scoring, defaults to "gpt-4". + assessment_questions: a list of close-ended questions that can be answered with either a 'yes' or a 'no'. + These are questions you want your summary to be able to ideally answer, + and is especially helpful if you already know what a good summary for your use case looks like. If + include_reason (bool): Include reason for the evaluation score, defaults to True. + strict_mode (bool): Enforces binary metric scoring (1 or 0), defaults to False. + async_mode (bool): Use asynchronous scoring, defaults to True. + verbose (bool): Print intermediate steps used for scoring, defaults to False. + truths_extraction_limit (Optional[int]): Maximum number of factual truths to extract + from the retrieval_context. Defaults to None. + metric (SummarizationMetric): An instance of DeepEval's `SummarizationMetric` for scoring. + """ + + threshold: float = Optional[None] + model: str = Optional[None] + include_reason: bool = Optional[None] + strict_mode: bool = Optional[None] + async_mode: bool = Optional[None] + verbose: bool = Optional[None] + assessment_questions: Optional[list[str]] = Optional[None] + n: Optional[int] = Optional[None] + truths_extraction_limit: Optional[int] = Optional[None] + metric: SummarizationMetric = Optional[None] + + def __init__( + self, + threshold: float = 0.5, + model: str = "gpt-4", + include_reason: bool = True, + strict_mode: bool = False, + async_mode: bool = True, + verbose: bool = False, + assessment_questions: Optional[list[str]] = None, + n: Optional[int] = 5, + truths_extraction_limit: Optional[int] = None, + ): + """Initialize the Summarization Scorer with DeepEval's Summarization Metric. + + Args: + threshold (float): Minimum passing threshold, defaults to 0.5. + model (str): LLM model for scoring, defaults to "gpt-4". + include_reason (bool): Include reason for the evaluation score, defaults to True. + strict_mode (bool): Enforces binary metric scoring (1 or 0), defaults to False. + async_mode (bool): Use asynchronous scoring, defaults to True. + verbose (bool): Print intermediate steps used for scoring, defaults to False. + assessment_questions (Optional[list[str]]): a list of close-ended questions that can be answered with either + a 'yes' or a 'no'. These are questions you want your summary to be able to ideally answer, and is + especially helpful if you already know what a good summary for your use case looks like. If + `assessment_questions` is not provided, the metric will generate a set of `assessment_questions` at + evaluation time. + n (Optional[int]): The number of assessment questions to generate when `assessment_questions` is not + provided. Defaults to 5. + truths_extraction_limit (Optional[int]): Maximum number of factual truths to extract + from the retrieval_context. Defaults to None. + """ + super().__init__( + threshold=threshold, + model=model, + include_reason=include_reason, + strict_mode=strict_mode, + async_mode=async_mode, + verbose=verbose, + assessment_questions=assessment_questions, + n=n, + truths_extraction_limit=truths_extraction_limit, + ) + + self.threshold = threshold + self.model = model + self.include_reason = include_reason + self.strict_mode = strict_mode + self.async_mode = async_mode + self.verbose = verbose + self.assessment_questions = assessment_questions + self.n = n + self.truths_extraction_limit = truths_extraction_limit + + self.metric = SummarizationMetric( + threshold=self.threshold, + model=self.model, + include_reason=self.include_reason, + async_mode=self.async_mode, + strict_mode=self.strict_mode, + verbose_mode=self.verbose, + assessment_questions=self.assessment_questions, + n=self.n, + truths_extraction_limit=self.truths_extraction_limit, + ) + + @weave.op + def score( + self, + input: str, + actual_output: str, + expected_output: Optional[str] = None, + retrieval_context: Optional[list[str]] = None, + context: Optional[list[str]] = None, + ) -> dict[str, Union[str, float]]: + """Evaluate the quality of summarization of an LLM generated response. + + The Summarization score is calculated according to the following equation: + + Summarization = min(Alignment Score, Coverage Score) + + where, + - Alignment Score: determines whether the summary contains hallucinated or contradictory information to the original text. + - Coverage Score: determines whether the summary contains the neccessary information from the original text. + + + While the Alignment Score is similar to that of the Hallucination Score, the Coverage Score is first calculated + by generating n closed-ended questions that can only be answered with either a 'yes or a 'no', before + calculating the ratio of which the original text and summary yields the same answer. + + Args: + input (str): The input query or prompt that triggered the output. + actual_output (str): The LLM generated response to evaluate. + expected_output (Optional[str]): The expected or reference output, defaults to None. + retrieval_context (Optional[list[str]]): The context containing factual information to compare against. + context (Optional[list[str]]): Additional context for the evaluation, defaults to None. + + Returns: + dict[str, Union[str, float]]: A dictionary containing: + - "score" (float): The computed summarization score. + - "reason" (str): A detailed explanation for the assigned score. + """ + test_case = LLMTestCase( + input=input, + actual_output=actual_output, + expected_output=expected_output, + retrieval_context=retrieval_context, + context=context, + ) + + result: dict[str, Union[str, float]] = {} + + self.metric.measure(test_case) + result = {"score": self.metric.score, "reason": self.metric.reason} + + return result diff --git a/src/rag_pipelines/evaluation/retrieval/__init__.py b/src/rag_pipelines/evaluation/retrieval/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/rag_pipelines/evaluation/retrieval/contextual_precision.py b/src/rag_pipelines/evaluation/retrieval/contextual_precision.py new file mode 100644 index 0000000000000000000000000000000000000000..e9d1b03c4c939d0e604c5a26c964ab709c56edd2 --- /dev/null +++ b/src/rag_pipelines/evaluation/retrieval/contextual_precision.py @@ -0,0 +1,160 @@ +from typing import Optional, Union + +import numpy as np +import weave +from deepeval.metrics import ContextualPrecisionMetric +from deepeval.test_case import LLMTestCase +from weave import Scorer + + +class ContextualPrecisionScorer(Scorer): + """Evaluate the contextual precision of the generated output with the provided context. + + This scorer uses DeepEval's `Contextual Precision` Metric to assess how well the generated output + aligns with the reference context. + + The contextual precision metric measures the quality of the pipeline's retriever by evaluating whether results in the `retrieval_context` that are relevant to the given input are ranked higher than irrelevant ones. + + Attributes: + threshold (float): A float representing the minimum passing threshold, defaults to 0.5. + model (str): The LLM model to use for scoring, defaults to "gpt-4". + include_reason (bool): Whether to include a reason for the evaluation score, defaults to True. + strict_mode (bool): A boolean which when set to True, enforces a binary metric score: 1 for perfection, + 0 otherwise. It also overrides the current threshold and sets it to 1. Defaults to False. + async_mode (bool): Whether to use asynchronous scoring, defaults to True. + verbose (bool): Whether to print the intermediate steps used to calculate said metric to the console, defaults + to False. + metric (ContextualPrecisionMetric): The DeepEval ContextualPrecisionMetric. + """ + + threshold: float = Optional[None] + model: str = Optional[None] + include_reason: bool = Optional[None] + strict_mode: bool = Optional[None] + async_mode: bool = Optional[None] + verbose: bool = Optional[None] + metric: ContextualPrecisionMetric = Optional[None] + + def __init__( + self, + threshold: float = 0.5, + model: str = "gpt-4", + include_reason: bool = True, + strict_mode: bool = True, + async_mode: bool = True, + verbose: bool = False, + ): + """Initialize the Contextual Precision Scorer using DeepEval's Contextual Precision Metric. + + Args: + threshold (float): A float representing the minimum passing threshold, defaults to 0.5. + model (str): The LLM model to use for scoring, defaults to "gpt-4". + include_reason (bool): Whether to include a reason for the evaluation score, defaults to True. + strict_mode (bool): A boolean which when set to True, enforces a binary metric score: 1 for perfection, + 0 otherwise. It also overrides the current threshold and sets it to 1. Defaults to False. + async_mode (bool): Whether to use asynchronous scoring, defaults to True. + verbose (bool): Whether to print the intermediate steps used to calculate said metric to the console, defaults + to False. + """ + super().__init__( + threshold=threshold, + model=model, + include_reason=include_reason, + strict_mode=strict_mode, + async_mode=async_mode, + verbose=verbose, + ) + + self.threshold = threshold + self.model = model + self.include_reason = include_reason + self.strict_mode = strict_mode + self.async_mode = async_mode + self.verbose = verbose + + self.metric = ContextualPrecisionMetric( + threshold=self.threshold, + model=self.model, + include_reason=self.include_reason, + async_mode=self.async_mode, + strict_mode=self.strict_mode, + verbose_mode=self.verbose, + ) + + @weave.op + def score( + self, + input: str, + output: Optional[dict] = None, + expected_output: Optional[str] = None, + context: Optional[list[str]] = None, + ) -> dict[str, Union[str, float]]: + """Evaluate the contextual precision of the generated output with the provided context. + + The Contextual Precision Score is calculated according to the following equation: + + Contextual Precision = (1 / Number of Relevant Results) * (Sum(Number of Relevant Results up to position k) / k) * Binary Relevance of k'th result) + + where, + - k: The position of the result in the list of all results. + + The Contextual Precision Scorer first uses an LLM to determine for each result in the `retrieval_context` + whether it is relevant to the input based on information in the `expected_output`, before calculating the + weighted cumulative precision as the contextual precision score. + + Args: + input (str): The input query or prompt that triggered the output. + output (str): The LLM generated response to evaluate. + expected_output (Optional[str]): The expected or reference output, defaults to None. + retrieval_context (Optional[list[str]]): The context containing factual information to compare against. + context (Optional[list[str]]): Additional context for the evaluation, defaults to None. + + Returns: + dict[str, Union[str, float]]: A dictionary containing: + - "score" (float): The computed contextual precision score. + """ + test_case = LLMTestCase( + input=input, + actual_output=output.get("output", ""), + expected_output=expected_output, + retrieval_context=output.get("retrieval_context", [""]), + context=context, + ) + + result: dict[str, Union[str, float]] = {} + + self.metric.measure(test_case) + result = { + "score": self.metric.score, + } + + return result + + @weave.op() + def summarize(self, score_rows: list) -> dict: + """Summarize the results of the Contextual Precision Scorer. + + Args: + score_rows (list): A list of dictionaries containing the following keys: + - "score" (float): The computed answer relevancy score. + - "reason" (str): A detailed explanation for the assigned score. + + Returns: + dict: A dictionary containing the following keys: + - "answer_relevancy_score" (dict): A dictionary containing the following keys: + - "score" (float): The average answer relevancy score. + - "variance" (float): The variance of the answer relevancy scores. + - "std" (float): The standard deviation of the answer relevancy scores. + - "count" (int): The number of answer relevancy scores. + """ + scores = [] + for row in score_rows: + score = row.get("score", 0.0) + scores.append(float(score)) + + score = np.mean(scores).item() + variance = np.var(scores).item() + std = np.std(scores).item() + count = len(scores) + + return {"contextual_precision_score": {"score": score, "variance": variance, "std": std, "count": count}} diff --git a/src/rag_pipelines/evaluation/retrieval/contextual_recall.py b/src/rag_pipelines/evaluation/retrieval/contextual_recall.py new file mode 100644 index 0000000000000000000000000000000000000000..cd7e724264ecc7afbc85bbdef62e6b2924c80d48 --- /dev/null +++ b/src/rag_pipelines/evaluation/retrieval/contextual_recall.py @@ -0,0 +1,127 @@ +from typing import Optional, Union + +import weave +from deepeval.metrics import ContextualRecallMetric +from deepeval.test_case import LLMTestCase +from weave import Scorer + + +class ContextualRecallScorer(Scorer): + """Evaluate the contextual recall of the generated output with the provided context. + + This scorer uses DeepEval's `ContextualRecall` Metric to assess how well the generated output + aligns with the reference context. + + The contextual recall metric measures the quality of the pipeline's retriever by evaluating the extent of which the `retrieval_context` aligns with the `expected_output`. + + Attributes: + threshold (float): A float representing the minimum passing threshold, defaults to 0.5. + model (str): The LLM model to use for scoring, defaults to "gpt-4". + include_reason (bool): Whether to include a reason for the evaluation score, defaults to True. + strict_mode (bool): A boolean which when set to True, enforces a binary metric score: 1 for perfection, + 0 otherwise. It also overrides the current threshold and sets it to 1. Defaults to False. + async_mode (bool): Whether to use asynchronous scoring, defaults to True. + verbose (bool): Whether to print the intermediate steps used to calculate said metric to the console, defaults + to False. + metric (ContextualRecallMetric): The DeepEval ContextualRecallMetric. + """ + + threshold: float = Optional[None] + model: str = Optional[None] + include_reason: bool = Optional[None] + strict_mode: bool = Optional[None] + async_mode: bool = Optional[None] + verbose: bool = Optional[None] + metric: ContextualRecallMetric = Optional[None] + + def __init__( + self, + threshold: float = 0.5, + model: str = "gpt-4", + include_reason: bool = True, + strict_mode: bool = True, + async_mode: bool = True, + verbose: bool = False, + ): + """Initialize the Contextual Recall Scorer using DeepEval's Contextual Recall Metric. + + Args: + threshold (float): A float representing the minimum passing threshold, defaults to 0.5. + model (str): The LLM model to use for scoring, defaults to "gpt-4". + include_reason (bool): Whether to include a reason for the evaluation score, defaults to True. + strict_mode (bool): A boolean which when set to True, enforces a binary metric score: 1 for perfection, + 0 otherwise. It also overrides the current threshold and sets it to 1. Defaults to False. + async_mode (bool): Whether to use asynchronous scoring, defaults to True. + verbose (bool): Whether to print the intermediate steps used to calculate said metric to the console, defaults + to False. + """ + super().__init__( + threshold=threshold, + model=model, + include_reason=include_reason, + strict_mode=strict_mode, + async_mode=async_mode, + verbose=verbose, + ) + + self.threshold = threshold + self.model = model + self.include_reason = include_reason + self.strict_mode = strict_mode + self.async_mode = async_mode + self.verbose = verbose + + self.metric = ContextualRecallMetric( + threshold=self.threshold, + model=self.model, + include_reason=self.include_reason, + async_mode=self.async_mode, + strict_mode=self.strict_mode, + verbose_mode=self.verbose, + ) + + @weave.op + def score( + self, + input: str, + actual_output: str, + expected_output: Optional[str] = None, + retrieval_context: Optional[list[str]] = None, + context: Optional[list[str]] = None, + ) -> dict[str, Union[str, float]]: + """Evaluate the contextual recall of the generated output with the provided context. + + The Contextual Recall Score is calculated according to the following equation: + + Contextual Recall = Number of Attributable Results / Total Number of Results + + he Contextual Recall Scorer first uses an LLM to extract all statements made in the `expected_output`, before using the same LLM to classify whether each statement can be attributed to results in the `retrieval_context`. + + A higher contextual recall score represents a greater ability of the retrieval system to capture all relevant information from the total available relevant set within your knowledge base. + + Args: + input (str): The input query or prompt that triggered the output. + actual_output (str): The LLM generated response to evaluate. + expected_output (Optional[str]): The expected or reference output, defaults to None. + retrieval_context (Optional[list[str]]): The context containing factual information to compare against. + context (Optional[list[str]]): Additional context for the evaluation, defaults to None. + + Returns: + dict[str, Union[str, float]]: A dictionary containing: + - "score" (float): The computed contextual recall score. + - "reason" (str): A detailed explanation for the assigned score. + """ + test_case = LLMTestCase( + input=input, + actual_output=actual_output, + expected_output=expected_output, + retrieval_context=retrieval_context, + context=context, + ) + + result: dict[str, Union[str, float]] = {} + + self.metric.measure(test_case) + result = {"score": self.metric.score, "reason": self.metric.reason} + + return result diff --git a/src/rag_pipelines/evaluation/retrieval/contextual_relevancy.py b/src/rag_pipelines/evaluation/retrieval/contextual_relevancy.py new file mode 100644 index 0000000000000000000000000000000000000000..1038f6b935ddf1755eefbf9339b4dcb8a2f51b20 --- /dev/null +++ b/src/rag_pipelines/evaluation/retrieval/contextual_relevancy.py @@ -0,0 +1,125 @@ +from typing import Optional, Union + +import weave +from deepeval.metrics import ContextualRelevancyMetric +from deepeval.test_case import LLMTestCase +from weave import Scorer + + +class ContextualRelevancyScorer(Scorer): + """Evaluate the contextual relevancy of the generated output with the provided context. + + This scorer uses DeepEval's `ContextualRelevancy` Metric to assess how well the generated output + aligns with the reference context. + + The contextual relevancy metric measures the quality of the RAG pipeline's retriever by evaluating the overall relevance of the information presented in the `retrieval_context` for a given input. + + Attributes: + threshold (float): A float representing the minimum passing threshold, defaults to 0.5. + model (str): The LLM model to use for scoring, defaults to "gpt-4". + include_reason (bool): Whether to include a reason for the evaluation score, defaults to True. + strict_mode (bool): A boolean which when set to True, enforces a binary metric score: 1 for perfection, + 0 otherwise. It also overrides the current threshold and sets it to 1. Defaults to False. + async_mode (bool): Whether to use asynchronous scoring, defaults to True. + verbose (bool): Whether to print the intermediate steps used to calculate said metric to the console, defaults + to False. + metric (ContextualRelevancyMetric): The DeepEval ContextualRelevancyMetric. + """ + + threshold: float = Optional[None] + model: str = Optional[None] + include_reason: bool = Optional[None] + strict_mode: bool = Optional[None] + async_mode: bool = Optional[None] + verbose: bool = Optional[None] + metric: ContextualRelevancyMetric = Optional[None] + + def __init__( + self, + threshold: float = 0.5, + model: str = "gpt-4", + include_reason: bool = True, + strict_mode: bool = True, + async_mode: bool = True, + verbose: bool = False, + ): + """Initialize the Contextual Relevancy Scorer using DeepEval's Contextual Relevancy Metric. + + Args: + threshold (float): A float representing the minimum passing threshold, defaults to 0.5. + model (str): The LLM model to use for scoring, defaults to "gpt-4". + include_reason (bool): Whether to include a reason for the evaluation score, defaults to True. + strict_mode (bool): A boolean which when set to True, enforces a binary metric score: 1 for perfection, + 0 otherwise. It also overrides the current threshold and sets it to 1. Defaults to False. + async_mode (bool): Whether to use asynchronous scoring, defaults to True. + verbose (bool): Whether to print the intermediate steps used to calculate said metric to the console, defaults + to False. + """ + super().__init__( + threshold=threshold, + model=model, + include_reason=include_reason, + strict_mode=strict_mode, + async_mode=async_mode, + verbose=verbose, + ) + + self.threshold = threshold + self.model = model + self.include_reason = include_reason + self.strict_mode = strict_mode + self.async_mode = async_mode + self.verbose = verbose + + self.metric = ContextualRelevancyMetric( + threshold=self.threshold, + model=self.model, + include_reason=self.include_reason, + async_mode=self.async_mode, + strict_mode=self.strict_mode, + verbose_mode=self.verbose, + ) + + @weave.op + def score( + self, + input: str, + actual_output: str, + expected_output: Optional[str] = None, + retrieval_context: Optional[list[str]] = None, + context: Optional[list[str]] = None, + ) -> dict[str, Union[str, float]]: + """Evaluate the contextual relevancy of the generated output with the provided context. + + The Contextual Relevancy Score is calculated according to the following equation: + + Contextual Relevancy = Number of Relevant Results / Total Number of Results + + Although similar to how the Answer Relevancy Score is calculated, the Contextual Relevancy Metric first uses an LLM to extract all statements made in the `retrieval_context` instead, before using the same LLM to classify whether each statement is relevant to the input. + + Args: + input (str): The input query or prompt that triggered the output. + actual_output (str): The LLM generated response to evaluate. + expected_output (Optional[str]): The expected or reference output, defaults to None. + retrieval_context (Optional[list[str]]): The context containing factual information to compare against. + context (Optional[list[str]]): Additional context for the evaluation, defaults to None. + + Returns: + dict[str, Union[str, float]]: A dictionary containing: + - "score" (float): The computed contextual relevancy score. + - "reason" (str): A detailed explanation for the assigned score. + """ + test_case = LLMTestCase( + input=input, + actual_output=actual_output, + expected_output=expected_output, + retrieval_context=retrieval_context, + context=context, + ) + + result: dict[str, Union[str, float]] = {} + + self.metric.measure(test_case) + result = {"score": self.metric.score, "reason": self.metric.reason} + + return result diff --git a/src/rag_pipelines/llms/__init__.py b/src/rag_pipelines/llms/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ed60fdeefcdf91ad8e11a0908d9bb07a3af2b128 --- /dev/null +++ b/src/rag_pipelines/llms/__init__.py @@ -0,0 +1,3 @@ +from rag_pipelines.llms.groq import ChatGroqGenerator + +__all__ = ["ChatGroqGenerator"] diff --git a/src/rag_pipelines/llms/__pycache__/__init__.cpython-310.pyc b/src/rag_pipelines/llms/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d9fed28bc7a7e4e339ef856ea51931068c67fb51 Binary files /dev/null and b/src/rag_pipelines/llms/__pycache__/__init__.cpython-310.pyc differ diff --git a/src/rag_pipelines/llms/__pycache__/groq.cpython-310.pyc b/src/rag_pipelines/llms/__pycache__/groq.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..683fb2750c62d2e2550eb1d6a84bd0e83c36657e Binary files /dev/null and b/src/rag_pipelines/llms/__pycache__/groq.cpython-310.pyc differ diff --git a/src/rag_pipelines/llms/groq.py b/src/rag_pipelines/llms/groq.py new file mode 100644 index 0000000000000000000000000000000000000000..147953837d5641c451611f2128b9b4cd6f5de64b --- /dev/null +++ b/src/rag_pipelines/llms/groq.py @@ -0,0 +1,99 @@ +import os +from typing import Any, Optional + +import weave +from langchain_core.prompts import ChatPromptTemplate +from langchain_groq import ChatGroq +from pydantic import BaseModel + +from rag_pipelines.prompts import STRUCTURED_RAG_PROMPT, RAGResponseModel + + +class ChatGroqGenerator: + """Interact with the ChatGroq model to generate responses based on user queries and documents. + + This class provides an interface for generating responses using the ChatGroq model. + It handles prompt formatting, LLM invocation, document integration, and result generation. + """ + + model: str + api_key: str + llm_params: dict[str, Any] + llm: Optional[ChatGroq] = None + structured_output_model: BaseModel + system_prompt: str + + def __init__( + self, + model: str, + api_key: Optional[str] = None, + llm_params: Optional[dict[str, Any]] = None, + structured_output_model: BaseModel = RAGResponseModel, + system_prompt: str = STRUCTURED_RAG_PROMPT, + ): + """Initialize the ChatGroqGenerator with configuration parameters. + + Args: + model (str): The name of the ChatGroq model to use. + api_key (Optional[str]): API key for the ChatGroq service. If not provided, + the "GROQ_API_KEY" environment variable will be used. + llm_params (Optional[dict]): Additional parameters for configuring the ChatGroq model. + structured_output_model (BaseModel): The output model for structured responses. + system_prompt (str): The system prompt for the ChatGroq model. + + Raises: + ValueError: If the API key is not provided and the "GROQ_API_KEY" environment variable is not set. + """ + if llm_params is None: + llm_params = {} + + api_key = api_key or os.environ.get("GROQ_API_KEY") + if api_key is None: + msg = "GROQ_API_KEY is not set. Please provide an API key or set it as an environment variable." + raise ValueError(msg) + + self.model = model + self.api_key = api_key + self.llm_params = llm_params + + self.structured_output_model = structured_output_model + self.system_prompt = system_prompt + + self.llm = ChatGroq(model=self.model, api_key=self.api_key, **llm_params) + + @weave.op() + def __call__(self, state: dict[str, Any]) -> dict[str, Any]: + """Generate a response using the current state of user prompts and graded documents. + + Args: + state (dict[str, Any]): The current state, containing: + - 'question': The user question. + - 'context': A list of filtered document texts. + - 'documents': A list of retrieved documents. + + Returns: + dict[str, Any]: A dictionary containing: + - 'question': The user question. + - 'context': A list of filtered document texts. + - 'documents': A list of retrieved documents. + - 'answer': The generated response. + """ + question = state["question"] + context = state["context"] + documents = state["documents"] + + formatted_context = "\n".join(context) + + prompt = ChatPromptTemplate.from_messages( + [ + ("system", self.system_prompt), + ] + ) + + response_chain = prompt | self.llm.with_structured_output(self.structured_output_model) + + response = response_chain.invoke({"question": question, "context": formatted_context}) + + answer = response.final_answer + + return {"question": question, "context": context, "documents": documents, "answer": answer} diff --git a/src/rag_pipelines/pipelines/__init__.py b/src/rag_pipelines/pipelines/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b2cbc03925ecf17a734f1b4c35886034df4c05b4 --- /dev/null +++ b/src/rag_pipelines/pipelines/__init__.py @@ -0,0 +1,3 @@ +from rag_pipelines.pipelines.self_rag import SelfRAGPipeline + +__all__ = ["SelfRAGPipeline"] diff --git a/src/rag_pipelines/pipelines/__pycache__/__init__.cpython-310.pyc b/src/rag_pipelines/pipelines/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b25de8daa571af2cd65dde07b1389e3bf8f84013 Binary files /dev/null and b/src/rag_pipelines/pipelines/__pycache__/__init__.cpython-310.pyc differ diff --git a/src/rag_pipelines/pipelines/__pycache__/self_rag.cpython-310.pyc b/src/rag_pipelines/pipelines/__pycache__/self_rag.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f7802384dbba09e10bd80f33434dfeac79d69a41 Binary files /dev/null and b/src/rag_pipelines/pipelines/__pycache__/self_rag.cpython-310.pyc differ diff --git a/src/rag_pipelines/pipelines/__pycache__/self_rag_graph_state.cpython-310.pyc b/src/rag_pipelines/pipelines/__pycache__/self_rag_graph_state.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6116ca1a1ddbeee9d3e40e665408c17975bf24f4 Binary files /dev/null and b/src/rag_pipelines/pipelines/__pycache__/self_rag_graph_state.cpython-310.pyc differ diff --git a/src/rag_pipelines/pipelines/adaptive_rag.py b/src/rag_pipelines/pipelines/adaptive_rag.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/rag_pipelines/pipelines/adaptive_rag_graph_state.py b/src/rag_pipelines/pipelines/adaptive_rag_graph_state.py new file mode 100644 index 0000000000000000000000000000000000000000..73ff617acb1202c126efb358b7496499763dcf89 --- /dev/null +++ b/src/rag_pipelines/pipelines/adaptive_rag_graph_state.py @@ -0,0 +1,18 @@ +from langchain_core.documents import Document +from typing_extensions import TypedDict + + +class AdaptiveRAGGraphState(TypedDict): + """Represents the state of the graph for the Adaptive Retrieval-Augmentation-Generation (Adaptive-RAG) pipeline. + + Attributes: + question (str): The input question for the pipeline. + answer (str): The generated response from the LLM. + documents (list[Document]): A list of LangChain documents that are retrieved and processed through the pipeline. + context (list[str]): The final list of context documents passed to the LLM for generating the answer. + """ + + question: str + answer: str + documents: list[Document] + context: list[str] diff --git a/src/rag_pipelines/pipelines/crag.py b/src/rag_pipelines/pipelines/crag.py new file mode 100644 index 0000000000000000000000000000000000000000..e24739c70bbd1ef19f3b56817694c0535bfe9d4f --- /dev/null +++ b/src/rag_pipelines/pipelines/crag.py @@ -0,0 +1,172 @@ +import os +from typing import Any, Optional + +import weave +from langchain_community.retrievers import PineconeHybridSearchRetriever +from langchain_core.prompts.chat import ChatPromptTemplate +from langgraph.graph import END, START, StateGraph +from langgraph.graph.state import CompiledStateGraph +from weave.integrations.langchain import WeaveTracer + +from rag_pipelines.llms.groq import ChatGroqGenerator +from rag_pipelines.pipelines.crag_graph_state import CRAGGraphState +from rag_pipelines.query_transformer import QueryTransformer +from rag_pipelines.retrieval_evaluator import DocumentGrader, QueryDecisionMaker +from rag_pipelines.websearch import WebSearch + +# Disable global tracing explicitly +os.environ["WEAVE_TRACE_LANGCHAIN"] = "false" + + +class CorrectiveRAGPipeline(weave.Model): + """A corrective retrieval-augmented generation (RAG) pipeline using Weave for tracing and LangChain components. + + This pipeline integrates document retrieval, relevance evaluation, grading, query transformation, web search, + and LLM-based response generation to implement a corrective RAG system. It utilizes Weave for tracing execution + details and LangChain components for processing. + + Attributes: + retriever (Optional[PineconeHybridSearchRetriever]): The retrieval model used to fetch relevant documents based on a query. + prompt (Optional[ChatPromptTemplate]): The prompt template to generate questions for the LLM. + generator (Optional[ChatGroqGenerator]): The language model used to generate responses. + grader (Optional[DocumentGrader]): Grades documents based on evaluation results. + query_transformer (Optional[QueryTransformer]): Transforms user queries to optimize retrieval. + web_search (Optional[WebSearch]): Performs web search for additional context. + tracing_project_name (str): The name of the Weave project for tracing. + weave_params (Dict[str, Any]): Parameters for initializing Weave. + tracer (Optional[WeaveTracer]): The tracer used to record execution details with Weave. + """ + + retriever: Optional[PineconeHybridSearchRetriever] = None + prompt: Optional[ChatPromptTemplate] = None + generator: Optional[ChatGroqGenerator] = None + grader: Optional[DocumentGrader] = None + query_transformer: Optional[QueryTransformer] = None + web_search: Optional[WebSearch] = None + tracing_project_name: str + weave_params: dict[str, Any] + tracer: Optional[WeaveTracer] = None + + def __init__( + self, + retriever: PineconeHybridSearchRetriever, + prompt: ChatPromptTemplate, + generator: ChatGroqGenerator, + grader: DocumentGrader, + query_transformer: QueryTransformer, + web_search: WebSearch, + tracing_project_name: str = "corrective_rag", + weave_params: Optional[dict[str, Any]] = None, + ): + """Initialize the CorrectiveRAGPipeline. + + Args: + retriever (PineconeHybridSearchRetriever): The retrieval model used to fetch documents for the RAG pipeline. + prompt (ChatPromptTemplate): The prompt template used to create questions for the LLM. + generator (ChatGroqGenerator): The language model used for response generation. + grader (DocumentGrader): Component to grade the relevance of evaluated documents. + query_transformer (QueryTransformer): Component to transform the user query. + web_search (WebSearch): Component to perform web search for additional context. + tracing_project_name (str): The name of the Weave project for tracing. Defaults to "corrective_rag". + weave_params (Dict[str, Any]): Additional parameters for initializing Weave. + """ + if weave_params is None: + weave_params = {} + + super().__init__( + retriever=retriever, + prompt=prompt, + generator=generator, + grader=grader, + query_transformer=query_transformer, + web_search=web_search, + tracing_project_name=tracing_project_name, + weave_params=weave_params, + ) + + self.retriever = retriever + self.prompt = prompt + self.generator = generator + self.grader = grader + self.query_transformer = query_transformer + self.web_search = web_search + self.tracing_project_name = tracing_project_name + self.weave_params = weave_params + + self._initialize_weave(**weave_params) + + def _initialize_weave(self, **weave_params) -> None: + """Initialize Weave with the specified tracing project name. + + Sets up the Weave environment and creates a tracer for monitoring pipeline execution. + + Args: + weave_params (Dict[str, Any]): Additional parameters for configuring Weave. + """ + weave.init(self.tracing_project_name, **weave_params) + self.tracer = WeaveTracer() + + def _build_crag_graph(self) -> CompiledStateGraph: + """Build and compile the corrective RAG workflow graph. + + The graph defines the flow between components like retrieval, grading, query transformation, + web search, and generation. + + Returns: + CompiledStateGraph: The compiled state graph representing the corrective RAG pipeline workflow. + """ + crag_workflow = StateGraph(CRAGGraphState) + + # Define the nodes + crag_workflow.add_node("retrieve", self.retriever) + crag_workflow.add_node("grade_documents", self.grader) + crag_workflow.add_node("generate", self.generator) + crag_workflow.add_node("transform_query", self.query_transformer) + crag_workflow.add_node("web_search_node", self.web_search) + + # Define edges between nodes + crag_workflow.add_edge(START, "retrieve") + crag_workflow.add_edge("retrieve", "grade_documents") + crag_workflow.add_conditional_edges( + "grade_documents", + QueryDecisionMaker(), + { + "transform_query": "transform_query", + "generate": "generate", + }, + ) + crag_workflow.add_edge("transform_query", "web_search_node") + crag_workflow.add_edge("web_search_node", "generate") + crag_workflow.add_edge("generate", END) + + # Compile the graph + crag_pipeline = crag_workflow.compile() + + return crag_pipeline + + @weave.op() + def predict(self, question: str) -> str: + """Execute the corrective RAG pipeline with the given question. + + The pipeline retrieves documents, evaluates and grades their relevance, and generates a final response + using the LLM. + + Args: + question (str): The input question to be answered. + + Returns: + str: The final answer generated by the LLM. + + Example: + ```python + pipeline = CorrectiveRAGPipeline(...) + answer = pipeline.predict("What are the latest AI trends?") + print(answer) + ``` + """ + config = {"callbacks": [self.tracer]} + + crag_graph = self._build_crag_graph() + response = crag_graph.invoke(question, config=config) + + return response diff --git a/src/rag_pipelines/pipelines/crag_graph_state.py b/src/rag_pipelines/pipelines/crag_graph_state.py new file mode 100644 index 0000000000000000000000000000000000000000..f37a33e746b0d5f084e7dea0a566b06df974094e --- /dev/null +++ b/src/rag_pipelines/pipelines/crag_graph_state.py @@ -0,0 +1,17 @@ +from typing_extensions import TypedDict + + +class CRAGGraphState(TypedDict): + """Represents the state of the graph for the Corrective Retrieval-Augmentation-Generation (CRAG) pipeline. + + Attributes: + question (str): The input question for the pipeline. + generation (str): The generated response from the LLM. + web_search (str): Indicates whether a web search is required (e.g., "yes" or "no"). + documents (List[str]): A list of relevant documents retrieved or processed. + """ + + question: str + generation: str + web_search: str + documents: list[str] diff --git a/src/rag_pipelines/pipelines/dspy/dspy_baseline_rag.py b/src/rag_pipelines/pipelines/dspy/dspy_baseline_rag.py new file mode 100644 index 0000000000000000000000000000000000000000..76c7562614293e3a444bcd54c484fb10ca6d74db --- /dev/null +++ b/src/rag_pipelines/pipelines/dspy/dspy_baseline_rag.py @@ -0,0 +1,46 @@ +import argparse + +import dspy +from datasets import load_dataset +from dspy_modules.evaluator import DSPyEvaluator +from dspy_modules.rag import DSPyRAG +from dspy_modules.weaviate_db import WeaviateVectorDB + + +def main(cluster_url, api_key, index_name, model_name, llm_model, llm_api_key): + # Load dataset + earnings_calls_data = load_dataset("lamini/earnings-calls-qa", split="train[:50]") + questions = earnings_calls_data["question"] + + # Split into datasets + [dspy.Example(question=q).with_inputs("question") for q in questions[:20]] + devset = [dspy.Example(question=q).with_inputs("question") for q in questions[20:30]] + [dspy.Example(question=q).with_inputs("question") for q in questions[30:]] + + # Initialize Weaviate VectorDB + weaviate_db = WeaviateVectorDB(cluster_url, api_key, index_name, model_name) + + # Initialize LLM + llm = dspy.LM(llm_model, api_key=llm_api_key, num_retries=120) + dspy.configure(lm=llm) + + # Initialize RAG + rag = DSPyRAG(weaviate_db) + + # Evaluate before compilation + evaluator = DSPyEvaluator() + evaluate = dspy.Evaluate(devset=devset, num_threads=1, display_progress=True, display_table=5) + evaluate(rag, metric=evaluator.llm_metric) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Run DSPy-based RAG pipeline") + parser.add_argument("--cluster_url", type=str, required=True, help="Weaviate cluster URL") + parser.add_argument("--api_key", type=str, required=True, help="Weaviate API key") + parser.add_argument("--index_name", type=str, required=True, help="Weaviate index name") + parser.add_argument("--model_name", type=str, required=True, help="Embedding model name") + parser.add_argument("--llm_model", type=str, required=True, help="LLM model name") + parser.add_argument("--llm_api_key", type=str, required=True, help="LLM API key") + + args = parser.parse_args() + main(args.cluster_url, args.api_key, args.index_name, args.model_name, args.llm_model, args.llm_api_key) diff --git a/src/rag_pipelines/pipelines/dspy/dspy_bayesian_signature_optimization_rag.py b/src/rag_pipelines/pipelines/dspy/dspy_bayesian_signature_optimization_rag.py new file mode 100644 index 0000000000000000000000000000000000000000..b5e718c718e6402c664f8b05933c05d2d6f52a75 --- /dev/null +++ b/src/rag_pipelines/pipelines/dspy/dspy_bayesian_signature_optimization_rag.py @@ -0,0 +1,124 @@ +import argparse + +import dspy +import weaviate +from datasets import load_dataset +from dspy.evaluate.evaluate import Evaluate +from dspy.primitives.prediction import Prediction +from dspy.teleprompt import BayesianSignatureOptimizer, BootstrapFewShotWithRandomSearch +from langchain_huggingface import HuggingFaceEmbeddings +from langchain_weaviate.vectorstores import WeaviateVectorStore +from weaviate.classes.init import Auth + +# Argument Parser +parser = argparse.ArgumentParser(description="RAG Optimization with DSPy") +parser.add_argument( + "--optimizer", + type=str, + choices=["bootstrap", "bayesian"], + default="bootstrap", + help="Choose the optimization method", +) +args = parser.parse_args() + +# Load dataset +earnings_calls_data = load_dataset("lamini/earnings-calls-qa", split="train[:50]") +questions = earnings_calls_data["question"] + +# Create DSPy datasets +trainset = [dspy.Example(question=q).with_inputs("question") for q in questions[:20]] +devset = [dspy.Example(question=q).with_inputs("question") for q in questions[20:30]] +testset = [dspy.Example(question=q).with_inputs("question") for q in questions[30:]] + +# Embeddings and Weaviate client +embeddings = HuggingFaceEmbeddings( + model_name="jinaai/jina-embeddings-v3", + model_kwargs={"device": "cpu", "trust_remote_code": True}, + encode_kwargs={"task": "retrieval.query", "prompt_name": "retrieval.query"}, +) + +weaviate_client = weaviate.connect_to_weaviate_cloud( + cluster_url="https://adrrwus9shkxkuijvazcrq.c0.us-west3.gcp.weaviate.cloud", + auth_credentials=Auth.api_key("J94gHySMWTWxggDDayGrF2ESGo23yOHZ1bUC"), +) +weaviate_db = WeaviateVectorStore( + index_name="LangChain_d73ad6159d514fec887456fa6db11e61", + embedding=embeddings, + client=weaviate_client, + text_key="text", +) + +# Configure LLM +llm = dspy.LM( + "groq/llama-3.3-70b-versatile", + api_key="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", + num_retries=120, +) +dspy.configure(lm=llm) + + +# Define DSPy Module +class GenerateAnswer(dspy.Signature): + context = dspy.InputField(desc="may contain relevant facts") + question = dspy.InputField() + answer = dspy.OutputField(desc="short and precise answer") + + +class RAG(dspy.Module): + def __init__(self): + super().__init__() + self.generate_answer = dspy.ChainOfThought(GenerateAnswer) + + def retrieve(self, question): + results = weaviate_db.similarity_search(query=question) + passages = [res.page_content for res in results] + return Prediction(passages=passages) + + def forward(self, question): + context = self.retrieve(question).passages + prediction = self.generate_answer(context=context, question=question) + return dspy.Prediction(context=context, answer=prediction.answer) + + +# Define LLM Metric +def llm_metric(gold, pred, trace=None): + predicted_answer = pred.answer + context = pred.context + detail = dspy.ChainOfThought(GenerateAnswer)( + context="N/A", assessed_question="Is the answer detailed?", assessed_answer=predicted_answer + ) + faithful = dspy.ChainOfThought(GenerateAnswer)( + context=context, assessed_question="Is it grounded in context?", assessed_answer=predicted_answer + ) + overall = dspy.ChainOfThought(GenerateAnswer)( + context=context, assessed_question=f"Rate the answer: {predicted_answer}", assessed_answer=predicted_answer + ) + total = float(detail.answer) + float(faithful.answer) * 2 + float(overall.answer) + return total / 5.0 + + +# Evaluate before optimization +evaluate = Evaluate(devset=devset, num_threads=1, display_progress=True, display_table=5) +evaluate(RAG(), metric=llm_metric) + +# Select Optimizer +if args.optimizer == "bootstrap": + optimizer = BootstrapFewShotWithRandomSearch( + metric=llm_metric, + max_bootstrapped_demos=4, + max_labeled_demos=4, + max_rounds=1, + num_candidate_programs=2, + num_threads=2, + ) +else: + optimizer = BayesianSignatureOptimizer( + task_model=dspy.settings.lm, metric=llm_metric, prompt_model=dspy.settings.lm, n=5, verbose=False + ) + +# Compile optimized RAG +optimized_compiled_rag = optimizer.compile(RAG(), trainset=trainset) + +# Evaluate optimized RAG +evaluate = Evaluate(metric=llm_metric, devset=devset, num_threads=1, display_progress=True, display_table=5) +evaluate(optimized_compiled_rag) diff --git a/src/rag_pipelines/pipelines/dspy/dspy_bootstrap_few_shot_optimization_rag.py b/src/rag_pipelines/pipelines/dspy/dspy_bootstrap_few_shot_optimization_rag.py new file mode 100644 index 0000000000000000000000000000000000000000..3b0695718cf77661404b543b624871e94640afa5 --- /dev/null +++ b/src/rag_pipelines/pipelines/dspy/dspy_bootstrap_few_shot_optimization_rag.py @@ -0,0 +1,60 @@ +import argparse + +import dspy +from datasets import load_dataset +from dspy.evaluate.evaluate import Evaluate +from dspy.teleprompt import BootstrapFewShot +from dspy_modules.evaluator import llm_metric +from dspy_modules.rag import RAG +from dspy_modules.weaviate_db import WeaviateVectorDB + + +def main(args): + # Load dataset + earnings_calls_data = load_dataset("lamini/earnings-calls-qa", split="train[:50]") + questions = earnings_calls_data["question"] + + # Split dataset + trainset = [dspy.Example(question=q).with_inputs("question") for q in questions[:20]] + devset = [dspy.Example(question=q).with_inputs("question") for q in questions[20:30]] + [dspy.Example(question=q).with_inputs("question") for q in questions[30:]] + + # Initialize Weaviate VectorDB + weaviate_db = WeaviateVectorDB( + cluster_url=args.cluster_url, api_key=args.api_key, index_name=args.index_name, model_name=args.embedding_model + ) + + # Initialize LLM + llm = dspy.LM(args.llm_model, api_key=args.llm_api_key, num_retries=args.num_retries) + dspy.configure(lm=llm) + + # Initialize and evaluate unoptimized RAG + uncompiled_rag = RAG(weaviate_db) + evaluate = Evaluate(devset=devset, num_threads=1, display_progress=True, display_table=5) + evaluate(uncompiled_rag, metric=llm_metric) + + # Optimize RAG using BootstrapFewShot + optimizer = BootstrapFewShot(metric=llm_metric) + optimized_compiled_rag = optimizer.compile(uncompiled_rag, trainset=trainset) + + # Evaluate optimized RAG + evaluate = Evaluate(devset=devset, num_threads=1, display_progress=True, display_table=5) + evaluate(optimized_compiled_rag) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="DSPy RAG Optimization Pipeline") + + # Weaviate parameters + parser.add_argument("--cluster_url", type=str, required=True, help="Weaviate cluster URL") + parser.add_argument("--api_key", type=str, required=True, help="Weaviate API key") + parser.add_argument("--index_name", type=str, required=True, help="Weaviate index name") + parser.add_argument("--embedding_model", type=str, default="jinaai/jina-embeddings-v3", help="Embedding model name") + + # LLM parameters + parser.add_argument("--llm_model", type=str, default="groq/llama-3.3-70b-versatile", help="LLM model name") + parser.add_argument("--llm_api_key", type=str, required=True, help="LLM API key") + parser.add_argument("--num_retries", type=int, default=120, help="Number of retries for LLM calls") + + args = parser.parse_args() + main(args) diff --git a/src/rag_pipelines/pipelines/dspy/dspy_copro_instruction_signature_optimization_rag.py b/src/rag_pipelines/pipelines/dspy/dspy_copro_instruction_signature_optimization_rag.py new file mode 100644 index 0000000000000000000000000000000000000000..f55dd6b3b6ed784fe0ff3ee2d3d06464a939a53a --- /dev/null +++ b/src/rag_pipelines/pipelines/dspy/dspy_copro_instruction_signature_optimization_rag.py @@ -0,0 +1,150 @@ +# https://github.com/weaviate/recipes/blob/main/integrations/llm-frameworks/dspy/1.Getting-Started-with-RAG-in-DSPy.ipynb + +import dspy +import weaviate +from datasets import load_dataset +from dspy.evaluate.evaluate import Evaluate +from dspy.primitives.prediction import Prediction +from dspy.teleprompt import COPRO +from langchain_huggingface import HuggingFaceEmbeddings +from langchain_weaviate.vectorstores import WeaviateVectorStore +from weaviate.classes.init import Auth + +earnings_calls_data = load_dataset("lamini/earnings-calls-qa", split="train[:50]") +questions = earnings_calls_data["question"] + +# Create the dspy datasets +trainset = questions[:20] # 20 examples for training +devset = questions[20:30] # 10 examples for development +testset = questions[30:] # 20 examples for testing + +trainset = [dspy.Example(question=question).with_inputs("question") for question in trainset] +devset = [dspy.Example(question=question).with_inputs("question") for question in devset] +testset = [dspy.Example(question=question).with_inputs("question") for question in testset] + + +model_name = "jinaai/jina-embeddings-v3" +task = "retrieval.query" +model_kwargs = {"device": "cpu", "trust_remote_code": True} +encode_kwargs = {"task": task, "prompt_name": task} +embeddings = HuggingFaceEmbeddings(model_name=model_name, model_kwargs=model_kwargs, encode_kwargs=encode_kwargs) + + +weaviate_client = weaviate.connect_to_weaviate_cloud( + cluster_url="https://adrrwus9shkxkuijvazcrq.c0.us-west3.gcp.weaviate.cloud", + auth_credentials=Auth.api_key("J94gHySMWTWxggDDayGrF2ESGo23yOHZ1bUC"), +) +weaviate_db = WeaviateVectorStore( + index_name="LangChain_d73ad6159d514fec887456fa6db11e61", + embedding=embeddings, + client=weaviate_client, + text_key="text", +) + + +llm = dspy.LM( + "groq/llama-3.3-70b-versatile", + api_key="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", + num_retries=120, +) +dspy.configure(lm=llm) + + +class GenerateAnswer(dspy.Signature): + """Answer questions with short factoid answers.""" + + context = dspy.InputField(desc="may contain relevant facts") + question = dspy.InputField() + answer = dspy.OutputField(desc="short and precise answer") + + +class RAG(dspy.Module): + def __init__(self): + super().__init__() + self.generate_answer = dspy.ChainOfThought(GenerateAnswer) + + # This makes it possible to use the Langchain VectorDB integration and custom embeddings with SentenceTransformers + def retrieve(self, question): + results = weaviate_db.similarity_search(query=question) + passages = [res.page_content for res in results] + return Prediction(passages=passages) + + def forward(self, question): + context = self.retrieve(question).passages + prediction = self.generate_answer(context=context, question=question) + return dspy.Prediction(context=context, answer=prediction.answer) + + +# Create an LLM as a Judge Evaluation Metric for evaluation of the RAG Pipelines +# (Taken from weaviate recipe) + + +class Assess(dspy.Signature): + """Assess the quality of an answer to a question.""" + + context = dspy.InputField(desc="The context for answering the question.") + assessed_question = dspy.InputField(desc="The evaluation criterion.") + assessed_answer = dspy.InputField(desc="The answer to the question.") + assessment_answer = dspy.OutputField(desc="A rating between 1 and 5. Only output the rating and nothing else.") + + +def llm_metric(gold, pred, trace=None): + predicted_answer = pred.answer + context = pred.context + question = gold.question + + print(f"Test Question: {question}") + print(f"Predicted Answer: {predicted_answer}") + + detail = "Is the assessed answer detailed?" + faithful = ( + "Is the assessed text grounded in the context? Say no if it includes significant facts not in the context." + ) + overall = f"Please rate how well this answer answers the question, `{question}` based on the context.\n `{predicted_answer}`" + + detail = dspy.ChainOfThought(Assess)(context="N/A", assessed_question=detail, assessed_answer=predicted_answer) + faithful = dspy.ChainOfThought(Assess)( + context=context, assessed_question=faithful, assessed_answer=predicted_answer + ) + overall = dspy.ChainOfThought(Assess)(context=context, assessed_question=overall, assessed_answer=predicted_answer) + + print(f"Faithful: {faithful.assessment_answer}") + print(f"Detail: {detail.assessment_answer}") + print(f"Overall: {overall.assessment_answer}") + + total = float(detail.assessment_answer) + float(faithful.assessment_answer) * 2 + float(overall.assessment_answer) + + return total / 5.0 + + +# Evaluate our RAG Program before it is compiled +evaluate = Evaluate(devset=devset, num_threads=1, display_progress=True, display_table=5) +evaluate(RAG(), metric=llm_metric) + + +# Optimize the RAG Program +optimizer = COPRO( + prompt_model=dspy.settings.lm, + metric=llm_metric, + breadth=3, + depth=2, + init_temperature=0.25, + verbose=False, +) + + +optimized_compiled_rag = optimizer.compile( + RAG(), + trainset=trainset, + eval_kwargs={"num_threads": 1, "display_progress": True, "display_table": 0}, +) + +# Evaluate the optimized RAG Program +evaluate = Evaluate( + metric=llm_metric, + devset=devset, + num_threads=1, + display_progress=True, + display_table=5, +) +evaluate(optimized_compiled_rag) diff --git a/src/rag_pipelines/pipelines/dspy_baseline_rag.py b/src/rag_pipelines/pipelines/dspy_baseline_rag.py new file mode 100644 index 0000000000000000000000000000000000000000..7bdbb119095459aa772a9c7e1c56c15b1e945727 --- /dev/null +++ b/src/rag_pipelines/pipelines/dspy_baseline_rag.py @@ -0,0 +1,81 @@ +import argparse + +import dspy +from datasets import load_dataset + +from rag_pipelines.dspy.dspy_evaluator import DSPyEvaluator +from rag_pipelines.dspy.dspy_rag import DSPyRAG +from rag_pipelines.vectordb.weaviate import WeaviateVectorDB + + +def main(cluster_url, api_key, index_name, model_name, llm_model, llm_api_key): + """Executes the DSPy-based Retrieval-Augmented Generation (RAG) pipeline. + + This function: + 1. Loads a dataset of earnings call Q&A pairs. + 2. Prepares development (dev) and test datasets for evaluation. + 3. Initializes a Weaviate vector database for storing and retrieving embeddings. + 4. Configures a Large Language Model (LLM) with DSPy. + 5. Instantiates and evaluates the RAG pipeline before optimization. + + Args: + cluster_url (str): URL of the Weaviate vector database cluster. + api_key (str): API key for authenticating access to Weaviate. + index_name (str): Name of the Weaviate index for storing vectors. + model_name (str): Embedding model name for vectorization. + llm_model (str): Name of the LLM used for inference. + llm_api_key (str): API key for accessing the LLM. + """ + # Load the earnings calls Q&A dataset (first 50 samples) + earnings_calls_data = load_dataset("lamini/earnings-calls-qa", split="train[:50]") + questions = earnings_calls_data["question"] + + # Prepare dataset splits: + # - The first 20 questions are used for training (not explicitly utilized here). + # - The next 10 questions are used as the development set (devset) for evaluation. + # - The remaining questions are part of the test set (not used in this script). + devset = [dspy.Example(question=q).with_inputs("question") for q in questions[20:30]] + + # Initialize Weaviate VectorDB for embedding storage and retrieval + weaviate_db = WeaviateVectorDB( + cluster_url=cluster_url, # Weaviate cluster URL + api_key=api_key, # API key for authentication + index_name=index_name, # Name of the index for vector storage + model_name=model_name, # Embedding model used for vectorization + ) + + # Initialize the LLM with DSPy + llm = dspy.LM(llm_model, api_key=llm_api_key, num_retries=120) + dspy.configure(lm=llm) # Set DSPy’s global LLM configuration + + # Instantiate the RAG pipeline + rag = DSPyRAG(weaviate_db) + + # Initialize evaluator for measuring LLM-based retrieval performance + evaluator = DSPyEvaluator() + + # Evaluate the RAG pipeline before optimization + evaluate = dspy.Evaluate(devset=devset, num_threads=1, display_progress=True, display_table=5) + evaluate(rag, metric=evaluator.llm_metric) + + +if __name__ == "__main__": + """ + Parses command-line arguments and runs the DSPy-based RAG pipeline. + """ + + parser = argparse.ArgumentParser(description="Run DSPy-based RAG pipeline") + + # Weaviate configuration parameters + parser.add_argument("--cluster_url", type=str, required=True, help="Weaviate cluster URL.") + parser.add_argument("--api_key", type=str, required=True, help="Weaviate API key.") + parser.add_argument("--index_name", type=str, required=True, help="Weaviate index name.") + parser.add_argument("--model_name", type=str, required=True, help="Embedding model name for vectorization.") + + # LLM configuration parameters + parser.add_argument("--llm_model", type=str, required=True, help="LLM model name.") + parser.add_argument("--llm_api_key", type=str, required=True, help="API key for LLM access.") + + # Parse command-line arguments and execute the pipeline + args = parser.parse_args() + main(args.cluster_url, args.api_key, args.index_name, args.model_name, args.llm_model, args.llm_api_key) diff --git a/src/rag_pipelines/pipelines/dspy_bayesian_signature_optimization_rag.py b/src/rag_pipelines/pipelines/dspy_bayesian_signature_optimization_rag.py new file mode 100644 index 0000000000000000000000000000000000000000..6ed8ab5e37c3f0a081f23981ed838266d640c16a --- /dev/null +++ b/src/rag_pipelines/pipelines/dspy_bayesian_signature_optimization_rag.py @@ -0,0 +1,119 @@ +import argparse + +import dspy +import weaviate +from datasets import load_dataset +from dspy.evaluate.evaluate import Evaluate +from dspy.teleprompt import BayesianSignatureOptimizer, BootstrapFewShotWithRandomSearch +from langchain_huggingface import HuggingFaceEmbeddings +from weaviate.classes.init import Auth + +from rag_pipelines.dspy.dspy_evaluator import DSPyEvaluator +from rag_pipelines.dspy.dspy_rag import DSPyRAG +from rag_pipelines.vectordb.weaviate import WeaviateVectorStore + + +def parse_args(): + """Parse command-line arguments.""" + parser = argparse.ArgumentParser(description="Optimize and evaluate RAG pipeline with DSPy.") + + # Dataset Arguments + parser.add_argument( + "--dataset_name", type=str, default="lamini/earnings-calls-qa", help="Name of the dataset to use." + ) + parser.add_argument("--dataset_size", type=int, default=50, help="Number of examples to load from the dataset.") + + # Weaviate Configuration + parser.add_argument("--weaviate_url", type=str, required=True, help="Weaviate cloud cluster URL.") + parser.add_argument("--weaviate_api_key", type=str, required=True, help="API key for Weaviate.") + parser.add_argument("--index_name", type=str, required=True, help="Index name in Weaviate.") + parser.add_argument( + "--embedding_model", type=str, default="jinaai/jina-embeddings-v3", help="Embedding model for Weaviate." + ) + + # LLM Configuration + parser.add_argument("--llm_model", type=str, default="groq/llama-3.3-70b-versatile", help="LLM model to use.") + parser.add_argument("--llm_api_key", type=str, required=True, help="API key for LLM.") + + # Optimization Method + parser.add_argument( + "--optimizer", + type=str, + choices=["bootstrap", "bayesian"], + default="bootstrap", + help="Choose the optimization method.", + ) + + return parser.parse_args() + + +def main(): + args = parse_args() + + # Load dataset + dataset = load_dataset(args.dataset_name, split=f"train[:{args.dataset_size}]") + questions = dataset["question"] + + # Create DSPy datasets + trainset = [dspy.Example(question=q).with_inputs("question") for q in questions[:20]] + devset = [dspy.Example(question=q).with_inputs("question") for q in questions[20:30]] + testset = [dspy.Example(question=q).with_inputs("question") for q in questions[30:]] + + # Initialize embeddings + model_kwargs = {"device": "cpu", "trust_remote_code": True} + encode_kwargs = {"task": "retrieval.query", "prompt_name": "retrieval.query"} + embeddings = HuggingFaceEmbeddings( + model_name=args.embedding_model, model_kwargs=model_kwargs, encode_kwargs=encode_kwargs + ) + + # Connect to Weaviate + weaviate_client = weaviate.connect_to_weaviate_cloud( + cluster_url=args.weaviate_url, + auth_credentials=Auth.api_key(args.weaviate_api_key), + ) + WeaviateVectorStore( + index_name=args.index_name, + embedding=embeddings, + client=weaviate_client, + text_key="text", + ) + + # Configure LLM + llm = dspy.LM(args.llm_model, api_key=args.llm_api_key, num_retries=120) + dspy.configure(lm=llm) + + # Evaluate before optimization + evaluate = Evaluate(devset=devset, num_threads=1, display_progress=True, display_table=5) + evaluate(DSPyRAG(), metric=DSPyEvaluator.llm_metric()) + + # Select Optimizer + if args.optimizer == "bootstrap": + optimizer = BootstrapFewShotWithRandomSearch( + metric=DSPyEvaluator.llm_metric(), + max_bootstrapped_demos=4, + max_labeled_demos=4, + max_rounds=1, + num_candidate_programs=2, + num_threads=2, + ) + else: + optimizer = BayesianSignatureOptimizer( + task_model=dspy.settings.lm, + metric=DSPyEvaluator.llm_metric(), + prompt_model=dspy.settings.lm, + n=5, + verbose=False, + ) + + # Compile optimized RAG + optimized_compiled_rag = optimizer.compile(DSPyRAG(), testset=testset, trainset=trainset) + + # Evaluate optimized RAG + evaluate = Evaluate( + metric=DSPyEvaluator.llm_metric(), devset=devset, num_threads=1, display_progress=True, display_table=5 + ) + evaluate(optimized_compiled_rag) + + +if __name__ == "__main__": + main() diff --git a/src/rag_pipelines/pipelines/dspy_bootstrap_few_shot_optimization_rag.py b/src/rag_pipelines/pipelines/dspy_bootstrap_few_shot_optimization_rag.py new file mode 100644 index 0000000000000000000000000000000000000000..df0ff9187e7375a3ed45fd6967e07e92dc92b5ec --- /dev/null +++ b/src/rag_pipelines/pipelines/dspy_bootstrap_few_shot_optimization_rag.py @@ -0,0 +1,91 @@ +import argparse + +import dspy +from datasets import load_dataset +from dspy.evaluate.evaluate import Evaluate +from dspy.teleprompt import BootstrapFewShot + +from rag_pipelines.dspy.dspy_evaluator import DSPyEvaluator +from rag_pipelines.dspy.dspy_rag import DSPyRAG +from rag_pipelines.vectordb.weaviate import WeaviateVectorDB + + +def main(args): + """Runs the DSPy RAG optimization pipeline. + + This function: + 1. Loads the earnings calls dataset. + 2. Splits the dataset into training, development, and test sets. + 3. Initializes a Weaviate vector database and an LLM. + 4. Evaluates an unoptimized RAG pipeline. + 5. Optimizes the RAG pipeline using BootstrapFewShot. + 6. Evaluates the optimized RAG pipeline. + + Args: + args (argparse.Namespace): Command-line arguments for configuring the pipeline. + """ + # Load the dataset (Earnings Calls QA dataset) + earnings_calls_data = load_dataset("lamini/earnings-calls-qa", split="train[:50]") + questions = earnings_calls_data["question"] + + # Split the dataset into training (20), development (10), and test sets + trainset = [dspy.Example(question=q).with_inputs("question") for q in questions[:20]] + devset = [dspy.Example(question=q).with_inputs("question") for q in questions[20:30]] + [dspy.Example(question=q).with_inputs("question") for q in questions[30:]] # Test set (not used in this script) + + # Initialize Weaviate VectorDB for storing and retrieving embeddings + weaviate_db = WeaviateVectorDB( + cluster_url=args.cluster_url, # URL of the Weaviate cluster + api_key=args.api_key, # API key for authentication + index_name=args.index_name, # Name of the Weaviate index + model_name=args.embedding_model, # Embedding model to use for vector storage + ) + + # Initialize LLM with DSPy + llm = dspy.LM(args.llm_model, api_key=args.llm_api_key, num_retries=args.num_retries) + dspy.configure(lm=llm) # Set DSPy's global configuration for LLM usage + + # Initialize the unoptimized RAG pipeline + uncompiled_rag = DSPyRAG(weaviate_db) + + # Evaluate the unoptimized RAG pipeline using the development set + evaluate = Evaluate(devset=devset, num_threads=1, display_progress=True, display_table=5) + evaluate(uncompiled_rag, metric=DSPyEvaluator.llm_metric()) + + # Optimize the RAG pipeline using BootstrapFewShot + optimizer = BootstrapFewShot(metric=DSPyEvaluator.llm_metric()) + + # Compile an optimized version of the RAG model using the training set + optimized_compiled_rag = optimizer.compile(uncompiled_rag, trainset=trainset) + + # Evaluate the optimized RAG pipeline + evaluate = Evaluate(devset=devset, num_threads=1, display_progress=True, display_table=5) + evaluate(optimized_compiled_rag) + + +if __name__ == "__main__": + """ + Parses command-line arguments and runs the main function. + """ + + parser = argparse.ArgumentParser(description="DSPy RAG Optimization Pipeline") + + # Weaviate parameters (for vector storage and retrieval) + parser.add_argument("--cluster_url", type=str, required=True, help="Weaviate cluster URL.") + parser.add_argument("--api_key", type=str, required=True, help="Weaviate API key.") + parser.add_argument("--index_name", type=str, required=True, help="Weaviate index name.") + parser.add_argument( + "--embedding_model", + type=str, + default="jinaai/jina-embeddings-v3", + help="Embedding model used for document vectorization.", + ) + + # LLM parameters (for DSPy-based language model inference) + parser.add_argument("--llm_model", type=str, default="groq/llama-3.3-70b-versatile", help="LLM model name.") + parser.add_argument("--llm_api_key", type=str, required=True, help="API key for accessing the LLM service.") + parser.add_argument("--num_retries", type=int, default=120, help="Number of retries for LLM API calls.") + + # Parse command-line arguments and execute the pipeline + args = parser.parse_args() + main(args) diff --git a/src/rag_pipelines/pipelines/dspy_bootstrap_few_shot_random_search_optimization_rag.py b/src/rag_pipelines/pipelines/dspy_bootstrap_few_shot_random_search_optimization_rag.py new file mode 100644 index 0000000000000000000000000000000000000000..15e2bab5fa4fa36aff6c5c8cb223c7642ea960c6 --- /dev/null +++ b/src/rag_pipelines/pipelines/dspy_bootstrap_few_shot_random_search_optimization_rag.py @@ -0,0 +1,103 @@ +import argparse + + +import dspy +from datasets import load_dataset +from dspy.evaluate.evaluate import Evaluate +from dspy.teleprompt import BootstrapFewShotWithRandomSearch + +from rag_pipelines.dspy.dspy_evaluator import DSPyEvaluator +from rag_pipelines.dspy.dspy_rag import DSPyRAG +from rag_pipelines.vectordb.weaviate import WeaviateVectorDB + + +def main(args): + """Main function to run the DSPy RAG optimization pipeline. + + This function loads a dataset, initializes a Weaviate vector database and an LLM, + evaluates an unoptimized RAG pipeline, optimizes it using BootstrapFewShotWithRandomSearch, + and then evaluates the optimized pipeline. + + Args: + args (argparse.Namespace): Command-line arguments for configuring the pipeline. + """ + # Load dataset (Earnings Calls QA) + earnings_calls_data = load_dataset("lamini/earnings-calls-qa", split="train[:50]") + questions = earnings_calls_data["question"] + + # Split dataset into training, development, and test sets + trainset = [dspy.Example(question=q).with_inputs("question") for q in questions[:20]] + devset = [dspy.Example(question=q).with_inputs("question") for q in questions[20:30]] + [dspy.Example(question=q).with_inputs("question") for q in questions[30:]] # Test set (not used here) + + # Initialize Weaviate Vector Database + weaviate_db = WeaviateVectorDB( + cluster_url=args.cluster_url, + api_key=args.api_key, + index_name=args.index_name, + model_name=args.embedding_model, + ) + + # Initialize the LLM + llm = dspy.LM(args.llm_model, api_key=args.llm_api_key, num_retries=args.num_retries) + dspy.configure(lm=llm) # Set DSPy's global LLM configuration + + # Initialize the unoptimized RAG pipeline + uncompiled_rag = DSPyRAG(weaviate_db) + + # Evaluate the unoptimized RAG model + evaluate = Evaluate(devset=devset, num_threads=1, display_progress=True, display_table=5) + evaluate(uncompiled_rag, metric=DSPyEvaluator.llm_metric()) + + # Optimize RAG using BootstrapFewShotWithRandomSearch + optimizer = BootstrapFewShotWithRandomSearch( + metric=DSPyEvaluator.llm_metric(), + max_bootstrapped_demos=args.max_bootstrapped_demos, + max_labeled_demos=args.max_labeled_demos, + max_rounds=args.max_rounds, + num_candidate_programs=args.num_candidate_programs, + num_threads=args.num_threads, + num_threads=args.num_threads, + ) + + # Compile an optimized version of the RAG model + optimized_compiled_rag = optimizer.compile(uncompiled_rag, trainset=trainset) + + # Evaluate the optimized RAG model + evaluate = Evaluate(devset=devset, num_threads=1, display_progress=True, display_table=5) + evaluate(optimized_compiled_rag) + + +if __name__ == "__main__": + """ + Parses command-line arguments and runs the main function. + """ + + parser = argparse.ArgumentParser(description="DSPy RAG Optimization Pipeline") + + # Weaviate parameters + parser.add_argument("--cluster_url", type=str, required=True, help="Weaviate cluster URL.") + parser.add_argument("--api_key", type=str, required=True, help="Weaviate API key.") + parser.add_argument("--index_name", type=str, required=True, help="Weaviate index name.") + parser.add_argument( + "--embedding_model", + type=str, + default="jinaai/jina-embeddings-v3", + help="Embedding model to use for vector retrieval.", + ) + + # LLM parameters + parser.add_argument("--llm_model", type=str, default="groq/llama-3.3-70b-versatile", help="LLM model name.") + parser.add_argument("--llm_api_key", type=str, required=True, help="API key for accessing the LLM.") + parser.add_argument("--num_retries", type=int, default=120, help="Number of retries for LLM calls.") + + # Optimization parameters + parser.add_argument("--max_bootstrapped_demos", type=int, default=4, help="Max bootstrapped demonstrations.") + parser.add_argument("--max_labeled_demos", type=int, default=4, help="Max labeled demonstrations.") + parser.add_argument("--max_rounds", type=int, default=1, help="Max optimization rounds.") + parser.add_argument("--num_candidate_programs", type=int, default=2, help="Number of candidate programs.") + parser.add_argument("--num_threads", type=int, default=2, help="Number of threads for optimization.") + + # Parse arguments and run the main function + args = parser.parse_args() + main(args) diff --git a/src/rag_pipelines/pipelines/dspy_copro_instruction_signature_optimization_rag.py b/src/rag_pipelines/pipelines/dspy_copro_instruction_signature_optimization_rag.py new file mode 100644 index 0000000000000000000000000000000000000000..e38ef713bc991f0f73ed1fea8e0312b4bfd7bd3f --- /dev/null +++ b/src/rag_pipelines/pipelines/dspy_copro_instruction_signature_optimization_rag.py @@ -0,0 +1,121 @@ +import argparse + +import dspy +import weaviate +from datasets import load_dataset +from dspy.evaluate.evaluate import Evaluate +from dspy.teleprompt import COPRO +from langchain_huggingface import HuggingFaceEmbeddings +from weaviate.classes.init import Auth + +from rag_pipelines.dspy.dspy_evaluator import DSPyEvaluator +from rag_pipelines.dspy.dspy_rag import DSPyRAG +from rag_pipelines.vectordb.weaviate import WeaviateVectorStore + + +def parse_args(): + """Parse command-line arguments for the DSPy RAG pipeline with Weaviate and LLM evaluation. + + Returns: + argparse.Namespace: The parsed command-line arguments. + """ + parser = argparse.ArgumentParser(description="Run DSPy RAG pipeline with Weaviate and LLM evaluation.") + + # Dataset Arguments + parser.add_argument( + "--dataset_name", type=str, default="lamini/earnings-calls-qa", help="Name of the dataset to use." + ) + parser.add_argument("--dataset_size", type=int, default=50, help="Number of examples to load from the dataset.") + + # Weaviate Configuration + parser.add_argument("--weaviate_url", type=str, required=True, help="Weaviate cloud cluster URL.") + parser.add_argument("--weaviate_api_key", type=str, required=True, help="API key for Weaviate.") + parser.add_argument("--index_name", type=str, required=True, help="Index name in Weaviate.") + parser.add_argument( + "--embedding_model", type=str, default="jinaai/jina-embeddings-v3", help="Embedding model for Weaviate." + ) + + # LLM Configuration + parser.add_argument("--llm_model", type=str, default="groq/llama-3.3-70b-versatile", help="LLM model to use.") + parser.add_argument("--llm_api_key", type=str, required=True, help="API key for LLM.") + + return parser.parse_args() + + +def main(): + """Run the DSPy RAG pipeline, including dataset loading, embedding initialization, Weaviate connection, LLM evaluation, and model optimization. + + This function orchestrates the entire pipeline from loading data, preparing datasets, + connecting to Weaviate, initializing embeddings, evaluating the model, and optimizing the RAG pipeline. + """ + # Parse command-line arguments + args = parse_args() + + # Load dataset from Hugging Face and extract questions + dataset = load_dataset(args.dataset_name, split=f"train[:{args.dataset_size}]") + questions = dataset["question"] + + # Create DSPy datasets for training and evaluation + trainset = [dspy.Example(question=q).with_inputs("question") for q in questions[:20]] + devset = [dspy.Example(question=q).with_inputs("question") for q in questions[20:30]] + [dspy.Example(question=q).with_inputs("question") for q in questions[30:]] + + # Initialize HuggingFace embeddings for retrieval tasks + model_kwargs = {"device": "cpu", "trust_remote_code": True} + encode_kwargs = {"task": "retrieval.query", "prompt_name": "retrieval.query"} + embeddings = HuggingFaceEmbeddings( + model_name=args.embedding_model, model_kwargs=model_kwargs, encode_kwargs=encode_kwargs + ) + + # Connect to Weaviate using the provided URL and API key + weaviate_client = weaviate.connect_to_weaviate_cloud( + cluster_url=args.weaviate_url, + auth_credentials=Auth.api_key(args.weaviate_api_key), + ) + + # Initialize Weaviate vector store with the embeddings and client connection + WeaviateVectorStore( + index_name=args.index_name, + embedding=embeddings, + client=weaviate_client, + text_key="text", + ) + + # Initialize the LLM (Language Model) with the specified model and API key + llm = dspy.LM(args.llm_model, api_key=args.llm_api_key, num_retries=120) + dspy.configure(lm=llm) + + # Evaluate the initial RAG pipeline + evaluate = Evaluate(devset=devset, num_threads=1, display_progress=True, display_table=5) + evaluate(DSPyRAG(), metric=DSPyEvaluator.llm_metric()) + + # Optimize the RAG model using COPRO (Collaborative Prompt Optimization) + optimizer = COPRO( + prompt_model=dspy.settings.lm, + metric=DSPyEvaluator.llm_metric(), + breadth=3, + depth=2, + init_temperature=0.25, + verbose=False, + ) + + # Compile the optimized RAG model with the training set + optimized_compiled_rag = optimizer.compile( + DSPyRAG(), + trainset=trainset, + eval_kwargs={"num_threads": 1, "display_progress": True, "display_table": 0}, + ) + + # Evaluate the optimized model on the development set + evaluate = Evaluate( + metric=DSPyEvaluator.llm_metric(), + devset=devset, + num_threads=1, + display_progress=True, + display_table=5, + ) + evaluate(optimized_compiled_rag) + + +if __name__ == "__main__": + main() diff --git a/src/rag_pipelines/pipelines/dspy_rag.py b/src/rag_pipelines/pipelines/dspy_rag.py new file mode 100644 index 0000000000000000000000000000000000000000..e52962ef2bb73d6620c9fefb250046a76d764308 --- /dev/null +++ b/src/rag_pipelines/pipelines/dspy_rag.py @@ -0,0 +1,47 @@ +from typing import Any + +import dspy +import weave +from dspy import LM, Module + + +class DSPyRAGPipeline(weave.Model): + """A class representing a Retrieval-Augmented Generation (RAG) model using DSPy. + + Attributes: + llm (LM): The language model used for generating predictions. + rag_module (Module): The module used for retrieval tasks. + """ + + llm: LM + rag_module: Module + + def __init__(self, llm: LM, rag_module: Module) -> None: + """Initialize the DSPyRAG model. + + Args: + llm (LM): The language model to be used. + rag_module (Module): The module to be used for retrieval tasks. + """ + super().__init__(llm=llm, rag_module=rag_module) + + self.llm = llm + self.rag_module = rag_module + + dspy.configure(lm=llm) + + @weave.op() + def predict(self, input: str) -> dict[str, Any]: + """Predicts the answer to a given question using the RAG approach. + + Args: + input (str): The question to be answered. + + Returns: + Dict[str, Any]: A dictionary containing the answer and the context. + - "answer" (str): The predicted answer to the question. + - "context" (Any): The context used by the RAG module. + """ + prediction = self.rag_module(input) + + return {"output": prediction.answer, "retrieval_context": prediction.retrieval_context} diff --git a/src/rag_pipelines/pipelines/dspy_rag_module.py b/src/rag_pipelines/pipelines/dspy_rag_module.py new file mode 100644 index 0000000000000000000000000000000000000000..a4e3479c24d1d4fce1ae03b4334d0c775327c647 --- /dev/null +++ b/src/rag_pipelines/pipelines/dspy_rag_module.py @@ -0,0 +1,39 @@ +from typing import Any + +from dspy import ChainOfThought, Module, Prediction + +from rag_pipelines.evaluation import retrieval +from rag_pipelines.prompts import GenerateAnswerFromContext + + +class RAG(Module): + """Retrieval-Augmented Generation (RAG) module that retrieves context based on a question and generates an answer using that context.""" + + def __init__(self, retriever: Any): + """Initialize the RAG module. + + Args: + retriever (Any): An object that has a `question` method returning + a `passages` attribute. Typically, this would be a retriever like + a Milvus Retriever. + """ + super().__init__() + self.retrieve = retriever + self.generate_answer = ChainOfThought(GenerateAnswerFromContext) + + def forward(self, question: str) -> Prediction: + """Process a question by retrieving context and generating an answer. + + Args: + question (str): The question to be answered. + + Returns: + Prediction: A Prediction object containing the context and the generated answer. + """ + # Use the retriever to get context for the question. + context = self.retrieve(question).passages + # Generate an answer using the retrieved context and the question. + prediction = self.generate_answer(context=context, question=question) + # Return a Prediction object with the context and answer. + + return Prediction(retrieval_context=[item.long_text for item in context], answer=prediction.answer) diff --git a/src/rag_pipelines/pipelines/rag.py b/src/rag_pipelines/pipelines/rag.py new file mode 100644 index 0000000000000000000000000000000000000000..e6365c6f84f1a8c0c3c96961ad2f35fef857765b --- /dev/null +++ b/src/rag_pipelines/pipelines/rag.py @@ -0,0 +1,146 @@ +import os +from typing import Any, Optional + +import weave +from langchain_community.retrievers import PineconeHybridSearchRetriever +from langchain_core.output_parsers import StrOutputParser +from langchain_core.prompts.chat import ChatPromptTemplate +from langchain_core.runnables import RunnablePassthrough +from langchain_groq import ChatGroq +from weave import Model +from weave.integrations.langchain import WeaveTracer + +# Disable global tracing explicitly +os.environ["WEAVE_TRACE_LANGCHAIN"] = "false" + + +class RAGPipeline(Model): + """A hybrid retrieval-augmented generation (RAG) pipeline using Weave for tracing and LangChain components. + + This pipeline integrates a retriever, prompt template, and language model (LLM) to implement a retrieval-augmented + generation system, where the LLM generates answers based on both retrieved documents and a prompt template. + Weave is used for tracing to monitor the pipeline's execution. + + Attributes: + retriever: The retrieval model used to fetch relevant documents based on a query. + prompt: The prompt template to generate questions for the LLM. + llm: The language model used to generate responses. + tracer: The tracer used to record the execution details with Weave. + tracing_project_name: The name of the Weave project for tracing. + """ + + retriever: Optional[PineconeHybridSearchRetriever] = None + prompt: Optional[ChatPromptTemplate] = None + llm: Optional[ChatGroq] = None + tracing_project_name: str + weave_params: dict[str, Any] + tracer: Optional[WeaveTracer] = None + + def __init__(self, retriever, prompt, llm, tracing_project_name="hybrid_rag", weave_params=None): + """Initialize the HybridRAGPipeline. + + This constructor sets up the retriever, prompt, LLM, and integrates Weave tracing if specified. + + Args: + retriever: The retrieval model used to fetch documents for the RAG pipeline. + prompt: The prompt template used to create questions for the LLM. + llm: The language model used for response generation based on retrieved documents and prompt. + tracing_project_name (str): The name of the Weave project for tracing. Defaults to "hybrid_rag". + weave_params (dict): Additional parameters for initializing Weave. This can include configuration + details or authentication settings for the Weave service. + """ + super().__init__( + retriever=retriever, + prompt=prompt, + llm=llm, + tracing_project_name=tracing_project_name, + weave_params=weave_params, + ) + + if weave_params is None: + weave_params = {} + + self.retriever = retriever + self.prompt = prompt + self.llm = llm + self.tracing_project_name = tracing_project_name + + # Initialize Weave tracing if parameters are provided, otherwise default initialization. + if weave_params: + self._initialize_weave(**weave_params) + else: + self._initialize_weave() + + def _initialize_weave(self, **weave_params): + """Initialize Weave with the specified tracing project name. + + This method sets up the Weave environment and creates an instance of the WeaveTracer. + The tracer records the execution of each step in the RAG pipeline for monitoring and debugging purposes. + """ + # Initialize the Weave project + weave.init(self.tracing_project_name, **weave_params) + # Set up the tracer for tracking pipeline execution + self.tracer = WeaveTracer() + + @weave.op() + def predict(self, question: str) -> str: + """Execute the Hybrid RAG pipeline with the given question. + + This method orchestrates the entire RAG pipeline. It first retrieves documents using the retriever, + formats them, generates a question using the prompt template, and then processes the final response + using the LLM. The process is traced using Weave for debugging and monitoring. + + Args: + question (str): The input question to be answered by the pipeline. + + Returns: + str: The answer generated by the LLM based on the retrieved documents and the question prompt. + """ + # Configuration for trace callbacks to record the execution process + config = {"callbacks": [self.tracer]} + + # Set up the RAG pipeline chain with document retrieval, formatting, prompting, LLM, and output parsing + rag_chain = ( + { + "context": self.retriever | self.format_docs, + "question": RunnablePassthrough(), + } + | self.prompt + | self.llm + | StrOutputParser() + ) + + # Invoke the pipeline with the specified question and configuration + return rag_chain.invoke(question, config=config) + + def format_docs(self, docs): + """Format retrieved documents into a string for input to the LLM. + + The documents are formatted with information such as filing date, accession number, summary, and image + descriptions. + This string will be passed as the context for the LLM to generate a response. + + Args: + docs (list): A list of document objects that have been retrieved based on the input question. + + Returns: + str: A formatted string of document contents, joined by newline characters. + """ + context = "" + for doc in docs: + date = doc.metadata["filing_date"] + accession_no = doc.metadata["accession_no"] + summary = doc.metadata["summary"] + image_descriptions = doc.metadata["image_descriptions"] + context += ( + f"""# Report {accession_no} filed on {date}:\n\n## An excerpt from the report""" + f"""\n\n{doc.page_content}\n\n""" + ) + if len(image_descriptions) > 0: + context += f"""## Image descriptions\n\n{image_descriptions}\n\n""" + context += ( + f"""## Summary of the report\n\nHere's a summary of the report along with the some """ + f"""important keywords and phrases present in the report:\n\n{summary}\n\n""" + ) + + return context diff --git a/src/rag_pipelines/pipelines/self_rag.py b/src/rag_pipelines/pipelines/self_rag.py new file mode 100644 index 0000000000000000000000000000000000000000..a355d7091a41b89a17b127695478c21b441b0730 --- /dev/null +++ b/src/rag_pipelines/pipelines/self_rag.py @@ -0,0 +1,132 @@ +import os +from typing import Any, Optional + +import weave +from langgraph.graph import END, START, StateGraph +from langgraph.graph.state import CompiledStateGraph +from weave import Model +from weave.integrations.langchain import WeaveTracer + +from rag_pipelines.llms.groq import ChatGroqGenerator +from rag_pipelines.pipelines.self_rag_graph_state import SelfRAGGraphState +from rag_pipelines.retrieval_evaluation import RetrievalCritic, RetrievalEvaluator, UsefulnessEvaluator +from rag_pipelines.vectordb import MilvusRetriever + +# Disable global tracing explicitly +os.environ["WEAVE_TRACE_LANGCHAIN"] = "false" + + +class SelfRAGPipeline(Model): + """A self-reflective retrieval-augmented generation (Self-RAG) pipeline using Weave for tracing and LangChain components. + + The pipeline implements a workflow that retrieves documents, evaluates retrieval quality, critiques the results, + assesses usefulness, and generates final responses while maintaining traceability through Weave. + + Attributes: + retriever (Optional[MilvusRetriever]): Vector store retriever for document fetching + generator (Optional[ChatGroqGenerator]): LLM for response generation + retrieval_evaluator (Optional[RetrievalEvaluator]): Component for evaluating retrieval quality + retrieval_critic (Optional[RetrievalCritic]): Component for critiquing retrieval results + usefulness_evaluator (Optional[UsefulnessEvaluator]): Component for assessing response usefulness + tracer (Optional[WeaveTracer]): Weave integration for execution tracing + """ + + retriever: Optional[MilvusRetriever] = None + generator: Optional[ChatGroqGenerator] = None + retrieval_evaluator: Optional[RetrievalEvaluator] = None + retrieval_critic: Optional[RetrievalCritic] = None + usefulness_evaluator: Optional[UsefulnessEvaluator] = None + tracer: Optional[WeaveTracer] = None + + def __init__( + self, + retriever: MilvusRetriever, + generator: ChatGroqGenerator, + retrieval_evaluator: RetrievalEvaluator, + retrieval_critic: RetrievalCritic, + usefulness_evaluator: UsefulnessEvaluator, + ) -> None: + """Initialize the Self-RAG pipeline with required components. + + Args: + retriever (MilvusRetriever): Vector store retriever instance + generator (ChatGroqGenerator): LLM instance for response generation + retrieval_evaluator (RetrievalEvaluator): Evaluator for retrieval quality assessment + retrieval_critic (RetrievalCritic): Critic component for retrieval result analysis + usefulness_evaluator (UsefulnessEvaluator): Evaluator for response usefulness assessment + """ + super().__init__( + retriever=retriever, + generator=generator, + retrieval_evaluator=retrieval_evaluator, + retrieval_critic=retrieval_critic, + usefulness_evaluator=usefulness_evaluator, + ) + + self.retriever = retriever + self.generator = generator + self.retrieval_evaluator = retrieval_evaluator + self.retrieval_critic = retrieval_critic + self.usefulness_evaluator = usefulness_evaluator + self.tracer = WeaveTracer() + + def _build_self_rag_graph(self) -> CompiledStateGraph: + """Construct and compile the Self-RAG workflow state graph. + + Builds a LangGraph StateGraph with the following nodes: + - Retrieve: Fetch relevant documents + - Evaluate Retrieval: Assess retrieval quality + - Critic Retrieval: Analyze retrieval results + - Evaluate Usefulness: Determine response utility + - Generate: Produce final response + + Returns: + CompiledStateGraph: Compiled workflow graph ready for execution + """ + self_rag_workflow = StateGraph(SelfRAGGraphState) + + # Define the nodes + self_rag_workflow.add_node("retrieve", self.retriever) + self_rag_workflow.add_node("evaluate_retrieval", self.retrieval_evaluator) + self_rag_workflow.add_node("critic_retrieval", self.retrieval_critic) + self_rag_workflow.add_node("evaluate_usefulness", self.usefulness_evaluator) + self_rag_workflow.add_node("generate", self.generator) + + # Define edges between nodes + self_rag_workflow.add_edge(START, "retrieve") + self_rag_workflow.add_edge("retrieve", "evaluate_retrieval") + self_rag_workflow.add_edge("evaluate_retrieval", "critic_retrieval") + self_rag_workflow.add_edge("critic_retrieval", "evaluate_usefulness") + self_rag_workflow.add_edge("evaluate_usefulness", "generate") + self_rag_workflow.add_edge("generate", END) + + # Compile the graph + self_rag_pipeline = self_rag_workflow.compile() + return self_rag_pipeline + + @weave.op() + def predict(self, input: str) -> dict[str, Any]: + """Execute the Self-RAG pipeline to generate a response for the given input. + + Args: + input (str): User query to process through the pipeline + + Returns: + Dict[str, Any]: Result dictionary containing: + - output (str): Generated response text + - retrieval_context (List[str]): List of document excerpts used as context + + Note: + Traces execution through WeaveTracer for observability and debugging + """ + config: dict[str, list[WeaveTracer]] = {"callbacks": [self.tracer]} + + self_rag_graph: CompiledStateGraph = self._build_self_rag_graph() + response: dict[str, Any] = self_rag_graph.invoke({"question": input}, config=config) + + result: dict[str, Any] = { + "output": response["answer"], + "retrieval_context": response["context"], + } + + return result diff --git a/src/rag_pipelines/pipelines/self_rag_graph_state.py b/src/rag_pipelines/pipelines/self_rag_graph_state.py new file mode 100644 index 0000000000000000000000000000000000000000..9853c118e890143cce5c2e1593cce137803b8fe1 --- /dev/null +++ b/src/rag_pipelines/pipelines/self_rag_graph_state.py @@ -0,0 +1,18 @@ +from langchain_core.documents import Document +from typing_extensions import TypedDict + + +class SelfRAGGraphState(TypedDict): + """Represents the state of the graph for the Self-Reflective Retrieval-Augmentation-Generation (Self-RAG) pipeline. + + Attributes: + question (str): The input question for the pipeline. + answer (str): The generated response from the LLM. + documents (list[Document]): A list of LangChain documents that are retrieved and processed through the pipeline. + context (list[str]): The final list of context documents passed to the LLM for generating the answer. + """ + + question: str + answer: str + documents: list[Document] + context: list[str] diff --git a/src/rag_pipelines/prompts/__init__.py b/src/rag_pipelines/prompts/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..82d386b778b8387a763e97be2a667ed747f26301 --- /dev/null +++ b/src/rag_pipelines/prompts/__init__.py @@ -0,0 +1,22 @@ +from rag_pipelines.prompts.dspy_rag import GenerateAnswerFromContext +from rag_pipelines.prompts.financial_summary import FinancialSummaryKeywordsPrompt, FinancialSummaryPrompt +from rag_pipelines.prompts.rag_prompt import STRUCTURED_RAG_PROMPT, RAGResponseModel +from rag_pipelines.prompts.retrieval_critic import RETRIEVAL_CRITIC_PROMPT, RetrievalCriticResult +from rag_pipelines.prompts.retrieval_evalution import RETRIEVAL_EVALUATION_PROMPT, RetrievalEvaluationResult +from rag_pipelines.prompts.summarize_answers import SummarizeAnswersPrompt +from rag_pipelines.prompts.usefulness_evaluator import USEFULNESS_EVALUATOR_PROMPT, UsefulnessEvaluatorResult + +__all__ = [ + "RETRIEVAL_CRITIC_PROMPT", + "RETRIEVAL_EVALUATION_PROMPT", + "STRUCTURED_RAG_PROMPT", + "USEFULNESS_EVALUATOR_PROMPT", + "FinancialSummaryKeywordsPrompt", + "FinancialSummaryPrompt", + "GenerateAnswerFromContext", + "RAGResponseModel", + "RetrievalCriticResult", + "RetrievalEvaluationResult", + "SummarizeAnswersPrompt", + "UsefulnessEvaluatorResult", +] diff --git a/src/rag_pipelines/prompts/__pycache__/__init__.cpython-310.pyc b/src/rag_pipelines/prompts/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bab29dd69398b45839185762109e5013e13b3a8e Binary files /dev/null and b/src/rag_pipelines/prompts/__pycache__/__init__.cpython-310.pyc differ diff --git a/src/rag_pipelines/prompts/__pycache__/dspy_rag.cpython-310.pyc b/src/rag_pipelines/prompts/__pycache__/dspy_rag.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..21e54f7f5ff56402f567dde16a443f02af331db1 Binary files /dev/null and b/src/rag_pipelines/prompts/__pycache__/dspy_rag.cpython-310.pyc differ diff --git a/src/rag_pipelines/prompts/__pycache__/financial_summary.cpython-310.pyc b/src/rag_pipelines/prompts/__pycache__/financial_summary.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b497cc28dd2175d7ba82e5d205e0733a18247f41 Binary files /dev/null and b/src/rag_pipelines/prompts/__pycache__/financial_summary.cpython-310.pyc differ diff --git a/src/rag_pipelines/prompts/__pycache__/rag_prompt.cpython-310.pyc b/src/rag_pipelines/prompts/__pycache__/rag_prompt.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..953ead661f986bca3aa5da54f6c5fd032097515f Binary files /dev/null and b/src/rag_pipelines/prompts/__pycache__/rag_prompt.cpython-310.pyc differ diff --git a/src/rag_pipelines/prompts/__pycache__/retrieval_critic.cpython-310.pyc b/src/rag_pipelines/prompts/__pycache__/retrieval_critic.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..caa3b0d1f38b2a2bd5266b06e549eb73725dac91 Binary files /dev/null and b/src/rag_pipelines/prompts/__pycache__/retrieval_critic.cpython-310.pyc differ diff --git a/src/rag_pipelines/prompts/__pycache__/retrieval_evalution.cpython-310.pyc b/src/rag_pipelines/prompts/__pycache__/retrieval_evalution.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..15c83701fdb65bbc506cab43c8fe198cbe48254e Binary files /dev/null and b/src/rag_pipelines/prompts/__pycache__/retrieval_evalution.cpython-310.pyc differ diff --git a/src/rag_pipelines/prompts/__pycache__/summarize_answers.cpython-310.pyc b/src/rag_pipelines/prompts/__pycache__/summarize_answers.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..38c09255f1705170c3d3accb44fd6198586e02a2 Binary files /dev/null and b/src/rag_pipelines/prompts/__pycache__/summarize_answers.cpython-310.pyc differ diff --git a/src/rag_pipelines/prompts/__pycache__/usefulness_evaluator.cpython-310.pyc b/src/rag_pipelines/prompts/__pycache__/usefulness_evaluator.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..45546ee7cf4f2038f5b508a8087a948fc9ee3fe2 Binary files /dev/null and b/src/rag_pipelines/prompts/__pycache__/usefulness_evaluator.cpython-310.pyc differ diff --git a/src/rag_pipelines/prompts/dspy_evaluate.py b/src/rag_pipelines/prompts/dspy_evaluate.py new file mode 100644 index 0000000000000000000000000000000000000000..5ab43e89237c23f1e0818368fe197b30b41a655f --- /dev/null +++ b/src/rag_pipelines/prompts/dspy_evaluate.py @@ -0,0 +1,26 @@ +from dspy import InputField, OutputField, Signature + + +class Assess(Signature): + """Assessment Signature for Evaluating Answer Quality. + + This signature defines the input and output fields required for an evaluation module + that assesses the quality of an answer. The module considers the context in which the + answer is provided, the evaluation criteria (question), and the answer itself, and outputs + a rating on a scale from 1 to 5. + + Attributes: + context (InputField): The background or supporting information used to answer the question. + assessed_question (InputField): The evaluation criterion or the specific question that guides the assessment. + assessed_answer (InputField): The answer provided that will be evaluated. + assessment_answer (OutputField): The resulting quality rating, typically a number between 1 and 5. + """ + + # Input field containing the context that supports the answer. + context = InputField(desc="The context for answering the question.") + # Input field specifying the evaluation criterion or the question to be assessed. + assessed_question = InputField(desc="The evaluation criterion.") + # Input field holding the answer that is being evaluated. + assessed_answer = InputField(desc="The answer to the question.") + # Output field that will hold the final assessment score, which is a rating between 1 and 5. + assessment_answer = OutputField(desc="A rating between 1 and 5.") diff --git a/src/rag_pipelines/prompts/dspy_rag.py b/src/rag_pipelines/prompts/dspy_rag.py new file mode 100644 index 0000000000000000000000000000000000000000..b144d8f2e3f91707bc0e726cd2391231dfbbc778 --- /dev/null +++ b/src/rag_pipelines/prompts/dspy_rag.py @@ -0,0 +1,22 @@ +from dspy import InputField, OutputField, Signature + + +class GenerateAnswerFromContext(Signature): + """DSPy Signature for Answer Generation. + + This signature specifies the input and output fields for a module responsible for generating + answers to questions. It takes in a context, which may include relevant facts or background information, + and a question. Based on these inputs, the module produces a short and precise answer. + + Attributes: + context (InputField): An input field containing background information or relevant facts that aid in answering the question. + question (InputField): An input field representing the question for which an answer is to be generated. + answer (OutputField): An output field that will contain the generated answer, designed to be short and precise. + """ + + # Input field for context information that may include relevant facts. + context = InputField(desc="contains relevant facts and context for the question") + # Input field for the question to be answered. + question = InputField() + # Output field for the generated answer. + answer = OutputField(desc="short, detailed and precise answer") diff --git a/src/rag_pipelines/prompts/financial_summary.py b/src/rag_pipelines/prompts/financial_summary.py new file mode 100644 index 0000000000000000000000000000000000000000..bb321ad2f71079dbc76abc72970e6d379201dfe3 --- /dev/null +++ b/src/rag_pipelines/prompts/financial_summary.py @@ -0,0 +1,33 @@ +FinancialSummaryPrompt = """You are an expert financial analyst tasked with generating + descriptions for images in financial filings given a summary of the financial filing + and some important keywords present in the document along with the image. + + Here are some rules you should follow: + 1. If the image has text in it, you should first + generate a description of the image and then extract the text in markdown format. + 2. If the image does not have text in it, you should generate a description of the image. + 3. You should frame your reply in markdown format. + 4. The description should be a list of bullet points under the markdown header "Description of the image". + 5. The extracted text should be under the markdown header "Extracted text from the image". + 6. If there are tables or tabular data in the image, you should extract the data in markdown format. + 7. You should pay attention to the financial filing and use the information to generate the description. + + Here is the financial filing's summary: + + --- + {filing_summary} + ---""" + +FinancialSummaryKeywordsPrompt = """You are an expert financial analyst tasked with generating keywords for financial filings. +You should generate a summary of the financial filing and a list of important keywords from +the financial filing. + +Here are some rules you should follow: +1. The summary should be a list of bullet points under the markdown header "Summary of the financial filing". +2. The keywords should be a list of keywords under the markdown header "Important keywords from the financial filing". + +Here is the financial filing: + +--- +{filing_data} +---""" diff --git a/src/rag_pipelines/prompts/rag_prompt.py b/src/rag_pipelines/prompts/rag_prompt.py new file mode 100644 index 0000000000000000000000000000000000000000..049c9e472e7e5df192e219f692d33babd9184889 --- /dev/null +++ b/src/rag_pipelines/prompts/rag_prompt.py @@ -0,0 +1,101 @@ +from pydantic import BaseModel, Field, field_validator + +STRUCTURED_RAG_PROMPT = """You are an analytical question-answering assistant. Follow these steps systematically: + +1. Context Evaluation: + - Analyze all context relevance to: "{question}" + - Identify directly relevant segments + - Flag gaps/contradictions in context + +2. Response Formulation: + - If sufficient context: + * Synthesize key evidence comprehensively + * Construct complete evidence-based answer + * Maintain objective tone + - If insufficient context: + * Specify missing information in detail + * State knowledge boundaries clearly + * Decline speculation + +3. Final Verification: + - Ensure full factual alignment with context + - Remove any external knowledge + +**Question:** {question} +**Relevant Context:** +{context} + +**Reasoning Process:** [Complete analysis with all required sections] +**Final Answer:** [Thorough response or detailed decline explanation]""" + + +class RAGResponseModel(BaseModel): + """Structured validation for RAG response process compliance.""" + + reasoning_chain: str = Field( + ..., + description=( + "Complete analytical process containing:\n" + "- 1. Context Evaluation: Relevance assessment\n" + "- 2. Response Formulation: Answer construction logic\n" + "- 3. Final Verification: Factual consistency check" + ), + ) + + final_answer: str = Field( + ..., + description=( + "Verified response containing either:\n" + "- Comprehensive evidence-based explanation\n" + "- Detailed decline with missing context specification" + ), + ) + + @field_validator("reasoning_chain") + @classmethod + def validate_analysis_steps(cls, chain: str) -> str: + """Validate the structure and completeness of the analytical reasoning chain. + + Ensures the reasoning chain contains all required section headers and meets length constraints. + + Args: + chain (str): The reasoning chain to validate. + + Returns: + str: The validated reasoning chain. + + Raises: + ValueError: If any required section headers are missing from the chain. + """ + required_sections = ["1. Context Evaluation", "2. Response Formulation", "3. Final Verification"] + + missing = [step for step in required_sections if step not in chain] + if missing: + msg = f"Missing required analysis steps: {missing}" + raise ValueError(msg) + + return chain + + model_config = { + "json_schema_extra": { + "example": { + "reasoning_chain": ( + "1. Context Evaluation: Analyzed 5 documents on climate models\n" + "Identified relevant sections on temperature projections\n" + "Noted absence of economic impact data\n\n" + "2. Response Formulation: Combined IPCC and NOAA projections\n" + "Structured timeline-based response\n" + "Excluded unrelated urban heat island content\n\n" + "3. Final Verification: Cross-referenced all statistics\n" + "Confirmed absence of external knowledge" + ), + "final_answer": ( + "Current climate models project a temperature increase range of 1.5°C to 4.5°C by 2100, " + "depending on emission scenarios. The IPCC AR6 report indicates a 66% likelihood of staying " + "below 2°C if net-zero emissions are achieved by 2070. NOAA data shows accelerated warming " + "trends in Arctic regions, with models predicting ice-free summers by 2050 under high-emission " + "pathways. Economic impact projections remain unavailable in provided context." + ), + } + } + } diff --git a/src/rag_pipelines/prompts/retrieval_critic.py b/src/rag_pipelines/prompts/retrieval_critic.py new file mode 100644 index 0000000000000000000000000000000000000000..2d8489ebd88621a282ca3c9b4a26ea767b562679 --- /dev/null +++ b/src/rag_pipelines/prompts/retrieval_critic.py @@ -0,0 +1,130 @@ +from typing import Literal + +from pydantic import BaseModel, Field, field_validator + +RETRIEVAL_CRITIC_PROMPT = """You are a senior researcher evaluating document support quality. Follow these steps: + +1. Answer Requirements: Identify exact claims needed for complete answer +2. Claim Verification: Check document's direct evidence for each requirement +3. Evidence Strength: Assess quality/reliability of supporting facts +4. Completeness Check: Identify missing elements or partial coverage +5. Support Synthesis: Combine analysis into final support classification + +**Question:** {question} + +**Document Excerpt:** {context} + +Provide detailed reasoning through all steps, then state final decision as either: +- "fully-supported" (covers all requirements with strong evidence) +- "partially-supported" (covers some requirements or weak evidence) +- "no-support" (contains no usable evidence)""" + + +class RetrievalCriticResult(BaseModel): + """Structured evaluation of a document's ability to support answering a query. + + Validates that the reasoning chain contains all required analytical steps and that + the final decision matches one of the predefined support classifications. + + Attributes: + reasoning_chain (str): Step-by-step analysis through verification stages. + Must contain all required section headers. + decision (Literal["fully-supported", "partially-supported", "no-support"]): + Final classification of document support quality. + + Raises: + ValueError: If reasoning chain is missing required sections or is too short + ValidationError: If decision value doesn't match allowed literals + + Example: + >>> valid_result = RetrievalCriticResult( + ... reasoning_chain=( + ... "1. Answer Requirements: Needs 3 climate change impacts" + ... "2. Claim Verification: Documents sea level rise data" + ... "3. Evidence Strength: IPCC report citations provided" + ... "4. Completeness Check: Missing economic impact analysis" + ... "5. Support Synthesis: Covers 2/3 required impact areas" + ... ), + ... decision="partially-supported" + ... ) + >>> valid_result.decision + 'partially-supported' + """ + + reasoning_chain: str = Field( + ..., + description=( + "Systematic analysis through verification stages. Must contain:\n" + "- 1. Answer Requirements: Identification of needed claims\n" + "- 2. Claim Verification: Document evidence checking\n" + "- 3. Evidence Strength: Quality assessment of sources\n" + "- 4. Completeness Check: Missing elements analysis\n" + "- 5. Support Synthesis: Final classification rationale" + ), + ) + decision: Literal["fully-supported", "partially-supported", "no-support"] = Field( + ..., + description=( + "Final classification of document's support quality:\n" + "- 'fully-supported': Comprehensive evidence for all requirements\n" + "- 'partially-supported': Partial or weak evidence coverage\n" + "- 'no-support': No usable evidence found" + ), + ) + + @field_validator("reasoning_chain") + @classmethod + def validate_reasoning_steps(cls, chain_to_validate: str) -> str: + """Validate the structure and completeness of the analytical reasoning chain. + + Ensures the reasoning chain contains all required section headers and meets + minimum length requirements for meaningful analysis. + + Args: + chain_to_validate (str): The raw reasoning chain text to validate + + Returns: + str: The validated reasoning chain if all requirements are met + + Raises: + ValueError: If any required section headers are missing from the chain + + Example: + >>> valid_chain = ( + ... "1. Answer Requirements: Needs 5 economic indicators" + ... "2. Claim Verification: GDP data verified" + ... "3. Evidence Strength: Government reports cited" + ... "4. Completeness Check: Missing unemployment figures" + ... "5. Support Synthesis: Covers 4/5 required indicators" + ... ) + >>> RetrievalCriticResult.validate_reasoning_steps(valid_chain) + '1. Answer Requirements: Needs 5 economic indicators...' + """ + required_steps = [ + "1. Answer Requirements", + "2. Claim Verification", + "3. Evidence Strength", + "4. Completeness Check", + "5. Support Synthesis", + ] + + missing: list[str] = [step for step in required_steps if step not in chain_to_validate] + if missing: + msg = f"Missing required analysis steps: {missing}" + raise ValueError(msg) + return chain_to_validate + + model_config = { + "json_schema_extra": { + "example": { + "reasoning_chain": ( + "1. Answer Requirements: Needs 3 main battery innovations\n" + "2. Claim Verification: Documents solid-state and lithium-air tech\n" + "3. Evidence Strength: Peer-reviewed study citations\n" + "4. Completeness Check: Missing third innovation details\n" + "5. Support Synthesis: Strong evidence for 2/3 requirements" + ), + "decision": "partially-supported", + } + } + } diff --git a/src/rag_pipelines/prompts/retrieval_evalution.py b/src/rag_pipelines/prompts/retrieval_evalution.py new file mode 100644 index 0000000000000000000000000000000000000000..d3baa83ba19d365a7247659a6f741b437777314f --- /dev/null +++ b/src/rag_pipelines/prompts/retrieval_evalution.py @@ -0,0 +1,110 @@ +from typing import Literal + +from pydantic import BaseModel, Field, field_validator + +RETRIEVAL_EVALUATION_PROMPT = """You are a senior analyst evaluating document relevance. Follow these steps: + +1. Question Core: Identify essential information needed +2. Document Facts: Extract concrete claims from text +3. Direct Overlap: Verify question-document concept matches +4. Indirect Support: Evaluate contextual relevance +5. Final Synthesis: Combine analysis into verdict + +**Question:** {question} + +**Document Excerpt:** {context} + +Provide detailed reasoning through all steps, then state final decision as 'relevant' or 'irrelevant'.""" + + +class RetrievalEvaluationResult(BaseModel): + """Pydantic model for validating structured relevance evaluation results. + + Enforces complete reasoning chain and constrained decision output through + validation rules. Designed for integration with LLM structured output pipelines. + + Attributes: + reasoning_chain (str): Sequential analysis following required evaluation steps. + Must contain all required section headers. + decision (Literal["relevant", "irrelevant"]): Final relevance determination. + + Raises: + ValidationError: If reasoning chain misses required sections or length constraints + + Examples: + >>> valid_instance = RetrievalEvaluator( + ... reasoning_chain=( + ... "1. Question Core: Cloud security practices" + ... "2. Document Facts: AWS IAM details" + ... "3. Direct Overlap: Security focus match" + ... "4. Indirect Support: Implementation examples" + ... "5. Final Synthesis: Directly addresses question" + ... ), + ... decision="relevant" + ... ) + >>> isinstance(valid_instance, RetrievalEvaluator) + True + """ + + reasoning_chain: str = Field( + default=..., + description="Sequential analysis through required evaluation stages. Must contain: 1. Question Core, " + "2. Document Facts, 3. Direct Overlap, 4. Indirect Support, 5. Final Synthesis sections.", + ) + decision: Literal["relevant", "irrelevant"] = Field( + default=..., + description="Binary relevance determination based on structured analysis of document content against query " + "requirements.", + ) + + @field_validator("reasoning_chain") + @classmethod + def validate_reasoning_steps(cls, chain_to_validate: str) -> str: + r"""Validate reasoning chain contains all required analysis sections. + + Args: + chain_to_validate (str): Input reasoning chain text to validate + + Returns: + str: Validated reasoning chain text if all sections present + + Raises: + ValueError: If any required section headers are missing from the text + + Example: + >>> valid_chain = ( + ... "1. Question Core: ... 2. Document Facts: ... 3. Direct Overlap: ...4. Indirect Support: ..." + ... "5. Final Synthesis: ..." + ... ) + >>> RetrievalEvaluator.validate_reasoning_steps(valid_chain) + '1. Question Core: ... 2. Document Facts: ... 3. Direct Overlap: ... 4. Indirect Support: ... + 5. Final Synthesis: ...' + """ + required_steps = [ + "1. Question Core", + "2. Document Facts", + "3. Direct Overlap", + "4. Indirect Support", + "5. Final Synthesis", + ] + + missing: list[str] = [step for step in required_steps if step not in chain_to_validate] + if missing: + msg = f"Missing required analysis steps: {missing}" + raise ValueError(msg) + return chain_to_validate + + model_config = { + "json_schema_extra": { + "example": { + "reasoning_chain": ( + "1. Question Core: Requires cloud security best practices\n" + "2. Document Facts: Details AWS IAM role configurations\n" + "3. Direct Overlap: Matches cloud security focus\n" + "4. Indirect Support: Provides implementation examples\n" + "5. Final Synthesis: Directly addresses core security question" + ), + "decision": "relevant", + } + } + } diff --git a/src/rag_pipelines/prompts/summarize_answers.py b/src/rag_pipelines/prompts/summarize_answers.py new file mode 100644 index 0000000000000000000000000000000000000000..97e0af0089479efda9c3804057fd5d585dc1e1f0 --- /dev/null +++ b/src/rag_pipelines/prompts/summarize_answers.py @@ -0,0 +1,10 @@ +SummarizeAnswersPrompt = ( + "Task: Generate the most accurate and relevant answer.\n\n" + "Instructions:\n" + "1. Analyze the given question: '{question}'.\n" + "2. Review the provided list of answers: {answers}.\n" + "3. Craft a response that best addresses the question. The answer can be:\n" + " - A completely new formulation.\n" + " - A refined combination of ideas from the list.\n\n" + "Output: Only provide the final answer, with no additional text or commentary." +) diff --git a/src/rag_pipelines/prompts/usefulness_evaluator.py b/src/rag_pipelines/prompts/usefulness_evaluator.py new file mode 100644 index 0000000000000000000000000000000000000000..68617158a46b84920a3aebc88cc2e122e7e57d04 --- /dev/null +++ b/src/rag_pipelines/prompts/usefulness_evaluator.py @@ -0,0 +1,130 @@ +from typing import Any, ClassVar, Literal + +from pydantic import BaseModel, Field, field_validator + +USEFULNESS_EVALUATOR_PROMPT: str = """You are a senior researcher evaluating document usefulness. Follow these steps: + +1. Required Information: Identify key facts needed to answer the question +2. Factual Content: Check document for presence of required facts +3. Quality Assessment: Evaluate reliability/detail of presented information +4. Coverage Analysis: Determine percentage of required information covered +5. Score Synthesis: Combine factors into final usefulness score + +**Question:** {question} + +**Document Excerpt:** {context} + +Provide detailed reasoning through all steps, then state final score as an integer from 1-5 using format: "final_score": "" + +1 = No relevant facts, 3 = Some useful facts with gaps, 5 = Comprehensive high-quality information""" + + +class UsefulnessEvaluatorResult(BaseModel): + """Structured evaluation result for document usefulness scoring system. + + Encapsulates both the reasoning process and final score while ensuring validation + of required analytical components through Pydantic model constraints. + + Attributes: + reasoning_chain (str): Step-by-step analysis through evaluation stages. + decision (Literal["1", "2", "3", "4", "5"]): Final numerical usefulness score. + + Raises: + ValueError: If reasoning chain misses any required analysis sections + ValidationError: If score value is not in allowed range (1-5) + + Example: + >>> try: + ... result = UsefulnessEvaluatorResult( + ... reasoning_chain=( + ... "1. Required Information: Needs 3 economic indicators" + ... "2. Factual Content: Contains GDP data" + ... "3. Quality Assessment: Government statistics" + ... "4. Coverage Analysis: 2/3 indicators present" + ... "5. Score Synthesis: Partial official data" + ... ), + ... decision="3" + ... ) + ... print(result.decision) + ... except ValidationError as e: + ... print(e) + 3 + """ + + reasoning_chain: str = Field( + ..., + description=( + "Complete analysis chain containing required sections:\n" + "1. Required Information: Key facts needed for comprehensive answer\n" + "2. Factual Content: Presence verification of required facts\n" + "3. Quality Assessment: Source reliability and detail depth\n" + "4. Coverage Analysis: Percentage of requirements fulfilled\n" + "5. Score Synthesis: Final numerical score justification" + ), + ) + decision: Literal["1", "2", "3", "4", "5"] = Field( + ..., + description=( + "Numerical usefulness score with criteria:\n" + "1 - Irrelevant/no facts | 2 - Minimal value | 3 - Partial with gaps\n" + "4 - Good reliable coverage | 5 - Comprehensive high-quality" + ), + ) + + @field_validator("reasoning_chain") + @classmethod + def validate_reasoning_steps(cls, chain_to_validate: str) -> str: + """Validate completeness of analytical reasoning chain. + + Ensures all required evaluation phases are present and properly formatted in + the reasoning chain through section header verification. + + Args: + chain_to_validate (str): Raw text of the reasoning chain to validate + + Returns: + str: Validated reasoning chain if all requirements are met + + Raises: + ValueError: If any of the required section headers are missing from + the reasoning chain text + + Example: + >>> valid_chain = ( + ... "1. Required Information: Needs 5 metrics" + ... "2. Factual Content: Contains 3 metrics" + ... "3. Quality Assessment: Industry reports" + ... "4. Coverage Analysis: 60% complete" + ... "5. Score Synthesis: Partial coverage" + ... ) + >>> UsefulnessEvaluatorResult.validate_reasoning_steps(valid_chain) + '1. Required Information: Needs 5 metrics...' + """ + required_steps: list[str] = [ + "1. Required Information", + "2. Factual Content", + "3. Quality Assessment", + "4. Coverage Analysis", + "5. Score Synthesis", + ] + + missing: list[str] = [step for step in required_steps if step not in chain_to_validate] + if missing: + msg = f"Missing required analysis sections: {', '.join(missing)}" + raise ValueError(msg) + return chain_to_validate + + model_config: ClassVar[dict[str, Any]] = { + "json_schema_extra": { + "example": { + "reasoning_chain": ( + "1. Required Information: Needs 5 climate change impacts\n" + "2. Factual Content: Details 3 impacts with data\n" + "3. Quality Assessment: Peer-reviewed sources cited\n" + "4. Coverage Analysis: 60% of requirements met\n" + "5. Score Synthesis: Strong but incomplete coverage" + ), + "decision": "4", + } + } + } diff --git a/src/rag_pipelines/query_transformer/__init__.py b/src/rag_pipelines/query_transformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..51aa2413a7d9395505bbed49b5e0f54b0f696100 --- /dev/null +++ b/src/rag_pipelines/query_transformer/__init__.py @@ -0,0 +1,3 @@ +from rag_pipelines.query_transformer.query_transformer import QueryTransformer + +__all__ = ["QueryTransformer"] diff --git a/src/rag_pipelines/query_transformer/query_transformer.py b/src/rag_pipelines/query_transformer/query_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..8a9369537f63f1764e7cf5836421e9136001f6c2 --- /dev/null +++ b/src/rag_pipelines/query_transformer/query_transformer.py @@ -0,0 +1,106 @@ +from typing import Any + +from langchain.prompts import PromptTemplate +from langchain_core.output_parsers import StrOutputParser + + +class QueryTransformer: + """Transform user questions into optimized versions for document retrieval. + + This class leverages a Large Language Model (LLM) to rewrite user questions by analyzing their semantic + intent, making them more effective for retrieval tasks in pipelines involving LLMs. + + Attributes: + llm (Any): The language model used for query transformation. + + Methods: + transform_query(question: str) -> dict[str, str]: + Transform an input question into an optimized version for retrieval. + __call__(state: dict[str, Any]) -> dict[str, Any]: + Process a state dictionary containing a question and documents, returning the transformed question. + """ + + def __init__(self, llm: Any): + """Initialize the QueryTransformer with a language model. + + Args: + llm (Any): The large language model that processes and transforms the query. + The LLM must implement an `invoke` method to process the input and generate the transformed query. + """ + self.llm = llm + + def transform_query(self, question: str) -> dict[str, str]: + """Transform a given user question into an optimized form for document retrieval. + + This method uses the LLM to analyze the input question, infer its semantic intent, and + return a more effective version for retrieval purposes. + + Args: + question (str): The user question to transform. + + Returns: + dict[str, str]: A dictionary containing the transformed question under the key 'question'. + + Example: + ```python + original_question = "What are the benefits of cloud computing?" + transformer = QueryTransformer(llm) + transformed_query = transformer.transform_query(original_question) + print(transformed_query["question"]) # Output: 'What advantages does cloud computing provide?' + ``` + """ + self.prompt = PromptTemplate.from_template( + ( + "You are generating questions that are well optimized for retrieval.\n" + "Look at the input and try to reason about the underlying semantic intent.\n" + "Here is the initial question:\n" + "-------\n" + "{question}\n" + "-------\n" + "Provide an improved question without any preamble, only respond with the updated question:" + ), + input_variables=["question"], + ) + + # Define the chain that processes the question using the prompt, LLM, and output parser + self.chain = self.prompt | self.llm | StrOutputParser() + + # Invoke the chain to process the input question and return the optimized version + rewritten_question = self.chain.invoke({"question": question}) + + return {"question": rewritten_question} + + def __call__(self, state: dict[str, Any]) -> dict[str, Any]: + """Process a state dictionary to transform a question for optimized retrieval. + + This method takes the current state, transforms the question using the LLM, + and returns the updated state containing the optimized question alongside + the original documents. + + Args: + state (dict[str, Any]): A dictionary with the following keys: + - 'question': The original user question. + - 'documents': A list of documents (passed through unchanged). + + Returns: + dict[str, Any]: A dictionary containing: + - 'documents': The original documents. + - 'question': The transformed question. + + Example: + ```python + state = { + "question": "What are the benefits of cloud computing?", + "documents": [{"content": "Sample document"}] + } + transformer = QueryTransformer(llm) + new_state = transformer(state) + print(new_state["question"]) # Optimized question + ``` + """ + question = state["question"] + documents = state["documents"] + + rewritten_question = self.transform_query(question) + + return {"documents": documents, "question": rewritten_question} diff --git a/src/rag_pipelines/query_transformer/router.py b/src/rag_pipelines/query_transformer/router.py new file mode 100644 index 0000000000000000000000000000000000000000..d1bce084e7b561e26f0c9571823fe2606801c774 --- /dev/null +++ b/src/rag_pipelines/query_transformer/router.py @@ -0,0 +1,97 @@ +import os +from typing import Any, Optional + +import weave +from langchain_community.retrievers import PineconeHybridSearchRetriever +from langchain_core.prompts.chat import ChatPromptTemplate +from langgraph.graph import END, START, StateGraph +from langgraph.graph.state import CompiledStateGraph +from weave import Model +from weave.integrations.langchain import WeaveTracer + +from rag_pipelines.llms.groq import ChatGroqGenerator +from rag_pipelines.query_transformer import QueryTransformer +from rag_pipelines.websearch import WebSearch + + +class SubQueryRouter: + """Route sub-queries to specific retrieval mechanisms and combines responses. + + Attributes: + llm (Any): The LLM responsible for generating sub-queries and synthesizing answers. + vectordb (Any): The vector database (e.g., Milvus) for retrieving relevant documents. + embedding_model (Any): The model for generating vector embeddings from sub-queries. + + Methods: + __call__(query: str) -> str: + Processes a complex query, breaks it into sub-queries, retrieves relevant documents, + and synthesizes a final response. + """ + + def __init__(self, llm: Any, vectordb: Any, embedding_model: Any): + """Initialize the router with an LLM, a vector database, and an embedding model. + + Args: + llm (Any): The LLM used for sub-query generation and response synthesis. + vectordb (Any): The vector database for retrieving relevant documents. + embedding_model (Any): Model for generating embeddings from queries. + """ + self.llm = llm + self.vectordb = vectordb + self.embedding_model = embedding_model + + def generate_sub_queries(self, query: str) -> List[str]: + """Use the LLM to generate sub-queries from a complex query. + + Args: + query (str): The user-provided complex query. + + Returns: + List[str]: A list of simpler sub-queries. + """ + prompt = f"Break down the following query into simpler sub-queries:\n\n{query}\n\nSub-queries:" + response = self.llm.generate(prompt) + return response.split("\n") # Assuming the LLM returns sub-queries as newline-separated text. + + def retrieve_relevant_documents(self, sub_queries: List[str]) -> List[Document]: + """Retrieve relevant documents for each sub-query from the vector database. + + Args: + sub_queries (List[str]): A list of sub-queries. + + Returns: + List[Document]: A list of relevant documents retrieved from the vector database. + """ + retrieved_docs = [] + for sub_query in sub_queries: + query_embedding = self.embedding_model.embed(sub_query) + results = self.vectordb.search(query_embedding, top_k=5) # Retrieve top-K relevant chunks + retrieved_docs.extend(results) + return retrieved_docs + + def synthesize_response(self, query: str, documents: List[Document]) -> str: + """Use the LLM to synthesize a final response based on retrieved documents. + + Args: + query (str): The original complex query. + documents (List[Document]): Relevant documents retrieved from the vector database. + + Returns: + str: The synthesized response. + """ + context = "\n\n".join([doc.page_content for doc in documents]) + prompt = f"Based on the following retrieved information, answer the query:\n\nQuery: {query}\n\nContext:\n{context}\n\nResponse:" + return self.llm.generate(prompt) + + def __call__(self, query: str) -> str: + """Process a query by breaking it into sub-queries, retrieving relevant documents, and synthesizing a response. + + Args: + query (str): The user query. + + Returns: + str: The final synthesized response. + """ + sub_queries = self.generate_sub_queries(query) + documents = self.retrieve_relevant_documents(sub_queries) + return self.synthesize_response(query, documents) diff --git a/src/rag_pipelines/retrieval_evaluation/__init__.py b/src/rag_pipelines/retrieval_evaluation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9d4c421dd97624b454681bc332a7b44f2c083ad7 --- /dev/null +++ b/src/rag_pipelines/retrieval_evaluation/__init__.py @@ -0,0 +1,13 @@ +from rag_pipelines.retrieval_evaluation.document_grader import DocumentGrader +from rag_pipelines.retrieval_evaluation.query_decision_maker import QueryDecisionMaker +from rag_pipelines.retrieval_evaluation.retrieval_critic import RetrievalCritic +from rag_pipelines.retrieval_evaluation.retrieval_evaluator import RetrievalEvaluator +from rag_pipelines.retrieval_evaluation.usefulness_evaluator import UsefulnessEvaluator + +__all__ = [ + "DocumentGrader", + "QueryDecisionMaker", + "RetrievalCritic", + "RetrievalEvaluator", + "UsefulnessEvaluator", +] diff --git a/src/rag_pipelines/retrieval_evaluation/__pycache__/__init__.cpython-310.pyc b/src/rag_pipelines/retrieval_evaluation/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..15d8438f42795064ac34ce86c4d440b308d7b1ab Binary files /dev/null and b/src/rag_pipelines/retrieval_evaluation/__pycache__/__init__.cpython-310.pyc differ diff --git a/src/rag_pipelines/retrieval_evaluation/__pycache__/document_grader.cpython-310.pyc b/src/rag_pipelines/retrieval_evaluation/__pycache__/document_grader.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..386abfd493a10e5720d9e0056faa1fda715869dc Binary files /dev/null and b/src/rag_pipelines/retrieval_evaluation/__pycache__/document_grader.cpython-310.pyc differ diff --git a/src/rag_pipelines/retrieval_evaluation/__pycache__/query_decision_maker.cpython-310.pyc b/src/rag_pipelines/retrieval_evaluation/__pycache__/query_decision_maker.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c0e08910ea1add5cbf90001acc3c19a17f872752 Binary files /dev/null and b/src/rag_pipelines/retrieval_evaluation/__pycache__/query_decision_maker.cpython-310.pyc differ diff --git a/src/rag_pipelines/retrieval_evaluation/__pycache__/retrieval_critic.cpython-310.pyc b/src/rag_pipelines/retrieval_evaluation/__pycache__/retrieval_critic.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4f8920bbcd0c906b7c20226079adb73f8cb607dc Binary files /dev/null and b/src/rag_pipelines/retrieval_evaluation/__pycache__/retrieval_critic.cpython-310.pyc differ diff --git a/src/rag_pipelines/retrieval_evaluation/__pycache__/retrieval_evaluator.cpython-310.pyc b/src/rag_pipelines/retrieval_evaluation/__pycache__/retrieval_evaluator.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..47a355bd3c3ee02f0f18b754ad2de53842becc25 Binary files /dev/null and b/src/rag_pipelines/retrieval_evaluation/__pycache__/retrieval_evaluator.cpython-310.pyc differ diff --git a/src/rag_pipelines/retrieval_evaluation/__pycache__/usefulness_evaluator.cpython-310.pyc b/src/rag_pipelines/retrieval_evaluation/__pycache__/usefulness_evaluator.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2c5f8f461f68054e2ee117b5f233707786e0713a Binary files /dev/null and b/src/rag_pipelines/retrieval_evaluation/__pycache__/usefulness_evaluator.cpython-310.pyc differ diff --git a/src/rag_pipelines/retrieval_evaluation/document_grader.py b/src/rag_pipelines/retrieval_evaluation/document_grader.py new file mode 100644 index 0000000000000000000000000000000000000000..729be0c943034a47c2732c55a35e7d7cca132be4 --- /dev/null +++ b/src/rag_pipelines/retrieval_evaluation/document_grader.py @@ -0,0 +1,112 @@ +from typing import Any + +from rag_pipelines.retrieval_evaluation.retrieval_evaluator import RetrievalEvaluator + + +class DocumentGrader: + """Grade documents and determine if a web search is required. + + This class evaluates document relevance using the `RetrievalEvaluator` and applies a predefined threshold + to decide whether the retrieved documents are sufficient for answering a query. If the proportion of relevant + documents falls below the threshold, it recommends conducting a web search. + + Attributes: + threshold (float): The minimum relevance ratio required to avoid a web search. + retrieval_evaluator (RetrievalEvaluator): An instance used for scoring document relevance. + + Methods: + grade_documents(question: str, documents: list[Any]) -> dict[str, Any]: + Evaluate and filter relevant documents while determining if a web search is needed. + + __call__(state: dict[str, Any]) -> dict[str, Any]: + Process the current state, filter documents, and decide if a web search is necessary. + """ + + def __init__(self, retrieval_evaluator: RetrievalEvaluator, threshold: float = 0.4): + """Initialize the DocumentGrader with a threshold and evaluator. + + Args: + retrieval_evaluator (RetrievalEvaluator): An instance for evaluating document relevance. + threshold (float): The minimum ratio of relevant documents required to avoid a web search. Defaults to 0.4. + """ + self.threshold = threshold + self.retrieval_evaluator = retrieval_evaluator + + def grade_documents(self, question: str, documents: list[Any]) -> dict[str, Any]: + """Grade documents for relevance and filter out irrelevant ones. + + This method evaluates the relevance of documents using the `RetrievalEvaluator`. Relevant documents + are retained, and the ratio of relevant documents to total documents is compared against a threshold + to determine if a web search should be performed. + + Args: + question (str): The user's query to evaluate document relevance. + documents (list[Any]): A list of retrieved documents. Each document must have a `page_content` attribute. + + Returns: + dict[str, Any]: A dictionary containing: + - 'documents' (list[Any]): A list of relevant documents passing the threshold. + - 'run_web_search' (str): A recommendation for a web search ("Yes" or "No"). + + Example: + ```python + grader = DocumentGrader(threshold=0.4, retrieval_evaluator=evaluator) + results = grader.grade_documents("What is AI?", documents) + print(results["run_web_search"]) # Output: 'Yes' or 'No' + ``` + """ + scored_documents = self.retrieval_evaluator.score_documents(question=question, documents=documents) + + relevant_docs = [] + relevant_count = 0 + + for scored_doc in scored_documents: + if scored_doc["relevance_score"] == "yes": + relevant_docs.append(scored_doc["document"]) + relevant_count += 1 + + # Calculate relevance ratio + relevance_ratio = relevant_count / len(scored_documents) if scored_documents else 0 + + # Determine if a web search is needed + run_web_search = "Yes" if relevance_ratio <= self.threshold else "No" + + return { + "documents": relevant_docs, + "run_web_search": run_web_search, + } + + def __call__(self, state: dict[str, Any]) -> dict[str, Any]: + """Process the state to filter documents and decide if a web search is necessary. + + Args: + state (dict[str, Any]): The current state containing: + - 'question' (str): The user's query. + - 'documents' (list[Any]): The list of retrieved documents. + + Returns: + dict[str, Any]: An updated state with: + - 'documents' (list[Any]): The filtered relevant documents. + - 'question' (str): The original query. + - 'web_search' (str): A web search recommendation ("Yes" or "No"). + + Example: + ```python + state = { + "question": "What is the capital of France?", + "documents": [...] + } + updated_state = grader(state) + print(updated_state["web_search"]) # Output: 'Yes' or 'No' + ``` + """ + question = state["question"] + documents = state["documents"] + + graded_results = self.grade_documents(question=question, documents=documents) + + return { + "documents": graded_results["documents"], + "question": question, + "web_search": graded_results["run_web_search"], + } diff --git a/src/rag_pipelines/retrieval_evaluation/query_decision_maker.py b/src/rag_pipelines/retrieval_evaluation/query_decision_maker.py new file mode 100644 index 0000000000000000000000000000000000000000..38efd9d531349534351cae72aec987b71c03ee75 --- /dev/null +++ b/src/rag_pipelines/retrieval_evaluation/query_decision_maker.py @@ -0,0 +1,65 @@ +import logging +from typing import Any + +from rag_pipelines.utils import LoggerFactory + +logger_factory = LoggerFactory(logger_name=__name__, log_level=logging.INFO) +logger = logger_factory.get_logger() + + +class QueryDecisionMaker: + """Make decisions on the next step in the retrieval-augmented generation pipeline. + + This class evaluates the relevance of retrieved documents and determines the appropriate action: + 1. Transform the user's query for a web search if the retrieved documents are irrelevant. + 2. Generate an answer using the relevant documents if available. + + Designed for use in graph-based state management systems, this class processes the current state + to guide the next step in the pipeline. + + Methods: + __call__(state: dict[str, Any]) -> str: + Determine the next action ("transform_query" or "generate") based on the state. + """ + + def __call__(self, state: dict[str, Any]) -> str: + """Determine the next step in the pipeline based on document relevance. + + This method evaluates the relevance of retrieved documents, as indicated in the `web_search` field. + If a web search is required due to irrelevant documents, it suggests transforming the query. + Otherwise, it suggests generating an answer based on the relevant documents. + + Args: + state (dict[str, Any]): The current state of the pipeline, containing: + - `question` (str): The user's query. + - `web_search` (str): A binary decision ("Yes" or "No") indicating if a web search is required. + - `documents` (list): A list of retrieved and graded documents. + + Returns: + str: The next step in the pipeline: + - `"transform_query"`: If documents are irrelevant and a web search is needed. + - `"generate"`: If relevant documents are available for answering the query. + + Example: + ```python + state = { + "question": "What is the capital of France?", + "web_search": "Yes", + "documents": [] + } + decision_maker = QueryDecisionMaker() + next_step = decision_maker(state) + print(next_step) # Output: "transform_query" + ``` + """ + logger.info("ASSESSING GRADED DOCUMENTS") + web_search = state["web_search"] + + if web_search == "Yes": + # All documents have been filtered as irrelevant + logger.info("DECISION: DOCUMENTS ARE NOT RELEVANT TO QUESTION, TRANSFORM QUERY") + return "transform_query" + else: + # Relevant documents are available, proceed to generate the answer + logger.info("DECISION: GENERATE") + return "generate" diff --git a/src/rag_pipelines/retrieval_evaluation/retrieval_critic.py b/src/rag_pipelines/retrieval_evaluation/retrieval_critic.py new file mode 100644 index 0000000000000000000000000000000000000000..090a0b9122aa6fa512cfc9a9f04fe4dd86d30cb9 --- /dev/null +++ b/src/rag_pipelines/retrieval_evaluation/retrieval_critic.py @@ -0,0 +1,84 @@ +from typing import Any, Optional + +import weave +from langchain_core.documents import Document +from langchain_core.prompts import ChatPromptTemplate +from langchain_groq import ChatGroq + +from rag_pipelines.prompts import RETRIEVAL_CRITIC_PROMPT, RetrievalCriticResult + + +class RetrievalCritic: + """Evaluates the relevance of retrieved documents in response to a user question. + + Uses a language model chain to assess document relevance and filters documents based + on specified support levels. + + Attributes: + llm (ChatGroq): Language model instance used for relevance assessment. + prompt (ChatPromptTemplate): Template for structuring the critic evaluation prompt. + retrieval_critic_chain (RunnableSequence): Configured LangChain processing pipeline for evaluation. + support_levels (list[str]): list of support levels considered relevant for filtering. + """ + + def __init__(self, llm: ChatGroq, support_levels: Optional[list[str]] = None) -> None: + """Initialize the retrieval critic with language model and configuration. + + Args: + llm (ChatGroq): Pre-configured ChatGroq instance for evaluation processing. + support_levels (Optional[list[str]]): list of acceptable support levels. + Defaults to ["fully-supported", "partially-supported", "no-support"]. + Documents will be filtered to only include these levels. + """ + self.llm = llm + self.prompt = ChatPromptTemplate.from_messages([("system", RETRIEVAL_CRITIC_PROMPT)]) + self.retrieval_critic_chain = self.prompt | self.llm.with_structured_output(RetrievalCriticResult) + self.support_levels = support_levels or ["fully-supported", "partially-supported", "no-support"] + + @weave.op() + def score_context(self, question: str, context: str) -> str: + """Evaluate the relevance of a single document context to a question. + + Args: + question (str): User question to evaluate against. + context (str): Document text content to assess for relevance. + + Returns: + str: Support level decision from the model. Possible values are: + "fully-supported", "partially-supported", or "no-support". + """ + result = self.retrieval_critic_chain.invoke({"question": question, "context": context}) + return result.decision + + @weave.op() + def __call__(self, state: dict[str, Any]) -> dict[str, Any]: + """Filter document contexts based on their relevance to the user question. + + Processes a state dictionary containing question, documents, and contexts, + returning a new state with filtered contexts based on support levels. + + Args: + state (dict[str, Any]): Input processing state containing: + - "question" (str): Original user question. + - "documents" (list[Any]): Retrieved document objects (passed through). + - "context" (list[str]): Extracted document texts to filter. + + Returns: + dict[str, Any]: Output state with filtered contexts. Contains: + - "question" (str): Original question. + - "documents" (list[Document]): Document objects from input. + - "context" (list[str]): Filtered list of document texts that match the configured support levels. + """ + question: str = state["question"] + documents: list[Document] = state["documents"] + relevant_context: list[str] = state["context"] + + filtered_context: list[str] = [ + context for context in relevant_context if self.score_context(question, context) in self.support_levels + ] + + return { + "question": question, + "context": filtered_context, + "documents": documents, + } diff --git a/src/rag_pipelines/retrieval_evaluation/retrieval_evaluator.py b/src/rag_pipelines/retrieval_evaluation/retrieval_evaluator.py new file mode 100644 index 0000000000000000000000000000000000000000..eacf0a5d460d181d16ba0513a016a5f3f3b0c938 --- /dev/null +++ b/src/rag_pipelines/retrieval_evaluation/retrieval_evaluator.py @@ -0,0 +1,93 @@ +from typing import Any + +import weave +from langchain_core.documents import Document +from langchain_core.prompts import ChatPromptTemplate +from langchain_groq import ChatGroq + +from rag_pipelines.prompts import RETRIEVAL_EVALUATION_PROMPT, RetrievalEvaluationResult + + +class RetrievalEvaluator: + """Evaluates the relevance of retrieved documents in response to a user question. + + This class uses a language model to assess whether retrieved documents are relevant + to answering a given question. It provides both individual document scoring and + batch processing capabilities through a state-based interface. + + Attributes: + llm (ChatGroq): Language model used for relevance evaluation + prompt (ChatPromptTemplate): Template for structuring the evaluation prompt + retrieval_evaluation_chain (Chain): Configured processing chain for evaluations + """ + + def __init__(self, llm: ChatGroq) -> None: + """Initialize the evaluator with a language model and processing chain. + + Constructs a complete evaluation pipeline combining: + - Predefined prompt template + - Specified language model + - Structured output parser + + Args: + llm (ChatGroq): Configured ChatGroq instance for processing evaluations + """ + self.llm = llm + self.prompt = ChatPromptTemplate.from_messages([("system", RETRIEVAL_EVALUATION_PROMPT)]) + self.retrieval_evaluation_chain = self.prompt | self.llm.with_structured_output(RetrievalEvaluationResult) + + @weave.op() + def score_document(self, question: str, document: Document) -> str: + """Evaluate a single document's relevance to a given question. + + Args: + question (str): User query to evaluate against + document (Document): Document object to assess for relevance + + Returns: + str: Binary relevance decision - either 'relevant' or 'irrelevant' + + Example: + >>> evaluator.score_document("What is AI?", Document(page_content="AI is...")) + 'relevant' + """ + result = self.retrieval_evaluation_chain.invoke( + { + "question": question, + "context": document.page_content, + } + ) + return result.decision + + @weave.op() + def __call__(self, state: dict[str, Any]) -> dict[str, Any]: + """Process a state dictionary to filter relevant documents. + + Takes a state containing a question and retrieved documents, returning an updated + state with filtered context while preserving original documents. + + Args: + state (dict[str, Any]): Processing state containing: + - question (str): Original user question + - documents (List[Document]): Retrieved documents to evaluate + + Returns: + dict[str, Any]: Updated state with: + - question (str): Original question + - context (List[str]): Content from relevant documents + - documents (List[Document]): Original document list (unmodified) + + Raises: + KeyError: If input state is missing required keys ('question' or 'documents') + """ + question = state["question"] + documents = state["documents"] + relevant_context = [ + document.page_content for document in documents if self.score_document(question, document) == "relevant" + ] + + return { + "question": question, + "context": relevant_context, + "documents": documents, + } diff --git a/src/rag_pipelines/retrieval_evaluation/usefulness_evaluator.py b/src/rag_pipelines/retrieval_evaluation/usefulness_evaluator.py new file mode 100644 index 0000000000000000000000000000000000000000..1940964e81b00fae6fbe9032f1645145c081e44b --- /dev/null +++ b/src/rag_pipelines/retrieval_evaluation/usefulness_evaluator.py @@ -0,0 +1,86 @@ +from typing import Any, Optional + +import weave +from langchain_core.documents import Document +from langchain_core.prompts import ChatPromptTemplate +from langchain_groq import ChatGroq + +from rag_pipelines.prompts import USEFULNESS_EVALUATOR_PROMPT, UsefulnessEvaluatorResult + +SCORE_UPPER_LIMIT = 5 +SCORE_LOWER_LIMIT = 1 + + +class UsefulnessEvaluator: + """Evaluates and filters document contexts based on their usefulness score. + + Uses a language model chain to assess document usefulness and retains only contexts + meeting or exceeding the specified score threshold. + + Attributes: + llm (ChatGroq): Language model instance for scoring. + prompt (ChatPromptTemplate): Template for usefulness evaluation prompt. + evaluator_chain (RunnableSequence): LangChain pipeline for evaluation. + score_threshold (int): Minimum score (1-5) required for context inclusion. + """ + + def __init__(self, llm: ChatGroq, score_threshold: Optional[int] = 3) -> None: + """Initialize evaluator with language model and scoring threshold. + + Args: + llm (ChatGroq): Configured ChatGroq instance for evaluation. + score_threshold (Optional[int]): Minimum usefulness score (1-5) to retain. + Defaults to 3. Must be between 1 and 5 inclusive. + """ + if not (SCORE_LOWER_LIMIT <= score_threshold <= SCORE_UPPER_LIMIT): + msg = "score_threshold must be between 1 and 5" + raise ValueError(msg) + + self.llm = llm + self.prompt = ChatPromptTemplate.from_messages([("system", USEFULNESS_EVALUATOR_PROMPT)]) + self.evaluator_chain = self.prompt | self.llm.with_structured_output(UsefulnessEvaluatorResult) + self.score_threshold = score_threshold + + @weave.op() + def score_context(self, question: str, context: str) -> int: + """Calculate usefulness score for a document context relative to a question. + + Args: + question (str): User question to evaluate against. + context (str): Document text content to assess. + + Returns: + int: Numerical usefulness score (1-5). + """ + result = self.evaluator_chain.invoke({"question": question, "context": context}) + return int(result.decision) + + @weave.op() + def __call__(self, state: dict[str, Any]) -> dict[str, Any]: + """Filter document contexts based on their usefulness scores. + + Args: + state (dict[str, Any]): Processing state containing: + - "question" (str): Original user question + - "documents" (list[Document]): Retrieved documents + - "context" (list[str]): Document texts to filter + + Returns: + dict[str, Any]: Updated state with filtered contexts: + - "question": Original question + - "context": Texts with score >= threshold + - "documents": Original documents (unfiltered) + """ + question: str = state["question"] + documents: list[Document] = state["documents"] + relevant_context: list[str] = state["context"] + + filtered_context: list[str] = [ + context for context in relevant_context if self.score_context(question, context) >= self.score_threshold + ] + + return { + "question": question, + "context": filtered_context, + "documents": documents, + } diff --git a/src/rag_pipelines/unstructured/__init__.py b/src/rag_pipelines/unstructured/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8b31709880843f0748f4bced9fc7eed93b3dcc82 --- /dev/null +++ b/src/rag_pipelines/unstructured/__init__.py @@ -0,0 +1,4 @@ +from rag_pipelines.unstructured.unstructured_chunker import UnstructuredChunker +from rag_pipelines.unstructured.unstructured_pdf_loader import UnstructuredDocumentLoader + +__all__ = ["UnstructuredChunker", "UnstructuredDocumentLoader"] diff --git a/src/rag_pipelines/unstructured/unstructured_chunker.py b/src/rag_pipelines/unstructured/unstructured_chunker.py new file mode 100644 index 0000000000000000000000000000000000000000..3ec9a29b416308696aac7eac8ddd2eacf429f9eb --- /dev/null +++ b/src/rag_pipelines/unstructured/unstructured_chunker.py @@ -0,0 +1,227 @@ +"""Unstructured document chunking library. + +This module provides the `UnstructuredChunker` class for chunking unstructured documents +using different strategies from the `unstructured` library. It supports: + +- "basic" chunking: Splits documents into chunks based on character limits with + optional overlap and combining of smaller text elements. +- "by_title" chunking: Splits documents into sections based on titles, considering + options like maximum characters and allowing sections to span multiple pages. + +The chunker can be configured with various parameters for each strategy, allowing +fine-grained control over the resulting chunks. + +Example Usage: +```python +from dataloaders.text_splitters import UnstructuredChunker + +# Initialize the chunker +chunker = UnstructuredChunker() + +# Load your documents +documents = [...] + +# Chunk the documents using the "basic" strategy +chunked_documents = chunker.transform_documents(documents) +""" + +import logging +from typing import Any, Literal, Optional + +import weave +from langchain_core.documents import Document +from unstructured.chunking.basic import chunk_elements +from unstructured.chunking.title import chunk_by_title +from unstructured.documents.elements import Element, NarrativeText + +from rag_pipelines.utils import LoggerFactory + +# Initialize logger +logger_factory = LoggerFactory(logger_name=__name__, log_level=logging.INFO) +logger = logger_factory.get_logger() + + +class UnstructuredChunker(weave.Model): + """A class for chunking documents using different strategies provided by the unstructured library. + + Supports both "basic" and "by_title" chunking strategies. + + Attributes: + chunking_strategy (Literal["basic", "by_title"]): The strategy to use for chunking elements. + max_characters (int): Maximum number of characters in each chunk for the "basic" strategy. + new_after_n_chars (int): Number of characters after which a new chunk is forced for the "basic" strategy. + overlap (int): Number of characters to overlap between chunks for the "basic" strategy. + overlap_all (bool): Whether to overlap all chunks for the "basic" strategy. + combine_text_under_n_chars (Optional[int]): Maximum characters to combine smaller text elements for the "basic" strategy. + include_orig_elements (Optional[bool]): Whether to include original elements in the output for the "basic" strategy. + multipage_sections (Optional[bool]): Whether to allow sections to span multiple pages for the "by_title" strategy. + """ + + chunking_strategy: Literal["basic", "by_title"] + max_characters: int + new_after_n_chars: int + overlap: int + overlap_all: bool + combine_text_under_n_chars: Optional[int] + include_orig_elements: Optional[bool] + multipage_sections: Optional[bool] + + def __init__( + self, + chunking_strategy: Literal["basic", "by_title"] = "basic", + max_characters: int = 500, + new_after_n_chars: int = 500, + overlap: int = 0, + overlap_all: bool = False, + combine_text_under_n_chars: Optional[int] = None, + include_orig_elements: Optional[bool] = None, + multipage_sections: Optional[bool] = None, + ): + """Initialize the chunker with the specified chunking strategy and parameters. + + Args: + chunking_strategy (Literal["basic", "by_title"], optional): Chunking strategy to use. Defaults to "basic". + max_characters (int, optional): Maximum characters in a chunk for the "basic" strategy. Defaults to 500. + new_after_n_chars (int, optional): Characters after which a new chunk is forced for the "basic" strategy. + Defaults to 500. + overlap (int, optional): Characters to overlap between chunks for the "basic" strategy. Defaults to 0. + overlap_all (bool, optional): Overlap all chunks for the "basic" strategy. Defaults to False. + combine_text_under_n_chars (Optional[int], optional): Combine text elements under this limit for the + "basic" strategy. Defaults to None. + include_orig_elements (Optional[bool], optional): Include original elements in output for the "basic" + strategy. Defaults to None. + multipage_sections (Optional[bool], optional): Allow sections to span multiple pages for the "by_title" + strategy. Defaults to None. + """ + super().__init__( + chunking_strategy=chunking_strategy, + max_characters=max_characters, + new_after_n_chars=new_after_n_chars, + overlap=overlap, + overlap_all=overlap_all, + combine_text_under_n_chars=combine_text_under_n_chars, + include_orig_elements=include_orig_elements, + multipage_sections=multipage_sections, + ) + + self.chunking_strategy = chunking_strategy + self.max_characters = max_characters + self.new_after_n_chars = new_after_n_chars + self.overlap = overlap + self.overlap_all = overlap_all + self.combine_text_under_n_chars = combine_text_under_n_chars + self.include_orig_elements = include_orig_elements + self.multipage_sections = multipage_sections + + def _convert_documents_to_elements( + self, documents: list[Document] + ) -> tuple[list[NarrativeText], list[dict[str, Any]]]: + """Convert a list of LangChain documents to unstructured NarrativeText elements. + + This method takes in a list of LangChain Document objects and converts each + document into a NarrativeText element. It also extracts and stores the metadata + of each document separately in a list. + + Args: + documents (List[Document]): A list of LangChain Document objects to be converted to NarrativeText elements. + + Returns: + tuple[list[NarrativeText], list[dict[str, Any]]]: + - A list of unstructured NarrativeText elements, where each element corresponds to a document's text. + - A list of metadata dictionaries, where each dictionary corresponds to the metadata of a document in the + input list. + + """ + elements = [] + element_metadatas = [] + + for document in documents: + # Convert each document into a NarrativeText element + element = NarrativeText(text=document.page_content) + elements.append(element) + + # Store the metadata separately + element_metadatas.append(document.metadata) + + logger.debug(f"Converted {len(documents)} documents to elements.") + + return elements, element_metadatas + + def _convert_chunked_elements_to_documents(self, elements: list[Element]) -> list[Document]: + """Convert a list of chunked unstructured elements back to LangChain documents. + + Args: + elements (List[Element]): List of chunked unstructured elements. + + Returns: + List[Document]: List of LangChain documents converted from elements. + """ + documents = [] + for element in elements: + document = Document(page_content=element.text, metadata=element.metadata.to_dict()) + documents.append(document) + logger.debug(f"Converted {len(elements)} chunked elements to documents.") + return documents + + def transform_documents(self, documents: list[Document]) -> list[Document]: + """Chunks the provided documents based on the configured strategy. + + Args: + documents (List[Document]): List of documents to be chunked. + + Returns: + List[Document]: List of chunked documents. + + Raises: + ValueError: If no documents are provided or if an unsupported chunking strategy is specified. + """ + if not documents: + msg = "No documents provided for transformation." + logger.error(msg) + raise ValueError(msg) + + logger.info(f"Transforming {len(documents)} documents using strategy: {self.chunking_strategy}") + + all_chunked_documents = [] + + # Convert each document to unstructured elements and separate metadata + elements, element_metadatas = self._convert_documents_to_elements(documents) + + # Apply the selected chunking strategy + for i, element in enumerate(elements): + metadata = element_metadatas[i] # Get the metadata for the current document + + if self.chunking_strategy == "basic": + chunked_elements = chunk_elements( + [element], # Process one element at a time + max_characters=self.max_characters, + new_after_n_chars=self.new_after_n_chars, + overlap=self.overlap, + overlap_all=self.overlap_all, + include_orig_elements=self.include_orig_elements, + ) + elif self.chunking_strategy == "by_title": + chunked_elements = chunk_by_title( + [element], # Process one element at a time + max_characters=self.max_characters, + new_after_n_chars=self.new_after_n_chars, + overlap=self.overlap, + overlap_all=self.overlap_all, + include_orig_elements=self.include_orig_elements, + multipage_sections=self.multipage_sections, + combine_text_under_n_chars=self.combine_text_under_n_chars, + ) + else: + msg = f"Unsupported chunking strategy: {self.chunking_strategy}" + logger.error(msg) + raise ValueError(msg) + + logger.info(f"Chunked element into {len(chunked_elements)} sub-elements.") + + # Add metadata to each chunk and convert back to Document + for chunk in chunked_elements: + chunked_document = Document(page_content=chunk.text, metadata=metadata) + all_chunked_documents.append(chunked_document) + + logger.info(f"Combined all chunked documents into {len(all_chunked_documents)} documents.") + return all_chunked_documents diff --git a/src/rag_pipelines/unstructured/unstructured_pdf_loader.py b/src/rag_pipelines/unstructured/unstructured_pdf_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..bce69b60f219611149e15df6ee6c44672ac08796 --- /dev/null +++ b/src/rag_pipelines/unstructured/unstructured_pdf_loader.py @@ -0,0 +1,143 @@ +from pathlib import Path +from typing import Optional + +import weave +from langchain_community.document_loaders import UnstructuredPDFLoader +from langchain_core.documents import Document + + +class UnstructuredDocumentLoader(weave.Model): + """A class for loading and transforming unstructured PDF documents. + + This class provides functionality for extracting text, tables, and images + from PDFs using different processing strategies. + """ + + strategy: str + mode: str + include_page_breaks: bool + infer_table_structure: bool + ocr_languages: Optional[str] + languages: Optional[list[str]] + hi_res_model_name: Optional[str] + extract_images_in_pdf: bool + extract_image_block_types: Optional[list[str]] + extract_image_block_output_dir: Optional[str] + extract_image_block_to_payload: bool + starting_page_number: int + extract_forms: bool + form_extraction_skip_tables: bool + + def __init__( + self, + strategy: str = "hi_res", + mode: str = "elements", + include_page_breaks: bool = False, + infer_table_structure: bool = False, + ocr_languages: Optional[str] = None, + languages: Optional[list[str]] = None, + hi_res_model_name: Optional[str] = None, + extract_images_in_pdf: bool = False, + extract_image_block_types: Optional[list[str]] = None, + extract_image_block_output_dir: Optional[str] = None, + extract_image_block_to_payload: bool = False, + starting_page_number: int = 1, + extract_forms: bool = False, + form_extraction_skip_tables: bool = True, + ): + """Initialize the document loader with configuration parameters. + + Args: + strategy (str): The strategy for document processing (e.g., "hi_res"). + mode (str): The mode of extraction (e.g., "elements"). + include_page_breaks (bool): Whether to include page breaks. + infer_table_structure (bool): Whether to infer table structures. + ocr_languages (Optional[str]): Languages for OCR processing. + languages (Optional[List[str]]): List of languages for document processing. + hi_res_model_name (Optional[str]): Model name for high-resolution processing. + extract_images_in_pdf (bool): Whether to extract images from PDFs. + extract_image_block_types (Optional[List[str]]): Types of image blocks to extract. + extract_image_block_output_dir (Optional[str]): Directory to save extracted images. + extract_image_block_to_payload (bool): Whether to add extracted images to payload. + starting_page_number (int): Page number from which extraction should start. + extract_forms (bool): Whether to extract form data. + form_extraction_skip_tables (bool): Whether to skip tables during form extraction. + """ + super().__init__( + strategy=strategy, + mode=mode, + include_page_breaks=include_page_breaks, + infer_table_structure=infer_table_structure, + ocr_languages=ocr_languages, + languages=languages, + hi_res_model_name=hi_res_model_name, + extract_images_in_pdf=extract_images_in_pdf, + extract_image_block_types=extract_image_block_types, + extract_image_block_output_dir=extract_image_block_output_dir, + extract_image_block_to_payload=extract_image_block_to_payload, + starting_page_number=starting_page_number, + extract_forms=extract_forms, + form_extraction_skip_tables=form_extraction_skip_tables, + ) + + def _get_all_file_paths_from_directory(self, directory_path: str) -> list[str]: + """Retrieve all file paths from a given directory (recursively). + + Args: + directory_path (str): Path to the directory. + + Returns: + List[str]: A list of file paths. + + Raises: + ValueError: If the directory does not exist or is not a directory. + """ + path = Path(directory_path).resolve() # Convert to absolute path + + if not path.exists(): + msg = f"Directory does not exist: {directory_path}" + raise ValueError(msg) + if not path.is_dir(): + msg = f"Path is not a directory: {directory_path}" + raise ValueError(msg) + + return [str(file) for file in path.rglob("*") if file.is_file()] # Get only files + + def transform_documents(self, directory_path: str) -> list[Document]: + """Transform all documents in the given directory into structured format. + + This method loads PDFs from the specified directory and processes them + using the UnstructuredPDFLoader. + + Args: + directory_path (str): Path to the directory containing PDF files. + + Returns: + List[Document]: A list of structured documents. + """ + file_paths = self._get_all_file_paths_from_directory(directory_path) + + documents: list[Document] = [] + + for file in file_paths: + loader = UnstructuredPDFLoader( + file_path=file, + mode=self.mode, + strategy=self.strategy, + include_page_breaks=self.include_page_breaks, + infer_table_structure=self.infer_table_structure, + ocr_languages=self.ocr_languages, + languages=self.languages, + hi_res_model_name=self.hi_res_model_name, + extract_images_in_pdf=self.extract_images_in_pdf, + extract_image_block_types=self.extract_image_block_types, + extract_image_block_output_dir=self.extract_image_block_output_dir, + extract_image_block_to_payload=self.extract_image_block_to_payload, + starting_page_number=self.starting_page_number, + extract_forms=self.extract_forms, + form_extraction_skip_tables=self.form_extraction_skip_tables, + ) + parsed_documents = loader.load() + documents.extend(parsed_documents) + + return documents diff --git a/src/rag_pipelines/utils/__init__.py b/src/rag_pipelines/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c0be751bcc174d899ea91f0488a3a2f289e0b22e --- /dev/null +++ b/src/rag_pipelines/utils/__init__.py @@ -0,0 +1,4 @@ +from rag_pipelines.utils.logging import LoggerFactory +from rag_pipelines.utils.parse_inputs import dict_type + +__all__ = ["LoggerFactory", "dict_type"] diff --git a/src/rag_pipelines/utils/__pycache__/__init__.cpython-310.pyc b/src/rag_pipelines/utils/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cb5e8bc774fdae55828d9bc0792202b7314dff57 Binary files /dev/null and b/src/rag_pipelines/utils/__pycache__/__init__.cpython-310.pyc differ diff --git a/src/rag_pipelines/utils/__pycache__/logging.cpython-310.pyc b/src/rag_pipelines/utils/__pycache__/logging.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d2253d34bd626216753bb59a3be11a02f1f22a28 Binary files /dev/null and b/src/rag_pipelines/utils/__pycache__/logging.cpython-310.pyc differ diff --git a/src/rag_pipelines/utils/__pycache__/parse_inputs.cpython-310.pyc b/src/rag_pipelines/utils/__pycache__/parse_inputs.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..558b8c528e378dd9f359c1334562afe608bc633f Binary files /dev/null and b/src/rag_pipelines/utils/__pycache__/parse_inputs.cpython-310.pyc differ diff --git a/src/rag_pipelines/utils/logging.py b/src/rag_pipelines/utils/logging.py new file mode 100644 index 0000000000000000000000000000000000000000..a512b00786c9354b211269a5611b67103720c6f5 --- /dev/null +++ b/src/rag_pipelines/utils/logging.py @@ -0,0 +1,79 @@ +import logging +import os + + +class LoggerFactory: + """Factory class to set up and configure a logger with customizable log level and format. + + This class ensures that logging is configured only once during the application's lifetime. + It can configure the logger based on a given name, logging level, and logging format. + + Attributes: + logger_name (str): The name of the logger to create. + log_level (int): The logging level (e.g., logging.INFO, logging.DEBUG). + log_format (str): The format for log messages. + logger (logging.Logger): The configured logger instance. + + Methods: + get_logger(): + Returns the configured logger instance. + configure_from_env(logger_name, env_var="LOG_LEVEL"): + Configures the logger based on an environment variable for dynamic log level setting. + """ + + _is_logger_initialized: bool = False + + def __init__( + self, + logger_name: str, + log_level: int = logging.INFO, + log_format: str = "%(asctime)s - %(levelname)s - %(message)s", + ) -> None: + """Initialize the LoggerFactory instance with the given logger name, log level, and format. + + Args: + logger_name (str): The name of the logger to create. + log_level (int, optional): The logging level to use (default is logging.INFO). + log_format (str, optional): The format for log messages (default is + "%(asctime)s - %(levelname)s - %(message)s"). + """ + self.logger_name = logger_name + self.log_level = log_level + self.log_format = log_format + self.logger = self._initialize_logger() + + def _initialize_logger(self) -> logging.Logger: + """Configure the logger with the given name and log level. + + Returns: + logging.Logger: A configured logger instance. + """ + if not self._is_logger_initialized: + logging.basicConfig(level=self.log_level, format=self.log_format) + self._is_logger_initialized = True + + return logging.getLogger(self.logger_name) + + def get_logger(self) -> logging.Logger: + """Return the configured logger instance. + + Returns: + logging.Logger: The logger instance. + """ + return self.logger + + @staticmethod + def configure_from_env(logger_name: str, env_var: str = "LOG_LEVEL") -> "LoggerFactory": + """Configure the logger based on an environment variable for dynamic log level setting. + + Args: + logger_name (str): The name of the logger to create. + env_var (str, optional): The environment variable to retrieve the log level from (default is "LOG_LEVEL"). + + Returns: + LoggerFactory: A LoggerFactory instance with the log level set based on the environment variable. + """ + log_level_str = os.getenv(env_var, "INFO").upper() + # Default to INFO if invalid level + log_level = getattr(logging, log_level_str, logging.INFO) + return LoggerFactory(logger_name, log_level=log_level) diff --git a/src/rag_pipelines/utils/parse_inputs.py b/src/rag_pipelines/utils/parse_inputs.py new file mode 100644 index 0000000000000000000000000000000000000000..0d61e7a44598835fb7cf990e6be476d7b563eda5 --- /dev/null +++ b/src/rag_pipelines/utils/parse_inputs.py @@ -0,0 +1,12 @@ +from ast import literal_eval + + +def dict_type(input_value): + """Convert a string to a dictionary.""" + stripped_input = input_value.strip() + if not stripped_input or stripped_input in ("{}", "''", '""'): + return {} + try: + return literal_eval(stripped_input) + except (SyntaxError, ValueError): + return {} diff --git a/src/rag_pipelines/vectordb/__init__.py b/src/rag_pipelines/vectordb/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7d44a3b2099037a7eafa365292d58d18a942635f --- /dev/null +++ b/src/rag_pipelines/vectordb/__init__.py @@ -0,0 +1,7 @@ +from rag_pipelines.vectordb.dspy_milvus_retriever import MilvusRetriever as DSPyMilvusRetriever +from rag_pipelines.vectordb.milvus import MilvusVectorDB +from rag_pipelines.vectordb.milvus_retriever import MilvusRetriever +from rag_pipelines.vectordb.pinecone import PineconeVectorDB +from rag_pipelines.vectordb.pinecone_retriever import PineconeRetriever + +__all__ = ["DSPyMilvusRetriever", "MilvusRetriever", "MilvusVectorDB", "PineconeRetriever", "PineconeVectorDB"] diff --git a/src/rag_pipelines/vectordb/__pycache__/__init__.cpython-310.pyc b/src/rag_pipelines/vectordb/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3a008d125a9f96c8487195e14cfd0d4c10109793 Binary files /dev/null and b/src/rag_pipelines/vectordb/__pycache__/__init__.cpython-310.pyc differ diff --git a/src/rag_pipelines/vectordb/__pycache__/dspy_milvus_retriever.cpython-310.pyc b/src/rag_pipelines/vectordb/__pycache__/dspy_milvus_retriever.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ad9b8f2d084828442b85755a628fe46eb62e3239 Binary files /dev/null and b/src/rag_pipelines/vectordb/__pycache__/dspy_milvus_retriever.cpython-310.pyc differ diff --git a/src/rag_pipelines/vectordb/__pycache__/milvus.cpython-310.pyc b/src/rag_pipelines/vectordb/__pycache__/milvus.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cf55214ea579e909116f000ad55a507b81abe68b Binary files /dev/null and b/src/rag_pipelines/vectordb/__pycache__/milvus.cpython-310.pyc differ diff --git a/src/rag_pipelines/vectordb/__pycache__/milvus_retriever.cpython-310.pyc b/src/rag_pipelines/vectordb/__pycache__/milvus_retriever.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..03c28b4eac798ff5ca66f39f8c604417a2bf7eb8 Binary files /dev/null and b/src/rag_pipelines/vectordb/__pycache__/milvus_retriever.cpython-310.pyc differ diff --git a/src/rag_pipelines/vectordb/__pycache__/pinecone.cpython-310.pyc b/src/rag_pipelines/vectordb/__pycache__/pinecone.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ada41711692d56ede2ec0626928ea9d3c15457db Binary files /dev/null and b/src/rag_pipelines/vectordb/__pycache__/pinecone.cpython-310.pyc differ diff --git a/src/rag_pipelines/vectordb/__pycache__/pinecone_retriever.cpython-310.pyc b/src/rag_pipelines/vectordb/__pycache__/pinecone_retriever.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..629a7b59067b3af61e23cd06c670e6a2d8f657d7 Binary files /dev/null and b/src/rag_pipelines/vectordb/__pycache__/pinecone_retriever.cpython-310.pyc differ diff --git a/src/rag_pipelines/vectordb/dspy_milvus_retriever.py b/src/rag_pipelines/vectordb/dspy_milvus_retriever.py new file mode 100644 index 0000000000000000000000000000000000000000..92e37ddd0cc00079900e8691ed2644ea51da0cca --- /dev/null +++ b/src/rag_pipelines/vectordb/dspy_milvus_retriever.py @@ -0,0 +1,117 @@ +from typing import Any, Optional, Union + +import dspy +import weave +from dspy import Prediction +from dspy.dsp.utils import dotdict +from langchain_huggingface import HuggingFaceEmbeddings +from langchain_milvus.retrievers import MilvusCollectionHybridSearchRetriever +from pymilvus import ( + Collection, + WeightedRanker, +) + +from rag_pipelines.embeddings.sparse_milvus import SparseEmbeddingsMilvus as SparseEmbeddings + +TEXT_FIELD = "text" +DENSE_FIELD = "dense_vector" +SPARSE_FIELD = "sparse_vector" + + +class MilvusRetriever(dspy.Retrieve): + """Combine dense and sparse retrieval methods using Milvus. + + This class sets up a hybrid retriever that leverages dense embeddings (e.g., transformer-based) + and sparse embeddings (e.g., BM25 or SPLADE) to enhance retrieval accuracy. By balancing + between dense contextual similarity and sparse keyword-based matching, it improves search results. + """ + + collection: Optional[Collection] = None + dense_embedding_model: Optional[HuggingFaceEmbeddings] = None + sparse_embedding_model: Optional[SparseEmbeddings] = None + anns_fields: Optional[list[str]] = None + field_search_params: Optional[list[dict[str, Any]]] = None + text_field: str = TEXT_FIELD + top_k: int = 4 + rerank: Optional[WeightedRanker] = WeightedRanker(0.5, 0.5) + hybrid_retriever: Optional[MilvusCollectionHybridSearchRetriever] = None + + def __init__( + self, + collection: Collection, + dense_embedding_model: HuggingFaceEmbeddings, + sparse_embedding_model: SparseEmbeddings, + anns_fields: Optional[list[str]] = None, + field_search_params: Optional[list[dict[str, Any]]] = None, + text_field: str = TEXT_FIELD, + rerank: Optional[WeightedRanker] = WeightedRanker(0.5, 0.5), + top_k: int = 4, + ) -> None: + """Initialize the hybrid retriever with specified parameters. + + This constructor configures the retriever to balance between dense and sparse retrieval + methods. + + Args: + collection (Collection): The Milvus collection instance for document retrieval. + dense_embedding_model (HuggingFaceEmbeddings): Model for generating dense embeddings. + sparse_embedding_model (SparseEmbeddings): Model for generating sparse embeddings. + anns_fields (Optional[list[str]]): List of fields to search in the collection. + field_search_params (Optional[dict[str, Any]]): Parameters for field-specific search. + text_field (str): Field name for text content. Defaults to "text". + rerank (Optional[WeightedRanker]): Weighted ranker for reranking results. + top_k (int): Maximum number of top documents to retrieve. + """ + if field_search_params is None: + field_search_params = [{"metric_type": "IP"}, {"metric_type": "IP", "params": {}}] + if anns_fields is None: + anns_fields = [DENSE_FIELD, SPARSE_FIELD] + + self.collection = collection + self.dense_embedding_model = dense_embedding_model + self.sparse_embedding_model = sparse_embedding_model + self.anns_fields = anns_fields + self.field_search_params = field_search_params + self.text_field = text_field + self.top_k = top_k + self.rerank = rerank + + self._initialize_retriever() + + def _initialize_retriever(self): + self.hybrid_retriever = MilvusCollectionHybridSearchRetriever( + collection=self.collection, + rerank=self.rerank, + anns_fields=self.anns_fields, + field_embeddings=[self.dense_embedding_model, self.sparse_embedding_model], + field_search_params=self.field_search_params, + text_field=self.text_field, + top_k=self.top_k, + ) + + @weave.op() + def forward(self, query_or_queries: Union[str, list[str]], k: Optional[int] = None) -> dspy.Prediction: + """Query the hybrid retriever for relevant documents. + + Args: + query_or_queries (Union[str, List[str]]): Either a single query or a list of queries. + k (Optional[int]): Maximum number of top documents to retrieve. Defaults to None. + + Returns: + dspy.Prediction: A Prediction object containing the retrieved documents. + """ + if k: + self.top_k = k + + if isinstance(query_or_queries, str): + question = query_or_queries + documents = self.hybrid_retriever.invoke(question) + else: + documents = [] + for question in query_or_queries: + response_documents = self.hybrid_retriever.invoke(question) + documents.append(response_documents) + + predictions = Prediction(passages=[dotdict({"long_text": doc.page_content}) for doc in documents]) + + return predictions diff --git a/src/rag_pipelines/vectordb/milvus.py b/src/rag_pipelines/vectordb/milvus.py new file mode 100644 index 0000000000000000000000000000000000000000..7b774694ced65af148faf58c73b5a20cd4e13a53 --- /dev/null +++ b/src/rag_pipelines/vectordb/milvus.py @@ -0,0 +1,158 @@ +from typing import Optional + +import weave +from langchain_core.documents import Document +from pymilvus import ( + Collection, + CollectionSchema, + connections, +) +from tqdm import tqdm + +ID_FIELD = "doc_id" +TEXT_FIELD = "text" +METADATA_FIELD = "metadata" +DENSE_FIELD = "dense_vector" +SPARSE_FIELD = "sparse_vector" + + +class MilvusVectorDB: + """Manage interactions with a Milvus vector index for hybrid search. + + This class facilitates initializing a Milvus Collection, adding documents with dense and sparse embeddings, + and configuring a hybrid retriever for effective search across both embeddings. + """ + + uri: str + token: str + collection_name: str + collection_schema: Optional[CollectionSchema] = None + pk_field: str + dense_field: str + sparse_field: str + text_field: str + metadata_field: str + dense_index_params: Optional[dict] = None + sparse_index_params: Optional[dict] = None + create_new_collection: bool + collection: Optional[Collection] = None + + def __init__( + self, + uri: str, + token: str, + collection_name: str = "default", + collection_schema: Optional[CollectionSchema] = None, + pk_field: str = ID_FIELD, + dense_field: str = DENSE_FIELD, + sparse_field: str = SPARSE_FIELD, + text_field: str = TEXT_FIELD, + metadata_field: str = METADATA_FIELD, + dense_index_params: Optional[dict] = None, + sparse_index_params: Optional[dict] = None, + create_new_collection: bool = False, + ) -> None: + """Initialize the MilvusVectorDB. + + This sets up the Milvus client, creates a collection if it does not already exist, + and prepares it for hybrid search using dense and sparse embeddings. + + Args: + uri (str): Milvus server URI. + token (str): Milvus server token. + collection_name (str): Name of the Milvus collection to use or create. Defaults to "default". + collection_schema (Optional[CollectionSchema]): Schema for the Milvus collection. Defaults to None. + dense_field (str): Field name for dense embeddings. Defaults to "dense_vector". + sparse_field (str): Field name for sparse embeddings. Defaults to "sparse_vector". + text_field (str): Field name for text content. Defaults to "text". + metadata_field (str): Field name for metadata. Defaults to "metadata". + dense_index_params (Optional[dict]): Index parameters for dense embeddings. Defaults to None. + sparse_index_params (Optional[dict]): Index parameters for sparse embeddings. Defaults to None. + create_new_collection (bool): Flag indicating whether to create a new collection. Defaults to False. + """ + self.uri = uri + self.token = token + self.collection_name = collection_name + self.collection_schema = collection_schema + self.pk_field = pk_field + self.text_field = text_field + self.metadata_field = metadata_field + self.dense_field = dense_field + self.sparse_field = sparse_field + self.dense_index_params = dense_index_params or {} + self.sparse_index_params = sparse_index_params or {} + + connections.connect(uri=self.uri, token=self.token) + + if create_new_collection: + self.create_collection() + + else: + self.collection = Collection(name=self.collection_name) + + def create_collection(self) -> None: + """Create the Milvus collection if it does not already exist. + + This method connects to the specified Milvus server and creates a new collection + with the provided schema, dense field, and index parameters. + """ + self.collection = Collection(name=self.collection_name, schema=self.collection_schema) + + self.collection.create_index(self.dense_field, self.dense_index_params) + self.collection.create_index(self.sparse_field, self.sparse_index_params) + self.collection.flush() + + @weave.op() + def add_documents( + self, + documents: list[Document], + dense_embedding_model, + sparse_embedding_model, + batch_size: int = 100, + ) -> None: + """Add documents to the Milvus collection with hybrid embeddings in batches. + + This method processes documents in batches, generates dense and sparse embeddings, + and upserts them into the Milvus collection. + + Args: + documents (list[Document]): List of documents to add to the collection. + dense_embedding_model: Model for generating dense embeddings. + sparse_embedding_model: Model for generating sparse embeddings. + batch_size (int): Number of documents to process per batch. Defaults to 100. + """ + if not documents: + msg = "The documents list is empty, provide valid documents to add." + raise ValueError(msg) + + texts = [doc.page_content for doc in documents] + metadata_list = [doc.metadata for doc in documents] + + with tqdm(total=len(texts), desc="Processing batches") as pbar: + for i in range(0, len(texts), batch_size): + batch_texts = texts[i : i + batch_size] + batch_metadata = metadata_list[i : i + batch_size] + + # Batch embedding generation + dense_embeddings = dense_embedding_model.embed_documents(batch_texts) + sparse_embeddings = sparse_embedding_model.embed_documents(batch_texts) + + # Create entities for current batch + batch_entities = [ + { + self.pk_field: str(i + j), + self.text_field: text, + self.dense_field: dense_emb, + self.sparse_field: sparse_emb, + self.metadata_field: meta, + } + for j, (text, meta, dense_emb, sparse_emb) in enumerate( + zip(batch_texts, batch_metadata, dense_embeddings, sparse_embeddings) + ) + ] + + # Upsert batch into Milvus + self.collection.upsert(batch_entities) + pbar.update(len(batch_texts)) + + self.collection.load() diff --git a/src/rag_pipelines/vectordb/milvus_retriever.py b/src/rag_pipelines/vectordb/milvus_retriever.py new file mode 100644 index 0000000000000000000000000000000000000000..1ddaee1613834be979b2bc8bfbe9efdf1cbbd7f3 --- /dev/null +++ b/src/rag_pipelines/vectordb/milvus_retriever.py @@ -0,0 +1,129 @@ +from typing import Any, Optional + +import weave +# from langchain_huggingface import HuggingFaceEmbeddings +from langchain_voyageai import VoyageAIEmbeddings +from langchain_milvus.retrievers import MilvusCollectionHybridSearchRetriever +from pymilvus import ( + Collection, + WeightedRanker, +) + +from rag_pipelines.embeddings.sparse_milvus import SparseEmbeddingsMilvus as SparseEmbeddings + +TEXT_FIELD = "text" +DENSE_FIELD = "dense_vector" +SPARSE_FIELD = "sparse_vector" + + +class MilvusRetriever(weave.Model): + """Combine dense and sparse retrieval methods using Milvus. + + This class sets up a hybrid retriever that leverages dense embeddings (e.g., transformer-based) + and sparse embeddings (e.g., BM25 or SPLADE) to enhance retrieval accuracy. By balancing + between dense contextual similarity and sparse keyword-based matching, it improves search results. + """ + + collection: Optional[Collection] = None + dense_embedding_model: Optional[VoyageAIEmbeddings] = None + sparse_embedding_model: Optional[SparseEmbeddings] = None + anns_fields: Optional[list[str]] = None + field_search_params: Optional[list[dict[str, Any]]] = None + text_field: str = TEXT_FIELD + top_k: int = 4 + rerank: Optional[WeightedRanker] = WeightedRanker(0.5, 0.5) + hybrid_retriever: Optional[MilvusCollectionHybridSearchRetriever] = None + + def __init__( + self, + collection: Collection, + dense_embedding_model: VoyageAIEmbeddings, + sparse_embedding_model: SparseEmbeddings, + anns_fields: Optional[list[str]] = None, + field_search_params: Optional[list[dict[str, Any]]] = None, + text_field: str = TEXT_FIELD, + rerank: Optional[WeightedRanker] = WeightedRanker(0.5, 0.5), + top_k: int = 4, + ) -> None: + """Initialize the hybrid retriever with specified parameters. + + This constructor configures the retriever to balance between dense and sparse retrieval + methods. + + Args: + collection (Collection): The Milvus collection instance for document retrieval. + dense_embedding_model (HuggingFaceEmbeddings): Model for generating dense embeddings. + sparse_embedding_model (SparseEmbeddings): Model for generating sparse embeddings. + anns_fields (Optional[list[str]]): List of fields to search in the collection. + field_search_params (Optional[dict[str, Any]]): Parameters for field-specific search. + text_field (str): Field name for text content. Defaults to "text". + rerank (Optional[WeightedRanker]): Weighted ranker for reranking results. + top_k (int): Maximum number of top documents to retrieve. + """ + super().__init__( + collection=collection, + dense_embedding_model=dense_embedding_model, + sparse_embedding_model=sparse_embedding_model, + anns_fields=anns_fields, + field_search_params=field_search_params, + text_field=text_field, + top_k=top_k, + rerank=rerank, + ) + + if field_search_params is None: + field_search_params = [{"metric_type": "IP"}, {"metric_type": "IP", "params": {}}] + if anns_fields is None: + anns_fields = [DENSE_FIELD, SPARSE_FIELD] + + self.collection = collection + self.dense_embedding_model = dense_embedding_model + self.sparse_embedding_model = sparse_embedding_model + self.anns_fields = anns_fields + self.field_search_params = field_search_params + self.text_field = text_field + self.top_k = top_k + self.rerank = rerank + + self._initialize_retriever() + + def _initialize_retriever(self): + self.hybrid_retriever = MilvusCollectionHybridSearchRetriever( + collection=self.collection, + rerank=self.rerank, + anns_fields=self.anns_fields, + field_embeddings=[self.dense_embedding_model, self.sparse_embedding_model], + field_search_params=self.field_search_params, + text_field=self.text_field, + top_k=self.top_k, + ) + + @weave.op() + def __call__(self, state: dict[str, Any]) -> dict[str, Any]: + """Retrieve documents and update the state with the results. + + This method accepts a state dictionary containing a query string under the key "question." + It uses the hybrid retriever to fetch relevant documents and updates the state with the + retrieved documents. + + Args: + state (Dict[str, Any]): Input state containing: + - "question" (str): The query string to retrieve documents for. + + Returns: + Dict[str, Any]: Updated state containing: + - "documents" (List[Any]): Retrieved documents matching the query. + - "question" (str): The original query string. + + Example: + ```python + retriever = MilvusRetriever(collection, dense_model, sparse_model) + state = {"question": "What is quantum computing?"} + updated_state = retriever(state) + print(updated_state["documents"]) # Output: Retrieved documents + ``` + """ + question = state["question"] + documents = self.hybrid_retriever.invoke(question) + + return {"documents": documents, "question": question} diff --git a/src/rag_pipelines/vectordb/pinecone.py b/src/rag_pipelines/vectordb/pinecone.py new file mode 100644 index 0000000000000000000000000000000000000000..c9fdaadd830ac9ffab3750afbb023047e9b5f633 --- /dev/null +++ b/src/rag_pipelines/vectordb/pinecone.py @@ -0,0 +1,134 @@ +import os +from typing import Any, Optional + +from langchain_community.retrievers import PineconeHybridSearchRetriever +from langchain_core.documents import Document +from pinecone import Pinecone, ServerlessSpec + + +class PineconeVectorDB: + """Manage interactions with a Pinecone vector index for hybrid search. + + This class facilitates initializing a Pinecone index, adding documents with dense and sparse embeddings, + and configuring a hybrid retriever for effective search across both embeddings. + + Attributes: + pinecone_client (Pinecone): Client instance for interacting with Pinecone services. + index (Pinecone.Index): The active Pinecone index for storing and retrieving vector data. + hybrid_retriever (Optional[PineconeHybridSearchRetriever]): A retriever supporting hybrid search combining dense and sparse embeddings. + """ + + def __init__( + self, + api_key: Optional[str] = None, + index_name: str = "default", + dimension: int = 512, + metric: str = "dotproduct", + region: str = "us-east-1", + cloud: str = "aws", + **kwargs: Any, + ) -> None: + """Initialize the PineconeHybridVectorDB. + + This sets up the Pinecone client, creates an index if it does not already exist, + and prepares it for hybrid search using dense and sparse embeddings. + + Args: + api_key (Optional[str]): Pinecone API key; if not provided, attempts to retrieve it from environment variables. + index_name (str): Name of the Pinecone index to use or create. Defaults to "default". + dimension (int): Dimensionality of vectors in the index. Defaults to 512. + metric (str): Similarity metric for vector search. Defaults to "dotproduct". + region (str): Region where the index will be created. Defaults to "us-east-1". + cloud (str): Cloud provider for the index. Defaults to "aws". + **kwargs: Additional parameters for configuring the index. + + Raises: + OSError: If the API key is not provided and cannot be retrieved from environment variables. + """ + # Retrieve API key from environment or provided input + self.api_key = api_key or os.getenv("PINECONE_API_KEY") + if not self.api_key: + msg = "Pinecone API key is missing; provide it directly or via environment variables." + raise OSError(msg) + + # Initialize Pinecone client + self.pinecone_client = Pinecone(api_key=self.api_key) + + # Set up index parameters + self.index_name = index_name + self.index_params = { + "dimension": dimension, + "metric": metric, + "region": region, + "cloud": cloud, + **kwargs, + } + + # Initialize or create the specified index + self.initialize_index() + + # Access the specified index + self.index = self.pinecone_client.Index(self.index_name) + + # Hybrid retriever, initially unset + self.hybrid_retriever: Optional[PineconeHybridSearchRetriever] = None + + def initialize_index(self) -> None: + """Create the Pinecone index if it does not exist. + + This method checks for the existence of the index specified during initialization. + If the index does not exist, it creates one with the specified parameters. + + Raises: + Exception: If there is an issue creating or accessing the index. + """ + # Check if the index exists; if not, create it + if self.index_name not in self.pinecone_client.list_indexes().names(): + self.pinecone_client.create_index( + name=self.index_name, + dimension=self.index_params["dimension"], + metric=self.index_params["metric"], + spec=ServerlessSpec(cloud=self.index_params["cloud"], region=self.index_params["region"]), + ) + + def add_documents( + self, + documents: list[Document], + dense_embedding_model, + sparse_embedding_model, + namespace: str, + ) -> None: + """Add documents to the Pinecone index with hybrid embeddings. + + This method processes a list of documents, generates dense and sparse embeddings, + and stores them in the Pinecone index under the specified namespace. + + Args: + documents (List[Document]): List of documents to be added, each containing `page_content` and metadata. + dense_embedding_model: Model used to generate dense embeddings for the documents. + sparse_embedding_model: Model used to generate sparse embeddings for the documents. + namespace (str): Namespace in the Pinecone index to isolate these documents. + + Raises: + ValueError: If documents list is empty or lacks required attributes (`page_content` or `metadata`). + """ + # Validate input documents + if not documents: + msg = "The documents list is empty, provide valid documents to add." + raise ValueError(msg) + + # Extract text, metadata, and unique IDs from the documents + texts = [doc.page_content for doc in documents] + metadata_list = [doc.metadata for doc in documents] + ids = [doc.id for doc in documents] + + # Initialize the hybrid retriever if it hasn't been initialized yet + if not self.hybrid_retriever: + self.hybrid_retriever = PineconeHybridSearchRetriever( + embeddings=dense_embedding_model, + sparse_encoder=sparse_embedding_model, + index=self.index, + ) + + # Add texts to the Pinecone index using the retriever + self.hybrid_retriever.add_texts(texts=texts, ids=ids, metadatas=metadata_list, namespace=namespace) diff --git a/src/rag_pipelines/vectordb/pinecone_retriever.py b/src/rag_pipelines/vectordb/pinecone_retriever.py new file mode 100644 index 0000000000000000000000000000000000000000..412d38b76cf677d093c854cd22a5e22ae51c6790 --- /dev/null +++ b/src/rag_pipelines/vectordb/pinecone_retriever.py @@ -0,0 +1,111 @@ +from typing import Any, Optional + +import weave +from langchain_community.retrievers import PineconeHybridSearchRetriever +from langchain_huggingface import HuggingFaceEmbeddings +from pinecone.data.index import Index +from pinecone_text.sparse import SpladeEncoder + + +class PineconeRetriever(weave.Model): + """Combine dense and sparse retrieval methods using Pinecone. + + This class sets up a hybrid retriever that leverages dense embeddings (e.g., transformer-based) + and sparse embeddings (e.g., BM25 or SPLADE) to enhance retrieval accuracy. By balancing + between dense contextual similarity and sparse keyword-based matching, it improves search results. + + Attributes: + index (Optional[Index]): The Pinecone index instance used for document retrieval. + dense_embedding_model (Optional[HuggingFaceEmbeddings]): Model for generating dense embeddings. + sparse_embedding_model (Optional[SpladeEncoder]): Model for generating sparse embeddings. + namespace (str): Namespace for isolating the retrieval space within the Pinecone index. + alpha (float): Weighting factor for dense vs. sparse retrieval. Closer to 1 favors dense, closer to 0 favors sparse. + top_k (int): Maximum number of top documents to retrieve. + hybrid_retriever (Optional[PineconeHybridSearchRetriever]): Configured retriever instance. + """ + + index: Optional[Index] = None + dense_embedding_model: Optional[HuggingFaceEmbeddings] = None + sparse_embedding_model: Optional[SpladeEncoder] = None + namespace: str + alpha: float + top_k: int + hybrid_retriever: Optional[PineconeHybridSearchRetriever] = None + + def __init__( + self, + index: Index, + dense_embedding_model: HuggingFaceEmbeddings, + sparse_embedding_model: SpladeEncoder, + namespace: str, + alpha: float = 0.5, + top_k: int = 4, + ) -> None: + """Initialize the hybrid retriever with specified parameters. + + This constructor configures the retriever to balance between dense and sparse retrieval + methods and sets up a namespace for isolated searches in the Pinecone index. + + Args: + index (Index): The Pinecone index instance for document retrieval. + dense_embedding_model (HuggingFaceEmbeddings): Model for generating dense embeddings. + sparse_embedding_model (SpladeEncoder): Model for generating sparse embeddings. + namespace (str): String to isolate the retrieval context within the Pinecone index. + alpha (float): Weighting factor for dense vs. sparse retrieval, in the range [0, 1]. Defaults to 0.5. + top_k (int): Maximum number of documents to retrieve. Defaults to 4. + """ + super().__init__( + index=index, + dense_embedding_model=dense_embedding_model, + sparse_embedding_model=sparse_embedding_model, + namespace=namespace, + alpha=alpha, + top_k=top_k, + ) + + self.index = index + self.dense_embedding_model = dense_embedding_model + self.sparse_embedding_model = sparse_embedding_model + self.namespace = namespace + self.alpha = alpha + self.top_k = top_k + + # Initialize the hybrid retriever + if not self.hybrid_retriever: + self.hybrid_retriever = PineconeHybridSearchRetriever( + embeddings=self.dense_embedding_model, + sparse_encoder=self.sparse_embedding_model, + index=self.index, + namespace=self.namespace, + alpha=self.alpha, + top_k=self.top_k, + ) + + def __call__(self, state: dict[str, Any]) -> dict[str, Any]: + """Retrieve documents and update the state with the results. + + This method accepts a state dictionary containing a query string under the key "question." + It uses the hybrid retriever to fetch relevant documents and updates the state with the + retrieved documents. + + Args: + state (Dict[str, Any]): Input state containing: + - "question" (str): The query string to retrieve documents for. + + Returns: + Dict[str, Any]: Updated state containing: + - "documents" (List[Any]): Retrieved documents matching the query. + - "question" (str): The original query string. + + Example: + ```python + retriever = PineconeHybridRetriever(index, dense_model, sparse_model, "my_namespace") + state = {"question": "What is quantum computing?"} + updated_state = retriever(state) + print(updated_state["documents"]) # Output: Retrieved documents + ``` + """ + question = state["question"] + documents = self.hybrid_retriever.invoke(question) + + return {"documents": documents, "question": question} diff --git a/src/rag_pipelines/websearch/__init__.py b/src/rag_pipelines/websearch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0b5c5b82a0d45080a04276307493f7be5e016a7d --- /dev/null +++ b/src/rag_pipelines/websearch/__init__.py @@ -0,0 +1,5 @@ +from rag_pipelines.websearch.web_search import WebSearch + +__all__ = [ + "WebSearch", +] diff --git a/src/rag_pipelines/websearch/web_search.py b/src/rag_pipelines/websearch/web_search.py new file mode 100644 index 0000000000000000000000000000000000000000..374858e72fa8c7124b47f6d077e64d2506391a42 --- /dev/null +++ b/src/rag_pipelines/websearch/web_search.py @@ -0,0 +1,98 @@ +from typing import Any + +from langchain_community.tools import DuckDuckGoSearchRun +from langchain_core.documents import Document + + +class WebSearch: + """Perform web searches and append results to filtered documents. + + This class uses the DuckDuckGo search tool to perform web searches based on user queries. The search results + are converted into `Document` objects and appended to an existing list of documents for further use. + + Attributes: + search_tool (DuckDuckGoSearchRun): The DuckDuckGo search tool used to perform web searches. + """ + + def __init__(self): + """Initialize the WebSearch class with the DuckDuckGo search tool. + + Example: + ```python + web_searcher = WebSearch() + ``` + """ + # Initialize the DuckDuckGo search tool + self.search_tool = DuckDuckGoSearchRun() + + def search(self, question: str, filtered_docs: list[Document]) -> list[Document]: + """Perform a web search using DuckDuckGo and append results to filtered documents. + + This method conducts a web search based on the provided question and appends the resulting web search + results to the given list of filtered documents. + + Args: + question (str): The user query used for the web search. + filtered_docs (list[Document]): A list of existing documents to which web search results will be added. + + Returns: + list[Document]: An updated list of documents containing the original filtered documents along with + the newly retrieved web search results. + + Example: + ```python + question = "What is the latest AI research?" + filtered_docs = [Document(page_content="Existing content")] + + updated_docs = web_searcher.search(question, filtered_docs) + print(updated_docs) + ``` + """ + # Conduct a web search using the DuckDuckGo search tool + search_results = self.search_tool.invoke({"query": question}) + + # Create a Document object from the search results + web_results = Document(page_content=search_results) + + # Append the web search results to the filtered documents list + filtered_docs.append(web_results) + + # Return the updated list of documents + return filtered_docs + + def __call__(self, state: dict[str, Any]) -> dict[str, Any]: + """Process the input state and append web search results to its documents. + + This callable interface extracts a question and a list of documents from the input state, performs a web + search, and appends the results to the documents in the state. + + Args: + state (dict[str, Any]): A dictionary containing: + - 'question' (str): The user query for the web search. + - 'documents' (list[Document]): Existing documents to which web search results will be added. + + Returns: + dict[str, Any]: An updated state dictionary containing: + - 'question' (str): The original query. + - 'documents' (list[Document]): The updated list of documents with appended web search results. + + Example: + ```python + state = { + "question": "What is the latest AI research?", + "documents": [Document(page_content="Existing content")] + } + web_searcher = WebSearch() + updated_state = web_searcher(state) + print(updated_state["documents"]) + ``` + """ + # Extract the question and documents from the state + question = state["question"] + documents = state["documents"] + + # Perform a web search and append the results to the existing documents + updated_docs = self.search(question=question, filtered_docs=documents) + + # Return the updated state with the new documents + return {"question": question, "documents": updated_docs}