| |
| |
| import json |
| import os |
| import re |
| import unittest |
| from pathlib import Path |
| from typing import Any, cast |
| from unittest import mock |
|
|
| import pytest |
| import yaml |
| from pydantic import ValidationError |
|
|
| import graphrag.config.defaults as defs |
| from graphrag.config import ( |
| ApiKeyMissingError, |
| AzureApiBaseMissingError, |
| AzureDeploymentNameMissingError, |
| CacheConfig, |
| CacheConfigInput, |
| CacheType, |
| ChunkingConfig, |
| ChunkingConfigInput, |
| ClaimExtractionConfig, |
| ClaimExtractionConfigInput, |
| ClusterGraphConfig, |
| ClusterGraphConfigInput, |
| CommunityReportsConfig, |
| CommunityReportsConfigInput, |
| EmbedGraphConfig, |
| EmbedGraphConfigInput, |
| EntityExtractionConfig, |
| EntityExtractionConfigInput, |
| GlobalSearchConfig, |
| GraphRagConfig, |
| GraphRagConfigInput, |
| InputConfig, |
| InputConfigInput, |
| InputFileType, |
| InputType, |
| LLMParameters, |
| LLMParametersInput, |
| LocalSearchConfig, |
| ParallelizationParameters, |
| ReportingConfig, |
| ReportingConfigInput, |
| ReportingType, |
| SnapshotsConfig, |
| SnapshotsConfigInput, |
| StorageConfig, |
| StorageConfigInput, |
| StorageType, |
| SummarizeDescriptionsConfig, |
| SummarizeDescriptionsConfigInput, |
| TextEmbeddingConfig, |
| TextEmbeddingConfigInput, |
| UmapConfig, |
| UmapConfigInput, |
| create_graphrag_config, |
| ) |
| from graphrag.index import ( |
| PipelineConfig, |
| PipelineCSVInputConfig, |
| PipelineFileCacheConfig, |
| PipelineFileReportingConfig, |
| PipelineFileStorageConfig, |
| PipelineInputConfig, |
| PipelineTextInputConfig, |
| PipelineWorkflowReference, |
| create_pipeline_config, |
| ) |
|
|
| current_dir = os.path.dirname(__file__) |
|
|
| ALL_ENV_VARS = { |
| "GRAPHRAG_API_BASE": "http://some/base", |
| "GRAPHRAG_API_KEY": "test", |
| "GRAPHRAG_API_ORGANIZATION": "test_org", |
| "GRAPHRAG_API_PROXY": "http://some/proxy", |
| "GRAPHRAG_API_VERSION": "v1234", |
| "GRAPHRAG_ASYNC_MODE": "asyncio", |
| "GRAPHRAG_CACHE_STORAGE_ACCOUNT_BLOB_URL": "cache_account_blob_url", |
| "GRAPHRAG_CACHE_BASE_DIR": "/some/cache/dir", |
| "GRAPHRAG_CACHE_CONNECTION_STRING": "test_cs1", |
| "GRAPHRAG_CACHE_CONTAINER_NAME": "test_cn1", |
| "GRAPHRAG_CACHE_TYPE": "blob", |
| "GRAPHRAG_CHUNK_BY_COLUMNS": "a,b", |
| "GRAPHRAG_CHUNK_OVERLAP": "12", |
| "GRAPHRAG_CHUNK_SIZE": "500", |
| "GRAPHRAG_CLAIM_EXTRACTION_ENABLED": "True", |
| "GRAPHRAG_CLAIM_EXTRACTION_DESCRIPTION": "test 123", |
| "GRAPHRAG_CLAIM_EXTRACTION_MAX_GLEANINGS": "5000", |
| "GRAPHRAG_CLAIM_EXTRACTION_PROMPT_FILE": "tests/unit/config/prompt-a.txt", |
| "GRAPHRAG_COMMUNITY_REPORTS_MAX_LENGTH": "23456", |
| "GRAPHRAG_COMMUNITY_REPORTS_PROMPT_FILE": "tests/unit/config/prompt-b.txt", |
| "GRAPHRAG_EMBEDDING_BATCH_MAX_TOKENS": "17", |
| "GRAPHRAG_EMBEDDING_BATCH_SIZE": "1000000", |
| "GRAPHRAG_EMBEDDING_CONCURRENT_REQUESTS": "12", |
| "GRAPHRAG_EMBEDDING_DEPLOYMENT_NAME": "model-deployment-name", |
| "GRAPHRAG_EMBEDDING_MAX_RETRIES": "3", |
| "GRAPHRAG_EMBEDDING_MAX_RETRY_WAIT": "0.1123", |
| "GRAPHRAG_EMBEDDING_MODEL": "text-embedding-2", |
| "GRAPHRAG_EMBEDDING_REQUESTS_PER_MINUTE": "500", |
| "GRAPHRAG_EMBEDDING_SKIP": "a1,b1,c1", |
| "GRAPHRAG_EMBEDDING_SLEEP_ON_RATE_LIMIT_RECOMMENDATION": "False", |
| "GRAPHRAG_EMBEDDING_TARGET": "all", |
| "GRAPHRAG_EMBEDDING_THREAD_COUNT": "2345", |
| "GRAPHRAG_EMBEDDING_THREAD_STAGGER": "0.456", |
| "GRAPHRAG_EMBEDDING_TOKENS_PER_MINUTE": "7000", |
| "GRAPHRAG_EMBEDDING_TYPE": "azure_openai_embedding", |
| "GRAPHRAG_ENCODING_MODEL": "test123", |
| "GRAPHRAG_INPUT_STORAGE_ACCOUNT_BLOB_URL": "input_account_blob_url", |
| "GRAPHRAG_ENTITY_EXTRACTION_ENTITY_TYPES": "cat,dog,elephant", |
| "GRAPHRAG_ENTITY_EXTRACTION_MAX_GLEANINGS": "112", |
| "GRAPHRAG_ENTITY_EXTRACTION_PROMPT_FILE": "tests/unit/config/prompt-c.txt", |
| "GRAPHRAG_INPUT_BASE_DIR": "/some/input/dir", |
| "GRAPHRAG_INPUT_CONNECTION_STRING": "input_cs", |
| "GRAPHRAG_INPUT_CONTAINER_NAME": "input_cn", |
| "GRAPHRAG_INPUT_DOCUMENT_ATTRIBUTE_COLUMNS": "test1,test2", |
| "GRAPHRAG_INPUT_ENCODING": "utf-16", |
| "GRAPHRAG_INPUT_FILE_PATTERN": ".*\\test\\.txt$", |
| "GRAPHRAG_INPUT_SOURCE_COLUMN": "test_source", |
| "GRAPHRAG_INPUT_TYPE": "blob", |
| "GRAPHRAG_INPUT_TEXT_COLUMN": "test_text", |
| "GRAPHRAG_INPUT_TIMESTAMP_COLUMN": "test_timestamp", |
| "GRAPHRAG_INPUT_TIMESTAMP_FORMAT": "test_format", |
| "GRAPHRAG_INPUT_TITLE_COLUMN": "test_title", |
| "GRAPHRAG_INPUT_FILE_TYPE": "text", |
| "GRAPHRAG_LLM_CONCURRENT_REQUESTS": "12", |
| "GRAPHRAG_LLM_DEPLOYMENT_NAME": "model-deployment-name-x", |
| "GRAPHRAG_LLM_MAX_RETRIES": "312", |
| "GRAPHRAG_LLM_MAX_RETRY_WAIT": "0.1122", |
| "GRAPHRAG_LLM_MAX_TOKENS": "15000", |
| "GRAPHRAG_LLM_MODEL_SUPPORTS_JSON": "true", |
| "GRAPHRAG_LLM_MODEL": "test-llm", |
| "GRAPHRAG_LLM_N": "1", |
| "GRAPHRAG_LLM_REQUEST_TIMEOUT": "12.7", |
| "GRAPHRAG_LLM_REQUESTS_PER_MINUTE": "900", |
| "GRAPHRAG_LLM_SLEEP_ON_RATE_LIMIT_RECOMMENDATION": "False", |
| "GRAPHRAG_LLM_THREAD_COUNT": "987", |
| "GRAPHRAG_LLM_THREAD_STAGGER": "0.123", |
| "GRAPHRAG_LLM_TOKENS_PER_MINUTE": "8000", |
| "GRAPHRAG_LLM_TYPE": "azure_openai_chat", |
| "GRAPHRAG_MAX_CLUSTER_SIZE": "123", |
| "GRAPHRAG_NODE2VEC_ENABLED": "true", |
| "GRAPHRAG_NODE2VEC_ITERATIONS": "878787", |
| "GRAPHRAG_NODE2VEC_NUM_WALKS": "5000000", |
| "GRAPHRAG_NODE2VEC_RANDOM_SEED": "010101", |
| "GRAPHRAG_NODE2VEC_WALK_LENGTH": "555111", |
| "GRAPHRAG_NODE2VEC_WINDOW_SIZE": "12345", |
| "GRAPHRAG_REPORTING_STORAGE_ACCOUNT_BLOB_URL": "reporting_account_blob_url", |
| "GRAPHRAG_REPORTING_BASE_DIR": "/some/reporting/dir", |
| "GRAPHRAG_REPORTING_CONNECTION_STRING": "test_cs2", |
| "GRAPHRAG_REPORTING_CONTAINER_NAME": "test_cn2", |
| "GRAPHRAG_REPORTING_TYPE": "blob", |
| "GRAPHRAG_SKIP_WORKFLOWS": "a,b,c", |
| "GRAPHRAG_SNAPSHOT_GRAPHML": "true", |
| "GRAPHRAG_SNAPSHOT_RAW_ENTITIES": "true", |
| "GRAPHRAG_SNAPSHOT_TOP_LEVEL_NODES": "true", |
| "GRAPHRAG_STORAGE_STORAGE_ACCOUNT_BLOB_URL": "storage_account_blob_url", |
| "GRAPHRAG_STORAGE_BASE_DIR": "/some/storage/dir", |
| "GRAPHRAG_STORAGE_CONNECTION_STRING": "test_cs", |
| "GRAPHRAG_STORAGE_CONTAINER_NAME": "test_cn", |
| "GRAPHRAG_STORAGE_TYPE": "blob", |
| "GRAPHRAG_SUMMARIZE_DESCRIPTIONS_MAX_LENGTH": "12345", |
| "GRAPHRAG_SUMMARIZE_DESCRIPTIONS_PROMPT_FILE": "tests/unit/config/prompt-d.txt", |
| "GRAPHRAG_LLM_TEMPERATURE": "0.0", |
| "GRAPHRAG_LLM_TOP_P": "1.0", |
| "GRAPHRAG_UMAP_ENABLED": "true", |
| "GRAPHRAG_LOCAL_SEARCH_TEXT_UNIT_PROP": "0.713", |
| "GRAPHRAG_LOCAL_SEARCH_COMMUNITY_PROP": "0.1234", |
| "GRAPHRAG_LOCAL_SEARCH_LLM_TEMPERATURE": "0.1", |
| "GRAPHRAG_LOCAL_SEARCH_LLM_TOP_P": "0.9", |
| "GRAPHRAG_LOCAL_SEARCH_LLM_N": "2", |
| "GRAPHRAG_LOCAL_SEARCH_LLM_MAX_TOKENS": "12", |
| "GRAPHRAG_LOCAL_SEARCH_TOP_K_RELATIONSHIPS": "15", |
| "GRAPHRAG_LOCAL_SEARCH_TOP_K_ENTITIES": "14", |
| "GRAPHRAG_LOCAL_SEARCH_CONVERSATION_HISTORY_MAX_TURNS": "2", |
| "GRAPHRAG_LOCAL_SEARCH_MAX_TOKENS": "142435", |
| "GRAPHRAG_GLOBAL_SEARCH_LLM_TEMPERATURE": "0.1", |
| "GRAPHRAG_GLOBAL_SEARCH_LLM_TOP_P": "0.9", |
| "GRAPHRAG_GLOBAL_SEARCH_LLM_N": "2", |
| "GRAPHRAG_GLOBAL_SEARCH_MAX_TOKENS": "5123", |
| "GRAPHRAG_GLOBAL_SEARCH_DATA_MAX_TOKENS": "123", |
| "GRAPHRAG_GLOBAL_SEARCH_MAP_MAX_TOKENS": "4123", |
| "GRAPHRAG_GLOBAL_SEARCH_CONCURRENCY": "7", |
| "GRAPHRAG_GLOBAL_SEARCH_REDUCE_MAX_TOKENS": "15432", |
| } |
|
|
|
|
| class TestDefaultConfig(unittest.TestCase): |
| def test_clear_warnings(self): |
| """Just clearing unused import warnings""" |
| assert CacheConfig is not None |
| assert ChunkingConfig is not None |
| assert ClaimExtractionConfig is not None |
| assert ClusterGraphConfig is not None |
| assert CommunityReportsConfig is not None |
| assert EmbedGraphConfig is not None |
| assert EntityExtractionConfig is not None |
| assert GlobalSearchConfig is not None |
| assert GraphRagConfig is not None |
| assert InputConfig is not None |
| assert LLMParameters is not None |
| assert LocalSearchConfig is not None |
| assert ParallelizationParameters is not None |
| assert ReportingConfig is not None |
| assert SnapshotsConfig is not None |
| assert StorageConfig is not None |
| assert SummarizeDescriptionsConfig is not None |
| assert TextEmbeddingConfig is not None |
| assert UmapConfig is not None |
| assert PipelineConfig is not None |
| assert PipelineFileReportingConfig is not None |
| assert PipelineFileStorageConfig is not None |
| assert PipelineInputConfig is not None |
| assert PipelineFileCacheConfig is not None |
| assert PipelineWorkflowReference is not None |
|
|
| @mock.patch.dict(os.environ, {"OPENAI_API_KEY": "test"}, clear=True) |
| def test_string_repr(self): |
| |
| config = create_graphrag_config() |
| string_repr = str(config) |
| assert string_repr is not None |
| assert json.loads(string_repr) is not None |
|
|
| |
| repr_str = config.__repr__() |
| |
| repr_str = repr_str.replace("async_mode=<AsyncType.Threaded: 'threaded'>,", "") |
| assert eval(repr_str) is not None |
|
|
| |
| pipeline_config = create_pipeline_config(config) |
| string_repr = str(pipeline_config) |
| assert string_repr is not None |
| assert json.loads(string_repr) is not None |
|
|
| |
| repr_str = pipeline_config.__repr__() |
| |
| repr_str = repr_str.replace( |
| "'async_mode': <AsyncType.Threaded: 'threaded'>,", "" |
| ) |
| assert eval(repr_str) is not None |
|
|
| @mock.patch.dict(os.environ, {}, clear=True) |
| def test_default_config_with_no_env_vars_throws(self): |
| with pytest.raises(ApiKeyMissingError): |
| |
| create_pipeline_config(create_graphrag_config()) |
|
|
| @mock.patch.dict(os.environ, {"GRAPHRAG_API_KEY": "test"}, clear=True) |
| def test_default_config_with_api_key_passes(self): |
| |
| config = create_pipeline_config(create_graphrag_config()) |
| assert config is not None |
|
|
| @mock.patch.dict(os.environ, {"OPENAI_API_KEY": "test"}, clear=True) |
| def test_default_config_with_oai_key_passes_envvar(self): |
| |
| config = create_pipeline_config(create_graphrag_config()) |
| assert config is not None |
|
|
| def test_default_config_with_oai_key_passes_obj(self): |
| |
| config = create_pipeline_config( |
| create_graphrag_config({"llm": {"api_key": "test"}}) |
| ) |
| assert config is not None |
|
|
| @mock.patch.dict( |
| os.environ, |
| {"GRAPHRAG_API_KEY": "test", "GRAPHRAG_LLM_TYPE": "azure_openai_chat"}, |
| clear=True, |
| ) |
| def test_throws_if_azure_is_used_without_api_base_envvar(self): |
| with pytest.raises(AzureApiBaseMissingError): |
| create_graphrag_config() |
|
|
| @mock.patch.dict(os.environ, {"GRAPHRAG_API_KEY": "test"}, clear=True) |
| def test_throws_if_azure_is_used_without_api_base_obj(self): |
| with pytest.raises(AzureApiBaseMissingError): |
| create_graphrag_config( |
| GraphRagConfigInput(llm=LLMParametersInput(type="azure_openai_chat")) |
| ) |
|
|
| @mock.patch.dict( |
| os.environ, |
| { |
| "GRAPHRAG_API_KEY": "test", |
| "GRAPHRAG_LLM_TYPE": "azure_openai_chat", |
| "GRAPHRAG_API_BASE": "http://some/base", |
| }, |
| clear=True, |
| ) |
| def test_throws_if_azure_is_used_without_llm_deployment_name_envvar(self): |
| with pytest.raises(AzureDeploymentNameMissingError): |
| create_graphrag_config() |
|
|
| @mock.patch.dict(os.environ, {"GRAPHRAG_API_KEY": "test"}, clear=True) |
| def test_throws_if_azure_is_used_without_llm_deployment_name_obj(self): |
| with pytest.raises(AzureDeploymentNameMissingError): |
| create_graphrag_config( |
| GraphRagConfigInput( |
| llm=LLMParametersInput( |
| type="azure_openai_chat", api_base="http://some/base" |
| ) |
| ) |
| ) |
|
|
| @mock.patch.dict( |
| os.environ, |
| { |
| "GRAPHRAG_API_KEY": "test", |
| "GRAPHRAG_EMBEDDING_TYPE": "azure_openai_embedding", |
| "GRAPHRAG_EMBEDDING_DEPLOYMENT_NAME": "x", |
| }, |
| clear=True, |
| ) |
| def test_throws_if_azure_is_used_without_embedding_api_base_envvar(self): |
| with pytest.raises(AzureApiBaseMissingError): |
| create_graphrag_config() |
|
|
| @mock.patch.dict(os.environ, {"GRAPHRAG_API_KEY": "test"}, clear=True) |
| def test_throws_if_azure_is_used_without_embedding_api_base_obj(self): |
| with pytest.raises(AzureApiBaseMissingError): |
| create_graphrag_config( |
| GraphRagConfigInput( |
| embeddings=TextEmbeddingConfigInput( |
| llm=LLMParametersInput( |
| type="azure_openai_embedding", |
| deployment_name="x", |
| ) |
| ), |
| ) |
| ) |
|
|
| @mock.patch.dict( |
| os.environ, |
| { |
| "GRAPHRAG_API_KEY": "test", |
| "GRAPHRAG_API_BASE": "http://some/base", |
| "GRAPHRAG_LLM_DEPLOYMENT_NAME": "x", |
| "GRAPHRAG_LLM_TYPE": "azure_openai_chat", |
| "GRAPHRAG_EMBEDDING_TYPE": "azure_openai_embedding", |
| }, |
| clear=True, |
| ) |
| def test_throws_if_azure_is_used_without_embedding_deployment_name_envvar(self): |
| with pytest.raises(AzureDeploymentNameMissingError): |
| create_graphrag_config() |
|
|
| @mock.patch.dict(os.environ, {"GRAPHRAG_API_KEY": "test"}, clear=True) |
| def test_throws_if_azure_is_used_without_embedding_deployment_name_obj(self): |
| with pytest.raises(AzureDeploymentNameMissingError): |
| create_graphrag_config( |
| GraphRagConfigInput( |
| llm=LLMParametersInput( |
| type="azure_openai_chat", |
| api_base="http://some/base", |
| deployment_name="model-deployment-name-x", |
| ), |
| embeddings=TextEmbeddingConfigInput( |
| llm=LLMParametersInput( |
| type="azure_openai_embedding", |
| ) |
| ), |
| ) |
| ) |
|
|
| @mock.patch.dict(os.environ, {"GRAPHRAG_API_KEY": "test"}, clear=True) |
| def test_minimim_azure_config_object(self): |
| config = create_graphrag_config( |
| GraphRagConfigInput( |
| llm=LLMParametersInput( |
| type="azure_openai_chat", |
| api_base="http://some/base", |
| deployment_name="model-deployment-name-x", |
| ), |
| embeddings=TextEmbeddingConfigInput( |
| llm=LLMParametersInput( |
| type="azure_openai_embedding", |
| deployment_name="model-deployment-name", |
| ) |
| ), |
| ) |
| ) |
| assert config is not None |
|
|
| @mock.patch.dict( |
| os.environ, |
| { |
| "GRAPHRAG_API_KEY": "test", |
| "GRAPHRAG_LLM_TYPE": "azure_openai_chat", |
| "GRAPHRAG_LLM_DEPLOYMENT_NAME": "x", |
| }, |
| clear=True, |
| ) |
| def test_throws_if_azure_is_used_without_api_base(self): |
| with pytest.raises(AzureApiBaseMissingError): |
| create_graphrag_config() |
|
|
| @mock.patch.dict( |
| os.environ, |
| { |
| "GRAPHRAG_API_KEY": "test", |
| "GRAPHRAG_LLM_TYPE": "azure_openai_chat", |
| "GRAPHRAG_LLM_API_BASE": "http://some/base", |
| }, |
| clear=True, |
| ) |
| def test_throws_if_azure_is_used_without_llm_deployment_name(self): |
| with pytest.raises(AzureDeploymentNameMissingError): |
| create_graphrag_config() |
|
|
| @mock.patch.dict( |
| os.environ, |
| { |
| "GRAPHRAG_API_KEY": "test", |
| "GRAPHRAG_LLM_TYPE": "azure_openai_chat", |
| "GRAPHRAG_API_BASE": "http://some/base", |
| "GRAPHRAG_LLM_DEPLOYMENT_NAME": "model-deployment-name-x", |
| "GRAPHRAG_EMBEDDING_TYPE": "azure_openai_embedding", |
| }, |
| clear=True, |
| ) |
| def test_throws_if_azure_is_used_without_embedding_deployment_name(self): |
| with pytest.raises(AzureDeploymentNameMissingError): |
| create_graphrag_config() |
|
|
| @mock.patch.dict( |
| os.environ, |
| {"GRAPHRAG_API_KEY": "test", "GRAPHRAG_INPUT_FILE_TYPE": "csv"}, |
| clear=True, |
| ) |
| def test_csv_input_returns_correct_config(self): |
| config = create_pipeline_config(create_graphrag_config(root_dir="/some/root")) |
| assert config.root_dir == "/some/root" |
| |
| assert isinstance(config.input, PipelineCSVInputConfig) |
| assert (config.input.file_pattern or "") == ".*\\.csv$" |
|
|
| @mock.patch.dict( |
| os.environ, |
| {"GRAPHRAG_API_KEY": "test", "GRAPHRAG_INPUT_FILE_TYPE": "text"}, |
| clear=True, |
| ) |
| def test_text_input_returns_correct_config(self): |
| config = create_pipeline_config(create_graphrag_config(root_dir=".")) |
| assert isinstance(config.input, PipelineTextInputConfig) |
| assert config.input is not None |
| assert (config.input.file_pattern or "") == ".*\\.txt$" |
|
|
| def test_all_env_vars_is_accurate(self): |
| env_var_docs_path = Path("docsite/posts/config/env_vars.md") |
| query_docs_path = Path("docsite/posts/query/3-cli.md") |
|
|
| env_var_docs = env_var_docs_path.read_text(encoding="utf-8") |
| query_docs = query_docs_path.read_text(encoding="utf-8") |
|
|
| def find_envvar_names(text) -> set[str]: |
| pattern = r"`(GRAPHRAG_[^`]+)`" |
| found = re.findall(pattern, text) |
| found = {f for f in found if not f.endswith("_")} |
| return {*found} |
|
|
| graphrag_strings = find_envvar_names(env_var_docs) | find_envvar_names( |
| query_docs |
| ) |
|
|
| missing = {s for s in graphrag_strings if s not in ALL_ENV_VARS} - { |
| |
| "GRAPHRAG_LLM_API_KEY", |
| "GRAPHRAG_LLM_API_BASE", |
| "GRAPHRAG_LLM_API_VERSION", |
| "GRAPHRAG_LLM_API_ORGANIZATION", |
| "GRAPHRAG_LLM_API_PROXY", |
| "GRAPHRAG_EMBEDDING_API_KEY", |
| "GRAPHRAG_EMBEDDING_API_BASE", |
| "GRAPHRAG_EMBEDDING_API_VERSION", |
| "GRAPHRAG_EMBEDDING_API_ORGANIZATION", |
| "GRAPHRAG_EMBEDDING_API_PROXY", |
| } |
| if missing: |
| msg = f"{len(missing)} missing env vars: {missing}" |
| print(msg) |
| raise ValueError(msg) |
|
|
| @mock.patch.dict( |
| os.environ, |
| {"GRAPHRAG_API_KEY": "test"}, |
| clear=True, |
| ) |
| def test_malformed_input_dict_throws(self): |
| with pytest.raises(ValidationError): |
| create_graphrag_config(cast(Any, {"llm": 12})) |
|
|
| @mock.patch.dict( |
| os.environ, |
| ALL_ENV_VARS, |
| clear=True, |
| ) |
| def test_create_parameters_from_env_vars(self) -> None: |
| parameters = create_graphrag_config() |
| assert parameters.async_mode == "asyncio" |
| assert parameters.cache.storage_account_blob_url == "cache_account_blob_url" |
| assert parameters.cache.base_dir == "/some/cache/dir" |
| assert parameters.cache.connection_string == "test_cs1" |
| assert parameters.cache.container_name == "test_cn1" |
| assert parameters.cache.type == CacheType.blob |
| assert parameters.chunks.group_by_columns == ["a", "b"] |
| assert parameters.chunks.overlap == 12 |
| assert parameters.chunks.size == 500 |
| assert parameters.claim_extraction.enabled |
| assert parameters.claim_extraction.description == "test 123" |
| assert parameters.claim_extraction.max_gleanings == 5000 |
| assert parameters.claim_extraction.prompt == "tests/unit/config/prompt-a.txt" |
| assert parameters.cluster_graph.max_cluster_size == 123 |
| assert parameters.community_reports.max_length == 23456 |
| assert parameters.community_reports.prompt == "tests/unit/config/prompt-b.txt" |
| assert parameters.embed_graph.enabled |
| assert parameters.embed_graph.iterations == 878787 |
| assert parameters.embed_graph.num_walks == 5_000_000 |
| assert parameters.embed_graph.random_seed == 10101 |
| assert parameters.embed_graph.walk_length == 555111 |
| assert parameters.embed_graph.window_size == 12345 |
| assert parameters.embeddings.batch_max_tokens == 17 |
| assert parameters.embeddings.batch_size == 1_000_000 |
| assert parameters.embeddings.llm.concurrent_requests == 12 |
| assert parameters.embeddings.llm.deployment_name == "model-deployment-name" |
| assert parameters.embeddings.llm.max_retries == 3 |
| assert parameters.embeddings.llm.max_retry_wait == 0.1123 |
| assert parameters.embeddings.llm.model == "text-embedding-2" |
| assert parameters.embeddings.llm.requests_per_minute == 500 |
| assert parameters.embeddings.llm.sleep_on_rate_limit_recommendation is False |
| assert parameters.embeddings.llm.tokens_per_minute == 7000 |
| assert parameters.embeddings.llm.type == "azure_openai_embedding" |
| assert parameters.embeddings.parallelization.num_threads == 2345 |
| assert parameters.embeddings.parallelization.stagger == 0.456 |
| assert parameters.embeddings.skip == ["a1", "b1", "c1"] |
| assert parameters.embeddings.target == "all" |
| assert parameters.encoding_model == "test123" |
| assert parameters.entity_extraction.entity_types == ["cat", "dog", "elephant"] |
| assert parameters.entity_extraction.llm.api_base == "http://some/base" |
| assert parameters.entity_extraction.max_gleanings == 112 |
| assert parameters.entity_extraction.prompt == "tests/unit/config/prompt-c.txt" |
| assert parameters.input.storage_account_blob_url == "input_account_blob_url" |
| assert parameters.input.base_dir == "/some/input/dir" |
| assert parameters.input.connection_string == "input_cs" |
| assert parameters.input.container_name == "input_cn" |
| assert parameters.input.document_attribute_columns == ["test1", "test2"] |
| assert parameters.input.encoding == "utf-16" |
| assert parameters.input.file_pattern == ".*\\test\\.txt$" |
| assert parameters.input.file_type == InputFileType.text |
| assert parameters.input.source_column == "test_source" |
| assert parameters.input.text_column == "test_text" |
| assert parameters.input.timestamp_column == "test_timestamp" |
| assert parameters.input.timestamp_format == "test_format" |
| assert parameters.input.title_column == "test_title" |
| assert parameters.input.type == InputType.blob |
| assert parameters.llm.api_base == "http://some/base" |
| assert parameters.llm.api_key == "test" |
| assert parameters.llm.api_version == "v1234" |
| assert parameters.llm.concurrent_requests == 12 |
| assert parameters.llm.deployment_name == "model-deployment-name-x" |
| assert parameters.llm.max_retries == 312 |
| assert parameters.llm.max_retry_wait == 0.1122 |
| assert parameters.llm.max_tokens == 15000 |
| assert parameters.llm.model == "test-llm" |
| assert parameters.llm.model_supports_json |
| assert parameters.llm.n == 1 |
| assert parameters.llm.organization == "test_org" |
| assert parameters.llm.proxy == "http://some/proxy" |
| assert parameters.llm.request_timeout == 12.7 |
| assert parameters.llm.requests_per_minute == 900 |
| assert parameters.llm.sleep_on_rate_limit_recommendation is False |
| assert parameters.llm.temperature == 0.0 |
| assert parameters.llm.top_p == 1.0 |
| assert parameters.llm.tokens_per_minute == 8000 |
| assert parameters.llm.type == "azure_openai_chat" |
| assert parameters.parallelization.num_threads == 987 |
| assert parameters.parallelization.stagger == 0.123 |
| assert ( |
| parameters.reporting.storage_account_blob_url |
| == "reporting_account_blob_url" |
| ) |
| assert parameters.reporting.base_dir == "/some/reporting/dir" |
| assert parameters.reporting.connection_string == "test_cs2" |
| assert parameters.reporting.container_name == "test_cn2" |
| assert parameters.reporting.type == ReportingType.blob |
| assert parameters.skip_workflows == ["a", "b", "c"] |
| assert parameters.snapshots.graphml |
| assert parameters.snapshots.raw_entities |
| assert parameters.snapshots.top_level_nodes |
| assert parameters.storage.storage_account_blob_url == "storage_account_blob_url" |
| assert parameters.storage.base_dir == "/some/storage/dir" |
| assert parameters.storage.connection_string == "test_cs" |
| assert parameters.storage.container_name == "test_cn" |
| assert parameters.storage.type == StorageType.blob |
| assert parameters.summarize_descriptions.max_length == 12345 |
| assert ( |
| parameters.summarize_descriptions.prompt == "tests/unit/config/prompt-d.txt" |
| ) |
| assert parameters.umap.enabled |
| assert parameters.local_search.text_unit_prop == 0.713 |
| assert parameters.local_search.community_prop == 0.1234 |
| assert parameters.local_search.llm_max_tokens == 12 |
| assert parameters.local_search.top_k_relationships == 15 |
| assert parameters.local_search.conversation_history_max_turns == 2 |
| assert parameters.local_search.top_k_entities == 14 |
| assert parameters.local_search.temperature == 0.1 |
| assert parameters.local_search.top_p == 0.9 |
| assert parameters.local_search.n == 2 |
| assert parameters.local_search.max_tokens == 142435 |
|
|
| assert parameters.global_search.temperature == 0.1 |
| assert parameters.global_search.top_p == 0.9 |
| assert parameters.global_search.n == 2 |
| assert parameters.global_search.max_tokens == 5123 |
| assert parameters.global_search.data_max_tokens == 123 |
| assert parameters.global_search.map_max_tokens == 4123 |
| assert parameters.global_search.concurrency == 7 |
| assert parameters.global_search.reduce_max_tokens == 15432 |
|
|
| @mock.patch.dict(os.environ, {"API_KEY_X": "test"}, clear=True) |
| def test_create_parameters(self) -> None: |
| parameters = create_graphrag_config( |
| GraphRagConfigInput( |
| llm=LLMParametersInput(api_key="${API_KEY_X}", model="test-llm"), |
| storage=StorageConfigInput( |
| type=StorageType.blob, |
| connection_string="test_cs", |
| container_name="test_cn", |
| base_dir="/some/storage/dir", |
| storage_account_blob_url="storage_account_blob_url", |
| ), |
| cache=CacheConfigInput( |
| type=CacheType.blob, |
| connection_string="test_cs1", |
| container_name="test_cn1", |
| base_dir="/some/cache/dir", |
| storage_account_blob_url="cache_account_blob_url", |
| ), |
| reporting=ReportingConfigInput( |
| type=ReportingType.blob, |
| connection_string="test_cs2", |
| container_name="test_cn2", |
| base_dir="/some/reporting/dir", |
| storage_account_blob_url="reporting_account_blob_url", |
| ), |
| input=InputConfigInput( |
| file_type=InputFileType.text, |
| file_encoding="utf-16", |
| document_attribute_columns=["test1", "test2"], |
| base_dir="/some/input/dir", |
| connection_string="input_cs", |
| container_name="input_cn", |
| file_pattern=".*\\test\\.txt$", |
| source_column="test_source", |
| text_column="test_text", |
| timestamp_column="test_timestamp", |
| timestamp_format="test_format", |
| title_column="test_title", |
| type="blob", |
| storage_account_blob_url="input_account_blob_url", |
| ), |
| embed_graph=EmbedGraphConfigInput( |
| enabled=True, |
| num_walks=5_000_000, |
| iterations=878787, |
| random_seed=10101, |
| walk_length=555111, |
| ), |
| embeddings=TextEmbeddingConfigInput( |
| batch_size=1_000_000, |
| batch_max_tokens=8000, |
| skip=["a1", "b1", "c1"], |
| llm=LLMParametersInput(model="text-embedding-2"), |
| ), |
| chunks=ChunkingConfigInput( |
| size=500, overlap=12, group_by_columns=["a", "b"] |
| ), |
| snapshots=SnapshotsConfigInput( |
| graphml=True, |
| raw_entities=True, |
| top_level_nodes=True, |
| ), |
| entity_extraction=EntityExtractionConfigInput( |
| max_gleanings=112, |
| entity_types=["cat", "dog", "elephant"], |
| prompt="entity_extraction_prompt_file.txt", |
| ), |
| summarize_descriptions=SummarizeDescriptionsConfigInput( |
| max_length=12345, prompt="summarize_prompt_file.txt" |
| ), |
| community_reports=CommunityReportsConfigInput( |
| max_length=23456, |
| prompt="community_report_prompt_file.txt", |
| max_input_length=12345, |
| ), |
| claim_extraction=ClaimExtractionConfigInput( |
| description="test 123", |
| max_gleanings=5000, |
| prompt="claim_extraction_prompt_file.txt", |
| ), |
| cluster_graph=ClusterGraphConfigInput( |
| max_cluster_size=123, |
| ), |
| umap=UmapConfigInput(enabled=True), |
| encoding_model="test123", |
| skip_workflows=["a", "b", "c"], |
| ), |
| ".", |
| ) |
|
|
| assert parameters.cache.base_dir == "/some/cache/dir" |
| assert parameters.cache.connection_string == "test_cs1" |
| assert parameters.cache.container_name == "test_cn1" |
| assert parameters.cache.type == CacheType.blob |
| assert parameters.cache.storage_account_blob_url == "cache_account_blob_url" |
| assert parameters.chunks.group_by_columns == ["a", "b"] |
| assert parameters.chunks.overlap == 12 |
| assert parameters.chunks.size == 500 |
| assert parameters.claim_extraction.description == "test 123" |
| assert parameters.claim_extraction.max_gleanings == 5000 |
| assert parameters.claim_extraction.prompt == "claim_extraction_prompt_file.txt" |
| assert parameters.cluster_graph.max_cluster_size == 123 |
| assert parameters.community_reports.max_input_length == 12345 |
| assert parameters.community_reports.max_length == 23456 |
| assert parameters.community_reports.prompt == "community_report_prompt_file.txt" |
| assert parameters.embed_graph.enabled |
| assert parameters.embed_graph.iterations == 878787 |
| assert parameters.embed_graph.num_walks == 5_000_000 |
| assert parameters.embed_graph.random_seed == 10101 |
| assert parameters.embed_graph.walk_length == 555111 |
| assert parameters.embeddings.batch_max_tokens == 8000 |
| assert parameters.embeddings.batch_size == 1_000_000 |
| assert parameters.embeddings.llm.model == "text-embedding-2" |
| assert parameters.embeddings.skip == ["a1", "b1", "c1"] |
| assert parameters.encoding_model == "test123" |
| assert parameters.entity_extraction.entity_types == ["cat", "dog", "elephant"] |
| assert parameters.entity_extraction.max_gleanings == 112 |
| assert ( |
| parameters.entity_extraction.prompt == "entity_extraction_prompt_file.txt" |
| ) |
| assert parameters.input.base_dir == "/some/input/dir" |
| assert parameters.input.connection_string == "input_cs" |
| assert parameters.input.container_name == "input_cn" |
| assert parameters.input.document_attribute_columns == ["test1", "test2"] |
| assert parameters.input.encoding == "utf-16" |
| assert parameters.input.file_pattern == ".*\\test\\.txt$" |
| assert parameters.input.source_column == "test_source" |
| assert parameters.input.type == "blob" |
| assert parameters.input.text_column == "test_text" |
| assert parameters.input.timestamp_column == "test_timestamp" |
| assert parameters.input.timestamp_format == "test_format" |
| assert parameters.input.title_column == "test_title" |
| assert parameters.input.file_type == InputFileType.text |
| assert parameters.input.storage_account_blob_url == "input_account_blob_url" |
| assert parameters.llm.api_key == "test" |
| assert parameters.llm.model == "test-llm" |
| assert parameters.reporting.base_dir == "/some/reporting/dir" |
| assert parameters.reporting.connection_string == "test_cs2" |
| assert parameters.reporting.container_name == "test_cn2" |
| assert parameters.reporting.type == ReportingType.blob |
| assert ( |
| parameters.reporting.storage_account_blob_url |
| == "reporting_account_blob_url" |
| ) |
| assert parameters.skip_workflows == ["a", "b", "c"] |
| assert parameters.snapshots.graphml |
| assert parameters.snapshots.raw_entities |
| assert parameters.snapshots.top_level_nodes |
| assert parameters.storage.base_dir == "/some/storage/dir" |
| assert parameters.storage.connection_string == "test_cs" |
| assert parameters.storage.container_name == "test_cn" |
| assert parameters.storage.type == StorageType.blob |
| assert parameters.storage.storage_account_blob_url == "storage_account_blob_url" |
| assert parameters.summarize_descriptions.max_length == 12345 |
| assert parameters.summarize_descriptions.prompt == "summarize_prompt_file.txt" |
| assert parameters.umap.enabled |
|
|
| @mock.patch.dict( |
| os.environ, |
| {"GRAPHRAG_API_KEY": "test"}, |
| clear=True, |
| ) |
| def test_default_values(self) -> None: |
| parameters = create_graphrag_config() |
| assert parameters.async_mode == defs.ASYNC_MODE |
| assert parameters.cache.base_dir == defs.CACHE_BASE_DIR |
| assert parameters.cache.type == defs.CACHE_TYPE |
| assert parameters.cache.base_dir == defs.CACHE_BASE_DIR |
| assert parameters.chunks.group_by_columns == defs.CHUNK_GROUP_BY_COLUMNS |
| assert parameters.chunks.overlap == defs.CHUNK_OVERLAP |
| assert parameters.chunks.size == defs.CHUNK_SIZE |
| assert parameters.claim_extraction.description == defs.CLAIM_DESCRIPTION |
| assert parameters.claim_extraction.max_gleanings == defs.CLAIM_MAX_GLEANINGS |
| assert ( |
| parameters.community_reports.max_input_length |
| == defs.COMMUNITY_REPORT_MAX_INPUT_LENGTH |
| ) |
| assert ( |
| parameters.community_reports.max_length == defs.COMMUNITY_REPORT_MAX_LENGTH |
| ) |
| assert parameters.embeddings.batch_max_tokens == defs.EMBEDDING_BATCH_MAX_TOKENS |
| assert parameters.embeddings.batch_size == defs.EMBEDDING_BATCH_SIZE |
| assert parameters.embeddings.llm.model == defs.EMBEDDING_MODEL |
| assert parameters.embeddings.target == defs.EMBEDDING_TARGET |
| assert parameters.embeddings.llm.type == defs.EMBEDDING_TYPE |
| assert ( |
| parameters.embeddings.llm.requests_per_minute |
| == defs.LLM_REQUESTS_PER_MINUTE |
| ) |
| assert parameters.embeddings.llm.tokens_per_minute == defs.LLM_TOKENS_PER_MINUTE |
| assert ( |
| parameters.embeddings.llm.sleep_on_rate_limit_recommendation |
| == defs.LLM_SLEEP_ON_RATE_LIMIT_RECOMMENDATION |
| ) |
| assert ( |
| parameters.entity_extraction.entity_types |
| == defs.ENTITY_EXTRACTION_ENTITY_TYPES |
| ) |
| assert ( |
| parameters.entity_extraction.max_gleanings |
| == defs.ENTITY_EXTRACTION_MAX_GLEANINGS |
| ) |
| assert parameters.encoding_model == defs.ENCODING_MODEL |
| assert parameters.input.base_dir == defs.INPUT_BASE_DIR |
| assert parameters.input.file_pattern == defs.INPUT_CSV_PATTERN |
| assert parameters.input.encoding == defs.INPUT_FILE_ENCODING |
| assert parameters.input.type == defs.INPUT_TYPE |
| assert parameters.input.base_dir == defs.INPUT_BASE_DIR |
| assert parameters.input.text_column == defs.INPUT_TEXT_COLUMN |
| assert parameters.input.file_type == defs.INPUT_FILE_TYPE |
| assert parameters.llm.concurrent_requests == defs.LLM_CONCURRENT_REQUESTS |
| assert parameters.llm.max_retries == defs.LLM_MAX_RETRIES |
| assert parameters.llm.max_retry_wait == defs.LLM_MAX_RETRY_WAIT |
| assert parameters.llm.max_tokens == defs.LLM_MAX_TOKENS |
| assert parameters.llm.model == defs.LLM_MODEL |
| assert parameters.llm.request_timeout == defs.LLM_REQUEST_TIMEOUT |
| assert parameters.llm.requests_per_minute == defs.LLM_REQUESTS_PER_MINUTE |
| assert parameters.llm.tokens_per_minute == defs.LLM_TOKENS_PER_MINUTE |
| assert ( |
| parameters.llm.sleep_on_rate_limit_recommendation |
| == defs.LLM_SLEEP_ON_RATE_LIMIT_RECOMMENDATION |
| ) |
| assert parameters.llm.type == defs.LLM_TYPE |
| assert parameters.cluster_graph.max_cluster_size == defs.MAX_CLUSTER_SIZE |
| assert parameters.embed_graph.enabled == defs.NODE2VEC_ENABLED |
| assert parameters.embed_graph.iterations == defs.NODE2VEC_ITERATIONS |
| assert parameters.embed_graph.num_walks == defs.NODE2VEC_NUM_WALKS |
| assert parameters.embed_graph.random_seed == defs.NODE2VEC_RANDOM_SEED |
| assert parameters.embed_graph.walk_length == defs.NODE2VEC_WALK_LENGTH |
| assert parameters.embed_graph.window_size == defs.NODE2VEC_WINDOW_SIZE |
| assert ( |
| parameters.parallelization.num_threads == defs.PARALLELIZATION_NUM_THREADS |
| ) |
| assert parameters.parallelization.stagger == defs.PARALLELIZATION_STAGGER |
| assert parameters.reporting.type == defs.REPORTING_TYPE |
| assert parameters.reporting.base_dir == defs.REPORTING_BASE_DIR |
| assert parameters.snapshots.graphml == defs.SNAPSHOTS_GRAPHML |
| assert parameters.snapshots.raw_entities == defs.SNAPSHOTS_RAW_ENTITIES |
| assert parameters.snapshots.top_level_nodes == defs.SNAPSHOTS_TOP_LEVEL_NODES |
| assert parameters.storage.base_dir == defs.STORAGE_BASE_DIR |
| assert parameters.storage.type == defs.STORAGE_TYPE |
| assert parameters.umap.enabled == defs.UMAP_ENABLED |
|
|
| @mock.patch.dict( |
| os.environ, |
| {"GRAPHRAG_API_KEY": "test"}, |
| clear=True, |
| ) |
| def test_prompt_file_reading(self): |
| config = create_graphrag_config({ |
| "entity_extraction": {"prompt": "tests/unit/config/prompt-a.txt"}, |
| "claim_extraction": {"prompt": "tests/unit/config/prompt-b.txt"}, |
| "community_reports": {"prompt": "tests/unit/config/prompt-c.txt"}, |
| "summarize_descriptions": {"prompt": "tests/unit/config/prompt-d.txt"}, |
| }) |
| strategy = config.entity_extraction.resolved_strategy(".", "abc123") |
| assert strategy["extraction_prompt"] == "Hello, World! A" |
| assert strategy["encoding_name"] == "abc123" |
|
|
| strategy = config.claim_extraction.resolved_strategy(".") |
| assert strategy["extraction_prompt"] == "Hello, World! B" |
|
|
| strategy = config.community_reports.resolved_strategy(".") |
| assert strategy["extraction_prompt"] == "Hello, World! C" |
|
|
| strategy = config.summarize_descriptions.resolved_strategy(".") |
| assert strategy["summarize_prompt"] == "Hello, World! D" |
|
|
|
|
| @mock.patch.dict( |
| os.environ, |
| { |
| "PIPELINE_LLM_API_KEY": "test", |
| "PIPELINE_LLM_API_BASE": "http://test", |
| "PIPELINE_LLM_API_VERSION": "v1", |
| "PIPELINE_LLM_MODEL": "test-llm", |
| "PIPELINE_LLM_DEPLOYMENT_NAME": "test", |
| }, |
| clear=True, |
| ) |
| def test_yaml_load_e2e(): |
| config_dict = yaml.safe_load( |
| """ |
| input: |
| file_type: text |
| |
| llm: |
| type: azure_openai_chat |
| api_key: ${PIPELINE_LLM_API_KEY} |
| api_base: ${PIPELINE_LLM_API_BASE} |
| api_version: ${PIPELINE_LLM_API_VERSION} |
| model: ${PIPELINE_LLM_MODEL} |
| deployment_name: ${PIPELINE_LLM_DEPLOYMENT_NAME} |
| model_supports_json: True |
| tokens_per_minute: 80000 |
| requests_per_minute: 900 |
| thread_count: 50 |
| concurrent_requests: 25 |
| """ |
| ) |
| |
| model = config_dict |
| parameters = create_graphrag_config(model, ".") |
|
|
| assert parameters.llm.api_key == "test" |
| assert parameters.llm.model == "test-llm" |
| assert parameters.llm.api_base == "http://test" |
| assert parameters.llm.api_version == "v1" |
| assert parameters.llm.deployment_name == "test" |
|
|
| |
| pipeline_config = create_pipeline_config(parameters, True) |
|
|
| config_str = pipeline_config.model_dump_json() |
| assert "${PIPELINE_LLM_API_KEY}" not in config_str |
| assert "${PIPELINE_LLM_API_BASE}" not in config_str |
| assert "${PIPELINE_LLM_API_VERSION}" not in config_str |
| assert "${PIPELINE_LLM_MODEL}" not in config_str |
| assert "${PIPELINE_LLM_DEPLOYMENT_NAME}" not in config_str |
|
|