Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- backend/.github/workflows/ci.yml +0 -0
- backend/README.md +204 -0
- backend/README_backend.md +71 -0
- backend/README_demo.md +25 -0
- backend/__init__.py +6 -0
- backend/adapters/graph_adapter.py +81 -0
- backend/adapters/vector_adapter.py +33 -0
- backend/adapters/vector_adapter_full.py +55 -0
- backend/api/analysis_routes.py +27 -0
- backend/api/auth.py +13 -0
- backend/api/crud.py +85 -0
- backend/api/example_payloads.md +50 -0
- backend/api/neuro_symbolic_routes.py +31 -0
- backend/api/quantum_routes.py +18 -0
- backend/api/query_routes.py +50 -0
- backend/api/routes.py +96 -0
- backend/api/schemas.py +143 -0
- backend/api/vector_routes.py +106 -0
- backend/api/visualization_routes.py +65 -0
- backend/core/.rustup/settings.toml +4 -0
- backend/core/ag4masses/.gitignore +27 -0
- backend/core/ag4masses/CONTRIBUTING.md +13 -0
- backend/core/ag4masses/LICENSE +202 -0
- backend/core/ag4masses/README.md +346 -0
- backend/core/ag4masses/alphageometry/CONTRIBUTING.md +25 -0
- backend/core/ag4masses/alphageometry/alphageometry.py +778 -0
- backend/core/ag4masses/alphageometry/alphageometry_test.py +103 -0
- backend/core/ag4masses/alphageometry/ar.py +752 -0
- backend/core/ag4masses/alphageometry/ar_test.py +204 -0
- backend/core/ag4masses/alphageometry/beam_search.py +463 -0
- backend/core/ag4masses/alphageometry/dd.py +1156 -0
- backend/core/ag4masses/alphageometry/dd_test.py +79 -0
- backend/core/ag4masses/alphageometry/ddar.py +159 -0
- backend/core/ag4masses/alphageometry/ddar_test.py +65 -0
- backend/core/ag4masses/alphageometry/decoder_stack.py +55 -0
- backend/core/ag4masses/alphageometry/defs.txt +407 -0
- backend/core/ag4masses/alphageometry/download.sh +17 -0
- backend/core/ag4masses/alphageometry/examples.txt +8 -0
- backend/core/ag4masses/alphageometry/fig1.svg +0 -0
- backend/core/ag4masses/alphageometry/geometry.py +578 -0
- backend/core/ag4masses/alphageometry/geometry_150M_generate.gin +47 -0
- backend/core/ag4masses/alphageometry/geometry_test.py +80 -0
- backend/core/ag4masses/alphageometry/graph.py +3057 -0
- backend/core/ag4masses/alphageometry/graph_test.py +164 -0
- backend/core/alphageometry_adapter.py +118 -0
- backend/core/alphageometry_runner.py +0 -0
- backend/core/captum.py +21 -0
- backend/core/coq_adapter.py +20 -0
- backend/core/cross_universe_analysis.py +599 -0
- backend/core/ddar.py +24 -0
backend/.github/workflows/ci.yml
ADDED
|
File without changes
|
backend/README.md
ADDED
|
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Project V1 Backend
|
| 2 |
+
|
| 3 |
+
## Overview
|
| 4 |
+
This backend implements the Genesis Engine for mathematical universe simulation, axiom evolution, theorem derivation, and persistent history tracking. It is built with FastAPI, SQLAlchemy, and SymPy.
|
| 5 |
+
|
| 6 |
+
## Core Logic & Algorithms
|
| 7 |
+
- **AlphaGeometry, Lean 4, and Coq Integration:**
|
| 8 |
+
- The backend can now run external proof engines for advanced theorem proving and verification.
|
| 9 |
+
- Adapters: `backend/core/alphageometry_adapter.py`, `lean_adapter.py`, `coq_adapter.py`.
|
| 10 |
+
- Usage example:
|
| 11 |
+
```python
|
| 12 |
+
from backend.core.alphageometry_adapter import run_alphageometry
|
| 13 |
+
output = run_alphageometry("path/to/input_file")
|
| 14 |
+
```
|
| 15 |
+
- Similar usage for `run_lean4` and `run_coq`.
|
| 16 |
+
- Make sure the external tools are downloaded and paths are correct (see adapters for details).
|
| 17 |
+
- **Universe Generation:** Create universes with custom types and axioms.
|
| 18 |
+
- **Axiom Evolution:** Add, evolve, and track axioms with lineage and versioning.
|
| 19 |
+
- **Theorem Derivation:** Use symbolic logic (SymPy) to derive theorems from axioms and store proofs.
|
| 20 |
+
- **History Tracking:** All changes to universes and axioms are versioned and timestamped.
|
| 21 |
+
- **Neuro-Symbolic Network:** Train and use a neural network (PyTorch) to guide proof search and theory growth.
|
| 22 |
+
- **Quantum-Inspired Algorithms:** Classical simulation of Grover’s search and other quantum algorithms for proof exploration.
|
| 23 |
+
- **Cross-Universe Analysis:** Compare multiple universes to find shared axioms, theorems, and patterns. Results are stored in the database.
|
| 24 |
+
- Advanced analysis algorithms can be added to detect invariants, patterns, and relationships across universes. Results are queryable via the API and stored for further research.
|
| 25 |
+
- **3D Visualization & Query Interface:** Backend endpoints provide graph data for universes, axioms, and theorems to support interactive frontend visualization.
|
| 26 |
+
- **Query Engine:** API endpoints answer complex mathematical questions and generate universe/theorem summaries for research and visualization.
|
| 27 |
+
|
| 28 |
+
## API Endpoints
|
| 29 |
+
- `POST /universes` — Create a new universe (with optional axioms and type)
|
| 30 |
+
- `GET /universes` — List all universes
|
| 31 |
+
- `GET /universes/{universe_id}/history` — Retrieve universe history and axiom lineage
|
| 32 |
+
- `GET /axioms/{universe_id}` — List axioms for a universe
|
| 33 |
+
- `POST /axioms` — Add a new axiom
|
| 34 |
+
- `POST /axioms/evolve` — Evolve an axiom (with lineage)
|
| 35 |
+
- `POST /theorems/derive` — Derive a theorem from axioms
|
| 36 |
+
- `GET /theorems/{universe_id}` — List theorems for a universe
|
| 37 |
+
- `POST /neuro/train` — Train the neuro-symbolic network
|
| 38 |
+
- `POST /neuro/predict` — Predict with the neuro-symbolic network
|
| 39 |
+
- `POST /neuro/guide` — Guide proof search using the neuro-symbolic network
|
| 40 |
+
- `POST /quantum/grover` — Run Grover’s search algorithm simulation
|
| 41 |
+
- `POST /analysis/cross_universe` — Run cross-universe analysis and retrieve shared axioms/theorems
|
| 42 |
+
- `GET /visualization/universe/{universe_id}` — Get graph data for a single universe
|
| 43 |
+
- `GET /visualization/universes` — Get graph data for all universes
|
| 44 |
+
- `GET /query/universe_summary/{universe_id}` — Get a summary of a universe (axioms, theorems, counts)
|
| 45 |
+
- `GET /query/axiom_usage/{axiom_id}` — Get usage of an axiom in theorems
|
| 46 |
+
|
| 47 |
+
## Usage Example
|
| 48 |
+
```python
|
| 49 |
+
# Create a universe
|
| 50 |
+
POST /universes
|
| 51 |
+
{
|
| 52 |
+
"name": "Group Theory",
|
| 53 |
+
"description": "Universe for group theory",
|
| 54 |
+
"universe_type": "group_theory",
|
| 55 |
+
"axioms": ["Closure", "Associativity", "Identity", "Inverse"]
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
# Add an axiom
|
| 59 |
+
POST /axioms
|
| 60 |
+
{
|
| 61 |
+
"universe_id": 1,
|
| 62 |
+
"statement": "Commutativity"
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
# Evolve an axiom
|
| 66 |
+
POST /axioms/evolve
|
| 67 |
+
{
|
| 68 |
+
"axiom_id": 2,
|
| 69 |
+
"new_statement": "Commutativity (strong)"
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
# Derive a theorem
|
| 73 |
+
POST /theorems/derive
|
| 74 |
+
{
|
| 75 |
+
"universe_id": 1,
|
| 76 |
+
"axiom_ids": [1, 2],
|
| 77 |
+
"statement": "Closure Commutativity"
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
# Train the neuro-symbolic network
|
| 81 |
+
POST /neuro/train
|
| 82 |
+
{
|
| 83 |
+
"training_data": [[0.1, 0.2, ...], [0.3, 0.4, ...]],
|
| 84 |
+
"labels": [0, 1],
|
| 85 |
+
"epochs": 10
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
# Predict with the neuro-symbolic network
|
| 89 |
+
POST /neuro/predict
|
| 90 |
+
{
|
| 91 |
+
"input_data": [[0.1, 0.2, ...], [0.3, 0.4, ...]]
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
# Guide proof search
|
| 95 |
+
POST /neuro/guide
|
| 96 |
+
{
|
| 97 |
+
"universe_id": 1,
|
| 98 |
+
"axiom_ids": [1, 2, 3]
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
# Run Grover’s search
|
| 102 |
+
POST /quantum/grover
|
| 103 |
+
{
|
| 104 |
+
"database_size": 16,
|
| 105 |
+
"target_idx": 5,
|
| 106 |
+
"iterations": 3
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
# Run cross-universe analysis
|
| 110 |
+
POST /analysis/cross_universe
|
| 111 |
+
{
|
| 112 |
+
"universe_ids": [1, 2, 3]
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
# Get graph data for a universe
|
| 116 |
+
GET /visualization/universe/1
|
| 117 |
+
|
| 118 |
+
# Get graph data for all universes
|
| 119 |
+
GET /visualization/universes
|
| 120 |
+
|
| 121 |
+
# Get a universe summary
|
| 122 |
+
GET /query/universe_summary/1
|
| 123 |
+
|
| 124 |
+
# Get axiom usage
|
| 125 |
+
GET /query/axiom_usage/2
|
| 126 |
+
```
|
| 127 |
+
|
| 128 |
+
## Developer Guide
|
| 129 |
+
- All core logic is in `backend/core/`.
|
| 130 |
+
- Database models are in `backend/db/models.py`.
|
| 131 |
+
- API endpoints are in `backend/api/routes.py`.
|
| 132 |
+
- Cross-universe analysis logic is in `backend/core/cross_universe_analysis.py`.
|
| 133 |
+
- API endpoint for analysis is in `backend/api/analysis_routes.py`.
|
| 134 |
+
- Tests are in `backend/tests/`.
|
| 135 |
+
- Tests for analysis are in `backend/tests/test_analysis.py`.
|
| 136 |
+
- Environment variables are set in `.env`.
|
| 137 |
+
|
| 138 |
+
## Running & Testing
|
| 139 |
+
1. Install dependencies: `pip install -r requirements.txt`
|
| 140 |
+
2. Start server: `uvicorn backend.app:app --reload`
|
| 141 |
+
3. Run tests: `pytest backend/tests/`
|
| 142 |
+
|
| 143 |
+
## Deployment & Maintenance
|
| 144 |
+
|
| 145 |
+
### Docker
|
| 146 |
+
Build and run the backend in a container:
|
| 147 |
+
```sh
|
| 148 |
+
docker build -t projectv1-backend .
|
| 149 |
+
docker run -p 8000:8000 --env-file backend/.env projectv1-backend
|
| 150 |
+
```
|
| 151 |
+
|
| 152 |
+
### CI/CD
|
| 153 |
+
GitHub Actions workflow is set up in `.github/workflows/ci.yml` to run tests on every push and pull request.
|
| 154 |
+
|
| 155 |
+
### Maintenance
|
| 156 |
+
- Monitor logs and errors for performance issues.
|
| 157 |
+
- Regularly update dependencies and security patches.
|
| 158 |
+
- Scale with Docker and orchestration tools as needed.
|
| 159 |
+
|
| 160 |
+
## Contributing
|
| 161 |
+
- Follow code style and add tests for new features.
|
| 162 |
+
|
| 163 |
+
---
|
| 164 |
+
|
| 165 |
+
## Production Monitoring & Logging
|
| 166 |
+
|
| 167 |
+
- **Sentry Integration:**
|
| 168 |
+
- Sentry is integrated for error monitoring. To enable, set the `SENTRY_DSN` environment variable in your `.env` file.
|
| 169 |
+
- Install Sentry with `pip install sentry-sdk` (already included in requirements).
|
| 170 |
+
- Adjust `traces_sample_rate` in `backend/core/logging_config.py` for your needs.
|
| 171 |
+
|
| 172 |
+
- **Prometheus/Grafana:**
|
| 173 |
+
- For advanced metrics, consider adding [Prometheus FastAPI Instrumentator](https://github.com/trallard/fastapi_prometheus) and exporting metrics to Grafana.
|
| 174 |
+
- Example: `pip install prometheus-fastapi-instrumentator`
|
| 175 |
+
|
| 176 |
+
## Database Optimization
|
| 177 |
+
- All major foreign keys and frequently queried fields are indexed (see `backend/db/models.py`).
|
| 178 |
+
- For large-scale deployments, consider query profiling and further index tuning based on real-world usage.
|
| 179 |
+
|
| 180 |
+
## Security Best Practices
|
| 181 |
+
- API key authentication is required for all endpoints (see `backend/api/auth.py`).
|
| 182 |
+
- Store secrets in `.env` and never commit them to version control.
|
| 183 |
+
- Regularly update dependencies for security patches.
|
| 184 |
+
- Use HTTPS in production.
|
| 185 |
+
- Limit database and API access by IP/firewall as needed.
|
| 186 |
+
|
| 187 |
+
## Troubleshooting
|
| 188 |
+
- **Common Issues:**
|
| 189 |
+
- *Database connection errors*: Check your DB URL and credentials in `.env`.
|
| 190 |
+
- *Missing dependencies*: Run `pip install -r requirements.txt`.
|
| 191 |
+
- *Sentry not reporting*: Ensure `SENTRY_DSN` is set and `sentry-sdk` is installed.
|
| 192 |
+
- *API key errors*: Make sure your request includes the correct API key header.
|
| 193 |
+
- **Logs:**
|
| 194 |
+
- All errors and important events are logged. Check your server logs for details.
|
| 195 |
+
|
| 196 |
+
## External Resources
|
| 197 |
+
- [FastAPI Documentation](https://fastapi.tiangolo.com/)
|
| 198 |
+
- [SQLAlchemy Documentation](https://docs.sqlalchemy.org/)
|
| 199 |
+
- [Sentry for Python](https://docs.sentry.io/platforms/python/)
|
| 200 |
+
- [Prometheus FastAPI Instrumentator](https://github.com/trallard/fastapi_prometheus)
|
| 201 |
+
- [PyTorch](https://pytorch.org/)
|
| 202 |
+
- [SymPy](https://www.sympy.org/)
|
| 203 |
+
- [Docker](https://docs.docker.com/)
|
| 204 |
+
- [GitHub Actions](https://docs.github.com/en/actions)
|
backend/README_backend.md
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Backend quickstart (development)
|
| 2 |
+
|
| 3 |
+
This document contains quick instructions to get the backend running for local development and testing.
|
| 4 |
+
|
| 5 |
+
Prerequisites
|
| 6 |
+
- Python 3.10+ (recommended)
|
| 7 |
+
- A virtual environment (venv, conda, etc.)
|
| 8 |
+
|
| 9 |
+
Install dependencies (recommended in a venv):
|
| 10 |
+
|
| 11 |
+
```powershell
|
| 12 |
+
python -m venv .venv
|
| 13 |
+
.\.venv\Scripts\Activate.ps1
|
| 14 |
+
pip install -r requirements.txt
|
| 15 |
+
pip install -r requirements-dev.txt
|
| 16 |
+
```
|
| 17 |
+
|
| 18 |
+
Environment
|
| 19 |
+
- Copy `.env.example` to `.env` and edit values as needed. By default the code will use an in-memory SQLite DB.
|
| 20 |
+
|
| 21 |
+
Run the app (development):
|
| 22 |
+
|
| 23 |
+
```powershell
|
| 24 |
+
# From repository root
|
| 25 |
+
uvicorn backend.app:app --reload --host 127.0.0.1 --port 8000
|
| 26 |
+
```
|
| 27 |
+
|
| 28 |
+
Run tests:
|
| 29 |
+
|
| 30 |
+
```powershell
|
| 31 |
+
# activate venv first
|
| 32 |
+
pytest -q backend/tests
|
| 33 |
+
```
|
| 34 |
+
|
| 35 |
+
Notes
|
| 36 |
+
- The repository includes defensive fallbacks for some optional heavy dependencies; for full functionality you should install the optional packages listed in `requirements-dev.txt`.
|
| 37 |
+
- The DB defaults to `sqlite:///:memory:` when no `DB_URL` is set in `.env` for easy local testing.
|
| 38 |
+
|
| 39 |
+
## Example API Payloads
|
| 40 |
+
|
| 41 |
+
See `backend/api/example_payloads.md` for sample requests.
|
| 42 |
+
|
| 43 |
+
### Create Universe
|
| 44 |
+
POST /universes
|
| 45 |
+
```json
|
| 46 |
+
{
|
| 47 |
+
"name": "Group Theory",
|
| 48 |
+
"description": "Universe for group theory",
|
| 49 |
+
"universe_type": "group_theory",
|
| 50 |
+
"axioms": ["Closure", "Associativity", "Identity", "Inverse"]
|
| 51 |
+
}
|
| 52 |
+
```
|
| 53 |
+
|
| 54 |
+
### Add Axiom
|
| 55 |
+
POST /axioms
|
| 56 |
+
```json
|
| 57 |
+
{
|
| 58 |
+
"universe_id": 1,
|
| 59 |
+
"statement": "Commutativity"
|
| 60 |
+
}
|
| 61 |
+
```
|
| 62 |
+
|
| 63 |
+
### Derive Theorem
|
| 64 |
+
POST /theorems/derive
|
| 65 |
+
```json
|
| 66 |
+
{
|
| 67 |
+
"universe_id": 1,
|
| 68 |
+
"axiom_ids": [1, 2],
|
| 69 |
+
"statement": "Closure Commutativity"
|
| 70 |
+
}
|
| 71 |
+
```
|
backend/README_demo.md
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
AlphaGeometry demo backend
|
| 2 |
+
|
| 3 |
+
Quick start (uses pure-Python fallbacks, no Docker required):
|
| 4 |
+
|
| 5 |
+
1. Create a virtualenv and install minimal deps:
|
| 6 |
+
|
| 7 |
+
```powershell
|
| 8 |
+
python -m venv .venv
|
| 9 |
+
.\.venv\Scripts\Activate.ps1
|
| 10 |
+
pip install -r requirements-merged.txt
|
| 11 |
+
```
|
| 12 |
+
|
| 13 |
+
2. Run the demo app:
|
| 14 |
+
|
| 15 |
+
```powershell
|
| 16 |
+
python -m backend.run_demo
|
| 17 |
+
```
|
| 18 |
+
|
| 19 |
+
Notes:
|
| 20 |
+
- The repo contains a top-level folder named `fastapi/` which may shadow the installed
|
| 21 |
+
`fastapi` package. If you see errors when starting the app, run inside a clean virtualenv
|
| 22 |
+
where `fastapi` is installed, or rename the repo-local `fastapi/` folder.
|
| 23 |
+
- Neo4j and FAISS are optional; the demo uses `networkx` and an in-memory vector index.
|
| 24 |
+
- To wire real Neo4j, install Docker and the `neo4j` / `py2neo` python packages and configure
|
| 25 |
+
`backend/adapters/graph_adapter.py` with the connection URI.
|
backend/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Backend package for demo API and adapters."""
|
| 2 |
+
|
| 3 |
+
__all__ = ["api", "universe", "adapters", "prover_adapter"]
|
| 4 |
+
# Backend package initializer
|
| 5 |
+
|
| 6 |
+
# This file makes `backend` a Python package so tests can import it.
|
backend/adapters/graph_adapter.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Graph adapter: NetworkX fallback and a full Neo4j adapter (lazy imports).
|
| 2 |
+
|
| 3 |
+
This file provides a production-ready adapter implementation that will use the
|
| 4 |
+
`neo4j` python driver when available and fall back to an in-memory NetworkX
|
| 5 |
+
graph otherwise.
|
| 6 |
+
"""
|
| 7 |
+
from typing import Any, Dict, List, Optional
|
| 8 |
+
|
| 9 |
+
import logging
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
try:
|
| 13 |
+
import networkx as nx
|
| 14 |
+
except Exception:
|
| 15 |
+
nx = None
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class NetworkXGraph:
|
| 19 |
+
def __init__(self):
|
| 20 |
+
if nx is None:
|
| 21 |
+
raise RuntimeError("networkx is not available")
|
| 22 |
+
self.g = nx.MultiDiGraph()
|
| 23 |
+
|
| 24 |
+
def add_node(self, node_id: str, **props: Any):
|
| 25 |
+
self.g.add_node(node_id, **props)
|
| 26 |
+
|
| 27 |
+
def add_edge(self, a: str, b: str, **props: Any):
|
| 28 |
+
self.g.add_edge(a, b, **props)
|
| 29 |
+
|
| 30 |
+
def find_nodes(self, key: str, value: str) -> List[str]:
|
| 31 |
+
return [n for n, d in self.g.nodes(data=True) if d.get(key) == value]
|
| 32 |
+
|
| 33 |
+
def run_cypher(self, query: str, **params: Any):
|
| 34 |
+
# Not applicable for NetworkX; provide simple pattern matcher if needed
|
| 35 |
+
raise NotImplementedError("Cypher not supported for NetworkX fallback")
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class Neo4jAdapter:
|
| 39 |
+
def __init__(self, uri: Optional[str] = None, user: Optional[str] = None, password: Optional[str] = None):
|
| 40 |
+
self._driver = None
|
| 41 |
+
self._connected = False
|
| 42 |
+
self._uri = uri or "bolt://localhost:7687"
|
| 43 |
+
self._user = user or "neo4j"
|
| 44 |
+
self._password = password or "testpassword"
|
| 45 |
+
try:
|
| 46 |
+
# lazy import to avoid importing heavy driver at module import time
|
| 47 |
+
from neo4j import GraphDatabase
|
| 48 |
+
self._driver = GraphDatabase.driver(self._uri, auth=(self._user, self._password))
|
| 49 |
+
self._connected = True
|
| 50 |
+
except Exception as e:
|
| 51 |
+
logger.info("Neo4j driver not available or connection failed: %s", e)
|
| 52 |
+
self._driver = None
|
| 53 |
+
|
| 54 |
+
def is_available(self) -> bool:
|
| 55 |
+
return self._driver is not None
|
| 56 |
+
|
| 57 |
+
def close(self):
|
| 58 |
+
if self._driver:
|
| 59 |
+
try:
|
| 60 |
+
self._driver.close()
|
| 61 |
+
except Exception:
|
| 62 |
+
pass
|
| 63 |
+
|
| 64 |
+
def run(self, cypher: str, **params: Any) -> List[Dict[str, Any]]:
|
| 65 |
+
if not self._driver:
|
| 66 |
+
raise RuntimeError("Neo4j driver not available")
|
| 67 |
+
with self._driver.session() as session:
|
| 68 |
+
res = session.run(cypher, **params)
|
| 69 |
+
return [dict(record) for record in res]
|
| 70 |
+
|
| 71 |
+
def create_node(self, labels: List[str], props: Dict[str, Any]) -> Dict[str, Any]:
|
| 72 |
+
lbl = ":".join(labels) if labels else ""
|
| 73 |
+
cypher = f"CREATE (n:{lbl} $props) RETURN id(n) as id"
|
| 74 |
+
rows = self.run(cypher, props=props)
|
| 75 |
+
return rows[0] if rows else {}
|
| 76 |
+
|
| 77 |
+
def create_relationship(self, a_id: int, b_id: int, rel_type: str, props: Dict[str, Any] = None) -> Dict[str, Any]:
|
| 78 |
+
props = props or {}
|
| 79 |
+
cypher = "MATCH (a),(b) WHERE id(a)=$aid AND id(b)=$bid CREATE (a)-[r:%s $props]->(b) RETURN id(r) as id" % rel_type
|
| 80 |
+
rows = self.run(cypher, aid=a_id, bid=b_id, props=props)
|
| 81 |
+
return rows[0] if rows else {}
|
backend/adapters/vector_adapter.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Compatibility wrapper for vector adapters.
|
| 2 |
+
|
| 3 |
+
Provides `InMemoryVectorIndex` to keep older imports working and attempts to
|
| 4 |
+
use FAISS adapter if present.
|
| 5 |
+
"""
|
| 6 |
+
from typing import List, Tuple
|
| 7 |
+
try:
|
| 8 |
+
# try to import full adapter
|
| 9 |
+
from .vector_adapter_full import FaissIndex, HostedVectorAdapter
|
| 10 |
+
FAISS_AVAILABLE = True
|
| 11 |
+
except Exception:
|
| 12 |
+
FAISS_AVAILABLE = False
|
| 13 |
+
|
| 14 |
+
import math
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class InMemoryVectorIndex:
|
| 18 |
+
def __init__(self):
|
| 19 |
+
self.data: List[Tuple[str, List[float]]] = []
|
| 20 |
+
|
| 21 |
+
def upsert(self, id: str, vector: List[float]):
|
| 22 |
+
self.data.append((id, vector))
|
| 23 |
+
|
| 24 |
+
def search(self, vector: List[float], top_k: int = 10):
|
| 25 |
+
def score(a, b):
|
| 26 |
+
dot = sum(x * y for x, y in zip(a, b))
|
| 27 |
+
na = math.sqrt(sum(x * x for x in a))
|
| 28 |
+
nb = math.sqrt(sum(x * x for x in b))
|
| 29 |
+
return dot / (na * nb) if na and nb else 0.0
|
| 30 |
+
|
| 31 |
+
scored = [(id, score(vec, vector)) for id, vec in self.data]
|
| 32 |
+
scored.sort(key=lambda x: x[1], reverse=True)
|
| 33 |
+
return scored[:top_k]
|
backend/adapters/vector_adapter_full.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Vector adapters: FAISS-backed index and hosted HTTP adapter.
|
| 2 |
+
|
| 3 |
+
These adapters attempt to use faiss if available; otherwise they expose the
|
| 4 |
+
interfaces and raise clear errors when not installed.
|
| 5 |
+
"""
|
| 6 |
+
from typing import List, Tuple, Optional, Dict
|
| 7 |
+
import logging
|
| 8 |
+
logger = logging.getLogger(__name__)
|
| 9 |
+
|
| 10 |
+
try:
|
| 11 |
+
import numpy as np
|
| 12 |
+
except Exception:
|
| 13 |
+
np = None
|
| 14 |
+
|
| 15 |
+
try:
|
| 16 |
+
import faiss
|
| 17 |
+
except Exception:
|
| 18 |
+
faiss = None
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class FaissIndex:
|
| 22 |
+
def __init__(self, dim: int):
|
| 23 |
+
if faiss is None or np is None:
|
| 24 |
+
raise RuntimeError("faiss or numpy is not installed")
|
| 25 |
+
self.dim = dim
|
| 26 |
+
self.index = faiss.IndexFlatIP(dim)
|
| 27 |
+
self.ids: List[str] = []
|
| 28 |
+
|
| 29 |
+
def upsert(self, id: str, vector: List[float]):
|
| 30 |
+
v = np.array([vector], dtype='float32')
|
| 31 |
+
self.index.add(v)
|
| 32 |
+
self.ids.append(id)
|
| 33 |
+
|
| 34 |
+
def search(self, vector: List[float], top_k: int = 10) -> List[Tuple[str, float]]:
|
| 35 |
+
v = np.array([vector], dtype='float32')
|
| 36 |
+
D, I = self.index.search(v, top_k)
|
| 37 |
+
results = []
|
| 38 |
+
for score, idx in zip(D[0], I[0]):
|
| 39 |
+
if idx < 0:
|
| 40 |
+
continue
|
| 41 |
+
results.append((self.ids[idx], float(score)))
|
| 42 |
+
return results
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class HostedVectorAdapter:
|
| 46 |
+
def __init__(self, endpoint: str):
|
| 47 |
+
self.endpoint = endpoint
|
| 48 |
+
|
| 49 |
+
def upsert(self, id: str, vector: List[float]):
|
| 50 |
+
# placeholder: send HTTP request to hosted service
|
| 51 |
+
logger.info("Would upsert to hosted vector DB at %s", self.endpoint)
|
| 52 |
+
|
| 53 |
+
def search(self, vector: List[float], top_k: int = 10):
|
| 54 |
+
logger.info("Would query hosted vector DB at %s", self.endpoint)
|
| 55 |
+
return []
|
backend/api/analysis_routes.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import APIRouter, Depends, HTTPException
|
| 2 |
+
from sqlalchemy.orm import Session
|
| 3 |
+
from backend.db.session import SessionLocal
|
| 4 |
+
from backend.core.cross_universe_analysis import CrossUniverseAnalyzer
|
| 5 |
+
from backend.api.auth import get_api_key
|
| 6 |
+
from backend.core.logging_config import get_logger
|
| 7 |
+
|
| 8 |
+
router = APIRouter()
|
| 9 |
+
logger = get_logger("analysis_routes")
|
| 10 |
+
|
| 11 |
+
def get_db():
|
| 12 |
+
db = SessionLocal()
|
| 13 |
+
try:
|
| 14 |
+
yield db
|
| 15 |
+
finally:
|
| 16 |
+
db.close()
|
| 17 |
+
|
| 18 |
+
@router.post("/analysis/cross_universe")
|
| 19 |
+
def cross_universe_analysis(universe_ids: list[int], db: Session = Depends(get_db), api_key: str = Depends(get_api_key)):
|
| 20 |
+
try:
|
| 21 |
+
analyzer = CrossUniverseAnalyzer(db)
|
| 22 |
+
result = analyzer.analyze(universe_ids)
|
| 23 |
+
logger.info(f"Cross-universe analysis: universes={universe_ids}, result={result}")
|
| 24 |
+
return result
|
| 25 |
+
except Exception as e:
|
| 26 |
+
logger.error(f"Analysis error: {e}")
|
| 27 |
+
raise HTTPException(status_code=500, detail=str(e))
|
backend/api/auth.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import Depends, HTTPException, status
|
| 2 |
+
from fastapi.security import APIKeyHeader
|
| 3 |
+
|
| 4 |
+
API_KEY = "your_api_key_here" # Replace with a secure key or load from env
|
| 5 |
+
api_key_header = APIKeyHeader(name="X-API-Key")
|
| 6 |
+
|
| 7 |
+
def get_api_key(api_key: str = Depends(api_key_header)):
|
| 8 |
+
if api_key != API_KEY:
|
| 9 |
+
raise HTTPException(
|
| 10 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 11 |
+
detail="Invalid or missing API Key",
|
| 12 |
+
)
|
| 13 |
+
return api_key
|
backend/api/crud.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from backend.db.models import Universe, Axiom, Theorem, Proof
|
| 2 |
+
from backend.db.session import SessionLocal
|
| 3 |
+
from datetime import datetime
|
| 4 |
+
|
| 5 |
+
def create_universe(db, name, description, universe_type):
|
| 6 |
+
universe = Universe(name=name, description=description, universe_type=universe_type, version=1, created_at=str(datetime.utcnow()))
|
| 7 |
+
db.add(universe)
|
| 8 |
+
db.commit()
|
| 9 |
+
db.refresh(universe)
|
| 10 |
+
return universe
|
| 11 |
+
|
| 12 |
+
def update_universe(db, universe_id, **kwargs):
|
| 13 |
+
universe = db.query(Universe).filter(Universe.id == universe_id).first()
|
| 14 |
+
for key, value in kwargs.items():
|
| 15 |
+
setattr(universe, key, value)
|
| 16 |
+
universe.version += 1
|
| 17 |
+
db.commit()
|
| 18 |
+
db.refresh(universe)
|
| 19 |
+
return universe
|
| 20 |
+
|
| 21 |
+
def delete_universe(db, universe_id):
|
| 22 |
+
universe = db.query(Universe).filter(Universe.id == universe_id).first()
|
| 23 |
+
db.delete(universe)
|
| 24 |
+
db.commit()
|
| 25 |
+
|
| 26 |
+
def create_axiom(db, universe_id, statement, parent_axiom_id=None):
|
| 27 |
+
axiom = Axiom(universe_id=universe_id, statement=statement, parent_axiom_id=parent_axiom_id, version=1, created_at=str(datetime.utcnow()))
|
| 28 |
+
db.add(axiom)
|
| 29 |
+
db.commit()
|
| 30 |
+
db.refresh(axiom)
|
| 31 |
+
return axiom
|
| 32 |
+
|
| 33 |
+
def update_axiom(db, axiom_id, **kwargs):
|
| 34 |
+
axiom = db.query(Axiom).filter(Axiom.id == axiom_id).first()
|
| 35 |
+
for key, value in kwargs.items():
|
| 36 |
+
setattr(axiom, key, value)
|
| 37 |
+
axiom.version += 1
|
| 38 |
+
db.commit()
|
| 39 |
+
db.refresh(axiom)
|
| 40 |
+
return axiom
|
| 41 |
+
|
| 42 |
+
def delete_axiom(db, axiom_id):
|
| 43 |
+
axiom = db.query(Axiom).filter(Axiom.id == axiom_id).first()
|
| 44 |
+
db.delete(axiom)
|
| 45 |
+
db.commit()
|
| 46 |
+
|
| 47 |
+
def create_theorem(db, universe_id, statement, proof):
|
| 48 |
+
theorem = Theorem(universe_id=universe_id, statement=statement, proof=proof)
|
| 49 |
+
db.add(theorem)
|
| 50 |
+
db.commit()
|
| 51 |
+
db.refresh(theorem)
|
| 52 |
+
return theorem
|
| 53 |
+
|
| 54 |
+
def update_theorem(db, theorem_id, **kwargs):
|
| 55 |
+
theorem = db.query(Theorem).filter(Theorem.id == theorem_id).first()
|
| 56 |
+
for key, value in kwargs.items():
|
| 57 |
+
setattr(theorem, key, value)
|
| 58 |
+
db.commit()
|
| 59 |
+
db.refresh(theorem)
|
| 60 |
+
return theorem
|
| 61 |
+
|
| 62 |
+
def delete_theorem(db, theorem_id):
|
| 63 |
+
theorem = db.query(Theorem).filter(Theorem.id == theorem_id).first()
|
| 64 |
+
db.delete(theorem)
|
| 65 |
+
db.commit()
|
| 66 |
+
|
| 67 |
+
def create_proof(db, axiom_id, content):
|
| 68 |
+
proof = Proof(axiom_id=axiom_id, content=content)
|
| 69 |
+
db.add(proof)
|
| 70 |
+
db.commit()
|
| 71 |
+
db.refresh(proof)
|
| 72 |
+
return proof
|
| 73 |
+
|
| 74 |
+
def update_proof(db, proof_id, **kwargs):
|
| 75 |
+
proof = db.query(Proof).filter(Proof.id == proof_id).first()
|
| 76 |
+
for key, value in kwargs.items():
|
| 77 |
+
setattr(proof, key, value)
|
| 78 |
+
db.commit()
|
| 79 |
+
db.refresh(proof)
|
| 80 |
+
return proof
|
| 81 |
+
|
| 82 |
+
def delete_proof(db, proof_id):
|
| 83 |
+
proof = db.query(Proof).filter(Proof.id == proof_id).first()
|
| 84 |
+
db.delete(proof)
|
| 85 |
+
db.commit()
|
backend/api/example_payloads.md
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Example API Payloads
|
| 2 |
+
|
| 3 |
+
### Create Universe
|
| 4 |
+
POST /universes
|
| 5 |
+
```json
|
| 6 |
+
{
|
| 7 |
+
"name": "Group Theory",
|
| 8 |
+
"description": "Universe for group theory",
|
| 9 |
+
"universe_type": "group_theory",
|
| 10 |
+
"axioms": ["Closure", "Associativity", "Identity", "Inverse"]
|
| 11 |
+
}
|
| 12 |
+
```
|
| 13 |
+
|
| 14 |
+
### Add Axiom
|
| 15 |
+
POST /axioms
|
| 16 |
+
```json
|
| 17 |
+
{
|
| 18 |
+
"universe_id": 1,
|
| 19 |
+
"statement": "Commutativity"
|
| 20 |
+
}
|
| 21 |
+
```
|
| 22 |
+
|
| 23 |
+
### Evolve Axiom
|
| 24 |
+
POST /axioms/evolve
|
| 25 |
+
Form data or JSON:
|
| 26 |
+
```json
|
| 27 |
+
{
|
| 28 |
+
"axiom_id": 2,
|
| 29 |
+
"new_statement": "Commutativity (strong)"
|
| 30 |
+
}
|
| 31 |
+
```
|
| 32 |
+
|
| 33 |
+
### Derive Theorem
|
| 34 |
+
POST /theorems/derive
|
| 35 |
+
```json
|
| 36 |
+
{
|
| 37 |
+
"universe_id": 1,
|
| 38 |
+
"axiom_ids": [1, 2],
|
| 39 |
+
"statement": "Closure Commutativity"
|
| 40 |
+
}
|
| 41 |
+
```
|
| 42 |
+
|
| 43 |
+
### Create Proof
|
| 44 |
+
POST /proofs
|
| 45 |
+
```json
|
| 46 |
+
{
|
| 47 |
+
"axiom_id": 1,
|
| 48 |
+
"content": "Proof details here."
|
| 49 |
+
}
|
| 50 |
+
```
|
backend/api/neuro_symbolic_routes.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import APIRouter, Depends
|
| 2 |
+
from sqlalchemy.orm import Session
|
| 3 |
+
from backend.db.session import SessionLocal
|
| 4 |
+
from backend.core.neuro_symbolic import NeuroSymbolicNetwork
|
| 5 |
+
|
| 6 |
+
router = APIRouter()
|
| 7 |
+
|
| 8 |
+
def get_db():
|
| 9 |
+
db = SessionLocal()
|
| 10 |
+
try:
|
| 11 |
+
yield db
|
| 12 |
+
finally:
|
| 13 |
+
db.close()
|
| 14 |
+
|
| 15 |
+
@router.post("/neuro/train")
|
| 16 |
+
def train_neuro(training_data: list[list[float]], labels: list[int], epochs: int = 10, db: Session = Depends(get_db)):
|
| 17 |
+
nsn = NeuroSymbolicNetwork(db)
|
| 18 |
+
loss = nsn.train(training_data, labels, epochs)
|
| 19 |
+
return {"final_loss": loss}
|
| 20 |
+
|
| 21 |
+
@router.post("/neuro/predict")
|
| 22 |
+
def predict_neuro(input_data: list[list[float]], db: Session = Depends(get_db)):
|
| 23 |
+
nsn = NeuroSymbolicNetwork(db)
|
| 24 |
+
predictions = nsn.predict(input_data)
|
| 25 |
+
return {"predictions": predictions}
|
| 26 |
+
|
| 27 |
+
@router.post("/neuro/guide")
|
| 28 |
+
def guide_proof_search(universe_id: int, axiom_ids: list[int], db: Session = Depends(get_db)):
|
| 29 |
+
nsn = NeuroSymbolicNetwork(db)
|
| 30 |
+
suggestion = nsn.guide_proof_search(universe_id, axiom_ids)
|
| 31 |
+
return suggestion
|
backend/api/quantum_routes.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import APIRouter, Depends, HTTPException
|
| 2 |
+
from backend.core.quantum_search import GroverSearch
|
| 3 |
+
from backend.api.auth import get_api_key
|
| 4 |
+
from backend.core.logging_config import get_logger
|
| 5 |
+
|
| 6 |
+
router = APIRouter()
|
| 7 |
+
logger = get_logger("quantum_routes")
|
| 8 |
+
|
| 9 |
+
@router.post("/quantum/grover")
|
| 10 |
+
def run_grover(database_size: int, target_idx: int, iterations: int = None, api_key: str = Depends(get_api_key)):
|
| 11 |
+
try:
|
| 12 |
+
search = GroverSearch(database_size)
|
| 13 |
+
result_idx = search.run(target_idx, iterations)
|
| 14 |
+
logger.info(f"Grover search: db_size={database_size}, target={target_idx}, result={result_idx}")
|
| 15 |
+
return {"found_index": result_idx}
|
| 16 |
+
except Exception as e:
|
| 17 |
+
logger.error(f"Grover search error: {e}")
|
| 18 |
+
raise HTTPException(status_code=500, detail=str(e))
|
backend/api/query_routes.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import APIRouter, Depends, HTTPException
|
| 2 |
+
from sqlalchemy.orm import Session
|
| 3 |
+
from backend.db.session import SessionLocal
|
| 4 |
+
from backend.db.models import Universe, Axiom, Theorem
|
| 5 |
+
from backend.api.auth import get_api_key
|
| 6 |
+
from backend.core.logging_config import get_logger
|
| 7 |
+
|
| 8 |
+
router = APIRouter()
|
| 9 |
+
logger = get_logger("query_routes")
|
| 10 |
+
|
| 11 |
+
def get_db():
|
| 12 |
+
db = SessionLocal()
|
| 13 |
+
try:
|
| 14 |
+
yield db
|
| 15 |
+
finally:
|
| 16 |
+
db.close()
|
| 17 |
+
|
| 18 |
+
@router.get("/query/universe_summary/{universe_id}")
|
| 19 |
+
def get_universe_summary(universe_id: int, db: Session = Depends(get_db), api_key: str = Depends(get_api_key)):
|
| 20 |
+
try:
|
| 21 |
+
universe = db.query(Universe).filter(Universe.id == universe_id).first()
|
| 22 |
+
axioms = db.query(Axiom).filter(Axiom.universe_id == universe_id).all()
|
| 23 |
+
theorems = db.query(Theorem).filter(Theorem.universe_id == universe_id).all()
|
| 24 |
+
logger.info(f"Universe summary for {universe_id} generated.")
|
| 25 |
+
return {
|
| 26 |
+
"universe": {"id": universe.id, "name": universe.name, "type": universe.universe_type},
|
| 27 |
+
"axioms": [ax.statement for ax in axioms],
|
| 28 |
+
"theorems": [th.statement for th in theorems],
|
| 29 |
+
"axiom_count": len(axioms),
|
| 30 |
+
"theorem_count": len(theorems)
|
| 31 |
+
}
|
| 32 |
+
except Exception as e:
|
| 33 |
+
logger.error(f"Query error: {e}")
|
| 34 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 35 |
+
|
| 36 |
+
@router.get("/query/axiom_usage/{axiom_id}")
|
| 37 |
+
def get_axiom_usage(axiom_id: int, db: Session = Depends(get_db), api_key: str = Depends(get_api_key)):
|
| 38 |
+
try:
|
| 39 |
+
axiom = db.query(Axiom).filter(Axiom.id == axiom_id).first()
|
| 40 |
+
theorems = db.query(Theorem).filter(Theorem.universe_id == axiom.universe_id).all()
|
| 41 |
+
used_in = [th.statement for th in theorems if axiom.statement in th.proof]
|
| 42 |
+
logger.info(f"Axiom usage for {axiom_id} generated.")
|
| 43 |
+
return {
|
| 44 |
+
"axiom": axiom.statement,
|
| 45 |
+
"used_in_theorems": used_in,
|
| 46 |
+
"usage_count": len(used_in)
|
| 47 |
+
}
|
| 48 |
+
except Exception as e:
|
| 49 |
+
logger.error(f"Query error: {e}")
|
| 50 |
+
raise HTTPException(status_code=500, detail=str(e))
|
backend/api/routes.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
"""
|
| 3 |
+
API route definitions for the backend.
|
| 4 |
+
Includes endpoints for universes, axioms, theorems, proofs, and analysis.
|
| 5 |
+
All endpoints use Pydantic schemas for request/response validation.
|
| 6 |
+
"""
|
| 7 |
+
from fastapi import APIRouter, Depends, HTTPException
|
| 8 |
+
from sqlalchemy.orm import Session
|
| 9 |
+
from backend.db.session import get_db
|
| 10 |
+
from backend.db.models import Universe, Axiom, Proof, AnalysisResult
|
| 11 |
+
from backend.core.universe_generator import UniverseGenerator
|
| 12 |
+
from backend.core.theorem_engine import TheoremEngine
|
| 13 |
+
from backend.api.schemas import UniverseCreate, AxiomCreate, ProofCreate, TheoremCreate, TheoremOut, UniverseOut, AxiomOut, ProofOut, AnalysisResultOut
|
| 14 |
+
from typing import List
|
| 15 |
+
|
| 16 |
+
router = APIRouter()
|
| 17 |
+
|
| 18 |
+
@router.post("/theorems/derive", response_model=TheoremOut, summary="Derive a theorem from axioms")
|
| 19 |
+
def api_derive_theorem(payload: TheoremCreate, db: Session = Depends(get_db)) -> TheoremOut:
|
| 20 |
+
"""Derive a theorem from a set of axioms in a universe."""
|
| 21 |
+
engine = TheoremEngine(db)
|
| 22 |
+
try:
|
| 23 |
+
theorem = engine.derive_theorem(payload.universe_id, payload.axiom_ids, payload.statement)
|
| 24 |
+
return theorem
|
| 25 |
+
except ValueError as e:
|
| 26 |
+
raise HTTPException(status_code=400, detail=str(e))
|
| 27 |
+
|
| 28 |
+
@router.get("/universes", response_model=List[UniverseOut], summary="List all universes")
|
| 29 |
+
def list_universes(db: Session = Depends(get_db)) -> List[UniverseOut]:
|
| 30 |
+
"""List all universes in the database."""
|
| 31 |
+
return db.query(Universe).all()
|
| 32 |
+
|
| 33 |
+
@router.post("/universes", response_model=UniverseOut, summary="Create a new universe")
|
| 34 |
+
def api_create_universe(payload: UniverseCreate, db: Session = Depends(get_db)) -> UniverseOut:
|
| 35 |
+
"""Create a new universe with optional axioms and type."""
|
| 36 |
+
generator = UniverseGenerator(db)
|
| 37 |
+
universe = generator.create_universe(payload.name, payload.description, payload.universe_type, payload.axioms)
|
| 38 |
+
return universe
|
| 39 |
+
|
| 40 |
+
@router.get("/universes/{universe_id}/history", summary="Get universe history and axiom lineage")
|
| 41 |
+
def get_universe_history(universe_id: int, db: Session = Depends(get_db)):
|
| 42 |
+
"""Get the history and axiom lineage for a universe."""
|
| 43 |
+
universe = db.query(Universe).filter(Universe.id == universe_id).first()
|
| 44 |
+
if not universe:
|
| 45 |
+
raise HTTPException(status_code=404, detail="Universe not found")
|
| 46 |
+
axioms = db.query(Axiom).filter(Axiom.universe_id == universe_id).all()
|
| 47 |
+
return {
|
| 48 |
+
"universe": universe,
|
| 49 |
+
"axioms": axioms
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
@router.get("/axioms/{universe_id}", response_model=List[AxiomOut], summary="List axioms for a universe")
|
| 53 |
+
def list_axioms(universe_id: int, db: Session = Depends(get_db)) -> List[AxiomOut]:
|
| 54 |
+
"""List all axioms for a given universe."""
|
| 55 |
+
axioms = db.query(Axiom).filter(Axiom.universe_id == universe_id).all()
|
| 56 |
+
return axioms
|
| 57 |
+
|
| 58 |
+
@router.post("/axioms", response_model=AxiomOut, summary="Add a new axiom")
|
| 59 |
+
def api_create_axiom(payload: AxiomCreate, db: Session = Depends(get_db)) -> AxiomOut:
|
| 60 |
+
"""Add a new axiom to a universe."""
|
| 61 |
+
generator = UniverseGenerator(db)
|
| 62 |
+
try:
|
| 63 |
+
axiom = generator.add_axiom(payload.universe_id, payload.statement)
|
| 64 |
+
return axiom
|
| 65 |
+
except ValueError as e:
|
| 66 |
+
raise HTTPException(status_code=400, detail=str(e))
|
| 67 |
+
|
| 68 |
+
@router.post("/axioms/evolve", response_model=AxiomOut, summary="Evolve an axiom")
|
| 69 |
+
def api_evolve_axiom(axiom_id: int, new_statement: str, db: Session = Depends(get_db)) -> AxiomOut:
|
| 70 |
+
"""Evolve an axiom to a new statement."""
|
| 71 |
+
generator = UniverseGenerator(db)
|
| 72 |
+
try:
|
| 73 |
+
new_axiom = generator.evolve_axiom(axiom_id, new_statement)
|
| 74 |
+
return new_axiom
|
| 75 |
+
except ValueError as e:
|
| 76 |
+
raise HTTPException(status_code=400, detail=str(e))
|
| 77 |
+
|
| 78 |
+
@router.get("/theorems/{universe_id}", response_model=List[TheoremOut], summary="List theorems for a universe")
|
| 79 |
+
def list_theorems(universe_id: int, db: Session = Depends(get_db)) -> List[TheoremOut]:
|
| 80 |
+
"""List all theorems for a given universe."""
|
| 81 |
+
engine = TheoremEngine(db)
|
| 82 |
+
return engine.list_theorems(universe_id)
|
| 83 |
+
|
| 84 |
+
@router.post("/proofs", response_model=ProofOut, summary="Create a proof for an axiom")
|
| 85 |
+
def create_proof(payload: ProofCreate, db: Session = Depends(get_db)) -> ProofOut:
|
| 86 |
+
"""Create a proof for an axiom."""
|
| 87 |
+
proof = Proof(axiom_id=payload.axiom_id, content=payload.content)
|
| 88 |
+
db.add(proof)
|
| 89 |
+
db.commit()
|
| 90 |
+
db.refresh(proof)
|
| 91 |
+
return proof
|
| 92 |
+
|
| 93 |
+
@router.get("/analysis/{universe_id}", response_model=List[AnalysisResultOut], summary="Get analysis results for a universe")
|
| 94 |
+
def get_analysis(universe_id: int, db: Session = Depends(get_db)) -> List[AnalysisResultOut]:
|
| 95 |
+
"""Get analysis results for a universe."""
|
| 96 |
+
return db.query(AnalysisResult).filter(AnalysisResult.universe_id == universe_id).all()
|
backend/api/schemas.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Optional
|
| 2 |
+
from pydantic import BaseModel, Field
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class AxiomCreate(BaseModel):
|
| 6 |
+
"""Schema for creating a new axiom."""
|
| 7 |
+
universe_id: int
|
| 8 |
+
statement: str = Field(..., min_length=3, description="Axiom statement (min 3 chars)")
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class AxiomOut(BaseModel):
|
| 12 |
+
"""Schema for axiom output."""
|
| 13 |
+
id: int
|
| 14 |
+
universe_id: int
|
| 15 |
+
statement: str
|
| 16 |
+
is_active: int
|
| 17 |
+
parent_axiom_id: Optional[int]
|
| 18 |
+
version: int
|
| 19 |
+
created_at: Optional[str]
|
| 20 |
+
updated_at: Optional[str]
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class UniverseCreate(BaseModel):
|
| 24 |
+
"""Schema for creating a new universe."""
|
| 25 |
+
name: str = Field(..., min_length=1, description="Universe name")
|
| 26 |
+
description: Optional[str] = ""
|
| 27 |
+
universe_type: Optional[str] = "generic"
|
| 28 |
+
axioms: Optional[List[str]] = []
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class UniverseOut(BaseModel):
|
| 32 |
+
"""Schema for universe output."""
|
| 33 |
+
id: int
|
| 34 |
+
name: str
|
| 35 |
+
description: Optional[str]
|
| 36 |
+
universe_type: str
|
| 37 |
+
version: int
|
| 38 |
+
created_at: Optional[str]
|
| 39 |
+
updated_at: Optional[str]
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class TheoremCreate(BaseModel):
|
| 43 |
+
"""Schema for creating a theorem."""
|
| 44 |
+
universe_id: int
|
| 45 |
+
axiom_ids: List[int]
|
| 46 |
+
statement: str = Field(..., min_length=3, description="Theorem statement")
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class TheoremOut(BaseModel):
|
| 50 |
+
"""Schema for theorem output."""
|
| 51 |
+
id: int
|
| 52 |
+
universe_id: int
|
| 53 |
+
statement: str
|
| 54 |
+
proof: Optional[str]
|
| 55 |
+
created_at: Optional[str]
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class ProofCreate(BaseModel):
|
| 59 |
+
"""Schema for creating a proof."""
|
| 60 |
+
axiom_id: int
|
| 61 |
+
content: str = Field(..., min_length=1, description="Proof content")
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class ProofOut(BaseModel):
|
| 65 |
+
"""Schema for proof output."""
|
| 66 |
+
id: int
|
| 67 |
+
axiom_id: int
|
| 68 |
+
content: str
|
| 69 |
+
created_at: Optional[str]
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class AnalysisRequest(BaseModel):
|
| 73 |
+
"""Schema for requesting analysis on universes."""
|
| 74 |
+
universe_ids: List[int]
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class AnalysisResultOut(BaseModel):
|
| 78 |
+
"""Schema for analysis result output."""
|
| 79 |
+
id: int
|
| 80 |
+
universe_id: int
|
| 81 |
+
result: str
|
| 82 |
+
created_at: Optional[str]
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
# --- Vector store schemas ---
|
| 86 |
+
class VectorAddRequest(BaseModel):
|
| 87 |
+
id: str
|
| 88 |
+
text: str
|
| 89 |
+
metadata: Optional[dict] = {}
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class VectorQueryRequest(BaseModel):
|
| 93 |
+
text: str
|
| 94 |
+
k: Optional[int] = 5
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
class VectorResultItem(BaseModel):
|
| 98 |
+
id: str
|
| 99 |
+
distance: float
|
| 100 |
+
metadata: Optional[dict]
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class VectorQueryResponse(BaseModel):
|
| 104 |
+
results: List[VectorResultItem]
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
# --- vector store related schemas (small convenience types) ---
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
class VectorAddRequest(BaseModel):
|
| 111 |
+
ids: List[str]
|
| 112 |
+
vectors: List[List[float]]
|
| 113 |
+
metas: Optional[List[dict]] = None
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
class VectorSearchRequest(BaseModel):
|
| 117 |
+
query: List[float]
|
| 118 |
+
top_k: int = 5
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
class VectorSearchResult(BaseModel):
|
| 122 |
+
id: str
|
| 123 |
+
score: float
|
| 124 |
+
meta: Optional[dict]
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
# Vector store schemas
|
| 128 |
+
class VectorUpsert(BaseModel):
|
| 129 |
+
id: str
|
| 130 |
+
vector: List[float]
|
| 131 |
+
metadata: Optional[dict] = None
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
class VectorQuery(BaseModel):
|
| 135 |
+
vector: List[float]
|
| 136 |
+
k: Optional[int] = 5
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
class VectorOut(BaseModel):
|
| 140 |
+
id: str
|
| 141 |
+
score: Optional[float] = None
|
| 142 |
+
metadata: Optional[dict] = None
|
| 143 |
+
vector: Optional[List[float]] = None
|
backend/api/vector_routes.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import APIRouter, HTTPException, Depends
|
| 2 |
+
from typing import List, Optional
|
| 3 |
+
from pydantic import BaseModel
|
| 4 |
+
from backend.core.vector_store import get_global_vector_store
|
| 5 |
+
|
| 6 |
+
router = APIRouter()
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class AddTextPayload(BaseModel):
|
| 10 |
+
id: str
|
| 11 |
+
text: str
|
| 12 |
+
metadata: Optional[dict] = None
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class QueryPayload(BaseModel):
|
| 16 |
+
text: str
|
| 17 |
+
k: Optional[int] = 5
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@router.post("/vectors/add", summary="Add text as vector")
|
| 21 |
+
def add_text(payload: AddTextPayload):
|
| 22 |
+
store = get_global_vector_store()
|
| 23 |
+
try:
|
| 24 |
+
store.add_text(payload.id, payload.text, payload.metadata)
|
| 25 |
+
return {"status": "ok", "id": payload.id}
|
| 26 |
+
except Exception as e:
|
| 27 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@router.post("/vectors/query", summary="Query nearest vectors by text")
|
| 31 |
+
def query_text(payload: QueryPayload):
|
| 32 |
+
store = get_global_vector_store()
|
| 33 |
+
try:
|
| 34 |
+
results = store.query_text(payload.text, k=payload.k or 5)
|
| 35 |
+
# convert numpy arrays to lists for JSON
|
| 36 |
+
out = [{"id": r[0], "distance": r[1], "metadata": r[2]} for r in results]
|
| 37 |
+
return {"results": out}
|
| 38 |
+
except Exception as e:
|
| 39 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 40 |
+
"""API routes for vector store operations (add/search)."""
|
| 41 |
+
from fastapi import APIRouter, Depends, HTTPException
|
| 42 |
+
from typing import List, Optional
|
| 43 |
+
from pydantic import BaseModel, Field
|
| 44 |
+
import numpy as np
|
| 45 |
+
|
| 46 |
+
from backend.core.vector_store import get_default_store, VectorStore
|
| 47 |
+
|
| 48 |
+
router = APIRouter(prefix="/vector", tags=["vector-store"])
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class VectorAddRequest(BaseModel):
|
| 52 |
+
ids: List[str]
|
| 53 |
+
vectors: List[List[float]]
|
| 54 |
+
metas: Optional[List[dict]] = None
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class VectorSearchRequest(BaseModel):
|
| 58 |
+
query: List[float] = Field(..., min_items=1)
|
| 59 |
+
top_k: int = 5
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class VectorSearchResult(BaseModel):
|
| 63 |
+
id: str
|
| 64 |
+
score: float
|
| 65 |
+
meta: Optional[dict]
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
@router.post("/add")
|
| 69 |
+
def add_vectors(payload: VectorAddRequest):
|
| 70 |
+
store = get_default_store(dim=len(payload.vectors[0]) if payload.vectors else 128)
|
| 71 |
+
try:
|
| 72 |
+
vecs = np.array(payload.vectors, dtype=np.float32)
|
| 73 |
+
except Exception as e:
|
| 74 |
+
raise HTTPException(status_code=400, detail=f"invalid vectors: {e}")
|
| 75 |
+
count = store.add(payload.ids, vecs, payload.metas)
|
| 76 |
+
return {"indexed": count}
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
@router.post("/search", response_model=List[VectorSearchResult])
|
| 80 |
+
def search_vectors(payload: VectorSearchRequest):
|
| 81 |
+
store = get_default_store(dim=len(payload.query))
|
| 82 |
+
q = np.array(payload.query, dtype=np.float32)
|
| 83 |
+
results = store.search(q, top_k=payload.top_k)
|
| 84 |
+
return results
|
| 85 |
+
from fastapi import APIRouter, Depends, HTTPException
|
| 86 |
+
from typing import List, Optional
|
| 87 |
+
from backend.api.schemas import VectorUpsert, VectorQuery, VectorOut
|
| 88 |
+
from backend.core.vector_store import default_store, VectorStore
|
| 89 |
+
|
| 90 |
+
router = APIRouter(prefix="/vectors", tags=["vectors"])
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
@router.post("/upsert", response_model=VectorOut, summary="Upsert a single vector")
|
| 94 |
+
def upsert_vector(payload: VectorUpsert):
|
| 95 |
+
"""Add or update a single vector in the default store."""
|
| 96 |
+
try:
|
| 97 |
+
default_store.add(payload.id, payload.vector, metadata=payload.metadata or {})
|
| 98 |
+
return {"id": payload.id, "vector": payload.vector, "metadata": payload.metadata or {}}
|
| 99 |
+
except Exception as e:
|
| 100 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
@router.post("/query", response_model=List[VectorOut], summary="Query nearest vectors")
|
| 104 |
+
def query_vectors(payload: VectorQuery):
|
| 105 |
+
results = default_store.search(payload.vector, k=payload.k or 5)
|
| 106 |
+
return [{"id": r[0], "score": r[1], "metadata": r[2], "vector": None} for r in results]
|
backend/api/visualization_routes.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import APIRouter, Depends, HTTPException
|
| 2 |
+
from sqlalchemy.orm import Session
|
| 3 |
+
from backend.db.session import SessionLocal
|
| 4 |
+
from backend.db.models import Universe, Axiom, Theorem
|
| 5 |
+
from backend.api.auth import get_api_key
|
| 6 |
+
from backend.core.logging_config import get_logger
|
| 7 |
+
|
| 8 |
+
router = APIRouter()
|
| 9 |
+
logger = get_logger("visualization_routes")
|
| 10 |
+
|
| 11 |
+
def get_db():
|
| 12 |
+
db = SessionLocal()
|
| 13 |
+
try:
|
| 14 |
+
yield db
|
| 15 |
+
finally:
|
| 16 |
+
db.close()
|
| 17 |
+
|
| 18 |
+
@router.get("/visualization/universe/{universe_id}")
|
| 19 |
+
def get_universe_graph(universe_id: int, db: Session = Depends(get_db), api_key: str = Depends(get_api_key)):
|
| 20 |
+
try:
|
| 21 |
+
universe = db.query(Universe).filter(Universe.id == universe_id).first()
|
| 22 |
+
axioms = db.query(Axiom).filter(Axiom.universe_id == universe_id).all()
|
| 23 |
+
theorems = db.query(Theorem).filter(Theorem.universe_id == universe_id).all()
|
| 24 |
+
nodes = [{"id": ax.id, "type": "axiom", "label": ax.statement} for ax in axioms] + \
|
| 25 |
+
[{"id": th.id, "type": "theorem", "label": th.statement} for th in theorems]
|
| 26 |
+
edges = []
|
| 27 |
+
for th in theorems:
|
| 28 |
+
for ax in axioms:
|
| 29 |
+
if ax.statement in th.proof:
|
| 30 |
+
edges.append({"source": ax.id, "target": th.id, "type": "proof"})
|
| 31 |
+
logger.info(f"Visualization graph for universe {universe_id} generated.")
|
| 32 |
+
return {
|
| 33 |
+
"universe": {"id": universe.id, "name": universe.name, "type": universe.universe_type},
|
| 34 |
+
"nodes": nodes,
|
| 35 |
+
"edges": edges
|
| 36 |
+
}
|
| 37 |
+
except Exception as e:
|
| 38 |
+
logger.error(f"Visualization error: {e}")
|
| 39 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 40 |
+
|
| 41 |
+
@router.get("/visualization/universes")
|
| 42 |
+
def get_all_universe_graphs(db: Session = Depends(get_db), api_key: str = Depends(get_api_key)):
|
| 43 |
+
try:
|
| 44 |
+
universes = db.query(Universe).all()
|
| 45 |
+
result = []
|
| 46 |
+
for universe in universes:
|
| 47 |
+
axioms = db.query(Axiom).filter(Axiom.universe_id == universe.id).all()
|
| 48 |
+
theorems = db.query(Theorem).filter(Theorem.universe_id == universe.id).all()
|
| 49 |
+
nodes = [{"id": ax.id, "type": "axiom", "label": ax.statement} for ax in axioms] + \
|
| 50 |
+
[{"id": th.id, "type": "theorem", "label": th.statement} for th in theorems]
|
| 51 |
+
edges = []
|
| 52 |
+
for th in theorems:
|
| 53 |
+
for ax in axioms:
|
| 54 |
+
if ax.statement in th.proof:
|
| 55 |
+
edges.append({"source": ax.id, "target": th.id, "type": "proof"})
|
| 56 |
+
result.append({
|
| 57 |
+
"universe": {"id": universe.id, "name": universe.name, "type": universe.universe_type},
|
| 58 |
+
"nodes": nodes,
|
| 59 |
+
"edges": edges
|
| 60 |
+
})
|
| 61 |
+
logger.info("Visualization graphs for all universes generated.")
|
| 62 |
+
return result
|
| 63 |
+
except Exception as e:
|
| 64 |
+
logger.error(f"Visualization error: {e}")
|
| 65 |
+
raise HTTPException(status_code=500, detail=str(e))
|
backend/core/.rustup/settings.toml
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version = "12"
|
| 2 |
+
profile = "default"
|
| 3 |
+
|
| 4 |
+
[overrides]
|
backend/core/ag4masses/.gitignore
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
# Distribution / packaging
|
| 7 |
+
.Python
|
| 8 |
+
build/
|
| 9 |
+
develop-eggs/
|
| 10 |
+
dist/
|
| 11 |
+
downloads/
|
| 12 |
+
eggs/
|
| 13 |
+
.eggs/
|
| 14 |
+
lib/
|
| 15 |
+
lib64/
|
| 16 |
+
parts/
|
| 17 |
+
sdist/
|
| 18 |
+
var/
|
| 19 |
+
wheels/
|
| 20 |
+
share/python-wheels/
|
| 21 |
+
*.egg-info/
|
| 22 |
+
.installed.cfg
|
| 23 |
+
*.egg
|
| 24 |
+
MANIFEST
|
| 25 |
+
ag_ckpt_vocab/
|
| 26 |
+
.vscode
|
| 27 |
+
.env
|
backend/core/ag4masses/CONTRIBUTING.md
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# How to Contribute
|
| 2 |
+
|
| 3 |
+
## Contributor License Agreement
|
| 4 |
+
|
| 5 |
+
Contributed code or data will become part of the AG4Masses project and be subject to the same Licence Agreement as the AG4Masses project.
|
| 6 |
+
|
| 7 |
+
## Code reviews
|
| 8 |
+
|
| 9 |
+
All submissions, including submissions by project members, require review. We
|
| 10 |
+
use GitHub pull requests for this purpose. Consult
|
| 11 |
+
[GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
|
| 12 |
+
information on using pull requests.
|
| 13 |
+
|
backend/core/ag4masses/LICENSE
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
Apache License
|
| 3 |
+
Version 2.0, January 2004
|
| 4 |
+
http://www.apache.org/licenses/
|
| 5 |
+
|
| 6 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 7 |
+
|
| 8 |
+
1. Definitions.
|
| 9 |
+
|
| 10 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 11 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 12 |
+
|
| 13 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 14 |
+
the copyright owner that is granting the License.
|
| 15 |
+
|
| 16 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 17 |
+
other entities that control, are controlled by, or are under common
|
| 18 |
+
control with that entity. For the purposes of this definition,
|
| 19 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 20 |
+
direction or management of such entity, whether by contract or
|
| 21 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 22 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 23 |
+
|
| 24 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 25 |
+
exercising permissions granted by this License.
|
| 26 |
+
|
| 27 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 28 |
+
including but not limited to software source code, documentation
|
| 29 |
+
source, and configuration files.
|
| 30 |
+
|
| 31 |
+
"Object" form shall mean any form resulting from mechanical
|
| 32 |
+
transformation or translation of a Source form, including but
|
| 33 |
+
not limited to compiled object code, generated documentation,
|
| 34 |
+
and conversions to other media types.
|
| 35 |
+
|
| 36 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 37 |
+
Object form, made available under the License, as indicated by a
|
| 38 |
+
copyright notice that is included in or attached to the work
|
| 39 |
+
(an example is provided in the Appendix below).
|
| 40 |
+
|
| 41 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 42 |
+
form, that is based on (or derived from) the Work and for which the
|
| 43 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 44 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 45 |
+
of this License, Derivative Works shall not include works that remain
|
| 46 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 47 |
+
the Work and Derivative Works thereof.
|
| 48 |
+
|
| 49 |
+
"Contribution" shall mean any work of authorship, including
|
| 50 |
+
the original version of the Work and any modifications or additions
|
| 51 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 52 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 53 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 54 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 55 |
+
means any form of electronic, verbal, or written communication sent
|
| 56 |
+
to the Licensor or its representatives, including but not limited to
|
| 57 |
+
communication on electronic mailing lists, source code control systems,
|
| 58 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 59 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 60 |
+
excluding communication that is conspicuously marked or otherwise
|
| 61 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 62 |
+
|
| 63 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 64 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 65 |
+
subsequently incorporated within the Work.
|
| 66 |
+
|
| 67 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 68 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 69 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 70 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 71 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 72 |
+
Work and such Derivative Works in Source or Object form.
|
| 73 |
+
|
| 74 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 75 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 76 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 77 |
+
(except as stated in this section) patent license to make, have made,
|
| 78 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 79 |
+
where such license applies only to those patent claims licensable
|
| 80 |
+
by such Contributor that are necessarily infringed by their
|
| 81 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 82 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 83 |
+
institute patent litigation against any entity (including a
|
| 84 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 85 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 86 |
+
or contributory patent infringement, then any patent licenses
|
| 87 |
+
granted to You under this License for that Work shall terminate
|
| 88 |
+
as of the date such litigation is filed.
|
| 89 |
+
|
| 90 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 91 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 92 |
+
modifications, and in Source or Object form, provided that You
|
| 93 |
+
meet the following conditions:
|
| 94 |
+
|
| 95 |
+
(a) You must give any other recipients of the Work or
|
| 96 |
+
Derivative Works a copy of this License; and
|
| 97 |
+
|
| 98 |
+
(b) You must cause any modified files to carry prominent notices
|
| 99 |
+
stating that You changed the files; and
|
| 100 |
+
|
| 101 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 102 |
+
that You distribute, all copyright, patent, trademark, and
|
| 103 |
+
attribution notices from the Source form of the Work,
|
| 104 |
+
excluding those notices that do not pertain to any part of
|
| 105 |
+
the Derivative Works; and
|
| 106 |
+
|
| 107 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 108 |
+
distribution, then any Derivative Works that You distribute must
|
| 109 |
+
include a readable copy of the attribution notices contained
|
| 110 |
+
within such NOTICE file, excluding those notices that do not
|
| 111 |
+
pertain to any part of the Derivative Works, in at least one
|
| 112 |
+
of the following places: within a NOTICE text file distributed
|
| 113 |
+
as part of the Derivative Works; within the Source form or
|
| 114 |
+
documentation, if provided along with the Derivative Works; or,
|
| 115 |
+
within a display generated by the Derivative Works, if and
|
| 116 |
+
wherever such third-party notices normally appear. The contents
|
| 117 |
+
of the NOTICE file are for informational purposes only and
|
| 118 |
+
do not modify the License. You may add Your own attribution
|
| 119 |
+
notices within Derivative Works that You distribute, alongside
|
| 120 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 121 |
+
that such additional attribution notices cannot be construed
|
| 122 |
+
as modifying the License.
|
| 123 |
+
|
| 124 |
+
You may add Your own copyright statement to Your modifications and
|
| 125 |
+
may provide additional or different license terms and conditions
|
| 126 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 127 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 128 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 129 |
+
the conditions stated in this License.
|
| 130 |
+
|
| 131 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 132 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 133 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 134 |
+
this License, without any additional terms or conditions.
|
| 135 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 136 |
+
the terms of any separate license agreement you may have executed
|
| 137 |
+
with Licensor regarding such Contributions.
|
| 138 |
+
|
| 139 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 140 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 141 |
+
except as required for reasonable and customary use in describing the
|
| 142 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 143 |
+
|
| 144 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 145 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 146 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 147 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 148 |
+
implied, including, without limitation, any warranties or conditions
|
| 149 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 150 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 151 |
+
appropriateness of using or redistributing the Work and assume any
|
| 152 |
+
risks associated with Your exercise of permissions under this License.
|
| 153 |
+
|
| 154 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 155 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 156 |
+
unless required by applicable law (such as deliberate and grossly
|
| 157 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 158 |
+
liable to You for damages, including any direct, indirect, special,
|
| 159 |
+
incidental, or consequential damages of any character arising as a
|
| 160 |
+
result of this License or out of the use or inability to use the
|
| 161 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 162 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 163 |
+
other commercial damages or losses), even if such Contributor
|
| 164 |
+
has been advised of the possibility of such damages.
|
| 165 |
+
|
| 166 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 167 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 168 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 169 |
+
or other liability obligations and/or rights consistent with this
|
| 170 |
+
License. However, in accepting such obligations, You may act only
|
| 171 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 172 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 173 |
+
defend, and hold each Contributor harmless for any liability
|
| 174 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 175 |
+
of your accepting any such warranty or additional liability.
|
| 176 |
+
|
| 177 |
+
END OF TERMS AND CONDITIONS
|
| 178 |
+
|
| 179 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 180 |
+
|
| 181 |
+
To apply the Apache License to your work, attach the following
|
| 182 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 183 |
+
replaced with your own identifying information. (Don't include
|
| 184 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 185 |
+
comment syntax for the file format. We also recommend that a
|
| 186 |
+
file or class name and description of purpose be included on the
|
| 187 |
+
same "printed page" as the copyright notice for easier
|
| 188 |
+
identification within third-party archives.
|
| 189 |
+
|
| 190 |
+
Copyright [yyyy] [name of copyright owner]
|
| 191 |
+
|
| 192 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 193 |
+
you may not use this file except in compliance with the License.
|
| 194 |
+
You may obtain a copy of the License at
|
| 195 |
+
|
| 196 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 197 |
+
|
| 198 |
+
Unless required by applicable law or agreed to in writing, software
|
| 199 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 200 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 201 |
+
See the License for the specific language governing permissions and
|
| 202 |
+
limitations under the License.
|
backend/core/ag4masses/README.md
ADDED
|
@@ -0,0 +1,346 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# AG4Masses: AlphaGeometry for the Masses
|
| 2 |
+
|
| 3 |
+
An exciting recent development in AI with rigorous logical reasoning ability is the [AlphaGeometry](https://www.nature.com/articles/s41586-023-06747-5) system developed by Google Deepmind. Google made the source code for running AlphaGeometry available on GitHub at [google-deepmind/alphageometry](https://github.com/google-deepmind/alphageometry). However, the AlphaGeometry system as Google released, with the Language Model trained by Google, still requires a tremendous amount of computing power to run when solving problems. As Google's paper mentioned, in order to solve IMO level problems in about 1.5 hour, it needed 4 GPU V100 and up to 250 CPUs. These are not the kind of hardware casual users and hobbyists have access to.
|
| 4 |
+
|
| 5 |
+
AlphaGeometry includes a powerful deductive engine DD+AR that can solve virtually any plane geometry problem that does not require auxiliary points within a few minutes with household hardware. The ultimate performance of the system hinges on the ability to add auxiliary points that lead to a solution. In AlphaGeometry, this is done by the Language Model. My tests shown that, for many classic problems, AlphaGeometry failed to solve them after trying more than ~8000 figures with added auxiliary points. For humans, the number of figures attempted is typically under 100. This indicates that there is still vast room to improve the performance of AlphaGeometry.
|
| 6 |
+
|
| 7 |
+
Since the initial open-sourcing in January 2024, as of April 2024, there has been no update to the AlphaGeometry repository. It is unclear whether Google plans to continue developing AlphaGeometry. The AG4Masses project is a fork of [google-deepmind/alphageometry](https://github.com/google-deepmind/alphageometry). I hope to build on the wonderful foundation AlphaGeometry has laid, continue to improve it, bring its powers to everyday users and hobbyists, and provide useful insights to future developments of AI with rigorous logical reasoning ability.
|
| 8 |
+
|
| 9 |
+
# The Goal of AG4Masses
|
| 10 |
+
|
| 11 |
+
The goal of this project is to **improve the performance of AlphaGeometry by a factor of ~100** to enable it to **solve IMO level (hence vast majority of) plane geometry problems with household hardware (as of 2024, 4-8 logical CPU, 16-32G RAM, no high-end GPU) within a day**.
|
| 12 |
+
|
| 13 |
+
If you are interested, you are welcome to join the community, contribute your ideas and code, or join the discussion on the [Discussions](https://github.com/tpgh24/ag4masses/discussions) page.
|
| 14 |
+
|
| 15 |
+
# Release Notes
|
| 16 |
+
* January 2025:
|
| 17 |
+
* Added a Kaggle Notebook enabling AG4Masses to be run on Kaggle to leverage free resources provided by Kaggle, including 2 Nvidia T4 GPUs, 4 virtual CPUs and 29G RAM
|
| 18 |
+
* Various minor improvements of the robustness and user-friendliness, including [Pull request #12](https://github.com/tpgh24/ag4masses/pull/12) by [pgmthar](https://github.com/pgmthar)
|
| 19 |
+
* Some additional problems and outputs, including [IMO 2024 Question 4](https://artofproblemsolving.com/wiki/index.php/2024_IMO_Problems/Problem_4). See [`outputs/solved`](https://github.com/tpgh24/ag4masses/tree/main/outputs/solved)
|
| 20 |
+
* April 2024
|
| 21 |
+
* Initial release
|
| 22 |
+
|
| 23 |
+
# Table of Contents
|
| 24 |
+
|
| 25 |
+
* [What's Provided in AG4Masses](#whats-provided-in-ag4masses-as-of-april-2024)
|
| 26 |
+
* [(New January 2025) Kaggle Notebook for running AG4Masses](#new-january-2025-kaggle-notebook-for-running-ag4masses)
|
| 27 |
+
* [Code Improvements over AlphaGeometry](#code-improvements-over-alphageometry)
|
| 28 |
+
* [Additional Problems and Test Results](#additional-problems-and-test-results)
|
| 29 |
+
* [Plan for Future Developments](#plan-for-future-developments)
|
| 30 |
+
* [Improve the Language Model that Adds Auxiliary Points](#improve-the-language-model-that-adds-auxiliary-points)
|
| 31 |
+
* [Improve Problem Solving Strategy and Algorithm](#improve-problem-solving-strategy-and-algorithm)
|
| 32 |
+
* [Enhance the Range of Geometry Problems Handled by the System](#enhance-the-range-of-geometry-problems-handled-by-the-system)
|
| 33 |
+
* [Improve the User Friendliness and Robustness of the System](#improve-the-user-friendliness-and-robustness-of-the-system)
|
| 34 |
+
* [Some Tips and Experiences about the AlphaGeometry System](#some-tips-and-experiences-about-the-alphageometry-system)
|
| 35 |
+
* [The Problem Definition Language](#the-problem-definition-language)
|
| 36 |
+
* [Some Tips](#some-tips)
|
| 37 |
+
* [Setup](#setup)
|
| 38 |
+
* [System and Python version](#system-and-python-version)
|
| 39 |
+
* [Choose file locations](#choose-file-locations)
|
| 40 |
+
* [Download source and data files](#download-source-and-data-files)
|
| 41 |
+
* [Install necessary Linux packages](#install-necessary-linux-packages)
|
| 42 |
+
* [Install Python module dependencies](#install-python-module-dependencies)
|
| 43 |
+
* [Run tests](#run-tests)
|
| 44 |
+
* [Run AG4Masses](#run-ag4masses)
|
| 45 |
+
* [Directory Layout](#directory-layout)
|
| 46 |
+
|
| 47 |
+
# What's Provided in AG4Masses (as of January 2025)
|
| 48 |
+
|
| 49 |
+
## (New January 2025) Kaggle Notebook for running AG4Masses
|
| 50 |
+
* The Notebook can be accessed at [AG4Masses-Public](https://www.kaggle.com/code/pengtong/ag4masses-public). It's also included in `ag4masses/utils/` in the ag4masses code base.
|
| 51 |
+
* The Notebook enables running AG4Masses on [Kaggle](https://www.kaggle.com/). As of January 2025, the free version of Kaggle provides 2 Nvidia T4 GPUs, 4 virtual CPUs and 29G RAM. These allow AG4Masses to process about 200 figures per hour for a typical problem (obviously this depends on the complexity of the problem and as more auxilary points are added the progress will slow down)
|
| 52 |
+
* Because Kaggle does not provide persistent storage, everytime a new Kaggle session for the Notebook is started, Python and Linux packages need to be installed, taking about 10 minutes. If anyone knows a way to avoid this, please let me know
|
| 53 |
+
|
| 54 |
+
## Code Improvements over AlphaGeometry
|
| 55 |
+
* Added the ability to use multiple CPUs on a symmetric multiprocessor machine to improve speed
|
| 56 |
+
* Fixed some bugs
|
| 57 |
+
* Improved robustness by handling many error conditions that would have caused AlphaGeometry to abort
|
| 58 |
+
* Improved logging
|
| 59 |
+
* Utility scripts for running AG4Masses, analyzing run-time log, monitoring progress of a run, etc.
|
| 60 |
+
|
| 61 |
+
## Additional Problems and Test Results
|
| 62 |
+
|
| 63 |
+
Additional geometry problems are provided by the AG4Masses project, including some classic problems such as the 5-circles problem, Napoleon problem, Butterfly problem, Ceva Theorem etc. in the `data/ag4m_problems.txt` file.
|
| 64 |
+
|
| 65 |
+
The `outputs` directory contains log files of many test cases. The `solved` subdir are problems solved, most of the problems also come with image files showing the diagrams of the problems. Most of the diagrams are generated by AlphaGeometry automatically, sometimes such diagrams are not very easy to read. For some problems I manually created more readable images, file names of the manually generated diagrams are tagged with '-manual'. The `unsolved` subdir are problems that I have not been able to solve with hardware available to me, after attempting 7500-9500 figures. The auxiliary points added by AlphaGeometry can be found by searching lines like:
|
| 66 |
+
|
| 67 |
+
`I0304 22:44:12.423360 140094168801280 alphageometry.py:548] Worker 0: Translation: "i = on_line i b c, on_bline i c b"`
|
| 68 |
+
|
| 69 |
+
Note that there are some small differences in the format of the log files for different problems because of code changes over time.
|
| 70 |
+
|
| 71 |
+
The naming convention of the log files is: for problems that can be solved by ddar (no auxiliary point needed), the file name contains 'ddar-ok'; for problems that need AlphaGeometry (need auxiliary points) and solved, the file name contains 'ag-ok'.
|
| 72 |
+
|
| 73 |
+
Below are a few examples:
|
| 74 |
+
|
| 75 |
+
### The 5-Circles Problem (`outputs/solved/5circles-ddar-ok.log`):
|
| 76 |
+
|
| 77 |
+
`A, B, C, D, E` are vertices of a pentagon. `F, G, H, I, J` are intersections of their diagonals. 5 circumcircles of triangles `AJF, BFG` *etc.* intersect at 5 points `P, Q, R, S, T`, in addition to `F, G, H, I, J`. Prove that `P, Q, R, S, T` are concyclic.
|
| 78 |
+
|
| 79 |
+
<center>
|
| 80 |
+
<img alt="5circles-manual" width="800px" src="outputs/solved/5circles-manual.jpg">
|
| 81 |
+
</center>
|
| 82 |
+
|
| 83 |
+
It turns out no auxiliary point is needed for this problem, it can be solved by DD+AR, taking 6 minutes with 1 CPU in use. This problem is not easy for humans given there are many points on the diagram and it's not easy to see all the relationships between them. This shows the power of the DD+AR engine.
|
| 84 |
+
|
| 85 |
+
### The 15-Degree-Line-in-Square Problem (`outputs/solved/square_angle15-ag-ok.log`):
|
| 86 |
+
|
| 87 |
+
`A, B, C, D` is a square. `E` is inside the square and `CDE = ECD = 15-degree`. Prove that `ABE` is an equilateral triangle.
|
| 88 |
+
|
| 89 |
+
<center>
|
| 90 |
+
<img alt="square_angle15.jpg" width="800px" src="outputs/solved/square_angle15.jpg">
|
| 91 |
+
</center>
|
| 92 |
+
|
| 93 |
+
This needs an auxiliary point and AlphaGeometry found it very quickly (13 minutes, about 1 CPU in use, no GPU), on the 3rd try (and the first valid figure).
|
| 94 |
+
|
| 95 |
+
I remember I first encountered this problem in the middle school, a few months after learning geometry. An obvious solution was an indirect one: construct an equilateral triangle `ABE` with `AB` as one side and `E` inside the square, show that `CDE = ECD = 15-degree`, then argue that there is only one point that can satisfy this condition. But I and several other classmates were not satisfied with the indirect solution and wanted to find a direct one. 5-6 of us spend 1-2 hours before one student solved it. In that exercise, it took about 10 hours of intense execution by enthusiastic and lightly trained young human brains. Even on very basic hardware, AlphaGeometry is already better than a novice human problem solver.
|
| 96 |
+
|
| 97 |
+
### The Napoleon Problem (`outputs/solved/napoleon-ddar-ok.log`, `outputs/solved/napoleon2-mp-4-solutions-ag-ok.log`)
|
| 98 |
+
|
| 99 |
+
For any triangle `ABC`, construct equilateral triangles with one of the sides as a side (the 3 equilaterals must be in the same direction relative to `ABC`, either all "going out" or all "going in"). The centers of the 3 equilateral triangles - `D, E, F` - form an equilateral triangle.
|
| 100 |
+
|
| 101 |
+
If the problem is stated this way, no additional auxiliary point is needed, it can be solved by DD+AR, see `outputs/solved/napoleon-ddar-ok.log`.
|
| 102 |
+
|
| 103 |
+
<center>
|
| 104 |
+
<img alt="napoleon.jpg" width="800px" src="outputs/solved/napoleon.jpg">
|
| 105 |
+
</center>
|
| 106 |
+
|
| 107 |
+
A more challenging version is to give points `D, E, F` through the conditions that angles `DAB, ABD, EBC, BCE`, *etc.* all equal 30-degree. This will need auxiliary points. In my run AlphaGeometry found 4 solutions, they require 4 auxiliary points. AlphaGeometry found the first after trying around 360 figures. See `outputs/solved/napoleon2-mp-4-solutions-ag-ok.log`.
|
| 108 |
+
|
| 109 |
+
<center>
|
| 110 |
+
<img alt="napoleon2-mp-2.jpg" width="800px" src="outputs/solved/napoleon2-mp-2.jpg">
|
| 111 |
+
</center>
|
| 112 |
+
|
| 113 |
+
### Ceva's Theorem (`outputs/unsolved/ceva-mp-16-crash.log`)
|
| 114 |
+
|
| 115 |
+
For any triangle `ABC` and point `D`, points `E` is the interception of `AD` and `BC`, and so on for `F, G`. Prove that `AG/GB * BE/EC * CF/FA = 1` (a more general way to state the theorem considers sign of the segments and rhs is -1). Here we run into a limitation of AlphaGeometry: it does not support complex conclusions (goals to be proved) like the one in the Ceva's Theorem, only equality of two ratios. To work around this, I added an auxiliary point `H` on `AC` with `BH // EF`, and transformed the conclusion to `FH/FA = GB/GA`.
|
| 116 |
+
|
| 117 |
+
<center>
|
| 118 |
+
<img alt="ceva-manual.jpg" width="800px" src="outputs/unsolved/ceva-manual.jpg">
|
| 119 |
+
</center>
|
| 120 |
+
|
| 121 |
+
In my test this problem was not solved by AlphaGeometry after over 10k figures, see `outputs/unsolved/ceva-mp-16-crash.log`. The machine I used eventually ran out of memory as the figures got more complex. It's interesting to look at the auxiliary points AlphaGeometry attempted to add. To a human, observing that the problem is very general, there are very few relationships given, and the conclusion is about ratio of segments, it will be very natural to try to add parallel lines to construct similar triangles. Indeed, a typical solution only requires two auxiliary points, *e.g.* draw a line over `A` parallel to `BC`, extend `CD` and `BD` to meet this line. But only about 10% of AlphaGeometry's auxiliary points for this problem involve parallel lines. For this and other problems I tried, I find AlphaGeometry to prefer adding midpoints and mirror points around another point or a line. AlphaGeometry also seems to perform worse for problems like this one whose premises are simple with few relationships given.
|
| 122 |
+
|
| 123 |
+
# Plan for Future Developments
|
| 124 |
+
|
| 125 |
+
## Improve the Language Model that Adds Auxiliary Points
|
| 126 |
+
|
| 127 |
+
The DD+AR deduction engine can solve virtually any problem in a few minutes with household hardware. The performance of the system all hinges on the LM's ability to add auxiliary points effectively. As Google's paper mentions, the current model is trained on 100 million randomly generated problems, with nearly 10 million involving auxiliary points. Yet as we observed in the [Additional Problems and Test Results](#additional-problems-and-test-results) section above, the performance still has vast room to improve. Humans typically cannot try more than ~100 figures, but top human problem solvers perform better than what the current version of AlphaGeometry can do with thousands of times more attempts.
|
| 128 |
+
|
| 129 |
+
I believe this requires tuning the LM using data based on **human designed** problems. Although many strategic search type of problems have been solved very successfully by approaches based on first principles without requiring human inputs, such as Google Deepmind's AlphaZero for many challenging board and video games, math and scientific research in general and plane geometry in particular are different. Unlike the board and video games that have simple and clearly defined goals, other than a few areas such as proof of Riemann's Hypothesis, math and science research have no such simple and clearly defined final goals. The active research areas are defined by collective activities and interests of researchers in the fields. Even major breakthroughs such as calculus, theory of relativity and quantum mechanics were still pretty close to the frontier of human knowledge at their times. Looking at plane geometry in particular, it is not an active area of continued mathematical discovery any more, the interest in it is main for education, recreation and as test cases for AI research. So the performance of a problem solving system is measured by its ability to solve human designed problems. A system like the current version of AlphaGeometry trained on randomly generated problems may be strong in solving random problems, but not particularly strong in solving the kind of problems commonly of interest to humans, which are mostly **designed by humans** (instead of arising naturally in some way).
|
| 130 |
+
|
| 131 |
+
As Google's paper mentions, the challenge in training a model to solve plane geometry problem is the scarcity of data, that was one reason the authors used randomly generated problems. However, with the advent of the AlphaGeometry system, we can use AlphaGeometry itself as a platform to collect data. There are already some quite large plane geometry problem sets available in electronic form, such as [FormalGeo](https://github.com/FormalGeo/Datasets) with 7k problems. What's missing is for problems that require auxiliary points, knowing the auxiliary points that lead to the solution of the problem. This can be obtained either manually (if one knows the solution) or by successful solution by the latest version of AlphaGeometry or one of its improved versions such as AG4Masses. To estimate the number of data points needed, we again use human as reference. A top human problem solver is probably trained on less than 1k problems. If we can collect 10k problems with auxiliary points, I believe they can significantly improve the performance of the LM. The specific tasks include:
|
| 132 |
+
|
| 133 |
+
* Define a format to record problems and auxiliary points, enhance the AG4Masses code so when a problem is successfully solved, record the problem and auxiliary points in the standard format. Automatically submit the results to the AG4Masses project, with the user's consent. [Effort Level: low]
|
| 134 |
+
* Investigate ways to tune the LM. Google has not published the code and details for the training and tuning of the LM. The [Meliad](https://github.com/google-research/meliad) project AlphaGeometry uses does not have much documentation (other than several related published papers), so this may be challenging. [Effort Level: high]
|
| 135 |
+
* Tune the model once a meaningful amount of data are collected. I am not sure about the amount of computing power needed for this, need further investigation. [Effort Level: potentially high]
|
| 136 |
+
|
| 137 |
+
## Improve Problem Solving Strategy and Algorithm
|
| 138 |
+
|
| 139 |
+
When searching for auxiliary points, the current version of AlphaGeometry simply does a beam (breadth-first with pruning) search from the premises of the problem. A strategy commonly used by humans is to also look from the conclusion backwards: find sufficient conditions of the conclusion, and attempt to prove one of the sufficient conditions. Intuitively, this enlarges the goal we are searching for.
|
| 140 |
+
|
| 141 |
+
One way to look for sufficient conditions is to look for necessary conditions of the conclusion, i.e. what can be deduced from the problem's premises **and the conclusion**, then test whether the necessary conditions are also sufficient. This is especially effective for human designed problems because the authors of the problems usually have already made the problems as general as possible, i.e. there is usually no sufficient but not necessary conditions provable from the premises. The specific tasks are, at each step of the auxiliary point searching process:
|
| 142 |
+
|
| 143 |
+
* Add the conclusion of the problem into the premises (including the auxiliary points already added), use the DD+AR engine to find all necessary conditions (what can be deduced), and use DD+AR to verify whether each of them is a sufficient condition
|
| 144 |
+
* For each sufficient condition found, when running the LM to search for the next auxiliary point, change the conclusion to the sufficient condition
|
| 145 |
+
|
| 146 |
+
This should hopefully improve the effectiveness of the auxiliary points, but it needs to be balanced with the runtime cost incurred.
|
| 147 |
+
|
| 148 |
+
There may be other ways to improve the problem-solving strategy, such as combining hand-crafted heuristics with the LM model.
|
| 149 |
+
|
| 150 |
+
Effort Level: high, but more certain since it does not require changes to the LM itself
|
| 151 |
+
|
| 152 |
+
## Enhance the Range of Geometry Problems Handled by the System
|
| 153 |
+
|
| 154 |
+
AlphaGeometry's problem definition language is restrictive, for example:
|
| 155 |
+
|
| 156 |
+
* The premise specification does not allow construction of points based on ratio of segment lengths
|
| 157 |
+
* The conclusion specification does not allow complex conditions involving arithmetic, such as sum of length of 2 segments equaling length of another segment, or product of 3 segment length ratios, like in Ceva's Theorem
|
| 158 |
+
|
| 159 |
+
These limits the scope of problems that can be handled by the system. At least for the two examples mentioned above, it should not be too difficult to add them into the DD+AR part of the system, but the LM's performance for problems involving these new constructs may be degraded, since the LM model's training dataset does not contain such constructs. To maintain the performance of the LM model, we may need to wait for Google to publish the code and data set for LM model training. Even with the code and data, the computing power needed for retaining the model may be beyond the reach of an online community. Another possibility is to develop a way to transform such constructs to the ones AlphaGeometry already handles.
|
| 160 |
+
|
| 161 |
+
Effort Level: medium for extending DD+AR, high for ensuring performance of the LM for the new constructs
|
| 162 |
+
|
| 163 |
+
## Improve the User Friendliness and Robustness of the System
|
| 164 |
+
|
| 165 |
+
The AlphaGeometry system is not very user friendly, and not very robust. For example:
|
| 166 |
+
|
| 167 |
+
* The problem definition language syntax is very strict, it's sensitive to white spaces
|
| 168 |
+
* The code does not do a very good job checking correctness of problem definition. When a problem definition has errors or the proposition is false, the code often just freezes. When it catches a error, the error message is often hard to understand
|
| 169 |
+
* The LM does not always return valid auxiliary point construction. The code captures most of these, but there are still some uncaught ones that will cause the execution to abort
|
| 170 |
+
|
| 171 |
+
I already made some improvements in AG4Masses in these aspects, but more can be done.
|
| 172 |
+
|
| 173 |
+
Effort Level: low to medium
|
| 174 |
+
|
| 175 |
+
# Some Tips and Experiences about the AlphaGeometry System
|
| 176 |
+
|
| 177 |
+
Below are based on my testing and reading of the source code.
|
| 178 |
+
|
| 179 |
+
## The Problem Definition Language
|
| 180 |
+
|
| 181 |
+
Below is a problem from `alphageometry/examples.txt`:
|
| 182 |
+
|
| 183 |
+
```
|
| 184 |
+
orthocenter
|
| 185 |
+
a b c = triangle; h = on_tline b a c, on_tline c a b ? perp a h b c
|
| 186 |
+
```
|
| 187 |
+
|
| 188 |
+
* A problem consists of 2 lines, the first line is the name of the problem, the second line is the definition
|
| 189 |
+
* The problem definition is **sensitive to white spaces, including trailing ones**
|
| 190 |
+
* The problem definition consists of premises and a conclusion, separated by `' ? '`
|
| 191 |
+
* The premises consist of multiple clauses for constructing points, the best way to understand them is to think of the process of drawing the points one by one
|
| 192 |
+
* Multiple point-construction clauses are separated by `' ; '`. Note that the last one should **not** end with `' ; '`, before the `' ? '` separating the premises and the conclusion
|
| 193 |
+
* Some point-construction clauses can construct multiple points, such as `'a b c = triangle'`
|
| 194 |
+
* A point-construction clause consists of point names (separated by a single space), followed by `' = '`, and 1 or 2 "actions" (the term used in the Google paper), separated by `' , '`. See in the above example: `h = on_tline b a c, on_tline c a b`
|
| 195 |
+
* Actions are defined in the `alphageometry/defs.txt` file. They are also listed in the Google paper in *"Extended Data Table 1 | List of actions to construct the random premises"* (reproduced [here](data/ag_defs.jpg)). Each action is a constraint on the position of the point. Constructing a point using actions is similar to constructing it using straight edge and compass, *e.g.* find the point through intersection of 2 lines
|
| 196 |
+
* An action is similar to a function call, with other points being inputs and the point to be constructed being output
|
| 197 |
+
* Output point names can be optionally repeated in the beginning of the inputs (arguments) of the actions. For example, `h = on_tline b a c, on_tline c a b` can also be `h = on_tline h b a c, on_tline h c a b`. In `alphageometry/defs.txt` the output point names are repeated in front of the input point names. This sometimes makes the action clearer to read
|
| 198 |
+
* It's possible to add actions but it's not enough to just add into the `defs.txt` file. In `defs.txt`, each action is defined by 5 lines. The last line invoves functions needed for numerical checking that need to be implemented in Python
|
| 199 |
+
* The conclusion (goal) part of the problem can have one of the following statements:
|
| 200 |
+
* `coll a b c` : points `a b c` are collinear
|
| 201 |
+
* `cong a b c e` : segments `ab` and `cd` are congruent (length equal)
|
| 202 |
+
* `contri a b c p q r` : triangles `abc` and `pqr` are congruent
|
| 203 |
+
* `cyclic a b c d` : 4 points `a b c d` are cocyclic
|
| 204 |
+
* `eqangle a b c d p q r s` : the angles between lines `ab-cd` and `pq-rs` are equal. **Note that angles have directions (signs)** so the order between `a b` and `c d` matters. `eqangle a b c d c d a b` is false. The way to think about it is, angle `ab-cd` is the angle to turn line `ab` **clockwise** so it is parallel with the line `cd`. You can use counter-clockwise as the convention too, as long as for all angles the same convention is used
|
| 205 |
+
* `eqratio a b c d p q r s` : segment length `ab/cd = pq/rs`
|
| 206 |
+
* `midp m a b` : point `m` is the midpoint of `a` and `b`
|
| 207 |
+
* `para a b c d` : segments `ab` and `cd` are parallel
|
| 208 |
+
* `perp a b c d` : segments `ab` and `cd` are perpendicular to each other
|
| 209 |
+
* `simtri a b c p q r` : triangles `abc` and `pqr` are similar
|
| 210 |
+
|
| 211 |
+
## Some Tips
|
| 212 |
+
|
| 213 |
+
* **Angles have directions (signs)**. See the note for `eqangle` above. Attention needs to be paid both in the premise (point construction) part and the conclusion part of a problem
|
| 214 |
+
|
| 215 |
+
* AlphaGeometry does not do robust error checking of the problem or the proposition. If the problem has syntax errors or the proposition is false, it often freezes. To detect this, look at the log on stderr. AlphaGeometry will first try to solve the problem using DD+AR, and on stderr, you should see logs like this:
|
| 216 |
+
|
| 217 |
+
```
|
| 218 |
+
I0324 19:53:37.293019 123295230480384 graph.py:498] pascal
|
| 219 |
+
I0324 19:53:37.293379 123295230480384 graph.py:499] a = free a; b = free b; c = on_circle c a b; d = on_circle d a b; e = on_circle e a b; f = on_circle f a b; g = on_circle g a b; h = intersection_ll h b c e f; i = intersection_ll i c d f g; j = intersection_ll j d e g b ? coll h i j
|
| 220 |
+
I0324 19:53:38.638956 123295230480384 ddar.py:60] Depth 1/1000 time = 1.2907805442810059
|
| 221 |
+
I0324 19:53:42.962377 123295230480384 ddar.py:60] Depth 2/1000 time = 4.3230626583099365
|
| 222 |
+
I0324 19:53:47.302527 123295230480384 ddar.py:60] Depth 3/1000 time = 4.3398051261901855
|
| 223 |
+
```
|
| 224 |
+
|
| 225 |
+
Using the AG4Masses code, this should happen right away. Using the original AlphaGeometry code, when the model is `alphageometry`, it will take several minutes to get there because the original AlphaGeometry code loads the LM first. In any case, if you do not see this after several minutes, chances are there is an error in the syntax of the problem or the proposition is false.
|
| 226 |
+
|
| 227 |
+
One trick to error-check a problem's syntax and generate the diagram for the problem is to first use a trivial conclusion such as `cong a b a b`. If the rest of the problem is correct, it will be proven right away, and you will get a diagram generated by the code.
|
| 228 |
+
|
| 229 |
+
# Setup
|
| 230 |
+
|
| 231 |
+
The installation and setup process is similar to those for [alphageometry](https://github.com/google-deepmind/alphageometry) with some refinements.
|
| 232 |
+
|
| 233 |
+
## System and Python version
|
| 234 |
+
|
| 235 |
+
As of April 2024, AlphaGeometry seems to only run on Linux using Python 3.10. I had difficulties making Python module dependencies work on other versions of Python such as 3.11. It's also difficult to install different versions of Python on Linux, so the simplest approach is to use a version of Linux that comes with Python 3.10 installed. Ubuntu 22.04 and Mint 21.3 are two such Linux versions that worked for me.
|
| 236 |
+
|
| 237 |
+
If you don't have a dedicated computer for Linux, one solution is to run a virtual machine using [VirtualBox](https://www.virtualbox.org/). One way to get more computing power is to leverage the $300 free trial credit offered by [Google Cloud Platform](https://cloud.google.com/free?hl=en). A 16 vCPU 128 GB RAM Virtual Machine (machine type e2-himem-16) costs about $0.8/hour. Google Cloud also offers a much cheaper but unreliable type of 'Spot' machine ('VM provisioning model' = 'Spot' instead of 'Standard'), but they get preempted (shut down) every few hours. They may be useful for testing small problems but not suitable for runs lasting a long time.
|
| 238 |
+
|
| 239 |
+
## Choose file locations
|
| 240 |
+
|
| 241 |
+
It's cleaner to put source code, external library (not installed directly in Python virtual environment) and outputs in separate directories. In the `utils/run.sh` script, they are stored in several env vars. In this instruction we will use the same env vars to refer to them
|
| 242 |
+
```
|
| 243 |
+
# Directory where output files go
|
| 244 |
+
TESTDIR=$HOME/ag4mtest
|
| 245 |
+
# Directory containing AG4Masses source files
|
| 246 |
+
AG4MDIR=$HOME/ag4masses
|
| 247 |
+
# Directory containing external libraries including ag_ckpt_vocab and meliad
|
| 248 |
+
AGLIB=$HOME/aglib
|
| 249 |
+
```
|
| 250 |
+
|
| 251 |
+
Instructions below assume you want to put these directories in `$HOME`. If you want to put them somewhere else, just replace `$HOME` with the directory you want to use, and they don't need to be the same for the 3 directories.
|
| 252 |
+
|
| 253 |
+
## Download source and data files
|
| 254 |
+
```
|
| 255 |
+
cd $HOME
|
| 256 |
+
git clone https://github.com/tpgh24/ag4masses.git
|
| 257 |
+
|
| 258 |
+
mkdir $AGLIB
|
| 259 |
+
cd $AGLIB
|
| 260 |
+
git clone https://github.com/google-research/meliad
|
| 261 |
+
|
| 262 |
+
mkdir $AGLIB/ag_ckpt_vocab
|
| 263 |
+
```
|
| 264 |
+
|
| 265 |
+
Download the following files from https://bit.ly/alphageometry into `$AGLIB/ag_ckpt_vocab` . They are weights and vocabulary for the LM. They are on Google Drive, `alphageomrtry/download.sh` provided by Google uses `gdown` to download them, but it did not work for me. You can just download them using a web browser.
|
| 266 |
+
* checkpoint_10999999
|
| 267 |
+
* geometry.757.model
|
| 268 |
+
* geometry.757.vocab
|
| 269 |
+
|
| 270 |
+
## Install necessary Linux packages
|
| 271 |
+
|
| 272 |
+
Depending on the exact Linux distribution/version, you may need to install these packages if they are not already installed.
|
| 273 |
+
```
|
| 274 |
+
sudo apt update
|
| 275 |
+
sudo apt install python3-virtualenv
|
| 276 |
+
sudo apt install python3-tk
|
| 277 |
+
```
|
| 278 |
+
|
| 279 |
+
## Install Python module dependencies
|
| 280 |
+
|
| 281 |
+
For AG4Masses, Python is run in a virtual env. Instructions below assume the virtual env is located in `$HOME/pyve`.
|
| 282 |
+
|
| 283 |
+
```
|
| 284 |
+
virtualenv -p python3 $HOME/pyve
|
| 285 |
+
. $HOME/pyve/bin/activate
|
| 286 |
+
cd $AG4MDIR/alphageometry
|
| 287 |
+
pip install --require-hashes --no-deps -r requirements.txt
|
| 288 |
+
```
|
| 289 |
+
**Note** that the original instruction in AlphaGeometry does not include the `--no-deps` flag. Without it, I was not able to run the command line above successfully.
|
| 290 |
+
|
| 291 |
+
## Run tests
|
| 292 |
+
|
| 293 |
+
Edit `utils/run_test.sh`, update env vars `TESTDIR, AG4MDIR, AGLIB` to match the locations you have chosen, as mentioned in [Choose file locations](#choose-file-locations) above. Then
|
| 294 |
+
|
| 295 |
+
```
|
| 296 |
+
cd $TESTDIR
|
| 297 |
+
$AG4MDIR/utils/run_tests.sh
|
| 298 |
+
```
|
| 299 |
+
This will write logs both to the terminal and file `$TESTDIR/test.log`. All tests except the last one `LmInferenceTest.test_lm_score_may_fail_numerically_for_external_meliad` should pass. The last test may fail because the Meliad library is not numerically stable, as noted in [AlphaGeometry Issues#14](https://github.com/google-deepmind/alphageometry/issues/14).
|
| 300 |
+
|
| 301 |
+
## Run AG4Masses
|
| 302 |
+
|
| 303 |
+
Use the wrapper script `utils/run.sh` to run AG4Masses. Edit it to adjust settings.
|
| 304 |
+
|
| 305 |
+
Update env vars `TESTDIR, AG4MDIR, AGLIB` to match the locations you have chosen, as mentioned in [Choose file locations](#choose-file-locations) above.
|
| 306 |
+
|
| 307 |
+
Update env vars `PROB_FILE, PROB` to point to the problem you want to solve. There are several problem sets provided:
|
| 308 |
+
|
| 309 |
+
* `$AG4MDIR/data/ag4m_problems.txt` : Additional problems provided by the AG4Masses project, including some classic problems described in the [Additional Problems and Test Results](#additional-problems-and-test-results) section above, such as the 5-circles problem, Napoleon problem, Butterfly problem, Ceva Theorem, *etc.*
|
| 310 |
+
* `$AG4MDIR/alphageometry/examples.txt` : from AlphaGeometry, a few test examples
|
| 311 |
+
* `$AG4MDIR/alphageometry/imo_ag_30.txt` : from AlphaGeometry, 30 IMO problems as described in the Google paper
|
| 312 |
+
* `$AG4MDIR/alphageometry/jgex_ag_231.txt` : from AlphaGeometry, 231 problems originally from the [Java-Geometry-Expert](https://github.com/yezheng1981/Java-Geometry-Expert) project as described in the Google paper
|
| 313 |
+
|
| 314 |
+
Set the model you want to run through env var `MODEL`:
|
| 315 |
+
* `ddar` : DD+AR only
|
| 316 |
+
* `alphageometry` : AlphaGeometry/AG4Masses, with LM assisted auxiliary point addition
|
| 317 |
+
|
| 318 |
+
There are several other parameters you can set to control the behavior of the model, see comments in `run.sh`:
|
| 319 |
+
|
| 320 |
+
```
|
| 321 |
+
# BATCH_SIZE: number of outputs for each LM query
|
| 322 |
+
# BEAM_SIZE: size of the breadth-first search queue
|
| 323 |
+
# DEPTH: search depth (number of auxiliary points to add)
|
| 324 |
+
# NWORKERS: number of parallel run worker processes. Rule of thumb: on a 128G machine with 16 logical CPUs,
|
| 325 |
+
# use NWORKERS=8, BATCH_SIZE=24.
|
| 326 |
+
#
|
| 327 |
+
# Memory usage is affected by BATCH_SIZE, NWORKER and complexity of the problem.
|
| 328 |
+
# Larger NWORKER and BATCH_SIZE tends to cause out of memory issue
|
| 329 |
+
|
| 330 |
+
BATCH_SIZE=8
|
| 331 |
+
BEAM_SIZE=32
|
| 332 |
+
DEPTH=8
|
| 333 |
+
NWORKERS=1
|
| 334 |
+
```
|
| 335 |
+
|
| 336 |
+
The stdout and stderr are written to both the terminal and the file `$TESTDIR/ag.err`. If a problem is solved, the solution is written to `$TESTDIR/ag.out`. You can edit env var `ERRFILE, OUTFILE` to change the file names.
|
| 337 |
+
|
| 338 |
+
# Directory Layout
|
| 339 |
+
* `alphageometry` : alphageometry source code
|
| 340 |
+
* `data` : data files such as problem sets
|
| 341 |
+
* `outputs` : test results, logs from ag4masses runs
|
| 342 |
+
* `utils` : utility scripts
|
| 343 |
+
* `checkprog.sh` : when AG4Masses is running, show progress based on information written to stderr
|
| 344 |
+
* `mklog.py` : process AG4Masses stderr output files to create cleaner log files
|
| 345 |
+
* `run.sh` : wrapper to run AG4Masses with proper settings
|
| 346 |
+
* `run_test.sh` : run tests to check that AG4Masses is installed correctly
|
backend/core/ag4masses/alphageometry/CONTRIBUTING.md
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# How to Contribute
|
| 2 |
+
|
| 3 |
+
## Contributor License Agreement
|
| 4 |
+
|
| 5 |
+
Contributions to this project must be accompanied by a Contributor License
|
| 6 |
+
Agreement. You (or your employer) retain the copyright to your contribution,
|
| 7 |
+
this simply gives us permission to use and redistribute your contributions as
|
| 8 |
+
part of the project. Head over to <https://cla.developers.google.com/> to see
|
| 9 |
+
your current agreements on file or to sign a new one.
|
| 10 |
+
|
| 11 |
+
You generally only need to submit a CLA once, so if you've already submitted one
|
| 12 |
+
(even if it was for a different project), you probably don't need to do it
|
| 13 |
+
again.
|
| 14 |
+
|
| 15 |
+
## Code reviews
|
| 16 |
+
|
| 17 |
+
All submissions, including submissions by project members, require review. We
|
| 18 |
+
use GitHub pull requests for this purpose. Consult
|
| 19 |
+
[GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
|
| 20 |
+
information on using pull requests.
|
| 21 |
+
|
| 22 |
+
## Community Guidelines
|
| 23 |
+
|
| 24 |
+
This project follows [Google's Open Source Community
|
| 25 |
+
Guidelines](https://opensource.google/conduct/).
|
backend/core/ag4masses/alphageometry/alphageometry.py
ADDED
|
@@ -0,0 +1,778 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 DeepMind Technologies Limited
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
|
| 16 |
+
"""Run DD+AR or AlphaGeometry solver.
|
| 17 |
+
|
| 18 |
+
Please refer to README.md for detailed instructions.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
import time
|
| 22 |
+
import traceback
|
| 23 |
+
|
| 24 |
+
from absl import app
|
| 25 |
+
from absl import flags
|
| 26 |
+
from absl import logging
|
| 27 |
+
import ddar
|
| 28 |
+
import graph as gh
|
| 29 |
+
import lm_inference as lm
|
| 30 |
+
import pretty as pt
|
| 31 |
+
import problem as pr
|
| 32 |
+
|
| 33 |
+
#=============
|
| 34 |
+
import sys, os, math, re
|
| 35 |
+
import multiprocessing
|
| 36 |
+
model = None # global variable used in multi-processing workers
|
| 37 |
+
|
| 38 |
+
_GIN_SEARCH_PATHS = flags.DEFINE_list(
|
| 39 |
+
'gin_search_paths',
|
| 40 |
+
['third_party/py/meliad/transformer/configs'],
|
| 41 |
+
'List of paths where the Gin config files are located.',
|
| 42 |
+
)
|
| 43 |
+
_GIN_FILE = flags.DEFINE_multi_string(
|
| 44 |
+
'gin_file', ['base_htrans.gin'], 'List of Gin config files.'
|
| 45 |
+
)
|
| 46 |
+
_GIN_PARAM = flags.DEFINE_multi_string(
|
| 47 |
+
'gin_param', None, 'Newline separated list of Gin parameter bindings.'
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
_PROBLEMS_FILE = flags.DEFINE_string(
|
| 51 |
+
'problems_file',
|
| 52 |
+
'imo_ag_30.txt',
|
| 53 |
+
'text file contains the problem strings. See imo_ag_30.txt for example.',
|
| 54 |
+
)
|
| 55 |
+
_PROBLEM_NAME = flags.DEFINE_string(
|
| 56 |
+
'problem_name',
|
| 57 |
+
'imo_2000_p1',
|
| 58 |
+
'name of the problem to solve, must be in the problem_file.',
|
| 59 |
+
)
|
| 60 |
+
_MODE = flags.DEFINE_string(
|
| 61 |
+
'mode', 'ddar', 'either `ddar` (DD+AR) or `alphageometry`')
|
| 62 |
+
_DEFS_FILE = flags.DEFINE_string(
|
| 63 |
+
'defs_file',
|
| 64 |
+
'defs.txt',
|
| 65 |
+
'definitions of available constructions to state a problem.',
|
| 66 |
+
)
|
| 67 |
+
_RULES_FILE = flags.DEFINE_string(
|
| 68 |
+
'rules_file', 'rules.txt', 'list of deduction rules used by DD.'
|
| 69 |
+
)
|
| 70 |
+
_CKPT_PATH = flags.DEFINE_string('ckpt_path', '', 'checkpoint of the LM model.')
|
| 71 |
+
_VOCAB_PATH = flags.DEFINE_string(
|
| 72 |
+
'vocab_path', '', 'path to the LM vocab file.'
|
| 73 |
+
)
|
| 74 |
+
_OUT_FILE = flags.DEFINE_string(
|
| 75 |
+
'out_file', '', 'path to the solution output file.'
|
| 76 |
+
) # pylint: disable=line-too-long
|
| 77 |
+
_BEAM_SIZE = flags.DEFINE_integer(
|
| 78 |
+
'beam_size', 1, 'beam size of the proof search.'
|
| 79 |
+
) # pylint: disable=line-too-long
|
| 80 |
+
_SEARCH_DEPTH = flags.DEFINE_integer(
|
| 81 |
+
'search_depth', 1, 'search depth of the proof search.'
|
| 82 |
+
) # pylint: disable=line-too-long
|
| 83 |
+
|
| 84 |
+
#===================================
|
| 85 |
+
_N_WORKSERS = flags.DEFINE_integer(
|
| 86 |
+
'n_workers', 1, 'number of workers'
|
| 87 |
+
)# pylint: disable=line-too-long
|
| 88 |
+
|
| 89 |
+
DEFINITIONS = None # contains definitions of construction actions
|
| 90 |
+
RULES = None # contains rules of deductions
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def natural_language_statement(logical_statement: pr.Dependency) -> str:
|
| 94 |
+
"""Convert logical_statement to natural language.
|
| 95 |
+
|
| 96 |
+
Args:
|
| 97 |
+
logical_statement: pr.Dependency with .name and .args
|
| 98 |
+
|
| 99 |
+
Returns:
|
| 100 |
+
a string of (pseudo) natural language of the predicate for human reader.
|
| 101 |
+
"""
|
| 102 |
+
names = [a.name.upper() for a in logical_statement.args]
|
| 103 |
+
names = [(n[0] + '_' + n[1:]) if len(n) > 1 else n for n in names]
|
| 104 |
+
return pt.pretty_nl(logical_statement.name, names)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def proof_step_string(
|
| 108 |
+
proof_step: pr.Dependency, refs: dict[tuple[str, ...], int], last_step: bool
|
| 109 |
+
) -> str:
|
| 110 |
+
"""Translate proof to natural language.
|
| 111 |
+
|
| 112 |
+
Args:
|
| 113 |
+
proof_step: pr.Dependency with .name and .args
|
| 114 |
+
refs: dict(hash: int) to keep track of derived predicates
|
| 115 |
+
last_step: boolean to keep track whether this is the last step.
|
| 116 |
+
|
| 117 |
+
Returns:
|
| 118 |
+
a string of (pseudo) natural language of the proof step for human reader.
|
| 119 |
+
"""
|
| 120 |
+
premises, [conclusion] = proof_step
|
| 121 |
+
|
| 122 |
+
premises_nl = ' & '.join(
|
| 123 |
+
[
|
| 124 |
+
natural_language_statement(p) + ' [{:02}]'.format(refs[p.hashed()])
|
| 125 |
+
for p in premises
|
| 126 |
+
]
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
if not premises:
|
| 130 |
+
premises_nl = 'similarly'
|
| 131 |
+
|
| 132 |
+
refs[conclusion.hashed()] = len(refs)
|
| 133 |
+
|
| 134 |
+
conclusion_nl = natural_language_statement(conclusion)
|
| 135 |
+
if not last_step:
|
| 136 |
+
conclusion_nl += ' [{:02}]'.format(refs[conclusion.hashed()])
|
| 137 |
+
|
| 138 |
+
return f'{premises_nl} \u21d2 {conclusion_nl}'
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def write_solution(g: gh.Graph, p: pr.Problem, out_file: str) -> None:
|
| 142 |
+
"""Output the solution to out_file.
|
| 143 |
+
|
| 144 |
+
Args:
|
| 145 |
+
g: gh.Graph object, containing the proof state.
|
| 146 |
+
p: pr.Problem object, containing the theorem.
|
| 147 |
+
out_file: file to write to, empty string to skip writing to file.
|
| 148 |
+
"""
|
| 149 |
+
setup, aux, proof_steps, refs = ddar.get_proof_steps(
|
| 150 |
+
g, p.goal, merge_trivials=False
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
solution = '\n=========================='
|
| 154 |
+
solution += '\n * From theorem premises:\n'
|
| 155 |
+
premises_nl = []
|
| 156 |
+
for premises, [points] in setup:
|
| 157 |
+
solution += ' '.join([p.name.upper() for p in points]) + ' '
|
| 158 |
+
if not premises:
|
| 159 |
+
continue
|
| 160 |
+
premises_nl += [
|
| 161 |
+
natural_language_statement(p) + ' [{:02}]'.format(refs[p.hashed()])
|
| 162 |
+
for p in premises
|
| 163 |
+
]
|
| 164 |
+
solution += ': Points\n' + '\n'.join(premises_nl)
|
| 165 |
+
|
| 166 |
+
solution += '\n\n * Auxiliary Constructions:\n'
|
| 167 |
+
aux_premises_nl = []
|
| 168 |
+
for premises, [points] in aux:
|
| 169 |
+
solution += ' '.join([p.name.upper() for p in points]) + ' '
|
| 170 |
+
aux_premises_nl += [
|
| 171 |
+
natural_language_statement(p) + ' [{:02}]'.format(refs[p.hashed()])
|
| 172 |
+
for p in premises
|
| 173 |
+
]
|
| 174 |
+
solution += ': Points\n' + '\n'.join(aux_premises_nl)
|
| 175 |
+
|
| 176 |
+
# some special case where the deduction rule has a well known name.
|
| 177 |
+
r2name = {
|
| 178 |
+
'r32': '(SSS)',
|
| 179 |
+
'r33': '(SAS)',
|
| 180 |
+
'r34': '(Similar Triangles)',
|
| 181 |
+
'r35': '(Similar Triangles)',
|
| 182 |
+
'r36': '(ASA)',
|
| 183 |
+
'r37': '(ASA)',
|
| 184 |
+
'r38': '(Similar Triangles)',
|
| 185 |
+
'r39': '(Similar Triangles)',
|
| 186 |
+
'r40': '(Congruent Triangles)',
|
| 187 |
+
'a00': '(Distance chase)',
|
| 188 |
+
'a01': '(Ratio chase)',
|
| 189 |
+
'a02': '(Angle chase)',
|
| 190 |
+
}
|
| 191 |
+
|
| 192 |
+
solution += '\n\n * Proof steps:\n'
|
| 193 |
+
for i, step in enumerate(proof_steps):
|
| 194 |
+
_, [con] = step
|
| 195 |
+
nl = proof_step_string(step, refs, last_step=i == len(proof_steps) - 1)
|
| 196 |
+
rule_name = r2name.get(con.rule_name, '')
|
| 197 |
+
nl = nl.replace('\u21d2', f'{rule_name}\u21d2 ')
|
| 198 |
+
solution += '{:03}. '.format(i + 1) + nl + '\n'
|
| 199 |
+
|
| 200 |
+
solution += '==========================\n'
|
| 201 |
+
logging.info(solution)
|
| 202 |
+
if out_file:
|
| 203 |
+
with open(out_file, 'w') as f:
|
| 204 |
+
f.write(solution)
|
| 205 |
+
logging.info('Solution written to %s.', out_file)
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def get_lm(ckpt_init: str, vocab_path: str) -> lm.LanguageModelInference:
|
| 209 |
+
lm.parse_gin_configuration(
|
| 210 |
+
_GIN_FILE.value, _GIN_PARAM.value, gin_paths=_GIN_SEARCH_PATHS.value
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
return lm.LanguageModelInference(vocab_path, ckpt_init, mode='beam_search')
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def run_ddar(g: gh.Graph, p: pr.Problem, out_file: str) -> bool:
|
| 217 |
+
"""Run DD+AR.
|
| 218 |
+
|
| 219 |
+
Args:
|
| 220 |
+
g: gh.Graph object, containing the proof state.
|
| 221 |
+
p: pr.Problem object, containing the problem statement.
|
| 222 |
+
out_file: path to output file if solution is found.
|
| 223 |
+
|
| 224 |
+
Returns:
|
| 225 |
+
Boolean, whether DD+AR finishes successfully.
|
| 226 |
+
"""
|
| 227 |
+
ddar.solve(g, RULES, p, max_level=1000)
|
| 228 |
+
|
| 229 |
+
goal_args = g.names2nodes(p.goal.args)
|
| 230 |
+
if not g.check(p.goal.name, goal_args):
|
| 231 |
+
logging.info('DD+AR failed to solve the problem.')
|
| 232 |
+
return False
|
| 233 |
+
|
| 234 |
+
write_solution(g, p, out_file)
|
| 235 |
+
|
| 236 |
+
gh.nm.draw(
|
| 237 |
+
g.type2nodes[gh.Point],
|
| 238 |
+
g.type2nodes[gh.Line],
|
| 239 |
+
g.type2nodes[gh.Circle],
|
| 240 |
+
g.type2nodes[gh.Segment])
|
| 241 |
+
return True
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def translate_constrained_to_constructive(
|
| 245 |
+
point: str, name: str, args: list[str]
|
| 246 |
+
) -> tuple[str, list[str]]:
|
| 247 |
+
"""Translate a predicate from constraint-based to construction-based.
|
| 248 |
+
|
| 249 |
+
Args:
|
| 250 |
+
point: str: name of the new point
|
| 251 |
+
name: str: name of the predicate, e.g., perp, para, etc.
|
| 252 |
+
args: list[str]: list of predicate args.
|
| 253 |
+
|
| 254 |
+
Returns:
|
| 255 |
+
(name, args): translated to constructive predicate.
|
| 256 |
+
"""
|
| 257 |
+
if name in ['T', 'perp']:
|
| 258 |
+
a, b, c, d = args
|
| 259 |
+
if point in [c, d]:
|
| 260 |
+
a, b, c, d = c, d, a, b
|
| 261 |
+
if point == b:
|
| 262 |
+
a, b = b, a
|
| 263 |
+
if point == d:
|
| 264 |
+
c, d = d, c
|
| 265 |
+
if a == c and a == point:
|
| 266 |
+
return 'on_dia', [a, b, d]
|
| 267 |
+
return 'on_tline', [a, b, c, d]
|
| 268 |
+
|
| 269 |
+
elif name in ['P', 'para']:
|
| 270 |
+
a, b, c, d = args
|
| 271 |
+
if point in [c, d]:
|
| 272 |
+
a, b, c, d = c, d, a, b
|
| 273 |
+
if point == b:
|
| 274 |
+
a, b = b, a
|
| 275 |
+
return 'on_pline', [a, b, c, d]
|
| 276 |
+
|
| 277 |
+
elif name in ['D', 'cong']:
|
| 278 |
+
a, b, c, d = args
|
| 279 |
+
if point in [c, d]:
|
| 280 |
+
a, b, c, d = c, d, a, b
|
| 281 |
+
if point == b:
|
| 282 |
+
a, b = b, a
|
| 283 |
+
if point == d:
|
| 284 |
+
c, d = d, c
|
| 285 |
+
if a == c and a == point:
|
| 286 |
+
return 'on_bline', [a, b, d]
|
| 287 |
+
if b in [c, d]:
|
| 288 |
+
if b == d:
|
| 289 |
+
c, d = d, c # pylint: disable=unused-variable
|
| 290 |
+
return 'on_circle', [a, b, d]
|
| 291 |
+
return 'eqdistance', [a, b, c, d]
|
| 292 |
+
|
| 293 |
+
elif name in ['C', 'coll']:
|
| 294 |
+
a, b, c = args
|
| 295 |
+
if point == b:
|
| 296 |
+
a, b = b, a
|
| 297 |
+
if point == c:
|
| 298 |
+
a, b, c = c, a, b
|
| 299 |
+
return 'on_line', [a, b, c]
|
| 300 |
+
|
| 301 |
+
elif name in ['^', 'eqangle']:
|
| 302 |
+
a, b, c, d, e, f = args
|
| 303 |
+
|
| 304 |
+
if point in [d, e, f]:
|
| 305 |
+
a, b, c, d, e, f = d, e, f, a, b, c
|
| 306 |
+
|
| 307 |
+
x, b, y, c, d = b, c, e, d, f
|
| 308 |
+
if point == b:
|
| 309 |
+
a, b, c, d = b, a, d, c
|
| 310 |
+
|
| 311 |
+
if point == d and x == y: # x p x b = x c x p
|
| 312 |
+
return 'angle_bisector', [point, b, x, c]
|
| 313 |
+
|
| 314 |
+
if point == x:
|
| 315 |
+
return 'eqangle3', [x, a, b, y, c, d]
|
| 316 |
+
|
| 317 |
+
return 'on_aline', [a, x, b, c, y, d]
|
| 318 |
+
|
| 319 |
+
elif name in ['cyclic', 'O']:
|
| 320 |
+
a, b, c = [x for x in args if x != point]
|
| 321 |
+
return 'on_circum', [point, a, b, c]
|
| 322 |
+
|
| 323 |
+
return name, args
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
def check_valid_args(name: str, args: list[str]) -> bool:
|
| 327 |
+
"""Check whether a predicate is grammarically correct.
|
| 328 |
+
|
| 329 |
+
Args:
|
| 330 |
+
name: str: name of the predicate
|
| 331 |
+
args: list[str]: args of the predicate
|
| 332 |
+
|
| 333 |
+
Returns:
|
| 334 |
+
bool: whether the predicate arg count is valid.
|
| 335 |
+
"""
|
| 336 |
+
if name == 'perp':
|
| 337 |
+
if len(args) != 4:
|
| 338 |
+
return False
|
| 339 |
+
a, b, c, d = args
|
| 340 |
+
if len({a, b}) < 2:
|
| 341 |
+
return False
|
| 342 |
+
if len({c, d}) < 2:
|
| 343 |
+
return False
|
| 344 |
+
elif name == 'para':
|
| 345 |
+
if len(args) != 4:
|
| 346 |
+
return False
|
| 347 |
+
a, b, c, d = args
|
| 348 |
+
if len({a, b, c, d}) < 4:
|
| 349 |
+
return False
|
| 350 |
+
elif name == 'cong':
|
| 351 |
+
if len(args) != 4:
|
| 352 |
+
return False
|
| 353 |
+
a, b, c, d = args
|
| 354 |
+
if len({a, b}) < 2:
|
| 355 |
+
return False
|
| 356 |
+
if len({c, d}) < 2:
|
| 357 |
+
return False
|
| 358 |
+
elif name == 'coll':
|
| 359 |
+
if len(args) != 3:
|
| 360 |
+
return False
|
| 361 |
+
a, b, c = args
|
| 362 |
+
if len({a, b, c}) < 3:
|
| 363 |
+
return False
|
| 364 |
+
elif name == 'cyclic':
|
| 365 |
+
if len(args) != 4:
|
| 366 |
+
return False
|
| 367 |
+
a, b, c, d = args
|
| 368 |
+
if len({a, b, c, d}) < 4:
|
| 369 |
+
return False
|
| 370 |
+
elif name == 'eqangle':
|
| 371 |
+
if len(args) != 8:
|
| 372 |
+
return False
|
| 373 |
+
a, b, c, d, e, f, g, h = args
|
| 374 |
+
if len({a, b, c, d}) < 3:
|
| 375 |
+
return False
|
| 376 |
+
if len({e, f, g, h}) < 3:
|
| 377 |
+
return False
|
| 378 |
+
return True
|
| 379 |
+
|
| 380 |
+
|
| 381 |
+
def try_translate_constrained_to_construct(string: str, g: gh.Graph) -> str:
|
| 382 |
+
"""Whether a string of aux construction can be constructed.
|
| 383 |
+
|
| 384 |
+
Args:
|
| 385 |
+
string: str: the string describing aux construction.
|
| 386 |
+
g: gh.Graph: the current proof state.
|
| 387 |
+
|
| 388 |
+
Returns:
|
| 389 |
+
str: whether this construction is valid. If not, starts with "ERROR:".
|
| 390 |
+
"""
|
| 391 |
+
if string[-1] != ';':
|
| 392 |
+
return 'ERROR: must end with ;'
|
| 393 |
+
|
| 394 |
+
logging.info(f'PID={os.getpid()}: !! try_translate_constrained_to_construct: string=%s', string)
|
| 395 |
+
|
| 396 |
+
# sometimes the LM may return ill-formed result with multiple colons.
|
| 397 |
+
# example:
|
| 398 |
+
#
|
| 399 |
+
# napoleon2
|
| 400 |
+
# a1 a2 a3 = triangle; c3 = s_angle a1 a2 c3 30, s_angle a2 a1 c3 150; c1 = s_angle a2 a3 c1 30, s_angle a3 a2 c1 150; c2 = s_angle a3 a1 c2 30, s_angle a1 a3 c2 150 ? cong c1 c2 c1 c3
|
| 401 |
+
#
|
| 402 |
+
# in the process,
|
| 403 |
+
# I0210 17:58:01.513668 140016515833856 alphageometry.py:550] Decoding from {S} a : ; b : ; c : ; d : ^ a d a b 5. pi / 6. 00 ^ b d b a 1. pi / 6. 01 ; e : ^ b e b c 5. pi / 6. 02 ^ c e c b 1. pi / 6. 03 ; f : ^ a f a c 1. pi / 6. 04 ^ c f c a 5. pi / 6. 05 ? D e f e d {F1} x00 g : C a b g 06 D a g b g 07 ; x00 h : C c b h 08 D c h b h 09 ; x00
|
| 404 |
+
# I0210 18:01:38.182158 140016515833856 alphageometry.py:384] !! try_translate_constrained_to_construct: string=i : C a c i 10 D a i c i 11 ? V d f {F1} x00 j : D g j h j 12 D h j i j 13 ;
|
| 405 |
+
|
| 406 |
+
#XXX
|
| 407 |
+
# str_parts = string.split(' : ')
|
| 408 |
+
# if len(str_parts) != 2:
|
| 409 |
+
# return f'ERROR: string has multiple colons: |{string}|'
|
| 410 |
+
mch = re.match('(.*?)( \? | \. \{)', string)
|
| 411 |
+
if mch :
|
| 412 |
+
strFixed = mch.group(1) + ';'
|
| 413 |
+
logging.info(f'ID={os.getpid()}: Bad LM output: {string}. Changed to {strFixed}')
|
| 414 |
+
string = strFixed
|
| 415 |
+
|
| 416 |
+
# sometimes the constraint in string is empty:
|
| 417 |
+
# 0407 17:11:35.470240 126383800963072 alphageometry.py:394] !! try_translate_constrained_to_construct: string=j : ;
|
| 418 |
+
hdprem = string.split(' : ')
|
| 419 |
+
if len(hdprem) !=2 or hdprem[1].strip()==';' :
|
| 420 |
+
logging.info(f'ID={os.getpid()}: Bad LM output: {string}. ERROR')
|
| 421 |
+
return f'ERROR: Bad LM output: {string}'
|
| 422 |
+
head, prem_str = hdprem
|
| 423 |
+
point = head.strip()
|
| 424 |
+
|
| 425 |
+
if len(point) != 1 or point == ' ':
|
| 426 |
+
return f'ERROR: invalid point name {point}'
|
| 427 |
+
|
| 428 |
+
existing_points = [p.name for p in g.all_points()]
|
| 429 |
+
if point in existing_points:
|
| 430 |
+
return f'ERROR: point {point} already exists.'
|
| 431 |
+
|
| 432 |
+
prem_toks = prem_str.split()[:-1] # remove the EOS ' ;'
|
| 433 |
+
prems = [[]]
|
| 434 |
+
|
| 435 |
+
for i, tok in enumerate(prem_toks):
|
| 436 |
+
if tok.isdigit():
|
| 437 |
+
if i < len(prem_toks) - 1:
|
| 438 |
+
prems.append([])
|
| 439 |
+
else:
|
| 440 |
+
prems[-1].append(tok)
|
| 441 |
+
|
| 442 |
+
if len(prems) > 2:
|
| 443 |
+
return 'ERROR: there cannot be more than two predicates.'
|
| 444 |
+
|
| 445 |
+
clause_txt = point + ' = '
|
| 446 |
+
constructions = []
|
| 447 |
+
|
| 448 |
+
for prem in prems:
|
| 449 |
+
name, *args = prem
|
| 450 |
+
|
| 451 |
+
if point not in args:
|
| 452 |
+
return f'ERROR: {point} not found in predicate args.'
|
| 453 |
+
|
| 454 |
+
if not check_valid_args(pt.map_symbol(name), args):
|
| 455 |
+
return 'ERROR: Invalid predicate ' + name + ' ' + ' '.join(args)
|
| 456 |
+
|
| 457 |
+
for a in args:
|
| 458 |
+
if a != point and a not in existing_points:
|
| 459 |
+
return f'ERROR: point {a} does not exist.'
|
| 460 |
+
|
| 461 |
+
try:
|
| 462 |
+
name, args = translate_constrained_to_constructive(point, name, args)
|
| 463 |
+
except: # pylint: disable=bare-except
|
| 464 |
+
return 'ERROR: Invalid predicate ' + name + ' ' + ' '.join(args)
|
| 465 |
+
|
| 466 |
+
if name == 'on_aline':
|
| 467 |
+
if args.count(point) > 1:
|
| 468 |
+
return f'ERROR: on_aline involves twice {point}'
|
| 469 |
+
|
| 470 |
+
constructions += [name + ' ' + ' '.join(args)]
|
| 471 |
+
|
| 472 |
+
clause_txt += ', '.join(constructions)
|
| 473 |
+
clause = pr.Clause.from_txt(clause_txt)
|
| 474 |
+
|
| 475 |
+
try:
|
| 476 |
+
g.copy().add_clause(clause, 0, DEFINITIONS)
|
| 477 |
+
except: # pylint: disable=bare-except
|
| 478 |
+
return 'ERROR: ' + traceback.format_exc()
|
| 479 |
+
|
| 480 |
+
return clause_txt
|
| 481 |
+
|
| 482 |
+
|
| 483 |
+
def insert_aux_to_premise(pstring: str, auxstring: str) -> str:
|
| 484 |
+
"""Insert auxiliary constructs from proof to premise.
|
| 485 |
+
|
| 486 |
+
Args:
|
| 487 |
+
pstring: str: describing the problem to solve.
|
| 488 |
+
auxstring: str: describing the auxiliar construction.
|
| 489 |
+
|
| 490 |
+
Returns:
|
| 491 |
+
str: new pstring with auxstring inserted before the conclusion.
|
| 492 |
+
"""
|
| 493 |
+
setup, goal = pstring.split(' ? ')
|
| 494 |
+
return setup + '; ' + auxstring + ' ? ' + goal
|
| 495 |
+
|
| 496 |
+
|
| 497 |
+
class BeamQueue:
|
| 498 |
+
"""Keep only the top k objects according to their values."""
|
| 499 |
+
|
| 500 |
+
def __init__(self, max_size: int = 512):
|
| 501 |
+
self.queue = []
|
| 502 |
+
self.max_size = max_size
|
| 503 |
+
|
| 504 |
+
def add(self, node: object, val: float) -> None:
|
| 505 |
+
"""Add a new node to this queue."""
|
| 506 |
+
|
| 507 |
+
if len(self.queue) < self.max_size:
|
| 508 |
+
self.queue.append((val, node))
|
| 509 |
+
return
|
| 510 |
+
|
| 511 |
+
# Find the minimum node:
|
| 512 |
+
min_idx, (min_val, _) = min(enumerate(self.queue), key=lambda x: x[1])
|
| 513 |
+
|
| 514 |
+
# replace it if the new node has higher value.
|
| 515 |
+
if val > min_val:
|
| 516 |
+
self.queue[min_idx] = (val, node)
|
| 517 |
+
|
| 518 |
+
def __iter__(self):
|
| 519 |
+
for val, node in self.queue:
|
| 520 |
+
yield val, node
|
| 521 |
+
|
| 522 |
+
def __len__(self) -> int:
|
| 523 |
+
return len(self.queue)
|
| 524 |
+
|
| 525 |
+
def bqsearch_init(worker_id):
|
| 526 |
+
# When using spawn or forkserver start method for multiprocessing.Pool, need to re-initialize
|
| 527 |
+
flags.FLAGS(sys.argv)
|
| 528 |
+
logging.use_absl_handler()
|
| 529 |
+
logging.set_verbosity(logging.INFO)
|
| 530 |
+
sys.setrecursionlimit(10000)
|
| 531 |
+
|
| 532 |
+
# Global variables initialized in main(). Need to re-initialize
|
| 533 |
+
#
|
| 534 |
+
# definitions of terms used in our domain-specific language.
|
| 535 |
+
global DEFINITIONS, RULES
|
| 536 |
+
DEFINITIONS = pr.Definition.from_txt_file(_DEFS_FILE.value, to_dict=True)
|
| 537 |
+
# load inference rules used in DD.
|
| 538 |
+
RULES = pr.Theorem.from_txt_file(_RULES_FILE.value, to_dict=True)
|
| 539 |
+
|
| 540 |
+
wkrpid = os.getpid()
|
| 541 |
+
logging.info('Worker %d initializing. PID=%d', worker_id, wkrpid)
|
| 542 |
+
|
| 543 |
+
if 'CUDA_VISIBLE_DEVICES' in os.environ and os.environ['CUDA_VISIBLE_DEVICES'].strip():
|
| 544 |
+
os.environ['CUDA_VISIBLE_DEVICES']=f"{worker_id}"
|
| 545 |
+
logging.info('Worker %d: CUDA_VISIBLE_DEVICES=%s', worker_id, os.environ['CUDA_VISIBLE_DEVICES'])
|
| 546 |
+
|
| 547 |
+
global model
|
| 548 |
+
model = get_lm(_CKPT_PATH.value, _VOCAB_PATH.value)
|
| 549 |
+
return wkrpid
|
| 550 |
+
|
| 551 |
+
def bqsearch(i_nd, srch_inputs, out_file) -> tuple[int, bool, list]: # ( iNode, solved, [ (node, score) ] )
|
| 552 |
+
pid = os.getpid()
|
| 553 |
+
logging.info(f'Worker PID={pid} called for beam search node {i_nd}')
|
| 554 |
+
|
| 555 |
+
prev_score, (g, string, pstring) = srch_inputs
|
| 556 |
+
logging.info(f'Worker PID={pid}: Beam-searching and Decoding from {string}')
|
| 557 |
+
outputs = model.beam_decode(string, eos_tokens=[';'])
|
| 558 |
+
|
| 559 |
+
# translate lm output to the constructive language.
|
| 560 |
+
# so that we can update the graph representing proof states:
|
| 561 |
+
translations = [
|
| 562 |
+
try_translate_constrained_to_construct(o, g)
|
| 563 |
+
for o in outputs['seqs_str']
|
| 564 |
+
]
|
| 565 |
+
|
| 566 |
+
# couple the lm outputs with its translations
|
| 567 |
+
candidates = zip(outputs['seqs_str'], translations, outputs['scores'])
|
| 568 |
+
|
| 569 |
+
# bring the highest scoring candidate first
|
| 570 |
+
candidates = reversed(list(candidates))
|
| 571 |
+
|
| 572 |
+
ret = []
|
| 573 |
+
for lm_out, translation, score in candidates:
|
| 574 |
+
logging.info(f'Worker PID={pid}: LM output (score={score}): "{lm_out}"')
|
| 575 |
+
logging.info(f'Worker PID={pid}: Translation: "{translation}"')
|
| 576 |
+
|
| 577 |
+
if translation.startswith('ERROR:'):
|
| 578 |
+
# the construction is invalid.
|
| 579 |
+
continue
|
| 580 |
+
|
| 581 |
+
# Update the constructive statement of the problem with the aux point:
|
| 582 |
+
candidate_pstring = insert_aux_to_premise(pstring, translation)
|
| 583 |
+
|
| 584 |
+
#XXX
|
| 585 |
+
logging.info(f'Worker PID={pid}: string=|{string}| lm_out=|{lm_out}|')
|
| 586 |
+
logging.info(f'Worker PID={pid}: Solving: "{candidate_pstring}"')
|
| 587 |
+
p_new = pr.Problem.from_txt(candidate_pstring)
|
| 588 |
+
|
| 589 |
+
# This is the new proof state graph representation:
|
| 590 |
+
g_new, _ = gh.Graph.build_problem(p_new, DEFINITIONS)
|
| 591 |
+
|
| 592 |
+
try:
|
| 593 |
+
if run_ddar(g_new, p_new, out_file):
|
| 594 |
+
logging.info(f'Worker PID={pid}: Solved.')
|
| 595 |
+
return (i_nd, True, None)
|
| 596 |
+
except Exception as e:
|
| 597 |
+
logging.info(f'Worker PID={pid}: Error in run_ddar: {e}')
|
| 598 |
+
|
| 599 |
+
# Add the candidate to the beam queue.
|
| 600 |
+
ret.append( [
|
| 601 |
+
# The string for the new node is old_string + lm output +
|
| 602 |
+
# the special token asking for a new auxiliary point ' x00':
|
| 603 |
+
# node
|
| 604 |
+
(g_new, string + ' ' + lm_out + ' x00', candidate_pstring),
|
| 605 |
+
# the score of each node is sum of score of all nodes
|
| 606 |
+
# on the path to itself. For beam search, there is no need to
|
| 607 |
+
# normalize according to path length because all nodes in beam
|
| 608 |
+
# is of the same path length.
|
| 609 |
+
# val
|
| 610 |
+
prev_score + score ]
|
| 611 |
+
)
|
| 612 |
+
|
| 613 |
+
logging.info(f'Worker PID={pid} beam search node {i_nd}: returning')
|
| 614 |
+
return (i_nd, False, ret)
|
| 615 |
+
|
| 616 |
+
def run_alphageometry(
|
| 617 |
+
#XX model: lm.LanguageModelInference,
|
| 618 |
+
p: pr.Problem,
|
| 619 |
+
search_depth: int,
|
| 620 |
+
beam_size: int,
|
| 621 |
+
out_file: str,
|
| 622 |
+
) -> bool:
|
| 623 |
+
"""Simplified code to run AlphaGeometry proof search.
|
| 624 |
+
|
| 625 |
+
We removed all optimizations that are infrastructure-dependent, e.g.
|
| 626 |
+
parallelized model inference on multi GPUs,
|
| 627 |
+
parallelized DD+AR on multiple CPUs,
|
| 628 |
+
parallel execution of LM and DD+AR,
|
| 629 |
+
shared pool of CPU workers across different problems, etc.
|
| 630 |
+
|
| 631 |
+
Many other speed optimizations and abstractions are also removed to
|
| 632 |
+
better present the core structure of the proof search.
|
| 633 |
+
|
| 634 |
+
Args:
|
| 635 |
+
model: Interface with inference-related endpoints to JAX's model.
|
| 636 |
+
p: pr.Problem object describing the problem to solve.
|
| 637 |
+
search_depth: max proof search depth.
|
| 638 |
+
beam_size: beam size of the proof search.
|
| 639 |
+
out_file: path to output file if solution is found.
|
| 640 |
+
|
| 641 |
+
Returns:
|
| 642 |
+
boolean of whether this is solved.
|
| 643 |
+
"""
|
| 644 |
+
# translate the problem to a string of grammar that the LM is trained on.
|
| 645 |
+
string = p.setup_str_from_problem(DEFINITIONS)
|
| 646 |
+
# special tokens prompting the LM to generate auxiliary points.
|
| 647 |
+
string += ' {F1} x00'
|
| 648 |
+
# the graph to represent the proof state.
|
| 649 |
+
g, _ = gh.Graph.build_problem(p, DEFINITIONS)
|
| 650 |
+
|
| 651 |
+
# First we run the symbolic engine DD+AR:
|
| 652 |
+
if run_ddar(g, p, out_file):
|
| 653 |
+
return True
|
| 654 |
+
|
| 655 |
+
# ?? when pickling graph for some problems, the default recursion limit 1000 is not enough,
|
| 656 |
+
# got 'maximum recursion depth exceeded while pickling an object' error
|
| 657 |
+
sys.setrecursionlimit(10000)
|
| 658 |
+
|
| 659 |
+
# beam search for the proof
|
| 660 |
+
# each node in the search tree is a 3-tuple:
|
| 661 |
+
# (<graph representation of proof state>,
|
| 662 |
+
# <string for LM to decode from>,
|
| 663 |
+
# <original problem string>)
|
| 664 |
+
beam_queue = BeamQueue(max_size=beam_size)
|
| 665 |
+
# originally the beam search tree starts with a single node (a 3-tuple):
|
| 666 |
+
beam_queue.add(
|
| 667 |
+
node=(g, string, p.txt()), val=0.0 # value of the root node is simply 0.
|
| 668 |
+
)
|
| 669 |
+
|
| 670 |
+
pool = None
|
| 671 |
+
if _N_WORKSERS.value == 1:
|
| 672 |
+
bqsearch_init(0)
|
| 673 |
+
else:
|
| 674 |
+
# Default is 'fork' on Linux, does not work with CUDA. Need to use 'spawn' or 'forkserver'
|
| 675 |
+
multiprocessing.set_start_method('spawn')
|
| 676 |
+
pool = multiprocessing.Pool(_N_WORKSERS.value)
|
| 677 |
+
|
| 678 |
+
logging.info("Initializing workers")
|
| 679 |
+
wkrpids = pool.map(bqsearch_init, range(_N_WORKSERS.value))
|
| 680 |
+
logging.info("Worker PIDs: " + str(wkrpids))
|
| 681 |
+
|
| 682 |
+
for depth in range(search_depth):
|
| 683 |
+
logging.info(
|
| 684 |
+
'Depth %s. There are %i nodes to expand:', depth, len(beam_queue)
|
| 685 |
+
)
|
| 686 |
+
for _, (_, string, _) in beam_queue:
|
| 687 |
+
logging.info(string)
|
| 688 |
+
|
| 689 |
+
new_queue = BeamQueue(max_size=beam_size) # to replace beam_queue.
|
| 690 |
+
if _N_WORKSERS.value==1:
|
| 691 |
+
for i, srch_inputs in enumerate(beam_queue):
|
| 692 |
+
_, solved, res = bqsearch(i, srch_inputs, out_file)
|
| 693 |
+
if solved:
|
| 694 |
+
return True
|
| 695 |
+
for node, val in res:
|
| 696 |
+
# Add the candidate to the beam queue.
|
| 697 |
+
new_queue.add(node, val)
|
| 698 |
+
# Note that the queue only maintain at most beam_size nodes
|
| 699 |
+
# so this new node might possibly be dropped depending on its value.
|
| 700 |
+
else:
|
| 701 |
+
jobs = [pool.apply_async(bqsearch, (i, srch_inputs, out_file)) for i, srch_inputs in enumerate(beam_queue)]
|
| 702 |
+
|
| 703 |
+
n_done = 0
|
| 704 |
+
while n_done < len(beam_queue):
|
| 705 |
+
for i, jobres in enumerate(jobs):
|
| 706 |
+
if jobres and jobres.ready():
|
| 707 |
+
n_done += 1
|
| 708 |
+
jobs[i] = None
|
| 709 |
+
_, solved, res = jobres.get()
|
| 710 |
+
if solved:
|
| 711 |
+
# Clean up resources
|
| 712 |
+
pool.terminate()
|
| 713 |
+
pool.join()
|
| 714 |
+
return True
|
| 715 |
+
for node, val in res:
|
| 716 |
+
# Add the candidate to the beam queue.
|
| 717 |
+
new_queue.add(node, val)
|
| 718 |
+
# Note that the queue only maintain at most beam_size nodes
|
| 719 |
+
# so this new node might possibly be dropped depending on its value.
|
| 720 |
+
time.sleep(1) # Adjust wait time as needed
|
| 721 |
+
|
| 722 |
+
# replace the old queue with new queue before the new proof search depth.
|
| 723 |
+
beam_queue = new_queue
|
| 724 |
+
|
| 725 |
+
# Clean up resources
|
| 726 |
+
if pool:
|
| 727 |
+
pool.terminate()
|
| 728 |
+
pool.join()
|
| 729 |
+
return False
|
| 730 |
+
|
| 731 |
+
def main(_):
|
| 732 |
+
global DEFINITIONS
|
| 733 |
+
global RULES
|
| 734 |
+
|
| 735 |
+
# definitions of terms used in our domain-specific language.
|
| 736 |
+
DEFINITIONS = pr.Definition.from_txt_file(_DEFS_FILE.value, to_dict=True)
|
| 737 |
+
# load inference rules used in DD.
|
| 738 |
+
RULES = pr.Theorem.from_txt_file(_RULES_FILE.value, to_dict=True)
|
| 739 |
+
|
| 740 |
+
# when using the language model,
|
| 741 |
+
# point names will be renamed to alphabetical a, b, c, d, e, ...
|
| 742 |
+
# instead of staying with their original names,
|
| 743 |
+
# in order to match the synthetic training data generation.
|
| 744 |
+
need_rename = _MODE.value != 'ddar'
|
| 745 |
+
|
| 746 |
+
# load problems from the problems_file,
|
| 747 |
+
problems = pr.Problem.from_txt_file(
|
| 748 |
+
_PROBLEMS_FILE.value, to_dict=True, translate=need_rename
|
| 749 |
+
)
|
| 750 |
+
|
| 751 |
+
if _PROBLEM_NAME.value not in problems:
|
| 752 |
+
raise ValueError(
|
| 753 |
+
f'Problem name `{_PROBLEM_NAME.value}` '
|
| 754 |
+
+ f'not found in `{_PROBLEMS_FILE.value}`'
|
| 755 |
+
)
|
| 756 |
+
|
| 757 |
+
this_problem = problems[_PROBLEM_NAME.value]
|
| 758 |
+
|
| 759 |
+
if _MODE.value == 'ddar':
|
| 760 |
+
g, _ = gh.Graph.build_problem(this_problem, DEFINITIONS)
|
| 761 |
+
run_ddar(g, this_problem, _OUT_FILE.value)
|
| 762 |
+
|
| 763 |
+
elif _MODE.value == 'alphageometry':
|
| 764 |
+
#XX model = get_lm(_CKPT_PATH.value, _VOCAB_PATH.value)
|
| 765 |
+
run_alphageometry(
|
| 766 |
+
#XX model,
|
| 767 |
+
this_problem,
|
| 768 |
+
_SEARCH_DEPTH.value,
|
| 769 |
+
_BEAM_SIZE.value,
|
| 770 |
+
_OUT_FILE.value,
|
| 771 |
+
)
|
| 772 |
+
|
| 773 |
+
else:
|
| 774 |
+
raise ValueError(f'Unknown FLAGS.mode: {_MODE.value}')
|
| 775 |
+
|
| 776 |
+
|
| 777 |
+
if __name__ == '__main__':
|
| 778 |
+
app.run(main)
|
backend/core/ag4masses/alphageometry/alphageometry_test.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 DeepMind Technologies Limited
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
|
| 16 |
+
"""Unit tests for alphageometry.py."""
|
| 17 |
+
|
| 18 |
+
import unittest
|
| 19 |
+
|
| 20 |
+
from absl.testing import absltest
|
| 21 |
+
import alphageometry
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class AlphaGeometryTest(unittest.TestCase):
|
| 25 |
+
|
| 26 |
+
def test_translate_constrained_to_constructive(self):
|
| 27 |
+
self.assertEqual(
|
| 28 |
+
alphageometry.translate_constrained_to_constructive(
|
| 29 |
+
'd', 'T', list('addb')
|
| 30 |
+
),
|
| 31 |
+
('on_dia', ['d', 'b', 'a']),
|
| 32 |
+
)
|
| 33 |
+
self.assertEqual(
|
| 34 |
+
alphageometry.translate_constrained_to_constructive(
|
| 35 |
+
'd', 'T', list('adbc')
|
| 36 |
+
),
|
| 37 |
+
('on_tline', ['d', 'a', 'b', 'c']),
|
| 38 |
+
)
|
| 39 |
+
self.assertEqual(
|
| 40 |
+
alphageometry.translate_constrained_to_constructive(
|
| 41 |
+
'd', 'P', list('bcda')
|
| 42 |
+
),
|
| 43 |
+
('on_pline', ['d', 'a', 'b', 'c']),
|
| 44 |
+
)
|
| 45 |
+
self.assertEqual(
|
| 46 |
+
alphageometry.translate_constrained_to_constructive(
|
| 47 |
+
'd', 'D', list('bdcd')
|
| 48 |
+
),
|
| 49 |
+
('on_bline', ['d', 'c', 'b']),
|
| 50 |
+
)
|
| 51 |
+
self.assertEqual(
|
| 52 |
+
alphageometry.translate_constrained_to_constructive(
|
| 53 |
+
'd', 'D', list('bdcb')
|
| 54 |
+
),
|
| 55 |
+
('on_circle', ['d', 'b', 'c']),
|
| 56 |
+
)
|
| 57 |
+
self.assertEqual(
|
| 58 |
+
alphageometry.translate_constrained_to_constructive(
|
| 59 |
+
'd', 'D', list('bacd')
|
| 60 |
+
),
|
| 61 |
+
('eqdistance', ['d', 'c', 'b', 'a']),
|
| 62 |
+
)
|
| 63 |
+
self.assertEqual(
|
| 64 |
+
alphageometry.translate_constrained_to_constructive(
|
| 65 |
+
'd', 'C', list('bad')
|
| 66 |
+
),
|
| 67 |
+
('on_line', ['d', 'b', 'a']),
|
| 68 |
+
)
|
| 69 |
+
self.assertEqual(
|
| 70 |
+
alphageometry.translate_constrained_to_constructive(
|
| 71 |
+
'd', 'C', list('bad')
|
| 72 |
+
),
|
| 73 |
+
('on_line', ['d', 'b', 'a']),
|
| 74 |
+
)
|
| 75 |
+
self.assertEqual(
|
| 76 |
+
alphageometry.translate_constrained_to_constructive(
|
| 77 |
+
'd', 'O', list('abcd')
|
| 78 |
+
),
|
| 79 |
+
('on_circum', ['d', 'a', 'b', 'c']),
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
def test_insert_aux_to_premise(self):
|
| 83 |
+
pstring = 'a b c = triangle a b c; d = on_tline d b a c, on_tline d c a b ? perp a d b c' # pylint: disable=line-too-long
|
| 84 |
+
auxstring = 'e = on_line e a c, on_line e b d'
|
| 85 |
+
|
| 86 |
+
target = 'a b c = triangle a b c; d = on_tline d b a c, on_tline d c a b; e = on_line e a c, on_line e b d ? perp a d b c' # pylint: disable=line-too-long
|
| 87 |
+
self.assertEqual(
|
| 88 |
+
alphageometry.insert_aux_to_premise(pstring, auxstring), target
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
def test_beam_queue(self):
|
| 92 |
+
beam_queue = alphageometry.BeamQueue(max_size=2)
|
| 93 |
+
|
| 94 |
+
beam_queue.add('a', 1)
|
| 95 |
+
beam_queue.add('b', 2)
|
| 96 |
+
beam_queue.add('c', 3)
|
| 97 |
+
|
| 98 |
+
beam_queue = list(beam_queue)
|
| 99 |
+
self.assertEqual(beam_queue, [(3, 'c'), (2, 'b')])
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
if __name__ == '__main__':
|
| 103 |
+
absltest.main()
|
backend/core/ag4masses/alphageometry/ar.py
ADDED
|
@@ -0,0 +1,752 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 DeepMind Technologies Limited
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
|
| 16 |
+
"""Implementing Algebraic Reasoning (AR)."""
|
| 17 |
+
|
| 18 |
+
from collections import defaultdict # pylint: disable=g-importing-member
|
| 19 |
+
from fractions import Fraction as frac # pylint: disable=g-importing-member
|
| 20 |
+
from typing import Any, Generator
|
| 21 |
+
|
| 22 |
+
import geometry as gm
|
| 23 |
+
import numpy as np
|
| 24 |
+
import problem as pr
|
| 25 |
+
from scipy import optimize
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class InfQuotientError(Exception):
|
| 29 |
+
pass
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def _gcd(x: int, y: int) -> int:
|
| 33 |
+
while y:
|
| 34 |
+
x, y = y, x % y
|
| 35 |
+
return x
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def simplify(n: int, d: int) -> tuple[int, int]:
|
| 39 |
+
g = _gcd(n, d)
|
| 40 |
+
return (n // g, d // g)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
# maximum denominator for a fraction.
|
| 44 |
+
MAX_DENOMINATOR = 1000000
|
| 45 |
+
|
| 46 |
+
# tolerance for fraction approximation
|
| 47 |
+
TOL = 1e-15
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def get_quotient(v: float) -> tuple[int, int]:
|
| 51 |
+
n = v
|
| 52 |
+
d = 1
|
| 53 |
+
while abs(n - round(n)) > TOL:
|
| 54 |
+
d += 1
|
| 55 |
+
n += v
|
| 56 |
+
if d > MAX_DENOMINATOR:
|
| 57 |
+
e = InfQuotientError(v)
|
| 58 |
+
raise e
|
| 59 |
+
|
| 60 |
+
n = int(round(n))
|
| 61 |
+
return simplify(n, d)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def fix_v(v: float) -> float:
|
| 65 |
+
n, d = get_quotient(v)
|
| 66 |
+
return n / d
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def fix(e: dict[str, float]) -> dict[str, float]:
|
| 70 |
+
return {k: fix_v(v) for k, v in e.items()}
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def frac_string(f: frac) -> str:
|
| 74 |
+
n, d = get_quotient(f)
|
| 75 |
+
return f'{n}/{d}'
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def hashed(e: dict[str, float]) -> tuple[tuple[str, float], ...]:
|
| 79 |
+
return tuple(sorted(list(e.items())))
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def is_zero(e: dict[str, float]) -> bool:
|
| 83 |
+
return len(strip(e)) == 0 # pylint: disable=g-explicit-length-test
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def strip(e: dict[str, float]) -> dict[str, float]:
|
| 87 |
+
return {v: c for v, c in e.items() if c != 0}
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def plus(e1: dict[str, float], e2: dict[str, float]) -> dict[str, float]:
|
| 91 |
+
e = dict(e1)
|
| 92 |
+
for v, c in e2.items():
|
| 93 |
+
if v in e:
|
| 94 |
+
e[v] += c
|
| 95 |
+
else:
|
| 96 |
+
e[v] = c
|
| 97 |
+
return strip(e)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def plus_all(*es: list[dict[str, float]]) -> dict[str, float]:
|
| 101 |
+
result = {}
|
| 102 |
+
for e in es:
|
| 103 |
+
result = plus(result, e)
|
| 104 |
+
return result
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def mult(e: dict[str, float], m: float) -> dict[str, float]:
|
| 108 |
+
return {v: m * c for v, c in e.items()}
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def minus(e1: dict[str, float], e2: dict[str, float]) -> dict[str, float]:
|
| 112 |
+
return plus(e1, mult(e2, -1))
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def div(e1: dict[str, float], e2: dict[str, float]) -> float:
|
| 116 |
+
"""Divide e1 by e2."""
|
| 117 |
+
e1 = strip(e1)
|
| 118 |
+
e2 = strip(e2)
|
| 119 |
+
if set(e1.keys()) != set(e2.keys()):
|
| 120 |
+
return None
|
| 121 |
+
|
| 122 |
+
n, d = None, None
|
| 123 |
+
|
| 124 |
+
for v, c1 in e1.items():
|
| 125 |
+
c2 = e2[v] # we want c1/c2 = n/d => c1*d=c2*n
|
| 126 |
+
if n is not None and c1 * d != c2 * n:
|
| 127 |
+
return None
|
| 128 |
+
n, d = c1, c2
|
| 129 |
+
return frac(n) / frac(d)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def recon(e: dict[str, float], const: str) -> tuple[str, dict[str, float]]:
|
| 133 |
+
"""Reconcile one variable in the expression e=0, given const."""
|
| 134 |
+
e = strip(e)
|
| 135 |
+
if len(e) == 0: # pylint: disable=g-explicit-length-test
|
| 136 |
+
return None
|
| 137 |
+
|
| 138 |
+
v0 = None
|
| 139 |
+
for v in e:
|
| 140 |
+
if v != const:
|
| 141 |
+
v0 = v
|
| 142 |
+
break
|
| 143 |
+
if v0 is None:
|
| 144 |
+
return v0
|
| 145 |
+
|
| 146 |
+
c0 = e.pop(v0)
|
| 147 |
+
return v0, {v: -c / c0 for v, c in e.items()}
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def replace(
|
| 151 |
+
e: dict[str, float], v0: str, e0: dict[str, float]
|
| 152 |
+
) -> dict[str, float]:
|
| 153 |
+
if v0 not in e:
|
| 154 |
+
return e
|
| 155 |
+
e = dict(e)
|
| 156 |
+
m = e.pop(v0)
|
| 157 |
+
return plus(e, mult(e0, m))
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def comb2(elems: list[Any]) -> Generator[tuple[Any, Any], None, None]:
|
| 161 |
+
if len(elems) < 1:
|
| 162 |
+
return
|
| 163 |
+
for i, e1 in enumerate(elems[:-1]):
|
| 164 |
+
for e2 in elems[i + 1 :]:
|
| 165 |
+
yield e1, e2
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def perm2(elems: list[Any]) -> Generator[tuple[Any, Any], None, None]:
|
| 169 |
+
for e1, e2 in comb2(elems):
|
| 170 |
+
yield e1, e2
|
| 171 |
+
yield e2, e1
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def chain2(elems: list[Any]) -> Generator[tuple[Any, Any], None, None]:
|
| 175 |
+
if len(elems) < 2:
|
| 176 |
+
return
|
| 177 |
+
for i, e1 in enumerate(elems[:-1]):
|
| 178 |
+
yield e1, elems[i + 1]
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def update_groups(
|
| 182 |
+
groups1: list[Any], groups2: list[Any]
|
| 183 |
+
) -> tuple[list[Any], list[tuple[Any, Any]], list[list[Any]]]:
|
| 184 |
+
"""Update groups of equivalent elements.
|
| 185 |
+
|
| 186 |
+
Given groups1 = [set1, set2, set3, ..]
|
| 187 |
+
where all elems within each set_i is defined to be "equivalent" to each other.
|
| 188 |
+
(but not across the sets)
|
| 189 |
+
|
| 190 |
+
Incoming groups2 = [set1, set2, ...] similar to set1 - it is the
|
| 191 |
+
additional equivalent information on elements in groups1.
|
| 192 |
+
|
| 193 |
+
Return the new updated groups1 and the set of links
|
| 194 |
+
that make it that way.
|
| 195 |
+
|
| 196 |
+
Example:
|
| 197 |
+
groups1 = [{1, 2}, {3, 4, 5}, {6, 7}]
|
| 198 |
+
groups2 = [{2, 3, 8}, {9, 10, 11}]
|
| 199 |
+
|
| 200 |
+
=> new groups1 and links:
|
| 201 |
+
groups1 = [{1, 2, 3, 4, 5, 8}, {6, 7}, {9, 10, 11}]
|
| 202 |
+
links = (2, 3), (3, 8), (9, 10), (10, 11)
|
| 203 |
+
|
| 204 |
+
Explain: since groups2 says 2 and 3 are equivalent (with {2, 3, 8}),
|
| 205 |
+
then {1, 2} and {3, 4, 5} in groups1 will be merged,
|
| 206 |
+
because 2 and 3 each belong to those 2 groups.
|
| 207 |
+
Additionally 8 also belong to this same group.
|
| 208 |
+
{3, 4, 5} is left alone, while {9, 10, 11} is a completely new set.
|
| 209 |
+
|
| 210 |
+
The links to make this all happens is:
|
| 211 |
+
(2, 3): to merge {1, 2} and {3, 4, 5}
|
| 212 |
+
(3, 8): to link 8 into the merged({1, 2, 3, 4, 5})
|
| 213 |
+
(9, 10) and (10, 11): to make the new group {9, 10, 11}
|
| 214 |
+
|
| 215 |
+
Args:
|
| 216 |
+
groups1: a list of sets.
|
| 217 |
+
groups2: a list of sets.
|
| 218 |
+
|
| 219 |
+
Returns:
|
| 220 |
+
groups1, links, history: result of the update.
|
| 221 |
+
"""
|
| 222 |
+
history = []
|
| 223 |
+
links = []
|
| 224 |
+
for g2 in groups2:
|
| 225 |
+
joins = [None] * len(groups1) # mark which one in groups1 is merged
|
| 226 |
+
merged_g1 = set() # merge them into this.
|
| 227 |
+
old = None # any elem in g2 that belong to any set in groups1 (old)
|
| 228 |
+
new = [] # all elem in g2 that is new
|
| 229 |
+
|
| 230 |
+
for e in g2:
|
| 231 |
+
found = False
|
| 232 |
+
for i, g1 in enumerate(groups1):
|
| 233 |
+
if e not in g1:
|
| 234 |
+
continue
|
| 235 |
+
|
| 236 |
+
found = True
|
| 237 |
+
if joins[i]:
|
| 238 |
+
continue
|
| 239 |
+
|
| 240 |
+
joins[i] = True
|
| 241 |
+
merged_g1.update(g1)
|
| 242 |
+
|
| 243 |
+
if old is not None:
|
| 244 |
+
links.append((old, e)) # link to make merging happen.
|
| 245 |
+
old = e
|
| 246 |
+
|
| 247 |
+
if not found: # e is new!
|
| 248 |
+
new.append(e)
|
| 249 |
+
|
| 250 |
+
# now chain elems in new together.
|
| 251 |
+
if old is not None and new:
|
| 252 |
+
links.append((old, new[0]))
|
| 253 |
+
merged_g1.update(new)
|
| 254 |
+
|
| 255 |
+
links += chain2(new)
|
| 256 |
+
|
| 257 |
+
new_groups1 = []
|
| 258 |
+
if merged_g1: # put the merged_g1 in first
|
| 259 |
+
new_groups1.append(merged_g1)
|
| 260 |
+
|
| 261 |
+
# put the remaining (unjoined) groups in
|
| 262 |
+
new_groups1 += [g1 for j, g1 in zip(joins, groups1) if not j]
|
| 263 |
+
|
| 264 |
+
if old is None and new:
|
| 265 |
+
new_groups1 += [set(new)]
|
| 266 |
+
|
| 267 |
+
groups1 = new_groups1
|
| 268 |
+
history.append(groups1)
|
| 269 |
+
|
| 270 |
+
return groups1, links, history
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
class Table:
|
| 274 |
+
"""The coefficient matrix."""
|
| 275 |
+
|
| 276 |
+
def __init__(self, const: str = '1'):
|
| 277 |
+
self.const = const
|
| 278 |
+
self.v2e = {}
|
| 279 |
+
self.add_free(const) # the table {var: expression}
|
| 280 |
+
|
| 281 |
+
# to cache what is already derived/inputted
|
| 282 |
+
self.eqs = set()
|
| 283 |
+
self.groups = [] # groups of equal pairs.
|
| 284 |
+
|
| 285 |
+
# for why (linprog)
|
| 286 |
+
self.c = []
|
| 287 |
+
self.v2i = {} # v -> index of row in A.
|
| 288 |
+
self.deps = [] # equal number of columns.
|
| 289 |
+
self.A = np.zeros([0, 0]) # pylint: disable=invalid-name
|
| 290 |
+
self.do_why = True
|
| 291 |
+
|
| 292 |
+
def add_free(self, v: str) -> None:
|
| 293 |
+
self.v2e[v] = {v: frac(1)}
|
| 294 |
+
|
| 295 |
+
def replace(self, v0: str, e0: dict[str, float]) -> None:
|
| 296 |
+
for v, e in list(self.v2e.items()):
|
| 297 |
+
self.v2e[v] = replace(e, v0, e0)
|
| 298 |
+
|
| 299 |
+
def add_expr(self, vc: list[tuple[str, float]]) -> bool:
|
| 300 |
+
"""Add a new equality, represented by the list of tuples vc=[(v, c), ..]."""
|
| 301 |
+
result = {}
|
| 302 |
+
free = []
|
| 303 |
+
|
| 304 |
+
for v, c in vc:
|
| 305 |
+
c = frac(c)
|
| 306 |
+
if v in self.v2e:
|
| 307 |
+
result = plus(result, mult(self.v2e[v], c))
|
| 308 |
+
else:
|
| 309 |
+
free += [(v, c)]
|
| 310 |
+
|
| 311 |
+
if free == []: # pylint: disable=g-explicit-bool-comparison
|
| 312 |
+
if is_zero(self.modulo(result)):
|
| 313 |
+
return False
|
| 314 |
+
result = recon(result, self.const)
|
| 315 |
+
if result is None:
|
| 316 |
+
return False
|
| 317 |
+
v, e = result
|
| 318 |
+
self.replace(v, e)
|
| 319 |
+
|
| 320 |
+
elif len(free) == 1:
|
| 321 |
+
v, m = free[0]
|
| 322 |
+
self.v2e[v] = mult(result, frac(-1, m))
|
| 323 |
+
|
| 324 |
+
else:
|
| 325 |
+
dependent_v = None
|
| 326 |
+
for v, m in free:
|
| 327 |
+
if dependent_v is None and v != self.const:
|
| 328 |
+
dependent_v = (v, m)
|
| 329 |
+
continue
|
| 330 |
+
|
| 331 |
+
self.add_free(v)
|
| 332 |
+
result = plus(result, {v: m})
|
| 333 |
+
|
| 334 |
+
v, m = dependent_v
|
| 335 |
+
self.v2e[v] = mult(result, frac(-1, m))
|
| 336 |
+
|
| 337 |
+
return True
|
| 338 |
+
|
| 339 |
+
def register(self, vc: list[tuple[str, float]], dep: pr.Dependency) -> None:
|
| 340 |
+
"""Register a new equality vc=[(v, c), ..] with traceback dependency dep."""
|
| 341 |
+
result = plus_all(*[{v: c} for v, c in vc])
|
| 342 |
+
if is_zero(result):
|
| 343 |
+
return
|
| 344 |
+
|
| 345 |
+
vs, _ = zip(*vc)
|
| 346 |
+
for v in vs:
|
| 347 |
+
if v not in self.v2i:
|
| 348 |
+
self.v2i[v] = len(self.v2i)
|
| 349 |
+
|
| 350 |
+
(m, n), l = self.A.shape, len(self.v2i)
|
| 351 |
+
if l > m:
|
| 352 |
+
self.A = np.concatenate([self.A, np.zeros([l - m, n])], 0)
|
| 353 |
+
|
| 354 |
+
new_column = np.zeros([len(self.v2i), 2]) # N, 2
|
| 355 |
+
for v, c in vc:
|
| 356 |
+
new_column[self.v2i[v], 0] += float(c)
|
| 357 |
+
new_column[self.v2i[v], 1] -= float(c)
|
| 358 |
+
|
| 359 |
+
self.A = np.concatenate([self.A, new_column], 1)
|
| 360 |
+
self.c += [1.0, -1.0]
|
| 361 |
+
self.deps += [dep]
|
| 362 |
+
|
| 363 |
+
def register2(
|
| 364 |
+
self, a: str, b: str, m: float, n: float, dep: pr.Dependency
|
| 365 |
+
) -> None:
|
| 366 |
+
self.register([(a, m), (b, -n)], dep)
|
| 367 |
+
|
| 368 |
+
def register3(self, a: str, b: str, f: float, dep: pr.Dependency) -> None:
|
| 369 |
+
self.register([(a, 1), (b, -1), (self.const, -f)], dep)
|
| 370 |
+
|
| 371 |
+
def register4(
|
| 372 |
+
self, a: str, b: str, c: str, d: str, dep: pr.Dependency
|
| 373 |
+
) -> None:
|
| 374 |
+
self.register([(a, 1), (b, -1), (c, -1), (d, 1)], dep)
|
| 375 |
+
|
| 376 |
+
def why(self, e: dict[str, float]) -> list[Any]:
|
| 377 |
+
"""AR traceback == MILP."""
|
| 378 |
+
if not self.do_why:
|
| 379 |
+
return []
|
| 380 |
+
# why expr == 0?
|
| 381 |
+
# Solve min(c^Tx) s.t. A_eq * x = b_eq, x >= 0
|
| 382 |
+
e = strip(e)
|
| 383 |
+
if not e:
|
| 384 |
+
return []
|
| 385 |
+
|
| 386 |
+
b_eq = [0] * len(self.v2i)
|
| 387 |
+
for v, c in e.items():
|
| 388 |
+
b_eq[self.v2i[v]] += float(c)
|
| 389 |
+
|
| 390 |
+
try:
|
| 391 |
+
x = optimize.linprog(c=self.c, A_eq=self.A, b_eq=b_eq, method='highs')[
|
| 392 |
+
'x'
|
| 393 |
+
]
|
| 394 |
+
except: # pylint: disable=bare-except
|
| 395 |
+
x = optimize.linprog(
|
| 396 |
+
c=self.c,
|
| 397 |
+
A_eq=self.A,
|
| 398 |
+
b_eq=b_eq,
|
| 399 |
+
)['x']
|
| 400 |
+
|
| 401 |
+
deps = []
|
| 402 |
+
for i, dep in enumerate(self.deps):
|
| 403 |
+
if x[2 * i] > 1e-12 or x[2 * i + 1] > 1e-12:
|
| 404 |
+
if dep not in deps:
|
| 405 |
+
deps.append(dep)
|
| 406 |
+
return deps
|
| 407 |
+
|
| 408 |
+
def record_eq(self, v1: str, v2: str, v3: str, v4: str) -> None:
|
| 409 |
+
self.eqs.add((v1, v2, v3, v4))
|
| 410 |
+
self.eqs.add((v2, v1, v4, v3))
|
| 411 |
+
self.eqs.add((v3, v4, v1, v2))
|
| 412 |
+
self.eqs.add((v4, v3, v2, v1))
|
| 413 |
+
|
| 414 |
+
def check_record_eq(self, v1: str, v2: str, v3: str, v4: str) -> bool:
|
| 415 |
+
if (v1, v2, v3, v4) in self.eqs:
|
| 416 |
+
return True
|
| 417 |
+
if (v2, v1, v4, v3) in self.eqs:
|
| 418 |
+
return True
|
| 419 |
+
if (v3, v4, v1, v2) in self.eqs:
|
| 420 |
+
return True
|
| 421 |
+
if (v4, v3, v2, v1) in self.eqs:
|
| 422 |
+
return True
|
| 423 |
+
return False
|
| 424 |
+
|
| 425 |
+
def add_eq2(
|
| 426 |
+
self, a: str, b: str, m: float, n: float, dep: pr.Dependency
|
| 427 |
+
) -> None:
|
| 428 |
+
# a/b = m/n
|
| 429 |
+
if not self.add_expr([(a, n), (b, -m)]):
|
| 430 |
+
return []
|
| 431 |
+
self.register2(a, b, m, n, dep)
|
| 432 |
+
|
| 433 |
+
def add_eq3(self, a: str, b: str, f: float, dep: pr.Dependency) -> None:
|
| 434 |
+
# a - b = f * constant
|
| 435 |
+
self.eqs.add((a, b, frac(f)))
|
| 436 |
+
self.eqs.add((b, a, frac(1 - f)))
|
| 437 |
+
|
| 438 |
+
if not self.add_expr([(a, 1), (b, -1), (self.const, -f)]):
|
| 439 |
+
return []
|
| 440 |
+
|
| 441 |
+
self.register3(a, b, f, dep)
|
| 442 |
+
|
| 443 |
+
def add_eq4(self, a: str, b: str, c: str, d: str, dep: pr.Dependency) -> None:
|
| 444 |
+
# a - b = c - d
|
| 445 |
+
self.record_eq(a, b, c, d)
|
| 446 |
+
self.record_eq(a, c, b, d)
|
| 447 |
+
|
| 448 |
+
expr = list(minus({a: 1, b: -1}, {c: 1, d: -1}).items())
|
| 449 |
+
|
| 450 |
+
if not self.add_expr(expr):
|
| 451 |
+
return []
|
| 452 |
+
|
| 453 |
+
self.register4(a, b, c, d, dep)
|
| 454 |
+
self.groups, _, _ = update_groups(
|
| 455 |
+
self.groups, [{(a, b), (c, d)}, {(b, a), (d, c)}]
|
| 456 |
+
)
|
| 457 |
+
|
| 458 |
+
def pairs(self) -> Generator[list[tuple[str, str]], None, None]:
|
| 459 |
+
for v1, v2 in perm2(list(self.v2e.keys())): # pylint: disable=g-builtin-op
|
| 460 |
+
if v1 == self.const or v2 == self.const:
|
| 461 |
+
continue
|
| 462 |
+
yield v1, v2
|
| 463 |
+
|
| 464 |
+
def modulo(self, e: dict[str, float]) -> dict[str, float]:
|
| 465 |
+
return strip(e)
|
| 466 |
+
|
| 467 |
+
def get_all_eqs(
|
| 468 |
+
self,
|
| 469 |
+
) -> dict[tuple[tuple[str, float], ...], list[tuple[str, str]]]:
|
| 470 |
+
h2pairs = defaultdict(list)
|
| 471 |
+
for v1, v2 in self.pairs():
|
| 472 |
+
e1, e2 = self.v2e[v1], self.v2e[v2]
|
| 473 |
+
e12 = minus(e1, e2)
|
| 474 |
+
h12 = hashed(self.modulo(e12))
|
| 475 |
+
h2pairs[h12].append((v1, v2))
|
| 476 |
+
return h2pairs
|
| 477 |
+
|
| 478 |
+
def get_all_eqs_and_why(
|
| 479 |
+
self, return_quads: bool = True
|
| 480 |
+
) -> Generator[Any, None, None]:
|
| 481 |
+
"""Check all 4/3/2-permutations for new equalities."""
|
| 482 |
+
groups = []
|
| 483 |
+
|
| 484 |
+
for h, vv in self.get_all_eqs().items():
|
| 485 |
+
if h == (): # pylint: disable=g-explicit-bool-comparison
|
| 486 |
+
for v1, v2 in vv:
|
| 487 |
+
if (v1, v2) in self.eqs or (v2, v1) in self.eqs:
|
| 488 |
+
continue
|
| 489 |
+
self.eqs.add((v1, v2))
|
| 490 |
+
# why v1 - v2 = e12 ? (note modulo(e12) == 0)
|
| 491 |
+
why_dict = minus({v1: 1, v2: -1}, minus(self.v2e[v1], self.v2e[v2]))
|
| 492 |
+
yield v1, v2, self.why(why_dict)
|
| 493 |
+
continue
|
| 494 |
+
|
| 495 |
+
if len(h) == 1 and h[0][0] == self.const:
|
| 496 |
+
for v1, v2 in vv:
|
| 497 |
+
frac = h[0][1] # pylint: disable=redefined-outer-name
|
| 498 |
+
if (v1, v2, frac) in self.eqs:
|
| 499 |
+
continue
|
| 500 |
+
self.eqs.add((v1, v2, frac))
|
| 501 |
+
# why v1 - v2 = e12 ? (note modulo(e12) == 0)
|
| 502 |
+
why_dict = minus({v1: 1, v2: -1}, minus(self.v2e[v1], self.v2e[v2]))
|
| 503 |
+
value = simplify(frac.numerator, frac.denominator)
|
| 504 |
+
yield v1, v2, value, self.why(why_dict)
|
| 505 |
+
continue
|
| 506 |
+
|
| 507 |
+
groups.append(vv)
|
| 508 |
+
|
| 509 |
+
if not return_quads:
|
| 510 |
+
return
|
| 511 |
+
|
| 512 |
+
self.groups, links, _ = update_groups(self.groups, groups)
|
| 513 |
+
for (v1, v2), (v3, v4) in links:
|
| 514 |
+
if self.check_record_eq(v1, v2, v3, v4):
|
| 515 |
+
continue
|
| 516 |
+
e12 = minus(self.v2e[v1], self.v2e[v2])
|
| 517 |
+
e34 = minus(self.v2e[v3], self.v2e[v4])
|
| 518 |
+
|
| 519 |
+
why_dict = minus( # why (v1-v2)-(v3-v4)=e12-e34?
|
| 520 |
+
minus({v1: 1, v2: -1}, {v3: 1, v4: -1}), minus(e12, e34)
|
| 521 |
+
)
|
| 522 |
+
self.record_eq(v1, v2, v3, v4)
|
| 523 |
+
yield v1, v2, v3, v4, self.why(why_dict)
|
| 524 |
+
|
| 525 |
+
|
| 526 |
+
class GeometricTable(Table):
|
| 527 |
+
"""Abstract class representing the coefficient matrix (table) A."""
|
| 528 |
+
|
| 529 |
+
def __init__(self, name: str = ''):
|
| 530 |
+
super().__init__(name)
|
| 531 |
+
self.v2obj = {}
|
| 532 |
+
|
| 533 |
+
def get_name(self, objs: list[Any]) -> list[str]:
|
| 534 |
+
self.v2obj.update({o.name: o for o in objs})
|
| 535 |
+
return [o.name for o in objs]
|
| 536 |
+
|
| 537 |
+
def map2obj(self, names: list[str]) -> list[Any]:
|
| 538 |
+
return [self.v2obj[n] for n in names]
|
| 539 |
+
|
| 540 |
+
def get_all_eqs_and_why(
|
| 541 |
+
self, return_quads: bool
|
| 542 |
+
) -> Generator[Any, None, None]:
|
| 543 |
+
for out in super().get_all_eqs_and_why(return_quads):
|
| 544 |
+
if len(out) == 3:
|
| 545 |
+
x, y, why = out
|
| 546 |
+
x, y = self.map2obj([x, y])
|
| 547 |
+
yield x, y, why
|
| 548 |
+
if len(out) == 4:
|
| 549 |
+
x, y, f, why = out
|
| 550 |
+
x, y = self.map2obj([x, y])
|
| 551 |
+
yield x, y, f, why
|
| 552 |
+
if len(out) == 5:
|
| 553 |
+
a, b, x, y, why = out
|
| 554 |
+
a, b, x, y = self.map2obj([a, b, x, y])
|
| 555 |
+
yield a, b, x, y, why
|
| 556 |
+
|
| 557 |
+
|
| 558 |
+
class RatioTable(GeometricTable):
|
| 559 |
+
"""Coefficient matrix A for log(distance)."""
|
| 560 |
+
|
| 561 |
+
def __init__(self, name: str = ''):
|
| 562 |
+
name = name or '1'
|
| 563 |
+
super().__init__(name)
|
| 564 |
+
self.one = self.const
|
| 565 |
+
|
| 566 |
+
def add_eq(self, l1: gm.Length, l2: gm.Length, dep: pr.Dependency) -> None:
|
| 567 |
+
l1, l2 = self.get_name([l1, l2])
|
| 568 |
+
return super().add_eq3(l1, l2, 0.0, dep)
|
| 569 |
+
|
| 570 |
+
def add_const_ratio(
|
| 571 |
+
self, l1: gm.Length, l2: gm.Length, m: float, n: float, dep: pr.Dependency
|
| 572 |
+
) -> None:
|
| 573 |
+
l1, l2 = self.get_name([l1, l2])
|
| 574 |
+
return super().add_eq2(l1, l2, m, n, dep)
|
| 575 |
+
|
| 576 |
+
def add_eqratio(
|
| 577 |
+
self,
|
| 578 |
+
l1: gm.Length,
|
| 579 |
+
l2: gm.Length,
|
| 580 |
+
l3: gm.Length,
|
| 581 |
+
l4: gm.Length,
|
| 582 |
+
dep: pr.Dependency,
|
| 583 |
+
) -> None:
|
| 584 |
+
l1, l2, l3, l4 = self.get_name([l1, l2, l3, l4])
|
| 585 |
+
return self.add_eq4(l1, l2, l3, l4, dep)
|
| 586 |
+
|
| 587 |
+
def get_all_eqs_and_why(self) -> Generator[Any, None, None]:
|
| 588 |
+
return super().get_all_eqs_and_why(True)
|
| 589 |
+
|
| 590 |
+
|
| 591 |
+
class AngleTable(GeometricTable):
|
| 592 |
+
"""Coefficient matrix A for slope(direction)."""
|
| 593 |
+
|
| 594 |
+
def __init__(self, name: str = ''):
|
| 595 |
+
name = name or 'pi'
|
| 596 |
+
super().__init__(name)
|
| 597 |
+
self.pi = self.const
|
| 598 |
+
|
| 599 |
+
def modulo(self, e: dict[str, float]) -> dict[str, float]:
|
| 600 |
+
e = strip(e)
|
| 601 |
+
if self.pi not in e:
|
| 602 |
+
return super().modulo(e)
|
| 603 |
+
|
| 604 |
+
e[self.pi] = e[self.pi] % 1
|
| 605 |
+
return strip(e)
|
| 606 |
+
|
| 607 |
+
def add_para(
|
| 608 |
+
self, d1: gm.Direction, d2: gm.Direction, dep: pr.Dependency
|
| 609 |
+
) -> None:
|
| 610 |
+
return self.add_const_angle(d1, d2, 0, dep)
|
| 611 |
+
|
| 612 |
+
def add_const_angle(
|
| 613 |
+
self, d1: gm.Direction, d2: gm.Direction, ang: float, dep: pr.Dependency
|
| 614 |
+
) -> None:
|
| 615 |
+
if ang and d2._obj.num > d1._obj.num: # pylint: disable=protected-access
|
| 616 |
+
d1, d2 = d2, d1
|
| 617 |
+
ang = 180 - ang
|
| 618 |
+
|
| 619 |
+
d1, d2 = self.get_name([d1, d2])
|
| 620 |
+
|
| 621 |
+
num, den = simplify(ang, 180)
|
| 622 |
+
ang = frac(int(num), int(den))
|
| 623 |
+
return super().add_eq3(d1, d2, ang, dep)
|
| 624 |
+
|
| 625 |
+
def add_eqangle(
|
| 626 |
+
self,
|
| 627 |
+
d1: gm.Direction,
|
| 628 |
+
d2: gm.Direction,
|
| 629 |
+
d3: gm.Direction,
|
| 630 |
+
d4: gm.Direction,
|
| 631 |
+
dep: pr.Dependency,
|
| 632 |
+
) -> None:
|
| 633 |
+
"""Add the inequality d1-d2=d3-d4."""
|
| 634 |
+
# Use string as variables.
|
| 635 |
+
l1, l2, l3, l4 = [d._obj.num for d in [d1, d2, d3, d4]] # pylint: disable=protected-access
|
| 636 |
+
d1, d2, d3, d4 = self.get_name([d1, d2, d3, d4])
|
| 637 |
+
ang1 = {d1: 1, d2: -1}
|
| 638 |
+
ang2 = {d3: 1, d4: -1}
|
| 639 |
+
|
| 640 |
+
if l2 > l1:
|
| 641 |
+
ang1 = plus({self.pi: 1}, ang1)
|
| 642 |
+
if l4 > l3:
|
| 643 |
+
ang2 = plus({self.pi: 1}, ang2)
|
| 644 |
+
|
| 645 |
+
ang12 = minus(ang1, ang2)
|
| 646 |
+
self.record_eq(d1, d2, d3, d4)
|
| 647 |
+
self.record_eq(d1, d3, d2, d4)
|
| 648 |
+
|
| 649 |
+
expr = list(ang12.items())
|
| 650 |
+
if not self.add_expr(expr):
|
| 651 |
+
return []
|
| 652 |
+
|
| 653 |
+
self.register(expr, dep)
|
| 654 |
+
|
| 655 |
+
def get_all_eqs_and_why(self) -> Generator[Any, None, None]:
|
| 656 |
+
return super().get_all_eqs_and_why(True)
|
| 657 |
+
|
| 658 |
+
|
| 659 |
+
class DistanceTable(GeometricTable):
|
| 660 |
+
"""Coefficient matrix A for position(point, line)."""
|
| 661 |
+
|
| 662 |
+
def __init__(self, name: str = ''):
|
| 663 |
+
name = name or '1:1'
|
| 664 |
+
self.merged = {}
|
| 665 |
+
self.ratios = set()
|
| 666 |
+
super().__init__(name)
|
| 667 |
+
|
| 668 |
+
def pairs(self) -> Generator[tuple[str, str], None, None]:
|
| 669 |
+
l2vs = defaultdict(list)
|
| 670 |
+
for v in list(self.v2e.keys()): # pylint: disable=g-builtin-op
|
| 671 |
+
if v == self.const:
|
| 672 |
+
continue
|
| 673 |
+
l, p = v.split(':')
|
| 674 |
+
l2vs[l].append(p)
|
| 675 |
+
|
| 676 |
+
for l, ps in l2vs.items():
|
| 677 |
+
for p1, p2 in perm2(ps):
|
| 678 |
+
yield l + ':' + p1, l + ':' + p2
|
| 679 |
+
|
| 680 |
+
def name(self, l: gm.Line, p: gm.Point) -> str:
|
| 681 |
+
v = l.name + ':' + p.name
|
| 682 |
+
self.v2obj[v] = (l, p)
|
| 683 |
+
return v
|
| 684 |
+
|
| 685 |
+
def map2obj(self, names: list[str]) -> list[gm.Point]:
|
| 686 |
+
return [self.v2obj[n][1] for n in names]
|
| 687 |
+
|
| 688 |
+
def add_cong(
|
| 689 |
+
self,
|
| 690 |
+
l12: gm.Line,
|
| 691 |
+
l34: gm.Line,
|
| 692 |
+
p1: gm.Point,
|
| 693 |
+
p2: gm.Point,
|
| 694 |
+
p3: gm.Point,
|
| 695 |
+
p4: gm.Point,
|
| 696 |
+
dep: pr.Dependency,
|
| 697 |
+
) -> None:
|
| 698 |
+
"""Add that distance between p1 and p2 (on l12) == p3 and p4 (on l34)."""
|
| 699 |
+
if p2.num > p1.num:
|
| 700 |
+
p1, p2 = p2, p1
|
| 701 |
+
if p4.num > p3.num:
|
| 702 |
+
p3, p4 = p4, p3
|
| 703 |
+
|
| 704 |
+
p1 = self.name(l12, p1)
|
| 705 |
+
p2 = self.name(l12, p2)
|
| 706 |
+
p3 = self.name(l34, p3)
|
| 707 |
+
p4 = self.name(l34, p4)
|
| 708 |
+
return super().add_eq4(p1, p2, p3, p4, dep)
|
| 709 |
+
|
| 710 |
+
def get_all_eqs_and_why(self) -> Generator[Any, None, None]:
|
| 711 |
+
for x in super().get_all_eqs_and_why(True):
|
| 712 |
+
yield x
|
| 713 |
+
|
| 714 |
+
# Now we figure out all the const ratios.
|
| 715 |
+
h2pairs = defaultdict(list)
|
| 716 |
+
for v1, v2 in self.pairs():
|
| 717 |
+
if (v1, v2) in self.merged:
|
| 718 |
+
continue
|
| 719 |
+
e1, e2 = self.v2e[v1], self.v2e[v2]
|
| 720 |
+
e12 = minus(e1, e2)
|
| 721 |
+
h12 = hashed(e12)
|
| 722 |
+
h2pairs[h12].append((v1, v2, e12))
|
| 723 |
+
|
| 724 |
+
for (_, vves1), (_, vves2) in perm2(list(h2pairs.items())):
|
| 725 |
+
v1, v2, e12 = vves1[0]
|
| 726 |
+
for v1_, v2_, _ in vves1[1:]:
|
| 727 |
+
self.merged[(v1_, v2_)] = (v1, v2)
|
| 728 |
+
|
| 729 |
+
v3, v4, e34 = vves2[0]
|
| 730 |
+
for v3_, v4_, _ in vves2[1:]:
|
| 731 |
+
self.merged[(v3_, v4_)] = (v3, v4)
|
| 732 |
+
|
| 733 |
+
if (v1, v2, v3, v4) in self.ratios:
|
| 734 |
+
continue
|
| 735 |
+
|
| 736 |
+
d12 = div(e12, e34)
|
| 737 |
+
if d12 is None or d12 > 1 or d12 < 0:
|
| 738 |
+
continue
|
| 739 |
+
|
| 740 |
+
self.ratios.add((v1, v2, v3, v4))
|
| 741 |
+
self.ratios.add((v2, v1, v4, v3))
|
| 742 |
+
|
| 743 |
+
n, d = d12.numerator, d12.denominator
|
| 744 |
+
|
| 745 |
+
# (v1 - v2) * d = (v3 - v4) * n
|
| 746 |
+
why_dict = minus(
|
| 747 |
+
minus({v1: d, v2: -d}, {v3: n, v4: -n}),
|
| 748 |
+
minus(mult(e12, d), mult(e34, n)), # there is no modulo, so this is 0
|
| 749 |
+
)
|
| 750 |
+
|
| 751 |
+
v1, v2, v3, v4 = self.map2obj([v1, v2, v3, v4])
|
| 752 |
+
yield v1, v2, v3, v4, abs(n), abs(d), self.why(why_dict)
|
backend/core/ag4masses/alphageometry/ar_test.py
ADDED
|
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 DeepMind Technologies Limited
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
|
| 16 |
+
"""Unit tests for ar.py."""
|
| 17 |
+
import unittest
|
| 18 |
+
|
| 19 |
+
from absl.testing import absltest
|
| 20 |
+
import ar
|
| 21 |
+
import graph as gh
|
| 22 |
+
import problem as pr
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class ARTest(unittest.TestCase):
|
| 26 |
+
|
| 27 |
+
@classmethod
|
| 28 |
+
def setUpClass(cls):
|
| 29 |
+
super().setUpClass()
|
| 30 |
+
cls.defs = pr.Definition.from_txt_file('defs.txt', to_dict=True)
|
| 31 |
+
cls.rules = pr.Theorem.from_txt_file('rules.txt', to_dict=True)
|
| 32 |
+
|
| 33 |
+
def test_update_groups(self):
|
| 34 |
+
"""Test for update_groups."""
|
| 35 |
+
groups1 = [{1, 2}, {3, 4, 5}, {6, 7}]
|
| 36 |
+
groups2 = [{2, 3, 8}, {9, 10, 11}]
|
| 37 |
+
|
| 38 |
+
_, links, history = ar.update_groups(groups1, groups2)
|
| 39 |
+
self.assertEqual(
|
| 40 |
+
history,
|
| 41 |
+
[
|
| 42 |
+
[{1, 2, 3, 4, 5, 8}, {6, 7}],
|
| 43 |
+
[{1, 2, 3, 4, 5, 8}, {6, 7}, {9, 10, 11}],
|
| 44 |
+
],
|
| 45 |
+
)
|
| 46 |
+
self.assertEqual(links, [(2, 3), (3, 8), (9, 10), (10, 11)])
|
| 47 |
+
|
| 48 |
+
groups1 = [{1, 2}, {3, 4}, {5, 6}, {7, 8}]
|
| 49 |
+
groups2 = [{2, 3, 8, 9, 10}, {3, 6, 11}]
|
| 50 |
+
|
| 51 |
+
_, links, history = ar.update_groups(groups1, groups2)
|
| 52 |
+
self.assertEqual(
|
| 53 |
+
history,
|
| 54 |
+
[
|
| 55 |
+
[{1, 2, 3, 4, 7, 8, 9, 10}, {5, 6}],
|
| 56 |
+
[{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}],
|
| 57 |
+
],
|
| 58 |
+
)
|
| 59 |
+
self.assertEqual(links, [(2, 3), (3, 8), (8, 9), (9, 10), (3, 6), (6, 11)])
|
| 60 |
+
|
| 61 |
+
groups1 = []
|
| 62 |
+
groups2 = [{1, 2}, {3, 4}, {5, 6}, {2, 3}]
|
| 63 |
+
|
| 64 |
+
_, links, history = ar.update_groups(groups1, groups2)
|
| 65 |
+
self.assertEqual(
|
| 66 |
+
history,
|
| 67 |
+
[
|
| 68 |
+
[{1, 2}],
|
| 69 |
+
[{1, 2}, {3, 4}],
|
| 70 |
+
[{1, 2}, {3, 4}, {5, 6}],
|
| 71 |
+
[{1, 2, 3, 4}, {5, 6}],
|
| 72 |
+
],
|
| 73 |
+
)
|
| 74 |
+
self.assertEqual(links, [(1, 2), (3, 4), (5, 6), (2, 3)])
|
| 75 |
+
|
| 76 |
+
def test_generic_table_simple(self):
|
| 77 |
+
tb = ar.Table()
|
| 78 |
+
|
| 79 |
+
# If a-b = b-c & d-a = c-d
|
| 80 |
+
tb.add_eq4('a', 'b', 'b', 'c', 'fact1')
|
| 81 |
+
tb.add_eq4('d', 'a', 'c', 'd', 'fact2')
|
| 82 |
+
tb.add_eq4('x', 'y', 'z', 't', 'fact3') # distractor fact
|
| 83 |
+
|
| 84 |
+
# Then b=d, because {fact1, fact2} but not fact3.
|
| 85 |
+
result = list(tb.get_all_eqs_and_why())
|
| 86 |
+
self.assertIn(('b', 'd', ['fact1', 'fact2']), result)
|
| 87 |
+
|
| 88 |
+
def test_angle_table_inbisector_exbisector(self):
|
| 89 |
+
"""Test that AR can figure out bisector & ex-bisector are perpendicular."""
|
| 90 |
+
# Load the scenario that we have cd is bisector of acb and
|
| 91 |
+
# ce is the ex-bisector of acb.
|
| 92 |
+
p = pr.Problem.from_txt(
|
| 93 |
+
'a b c = triangle a b c; d = incenter d a b c; e = excenter e a b c ?'
|
| 94 |
+
' perp d c c e'
|
| 95 |
+
)
|
| 96 |
+
g, _ = gh.Graph.build_problem(p, ARTest.defs)
|
| 97 |
+
|
| 98 |
+
# Create an external angle table:
|
| 99 |
+
tb = ar.AngleTable('pi')
|
| 100 |
+
|
| 101 |
+
# Add bisector & ex-bisector facts into the table:
|
| 102 |
+
ca, cd, cb, ce = g.names2nodes(['d(ac)', 'd(cd)', 'd(bc)', 'd(ce)'])
|
| 103 |
+
tb.add_eqangle(ca, cd, cd, cb, 'fact1')
|
| 104 |
+
tb.add_eqangle(ce, ca, cb, ce, 'fact2')
|
| 105 |
+
|
| 106 |
+
# Add a distractor fact to make sure traceback does not include this fact
|
| 107 |
+
ab = g.names2nodes(['d(ab)'])[0]
|
| 108 |
+
tb.add_eqangle(ab, cb, cb, ca, 'fact3')
|
| 109 |
+
|
| 110 |
+
# Check for all new equalities
|
| 111 |
+
result = list(tb.get_all_eqs_and_why())
|
| 112 |
+
|
| 113 |
+
# halfpi is represented as a tuple (1, 2)
|
| 114 |
+
halfpi = (1, 2)
|
| 115 |
+
|
| 116 |
+
# check that cd-ce == halfpi and this is because fact1 & fact2, not fact3
|
| 117 |
+
self.assertCountEqual(
|
| 118 |
+
result,
|
| 119 |
+
[
|
| 120 |
+
(cd, ce, halfpi, ['fact1', 'fact2']),
|
| 121 |
+
(ce, cd, halfpi, ['fact1', 'fact2']),
|
| 122 |
+
],
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
def test_angle_table_equilateral_triangle(self):
|
| 126 |
+
"""Test that AR can figure out triangles with 3 equal angles => each is pi/3."""
|
| 127 |
+
# Load an equaliteral scenario
|
| 128 |
+
p = pr.Problem.from_txt('a b c = ieq_triangle ? cong a b a c')
|
| 129 |
+
g, _ = gh.Graph.build_problem(p, ARTest.defs)
|
| 130 |
+
|
| 131 |
+
# Add two eqangles facts because ieq_triangle only add congruent sides
|
| 132 |
+
a, b, c = g.names2nodes('abc')
|
| 133 |
+
g.add_eqangle([a, b, b, c, b, c, c, a], pr.EmptyDependency(0, None))
|
| 134 |
+
g.add_eqangle([b, c, c, a, c, a, a, b], pr.EmptyDependency(0, None))
|
| 135 |
+
|
| 136 |
+
# Create an external angle table:
|
| 137 |
+
tb = ar.AngleTable('pi')
|
| 138 |
+
|
| 139 |
+
# Add the fact that there are three equal angles
|
| 140 |
+
ab, bc, ca = g.names2nodes(['d(ab)', 'd(bc)', 'd(ac)'])
|
| 141 |
+
tb.add_eqangle(ab, bc, bc, ca, 'fact1')
|
| 142 |
+
tb.add_eqangle(bc, ca, ca, ab, 'fact2')
|
| 143 |
+
|
| 144 |
+
# Now check for all new equalities
|
| 145 |
+
result = list(tb.get_all_eqs_and_why())
|
| 146 |
+
result = [(x.name, y.name, z, t) for x, y, z, t in result]
|
| 147 |
+
|
| 148 |
+
# 1/3 pi is represented as a tuple angle_60
|
| 149 |
+
angle_60 = (1, 3)
|
| 150 |
+
angle_120 = (2, 3)
|
| 151 |
+
|
| 152 |
+
# check that angles constants are created and figured out:
|
| 153 |
+
self.assertCountEqual(
|
| 154 |
+
result,
|
| 155 |
+
[
|
| 156 |
+
('d(bc)', 'd(ac)', angle_120, ['fact1', 'fact2']),
|
| 157 |
+
('d(ab)', 'd(bc)', angle_120, ['fact1', 'fact2']),
|
| 158 |
+
('d(ac)', 'd(ab)', angle_120, ['fact1', 'fact2']),
|
| 159 |
+
('d(ac)', 'd(bc)', angle_60, ['fact1', 'fact2']),
|
| 160 |
+
('d(bc)', 'd(ab)', angle_60, ['fact1', 'fact2']),
|
| 161 |
+
('d(ab)', 'd(ac)', angle_60, ['fact1', 'fact2']),
|
| 162 |
+
],
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
def test_incenter_excenter_touchpoints(self):
|
| 166 |
+
"""Test that AR can figure out incenter/excenter touchpoints are equidistant to midpoint."""
|
| 167 |
+
|
| 168 |
+
p = pr.Problem.from_txt(
|
| 169 |
+
'a b c = triangle a b c; d1 d2 d3 d = incenter2 a b c; e1 e2 e3 e ='
|
| 170 |
+
' excenter2 a b c ? perp d c c e',
|
| 171 |
+
translate=False,
|
| 172 |
+
)
|
| 173 |
+
g, _ = gh.Graph.build_problem(p, ARTest.defs)
|
| 174 |
+
|
| 175 |
+
a, b, c, ab, bc, ca, d1, d2, d3, e1, e2, e3 = g.names2nodes(
|
| 176 |
+
['a', 'b', 'c', 'ab', 'bc', 'ac', 'd1', 'd2', 'd3', 'e1', 'e2', 'e3']
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
# Create an external distance table:
|
| 180 |
+
tb = ar.DistanceTable()
|
| 181 |
+
|
| 182 |
+
# DD can figure out the following facts,
|
| 183 |
+
# we manually add them to AR.
|
| 184 |
+
tb.add_cong(ab, ca, a, d3, a, d2, 'fact1')
|
| 185 |
+
tb.add_cong(ab, ca, a, e3, a, e2, 'fact2')
|
| 186 |
+
tb.add_cong(ca, bc, c, d2, c, d1, 'fact5')
|
| 187 |
+
tb.add_cong(ca, bc, c, e2, c, e1, 'fact6')
|
| 188 |
+
tb.add_cong(bc, ab, b, d1, b, d3, 'fact3')
|
| 189 |
+
tb.add_cong(bc, ab, b, e1, b, e3, 'fact4')
|
| 190 |
+
|
| 191 |
+
# Now we check whether tb has figured out that
|
| 192 |
+
# distance(b, d1) == distance(e1, c)
|
| 193 |
+
|
| 194 |
+
# linear comb exprssion of each variables:
|
| 195 |
+
b = tb.v2e['bc:b']
|
| 196 |
+
c = tb.v2e['bc:c']
|
| 197 |
+
d1 = tb.v2e['bc:d1']
|
| 198 |
+
e1 = tb.v2e['bc:e1']
|
| 199 |
+
|
| 200 |
+
self.assertEqual(ar.minus(d1, b), ar.minus(c, e1))
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
if __name__ == '__main__':
|
| 204 |
+
absltest.main()
|
backend/core/ag4masses/alphageometry/beam_search.py
ADDED
|
@@ -0,0 +1,463 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 DeepMind Technologies Limited
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
|
| 16 |
+
"""Fast decoding routines for inference from a trained model.
|
| 17 |
+
|
| 18 |
+
Modified https://github.com/google/flax/blob/main/examples/wmt/decode.py
|
| 19 |
+
to acommodate
|
| 20 |
+
|
| 21 |
+
(a) continued decoding from a previous beam cache.
|
| 22 |
+
(b) init with with a single beam and then expand into beam_size beams.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
from typing import Any
|
| 26 |
+
|
| 27 |
+
import flax
|
| 28 |
+
import jax
|
| 29 |
+
from jax import lax
|
| 30 |
+
import jax.numpy as jnp
|
| 31 |
+
import numpy as np
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# Constants
|
| 35 |
+
# "Effective negative infinity" constant for masking in beam search.
|
| 36 |
+
NEG_INF = np.array(-1.0e7)
|
| 37 |
+
|
| 38 |
+
# Beam search parameters
|
| 39 |
+
BEAM_SEARCH_DEFAULT_ALPHA = 0.6
|
| 40 |
+
MAX_DECODE_LEN = 32
|
| 41 |
+
|
| 42 |
+
# Brevity penalty parameters
|
| 43 |
+
BREVITY_LEN_BIAS_NUMERATOR = 5.0
|
| 44 |
+
BREVITY_LEN_BIAS_DENOMINATOR = 6.0
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def brevity_penalty(alpha: float, length: int):
|
| 48 |
+
"""Brevity penalty function for beam search penalizing short sequences.
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
alpha: float: brevity-penalty scaling parameter.
|
| 52 |
+
length: int: length of considered sequence.
|
| 53 |
+
|
| 54 |
+
Returns:
|
| 55 |
+
Brevity penalty score as jax scalar.
|
| 56 |
+
"""
|
| 57 |
+
return jnp.power(
|
| 58 |
+
((BREVITY_LEN_BIAS_NUMERATOR + length) / BREVITY_LEN_BIAS_DENOMINATOR),
|
| 59 |
+
alpha,
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
# Beam handling utility functions:
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def add_beam_dim(x: jnp.ndarray, beam_size: int) -> jnp.ndarray:
|
| 67 |
+
"""Creates new beam dimension in non-scalar array and tiles into it."""
|
| 68 |
+
if x.ndim == 0: # ignore scalars (e.g. cache index)
|
| 69 |
+
return x
|
| 70 |
+
x = jnp.expand_dims(x, axis=1)
|
| 71 |
+
tile_dims = [1] * x.ndim
|
| 72 |
+
tile_dims[1] = beam_size
|
| 73 |
+
return jnp.tile(x, tile_dims)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def add_beam_dim_cache(
|
| 77 |
+
cache: tuple[dict[str, jnp.ndarray], ...], beam_size: int
|
| 78 |
+
) -> tuple[dict[str, jnp.ndarray], ...]:
|
| 79 |
+
"""Creates new beam dimension in non-scalar array and tiles into it."""
|
| 80 |
+
new_cache = []
|
| 81 |
+
|
| 82 |
+
for layer in cache:
|
| 83 |
+
new_layer = {}
|
| 84 |
+
for key, x in layer.items():
|
| 85 |
+
if key in ['keys', 'vals']:
|
| 86 |
+
x = add_beam_dim(x, beam_size)
|
| 87 |
+
new_layer[key] = x
|
| 88 |
+
new_cache.append(new_layer)
|
| 89 |
+
|
| 90 |
+
return tuple(new_cache)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def flatten_beam_dim(x):
|
| 94 |
+
"""Flattens the first two dimensions of a non-scalar array."""
|
| 95 |
+
if x.ndim < 2: # ignore scalars (e.g. cache index)
|
| 96 |
+
return x
|
| 97 |
+
return x.reshape((x.shape[0] * x.shape[1],) + x.shape[2:])
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def unflatten_beam_dim(x, batch_size, beam_size):
|
| 101 |
+
"""Unflattens the first, flat batch*beam dimension of a non-scalar array."""
|
| 102 |
+
if x.ndim == 0: # ignore scalars (e.g. cache index)
|
| 103 |
+
return x
|
| 104 |
+
assert batch_size * beam_size == x.shape[0]
|
| 105 |
+
return x.reshape((batch_size, beam_size) + x.shape[1:])
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def flat_batch_beam_expand(x, beam_size):
|
| 109 |
+
"""Expands the each batch item by beam_size in batch_dimension."""
|
| 110 |
+
return flatten_beam_dim(add_beam_dim(x, beam_size))
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def gather_beams(nested, beam_indices, batch_size, new_beam_size):
|
| 114 |
+
"""Gathers the beam slices indexed by beam_indices into new beam array.
|
| 115 |
+
|
| 116 |
+
Args:
|
| 117 |
+
nested: pytree of arrays or scalars (the latter ignored).
|
| 118 |
+
beam_indices: array of beam_indices
|
| 119 |
+
batch_size: int: size of batch.
|
| 120 |
+
new_beam_size: int: size of _new_ beam dimension.
|
| 121 |
+
|
| 122 |
+
Returns:
|
| 123 |
+
New pytree with new beam arrays.
|
| 124 |
+
[batch_size, old_beam_size, ...] --> [batch_size, new_beam_size, ...]
|
| 125 |
+
"""
|
| 126 |
+
batch_indices = jnp.reshape(
|
| 127 |
+
jnp.arange(batch_size * new_beam_size) // new_beam_size,
|
| 128 |
+
(batch_size, new_beam_size),
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
def gather_fn(x):
|
| 132 |
+
if x.ndim == 0: # ignore scalars (e.g. cache index)
|
| 133 |
+
return x
|
| 134 |
+
else:
|
| 135 |
+
return x[batch_indices, beam_indices]
|
| 136 |
+
|
| 137 |
+
return jax.tree_util.tree_map(gather_fn, nested)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def gather_topk_beams(nested, score_or_log_prob, batch_size, new_beam_size):
|
| 141 |
+
"""Gathers the top-k beam slices given by score_or_log_prob array.
|
| 142 |
+
|
| 143 |
+
Args:
|
| 144 |
+
nested: pytree of arrays or scalars (the latter ignored).
|
| 145 |
+
score_or_log_prob: [batch_size, old_beam_size] array of values to sort by
|
| 146 |
+
for top-k selection of beam slices.
|
| 147 |
+
batch_size: int: size of batch.
|
| 148 |
+
new_beam_size: int: size of _new_ top-k selected beam dimension
|
| 149 |
+
|
| 150 |
+
Returns:
|
| 151 |
+
New pytree with new beam arrays containing top k new_beam_size slices.
|
| 152 |
+
[batch_size, old_beam_size, ...] --> [batch_size, new_beam_size, ...]
|
| 153 |
+
"""
|
| 154 |
+
_, topk_indices = lax.top_k(score_or_log_prob, k=new_beam_size)
|
| 155 |
+
topk_indices = jnp.flip(topk_indices, axis=1)
|
| 156 |
+
return gather_beams(nested, topk_indices, batch_size, new_beam_size)
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def apply_on_cache(fn, cache, *args, **kwargs):
|
| 160 |
+
"""Apply fn(val) only when key is 'keys' or 'val'."""
|
| 161 |
+
new_cache = []
|
| 162 |
+
for layer in cache:
|
| 163 |
+
new_layer = {}
|
| 164 |
+
for key, val in layer.items():
|
| 165 |
+
if key in ['keys', 'values', 'current_index', 'relative_position_bias']:
|
| 166 |
+
val = fn(val, *args, **kwargs)
|
| 167 |
+
new_layer[key] = val
|
| 168 |
+
new_cache.append(new_layer)
|
| 169 |
+
return tuple(new_cache)
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
# Beam search state:
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
@flax.struct.dataclass
|
| 176 |
+
class BeamState:
|
| 177 |
+
"""Holds beam search state data."""
|
| 178 |
+
|
| 179 |
+
# The position of the decoding loop in the length dimension.
|
| 180 |
+
cur_index: jax.Array # scalar int32: current decoded length index
|
| 181 |
+
# The active sequence log probabilities and finished sequence scores.
|
| 182 |
+
live_logprobs: jax.Array # float32: [batch_size, beam_size]
|
| 183 |
+
finished_scores: jax.Array # float32: [batch_size, beam_size]
|
| 184 |
+
# The current active-beam-searching and finished sequences.
|
| 185 |
+
live_seqs: jax.Array # int32: [batch_size, beam_size, max_decode_len]
|
| 186 |
+
finished_seqs: jax.Array # int32: [batch_size, beam_size,
|
| 187 |
+
# max_decode_len]
|
| 188 |
+
# Records which of the 'finished_seqs' is occupied and not a filler slot.
|
| 189 |
+
finished_flags: jax.Array # bool: [batch_size, beam_size]
|
| 190 |
+
# The current state of the autoregressive decoding caches.
|
| 191 |
+
cache: Any # Any pytree of arrays, e.g. flax attention Cache object
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def beam_init(seed_token, batch_size, beam_size, max_decode_len, cache):
|
| 195 |
+
"""Initializes the beam search state data structure."""
|
| 196 |
+
cur_index0 = jnp.array(0)
|
| 197 |
+
live_logprobs0 = jnp.tile(
|
| 198 |
+
jnp.array([0.0] + [NEG_INF] * (beam_size - 1)), [batch_size, 1]
|
| 199 |
+
)
|
| 200 |
+
finished_scores0 = jnp.ones((batch_size, beam_size)) * NEG_INF
|
| 201 |
+
|
| 202 |
+
live_seqs0 = jnp.concatenate(
|
| 203 |
+
[
|
| 204 |
+
jnp.reshape(seed_token, (batch_size, beam_size, 1)),
|
| 205 |
+
jnp.zeros((batch_size, beam_size, max_decode_len - 1), jnp.int32),
|
| 206 |
+
],
|
| 207 |
+
axis=-1,
|
| 208 |
+
) # (batch, beam, max_decode_len)
|
| 209 |
+
|
| 210 |
+
finished_seqs0 = jnp.zeros((batch_size, beam_size, max_decode_len), jnp.int32)
|
| 211 |
+
finished_flags0 = jnp.zeros((batch_size, beam_size), jnp.bool_)
|
| 212 |
+
beam_cache0 = apply_on_cache(lambda x: jnp.expand_dims(x, axis=0), cache)
|
| 213 |
+
return BeamState(
|
| 214 |
+
cur_index=cur_index0,
|
| 215 |
+
live_logprobs=live_logprobs0,
|
| 216 |
+
finished_scores=finished_scores0,
|
| 217 |
+
live_seqs=live_seqs0,
|
| 218 |
+
finished_seqs=finished_seqs0,
|
| 219 |
+
finished_flags=finished_flags0,
|
| 220 |
+
cache=beam_cache0,
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
# Beam search routine:
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def beam_search_flat(
|
| 228 |
+
seed_token,
|
| 229 |
+
cache,
|
| 230 |
+
tokens_to_logits,
|
| 231 |
+
alpha=BEAM_SEARCH_DEFAULT_ALPHA,
|
| 232 |
+
eos=None,
|
| 233 |
+
max_decode_len=MAX_DECODE_LEN,
|
| 234 |
+
mask=None,
|
| 235 |
+
):
|
| 236 |
+
"""Beam search for LM.
|
| 237 |
+
|
| 238 |
+
inputs and cache is already flat! i.e. first dimention == batch*beam.
|
| 239 |
+
|
| 240 |
+
Args:
|
| 241 |
+
seed_token: array: [beam_size, 1] int32 sequence of tokens.
|
| 242 |
+
cache: flax attention cache.
|
| 243 |
+
tokens_to_logits: fast autoregressive decoder function taking single token
|
| 244 |
+
slices and cache and returning next-token logits and updated cache.
|
| 245 |
+
alpha: float: scaling factor for brevity penalty.
|
| 246 |
+
eos: array: [vocab] 1 for end-of-sentence tokens, 0 for not.
|
| 247 |
+
max_decode_len: int: maximum length of decoded translations.
|
| 248 |
+
mask: array: [vocab] binary mask for vocab. 1 to keep the prob, 0 to set the
|
| 249 |
+
prob := 0.
|
| 250 |
+
|
| 251 |
+
Returns:
|
| 252 |
+
Tuple of:
|
| 253 |
+
[beam_size, max_decode_len] top-scoring sequences
|
| 254 |
+
[beam_size] beam-search scores.
|
| 255 |
+
"""
|
| 256 |
+
# We liberally annotate shape information for clarity below.
|
| 257 |
+
batch_size, beam_size = 1, seed_token.shape[0]
|
| 258 |
+
mask = mask.reshape((1, 1, -1))
|
| 259 |
+
eos = eos.reshape((1, 1, -1))
|
| 260 |
+
mask_bias = (1 - mask) * NEG_INF
|
| 261 |
+
|
| 262 |
+
# initialize beam search state
|
| 263 |
+
beam_search_init_state = beam_init(
|
| 264 |
+
seed_token, batch_size, beam_size, max_decode_len, cache
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
def beam_search_loop_cond_fn(state):
|
| 268 |
+
"""Beam search loop termination condition."""
|
| 269 |
+
# Have we reached max decoding length?
|
| 270 |
+
not_at_end = state.cur_index < max_decode_len - 1
|
| 271 |
+
|
| 272 |
+
# Is no further progress in the beam search possible?
|
| 273 |
+
# Get the best possible scores from alive sequences.
|
| 274 |
+
min_brevity_penalty = brevity_penalty(alpha, max_decode_len)
|
| 275 |
+
best_live_scores = state.live_logprobs[:, -1:] / min_brevity_penalty
|
| 276 |
+
# Get the worst scores from finished sequences.
|
| 277 |
+
worst_finished_scores = jnp.min(
|
| 278 |
+
state.finished_scores, axis=1, keepdims=True
|
| 279 |
+
)
|
| 280 |
+
# Mask out scores from slots without any actual finished sequences.
|
| 281 |
+
worst_finished_scores = jnp.where(
|
| 282 |
+
state.finished_flags, worst_finished_scores, NEG_INF
|
| 283 |
+
)
|
| 284 |
+
# If no best possible live score is better than current worst finished
|
| 285 |
+
# scores, the search cannot improve the finished set further.
|
| 286 |
+
search_terminated = jnp.all(worst_finished_scores > best_live_scores)
|
| 287 |
+
|
| 288 |
+
# If we're not at the max decode length, and the search hasn't terminated,
|
| 289 |
+
# continue looping.
|
| 290 |
+
return not_at_end & (~search_terminated)
|
| 291 |
+
|
| 292 |
+
def beam_search_loop_body_fn(state):
|
| 293 |
+
"""Beam search loop state update function."""
|
| 294 |
+
# Collect the current position slice along length to feed the fast
|
| 295 |
+
# autoregressive decoder model. Flatten the beam dimension into batch
|
| 296 |
+
# dimension for feeding into the model.
|
| 297 |
+
# --> [batch * beam, 1]
|
| 298 |
+
flat_ids = flatten_beam_dim(
|
| 299 |
+
lax.dynamic_slice(
|
| 300 |
+
state.live_seqs, (0, 0, state.cur_index), (batch_size, beam_size, 1)
|
| 301 |
+
)
|
| 302 |
+
)
|
| 303 |
+
# Flatten beam dimension into batch to be compatible with model.
|
| 304 |
+
# {[batch, beam, ...], ...} --> {[batch * beam, ...], ...}
|
| 305 |
+
flat_cache = apply_on_cache(flatten_beam_dim, state.cache)
|
| 306 |
+
|
| 307 |
+
# Call fast-decoder model on current tokens to get next-position logits.
|
| 308 |
+
# --> [batch * beam, vocab]
|
| 309 |
+
flat_logits, new_flat_cache = tokens_to_logits(flat_ids, flat_cache)
|
| 310 |
+
|
| 311 |
+
# unflatten beam dimension
|
| 312 |
+
# [batch * beam, vocab] --> [batch, beam, vocab]
|
| 313 |
+
logits = unflatten_beam_dim(flat_logits, batch_size, beam_size)
|
| 314 |
+
|
| 315 |
+
# Unflatten beam dimension in attention cache arrays
|
| 316 |
+
# {[batch * beam, ...], ...} --> {[batch, beam, ...], ...}
|
| 317 |
+
new_cache = apply_on_cache(
|
| 318 |
+
unflatten_beam_dim, new_flat_cache, batch_size, beam_size
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
# Gather log probabilities from logits
|
| 322 |
+
candidate_log_probs = jax.nn.log_softmax(logits)
|
| 323 |
+
# Add new logprobs to existing prefix logprobs.
|
| 324 |
+
# --> [batch, beam, vocab]
|
| 325 |
+
log_probs = candidate_log_probs + jnp.expand_dims(
|
| 326 |
+
state.live_logprobs, axis=2
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
# We'll need the vocab size, gather it from the log probability dimension.
|
| 330 |
+
vocab_size = log_probs.shape[2]
|
| 331 |
+
|
| 332 |
+
# mask away some tokens.
|
| 333 |
+
log_probs += mask_bias # [batch,beam,vocab]+[1,1,vocab]
|
| 334 |
+
|
| 335 |
+
# Each item in batch has beam_size * vocab_size candidate sequences.
|
| 336 |
+
# For each item, get the top 2*k candidates with the highest log-
|
| 337 |
+
# probabilities. We gather the top 2*K beams here so that even if the best
|
| 338 |
+
# K sequences reach EOS simultaneously, we have another K sequences
|
| 339 |
+
# remaining to continue the live beam search.
|
| 340 |
+
beams_to_keep = 2 * beam_size
|
| 341 |
+
# Flatten beam and vocab dimensions.
|
| 342 |
+
flat_log_probs = log_probs.reshape((batch_size, beam_size * vocab_size))
|
| 343 |
+
# Gather the top 2*K scores from _all_ beams.
|
| 344 |
+
# --> [batch, 2*beams], [batch, 2*beams]
|
| 345 |
+
topk_log_probs, topk_indices = lax.top_k(flat_log_probs, k=beams_to_keep)
|
| 346 |
+
# Recover the beam index by floor division.
|
| 347 |
+
topk_beam_indices = topk_indices // vocab_size
|
| 348 |
+
# Gather 2*k top beams.
|
| 349 |
+
# --> [batch, 2*beams, length]
|
| 350 |
+
topk_seq = gather_beams(
|
| 351 |
+
state.live_seqs, topk_beam_indices, batch_size, beams_to_keep
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
# Append the most probable 2*K token IDs to the top 2*K sequences
|
| 355 |
+
# Recover token id by modulo division and expand Id array for broadcasting.
|
| 356 |
+
# --> [batch, 2*beams, 1]
|
| 357 |
+
topk_ids = jnp.expand_dims(topk_indices % vocab_size, axis=2)
|
| 358 |
+
# Update sequences for the 2*K top-k new sequences.
|
| 359 |
+
# --> [batch, 2*beams, length]
|
| 360 |
+
topk_seq = lax.dynamic_update_slice(
|
| 361 |
+
topk_seq, topk_ids, (0, 0, state.cur_index + 1)
|
| 362 |
+
)
|
| 363 |
+
|
| 364 |
+
# Update LIVE (in-progress) sequences:
|
| 365 |
+
# Did any of these sequences reach an end marker?
|
| 366 |
+
# --> [batch, 2*beams]
|
| 367 |
+
last_token = topk_seq[:, :, state.cur_index + 1]
|
| 368 |
+
last_token = jax.nn.one_hot(last_token, vocab_size, dtype=jnp.bfloat16)
|
| 369 |
+
|
| 370 |
+
# any([batch, 2b, vocab] * [1, 1, vocab], axis=-1) == [batch, 2b]
|
| 371 |
+
newly_finished = jnp.any(last_token * eos, axis=-1)
|
| 372 |
+
|
| 373 |
+
# To prevent these newly finished sequences from being added to the LIVE
|
| 374 |
+
# set of active beam search sequences, set their log probs to a very large
|
| 375 |
+
# negative value.
|
| 376 |
+
new_log_probs = topk_log_probs + newly_finished * NEG_INF
|
| 377 |
+
# Determine the top k beam indices (from top 2*k beams) from log probs.
|
| 378 |
+
# --> [batch, beams]
|
| 379 |
+
_, new_topk_indices = lax.top_k(new_log_probs, k=beam_size)
|
| 380 |
+
new_topk_indices = jnp.flip(new_topk_indices, axis=1)
|
| 381 |
+
# Gather the top k beams (from top 2*k beams).
|
| 382 |
+
# --> [batch, beams, length], [batch, beams]
|
| 383 |
+
top_alive_seq, top_alive_log_probs = gather_beams(
|
| 384 |
+
[topk_seq, new_log_probs], new_topk_indices, batch_size, beam_size
|
| 385 |
+
)
|
| 386 |
+
|
| 387 |
+
# Determine the top k beam indices from the original set of all beams.
|
| 388 |
+
# --> [batch, beams]
|
| 389 |
+
top_alive_indices = gather_beams(
|
| 390 |
+
topk_beam_indices, new_topk_indices, batch_size, beam_size
|
| 391 |
+
)
|
| 392 |
+
# With these, gather the top k beam-associated caches.
|
| 393 |
+
# --> {[batch, beams, ...], ...}
|
| 394 |
+
top_alive_cache = apply_on_cache(
|
| 395 |
+
gather_beams, new_cache, top_alive_indices, batch_size, beam_size
|
| 396 |
+
)
|
| 397 |
+
|
| 398 |
+
# Update FINISHED (reached end of sentence) sequences:
|
| 399 |
+
# Calculate new seq scores from log probabilities.
|
| 400 |
+
new_scores = topk_log_probs / brevity_penalty(alpha, state.cur_index + 1)
|
| 401 |
+
# Mask out the still unfinished sequences by adding large negative value.
|
| 402 |
+
# --> [batch, 2*beams]
|
| 403 |
+
new_scores += (~newly_finished) * NEG_INF
|
| 404 |
+
|
| 405 |
+
# Combine sequences, scores, and flags along the beam dimension and compare
|
| 406 |
+
# new finished sequence scores to existing finished scores and select the
|
| 407 |
+
# best from the new set of beams.
|
| 408 |
+
finished_seqs = jnp.concatenate( # --> [batch, 3*beams, length]
|
| 409 |
+
[state.finished_seqs, topk_seq], axis=1
|
| 410 |
+
)
|
| 411 |
+
finished_scores = jnp.concatenate( # --> [batch, 3*beams]
|
| 412 |
+
[state.finished_scores, new_scores], axis=1
|
| 413 |
+
)
|
| 414 |
+
finished_flags = jnp.concatenate( # --> [batch, 3*beams]
|
| 415 |
+
[state.finished_flags, newly_finished], axis=1
|
| 416 |
+
)
|
| 417 |
+
# --> [batch, beams, length], [batch, beams], [batch, beams]
|
| 418 |
+
top_finished_seq, top_finished_scores, top_finished_flags = (
|
| 419 |
+
gather_topk_beams(
|
| 420 |
+
[finished_seqs, finished_scores, finished_flags],
|
| 421 |
+
finished_scores,
|
| 422 |
+
batch_size,
|
| 423 |
+
beam_size,
|
| 424 |
+
)
|
| 425 |
+
)
|
| 426 |
+
|
| 427 |
+
return BeamState(
|
| 428 |
+
cur_index=state.cur_index + 1,
|
| 429 |
+
live_logprobs=top_alive_log_probs,
|
| 430 |
+
finished_scores=top_finished_scores,
|
| 431 |
+
live_seqs=top_alive_seq,
|
| 432 |
+
finished_seqs=top_finished_seq,
|
| 433 |
+
finished_flags=top_finished_flags,
|
| 434 |
+
cache=top_alive_cache,
|
| 435 |
+
)
|
| 436 |
+
|
| 437 |
+
# Run while loop and get final beam search state.
|
| 438 |
+
final_state = lax.while_loop(
|
| 439 |
+
beam_search_loop_cond_fn, beam_search_loop_body_fn, beam_search_init_state
|
| 440 |
+
)
|
| 441 |
+
|
| 442 |
+
# Account for the edge-case where there are no finished sequences for a
|
| 443 |
+
# particular batch item. If so, return live sequences for that batch item.
|
| 444 |
+
# --> [batch]
|
| 445 |
+
none_finished = jnp.any(final_state.finished_flags, axis=1)
|
| 446 |
+
# --> [batch, beams, length]
|
| 447 |
+
finished_seqs = jnp.where(
|
| 448 |
+
none_finished[:, None, None],
|
| 449 |
+
final_state.finished_seqs,
|
| 450 |
+
final_state.live_seqs,
|
| 451 |
+
)
|
| 452 |
+
# --> [batch, beams]
|
| 453 |
+
finished_scores = jnp.where(
|
| 454 |
+
none_finished[:, None],
|
| 455 |
+
final_state.finished_scores,
|
| 456 |
+
final_state.live_logprobs,
|
| 457 |
+
)
|
| 458 |
+
|
| 459 |
+
finished_seqs = jnp.reshape(finished_seqs, (beam_size, max_decode_len))
|
| 460 |
+
finished_scores = jnp.reshape(finished_scores, (beam_size,))
|
| 461 |
+
|
| 462 |
+
final_cache = apply_on_cache(flatten_beam_dim, final_state.cache)
|
| 463 |
+
return finished_seqs, finished_scores, final_cache
|
backend/core/ag4masses/alphageometry/dd.py
ADDED
|
@@ -0,0 +1,1156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 DeepMind Technologies Limited
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
|
| 16 |
+
"""Implements Deductive Database (DD)."""
|
| 17 |
+
|
| 18 |
+
# pylint: disable=g-multiple-import,g-importing-member
|
| 19 |
+
from collections import defaultdict
|
| 20 |
+
import time
|
| 21 |
+
from typing import Any, Callable, Generator
|
| 22 |
+
|
| 23 |
+
import geometry as gm
|
| 24 |
+
import graph as gh
|
| 25 |
+
import graph_utils as utils
|
| 26 |
+
import numericals as nm
|
| 27 |
+
import problem as pr
|
| 28 |
+
from problem import Dependency, EmptyDependency
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def intersect1(set1: set[Any], set2: set[Any]) -> Any:
|
| 32 |
+
for x in set1:
|
| 33 |
+
if x in set2:
|
| 34 |
+
return x
|
| 35 |
+
return None
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def diff_point(l: gm.Line, a: gm.Point) -> gm.Point:
|
| 39 |
+
for x in l.neighbors(gm.Point):
|
| 40 |
+
if x != a:
|
| 41 |
+
return x
|
| 42 |
+
return None
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
# pylint: disable=protected-access
|
| 46 |
+
# pylint: disable=unused-argument
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def match_eqratio_eqratio_eqratio(
|
| 50 |
+
g: gh.Graph,
|
| 51 |
+
g_matcher: Callable[str, list[tuple[gm.Point, ...]]],
|
| 52 |
+
theorem: pr.Theorem,
|
| 53 |
+
) -> Generator[dict[str, gm.Point], None, None]:
|
| 54 |
+
"""Match eqratio a b c d m n p q, eqratio c d e f p q r u => eqratio a b e f m n r u."""
|
| 55 |
+
for m1 in g.type2nodes[gm.Value]:
|
| 56 |
+
for m2 in g.type2nodes[gm.Value]:
|
| 57 |
+
rats1 = []
|
| 58 |
+
for rat in m1.neighbors(gm.Ratio):
|
| 59 |
+
l1, l2 = rat.lengths
|
| 60 |
+
if l1 is None or l2 is None:
|
| 61 |
+
continue
|
| 62 |
+
rats1.append((l1, l2))
|
| 63 |
+
|
| 64 |
+
rats2 = []
|
| 65 |
+
for rat in m2.neighbors(gm.Ratio):
|
| 66 |
+
l1, l2 = rat.lengths
|
| 67 |
+
if l1 is None or l2 is None:
|
| 68 |
+
continue
|
| 69 |
+
rats2.append((l1, l2))
|
| 70 |
+
|
| 71 |
+
pairs = []
|
| 72 |
+
for (l1, l2), (l3, l4) in utils.cross(rats1, rats2):
|
| 73 |
+
if l2 == l3:
|
| 74 |
+
pairs.append((l1, l2, l4))
|
| 75 |
+
|
| 76 |
+
for (l1, l12, l2), (l3, l34, l4) in utils.comb2(pairs):
|
| 77 |
+
if (l1, l12, l2) == (l3, l34, l4):
|
| 78 |
+
continue
|
| 79 |
+
if l1 == l2 or l3 == l4:
|
| 80 |
+
continue
|
| 81 |
+
if l1 == l12 or l12 == l2 or l3 == l34 or l4 == l34:
|
| 82 |
+
continue
|
| 83 |
+
# d12 - d1 = d34 - d3 = m1
|
| 84 |
+
# d2 - d12 = d4 - d34 = m2
|
| 85 |
+
# => d2 - d1 = d4 - d3 (= m1+m2)
|
| 86 |
+
a, b = g.two_points_of_length(l1)
|
| 87 |
+
c, d = g.two_points_of_length(l12)
|
| 88 |
+
m, n = g.two_points_of_length(l3)
|
| 89 |
+
p, q = g.two_points_of_length(l34)
|
| 90 |
+
# eqangle a b c d m n p q
|
| 91 |
+
e, f = g.two_points_of_length(l2)
|
| 92 |
+
r, u = g.two_points_of_length(l4)
|
| 93 |
+
yield dict(zip('abcdefmnpqru', [a, b, c, d, e, f, m, n, p, q, r, u]))
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def match_eqangle_eqangle_eqangle(
|
| 97 |
+
g: gh.Graph,
|
| 98 |
+
g_matcher: Callable[str, list[tuple[gm.Point, ...]]],
|
| 99 |
+
theorem: pr.Theorem,
|
| 100 |
+
) -> Generator[dict[str, gm.Point], None, None]:
|
| 101 |
+
"""Match eqangle a b c d m n p q, eqangle c d e f p q r u => eqangle a b e f m n r u."""
|
| 102 |
+
for m1 in g.type2nodes[gm.Measure]:
|
| 103 |
+
for m2 in g.type2nodes[gm.Measure]:
|
| 104 |
+
angs1 = []
|
| 105 |
+
for ang in m1.neighbors(gm.Angle):
|
| 106 |
+
d1, d2 = ang.directions
|
| 107 |
+
if d1 is None or d2 is None:
|
| 108 |
+
continue
|
| 109 |
+
angs1.append((d1, d2))
|
| 110 |
+
|
| 111 |
+
angs2 = []
|
| 112 |
+
for ang in m2.neighbors(gm.Angle):
|
| 113 |
+
d1, d2 = ang.directions
|
| 114 |
+
if d1 is None or d2 is None:
|
| 115 |
+
continue
|
| 116 |
+
angs2.append((d1, d2))
|
| 117 |
+
|
| 118 |
+
pairs = []
|
| 119 |
+
for (d1, d2), (d3, d4) in utils.cross(angs1, angs2):
|
| 120 |
+
if d2 == d3:
|
| 121 |
+
pairs.append((d1, d2, d4))
|
| 122 |
+
|
| 123 |
+
for (d1, d12, d2), (d3, d34, d4) in utils.comb2(pairs):
|
| 124 |
+
if (d1, d12, d2) == (d3, d34, d4):
|
| 125 |
+
continue
|
| 126 |
+
if d1 == d2 or d3 == d4:
|
| 127 |
+
continue
|
| 128 |
+
if d1 == d12 or d12 == d2 or d3 == d34 or d4 == d34:
|
| 129 |
+
continue
|
| 130 |
+
# d12 - d1 = d34 - d3 = m1
|
| 131 |
+
# d2 - d12 = d4 - d34 = m2
|
| 132 |
+
# => d2 - d1 = d4 - d3
|
| 133 |
+
a, b = g.two_points_on_direction(d1)
|
| 134 |
+
c, d = g.two_points_on_direction(d12)
|
| 135 |
+
m, n = g.two_points_on_direction(d3)
|
| 136 |
+
p, q = g.two_points_on_direction(d34)
|
| 137 |
+
# eqangle a b c d m n p q
|
| 138 |
+
e, f = g.two_points_on_direction(d2)
|
| 139 |
+
r, u = g.two_points_on_direction(d4)
|
| 140 |
+
yield dict(zip('abcdefmnpqru', [a, b, c, d, e, f, m, n, p, q, r, u]))
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def match_perp_perp_npara_eqangle(
|
| 144 |
+
g: gh.Graph,
|
| 145 |
+
g_matcher: Callable[str, list[tuple[gm.Point, ...]]],
|
| 146 |
+
theorem: pr.Theorem,
|
| 147 |
+
) -> Generator[dict[str, gm.Point], None, None]:
|
| 148 |
+
"""Match perp A B C D, perp E F G H, npara A B E F => eqangle A B E F C D G H."""
|
| 149 |
+
dpairs = []
|
| 150 |
+
for ang in g.vhalfpi.neighbors(gm.Angle):
|
| 151 |
+
d1, d2 = ang.directions
|
| 152 |
+
if d1 is None or d2 is None:
|
| 153 |
+
continue
|
| 154 |
+
dpairs.append((d1, d2))
|
| 155 |
+
|
| 156 |
+
for (d1, d2), (d3, d4) in utils.comb2(dpairs):
|
| 157 |
+
a, b = g.two_points_on_direction(d1)
|
| 158 |
+
c, d = g.two_points_on_direction(d2)
|
| 159 |
+
m, n = g.two_points_on_direction(d3)
|
| 160 |
+
p, q = g.two_points_on_direction(d4)
|
| 161 |
+
if g.check_npara([a, b, m, n]):
|
| 162 |
+
if ({a, b}, {c, d}) == ({m, n}, {p, q}):
|
| 163 |
+
continue
|
| 164 |
+
if ({a, b}, {c, d}) == ({p, q}, {m, n}):
|
| 165 |
+
continue
|
| 166 |
+
|
| 167 |
+
yield dict(zip('ABCDEFGH', [a, b, c, d, m, n, p, q]))
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def match_circle_coll_eqangle_midp(
|
| 171 |
+
g: gh.Graph,
|
| 172 |
+
g_matcher: Callable[str, list[tuple[gm.Point, ...]]],
|
| 173 |
+
theorem: pr.Theorem,
|
| 174 |
+
) -> Generator[dict[str, gm.Point], None, None]:
|
| 175 |
+
"""Match circle O A B C, coll M B C, eqangle A B A C O B O M => midp M B C."""
|
| 176 |
+
for p, a, b, c in g.all_circles():
|
| 177 |
+
ab = g._get_line(a, b)
|
| 178 |
+
if ab is None:
|
| 179 |
+
continue
|
| 180 |
+
if ab.val is None:
|
| 181 |
+
continue
|
| 182 |
+
ac = g._get_line(a, c)
|
| 183 |
+
if ac is None:
|
| 184 |
+
continue
|
| 185 |
+
if ac.val is None:
|
| 186 |
+
continue
|
| 187 |
+
pb = g._get_line(p, b)
|
| 188 |
+
if pb is None:
|
| 189 |
+
continue
|
| 190 |
+
if pb.val is None:
|
| 191 |
+
continue
|
| 192 |
+
|
| 193 |
+
bc = g._get_line(b, c)
|
| 194 |
+
if bc is None:
|
| 195 |
+
continue
|
| 196 |
+
bc_points = bc.neighbors(gm.Point, return_set=True)
|
| 197 |
+
|
| 198 |
+
anga, _ = g._get_angle(ab.val, ac.val)
|
| 199 |
+
|
| 200 |
+
for angp in pb.val.neighbors(gm.Angle):
|
| 201 |
+
if not g.is_equal(anga, angp):
|
| 202 |
+
continue
|
| 203 |
+
|
| 204 |
+
_, d = angp.directions
|
| 205 |
+
for l in d.neighbors(gm.Line):
|
| 206 |
+
l_points = l.neighbors(gm.Point, return_set=True)
|
| 207 |
+
m = intersect1(bc_points, l_points)
|
| 208 |
+
if m is not None:
|
| 209 |
+
yield dict(zip('ABCMO', [a, b, c, m, p]))
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def match_midp_perp_cong(
|
| 213 |
+
g: gh.Graph,
|
| 214 |
+
g_matcher: Callable[str, list[tuple[gm.Point, ...]]],
|
| 215 |
+
theorem: pr.Theorem,
|
| 216 |
+
) -> Generator[dict[str, gm.Point], None, None]:
|
| 217 |
+
"""Match midp M A B, perp O M A B => cong O A O B."""
|
| 218 |
+
for m, a, b in g.all_midps():
|
| 219 |
+
ab = g._get_line(a, b)
|
| 220 |
+
for l in m.neighbors(gm.Line):
|
| 221 |
+
if g.check_perpl(l, ab):
|
| 222 |
+
for o in l.neighbors(gm.Point):
|
| 223 |
+
if o != m:
|
| 224 |
+
yield dict(zip('ABMO', [a, b, m, o]))
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def match_cyclic_eqangle_cong(
|
| 228 |
+
g: gh.Graph,
|
| 229 |
+
g_matcher: Callable[str, list[tuple[gm.Point, ...]]],
|
| 230 |
+
theorem: pr.Theorem,
|
| 231 |
+
) -> Generator[dict[str, gm.Point], None, None]:
|
| 232 |
+
"""Match cyclic A B C P Q R, eqangle C A C B R P R Q => cong A B P Q."""
|
| 233 |
+
for c in g.type2nodes[gm.Circle]:
|
| 234 |
+
ps = c.neighbors(gm.Point)
|
| 235 |
+
for (a, b, c), (x, y, z) in utils.comb2(list(utils.perm3(ps))):
|
| 236 |
+
if {a, b, c} == {x, y, z}:
|
| 237 |
+
continue
|
| 238 |
+
if g.check_eqangle([c, a, c, b, z, x, z, y]):
|
| 239 |
+
yield dict(zip('ABCPQR', [a, b, c, x, y, z]))
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def match_circle_eqangle_perp(
|
| 243 |
+
g: gh.Graph,
|
| 244 |
+
g_matcher: Callable[str, list[tuple[gm.Point, ...]]],
|
| 245 |
+
theorem: pr.Theorem,
|
| 246 |
+
) -> Generator[dict[str, gm.Point], None, None]:
|
| 247 |
+
"""Match circle O A B C, eqangle A X A B C A C B => perp O A A X."""
|
| 248 |
+
for p, a, b, c in g.all_circles():
|
| 249 |
+
ca = g._get_line(c, a)
|
| 250 |
+
if ca is None:
|
| 251 |
+
continue
|
| 252 |
+
cb = g._get_line(c, b)
|
| 253 |
+
if cb is None:
|
| 254 |
+
continue
|
| 255 |
+
ab = g._get_line(a, b)
|
| 256 |
+
if ab is None:
|
| 257 |
+
continue
|
| 258 |
+
|
| 259 |
+
if ca.val is None:
|
| 260 |
+
continue
|
| 261 |
+
if cb.val is None:
|
| 262 |
+
continue
|
| 263 |
+
if ab.val is None:
|
| 264 |
+
continue
|
| 265 |
+
|
| 266 |
+
c_ang, _ = g._get_angle(cb.val, ca.val)
|
| 267 |
+
if c_ang is None:
|
| 268 |
+
continue
|
| 269 |
+
|
| 270 |
+
for ang in ab.val.neighbors(gm.Angle):
|
| 271 |
+
if g.is_equal(ang, c_ang):
|
| 272 |
+
_, d = ang.directions
|
| 273 |
+
for l in d.neighbors(gm.Line):
|
| 274 |
+
if a not in l.neighbors(gm.Point):
|
| 275 |
+
continue
|
| 276 |
+
x = diff_point(l, a)
|
| 277 |
+
if x is None:
|
| 278 |
+
continue
|
| 279 |
+
yield dict(zip('OABCX', [p, a, b, c, x]))
|
| 280 |
+
break
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
def match_circle_perp_eqangle(
|
| 284 |
+
g: gh.Graph,
|
| 285 |
+
g_matcher: Callable[str, list[tuple[gm.Point, ...]]],
|
| 286 |
+
theorem: pr.Theorem,
|
| 287 |
+
) -> Generator[dict[str, gm.Point], None, None]:
|
| 288 |
+
"""Match circle O A B C, perp O A A X => eqangle A X A B C A C B."""
|
| 289 |
+
for p, a, b, c in g.all_circles():
|
| 290 |
+
pa = g._get_line(p, a)
|
| 291 |
+
if pa is None:
|
| 292 |
+
continue
|
| 293 |
+
if pa.val is None:
|
| 294 |
+
continue
|
| 295 |
+
for l in a.neighbors(gm.Line):
|
| 296 |
+
if g.check_perpl(pa, l):
|
| 297 |
+
x = diff_point(l, a)
|
| 298 |
+
if x is not None:
|
| 299 |
+
yield dict(zip('OABCX', [p, a, b, c, x]))
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
def match_perp_perp_ncoll_para(
|
| 303 |
+
g: gh.Graph,
|
| 304 |
+
g_matcher: Callable[str, list[tuple[gm.Point, ...]]],
|
| 305 |
+
theorem: pr.Theorem,
|
| 306 |
+
) -> Generator[dict[str, gm.Point], None, None]:
|
| 307 |
+
"""Match perp A B C D, perp C D E F, ncoll A B E => para A B E F."""
|
| 308 |
+
d2d = defaultdict(list)
|
| 309 |
+
for ang in g.vhalfpi.neighbors(gm.Angle):
|
| 310 |
+
d1, d2 = ang.directions
|
| 311 |
+
if d1 is None or d2 is None:
|
| 312 |
+
continue
|
| 313 |
+
d2d[d1] += [d2]
|
| 314 |
+
d2d[d2] += [d1]
|
| 315 |
+
|
| 316 |
+
for x, ys in d2d.items():
|
| 317 |
+
if len(ys) < 2:
|
| 318 |
+
continue
|
| 319 |
+
c, d = g.two_points_on_direction(x)
|
| 320 |
+
for y1, y2 in utils.comb2(ys):
|
| 321 |
+
a, b = g.two_points_on_direction(y1)
|
| 322 |
+
e, f = g.two_points_on_direction(y2)
|
| 323 |
+
if nm.check_ncoll([a.num, b.num, e.num]):
|
| 324 |
+
yield dict(zip('ABCDEF', [a, b, c, d, e, f]))
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
def match_eqangle6_ncoll_cong(
|
| 328 |
+
g: gh.Graph,
|
| 329 |
+
g_matcher: Callable[str, list[tuple[gm.Point, ...]]],
|
| 330 |
+
theorem: pr.Theorem,
|
| 331 |
+
) -> Generator[dict[str, gm.Point], None, None]:
|
| 332 |
+
"""Match eqangle6 A O A B B A B O, ncoll O A B => cong O A O B."""
|
| 333 |
+
for a in g.type2nodes[gm.Point]:
|
| 334 |
+
for b, c in utils.comb2(g.type2nodes[gm.Point]):
|
| 335 |
+
if a == b or a == c:
|
| 336 |
+
continue
|
| 337 |
+
if g.check_eqangle([b, a, b, c, c, b, c, a]):
|
| 338 |
+
if g.check_ncoll([a, b, c]):
|
| 339 |
+
yield dict(zip('OAB', [a, b, c]))
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
def match_eqangle_perp_perp(
|
| 343 |
+
g: gh.Graph,
|
| 344 |
+
g_matcher: Callable[str, list[tuple[gm.Point, ...]]],
|
| 345 |
+
theorem: pr.Theorem,
|
| 346 |
+
) -> Generator[dict[str, gm.Point], None, None]:
|
| 347 |
+
"""Match eqangle A B P Q C D U V, perp P Q U V => perp A B C D."""
|
| 348 |
+
for ang in g.vhalfpi.neighbors(gm.Angle):
|
| 349 |
+
# d1 perp d2
|
| 350 |
+
d1, d2 = ang.directions
|
| 351 |
+
if d1 is None or d2 is None:
|
| 352 |
+
continue
|
| 353 |
+
for d3, d4 in utils.comb2(g.type2nodes[gm.Direction]):
|
| 354 |
+
if d1 == d3 or d2 == d4:
|
| 355 |
+
continue
|
| 356 |
+
# if d1 - d3 = d2 - d4 => d3 perp d4
|
| 357 |
+
a13, a31 = g._get_angle(d1, d3)
|
| 358 |
+
a24, a42 = g._get_angle(d2, d4)
|
| 359 |
+
if a13 is None or a31 is None or a24 is None or a42 is None:
|
| 360 |
+
continue
|
| 361 |
+
if g.is_equal(a13, a24) and g.is_equal(a31, a42):
|
| 362 |
+
a, b = g.two_points_on_direction(d1)
|
| 363 |
+
c, d = g.two_points_on_direction(d2)
|
| 364 |
+
m, n = g.two_points_on_direction(d3)
|
| 365 |
+
p, q = g.two_points_on_direction(d4)
|
| 366 |
+
yield dict(zip('ABCDPQUV', [m, n, p, q, a, b, c, d]))
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
def match_eqangle_ncoll_cyclic(
|
| 370 |
+
g: gh.Graph,
|
| 371 |
+
g_matcher: Callable[str, list[tuple[gm.Point, ...]]],
|
| 372 |
+
theorem: pr.Theorem,
|
| 373 |
+
) -> Generator[dict[str, gm.Point], None, None]:
|
| 374 |
+
"""Match eqangle6 P A P B Q A Q B, ncoll P Q A B => cyclic A B P Q."""
|
| 375 |
+
for l1, l2, l3, l4 in g.all_eqangles_distinct_linepairss():
|
| 376 |
+
if len(set([l1, l2, l3, l4])) < 4:
|
| 377 |
+
continue # they all must be distinct.
|
| 378 |
+
|
| 379 |
+
p1s = l1.neighbors(gm.Point, return_set=True)
|
| 380 |
+
p2s = l2.neighbors(gm.Point, return_set=True)
|
| 381 |
+
p3s = l3.neighbors(gm.Point, return_set=True)
|
| 382 |
+
p4s = l4.neighbors(gm.Point, return_set=True)
|
| 383 |
+
|
| 384 |
+
p = intersect1(p1s, p2s)
|
| 385 |
+
if not p:
|
| 386 |
+
continue
|
| 387 |
+
q = intersect1(p3s, p4s)
|
| 388 |
+
if not q:
|
| 389 |
+
continue
|
| 390 |
+
a = intersect1(p1s, p3s)
|
| 391 |
+
if not a:
|
| 392 |
+
continue
|
| 393 |
+
b = intersect1(p2s, p4s)
|
| 394 |
+
if not b:
|
| 395 |
+
continue
|
| 396 |
+
if len(set([a, b, p, q])) < 4:
|
| 397 |
+
continue
|
| 398 |
+
|
| 399 |
+
if not g.check_ncoll([a, b, p, q]):
|
| 400 |
+
continue
|
| 401 |
+
|
| 402 |
+
yield dict(zip('ABPQ', [a, b, p, q]))
|
| 403 |
+
|
| 404 |
+
|
| 405 |
+
def match_eqangle_para(
|
| 406 |
+
g: gh.Graph,
|
| 407 |
+
g_matcher: Callable[str, list[tuple[gm.Point, ...]]],
|
| 408 |
+
theorem: pr.Theorem,
|
| 409 |
+
) -> Generator[dict[str, gm.Point], None, None]:
|
| 410 |
+
"""Match eqangle A B P Q C D P Q => para A B C D."""
|
| 411 |
+
for measure in g.type2nodes[gm.Measure]:
|
| 412 |
+
angs = measure.neighbors(gm.Angle)
|
| 413 |
+
d12, d21 = defaultdict(list), defaultdict(list)
|
| 414 |
+
for ang in angs:
|
| 415 |
+
d1, d2 = ang.directions
|
| 416 |
+
if d1 is None or d2 is None:
|
| 417 |
+
continue
|
| 418 |
+
d12[d1].append(d2)
|
| 419 |
+
d21[d2].append(d1)
|
| 420 |
+
|
| 421 |
+
for d1, d2s in d12.items():
|
| 422 |
+
a, b = g.two_points_on_direction(d1)
|
| 423 |
+
for d2, d3 in utils.comb2(d2s):
|
| 424 |
+
c, d = g.two_points_on_direction(d2)
|
| 425 |
+
e, f = g.two_points_on_direction(d3)
|
| 426 |
+
yield dict(zip('ABCDPQ', [c, d, e, f, a, b]))
|
| 427 |
+
|
| 428 |
+
|
| 429 |
+
def match_cyclic_eqangle(
|
| 430 |
+
g: gh.Graph,
|
| 431 |
+
g_matcher: Callable[str, list[tuple[gm.Point, ...]]],
|
| 432 |
+
theorem: pr.Theorem,
|
| 433 |
+
) -> Generator[dict[str, gm.Point], None, None]:
|
| 434 |
+
"""Match cyclic A B P Q => eqangle P A P B Q A Q B."""
|
| 435 |
+
record = set()
|
| 436 |
+
for a, b, c, d in g_matcher('cyclic'):
|
| 437 |
+
if (a, b, c, d) in record:
|
| 438 |
+
continue
|
| 439 |
+
record.add((a, b, c, d))
|
| 440 |
+
record.add((a, b, d, c))
|
| 441 |
+
record.add((b, a, c, d))
|
| 442 |
+
record.add((b, a, d, c))
|
| 443 |
+
yield dict(zip('ABPQ', [a, b, c, d]))
|
| 444 |
+
|
| 445 |
+
|
| 446 |
+
def rotate_simtri(
|
| 447 |
+
a: gm.Point, b: gm.Point, c: gm.Point, x: gm.Point, y: gm.Point, z: gm.Point
|
| 448 |
+
) -> Generator[tuple[gm.Point, ...], None, None]:
|
| 449 |
+
"""Rotate points around for similar triangle predicates."""
|
| 450 |
+
yield (z, y, x, c, b, a)
|
| 451 |
+
for p in [
|
| 452 |
+
(b, c, a, y, z, x),
|
| 453 |
+
(c, a, b, z, x, y),
|
| 454 |
+
(x, y, z, a, b, c),
|
| 455 |
+
(y, z, x, b, c, a),
|
| 456 |
+
(z, x, y, c, a, b),
|
| 457 |
+
]:
|
| 458 |
+
yield p
|
| 459 |
+
yield p[::-1]
|
| 460 |
+
|
| 461 |
+
|
| 462 |
+
def match_cong_cong_cong_cyclic(
|
| 463 |
+
g: gh.Graph,
|
| 464 |
+
g_matcher: Callable[str, list[tuple[gm.Point, ...]]],
|
| 465 |
+
theorem: pr.Theorem,
|
| 466 |
+
) -> Generator[dict[str, gm.Point], None, None]:
|
| 467 |
+
"""Match cong O A O B, cong O B O C, cong O C O D => cyclic A B C D."""
|
| 468 |
+
for l in g.type2nodes[gm.Length]:
|
| 469 |
+
p2p = defaultdict(list)
|
| 470 |
+
for s in l.neighbors(gm.Segment):
|
| 471 |
+
a, b = s.points
|
| 472 |
+
p2p[a].append(b)
|
| 473 |
+
p2p[b].append(a)
|
| 474 |
+
|
| 475 |
+
for p, ps in p2p.items():
|
| 476 |
+
if len(ps) >= 4:
|
| 477 |
+
for a, b, c, d in utils.comb4(ps):
|
| 478 |
+
yield dict(zip('OABCD', [p, a, b, c, d]))
|
| 479 |
+
|
| 480 |
+
|
| 481 |
+
def match_cong_cong_cong_ncoll_contri(
|
| 482 |
+
g: gh.Graph,
|
| 483 |
+
g_matcher: Callable[str, list[tuple[gm.Point, ...]]],
|
| 484 |
+
theorem: pr.Theorem,
|
| 485 |
+
) -> Generator[dict[str, gm.Point], None, None]:
|
| 486 |
+
"""Match cong A B P Q, cong B C Q R, cong C A R P, ncoll A B C => contri* A B C P Q R."""
|
| 487 |
+
record = set()
|
| 488 |
+
for a, b, p, q in g_matcher('cong'):
|
| 489 |
+
for c in g.type2nodes[gm.Point]:
|
| 490 |
+
for r in g.type2nodes[gm.Point]:
|
| 491 |
+
if any([x in record for x in rotate_simtri(a, b, c, p, q, r)]):
|
| 492 |
+
continue
|
| 493 |
+
if not g.check_ncoll([a, b, c]):
|
| 494 |
+
continue
|
| 495 |
+
if g.check_cong([b, c, q, r]) and g.check_cong([c, a, r, p]):
|
| 496 |
+
record.add((a, b, c, p, q, r))
|
| 497 |
+
yield dict(zip('ABCPQR', [a, b, c, p, q, r]))
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
def match_cong_cong_eqangle6_ncoll_contri(
|
| 501 |
+
g: gh.Graph,
|
| 502 |
+
g_matcher: Callable[str, list[tuple[gm.Point, ...]]],
|
| 503 |
+
theorem: pr.Theorem,
|
| 504 |
+
) -> Generator[dict[str, gm.Point], None, None]:
|
| 505 |
+
"""Match cong A B P Q, cong B C Q R, eqangle6 B A B C Q P Q R, ncoll A B C => contri* A B C P Q R."""
|
| 506 |
+
record = set()
|
| 507 |
+
for a, b, p, q in g_matcher('cong'):
|
| 508 |
+
for c in g.type2nodes[gm.Point]:
|
| 509 |
+
if c in (a, b):
|
| 510 |
+
continue
|
| 511 |
+
for r in g.type2nodes[gm.Point]:
|
| 512 |
+
if r in (p, q):
|
| 513 |
+
continue
|
| 514 |
+
|
| 515 |
+
in_record = False
|
| 516 |
+
for x in [
|
| 517 |
+
(c, b, a, r, q, p),
|
| 518 |
+
(p, q, r, a, b, c),
|
| 519 |
+
(r, q, p, c, b, a),
|
| 520 |
+
]:
|
| 521 |
+
if x in record:
|
| 522 |
+
in_record = True
|
| 523 |
+
break
|
| 524 |
+
|
| 525 |
+
if in_record:
|
| 526 |
+
continue
|
| 527 |
+
|
| 528 |
+
if not g.check_cong([b, c, q, r]):
|
| 529 |
+
continue
|
| 530 |
+
if not g.check_ncoll([a, b, c]):
|
| 531 |
+
continue
|
| 532 |
+
|
| 533 |
+
if nm.same_clock(a.num, b.num, c.num, p.num, q.num, r.num):
|
| 534 |
+
if g.check_eqangle([b, a, b, c, q, p, q, r]):
|
| 535 |
+
record.add((a, b, c, p, q, r))
|
| 536 |
+
yield dict(zip('ABCPQR', [a, b, c, p, q, r]))
|
| 537 |
+
else:
|
| 538 |
+
if g.check_eqangle([b, a, b, c, q, r, q, p]):
|
| 539 |
+
record.add((a, b, c, p, q, r))
|
| 540 |
+
yield dict(zip('ABCPQR', [a, b, c, p, q, r]))
|
| 541 |
+
|
| 542 |
+
|
| 543 |
+
def match_eqratio6_eqangle6_ncoll_simtri(
|
| 544 |
+
g: gh.Graph,
|
| 545 |
+
g_matcher: Callable[str, list[tuple[gm.Point, ...]]],
|
| 546 |
+
theorem: pr.Theorem,
|
| 547 |
+
) -> Generator[dict[str, gm.Point], None, None]:
|
| 548 |
+
"""Match eqratio6 B A B C Q P Q R, eqratio6 C A C B R P R Q, ncoll A B C => simtri* A B C P Q R."""
|
| 549 |
+
enums = g_matcher('eqratio6')
|
| 550 |
+
|
| 551 |
+
record = set()
|
| 552 |
+
for b, a, b, c, q, p, q, r in enums: # pylint: disable=redeclared-assigned-name,unused-variable
|
| 553 |
+
if (a, b, c) == (p, q, r):
|
| 554 |
+
continue
|
| 555 |
+
if any([x in record for x in rotate_simtri(a, b, c, p, q, r)]):
|
| 556 |
+
continue
|
| 557 |
+
if not g.check_ncoll([a, b, c]):
|
| 558 |
+
continue
|
| 559 |
+
|
| 560 |
+
if nm.same_clock(a.num, b.num, c.num, p.num, q.num, r.num):
|
| 561 |
+
if g.check_eqangle([b, a, b, c, q, p, q, r]):
|
| 562 |
+
record.add((a, b, c, p, q, r))
|
| 563 |
+
yield dict(zip('ABCPQR', [a, b, c, p, q, r]))
|
| 564 |
+
elif g.check_eqangle([b, a, b, c, q, r, q, p]):
|
| 565 |
+
record.add((a, b, c, p, q, r))
|
| 566 |
+
yield dict(zip('ABCPQR', [a, b, c, p, q, r]))
|
| 567 |
+
|
| 568 |
+
|
| 569 |
+
def match_eqangle6_eqangle6_ncoll_simtri(
|
| 570 |
+
g: gh.Graph,
|
| 571 |
+
g_matcher: Callable[str, list[tuple[gm.Point, ...]]],
|
| 572 |
+
theorem: pr.Theorem,
|
| 573 |
+
) -> Generator[dict[str, gm.Point], None, None]:
|
| 574 |
+
"""Match eqangle6 B A B C Q P Q R, eqangle6 C A C B R P R Q, ncoll A B C => simtri A B C P Q R."""
|
| 575 |
+
enums = g_matcher('eqangle6')
|
| 576 |
+
|
| 577 |
+
record = set()
|
| 578 |
+
for b, a, b, c, q, p, q, r in enums: # pylint: disable=redeclared-assigned-name,unused-variable
|
| 579 |
+
if (a, b, c) == (p, q, r):
|
| 580 |
+
continue
|
| 581 |
+
if any([x in record for x in rotate_simtri(a, b, c, p, q, r)]):
|
| 582 |
+
continue
|
| 583 |
+
if not g.check_eqangle([c, a, c, b, r, p, r, q]):
|
| 584 |
+
continue
|
| 585 |
+
if not g.check_ncoll([a, b, c]):
|
| 586 |
+
continue
|
| 587 |
+
|
| 588 |
+
mapping = dict(zip('ABCPQR', [a, b, c, p, q, r]))
|
| 589 |
+
record.add((a, b, c, p, q, r))
|
| 590 |
+
yield mapping
|
| 591 |
+
|
| 592 |
+
|
| 593 |
+
def match_eqratio6_eqratio6_ncoll_simtri(
|
| 594 |
+
g: gh.Graph,
|
| 595 |
+
g_matcher: Callable[str, list[tuple[gm.Point, ...]]],
|
| 596 |
+
theorem: pr.Theorem,
|
| 597 |
+
) -> Generator[dict[str, gm.Point], None, None]:
|
| 598 |
+
"""Match eqratio6 B A B C Q P Q R, eqratio6 C A C B R P R Q, ncoll A B C => simtri* A B C P Q R."""
|
| 599 |
+
enums = g_matcher('eqratio6')
|
| 600 |
+
|
| 601 |
+
record = set()
|
| 602 |
+
for b, a, b, c, q, p, q, r in enums: # pylint: disable=redeclared-assigned-name,unused-variable
|
| 603 |
+
if (a, b, c) == (p, q, r):
|
| 604 |
+
continue
|
| 605 |
+
if any([x in record for x in rotate_simtri(a, b, c, p, q, r)]):
|
| 606 |
+
continue
|
| 607 |
+
if not g.check_eqratio([c, a, c, b, r, p, r, q]):
|
| 608 |
+
continue
|
| 609 |
+
if not g.check_ncoll([a, b, c]):
|
| 610 |
+
continue
|
| 611 |
+
|
| 612 |
+
mapping = dict(zip('ABCPQR', [a, b, c, p, q, r]))
|
| 613 |
+
record.add((a, b, c, p, q, r))
|
| 614 |
+
yield mapping
|
| 615 |
+
|
| 616 |
+
|
| 617 |
+
def match_eqangle6_eqangle6_ncoll_simtri2(
|
| 618 |
+
g: gh.Graph,
|
| 619 |
+
g_matcher: Callable[str, list[tuple[gm.Point, ...]]],
|
| 620 |
+
theorem: pr.Theorem,
|
| 621 |
+
) -> Generator[dict[str, gm.Point], None, None]:
|
| 622 |
+
"""Match eqangle6 B A B C Q R Q P, eqangle6 C A C B R Q R P, ncoll A B C => simtri2 A B C P Q R."""
|
| 623 |
+
enums = g_matcher('eqangle6')
|
| 624 |
+
|
| 625 |
+
record = set()
|
| 626 |
+
for b, a, b, c, q, r, q, p in enums: # pylint: disable=redeclared-assigned-name,unused-variable
|
| 627 |
+
if (a, b, c) == (p, q, r):
|
| 628 |
+
continue
|
| 629 |
+
if any([x in record for x in rotate_simtri(a, b, c, p, q, r)]):
|
| 630 |
+
continue
|
| 631 |
+
if not g.check_eqangle([c, a, c, b, r, q, r, p]):
|
| 632 |
+
continue
|
| 633 |
+
if not g.check_ncoll([a, b, c]):
|
| 634 |
+
continue
|
| 635 |
+
|
| 636 |
+
mapping = dict(zip('ABCPQR', [a, b, c, p, q, r]))
|
| 637 |
+
record.add((a, b, c, p, q, r))
|
| 638 |
+
yield mapping
|
| 639 |
+
|
| 640 |
+
|
| 641 |
+
def rotate_contri(
|
| 642 |
+
a: gm.Point, b: gm.Point, c: gm.Point, x: gm.Point, y: gm.Point, z: gm.Point
|
| 643 |
+
) -> Generator[tuple[gm.Point, ...], None, None]:
|
| 644 |
+
for p in [(b, a, c, y, x, z), (x, y, z, a, b, c), (y, x, z, b, a, c)]:
|
| 645 |
+
yield p
|
| 646 |
+
|
| 647 |
+
|
| 648 |
+
def match_eqangle6_eqangle6_ncoll_cong_contri(
|
| 649 |
+
g: gh.Graph,
|
| 650 |
+
g_matcher: Callable[str, list[tuple[gm.Point, ...]]],
|
| 651 |
+
theorem: pr.Theorem,
|
| 652 |
+
) -> Generator[dict[str, gm.Point], None, None]:
|
| 653 |
+
"""Match eqangle6 B A B C Q P Q R, eqangle6 C A C B R P R Q, ncoll A B C, cong A B P Q => contri A B C P Q R."""
|
| 654 |
+
enums = g_matcher('eqangle6')
|
| 655 |
+
|
| 656 |
+
record = set()
|
| 657 |
+
for b, a, b, c, q, p, q, r in enums: # pylint: disable=redeclared-assigned-name,unused-variable
|
| 658 |
+
if not g.check_cong([a, b, p, q]):
|
| 659 |
+
continue
|
| 660 |
+
if (a, b, c) == (p, q, r):
|
| 661 |
+
continue
|
| 662 |
+
if any([x in record for x in rotate_contri(a, b, c, p, q, r)]):
|
| 663 |
+
continue
|
| 664 |
+
if not g.check_eqangle([c, a, c, b, r, p, r, q]):
|
| 665 |
+
continue
|
| 666 |
+
|
| 667 |
+
if not g.check_ncoll([a, b, c]):
|
| 668 |
+
continue
|
| 669 |
+
|
| 670 |
+
mapping = dict(zip('ABCPQR', [a, b, c, p, q, r]))
|
| 671 |
+
record.add((a, b, c, p, q, r))
|
| 672 |
+
yield mapping
|
| 673 |
+
|
| 674 |
+
|
| 675 |
+
def match_eqratio6_eqratio6_ncoll_cong_contri(
|
| 676 |
+
g: gh.Graph,
|
| 677 |
+
g_matcher: Callable[str, list[tuple[gm.Point, ...]]],
|
| 678 |
+
theorem: pr.Theorem,
|
| 679 |
+
) -> Generator[dict[str, gm.Point], None, None]:
|
| 680 |
+
"""Match eqratio6 B A B C Q P Q R, eqratio6 C A C B R P R Q, ncoll A B C, cong A B P Q => contri* A B C P Q R."""
|
| 681 |
+
enums = g_matcher('eqratio6')
|
| 682 |
+
|
| 683 |
+
record = set()
|
| 684 |
+
for b, a, b, c, q, p, q, r in enums: # pylint: disable=redeclared-assigned-name,unused-variable
|
| 685 |
+
if not g.check_cong([a, b, p, q]):
|
| 686 |
+
continue
|
| 687 |
+
if (a, b, c) == (p, q, r):
|
| 688 |
+
continue
|
| 689 |
+
if any([x in record for x in rotate_contri(a, b, c, p, q, r)]):
|
| 690 |
+
continue
|
| 691 |
+
if not g.check_eqratio([c, a, c, b, r, p, r, q]):
|
| 692 |
+
continue
|
| 693 |
+
|
| 694 |
+
if not g.check_ncoll([a, b, c]):
|
| 695 |
+
continue
|
| 696 |
+
|
| 697 |
+
mapping = dict(zip('ABCPQR', [a, b, c, p, q, r]))
|
| 698 |
+
record.add((a, b, c, p, q, r))
|
| 699 |
+
yield mapping
|
| 700 |
+
|
| 701 |
+
|
| 702 |
+
def match_eqangle6_eqangle6_ncoll_cong_contri2(
|
| 703 |
+
g: gh.Graph,
|
| 704 |
+
g_matcher: Callable[str, list[tuple[gm.Point, ...]]],
|
| 705 |
+
theorem: pr.Theorem,
|
| 706 |
+
) -> Generator[dict[str, gm.Point], None, None]:
|
| 707 |
+
"""Match eqangle6 B A B C Q R Q P, eqangle6 C A C B R Q R P, ncoll A B C, cong A B P Q => contri2 A B C P Q R."""
|
| 708 |
+
enums = g_matcher('eqangle6')
|
| 709 |
+
|
| 710 |
+
record = set()
|
| 711 |
+
for b, a, b, c, q, r, q, p in enums: # pylint: disable=redeclared-assigned-name,unused-variable
|
| 712 |
+
if not g.check_cong([a, b, p, q]):
|
| 713 |
+
continue
|
| 714 |
+
if (a, b, c) == (p, q, r):
|
| 715 |
+
continue
|
| 716 |
+
if any([x in record for x in rotate_contri(a, b, c, p, q, r)]):
|
| 717 |
+
continue
|
| 718 |
+
if not g.check_eqangle([c, a, c, b, r, q, r, p]):
|
| 719 |
+
continue
|
| 720 |
+
if not g.check_ncoll([a, b, c]):
|
| 721 |
+
continue
|
| 722 |
+
|
| 723 |
+
mapping = dict(zip('ABCPQR', [a, b, c, p, q, r]))
|
| 724 |
+
record.add((a, b, c, p, q, r))
|
| 725 |
+
yield mapping
|
| 726 |
+
|
| 727 |
+
|
| 728 |
+
def match_eqratio6_coll_ncoll_eqangle6(
|
| 729 |
+
g: gh.Graph,
|
| 730 |
+
g_matcher: Callable[str, list[tuple[gm.Point, ...]]],
|
| 731 |
+
theorem: pr.Theorem,
|
| 732 |
+
) -> Generator[dict[str, gm.Point], None, None]:
|
| 733 |
+
"""Match eqratio6 d b d c a b a c, coll d b c, ncoll a b c => eqangle6 a b a d a d a c."""
|
| 734 |
+
records = set()
|
| 735 |
+
for b, d, c in g_matcher('coll'):
|
| 736 |
+
for a in g.all_points():
|
| 737 |
+
if not g.check_ncoll([a, b, c]):
|
| 738 |
+
continue
|
| 739 |
+
if (a, b, d, c) in records or (a, c, d, b) in records:
|
| 740 |
+
continue
|
| 741 |
+
records.add((a, b, d, c))
|
| 742 |
+
|
| 743 |
+
if g.check_eqratio([d, b, d, c, a, b, a, c]):
|
| 744 |
+
yield dict(zip('abcd', [a, b, c, d]))
|
| 745 |
+
|
| 746 |
+
|
| 747 |
+
def match_eqangle6_coll_ncoll_eqratio6(
|
| 748 |
+
g: gh.Graph,
|
| 749 |
+
g_matcher: Callable[str, list[tuple[gm.Point, ...]]],
|
| 750 |
+
theorem: pr.Theorem,
|
| 751 |
+
) -> Generator[dict[str, gm.Point], None, None]:
|
| 752 |
+
"""Match eqangle6 a b a d a d a c, coll d b c, ncoll a b c => eqratio6 d b d c a b a c."""
|
| 753 |
+
records = set()
|
| 754 |
+
for b, d, c in g_matcher('coll'):
|
| 755 |
+
for a in g.all_points():
|
| 756 |
+
if not g.check_ncoll([a, b, c]):
|
| 757 |
+
continue
|
| 758 |
+
if (a, b, d, c) in records or (a, c, d, b) in records:
|
| 759 |
+
continue
|
| 760 |
+
records.add((a, b, d, c))
|
| 761 |
+
|
| 762 |
+
if g.check_eqangle([a, b, a, d, a, d, a, c]):
|
| 763 |
+
yield dict(zip('abcd', [a, b, c, d]))
|
| 764 |
+
|
| 765 |
+
|
| 766 |
+
def match_eqangle6_ncoll_cyclic(
|
| 767 |
+
g: gh.Graph,
|
| 768 |
+
g_matcher: Callable[str, list[tuple[gm.Point, ...]]],
|
| 769 |
+
theorem: pr.Theorem,
|
| 770 |
+
) -> Generator[dict[str, gm.Point], None, None]:
|
| 771 |
+
"""Match eqangle6 P A P B Q A Q B, ncoll P Q A B => cyclic A B P Q."""
|
| 772 |
+
for a, b, a, c, x, y, x, z in g_matcher('eqangle6'): # pylint: disable=redeclared-assigned-name,unused-variable
|
| 773 |
+
if (b, c) != (y, z) or a == x:
|
| 774 |
+
continue
|
| 775 |
+
if nm.check_ncoll([x.num for x in [a, b, c, x]]):
|
| 776 |
+
yield dict(zip('ABPQ', [b, c, a, x]))
|
| 777 |
+
|
| 778 |
+
|
| 779 |
+
def match_all(
|
| 780 |
+
name: str, g: gh.Graph
|
| 781 |
+
) -> Generator[tuple[gm.Point, ...], None, None]:
|
| 782 |
+
"""Match all instances of a certain relation."""
|
| 783 |
+
if name in ['ncoll', 'npara', 'nperp']:
|
| 784 |
+
return []
|
| 785 |
+
if name == 'coll':
|
| 786 |
+
return g.all_colls()
|
| 787 |
+
if name == 'para':
|
| 788 |
+
return g.all_paras()
|
| 789 |
+
if name == 'perp':
|
| 790 |
+
return g.all_perps()
|
| 791 |
+
if name == 'cong':
|
| 792 |
+
return g.all_congs()
|
| 793 |
+
if name == 'eqangle':
|
| 794 |
+
return g.all_eqangles_8points()
|
| 795 |
+
if name == 'eqangle6':
|
| 796 |
+
return g.all_eqangles_6points()
|
| 797 |
+
if name == 'eqratio':
|
| 798 |
+
return g.all_eqratios_8points()
|
| 799 |
+
if name == 'eqratio6':
|
| 800 |
+
return g.all_eqratios_6points()
|
| 801 |
+
if name == 'cyclic':
|
| 802 |
+
return g.all_cyclics()
|
| 803 |
+
if name == 'midp':
|
| 804 |
+
return g.all_midps()
|
| 805 |
+
if name == 'circle':
|
| 806 |
+
return g.all_circles()
|
| 807 |
+
raise ValueError(f'Unrecognize {name}')
|
| 808 |
+
|
| 809 |
+
|
| 810 |
+
def cache_match(
|
| 811 |
+
graph: gh.Graph,
|
| 812 |
+
) -> Callable[str, list[tuple[gm.Point, ...]]]:
|
| 813 |
+
"""Cache throughout one single BFS level."""
|
| 814 |
+
cache = {}
|
| 815 |
+
|
| 816 |
+
def match_fn(name: str) -> list[tuple[gm.Point, ...]]:
|
| 817 |
+
if name in cache:
|
| 818 |
+
return cache[name]
|
| 819 |
+
|
| 820 |
+
result = list(match_all(name, graph))
|
| 821 |
+
cache[name] = result
|
| 822 |
+
return result
|
| 823 |
+
|
| 824 |
+
return match_fn
|
| 825 |
+
|
| 826 |
+
|
| 827 |
+
def try_to_map(
|
| 828 |
+
clause_enum: list[tuple[pr.Clause, list[tuple[gm.Point, ...]]]],
|
| 829 |
+
mapping: dict[str, gm.Point],
|
| 830 |
+
) -> Generator[dict[str, gm.Point], None, None]:
|
| 831 |
+
"""Recursively try to match the remaining points given current mapping."""
|
| 832 |
+
if not clause_enum:
|
| 833 |
+
yield mapping
|
| 834 |
+
return
|
| 835 |
+
|
| 836 |
+
clause, enum = clause_enum[0]
|
| 837 |
+
for points in enum:
|
| 838 |
+
mpcpy = dict(mapping)
|
| 839 |
+
|
| 840 |
+
fail = False
|
| 841 |
+
for p, a in zip(points, clause.args):
|
| 842 |
+
if a in mpcpy and mpcpy[a] != p or p in mpcpy and mpcpy[p] != a:
|
| 843 |
+
fail = True
|
| 844 |
+
break
|
| 845 |
+
mpcpy[a] = p
|
| 846 |
+
mpcpy[p] = a
|
| 847 |
+
|
| 848 |
+
if fail:
|
| 849 |
+
continue
|
| 850 |
+
|
| 851 |
+
for m in try_to_map(clause_enum[1:], mpcpy):
|
| 852 |
+
yield m
|
| 853 |
+
|
| 854 |
+
|
| 855 |
+
def match_generic(
|
| 856 |
+
g: gh.Graph,
|
| 857 |
+
cache: Callable[str, list[tuple[gm.Point, ...]]],
|
| 858 |
+
theorem: pr.Theorem
|
| 859 |
+
) -> Generator[dict[str, gm.Point], None, None]:
|
| 860 |
+
"""Match any generic rule that is not one of the above match_*() rules."""
|
| 861 |
+
clause2enum = {}
|
| 862 |
+
|
| 863 |
+
clauses = []
|
| 864 |
+
numerical_checks = []
|
| 865 |
+
for clause in theorem.premise:
|
| 866 |
+
if clause.name in ['ncoll', 'npara', 'nperp', 'sameside']:
|
| 867 |
+
numerical_checks.append(clause)
|
| 868 |
+
continue
|
| 869 |
+
|
| 870 |
+
enum = cache(clause.name)
|
| 871 |
+
if len(enum) == 0: # pylint: disable=g-explicit-length-test
|
| 872 |
+
return 0
|
| 873 |
+
|
| 874 |
+
clause2enum[clause] = enum
|
| 875 |
+
clauses.append((len(set(clause.args)), clause))
|
| 876 |
+
|
| 877 |
+
clauses = sorted(clauses, key=lambda x: x[0], reverse=True)
|
| 878 |
+
_, clauses = zip(*clauses)
|
| 879 |
+
|
| 880 |
+
for mapping in try_to_map([(c, clause2enum[c]) for c in clauses], {}):
|
| 881 |
+
if not mapping:
|
| 882 |
+
continue
|
| 883 |
+
|
| 884 |
+
checks_ok = True
|
| 885 |
+
for check in numerical_checks:
|
| 886 |
+
args = [mapping[a] for a in check.args]
|
| 887 |
+
if check.name == 'ncoll':
|
| 888 |
+
checks_ok = g.check_ncoll(args)
|
| 889 |
+
elif check.name == 'npara':
|
| 890 |
+
checks_ok = g.check_npara(args)
|
| 891 |
+
elif check.name == 'nperp':
|
| 892 |
+
checks_ok = g.check_nperp(args)
|
| 893 |
+
elif check.name == 'sameside':
|
| 894 |
+
checks_ok = g.check_sameside(args)
|
| 895 |
+
if not checks_ok:
|
| 896 |
+
break
|
| 897 |
+
if not checks_ok:
|
| 898 |
+
continue
|
| 899 |
+
|
| 900 |
+
yield mapping
|
| 901 |
+
|
| 902 |
+
|
| 903 |
+
BUILT_IN_FNS = {
|
| 904 |
+
'cong_cong_cong_cyclic': match_cong_cong_cong_cyclic,
|
| 905 |
+
'cong_cong_cong_ncoll_contri*': match_cong_cong_cong_ncoll_contri,
|
| 906 |
+
'cong_cong_eqangle6_ncoll_contri*': match_cong_cong_eqangle6_ncoll_contri,
|
| 907 |
+
'eqangle6_eqangle6_ncoll_simtri': match_eqangle6_eqangle6_ncoll_simtri,
|
| 908 |
+
'eqangle6_eqangle6_ncoll_cong_contri': (
|
| 909 |
+
match_eqangle6_eqangle6_ncoll_cong_contri
|
| 910 |
+
), # pylint: disable=line-too-long
|
| 911 |
+
'eqangle6_eqangle6_ncoll_simtri2': match_eqangle6_eqangle6_ncoll_simtri2,
|
| 912 |
+
'eqangle6_eqangle6_ncoll_cong_contri2': (
|
| 913 |
+
match_eqangle6_eqangle6_ncoll_cong_contri2
|
| 914 |
+
), # pylint: disable=line-too-long
|
| 915 |
+
'eqratio6_eqratio6_ncoll_simtri*': match_eqratio6_eqratio6_ncoll_simtri,
|
| 916 |
+
'eqratio6_eqratio6_ncoll_cong_contri*': (
|
| 917 |
+
match_eqratio6_eqratio6_ncoll_cong_contri
|
| 918 |
+
), # pylint: disable=line-too-long
|
| 919 |
+
'eqangle_para': match_eqangle_para,
|
| 920 |
+
'eqangle_ncoll_cyclic': match_eqangle_ncoll_cyclic,
|
| 921 |
+
'eqratio6_eqangle6_ncoll_simtri*': match_eqratio6_eqangle6_ncoll_simtri,
|
| 922 |
+
'eqangle_perp_perp': match_eqangle_perp_perp,
|
| 923 |
+
'eqangle6_ncoll_cong': match_eqangle6_ncoll_cong,
|
| 924 |
+
'perp_perp_ncoll_para': match_perp_perp_ncoll_para,
|
| 925 |
+
'circle_perp_eqangle': match_circle_perp_eqangle,
|
| 926 |
+
'circle_eqangle_perp': match_circle_eqangle_perp,
|
| 927 |
+
'cyclic_eqangle_cong': match_cyclic_eqangle_cong,
|
| 928 |
+
'midp_perp_cong': match_midp_perp_cong,
|
| 929 |
+
'perp_perp_npara_eqangle': match_perp_perp_npara_eqangle,
|
| 930 |
+
'cyclic_eqangle': match_cyclic_eqangle,
|
| 931 |
+
'eqangle_eqangle_eqangle': match_eqangle_eqangle_eqangle,
|
| 932 |
+
'eqratio_eqratio_eqratio': match_eqratio_eqratio_eqratio,
|
| 933 |
+
'eqratio6_coll_ncoll_eqangle6': match_eqratio6_coll_ncoll_eqangle6,
|
| 934 |
+
'eqangle6_coll_ncoll_eqratio6': match_eqangle6_coll_ncoll_eqratio6,
|
| 935 |
+
'eqangle6_ncoll_cyclic': match_eqangle6_ncoll_cyclic,
|
| 936 |
+
}
|
| 937 |
+
|
| 938 |
+
|
| 939 |
+
SKIP_THEOREMS = set()
|
| 940 |
+
|
| 941 |
+
|
| 942 |
+
def set_skip_theorems(theorems: set[str]) -> None:
|
| 943 |
+
SKIP_THEOREMS.update(theorems)
|
| 944 |
+
|
| 945 |
+
|
| 946 |
+
MAX_BRANCH = 50_000
|
| 947 |
+
|
| 948 |
+
|
| 949 |
+
def match_one_theorem(
|
| 950 |
+
g: gh.Graph,
|
| 951 |
+
cache: Callable[str, list[tuple[gm.Point, ...]]],
|
| 952 |
+
theorem: pr.Theorem
|
| 953 |
+
) -> Generator[dict[str, gm.Point], None, None]:
|
| 954 |
+
"""Match all instances of a single theorem (rule)."""
|
| 955 |
+
if cache is None:
|
| 956 |
+
cache = cache_match(g)
|
| 957 |
+
|
| 958 |
+
if theorem.name in SKIP_THEOREMS:
|
| 959 |
+
return []
|
| 960 |
+
|
| 961 |
+
if theorem.name.split('_')[-1] in SKIP_THEOREMS:
|
| 962 |
+
return []
|
| 963 |
+
|
| 964 |
+
if theorem.name in BUILT_IN_FNS:
|
| 965 |
+
mps = BUILT_IN_FNS[theorem.name](g, cache, theorem)
|
| 966 |
+
else:
|
| 967 |
+
mps = match_generic(g, cache, theorem)
|
| 968 |
+
|
| 969 |
+
mappings = []
|
| 970 |
+
for mp in mps:
|
| 971 |
+
mappings.append(mp)
|
| 972 |
+
if len(mappings) > MAX_BRANCH: # cap branching at this number.
|
| 973 |
+
break
|
| 974 |
+
|
| 975 |
+
return mappings
|
| 976 |
+
|
| 977 |
+
|
| 978 |
+
def match_all_theorems(
|
| 979 |
+
g: gh.Graph, theorems: list[pr.Theorem], goal: pr.Clause
|
| 980 |
+
) -> dict[pr.Theorem, dict[pr.Theorem, dict[str, gm.Point]]]:
|
| 981 |
+
"""Match all instances of all theorems (rules)."""
|
| 982 |
+
cache = cache_match(g)
|
| 983 |
+
# for BFS, collect all potential matches
|
| 984 |
+
# and then do it at the same time
|
| 985 |
+
theorem2mappings = {}
|
| 986 |
+
|
| 987 |
+
# Step 1: list all matches
|
| 988 |
+
for _, theorem in theorems.items():
|
| 989 |
+
name = theorem.name
|
| 990 |
+
if name.split('_')[-1] in [
|
| 991 |
+
'acompute',
|
| 992 |
+
'rcompute',
|
| 993 |
+
'fixl',
|
| 994 |
+
'fixc',
|
| 995 |
+
'fixb',
|
| 996 |
+
'fixt',
|
| 997 |
+
'fixp',
|
| 998 |
+
]:
|
| 999 |
+
if goal and goal.name != name:
|
| 1000 |
+
continue
|
| 1001 |
+
|
| 1002 |
+
mappings = match_one_theorem(g, cache, theorem)
|
| 1003 |
+
if len(mappings): # pylint: disable=g-explicit-length-test
|
| 1004 |
+
theorem2mappings[theorem] = list(mappings)
|
| 1005 |
+
return theorem2mappings
|
| 1006 |
+
|
| 1007 |
+
|
| 1008 |
+
def bfs_one_level(
|
| 1009 |
+
g: gh.Graph,
|
| 1010 |
+
theorems: list[pr.Theorem],
|
| 1011 |
+
level: int,
|
| 1012 |
+
controller: pr.Problem,
|
| 1013 |
+
verbose: bool = False,
|
| 1014 |
+
nm_check: bool = False,
|
| 1015 |
+
timeout: int = 600,
|
| 1016 |
+
) -> tuple[
|
| 1017 |
+
list[pr.Dependency],
|
| 1018 |
+
dict[str, list[tuple[gm.Point, ...]]],
|
| 1019 |
+
dict[str, list[tuple[gm.Point, ...]]],
|
| 1020 |
+
int,
|
| 1021 |
+
]:
|
| 1022 |
+
"""Forward deduce one breadth-first level."""
|
| 1023 |
+
|
| 1024 |
+
# Step 1: match all theorems:
|
| 1025 |
+
theorem2mappings = match_all_theorems(g, theorems, controller.goal)
|
| 1026 |
+
|
| 1027 |
+
# Step 2: traceback for each deduce:
|
| 1028 |
+
theorem2deps = {}
|
| 1029 |
+
t0 = time.time()
|
| 1030 |
+
for theorem, mappings in theorem2mappings.items():
|
| 1031 |
+
if time.time() - t0 > timeout:
|
| 1032 |
+
break
|
| 1033 |
+
mp_deps = []
|
| 1034 |
+
for mp in mappings:
|
| 1035 |
+
deps = EmptyDependency(level=level, rule_name=theorem.rule_name)
|
| 1036 |
+
fail = False # finding why deps might fail.
|
| 1037 |
+
|
| 1038 |
+
for p in theorem.premise:
|
| 1039 |
+
p_args = [mp[a] for a in p.args]
|
| 1040 |
+
# Trivial deps.
|
| 1041 |
+
if p.name == 'cong':
|
| 1042 |
+
a, b, c, d = p_args
|
| 1043 |
+
if {a, b} == {c, d}:
|
| 1044 |
+
continue
|
| 1045 |
+
if p.name == 'para':
|
| 1046 |
+
a, b, c, d = p_args
|
| 1047 |
+
if {a, b} == {c, d}:
|
| 1048 |
+
continue
|
| 1049 |
+
|
| 1050 |
+
if theorem.name in [
|
| 1051 |
+
'cong_cong_eqangle6_ncoll_contri*',
|
| 1052 |
+
'eqratio6_eqangle6_ncoll_simtri*',
|
| 1053 |
+
]:
|
| 1054 |
+
if p.name in ['eqangle', 'eqangle6']: # SAS or RAR
|
| 1055 |
+
b, a, b, c, y, x, y, z = ( # pylint: disable=redeclared-assigned-name,unused-variable
|
| 1056 |
+
p_args
|
| 1057 |
+
)
|
| 1058 |
+
if not nm.same_clock(a.num, b.num, c.num, x.num, y.num, z.num):
|
| 1059 |
+
p_args = b, a, b, c, y, z, y, x
|
| 1060 |
+
|
| 1061 |
+
dep = Dependency(p.name, p_args, rule_name='', level=level)
|
| 1062 |
+
try:
|
| 1063 |
+
dep = dep.why_me_or_cache(g, level)
|
| 1064 |
+
except: # pylint: disable=bare-except
|
| 1065 |
+
fail = True
|
| 1066 |
+
break
|
| 1067 |
+
|
| 1068 |
+
if dep.why is None:
|
| 1069 |
+
fail = True
|
| 1070 |
+
break
|
| 1071 |
+
g.cache_dep(p.name, p_args, dep)
|
| 1072 |
+
deps.why.append(dep)
|
| 1073 |
+
|
| 1074 |
+
if fail:
|
| 1075 |
+
continue
|
| 1076 |
+
|
| 1077 |
+
mp_deps.append((mp, deps))
|
| 1078 |
+
theorem2deps[theorem] = mp_deps
|
| 1079 |
+
|
| 1080 |
+
theorem2deps = list(theorem2deps.items())
|
| 1081 |
+
|
| 1082 |
+
# Step 3: add conclusions to graph.
|
| 1083 |
+
# Note that we do NOT mix step 2 and 3, strictly going for BFS.
|
| 1084 |
+
added = []
|
| 1085 |
+
for theorem, mp_deps in theorem2deps:
|
| 1086 |
+
for mp, deps in mp_deps:
|
| 1087 |
+
if time.time() - t0 > timeout:
|
| 1088 |
+
break
|
| 1089 |
+
name, args = theorem.conclusion_name_args(mp)
|
| 1090 |
+
hash_conclusion = pr.hashed(name, args)
|
| 1091 |
+
if hash_conclusion in g.cache:
|
| 1092 |
+
continue
|
| 1093 |
+
|
| 1094 |
+
add = g.add_piece(name, args, deps=deps)
|
| 1095 |
+
added += add
|
| 1096 |
+
|
| 1097 |
+
branching = len(added)
|
| 1098 |
+
|
| 1099 |
+
# Check if goal is found
|
| 1100 |
+
if controller.goal:
|
| 1101 |
+
args = []
|
| 1102 |
+
|
| 1103 |
+
for a in controller.goal.args:
|
| 1104 |
+
if a in g._name2node:
|
| 1105 |
+
a = g._name2node[a]
|
| 1106 |
+
elif '/' in a:
|
| 1107 |
+
a = create_consts_str(g, a)
|
| 1108 |
+
elif a.isdigit():
|
| 1109 |
+
a = int(a)
|
| 1110 |
+
args.append(a)
|
| 1111 |
+
|
| 1112 |
+
if g.check(controller.goal.name, args):
|
| 1113 |
+
return added, {}, {}, branching
|
| 1114 |
+
|
| 1115 |
+
# Run AR, but do NOT apply to the proof state (yet).
|
| 1116 |
+
for dep in added:
|
| 1117 |
+
g.add_algebra(dep, level)
|
| 1118 |
+
derives, eq4s = g.derive_algebra(level, verbose=False)
|
| 1119 |
+
|
| 1120 |
+
branching += sum([len(x) for x in derives.values()])
|
| 1121 |
+
branching += sum([len(x) for x in eq4s.values()])
|
| 1122 |
+
|
| 1123 |
+
return added, derives, eq4s, branching
|
| 1124 |
+
|
| 1125 |
+
|
| 1126 |
+
def create_consts_str(g: gh.Graph, s: str) -> gm.Angle | gm.Ratio:
|
| 1127 |
+
if 'pi/' in s:
|
| 1128 |
+
n, d = s.split('pi/')
|
| 1129 |
+
n, d = int(n), int(d)
|
| 1130 |
+
p0, _ = g.get_or_create_const_ang(n, d)
|
| 1131 |
+
else:
|
| 1132 |
+
n, d = s.split('/')
|
| 1133 |
+
n, d = int(n), int(d)
|
| 1134 |
+
p0, _ = g.get_or_create_const_rat(n, d)
|
| 1135 |
+
return p0
|
| 1136 |
+
|
| 1137 |
+
|
| 1138 |
+
def do_algebra(
|
| 1139 |
+
g: gh.Graph, added: list[pr.Dependency], verbose: bool = False
|
| 1140 |
+
) -> None:
|
| 1141 |
+
for add in added:
|
| 1142 |
+
g.add_algebra(add, None)
|
| 1143 |
+
derives, eq4s = g.derive_algebra(level=None, verbose=verbose)
|
| 1144 |
+
apply_derivations(g, derives)
|
| 1145 |
+
apply_derivations(g, eq4s)
|
| 1146 |
+
|
| 1147 |
+
|
| 1148 |
+
def apply_derivations(
|
| 1149 |
+
g: gh.Graph, derives: dict[str, list[tuple[gm.Point, ...]]]
|
| 1150 |
+
) -> list[pr.Dependency]:
|
| 1151 |
+
applied = []
|
| 1152 |
+
all_derives = list(derives.items())
|
| 1153 |
+
for name, args in all_derives:
|
| 1154 |
+
for arg in args:
|
| 1155 |
+
applied += g.do_algebra(name, arg)
|
| 1156 |
+
return applied
|
backend/core/ag4masses/alphageometry/dd_test.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 DeepMind Technologies Limited
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
|
| 16 |
+
"""Unit tests for dd."""
|
| 17 |
+
import unittest
|
| 18 |
+
|
| 19 |
+
from absl.testing import absltest
|
| 20 |
+
import dd
|
| 21 |
+
import graph as gh
|
| 22 |
+
import problem as pr
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
MAX_LEVEL = 1000
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class DDTest(unittest.TestCase):
|
| 29 |
+
|
| 30 |
+
@classmethod
|
| 31 |
+
def setUpClass(cls):
|
| 32 |
+
super().setUpClass()
|
| 33 |
+
cls.defs = pr.Definition.from_txt_file('defs.txt', to_dict=True)
|
| 34 |
+
cls.rules = pr.Theorem.from_txt_file('rules.txt', to_dict=True)
|
| 35 |
+
|
| 36 |
+
def test_imo_2022_p4_should_succeed(self):
|
| 37 |
+
p = pr.Problem.from_txt(
|
| 38 |
+
'a b = segment a b; g1 = on_tline g1 a a b; g2 = on_tline g2 b b a; m ='
|
| 39 |
+
' on_circle m g1 a, on_circle m g2 b; n = on_circle n g1 a, on_circle n'
|
| 40 |
+
' g2 b; c = on_pline c m a b, on_circle c g1 a; d = on_pline d m a b,'
|
| 41 |
+
' on_circle d g2 b; e = on_line e a c, on_line e b d; p = on_line p a'
|
| 42 |
+
' n, on_line p c d; q = on_line q b n, on_line q c d ? cong e p e q'
|
| 43 |
+
)
|
| 44 |
+
g, _ = gh.Graph.build_problem(p, DDTest.defs)
|
| 45 |
+
goal_args = g.names2nodes(p.goal.args)
|
| 46 |
+
|
| 47 |
+
success = False
|
| 48 |
+
for level in range(MAX_LEVEL):
|
| 49 |
+
added, _, _, _ = dd.bfs_one_level(g, DDTest.rules, level, p)
|
| 50 |
+
if g.check(p.goal.name, goal_args):
|
| 51 |
+
success = True
|
| 52 |
+
break
|
| 53 |
+
if not added: # saturated
|
| 54 |
+
break
|
| 55 |
+
|
| 56 |
+
self.assertTrue(success)
|
| 57 |
+
|
| 58 |
+
def test_incenter_excenter_should_fail(self):
|
| 59 |
+
p = pr.Problem.from_txt(
|
| 60 |
+
'a b c = triangle a b c; d = incenter d a b c; e = excenter e a b c ?'
|
| 61 |
+
' perp d c c e'
|
| 62 |
+
)
|
| 63 |
+
g, _ = gh.Graph.build_problem(p, DDTest.defs)
|
| 64 |
+
goal_args = g.names2nodes(p.goal.args)
|
| 65 |
+
|
| 66 |
+
success = False
|
| 67 |
+
for level in range(MAX_LEVEL):
|
| 68 |
+
added, _, _, _ = dd.bfs_one_level(g, DDTest.rules, level, p)
|
| 69 |
+
if g.check(p.goal.name, goal_args):
|
| 70 |
+
success = True
|
| 71 |
+
break
|
| 72 |
+
if not added: # saturated
|
| 73 |
+
break
|
| 74 |
+
|
| 75 |
+
self.assertFalse(success)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
if __name__ == '__main__':
|
| 79 |
+
absltest.main()
|
backend/core/ag4masses/alphageometry/ddar.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 DeepMind Technologies Limited
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
|
| 16 |
+
"""Implements the combination DD+AR."""
|
| 17 |
+
import time
|
| 18 |
+
|
| 19 |
+
from absl import logging
|
| 20 |
+
import dd
|
| 21 |
+
import graph as gh
|
| 22 |
+
import problem as pr
|
| 23 |
+
from problem import Dependency # pylint: disable=g-importing-member
|
| 24 |
+
import trace_back
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def saturate_or_goal(
|
| 28 |
+
g: gh.Graph,
|
| 29 |
+
theorems: list[pr.Theorem],
|
| 30 |
+
level_times: list[float],
|
| 31 |
+
p: pr.Problem,
|
| 32 |
+
max_level: int = 100,
|
| 33 |
+
timeout: int = 600,
|
| 34 |
+
) -> tuple[
|
| 35 |
+
list[dict[str, list[tuple[gh.Point, ...]]]],
|
| 36 |
+
list[dict[str, list[tuple[gh.Point, ...]]]],
|
| 37 |
+
list[int],
|
| 38 |
+
list[pr.Dependency],
|
| 39 |
+
]:
|
| 40 |
+
"""Run DD until saturation or goal found."""
|
| 41 |
+
derives = []
|
| 42 |
+
eq4s = []
|
| 43 |
+
branching = []
|
| 44 |
+
all_added = []
|
| 45 |
+
|
| 46 |
+
while len(level_times) < max_level:
|
| 47 |
+
level = len(level_times) + 1
|
| 48 |
+
|
| 49 |
+
t = time.time()
|
| 50 |
+
added, derv, eq4, n_branching = dd.bfs_one_level(
|
| 51 |
+
g, theorems, level, p, verbose=False, nm_check=True, timeout=timeout
|
| 52 |
+
)
|
| 53 |
+
all_added += added
|
| 54 |
+
branching.append(n_branching)
|
| 55 |
+
|
| 56 |
+
derives.append(derv)
|
| 57 |
+
eq4s.append(eq4)
|
| 58 |
+
level_time = time.time() - t
|
| 59 |
+
|
| 60 |
+
logging.info(f'Depth {level}/{max_level} time = {level_time}') # pylint: disable=logging-fstring-interpolation
|
| 61 |
+
level_times.append(level_time)
|
| 62 |
+
|
| 63 |
+
if p.goal is not None:
|
| 64 |
+
goal_args = list(map(lambda x: g.get(x, lambda: int(x)), p.goal.args))
|
| 65 |
+
if g.check(p.goal.name, goal_args): # found goal
|
| 66 |
+
break
|
| 67 |
+
|
| 68 |
+
if not added: # saturated
|
| 69 |
+
break
|
| 70 |
+
|
| 71 |
+
if level_time > timeout:
|
| 72 |
+
break
|
| 73 |
+
|
| 74 |
+
return derives, eq4s, branching, all_added
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def solve(
|
| 78 |
+
g: gh.Graph,
|
| 79 |
+
theorems: list[pr.Problem],
|
| 80 |
+
controller: pr.Problem,
|
| 81 |
+
max_level: int = 1000,
|
| 82 |
+
timeout: int = 600,
|
| 83 |
+
) -> tuple[gh.Graph, list[float], str, list[int], list[pr.Dependency]]:
|
| 84 |
+
"""Alternate between DD and AR until goal is found."""
|
| 85 |
+
status = 'saturated'
|
| 86 |
+
level_times = []
|
| 87 |
+
|
| 88 |
+
dervs, eq4 = g.derive_algebra(level=0, verbose=False)
|
| 89 |
+
derives = [dervs]
|
| 90 |
+
eq4s = [eq4]
|
| 91 |
+
branches = []
|
| 92 |
+
all_added = []
|
| 93 |
+
|
| 94 |
+
while len(level_times) < max_level:
|
| 95 |
+
dervs, eq4, next_branches, added = saturate_or_goal(
|
| 96 |
+
g, theorems, level_times, controller, max_level, timeout=timeout
|
| 97 |
+
)
|
| 98 |
+
all_added += added
|
| 99 |
+
|
| 100 |
+
derives += dervs
|
| 101 |
+
eq4s += eq4
|
| 102 |
+
branches += next_branches
|
| 103 |
+
|
| 104 |
+
# Now, it is either goal or saturated
|
| 105 |
+
if controller.goal is not None:
|
| 106 |
+
goal_args = g.names2points(controller.goal.args)
|
| 107 |
+
if g.check(controller.goal.name, goal_args): # found goal
|
| 108 |
+
status = 'solved'
|
| 109 |
+
break
|
| 110 |
+
|
| 111 |
+
if not derives: # officially saturated.
|
| 112 |
+
logging.info("derives empty, breaking")
|
| 113 |
+
break
|
| 114 |
+
|
| 115 |
+
# Now we resort to algebra derivations.
|
| 116 |
+
added = []
|
| 117 |
+
while derives and not added:
|
| 118 |
+
added += dd.apply_derivations(g, derives.pop(0))
|
| 119 |
+
|
| 120 |
+
if added:
|
| 121 |
+
continue
|
| 122 |
+
|
| 123 |
+
# Final help from AR.
|
| 124 |
+
while eq4s and not added:
|
| 125 |
+
added += dd.apply_derivations(g, eq4s.pop(0))
|
| 126 |
+
|
| 127 |
+
all_added += added
|
| 128 |
+
|
| 129 |
+
if not added: # Nothing left. saturated.
|
| 130 |
+
logging.info("Nothing added, breaking")
|
| 131 |
+
break
|
| 132 |
+
|
| 133 |
+
return g, level_times, status, branches, all_added
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def get_proof_steps(
|
| 137 |
+
g: gh.Graph, goal: pr.Clause, merge_trivials: bool = False
|
| 138 |
+
) -> tuple[
|
| 139 |
+
list[pr.Dependency],
|
| 140 |
+
list[pr.Dependency],
|
| 141 |
+
list[tuple[list[pr.Dependency], list[pr.Dependency]]],
|
| 142 |
+
dict[tuple[str, ...], int],
|
| 143 |
+
]:
|
| 144 |
+
"""Extract proof steps from the built DAG."""
|
| 145 |
+
goal_args = g.names2nodes(goal.args)
|
| 146 |
+
query = Dependency(goal.name, goal_args, None, None)
|
| 147 |
+
|
| 148 |
+
setup, aux, log, setup_points = trace_back.get_logs(
|
| 149 |
+
query, g, merge_trivials=merge_trivials
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
refs = {}
|
| 153 |
+
setup = trace_back.point_log(setup, refs, set())
|
| 154 |
+
aux = trace_back.point_log(aux, refs, setup_points)
|
| 155 |
+
|
| 156 |
+
setup = [(prems, [tuple(p)]) for p, prems in setup]
|
| 157 |
+
aux = [(prems, [tuple(p)]) for p, prems in aux]
|
| 158 |
+
|
| 159 |
+
return setup, aux, log, refs
|
backend/core/ag4masses/alphageometry/ddar_test.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 DeepMind Technologies Limited
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
|
| 16 |
+
"""Unit tests for ddar.py."""
|
| 17 |
+
import unittest
|
| 18 |
+
|
| 19 |
+
from absl.testing import absltest
|
| 20 |
+
import ddar
|
| 21 |
+
import graph as gh
|
| 22 |
+
import problem as pr
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class DDARTest(unittest.TestCase):
|
| 26 |
+
|
| 27 |
+
@classmethod
|
| 28 |
+
def setUpClass(cls):
|
| 29 |
+
super().setUpClass()
|
| 30 |
+
cls.defs = pr.Definition.from_txt_file('defs.txt', to_dict=True)
|
| 31 |
+
cls.rules = pr.Theorem.from_txt_file('rules.txt', to_dict=True)
|
| 32 |
+
|
| 33 |
+
def test_orthocenter_should_fail(self):
|
| 34 |
+
txt = 'a b c = triangle a b c; d = on_tline d b a c, on_tline d c a b ? perp a d b c' # pylint: disable=line-too-long
|
| 35 |
+
p = pr.Problem.from_txt(txt)
|
| 36 |
+
g, _ = gh.Graph.build_problem(p, DDARTest.defs)
|
| 37 |
+
|
| 38 |
+
ddar.solve(g, DDARTest.rules, p, max_level=1000)
|
| 39 |
+
goal_args = g.names2nodes(p.goal.args)
|
| 40 |
+
self.assertFalse(g.check(p.goal.name, goal_args))
|
| 41 |
+
|
| 42 |
+
def test_orthocenter_aux_should_succeed(self):
|
| 43 |
+
txt = 'a b c = triangle a b c; d = on_tline d b a c, on_tline d c a b; e = on_line e a c, on_line e b d ? perp a d b c' # pylint: disable=line-too-long
|
| 44 |
+
p = pr.Problem.from_txt(txt)
|
| 45 |
+
g, _ = gh.Graph.build_problem(p, DDARTest.defs)
|
| 46 |
+
|
| 47 |
+
ddar.solve(g, DDARTest.rules, p, max_level=1000)
|
| 48 |
+
goal_args = g.names2nodes(p.goal.args)
|
| 49 |
+
self.assertTrue(g.check(p.goal.name, goal_args))
|
| 50 |
+
|
| 51 |
+
def test_incenter_excenter_should_succeed(self):
|
| 52 |
+
# Note that this same problem should fail in dd_test.py
|
| 53 |
+
p = pr.Problem.from_txt(
|
| 54 |
+
'a b c = triangle a b c; d = incenter d a b c; e = excenter e a b c ?'
|
| 55 |
+
' perp d c c e'
|
| 56 |
+
) # pylint: disable=line-too-long
|
| 57 |
+
g, _ = gh.Graph.build_problem(p, DDARTest.defs)
|
| 58 |
+
|
| 59 |
+
ddar.solve(g, DDARTest.rules, p, max_level=1000)
|
| 60 |
+
goal_args = g.names2nodes(p.goal.args)
|
| 61 |
+
self.assertTrue(g.check(p.goal.name, goal_args))
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
if __name__ == '__main__':
|
| 65 |
+
absltest.main()
|
backend/core/ag4masses/alphageometry/decoder_stack.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 DeepMind Technologies Limited
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
|
| 16 |
+
"""The decoder stack in inference mode."""
|
| 17 |
+
|
| 18 |
+
from typing import Any, Tuple
|
| 19 |
+
|
| 20 |
+
import gin
|
| 21 |
+
from transformer import decoder_stack
|
| 22 |
+
import transformer_layer as tl
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
struct = decoder_stack.struct
|
| 26 |
+
nn_components = decoder_stack.nn_components
|
| 27 |
+
position = decoder_stack.position
|
| 28 |
+
jnp = decoder_stack.jnp
|
| 29 |
+
attention = decoder_stack.attention
|
| 30 |
+
|
| 31 |
+
DStackWindowState = decoder_stack.DStackWindowState
|
| 32 |
+
|
| 33 |
+
Array = Any
|
| 34 |
+
|
| 35 |
+
TransformerTaskConfig = decoder_stack.TransformerTaskConfig
|
| 36 |
+
|
| 37 |
+
DStackDecoderState = Tuple[tl.DecoderState, ...]
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
@gin.configurable
|
| 41 |
+
class DecoderStackGenerate(decoder_stack.DecoderStack):
|
| 42 |
+
"""Stack of transformer decoder layers."""
|
| 43 |
+
|
| 44 |
+
layer_factory = tl.TransformerLayerGenerate
|
| 45 |
+
|
| 46 |
+
def init_decoder_state_vanilla(
|
| 47 |
+
self, sequence_length: int, start_of_sequence: Array
|
| 48 |
+
) -> DStackDecoderState:
|
| 49 |
+
"""Return initial state for autoregressive generation."""
|
| 50 |
+
return tuple(
|
| 51 |
+
[
|
| 52 |
+
layer.init_decoder_state_vanilla(sequence_length, start_of_sequence)
|
| 53 |
+
for layer in self.transformer_layers
|
| 54 |
+
]
|
| 55 |
+
)
|
backend/core/ag4masses/alphageometry/defs.txt
ADDED
|
@@ -0,0 +1,407 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
angle_bisector x a b c
|
| 2 |
+
x : a b c x
|
| 3 |
+
a b c = ncoll a b c
|
| 4 |
+
x : eqangle b a b x b x b c
|
| 5 |
+
bisect a b c
|
| 6 |
+
|
| 7 |
+
angle_mirror x a b c
|
| 8 |
+
x : a b c x
|
| 9 |
+
a b c = ncoll a b c
|
| 10 |
+
x : eqangle b a b c b c b x
|
| 11 |
+
amirror a b c
|
| 12 |
+
|
| 13 |
+
circle x a b c
|
| 14 |
+
x : a b c
|
| 15 |
+
a b c = ncoll a b c
|
| 16 |
+
x : cong x a x b, cong x b x c
|
| 17 |
+
bline a b, bline a c
|
| 18 |
+
|
| 19 |
+
circumcenter x a b c
|
| 20 |
+
x : a b c
|
| 21 |
+
a b c = ncoll a b c
|
| 22 |
+
x : cong x a x b, cong x b x c
|
| 23 |
+
bline a b, bline a c
|
| 24 |
+
|
| 25 |
+
eq_quadrangle a b c d
|
| 26 |
+
d : a b c d
|
| 27 |
+
=
|
| 28 |
+
a : ; b : ; c : ; d : cong d a b c
|
| 29 |
+
eq_quadrangle
|
| 30 |
+
|
| 31 |
+
eq_trapezoid a b c d
|
| 32 |
+
d : a b c
|
| 33 |
+
=
|
| 34 |
+
a : ; b : ; c : ; d : para d c a b, cong d a b c
|
| 35 |
+
eq_trapezoid
|
| 36 |
+
|
| 37 |
+
eq_triangle x b c
|
| 38 |
+
x : b c
|
| 39 |
+
b c = diff b c
|
| 40 |
+
x : cong x b b c, cong b c c x; eqangle b x b c c b c x, eqangle x c x b b x b c
|
| 41 |
+
circle b b c, circle c b c
|
| 42 |
+
|
| 43 |
+
eqangle2 x a b c
|
| 44 |
+
x : a b c x
|
| 45 |
+
a b c = ncoll a b c
|
| 46 |
+
x : eqangle a b a x c x c b
|
| 47 |
+
eqangle2 a b c
|
| 48 |
+
|
| 49 |
+
eqdia_quadrangle a b c d
|
| 50 |
+
d : a b c d
|
| 51 |
+
=
|
| 52 |
+
a : ; b : ; c : ; d : cong d b a c
|
| 53 |
+
eqdia_quadrangle
|
| 54 |
+
|
| 55 |
+
eqdistance x a b c
|
| 56 |
+
x : a b c x
|
| 57 |
+
a b c = diff b c
|
| 58 |
+
x : cong x a b c
|
| 59 |
+
circle a b c
|
| 60 |
+
|
| 61 |
+
foot x a b c
|
| 62 |
+
x : a b c
|
| 63 |
+
a b c = ncoll a b c
|
| 64 |
+
x : perp x a b c, coll x b c
|
| 65 |
+
tline a b c, line b c
|
| 66 |
+
|
| 67 |
+
free a
|
| 68 |
+
a : a
|
| 69 |
+
=
|
| 70 |
+
a :
|
| 71 |
+
free
|
| 72 |
+
|
| 73 |
+
incenter x a b c
|
| 74 |
+
x : a b c
|
| 75 |
+
a b c = ncoll a b c
|
| 76 |
+
x : eqangle a b a x a x a c, eqangle c a c x c x c b; eqangle b c b x b x b a
|
| 77 |
+
bisect a b c, bisect b c a
|
| 78 |
+
|
| 79 |
+
incenter2 x y z i a b c
|
| 80 |
+
i : a b c, x : i b c, y : i c a, z : i a b
|
| 81 |
+
a b c = ncoll a b c
|
| 82 |
+
i : eqangle a b a i a i a c, eqangle c a c i c i c b; eqangle b c b i b i b a; x : coll x b c, perp i x b c; y : coll y c a, perp i y c a; z : coll z a b, perp i z a b; cong i x i y, cong i y i z
|
| 83 |
+
incenter2 a b c
|
| 84 |
+
|
| 85 |
+
excenter x a b c
|
| 86 |
+
x : a b c
|
| 87 |
+
a b c = ncoll a b c
|
| 88 |
+
x : eqangle a b a x a x a c, eqangle c a c x c x c b; eqangle b c b x b x b a
|
| 89 |
+
bisect b a c, exbisect b c a
|
| 90 |
+
|
| 91 |
+
excenter2 x y z i a b c
|
| 92 |
+
i : a b c, x : i b c, y : i c a, z : i a b
|
| 93 |
+
a b c = ncoll a b c
|
| 94 |
+
i : eqangle a b a i a i a c, eqangle c a c i c i c b; eqangle b c b i b i b a; x : coll x b c, perp i x b c; y : coll y c a, perp i y c a; z : coll z a b, perp i z a b; cong i x i y, cong i y i z
|
| 95 |
+
excenter2 a b c
|
| 96 |
+
|
| 97 |
+
centroid x y z i a b c
|
| 98 |
+
x : b c, y : c a, z : a b, i : a x b y
|
| 99 |
+
a b c = ncoll a b c
|
| 100 |
+
x : coll x b c, cong x b x c; y : coll y c a, cong y c y a; z : coll z a b, cong z a z b; i : coll a x i, coll b y i; coll c z i
|
| 101 |
+
centroid a b c
|
| 102 |
+
|
| 103 |
+
ninepoints x y z i a b c
|
| 104 |
+
x : b c, y : c a, z : a b, i : x y z
|
| 105 |
+
a b c = ncoll a b c
|
| 106 |
+
x : coll x b c, cong x b x c; y : coll y c a, cong y c y a; z : coll z a b, cong z a z b; i : cong i x i y, cong i y i z
|
| 107 |
+
ninepoints a b c
|
| 108 |
+
|
| 109 |
+
intersection_cc x o w a
|
| 110 |
+
x : o w a
|
| 111 |
+
o w a = ncoll o w a
|
| 112 |
+
x : cong o a o x, cong w a w x
|
| 113 |
+
circle o o a, circle w w a
|
| 114 |
+
|
| 115 |
+
intersection_lc x a o b
|
| 116 |
+
x : a o b
|
| 117 |
+
a o b = diff a b, diff o b, nperp b o b a
|
| 118 |
+
x : coll x a b, cong o b o x
|
| 119 |
+
line b a, circle o o b
|
| 120 |
+
|
| 121 |
+
intersection_ll x a b c d
|
| 122 |
+
x : a b c d
|
| 123 |
+
a b c d = npara a b c d, ncoll a b c d
|
| 124 |
+
x : coll x a b, coll x c d
|
| 125 |
+
line a b, line c d
|
| 126 |
+
|
| 127 |
+
intersection_lp x a b c m n
|
| 128 |
+
x : a b c m n
|
| 129 |
+
a b c m n = npara m n a b, ncoll a b c, ncoll c m n
|
| 130 |
+
x : coll x a b, para c x m n
|
| 131 |
+
line a b, pline c m n
|
| 132 |
+
|
| 133 |
+
intersection_lt x a b c d e
|
| 134 |
+
x : a b c d e
|
| 135 |
+
a b c d e = ncoll a b c, nperp a b d e
|
| 136 |
+
x : coll x a b, perp x c d e
|
| 137 |
+
line a b, tline c d e
|
| 138 |
+
|
| 139 |
+
intersection_pp x a b c d e f
|
| 140 |
+
x : a b c d e f
|
| 141 |
+
a b c d e f = diff a d, npara b c e f
|
| 142 |
+
x : para x a b c, para x d e f
|
| 143 |
+
pline a b c, pline d e f
|
| 144 |
+
|
| 145 |
+
intersection_tt x a b c d e f
|
| 146 |
+
x : a b c d e f
|
| 147 |
+
a b c d e f = diff a d, npara b c e f
|
| 148 |
+
x : perp x a b c, perp x d e f
|
| 149 |
+
tline a b c, tline d e f
|
| 150 |
+
|
| 151 |
+
iso_triangle a b c
|
| 152 |
+
c : a b c
|
| 153 |
+
=
|
| 154 |
+
a : ; b : ; c : eqangle b a b c c b c a, cong a b a c
|
| 155 |
+
isos
|
| 156 |
+
|
| 157 |
+
lc_tangent x a o
|
| 158 |
+
x : x a o
|
| 159 |
+
a o = diff a o
|
| 160 |
+
x : perp a x a o
|
| 161 |
+
tline a a o
|
| 162 |
+
|
| 163 |
+
midpoint x a b
|
| 164 |
+
x : a b
|
| 165 |
+
a b = diff a b
|
| 166 |
+
x : coll x a b, cong x a x b
|
| 167 |
+
midp a b
|
| 168 |
+
|
| 169 |
+
mirror x a b
|
| 170 |
+
x : a b
|
| 171 |
+
a b = diff a b
|
| 172 |
+
x : coll x a b, cong b a b x
|
| 173 |
+
pmirror a b
|
| 174 |
+
|
| 175 |
+
nsquare x a b
|
| 176 |
+
x : a b
|
| 177 |
+
a b = diff a b
|
| 178 |
+
x : cong x a a b, perp x a a b
|
| 179 |
+
rotaten90 a b
|
| 180 |
+
|
| 181 |
+
on_aline x a b c d e
|
| 182 |
+
x : x a b c d e
|
| 183 |
+
a b c d e = ncoll c d e
|
| 184 |
+
x : eqangle a x a b d c d e
|
| 185 |
+
aline e d c b a
|
| 186 |
+
|
| 187 |
+
on_aline2 x a b c d e
|
| 188 |
+
x : x a b c d e
|
| 189 |
+
a b c d e = ncoll c d e
|
| 190 |
+
x : eqangle x a x b d c d e
|
| 191 |
+
aline2 e d c b a
|
| 192 |
+
|
| 193 |
+
on_bline x a b
|
| 194 |
+
x : x a b
|
| 195 |
+
a b = diff a b
|
| 196 |
+
x : cong x a x b, eqangle a x a b b a b x
|
| 197 |
+
bline a b
|
| 198 |
+
|
| 199 |
+
on_circle x o a
|
| 200 |
+
x : x o a
|
| 201 |
+
o a = diff o a
|
| 202 |
+
x : cong o x o a
|
| 203 |
+
circle o o a
|
| 204 |
+
|
| 205 |
+
on_line x a b
|
| 206 |
+
x : x a b
|
| 207 |
+
a b = diff a b
|
| 208 |
+
x : coll x a b
|
| 209 |
+
line a b
|
| 210 |
+
|
| 211 |
+
on_pline x a b c
|
| 212 |
+
x : x a b c
|
| 213 |
+
a b c = diff b c, ncoll a b c
|
| 214 |
+
x : para x a b c
|
| 215 |
+
pline a b c
|
| 216 |
+
|
| 217 |
+
on_tline x a b c
|
| 218 |
+
x : x a b c
|
| 219 |
+
a b c = diff b c
|
| 220 |
+
x : perp x a b c
|
| 221 |
+
tline a b c
|
| 222 |
+
|
| 223 |
+
orthocenter x a b c
|
| 224 |
+
x : a b c
|
| 225 |
+
a b c = ncoll a b c
|
| 226 |
+
x : perp x a b c, perp x b c a; perp x c a b
|
| 227 |
+
tline a b c, tline b c a
|
| 228 |
+
|
| 229 |
+
parallelogram a b c x
|
| 230 |
+
x : a b c
|
| 231 |
+
a b c = ncoll a b c
|
| 232 |
+
x : para a b c x, para a x b c; cong a b c x, cong a x b c
|
| 233 |
+
pline a b c, pline c a b
|
| 234 |
+
|
| 235 |
+
pentagon a b c d e
|
| 236 |
+
|
| 237 |
+
=
|
| 238 |
+
a : ; b : ; c : ; d : ; e :
|
| 239 |
+
pentagon
|
| 240 |
+
|
| 241 |
+
psquare x a b
|
| 242 |
+
x : a b
|
| 243 |
+
a b = diff a b
|
| 244 |
+
x : cong x a a b, perp x a a b
|
| 245 |
+
rotatep90 a b
|
| 246 |
+
|
| 247 |
+
quadrangle a b c d
|
| 248 |
+
|
| 249 |
+
=
|
| 250 |
+
a : ; b : ; c : ; d :
|
| 251 |
+
quadrangle
|
| 252 |
+
|
| 253 |
+
r_trapezoid a b c d
|
| 254 |
+
d : a b c
|
| 255 |
+
=
|
| 256 |
+
a : ; b : ; c : ; d : para a b c d, perp a b a d
|
| 257 |
+
r_trapezoid
|
| 258 |
+
|
| 259 |
+
r_triangle a b c
|
| 260 |
+
c : a b c
|
| 261 |
+
=
|
| 262 |
+
a : ; b : ; c : perp a b a c
|
| 263 |
+
r_triangle
|
| 264 |
+
|
| 265 |
+
rectangle a b c d
|
| 266 |
+
c : a b c , d : a b c
|
| 267 |
+
=
|
| 268 |
+
a : ; b : ; c : perp a b b c ; d : para a b c d, para a d b c; perp a b a d, cong a b c d, cong a d b c, cong a c b d
|
| 269 |
+
rectangle
|
| 270 |
+
|
| 271 |
+
reflect x a b c
|
| 272 |
+
x : a b c
|
| 273 |
+
a b c = diff b c, ncoll a b c
|
| 274 |
+
x : cong b a b x, cong c a c x; perp b c a x
|
| 275 |
+
reflect a b c
|
| 276 |
+
|
| 277 |
+
risos a b c
|
| 278 |
+
c : a b
|
| 279 |
+
=
|
| 280 |
+
a : ; b : ; c : perp a b a c, cong a b a c; eqangle b a b c c b c a
|
| 281 |
+
risos
|
| 282 |
+
|
| 283 |
+
s_angle a b x y
|
| 284 |
+
x : a b x
|
| 285 |
+
a b = diff a b
|
| 286 |
+
x : s_angle a b x y
|
| 287 |
+
s_angle a b y
|
| 288 |
+
|
| 289 |
+
segment a b
|
| 290 |
+
|
| 291 |
+
=
|
| 292 |
+
a : ; b :
|
| 293 |
+
segment
|
| 294 |
+
|
| 295 |
+
shift x b c d
|
| 296 |
+
x : b c d
|
| 297 |
+
b c d = diff d b
|
| 298 |
+
x : cong x b c d, cong x c b d
|
| 299 |
+
shift d c b
|
| 300 |
+
|
| 301 |
+
square a b x y
|
| 302 |
+
x : a b, y : a b x
|
| 303 |
+
a b = diff a b
|
| 304 |
+
x : perp a b b x, cong a b b x; y : para a b x y, para a y b x; perp a y y x, cong b x x y, cong x y y a, perp a x b y, cong a x b y
|
| 305 |
+
square a b
|
| 306 |
+
|
| 307 |
+
isquare a b c d
|
| 308 |
+
c : a b , d : a b c
|
| 309 |
+
=
|
| 310 |
+
a : ; b : ; c : perp a b b c, cong a b b c; d : para a b c d, para a d b c; perp a d d c, cong b c c d, cong c d d a, perp a c b d, cong a c b d
|
| 311 |
+
isquare
|
| 312 |
+
|
| 313 |
+
trapezoid a b c d
|
| 314 |
+
d : a b c d
|
| 315 |
+
=
|
| 316 |
+
a : ; b : ; c : ; d : para a b c d
|
| 317 |
+
trapezoid
|
| 318 |
+
|
| 319 |
+
triangle a b c
|
| 320 |
+
|
| 321 |
+
=
|
| 322 |
+
a : ; b : ; c :
|
| 323 |
+
triangle
|
| 324 |
+
|
| 325 |
+
triangle12 a b c
|
| 326 |
+
c : a b c
|
| 327 |
+
=
|
| 328 |
+
a : ; b : ; c : rconst a b a c 1 2
|
| 329 |
+
triangle12
|
| 330 |
+
|
| 331 |
+
2l1c x y z i a b c o
|
| 332 |
+
x : a b c o y z i, y : a b c o x z i, z : a b c o x y i, i : a b c o x y z
|
| 333 |
+
a b c o = cong o a o b, ncoll a b c
|
| 334 |
+
x y z i : coll x a c, coll y b c, cong o a o z, coll i o z, cong i x i y, cong i y i z, perp i x a c, perp i y b c
|
| 335 |
+
2l1c a b c o
|
| 336 |
+
|
| 337 |
+
e5128 x y a b c d
|
| 338 |
+
x : a b c d y, y : a b c d x
|
| 339 |
+
a b c d = cong c b c d, perp b c b a
|
| 340 |
+
x y : cong c b c x, coll y a b, coll x y d, eqangle a b a d x a x y
|
| 341 |
+
e5128 a b c d
|
| 342 |
+
|
| 343 |
+
3peq x y z a b c
|
| 344 |
+
z : b c z , x : a b c z y, y : a b c z x
|
| 345 |
+
a b c = ncoll a b c
|
| 346 |
+
z : coll z b c ; x y : coll x a b, coll y a c, coll x y z, cong z x z y
|
| 347 |
+
3peq a b c
|
| 348 |
+
|
| 349 |
+
trisect x y a b c
|
| 350 |
+
x : a b c y, y : a b c x
|
| 351 |
+
a b c = ncoll a b c
|
| 352 |
+
x y : coll x a c, coll y a c, eqangle b a b x b x b y, eqangle b x b y b y b c
|
| 353 |
+
trisect a b c
|
| 354 |
+
|
| 355 |
+
trisegment x y a b
|
| 356 |
+
x : a b y, y : a b x
|
| 357 |
+
a b = diff a b
|
| 358 |
+
x y : coll x a b, coll y a b, cong x a x y, cong y x y b
|
| 359 |
+
trisegment a b
|
| 360 |
+
|
| 361 |
+
on_dia x a b
|
| 362 |
+
x : x a b
|
| 363 |
+
a b = diff a b
|
| 364 |
+
x : perp x a x b
|
| 365 |
+
dia a b
|
| 366 |
+
|
| 367 |
+
ieq_triangle a b c
|
| 368 |
+
c : a b
|
| 369 |
+
=
|
| 370 |
+
a : ; b : ; c : cong a b b c, cong b c c a; eqangle a b a c c a c b, eqangle c a c b b c b a
|
| 371 |
+
ieq_triangle
|
| 372 |
+
|
| 373 |
+
on_opline x a b
|
| 374 |
+
x : x a b
|
| 375 |
+
a b = diff a b
|
| 376 |
+
x : coll x a b
|
| 377 |
+
on_opline a b
|
| 378 |
+
|
| 379 |
+
cc_tangent0 x y o a w b
|
| 380 |
+
x : o a w b y, y : o a w b x
|
| 381 |
+
o a w b = diff o a, diff w b, diff o w
|
| 382 |
+
x y : cong o x o a, cong w y w b, perp x o x y, perp y w y x
|
| 383 |
+
cc_tangent0 o a w b
|
| 384 |
+
|
| 385 |
+
cc_tangent x y z i o a w b
|
| 386 |
+
x : o a w b y, y : o a w b x, z : o a w b i, i : o a w b z
|
| 387 |
+
o a w b = diff o a, diff w b, diff o w
|
| 388 |
+
x y : cong o x o a, cong w y w b, perp x o x y, perp y w y x; z i : cong o z o a, cong w i w b, perp z o z i, perp i w i z
|
| 389 |
+
cc_tangent o a w b
|
| 390 |
+
|
| 391 |
+
eqangle3 x a b d e f
|
| 392 |
+
x : x a b d e f
|
| 393 |
+
a b d e f = ncoll d e f, diff a b, diff d e, diff e f
|
| 394 |
+
x : eqangle x a x b d e d f
|
| 395 |
+
eqangle3 a b d e f
|
| 396 |
+
|
| 397 |
+
tangent x y a o b
|
| 398 |
+
x y : o a b
|
| 399 |
+
a o b = diff o a, diff o b, diff a b
|
| 400 |
+
x : cong o x o b, perp a x o x; y : cong o y o b, perp a y o y
|
| 401 |
+
tangent a o b
|
| 402 |
+
|
| 403 |
+
on_circum x a b c
|
| 404 |
+
x : a b c
|
| 405 |
+
a b c = ncoll a b c
|
| 406 |
+
x : cyclic a b c x
|
| 407 |
+
cyclic a b c
|
backend/core/ag4masses/alphageometry/download.sh
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 DeepMind Technologies Limited
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
|
| 16 |
+
gdown --folder https://bit.ly/alphageometry
|
| 17 |
+
export DATA=ag_ckpt_vocab
|
backend/core/ag4masses/alphageometry/examples.txt
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
orthocenter
|
| 2 |
+
a b c = triangle; h = on_tline b a c, on_tline c a b ? perp a h b c
|
| 3 |
+
orthocenter_aux
|
| 4 |
+
a b c = triangle; d = on_tline d b a c, on_tline d c a b; e = on_line e a c, on_line e b d ? perp a d b c
|
| 5 |
+
incenter_excenter
|
| 6 |
+
a b c = triangle a b c; d1 d2 d3 d = incenter2 a b c; e1 e2 e3 e = excenter2 a b c ? perp d c c e
|
| 7 |
+
euler
|
| 8 |
+
a b c = triangle a b c; h = orthocenter a b c; h1 = foot a b c; h2 = foot b c a; h3 = foot c a b; g1 g2 g3 g = centroid g1 g2 g3 g a b c; o = circle a b c ? coll h g o
|
backend/core/ag4masses/alphageometry/fig1.svg
ADDED
|
|
backend/core/ag4masses/alphageometry/geometry.py
ADDED
|
@@ -0,0 +1,578 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 DeepMind Technologies Limited
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
|
| 16 |
+
"""Implements geometric objects used in the graph representation."""
|
| 17 |
+
from __future__ import annotations
|
| 18 |
+
from collections import defaultdict # pylint: disable=g-importing-member
|
| 19 |
+
from typing import Any, Type
|
| 20 |
+
|
| 21 |
+
# pylint: disable=protected-access
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class Node:
|
| 25 |
+
r"""Node in the proof state graph.
|
| 26 |
+
|
| 27 |
+
Can be Point, Line, Circle, etc.
|
| 28 |
+
|
| 29 |
+
Each node maintains a merge history to
|
| 30 |
+
other nodes if they are (found out to be) equivalent
|
| 31 |
+
|
| 32 |
+
a -> b -
|
| 33 |
+
\
|
| 34 |
+
c -> d -> e -> f -> g
|
| 35 |
+
|
| 36 |
+
d.merged_to = e
|
| 37 |
+
d.rep = g
|
| 38 |
+
d.merged_from = {a, b, c, d}
|
| 39 |
+
d.equivs = {a, b, c, d, e, f, g}
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
def __init__(self, name: str = '', graph: Any = None):
|
| 43 |
+
self.name = name or str(self)
|
| 44 |
+
self.graph = graph
|
| 45 |
+
|
| 46 |
+
self.edge_graph = {}
|
| 47 |
+
# Edge graph: what other nodes is connected to this node.
|
| 48 |
+
# edge graph = {
|
| 49 |
+
# other1: {self1: deps, self2: deps},
|
| 50 |
+
# other2: {self2: deps, self3: deps}
|
| 51 |
+
# }
|
| 52 |
+
|
| 53 |
+
self.merge_graph = {}
|
| 54 |
+
# Merge graph: history of merges with other nodes.
|
| 55 |
+
# merge_graph = {self1: {self2: deps1, self3: deps2}}
|
| 56 |
+
|
| 57 |
+
self.rep_by = None # represented by.
|
| 58 |
+
self.members = {self}
|
| 59 |
+
|
| 60 |
+
self._val = None
|
| 61 |
+
self._obj = None
|
| 62 |
+
|
| 63 |
+
self.deps = []
|
| 64 |
+
|
| 65 |
+
# numerical representation.
|
| 66 |
+
self.num = None
|
| 67 |
+
self.change = set() # what other nodes' num rely on this node?
|
| 68 |
+
|
| 69 |
+
def set_rep(self, node: Node) -> None:
|
| 70 |
+
if node == self:
|
| 71 |
+
return
|
| 72 |
+
self.rep_by = node
|
| 73 |
+
node.merge_edge_graph(self.edge_graph)
|
| 74 |
+
node.members.update(self.members)
|
| 75 |
+
|
| 76 |
+
def rep(self) -> Node:
|
| 77 |
+
x = self
|
| 78 |
+
while x.rep_by:
|
| 79 |
+
x = x.rep_by
|
| 80 |
+
return x
|
| 81 |
+
|
| 82 |
+
def why_rep(self) -> list[Any]:
|
| 83 |
+
return self.why_equal([self.rep()], None)
|
| 84 |
+
|
| 85 |
+
def rep_and_why(self) -> tuple[Node, list[Any]]:
|
| 86 |
+
rep = self.rep()
|
| 87 |
+
return rep, self.why_equal([rep], None)
|
| 88 |
+
|
| 89 |
+
def neighbors(
|
| 90 |
+
self, oftype: Type[Node], return_set: bool = False, do_rep: bool = True
|
| 91 |
+
) -> list[Node]:
|
| 92 |
+
"""Neighbors of this node in the proof state graph."""
|
| 93 |
+
if do_rep:
|
| 94 |
+
rep = self.rep()
|
| 95 |
+
else:
|
| 96 |
+
rep = self
|
| 97 |
+
result = set()
|
| 98 |
+
|
| 99 |
+
for n in rep.edge_graph:
|
| 100 |
+
if oftype is None or oftype and isinstance(n, oftype):
|
| 101 |
+
if do_rep:
|
| 102 |
+
result.add(n.rep())
|
| 103 |
+
else:
|
| 104 |
+
result.add(n)
|
| 105 |
+
|
| 106 |
+
if return_set:
|
| 107 |
+
return result
|
| 108 |
+
return list(result)
|
| 109 |
+
|
| 110 |
+
def merge_edge_graph(
|
| 111 |
+
self, new_edge_graph: dict[Node, dict[Node, list[Node]]]
|
| 112 |
+
) -> None:
|
| 113 |
+
for x, xdict in new_edge_graph.items():
|
| 114 |
+
if x in self.edge_graph:
|
| 115 |
+
self.edge_graph[x].update(dict(xdict))
|
| 116 |
+
else:
|
| 117 |
+
self.edge_graph[x] = dict(xdict)
|
| 118 |
+
|
| 119 |
+
def merge(self, nodes: list[Node], deps: list[Any]) -> None:
|
| 120 |
+
for node in nodes:
|
| 121 |
+
self.merge_one(node, deps)
|
| 122 |
+
|
| 123 |
+
def merge_one(self, node: Node, deps: list[Any]) -> None:
|
| 124 |
+
node.rep().set_rep(self.rep())
|
| 125 |
+
|
| 126 |
+
if node in self.merge_graph:
|
| 127 |
+
return
|
| 128 |
+
|
| 129 |
+
self.merge_graph[node] = deps
|
| 130 |
+
node.merge_graph[self] = deps
|
| 131 |
+
|
| 132 |
+
def is_val(self, node: Node) -> bool:
|
| 133 |
+
return (
|
| 134 |
+
isinstance(self, Line)
|
| 135 |
+
and isinstance(node, Direction)
|
| 136 |
+
or isinstance(self, Segment)
|
| 137 |
+
and isinstance(node, Length)
|
| 138 |
+
or isinstance(self, Angle)
|
| 139 |
+
and isinstance(node, Measure)
|
| 140 |
+
or isinstance(self, Ratio)
|
| 141 |
+
and isinstance(node, Value)
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
def set_val(self, node: Node) -> None:
|
| 145 |
+
self._val = node
|
| 146 |
+
|
| 147 |
+
def set_obj(self, node: Node) -> None:
|
| 148 |
+
self._obj = node
|
| 149 |
+
|
| 150 |
+
@property
|
| 151 |
+
def val(self) -> Node:
|
| 152 |
+
if self._val is None:
|
| 153 |
+
return None
|
| 154 |
+
return self._val.rep()
|
| 155 |
+
|
| 156 |
+
@property
|
| 157 |
+
def obj(self) -> Node:
|
| 158 |
+
if self._obj is None:
|
| 159 |
+
return None
|
| 160 |
+
return self._obj.rep()
|
| 161 |
+
|
| 162 |
+
def equivs(self) -> set[Node]:
|
| 163 |
+
return self.rep().members
|
| 164 |
+
|
| 165 |
+
def connect_to(self, node: Node, deps: list[Any] = None) -> None:
|
| 166 |
+
rep = self.rep()
|
| 167 |
+
|
| 168 |
+
if node in rep.edge_graph:
|
| 169 |
+
rep.edge_graph[node].update({self: deps})
|
| 170 |
+
else:
|
| 171 |
+
rep.edge_graph[node] = {self: deps}
|
| 172 |
+
|
| 173 |
+
if self.is_val(node):
|
| 174 |
+
self.set_val(node)
|
| 175 |
+
node.set_obj(self)
|
| 176 |
+
|
| 177 |
+
def equivs_upto(self, level: int) -> dict[Node, Node]:
|
| 178 |
+
"""What are the equivalent nodes up to a certain level."""
|
| 179 |
+
parent = {self: None}
|
| 180 |
+
visited = set()
|
| 181 |
+
queue = [self]
|
| 182 |
+
i = 0
|
| 183 |
+
|
| 184 |
+
while i < len(queue):
|
| 185 |
+
current = queue[i]
|
| 186 |
+
i += 1
|
| 187 |
+
visited.add(current)
|
| 188 |
+
|
| 189 |
+
for neighbor in current.merge_graph:
|
| 190 |
+
if (
|
| 191 |
+
level is not None
|
| 192 |
+
and current.merge_graph[neighbor].level is not None
|
| 193 |
+
and current.merge_graph[neighbor].level >= level
|
| 194 |
+
):
|
| 195 |
+
continue
|
| 196 |
+
if neighbor not in visited:
|
| 197 |
+
queue.append(neighbor)
|
| 198 |
+
parent[neighbor] = current
|
| 199 |
+
|
| 200 |
+
return parent
|
| 201 |
+
|
| 202 |
+
def why_equal(self, others: list[Node], level: int) -> list[Any]:
|
| 203 |
+
"""BFS why this node is equal to other nodes."""
|
| 204 |
+
others = set(others)
|
| 205 |
+
found = 0
|
| 206 |
+
|
| 207 |
+
parent = {}
|
| 208 |
+
queue = [self]
|
| 209 |
+
i = 0
|
| 210 |
+
|
| 211 |
+
while i < len(queue):
|
| 212 |
+
current = queue[i]
|
| 213 |
+
if current in others:
|
| 214 |
+
found += 1
|
| 215 |
+
if found == len(others):
|
| 216 |
+
break
|
| 217 |
+
|
| 218 |
+
i += 1
|
| 219 |
+
|
| 220 |
+
for neighbor in current.merge_graph:
|
| 221 |
+
if (
|
| 222 |
+
level is not None
|
| 223 |
+
and current.merge_graph[neighbor].level is not None
|
| 224 |
+
and current.merge_graph[neighbor].level >= level
|
| 225 |
+
):
|
| 226 |
+
continue
|
| 227 |
+
if neighbor not in parent:
|
| 228 |
+
queue.append(neighbor)
|
| 229 |
+
parent[neighbor] = current
|
| 230 |
+
|
| 231 |
+
return bfs_backtrack(self, others, parent)
|
| 232 |
+
|
| 233 |
+
def why_equal_groups(
|
| 234 |
+
self, groups: list[list[Node]], level: int
|
| 235 |
+
) -> tuple[list[Any], list[Node]]:
|
| 236 |
+
"""BFS for why self is equal to at least one member of each group."""
|
| 237 |
+
others = [None for _ in groups]
|
| 238 |
+
found = 0
|
| 239 |
+
|
| 240 |
+
parent = {}
|
| 241 |
+
queue = [self]
|
| 242 |
+
i = 0
|
| 243 |
+
|
| 244 |
+
while i < len(queue):
|
| 245 |
+
current = queue[i]
|
| 246 |
+
|
| 247 |
+
for j, grp in enumerate(groups):
|
| 248 |
+
if others[j] is None and current in grp:
|
| 249 |
+
others[j] = current
|
| 250 |
+
found += 1
|
| 251 |
+
|
| 252 |
+
if found == len(others):
|
| 253 |
+
break
|
| 254 |
+
|
| 255 |
+
i += 1
|
| 256 |
+
|
| 257 |
+
for neighbor in current.merge_graph:
|
| 258 |
+
if (
|
| 259 |
+
level is not None
|
| 260 |
+
and current.merge_graph[neighbor].level is not None
|
| 261 |
+
and current.merge_graph[neighbor].level >= level
|
| 262 |
+
):
|
| 263 |
+
continue
|
| 264 |
+
if neighbor not in parent:
|
| 265 |
+
queue.append(neighbor)
|
| 266 |
+
parent[neighbor] = current
|
| 267 |
+
|
| 268 |
+
return bfs_backtrack(self, others, parent), others
|
| 269 |
+
|
| 270 |
+
def why_val(self, level: int) -> list[Any]:
|
| 271 |
+
return self._val.why_equal([self.val], level)
|
| 272 |
+
|
| 273 |
+
def why_connect(self, node: Node, level: int = None) -> list[Any]:
|
| 274 |
+
rep = self.rep()
|
| 275 |
+
equivs = list(rep.edge_graph[node].keys())
|
| 276 |
+
if not equivs:
|
| 277 |
+
return None
|
| 278 |
+
equiv = equivs[0]
|
| 279 |
+
dep = rep.edge_graph[node][equiv]
|
| 280 |
+
return [dep] + self.why_equal(equiv, level)
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
def why_connect(*pairs: list[tuple[Node, Node]]) -> list[Any]:
|
| 284 |
+
result = []
|
| 285 |
+
for node1, node2 in pairs:
|
| 286 |
+
result += node1.why_connect(node2)
|
| 287 |
+
return result
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
def is_equiv(x: Node, y: Node, level: int = None) -> bool:
|
| 291 |
+
level = level or float('inf')
|
| 292 |
+
return x.why_equal([y], level) is not None
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
def is_equal(x: Node, y: Node, level: int = None) -> bool:
|
| 296 |
+
if x == y:
|
| 297 |
+
return True
|
| 298 |
+
if x._val is None or y._val is None:
|
| 299 |
+
return False
|
| 300 |
+
if x.val != y.val:
|
| 301 |
+
return False
|
| 302 |
+
return is_equiv(x._val, y._val, level)
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
def bfs_backtrack(
|
| 306 |
+
root: Node, leafs: list[Node], parent: dict[Node, Node]
|
| 307 |
+
) -> list[Any]:
|
| 308 |
+
"""Return the path given BFS trace of parent nodes."""
|
| 309 |
+
backtracked = {root} # no need to backtrack further when touching this set.
|
| 310 |
+
deps = []
|
| 311 |
+
for node in leafs:
|
| 312 |
+
if node is None:
|
| 313 |
+
return None
|
| 314 |
+
if node in backtracked:
|
| 315 |
+
continue
|
| 316 |
+
if node not in parent:
|
| 317 |
+
return None
|
| 318 |
+
while node not in backtracked:
|
| 319 |
+
backtracked.add(node)
|
| 320 |
+
deps.append(node.merge_graph[parent[node]])
|
| 321 |
+
node = parent[node]
|
| 322 |
+
|
| 323 |
+
return deps
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
class Point(Node):
|
| 327 |
+
pass
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
class Line(Node):
|
| 331 |
+
"""Node of type Line."""
|
| 332 |
+
|
| 333 |
+
def new_val(self) -> Direction:
|
| 334 |
+
return Direction()
|
| 335 |
+
|
| 336 |
+
def why_coll(self, points: list[Point], level: int = None) -> list[Any]:
|
| 337 |
+
"""Why points are connected to self."""
|
| 338 |
+
level = level or float('inf')
|
| 339 |
+
|
| 340 |
+
groups = []
|
| 341 |
+
for p in points:
|
| 342 |
+
group = [
|
| 343 |
+
l
|
| 344 |
+
for l, d in self.edge_graph[p].items()
|
| 345 |
+
if d is None or d.level < level
|
| 346 |
+
]
|
| 347 |
+
if not group:
|
| 348 |
+
return None
|
| 349 |
+
groups.append(group)
|
| 350 |
+
|
| 351 |
+
min_deps = None
|
| 352 |
+
for line in groups[0]:
|
| 353 |
+
deps, others = line.why_equal_groups(groups[1:], level)
|
| 354 |
+
if deps is None:
|
| 355 |
+
continue
|
| 356 |
+
for p, o in zip(points, [line] + others):
|
| 357 |
+
deps.append(self.edge_graph[p][o])
|
| 358 |
+
if min_deps is None or len(deps) < len(min_deps):
|
| 359 |
+
min_deps = deps
|
| 360 |
+
|
| 361 |
+
if min_deps is None:
|
| 362 |
+
return None
|
| 363 |
+
return [d for d in min_deps if d is not None]
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
class Segment(Node):
|
| 367 |
+
|
| 368 |
+
def new_val(self) -> Length:
|
| 369 |
+
return Length()
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
class Circle(Node):
|
| 373 |
+
"""Node of type Circle."""
|
| 374 |
+
|
| 375 |
+
def why_cyclic(self, points: list[Point], level: int = None) -> list[Any]:
|
| 376 |
+
"""Why points are connected to self."""
|
| 377 |
+
level = level or float('inf')
|
| 378 |
+
|
| 379 |
+
groups = []
|
| 380 |
+
for p in points:
|
| 381 |
+
group = [
|
| 382 |
+
c
|
| 383 |
+
for c, d in self.edge_graph[p].items()
|
| 384 |
+
if d is None or d.level < level
|
| 385 |
+
]
|
| 386 |
+
if not group:
|
| 387 |
+
return None
|
| 388 |
+
groups.append(group)
|
| 389 |
+
|
| 390 |
+
min_deps = None
|
| 391 |
+
for circle in groups[0]:
|
| 392 |
+
deps, others = circle.why_equal_groups(groups[1:], level)
|
| 393 |
+
if deps is None:
|
| 394 |
+
continue
|
| 395 |
+
for p, o in zip(points, [circle] + others):
|
| 396 |
+
deps.append(self.edge_graph[p][o])
|
| 397 |
+
|
| 398 |
+
if min_deps is None or len(deps) < len(min_deps):
|
| 399 |
+
min_deps = deps
|
| 400 |
+
|
| 401 |
+
if min_deps is None:
|
| 402 |
+
return None
|
| 403 |
+
return [d for d in min_deps if d is not None]
|
| 404 |
+
|
| 405 |
+
|
| 406 |
+
def why_equal(x: Node, y: Node, level: int = None) -> list[Any]:
|
| 407 |
+
if x == y:
|
| 408 |
+
return []
|
| 409 |
+
if not x._val or not y._val:
|
| 410 |
+
return None
|
| 411 |
+
if x._val == y._val:
|
| 412 |
+
return []
|
| 413 |
+
return x._val.why_equal([y._val], level)
|
| 414 |
+
|
| 415 |
+
|
| 416 |
+
class Direction(Node):
|
| 417 |
+
pass
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
def get_lines_thru_all(*points: list[Point]) -> list[Line]:
|
| 421 |
+
line2count = defaultdict(lambda: 0)
|
| 422 |
+
points = set(points)
|
| 423 |
+
for p in points:
|
| 424 |
+
for l in p.neighbors(Line):
|
| 425 |
+
line2count[l] += 1
|
| 426 |
+
return [l for l, count in line2count.items() if count == len(points)]
|
| 427 |
+
|
| 428 |
+
|
| 429 |
+
def line_of_and_why(
|
| 430 |
+
points: list[Point], level: int = None
|
| 431 |
+
) -> tuple[Line, list[Any]]:
|
| 432 |
+
"""Why points are collinear."""
|
| 433 |
+
for l0 in get_lines_thru_all(*points):
|
| 434 |
+
for l in l0.equivs():
|
| 435 |
+
if all([p in l.edge_graph for p in points]):
|
| 436 |
+
x, y = l.points
|
| 437 |
+
colls = list({x, y} | set(points))
|
| 438 |
+
# if len(colls) < 3:
|
| 439 |
+
# return l, []
|
| 440 |
+
why = l.why_coll(colls, level)
|
| 441 |
+
if why is not None:
|
| 442 |
+
return l, why
|
| 443 |
+
|
| 444 |
+
return None, None
|
| 445 |
+
|
| 446 |
+
|
| 447 |
+
def get_circles_thru_all(*points: list[Point]) -> list[Circle]:
|
| 448 |
+
circle2count = defaultdict(lambda: 0)
|
| 449 |
+
points = set(points)
|
| 450 |
+
for p in points:
|
| 451 |
+
for c in p.neighbors(Circle):
|
| 452 |
+
circle2count[c] += 1
|
| 453 |
+
return [c for c, count in circle2count.items() if count == len(points)]
|
| 454 |
+
|
| 455 |
+
|
| 456 |
+
def circle_of_and_why(
|
| 457 |
+
points: list[Point], level: int = None
|
| 458 |
+
) -> tuple[Circle, list[Any]]:
|
| 459 |
+
"""Why points are concyclic."""
|
| 460 |
+
for c0 in get_circles_thru_all(*points):
|
| 461 |
+
for c in c0.equivs():
|
| 462 |
+
if all([p in c.edge_graph for p in points]):
|
| 463 |
+
cycls = list(set(points))
|
| 464 |
+
why = c.why_cyclic(cycls, level)
|
| 465 |
+
if why is not None:
|
| 466 |
+
return c, why
|
| 467 |
+
|
| 468 |
+
return None, None
|
| 469 |
+
|
| 470 |
+
|
| 471 |
+
def name_map(struct: Any) -> Any:
|
| 472 |
+
if isinstance(struct, list):
|
| 473 |
+
return [name_map(x) for x in struct]
|
| 474 |
+
elif isinstance(struct, tuple):
|
| 475 |
+
return tuple([name_map(x) for x in struct])
|
| 476 |
+
elif isinstance(struct, set):
|
| 477 |
+
return set([name_map(x) for x in struct])
|
| 478 |
+
elif isinstance(struct, dict):
|
| 479 |
+
return {name_map(x): name_map(y) for x, y in struct.items()}
|
| 480 |
+
else:
|
| 481 |
+
return getattr(struct, 'name', '')
|
| 482 |
+
|
| 483 |
+
|
| 484 |
+
class Angle(Node):
|
| 485 |
+
"""Node of type Angle."""
|
| 486 |
+
|
| 487 |
+
def new_val(self) -> Measure:
|
| 488 |
+
return Measure()
|
| 489 |
+
|
| 490 |
+
def set_directions(self, d1: Direction, d2: Direction) -> None:
|
| 491 |
+
self._d = d1, d2
|
| 492 |
+
|
| 493 |
+
@property
|
| 494 |
+
def directions(self) -> tuple[Direction, Direction]:
|
| 495 |
+
d1, d2 = self._d
|
| 496 |
+
if d1 is None or d2 is None:
|
| 497 |
+
return d1, d2
|
| 498 |
+
return d1.rep(), d2.rep()
|
| 499 |
+
|
| 500 |
+
|
| 501 |
+
class Measure(Node):
|
| 502 |
+
pass
|
| 503 |
+
|
| 504 |
+
|
| 505 |
+
class Length(Node):
|
| 506 |
+
pass
|
| 507 |
+
|
| 508 |
+
|
| 509 |
+
class Ratio(Node):
|
| 510 |
+
"""Node of type Ratio."""
|
| 511 |
+
|
| 512 |
+
def new_val(self) -> Value:
|
| 513 |
+
return Value()
|
| 514 |
+
|
| 515 |
+
def set_lengths(self, l1: Length, l2: Length) -> None:
|
| 516 |
+
self._l = l1, l2
|
| 517 |
+
|
| 518 |
+
@property
|
| 519 |
+
def lengths(self) -> tuple[Length, Length]:
|
| 520 |
+
l1, l2 = self._l
|
| 521 |
+
if l1 is None or l2 is None:
|
| 522 |
+
return l1, l2
|
| 523 |
+
return l1.rep(), l2.rep()
|
| 524 |
+
|
| 525 |
+
|
| 526 |
+
class Value(Node):
|
| 527 |
+
pass
|
| 528 |
+
|
| 529 |
+
|
| 530 |
+
def all_angles(
|
| 531 |
+
d1: Direction, d2: Direction, level: int = None
|
| 532 |
+
) -> tuple[Angle, list[Direction], list[Direction]]:
|
| 533 |
+
level = level or float('inf')
|
| 534 |
+
d1s = d1.equivs_upto(level)
|
| 535 |
+
d2s = d2.equivs_upto(level)
|
| 536 |
+
|
| 537 |
+
for ang in d1.rep().neighbors(Angle):
|
| 538 |
+
d1_, d2_ = ang._d
|
| 539 |
+
if d1_ in d1s and d2_ in d2s:
|
| 540 |
+
yield ang, d1s, d2s
|
| 541 |
+
|
| 542 |
+
|
| 543 |
+
def all_ratios(
|
| 544 |
+
d1, d2, level=None
|
| 545 |
+
) -> tuple[Angle, list[Direction], list[Direction]]:
|
| 546 |
+
level = level or float('inf')
|
| 547 |
+
d1s = d1.equivs_upto(level)
|
| 548 |
+
d2s = d2.equivs_upto(level)
|
| 549 |
+
|
| 550 |
+
for ang in d1.rep().neighbors(Ratio):
|
| 551 |
+
d1_, d2_ = ang._l
|
| 552 |
+
if d1_ in d1s and d2_ in d2s:
|
| 553 |
+
yield ang, d1s, d2s
|
| 554 |
+
|
| 555 |
+
|
| 556 |
+
RANKING = {
|
| 557 |
+
Point: 0,
|
| 558 |
+
Line: 1,
|
| 559 |
+
Segment: 2,
|
| 560 |
+
Circle: 3,
|
| 561 |
+
Direction: 4,
|
| 562 |
+
Length: 5,
|
| 563 |
+
Angle: 6,
|
| 564 |
+
Ratio: 7,
|
| 565 |
+
Measure: 8,
|
| 566 |
+
Value: 9,
|
| 567 |
+
}
|
| 568 |
+
|
| 569 |
+
|
| 570 |
+
def val_type(x: Node) -> Type[Node]:
|
| 571 |
+
if isinstance(x, Line):
|
| 572 |
+
return Direction
|
| 573 |
+
if isinstance(x, Segment):
|
| 574 |
+
return Length
|
| 575 |
+
if isinstance(x, Angle):
|
| 576 |
+
return Measure
|
| 577 |
+
if isinstance(x, Ratio):
|
| 578 |
+
return Value
|
backend/core/ag4masses/alphageometry/geometry_150M_generate.gin
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
NUM_EMBEDDINGS = 1024
|
| 2 |
+
|
| 3 |
+
# Number of parameters = 152M
|
| 4 |
+
NUM_LAYERS = 12
|
| 5 |
+
EMBED_DIM = 1024
|
| 6 |
+
NUM_HEADS = 8
|
| 7 |
+
HEAD_DIM = 128
|
| 8 |
+
MLP_DIM = 4096
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
transformer_layer.TransformerLayerGenerate:
|
| 12 |
+
num_heads = %NUM_HEADS
|
| 13 |
+
head_size = %HEAD_DIM
|
| 14 |
+
window_length = 1024
|
| 15 |
+
use_long_xl_architecture = False
|
| 16 |
+
max_unrolled_windows = -1 # Always unroll.
|
| 17 |
+
relative_position_type = "t5" # Can be "fourier", "t5", or None.
|
| 18 |
+
use_causal_mask = True
|
| 19 |
+
attn_dropout_rate = %ATTN_DROPOUT_RATE # Attention matrix dropout.
|
| 20 |
+
memory_num_neighbors = 0
|
| 21 |
+
dtype = %DTYPE
|
| 22 |
+
|
| 23 |
+
decoder_stack.DecoderStackGenerate:
|
| 24 |
+
num_layers = %NUM_LAYERS
|
| 25 |
+
embedding_size = %EMBED_DIM
|
| 26 |
+
embedding_stddev = 1.0
|
| 27 |
+
layer_factory = @transformer_layer.TransformerLayerGenerate
|
| 28 |
+
dstack_window_length = 0
|
| 29 |
+
use_absolute_positions = False
|
| 30 |
+
use_final_layernorm = True # Final layernorm before token lookup.
|
| 31 |
+
final_dropout_rate = %DROPOUT_RATE # Dropout before token lookup.
|
| 32 |
+
final_mlp_factory = None # Final MLP to predict target tokens.
|
| 33 |
+
recurrent_layer_indices = ()
|
| 34 |
+
memory_factory = None # e.g. @memory_factory.memory_on_tpu_factory
|
| 35 |
+
memory_layer_indices = ()
|
| 36 |
+
dtype = %DTYPE
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
models.DecoderOnlyLanguageModelGenerate:
|
| 40 |
+
num_heads = %NUM_HEADS
|
| 41 |
+
head_size = %HEAD_DIM
|
| 42 |
+
task_config = @decoder_stack.TransformerTaskConfig()
|
| 43 |
+
decoder_factory = @decoder_stack.DecoderStackGenerate
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
training_loop.Trainer:
|
| 47 |
+
model_definition = @models.DecoderOnlyLanguageModelGenerate
|
backend/core/ag4masses/alphageometry/geometry_test.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 DeepMind Technologies Limited
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
|
| 16 |
+
"""Unit tests for geometry.py."""
|
| 17 |
+
import unittest
|
| 18 |
+
|
| 19 |
+
from absl.testing import absltest
|
| 20 |
+
import geometry as gm
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class GeometryTest(unittest.TestCase):
|
| 24 |
+
|
| 25 |
+
def _setup_equality_example(self):
|
| 26 |
+
# Create 4 nodes a, b, c, d
|
| 27 |
+
# and their lengths
|
| 28 |
+
a = gm.Segment('a')
|
| 29 |
+
la = gm.Length('l(a)')
|
| 30 |
+
a.connect_to(la)
|
| 31 |
+
la.connect_to(a)
|
| 32 |
+
|
| 33 |
+
b = gm.Segment('b')
|
| 34 |
+
lb = gm.Length('l(b)')
|
| 35 |
+
b.connect_to(lb)
|
| 36 |
+
lb.connect_to(b)
|
| 37 |
+
|
| 38 |
+
c = gm.Segment('c')
|
| 39 |
+
lc = gm.Length('l(c)')
|
| 40 |
+
c.connect_to(lc)
|
| 41 |
+
lc.connect_to(c)
|
| 42 |
+
|
| 43 |
+
d = gm.Segment('d')
|
| 44 |
+
ld = gm.Length('l(d)')
|
| 45 |
+
d.connect_to(ld)
|
| 46 |
+
ld.connect_to(d)
|
| 47 |
+
|
| 48 |
+
# Now let a=b, b=c, a=c, c=d
|
| 49 |
+
la.merge([lb], 'fact1')
|
| 50 |
+
lb.merge([lc], 'fact2')
|
| 51 |
+
la.merge([lc], 'fact3')
|
| 52 |
+
lc.merge([ld], 'fact4')
|
| 53 |
+
return a, b, c, d, la, lb, lc, ld
|
| 54 |
+
|
| 55 |
+
def test_merged_node_representative(self):
|
| 56 |
+
_, _, _, _, la, lb, lc, ld = self._setup_equality_example()
|
| 57 |
+
|
| 58 |
+
# all nodes are now represented by la.
|
| 59 |
+
self.assertEqual(la.rep(), la)
|
| 60 |
+
self.assertEqual(lb.rep(), la)
|
| 61 |
+
self.assertEqual(lc.rep(), la)
|
| 62 |
+
self.assertEqual(ld.rep(), la)
|
| 63 |
+
|
| 64 |
+
def test_merged_node_equivalence(self):
|
| 65 |
+
_, _, _, _, la, lb, lc, ld = self._setup_equality_example()
|
| 66 |
+
# all la, lb, lc, ld are equivalent
|
| 67 |
+
self.assertCountEqual(la.equivs(), [la, lb, lc, ld])
|
| 68 |
+
self.assertCountEqual(lb.equivs(), [la, lb, lc, ld])
|
| 69 |
+
self.assertCountEqual(lc.equivs(), [la, lb, lc, ld])
|
| 70 |
+
self.assertCountEqual(ld.equivs(), [la, lb, lc, ld])
|
| 71 |
+
|
| 72 |
+
def test_bfs_for_equality_transitivity(self):
|
| 73 |
+
a, _, _, d, _, _, _, _ = self._setup_equality_example()
|
| 74 |
+
|
| 75 |
+
# check that a==d because fact3 & fact4, not fact1 & fact2
|
| 76 |
+
self.assertCountEqual(gm.why_equal(a, d), ['fact3', 'fact4'])
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
if __name__ == '__main__':
|
| 80 |
+
absltest.main()
|
backend/core/ag4masses/alphageometry/graph.py
ADDED
|
@@ -0,0 +1,3057 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 DeepMind Technologies Limited
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
|
| 16 |
+
"""Implements the graph representation of the proof state."""
|
| 17 |
+
|
| 18 |
+
# pylint: disable=g-multiple-import
|
| 19 |
+
from __future__ import annotations
|
| 20 |
+
|
| 21 |
+
from collections import defaultdict # pylint: disable=g-importing-member
|
| 22 |
+
from typing import Callable, Generator, Optional, Type, Union
|
| 23 |
+
|
| 24 |
+
from absl import logging
|
| 25 |
+
import ar
|
| 26 |
+
import geometry as gm
|
| 27 |
+
from geometry import Angle, Direction, Length, Ratio
|
| 28 |
+
from geometry import Circle, Line, Point, Segment
|
| 29 |
+
from geometry import Measure, Value
|
| 30 |
+
import graph_utils as utils
|
| 31 |
+
import numericals as nm
|
| 32 |
+
import problem
|
| 33 |
+
from problem import Dependency, EmptyDependency
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
np = nm.np
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
FREE = [
|
| 40 |
+
'free',
|
| 41 |
+
'segment',
|
| 42 |
+
'r_triangle',
|
| 43 |
+
'risos',
|
| 44 |
+
'triangle',
|
| 45 |
+
'triangle12',
|
| 46 |
+
'ieq_triangle',
|
| 47 |
+
'eq_quadrangle',
|
| 48 |
+
'eq_trapezoid',
|
| 49 |
+
'eqdia_quadrangle',
|
| 50 |
+
'quadrangle',
|
| 51 |
+
'r_trapezoid',
|
| 52 |
+
'rectangle',
|
| 53 |
+
'isquare',
|
| 54 |
+
'trapezoid',
|
| 55 |
+
'pentagon',
|
| 56 |
+
'iso_triangle',
|
| 57 |
+
]
|
| 58 |
+
|
| 59 |
+
INTERSECT = [
|
| 60 |
+
'angle_bisector',
|
| 61 |
+
'angle_mirror',
|
| 62 |
+
'eqdistance',
|
| 63 |
+
'lc_tangent',
|
| 64 |
+
'on_aline',
|
| 65 |
+
'on_bline',
|
| 66 |
+
'on_circle',
|
| 67 |
+
'on_line',
|
| 68 |
+
'on_pline',
|
| 69 |
+
'on_tline',
|
| 70 |
+
'on_dia',
|
| 71 |
+
's_angle',
|
| 72 |
+
'on_opline',
|
| 73 |
+
'eqangle3',
|
| 74 |
+
]
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
# pylint: disable=protected-access
|
| 78 |
+
# pylint: disable=unused-argument
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class DepCheckFailError(Exception):
|
| 82 |
+
pass
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class PointTooCloseError(Exception):
|
| 86 |
+
pass
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class PointTooFarError(Exception):
|
| 90 |
+
pass
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class Graph:
|
| 94 |
+
"""Graph data structure representing proof state."""
|
| 95 |
+
|
| 96 |
+
def __init__(self):
|
| 97 |
+
self.type2nodes = {
|
| 98 |
+
Point: [],
|
| 99 |
+
Line: [],
|
| 100 |
+
Segment: [],
|
| 101 |
+
Circle: [],
|
| 102 |
+
Direction: [],
|
| 103 |
+
Length: [],
|
| 104 |
+
Angle: [],
|
| 105 |
+
Ratio: [],
|
| 106 |
+
Measure: [],
|
| 107 |
+
Value: [],
|
| 108 |
+
}
|
| 109 |
+
self._name2point = {}
|
| 110 |
+
self._name2node = {}
|
| 111 |
+
|
| 112 |
+
self.rconst = {} # contains all constant ratios
|
| 113 |
+
self.aconst = {} # contains all constant angles.
|
| 114 |
+
|
| 115 |
+
self.halfpi, _ = self.get_or_create_const_ang(1, 2)
|
| 116 |
+
self.vhalfpi = self.halfpi.val
|
| 117 |
+
|
| 118 |
+
self.atable = ar.AngleTable()
|
| 119 |
+
self.dtable = ar.DistanceTable()
|
| 120 |
+
self.rtable = ar.RatioTable()
|
| 121 |
+
|
| 122 |
+
# to quick access deps.
|
| 123 |
+
self.cache = {}
|
| 124 |
+
|
| 125 |
+
self._pair2line = {}
|
| 126 |
+
self._triplet2circle = {}
|
| 127 |
+
|
| 128 |
+
def copy(self) -> Graph:
|
| 129 |
+
"""Make a copy of self."""
|
| 130 |
+
p, definitions = self.build_def
|
| 131 |
+
|
| 132 |
+
p = p.copy()
|
| 133 |
+
for clause in p.clauses:
|
| 134 |
+
clause.nums = []
|
| 135 |
+
for pname in clause.points:
|
| 136 |
+
clause.nums.append(self._name2node[pname].num)
|
| 137 |
+
|
| 138 |
+
g, _ = Graph.build_problem(p, definitions, verbose=False, init_copy=False)
|
| 139 |
+
|
| 140 |
+
g.build_clauses = list(getattr(self, 'build_clauses', []))
|
| 141 |
+
return g
|
| 142 |
+
|
| 143 |
+
def _create_const_ang(self, n: int, d: int) -> None:
|
| 144 |
+
n, d = ar.simplify(n, d)
|
| 145 |
+
ang = self.aconst[(n, d)] = self.new_node(Angle, f'{n}pi/{d}')
|
| 146 |
+
ang.set_directions(None, None)
|
| 147 |
+
self.connect_val(ang, deps=None)
|
| 148 |
+
|
| 149 |
+
def _create_const_rat(self, n: int, d: int) -> None:
|
| 150 |
+
n, d = ar.simplify(n, d)
|
| 151 |
+
rat = self.rconst[(n, d)] = self.new_node(Ratio, f'{n}/{d}')
|
| 152 |
+
rat.set_lengths(None, None)
|
| 153 |
+
self.connect_val(rat, deps=None)
|
| 154 |
+
|
| 155 |
+
def get_or_create_const_ang(self, n: int, d: int) -> None:
|
| 156 |
+
n, d = ar.simplify(n, d)
|
| 157 |
+
if (n, d) not in self.aconst:
|
| 158 |
+
self._create_const_ang(n, d)
|
| 159 |
+
ang1 = self.aconst[(n, d)]
|
| 160 |
+
|
| 161 |
+
n, d = ar.simplify(d - n, d)
|
| 162 |
+
if (n, d) not in self.aconst:
|
| 163 |
+
self._create_const_ang(n, d)
|
| 164 |
+
ang2 = self.aconst[(n, d)]
|
| 165 |
+
return ang1, ang2
|
| 166 |
+
|
| 167 |
+
def get_or_create_const_rat(self, n: int, d: int) -> None:
|
| 168 |
+
n, d = ar.simplify(n, d)
|
| 169 |
+
if (n, d) not in self.rconst:
|
| 170 |
+
self._create_const_rat(n, d)
|
| 171 |
+
rat1 = self.rconst[(n, d)]
|
| 172 |
+
|
| 173 |
+
if (d, n) not in self.rconst:
|
| 174 |
+
self._create_const_rat(d, n) # pylint: disable=arguments-out-of-order
|
| 175 |
+
rat2 = self.rconst[(d, n)]
|
| 176 |
+
return rat1, rat2
|
| 177 |
+
|
| 178 |
+
def add_algebra(self, dep: Dependency, level: int) -> None:
|
| 179 |
+
"""Add new algebraic predicates."""
|
| 180 |
+
_ = level
|
| 181 |
+
if dep.name not in [
|
| 182 |
+
'para',
|
| 183 |
+
'perp',
|
| 184 |
+
'eqangle',
|
| 185 |
+
'eqratio',
|
| 186 |
+
'aconst',
|
| 187 |
+
'rconst',
|
| 188 |
+
'cong',
|
| 189 |
+
]:
|
| 190 |
+
return
|
| 191 |
+
|
| 192 |
+
name, args = dep.name, dep.args
|
| 193 |
+
|
| 194 |
+
if name == 'para':
|
| 195 |
+
ab, cd = dep.algebra
|
| 196 |
+
self.atable.add_para(ab, cd, dep)
|
| 197 |
+
|
| 198 |
+
if name == 'perp':
|
| 199 |
+
ab, cd = dep.algebra
|
| 200 |
+
self.atable.add_const_angle(ab, cd, 90, dep)
|
| 201 |
+
|
| 202 |
+
if name == 'eqangle':
|
| 203 |
+
ab, cd, mn, pq = dep.algebra
|
| 204 |
+
if (ab, cd) == (pq, mn):
|
| 205 |
+
self.atable.add_const_angle(ab, cd, 90, dep)
|
| 206 |
+
else:
|
| 207 |
+
self.atable.add_eqangle(ab, cd, mn, pq, dep)
|
| 208 |
+
|
| 209 |
+
if name == 'eqratio':
|
| 210 |
+
ab, cd, mn, pq = dep.algebra
|
| 211 |
+
if (ab, cd) == (pq, mn):
|
| 212 |
+
self.rtable.add_eq(ab, cd, dep)
|
| 213 |
+
else:
|
| 214 |
+
self.rtable.add_eqratio(ab, cd, mn, pq, dep)
|
| 215 |
+
|
| 216 |
+
if name == 'aconst':
|
| 217 |
+
bx, ab, y = dep.algebra
|
| 218 |
+
self.atable.add_const_angle(bx, ab, y, dep)
|
| 219 |
+
|
| 220 |
+
if name == 'rconst':
|
| 221 |
+
l1, l2, m, n = dep.algebra
|
| 222 |
+
self.rtable.add_const_ratio(l1, l2, m, n, dep)
|
| 223 |
+
|
| 224 |
+
if name == 'cong':
|
| 225 |
+
a, b, c, d = args
|
| 226 |
+
ab, _ = self.get_line_thru_pair_why(a, b)
|
| 227 |
+
cd, _ = self.get_line_thru_pair_why(c, d)
|
| 228 |
+
self.dtable.add_cong(ab, cd, a, b, c, d, dep)
|
| 229 |
+
|
| 230 |
+
ab, cd = dep.algebra
|
| 231 |
+
self.rtable.add_eq(ab, cd, dep)
|
| 232 |
+
|
| 233 |
+
def add_eqrat_const(
|
| 234 |
+
self, args: list[Point], deps: EmptyDependency
|
| 235 |
+
) -> list[Dependency]:
|
| 236 |
+
"""Add new algebraic predicates of type eqratio-constant."""
|
| 237 |
+
a, b, c, d, num, den = args
|
| 238 |
+
nd, dn = self.get_or_create_const_rat(num, den)
|
| 239 |
+
|
| 240 |
+
if num == den:
|
| 241 |
+
return self.add_cong([a, b, c, d], deps)
|
| 242 |
+
|
| 243 |
+
ab = self._get_or_create_segment(a, b, deps=None)
|
| 244 |
+
cd = self._get_or_create_segment(c, d, deps=None)
|
| 245 |
+
|
| 246 |
+
self.connect_val(ab, deps=None)
|
| 247 |
+
self.connect_val(cd, deps=None)
|
| 248 |
+
|
| 249 |
+
if ab.val == cd.val:
|
| 250 |
+
raise ValueError(f'{ab.name} and {cd.name} cannot be equal')
|
| 251 |
+
|
| 252 |
+
args = [a, b, c, d, nd]
|
| 253 |
+
i = 0
|
| 254 |
+
for x, y, xy in [(a, b, ab), (c, d, cd)]:
|
| 255 |
+
i += 1
|
| 256 |
+
x_, y_ = list(xy._val._obj.points)
|
| 257 |
+
if {x, y} == {x_, y_}:
|
| 258 |
+
continue
|
| 259 |
+
if deps:
|
| 260 |
+
deps = deps.extend(self, 'rconst', list(args), 'cong', [x, y, x_, y_])
|
| 261 |
+
args[2 * i - 2] = x_
|
| 262 |
+
args[2 * i - 1] = y_
|
| 263 |
+
|
| 264 |
+
ab_cd, cd_ab, why = self._get_or_create_ratio(ab, cd, deps=None)
|
| 265 |
+
if why:
|
| 266 |
+
dep0 = deps.populate('rconst', [a, b, c, d, nd])
|
| 267 |
+
deps = EmptyDependency(level=deps.level, rule_name=None)
|
| 268 |
+
deps.why = [dep0] + why
|
| 269 |
+
|
| 270 |
+
lab, lcd = ab_cd._l
|
| 271 |
+
a, b = list(lab._obj.points)
|
| 272 |
+
c, d = list(lcd._obj.points)
|
| 273 |
+
|
| 274 |
+
add = []
|
| 275 |
+
if not self.is_equal(ab_cd, nd):
|
| 276 |
+
args = [a, b, c, d, nd]
|
| 277 |
+
dep1 = deps.populate('rconst', args)
|
| 278 |
+
dep1.algebra = ab._val, cd._val, num, den
|
| 279 |
+
self.make_equal(nd, ab_cd, deps=dep1)
|
| 280 |
+
self.cache_dep('rconst', [a, b, c, d, nd], dep1)
|
| 281 |
+
add += [dep1]
|
| 282 |
+
|
| 283 |
+
if not self.is_equal(cd_ab, dn):
|
| 284 |
+
args = [c, d, a, b, dn]
|
| 285 |
+
dep2 = deps.populate('rconst', args)
|
| 286 |
+
dep2.algebra = cd._val, ab._val, den, num
|
| 287 |
+
self.make_equal(dn, cd_ab, deps=dep2)
|
| 288 |
+
self.cache_dep('rconst', [c, d, a, b, dn], dep2)
|
| 289 |
+
add += [dep2]
|
| 290 |
+
|
| 291 |
+
return add
|
| 292 |
+
|
| 293 |
+
def do_algebra(self, name: str, args: list[Point]) -> list[Dependency]:
|
| 294 |
+
"""Derive (but not add) new algebraic predicates."""
|
| 295 |
+
if name == 'para':
|
| 296 |
+
a, b, dep = args
|
| 297 |
+
if gm.is_equiv(a, b):
|
| 298 |
+
return []
|
| 299 |
+
(x, y), (m, n) = a._obj.points, b._obj.points
|
| 300 |
+
return self.add_para([x, y, m, n], dep)
|
| 301 |
+
|
| 302 |
+
if name == 'aconst':
|
| 303 |
+
a, b, n, d, dep = args
|
| 304 |
+
ab, ba, why = self.get_or_create_angle_d(a, b, deps=None)
|
| 305 |
+
nd, dn = self.get_or_create_const_ang(n, d)
|
| 306 |
+
|
| 307 |
+
(x, y), (m, n) = a._obj.points, b._obj.points
|
| 308 |
+
|
| 309 |
+
if why:
|
| 310 |
+
dep0 = dep.populate('aconst', [x, y, m, n, nd])
|
| 311 |
+
dep = EmptyDependency(level=dep.level, rule_name=None)
|
| 312 |
+
dep.why = [dep0] + why
|
| 313 |
+
|
| 314 |
+
a, b = ab._d
|
| 315 |
+
(x, y), (m, n) = a._obj.points, b._obj.points
|
| 316 |
+
|
| 317 |
+
added = []
|
| 318 |
+
if not self.is_equal(ab, nd):
|
| 319 |
+
if nd == self.halfpi:
|
| 320 |
+
added += self.add_perp([x, y, m, n], dep)
|
| 321 |
+
# else:
|
| 322 |
+
name = 'aconst'
|
| 323 |
+
args = [x, y, m, n, nd]
|
| 324 |
+
dep1 = dep.populate(name, args)
|
| 325 |
+
self.cache_dep(name, args, dep1)
|
| 326 |
+
self.make_equal(nd, ab, deps=dep1)
|
| 327 |
+
added += [dep1]
|
| 328 |
+
|
| 329 |
+
if not self.is_equal(ba, dn):
|
| 330 |
+
if dn == self.halfpi:
|
| 331 |
+
added += self.add_perp([m, n, x, y], dep)
|
| 332 |
+
name = 'aconst'
|
| 333 |
+
args = [m, n, x, y, dn]
|
| 334 |
+
dep2 = dep.populate(name, args)
|
| 335 |
+
self.cache_dep(name, args, dep2)
|
| 336 |
+
self.make_equal(dn, ba, deps=dep2)
|
| 337 |
+
added += [dep2]
|
| 338 |
+
return added
|
| 339 |
+
|
| 340 |
+
if name == 'rconst':
|
| 341 |
+
a, b, c, d, num, den, dep = args
|
| 342 |
+
return self.add_eqrat_const([a, b, c, d, num, den], dep)
|
| 343 |
+
|
| 344 |
+
if name == 'eqangle':
|
| 345 |
+
d1, d2, d3, d4, dep = args
|
| 346 |
+
a, b = d1._obj.points
|
| 347 |
+
c, d = d2._obj.points
|
| 348 |
+
e, f = d3._obj.points
|
| 349 |
+
g, h = d4._obj.points
|
| 350 |
+
|
| 351 |
+
return self.add_eqangle([a, b, c, d, e, f, g, h], dep)
|
| 352 |
+
|
| 353 |
+
if name == 'eqratio':
|
| 354 |
+
d1, d2, d3, d4, dep = args
|
| 355 |
+
a, b = d1._obj.points
|
| 356 |
+
c, d = d2._obj.points
|
| 357 |
+
e, f = d3._obj.points
|
| 358 |
+
g, h = d4._obj.points
|
| 359 |
+
|
| 360 |
+
return self.add_eqratio([a, b, c, d, e, f, g, h], dep)
|
| 361 |
+
|
| 362 |
+
if name in ['cong', 'cong2']:
|
| 363 |
+
a, b, c, d, dep = args
|
| 364 |
+
if not (a != b and c != d and (a != c or b != d)):
|
| 365 |
+
return []
|
| 366 |
+
return self.add_cong([a, b, c, d], dep)
|
| 367 |
+
|
| 368 |
+
return []
|
| 369 |
+
|
| 370 |
+
def derive_algebra(
|
| 371 |
+
self, level: int, verbose: bool = False
|
| 372 |
+
) -> tuple[
|
| 373 |
+
dict[str, list[tuple[Point, ...]]], dict[str, [tuple[Point, ...]]]
|
| 374 |
+
]:
|
| 375 |
+
"""Derive new algebraic predicates."""
|
| 376 |
+
derives = {}
|
| 377 |
+
ang_derives = self.derive_angle_algebra(level, verbose=verbose)
|
| 378 |
+
dist_derives = self.derive_distance_algebra(level, verbose=verbose)
|
| 379 |
+
rat_derives = self.derive_ratio_algebra(level, verbose=verbose)
|
| 380 |
+
|
| 381 |
+
derives.update(ang_derives)
|
| 382 |
+
derives.update(dist_derives)
|
| 383 |
+
derives.update(rat_derives)
|
| 384 |
+
|
| 385 |
+
# Separate eqangle and eqratio derivations
|
| 386 |
+
# As they are too numerous => slow down DD+AR.
|
| 387 |
+
# & reserve them only for last effort.
|
| 388 |
+
eqs = {'eqangle': derives.pop('eqangle'), 'eqratio': derives.pop('eqratio')}
|
| 389 |
+
return derives, eqs
|
| 390 |
+
|
| 391 |
+
def derive_ratio_algebra(
|
| 392 |
+
self, level: int, verbose: bool = False
|
| 393 |
+
) -> dict[str, list[tuple[Point, ...]]]:
|
| 394 |
+
"""Derive new eqratio predicates."""
|
| 395 |
+
added = {'cong2': [], 'eqratio': []}
|
| 396 |
+
|
| 397 |
+
for x in self.rtable.get_all_eqs_and_why():
|
| 398 |
+
x, why = x[:-1], x[-1]
|
| 399 |
+
dep = EmptyDependency(level=level, rule_name='a01')
|
| 400 |
+
dep.why = why
|
| 401 |
+
|
| 402 |
+
if len(x) == 2:
|
| 403 |
+
a, b = x
|
| 404 |
+
if gm.is_equiv(a, b):
|
| 405 |
+
continue
|
| 406 |
+
|
| 407 |
+
(m, n), (p, q) = a._obj.points, b._obj.points
|
| 408 |
+
added['cong2'].append((m, n, p, q, dep))
|
| 409 |
+
|
| 410 |
+
if len(x) == 4:
|
| 411 |
+
a, b, c, d = x
|
| 412 |
+
added['eqratio'].append((a, b, c, d, dep))
|
| 413 |
+
|
| 414 |
+
return added
|
| 415 |
+
|
| 416 |
+
def derive_angle_algebra(
|
| 417 |
+
self, level: int, verbose: bool = False
|
| 418 |
+
) -> dict[str, list[tuple[Point, ...]]]:
|
| 419 |
+
"""Derive new eqangles predicates."""
|
| 420 |
+
added = {'eqangle': [], 'aconst': [], 'para': []}
|
| 421 |
+
|
| 422 |
+
for x in self.atable.get_all_eqs_and_why():
|
| 423 |
+
x, why = x[:-1], x[-1]
|
| 424 |
+
dep = EmptyDependency(level=level, rule_name='a02')
|
| 425 |
+
dep.why = why
|
| 426 |
+
|
| 427 |
+
if len(x) == 2:
|
| 428 |
+
a, b = x
|
| 429 |
+
if gm.is_equiv(a, b):
|
| 430 |
+
continue
|
| 431 |
+
|
| 432 |
+
(e, f), (p, q) = a._obj.points, b._obj.points
|
| 433 |
+
if not nm.check('para', [e, f, p, q]):
|
| 434 |
+
continue
|
| 435 |
+
|
| 436 |
+
added['para'].append((a, b, dep))
|
| 437 |
+
|
| 438 |
+
if len(x) == 3:
|
| 439 |
+
a, b, (n, d) = x
|
| 440 |
+
|
| 441 |
+
(e, f), (p, q) = a._obj.points, b._obj.points
|
| 442 |
+
if not nm.check('aconst', [e, f, p, q, n, d]):
|
| 443 |
+
continue
|
| 444 |
+
|
| 445 |
+
added['aconst'].append((a, b, n, d, dep))
|
| 446 |
+
|
| 447 |
+
if len(x) == 4:
|
| 448 |
+
a, b, c, d = x
|
| 449 |
+
added['eqangle'].append((a, b, c, d, dep))
|
| 450 |
+
|
| 451 |
+
return added
|
| 452 |
+
|
| 453 |
+
def derive_distance_algebra(
|
| 454 |
+
self, level: int, verbose: bool = False
|
| 455 |
+
) -> dict[str, list[tuple[Point, ...]]]:
|
| 456 |
+
"""Derive new cong predicates."""
|
| 457 |
+
added = {'inci': [], 'cong': [], 'rconst': []}
|
| 458 |
+
for x in self.dtable.get_all_eqs_and_why():
|
| 459 |
+
x, why = x[:-1], x[-1]
|
| 460 |
+
dep = EmptyDependency(level=level, rule_name='a00')
|
| 461 |
+
dep.why = why
|
| 462 |
+
|
| 463 |
+
if len(x) == 2:
|
| 464 |
+
a, b = x
|
| 465 |
+
if a == b:
|
| 466 |
+
continue
|
| 467 |
+
|
| 468 |
+
dep.name = f'inci {a.name} {b.name}'
|
| 469 |
+
added['inci'].append((x, dep))
|
| 470 |
+
|
| 471 |
+
if len(x) == 4:
|
| 472 |
+
a, b, c, d = x
|
| 473 |
+
if not (a != b and c != d and (a != c or b != d)):
|
| 474 |
+
continue
|
| 475 |
+
added['cong'].append((a, b, c, d, dep))
|
| 476 |
+
|
| 477 |
+
if len(x) == 6:
|
| 478 |
+
a, b, c, d, num, den = x
|
| 479 |
+
if not (a != b and c != d and (a != c or b != d)):
|
| 480 |
+
continue
|
| 481 |
+
added['rconst'].append((a, b, c, d, num, den, dep))
|
| 482 |
+
|
| 483 |
+
return added
|
| 484 |
+
|
| 485 |
+
@classmethod
|
| 486 |
+
def build_problem(
|
| 487 |
+
cls,
|
| 488 |
+
pr: problem.Problem,
|
| 489 |
+
definitions: dict[str, problem.Definition],
|
| 490 |
+
verbose: bool = True,
|
| 491 |
+
init_copy: bool = True,
|
| 492 |
+
) -> tuple[Graph, list[Dependency]]:
|
| 493 |
+
"""Build a problem into a gr.Graph object."""
|
| 494 |
+
check = False
|
| 495 |
+
g = None
|
| 496 |
+
added = None
|
| 497 |
+
if verbose:
|
| 498 |
+
logging.info(pr.url)
|
| 499 |
+
logging.info(pr.txt())
|
| 500 |
+
while not check:
|
| 501 |
+
try:
|
| 502 |
+
g = Graph()
|
| 503 |
+
added = []
|
| 504 |
+
plevel = 0
|
| 505 |
+
for clause in pr.clauses:
|
| 506 |
+
adds, plevel = g.add_clause(
|
| 507 |
+
clause, plevel, definitions, verbose=verbose
|
| 508 |
+
)
|
| 509 |
+
added += adds
|
| 510 |
+
g.plevel = plevel
|
| 511 |
+
|
| 512 |
+
except (nm.InvalidLineIntersectError, nm.InvalidQuadSolveError):
|
| 513 |
+
continue
|
| 514 |
+
except DepCheckFailError:
|
| 515 |
+
continue
|
| 516 |
+
except (PointTooCloseError, PointTooFarError):
|
| 517 |
+
continue
|
| 518 |
+
|
| 519 |
+
if not pr.goal:
|
| 520 |
+
break
|
| 521 |
+
|
| 522 |
+
args = list(map(lambda x: g.get(x, lambda: int(x)), pr.goal.args))
|
| 523 |
+
check = nm.check(pr.goal.name, args)
|
| 524 |
+
|
| 525 |
+
g.url = pr.url
|
| 526 |
+
g.build_def = (pr, definitions)
|
| 527 |
+
for add in added:
|
| 528 |
+
g.add_algebra(add, level=0)
|
| 529 |
+
|
| 530 |
+
return g, added
|
| 531 |
+
|
| 532 |
+
def all_points(self) -> list[Point]:
|
| 533 |
+
"""Return all nodes of type Point."""
|
| 534 |
+
return list(self.type2nodes[Point])
|
| 535 |
+
|
| 536 |
+
def all_nodes(self) -> list[gm.Node]:
|
| 537 |
+
"""Return all nodes."""
|
| 538 |
+
return list(self._name2node.values())
|
| 539 |
+
|
| 540 |
+
def add_points(self, pnames: list[str]) -> list[Point]:
|
| 541 |
+
"""Add new points with given names in list pnames."""
|
| 542 |
+
result = [self.new_node(Point, name) for name in pnames]
|
| 543 |
+
self._name2point.update(zip(pnames, result))
|
| 544 |
+
return result
|
| 545 |
+
|
| 546 |
+
def names2nodes(self, pnames: list[str]) -> list[gm.Node]:
|
| 547 |
+
return [self._name2node[name] for name in pnames]
|
| 548 |
+
|
| 549 |
+
def names2points(
|
| 550 |
+
self, pnames: list[str], create_new_point: bool = False
|
| 551 |
+
) -> list[Point]:
|
| 552 |
+
"""Return Point objects given names."""
|
| 553 |
+
result = []
|
| 554 |
+
for name in pnames:
|
| 555 |
+
if name not in self._name2node and not create_new_point:
|
| 556 |
+
raise ValueError(f'Cannot find point {name} in graph')
|
| 557 |
+
elif name in self._name2node:
|
| 558 |
+
obj = self._name2node[name]
|
| 559 |
+
else:
|
| 560 |
+
obj = self.new_node(Point, name)
|
| 561 |
+
result.append(obj)
|
| 562 |
+
|
| 563 |
+
return result
|
| 564 |
+
|
| 565 |
+
def names2points_or_int(self, pnames: list[str]) -> list[Point]:
|
| 566 |
+
"""Return Point objects given names."""
|
| 567 |
+
result = []
|
| 568 |
+
for name in pnames:
|
| 569 |
+
if name.isdigit():
|
| 570 |
+
result += [int(name)]
|
| 571 |
+
elif 'pi/' in name:
|
| 572 |
+
n, d = name.split('pi/')
|
| 573 |
+
ang, _ = self.get_or_create_const_ang(int(n), int(d))
|
| 574 |
+
result += [ang]
|
| 575 |
+
elif '/' in name:
|
| 576 |
+
n, d = name.split('/')
|
| 577 |
+
rat, _ = self.get_or_create_const_rat(int(n), int(d))
|
| 578 |
+
result += [rat]
|
| 579 |
+
else:
|
| 580 |
+
result += [self._name2point[name]]
|
| 581 |
+
|
| 582 |
+
return result
|
| 583 |
+
|
| 584 |
+
def get(self, pointname: str, default_fn: Callable[str, Point]) -> Point:
|
| 585 |
+
if pointname in self._name2point:
|
| 586 |
+
return self._name2point[pointname]
|
| 587 |
+
if pointname in self._name2node:
|
| 588 |
+
return self._name2node[pointname]
|
| 589 |
+
return default_fn()
|
| 590 |
+
|
| 591 |
+
def new_node(self, oftype: Type[gm.Node], name: str = '') -> gm.Node:
|
| 592 |
+
node = oftype(name, self)
|
| 593 |
+
|
| 594 |
+
self.type2nodes[oftype].append(node)
|
| 595 |
+
self._name2node[name] = node
|
| 596 |
+
|
| 597 |
+
if isinstance(node, Point):
|
| 598 |
+
self._name2point[name] = node
|
| 599 |
+
|
| 600 |
+
return node
|
| 601 |
+
|
| 602 |
+
def merge(self, nodes: list[gm.Node], deps: Dependency) -> gm.Node:
|
| 603 |
+
"""Merge all nodes."""
|
| 604 |
+
if len(nodes) < 2:
|
| 605 |
+
return
|
| 606 |
+
|
| 607 |
+
node0, *nodes1 = nodes
|
| 608 |
+
all_nodes = self.type2nodes[type(node0)]
|
| 609 |
+
|
| 610 |
+
# find node0 that exists in all_nodes to be the rep
|
| 611 |
+
# and merge all other nodes into node0
|
| 612 |
+
for node in nodes:
|
| 613 |
+
if node in all_nodes:
|
| 614 |
+
node0 = node
|
| 615 |
+
nodes1 = [n for n in nodes if n != node0]
|
| 616 |
+
break
|
| 617 |
+
return self.merge_into(node0, nodes1, deps)
|
| 618 |
+
|
| 619 |
+
def merge_into(
|
| 620 |
+
self, node0: gm.Node, nodes1: list[gm.Node], deps: Dependency
|
| 621 |
+
) -> gm.Node:
|
| 622 |
+
"""Merge nodes1 into a single node0."""
|
| 623 |
+
node0.merge(nodes1, deps)
|
| 624 |
+
for n in nodes1:
|
| 625 |
+
if n.rep() != n:
|
| 626 |
+
self.remove([n])
|
| 627 |
+
|
| 628 |
+
nodes = [node0] + nodes1
|
| 629 |
+
if any([node._val for node in nodes]):
|
| 630 |
+
for node in nodes:
|
| 631 |
+
self.connect_val(node, deps=None)
|
| 632 |
+
|
| 633 |
+
vals1 = [n._val for n in nodes1]
|
| 634 |
+
node0._val.merge(vals1, deps)
|
| 635 |
+
|
| 636 |
+
for v in vals1:
|
| 637 |
+
if v.rep() != v:
|
| 638 |
+
self.remove([v])
|
| 639 |
+
|
| 640 |
+
return node0
|
| 641 |
+
|
| 642 |
+
def remove(self, nodes: list[gm.Node]) -> None:
|
| 643 |
+
"""Remove nodes out of self because they are merged."""
|
| 644 |
+
if not nodes:
|
| 645 |
+
return
|
| 646 |
+
|
| 647 |
+
for node in nodes:
|
| 648 |
+
all_nodes = self.type2nodes[type(nodes[0])]
|
| 649 |
+
|
| 650 |
+
if node in all_nodes:
|
| 651 |
+
all_nodes.remove(node)
|
| 652 |
+
|
| 653 |
+
if node.name in self._name2node.values():
|
| 654 |
+
self._name2node.pop(node.name)
|
| 655 |
+
|
| 656 |
+
def connect(self, a: gm.Node, b: gm.Node, deps: Dependency) -> None:
|
| 657 |
+
a.connect_to(b, deps)
|
| 658 |
+
b.connect_to(a, deps)
|
| 659 |
+
|
| 660 |
+
def connect_val(self, node: gm.Node, deps: Dependency) -> gm.Node:
|
| 661 |
+
"""Connect a node into its value (equality) node."""
|
| 662 |
+
if node._val:
|
| 663 |
+
return node._val
|
| 664 |
+
name = None
|
| 665 |
+
if isinstance(node, Line):
|
| 666 |
+
name = 'd(' + node.name + ')'
|
| 667 |
+
if isinstance(node, Angle):
|
| 668 |
+
name = 'm(' + node.name + ')'
|
| 669 |
+
if isinstance(node, Segment):
|
| 670 |
+
name = 'l(' + node.name + ')'
|
| 671 |
+
if isinstance(node, Ratio):
|
| 672 |
+
name = 'r(' + node.name + ')'
|
| 673 |
+
v = self.new_node(gm.val_type(node), name)
|
| 674 |
+
self.connect(node, v, deps=deps)
|
| 675 |
+
return v
|
| 676 |
+
|
| 677 |
+
def is_equal(self, x: gm.Node, y: gm.Node, level: int = None) -> bool:
|
| 678 |
+
return gm.is_equal(x, y, level)
|
| 679 |
+
|
| 680 |
+
def add_piece(
|
| 681 |
+
self, name: str, args: list[Point], deps: EmptyDependency
|
| 682 |
+
) -> list[Dependency]:
|
| 683 |
+
"""Add a new predicate."""
|
| 684 |
+
if name in ['coll', 'collx']:
|
| 685 |
+
return self.add_coll(args, deps)
|
| 686 |
+
elif name == 'para':
|
| 687 |
+
return self.add_para(args, deps)
|
| 688 |
+
elif name == 'perp':
|
| 689 |
+
return self.add_perp(args, deps)
|
| 690 |
+
elif name == 'midp':
|
| 691 |
+
return self.add_midp(args, deps)
|
| 692 |
+
elif name == 'cong':
|
| 693 |
+
return self.add_cong(args, deps)
|
| 694 |
+
elif name == 'circle':
|
| 695 |
+
return self.add_circle(args, deps)
|
| 696 |
+
elif name == 'cyclic':
|
| 697 |
+
return self.add_cyclic(args, deps)
|
| 698 |
+
elif name in ['eqangle', 'eqangle6']:
|
| 699 |
+
return self.add_eqangle(args, deps)
|
| 700 |
+
elif name in ['eqratio', 'eqratio6']:
|
| 701 |
+
return self.add_eqratio(args, deps)
|
| 702 |
+
# numerical!
|
| 703 |
+
elif name == 's_angle':
|
| 704 |
+
return self.add_s_angle(args, deps)
|
| 705 |
+
elif name == 'aconst':
|
| 706 |
+
a, b, c, d, ang = args
|
| 707 |
+
|
| 708 |
+
if isinstance(ang, str):
|
| 709 |
+
name = ang
|
| 710 |
+
else:
|
| 711 |
+
name = ang.name
|
| 712 |
+
|
| 713 |
+
num, den = name.split('pi/')
|
| 714 |
+
num, den = int(num), int(den)
|
| 715 |
+
return self.add_aconst([a, b, c, d, num, den], deps)
|
| 716 |
+
elif name == 's_angle':
|
| 717 |
+
b, x, a, b, ang = ( # pylint: disable=redeclared-assigned-name,unused-variable
|
| 718 |
+
args
|
| 719 |
+
)
|
| 720 |
+
|
| 721 |
+
if isinstance(ang, str):
|
| 722 |
+
name = ang
|
| 723 |
+
else:
|
| 724 |
+
name = ang.name
|
| 725 |
+
|
| 726 |
+
n, d = name.split('pi/')
|
| 727 |
+
ang = int(n) * 180 / int(d)
|
| 728 |
+
return self.add_s_angle([a, b, x, ang], deps)
|
| 729 |
+
elif name == 'rconst':
|
| 730 |
+
a, b, c, d, rat = args
|
| 731 |
+
|
| 732 |
+
if isinstance(rat, str):
|
| 733 |
+
name = rat
|
| 734 |
+
else:
|
| 735 |
+
name = rat.name
|
| 736 |
+
|
| 737 |
+
num, den = name.split('/')
|
| 738 |
+
num, den = int(num), int(den)
|
| 739 |
+
return self.add_eqrat_const([a, b, c, d, num, den], deps)
|
| 740 |
+
|
| 741 |
+
# composite pieces:
|
| 742 |
+
elif name == 'cong2':
|
| 743 |
+
return self.add_cong2(args, deps)
|
| 744 |
+
elif name == 'eqratio3':
|
| 745 |
+
return self.add_eqratio3(args, deps)
|
| 746 |
+
elif name == 'eqratio4':
|
| 747 |
+
return self.add_eqratio4(args, deps)
|
| 748 |
+
elif name == 'simtri':
|
| 749 |
+
return self.add_simtri(args, deps)
|
| 750 |
+
elif name == 'contri':
|
| 751 |
+
return self.add_contri(args, deps)
|
| 752 |
+
elif name == 'simtri2':
|
| 753 |
+
return self.add_simtri2(args, deps)
|
| 754 |
+
elif name == 'contri2':
|
| 755 |
+
return self.add_contri2(args, deps)
|
| 756 |
+
elif name == 'simtri*':
|
| 757 |
+
return self.add_simtri_check(args, deps)
|
| 758 |
+
elif name == 'contri*':
|
| 759 |
+
return self.add_contri_check(args, deps)
|
| 760 |
+
elif name in ['acompute', 'rcompute']:
|
| 761 |
+
dep = deps.populate(name, args)
|
| 762 |
+
self.cache_dep(name, args, dep)
|
| 763 |
+
return [dep]
|
| 764 |
+
elif name in ['fixl', 'fixc', 'fixb', 'fixt', 'fixp']:
|
| 765 |
+
dep = deps.populate(name, args)
|
| 766 |
+
self.cache_dep(name, args, dep)
|
| 767 |
+
return [dep]
|
| 768 |
+
elif name in ['ind']:
|
| 769 |
+
return []
|
| 770 |
+
raise ValueError(f'Not recognize {name}')
|
| 771 |
+
|
| 772 |
+
def check(self, name: str, args: list[Point]) -> bool:
|
| 773 |
+
"""Symbolically check if a predicate is True."""
|
| 774 |
+
if name == 'ncoll':
|
| 775 |
+
return self.check_ncoll(args)
|
| 776 |
+
if name == 'npara':
|
| 777 |
+
return self.check_npara(args)
|
| 778 |
+
if name == 'nperp':
|
| 779 |
+
return self.check_nperp(args)
|
| 780 |
+
if name == 'midp':
|
| 781 |
+
return self.check_midp(args)
|
| 782 |
+
if name == 'cong':
|
| 783 |
+
return self.check_cong(args)
|
| 784 |
+
if name == 'perp':
|
| 785 |
+
return self.check_perp(args)
|
| 786 |
+
if name == 'para':
|
| 787 |
+
return self.check_para(args)
|
| 788 |
+
if name == 'coll':
|
| 789 |
+
return self.check_coll(args)
|
| 790 |
+
if name == 'cyclic':
|
| 791 |
+
return self.check_cyclic(args)
|
| 792 |
+
if name == 'circle':
|
| 793 |
+
return self.check_circle(args)
|
| 794 |
+
if name == 'aconst':
|
| 795 |
+
return self.check_aconst(args)
|
| 796 |
+
if name == 'rconst':
|
| 797 |
+
return self.check_rconst(args)
|
| 798 |
+
if name == 'acompute':
|
| 799 |
+
return self.check_acompute(args)
|
| 800 |
+
if name == 'rcompute':
|
| 801 |
+
return self.check_rcompute(args)
|
| 802 |
+
if name in ['eqangle', 'eqangle6']:
|
| 803 |
+
if len(args) == 5:
|
| 804 |
+
return self.check_aconst(args)
|
| 805 |
+
return self.check_eqangle(args)
|
| 806 |
+
if name in ['eqratio', 'eqratio6']:
|
| 807 |
+
if len(args) == 5:
|
| 808 |
+
return self.check_rconst(args)
|
| 809 |
+
return self.check_eqratio(args)
|
| 810 |
+
if name in ['simtri', 'simtri2', 'simtri*']:
|
| 811 |
+
return self.check_simtri(args)
|
| 812 |
+
if name in ['contri', 'contri2', 'contri*']:
|
| 813 |
+
return self.check_contri(args)
|
| 814 |
+
if name == 'sameside':
|
| 815 |
+
return self.check_sameside(args)
|
| 816 |
+
if name in 'diff':
|
| 817 |
+
a, b = args
|
| 818 |
+
return not a.num.close(b.num)
|
| 819 |
+
if name in ['fixl', 'fixc', 'fixb', 'fixt', 'fixp']:
|
| 820 |
+
return self.in_cache(name, args)
|
| 821 |
+
if name in ['ind']:
|
| 822 |
+
return True
|
| 823 |
+
raise ValueError(f'Not recognize {name}')
|
| 824 |
+
|
| 825 |
+
def get_lines_thru_all(self, *points: list[gm.Point]) -> list[Line]:
|
| 826 |
+
line2count = defaultdict(lambda: 0)
|
| 827 |
+
points = set(points)
|
| 828 |
+
for p in points:
|
| 829 |
+
for l in p.neighbors(Line):
|
| 830 |
+
line2count[l] += 1
|
| 831 |
+
return [l for l, count in line2count.items() if count == len(points)]
|
| 832 |
+
|
| 833 |
+
def _get_line(self, a: Point, b: Point) -> Optional[Line]:
|
| 834 |
+
linesa = a.neighbors(Line)
|
| 835 |
+
for l in b.neighbors(Line):
|
| 836 |
+
if l in linesa:
|
| 837 |
+
return l
|
| 838 |
+
return None
|
| 839 |
+
|
| 840 |
+
def _get_line_all(self, a: Point, b: Point) -> Generator[Line, None, None]:
|
| 841 |
+
linesa = a.neighbors(Line, do_rep=False)
|
| 842 |
+
linesb = b.neighbors(Line, do_rep=False)
|
| 843 |
+
for l in linesb:
|
| 844 |
+
if l in linesa:
|
| 845 |
+
yield l
|
| 846 |
+
|
| 847 |
+
def _get_lines(self, *points: list[Point]) -> list[Line]:
|
| 848 |
+
"""Return all lines that connect to >= 2 points."""
|
| 849 |
+
line2count = defaultdict(lambda: 0)
|
| 850 |
+
for p in points:
|
| 851 |
+
for l in p.neighbors(Line):
|
| 852 |
+
line2count[l] += 1
|
| 853 |
+
return [l for l, count in line2count.items() if count >= 2]
|
| 854 |
+
|
| 855 |
+
def get_circle_thru_triplet(self, p1: Point, p2: Point, p3: Point) -> Circle:
|
| 856 |
+
p1, p2, p3 = sorted([p1, p2, p3], key=lambda x: x.name)
|
| 857 |
+
if (p1, p2, p3) in self._triplet2circle:
|
| 858 |
+
return self._triplet2circle[(p1, p2, p3)]
|
| 859 |
+
return self.get_new_circle_thru_triplet(p1, p2, p3)
|
| 860 |
+
|
| 861 |
+
def get_new_circle_thru_triplet(
|
| 862 |
+
self, p1: Point, p2: Point, p3: Point
|
| 863 |
+
) -> Circle:
|
| 864 |
+
"""Get a new Circle that goes thru three given Points."""
|
| 865 |
+
p1, p2, p3 = sorted([p1, p2, p3], key=lambda x: x.name)
|
| 866 |
+
name = p1.name.lower() + p2.name.lower() + p3.name.lower()
|
| 867 |
+
circle = self.new_node(Circle, f'({name})')
|
| 868 |
+
circle.num = nm.Circle(p1=p1.num, p2=p2.num, p3=p3.num)
|
| 869 |
+
circle.points = p1, p2, p3
|
| 870 |
+
|
| 871 |
+
self.connect(p1, circle, deps=None)
|
| 872 |
+
self.connect(p2, circle, deps=None)
|
| 873 |
+
self.connect(p3, circle, deps=None)
|
| 874 |
+
self._triplet2circle[(p1, p2, p3)] = circle
|
| 875 |
+
return circle
|
| 876 |
+
|
| 877 |
+
def get_line_thru_pair(self, p1: Point, p2: Point) -> Line:
|
| 878 |
+
if (p1, p2) in self._pair2line:
|
| 879 |
+
return self._pair2line[(p1, p2)]
|
| 880 |
+
if (p2, p1) in self._pair2line:
|
| 881 |
+
return self._pair2line[(p2, p1)]
|
| 882 |
+
return self.get_new_line_thru_pair(p1, p2)
|
| 883 |
+
|
| 884 |
+
def get_new_line_thru_pair(self, p1: Point, p2: Point) -> Line:
|
| 885 |
+
if p1.name.lower() > p2.name.lower():
|
| 886 |
+
p1, p2 = p2, p1
|
| 887 |
+
name = p1.name.lower() + p2.name.lower()
|
| 888 |
+
line = self.new_node(Line, name)
|
| 889 |
+
line.num = nm.Line(p1.num, p2.num)
|
| 890 |
+
line.points = p1, p2
|
| 891 |
+
|
| 892 |
+
self.connect(p1, line, deps=None)
|
| 893 |
+
self.connect(p2, line, deps=None)
|
| 894 |
+
self._pair2line[(p1, p2)] = line
|
| 895 |
+
return line
|
| 896 |
+
|
| 897 |
+
def get_line_thru_pair_why(
|
| 898 |
+
self, p1: Point, p2: Point
|
| 899 |
+
) -> tuple[Line, list[Dependency]]:
|
| 900 |
+
"""Get one line thru two given points and the corresponding dependency list."""
|
| 901 |
+
if p1.name.lower() > p2.name.lower():
|
| 902 |
+
p1, p2 = p2, p1
|
| 903 |
+
if (p1, p2) in self._pair2line:
|
| 904 |
+
return self._pair2line[(p1, p2)].rep_and_why()
|
| 905 |
+
|
| 906 |
+
l, why = gm.line_of_and_why([p1, p2])
|
| 907 |
+
if l is None:
|
| 908 |
+
l = self.get_new_line_thru_pair(p1, p2)
|
| 909 |
+
why = []
|
| 910 |
+
return l, why
|
| 911 |
+
|
| 912 |
+
def coll_dep(self, points: list[Point], p: Point) -> list[Dependency]:
|
| 913 |
+
"""Return the dep(.why) explaining why p is coll with points."""
|
| 914 |
+
for p1, p2 in utils.comb2(points):
|
| 915 |
+
if self.check_coll([p1, p2, p]):
|
| 916 |
+
dep = Dependency('coll', [p1, p2, p], None, None)
|
| 917 |
+
return dep.why_me_or_cache(self, None)
|
| 918 |
+
|
| 919 |
+
def add_coll(
|
| 920 |
+
self, points: list[Point], deps: EmptyDependency
|
| 921 |
+
) -> list[Dependency]:
|
| 922 |
+
"""Add a predicate that `points` are collinear."""
|
| 923 |
+
points = list(set(points))
|
| 924 |
+
og_points = list(points)
|
| 925 |
+
|
| 926 |
+
all_lines = []
|
| 927 |
+
for p1, p2 in utils.comb2(points):
|
| 928 |
+
all_lines.append(self.get_line_thru_pair(p1, p2))
|
| 929 |
+
points = sum([l.neighbors(Point) for l in all_lines], [])
|
| 930 |
+
points = list(set(points))
|
| 931 |
+
|
| 932 |
+
existed = set()
|
| 933 |
+
new = set()
|
| 934 |
+
for p1, p2 in utils.comb2(points):
|
| 935 |
+
if p1.name > p2.name:
|
| 936 |
+
p1, p2 = p2, p1
|
| 937 |
+
if (p1, p2) in self._pair2line:
|
| 938 |
+
line = self._pair2line[(p1, p2)]
|
| 939 |
+
existed.add(line)
|
| 940 |
+
else:
|
| 941 |
+
line = self.get_new_line_thru_pair(p1, p2)
|
| 942 |
+
new.add(line)
|
| 943 |
+
|
| 944 |
+
existed = sorted(existed, key=lambda l: l.name)
|
| 945 |
+
new = sorted(new, key=lambda l: l.name)
|
| 946 |
+
|
| 947 |
+
existed, new = list(existed), list(new)
|
| 948 |
+
if not existed:
|
| 949 |
+
line0, *lines = new
|
| 950 |
+
else:
|
| 951 |
+
line0, lines = existed[0], existed[1:] + new
|
| 952 |
+
|
| 953 |
+
add = []
|
| 954 |
+
line0, why0 = line0.rep_and_why()
|
| 955 |
+
a, b = line0.points
|
| 956 |
+
for line in lines:
|
| 957 |
+
c, d = line.points
|
| 958 |
+
args = list({a, b, c, d})
|
| 959 |
+
if len(args) < 3:
|
| 960 |
+
continue
|
| 961 |
+
|
| 962 |
+
whys = []
|
| 963 |
+
for x in args:
|
| 964 |
+
if x not in og_points:
|
| 965 |
+
whys.append(self.coll_dep(og_points, x))
|
| 966 |
+
|
| 967 |
+
abcd_deps = deps
|
| 968 |
+
if whys + why0:
|
| 969 |
+
dep0 = deps.populate('coll', og_points)
|
| 970 |
+
abcd_deps = EmptyDependency(level=deps.level, rule_name=None)
|
| 971 |
+
abcd_deps.why = [dep0] + whys
|
| 972 |
+
|
| 973 |
+
is_coll = self.check_coll(args)
|
| 974 |
+
dep = abcd_deps.populate('coll', args)
|
| 975 |
+
self.cache_dep('coll', args, dep)
|
| 976 |
+
self.merge_into(line0, [line], dep)
|
| 977 |
+
|
| 978 |
+
if not is_coll:
|
| 979 |
+
add += [dep]
|
| 980 |
+
|
| 981 |
+
return add
|
| 982 |
+
|
| 983 |
+
def check_coll(self, points: list[Point]) -> bool:
|
| 984 |
+
points = list(set(points))
|
| 985 |
+
if len(points) < 3:
|
| 986 |
+
return True
|
| 987 |
+
line2count = defaultdict(lambda: 0)
|
| 988 |
+
for p in points:
|
| 989 |
+
for l in p.neighbors(Line):
|
| 990 |
+
line2count[l] += 1
|
| 991 |
+
return any([count == len(points) for _, count in line2count.items()])
|
| 992 |
+
|
| 993 |
+
def why_coll(self, args: tuple[Line, list[Point]]) -> list[Dependency]:
|
| 994 |
+
line, points = args
|
| 995 |
+
return line.why_coll(points)
|
| 996 |
+
|
| 997 |
+
def check_ncoll(self, points: list[Point]) -> bool:
|
| 998 |
+
if self.check_coll(points):
|
| 999 |
+
return False
|
| 1000 |
+
return not nm.check_coll([p.num for p in points])
|
| 1001 |
+
|
| 1002 |
+
def check_sameside(self, points: list[Point]) -> bool:
|
| 1003 |
+
return nm.check_sameside([p.num for p in points])
|
| 1004 |
+
|
| 1005 |
+
def make_equal(self, x: gm.Node, y: gm.Node, deps: Dependency) -> None:
|
| 1006 |
+
"""Make that two nodes x and y are equal, i.e. merge their value node."""
|
| 1007 |
+
if x.val is None:
|
| 1008 |
+
x, y = y, x
|
| 1009 |
+
|
| 1010 |
+
self.connect_val(x, deps=None)
|
| 1011 |
+
self.connect_val(y, deps=None)
|
| 1012 |
+
vx = x._val
|
| 1013 |
+
vy = y._val
|
| 1014 |
+
|
| 1015 |
+
if vx == vy:
|
| 1016 |
+
return
|
| 1017 |
+
|
| 1018 |
+
merges = [vx, vy]
|
| 1019 |
+
|
| 1020 |
+
if (
|
| 1021 |
+
isinstance(x, Angle)
|
| 1022 |
+
and x not in self.aconst.values()
|
| 1023 |
+
and y not in self.aconst.values()
|
| 1024 |
+
and x.directions == y.directions[::-1]
|
| 1025 |
+
and x.directions[0] != x.directions[1]
|
| 1026 |
+
):
|
| 1027 |
+
merges = [self.vhalfpi, vx, vy]
|
| 1028 |
+
|
| 1029 |
+
self.merge(merges, deps)
|
| 1030 |
+
|
| 1031 |
+
def merge_vals(self, vx: gm.Node, vy: gm.Node, deps: Dependency) -> None:
|
| 1032 |
+
if vx == vy:
|
| 1033 |
+
return
|
| 1034 |
+
merges = [vx, vy]
|
| 1035 |
+
self.merge(merges, deps)
|
| 1036 |
+
|
| 1037 |
+
def why_equal(self, x: gm.Node, y: gm.Node, level: int) -> list[Dependency]:
|
| 1038 |
+
return gm.why_equal(x, y, level)
|
| 1039 |
+
|
| 1040 |
+
def _why_coll4(
|
| 1041 |
+
self,
|
| 1042 |
+
a: Point,
|
| 1043 |
+
b: Point,
|
| 1044 |
+
ab: Line,
|
| 1045 |
+
c: Point,
|
| 1046 |
+
d: Point,
|
| 1047 |
+
cd: Line,
|
| 1048 |
+
level: int,
|
| 1049 |
+
) -> list[Dependency]:
|
| 1050 |
+
return self._why_coll2(a, b, ab, level) + self._why_coll2(c, d, cd, level)
|
| 1051 |
+
|
| 1052 |
+
def _why_coll8(
|
| 1053 |
+
self,
|
| 1054 |
+
a: Point,
|
| 1055 |
+
b: Point,
|
| 1056 |
+
ab: Line,
|
| 1057 |
+
c: Point,
|
| 1058 |
+
d: Point,
|
| 1059 |
+
cd: Line,
|
| 1060 |
+
m: Point,
|
| 1061 |
+
n: Point,
|
| 1062 |
+
mn: Line,
|
| 1063 |
+
p: Point,
|
| 1064 |
+
q: Point,
|
| 1065 |
+
pq: Line,
|
| 1066 |
+
level: int,
|
| 1067 |
+
) -> list[Dependency]:
|
| 1068 |
+
"""Dependency list of why 8 points are collinear."""
|
| 1069 |
+
why8 = self._why_coll4(a, b, ab, c, d, cd, level)
|
| 1070 |
+
why8 += self._why_coll4(m, n, mn, p, q, pq, level)
|
| 1071 |
+
return why8
|
| 1072 |
+
|
| 1073 |
+
def add_para(
|
| 1074 |
+
self, points: list[Point], deps: EmptyDependency
|
| 1075 |
+
) -> list[Dependency]:
|
| 1076 |
+
"""Add a new predicate that 4 points (2 lines) are parallel."""
|
| 1077 |
+
a, b, c, d = points
|
| 1078 |
+
ab, why1 = self.get_line_thru_pair_why(a, b)
|
| 1079 |
+
cd, why2 = self.get_line_thru_pair_why(c, d)
|
| 1080 |
+
|
| 1081 |
+
is_equal = self.is_equal(ab, cd)
|
| 1082 |
+
|
| 1083 |
+
(a, b), (c, d) = ab.points, cd.points
|
| 1084 |
+
|
| 1085 |
+
dep0 = deps.populate('para', points)
|
| 1086 |
+
deps = EmptyDependency(level=deps.level, rule_name=None)
|
| 1087 |
+
|
| 1088 |
+
deps = deps.populate('para', [a, b, c, d])
|
| 1089 |
+
deps.why = [dep0] + why1 + why2
|
| 1090 |
+
|
| 1091 |
+
self.make_equal(ab, cd, deps)
|
| 1092 |
+
deps.algebra = ab._val, cd._val
|
| 1093 |
+
|
| 1094 |
+
self.cache_dep('para', [a, b, c, d], deps)
|
| 1095 |
+
if not is_equal:
|
| 1096 |
+
return [deps]
|
| 1097 |
+
return []
|
| 1098 |
+
|
| 1099 |
+
def why_para(self, args: list[Point]) -> list[Dependency]:
|
| 1100 |
+
ab, cd, lvl = args
|
| 1101 |
+
return self.why_equal(ab, cd, lvl)
|
| 1102 |
+
|
| 1103 |
+
def check_para_or_coll(self, points: list[Point]) -> bool:
|
| 1104 |
+
return self.check_para(points) or self.check_coll(points)
|
| 1105 |
+
|
| 1106 |
+
def check_para(self, points: list[Point]) -> bool:
|
| 1107 |
+
a, b, c, d = points
|
| 1108 |
+
if (a == b) or (c == d):
|
| 1109 |
+
return False
|
| 1110 |
+
ab = self._get_line(a, b)
|
| 1111 |
+
cd = self._get_line(c, d)
|
| 1112 |
+
if not ab or not cd:
|
| 1113 |
+
return False
|
| 1114 |
+
|
| 1115 |
+
return self.is_equal(ab, cd)
|
| 1116 |
+
|
| 1117 |
+
def check_npara(self, points: list[Point]) -> bool:
|
| 1118 |
+
if self.check_para(points):
|
| 1119 |
+
return False
|
| 1120 |
+
return not nm.check_para([p.num for p in points])
|
| 1121 |
+
|
| 1122 |
+
def _get_angle(
|
| 1123 |
+
self, d1: Direction, d2: Direction
|
| 1124 |
+
) -> tuple[Angle, Optional[Angle]]:
|
| 1125 |
+
for a in self.type2nodes[Angle]:
|
| 1126 |
+
if a.directions == (d1, d2):
|
| 1127 |
+
return a, a.opposite
|
| 1128 |
+
return None, None
|
| 1129 |
+
|
| 1130 |
+
def get_first_angle(
|
| 1131 |
+
self, l1: Line, l2: Line
|
| 1132 |
+
) -> tuple[Angle, list[Dependency]]:
|
| 1133 |
+
"""Get a first angle between line l1 and line l2."""
|
| 1134 |
+
d1, d2 = l1._val, l2._val
|
| 1135 |
+
|
| 1136 |
+
d1s = d1.all_reps()
|
| 1137 |
+
d2s = d2.all_reps()
|
| 1138 |
+
|
| 1139 |
+
found = d1.first_angle(d2s)
|
| 1140 |
+
if found is None:
|
| 1141 |
+
found = d2.first_angle(d1s)
|
| 1142 |
+
if found is None:
|
| 1143 |
+
return None, []
|
| 1144 |
+
ang, x2, x1 = found
|
| 1145 |
+
found = ang.opposite, x1, x2
|
| 1146 |
+
|
| 1147 |
+
ang, x1, x2 = found
|
| 1148 |
+
return ang, d1.deps_upto(x1) + d2.deps_upto(x2)
|
| 1149 |
+
|
| 1150 |
+
def _get_or_create_angle(
|
| 1151 |
+
self, l1: Line, l2: Line, deps: Dependency
|
| 1152 |
+
) -> tuple[Angle, Angle, list[Dependency]]:
|
| 1153 |
+
return self.get_or_create_angle_d(l1._val, l2._val, deps)
|
| 1154 |
+
|
| 1155 |
+
def get_or_create_angle_d(
|
| 1156 |
+
self, d1: Direction, d2: Direction, deps: Dependency
|
| 1157 |
+
) -> tuple[Angle, Angle, list[Dependency]]:
|
| 1158 |
+
"""Get or create an angle between two Direction d1 and d2."""
|
| 1159 |
+
for a in self.type2nodes[Angle]:
|
| 1160 |
+
if a.directions == (d1.rep(), d2.rep()): # directions = _d.rep()
|
| 1161 |
+
d1_, d2_ = a._d
|
| 1162 |
+
why1 = d1.why_equal([d1_], None) + d1_.why_rep()
|
| 1163 |
+
why2 = d2.why_equal([d2_], None) + d2_.why_rep()
|
| 1164 |
+
return a, a.opposite, why1 + why2
|
| 1165 |
+
|
| 1166 |
+
d1, why1 = d1.rep_and_why()
|
| 1167 |
+
d2, why2 = d2.rep_and_why()
|
| 1168 |
+
a12 = self.new_node(Angle, f'{d1.name}-{d2.name}')
|
| 1169 |
+
a21 = self.new_node(Angle, f'{d2.name}-{d1.name}')
|
| 1170 |
+
self.connect(d1, a12, deps)
|
| 1171 |
+
self.connect(d2, a21, deps)
|
| 1172 |
+
self.connect(a12, a21, deps)
|
| 1173 |
+
a12.set_directions(d1, d2)
|
| 1174 |
+
a21.set_directions(d2, d1)
|
| 1175 |
+
a12.opposite = a21
|
| 1176 |
+
a21.opposite = a12
|
| 1177 |
+
return a12, a21, why1 + why2
|
| 1178 |
+
|
| 1179 |
+
def _add_para_or_coll(
|
| 1180 |
+
self,
|
| 1181 |
+
a: Point,
|
| 1182 |
+
b: Point,
|
| 1183 |
+
c: Point,
|
| 1184 |
+
d: Point,
|
| 1185 |
+
x: Point,
|
| 1186 |
+
y: Point,
|
| 1187 |
+
m: Point,
|
| 1188 |
+
n: Point,
|
| 1189 |
+
deps: EmptyDependency,
|
| 1190 |
+
) -> list[Dependency]:
|
| 1191 |
+
"""Add a new parallel or collinear predicate."""
|
| 1192 |
+
extends = [('perp', [x, y, m, n])]
|
| 1193 |
+
if {a, b} == {x, y}:
|
| 1194 |
+
pass
|
| 1195 |
+
elif self.check_para([a, b, x, y]):
|
| 1196 |
+
extends.append(('para', [a, b, x, y]))
|
| 1197 |
+
elif self.check_coll([a, b, x, y]):
|
| 1198 |
+
extends.append(('coll', set(list([a, b, x, y]))))
|
| 1199 |
+
else:
|
| 1200 |
+
return None
|
| 1201 |
+
|
| 1202 |
+
if m in [c, d] or n in [c, d] or c in [m, n] or d in [m, n]:
|
| 1203 |
+
pass
|
| 1204 |
+
elif self.check_coll([c, d, m]):
|
| 1205 |
+
extends.append(('coll', [c, d, m]))
|
| 1206 |
+
elif self.check_coll([c, d, n]):
|
| 1207 |
+
extends.append(('coll', [c, d, n]))
|
| 1208 |
+
elif self.check_coll([c, m, n]):
|
| 1209 |
+
extends.append(('coll', [c, m, n]))
|
| 1210 |
+
elif self.check_coll([d, m, n]):
|
| 1211 |
+
extends.append(('coll', [d, m, n]))
|
| 1212 |
+
else:
|
| 1213 |
+
deps = deps.extend_many(self, 'perp', [a, b, c, d], extends)
|
| 1214 |
+
return self.add_para([c, d, m, n], deps)
|
| 1215 |
+
|
| 1216 |
+
deps = deps.extend_many(self, 'perp', [a, b, c, d], extends)
|
| 1217 |
+
return self.add_coll(list(set([c, d, m, n])), deps)
|
| 1218 |
+
|
| 1219 |
+
def maybe_make_para_from_perp(
|
| 1220 |
+
self, points: list[Point], deps: EmptyDependency
|
| 1221 |
+
) -> Optional[list[Dependency]]:
|
| 1222 |
+
"""Maybe add a new parallel predicate from perp predicate."""
|
| 1223 |
+
a, b, c, d = points
|
| 1224 |
+
halfpi = self.aconst[(1, 2)]
|
| 1225 |
+
for ang in halfpi.val.neighbors(Angle):
|
| 1226 |
+
if ang == halfpi:
|
| 1227 |
+
continue
|
| 1228 |
+
d1, d2 = ang.directions
|
| 1229 |
+
x, y = d1._obj.points
|
| 1230 |
+
m, n = d2._obj.points
|
| 1231 |
+
|
| 1232 |
+
for args in [
|
| 1233 |
+
(a, b, c, d, x, y, m, n),
|
| 1234 |
+
(a, b, c, d, m, n, x, y),
|
| 1235 |
+
(c, d, a, b, x, y, m, n),
|
| 1236 |
+
(c, d, a, b, m, n, x, y),
|
| 1237 |
+
]:
|
| 1238 |
+
args = args + (deps,)
|
| 1239 |
+
add = self._add_para_or_coll(*args)
|
| 1240 |
+
if add:
|
| 1241 |
+
return add
|
| 1242 |
+
|
| 1243 |
+
return None
|
| 1244 |
+
|
| 1245 |
+
def add_perp(
|
| 1246 |
+
self, points: list[Point], deps: EmptyDependency
|
| 1247 |
+
) -> list[Dependency]:
|
| 1248 |
+
"""Add a new perpendicular predicate from 4 points (2 lines)."""
|
| 1249 |
+
add = self.maybe_make_para_from_perp(points, deps)
|
| 1250 |
+
if add is not None:
|
| 1251 |
+
return add
|
| 1252 |
+
|
| 1253 |
+
a, b, c, d = points
|
| 1254 |
+
ab, why1 = self.get_line_thru_pair_why(a, b)
|
| 1255 |
+
cd, why2 = self.get_line_thru_pair_why(c, d)
|
| 1256 |
+
|
| 1257 |
+
(a, b), (c, d) = ab.points, cd.points
|
| 1258 |
+
|
| 1259 |
+
if why1 + why2:
|
| 1260 |
+
dep0 = deps.populate('perp', points)
|
| 1261 |
+
deps = EmptyDependency(level=deps.level, rule_name=None)
|
| 1262 |
+
deps.why = [dep0] + why1 + why2
|
| 1263 |
+
|
| 1264 |
+
self.connect_val(ab, deps=None)
|
| 1265 |
+
self.connect_val(cd, deps=None)
|
| 1266 |
+
|
| 1267 |
+
if ab.val == cd.val:
|
| 1268 |
+
raise ValueError(f'{ab.name} and {cd.name} Cannot be perp.')
|
| 1269 |
+
|
| 1270 |
+
args = [a, b, c, d]
|
| 1271 |
+
i = 0
|
| 1272 |
+
for x, y, xy in [(a, b, ab), (c, d, cd)]:
|
| 1273 |
+
i += 1
|
| 1274 |
+
x_, y_ = xy._val._obj.points
|
| 1275 |
+
if {x, y} == {x_, y_}:
|
| 1276 |
+
continue
|
| 1277 |
+
if deps:
|
| 1278 |
+
deps = deps.extend(self, 'perp', list(args), 'para', [x, y, x_, y_])
|
| 1279 |
+
args[2 * i - 2] = x_
|
| 1280 |
+
args[2 * i - 1] = y_
|
| 1281 |
+
|
| 1282 |
+
a12, a21, why = self._get_or_create_angle(ab, cd, deps=None)
|
| 1283 |
+
|
| 1284 |
+
if why:
|
| 1285 |
+
dep0 = deps.populate('perp', [a, b, c, d])
|
| 1286 |
+
deps = EmptyDependency(level=deps.level, rule_name=None)
|
| 1287 |
+
deps.why = [dep0] + why
|
| 1288 |
+
|
| 1289 |
+
dab, dcd = a12._d
|
| 1290 |
+
a, b = dab._obj.points
|
| 1291 |
+
c, d = dcd._obj.points
|
| 1292 |
+
|
| 1293 |
+
is_equal = self.is_equal(a12, a21)
|
| 1294 |
+
deps = deps.populate('perp', [a, b, c, d])
|
| 1295 |
+
deps.algebra = [dab, dcd]
|
| 1296 |
+
self.make_equal(a12, a21, deps=deps)
|
| 1297 |
+
|
| 1298 |
+
self.cache_dep('perp', [a, b, c, d], deps)
|
| 1299 |
+
self.cache_dep('eqangle', [a, b, c, d, c, d, a, b], deps)
|
| 1300 |
+
|
| 1301 |
+
if not is_equal:
|
| 1302 |
+
return [deps]
|
| 1303 |
+
return []
|
| 1304 |
+
|
| 1305 |
+
def why_perp(
|
| 1306 |
+
self, args: list[Union[Point, list[Dependency]]]
|
| 1307 |
+
) -> list[Dependency]:
|
| 1308 |
+
a, b, deps = args
|
| 1309 |
+
return deps + self.why_equal(a, b, None)
|
| 1310 |
+
|
| 1311 |
+
def check_perpl(self, ab: Line, cd: Line) -> bool:
|
| 1312 |
+
if ab.val is None or cd.val is None:
|
| 1313 |
+
return False
|
| 1314 |
+
if ab.val == cd.val:
|
| 1315 |
+
return False
|
| 1316 |
+
a12, a21 = self._get_angle(ab.val, cd.val)
|
| 1317 |
+
if a12 is None or a21 is None:
|
| 1318 |
+
return False
|
| 1319 |
+
return self.is_equal(a12, a21)
|
| 1320 |
+
|
| 1321 |
+
def check_perp(self, points: list[Point]) -> bool:
|
| 1322 |
+
a, b, c, d = points
|
| 1323 |
+
ab = self._get_line(a, b)
|
| 1324 |
+
cd = self._get_line(c, d)
|
| 1325 |
+
if not ab or not cd:
|
| 1326 |
+
return False
|
| 1327 |
+
return self.check_perpl(ab, cd)
|
| 1328 |
+
|
| 1329 |
+
def check_nperp(self, points: list[Point]) -> bool:
|
| 1330 |
+
if self.check_perp(points):
|
| 1331 |
+
return False
|
| 1332 |
+
return not nm.check_perp([p.num for p in points])
|
| 1333 |
+
|
| 1334 |
+
def _get_segment(self, p1: Point, p2: Point) -> Optional[Segment]:
|
| 1335 |
+
for s in self.type2nodes[Segment]:
|
| 1336 |
+
if s.points == {p1, p2}:
|
| 1337 |
+
return s
|
| 1338 |
+
return None
|
| 1339 |
+
|
| 1340 |
+
def _get_or_create_segment(
|
| 1341 |
+
self, p1: Point, p2: Point, deps: Dependency
|
| 1342 |
+
) -> Segment:
|
| 1343 |
+
"""Get or create a Segment object between two Points p1 and p2."""
|
| 1344 |
+
if p1 == p2:
|
| 1345 |
+
raise ValueError(f'Creating same 0-length segment {p1.name}')
|
| 1346 |
+
|
| 1347 |
+
for s in self.type2nodes[Segment]:
|
| 1348 |
+
if s.points == {p1, p2}:
|
| 1349 |
+
return s
|
| 1350 |
+
|
| 1351 |
+
if p1.name > p2.name:
|
| 1352 |
+
p1, p2 = p2, p1
|
| 1353 |
+
s = self.new_node(Segment, name=f'{p1.name.upper()}{p2.name.upper()}')
|
| 1354 |
+
self.connect(p1, s, deps=deps)
|
| 1355 |
+
self.connect(p2, s, deps=deps)
|
| 1356 |
+
s.points = {p1, p2}
|
| 1357 |
+
return s
|
| 1358 |
+
|
| 1359 |
+
def add_cong(
|
| 1360 |
+
self, points: list[Point], deps: EmptyDependency
|
| 1361 |
+
) -> list[Dependency]:
|
| 1362 |
+
"""Add that two segments (4 points) are congruent."""
|
| 1363 |
+
a, b, c, d = points
|
| 1364 |
+
ab = self._get_or_create_segment(a, b, deps=None)
|
| 1365 |
+
cd = self._get_or_create_segment(c, d, deps=None)
|
| 1366 |
+
|
| 1367 |
+
is_equal = self.is_equal(ab, cd)
|
| 1368 |
+
|
| 1369 |
+
dep = deps.populate('cong', [a, b, c, d])
|
| 1370 |
+
self.make_equal(ab, cd, deps=dep)
|
| 1371 |
+
dep.algebra = ab._val, cd._val
|
| 1372 |
+
|
| 1373 |
+
self.cache_dep('cong', [a, b, c, d], dep)
|
| 1374 |
+
|
| 1375 |
+
result = []
|
| 1376 |
+
|
| 1377 |
+
if not is_equal:
|
| 1378 |
+
result += [dep]
|
| 1379 |
+
|
| 1380 |
+
if a not in [c, d] and b not in [c, d]:
|
| 1381 |
+
return result
|
| 1382 |
+
|
| 1383 |
+
if b in [c, d]:
|
| 1384 |
+
a, b = b, a
|
| 1385 |
+
if a == d:
|
| 1386 |
+
c, d = d, c # pylint: disable=unused-variable
|
| 1387 |
+
|
| 1388 |
+
result += self._maybe_add_cyclic_from_cong(a, b, d, dep)
|
| 1389 |
+
return result
|
| 1390 |
+
|
| 1391 |
+
def _maybe_add_cyclic_from_cong(
|
| 1392 |
+
self, a: Point, b: Point, c: Point, cong_ab_ac: Dependency
|
| 1393 |
+
) -> list[Dependency]:
|
| 1394 |
+
"""Maybe add a new cyclic predicate from given congruent segments."""
|
| 1395 |
+
ab = self._get_or_create_segment(a, b, deps=None)
|
| 1396 |
+
|
| 1397 |
+
# all eq segs with one end being a.
|
| 1398 |
+
segs = [s for s in ab.val.neighbors(Segment) if a in s.points]
|
| 1399 |
+
|
| 1400 |
+
# all points on circle (a, b)
|
| 1401 |
+
points = []
|
| 1402 |
+
for s in segs:
|
| 1403 |
+
x, y = list(s.points)
|
| 1404 |
+
points.append(x if y == a else y)
|
| 1405 |
+
|
| 1406 |
+
# for sure both b and c are in points
|
| 1407 |
+
points = [p for p in points if p not in [b, c]]
|
| 1408 |
+
|
| 1409 |
+
if len(points) < 2:
|
| 1410 |
+
return []
|
| 1411 |
+
|
| 1412 |
+
x, y = points[:2]
|
| 1413 |
+
|
| 1414 |
+
if self.check_cyclic([b, c, x, y]):
|
| 1415 |
+
return []
|
| 1416 |
+
|
| 1417 |
+
ax = self._get_or_create_segment(a, x, deps=None)
|
| 1418 |
+
ay = self._get_or_create_segment(a, y, deps=None)
|
| 1419 |
+
why = ab._val.why_equal([ax._val, ay._val], level=None)
|
| 1420 |
+
why += [cong_ab_ac]
|
| 1421 |
+
|
| 1422 |
+
deps = EmptyDependency(cong_ab_ac.level, '')
|
| 1423 |
+
deps.why = why
|
| 1424 |
+
|
| 1425 |
+
return self.add_cyclic([b, c, x, y], deps)
|
| 1426 |
+
|
| 1427 |
+
def check_cong(self, points: list[Point]) -> bool:
|
| 1428 |
+
a, b, c, d = points
|
| 1429 |
+
if {a, b} == {c, d}:
|
| 1430 |
+
return True
|
| 1431 |
+
|
| 1432 |
+
ab = self._get_segment(a, b)
|
| 1433 |
+
cd = self._get_segment(c, d)
|
| 1434 |
+
if ab is None or cd is None:
|
| 1435 |
+
return False
|
| 1436 |
+
return self.is_equal(ab, cd)
|
| 1437 |
+
|
| 1438 |
+
def why_cong(self, args: tuple[Segment, Segment]) -> list[Dependency]:
|
| 1439 |
+
ab, cd = args
|
| 1440 |
+
return self.why_equal(ab, cd, None)
|
| 1441 |
+
|
| 1442 |
+
def add_midp(
|
| 1443 |
+
self, points: list[Point], deps: EmptyDependency
|
| 1444 |
+
) -> list[Dependency]:
|
| 1445 |
+
m, a, b = points
|
| 1446 |
+
add = self.add_coll(points, deps=deps)
|
| 1447 |
+
add += self.add_cong([m, a, m, b], deps)
|
| 1448 |
+
return add
|
| 1449 |
+
|
| 1450 |
+
def why_midp(
|
| 1451 |
+
self, args: tuple[Line, list[Point], Segment, Segment]
|
| 1452 |
+
) -> list[Dependency]:
|
| 1453 |
+
line, points, ma, mb = args
|
| 1454 |
+
return self.why_coll([line, points]) + self.why_cong([ma, mb])
|
| 1455 |
+
|
| 1456 |
+
def check_midp(self, points: list[Point]) -> bool:
|
| 1457 |
+
if not self.check_coll(points):
|
| 1458 |
+
return False
|
| 1459 |
+
m, a, b = points
|
| 1460 |
+
return self.check_cong([m, a, m, b])
|
| 1461 |
+
|
| 1462 |
+
def add_circle(
|
| 1463 |
+
self, points: list[Point], deps: EmptyDependency
|
| 1464 |
+
) -> list[Dependency]:
|
| 1465 |
+
o, a, b, c = points
|
| 1466 |
+
add = self.add_cong([o, a, o, b], deps=deps)
|
| 1467 |
+
add += self.add_cong([o, a, o, c], deps=deps)
|
| 1468 |
+
return add
|
| 1469 |
+
|
| 1470 |
+
def why_circle(
|
| 1471 |
+
self, args: tuple[Segment, Segment, Segment]
|
| 1472 |
+
) -> list[Dependency]:
|
| 1473 |
+
oa, ob, oc = args
|
| 1474 |
+
return self.why_equal(oa, ob, None) and self.why_equal(oa, oc, None)
|
| 1475 |
+
|
| 1476 |
+
def check_circle(self, points: list[Point]) -> bool:
|
| 1477 |
+
o, a, b, c = points
|
| 1478 |
+
return self.check_cong([o, a, o, b]) and self.check_cong([o, a, o, c])
|
| 1479 |
+
|
| 1480 |
+
def get_circles_thru_all(self, *points: list[Point]) -> list[Circle]:
|
| 1481 |
+
circle2count = defaultdict(lambda: 0)
|
| 1482 |
+
points = set(points)
|
| 1483 |
+
for p in points:
|
| 1484 |
+
for c in p.neighbors(Circle):
|
| 1485 |
+
circle2count[c] += 1
|
| 1486 |
+
return [c for c, count in circle2count.items() if count == len(points)]
|
| 1487 |
+
|
| 1488 |
+
def _get_circles(self, *points: list[Point]) -> list[Circle]:
|
| 1489 |
+
circle2count = defaultdict(lambda: 0)
|
| 1490 |
+
for p in points:
|
| 1491 |
+
for c in p.neighbors(Circle):
|
| 1492 |
+
circle2count[c] += 1
|
| 1493 |
+
return [c for c, count in circle2count.items() if count >= 3]
|
| 1494 |
+
|
| 1495 |
+
def cyclic_dep(self, points: list[Point], p: Point) -> list[Dependency]:
|
| 1496 |
+
for p1, p2, p3 in utils.comb3(points):
|
| 1497 |
+
if self.check_cyclic([p1, p2, p3, p]):
|
| 1498 |
+
dep = Dependency('cyclic', [p1, p2, p3, p], None, None)
|
| 1499 |
+
return dep.why_me_or_cache(self, None)
|
| 1500 |
+
|
| 1501 |
+
def add_cyclic(
|
| 1502 |
+
self, points: list[Point], deps: EmptyDependency
|
| 1503 |
+
) -> list[Dependency]:
|
| 1504 |
+
"""Add a new cyclic predicate that 4 points are concyclic."""
|
| 1505 |
+
points = list(set(points))
|
| 1506 |
+
og_points = list(points)
|
| 1507 |
+
|
| 1508 |
+
all_circles = []
|
| 1509 |
+
for p1, p2, p3 in utils.comb3(points):
|
| 1510 |
+
all_circles.append(self.get_circle_thru_triplet(p1, p2, p3))
|
| 1511 |
+
points = sum([c.neighbors(Point) for c in all_circles], [])
|
| 1512 |
+
points = list(set(points))
|
| 1513 |
+
|
| 1514 |
+
existed = set()
|
| 1515 |
+
new = set()
|
| 1516 |
+
for p1, p2, p3 in utils.comb3(points):
|
| 1517 |
+
p1, p2, p3 = sorted([p1, p2, p3], key=lambda x: x.name)
|
| 1518 |
+
|
| 1519 |
+
if (p1, p2, p3) in self._triplet2circle:
|
| 1520 |
+
circle = self._triplet2circle[(p1, p2, p3)]
|
| 1521 |
+
existed.add(circle)
|
| 1522 |
+
else:
|
| 1523 |
+
circle = self.get_new_circle_thru_triplet(p1, p2, p3)
|
| 1524 |
+
new.add(circle)
|
| 1525 |
+
|
| 1526 |
+
existed = sorted(existed, key=lambda l: l.name)
|
| 1527 |
+
new = sorted(new, key=lambda l: l.name)
|
| 1528 |
+
|
| 1529 |
+
existed, new = list(existed), list(new)
|
| 1530 |
+
if not existed:
|
| 1531 |
+
circle0, *circles = new
|
| 1532 |
+
else:
|
| 1533 |
+
circle0, circles = existed[0], existed[1:] + new
|
| 1534 |
+
|
| 1535 |
+
add = []
|
| 1536 |
+
circle0, why0 = circle0.rep_and_why()
|
| 1537 |
+
a, b, c = circle0.points
|
| 1538 |
+
for circle in circles:
|
| 1539 |
+
d, e, f = circle.points
|
| 1540 |
+
args = list({a, b, c, d, e, f})
|
| 1541 |
+
if len(args) < 4:
|
| 1542 |
+
continue
|
| 1543 |
+
whys = []
|
| 1544 |
+
for x in [a, b, c, d, e, f]:
|
| 1545 |
+
if x not in og_points:
|
| 1546 |
+
whys.append(self.cyclic_dep(og_points, x))
|
| 1547 |
+
abcdef_deps = deps
|
| 1548 |
+
if whys + why0:
|
| 1549 |
+
dep0 = deps.populate('cyclic', og_points)
|
| 1550 |
+
abcdef_deps = EmptyDependency(level=deps.level, rule_name=None)
|
| 1551 |
+
abcdef_deps.why = [dep0] + whys
|
| 1552 |
+
|
| 1553 |
+
is_cyclic = self.check_cyclic(args)
|
| 1554 |
+
|
| 1555 |
+
dep = abcdef_deps.populate('cyclic', args)
|
| 1556 |
+
self.cache_dep('cyclic', args, dep)
|
| 1557 |
+
self.merge_into(circle0, [circle], dep)
|
| 1558 |
+
if not is_cyclic:
|
| 1559 |
+
add += [dep]
|
| 1560 |
+
|
| 1561 |
+
return add
|
| 1562 |
+
|
| 1563 |
+
def check_cyclic(self, points: list[Point]) -> bool:
|
| 1564 |
+
points = list(set(points))
|
| 1565 |
+
if len(points) < 4:
|
| 1566 |
+
return True
|
| 1567 |
+
circle2count = defaultdict(lambda: 0)
|
| 1568 |
+
for p in points:
|
| 1569 |
+
for c in p.neighbors(Circle):
|
| 1570 |
+
circle2count[c] += 1
|
| 1571 |
+
return any([count == len(points) for _, count in circle2count.items()])
|
| 1572 |
+
|
| 1573 |
+
def make_equal_pairs(
|
| 1574 |
+
self,
|
| 1575 |
+
a: Point,
|
| 1576 |
+
b: Point,
|
| 1577 |
+
c: Point,
|
| 1578 |
+
d: Point,
|
| 1579 |
+
m: Point,
|
| 1580 |
+
n: Point,
|
| 1581 |
+
p: Point,
|
| 1582 |
+
q: Point,
|
| 1583 |
+
ab: Line,
|
| 1584 |
+
cd: Line,
|
| 1585 |
+
mn: Line,
|
| 1586 |
+
pq: Line,
|
| 1587 |
+
deps: EmptyDependency,
|
| 1588 |
+
) -> list[Dependency]:
|
| 1589 |
+
"""Add ab/cd = mn/pq in case either two of (ab,cd,mn,pq) are equal."""
|
| 1590 |
+
depname = 'eqratio' if isinstance(ab, Segment) else 'eqangle'
|
| 1591 |
+
eqname = 'cong' if isinstance(ab, Segment) else 'para'
|
| 1592 |
+
|
| 1593 |
+
is_equal = self.is_equal(mn, pq)
|
| 1594 |
+
|
| 1595 |
+
if ab != cd:
|
| 1596 |
+
dep0 = deps.populate(depname, [a, b, c, d, m, n, p, q])
|
| 1597 |
+
deps = EmptyDependency(level=deps.level, rule_name=None)
|
| 1598 |
+
|
| 1599 |
+
dep = Dependency(eqname, [a, b, c, d], None, deps.level)
|
| 1600 |
+
deps.why = [dep0, dep.why_me_or_cache(self, None)]
|
| 1601 |
+
|
| 1602 |
+
elif eqname == 'para': # ab == cd.
|
| 1603 |
+
colls = [a, b, c, d]
|
| 1604 |
+
if len(set(colls)) > 2:
|
| 1605 |
+
dep0 = deps.populate(depname, [a, b, c, d, m, n, p, q])
|
| 1606 |
+
deps = EmptyDependency(level=deps.level, rule_name=None)
|
| 1607 |
+
|
| 1608 |
+
dep = Dependency('collx', colls, None, deps.level)
|
| 1609 |
+
deps.why = [dep0, dep.why_me_or_cache(self, None)]
|
| 1610 |
+
|
| 1611 |
+
deps = deps.populate(eqname, [m, n, p, q])
|
| 1612 |
+
self.make_equal(mn, pq, deps=deps)
|
| 1613 |
+
|
| 1614 |
+
deps.algebra = mn._val, pq._val
|
| 1615 |
+
self.cache_dep(eqname, [m, n, p, q], deps)
|
| 1616 |
+
|
| 1617 |
+
if is_equal:
|
| 1618 |
+
return []
|
| 1619 |
+
return [deps]
|
| 1620 |
+
|
| 1621 |
+
def maybe_make_equal_pairs(
|
| 1622 |
+
self,
|
| 1623 |
+
a: Point,
|
| 1624 |
+
b: Point,
|
| 1625 |
+
c: Point,
|
| 1626 |
+
d: Point,
|
| 1627 |
+
m: Point,
|
| 1628 |
+
n: Point,
|
| 1629 |
+
p: Point,
|
| 1630 |
+
q: Point,
|
| 1631 |
+
ab: Line,
|
| 1632 |
+
cd: Line,
|
| 1633 |
+
mn: Line,
|
| 1634 |
+
pq: Line,
|
| 1635 |
+
deps: EmptyDependency,
|
| 1636 |
+
) -> Optional[list[Dependency]]:
|
| 1637 |
+
"""Add ab/cd = mn/pq in case maybe either two of (ab,cd,mn,pq) are equal."""
|
| 1638 |
+
level = deps.level
|
| 1639 |
+
if self.is_equal(ab, cd, level):
|
| 1640 |
+
return self.make_equal_pairs(a, b, c, d, m, n, p, q, ab, cd, mn, pq, deps)
|
| 1641 |
+
elif self.is_equal(mn, pq, level):
|
| 1642 |
+
return self.make_equal_pairs( # pylint: disable=arguments-out-of-order
|
| 1643 |
+
m,
|
| 1644 |
+
n,
|
| 1645 |
+
p,
|
| 1646 |
+
q,
|
| 1647 |
+
a,
|
| 1648 |
+
b,
|
| 1649 |
+
c,
|
| 1650 |
+
d,
|
| 1651 |
+
mn,
|
| 1652 |
+
pq,
|
| 1653 |
+
ab,
|
| 1654 |
+
cd,
|
| 1655 |
+
deps,
|
| 1656 |
+
)
|
| 1657 |
+
elif self.is_equal(ab, mn, level):
|
| 1658 |
+
return self.make_equal_pairs( # pylint: disable=arguments-out-of-order
|
| 1659 |
+
a,
|
| 1660 |
+
b,
|
| 1661 |
+
m,
|
| 1662 |
+
n,
|
| 1663 |
+
c,
|
| 1664 |
+
d,
|
| 1665 |
+
p,
|
| 1666 |
+
q,
|
| 1667 |
+
ab,
|
| 1668 |
+
mn,
|
| 1669 |
+
cd,
|
| 1670 |
+
pq,
|
| 1671 |
+
deps,
|
| 1672 |
+
)
|
| 1673 |
+
elif self.is_equal(cd, pq, level):
|
| 1674 |
+
return self.make_equal_pairs( # pylint: disable=arguments-out-of-order
|
| 1675 |
+
c,
|
| 1676 |
+
d,
|
| 1677 |
+
p,
|
| 1678 |
+
q,
|
| 1679 |
+
a,
|
| 1680 |
+
b,
|
| 1681 |
+
m,
|
| 1682 |
+
n,
|
| 1683 |
+
cd,
|
| 1684 |
+
pq,
|
| 1685 |
+
ab,
|
| 1686 |
+
mn,
|
| 1687 |
+
deps,
|
| 1688 |
+
)
|
| 1689 |
+
else:
|
| 1690 |
+
return None
|
| 1691 |
+
|
| 1692 |
+
def _add_eqangle(
|
| 1693 |
+
self,
|
| 1694 |
+
a: Point,
|
| 1695 |
+
b: Point,
|
| 1696 |
+
c: Point,
|
| 1697 |
+
d: Point,
|
| 1698 |
+
m: Point,
|
| 1699 |
+
n: Point,
|
| 1700 |
+
p: Point,
|
| 1701 |
+
q: Point,
|
| 1702 |
+
ab: Line,
|
| 1703 |
+
cd: Line,
|
| 1704 |
+
mn: Line,
|
| 1705 |
+
pq: Line,
|
| 1706 |
+
deps: EmptyDependency,
|
| 1707 |
+
) -> list[Dependency]:
|
| 1708 |
+
"""Add eqangle core."""
|
| 1709 |
+
if deps:
|
| 1710 |
+
deps = deps.copy()
|
| 1711 |
+
|
| 1712 |
+
args = [a, b, c, d, m, n, p, q]
|
| 1713 |
+
i = 0
|
| 1714 |
+
for x, y, xy in [(a, b, ab), (c, d, cd), (m, n, mn), (p, q, pq)]:
|
| 1715 |
+
i += 1
|
| 1716 |
+
x_, y_ = xy._val._obj.points
|
| 1717 |
+
if {x, y} == {x_, y_}:
|
| 1718 |
+
continue
|
| 1719 |
+
if deps:
|
| 1720 |
+
deps = deps.extend(self, 'eqangle', list(args), 'para', [x, y, x_, y_])
|
| 1721 |
+
|
| 1722 |
+
args[2 * i - 2] = x_
|
| 1723 |
+
args[2 * i - 1] = y_
|
| 1724 |
+
|
| 1725 |
+
add = []
|
| 1726 |
+
ab_cd, cd_ab, why1 = self._get_or_create_angle(ab, cd, deps=None)
|
| 1727 |
+
mn_pq, pq_mn, why2 = self._get_or_create_angle(mn, pq, deps=None)
|
| 1728 |
+
|
| 1729 |
+
why = why1 + why2
|
| 1730 |
+
if why:
|
| 1731 |
+
dep0 = deps.populate('eqangle', args)
|
| 1732 |
+
deps = EmptyDependency(level=deps.level, rule_name=None)
|
| 1733 |
+
deps.why = [dep0] + why
|
| 1734 |
+
|
| 1735 |
+
dab, dcd = ab_cd._d
|
| 1736 |
+
dmn, dpq = mn_pq._d
|
| 1737 |
+
|
| 1738 |
+
a, b = dab._obj.points
|
| 1739 |
+
c, d = dcd._obj.points
|
| 1740 |
+
m, n = dmn._obj.points
|
| 1741 |
+
p, q = dpq._obj.points
|
| 1742 |
+
|
| 1743 |
+
is_eq1 = self.is_equal(ab_cd, mn_pq)
|
| 1744 |
+
deps1 = None
|
| 1745 |
+
if deps:
|
| 1746 |
+
deps1 = deps.populate('eqangle', [a, b, c, d, m, n, p, q])
|
| 1747 |
+
deps1.algebra = [dab, dcd, dmn, dpq]
|
| 1748 |
+
if not is_eq1:
|
| 1749 |
+
add += [deps1]
|
| 1750 |
+
self.cache_dep('eqangle', [a, b, c, d, m, n, p, q], deps1)
|
| 1751 |
+
self.make_equal(ab_cd, mn_pq, deps=deps1)
|
| 1752 |
+
|
| 1753 |
+
is_eq2 = self.is_equal(cd_ab, pq_mn)
|
| 1754 |
+
deps2 = None
|
| 1755 |
+
if deps:
|
| 1756 |
+
deps2 = deps.populate('eqangle', [c, d, a, b, p, q, m, n])
|
| 1757 |
+
deps2.algebra = [dcd, dab, dpq, dmn]
|
| 1758 |
+
if not is_eq2:
|
| 1759 |
+
add += [deps2]
|
| 1760 |
+
self.cache_dep('eqangle', [c, d, a, b, p, q, m, n], deps2)
|
| 1761 |
+
self.make_equal(cd_ab, pq_mn, deps=deps2)
|
| 1762 |
+
|
| 1763 |
+
return add
|
| 1764 |
+
|
| 1765 |
+
def add_eqangle(
|
| 1766 |
+
self, points: list[Point], deps: EmptyDependency
|
| 1767 |
+
) -> list[Dependency]:
|
| 1768 |
+
"""Add eqangle made by 8 points in `points`."""
|
| 1769 |
+
if deps:
|
| 1770 |
+
deps = deps.copy()
|
| 1771 |
+
a, b, c, d, m, n, p, q = points
|
| 1772 |
+
ab, why1 = self.get_line_thru_pair_why(a, b)
|
| 1773 |
+
cd, why2 = self.get_line_thru_pair_why(c, d)
|
| 1774 |
+
mn, why3 = self.get_line_thru_pair_why(m, n)
|
| 1775 |
+
pq, why4 = self.get_line_thru_pair_why(p, q)
|
| 1776 |
+
|
| 1777 |
+
a, b = ab.points
|
| 1778 |
+
c, d = cd.points
|
| 1779 |
+
m, n = mn.points
|
| 1780 |
+
p, q = pq.points
|
| 1781 |
+
|
| 1782 |
+
if deps and why1 + why2 + why3 + why4:
|
| 1783 |
+
dep0 = deps.populate('eqangle', points)
|
| 1784 |
+
deps = EmptyDependency(level=deps.level, rule_name=None)
|
| 1785 |
+
deps.why = [dep0] + why1 + why2 + why3 + why4
|
| 1786 |
+
|
| 1787 |
+
add = self.maybe_make_equal_pairs(
|
| 1788 |
+
a, b, c, d, m, n, p, q, ab, cd, mn, pq, deps
|
| 1789 |
+
)
|
| 1790 |
+
|
| 1791 |
+
if add is not None:
|
| 1792 |
+
return add
|
| 1793 |
+
|
| 1794 |
+
self.connect_val(ab, deps=None)
|
| 1795 |
+
self.connect_val(cd, deps=None)
|
| 1796 |
+
self.connect_val(mn, deps=None)
|
| 1797 |
+
self.connect_val(pq, deps=None)
|
| 1798 |
+
|
| 1799 |
+
add = []
|
| 1800 |
+
if (
|
| 1801 |
+
ab.val != cd.val
|
| 1802 |
+
and mn.val != pq.val
|
| 1803 |
+
and (ab.val != mn.val or cd.val != pq.val)
|
| 1804 |
+
):
|
| 1805 |
+
add += self._add_eqangle(a, b, c, d, m, n, p, q, ab, cd, mn, pq, deps)
|
| 1806 |
+
|
| 1807 |
+
if (
|
| 1808 |
+
ab.val != mn.val
|
| 1809 |
+
and cd.val != pq.val
|
| 1810 |
+
and (ab.val != cd.val or mn.val != pq.val)
|
| 1811 |
+
):
|
| 1812 |
+
add += self._add_eqangle( # pylint: disable=arguments-out-of-order
|
| 1813 |
+
a,
|
| 1814 |
+
b,
|
| 1815 |
+
m,
|
| 1816 |
+
n,
|
| 1817 |
+
c,
|
| 1818 |
+
d,
|
| 1819 |
+
p,
|
| 1820 |
+
q,
|
| 1821 |
+
ab,
|
| 1822 |
+
mn,
|
| 1823 |
+
cd,
|
| 1824 |
+
pq,
|
| 1825 |
+
deps,
|
| 1826 |
+
)
|
| 1827 |
+
|
| 1828 |
+
return add
|
| 1829 |
+
|
| 1830 |
+
def add_aconst(
|
| 1831 |
+
self, points: list[Point], deps: EmptyDependency
|
| 1832 |
+
) -> list[Dependency]:
|
| 1833 |
+
"""Add that an angle is equal to some constant."""
|
| 1834 |
+
a, b, c, d, num, den = points
|
| 1835 |
+
nd, dn = self.get_or_create_const_ang(num, den)
|
| 1836 |
+
|
| 1837 |
+
if nd == self.halfpi:
|
| 1838 |
+
return self.add_perp([a, b, c, d], deps)
|
| 1839 |
+
|
| 1840 |
+
ab, why1 = self.get_line_thru_pair_why(a, b)
|
| 1841 |
+
cd, why2 = self.get_line_thru_pair_why(c, d)
|
| 1842 |
+
|
| 1843 |
+
(a, b), (c, d) = ab.points, cd.points
|
| 1844 |
+
if why1 + why2:
|
| 1845 |
+
args = points[:-2] + [nd]
|
| 1846 |
+
dep0 = deps.populate('aconst', args)
|
| 1847 |
+
deps = EmptyDependency(level=deps.level, rule_name=None)
|
| 1848 |
+
deps.why = [dep0] + why1 + why2
|
| 1849 |
+
|
| 1850 |
+
self.connect_val(ab, deps=None)
|
| 1851 |
+
self.connect_val(cd, deps=None)
|
| 1852 |
+
|
| 1853 |
+
if ab.val == cd.val:
|
| 1854 |
+
raise ValueError(f'{ab.name} - {cd.name} cannot be {nd.name}')
|
| 1855 |
+
|
| 1856 |
+
args = [a, b, c, d, nd]
|
| 1857 |
+
i = 0
|
| 1858 |
+
for x, y, xy in [(a, b, ab), (c, d, cd)]:
|
| 1859 |
+
i += 1
|
| 1860 |
+
x_, y_ = xy._val._obj.points
|
| 1861 |
+
if {x, y} == {x_, y_}:
|
| 1862 |
+
continue
|
| 1863 |
+
if deps:
|
| 1864 |
+
deps = deps.extend(self, 'aconst', list(args), 'para', [x, y, x_, y_])
|
| 1865 |
+
args[2 * i - 2] = x_
|
| 1866 |
+
args[2 * i - 1] = y_
|
| 1867 |
+
|
| 1868 |
+
ab_cd, cd_ab, why = self._get_or_create_angle(ab, cd, deps=None)
|
| 1869 |
+
if why:
|
| 1870 |
+
dep0 = deps.populate('aconst', [a, b, c, d, nd])
|
| 1871 |
+
deps = EmptyDependency(level=deps.level, rule_name=None)
|
| 1872 |
+
deps.why = [dep0] + why
|
| 1873 |
+
|
| 1874 |
+
dab, dcd = ab_cd._d
|
| 1875 |
+
a, b = dab._obj.points
|
| 1876 |
+
c, d = dcd._obj.points
|
| 1877 |
+
|
| 1878 |
+
ang = int(num) * 180 / int(den)
|
| 1879 |
+
add = []
|
| 1880 |
+
if not self.is_equal(ab_cd, nd):
|
| 1881 |
+
deps1 = deps.populate('aconst', [a, b, c, d, nd])
|
| 1882 |
+
deps1.algebra = dab, dcd, ang % 180
|
| 1883 |
+
self.make_equal(ab_cd, nd, deps=deps1)
|
| 1884 |
+
self.cache_dep('aconst', [a, b, c, d, nd], deps1)
|
| 1885 |
+
add += [deps1]
|
| 1886 |
+
|
| 1887 |
+
if not self.is_equal(cd_ab, dn):
|
| 1888 |
+
deps2 = deps.populate('aconst', [c, d, a, b, dn])
|
| 1889 |
+
deps2.algebra = dcd, dab, 180 - ang % 180
|
| 1890 |
+
self.make_equal(cd_ab, dn, deps=deps2)
|
| 1891 |
+
self.cache_dep('aconst', [c, d, a, b, dn], deps2)
|
| 1892 |
+
add += [deps2]
|
| 1893 |
+
return add
|
| 1894 |
+
|
| 1895 |
+
def add_s_angle(
|
| 1896 |
+
self, points: list[Point], deps: EmptyDependency
|
| 1897 |
+
) -> list[Dependency]:
|
| 1898 |
+
"""Add that an angle abx is equal to constant y."""
|
| 1899 |
+
a, b, x, y = points
|
| 1900 |
+
|
| 1901 |
+
n, d = ar.simplify(y % 180, 180)
|
| 1902 |
+
nd, dn = self.get_or_create_const_ang(n, d)
|
| 1903 |
+
|
| 1904 |
+
if nd == self.halfpi:
|
| 1905 |
+
return self.add_perp([a, b, b, x], deps)
|
| 1906 |
+
|
| 1907 |
+
ab, why1 = self.get_line_thru_pair_why(a, b)
|
| 1908 |
+
bx, why2 = self.get_line_thru_pair_why(b, x)
|
| 1909 |
+
|
| 1910 |
+
self.connect_val(ab, deps=None)
|
| 1911 |
+
self.connect_val(bx, deps=None)
|
| 1912 |
+
add = []
|
| 1913 |
+
|
| 1914 |
+
if ab.val == bx.val:
|
| 1915 |
+
return add
|
| 1916 |
+
|
| 1917 |
+
deps.why += why1 + why2
|
| 1918 |
+
|
| 1919 |
+
for p, q, pq in [(a, b, ab), (b, x, bx)]:
|
| 1920 |
+
p_, q_ = pq.val._obj.points
|
| 1921 |
+
if {p, q} == {p_, q_}:
|
| 1922 |
+
continue
|
| 1923 |
+
dep = Dependency('para', [p, q, p_, q_], None, deps.level)
|
| 1924 |
+
deps.why += [dep.why_me_or_cache(self, None)]
|
| 1925 |
+
|
| 1926 |
+
xba, abx, why = self._get_or_create_angle(bx, ab, deps=None)
|
| 1927 |
+
if why:
|
| 1928 |
+
dep0 = deps.populate('aconst', [b, x, a, b, nd])
|
| 1929 |
+
deps = EmptyDependency(level=deps.level, rule_name=None)
|
| 1930 |
+
deps.why = [dep0] + why
|
| 1931 |
+
|
| 1932 |
+
dab, dbx = abx._d
|
| 1933 |
+
a, b = dab._obj.points
|
| 1934 |
+
c, x = dbx._obj.points
|
| 1935 |
+
|
| 1936 |
+
if not self.is_equal(xba, nd):
|
| 1937 |
+
deps1 = deps.populate('aconst', [c, x, a, b, nd])
|
| 1938 |
+
deps1.algebra = dbx, dab, y % 180
|
| 1939 |
+
|
| 1940 |
+
self.make_equal(xba, nd, deps=deps1)
|
| 1941 |
+
self.cache_dep('aconst', [c, x, a, b, nd], deps1)
|
| 1942 |
+
add += [deps1]
|
| 1943 |
+
|
| 1944 |
+
if not self.is_equal(abx, dn):
|
| 1945 |
+
deps2 = deps.populate('aconst', [a, b, c, x, dn])
|
| 1946 |
+
deps2.algebra = dab, dbx, 180 - (y % 180)
|
| 1947 |
+
|
| 1948 |
+
self.make_equal(abx, dn, deps=deps2)
|
| 1949 |
+
self.cache_dep('s_angle', [a, b, c, x, dn], deps2)
|
| 1950 |
+
add += [deps2]
|
| 1951 |
+
return add
|
| 1952 |
+
|
| 1953 |
+
def check_aconst(self, points: list[Point], verbose: bool = False) -> bool:
|
| 1954 |
+
"""Check if the angle is equal to a certain constant."""
|
| 1955 |
+
a, b, c, d, nd = points
|
| 1956 |
+
_ = verbose
|
| 1957 |
+
if isinstance(nd, str):
|
| 1958 |
+
name = nd
|
| 1959 |
+
else:
|
| 1960 |
+
name = nd.name
|
| 1961 |
+
num, den = name.split('pi/')
|
| 1962 |
+
ang, _ = self.get_or_create_const_ang(int(num), int(den))
|
| 1963 |
+
|
| 1964 |
+
ab = self._get_line(a, b)
|
| 1965 |
+
cd = self._get_line(c, d)
|
| 1966 |
+
if not ab or not cd:
|
| 1967 |
+
return False
|
| 1968 |
+
|
| 1969 |
+
if not (ab.val and cd.val):
|
| 1970 |
+
return False
|
| 1971 |
+
|
| 1972 |
+
for ang1, _, _ in gm.all_angles(ab._val, cd._val):
|
| 1973 |
+
if self.is_equal(ang1, ang):
|
| 1974 |
+
return True
|
| 1975 |
+
return False
|
| 1976 |
+
|
| 1977 |
+
def check_acompute(self, points: list[Point]) -> bool:
|
| 1978 |
+
"""Check if an angle has a constant value."""
|
| 1979 |
+
a, b, c, d = points
|
| 1980 |
+
ab = self._get_line(a, b)
|
| 1981 |
+
cd = self._get_line(c, d)
|
| 1982 |
+
if not ab or not cd:
|
| 1983 |
+
return False
|
| 1984 |
+
|
| 1985 |
+
if not (ab.val and cd.val):
|
| 1986 |
+
return False
|
| 1987 |
+
|
| 1988 |
+
for ang0 in self.aconst.values():
|
| 1989 |
+
for ang in ang0.val.neighbors(Angle):
|
| 1990 |
+
d1, d2 = ang.directions
|
| 1991 |
+
if ab.val == d1 and cd.val == d2:
|
| 1992 |
+
return True
|
| 1993 |
+
return False
|
| 1994 |
+
|
| 1995 |
+
def check_eqangle(self, points: list[Point]) -> bool:
|
| 1996 |
+
"""Check if two angles are equal."""
|
| 1997 |
+
a, b, c, d, m, n, p, q = points
|
| 1998 |
+
|
| 1999 |
+
if {a, b} == {c, d} and {m, n} == {p, q}:
|
| 2000 |
+
return True
|
| 2001 |
+
if {a, b} == {m, n} and {c, d} == {p, q}:
|
| 2002 |
+
return True
|
| 2003 |
+
|
| 2004 |
+
if (a == b) or (c == d) or (m == n) or (p == q):
|
| 2005 |
+
return False
|
| 2006 |
+
ab = self._get_line(a, b)
|
| 2007 |
+
cd = self._get_line(c, d)
|
| 2008 |
+
mn = self._get_line(m, n)
|
| 2009 |
+
pq = self._get_line(p, q)
|
| 2010 |
+
|
| 2011 |
+
if {a, b} == {c, d} and mn and pq and self.is_equal(mn, pq):
|
| 2012 |
+
return True
|
| 2013 |
+
if {a, b} == {m, n} and cd and pq and self.is_equal(cd, pq):
|
| 2014 |
+
return True
|
| 2015 |
+
if {p, q} == {m, n} and ab and cd and self.is_equal(ab, cd):
|
| 2016 |
+
return True
|
| 2017 |
+
if {p, q} == {c, d} and ab and mn and self.is_equal(ab, mn):
|
| 2018 |
+
return True
|
| 2019 |
+
|
| 2020 |
+
if not ab or not cd or not mn or not pq:
|
| 2021 |
+
return False
|
| 2022 |
+
|
| 2023 |
+
if self.is_equal(ab, cd) and self.is_equal(mn, pq):
|
| 2024 |
+
return True
|
| 2025 |
+
if self.is_equal(ab, mn) and self.is_equal(cd, pq):
|
| 2026 |
+
return True
|
| 2027 |
+
|
| 2028 |
+
if not (ab.val and cd.val and mn.val and pq.val):
|
| 2029 |
+
return False
|
| 2030 |
+
|
| 2031 |
+
if (ab.val, cd.val) == (mn.val, pq.val) or (ab.val, mn.val) == (
|
| 2032 |
+
cd.val,
|
| 2033 |
+
pq.val,
|
| 2034 |
+
):
|
| 2035 |
+
return True
|
| 2036 |
+
|
| 2037 |
+
for ang1, _, _ in gm.all_angles(ab._val, cd._val):
|
| 2038 |
+
for ang2, _, _ in gm.all_angles(mn._val, pq._val):
|
| 2039 |
+
if self.is_equal(ang1, ang2):
|
| 2040 |
+
return True
|
| 2041 |
+
|
| 2042 |
+
if self.check_perp([a, b, m, n]) and self.check_perp([c, d, p, q]):
|
| 2043 |
+
return True
|
| 2044 |
+
if self.check_perp([a, b, p, q]) and self.check_perp([c, d, m, n]):
|
| 2045 |
+
return True
|
| 2046 |
+
|
| 2047 |
+
return False
|
| 2048 |
+
|
| 2049 |
+
def _get_ratio(self, l1: Length, l2: Length) -> tuple[Ratio, Ratio]:
|
| 2050 |
+
for r in self.type2nodes[Ratio]:
|
| 2051 |
+
if r.lengths == (l1, l2):
|
| 2052 |
+
return r, r.opposite
|
| 2053 |
+
return None, None
|
| 2054 |
+
|
| 2055 |
+
def _get_or_create_ratio(
|
| 2056 |
+
self, s1: Segment, s2: Segment, deps: Dependency
|
| 2057 |
+
) -> tuple[Ratio, Ratio, list[Dependency]]:
|
| 2058 |
+
return self._get_or_create_ratio_l(s1._val, s2._val, deps)
|
| 2059 |
+
|
| 2060 |
+
def _get_or_create_ratio_l(
|
| 2061 |
+
self, l1: Length, l2: Length, deps: Dependency
|
| 2062 |
+
) -> tuple[Ratio, Ratio, list[Dependency]]:
|
| 2063 |
+
"""Get or create a new Ratio from two Lenghts l1 and l2."""
|
| 2064 |
+
for r in self.type2nodes[Ratio]:
|
| 2065 |
+
if r.lengths == (l1.rep(), l2.rep()):
|
| 2066 |
+
l1_, l2_ = r._l
|
| 2067 |
+
why1 = l1.why_equal([l1_], None) + l1_.why_rep()
|
| 2068 |
+
why2 = l2.why_equal([l2_], None) + l2_.why_rep()
|
| 2069 |
+
return r, r.opposite, why1 + why2
|
| 2070 |
+
|
| 2071 |
+
l1, why1 = l1.rep_and_why()
|
| 2072 |
+
l2, why2 = l2.rep_and_why()
|
| 2073 |
+
r12 = self.new_node(Ratio, f'{l1.name}/{l2.name}')
|
| 2074 |
+
r21 = self.new_node(Ratio, f'{l2.name}/{l1.name}')
|
| 2075 |
+
self.connect(l1, r12, deps)
|
| 2076 |
+
self.connect(l2, r21, deps)
|
| 2077 |
+
self.connect(r12, r21, deps)
|
| 2078 |
+
r12.set_lengths(l1, l2)
|
| 2079 |
+
r21.set_lengths(l2, l1)
|
| 2080 |
+
r12.opposite = r21
|
| 2081 |
+
r21.opposite = r12
|
| 2082 |
+
return r12, r21, why1 + why2
|
| 2083 |
+
|
| 2084 |
+
def add_cong2(
|
| 2085 |
+
self, points: list[Point], deps: EmptyDependency
|
| 2086 |
+
) -> list[Dependency]:
|
| 2087 |
+
m, n, a, b = points
|
| 2088 |
+
add = []
|
| 2089 |
+
add += self.add_cong([m, a, n, a], deps)
|
| 2090 |
+
add += self.add_cong([m, b, n, b], deps)
|
| 2091 |
+
return add
|
| 2092 |
+
|
| 2093 |
+
def add_eqratio3(
|
| 2094 |
+
self, points: list[Point], deps: EmptyDependency
|
| 2095 |
+
) -> list[Dependency]:
|
| 2096 |
+
"""Add three eqratios through a list of 6 points (due to parallel lines)."""
|
| 2097 |
+
a, b, c, d, m, n = points
|
| 2098 |
+
# a -- b
|
| 2099 |
+
# m -- n
|
| 2100 |
+
# c -- d
|
| 2101 |
+
add = []
|
| 2102 |
+
add += self.add_eqratio([m, a, m, c, n, b, n, d], deps)
|
| 2103 |
+
add += self.add_eqratio([a, m, a, c, b, n, b, d], deps)
|
| 2104 |
+
add += self.add_eqratio([c, m, c, a, d, n, d, b], deps)
|
| 2105 |
+
if m == n:
|
| 2106 |
+
add += self.add_eqratio([m, a, m, c, a, b, c, d], deps)
|
| 2107 |
+
return add
|
| 2108 |
+
|
| 2109 |
+
def add_eqratio4(
|
| 2110 |
+
self, points: list[Point], deps: EmptyDependency
|
| 2111 |
+
) -> list[Dependency]:
|
| 2112 |
+
o, a, b, c, d = points
|
| 2113 |
+
# o
|
| 2114 |
+
# a b
|
| 2115 |
+
# c d
|
| 2116 |
+
add = self.add_eqratio3([a, b, c, d, o, o], deps)
|
| 2117 |
+
add += self.add_eqratio([o, a, o, c, a, b, c, d], deps)
|
| 2118 |
+
return add
|
| 2119 |
+
|
| 2120 |
+
def _add_eqratio(
|
| 2121 |
+
self,
|
| 2122 |
+
a: Point,
|
| 2123 |
+
b: Point,
|
| 2124 |
+
c: Point,
|
| 2125 |
+
d: Point,
|
| 2126 |
+
m: Point,
|
| 2127 |
+
n: Point,
|
| 2128 |
+
p: Point,
|
| 2129 |
+
q: Point,
|
| 2130 |
+
ab: Segment,
|
| 2131 |
+
cd: Segment,
|
| 2132 |
+
mn: Segment,
|
| 2133 |
+
pq: Segment,
|
| 2134 |
+
deps: EmptyDependency,
|
| 2135 |
+
) -> list[Dependency]:
|
| 2136 |
+
"""Add a new eqratio from 8 points (core)."""
|
| 2137 |
+
if deps:
|
| 2138 |
+
deps = deps.copy()
|
| 2139 |
+
|
| 2140 |
+
args = [a, b, c, d, m, n, p, q]
|
| 2141 |
+
i = 0
|
| 2142 |
+
for x, y, xy in [(a, b, ab), (c, d, cd), (m, n, mn), (p, q, pq)]:
|
| 2143 |
+
if {x, y} == set(xy.points):
|
| 2144 |
+
continue
|
| 2145 |
+
x_, y_ = list(xy.points)
|
| 2146 |
+
if deps:
|
| 2147 |
+
deps = deps.extend(self, 'eqratio', list(args), 'cong', [x, y, x_, y_])
|
| 2148 |
+
args[2 * i - 2] = x_
|
| 2149 |
+
args[2 * i - 1] = y_
|
| 2150 |
+
|
| 2151 |
+
add = []
|
| 2152 |
+
ab_cd, cd_ab, why1 = self._get_or_create_ratio(ab, cd, deps=None)
|
| 2153 |
+
mn_pq, pq_mn, why2 = self._get_or_create_ratio(mn, pq, deps=None)
|
| 2154 |
+
|
| 2155 |
+
why = why1 + why2
|
| 2156 |
+
if why:
|
| 2157 |
+
dep0 = deps.populate('eqratio', args)
|
| 2158 |
+
deps = EmptyDependency(level=deps.level, rule_name=None)
|
| 2159 |
+
deps.why = [dep0] + why
|
| 2160 |
+
|
| 2161 |
+
lab, lcd = ab_cd._l
|
| 2162 |
+
lmn, lpq = mn_pq._l
|
| 2163 |
+
|
| 2164 |
+
a, b = lab._obj.points
|
| 2165 |
+
c, d = lcd._obj.points
|
| 2166 |
+
m, n = lmn._obj.points
|
| 2167 |
+
p, q = lpq._obj.points
|
| 2168 |
+
|
| 2169 |
+
is_eq1 = self.is_equal(ab_cd, mn_pq)
|
| 2170 |
+
deps1 = None
|
| 2171 |
+
if deps:
|
| 2172 |
+
deps1 = deps.populate('eqratio', [a, b, c, d, m, n, p, q])
|
| 2173 |
+
deps1.algebra = [ab._val, cd._val, mn._val, pq._val]
|
| 2174 |
+
if not is_eq1:
|
| 2175 |
+
add += [deps1]
|
| 2176 |
+
self.cache_dep('eqratio', [a, b, c, d, m, n, p, q], deps1)
|
| 2177 |
+
self.make_equal(ab_cd, mn_pq, deps=deps1)
|
| 2178 |
+
|
| 2179 |
+
is_eq2 = self.is_equal(cd_ab, pq_mn)
|
| 2180 |
+
deps2 = None
|
| 2181 |
+
if deps:
|
| 2182 |
+
deps2 = deps.populate('eqratio', [c, d, a, b, p, q, m, n])
|
| 2183 |
+
deps2.algebra = [cd._val, ab._val, pq._val, mn._val]
|
| 2184 |
+
if not is_eq2:
|
| 2185 |
+
add += [deps2]
|
| 2186 |
+
self.cache_dep('eqratio', [c, d, a, b, p, q, m, n], deps2)
|
| 2187 |
+
self.make_equal(cd_ab, pq_mn, deps=deps2)
|
| 2188 |
+
return add
|
| 2189 |
+
|
| 2190 |
+
def add_eqratio(
|
| 2191 |
+
self, points: list[Point], deps: EmptyDependency
|
| 2192 |
+
) -> list[Dependency]:
|
| 2193 |
+
"""Add a new eqratio from 8 points."""
|
| 2194 |
+
if deps:
|
| 2195 |
+
deps = deps.copy()
|
| 2196 |
+
a, b, c, d, m, n, p, q = points
|
| 2197 |
+
ab = self._get_or_create_segment(a, b, deps=None)
|
| 2198 |
+
cd = self._get_or_create_segment(c, d, deps=None)
|
| 2199 |
+
mn = self._get_or_create_segment(m, n, deps=None)
|
| 2200 |
+
pq = self._get_or_create_segment(p, q, deps=None)
|
| 2201 |
+
|
| 2202 |
+
add = self.maybe_make_equal_pairs(
|
| 2203 |
+
a, b, c, d, m, n, p, q, ab, cd, mn, pq, deps
|
| 2204 |
+
)
|
| 2205 |
+
|
| 2206 |
+
if add is not None:
|
| 2207 |
+
return add
|
| 2208 |
+
|
| 2209 |
+
self.connect_val(ab, deps=None)
|
| 2210 |
+
self.connect_val(cd, deps=None)
|
| 2211 |
+
self.connect_val(mn, deps=None)
|
| 2212 |
+
self.connect_val(pq, deps=None)
|
| 2213 |
+
|
| 2214 |
+
add = []
|
| 2215 |
+
if (
|
| 2216 |
+
ab.val != cd.val
|
| 2217 |
+
and mn.val != pq.val
|
| 2218 |
+
and (ab.val != mn.val or cd.val != pq.val)
|
| 2219 |
+
):
|
| 2220 |
+
add += self._add_eqratio(a, b, c, d, m, n, p, q, ab, cd, mn, pq, deps)
|
| 2221 |
+
|
| 2222 |
+
if (
|
| 2223 |
+
ab.val != mn.val
|
| 2224 |
+
and cd.val != pq.val
|
| 2225 |
+
and (ab.val != cd.val or mn.val != pq.val)
|
| 2226 |
+
):
|
| 2227 |
+
add += self._add_eqratio( # pylint: disable=arguments-out-of-order
|
| 2228 |
+
a,
|
| 2229 |
+
b,
|
| 2230 |
+
m,
|
| 2231 |
+
n,
|
| 2232 |
+
c,
|
| 2233 |
+
d,
|
| 2234 |
+
p,
|
| 2235 |
+
q,
|
| 2236 |
+
ab,
|
| 2237 |
+
mn,
|
| 2238 |
+
cd,
|
| 2239 |
+
pq,
|
| 2240 |
+
deps,
|
| 2241 |
+
)
|
| 2242 |
+
return add
|
| 2243 |
+
|
| 2244 |
+
def check_rconst(self, points: list[Point], verbose: bool = False) -> bool:
|
| 2245 |
+
"""Check whether a ratio is equal to some given constant."""
|
| 2246 |
+
_ = verbose
|
| 2247 |
+
a, b, c, d, nd = points
|
| 2248 |
+
if isinstance(nd, str):
|
| 2249 |
+
name = nd
|
| 2250 |
+
else:
|
| 2251 |
+
name = nd.name
|
| 2252 |
+
num, den = name.split('/')
|
| 2253 |
+
rat, _ = self.get_or_create_const_rat(int(num), int(den))
|
| 2254 |
+
|
| 2255 |
+
ab = self._get_segment(a, b)
|
| 2256 |
+
cd = self._get_segment(c, d)
|
| 2257 |
+
|
| 2258 |
+
if not ab or not cd:
|
| 2259 |
+
return False
|
| 2260 |
+
|
| 2261 |
+
if not (ab.val and cd.val):
|
| 2262 |
+
return False
|
| 2263 |
+
|
| 2264 |
+
for rat1, _, _ in gm.all_ratios(ab._val, cd._val):
|
| 2265 |
+
if self.is_equal(rat1, rat):
|
| 2266 |
+
return True
|
| 2267 |
+
return False
|
| 2268 |
+
|
| 2269 |
+
def check_rcompute(self, points: list[Point]) -> bool:
|
| 2270 |
+
"""Check whether a ratio is equal to some constant."""
|
| 2271 |
+
a, b, c, d = points
|
| 2272 |
+
ab = self._get_segment(a, b)
|
| 2273 |
+
cd = self._get_segment(c, d)
|
| 2274 |
+
|
| 2275 |
+
if not ab or not cd:
|
| 2276 |
+
return False
|
| 2277 |
+
|
| 2278 |
+
if not (ab.val and cd.val):
|
| 2279 |
+
return False
|
| 2280 |
+
|
| 2281 |
+
for rat0 in self.rconst.values():
|
| 2282 |
+
for rat in rat0.val.neighbors(Ratio):
|
| 2283 |
+
l1, l2 = rat.lengths
|
| 2284 |
+
if ab.val == l1 and cd.val == l2:
|
| 2285 |
+
return True
|
| 2286 |
+
return False
|
| 2287 |
+
|
| 2288 |
+
def check_eqratio(self, points: list[Point]) -> bool:
|
| 2289 |
+
"""Check if 8 points make an eqratio predicate."""
|
| 2290 |
+
a, b, c, d, m, n, p, q = points
|
| 2291 |
+
|
| 2292 |
+
if {a, b} == {c, d} and {m, n} == {p, q}:
|
| 2293 |
+
return True
|
| 2294 |
+
if {a, b} == {m, n} and {c, d} == {p, q}:
|
| 2295 |
+
return True
|
| 2296 |
+
|
| 2297 |
+
ab = self._get_segment(a, b)
|
| 2298 |
+
cd = self._get_segment(c, d)
|
| 2299 |
+
mn = self._get_segment(m, n)
|
| 2300 |
+
pq = self._get_segment(p, q)
|
| 2301 |
+
|
| 2302 |
+
if {a, b} == {c, d} and mn and pq and self.is_equal(mn, pq):
|
| 2303 |
+
return True
|
| 2304 |
+
if {a, b} == {m, n} and cd and pq and self.is_equal(cd, pq):
|
| 2305 |
+
return True
|
| 2306 |
+
if {p, q} == {m, n} and ab and cd and self.is_equal(ab, cd):
|
| 2307 |
+
return True
|
| 2308 |
+
if {p, q} == {c, d} and ab and mn and self.is_equal(ab, mn):
|
| 2309 |
+
return True
|
| 2310 |
+
|
| 2311 |
+
if not ab or not cd or not mn or not pq:
|
| 2312 |
+
return False
|
| 2313 |
+
|
| 2314 |
+
if self.is_equal(ab, cd) and self.is_equal(mn, pq):
|
| 2315 |
+
return True
|
| 2316 |
+
if self.is_equal(ab, mn) and self.is_equal(cd, pq):
|
| 2317 |
+
return True
|
| 2318 |
+
|
| 2319 |
+
if not (ab.val and cd.val and mn.val and pq.val):
|
| 2320 |
+
return False
|
| 2321 |
+
|
| 2322 |
+
if (ab.val, cd.val) == (mn.val, pq.val) or (ab.val, mn.val) == (
|
| 2323 |
+
cd.val,
|
| 2324 |
+
pq.val,
|
| 2325 |
+
):
|
| 2326 |
+
return True
|
| 2327 |
+
|
| 2328 |
+
for rat1, _, _ in gm.all_ratios(ab._val, cd._val):
|
| 2329 |
+
for rat2, _, _ in gm.all_ratios(mn._val, pq._val):
|
| 2330 |
+
if self.is_equal(rat1, rat2):
|
| 2331 |
+
return True
|
| 2332 |
+
return False
|
| 2333 |
+
|
| 2334 |
+
def add_simtri_check(
|
| 2335 |
+
self, points: list[Point], deps: EmptyDependency
|
| 2336 |
+
) -> list[Dependency]:
|
| 2337 |
+
if nm.same_clock(*[p.num for p in points]):
|
| 2338 |
+
return self.add_simtri(points, deps)
|
| 2339 |
+
return self.add_simtri2(points, deps)
|
| 2340 |
+
|
| 2341 |
+
def add_contri_check(
|
| 2342 |
+
self, points: list[Point], deps: EmptyDependency
|
| 2343 |
+
) -> list[Dependency]:
|
| 2344 |
+
if nm.same_clock(*[p.num for p in points]):
|
| 2345 |
+
return self.add_contri(points, deps)
|
| 2346 |
+
return self.add_contri2(points, deps)
|
| 2347 |
+
|
| 2348 |
+
def enum_sides(
|
| 2349 |
+
self, points: list[Point]
|
| 2350 |
+
) -> Generator[list[Point], None, None]:
|
| 2351 |
+
a, b, c, x, y, z = points
|
| 2352 |
+
yield [a, b, x, y]
|
| 2353 |
+
yield [b, c, y, z]
|
| 2354 |
+
yield [c, a, z, x]
|
| 2355 |
+
|
| 2356 |
+
def enum_triangle(
|
| 2357 |
+
self, points: list[Point]
|
| 2358 |
+
) -> Generator[list[Point], None, None]:
|
| 2359 |
+
a, b, c, x, y, z = points
|
| 2360 |
+
yield [a, b, a, c, x, y, x, z]
|
| 2361 |
+
yield [b, a, b, c, y, x, y, z]
|
| 2362 |
+
yield [c, a, c, b, z, x, z, y]
|
| 2363 |
+
|
| 2364 |
+
def enum_triangle2(
|
| 2365 |
+
self, points: list[Point]
|
| 2366 |
+
) -> Generator[list[Point], None, None]:
|
| 2367 |
+
a, b, c, x, y, z = points
|
| 2368 |
+
yield [a, b, a, c, x, z, x, y]
|
| 2369 |
+
yield [b, a, b, c, y, z, y, x]
|
| 2370 |
+
yield [c, a, c, b, z, y, z, x]
|
| 2371 |
+
|
| 2372 |
+
def add_simtri(
|
| 2373 |
+
self, points: list[Point], deps: EmptyDependency
|
| 2374 |
+
) -> list[Dependency]:
|
| 2375 |
+
"""Add two similar triangles."""
|
| 2376 |
+
add = []
|
| 2377 |
+
hashs = [d.hashed() for d in deps.why]
|
| 2378 |
+
|
| 2379 |
+
for args in self.enum_triangle(points):
|
| 2380 |
+
if problem.hashed('eqangle6', args) in hashs:
|
| 2381 |
+
continue
|
| 2382 |
+
add += self.add_eqangle(args, deps=deps)
|
| 2383 |
+
|
| 2384 |
+
for args in self.enum_triangle(points):
|
| 2385 |
+
if problem.hashed('eqratio6', args) in hashs:
|
| 2386 |
+
continue
|
| 2387 |
+
add += self.add_eqratio(args, deps=deps)
|
| 2388 |
+
|
| 2389 |
+
return add
|
| 2390 |
+
|
| 2391 |
+
def check_simtri(self, points: list[Point]) -> bool:
|
| 2392 |
+
a, b, c, x, y, z = points
|
| 2393 |
+
return self.check_eqangle([a, b, a, c, x, y, x, z]) and self.check_eqangle(
|
| 2394 |
+
[b, a, b, c, y, x, y, z]
|
| 2395 |
+
)
|
| 2396 |
+
|
| 2397 |
+
def add_simtri2(
|
| 2398 |
+
self, points: list[Point], deps: EmptyDependency
|
| 2399 |
+
) -> list[Dependency]:
|
| 2400 |
+
"""Add two similar reflected triangles."""
|
| 2401 |
+
add = []
|
| 2402 |
+
hashs = [d.hashed() for d in deps.why]
|
| 2403 |
+
for args in self.enum_triangle2(points):
|
| 2404 |
+
if problem.hashed('eqangle6', args) in hashs:
|
| 2405 |
+
continue
|
| 2406 |
+
add += self.add_eqangle(args, deps=deps)
|
| 2407 |
+
|
| 2408 |
+
for args in self.enum_triangle(points):
|
| 2409 |
+
if problem.hashed('eqratio6', args) in hashs:
|
| 2410 |
+
continue
|
| 2411 |
+
add += self.add_eqratio(args, deps=deps)
|
| 2412 |
+
|
| 2413 |
+
return add
|
| 2414 |
+
|
| 2415 |
+
def add_contri(
|
| 2416 |
+
self, points: list[Point], deps: EmptyDependency
|
| 2417 |
+
) -> list[Dependency]:
|
| 2418 |
+
"""Add two congruent triangles."""
|
| 2419 |
+
add = []
|
| 2420 |
+
hashs = [d.hashed() for d in deps.why]
|
| 2421 |
+
for args in self.enum_triangle(points):
|
| 2422 |
+
if problem.hashed('eqangle6', args) in hashs:
|
| 2423 |
+
continue
|
| 2424 |
+
add += self.add_eqangle(args, deps=deps)
|
| 2425 |
+
|
| 2426 |
+
for args in self.enum_sides(points):
|
| 2427 |
+
if problem.hashed('cong', args) in hashs:
|
| 2428 |
+
continue
|
| 2429 |
+
add += self.add_cong(args, deps=deps)
|
| 2430 |
+
return add
|
| 2431 |
+
|
| 2432 |
+
def check_contri(self, points: list[Point]) -> bool:
|
| 2433 |
+
a, b, c, x, y, z = points
|
| 2434 |
+
return (
|
| 2435 |
+
self.check_cong([a, b, x, y])
|
| 2436 |
+
and self.check_cong([b, c, y, z])
|
| 2437 |
+
and self.check_cong([c, a, z, x])
|
| 2438 |
+
)
|
| 2439 |
+
|
| 2440 |
+
def add_contri2(
|
| 2441 |
+
self, points: list[Point], deps: EmptyDependency
|
| 2442 |
+
) -> list[Dependency]:
|
| 2443 |
+
"""Add two congruent reflected triangles."""
|
| 2444 |
+
add = []
|
| 2445 |
+
hashs = [d.hashed() for d in deps.why]
|
| 2446 |
+
for args in self.enum_triangle2(points):
|
| 2447 |
+
if problem.hashed('eqangle6', args) in hashs:
|
| 2448 |
+
continue
|
| 2449 |
+
add += self.add_eqangle(args, deps=deps)
|
| 2450 |
+
|
| 2451 |
+
for args in self.enum_sides(points):
|
| 2452 |
+
if problem.hashed('cong', args) in hashs:
|
| 2453 |
+
continue
|
| 2454 |
+
add += self.add_cong(args, deps=deps)
|
| 2455 |
+
|
| 2456 |
+
return add
|
| 2457 |
+
|
| 2458 |
+
def in_cache(self, name: str, args: list[Point]) -> bool:
|
| 2459 |
+
return problem.hashed(name, args) in self.cache
|
| 2460 |
+
|
| 2461 |
+
def cache_dep(
|
| 2462 |
+
self, name: str, args: list[Point], premises: list[Dependency]
|
| 2463 |
+
) -> None:
|
| 2464 |
+
hashed = problem.hashed(name, args)
|
| 2465 |
+
if hashed in self.cache:
|
| 2466 |
+
return
|
| 2467 |
+
self.cache[hashed] = premises
|
| 2468 |
+
|
| 2469 |
+
def all_same_line(
|
| 2470 |
+
self, a: Point, b: Point
|
| 2471 |
+
) -> Generator[tuple[Point, Point], None, None]:
|
| 2472 |
+
ab = self._get_line(a, b)
|
| 2473 |
+
if ab is None:
|
| 2474 |
+
return
|
| 2475 |
+
for p1, p2 in utils.comb2(ab.neighbors(Point)):
|
| 2476 |
+
if {p1, p2} != {a, b}:
|
| 2477 |
+
yield p1, p2
|
| 2478 |
+
|
| 2479 |
+
def all_same_angle(
|
| 2480 |
+
self, a: Point, b: Point, c: Point, d: Point
|
| 2481 |
+
) -> Generator[tuple[Point, Point, Point, Point], None, None]:
|
| 2482 |
+
for x, y in self.all_same_line(a, b):
|
| 2483 |
+
for m, n in self.all_same_line(c, d):
|
| 2484 |
+
yield x, y, m, n
|
| 2485 |
+
|
| 2486 |
+
def additionally_draw(self, name: str, args: list[Point]) -> None:
|
| 2487 |
+
"""Draw some extra line/circles for illustration purpose."""
|
| 2488 |
+
|
| 2489 |
+
if name in ['circle']:
|
| 2490 |
+
center, point = args[:2]
|
| 2491 |
+
circle = self.new_node(Circle, f'({center.name},{point.name})')
|
| 2492 |
+
circle.num = nm.Circle(center.num, p1=point.num)
|
| 2493 |
+
circle.points = center, point
|
| 2494 |
+
|
| 2495 |
+
if name in ['on_circle', 'tangent']:
|
| 2496 |
+
center, point = args[-2:]
|
| 2497 |
+
circle = self.new_node(Circle, f'({center.name},{point.name})')
|
| 2498 |
+
circle.num = nm.Circle(center.num, p1=point.num)
|
| 2499 |
+
circle.points = center, point
|
| 2500 |
+
|
| 2501 |
+
if name in ['incenter', 'excenter', 'incenter2', 'excenter2']:
|
| 2502 |
+
d, a, b, c = [x for x in args[-4:]]
|
| 2503 |
+
a, b, c = sorted([a, b, c], key=lambda x: x.name.lower())
|
| 2504 |
+
circle = self.new_node(Circle, f'({d.name},h.{a.name}{b.name})')
|
| 2505 |
+
p = d.num.foot(nm.Line(a.num, b.num))
|
| 2506 |
+
circle.num = nm.Circle(d.num, p1=p)
|
| 2507 |
+
circle.points = d, a, b, c
|
| 2508 |
+
|
| 2509 |
+
if name in ['cc_tangent']:
|
| 2510 |
+
o, a, w, b = args[-4:]
|
| 2511 |
+
c1 = self.new_node(Circle, f'({o.name},{a.name})')
|
| 2512 |
+
c1.num = nm.Circle(o.num, p1=a.num)
|
| 2513 |
+
c1.points = o, a
|
| 2514 |
+
|
| 2515 |
+
c2 = self.new_node(Circle, f'({w.name},{b.name})')
|
| 2516 |
+
c2.num = nm.Circle(w.num, p1=b.num)
|
| 2517 |
+
c2.points = w, b
|
| 2518 |
+
|
| 2519 |
+
if name in ['ninepoints']:
|
| 2520 |
+
a, b, c = args[-3:]
|
| 2521 |
+
a, b, c = sorted([a, b, c], key=lambda x: x.name.lower())
|
| 2522 |
+
circle = self.new_node(Circle, f'(,m.{a.name}{b.name}{c.name})')
|
| 2523 |
+
p1 = (b.num + c.num) * 0.5
|
| 2524 |
+
p2 = (c.num + a.num) * 0.5
|
| 2525 |
+
p3 = (a.num + b.num) * 0.5
|
| 2526 |
+
circle.num = nm.Circle(p1=p1, p2=p2, p3=p3)
|
| 2527 |
+
circle.points = (None, None, a, b, c)
|
| 2528 |
+
|
| 2529 |
+
if name in ['2l1c']:
|
| 2530 |
+
a, b, c, o = args[:4]
|
| 2531 |
+
a, b, c = sorted([a, b, c], key=lambda x: x.name.lower())
|
| 2532 |
+
circle = self.new_node(Circle, f'({o.name},{a.name}{b.name}{c.name})')
|
| 2533 |
+
circle.num = nm.Circle(p1=a.num, p2=b.num, p3=c.num)
|
| 2534 |
+
circle.points = (a, b, c)
|
| 2535 |
+
|
| 2536 |
+
def add_clause(
|
| 2537 |
+
self,
|
| 2538 |
+
clause: problem.Clause,
|
| 2539 |
+
plevel: int,
|
| 2540 |
+
definitions: dict[str, problem.Definition],
|
| 2541 |
+
verbose: int = False,
|
| 2542 |
+
) -> tuple[list[Dependency], int]:
|
| 2543 |
+
"""Add a new clause of construction, e.g. a new excenter."""
|
| 2544 |
+
existing_points = self.all_points()
|
| 2545 |
+
new_points = [Point(name) for name in clause.points]
|
| 2546 |
+
|
| 2547 |
+
new_points_dep_points = set()
|
| 2548 |
+
new_points_dep = []
|
| 2549 |
+
|
| 2550 |
+
# Step 1: check for all deps.
|
| 2551 |
+
for c in clause.constructions:
|
| 2552 |
+
cdef = definitions[c.name]
|
| 2553 |
+
|
| 2554 |
+
if len(cdef.construction.args) != len(c.args):
|
| 2555 |
+
if len(cdef.construction.args) - len(c.args) == len(clause.points):
|
| 2556 |
+
c.args = clause.points + c.args
|
| 2557 |
+
else:
|
| 2558 |
+
correct_form = ' '.join(cdef.points + ['=', c.name] + cdef.args)
|
| 2559 |
+
raise ValueError('Argument mismatch. ' + correct_form)
|
| 2560 |
+
|
| 2561 |
+
mapping = dict(zip(cdef.construction.args, c.args))
|
| 2562 |
+
c_name = 'midp' if c.name == 'midpoint' else c.name
|
| 2563 |
+
deps = EmptyDependency(level=0, rule_name=problem.CONSTRUCTION_RULE)
|
| 2564 |
+
deps.construction = Dependency(c_name, c.args, rule_name=None, level=0)
|
| 2565 |
+
|
| 2566 |
+
for d in cdef.deps.constructions:
|
| 2567 |
+
args = self.names2points([mapping[a] for a in d.args])
|
| 2568 |
+
new_points_dep_points.update(args)
|
| 2569 |
+
if not self.check(d.name, args):
|
| 2570 |
+
raise DepCheckFailError(
|
| 2571 |
+
d.name + ' ' + ' '.join([x.name for x in args])
|
| 2572 |
+
)
|
| 2573 |
+
deps.why += [
|
| 2574 |
+
Dependency(
|
| 2575 |
+
d.name, args, rule_name=problem.CONSTRUCTION_RULE, level=0
|
| 2576 |
+
)
|
| 2577 |
+
]
|
| 2578 |
+
|
| 2579 |
+
new_points_dep += [deps]
|
| 2580 |
+
|
| 2581 |
+
# Step 2: draw.
|
| 2582 |
+
def range_fn() -> (
|
| 2583 |
+
list[Union[nm.Point, nm.Line, nm.Circle, nm.HalfLine, nm.HoleCircle]]
|
| 2584 |
+
):
|
| 2585 |
+
to_be_intersected = []
|
| 2586 |
+
for c in clause.constructions:
|
| 2587 |
+
cdef = definitions[c.name]
|
| 2588 |
+
mapping = dict(zip(cdef.construction.args, c.args))
|
| 2589 |
+
for n in cdef.numerics:
|
| 2590 |
+
args = [mapping[a] for a in n.args]
|
| 2591 |
+
args = list(map(lambda x: self.get(x, lambda: int(x)), args))
|
| 2592 |
+
to_be_intersected += nm.sketch(n.name, args)
|
| 2593 |
+
|
| 2594 |
+
return to_be_intersected
|
| 2595 |
+
|
| 2596 |
+
is_total_free = (
|
| 2597 |
+
len(clause.constructions) == 1 and clause.constructions[0].name in FREE
|
| 2598 |
+
)
|
| 2599 |
+
is_semi_free = (
|
| 2600 |
+
len(clause.constructions) == 1
|
| 2601 |
+
and clause.constructions[0].name in INTERSECT
|
| 2602 |
+
)
|
| 2603 |
+
|
| 2604 |
+
existing_points = [p.num for p in existing_points]
|
| 2605 |
+
|
| 2606 |
+
def draw_fn() -> list[nm.Point]:
|
| 2607 |
+
to_be_intersected = range_fn()
|
| 2608 |
+
return nm.reduce(to_be_intersected, existing_points)
|
| 2609 |
+
|
| 2610 |
+
rely_on = set()
|
| 2611 |
+
for c in clause.constructions:
|
| 2612 |
+
cdef = definitions[c.name]
|
| 2613 |
+
mapping = dict(zip(cdef.construction.args, c.args))
|
| 2614 |
+
for n in cdef.numerics:
|
| 2615 |
+
args = [mapping[a] for a in n.args]
|
| 2616 |
+
args = list(map(lambda x: self.get(x, lambda: int(x)), args))
|
| 2617 |
+
rely_on.update([a for a in args if isinstance(a, Point)])
|
| 2618 |
+
|
| 2619 |
+
for p in rely_on:
|
| 2620 |
+
p.change.update(new_points)
|
| 2621 |
+
|
| 2622 |
+
nums = draw_fn()
|
| 2623 |
+
for p, num, num0 in zip(new_points, nums, clause.nums):
|
| 2624 |
+
p.co_change = new_points
|
| 2625 |
+
if isinstance(num0, nm.Point):
|
| 2626 |
+
num = num0
|
| 2627 |
+
elif isinstance(num0, (tuple, list)):
|
| 2628 |
+
x, y = num0
|
| 2629 |
+
num = nm.Point(x, y)
|
| 2630 |
+
|
| 2631 |
+
p.num = num
|
| 2632 |
+
|
| 2633 |
+
# check two things.
|
| 2634 |
+
if nm.check_too_close(nums, existing_points):
|
| 2635 |
+
raise PointTooCloseError()
|
| 2636 |
+
if nm.check_too_far(nums, existing_points):
|
| 2637 |
+
raise PointTooFarError()
|
| 2638 |
+
|
| 2639 |
+
# Commit: now that all conditions are passed.
|
| 2640 |
+
# add these points to current graph.
|
| 2641 |
+
for p in new_points:
|
| 2642 |
+
self._name2point[p.name] = p
|
| 2643 |
+
self._name2node[p.name] = p
|
| 2644 |
+
self.type2nodes[Point].append(p)
|
| 2645 |
+
|
| 2646 |
+
for p in new_points:
|
| 2647 |
+
p.why = sum([d.why for d in new_points_dep], []) # to generate txt logs.
|
| 2648 |
+
p.group = new_points
|
| 2649 |
+
p.dep_points = new_points_dep_points
|
| 2650 |
+
p.dep_points.update(new_points)
|
| 2651 |
+
p.plevel = plevel
|
| 2652 |
+
|
| 2653 |
+
# movement dependency:
|
| 2654 |
+
rely_dict_0 = defaultdict(lambda: [])
|
| 2655 |
+
|
| 2656 |
+
for c in clause.constructions:
|
| 2657 |
+
cdef = definitions[c.name]
|
| 2658 |
+
mapping = dict(zip(cdef.construction.args, c.args))
|
| 2659 |
+
for p, ps in cdef.rely.items():
|
| 2660 |
+
p = mapping[p]
|
| 2661 |
+
ps = [mapping[x] for x in ps]
|
| 2662 |
+
rely_dict_0[p].append(ps)
|
| 2663 |
+
|
| 2664 |
+
rely_dict = {}
|
| 2665 |
+
for p, pss in rely_dict_0.items():
|
| 2666 |
+
ps = sum(pss, [])
|
| 2667 |
+
if len(pss) > 1:
|
| 2668 |
+
ps = [x for x in ps if x != p]
|
| 2669 |
+
|
| 2670 |
+
p = self._name2point[p]
|
| 2671 |
+
ps = self.names2nodes(ps)
|
| 2672 |
+
rely_dict[p] = ps
|
| 2673 |
+
|
| 2674 |
+
for p in new_points:
|
| 2675 |
+
p.rely_on = set(rely_dict.get(p, []))
|
| 2676 |
+
for x in p.rely_on:
|
| 2677 |
+
if not hasattr(x, 'base_rely_on'):
|
| 2678 |
+
x.base_rely_on = set()
|
| 2679 |
+
p.base_rely_on = set.union(*[x.base_rely_on for x in p.rely_on] + [set()])
|
| 2680 |
+
if is_total_free or is_semi_free:
|
| 2681 |
+
p.rely_on.add(p)
|
| 2682 |
+
p.base_rely_on.add(p)
|
| 2683 |
+
|
| 2684 |
+
plevel_done = set()
|
| 2685 |
+
added = []
|
| 2686 |
+
basics = []
|
| 2687 |
+
# Step 3: build the basics.
|
| 2688 |
+
for c, deps in zip(clause.constructions, new_points_dep):
|
| 2689 |
+
cdef = definitions[c.name]
|
| 2690 |
+
mapping = dict(zip(cdef.construction.args, c.args))
|
| 2691 |
+
|
| 2692 |
+
# not necessary for proofing, but for visualization.
|
| 2693 |
+
c_args = list(map(lambda x: self.get(x, lambda: int(x)), c.args))
|
| 2694 |
+
self.additionally_draw(c.name, c_args)
|
| 2695 |
+
|
| 2696 |
+
for points, bs in cdef.basics:
|
| 2697 |
+
if points:
|
| 2698 |
+
points = self.names2nodes([mapping[p] for p in points])
|
| 2699 |
+
points = [p for p in points if p not in plevel_done]
|
| 2700 |
+
for p in points:
|
| 2701 |
+
p.plevel = plevel
|
| 2702 |
+
plevel_done.update(points)
|
| 2703 |
+
plevel += 1
|
| 2704 |
+
else:
|
| 2705 |
+
continue
|
| 2706 |
+
|
| 2707 |
+
for b in bs:
|
| 2708 |
+
if b.name != 'rconst':
|
| 2709 |
+
args = [mapping[a] for a in b.args]
|
| 2710 |
+
else:
|
| 2711 |
+
num, den = map(int, b.args[-2:])
|
| 2712 |
+
rat, _ = self.get_or_create_const_rat(num, den)
|
| 2713 |
+
args = [mapping[a] for a in b.args[:-2]] + [rat.name]
|
| 2714 |
+
|
| 2715 |
+
args = list(map(lambda x: self.get(x, lambda: int(x)), args))
|
| 2716 |
+
|
| 2717 |
+
adds = self.add_piece(name=b.name, args=args, deps=deps)
|
| 2718 |
+
basics.append((b.name, args, deps))
|
| 2719 |
+
if adds:
|
| 2720 |
+
added += adds
|
| 2721 |
+
for add in adds:
|
| 2722 |
+
self.cache_dep(add.name, add.args, add)
|
| 2723 |
+
|
| 2724 |
+
assert len(plevel_done) == len(new_points)
|
| 2725 |
+
for p in new_points:
|
| 2726 |
+
p.basics = basics
|
| 2727 |
+
|
| 2728 |
+
return added, plevel
|
| 2729 |
+
|
| 2730 |
+
def all_eqangle_same_lines(self) -> Generator[tuple[Point, ...], None, None]:
|
| 2731 |
+
for l1, l2 in utils.perm2(self.type2nodes[Line]):
|
| 2732 |
+
for a, b, c, d, e, f, g, h in utils.all_8points(l1, l2, l1, l2):
|
| 2733 |
+
if (a, b, c, d) != (e, f, g, h):
|
| 2734 |
+
yield a, b, c, d, e, f, g, h
|
| 2735 |
+
|
| 2736 |
+
def all_eqangles_distinct_linepairss(
|
| 2737 |
+
self,
|
| 2738 |
+
) -> Generator[tuple[Line, ...], None, None]:
|
| 2739 |
+
"""No eqangles betcause para-para, or para-corresponding, or same."""
|
| 2740 |
+
|
| 2741 |
+
for measure in self.type2nodes[Measure]:
|
| 2742 |
+
angs = measure.neighbors(Angle)
|
| 2743 |
+
line_pairss = []
|
| 2744 |
+
for ang in angs:
|
| 2745 |
+
d1, d2 = ang.directions
|
| 2746 |
+
if d1 is None or d2 is None:
|
| 2747 |
+
continue
|
| 2748 |
+
l1s = d1.neighbors(Line)
|
| 2749 |
+
l2s = d2.neighbors(Line)
|
| 2750 |
+
# Any pair in this is para-para.
|
| 2751 |
+
para_para = list(utils.cross(l1s, l2s))
|
| 2752 |
+
line_pairss.append(para_para)
|
| 2753 |
+
|
| 2754 |
+
for pairs1, pairs2 in utils.comb2(line_pairss):
|
| 2755 |
+
for pair1, pair2 in utils.cross(pairs1, pairs2):
|
| 2756 |
+
(l1, l2), (l3, l4) = pair1, pair2
|
| 2757 |
+
yield l1, l2, l3, l4
|
| 2758 |
+
|
| 2759 |
+
def all_eqangles_8points(self) -> Generator[tuple[Point, ...], None, None]:
|
| 2760 |
+
"""List all sets of 8 points that make two equal angles."""
|
| 2761 |
+
# Case 1: (l1-l2) = (l3-l4), including because l1//l3, l2//l4 (para-para)
|
| 2762 |
+
angss = []
|
| 2763 |
+
for measure in self.type2nodes[Measure]:
|
| 2764 |
+
angs = measure.neighbors(Angle)
|
| 2765 |
+
angss.append(angs)
|
| 2766 |
+
|
| 2767 |
+
# include the angs that do not have any measure.
|
| 2768 |
+
angss.extend([[ang] for ang in self.type2nodes[Angle] if ang.val is None])
|
| 2769 |
+
|
| 2770 |
+
line_pairss = []
|
| 2771 |
+
for angs in angss:
|
| 2772 |
+
line_pairs = set()
|
| 2773 |
+
for ang in angs:
|
| 2774 |
+
d1, d2 = ang.directions
|
| 2775 |
+
if d1 is None or d2 is None:
|
| 2776 |
+
continue
|
| 2777 |
+
l1s = d1.neighbors(Line)
|
| 2778 |
+
l2s = d2.neighbors(Line)
|
| 2779 |
+
line_pairs.update(set(utils.cross(l1s, l2s)))
|
| 2780 |
+
line_pairss.append(line_pairs)
|
| 2781 |
+
|
| 2782 |
+
# include (d1, d2) in which d1 does not have any angles.
|
| 2783 |
+
noang_ds = [d for d in self.type2nodes[Direction] if not d.neighbors(Angle)]
|
| 2784 |
+
|
| 2785 |
+
for d1 in noang_ds:
|
| 2786 |
+
for d2 in self.type2nodes[Direction]:
|
| 2787 |
+
if d1 == d2:
|
| 2788 |
+
continue
|
| 2789 |
+
l1s = d1.neighbors(Line)
|
| 2790 |
+
l2s = d2.neighbors(Line)
|
| 2791 |
+
if len(l1s) < 2 and len(l2s) < 2:
|
| 2792 |
+
continue
|
| 2793 |
+
line_pairss.append(set(utils.cross(l1s, l2s)))
|
| 2794 |
+
line_pairss.append(set(utils.cross(l2s, l1s)))
|
| 2795 |
+
|
| 2796 |
+
# Case 2: d1 // d2 => (d1-d3) = (d2-d3)
|
| 2797 |
+
# include lines that does not have any direction.
|
| 2798 |
+
nodir_ls = [l for l in self.type2nodes[Line] if l.val is None]
|
| 2799 |
+
|
| 2800 |
+
for line in nodir_ls:
|
| 2801 |
+
for d in self.type2nodes[Direction]:
|
| 2802 |
+
l1s = d.neighbors(Line)
|
| 2803 |
+
if len(l1s) < 2:
|
| 2804 |
+
continue
|
| 2805 |
+
l2s = [line]
|
| 2806 |
+
line_pairss.append(set(utils.cross(l1s, l2s)))
|
| 2807 |
+
line_pairss.append(set(utils.cross(l2s, l1s)))
|
| 2808 |
+
|
| 2809 |
+
record = set()
|
| 2810 |
+
for line_pairs in line_pairss:
|
| 2811 |
+
for pair1, pair2 in utils.perm2(list(line_pairs)):
|
| 2812 |
+
(l1, l2), (l3, l4) = pair1, pair2
|
| 2813 |
+
if l1 == l2 or l3 == l4:
|
| 2814 |
+
continue
|
| 2815 |
+
if (l1, l2) == (l3, l4):
|
| 2816 |
+
continue
|
| 2817 |
+
if (l1, l2, l3, l4) in record:
|
| 2818 |
+
continue
|
| 2819 |
+
record.add((l1, l2, l3, l4))
|
| 2820 |
+
for a, b, c, d, e, f, g, h in utils.all_8points(l1, l2, l3, l4):
|
| 2821 |
+
yield (a, b, c, d, e, f, g, h)
|
| 2822 |
+
|
| 2823 |
+
for a, b, c, d, e, f, g, h in self.all_eqangle_same_lines():
|
| 2824 |
+
yield a, b, c, d, e, f, g, h
|
| 2825 |
+
|
| 2826 |
+
def all_eqangles_6points(self) -> Generator[tuple[Point, ...], None, None]:
|
| 2827 |
+
"""List all sets of 6 points that make two equal angles."""
|
| 2828 |
+
record = set()
|
| 2829 |
+
for a, b, c, d, e, f, g, h in self.all_eqangles_8points():
|
| 2830 |
+
if (
|
| 2831 |
+
a not in (c, d)
|
| 2832 |
+
and b not in (c, d)
|
| 2833 |
+
or e not in (g, h)
|
| 2834 |
+
and f not in (g, h)
|
| 2835 |
+
):
|
| 2836 |
+
continue
|
| 2837 |
+
|
| 2838 |
+
if b in (c, d):
|
| 2839 |
+
a, b = b, a # now a in c, d
|
| 2840 |
+
if f in (g, h):
|
| 2841 |
+
e, f = f, e # now e in g, h
|
| 2842 |
+
if a == d:
|
| 2843 |
+
c, d = d, c # now a == c
|
| 2844 |
+
if e == h:
|
| 2845 |
+
g, h = h, g # now e == g
|
| 2846 |
+
if (a, b, c, d, e, f, g, h) in record:
|
| 2847 |
+
continue
|
| 2848 |
+
record.add((a, b, c, d, e, f, g, h))
|
| 2849 |
+
yield a, b, c, d, e, f, g, h # where a==c, e==g
|
| 2850 |
+
|
| 2851 |
+
def all_paras(self) -> Generator[tuple[Point, ...], None, None]:
|
| 2852 |
+
for d in self.type2nodes[Direction]:
|
| 2853 |
+
for l1, l2 in utils.perm2(d.neighbors(Line)):
|
| 2854 |
+
for a, b, c, d in utils.all_4points(l1, l2):
|
| 2855 |
+
yield a, b, c, d
|
| 2856 |
+
|
| 2857 |
+
def all_perps(self) -> Generator[tuple[Point, ...], None, None]:
|
| 2858 |
+
for ang in self.vhalfpi.neighbors(Angle):
|
| 2859 |
+
d1, d2 = ang.directions
|
| 2860 |
+
if d1 is None or d2 is None:
|
| 2861 |
+
continue
|
| 2862 |
+
if d1 == d2:
|
| 2863 |
+
continue
|
| 2864 |
+
for l1, l2 in utils.cross(d1.neighbors(Line), d2.neighbors(Line)):
|
| 2865 |
+
for a, b, c, d in utils.all_4points(l1, l2):
|
| 2866 |
+
yield a, b, c, d
|
| 2867 |
+
|
| 2868 |
+
def all_congs(self) -> Generator[tuple[Point, ...], None, None]:
|
| 2869 |
+
for l in self.type2nodes[Length]:
|
| 2870 |
+
for s1, s2 in utils.perm2(l.neighbors(Segment)):
|
| 2871 |
+
(a, b), (c, d) = s1.points, s2.points
|
| 2872 |
+
for x, y in [(a, b), (b, a)]:
|
| 2873 |
+
for m, n in [(c, d), (d, c)]:
|
| 2874 |
+
yield x, y, m, n
|
| 2875 |
+
|
| 2876 |
+
def all_eqratios_8points(self) -> Generator[tuple[Point, ...], None, None]:
|
| 2877 |
+
"""List all sets of 8 points that make two equal ratios."""
|
| 2878 |
+
ratss = []
|
| 2879 |
+
for value in self.type2nodes[Value]:
|
| 2880 |
+
rats = value.neighbors(Ratio)
|
| 2881 |
+
ratss.append(rats)
|
| 2882 |
+
|
| 2883 |
+
# include the rats that do not have any val.
|
| 2884 |
+
ratss.extend([[rat] for rat in self.type2nodes[Ratio] if rat.val is None])
|
| 2885 |
+
|
| 2886 |
+
seg_pairss = []
|
| 2887 |
+
for rats in ratss:
|
| 2888 |
+
seg_pairs = set()
|
| 2889 |
+
for rat in rats:
|
| 2890 |
+
l1, l2 = rat.lengths
|
| 2891 |
+
if l1 is None or l2 is None:
|
| 2892 |
+
continue
|
| 2893 |
+
s1s = l1.neighbors(Segment)
|
| 2894 |
+
s2s = l2.neighbors(Segment)
|
| 2895 |
+
seg_pairs.update(utils.cross(s1s, s2s))
|
| 2896 |
+
seg_pairss.append(seg_pairs)
|
| 2897 |
+
|
| 2898 |
+
# include (l1, l2) in which l1 does not have any ratio.
|
| 2899 |
+
norat_ls = [l for l in self.type2nodes[Length] if not l.neighbors(Ratio)]
|
| 2900 |
+
|
| 2901 |
+
for l1 in norat_ls:
|
| 2902 |
+
for l2 in self.type2nodes[Length]:
|
| 2903 |
+
if l1 == l2:
|
| 2904 |
+
continue
|
| 2905 |
+
s1s = l1.neighbors(Segment)
|
| 2906 |
+
s2s = l2.neighbors(Segment)
|
| 2907 |
+
if len(s1s) < 2 and len(s2s) < 2:
|
| 2908 |
+
continue
|
| 2909 |
+
seg_pairss.append(set(utils.cross(s1s, s2s)))
|
| 2910 |
+
seg_pairss.append(set(utils.cross(s2s, s1s)))
|
| 2911 |
+
|
| 2912 |
+
# include Seg that does not have any Length.
|
| 2913 |
+
nolen_ss = [s for s in self.type2nodes[Segment] if s.val is None]
|
| 2914 |
+
|
| 2915 |
+
for seg in nolen_ss:
|
| 2916 |
+
for l in self.type2nodes[Length]:
|
| 2917 |
+
s1s = l.neighbors(Segment)
|
| 2918 |
+
if len(s1s) == 1:
|
| 2919 |
+
continue
|
| 2920 |
+
s2s = [seg]
|
| 2921 |
+
seg_pairss.append(set(utils.cross(s1s, s2s)))
|
| 2922 |
+
seg_pairss.append(set(utils.cross(s2s, s1s)))
|
| 2923 |
+
|
| 2924 |
+
record = set()
|
| 2925 |
+
for seg_pairs in seg_pairss:
|
| 2926 |
+
for pair1, pair2 in utils.perm2(list(seg_pairs)):
|
| 2927 |
+
(s1, s2), (s3, s4) = pair1, pair2
|
| 2928 |
+
if s1 == s2 or s3 == s4:
|
| 2929 |
+
continue
|
| 2930 |
+
if (s1, s2) == (s3, s4):
|
| 2931 |
+
continue
|
| 2932 |
+
if (s1, s2, s3, s4) in record:
|
| 2933 |
+
continue
|
| 2934 |
+
record.add((s1, s2, s3, s4))
|
| 2935 |
+
a, b = s1.points
|
| 2936 |
+
c, d = s2.points
|
| 2937 |
+
e, f = s3.points
|
| 2938 |
+
g, h = s4.points
|
| 2939 |
+
|
| 2940 |
+
for x, y in [(a, b), (b, a)]:
|
| 2941 |
+
for z, t in [(c, d), (d, c)]:
|
| 2942 |
+
for m, n in [(e, f), (f, e)]:
|
| 2943 |
+
for p, q in [(g, h), (h, g)]:
|
| 2944 |
+
yield (x, y, z, t, m, n, p, q)
|
| 2945 |
+
|
| 2946 |
+
segss = []
|
| 2947 |
+
# finally the list of ratios that is equal to 1.0
|
| 2948 |
+
for length in self.type2nodes[Length]:
|
| 2949 |
+
segs = length.neighbors(Segment)
|
| 2950 |
+
segss.append(segs)
|
| 2951 |
+
|
| 2952 |
+
segs_pair = list(utils.perm2(list(segss)))
|
| 2953 |
+
segs_pair += list(zip(segss, segss))
|
| 2954 |
+
for segs1, segs2 in segs_pair:
|
| 2955 |
+
for s1, s2 in utils.perm2(list(segs1)):
|
| 2956 |
+
for s3, s4 in utils.perm2(list(segs2)):
|
| 2957 |
+
if (s1, s2) == (s3, s4) or (s1, s3) == (s2, s4):
|
| 2958 |
+
continue
|
| 2959 |
+
if (s1, s2, s3, s4) in record:
|
| 2960 |
+
continue
|
| 2961 |
+
record.add((s1, s2, s3, s4))
|
| 2962 |
+
a, b = s1.points
|
| 2963 |
+
c, d = s2.points
|
| 2964 |
+
e, f = s3.points
|
| 2965 |
+
g, h = s4.points
|
| 2966 |
+
|
| 2967 |
+
for x, y in [(a, b), (b, a)]:
|
| 2968 |
+
for z, t in [(c, d), (d, c)]:
|
| 2969 |
+
for m, n in [(e, f), (f, e)]:
|
| 2970 |
+
for p, q in [(g, h), (h, g)]:
|
| 2971 |
+
yield (x, y, z, t, m, n, p, q)
|
| 2972 |
+
|
| 2973 |
+
def all_eqratios_6points(self) -> Generator[tuple[Point, ...], None, None]:
|
| 2974 |
+
"""List all sets of 6 points that make two equal angles."""
|
| 2975 |
+
record = set()
|
| 2976 |
+
for a, b, c, d, e, f, g, h in self.all_eqratios_8points():
|
| 2977 |
+
if (
|
| 2978 |
+
a not in (c, d)
|
| 2979 |
+
and b not in (c, d)
|
| 2980 |
+
or e not in (g, h)
|
| 2981 |
+
and f not in (g, h)
|
| 2982 |
+
):
|
| 2983 |
+
continue
|
| 2984 |
+
if b in (c, d):
|
| 2985 |
+
a, b = b, a
|
| 2986 |
+
if f in (g, h):
|
| 2987 |
+
e, f = f, e
|
| 2988 |
+
if a == d:
|
| 2989 |
+
c, d = d, c
|
| 2990 |
+
if e == h:
|
| 2991 |
+
g, h = h, g
|
| 2992 |
+
if (a, b, c, d, e, f, g, h) in record:
|
| 2993 |
+
continue
|
| 2994 |
+
record.add((a, b, c, d, e, f, g, h))
|
| 2995 |
+
yield a, b, c, d, e, f, g, h # now a==c, e==g
|
| 2996 |
+
|
| 2997 |
+
def all_cyclics(self) -> Generator[tuple[Point, ...], None, None]:
|
| 2998 |
+
for c in self.type2nodes[Circle]:
|
| 2999 |
+
for x, y, z, t in utils.perm4(c.neighbors(Point)):
|
| 3000 |
+
yield x, y, z, t
|
| 3001 |
+
|
| 3002 |
+
def all_colls(self) -> Generator[tuple[Point, ...], None, None]:
|
| 3003 |
+
for l in self.type2nodes[Line]:
|
| 3004 |
+
for x, y, z in utils.perm3(l.neighbors(Point)):
|
| 3005 |
+
yield x, y, z
|
| 3006 |
+
|
| 3007 |
+
def all_midps(self) -> Generator[tuple[Point, ...], None, None]:
|
| 3008 |
+
for l in self.type2nodes[Line]:
|
| 3009 |
+
for a, b, c in utils.perm3(l.neighbors(Point)):
|
| 3010 |
+
if self.check_cong([a, b, a, c]):
|
| 3011 |
+
yield a, b, c
|
| 3012 |
+
|
| 3013 |
+
def all_circles(self) -> Generator[tuple[Point, ...], None, None]:
|
| 3014 |
+
for l in self.type2nodes[Length]:
|
| 3015 |
+
p2p = defaultdict(list)
|
| 3016 |
+
for s in l.neighbors(Segment):
|
| 3017 |
+
a, b = s.points
|
| 3018 |
+
p2p[a].append(b)
|
| 3019 |
+
p2p[b].append(a)
|
| 3020 |
+
for p, ps in p2p.items():
|
| 3021 |
+
if len(ps) >= 3:
|
| 3022 |
+
for a, b, c in utils.perm3(ps):
|
| 3023 |
+
yield p, a, b, c
|
| 3024 |
+
|
| 3025 |
+
def two_points_on_direction(self, d: Direction) -> tuple[Point, Point]:
|
| 3026 |
+
l = d.neighbors(Line)[0]
|
| 3027 |
+
p1, p2 = l.neighbors(Point)[:2]
|
| 3028 |
+
return p1, p2
|
| 3029 |
+
|
| 3030 |
+
def two_points_of_length(self, l: Length) -> tuple[Point, Point]:
|
| 3031 |
+
s = l.neighbors(Segment)[0]
|
| 3032 |
+
p1, p2 = s.points
|
| 3033 |
+
return p1, p2
|
| 3034 |
+
|
| 3035 |
+
|
| 3036 |
+
def create_consts_str(g: Graph, s: str) -> Union[Ratio, Angle]:
|
| 3037 |
+
if 'pi/' in s:
|
| 3038 |
+
n, d = s.split('pi/')
|
| 3039 |
+
n, d = int(n), int(d)
|
| 3040 |
+
p0, _ = g.get_or_create_const_ang(n, d)
|
| 3041 |
+
else:
|
| 3042 |
+
n, d = s.split('/')
|
| 3043 |
+
n, d = int(n), int(d)
|
| 3044 |
+
p0, _ = g.get_or_create_const_rat(n, d)
|
| 3045 |
+
return p0
|
| 3046 |
+
|
| 3047 |
+
|
| 3048 |
+
def create_consts(g: Graph, p: gm.Node) -> Union[Ratio, Angle]:
|
| 3049 |
+
if isinstance(p, Angle):
|
| 3050 |
+
n, d = p.name.split('pi/')
|
| 3051 |
+
n, d = int(n), int(d)
|
| 3052 |
+
p0, _ = g.get_or_create_const_ang(n, d)
|
| 3053 |
+
if isinstance(p, Ratio):
|
| 3054 |
+
n, d = p.name.split('/')
|
| 3055 |
+
n, d = int(n), int(d)
|
| 3056 |
+
p0, _ = g.get_or_create_const_rat(n, d)
|
| 3057 |
+
return p0 # pylint: disable=undefined-variable
|
backend/core/ag4masses/alphageometry/graph_test.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 DeepMind Technologies Limited
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
|
| 16 |
+
"""Unit tests for graph.py."""
|
| 17 |
+
import unittest
|
| 18 |
+
|
| 19 |
+
from absl.testing import absltest
|
| 20 |
+
import graph as gh
|
| 21 |
+
import numericals as nm
|
| 22 |
+
import problem as pr
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
MAX_LEVEL = 1000
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class GraphTest(unittest.TestCase):
|
| 29 |
+
|
| 30 |
+
@classmethod
|
| 31 |
+
def setUpClass(cls):
|
| 32 |
+
super().setUpClass()
|
| 33 |
+
|
| 34 |
+
cls.defs = pr.Definition.from_txt_file('defs.txt', to_dict=True)
|
| 35 |
+
cls.rules = pr.Theorem.from_txt_file('rules.txt', to_dict=True)
|
| 36 |
+
|
| 37 |
+
# load a complex setup:
|
| 38 |
+
txt = 'a b c = triangle a b c; h = orthocenter a b c; h1 = foot a b c; h2 = foot b c a; h3 = foot c a b; g1 g2 g3 g = centroid g1 g2 g3 g a b c; o = circle a b c ? coll h g o' # pylint: disable=line-too-long
|
| 39 |
+
p = pr.Problem.from_txt(txt, translate=False)
|
| 40 |
+
cls.g, _ = gh.Graph.build_problem(p, GraphTest.defs)
|
| 41 |
+
|
| 42 |
+
def test_build_graph_points(self):
|
| 43 |
+
g = GraphTest.g
|
| 44 |
+
|
| 45 |
+
all_points = g.all_points()
|
| 46 |
+
all_names = [p.name for p in all_points]
|
| 47 |
+
self.assertCountEqual(
|
| 48 |
+
all_names,
|
| 49 |
+
['a', 'b', 'c', 'g', 'h', 'o', 'g1', 'g2', 'g3', 'h1', 'h2', 'h3'],
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
def test_build_graph_predicates(self):
|
| 53 |
+
gr = GraphTest.g
|
| 54 |
+
|
| 55 |
+
a, b, c, g, h, o, g1, g2, g3, h1, h2, h3 = gr.names2points(
|
| 56 |
+
['a', 'b', 'c', 'g', 'h', 'o', 'g1', 'g2', 'g3', 'h1', 'h2', 'h3']
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
# Explicit statements:
|
| 60 |
+
self.assertTrue(gr.check_cong([b, g1, g1, c]))
|
| 61 |
+
self.assertTrue(gr.check_cong([c, g2, g2, a]))
|
| 62 |
+
self.assertTrue(gr.check_cong([a, g3, g3, b]))
|
| 63 |
+
self.assertTrue(gr.check_perp([a, h1, b, c]))
|
| 64 |
+
self.assertTrue(gr.check_perp([b, h2, c, a]))
|
| 65 |
+
self.assertTrue(gr.check_perp([c, h3, a, b]))
|
| 66 |
+
self.assertTrue(gr.check_cong([o, a, o, b]))
|
| 67 |
+
self.assertTrue(gr.check_cong([o, b, o, c]))
|
| 68 |
+
self.assertTrue(gr.check_cong([o, a, o, c]))
|
| 69 |
+
self.assertTrue(gr.check_coll([a, g, g1]))
|
| 70 |
+
self.assertTrue(gr.check_coll([b, g, g2]))
|
| 71 |
+
self.assertTrue(gr.check_coll([g1, b, c]))
|
| 72 |
+
self.assertTrue(gr.check_coll([g2, c, a]))
|
| 73 |
+
self.assertTrue(gr.check_coll([g3, a, b]))
|
| 74 |
+
self.assertTrue(gr.check_perp([a, h, b, c]))
|
| 75 |
+
self.assertTrue(gr.check_perp([b, h, c, a]))
|
| 76 |
+
|
| 77 |
+
# These are NOT part of the premises:
|
| 78 |
+
self.assertFalse(gr.check_perp([c, h, a, b]))
|
| 79 |
+
self.assertFalse(gr.check_coll([c, g, g3]))
|
| 80 |
+
|
| 81 |
+
# These are automatically inferred by the graph datastructure:
|
| 82 |
+
self.assertTrue(gr.check_eqangle([a, h1, b, c, b, h2, c, a]))
|
| 83 |
+
self.assertTrue(gr.check_eqangle([a, h1, b, h2, b, c, c, a]))
|
| 84 |
+
self.assertTrue(gr.check_eqratio([b, g1, g1, c, c, g2, g2, a]))
|
| 85 |
+
self.assertTrue(gr.check_eqratio([b, g1, g1, c, o, a, o, b]))
|
| 86 |
+
self.assertTrue(gr.check_para([a, h, a, h1]))
|
| 87 |
+
self.assertTrue(gr.check_para([b, h, b, h2]))
|
| 88 |
+
self.assertTrue(gr.check_coll([a, h, h1]))
|
| 89 |
+
self.assertTrue(gr.check_coll([b, h, h2]))
|
| 90 |
+
|
| 91 |
+
def test_enumerate_colls(self):
|
| 92 |
+
g = GraphTest.g
|
| 93 |
+
|
| 94 |
+
for a, b, c in g.all_colls():
|
| 95 |
+
self.assertTrue(g.check_coll([a, b, c]))
|
| 96 |
+
self.assertTrue(nm.check_coll([a.num, b.num, c.num]))
|
| 97 |
+
|
| 98 |
+
def test_enumerate_paras(self):
|
| 99 |
+
g = GraphTest.g
|
| 100 |
+
|
| 101 |
+
for a, b, c, d in g.all_paras():
|
| 102 |
+
self.assertTrue(g.check_para([a, b, c, d]))
|
| 103 |
+
self.assertTrue(nm.check_para([a.num, b.num, c.num, d.num]))
|
| 104 |
+
|
| 105 |
+
def test_enumerate_perps(self):
|
| 106 |
+
g = GraphTest.g
|
| 107 |
+
|
| 108 |
+
for a, b, c, d in g.all_perps():
|
| 109 |
+
self.assertTrue(g.check_perp([a, b, c, d]))
|
| 110 |
+
self.assertTrue(nm.check_perp([a.num, b.num, c.num, d.num]))
|
| 111 |
+
|
| 112 |
+
def test_enumerate_congs(self):
|
| 113 |
+
g = GraphTest.g
|
| 114 |
+
|
| 115 |
+
for a, b, c, d in g.all_congs():
|
| 116 |
+
self.assertTrue(g.check_cong([a, b, c, d]))
|
| 117 |
+
self.assertTrue(nm.check_cong([a.num, b.num, c.num, d.num]))
|
| 118 |
+
|
| 119 |
+
def test_enumerate_eqangles(self):
|
| 120 |
+
g = GraphTest.g
|
| 121 |
+
|
| 122 |
+
for a, b, c, d, x, y, z, t in g.all_eqangles_8points():
|
| 123 |
+
self.assertTrue(g.check_eqangle([a, b, c, d, x, y, z, t]))
|
| 124 |
+
self.assertTrue(
|
| 125 |
+
nm.check_eqangle(
|
| 126 |
+
[a.num, b.num, c.num, d.num, x.num, y.num, z.num, t.num]
|
| 127 |
+
)
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
def test_enumerate_eqratios(self):
|
| 131 |
+
g = GraphTest.g
|
| 132 |
+
|
| 133 |
+
for a, b, c, d, x, y, z, t in g.all_eqratios_8points():
|
| 134 |
+
self.assertTrue(g.check_eqratio([a, b, c, d, x, y, z, t]))
|
| 135 |
+
self.assertTrue(
|
| 136 |
+
nm.check_eqratio(
|
| 137 |
+
[a.num, b.num, c.num, d.num, x.num, y.num, z.num, t.num]
|
| 138 |
+
)
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
def test_enumerate_cyclics(self):
|
| 142 |
+
g = GraphTest.g
|
| 143 |
+
|
| 144 |
+
for a, b, c, d, x, y, z, t in g.all_cyclics():
|
| 145 |
+
self.assertTrue(g.check_cyclic([a, b, c, d, x, y, z, t]))
|
| 146 |
+
self.assertTrue(nm.check_cyclic([a.num, b.num, c.num, d.num]))
|
| 147 |
+
|
| 148 |
+
def test_enumerate_midps(self):
|
| 149 |
+
g = GraphTest.g
|
| 150 |
+
|
| 151 |
+
for a, b, c in g.all_midps():
|
| 152 |
+
self.assertTrue(g.check_midp([a, b, c]))
|
| 153 |
+
self.assertTrue(nm.check_midp([a.num, b.num, c.num]))
|
| 154 |
+
|
| 155 |
+
def test_enumerate_circles(self):
|
| 156 |
+
g = GraphTest.g
|
| 157 |
+
|
| 158 |
+
for a, b, c, d in g.all_circles():
|
| 159 |
+
self.assertTrue(g.check_circle([a, b, c, d]))
|
| 160 |
+
self.assertTrue(nm.check_circle([a.num, b.num, c.num, d.num]))
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
if __name__ == '__main__':
|
| 164 |
+
absltest.main()
|
backend/core/alphageometry_adapter.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
AlphaGeometry Adapter: Run AlphaGeometry proofs from Python with advanced features.
|
| 3 |
+
Features:
|
| 4 |
+
- Async execution, timeouts, resource limits
|
| 5 |
+
- Logging, error handling, compliance
|
| 6 |
+
- Batch/parallel runs, result parsing, provenance
|
| 7 |
+
- Plugin system, benchmarking, test harness
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import subprocess
|
| 11 |
+
import os
|
| 12 |
+
import asyncio
|
| 13 |
+
import concurrent.futures
|
| 14 |
+
import logging
|
| 15 |
+
import time
|
| 16 |
+
from typing import List, Optional, Callable, Dict, Any
|
| 17 |
+
|
| 18 |
+
class AlphaGeometryResult:
|
| 19 |
+
def __init__(self, output: str, success: bool, elapsed: float, provenance: Optional[Dict[str, Any]] = None):
|
| 20 |
+
self.output: str = output
|
| 21 |
+
self.success: bool = success
|
| 22 |
+
self.elapsed: float = elapsed
|
| 23 |
+
self.provenance: Dict[str, Any] = provenance or {}
|
| 24 |
+
|
| 25 |
+
def parse(self) -> Dict[str, Any]:
|
| 26 |
+
# Example: parse output for key results (stub)
|
| 27 |
+
lines: List[str] = self.output.splitlines()
|
| 28 |
+
result: Dict[str, Any] = {"lines": lines, "success": self.success, "elapsed": self.elapsed}
|
| 29 |
+
if any("QED" in l for l in lines):
|
| 30 |
+
result["proved"] = True
|
| 31 |
+
return result
|
| 32 |
+
|
| 33 |
+
def run_alphageometry(
|
| 34 |
+
input_file: str,
|
| 35 |
+
alphageometry_dir: str = "external/alphageometry",
|
| 36 |
+
timeout: int = 60,
|
| 37 |
+
plugins: Optional[List[Callable[[AlphaGeometryResult], None]]] = None
|
| 38 |
+
) -> AlphaGeometryResult:
|
| 39 |
+
"""
|
| 40 |
+
Runs AlphaGeometry on the given input file and returns a structured result.
|
| 41 |
+
"""
|
| 42 |
+
exe_path = os.path.join(alphageometry_dir, "main.py")
|
| 43 |
+
if not os.path.exists(exe_path):
|
| 44 |
+
raise FileNotFoundError(f"AlphaGeometry not found at {exe_path}")
|
| 45 |
+
start = time.time()
|
| 46 |
+
try:
|
| 47 |
+
result = subprocess.run([
|
| 48 |
+
"python", exe_path, input_file
|
| 49 |
+
], capture_output=True, text=True, check=True, timeout=timeout)
|
| 50 |
+
elapsed = time.time() - start
|
| 51 |
+
ag_result = AlphaGeometryResult(result.stdout, True, elapsed)
|
| 52 |
+
except subprocess.TimeoutExpired as e:
|
| 53 |
+
logging.error(f"AlphaGeometry timeout: {e}")
|
| 54 |
+
ag_result = AlphaGeometryResult(f"Timeout: {e}", False, timeout)
|
| 55 |
+
except Exception as e:
|
| 56 |
+
logging.error(f"AlphaGeometry error: {e}", exc_info=True)
|
| 57 |
+
ag_result = AlphaGeometryResult(f"AlphaGeometry error: {e}", False, time.time() - start)
|
| 58 |
+
# Plugin post-processing
|
| 59 |
+
if plugins:
|
| 60 |
+
for plugin in plugins:
|
| 61 |
+
plugin(ag_result)
|
| 62 |
+
return ag_result
|
| 63 |
+
|
| 64 |
+
async def run_alphageometry_async(
|
| 65 |
+
input_file: str,
|
| 66 |
+
alphageometry_dir: str = "external/alphageometry",
|
| 67 |
+
timeout: int = 60
|
| 68 |
+
) -> AlphaGeometryResult:
|
| 69 |
+
loop = asyncio.get_event_loop()
|
| 70 |
+
with concurrent.futures.ThreadPoolExecutor() as pool:
|
| 71 |
+
return await loop.run_in_executor(pool, run_alphageometry, input_file, alphageometry_dir, timeout)
|
| 72 |
+
|
| 73 |
+
def run_alphageometry_batch(
|
| 74 |
+
input_files: List[str],
|
| 75 |
+
alphageometry_dir: str = "external/alphageometry",
|
| 76 |
+
timeout: int = 60,
|
| 77 |
+
parallel: int = 4
|
| 78 |
+
) -> List[AlphaGeometryResult]:
|
| 79 |
+
"""Run AlphaGeometry on a batch of input files in parallel."""
|
| 80 |
+
with concurrent.futures.ThreadPoolExecutor(max_workers=parallel) as executor:
|
| 81 |
+
futures: List[concurrent.futures.Future[AlphaGeometryResult]] = [executor.submit(run_alphageometry, f, alphageometry_dir, timeout) for f in input_files]
|
| 82 |
+
return [f.result() for f in futures]
|
| 83 |
+
|
| 84 |
+
def benchmark_alphageometry(
|
| 85 |
+
input_file: str,
|
| 86 |
+
alphageometry_dir: str = "external/alphageometry",
|
| 87 |
+
n_iter: int = 5
|
| 88 |
+
) -> None:
|
| 89 |
+
times: List[float] = []
|
| 90 |
+
for _ in range(n_iter):
|
| 91 |
+
start = time.time()
|
| 92 |
+
_ = run_alphageometry(input_file, alphageometry_dir)
|
| 93 |
+
times.append(float(time.time() - start))
|
| 94 |
+
if times:
|
| 95 |
+
mean: float = sum(times) / len(times)
|
| 96 |
+
std: float = float((sum((t - mean) ** 2 for t in times) / len(times)) ** 0.5)
|
| 97 |
+
print(f"[Benchmark] Mean: {mean:.4f}s, Std: {std:.4f}s")
|
| 98 |
+
else:
|
| 99 |
+
print("[Benchmark] No runs completed.")
|
| 100 |
+
|
| 101 |
+
# --- Plugin Example ---
|
| 102 |
+
class QEDPlugin:
|
| 103 |
+
def __call__(self, result: AlphaGeometryResult) -> None:
|
| 104 |
+
if "QED" in result.output:
|
| 105 |
+
result.provenance["proved"] = True
|
| 106 |
+
|
| 107 |
+
# --- Test Harness ---
|
| 108 |
+
def test_alphageometry_adapter() -> None:
|
| 109 |
+
# Dummy test: expects a dummy input file and AlphaGeometry stub
|
| 110 |
+
input_file = "dummy_input.txt"
|
| 111 |
+
with open(input_file, "w") as f:
|
| 112 |
+
f.write("A B C = triangle A B C\n")
|
| 113 |
+
result = run_alphageometry(input_file, timeout=2, plugins=[QEDPlugin()])
|
| 114 |
+
print("Result:", result.parse())
|
| 115 |
+
os.remove(input_file)
|
| 116 |
+
|
| 117 |
+
if __name__ == "__main__":
|
| 118 |
+
test_alphageometry_adapter()
|
backend/core/alphageometry_runner.py
ADDED
|
File without changes
|
backend/core/captum.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Minimal shim for `captum.attr.IntegratedGradients` used in neuro_symbolic explainability.
|
| 3 |
+
This avoids requiring the real Captum package during test collection while still allowing
|
| 4 |
+
code that imports `IntegratedGradients` to run (as a no-op shim).
|
| 5 |
+
"""
|
| 6 |
+
from typing import Any, Tuple
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class IntegratedGradients:
|
| 10 |
+
def __init__(self, model: Any):
|
| 11 |
+
self.model = model
|
| 12 |
+
|
| 13 |
+
def attribute(self, inputs: Any, target: int = 0, return_convergence_delta: bool = False) -> Tuple[Any, Any]:
|
| 14 |
+
# Return zero-attribution and zero delta
|
| 15 |
+
import numpy as np
|
| 16 |
+
attr = np.zeros_like(inputs)
|
| 17 |
+
delta = 0.0
|
| 18 |
+
return attr, delta
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
__all__ = ["IntegratedGradients"]
|
backend/core/coq_adapter.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Adapter for running Coq proofs from Python.
|
| 3 |
+
"""
|
| 4 |
+
import subprocess
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
def run_coq(input_file: str, coq_dir: str = "external/coq-platform") -> str:
|
| 8 |
+
"""
|
| 9 |
+
Runs Coq on the given input file and returns the output as a string.
|
| 10 |
+
"""
|
| 11 |
+
exe_path = os.path.join(coq_dir, "bin", "coqc")
|
| 12 |
+
if not os.path.exists(exe_path):
|
| 13 |
+
raise FileNotFoundError(f"Coq not found at {exe_path}")
|
| 14 |
+
try:
|
| 15 |
+
result = subprocess.run([
|
| 16 |
+
exe_path, input_file
|
| 17 |
+
], capture_output=True, text=True, check=True)
|
| 18 |
+
return result.stdout
|
| 19 |
+
except Exception as e:
|
| 20 |
+
return f"Coq error: {e}"
|
backend/core/cross_universe_analysis.py
ADDED
|
@@ -0,0 +1,599 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
# --- Real Graph Analytics ---
|
| 4 |
+
try:
|
| 5 |
+
import numpy as np
|
| 6 |
+
except Exception:
|
| 7 |
+
class _np_stub:
|
| 8 |
+
def zeros(self, *a, **k):
|
| 9 |
+
return []
|
| 10 |
+
|
| 11 |
+
def mean(self, *a, **k):
|
| 12 |
+
return 0.0
|
| 13 |
+
|
| 14 |
+
def median(self, *a, **k):
|
| 15 |
+
return 0.0
|
| 16 |
+
|
| 17 |
+
np = _np_stub()
|
| 18 |
+
|
| 19 |
+
try:
|
| 20 |
+
import pandas as pd
|
| 21 |
+
except Exception:
|
| 22 |
+
pd = None
|
| 23 |
+
|
| 24 |
+
try:
|
| 25 |
+
import matplotlib.pyplot as plt
|
| 26 |
+
except Exception:
|
| 27 |
+
plt = None
|
| 28 |
+
|
| 29 |
+
def theorem_graph_centrality(analyzer: 'CrossUniverseAnalyzer', universe_ids: List[int]) -> Dict[int, float]:
|
| 30 |
+
G = nx.DiGraph()
|
| 31 |
+
for uid in universe_ids:
|
| 32 |
+
theorems = analyzer.db.query(Theorem).filter(Theorem.universe_id == uid).all()
|
| 33 |
+
for thm in theorems:
|
| 34 |
+
G.add_node(thm.id)
|
| 35 |
+
deps = getattr(thm, 'dependencies', [])
|
| 36 |
+
for dep in deps:
|
| 37 |
+
G.add_edge(dep, thm.id)
|
| 38 |
+
centrality = nx.degree_centrality(G)
|
| 39 |
+
return centrality
|
| 40 |
+
|
| 41 |
+
def theorem_graph_communities(analyzer: 'CrossUniverseAnalyzer', universe_ids: List[int]) -> Dict[int, int]:
|
| 42 |
+
G = nx.Graph()
|
| 43 |
+
for uid in universe_ids:
|
| 44 |
+
theorems = analyzer.db.query(Theorem).filter(Theorem.universe_id == uid).all()
|
| 45 |
+
for thm in theorems:
|
| 46 |
+
G.add_node(thm.id)
|
| 47 |
+
deps = getattr(thm, 'dependencies', [])
|
| 48 |
+
for dep in deps:
|
| 49 |
+
G.add_edge(dep, thm.id)
|
| 50 |
+
from networkx.algorithms.community import greedy_modularity_communities
|
| 51 |
+
comms = list(greedy_modularity_communities(G))
|
| 52 |
+
comm_map = {}
|
| 53 |
+
for i, comm in enumerate(comms):
|
| 54 |
+
for node in comm:
|
| 55 |
+
comm_map[node] = i
|
| 56 |
+
return comm_map
|
| 57 |
+
|
| 58 |
+
def shortest_path_between_theorems(analyzer: 'CrossUniverseAnalyzer', universe_ids: List[int], thm_id1: int, thm_id2: int) -> List[int]:
|
| 59 |
+
G = nx.DiGraph()
|
| 60 |
+
for uid in universe_ids:
|
| 61 |
+
theorems = analyzer.db.query(Theorem).filter(Theorem.universe_id == uid).all()
|
| 62 |
+
for thm in theorems:
|
| 63 |
+
G.add_node(thm.id)
|
| 64 |
+
deps = getattr(thm, 'dependencies', [])
|
| 65 |
+
for dep in deps:
|
| 66 |
+
G.add_edge(dep, thm.id)
|
| 67 |
+
try:
|
| 68 |
+
path = nx.shortest_path(G, source=thm_id1, target=thm_id2)
|
| 69 |
+
return path
|
| 70 |
+
except nx.NetworkXNoPath:
|
| 71 |
+
return []
|
| 72 |
+
|
| 73 |
+
# --- Real Transfer Learning (Axiom Embeddings/Theorem Models) ---
|
| 74 |
+
try:
|
| 75 |
+
from sklearn.decomposition import TruncatedSVD
|
| 76 |
+
from sklearn.linear_model import LogisticRegression
|
| 77 |
+
except Exception:
|
| 78 |
+
TruncatedSVD = None
|
| 79 |
+
LogisticRegression = None
|
| 80 |
+
|
| 81 |
+
try:
|
| 82 |
+
import torch
|
| 83 |
+
import torch.nn as nn
|
| 84 |
+
import torch.optim as optim
|
| 85 |
+
except Exception:
|
| 86 |
+
torch = None
|
| 87 |
+
nn = None
|
| 88 |
+
optim = None
|
| 89 |
+
|
| 90 |
+
def transfer_axiom_embeddings(analyzer: 'CrossUniverseAnalyzer', source_universe: int, target_universe: int) -> np.ndarray:
|
| 91 |
+
# Build axiom embedding matrix for source, transfer to target
|
| 92 |
+
axioms_src = analyzer.db.query(Axiom).filter(Axiom.universe_id == source_universe).all()
|
| 93 |
+
axioms_tgt = analyzer.db.query(Axiom).filter(Axiom.universe_id == target_universe).all()
|
| 94 |
+
all_axioms = list({ax.statement for ax in axioms_src + axioms_tgt})
|
| 95 |
+
X_src = np.array([[1 if ax.statement == a else 0 for a in all_axioms] for ax in axioms_src])
|
| 96 |
+
svd = TruncatedSVD(n_components=2)
|
| 97 |
+
emb_src = svd.fit_transform(X_src)
|
| 98 |
+
# Transfer: project target axioms into source embedding space
|
| 99 |
+
X_tgt = np.array([[1 if ax.statement == a else 0 for a in all_axioms] for ax in axioms_tgt])
|
| 100 |
+
emb_tgt = svd.transform(X_tgt)
|
| 101 |
+
return emb_tgt
|
| 102 |
+
|
| 103 |
+
def transfer_theorem_model(analyzer: 'CrossUniverseAnalyzer', source_universe: int, target_universe: int):
|
| 104 |
+
# Train a simple model on source, transfer to target
|
| 105 |
+
theorems_src = analyzer.db.query(Theorem).filter(Theorem.universe_id == source_universe).all()
|
| 106 |
+
theorems_tgt = analyzer.db.query(Theorem).filter(Theorem.universe_id == target_universe).all()
|
| 107 |
+
all_thms = list({thm.statement for thm in theorems_src + theorems_tgt})
|
| 108 |
+
X_src = np.array([[1 if thm.statement == t else 0 for t in all_thms] for thm in theorems_src])
|
| 109 |
+
y_src = [1]*len(theorems_src)
|
| 110 |
+
model = LogisticRegression().fit(X_src, y_src)
|
| 111 |
+
X_tgt = np.array([[1 if thm.statement == t else 0 for t in all_thms] for thm in theorems_tgt])
|
| 112 |
+
preds = model.predict(X_tgt)
|
| 113 |
+
return preds
|
| 114 |
+
|
| 115 |
+
# --- Real-Time Interactive Visualization (Plotly/Bokeh) ---
|
| 116 |
+
try:
|
| 117 |
+
import plotly.graph_objs as go
|
| 118 |
+
import plotly.offline as py
|
| 119 |
+
except Exception:
|
| 120 |
+
go = None
|
| 121 |
+
py = None
|
| 122 |
+
def plotly_universe_similarity(analyzer: 'CrossUniverseAnalyzer', universe_ids: List[int]):
|
| 123 |
+
sim_matrix = analyzer.universe_similarity(universe_ids)
|
| 124 |
+
fig = go.Figure(data=go.Heatmap(z=sim_matrix, x=universe_ids, y=universe_ids, colorscale='Viridis'))
|
| 125 |
+
fig.update_layout(title="Universe Similarity (Plotly)")
|
| 126 |
+
py.plot(fig, filename='universe_similarity.html')
|
| 127 |
+
|
| 128 |
+
# --- PDF/HTML Reporting ---
|
| 129 |
+
from matplotlib.backends.backend_pdf import PdfPages
|
| 130 |
+
|
| 131 |
+
# Use pandas if available, otherwise fall back to CSV-based reporting
|
| 132 |
+
if pd is not None:
|
| 133 |
+
def generate_pdf_report(analyzer: 'CrossUniverseAnalyzer', universe_ids: List[int], path: str):
|
| 134 |
+
sim_matrix = analyzer.universe_similarity(universe_ids)
|
| 135 |
+
with PdfPages(path) as pdf:
|
| 136 |
+
plt.figure()
|
| 137 |
+
plt.imshow(sim_matrix, cmap='viridis')
|
| 138 |
+
plt.title("Universe Similarity Matrix")
|
| 139 |
+
pdf.savefig()
|
| 140 |
+
plt.close()
|
| 141 |
+
# Add tabular summary
|
| 142 |
+
df = pd.DataFrame(sim_matrix, index=universe_ids, columns=universe_ids)
|
| 143 |
+
fig, ax = plt.subplots()
|
| 144 |
+
ax.axis('off')
|
| 145 |
+
# Convert values/labels to plain Python lists/strings to satisfy static typing
|
| 146 |
+
cell_text = df.values.tolist()
|
| 147 |
+
col_labels = [str(c) for c in df.columns.tolist()]
|
| 148 |
+
row_labels = [str(r) for r in df.index.tolist()]
|
| 149 |
+
tbl = ax.table(cellText=cell_text, colLabels=col_labels, rowLabels=row_labels, loc='center')
|
| 150 |
+
pdf.savefig(fig)
|
| 151 |
+
plt.close(fig)
|
| 152 |
+
|
| 153 |
+
def generate_html_report(analyzer: 'CrossUniverseAnalyzer', universe_ids: List[int], path: str):
|
| 154 |
+
sim_matrix = analyzer.universe_similarity(universe_ids)
|
| 155 |
+
df = pd.DataFrame(sim_matrix, index=universe_ids, columns=universe_ids)
|
| 156 |
+
html = df.to_html()
|
| 157 |
+
with open(path, 'w') as f:
|
| 158 |
+
f.write(f"<h1>Universe Similarity Matrix</h1>{html}")
|
| 159 |
+
else:
|
| 160 |
+
import csv
|
| 161 |
+
def generate_pdf_report(analyzer: 'CrossUniverseAnalyzer', universe_ids: List[int], path: str):
|
| 162 |
+
# Minimal fallback: write CSV with similarity matrix and create a tiny PDF with a single page
|
| 163 |
+
sim_matrix = analyzer.universe_similarity(universe_ids)
|
| 164 |
+
csv_path = path + '.csv'
|
| 165 |
+
with open(csv_path, 'w', newline='') as f:
|
| 166 |
+
writer = csv.writer(f)
|
| 167 |
+
writer.writerow([''] + [str(u) for u in universe_ids])
|
| 168 |
+
for i, u in enumerate(universe_ids):
|
| 169 |
+
writer.writerow([str(u)] + list(sim_matrix[i]))
|
| 170 |
+
# Create a tiny PDF with matplotlib if available
|
| 171 |
+
try:
|
| 172 |
+
plt.figure()
|
| 173 |
+
plt.imshow(sim_matrix, cmap='viridis')
|
| 174 |
+
plt.title("Universe Similarity Matrix")
|
| 175 |
+
plt.savefig(path)
|
| 176 |
+
plt.close()
|
| 177 |
+
except Exception:
|
| 178 |
+
# If matplotlib isn't available, write the CSV only
|
| 179 |
+
pass
|
| 180 |
+
|
| 181 |
+
def generate_html_report(analyzer: 'CrossUniverseAnalyzer', universe_ids: List[int], path: str):
|
| 182 |
+
sim_matrix = analyzer.universe_similarity(universe_ids)
|
| 183 |
+
csv_path = path + '.csv'
|
| 184 |
+
with open(csv_path, 'w', newline='') as f:
|
| 185 |
+
writer = csv.writer(f)
|
| 186 |
+
writer.writerow([''] + [str(u) for u in universe_ids])
|
| 187 |
+
for i, u in enumerate(universe_ids):
|
| 188 |
+
writer.writerow([str(u)] + list(sim_matrix[i]))
|
| 189 |
+
# Also generate a minimal HTML table
|
| 190 |
+
try:
|
| 191 |
+
html_rows = ['<tr><th></th>' + ''.join(f'<th>{u}</th>' for u in universe_ids) + '</tr>']
|
| 192 |
+
for i, u in enumerate(universe_ids):
|
| 193 |
+
row = '<tr>' + f'<td>{u}</td>' + ''.join(f'<td>{val}</td>' for val in sim_matrix[i]) + '</tr>'
|
| 194 |
+
html_rows.append(row)
|
| 195 |
+
with open(path, 'w') as f:
|
| 196 |
+
f.write('<table>' + ''.join(html_rows) + '</table>')
|
| 197 |
+
except Exception:
|
| 198 |
+
pass
|
| 199 |
+
|
| 200 |
+
# --- Real Data Ingestion (CSV/JSON/API) ---
|
| 201 |
+
import requests
|
| 202 |
+
def ingest_universe_data_from_csv(path: str) -> List[Dict[str, Any]]:
|
| 203 |
+
df = pd.read_csv(path)
|
| 204 |
+
# Ensure return type matches List[Dict[str, Any]]
|
| 205 |
+
records = [dict((str(k), v) for k, v in r.items()) for r in df.to_dict(orient='records')]
|
| 206 |
+
return records
|
| 207 |
+
|
| 208 |
+
def ingest_universe_data_from_json(path: str) -> List[Dict[str, Any]]:
|
| 209 |
+
import json
|
| 210 |
+
with open(path, 'r') as f:
|
| 211 |
+
return json.load(f)
|
| 212 |
+
|
| 213 |
+
def ingest_universe_data_from_api(url: str) -> List[Dict[str, Any]]:
|
| 214 |
+
resp = requests.get(url)
|
| 215 |
+
return resp.json()
|
| 216 |
+
|
| 217 |
+
# --- Expanded Test Harness with Real Analytics/Reporting ---
|
| 218 |
+
def test_fully_real_cross_universe_analysis():
|
| 219 |
+
logging.basicConfig(level=logging.INFO)
|
| 220 |
+
analyzer = CrossUniverseAnalyzer()
|
| 221 |
+
universe_ids = [1, 2, 3, 4]
|
| 222 |
+
# Graph analytics
|
| 223 |
+
print("Centrality:", theorem_graph_centrality(analyzer, universe_ids))
|
| 224 |
+
print("Communities:", theorem_graph_communities(analyzer, universe_ids))
|
| 225 |
+
print("Shortest path:", shortest_path_between_theorems(analyzer, universe_ids, 1, 2))
|
| 226 |
+
# Transfer learning
|
| 227 |
+
print("Axiom embedding transfer:", transfer_axiom_embeddings(analyzer, 1, 2))
|
| 228 |
+
print("Theorem model transfer:", transfer_theorem_model(analyzer, 1, 2))
|
| 229 |
+
# Interactive visualization
|
| 230 |
+
plotly_universe_similarity(analyzer, universe_ids)
|
| 231 |
+
# PDF/HTML reporting
|
| 232 |
+
generate_pdf_report(analyzer, universe_ids, "universe_report.pdf")
|
| 233 |
+
generate_html_report(analyzer, universe_ids, "universe_report.html")
|
| 234 |
+
# Data ingestion
|
| 235 |
+
print("Ingested CSV:", ingest_universe_data_from_csv("analysis.csv"))
|
| 236 |
+
# Performance profiling
|
| 237 |
+
import time
|
| 238 |
+
start = time.time()
|
| 239 |
+
analyzer.analyze(universe_ids)
|
| 240 |
+
print("Analysis time:", time.time() - start)
|
| 241 |
+
|
| 242 |
+
if __name__ == "__main__":
|
| 243 |
+
test_fully_real_cross_universe_analysis()
|
| 244 |
+
# --- Advanced ML/Statistical Analysis ---
|
| 245 |
+
try:
|
| 246 |
+
from sklearn.decomposition import PCA
|
| 247 |
+
from sklearn.manifold import TSNE
|
| 248 |
+
from sklearn.ensemble import IsolationForest
|
| 249 |
+
except Exception:
|
| 250 |
+
PCA = None
|
| 251 |
+
TSNE = None
|
| 252 |
+
IsolationForest = None
|
| 253 |
+
|
| 254 |
+
try:
|
| 255 |
+
import shap
|
| 256 |
+
except Exception:
|
| 257 |
+
shap = None
|
| 258 |
+
|
| 259 |
+
try:
|
| 260 |
+
import lime.lime_tabular
|
| 261 |
+
except Exception:
|
| 262 |
+
lime = None
|
| 263 |
+
|
| 264 |
+
try:
|
| 265 |
+
import matplotlib.pyplot as plt
|
| 266 |
+
except Exception:
|
| 267 |
+
plt = None
|
| 268 |
+
|
| 269 |
+
try:
|
| 270 |
+
import networkx as nx
|
| 271 |
+
except Exception:
|
| 272 |
+
nx = None
|
| 273 |
+
|
| 274 |
+
import multiprocessing
|
| 275 |
+
try:
|
| 276 |
+
import dask
|
| 277 |
+
import dask.dataframe as dd
|
| 278 |
+
except Exception:
|
| 279 |
+
dask = None
|
| 280 |
+
dd = None
|
| 281 |
+
|
| 282 |
+
def pca_universe_features(analyzer: 'CrossUniverseAnalyzer', universe_ids: List[int]) -> np.ndarray:
|
| 283 |
+
# Build feature matrix: each row = axiom vector for a universe
|
| 284 |
+
all_axioms = list({ax for uid in universe_ids for ax in analyzer.shared_axioms([uid])})
|
| 285 |
+
X = []
|
| 286 |
+
for uid in universe_ids:
|
| 287 |
+
axioms = analyzer.shared_axioms([uid])
|
| 288 |
+
X.append([1 if ax in axioms else 0 for ax in all_axioms])
|
| 289 |
+
pca = PCA(n_components=2)
|
| 290 |
+
arr = np.array(X)
|
| 291 |
+
return pca.fit_transform(arr)
|
| 292 |
+
|
| 293 |
+
def tsne_universe_features(analyzer: 'CrossUniverseAnalyzer', universe_ids: List[int]) -> np.ndarray:
|
| 294 |
+
all_axioms = list({ax for uid in universe_ids for ax in analyzer.shared_axioms([uid])})
|
| 295 |
+
X = []
|
| 296 |
+
for uid in universe_ids:
|
| 297 |
+
axioms = analyzer.shared_axioms([uid])
|
| 298 |
+
X.append([1 if ax in axioms else 0 for ax in all_axioms])
|
| 299 |
+
tsne = TSNE(n_components=2)
|
| 300 |
+
arr = np.array(X)
|
| 301 |
+
return tsne.fit_transform(arr)
|
| 302 |
+
|
| 303 |
+
def isolation_forest_anomaly(analyzer: 'CrossUniverseAnalyzer', universe_ids: List[int]) -> List[int]:
|
| 304 |
+
all_axioms = list({ax for uid in universe_ids for ax in analyzer.shared_axioms([uid])})
|
| 305 |
+
X = []
|
| 306 |
+
for uid in universe_ids:
|
| 307 |
+
axioms = analyzer.shared_axioms([uid])
|
| 308 |
+
X.append([1 if ax in axioms else 0 for ax in all_axioms])
|
| 309 |
+
clf = IsolationForest()
|
| 310 |
+
preds = clf.fit_predict(X)
|
| 311 |
+
return [uid for uid, pred in zip(universe_ids, preds) if pred == -1]
|
| 312 |
+
|
| 313 |
+
# --- Distributed/Batch Analysis ---
|
| 314 |
+
def distributed_batch_analyze(analyze_fn: Callable, universe_batches: List[List[int]], num_workers: int = 4) -> List[Any]:
|
| 315 |
+
with multiprocessing.Pool(num_workers) as pool:
|
| 316 |
+
results = pool.map(analyze_fn, universe_batches)
|
| 317 |
+
return results
|
| 318 |
+
|
| 319 |
+
def dask_batch_analyze(analyze_fn: Callable, universe_ids: List[int], batch_size: int = 10) -> List[Any]:
|
| 320 |
+
batches = [universe_ids[i:i+batch_size] for i in range(0, len(universe_ids), batch_size)]
|
| 321 |
+
ddf = dd.from_pandas(dd.DataFrame({'batch': batches}), npartitions=len(batches))
|
| 322 |
+
return list(ddf['batch'].map(analyze_fn).compute())
|
| 323 |
+
|
| 324 |
+
# --- SHAP/LIME Explainability ---
|
| 325 |
+
def explain_universe_similarity_shap(analyzer: 'CrossUniverseAnalyzer', universe_ids: List[int]):
|
| 326 |
+
all_axioms = list({ax for uid in universe_ids for ax in analyzer.shared_axioms([uid])})
|
| 327 |
+
X = []
|
| 328 |
+
for uid in universe_ids:
|
| 329 |
+
axioms = analyzer.shared_axioms([uid])
|
| 330 |
+
X.append([1 if ax in axioms else 0 for ax in all_axioms])
|
| 331 |
+
model = IsolationForest().fit(X)
|
| 332 |
+
explainer = shap.TreeExplainer(model)
|
| 333 |
+
shap_values = explainer.shap_values(X)
|
| 334 |
+
shap.summary_plot(shap_values, X, feature_names=all_axioms)
|
| 335 |
+
|
| 336 |
+
def explain_universe_similarity_lime(analyzer: 'CrossUniverseAnalyzer', universe_ids: List[int]):
|
| 337 |
+
all_axioms = list({ax for uid in universe_ids for ax in analyzer.shared_axioms([uid])})
|
| 338 |
+
X = []
|
| 339 |
+
for uid in universe_ids:
|
| 340 |
+
axioms = analyzer.shared_axioms([uid])
|
| 341 |
+
X.append([1 if ax in axioms else 0 for ax in all_axioms])
|
| 342 |
+
model = IsolationForest().fit(X)
|
| 343 |
+
explainer = lime.lime_tabular.LimeTabularExplainer(X)
|
| 344 |
+
exp = explainer.explain_instance(X[0], model.predict)
|
| 345 |
+
exp.show_in_notebook()
|
| 346 |
+
|
| 347 |
+
# --- Data Export/Import, Reporting ---
|
| 348 |
+
def export_analysis_to_csv(results: List[Dict[str, Any]], path: str):
|
| 349 |
+
df = pd.DataFrame(results)
|
| 350 |
+
df.to_csv(path, index=False)
|
| 351 |
+
|
| 352 |
+
def import_analysis_from_csv(path: str) -> List[Dict[str, Any]]:
|
| 353 |
+
df = pd.read_csv(path)
|
| 354 |
+
records = [dict((str(k), v) for k, v in r.items()) for r in df.to_dict(orient='records')]
|
| 355 |
+
return records
|
| 356 |
+
|
| 357 |
+
# --- Advanced Visualization ---
|
| 358 |
+
def plot_universe_network(analyzer: 'CrossUniverseAnalyzer', universe_ids: List[int]):
|
| 359 |
+
G = nx.Graph()
|
| 360 |
+
for uid in universe_ids:
|
| 361 |
+
G.add_node(uid)
|
| 362 |
+
sim_matrix = analyzer.universe_similarity(universe_ids)
|
| 363 |
+
for i, uid1 in enumerate(universe_ids):
|
| 364 |
+
for j, uid2 in enumerate(universe_ids):
|
| 365 |
+
if i < j and sim_matrix[i, j] > 0.5:
|
| 366 |
+
G.add_edge(uid1, uid2, weight=sim_matrix[i, j])
|
| 367 |
+
pos = nx.spring_layout(G)
|
| 368 |
+
nx.draw(G, pos, with_labels=True, node_color='lightblue', edge_color='gray')
|
| 369 |
+
plt.title("Universe Network (Similarity > 0.5)")
|
| 370 |
+
plt.show()
|
| 371 |
+
|
| 372 |
+
# --- Integration Hooks (Expanded) ---
|
| 373 |
+
def integrate_with_theorem_engine(theorem_engine: Any, analyzer: Any):
|
| 374 |
+
analyzer.logger.info("Integrating with theorem engine.")
|
| 375 |
+
pass
|
| 376 |
+
|
| 377 |
+
def integrate_with_neuro_symbolic(neuro_module: Any, analyzer: Any):
|
| 378 |
+
analyzer.logger.info("Integrating with neuro-symbolic module.")
|
| 379 |
+
pass
|
| 380 |
+
|
| 381 |
+
def integrate_with_quantum(quantum_module: Any, analyzer: Any):
|
| 382 |
+
analyzer.logger.info("Integrating with quantum module.")
|
| 383 |
+
pass
|
| 384 |
+
|
| 385 |
+
# --- Expanded Test Harness ---
|
| 386 |
+
def test_real_cross_universe_analysis():
|
| 387 |
+
logging.basicConfig(level=logging.INFO)
|
| 388 |
+
analyzer = CrossUniverseAnalyzer()
|
| 389 |
+
universe_ids = [1, 2, 3, 4]
|
| 390 |
+
# PCA/t-SNE
|
| 391 |
+
print("PCA features:", pca_universe_features(analyzer, universe_ids))
|
| 392 |
+
print("t-SNE features:", tsne_universe_features(analyzer, universe_ids))
|
| 393 |
+
# Isolation Forest anomaly
|
| 394 |
+
print("Isolation Forest anomalies:", isolation_forest_anomaly(analyzer, universe_ids))
|
| 395 |
+
# Distributed/batch
|
| 396 |
+
print("Distributed batch analyze:", distributed_batch_analyze(analyzer.analyze, [universe_ids]*2))
|
| 397 |
+
print("Dask batch analyze:", dask_batch_analyze(analyzer.analyze, universe_ids))
|
| 398 |
+
# SHAP/LIME explainability
|
| 399 |
+
explain_universe_similarity_shap(analyzer, universe_ids)
|
| 400 |
+
explain_universe_similarity_lime(analyzer, universe_ids)
|
| 401 |
+
# Export/import
|
| 402 |
+
results = [analyzer.analyze(universe_ids)]
|
| 403 |
+
export_analysis_to_csv(results, "analysis.csv")
|
| 404 |
+
print("Imported analysis:", import_analysis_from_csv("analysis.csv"))
|
| 405 |
+
# Visualization
|
| 406 |
+
plot_universe_network(analyzer, universe_ids)
|
| 407 |
+
|
| 408 |
+
if __name__ == "__main__":
|
| 409 |
+
test_real_cross_universe_analysis()
|
| 410 |
+
|
| 411 |
+
import logging
|
| 412 |
+
from typing import List, Dict, Any, Optional, Set, Callable
|
| 413 |
+
from collections import Counter, defaultdict
|
| 414 |
+
import numpy as np
|
| 415 |
+
from backend.db.models import Universe, Axiom, Theorem, AnalysisResult
|
| 416 |
+
from backend.db.session import SessionLocal
|
| 417 |
+
|
| 418 |
+
class CrossUniverseAnalyzer:
|
| 419 |
+
"""
|
| 420 |
+
Advanced cross-universe analysis for mathematical universes, axioms, and theorems.
|
| 421 |
+
Provides lineage, influence, clustering, anomaly detection, transfer learning, and more.
|
| 422 |
+
Extensible for integration with neuro-symbolic, quantum, and external provers.
|
| 423 |
+
"""
|
| 424 |
+
def __init__(self, db_session=None, logger=None):
|
| 425 |
+
self.db = db_session or SessionLocal()
|
| 426 |
+
self.logger = logger or logging.getLogger("CrossUniverseAnalyzer")
|
| 427 |
+
|
| 428 |
+
def shared_axioms(self, universe_ids: List[int]) -> List[str]:
|
| 429 |
+
axiom_sets = []
|
| 430 |
+
for uid in universe_ids:
|
| 431 |
+
axioms = self.db.query(Axiom).filter(Axiom.universe_id == uid, Axiom.is_active == 1).all()
|
| 432 |
+
axiom_sets.append(set(ax.statement for ax in axioms))
|
| 433 |
+
shared = set.intersection(*axiom_sets) if axiom_sets else set()
|
| 434 |
+
self.logger.info(f"Shared axioms for universes {universe_ids}: {shared}")
|
| 435 |
+
return list(shared)
|
| 436 |
+
|
| 437 |
+
def shared_theorems(self, universe_ids: List[int]) -> List[str]:
|
| 438 |
+
thm_sets = []
|
| 439 |
+
for uid in universe_ids:
|
| 440 |
+
theorems = self.db.query(Theorem).filter(Theorem.universe_id == uid).all()
|
| 441 |
+
thm_sets.append(set(thm.statement for thm in theorems))
|
| 442 |
+
shared = set.intersection(*thm_sets) if thm_sets else set()
|
| 443 |
+
self.logger.info(f"Shared theorems for universes {universe_ids}: {shared}")
|
| 444 |
+
return list(shared)
|
| 445 |
+
|
| 446 |
+
def axiom_lineage(self, axiom_id: int) -> List[int]:
|
| 447 |
+
# Trace the lineage of an axiom across universes
|
| 448 |
+
lineage = []
|
| 449 |
+
axiom = self.db.query(Axiom).get(axiom_id)
|
| 450 |
+
while axiom:
|
| 451 |
+
lineage.append(axiom.id)
|
| 452 |
+
axiom = self.db.query(Axiom).get(getattr(axiom, 'parent_id', None)) if getattr(axiom, 'parent_id', None) else None
|
| 453 |
+
self.logger.info(f"Axiom lineage for {axiom_id}: {lineage}")
|
| 454 |
+
return lineage
|
| 455 |
+
|
| 456 |
+
def theorem_influence_graph(self, universe_ids: List[int]) -> Dict[int, Set[int]]:
|
| 457 |
+
# Build a graph of theorem dependencies across universes
|
| 458 |
+
graph = defaultdict(set)
|
| 459 |
+
for uid in universe_ids:
|
| 460 |
+
theorems = self.db.query(Theorem).filter(Theorem.universe_id == uid).all()
|
| 461 |
+
for thm in theorems:
|
| 462 |
+
deps = getattr(thm, 'dependencies', [])
|
| 463 |
+
for dep in deps:
|
| 464 |
+
graph[thm.id].add(dep)
|
| 465 |
+
self.logger.info(f"Theorem influence graph: {dict(graph)}")
|
| 466 |
+
return dict(graph)
|
| 467 |
+
|
| 468 |
+
def universe_similarity(self, universe_ids: List[int], metric: str = 'jaccard') -> np.ndarray:
|
| 469 |
+
# Compute pairwise similarity between universes
|
| 470 |
+
axioms_by_universe = []
|
| 471 |
+
for uid in universe_ids:
|
| 472 |
+
axioms = self.db.query(Axiom).filter(Axiom.universe_id == uid, Axiom.is_active == 1).all()
|
| 473 |
+
axioms_by_universe.append(set(ax.statement for ax in axioms))
|
| 474 |
+
n = len(universe_ids)
|
| 475 |
+
sim_matrix = np.zeros((n, n))
|
| 476 |
+
for i in range(n):
|
| 477 |
+
for j in range(n):
|
| 478 |
+
if metric == 'jaccard':
|
| 479 |
+
inter = len(axioms_by_universe[i] & axioms_by_universe[j])
|
| 480 |
+
union = len(axioms_by_universe[i] | axioms_by_universe[j])
|
| 481 |
+
sim_matrix[i, j] = inter / union if union else 0.0
|
| 482 |
+
self.logger.info(f"Universe similarity matrix: {sim_matrix}")
|
| 483 |
+
return sim_matrix
|
| 484 |
+
|
| 485 |
+
def cluster_universes(self, universe_ids: List[int], n_clusters: int = 2) -> Dict[int, int]:
|
| 486 |
+
# Cluster universes by axiom similarity
|
| 487 |
+
sim_matrix = self.universe_similarity(universe_ids)
|
| 488 |
+
from sklearn.cluster import KMeans
|
| 489 |
+
kmeans = KMeans(n_clusters=n_clusters, random_state=0).fit(sim_matrix)
|
| 490 |
+
labels = {uid: int(label) for uid, label in zip(universe_ids, kmeans.labels_)}
|
| 491 |
+
self.logger.info(f"Universe clusters: {labels}")
|
| 492 |
+
return labels
|
| 493 |
+
|
| 494 |
+
def detect_anomalies(self, universe_ids: List[int]) -> List[int]:
|
| 495 |
+
# Detect universes with anomalous axiom sets
|
| 496 |
+
sim_matrix = self.universe_similarity(universe_ids)
|
| 497 |
+
mean_sim = np.mean(sim_matrix, axis=1)
|
| 498 |
+
threshold = np.mean(mean_sim) - 2 * np.std(mean_sim)
|
| 499 |
+
anomalies = [uid for uid, sim in zip(universe_ids, mean_sim) if sim < threshold]
|
| 500 |
+
self.logger.info(f"Anomalous universes: {anomalies}")
|
| 501 |
+
return anomalies
|
| 502 |
+
|
| 503 |
+
def transfer_axioms(self, source_universe: int, target_universe: int) -> int:
|
| 504 |
+
# Transfer axioms from one universe to another
|
| 505 |
+
axioms = self.db.query(Axiom).filter(Axiom.universe_id == source_universe, Axiom.is_active == 1).all()
|
| 506 |
+
count = 0
|
| 507 |
+
for ax in axioms:
|
| 508 |
+
new_ax = Axiom(statement=ax.statement, universe_id=target_universe, is_active=1)
|
| 509 |
+
self.db.add(new_ax)
|
| 510 |
+
count += 1
|
| 511 |
+
self.db.commit()
|
| 512 |
+
self.logger.info(f"Transferred {count} axioms from {source_universe} to {target_universe}")
|
| 513 |
+
return count
|
| 514 |
+
|
| 515 |
+
def batch_analyze(self, universe_batches: List[List[int]]) -> List[Dict[str, Any]]:
|
| 516 |
+
results = []
|
| 517 |
+
for batch in universe_batches:
|
| 518 |
+
result = self.analyze(batch)
|
| 519 |
+
results.append(result)
|
| 520 |
+
self.logger.info(f"Batch analysis results: {results}")
|
| 521 |
+
return results
|
| 522 |
+
|
| 523 |
+
def distributed_analyze(self, universe_ids: List[int], num_workers: int = 4) -> List[Dict[str, Any]]:
|
| 524 |
+
# Placeholder for distributed analysis
|
| 525 |
+
self.logger.info(f"Distributed analysis with {num_workers} workers.")
|
| 526 |
+
chunk_size = max(1, len(universe_ids) // num_workers)
|
| 527 |
+
batches = [universe_ids[i:i+chunk_size] for i in range(0, len(universe_ids), chunk_size)]
|
| 528 |
+
return self.batch_analyze(batches)
|
| 529 |
+
|
| 530 |
+
def visualize_similarity(self, universe_ids: List[int]):
|
| 531 |
+
sim_matrix = self.universe_similarity(universe_ids)
|
| 532 |
+
import matplotlib.pyplot as plt
|
| 533 |
+
plt.imshow(sim_matrix, cmap='viridis')
|
| 534 |
+
plt.colorbar()
|
| 535 |
+
plt.title("Universe Similarity Matrix")
|
| 536 |
+
plt.xlabel("Universe Index")
|
| 537 |
+
plt.ylabel("Universe Index")
|
| 538 |
+
plt.show()
|
| 539 |
+
|
| 540 |
+
def explain_analysis(self, universe_ids: List[int]) -> Dict[str, Any]:
|
| 541 |
+
# Placeholder for explainability (e.g., feature importance, lineage)
|
| 542 |
+
return {"universes": universe_ids, "explanation": "Analysis explainability not implemented."}
|
| 543 |
+
|
| 544 |
+
def integrate_with_neuro_symbolic(self, *args, **kwargs):
|
| 545 |
+
self.logger.info("Integrating with neuro-symbolic module.")
|
| 546 |
+
pass
|
| 547 |
+
|
| 548 |
+
def integrate_with_quantum(self, *args, **kwargs):
|
| 549 |
+
self.logger.info("Integrating with quantum module.")
|
| 550 |
+
pass
|
| 551 |
+
|
| 552 |
+
def integrate_with_external_prover(self, *args, **kwargs):
|
| 553 |
+
self.logger.info("Integrating with external prover.")
|
| 554 |
+
pass
|
| 555 |
+
|
| 556 |
+
def analyze(self, universe_ids: List[int]) -> Dict[str, Any]:
|
| 557 |
+
shared_axioms = self.shared_axioms(universe_ids)
|
| 558 |
+
shared_theorems = self.shared_theorems(universe_ids)
|
| 559 |
+
result = {
|
| 560 |
+
"shared_axioms": shared_axioms,
|
| 561 |
+
"shared_theorems": shared_theorems,
|
| 562 |
+
"universes": universe_ids
|
| 563 |
+
}
|
| 564 |
+
# Store result in DB
|
| 565 |
+
for uid in universe_ids:
|
| 566 |
+
analysis = AnalysisResult(universe_id=uid, result=str(result))
|
| 567 |
+
self.db.add(analysis)
|
| 568 |
+
self.db.commit()
|
| 569 |
+
self.logger.info(f"Analysis result stored for universes {universe_ids}")
|
| 570 |
+
return result
|
| 571 |
+
|
| 572 |
+
# --- Research/Test Utilities ---
|
| 573 |
+
def benchmark_analysis(analyze_fn: Callable, universe_ids: List[int], repeats: int = 5) -> Dict[str, Any]:
|
| 574 |
+
import time
|
| 575 |
+
times = []
|
| 576 |
+
for _ in range(repeats):
|
| 577 |
+
start = time.time()
|
| 578 |
+
analyze_fn(universe_ids)
|
| 579 |
+
times.append(time.time() - start)
|
| 580 |
+
return {"mean_time": np.mean(times), "std_time": np.std(times), "runs": repeats}
|
| 581 |
+
|
| 582 |
+
def test_cross_universe_analysis():
|
| 583 |
+
logging.basicConfig(level=logging.INFO)
|
| 584 |
+
analyzer = CrossUniverseAnalyzer()
|
| 585 |
+
# Example universe IDs (replace with real IDs in production)
|
| 586 |
+
universe_ids = [1, 2, 3, 4]
|
| 587 |
+
print("Shared axioms:", analyzer.shared_axioms(universe_ids))
|
| 588 |
+
print("Shared theorems:", analyzer.shared_theorems(universe_ids))
|
| 589 |
+
print("Axiom lineage:", analyzer.axiom_lineage(1))
|
| 590 |
+
print("Theorem influence graph:", analyzer.theorem_influence_graph(universe_ids))
|
| 591 |
+
print("Universe similarity matrix:\n", analyzer.universe_similarity(universe_ids))
|
| 592 |
+
print("Universe clusters:", analyzer.cluster_universes(universe_ids, n_clusters=2))
|
| 593 |
+
print("Anomalous universes:", analyzer.detect_anomalies(universe_ids))
|
| 594 |
+
print("Transferred axioms:", analyzer.transfer_axioms(1, 2))
|
| 595 |
+
analyzer.visualize_similarity(universe_ids)
|
| 596 |
+
print("Explain analysis:", analyzer.explain_analysis(universe_ids))
|
| 597 |
+
|
| 598 |
+
if __name__ == "__main__":
|
| 599 |
+
test_cross_universe_analysis()
|
backend/core/ddar.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Lightweight shim for `ddar` used by several tests. When a real `ddar` implementation
|
| 3 |
+
is available in other project submodules, Python's import system will prefer that.
|
| 4 |
+
This shim provides minimal safe implementations so tests that import `ddar` during
|
| 5 |
+
collection won't fail.
|
| 6 |
+
"""
|
| 7 |
+
from typing import Any, List
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def solve(graph: Any, rules: Any, problem: Any) -> Any:
|
| 11 |
+
# Minimal stub: pretend to solve by returning None
|
| 12 |
+
return None
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class Solver:
|
| 16 |
+
def __init__(self):
|
| 17 |
+
pass
|
| 18 |
+
|
| 19 |
+
def run(self, *args, **kwargs):
|
| 20 |
+
return None
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# Export common names used in tests
|
| 24 |
+
__all__ = ["solve", "Solver"]
|