koichi12 commited on
Commit
b163abe
·
verified ·
1 Parent(s): cf58bb1

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. llm_tutorial/merged_models4/ja-en_en-ja_ties01/config.json +30 -0
  2. llm_tutorial/merged_models4/ja-en_en-ja_ties01/mergekit_config.yml +12 -0
  3. llm_tutorial/merged_models4/ja-en_en-ja_ties01/special_tokens_map.json +10 -0
  4. llm_tutorial/merged_models4/ja-en_en-ja_ties01/tokenizer_config.json +18 -0
  5. llm_tutorial/merged_models4/ja-en_en-ja_ties03/tokenizer.json +0 -0
  6. llm_tutorial/merged_models4/ja-en_en-ja_ties04-instruct/config.json +30 -0
  7. llm_tutorial/merged_models4/ja-en_en-ja_ties04-instruct/mergekit_config.yml +12 -0
  8. llm_tutorial/merged_models4/ja-en_en-ja_ties04-instruct/model.safetensors.index.json +1 -0
  9. llm_tutorial/merged_models4/ja-en_en-ja_ties04-instruct/tokenizer_config.json +84 -0
  10. llm_tutorial/merged_models4/ja-en_en-ja_ties04/config.json +30 -0
  11. tuning-competition-baseline/.venv/lib/python3.11/site-packages/markupsafe/py.typed +0 -0
  12. tuning-competition-baseline/.venv/lib/python3.11/site-packages/sympy/categories/tests/test_drawing.py +919 -0
  13. tuning-competition-baseline/.venv/lib/python3.11/site-packages/sympy/codegen/__init__.py +24 -0
  14. tuning-competition-baseline/.venv/lib/python3.11/site-packages/sympy/codegen/__pycache__/__init__.cpython-311.pyc +0 -0
  15. tuning-competition-baseline/.venv/lib/python3.11/site-packages/sympy/codegen/__pycache__/abstract_nodes.cpython-311.pyc +0 -0
  16. tuning-competition-baseline/.venv/lib/python3.11/site-packages/sympy/codegen/__pycache__/algorithms.cpython-311.pyc +0 -0
  17. tuning-competition-baseline/.venv/lib/python3.11/site-packages/sympy/codegen/__pycache__/approximations.cpython-311.pyc +0 -0
  18. tuning-competition-baseline/.venv/lib/python3.11/site-packages/sympy/codegen/__pycache__/ast.cpython-311.pyc +0 -0
  19. tuning-competition-baseline/.venv/lib/python3.11/site-packages/sympy/codegen/__pycache__/cfunctions.cpython-311.pyc +0 -0
  20. tuning-competition-baseline/.venv/lib/python3.11/site-packages/sympy/codegen/__pycache__/cnodes.cpython-311.pyc +0 -0
  21. tuning-competition-baseline/.venv/lib/python3.11/site-packages/sympy/codegen/__pycache__/cutils.cpython-311.pyc +0 -0
  22. tuning-competition-baseline/.venv/lib/python3.11/site-packages/sympy/codegen/__pycache__/cxxnodes.cpython-311.pyc +0 -0
  23. tuning-competition-baseline/.venv/lib/python3.11/site-packages/sympy/codegen/__pycache__/fnodes.cpython-311.pyc +0 -0
  24. tuning-competition-baseline/.venv/lib/python3.11/site-packages/sympy/codegen/__pycache__/futils.cpython-311.pyc +0 -0
  25. tuning-competition-baseline/.venv/lib/python3.11/site-packages/sympy/codegen/__pycache__/rewriting.cpython-311.pyc +0 -0
  26. tuning-competition-baseline/.venv/lib/python3.11/site-packages/sympy/codegen/tests/__pycache__/test_algorithms.cpython-311.pyc +0 -0
  27. tuning-competition-baseline/.venv/lib/python3.11/site-packages/sympy/codegen/tests/__pycache__/test_ast.cpython-311.pyc +0 -0
  28. tuning-competition-baseline/.venv/lib/python3.11/site-packages/sympy/codegen/tests/__pycache__/test_rewriting.cpython-311.pyc +0 -0
  29. tuning-competition-baseline/.venv/lib/python3.11/site-packages/sympy/codegen/tests/test_algorithms.py +179 -0
  30. tuning-competition-baseline/.venv/lib/python3.11/site-packages/sympy/codegen/tests/test_cxxnodes.py +14 -0
  31. tuning-competition-baseline/.venv/lib/python3.11/site-packages/sympy/codegen/tests/test_matrix_nodes.py +50 -0
  32. tuning-competition-baseline/.venv/lib/python3.11/site-packages/sympy/codegen/tests/test_scipy_nodes.py +44 -0
  33. tuning-competition-baseline/.venv/lib/python3.11/site-packages/sympy/concrete/__pycache__/__init__.cpython-311.pyc +0 -0
  34. tuning-competition-baseline/.venv/lib/python3.11/site-packages/sympy/concrete/__pycache__/gosper.cpython-311.pyc +0 -0
  35. tuning-competition-baseline/.venv/lib/python3.11/site-packages/sympy/concrete/tests/__pycache__/__init__.cpython-311.pyc +0 -0
  36. tuning-competition-baseline/.venv/lib/python3.11/site-packages/sympy/concrete/tests/__pycache__/test_delta.cpython-311.pyc +0 -0
  37. tuning-competition-baseline/.venv/lib/python3.11/site-packages/sympy/concrete/tests/__pycache__/test_gosper.cpython-311.pyc +0 -0
  38. tuning-competition-baseline/.venv/lib/python3.11/site-packages/sympy/concrete/tests/test_delta.py +499 -0
  39. tuning-competition-baseline/.venv/lib/python3.11/site-packages/sympy/concrete/tests/test_products.py +410 -0
  40. tuning-competition-baseline/.venv/lib/python3.11/site-packages/sympy/concrete/tests/test_sums_products.py +1646 -0
  41. tuning-competition-baseline/.venv/lib/python3.11/site-packages/sympy/matrices/common.py +3263 -0
  42. tuning-competition-baseline/.venv/lib/python3.11/site-packages/sympy/matrices/decompositions.py +1621 -0
  43. tuning-competition-baseline/.venv/lib/python3.11/site-packages/sympy/matrices/expressions/tests/test_blockmatrix.py +469 -0
  44. tuning-competition-baseline/.venv/lib/python3.11/site-packages/sympy/matrices/expressions/tests/test_diagonal.py +156 -0
  45. tuning-competition-baseline/.venv/lib/python3.11/site-packages/sympy/matrices/expressions/tests/test_funcmatrix.py +54 -0
  46. tuning-competition-baseline/.venv/lib/python3.11/site-packages/sympy/matrices/expressions/tests/test_indexing.py +299 -0
  47. tuning-competition-baseline/.venv/lib/python3.11/site-packages/sympy/matrices/expressions/tests/test_matadd.py +58 -0
  48. tuning-competition-baseline/.venv/lib/python3.11/site-packages/sympy/matrices/expressions/tests/test_matexpr.py +592 -0
  49. tuning-competition-baseline/.venv/lib/python3.11/site-packages/sympy/matrices/expressions/tests/test_matmul.py +186 -0
  50. tuning-competition-baseline/.venv/lib/python3.11/site-packages/sympy/matrices/expressions/tests/test_permutation.py +166 -0
llm_tutorial/merged_models4/ja-en_en-ja_ties01/config.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "/work01/llm_tutorial/pretrained_lm/llm-jp/llm-jp-v3-3.7b",
3
+ "architectures": [
4
+ "LlamaForCausalLM"
5
+ ],
6
+ "attention_bias": false,
7
+ "attention_dropout": 0.0,
8
+ "bos_token_id": 1,
9
+ "eos_token_id": 2,
10
+ "head_dim": 128,
11
+ "hidden_act": "silu",
12
+ "hidden_size": 3072,
13
+ "initializer_range": 0.02,
14
+ "intermediate_size": 8192,
15
+ "max_position_embeddings": 4096,
16
+ "mlp_bias": false,
17
+ "model_type": "llama",
18
+ "num_attention_heads": 24,
19
+ "num_hidden_layers": 28,
20
+ "num_key_value_heads": 24,
21
+ "pretraining_tp": 1,
22
+ "rms_norm_eps": 1e-05,
23
+ "rope_scaling": null,
24
+ "rope_theta": 10000,
25
+ "tie_word_embeddings": false,
26
+ "torch_dtype": "float16",
27
+ "transformers_version": "4.46.2",
28
+ "use_cache": true,
29
+ "vocab_size": 99584
30
+ }
llm_tutorial/merged_models4/ja-en_en-ja_ties01/mergekit_config.yml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ models:
2
+ - model: /home/koiwa/work/llm_tutorial/llm_recipes/models/hf-model-eval/llm-jp-v3-3.7b_ja-en_3M-pairs_3.5e-5/iter_0000698
3
+ parameters:
4
+ density: 0.1
5
+ weight: 0.5
6
+ - model: /home/koiwa/work/llm_tutorial/llm_recipes/models/hf-model-eval/llm-jp-v3-3.7b_en-ja_3M-pairs_3.5e-5/iter_0000698
7
+ parameters:
8
+ density: 0.1
9
+ weight: 0.5
10
+ base_model: /work01/llm_tutorial/pretrained_lm/llm-jp/llm-jp-v3-3.7b
11
+ merge_method: ties
12
+ dtype: float16
llm_tutorial/merged_models4/ja-en_en-ja_ties01/special_tokens_map.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "<s>",
3
+ "cls_token": "<CLS|LLM-jp>",
4
+ "eod_token": "</s>",
5
+ "eos_token": "</s>",
6
+ "mask_token": "<MASK|LLM-jp>",
7
+ "pad_token": "<PAD|LLM-jp>",
8
+ "sep_token": "<SEP|LLM-jp>",
9
+ "unk_token": "<unk>"
10
+ }
llm_tutorial/merged_models4/ja-en_en-ja_ties01/tokenizer_config.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": true,
3
+ "add_eos_token": false,
4
+ "unk_token": "<unk>",
5
+ "bos_token": "<s>",
6
+ "eos_token": "</s>",
7
+ "pad_token": "<PAD|LLM-jp>",
8
+ "cls_token": "<CLS|LLM-jp>",
9
+ "sep_token": "<SEP|LLM-jp>",
10
+ "eod_token": "</s>",
11
+ "mask_token": "<MASK|LLM-jp>",
12
+ "extra_ids": 0,
13
+ "sp_model_kwargs": {},
14
+ "model_max_length": 1000000000000000019884624838656,
15
+ "clean_up_tokenization_spaces": false,
16
+ "special_tokens_map_file": null,
17
+ "tokenizer_class": "PreTrainedTokenizerFast"
18
+ }
llm_tutorial/merged_models4/ja-en_en-ja_ties03/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
llm_tutorial/merged_models4/ja-en_en-ja_ties04-instruct/config.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "/work01/llm_tutorial/pretrained_lm/llm-jp/llm-jp-v3-3.7b-instruct",
3
+ "architectures": [
4
+ "LlamaForCausalLM"
5
+ ],
6
+ "attention_bias": false,
7
+ "attention_dropout": 0.0,
8
+ "bos_token_id": 1,
9
+ "eos_token_id": 2,
10
+ "head_dim": 128,
11
+ "hidden_act": "silu",
12
+ "hidden_size": 3072,
13
+ "initializer_range": 0.02,
14
+ "intermediate_size": 8192,
15
+ "max_position_embeddings": 4096,
16
+ "mlp_bias": false,
17
+ "model_type": "llama",
18
+ "num_attention_heads": 24,
19
+ "num_hidden_layers": 28,
20
+ "num_key_value_heads": 24,
21
+ "pretraining_tp": 1,
22
+ "rms_norm_eps": 1e-05,
23
+ "rope_scaling": null,
24
+ "rope_theta": 10000,
25
+ "tie_word_embeddings": false,
26
+ "torch_dtype": "float16",
27
+ "transformers_version": "4.46.2",
28
+ "use_cache": true,
29
+ "vocab_size": 99584
30
+ }
llm_tutorial/merged_models4/ja-en_en-ja_ties04-instruct/mergekit_config.yml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ models:
2
+ - model: /home/koiwa/work/llm_tutorial/llm_recipes/models/hf-model-eval/llm-jp-v3-3.7b_ja-en-instruct_3M-pairs/iter_0000698
3
+ parameters:
4
+ weight: 0.5
5
+ density: 0.4
6
+ - model: /home/koiwa/work/llm_tutorial/llm_recipes/models/hf-model-eval/llm-jp-v3-3.7b_en-ja-instruct_3M-pairs/iter_0000698
7
+ parameters:
8
+ weight: 0.5
9
+ density: 0.4
10
+ base_model: /work01/llm_tutorial/pretrained_lm/llm-jp/llm-jp-v3-3.7b-instruct
11
+ merge_method: ties
12
+ dtype: float16
llm_tutorial/merged_models4/ja-en_en-ja_ties04-instruct/model.safetensors.index.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"metadata": {"mergekit_version": "0.0.5.1", "total_size": 7565826048}, "weight_map": {"lm_head.weight": "model-00001-of-00002.safetensors", "model.embed_tokens.weight": "model-00001-of-00002.safetensors", "model.layers.0.input_layernorm.weight": "model-00001-of-00002.safetensors", "model.layers.0.mlp.down_proj.weight": "model-00001-of-00002.safetensors", "model.layers.0.mlp.gate_proj.weight": "model-00001-of-00002.safetensors", "model.layers.0.mlp.up_proj.weight": "model-00001-of-00002.safetensors", "model.layers.0.post_attention_layernorm.weight": "model-00001-of-00002.safetensors", "model.layers.0.self_attn.k_proj.weight": "model-00001-of-00002.safetensors", "model.layers.0.self_attn.o_proj.weight": "model-00001-of-00002.safetensors", "model.layers.0.self_attn.q_proj.weight": "model-00001-of-00002.safetensors", "model.layers.0.self_attn.v_proj.weight": "model-00001-of-00002.safetensors", "model.layers.1.input_layernorm.weight": "model-00001-of-00002.safetensors", "model.layers.1.mlp.down_proj.weight": "model-00001-of-00002.safetensors", "model.layers.1.mlp.gate_proj.weight": "model-00001-of-00002.safetensors", "model.layers.1.mlp.up_proj.weight": "model-00001-of-00002.safetensors", "model.layers.1.post_attention_layernorm.weight": "model-00001-of-00002.safetensors", "model.layers.1.self_attn.k_proj.weight": "model-00001-of-00002.safetensors", "model.layers.1.self_attn.o_proj.weight": "model-00001-of-00002.safetensors", "model.layers.1.self_attn.q_proj.weight": "model-00001-of-00002.safetensors", "model.layers.1.self_attn.v_proj.weight": "model-00001-of-00002.safetensors", "model.layers.10.input_layernorm.weight": "model-00001-of-00002.safetensors", "model.layers.10.mlp.down_proj.weight": "model-00001-of-00002.safetensors", "model.layers.10.mlp.gate_proj.weight": "model-00001-of-00002.safetensors", "model.layers.10.mlp.up_proj.weight": "model-00001-of-00002.safetensors", "model.layers.10.post_attention_layernorm.weight": "model-00001-of-00002.safetensors", "model.layers.10.self_attn.k_proj.weight": "model-00001-of-00002.safetensors", "model.layers.10.self_attn.o_proj.weight": "model-00001-of-00002.safetensors", "model.layers.10.self_attn.q_proj.weight": "model-00001-of-00002.safetensors", "model.layers.10.self_attn.v_proj.weight": "model-00001-of-00002.safetensors", "model.layers.11.input_layernorm.weight": "model-00001-of-00002.safetensors", "model.layers.11.mlp.down_proj.weight": "model-00001-of-00002.safetensors", "model.layers.11.mlp.gate_proj.weight": "model-00001-of-00002.safetensors", "model.layers.11.mlp.up_proj.weight": "model-00001-of-00002.safetensors", "model.layers.11.post_attention_layernorm.weight": "model-00001-of-00002.safetensors", "model.layers.11.self_attn.k_proj.weight": "model-00001-of-00002.safetensors", "model.layers.11.self_attn.o_proj.weight": "model-00001-of-00002.safetensors", "model.layers.11.self_attn.q_proj.weight": "model-00001-of-00002.safetensors", "model.layers.11.self_attn.v_proj.weight": "model-00001-of-00002.safetensors", "model.layers.12.input_layernorm.weight": "model-00001-of-00002.safetensors", "model.layers.12.mlp.down_proj.weight": "model-00001-of-00002.safetensors", "model.layers.12.mlp.gate_proj.weight": "model-00001-of-00002.safetensors", "model.layers.12.mlp.up_proj.weight": "model-00001-of-00002.safetensors", "model.layers.12.post_attention_layernorm.weight": "model-00001-of-00002.safetensors", "model.layers.12.self_attn.k_proj.weight": "model-00001-of-00002.safetensors", "model.layers.12.self_attn.o_proj.weight": "model-00001-of-00002.safetensors", "model.layers.12.self_attn.q_proj.weight": "model-00001-of-00002.safetensors", "model.layers.12.self_attn.v_proj.weight": "model-00001-of-00002.safetensors", "model.layers.13.input_layernorm.weight": "model-00001-of-00002.safetensors", "model.layers.13.mlp.down_proj.weight": "model-00001-of-00002.safetensors", "model.layers.13.mlp.gate_proj.weight": "model-00001-of-00002.safetensors", "model.layers.13.mlp.up_proj.weight": "model-00001-of-00002.safetensors", "model.layers.13.post_attention_layernorm.weight": "model-00001-of-00002.safetensors", "model.layers.13.self_attn.k_proj.weight": "model-00001-of-00002.safetensors", "model.layers.13.self_attn.o_proj.weight": "model-00001-of-00002.safetensors", "model.layers.13.self_attn.q_proj.weight": "model-00001-of-00002.safetensors", "model.layers.13.self_attn.v_proj.weight": "model-00001-of-00002.safetensors", "model.layers.14.input_layernorm.weight": "model-00001-of-00002.safetensors", "model.layers.14.mlp.down_proj.weight": "model-00001-of-00002.safetensors", "model.layers.14.mlp.gate_proj.weight": "model-00001-of-00002.safetensors", "model.layers.14.mlp.up_proj.weight": "model-00001-of-00002.safetensors", "model.layers.14.post_attention_layernorm.weight": "model-00001-of-00002.safetensors", "model.layers.14.self_attn.k_proj.weight": "model-00001-of-00002.safetensors", "model.layers.14.self_attn.o_proj.weight": "model-00001-of-00002.safetensors", "model.layers.14.self_attn.q_proj.weight": "model-00001-of-00002.safetensors", "model.layers.14.self_attn.v_proj.weight": "model-00001-of-00002.safetensors", "model.layers.15.input_layernorm.weight": "model-00001-of-00002.safetensors", "model.layers.15.mlp.down_proj.weight": "model-00001-of-00002.safetensors", "model.layers.15.mlp.gate_proj.weight": "model-00001-of-00002.safetensors", "model.layers.15.mlp.up_proj.weight": "model-00001-of-00002.safetensors", "model.layers.15.post_attention_layernorm.weight": "model-00001-of-00002.safetensors", "model.layers.15.self_attn.k_proj.weight": "model-00001-of-00002.safetensors", "model.layers.15.self_attn.o_proj.weight": "model-00001-of-00002.safetensors", "model.layers.15.self_attn.q_proj.weight": "model-00001-of-00002.safetensors", "model.layers.15.self_attn.v_proj.weight": "model-00001-of-00002.safetensors", "model.layers.16.input_layernorm.weight": "model-00001-of-00002.safetensors", "model.layers.16.mlp.down_proj.weight": "model-00001-of-00002.safetensors", "model.layers.16.mlp.gate_proj.weight": "model-00001-of-00002.safetensors", "model.layers.16.mlp.up_proj.weight": "model-00001-of-00002.safetensors", "model.layers.16.post_attention_layernorm.weight": "model-00001-of-00002.safetensors", "model.layers.16.self_attn.k_proj.weight": "model-00001-of-00002.safetensors", "model.layers.16.self_attn.o_proj.weight": "model-00001-of-00002.safetensors", "model.layers.16.self_attn.q_proj.weight": "model-00001-of-00002.safetensors", "model.layers.16.self_attn.v_proj.weight": "model-00001-of-00002.safetensors", "model.layers.17.input_layernorm.weight": "model-00001-of-00002.safetensors", "model.layers.17.mlp.down_proj.weight": "model-00001-of-00002.safetensors", "model.layers.17.mlp.gate_proj.weight": "model-00001-of-00002.safetensors", "model.layers.17.mlp.up_proj.weight": "model-00001-of-00002.safetensors", "model.layers.17.post_attention_layernorm.weight": "model-00001-of-00002.safetensors", "model.layers.17.self_attn.k_proj.weight": "model-00001-of-00002.safetensors", "model.layers.17.self_attn.o_proj.weight": "model-00001-of-00002.safetensors", "model.layers.17.self_attn.q_proj.weight": "model-00001-of-00002.safetensors", "model.layers.17.self_attn.v_proj.weight": "model-00001-of-00002.safetensors", "model.layers.18.input_layernorm.weight": "model-00001-of-00002.safetensors", "model.layers.18.mlp.down_proj.weight": "model-00001-of-00002.safetensors", "model.layers.18.mlp.gate_proj.weight": "model-00001-of-00002.safetensors", "model.layers.18.mlp.up_proj.weight": "model-00001-of-00002.safetensors", "model.layers.18.post_attention_layernorm.weight": "model-00001-of-00002.safetensors", "model.layers.18.self_attn.k_proj.weight": "model-00001-of-00002.safetensors", "model.layers.18.self_attn.o_proj.weight": "model-00001-of-00002.safetensors", "model.layers.18.self_attn.q_proj.weight": "model-00001-of-00002.safetensors", "model.layers.18.self_attn.v_proj.weight": "model-00001-of-00002.safetensors", "model.layers.19.input_layernorm.weight": "model-00001-of-00002.safetensors", "model.layers.19.mlp.down_proj.weight": "model-00001-of-00002.safetensors", "model.layers.19.mlp.gate_proj.weight": "model-00001-of-00002.safetensors", "model.layers.19.mlp.up_proj.weight": "model-00001-of-00002.safetensors", "model.layers.19.post_attention_layernorm.weight": "model-00001-of-00002.safetensors", "model.layers.19.self_attn.k_proj.weight": "model-00001-of-00002.safetensors", "model.layers.19.self_attn.o_proj.weight": "model-00001-of-00002.safetensors", "model.layers.19.self_attn.q_proj.weight": "model-00001-of-00002.safetensors", "model.layers.19.self_attn.v_proj.weight": "model-00001-of-00002.safetensors", "model.layers.2.input_layernorm.weight": "model-00001-of-00002.safetensors", "model.layers.2.mlp.down_proj.weight": "model-00001-of-00002.safetensors", "model.layers.2.mlp.gate_proj.weight": "model-00001-of-00002.safetensors", "model.layers.2.mlp.up_proj.weight": "model-00001-of-00002.safetensors", "model.layers.2.post_attention_layernorm.weight": "model-00001-of-00002.safetensors", "model.layers.2.self_attn.k_proj.weight": "model-00001-of-00002.safetensors", "model.layers.2.self_attn.o_proj.weight": "model-00001-of-00002.safetensors", "model.layers.2.self_attn.q_proj.weight": "model-00001-of-00002.safetensors", "model.layers.2.self_attn.v_proj.weight": "model-00001-of-00002.safetensors", "model.layers.20.input_layernorm.weight": "model-00001-of-00002.safetensors", "model.layers.20.mlp.down_proj.weight": "model-00001-of-00002.safetensors", "model.layers.20.mlp.gate_proj.weight": "model-00001-of-00002.safetensors", "model.layers.20.mlp.up_proj.weight": "model-00001-of-00002.safetensors", "model.layers.20.post_attention_layernorm.weight": "model-00001-of-00002.safetensors", "model.layers.20.self_attn.k_proj.weight": "model-00001-of-00002.safetensors", "model.layers.20.self_attn.o_proj.weight": "model-00001-of-00002.safetensors", "model.layers.20.self_attn.q_proj.weight": "model-00001-of-00002.safetensors", "model.layers.20.self_attn.v_proj.weight": "model-00001-of-00002.safetensors", "model.layers.21.input_layernorm.weight": "model-00001-of-00002.safetensors", "model.layers.21.mlp.down_proj.weight": "model-00001-of-00002.safetensors", "model.layers.21.mlp.gate_proj.weight": "model-00001-of-00002.safetensors", "model.layers.21.mlp.up_proj.weight": "model-00001-of-00002.safetensors", "model.layers.21.post_attention_layernorm.weight": "model-00001-of-00002.safetensors", "model.layers.21.self_attn.k_proj.weight": "model-00001-of-00002.safetensors", "model.layers.21.self_attn.o_proj.weight": "model-00001-of-00002.safetensors", "model.layers.21.self_attn.q_proj.weight": "model-00001-of-00002.safetensors", "model.layers.21.self_attn.v_proj.weight": "model-00001-of-00002.safetensors", "model.layers.22.input_layernorm.weight": "model-00001-of-00002.safetensors", "model.layers.22.mlp.down_proj.weight": "model-00001-of-00002.safetensors", "model.layers.22.mlp.gate_proj.weight": "model-00001-of-00002.safetensors", "model.layers.22.mlp.up_proj.weight": "model-00001-of-00002.safetensors", "model.layers.22.post_attention_layernorm.weight": "model-00001-of-00002.safetensors", "model.layers.22.self_attn.k_proj.weight": "model-00001-of-00002.safetensors", "model.layers.22.self_attn.o_proj.weight": "model-00001-of-00002.safetensors", "model.layers.22.self_attn.q_proj.weight": "model-00001-of-00002.safetensors", "model.layers.22.self_attn.v_proj.weight": "model-00001-of-00002.safetensors", "model.layers.23.input_layernorm.weight": "model-00001-of-00002.safetensors", "model.layers.23.mlp.down_proj.weight": "model-00001-of-00002.safetensors", "model.layers.23.mlp.gate_proj.weight": "model-00001-of-00002.safetensors", "model.layers.23.mlp.up_proj.weight": "model-00001-of-00002.safetensors", "model.layers.23.post_attention_layernorm.weight": "model-00001-of-00002.safetensors", "model.layers.23.self_attn.k_proj.weight": "model-00002-of-00002.safetensors", "model.layers.23.self_attn.o_proj.weight": "model-00002-of-00002.safetensors", "model.layers.23.self_attn.q_proj.weight": "model-00002-of-00002.safetensors", "model.layers.23.self_attn.v_proj.weight": "model-00002-of-00002.safetensors", "model.layers.24.input_layernorm.weight": "model-00002-of-00002.safetensors", "model.layers.24.mlp.down_proj.weight": "model-00002-of-00002.safetensors", "model.layers.24.mlp.gate_proj.weight": "model-00002-of-00002.safetensors", "model.layers.24.mlp.up_proj.weight": "model-00002-of-00002.safetensors", "model.layers.24.post_attention_layernorm.weight": "model-00002-of-00002.safetensors", "model.layers.24.self_attn.k_proj.weight": "model-00002-of-00002.safetensors", "model.layers.24.self_attn.o_proj.weight": "model-00002-of-00002.safetensors", "model.layers.24.self_attn.q_proj.weight": "model-00002-of-00002.safetensors", "model.layers.24.self_attn.v_proj.weight": "model-00002-of-00002.safetensors", "model.layers.25.input_layernorm.weight": "model-00002-of-00002.safetensors", "model.layers.25.mlp.down_proj.weight": "model-00002-of-00002.safetensors", "model.layers.25.mlp.gate_proj.weight": "model-00002-of-00002.safetensors", "model.layers.25.mlp.up_proj.weight": "model-00002-of-00002.safetensors", "model.layers.25.post_attention_layernorm.weight": "model-00002-of-00002.safetensors", "model.layers.25.self_attn.k_proj.weight": "model-00002-of-00002.safetensors", "model.layers.25.self_attn.o_proj.weight": "model-00002-of-00002.safetensors", "model.layers.25.self_attn.q_proj.weight": "model-00002-of-00002.safetensors", "model.layers.25.self_attn.v_proj.weight": "model-00002-of-00002.safetensors", "model.layers.26.input_layernorm.weight": "model-00002-of-00002.safetensors", "model.layers.26.mlp.down_proj.weight": "model-00002-of-00002.safetensors", "model.layers.26.mlp.gate_proj.weight": "model-00002-of-00002.safetensors", "model.layers.26.mlp.up_proj.weight": "model-00002-of-00002.safetensors", "model.layers.26.post_attention_layernorm.weight": "model-00002-of-00002.safetensors", "model.layers.26.self_attn.k_proj.weight": "model-00002-of-00002.safetensors", "model.layers.26.self_attn.o_proj.weight": "model-00002-of-00002.safetensors", "model.layers.26.self_attn.q_proj.weight": "model-00002-of-00002.safetensors", "model.layers.26.self_attn.v_proj.weight": "model-00002-of-00002.safetensors", "model.layers.27.input_layernorm.weight": "model-00002-of-00002.safetensors", "model.layers.27.mlp.down_proj.weight": "model-00002-of-00002.safetensors", "model.layers.27.mlp.gate_proj.weight": "model-00002-of-00002.safetensors", "model.layers.27.mlp.up_proj.weight": "model-00002-of-00002.safetensors", "model.layers.27.post_attention_layernorm.weight": "model-00002-of-00002.safetensors", "model.layers.27.self_attn.k_proj.weight": "model-00002-of-00002.safetensors", "model.layers.27.self_attn.o_proj.weight": "model-00002-of-00002.safetensors", "model.layers.27.self_attn.q_proj.weight": "model-00002-of-00002.safetensors", "model.layers.27.self_attn.v_proj.weight": "model-00002-of-00002.safetensors", "model.layers.3.input_layernorm.weight": "model-00002-of-00002.safetensors", "model.layers.3.mlp.down_proj.weight": "model-00002-of-00002.safetensors", "model.layers.3.mlp.gate_proj.weight": "model-00002-of-00002.safetensors", "model.layers.3.mlp.up_proj.weight": "model-00002-of-00002.safetensors", "model.layers.3.post_attention_layernorm.weight": "model-00002-of-00002.safetensors", "model.layers.3.self_attn.k_proj.weight": "model-00002-of-00002.safetensors", "model.layers.3.self_attn.o_proj.weight": "model-00002-of-00002.safetensors", "model.layers.3.self_attn.q_proj.weight": "model-00002-of-00002.safetensors", "model.layers.3.self_attn.v_proj.weight": "model-00002-of-00002.safetensors", "model.layers.4.input_layernorm.weight": "model-00002-of-00002.safetensors", "model.layers.4.mlp.down_proj.weight": "model-00002-of-00002.safetensors", "model.layers.4.mlp.gate_proj.weight": "model-00002-of-00002.safetensors", "model.layers.4.mlp.up_proj.weight": "model-00002-of-00002.safetensors", "model.layers.4.post_attention_layernorm.weight": "model-00002-of-00002.safetensors", "model.layers.4.self_attn.k_proj.weight": "model-00002-of-00002.safetensors", "model.layers.4.self_attn.o_proj.weight": "model-00002-of-00002.safetensors", "model.layers.4.self_attn.q_proj.weight": "model-00002-of-00002.safetensors", "model.layers.4.self_attn.v_proj.weight": "model-00002-of-00002.safetensors", "model.layers.5.input_layernorm.weight": "model-00002-of-00002.safetensors", "model.layers.5.mlp.down_proj.weight": "model-00002-of-00002.safetensors", "model.layers.5.mlp.gate_proj.weight": "model-00002-of-00002.safetensors", "model.layers.5.mlp.up_proj.weight": "model-00002-of-00002.safetensors", "model.layers.5.post_attention_layernorm.weight": "model-00002-of-00002.safetensors", "model.layers.5.self_attn.k_proj.weight": "model-00002-of-00002.safetensors", "model.layers.5.self_attn.o_proj.weight": "model-00002-of-00002.safetensors", "model.layers.5.self_attn.q_proj.weight": "model-00002-of-00002.safetensors", "model.layers.5.self_attn.v_proj.weight": "model-00002-of-00002.safetensors", "model.layers.6.input_layernorm.weight": "model-00002-of-00002.safetensors", "model.layers.6.mlp.down_proj.weight": "model-00002-of-00002.safetensors", "model.layers.6.mlp.gate_proj.weight": "model-00002-of-00002.safetensors", "model.layers.6.mlp.up_proj.weight": "model-00002-of-00002.safetensors", "model.layers.6.post_attention_layernorm.weight": "model-00002-of-00002.safetensors", "model.layers.6.self_attn.k_proj.weight": "model-00002-of-00002.safetensors", "model.layers.6.self_attn.o_proj.weight": "model-00002-of-00002.safetensors", "model.layers.6.self_attn.q_proj.weight": "model-00002-of-00002.safetensors", "model.layers.6.self_attn.v_proj.weight": "model-00002-of-00002.safetensors", "model.layers.7.input_layernorm.weight": "model-00002-of-00002.safetensors", "model.layers.7.mlp.down_proj.weight": "model-00002-of-00002.safetensors", "model.layers.7.mlp.gate_proj.weight": "model-00002-of-00002.safetensors", "model.layers.7.mlp.up_proj.weight": "model-00002-of-00002.safetensors", "model.layers.7.post_attention_layernorm.weight": "model-00002-of-00002.safetensors", "model.layers.7.self_attn.k_proj.weight": "model-00002-of-00002.safetensors", "model.layers.7.self_attn.o_proj.weight": "model-00002-of-00002.safetensors", "model.layers.7.self_attn.q_proj.weight": "model-00002-of-00002.safetensors", "model.layers.7.self_attn.v_proj.weight": "model-00002-of-00002.safetensors", "model.layers.8.input_layernorm.weight": "model-00002-of-00002.safetensors", "model.layers.8.mlp.down_proj.weight": "model-00002-of-00002.safetensors", "model.layers.8.mlp.gate_proj.weight": "model-00002-of-00002.safetensors", "model.layers.8.mlp.up_proj.weight": "model-00002-of-00002.safetensors", "model.layers.8.post_attention_layernorm.weight": "model-00002-of-00002.safetensors", "model.layers.8.self_attn.k_proj.weight": "model-00002-of-00002.safetensors", "model.layers.8.self_attn.o_proj.weight": "model-00002-of-00002.safetensors", "model.layers.8.self_attn.q_proj.weight": "model-00002-of-00002.safetensors", "model.layers.8.self_attn.v_proj.weight": "model-00002-of-00002.safetensors", "model.layers.9.input_layernorm.weight": "model-00002-of-00002.safetensors", "model.layers.9.mlp.down_proj.weight": "model-00002-of-00002.safetensors", "model.layers.9.mlp.gate_proj.weight": "model-00002-of-00002.safetensors", "model.layers.9.mlp.up_proj.weight": "model-00002-of-00002.safetensors", "model.layers.9.post_attention_layernorm.weight": "model-00002-of-00002.safetensors", "model.layers.9.self_attn.k_proj.weight": "model-00002-of-00002.safetensors", "model.layers.9.self_attn.o_proj.weight": "model-00002-of-00002.safetensors", "model.layers.9.self_attn.q_proj.weight": "model-00002-of-00002.safetensors", "model.layers.9.self_attn.v_proj.weight": "model-00002-of-00002.safetensors", "model.norm.weight": "model-00002-of-00002.safetensors"}}
llm_tutorial/merged_models4/ja-en_en-ja_ties04-instruct/tokenizer_config.json ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": true,
3
+ "add_eos_token": false,
4
+ "added_tokens_decoder": {
5
+ "0": {
6
+ "content": "<unk>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "1": {
14
+ "content": "<s>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "2": {
22
+ "content": "</s>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false,
27
+ "special": true
28
+ },
29
+ "3": {
30
+ "content": "<MASK|LLM-jp>",
31
+ "lstrip": false,
32
+ "normalized": false,
33
+ "rstrip": false,
34
+ "single_word": false,
35
+ "special": true
36
+ },
37
+ "4": {
38
+ "content": "<PAD|LLM-jp>",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false,
43
+ "special": true
44
+ },
45
+ "5": {
46
+ "content": "<CLS|LLM-jp>",
47
+ "lstrip": false,
48
+ "normalized": false,
49
+ "rstrip": false,
50
+ "single_word": false,
51
+ "special": true
52
+ },
53
+ "6": {
54
+ "content": "<SEP|LLM-jp>",
55
+ "lstrip": false,
56
+ "normalized": false,
57
+ "rstrip": false,
58
+ "single_word": false,
59
+ "special": true
60
+ },
61
+ "7": {
62
+ "content": "<EOD|LLM-jp>",
63
+ "lstrip": false,
64
+ "normalized": false,
65
+ "rstrip": false,
66
+ "single_word": false,
67
+ "special": true
68
+ }
69
+ },
70
+ "bos_token": "<s>",
71
+ "chat_template": "{{bos_token}}{% for message in messages %}{% if message['role'] == 'user' %}{{ '\\n\\n### 指示:\\n' + message['content'] }}{% elif message['role'] == 'system' %}{{ '以下は、タスクを説明する指示です。要求を適切に満たす応答を書きなさい。' }}{% elif message['role'] == 'assistant' %}{{ '\\n\\n### 応答:\\n' + message['content'] + eos_token }}{% endif %}{% if loop.last and add_generation_prompt %}{{ '\\n\\n### 応答:\\n' }}{% endif %}{% endfor %}",
72
+ "clean_up_tokenization_spaces": false,
73
+ "cls_token": "<CLS|LLM-jp>",
74
+ "eod_token": "</s>",
75
+ "eos_token": "</s>",
76
+ "extra_ids": 0,
77
+ "mask_token": "<MASK|LLM-jp>",
78
+ "model_max_length": 1000000000000000019884624838656,
79
+ "pad_token": "<PAD|LLM-jp>",
80
+ "sep_token": "<SEP|LLM-jp>",
81
+ "sp_model_kwargs": {},
82
+ "tokenizer_class": "PreTrainedTokenizerFast",
83
+ "unk_token": "<unk>"
84
+ }
llm_tutorial/merged_models4/ja-en_en-ja_ties04/config.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "/work01/llm_tutorial/pretrained_lm/llm-jp/llm-jp-v3-3.7b",
3
+ "architectures": [
4
+ "LlamaForCausalLM"
5
+ ],
6
+ "attention_bias": false,
7
+ "attention_dropout": 0.0,
8
+ "bos_token_id": 1,
9
+ "eos_token_id": 2,
10
+ "head_dim": 128,
11
+ "hidden_act": "silu",
12
+ "hidden_size": 3072,
13
+ "initializer_range": 0.02,
14
+ "intermediate_size": 8192,
15
+ "max_position_embeddings": 4096,
16
+ "mlp_bias": false,
17
+ "model_type": "llama",
18
+ "num_attention_heads": 24,
19
+ "num_hidden_layers": 28,
20
+ "num_key_value_heads": 24,
21
+ "pretraining_tp": 1,
22
+ "rms_norm_eps": 1e-05,
23
+ "rope_scaling": null,
24
+ "rope_theta": 10000,
25
+ "tie_word_embeddings": false,
26
+ "torch_dtype": "float16",
27
+ "transformers_version": "4.46.2",
28
+ "use_cache": true,
29
+ "vocab_size": 99584
30
+ }
tuning-competition-baseline/.venv/lib/python3.11/site-packages/markupsafe/py.typed ADDED
File without changes
tuning-competition-baseline/.venv/lib/python3.11/site-packages/sympy/categories/tests/test_drawing.py ADDED
@@ -0,0 +1,919 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sympy.categories.diagram_drawing import _GrowableGrid, ArrowStringDescription
2
+ from sympy.categories import (DiagramGrid, Object, NamedMorphism,
3
+ Diagram, XypicDiagramDrawer, xypic_draw_diagram)
4
+ from sympy.sets.sets import FiniteSet
5
+
6
+
7
+ def test_GrowableGrid():
8
+ grid = _GrowableGrid(1, 2)
9
+
10
+ # Check dimensions.
11
+ assert grid.width == 1
12
+ assert grid.height == 2
13
+
14
+ # Check initialization of elements.
15
+ assert grid[0, 0] is None
16
+ assert grid[1, 0] is None
17
+
18
+ # Check assignment to elements.
19
+ grid[0, 0] = 1
20
+ grid[1, 0] = "two"
21
+
22
+ assert grid[0, 0] == 1
23
+ assert grid[1, 0] == "two"
24
+
25
+ # Check appending a row.
26
+ grid.append_row()
27
+
28
+ assert grid.width == 1
29
+ assert grid.height == 3
30
+
31
+ assert grid[0, 0] == 1
32
+ assert grid[1, 0] == "two"
33
+ assert grid[2, 0] is None
34
+
35
+ # Check appending a column.
36
+ grid.append_column()
37
+ assert grid.width == 2
38
+ assert grid.height == 3
39
+
40
+ assert grid[0, 0] == 1
41
+ assert grid[1, 0] == "two"
42
+ assert grid[2, 0] is None
43
+
44
+ assert grid[0, 1] is None
45
+ assert grid[1, 1] is None
46
+ assert grid[2, 1] is None
47
+
48
+ grid = _GrowableGrid(1, 2)
49
+ grid[0, 0] = 1
50
+ grid[1, 0] = "two"
51
+
52
+ # Check prepending a row.
53
+ grid.prepend_row()
54
+ assert grid.width == 1
55
+ assert grid.height == 3
56
+
57
+ assert grid[0, 0] is None
58
+ assert grid[1, 0] == 1
59
+ assert grid[2, 0] == "two"
60
+
61
+ # Check prepending a column.
62
+ grid.prepend_column()
63
+ assert grid.width == 2
64
+ assert grid.height == 3
65
+
66
+ assert grid[0, 0] is None
67
+ assert grid[1, 0] is None
68
+ assert grid[2, 0] is None
69
+
70
+ assert grid[0, 1] is None
71
+ assert grid[1, 1] == 1
72
+ assert grid[2, 1] == "two"
73
+
74
+
75
+ def test_DiagramGrid():
76
+ # Set up some objects and morphisms.
77
+ A = Object("A")
78
+ B = Object("B")
79
+ C = Object("C")
80
+ D = Object("D")
81
+ E = Object("E")
82
+
83
+ f = NamedMorphism(A, B, "f")
84
+ g = NamedMorphism(B, C, "g")
85
+ h = NamedMorphism(D, A, "h")
86
+ k = NamedMorphism(D, B, "k")
87
+
88
+ # A one-morphism diagram.
89
+ d = Diagram([f])
90
+ grid = DiagramGrid(d)
91
+
92
+ assert grid.width == 2
93
+ assert grid.height == 1
94
+ assert grid[0, 0] == A
95
+ assert grid[0, 1] == B
96
+ assert grid.morphisms == {f: FiniteSet()}
97
+
98
+ # A triangle.
99
+ d = Diagram([f, g], {g * f: "unique"})
100
+ grid = DiagramGrid(d)
101
+
102
+ assert grid.width == 2
103
+ assert grid.height == 2
104
+ assert grid[0, 0] == A
105
+ assert grid[0, 1] == B
106
+ assert grid[1, 0] == C
107
+ assert grid[1, 1] is None
108
+ assert grid.morphisms == {f: FiniteSet(), g: FiniteSet(),
109
+ g * f: FiniteSet("unique")}
110
+
111
+ # A triangle with a "loop" morphism.
112
+ l_A = NamedMorphism(A, A, "l_A")
113
+ d = Diagram([f, g, l_A])
114
+ grid = DiagramGrid(d)
115
+
116
+ assert grid.width == 2
117
+ assert grid.height == 2
118
+ assert grid[0, 0] == A
119
+ assert grid[0, 1] == B
120
+ assert grid[1, 0] is None
121
+ assert grid[1, 1] == C
122
+ assert grid.morphisms == {f: FiniteSet(), g: FiniteSet(), l_A: FiniteSet()}
123
+
124
+ # A simple diagram.
125
+ d = Diagram([f, g, h, k])
126
+ grid = DiagramGrid(d)
127
+
128
+ assert grid.width == 3
129
+ assert grid.height == 2
130
+ assert grid[0, 0] == A
131
+ assert grid[0, 1] == B
132
+ assert grid[0, 2] == D
133
+ assert grid[1, 0] is None
134
+ assert grid[1, 1] == C
135
+ assert grid[1, 2] is None
136
+ assert grid.morphisms == {f: FiniteSet(), g: FiniteSet(), h: FiniteSet(),
137
+ k: FiniteSet()}
138
+
139
+ assert str(grid) == '[[Object("A"), Object("B"), Object("D")], ' \
140
+ '[None, Object("C"), None]]'
141
+
142
+ # A chain of morphisms.
143
+ f = NamedMorphism(A, B, "f")
144
+ g = NamedMorphism(B, C, "g")
145
+ h = NamedMorphism(C, D, "h")
146
+ k = NamedMorphism(D, E, "k")
147
+ d = Diagram([f, g, h, k])
148
+ grid = DiagramGrid(d)
149
+
150
+ assert grid.width == 3
151
+ assert grid.height == 3
152
+ assert grid[0, 0] == A
153
+ assert grid[0, 1] == B
154
+ assert grid[0, 2] is None
155
+ assert grid[1, 0] is None
156
+ assert grid[1, 1] == C
157
+ assert grid[1, 2] == D
158
+ assert grid[2, 0] is None
159
+ assert grid[2, 1] is None
160
+ assert grid[2, 2] == E
161
+ assert grid.morphisms == {f: FiniteSet(), g: FiniteSet(), h: FiniteSet(),
162
+ k: FiniteSet()}
163
+
164
+ # A square.
165
+ f = NamedMorphism(A, B, "f")
166
+ g = NamedMorphism(B, D, "g")
167
+ h = NamedMorphism(A, C, "h")
168
+ k = NamedMorphism(C, D, "k")
169
+ d = Diagram([f, g, h, k])
170
+ grid = DiagramGrid(d)
171
+
172
+ assert grid.width == 2
173
+ assert grid.height == 2
174
+ assert grid[0, 0] == A
175
+ assert grid[0, 1] == B
176
+ assert grid[1, 0] == C
177
+ assert grid[1, 1] == D
178
+ assert grid.morphisms == {f: FiniteSet(), g: FiniteSet(), h: FiniteSet(),
179
+ k: FiniteSet()}
180
+
181
+ # A strange diagram which resulted from a typo when creating a
182
+ # test for five lemma, but which allowed to stop one extra problem
183
+ # in the algorithm.
184
+ A = Object("A")
185
+ B = Object("B")
186
+ C = Object("C")
187
+ D = Object("D")
188
+ E = Object("E")
189
+ A_ = Object("A'")
190
+ B_ = Object("B'")
191
+ C_ = Object("C'")
192
+ D_ = Object("D'")
193
+ E_ = Object("E'")
194
+
195
+ f = NamedMorphism(A, B, "f")
196
+ g = NamedMorphism(B, C, "g")
197
+ h = NamedMorphism(C, D, "h")
198
+ i = NamedMorphism(D, E, "i")
199
+
200
+ # These 4 morphisms should be between primed objects.
201
+ j = NamedMorphism(A, B, "j")
202
+ k = NamedMorphism(B, C, "k")
203
+ l = NamedMorphism(C, D, "l")
204
+ m = NamedMorphism(D, E, "m")
205
+
206
+ o = NamedMorphism(A, A_, "o")
207
+ p = NamedMorphism(B, B_, "p")
208
+ q = NamedMorphism(C, C_, "q")
209
+ r = NamedMorphism(D, D_, "r")
210
+ s = NamedMorphism(E, E_, "s")
211
+
212
+ d = Diagram([f, g, h, i, j, k, l, m, o, p, q, r, s])
213
+ grid = DiagramGrid(d)
214
+
215
+ assert grid.width == 3
216
+ assert grid.height == 4
217
+ assert grid[0, 0] is None
218
+ assert grid[0, 1] == A
219
+ assert grid[0, 2] == A_
220
+ assert grid[1, 0] == C
221
+ assert grid[1, 1] == B
222
+ assert grid[1, 2] == B_
223
+ assert grid[2, 0] == C_
224
+ assert grid[2, 1] == D
225
+ assert grid[2, 2] == D_
226
+ assert grid[3, 0] is None
227
+ assert grid[3, 1] == E
228
+ assert grid[3, 2] == E_
229
+
230
+ morphisms = {}
231
+ for m in [f, g, h, i, j, k, l, m, o, p, q, r, s]:
232
+ morphisms[m] = FiniteSet()
233
+ assert grid.morphisms == morphisms
234
+
235
+ # A cube.
236
+ A1 = Object("A1")
237
+ A2 = Object("A2")
238
+ A3 = Object("A3")
239
+ A4 = Object("A4")
240
+ A5 = Object("A5")
241
+ A6 = Object("A6")
242
+ A7 = Object("A7")
243
+ A8 = Object("A8")
244
+
245
+ # The top face of the cube.
246
+ f1 = NamedMorphism(A1, A2, "f1")
247
+ f2 = NamedMorphism(A1, A3, "f2")
248
+ f3 = NamedMorphism(A2, A4, "f3")
249
+ f4 = NamedMorphism(A3, A4, "f3")
250
+
251
+ # The bottom face of the cube.
252
+ f5 = NamedMorphism(A5, A6, "f5")
253
+ f6 = NamedMorphism(A5, A7, "f6")
254
+ f7 = NamedMorphism(A6, A8, "f7")
255
+ f8 = NamedMorphism(A7, A8, "f8")
256
+
257
+ # The remaining morphisms.
258
+ f9 = NamedMorphism(A1, A5, "f9")
259
+ f10 = NamedMorphism(A2, A6, "f10")
260
+ f11 = NamedMorphism(A3, A7, "f11")
261
+ f12 = NamedMorphism(A4, A8, "f11")
262
+
263
+ d = Diagram([f1, f2, f3, f4, f5, f6, f7, f8, f9, f10, f11, f12])
264
+ grid = DiagramGrid(d)
265
+
266
+ assert grid.width == 4
267
+ assert grid.height == 3
268
+ assert grid[0, 0] is None
269
+ assert grid[0, 1] == A5
270
+ assert grid[0, 2] == A6
271
+ assert grid[0, 3] is None
272
+ assert grid[1, 0] is None
273
+ assert grid[1, 1] == A1
274
+ assert grid[1, 2] == A2
275
+ assert grid[1, 3] is None
276
+ assert grid[2, 0] == A7
277
+ assert grid[2, 1] == A3
278
+ assert grid[2, 2] == A4
279
+ assert grid[2, 3] == A8
280
+
281
+ morphisms = {}
282
+ for m in [f1, f2, f3, f4, f5, f6, f7, f8, f9, f10, f11, f12]:
283
+ morphisms[m] = FiniteSet()
284
+ assert grid.morphisms == morphisms
285
+
286
+ # A line diagram.
287
+ A = Object("A")
288
+ B = Object("B")
289
+ C = Object("C")
290
+ D = Object("D")
291
+ E = Object("E")
292
+
293
+ f = NamedMorphism(A, B, "f")
294
+ g = NamedMorphism(B, C, "g")
295
+ h = NamedMorphism(C, D, "h")
296
+ i = NamedMorphism(D, E, "i")
297
+ d = Diagram([f, g, h, i])
298
+ grid = DiagramGrid(d, layout="sequential")
299
+
300
+ assert grid.width == 5
301
+ assert grid.height == 1
302
+ assert grid[0, 0] == A
303
+ assert grid[0, 1] == B
304
+ assert grid[0, 2] == C
305
+ assert grid[0, 3] == D
306
+ assert grid[0, 4] == E
307
+ assert grid.morphisms == {f: FiniteSet(), g: FiniteSet(), h: FiniteSet(),
308
+ i: FiniteSet()}
309
+
310
+ # Test the transposed version.
311
+ grid = DiagramGrid(d, layout="sequential", transpose=True)
312
+
313
+ assert grid.width == 1
314
+ assert grid.height == 5
315
+ assert grid[0, 0] == A
316
+ assert grid[1, 0] == B
317
+ assert grid[2, 0] == C
318
+ assert grid[3, 0] == D
319
+ assert grid[4, 0] == E
320
+ assert grid.morphisms == {f: FiniteSet(), g: FiniteSet(), h: FiniteSet(),
321
+ i: FiniteSet()}
322
+
323
+ # A pullback.
324
+ m1 = NamedMorphism(A, B, "m1")
325
+ m2 = NamedMorphism(A, C, "m2")
326
+ s1 = NamedMorphism(B, D, "s1")
327
+ s2 = NamedMorphism(C, D, "s2")
328
+ f1 = NamedMorphism(E, B, "f1")
329
+ f2 = NamedMorphism(E, C, "f2")
330
+ g = NamedMorphism(E, A, "g")
331
+
332
+ d = Diagram([m1, m2, s1, s2, f1, f2], {g: "unique"})
333
+ grid = DiagramGrid(d)
334
+
335
+ assert grid.width == 3
336
+ assert grid.height == 2
337
+ assert grid[0, 0] == A
338
+ assert grid[0, 1] == B
339
+ assert grid[0, 2] == E
340
+ assert grid[1, 0] == C
341
+ assert grid[1, 1] == D
342
+ assert grid[1, 2] is None
343
+
344
+ morphisms = {g: FiniteSet("unique")}
345
+ for m in [m1, m2, s1, s2, f1, f2]:
346
+ morphisms[m] = FiniteSet()
347
+ assert grid.morphisms == morphisms
348
+
349
+ # Test the pullback with sequential layout, just for stress
350
+ # testing.
351
+ grid = DiagramGrid(d, layout="sequential")
352
+
353
+ assert grid.width == 5
354
+ assert grid.height == 1
355
+ assert grid[0, 0] == D
356
+ assert grid[0, 1] == B
357
+ assert grid[0, 2] == A
358
+ assert grid[0, 3] == C
359
+ assert grid[0, 4] == E
360
+ assert grid.morphisms == morphisms
361
+
362
+ # Test a pullback with object grouping.
363
+ grid = DiagramGrid(d, groups=FiniteSet(E, FiniteSet(A, B, C, D)))
364
+
365
+ assert grid.width == 3
366
+ assert grid.height == 2
367
+ assert grid[0, 0] == E
368
+ assert grid[0, 1] == A
369
+ assert grid[0, 2] == B
370
+ assert grid[1, 0] is None
371
+ assert grid[1, 1] == C
372
+ assert grid[1, 2] == D
373
+ assert grid.morphisms == morphisms
374
+
375
+ # Five lemma, actually.
376
+ A = Object("A")
377
+ B = Object("B")
378
+ C = Object("C")
379
+ D = Object("D")
380
+ E = Object("E")
381
+ A_ = Object("A'")
382
+ B_ = Object("B'")
383
+ C_ = Object("C'")
384
+ D_ = Object("D'")
385
+ E_ = Object("E'")
386
+
387
+ f = NamedMorphism(A, B, "f")
388
+ g = NamedMorphism(B, C, "g")
389
+ h = NamedMorphism(C, D, "h")
390
+ i = NamedMorphism(D, E, "i")
391
+
392
+ j = NamedMorphism(A_, B_, "j")
393
+ k = NamedMorphism(B_, C_, "k")
394
+ l = NamedMorphism(C_, D_, "l")
395
+ m = NamedMorphism(D_, E_, "m")
396
+
397
+ o = NamedMorphism(A, A_, "o")
398
+ p = NamedMorphism(B, B_, "p")
399
+ q = NamedMorphism(C, C_, "q")
400
+ r = NamedMorphism(D, D_, "r")
401
+ s = NamedMorphism(E, E_, "s")
402
+
403
+ d = Diagram([f, g, h, i, j, k, l, m, o, p, q, r, s])
404
+ grid = DiagramGrid(d)
405
+
406
+ assert grid.width == 5
407
+ assert grid.height == 3
408
+ assert grid[0, 0] is None
409
+ assert grid[0, 1] == A
410
+ assert grid[0, 2] == A_
411
+ assert grid[0, 3] is None
412
+ assert grid[0, 4] is None
413
+ assert grid[1, 0] == C
414
+ assert grid[1, 1] == B
415
+ assert grid[1, 2] == B_
416
+ assert grid[1, 3] == C_
417
+ assert grid[1, 4] is None
418
+ assert grid[2, 0] == D
419
+ assert grid[2, 1] == E
420
+ assert grid[2, 2] is None
421
+ assert grid[2, 3] == D_
422
+ assert grid[2, 4] == E_
423
+
424
+ morphisms = {}
425
+ for m in [f, g, h, i, j, k, l, m, o, p, q, r, s]:
426
+ morphisms[m] = FiniteSet()
427
+ assert grid.morphisms == morphisms
428
+
429
+ # Test the five lemma with object grouping.
430
+ grid = DiagramGrid(d, FiniteSet(
431
+ FiniteSet(A, B, C, D, E), FiniteSet(A_, B_, C_, D_, E_)))
432
+
433
+ assert grid.width == 6
434
+ assert grid.height == 3
435
+ assert grid[0, 0] == A
436
+ assert grid[0, 1] == B
437
+ assert grid[0, 2] is None
438
+ assert grid[0, 3] == A_
439
+ assert grid[0, 4] == B_
440
+ assert grid[0, 5] is None
441
+ assert grid[1, 0] is None
442
+ assert grid[1, 1] == C
443
+ assert grid[1, 2] == D
444
+ assert grid[1, 3] is None
445
+ assert grid[1, 4] == C_
446
+ assert grid[1, 5] == D_
447
+ assert grid[2, 0] is None
448
+ assert grid[2, 1] is None
449
+ assert grid[2, 2] == E
450
+ assert grid[2, 3] is None
451
+ assert grid[2, 4] is None
452
+ assert grid[2, 5] == E_
453
+ assert grid.morphisms == morphisms
454
+
455
+ # Test the five lemma with object grouping, but mixing containers
456
+ # to represent groups.
457
+ grid = DiagramGrid(d, [(A, B, C, D, E), {A_, B_, C_, D_, E_}])
458
+
459
+ assert grid.width == 6
460
+ assert grid.height == 3
461
+ assert grid[0, 0] == A
462
+ assert grid[0, 1] == B
463
+ assert grid[0, 2] is None
464
+ assert grid[0, 3] == A_
465
+ assert grid[0, 4] == B_
466
+ assert grid[0, 5] is None
467
+ assert grid[1, 0] is None
468
+ assert grid[1, 1] == C
469
+ assert grid[1, 2] == D
470
+ assert grid[1, 3] is None
471
+ assert grid[1, 4] == C_
472
+ assert grid[1, 5] == D_
473
+ assert grid[2, 0] is None
474
+ assert grid[2, 1] is None
475
+ assert grid[2, 2] == E
476
+ assert grid[2, 3] is None
477
+ assert grid[2, 4] is None
478
+ assert grid[2, 5] == E_
479
+ assert grid.morphisms == morphisms
480
+
481
+ # Test the five lemma with object grouping and hints.
482
+ grid = DiagramGrid(d, {
483
+ FiniteSet(A, B, C, D, E): {"layout": "sequential",
484
+ "transpose": True},
485
+ FiniteSet(A_, B_, C_, D_, E_): {"layout": "sequential",
486
+ "transpose": True}},
487
+ transpose=True)
488
+
489
+ assert grid.width == 5
490
+ assert grid.height == 2
491
+ assert grid[0, 0] == A
492
+ assert grid[0, 1] == B
493
+ assert grid[0, 2] == C
494
+ assert grid[0, 3] == D
495
+ assert grid[0, 4] == E
496
+ assert grid[1, 0] == A_
497
+ assert grid[1, 1] == B_
498
+ assert grid[1, 2] == C_
499
+ assert grid[1, 3] == D_
500
+ assert grid[1, 4] == E_
501
+ assert grid.morphisms == morphisms
502
+
503
+ # A two-triangle disconnected diagram.
504
+ f = NamedMorphism(A, B, "f")
505
+ g = NamedMorphism(B, C, "g")
506
+ f_ = NamedMorphism(A_, B_, "f")
507
+ g_ = NamedMorphism(B_, C_, "g")
508
+ d = Diagram([f, g, f_, g_], {g * f: "unique", g_ * f_: "unique"})
509
+ grid = DiagramGrid(d)
510
+
511
+ assert grid.width == 4
512
+ assert grid.height == 2
513
+ assert grid[0, 0] == A
514
+ assert grid[0, 1] == B
515
+ assert grid[0, 2] == A_
516
+ assert grid[0, 3] == B_
517
+ assert grid[1, 0] == C
518
+ assert grid[1, 1] is None
519
+ assert grid[1, 2] == C_
520
+ assert grid[1, 3] is None
521
+ assert grid.morphisms == {f: FiniteSet(), g: FiniteSet(), f_: FiniteSet(),
522
+ g_: FiniteSet(), g * f: FiniteSet("unique"),
523
+ g_ * f_: FiniteSet("unique")}
524
+
525
+ # A two-morphism disconnected diagram.
526
+ f = NamedMorphism(A, B, "f")
527
+ g = NamedMorphism(C, D, "g")
528
+ d = Diagram([f, g])
529
+ grid = DiagramGrid(d)
530
+
531
+ assert grid.width == 4
532
+ assert grid.height == 1
533
+ assert grid[0, 0] == A
534
+ assert grid[0, 1] == B
535
+ assert grid[0, 2] == C
536
+ assert grid[0, 3] == D
537
+ assert grid.morphisms == {f: FiniteSet(), g: FiniteSet()}
538
+
539
+ # Test a one-object diagram.
540
+ f = NamedMorphism(A, A, "f")
541
+ d = Diagram([f])
542
+ grid = DiagramGrid(d)
543
+
544
+ assert grid.width == 1
545
+ assert grid.height == 1
546
+ assert grid[0, 0] == A
547
+
548
+ # Test a two-object disconnected diagram.
549
+ g = NamedMorphism(B, B, "g")
550
+ d = Diagram([f, g])
551
+ grid = DiagramGrid(d)
552
+
553
+ assert grid.width == 2
554
+ assert grid.height == 1
555
+ assert grid[0, 0] == A
556
+ assert grid[0, 1] == B
557
+
558
+
559
+ def test_DiagramGrid_pseudopod():
560
+ # Test a diagram in which even growing a pseudopod does not
561
+ # eventually help.
562
+ A = Object("A")
563
+ B = Object("B")
564
+ C = Object("C")
565
+ D = Object("D")
566
+ E = Object("E")
567
+ F = Object("F")
568
+ A_ = Object("A'")
569
+ B_ = Object("B'")
570
+ C_ = Object("C'")
571
+ D_ = Object("D'")
572
+ E_ = Object("E'")
573
+
574
+ f1 = NamedMorphism(A, B, "f1")
575
+ f2 = NamedMorphism(A, C, "f2")
576
+ f3 = NamedMorphism(A, D, "f3")
577
+ f4 = NamedMorphism(A, E, "f4")
578
+ f5 = NamedMorphism(A, A_, "f5")
579
+ f6 = NamedMorphism(A, B_, "f6")
580
+ f7 = NamedMorphism(A, C_, "f7")
581
+ f8 = NamedMorphism(A, D_, "f8")
582
+ f9 = NamedMorphism(A, E_, "f9")
583
+ f10 = NamedMorphism(A, F, "f10")
584
+ d = Diagram([f1, f2, f3, f4, f5, f6, f7, f8, f9, f10])
585
+ grid = DiagramGrid(d)
586
+
587
+ assert grid.width == 5
588
+ assert grid.height == 3
589
+ assert grid[0, 0] == E
590
+ assert grid[0, 1] == C
591
+ assert grid[0, 2] == C_
592
+ assert grid[0, 3] == E_
593
+ assert grid[0, 4] == F
594
+ assert grid[1, 0] == D
595
+ assert grid[1, 1] == A
596
+ assert grid[1, 2] == A_
597
+ assert grid[1, 3] is None
598
+ assert grid[1, 4] is None
599
+ assert grid[2, 0] == D_
600
+ assert grid[2, 1] == B
601
+ assert grid[2, 2] == B_
602
+ assert grid[2, 3] is None
603
+ assert grid[2, 4] is None
604
+
605
+ morphisms = {}
606
+ for f in [f1, f2, f3, f4, f5, f6, f7, f8, f9, f10]:
607
+ morphisms[f] = FiniteSet()
608
+ assert grid.morphisms == morphisms
609
+
610
+
611
+ def test_ArrowStringDescription():
612
+ astr = ArrowStringDescription("cm", "", None, "", "", "d", "r", "_", "f")
613
+ assert str(astr) == "\\ar[dr]_{f}"
614
+
615
+ astr = ArrowStringDescription("cm", "", 12, "", "", "d", "r", "_", "f")
616
+ assert str(astr) == "\\ar[dr]_{f}"
617
+
618
+ astr = ArrowStringDescription("cm", "^", 12, "", "", "d", "r", "_", "f")
619
+ assert str(astr) == "\\ar@/^12cm/[dr]_{f}"
620
+
621
+ astr = ArrowStringDescription("cm", "", 12, "r", "", "d", "r", "_", "f")
622
+ assert str(astr) == "\\ar[dr]_{f}"
623
+
624
+ astr = ArrowStringDescription("cm", "", 12, "r", "u", "d", "r", "_", "f")
625
+ assert str(astr) == "\\ar@(r,u)[dr]_{f}"
626
+
627
+ astr = ArrowStringDescription("cm", "", 12, "r", "u", "d", "r", "_", "f")
628
+ assert str(astr) == "\\ar@(r,u)[dr]_{f}"
629
+
630
+ astr = ArrowStringDescription("cm", "", 12, "r", "u", "d", "r", "_", "f")
631
+ astr.arrow_style = "{-->}"
632
+ assert str(astr) == "\\ar@(r,u)@{-->}[dr]_{f}"
633
+
634
+ astr = ArrowStringDescription("cm", "_", 12, "", "", "d", "r", "_", "f")
635
+ astr.arrow_style = "{-->}"
636
+ assert str(astr) == "\\ar@/_12cm/@{-->}[dr]_{f}"
637
+
638
+
639
+ def test_XypicDiagramDrawer_line():
640
+ # A linear diagram.
641
+ A = Object("A")
642
+ B = Object("B")
643
+ C = Object("C")
644
+ D = Object("D")
645
+ E = Object("E")
646
+
647
+ f = NamedMorphism(A, B, "f")
648
+ g = NamedMorphism(B, C, "g")
649
+ h = NamedMorphism(C, D, "h")
650
+ i = NamedMorphism(D, E, "i")
651
+ d = Diagram([f, g, h, i])
652
+ grid = DiagramGrid(d, layout="sequential")
653
+ drawer = XypicDiagramDrawer()
654
+ assert drawer.draw(d, grid) == "\\xymatrix{\n" \
655
+ "A \\ar[r]^{f} & B \\ar[r]^{g} & C \\ar[r]^{h} & D \\ar[r]^{i} & E \n" \
656
+ "}\n"
657
+
658
+ # The same diagram, transposed.
659
+ grid = DiagramGrid(d, layout="sequential", transpose=True)
660
+ drawer = XypicDiagramDrawer()
661
+ assert drawer.draw(d, grid) == "\\xymatrix{\n" \
662
+ "A \\ar[d]^{f} \\\\\n" \
663
+ "B \\ar[d]^{g} \\\\\n" \
664
+ "C \\ar[d]^{h} \\\\\n" \
665
+ "D \\ar[d]^{i} \\\\\n" \
666
+ "E \n" \
667
+ "}\n"
668
+
669
+
670
+ def test_XypicDiagramDrawer_triangle():
671
+ # A triangle diagram.
672
+ A = Object("A")
673
+ B = Object("B")
674
+ C = Object("C")
675
+ f = NamedMorphism(A, B, "f")
676
+ g = NamedMorphism(B, C, "g")
677
+
678
+ d = Diagram([f, g], {g * f: "unique"})
679
+ grid = DiagramGrid(d)
680
+ drawer = XypicDiagramDrawer()
681
+ assert drawer.draw(d, grid) == "\\xymatrix{\n" \
682
+ "A \\ar[d]_{g\\circ f} \\ar[r]^{f} & B \\ar[ld]^{g} \\\\\n" \
683
+ "C & \n" \
684
+ "}\n"
685
+
686
+ # The same diagram, transposed.
687
+ grid = DiagramGrid(d, transpose=True)
688
+ drawer = XypicDiagramDrawer()
689
+ assert drawer.draw(d, grid) == "\\xymatrix{\n" \
690
+ "A \\ar[r]^{g\\circ f} \\ar[d]_{f} & C \\\\\n" \
691
+ "B \\ar[ru]_{g} & \n" \
692
+ "}\n"
693
+
694
+ # The same diagram, with a masked morphism.
695
+ assert drawer.draw(d, grid, masked=[g]) == "\\xymatrix{\n" \
696
+ "A \\ar[r]^{g\\circ f} \\ar[d]_{f} & C \\\\\n" \
697
+ "B & \n" \
698
+ "}\n"
699
+
700
+ # The same diagram with a formatter for "unique".
701
+ def formatter(astr):
702
+ astr.label = "\\exists !" + astr.label
703
+ astr.arrow_style = "{-->}"
704
+
705
+ drawer.arrow_formatters["unique"] = formatter
706
+ assert drawer.draw(d, grid) == "\\xymatrix{\n" \
707
+ "A \\ar@{-->}[r]^{\\exists !g\\circ f} \\ar[d]_{f} & C \\\\\n" \
708
+ "B \\ar[ru]_{g} & \n" \
709
+ "}\n"
710
+
711
+ # The same diagram with a default formatter.
712
+ def default_formatter(astr):
713
+ astr.label_displacement = "(0.45)"
714
+
715
+ drawer.default_arrow_formatter = default_formatter
716
+ assert drawer.draw(d, grid) == "\\xymatrix{\n" \
717
+ "A \\ar@{-->}[r]^(0.45){\\exists !g\\circ f} \\ar[d]_(0.45){f} & C \\\\\n" \
718
+ "B \\ar[ru]_(0.45){g} & \n" \
719
+ "}\n"
720
+
721
+ # A triangle diagram with a lot of morphisms between the same
722
+ # objects.
723
+ f1 = NamedMorphism(B, A, "f1")
724
+ f2 = NamedMorphism(A, B, "f2")
725
+ g1 = NamedMorphism(C, B, "g1")
726
+ g2 = NamedMorphism(B, C, "g2")
727
+ d = Diagram([f, f1, f2, g, g1, g2], {f1 * g1: "unique", g2 * f2: "unique"})
728
+
729
+ grid = DiagramGrid(d, transpose=True)
730
+ drawer = XypicDiagramDrawer()
731
+ assert drawer.draw(d, grid, masked=[f1*g1*g2*f2, g2*f2*f1*g1]) == \
732
+ "\\xymatrix{\n" \
733
+ "A \\ar[r]^{g_{2}\\circ f_{2}} \\ar[d]_{f} \\ar@/^3mm/[d]^{f_{2}} " \
734
+ "& C \\ar@/^3mm/[l]^{f_{1}\\circ g_{1}} \\ar@/^3mm/[ld]^{g_{1}} \\\\\n" \
735
+ "B \\ar@/^3mm/[u]^{f_{1}} \\ar[ru]_{g} \\ar@/^3mm/[ru]^{g_{2}} & \n" \
736
+ "}\n"
737
+
738
+
739
+ def test_XypicDiagramDrawer_cube():
740
+ # A cube diagram.
741
+ A1 = Object("A1")
742
+ A2 = Object("A2")
743
+ A3 = Object("A3")
744
+ A4 = Object("A4")
745
+ A5 = Object("A5")
746
+ A6 = Object("A6")
747
+ A7 = Object("A7")
748
+ A8 = Object("A8")
749
+
750
+ # The top face of the cube.
751
+ f1 = NamedMorphism(A1, A2, "f1")
752
+ f2 = NamedMorphism(A1, A3, "f2")
753
+ f3 = NamedMorphism(A2, A4, "f3")
754
+ f4 = NamedMorphism(A3, A4, "f3")
755
+
756
+ # The bottom face of the cube.
757
+ f5 = NamedMorphism(A5, A6, "f5")
758
+ f6 = NamedMorphism(A5, A7, "f6")
759
+ f7 = NamedMorphism(A6, A8, "f7")
760
+ f8 = NamedMorphism(A7, A8, "f8")
761
+
762
+ # The remaining morphisms.
763
+ f9 = NamedMorphism(A1, A5, "f9")
764
+ f10 = NamedMorphism(A2, A6, "f10")
765
+ f11 = NamedMorphism(A3, A7, "f11")
766
+ f12 = NamedMorphism(A4, A8, "f11")
767
+
768
+ d = Diagram([f1, f2, f3, f4, f5, f6, f7, f8, f9, f10, f11, f12])
769
+ grid = DiagramGrid(d)
770
+ drawer = XypicDiagramDrawer()
771
+ assert drawer.draw(d, grid) == "\\xymatrix{\n" \
772
+ "& A_{5} \\ar[r]^{f_{5}} \\ar[ldd]_{f_{6}} & A_{6} \\ar[rdd]^{f_{7}} " \
773
+ "& \\\\\n" \
774
+ "& A_{1} \\ar[r]^{f_{1}} \\ar[d]^{f_{2}} \\ar[u]^{f_{9}} & A_{2} " \
775
+ "\\ar[d]^{f_{3}} \\ar[u]_{f_{10}} & \\\\\n" \
776
+ "A_{7} \\ar@/_3mm/[rrr]_{f_{8}} & A_{3} \\ar[r]^{f_{3}} \\ar[l]_{f_{11}} " \
777
+ "& A_{4} \\ar[r]^{f_{11}} & A_{8} \n" \
778
+ "}\n"
779
+
780
+ # The same diagram, transposed.
781
+ grid = DiagramGrid(d, transpose=True)
782
+ drawer = XypicDiagramDrawer()
783
+ assert drawer.draw(d, grid) == "\\xymatrix{\n" \
784
+ "& & A_{7} \\ar@/^3mm/[ddd]^{f_{8}} \\\\\n" \
785
+ "A_{5} \\ar[d]_{f_{5}} \\ar[rru]^{f_{6}} & A_{1} \\ar[d]^{f_{1}} " \
786
+ "\\ar[r]^{f_{2}} \\ar[l]^{f_{9}} & A_{3} \\ar[d]_{f_{3}} " \
787
+ "\\ar[u]^{f_{11}} \\\\\n" \
788
+ "A_{6} \\ar[rrd]_{f_{7}} & A_{2} \\ar[r]^{f_{3}} \\ar[l]^{f_{10}} " \
789
+ "& A_{4} \\ar[d]_{f_{11}} \\\\\n" \
790
+ "& & A_{8} \n" \
791
+ "}\n"
792
+
793
+
794
+ def test_XypicDiagramDrawer_curved_and_loops():
795
+ # A simple diagram, with a curved arrow.
796
+ A = Object("A")
797
+ B = Object("B")
798
+ C = Object("C")
799
+ D = Object("D")
800
+
801
+ f = NamedMorphism(A, B, "f")
802
+ g = NamedMorphism(B, C, "g")
803
+ h = NamedMorphism(D, A, "h")
804
+ k = NamedMorphism(D, B, "k")
805
+ d = Diagram([f, g, h, k])
806
+ grid = DiagramGrid(d)
807
+ drawer = XypicDiagramDrawer()
808
+ assert drawer.draw(d, grid) == "\\xymatrix{\n" \
809
+ "A \\ar[r]_{f} & B \\ar[d]^{g} & D \\ar[l]^{k} \\ar@/_3mm/[ll]_{h} \\\\\n" \
810
+ "& C & \n" \
811
+ "}\n"
812
+
813
+ # The same diagram, transposed.
814
+ grid = DiagramGrid(d, transpose=True)
815
+ drawer = XypicDiagramDrawer()
816
+ assert drawer.draw(d, grid) == "\\xymatrix{\n" \
817
+ "A \\ar[d]^{f} & \\\\\n" \
818
+ "B \\ar[r]^{g} & C \\\\\n" \
819
+ "D \\ar[u]_{k} \\ar@/^3mm/[uu]^{h} & \n" \
820
+ "}\n"
821
+
822
+ # The same diagram, larger and rotated.
823
+ assert drawer.draw(d, grid, diagram_format="@+1cm@dr") == \
824
+ "\\xymatrix@+1cm@dr{\n" \
825
+ "A \\ar[d]^{f} & \\\\\n" \
826
+ "B \\ar[r]^{g} & C \\\\\n" \
827
+ "D \\ar[u]_{k} \\ar@/^3mm/[uu]^{h} & \n" \
828
+ "}\n"
829
+
830
+ # A simple diagram with three curved arrows.
831
+ h1 = NamedMorphism(D, A, "h1")
832
+ h2 = NamedMorphism(A, D, "h2")
833
+ k = NamedMorphism(D, B, "k")
834
+ d = Diagram([f, g, h, k, h1, h2])
835
+ grid = DiagramGrid(d)
836
+ drawer = XypicDiagramDrawer()
837
+ assert drawer.draw(d, grid) == "\\xymatrix{\n" \
838
+ "A \\ar[r]_{f} \\ar@/^3mm/[rr]^{h_{2}} & B \\ar[d]^{g} & D \\ar[l]^{k} " \
839
+ "\\ar@/_7mm/[ll]_{h} \\ar@/_11mm/[ll]_{h_{1}} \\\\\n" \
840
+ "& C & \n" \
841
+ "}\n"
842
+
843
+ # The same diagram, transposed.
844
+ grid = DiagramGrid(d, transpose=True)
845
+ drawer = XypicDiagramDrawer()
846
+ assert drawer.draw(d, grid) == "\\xymatrix{\n" \
847
+ "A \\ar[d]^{f} \\ar@/_3mm/[dd]_{h_{2}} & \\\\\n" \
848
+ "B \\ar[r]^{g} & C \\\\\n" \
849
+ "D \\ar[u]_{k} \\ar@/^7mm/[uu]^{h} \\ar@/^11mm/[uu]^{h_{1}} & \n" \
850
+ "}\n"
851
+
852
+ # The same diagram, with "loop" morphisms.
853
+ l_A = NamedMorphism(A, A, "l_A")
854
+ l_D = NamedMorphism(D, D, "l_D")
855
+ l_C = NamedMorphism(C, C, "l_C")
856
+ d = Diagram([f, g, h, k, h1, h2, l_A, l_D, l_C])
857
+ grid = DiagramGrid(d)
858
+ drawer = XypicDiagramDrawer()
859
+ assert drawer.draw(d, grid) == "\\xymatrix{\n" \
860
+ "A \\ar[r]_{f} \\ar@/^3mm/[rr]^{h_{2}} \\ar@(u,l)[]^{l_{A}} " \
861
+ "& B \\ar[d]^{g} & D \\ar[l]^{k} \\ar@/_7mm/[ll]_{h} " \
862
+ "\\ar@/_11mm/[ll]_{h_{1}} \\ar@(r,u)[]^{l_{D}} \\\\\n" \
863
+ "& C \\ar@(l,d)[]^{l_{C}} & \n" \
864
+ "}\n"
865
+
866
+ # The same diagram with "loop" morphisms, transposed.
867
+ grid = DiagramGrid(d, transpose=True)
868
+ drawer = XypicDiagramDrawer()
869
+ assert drawer.draw(d, grid) == "\\xymatrix{\n" \
870
+ "A \\ar[d]^{f} \\ar@/_3mm/[dd]_{h_{2}} \\ar@(r,u)[]^{l_{A}} & \\\\\n" \
871
+ "B \\ar[r]^{g} & C \\ar@(r,u)[]^{l_{C}} \\\\\n" \
872
+ "D \\ar[u]_{k} \\ar@/^7mm/[uu]^{h} \\ar@/^11mm/[uu]^{h_{1}} " \
873
+ "\\ar@(l,d)[]^{l_{D}} & \n" \
874
+ "}\n"
875
+
876
+ # The same diagram with two "loop" morphisms per object.
877
+ l_A_ = NamedMorphism(A, A, "n_A")
878
+ l_D_ = NamedMorphism(D, D, "n_D")
879
+ l_C_ = NamedMorphism(C, C, "n_C")
880
+ d = Diagram([f, g, h, k, h1, h2, l_A, l_D, l_C, l_A_, l_D_, l_C_])
881
+ grid = DiagramGrid(d)
882
+ drawer = XypicDiagramDrawer()
883
+ assert drawer.draw(d, grid) == "\\xymatrix{\n" \
884
+ "A \\ar[r]_{f} \\ar@/^3mm/[rr]^{h_{2}} \\ar@(u,l)[]^{l_{A}} " \
885
+ "\\ar@/^3mm/@(l,d)[]^{n_{A}} & B \\ar[d]^{g} & D \\ar[l]^{k} " \
886
+ "\\ar@/_7mm/[ll]_{h} \\ar@/_11mm/[ll]_{h_{1}} \\ar@(r,u)[]^{l_{D}} " \
887
+ "\\ar@/^3mm/@(d,r)[]^{n_{D}} \\\\\n" \
888
+ "& C \\ar@(l,d)[]^{l_{C}} \\ar@/^3mm/@(d,r)[]^{n_{C}} & \n" \
889
+ "}\n"
890
+
891
+ # The same diagram with two "loop" morphisms per object, transposed.
892
+ grid = DiagramGrid(d, transpose=True)
893
+ drawer = XypicDiagramDrawer()
894
+ assert drawer.draw(d, grid) == "\\xymatrix{\n" \
895
+ "A \\ar[d]^{f} \\ar@/_3mm/[dd]_{h_{2}} \\ar@(r,u)[]^{l_{A}} " \
896
+ "\\ar@/^3mm/@(u,l)[]^{n_{A}} & \\\\\n" \
897
+ "B \\ar[r]^{g} & C \\ar@(r,u)[]^{l_{C}} \\ar@/^3mm/@(d,r)[]^{n_{C}} \\\\\n" \
898
+ "D \\ar[u]_{k} \\ar@/^7mm/[uu]^{h} \\ar@/^11mm/[uu]^{h_{1}} " \
899
+ "\\ar@(l,d)[]^{l_{D}} \\ar@/^3mm/@(d,r)[]^{n_{D}} & \n" \
900
+ "}\n"
901
+
902
+
903
+ def test_xypic_draw_diagram():
904
+ # A linear diagram.
905
+ A = Object("A")
906
+ B = Object("B")
907
+ C = Object("C")
908
+ D = Object("D")
909
+ E = Object("E")
910
+
911
+ f = NamedMorphism(A, B, "f")
912
+ g = NamedMorphism(B, C, "g")
913
+ h = NamedMorphism(C, D, "h")
914
+ i = NamedMorphism(D, E, "i")
915
+ d = Diagram([f, g, h, i])
916
+
917
+ grid = DiagramGrid(d, layout="sequential")
918
+ drawer = XypicDiagramDrawer()
919
+ assert drawer.draw(d, grid) == xypic_draw_diagram(d, layout="sequential")
tuning-competition-baseline/.venv/lib/python3.11/site-packages/sympy/codegen/__init__.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ The ``sympy.codegen`` module contains classes and functions for building
2
+ abstract syntax trees of algorithms. These trees may then be printed by the
3
+ code-printers in ``sympy.printing``.
4
+
5
+ There are several submodules available:
6
+ - ``sympy.codegen.ast``: AST nodes useful across multiple languages.
7
+ - ``sympy.codegen.cnodes``: AST nodes useful for the C family of languages.
8
+ - ``sympy.codegen.fnodes``: AST nodes useful for Fortran.
9
+ - ``sympy.codegen.cfunctions``: functions specific to C (C99 math functions)
10
+ - ``sympy.codegen.ffunctions``: functions specific to Fortran (e.g. ``kind``).
11
+
12
+
13
+
14
+ """
15
+ from .ast import (
16
+ Assignment, aug_assign, CodeBlock, For, Attribute, Variable, Declaration,
17
+ While, Scope, Print, FunctionPrototype, FunctionDefinition, FunctionCall
18
+ )
19
+
20
+ __all__ = [
21
+ 'Assignment', 'aug_assign', 'CodeBlock', 'For', 'Attribute', 'Variable',
22
+ 'Declaration', 'While', 'Scope', 'Print', 'FunctionPrototype',
23
+ 'FunctionDefinition', 'FunctionCall',
24
+ ]
tuning-competition-baseline/.venv/lib/python3.11/site-packages/sympy/codegen/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (1.29 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/sympy/codegen/__pycache__/abstract_nodes.cpython-311.pyc ADDED
Binary file (1.32 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/sympy/codegen/__pycache__/algorithms.cpython-311.pyc ADDED
Binary file (9.97 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/sympy/codegen/__pycache__/approximations.cpython-311.pyc ADDED
Binary file (9.71 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/sympy/codegen/__pycache__/ast.cpython-311.pyc ADDED
Binary file (82.4 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/sympy/codegen/__pycache__/cfunctions.cpython-311.pyc ADDED
Binary file (19.6 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/sympy/codegen/__pycache__/cnodes.cpython-311.pyc ADDED
Binary file (6.7 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/sympy/codegen/__pycache__/cutils.cpython-311.pyc ADDED
Binary file (1.01 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/sympy/codegen/__pycache__/cxxnodes.cpython-311.pyc ADDED
Binary file (908 Bytes). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/sympy/codegen/__pycache__/fnodes.cpython-311.pyc ADDED
Binary file (29.6 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/sympy/codegen/__pycache__/futils.cpython-311.pyc ADDED
Binary file (3.04 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/sympy/codegen/__pycache__/rewriting.cpython-311.pyc ADDED
Binary file (20 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/sympy/codegen/tests/__pycache__/test_algorithms.cpython-311.pyc ADDED
Binary file (13.1 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/sympy/codegen/tests/__pycache__/test_ast.cpython-311.pyc ADDED
Binary file (53.7 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/sympy/codegen/tests/__pycache__/test_rewriting.cpython-311.pyc ADDED
Binary file (36 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/sympy/codegen/tests/test_algorithms.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tempfile
2
+ from sympy import log, Min, Max, sqrt
3
+ from sympy.core.numbers import Float
4
+ from sympy.core.symbol import Symbol, symbols
5
+ from sympy.functions.elementary.trigonometric import cos
6
+ from sympy.codegen.ast import Assignment, Raise, RuntimeError_, QuotedString
7
+ from sympy.codegen.algorithms import newtons_method, newtons_method_function
8
+ from sympy.codegen.cfunctions import expm1
9
+ from sympy.codegen.fnodes import bind_C
10
+ from sympy.codegen.futils import render_as_module as f_module
11
+ from sympy.codegen.pyutils import render_as_module as py_module
12
+ from sympy.external import import_module
13
+ from sympy.printing.codeprinter import ccode
14
+ from sympy.utilities._compilation import compile_link_import_strings, has_c, has_fortran
15
+ from sympy.utilities._compilation.util import may_xfail
16
+ from sympy.testing.pytest import skip, raises
17
+
18
+ cython = import_module('cython')
19
+ wurlitzer = import_module('wurlitzer')
20
+
21
+ def test_newtons_method():
22
+ x, dx, atol = symbols('x dx atol')
23
+ expr = cos(x) - x**3
24
+ algo = newtons_method(expr, x, atol, dx)
25
+ assert algo.has(Assignment(dx, -expr/expr.diff(x)))
26
+
27
+
28
+ @may_xfail
29
+ def test_newtons_method_function__ccode():
30
+ x = Symbol('x', real=True)
31
+ expr = cos(x) - x**3
32
+ func = newtons_method_function(expr, x)
33
+
34
+ if not cython:
35
+ skip("cython not installed.")
36
+ if not has_c():
37
+ skip("No C compiler found.")
38
+
39
+ compile_kw = {"std": 'c99'}
40
+ with tempfile.TemporaryDirectory() as folder:
41
+ mod, info = compile_link_import_strings([
42
+ ('newton.c', ('#include <math.h>\n'
43
+ '#include <stdio.h>\n') + ccode(func)),
44
+ ('_newton.pyx', ("#cython: language_level={}\n".format("3") +
45
+ "cdef extern double newton(double)\n"
46
+ "def py_newton(x):\n"
47
+ " return newton(x)\n"))
48
+ ], build_dir=folder, compile_kwargs=compile_kw)
49
+ assert abs(mod.py_newton(0.5) - 0.865474033102) < 1e-12
50
+
51
+
52
+ @may_xfail
53
+ def test_newtons_method_function__fcode():
54
+ x = Symbol('x', real=True)
55
+ expr = cos(x) - x**3
56
+ func = newtons_method_function(expr, x, attrs=[bind_C(name='newton')])
57
+
58
+ if not cython:
59
+ skip("cython not installed.")
60
+ if not has_fortran():
61
+ skip("No Fortran compiler found.")
62
+
63
+ f_mod = f_module([func], 'mod_newton')
64
+ with tempfile.TemporaryDirectory() as folder:
65
+ mod, info = compile_link_import_strings([
66
+ ('newton.f90', f_mod),
67
+ ('_newton.pyx', ("#cython: language_level={}\n".format("3") +
68
+ "cdef extern double newton(double*)\n"
69
+ "def py_newton(double x):\n"
70
+ " return newton(&x)\n"))
71
+ ], build_dir=folder)
72
+ assert abs(mod.py_newton(0.5) - 0.865474033102) < 1e-12
73
+
74
+
75
+ def test_newtons_method_function__pycode():
76
+ x = Symbol('x', real=True)
77
+ expr = cos(x) - x**3
78
+ func = newtons_method_function(expr, x)
79
+ py_mod = py_module(func)
80
+ namespace = {}
81
+ exec(py_mod, namespace, namespace)
82
+ res = eval('newton(0.5)', namespace)
83
+ assert abs(res - 0.865474033102) < 1e-12
84
+
85
+
86
+ @may_xfail
87
+ def test_newtons_method_function__ccode_parameters():
88
+ args = x, A, k, p = symbols('x A k p')
89
+ expr = A*cos(k*x) - p*x**3
90
+ raises(ValueError, lambda: newtons_method_function(expr, x))
91
+ use_wurlitzer = wurlitzer
92
+
93
+ func = newtons_method_function(expr, x, args, debug=use_wurlitzer)
94
+
95
+ if not has_c():
96
+ skip("No C compiler found.")
97
+ if not cython:
98
+ skip("cython not installed.")
99
+
100
+ compile_kw = {"std": 'c99'}
101
+ with tempfile.TemporaryDirectory() as folder:
102
+ mod, info = compile_link_import_strings([
103
+ ('newton_par.c', ('#include <math.h>\n'
104
+ '#include <stdio.h>\n') + ccode(func)),
105
+ ('_newton_par.pyx', ("#cython: language_level={}\n".format("3") +
106
+ "cdef extern double newton(double, double, double, double)\n"
107
+ "def py_newton(x, A=1, k=1, p=1):\n"
108
+ " return newton(x, A, k, p)\n"))
109
+ ], compile_kwargs=compile_kw, build_dir=folder)
110
+
111
+ if use_wurlitzer:
112
+ with wurlitzer.pipes() as (out, err):
113
+ result = mod.py_newton(0.5)
114
+ else:
115
+ result = mod.py_newton(0.5)
116
+
117
+ assert abs(result - 0.865474033102) < 1e-12
118
+
119
+ if not use_wurlitzer:
120
+ skip("C-level output only tested when package 'wurlitzer' is available.")
121
+
122
+ out, err = out.read(), err.read()
123
+ assert err == ''
124
+ assert out == """\
125
+ x= 0.5
126
+ x= 1.1121 d_x= 0.61214
127
+ x= 0.90967 d_x= -0.20247
128
+ x= 0.86726 d_x= -0.042409
129
+ x= 0.86548 d_x= -0.0017867
130
+ x= 0.86547 d_x= -3.1022e-06
131
+ x= 0.86547 d_x= -9.3421e-12
132
+ x= 0.86547 d_x= 3.6902e-17
133
+ """ # try to run tests with LC_ALL=C if this assertion fails
134
+
135
+
136
+ def test_newtons_method_function__rtol_cse_nan():
137
+ a, b, c, N_geo, N_tot = symbols('a b c N_geo N_tot', real=True, nonnegative=True)
138
+ i = Symbol('i', integer=True, nonnegative=True)
139
+ N_ari = N_tot - N_geo - 1
140
+ delta_ari = (c-b)/N_ari
141
+ ln_delta_geo = log(b) + log(-expm1((log(a)-log(b))/N_geo))
142
+ eqb_log = ln_delta_geo - log(delta_ari)
143
+
144
+ def _clamp(low, expr, high):
145
+ return Min(Max(low, expr), high)
146
+
147
+ meth_kw = {
148
+ 'clamped_newton': {'delta_fn': lambda e, x: _clamp(
149
+ (sqrt(a*x)-x)*0.99,
150
+ -e/e.diff(x),
151
+ (sqrt(c*x)-x)*0.99
152
+ )},
153
+ 'halley': {'delta_fn': lambda e, x: (-2*(e*e.diff(x))/(2*e.diff(x)**2 - e*e.diff(x, 2)))},
154
+ 'halley_alt': {'delta_fn': lambda e, x: (-e/e.diff(x)/(1-e/e.diff(x)*e.diff(x,2)/2/e.diff(x)))},
155
+ }
156
+ args = eqb_log, b
157
+ for use_cse in [False, True]:
158
+ kwargs = {
159
+ 'params': (b, a, c, N_geo, N_tot), 'itermax': 60, 'debug': True, 'cse': use_cse,
160
+ 'counter': i, 'atol': 1e-100, 'rtol': 2e-16, 'bounds': (a,c),
161
+ 'handle_nan': Raise(RuntimeError_(QuotedString("encountered NaN.")))
162
+ }
163
+ func = {k: newtons_method_function(*args, func_name=f"{k}_b", **dict(kwargs, **kw)) for k, kw in meth_kw.items()}
164
+ py_mod = {k: py_module(v) for k, v in func.items()}
165
+ namespace = {}
166
+ root_find_b = {}
167
+ for k, v in py_mod.items():
168
+ ns = namespace[k] = {}
169
+ exec(v, ns, ns)
170
+ root_find_b[k] = ns[f'{k}_b']
171
+ ref = Float('13.2261515064168768938151923226496')
172
+ reftol = {'clamped_newton': 2e-16, 'halley': 2e-16, 'halley_alt': 3e-16}
173
+ guess = 4.0
174
+ for meth, func in root_find_b.items():
175
+ result = func(guess, 1e-2, 1e2, 50, 100)
176
+ req = ref*reftol[meth]
177
+ if use_cse:
178
+ req *= 2
179
+ assert abs(result - ref) < req
tuning-competition-baseline/.venv/lib/python3.11/site-packages/sympy/codegen/tests/test_cxxnodes.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sympy.core.symbol import Symbol
2
+ from sympy.codegen.ast import Type
3
+ from sympy.codegen.cxxnodes import using
4
+ from sympy.printing.codeprinter import cxxcode
5
+
6
+ x = Symbol('x')
7
+
8
+ def test_using():
9
+ v = Type('std::vector')
10
+ u1 = using(v)
11
+ assert cxxcode(u1) == 'using std::vector'
12
+
13
+ u2 = using(v, 'vec')
14
+ assert cxxcode(u2) == 'using vec = std::vector'
tuning-competition-baseline/.venv/lib/python3.11/site-packages/sympy/codegen/tests/test_matrix_nodes.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sympy.core.symbol import symbols
2
+ from sympy.core.function import Function
3
+ from sympy.matrices.dense import Matrix
4
+ from sympy.matrices.dense import zeros
5
+ from sympy.simplify.simplify import simplify
6
+ from sympy.codegen.matrix_nodes import MatrixSolve
7
+ from sympy.utilities.lambdify import lambdify
8
+ from sympy.printing.numpy import NumPyPrinter
9
+ from sympy.testing.pytest import skip
10
+ from sympy.external import import_module
11
+
12
+
13
+ def test_matrix_solve_issue_24862():
14
+ A = Matrix(3, 3, symbols('a:9'))
15
+ b = Matrix(3, 1, symbols('b:3'))
16
+ hash(MatrixSolve(A, b))
17
+
18
+
19
+ def test_matrix_solve_derivative_exact():
20
+ q = symbols('q')
21
+ a11, a12, a21, a22, b1, b2 = (
22
+ f(q) for f in symbols('a11 a12 a21 a22 b1 b2', cls=Function))
23
+ A = Matrix([[a11, a12], [a21, a22]])
24
+ b = Matrix([b1, b2])
25
+ x_lu = A.LUsolve(b)
26
+ dxdq_lu = A.LUsolve(b.diff(q) - A.diff(q) * A.LUsolve(b))
27
+ assert simplify(x_lu.diff(q) - dxdq_lu) == zeros(2, 1)
28
+ # dxdq_ms is the MatrixSolve equivalent of dxdq_lu
29
+ dxdq_ms = MatrixSolve(A, b.diff(q) - A.diff(q) * MatrixSolve(A, b))
30
+ assert MatrixSolve(A, b).diff(q) == dxdq_ms
31
+
32
+
33
+ def test_matrix_solve_derivative_numpy():
34
+ np = import_module('numpy')
35
+ if not np:
36
+ skip("numpy not installed.")
37
+ q = symbols('q')
38
+ a11, a12, a21, a22, b1, b2 = (
39
+ f(q) for f in symbols('a11 a12 a21 a22 b1 b2', cls=Function))
40
+ A = Matrix([[a11, a12], [a21, a22]])
41
+ b = Matrix([b1, b2])
42
+ dx_lu = A.LUsolve(b).diff(q)
43
+ subs = {a11.diff(q): 0.2, a12.diff(q): 0.3, a21.diff(q): 0.1,
44
+ a22.diff(q): 0.5, b1.diff(q): 0.4, b2.diff(q): 0.9,
45
+ a11: 1.3, a12: 0.5, a21: 1.2, a22: 4, b1: 6.2, b2: 3.5}
46
+ p, p_vals = zip(*subs.items())
47
+ dx_sm = MatrixSolve(A, b).diff(q)
48
+ np.testing.assert_allclose(
49
+ lambdify(p, dx_sm, printer=NumPyPrinter)(*p_vals),
50
+ lambdify(p, dx_lu, printer=NumPyPrinter)(*p_vals))
tuning-competition-baseline/.venv/lib/python3.11/site-packages/sympy/codegen/tests/test_scipy_nodes.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from itertools import product
2
+ from sympy.core.power import Pow
3
+ from sympy.core.symbol import symbols
4
+ from sympy.functions.elementary.exponential import exp, log
5
+ from sympy.functions.elementary.trigonometric import cos
6
+ from sympy.core.numbers import pi
7
+ from sympy.codegen.scipy_nodes import cosm1, powm1
8
+
9
+ x, y, z = symbols('x y z')
10
+
11
+
12
+ def test_cosm1():
13
+ cm1_xy = cosm1(x*y)
14
+ ref_xy = cos(x*y) - 1
15
+ for wrt, deriv_order in product([x, y, z], range(3)):
16
+ assert (
17
+ cm1_xy.diff(wrt, deriv_order) -
18
+ ref_xy.diff(wrt, deriv_order)
19
+ ).rewrite(cos).simplify() == 0
20
+
21
+ expr_minus2 = cosm1(pi)
22
+ assert expr_minus2.rewrite(cos) == -2
23
+ assert cosm1(3.14).simplify() == cosm1(3.14) # cannot simplify with 3.14
24
+ assert cosm1(pi/2).simplify() == -1
25
+ assert (1/cos(x) - 1 + cosm1(x)/cos(x)).simplify() == 0
26
+
27
+
28
+ def test_powm1():
29
+ cases = {
30
+ powm1(x, y): x**y - 1,
31
+ powm1(x*y, z): (x*y)**z - 1,
32
+ powm1(x, y*z): x**(y*z)-1,
33
+ powm1(x*y*z, x*y*z): (x*y*z)**(x*y*z)-1
34
+ }
35
+ for pm1_e, ref_e in cases.items():
36
+ for wrt, deriv_order in product([x, y, z], range(3)):
37
+ der = pm1_e.diff(wrt, deriv_order)
38
+ ref = ref_e.diff(wrt, deriv_order)
39
+ delta = (der - ref).rewrite(Pow)
40
+ assert delta.simplify() == 0
41
+
42
+ eulers_constant_m1 = powm1(x, 1/log(x))
43
+ assert eulers_constant_m1.rewrite(Pow) == exp(1) - 1
44
+ assert eulers_constant_m1.simplify() == exp(1) - 1
tuning-competition-baseline/.venv/lib/python3.11/site-packages/sympy/concrete/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (427 Bytes). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/sympy/concrete/__pycache__/gosper.cpython-311.pyc ADDED
Binary file (8.83 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/sympy/concrete/tests/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (221 Bytes). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/sympy/concrete/tests/__pycache__/test_delta.cpython-311.pyc ADDED
Binary file (93.1 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/sympy/concrete/tests/__pycache__/test_gosper.cpython-311.pyc ADDED
Binary file (23.4 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/sympy/concrete/tests/test_delta.py ADDED
@@ -0,0 +1,499 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sympy.concrete import Sum
2
+ from sympy.concrete.delta import deltaproduct as dp, deltasummation as ds, _extract_delta
3
+ from sympy.core import Eq, S, symbols, oo
4
+ from sympy.functions import KroneckerDelta as KD, Piecewise, piecewise_fold
5
+ from sympy.logic import And
6
+ from sympy.testing.pytest import raises
7
+
8
+ i, j, k, l, m = symbols("i j k l m", integer=True, finite=True)
9
+ x, y = symbols("x y", commutative=False)
10
+
11
+
12
+ def test_deltaproduct_trivial():
13
+ assert dp(x, (j, 1, 0)) == 1
14
+ assert dp(x, (j, 1, 3)) == x**3
15
+ assert dp(x + y, (j, 1, 3)) == (x + y)**3
16
+ assert dp(x*y, (j, 1, 3)) == (x*y)**3
17
+ assert dp(KD(i, j), (k, 1, 3)) == KD(i, j)
18
+ assert dp(x*KD(i, j), (k, 1, 3)) == x**3*KD(i, j)
19
+ assert dp(x*y*KD(i, j), (k, 1, 3)) == (x*y)**3*KD(i, j)
20
+
21
+
22
+ def test_deltaproduct_basic():
23
+ assert dp(KD(i, j), (j, 1, 3)) == 0
24
+ assert dp(KD(i, j), (j, 1, 1)) == KD(i, 1)
25
+ assert dp(KD(i, j), (j, 2, 2)) == KD(i, 2)
26
+ assert dp(KD(i, j), (j, 3, 3)) == KD(i, 3)
27
+ assert dp(KD(i, j), (j, 1, k)) == KD(i, 1)*KD(k, 1) + KD(k, 0)
28
+ assert dp(KD(i, j), (j, k, 3)) == KD(i, 3)*KD(k, 3) + KD(k, 4)
29
+ assert dp(KD(i, j), (j, k, l)) == KD(i, l)*KD(k, l) + KD(k, l + 1)
30
+
31
+
32
+ def test_deltaproduct_mul_x_kd():
33
+ assert dp(x*KD(i, j), (j, 1, 3)) == 0
34
+ assert dp(x*KD(i, j), (j, 1, 1)) == x*KD(i, 1)
35
+ assert dp(x*KD(i, j), (j, 2, 2)) == x*KD(i, 2)
36
+ assert dp(x*KD(i, j), (j, 3, 3)) == x*KD(i, 3)
37
+ assert dp(x*KD(i, j), (j, 1, k)) == x*KD(i, 1)*KD(k, 1) + KD(k, 0)
38
+ assert dp(x*KD(i, j), (j, k, 3)) == x*KD(i, 3)*KD(k, 3) + KD(k, 4)
39
+ assert dp(x*KD(i, j), (j, k, l)) == x*KD(i, l)*KD(k, l) + KD(k, l + 1)
40
+
41
+
42
+ def test_deltaproduct_mul_add_x_y_kd():
43
+ assert dp((x + y)*KD(i, j), (j, 1, 3)) == 0
44
+ assert dp((x + y)*KD(i, j), (j, 1, 1)) == (x + y)*KD(i, 1)
45
+ assert dp((x + y)*KD(i, j), (j, 2, 2)) == (x + y)*KD(i, 2)
46
+ assert dp((x + y)*KD(i, j), (j, 3, 3)) == (x + y)*KD(i, 3)
47
+ assert dp((x + y)*KD(i, j), (j, 1, k)) == \
48
+ (x + y)*KD(i, 1)*KD(k, 1) + KD(k, 0)
49
+ assert dp((x + y)*KD(i, j), (j, k, 3)) == \
50
+ (x + y)*KD(i, 3)*KD(k, 3) + KD(k, 4)
51
+ assert dp((x + y)*KD(i, j), (j, k, l)) == \
52
+ (x + y)*KD(i, l)*KD(k, l) + KD(k, l + 1)
53
+
54
+
55
+ def test_deltaproduct_add_kd_kd():
56
+ assert dp(KD(i, k) + KD(j, k), (k, 1, 3)) == 0
57
+ assert dp(KD(i, k) + KD(j, k), (k, 1, 1)) == KD(i, 1) + KD(j, 1)
58
+ assert dp(KD(i, k) + KD(j, k), (k, 2, 2)) == KD(i, 2) + KD(j, 2)
59
+ assert dp(KD(i, k) + KD(j, k), (k, 3, 3)) == KD(i, 3) + KD(j, 3)
60
+ assert dp(KD(i, k) + KD(j, k), (k, 1, l)) == KD(l, 0) + \
61
+ KD(i, 1)*KD(l, 1) + KD(j, 1)*KD(l, 1) + \
62
+ KD(i, 1)*KD(j, 2)*KD(l, 2) + KD(j, 1)*KD(i, 2)*KD(l, 2)
63
+ assert dp(KD(i, k) + KD(j, k), (k, l, 3)) == KD(l, 4) + \
64
+ KD(i, 3)*KD(l, 3) + KD(j, 3)*KD(l, 3) + \
65
+ KD(i, 2)*KD(j, 3)*KD(l, 2) + KD(i, 3)*KD(j, 2)*KD(l, 2)
66
+ assert dp(KD(i, k) + KD(j, k), (k, l, m)) == KD(l, m + 1) + \
67
+ KD(i, m)*KD(l, m) + KD(j, m)*KD(l, m) + \
68
+ KD(i, m)*KD(j, m - 1)*KD(l, m - 1) + KD(i, m - 1)*KD(j, m)*KD(l, m - 1)
69
+
70
+
71
+ def test_deltaproduct_mul_x_add_kd_kd():
72
+ assert dp(x*(KD(i, k) + KD(j, k)), (k, 1, 3)) == 0
73
+ assert dp(x*(KD(i, k) + KD(j, k)), (k, 1, 1)) == x*(KD(i, 1) + KD(j, 1))
74
+ assert dp(x*(KD(i, k) + KD(j, k)), (k, 2, 2)) == x*(KD(i, 2) + KD(j, 2))
75
+ assert dp(x*(KD(i, k) + KD(j, k)), (k, 3, 3)) == x*(KD(i, 3) + KD(j, 3))
76
+ assert dp(x*(KD(i, k) + KD(j, k)), (k, 1, l)) == KD(l, 0) + \
77
+ x*KD(i, 1)*KD(l, 1) + x*KD(j, 1)*KD(l, 1) + \
78
+ x**2*KD(i, 1)*KD(j, 2)*KD(l, 2) + x**2*KD(j, 1)*KD(i, 2)*KD(l, 2)
79
+ assert dp(x*(KD(i, k) + KD(j, k)), (k, l, 3)) == KD(l, 4) + \
80
+ x*KD(i, 3)*KD(l, 3) + x*KD(j, 3)*KD(l, 3) + \
81
+ x**2*KD(i, 2)*KD(j, 3)*KD(l, 2) + x**2*KD(i, 3)*KD(j, 2)*KD(l, 2)
82
+ assert dp(x*(KD(i, k) + KD(j, k)), (k, l, m)) == KD(l, m + 1) + \
83
+ x*KD(i, m)*KD(l, m) + x*KD(j, m)*KD(l, m) + \
84
+ x**2*KD(i, m - 1)*KD(j, m)*KD(l, m - 1) + \
85
+ x**2*KD(i, m)*KD(j, m - 1)*KD(l, m - 1)
86
+
87
+
88
+ def test_deltaproduct_mul_add_x_y_add_kd_kd():
89
+ assert dp((x + y)*(KD(i, k) + KD(j, k)), (k, 1, 3)) == 0
90
+ assert dp((x + y)*(KD(i, k) + KD(j, k)), (k, 1, 1)) == \
91
+ (x + y)*(KD(i, 1) + KD(j, 1))
92
+ assert dp((x + y)*(KD(i, k) + KD(j, k)), (k, 2, 2)) == \
93
+ (x + y)*(KD(i, 2) + KD(j, 2))
94
+ assert dp((x + y)*(KD(i, k) + KD(j, k)), (k, 3, 3)) == \
95
+ (x + y)*(KD(i, 3) + KD(j, 3))
96
+ assert dp((x + y)*(KD(i, k) + KD(j, k)), (k, 1, l)) == KD(l, 0) + \
97
+ (x + y)*KD(i, 1)*KD(l, 1) + (x + y)*KD(j, 1)*KD(l, 1) + \
98
+ (x + y)**2*KD(i, 1)*KD(j, 2)*KD(l, 2) + \
99
+ (x + y)**2*KD(j, 1)*KD(i, 2)*KD(l, 2)
100
+ assert dp((x + y)*(KD(i, k) + KD(j, k)), (k, l, 3)) == KD(l, 4) + \
101
+ (x + y)*KD(i, 3)*KD(l, 3) + (x + y)*KD(j, 3)*KD(l, 3) + \
102
+ (x + y)**2*KD(i, 2)*KD(j, 3)*KD(l, 2) + \
103
+ (x + y)**2*KD(i, 3)*KD(j, 2)*KD(l, 2)
104
+ assert dp((x + y)*(KD(i, k) + KD(j, k)), (k, l, m)) == KD(l, m + 1) + \
105
+ (x + y)*KD(i, m)*KD(l, m) + (x + y)*KD(j, m)*KD(l, m) + \
106
+ (x + y)**2*KD(i, m - 1)*KD(j, m)*KD(l, m - 1) + \
107
+ (x + y)**2*KD(i, m)*KD(j, m - 1)*KD(l, m - 1)
108
+
109
+
110
+ def test_deltaproduct_add_mul_x_y_mul_x_kd():
111
+ assert dp(x*y + x*KD(i, j), (j, 1, 3)) == (x*y)**3 + \
112
+ x*(x*y)**2*KD(i, 1) + (x*y)*x*(x*y)*KD(i, 2) + (x*y)**2*x*KD(i, 3)
113
+ assert dp(x*y + x*KD(i, j), (j, 1, 1)) == x*y + x*KD(i, 1)
114
+ assert dp(x*y + x*KD(i, j), (j, 2, 2)) == x*y + x*KD(i, 2)
115
+ assert dp(x*y + x*KD(i, j), (j, 3, 3)) == x*y + x*KD(i, 3)
116
+ assert dp(x*y + x*KD(i, j), (j, 1, k)) == \
117
+ (x*y)**k + Piecewise(
118
+ ((x*y)**(i - 1)*x*(x*y)**(k - i), And(1 <= i, i <= k)),
119
+ (0, True)
120
+ )
121
+ assert dp(x*y + x*KD(i, j), (j, k, 3)) == \
122
+ (x*y)**(-k + 4) + Piecewise(
123
+ ((x*y)**(i - k)*x*(x*y)**(3 - i), And(k <= i, i <= 3)),
124
+ (0, True)
125
+ )
126
+ assert dp(x*y + x*KD(i, j), (j, k, l)) == \
127
+ (x*y)**(-k + l + 1) + Piecewise(
128
+ ((x*y)**(i - k)*x*(x*y)**(l - i), And(k <= i, i <= l)),
129
+ (0, True)
130
+ )
131
+
132
+
133
+ def test_deltaproduct_mul_x_add_y_kd():
134
+ assert dp(x*(y + KD(i, j)), (j, 1, 3)) == (x*y)**3 + \
135
+ x*(x*y)**2*KD(i, 1) + (x*y)*x*(x*y)*KD(i, 2) + (x*y)**2*x*KD(i, 3)
136
+ assert dp(x*(y + KD(i, j)), (j, 1, 1)) == x*(y + KD(i, 1))
137
+ assert dp(x*(y + KD(i, j)), (j, 2, 2)) == x*(y + KD(i, 2))
138
+ assert dp(x*(y + KD(i, j)), (j, 3, 3)) == x*(y + KD(i, 3))
139
+ assert dp(x*(y + KD(i, j)), (j, 1, k)) == \
140
+ (x*y)**k + Piecewise(
141
+ ((x*y)**(i - 1)*x*(x*y)**(k - i), And(1 <= i, i <= k)),
142
+ (0, True)
143
+ ).expand()
144
+ assert dp(x*(y + KD(i, j)), (j, k, 3)) == \
145
+ ((x*y)**(-k + 4) + Piecewise(
146
+ ((x*y)**(i - k)*x*(x*y)**(3 - i), And(k <= i, i <= 3)),
147
+ (0, True)
148
+ )).expand()
149
+ assert dp(x*(y + KD(i, j)), (j, k, l)) == \
150
+ ((x*y)**(-k + l + 1) + Piecewise(
151
+ ((x*y)**(i - k)*x*(x*y)**(l - i), And(k <= i, i <= l)),
152
+ (0, True)
153
+ )).expand()
154
+
155
+
156
+ def test_deltaproduct_mul_x_add_y_twokd():
157
+ assert dp(x*(y + 2*KD(i, j)), (j, 1, 3)) == (x*y)**3 + \
158
+ 2*x*(x*y)**2*KD(i, 1) + 2*x*y*x*x*y*KD(i, 2) + 2*(x*y)**2*x*KD(i, 3)
159
+ assert dp(x*(y + 2*KD(i, j)), (j, 1, 1)) == x*(y + 2*KD(i, 1))
160
+ assert dp(x*(y + 2*KD(i, j)), (j, 2, 2)) == x*(y + 2*KD(i, 2))
161
+ assert dp(x*(y + 2*KD(i, j)), (j, 3, 3)) == x*(y + 2*KD(i, 3))
162
+ assert dp(x*(y + 2*KD(i, j)), (j, 1, k)) == \
163
+ (x*y)**k + Piecewise(
164
+ (2*(x*y)**(i - 1)*x*(x*y)**(k - i), And(1 <= i, i <= k)),
165
+ (0, True)
166
+ ).expand()
167
+ assert dp(x*(y + 2*KD(i, j)), (j, k, 3)) == \
168
+ ((x*y)**(-k + 4) + Piecewise(
169
+ (2*(x*y)**(i - k)*x*(x*y)**(3 - i), And(k <= i, i <= 3)),
170
+ (0, True)
171
+ )).expand()
172
+ assert dp(x*(y + 2*KD(i, j)), (j, k, l)) == \
173
+ ((x*y)**(-k + l + 1) + Piecewise(
174
+ (2*(x*y)**(i - k)*x*(x*y)**(l - i), And(k <= i, i <= l)),
175
+ (0, True)
176
+ )).expand()
177
+
178
+
179
+ def test_deltaproduct_mul_add_x_y_add_y_kd():
180
+ assert dp((x + y)*(y + KD(i, j)), (j, 1, 3)) == ((x + y)*y)**3 + \
181
+ (x + y)*((x + y)*y)**2*KD(i, 1) + \
182
+ (x + y)*y*(x + y)**2*y*KD(i, 2) + \
183
+ ((x + y)*y)**2*(x + y)*KD(i, 3)
184
+ assert dp((x + y)*(y + KD(i, j)), (j, 1, 1)) == (x + y)*(y + KD(i, 1))
185
+ assert dp((x + y)*(y + KD(i, j)), (j, 2, 2)) == (x + y)*(y + KD(i, 2))
186
+ assert dp((x + y)*(y + KD(i, j)), (j, 3, 3)) == (x + y)*(y + KD(i, 3))
187
+ assert dp((x + y)*(y + KD(i, j)), (j, 1, k)) == \
188
+ ((x + y)*y)**k + Piecewise(
189
+ (((x + y)*y)**(-1)*((x + y)*y)**i*(x + y)*((x + y)*y
190
+ )**k*((x + y)*y)**(-i), (i >= 1) & (i <= k)), (0, True))
191
+ assert dp((x + y)*(y + KD(i, j)), (j, k, 3)) == (
192
+ (x + y)*y)**4*((x + y)*y)**(-k) + Piecewise((((x + y)*y)**i*(
193
+ (x + y)*y)**(-k)*(x + y)*((x + y)*y)**3*((x + y)*y)**(-i),
194
+ (i >= k) & (i <= 3)), (0, True))
195
+ assert dp((x + y)*(y + KD(i, j)), (j, k, l)) == \
196
+ (x + y)*y*((x + y)*y)**l*((x + y)*y)**(-k) + Piecewise(
197
+ (((x + y)*y)**i*((x + y)*y)**(-k)*(x + y)*((x + y)*y
198
+ )**l*((x + y)*y)**(-i), (i >= k) & (i <= l)), (0, True))
199
+
200
+
201
+ def test_deltaproduct_mul_add_x_kd_add_y_kd():
202
+ assert dp((x + KD(i, k))*(y + KD(i, j)), (j, 1, 3)) == \
203
+ KD(i, 1)*(KD(i, k) + x)*((KD(i, k) + x)*y)**2 + \
204
+ KD(i, 2)*(KD(i, k) + x)*y*(KD(i, k) + x)**2*y + \
205
+ KD(i, 3)*((KD(i, k) + x)*y)**2*(KD(i, k) + x) + \
206
+ ((KD(i, k) + x)*y)**3
207
+ assert dp((x + KD(i, k))*(y + KD(i, j)), (j, 1, 1)) == \
208
+ (x + KD(i, k))*(y + KD(i, 1))
209
+ assert dp((x + KD(i, k))*(y + KD(i, j)), (j, 2, 2)) == \
210
+ (x + KD(i, k))*(y + KD(i, 2))
211
+ assert dp((x + KD(i, k))*(y + KD(i, j)), (j, 3, 3)) == \
212
+ (x + KD(i, k))*(y + KD(i, 3))
213
+ assert dp((x + KD(i, k))*(y + KD(i, j)), (j, 1, k)) == \
214
+ ((KD(i, k) + x)*y)**k + Piecewise(
215
+ (((KD(i, k) + x)*y)**(-1)*((KD(i, k) + x)*y)**i*(KD(i, k) + x
216
+ )*((KD(i, k) + x)*y)**k*((KD(i, k) + x)*y)**(-i), (i >= 1
217
+ ) & (i <= k)), (0, True))
218
+ assert dp((x + KD(i, k))*(y + KD(i, j)), (j, k, 3)) == (
219
+ (KD(i, k) + x)*y)**4*((KD(i, k) + x)*y)**(-k) + Piecewise(
220
+ (((KD(i, k) + x)*y)**i*((KD(i, k) + x)*y)**(-k)*(KD(i, k)
221
+ + x)*((KD(i, k) + x)*y)**3*((KD(i, k) + x)*y)**(-i),
222
+ (i >= k) & (i <= 3)), (0, True))
223
+ assert dp((x + KD(i, k))*(y + KD(i, j)), (j, k, l)) == (
224
+ KD(i, k) + x)*y*((KD(i, k) + x)*y)**l*((KD(i, k) + x)*y
225
+ )**(-k) + Piecewise((((KD(i, k) + x)*y)**i*((KD(i, k) + x
226
+ )*y)**(-k)*(KD(i, k) + x)*((KD(i, k) + x)*y)**l*((KD(i, k) + x
227
+ )*y)**(-i), (i >= k) & (i <= l)), (0, True))
228
+
229
+
230
+ def test_deltasummation_trivial():
231
+ assert ds(x, (j, 1, 0)) == 0
232
+ assert ds(x, (j, 1, 3)) == 3*x
233
+ assert ds(x + y, (j, 1, 3)) == 3*(x + y)
234
+ assert ds(x*y, (j, 1, 3)) == 3*x*y
235
+ assert ds(KD(i, j), (k, 1, 3)) == 3*KD(i, j)
236
+ assert ds(x*KD(i, j), (k, 1, 3)) == 3*x*KD(i, j)
237
+ assert ds(x*y*KD(i, j), (k, 1, 3)) == 3*x*y*KD(i, j)
238
+
239
+
240
+ def test_deltasummation_basic_numerical():
241
+ n = symbols('n', integer=True, nonzero=True)
242
+ assert ds(KD(n, 0), (n, 1, 3)) == 0
243
+
244
+ # return unevaluated, until it gets implemented
245
+ assert ds(KD(i**2, j**2), (j, -oo, oo)) == \
246
+ Sum(KD(i**2, j**2), (j, -oo, oo))
247
+
248
+ assert Piecewise((KD(i, k), And(1 <= i, i <= 3)), (0, True)) == \
249
+ ds(KD(i, j)*KD(j, k), (j, 1, 3)) == \
250
+ ds(KD(j, k)*KD(i, j), (j, 1, 3))
251
+
252
+ assert ds(KD(i, k), (k, -oo, oo)) == 1
253
+ assert ds(KD(i, k), (k, 0, oo)) == Piecewise((1, S.Zero <= i), (0, True))
254
+ assert ds(KD(i, k), (k, 1, 3)) == \
255
+ Piecewise((1, And(1 <= i, i <= 3)), (0, True))
256
+ assert ds(k*KD(i, j)*KD(j, k), (k, -oo, oo)) == j*KD(i, j)
257
+ assert ds(j*KD(i, j), (j, -oo, oo)) == i
258
+ assert ds(i*KD(i, j), (i, -oo, oo)) == j
259
+ assert ds(x, (i, 1, 3)) == 3*x
260
+ assert ds((i + j)*KD(i, j), (j, -oo, oo)) == 2*i
261
+
262
+
263
+ def test_deltasummation_basic_symbolic():
264
+ assert ds(KD(i, j), (j, 1, 3)) == \
265
+ Piecewise((1, And(1 <= i, i <= 3)), (0, True))
266
+ assert ds(KD(i, j), (j, 1, 1)) == Piecewise((1, Eq(i, 1)), (0, True))
267
+ assert ds(KD(i, j), (j, 2, 2)) == Piecewise((1, Eq(i, 2)), (0, True))
268
+ assert ds(KD(i, j), (j, 3, 3)) == Piecewise((1, Eq(i, 3)), (0, True))
269
+ assert ds(KD(i, j), (j, 1, k)) == \
270
+ Piecewise((1, And(1 <= i, i <= k)), (0, True))
271
+ assert ds(KD(i, j), (j, k, 3)) == \
272
+ Piecewise((1, And(k <= i, i <= 3)), (0, True))
273
+ assert ds(KD(i, j), (j, k, l)) == \
274
+ Piecewise((1, And(k <= i, i <= l)), (0, True))
275
+
276
+
277
+ def test_deltasummation_mul_x_kd():
278
+ assert ds(x*KD(i, j), (j, 1, 3)) == \
279
+ Piecewise((x, And(1 <= i, i <= 3)), (0, True))
280
+ assert ds(x*KD(i, j), (j, 1, 1)) == Piecewise((x, Eq(i, 1)), (0, True))
281
+ assert ds(x*KD(i, j), (j, 2, 2)) == Piecewise((x, Eq(i, 2)), (0, True))
282
+ assert ds(x*KD(i, j), (j, 3, 3)) == Piecewise((x, Eq(i, 3)), (0, True))
283
+ assert ds(x*KD(i, j), (j, 1, k)) == \
284
+ Piecewise((x, And(1 <= i, i <= k)), (0, True))
285
+ assert ds(x*KD(i, j), (j, k, 3)) == \
286
+ Piecewise((x, And(k <= i, i <= 3)), (0, True))
287
+ assert ds(x*KD(i, j), (j, k, l)) == \
288
+ Piecewise((x, And(k <= i, i <= l)), (0, True))
289
+
290
+
291
+ def test_deltasummation_mul_add_x_y_kd():
292
+ assert ds((x + y)*KD(i, j), (j, 1, 3)) == \
293
+ Piecewise((x + y, And(1 <= i, i <= 3)), (0, True))
294
+ assert ds((x + y)*KD(i, j), (j, 1, 1)) == \
295
+ Piecewise((x + y, Eq(i, 1)), (0, True))
296
+ assert ds((x + y)*KD(i, j), (j, 2, 2)) == \
297
+ Piecewise((x + y, Eq(i, 2)), (0, True))
298
+ assert ds((x + y)*KD(i, j), (j, 3, 3)) == \
299
+ Piecewise((x + y, Eq(i, 3)), (0, True))
300
+ assert ds((x + y)*KD(i, j), (j, 1, k)) == \
301
+ Piecewise((x + y, And(1 <= i, i <= k)), (0, True))
302
+ assert ds((x + y)*KD(i, j), (j, k, 3)) == \
303
+ Piecewise((x + y, And(k <= i, i <= 3)), (0, True))
304
+ assert ds((x + y)*KD(i, j), (j, k, l)) == \
305
+ Piecewise((x + y, And(k <= i, i <= l)), (0, True))
306
+
307
+
308
+ def test_deltasummation_add_kd_kd():
309
+ assert ds(KD(i, k) + KD(j, k), (k, 1, 3)) == piecewise_fold(
310
+ Piecewise((1, And(1 <= i, i <= 3)), (0, True)) +
311
+ Piecewise((1, And(1 <= j, j <= 3)), (0, True)))
312
+ assert ds(KD(i, k) + KD(j, k), (k, 1, 1)) == piecewise_fold(
313
+ Piecewise((1, Eq(i, 1)), (0, True)) +
314
+ Piecewise((1, Eq(j, 1)), (0, True)))
315
+ assert ds(KD(i, k) + KD(j, k), (k, 2, 2)) == piecewise_fold(
316
+ Piecewise((1, Eq(i, 2)), (0, True)) +
317
+ Piecewise((1, Eq(j, 2)), (0, True)))
318
+ assert ds(KD(i, k) + KD(j, k), (k, 3, 3)) == piecewise_fold(
319
+ Piecewise((1, Eq(i, 3)), (0, True)) +
320
+ Piecewise((1, Eq(j, 3)), (0, True)))
321
+ assert ds(KD(i, k) + KD(j, k), (k, 1, l)) == piecewise_fold(
322
+ Piecewise((1, And(1 <= i, i <= l)), (0, True)) +
323
+ Piecewise((1, And(1 <= j, j <= l)), (0, True)))
324
+ assert ds(KD(i, k) + KD(j, k), (k, l, 3)) == piecewise_fold(
325
+ Piecewise((1, And(l <= i, i <= 3)), (0, True)) +
326
+ Piecewise((1, And(l <= j, j <= 3)), (0, True)))
327
+ assert ds(KD(i, k) + KD(j, k), (k, l, m)) == piecewise_fold(
328
+ Piecewise((1, And(l <= i, i <= m)), (0, True)) +
329
+ Piecewise((1, And(l <= j, j <= m)), (0, True)))
330
+
331
+
332
+ def test_deltasummation_add_mul_x_kd_kd():
333
+ assert ds(x*KD(i, k) + KD(j, k), (k, 1, 3)) == piecewise_fold(
334
+ Piecewise((x, And(1 <= i, i <= 3)), (0, True)) +
335
+ Piecewise((1, And(1 <= j, j <= 3)), (0, True)))
336
+ assert ds(x*KD(i, k) + KD(j, k), (k, 1, 1)) == piecewise_fold(
337
+ Piecewise((x, Eq(i, 1)), (0, True)) +
338
+ Piecewise((1, Eq(j, 1)), (0, True)))
339
+ assert ds(x*KD(i, k) + KD(j, k), (k, 2, 2)) == piecewise_fold(
340
+ Piecewise((x, Eq(i, 2)), (0, True)) +
341
+ Piecewise((1, Eq(j, 2)), (0, True)))
342
+ assert ds(x*KD(i, k) + KD(j, k), (k, 3, 3)) == piecewise_fold(
343
+ Piecewise((x, Eq(i, 3)), (0, True)) +
344
+ Piecewise((1, Eq(j, 3)), (0, True)))
345
+ assert ds(x*KD(i, k) + KD(j, k), (k, 1, l)) == piecewise_fold(
346
+ Piecewise((x, And(1 <= i, i <= l)), (0, True)) +
347
+ Piecewise((1, And(1 <= j, j <= l)), (0, True)))
348
+ assert ds(x*KD(i, k) + KD(j, k), (k, l, 3)) == piecewise_fold(
349
+ Piecewise((x, And(l <= i, i <= 3)), (0, True)) +
350
+ Piecewise((1, And(l <= j, j <= 3)), (0, True)))
351
+ assert ds(x*KD(i, k) + KD(j, k), (k, l, m)) == piecewise_fold(
352
+ Piecewise((x, And(l <= i, i <= m)), (0, True)) +
353
+ Piecewise((1, And(l <= j, j <= m)), (0, True)))
354
+
355
+
356
+ def test_deltasummation_mul_x_add_kd_kd():
357
+ assert ds(x*(KD(i, k) + KD(j, k)), (k, 1, 3)) == piecewise_fold(
358
+ Piecewise((x, And(1 <= i, i <= 3)), (0, True)) +
359
+ Piecewise((x, And(1 <= j, j <= 3)), (0, True)))
360
+ assert ds(x*(KD(i, k) + KD(j, k)), (k, 1, 1)) == piecewise_fold(
361
+ Piecewise((x, Eq(i, 1)), (0, True)) +
362
+ Piecewise((x, Eq(j, 1)), (0, True)))
363
+ assert ds(x*(KD(i, k) + KD(j, k)), (k, 2, 2)) == piecewise_fold(
364
+ Piecewise((x, Eq(i, 2)), (0, True)) +
365
+ Piecewise((x, Eq(j, 2)), (0, True)))
366
+ assert ds(x*(KD(i, k) + KD(j, k)), (k, 3, 3)) == piecewise_fold(
367
+ Piecewise((x, Eq(i, 3)), (0, True)) +
368
+ Piecewise((x, Eq(j, 3)), (0, True)))
369
+ assert ds(x*(KD(i, k) + KD(j, k)), (k, 1, l)) == piecewise_fold(
370
+ Piecewise((x, And(1 <= i, i <= l)), (0, True)) +
371
+ Piecewise((x, And(1 <= j, j <= l)), (0, True)))
372
+ assert ds(x*(KD(i, k) + KD(j, k)), (k, l, 3)) == piecewise_fold(
373
+ Piecewise((x, And(l <= i, i <= 3)), (0, True)) +
374
+ Piecewise((x, And(l <= j, j <= 3)), (0, True)))
375
+ assert ds(x*(KD(i, k) + KD(j, k)), (k, l, m)) == piecewise_fold(
376
+ Piecewise((x, And(l <= i, i <= m)), (0, True)) +
377
+ Piecewise((x, And(l <= j, j <= m)), (0, True)))
378
+
379
+
380
+ def test_deltasummation_mul_add_x_y_add_kd_kd():
381
+ assert ds((x + y)*(KD(i, k) + KD(j, k)), (k, 1, 3)) == piecewise_fold(
382
+ Piecewise((x + y, And(1 <= i, i <= 3)), (0, True)) +
383
+ Piecewise((x + y, And(1 <= j, j <= 3)), (0, True)))
384
+ assert ds((x + y)*(KD(i, k) + KD(j, k)), (k, 1, 1)) == piecewise_fold(
385
+ Piecewise((x + y, Eq(i, 1)), (0, True)) +
386
+ Piecewise((x + y, Eq(j, 1)), (0, True)))
387
+ assert ds((x + y)*(KD(i, k) + KD(j, k)), (k, 2, 2)) == piecewise_fold(
388
+ Piecewise((x + y, Eq(i, 2)), (0, True)) +
389
+ Piecewise((x + y, Eq(j, 2)), (0, True)))
390
+ assert ds((x + y)*(KD(i, k) + KD(j, k)), (k, 3, 3)) == piecewise_fold(
391
+ Piecewise((x + y, Eq(i, 3)), (0, True)) +
392
+ Piecewise((x + y, Eq(j, 3)), (0, True)))
393
+ assert ds((x + y)*(KD(i, k) + KD(j, k)), (k, 1, l)) == piecewise_fold(
394
+ Piecewise((x + y, And(1 <= i, i <= l)), (0, True)) +
395
+ Piecewise((x + y, And(1 <= j, j <= l)), (0, True)))
396
+ assert ds((x + y)*(KD(i, k) + KD(j, k)), (k, l, 3)) == piecewise_fold(
397
+ Piecewise((x + y, And(l <= i, i <= 3)), (0, True)) +
398
+ Piecewise((x + y, And(l <= j, j <= 3)), (0, True)))
399
+ assert ds((x + y)*(KD(i, k) + KD(j, k)), (k, l, m)) == piecewise_fold(
400
+ Piecewise((x + y, And(l <= i, i <= m)), (0, True)) +
401
+ Piecewise((x + y, And(l <= j, j <= m)), (0, True)))
402
+
403
+
404
+ def test_deltasummation_add_mul_x_y_mul_x_kd():
405
+ assert ds(x*y + x*KD(i, j), (j, 1, 3)) == \
406
+ Piecewise((3*x*y + x, And(1 <= i, i <= 3)), (3*x*y, True))
407
+ assert ds(x*y + x*KD(i, j), (j, 1, 1)) == \
408
+ Piecewise((x*y + x, Eq(i, 1)), (x*y, True))
409
+ assert ds(x*y + x*KD(i, j), (j, 2, 2)) == \
410
+ Piecewise((x*y + x, Eq(i, 2)), (x*y, True))
411
+ assert ds(x*y + x*KD(i, j), (j, 3, 3)) == \
412
+ Piecewise((x*y + x, Eq(i, 3)), (x*y, True))
413
+ assert ds(x*y + x*KD(i, j), (j, 1, k)) == \
414
+ Piecewise((k*x*y + x, And(1 <= i, i <= k)), (k*x*y, True))
415
+ assert ds(x*y + x*KD(i, j), (j, k, 3)) == \
416
+ Piecewise(((4 - k)*x*y + x, And(k <= i, i <= 3)), ((4 - k)*x*y, True))
417
+ assert ds(x*y + x*KD(i, j), (j, k, l)) == Piecewise(
418
+ ((l - k + 1)*x*y + x, And(k <= i, i <= l)), ((l - k + 1)*x*y, True))
419
+
420
+
421
+ def test_deltasummation_mul_x_add_y_kd():
422
+ assert ds(x*(y + KD(i, j)), (j, 1, 3)) == \
423
+ Piecewise((3*x*y + x, And(1 <= i, i <= 3)), (3*x*y, True))
424
+ assert ds(x*(y + KD(i, j)), (j, 1, 1)) == \
425
+ Piecewise((x*y + x, Eq(i, 1)), (x*y, True))
426
+ assert ds(x*(y + KD(i, j)), (j, 2, 2)) == \
427
+ Piecewise((x*y + x, Eq(i, 2)), (x*y, True))
428
+ assert ds(x*(y + KD(i, j)), (j, 3, 3)) == \
429
+ Piecewise((x*y + x, Eq(i, 3)), (x*y, True))
430
+ assert ds(x*(y + KD(i, j)), (j, 1, k)) == \
431
+ Piecewise((k*x*y + x, And(1 <= i, i <= k)), (k*x*y, True))
432
+ assert ds(x*(y + KD(i, j)), (j, k, 3)) == \
433
+ Piecewise(((4 - k)*x*y + x, And(k <= i, i <= 3)), ((4 - k)*x*y, True))
434
+ assert ds(x*(y + KD(i, j)), (j, k, l)) == Piecewise(
435
+ ((l - k + 1)*x*y + x, And(k <= i, i <= l)), ((l - k + 1)*x*y, True))
436
+
437
+
438
+ def test_deltasummation_mul_x_add_y_twokd():
439
+ assert ds(x*(y + 2*KD(i, j)), (j, 1, 3)) == \
440
+ Piecewise((3*x*y + 2*x, And(1 <= i, i <= 3)), (3*x*y, True))
441
+ assert ds(x*(y + 2*KD(i, j)), (j, 1, 1)) == \
442
+ Piecewise((x*y + 2*x, Eq(i, 1)), (x*y, True))
443
+ assert ds(x*(y + 2*KD(i, j)), (j, 2, 2)) == \
444
+ Piecewise((x*y + 2*x, Eq(i, 2)), (x*y, True))
445
+ assert ds(x*(y + 2*KD(i, j)), (j, 3, 3)) == \
446
+ Piecewise((x*y + 2*x, Eq(i, 3)), (x*y, True))
447
+ assert ds(x*(y + 2*KD(i, j)), (j, 1, k)) == \
448
+ Piecewise((k*x*y + 2*x, And(1 <= i, i <= k)), (k*x*y, True))
449
+ assert ds(x*(y + 2*KD(i, j)), (j, k, 3)) == Piecewise(
450
+ ((4 - k)*x*y + 2*x, And(k <= i, i <= 3)), ((4 - k)*x*y, True))
451
+ assert ds(x*(y + 2*KD(i, j)), (j, k, l)) == Piecewise(
452
+ ((l - k + 1)*x*y + 2*x, And(k <= i, i <= l)), ((l - k + 1)*x*y, True))
453
+
454
+
455
+ def test_deltasummation_mul_add_x_y_add_y_kd():
456
+ assert ds((x + y)*(y + KD(i, j)), (j, 1, 3)) == Piecewise(
457
+ (3*(x + y)*y + x + y, And(1 <= i, i <= 3)), (3*(x + y)*y, True))
458
+ assert ds((x + y)*(y + KD(i, j)), (j, 1, 1)) == \
459
+ Piecewise(((x + y)*y + x + y, Eq(i, 1)), ((x + y)*y, True))
460
+ assert ds((x + y)*(y + KD(i, j)), (j, 2, 2)) == \
461
+ Piecewise(((x + y)*y + x + y, Eq(i, 2)), ((x + y)*y, True))
462
+ assert ds((x + y)*(y + KD(i, j)), (j, 3, 3)) == \
463
+ Piecewise(((x + y)*y + x + y, Eq(i, 3)), ((x + y)*y, True))
464
+ assert ds((x + y)*(y + KD(i, j)), (j, 1, k)) == Piecewise(
465
+ (k*(x + y)*y + x + y, And(1 <= i, i <= k)), (k*(x + y)*y, True))
466
+ assert ds((x + y)*(y + KD(i, j)), (j, k, 3)) == Piecewise(
467
+ ((4 - k)*(x + y)*y + x + y, And(k <= i, i <= 3)),
468
+ ((4 - k)*(x + y)*y, True))
469
+ assert ds((x + y)*(y + KD(i, j)), (j, k, l)) == Piecewise(
470
+ ((l - k + 1)*(x + y)*y + x + y, And(k <= i, i <= l)),
471
+ ((l - k + 1)*(x + y)*y, True))
472
+
473
+
474
+ def test_deltasummation_mul_add_x_kd_add_y_kd():
475
+ assert ds((x + KD(i, k))*(y + KD(i, j)), (j, 1, 3)) == piecewise_fold(
476
+ Piecewise((KD(i, k) + x, And(1 <= i, i <= 3)), (0, True)) +
477
+ 3*(KD(i, k) + x)*y)
478
+ assert ds((x + KD(i, k))*(y + KD(i, j)), (j, 1, 1)) == piecewise_fold(
479
+ Piecewise((KD(i, k) + x, Eq(i, 1)), (0, True)) +
480
+ (KD(i, k) + x)*y)
481
+ assert ds((x + KD(i, k))*(y + KD(i, j)), (j, 2, 2)) == piecewise_fold(
482
+ Piecewise((KD(i, k) + x, Eq(i, 2)), (0, True)) +
483
+ (KD(i, k) + x)*y)
484
+ assert ds((x + KD(i, k))*(y + KD(i, j)), (j, 3, 3)) == piecewise_fold(
485
+ Piecewise((KD(i, k) + x, Eq(i, 3)), (0, True)) +
486
+ (KD(i, k) + x)*y)
487
+ assert ds((x + KD(i, k))*(y + KD(i, j)), (j, 1, k)) == piecewise_fold(
488
+ Piecewise((KD(i, k) + x, And(1 <= i, i <= k)), (0, True)) +
489
+ k*(KD(i, k) + x)*y)
490
+ assert ds((x + KD(i, k))*(y + KD(i, j)), (j, k, 3)) == piecewise_fold(
491
+ Piecewise((KD(i, k) + x, And(k <= i, i <= 3)), (0, True)) +
492
+ (4 - k)*(KD(i, k) + x)*y)
493
+ assert ds((x + KD(i, k))*(y + KD(i, j)), (j, k, l)) == piecewise_fold(
494
+ Piecewise((KD(i, k) + x, And(k <= i, i <= l)), (0, True)) +
495
+ (l - k + 1)*(KD(i, k) + x)*y)
496
+
497
+
498
+ def test_extract_delta():
499
+ raises(ValueError, lambda: _extract_delta(KD(i, j) + KD(k, l), i))
tuning-competition-baseline/.venv/lib/python3.11/site-packages/sympy/concrete/tests/test_products.py ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sympy.concrete.products import (Product, product)
2
+ from sympy.concrete.summations import Sum
3
+ from sympy.core.function import (Derivative, Function, diff)
4
+ from sympy.core.numbers import (Rational, oo, pi)
5
+ from sympy.core.singleton import S
6
+ from sympy.core.symbol import (Dummy, Symbol, symbols)
7
+ from sympy.functions.combinatorial.factorials import (rf, factorial)
8
+ from sympy.functions.elementary.exponential import (exp, log)
9
+ from sympy.functions.elementary.miscellaneous import sqrt
10
+ from sympy.functions.elementary.trigonometric import (cos, sin)
11
+ from sympy.functions.special.tensor_functions import KroneckerDelta
12
+ from sympy.simplify.combsimp import combsimp
13
+ from sympy.simplify.simplify import simplify
14
+ from sympy.testing.pytest import raises
15
+
16
+ a, k, n, m, x = symbols('a,k,n,m,x', integer=True)
17
+ f = Function('f')
18
+
19
+
20
+ def test_karr_convention():
21
+ # Test the Karr product convention that we want to hold.
22
+ # See his paper "Summation in Finite Terms" for a detailed
23
+ # reasoning why we really want exactly this definition.
24
+ # The convention is described for sums on page 309 and
25
+ # essentially in section 1.4, definition 3. For products
26
+ # we can find in analogy:
27
+ #
28
+ # \prod_{m <= i < n} f(i) 'has the obvious meaning' for m < n
29
+ # \prod_{m <= i < n} f(i) = 0 for m = n
30
+ # \prod_{m <= i < n} f(i) = 1 / \prod_{n <= i < m} f(i) for m > n
31
+ #
32
+ # It is important to note that he defines all products with
33
+ # the upper limit being *exclusive*.
34
+ # In contrast, SymPy and the usual mathematical notation has:
35
+ #
36
+ # prod_{i = a}^b f(i) = f(a) * f(a+1) * ... * f(b-1) * f(b)
37
+ #
38
+ # with the upper limit *inclusive*. So translating between
39
+ # the two we find that:
40
+ #
41
+ # \prod_{m <= i < n} f(i) = \prod_{i = m}^{n-1} f(i)
42
+ #
43
+ # where we intentionally used two different ways to typeset the
44
+ # products and its limits.
45
+
46
+ i = Symbol("i", integer=True)
47
+ k = Symbol("k", integer=True)
48
+ j = Symbol("j", integer=True, positive=True)
49
+
50
+ # A simple example with a concrete factors and symbolic limits.
51
+
52
+ # The normal product: m = k and n = k + j and therefore m < n:
53
+ m = k
54
+ n = k + j
55
+
56
+ a = m
57
+ b = n - 1
58
+ S1 = Product(i**2, (i, a, b)).doit()
59
+
60
+ # The reversed product: m = k + j and n = k and therefore m > n:
61
+ m = k + j
62
+ n = k
63
+
64
+ a = m
65
+ b = n - 1
66
+ S2 = Product(i**2, (i, a, b)).doit()
67
+
68
+ assert S1 * S2 == 1
69
+
70
+ # Test the empty product: m = k and n = k and therefore m = n:
71
+ m = k
72
+ n = k
73
+
74
+ a = m
75
+ b = n - 1
76
+ Sz = Product(i**2, (i, a, b)).doit()
77
+
78
+ assert Sz == 1
79
+
80
+ # Another example this time with an unspecified factor and
81
+ # numeric limits. (We can not do both tests in the same example.)
82
+ f = Function("f")
83
+
84
+ # The normal product with m < n:
85
+ m = 2
86
+ n = 11
87
+
88
+ a = m
89
+ b = n - 1
90
+ S1 = Product(f(i), (i, a, b)).doit()
91
+
92
+ # The reversed product with m > n:
93
+ m = 11
94
+ n = 2
95
+
96
+ a = m
97
+ b = n - 1
98
+ S2 = Product(f(i), (i, a, b)).doit()
99
+
100
+ assert simplify(S1 * S2) == 1
101
+
102
+ # Test the empty product with m = n:
103
+ m = 5
104
+ n = 5
105
+
106
+ a = m
107
+ b = n - 1
108
+ Sz = Product(f(i), (i, a, b)).doit()
109
+
110
+ assert Sz == 1
111
+
112
+
113
+ def test_karr_proposition_2a():
114
+ # Test Karr, page 309, proposition 2, part a
115
+ i, u, v = symbols('i u v', integer=True)
116
+
117
+ def test_the_product(m, n):
118
+ # g
119
+ g = i**3 + 2*i**2 - 3*i
120
+ # f = Delta g
121
+ f = simplify(g.subs(i, i+1) / g)
122
+ # The product
123
+ a = m
124
+ b = n - 1
125
+ P = Product(f, (i, a, b)).doit()
126
+ # Test if Product_{m <= i < n} f(i) = g(n) / g(m)
127
+ assert combsimp(P / (g.subs(i, n) / g.subs(i, m))) == 1
128
+
129
+ # m < n
130
+ test_the_product(u, u + v)
131
+ # m = n
132
+ test_the_product(u, u)
133
+ # m > n
134
+ test_the_product(u + v, u)
135
+
136
+
137
+ def test_karr_proposition_2b():
138
+ # Test Karr, page 309, proposition 2, part b
139
+ i, u, v, w = symbols('i u v w', integer=True)
140
+
141
+ def test_the_product(l, n, m):
142
+ # Productmand
143
+ s = i**3
144
+ # First product
145
+ a = l
146
+ b = n - 1
147
+ S1 = Product(s, (i, a, b)).doit()
148
+ # Second product
149
+ a = l
150
+ b = m - 1
151
+ S2 = Product(s, (i, a, b)).doit()
152
+ # Third product
153
+ a = m
154
+ b = n - 1
155
+ S3 = Product(s, (i, a, b)).doit()
156
+ # Test if S1 = S2 * S3 as required
157
+ assert combsimp(S1 / (S2 * S3)) == 1
158
+
159
+ # l < m < n
160
+ test_the_product(u, u + v, u + v + w)
161
+ # l < m = n
162
+ test_the_product(u, u + v, u + v)
163
+ # l < m > n
164
+ test_the_product(u, u + v + w, v)
165
+ # l = m < n
166
+ test_the_product(u, u, u + v)
167
+ # l = m = n
168
+ test_the_product(u, u, u)
169
+ # l = m > n
170
+ test_the_product(u + v, u + v, u)
171
+ # l > m < n
172
+ test_the_product(u + v, u, u + w)
173
+ # l > m = n
174
+ test_the_product(u + v, u, u)
175
+ # l > m > n
176
+ test_the_product(u + v + w, u + v, u)
177
+
178
+
179
+ def test_simple_products():
180
+ assert product(2, (k, a, n)) == 2**(n - a + 1)
181
+ assert product(k, (k, 1, n)) == factorial(n)
182
+ assert product(k**3, (k, 1, n)) == factorial(n)**3
183
+
184
+ assert product(k + 1, (k, 0, n - 1)) == factorial(n)
185
+ assert product(k + 1, (k, a, n - 1)) == rf(1 + a, n - a)
186
+
187
+ assert product(cos(k), (k, 0, 5)) == cos(1)*cos(2)*cos(3)*cos(4)*cos(5)
188
+ assert product(cos(k), (k, 3, 5)) == cos(3)*cos(4)*cos(5)
189
+ assert product(cos(k), (k, 1, Rational(5, 2))) != cos(1)*cos(2)
190
+
191
+ assert isinstance(product(k**k, (k, 1, n)), Product)
192
+
193
+ assert Product(x**k, (k, 1, n)).variables == [k]
194
+
195
+ raises(ValueError, lambda: Product(n))
196
+ raises(ValueError, lambda: Product(n, k))
197
+ raises(ValueError, lambda: Product(n, k, 1))
198
+ raises(ValueError, lambda: Product(n, k, 1, 10))
199
+ raises(ValueError, lambda: Product(n, (k, 1)))
200
+
201
+ assert product(1, (n, 1, oo)) == 1 # issue 8301
202
+ assert product(2, (n, 1, oo)) is oo
203
+ assert product(-1, (n, 1, oo)).func is Product
204
+
205
+
206
+ def test_multiple_products():
207
+ assert product(x, (n, 1, k), (k, 1, m)) == x**(m**2/2 + m/2)
208
+ assert product(f(n), (
209
+ n, 1, m), (m, 1, k)) == Product(f(n), (n, 1, m), (m, 1, k)).doit()
210
+ assert Product(f(n), (m, 1, k), (n, 1, k)).doit() == \
211
+ Product(Product(f(n), (m, 1, k)), (n, 1, k)).doit() == \
212
+ product(f(n), (m, 1, k), (n, 1, k)) == \
213
+ product(product(f(n), (m, 1, k)), (n, 1, k)) == \
214
+ Product(f(n)**k, (n, 1, k))
215
+ assert Product(
216
+ x, (x, 1, k), (k, 1, n)).doit() == Product(factorial(k), (k, 1, n))
217
+
218
+ assert Product(x**k, (n, 1, k), (k, 1, m)).variables == [n, k]
219
+
220
+
221
+ def test_rational_products():
222
+ assert product(1 + 1/k, (k, 1, n)) == rf(2, n)/factorial(n)
223
+
224
+
225
+ def test_special_products():
226
+ # Wallis product
227
+ assert product((4*k)**2 / (4*k**2 - 1), (k, 1, n)) == \
228
+ 4**n*factorial(n)**2/rf(S.Half, n)/rf(Rational(3, 2), n)
229
+
230
+ # Euler's product formula for sin
231
+ assert product(1 + a/k**2, (k, 1, n)) == \
232
+ rf(1 - sqrt(-a), n)*rf(1 + sqrt(-a), n)/factorial(n)**2
233
+
234
+
235
+ def test__eval_product():
236
+ from sympy.abc import i, n
237
+ # issue 4809
238
+ a = Function('a')
239
+ assert product(2*a(i), (i, 1, n)) == 2**n * Product(a(i), (i, 1, n))
240
+ # issue 4810
241
+ assert product(2**i, (i, 1, n)) == 2**(n*(n + 1)/2)
242
+ k, m = symbols('k m', integer=True)
243
+ assert product(2**i, (i, k, m)) == 2**(-k**2/2 + k/2 + m**2/2 + m/2)
244
+ n = Symbol('n', negative=True, integer=True)
245
+ p = Symbol('p', positive=True, integer=True)
246
+ assert product(2**i, (i, n, p)) == 2**(-n**2/2 + n/2 + p**2/2 + p/2)
247
+ assert product(2**i, (i, p, n)) == 2**(n**2/2 + n/2 - p**2/2 + p/2)
248
+
249
+
250
+ def test_product_pow():
251
+ # issue 4817
252
+ assert product(2**f(k), (k, 1, n)) == 2**Sum(f(k), (k, 1, n))
253
+ assert product(2**(2*f(k)), (k, 1, n)) == 2**Sum(2*f(k), (k, 1, n))
254
+
255
+
256
+ def test_infinite_product():
257
+ # issue 5737
258
+ assert isinstance(Product(2**(1/factorial(n)), (n, 0, oo)), Product)
259
+
260
+
261
+ def test_conjugate_transpose():
262
+ p = Product(x**k, (k, 1, 3))
263
+ assert p.adjoint().doit() == p.doit().adjoint()
264
+ assert p.conjugate().doit() == p.doit().conjugate()
265
+ assert p.transpose().doit() == p.doit().transpose()
266
+
267
+ A, B = symbols("A B", commutative=False)
268
+ p = Product(A*B**k, (k, 1, 3))
269
+ assert p.adjoint().doit() == p.doit().adjoint()
270
+ assert p.conjugate().doit() == p.doit().conjugate()
271
+ assert p.transpose().doit() == p.doit().transpose()
272
+
273
+ p = Product(B**k*A, (k, 1, 3))
274
+ assert p.adjoint().doit() == p.doit().adjoint()
275
+ assert p.conjugate().doit() == p.doit().conjugate()
276
+ assert p.transpose().doit() == p.doit().transpose()
277
+
278
+
279
+ def test_simplify_prod():
280
+ y, t, b, c, v, d = symbols('y, t, b, c, v, d', integer = True)
281
+
282
+ _simplify = lambda e: simplify(e, doit=False)
283
+ assert _simplify(Product(x*y, (x, n, m), (y, a, k)) * \
284
+ Product(y, (x, n, m), (y, a, k))) == \
285
+ Product(x*y**2, (x, n, m), (y, a, k))
286
+ assert _simplify(3 * y* Product(x, (x, n, m)) * Product(x, (x, m + 1, a))) \
287
+ == 3 * y * Product(x, (x, n, a))
288
+ assert _simplify(Product(x, (x, k + 1, a)) * Product(x, (x, n, k))) == \
289
+ Product(x, (x, n, a))
290
+ assert _simplify(Product(x, (x, k + 1, a)) * Product(x + 1, (x, n, k))) == \
291
+ Product(x, (x, k + 1, a)) * Product(x + 1, (x, n, k))
292
+ assert _simplify(Product(x, (t, a, b)) * Product(y, (t, a, b)) * \
293
+ Product(x, (t, b+1, c))) == Product(x*y, (t, a, b)) * \
294
+ Product(x, (t, b+1, c))
295
+ assert _simplify(Product(x, (t, a, b)) * Product(x, (t, b+1, c)) * \
296
+ Product(y, (t, a, b))) == Product(x*y, (t, a, b)) * \
297
+ Product(x, (t, b+1, c))
298
+ assert _simplify(Product(sin(t)**2 + cos(t)**2 + 1, (t, a, b))) == \
299
+ Product(2, (t, a, b))
300
+ assert _simplify(Product(sin(t)**2 + cos(t)**2 - 1, (t, a, b))) == \
301
+ Product(0, (t, a, b))
302
+ assert _simplify(Product(v*Product(sin(t)**2 + cos(t)**2, (t, a, b)),
303
+ (v, c, d))) == Product(v*Product(1, (t, a, b)), (v, c, d))
304
+
305
+
306
+ def test_change_index():
307
+ b, y, c, d, z = symbols('b, y, c, d, z', integer = True)
308
+
309
+ assert Product(x, (x, a, b)).change_index(x, x + 1, y) == \
310
+ Product(y - 1, (y, a + 1, b + 1))
311
+ assert Product(x**2, (x, a, b)).change_index(x, x - 1) == \
312
+ Product((x + 1)**2, (x, a - 1, b - 1))
313
+ assert Product(x**2, (x, a, b)).change_index(x, -x, y) == \
314
+ Product((-y)**2, (y, -b, -a))
315
+ assert Product(x, (x, a, b)).change_index(x, -x - 1) == \
316
+ Product(-x - 1, (x, - b - 1, -a - 1))
317
+ assert Product(x*y, (x, a, b), (y, c, d)).change_index(x, x - 1, z) == \
318
+ Product((z + 1)*y, (z, a - 1, b - 1), (y, c, d))
319
+
320
+
321
+ def test_reorder():
322
+ b, y, c, d, z = symbols('b, y, c, d, z', integer = True)
323
+
324
+ assert Product(x*y, (x, a, b), (y, c, d)).reorder((0, 1)) == \
325
+ Product(x*y, (y, c, d), (x, a, b))
326
+ assert Product(x, (x, a, b), (x, c, d)).reorder((0, 1)) == \
327
+ Product(x, (x, c, d), (x, a, b))
328
+ assert Product(x*y + z, (x, a, b), (z, m, n), (y, c, d)).reorder(\
329
+ (2, 0), (0, 1)) == Product(x*y + z, (z, m, n), (y, c, d), (x, a, b))
330
+ assert Product(x*y*z, (x, a, b), (y, c, d), (z, m, n)).reorder(\
331
+ (0, 1), (1, 2), (0, 2)) == \
332
+ Product(x*y*z, (x, a, b), (z, m, n), (y, c, d))
333
+ assert Product(x*y*z, (x, a, b), (y, c, d), (z, m, n)).reorder(\
334
+ (x, y), (y, z), (x, z)) == \
335
+ Product(x*y*z, (x, a, b), (z, m, n), (y, c, d))
336
+ assert Product(x*y, (x, a, b), (y, c, d)).reorder((x, 1)) == \
337
+ Product(x*y, (y, c, d), (x, a, b))
338
+ assert Product(x*y, (x, a, b), (y, c, d)).reorder((y, x)) == \
339
+ Product(x*y, (y, c, d), (x, a, b))
340
+
341
+
342
+ def test_Product_is_convergent():
343
+ assert Product(1/n**2, (n, 1, oo)).is_convergent() is S.false
344
+ assert Product(exp(1/n**2), (n, 1, oo)).is_convergent() is S.true
345
+ assert Product(1/n, (n, 1, oo)).is_convergent() is S.false
346
+ assert Product(1 + 1/n, (n, 1, oo)).is_convergent() is S.false
347
+ assert Product(1 + 1/n**2, (n, 1, oo)).is_convergent() is S.true
348
+
349
+
350
+ def test_reverse_order():
351
+ x, y, a, b, c, d= symbols('x, y, a, b, c, d', integer = True)
352
+
353
+ assert Product(x, (x, 0, 3)).reverse_order(0) == Product(1/x, (x, 4, -1))
354
+ assert Product(x*y, (x, 1, 5), (y, 0, 6)).reverse_order(0, 1) == \
355
+ Product(x*y, (x, 6, 0), (y, 7, -1))
356
+ assert Product(x, (x, 1, 2)).reverse_order(0) == Product(1/x, (x, 3, 0))
357
+ assert Product(x, (x, 1, 3)).reverse_order(0) == Product(1/x, (x, 4, 0))
358
+ assert Product(x, (x, 1, a)).reverse_order(0) == Product(1/x, (x, a + 1, 0))
359
+ assert Product(x, (x, a, 5)).reverse_order(0) == Product(1/x, (x, 6, a - 1))
360
+ assert Product(x, (x, a + 1, a + 5)).reverse_order(0) == \
361
+ Product(1/x, (x, a + 6, a))
362
+ assert Product(x, (x, a + 1, a + 2)).reverse_order(0) == \
363
+ Product(1/x, (x, a + 3, a))
364
+ assert Product(x, (x, a + 1, a + 1)).reverse_order(0) == \
365
+ Product(1/x, (x, a + 2, a))
366
+ assert Product(x, (x, a, b)).reverse_order(0) == Product(1/x, (x, b + 1, a - 1))
367
+ assert Product(x, (x, a, b)).reverse_order(x) == Product(1/x, (x, b + 1, a - 1))
368
+ assert Product(x*y, (x, a, b), (y, 2, 5)).reverse_order(x, 1) == \
369
+ Product(x*y, (x, b + 1, a - 1), (y, 6, 1))
370
+ assert Product(x*y, (x, a, b), (y, 2, 5)).reverse_order(y, x) == \
371
+ Product(x*y, (x, b + 1, a - 1), (y, 6, 1))
372
+
373
+
374
+ def test_issue_9983():
375
+ n = Symbol('n', integer=True, positive=True)
376
+ p = Product(1 + 1/n**Rational(2, 3), (n, 1, oo))
377
+ assert p.is_convergent() is S.false
378
+ assert product(1 + 1/n**Rational(2, 3), (n, 1, oo)) == p.doit()
379
+
380
+
381
+ def test_issue_13546():
382
+ n = Symbol('n')
383
+ k = Symbol('k')
384
+ p = Product(n + 1 / 2**k, (k, 0, n-1)).doit()
385
+ assert p.subs(n, 2).doit() == Rational(15, 2)
386
+
387
+
388
+ def test_issue_14036():
389
+ a, n = symbols('a n')
390
+ assert product(1 - a**2 / (n*pi)**2, [n, 1, oo]) != 0
391
+
392
+
393
+ def test_rewrite_Sum():
394
+ assert Product(1 - S.Half**2/k**2, (k, 1, oo)).rewrite(Sum) == \
395
+ exp(Sum(log(1 - 1/(4*k**2)), (k, 1, oo)))
396
+
397
+
398
+ def test_KroneckerDelta_Product():
399
+ y = Symbol('y')
400
+ assert Product(x*KroneckerDelta(x, y), (x, 0, 1)).doit() == 0
401
+
402
+
403
+ def test_issue_20848():
404
+ _i = Dummy('i')
405
+ t, y, z = symbols('t y z')
406
+ assert diff(Product(x, (y, 1, z)), x).as_dummy() == Sum(Product(x, (y, 1, _i - 1))*Product(x, (y, _i + 1, z)), (_i, 1, z)).as_dummy()
407
+ assert diff(Product(x, (y, 1, z)), x).doit() == x**(z - 1)*z
408
+ assert diff(Product(x, (y, x, z)), x) == Derivative(Product(x, (y, x, z)), x)
409
+ assert diff(Product(t, (x, 1, z)), x) == S(0)
410
+ assert Product(sin(n*x), (n, -1, 1)).diff(x).doit() == S(0)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/sympy/concrete/tests/test_sums_products.py ADDED
@@ -0,0 +1,1646 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from math import prod
2
+
3
+ from sympy.concrete.expr_with_intlimits import ReorderError
4
+ from sympy.concrete.products import (Product, product)
5
+ from sympy.concrete.summations import (Sum, summation, telescopic,
6
+ eval_sum_residue, _dummy_with_inherited_properties_concrete)
7
+ from sympy.core.function import (Derivative, Function)
8
+ from sympy.core import (Catalan, EulerGamma)
9
+ from sympy.core.facts import InconsistentAssumptions
10
+ from sympy.core.mod import Mod
11
+ from sympy.core.numbers import (E, I, Rational, nan, oo, pi)
12
+ from sympy.core.relational import Eq
13
+ from sympy.core.numbers import Float
14
+ from sympy.core.singleton import S
15
+ from sympy.core.symbol import (Dummy, Symbol, symbols)
16
+ from sympy.core.sympify import sympify
17
+ from sympy.functions.combinatorial.factorials import (rf, binomial, factorial)
18
+ from sympy.functions.combinatorial.numbers import harmonic
19
+ from sympy.functions.elementary.complexes import Abs
20
+ from sympy.functions.elementary.exponential import (exp, log)
21
+ from sympy.functions.elementary.hyperbolic import (sinh, tanh)
22
+ from sympy.functions.elementary.integers import floor
23
+ from sympy.functions.elementary.miscellaneous import sqrt
24
+ from sympy.functions.elementary.piecewise import Piecewise
25
+ from sympy.functions.elementary.trigonometric import (cos, sin)
26
+ from sympy.functions.special.gamma_functions import (gamma, lowergamma)
27
+ from sympy.functions.special.tensor_functions import KroneckerDelta
28
+ from sympy.functions.special.zeta_functions import zeta
29
+ from sympy.integrals.integrals import Integral
30
+ from sympy.logic.boolalg import And, Or
31
+ from sympy.matrices.expressions.matexpr import MatrixSymbol
32
+ from sympy.matrices.expressions.special import Identity
33
+ from sympy.matrices import (Matrix, SparseMatrix,
34
+ ImmutableDenseMatrix, ImmutableSparseMatrix, diag)
35
+ from sympy.sets.fancysets import Range
36
+ from sympy.sets.sets import Interval
37
+ from sympy.simplify.combsimp import combsimp
38
+ from sympy.simplify.simplify import simplify
39
+ from sympy.tensor.indexed import (Idx, Indexed, IndexedBase)
40
+ from sympy.testing.pytest import XFAIL, raises, slow
41
+ from sympy.abc import a, b, c, d, k, m, x, y, z
42
+
43
+ n = Symbol('n', integer=True)
44
+ f, g = symbols('f g', cls=Function)
45
+
46
+ def test_karr_convention():
47
+ # Test the Karr summation convention that we want to hold.
48
+ # See his paper "Summation in Finite Terms" for a detailed
49
+ # reasoning why we really want exactly this definition.
50
+ # The convention is described on page 309 and essentially
51
+ # in section 1.4, definition 3:
52
+ #
53
+ # \sum_{m <= i < n} f(i) 'has the obvious meaning' for m < n
54
+ # \sum_{m <= i < n} f(i) = 0 for m = n
55
+ # \sum_{m <= i < n} f(i) = - \sum_{n <= i < m} f(i) for m > n
56
+ #
57
+ # It is important to note that he defines all sums with
58
+ # the upper limit being *exclusive*.
59
+ # In contrast, SymPy and the usual mathematical notation has:
60
+ #
61
+ # sum_{i = a}^b f(i) = f(a) + f(a+1) + ... + f(b-1) + f(b)
62
+ #
63
+ # with the upper limit *inclusive*. So translating between
64
+ # the two we find that:
65
+ #
66
+ # \sum_{m <= i < n} f(i) = \sum_{i = m}^{n-1} f(i)
67
+ #
68
+ # where we intentionally used two different ways to typeset the
69
+ # sum and its limits.
70
+
71
+ i = Symbol("i", integer=True)
72
+ k = Symbol("k", integer=True)
73
+ j = Symbol("j", integer=True)
74
+
75
+ # A simple example with a concrete summand and symbolic limits.
76
+
77
+ # The normal sum: m = k and n = k + j and therefore m < n:
78
+ m = k
79
+ n = k + j
80
+
81
+ a = m
82
+ b = n - 1
83
+ S1 = Sum(i**2, (i, a, b)).doit()
84
+
85
+ # The reversed sum: m = k + j and n = k and therefore m > n:
86
+ m = k + j
87
+ n = k
88
+
89
+ a = m
90
+ b = n - 1
91
+ S2 = Sum(i**2, (i, a, b)).doit()
92
+
93
+ assert simplify(S1 + S2) == 0
94
+
95
+ # Test the empty sum: m = k and n = k and therefore m = n:
96
+ m = k
97
+ n = k
98
+
99
+ a = m
100
+ b = n - 1
101
+ Sz = Sum(i**2, (i, a, b)).doit()
102
+
103
+ assert Sz == 0
104
+
105
+ # Another example this time with an unspecified summand and
106
+ # numeric limits. (We can not do both tests in the same example.)
107
+
108
+ # The normal sum with m < n:
109
+ m = 2
110
+ n = 11
111
+
112
+ a = m
113
+ b = n - 1
114
+ S1 = Sum(f(i), (i, a, b)).doit()
115
+
116
+ # The reversed sum with m > n:
117
+ m = 11
118
+ n = 2
119
+
120
+ a = m
121
+ b = n - 1
122
+ S2 = Sum(f(i), (i, a, b)).doit()
123
+
124
+ assert simplify(S1 + S2) == 0
125
+
126
+ # Test the empty sum with m = n:
127
+ m = 5
128
+ n = 5
129
+
130
+ a = m
131
+ b = n - 1
132
+ Sz = Sum(f(i), (i, a, b)).doit()
133
+
134
+ assert Sz == 0
135
+
136
+ e = Piecewise((exp(-i), Mod(i, 2) > 0), (0, True))
137
+ s = Sum(e, (i, 0, 11))
138
+ assert s.n(3) == s.doit().n(3)
139
+
140
+
141
+ def test_karr_proposition_2a():
142
+ # Test Karr, page 309, proposition 2, part a
143
+ i = Symbol("i", integer=True)
144
+ u = Symbol("u", integer=True)
145
+ v = Symbol("v", integer=True)
146
+
147
+ def test_the_sum(m, n):
148
+ # g
149
+ g = i**3 + 2*i**2 - 3*i
150
+ # f = Delta g
151
+ f = simplify(g.subs(i, i+1) - g)
152
+ # The sum
153
+ a = m
154
+ b = n - 1
155
+ S = Sum(f, (i, a, b)).doit()
156
+ # Test if Sum_{m <= i < n} f(i) = g(n) - g(m)
157
+ assert simplify(S - (g.subs(i, n) - g.subs(i, m))) == 0
158
+
159
+ # m < n
160
+ test_the_sum(u, u+v)
161
+ # m = n
162
+ test_the_sum(u, u )
163
+ # m > n
164
+ test_the_sum(u+v, u )
165
+
166
+
167
+ def test_karr_proposition_2b():
168
+ # Test Karr, page 309, proposition 2, part b
169
+ i = Symbol("i", integer=True)
170
+ u = Symbol("u", integer=True)
171
+ v = Symbol("v", integer=True)
172
+ w = Symbol("w", integer=True)
173
+
174
+ def test_the_sum(l, n, m):
175
+ # Summand
176
+ s = i**3
177
+ # First sum
178
+ a = l
179
+ b = n - 1
180
+ S1 = Sum(s, (i, a, b)).doit()
181
+ # Second sum
182
+ a = l
183
+ b = m - 1
184
+ S2 = Sum(s, (i, a, b)).doit()
185
+ # Third sum
186
+ a = m
187
+ b = n - 1
188
+ S3 = Sum(s, (i, a, b)).doit()
189
+ # Test if S1 = S2 + S3 as required
190
+ assert S1 - (S2 + S3) == 0
191
+
192
+ # l < m < n
193
+ test_the_sum(u, u+v, u+v+w)
194
+ # l < m = n
195
+ test_the_sum(u, u+v, u+v )
196
+ # l < m > n
197
+ test_the_sum(u, u+v+w, v )
198
+ # l = m < n
199
+ test_the_sum(u, u, u+v )
200
+ # l = m = n
201
+ test_the_sum(u, u, u )
202
+ # l = m > n
203
+ test_the_sum(u+v, u+v, u )
204
+ # l > m < n
205
+ test_the_sum(u+v, u, u+w )
206
+ # l > m = n
207
+ test_the_sum(u+v, u, u )
208
+ # l > m > n
209
+ test_the_sum(u+v+w, u+v, u )
210
+
211
+
212
+ def test_arithmetic_sums():
213
+ assert summation(1, (n, a, b)) == b - a + 1
214
+ assert Sum(S.NaN, (n, a, b)) is S.NaN
215
+ assert Sum(x, (n, a, a)).doit() == x
216
+ assert Sum(x, (x, a, a)).doit() == a
217
+ assert Sum(x, (n, 1, a)).doit() == a*x
218
+ assert Sum(x, (x, Range(1, 11))).doit() == 55
219
+ assert Sum(x, (x, Range(1, 11, 2))).doit() == 25
220
+ assert Sum(x, (x, Range(1, 10, 2))) == Sum(x, (x, Range(9, 0, -2)))
221
+ lo, hi = 1, 2
222
+ s1 = Sum(n, (n, lo, hi))
223
+ s2 = Sum(n, (n, hi, lo))
224
+ assert s1 != s2
225
+ assert s1.doit() == 3 and s2.doit() == 0
226
+ lo, hi = x, x + 1
227
+ s1 = Sum(n, (n, lo, hi))
228
+ s2 = Sum(n, (n, hi, lo))
229
+ assert s1 != s2
230
+ assert s1.doit() == 2*x + 1 and s2.doit() == 0
231
+ assert Sum(Integral(x, (x, 1, y)) + x, (x, 1, 2)).doit() == \
232
+ y**2 + 2
233
+ assert summation(1, (n, 1, 10)) == 10
234
+ assert summation(2*n, (n, 0, 10**10)) == 100000000010000000000
235
+ assert summation(4*n*m, (n, a, 1), (m, 1, d)).expand() == \
236
+ 2*d + 2*d**2 + a*d + a*d**2 - d*a**2 - a**2*d**2
237
+ assert summation(cos(n), (n, -2, 1)) == cos(-2) + cos(-1) + cos(0) + cos(1)
238
+ assert summation(cos(n), (n, x, x + 2)) == cos(x) + cos(x + 1) + cos(x + 2)
239
+ assert isinstance(summation(cos(n), (n, x, x + S.Half)), Sum)
240
+ assert summation(k, (k, 0, oo)) is oo
241
+ assert summation(k, (k, Range(1, 11))) == 55
242
+
243
+
244
+ def test_polynomial_sums():
245
+ assert summation(n**2, (n, 3, 8)) == 199
246
+ assert summation(n, (n, a, b)) == \
247
+ ((a + b)*(b - a + 1)/2).expand()
248
+ assert summation(n**2, (n, 1, b)) == \
249
+ ((2*b**3 + 3*b**2 + b)/6).expand()
250
+ assert summation(n**3, (n, 1, b)) == \
251
+ ((b**4 + 2*b**3 + b**2)/4).expand()
252
+ assert summation(n**6, (n, 1, b)) == \
253
+ ((6*b**7 + 21*b**6 + 21*b**5 - 7*b**3 + b)/42).expand()
254
+
255
+
256
+ def test_geometric_sums():
257
+ assert summation(pi**n, (n, 0, b)) == (1 - pi**(b + 1)) / (1 - pi)
258
+ assert summation(2 * 3**n, (n, 0, b)) == 3**(b + 1) - 1
259
+ assert summation(S.Half**n, (n, 1, oo)) == 1
260
+ assert summation(2**n, (n, 0, b)) == 2**(b + 1) - 1
261
+ assert summation(2**n, (n, 1, oo)) is oo
262
+ assert summation(2**(-n), (n, 1, oo)) == 1
263
+ assert summation(3**(-n), (n, 4, oo)) == Rational(1, 54)
264
+ assert summation(2**(-4*n + 3), (n, 1, oo)) == Rational(8, 15)
265
+ assert summation(2**(n + 1), (n, 1, b)).expand() == 4*(2**b - 1)
266
+
267
+ # issue 6664:
268
+ assert summation(x**n, (n, 0, oo)) == \
269
+ Piecewise((1/(-x + 1), Abs(x) < 1), (Sum(x**n, (n, 0, oo)), True))
270
+
271
+ assert summation(-2**n, (n, 0, oo)) is -oo
272
+ assert summation(I**n, (n, 0, oo)) == Sum(I**n, (n, 0, oo))
273
+
274
+ # issue 6802:
275
+ assert summation((-1)**(2*x + 2), (x, 0, n)) == n + 1
276
+ assert summation((-2)**(2*x + 2), (x, 0, n)) == 4*4**(n + 1)/S(3) - Rational(4, 3)
277
+ assert summation((-1)**x, (x, 0, n)) == -(-1)**(n + 1)/S(2) + S.Half
278
+ assert summation(y**x, (x, a, b)) == \
279
+ Piecewise((-a + b + 1, Eq(y, 1)), ((y**a - y**(b + 1))/(-y + 1), True))
280
+ assert summation((-2)**(y*x + 2), (x, 0, n)) == \
281
+ 4*Piecewise((n + 1, Eq((-2)**y, 1)),
282
+ ((-(-2)**(y*(n + 1)) + 1)/(-(-2)**y + 1), True))
283
+
284
+ # issue 8251:
285
+ assert summation((1/(n + 1)**2)*n**2, (n, 0, oo)) is oo
286
+
287
+ #issue 9908:
288
+ assert Sum(1/(n**3 - 1), (n, -oo, -2)).doit() == summation(1/(n**3 - 1), (n, -oo, -2))
289
+
290
+ #issue 11642:
291
+ result = Sum(0.5**n, (n, 1, oo)).doit()
292
+ assert result == 1.0
293
+ assert result.is_Float
294
+
295
+ result = Sum(0.25**n, (n, 1, oo)).doit()
296
+ assert result == 1/3.
297
+ assert result.is_Float
298
+
299
+ result = Sum(0.99999**n, (n, 1, oo)).doit()
300
+ assert result == 99999.0
301
+ assert result.is_Float
302
+
303
+ result = Sum(S.Half**n, (n, 1, oo)).doit()
304
+ assert result == 1
305
+ assert not result.is_Float
306
+
307
+ result = Sum(Rational(3, 5)**n, (n, 1, oo)).doit()
308
+ assert result == Rational(3, 2)
309
+ assert not result.is_Float
310
+
311
+ assert Sum(1.0**n, (n, 1, oo)).doit() is oo
312
+ assert Sum(2.43**n, (n, 1, oo)).doit() is oo
313
+
314
+ # Issue 13979
315
+ i, k, q = symbols('i k q', integer=True)
316
+ result = summation(
317
+ exp(-2*I*pi*k*i/n) * exp(2*I*pi*q*i/n) / n, (i, 0, n - 1)
318
+ )
319
+ assert result.simplify() == Piecewise(
320
+ (1, Eq(exp(-2*I*pi*(k - q)/n), 1)), (0, True)
321
+ )
322
+
323
+ #Issue 23491
324
+ assert Sum(1/(n**2 + 1), (n, 1, oo)).doit() == S(-1)/2 + pi/(2*tanh(pi))
325
+
326
+ def test_harmonic_sums():
327
+ assert summation(1/k, (k, 0, n)) == Sum(1/k, (k, 0, n))
328
+ assert summation(1/k, (k, 1, n)) == harmonic(n)
329
+ assert summation(n/k, (k, 1, n)) == n*harmonic(n)
330
+ assert summation(1/k, (k, 5, n)) == harmonic(n) - harmonic(4)
331
+
332
+
333
+ def test_composite_sums():
334
+ f = S.Half*(7 - 6*n + Rational(1, 7)*n**3)
335
+ s = summation(f, (n, a, b))
336
+ assert not isinstance(s, Sum)
337
+ A = 0
338
+ for i in range(-3, 5):
339
+ A += f.subs(n, i)
340
+ B = s.subs(a, -3).subs(b, 4)
341
+ assert A == B
342
+
343
+
344
+ def test_hypergeometric_sums():
345
+ assert summation(
346
+ binomial(2*k, k)/4**k, (k, 0, n)) == (1 + 2*n)*binomial(2*n, n)/4**n
347
+ assert summation(binomial(2*k, k)/5**k, (k, -oo, oo)) == sqrt(5)
348
+
349
+
350
+ def test_other_sums():
351
+ f = m**2 + m*exp(m)
352
+ g = 3*exp(Rational(3, 2))/2 + exp(S.Half)/2 - exp(Rational(-1, 2))/2 - 3*exp(Rational(-3, 2))/2 + 5
353
+
354
+ assert summation(f, (m, Rational(-3, 2), Rational(3, 2))) == g
355
+ assert summation(f, (m, -1.5, 1.5)).evalf().epsilon_eq(g.evalf(), 1e-10)
356
+
357
+ fac = factorial
358
+
359
+
360
+ def NS(e, n=15, **options):
361
+ return str(sympify(e).evalf(n, **options))
362
+
363
+
364
+ def test_evalf_fast_series():
365
+ # Euler transformed series for sqrt(1+x)
366
+ assert NS(Sum(
367
+ fac(2*n + 1)/fac(n)**2/2**(3*n + 1), (n, 0, oo)), 100) == NS(sqrt(2), 100)
368
+
369
+ # Some series for exp(1)
370
+ estr = NS(E, 100)
371
+ assert NS(Sum(1/fac(n), (n, 0, oo)), 100) == estr
372
+ assert NS(1/Sum((1 - 2*n)/fac(2*n), (n, 0, oo)), 100) == estr
373
+ assert NS(Sum((2*n + 1)/fac(2*n), (n, 0, oo)), 100) == estr
374
+ assert NS(Sum((4*n + 3)/2**(2*n + 1)/fac(2*n + 1), (n, 0, oo))**2, 100) == estr
375
+
376
+ pistr = NS(pi, 100)
377
+ # Ramanujan series for pi
378
+ assert NS(9801/sqrt(8)/Sum(fac(
379
+ 4*n)*(1103 + 26390*n)/fac(n)**4/396**(4*n), (n, 0, oo)), 100) == pistr
380
+ assert NS(1/Sum(
381
+ binomial(2*n, n)**3 * (42*n + 5)/2**(12*n + 4), (n, 0, oo)), 100) == pistr
382
+ # Machin's formula for pi
383
+ assert NS(16*Sum((-1)**n/(2*n + 1)/5**(2*n + 1), (n, 0, oo)) -
384
+ 4*Sum((-1)**n/(2*n + 1)/239**(2*n + 1), (n, 0, oo)), 100) == pistr
385
+
386
+ # Apery's constant
387
+ astr = NS(zeta(3), 100)
388
+ P = 126392*n**5 + 412708*n**4 + 531578*n**3 + 336367*n**2 + 104000* \
389
+ n + 12463
390
+ assert NS(Sum((-1)**n * P / 24 * (fac(2*n + 1)*fac(2*n)*fac(
391
+ n))**3 / fac(3*n + 2) / fac(4*n + 3)**3, (n, 0, oo)), 100) == astr
392
+ assert NS(Sum((-1)**n * (205*n**2 + 250*n + 77)/64 * fac(n)**10 /
393
+ fac(2*n + 1)**5, (n, 0, oo)), 100) == astr
394
+
395
+
396
+ def test_evalf_fast_series_issue_4021():
397
+ # Catalan's constant
398
+ assert NS(Sum((-1)**(n - 1)*2**(8*n)*(40*n**2 - 24*n + 3)*fac(2*n)**3*
399
+ fac(n)**2/n**3/(2*n - 1)/fac(4*n)**2, (n, 1, oo))/64, 100) == \
400
+ NS(Catalan, 100)
401
+ astr = NS(zeta(3), 100)
402
+ assert NS(5*Sum(
403
+ (-1)**(n - 1)*fac(n)**2 / n**3 / fac(2*n), (n, 1, oo))/2, 100) == astr
404
+ assert NS(Sum((-1)**(n - 1)*(56*n**2 - 32*n + 5) / (2*n - 1)**2 * fac(n - 1)
405
+ **3 / fac(3*n), (n, 1, oo))/4, 100) == astr
406
+
407
+
408
+ def test_evalf_slow_series():
409
+ assert NS(Sum((-1)**n / n, (n, 1, oo)), 15) == NS(-log(2), 15)
410
+ assert NS(Sum((-1)**n / n, (n, 1, oo)), 50) == NS(-log(2), 50)
411
+ assert NS(Sum(1/n**2, (n, 1, oo)), 15) == NS(pi**2/6, 15)
412
+ assert NS(Sum(1/n**2, (n, 1, oo)), 100) == NS(pi**2/6, 100)
413
+ assert NS(Sum(1/n**2, (n, 1, oo)), 500) == NS(pi**2/6, 500)
414
+ assert NS(Sum((-1)**n / (2*n + 1)**3, (n, 0, oo)), 15) == NS(pi**3/32, 15)
415
+ assert NS(Sum((-1)**n / (2*n + 1)**3, (n, 0, oo)), 50) == NS(pi**3/32, 50)
416
+
417
+
418
+ def test_evalf_oo_to_oo():
419
+ # There used to be an error in certain cases
420
+ # Does not evaluate, but at least do not throw an error
421
+ # Evaluates symbolically to 0, which is not correct
422
+ assert Sum(1/(n**2+1), (n, -oo, oo)).evalf() == Sum(1/(n**2+1), (n, -oo, oo))
423
+ # This evaluates if from 1 to oo and symbolically
424
+ assert Sum(1/(factorial(abs(n))), (n, -oo, -1)).evalf() == Sum(1/(factorial(abs(n))), (n, -oo, -1))
425
+
426
+
427
+ def test_euler_maclaurin():
428
+ # Exact polynomial sums with E-M
429
+ def check_exact(f, a, b, m, n):
430
+ A = Sum(f, (k, a, b))
431
+ s, e = A.euler_maclaurin(m, n)
432
+ assert (e == 0) and (s.expand() == A.doit())
433
+ check_exact(k**4, a, b, 0, 2)
434
+ check_exact(k**4 + 2*k, a, b, 1, 2)
435
+ check_exact(k**4 + k**2, a, b, 1, 5)
436
+ check_exact(k**5, 2, 6, 1, 2)
437
+ check_exact(k**5, 2, 6, 1, 3)
438
+ assert Sum(x-1, (x, 0, 2)).euler_maclaurin(m=30, n=30, eps=2**-15) == (0, 0)
439
+ # Not exact
440
+ assert Sum(k**6, (k, a, b)).euler_maclaurin(0, 2)[1] != 0
441
+ # Numerical test
442
+ for mi, ni in [(2, 4), (2, 20), (10, 20), (18, 20)]:
443
+ A = Sum(1/k**3, (k, 1, oo))
444
+ s, e = A.euler_maclaurin(mi, ni)
445
+ assert abs((s - zeta(3)).evalf()) < e.evalf()
446
+
447
+ raises(ValueError, lambda: Sum(1, (x, 0, 1), (k, 0, 1)).euler_maclaurin())
448
+
449
+
450
+ @slow
451
+ def test_evalf_euler_maclaurin():
452
+ assert NS(Sum(1/k**k, (k, 1, oo)), 15) == '1.29128599706266'
453
+ assert NS(Sum(1/k**k, (k, 1, oo)),
454
+ 50) == '1.2912859970626635404072825905956005414986193682745'
455
+ assert NS(Sum(1/k - log(1 + 1/k), (k, 1, oo)), 15) == NS(EulerGamma, 15)
456
+ assert NS(Sum(1/k - log(1 + 1/k), (k, 1, oo)), 50) == NS(EulerGamma, 50)
457
+ assert NS(Sum(log(k)/k**2, (k, 1, oo)), 15) == '0.937548254315844'
458
+ assert NS(Sum(log(k)/k**2, (k, 1, oo)),
459
+ 50) == '0.93754825431584375370257409456786497789786028861483'
460
+ assert NS(Sum(1/k, (k, 1000000, 2000000)), 15) == '0.693147930560008'
461
+ assert NS(Sum(1/k, (k, 1000000, 2000000)),
462
+ 50) == '0.69314793056000780941723211364567656807940638436025'
463
+
464
+
465
+ def test_evalf_symbolic():
466
+ # issue 6328
467
+ expr = Sum(f(x), (x, 1, 3)) + Sum(g(x), (x, 1, 3))
468
+ assert expr.evalf() == expr
469
+
470
+
471
+ def test_evalf_issue_3273():
472
+ assert Sum(0, (k, 1, oo)).evalf() == 0
473
+
474
+
475
+ def test_simple_products():
476
+ assert Product(S.NaN, (x, 1, 3)) is S.NaN
477
+ assert product(S.NaN, (x, 1, 3)) is S.NaN
478
+ assert Product(x, (n, a, a)).doit() == x
479
+ assert Product(x, (x, a, a)).doit() == a
480
+ assert Product(x, (y, 1, a)).doit() == x**a
481
+
482
+ lo, hi = 1, 2
483
+ s1 = Product(n, (n, lo, hi))
484
+ s2 = Product(n, (n, hi, lo))
485
+ assert s1 != s2
486
+ # This IS correct according to Karr product convention
487
+ assert s1.doit() == 2
488
+ assert s2.doit() == 1
489
+
490
+ lo, hi = x, x + 1
491
+ s1 = Product(n, (n, lo, hi))
492
+ s2 = Product(n, (n, hi, lo))
493
+ s3 = 1 / Product(n, (n, hi + 1, lo - 1))
494
+ assert s1 != s2
495
+ # This IS correct according to Karr product convention
496
+ assert s1.doit() == x*(x + 1)
497
+ assert s2.doit() == 1
498
+ assert s3.doit() == x*(x + 1)
499
+
500
+ assert Product(Integral(2*x, (x, 1, y)) + 2*x, (x, 1, 2)).doit() == \
501
+ (y**2 + 1)*(y**2 + 3)
502
+ assert product(2, (n, a, b)) == 2**(b - a + 1)
503
+ assert product(n, (n, 1, b)) == factorial(b)
504
+ assert product(n**3, (n, 1, b)) == factorial(b)**3
505
+ assert product(3**(2 + n), (n, a, b)) \
506
+ == 3**(2*(1 - a + b) + b/2 + (b**2)/2 + a/2 - (a**2)/2)
507
+ assert product(cos(n), (n, 3, 5)) == cos(3)*cos(4)*cos(5)
508
+ assert product(cos(n), (n, x, x + 2)) == cos(x)*cos(x + 1)*cos(x + 2)
509
+ assert isinstance(product(cos(n), (n, x, x + S.Half)), Product)
510
+ # If Product managed to evaluate this one, it most likely got it wrong!
511
+ assert isinstance(Product(n**n, (n, 1, b)), Product)
512
+
513
+
514
+ def test_rational_products():
515
+ assert combsimp(product(1 + 1/n, (n, a, b))) == (1 + b)/a
516
+ assert combsimp(product(n + 1, (n, a, b))) == gamma(2 + b)/gamma(1 + a)
517
+ assert combsimp(product((n + 1)/(n - 1), (n, a, b))) == b*(1 + b)/(a*(a - 1))
518
+ assert combsimp(product(n/(n + 1)/(n + 2), (n, a, b))) == \
519
+ a*gamma(a + 2)/(b + 1)/gamma(b + 3)
520
+ assert combsimp(product(n*(n + 1)/(n - 1)/(n - 2), (n, a, b))) == \
521
+ b**2*(b - 1)*(1 + b)/(a - 1)**2/(a*(a - 2))
522
+
523
+
524
+ def test_wallis_product():
525
+ # Wallis product, given in two different forms to ensure that Product
526
+ # can factor simple rational expressions
527
+ A = Product(4*n**2 / (4*n**2 - 1), (n, 1, b))
528
+ B = Product((2*n)*(2*n)/(2*n - 1)/(2*n + 1), (n, 1, b))
529
+ R = pi*gamma(b + 1)**2/(2*gamma(b + S.Half)*gamma(b + Rational(3, 2)))
530
+ assert simplify(A.doit()) == R
531
+ assert simplify(B.doit()) == R
532
+ # This one should eventually also be doable (Euler's product formula for sin)
533
+ # assert Product(1+x/n**2, (n, 1, b)) == ...
534
+
535
+
536
+ def test_telescopic_sums():
537
+ #checks also input 2 of comment 1 issue 4127
538
+ assert Sum(1/k - 1/(k + 1), (k, 1, n)).doit() == 1 - 1/(1 + n)
539
+ assert Sum(
540
+ f(k) - f(k + 2), (k, m, n)).doit() == -f(1 + n) - f(2 + n) + f(m) + f(1 + m)
541
+ assert Sum(cos(k) - cos(k + 3), (k, 1, n)).doit() == -cos(1 + n) - \
542
+ cos(2 + n) - cos(3 + n) + cos(1) + cos(2) + cos(3)
543
+
544
+ # dummy variable shouldn't matter
545
+ assert telescopic(1/m, -m/(1 + m), (m, n - 1, n)) == \
546
+ telescopic(1/k, -k/(1 + k), (k, n - 1, n))
547
+
548
+ assert Sum(1/x/(x - 1), (x, a, b)).doit() == 1/(a - 1) - 1/b
549
+ eq = 1/((5*n + 2)*(5*(n + 1) + 2))
550
+ assert Sum(eq, (n, 0, oo)).doit() == S(1)/10
551
+ nz = symbols('nz', nonzero=True)
552
+ v = Sum(eq.subs(5, nz), (n, 0, oo)).doit()
553
+ assert v.subs(nz, 5).simplify() == S(1)/10
554
+ # check that apart is being used in non-symbolic case
555
+ s = Sum(eq, (n, 0, k)).doit()
556
+ v = Sum(eq, (n, 0, 10**100)).doit()
557
+ assert v == s.subs(k, 10**100)
558
+
559
+
560
+ def test_sum_reconstruct():
561
+ s = Sum(n**2, (n, -1, 1))
562
+ assert s == Sum(*s.args)
563
+ raises(ValueError, lambda: Sum(x, x))
564
+ raises(ValueError, lambda: Sum(x, (x, 1)))
565
+
566
+
567
+ def test_limit_subs():
568
+ for F in (Sum, Product, Integral):
569
+ assert F(a*exp(a), (a, -2, 2)) == F(a*exp(a), (a, -b, b)).subs(b, 2)
570
+ assert F(a, (a, F(b, (b, 1, 2)), 4)).subs(F(b, (b, 1, 2)), c) == \
571
+ F(a, (a, c, 4))
572
+ assert F(x, (x, 1, x + y)).subs(x, 1) == F(x, (x, 1, y + 1))
573
+
574
+
575
+ def test_function_subs():
576
+ S = Sum(x*f(y),(x,0,oo),(y,0,oo))
577
+ assert S.subs(f(y),y) == Sum(x*y,(x,0,oo),(y,0,oo))
578
+ assert S.subs(f(x),x) == S
579
+ raises(ValueError, lambda: S.subs(f(y),x+y) )
580
+ S = Sum(x*log(y),(x,0,oo),(y,0,oo))
581
+ assert S.subs(log(y),y) == S
582
+ S = Sum(x*f(y),(x,0,oo),(y,0,oo))
583
+ assert S.subs(f(y),y) == Sum(x*y,(x,0,oo),(y,0,oo))
584
+
585
+
586
+ def test_equality():
587
+ # if this fails remove special handling below
588
+ raises(ValueError, lambda: Sum(x, x))
589
+ r = symbols('x', real=True)
590
+ for F in (Sum, Product, Integral):
591
+ try:
592
+ assert F(x, x) != F(y, y)
593
+ assert F(x, (x, 1, 2)) != F(x, x)
594
+ assert F(x, (x, x)) != F(x, x) # or else they print the same
595
+ assert F(1, x) != F(1, y)
596
+ except ValueError:
597
+ pass
598
+ assert F(a, (x, 1, 2)) != F(a, (x, 1, 3)) # diff limit
599
+ assert F(a, (x, 1, x)) != F(a, (y, 1, y))
600
+ assert F(a, (x, 1, 2)) != F(b, (x, 1, 2)) # diff expression
601
+ assert F(x, (x, 1, 2)) != F(r, (r, 1, 2)) # diff assumptions
602
+ assert F(1, (x, 1, x)) != F(1, (y, 1, x)) # only dummy is diff
603
+ assert F(1, (x, 1, x)).dummy_eq(F(1, (y, 1, x)))
604
+
605
+ # issue 5265
606
+ assert Sum(x, (x, 1, x)).subs(x, a) == Sum(x, (x, 1, a))
607
+
608
+
609
+ def test_Sum_doit():
610
+ assert Sum(n*Integral(a**2), (n, 0, 2)).doit() == a**3
611
+ assert Sum(n*Integral(a**2), (n, 0, 2)).doit(deep=False) == \
612
+ 3*Integral(a**2)
613
+ assert summation(n*Integral(a**2), (n, 0, 2)) == 3*Integral(a**2)
614
+
615
+ # test nested sum evaluation
616
+ s = Sum( Sum( Sum(2,(z,1,n+1)), (y,x+1,n)), (x,1,n))
617
+ assert 0 == (s.doit() - n*(n+1)*(n-1)).factor()
618
+
619
+ # Integer assumes finite
620
+ assert Sum(KroneckerDelta(x, y), (x, -oo, oo)).doit() == Piecewise((1, And(-oo < y, y < oo)), (0, True))
621
+ assert Sum(KroneckerDelta(m, n), (m, -oo, oo)).doit() == 1
622
+ assert Sum(m*KroneckerDelta(x, y), (x, -oo, oo)).doit() == Piecewise((m, And(-oo < y, y < oo)), (0, True))
623
+ assert Sum(x*KroneckerDelta(m, n), (m, -oo, oo)).doit() == x
624
+ assert Sum(Sum(KroneckerDelta(m, n), (m, 1, 3)), (n, 1, 3)).doit() == 3
625
+ assert Sum(Sum(KroneckerDelta(k, m), (m, 1, 3)), (n, 1, 3)).doit() == \
626
+ 3 * Piecewise((1, And(1 <= k, k <= 3)), (0, True))
627
+ assert Sum(f(n) * Sum(KroneckerDelta(m, n), (m, 0, oo)), (n, 1, 3)).doit() == \
628
+ f(1) + f(2) + f(3)
629
+ assert Sum(f(n) * Sum(KroneckerDelta(m, n), (m, 0, oo)), (n, 1, oo)).doit() == \
630
+ Sum(f(n), (n, 1, oo))
631
+
632
+ # issue 2597
633
+ nmax = symbols('N', integer=True, positive=True)
634
+ pw = Piecewise((1, And(1 <= n, n <= nmax)), (0, True))
635
+ assert Sum(pw, (n, 1, nmax)).doit() == Sum(Piecewise((1, nmax >= n),
636
+ (0, True)), (n, 1, nmax))
637
+
638
+ q, s = symbols('q, s')
639
+ assert summation(1/n**(2*s), (n, 1, oo)) == Piecewise((zeta(2*s), 2*s > 1),
640
+ (Sum(n**(-2*s), (n, 1, oo)), True))
641
+ assert summation(1/(n+1)**s, (n, 0, oo)) == Piecewise((zeta(s), s > 1),
642
+ (Sum((n + 1)**(-s), (n, 0, oo)), True))
643
+ assert summation(1/(n+q)**s, (n, 0, oo)) == Piecewise(
644
+ (zeta(s, q), And(q > 0, s > 1)),
645
+ (Sum((n + q)**(-s), (n, 0, oo)), True))
646
+ assert summation(1/(n+q)**s, (n, q, oo)) == Piecewise(
647
+ (zeta(s, 2*q), And(2*q > 0, s > 1)),
648
+ (Sum((n + q)**(-s), (n, q, oo)), True))
649
+ assert summation(1/n**2, (n, 1, oo)) == zeta(2)
650
+ assert summation(1/n**s, (n, 0, oo)) == Sum(n**(-s), (n, 0, oo))
651
+
652
+
653
+ def test_Product_doit():
654
+ assert Product(n*Integral(a**2), (n, 1, 3)).doit() == 2 * a**9 / 9
655
+ assert Product(n*Integral(a**2), (n, 1, 3)).doit(deep=False) == \
656
+ 6*Integral(a**2)**3
657
+ assert product(n*Integral(a**2), (n, 1, 3)) == 6*Integral(a**2)**3
658
+
659
+
660
+ def test_Sum_interface():
661
+ assert isinstance(Sum(0, (n, 0, 2)), Sum)
662
+ assert Sum(nan, (n, 0, 2)) is nan
663
+ assert Sum(nan, (n, 0, oo)) is nan
664
+ assert Sum(0, (n, 0, 2)).doit() == 0
665
+ assert isinstance(Sum(0, (n, 0, oo)), Sum)
666
+ assert Sum(0, (n, 0, oo)).doit() == 0
667
+ raises(ValueError, lambda: Sum(1))
668
+ raises(ValueError, lambda: summation(1))
669
+
670
+
671
+ def test_diff():
672
+ assert Sum(x, (x, 1, 2)).diff(x) == 0
673
+ assert Sum(x*y, (x, 1, 2)).diff(x) == 0
674
+ assert Sum(x*y, (y, 1, 2)).diff(x) == Sum(y, (y, 1, 2))
675
+ e = Sum(x*y, (x, 1, a))
676
+ assert e.diff(a) == Derivative(e, a)
677
+ assert Sum(x*y, (x, 1, 3), (a, 2, 5)).diff(y).doit() == \
678
+ Sum(x*y, (x, 1, 3), (a, 2, 5)).doit().diff(y) == 24
679
+ assert Sum(x, (x, 1, 2)).diff(y) == 0
680
+
681
+
682
+ def test_hypersum():
683
+ assert simplify(summation(x**n/fac(n), (n, 1, oo))) == -1 + exp(x)
684
+ assert summation((-1)**n * x**(2*n) / fac(2*n), (n, 0, oo)) == cos(x)
685
+ assert simplify(summation((-1)**n*x**(2*n + 1) /
686
+ factorial(2*n + 1), (n, 3, oo))) == -x + sin(x) + x**3/6 - x**5/120
687
+
688
+ assert summation(1/(n + 2)**3, (n, 1, oo)) == Rational(-9, 8) + zeta(3)
689
+ assert summation(1/n**4, (n, 1, oo)) == pi**4/90
690
+
691
+ s = summation(x**n*n, (n, -oo, 0))
692
+ assert s.is_Piecewise
693
+ assert s.args[0].args[0] == -1/(x*(1 - 1/x)**2)
694
+ assert s.args[0].args[1] == (abs(1/x) < 1)
695
+
696
+ m = Symbol('n', integer=True, positive=True)
697
+ assert summation(binomial(m, k), (k, 0, m)) == 2**m
698
+
699
+
700
+ def test_issue_4170():
701
+ assert summation(1/factorial(k), (k, 0, oo)) == E
702
+
703
+
704
+ def test_is_commutative():
705
+ from sympy.physics.secondquant import NO, F, Fd
706
+ m = Symbol('m', commutative=False)
707
+ for f in (Sum, Product, Integral):
708
+ assert f(z, (z, 1, 1)).is_commutative is True
709
+ assert f(z*y, (z, 1, 6)).is_commutative is True
710
+ assert f(m*x, (x, 1, 2)).is_commutative is False
711
+
712
+ assert f(NO(Fd(x)*F(y))*z, (z, 1, 2)).is_commutative is False
713
+
714
+
715
+ def test_is_zero():
716
+ for func in [Sum, Product]:
717
+ assert func(0, (x, 1, 1)).is_zero is True
718
+ assert func(x, (x, 1, 1)).is_zero is None
719
+
720
+ assert Sum(0, (x, 1, 0)).is_zero is True
721
+ assert Product(0, (x, 1, 0)).is_zero is False
722
+
723
+
724
+ def test_is_number():
725
+ # is number should not rely on evaluation or assumptions,
726
+ # it should be equivalent to `not foo.free_symbols`
727
+ assert Sum(1, (x, 1, 1)).is_number is True
728
+ assert Sum(1, (x, 1, x)).is_number is False
729
+ assert Sum(0, (x, y, z)).is_number is False
730
+ assert Sum(x, (y, 1, 2)).is_number is False
731
+ assert Sum(x, (y, 1, 1)).is_number is False
732
+ assert Sum(x, (x, 1, 2)).is_number is True
733
+ assert Sum(x*y, (x, 1, 2), (y, 1, 3)).is_number is True
734
+
735
+ assert Product(2, (x, 1, 1)).is_number is True
736
+ assert Product(2, (x, 1, y)).is_number is False
737
+ assert Product(0, (x, y, z)).is_number is False
738
+ assert Product(1, (x, y, z)).is_number is False
739
+ assert Product(x, (y, 1, x)).is_number is False
740
+ assert Product(x, (y, 1, 2)).is_number is False
741
+ assert Product(x, (y, 1, 1)).is_number is False
742
+ assert Product(x, (x, 1, 2)).is_number is True
743
+
744
+
745
+ def test_free_symbols():
746
+ for func in [Sum, Product]:
747
+ assert func(1, (x, 1, 2)).free_symbols == set()
748
+ assert func(0, (x, 1, y)).free_symbols == {y}
749
+ assert func(2, (x, 1, y)).free_symbols == {y}
750
+ assert func(x, (x, 1, 2)).free_symbols == set()
751
+ assert func(x, (x, 1, y)).free_symbols == {y}
752
+ assert func(x, (y, 1, y)).free_symbols == {x, y}
753
+ assert func(x, (y, 1, 2)).free_symbols == {x}
754
+ assert func(x, (y, 1, 1)).free_symbols == {x}
755
+ assert func(x, (y, 1, z)).free_symbols == {x, z}
756
+ assert func(x, (x, 1, y), (y, 1, 2)).free_symbols == set()
757
+ assert func(x, (x, 1, y), (y, 1, z)).free_symbols == {z}
758
+ assert func(x, (x, 1, y), (y, 1, y)).free_symbols == {y}
759
+ assert func(x, (y, 1, y), (y, 1, z)).free_symbols == {x, z}
760
+ assert Sum(1, (x, 1, y)).free_symbols == {y}
761
+ # free_symbols answers whether the object *as written* has free symbols,
762
+ # not whether the evaluated expression has free symbols
763
+ assert Product(1, (x, 1, y)).free_symbols == {y}
764
+ # don't count free symbols that are not independent of integration
765
+ # variable(s)
766
+ assert func(f(x), (f(x), 1, 2)).free_symbols == set()
767
+ assert func(f(x), (f(x), 1, x)).free_symbols == {x}
768
+ assert func(f(x), (f(x), 1, y)).free_symbols == {y}
769
+ assert func(f(x), (z, 1, y)).free_symbols == {x, y}
770
+
771
+
772
+ def test_conjugate_transpose():
773
+ A, B = symbols("A B", commutative=False)
774
+ p = Sum(A*B**n, (n, 1, 3))
775
+ assert p.adjoint().doit() == p.doit().adjoint()
776
+ assert p.conjugate().doit() == p.doit().conjugate()
777
+ assert p.transpose().doit() == p.doit().transpose()
778
+
779
+ p = Sum(B**n*A, (n, 1, 3))
780
+ assert p.adjoint().doit() == p.doit().adjoint()
781
+ assert p.conjugate().doit() == p.doit().conjugate()
782
+ assert p.transpose().doit() == p.doit().transpose()
783
+
784
+
785
+ def test_noncommutativity_honoured():
786
+ A, B = symbols("A B", commutative=False)
787
+ M = symbols('M', integer=True, positive=True)
788
+ p = Sum(A*B**n, (n, 1, M))
789
+ assert p.doit() == A*Piecewise((M, Eq(B, 1)),
790
+ ((B - B**(M + 1))*(1 - B)**(-1), True))
791
+
792
+ p = Sum(B**n*A, (n, 1, M))
793
+ assert p.doit() == Piecewise((M, Eq(B, 1)),
794
+ ((B - B**(M + 1))*(1 - B)**(-1), True))*A
795
+
796
+ p = Sum(B**n*A*B**n, (n, 1, M))
797
+ assert p.doit() == p
798
+
799
+
800
+ def test_issue_4171():
801
+ assert summation(factorial(2*k + 1)/factorial(2*k), (k, 0, oo)) is oo
802
+ assert summation(2*k + 1, (k, 0, oo)) is oo
803
+
804
+
805
+ def test_issue_6273():
806
+ assert Sum(x, (x, 1, n)).n(2, subs={n: 1}) == Float(1, 2)
807
+
808
+
809
+ def test_issue_6274():
810
+ assert Sum(x, (x, 1, 0)).doit() == 0
811
+ assert NS(Sum(x, (x, 1, 0))) == '0'
812
+ assert Sum(n, (n, 10, 5)).doit() == -30
813
+ assert NS(Sum(n, (n, 10, 5))) == '-30.0000000000000'
814
+
815
+
816
+ def test_simplify_sum():
817
+ y, t, v = symbols('y, t, v')
818
+
819
+ _simplify = lambda e: simplify(e, doit=False)
820
+ assert _simplify(Sum(x*y, (x, n, m), (y, a, k)) + \
821
+ Sum(y, (x, n, m), (y, a, k))) == Sum(y * (x + 1), (x, n, m), (y, a, k))
822
+ assert _simplify(Sum(x, (x, n, m)) + Sum(x, (x, m + 1, a))) == \
823
+ Sum(x, (x, n, a))
824
+ assert _simplify(Sum(x, (x, k + 1, a)) + Sum(x, (x, n, k))) == \
825
+ Sum(x, (x, n, a))
826
+ assert _simplify(Sum(x, (x, k + 1, a)) + Sum(x + 1, (x, n, k))) == \
827
+ Sum(x, (x, n, a)) + Sum(1, (x, n, k))
828
+ assert _simplify(Sum(x, (x, 0, 3)) * 3 + 3 * Sum(x, (x, 4, 6)) + \
829
+ 4 * Sum(z, (z, 0, 1))) == 4*Sum(z, (z, 0, 1)) + 3*Sum(x, (x, 0, 6))
830
+ assert _simplify(3*Sum(x**2, (x, a, b)) + Sum(x, (x, a, b))) == \
831
+ Sum(x*(3*x + 1), (x, a, b))
832
+ assert _simplify(Sum(x**3, (x, n, k)) * 3 + 3 * Sum(x, (x, n, k)) + \
833
+ 4 * y * Sum(z, (z, n, k))) + 1 == \
834
+ 4*y*Sum(z, (z, n, k)) + 3*Sum(x**3 + x, (x, n, k)) + 1
835
+ assert _simplify(Sum(x, (x, a, b)) + 1 + Sum(x, (x, b + 1, c))) == \
836
+ 1 + Sum(x, (x, a, c))
837
+ assert _simplify(Sum(x, (t, a, b)) + Sum(y, (t, a, b)) + \
838
+ Sum(x, (t, b+1, c))) == x * Sum(1, (t, a, c)) + y * Sum(1, (t, a, b))
839
+ assert _simplify(Sum(x, (t, a, b)) + Sum(x, (t, b+1, c)) + \
840
+ Sum(y, (t, a, b))) == x * Sum(1, (t, a, c)) + y * Sum(1, (t, a, b))
841
+ assert _simplify(Sum(x, (t, a, b)) + 2 * Sum(x, (t, b+1, c))) == \
842
+ _simplify(Sum(x, (t, a, b)) + Sum(x, (t, b+1, c)) + Sum(x, (t, b+1, c)))
843
+ assert _simplify(Sum(x, (x, a, b))*Sum(x**2, (x, a, b))) == \
844
+ Sum(x, (x, a, b)) * Sum(x**2, (x, a, b))
845
+ assert _simplify(Sum(x, (t, a, b)) + Sum(y, (t, a, b)) + Sum(z, (t, a, b))) \
846
+ == (x + y + z) * Sum(1, (t, a, b)) # issue 8596
847
+ assert _simplify(Sum(x, (t, a, b)) + Sum(y, (t, a, b)) + Sum(z, (t, a, b)) + \
848
+ Sum(v, (t, a, b))) == (x + y + z + v) * Sum(1, (t, a, b)) # issue 8596
849
+ assert _simplify(Sum(x * y, (x, a, b)) / (3 * y)) == \
850
+ (Sum(x, (x, a, b)) / 3)
851
+ assert _simplify(Sum(f(x) * y * z, (x, a, b)) / (y * z)) \
852
+ == Sum(f(x), (x, a, b))
853
+ assert _simplify(Sum(c * x, (x, a, b)) - c * Sum(x, (x, a, b))) == 0
854
+ assert _simplify(c * (Sum(x, (x, a, b)) + y)) == c * (y + Sum(x, (x, a, b)))
855
+ assert _simplify(c * (Sum(x, (x, a, b)) + y * Sum(x, (x, a, b)))) == \
856
+ c * (y + 1) * Sum(x, (x, a, b))
857
+ assert _simplify(Sum(Sum(c * x, (x, a, b)), (y, a, b))) == \
858
+ c * Sum(x, (x, a, b), (y, a, b))
859
+ assert _simplify(Sum((3 + y) * Sum(c * x, (x, a, b)), (y, a, b))) == \
860
+ c * Sum((3 + y), (y, a, b)) * Sum(x, (x, a, b))
861
+ assert _simplify(Sum((3 + t) * Sum(c * t, (x, a, b)), (y, a, b))) == \
862
+ c*t*(t + 3)*Sum(1, (x, a, b))*Sum(1, (y, a, b))
863
+ assert _simplify(Sum(Sum(d * t, (x, a, b - 1)) + \
864
+ Sum(d * t, (x, b, c)), (t, a, b))) == \
865
+ d * Sum(1, (x, a, c)) * Sum(t, (t, a, b))
866
+ assert _simplify(Sum(sin(t)**2 + cos(t)**2 + 1, (t, a, b))) == \
867
+ 2 * Sum(1, (t, a, b))
868
+
869
+
870
+ def test_change_index():
871
+ b, v, w = symbols('b, v, w', integer = True)
872
+
873
+ assert Sum(x, (x, a, b)).change_index(x, x + 1, y) == \
874
+ Sum(y - 1, (y, a + 1, b + 1))
875
+ assert Sum(x**2, (x, a, b)).change_index( x, x - 1) == \
876
+ Sum((x+1)**2, (x, a - 1, b - 1))
877
+ assert Sum(x**2, (x, a, b)).change_index( x, -x, y) == \
878
+ Sum((-y)**2, (y, -b, -a))
879
+ assert Sum(x, (x, a, b)).change_index( x, -x - 1) == \
880
+ Sum(-x - 1, (x, -b - 1, -a - 1))
881
+ assert Sum(x*y, (x, a, b), (y, c, d)).change_index( x, x - 1, z) == \
882
+ Sum((z + 1)*y, (z, a - 1, b - 1), (y, c, d))
883
+ assert Sum(x, (x, a, b)).change_index( x, x + v) == \
884
+ Sum(-v + x, (x, a + v, b + v))
885
+ assert Sum(x, (x, a, b)).change_index( x, -x - v) == \
886
+ Sum(-v - x, (x, -b - v, -a - v))
887
+ assert Sum(x, (x, a, b)).change_index(x, w*x, v) == \
888
+ Sum(v/w, (v, b*w, a*w))
889
+ raises(ValueError, lambda: Sum(x, (x, a, b)).change_index(x, 2*x))
890
+
891
+
892
+ def test_reorder():
893
+ b, y, c, d, z = symbols('b, y, c, d, z', integer = True)
894
+
895
+ assert Sum(x*y, (x, a, b), (y, c, d)).reorder((0, 1)) == \
896
+ Sum(x*y, (y, c, d), (x, a, b))
897
+ assert Sum(x, (x, a, b), (x, c, d)).reorder((0, 1)) == \
898
+ Sum(x, (x, c, d), (x, a, b))
899
+ assert Sum(x*y + z, (x, a, b), (z, m, n), (y, c, d)).reorder(\
900
+ (2, 0), (0, 1)) == Sum(x*y + z, (z, m, n), (y, c, d), (x, a, b))
901
+ assert Sum(x*y*z, (x, a, b), (y, c, d), (z, m, n)).reorder(\
902
+ (0, 1), (1, 2), (0, 2)) == Sum(x*y*z, (x, a, b), (z, m, n), (y, c, d))
903
+ assert Sum(x*y*z, (x, a, b), (y, c, d), (z, m, n)).reorder(\
904
+ (x, y), (y, z), (x, z)) == Sum(x*y*z, (x, a, b), (z, m, n), (y, c, d))
905
+ assert Sum(x*y, (x, a, b), (y, c, d)).reorder((x, 1)) == \
906
+ Sum(x*y, (y, c, d), (x, a, b))
907
+ assert Sum(x*y, (x, a, b), (y, c, d)).reorder((y, x)) == \
908
+ Sum(x*y, (y, c, d), (x, a, b))
909
+
910
+
911
+ def test_reverse_order():
912
+ assert Sum(x, (x, 0, 3)).reverse_order(0) == Sum(-x, (x, 4, -1))
913
+ assert Sum(x*y, (x, 1, 5), (y, 0, 6)).reverse_order(0, 1) == \
914
+ Sum(x*y, (x, 6, 0), (y, 7, -1))
915
+ assert Sum(x, (x, 1, 2)).reverse_order(0) == Sum(-x, (x, 3, 0))
916
+ assert Sum(x, (x, 1, 3)).reverse_order(0) == Sum(-x, (x, 4, 0))
917
+ assert Sum(x, (x, 1, a)).reverse_order(0) == Sum(-x, (x, a + 1, 0))
918
+ assert Sum(x, (x, a, 5)).reverse_order(0) == Sum(-x, (x, 6, a - 1))
919
+ assert Sum(x, (x, a + 1, a + 5)).reverse_order(0) == \
920
+ Sum(-x, (x, a + 6, a))
921
+ assert Sum(x, (x, a + 1, a + 2)).reverse_order(0) == \
922
+ Sum(-x, (x, a + 3, a))
923
+ assert Sum(x, (x, a + 1, a + 1)).reverse_order(0) == \
924
+ Sum(-x, (x, a + 2, a))
925
+ assert Sum(x, (x, a, b)).reverse_order(0) == Sum(-x, (x, b + 1, a - 1))
926
+ assert Sum(x, (x, a, b)).reverse_order(x) == Sum(-x, (x, b + 1, a - 1))
927
+ assert Sum(x*y, (x, a, b), (y, 2, 5)).reverse_order(x, 1) == \
928
+ Sum(x*y, (x, b + 1, a - 1), (y, 6, 1))
929
+ assert Sum(x*y, (x, a, b), (y, 2, 5)).reverse_order(y, x) == \
930
+ Sum(x*y, (x, b + 1, a - 1), (y, 6, 1))
931
+
932
+
933
+ def test_issue_7097():
934
+ assert sum(x**n/n for n in range(1, 401)) == summation(x**n/n, (n, 1, 400))
935
+
936
+
937
+ def test_factor_expand_subs():
938
+ # test factoring
939
+ assert Sum(4 * x, (x, 1, y)).factor() == 4 * Sum(x, (x, 1, y))
940
+ assert Sum(x * a, (x, 1, y)).factor() == a * Sum(x, (x, 1, y))
941
+ assert Sum(4 * x * a, (x, 1, y)).factor() == 4 * a * Sum(x, (x, 1, y))
942
+ assert Sum(4 * x * y, (x, 1, y)).factor() == 4 * y * Sum(x, (x, 1, y))
943
+
944
+ # test expand
945
+ _x = Symbol('x', zero=False)
946
+ assert Sum(x+1,(x,1,y)).expand() == Sum(x,(x,1,y)) + Sum(1,(x,1,y))
947
+ assert Sum(x+a*x**2,(x,1,y)).expand() == Sum(x,(x,1,y)) + Sum(a*x**2,(x,1,y))
948
+ assert Sum(_x**(n + 1)*(n + 1), (n, -1, oo)).expand() \
949
+ == Sum(n*_x*_x**n + _x*_x**n, (n, -1, oo))
950
+ assert Sum(x**(n + 1)*(n + 1), (n, -1, oo)).expand(power_exp=False) \
951
+ == Sum(n*x**(n + 1) + x**(n + 1), (n, -1, oo))
952
+ assert Sum(x**(n + 1)*(n + 1), (n, -1, oo)).expand(force=True) \
953
+ == Sum(x*x**n, (n, -1, oo)) + Sum(n*x*x**n, (n, -1, oo))
954
+ assert Sum(a*n+a*n**2,(n,0,4)).expand() \
955
+ == Sum(a*n,(n,0,4)) + Sum(a*n**2,(n,0,4))
956
+ assert Sum(_x**a*_x**n,(x,0,3)) \
957
+ == Sum(_x**(a+n),(x,0,3)).expand(power_exp=True)
958
+ _a, _n = symbols('a n', positive=True)
959
+ assert Sum(x**(_a+_n),(x,0,3)).expand(power_exp=True) \
960
+ == Sum(x**_a*x**_n, (x, 0, 3))
961
+ assert Sum(x**(_a-_n),(x,0,3)).expand(power_exp=True) \
962
+ == Sum(x**(_a-_n),(x,0,3)).expand(power_exp=False)
963
+
964
+ # test subs
965
+ assert Sum(1/(1+a*x**2),(x,0,3)).subs([(a,3)]) == Sum(1/(1+3*x**2),(x,0,3))
966
+ assert Sum(x*y,(x,0,y),(y,0,x)).subs([(x,3)]) == Sum(x*y,(x,0,y),(y,0,3))
967
+ assert Sum(x,(x,1,10)).subs([(x,y-2)]) == Sum(x,(x,1,10))
968
+ assert Sum(1/x,(x,1,10)).subs([(x,(3+n)**3)]) == Sum(1/x,(x,1,10))
969
+ assert Sum(1/x,(x,1,10)).subs([(x,3*x-2)]) == Sum(1/x,(x,1,10))
970
+
971
+
972
+ def test_distribution_over_equality():
973
+ assert Product(Eq(x*2, f(x)), (x, 1, 3)).doit() == Eq(48, f(1)*f(2)*f(3))
974
+ assert Sum(Eq(f(x), x**2), (x, 0, y)) == \
975
+ Eq(Sum(f(x), (x, 0, y)), Sum(x**2, (x, 0, y)))
976
+
977
+
978
+ def test_issue_2787():
979
+ n, k = symbols('n k', positive=True, integer=True)
980
+ p = symbols('p', positive=True)
981
+ binomial_dist = binomial(n, k)*p**k*(1 - p)**(n - k)
982
+ s = Sum(binomial_dist*k, (k, 0, n))
983
+ res = s.doit().simplify()
984
+ ans = Piecewise(
985
+ (n*p, x),
986
+ (Sum(k*p**k*binomial(n, k)*(1 - p)**(n - k), (k, 0, n)),
987
+ True)).subs(x, (Eq(n, 1) | (n > 1)) & (p/Abs(p - 1) <= 1))
988
+ ans2 = Piecewise(
989
+ (n*p, x),
990
+ (factorial(n)*Sum(p**k*(1 - p)**(-k + n)/
991
+ (factorial(-k + n)*factorial(k - 1)), (k, 0, n)),
992
+ True)).subs(x, (Eq(n, 1) | (n > 1)) & (p/Abs(p - 1) <= 1))
993
+ assert res in [ans, ans2] # XXX system dependent
994
+ # Issue #17165: make sure that another simplify does not complicate
995
+ # the result by much. Why didn't first simplify replace
996
+ # Eq(n, 1) | (n > 1) with True?
997
+ assert res.simplify().count_ops() <= res.count_ops() + 2
998
+
999
+
1000
+ def test_issue_4668():
1001
+ assert summation(1/n, (n, 2, oo)) is oo
1002
+
1003
+
1004
+ def test_matrix_sum():
1005
+ A = Matrix([[0, 1], [n, 0]])
1006
+
1007
+ result = Sum(A, (n, 0, 3)).doit()
1008
+ assert result == Matrix([[0, 4], [6, 0]])
1009
+ assert result.__class__ == ImmutableDenseMatrix
1010
+
1011
+ A = SparseMatrix([[0, 1], [n, 0]])
1012
+
1013
+ result = Sum(A, (n, 0, 3)).doit()
1014
+ assert result.__class__ == ImmutableSparseMatrix
1015
+
1016
+
1017
+ def test_failing_matrix_sum():
1018
+ n = Symbol('n')
1019
+ # TODO Implement matrix geometric series summation.
1020
+ A = Matrix([[0, 1, 0], [-1, 0, 0], [0, 0, 0]])
1021
+ assert Sum(A ** n, (n, 1, 4)).doit() == \
1022
+ Matrix([[0, 0, 0], [0, 0, 0], [0, 0, 0]])
1023
+ # issue sympy/sympy#16989
1024
+ assert summation(A**n, (n, 1, 1)) == A
1025
+
1026
+
1027
+ def test_indexed_idx_sum():
1028
+ i = symbols('i', cls=Idx)
1029
+ r = Indexed('r', i)
1030
+ assert Sum(r, (i, 0, 3)).doit() == sum(r.xreplace({i: j}) for j in range(4))
1031
+ assert Product(r, (i, 0, 3)).doit() == prod([r.xreplace({i: j}) for j in range(4)])
1032
+
1033
+ j = symbols('j', integer=True)
1034
+ assert Sum(r, (i, j, j+2)).doit() == sum(r.xreplace({i: j+k}) for k in range(3))
1035
+ assert Product(r, (i, j, j+2)).doit() == prod([r.xreplace({i: j+k}) for k in range(3)])
1036
+
1037
+ k = Idx('k', range=(1, 3))
1038
+ A = IndexedBase('A')
1039
+ assert Sum(A[k], k).doit() == sum(A[Idx(j, (1, 3))] for j in range(1, 4))
1040
+ assert Product(A[k], k).doit() == prod([A[Idx(j, (1, 3))] for j in range(1, 4)])
1041
+
1042
+ raises(ValueError, lambda: Sum(A[k], (k, 1, 4)))
1043
+ raises(ValueError, lambda: Sum(A[k], (k, 0, 3)))
1044
+ raises(ValueError, lambda: Sum(A[k], (k, 2, oo)))
1045
+
1046
+ raises(ValueError, lambda: Product(A[k], (k, 1, 4)))
1047
+ raises(ValueError, lambda: Product(A[k], (k, 0, 3)))
1048
+ raises(ValueError, lambda: Product(A[k], (k, 2, oo)))
1049
+
1050
+
1051
+ @slow
1052
+ def test_is_convergent():
1053
+ # divergence tests --
1054
+ assert Sum(n/(2*n + 1), (n, 1, oo)).is_convergent() is S.false
1055
+ assert Sum(factorial(n)/5**n, (n, 1, oo)).is_convergent() is S.false
1056
+ assert Sum(3**(-2*n - 1)*n**n, (n, 1, oo)).is_convergent() is S.false
1057
+ assert Sum((-1)**n*n, (n, 3, oo)).is_convergent() is S.false
1058
+ assert Sum((-1)**n, (n, 1, oo)).is_convergent() is S.false
1059
+ assert Sum(log(1/n), (n, 2, oo)).is_convergent() is S.false
1060
+
1061
+ # Raabe's test --
1062
+ assert Sum(Product((3*m),(m,1,n))/Product((3*m+4),(m,1,n)),(n,1,oo)).is_convergent() is S.true
1063
+
1064
+ # root test --
1065
+ assert Sum((-12)**n/n, (n, 1, oo)).is_convergent() is S.false
1066
+
1067
+ # integral test --
1068
+
1069
+ # p-series test --
1070
+ assert Sum(1/(n**2 + 1), (n, 1, oo)).is_convergent() is S.true
1071
+ assert Sum(1/n**Rational(6, 5), (n, 1, oo)).is_convergent() is S.true
1072
+ assert Sum(2/(n*sqrt(n - 1)), (n, 2, oo)).is_convergent() is S.true
1073
+ assert Sum(1/(sqrt(n)*sqrt(n)), (n, 2, oo)).is_convergent() is S.false
1074
+ assert Sum(factorial(n) / factorial(n+2), (n, 1, oo)).is_convergent() is S.true
1075
+ assert Sum(rf(5,n)/rf(7,n),(n,1,oo)).is_convergent() is S.true
1076
+ assert Sum((rf(1, n)*rf(2, n))/(rf(3, n)*factorial(n)),(n,1,oo)).is_convergent() is S.false
1077
+
1078
+ # comparison test --
1079
+ assert Sum(1/(n + log(n)), (n, 1, oo)).is_convergent() is S.false
1080
+ assert Sum(1/(n**2*log(n)), (n, 2, oo)).is_convergent() is S.true
1081
+ assert Sum(1/(n*log(n)), (n, 2, oo)).is_convergent() is S.false
1082
+ assert Sum(2/(n*log(n)*log(log(n))**2), (n, 5, oo)).is_convergent() is S.true
1083
+ assert Sum(2/(n*log(n)**2), (n, 2, oo)).is_convergent() is S.true
1084
+ assert Sum((n - 1)/(n**2*log(n)**3), (n, 2, oo)).is_convergent() is S.true
1085
+ assert Sum(1/(n*log(n)*log(log(n))), (n, 5, oo)).is_convergent() is S.false
1086
+ assert Sum((n - 1)/(n*log(n)**3), (n, 3, oo)).is_convergent() is S.false
1087
+ assert Sum(2/(n**2*log(n)), (n, 2, oo)).is_convergent() is S.true
1088
+ assert Sum(1/(n*sqrt(log(n))*log(log(n))), (n, 100, oo)).is_convergent() is S.false
1089
+ assert Sum(log(log(n))/(n*log(n)**2), (n, 100, oo)).is_convergent() is S.true
1090
+ assert Sum(log(n)/n**2, (n, 5, oo)).is_convergent() is S.true
1091
+
1092
+ # alternating series tests --
1093
+ assert Sum((-1)**(n - 1)/(n**2 - 1), (n, 3, oo)).is_convergent() is S.true
1094
+
1095
+ # with -negativeInfinite Limits
1096
+ assert Sum(1/(n**2 + 1), (n, -oo, 1)).is_convergent() is S.true
1097
+ assert Sum(1/(n - 1), (n, -oo, -1)).is_convergent() is S.false
1098
+ assert Sum(1/(n**2 - 1), (n, -oo, -5)).is_convergent() is S.true
1099
+ assert Sum(1/(n**2 - 1), (n, -oo, 2)).is_convergent() is S.true
1100
+ assert Sum(1/(n**2 - 1), (n, -oo, oo)).is_convergent() is S.true
1101
+
1102
+ # piecewise functions
1103
+ f = Piecewise((n**(-2), n <= 1), (n**2, n > 1))
1104
+ assert Sum(f, (n, 1, oo)).is_convergent() is S.false
1105
+ assert Sum(f, (n, -oo, oo)).is_convergent() is S.false
1106
+ assert Sum(f, (n, 1, 100)).is_convergent() is S.true
1107
+ #assert Sum(f, (n, -oo, 1)).is_convergent() is S.true
1108
+
1109
+ # integral test
1110
+
1111
+ assert Sum(log(n)/n**3, (n, 1, oo)).is_convergent() is S.true
1112
+ assert Sum(-log(n)/n**3, (n, 1, oo)).is_convergent() is S.true
1113
+ # the following function has maxima located at (x, y) =
1114
+ # (1.2, 0.43), (3.0, -0.25) and (6.8, 0.050)
1115
+ eq = (x - 2)*(x**2 - 6*x + 4)*exp(-x)
1116
+ assert Sum(eq, (x, 1, oo)).is_convergent() is S.true
1117
+ assert Sum(eq, (x, 1, 2)).is_convergent() is S.true
1118
+ assert Sum(1/(x**3), (x, 1, oo)).is_convergent() is S.true
1119
+ assert Sum(1/(x**S.Half), (x, 1, oo)).is_convergent() is S.false
1120
+
1121
+ # issue 19545
1122
+ assert Sum(1/n - 3/(3*n +2), (n, 1, oo)).is_convergent() is S.true
1123
+
1124
+ # issue 19836
1125
+ assert Sum(4/(n + 2) - 5/(n + 1) + 1/n,(n, 7, oo)).is_convergent() is S.true
1126
+
1127
+
1128
+ def test_is_absolutely_convergent():
1129
+ assert Sum((-1)**n, (n, 1, oo)).is_absolutely_convergent() is S.false
1130
+ assert Sum((-1)**n/n**2, (n, 1, oo)).is_absolutely_convergent() is S.true
1131
+
1132
+
1133
+ @XFAIL
1134
+ def test_convergent_failing():
1135
+ # dirichlet tests
1136
+ assert Sum(sin(n)/n, (n, 1, oo)).is_convergent() is S.true
1137
+ assert Sum(sin(2*n)/n, (n, 1, oo)).is_convergent() is S.true
1138
+
1139
+
1140
+ def test_issue_6966():
1141
+ i, k, m = symbols('i k m', integer=True)
1142
+ z_i, q_i = symbols('z_i q_i')
1143
+ a_k = Sum(-q_i*z_i/k,(i,1,m))
1144
+ b_k = a_k.diff(z_i)
1145
+ assert isinstance(b_k, Sum)
1146
+ assert b_k == Sum(-q_i/k,(i,1,m))
1147
+
1148
+
1149
+ def test_issue_10156():
1150
+ cx = Sum(2*y**2*x, (x, 1,3))
1151
+ e = 2*y*Sum(2*cx*x**2, (x, 1, 9))
1152
+ assert e.factor() == \
1153
+ 8*y**3*Sum(x, (x, 1, 3))*Sum(x**2, (x, 1, 9))
1154
+
1155
+
1156
+ def test_issue_10973():
1157
+ assert Sum((-n + (n**3 + 1)**(S(1)/3))/log(n), (n, 1, oo)).is_convergent() is S.true
1158
+
1159
+
1160
+ def test_issue_14129():
1161
+ x = Symbol('x', zero=False)
1162
+ assert Sum( k*x**k, (k, 0, n-1)).doit() == \
1163
+ Piecewise((n**2/2 - n/2, Eq(x, 1)), ((n*x*x**n -
1164
+ n*x**n - x*x**n + x)/(x - 1)**2, True))
1165
+ assert Sum( x**k, (k, 0, n-1)).doit() == \
1166
+ Piecewise((n, Eq(x, 1)), ((-x**n + 1)/(-x + 1), True))
1167
+ assert Sum( k*(x/y+x)**k, (k, 0, n-1)).doit() == \
1168
+ Piecewise((n*(n - 1)/2, Eq(x, y/(y + 1))),
1169
+ (x*(y + 1)*(n*x*y*(x + x/y)**(n - 1) +
1170
+ n*x*(x + x/y)**(n - 1) - n*y*(x + x/y)**(n - 1) -
1171
+ x*y*(x + x/y)**(n - 1) - x*(x + x/y)**(n - 1) + y)/
1172
+ (x*y + x - y)**2, True))
1173
+
1174
+
1175
+ def test_issue_14112():
1176
+ assert Sum((-1)**n/sqrt(n), (n, 1, oo)).is_absolutely_convergent() is S.false
1177
+ assert Sum((-1)**(2*n)/n, (n, 1, oo)).is_convergent() is S.false
1178
+ assert Sum((-2)**n + (-3)**n, (n, 1, oo)).is_convergent() is S.false
1179
+
1180
+
1181
+ def test_issue_14219():
1182
+ A = diag(0, 2, -3)
1183
+ res = diag(1, 15, -20)
1184
+ assert Sum(A**n, (n, 0, 3)).doit() == res
1185
+
1186
+
1187
+ def test_sin_times_absolutely_convergent():
1188
+ assert Sum(sin(n) / n**3, (n, 1, oo)).is_convergent() is S.true
1189
+ assert Sum(sin(n) * log(n) / n**3, (n, 1, oo)).is_convergent() is S.true
1190
+
1191
+
1192
+ def test_issue_14111():
1193
+ assert Sum(1/log(log(n)), (n, 22, oo)).is_convergent() is S.false
1194
+
1195
+
1196
+ def test_issue_14484():
1197
+ assert Sum(sin(n)/log(log(n)), (n, 22, oo)).is_convergent() is S.false
1198
+
1199
+
1200
+ def test_issue_14640():
1201
+ i, n = symbols("i n", integer=True)
1202
+ a, b, c = symbols("a b c", zero=False)
1203
+
1204
+ assert Sum(a**-i/(a - b), (i, 0, n)).doit() == Sum(
1205
+ 1/(a*a**i - a**i*b), (i, 0, n)).doit() == Piecewise(
1206
+ (n + 1, Eq(1/a, 1)),
1207
+ ((-a**(-n - 1) + 1)/(1 - 1/a), True))/(a - b)
1208
+
1209
+ assert Sum((b*a**i - c*a**i)**-2, (i, 0, n)).doit() == Piecewise(
1210
+ (n + 1, Eq(a**(-2), 1)),
1211
+ ((-a**(-2*n - 2) + 1)/(1 - 1/a**2), True))/(b - c)**2
1212
+
1213
+ s = Sum(i*(a**(n - i) - b**(n - i))/(a - b), (i, 0, n)).doit()
1214
+ assert not s.has(Sum)
1215
+ assert s.subs({a: 2, b: 3, n: 5}) == 122
1216
+
1217
+
1218
+ def test_issue_15943():
1219
+ s = Sum(binomial(n, k)*factorial(n - k), (k, 0, n)).doit().rewrite(gamma)
1220
+ assert s == -E*(n + 1)*gamma(n + 1)*lowergamma(n + 1, 1)/gamma(n + 2
1221
+ ) + E*gamma(n + 1)
1222
+ assert s.simplify() == E*(factorial(n) - lowergamma(n + 1, 1))
1223
+
1224
+
1225
+ def test_Sum_dummy_eq():
1226
+ assert not Sum(x, (x, a, b)).dummy_eq(1)
1227
+ assert not Sum(x, (x, a, b)).dummy_eq(Sum(x, (x, a, b), (a, 1, 2)))
1228
+ assert not Sum(x, (x, a, b)).dummy_eq(Sum(x, (x, a, c)))
1229
+ assert Sum(x, (x, a, b)).dummy_eq(Sum(x, (x, a, b)))
1230
+ d = Dummy()
1231
+ assert Sum(x, (x, a, d)).dummy_eq(Sum(x, (x, a, c)), c)
1232
+ assert not Sum(x, (x, a, d)).dummy_eq(Sum(x, (x, a, c)))
1233
+ assert Sum(x, (x, a, c)).dummy_eq(Sum(y, (y, a, c)))
1234
+ assert Sum(x, (x, a, d)).dummy_eq(Sum(y, (y, a, c)), c)
1235
+ assert not Sum(x, (x, a, d)).dummy_eq(Sum(y, (y, a, c)))
1236
+
1237
+
1238
+ def test_issue_15852():
1239
+ assert summation(x**y*y, (y, -oo, oo)).doit() == Sum(x**y*y, (y, -oo, oo))
1240
+
1241
+
1242
+ def test_exceptions():
1243
+ S = Sum(x, (x, a, b))
1244
+ raises(ValueError, lambda: S.change_index(x, x**2, y))
1245
+ S = Sum(x, (x, a, b), (x, 1, 4))
1246
+ raises(ValueError, lambda: S.index(x))
1247
+ S = Sum(x, (x, a, b), (y, 1, 4))
1248
+ raises(ValueError, lambda: S.reorder([x]))
1249
+ S = Sum(x, (x, y, b), (y, 1, 4))
1250
+ raises(ReorderError, lambda: S.reorder_limit(0, 1))
1251
+ S = Sum(x*y, (x, a, b), (y, 1, 4))
1252
+ raises(NotImplementedError, lambda: S.is_convergent())
1253
+
1254
+
1255
+ def test_sumproducts_assumptions():
1256
+ M = Symbol('M', integer=True, positive=True)
1257
+
1258
+ m = Symbol('m', integer=True)
1259
+ for func in [Sum, Product]:
1260
+ assert func(m, (m, -M, M)).is_positive is None
1261
+ assert func(m, (m, -M, M)).is_nonpositive is None
1262
+ assert func(m, (m, -M, M)).is_negative is None
1263
+ assert func(m, (m, -M, M)).is_nonnegative is None
1264
+ assert func(m, (m, -M, M)).is_finite is True
1265
+
1266
+ m = Symbol('m', integer=True, nonnegative=True)
1267
+ for func in [Sum, Product]:
1268
+ assert func(m, (m, 0, M)).is_positive is None
1269
+ assert func(m, (m, 0, M)).is_nonpositive is None
1270
+ assert func(m, (m, 0, M)).is_negative is False
1271
+ assert func(m, (m, 0, M)).is_nonnegative is True
1272
+ assert func(m, (m, 0, M)).is_finite is True
1273
+
1274
+ m = Symbol('m', integer=True, positive=True)
1275
+ for func in [Sum, Product]:
1276
+ assert func(m, (m, 1, M)).is_positive is True
1277
+ assert func(m, (m, 1, M)).is_nonpositive is False
1278
+ assert func(m, (m, 1, M)).is_negative is False
1279
+ assert func(m, (m, 1, M)).is_nonnegative is True
1280
+ assert func(m, (m, 1, M)).is_finite is True
1281
+
1282
+ m = Symbol('m', integer=True, negative=True)
1283
+ assert Sum(m, (m, -M, -1)).is_positive is False
1284
+ assert Sum(m, (m, -M, -1)).is_nonpositive is True
1285
+ assert Sum(m, (m, -M, -1)).is_negative is True
1286
+ assert Sum(m, (m, -M, -1)).is_nonnegative is False
1287
+ assert Sum(m, (m, -M, -1)).is_finite is True
1288
+ assert Product(m, (m, -M, -1)).is_positive is None
1289
+ assert Product(m, (m, -M, -1)).is_nonpositive is None
1290
+ assert Product(m, (m, -M, -1)).is_negative is None
1291
+ assert Product(m, (m, -M, -1)).is_nonnegative is None
1292
+ assert Product(m, (m, -M, -1)).is_finite is True
1293
+
1294
+ m = Symbol('m', integer=True, nonpositive=True)
1295
+ assert Sum(m, (m, -M, 0)).is_positive is False
1296
+ assert Sum(m, (m, -M, 0)).is_nonpositive is True
1297
+ assert Sum(m, (m, -M, 0)).is_negative is None
1298
+ assert Sum(m, (m, -M, 0)).is_nonnegative is None
1299
+ assert Sum(m, (m, -M, 0)).is_finite is True
1300
+ assert Product(m, (m, -M, 0)).is_positive is None
1301
+ assert Product(m, (m, -M, 0)).is_nonpositive is None
1302
+ assert Product(m, (m, -M, 0)).is_negative is None
1303
+ assert Product(m, (m, -M, 0)).is_nonnegative is None
1304
+ assert Product(m, (m, -M, 0)).is_finite is True
1305
+
1306
+ m = Symbol('m', integer=True)
1307
+ assert Sum(2, (m, 0, oo)).is_positive is None
1308
+ assert Sum(2, (m, 0, oo)).is_nonpositive is None
1309
+ assert Sum(2, (m, 0, oo)).is_negative is None
1310
+ assert Sum(2, (m, 0, oo)).is_nonnegative is None
1311
+ assert Sum(2, (m, 0, oo)).is_finite is None
1312
+
1313
+ assert Product(2, (m, 0, oo)).is_positive is None
1314
+ assert Product(2, (m, 0, oo)).is_nonpositive is None
1315
+ assert Product(2, (m, 0, oo)).is_negative is False
1316
+ assert Product(2, (m, 0, oo)).is_nonnegative is None
1317
+ assert Product(2, (m, 0, oo)).is_finite is None
1318
+
1319
+ assert Product(0, (x, M, M-1)).is_positive is True
1320
+ assert Product(0, (x, M, M-1)).is_finite is True
1321
+
1322
+
1323
+ def test_expand_with_assumptions():
1324
+ M = Symbol('M', integer=True, positive=True)
1325
+ x = Symbol('x', positive=True)
1326
+ m = Symbol('m', nonnegative=True)
1327
+ assert log(Product(x**m, (m, 0, M))).expand() == Sum(m*log(x), (m, 0, M))
1328
+ assert log(Product(exp(x**m), (m, 0, M))).expand() == Sum(x**m, (m, 0, M))
1329
+ assert log(Product(x**m, (m, 0, M))).rewrite(Sum).expand() == Sum(m*log(x), (m, 0, M))
1330
+ assert log(Product(exp(x**m), (m, 0, M))).rewrite(Sum).expand() == Sum(x**m, (m, 0, M))
1331
+
1332
+ n = Symbol('n', nonnegative=True)
1333
+ i, j = symbols('i,j', positive=True, integer=True)
1334
+ x, y = symbols('x,y', positive=True)
1335
+ assert log(Product(x**i*y**j, (i, 1, n), (j, 1, m))).expand() \
1336
+ == Sum(i*log(x) + j*log(y), (i, 1, n), (j, 1, m))
1337
+
1338
+ m = Symbol('m', nonnegative=True, integer=True)
1339
+ s = Sum(x**m, (m, 0, M))
1340
+ s_as_product = s.rewrite(Product)
1341
+ assert s_as_product.has(Product)
1342
+ assert s_as_product == log(Product(exp(x**m), (m, 0, M)))
1343
+ assert s_as_product.expand() == s
1344
+ s5 = s.subs(M, 5)
1345
+ s5_as_product = s5.rewrite(Product)
1346
+ assert s5_as_product.has(Product)
1347
+ assert s5_as_product.doit().expand() == s5.doit()
1348
+
1349
+
1350
+ def test_has_finite_limits():
1351
+ x = Symbol('x')
1352
+ assert Sum(1, (x, 1, 9)).has_finite_limits is True
1353
+ assert Sum(1, (x, 1, oo)).has_finite_limits is False
1354
+ M = Symbol('M')
1355
+ assert Sum(1, (x, 1, M)).has_finite_limits is None
1356
+ M = Symbol('M', positive=True)
1357
+ assert Sum(1, (x, 1, M)).has_finite_limits is True
1358
+ x = Symbol('x', positive=True)
1359
+ M = Symbol('M')
1360
+ assert Sum(1, (x, 1, M)).has_finite_limits is True
1361
+
1362
+ assert Sum(1, (x, 1, M), (y, -oo, oo)).has_finite_limits is False
1363
+
1364
+ def test_has_reversed_limits():
1365
+ assert Sum(1, (x, 1, 1)).has_reversed_limits is False
1366
+ assert Sum(1, (x, 1, 9)).has_reversed_limits is False
1367
+ assert Sum(1, (x, 1, -9)).has_reversed_limits is True
1368
+ assert Sum(1, (x, 1, 0)).has_reversed_limits is True
1369
+ assert Sum(1, (x, 1, oo)).has_reversed_limits is False
1370
+ M = Symbol('M')
1371
+ assert Sum(1, (x, 1, M)).has_reversed_limits is None
1372
+ M = Symbol('M', positive=True, integer=True)
1373
+ assert Sum(1, (x, 1, M)).has_reversed_limits is False
1374
+ assert Sum(1, (x, 1, M), (y, -oo, oo)).has_reversed_limits is False
1375
+ M = Symbol('M', negative=True)
1376
+ assert Sum(1, (x, 1, M)).has_reversed_limits is True
1377
+
1378
+ assert Sum(1, (x, 1, M), (y, -oo, oo)).has_reversed_limits is True
1379
+ assert Sum(1, (x, oo, oo)).has_reversed_limits is None
1380
+
1381
+
1382
+ def test_has_empty_sequence():
1383
+ assert Sum(1, (x, 1, 1)).has_empty_sequence is False
1384
+ assert Sum(1, (x, 1, 9)).has_empty_sequence is False
1385
+ assert Sum(1, (x, 1, -9)).has_empty_sequence is False
1386
+ assert Sum(1, (x, 1, 0)).has_empty_sequence is True
1387
+ assert Sum(1, (x, y, y - 1)).has_empty_sequence is True
1388
+ assert Sum(1, (x, 3, 2), (y, -oo, oo)).has_empty_sequence is True
1389
+ assert Sum(1, (y, -oo, oo), (x, 3, 2)).has_empty_sequence is True
1390
+ assert Sum(1, (x, oo, oo)).has_empty_sequence is False
1391
+
1392
+
1393
+ def test_empty_sequence():
1394
+ assert Product(x*y, (x, -oo, oo), (y, 1, 0)).doit() == 1
1395
+ assert Product(x*y, (y, 1, 0), (x, -oo, oo)).doit() == 1
1396
+ assert Sum(x, (x, -oo, oo), (y, 1, 0)).doit() == 0
1397
+ assert Sum(x, (y, 1, 0), (x, -oo, oo)).doit() == 0
1398
+
1399
+
1400
+ def test_issue_8016():
1401
+ k = Symbol('k', integer=True)
1402
+ n, m = symbols('n, m', integer=True, positive=True)
1403
+ s = Sum(binomial(m, k)*binomial(m, n - k)*(-1)**k, (k, 0, n))
1404
+ assert s.doit().simplify() == \
1405
+ cos(pi*n/2)*gamma(m + 1)/gamma(n/2 + 1)/gamma(m - n/2 + 1)
1406
+
1407
+
1408
+ def test_issue_14313():
1409
+ assert Sum(S.Half**floor(n/2), (n, 1, oo)).is_convergent()
1410
+
1411
+
1412
+ def test_issue_14563():
1413
+ # The assertion was failing due to no assumptions methods in Sums and Product
1414
+ assert 1 % Sum(1, (x, 0, 1)) == 1
1415
+
1416
+
1417
+ def test_issue_16735():
1418
+ assert Sum(5**n/gamma(n+1), (n, 1, oo)).is_convergent() is S.true
1419
+
1420
+
1421
+ def test_issue_14871():
1422
+ assert Sum((Rational(1, 10))**n*rf(0, n)/factorial(n), (n, 0, oo)).rewrite(factorial).doit() == 1
1423
+
1424
+
1425
+ def test_issue_17165():
1426
+ n = symbols("n", integer=True)
1427
+ x = symbols('x')
1428
+ s = (x*Sum(x**n, (n, -1, oo)))
1429
+ ssimp = s.doit().simplify()
1430
+
1431
+ assert ssimp == Piecewise((-1/(x - 1), (x > -1) & (x < 1)),
1432
+ (x*Sum(x**n, (n, -1, oo)), True)), ssimp
1433
+ assert ssimp.simplify() == ssimp
1434
+
1435
+
1436
+ def test_issue_19379():
1437
+ assert Sum(factorial(n)/factorial(n + 2), (n, 1, oo)).is_convergent() is S.true
1438
+
1439
+
1440
+ def test_issue_20777():
1441
+ assert Sum(exp(x*sin(n/m)), (n, 1, m)).doit() == Sum(exp(x*sin(n/m)), (n, 1, m))
1442
+
1443
+
1444
+ def test__dummy_with_inherited_properties_concrete():
1445
+ x = Symbol('x')
1446
+
1447
+ from sympy.core.containers import Tuple
1448
+ d = _dummy_with_inherited_properties_concrete(Tuple(x, 0, 5))
1449
+ assert d.is_real
1450
+ assert d.is_integer
1451
+ assert d.is_nonnegative
1452
+ assert d.is_extended_nonnegative
1453
+
1454
+ d = _dummy_with_inherited_properties_concrete(Tuple(x, 1, 9))
1455
+ assert d.is_real
1456
+ assert d.is_integer
1457
+ assert d.is_positive
1458
+ assert d.is_odd is None
1459
+
1460
+ d = _dummy_with_inherited_properties_concrete(Tuple(x, -5, 5))
1461
+ assert d.is_real
1462
+ assert d.is_integer
1463
+ assert d.is_positive is None
1464
+ assert d.is_extended_nonnegative is None
1465
+ assert d.is_odd is None
1466
+
1467
+ d = _dummy_with_inherited_properties_concrete(Tuple(x, -1.5, 1.5))
1468
+ assert d.is_real
1469
+ assert d.is_integer is None
1470
+ assert d.is_positive is None
1471
+ assert d.is_extended_nonnegative is None
1472
+
1473
+ N = Symbol('N', integer=True, positive=True)
1474
+ d = _dummy_with_inherited_properties_concrete(Tuple(x, 2, N))
1475
+ assert d.is_real
1476
+ assert d.is_positive
1477
+ assert d.is_integer
1478
+
1479
+ # Return None if no assumptions are added
1480
+ N = Symbol('N', integer=True, positive=True)
1481
+ d = _dummy_with_inherited_properties_concrete(Tuple(N, 2, 4))
1482
+ assert d is None
1483
+
1484
+ x = Symbol('x', negative=True)
1485
+ raises(InconsistentAssumptions,
1486
+ lambda: _dummy_with_inherited_properties_concrete(Tuple(x, 1, 5)))
1487
+
1488
+
1489
+ def test_matrixsymbol_summation_numerical_limits():
1490
+ A = MatrixSymbol('A', 3, 3)
1491
+ n = Symbol('n', integer=True)
1492
+
1493
+ assert Sum(A**n, (n, 0, 2)).doit() == Identity(3) + A + A**2
1494
+ assert Sum(A, (n, 0, 2)).doit() == 3*A
1495
+ assert Sum(n*A, (n, 0, 2)).doit() == 3*A
1496
+
1497
+ B = Matrix([[0, n, 0], [-1, 0, 0], [0, 0, 2]])
1498
+ ans = Matrix([[0, 6, 0], [-4, 0, 0], [0, 0, 8]]) + 4*A
1499
+ assert Sum(A+B, (n, 0, 3)).doit() == ans
1500
+ ans = A*Matrix([[0, 6, 0], [-4, 0, 0], [0, 0, 8]])
1501
+ assert Sum(A*B, (n, 0, 3)).doit() == ans
1502
+
1503
+ ans = (A**2*Matrix([[-2, 0, 0], [0,-2, 0], [0, 0, 4]]) +
1504
+ A**3*Matrix([[0, -9, 0], [3, 0, 0], [0, 0, 8]]) +
1505
+ A*Matrix([[0, 1, 0], [-1, 0, 0], [0, 0, 2]]))
1506
+ assert Sum(A**n*B**n, (n, 1, 3)).doit() == ans
1507
+
1508
+
1509
+ def test_issue_21651():
1510
+ i = Symbol('i')
1511
+ a = Sum(floor(2*2**(-i)), (i, S.One, 2))
1512
+ assert a.doit() == S.One
1513
+
1514
+
1515
+ @XFAIL
1516
+ def test_matrixsymbol_summation_symbolic_limits():
1517
+ N = Symbol('N', integer=True, positive=True)
1518
+
1519
+ A = MatrixSymbol('A', 3, 3)
1520
+ n = Symbol('n', integer=True)
1521
+ assert Sum(A, (n, 0, N)).doit() == (N+1)*A
1522
+ assert Sum(n*A, (n, 0, N)).doit() == (N**2/2+N/2)*A
1523
+
1524
+
1525
+ def test_summation_by_residues():
1526
+ x = Symbol('x')
1527
+
1528
+ # Examples from Nakhle H. Asmar, Loukas Grafakos,
1529
+ # Complex Analysis with Applications
1530
+ assert eval_sum_residue(1 / (x**2 + 1), (x, -oo, oo)) == pi/tanh(pi)
1531
+ assert eval_sum_residue(1 / x**6, (x, S(1), oo)) == pi**6/945
1532
+ assert eval_sum_residue(1 / (x**2 + 9), (x, -oo, oo)) == pi/(3*tanh(3*pi))
1533
+ assert eval_sum_residue(1 / (x**2 + 1)**2, (x, -oo, oo)).cancel() == \
1534
+ (-pi**2*tanh(pi)**2 + pi*tanh(pi) + pi**2)/(2*tanh(pi)**2)
1535
+ assert eval_sum_residue(x**2 / (x**2 + 1)**2, (x, -oo, oo)).cancel() == \
1536
+ (-pi**2 + pi*tanh(pi) + pi**2*tanh(pi)**2)/(2*tanh(pi)**2)
1537
+ assert eval_sum_residue(1 / (4*x**2 - 1), (x, -oo, oo)) == 0
1538
+ assert eval_sum_residue(x**2 / (x**2 - S(1)/4)**2, (x, -oo, oo)) == pi**2/2
1539
+ assert eval_sum_residue(1 / (4*x**2 - 1)**2, (x, -oo, oo)) == pi**2/8
1540
+ assert eval_sum_residue(1 / ((x - S(1)/2)**2 + 1), (x, -oo, oo)) == pi*tanh(pi)
1541
+ assert eval_sum_residue(1 / x**2, (x, S(1), oo)) == pi**2/6
1542
+ assert eval_sum_residue(1 / x**4, (x, S(1), oo)) == pi**4/90
1543
+ assert eval_sum_residue(1 / x**2 / (x**2 + 4), (x, S(1), oo)) == \
1544
+ -pi*(-pi/12 - 1/(16*pi) + 1/(8*tanh(2*pi)))/2
1545
+
1546
+ # Some examples made from 1 / (x**2 + 1)
1547
+ assert eval_sum_residue(1 / (x**2 + 1), (x, S(0), oo)) == \
1548
+ S(1)/2 + pi/(2*tanh(pi))
1549
+ assert eval_sum_residue(1 / (x**2 + 1), (x, S(1), oo)) == \
1550
+ -S(1)/2 + pi/(2*tanh(pi))
1551
+ assert eval_sum_residue(1 / (x**2 + 1), (x, S(-1), oo)) == \
1552
+ 1 + pi/(2*tanh(pi))
1553
+ assert eval_sum_residue((-1)**x / (x**2 + 1), (x, -oo, oo)) == \
1554
+ pi/sinh(pi)
1555
+ assert eval_sum_residue((-1)**x / (x**2 + 1), (x, S(0), oo)) == \
1556
+ pi/(2*sinh(pi)) + S(1)/2
1557
+ assert eval_sum_residue((-1)**x / (x**2 + 1), (x, S(1), oo)) == \
1558
+ -S(1)/2 + pi/(2*sinh(pi))
1559
+ assert eval_sum_residue((-1)**x / (x**2 + 1), (x, S(-1), oo)) == \
1560
+ pi/(2*sinh(pi))
1561
+
1562
+ # Some examples made from shifting of 1 / (x**2 + 1)
1563
+ assert eval_sum_residue(1 / (x**2 + 2*x + 2), (x, S(-1), oo)) == S(1)/2 + pi/(2*tanh(pi))
1564
+ assert eval_sum_residue(1 / (x**2 + 4*x + 5), (x, S(-2), oo)) == S(1)/2 + pi/(2*tanh(pi))
1565
+ assert eval_sum_residue(1 / (x**2 - 2*x + 2), (x, S(1), oo)) == S(1)/2 + pi/(2*tanh(pi))
1566
+ assert eval_sum_residue(1 / (x**2 - 4*x + 5), (x, S(2), oo)) == S(1)/2 + pi/(2*tanh(pi))
1567
+ assert eval_sum_residue((-1)**x * -1 / (x**2 + 2*x + 2), (x, S(-1), oo)) == S(1)/2 + pi/(2*sinh(pi))
1568
+ assert eval_sum_residue((-1)**x * -1 / (x**2 -2*x + 2), (x, S(1), oo)) == S(1)/2 + pi/(2*sinh(pi))
1569
+
1570
+ # Some examples made from 1 / x**2
1571
+ assert eval_sum_residue(1 / x**2, (x, S(2), oo)) == -1 + pi**2/6
1572
+ assert eval_sum_residue(1 / x**2, (x, S(3), oo)) == -S(5)/4 + pi**2/6
1573
+ assert eval_sum_residue((-1)**x / x**2, (x, S(1), oo)) == -pi**2/12
1574
+ assert eval_sum_residue((-1)**x / x**2, (x, S(2), oo)) == 1 - pi**2/12
1575
+
1576
+
1577
+ @slow
1578
+ def test_summation_by_residues_failing():
1579
+ x = Symbol('x')
1580
+
1581
+ # Failing because of the bug in residue computation
1582
+ assert eval_sum_residue(x**2 / (x**4 + 1), (x, S(1), oo))
1583
+ assert eval_sum_residue(1 / ((x - 1)*(x - 2) + 1), (x, -oo, oo)) != 0
1584
+
1585
+
1586
+ def test_process_limits():
1587
+ from sympy.concrete.expr_with_limits import _process_limits
1588
+
1589
+ # these should be (x, Range(3)) not Range(3)
1590
+ raises(ValueError, lambda: _process_limits(
1591
+ Range(3), discrete=True))
1592
+ raises(ValueError, lambda: _process_limits(
1593
+ Range(3), discrete=False))
1594
+ # these should be (x, union) not union
1595
+ # (but then we would get a TypeError because we don't
1596
+ # handle non-contiguous sets: see below use of `union`)
1597
+ union = Or(x < 1, x > 3).as_set()
1598
+ raises(ValueError, lambda: _process_limits(
1599
+ union, discrete=True))
1600
+ raises(ValueError, lambda: _process_limits(
1601
+ union, discrete=False))
1602
+
1603
+ # error not triggered if not needed
1604
+ assert _process_limits((x, 1, 2)) == ([(x, 1, 2)], 1)
1605
+
1606
+ # this equivalence is used to detect Reals in _process_limits
1607
+ assert isinstance(S.Reals, Interval)
1608
+
1609
+ C = Integral # continuous limits
1610
+ assert C(x, x >= 5) == C(x, (x, 5, oo))
1611
+ assert C(x, x < 3) == C(x, (x, -oo, 3))
1612
+ ans = C(x, (x, 0, 3))
1613
+ assert C(x, And(x >= 0, x < 3)) == ans
1614
+ assert C(x, (x, Interval.Ropen(0, 3))) == ans
1615
+ raises(TypeError, lambda: C(x, (x, Range(3))))
1616
+
1617
+ # discrete limits
1618
+ for D in (Sum, Product):
1619
+ r, ans = Range(3, 10, 2), D(2*x + 3, (x, 0, 3))
1620
+ assert D(x, (x, r)) == ans
1621
+ assert D(x, (x, r.reversed)) == ans
1622
+ r, ans = Range(3, oo, 2), D(2*x + 3, (x, 0, oo))
1623
+ assert D(x, (x, r)) == ans
1624
+ assert D(x, (x, r.reversed)) == ans
1625
+ r, ans = Range(-oo, 5, 2), D(3 - 2*x, (x, 0, oo))
1626
+ assert D(x, (x, r)) == ans
1627
+ assert D(x, (x, r.reversed)) == ans
1628
+ raises(TypeError, lambda: D(x, x > 0))
1629
+ raises(ValueError, lambda: D(x, Interval(1, 3)))
1630
+ raises(NotImplementedError, lambda: D(x, (x, union)))
1631
+
1632
+
1633
+ def test_pr_22677():
1634
+ b = Symbol('b', integer=True, positive=True)
1635
+ assert Sum(1/x**2,(x, 0, b)).doit() == Sum(x**(-2), (x, 0, b))
1636
+ assert Sum(1/(x - b)**2,(x, 0, b-1)).doit() == Sum(
1637
+ (-b + x)**(-2), (x, 0, b - 1))
1638
+
1639
+
1640
+ def test_issue_23952():
1641
+ p, q = symbols("p q", real=True, nonnegative=True)
1642
+ k1, k2 = symbols("k1 k2", integer=True, nonnegative=True)
1643
+ n = Symbol("n", integer=True, positive=True)
1644
+ expr = Sum(abs(k1 - k2)*p**k1 *(1 - q)**(n - k2),
1645
+ (k1, 0, n), (k2, 0, n))
1646
+ assert expr.subs(p,0).subs(q,1).subs(n, 3).doit() == 3
tuning-competition-baseline/.venv/lib/python3.11/site-packages/sympy/matrices/common.py ADDED
@@ -0,0 +1,3263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A module contining deprecated matrix mixin classes.
3
+
4
+ The classes in this module are deprecated and will be removed in a future
5
+ release. They are kept here for backwards compatibility in case downstream
6
+ code was subclassing them.
7
+
8
+ Importing anything else from this module is deprecated so anything here
9
+ should either not be used or should be imported from somewhere else.
10
+ """
11
+
12
+ from collections import defaultdict
13
+ from collections.abc import Iterable
14
+ from inspect import isfunction
15
+ from functools import reduce
16
+
17
+ from sympy.assumptions.refine import refine
18
+ from sympy.core import SympifyError, Add
19
+ from sympy.core.basic import Atom
20
+ from sympy.core.decorators import call_highest_priority
21
+ from sympy.core.logic import fuzzy_and, FuzzyBool
22
+ from sympy.core.numbers import Integer
23
+ from sympy.core.mod import Mod
24
+ from sympy.core.singleton import S
25
+ from sympy.core.symbol import Symbol
26
+ from sympy.core.sympify import sympify
27
+ from sympy.functions.elementary.complexes import Abs, re, im
28
+ from sympy.utilities.exceptions import sympy_deprecation_warning
29
+ from .utilities import _dotprodsimp, _simplify
30
+ from sympy.polys.polytools import Poly
31
+ from sympy.utilities.iterables import flatten, is_sequence
32
+ from sympy.utilities.misc import as_int, filldedent
33
+ from sympy.tensor.array import NDimArray
34
+
35
+ from .utilities import _get_intermediate_simp_bool
36
+
37
+
38
+ # These exception types were previously defined in this module but were moved
39
+ # to exceptions.py. We reimport them here for backwards compatibility in case
40
+ # downstream code was importing them from here.
41
+ from .exceptions import ( # noqa: F401
42
+ MatrixError, ShapeError, NonSquareMatrixError, NonInvertibleMatrixError,
43
+ NonPositiveDefiniteMatrixError
44
+ )
45
+
46
+
47
+ _DEPRECATED_MIXINS = (
48
+ 'MatrixShaping',
49
+ 'MatrixSpecial',
50
+ 'MatrixProperties',
51
+ 'MatrixOperations',
52
+ 'MatrixArithmetic',
53
+ 'MatrixCommon',
54
+ 'MatrixDeterminant',
55
+ 'MatrixReductions',
56
+ 'MatrixSubspaces',
57
+ 'MatrixEigen',
58
+ 'MatrixCalculus',
59
+ 'MatrixDeprecated',
60
+ )
61
+
62
+
63
+ class _MatrixDeprecatedMeta(type):
64
+
65
+ #
66
+ # Override the default __instancecheck__ implementation to ensure that
67
+ # e.g. isinstance(M, MatrixCommon) still works when M is one of the
68
+ # matrix classes. Matrix no longer inherits from MatrixCommon so
69
+ # isinstance(M, MatrixCommon) would now return False by default.
70
+ #
71
+ # There were lots of places in the codebase where this was being done
72
+ # so it seems likely that downstream code may be doing it too. All use
73
+ # of these mixins is deprecated though so we give a deprecation warning
74
+ # unconditionally if they are being used with isinstance.
75
+ #
76
+ # Any code seeing this deprecation warning should be changed to use
77
+ # isinstance(M, MatrixBase) instead which also works in previous versions
78
+ # of SymPy.
79
+ #
80
+
81
+ def __instancecheck__(cls, instance):
82
+
83
+ sympy_deprecation_warning(
84
+ f"""
85
+ Checking whether an object is an instance of {cls.__name__} is
86
+ deprecated.
87
+
88
+ Use `isinstance(obj, Matrix)` instead of `isinstance(obj, {cls.__name__})`.
89
+ """,
90
+ deprecated_since_version="1.13",
91
+ active_deprecations_target="deprecated-matrix-mixins",
92
+ stacklevel=3,
93
+ )
94
+
95
+ from sympy.matrices.matrixbase import MatrixBase
96
+ from sympy.matrices.matrices import (
97
+ MatrixDeterminant,
98
+ MatrixReductions,
99
+ MatrixSubspaces,
100
+ MatrixEigen,
101
+ MatrixCalculus,
102
+ MatrixDeprecated
103
+ )
104
+
105
+ all_mixins = (
106
+ MatrixRequired,
107
+ MatrixShaping,
108
+ MatrixSpecial,
109
+ MatrixProperties,
110
+ MatrixOperations,
111
+ MatrixArithmetic,
112
+ MatrixCommon,
113
+ MatrixDeterminant,
114
+ MatrixReductions,
115
+ MatrixSubspaces,
116
+ MatrixEigen,
117
+ MatrixCalculus,
118
+ MatrixDeprecated
119
+ )
120
+
121
+ if cls in all_mixins and isinstance(instance, MatrixBase):
122
+ return True
123
+ else:
124
+ return super().__instancecheck__(instance)
125
+
126
+
127
+ class MatrixRequired(metaclass=_MatrixDeprecatedMeta):
128
+ """Deprecated mixin class for making matrix classes."""
129
+
130
+ rows = None # type: int
131
+ cols = None # type: int
132
+ _simplify = None
133
+
134
+ def __init_subclass__(cls, **kwargs):
135
+
136
+ # Warn if any downstream code is subclassing this class or any of the
137
+ # deprecated mixin classes that are all ultimately subclasses of this
138
+ # class.
139
+ #
140
+ # We don't want to warn about the deprecated mixins themselves being
141
+ # created, but only about them being used as mixins by downstream code.
142
+ # Otherwise just importing this module would trigger a warning.
143
+ # Ultimately the whole module should be deprecated and removed but for
144
+ # SymPy 1.13 it is premature to do that given that this module was the
145
+ # main way to import matrix exception types in all previous versions.
146
+
147
+ if cls.__name__ not in _DEPRECATED_MIXINS:
148
+ sympy_deprecation_warning(
149
+ f"""
150
+ Inheriting from the Matrix mixin classes is deprecated.
151
+
152
+ The class {cls.__name__} is subclassing a deprecated mixin.
153
+ """,
154
+ deprecated_since_version="1.13",
155
+ active_deprecations_target="deprecated-matrix-mixins",
156
+ stacklevel=3,
157
+ )
158
+
159
+ super().__init_subclass__(**kwargs)
160
+
161
+ @classmethod
162
+ def _new(cls, *args, **kwargs):
163
+ """`_new` must, at minimum, be callable as
164
+ `_new(rows, cols, mat) where mat is a flat list of the
165
+ elements of the matrix."""
166
+ raise NotImplementedError("Subclasses must implement this.")
167
+
168
+ def __eq__(self, other):
169
+ raise NotImplementedError("Subclasses must implement this.")
170
+
171
+ def __getitem__(self, key):
172
+ """Implementations of __getitem__ should accept ints, in which
173
+ case the matrix is indexed as a flat list, tuples (i,j) in which
174
+ case the (i,j) entry is returned, slices, or mixed tuples (a,b)
175
+ where a and b are any combination of slices and integers."""
176
+ raise NotImplementedError("Subclasses must implement this.")
177
+
178
+ def __len__(self):
179
+ """The total number of entries in the matrix."""
180
+ raise NotImplementedError("Subclasses must implement this.")
181
+
182
+ @property
183
+ def shape(self):
184
+ raise NotImplementedError("Subclasses must implement this.")
185
+
186
+
187
+ class MatrixShaping(MatrixRequired):
188
+ """Provides basic matrix shaping and extracting of submatrices"""
189
+
190
+ def _eval_col_del(self, col):
191
+ def entry(i, j):
192
+ return self[i, j] if j < col else self[i, j + 1]
193
+ return self._new(self.rows, self.cols - 1, entry)
194
+
195
+ def _eval_col_insert(self, pos, other):
196
+
197
+ def entry(i, j):
198
+ if j < pos:
199
+ return self[i, j]
200
+ elif pos <= j < pos + other.cols:
201
+ return other[i, j - pos]
202
+ return self[i, j - other.cols]
203
+
204
+ return self._new(self.rows, self.cols + other.cols, entry)
205
+
206
+ def _eval_col_join(self, other):
207
+ rows = self.rows
208
+
209
+ def entry(i, j):
210
+ if i < rows:
211
+ return self[i, j]
212
+ return other[i - rows, j]
213
+
214
+ return classof(self, other)._new(self.rows + other.rows, self.cols,
215
+ entry)
216
+
217
+ def _eval_extract(self, rowsList, colsList):
218
+ mat = list(self)
219
+ cols = self.cols
220
+ indices = (i * cols + j for i in rowsList for j in colsList)
221
+ return self._new(len(rowsList), len(colsList),
222
+ [mat[i] for i in indices])
223
+
224
+ def _eval_get_diag_blocks(self):
225
+ sub_blocks = []
226
+
227
+ def recurse_sub_blocks(M):
228
+ i = 1
229
+ while i <= M.shape[0]:
230
+ if i == 1:
231
+ to_the_right = M[0, i:]
232
+ to_the_bottom = M[i:, 0]
233
+ else:
234
+ to_the_right = M[:i, i:]
235
+ to_the_bottom = M[i:, :i]
236
+ if any(to_the_right) or any(to_the_bottom):
237
+ i += 1
238
+ continue
239
+ else:
240
+ sub_blocks.append(M[:i, :i])
241
+ if M.shape == M[:i, :i].shape:
242
+ return
243
+ else:
244
+ recurse_sub_blocks(M[i:, i:])
245
+ return
246
+
247
+ recurse_sub_blocks(self)
248
+ return sub_blocks
249
+
250
+ def _eval_row_del(self, row):
251
+ def entry(i, j):
252
+ return self[i, j] if i < row else self[i + 1, j]
253
+ return self._new(self.rows - 1, self.cols, entry)
254
+
255
+ def _eval_row_insert(self, pos, other):
256
+ entries = list(self)
257
+ insert_pos = pos * self.cols
258
+ entries[insert_pos:insert_pos] = list(other)
259
+ return self._new(self.rows + other.rows, self.cols, entries)
260
+
261
+ def _eval_row_join(self, other):
262
+ cols = self.cols
263
+
264
+ def entry(i, j):
265
+ if j < cols:
266
+ return self[i, j]
267
+ return other[i, j - cols]
268
+
269
+ return classof(self, other)._new(self.rows, self.cols + other.cols,
270
+ entry)
271
+
272
+ def _eval_tolist(self):
273
+ return [list(self[i,:]) for i in range(self.rows)]
274
+
275
+ def _eval_todok(self):
276
+ dok = {}
277
+ rows, cols = self.shape
278
+ for i in range(rows):
279
+ for j in range(cols):
280
+ val = self[i, j]
281
+ if val != self.zero:
282
+ dok[i, j] = val
283
+ return dok
284
+
285
+ def _eval_vec(self):
286
+ rows = self.rows
287
+
288
+ def entry(n, _):
289
+ # we want to read off the columns first
290
+ j = n // rows
291
+ i = n - j * rows
292
+ return self[i, j]
293
+
294
+ return self._new(len(self), 1, entry)
295
+
296
+ def _eval_vech(self, diagonal):
297
+ c = self.cols
298
+ v = []
299
+ if diagonal:
300
+ for j in range(c):
301
+ for i in range(j, c):
302
+ v.append(self[i, j])
303
+ else:
304
+ for j in range(c):
305
+ for i in range(j + 1, c):
306
+ v.append(self[i, j])
307
+ return self._new(len(v), 1, v)
308
+
309
+ def col_del(self, col):
310
+ """Delete the specified column."""
311
+ if col < 0:
312
+ col += self.cols
313
+ if not 0 <= col < self.cols:
314
+ raise IndexError("Column {} is out of range.".format(col))
315
+ return self._eval_col_del(col)
316
+
317
+ def col_insert(self, pos, other):
318
+ """Insert one or more columns at the given column position.
319
+
320
+ Examples
321
+ ========
322
+
323
+ >>> from sympy import zeros, ones
324
+ >>> M = zeros(3)
325
+ >>> V = ones(3, 1)
326
+ >>> M.col_insert(1, V)
327
+ Matrix([
328
+ [0, 1, 0, 0],
329
+ [0, 1, 0, 0],
330
+ [0, 1, 0, 0]])
331
+
332
+ See Also
333
+ ========
334
+
335
+ col
336
+ row_insert
337
+ """
338
+ # Allows you to build a matrix even if it is null matrix
339
+ if not self:
340
+ return type(self)(other)
341
+
342
+ pos = as_int(pos)
343
+
344
+ if pos < 0:
345
+ pos = self.cols + pos
346
+ if pos < 0:
347
+ pos = 0
348
+ elif pos > self.cols:
349
+ pos = self.cols
350
+
351
+ if self.rows != other.rows:
352
+ raise ShapeError(
353
+ "The matrices have incompatible number of rows ({} and {})"
354
+ .format(self.rows, other.rows))
355
+
356
+ return self._eval_col_insert(pos, other)
357
+
358
+ def col_join(self, other):
359
+ """Concatenates two matrices along self's last and other's first row.
360
+
361
+ Examples
362
+ ========
363
+
364
+ >>> from sympy import zeros, ones
365
+ >>> M = zeros(3)
366
+ >>> V = ones(1, 3)
367
+ >>> M.col_join(V)
368
+ Matrix([
369
+ [0, 0, 0],
370
+ [0, 0, 0],
371
+ [0, 0, 0],
372
+ [1, 1, 1]])
373
+
374
+ See Also
375
+ ========
376
+
377
+ col
378
+ row_join
379
+ """
380
+ # A null matrix can always be stacked (see #10770)
381
+ if self.rows == 0 and self.cols != other.cols:
382
+ return self._new(0, other.cols, []).col_join(other)
383
+
384
+ if self.cols != other.cols:
385
+ raise ShapeError(
386
+ "The matrices have incompatible number of columns ({} and {})"
387
+ .format(self.cols, other.cols))
388
+ return self._eval_col_join(other)
389
+
390
+ def col(self, j):
391
+ """Elementary column selector.
392
+
393
+ Examples
394
+ ========
395
+
396
+ >>> from sympy import eye
397
+ >>> eye(2).col(0)
398
+ Matrix([
399
+ [1],
400
+ [0]])
401
+
402
+ See Also
403
+ ========
404
+
405
+ row
406
+ col_del
407
+ col_join
408
+ col_insert
409
+ """
410
+ return self[:, j]
411
+
412
+ def extract(self, rowsList, colsList):
413
+ r"""Return a submatrix by specifying a list of rows and columns.
414
+ Negative indices can be given. All indices must be in the range
415
+ $-n \le i < n$ where $n$ is the number of rows or columns.
416
+
417
+ Examples
418
+ ========
419
+
420
+ >>> from sympy import Matrix
421
+ >>> m = Matrix(4, 3, range(12))
422
+ >>> m
423
+ Matrix([
424
+ [0, 1, 2],
425
+ [3, 4, 5],
426
+ [6, 7, 8],
427
+ [9, 10, 11]])
428
+ >>> m.extract([0, 1, 3], [0, 1])
429
+ Matrix([
430
+ [0, 1],
431
+ [3, 4],
432
+ [9, 10]])
433
+
434
+ Rows or columns can be repeated:
435
+
436
+ >>> m.extract([0, 0, 1], [-1])
437
+ Matrix([
438
+ [2],
439
+ [2],
440
+ [5]])
441
+
442
+ Every other row can be taken by using range to provide the indices:
443
+
444
+ >>> m.extract(range(0, m.rows, 2), [-1])
445
+ Matrix([
446
+ [2],
447
+ [8]])
448
+
449
+ RowsList or colsList can also be a list of booleans, in which case
450
+ the rows or columns corresponding to the True values will be selected:
451
+
452
+ >>> m.extract([0, 1, 2, 3], [True, False, True])
453
+ Matrix([
454
+ [0, 2],
455
+ [3, 5],
456
+ [6, 8],
457
+ [9, 11]])
458
+ """
459
+
460
+ if not is_sequence(rowsList) or not is_sequence(colsList):
461
+ raise TypeError("rowsList and colsList must be iterable")
462
+ # ensure rowsList and colsList are lists of integers
463
+ if rowsList and all(isinstance(i, bool) for i in rowsList):
464
+ rowsList = [index for index, item in enumerate(rowsList) if item]
465
+ if colsList and all(isinstance(i, bool) for i in colsList):
466
+ colsList = [index for index, item in enumerate(colsList) if item]
467
+
468
+ # ensure everything is in range
469
+ rowsList = [a2idx(k, self.rows) for k in rowsList]
470
+ colsList = [a2idx(k, self.cols) for k in colsList]
471
+
472
+ return self._eval_extract(rowsList, colsList)
473
+
474
+ def get_diag_blocks(self):
475
+ """Obtains the square sub-matrices on the main diagonal of a square matrix.
476
+
477
+ Useful for inverting symbolic matrices or solving systems of
478
+ linear equations which may be decoupled by having a block diagonal
479
+ structure.
480
+
481
+ Examples
482
+ ========
483
+
484
+ >>> from sympy import Matrix
485
+ >>> from sympy.abc import x, y, z
486
+ >>> A = Matrix([[1, 3, 0, 0], [y, z*z, 0, 0], [0, 0, x, 0], [0, 0, 0, 0]])
487
+ >>> a1, a2, a3 = A.get_diag_blocks()
488
+ >>> a1
489
+ Matrix([
490
+ [1, 3],
491
+ [y, z**2]])
492
+ >>> a2
493
+ Matrix([[x]])
494
+ >>> a3
495
+ Matrix([[0]])
496
+
497
+ """
498
+ return self._eval_get_diag_blocks()
499
+
500
+ @classmethod
501
+ def hstack(cls, *args):
502
+ """Return a matrix formed by joining args horizontally (i.e.
503
+ by repeated application of row_join).
504
+
505
+ Examples
506
+ ========
507
+
508
+ >>> from sympy import Matrix, eye
509
+ >>> Matrix.hstack(eye(2), 2*eye(2))
510
+ Matrix([
511
+ [1, 0, 2, 0],
512
+ [0, 1, 0, 2]])
513
+ """
514
+ if len(args) == 0:
515
+ return cls._new()
516
+
517
+ kls = type(args[0])
518
+ return reduce(kls.row_join, args)
519
+
520
+ def reshape(self, rows, cols):
521
+ """Reshape the matrix. Total number of elements must remain the same.
522
+
523
+ Examples
524
+ ========
525
+
526
+ >>> from sympy import Matrix
527
+ >>> m = Matrix(2, 3, lambda i, j: 1)
528
+ >>> m
529
+ Matrix([
530
+ [1, 1, 1],
531
+ [1, 1, 1]])
532
+ >>> m.reshape(1, 6)
533
+ Matrix([[1, 1, 1, 1, 1, 1]])
534
+ >>> m.reshape(3, 2)
535
+ Matrix([
536
+ [1, 1],
537
+ [1, 1],
538
+ [1, 1]])
539
+
540
+ """
541
+ if self.rows * self.cols != rows * cols:
542
+ raise ValueError("Invalid reshape parameters %d %d" % (rows, cols))
543
+ return self._new(rows, cols, lambda i, j: self[i * cols + j])
544
+
545
+ def row_del(self, row):
546
+ """Delete the specified row."""
547
+ if row < 0:
548
+ row += self.rows
549
+ if not 0 <= row < self.rows:
550
+ raise IndexError("Row {} is out of range.".format(row))
551
+
552
+ return self._eval_row_del(row)
553
+
554
+ def row_insert(self, pos, other):
555
+ """Insert one or more rows at the given row position.
556
+
557
+ Examples
558
+ ========
559
+
560
+ >>> from sympy import zeros, ones
561
+ >>> M = zeros(3)
562
+ >>> V = ones(1, 3)
563
+ >>> M.row_insert(1, V)
564
+ Matrix([
565
+ [0, 0, 0],
566
+ [1, 1, 1],
567
+ [0, 0, 0],
568
+ [0, 0, 0]])
569
+
570
+ See Also
571
+ ========
572
+
573
+ row
574
+ col_insert
575
+ """
576
+ # Allows you to build a matrix even if it is null matrix
577
+ if not self:
578
+ return self._new(other)
579
+
580
+ pos = as_int(pos)
581
+
582
+ if pos < 0:
583
+ pos = self.rows + pos
584
+ if pos < 0:
585
+ pos = 0
586
+ elif pos > self.rows:
587
+ pos = self.rows
588
+
589
+ if self.cols != other.cols:
590
+ raise ShapeError(
591
+ "The matrices have incompatible number of columns ({} and {})"
592
+ .format(self.cols, other.cols))
593
+
594
+ return self._eval_row_insert(pos, other)
595
+
596
+ def row_join(self, other):
597
+ """Concatenates two matrices along self's last and rhs's first column
598
+
599
+ Examples
600
+ ========
601
+
602
+ >>> from sympy import zeros, ones
603
+ >>> M = zeros(3)
604
+ >>> V = ones(3, 1)
605
+ >>> M.row_join(V)
606
+ Matrix([
607
+ [0, 0, 0, 1],
608
+ [0, 0, 0, 1],
609
+ [0, 0, 0, 1]])
610
+
611
+ See Also
612
+ ========
613
+
614
+ row
615
+ col_join
616
+ """
617
+ # A null matrix can always be stacked (see #10770)
618
+ if self.cols == 0 and self.rows != other.rows:
619
+ return self._new(other.rows, 0, []).row_join(other)
620
+
621
+ if self.rows != other.rows:
622
+ raise ShapeError(
623
+ "The matrices have incompatible number of rows ({} and {})"
624
+ .format(self.rows, other.rows))
625
+ return self._eval_row_join(other)
626
+
627
+ def diagonal(self, k=0):
628
+ """Returns the kth diagonal of self. The main diagonal
629
+ corresponds to `k=0`; diagonals above and below correspond to
630
+ `k > 0` and `k < 0`, respectively. The values of `self[i, j]`
631
+ for which `j - i = k`, are returned in order of increasing
632
+ `i + j`, starting with `i + j = |k|`.
633
+
634
+ Examples
635
+ ========
636
+
637
+ >>> from sympy import Matrix
638
+ >>> m = Matrix(3, 3, lambda i, j: j - i); m
639
+ Matrix([
640
+ [ 0, 1, 2],
641
+ [-1, 0, 1],
642
+ [-2, -1, 0]])
643
+ >>> _.diagonal()
644
+ Matrix([[0, 0, 0]])
645
+ >>> m.diagonal(1)
646
+ Matrix([[1, 1]])
647
+ >>> m.diagonal(-2)
648
+ Matrix([[-2]])
649
+
650
+ Even though the diagonal is returned as a Matrix, the element
651
+ retrieval can be done with a single index:
652
+
653
+ >>> Matrix.diag(1, 2, 3).diagonal()[1] # instead of [0, 1]
654
+ 2
655
+
656
+ See Also
657
+ ========
658
+
659
+ diag
660
+ """
661
+ rv = []
662
+ k = as_int(k)
663
+ r = 0 if k > 0 else -k
664
+ c = 0 if r else k
665
+ while True:
666
+ if r == self.rows or c == self.cols:
667
+ break
668
+ rv.append(self[r, c])
669
+ r += 1
670
+ c += 1
671
+ if not rv:
672
+ raise ValueError(filldedent('''
673
+ The %s diagonal is out of range [%s, %s]''' % (
674
+ k, 1 - self.rows, self.cols - 1)))
675
+ return self._new(1, len(rv), rv)
676
+
677
+ def row(self, i):
678
+ """Elementary row selector.
679
+
680
+ Examples
681
+ ========
682
+
683
+ >>> from sympy import eye
684
+ >>> eye(2).row(0)
685
+ Matrix([[1, 0]])
686
+
687
+ See Also
688
+ ========
689
+
690
+ col
691
+ row_del
692
+ row_join
693
+ row_insert
694
+ """
695
+ return self[i, :]
696
+
697
+ @property
698
+ def shape(self):
699
+ """The shape (dimensions) of the matrix as the 2-tuple (rows, cols).
700
+
701
+ Examples
702
+ ========
703
+
704
+ >>> from sympy import zeros
705
+ >>> M = zeros(2, 3)
706
+ >>> M.shape
707
+ (2, 3)
708
+ >>> M.rows
709
+ 2
710
+ >>> M.cols
711
+ 3
712
+ """
713
+ return (self.rows, self.cols)
714
+
715
+ def todok(self):
716
+ """Return the matrix as dictionary of keys.
717
+
718
+ Examples
719
+ ========
720
+
721
+ >>> from sympy import Matrix
722
+ >>> M = Matrix.eye(3)
723
+ >>> M.todok()
724
+ {(0, 0): 1, (1, 1): 1, (2, 2): 1}
725
+ """
726
+ return self._eval_todok()
727
+
728
+ def tolist(self):
729
+ """Return the Matrix as a nested Python list.
730
+
731
+ Examples
732
+ ========
733
+
734
+ >>> from sympy import Matrix, ones
735
+ >>> m = Matrix(3, 3, range(9))
736
+ >>> m
737
+ Matrix([
738
+ [0, 1, 2],
739
+ [3, 4, 5],
740
+ [6, 7, 8]])
741
+ >>> m.tolist()
742
+ [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
743
+ >>> ones(3, 0).tolist()
744
+ [[], [], []]
745
+
746
+ When there are no rows then it will not be possible to tell how
747
+ many columns were in the original matrix:
748
+
749
+ >>> ones(0, 3).tolist()
750
+ []
751
+
752
+ """
753
+ if not self.rows:
754
+ return []
755
+ if not self.cols:
756
+ return [[] for i in range(self.rows)]
757
+ return self._eval_tolist()
758
+
759
+ def todod(M):
760
+ """Returns matrix as dict of dicts containing non-zero elements of the Matrix
761
+
762
+ Examples
763
+ ========
764
+
765
+ >>> from sympy import Matrix
766
+ >>> A = Matrix([[0, 1],[0, 3]])
767
+ >>> A
768
+ Matrix([
769
+ [0, 1],
770
+ [0, 3]])
771
+ >>> A.todod()
772
+ {0: {1: 1}, 1: {1: 3}}
773
+
774
+
775
+ """
776
+ rowsdict = {}
777
+ Mlol = M.tolist()
778
+ for i, Mi in enumerate(Mlol):
779
+ row = {j: Mij for j, Mij in enumerate(Mi) if Mij}
780
+ if row:
781
+ rowsdict[i] = row
782
+ return rowsdict
783
+
784
+ def vec(self):
785
+ """Return the Matrix converted into a one column matrix by stacking columns
786
+
787
+ Examples
788
+ ========
789
+
790
+ >>> from sympy import Matrix
791
+ >>> m=Matrix([[1, 3], [2, 4]])
792
+ >>> m
793
+ Matrix([
794
+ [1, 3],
795
+ [2, 4]])
796
+ >>> m.vec()
797
+ Matrix([
798
+ [1],
799
+ [2],
800
+ [3],
801
+ [4]])
802
+
803
+ See Also
804
+ ========
805
+
806
+ vech
807
+ """
808
+ return self._eval_vec()
809
+
810
+ def vech(self, diagonal=True, check_symmetry=True):
811
+ """Reshapes the matrix into a column vector by stacking the
812
+ elements in the lower triangle.
813
+
814
+ Parameters
815
+ ==========
816
+
817
+ diagonal : bool, optional
818
+ If ``True``, it includes the diagonal elements.
819
+
820
+ check_symmetry : bool, optional
821
+ If ``True``, it checks whether the matrix is symmetric.
822
+
823
+ Examples
824
+ ========
825
+
826
+ >>> from sympy import Matrix
827
+ >>> m=Matrix([[1, 2], [2, 3]])
828
+ >>> m
829
+ Matrix([
830
+ [1, 2],
831
+ [2, 3]])
832
+ >>> m.vech()
833
+ Matrix([
834
+ [1],
835
+ [2],
836
+ [3]])
837
+ >>> m.vech(diagonal=False)
838
+ Matrix([[2]])
839
+
840
+ Notes
841
+ =====
842
+
843
+ This should work for symmetric matrices and ``vech`` can
844
+ represent symmetric matrices in vector form with less size than
845
+ ``vec``.
846
+
847
+ See Also
848
+ ========
849
+
850
+ vec
851
+ """
852
+ if not self.is_square:
853
+ raise NonSquareMatrixError
854
+
855
+ if check_symmetry and not self.is_symmetric():
856
+ raise ValueError("The matrix is not symmetric.")
857
+
858
+ return self._eval_vech(diagonal)
859
+
860
+ @classmethod
861
+ def vstack(cls, *args):
862
+ """Return a matrix formed by joining args vertically (i.e.
863
+ by repeated application of col_join).
864
+
865
+ Examples
866
+ ========
867
+
868
+ >>> from sympy import Matrix, eye
869
+ >>> Matrix.vstack(eye(2), 2*eye(2))
870
+ Matrix([
871
+ [1, 0],
872
+ [0, 1],
873
+ [2, 0],
874
+ [0, 2]])
875
+ """
876
+ if len(args) == 0:
877
+ return cls._new()
878
+
879
+ kls = type(args[0])
880
+ return reduce(kls.col_join, args)
881
+
882
+
883
+ class MatrixSpecial(MatrixRequired):
884
+ """Construction of special matrices"""
885
+
886
+ @classmethod
887
+ def _eval_diag(cls, rows, cols, diag_dict):
888
+ """diag_dict is a defaultdict containing
889
+ all the entries of the diagonal matrix."""
890
+ def entry(i, j):
891
+ return diag_dict[(i, j)]
892
+ return cls._new(rows, cols, entry)
893
+
894
+ @classmethod
895
+ def _eval_eye(cls, rows, cols):
896
+ vals = [cls.zero]*(rows*cols)
897
+ vals[::cols+1] = [cls.one]*min(rows, cols)
898
+ return cls._new(rows, cols, vals, copy=False)
899
+
900
+ @classmethod
901
+ def _eval_jordan_block(cls, size: int, eigenvalue, band='upper'):
902
+ if band == 'lower':
903
+ def entry(i, j):
904
+ if i == j:
905
+ return eigenvalue
906
+ elif j + 1 == i:
907
+ return cls.one
908
+ return cls.zero
909
+ else:
910
+ def entry(i, j):
911
+ if i == j:
912
+ return eigenvalue
913
+ elif i + 1 == j:
914
+ return cls.one
915
+ return cls.zero
916
+ return cls._new(size, size, entry)
917
+
918
+ @classmethod
919
+ def _eval_ones(cls, rows, cols):
920
+ def entry(i, j):
921
+ return cls.one
922
+ return cls._new(rows, cols, entry)
923
+
924
+ @classmethod
925
+ def _eval_zeros(cls, rows, cols):
926
+ return cls._new(rows, cols, [cls.zero]*(rows*cols), copy=False)
927
+
928
+ @classmethod
929
+ def _eval_wilkinson(cls, n):
930
+ def entry(i, j):
931
+ return cls.one if i + 1 == j else cls.zero
932
+
933
+ D = cls._new(2*n + 1, 2*n + 1, entry)
934
+
935
+ wminus = cls.diag(list(range(-n, n + 1)), unpack=True) + D + D.T
936
+ wplus = abs(cls.diag(list(range(-n, n + 1)), unpack=True)) + D + D.T
937
+
938
+ return wminus, wplus
939
+
940
+ @classmethod
941
+ def diag(kls, *args, strict=False, unpack=True, rows=None, cols=None, **kwargs):
942
+ """Returns a matrix with the specified diagonal.
943
+ If matrices are passed, a block-diagonal matrix
944
+ is created (i.e. the "direct sum" of the matrices).
945
+
946
+ kwargs
947
+ ======
948
+
949
+ rows : rows of the resulting matrix; computed if
950
+ not given.
951
+
952
+ cols : columns of the resulting matrix; computed if
953
+ not given.
954
+
955
+ cls : class for the resulting matrix
956
+
957
+ unpack : bool which, when True (default), unpacks a single
958
+ sequence rather than interpreting it as a Matrix.
959
+
960
+ strict : bool which, when False (default), allows Matrices to
961
+ have variable-length rows.
962
+
963
+ Examples
964
+ ========
965
+
966
+ >>> from sympy import Matrix
967
+ >>> Matrix.diag(1, 2, 3)
968
+ Matrix([
969
+ [1, 0, 0],
970
+ [0, 2, 0],
971
+ [0, 0, 3]])
972
+
973
+ The current default is to unpack a single sequence. If this is
974
+ not desired, set `unpack=False` and it will be interpreted as
975
+ a matrix.
976
+
977
+ >>> Matrix.diag([1, 2, 3]) == Matrix.diag(1, 2, 3)
978
+ True
979
+
980
+ When more than one element is passed, each is interpreted as
981
+ something to put on the diagonal. Lists are converted to
982
+ matrices. Filling of the diagonal always continues from
983
+ the bottom right hand corner of the previous item: this
984
+ will create a block-diagonal matrix whether the matrices
985
+ are square or not.
986
+
987
+ >>> col = [1, 2, 3]
988
+ >>> row = [[4, 5]]
989
+ >>> Matrix.diag(col, row)
990
+ Matrix([
991
+ [1, 0, 0],
992
+ [2, 0, 0],
993
+ [3, 0, 0],
994
+ [0, 4, 5]])
995
+
996
+ When `unpack` is False, elements within a list need not all be
997
+ of the same length. Setting `strict` to True would raise a
998
+ ValueError for the following:
999
+
1000
+ >>> Matrix.diag([[1, 2, 3], [4, 5], [6]], unpack=False)
1001
+ Matrix([
1002
+ [1, 2, 3],
1003
+ [4, 5, 0],
1004
+ [6, 0, 0]])
1005
+
1006
+ The type of the returned matrix can be set with the ``cls``
1007
+ keyword.
1008
+
1009
+ >>> from sympy import ImmutableMatrix
1010
+ >>> from sympy.utilities.misc import func_name
1011
+ >>> func_name(Matrix.diag(1, cls=ImmutableMatrix))
1012
+ 'ImmutableDenseMatrix'
1013
+
1014
+ A zero dimension matrix can be used to position the start of
1015
+ the filling at the start of an arbitrary row or column:
1016
+
1017
+ >>> from sympy import ones
1018
+ >>> r2 = ones(0, 2)
1019
+ >>> Matrix.diag(r2, 1, 2)
1020
+ Matrix([
1021
+ [0, 0, 1, 0],
1022
+ [0, 0, 0, 2]])
1023
+
1024
+ See Also
1025
+ ========
1026
+ eye
1027
+ diagonal
1028
+ .dense.diag
1029
+ .expressions.blockmatrix.BlockMatrix
1030
+ .sparsetools.banded
1031
+ """
1032
+ from sympy.matrices.matrixbase import MatrixBase
1033
+ from sympy.matrices.dense import Matrix
1034
+ from sympy.matrices import SparseMatrix
1035
+ klass = kwargs.get('cls', kls)
1036
+ if unpack and len(args) == 1 and is_sequence(args[0]) and \
1037
+ not isinstance(args[0], MatrixBase):
1038
+ args = args[0]
1039
+
1040
+ # fill a default dict with the diagonal entries
1041
+ diag_entries = defaultdict(int)
1042
+ rmax = cmax = 0 # keep track of the biggest index seen
1043
+ for m in args:
1044
+ if isinstance(m, list):
1045
+ if strict:
1046
+ # if malformed, Matrix will raise an error
1047
+ _ = Matrix(m)
1048
+ r, c = _.shape
1049
+ m = _.tolist()
1050
+ else:
1051
+ r, c, smat = SparseMatrix._handle_creation_inputs(m)
1052
+ for (i, j), _ in smat.items():
1053
+ diag_entries[(i + rmax, j + cmax)] = _
1054
+ m = [] # to skip process below
1055
+ elif hasattr(m, 'shape'): # a Matrix
1056
+ # convert to list of lists
1057
+ r, c = m.shape
1058
+ m = m.tolist()
1059
+ else: # in this case, we're a single value
1060
+ diag_entries[(rmax, cmax)] = m
1061
+ rmax += 1
1062
+ cmax += 1
1063
+ continue
1064
+ # process list of lists
1065
+ for i, mi in enumerate(m):
1066
+ for j, _ in enumerate(mi):
1067
+ diag_entries[(i + rmax, j + cmax)] = _
1068
+ rmax += r
1069
+ cmax += c
1070
+ if rows is None:
1071
+ rows, cols = cols, rows
1072
+ if rows is None:
1073
+ rows, cols = rmax, cmax
1074
+ else:
1075
+ cols = rows if cols is None else cols
1076
+ if rows < rmax or cols < cmax:
1077
+ raise ValueError(filldedent('''
1078
+ The constructed matrix is {} x {} but a size of {} x {}
1079
+ was specified.'''.format(rmax, cmax, rows, cols)))
1080
+ return klass._eval_diag(rows, cols, diag_entries)
1081
+
1082
+ @classmethod
1083
+ def eye(kls, rows, cols=None, **kwargs):
1084
+ """Returns an identity matrix.
1085
+
1086
+ Parameters
1087
+ ==========
1088
+
1089
+ rows : rows of the matrix
1090
+ cols : cols of the matrix (if None, cols=rows)
1091
+
1092
+ kwargs
1093
+ ======
1094
+ cls : class of the returned matrix
1095
+ """
1096
+ if cols is None:
1097
+ cols = rows
1098
+ if rows < 0 or cols < 0:
1099
+ raise ValueError("Cannot create a {} x {} matrix. "
1100
+ "Both dimensions must be positive".format(rows, cols))
1101
+ klass = kwargs.get('cls', kls)
1102
+ rows, cols = as_int(rows), as_int(cols)
1103
+
1104
+ return klass._eval_eye(rows, cols)
1105
+
1106
+ @classmethod
1107
+ def jordan_block(kls, size=None, eigenvalue=None, *, band='upper', **kwargs):
1108
+ """Returns a Jordan block
1109
+
1110
+ Parameters
1111
+ ==========
1112
+
1113
+ size : Integer, optional
1114
+ Specifies the shape of the Jordan block matrix.
1115
+
1116
+ eigenvalue : Number or Symbol
1117
+ Specifies the value for the main diagonal of the matrix.
1118
+
1119
+ .. note::
1120
+ The keyword ``eigenval`` is also specified as an alias
1121
+ of this keyword, but it is not recommended to use.
1122
+
1123
+ We may deprecate the alias in later release.
1124
+
1125
+ band : 'upper' or 'lower', optional
1126
+ Specifies the position of the off-diagonal to put `1` s on.
1127
+
1128
+ cls : Matrix, optional
1129
+ Specifies the matrix class of the output form.
1130
+
1131
+ If it is not specified, the class type where the method is
1132
+ being executed on will be returned.
1133
+
1134
+ Returns
1135
+ =======
1136
+
1137
+ Matrix
1138
+ A Jordan block matrix.
1139
+
1140
+ Raises
1141
+ ======
1142
+
1143
+ ValueError
1144
+ If insufficient arguments are given for matrix size
1145
+ specification, or no eigenvalue is given.
1146
+
1147
+ Examples
1148
+ ========
1149
+
1150
+ Creating a default Jordan block:
1151
+
1152
+ >>> from sympy import Matrix
1153
+ >>> from sympy.abc import x
1154
+ >>> Matrix.jordan_block(4, x)
1155
+ Matrix([
1156
+ [x, 1, 0, 0],
1157
+ [0, x, 1, 0],
1158
+ [0, 0, x, 1],
1159
+ [0, 0, 0, x]])
1160
+
1161
+ Creating an alternative Jordan block matrix where `1` is on
1162
+ lower off-diagonal:
1163
+
1164
+ >>> Matrix.jordan_block(4, x, band='lower')
1165
+ Matrix([
1166
+ [x, 0, 0, 0],
1167
+ [1, x, 0, 0],
1168
+ [0, 1, x, 0],
1169
+ [0, 0, 1, x]])
1170
+
1171
+ Creating a Jordan block with keyword arguments
1172
+
1173
+ >>> Matrix.jordan_block(size=4, eigenvalue=x)
1174
+ Matrix([
1175
+ [x, 1, 0, 0],
1176
+ [0, x, 1, 0],
1177
+ [0, 0, x, 1],
1178
+ [0, 0, 0, x]])
1179
+
1180
+ References
1181
+ ==========
1182
+
1183
+ .. [1] https://en.wikipedia.org/wiki/Jordan_matrix
1184
+ """
1185
+ klass = kwargs.pop('cls', kls)
1186
+
1187
+ eigenval = kwargs.get('eigenval', None)
1188
+ if eigenvalue is None and eigenval is None:
1189
+ raise ValueError("Must supply an eigenvalue")
1190
+ elif eigenvalue != eigenval and None not in (eigenval, eigenvalue):
1191
+ raise ValueError(
1192
+ "Inconsistent values are given: 'eigenval'={}, "
1193
+ "'eigenvalue'={}".format(eigenval, eigenvalue))
1194
+ else:
1195
+ if eigenval is not None:
1196
+ eigenvalue = eigenval
1197
+
1198
+ if size is None:
1199
+ raise ValueError("Must supply a matrix size")
1200
+
1201
+ size = as_int(size)
1202
+ return klass._eval_jordan_block(size, eigenvalue, band)
1203
+
1204
+ @classmethod
1205
+ def ones(kls, rows, cols=None, **kwargs):
1206
+ """Returns a matrix of ones.
1207
+
1208
+ Parameters
1209
+ ==========
1210
+
1211
+ rows : rows of the matrix
1212
+ cols : cols of the matrix (if None, cols=rows)
1213
+
1214
+ kwargs
1215
+ ======
1216
+ cls : class of the returned matrix
1217
+ """
1218
+ if cols is None:
1219
+ cols = rows
1220
+ klass = kwargs.get('cls', kls)
1221
+ rows, cols = as_int(rows), as_int(cols)
1222
+
1223
+ return klass._eval_ones(rows, cols)
1224
+
1225
+ @classmethod
1226
+ def zeros(kls, rows, cols=None, **kwargs):
1227
+ """Returns a matrix of zeros.
1228
+
1229
+ Parameters
1230
+ ==========
1231
+
1232
+ rows : rows of the matrix
1233
+ cols : cols of the matrix (if None, cols=rows)
1234
+
1235
+ kwargs
1236
+ ======
1237
+ cls : class of the returned matrix
1238
+ """
1239
+ if cols is None:
1240
+ cols = rows
1241
+ if rows < 0 or cols < 0:
1242
+ raise ValueError("Cannot create a {} x {} matrix. "
1243
+ "Both dimensions must be positive".format(rows, cols))
1244
+ klass = kwargs.get('cls', kls)
1245
+ rows, cols = as_int(rows), as_int(cols)
1246
+
1247
+ return klass._eval_zeros(rows, cols)
1248
+
1249
+ @classmethod
1250
+ def companion(kls, poly):
1251
+ """Returns a companion matrix of a polynomial.
1252
+
1253
+ Examples
1254
+ ========
1255
+
1256
+ >>> from sympy import Matrix, Poly, Symbol, symbols
1257
+ >>> x = Symbol('x')
1258
+ >>> c0, c1, c2, c3, c4 = symbols('c0:5')
1259
+ >>> p = Poly(c0 + c1*x + c2*x**2 + c3*x**3 + c4*x**4 + x**5, x)
1260
+ >>> Matrix.companion(p)
1261
+ Matrix([
1262
+ [0, 0, 0, 0, -c0],
1263
+ [1, 0, 0, 0, -c1],
1264
+ [0, 1, 0, 0, -c2],
1265
+ [0, 0, 1, 0, -c3],
1266
+ [0, 0, 0, 1, -c4]])
1267
+ """
1268
+ poly = kls._sympify(poly)
1269
+ if not isinstance(poly, Poly):
1270
+ raise ValueError("{} must be a Poly instance.".format(poly))
1271
+ if not poly.is_monic:
1272
+ raise ValueError("{} must be a monic polynomial.".format(poly))
1273
+ if not poly.is_univariate:
1274
+ raise ValueError(
1275
+ "{} must be a univariate polynomial.".format(poly))
1276
+
1277
+ size = poly.degree()
1278
+ if not size >= 1:
1279
+ raise ValueError(
1280
+ "{} must have degree not less than 1.".format(poly))
1281
+
1282
+ coeffs = poly.all_coeffs()
1283
+ def entry(i, j):
1284
+ if j == size - 1:
1285
+ return -coeffs[-1 - i]
1286
+ elif i == j + 1:
1287
+ return kls.one
1288
+ return kls.zero
1289
+ return kls._new(size, size, entry)
1290
+
1291
+
1292
+ @classmethod
1293
+ def wilkinson(kls, n, **kwargs):
1294
+ """Returns two square Wilkinson Matrix of size 2*n + 1
1295
+ $W_{2n + 1}^-, W_{2n + 1}^+ =$ Wilkinson(n)
1296
+
1297
+ Examples
1298
+ ========
1299
+
1300
+ >>> from sympy import Matrix
1301
+ >>> wminus, wplus = Matrix.wilkinson(3)
1302
+ >>> wminus
1303
+ Matrix([
1304
+ [-3, 1, 0, 0, 0, 0, 0],
1305
+ [ 1, -2, 1, 0, 0, 0, 0],
1306
+ [ 0, 1, -1, 1, 0, 0, 0],
1307
+ [ 0, 0, 1, 0, 1, 0, 0],
1308
+ [ 0, 0, 0, 1, 1, 1, 0],
1309
+ [ 0, 0, 0, 0, 1, 2, 1],
1310
+ [ 0, 0, 0, 0, 0, 1, 3]])
1311
+ >>> wplus
1312
+ Matrix([
1313
+ [3, 1, 0, 0, 0, 0, 0],
1314
+ [1, 2, 1, 0, 0, 0, 0],
1315
+ [0, 1, 1, 1, 0, 0, 0],
1316
+ [0, 0, 1, 0, 1, 0, 0],
1317
+ [0, 0, 0, 1, 1, 1, 0],
1318
+ [0, 0, 0, 0, 1, 2, 1],
1319
+ [0, 0, 0, 0, 0, 1, 3]])
1320
+
1321
+ References
1322
+ ==========
1323
+
1324
+ .. [1] https://blogs.mathworks.com/cleve/2013/04/15/wilkinsons-matrices-2/
1325
+ .. [2] J. H. Wilkinson, The Algebraic Eigenvalue Problem, Claredon Press, Oxford, 1965, 662 pp.
1326
+
1327
+ """
1328
+ klass = kwargs.get('cls', kls)
1329
+ n = as_int(n)
1330
+ return klass._eval_wilkinson(n)
1331
+
1332
+ class MatrixProperties(MatrixRequired):
1333
+ """Provides basic properties of a matrix."""
1334
+
1335
+ def _eval_atoms(self, *types):
1336
+ result = set()
1337
+ for i in self:
1338
+ result.update(i.atoms(*types))
1339
+ return result
1340
+
1341
+ def _eval_free_symbols(self):
1342
+ return set().union(*(i.free_symbols for i in self if i))
1343
+
1344
+ def _eval_has(self, *patterns):
1345
+ return any(a.has(*patterns) for a in self)
1346
+
1347
+ def _eval_is_anti_symmetric(self, simpfunc):
1348
+ if not all(simpfunc(self[i, j] + self[j, i]).is_zero for i in range(self.rows) for j in range(self.cols)):
1349
+ return False
1350
+ return True
1351
+
1352
+ def _eval_is_diagonal(self):
1353
+ for i in range(self.rows):
1354
+ for j in range(self.cols):
1355
+ if i != j and self[i, j]:
1356
+ return False
1357
+ return True
1358
+
1359
+ # _eval_is_hermitian is called by some general SymPy
1360
+ # routines and has a different *args signature. Make
1361
+ # sure the names don't clash by adding `_matrix_` in name.
1362
+ def _eval_is_matrix_hermitian(self, simpfunc):
1363
+ mat = self._new(self.rows, self.cols, lambda i, j: simpfunc(self[i, j] - self[j, i].conjugate()))
1364
+ return mat.is_zero_matrix
1365
+
1366
+ def _eval_is_Identity(self) -> FuzzyBool:
1367
+ def dirac(i, j):
1368
+ if i == j:
1369
+ return 1
1370
+ return 0
1371
+
1372
+ return all(self[i, j] == dirac(i, j)
1373
+ for i in range(self.rows)
1374
+ for j in range(self.cols))
1375
+
1376
+ def _eval_is_lower_hessenberg(self):
1377
+ return all(self[i, j].is_zero
1378
+ for i in range(self.rows)
1379
+ for j in range(i + 2, self.cols))
1380
+
1381
+ def _eval_is_lower(self):
1382
+ return all(self[i, j].is_zero
1383
+ for i in range(self.rows)
1384
+ for j in range(i + 1, self.cols))
1385
+
1386
+ def _eval_is_symbolic(self):
1387
+ return self.has(Symbol)
1388
+
1389
+ def _eval_is_symmetric(self, simpfunc):
1390
+ mat = self._new(self.rows, self.cols, lambda i, j: simpfunc(self[i, j] - self[j, i]))
1391
+ return mat.is_zero_matrix
1392
+
1393
+ def _eval_is_zero_matrix(self):
1394
+ if any(i.is_zero == False for i in self):
1395
+ return False
1396
+ if any(i.is_zero is None for i in self):
1397
+ return None
1398
+ return True
1399
+
1400
+ def _eval_is_upper_hessenberg(self):
1401
+ return all(self[i, j].is_zero
1402
+ for i in range(2, self.rows)
1403
+ for j in range(min(self.cols, (i - 1))))
1404
+
1405
+ def _eval_values(self):
1406
+ return [i for i in self if not i.is_zero]
1407
+
1408
+ def _has_positive_diagonals(self):
1409
+ diagonal_entries = (self[i, i] for i in range(self.rows))
1410
+ return fuzzy_and(x.is_positive for x in diagonal_entries)
1411
+
1412
+ def _has_nonnegative_diagonals(self):
1413
+ diagonal_entries = (self[i, i] for i in range(self.rows))
1414
+ return fuzzy_and(x.is_nonnegative for x in diagonal_entries)
1415
+
1416
+ def atoms(self, *types):
1417
+ """Returns the atoms that form the current object.
1418
+
1419
+ Examples
1420
+ ========
1421
+
1422
+ >>> from sympy.abc import x, y
1423
+ >>> from sympy import Matrix
1424
+ >>> Matrix([[x]])
1425
+ Matrix([[x]])
1426
+ >>> _.atoms()
1427
+ {x}
1428
+ >>> Matrix([[x, y], [y, x]])
1429
+ Matrix([
1430
+ [x, y],
1431
+ [y, x]])
1432
+ >>> _.atoms()
1433
+ {x, y}
1434
+ """
1435
+
1436
+ types = tuple(t if isinstance(t, type) else type(t) for t in types)
1437
+ if not types:
1438
+ types = (Atom,)
1439
+ return self._eval_atoms(*types)
1440
+
1441
+ @property
1442
+ def free_symbols(self):
1443
+ """Returns the free symbols within the matrix.
1444
+
1445
+ Examples
1446
+ ========
1447
+
1448
+ >>> from sympy.abc import x
1449
+ >>> from sympy import Matrix
1450
+ >>> Matrix([[x], [1]]).free_symbols
1451
+ {x}
1452
+ """
1453
+ return self._eval_free_symbols()
1454
+
1455
+ def has(self, *patterns):
1456
+ """Test whether any subexpression matches any of the patterns.
1457
+
1458
+ Examples
1459
+ ========
1460
+
1461
+ >>> from sympy import Matrix, SparseMatrix, Float
1462
+ >>> from sympy.abc import x, y
1463
+ >>> A = Matrix(((1, x), (0.2, 3)))
1464
+ >>> B = SparseMatrix(((1, x), (0.2, 3)))
1465
+ >>> A.has(x)
1466
+ True
1467
+ >>> A.has(y)
1468
+ False
1469
+ >>> A.has(Float)
1470
+ True
1471
+ >>> B.has(x)
1472
+ True
1473
+ >>> B.has(y)
1474
+ False
1475
+ >>> B.has(Float)
1476
+ True
1477
+ """
1478
+ return self._eval_has(*patterns)
1479
+
1480
+ def is_anti_symmetric(self, simplify=True):
1481
+ """Check if matrix M is an antisymmetric matrix,
1482
+ that is, M is a square matrix with all M[i, j] == -M[j, i].
1483
+
1484
+ When ``simplify=True`` (default), the sum M[i, j] + M[j, i] is
1485
+ simplified before testing to see if it is zero. By default,
1486
+ the SymPy simplify function is used. To use a custom function
1487
+ set simplify to a function that accepts a single argument which
1488
+ returns a simplified expression. To skip simplification, set
1489
+ simplify to False but note that although this will be faster,
1490
+ it may induce false negatives.
1491
+
1492
+ Examples
1493
+ ========
1494
+
1495
+ >>> from sympy import Matrix, symbols
1496
+ >>> m = Matrix(2, 2, [0, 1, -1, 0])
1497
+ >>> m
1498
+ Matrix([
1499
+ [ 0, 1],
1500
+ [-1, 0]])
1501
+ >>> m.is_anti_symmetric()
1502
+ True
1503
+ >>> x, y = symbols('x y')
1504
+ >>> m = Matrix(2, 3, [0, 0, x, -y, 0, 0])
1505
+ >>> m
1506
+ Matrix([
1507
+ [ 0, 0, x],
1508
+ [-y, 0, 0]])
1509
+ >>> m.is_anti_symmetric()
1510
+ False
1511
+
1512
+ >>> from sympy.abc import x, y
1513
+ >>> m = Matrix(3, 3, [0, x**2 + 2*x + 1, y,
1514
+ ... -(x + 1)**2, 0, x*y,
1515
+ ... -y, -x*y, 0])
1516
+
1517
+ Simplification of matrix elements is done by default so even
1518
+ though two elements which should be equal and opposite would not
1519
+ pass an equality test, the matrix is still reported as
1520
+ anti-symmetric:
1521
+
1522
+ >>> m[0, 1] == -m[1, 0]
1523
+ False
1524
+ >>> m.is_anti_symmetric()
1525
+ True
1526
+
1527
+ If ``simplify=False`` is used for the case when a Matrix is already
1528
+ simplified, this will speed things up. Here, we see that without
1529
+ simplification the matrix does not appear anti-symmetric:
1530
+
1531
+ >>> print(m.is_anti_symmetric(simplify=False))
1532
+ None
1533
+
1534
+ But if the matrix were already expanded, then it would appear
1535
+ anti-symmetric and simplification in the is_anti_symmetric routine
1536
+ is not needed:
1537
+
1538
+ >>> m = m.expand()
1539
+ >>> m.is_anti_symmetric(simplify=False)
1540
+ True
1541
+ """
1542
+ # accept custom simplification
1543
+ simpfunc = simplify
1544
+ if not isfunction(simplify):
1545
+ simpfunc = _simplify if simplify else lambda x: x
1546
+
1547
+ if not self.is_square:
1548
+ return False
1549
+ return self._eval_is_anti_symmetric(simpfunc)
1550
+
1551
+ def is_diagonal(self):
1552
+ """Check if matrix is diagonal,
1553
+ that is matrix in which the entries outside the main diagonal are all zero.
1554
+
1555
+ Examples
1556
+ ========
1557
+
1558
+ >>> from sympy import Matrix, diag
1559
+ >>> m = Matrix(2, 2, [1, 0, 0, 2])
1560
+ >>> m
1561
+ Matrix([
1562
+ [1, 0],
1563
+ [0, 2]])
1564
+ >>> m.is_diagonal()
1565
+ True
1566
+
1567
+ >>> m = Matrix(2, 2, [1, 1, 0, 2])
1568
+ >>> m
1569
+ Matrix([
1570
+ [1, 1],
1571
+ [0, 2]])
1572
+ >>> m.is_diagonal()
1573
+ False
1574
+
1575
+ >>> m = diag(1, 2, 3)
1576
+ >>> m
1577
+ Matrix([
1578
+ [1, 0, 0],
1579
+ [0, 2, 0],
1580
+ [0, 0, 3]])
1581
+ >>> m.is_diagonal()
1582
+ True
1583
+
1584
+ See Also
1585
+ ========
1586
+
1587
+ is_lower
1588
+ is_upper
1589
+ sympy.matrices.matrixbase.MatrixCommon.is_diagonalizable
1590
+ diagonalize
1591
+ """
1592
+ return self._eval_is_diagonal()
1593
+
1594
+ @property
1595
+ def is_weakly_diagonally_dominant(self):
1596
+ r"""Tests if the matrix is row weakly diagonally dominant.
1597
+
1598
+ Explanation
1599
+ ===========
1600
+
1601
+ A $n, n$ matrix $A$ is row weakly diagonally dominant if
1602
+
1603
+ .. math::
1604
+ \left|A_{i, i}\right| \ge \sum_{j = 0, j \neq i}^{n-1}
1605
+ \left|A_{i, j}\right| \quad {\text{for all }}
1606
+ i \in \{ 0, ..., n-1 \}
1607
+
1608
+ Examples
1609
+ ========
1610
+
1611
+ >>> from sympy import Matrix
1612
+ >>> A = Matrix([[3, -2, 1], [1, -3, 2], [-1, 2, 4]])
1613
+ >>> A.is_weakly_diagonally_dominant
1614
+ True
1615
+
1616
+ >>> A = Matrix([[-2, 2, 1], [1, 3, 2], [1, -2, 0]])
1617
+ >>> A.is_weakly_diagonally_dominant
1618
+ False
1619
+
1620
+ >>> A = Matrix([[-4, 2, 1], [1, 6, 2], [1, -2, 5]])
1621
+ >>> A.is_weakly_diagonally_dominant
1622
+ True
1623
+
1624
+ Notes
1625
+ =====
1626
+
1627
+ If you want to test whether a matrix is column diagonally
1628
+ dominant, you can apply the test after transposing the matrix.
1629
+ """
1630
+ if not self.is_square:
1631
+ return False
1632
+
1633
+ rows, cols = self.shape
1634
+
1635
+ def test_row(i):
1636
+ summation = self.zero
1637
+ for j in range(cols):
1638
+ if i != j:
1639
+ summation += Abs(self[i, j])
1640
+ return (Abs(self[i, i]) - summation).is_nonnegative
1641
+
1642
+ return fuzzy_and(test_row(i) for i in range(rows))
1643
+
1644
+ @property
1645
+ def is_strongly_diagonally_dominant(self):
1646
+ r"""Tests if the matrix is row strongly diagonally dominant.
1647
+
1648
+ Explanation
1649
+ ===========
1650
+
1651
+ A $n, n$ matrix $A$ is row strongly diagonally dominant if
1652
+
1653
+ .. math::
1654
+ \left|A_{i, i}\right| > \sum_{j = 0, j \neq i}^{n-1}
1655
+ \left|A_{i, j}\right| \quad {\text{for all }}
1656
+ i \in \{ 0, ..., n-1 \}
1657
+
1658
+ Examples
1659
+ ========
1660
+
1661
+ >>> from sympy import Matrix
1662
+ >>> A = Matrix([[3, -2, 1], [1, -3, 2], [-1, 2, 4]])
1663
+ >>> A.is_strongly_diagonally_dominant
1664
+ False
1665
+
1666
+ >>> A = Matrix([[-2, 2, 1], [1, 3, 2], [1, -2, 0]])
1667
+ >>> A.is_strongly_diagonally_dominant
1668
+ False
1669
+
1670
+ >>> A = Matrix([[-4, 2, 1], [1, 6, 2], [1, -2, 5]])
1671
+ >>> A.is_strongly_diagonally_dominant
1672
+ True
1673
+
1674
+ Notes
1675
+ =====
1676
+
1677
+ If you want to test whether a matrix is column diagonally
1678
+ dominant, you can apply the test after transposing the matrix.
1679
+ """
1680
+ if not self.is_square:
1681
+ return False
1682
+
1683
+ rows, cols = self.shape
1684
+
1685
+ def test_row(i):
1686
+ summation = self.zero
1687
+ for j in range(cols):
1688
+ if i != j:
1689
+ summation += Abs(self[i, j])
1690
+ return (Abs(self[i, i]) - summation).is_positive
1691
+
1692
+ return fuzzy_and(test_row(i) for i in range(rows))
1693
+
1694
+ @property
1695
+ def is_hermitian(self):
1696
+ """Checks if the matrix is Hermitian.
1697
+
1698
+ In a Hermitian matrix element i,j is the complex conjugate of
1699
+ element j,i.
1700
+
1701
+ Examples
1702
+ ========
1703
+
1704
+ >>> from sympy import Matrix
1705
+ >>> from sympy import I
1706
+ >>> from sympy.abc import x
1707
+ >>> a = Matrix([[1, I], [-I, 1]])
1708
+ >>> a
1709
+ Matrix([
1710
+ [ 1, I],
1711
+ [-I, 1]])
1712
+ >>> a.is_hermitian
1713
+ True
1714
+ >>> a[0, 0] = 2*I
1715
+ >>> a.is_hermitian
1716
+ False
1717
+ >>> a[0, 0] = x
1718
+ >>> a.is_hermitian
1719
+ >>> a[0, 1] = a[1, 0]*I
1720
+ >>> a.is_hermitian
1721
+ False
1722
+ """
1723
+ if not self.is_square:
1724
+ return False
1725
+
1726
+ return self._eval_is_matrix_hermitian(_simplify)
1727
+
1728
+ @property
1729
+ def is_Identity(self) -> FuzzyBool:
1730
+ if not self.is_square:
1731
+ return False
1732
+ return self._eval_is_Identity()
1733
+
1734
+ @property
1735
+ def is_lower_hessenberg(self):
1736
+ r"""Checks if the matrix is in the lower-Hessenberg form.
1737
+
1738
+ The lower hessenberg matrix has zero entries
1739
+ above the first superdiagonal.
1740
+
1741
+ Examples
1742
+ ========
1743
+
1744
+ >>> from sympy import Matrix
1745
+ >>> a = Matrix([[1, 2, 0, 0], [5, 2, 3, 0], [3, 4, 3, 7], [5, 6, 1, 1]])
1746
+ >>> a
1747
+ Matrix([
1748
+ [1, 2, 0, 0],
1749
+ [5, 2, 3, 0],
1750
+ [3, 4, 3, 7],
1751
+ [5, 6, 1, 1]])
1752
+ >>> a.is_lower_hessenberg
1753
+ True
1754
+
1755
+ See Also
1756
+ ========
1757
+
1758
+ is_upper_hessenberg
1759
+ is_lower
1760
+ """
1761
+ return self._eval_is_lower_hessenberg()
1762
+
1763
+ @property
1764
+ def is_lower(self):
1765
+ """Check if matrix is a lower triangular matrix. True can be returned
1766
+ even if the matrix is not square.
1767
+
1768
+ Examples
1769
+ ========
1770
+
1771
+ >>> from sympy import Matrix
1772
+ >>> m = Matrix(2, 2, [1, 0, 0, 1])
1773
+ >>> m
1774
+ Matrix([
1775
+ [1, 0],
1776
+ [0, 1]])
1777
+ >>> m.is_lower
1778
+ True
1779
+
1780
+ >>> m = Matrix(4, 3, [0, 0, 0, 2, 0, 0, 1, 4, 0, 6, 6, 5])
1781
+ >>> m
1782
+ Matrix([
1783
+ [0, 0, 0],
1784
+ [2, 0, 0],
1785
+ [1, 4, 0],
1786
+ [6, 6, 5]])
1787
+ >>> m.is_lower
1788
+ True
1789
+
1790
+ >>> from sympy.abc import x, y
1791
+ >>> m = Matrix(2, 2, [x**2 + y, y**2 + x, 0, x + y])
1792
+ >>> m
1793
+ Matrix([
1794
+ [x**2 + y, x + y**2],
1795
+ [ 0, x + y]])
1796
+ >>> m.is_lower
1797
+ False
1798
+
1799
+ See Also
1800
+ ========
1801
+
1802
+ is_upper
1803
+ is_diagonal
1804
+ is_lower_hessenberg
1805
+ """
1806
+ return self._eval_is_lower()
1807
+
1808
+ @property
1809
+ def is_square(self):
1810
+ """Checks if a matrix is square.
1811
+
1812
+ A matrix is square if the number of rows equals the number of columns.
1813
+ The empty matrix is square by definition, since the number of rows and
1814
+ the number of columns are both zero.
1815
+
1816
+ Examples
1817
+ ========
1818
+
1819
+ >>> from sympy import Matrix
1820
+ >>> a = Matrix([[1, 2, 3], [4, 5, 6]])
1821
+ >>> b = Matrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
1822
+ >>> c = Matrix([])
1823
+ >>> a.is_square
1824
+ False
1825
+ >>> b.is_square
1826
+ True
1827
+ >>> c.is_square
1828
+ True
1829
+ """
1830
+ return self.rows == self.cols
1831
+
1832
+ def is_symbolic(self):
1833
+ """Checks if any elements contain Symbols.
1834
+
1835
+ Examples
1836
+ ========
1837
+
1838
+ >>> from sympy import Matrix
1839
+ >>> from sympy.abc import x, y
1840
+ >>> M = Matrix([[x, y], [1, 0]])
1841
+ >>> M.is_symbolic()
1842
+ True
1843
+
1844
+ """
1845
+ return self._eval_is_symbolic()
1846
+
1847
+ def is_symmetric(self, simplify=True):
1848
+ """Check if matrix is symmetric matrix,
1849
+ that is square matrix and is equal to its transpose.
1850
+
1851
+ By default, simplifications occur before testing symmetry.
1852
+ They can be skipped using 'simplify=False'; while speeding things a bit,
1853
+ this may however induce false negatives.
1854
+
1855
+ Examples
1856
+ ========
1857
+
1858
+ >>> from sympy import Matrix
1859
+ >>> m = Matrix(2, 2, [0, 1, 1, 2])
1860
+ >>> m
1861
+ Matrix([
1862
+ [0, 1],
1863
+ [1, 2]])
1864
+ >>> m.is_symmetric()
1865
+ True
1866
+
1867
+ >>> m = Matrix(2, 2, [0, 1, 2, 0])
1868
+ >>> m
1869
+ Matrix([
1870
+ [0, 1],
1871
+ [2, 0]])
1872
+ >>> m.is_symmetric()
1873
+ False
1874
+
1875
+ >>> m = Matrix(2, 3, [0, 0, 0, 0, 0, 0])
1876
+ >>> m
1877
+ Matrix([
1878
+ [0, 0, 0],
1879
+ [0, 0, 0]])
1880
+ >>> m.is_symmetric()
1881
+ False
1882
+
1883
+ >>> from sympy.abc import x, y
1884
+ >>> m = Matrix(3, 3, [1, x**2 + 2*x + 1, y, (x + 1)**2, 2, 0, y, 0, 3])
1885
+ >>> m
1886
+ Matrix([
1887
+ [ 1, x**2 + 2*x + 1, y],
1888
+ [(x + 1)**2, 2, 0],
1889
+ [ y, 0, 3]])
1890
+ >>> m.is_symmetric()
1891
+ True
1892
+
1893
+ If the matrix is already simplified, you may speed-up is_symmetric()
1894
+ test by using 'simplify=False'.
1895
+
1896
+ >>> bool(m.is_symmetric(simplify=False))
1897
+ False
1898
+ >>> m1 = m.expand()
1899
+ >>> m1.is_symmetric(simplify=False)
1900
+ True
1901
+ """
1902
+ simpfunc = simplify
1903
+ if not isfunction(simplify):
1904
+ simpfunc = _simplify if simplify else lambda x: x
1905
+
1906
+ if not self.is_square:
1907
+ return False
1908
+
1909
+ return self._eval_is_symmetric(simpfunc)
1910
+
1911
+ @property
1912
+ def is_upper_hessenberg(self):
1913
+ """Checks if the matrix is the upper-Hessenberg form.
1914
+
1915
+ The upper hessenberg matrix has zero entries
1916
+ below the first subdiagonal.
1917
+
1918
+ Examples
1919
+ ========
1920
+
1921
+ >>> from sympy import Matrix
1922
+ >>> a = Matrix([[1, 4, 2, 3], [3, 4, 1, 7], [0, 2, 3, 4], [0, 0, 1, 3]])
1923
+ >>> a
1924
+ Matrix([
1925
+ [1, 4, 2, 3],
1926
+ [3, 4, 1, 7],
1927
+ [0, 2, 3, 4],
1928
+ [0, 0, 1, 3]])
1929
+ >>> a.is_upper_hessenberg
1930
+ True
1931
+
1932
+ See Also
1933
+ ========
1934
+
1935
+ is_lower_hessenberg
1936
+ is_upper
1937
+ """
1938
+ return self._eval_is_upper_hessenberg()
1939
+
1940
+ @property
1941
+ def is_upper(self):
1942
+ """Check if matrix is an upper triangular matrix. True can be returned
1943
+ even if the matrix is not square.
1944
+
1945
+ Examples
1946
+ ========
1947
+
1948
+ >>> from sympy import Matrix
1949
+ >>> m = Matrix(2, 2, [1, 0, 0, 1])
1950
+ >>> m
1951
+ Matrix([
1952
+ [1, 0],
1953
+ [0, 1]])
1954
+ >>> m.is_upper
1955
+ True
1956
+
1957
+ >>> m = Matrix(4, 3, [5, 1, 9, 0, 4, 6, 0, 0, 5, 0, 0, 0])
1958
+ >>> m
1959
+ Matrix([
1960
+ [5, 1, 9],
1961
+ [0, 4, 6],
1962
+ [0, 0, 5],
1963
+ [0, 0, 0]])
1964
+ >>> m.is_upper
1965
+ True
1966
+
1967
+ >>> m = Matrix(2, 3, [4, 2, 5, 6, 1, 1])
1968
+ >>> m
1969
+ Matrix([
1970
+ [4, 2, 5],
1971
+ [6, 1, 1]])
1972
+ >>> m.is_upper
1973
+ False
1974
+
1975
+ See Also
1976
+ ========
1977
+
1978
+ is_lower
1979
+ is_diagonal
1980
+ is_upper_hessenberg
1981
+ """
1982
+ return all(self[i, j].is_zero
1983
+ for i in range(1, self.rows)
1984
+ for j in range(min(i, self.cols)))
1985
+
1986
+ @property
1987
+ def is_zero_matrix(self):
1988
+ """Checks if a matrix is a zero matrix.
1989
+
1990
+ A matrix is zero if every element is zero. A matrix need not be square
1991
+ to be considered zero. The empty matrix is zero by the principle of
1992
+ vacuous truth. For a matrix that may or may not be zero (e.g.
1993
+ contains a symbol), this will be None
1994
+
1995
+ Examples
1996
+ ========
1997
+
1998
+ >>> from sympy import Matrix, zeros
1999
+ >>> from sympy.abc import x
2000
+ >>> a = Matrix([[0, 0], [0, 0]])
2001
+ >>> b = zeros(3, 4)
2002
+ >>> c = Matrix([[0, 1], [0, 0]])
2003
+ >>> d = Matrix([])
2004
+ >>> e = Matrix([[x, 0], [0, 0]])
2005
+ >>> a.is_zero_matrix
2006
+ True
2007
+ >>> b.is_zero_matrix
2008
+ True
2009
+ >>> c.is_zero_matrix
2010
+ False
2011
+ >>> d.is_zero_matrix
2012
+ True
2013
+ >>> e.is_zero_matrix
2014
+ """
2015
+ return self._eval_is_zero_matrix()
2016
+
2017
+ def values(self):
2018
+ """Return non-zero values of self."""
2019
+ return self._eval_values()
2020
+
2021
+
2022
+ class MatrixOperations(MatrixRequired):
2023
+ """Provides basic matrix shape and elementwise
2024
+ operations. Should not be instantiated directly."""
2025
+
2026
+ def _eval_adjoint(self):
2027
+ return self.transpose().conjugate()
2028
+
2029
+ def _eval_applyfunc(self, f):
2030
+ out = self._new(self.rows, self.cols, [f(x) for x in self])
2031
+ return out
2032
+
2033
+ def _eval_as_real_imag(self): # type: ignore
2034
+ return (self.applyfunc(re), self.applyfunc(im))
2035
+
2036
+ def _eval_conjugate(self):
2037
+ return self.applyfunc(lambda x: x.conjugate())
2038
+
2039
+ def _eval_permute_cols(self, perm):
2040
+ # apply the permutation to a list
2041
+ mapping = list(perm)
2042
+
2043
+ def entry(i, j):
2044
+ return self[i, mapping[j]]
2045
+
2046
+ return self._new(self.rows, self.cols, entry)
2047
+
2048
+ def _eval_permute_rows(self, perm):
2049
+ # apply the permutation to a list
2050
+ mapping = list(perm)
2051
+
2052
+ def entry(i, j):
2053
+ return self[mapping[i], j]
2054
+
2055
+ return self._new(self.rows, self.cols, entry)
2056
+
2057
+ def _eval_trace(self):
2058
+ return sum(self[i, i] for i in range(self.rows))
2059
+
2060
+ def _eval_transpose(self):
2061
+ return self._new(self.cols, self.rows, lambda i, j: self[j, i])
2062
+
2063
+ def adjoint(self):
2064
+ """Conjugate transpose or Hermitian conjugation."""
2065
+ return self._eval_adjoint()
2066
+
2067
+ def applyfunc(self, f):
2068
+ """Apply a function to each element of the matrix.
2069
+
2070
+ Examples
2071
+ ========
2072
+
2073
+ >>> from sympy import Matrix
2074
+ >>> m = Matrix(2, 2, lambda i, j: i*2+j)
2075
+ >>> m
2076
+ Matrix([
2077
+ [0, 1],
2078
+ [2, 3]])
2079
+ >>> m.applyfunc(lambda i: 2*i)
2080
+ Matrix([
2081
+ [0, 2],
2082
+ [4, 6]])
2083
+
2084
+ """
2085
+ if not callable(f):
2086
+ raise TypeError("`f` must be callable.")
2087
+
2088
+ return self._eval_applyfunc(f)
2089
+
2090
+ def as_real_imag(self, deep=True, **hints):
2091
+ """Returns a tuple containing the (real, imaginary) part of matrix."""
2092
+ # XXX: Ignoring deep and hints...
2093
+ return self._eval_as_real_imag()
2094
+
2095
+ def conjugate(self):
2096
+ """Return the by-element conjugation.
2097
+
2098
+ Examples
2099
+ ========
2100
+
2101
+ >>> from sympy import SparseMatrix, I
2102
+ >>> a = SparseMatrix(((1, 2 + I), (3, 4), (I, -I)))
2103
+ >>> a
2104
+ Matrix([
2105
+ [1, 2 + I],
2106
+ [3, 4],
2107
+ [I, -I]])
2108
+ >>> a.C
2109
+ Matrix([
2110
+ [ 1, 2 - I],
2111
+ [ 3, 4],
2112
+ [-I, I]])
2113
+
2114
+ See Also
2115
+ ========
2116
+
2117
+ transpose: Matrix transposition
2118
+ H: Hermite conjugation
2119
+ sympy.matrices.matrixbase.MatrixBase.D: Dirac conjugation
2120
+ """
2121
+ return self._eval_conjugate()
2122
+
2123
+ def doit(self, **hints):
2124
+ return self.applyfunc(lambda x: x.doit(**hints))
2125
+
2126
+ def evalf(self, n=15, subs=None, maxn=100, chop=False, strict=False, quad=None, verbose=False):
2127
+ """Apply evalf() to each element of self."""
2128
+ options = {'subs':subs, 'maxn':maxn, 'chop':chop, 'strict':strict,
2129
+ 'quad':quad, 'verbose':verbose}
2130
+ return self.applyfunc(lambda i: i.evalf(n, **options))
2131
+
2132
+ def expand(self, deep=True, modulus=None, power_base=True, power_exp=True,
2133
+ mul=True, log=True, multinomial=True, basic=True, **hints):
2134
+ """Apply core.function.expand to each entry of the matrix.
2135
+
2136
+ Examples
2137
+ ========
2138
+
2139
+ >>> from sympy.abc import x
2140
+ >>> from sympy import Matrix
2141
+ >>> Matrix(1, 1, [x*(x+1)])
2142
+ Matrix([[x*(x + 1)]])
2143
+ >>> _.expand()
2144
+ Matrix([[x**2 + x]])
2145
+
2146
+ """
2147
+ return self.applyfunc(lambda x: x.expand(
2148
+ deep, modulus, power_base, power_exp, mul, log, multinomial, basic,
2149
+ **hints))
2150
+
2151
+ @property
2152
+ def H(self):
2153
+ """Return Hermite conjugate.
2154
+
2155
+ Examples
2156
+ ========
2157
+
2158
+ >>> from sympy import Matrix, I
2159
+ >>> m = Matrix((0, 1 + I, 2, 3))
2160
+ >>> m
2161
+ Matrix([
2162
+ [ 0],
2163
+ [1 + I],
2164
+ [ 2],
2165
+ [ 3]])
2166
+ >>> m.H
2167
+ Matrix([[0, 1 - I, 2, 3]])
2168
+
2169
+ See Also
2170
+ ========
2171
+
2172
+ conjugate: By-element conjugation
2173
+ sympy.matrices.matrixbase.MatrixBase.D: Dirac conjugation
2174
+ """
2175
+ return self.T.C
2176
+
2177
+ def permute(self, perm, orientation='rows', direction='forward'):
2178
+ r"""Permute the rows or columns of a matrix by the given list of
2179
+ swaps.
2180
+
2181
+ Parameters
2182
+ ==========
2183
+
2184
+ perm : Permutation, list, or list of lists
2185
+ A representation for the permutation.
2186
+
2187
+ If it is ``Permutation``, it is used directly with some
2188
+ resizing with respect to the matrix size.
2189
+
2190
+ If it is specified as list of lists,
2191
+ (e.g., ``[[0, 1], [0, 2]]``), then the permutation is formed
2192
+ from applying the product of cycles. The direction how the
2193
+ cyclic product is applied is described in below.
2194
+
2195
+ If it is specified as a list, the list should represent
2196
+ an array form of a permutation. (e.g., ``[1, 2, 0]``) which
2197
+ would would form the swapping function
2198
+ `0 \mapsto 1, 1 \mapsto 2, 2\mapsto 0`.
2199
+
2200
+ orientation : 'rows', 'cols'
2201
+ A flag to control whether to permute the rows or the columns
2202
+
2203
+ direction : 'forward', 'backward'
2204
+ A flag to control whether to apply the permutations from
2205
+ the start of the list first, or from the back of the list
2206
+ first.
2207
+
2208
+ For example, if the permutation specification is
2209
+ ``[[0, 1], [0, 2]]``,
2210
+
2211
+ If the flag is set to ``'forward'``, the cycle would be
2212
+ formed as `0 \mapsto 2, 2 \mapsto 1, 1 \mapsto 0`.
2213
+
2214
+ If the flag is set to ``'backward'``, the cycle would be
2215
+ formed as `0 \mapsto 1, 1 \mapsto 2, 2 \mapsto 0`.
2216
+
2217
+ If the argument ``perm`` is not in a form of list of lists,
2218
+ this flag takes no effect.
2219
+
2220
+ Examples
2221
+ ========
2222
+
2223
+ >>> from sympy import eye
2224
+ >>> M = eye(3)
2225
+ >>> M.permute([[0, 1], [0, 2]], orientation='rows', direction='forward')
2226
+ Matrix([
2227
+ [0, 0, 1],
2228
+ [1, 0, 0],
2229
+ [0, 1, 0]])
2230
+
2231
+ >>> from sympy import eye
2232
+ >>> M = eye(3)
2233
+ >>> M.permute([[0, 1], [0, 2]], orientation='rows', direction='backward')
2234
+ Matrix([
2235
+ [0, 1, 0],
2236
+ [0, 0, 1],
2237
+ [1, 0, 0]])
2238
+
2239
+ Notes
2240
+ =====
2241
+
2242
+ If a bijective function
2243
+ `\sigma : \mathbb{N}_0 \rightarrow \mathbb{N}_0` denotes the
2244
+ permutation.
2245
+
2246
+ If the matrix `A` is the matrix to permute, represented as
2247
+ a horizontal or a vertical stack of vectors:
2248
+
2249
+ .. math::
2250
+ A =
2251
+ \begin{bmatrix}
2252
+ a_0 \\ a_1 \\ \vdots \\ a_{n-1}
2253
+ \end{bmatrix} =
2254
+ \begin{bmatrix}
2255
+ \alpha_0 & \alpha_1 & \cdots & \alpha_{n-1}
2256
+ \end{bmatrix}
2257
+
2258
+ If the matrix `B` is the result, the permutation of matrix rows
2259
+ is defined as:
2260
+
2261
+ .. math::
2262
+ B := \begin{bmatrix}
2263
+ a_{\sigma(0)} \\ a_{\sigma(1)} \\ \vdots \\ a_{\sigma(n-1)}
2264
+ \end{bmatrix}
2265
+
2266
+ And the permutation of matrix columns is defined as:
2267
+
2268
+ .. math::
2269
+ B := \begin{bmatrix}
2270
+ \alpha_{\sigma(0)} & \alpha_{\sigma(1)} &
2271
+ \cdots & \alpha_{\sigma(n-1)}
2272
+ \end{bmatrix}
2273
+ """
2274
+ from sympy.combinatorics import Permutation
2275
+
2276
+ # allow british variants and `columns`
2277
+ if direction == 'forwards':
2278
+ direction = 'forward'
2279
+ if direction == 'backwards':
2280
+ direction = 'backward'
2281
+ if orientation == 'columns':
2282
+ orientation = 'cols'
2283
+
2284
+ if direction not in ('forward', 'backward'):
2285
+ raise TypeError("direction='{}' is an invalid kwarg. "
2286
+ "Try 'forward' or 'backward'".format(direction))
2287
+ if orientation not in ('rows', 'cols'):
2288
+ raise TypeError("orientation='{}' is an invalid kwarg. "
2289
+ "Try 'rows' or 'cols'".format(orientation))
2290
+
2291
+ if not isinstance(perm, (Permutation, Iterable)):
2292
+ raise ValueError(
2293
+ "{} must be a list, a list of lists, "
2294
+ "or a SymPy permutation object.".format(perm))
2295
+
2296
+ # ensure all swaps are in range
2297
+ max_index = self.rows if orientation == 'rows' else self.cols
2298
+ if not all(0 <= t <= max_index for t in flatten(list(perm))):
2299
+ raise IndexError("`swap` indices out of range.")
2300
+
2301
+ if perm and not isinstance(perm, Permutation) and \
2302
+ isinstance(perm[0], Iterable):
2303
+ if direction == 'forward':
2304
+ perm = list(reversed(perm))
2305
+ perm = Permutation(perm, size=max_index+1)
2306
+ else:
2307
+ perm = Permutation(perm, size=max_index+1)
2308
+
2309
+ if orientation == 'rows':
2310
+ return self._eval_permute_rows(perm)
2311
+ if orientation == 'cols':
2312
+ return self._eval_permute_cols(perm)
2313
+
2314
+ def permute_cols(self, swaps, direction='forward'):
2315
+ """Alias for
2316
+ ``self.permute(swaps, orientation='cols', direction=direction)``
2317
+
2318
+ See Also
2319
+ ========
2320
+
2321
+ permute
2322
+ """
2323
+ return self.permute(swaps, orientation='cols', direction=direction)
2324
+
2325
+ def permute_rows(self, swaps, direction='forward'):
2326
+ """Alias for
2327
+ ``self.permute(swaps, orientation='rows', direction=direction)``
2328
+
2329
+ See Also
2330
+ ========
2331
+
2332
+ permute
2333
+ """
2334
+ return self.permute(swaps, orientation='rows', direction=direction)
2335
+
2336
+ def refine(self, assumptions=True):
2337
+ """Apply refine to each element of the matrix.
2338
+
2339
+ Examples
2340
+ ========
2341
+
2342
+ >>> from sympy import Symbol, Matrix, Abs, sqrt, Q
2343
+ >>> x = Symbol('x')
2344
+ >>> Matrix([[Abs(x)**2, sqrt(x**2)],[sqrt(x**2), Abs(x)**2]])
2345
+ Matrix([
2346
+ [ Abs(x)**2, sqrt(x**2)],
2347
+ [sqrt(x**2), Abs(x)**2]])
2348
+ >>> _.refine(Q.real(x))
2349
+ Matrix([
2350
+ [ x**2, Abs(x)],
2351
+ [Abs(x), x**2]])
2352
+
2353
+ """
2354
+ return self.applyfunc(lambda x: refine(x, assumptions))
2355
+
2356
+ def replace(self, F, G, map=False, simultaneous=True, exact=None):
2357
+ """Replaces Function F in Matrix entries with Function G.
2358
+
2359
+ Examples
2360
+ ========
2361
+
2362
+ >>> from sympy import symbols, Function, Matrix
2363
+ >>> F, G = symbols('F, G', cls=Function)
2364
+ >>> M = Matrix(2, 2, lambda i, j: F(i+j)) ; M
2365
+ Matrix([
2366
+ [F(0), F(1)],
2367
+ [F(1), F(2)]])
2368
+ >>> N = M.replace(F,G)
2369
+ >>> N
2370
+ Matrix([
2371
+ [G(0), G(1)],
2372
+ [G(1), G(2)]])
2373
+ """
2374
+ return self.applyfunc(
2375
+ lambda x: x.replace(F, G, map=map, simultaneous=simultaneous, exact=exact))
2376
+
2377
+ def rot90(self, k=1):
2378
+ """Rotates Matrix by 90 degrees
2379
+
2380
+ Parameters
2381
+ ==========
2382
+
2383
+ k : int
2384
+ Specifies how many times the matrix is rotated by 90 degrees
2385
+ (clockwise when positive, counter-clockwise when negative).
2386
+
2387
+ Examples
2388
+ ========
2389
+
2390
+ >>> from sympy import Matrix, symbols
2391
+ >>> A = Matrix(2, 2, symbols('a:d'))
2392
+ >>> A
2393
+ Matrix([
2394
+ [a, b],
2395
+ [c, d]])
2396
+
2397
+ Rotating the matrix clockwise one time:
2398
+
2399
+ >>> A.rot90(1)
2400
+ Matrix([
2401
+ [c, a],
2402
+ [d, b]])
2403
+
2404
+ Rotating the matrix anticlockwise two times:
2405
+
2406
+ >>> A.rot90(-2)
2407
+ Matrix([
2408
+ [d, c],
2409
+ [b, a]])
2410
+ """
2411
+
2412
+ mod = k%4
2413
+ if mod == 0:
2414
+ return self
2415
+ if mod == 1:
2416
+ return self[::-1, ::].T
2417
+ if mod == 2:
2418
+ return self[::-1, ::-1]
2419
+ if mod == 3:
2420
+ return self[::, ::-1].T
2421
+
2422
+ def simplify(self, **kwargs):
2423
+ """Apply simplify to each element of the matrix.
2424
+
2425
+ Examples
2426
+ ========
2427
+
2428
+ >>> from sympy.abc import x, y
2429
+ >>> from sympy import SparseMatrix, sin, cos
2430
+ >>> SparseMatrix(1, 1, [x*sin(y)**2 + x*cos(y)**2])
2431
+ Matrix([[x*sin(y)**2 + x*cos(y)**2]])
2432
+ >>> _.simplify()
2433
+ Matrix([[x]])
2434
+ """
2435
+ return self.applyfunc(lambda x: x.simplify(**kwargs))
2436
+
2437
+ def subs(self, *args, **kwargs): # should mirror core.basic.subs
2438
+ """Return a new matrix with subs applied to each entry.
2439
+
2440
+ Examples
2441
+ ========
2442
+
2443
+ >>> from sympy.abc import x, y
2444
+ >>> from sympy import SparseMatrix, Matrix
2445
+ >>> SparseMatrix(1, 1, [x])
2446
+ Matrix([[x]])
2447
+ >>> _.subs(x, y)
2448
+ Matrix([[y]])
2449
+ >>> Matrix(_).subs(y, x)
2450
+ Matrix([[x]])
2451
+ """
2452
+
2453
+ if len(args) == 1 and not isinstance(args[0], (dict, set)) and iter(args[0]) and not is_sequence(args[0]):
2454
+ args = (list(args[0]),)
2455
+
2456
+ return self.applyfunc(lambda x: x.subs(*args, **kwargs))
2457
+
2458
+ def trace(self):
2459
+ """
2460
+ Returns the trace of a square matrix i.e. the sum of the
2461
+ diagonal elements.
2462
+
2463
+ Examples
2464
+ ========
2465
+
2466
+ >>> from sympy import Matrix
2467
+ >>> A = Matrix(2, 2, [1, 2, 3, 4])
2468
+ >>> A.trace()
2469
+ 5
2470
+
2471
+ """
2472
+ if self.rows != self.cols:
2473
+ raise NonSquareMatrixError()
2474
+ return self._eval_trace()
2475
+
2476
+ def transpose(self):
2477
+ """
2478
+ Returns the transpose of the matrix.
2479
+
2480
+ Examples
2481
+ ========
2482
+
2483
+ >>> from sympy import Matrix
2484
+ >>> A = Matrix(2, 2, [1, 2, 3, 4])
2485
+ >>> A.transpose()
2486
+ Matrix([
2487
+ [1, 3],
2488
+ [2, 4]])
2489
+
2490
+ >>> from sympy import Matrix, I
2491
+ >>> m=Matrix(((1, 2+I), (3, 4)))
2492
+ >>> m
2493
+ Matrix([
2494
+ [1, 2 + I],
2495
+ [3, 4]])
2496
+ >>> m.transpose()
2497
+ Matrix([
2498
+ [ 1, 3],
2499
+ [2 + I, 4]])
2500
+ >>> m.T == m.transpose()
2501
+ True
2502
+
2503
+ See Also
2504
+ ========
2505
+
2506
+ conjugate: By-element conjugation
2507
+
2508
+ """
2509
+ return self._eval_transpose()
2510
+
2511
+ @property
2512
+ def T(self):
2513
+ '''Matrix transposition'''
2514
+ return self.transpose()
2515
+
2516
+ @property
2517
+ def C(self):
2518
+ '''By-element conjugation'''
2519
+ return self.conjugate()
2520
+
2521
+ def n(self, *args, **kwargs):
2522
+ """Apply evalf() to each element of self."""
2523
+ return self.evalf(*args, **kwargs)
2524
+
2525
+ def xreplace(self, rule): # should mirror core.basic.xreplace
2526
+ """Return a new matrix with xreplace applied to each entry.
2527
+
2528
+ Examples
2529
+ ========
2530
+
2531
+ >>> from sympy.abc import x, y
2532
+ >>> from sympy import SparseMatrix, Matrix
2533
+ >>> SparseMatrix(1, 1, [x])
2534
+ Matrix([[x]])
2535
+ >>> _.xreplace({x: y})
2536
+ Matrix([[y]])
2537
+ >>> Matrix(_).xreplace({y: x})
2538
+ Matrix([[x]])
2539
+ """
2540
+ return self.applyfunc(lambda x: x.xreplace(rule))
2541
+
2542
+ def _eval_simplify(self, **kwargs):
2543
+ # XXX: We can't use self.simplify here as mutable subclasses will
2544
+ # override simplify and have it return None
2545
+ return MatrixOperations.simplify(self, **kwargs)
2546
+
2547
+ def _eval_trigsimp(self, **opts):
2548
+ from sympy.simplify.trigsimp import trigsimp
2549
+ return self.applyfunc(lambda x: trigsimp(x, **opts))
2550
+
2551
+ def upper_triangular(self, k=0):
2552
+ """Return the elements on and above the kth diagonal of a matrix.
2553
+ If k is not specified then simply returns upper-triangular portion
2554
+ of a matrix
2555
+
2556
+ Examples
2557
+ ========
2558
+
2559
+ >>> from sympy import ones
2560
+ >>> A = ones(4)
2561
+ >>> A.upper_triangular()
2562
+ Matrix([
2563
+ [1, 1, 1, 1],
2564
+ [0, 1, 1, 1],
2565
+ [0, 0, 1, 1],
2566
+ [0, 0, 0, 1]])
2567
+
2568
+ >>> A.upper_triangular(2)
2569
+ Matrix([
2570
+ [0, 0, 1, 1],
2571
+ [0, 0, 0, 1],
2572
+ [0, 0, 0, 0],
2573
+ [0, 0, 0, 0]])
2574
+
2575
+ >>> A.upper_triangular(-1)
2576
+ Matrix([
2577
+ [1, 1, 1, 1],
2578
+ [1, 1, 1, 1],
2579
+ [0, 1, 1, 1],
2580
+ [0, 0, 1, 1]])
2581
+
2582
+ """
2583
+
2584
+ def entry(i, j):
2585
+ return self[i, j] if i + k <= j else self.zero
2586
+
2587
+ return self._new(self.rows, self.cols, entry)
2588
+
2589
+
2590
+ def lower_triangular(self, k=0):
2591
+ """Return the elements on and below the kth diagonal of a matrix.
2592
+ If k is not specified then simply returns lower-triangular portion
2593
+ of a matrix
2594
+
2595
+ Examples
2596
+ ========
2597
+
2598
+ >>> from sympy import ones
2599
+ >>> A = ones(4)
2600
+ >>> A.lower_triangular()
2601
+ Matrix([
2602
+ [1, 0, 0, 0],
2603
+ [1, 1, 0, 0],
2604
+ [1, 1, 1, 0],
2605
+ [1, 1, 1, 1]])
2606
+
2607
+ >>> A.lower_triangular(-2)
2608
+ Matrix([
2609
+ [0, 0, 0, 0],
2610
+ [0, 0, 0, 0],
2611
+ [1, 0, 0, 0],
2612
+ [1, 1, 0, 0]])
2613
+
2614
+ >>> A.lower_triangular(1)
2615
+ Matrix([
2616
+ [1, 1, 0, 0],
2617
+ [1, 1, 1, 0],
2618
+ [1, 1, 1, 1],
2619
+ [1, 1, 1, 1]])
2620
+
2621
+ """
2622
+
2623
+ def entry(i, j):
2624
+ return self[i, j] if i + k >= j else self.zero
2625
+
2626
+ return self._new(self.rows, self.cols, entry)
2627
+
2628
+
2629
+
2630
+ class MatrixArithmetic(MatrixRequired):
2631
+ """Provides basic matrix arithmetic operations.
2632
+ Should not be instantiated directly."""
2633
+
2634
+ _op_priority = 10.01
2635
+
2636
+ def _eval_Abs(self):
2637
+ return self._new(self.rows, self.cols, lambda i, j: Abs(self[i, j]))
2638
+
2639
+ def _eval_add(self, other):
2640
+ return self._new(self.rows, self.cols,
2641
+ lambda i, j: self[i, j] + other[i, j])
2642
+
2643
+ def _eval_matrix_mul(self, other):
2644
+ def entry(i, j):
2645
+ vec = [self[i,k]*other[k,j] for k in range(self.cols)]
2646
+ try:
2647
+ return Add(*vec)
2648
+ except (TypeError, SympifyError):
2649
+ # Some matrices don't work with `sum` or `Add`
2650
+ # They don't work with `sum` because `sum` tries to add `0`
2651
+ # Fall back to a safe way to multiply if the `Add` fails.
2652
+ return reduce(lambda a, b: a + b, vec)
2653
+
2654
+ return self._new(self.rows, other.cols, entry)
2655
+
2656
+ def _eval_matrix_mul_elementwise(self, other):
2657
+ return self._new(self.rows, self.cols, lambda i, j: self[i,j]*other[i,j])
2658
+
2659
+ def _eval_matrix_rmul(self, other):
2660
+ def entry(i, j):
2661
+ return sum(other[i,k]*self[k,j] for k in range(other.cols))
2662
+ return self._new(other.rows, self.cols, entry)
2663
+
2664
+ def _eval_pow_by_recursion(self, num):
2665
+ if num == 1:
2666
+ return self
2667
+
2668
+ if num % 2 == 1:
2669
+ a, b = self, self._eval_pow_by_recursion(num - 1)
2670
+ else:
2671
+ a = b = self._eval_pow_by_recursion(num // 2)
2672
+
2673
+ return a.multiply(b)
2674
+
2675
+ def _eval_pow_by_cayley(self, exp):
2676
+ from sympy.discrete.recurrences import linrec_coeffs
2677
+ row = self.shape[0]
2678
+ p = self.charpoly()
2679
+
2680
+ coeffs = (-p).all_coeffs()[1:]
2681
+ coeffs = linrec_coeffs(coeffs, exp)
2682
+ new_mat = self.eye(row)
2683
+ ans = self.zeros(row)
2684
+
2685
+ for i in range(row):
2686
+ ans += coeffs[i]*new_mat
2687
+ new_mat *= self
2688
+
2689
+ return ans
2690
+
2691
+ def _eval_pow_by_recursion_dotprodsimp(self, num, prevsimp=None):
2692
+ if prevsimp is None:
2693
+ prevsimp = [True]*len(self)
2694
+
2695
+ if num == 1:
2696
+ return self
2697
+
2698
+ if num % 2 == 1:
2699
+ a, b = self, self._eval_pow_by_recursion_dotprodsimp(num - 1,
2700
+ prevsimp=prevsimp)
2701
+ else:
2702
+ a = b = self._eval_pow_by_recursion_dotprodsimp(num // 2,
2703
+ prevsimp=prevsimp)
2704
+
2705
+ m = a.multiply(b, dotprodsimp=False)
2706
+ lenm = len(m)
2707
+ elems = [None]*lenm
2708
+
2709
+ for i in range(lenm):
2710
+ if prevsimp[i]:
2711
+ elems[i], prevsimp[i] = _dotprodsimp(m[i], withsimp=True)
2712
+ else:
2713
+ elems[i] = m[i]
2714
+
2715
+ return m._new(m.rows, m.cols, elems)
2716
+
2717
+ def _eval_scalar_mul(self, other):
2718
+ return self._new(self.rows, self.cols, lambda i, j: self[i,j]*other)
2719
+
2720
+ def _eval_scalar_rmul(self, other):
2721
+ return self._new(self.rows, self.cols, lambda i, j: other*self[i,j])
2722
+
2723
+ def _eval_Mod(self, other):
2724
+ return self._new(self.rows, self.cols, lambda i, j: Mod(self[i, j], other))
2725
+
2726
+ # Python arithmetic functions
2727
+ def __abs__(self):
2728
+ """Returns a new matrix with entry-wise absolute values."""
2729
+ return self._eval_Abs()
2730
+
2731
+ @call_highest_priority('__radd__')
2732
+ def __add__(self, other):
2733
+ """Return self + other, raising ShapeError if shapes do not match."""
2734
+ if isinstance(other, NDimArray): # Matrix and array addition is currently not implemented
2735
+ return NotImplemented
2736
+ other = _matrixify(other)
2737
+ # matrix-like objects can have shapes. This is
2738
+ # our first sanity check.
2739
+ if hasattr(other, 'shape'):
2740
+ if self.shape != other.shape:
2741
+ raise ShapeError("Matrix size mismatch: %s + %s" % (
2742
+ self.shape, other.shape))
2743
+
2744
+ # honest SymPy matrices defer to their class's routine
2745
+ if getattr(other, 'is_Matrix', False):
2746
+ # call the highest-priority class's _eval_add
2747
+ a, b = self, other
2748
+ if a.__class__ != classof(a, b):
2749
+ b, a = a, b
2750
+ return a._eval_add(b)
2751
+ # Matrix-like objects can be passed to CommonMatrix routines directly.
2752
+ if getattr(other, 'is_MatrixLike', False):
2753
+ return MatrixArithmetic._eval_add(self, other)
2754
+
2755
+ raise TypeError('cannot add %s and %s' % (type(self), type(other)))
2756
+
2757
+ @call_highest_priority('__rtruediv__')
2758
+ def __truediv__(self, other):
2759
+ return self * (self.one / other)
2760
+
2761
+ @call_highest_priority('__rmatmul__')
2762
+ def __matmul__(self, other):
2763
+ other = _matrixify(other)
2764
+ if not getattr(other, 'is_Matrix', False) and not getattr(other, 'is_MatrixLike', False):
2765
+ return NotImplemented
2766
+
2767
+ return self.__mul__(other)
2768
+
2769
+ def __mod__(self, other):
2770
+ return self.applyfunc(lambda x: x % other)
2771
+
2772
+ @call_highest_priority('__rmul__')
2773
+ def __mul__(self, other):
2774
+ """Return self*other where other is either a scalar or a matrix
2775
+ of compatible dimensions.
2776
+
2777
+ Examples
2778
+ ========
2779
+
2780
+ >>> from sympy import Matrix
2781
+ >>> A = Matrix([[1, 2, 3], [4, 5, 6]])
2782
+ >>> 2*A == A*2 == Matrix([[2, 4, 6], [8, 10, 12]])
2783
+ True
2784
+ >>> B = Matrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
2785
+ >>> A*B
2786
+ Matrix([
2787
+ [30, 36, 42],
2788
+ [66, 81, 96]])
2789
+ >>> B*A
2790
+ Traceback (most recent call last):
2791
+ ...
2792
+ ShapeError: Matrices size mismatch.
2793
+ >>>
2794
+
2795
+ See Also
2796
+ ========
2797
+
2798
+ matrix_multiply_elementwise
2799
+ """
2800
+
2801
+ return self.multiply(other)
2802
+
2803
+ def multiply(self, other, dotprodsimp=None):
2804
+ """Same as __mul__() but with optional simplification.
2805
+
2806
+ Parameters
2807
+ ==========
2808
+
2809
+ dotprodsimp : bool, optional
2810
+ Specifies whether intermediate term algebraic simplification is used
2811
+ during matrix multiplications to control expression blowup and thus
2812
+ speed up calculation. Default is off.
2813
+ """
2814
+
2815
+ isimpbool = _get_intermediate_simp_bool(False, dotprodsimp)
2816
+ other = _matrixify(other)
2817
+ # matrix-like objects can have shapes. This is
2818
+ # our first sanity check. Double check other is not explicitly not a Matrix.
2819
+ if (hasattr(other, 'shape') and len(other.shape) == 2 and
2820
+ (getattr(other, 'is_Matrix', True) or
2821
+ getattr(other, 'is_MatrixLike', True))):
2822
+ if self.shape[1] != other.shape[0]:
2823
+ raise ShapeError("Matrix size mismatch: %s * %s." % (
2824
+ self.shape, other.shape))
2825
+
2826
+ # honest SymPy matrices defer to their class's routine
2827
+ if getattr(other, 'is_Matrix', False):
2828
+ m = self._eval_matrix_mul(other)
2829
+ if isimpbool:
2830
+ return m._new(m.rows, m.cols, [_dotprodsimp(e) for e in m])
2831
+ return m
2832
+
2833
+ # Matrix-like objects can be passed to CommonMatrix routines directly.
2834
+ if getattr(other, 'is_MatrixLike', False):
2835
+ return MatrixArithmetic._eval_matrix_mul(self, other)
2836
+
2837
+ # if 'other' is not iterable then scalar multiplication.
2838
+ if not isinstance(other, Iterable):
2839
+ try:
2840
+ return self._eval_scalar_mul(other)
2841
+ except TypeError:
2842
+ pass
2843
+
2844
+ return NotImplemented
2845
+
2846
+ def multiply_elementwise(self, other):
2847
+ """Return the Hadamard product (elementwise product) of A and B
2848
+
2849
+ Examples
2850
+ ========
2851
+
2852
+ >>> from sympy import Matrix
2853
+ >>> A = Matrix([[0, 1, 2], [3, 4, 5]])
2854
+ >>> B = Matrix([[1, 10, 100], [100, 10, 1]])
2855
+ >>> A.multiply_elementwise(B)
2856
+ Matrix([
2857
+ [ 0, 10, 200],
2858
+ [300, 40, 5]])
2859
+
2860
+ See Also
2861
+ ========
2862
+
2863
+ sympy.matrices.matrixbase.MatrixBase.cross
2864
+ sympy.matrices.matrixbase.MatrixBase.dot
2865
+ multiply
2866
+ """
2867
+ if self.shape != other.shape:
2868
+ raise ShapeError("Matrix shapes must agree {} != {}".format(self.shape, other.shape))
2869
+
2870
+ return self._eval_matrix_mul_elementwise(other)
2871
+
2872
+ def __neg__(self):
2873
+ return self._eval_scalar_mul(-1)
2874
+
2875
+ @call_highest_priority('__rpow__')
2876
+ def __pow__(self, exp):
2877
+ """Return self**exp a scalar or symbol."""
2878
+
2879
+ return self.pow(exp)
2880
+
2881
+
2882
+ def pow(self, exp, method=None):
2883
+ r"""Return self**exp a scalar or symbol.
2884
+
2885
+ Parameters
2886
+ ==========
2887
+
2888
+ method : multiply, mulsimp, jordan, cayley
2889
+ If multiply then it returns exponentiation using recursion.
2890
+ If jordan then Jordan form exponentiation will be used.
2891
+ If cayley then the exponentiation is done using Cayley-Hamilton
2892
+ theorem.
2893
+ If mulsimp then the exponentiation is done using recursion
2894
+ with dotprodsimp. This specifies whether intermediate term
2895
+ algebraic simplification is used during naive matrix power to
2896
+ control expression blowup and thus speed up calculation.
2897
+ If None, then it heuristically decides which method to use.
2898
+
2899
+ """
2900
+
2901
+ if method is not None and method not in ['multiply', 'mulsimp', 'jordan', 'cayley']:
2902
+ raise TypeError('No such method')
2903
+ if self.rows != self.cols:
2904
+ raise NonSquareMatrixError()
2905
+ a = self
2906
+ jordan_pow = getattr(a, '_matrix_pow_by_jordan_blocks', None)
2907
+ exp = sympify(exp)
2908
+
2909
+ if exp.is_zero:
2910
+ return a._new(a.rows, a.cols, lambda i, j: int(i == j))
2911
+ if exp == 1:
2912
+ return a
2913
+
2914
+ diagonal = getattr(a, 'is_diagonal', None)
2915
+ if diagonal is not None and diagonal():
2916
+ return a._new(a.rows, a.cols, lambda i, j: a[i,j]**exp if i == j else 0)
2917
+
2918
+ if exp.is_Number and exp % 1 == 0:
2919
+ if a.rows == 1:
2920
+ return a._new([[a[0]**exp]])
2921
+ if exp < 0:
2922
+ exp = -exp
2923
+ a = a.inv()
2924
+ # When certain conditions are met,
2925
+ # Jordan block algorithm is faster than
2926
+ # computation by recursion.
2927
+ if method == 'jordan':
2928
+ try:
2929
+ return jordan_pow(exp)
2930
+ except MatrixError:
2931
+ if method == 'jordan':
2932
+ raise
2933
+
2934
+ elif method == 'cayley':
2935
+ if not exp.is_Number or exp % 1 != 0:
2936
+ raise ValueError("cayley method is only valid for integer powers")
2937
+ return a._eval_pow_by_cayley(exp)
2938
+
2939
+ elif method == "mulsimp":
2940
+ if not exp.is_Number or exp % 1 != 0:
2941
+ raise ValueError("mulsimp method is only valid for integer powers")
2942
+ return a._eval_pow_by_recursion_dotprodsimp(exp)
2943
+
2944
+ elif method == "multiply":
2945
+ if not exp.is_Number or exp % 1 != 0:
2946
+ raise ValueError("multiply method is only valid for integer powers")
2947
+ return a._eval_pow_by_recursion(exp)
2948
+
2949
+ elif method is None and exp.is_Number and exp % 1 == 0:
2950
+ if exp.is_Float:
2951
+ exp = Integer(exp)
2952
+ # Decide heuristically which method to apply
2953
+ if a.rows == 2 and exp > 100000:
2954
+ return jordan_pow(exp)
2955
+ elif _get_intermediate_simp_bool(True, None):
2956
+ return a._eval_pow_by_recursion_dotprodsimp(exp)
2957
+ elif exp > 10000:
2958
+ return a._eval_pow_by_cayley(exp)
2959
+ else:
2960
+ return a._eval_pow_by_recursion(exp)
2961
+
2962
+ if jordan_pow:
2963
+ try:
2964
+ return jordan_pow(exp)
2965
+ except NonInvertibleMatrixError:
2966
+ # Raised by jordan_pow on zero determinant matrix unless exp is
2967
+ # definitely known to be a non-negative integer.
2968
+ # Here we raise if n is definitely not a non-negative integer
2969
+ # but otherwise we can leave this as an unevaluated MatPow.
2970
+ if exp.is_integer is False or exp.is_nonnegative is False:
2971
+ raise
2972
+
2973
+ from sympy.matrices.expressions import MatPow
2974
+ return MatPow(a, exp)
2975
+
2976
+ @call_highest_priority('__add__')
2977
+ def __radd__(self, other):
2978
+ return self + other
2979
+
2980
+ @call_highest_priority('__matmul__')
2981
+ def __rmatmul__(self, other):
2982
+ other = _matrixify(other)
2983
+ if not getattr(other, 'is_Matrix', False) and not getattr(other, 'is_MatrixLike', False):
2984
+ return NotImplemented
2985
+
2986
+ return self.__rmul__(other)
2987
+
2988
+ @call_highest_priority('__mul__')
2989
+ def __rmul__(self, other):
2990
+ return self.rmultiply(other)
2991
+
2992
+ def rmultiply(self, other, dotprodsimp=None):
2993
+ """Same as __rmul__() but with optional simplification.
2994
+
2995
+ Parameters
2996
+ ==========
2997
+
2998
+ dotprodsimp : bool, optional
2999
+ Specifies whether intermediate term algebraic simplification is used
3000
+ during matrix multiplications to control expression blowup and thus
3001
+ speed up calculation. Default is off.
3002
+ """
3003
+ isimpbool = _get_intermediate_simp_bool(False, dotprodsimp)
3004
+ other = _matrixify(other)
3005
+ # matrix-like objects can have shapes. This is
3006
+ # our first sanity check. Double check other is not explicitly not a Matrix.
3007
+ if (hasattr(other, 'shape') and len(other.shape) == 2 and
3008
+ (getattr(other, 'is_Matrix', True) or
3009
+ getattr(other, 'is_MatrixLike', True))):
3010
+ if self.shape[0] != other.shape[1]:
3011
+ raise ShapeError("Matrix size mismatch.")
3012
+
3013
+ # honest SymPy matrices defer to their class's routine
3014
+ if getattr(other, 'is_Matrix', False):
3015
+ m = self._eval_matrix_rmul(other)
3016
+ if isimpbool:
3017
+ return m._new(m.rows, m.cols, [_dotprodsimp(e) for e in m])
3018
+ return m
3019
+ # Matrix-like objects can be passed to CommonMatrix routines directly.
3020
+ if getattr(other, 'is_MatrixLike', False):
3021
+ return MatrixArithmetic._eval_matrix_rmul(self, other)
3022
+
3023
+ # if 'other' is not iterable then scalar multiplication.
3024
+ if not isinstance(other, Iterable):
3025
+ try:
3026
+ return self._eval_scalar_rmul(other)
3027
+ except TypeError:
3028
+ pass
3029
+
3030
+ return NotImplemented
3031
+
3032
+ @call_highest_priority('__sub__')
3033
+ def __rsub__(self, a):
3034
+ return (-self) + a
3035
+
3036
+ @call_highest_priority('__rsub__')
3037
+ def __sub__(self, a):
3038
+ return self + (-a)
3039
+
3040
+
3041
+ class MatrixCommon(MatrixArithmetic, MatrixOperations, MatrixProperties,
3042
+ MatrixSpecial, MatrixShaping):
3043
+ """All common matrix operations including basic arithmetic, shaping,
3044
+ and special matrices like `zeros`, and `eye`."""
3045
+ _diff_wrt = True # type: bool
3046
+
3047
+
3048
+ class _MinimalMatrix:
3049
+ """Class providing the minimum functionality
3050
+ for a matrix-like object and implementing every method
3051
+ required for a `MatrixRequired`. This class does not have everything
3052
+ needed to become a full-fledged SymPy object, but it will satisfy the
3053
+ requirements of anything inheriting from `MatrixRequired`. If you wish
3054
+ to make a specialized matrix type, make sure to implement these
3055
+ methods and properties with the exception of `__init__` and `__repr__`
3056
+ which are included for convenience."""
3057
+
3058
+ is_MatrixLike = True
3059
+ _sympify = staticmethod(sympify)
3060
+ _class_priority = 3
3061
+ zero = S.Zero
3062
+ one = S.One
3063
+
3064
+ is_Matrix = True
3065
+ is_MatrixExpr = False
3066
+
3067
+ @classmethod
3068
+ def _new(cls, *args, **kwargs):
3069
+ return cls(*args, **kwargs)
3070
+
3071
+ def __init__(self, rows, cols=None, mat=None, copy=False):
3072
+ if isfunction(mat):
3073
+ # if we passed in a function, use that to populate the indices
3074
+ mat = [mat(i, j) for i in range(rows) for j in range(cols)]
3075
+ if cols is None and mat is None:
3076
+ mat = rows
3077
+ rows, cols = getattr(mat, 'shape', (rows, cols))
3078
+ try:
3079
+ # if we passed in a list of lists, flatten it and set the size
3080
+ if cols is None and mat is None:
3081
+ mat = rows
3082
+ cols = len(mat[0])
3083
+ rows = len(mat)
3084
+ mat = [x for l in mat for x in l]
3085
+ except (IndexError, TypeError):
3086
+ pass
3087
+ self.mat = tuple(self._sympify(x) for x in mat)
3088
+ self.rows, self.cols = rows, cols
3089
+ if self.rows is None or self.cols is None:
3090
+ raise NotImplementedError("Cannot initialize matrix with given parameters")
3091
+
3092
+ def __getitem__(self, key):
3093
+ def _normalize_slices(row_slice, col_slice):
3094
+ """Ensure that row_slice and col_slice do not have
3095
+ `None` in their arguments. Any integers are converted
3096
+ to slices of length 1"""
3097
+ if not isinstance(row_slice, slice):
3098
+ row_slice = slice(row_slice, row_slice + 1, None)
3099
+ row_slice = slice(*row_slice.indices(self.rows))
3100
+
3101
+ if not isinstance(col_slice, slice):
3102
+ col_slice = slice(col_slice, col_slice + 1, None)
3103
+ col_slice = slice(*col_slice.indices(self.cols))
3104
+
3105
+ return (row_slice, col_slice)
3106
+
3107
+ def _coord_to_index(i, j):
3108
+ """Return the index in _mat corresponding
3109
+ to the (i,j) position in the matrix. """
3110
+ return i * self.cols + j
3111
+
3112
+ if isinstance(key, tuple):
3113
+ i, j = key
3114
+ if isinstance(i, slice) or isinstance(j, slice):
3115
+ # if the coordinates are not slices, make them so
3116
+ # and expand the slices so they don't contain `None`
3117
+ i, j = _normalize_slices(i, j)
3118
+
3119
+ rowsList, colsList = list(range(self.rows))[i], \
3120
+ list(range(self.cols))[j]
3121
+ indices = (i * self.cols + j for i in rowsList for j in
3122
+ colsList)
3123
+ return self._new(len(rowsList), len(colsList),
3124
+ [self.mat[i] for i in indices])
3125
+
3126
+ # if the key is a tuple of ints, change
3127
+ # it to an array index
3128
+ key = _coord_to_index(i, j)
3129
+ return self.mat[key]
3130
+
3131
+ def __eq__(self, other):
3132
+ try:
3133
+ classof(self, other)
3134
+ except TypeError:
3135
+ return False
3136
+ return (
3137
+ self.shape == other.shape and list(self) == list(other))
3138
+
3139
+ def __len__(self):
3140
+ return self.rows*self.cols
3141
+
3142
+ def __repr__(self):
3143
+ return "_MinimalMatrix({}, {}, {})".format(self.rows, self.cols,
3144
+ self.mat)
3145
+
3146
+ @property
3147
+ def shape(self):
3148
+ return (self.rows, self.cols)
3149
+
3150
+
3151
+ class _CastableMatrix: # this is needed here ONLY FOR TESTS.
3152
+ def as_mutable(self):
3153
+ return self
3154
+
3155
+ def as_immutable(self):
3156
+ return self
3157
+
3158
+
3159
+ class _MatrixWrapper:
3160
+ """Wrapper class providing the minimum functionality for a matrix-like
3161
+ object: .rows, .cols, .shape, indexability, and iterability. CommonMatrix
3162
+ math operations should work on matrix-like objects. This one is intended for
3163
+ matrix-like objects which use the same indexing format as SymPy with respect
3164
+ to returning matrix elements instead of rows for non-tuple indexes.
3165
+ """
3166
+
3167
+ is_Matrix = False # needs to be here because of __getattr__
3168
+ is_MatrixLike = True
3169
+
3170
+ def __init__(self, mat, shape):
3171
+ self.mat = mat
3172
+ self.shape = shape
3173
+ self.rows, self.cols = shape
3174
+
3175
+ def __getitem__(self, key):
3176
+ if isinstance(key, tuple):
3177
+ return sympify(self.mat.__getitem__(key))
3178
+
3179
+ return sympify(self.mat.__getitem__((key // self.rows, key % self.cols)))
3180
+
3181
+ def __iter__(self): # supports numpy.matrix and numpy.array
3182
+ mat = self.mat
3183
+ cols = self.cols
3184
+
3185
+ return iter(sympify(mat[r, c]) for r in range(self.rows) for c in range(cols))
3186
+
3187
+
3188
+ def _matrixify(mat):
3189
+ """If `mat` is a Matrix or is matrix-like,
3190
+ return a Matrix or MatrixWrapper object. Otherwise
3191
+ `mat` is passed through without modification."""
3192
+
3193
+ if getattr(mat, 'is_Matrix', False) or getattr(mat, 'is_MatrixLike', False):
3194
+ return mat
3195
+
3196
+ if not(getattr(mat, 'is_Matrix', True) or getattr(mat, 'is_MatrixLike', True)):
3197
+ return mat
3198
+
3199
+ shape = None
3200
+
3201
+ if hasattr(mat, 'shape'): # numpy, scipy.sparse
3202
+ if len(mat.shape) == 2:
3203
+ shape = mat.shape
3204
+ elif hasattr(mat, 'rows') and hasattr(mat, 'cols'): # mpmath
3205
+ shape = (mat.rows, mat.cols)
3206
+
3207
+ if shape:
3208
+ return _MatrixWrapper(mat, shape)
3209
+
3210
+ return mat
3211
+
3212
+
3213
+ def a2idx(j, n=None):
3214
+ """Return integer after making positive and validating against n."""
3215
+ if not isinstance(j, int):
3216
+ jindex = getattr(j, '__index__', None)
3217
+ if jindex is not None:
3218
+ j = jindex()
3219
+ else:
3220
+ raise IndexError("Invalid index a[%r]" % (j,))
3221
+ if n is not None:
3222
+ if j < 0:
3223
+ j += n
3224
+ if not (j >= 0 and j < n):
3225
+ raise IndexError("Index out of range: a[%s]" % (j,))
3226
+ return int(j)
3227
+
3228
+
3229
+ def classof(A, B):
3230
+ """
3231
+ Get the type of the result when combining matrices of different types.
3232
+
3233
+ Currently the strategy is that immutability is contagious.
3234
+
3235
+ Examples
3236
+ ========
3237
+
3238
+ >>> from sympy import Matrix, ImmutableMatrix
3239
+ >>> from sympy.matrices.matrixbase import classof
3240
+ >>> M = Matrix([[1, 2], [3, 4]]) # a Mutable Matrix
3241
+ >>> IM = ImmutableMatrix([[1, 2], [3, 4]])
3242
+ >>> classof(M, IM)
3243
+ <class 'sympy.matrices.immutable.ImmutableDenseMatrix'>
3244
+ """
3245
+ priority_A = getattr(A, '_class_priority', None)
3246
+ priority_B = getattr(B, '_class_priority', None)
3247
+ if None not in (priority_A, priority_B):
3248
+ if A._class_priority > B._class_priority:
3249
+ return A.__class__
3250
+ else:
3251
+ return B.__class__
3252
+
3253
+ try:
3254
+ import numpy
3255
+ except ImportError:
3256
+ pass
3257
+ else:
3258
+ if isinstance(A, numpy.ndarray):
3259
+ return B.__class__
3260
+ if isinstance(B, numpy.ndarray):
3261
+ return A.__class__
3262
+
3263
+ raise TypeError("Incompatible classes %s, %s" % (A.__class__, B.__class__))
tuning-competition-baseline/.venv/lib/python3.11/site-packages/sympy/matrices/decompositions.py ADDED
@@ -0,0 +1,1621 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+
3
+ from sympy.core import S
4
+ from sympy.core.function import expand_mul
5
+ from sympy.functions.elementary.miscellaneous import Min, sqrt
6
+ from sympy.functions.elementary.complexes import sign
7
+
8
+ from .exceptions import NonSquareMatrixError, NonPositiveDefiniteMatrixError
9
+ from .utilities import _get_intermediate_simp, _iszero
10
+ from .determinant import _find_reasonable_pivot_naive
11
+
12
+
13
+ def _rank_decomposition(M, iszerofunc=_iszero, simplify=False):
14
+ r"""Returns a pair of matrices (`C`, `F`) with matching rank
15
+ such that `A = C F`.
16
+
17
+ Parameters
18
+ ==========
19
+
20
+ iszerofunc : Function, optional
21
+ A function used for detecting whether an element can
22
+ act as a pivot. ``lambda x: x.is_zero`` is used by default.
23
+
24
+ simplify : Bool or Function, optional
25
+ A function used to simplify elements when looking for a
26
+ pivot. By default SymPy's ``simplify`` is used.
27
+
28
+ Returns
29
+ =======
30
+
31
+ (C, F) : Matrices
32
+ `C` and `F` are full-rank matrices with rank as same as `A`,
33
+ whose product gives `A`.
34
+
35
+ See Notes for additional mathematical details.
36
+
37
+ Examples
38
+ ========
39
+
40
+ >>> from sympy import Matrix
41
+ >>> A = Matrix([
42
+ ... [1, 3, 1, 4],
43
+ ... [2, 7, 3, 9],
44
+ ... [1, 5, 3, 1],
45
+ ... [1, 2, 0, 8]
46
+ ... ])
47
+ >>> C, F = A.rank_decomposition()
48
+ >>> C
49
+ Matrix([
50
+ [1, 3, 4],
51
+ [2, 7, 9],
52
+ [1, 5, 1],
53
+ [1, 2, 8]])
54
+ >>> F
55
+ Matrix([
56
+ [1, 0, -2, 0],
57
+ [0, 1, 1, 0],
58
+ [0, 0, 0, 1]])
59
+ >>> C * F == A
60
+ True
61
+
62
+ Notes
63
+ =====
64
+
65
+ Obtaining `F`, an RREF of `A`, is equivalent to creating a
66
+ product
67
+
68
+ .. math::
69
+ E_n E_{n-1} ... E_1 A = F
70
+
71
+ where `E_n, E_{n-1}, \dots, E_1` are the elimination matrices or
72
+ permutation matrices equivalent to each row-reduction step.
73
+
74
+ The inverse of the same product of elimination matrices gives
75
+ `C`:
76
+
77
+ .. math::
78
+ C = \left(E_n E_{n-1} \dots E_1\right)^{-1}
79
+
80
+ It is not necessary, however, to actually compute the inverse:
81
+ the columns of `C` are those from the original matrix with the
82
+ same column indices as the indices of the pivot columns of `F`.
83
+
84
+ References
85
+ ==========
86
+
87
+ .. [1] https://en.wikipedia.org/wiki/Rank_factorization
88
+
89
+ .. [2] Piziak, R.; Odell, P. L. (1 June 1999).
90
+ "Full Rank Factorization of Matrices".
91
+ Mathematics Magazine. 72 (3): 193. doi:10.2307/2690882
92
+
93
+ See Also
94
+ ========
95
+
96
+ sympy.matrices.matrixbase.MatrixBase.rref
97
+ """
98
+
99
+ F, pivot_cols = M.rref(simplify=simplify, iszerofunc=iszerofunc,
100
+ pivots=True)
101
+ rank = len(pivot_cols)
102
+
103
+ C = M.extract(range(M.rows), pivot_cols)
104
+ F = F[:rank, :]
105
+
106
+ return C, F
107
+
108
+
109
+ def _liupc(M):
110
+ """Liu's algorithm, for pre-determination of the Elimination Tree of
111
+ the given matrix, used in row-based symbolic Cholesky factorization.
112
+
113
+ Examples
114
+ ========
115
+
116
+ >>> from sympy import SparseMatrix
117
+ >>> S = SparseMatrix([
118
+ ... [1, 0, 3, 2],
119
+ ... [0, 0, 1, 0],
120
+ ... [4, 0, 0, 5],
121
+ ... [0, 6, 7, 0]])
122
+ >>> S.liupc()
123
+ ([[0], [], [0], [1, 2]], [4, 3, 4, 4])
124
+
125
+ References
126
+ ==========
127
+
128
+ .. [1] Symbolic Sparse Cholesky Factorization using Elimination Trees,
129
+ Jeroen Van Grondelle (1999)
130
+ https://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.39.7582
131
+ """
132
+ # Algorithm 2.4, p 17 of reference
133
+
134
+ # get the indices of the elements that are non-zero on or below diag
135
+ R = [[] for r in range(M.rows)]
136
+
137
+ for r, c, _ in M.row_list():
138
+ if c <= r:
139
+ R[r].append(c)
140
+
141
+ inf = len(R) # nothing will be this large
142
+ parent = [inf]*M.rows
143
+ virtual = [inf]*M.rows
144
+
145
+ for r in range(M.rows):
146
+ for c in R[r][:-1]:
147
+ while virtual[c] < r:
148
+ t = virtual[c]
149
+ virtual[c] = r
150
+ c = t
151
+
152
+ if virtual[c] == inf:
153
+ parent[c] = virtual[c] = r
154
+
155
+ return R, parent
156
+
157
+ def _row_structure_symbolic_cholesky(M):
158
+ """Symbolic cholesky factorization, for pre-determination of the
159
+ non-zero structure of the Cholesky factororization.
160
+
161
+ Examples
162
+ ========
163
+
164
+ >>> from sympy import SparseMatrix
165
+ >>> S = SparseMatrix([
166
+ ... [1, 0, 3, 2],
167
+ ... [0, 0, 1, 0],
168
+ ... [4, 0, 0, 5],
169
+ ... [0, 6, 7, 0]])
170
+ >>> S.row_structure_symbolic_cholesky()
171
+ [[0], [], [0], [1, 2]]
172
+
173
+ References
174
+ ==========
175
+
176
+ .. [1] Symbolic Sparse Cholesky Factorization using Elimination Trees,
177
+ Jeroen Van Grondelle (1999)
178
+ https://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.39.7582
179
+ """
180
+
181
+ R, parent = M.liupc()
182
+ inf = len(R) # this acts as infinity
183
+ Lrow = copy.deepcopy(R)
184
+
185
+ for k in range(M.rows):
186
+ for j in R[k]:
187
+ while j != inf and j != k:
188
+ Lrow[k].append(j)
189
+ j = parent[j]
190
+
191
+ Lrow[k] = sorted(set(Lrow[k]))
192
+
193
+ return Lrow
194
+
195
+
196
+ def _cholesky(M, hermitian=True):
197
+ """Returns the Cholesky-type decomposition L of a matrix A
198
+ such that L * L.H == A if hermitian flag is True,
199
+ or L * L.T == A if hermitian is False.
200
+
201
+ A must be a Hermitian positive-definite matrix if hermitian is True,
202
+ or a symmetric matrix if it is False.
203
+
204
+ Examples
205
+ ========
206
+
207
+ >>> from sympy import Matrix
208
+ >>> A = Matrix(((25, 15, -5), (15, 18, 0), (-5, 0, 11)))
209
+ >>> A.cholesky()
210
+ Matrix([
211
+ [ 5, 0, 0],
212
+ [ 3, 3, 0],
213
+ [-1, 1, 3]])
214
+ >>> A.cholesky() * A.cholesky().T
215
+ Matrix([
216
+ [25, 15, -5],
217
+ [15, 18, 0],
218
+ [-5, 0, 11]])
219
+
220
+ The matrix can have complex entries:
221
+
222
+ >>> from sympy import I
223
+ >>> A = Matrix(((9, 3*I), (-3*I, 5)))
224
+ >>> A.cholesky()
225
+ Matrix([
226
+ [ 3, 0],
227
+ [-I, 2]])
228
+ >>> A.cholesky() * A.cholesky().H
229
+ Matrix([
230
+ [ 9, 3*I],
231
+ [-3*I, 5]])
232
+
233
+ Non-hermitian Cholesky-type decomposition may be useful when the
234
+ matrix is not positive-definite.
235
+
236
+ >>> A = Matrix([[1, 2], [2, 1]])
237
+ >>> L = A.cholesky(hermitian=False)
238
+ >>> L
239
+ Matrix([
240
+ [1, 0],
241
+ [2, sqrt(3)*I]])
242
+ >>> L*L.T == A
243
+ True
244
+
245
+ See Also
246
+ ========
247
+
248
+ sympy.matrices.dense.DenseMatrix.LDLdecomposition
249
+ sympy.matrices.matrixbase.MatrixBase.LUdecomposition
250
+ QRdecomposition
251
+ """
252
+
253
+ from .dense import MutableDenseMatrix
254
+
255
+ if not M.is_square:
256
+ raise NonSquareMatrixError("Matrix must be square.")
257
+ if hermitian and not M.is_hermitian:
258
+ raise ValueError("Matrix must be Hermitian.")
259
+ if not hermitian and not M.is_symmetric():
260
+ raise ValueError("Matrix must be symmetric.")
261
+
262
+ L = MutableDenseMatrix.zeros(M.rows, M.rows)
263
+
264
+ if hermitian:
265
+ for i in range(M.rows):
266
+ for j in range(i):
267
+ L[i, j] = ((1 / L[j, j])*(M[i, j] -
268
+ sum(L[i, k]*L[j, k].conjugate() for k in range(j))))
269
+
270
+ Lii2 = (M[i, i] -
271
+ sum(L[i, k]*L[i, k].conjugate() for k in range(i)))
272
+
273
+ if Lii2.is_positive is False:
274
+ raise NonPositiveDefiniteMatrixError(
275
+ "Matrix must be positive-definite")
276
+
277
+ L[i, i] = sqrt(Lii2)
278
+
279
+ else:
280
+ for i in range(M.rows):
281
+ for j in range(i):
282
+ L[i, j] = ((1 / L[j, j])*(M[i, j] -
283
+ sum(L[i, k]*L[j, k] for k in range(j))))
284
+
285
+ L[i, i] = sqrt(M[i, i] -
286
+ sum(L[i, k]**2 for k in range(i)))
287
+
288
+ return M._new(L)
289
+
290
+ def _cholesky_sparse(M, hermitian=True):
291
+ """
292
+ Returns the Cholesky decomposition L of a matrix A
293
+ such that L * L.T = A
294
+
295
+ A must be a square, symmetric, positive-definite
296
+ and non-singular matrix
297
+
298
+ Examples
299
+ ========
300
+
301
+ >>> from sympy import SparseMatrix
302
+ >>> A = SparseMatrix(((25,15,-5),(15,18,0),(-5,0,11)))
303
+ >>> A.cholesky()
304
+ Matrix([
305
+ [ 5, 0, 0],
306
+ [ 3, 3, 0],
307
+ [-1, 1, 3]])
308
+ >>> A.cholesky() * A.cholesky().T == A
309
+ True
310
+
311
+ The matrix can have complex entries:
312
+
313
+ >>> from sympy import I
314
+ >>> A = SparseMatrix(((9, 3*I), (-3*I, 5)))
315
+ >>> A.cholesky()
316
+ Matrix([
317
+ [ 3, 0],
318
+ [-I, 2]])
319
+ >>> A.cholesky() * A.cholesky().H
320
+ Matrix([
321
+ [ 9, 3*I],
322
+ [-3*I, 5]])
323
+
324
+ Non-hermitian Cholesky-type decomposition may be useful when the
325
+ matrix is not positive-definite.
326
+
327
+ >>> A = SparseMatrix([[1, 2], [2, 1]])
328
+ >>> L = A.cholesky(hermitian=False)
329
+ >>> L
330
+ Matrix([
331
+ [1, 0],
332
+ [2, sqrt(3)*I]])
333
+ >>> L*L.T == A
334
+ True
335
+
336
+ See Also
337
+ ========
338
+
339
+ sympy.matrices.sparse.SparseMatrix.LDLdecomposition
340
+ sympy.matrices.matrixbase.MatrixBase.LUdecomposition
341
+ QRdecomposition
342
+ """
343
+
344
+ from .dense import MutableDenseMatrix
345
+
346
+ if not M.is_square:
347
+ raise NonSquareMatrixError("Matrix must be square.")
348
+ if hermitian and not M.is_hermitian:
349
+ raise ValueError("Matrix must be Hermitian.")
350
+ if not hermitian and not M.is_symmetric():
351
+ raise ValueError("Matrix must be symmetric.")
352
+
353
+ dps = _get_intermediate_simp(expand_mul, expand_mul)
354
+ Crowstruc = M.row_structure_symbolic_cholesky()
355
+ C = MutableDenseMatrix.zeros(M.rows)
356
+
357
+ for i in range(len(Crowstruc)):
358
+ for j in Crowstruc[i]:
359
+ if i != j:
360
+ C[i, j] = M[i, j]
361
+ summ = 0
362
+
363
+ for p1 in Crowstruc[i]:
364
+ if p1 < j:
365
+ for p2 in Crowstruc[j]:
366
+ if p2 < j:
367
+ if p1 == p2:
368
+ if hermitian:
369
+ summ += C[i, p1]*C[j, p1].conjugate()
370
+ else:
371
+ summ += C[i, p1]*C[j, p1]
372
+ else:
373
+ break
374
+ else:
375
+ break
376
+
377
+ C[i, j] = dps((C[i, j] - summ) / C[j, j])
378
+
379
+ else: # i == j
380
+ C[j, j] = M[j, j]
381
+ summ = 0
382
+
383
+ for k in Crowstruc[j]:
384
+ if k < j:
385
+ if hermitian:
386
+ summ += C[j, k]*C[j, k].conjugate()
387
+ else:
388
+ summ += C[j, k]**2
389
+ else:
390
+ break
391
+
392
+ Cjj2 = dps(C[j, j] - summ)
393
+
394
+ if hermitian and Cjj2.is_positive is False:
395
+ raise NonPositiveDefiniteMatrixError(
396
+ "Matrix must be positive-definite")
397
+
398
+ C[j, j] = sqrt(Cjj2)
399
+
400
+ return M._new(C)
401
+
402
+
403
+ def _LDLdecomposition(M, hermitian=True):
404
+ """Returns the LDL Decomposition (L, D) of matrix A,
405
+ such that L * D * L.H == A if hermitian flag is True, or
406
+ L * D * L.T == A if hermitian is False.
407
+ This method eliminates the use of square root.
408
+ Further this ensures that all the diagonal entries of L are 1.
409
+ A must be a Hermitian positive-definite matrix if hermitian is True,
410
+ or a symmetric matrix otherwise.
411
+
412
+ Examples
413
+ ========
414
+
415
+ >>> from sympy import Matrix, eye
416
+ >>> A = Matrix(((25, 15, -5), (15, 18, 0), (-5, 0, 11)))
417
+ >>> L, D = A.LDLdecomposition()
418
+ >>> L
419
+ Matrix([
420
+ [ 1, 0, 0],
421
+ [ 3/5, 1, 0],
422
+ [-1/5, 1/3, 1]])
423
+ >>> D
424
+ Matrix([
425
+ [25, 0, 0],
426
+ [ 0, 9, 0],
427
+ [ 0, 0, 9]])
428
+ >>> L * D * L.T * A.inv() == eye(A.rows)
429
+ True
430
+
431
+ The matrix can have complex entries:
432
+
433
+ >>> from sympy import I
434
+ >>> A = Matrix(((9, 3*I), (-3*I, 5)))
435
+ >>> L, D = A.LDLdecomposition()
436
+ >>> L
437
+ Matrix([
438
+ [ 1, 0],
439
+ [-I/3, 1]])
440
+ >>> D
441
+ Matrix([
442
+ [9, 0],
443
+ [0, 4]])
444
+ >>> L*D*L.H == A
445
+ True
446
+
447
+ See Also
448
+ ========
449
+
450
+ sympy.matrices.dense.DenseMatrix.cholesky
451
+ sympy.matrices.matrixbase.MatrixBase.LUdecomposition
452
+ QRdecomposition
453
+ """
454
+
455
+ from .dense import MutableDenseMatrix
456
+
457
+ if not M.is_square:
458
+ raise NonSquareMatrixError("Matrix must be square.")
459
+ if hermitian and not M.is_hermitian:
460
+ raise ValueError("Matrix must be Hermitian.")
461
+ if not hermitian and not M.is_symmetric():
462
+ raise ValueError("Matrix must be symmetric.")
463
+
464
+ D = MutableDenseMatrix.zeros(M.rows, M.rows)
465
+ L = MutableDenseMatrix.eye(M.rows)
466
+
467
+ if hermitian:
468
+ for i in range(M.rows):
469
+ for j in range(i):
470
+ L[i, j] = (1 / D[j, j])*(M[i, j] - sum(
471
+ L[i, k]*L[j, k].conjugate()*D[k, k] for k in range(j)))
472
+
473
+ D[i, i] = (M[i, i] -
474
+ sum(L[i, k]*L[i, k].conjugate()*D[k, k] for k in range(i)))
475
+
476
+ if D[i, i].is_positive is False:
477
+ raise NonPositiveDefiniteMatrixError(
478
+ "Matrix must be positive-definite")
479
+
480
+ else:
481
+ for i in range(M.rows):
482
+ for j in range(i):
483
+ L[i, j] = (1 / D[j, j])*(M[i, j] - sum(
484
+ L[i, k]*L[j, k]*D[k, k] for k in range(j)))
485
+
486
+ D[i, i] = M[i, i] - sum(L[i, k]**2*D[k, k] for k in range(i))
487
+
488
+ return M._new(L), M._new(D)
489
+
490
+ def _LDLdecomposition_sparse(M, hermitian=True):
491
+ """
492
+ Returns the LDL Decomposition (matrices ``L`` and ``D``) of matrix
493
+ ``A``, such that ``L * D * L.T == A``. ``A`` must be a square,
494
+ symmetric, positive-definite and non-singular.
495
+
496
+ This method eliminates the use of square root and ensures that all
497
+ the diagonal entries of L are 1.
498
+
499
+ Examples
500
+ ========
501
+
502
+ >>> from sympy import SparseMatrix
503
+ >>> A = SparseMatrix(((25, 15, -5), (15, 18, 0), (-5, 0, 11)))
504
+ >>> L, D = A.LDLdecomposition()
505
+ >>> L
506
+ Matrix([
507
+ [ 1, 0, 0],
508
+ [ 3/5, 1, 0],
509
+ [-1/5, 1/3, 1]])
510
+ >>> D
511
+ Matrix([
512
+ [25, 0, 0],
513
+ [ 0, 9, 0],
514
+ [ 0, 0, 9]])
515
+ >>> L * D * L.T == A
516
+ True
517
+
518
+ """
519
+
520
+ from .dense import MutableDenseMatrix
521
+
522
+ if not M.is_square:
523
+ raise NonSquareMatrixError("Matrix must be square.")
524
+ if hermitian and not M.is_hermitian:
525
+ raise ValueError("Matrix must be Hermitian.")
526
+ if not hermitian and not M.is_symmetric():
527
+ raise ValueError("Matrix must be symmetric.")
528
+
529
+ dps = _get_intermediate_simp(expand_mul, expand_mul)
530
+ Lrowstruc = M.row_structure_symbolic_cholesky()
531
+ L = MutableDenseMatrix.eye(M.rows)
532
+ D = MutableDenseMatrix.zeros(M.rows, M.cols)
533
+
534
+ for i in range(len(Lrowstruc)):
535
+ for j in Lrowstruc[i]:
536
+ if i != j:
537
+ L[i, j] = M[i, j]
538
+ summ = 0
539
+
540
+ for p1 in Lrowstruc[i]:
541
+ if p1 < j:
542
+ for p2 in Lrowstruc[j]:
543
+ if p2 < j:
544
+ if p1 == p2:
545
+ if hermitian:
546
+ summ += L[i, p1]*L[j, p1].conjugate()*D[p1, p1]
547
+ else:
548
+ summ += L[i, p1]*L[j, p1]*D[p1, p1]
549
+ else:
550
+ break
551
+ else:
552
+ break
553
+
554
+ L[i, j] = dps((L[i, j] - summ) / D[j, j])
555
+
556
+ else: # i == j
557
+ D[i, i] = M[i, i]
558
+ summ = 0
559
+
560
+ for k in Lrowstruc[i]:
561
+ if k < i:
562
+ if hermitian:
563
+ summ += L[i, k]*L[i, k].conjugate()*D[k, k]
564
+ else:
565
+ summ += L[i, k]**2*D[k, k]
566
+ else:
567
+ break
568
+
569
+ D[i, i] = dps(D[i, i] - summ)
570
+
571
+ if hermitian and D[i, i].is_positive is False:
572
+ raise NonPositiveDefiniteMatrixError(
573
+ "Matrix must be positive-definite")
574
+
575
+ return M._new(L), M._new(D)
576
+
577
+
578
+ def _LUdecomposition(M, iszerofunc=_iszero, simpfunc=None, rankcheck=False):
579
+ """Returns (L, U, perm) where L is a lower triangular matrix with unit
580
+ diagonal, U is an upper triangular matrix, and perm is a list of row
581
+ swap index pairs. If A is the original matrix, then
582
+ ``A = (L*U).permuteBkwd(perm)``, and the row permutation matrix P such
583
+ that $P A = L U$ can be computed by ``P = eye(A.rows).permuteFwd(perm)``.
584
+
585
+ See documentation for LUCombined for details about the keyword argument
586
+ rankcheck, iszerofunc, and simpfunc.
587
+
588
+ Parameters
589
+ ==========
590
+
591
+ rankcheck : bool, optional
592
+ Determines if this function should detect the rank
593
+ deficiency of the matrixis and should raise a
594
+ ``ValueError``.
595
+
596
+ iszerofunc : function, optional
597
+ A function which determines if a given expression is zero.
598
+
599
+ The function should be a callable that takes a single
600
+ SymPy expression and returns a 3-valued boolean value
601
+ ``True``, ``False``, or ``None``.
602
+
603
+ It is internally used by the pivot searching algorithm.
604
+ See the notes section for a more information about the
605
+ pivot searching algorithm.
606
+
607
+ simpfunc : function or None, optional
608
+ A function that simplifies the input.
609
+
610
+ If this is specified as a function, this function should be
611
+ a callable that takes a single SymPy expression and returns
612
+ an another SymPy expression that is algebraically
613
+ equivalent.
614
+
615
+ If ``None``, it indicates that the pivot search algorithm
616
+ should not attempt to simplify any candidate pivots.
617
+
618
+ It is internally used by the pivot searching algorithm.
619
+ See the notes section for a more information about the
620
+ pivot searching algorithm.
621
+
622
+ Examples
623
+ ========
624
+
625
+ >>> from sympy import Matrix
626
+ >>> a = Matrix([[4, 3], [6, 3]])
627
+ >>> L, U, _ = a.LUdecomposition()
628
+ >>> L
629
+ Matrix([
630
+ [ 1, 0],
631
+ [3/2, 1]])
632
+ >>> U
633
+ Matrix([
634
+ [4, 3],
635
+ [0, -3/2]])
636
+
637
+ See Also
638
+ ========
639
+
640
+ sympy.matrices.dense.DenseMatrix.cholesky
641
+ sympy.matrices.dense.DenseMatrix.LDLdecomposition
642
+ QRdecomposition
643
+ LUdecomposition_Simple
644
+ LUdecompositionFF
645
+ LUsolve
646
+ """
647
+
648
+ combined, p = M.LUdecomposition_Simple(iszerofunc=iszerofunc,
649
+ simpfunc=simpfunc, rankcheck=rankcheck)
650
+
651
+ # L is lower triangular ``M.rows x M.rows``
652
+ # U is upper triangular ``M.rows x M.cols``
653
+ # L has unit diagonal. For each column in combined, the subcolumn
654
+ # below the diagonal of combined is shared by L.
655
+ # If L has more columns than combined, then the remaining subcolumns
656
+ # below the diagonal of L are zero.
657
+ # The upper triangular portion of L and combined are equal.
658
+ def entry_L(i, j):
659
+ if i < j:
660
+ # Super diagonal entry
661
+ return M.zero
662
+ elif i == j:
663
+ return M.one
664
+ elif j < combined.cols:
665
+ return combined[i, j]
666
+
667
+ # Subdiagonal entry of L with no corresponding
668
+ # entry in combined
669
+ return M.zero
670
+
671
+ def entry_U(i, j):
672
+ return M.zero if i > j else combined[i, j]
673
+
674
+ L = M._new(combined.rows, combined.rows, entry_L)
675
+ U = M._new(combined.rows, combined.cols, entry_U)
676
+
677
+ return L, U, p
678
+
679
+ def _LUdecomposition_Simple(M, iszerofunc=_iszero, simpfunc=None,
680
+ rankcheck=False):
681
+ r"""Compute the PLU decomposition of the matrix.
682
+
683
+ Parameters
684
+ ==========
685
+
686
+ rankcheck : bool, optional
687
+ Determines if this function should detect the rank
688
+ deficiency of the matrixis and should raise a
689
+ ``ValueError``.
690
+
691
+ iszerofunc : function, optional
692
+ A function which determines if a given expression is zero.
693
+
694
+ The function should be a callable that takes a single
695
+ SymPy expression and returns a 3-valued boolean value
696
+ ``True``, ``False``, or ``None``.
697
+
698
+ It is internally used by the pivot searching algorithm.
699
+ See the notes section for a more information about the
700
+ pivot searching algorithm.
701
+
702
+ simpfunc : function or None, optional
703
+ A function that simplifies the input.
704
+
705
+ If this is specified as a function, this function should be
706
+ a callable that takes a single SymPy expression and returns
707
+ an another SymPy expression that is algebraically
708
+ equivalent.
709
+
710
+ If ``None``, it indicates that the pivot search algorithm
711
+ should not attempt to simplify any candidate pivots.
712
+
713
+ It is internally used by the pivot searching algorithm.
714
+ See the notes section for a more information about the
715
+ pivot searching algorithm.
716
+
717
+ Returns
718
+ =======
719
+
720
+ (lu, row_swaps) : (Matrix, list)
721
+ If the original matrix is a $m, n$ matrix:
722
+
723
+ *lu* is a $m, n$ matrix, which contains result of the
724
+ decomposition in a compressed form. See the notes section
725
+ to see how the matrix is compressed.
726
+
727
+ *row_swaps* is a $m$-element list where each element is a
728
+ pair of row exchange indices.
729
+
730
+ ``A = (L*U).permute_backward(perm)``, and the row
731
+ permutation matrix $P$ from the formula $P A = L U$ can be
732
+ computed by ``P=eye(A.row).permute_forward(perm)``.
733
+
734
+ Raises
735
+ ======
736
+
737
+ ValueError
738
+ Raised if ``rankcheck=True`` and the matrix is found to
739
+ be rank deficient during the computation.
740
+
741
+ Notes
742
+ =====
743
+
744
+ About the PLU decomposition:
745
+
746
+ PLU decomposition is a generalization of a LU decomposition
747
+ which can be extended for rank-deficient matrices.
748
+
749
+ It can further be generalized for non-square matrices, and this
750
+ is the notation that SymPy is using.
751
+
752
+ PLU decomposition is a decomposition of a $m, n$ matrix $A$ in
753
+ the form of $P A = L U$ where
754
+
755
+ * $L$ is a $m, m$ lower triangular matrix with unit diagonal
756
+ entries.
757
+ * $U$ is a $m, n$ upper triangular matrix.
758
+ * $P$ is a $m, m$ permutation matrix.
759
+
760
+ So, for a square matrix, the decomposition would look like:
761
+
762
+ .. math::
763
+ L = \begin{bmatrix}
764
+ 1 & 0 & 0 & \cdots & 0 \\
765
+ L_{1, 0} & 1 & 0 & \cdots & 0 \\
766
+ L_{2, 0} & L_{2, 1} & 1 & \cdots & 0 \\
767
+ \vdots & \vdots & \vdots & \ddots & \vdots \\
768
+ L_{n-1, 0} & L_{n-1, 1} & L_{n-1, 2} & \cdots & 1
769
+ \end{bmatrix}
770
+
771
+ .. math::
772
+ U = \begin{bmatrix}
773
+ U_{0, 0} & U_{0, 1} & U_{0, 2} & \cdots & U_{0, n-1} \\
774
+ 0 & U_{1, 1} & U_{1, 2} & \cdots & U_{1, n-1} \\
775
+ 0 & 0 & U_{2, 2} & \cdots & U_{2, n-1} \\
776
+ \vdots & \vdots & \vdots & \ddots & \vdots \\
777
+ 0 & 0 & 0 & \cdots & U_{n-1, n-1}
778
+ \end{bmatrix}
779
+
780
+ And for a matrix with more rows than the columns,
781
+ the decomposition would look like:
782
+
783
+ .. math::
784
+ L = \begin{bmatrix}
785
+ 1 & 0 & 0 & \cdots & 0 & 0 & \cdots & 0 \\
786
+ L_{1, 0} & 1 & 0 & \cdots & 0 & 0 & \cdots & 0 \\
787
+ L_{2, 0} & L_{2, 1} & 1 & \cdots & 0 & 0 & \cdots & 0 \\
788
+ \vdots & \vdots & \vdots & \ddots & \vdots & \vdots & \ddots
789
+ & \vdots \\
790
+ L_{n-1, 0} & L_{n-1, 1} & L_{n-1, 2} & \cdots & 1 & 0
791
+ & \cdots & 0 \\
792
+ L_{n, 0} & L_{n, 1} & L_{n, 2} & \cdots & L_{n, n-1} & 1
793
+ & \cdots & 0 \\
794
+ \vdots & \vdots & \vdots & \ddots & \vdots & \vdots
795
+ & \ddots & \vdots \\
796
+ L_{m-1, 0} & L_{m-1, 1} & L_{m-1, 2} & \cdots & L_{m-1, n-1}
797
+ & 0 & \cdots & 1 \\
798
+ \end{bmatrix}
799
+
800
+ .. math::
801
+ U = \begin{bmatrix}
802
+ U_{0, 0} & U_{0, 1} & U_{0, 2} & \cdots & U_{0, n-1} \\
803
+ 0 & U_{1, 1} & U_{1, 2} & \cdots & U_{1, n-1} \\
804
+ 0 & 0 & U_{2, 2} & \cdots & U_{2, n-1} \\
805
+ \vdots & \vdots & \vdots & \ddots & \vdots \\
806
+ 0 & 0 & 0 & \cdots & U_{n-1, n-1} \\
807
+ 0 & 0 & 0 & \cdots & 0 \\
808
+ \vdots & \vdots & \vdots & \ddots & \vdots \\
809
+ 0 & 0 & 0 & \cdots & 0
810
+ \end{bmatrix}
811
+
812
+ Finally, for a matrix with more columns than the rows, the
813
+ decomposition would look like:
814
+
815
+ .. math::
816
+ L = \begin{bmatrix}
817
+ 1 & 0 & 0 & \cdots & 0 \\
818
+ L_{1, 0} & 1 & 0 & \cdots & 0 \\
819
+ L_{2, 0} & L_{2, 1} & 1 & \cdots & 0 \\
820
+ \vdots & \vdots & \vdots & \ddots & \vdots \\
821
+ L_{m-1, 0} & L_{m-1, 1} & L_{m-1, 2} & \cdots & 1
822
+ \end{bmatrix}
823
+
824
+ .. math::
825
+ U = \begin{bmatrix}
826
+ U_{0, 0} & U_{0, 1} & U_{0, 2} & \cdots & U_{0, m-1}
827
+ & \cdots & U_{0, n-1} \\
828
+ 0 & U_{1, 1} & U_{1, 2} & \cdots & U_{1, m-1}
829
+ & \cdots & U_{1, n-1} \\
830
+ 0 & 0 & U_{2, 2} & \cdots & U_{2, m-1}
831
+ & \cdots & U_{2, n-1} \\
832
+ \vdots & \vdots & \vdots & \ddots & \vdots
833
+ & \cdots & \vdots \\
834
+ 0 & 0 & 0 & \cdots & U_{m-1, m-1}
835
+ & \cdots & U_{m-1, n-1} \\
836
+ \end{bmatrix}
837
+
838
+ About the compressed LU storage:
839
+
840
+ The results of the decomposition are often stored in compressed
841
+ forms rather than returning $L$ and $U$ matrices individually.
842
+
843
+ It may be less intiuitive, but it is commonly used for a lot of
844
+ numeric libraries because of the efficiency.
845
+
846
+ The storage matrix is defined as following for this specific
847
+ method:
848
+
849
+ * The subdiagonal elements of $L$ are stored in the subdiagonal
850
+ portion of $LU$, that is $LU_{i, j} = L_{i, j}$ whenever
851
+ $i > j$.
852
+ * The elements on the diagonal of $L$ are all 1, and are not
853
+ explicitly stored.
854
+ * $U$ is stored in the upper triangular portion of $LU$, that is
855
+ $LU_{i, j} = U_{i, j}$ whenever $i <= j$.
856
+ * For a case of $m > n$, the right side of the $L$ matrix is
857
+ trivial to store.
858
+ * For a case of $m < n$, the below side of the $U$ matrix is
859
+ trivial to store.
860
+
861
+ So, for a square matrix, the compressed output matrix would be:
862
+
863
+ .. math::
864
+ LU = \begin{bmatrix}
865
+ U_{0, 0} & U_{0, 1} & U_{0, 2} & \cdots & U_{0, n-1} \\
866
+ L_{1, 0} & U_{1, 1} & U_{1, 2} & \cdots & U_{1, n-1} \\
867
+ L_{2, 0} & L_{2, 1} & U_{2, 2} & \cdots & U_{2, n-1} \\
868
+ \vdots & \vdots & \vdots & \ddots & \vdots \\
869
+ L_{n-1, 0} & L_{n-1, 1} & L_{n-1, 2} & \cdots & U_{n-1, n-1}
870
+ \end{bmatrix}
871
+
872
+ For a matrix with more rows than the columns, the compressed
873
+ output matrix would be:
874
+
875
+ .. math::
876
+ LU = \begin{bmatrix}
877
+ U_{0, 0} & U_{0, 1} & U_{0, 2} & \cdots & U_{0, n-1} \\
878
+ L_{1, 0} & U_{1, 1} & U_{1, 2} & \cdots & U_{1, n-1} \\
879
+ L_{2, 0} & L_{2, 1} & U_{2, 2} & \cdots & U_{2, n-1} \\
880
+ \vdots & \vdots & \vdots & \ddots & \vdots \\
881
+ L_{n-1, 0} & L_{n-1, 1} & L_{n-1, 2} & \cdots
882
+ & U_{n-1, n-1} \\
883
+ \vdots & \vdots & \vdots & \ddots & \vdots \\
884
+ L_{m-1, 0} & L_{m-1, 1} & L_{m-1, 2} & \cdots
885
+ & L_{m-1, n-1} \\
886
+ \end{bmatrix}
887
+
888
+ For a matrix with more columns than the rows, the compressed
889
+ output matrix would be:
890
+
891
+ .. math::
892
+ LU = \begin{bmatrix}
893
+ U_{0, 0} & U_{0, 1} & U_{0, 2} & \cdots & U_{0, m-1}
894
+ & \cdots & U_{0, n-1} \\
895
+ L_{1, 0} & U_{1, 1} & U_{1, 2} & \cdots & U_{1, m-1}
896
+ & \cdots & U_{1, n-1} \\
897
+ L_{2, 0} & L_{2, 1} & U_{2, 2} & \cdots & U_{2, m-1}
898
+ & \cdots & U_{2, n-1} \\
899
+ \vdots & \vdots & \vdots & \ddots & \vdots
900
+ & \cdots & \vdots \\
901
+ L_{m-1, 0} & L_{m-1, 1} & L_{m-1, 2} & \cdots & U_{m-1, m-1}
902
+ & \cdots & U_{m-1, n-1} \\
903
+ \end{bmatrix}
904
+
905
+ About the pivot searching algorithm:
906
+
907
+ When a matrix contains symbolic entries, the pivot search algorithm
908
+ differs from the case where every entry can be categorized as zero or
909
+ nonzero.
910
+ The algorithm searches column by column through the submatrix whose
911
+ top left entry coincides with the pivot position.
912
+ If it exists, the pivot is the first entry in the current search
913
+ column that iszerofunc guarantees is nonzero.
914
+ If no such candidate exists, then each candidate pivot is simplified
915
+ if simpfunc is not None.
916
+ The search is repeated, with the difference that a candidate may be
917
+ the pivot if ``iszerofunc()`` cannot guarantee that it is nonzero.
918
+ In the second search the pivot is the first candidate that
919
+ iszerofunc can guarantee is nonzero.
920
+ If no such candidate exists, then the pivot is the first candidate
921
+ for which iszerofunc returns None.
922
+ If no such candidate exists, then the search is repeated in the next
923
+ column to the right.
924
+ The pivot search algorithm differs from the one in ``rref()``, which
925
+ relies on ``_find_reasonable_pivot()``.
926
+ Future versions of ``LUdecomposition_simple()`` may use
927
+ ``_find_reasonable_pivot()``.
928
+
929
+ See Also
930
+ ========
931
+
932
+ sympy.matrices.matrixbase.MatrixBase.LUdecomposition
933
+ LUdecompositionFF
934
+ LUsolve
935
+ """
936
+
937
+ if rankcheck:
938
+ # https://github.com/sympy/sympy/issues/9796
939
+ pass
940
+
941
+ if S.Zero in M.shape:
942
+ # Define LU decomposition of a matrix with no entries as a matrix
943
+ # of the same dimensions with all zero entries.
944
+ return M.zeros(M.rows, M.cols), []
945
+
946
+ dps = _get_intermediate_simp()
947
+ lu = M.as_mutable()
948
+ row_swaps = []
949
+
950
+ pivot_col = 0
951
+
952
+ for pivot_row in range(0, lu.rows - 1):
953
+ # Search for pivot. Prefer entry that iszeropivot determines
954
+ # is nonzero, over entry that iszeropivot cannot guarantee
955
+ # is zero.
956
+ # XXX ``_find_reasonable_pivot`` uses slow zero testing. Blocked by bug #10279
957
+ # Future versions of LUdecomposition_simple can pass iszerofunc and simpfunc
958
+ # to _find_reasonable_pivot().
959
+ # In pass 3 of _find_reasonable_pivot(), the predicate in ``if x.equals(S.Zero):``
960
+ # calls sympy.simplify(), and not the simplification function passed in via
961
+ # the keyword argument simpfunc.
962
+ iszeropivot = True
963
+
964
+ while pivot_col != M.cols and iszeropivot:
965
+ sub_col = (lu[r, pivot_col] for r in range(pivot_row, M.rows))
966
+
967
+ pivot_row_offset, pivot_value, is_assumed_non_zero, ind_simplified_pairs =\
968
+ _find_reasonable_pivot_naive(sub_col, iszerofunc, simpfunc)
969
+
970
+ iszeropivot = pivot_value is None
971
+
972
+ if iszeropivot:
973
+ # All candidate pivots in this column are zero.
974
+ # Proceed to next column.
975
+ pivot_col += 1
976
+
977
+ if rankcheck and pivot_col != pivot_row:
978
+ # All entries including and below the pivot position are
979
+ # zero, which indicates that the rank of the matrix is
980
+ # strictly less than min(num rows, num cols)
981
+ # Mimic behavior of previous implementation, by throwing a
982
+ # ValueError.
983
+ raise ValueError("Rank of matrix is strictly less than"
984
+ " number of rows or columns."
985
+ " Pass keyword argument"
986
+ " rankcheck=False to compute"
987
+ " the LU decomposition of this matrix.")
988
+
989
+ candidate_pivot_row = None if pivot_row_offset is None else pivot_row + pivot_row_offset
990
+
991
+ if candidate_pivot_row is None and iszeropivot:
992
+ # If candidate_pivot_row is None and iszeropivot is True
993
+ # after pivot search has completed, then the submatrix
994
+ # below and to the right of (pivot_row, pivot_col) is
995
+ # all zeros, indicating that Gaussian elimination is
996
+ # complete.
997
+ return lu, row_swaps
998
+
999
+ # Update entries simplified during pivot search.
1000
+ for offset, val in ind_simplified_pairs:
1001
+ lu[pivot_row + offset, pivot_col] = val
1002
+
1003
+ if pivot_row != candidate_pivot_row:
1004
+ # Row swap book keeping:
1005
+ # Record which rows were swapped.
1006
+ # Update stored portion of L factor by multiplying L on the
1007
+ # left and right with the current permutation.
1008
+ # Swap rows of U.
1009
+ row_swaps.append([pivot_row, candidate_pivot_row])
1010
+
1011
+ # Update L.
1012
+ lu[pivot_row, 0:pivot_row], lu[candidate_pivot_row, 0:pivot_row] = \
1013
+ lu[candidate_pivot_row, 0:pivot_row], lu[pivot_row, 0:pivot_row]
1014
+
1015
+ # Swap pivot row of U with candidate pivot row.
1016
+ lu[pivot_row, pivot_col:lu.cols], lu[candidate_pivot_row, pivot_col:lu.cols] = \
1017
+ lu[candidate_pivot_row, pivot_col:lu.cols], lu[pivot_row, pivot_col:lu.cols]
1018
+
1019
+ # Introduce zeros below the pivot by adding a multiple of the
1020
+ # pivot row to a row under it, and store the result in the
1021
+ # row under it.
1022
+ # Only entries in the target row whose index is greater than
1023
+ # start_col may be nonzero.
1024
+ start_col = pivot_col + 1
1025
+
1026
+ for row in range(pivot_row + 1, lu.rows):
1027
+ # Store factors of L in the subcolumn below
1028
+ # (pivot_row, pivot_row).
1029
+ lu[row, pivot_row] = \
1030
+ dps(lu[row, pivot_col]/lu[pivot_row, pivot_col])
1031
+
1032
+ # Form the linear combination of the pivot row and the current
1033
+ # row below the pivot row that zeros the entries below the pivot.
1034
+ # Employing slicing instead of a loop here raises
1035
+ # NotImplementedError: Cannot add Zero to MutableSparseMatrix
1036
+ # in sympy/matrices/tests/test_sparse.py.
1037
+ # c = pivot_row + 1 if pivot_row == pivot_col else pivot_col
1038
+ for c in range(start_col, lu.cols):
1039
+ lu[row, c] = dps(lu[row, c] - lu[row, pivot_row]*lu[pivot_row, c])
1040
+
1041
+ if pivot_row != pivot_col:
1042
+ # matrix rank < min(num rows, num cols),
1043
+ # so factors of L are not stored directly below the pivot.
1044
+ # These entries are zero by construction, so don't bother
1045
+ # computing them.
1046
+ for row in range(pivot_row + 1, lu.rows):
1047
+ lu[row, pivot_col] = M.zero
1048
+
1049
+ pivot_col += 1
1050
+
1051
+ if pivot_col == lu.cols:
1052
+ # All candidate pivots are zero implies that Gaussian
1053
+ # elimination is complete.
1054
+ return lu, row_swaps
1055
+
1056
+ if rankcheck:
1057
+ if iszerofunc(
1058
+ lu[Min(lu.rows, lu.cols) - 1, Min(lu.rows, lu.cols) - 1]):
1059
+ raise ValueError("Rank of matrix is strictly less than"
1060
+ " number of rows or columns."
1061
+ " Pass keyword argument"
1062
+ " rankcheck=False to compute"
1063
+ " the LU decomposition of this matrix.")
1064
+
1065
+ return lu, row_swaps
1066
+
1067
+ def _LUdecompositionFF(M):
1068
+ """Compute a fraction-free LU decomposition.
1069
+
1070
+ Returns 4 matrices P, L, D, U such that PA = L D**-1 U.
1071
+ If the elements of the matrix belong to some integral domain I, then all
1072
+ elements of L, D and U are guaranteed to belong to I.
1073
+
1074
+ See Also
1075
+ ========
1076
+
1077
+ sympy.matrices.matrixbase.MatrixBase.LUdecomposition
1078
+ LUdecomposition_Simple
1079
+ LUsolve
1080
+
1081
+ References
1082
+ ==========
1083
+
1084
+ .. [1] W. Zhou & D.J. Jeffrey, "Fraction-free matrix factors: new forms
1085
+ for LU and QR factors". Frontiers in Computer Science in China,
1086
+ Vol 2, no. 1, pp. 67-80, 2008.
1087
+ """
1088
+
1089
+ from sympy.matrices import SparseMatrix
1090
+
1091
+ zeros = SparseMatrix.zeros
1092
+ eye = SparseMatrix.eye
1093
+ n, m = M.rows, M.cols
1094
+ U, L, P = M.as_mutable(), eye(n), eye(n)
1095
+ DD = zeros(n, n)
1096
+ oldpivot = 1
1097
+
1098
+ for k in range(n - 1):
1099
+ if U[k, k] == 0:
1100
+ for kpivot in range(k + 1, n):
1101
+ if U[kpivot, k]:
1102
+ break
1103
+ else:
1104
+ raise ValueError("Matrix is not full rank")
1105
+
1106
+ U[k, k:], U[kpivot, k:] = U[kpivot, k:], U[k, k:]
1107
+ L[k, :k], L[kpivot, :k] = L[kpivot, :k], L[k, :k]
1108
+ P[k, :], P[kpivot, :] = P[kpivot, :], P[k, :]
1109
+
1110
+ L [k, k] = Ukk = U[k, k]
1111
+ DD[k, k] = oldpivot * Ukk
1112
+
1113
+ for i in range(k + 1, n):
1114
+ L[i, k] = Uik = U[i, k]
1115
+
1116
+ for j in range(k + 1, m):
1117
+ U[i, j] = (Ukk * U[i, j] - U[k, j] * Uik) / oldpivot
1118
+
1119
+ U[i, k] = 0
1120
+
1121
+ oldpivot = Ukk
1122
+
1123
+ DD[n - 1, n - 1] = oldpivot
1124
+
1125
+ return P, L, DD, U
1126
+
1127
+ def _singular_value_decomposition(A):
1128
+ r"""Returns a Condensed Singular Value decomposition.
1129
+
1130
+ Explanation
1131
+ ===========
1132
+
1133
+ A Singular Value decomposition is a decomposition in the form $A = U \Sigma V^H$
1134
+ where
1135
+
1136
+ - $U, V$ are column orthogonal matrix.
1137
+ - $\Sigma$ is a diagonal matrix, where the main diagonal contains singular
1138
+ values of matrix A.
1139
+
1140
+ A column orthogonal matrix satisfies
1141
+ $\mathbb{I} = U^H U$ while a full orthogonal matrix satisfies
1142
+ relation $\mathbb{I} = U U^H = U^H U$ where $\mathbb{I}$ is an identity
1143
+ matrix with matching dimensions.
1144
+
1145
+ For matrices which are not square or are rank-deficient, it is
1146
+ sufficient to return a column orthogonal matrix because augmenting
1147
+ them may introduce redundant computations.
1148
+ In condensed Singular Value Decomposition we only return column orthogonal
1149
+ matrices because of this reason
1150
+
1151
+ If you want to augment the results to return a full orthogonal
1152
+ decomposition, you should use the following procedures.
1153
+
1154
+ - Augment the $U, V$ matrices with columns that are orthogonal to every
1155
+ other columns and make it square.
1156
+ - Augment the $\Sigma$ matrix with zero rows to make it have the same
1157
+ shape as the original matrix.
1158
+
1159
+ The procedure will be illustrated in the examples section.
1160
+
1161
+ Examples
1162
+ ========
1163
+
1164
+ we take a full rank matrix first:
1165
+
1166
+ >>> from sympy import Matrix
1167
+ >>> A = Matrix([[1, 2],[2,1]])
1168
+ >>> U, S, V = A.singular_value_decomposition()
1169
+ >>> U
1170
+ Matrix([
1171
+ [ sqrt(2)/2, sqrt(2)/2],
1172
+ [-sqrt(2)/2, sqrt(2)/2]])
1173
+ >>> S
1174
+ Matrix([
1175
+ [1, 0],
1176
+ [0, 3]])
1177
+ >>> V
1178
+ Matrix([
1179
+ [-sqrt(2)/2, sqrt(2)/2],
1180
+ [ sqrt(2)/2, sqrt(2)/2]])
1181
+
1182
+ If a matrix if square and full rank both U, V
1183
+ are orthogonal in both directions
1184
+
1185
+ >>> U * U.H
1186
+ Matrix([
1187
+ [1, 0],
1188
+ [0, 1]])
1189
+ >>> U.H * U
1190
+ Matrix([
1191
+ [1, 0],
1192
+ [0, 1]])
1193
+
1194
+ >>> V * V.H
1195
+ Matrix([
1196
+ [1, 0],
1197
+ [0, 1]])
1198
+ >>> V.H * V
1199
+ Matrix([
1200
+ [1, 0],
1201
+ [0, 1]])
1202
+ >>> A == U * S * V.H
1203
+ True
1204
+
1205
+ >>> C = Matrix([
1206
+ ... [1, 0, 0, 0, 2],
1207
+ ... [0, 0, 3, 0, 0],
1208
+ ... [0, 0, 0, 0, 0],
1209
+ ... [0, 2, 0, 0, 0],
1210
+ ... ])
1211
+ >>> U, S, V = C.singular_value_decomposition()
1212
+
1213
+ >>> V.H * V
1214
+ Matrix([
1215
+ [1, 0, 0],
1216
+ [0, 1, 0],
1217
+ [0, 0, 1]])
1218
+ >>> V * V.H
1219
+ Matrix([
1220
+ [1/5, 0, 0, 0, 2/5],
1221
+ [ 0, 1, 0, 0, 0],
1222
+ [ 0, 0, 1, 0, 0],
1223
+ [ 0, 0, 0, 0, 0],
1224
+ [2/5, 0, 0, 0, 4/5]])
1225
+
1226
+ If you want to augment the results to be a full orthogonal
1227
+ decomposition, you should augment $V$ with an another orthogonal
1228
+ column.
1229
+
1230
+ You are able to append an arbitrary standard basis that are linearly
1231
+ independent to every other columns and you can run the Gram-Schmidt
1232
+ process to make them augmented as orthogonal basis.
1233
+
1234
+ >>> V_aug = V.row_join(Matrix([[0,0,0,0,1],
1235
+ ... [0,0,0,1,0]]).H)
1236
+ >>> V_aug = V_aug.QRdecomposition()[0]
1237
+ >>> V_aug
1238
+ Matrix([
1239
+ [0, sqrt(5)/5, 0, -2*sqrt(5)/5, 0],
1240
+ [1, 0, 0, 0, 0],
1241
+ [0, 0, 1, 0, 0],
1242
+ [0, 0, 0, 0, 1],
1243
+ [0, 2*sqrt(5)/5, 0, sqrt(5)/5, 0]])
1244
+ >>> V_aug.H * V_aug
1245
+ Matrix([
1246
+ [1, 0, 0, 0, 0],
1247
+ [0, 1, 0, 0, 0],
1248
+ [0, 0, 1, 0, 0],
1249
+ [0, 0, 0, 1, 0],
1250
+ [0, 0, 0, 0, 1]])
1251
+ >>> V_aug * V_aug.H
1252
+ Matrix([
1253
+ [1, 0, 0, 0, 0],
1254
+ [0, 1, 0, 0, 0],
1255
+ [0, 0, 1, 0, 0],
1256
+ [0, 0, 0, 1, 0],
1257
+ [0, 0, 0, 0, 1]])
1258
+
1259
+ Similarly we augment U
1260
+
1261
+ >>> U_aug = U.row_join(Matrix([0,0,1,0]))
1262
+ >>> U_aug = U_aug.QRdecomposition()[0]
1263
+ >>> U_aug
1264
+ Matrix([
1265
+ [0, 1, 0, 0],
1266
+ [0, 0, 1, 0],
1267
+ [0, 0, 0, 1],
1268
+ [1, 0, 0, 0]])
1269
+
1270
+ >>> U_aug.H * U_aug
1271
+ Matrix([
1272
+ [1, 0, 0, 0],
1273
+ [0, 1, 0, 0],
1274
+ [0, 0, 1, 0],
1275
+ [0, 0, 0, 1]])
1276
+ >>> U_aug * U_aug.H
1277
+ Matrix([
1278
+ [1, 0, 0, 0],
1279
+ [0, 1, 0, 0],
1280
+ [0, 0, 1, 0],
1281
+ [0, 0, 0, 1]])
1282
+
1283
+ We add 2 zero columns and one row to S
1284
+
1285
+ >>> S_aug = S.col_join(Matrix([[0,0,0]]))
1286
+ >>> S_aug = S_aug.row_join(Matrix([[0,0,0,0],
1287
+ ... [0,0,0,0]]).H)
1288
+ >>> S_aug
1289
+ Matrix([
1290
+ [2, 0, 0, 0, 0],
1291
+ [0, sqrt(5), 0, 0, 0],
1292
+ [0, 0, 3, 0, 0],
1293
+ [0, 0, 0, 0, 0]])
1294
+
1295
+
1296
+
1297
+ >>> U_aug * S_aug * V_aug.H == C
1298
+ True
1299
+
1300
+ """
1301
+
1302
+ AH = A.H
1303
+ m, n = A.shape
1304
+ if m >= n:
1305
+ V, S = (AH * A).diagonalize()
1306
+
1307
+ ranked = []
1308
+ for i, x in enumerate(S.diagonal()):
1309
+ if not x.is_zero:
1310
+ ranked.append(i)
1311
+
1312
+ V = V[:, ranked]
1313
+
1314
+ Singular_vals = [sqrt(S[i, i]) for i in range(S.rows) if i in ranked]
1315
+
1316
+ S = S.diag(*Singular_vals)
1317
+ V, _ = V.QRdecomposition()
1318
+ U = A * V * S.inv()
1319
+ else:
1320
+ U, S = (A * AH).diagonalize()
1321
+
1322
+ ranked = []
1323
+ for i, x in enumerate(S.diagonal()):
1324
+ if not x.is_zero:
1325
+ ranked.append(i)
1326
+
1327
+ U = U[:, ranked]
1328
+ Singular_vals = [sqrt(S[i, i]) for i in range(S.rows) if i in ranked]
1329
+
1330
+ S = S.diag(*Singular_vals)
1331
+ U, _ = U.QRdecomposition()
1332
+ V = AH * U * S.inv()
1333
+
1334
+ return U, S, V
1335
+
1336
+ def _QRdecomposition_optional(M, normalize=True):
1337
+ def dot(u, v):
1338
+ return u.dot(v, hermitian=True)
1339
+
1340
+ dps = _get_intermediate_simp(expand_mul, expand_mul)
1341
+
1342
+ A = M.as_mutable()
1343
+ ranked = []
1344
+
1345
+ Q = A
1346
+ R = A.zeros(A.cols)
1347
+
1348
+ for j in range(A.cols):
1349
+ for i in range(j):
1350
+ if Q[:, i].is_zero_matrix:
1351
+ continue
1352
+
1353
+ R[i, j] = dot(Q[:, i], Q[:, j]) / dot(Q[:, i], Q[:, i])
1354
+ R[i, j] = dps(R[i, j])
1355
+ Q[:, j] -= Q[:, i] * R[i, j]
1356
+
1357
+ Q[:, j] = dps(Q[:, j])
1358
+ if Q[:, j].is_zero_matrix is not True:
1359
+ ranked.append(j)
1360
+ R[j, j] = M.one
1361
+
1362
+ Q = Q.extract(range(Q.rows), ranked)
1363
+ R = R.extract(ranked, range(R.cols))
1364
+
1365
+ if normalize:
1366
+ # Normalization
1367
+ for i in range(Q.cols):
1368
+ norm = Q[:, i].norm()
1369
+ Q[:, i] /= norm
1370
+ R[i, :] *= norm
1371
+
1372
+ return M.__class__(Q), M.__class__(R)
1373
+
1374
+
1375
+ def _QRdecomposition(M):
1376
+ r"""Returns a QR decomposition.
1377
+
1378
+ Explanation
1379
+ ===========
1380
+
1381
+ A QR decomposition is a decomposition in the form $A = Q R$
1382
+ where
1383
+
1384
+ - $Q$ is a column orthogonal matrix.
1385
+ - $R$ is a upper triangular (trapezoidal) matrix.
1386
+
1387
+ A column orthogonal matrix satisfies
1388
+ $\mathbb{I} = Q^H Q$ while a full orthogonal matrix satisfies
1389
+ relation $\mathbb{I} = Q Q^H = Q^H Q$ where $I$ is an identity
1390
+ matrix with matching dimensions.
1391
+
1392
+ For matrices which are not square or are rank-deficient, it is
1393
+ sufficient to return a column orthogonal matrix because augmenting
1394
+ them may introduce redundant computations.
1395
+ And an another advantage of this is that you can easily inspect the
1396
+ matrix rank by counting the number of columns of $Q$.
1397
+
1398
+ If you want to augment the results to return a full orthogonal
1399
+ decomposition, you should use the following procedures.
1400
+
1401
+ - Augment the $Q$ matrix with columns that are orthogonal to every
1402
+ other columns and make it square.
1403
+ - Augment the $R$ matrix with zero rows to make it have the same
1404
+ shape as the original matrix.
1405
+
1406
+ The procedure will be illustrated in the examples section.
1407
+
1408
+ Examples
1409
+ ========
1410
+
1411
+ A full rank matrix example:
1412
+
1413
+ >>> from sympy import Matrix
1414
+ >>> A = Matrix([[12, -51, 4], [6, 167, -68], [-4, 24, -41]])
1415
+ >>> Q, R = A.QRdecomposition()
1416
+ >>> Q
1417
+ Matrix([
1418
+ [ 6/7, -69/175, -58/175],
1419
+ [ 3/7, 158/175, 6/175],
1420
+ [-2/7, 6/35, -33/35]])
1421
+ >>> R
1422
+ Matrix([
1423
+ [14, 21, -14],
1424
+ [ 0, 175, -70],
1425
+ [ 0, 0, 35]])
1426
+
1427
+ If the matrix is square and full rank, the $Q$ matrix becomes
1428
+ orthogonal in both directions, and needs no augmentation.
1429
+
1430
+ >>> Q * Q.H
1431
+ Matrix([
1432
+ [1, 0, 0],
1433
+ [0, 1, 0],
1434
+ [0, 0, 1]])
1435
+ >>> Q.H * Q
1436
+ Matrix([
1437
+ [1, 0, 0],
1438
+ [0, 1, 0],
1439
+ [0, 0, 1]])
1440
+
1441
+ >>> A == Q*R
1442
+ True
1443
+
1444
+ A rank deficient matrix example:
1445
+
1446
+ >>> A = Matrix([[12, -51, 0], [6, 167, 0], [-4, 24, 0]])
1447
+ >>> Q, R = A.QRdecomposition()
1448
+ >>> Q
1449
+ Matrix([
1450
+ [ 6/7, -69/175],
1451
+ [ 3/7, 158/175],
1452
+ [-2/7, 6/35]])
1453
+ >>> R
1454
+ Matrix([
1455
+ [14, 21, 0],
1456
+ [ 0, 175, 0]])
1457
+
1458
+ QRdecomposition might return a matrix Q that is rectangular.
1459
+ In this case the orthogonality condition might be satisfied as
1460
+ $\mathbb{I} = Q.H*Q$ but not in the reversed product
1461
+ $\mathbb{I} = Q * Q.H$.
1462
+
1463
+ >>> Q.H * Q
1464
+ Matrix([
1465
+ [1, 0],
1466
+ [0, 1]])
1467
+ >>> Q * Q.H
1468
+ Matrix([
1469
+ [27261/30625, 348/30625, -1914/6125],
1470
+ [ 348/30625, 30589/30625, 198/6125],
1471
+ [ -1914/6125, 198/6125, 136/1225]])
1472
+
1473
+ If you want to augment the results to be a full orthogonal
1474
+ decomposition, you should augment $Q$ with an another orthogonal
1475
+ column.
1476
+
1477
+ You are able to append an identity matrix,
1478
+ and you can run the Gram-Schmidt
1479
+ process to make them augmented as orthogonal basis.
1480
+
1481
+ >>> Q_aug = Q.row_join(Matrix.eye(3))
1482
+ >>> Q_aug = Q_aug.QRdecomposition()[0]
1483
+ >>> Q_aug
1484
+ Matrix([
1485
+ [ 6/7, -69/175, 58/175],
1486
+ [ 3/7, 158/175, -6/175],
1487
+ [-2/7, 6/35, 33/35]])
1488
+ >>> Q_aug.H * Q_aug
1489
+ Matrix([
1490
+ [1, 0, 0],
1491
+ [0, 1, 0],
1492
+ [0, 0, 1]])
1493
+ >>> Q_aug * Q_aug.H
1494
+ Matrix([
1495
+ [1, 0, 0],
1496
+ [0, 1, 0],
1497
+ [0, 0, 1]])
1498
+
1499
+ Augmenting the $R$ matrix with zero row is straightforward.
1500
+
1501
+ >>> R_aug = R.col_join(Matrix([[0, 0, 0]]))
1502
+ >>> R_aug
1503
+ Matrix([
1504
+ [14, 21, 0],
1505
+ [ 0, 175, 0],
1506
+ [ 0, 0, 0]])
1507
+ >>> Q_aug * R_aug == A
1508
+ True
1509
+
1510
+ A zero matrix example:
1511
+
1512
+ >>> from sympy import Matrix
1513
+ >>> A = Matrix.zeros(3, 4)
1514
+ >>> Q, R = A.QRdecomposition()
1515
+
1516
+ They may return matrices with zero rows and columns.
1517
+
1518
+ >>> Q
1519
+ Matrix(3, 0, [])
1520
+ >>> R
1521
+ Matrix(0, 4, [])
1522
+ >>> Q*R
1523
+ Matrix([
1524
+ [0, 0, 0, 0],
1525
+ [0, 0, 0, 0],
1526
+ [0, 0, 0, 0]])
1527
+
1528
+ As the same augmentation rule described above, $Q$ can be augmented
1529
+ with columns of an identity matrix and $R$ can be augmented with
1530
+ rows of a zero matrix.
1531
+
1532
+ >>> Q_aug = Q.row_join(Matrix.eye(3))
1533
+ >>> R_aug = R.col_join(Matrix.zeros(3, 4))
1534
+ >>> Q_aug * Q_aug.T
1535
+ Matrix([
1536
+ [1, 0, 0],
1537
+ [0, 1, 0],
1538
+ [0, 0, 1]])
1539
+ >>> R_aug
1540
+ Matrix([
1541
+ [0, 0, 0, 0],
1542
+ [0, 0, 0, 0],
1543
+ [0, 0, 0, 0]])
1544
+ >>> Q_aug * R_aug == A
1545
+ True
1546
+
1547
+ See Also
1548
+ ========
1549
+
1550
+ sympy.matrices.dense.DenseMatrix.cholesky
1551
+ sympy.matrices.dense.DenseMatrix.LDLdecomposition
1552
+ sympy.matrices.matrixbase.MatrixBase.LUdecomposition
1553
+ QRsolve
1554
+ """
1555
+ return _QRdecomposition_optional(M, normalize=True)
1556
+
1557
+ def _upper_hessenberg_decomposition(A):
1558
+ """Converts a matrix into Hessenberg matrix H.
1559
+
1560
+ Returns 2 matrices H, P s.t.
1561
+ $P H P^{T} = A$, where H is an upper hessenberg matrix
1562
+ and P is an orthogonal matrix
1563
+
1564
+ Examples
1565
+ ========
1566
+
1567
+ >>> from sympy import Matrix
1568
+ >>> A = Matrix([
1569
+ ... [1,2,3],
1570
+ ... [-3,5,6],
1571
+ ... [4,-8,9],
1572
+ ... ])
1573
+ >>> H, P = A.upper_hessenberg_decomposition()
1574
+ >>> H
1575
+ Matrix([
1576
+ [1, 6/5, 17/5],
1577
+ [5, 213/25, -134/25],
1578
+ [0, 216/25, 137/25]])
1579
+ >>> P
1580
+ Matrix([
1581
+ [1, 0, 0],
1582
+ [0, -3/5, 4/5],
1583
+ [0, 4/5, 3/5]])
1584
+ >>> P * H * P.H == A
1585
+ True
1586
+
1587
+
1588
+ References
1589
+ ==========
1590
+
1591
+ .. [#] https://mathworld.wolfram.com/HessenbergDecomposition.html
1592
+ """
1593
+
1594
+ M = A.as_mutable()
1595
+
1596
+ if not M.is_square:
1597
+ raise NonSquareMatrixError("Matrix must be square.")
1598
+
1599
+ n = M.cols
1600
+ P = M.eye(n)
1601
+ H = M
1602
+
1603
+ for j in range(n - 2):
1604
+
1605
+ u = H[j + 1:, j]
1606
+
1607
+ if u[1:, :].is_zero_matrix:
1608
+ continue
1609
+
1610
+ if sign(u[0]) != 0:
1611
+ u[0] = u[0] + sign(u[0]) * u.norm()
1612
+ else:
1613
+ u[0] = u[0] + u.norm()
1614
+
1615
+ v = u / u.norm()
1616
+
1617
+ H[j + 1:, :] = H[j + 1:, :] - 2 * v * (v.H * H[j + 1:, :])
1618
+ H[:, j + 1:] = H[:, j + 1:] - (H[:, j + 1:] * (2 * v)) * v.H
1619
+ P[:, j + 1:] = P[:, j + 1:] - (P[:, j + 1:] * (2 * v)) * v.H
1620
+
1621
+ return H, P
tuning-competition-baseline/.venv/lib/python3.11/site-packages/sympy/matrices/expressions/tests/test_blockmatrix.py ADDED
@@ -0,0 +1,469 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sympy.matrices.expressions.trace import Trace
2
+ from sympy.testing.pytest import raises, slow
3
+ from sympy.matrices.expressions.blockmatrix import (
4
+ block_collapse, bc_matmul, bc_block_plus_ident, BlockDiagMatrix,
5
+ BlockMatrix, bc_dist, bc_matadd, bc_transpose, bc_inverse,
6
+ blockcut, reblock_2x2, deblock)
7
+ from sympy.matrices.expressions import (
8
+ MatrixSymbol, Identity, trace, det, ZeroMatrix, OneMatrix)
9
+ from sympy.matrices.expressions.inverse import Inverse
10
+ from sympy.matrices.expressions.matpow import MatPow
11
+ from sympy.matrices.expressions.transpose import Transpose
12
+ from sympy.matrices.exceptions import NonInvertibleMatrixError
13
+ from sympy.matrices import (
14
+ Matrix, ImmutableMatrix, ImmutableSparseMatrix, zeros)
15
+ from sympy.core import Tuple, Expr, S, Function
16
+ from sympy.core.symbol import Symbol, symbols
17
+ from sympy.functions import transpose, im, re
18
+
19
+ i, j, k, l, m, n, p = symbols('i:n, p', integer=True)
20
+ A = MatrixSymbol('A', n, n)
21
+ B = MatrixSymbol('B', n, n)
22
+ C = MatrixSymbol('C', n, n)
23
+ D = MatrixSymbol('D', n, n)
24
+ G = MatrixSymbol('G', n, n)
25
+ H = MatrixSymbol('H', n, n)
26
+ b1 = BlockMatrix([[G, H]])
27
+ b2 = BlockMatrix([[G], [H]])
28
+
29
+ def test_bc_matmul():
30
+ assert bc_matmul(H*b1*b2*G) == BlockMatrix([[(H*G*G + H*H*H)*G]])
31
+
32
+ def test_bc_matadd():
33
+ assert bc_matadd(BlockMatrix([[G, H]]) + BlockMatrix([[H, H]])) == \
34
+ BlockMatrix([[G+H, H+H]])
35
+
36
+ def test_bc_transpose():
37
+ assert bc_transpose(Transpose(BlockMatrix([[A, B], [C, D]]))) == \
38
+ BlockMatrix([[A.T, C.T], [B.T, D.T]])
39
+
40
+ def test_bc_dist_diag():
41
+ A = MatrixSymbol('A', n, n)
42
+ B = MatrixSymbol('B', m, m)
43
+ C = MatrixSymbol('C', l, l)
44
+ X = BlockDiagMatrix(A, B, C)
45
+
46
+ assert bc_dist(X+X).equals(BlockDiagMatrix(2*A, 2*B, 2*C))
47
+
48
+ def test_block_plus_ident():
49
+ A = MatrixSymbol('A', n, n)
50
+ B = MatrixSymbol('B', n, m)
51
+ C = MatrixSymbol('C', m, n)
52
+ D = MatrixSymbol('D', m, m)
53
+ X = BlockMatrix([[A, B], [C, D]])
54
+ Z = MatrixSymbol('Z', n + m, n + m)
55
+ assert bc_block_plus_ident(X + Identity(m + n) + Z) == \
56
+ BlockDiagMatrix(Identity(n), Identity(m)) + X + Z
57
+
58
+ def test_BlockMatrix():
59
+ A = MatrixSymbol('A', n, m)
60
+ B = MatrixSymbol('B', n, k)
61
+ C = MatrixSymbol('C', l, m)
62
+ D = MatrixSymbol('D', l, k)
63
+ M = MatrixSymbol('M', m + k, p)
64
+ N = MatrixSymbol('N', l + n, k + m)
65
+ X = BlockMatrix(Matrix([[A, B], [C, D]]))
66
+
67
+ assert X.__class__(*X.args) == X
68
+
69
+ # block_collapse does nothing on normal inputs
70
+ E = MatrixSymbol('E', n, m)
71
+ assert block_collapse(A + 2*E) == A + 2*E
72
+ F = MatrixSymbol('F', m, m)
73
+ assert block_collapse(E.T*A*F) == E.T*A*F
74
+
75
+ assert X.shape == (l + n, k + m)
76
+ assert X.blockshape == (2, 2)
77
+ assert transpose(X) == BlockMatrix(Matrix([[A.T, C.T], [B.T, D.T]]))
78
+ assert transpose(X).shape == X.shape[::-1]
79
+
80
+ # Test that BlockMatrices and MatrixSymbols can still mix
81
+ assert (X*M).is_MatMul
82
+ assert X._blockmul(M).is_MatMul
83
+ assert (X*M).shape == (n + l, p)
84
+ assert (X + N).is_MatAdd
85
+ assert X._blockadd(N).is_MatAdd
86
+ assert (X + N).shape == X.shape
87
+
88
+ E = MatrixSymbol('E', m, 1)
89
+ F = MatrixSymbol('F', k, 1)
90
+
91
+ Y = BlockMatrix(Matrix([[E], [F]]))
92
+
93
+ assert (X*Y).shape == (l + n, 1)
94
+ assert block_collapse(X*Y).blocks[0, 0] == A*E + B*F
95
+ assert block_collapse(X*Y).blocks[1, 0] == C*E + D*F
96
+
97
+ # block_collapse passes down into container objects, transposes, and inverse
98
+ assert block_collapse(transpose(X*Y)) == transpose(block_collapse(X*Y))
99
+ assert block_collapse(Tuple(X*Y, 2*X)) == (
100
+ block_collapse(X*Y), block_collapse(2*X))
101
+
102
+ # Make sure that MatrixSymbols will enter 1x1 BlockMatrix if it simplifies
103
+ Ab = BlockMatrix([[A]])
104
+ Z = MatrixSymbol('Z', *A.shape)
105
+ assert block_collapse(Ab + Z) == A + Z
106
+
107
+ def test_block_collapse_explicit_matrices():
108
+ A = Matrix([[1, 2], [3, 4]])
109
+ assert block_collapse(BlockMatrix([[A]])) == A
110
+
111
+ A = ImmutableSparseMatrix([[1, 2], [3, 4]])
112
+ assert block_collapse(BlockMatrix([[A]])) == A
113
+
114
+ def test_issue_17624():
115
+ a = MatrixSymbol("a", 2, 2)
116
+ z = ZeroMatrix(2, 2)
117
+ b = BlockMatrix([[a, z], [z, z]])
118
+ assert block_collapse(b * b) == BlockMatrix([[a**2, z], [z, z]])
119
+ assert block_collapse(b * b * b) == BlockMatrix([[a**3, z], [z, z]])
120
+
121
+ def test_issue_18618():
122
+ A = Matrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
123
+ assert A == Matrix(BlockDiagMatrix(A))
124
+
125
+ def test_BlockMatrix_trace():
126
+ A, B, C, D = [MatrixSymbol(s, 3, 3) for s in 'ABCD']
127
+ X = BlockMatrix([[A, B], [C, D]])
128
+ assert trace(X) == trace(A) + trace(D)
129
+ assert trace(BlockMatrix([ZeroMatrix(n, n)])) == 0
130
+
131
+ def test_BlockMatrix_Determinant():
132
+ A, B, C, D = [MatrixSymbol(s, 3, 3) for s in 'ABCD']
133
+ X = BlockMatrix([[A, B], [C, D]])
134
+ from sympy.assumptions.ask import Q
135
+ from sympy.assumptions.assume import assuming
136
+ with assuming(Q.invertible(A)):
137
+ assert det(X) == det(A) * det(X.schur('A'))
138
+
139
+ assert isinstance(det(X), Expr)
140
+ assert det(BlockMatrix([A])) == det(A)
141
+ assert det(BlockMatrix([ZeroMatrix(n, n)])) == 0
142
+
143
+ def test_squareBlockMatrix():
144
+ A = MatrixSymbol('A', n, n)
145
+ B = MatrixSymbol('B', n, m)
146
+ C = MatrixSymbol('C', m, n)
147
+ D = MatrixSymbol('D', m, m)
148
+ X = BlockMatrix([[A, B], [C, D]])
149
+ Y = BlockMatrix([[A]])
150
+
151
+ assert X.is_square
152
+
153
+ Q = X + Identity(m + n)
154
+ assert (block_collapse(Q) ==
155
+ BlockMatrix([[A + Identity(n), B], [C, D + Identity(m)]]))
156
+
157
+ assert (X + MatrixSymbol('Q', n + m, n + m)).is_MatAdd
158
+ assert (X * MatrixSymbol('Q', n + m, n + m)).is_MatMul
159
+
160
+ assert block_collapse(Y.I) == A.I
161
+
162
+ assert isinstance(X.inverse(), Inverse)
163
+
164
+ assert not X.is_Identity
165
+
166
+ Z = BlockMatrix([[Identity(n), B], [C, D]])
167
+ assert not Z.is_Identity
168
+
169
+
170
+ def test_BlockMatrix_2x2_inverse_symbolic():
171
+ A = MatrixSymbol('A', n, m)
172
+ B = MatrixSymbol('B', n, k - m)
173
+ C = MatrixSymbol('C', k - n, m)
174
+ D = MatrixSymbol('D', k - n, k - m)
175
+ X = BlockMatrix([[A, B], [C, D]])
176
+ assert X.is_square and X.shape == (k, k)
177
+ assert isinstance(block_collapse(X.I), Inverse) # Can't invert when none of the blocks is square
178
+
179
+ # test code path where only A is invertible
180
+ A = MatrixSymbol('A', n, n)
181
+ B = MatrixSymbol('B', n, m)
182
+ C = MatrixSymbol('C', m, n)
183
+ D = ZeroMatrix(m, m)
184
+ X = BlockMatrix([[A, B], [C, D]])
185
+ assert block_collapse(X.inverse()) == BlockMatrix([
186
+ [A.I + A.I * B * X.schur('A').I * C * A.I, -A.I * B * X.schur('A').I],
187
+ [-X.schur('A').I * C * A.I, X.schur('A').I],
188
+ ])
189
+
190
+ # test code path where only B is invertible
191
+ A = MatrixSymbol('A', n, m)
192
+ B = MatrixSymbol('B', n, n)
193
+ C = ZeroMatrix(m, m)
194
+ D = MatrixSymbol('D', m, n)
195
+ X = BlockMatrix([[A, B], [C, D]])
196
+ assert block_collapse(X.inverse()) == BlockMatrix([
197
+ [-X.schur('B').I * D * B.I, X.schur('B').I],
198
+ [B.I + B.I * A * X.schur('B').I * D * B.I, -B.I * A * X.schur('B').I],
199
+ ])
200
+
201
+ # test code path where only C is invertible
202
+ A = MatrixSymbol('A', n, m)
203
+ B = ZeroMatrix(n, n)
204
+ C = MatrixSymbol('C', m, m)
205
+ D = MatrixSymbol('D', m, n)
206
+ X = BlockMatrix([[A, B], [C, D]])
207
+ assert block_collapse(X.inverse()) == BlockMatrix([
208
+ [-C.I * D * X.schur('C').I, C.I + C.I * D * X.schur('C').I * A * C.I],
209
+ [X.schur('C').I, -X.schur('C').I * A * C.I],
210
+ ])
211
+
212
+ # test code path where only D is invertible
213
+ A = ZeroMatrix(n, n)
214
+ B = MatrixSymbol('B', n, m)
215
+ C = MatrixSymbol('C', m, n)
216
+ D = MatrixSymbol('D', m, m)
217
+ X = BlockMatrix([[A, B], [C, D]])
218
+ assert block_collapse(X.inverse()) == BlockMatrix([
219
+ [X.schur('D').I, -X.schur('D').I * B * D.I],
220
+ [-D.I * C * X.schur('D').I, D.I + D.I * C * X.schur('D').I * B * D.I],
221
+ ])
222
+
223
+
224
+ def test_BlockMatrix_2x2_inverse_numeric():
225
+ """Test 2x2 block matrix inversion numerically for all 4 formulas"""
226
+ M = Matrix([[1, 2], [3, 4]])
227
+ # rank deficient matrices that have full rank when two of them combined
228
+ D1 = Matrix([[1, 2], [2, 4]])
229
+ D2 = Matrix([[1, 3], [3, 9]])
230
+ D3 = Matrix([[1, 4], [4, 16]])
231
+ assert D1.rank() == D2.rank() == D3.rank() == 1
232
+ assert (D1 + D2).rank() == (D2 + D3).rank() == (D3 + D1).rank() == 2
233
+
234
+ # Only A is invertible
235
+ K = BlockMatrix([[M, D1], [D2, D3]])
236
+ assert block_collapse(K.inv()).as_explicit() == K.as_explicit().inv()
237
+ # Only B is invertible
238
+ K = BlockMatrix([[D1, M], [D2, D3]])
239
+ assert block_collapse(K.inv()).as_explicit() == K.as_explicit().inv()
240
+ # Only C is invertible
241
+ K = BlockMatrix([[D1, D2], [M, D3]])
242
+ assert block_collapse(K.inv()).as_explicit() == K.as_explicit().inv()
243
+ # Only D is invertible
244
+ K = BlockMatrix([[D1, D2], [D3, M]])
245
+ assert block_collapse(K.inv()).as_explicit() == K.as_explicit().inv()
246
+
247
+
248
+ @slow
249
+ def test_BlockMatrix_3x3_symbolic():
250
+ # Only test one of these, instead of all permutations, because it's slow
251
+ rowblocksizes = (n, m, k)
252
+ colblocksizes = (m, k, n)
253
+ K = BlockMatrix([
254
+ [MatrixSymbol('M%s%s' % (rows, cols), rows, cols) for cols in colblocksizes]
255
+ for rows in rowblocksizes
256
+ ])
257
+ collapse = block_collapse(K.I)
258
+ assert isinstance(collapse, BlockMatrix)
259
+
260
+
261
+ def test_BlockDiagMatrix():
262
+ A = MatrixSymbol('A', n, n)
263
+ B = MatrixSymbol('B', m, m)
264
+ C = MatrixSymbol('C', l, l)
265
+ M = MatrixSymbol('M', n + m + l, n + m + l)
266
+
267
+ X = BlockDiagMatrix(A, B, C)
268
+ Y = BlockDiagMatrix(A, 2*B, 3*C)
269
+
270
+ assert X.blocks[1, 1] == B
271
+ assert X.shape == (n + m + l, n + m + l)
272
+ assert all(X.blocks[i, j].is_ZeroMatrix if i != j else X.blocks[i, j] in [A, B, C]
273
+ for i in range(3) for j in range(3))
274
+ assert X.__class__(*X.args) == X
275
+ assert X.get_diag_blocks() == (A, B, C)
276
+
277
+ assert isinstance(block_collapse(X.I * X), Identity)
278
+
279
+ assert bc_matmul(X*X) == BlockDiagMatrix(A*A, B*B, C*C)
280
+ assert block_collapse(X*X) == BlockDiagMatrix(A*A, B*B, C*C)
281
+ #XXX: should be == ??
282
+ assert block_collapse(X + X).equals(BlockDiagMatrix(2*A, 2*B, 2*C))
283
+ assert block_collapse(X*Y) == BlockDiagMatrix(A*A, 2*B*B, 3*C*C)
284
+ assert block_collapse(X + Y) == BlockDiagMatrix(2*A, 3*B, 4*C)
285
+
286
+ # Ensure that BlockDiagMatrices can still interact with normal MatrixExprs
287
+ assert (X*(2*M)).is_MatMul
288
+ assert (X + (2*M)).is_MatAdd
289
+
290
+ assert (X._blockmul(M)).is_MatMul
291
+ assert (X._blockadd(M)).is_MatAdd
292
+
293
+ def test_BlockDiagMatrix_nonsquare():
294
+ A = MatrixSymbol('A', n, m)
295
+ B = MatrixSymbol('B', k, l)
296
+ X = BlockDiagMatrix(A, B)
297
+ assert X.shape == (n + k, m + l)
298
+ assert X.shape == (n + k, m + l)
299
+ assert X.rowblocksizes == [n, k]
300
+ assert X.colblocksizes == [m, l]
301
+ C = MatrixSymbol('C', n, m)
302
+ D = MatrixSymbol('D', k, l)
303
+ Y = BlockDiagMatrix(C, D)
304
+ assert block_collapse(X + Y) == BlockDiagMatrix(A + C, B + D)
305
+ assert block_collapse(X * Y.T) == BlockDiagMatrix(A * C.T, B * D.T)
306
+ raises(NonInvertibleMatrixError, lambda: BlockDiagMatrix(A, C.T).inverse())
307
+
308
+ def test_BlockDiagMatrix_determinant():
309
+ A = MatrixSymbol('A', n, n)
310
+ B = MatrixSymbol('B', m, m)
311
+ assert det(BlockDiagMatrix()) == 1
312
+ assert det(BlockDiagMatrix(A)) == det(A)
313
+ assert det(BlockDiagMatrix(A, B)) == det(A) * det(B)
314
+
315
+ # non-square blocks
316
+ C = MatrixSymbol('C', m, n)
317
+ D = MatrixSymbol('D', n, m)
318
+ assert det(BlockDiagMatrix(C, D)) == 0
319
+
320
+ def test_BlockDiagMatrix_trace():
321
+ assert trace(BlockDiagMatrix()) == 0
322
+ assert trace(BlockDiagMatrix(ZeroMatrix(n, n))) == 0
323
+ A = MatrixSymbol('A', n, n)
324
+ assert trace(BlockDiagMatrix(A)) == trace(A)
325
+ B = MatrixSymbol('B', m, m)
326
+ assert trace(BlockDiagMatrix(A, B)) == trace(A) + trace(B)
327
+
328
+ # non-square blocks
329
+ C = MatrixSymbol('C', m, n)
330
+ D = MatrixSymbol('D', n, m)
331
+ assert isinstance(trace(BlockDiagMatrix(C, D)), Trace)
332
+
333
+ def test_BlockDiagMatrix_transpose():
334
+ A = MatrixSymbol('A', n, m)
335
+ B = MatrixSymbol('B', k, l)
336
+ assert transpose(BlockDiagMatrix()) == BlockDiagMatrix()
337
+ assert transpose(BlockDiagMatrix(A)) == BlockDiagMatrix(A.T)
338
+ assert transpose(BlockDiagMatrix(A, B)) == BlockDiagMatrix(A.T, B.T)
339
+
340
+ def test_issue_2460():
341
+ bdm1 = BlockDiagMatrix(Matrix([i]), Matrix([j]))
342
+ bdm2 = BlockDiagMatrix(Matrix([k]), Matrix([l]))
343
+ assert block_collapse(bdm1 + bdm2) == BlockDiagMatrix(Matrix([i + k]), Matrix([j + l]))
344
+
345
+ def test_blockcut():
346
+ A = MatrixSymbol('A', n, m)
347
+ B = blockcut(A, (n/2, n/2), (m/2, m/2))
348
+ assert B == BlockMatrix([[A[:n/2, :m/2], A[:n/2, m/2:]],
349
+ [A[n/2:, :m/2], A[n/2:, m/2:]]])
350
+
351
+ M = ImmutableMatrix(4, 4, range(16))
352
+ B = blockcut(M, (2, 2), (2, 2))
353
+ assert M == ImmutableMatrix(B)
354
+
355
+ B = blockcut(M, (1, 3), (2, 2))
356
+ assert ImmutableMatrix(B.blocks[0, 1]) == ImmutableMatrix([[2, 3]])
357
+
358
+ def test_reblock_2x2():
359
+ B = BlockMatrix([[MatrixSymbol('A_%d%d'%(i,j), 2, 2)
360
+ for j in range(3)]
361
+ for i in range(3)])
362
+ assert B.blocks.shape == (3, 3)
363
+
364
+ BB = reblock_2x2(B)
365
+ assert BB.blocks.shape == (2, 2)
366
+
367
+ assert B.shape == BB.shape
368
+ assert B.as_explicit() == BB.as_explicit()
369
+
370
+ def test_deblock():
371
+ B = BlockMatrix([[MatrixSymbol('A_%d%d'%(i,j), n, n)
372
+ for j in range(4)]
373
+ for i in range(4)])
374
+
375
+ assert deblock(reblock_2x2(B)) == B
376
+
377
+ def test_block_collapse_type():
378
+ bm1 = BlockDiagMatrix(ImmutableMatrix([1]), ImmutableMatrix([2]))
379
+ bm2 = BlockDiagMatrix(ImmutableMatrix([3]), ImmutableMatrix([4]))
380
+
381
+ assert bm1.T.__class__ == BlockDiagMatrix
382
+ assert block_collapse(bm1 - bm2).__class__ == BlockDiagMatrix
383
+ assert block_collapse(Inverse(bm1)).__class__ == BlockDiagMatrix
384
+ assert block_collapse(Transpose(bm1)).__class__ == BlockDiagMatrix
385
+ assert bc_transpose(Transpose(bm1)).__class__ == BlockDiagMatrix
386
+ assert bc_inverse(Inverse(bm1)).__class__ == BlockDiagMatrix
387
+
388
+ def test_invalid_block_matrix():
389
+ raises(ValueError, lambda: BlockMatrix([
390
+ [Identity(2), Identity(5)],
391
+ ]))
392
+ raises(ValueError, lambda: BlockMatrix([
393
+ [Identity(n), Identity(m)],
394
+ ]))
395
+ raises(ValueError, lambda: BlockMatrix([
396
+ [ZeroMatrix(n, n), ZeroMatrix(n, n)],
397
+ [ZeroMatrix(n, n - 1), ZeroMatrix(n, n + 1)],
398
+ ]))
399
+ raises(ValueError, lambda: BlockMatrix([
400
+ [ZeroMatrix(n - 1, n), ZeroMatrix(n, n)],
401
+ [ZeroMatrix(n + 1, n), ZeroMatrix(n, n)],
402
+ ]))
403
+
404
+ def test_block_lu_decomposition():
405
+ A = MatrixSymbol('A', n, n)
406
+ B = MatrixSymbol('B', n, m)
407
+ C = MatrixSymbol('C', m, n)
408
+ D = MatrixSymbol('D', m, m)
409
+ X = BlockMatrix([[A, B], [C, D]])
410
+
411
+ #LDU decomposition
412
+ L, D, U = X.LDUdecomposition()
413
+ assert block_collapse(L*D*U) == X
414
+
415
+ #UDL decomposition
416
+ U, D, L = X.UDLdecomposition()
417
+ assert block_collapse(U*D*L) == X
418
+
419
+ #LU decomposition
420
+ L, U = X.LUdecomposition()
421
+ assert block_collapse(L*U) == X
422
+
423
+ def test_issue_21866():
424
+ n = 10
425
+ I = Identity(n)
426
+ O = ZeroMatrix(n, n)
427
+ A = BlockMatrix([[ I, O, O, O ],
428
+ [ O, I, O, O ],
429
+ [ O, O, I, O ],
430
+ [ I, O, O, I ]])
431
+ Ainv = block_collapse(A.inv())
432
+ AinvT = BlockMatrix([[ I, O, O, O ],
433
+ [ O, I, O, O ],
434
+ [ O, O, I, O ],
435
+ [ -I, O, O, I ]])
436
+ assert Ainv == AinvT
437
+
438
+
439
+ def test_adjoint_and_special_matrices():
440
+ A = Identity(3)
441
+ B = OneMatrix(3, 2)
442
+ C = ZeroMatrix(2, 3)
443
+ D = Identity(2)
444
+ X = BlockMatrix([[A, B], [C, D]])
445
+ X2 = BlockMatrix([[A, S.ImaginaryUnit*B], [C, D]])
446
+ assert X.adjoint() == BlockMatrix([[A, ZeroMatrix(3, 2)], [OneMatrix(2, 3), D]])
447
+ assert re(X) == X
448
+ assert X2.adjoint() == BlockMatrix([[A, ZeroMatrix(3, 2)], [-S.ImaginaryUnit*OneMatrix(2, 3), D]])
449
+ assert im(X2) == BlockMatrix([[ZeroMatrix(3, 3), OneMatrix(3, 2)], [ZeroMatrix(2, 3), ZeroMatrix(2, 2)]])
450
+
451
+
452
+ def test_block_matrix_derivative():
453
+ x = symbols('x')
454
+ A = Matrix(3, 3, [Function(f'a{i}')(x) for i in range(9)])
455
+ bc = BlockMatrix([[A[:2, :2], A[:2, 2]], [A[2, :2], A[2:, 2]]])
456
+ assert Matrix(bc.diff(x)) - A.diff(x) == zeros(3, 3)
457
+
458
+
459
+ def test_transpose_inverse_commute():
460
+ n = Symbol('n')
461
+ I = Identity(n)
462
+ Z = ZeroMatrix(n, n)
463
+ A = BlockMatrix([[I, Z], [Z, I]])
464
+
465
+ assert block_collapse(A.transpose().inverse()) == A
466
+ assert block_collapse(A.inverse().transpose()) == A
467
+
468
+ assert block_collapse(MatPow(A.transpose(), -2)) == MatPow(A, -2)
469
+ assert block_collapse(MatPow(A, -2).transpose()) == MatPow(A, -2)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/sympy/matrices/expressions/tests/test_diagonal.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sympy.matrices.expressions import MatrixSymbol
2
+ from sympy.matrices.expressions.diagonal import DiagonalMatrix, DiagonalOf, DiagMatrix, diagonalize_vector
3
+ from sympy.assumptions.ask import (Q, ask)
4
+ from sympy.core.symbol import Symbol
5
+ from sympy.functions.special.tensor_functions import KroneckerDelta
6
+ from sympy.matrices.dense import Matrix
7
+ from sympy.matrices.expressions.matmul import MatMul
8
+ from sympy.matrices.expressions.special import Identity
9
+ from sympy.testing.pytest import raises
10
+
11
+
12
+ n = Symbol('n')
13
+ m = Symbol('m')
14
+
15
+
16
+ def test_DiagonalMatrix():
17
+ x = MatrixSymbol('x', n, m)
18
+ D = DiagonalMatrix(x)
19
+ assert D.diagonal_length is None
20
+ assert D.shape == (n, m)
21
+
22
+ x = MatrixSymbol('x', n, n)
23
+ D = DiagonalMatrix(x)
24
+ assert D.diagonal_length == n
25
+ assert D.shape == (n, n)
26
+ assert D[1, 2] == 0
27
+ assert D[1, 1] == x[1, 1]
28
+ i = Symbol('i')
29
+ j = Symbol('j')
30
+ x = MatrixSymbol('x', 3, 3)
31
+ ij = DiagonalMatrix(x)[i, j]
32
+ assert ij != 0
33
+ assert ij.subs({i:0, j:0}) == x[0, 0]
34
+ assert ij.subs({i:0, j:1}) == 0
35
+ assert ij.subs({i:1, j:1}) == x[1, 1]
36
+ assert ask(Q.diagonal(D)) # affirm that D is diagonal
37
+
38
+ x = MatrixSymbol('x', n, 3)
39
+ D = DiagonalMatrix(x)
40
+ assert D.diagonal_length == 3
41
+ assert D.shape == (n, 3)
42
+ assert D[2, m] == KroneckerDelta(2, m)*x[2, m]
43
+ assert D[3, m] == 0
44
+ raises(IndexError, lambda: D[m, 3])
45
+
46
+ x = MatrixSymbol('x', 3, n)
47
+ D = DiagonalMatrix(x)
48
+ assert D.diagonal_length == 3
49
+ assert D.shape == (3, n)
50
+ assert D[m, 2] == KroneckerDelta(m, 2)*x[m, 2]
51
+ assert D[m, 3] == 0
52
+ raises(IndexError, lambda: D[3, m])
53
+
54
+ x = MatrixSymbol('x', n, m)
55
+ D = DiagonalMatrix(x)
56
+ assert D.diagonal_length is None
57
+ assert D.shape == (n, m)
58
+ assert D[m, 4] != 0
59
+
60
+ x = MatrixSymbol('x', 3, 4)
61
+ assert [DiagonalMatrix(x)[i] for i in range(12)] == [
62
+ x[0, 0], 0, 0, 0, 0, x[1, 1], 0, 0, 0, 0, x[2, 2], 0]
63
+
64
+ # shape is retained, issue 12427
65
+ assert (
66
+ DiagonalMatrix(MatrixSymbol('x', 3, 4))*
67
+ DiagonalMatrix(MatrixSymbol('x', 4, 2))).shape == (3, 2)
68
+
69
+
70
+ def test_DiagonalOf():
71
+ x = MatrixSymbol('x', n, n)
72
+ d = DiagonalOf(x)
73
+ assert d.shape == (n, 1)
74
+ assert d.diagonal_length == n
75
+ assert d[2, 0] == d[2] == x[2, 2]
76
+
77
+ x = MatrixSymbol('x', n, m)
78
+ d = DiagonalOf(x)
79
+ assert d.shape == (None, 1)
80
+ assert d.diagonal_length is None
81
+ assert d[2, 0] == d[2] == x[2, 2]
82
+
83
+ d = DiagonalOf(MatrixSymbol('x', 4, 3))
84
+ assert d.shape == (3, 1)
85
+ d = DiagonalOf(MatrixSymbol('x', n, 3))
86
+ assert d.shape == (3, 1)
87
+ d = DiagonalOf(MatrixSymbol('x', 3, n))
88
+ assert d.shape == (3, 1)
89
+ x = MatrixSymbol('x', n, m)
90
+ assert [DiagonalOf(x)[i] for i in range(4)] ==[
91
+ x[0, 0], x[1, 1], x[2, 2], x[3, 3]]
92
+
93
+
94
+ def test_DiagMatrix():
95
+ x = MatrixSymbol('x', n, 1)
96
+ d = DiagMatrix(x)
97
+ assert d.shape == (n, n)
98
+ assert d[0, 1] == 0
99
+ assert d[0, 0] == x[0, 0]
100
+
101
+ a = MatrixSymbol('a', 1, 1)
102
+ d = diagonalize_vector(a)
103
+ assert isinstance(d, MatrixSymbol)
104
+ assert a == d
105
+ assert diagonalize_vector(Identity(3)) == Identity(3)
106
+ assert DiagMatrix(Identity(3)).doit() == Identity(3)
107
+ assert isinstance(DiagMatrix(Identity(3)), DiagMatrix)
108
+
109
+ # A diagonal matrix is equal to its transpose:
110
+ assert DiagMatrix(x).T == DiagMatrix(x)
111
+ assert diagonalize_vector(x.T) == DiagMatrix(x)
112
+
113
+ dx = DiagMatrix(x)
114
+ assert dx[0, 0] == x[0, 0]
115
+ assert dx[1, 1] == x[1, 0]
116
+ assert dx[0, 1] == 0
117
+ assert dx[0, m] == x[0, 0]*KroneckerDelta(0, m)
118
+
119
+ z = MatrixSymbol('z', 1, n)
120
+ dz = DiagMatrix(z)
121
+ assert dz[0, 0] == z[0, 0]
122
+ assert dz[1, 1] == z[0, 1]
123
+ assert dz[0, 1] == 0
124
+ assert dz[0, m] == z[0, m]*KroneckerDelta(0, m)
125
+
126
+ v = MatrixSymbol('v', 3, 1)
127
+ dv = DiagMatrix(v)
128
+ assert dv.as_explicit() == Matrix([
129
+ [v[0, 0], 0, 0],
130
+ [0, v[1, 0], 0],
131
+ [0, 0, v[2, 0]],
132
+ ])
133
+
134
+ v = MatrixSymbol('v', 1, 3)
135
+ dv = DiagMatrix(v)
136
+ assert dv.as_explicit() == Matrix([
137
+ [v[0, 0], 0, 0],
138
+ [0, v[0, 1], 0],
139
+ [0, 0, v[0, 2]],
140
+ ])
141
+
142
+ dv = DiagMatrix(3*v)
143
+ assert dv.args == (3*v,)
144
+ assert dv.doit() == 3*DiagMatrix(v)
145
+ assert isinstance(dv.doit(), MatMul)
146
+
147
+ a = MatrixSymbol("a", 3, 1).as_explicit()
148
+ expr = DiagMatrix(a)
149
+ result = Matrix([
150
+ [a[0, 0], 0, 0],
151
+ [0, a[1, 0], 0],
152
+ [0, 0, a[2, 0]],
153
+ ])
154
+ assert expr.doit() == result
155
+ expr = DiagMatrix(a.T)
156
+ assert expr.doit() == result
tuning-competition-baseline/.venv/lib/python3.11/site-packages/sympy/matrices/expressions/tests/test_funcmatrix.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sympy.core import symbols, Lambda
2
+ from sympy.core.sympify import SympifyError
3
+ from sympy.functions import KroneckerDelta
4
+ from sympy.matrices import Matrix
5
+ from sympy.matrices.expressions import FunctionMatrix, MatrixExpr, Identity
6
+ from sympy.testing.pytest import raises
7
+
8
+
9
+ def test_funcmatrix_creation():
10
+ i, j, k = symbols('i j k')
11
+ assert FunctionMatrix(2, 2, Lambda((i, j), 0))
12
+ assert FunctionMatrix(0, 0, Lambda((i, j), 0))
13
+
14
+ raises(ValueError, lambda: FunctionMatrix(-1, 0, Lambda((i, j), 0)))
15
+ raises(ValueError, lambda: FunctionMatrix(2.0, 0, Lambda((i, j), 0)))
16
+ raises(ValueError, lambda: FunctionMatrix(2j, 0, Lambda((i, j), 0)))
17
+ raises(ValueError, lambda: FunctionMatrix(0, -1, Lambda((i, j), 0)))
18
+ raises(ValueError, lambda: FunctionMatrix(0, 2.0, Lambda((i, j), 0)))
19
+ raises(ValueError, lambda: FunctionMatrix(0, 2j, Lambda((i, j), 0)))
20
+
21
+ raises(ValueError, lambda: FunctionMatrix(2, 2, Lambda(i, 0)))
22
+ raises(SympifyError, lambda: FunctionMatrix(2, 2, lambda i, j: 0))
23
+ raises(ValueError, lambda: FunctionMatrix(2, 2, Lambda((i,), 0)))
24
+ raises(ValueError, lambda: FunctionMatrix(2, 2, Lambda((i, j, k), 0)))
25
+ raises(ValueError, lambda: FunctionMatrix(2, 2, i+j))
26
+ assert FunctionMatrix(2, 2, "lambda i, j: 0") == \
27
+ FunctionMatrix(2, 2, Lambda((i, j), 0))
28
+
29
+ m = FunctionMatrix(2, 2, KroneckerDelta)
30
+ assert m.as_explicit() == Identity(2).as_explicit()
31
+ assert m.args[2].dummy_eq(Lambda((i, j), KroneckerDelta(i, j)))
32
+
33
+ n = symbols('n')
34
+ assert FunctionMatrix(n, n, Lambda((i, j), 0))
35
+ n = symbols('n', integer=False)
36
+ raises(ValueError, lambda: FunctionMatrix(n, n, Lambda((i, j), 0)))
37
+ n = symbols('n', negative=True)
38
+ raises(ValueError, lambda: FunctionMatrix(n, n, Lambda((i, j), 0)))
39
+
40
+
41
+ def test_funcmatrix():
42
+ i, j = symbols('i,j')
43
+ X = FunctionMatrix(3, 3, Lambda((i, j), i - j))
44
+ assert X[1, 1] == 0
45
+ assert X[1, 2] == -1
46
+ assert X.shape == (3, 3)
47
+ assert X.rows == X.cols == 3
48
+ assert Matrix(X) == Matrix(3, 3, lambda i, j: i - j)
49
+ assert isinstance(X*X + X, MatrixExpr)
50
+
51
+
52
+ def test_replace_issue():
53
+ X = FunctionMatrix(3, 3, KroneckerDelta)
54
+ assert X.replace(lambda x: True, lambda x: x) == X
tuning-competition-baseline/.venv/lib/python3.11/site-packages/sympy/matrices/expressions/tests/test_indexing.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sympy.concrete.summations import Sum
2
+ from sympy.core.symbol import symbols, Symbol, Dummy
3
+ from sympy.functions.elementary.miscellaneous import sqrt
4
+ from sympy.functions.special.tensor_functions import KroneckerDelta
5
+ from sympy.matrices.dense import eye
6
+ from sympy.matrices.expressions.blockmatrix import BlockMatrix
7
+ from sympy.matrices.expressions.hadamard import HadamardPower
8
+ from sympy.matrices.expressions.matexpr import (MatrixSymbol,
9
+ MatrixExpr, MatrixElement)
10
+ from sympy.matrices.expressions.matpow import MatPow
11
+ from sympy.matrices.expressions.special import (ZeroMatrix, Identity,
12
+ OneMatrix)
13
+ from sympy.matrices.expressions.trace import Trace, trace
14
+ from sympy.matrices.immutable import ImmutableMatrix
15
+ from sympy.tensor.array.expressions.array_expressions import ArrayTensorProduct
16
+ from sympy.testing.pytest import XFAIL, raises
17
+
18
+ k, l, m, n = symbols('k l m n', integer=True)
19
+ i, j = symbols('i j', integer=True)
20
+
21
+ W = MatrixSymbol('W', k, l)
22
+ X = MatrixSymbol('X', l, m)
23
+ Y = MatrixSymbol('Y', l, m)
24
+ Z = MatrixSymbol('Z', m, n)
25
+
26
+ X1 = MatrixSymbol('X1', m, m)
27
+ X2 = MatrixSymbol('X2', m, m)
28
+ X3 = MatrixSymbol('X3', m, m)
29
+ X4 = MatrixSymbol('X4', m, m)
30
+
31
+ A = MatrixSymbol('A', 2, 2)
32
+ B = MatrixSymbol('B', 2, 2)
33
+ x = MatrixSymbol('x', 1, 2)
34
+ y = MatrixSymbol('x', 2, 1)
35
+
36
+
37
+ def test_symbolic_indexing():
38
+ x12 = X[1, 2]
39
+ assert all(s in str(x12) for s in ['1', '2', X.name])
40
+ # We don't care about the exact form of this. We do want to make sure
41
+ # that all of these features are present
42
+
43
+
44
+ def test_add_index():
45
+ assert (X + Y)[i, j] == X[i, j] + Y[i, j]
46
+
47
+
48
+ def test_mul_index():
49
+ assert (A*y)[0, 0] == A[0, 0]*y[0, 0] + A[0, 1]*y[1, 0]
50
+ assert (A*B).as_mutable() == (A.as_mutable() * B.as_mutable())
51
+ X = MatrixSymbol('X', n, m)
52
+ Y = MatrixSymbol('Y', m, k)
53
+
54
+ result = (X*Y)[4,2]
55
+ expected = Sum(X[4, i]*Y[i, 2], (i, 0, m - 1))
56
+ assert result.args[0].dummy_eq(expected.args[0], i)
57
+ assert result.args[1][1:] == expected.args[1][1:]
58
+
59
+
60
+ def test_pow_index():
61
+ Q = MatPow(A, 2)
62
+ assert Q[0, 0] == A[0, 0]**2 + A[0, 1]*A[1, 0]
63
+ n = symbols("n")
64
+ Q2 = A**n
65
+ assert Q2[0, 0] == 2*(
66
+ -sqrt((A[0, 0] + A[1, 1])**2 - 4*A[0, 0]*A[1, 1] +
67
+ 4*A[0, 1]*A[1, 0])/2 + A[0, 0]/2 + A[1, 1]/2
68
+ )**n * \
69
+ A[0, 1]*A[1, 0]/(
70
+ (sqrt(A[0, 0]**2 - 2*A[0, 0]*A[1, 1] + 4*A[0, 1]*A[1, 0] +
71
+ A[1, 1]**2) + A[0, 0] - A[1, 1])*
72
+ sqrt(A[0, 0]**2 - 2*A[0, 0]*A[1, 1] + 4*A[0, 1]*A[1, 0] + A[1, 1]**2)
73
+ ) - 2*(
74
+ sqrt((A[0, 0] + A[1, 1])**2 - 4*A[0, 0]*A[1, 1] +
75
+ 4*A[0, 1]*A[1, 0])/2 + A[0, 0]/2 + A[1, 1]/2
76
+ )**n * A[0, 1]*A[1, 0]/(
77
+ (-sqrt(A[0, 0]**2 - 2*A[0, 0]*A[1, 1] + 4*A[0, 1]*A[1, 0] +
78
+ A[1, 1]**2) + A[0, 0] - A[1, 1])*
79
+ sqrt(A[0, 0]**2 - 2*A[0, 0]*A[1, 1] + 4*A[0, 1]*A[1, 0] + A[1, 1]**2)
80
+ )
81
+
82
+
83
+ def test_transpose_index():
84
+ assert X.T[i, j] == X[j, i]
85
+
86
+
87
+ def test_Identity_index():
88
+ I = Identity(3)
89
+ assert I[0, 0] == I[1, 1] == I[2, 2] == 1
90
+ assert I[1, 0] == I[0, 1] == I[2, 1] == 0
91
+ assert I[i, 0].delta_range == (0, 2)
92
+ raises(IndexError, lambda: I[3, 3])
93
+
94
+
95
+ def test_block_index():
96
+ I = Identity(3)
97
+ Z = ZeroMatrix(3, 3)
98
+ B = BlockMatrix([[I, I], [I, I]])
99
+ e3 = ImmutableMatrix(eye(3))
100
+ BB = BlockMatrix([[e3, e3], [e3, e3]])
101
+ assert B[0, 0] == B[3, 0] == B[0, 3] == B[3, 3] == 1
102
+ assert B[4, 3] == B[5, 1] == 0
103
+
104
+ BB = BlockMatrix([[e3, e3], [e3, e3]])
105
+ assert B.as_explicit() == BB.as_explicit()
106
+
107
+ BI = BlockMatrix([[I, Z], [Z, I]])
108
+
109
+ assert BI.as_explicit().equals(eye(6))
110
+
111
+
112
+ def test_block_index_symbolic():
113
+ # Note that these matrices may be zero-sized and indices may be negative, which causes
114
+ # all naive simplifications given in the comments to be invalid
115
+ A1 = MatrixSymbol('A1', n, k)
116
+ A2 = MatrixSymbol('A2', n, l)
117
+ A3 = MatrixSymbol('A3', m, k)
118
+ A4 = MatrixSymbol('A4', m, l)
119
+ A = BlockMatrix([[A1, A2], [A3, A4]])
120
+ assert A[0, 0] == MatrixElement(A, 0, 0) # Cannot be A1[0, 0]
121
+ assert A[n - 1, k - 1] == A1[n - 1, k - 1]
122
+ assert A[n, k] == A4[0, 0]
123
+ assert A[n + m - 1, 0] == MatrixElement(A, n + m - 1, 0) # Cannot be A3[m - 1, 0]
124
+ assert A[0, k + l - 1] == MatrixElement(A, 0, k + l - 1) # Cannot be A2[0, l - 1]
125
+ assert A[n + m - 1, k + l - 1] == MatrixElement(A, n + m - 1, k + l - 1) # Cannot be A4[m - 1, l - 1]
126
+ assert A[i, j] == MatrixElement(A, i, j)
127
+ assert A[n + i, k + j] == MatrixElement(A, n + i, k + j) # Cannot be A4[i, j]
128
+ assert A[n - i - 1, k - j - 1] == MatrixElement(A, n - i - 1, k - j - 1) # Cannot be A1[n - i - 1, k - j - 1]
129
+
130
+
131
+ def test_block_index_symbolic_nonzero():
132
+ # All invalid simplifications from test_block_index_symbolic() that become valid if all
133
+ # matrices have nonzero size and all indices are nonnegative
134
+ k, l, m, n = symbols('k l m n', integer=True, positive=True)
135
+ i, j = symbols('i j', integer=True, nonnegative=True)
136
+ A1 = MatrixSymbol('A1', n, k)
137
+ A2 = MatrixSymbol('A2', n, l)
138
+ A3 = MatrixSymbol('A3', m, k)
139
+ A4 = MatrixSymbol('A4', m, l)
140
+ A = BlockMatrix([[A1, A2], [A3, A4]])
141
+ assert A[0, 0] == A1[0, 0]
142
+ assert A[n + m - 1, 0] == A3[m - 1, 0]
143
+ assert A[0, k + l - 1] == A2[0, l - 1]
144
+ assert A[n + m - 1, k + l - 1] == A4[m - 1, l - 1]
145
+ assert A[i, j] == MatrixElement(A, i, j)
146
+ assert A[n + i, k + j] == A4[i, j]
147
+ assert A[n - i - 1, k - j - 1] == A1[n - i - 1, k - j - 1]
148
+ assert A[2 * n, 2 * k] == A4[n, k]
149
+
150
+
151
+ def test_block_index_large():
152
+ n, m, k = symbols('n m k', integer=True, positive=True)
153
+ i = symbols('i', integer=True, nonnegative=True)
154
+ A1 = MatrixSymbol('A1', n, n)
155
+ A2 = MatrixSymbol('A2', n, m)
156
+ A3 = MatrixSymbol('A3', n, k)
157
+ A4 = MatrixSymbol('A4', m, n)
158
+ A5 = MatrixSymbol('A5', m, m)
159
+ A6 = MatrixSymbol('A6', m, k)
160
+ A7 = MatrixSymbol('A7', k, n)
161
+ A8 = MatrixSymbol('A8', k, m)
162
+ A9 = MatrixSymbol('A9', k, k)
163
+ A = BlockMatrix([[A1, A2, A3], [A4, A5, A6], [A7, A8, A9]])
164
+ assert A[n + i, n + i] == MatrixElement(A, n + i, n + i)
165
+
166
+
167
+ @XFAIL
168
+ def test_block_index_symbolic_fail():
169
+ # To make this work, symbolic matrix dimensions would need to be somehow assumed nonnegative
170
+ # even if the symbols aren't specified as such. Then 2 * n < n would correctly evaluate to
171
+ # False in BlockMatrix._entry()
172
+ A1 = MatrixSymbol('A1', n, 1)
173
+ A2 = MatrixSymbol('A2', m, 1)
174
+ A = BlockMatrix([[A1], [A2]])
175
+ assert A[2 * n, 0] == A2[n, 0]
176
+
177
+
178
+ def test_slicing():
179
+ A.as_explicit()[0, :] # does not raise an error
180
+
181
+
182
+ def test_errors():
183
+ raises(IndexError, lambda: Identity(2)[1, 2, 3, 4, 5])
184
+ raises(IndexError, lambda: Identity(2)[[1, 2, 3, 4, 5]])
185
+
186
+
187
+ def test_matrix_expression_to_indices():
188
+ i, j = symbols("i, j")
189
+ i1, i2, i3 = symbols("i_1:4")
190
+
191
+ def replace_dummies(expr):
192
+ repl = {i: Symbol(i.name) for i in expr.atoms(Dummy)}
193
+ return expr.xreplace(repl)
194
+
195
+ expr = W*X*Z
196
+ assert replace_dummies(expr._entry(i, j)) == \
197
+ Sum(W[i, i1]*X[i1, i2]*Z[i2, j], (i1, 0, l-1), (i2, 0, m-1))
198
+ assert MatrixExpr.from_index_summation(expr._entry(i, j)) == expr
199
+
200
+ expr = Z.T*X.T*W.T
201
+ assert replace_dummies(expr._entry(i, j)) == \
202
+ Sum(W[j, i2]*X[i2, i1]*Z[i1, i], (i1, 0, m-1), (i2, 0, l-1))
203
+ assert MatrixExpr.from_index_summation(expr._entry(i, j), i) == expr
204
+
205
+ expr = W*X*Z + W*Y*Z
206
+ assert replace_dummies(expr._entry(i, j)) == \
207
+ Sum(W[i, i1]*X[i1, i2]*Z[i2, j], (i1, 0, l-1), (i2, 0, m-1)) +\
208
+ Sum(W[i, i1]*Y[i1, i2]*Z[i2, j], (i1, 0, l-1), (i2, 0, m-1))
209
+ assert MatrixExpr.from_index_summation(expr._entry(i, j)) == expr
210
+
211
+ expr = 2*W*X*Z + 3*W*Y*Z
212
+ assert replace_dummies(expr._entry(i, j)) == \
213
+ 2*Sum(W[i, i1]*X[i1, i2]*Z[i2, j], (i1, 0, l-1), (i2, 0, m-1)) +\
214
+ 3*Sum(W[i, i1]*Y[i1, i2]*Z[i2, j], (i1, 0, l-1), (i2, 0, m-1))
215
+ assert MatrixExpr.from_index_summation(expr._entry(i, j)) == expr
216
+
217
+ expr = W*(X + Y)*Z
218
+ assert replace_dummies(expr._entry(i, j)) == \
219
+ Sum(W[i, i1]*(X[i1, i2] + Y[i1, i2])*Z[i2, j], (i1, 0, l-1), (i2, 0, m-1))
220
+ assert MatrixExpr.from_index_summation(expr._entry(i, j)) == expr
221
+
222
+ expr = A*B**2*A
223
+ #assert replace_dummies(expr._entry(i, j)) == \
224
+ # Sum(A[i, i1]*B[i1, i2]*B[i2, i3]*A[i3, j], (i1, 0, 1), (i2, 0, 1), (i3, 0, 1))
225
+
226
+ # Check that different dummies are used in sub-multiplications:
227
+ expr = (X1*X2 + X2*X1)*X3
228
+ assert replace_dummies(expr._entry(i, j)) == \
229
+ Sum((Sum(X1[i, i2] * X2[i2, i1], (i2, 0, m - 1)) + Sum(X1[i3, i1] * X2[i, i3], (i3, 0, m - 1))) * X3[
230
+ i1, j], (i1, 0, m - 1))
231
+
232
+
233
+ def test_matrix_expression_from_index_summation():
234
+ from sympy.abc import a,b,c,d
235
+ A = MatrixSymbol("A", k, k)
236
+ B = MatrixSymbol("B", k, k)
237
+ C = MatrixSymbol("C", k, k)
238
+ w1 = MatrixSymbol("w1", k, 1)
239
+
240
+ i0, i1, i2, i3, i4 = symbols("i0:5", cls=Dummy)
241
+
242
+ expr = Sum(W[a,b]*X[b,c]*Z[c,d], (b, 0, l-1), (c, 0, m-1))
243
+ assert MatrixExpr.from_index_summation(expr, a) == W*X*Z
244
+ expr = Sum(W.T[b,a]*X[b,c]*Z[c,d], (b, 0, l-1), (c, 0, m-1))
245
+ assert MatrixExpr.from_index_summation(expr, a) == W*X*Z
246
+ expr = Sum(A[b, a]*B[b, c]*C[c, d], (b, 0, k-1), (c, 0, k-1))
247
+ assert MatrixSymbol.from_index_summation(expr, a) == A.T*B*C
248
+ expr = Sum(A[b, a]*B[c, b]*C[c, d], (b, 0, k-1), (c, 0, k-1))
249
+ assert MatrixSymbol.from_index_summation(expr, a) == A.T*B.T*C
250
+ expr = Sum(C[c, d]*A[b, a]*B[c, b], (b, 0, k-1), (c, 0, k-1))
251
+ assert MatrixSymbol.from_index_summation(expr, a) == A.T*B.T*C
252
+ expr = Sum(A[a, b] + B[a, b], (a, 0, k-1), (b, 0, k-1))
253
+ assert MatrixExpr.from_index_summation(expr, a) == OneMatrix(1, k)*A*OneMatrix(k, 1) + OneMatrix(1, k)*B*OneMatrix(k, 1)
254
+ expr = Sum(A[a, b]**2, (a, 0, k - 1), (b, 0, k - 1))
255
+ assert MatrixExpr.from_index_summation(expr, a) == Trace(A * A.T)
256
+ expr = Sum(A[a, b]**3, (a, 0, k - 1), (b, 0, k - 1))
257
+ assert MatrixExpr.from_index_summation(expr, a) == Trace(HadamardPower(A.T, 2) * A)
258
+ expr = Sum((A[a, b] + B[a, b])*C[b, c], (b, 0, k-1))
259
+ assert MatrixExpr.from_index_summation(expr, a) == (A+B)*C
260
+ expr = Sum((A[a, b] + B[b, a])*C[b, c], (b, 0, k-1))
261
+ assert MatrixExpr.from_index_summation(expr, a) == (A+B.T)*C
262
+ expr = Sum(A[a, b]*A[b, c]*A[c, d], (b, 0, k-1), (c, 0, k-1))
263
+ assert MatrixExpr.from_index_summation(expr, a) == A**3
264
+ expr = Sum(A[a, b]*A[b, c]*B[c, d], (b, 0, k-1), (c, 0, k-1))
265
+ assert MatrixExpr.from_index_summation(expr, a) == A**2*B
266
+
267
+ # Parse the trace of a matrix:
268
+
269
+ expr = Sum(A[a, a], (a, 0, k-1))
270
+ assert MatrixExpr.from_index_summation(expr, None) == trace(A)
271
+ expr = Sum(A[a, a]*B[b, c]*C[c, d], (a, 0, k-1), (c, 0, k-1))
272
+ assert MatrixExpr.from_index_summation(expr, b) == trace(A)*B*C
273
+
274
+ # Check wrong sum ranges (should raise an exception):
275
+
276
+ ## Case 1: 0 to m instead of 0 to m-1
277
+ expr = Sum(W[a,b]*X[b,c]*Z[c,d], (b, 0, l-1), (c, 0, m))
278
+ raises(ValueError, lambda: MatrixExpr.from_index_summation(expr, a))
279
+ ## Case 2: 1 to m-1 instead of 0 to m-1
280
+ expr = Sum(W[a,b]*X[b,c]*Z[c,d], (b, 0, l-1), (c, 1, m-1))
281
+ raises(ValueError, lambda: MatrixExpr.from_index_summation(expr, a))
282
+
283
+ # Parse nested sums:
284
+ expr = Sum(A[a, b]*Sum(B[b, c]*C[c, d], (c, 0, k-1)), (b, 0, k-1))
285
+ assert MatrixExpr.from_index_summation(expr, a) == A*B*C
286
+
287
+ # Test Kronecker delta:
288
+ expr = Sum(A[a, b]*KroneckerDelta(b, c)*B[c, d], (b, 0, k-1), (c, 0, k-1))
289
+ assert MatrixExpr.from_index_summation(expr, a) == A*B
290
+
291
+ expr = Sum(KroneckerDelta(i1, m)*KroneckerDelta(i2, n)*A[i, i1]*A[j, i2], (i1, 0, k-1), (i2, 0, k-1))
292
+ assert MatrixExpr.from_index_summation(expr, m) == ArrayTensorProduct(A.T, A)
293
+
294
+ # Test numbered indices:
295
+ expr = Sum(A[i1, i2]*w1[i2, 0], (i2, 0, k-1))
296
+ assert MatrixExpr.from_index_summation(expr, i1) == MatrixElement(A*w1, i1, 0)
297
+
298
+ expr = Sum(A[i1, i2]*B[i2, 0], (i2, 0, k-1))
299
+ assert MatrixExpr.from_index_summation(expr, i1) == MatrixElement(A*B, i1, 0)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/sympy/matrices/expressions/tests/test_matadd.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sympy.matrices.expressions import MatrixSymbol, MatAdd, MatPow, MatMul
2
+ from sympy.matrices.expressions.special import GenericZeroMatrix, ZeroMatrix
3
+ from sympy.matrices.exceptions import ShapeError
4
+ from sympy.matrices import eye, ImmutableMatrix
5
+ from sympy.core import Add, Basic, S
6
+ from sympy.core.add import add
7
+ from sympy.testing.pytest import XFAIL, raises
8
+
9
+ X = MatrixSymbol('X', 2, 2)
10
+ Y = MatrixSymbol('Y', 2, 2)
11
+
12
+ def test_evaluate():
13
+ assert MatAdd(X, X, evaluate=True) == add(X, X, evaluate=True) == MatAdd(X, X).doit()
14
+
15
+ def test_sort_key():
16
+ assert MatAdd(Y, X).doit().args == add(Y, X).doit().args == (X, Y)
17
+
18
+
19
+ def test_matadd_sympify():
20
+ assert isinstance(MatAdd(eye(1), eye(1)).args[0], Basic)
21
+ assert isinstance(add(eye(1), eye(1)).args[0], Basic)
22
+
23
+
24
+ def test_matadd_of_matrices():
25
+ assert MatAdd(eye(2), 4*eye(2), eye(2)).doit() == ImmutableMatrix(6*eye(2))
26
+ assert add(eye(2), 4*eye(2), eye(2)).doit() == ImmutableMatrix(6*eye(2))
27
+
28
+
29
+ def test_doit_args():
30
+ A = ImmutableMatrix([[1, 2], [3, 4]])
31
+ B = ImmutableMatrix([[2, 3], [4, 5]])
32
+ assert MatAdd(A, MatPow(B, 2)).doit() == A + B**2
33
+ assert MatAdd(A, MatMul(A, B)).doit() == A + A*B
34
+ assert (MatAdd(A, X, MatMul(A, B), Y, MatAdd(2*A, B)).doit() ==
35
+ add(A, X, MatMul(A, B), Y, add(2*A, B)).doit() ==
36
+ MatAdd(3*A + A*B + B, X, Y))
37
+
38
+
39
+ def test_generic_identity():
40
+ assert MatAdd.identity == GenericZeroMatrix()
41
+ assert MatAdd.identity != S.Zero
42
+
43
+
44
+ def test_zero_matrix_add():
45
+ assert Add(ZeroMatrix(2, 2), ZeroMatrix(2, 2)) == ZeroMatrix(2, 2)
46
+
47
+ @XFAIL
48
+ def test_matrix_Add_with_scalar():
49
+ raises(TypeError, lambda: Add(0, ZeroMatrix(2, 2)))
50
+
51
+
52
+ def test_shape_error():
53
+ A = MatrixSymbol('A', 2, 3)
54
+ B = MatrixSymbol('B', 3, 3)
55
+ raises(ShapeError, lambda: MatAdd(A, B))
56
+
57
+ A = MatrixSymbol('A', 3, 2)
58
+ raises(ShapeError, lambda: MatAdd(A, B))
tuning-competition-baseline/.venv/lib/python3.11/site-packages/sympy/matrices/expressions/tests/test_matexpr.py ADDED
@@ -0,0 +1,592 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sympy.concrete.summations import Sum
2
+ from sympy.core.exprtools import gcd_terms
3
+ from sympy.core.function import (diff, expand)
4
+ from sympy.core.relational import Eq
5
+ from sympy.core.symbol import (Dummy, Symbol, Str)
6
+ from sympy.functions.special.tensor_functions import KroneckerDelta
7
+ from sympy.matrices.dense import zeros
8
+ from sympy.polys.polytools import factor
9
+
10
+ from sympy.core import (S, symbols, Add, Mul, SympifyError, Rational,
11
+ Function)
12
+ from sympy.functions import sin, cos, tan, sqrt, cbrt, exp
13
+ from sympy.simplify import simplify
14
+ from sympy.matrices import (ImmutableMatrix, Inverse, MatAdd, MatMul,
15
+ MatPow, Matrix, MatrixExpr, MatrixSymbol,
16
+ SparseMatrix, Transpose, Adjoint, MatrixSet)
17
+ from sympy.matrices.exceptions import NonSquareMatrixError
18
+ from sympy.matrices.expressions.determinant import Determinant, det
19
+ from sympy.matrices.expressions.matexpr import MatrixElement
20
+ from sympy.matrices.expressions.special import ZeroMatrix, Identity
21
+ from sympy.testing.pytest import raises, XFAIL, skip
22
+ from importlib.metadata import version
23
+
24
+ n, m, l, k, p = symbols('n m l k p', integer=True)
25
+ x = symbols('x')
26
+ A = MatrixSymbol('A', n, m)
27
+ B = MatrixSymbol('B', m, l)
28
+ C = MatrixSymbol('C', n, n)
29
+ D = MatrixSymbol('D', n, n)
30
+ E = MatrixSymbol('E', m, n)
31
+ w = MatrixSymbol('w', n, 1)
32
+
33
+
34
+ def test_matrix_symbol_creation():
35
+ assert MatrixSymbol('A', 2, 2)
36
+ assert MatrixSymbol('A', 0, 0)
37
+ raises(ValueError, lambda: MatrixSymbol('A', -1, 2))
38
+ raises(ValueError, lambda: MatrixSymbol('A', 2.0, 2))
39
+ raises(ValueError, lambda: MatrixSymbol('A', 2j, 2))
40
+ raises(ValueError, lambda: MatrixSymbol('A', 2, -1))
41
+ raises(ValueError, lambda: MatrixSymbol('A', 2, 2.0))
42
+ raises(ValueError, lambda: MatrixSymbol('A', 2, 2j))
43
+
44
+ n = symbols('n')
45
+ assert MatrixSymbol('A', n, n)
46
+ n = symbols('n', integer=False)
47
+ raises(ValueError, lambda: MatrixSymbol('A', n, n))
48
+ n = symbols('n', negative=True)
49
+ raises(ValueError, lambda: MatrixSymbol('A', n, n))
50
+
51
+
52
+ def test_matexpr_properties():
53
+ assert A.shape == (n, m)
54
+ assert (A * B).shape == (n, l)
55
+ assert A[0, 1].indices == (0, 1)
56
+ assert A[0, 0].symbol == A
57
+ assert A[0, 0].symbol.name == 'A'
58
+
59
+
60
+ def test_matexpr():
61
+ assert (x*A).shape == A.shape
62
+ assert (x*A).__class__ == MatMul
63
+ assert 2*A - A - A == ZeroMatrix(*A.shape)
64
+ assert (A*B).shape == (n, l)
65
+
66
+
67
+ def test_matexpr_subs():
68
+ A = MatrixSymbol('A', n, m)
69
+ B = MatrixSymbol('B', m, l)
70
+ C = MatrixSymbol('C', m, l)
71
+
72
+ assert A.subs(n, m).shape == (m, m)
73
+ assert (A*B).subs(B, C) == A*C
74
+ assert (A*B).subs(l, n).is_square
75
+
76
+ W = MatrixSymbol("W", 3, 3)
77
+ X = MatrixSymbol("X", 2, 2)
78
+ Y = MatrixSymbol("Y", 1, 2)
79
+ Z = MatrixSymbol("Z", n, 2)
80
+ # no restrictions on Symbol replacement
81
+ assert X.subs(X, Y) == Y
82
+ # it might be better to just change the name
83
+ y = Str('y')
84
+ assert X.subs(Str("X"), y).args == (y, 2, 2)
85
+ # it's ok to introduce a wider matrix
86
+ assert X[1, 1].subs(X, W) == W[1, 1]
87
+ # but for a given MatrixExpression, only change
88
+ # name if indexing on the new shape is valid.
89
+ # Here, X is 2,2; Y is 1,2 and Y[1, 1] is out
90
+ # of range so an error is raised
91
+ raises(IndexError, lambda: X[1, 1].subs(X, Y))
92
+ # here, [0, 1] is in range so the subs succeeds
93
+ assert X[0, 1].subs(X, Y) == Y[0, 1]
94
+ # and here the size of n will accept any index
95
+ # in the first position
96
+ assert W[2, 1].subs(W, Z) == Z[2, 1]
97
+ # but not in the second position
98
+ raises(IndexError, lambda: W[2, 2].subs(W, Z))
99
+ # any matrix should raise if invalid
100
+ raises(IndexError, lambda: W[2, 2].subs(W, zeros(2)))
101
+
102
+ A = SparseMatrix([[1, 2], [3, 4]])
103
+ B = Matrix([[1, 2], [3, 4]])
104
+ C, D = MatrixSymbol('C', 2, 2), MatrixSymbol('D', 2, 2)
105
+
106
+ assert (C*D).subs({C: A, D: B}) == MatMul(A, B)
107
+
108
+
109
+ def test_addition():
110
+ A = MatrixSymbol('A', n, m)
111
+ B = MatrixSymbol('B', n, m)
112
+
113
+ assert isinstance(A + B, MatAdd)
114
+ assert (A + B).shape == A.shape
115
+ assert isinstance(A - A + 2*B, MatMul)
116
+
117
+ raises(TypeError, lambda: A + 1)
118
+ raises(TypeError, lambda: 5 + A)
119
+ raises(TypeError, lambda: 5 - A)
120
+
121
+ assert A + ZeroMatrix(n, m) - A == ZeroMatrix(n, m)
122
+ raises(TypeError, lambda: ZeroMatrix(n, m) + S.Zero)
123
+
124
+
125
+ def test_multiplication():
126
+ A = MatrixSymbol('A', n, m)
127
+ B = MatrixSymbol('B', m, l)
128
+ C = MatrixSymbol('C', n, n)
129
+
130
+ assert (2*A*B).shape == (n, l)
131
+ assert (A*0*B) == ZeroMatrix(n, l)
132
+ assert (2*A).shape == A.shape
133
+
134
+ assert A * ZeroMatrix(m, m) * B == ZeroMatrix(n, l)
135
+
136
+ assert C * Identity(n) * C.I == Identity(n)
137
+
138
+ assert B/2 == S.Half*B
139
+ raises(NotImplementedError, lambda: 2/B)
140
+
141
+ A = MatrixSymbol('A', n, n)
142
+ B = MatrixSymbol('B', n, n)
143
+ assert Identity(n) * (A + B) == A + B
144
+
145
+ assert A**2*A == A**3
146
+ assert A**2*(A.I)**3 == A.I
147
+ assert A**3*(A.I)**2 == A
148
+
149
+
150
+ def test_MatPow():
151
+ A = MatrixSymbol('A', n, n)
152
+
153
+ AA = MatPow(A, 2)
154
+ assert AA.exp == 2
155
+ assert AA.base == A
156
+ assert (A**n).exp == n
157
+
158
+ assert A**0 == Identity(n)
159
+ assert A**1 == A
160
+ assert A**2 == AA
161
+ assert A**-1 == Inverse(A)
162
+ assert (A**-1)**-1 == A
163
+ assert (A**2)**3 == A**6
164
+ assert A**S.Half == sqrt(A)
165
+ assert A**Rational(1, 3) == cbrt(A)
166
+ raises(NonSquareMatrixError, lambda: MatrixSymbol('B', 3, 2)**2)
167
+
168
+
169
+ def test_MatrixSymbol():
170
+ n, m, t = symbols('n,m,t')
171
+ X = MatrixSymbol('X', n, m)
172
+ assert X.shape == (n, m)
173
+ raises(TypeError, lambda: MatrixSymbol('X', n, m)(t)) # issue 5855
174
+ assert X.doit() == X
175
+
176
+
177
+ def test_dense_conversion():
178
+ X = MatrixSymbol('X', 2, 2)
179
+ assert ImmutableMatrix(X) == ImmutableMatrix(2, 2, lambda i, j: X[i, j])
180
+ assert Matrix(X) == Matrix(2, 2, lambda i, j: X[i, j])
181
+
182
+
183
+ def test_free_symbols():
184
+ assert (C*D).free_symbols == {C, D}
185
+
186
+
187
+ def test_zero_matmul():
188
+ assert isinstance(S.Zero * MatrixSymbol('X', 2, 2), MatrixExpr)
189
+
190
+
191
+ def test_matadd_simplify():
192
+ A = MatrixSymbol('A', 1, 1)
193
+ assert simplify(MatAdd(A, ImmutableMatrix([[sin(x)**2 + cos(x)**2]]))) == \
194
+ MatAdd(A, Matrix([[1]]))
195
+
196
+
197
+ def test_matmul_simplify():
198
+ A = MatrixSymbol('A', 1, 1)
199
+ assert simplify(MatMul(A, ImmutableMatrix([[sin(x)**2 + cos(x)**2]]))) == \
200
+ MatMul(A, Matrix([[1]]))
201
+
202
+
203
+ def test_invariants():
204
+ A = MatrixSymbol('A', n, m)
205
+ B = MatrixSymbol('B', m, l)
206
+ X = MatrixSymbol('X', n, n)
207
+ objs = [Identity(n), ZeroMatrix(m, n), A, MatMul(A, B), MatAdd(A, A),
208
+ Transpose(A), Adjoint(A), Inverse(X), MatPow(X, 2), MatPow(X, -1),
209
+ MatPow(X, 0)]
210
+ for obj in objs:
211
+ assert obj == obj.__class__(*obj.args)
212
+
213
+
214
+ def test_matexpr_indexing():
215
+ A = MatrixSymbol('A', n, m)
216
+ A[1, 2]
217
+ A[l, k]
218
+ A[l + 1, k + 1]
219
+ A = MatrixSymbol('A', 2, 1)
220
+ for i in range(-2, 2):
221
+ for j in range(-1, 1):
222
+ A[i, j]
223
+
224
+
225
+ def test_single_indexing():
226
+ A = MatrixSymbol('A', 2, 3)
227
+ assert A[1] == A[0, 1]
228
+ assert A[int(1)] == A[0, 1]
229
+ assert A[3] == A[1, 0]
230
+ assert list(A[:2, :2]) == [A[0, 0], A[0, 1], A[1, 0], A[1, 1]]
231
+ raises(IndexError, lambda: A[6])
232
+ raises(IndexError, lambda: A[n])
233
+ B = MatrixSymbol('B', n, m)
234
+ raises(IndexError, lambda: B[1])
235
+ B = MatrixSymbol('B', n, 3)
236
+ assert B[3] == B[1, 0]
237
+
238
+
239
+ def test_MatrixElement_commutative():
240
+ assert A[0, 1]*A[1, 0] == A[1, 0]*A[0, 1]
241
+
242
+
243
+ def test_MatrixSymbol_determinant():
244
+ A = MatrixSymbol('A', 4, 4)
245
+ assert A.as_explicit().det() == A[0, 0]*A[1, 1]*A[2, 2]*A[3, 3] - \
246
+ A[0, 0]*A[1, 1]*A[2, 3]*A[3, 2] - A[0, 0]*A[1, 2]*A[2, 1]*A[3, 3] + \
247
+ A[0, 0]*A[1, 2]*A[2, 3]*A[3, 1] + A[0, 0]*A[1, 3]*A[2, 1]*A[3, 2] - \
248
+ A[0, 0]*A[1, 3]*A[2, 2]*A[3, 1] - A[0, 1]*A[1, 0]*A[2, 2]*A[3, 3] + \
249
+ A[0, 1]*A[1, 0]*A[2, 3]*A[3, 2] + A[0, 1]*A[1, 2]*A[2, 0]*A[3, 3] - \
250
+ A[0, 1]*A[1, 2]*A[2, 3]*A[3, 0] - A[0, 1]*A[1, 3]*A[2, 0]*A[3, 2] + \
251
+ A[0, 1]*A[1, 3]*A[2, 2]*A[3, 0] + A[0, 2]*A[1, 0]*A[2, 1]*A[3, 3] - \
252
+ A[0, 2]*A[1, 0]*A[2, 3]*A[3, 1] - A[0, 2]*A[1, 1]*A[2, 0]*A[3, 3] + \
253
+ A[0, 2]*A[1, 1]*A[2, 3]*A[3, 0] + A[0, 2]*A[1, 3]*A[2, 0]*A[3, 1] - \
254
+ A[0, 2]*A[1, 3]*A[2, 1]*A[3, 0] - A[0, 3]*A[1, 0]*A[2, 1]*A[3, 2] + \
255
+ A[0, 3]*A[1, 0]*A[2, 2]*A[3, 1] + A[0, 3]*A[1, 1]*A[2, 0]*A[3, 2] - \
256
+ A[0, 3]*A[1, 1]*A[2, 2]*A[3, 0] - A[0, 3]*A[1, 2]*A[2, 0]*A[3, 1] + \
257
+ A[0, 3]*A[1, 2]*A[2, 1]*A[3, 0]
258
+
259
+ B = MatrixSymbol('B', 4, 4)
260
+ assert Determinant(A + B).doit() == det(A + B) == (A + B).det()
261
+
262
+
263
+ def test_MatrixElement_diff():
264
+ assert (A[3, 0]*A[0, 0]).diff(A[0, 0]) == A[3, 0]
265
+
266
+
267
+ def test_MatrixElement_doit():
268
+ u = MatrixSymbol('u', 2, 1)
269
+ v = ImmutableMatrix([3, 5])
270
+ assert u[0, 0].subs(u, v).doit() == v[0, 0]
271
+
272
+
273
+ def test_identity_powers():
274
+ M = Identity(n)
275
+ assert MatPow(M, 3).doit() == M**3
276
+ assert M**n == M
277
+ assert MatPow(M, 0).doit() == M**2
278
+ assert M**-2 == M
279
+ assert MatPow(M, -2).doit() == M**0
280
+ N = Identity(3)
281
+ assert MatPow(N, 2).doit() == N**n
282
+ assert MatPow(N, 3).doit() == N
283
+ assert MatPow(N, -2).doit() == N**4
284
+ assert MatPow(N, 2).doit() == N**0
285
+
286
+
287
+ def test_Zero_power():
288
+ z1 = ZeroMatrix(n, n)
289
+ assert z1**4 == z1
290
+ raises(ValueError, lambda:z1**-2)
291
+ assert z1**0 == Identity(n)
292
+ assert MatPow(z1, 2).doit() == z1**2
293
+ raises(ValueError, lambda:MatPow(z1, -2).doit())
294
+ z2 = ZeroMatrix(3, 3)
295
+ assert MatPow(z2, 4).doit() == z2**4
296
+ raises(ValueError, lambda:z2**-3)
297
+ assert z2**3 == MatPow(z2, 3).doit()
298
+ assert z2**0 == Identity(3)
299
+ raises(ValueError, lambda:MatPow(z2, -1).doit())
300
+
301
+
302
+ def test_matrixelement_diff():
303
+ dexpr = diff((D*w)[k,0], w[p,0])
304
+
305
+ assert w[k, p].diff(w[k, p]) == 1
306
+ assert w[k, p].diff(w[0, 0]) == KroneckerDelta(0, k, (0, n-1))*KroneckerDelta(0, p, (0, 0))
307
+ _i_1 = Dummy("_i_1")
308
+ assert dexpr.dummy_eq(Sum(KroneckerDelta(_i_1, p, (0, n-1))*D[k, _i_1], (_i_1, 0, n - 1)))
309
+ assert dexpr.doit() == D[k, p]
310
+
311
+
312
+ def test_MatrixElement_with_values():
313
+ x, y, z, w = symbols("x y z w")
314
+ M = Matrix([[x, y], [z, w]])
315
+ i, j = symbols("i, j")
316
+ Mij = M[i, j]
317
+ assert isinstance(Mij, MatrixElement)
318
+ Ms = SparseMatrix([[2, 3], [4, 5]])
319
+ msij = Ms[i, j]
320
+ assert isinstance(msij, MatrixElement)
321
+ for oi, oj in [(0, 0), (0, 1), (1, 0), (1, 1)]:
322
+ assert Mij.subs({i: oi, j: oj}) == M[oi, oj]
323
+ assert msij.subs({i: oi, j: oj}) == Ms[oi, oj]
324
+ A = MatrixSymbol("A", 2, 2)
325
+ assert A[0, 0].subs(A, M) == x
326
+ assert A[i, j].subs(A, M) == M[i, j]
327
+ assert M[i, j].subs(M, A) == A[i, j]
328
+
329
+ assert isinstance(M[3*i - 2, j], MatrixElement)
330
+ assert M[3*i - 2, j].subs({i: 1, j: 0}) == M[1, 0]
331
+ assert isinstance(M[i, 0], MatrixElement)
332
+ assert M[i, 0].subs(i, 0) == M[0, 0]
333
+ assert M[0, i].subs(i, 1) == M[0, 1]
334
+
335
+ assert M[i, j].diff(x) == Matrix([[1, 0], [0, 0]])[i, j]
336
+
337
+ raises(ValueError, lambda: M[i, 2])
338
+ raises(ValueError, lambda: M[i, -1])
339
+ raises(ValueError, lambda: M[2, i])
340
+ raises(ValueError, lambda: M[-1, i])
341
+
342
+
343
+ def test_inv():
344
+ B = MatrixSymbol('B', 3, 3)
345
+ assert B.inv() == B**-1
346
+
347
+ # https://github.com/sympy/sympy/issues/19162
348
+ X = MatrixSymbol('X', 1, 1).as_explicit()
349
+ assert X.inv() == Matrix([[1/X[0, 0]]])
350
+
351
+ X = MatrixSymbol('X', 2, 2).as_explicit()
352
+ detX = X[0, 0]*X[1, 1] - X[0, 1]*X[1, 0]
353
+ invX = Matrix([[ X[1, 1], -X[0, 1]],
354
+ [-X[1, 0], X[0, 0]]]) / detX
355
+ assert X.inv() == invX
356
+
357
+
358
+ @XFAIL
359
+ def test_factor_expand():
360
+ A = MatrixSymbol("A", n, n)
361
+ B = MatrixSymbol("B", n, n)
362
+ expr1 = (A + B)*(C + D)
363
+ expr2 = A*C + B*C + A*D + B*D
364
+ assert expr1 != expr2
365
+ assert expand(expr1) == expr2
366
+ assert factor(expr2) == expr1
367
+
368
+ expr = B**(-1)*(A**(-1)*B**(-1) - A**(-1)*C*B**(-1))**(-1)*A**(-1)
369
+ I = Identity(n)
370
+ # Ideally we get the first, but we at least don't want a wrong answer
371
+ assert factor(expr) in [I - C, B**-1*(A**-1*(I - C)*B**-1)**-1*A**-1]
372
+
373
+ def test_numpy_conversion():
374
+ try:
375
+ from numpy import array, array_equal
376
+ except ImportError:
377
+ skip('NumPy must be available to test creating matrices from ndarrays')
378
+ A = MatrixSymbol('A', 2, 2)
379
+ np_array = array([[MatrixElement(A, 0, 0), MatrixElement(A, 0, 1)],
380
+ [MatrixElement(A, 1, 0), MatrixElement(A, 1, 1)]])
381
+ assert array_equal(array(A), np_array)
382
+ assert array_equal(array(A, copy=True), np_array)
383
+ if(int(version('numpy').split('.')[0]) >= 2): #run this test only if numpy is new enough that copy variable is passed properly.
384
+ raises(TypeError, lambda: array(A, copy=False))
385
+
386
+ def test_issue_2749():
387
+ A = MatrixSymbol("A", 5, 2)
388
+ assert (A.T * A).I.as_explicit() == Matrix([[(A.T * A).I[0, 0], (A.T * A).I[0, 1]], \
389
+ [(A.T * A).I[1, 0], (A.T * A).I[1, 1]]])
390
+
391
+
392
+ def test_issue_2750():
393
+ x = MatrixSymbol('x', 1, 1)
394
+ assert (x.T*x).as_explicit()**-1 == Matrix([[x[0, 0]**(-2)]])
395
+
396
+
397
+ def test_issue_7842():
398
+ A = MatrixSymbol('A', 3, 1)
399
+ B = MatrixSymbol('B', 2, 1)
400
+ assert Eq(A, B) == False
401
+ assert Eq(A[1,0], B[1, 0]).func is Eq
402
+ A = ZeroMatrix(2, 3)
403
+ B = ZeroMatrix(2, 3)
404
+ assert Eq(A, B) == True
405
+
406
+
407
+ def test_issue_21195():
408
+ t = symbols('t')
409
+ x = Function('x')(t)
410
+ dx = x.diff(t)
411
+ exp1 = cos(x) + cos(x)*dx
412
+ exp2 = sin(x) + tan(x)*(dx.diff(t))
413
+ exp3 = sin(x)*sin(t)*(dx.diff(t)).diff(t)
414
+ A = Matrix([[exp1], [exp2], [exp3]])
415
+ B = Matrix([[exp1.diff(x)], [exp2.diff(x)], [exp3.diff(x)]])
416
+ assert A.diff(x) == B
417
+
418
+
419
+ def test_issue_24859():
420
+ A = MatrixSymbol('A', 2, 3)
421
+ B = MatrixSymbol('B', 3, 2)
422
+ J = A*B
423
+ Jinv = Matrix(J).adjugate()
424
+ u = MatrixSymbol('u', 2, 3)
425
+ Jk = Jinv.subs(A, A + x*u)
426
+
427
+ expected = B[0, 1]*u[1, 0] + B[1, 1]*u[1, 1] + B[2, 1]*u[1, 2]
428
+ assert Jk[0, 0].diff(x) == expected
429
+ assert diff(Jk[0, 0], x).doit() == expected
430
+
431
+
432
+ def test_MatMul_postprocessor():
433
+ z = zeros(2)
434
+ z1 = ZeroMatrix(2, 2)
435
+ assert Mul(0, z) == Mul(z, 0) in [z, z1]
436
+
437
+ M = Matrix([[1, 2], [3, 4]])
438
+ Mx = Matrix([[x, 2*x], [3*x, 4*x]])
439
+ assert Mul(x, M) == Mul(M, x) == Mx
440
+
441
+ A = MatrixSymbol("A", 2, 2)
442
+ assert Mul(A, M) == MatMul(A, M)
443
+ assert Mul(M, A) == MatMul(M, A)
444
+ # Scalars should be absorbed into constant matrices
445
+ a = Mul(x, M, A)
446
+ b = Mul(M, x, A)
447
+ c = Mul(M, A, x)
448
+ assert a == b == c == MatMul(Mx, A)
449
+ a = Mul(x, A, M)
450
+ b = Mul(A, x, M)
451
+ c = Mul(A, M, x)
452
+ assert a == b == c == MatMul(A, Mx)
453
+ assert Mul(M, M) == M**2
454
+ assert Mul(A, M, M) == MatMul(A, M**2)
455
+ assert Mul(M, M, A) == MatMul(M**2, A)
456
+ assert Mul(M, A, M) == MatMul(M, A, M)
457
+
458
+ assert Mul(A, x, M, M, x) == MatMul(A, Mx**2)
459
+
460
+
461
+ @XFAIL
462
+ def test_MatAdd_postprocessor_xfail():
463
+ # This is difficult to get working because of the way that Add processes
464
+ # its args.
465
+ z = zeros(2)
466
+ assert Add(z, S.NaN) == Add(S.NaN, z)
467
+
468
+
469
+ def test_MatAdd_postprocessor():
470
+ # Some of these are nonsensical, but we do not raise errors for Add
471
+ # because that breaks algorithms that want to replace matrices with dummy
472
+ # symbols.
473
+
474
+ z = zeros(2)
475
+
476
+ assert Add(0, z) == Add(z, 0) == z
477
+
478
+ a = Add(S.Infinity, z)
479
+ assert a == Add(z, S.Infinity)
480
+ assert isinstance(a, Add)
481
+ assert a.args == (S.Infinity, z)
482
+
483
+ a = Add(S.ComplexInfinity, z)
484
+ assert a == Add(z, S.ComplexInfinity)
485
+ assert isinstance(a, Add)
486
+ assert a.args == (S.ComplexInfinity, z)
487
+
488
+ a = Add(z, S.NaN)
489
+ # assert a == Add(S.NaN, z) # See the XFAIL above
490
+ assert isinstance(a, Add)
491
+ assert a.args == (S.NaN, z)
492
+
493
+ M = Matrix([[1, 2], [3, 4]])
494
+ a = Add(x, M)
495
+ assert a == Add(M, x)
496
+ assert isinstance(a, Add)
497
+ assert a.args == (x, M)
498
+
499
+ A = MatrixSymbol("A", 2, 2)
500
+ assert Add(A, M) == Add(M, A) == A + M
501
+
502
+ # Scalars should be absorbed into constant matrices (producing an error)
503
+ a = Add(x, M, A)
504
+ assert a == Add(M, x, A) == Add(M, A, x) == Add(x, A, M) == Add(A, x, M) == Add(A, M, x)
505
+ assert isinstance(a, Add)
506
+ assert a.args == (x, A + M)
507
+
508
+ assert Add(M, M) == 2*M
509
+ assert Add(M, A, M) == Add(M, M, A) == Add(A, M, M) == A + 2*M
510
+
511
+ a = Add(A, x, M, M, x)
512
+ assert isinstance(a, Add)
513
+ assert a.args == (2*x, A + 2*M)
514
+
515
+
516
+ def test_simplify_matrix_expressions():
517
+ # Various simplification functions
518
+ assert type(gcd_terms(C*D + D*C)) == MatAdd
519
+ a = gcd_terms(2*C*D + 4*D*C)
520
+ assert type(a) == MatAdd
521
+ assert a.args == (2*C*D, 4*D*C)
522
+
523
+
524
+ def test_exp():
525
+ A = MatrixSymbol('A', 2, 2)
526
+ B = MatrixSymbol('B', 2, 2)
527
+ expr1 = exp(A)*exp(B)
528
+ expr2 = exp(B)*exp(A)
529
+ assert expr1 != expr2
530
+ assert expr1 - expr2 != 0
531
+ assert not isinstance(expr1, exp)
532
+ assert not isinstance(expr2, exp)
533
+
534
+
535
+ def test_invalid_args():
536
+ raises(SympifyError, lambda: MatrixSymbol(1, 2, 'A'))
537
+
538
+
539
+ def test_matrixsymbol_from_symbol():
540
+ # The label should be preserved during doit and subs
541
+ A_label = Symbol('A', complex=True)
542
+ A = MatrixSymbol(A_label, 2, 2)
543
+
544
+ A_1 = A.doit()
545
+ A_2 = A.subs(2, 3)
546
+ assert A_1.args == A.args
547
+ assert A_2.args[0] == A.args[0]
548
+
549
+
550
+ def test_as_explicit():
551
+ Z = MatrixSymbol('Z', 2, 3)
552
+ assert Z.as_explicit() == ImmutableMatrix([
553
+ [Z[0, 0], Z[0, 1], Z[0, 2]],
554
+ [Z[1, 0], Z[1, 1], Z[1, 2]],
555
+ ])
556
+ raises(ValueError, lambda: A.as_explicit())
557
+
558
+
559
+ def test_MatrixSet():
560
+ M = MatrixSet(2, 2, set=S.Reals)
561
+ assert M.shape == (2, 2)
562
+ assert M.set == S.Reals
563
+ X = Matrix([[1, 2], [3, 4]])
564
+ assert X in M
565
+ X = ZeroMatrix(2, 2)
566
+ assert X in M
567
+ raises(TypeError, lambda: A in M)
568
+ raises(TypeError, lambda: 1 in M)
569
+ M = MatrixSet(n, m, set=S.Reals)
570
+ assert A in M
571
+ raises(TypeError, lambda: C in M)
572
+ raises(TypeError, lambda: X in M)
573
+ M = MatrixSet(2, 2, set={1, 2, 3})
574
+ X = Matrix([[1, 2], [3, 4]])
575
+ Y = Matrix([[1, 2]])
576
+ assert (X in M) == S.false
577
+ assert (Y in M) == S.false
578
+ raises(ValueError, lambda: MatrixSet(2, -2, S.Reals))
579
+ raises(ValueError, lambda: MatrixSet(2.4, -1, S.Reals))
580
+ raises(TypeError, lambda: MatrixSet(2, 2, (1, 2, 3)))
581
+
582
+
583
+ def test_matrixsymbol_solving():
584
+ A = MatrixSymbol('A', 2, 2)
585
+ B = MatrixSymbol('B', 2, 2)
586
+ Z = ZeroMatrix(2, 2)
587
+ assert -(-A + B) - A + B == Z
588
+ assert (-(-A + B) - A + B).simplify() == Z
589
+ assert (-(-A + B) - A + B).expand() == Z
590
+ assert (-(-A + B) - A + B - Z).simplify() == Z
591
+ assert (-(-A + B) - A + B - Z).expand() == Z
592
+ assert (A*(A + B) + B*(A.T + B.T)).expand() == A**2 + A*B + B*A.T + B*B.T
tuning-competition-baseline/.venv/lib/python3.11/site-packages/sympy/matrices/expressions/tests/test_matmul.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sympy.core import I, symbols, Basic, Mul, S
2
+ from sympy.core.mul import mul
3
+ from sympy.functions import adjoint, transpose
4
+ from sympy.matrices.exceptions import ShapeError
5
+ from sympy.matrices import (Identity, Inverse, Matrix, MatrixSymbol, ZeroMatrix,
6
+ eye, ImmutableMatrix)
7
+ from sympy.matrices.expressions import Adjoint, Transpose, det, MatPow
8
+ from sympy.matrices.expressions.special import GenericIdentity
9
+ from sympy.matrices.expressions.matmul import (factor_in_front, remove_ids,
10
+ MatMul, combine_powers, any_zeros, unpack, only_squares)
11
+ from sympy.strategies import null_safe
12
+ from sympy.assumptions.ask import Q
13
+ from sympy.assumptions.refine import refine
14
+ from sympy.core.symbol import Symbol
15
+
16
+ from sympy.testing.pytest import XFAIL, raises
17
+
18
+ n, m, l, k = symbols('n m l k', integer=True)
19
+ x = symbols('x')
20
+ A = MatrixSymbol('A', n, m)
21
+ B = MatrixSymbol('B', m, l)
22
+ C = MatrixSymbol('C', n, n)
23
+ D = MatrixSymbol('D', n, n)
24
+ E = MatrixSymbol('E', m, n)
25
+
26
+ def test_evaluate():
27
+ assert MatMul(C, C, evaluate=True) == MatMul(C, C).doit()
28
+
29
+ def test_adjoint():
30
+ assert adjoint(A*B) == Adjoint(B)*Adjoint(A)
31
+ assert adjoint(2*A*B) == 2*Adjoint(B)*Adjoint(A)
32
+ assert adjoint(2*I*C) == -2*I*Adjoint(C)
33
+
34
+ M = Matrix(2, 2, [1, 2 + I, 3, 4])
35
+ MA = Matrix(2, 2, [1, 3, 2 - I, 4])
36
+ assert adjoint(M) == MA
37
+ assert adjoint(2*M) == 2*MA
38
+ assert adjoint(MatMul(2, M)) == MatMul(2, MA).doit()
39
+
40
+
41
+ def test_transpose():
42
+ assert transpose(A*B) == Transpose(B)*Transpose(A)
43
+ assert transpose(2*A*B) == 2*Transpose(B)*Transpose(A)
44
+ assert transpose(2*I*C) == 2*I*Transpose(C)
45
+
46
+ M = Matrix(2, 2, [1, 2 + I, 3, 4])
47
+ MT = Matrix(2, 2, [1, 3, 2 + I, 4])
48
+ assert transpose(M) == MT
49
+ assert transpose(2*M) == 2*MT
50
+ assert transpose(x*M) == x*MT
51
+ assert transpose(MatMul(2, M)) == MatMul(2, MT).doit()
52
+
53
+
54
+ def test_factor_in_front():
55
+ assert factor_in_front(MatMul(A, 2, B, evaluate=False)) ==\
56
+ MatMul(2, A, B, evaluate=False)
57
+
58
+
59
+ def test_remove_ids():
60
+ assert remove_ids(MatMul(A, Identity(m), B, evaluate=False)) == \
61
+ MatMul(A, B, evaluate=False)
62
+ assert null_safe(remove_ids)(MatMul(Identity(n), evaluate=False)) == \
63
+ MatMul(Identity(n), evaluate=False)
64
+
65
+
66
+ def test_combine_powers():
67
+ assert combine_powers(MatMul(D, Inverse(D), D, evaluate=False)) == \
68
+ MatMul(Identity(n), D, evaluate=False)
69
+ assert combine_powers(MatMul(B.T, Inverse(E*A), E, A, B, evaluate=False)) == \
70
+ MatMul(B.T, Identity(m), B, evaluate=False)
71
+ assert combine_powers(MatMul(A, E, Inverse(A*E), D, evaluate=False)) == \
72
+ MatMul(Identity(n), D, evaluate=False)
73
+
74
+
75
+ def test_any_zeros():
76
+ assert any_zeros(MatMul(A, ZeroMatrix(m, k), evaluate=False)) == \
77
+ ZeroMatrix(n, k)
78
+
79
+
80
+ def test_unpack():
81
+ assert unpack(MatMul(A, evaluate=False)) == A
82
+ x = MatMul(A, B)
83
+ assert unpack(x) == x
84
+
85
+
86
+ def test_only_squares():
87
+ assert only_squares(C) == [C]
88
+ assert only_squares(C, D) == [C, D]
89
+ assert only_squares(C, A, A.T, D) == [C, A*A.T, D]
90
+
91
+
92
+ def test_determinant():
93
+ assert det(2*C) == 2**n*det(C)
94
+ assert det(2*C*D) == 2**n*det(C)*det(D)
95
+ assert det(3*C*A*A.T*D) == 3**n*det(C)*det(A*A.T)*det(D)
96
+
97
+
98
+ def test_doit():
99
+ assert MatMul(C, 2, D).args == (C, 2, D)
100
+ assert MatMul(C, 2, D).doit().args == (2, C, D)
101
+ assert MatMul(C, Transpose(D*C)).args == (C, Transpose(D*C))
102
+ assert MatMul(C, Transpose(D*C)).doit(deep=True).args == (C, C.T, D.T)
103
+
104
+
105
+ def test_doit_drills_down():
106
+ X = ImmutableMatrix([[1, 2], [3, 4]])
107
+ Y = ImmutableMatrix([[2, 3], [4, 5]])
108
+ assert MatMul(X, MatPow(Y, 2)).doit() == X*Y**2
109
+ assert MatMul(C, Transpose(D*C)).doit().args == (C, C.T, D.T)
110
+
111
+
112
+ def test_doit_deep_false_still_canonical():
113
+ assert (MatMul(C, Transpose(D*C), 2).doit(deep=False).args ==
114
+ (2, C, Transpose(D*C)))
115
+
116
+
117
+ def test_matmul_scalar_Matrix_doit():
118
+ # Issue 9053
119
+ X = Matrix([[1, 2], [3, 4]])
120
+ assert MatMul(2, X).doit() == 2*X
121
+
122
+
123
+ def test_matmul_sympify():
124
+ assert isinstance(MatMul(eye(1), eye(1)).args[0], Basic)
125
+
126
+
127
+ def test_collapse_MatrixBase():
128
+ A = Matrix([[1, 1], [1, 1]])
129
+ B = Matrix([[1, 2], [3, 4]])
130
+ assert MatMul(A, B).doit() == ImmutableMatrix([[4, 6], [4, 6]])
131
+
132
+
133
+ def test_refine():
134
+ assert refine(C*C.T*D, Q.orthogonal(C)).doit() == D
135
+
136
+ kC = k*C
137
+ assert refine(kC*C.T, Q.orthogonal(C)).doit() == k*Identity(n)
138
+ assert refine(kC* kC.T, Q.orthogonal(C)).doit() == (k**2)*Identity(n)
139
+
140
+ def test_matmul_no_matrices():
141
+ assert MatMul(1) == 1
142
+ assert MatMul(n, m) == n*m
143
+ assert not isinstance(MatMul(n, m), MatMul)
144
+
145
+ def test_matmul_args_cnc():
146
+ assert MatMul(n, A, A.T).args_cnc() == [[n], [A, A.T]]
147
+ assert MatMul(A, A.T).args_cnc() == [[], [A, A.T]]
148
+
149
+ @XFAIL
150
+ def test_matmul_args_cnc_symbols():
151
+ # Not currently supported
152
+ a, b = symbols('a b', commutative=False)
153
+ assert MatMul(n, a, b, A, A.T).args_cnc() == [[n], [a, b, A, A.T]]
154
+ assert MatMul(n, a, A, b, A.T).args_cnc() == [[n], [a, A, b, A.T]]
155
+
156
+ def test_issue_12950():
157
+ M = Matrix([[Symbol("x")]]) * MatrixSymbol("A", 1, 1)
158
+ assert MatrixSymbol("A", 1, 1).as_explicit()[0]*Symbol('x') == M.as_explicit()[0]
159
+
160
+ def test_construction_with_Mul():
161
+ assert Mul(C, D) == MatMul(C, D)
162
+ assert Mul(D, C) == MatMul(D, C)
163
+
164
+ def test_construction_with_mul():
165
+ assert mul(C, D) == MatMul(C, D)
166
+ assert mul(D, C) == MatMul(D, C)
167
+ assert mul(C, D) != MatMul(D, C)
168
+
169
+ def test_generic_identity():
170
+ assert MatMul.identity == GenericIdentity()
171
+ assert MatMul.identity != S.One
172
+
173
+
174
+ def test_issue_23519():
175
+ N = Symbol("N", integer=True)
176
+ M1 = MatrixSymbol("M1", N, N)
177
+ M2 = MatrixSymbol("M2", N, N)
178
+ I = Identity(N)
179
+ z = (M2 + 2 * (M2 + I) * M1 + I)
180
+ assert z.coeff(M1) == 2*I + 2*M2
181
+
182
+
183
+ def test_shape_error():
184
+ A = MatrixSymbol('A', 2, 2)
185
+ B = MatrixSymbol('B', 3, 3)
186
+ raises(ShapeError, lambda: MatMul(A, B))
tuning-competition-baseline/.venv/lib/python3.11/site-packages/sympy/matrices/expressions/tests/test_permutation.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sympy.combinatorics import Permutation
2
+ from sympy.core.expr import unchanged
3
+ from sympy.matrices import Matrix
4
+ from sympy.matrices.expressions import \
5
+ MatMul, BlockDiagMatrix, Determinant, Inverse
6
+ from sympy.matrices.expressions.matexpr import MatrixSymbol
7
+ from sympy.matrices.expressions.special import ZeroMatrix, OneMatrix, Identity
8
+ from sympy.matrices.expressions.permutation import \
9
+ MatrixPermute, PermutationMatrix
10
+ from sympy.testing.pytest import raises
11
+ from sympy.core.symbol import Symbol
12
+
13
+
14
+ def test_PermutationMatrix_basic():
15
+ p = Permutation([1, 0])
16
+ assert unchanged(PermutationMatrix, p)
17
+ raises(ValueError, lambda: PermutationMatrix((0, 1, 2)))
18
+ assert PermutationMatrix(p).as_explicit() == Matrix([[0, 1], [1, 0]])
19
+ assert isinstance(PermutationMatrix(p)*MatrixSymbol('A', 2, 2), MatMul)
20
+
21
+
22
+ def test_PermutationMatrix_matmul():
23
+ p = Permutation([1, 2, 0])
24
+ P = PermutationMatrix(p)
25
+ M = Matrix([[0, 1, 2], [3, 4, 5], [6, 7, 8]])
26
+ assert (P*M).as_explicit() == P.as_explicit()*M
27
+ assert (M*P).as_explicit() == M*P.as_explicit()
28
+
29
+ P1 = PermutationMatrix(Permutation([1, 2, 0]))
30
+ P2 = PermutationMatrix(Permutation([2, 1, 0]))
31
+ P3 = PermutationMatrix(Permutation([1, 0, 2]))
32
+ assert P1*P2 == P3
33
+
34
+
35
+ def test_PermutationMatrix_matpow():
36
+ p1 = Permutation([1, 2, 0])
37
+ P1 = PermutationMatrix(p1)
38
+ p2 = Permutation([2, 0, 1])
39
+ P2 = PermutationMatrix(p2)
40
+ assert P1**2 == P2
41
+ assert P1**3 == Identity(3)
42
+
43
+
44
+ def test_PermutationMatrix_identity():
45
+ p = Permutation([0, 1])
46
+ assert PermutationMatrix(p).is_Identity
47
+
48
+ p = Permutation([1, 0])
49
+ assert not PermutationMatrix(p).is_Identity
50
+
51
+
52
+ def test_PermutationMatrix_determinant():
53
+ P = PermutationMatrix(Permutation([0, 1, 2]))
54
+ assert Determinant(P).doit() == 1
55
+ P = PermutationMatrix(Permutation([0, 2, 1]))
56
+ assert Determinant(P).doit() == -1
57
+ P = PermutationMatrix(Permutation([2, 0, 1]))
58
+ assert Determinant(P).doit() == 1
59
+
60
+
61
+ def test_PermutationMatrix_inverse():
62
+ P = PermutationMatrix(Permutation(0, 1, 2))
63
+ assert Inverse(P).doit() == PermutationMatrix(Permutation(0, 2, 1))
64
+
65
+
66
+ def test_PermutationMatrix_rewrite_BlockDiagMatrix():
67
+ P = PermutationMatrix(Permutation([0, 1, 2, 3, 4, 5]))
68
+ P0 = PermutationMatrix(Permutation([0]))
69
+ assert P.rewrite(BlockDiagMatrix) == \
70
+ BlockDiagMatrix(P0, P0, P0, P0, P0, P0)
71
+
72
+ P = PermutationMatrix(Permutation([0, 1, 3, 2, 4, 5]))
73
+ P10 = PermutationMatrix(Permutation(0, 1))
74
+ assert P.rewrite(BlockDiagMatrix) == \
75
+ BlockDiagMatrix(P0, P0, P10, P0, P0)
76
+
77
+ P = PermutationMatrix(Permutation([1, 0, 3, 2, 5, 4]))
78
+ assert P.rewrite(BlockDiagMatrix) == \
79
+ BlockDiagMatrix(P10, P10, P10)
80
+
81
+ P = PermutationMatrix(Permutation([0, 4, 3, 2, 1, 5]))
82
+ P3210 = PermutationMatrix(Permutation([3, 2, 1, 0]))
83
+ assert P.rewrite(BlockDiagMatrix) == \
84
+ BlockDiagMatrix(P0, P3210, P0)
85
+
86
+ P = PermutationMatrix(Permutation([0, 4, 2, 3, 1, 5]))
87
+ P3120 = PermutationMatrix(Permutation([3, 1, 2, 0]))
88
+ assert P.rewrite(BlockDiagMatrix) == \
89
+ BlockDiagMatrix(P0, P3120, P0)
90
+
91
+ P = PermutationMatrix(Permutation(0, 3)(1, 4)(2, 5))
92
+ assert P.rewrite(BlockDiagMatrix) == BlockDiagMatrix(P)
93
+
94
+
95
+ def test_MartrixPermute_basic():
96
+ p = Permutation(0, 1)
97
+ P = PermutationMatrix(p)
98
+ A = MatrixSymbol('A', 2, 2)
99
+
100
+ raises(ValueError, lambda: MatrixPermute(Symbol('x'), p))
101
+ raises(ValueError, lambda: MatrixPermute(A, Symbol('x')))
102
+
103
+ assert MatrixPermute(A, P) == MatrixPermute(A, p)
104
+ raises(ValueError, lambda: MatrixPermute(A, p, 2))
105
+
106
+ pp = Permutation(0, 1, size=3)
107
+ assert MatrixPermute(A, pp) == MatrixPermute(A, p)
108
+ pp = Permutation(0, 1, 2)
109
+ raises(ValueError, lambda: MatrixPermute(A, pp))
110
+
111
+
112
+ def test_MatrixPermute_shape():
113
+ p = Permutation(0, 1)
114
+ A = MatrixSymbol('A', 2, 3)
115
+ assert MatrixPermute(A, p).shape == (2, 3)
116
+
117
+
118
+ def test_MatrixPermute_explicit():
119
+ p = Permutation(0, 1, 2)
120
+ A = MatrixSymbol('A', 3, 3)
121
+ AA = A.as_explicit()
122
+ assert MatrixPermute(A, p, 0).as_explicit() == \
123
+ AA.permute(p, orientation='rows')
124
+ assert MatrixPermute(A, p, 1).as_explicit() == \
125
+ AA.permute(p, orientation='cols')
126
+
127
+
128
+ def test_MatrixPermute_rewrite_MatMul():
129
+ p = Permutation(0, 1, 2)
130
+ A = MatrixSymbol('A', 3, 3)
131
+
132
+ assert MatrixPermute(A, p, 0).rewrite(MatMul).as_explicit() == \
133
+ MatrixPermute(A, p, 0).as_explicit()
134
+ assert MatrixPermute(A, p, 1).rewrite(MatMul).as_explicit() == \
135
+ MatrixPermute(A, p, 1).as_explicit()
136
+
137
+
138
+ def test_MatrixPermute_doit():
139
+ p = Permutation(0, 1, 2)
140
+ A = MatrixSymbol('A', 3, 3)
141
+ assert MatrixPermute(A, p).doit() == MatrixPermute(A, p)
142
+
143
+ p = Permutation(0, size=3)
144
+ A = MatrixSymbol('A', 3, 3)
145
+ assert MatrixPermute(A, p).doit().as_explicit() == \
146
+ MatrixPermute(A, p).as_explicit()
147
+
148
+ p = Permutation(0, 1, 2)
149
+ A = Identity(3)
150
+ assert MatrixPermute(A, p, 0).doit().as_explicit() == \
151
+ MatrixPermute(A, p, 0).as_explicit()
152
+ assert MatrixPermute(A, p, 1).doit().as_explicit() == \
153
+ MatrixPermute(A, p, 1).as_explicit()
154
+
155
+ A = ZeroMatrix(3, 3)
156
+ assert MatrixPermute(A, p).doit() == A
157
+ A = OneMatrix(3, 3)
158
+ assert MatrixPermute(A, p).doit() == A
159
+
160
+ A = MatrixSymbol('A', 4, 4)
161
+ p1 = Permutation(0, 1, 2, 3)
162
+ p2 = Permutation(0, 2, 3, 1)
163
+ expr = MatrixPermute(MatrixPermute(A, p1, 0), p2, 0)
164
+ assert expr.as_explicit() == expr.doit().as_explicit()
165
+ expr = MatrixPermute(MatrixPermute(A, p1, 1), p2, 1)
166
+ assert expr.as_explicit() == expr.doit().as_explicit()