Spaces:
Running
Running
Upload 83 files
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .dockerignore +60 -0
- .github/workflows/core-tests.yml +23 -0
- .github/workflows/models-tests.yml +40 -0
- .github/workflows/workflow.yml +38 -0
- .gitignore +175 -0
- DEMO.md +203 -0
- Dockerfile +52 -0
- LICENSE +21 -0
- README.md +642 -12
- app.py +378 -0
- docker-compose.yml +55 -0
- docs/api_guide.md +37 -0
- docs/index.md +22 -0
- docs/installation.md +87 -0
- docs/metadata_schemas.md +57 -0
- mkdocs.yml +11 -0
- pages/1_Dashboard.py +125 -0
- pages/2_Control_Panel.py +138 -0
- pages/3_NQL_Chatbot.py +94 -0
- pages/4_Data_Explorer.py +167 -0
- pages/5_API_Playground.py +57 -0
- pages/6_Financial_Forecast_Demo.py +236 -0
- pages/api_playground_v2.py +171 -0
- pages/control_panel_v2.py +300 -0
- pages/data_explorer_v2.py +224 -0
- pages/nql_chatbot_v2.py +168 -0
- pages/pages_shared_utils.py +547 -0
- pages/ui_utils.py +267 -0
- pyproject.toml +41 -0
- requirements-test.txt +17 -0
- requirements.txt +54 -0
- setup.sh +12 -0
- tensorus/__init__.py +18 -0
- tensorus/api.py +0 -0
- tensorus/api/__init__.py +13 -0
- tensorus/api/dependencies.py +36 -0
- tensorus/api/endpoints.py +601 -0
- tensorus/api/main.py +78 -0
- tensorus/api/security.py +133 -0
- tensorus/audit.py +68 -0
- tensorus/automl_agent.py +356 -0
- tensorus/config.py +95 -0
- tensorus/dummy_env.py +91 -0
- tensorus/financial_data_generator.py +122 -0
- tensorus/ingestion_agent.py +409 -0
- tensorus/mcp_client.py +128 -0
- tensorus/mcp_server.py +228 -0
- tensorus/metadata/__init__.py +101 -0
- tensorus/metadata/postgres_storage.py +741 -0
- tensorus/metadata/schemas.py +250 -0
.dockerignore
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Git files
|
| 2 |
+
.git/
|
| 3 |
+
.gitignore
|
| 4 |
+
|
| 5 |
+
# Python cache and compiled files
|
| 6 |
+
__pycache__/
|
| 7 |
+
*.pyc
|
| 8 |
+
*.pyo
|
| 9 |
+
*.pyd
|
| 10 |
+
*.egg-info/
|
| 11 |
+
dist/
|
| 12 |
+
build/
|
| 13 |
+
|
| 14 |
+
# Virtual environments
|
| 15 |
+
env/
|
| 16 |
+
venv/
|
| 17 |
+
.venv/
|
| 18 |
+
|
| 19 |
+
# IDE and editor specific files
|
| 20 |
+
.vscode/
|
| 21 |
+
.idea/
|
| 22 |
+
*.swp
|
| 23 |
+
*.swo
|
| 24 |
+
|
| 25 |
+
# Secrets or local configuration (should be passed via environment variables)
|
| 26 |
+
.env
|
| 27 |
+
secrets/
|
| 28 |
+
*.pem
|
| 29 |
+
*.key
|
| 30 |
+
|
| 31 |
+
# Log files
|
| 32 |
+
*.log
|
| 33 |
+
tensorus_audit.log # Specifically exclude this if it's generated locally
|
| 34 |
+
|
| 35 |
+
# Test files and reports (unless you want to run tests in the image)
|
| 36 |
+
tests/
|
| 37 |
+
htmlcov/
|
| 38 |
+
.pytest_cache/
|
| 39 |
+
.coverage
|
| 40 |
+
|
| 41 |
+
# OS-specific files
|
| 42 |
+
.DS_Store
|
| 43 |
+
Thumbs.db
|
| 44 |
+
|
| 45 |
+
# Docker files themselves (no need to copy into the image)
|
| 46 |
+
Dockerfile
|
| 47 |
+
docker-compose.yml
|
| 48 |
+
|
| 49 |
+
# Other project specific files/folders to exclude
|
| 50 |
+
# E.g., notebooks/, docs/, etc.
|
| 51 |
+
# data/ # If you have a local data folder not needed in the image
|
| 52 |
+
# setup.sh # Exclude if not part of the Python API image
|
| 53 |
+
# DEMO.md # Exclude if not part of the Python API image
|
| 54 |
+
# LICENSE # Usually good to include, but can be excluded if desired
|
| 55 |
+
# README.md # Usually good to include, but can be excluded if desired
|
| 56 |
+
# pages/ # If these are Streamlit pages not served by this Docker image.
|
| 57 |
+
|
| 58 |
+
# If `app.py` at root is a Streamlit launcher, and not part of this specific API service, exclude it too.
|
| 59 |
+
# Assuming `app.py` from the root is not part of this specific service being Dockerized.
|
| 60 |
+
app.py
|
.github/workflows/core-tests.yml
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: Core Tests
|
| 2 |
+
|
| 3 |
+
on:
|
| 4 |
+
push:
|
| 5 |
+
branches: [main]
|
| 6 |
+
pull_request:
|
| 7 |
+
|
| 8 |
+
jobs:
|
| 9 |
+
test:
|
| 10 |
+
runs-on: ubuntu-latest
|
| 11 |
+
steps:
|
| 12 |
+
- name: Checkout code
|
| 13 |
+
uses: actions/checkout@v4
|
| 14 |
+
- name: Set up Python
|
| 15 |
+
uses: actions/setup-python@v5
|
| 16 |
+
with:
|
| 17 |
+
python-version: '3.11'
|
| 18 |
+
- name: Install dependencies
|
| 19 |
+
run: |
|
| 20 |
+
python -m pip install -r requirements.txt
|
| 21 |
+
python -m pip install -r requirements-test.txt
|
| 22 |
+
- name: Run tests
|
| 23 |
+
run: pytest -v
|
.github/workflows/models-tests.yml
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: Models Repository Tests
|
| 2 |
+
|
| 3 |
+
on:
|
| 4 |
+
workflow_dispatch:
|
| 5 |
+
schedule:
|
| 6 |
+
- cron: '0 0 * * 0'
|
| 7 |
+
|
| 8 |
+
jobs:
|
| 9 |
+
test-models:
|
| 10 |
+
runs-on: ubuntu-latest
|
| 11 |
+
steps:
|
| 12 |
+
- name: Checkout core repo
|
| 13 |
+
uses: actions/checkout@v4
|
| 14 |
+
with:
|
| 15 |
+
path: core
|
| 16 |
+
- name: Checkout models repo
|
| 17 |
+
uses: actions/checkout@v4
|
| 18 |
+
with:
|
| 19 |
+
repository: tensorus/models
|
| 20 |
+
path: models
|
| 21 |
+
- name: Set up Python
|
| 22 |
+
uses: actions/setup-python@v5
|
| 23 |
+
with:
|
| 24 |
+
python-version: '3.11'
|
| 25 |
+
- name: Install dependencies
|
| 26 |
+
run: |
|
| 27 |
+
python -m pip install -r core/requirements.txt
|
| 28 |
+
python -m pip install -r core/requirements-test.txt
|
| 29 |
+
- name: Prepare models package
|
| 30 |
+
run: |
|
| 31 |
+
mkdir -p models/tensorus
|
| 32 |
+
cp -r core/tensorus models/tensorus
|
| 33 |
+
mkdir -p models/tensorus/models
|
| 34 |
+
cp models/*.py models/tensorus/models/
|
| 35 |
+
cp models/__init__.py models/tensorus/models/__init__.py
|
| 36 |
+
- name: Run models tests
|
| 37 |
+
working-directory: models
|
| 38 |
+
env:
|
| 39 |
+
PYTHONPATH: ${{ github.workspace }}/models
|
| 40 |
+
run: pytest -v tests
|
.github/workflows/workflow.yml
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: Publish Python 🐍 distribution 📦 to PyPI
|
| 2 |
+
|
| 3 |
+
on:
|
| 4 |
+
push:
|
| 5 |
+
branches:
|
| 6 |
+
- main
|
| 7 |
+
release:
|
| 8 |
+
types: [published]
|
| 9 |
+
|
| 10 |
+
permissions:
|
| 11 |
+
id-token: write
|
| 12 |
+
contents: read
|
| 13 |
+
|
| 14 |
+
jobs:
|
| 15 |
+
build-and-publish:
|
| 16 |
+
name: Build and publish to PyPI
|
| 17 |
+
environment: pypi
|
| 18 |
+
runs-on: ubuntu-latest
|
| 19 |
+
|
| 20 |
+
steps:
|
| 21 |
+
- name: Checkout code
|
| 22 |
+
uses: actions/checkout@v4
|
| 23 |
+
|
| 24 |
+
- name: Set up Python
|
| 25 |
+
uses: actions/setup-python@v5
|
| 26 |
+
with:
|
| 27 |
+
python-version: '3.9'
|
| 28 |
+
|
| 29 |
+
- name: Install build dependencies
|
| 30 |
+
run: python -m pip install --upgrade build
|
| 31 |
+
|
| 32 |
+
- name: Build package
|
| 33 |
+
run: python -m build
|
| 34 |
+
|
| 35 |
+
- name: Publish to PyPI
|
| 36 |
+
uses: pypa/gh-action-pypi-publish@release/v1
|
| 37 |
+
with:
|
| 38 |
+
skip-existing: true
|
.gitignore
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
# C extensions
|
| 7 |
+
*.so
|
| 8 |
+
|
| 9 |
+
# Distribution / packaging
|
| 10 |
+
.Python
|
| 11 |
+
build/
|
| 12 |
+
develop-eggs/
|
| 13 |
+
dist/
|
| 14 |
+
downloads/
|
| 15 |
+
eggs/
|
| 16 |
+
.eggs/
|
| 17 |
+
lib/
|
| 18 |
+
lib64/
|
| 19 |
+
parts/
|
| 20 |
+
sdist/
|
| 21 |
+
var/
|
| 22 |
+
wheels/
|
| 23 |
+
share/python-wheels/
|
| 24 |
+
*.egg-info/
|
| 25 |
+
.installed.cfg
|
| 26 |
+
*.egg
|
| 27 |
+
MANIFEST
|
| 28 |
+
|
| 29 |
+
# PyInstaller
|
| 30 |
+
# Usually these files are written by a python script from a template
|
| 31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 32 |
+
*.manifest
|
| 33 |
+
*.spec
|
| 34 |
+
|
| 35 |
+
# Installer logs
|
| 36 |
+
pip-log.txt
|
| 37 |
+
pip-delete-this-directory.txt
|
| 38 |
+
|
| 39 |
+
# Unit test / coverage reports
|
| 40 |
+
htmlcov/
|
| 41 |
+
.tox/
|
| 42 |
+
.nox/
|
| 43 |
+
.coverage
|
| 44 |
+
.coverage.*
|
| 45 |
+
.cache
|
| 46 |
+
nosetests.xml
|
| 47 |
+
coverage.xml
|
| 48 |
+
*.cover
|
| 49 |
+
*.py,cover
|
| 50 |
+
.hypothesis/
|
| 51 |
+
.pytest_cache/
|
| 52 |
+
cover/
|
| 53 |
+
|
| 54 |
+
# Translations
|
| 55 |
+
*.mo
|
| 56 |
+
*.pot
|
| 57 |
+
|
| 58 |
+
# Django stuff:
|
| 59 |
+
*.log
|
| 60 |
+
local_settings.py
|
| 61 |
+
db.sqlite3
|
| 62 |
+
db.sqlite3-journal
|
| 63 |
+
|
| 64 |
+
# Flask stuff:
|
| 65 |
+
instance/
|
| 66 |
+
.webassets-cache
|
| 67 |
+
|
| 68 |
+
# Scrapy stuff:
|
| 69 |
+
.scrapy
|
| 70 |
+
|
| 71 |
+
# Sphinx documentation
|
| 72 |
+
docs/_build/
|
| 73 |
+
|
| 74 |
+
# PyBuilder
|
| 75 |
+
.pybuilder/
|
| 76 |
+
target/
|
| 77 |
+
|
| 78 |
+
# Jupyter Notebook
|
| 79 |
+
.ipynb_checkpoints
|
| 80 |
+
|
| 81 |
+
# IPython
|
| 82 |
+
profile_default/
|
| 83 |
+
ipython_config.py
|
| 84 |
+
|
| 85 |
+
# pyenv
|
| 86 |
+
# For a library or package, you might want to ignore these files since the code is
|
| 87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
| 88 |
+
# .python-version
|
| 89 |
+
|
| 90 |
+
# pipenv
|
| 91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 94 |
+
# install all needed dependencies.
|
| 95 |
+
#Pipfile.lock
|
| 96 |
+
|
| 97 |
+
# UV
|
| 98 |
+
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
| 99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 100 |
+
# commonly ignored for libraries.
|
| 101 |
+
#uv.lock
|
| 102 |
+
|
| 103 |
+
# poetry
|
| 104 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
| 105 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 106 |
+
# commonly ignored for libraries.
|
| 107 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
| 108 |
+
#poetry.lock
|
| 109 |
+
|
| 110 |
+
# pdm
|
| 111 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
| 112 |
+
#pdm.lock
|
| 113 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
| 114 |
+
# in version control.
|
| 115 |
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
| 116 |
+
.pdm.toml
|
| 117 |
+
.pdm-python
|
| 118 |
+
.pdm-build/
|
| 119 |
+
|
| 120 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
| 121 |
+
__pypackages__/
|
| 122 |
+
|
| 123 |
+
# Celery stuff
|
| 124 |
+
celerybeat-schedule
|
| 125 |
+
celerybeat.pid
|
| 126 |
+
|
| 127 |
+
# SageMath parsed files
|
| 128 |
+
*.sage.py
|
| 129 |
+
|
| 130 |
+
# Environments
|
| 131 |
+
.env
|
| 132 |
+
.venv
|
| 133 |
+
env/
|
| 134 |
+
venv/
|
| 135 |
+
ENV/
|
| 136 |
+
env.bak/
|
| 137 |
+
venv.bak/
|
| 138 |
+
|
| 139 |
+
# Spyder project settings
|
| 140 |
+
.spyderproject
|
| 141 |
+
.spyproject
|
| 142 |
+
|
| 143 |
+
# Rope project settings
|
| 144 |
+
.ropeproject
|
| 145 |
+
|
| 146 |
+
# mkdocs documentation
|
| 147 |
+
/site
|
| 148 |
+
|
| 149 |
+
# mypy
|
| 150 |
+
.mypy_cache/
|
| 151 |
+
.dmypy.json
|
| 152 |
+
dmypy.json
|
| 153 |
+
|
| 154 |
+
# Pyre type checker
|
| 155 |
+
.pyre/
|
| 156 |
+
|
| 157 |
+
# pytype static type analyzer
|
| 158 |
+
.pytype/
|
| 159 |
+
|
| 160 |
+
# Cython debug symbols
|
| 161 |
+
cython_debug/
|
| 162 |
+
|
| 163 |
+
# PyCharm
|
| 164 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
| 165 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 166 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 167 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 168 |
+
#.idea/
|
| 169 |
+
|
| 170 |
+
# Ruff stuff:
|
| 171 |
+
.ruff_cache/
|
| 172 |
+
|
| 173 |
+
# PyPI configuration file
|
| 174 |
+
.pypirc
|
| 175 |
+
|
DEMO.md
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Tensorus Demo Script: Showcasing Agentic Tensor Management
|
| 2 |
+
|
| 3 |
+
## Introduction
|
| 4 |
+
|
| 5 |
+
This demo showcases the key capabilities of Tensorus, an agentic tensor database/data lake. We'll walk through data ingestion, storage, querying, tensor operations, and how to interact with the system via its UI and APIs.
|
| 6 |
+
|
| 7 |
+
## Prerequisites
|
| 8 |
+
|
| 9 |
+
* Tensorus backend API running (`uvicorn api:app --reload --host 127.0.0.1 --port 7860`)
|
| 10 |
+
* Tensorus Streamlit UI running (`streamlit run app.py`)
|
| 11 |
+
* Tensorus MCP Server running (`python -m tensorus.mcp_server`)
|
| 12 |
+
* A terminal for API calls (e.g., using `curl`) or a tool like Postman.
|
| 13 |
+
* A web browser.
|
| 14 |
+
|
| 15 |
+
## Demo Scenarios
|
| 16 |
+
|
| 17 |
+
### Scenario 1: Automated Image Data Ingestion and Exploration
|
| 18 |
+
|
| 19 |
+
**Goal:** Demonstrate how new image data is automatically ingested, stored, and can be explored.
|
| 20 |
+
|
| 21 |
+
1. **Prepare for Ingestion:**
|
| 22 |
+
* Open the Streamlit UI in your browser (usually `http://localhost:8501`).
|
| 23 |
+
* Navigate to the "Agents" or "Control Panel" page.
|
| 24 |
+
* Ensure the "Data Ingestion Agent" is targeting the dataset `ingested_data_api` and monitoring the directory `temp_ingestion_source_api`. (You might need to configure this in `tensorus/api.py` if not dynamically configurable via UI).
|
| 25 |
+
* Start the "Data Ingestion Agent" if it's not already running. View its logs in the UI to see it polling.
|
| 26 |
+
* Verify the `temp_ingestion_source_api` directory (relative to where the API is run) is initially empty or clean it out.
|
| 27 |
+
|
| 28 |
+
2. **Simulate New Data Arrival:**
|
| 29 |
+
* Download a sample image (e.g., a picture of a cat, `cat.jpg`) into the `temp_ingestion_source_api` directory on your local filesystem.
|
| 30 |
+
* **Attractive Element:** Use a visually distinct and interesting image.
|
| 31 |
+
|
| 32 |
+
3. **Observe Ingestion:**
|
| 33 |
+
* In the Streamlit UI, watch the Ingestion Agent's logs. You should see messages indicating it detected `cat.jpg`, processed it, and ingested it into `ingested_data_api`.
|
| 34 |
+
|
| 35 |
+
4. **Explore Ingested Data (UI):**
|
| 36 |
+
* Navigate to the "Data Explorer" page in the Streamlit UI.
|
| 37 |
+
* Select the `ingested_data_api` dataset from the dropdown.
|
| 38 |
+
* You should see the newly ingested `cat.jpg` (or its metadata) listed.
|
| 39 |
+
* Click to view its details (metadata, tensor shape, perhaps a preview if the UI supports it).
|
| 40 |
+
* **Attractive Element:** The UI showing the image tensor's information clearly, perhaps even a thumbnail if the UI is advanced.
|
| 41 |
+
|
| 42 |
+
5. **Verify via API (Optional):**
|
| 43 |
+
* Use `curl` or Postman to fetch the tensor details. First, list datasets to find `ingested_data_api`, then fetch its records, identify the `record_id` for `cat.jpg` from its metadata.
|
| 44 |
+
```bash
|
| 45 |
+
# Example: List records in the dataset to find the ID (paged)
|
| 46 |
+
curl "http://127.0.0.1:7860/datasets/ingested_data_api/records?offset=0&limit=100"
|
| 47 |
+
# Then use the ID:
|
| 48 |
+
# curl http://127.0.0.1:7860/datasets/ingested_data_api/tensors/{record_id_of_cat_jpg}
|
| 49 |
+
```
|
| 50 |
+
* Confirm the API returns the tensor data and metadata.
|
| 51 |
+
|
| 52 |
+
### Scenario 2: Natural Language Querying for Specific Data
|
| 53 |
+
|
| 54 |
+
**Goal:** Show how NQL can be used to find specific tensors.
|
| 55 |
+
|
| 56 |
+
1. **Ensure Data Exists:**
|
| 57 |
+
* Make sure the `cat.jpg` from Scenario 1 is present in `ingested_data_api`.
|
| 58 |
+
* Optionally, add another image, e.g., `dog.png`, to `temp_ingestion_source_api` and let it be ingested. Give it distinctive metadata if possible, e.g., by modifying the `ingestion_agent.py` to add `{"animal_type": "dog"}` or by using the API to update metadata post-ingestion.
|
| 59 |
+
|
| 60 |
+
2. **Use NQL Chatbot (UI):**
|
| 61 |
+
* Navigate to the "NQL Chatbot" or "Query Hub" page in the Streamlit UI.
|
| 62 |
+
* Enter a query. Given the basic NQL, a query targeting the filename is most reliable:
|
| 63 |
+
`show all data from ingested_data_api where source_file contains "cat.jpg"`
|
| 64 |
+
* If you added custom metadata like `{"animal_type": "cat"}` for the cat image, you could try:
|
| 65 |
+
`find records from ingested_data_api where animal_type = 'cat'`
|
| 66 |
+
* **Attractive Element:** The chatbot interface and the directness of the natural language query yielding correct results.
|
| 67 |
+
|
| 68 |
+
3. **Observe Results:**
|
| 69 |
+
* The UI should display the tensor(s) matching your query.
|
| 70 |
+
|
| 71 |
+
### Scenario 3: Performing Tensor Operations via API
|
| 72 |
+
|
| 73 |
+
**Goal:** Demonstrate applying a tensor operation to a stored tensor.
|
| 74 |
+
|
| 75 |
+
1. **Identify a Target Tensor:**
|
| 76 |
+
* From Scenario 1 or 2, obtain the `record_id` of the `cat.jpg` tensor within the `ingested_data_api` dataset. Let's assume its `record_id` is `xyz123`.
|
| 77 |
+
|
| 78 |
+
2. **Perform a Tensor Operation (e.g., Transpose API):**
|
| 79 |
+
* Image tensors are often (C, H, W) or (H, W, C). Let's assume it's (C, H, W) and we want to transpose H and W, which would be `dim0=1, dim1=2`.
|
| 80 |
+
* Use `curl` or Postman:
|
| 81 |
+
```bash
|
| 82 |
+
curl -X POST -H "Content-Type: application/json" -d '{
|
| 83 |
+
"input_tensor": {
|
| 84 |
+
"dataset_name": "ingested_data_api",
|
| 85 |
+
"record_id": "xyz123"
|
| 86 |
+
},
|
| 87 |
+
"params": {
|
| 88 |
+
"dim0": 1,
|
| 89 |
+
"dim1": 2
|
| 90 |
+
},
|
| 91 |
+
"output_dataset_name": "ops_results",
|
| 92 |
+
"output_metadata": {"original_id": "xyz123", "operation": "transpose_height_width_demo"}
|
| 93 |
+
}' http://127.0.0.1:7860/ops/transpose
|
| 94 |
+
```
|
| 95 |
+
* **Attractive Element:** Showing the API call and the structured JSON response indicating success and the new transposed tensor's ID and details.
|
| 96 |
+
|
| 97 |
+
3. **Verify Result:**
|
| 98 |
+
* The API response will contain the `record_id` and details of the new (transposed) tensor in the `ops_results` dataset.
|
| 99 |
+
* Note the new shape in the response. If the original was (3, 128, 128), the new shape should be (3, 128, 128) after transposing height and width (assuming they were the same). If they were different, e.g. (3, 128, 200), the new shape would be (3, 200, 128).
|
| 100 |
+
* Optionally, use the "Data Explorer" in the UI or another API call to fetch and inspect this new tensor.
|
| 101 |
+
|
| 102 |
+
### Scenario 4: Interacting with the MCP Server (Conceptual)
|
| 103 |
+
|
| 104 |
+
**Goal:** Explain how an external AI agent could leverage Tensorus via MCP.
|
| 105 |
+
|
| 106 |
+
1. **Show MCP Server Running:**
|
| 107 |
+
* Point to the terminal where `python -m tensorus.mcp_server` is running and show its log output (e.g., "Tensorus MCP Server connected via stdio and ready.").
|
| 108 |
+
|
| 109 |
+
2. **Explain Available Tools (Conceptual):**
|
| 110 |
+
* Briefly show the tool definitions in `tensorus/mcp_server.py` or refer to the README's "Available Tools" under "MCP Server Details".
|
| 111 |
+
* Highlight a few tools like `tensorus_list_datasets`, `tensorus_ingest_tensor`, and `tensorus_apply_binary_operation`.
|
| 112 |
+
|
| 113 |
+
3. **Conceptual Client Interaction (Show code snippet from README):**
|
| 114 |
+
* Show the example client-side JavaScript snippet from the `README.md`:
|
| 115 |
+
```javascript
|
| 116 |
+
// Conceptual MCP client-side JavaScript
|
| 117 |
+
// Assuming 'client' is an initialized MCP client connected to the Tensorus MCP Server
|
| 118 |
+
|
| 119 |
+
async function example() {
|
| 120 |
+
// List available tools
|
| 121 |
+
const { tools } = await client.request({ method: 'tools/list' }, {});
|
| 122 |
+
console.log("Available Tensorus Tools:", tools.map(t => t.name));
|
| 123 |
+
|
| 124 |
+
// Create a new dataset
|
| 125 |
+
const createResponse = await client.request({ method: 'tools/call' }, {
|
| 126 |
+
name: 'tensorus_create_dataset',
|
| 127 |
+
arguments: { dataset_name: 'my_mcp_dataset_demo' }
|
| 128 |
+
});
|
| 129 |
+
console.log(JSON.parse(createResponse.content[0].text).message); // MCP server often returns JSON string in text
|
| 130 |
+
}
|
| 131 |
+
```
|
| 132 |
+
* **Attractive Element:** Emphasize that this allows *other* AI agents or LLMs to programmatically use Tensorus as a modular component in a larger intelligent system, promoting interoperability.
|
| 133 |
+
|
| 134 |
+
### Scenario 5: Financial Time Series Forecasting with ARIMA
|
| 135 |
+
|
| 136 |
+
**Goal:** Demonstrate end-to-end time series forecasting using generated financial data, Tensorus for storage, an ARIMA model for prediction, and visualization within a dedicated UI page.
|
| 137 |
+
|
| 138 |
+
**Prerequisites Specific to this Demo:**
|
| 139 |
+
* Ensure `statsmodels` is installed. If you used the standard setup, install it via the optional models extras:
|
| 140 |
+
```bash
|
| 141 |
+
pip install -e .[models]
|
| 142 |
+
```
|
| 143 |
+
or install the `tensorus-models` package which includes it.
|
| 144 |
+
|
| 145 |
+
**Steps:**
|
| 146 |
+
|
| 147 |
+
1. **Navigate to the Demo Page:**
|
| 148 |
+
* Open the Streamlit UI (e.g., `http://localhost:8501`).
|
| 149 |
+
* From the top navigation bar (or sidebar if the UI structure varies), find and click on the "Financial Forecast Demo" page (it might be titled "📈 Financial Forecast Demo" or similar).
|
| 150 |
+
|
| 151 |
+
2. **Generate and Ingest Data:**
|
| 152 |
+
* On the "Financial Forecast Demo" page, locate the section "1. Data Generation & Ingestion."
|
| 153 |
+
* Click the button labeled **"Generate & Ingest Sample Financial Data"**.
|
| 154 |
+
* Wait for the spinner to complete. You should see:
|
| 155 |
+
* A success message indicating data ingestion into a dataset like `financial_raw_data` in Tensorus.
|
| 156 |
+
* A sample DataFrame (head) of the generated data (Date, Close, Volume).
|
| 157 |
+
* A Plotly chart displaying the historical 'Close' prices that were just generated and ingested.
|
| 158 |
+
* **Attractive Element:** Observe the immediate visualization of the generated time series.
|
| 159 |
+
|
| 160 |
+
3. **Configure and Run ARIMA Prediction:**
|
| 161 |
+
* Go to the "3. ARIMA Model Prediction" section on the page.
|
| 162 |
+
* You can adjust the ARIMA order (p, d, q) and the number of future predictions if you wish. Default values (e.g., p=5, d=1, q=0, predictions=30) are provided.
|
| 163 |
+
* Click the button labeled **"Run ARIMA Prediction"**.
|
| 164 |
+
* Wait for the spinner to complete. This step involves:
|
| 165 |
+
* Loading the historical data from Tensorus.
|
| 166 |
+
* Training/fitting the ARIMA model.
|
| 167 |
+
* Generating future predictions.
|
| 168 |
+
* Storing these predictions back into Tensorus (e.g., into `financial_predictions` dataset).
|
| 169 |
+
|
| 170 |
+
4. **View Prediction Results:**
|
| 171 |
+
* Once the prediction is complete, scroll to the "4. Prediction Results" section.
|
| 172 |
+
* You should see:
|
| 173 |
+
* A Plotly chart displaying the original historical data with the ARIMA predictions plotted alongside/extending from it.
|
| 174 |
+
* A table or list showing the actual predicted values for future dates.
|
| 175 |
+
* **Attractive Element:** The clear visual comparison of the forecast against the historical data, showcasing the model's predictive attempt. The interactivity of Plotly charts (zoom, pan) enhances this.
|
| 176 |
+
|
| 177 |
+
5. **Interpretation (What this demonstrates):**
|
| 178 |
+
* **Data Flow:** Generation -> Tensorus Storage (Raw Data) -> Retrieval for Modeling -> Prediction -> Tensorus Storage (Predictions) -> UI Visualization.
|
| 179 |
+
* **Ease of Use:** A user-friendly interface to perform a complex task like time series forecasting.
|
| 180 |
+
* **Modularity:** Integration of data generation, storage (Tensorus), modeling (statsmodels), and UI (Streamlit) components.
|
| 181 |
+
* **Revised UI:** Notice the potentially improved layout and charting capabilities on this dedicated demo page.
|
| 182 |
+
|
| 183 |
+
### Scenario 6: Dashboard Overview
|
| 184 |
+
|
| 185 |
+
**Goal:** Show the main dashboard providing a system overview.
|
| 186 |
+
|
| 187 |
+
1. **Navigate to Dashboard:**
|
| 188 |
+
* In the Streamlit UI, go to the main "Nexus Dashboard" page (usually the default page when you run `streamlit run app.py`).
|
| 189 |
+
2. **Review Metrics:**
|
| 190 |
+
* Point out the key metrics displayed:
|
| 191 |
+
* Total Tensors / Active Datasets (these might be placeholders or simple counts).
|
| 192 |
+
* Agents Online / Status (showing the Ingestion Agent as 'running' if you started it).
|
| 193 |
+
* API Status (should be 'Connected').
|
| 194 |
+
* Simulated metrics like data ingestion rate, query latency, RL rewards, AutoML progress. Explain these are illustrative.
|
| 195 |
+
* **Attractive Element:** A visually appealing dashboard. If any metrics are updating (even if simulated based on time), it adds to the dynamic feel.
|
| 196 |
+
|
| 197 |
+
3. **Activity Feed (if populated):**
|
| 198 |
+
* Show the "Recent Agent Activity" feed. If the Ingestion Agent is running, it might populate this feed, or it might be placeholder data.
|
| 199 |
+
* Explain that in a fully operational system, this feed would show real-time updates from all active agents.
|
| 200 |
+
|
| 201 |
+
## Conclusion
|
| 202 |
+
|
| 203 |
+
This demo has provided a glimpse into Tensorus's capabilities, including automated data handling by agents, flexible data storage, natural language querying, powerful tensor operations, and a user-friendly interface. The MCP server further extends its reach, allowing programmatic interaction from other AI systems, paving the way for more complex, collaborative AI workflows.
|
Dockerfile
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Use an official Python runtime as a parent image
|
| 2 |
+
FROM python:3.9-slim
|
| 3 |
+
|
| 4 |
+
# Set environment variables for Python
|
| 5 |
+
ENV PYTHONDONTWRITEBYTECODE 1
|
| 6 |
+
ENV PYTHONUNBUFFERED 1
|
| 7 |
+
|
| 8 |
+
# Set the working directory in the container
|
| 9 |
+
WORKDIR /app
|
| 10 |
+
|
| 11 |
+
# Install system dependencies that might be needed by Python packages
|
| 12 |
+
# Example: build-essential for some packages, libpq-dev for psycopg2 from source (though -binary avoids this)
|
| 13 |
+
# For psycopg2-binary, typically no extra system deps are needed for common platforms if using a compatible wheel.
|
| 14 |
+
# RUN apt-get update && apt-get install -y --no-install-recommends build-essential libpq-dev && rm -rf /var/lib/apt/lists/*
|
| 15 |
+
|
| 16 |
+
# Copy the requirements file into the container
|
| 17 |
+
COPY requirements.txt .
|
| 18 |
+
|
| 19 |
+
# Install Python dependencies
|
| 20 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 21 |
+
|
| 22 |
+
# Copy the application source code into the container
|
| 23 |
+
# This structure assumes your main application code is within the 'tensorus' directory.
|
| 24 |
+
COPY ./tensorus ./tensorus
|
| 25 |
+
# If your app.py or main.py is at the root alongside Dockerfile, you'd copy it too:
|
| 26 |
+
# COPY app.py . # Or specific main file if it's at the root.
|
| 27 |
+
# Based on `CMD ["uvicorn", "tensorus.api.main:app"...]`, main:app is in tensorus/api/main.py.
|
| 28 |
+
|
| 29 |
+
# Set default environment variables for the application
|
| 30 |
+
# These can be overridden when running the container (e.g., via docker run -e or docker-compose.yml)
|
| 31 |
+
ENV TENSORUS_STORAGE_BACKEND="in_memory"
|
| 32 |
+
ENV TENSORUS_POSTGRES_HOST="db" # Default for docker-compose setup
|
| 33 |
+
ENV TENSORUS_POSTGRES_PORT="5432"
|
| 34 |
+
ENV TENSORUS_POSTGRES_USER="tensorus_user_dockerfile"
|
| 35 |
+
ENV TENSORUS_POSTGRES_PASSWORD="tensorus_pass_dockerfile"
|
| 36 |
+
ENV TENSORUS_POSTGRES_DB="tensorus_db_dockerfile"
|
| 37 |
+
ENV TENSORUS_POSTGRES_DSN=""
|
| 38 |
+
ENV TENSORUS_API_KEY_HEADER_NAME="X-API-KEY"
|
| 39 |
+
ENV TENSORUS_VALID_API_KEYS="" # Example: "key1,key2" - Must be set at runtime for security
|
| 40 |
+
ENV TENSORUS_AUTH_JWT_ENABLED="False" # Default JWT auth to disabled
|
| 41 |
+
ENV TENSORUS_AUTH_JWT_ISSUER=""
|
| 42 |
+
ENV TENSORUS_AUTH_JWT_AUDIENCE=""
|
| 43 |
+
ENV TENSORUS_AUTH_JWT_ALGORITHM="RS256"
|
| 44 |
+
ENV TENSORUS_AUTH_JWT_JWKS_URI=""
|
| 45 |
+
ENV TENSORUS_AUTH_DEV_MODE_ALLOW_DUMMY_JWT="False"
|
| 46 |
+
|
| 47 |
+
# Expose the port the app runs on
|
| 48 |
+
EXPOSE 7860
|
| 49 |
+
|
| 50 |
+
# Define the command to run the application
|
| 51 |
+
# This assumes your FastAPI app instance is named 'app' in 'tensorus.api.main'
|
| 52 |
+
CMD ["uvicorn", "tensorus.api.main:app", "--host", "0.0.0.0", "--port", "7860"]
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2025 Tensorus
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README.md
CHANGED
|
@@ -1,12 +1,642 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
-
|
| 11 |
-
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Tensorus: Agentic Tensor Database/Data Lake
|
| 2 |
+
|
| 3 |
+
Tensorus is a specialized data platform focused on the management and agent-driven manipulation of tensor data. It offers a streamlined environment for storing, retrieving, and operating on tensors, laying the groundwork for advanced AI and machine learning workflows.
|
| 4 |
+
|
| 5 |
+
The core purpose of Tensorus is to simplify and enhance how developers and AI agents interact with tensor datasets. By providing dedicated tools for tensor operations and a framework for agentic integration, Tensorus aims to accelerate tasks like automated data ingestion, reinforcement learning from stored experiences, and AutoML processes, ultimately enabling more efficient and intelligent data utilization in AI projects.
|
| 6 |
+
|
| 7 |
+
## Key Features
|
| 8 |
+
|
| 9 |
+
* **Tensor Storage:** Efficiently store and retrieve PyTorch tensors with associated metadata.
|
| 10 |
+
* **Dataset Schemas:** Optional per-dataset schemas enforce required metadata fields and tensor shape/dtype.
|
| 11 |
+
* **Natural Query Language (NQL):** Query your tensor data using a simple, natural language-like syntax.
|
| 12 |
+
* **Agent Framework:** A foundation for building and integrating intelligent agents that interact with your data.
|
| 13 |
+
* **Data Ingestion Agent:** Automatically monitors a directory for new files and ingests them as tensors.
|
| 14 |
+
* **RL Agent:** A Deep Q-Network (DQN) agent that can learn from experiences stored in TensorStorage.
|
| 15 |
+
* **AutoML Agent:** Performs hyperparameter optimization for a dummy model using random search.
|
| 16 |
+
* **API-Driven:** A FastAPI backend provides a RESTful API for interacting with Tensorus.
|
| 17 |
+
* **Streamlit UI:** A user-friendly Streamlit frontend for exploring data and controlling agents.
|
| 18 |
+
* **Tensor Operations:** A comprehensive library of robust tensor operations for common manipulations. See [Basic Tensor Operations](#basic-tensor-operations) for details.
|
| 19 |
+
* **Model System:** Optional model registry with example models provided in a
|
| 20 |
+
separate package. See [Tensorus Models](https://github.com/tensorus/models).
|
| 21 |
+
* **Metadata System:** Rich Pydantic schemas and storage backends for semantic, lineage, computational, quality, relational, and usage metadata.
|
| 22 |
+
* **Extensible:** Designed to be extended with more advanced agents, storage backends, and query capabilities.
|
| 23 |
+
* **Model Context Protocol (MCP) Server:** Provides a standardized interface for AI agents and LLMs to interact with Tensorus capabilities—including dataset management, tensor storage, and operations—using the Model Context Protocol. (See [MCP Server Details](#mcp-server-details) below).
|
| 24 |
+
|
| 25 |
+
## Project Structure
|
| 26 |
+
|
| 27 |
+
* `app.py`: The main Streamlit frontend application (located at the project root).
|
| 28 |
+
* `pages/`: Directory containing individual Streamlit page scripts and shared UI utilities for the dashboard.
|
| 29 |
+
* `pages/ui_utils.py`: Utility functions specifically for the Streamlit UI.
|
| 30 |
+
* *(Other page scripts like `01_dashboard.py`, `02_control_panel.py`, etc., define the different views of the dashboard)*
|
| 31 |
+
* `tensorus/`: Directory containing the core `tensorus` library modules (this is the main installable package).
|
| 32 |
+
* `tensorus/__init__.py`: Makes `tensorus` a Python package.
|
| 33 |
+
* `tensorus/api.py`: The FastAPI application providing the backend API for Tensorus.
|
| 34 |
+
* `tensorus/tensor_storage.py`: Core TensorStorage implementation for managing tensor data.
|
| 35 |
+
* `tensorus/tensor_ops.py`: Library of functions for tensor manipulations.
|
| 36 |
+
* `tensorus/nql_agent.py`: Agent for processing Natural Query Language queries.
|
| 37 |
+
* `tensorus/ingestion_agent.py`: Agent for ingesting data from various sources.
|
| 38 |
+
* `tensorus/rl_agent.py`: Agent for Reinforcement Learning tasks.
|
| 39 |
+
* `tensorus/automl_agent.py`: Agent for AutoML processes.
|
| 40 |
+
* `tensorus/dummy_env.py`: A simple environment for the RL agent demonstration.
|
| 41 |
+
* *(Other Python files within `tensorus/` are part of the core library.)*
|
| 42 |
+
* `requirements.txt`: Lists the project's Python dependencies for development and local execution.
|
| 43 |
+
* `pyproject.toml`: Project metadata, dependencies for distribution, and build system configuration (e.g., for PyPI).
|
| 44 |
+
* `tensorus/mcp_server.py`: Python implementation of the Model Context Protocol (MCP) server using `fastmcp`.
|
| 45 |
+
* `README.md`: This file.
|
| 46 |
+
* `LICENSE`: Project license file.
|
| 47 |
+
* `.gitignore`: Specifies intentionally untracked files that Git should ignore.
|
| 48 |
+
|
| 49 |
+
## Huggingface Demo
|
| 50 |
+
|
| 51 |
+
You can try Tensorus online via Huggingface Spaces:
|
| 52 |
+
|
| 53 |
+
* **API Documentation:** [Swagger UI](https://tensorus-api.hf.space/docs) | [ReDoc](https://tensorus-api.hf.space/redoc)
|
| 54 |
+
* **Dashboard UI:** [Streamlit Dashboard](https://tensorus-dashboard.hf.space)
|
| 55 |
+
|
| 56 |
+
## Tensorus Execution Cycle
|
| 57 |
+
|
| 58 |
+
```mermaid
|
| 59 |
+
graph TD
|
| 60 |
+
%% User Interface Layer
|
| 61 |
+
subgraph UI_Layer ["User Interaction"]
|
| 62 |
+
UI[Streamlit UI]
|
| 63 |
+
end
|
| 64 |
+
|
| 65 |
+
%% API Gateway Layer
|
| 66 |
+
subgraph API_Layer ["Backend Services"]
|
| 67 |
+
API[FastAPI Backend]
|
| 68 |
+
MCP["MCP Server (FastMCP Python)"]
|
| 69 |
+
end
|
| 70 |
+
|
| 71 |
+
%% Core Storage with Method Interface
|
| 72 |
+
subgraph Storage_Layer ["Core Storage - TensorStorage"]
|
| 73 |
+
TS[TensorStorage Core]
|
| 74 |
+
subgraph Storage_Methods ["Storage Interface"]
|
| 75 |
+
TS_insert[insert data metadata]
|
| 76 |
+
TS_query[query query_fn]
|
| 77 |
+
TS_get[get_by_id id]
|
| 78 |
+
TS_sample[sample n]
|
| 79 |
+
TS_update[update_metadata]
|
| 80 |
+
end
|
| 81 |
+
TS --- Storage_Methods
|
| 82 |
+
end
|
| 83 |
+
|
| 84 |
+
%% Agent Processing Layer
|
| 85 |
+
subgraph Agent_Layer ["Processing Agents"]
|
| 86 |
+
IA[Ingestion Agent]
|
| 87 |
+
NQLA[NQL Agent]
|
| 88 |
+
RLA[RL Agent]
|
| 89 |
+
AutoMLA[AutoML Agent]
|
| 90 |
+
end
|
| 91 |
+
|
| 92 |
+
%% Model System
|
| 93 |
+
subgraph Model_Layer ["Model System"]
|
| 94 |
+
Registry[Model Registry]
|
| 95 |
+
ModelsPkg[Models Package]
|
| 96 |
+
end
|
| 97 |
+
|
| 98 |
+
%% Tensor Operations Library
|
| 99 |
+
subgraph Ops_Layer ["Tensor Operations"]
|
| 100 |
+
TOps[TensorOps Library]
|
| 101 |
+
end
|
| 102 |
+
|
| 103 |
+
%% Primary UI & MCP Flow
|
| 104 |
+
UI -->|HTTP Requests| API
|
| 105 |
+
MCP -->|MCP Calls| API
|
| 106 |
+
|
| 107 |
+
%% API Orchestration
|
| 108 |
+
API -->|Command Dispatch| IA
|
| 109 |
+
API -->|Command Dispatch| NQLA
|
| 110 |
+
API -->|Command Dispatch| RLA
|
| 111 |
+
API -->|Command Dispatch| AutoMLA
|
| 112 |
+
API -->|Model Training| Registry
|
| 113 |
+
API -->|Direct Query| TS_query
|
| 114 |
+
|
| 115 |
+
%% Model System Interactions
|
| 116 |
+
Registry -->|Uses Models| ModelsPkg
|
| 117 |
+
Registry -->|Load/Save| TS
|
| 118 |
+
ModelsPkg -->|Tensor Ops| TOps
|
| 119 |
+
|
| 120 |
+
%% Agent Storage Interactions
|
| 121 |
+
IA -->|Data Ingestion| TS_insert
|
| 122 |
+
|
| 123 |
+
NQLA -->|Query Execution| TS_query
|
| 124 |
+
NQLA -->|Record Retrieval| TS_get
|
| 125 |
+
|
| 126 |
+
RLA -->|State Persistence| TS_insert
|
| 127 |
+
RLA -->|Experience Sampling| TS_sample
|
| 128 |
+
RLA -->|State Retrieval| TS_get
|
| 129 |
+
|
| 130 |
+
AutoMLA -->|Trial Storage| TS_insert
|
| 131 |
+
AutoMLA -->|Data Retrieval| TS_query
|
| 132 |
+
|
| 133 |
+
%% Computational Operations
|
| 134 |
+
NQLA -->|Vector Operations| TOps
|
| 135 |
+
RLA -->|Policy Evaluation| TOps
|
| 136 |
+
AutoMLA -->|Model Optimization| TOps
|
| 137 |
+
|
| 138 |
+
%% Indirect Storage Write-back
|
| 139 |
+
TOps -.->|Intermediate Results| TS_insert
|
| 140 |
+
```
|
| 141 |
+
|
| 142 |
+
## Getting Started
|
| 143 |
+
|
| 144 |
+
### Prerequisites
|
| 145 |
+
|
| 146 |
+
* Python 3.9+
|
| 147 |
+
* PyTorch
|
| 148 |
+
* FastAPI
|
| 149 |
+
* Uvicorn
|
| 150 |
+
* Streamlit
|
| 151 |
+
* Pydantic v2
|
| 152 |
+
* Requests
|
| 153 |
+
* Pillow (for image preprocessing)
|
| 154 |
+
* Matplotlib (optional, for plotting RL rewards)
|
| 155 |
+
|
| 156 |
+
### Installation
|
| 157 |
+
|
| 158 |
+
1. Clone the repository:
|
| 159 |
+
|
| 160 |
+
```bash
|
| 161 |
+
git clone https://github.com/tensorus/tensorus.git
|
| 162 |
+
cd tensorus
|
| 163 |
+
```
|
| 164 |
+
|
| 165 |
+
2. Create a virtual environment (recommended):
|
| 166 |
+
|
| 167 |
+
```bash
|
| 168 |
+
python3 -m venv venv
|
| 169 |
+
source venv/bin/activate # On Linux/macOS
|
| 170 |
+
venv\Scripts\activate # On Windows
|
| 171 |
+
```
|
| 172 |
+
|
| 173 |
+
3. Install dependencies using the provided setup script:
|
| 174 |
+
|
| 175 |
+
```bash
|
| 176 |
+
./setup.sh
|
| 177 |
+
```
|
| 178 |
+
This installs Python requirements from `requirements.txt` and
|
| 179 |
+
`requirements-test.txt`, using CPU wheels for PyTorch and
|
| 180 |
+
pinning `httpx` to a compatible version. The test requirements
|
| 181 |
+
also install `fastapi>=0.110` for compatibility with Pydantic v2.
|
| 182 |
+
The script also installs test requirements for running the Python test suite.
|
| 183 |
+
Heavy machine-learning libraries (e.g. `xgboost`, `lightgbm`, `catboost`,
|
| 184 |
+
`statsmodels`, `torch-geometric`) are not installed by default. Install
|
| 185 |
+
them separately using `pip install tensorus[models]` or by installing the
|
| 186 |
+
`tensorus-models` package if you need the built-in models.
|
| 187 |
+
The audit logger writes to `tensorus_audit.log` by default. Override the
|
| 188 |
+
path with the `TENSORUS_AUDIT_LOG_PATH` environment variable if desired.
|
| 189 |
+
|
| 190 |
+
### Running the API Server
|
| 191 |
+
|
| 192 |
+
1. Navigate to the project root directory (the directory containing the `tensorus` folder and `pyproject.toml`).
|
| 193 |
+
2. Ensure your virtual environment is activated if you are using one.
|
| 194 |
+
3. Start the FastAPI backend server using:
|
| 195 |
+
|
| 196 |
+
```bash
|
| 197 |
+
python -m uvicorn tensorus.api:app --reload --host 127.0.0.1 --port 7860
|
| 198 |
+
```
|
| 199 |
+
|
| 200 |
+
* The `python -m uvicorn` command ensures that Python runs Uvicorn as a module, and `tensorus.api:app` correctly points to the `app` instance within your `tensorus/api.py` file.
|
| 201 |
+
* `--reload` enables auto-reload for development.
|
| 202 |
+
* Access the API documentation at `http://127.0.0.1:7860/docs` or `http://127.0.0.1:7860/redoc`.
|
| 203 |
+
|
| 204 |
+
### Running the Streamlit UI
|
| 205 |
+
|
| 206 |
+
1. In a separate terminal (with the virtual environment activated), navigate to the project root.
|
| 207 |
+
2. Start the Streamlit frontend:
|
| 208 |
+
|
| 209 |
+
```bash
|
| 210 |
+
streamlit run app.py
|
| 211 |
+
```
|
| 212 |
+
|
| 213 |
+
* Access the UI in your browser at the URL provided by Streamlit (usually `http://localhost:8501`).
|
| 214 |
+
|
| 215 |
+
### Running the MCP Server
|
| 216 |
+
|
| 217 |
+
Tensorus provides a lightweight Python implementation of the Model Context Protocol server using `fastmcp`. It exposes the FastAPI endpoints as tools so you can run an MCP server without Node.js.
|
| 218 |
+
|
| 219 |
+
**Starting the MCP Server:**
|
| 220 |
+
|
| 221 |
+
1. Install dependencies (includes `fastmcp`):
|
| 222 |
+
```bash
|
| 223 |
+
pip install -r requirements.txt
|
| 224 |
+
```
|
| 225 |
+
2. Ensure the FastAPI backend is running.
|
| 226 |
+
3. Start the server from the repository root:
|
| 227 |
+
```bash
|
| 228 |
+
python -m tensorus.mcp_server
|
| 229 |
+
```
|
| 230 |
+
Add `--transport sse` to use SSE transport.
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
### Running the Agents (Examples)
|
| 234 |
+
|
| 235 |
+
You can run the example agents directly from their respective files:
|
| 236 |
+
|
| 237 |
+
* **RL Agent:**
|
| 238 |
+
|
| 239 |
+
```bash
|
| 240 |
+
python tensorus/rl_agent.py
|
| 241 |
+
```
|
| 242 |
+
|
| 243 |
+
* **AutoML Agent:**
|
| 244 |
+
|
| 245 |
+
```bash
|
| 246 |
+
python tensorus/automl_agent.py
|
| 247 |
+
```
|
| 248 |
+
|
| 249 |
+
* **Ingestion Agent:**
|
| 250 |
+
|
| 251 |
+
```bash
|
| 252 |
+
python tensorus/ingestion_agent.py
|
| 253 |
+
```
|
| 254 |
+
|
| 255 |
+
* Note: The Ingestion Agent will monitor the `temp_ingestion_source` directory (created automatically if it doesn't exist in the project root) for new files.
|
| 256 |
+
|
| 257 |
+
### Docker Usage
|
| 258 |
+
|
| 259 |
+
Tensorus can also be run inside a Docker container. Build the image from the project root:
|
| 260 |
+
|
| 261 |
+
```bash
|
| 262 |
+
docker build -t tensorus .
|
| 263 |
+
```
|
| 264 |
+
|
| 265 |
+
Run the container and expose the API server on port `7860`:
|
| 266 |
+
|
| 267 |
+
```bash
|
| 268 |
+
docker run -p 7860:7860 tensorus
|
| 269 |
+
```
|
| 270 |
+
|
| 271 |
+
The FastAPI documentation will then be available at `http://localhost:7860/docs`.
|
| 272 |
+
|
| 273 |
+
If your system has NVIDIA GPUs and the [NVIDIA Container Toolkit](https://github.com/NVIDIA/nvidia-docker) installed, you can pass `--gpus all` to `docker run` and modify `setup.sh` to install CUDA-enabled PyTorch wheels for GPU acceleration.
|
| 274 |
+
|
| 275 |
+
### Running Tests
|
| 276 |
+
|
| 277 |
+
Tensorus includes Python unit tests. To set up the environment and run them:
|
| 278 |
+
|
| 279 |
+
1. Install all dependencies using:
|
| 280 |
+
|
| 281 |
+
```bash
|
| 282 |
+
./setup.sh
|
| 283 |
+
```
|
| 284 |
+
|
| 285 |
+
This script installs packages from `requirements.txt` and `requirements-test.txt` (which pins `fastapi>=0.110` for Pydantic v2 support).
|
| 286 |
+
|
| 287 |
+
2. Run the Python test suite:
|
| 288 |
+
|
| 289 |
+
```bash
|
| 290 |
+
pytest
|
| 291 |
+
```
|
| 292 |
+
|
| 293 |
+
To specifically verify the Model Context Protocol components, run the MCP
|
| 294 |
+
server and client tests:
|
| 295 |
+
|
| 296 |
+
```bash
|
| 297 |
+
pytest tests/test_mcp_server.py tests/test_mcp_client.py
|
| 298 |
+
```
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
## Using Tensorus
|
| 302 |
+
|
| 303 |
+
### API Endpoints
|
| 304 |
+
|
| 305 |
+
The API provides the following main endpoints:
|
| 306 |
+
|
| 307 |
+
* **Datasets:**
|
| 308 |
+
* `POST /datasets/create`: Create a new dataset.
|
| 309 |
+
* `POST /datasets/{name}/ingest`: Ingest a tensor into a dataset.
|
| 310 |
+
* `GET /datasets/{name}/fetch`: Retrieve all records from a dataset.
|
| 311 |
+
* `GET /datasets/{name}/records`: Retrieve a page of records. Supports `offset` (start index, default `0`) and `limit` (max results, default `100`).
|
| 312 |
+
* `GET /datasets`: List all available datasets.
|
| 313 |
+
* **Querying:**
|
| 314 |
+
* `POST /query`: Execute an NQL query.
|
| 315 |
+
* **Agents:**
|
| 316 |
+
* `GET /agents`: List all registered agents.
|
| 317 |
+
* `GET /agents/{agent_id}/status`: Get the status of a specific agent.
|
| 318 |
+
* `POST /agents/{agent_id}/start`: Start an agent.
|
| 319 |
+
* `POST /agents/{agent_id}/stop`: Stop an agent.
|
| 320 |
+
* `GET /agents/{agent_id}/logs`: Get recent logs for an agent.
|
| 321 |
+
* **Metrics & Monitoring:**
|
| 322 |
+
* `GET /metrics/dashboard`: Get aggregated dashboard metrics.
|
| 323 |
+
|
| 324 |
+
### Dataset Schemas
|
| 325 |
+
|
| 326 |
+
Datasets can optionally include a schema when created. The schema defines
|
| 327 |
+
required metadata fields and expected tensor `shape` and `dtype`. Inserts that
|
| 328 |
+
violate the schema will raise a validation error.
|
| 329 |
+
|
| 330 |
+
Example:
|
| 331 |
+
|
| 332 |
+
```python
|
| 333 |
+
schema = {
|
| 334 |
+
"shape": [3, 10],
|
| 335 |
+
"dtype": "float32",
|
| 336 |
+
"metadata": {"source": "str", "value": "int"}
|
| 337 |
+
}
|
| 338 |
+
storage.create_dataset("my_ds", schema=schema)
|
| 339 |
+
storage.insert("my_ds", torch.rand(3, 10), {"source": "sensor", "value": 5})
|
| 340 |
+
```
|
| 341 |
+
|
| 342 |
+
## Metadata System
|
| 343 |
+
|
| 344 |
+
Tensorus includes a detailed metadata subsystem for describing tensors beyond their raw data. Each tensor has a `TensorDescriptor` and can be associated with optional semantic, lineage, computational, quality, relational, and usage metadata. The metadata storage backend is pluggable, supporting in-memory storage for quick testing or PostgreSQL for persistence. Search and aggregation utilities allow querying across these metadata fields. See [metadata_schemas.md](docs/metadata_schemas.md) for schema details.
|
| 345 |
+
|
| 346 |
+
### Streamlit UI
|
| 347 |
+
|
| 348 |
+
### Streamlit UI
|
| 349 |
+
|
| 350 |
+
The Streamlit UI provides a user-friendly interface for:
|
| 351 |
+
|
| 352 |
+
* **Dashboard:** View basic system metrics and agent status.
|
| 353 |
+
* **Agent Control:** Start, stop, and view logs for agents.
|
| 354 |
+
* **NQL Chat:** Enter natural language queries and view results.
|
| 355 |
+
* **Data Explorer:** Browse datasets, preview data, and perform tensor operations.
|
| 356 |
+
|
| 357 |
+
## Natural Query Language (NQL)
|
| 358 |
+
|
| 359 |
+
Tensorus ships with a simple regex‑based Natural Query Language for retrieving
|
| 360 |
+
tensors by metadata. You can issue NQL queries via the API or from the "NQL
|
| 361 |
+
Chat" page in the Streamlit UI.
|
| 362 |
+
|
| 363 |
+
### Enabling LLM rewriting
|
| 364 |
+
|
| 365 |
+
Set `NQL_USE_LLM=true` before starting the API server or Streamlit UI to enable
|
| 366 |
+
experimental LLM rewriting of natural language queries. Optionally specify a
|
| 367 |
+
model with `NQL_LLM_MODEL=<model-name>` (e.g., `google/flan-t5-base`). This
|
| 368 |
+
feature relies on the heavy `transformers` dependency and may trigger a model
|
| 369 |
+
download the first time it runs, which can take some time.
|
| 370 |
+
|
| 371 |
+
## Agent Details
|
| 372 |
+
|
| 373 |
+
### Data Ingestion Agent
|
| 374 |
+
|
| 375 |
+
* **Functionality:** Monitors a source directory for new files, preprocesses them into tensors, and inserts them into TensorStorage.
|
| 376 |
+
* **Supported File Types:** CSV, PNG, JPG, JPEG, TIF, TIFF (can be extended).
|
| 377 |
+
* **Preprocessing:** Uses default functions for CSV and images (resize, normalize).
|
| 378 |
+
* **Configuration:**
|
| 379 |
+
* `source_directory`: The directory to monitor.
|
| 380 |
+
* `polling_interval_sec`: How often to check for new files.
|
| 381 |
+
* `preprocessing_rules`: A dictionary mapping file extensions to custom preprocessing functions.
|
| 382 |
+
|
| 383 |
+
### RL Agent
|
| 384 |
+
|
| 385 |
+
* **Functionality:** A Deep Q-Network (DQN) agent that learns from experiences stored in TensorStorage.
|
| 386 |
+
* **Environment:** Uses a `DummyEnv` for demonstration.
|
| 387 |
+
* **Experience Storage:** Stores experiences (state, action, reward, next_state, done) in TensorStorage.
|
| 388 |
+
* **Training:** Implements epsilon-greedy exploration and target network updates.
|
| 389 |
+
* **Configuration:**
|
| 390 |
+
* `state_dim`: Dimensionality of the environment state.
|
| 391 |
+
* `action_dim`: Number of discrete actions.
|
| 392 |
+
* `hidden_size`: Hidden layer size for the DQN.
|
| 393 |
+
* `lr`: Learning rate.
|
| 394 |
+
* `gamma`: Discount factor.
|
| 395 |
+
* `epsilon_*`: Epsilon-greedy parameters.
|
| 396 |
+
* `target_update_freq`: Target network update frequency.
|
| 397 |
+
* `batch_size`: Experience batch size.
|
| 398 |
+
* `experience_dataset`: Dataset name for experiences.
|
| 399 |
+
* `state_dataset`: Dataset name for state tensors.
|
| 400 |
+
|
| 401 |
+
### AutoML Agent
|
| 402 |
+
|
| 403 |
+
* **Functionality:** Performs hyperparameter optimization using random search.
|
| 404 |
+
* **Model:** Trains a simple `DummyMLP` model.
|
| 405 |
+
* **Search Space:** Configurable hyperparameter search space (learning rate, hidden size, activation).
|
| 406 |
+
* **Evaluation:** Trains and evaluates models on synthetic data.
|
| 407 |
+
* **Results:** Stores trial results (parameters, score) in TensorStorage.
|
| 408 |
+
* **Configuration:**
|
| 409 |
+
* `search_space`: Dictionary defining the hyperparameter search space.
|
| 410 |
+
* `input_dim`: Input dimension for the model.
|
| 411 |
+
* `output_dim`: Output dimension for the model.
|
| 412 |
+
* `task_type`: Type of task ('regression' or 'classification').
|
| 413 |
+
* `results_dataset`: Dataset name for storing results.
|
| 414 |
+
|
| 415 |
+
|
| 416 |
+
### Tensorus Models
|
| 417 |
+
|
| 418 |
+
The collection of example models previously bundled with Tensorus now lives in
|
| 419 |
+
a separate repository: [tensorus/models](https://github.com/tensorus/models).
|
| 420 |
+
Install it with:
|
| 421 |
+
|
| 422 |
+
```bash
|
| 423 |
+
pip install tensorus-models
|
| 424 |
+
```
|
| 425 |
+
|
| 426 |
+
When the package is installed, Tensorus will automatically import it. Set the
|
| 427 |
+
environment variable `TENSORUS_MINIMAL_IMPORT=1` before importing Tensorus to
|
| 428 |
+
skip this optional dependency and keep startup lightweight.
|
| 429 |
+
|
| 430 |
+
## Basic Tensor Operations
|
| 431 |
+
|
| 432 |
+
This section details the core tensor manipulation functionalities provided by `tensor_ops.py`. These operations are designed to be robust, with built-in type and shape checking where appropriate.
|
| 433 |
+
|
| 434 |
+
#### Arithmetic Operations
|
| 435 |
+
|
| 436 |
+
* `add(t1, t2)`: Element-wise addition of two tensors, or a tensor and a scalar.
|
| 437 |
+
* `subtract(t1, t2)`: Element-wise subtraction of two tensors, or a tensor and a scalar.
|
| 438 |
+
* `multiply(t1, t2)`: Element-wise multiplication of two tensors, or a tensor and a scalar.
|
| 439 |
+
* `divide(t1, t2)`: Element-wise division of two tensors, or a tensor and a scalar. Includes checks for division by zero.
|
| 440 |
+
* `power(t1, t2)`: Raises each element in `t1` to the power of `t2`. Supports tensor or scalar exponents.
|
| 441 |
+
* `log(tensor)`: Element-wise natural logarithm with warnings for non-positive values.
|
| 442 |
+
|
| 443 |
+
#### Matrix and Dot Operations
|
| 444 |
+
|
| 445 |
+
* `matmul(t1, t2)`: Matrix multiplication of two tensors, supporting various dimensionalities (e.g., 2D matrices, batched matrix multiplication).
|
| 446 |
+
* `dot(t1, t2)`: Computes the dot product of two 1D tensors.
|
| 447 |
+
* `outer(t1, t2)`: Computes the outer product of two 1‑D tensors.
|
| 448 |
+
* `cross(t1, t2, dim=-1)`: Computes the cross product along the specified dimension (size must be 3).
|
| 449 |
+
* `matrix_eigendecomposition(matrix_A)`: Returns eigenvalues and eigenvectors of a square matrix.
|
| 450 |
+
* `matrix_trace(matrix_A)`: Computes the trace of a 2-D matrix.
|
| 451 |
+
* `tensor_trace(tensor_A, axis1=0, axis2=1)`: Trace of a tensor along two axes.
|
| 452 |
+
* `svd(matrix)`: Singular value decomposition of a matrix, returns `U`, `S`, and `Vh`.
|
| 453 |
+
* `qr_decomposition(matrix)`: QR decomposition returning `Q` and `R`.
|
| 454 |
+
* `lu_decomposition(matrix)`: LU decomposition returning permutation `P`, lower `L`, and upper `U` matrices.
|
| 455 |
+
* `cholesky_decomposition(matrix)`: Cholesky factor of a symmetric positive-definite matrix.
|
| 456 |
+
* `matrix_inverse(matrix)`: Inverse of a square matrix.
|
| 457 |
+
* `matrix_determinant(matrix)`: Determinant of a square matrix.
|
| 458 |
+
* `matrix_rank(matrix)`: Rank of a matrix.
|
| 459 |
+
|
| 460 |
+
#### Reduction Operations
|
| 461 |
+
|
| 462 |
+
* `sum(tensor, dim=None, keepdim=False)`: Computes the sum of tensor elements over specified dimensions.
|
| 463 |
+
* `mean(tensor, dim=None, keepdim=False)`: Computes the mean of tensor elements over specified dimensions. Tensor is cast to float for calculation.
|
| 464 |
+
* `min(tensor, dim=None, keepdim=False)`: Finds the minimum value in a tensor, optionally along a dimension. Returns values and indices if `dim` is specified.
|
| 465 |
+
* `max(tensor, dim=None, keepdim=False)`: Finds the maximum value in a tensor, optionally along a dimension. Returns values and indices if `dim` is specified.
|
| 466 |
+
* `variance(tensor, dim=None, unbiased=False, keepdim=False)`: Variance of tensor elements.
|
| 467 |
+
* `covariance(matrix_X, matrix_Y=None, rowvar=True, bias=False, ddof=None)`: Covariance matrix estimation.
|
| 468 |
+
* `correlation(matrix_X, matrix_Y=None, rowvar=True)`: Correlation coefficient matrix.
|
| 469 |
+
|
| 470 |
+
#### Reshaping and Slicing
|
| 471 |
+
|
| 472 |
+
* `reshape(tensor, shape)`: Changes the shape of a tensor without changing its data.
|
| 473 |
+
* `transpose(tensor, dim0, dim1)`: Swaps two dimensions of a tensor.
|
| 474 |
+
* `permute(tensor, dims)`: Permutes the dimensions of a tensor according to the specified order.
|
| 475 |
+
* `flatten(tensor, start_dim=0, end_dim=-1)`: Flattens a range of dimensions into a single dimension.
|
| 476 |
+
* `squeeze(tensor, dim=None)`: Removes dimensions of size 1, or a specific dimension if provided.
|
| 477 |
+
* `unsqueeze(tensor, dim)`: Inserts a dimension of size 1 at the given position.
|
| 478 |
+
|
| 479 |
+
#### Concatenation and Splitting
|
| 480 |
+
|
| 481 |
+
* `concatenate(tensors, dim=0)`: Joins a sequence of tensors along an existing dimension.
|
| 482 |
+
* `stack(tensors, dim=0)`: Joins a sequence of tensors along a new dimension.
|
| 483 |
+
|
| 484 |
+
#### Advanced Operations
|
| 485 |
+
|
| 486 |
+
* `einsum(equation, *tensors)`: Applies Einstein summation convention to the input tensors based on the provided equation string.
|
| 487 |
+
* `compute_gradient(func, tensor)`: Returns the gradient of a scalar `func` with respect to `tensor`.
|
| 488 |
+
* `compute_jacobian(func, tensor)`: Computes the Jacobian matrix of a vector function.
|
| 489 |
+
* `convolve_1d(signal_x, kernel_w, mode='valid')`: 1‑D convolution using `torch.nn.functional.conv1d`.
|
| 490 |
+
* `convolve_2d(image_I, kernel_K, mode='valid')`: 2‑D convolution using `torch.nn.functional.conv2d`.
|
| 491 |
+
* `frobenius_norm(tensor)`: Calculates the Frobenius norm.
|
| 492 |
+
* `l1_norm(tensor)`: Calculates the L1 norm (sum of absolute values).
|
| 493 |
+
|
| 494 |
+
## Tensor Decomposition Operations
|
| 495 |
+
|
| 496 |
+
Tensorus includes a library of higher‑order tensor factorizations in
|
| 497 |
+
`tensor_decompositions.py`. These operations mirror the algorithms
|
| 498 |
+
available in TensorLy and related libraries.
|
| 499 |
+
|
| 500 |
+
* **CP Decomposition** – Canonical Polyadic factorization returning
|
| 501 |
+
weights and factor matrices.
|
| 502 |
+
* **NTF‑CP Decomposition** – Non‑negative CP using
|
| 503 |
+
`non_negative_parafac`.
|
| 504 |
+
* **Tucker Decomposition** – Standard Tucker factorization for specified
|
| 505 |
+
ranks.
|
| 506 |
+
* **Non‑negative Tucker / Partial Tucker** – Variants with HOOI and
|
| 507 |
+
non‑negative constraints.
|
| 508 |
+
* **HOSVD** – Higher‑order SVD (Tucker with full ranks).
|
| 509 |
+
* **Tensor Train (TT)** – Sequence of TT cores representing the tensor.
|
| 510 |
+
* **TT‑SVD** – TT factorization via SVD initialization.
|
| 511 |
+
* **Tensor Ring (TR)** – Circular variant of TT.
|
| 512 |
+
* **Hierarchical Tucker (HT)** – Decomposition using a dimension tree.
|
| 513 |
+
* **Block Term Decomposition (BTD)** – Sum of Tucker‑1 terms for 3‑way
|
| 514 |
+
tensors.
|
| 515 |
+
* **t‑SVD** – Tensor singular value decomposition based on the
|
| 516 |
+
t‑product.
|
| 517 |
+
|
| 518 |
+
Examples of how to call these methods are provided in
|
| 519 |
+
[`tensorus/tensor_decompositions.py`](tensorus/tensor_decompositions.py).
|
| 520 |
+
|
| 521 |
+
## MCP Server Details
|
| 522 |
+
|
| 523 |
+
The Tensorus Model Context Protocol (MCP) Server allows external AI agents, LLM-based applications, and other MCP-compatible clients to interact with Tensorus functionalities in a standardized way. It acts as a bridge, translating MCP requests into calls to the Tensorus Python API.
|
| 524 |
+
|
| 525 |
+
### Overview
|
| 526 |
+
|
| 527 |
+
* **Protocol:** Implements the [Model Context Protocol](https://modelcontextprotocol.io/introduction).
|
| 528 |
+
* **Language:** Python, using the `fastmcp` library.
|
| 529 |
+
* **Communication:** Typically uses stdio for communication with a single client.
|
| 530 |
+
* **Interface:** Exposes Tensorus capabilities as a set of "tools" that an MCP client can list and call.
|
| 531 |
+
|
| 532 |
+
### Available Tools
|
| 533 |
+
|
| 534 |
+
The MCP server provides tools for various Tensorus functionalities. Below is an overview. For detailed input schemas and descriptions, an MCP client can call the standard `tools/list` method on the server, or you can inspect the tool definitions in `tensorus/mcp_server.py`.
|
| 535 |
+
|
| 536 |
+
* **Dataset Management:**
|
| 537 |
+
* `tensorus_list_datasets`: Lists all available datasets.
|
| 538 |
+
* `tensorus_create_dataset`: Creates a new dataset.
|
| 539 |
+
* `tensorus_delete_dataset`: Deletes an existing dataset.
|
| 540 |
+
* **Tensor Management:**
|
| 541 |
+
* `tensorus_ingest_tensor`: Ingests a new tensor (with data provided as JSON) into a dataset.
|
| 542 |
+
* `tensorus_get_tensor_details`: Retrieves the data and metadata for a specific tensor.
|
| 543 |
+
* `tensorus_delete_tensor`: Deletes a specific tensor from a dataset.
|
| 544 |
+
* `tensorus_update_tensor_metadata`: Updates the metadata of a specific tensor.
|
| 545 |
+
* **Tensor Operations:** These tools allow applying operations from the `TensorOps` library to tensors stored in Tensorus.
|
| 546 |
+
* `tensorus_apply_unary_operation`: Applies operations like `log`, `reshape`, `transpose`, `permute`, `sum`, `mean`, `min`, `max`.
|
| 547 |
+
* `tensorus_apply_binary_operation`: Applies operations like `add`, `subtract`, `multiply`, `divide`, `power`, `matmul`, `dot`.
|
| 548 |
+
* `tensorus_apply_list_operation`: Applies operations like `concatenate` and `stack` that take a list of input tensors.
|
| 549 |
+
* `tensorus_apply_einsum`: Applies Einstein summation.
|
| 550 |
+
|
| 551 |
+
*Note on Tensor Operations via MCP:* Input tensors are referenced by their `dataset_name` and `record_id`. The result is typically stored as a new tensor, and the MCP tool returns details of this new result tensor (like its `record_id`).
|
| 552 |
+
|
| 553 |
+
### Example Client Interaction (Conceptual)
|
| 554 |
+
|
| 555 |
+
```javascript
|
| 556 |
+
// Conceptual MCP client-side JavaScript
|
| 557 |
+
// Assuming 'client' is an initialized MCP client connected to the Tensorus MCP Server
|
| 558 |
+
|
| 559 |
+
async function example() {
|
| 560 |
+
// List available tools
|
| 561 |
+
const { tools } = await client.request({ method: 'tools/list' }, {});
|
| 562 |
+
console.log("Available Tensorus Tools:", tools.map(t => t.name)); // Log only names for brevity
|
| 563 |
+
|
| 564 |
+
// Create a new dataset
|
| 565 |
+
const createResponse = await client.request({ method: 'tools/call' }, {
|
| 566 |
+
name: 'tensorus_create_dataset',
|
| 567 |
+
arguments: { dataset_name: 'my_mcp_dataset' }
|
| 568 |
+
});
|
| 569 |
+
console.log(JSON.parse(createResponse.content[0].text).message);
|
| 570 |
+
|
| 571 |
+
// Ingest a tensor
|
| 572 |
+
const ingestResponse = await client.request({ method: 'tools/call' }, {
|
| 573 |
+
name: 'tensorus_ingest_tensor',
|
| 574 |
+
arguments: {
|
| 575 |
+
dataset_name: 'my_mcp_dataset',
|
| 576 |
+
tensor_shape: [2, 2],
|
| 577 |
+
tensor_dtype: 'float32',
|
| 578 |
+
tensor_data: [[1.0, 2.0], [3.0, 4.0]],
|
| 579 |
+
metadata: { source: 'mcp_client_example' }
|
| 580 |
+
}
|
| 581 |
+
});
|
| 582 |
+
// Assuming the Python API returns { success, message, data: { record_id, ... } }
|
| 583 |
+
// And MCP server stringifies this whole object in the text content
|
| 584 |
+
const ingestData = JSON.parse(ingestResponse.content[0].text);
|
| 585 |
+
console.log("Ingest success:", ingestData.success, "Record ID:", ingestData.data.record_id);
|
| 586 |
+
}
|
| 587 |
+
```
|
| 588 |
+
|
| 589 |
+
You can also interact with the server using the included Python helper:
|
| 590 |
+
|
| 591 |
+
```python
|
| 592 |
+
from tensorus.mcp_client import TensorusMCPClient
|
| 593 |
+
|
| 594 |
+
async def example_py():
|
| 595 |
+
async with TensorusMCPClient("http://localhost:7860/sse") as client:
|
| 596 |
+
tools = await client.list_datasets()
|
| 597 |
+
print(tools)
|
| 598 |
+
```
|
| 599 |
+
|
| 600 |
+
## Completed Features
|
| 601 |
+
|
| 602 |
+
* **Tensor Storage:** Efficiently stores and retrieves PyTorch tensors with associated metadata, including in-memory and optional file-based persistence. Supports dataset creation, tensor ingestion, querying, sampling, and metadata updates.
|
| 603 |
+
* **Natural Query Language (NQL):** Provides a basic regex-based natural language interface for querying tensor data, supporting retrieval and simple filtering.
|
| 604 |
+
* **Agent Framework:** Includes several operational agents:
|
| 605 |
+
* **Data Ingestion Agent:** Monitors local directories for CSV and image files, preprocesses them, and ingests them into TensorStorage.
|
| 606 |
+
* **RL Agent:** Implements a DQN agent that learns from experiences (stored in TensorStorage) in a dummy environment.
|
| 607 |
+
* **AutoML Agent:** Performs random search hyperparameter optimization for a dummy MLP model, storing trial results in TensorStorage.
|
| 608 |
+
* **API-Driven:** A comprehensive FastAPI backend offers RESTful endpoints for dataset management, NQL querying, tensor operations, and agent control (live for Ingestion Agent, simulated for RL/AutoML).
|
| 609 |
+
* **Streamlit UI:** A multi-page user interface for dashboard overview, agent control, NQL interaction, data exploration, and API interaction.
|
| 610 |
+
* **Tensor Operations:** A library of robust tensor operations (arithmetic, matrix ops, reductions, reshaping, etc.) accessible via the API.
|
| 611 |
+
* **Model Context Protocol (MCP) Server:** A Python server built with `fastmcp` exposes Tensorus capabilities (storage and operations) via the Model Context Protocol.
|
| 612 |
+
* **Extensible Design:** The project is structured with modular components, facilitating future extensions.
|
| 613 |
+
|
| 614 |
+
## Future Implementation
|
| 615 |
+
|
| 616 |
+
* **Enhanced NQL:** Integrate a local or remote LLM for more robust natural language understanding.
|
| 617 |
+
* **Advanced Agents:** Develop more sophisticated agents for specific tasks (e.g., anomaly detection, forecasting).
|
| 618 |
+
* **Persistent Storage Backend:** Replace/augment current file-based persistence with more robust database or cloud storage solutions (e.g., PostgreSQL, S3, MinIO).
|
| 619 |
+
* **Scalability & Performance:**
|
| 620 |
+
* Implement tensor chunking for very large tensors.
|
| 621 |
+
* Optimize query performance with indexing.
|
| 622 |
+
* Asynchronous operations for agents and API calls.
|
| 623 |
+
* **Security:** Implement authentication and authorization mechanisms for the API and UI.
|
| 624 |
+
* **Real-World Integration:**
|
| 625 |
+
* Connect Ingestion Agent to more data sources (e.g., cloud storage, databases, APIs).
|
| 626 |
+
* Integrate RL Agent with real-world environments or more complex simulations.
|
| 627 |
+
* **Advanced AutoML:**
|
| 628 |
+
* Implement sophisticated search algorithms (e.g., Bayesian Optimization, Hyperband).
|
| 629 |
+
* Support for diverse model architectures and custom models.
|
| 630 |
+
* **Model Management:** Add capabilities for saving, loading, versioning, and deploying trained models (from RL/AutoML).
|
| 631 |
+
* **Streaming Data Support:** Enhance Ingestion Agent to handle real-time streaming data.
|
| 632 |
+
* **Resource Management:** Add tools and controls for monitoring and managing the resource consumption (CPU, memory) of agents.
|
| 633 |
+
* **Improved UI/UX:** Continuously refine the Streamlit UI for better usability and richer visualizations.
|
| 634 |
+
* **Comprehensive Testing:** Expand unit, integration, and end-to-end tests.
|
| 635 |
+
|
| 636 |
+
## Contributing
|
| 637 |
+
|
| 638 |
+
Contributions are welcome! Please feel free to open issues or submit pull requests.
|
| 639 |
+
|
| 640 |
+
## License
|
| 641 |
+
|
| 642 |
+
MIT License
|
app.py
ADDED
|
@@ -0,0 +1,378 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# app.py
|
| 2 |
+
"""
|
| 3 |
+
Streamlit frontend application for the Tensorus platform.
|
| 4 |
+
New UI structure with top navigation and Nexus Dashboard.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import streamlit as st
|
| 8 |
+
import json
|
| 9 |
+
import time
|
| 10 |
+
import requests # Needed for ui_utils functions if integrated
|
| 11 |
+
import logging # Needed for ui_utils functions if integrated
|
| 12 |
+
import torch # Needed for integrated tensor utils
|
| 13 |
+
from typing import List, Dict, Any, Optional, Union, Tuple # Needed for integrated tensor utils
|
| 14 |
+
from pages.pages_shared_utils import get_api_status, get_agent_status, get_datasets # Updated imports
|
| 15 |
+
|
| 16 |
+
# Work around a Streamlit bug where inspecting `torch.classes` during module
|
| 17 |
+
# watching can raise a `RuntimeError`. Removing the module from `sys.modules`
|
| 18 |
+
# prevents Streamlit's watcher from trying to access it.
|
| 19 |
+
import sys
|
| 20 |
+
if "torch.classes" in sys.modules:
|
| 21 |
+
del sys.modules["torch.classes"]
|
| 22 |
+
|
| 23 |
+
# --- Page Configuration ---
|
| 24 |
+
st.set_page_config(
|
| 25 |
+
page_title="Tensorus Platform",
|
| 26 |
+
page_icon="🧊",
|
| 27 |
+
layout="wide",
|
| 28 |
+
initial_sidebar_state="collapsed" # Collapse sidebar as nav is now at top
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
# --- Configure Logging (Optional but good practice) ---
|
| 32 |
+
logging.basicConfig(level=logging.INFO)
|
| 33 |
+
logger = logging.getLogger(__name__)
|
| 34 |
+
|
| 35 |
+
# --- Integrated Tensor Utilities (Preserved) ---
|
| 36 |
+
|
| 37 |
+
def _validate_tensor_data(data: List[Any], shape: List[int]):
|
| 38 |
+
"""
|
| 39 |
+
Validates if the nested list structure of 'data' matches the 'shape'.
|
| 40 |
+
Raises ValueError on mismatch. (Optional validation)
|
| 41 |
+
"""
|
| 42 |
+
if not shape:
|
| 43 |
+
if not isinstance(data, (int, float)): raise ValueError("Scalar tensor data must be a single number.")
|
| 44 |
+
return True
|
| 45 |
+
if not isinstance(data, list): raise ValueError(f"Data for shape {shape} must be a list.")
|
| 46 |
+
expected_len = shape[0]
|
| 47 |
+
if len(data) != expected_len: raise ValueError(f"Dimension 0 mismatch: Expected {expected_len}, got {len(data)} for shape {shape}.")
|
| 48 |
+
if len(shape) > 1:
|
| 49 |
+
for item in data: _validate_tensor_data(item, shape[1:])
|
| 50 |
+
elif len(shape) == 1:
|
| 51 |
+
if not all(isinstance(x, (int, float)) for x in data): raise ValueError("Innermost list elements must be numbers.")
|
| 52 |
+
return True
|
| 53 |
+
|
| 54 |
+
def list_to_tensor(shape: List[int], dtype_str: str, data: Union[List[Any], int, float]) -> torch.Tensor:
|
| 55 |
+
"""
|
| 56 |
+
Converts a Python list (potentially nested) or scalar into a PyTorch tensor
|
| 57 |
+
with the specified shape and dtype.
|
| 58 |
+
"""
|
| 59 |
+
try:
|
| 60 |
+
dtype_map = {
|
| 61 |
+
'float32': torch.float32, 'float': torch.float,
|
| 62 |
+
'float64': torch.float64, 'double': torch.double,
|
| 63 |
+
'int32': torch.int32, 'int': torch.int,
|
| 64 |
+
'int64': torch.int64, 'long': torch.long,
|
| 65 |
+
'bool': torch.bool
|
| 66 |
+
}
|
| 67 |
+
torch_dtype = dtype_map.get(dtype_str.lower())
|
| 68 |
+
if torch_dtype is None: raise ValueError(f"Unsupported dtype string: {dtype_str}")
|
| 69 |
+
|
| 70 |
+
tensor = torch.tensor(data, dtype=torch_dtype)
|
| 71 |
+
|
| 72 |
+
if list(tensor.shape) != shape:
|
| 73 |
+
logger.debug(f"Initial tensor shape {list(tensor.shape)} differs from target {shape}. Attempting reshape.")
|
| 74 |
+
try:
|
| 75 |
+
tensor = tensor.reshape(shape)
|
| 76 |
+
except RuntimeError as reshape_err:
|
| 77 |
+
raise ValueError(f"Created tensor shape {list(tensor.shape)} != requested {shape} and reshape failed: {reshape_err}") from reshape_err
|
| 78 |
+
|
| 79 |
+
return tensor
|
| 80 |
+
except (TypeError, ValueError) as e:
|
| 81 |
+
logger.error(f"Error converting list to tensor: {e}. Shape: {shape}, Dtype: {dtype_str}")
|
| 82 |
+
raise ValueError(f"Failed tensor conversion: {e}") from e
|
| 83 |
+
except Exception as e:
|
| 84 |
+
logger.exception(f"Unexpected error during list_to_tensor: {e}")
|
| 85 |
+
raise ValueError(f"Unexpected tensor conversion error: {e}") from e
|
| 86 |
+
|
| 87 |
+
def tensor_to_list(tensor: torch.Tensor) -> Tuple[List[int], str, List[Any]]:
|
| 88 |
+
"""
|
| 89 |
+
Converts a PyTorch tensor back into its shape, dtype string, and nested list representation.
|
| 90 |
+
"""
|
| 91 |
+
if not isinstance(tensor, torch.Tensor):
|
| 92 |
+
raise TypeError("Input must be a torch.Tensor")
|
| 93 |
+
shape = list(tensor.shape)
|
| 94 |
+
dtype_str = str(tensor.dtype).split('.')[-1]
|
| 95 |
+
data = tensor.tolist()
|
| 96 |
+
return shape, dtype_str, data
|
| 97 |
+
|
| 98 |
+
# --- Helper functions for dashboard (can be expanded) ---
|
| 99 |
+
def get_total_tensors_placeholder():
|
| 100 |
+
# For now, as this endpoint is hypothetical for this task
|
| 101 |
+
return "N/A"
|
| 102 |
+
|
| 103 |
+
@st.cache_data(ttl=300)
|
| 104 |
+
def get_active_datasets_placeholder():
|
| 105 |
+
datasets = get_datasets()
|
| 106 |
+
if datasets: # get_datasets returns [] on error or if no datasets
|
| 107 |
+
return str(len(datasets))
|
| 108 |
+
return "Error"
|
| 109 |
+
|
| 110 |
+
@st.cache_data(ttl=60)
|
| 111 |
+
def get_agents_online_placeholder():
|
| 112 |
+
agent_data = get_agent_status()
|
| 113 |
+
if agent_data:
|
| 114 |
+
try:
|
| 115 |
+
# Assuming agent_data is a dict like {'agent_id': {'status': 'running', ...}}
|
| 116 |
+
# or {'agent_id': {'running': True, ...}}
|
| 117 |
+
online_agents = sum(1 for agent in agent_data.values()
|
| 118 |
+
if agent.get('running') is True or str(agent.get('status', '')).lower() == 'running')
|
| 119 |
+
total_agents = len(agent_data)
|
| 120 |
+
return f"{online_agents}/{total_agents} Online"
|
| 121 |
+
except Exception as e:
|
| 122 |
+
logger.error(f"Error processing agent data for dashboard: {e}")
|
| 123 |
+
return "Error"
|
| 124 |
+
return "N/A" # If agent_data is None
|
| 125 |
+
|
| 126 |
+
# --- CSS Styles ---
|
| 127 |
+
# Renaming app.py's specific CSS loader to avoid confusion with the shared one.
|
| 128 |
+
def load_app_specific_css():
|
| 129 |
+
# This function now only loads styles specific to the Nexus Dashboard content in app.py
|
| 130 |
+
# General styles (body, .stApp, nav, common-card, etc.) are in pages_shared_utils.load_shared_css()
|
| 131 |
+
st.markdown("""
|
| 132 |
+
<style>
|
| 133 |
+
/* Nexus Dashboard Specific Styles */
|
| 134 |
+
.dashboard-title { /* Custom title for the dashboard */
|
| 135 |
+
color: #e0e0ff; /* Light purple/blue, matching shared h1 */
|
| 136 |
+
text-align: center;
|
| 137 |
+
margin-top: 1rem; /* Standardized margin */
|
| 138 |
+
margin-bottom: 2rem;
|
| 139 |
+
font-size: 2.8em; /* Slightly larger for main dashboard title */
|
| 140 |
+
font-weight: bold;
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
/* Metric Cards Container for Dashboard */
|
| 144 |
+
.metric-card-container {
|
| 145 |
+
display: flex;
|
| 146 |
+
flex-wrap: wrap;
|
| 147 |
+
justify-content: space-around;
|
| 148 |
+
gap: 20px;
|
| 149 |
+
padding: 0 1rem;
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
/* Individual Metric Card - inherits from .common-card (defined in shared_utils) */
|
| 153 |
+
/* This specific .metric-card class is for dashboard cards if they need further specialization */
|
| 154 |
+
.metric-card {
|
| 155 |
+
/* Inherits background, border, padding, shadow, transition from .common-card */
|
| 156 |
+
flex: 1 1 220px; /* Adjusted flex basis */
|
| 157 |
+
min-width: 200px;
|
| 158 |
+
max-width: 300px;
|
| 159 |
+
text-align: center;
|
| 160 |
+
}
|
| 161 |
+
/* .metric-card:hover is inherited from .common-card:hover */
|
| 162 |
+
|
| 163 |
+
.metric-card .icon { /* Specific styling for icons within dashboard metric cards */
|
| 164 |
+
font-size: 2.8em; /* Slightly larger icon for dashboard */
|
| 165 |
+
margin-bottom: 0.5rem; /* Tighter spacing */
|
| 166 |
+
/* color is inherited from .common-card .icon or can be overridden here */
|
| 167 |
+
}
|
| 168 |
+
.metric-card h3 { /* Metric card titles */
|
| 169 |
+
/* color, font-size, margin-bottom, font-weight inherited from .common-card h3 */
|
| 170 |
+
/* No specific overrides here unless needed for dashboard metric cards */
|
| 171 |
+
}
|
| 172 |
+
.metric-card p.metric-value { /* Specific class for the main value display */
|
| 173 |
+
font-size: 2em; /* Larger font for the metric value */
|
| 174 |
+
font-weight: bold;
|
| 175 |
+
color: #ffffff; /* White color for emphasis */
|
| 176 |
+
margin-top: 0.25rem; /* Adjust as needed */
|
| 177 |
+
margin-bottom: 0;
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
/* Specific status icon colors for API status card in dashboard */
|
| 181 |
+
.metric-card.api-status-connected .icon { color: #50C878; } /* Emerald Green */
|
| 182 |
+
.metric-card.api-status-disconnected .icon { color: #FF6961; } /* Pastel Red */
|
| 183 |
+
|
| 184 |
+
/* Activity Feed Styles for Dashboard */
|
| 185 |
+
.activity-feed-container { /* Container for the activity feed section */
|
| 186 |
+
margin-top: 2.5rem;
|
| 187 |
+
padding: 0 1.5rem;
|
| 188 |
+
}
|
| 189 |
+
/* .activity-feed-container h2 is covered by shared h2 styles */
|
| 190 |
+
|
| 191 |
+
.activity-item { /* Individual item in the feed */
|
| 192 |
+
background-color: #1e2a47; /* Slightly lighter than common-card, for variety */
|
| 193 |
+
padding: 0.85rem 1.25rem; /* Adjusted padding */
|
| 194 |
+
border-radius: 6px;
|
| 195 |
+
margin-bottom: 0.6rem; /* Slightly more space */
|
| 196 |
+
font-size: 0.95em;
|
| 197 |
+
border-left: 4px solid #3a6fbf; /* Accent Blue border */
|
| 198 |
+
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
|
| 199 |
+
}
|
| 200 |
+
.activity-item .timestamp { /* Timestamp within an activity item */
|
| 201 |
+
color: #8080af; /* Muted purple for timestamp */
|
| 202 |
+
font-weight: bold;
|
| 203 |
+
font-size: 0.85em; /* Smaller timestamp */
|
| 204 |
+
margin-right: 0.75em; /* More space after timestamp */
|
| 205 |
+
}
|
| 206 |
+
.activity-item strong { /* Agent name or key part of activity */
|
| 207 |
+
color: #b0b0df; /* Tertiary heading color for emphasis */
|
| 208 |
+
}
|
| 209 |
+
</style>
|
| 210 |
+
""", unsafe_allow_html=True)
|
| 211 |
+
|
| 212 |
+
# --- Page Functions ---
|
| 213 |
+
|
| 214 |
+
def nexus_dashboard_content():
|
| 215 |
+
# Uses .dashboard-title for its main heading
|
| 216 |
+
st.markdown('<h1 class="dashboard-title">Tensorus Nexus</h1>', unsafe_allow_html=True)
|
| 217 |
+
|
| 218 |
+
# System Health & Key Metrics
|
| 219 |
+
# Uses .metric-card-container for the overall layout
|
| 220 |
+
st.markdown('<div class="metric-card-container">', unsafe_allow_html=True)
|
| 221 |
+
|
| 222 |
+
# Card 1: Total Tensors
|
| 223 |
+
total_tensors_val = get_total_tensors_placeholder()
|
| 224 |
+
st.markdown(f"""
|
| 225 |
+
<div class="common-card metric-card">
|
| 226 |
+
<div class="icon">⚙️</div>
|
| 227 |
+
<h3>Total Tensors</h3>
|
| 228 |
+
<p class="metric-value">{total_tensors_val}</p>
|
| 229 |
+
</div>
|
| 230 |
+
""", unsafe_allow_html=True)
|
| 231 |
+
|
| 232 |
+
# Card 2: Active Datasets
|
| 233 |
+
active_datasets_val = get_active_datasets_placeholder()
|
| 234 |
+
st.markdown(f"""
|
| 235 |
+
<div class="common-card metric-card">
|
| 236 |
+
<div class="icon">📚</div>
|
| 237 |
+
<h3>Active Datasets</h3>
|
| 238 |
+
<p class="metric-value">{active_datasets_val}</p>
|
| 239 |
+
</div>
|
| 240 |
+
""", unsafe_allow_html=True)
|
| 241 |
+
|
| 242 |
+
# Card 3: Agents Online
|
| 243 |
+
agents_online_val = get_agents_online_placeholder()
|
| 244 |
+
st.markdown(f"""
|
| 245 |
+
<div class="common-card metric-card">
|
| 246 |
+
<div class="icon">🤖</div>
|
| 247 |
+
<h3>Agents Online</h3>
|
| 248 |
+
<p class="metric-value">{agents_online_val}</p>
|
| 249 |
+
</div>
|
| 250 |
+
""", unsafe_allow_html=True)
|
| 251 |
+
|
| 252 |
+
# Card 4: API Status
|
| 253 |
+
@st.cache_data(ttl=30)
|
| 254 |
+
def cached_get_api_status():
|
| 255 |
+
return get_api_status()
|
| 256 |
+
|
| 257 |
+
api_ok, _ = cached_get_api_status()
|
| 258 |
+
api_status_text_val = "Connected" if api_ok else "Disconnected"
|
| 259 |
+
# Add specific class for API status icon coloring based on shared status styles
|
| 260 |
+
api_status_icon_class = "api-status-connected" if api_ok else "api-status-disconnected"
|
| 261 |
+
api_icon_char = "✔️" if api_ok else "❌"
|
| 262 |
+
st.markdown(f"""
|
| 263 |
+
<div class="common-card metric-card {api_status_icon_class}">
|
| 264 |
+
<div class="icon">{api_icon_char}</div>
|
| 265 |
+
<h3>API Status</h3>
|
| 266 |
+
<p class="metric-value">{api_status_text_val}</p>
|
| 267 |
+
</div>
|
| 268 |
+
""", unsafe_allow_html=True)
|
| 269 |
+
|
| 270 |
+
st.markdown('</div>', unsafe_allow_html=True) # Close metric-card-container
|
| 271 |
+
|
| 272 |
+
# Agent Activity Feed
|
| 273 |
+
# Uses .activity-feed-container and h2 (which is styled by shared CSS)
|
| 274 |
+
st.markdown('<div class="activity-feed-container">', unsafe_allow_html=True)
|
| 275 |
+
st.markdown('<h2>Recent Agent Activity</h2>', unsafe_allow_html=True)
|
| 276 |
+
|
| 277 |
+
# Placeholder activity items
|
| 278 |
+
activity_items = [
|
| 279 |
+
{"timestamp": "2023-10-27 10:05:15", "agent": "IngestionAgent", "action": "added 'img_new.png' to 'raw_images'"},
|
| 280 |
+
{"timestamp": "2023-10-27 10:02:30", "agent": "RLAgent", "action": "completed training cycle, reward: 75.2"},
|
| 281 |
+
{"timestamp": "2023-10-27 09:55:48", "agent": "MonitoringAgent", "action": "detected high CPU usage on node 'compute-01'"},
|
| 282 |
+
{"timestamp": "2023-10-27 09:45:10", "agent": "IngestionAgent", "action": "processed batch of 100 sensor readings"},
|
| 283 |
+
]
|
| 284 |
+
|
| 285 |
+
for item in activity_items:
|
| 286 |
+
st.markdown(f"""
|
| 287 |
+
<div class="activity-item">
|
| 288 |
+
<span class="timestamp">[{item['timestamp']}]</span>
|
| 289 |
+
<strong>{item['agent']}:</strong> {item['action']}
|
| 290 |
+
</div>
|
| 291 |
+
""", unsafe_allow_html=True)
|
| 292 |
+
st.markdown('</div>', unsafe_allow_html=True) # Close activity-feed-container
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
# --- Main Application ---
|
| 296 |
+
# Import the shared CSS loader
|
| 297 |
+
try:
|
| 298 |
+
from pages.pages_shared_utils import load_css as load_shared_css
|
| 299 |
+
except ImportError:
|
| 300 |
+
st.error("Failed to import shared CSS loader. Page styling will be incomplete.")
|
| 301 |
+
def load_shared_css(): pass # Dummy function
|
| 302 |
+
|
| 303 |
+
def main():
|
| 304 |
+
load_shared_css() # Load shared styles first
|
| 305 |
+
load_app_specific_css() # Then load app-specific styles (for dashboard)
|
| 306 |
+
|
| 307 |
+
# Initialize session state for current page if not set
|
| 308 |
+
if 'current_page' not in st.session_state:
|
| 309 |
+
st.session_state.current_page = "Nexus Dashboard"
|
| 310 |
+
|
| 311 |
+
# --- Top Navigation Bar ---
|
| 312 |
+
nav_items = {
|
| 313 |
+
"Nexus Dashboard": "Nexus Dashboard",
|
| 314 |
+
"Agents": "Agents",
|
| 315 |
+
"Explorer": "Explorer",
|
| 316 |
+
"Query Hub": "Query Hub",
|
| 317 |
+
"API Docs": "API Docs"
|
| 318 |
+
}
|
| 319 |
+
|
| 320 |
+
nav_html_parts = [f'<div class="topnav-container"><span class="logo">🧊 Tensorus</span><nav>']
|
| 321 |
+
for page_id, page_name in nav_items.items():
|
| 322 |
+
active_class = "active" if st.session_state.current_page == page_id else ""
|
| 323 |
+
# Use st.button an invisible character as key to make Streamlit rerun and update session state
|
| 324 |
+
# This is a common workaround for making nav links update state
|
| 325 |
+
# We'll use a more robust JavaScript approach if this is problematic, but st.query_params is better
|
| 326 |
+
|
| 327 |
+
# Using query_params for navigation state is more robust
|
| 328 |
+
# Check if query_params for page is set, if so, it overrides session_state
|
| 329 |
+
query_params = st.query_params.to_dict()
|
| 330 |
+
if "page" in query_params and query_params["page"] in nav_items:
|
| 331 |
+
st.session_state.current_page = query_params["page"]
|
| 332 |
+
# Clear the query param after use to avoid it sticking on manual refresh
|
| 333 |
+
# However, for deeplinking, we might want to keep it.
|
| 334 |
+
# For now, let's allow it to persist. To clear: st.query_params.clear()
|
| 335 |
+
|
| 336 |
+
# Construct the link with query_params
|
| 337 |
+
nav_html_parts.append(
|
| 338 |
+
f'<a href="?page={page_id}" class="{active_class}" target="_self">{page_name}</a>'
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
nav_html_parts.append('</nav></div>')
|
| 342 |
+
st.markdown("".join(nav_html_parts), unsafe_allow_html=True)
|
| 343 |
+
|
| 344 |
+
# Handle page selection clicks (alternative to query_params if that proves problematic)
|
| 345 |
+
# This part is tricky with pure st.markdown links.
|
| 346 |
+
# The query_params approach is generally preferred for web-like navigation.
|
| 347 |
+
|
| 348 |
+
# --- Content Area ---
|
| 349 |
+
# The main app.py now acts as a router to other pages or displays dashboard content directly.
|
| 350 |
+
if st.session_state.current_page == "Nexus Dashboard":
|
| 351 |
+
nexus_dashboard_content()
|
| 352 |
+
elif st.session_state.current_page == "Agents":
|
| 353 |
+
st.switch_page("pages/control_panel_v2.py")
|
| 354 |
+
elif st.session_state.current_page == "Explorer":
|
| 355 |
+
st.switch_page("pages/data_explorer_v2.py")
|
| 356 |
+
elif st.session_state.current_page == "Query Hub":
|
| 357 |
+
st.switch_page("pages/nql_chatbot_v2.py")
|
| 358 |
+
elif st.session_state.current_page == "API Docs":
|
| 359 |
+
st.switch_page("pages/api_playground_v2.py")
|
| 360 |
+
else:
|
| 361 |
+
# Default to Nexus Dashboard if current_page is unrecognized
|
| 362 |
+
st.session_state.current_page = "Nexus Dashboard"
|
| 363 |
+
nexus_dashboard_content()
|
| 364 |
+
# It's good practice to trigger a rerun if state was corrected
|
| 365 |
+
st.rerun()
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
if __name__ == "__main__":
|
| 369 |
+
# --- Initialize Old Session State Keys (to avoid errors if they are still used by preserved code) ---
|
| 370 |
+
# This should be phased out as those sections are rebuilt.
|
| 371 |
+
if 'agent_status' not in st.session_state: st.session_state.agent_status = None
|
| 372 |
+
if 'datasets' not in st.session_state: st.session_state.datasets = []
|
| 373 |
+
if 'selected_dataset' not in st.session_state: st.session_state.selected_dataset = None
|
| 374 |
+
if 'dataset_preview' not in st.session_state: st.session_state.dataset_preview = None
|
| 375 |
+
if 'explorer_result' not in st.session_state: st.session_state.explorer_result = None
|
| 376 |
+
if 'nql_response' not in st.session_state: st.session_state.nql_response = None
|
| 377 |
+
|
| 378 |
+
main()
|
docker-compose.yml
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version: '3.8'
|
| 2 |
+
|
| 3 |
+
services:
|
| 4 |
+
app:
|
| 5 |
+
build: .
|
| 6 |
+
ports:
|
| 7 |
+
- "7860:7860"
|
| 8 |
+
volumes:
|
| 9 |
+
# Persist audit logs generated by the app container to the host
|
| 10 |
+
# The path inside the container is /app/tensorus_audit.log as defined in tensorus/audit.py (implicitly if WORKDIR is /app)
|
| 11 |
+
- ./tensorus_audit.log:/app/tensorus_audit.log
|
| 12 |
+
environment:
|
| 13 |
+
# Python settings
|
| 14 |
+
PYTHONUNBUFFERED: 1
|
| 15 |
+
# Tensorus App Settings
|
| 16 |
+
TENSORUS_STORAGE_BACKEND: postgres
|
| 17 |
+
TENSORUS_POSTGRES_HOST: db # Service name of the postgres container
|
| 18 |
+
TENSORUS_POSTGRES_PORT: 5432 # Default postgres port, container to container
|
| 19 |
+
TENSORUS_POSTGRES_USER: tensorus_user_compose
|
| 20 |
+
TENSORUS_POSTGRES_PASSWORD: tensorus_password_compose
|
| 21 |
+
TENSORUS_POSTGRES_DB: tensorus_db_compose
|
| 22 |
+
# TENSORUS_POSTGRES_DSN: "" # Can be set if preferred over individual params
|
| 23 |
+
TENSORUS_API_KEY_HEADER_NAME: "X-API-KEY"
|
| 24 |
+
TENSORUS_VALID_API_KEYS: "compose_key1,another_secure_key" # Example keys
|
| 25 |
+
TENSORUS_AUTH_JWT_ENABLED: "False"
|
| 26 |
+
# TENSORUS_AUTH_JWT_ISSUER: "your_issuer_here"
|
| 27 |
+
# TENSORUS_AUTH_JWT_AUDIENCE: "tensorus_api_audience"
|
| 28 |
+
# TENSORUS_AUTH_JWT_ALGORITHM: "RS256"
|
| 29 |
+
# TENSORUS_AUTH_JWT_JWKS_URI: "your_jwks_uri_here"
|
| 30 |
+
TENSORUS_AUTH_DEV_MODE_ALLOW_DUMMY_JWT: "False"
|
| 31 |
+
depends_on:
|
| 32 |
+
db:
|
| 33 |
+
condition: service_healthy # Wait for db to be healthy (Postgres specific healthcheck needed in db service)
|
| 34 |
+
# For simpler startup without healthcheck, just `depends_on: - db` is fine,
|
| 35 |
+
# but app might start before DB is ready. Entrypoint script in app can handle retries then.
|
| 36 |
+
|
| 37 |
+
db:
|
| 38 |
+
image: postgres:15-alpine
|
| 39 |
+
volumes:
|
| 40 |
+
- postgres_data:/var/lib/postgresql/data/ # Persist data
|
| 41 |
+
environment:
|
| 42 |
+
POSTGRES_USER: tensorus_user_compose # Must match app's TENSORUS_POSTGRES_USER
|
| 43 |
+
POSTGRES_PASSWORD: tensorus_password_compose # Must match app's TENSORUS_POSTGRES_PASSWORD
|
| 44 |
+
POSTGRES_DB: tensorus_db_compose # Must match app's TENSORUS_POSTGRES_DB
|
| 45 |
+
ports:
|
| 46 |
+
- "5433:5432" # Expose Postgres on host port 5433 to avoid conflict if local PG runs on 5432
|
| 47 |
+
healthcheck: # Basic healthcheck for Postgres
|
| 48 |
+
test: ["CMD-SHELL", "pg_isready -U tensorus_user_compose -d tensorus_db_compose"]
|
| 49 |
+
interval: 10s
|
| 50 |
+
timeout: 5s
|
| 51 |
+
retries: 5
|
| 52 |
+
|
| 53 |
+
volumes:
|
| 54 |
+
postgres_data: # Defines the named volume for data persistence
|
| 55 |
+
driver: local
|
docs/api_guide.md
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# API Guide
|
| 2 |
+
|
| 3 |
+
The Tensorus Metadata System provides a comprehensive RESTful API for managing and interacting with tensor metadata.
|
| 4 |
+
|
| 5 |
+
## Interactive API Documentation (Swagger UI)
|
| 6 |
+
|
| 7 |
+
The API is self-documenting using OpenAPI. Once the application is running (e.g., locally at `http://localhost:7860`), you can access the interactive Swagger UI at:
|
| 8 |
+
|
| 9 |
+
* **`/docs`**: [http://localhost:7860/docs](http://localhost:7860/docs)
|
| 10 |
+
|
| 11 |
+
This interface allows you to explore all available endpoints, view their request and response schemas, and even try out API calls directly from your browser.
|
| 12 |
+
|
| 13 |
+
## Alternative API Documentation (ReDoc)
|
| 14 |
+
|
| 15 |
+
An alternative ReDoc interface is also available at:
|
| 16 |
+
|
| 17 |
+
* **`/redoc`**: [http://localhost:7860/redoc](http://localhost:7860/redoc)
|
| 18 |
+
|
| 19 |
+
## Main API Categories
|
| 20 |
+
|
| 21 |
+
The API is organized into several categories based on functionality:
|
| 22 |
+
|
| 23 |
+
* **Tensor Descriptors:** Core operations for creating, reading, updating, deleting, and listing tensor descriptors.
|
| 24 |
+
* **Semantic Metadata (Per Tensor):** Managing human-readable names, descriptions, etc., associated with specific tensors, nested under `/tensor_descriptors/{tensor_id}/semantic/`.
|
| 25 |
+
* **Extended Metadata (Per Tensor):** CRUD operations for detailed metadata types, nested under `/tensor_descriptors/{tensor_id}/`:
|
| 26 |
+
* Lineage Metadata (`/lineage`)
|
| 27 |
+
* Computational Metadata (`/computational`)
|
| 28 |
+
* Quality Metadata (`/quality`)
|
| 29 |
+
* Relational Metadata (`/relational`)
|
| 30 |
+
* Usage Metadata (`/usage`)
|
| 31 |
+
* **Versioning & Lineage:** Endpoints for creating tensor versions and managing lineage relationships at a higher level.
|
| 32 |
+
* **Search & Aggregation:** Advanced querying, text-based search across metadata, and metadata aggregation.
|
| 33 |
+
* **Import/Export:** Endpoints for exporting and importing tensor metadata in JSON format.
|
| 34 |
+
* **Management:** Health checks and system metrics.
|
| 35 |
+
* **Authentication:** Write operations (POST, PUT, PATCH, DELETE) are protected by API keys. See [Installation and Configuration](./installation.md) for details on setting API keys. The API key should be passed in the HTTP header specified by `TENSORUS_API_KEY_HEADER_NAME` (default: `X-API-KEY`).
|
| 36 |
+
|
| 37 |
+
Please refer to the interactive `/docs` for detailed information on each endpoint, including request parameters, request bodies, and response structures.
|
docs/index.md
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Welcome to the Tensorus Metadata System
|
| 2 |
+
|
| 3 |
+
The Tensorus Metadata System is a comprehensive solution for managing metadata
|
| 4 |
+
associated with tensor data in an Agentic Tensor Database. It enables advanced
|
| 5 |
+
capabilities such as semantic search, lineage tracking, and version control.
|
| 6 |
+
|
| 7 |
+
This documentation provides an overview of the system, how to install and
|
| 8 |
+
configure it, and details about its API and metadata schemas.
|
| 9 |
+
|
| 10 |
+
## Key Features
|
| 11 |
+
|
| 12 |
+
* **Rich Metadata Schemas:** Describe tensors with detailed structural, semantic, lineage, computational, quality, relational, and usage metadata.
|
| 13 |
+
* **Flexible Storage:** Supports in-memory storage for quick testing and a PostgreSQL backend for persistent, scalable storage.
|
| 14 |
+
* **Powerful API:** A comprehensive RESTful API for CRUD operations, advanced querying, search, aggregation, versioning, and lineage tracking.
|
| 15 |
+
* **Data Export/Import:** Utilities to export and import metadata in a standard JSON format.
|
| 16 |
+
* **Security:** Basic API key authentication for write operations and audit logging.
|
| 17 |
+
* **Monitoring:** Health check and basic metrics endpoints.
|
| 18 |
+
* **Enterprise Ready (Conceptual):** Designed with enterprise features like JWT authentication in mind.
|
| 19 |
+
* **Analytics:** Example APIs for deriving insights from metadata (e.g., stale tensors, complex tensors, co-occurring tags).
|
| 20 |
+
* **Dockerized:** Easy to deploy using Docker and Docker Compose.
|
| 21 |
+
|
| 22 |
+
Navigate through the documentation using the sidebar to learn more.
|
docs/installation.md
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Installation and Configuration
|
| 2 |
+
|
| 3 |
+
## Installation
|
| 4 |
+
|
| 5 |
+
### Using Docker (Recommended for Local Development/Testing)
|
| 6 |
+
|
| 7 |
+
The easiest way to get Tensorus and its PostgreSQL backend running locally is by using Docker Compose:
|
| 8 |
+
|
| 9 |
+
1. Ensure you have Docker and Docker Compose installed.
|
| 10 |
+
2. Clone the repository (if applicable) or ensure you have the `docker-compose.yml` and application files.
|
| 11 |
+
3. From the project root, run:
|
| 12 |
+
```bash
|
| 13 |
+
docker-compose up --build
|
| 14 |
+
```
|
| 15 |
+
4. The API will typically be available at `http://localhost:7860`.
|
| 16 |
+
|
| 17 |
+
### Manual Installation (From Source)
|
| 18 |
+
|
| 19 |
+
1. Ensure you have Python 3.9+ installed.
|
| 20 |
+
2. Clone the repository.
|
| 21 |
+
3. Create and activate a virtual environment:
|
| 22 |
+
```bash
|
| 23 |
+
python -m venv venv
|
| 24 |
+
source venv/bin/activate # On Windows: venv\Scripts\activate
|
| 25 |
+
```
|
| 26 |
+
4. Install dependencies:
|
| 27 |
+
```bash
|
| 28 |
+
pip install -r requirements.txt
|
| 29 |
+
```
|
| 30 |
+
Heavy machine-learning libraries used by the optional models are not
|
| 31 |
+
installed by default. Install them with the `[models]` extra when needed:
|
| 32 |
+
```bash
|
| 33 |
+
pip install -e .[models]
|
| 34 |
+
```
|
| 35 |
+
If you intend to run the test suite, also install the test requirements:
|
| 36 |
+
```bash
|
| 37 |
+
pip install -r requirements-test.txt
|
| 38 |
+
```
|
| 39 |
+
This ensures `fastapi>=0.110` is installed so the API is compatible with
|
| 40 |
+
Pydantic v2.
|
| 41 |
+
5. *(Optional)* Install the example models package. The built-in models that
|
| 42 |
+
were previously part of this repository now live at
|
| 43 |
+
[https://github.com/tensorus/models](https://github.com/tensorus/models):
|
| 44 |
+
```bash
|
| 45 |
+
pip install tensorus-models
|
| 46 |
+
```
|
| 47 |
+
6. Set up the necessary environment variables (see Configuration below).
|
| 48 |
+
7. Run the application using Uvicorn:
|
| 49 |
+
```bash
|
| 50 |
+
uvicorn tensorus.api.main:app --host 0.0.0.0 --port 7860 --reload # for development
|
| 51 |
+
```
|
| 52 |
+
|
| 53 |
+
## Configuration
|
| 54 |
+
|
| 55 |
+
Tensorus is configured via environment variables. Key variables include:
|
| 56 |
+
|
| 57 |
+
* `TENSORUS_STORAGE_BACKEND`: Specifies the storage backend.
|
| 58 |
+
* `in_memory` (default): Uses in-memory storage (data is not persisted).
|
| 59 |
+
* `postgres`: Uses PostgreSQL.
|
| 60 |
+
* `TENSORUS_POSTGRES_HOST`: Hostname for PostgreSQL (e.g., `localhost` or `db` if using Docker Compose).
|
| 61 |
+
* `TENSORUS_POSTGRES_PORT`: Port for PostgreSQL (e.g., `5432`).
|
| 62 |
+
* `TENSORUS_POSTGRES_USER`: Username for PostgreSQL.
|
| 63 |
+
* `TENSORUS_POSTGRES_PASSWORD`: Password for PostgreSQL.
|
| 64 |
+
* `TENSORUS_POSTGRES_DB`: Database name for PostgreSQL.
|
| 65 |
+
* `TENSORUS_POSTGRES_DSN`: Alternative DSN connection string for PostgreSQL.
|
| 66 |
+
* `TENSORUS_VALID_API_KEYS`: Comma-separated list of valid API keys for write operations (e.g., `key1,key2,anotherkey`). If empty or not set, write operations requiring API keys will be inaccessible.
|
| 67 |
+
* `TENSORUS_API_KEY_HEADER_NAME`: HTTP header name for the API key (default: `X-API-KEY`).
|
| 68 |
+
* `TENSORUS_MINIMAL_IMPORT`: Set to any value to skip importing the optional
|
| 69 |
+
`tensorus-models` package for a lightweight installation.
|
| 70 |
+
|
| 71 |
+
### JWT Authentication (Conceptual - For Future Use)
|
| 72 |
+
* `TENSORUS_AUTH_JWT_ENABLED`: `True` or `False` (default `False`).
|
| 73 |
+
* `TENSORUS_AUTH_JWT_ISSUER`: URL of the JWT issuer.
|
| 74 |
+
* `TENSORUS_AUTH_JWT_AUDIENCE`: Expected audience for JWTs.
|
| 75 |
+
* `TENSORUS_AUTH_JWT_ALGORITHM`: Algorithm (default `RS256`).
|
| 76 |
+
* `TENSORUS_AUTH_JWT_JWKS_URI`: URI to fetch JWKS.
|
| 77 |
+
* `TENSORUS_AUTH_DEV_MODE_ALLOW_DUMMY_JWT`: `True` to allow dummy JWTs for development if JWT auth is enabled (default `False`).
|
| 78 |
+
|
| 79 |
+
Refer to the `docker-compose.yml` for example environment variable settings when running with Docker. For manual setup, export these variables in your shell or use a `.env` file (if your setup supports it, though direct environment variables are primary).
|
| 80 |
+
|
| 81 |
+
## Running Tests
|
| 82 |
+
|
| 83 |
+
Tensorus includes Python unit tests. After installing the dependencies you can run them with:
|
| 84 |
+
|
| 85 |
+
```bash
|
| 86 |
+
pytest
|
| 87 |
+
```
|
docs/metadata_schemas.md
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Metadata Schemas Overview
|
| 2 |
+
|
| 3 |
+
The Tensorus Metadata System employs a rich set of Pydantic schemas to define the structure and validation rules for metadata. These schemas ensure data consistency and provide a clear contract for API interactions.
|
| 4 |
+
|
| 5 |
+
## Core Schema: TensorDescriptor
|
| 6 |
+
|
| 7 |
+
The `TensorDescriptor` is the fundamental metadata unit. It captures essential information about a tensor's structure, storage, and identity. Key fields include:
|
| 8 |
+
|
| 9 |
+
* `tensor_id`: Unique UUID for the tensor.
|
| 10 |
+
* `dimensionality`: Number of dimensions.
|
| 11 |
+
* `shape`: Size of each dimension.
|
| 12 |
+
* `data_type`: Data type of tensor elements (e.g., `float32`, `int64`).
|
| 13 |
+
* `storage_format`: How the tensor is physically stored (e.g., `raw`, `numpy_npz`).
|
| 14 |
+
* `creation_timestamp`, `last_modified_timestamp`.
|
| 15 |
+
* `owner`, `access_control` (which itself is a nested model detailing read/write permissions).
|
| 16 |
+
* `byte_size`, `checksum` (e.g., MD5, SHA256 of the tensor data).
|
| 17 |
+
* `compression_info` (details if tensor data is compressed).
|
| 18 |
+
* `tags`: A list of arbitrary string tags for categorization.
|
| 19 |
+
* `metadata`: A flexible dictionary for any other custom key-value metadata.
|
| 20 |
+
|
| 21 |
+
## Semantic Metadata
|
| 22 |
+
|
| 23 |
+
Associated with a `TensorDescriptor` (one-to-many, identified by `name` per tensor), this schema describes the meaning and context of the tensor data. Key fields:
|
| 24 |
+
|
| 25 |
+
* `tensor_id`: Links to the parent `TensorDescriptor`.
|
| 26 |
+
* `name`: The specific name of this semantic annotation (e.g., "primary_class_label", "object_bounding_boxes", "feature_description_set").
|
| 27 |
+
* `description`: Detailed explanation of this semantic annotation.
|
| 28 |
+
* *(Note: The original broader concept of a single SemanticMetadata object per tensor with fields like `domain`, `purpose`, etc., has evolved. These broader concepts might be captured in `TensorDescriptor.tags`, `TensorDescriptor.metadata`, or specific named `SemanticMetadata` entries.)*
|
| 29 |
+
|
| 30 |
+
## Extended Metadata Schemas
|
| 31 |
+
|
| 32 |
+
These provide more detailed and specialized information, typically one-to-one with a `TensorDescriptor`:
|
| 33 |
+
|
| 34 |
+
* **`LineageMetadata`**: Tracks origin (`source`), parent tensors (`parent_tensors`), a history of transformations (`transformation_history`), version string (`version`), version control details (`version_control`), and other provenance information (`provenance`).
|
| 35 |
+
* `LineageSource`: Details the origin (e.g., file, API, computation).
|
| 36 |
+
* `ParentTensorLink`: Links to parent tensors and describes the relationship.
|
| 37 |
+
* `TransformationStep`: Describes an operation in the transformation history.
|
| 38 |
+
* `VersionControlInfo`: Git-like versioning details for the source or tensor itself.
|
| 39 |
+
|
| 40 |
+
* **`ComputationalMetadata`**: Describes how the tensor was computed, including the `algorithm` used, `parameters`, reference to a `computational_graph_ref`, `execution_environment` details, `computation_time_seconds`, and `hardware_info`.
|
| 41 |
+
|
| 42 |
+
* **`QualityMetadata`**: Captures information about data quality. Includes:
|
| 43 |
+
* `statistics` (`QualityStatistics` model: min, max, mean, std_dev, median, variance, percentiles, histogram).
|
| 44 |
+
* `missing_values` (`MissingValuesInfo` model: count, percentage, imputation strategy).
|
| 45 |
+
* `outliers` (`OutlierInfo` model: count, percentage, detection method).
|
| 46 |
+
- `noise_level`, `confidence_score`, `validation_results` (custom checks), `drift_score`.
|
| 47 |
+
|
| 48 |
+
* **`RelationalMetadata`**: Describes relationships to other tensors (`related_tensors` via `RelatedTensorLink`), membership in `collections`, explicit `dependencies` on other tensors, and `dataset_context`.
|
| 49 |
+
|
| 50 |
+
* **`UsageMetadata`**: Tracks how the tensor is used:
|
| 51 |
+
* `access_history` (list of `UsageAccessRecord` detailing who/what accessed it, when, and how).
|
| 52 |
+
* `usage_frequency` (auto-calculated from history or can be set).
|
| 53 |
+
* `last_accessed_at` (auto-calculated from history or can be set).
|
| 54 |
+
* `application_references` (list of applications or models that use this tensor).
|
| 55 |
+
* `purpose` (dictionary describing purposes, e.g. for specific model training).
|
| 56 |
+
|
| 57 |
+
For the exact field definitions, types, optionality, default values, and validation rules for each schema, please refer to the source code in the `tensorus/metadata/schemas.py` module or consult the schemas provided in the interactive API documentation at `/docs`. The I/O schemas for export/import are defined in `tensorus/metadata/schemas_iodata.py`.
|
mkdocs.yml
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
site_name: Tensorus Metadata System
|
| 2 |
+
site_description: Documentation for the Tensorus Metadata System.
|
| 3 |
+
nav:
|
| 4 |
+
- Home: index.md
|
| 5 |
+
- Installation: installation.md
|
| 6 |
+
- API Guide: api_guide.md
|
| 7 |
+
- Metadata Schemas: metadata_schemas.md
|
| 8 |
+
theme: readthedocs
|
| 9 |
+
# Optional: Add repo_url, edit_uri for "Edit on GitHub" links if relevant
|
| 10 |
+
# repo_url: https://github.com/your_username/tensorus
|
| 11 |
+
# edit_uri: edit/main/docs/
|
pages/1_Dashboard.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# pages/1_Dashboard.py (Modifications for Step 3)
|
| 2 |
+
|
| 3 |
+
import streamlit as st
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import numpy as np
|
| 6 |
+
import plotly.express as px
|
| 7 |
+
import time
|
| 8 |
+
# Updated imports to use API-backed functions
|
| 9 |
+
from .ui_utils import get_dashboard_metrics, list_all_agents, get_agent_status # MODIFIED
|
| 10 |
+
|
| 11 |
+
st.set_page_config(page_title="Tensorus Dashboard", layout="wide")
|
| 12 |
+
|
| 13 |
+
st.title("📊 Operations Dashboard")
|
| 14 |
+
st.caption("Overview of Tensorus datasets and agent activity from API.")
|
| 15 |
+
|
| 16 |
+
# --- Fetch Data ---
|
| 17 |
+
# Use st.cache_data for API calls that don't need constant updates
|
| 18 |
+
# or manage refresh manually. For simplicity, call directly on rerun/button click.
|
| 19 |
+
metrics_data = None
|
| 20 |
+
agent_list = None # Fetch full agent list for detailed status display
|
| 21 |
+
|
| 22 |
+
# Button to force refresh
|
| 23 |
+
if st.button("🔄 Refresh Dashboard Data"):
|
| 24 |
+
# Clear previous cache if any or just proceed to refetch
|
| 25 |
+
metrics_data = get_dashboard_metrics()
|
| 26 |
+
agent_list = list_all_agents()
|
| 27 |
+
st.session_state['dashboard_metrics'] = metrics_data # Store in session state
|
| 28 |
+
st.session_state['dashboard_agents'] = agent_list
|
| 29 |
+
st.rerun() # Rerun the script to reflect fetched data
|
| 30 |
+
else:
|
| 31 |
+
# Try to load from session state or fetch if not present
|
| 32 |
+
if 'dashboard_metrics' not in st.session_state:
|
| 33 |
+
st.session_state['dashboard_metrics'] = get_dashboard_metrics()
|
| 34 |
+
if 'dashboard_agents' not in st.session_state:
|
| 35 |
+
st.session_state['dashboard_agents'] = list_all_agents()
|
| 36 |
+
|
| 37 |
+
metrics_data = st.session_state['dashboard_metrics']
|
| 38 |
+
agent_list = st.session_state['dashboard_agents']
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
# --- Display Metrics ---
|
| 42 |
+
st.subheader("System Metrics")
|
| 43 |
+
if metrics_data:
|
| 44 |
+
col1, col2, col3 = st.columns(3)
|
| 45 |
+
col1.metric("Total Datasets", metrics_data.get('dataset_count', 'N/A'))
|
| 46 |
+
col2.metric("Total Records (Est.)", f"{metrics_data.get('total_records_est', 0):,}")
|
| 47 |
+
# Agent status summary from metrics
|
| 48 |
+
agent_summary = metrics_data.get('agent_status_summary', {})
|
| 49 |
+
running_agents = agent_summary.get('running', 0) + agent_summary.get('starting', 0)
|
| 50 |
+
col3.metric("Running Agents", running_agents)
|
| 51 |
+
|
| 52 |
+
st.divider()
|
| 53 |
+
|
| 54 |
+
# --- Performance Metrics Row ---
|
| 55 |
+
st.subheader("Performance Indicators (Simulated)")
|
| 56 |
+
pcol1, pcol2, pcol3, pcol4 = st.columns(4)
|
| 57 |
+
pcol1.metric("Ingestion Rate (rec/s)", f"{metrics_data.get('data_ingestion_rate', 0.0):.1f}")
|
| 58 |
+
pcol2.metric("Avg Query Latency (ms)", f"{metrics_data.get('avg_query_latency_ms', 0.0):.1f}")
|
| 59 |
+
pcol3.metric("Latest RL Reward", f"{metrics_data.get('rl_latest_reward', 'N/A')}")
|
| 60 |
+
pcol4.metric("Best AutoML Score", f"{metrics_data.get('automl_best_score', 'N/A')}")
|
| 61 |
+
|
| 62 |
+
else:
|
| 63 |
+
st.warning("Could not fetch dashboard metrics from the API.")
|
| 64 |
+
|
| 65 |
+
st.divider()
|
| 66 |
+
|
| 67 |
+
# --- Agent Status Details ---
|
| 68 |
+
st.subheader("Agent Status")
|
| 69 |
+
if agent_list:
|
| 70 |
+
num_agents = len(agent_list)
|
| 71 |
+
cols = st.columns(max(1, num_agents)) # Create columns for agents
|
| 72 |
+
|
| 73 |
+
for i, agent_info in enumerate(agent_list):
|
| 74 |
+
agent_id = agent_info.get('id')
|
| 75 |
+
with cols[i % len(cols)]: # Distribute agents into columns
|
| 76 |
+
with st.container(border=True):
|
| 77 |
+
st.markdown(f"**{agent_info.get('name', 'Unknown Agent')}** (`{agent_id}`)")
|
| 78 |
+
# Fetch detailed status for more info if needed, or use basic status from list
|
| 79 |
+
# status_details = get_agent_status(agent_id) # Can make page slower
|
| 80 |
+
status = agent_info.get('status', 'unknown')
|
| 81 |
+
status_color = "green" if status in ["running", "starting"] else ("orange" if status in ["stopping"] else ("red" if status in ["error"] else "grey"))
|
| 82 |
+
st.markdown(f"Status: :{status_color}[**{status.upper()}**]")
|
| 83 |
+
|
| 84 |
+
# Display config from the list info
|
| 85 |
+
with st.expander("Config"):
|
| 86 |
+
st.json(agent_info.get('config', {}), expanded=False)
|
| 87 |
+
else:
|
| 88 |
+
st.warning("Could not fetch agent list from the API.")
|
| 89 |
+
|
| 90 |
+
st.divider()
|
| 91 |
+
|
| 92 |
+
# --- Performance Monitoring Chart (Using simulated data from metrics for now) ---
|
| 93 |
+
st.subheader("Performance Monitoring (Placeholder Graph)")
|
| 94 |
+
if metrics_data:
|
| 95 |
+
# Create some fake historical data for plotting based on current metrics
|
| 96 |
+
history_len = 20
|
| 97 |
+
# Use session state to persist some history for smoother simulation
|
| 98 |
+
if 'sim_history' not in st.session_state:
|
| 99 |
+
st.session_state['sim_history'] = pd.DataFrame({
|
| 100 |
+
'Ingestion Rate': np.random.rand(history_len) * metrics_data.get('data_ingestion_rate', 10),
|
| 101 |
+
'Query Latency': np.random.rand(history_len) * metrics_data.get('avg_query_latency_ms', 100),
|
| 102 |
+
'RL Reward': np.random.randn(history_len) * 5 + (metrics_data.get('rl_latest_reward', 0) or 0)
|
| 103 |
+
})
|
| 104 |
+
|
| 105 |
+
# Update history with latest point
|
| 106 |
+
latest_data = pd.DataFrame({
|
| 107 |
+
'Ingestion Rate': [metrics_data.get('data_ingestion_rate', 0.0)],
|
| 108 |
+
'Query Latency': [metrics_data.get('avg_query_latency_ms', 0.0)],
|
| 109 |
+
'RL Reward': [metrics_data.get('rl_latest_reward', 0) or 0] # Handle None
|
| 110 |
+
})
|
| 111 |
+
st.session_state['sim_history'] = pd.concat([st.session_state['sim_history'].iloc[1:], latest_data], ignore_index=True)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
# Use Plotly for better interactivity
|
| 115 |
+
try:
|
| 116 |
+
fig = px.line(st.session_state['sim_history'], title="Simulated Performance Metrics Over Time")
|
| 117 |
+
fig.update_layout(legend_title_text='Metrics')
|
| 118 |
+
st.plotly_chart(fig, use_container_width=True)
|
| 119 |
+
except Exception as e:
|
| 120 |
+
st.warning(f"Could not display performance chart: {e}")
|
| 121 |
+
|
| 122 |
+
else:
|
| 123 |
+
st.info("Performance metrics unavailable.")
|
| 124 |
+
|
| 125 |
+
st.caption(f"Dashboard data timestamp: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(metrics_data.get('timestamp', time.time())) if metrics_data else time.time())}")
|
pages/2_Control_Panel.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# pages/2_Control_Panel.py (Modifications for Step 3)
|
| 2 |
+
|
| 3 |
+
import streamlit as st
|
| 4 |
+
import time
|
| 5 |
+
import json
|
| 6 |
+
# Use the updated API-backed functions
|
| 7 |
+
from .ui_utils import (
|
| 8 |
+
list_all_agents,
|
| 9 |
+
get_agent_status,
|
| 10 |
+
get_agent_logs,
|
| 11 |
+
start_agent,
|
| 12 |
+
stop_agent,
|
| 13 |
+
update_agent_config,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
st.set_page_config(page_title="Agent Control Panel", layout="wide")
|
| 17 |
+
|
| 18 |
+
st.title("🕹️ Multi-Agent Control Panel")
|
| 19 |
+
st.caption("Manage and monitor Tensorus agents via API.")
|
| 20 |
+
|
| 21 |
+
# Fetch agent list from API
|
| 22 |
+
agent_list = list_all_agents()
|
| 23 |
+
|
| 24 |
+
if not agent_list:
|
| 25 |
+
st.error("Could not fetch agent list from API. Please ensure the backend is running and reachable.")
|
| 26 |
+
st.stop()
|
| 27 |
+
|
| 28 |
+
# Create a mapping from name to ID for easier selection
|
| 29 |
+
# Handle potential duplicate names if necessary, though IDs should be unique
|
| 30 |
+
agent_options = {agent['name']: agent['id'] for agent in agent_list}
|
| 31 |
+
# Add ID to name if names aren't unique (optional robustness)
|
| 32 |
+
# agent_options = {f"{agent['name']} ({agent['id']})": agent['id'] for agent in agent_list}
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
selected_agent_name = st.selectbox("Select Agent:", options=agent_options.keys())
|
| 36 |
+
|
| 37 |
+
if selected_agent_name:
|
| 38 |
+
selected_agent_id = agent_options[selected_agent_name]
|
| 39 |
+
st.divider()
|
| 40 |
+
st.subheader(f"Control: {selected_agent_name} (`{selected_agent_id}`)")
|
| 41 |
+
|
| 42 |
+
# Use session state to store fetched status and logs for the selected agent
|
| 43 |
+
# This avoids refetching constantly unless a refresh is triggered
|
| 44 |
+
agent_state_key = f"agent_status_{selected_agent_id}"
|
| 45 |
+
agent_logs_key = f"agent_logs_{selected_agent_id}"
|
| 46 |
+
|
| 47 |
+
# Button to force refresh status and logs
|
| 48 |
+
if st.button(f"🔄 Refresh Status & Logs##{selected_agent_id}"): # Unique key per agent
|
| 49 |
+
st.session_state[agent_state_key] = get_agent_status(selected_agent_id)
|
| 50 |
+
st.session_state[agent_logs_key] = get_agent_logs(selected_agent_id)
|
| 51 |
+
st.rerun() # Rerun to display refreshed data
|
| 52 |
+
|
| 53 |
+
# Fetch status if not in session state or refresh button wasn't just clicked
|
| 54 |
+
if agent_state_key not in st.session_state:
|
| 55 |
+
st.session_state[agent_state_key] = get_agent_status(selected_agent_id)
|
| 56 |
+
|
| 57 |
+
status_info = st.session_state[agent_state_key]
|
| 58 |
+
|
| 59 |
+
if status_info:
|
| 60 |
+
status = status_info.get('status', 'unknown')
|
| 61 |
+
status_color = "green" if status in ["running", "starting"] else ("orange" if status in ["stopping"] else ("red" if status in ["error"] else "grey"))
|
| 62 |
+
st.markdown(f"Current Status: :{status_color}[**{status.upper()}**]")
|
| 63 |
+
last_log_ts = status_info.get('last_log_timestamp')
|
| 64 |
+
if last_log_ts:
|
| 65 |
+
st.caption(f"Last Log Entry (approx.): {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(last_log_ts))}")
|
| 66 |
+
else:
|
| 67 |
+
st.error(f"Could not retrieve status for agent '{selected_agent_name}'.")
|
| 68 |
+
|
| 69 |
+
# Control Buttons (Now call API functions)
|
| 70 |
+
col1, col2, col3 = st.columns([1, 1, 5])
|
| 71 |
+
is_running = status_info and status_info.get('status') == 'running'
|
| 72 |
+
is_stopped = status_info and status_info.get('status') == 'stopped'
|
| 73 |
+
|
| 74 |
+
with col1:
|
| 75 |
+
start_disabled = not is_stopped # Disable if not stopped
|
| 76 |
+
if st.button("▶️ Start", key=f"start_{selected_agent_id}", disabled=start_disabled):
|
| 77 |
+
if start_agent(selected_agent_id): # API call returns success/fail
|
| 78 |
+
# Trigger refresh after short delay to allow backend state change (optimistic)
|
| 79 |
+
time.sleep(1.0)
|
| 80 |
+
# Clear state cache and rerun
|
| 81 |
+
if agent_state_key in st.session_state: del st.session_state[agent_state_key]
|
| 82 |
+
if agent_logs_key in st.session_state: del st.session_state[agent_logs_key]
|
| 83 |
+
st.rerun()
|
| 84 |
+
with col2:
|
| 85 |
+
stop_disabled = not is_running # Disable if not running
|
| 86 |
+
if st.button("⏹️ Stop", key=f"stop_{selected_agent_id}", disabled=stop_disabled):
|
| 87 |
+
if stop_agent(selected_agent_id): # API call returns success/fail
|
| 88 |
+
time.sleep(1.0)
|
| 89 |
+
if agent_state_key in st.session_state: del st.session_state[agent_state_key]
|
| 90 |
+
if agent_logs_key in st.session_state: del st.session_state[agent_logs_key]
|
| 91 |
+
st.rerun()
|
| 92 |
+
|
| 93 |
+
st.divider()
|
| 94 |
+
|
| 95 |
+
# Configuration & Logs
|
| 96 |
+
tab1, tab2 = st.tabs(["Configuration", "Logs"])
|
| 97 |
+
|
| 98 |
+
with tab1:
|
| 99 |
+
if status_info and 'config' in status_info:
|
| 100 |
+
current_config = status_info['config']
|
| 101 |
+
st.write("Current configuration:")
|
| 102 |
+
st.json(current_config)
|
| 103 |
+
|
| 104 |
+
with st.expander("Edit configuration"):
|
| 105 |
+
form = st.form(key=f"cfg_form_{selected_agent_id}")
|
| 106 |
+
config_text = form.text_area(
|
| 107 |
+
"Configuration JSON",
|
| 108 |
+
value=json.dumps(current_config, indent=2),
|
| 109 |
+
height=200,
|
| 110 |
+
)
|
| 111 |
+
submitted = form.form_submit_button("Update")
|
| 112 |
+
if submitted:
|
| 113 |
+
try:
|
| 114 |
+
new_cfg = json.loads(config_text)
|
| 115 |
+
if update_agent_config(selected_agent_id, new_cfg):
|
| 116 |
+
if agent_state_key in st.session_state:
|
| 117 |
+
del st.session_state[agent_state_key]
|
| 118 |
+
time.sleep(0.5)
|
| 119 |
+
st.rerun()
|
| 120 |
+
except json.JSONDecodeError as e:
|
| 121 |
+
form.error(f"Invalid JSON: {e}")
|
| 122 |
+
else:
|
| 123 |
+
st.warning("Configuration not available.")
|
| 124 |
+
|
| 125 |
+
with tab2:
|
| 126 |
+
st.write("Recent logs (fetched from API):")
|
| 127 |
+
# Fetch logs if not in session state
|
| 128 |
+
if agent_logs_key not in st.session_state:
|
| 129 |
+
st.session_state[agent_logs_key] = get_agent_logs(selected_agent_id)
|
| 130 |
+
|
| 131 |
+
logs = st.session_state[agent_logs_key]
|
| 132 |
+
if logs is not None:
|
| 133 |
+
st.code("\n".join(logs), language="log")
|
| 134 |
+
else:
|
| 135 |
+
st.error("Could not retrieve logs.")
|
| 136 |
+
|
| 137 |
+
else:
|
| 138 |
+
st.info("Select an agent from the dropdown above.")
|
pages/3_NQL_Chatbot.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# pages/3_NQL_Chatbot.py
|
| 2 |
+
|
| 3 |
+
import streamlit as st
|
| 4 |
+
from .ui_utils import execute_nql_query # MODIFIED
|
| 5 |
+
import pandas as pd # Added import for pandas
|
| 6 |
+
|
| 7 |
+
st.set_page_config(page_title="NQL Chatbot", layout="wide")
|
| 8 |
+
|
| 9 |
+
st.title("💬 Natural Query Language (NQL) Chatbot")
|
| 10 |
+
st.caption("Query Tensorus datasets using natural language.")
|
| 11 |
+
st.info("Backend uses Regex-based NQL Agent. LLM integration is future work.")
|
| 12 |
+
|
| 13 |
+
# Initialize chat history
|
| 14 |
+
if "messages" not in st.session_state:
|
| 15 |
+
st.session_state.messages = []
|
| 16 |
+
|
| 17 |
+
# Display chat messages from history on app rerun
|
| 18 |
+
for message in st.session_state.messages:
|
| 19 |
+
with st.chat_message(message["role"]):
|
| 20 |
+
st.markdown(message["content"])
|
| 21 |
+
if "results" in message and message["results"] is not None and not message["results"].empty: # Check if DataFrame is not None and not empty
|
| 22 |
+
st.dataframe(message["results"], use_container_width=True) # Display results as dataframe
|
| 23 |
+
elif "error" in message:
|
| 24 |
+
st.error(message["error"])
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
# React to user input
|
| 28 |
+
if prompt := st.chat_input("Enter your query (e.g., 'get all data from my_dataset')"):
|
| 29 |
+
# Display user message in chat message container
|
| 30 |
+
st.chat_message("user").markdown(prompt)
|
| 31 |
+
# Add user message to chat history
|
| 32 |
+
st.session_state.messages.append({"role": "user", "content": prompt})
|
| 33 |
+
|
| 34 |
+
# Get assistant response from NQL Agent API
|
| 35 |
+
with st.spinner("Processing query..."):
|
| 36 |
+
nql_response = execute_nql_query(prompt)
|
| 37 |
+
|
| 38 |
+
response_content = ""
|
| 39 |
+
results_df = None
|
| 40 |
+
error_msg = None
|
| 41 |
+
|
| 42 |
+
if nql_response:
|
| 43 |
+
response_content = nql_response.get("message", "Error processing response.")
|
| 44 |
+
if nql_response.get("success"):
|
| 45 |
+
results_list = nql_response.get("results")
|
| 46 |
+
if results_list:
|
| 47 |
+
# Convert results list (containing dicts with 'metadata', 'shape', etc.) to DataFrame
|
| 48 |
+
# Extract relevant fields for display
|
| 49 |
+
display_data = []
|
| 50 |
+
for res in results_list:
|
| 51 |
+
row = {
|
| 52 |
+
"record_id": res["metadata"].get("record_id"),
|
| 53 |
+
"shape": str(res.get("shape")), # Convert shape list to string
|
| 54 |
+
"dtype": res.get("dtype"),
|
| 55 |
+
**res["metadata"] # Flatten metadata into columns
|
| 56 |
+
}
|
| 57 |
+
# Remove potentially large 'tensor' data from direct display
|
| 58 |
+
row.pop('tensor', None)
|
| 59 |
+
# Avoid duplicate metadata keys if also present at top level
|
| 60 |
+
row.pop('shape', None)
|
| 61 |
+
row.pop('dtype', None)
|
| 62 |
+
row.pop('record_id', None)
|
| 63 |
+
display_data.append(row)
|
| 64 |
+
|
| 65 |
+
if display_data:
|
| 66 |
+
results_df = pd.DataFrame(display_data)
|
| 67 |
+
|
| 68 |
+
# Augment message if results found
|
| 69 |
+
count = nql_response.get("count")
|
| 70 |
+
if count is not None:
|
| 71 |
+
response_content += f" Found {count} record(s)."
|
| 72 |
+
|
| 73 |
+
else:
|
| 74 |
+
# NQL agent indicated failure (parsing or execution)
|
| 75 |
+
error_msg = response_content # Use the message as the error
|
| 76 |
+
|
| 77 |
+
else:
|
| 78 |
+
# API call itself failed (connection error, etc.)
|
| 79 |
+
response_content = "Failed to get response from the NQL agent."
|
| 80 |
+
error_msg = response_content
|
| 81 |
+
|
| 82 |
+
# Display assistant response in chat message container
|
| 83 |
+
message_data = {"role": "assistant", "content": response_content}
|
| 84 |
+
with st.chat_message("assistant"):
|
| 85 |
+
st.markdown(response_content)
|
| 86 |
+
if results_df is not None and not results_df.empty: # Check if DataFrame is not None and not empty
|
| 87 |
+
st.dataframe(results_df, use_container_width=True)
|
| 88 |
+
message_data["results"] = results_df # Store for history display if needed (might be large)
|
| 89 |
+
elif error_msg:
|
| 90 |
+
st.error(error_msg)
|
| 91 |
+
message_data["error"] = error_msg
|
| 92 |
+
|
| 93 |
+
# Add assistant response to chat history
|
| 94 |
+
st.session_state.messages.append(message_data)
|
pages/4_Data_Explorer.py
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# pages/4_Data_Explorer.py
|
| 2 |
+
|
| 3 |
+
import streamlit as st
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import plotly.express as px
|
| 6 |
+
from .ui_utils import list_datasets, fetch_dataset_data # MODIFIED
|
| 7 |
+
import torch # Needed if we want to recreate tensors for inspection/plotting
|
| 8 |
+
|
| 9 |
+
st.set_page_config(page_title="Data Explorer", layout="wide")
|
| 10 |
+
|
| 11 |
+
st.title("🔍 Interactive Data Explorer")
|
| 12 |
+
st.caption("Browse, filter, and visualize Tensorus datasets.")
|
| 13 |
+
|
| 14 |
+
# --- Dataset Selection ---
|
| 15 |
+
datasets = list_datasets()
|
| 16 |
+
if not datasets:
|
| 17 |
+
st.warning("No datasets found or API connection failed. Cannot explore data.")
|
| 18 |
+
st.stop() # Stop execution if no datasets
|
| 19 |
+
|
| 20 |
+
selected_dataset = st.selectbox("Select Dataset:", datasets)
|
| 21 |
+
|
| 22 |
+
# --- Data Fetching & Filtering ---
|
| 23 |
+
if selected_dataset:
|
| 24 |
+
st.subheader(f"Exploring: {selected_dataset}")
|
| 25 |
+
|
| 26 |
+
PAGE_SIZE = 20
|
| 27 |
+
page_key = f"page_{selected_dataset}"
|
| 28 |
+
if page_key not in st.session_state:
|
| 29 |
+
st.session_state[page_key] = 0
|
| 30 |
+
|
| 31 |
+
page = st.session_state[page_key]
|
| 32 |
+
offset = page * PAGE_SIZE
|
| 33 |
+
records = fetch_dataset_data(selected_dataset, offset=offset, limit=PAGE_SIZE)
|
| 34 |
+
|
| 35 |
+
prev_disabled = page == 0
|
| 36 |
+
next_disabled = records is None or len(records) < PAGE_SIZE
|
| 37 |
+
|
| 38 |
+
col_prev, col_next = st.columns(2)
|
| 39 |
+
with col_prev:
|
| 40 |
+
if st.button("Previous", disabled=prev_disabled, key="prev_btn"):
|
| 41 |
+
st.session_state[page_key] = max(0, page - 1)
|
| 42 |
+
st.experimental_rerun()
|
| 43 |
+
with col_next:
|
| 44 |
+
if st.button("Next", disabled=next_disabled, key="next_btn"):
|
| 45 |
+
st.session_state[page_key] = page + 1
|
| 46 |
+
st.experimental_rerun()
|
| 47 |
+
|
| 48 |
+
if records is None:
|
| 49 |
+
st.error("Failed to fetch data for the selected dataset.")
|
| 50 |
+
st.stop()
|
| 51 |
+
elif not records:
|
| 52 |
+
st.info("Selected dataset is empty.")
|
| 53 |
+
st.stop()
|
| 54 |
+
|
| 55 |
+
start_idx = offset + 1
|
| 56 |
+
end_idx = offset + len(records)
|
| 57 |
+
st.info(f"Displaying records {start_idx} - {end_idx} (page {page + 1})")
|
| 58 |
+
|
| 59 |
+
# Create DataFrame from metadata for filtering/display
|
| 60 |
+
metadata_list = [r['metadata'] for r in records]
|
| 61 |
+
df_meta = pd.DataFrame(metadata_list)
|
| 62 |
+
|
| 63 |
+
# --- Metadata Filtering UI ---
|
| 64 |
+
st.sidebar.header("Filter by Metadata")
|
| 65 |
+
filter_cols = st.sidebar.multiselect("Select metadata columns to filter:", options=df_meta.columns.tolist())
|
| 66 |
+
|
| 67 |
+
filtered_df = df_meta.copy()
|
| 68 |
+
for col in filter_cols:
|
| 69 |
+
unique_values = filtered_df[col].dropna().unique().tolist()
|
| 70 |
+
if pd.api.types.is_numeric_dtype(filtered_df[col]):
|
| 71 |
+
# Numeric filter (slider)
|
| 72 |
+
min_val, max_val = float(filtered_df[col].min()), float(filtered_df[col].max())
|
| 73 |
+
if min_val < max_val:
|
| 74 |
+
selected_range = st.sidebar.slider(f"Filter {col}:", min_val, max_val, (min_val, max_val))
|
| 75 |
+
filtered_df = filtered_df[filtered_df[col].between(selected_range[0], selected_range[1])]
|
| 76 |
+
else:
|
| 77 |
+
st.sidebar.caption(f"{col}: Single numeric value ({min_val}), no range filter.")
|
| 78 |
+
|
| 79 |
+
elif len(unique_values) > 0 and len(unique_values) <= 20: # Limit dropdown options
|
| 80 |
+
# Categorical filter (multiselect)
|
| 81 |
+
selected_values = st.sidebar.multiselect(f"Filter {col}:", options=unique_values, default=unique_values)
|
| 82 |
+
if selected_values: # Only filter if some values are selected
|
| 83 |
+
filtered_df = filtered_df[filtered_df[col].isin(selected_values)]
|
| 84 |
+
else: # If user deselects everything, show nothing
|
| 85 |
+
filtered_df = filtered_df[filtered_df[col].isnull()] # Hack to get empty DF matching columns
|
| 86 |
+
|
| 87 |
+
else:
|
| 88 |
+
st.sidebar.text_input(f"Filter {col} (Text contains):", key=f"text_{col}")
|
| 89 |
+
search_term = st.session_state.get(f"text_{col}", "").lower()
|
| 90 |
+
if search_term:
|
| 91 |
+
# Ensure column is string type before using .str.contains
|
| 92 |
+
filtered_df = filtered_df[filtered_df[col].astype(str).str.lower().str.contains(search_term, na=False)]
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
st.divider()
|
| 96 |
+
st.subheader("Filtered Data View")
|
| 97 |
+
st.write(f"{len(filtered_df)} records matching filters.")
|
| 98 |
+
st.dataframe(filtered_df, use_container_width=True)
|
| 99 |
+
|
| 100 |
+
# --- Tensor Preview & Visualization ---
|
| 101 |
+
st.divider()
|
| 102 |
+
st.subheader("Tensor Preview")
|
| 103 |
+
|
| 104 |
+
if not filtered_df.empty:
|
| 105 |
+
# Allow selecting a record ID from the filtered results
|
| 106 |
+
record_ids = filtered_df['record_id'].tolist()
|
| 107 |
+
selected_record_id = st.selectbox("Select Record ID to Preview Tensor:", record_ids)
|
| 108 |
+
|
| 109 |
+
if selected_record_id:
|
| 110 |
+
# Find the full record data corresponding to the selected ID
|
| 111 |
+
selected_record = next((r for r in records if r['metadata'].get('record_id') == selected_record_id), None)
|
| 112 |
+
|
| 113 |
+
if selected_record:
|
| 114 |
+
st.write("Metadata:")
|
| 115 |
+
st.json(selected_record['metadata'])
|
| 116 |
+
|
| 117 |
+
shape = selected_record.get("shape")
|
| 118 |
+
dtype = selected_record.get("dtype")
|
| 119 |
+
data_list = selected_record.get("data")
|
| 120 |
+
|
| 121 |
+
st.write(f"Tensor Info: Shape={shape}, Dtype={dtype}")
|
| 122 |
+
|
| 123 |
+
try:
|
| 124 |
+
# Recreate tensor for potential plotting/display
|
| 125 |
+
# Be careful with large tensors in Streamlit UI!
|
| 126 |
+
# We might only want to show info or small slices.
|
| 127 |
+
if shape and dtype and data_list is not None:
|
| 128 |
+
tensor = torch.tensor(data_list, dtype=getattr(torch, dtype, torch.float32)) # Use getattr for dtype
|
| 129 |
+
st.write("Tensor Data (first few elements):")
|
| 130 |
+
st.code(f"{tensor.flatten()[:10].numpy()}...") # Show flattened start
|
| 131 |
+
|
| 132 |
+
# --- Simple Visualizations ---
|
| 133 |
+
if tensor.ndim == 1 and tensor.numel() > 1:
|
| 134 |
+
st.line_chart(tensor.numpy())
|
| 135 |
+
elif tensor.ndim == 2 and tensor.shape[0] > 1 and tensor.shape[1] > 1 :
|
| 136 |
+
# Simple heatmap using plotly (requires plotly)
|
| 137 |
+
try:
|
| 138 |
+
fig = px.imshow(tensor.numpy(), title="Tensor Heatmap", aspect="auto")
|
| 139 |
+
st.plotly_chart(fig, use_container_width=True)
|
| 140 |
+
except Exception as plot_err:
|
| 141 |
+
st.warning(f"Could not generate heatmap: {plot_err}")
|
| 142 |
+
elif tensor.ndim == 3 and tensor.shape[0] in [1, 3]: # Basic image check (C, H, W) or (1, H, W)
|
| 143 |
+
try:
|
| 144 |
+
# Permute if needed (e.g., C, H, W -> H, W, C for display)
|
| 145 |
+
if tensor.shape[0] in [1, 3]:
|
| 146 |
+
display_tensor = tensor.permute(1, 2, 0).squeeze() # H, W, C or H, W
|
| 147 |
+
# Clamp/normalize data to display range [0, 1] or [0, 255] - basic attempt
|
| 148 |
+
display_tensor = (display_tensor - display_tensor.min()) / (display_tensor.max() - display_tensor.min() + 1e-6)
|
| 149 |
+
st.image(display_tensor.numpy(), caption="Tensor as Image (Attempted)", use_column_width=True)
|
| 150 |
+
except ImportError:
|
| 151 |
+
st.warning("Pillow needed for image display (`pip install Pillow`)")
|
| 152 |
+
except Exception as img_err:
|
| 153 |
+
st.warning(f"Could not display tensor as image: {img_err}")
|
| 154 |
+
else:
|
| 155 |
+
st.info("No specific visualization available for this tensor shape/dimension.")
|
| 156 |
+
|
| 157 |
+
else:
|
| 158 |
+
st.warning("Tensor data, shape, or dtype missing in the record.")
|
| 159 |
+
|
| 160 |
+
except Exception as tensor_err:
|
| 161 |
+
st.error(f"Error processing tensor data for preview: {tensor_err}")
|
| 162 |
+
else:
|
| 163 |
+
st.warning("Selected record details not found (this shouldn't happen).")
|
| 164 |
+
else:
|
| 165 |
+
st.info("Select a record ID above to preview its tensor.")
|
| 166 |
+
else:
|
| 167 |
+
st.info("No records match the current filters.")
|
pages/5_API_Playground.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# pages/5_API_Playground.py
|
| 2 |
+
|
| 3 |
+
import streamlit as st
|
| 4 |
+
import streamlit.components.v1 as components
|
| 5 |
+
from .ui_utils import TENSORUS_API_URL, get_api_status # MODIFIED
|
| 6 |
+
|
| 7 |
+
st.set_page_config(page_title="API Playground", layout="wide")
|
| 8 |
+
|
| 9 |
+
st.title("🚀 API Playground & Documentation Hub")
|
| 10 |
+
st.caption("Explore and interact with the Tensorus REST API.")
|
| 11 |
+
|
| 12 |
+
# Check if API is running
|
| 13 |
+
api_running = get_api_status()
|
| 14 |
+
|
| 15 |
+
if not api_running:
|
| 16 |
+
st.error(
|
| 17 |
+
f"The Tensorus API backend does not seem to be running at {TENSORUS_API_URL}. "
|
| 18 |
+
"Please start the backend (`uvicorn api:app --reload`) to use the API Playground."
|
| 19 |
+
)
|
| 20 |
+
st.stop() # Stop execution if API is not available
|
| 21 |
+
else:
|
| 22 |
+
st.success(f"Connected to API backend at {TENSORUS_API_URL}")
|
| 23 |
+
|
| 24 |
+
st.markdown(
|
| 25 |
+
f"""
|
| 26 |
+
This section provides live, interactive documentation for the Tensorus API,
|
| 27 |
+
powered by FastAPI's OpenAPI integration. You can explore endpoints,
|
| 28 |
+
view schemas, and even try out API calls directly in your browser.
|
| 29 |
+
|
| 30 |
+
* **Swagger UI:** A graphical interface for exploring and testing API endpoints.
|
| 31 |
+
* **ReDoc:** Alternative documentation format, often preferred for reading.
|
| 32 |
+
|
| 33 |
+
Select a view below:
|
| 34 |
+
"""
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
# Use tabs to embed Swagger and ReDoc
|
| 38 |
+
tab1, tab2 = st.tabs(["Swagger UI", "ReDoc"])
|
| 39 |
+
|
| 40 |
+
# Construct the documentation URLs based on the API base URL
|
| 41 |
+
swagger_url = f"{TENSORUS_API_URL}/docs"
|
| 42 |
+
redoc_url = f"{TENSORUS_API_URL}/redoc"
|
| 43 |
+
|
| 44 |
+
with tab1:
|
| 45 |
+
st.subheader("Swagger UI")
|
| 46 |
+
st.markdown(f"Explore the API interactively. [Open in new tab]({swagger_url})")
|
| 47 |
+
# Embed Swagger UI using an iframe
|
| 48 |
+
components.iframe(swagger_url, height=800, scrolling=True)
|
| 49 |
+
|
| 50 |
+
with tab2:
|
| 51 |
+
st.subheader("ReDoc")
|
| 52 |
+
st.markdown(f"View the API documentation. [Open in new tab]({redoc_url})")
|
| 53 |
+
# Embed ReDoc using an iframe
|
| 54 |
+
components.iframe(redoc_url, height=800, scrolling=True)
|
| 55 |
+
|
| 56 |
+
st.divider()
|
| 57 |
+
st.caption("Note: Ensure the Tensorus API backend is running to interact with the playground.")
|
pages/6_Financial_Forecast_Demo.py
ADDED
|
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import numpy as np
|
| 4 |
+
import plotly.graph_objects as go # Using graph_objects for more control
|
| 5 |
+
import time
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
# Assuming tensorus is in PYTHONPATH or installed
|
| 9 |
+
from tensorus.tensor_storage import TensorStorage
|
| 10 |
+
from tensorus.financial_data_generator import generate_financial_data
|
| 11 |
+
from tensorus.time_series_predictor import load_time_series_from_tensorus, train_arima_and_predict, store_predictions_to_tensorus
|
| 12 |
+
|
| 13 |
+
# Attempt to load shared CSS
|
| 14 |
+
try:
|
| 15 |
+
from pages.pages_shared_utils import load_css
|
| 16 |
+
LOAD_CSS_AVAILABLE = True
|
| 17 |
+
except ImportError:
|
| 18 |
+
LOAD_CSS_AVAILABLE = False
|
| 19 |
+
def load_css(): # Dummy function
|
| 20 |
+
st.markdown("<style>/* No shared CSS found */</style>", unsafe_allow_html=True)
|
| 21 |
+
# st.warning("Could not load shared CSS from pages.pages_shared_utils.") # Keep UI cleaner
|
| 22 |
+
|
| 23 |
+
# Page Configuration
|
| 24 |
+
st.set_page_config(page_title="Financial Forecast Demo", layout="wide")
|
| 25 |
+
if LOAD_CSS_AVAILABLE:
|
| 26 |
+
load_css()
|
| 27 |
+
|
| 28 |
+
# --- Constants ---
|
| 29 |
+
RAW_DATASET_NAME = "financial_raw_data"
|
| 30 |
+
PREDICTION_DATASET_NAME = "financial_predictions"
|
| 31 |
+
TIME_SERIES_NAME = "synthetic_stock_close"
|
| 32 |
+
# Use a demo-specific storage path to avoid conflicts with main app
|
| 33 |
+
TENSOR_STORAGE_PATH = "tensor_data_financial_demo"
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def get_tensor_storage_instance():
|
| 37 |
+
"""Initializes and returns a TensorStorage instance for the demo."""
|
| 38 |
+
return TensorStorage(storage_path=TENSOR_STORAGE_PATH)
|
| 39 |
+
|
| 40 |
+
def show_financial_demo_page():
|
| 41 |
+
"""
|
| 42 |
+
Main function to display the Streamlit page for the financial demo.
|
| 43 |
+
"""
|
| 44 |
+
st.title("📈 Financial Time Series Forecasting Demo")
|
| 45 |
+
st.markdown("""
|
| 46 |
+
This demo showcases time series forecasting using an ARIMA model on synthetically generated
|
| 47 |
+
financial data. Data is generated, stored in Tensorus, loaded for model training,
|
| 48 |
+
and predictions are then stored back and visualized.
|
| 49 |
+
""")
|
| 50 |
+
|
| 51 |
+
# Initialize session state variables
|
| 52 |
+
if 'raw_financial_df' not in st.session_state:
|
| 53 |
+
st.session_state.raw_financial_df = None
|
| 54 |
+
if 'loaded_historical_series' not in st.session_state:
|
| 55 |
+
st.session_state.loaded_historical_series = None
|
| 56 |
+
if 'predictions_series' not in st.session_state:
|
| 57 |
+
st.session_state.predictions_series = None
|
| 58 |
+
if 'arima_p' not in st.session_state: # For number_input persistence
|
| 59 |
+
st.session_state.arima_p = 5
|
| 60 |
+
if 'arima_d' not in st.session_state:
|
| 61 |
+
st.session_state.arima_d = 1
|
| 62 |
+
if 'arima_q' not in st.session_state:
|
| 63 |
+
st.session_state.arima_q = 0
|
| 64 |
+
if 'n_predictions' not in st.session_state:
|
| 65 |
+
st.session_state.n_predictions = 30
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
# --- Section 1: Data Generation & Ingestion ---
|
| 69 |
+
st.header("1. Data Generation & Ingestion")
|
| 70 |
+
if st.button("Generate & Ingest Sample Financial Data", key="generate_data_button"):
|
| 71 |
+
with st.spinner("Generating and ingesting data..."):
|
| 72 |
+
try:
|
| 73 |
+
storage = get_tensor_storage_instance()
|
| 74 |
+
|
| 75 |
+
# Generate sample financial data
|
| 76 |
+
df = generate_financial_data(days=365, initial_price=150, trend_slope=0.1, seasonality_amplitude=20)
|
| 77 |
+
st.session_state.raw_financial_df = df # Store for immediate plotting
|
| 78 |
+
|
| 79 |
+
# Ensure dataset exists
|
| 80 |
+
if RAW_DATASET_NAME not in storage.list_datasets():
|
| 81 |
+
storage.create_dataset(RAW_DATASET_NAME)
|
| 82 |
+
st.info(f"Dataset '{RAW_DATASET_NAME}' created.")
|
| 83 |
+
else:
|
| 84 |
+
st.info(f"Dataset '{RAW_DATASET_NAME}' already exists. Using existing.")
|
| 85 |
+
|
| 86 |
+
# Prepare data for TensorStorage
|
| 87 |
+
series_to_store = torch.tensor(df['Close'].values, dtype=torch.float32)
|
| 88 |
+
dates_for_metadata = df['Date'].dt.strftime('%Y-%m-%d').tolist()
|
| 89 |
+
metadata = {
|
| 90 |
+
"name": TIME_SERIES_NAME,
|
| 91 |
+
"dates": dates_for_metadata,
|
| 92 |
+
"source": "financial_demo_generator",
|
| 93 |
+
"description": f"Daily closing prices for {TIME_SERIES_NAME}"
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
record_id = storage.insert(RAW_DATASET_NAME, series_to_store, metadata)
|
| 97 |
+
if record_id:
|
| 98 |
+
st.success(f"Data for '{TIME_SERIES_NAME}' ingested into '{RAW_DATASET_NAME}' with Record ID: {record_id}")
|
| 99 |
+
# Also load it into loaded_historical_series for consistency in the workflow
|
| 100 |
+
st.session_state.loaded_historical_series = pd.Series(df['Close'].values, index=pd.to_datetime(df['Date']), name=TIME_SERIES_NAME)
|
| 101 |
+
st.session_state.predictions_series = None # Clear previous predictions
|
| 102 |
+
else:
|
| 103 |
+
st.error("Failed to ingest data into Tensorus.")
|
| 104 |
+
except Exception as e:
|
| 105 |
+
st.error(f"Error during data generation/ingestion: {e}")
|
| 106 |
+
|
| 107 |
+
# --- Section 2: Historical Data Visualization ---
|
| 108 |
+
st.header("2. Historical Data Visualization")
|
| 109 |
+
if st.button("Load and View Historical Data from Tensorus", key="load_historical_button"):
|
| 110 |
+
with st.spinner(f"Loading '{TIME_SERIES_NAME}' from Tensorus..."):
|
| 111 |
+
try:
|
| 112 |
+
storage = get_tensor_storage_instance()
|
| 113 |
+
loaded_series = load_time_series_from_tensorus(
|
| 114 |
+
storage,
|
| 115 |
+
RAW_DATASET_NAME,
|
| 116 |
+
series_metadata_field="name",
|
| 117 |
+
series_name=TIME_SERIES_NAME,
|
| 118 |
+
date_metadata_field="dates"
|
| 119 |
+
)
|
| 120 |
+
if loaded_series is not None:
|
| 121 |
+
st.session_state.loaded_historical_series = loaded_series
|
| 122 |
+
st.session_state.predictions_series = None # Clear previous predictions
|
| 123 |
+
st.success(f"Successfully loaded '{TIME_SERIES_NAME}' from Tensorus.")
|
| 124 |
+
else:
|
| 125 |
+
st.warning(f"Could not find or load '{TIME_SERIES_NAME}' from Tensorus dataset '{RAW_DATASET_NAME}'. Generate data first if it's not there.")
|
| 126 |
+
except Exception as e:
|
| 127 |
+
st.error(f"Error loading historical data: {e}")
|
| 128 |
+
|
| 129 |
+
if st.session_state.loaded_historical_series is not None:
|
| 130 |
+
st.subheader(f"Historical Data: {st.session_state.loaded_historical_series.name}")
|
| 131 |
+
fig_hist = go.Figure()
|
| 132 |
+
fig_hist.add_trace(go.Scatter(
|
| 133 |
+
x=st.session_state.loaded_historical_series.index,
|
| 134 |
+
y=st.session_state.loaded_historical_series.values,
|
| 135 |
+
mode='lines',
|
| 136 |
+
name='Historical Close'
|
| 137 |
+
))
|
| 138 |
+
fig_hist.update_layout(title="Historical Stock Prices", xaxis_title="Date", yaxis_title="Price")
|
| 139 |
+
st.plotly_chart(fig_hist, use_container_width=True)
|
| 140 |
+
else:
|
| 141 |
+
st.info("Generate or load data to view historical prices.")
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
# --- Section 3: ARIMA Model Prediction ---
|
| 145 |
+
st.header("3. ARIMA Model Prediction")
|
| 146 |
+
|
| 147 |
+
col1, col2, col3, col4 = st.columns(4)
|
| 148 |
+
with col1:
|
| 149 |
+
st.number_input("ARIMA Order (p)", min_value=0, max_value=10, key="arima_p", help="Autoregressive order.")
|
| 150 |
+
with col2:
|
| 151 |
+
st.number_input("ARIMA Order (d)", min_value=0, max_value=3, key="arima_d", help="Differencing order.")
|
| 152 |
+
with col3:
|
| 153 |
+
st.number_input("ARIMA Order (q)", min_value=0, max_value=10, key="arima_q", help="Moving average order.")
|
| 154 |
+
with col4:
|
| 155 |
+
st.number_input("Number of Future Predictions", min_value=1, max_value=180, key="n_predictions", help="Number of future days to predict.")
|
| 156 |
+
|
| 157 |
+
if st.button("Run ARIMA Prediction", key="run_arima_button"):
|
| 158 |
+
if st.session_state.loaded_historical_series is None:
|
| 159 |
+
st.error("No historical data loaded. Please generate or load data first (Sections 1 or 2).")
|
| 160 |
+
else:
|
| 161 |
+
with st.spinner("Running ARIMA prediction... This might take a moment."):
|
| 162 |
+
try:
|
| 163 |
+
p, d, q = st.session_state.arima_p, st.session_state.arima_d, st.session_state.arima_q
|
| 164 |
+
n_preds = st.session_state.n_predictions
|
| 165 |
+
|
| 166 |
+
predictions = train_arima_and_predict(
|
| 167 |
+
series=st.session_state.loaded_historical_series,
|
| 168 |
+
arima_order=(p, d, q),
|
| 169 |
+
n_predictions=n_preds
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
if predictions is not None:
|
| 173 |
+
st.session_state.predictions_series = predictions
|
| 174 |
+
st.success(f"ARIMA prediction complete for {n_preds} steps.")
|
| 175 |
+
|
| 176 |
+
# Store predictions
|
| 177 |
+
storage = get_tensor_storage_instance()
|
| 178 |
+
model_details_dict = {"type": "ARIMA", "order": (p,d,q), "parameters_estimated": True} # Example details
|
| 179 |
+
|
| 180 |
+
pred_record_id = store_predictions_to_tensorus(
|
| 181 |
+
storage,
|
| 182 |
+
predictions,
|
| 183 |
+
PREDICTION_DATASET_NAME,
|
| 184 |
+
original_series_name=st.session_state.loaded_historical_series.name,
|
| 185 |
+
model_details=model_details_dict
|
| 186 |
+
)
|
| 187 |
+
if pred_record_id:
|
| 188 |
+
st.info(f"Predictions stored in Tensorus dataset '{PREDICTION_DATASET_NAME}' with Record ID: {pred_record_id}")
|
| 189 |
+
else:
|
| 190 |
+
st.warning("Failed to store predictions in Tensorus.")
|
| 191 |
+
else:
|
| 192 |
+
st.error("Failed to generate ARIMA predictions. Check model parameters, data, or logs.")
|
| 193 |
+
except Exception as e:
|
| 194 |
+
st.error(f"Error during ARIMA prediction: {e}")
|
| 195 |
+
|
| 196 |
+
# --- Section 4: Prediction Results & Visualization ---
|
| 197 |
+
st.header("4. Prediction Results")
|
| 198 |
+
if st.session_state.predictions_series is not None and st.session_state.loaded_historical_series is not None:
|
| 199 |
+
st.subheader("Historical Data and Predictions")
|
| 200 |
+
|
| 201 |
+
fig_combined = go.Figure()
|
| 202 |
+
fig_combined.add_trace(go.Scatter(
|
| 203 |
+
x=st.session_state.loaded_historical_series.index,
|
| 204 |
+
y=st.session_state.loaded_historical_series.values,
|
| 205 |
+
mode='lines',
|
| 206 |
+
name='Historical Close',
|
| 207 |
+
line=dict(color='blue')
|
| 208 |
+
))
|
| 209 |
+
fig_combined.add_trace(go.Scatter(
|
| 210 |
+
x=st.session_state.predictions_series.index,
|
| 211 |
+
y=st.session_state.predictions_series.values,
|
| 212 |
+
mode='lines',
|
| 213 |
+
name='Predicted Close',
|
| 214 |
+
line=dict(color='red', dash='dash')
|
| 215 |
+
))
|
| 216 |
+
fig_combined.update_layout(
|
| 217 |
+
title=f"{st.session_state.loaded_historical_series.name}: Historical vs. Predicted",
|
| 218 |
+
xaxis_title="Date",
|
| 219 |
+
yaxis_title="Price"
|
| 220 |
+
)
|
| 221 |
+
st.plotly_chart(fig_combined, use_container_width=True)
|
| 222 |
+
|
| 223 |
+
st.subheader("Predicted Values (Next {} Days)".format(len(st.session_state.predictions_series)))
|
| 224 |
+
# Display predictions as a DataFrame, ensure index is named 'Date' if it's a DatetimeIndex
|
| 225 |
+
pred_df_display = st.session_state.predictions_series.to_frame()
|
| 226 |
+
if isinstance(pred_df_display.index, pd.DatetimeIndex):
|
| 227 |
+
pred_df_display.index.name = "Date"
|
| 228 |
+
st.dataframe(pred_df_display)
|
| 229 |
+
|
| 230 |
+
elif st.session_state.loaded_historical_series is not None:
|
| 231 |
+
st.info("Run ARIMA Prediction in Section 3 to see forecasted results.")
|
| 232 |
+
else:
|
| 233 |
+
st.info("Generate or load data and run predictions to see results here.")
|
| 234 |
+
|
| 235 |
+
if __name__ == "__main__":
|
| 236 |
+
show_financial_demo_page()
|
pages/api_playground_v2.py
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# pages/api_playground_v2.py
|
| 2 |
+
|
| 3 |
+
import streamlit as st
|
| 4 |
+
import streamlit.components.v1 as components
|
| 5 |
+
|
| 6 |
+
# Import from the shared utils for pages
|
| 7 |
+
try:
|
| 8 |
+
from pages.pages_shared_utils import (
|
| 9 |
+
load_css as load_shared_css,
|
| 10 |
+
API_BASE_URL, # Use this instead of TENSORUS_API_URL
|
| 11 |
+
get_api_status
|
| 12 |
+
)
|
| 13 |
+
except ImportError:
|
| 14 |
+
st.error("Critical Error: Could not import `pages_shared_utils`. Page cannot function.")
|
| 15 |
+
def load_shared_css(): pass
|
| 16 |
+
API_BASE_URL = "http://127.0.0.1:7860" # Fallback
|
| 17 |
+
def get_api_status():
|
| 18 |
+
st.error("`get_api_status` unavailable.")
|
| 19 |
+
return False, {"error": "Setup issue"} # Simulate API down
|
| 20 |
+
st.stop()
|
| 21 |
+
|
| 22 |
+
st.set_page_config(page_title="API Playground (V2)", layout="wide")
|
| 23 |
+
load_shared_css() # Load common CSS from shared utilities
|
| 24 |
+
|
| 25 |
+
# Custom CSS for API Playground page
|
| 26 |
+
# These styles are specific to this page and enhance the shared theme.
|
| 27 |
+
st.markdown("""
|
| 28 |
+
<style>
|
| 29 |
+
/* API Playground specific styles */
|
| 30 |
+
/* Using classes for titles and captions allows for more specific styling if needed,
|
| 31 |
+
while still inheriting base styles from shared CSS (h1, .stCaption etc.) */
|
| 32 |
+
.main-title {
|
| 33 |
+
color: #e0e0ff !important; /* Ensure it uses the desired dashboard title color */
|
| 34 |
+
text-align: center;
|
| 35 |
+
margin-bottom: 0.5rem;
|
| 36 |
+
font-weight: bold;
|
| 37 |
+
}
|
| 38 |
+
.main-caption {
|
| 39 |
+
color: #c0c0ff !important; /* Consistent caption color */
|
| 40 |
+
text-align: center;
|
| 41 |
+
margin-bottom: 2rem;
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
/* Status message styling - these could be potentially moved to shared_utils if used elsewhere */
|
| 45 |
+
/* For now, keeping them page-specific as they are styled slightly differently from .status-indicator */
|
| 46 |
+
.status-message-shared { /* Base for status messages on this page */
|
| 47 |
+
padding: 0.75rem;
|
| 48 |
+
border-radius: 5px;
|
| 49 |
+
margin-bottom: 1rem;
|
| 50 |
+
text-align: center; /* Center align text in status messages */
|
| 51 |
+
}
|
| 52 |
+
.status-success { /* Success status message, inherits from .status-message-shared */
|
| 53 |
+
color: #4CAF50; /* Green text */
|
| 54 |
+
background-color: rgba(76, 175, 80, 0.1); /* Light green background */
|
| 55 |
+
border: 1px solid rgba(76, 175, 80, 0.3); /* Light green border */
|
| 56 |
+
}
|
| 57 |
+
.status-error { /* Error status message, inherits from .status-message-shared */
|
| 58 |
+
color: #F44336; /* Red text */
|
| 59 |
+
background-color: rgba(244, 67, 54, 0.1); /* Light red background */
|
| 60 |
+
border: 1px solid rgba(244, 67, 54, 0.3); /* Light red border */
|
| 61 |
+
}
|
| 62 |
+
.status-error code, .status-success code { /* Style <code> tags within status messages */
|
| 63 |
+
background-color: rgba(0,0,0,0.1);
|
| 64 |
+
padding: 0.1em 0.3em;
|
| 65 |
+
border-radius: 3px;
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
/* Styling for the Streamlit Tabs component */
|
| 69 |
+
.tab-container .stTabs [data-baseweb="tab-list"] {
|
| 70 |
+
background-color: transparent; /* Remove default Streamlit tab bar background if any */
|
| 71 |
+
gap: 5px; /* Increased gap between tab headers */
|
| 72 |
+
border-bottom: 1px solid #2a3f5c; /* Match other dividers */
|
| 73 |
+
}
|
| 74 |
+
.tab-container .stTabs [data-baseweb="tab"] {
|
| 75 |
+
background-color: #18223f; /* Match common-card background for inactive tabs */
|
| 76 |
+
color: #c0c0ff; /* Standard text color for inactive tabs */
|
| 77 |
+
border-radius: 5px 5px 0 0; /* Rounded top corners */
|
| 78 |
+
padding: 0.75rem 1.5rem;
|
| 79 |
+
border: 1px solid #2a3f5c; /* Border for inactive tabs */
|
| 80 |
+
border-bottom: none; /* Remove bottom border as it's handled by tab-list */
|
| 81 |
+
margin-bottom: -1px; /* Overlap with tab-list border */
|
| 82 |
+
}
|
| 83 |
+
.tab-container .stTabs [data-baseweb="tab"][aria-selected="true"] {
|
| 84 |
+
background-color: #2a2f4c; /* Slightly lighter for active tab, similar to nav hover */
|
| 85 |
+
color: #ffffff; /* White text for active tab */
|
| 86 |
+
font-weight: bold;
|
| 87 |
+
border-color: #3a6fbf; /* Accent color for active tab border */
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
/* Container for the iframes displaying Swagger/ReDoc */
|
| 91 |
+
.iframe-container {
|
| 92 |
+
border: 1px solid #2a3f5c; /* Consistent border */
|
| 93 |
+
border-radius: 0 5px 5px 5px; /* Rounded corners, except top-left to align with tabs */
|
| 94 |
+
overflow: hidden; /* Ensures iframe respects border radius */
|
| 95 |
+
margin-top: -1px; /* Align with tab-list border */
|
| 96 |
+
height: 800px; /* Default height */
|
| 97 |
+
}
|
| 98 |
+
</style>
|
| 99 |
+
""", unsafe_allow_html=True)
|
| 100 |
+
|
| 101 |
+
# Page Title and Caption, using custom classes for styling
|
| 102 |
+
st.markdown('<h1 class="main-title">🚀 API Playground & Documentation Hub</h1>', unsafe_allow_html=True)
|
| 103 |
+
st.markdown('<p class="main-caption">Explore and interact with the Tensorus REST API directly.</p>', unsafe_allow_html=True)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
# Check if API is running using the shared utility function.
|
| 107 |
+
api_ok, api_info = get_api_status()
|
| 108 |
+
|
| 109 |
+
# Display API status message.
|
| 110 |
+
if not api_ok:
|
| 111 |
+
# If API is not reachable, display an error message and stop page execution.
|
| 112 |
+
st.markdown(
|
| 113 |
+
f"""
|
| 114 |
+
<div class="status-message-shared status-error">
|
| 115 |
+
<strong>API Connection Error:</strong> The Tensorus API backend does not seem to be running or reachable at <code>{API_BASE_URL}</code>.
|
| 116 |
+
<br>Please ensure the backend (<code>uvicorn api:app --reload</code>) is active to use the API Playground.
|
| 117 |
+
</div>
|
| 118 |
+
""", unsafe_allow_html=True
|
| 119 |
+
)
|
| 120 |
+
st.stop() # Halt further rendering of the page.
|
| 121 |
+
else:
|
| 122 |
+
# If API is reachable, display a success message with API version.
|
| 123 |
+
api_version = api_info.get("version", "N/A") # Get API version from status info.
|
| 124 |
+
st.markdown(
|
| 125 |
+
f"""
|
| 126 |
+
<div class="status-message-shared status-success">
|
| 127 |
+
Successfully connected to Tensorus API v{api_version} at <code>{API_BASE_URL}</code>.
|
| 128 |
+
</div>
|
| 129 |
+
""", unsafe_allow_html=True
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
# Introductory text for the API Playground.
|
| 133 |
+
st.markdown(
|
| 134 |
+
"""
|
| 135 |
+
This section provides live, interactive documentation for the Tensorus API,
|
| 136 |
+
powered by FastAPI's OpenAPI integration. You can explore endpoints,
|
| 137 |
+
view schemas, and even try out API calls directly in your browser.
|
| 138 |
+
"""
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
# Use tabs to embed Swagger UI and ReDoc for API documentation.
|
| 142 |
+
# Wrapping tabs in a div to apply specific tab styling if needed.
|
| 143 |
+
st.markdown('<div class="tab-container">', unsafe_allow_html=True)
|
| 144 |
+
tab1, tab2 = st.tabs(["Swagger UI", "ReDoc"])
|
| 145 |
+
st.markdown('</div>', unsafe_allow_html=True)
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
# Construct the URLs for Swagger and ReDoc based on the API_BASE_URL from shared utils.
|
| 149 |
+
swagger_url = f"{API_BASE_URL}/docs"
|
| 150 |
+
redoc_url = f"{API_BASE_URL}/redoc"
|
| 151 |
+
|
| 152 |
+
# Swagger UI Tab
|
| 153 |
+
with tab1:
|
| 154 |
+
st.subheader("Swagger UI") # Styled by shared CSS
|
| 155 |
+
st.markdown(f"Explore the API interactively. [Open Swagger UI in new tab]({swagger_url})")
|
| 156 |
+
# Embed Swagger UI using an iframe within a styled container.
|
| 157 |
+
st.markdown('<div class="iframe-container">', unsafe_allow_html=True)
|
| 158 |
+
components.iframe(swagger_url, height=800, scrolling=True)
|
| 159 |
+
st.markdown('</div>', unsafe_allow_html=True)
|
| 160 |
+
|
| 161 |
+
# ReDoc Tab
|
| 162 |
+
with tab2:
|
| 163 |
+
st.subheader("ReDoc") # Styled by shared CSS
|
| 164 |
+
st.markdown(f"View the API documentation. [Open ReDoc in new tab]({redoc_url})")
|
| 165 |
+
# Embed ReDoc using an iframe within a styled container.
|
| 166 |
+
st.markdown('<div class="iframe-container">', unsafe_allow_html=True)
|
| 167 |
+
components.iframe(redoc_url, height=800, scrolling=True)
|
| 168 |
+
st.markdown('</div>', unsafe_allow_html=True)
|
| 169 |
+
|
| 170 |
+
st.divider() # Visual separator.
|
| 171 |
+
st.caption("Note: The API backend must be running and accessible to fully utilize the interactive documentation.")
|
pages/control_panel_v2.py
ADDED
|
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# pages/control_panel_v2.py
|
| 2 |
+
|
| 3 |
+
import streamlit as st
|
| 4 |
+
import time
|
| 5 |
+
|
| 6 |
+
# Import from the newly created shared utils for pages
|
| 7 |
+
try:
|
| 8 |
+
from pages.pages_shared_utils import (
|
| 9 |
+
get_agent_status,
|
| 10 |
+
start_agent,
|
| 11 |
+
stop_agent,
|
| 12 |
+
load_css as load_shared_css
|
| 13 |
+
)
|
| 14 |
+
except ImportError:
|
| 15 |
+
st.error("Critical Error: Could not import `pages_shared_utils`. Page cannot function.")
|
| 16 |
+
def get_agent_status(): st.error("`get_agent_status` unavailable."); return None
|
| 17 |
+
def start_agent(agent_id): st.error(f"Start action for {agent_id} unavailable."); return False
|
| 18 |
+
def stop_agent(agent_id): st.error(f"Stop action for {agent_id} unavailable."); return False
|
| 19 |
+
def load_shared_css(): pass
|
| 20 |
+
st.stop()
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
st.set_page_config(page_title="Agent Control Tower (V2)", layout="wide")
|
| 24 |
+
load_shared_css() # Load shared styles first. This is crucial for base styling.
|
| 25 |
+
|
| 26 |
+
# Custom CSS for Agent Cards on this page.
|
| 27 |
+
# These styles complement or override the .common-card style from shared_utils.
|
| 28 |
+
# Note: .common-card styles (background, border, padding, shadow) are inherited.
|
| 29 |
+
st.markdown("""
|
| 30 |
+
<style>
|
| 31 |
+
/* Agent Card specific styling */
|
| 32 |
+
.agent-card { /* This class is applied along with .common-card */
|
| 33 |
+
/* Example: if agent cards needed a slightly different background or padding than common-card */
|
| 34 |
+
/* background-color: #1c2849; */ /* Slightly different background if needed */
|
| 35 |
+
}
|
| 36 |
+
.agent-card-header { /* Header section within an agent card */
|
| 37 |
+
display: flex;
|
| 38 |
+
align-items: center;
|
| 39 |
+
margin-bottom: 1rem;
|
| 40 |
+
padding-bottom: 0.75rem;
|
| 41 |
+
border-bottom: 1px solid #2a3f5c; /* Consistent border color */
|
| 42 |
+
}
|
| 43 |
+
.agent-card-header .icon { /* Icon in the header */
|
| 44 |
+
font-size: 2em;
|
| 45 |
+
margin-right: 0.75rem;
|
| 46 |
+
color: #9090ff; /* Specific icon color for agent cards */
|
| 47 |
+
}
|
| 48 |
+
.agent-card-header h3 { /* Agent name in the header */
|
| 49 |
+
/* Inherits .common-card h3 styles primarily, ensuring consistency */
|
| 50 |
+
color: #d0d0ff !important; /* Override if a brighter color is needed than common-card h3 */
|
| 51 |
+
font-size: 1.4em !important;
|
| 52 |
+
margin: 0 !important; /* Remove any default margins */
|
| 53 |
+
}
|
| 54 |
+
.agent-card .status-text {
|
| 55 |
+
font-weight: bold !important; /* Ensure status text is bold */
|
| 56 |
+
display: inline-block; /* Allows padding and consistent look */
|
| 57 |
+
padding: 0.2em 0.5em;
|
| 58 |
+
border-radius: 4px;
|
| 59 |
+
font-size: 0.9em;
|
| 60 |
+
}
|
| 61 |
+
/* Specific status colors for agent cards, using shared .status-indicator naming convention for consistency */
|
| 62 |
+
.agent-card .status-running { color: #ffffff; background-color: #4CAF50;}
|
| 63 |
+
.agent-card .status-stopped { color: #e0e0e0; background-color: #525252;}
|
| 64 |
+
.agent-card .status-error { color: #ffffff; background-color: #F44336;}
|
| 65 |
+
.agent-card .status-unknown { color: #333333; background-color: #BDBDBD;}
|
| 66 |
+
|
| 67 |
+
.agent-card p.description { /* Agent description text */
|
| 68 |
+
font-size: 0.95em;
|
| 69 |
+
color: #c0c0e0;
|
| 70 |
+
margin-bottom: 0.75rem;
|
| 71 |
+
}
|
| 72 |
+
.agent-card .metrics { /* Styling for the metrics block within an agent card */
|
| 73 |
+
font-size: 0.9em;
|
| 74 |
+
color: #b0b0d0;
|
| 75 |
+
margin-top: 0.75rem;
|
| 76 |
+
padding: 0.6rem; /* Increased padding */
|
| 77 |
+
background-color: rgba(0,0,0,0.15); /* Slightly darker background for metrics section */
|
| 78 |
+
border-radius: 4px;
|
| 79 |
+
border-left: 3px solid #3a6fbf; /* Accent blue border */
|
| 80 |
+
}
|
| 81 |
+
.agent-card .metrics p { margin-bottom: 0.3rem; } /* Spacing for lines within metrics block */
|
| 82 |
+
|
| 83 |
+
.agent-card .actions { /* Container for action buttons */
|
| 84 |
+
margin-top: 1.25rem;
|
| 85 |
+
display: flex;
|
| 86 |
+
gap: 0.75rem; /* Space between buttons */
|
| 87 |
+
justify-content: flex-start; /* Align buttons to the start */
|
| 88 |
+
}
|
| 89 |
+
/* Buttons within .actions will inherit general .stButton styling from shared_css */
|
| 90 |
+
</style>
|
| 91 |
+
""", unsafe_allow_html=True)
|
| 92 |
+
|
| 93 |
+
st.title("🕹️ Agent Control Tower")
|
| 94 |
+
st.caption("Manage and monitor your Tensorus intelligent agents.")
|
| 95 |
+
|
| 96 |
+
# AGENT_DETAILS_MAP provides static information (name, icon, description) for each agent type.
|
| 97 |
+
# This map is used to render the UI elements for agents known to the frontend.
|
| 98 |
+
# Metrics template helps in displaying placeholder metrics consistently.
|
| 99 |
+
AGENT_DETAILS_MAP = {
|
| 100 |
+
"ingestion": {"name": "Ingestion Agent", "icon": "📥", "description": "Monitors file systems and ingests new data tensors into the platform.", "metrics_template": "Files Processed: {files_processed}"},
|
| 101 |
+
"rl_trainer": {"name": "RL Training Agent", "icon": "🧠", "description": "Trains reinforcement learning models using available experiences and data.", "metrics_template": "Training Cycles: {episodes_trained}, Avg. Reward: {avg_reward}"},
|
| 102 |
+
"automl_search": {"name": "AutoML Search Agent", "icon": "✨", "description": "Performs automated machine learning model searches and hyperparameter tuning.", "metrics_template": "Search Trials: {trials_completed}, Best Model Score: {best_score}"},
|
| 103 |
+
"nql_query": {"name": "NQL Query Agent", "icon": "🗣️", "description": "Processes Natural Query Language (NQL) requests against tensor data.", "metrics_template": "Queries Handled: {queries_processed}"}
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
# Create two tabs: one for the roster of agents, another for visualizing interactions.
|
| 107 |
+
tab1, tab2 = st.tabs(["📊 Agent Roster", "🔗 Agent Interaction Visualizer"])
|
| 108 |
+
|
| 109 |
+
# --- Agent Roster Tab ---
|
| 110 |
+
with tab1:
|
| 111 |
+
st.header("Live Agent Roster") # Standard header, styled by shared CSS
|
| 112 |
+
|
| 113 |
+
# Button to manually refresh agent statuses from the backend.
|
| 114 |
+
if st.button("🔄 Refresh Agent Statuses"):
|
| 115 |
+
# Update session state cache with fresh data from API.
|
| 116 |
+
# Using a page-specific cache key to avoid conflicts.
|
| 117 |
+
st.session_state.agent_statuses_cache_control_panel_v2 = get_agent_status()
|
| 118 |
+
# Streamlit automatically reruns the script on button press.
|
| 119 |
+
|
| 120 |
+
# Cache for agent statuses to avoid redundant API calls on every interaction.
|
| 121 |
+
# Initializes if key is not present or if cache is None (e.g., after an error).
|
| 122 |
+
cache_key = 'agent_statuses_cache_control_panel_v2'
|
| 123 |
+
if cache_key not in st.session_state or st.session_state[cache_key] is None:
|
| 124 |
+
st.session_state[cache_key] = get_agent_status() # Initial fetch if not in cache
|
| 125 |
+
|
| 126 |
+
agents_api_data = st.session_state[cache_key] # Retrieve cached data.
|
| 127 |
+
|
| 128 |
+
# Check if agent data was successfully fetched.
|
| 129 |
+
if not agents_api_data:
|
| 130 |
+
st.error("Could not fetch agent statuses. Ensure the backend API is running and `pages_shared_utils.py` is correctly configured.")
|
| 131 |
+
else:
|
| 132 |
+
num_columns = 2 # Define number of columns for agent card layout.
|
| 133 |
+
cols = st.columns(num_columns)
|
| 134 |
+
|
| 135 |
+
# Iterate through a predefined list of agent IDs expected to be available.
|
| 136 |
+
# This ensures consistent ordering and display of known agents.
|
| 137 |
+
agent_ids_to_display = list(AGENT_DETAILS_MAP.keys())
|
| 138 |
+
|
| 139 |
+
for i, agent_id in enumerate(agent_ids_to_display):
|
| 140 |
+
# Retrieve static details (icon, name, description) from the predefined map.
|
| 141 |
+
# Provides a fallback if an agent_id from the API isn't in the map (e.g., new agent not yet in UI map).
|
| 142 |
+
agent_static_details = AGENT_DETAILS_MAP.get(agent_id, {
|
| 143 |
+
"name": agent_id.replace("_", " ").title() + " Agent (Unknown)", # Default name formatting
|
| 144 |
+
"icon": "❓", # Default icon for unknown agent types
|
| 145 |
+
"description": "Details not defined in AGENT_DETAILS_MAP.",
|
| 146 |
+
"metrics_template": "Status: {status}" # Basic fallback metric
|
| 147 |
+
})
|
| 148 |
+
# Retrieve live runtime information (status, specific metrics) from the API data.
|
| 149 |
+
agent_runtime_info = agents_api_data.get(agent_id, {}) # Default to empty dict if agent not in API response
|
| 150 |
+
|
| 151 |
+
with cols[i % num_columns]: # Distribute agents into the defined columns.
|
| 152 |
+
# Each agent is displayed in a styled card.
|
| 153 |
+
# The card uses 'common-card' class from shared CSS and page-specific 'agent-card'.
|
| 154 |
+
st.markdown(f'<div class="common-card agent-card" id="agent-card-{agent_id}">', unsafe_allow_html=True)
|
| 155 |
+
|
| 156 |
+
# Agent card header: Icon and Name.
|
| 157 |
+
st.markdown(f"""
|
| 158 |
+
<div class="agent-card-header">
|
| 159 |
+
<span class="icon">{agent_static_details['icon']}</span>
|
| 160 |
+
<h3>{agent_static_details['name']}</h3>
|
| 161 |
+
</div>
|
| 162 |
+
""", unsafe_allow_html=True)
|
| 163 |
+
|
| 164 |
+
# Agent status: Fetched live, styled based on status string.
|
| 165 |
+
# Prioritize 'status' field, fallback to 'running' boolean if 'status' is missing.
|
| 166 |
+
current_status = agent_runtime_info.get('status', agent_runtime_info.get('running', 'unknown'))
|
| 167 |
+
if isinstance(current_status, bool): # Convert boolean status to string representation
|
| 168 |
+
current_status = "running" if current_status else "stopped"
|
| 169 |
+
status_class = f"status-{current_status.lower()}" # CSS class for styling the status text
|
| 170 |
+
st.markdown(f'<p>Status: <span class="status-text {status_class}">{current_status.upper()}</span></p>', unsafe_allow_html=True)
|
| 171 |
+
|
| 172 |
+
# Agent description from the static map.
|
| 173 |
+
st.markdown(f"<p class='description'>{agent_static_details['description']}</p>", unsafe_allow_html=True)
|
| 174 |
+
|
| 175 |
+
# Placeholder for agent-specific metrics.
|
| 176 |
+
# Actual metrics would be populated from `agent_runtime_info` if the backend provides them.
|
| 177 |
+
# Fallback to "N/A" if a metric key is not present in runtime_info.
|
| 178 |
+
metrics_values = {
|
| 179 |
+
"files_processed": agent_runtime_info.get("files_processed", "N/A"),
|
| 180 |
+
"episodes_trained": agent_runtime_info.get("episodes_trained", "N/A"),
|
| 181 |
+
"avg_reward": agent_runtime_info.get("avg_reward", "N/A"),
|
| 182 |
+
"trials_completed": agent_runtime_info.get("trials_completed", "N/A"),
|
| 183 |
+
"best_score": agent_runtime_info.get("best_score", "N/A"),
|
| 184 |
+
"queries_processed": agent_runtime_info.get("queries_processed", "N/A"),
|
| 185 |
+
"status": current_status # Fallback for generic status display in metrics
|
| 186 |
+
}
|
| 187 |
+
metrics_str = agent_static_details['metrics_template'].format(**metrics_values)
|
| 188 |
+
st.markdown(f'<div class="metrics"><p>{metrics_str}</p></div>', unsafe_allow_html=True)
|
| 189 |
+
|
| 190 |
+
# Action buttons container.
|
| 191 |
+
st.markdown('<div class="actions">', unsafe_allow_html=True)
|
| 192 |
+
actions_cols = st.columns([1,1,1]) # Columns for button layout within the actions div.
|
| 193 |
+
|
| 194 |
+
with actions_cols[0]: # Details button (currently a placeholder).
|
| 195 |
+
st.button("Details", key=f"details_{agent_id}_v2", help="View detailed agent information (coming soon).", disabled=True, use_container_width=True)
|
| 196 |
+
|
| 197 |
+
with actions_cols[1]: # Start/Stop button, conditional on agent status.
|
| 198 |
+
if current_status == "running":
|
| 199 |
+
if st.button("Stop", key=f"stop_{agent_id}_v2", type="secondary", use_container_width=True):
|
| 200 |
+
result = stop_agent(agent_id) # API call from shared_utils
|
| 201 |
+
st.toast(result.get("message", f"Stop request sent for {agent_id}."))
|
| 202 |
+
time.sleep(1.0) # Brief pause to allow backend to process the request.
|
| 203 |
+
st.session_state[cache_key] = get_agent_status() # Refresh status cache.
|
| 204 |
+
st.rerun() # Rerun page to reflect updated status.
|
| 205 |
+
else: # Agent is stopped, errored, or in an unknown state.
|
| 206 |
+
if st.button("Start", key=f"start_{agent_id}_v2", type="primary", use_container_width=True):
|
| 207 |
+
result = start_agent(agent_id) # API call from shared_utils
|
| 208 |
+
st.toast(result.get("message", f"Start request sent for {agent_id}."))
|
| 209 |
+
time.sleep(1.0)
|
| 210 |
+
st.session_state[cache_key] = get_agent_status() # Refresh.
|
| 211 |
+
st.rerun()
|
| 212 |
+
|
| 213 |
+
with actions_cols[2]: # Configure button (currently a placeholder).
|
| 214 |
+
st.button("Configure", key=f"config_{agent_id}_v2", help="Configure agent settings (coming soon).", disabled=True, use_container_width=True)
|
| 215 |
+
|
| 216 |
+
st.markdown('</div>', unsafe_allow_html=True) # Close actions div.
|
| 217 |
+
st.markdown('</div>', unsafe_allow_html=True) # Close agent-card div.
|
| 218 |
+
|
| 219 |
+
# --- Agent Interaction Visualizer Tab ---
|
| 220 |
+
with tab2: # Agent Interaction Visualizer Tab
|
| 221 |
+
st.header("Agent Interaction Visualizer") # Standard header
|
| 222 |
+
st.subheader("Conceptual Agent-Data Flow")
|
| 223 |
+
|
| 224 |
+
# DOT language definition for the Graphviz chart.
|
| 225 |
+
# This describes the nodes (agents, data stores, sources) and edges (data flow) of the system.
|
| 226 |
+
graphviz_code = """
|
| 227 |
+
digraph AgentDataFlow {
|
| 228 |
+
// General graph attributes
|
| 229 |
+
rankdir=LR; /* Layout direction: Left to Right */
|
| 230 |
+
bgcolor="#0a0f2c"; /* Background color to match the app theme */
|
| 231 |
+
node [shape=record, style="filled,rounded", fillcolor="#18223f", /* Default node style */
|
| 232 |
+
fontname="Arial", fontsize=11, fontcolor="#e0e0e0", /* Default font attributes */
|
| 233 |
+
color="#3a6fbf", penwidth=1.5]; /* Default border color and width */
|
| 234 |
+
edge [fontname="Arial", fontsize=10, fontcolor="#c0c0ff", /* Default edge style */
|
| 235 |
+
color="#7080ff", penwidth=1.2]; /* Default edge color and width */
|
| 236 |
+
|
| 237 |
+
// Cluster for Data Sources
|
| 238 |
+
subgraph cluster_data_sources {
|
| 239 |
+
label="External Sources"; /* Cluster title */
|
| 240 |
+
style="rounded";
|
| 241 |
+
color="#4A5C85"; /* Cluster border color */
|
| 242 |
+
bgcolor="#101828"; /* Cluster background (slightly different from main) */
|
| 243 |
+
fontcolor="#D0D0FF"; /* Cluster title color */
|
| 244 |
+
|
| 245 |
+
fs [label="FileSystem | (e.g., S3, Local Disk)", shape=folder, fillcolor="#3E5F8A"]; // File system node
|
| 246 |
+
user_queries [label="User Queries | (via UI/API)", shape=ellipse, fillcolor="#5DADE2"]; // User queries node
|
| 247 |
+
}
|
| 248 |
+
|
| 249 |
+
// Cluster for Intelligent Agents
|
| 250 |
+
subgraph cluster_agents_group {
|
| 251 |
+
label="Intelligent Agents";
|
| 252 |
+
style="rounded"; color="#4A5C85"; bgcolor="#101828"; fontcolor="#D0D0FF";
|
| 253 |
+
|
| 254 |
+
ingestion_agent [label="{IngestionAgent | 📥 | Ingests raw data}", shape=Mrecord, fillcolor="#E74C3C"]; // Red
|
| 255 |
+
nql_agent [label="{NQLAgent | 🗣️ | Processes natural language queries}", shape=Mrecord, fillcolor="#9B59B6"]; // Purple
|
| 256 |
+
rl_agent [label="{RLAgent | 🧠 | Trains RL models, generates experiences}", shape=Mrecord, fillcolor="#2ECC71"]; // Green
|
| 257 |
+
automl_agent [label="{AutoMLAgent | ✨ | Conducts AutoML searches}", shape=Mrecord, fillcolor="#F1C40F"]; // Yellow
|
| 258 |
+
}
|
| 259 |
+
|
| 260 |
+
// Cluster for Data Stores
|
| 261 |
+
subgraph cluster_data_stores {
|
| 262 |
+
label="Tensorus Data Stores";
|
| 263 |
+
style="rounded"; color="#4A5C85"; bgcolor="#101828"; fontcolor="#D0D0FF";
|
| 264 |
+
|
| 265 |
+
ingested_data_api [label="Ingested Data Store | (Primary Tensor Collection)", shape=cylinder, fillcolor="#7F8C8D", height=1.5]; // Grey
|
| 266 |
+
rl_states [label="RL States | (Tensor Collection)", shape=cylinder, fillcolor="#95A5A6"]; // Lighter Grey
|
| 267 |
+
rl_experiences [label="RL Experiences | (Metadata/Tensor Collection)", shape=cylinder, fillcolor="#95A5A6"];
|
| 268 |
+
automl_results [label="AutoML Results | (Tensor/Metadata Collection)", shape=cylinder, fillcolor="#95A5A6"];
|
| 269 |
+
}
|
| 270 |
+
|
| 271 |
+
// Edges defining data flow between nodes
|
| 272 |
+
fs -> ingestion_agent [label=" new files/streams"];
|
| 273 |
+
ingestion_agent -> ingested_data_api [label=" stores tensors & metadata"];
|
| 274 |
+
|
| 275 |
+
user_queries -> nql_agent [label=" NQL query"];
|
| 276 |
+
nql_agent -> ingested_data_api [label=" reads data"];
|
| 277 |
+
nql_agent -> rl_experiences [label=" (can also query specific stores like experiences)", style=dashed]; /* Dashed for optional/secondary path */
|
| 278 |
+
|
| 279 |
+
ingested_data_api -> rl_agent [label=" reads training data/states"];
|
| 280 |
+
rl_agent -> rl_states [label=" stores/updates policy states"];
|
| 281 |
+
rl_agent -> rl_experiences [label=" stores new experiences"];
|
| 282 |
+
rl_experiences -> rl_agent [label=" reads past experiences for training"]; // Loop back for learning
|
| 283 |
+
|
| 284 |
+
ingested_data_api -> automl_agent [label=" reads datasets for AutoML"];
|
| 285 |
+
automl_agent -> automl_results [label=" stores trial results & models"];
|
| 286 |
+
}
|
| 287 |
+
"""
|
| 288 |
+
try:
|
| 289 |
+
# Render the Graphviz chart.
|
| 290 |
+
st.graphviz_chart(graphviz_code)
|
| 291 |
+
st.caption("This is a conceptual representation of data flows. Actual interactions can be more complex and configurable.")
|
| 292 |
+
except Exception as e:
|
| 293 |
+
# Handle potential errors if Graphviz is not installed or there's an issue with the DOT code.
|
| 294 |
+
st.error(f"Could not render Graphviz chart: {e}")
|
| 295 |
+
st.markdown("Please ensure Graphviz is installed and accessible in your environment's PATH (e.g., `sudo apt-get install graphviz` on Debian/Ubuntu).")
|
| 296 |
+
|
| 297 |
+
# Initialize session state key if not already present for this page.
|
| 298 |
+
# This helps prevent errors if the page is loaded before the cache is populated.
|
| 299 |
+
if 'agent_statuses_cache_control_panel_v2' not in st.session_state:
|
| 300 |
+
st.session_state.agent_statuses_cache_control_panel_v2 = None
|
pages/data_explorer_v2.py
ADDED
|
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# pages/data_explorer_v2.py
|
| 2 |
+
|
| 3 |
+
import streamlit as st
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import plotly.express as px
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
try:
|
| 9 |
+
from pages.pages_shared_utils import (
|
| 10 |
+
load_css as load_shared_css,
|
| 11 |
+
get_datasets,
|
| 12 |
+
get_dataset_preview
|
| 13 |
+
)
|
| 14 |
+
except ImportError:
|
| 15 |
+
st.error("Critical Error: Could not import `pages_shared_utils`. Page cannot function.")
|
| 16 |
+
def load_shared_css(): pass
|
| 17 |
+
def get_datasets(): st.error("`get_datasets` unavailable."); return []
|
| 18 |
+
def get_dataset_preview(dataset_name, limit=5): st.error("`get_dataset_preview` unavailable."); return None
|
| 19 |
+
st.stop()
|
| 20 |
+
|
| 21 |
+
# Configure page settings
|
| 22 |
+
st.set_page_config(page_title="Tensor Explorer V2", layout="wide")
|
| 23 |
+
load_shared_css() # Apply shared Nexus theme styles
|
| 24 |
+
|
| 25 |
+
st.title("🔍 Tensor Explorer (V2)")
|
| 26 |
+
st.caption("Browse, filter, and visualize Tensorus datasets with agent context.")
|
| 27 |
+
|
| 28 |
+
# --- Dataset Selection ---
|
| 29 |
+
# Fetch the list of available datasets from the backend using shared utility.
|
| 30 |
+
available_datasets = get_datasets()
|
| 31 |
+
if not available_datasets:
|
| 32 |
+
st.warning("No datasets found. Ensure the backend API is running and accessible.")
|
| 33 |
+
st.stop() # Halt execution if no datasets are available for exploration.
|
| 34 |
+
|
| 35 |
+
# Dropdown for user to select a dataset.
|
| 36 |
+
selected_dataset_name = st.selectbox("Select Dataset:", available_datasets)
|
| 37 |
+
|
| 38 |
+
# --- Data Fetching & Initial Processing ---
|
| 39 |
+
# Proceed only if a dataset is selected.
|
| 40 |
+
if selected_dataset_name:
|
| 41 |
+
st.subheader(f"Exploring: {selected_dataset_name}") # Styled by shared CSS
|
| 42 |
+
|
| 43 |
+
MAX_RECORDS_DISPLAY = 100 # Define max records to fetch for the initial preview.
|
| 44 |
+
# Fetch dataset preview information (includes sample records) using shared utility.
|
| 45 |
+
dataset_info = get_dataset_preview(selected_dataset_name, limit=MAX_RECORDS_DISPLAY)
|
| 46 |
+
|
| 47 |
+
# Validate fetched data.
|
| 48 |
+
if dataset_info is None or "preview" not in dataset_info:
|
| 49 |
+
st.error(f"Failed to fetch data preview for '{selected_dataset_name}'. The backend might be down or the dataset is invalid.")
|
| 50 |
+
st.stop()
|
| 51 |
+
|
| 52 |
+
records_preview = dataset_info.get("preview", []) # List of records (tensor data + metadata).
|
| 53 |
+
total_records_in_dataset = dataset_info.get("record_count", len(records_preview)) # Total records in the dataset.
|
| 54 |
+
|
| 55 |
+
if not records_preview:
|
| 56 |
+
st.info(f"Dataset '{selected_dataset_name}' is empty or no preview data available.")
|
| 57 |
+
st.stop()
|
| 58 |
+
|
| 59 |
+
st.info(f"Displaying {len(records_preview)} of {total_records_in_dataset} records from '{selected_dataset_name}'.")
|
| 60 |
+
|
| 61 |
+
# Prepare data for DataFrame display: extract and flatten metadata.
|
| 62 |
+
# This makes metadata fields directly accessible as columns in the DataFrame for filtering and display.
|
| 63 |
+
processed_records_for_df = []
|
| 64 |
+
for record in records_preview:
|
| 65 |
+
meta = record.get('metadata', {}).copy() # Make a copy to avoid modifying original record.
|
| 66 |
+
meta['tensor_id'] = record.get('id', 'N/A') # Use 'id' from record as 'tensor_id'.
|
| 67 |
+
meta['shape'] = str(record.get('shape', 'N/A')) # Store shape as string for display.
|
| 68 |
+
meta['dtype'] = record.get('dtype', 'N/A')
|
| 69 |
+
# Ensure 'created_by' field exists, defaulting to 'Unknown' if not present. This is for filtering.
|
| 70 |
+
meta['created_by'] = meta.get('created_by', 'Unknown')
|
| 71 |
+
processed_records_for_df.append(meta)
|
| 72 |
+
|
| 73 |
+
df_display_initial = pd.DataFrame(processed_records_for_df) # DataFrame for filtering and display.
|
| 74 |
+
|
| 75 |
+
# --- Sidebar Filtering UI ---
|
| 76 |
+
st.sidebar.header("Filter Options") # Styled by shared CSS
|
| 77 |
+
|
| 78 |
+
# Filter by Source Agent ('created_by' field).
|
| 79 |
+
st.sidebar.subheader("Filter by Source Agent") # Styled by shared CSS
|
| 80 |
+
# Predefined agent sources, augmented with any unique sources found in the data for comprehensive filtering.
|
| 81 |
+
agent_sources_default = ["IngestionAgent", "RLAgent", "AutoMLAgent", "Unknown"]
|
| 82 |
+
if 'created_by' in df_display_initial.columns:
|
| 83 |
+
available_agents_in_data = df_display_initial['created_by'].unique().tolist()
|
| 84 |
+
filter_options_agents = sorted(list(set(agent_sources_default + available_agents_in_data)))
|
| 85 |
+
else:
|
| 86 |
+
filter_options_agents = agent_sources_default # Fallback if 'created_by' column is missing.
|
| 87 |
+
|
| 88 |
+
selected_agent_filters = st.sidebar.multiselect(
|
| 89 |
+
"Show data created by:",
|
| 90 |
+
options=filter_options_agents,
|
| 91 |
+
default=[] # No agents selected by default; shows all data initially.
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
# Filter by Other Metadata Fields.
|
| 95 |
+
st.sidebar.subheader("Filter by Other Metadata Fields") # Styled by shared CSS
|
| 96 |
+
# Allow filtering on any metadata column except those already specifically handled or less useful for direct filtering.
|
| 97 |
+
potential_filter_cols = [col for col in df_display_initial.columns if col not in ['tensor_id', 'created_by', 'shape', 'dtype']]
|
| 98 |
+
|
| 99 |
+
filter_cols_metadata = st.sidebar.multiselect(
|
| 100 |
+
"Select metadata fields to filter:",
|
| 101 |
+
options=potential_filter_cols
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
# Apply filters to the DataFrame. Start with a copy of the initial DataFrame.
|
| 105 |
+
filtered_df_final = df_display_initial.copy()
|
| 106 |
+
|
| 107 |
+
# Apply agent source filter if any agents are selected.
|
| 108 |
+
if selected_agent_filters:
|
| 109 |
+
if 'created_by' in filtered_df_final.columns:
|
| 110 |
+
filtered_df_final = filtered_df_final[filtered_df_final['created_by'].isin(selected_agent_filters)]
|
| 111 |
+
else:
|
| 112 |
+
# This case should ideally not be reached if 'created_by' is always added or defaulted.
|
| 113 |
+
st.sidebar.warning("'created_by' field not found for agent filtering.")
|
| 114 |
+
filtered_df_final = filtered_df_final.iloc[0:0] # Show no results if filter is active but field is missing.
|
| 115 |
+
|
| 116 |
+
# Apply other metadata filters based on user selections.
|
| 117 |
+
for col in filter_cols_metadata:
|
| 118 |
+
unique_values = filtered_df_final[col].dropna().unique().tolist()
|
| 119 |
+
# Numeric filter: use a range slider if there are multiple unique numeric values.
|
| 120 |
+
if pd.api.types.is_numeric_dtype(filtered_df_final[col]) and filtered_df_final[col].nunique() > 1:
|
| 121 |
+
min_val, max_val = float(filtered_df_final[col].min()), float(filtered_df_final[col].max())
|
| 122 |
+
selected_range = st.sidebar.slider(f"Filter {col}:", min_val, max_val, (min_val, max_val), key=f"slider_{col}")
|
| 123 |
+
filtered_df_final = filtered_df_final[filtered_df_final[col].between(selected_range[0], selected_range[1])]
|
| 124 |
+
# Categorical filter (with limited unique values for dropdown): use multiselect.
|
| 125 |
+
elif len(unique_values) > 0 and len(unique_values) <= 25: # Threshold for using multiselect.
|
| 126 |
+
default_selection = [] # Default to no specific selection (i.e., don't filter unless user picks).
|
| 127 |
+
selected_values = st.sidebar.multiselect(f"Filter {col}:", options=unique_values, default=default_selection, key=f"multi_{col}")
|
| 128 |
+
if selected_values: # Only apply filter if user has made selections.
|
| 129 |
+
filtered_df_final = filtered_df_final[filtered_df_final[col].isin(selected_values)]
|
| 130 |
+
# Text search for other columns or columns with too many unique values.
|
| 131 |
+
else:
|
| 132 |
+
search_term = st.sidebar.text_input(f"Search in {col} (contains):", key=f"text_{col}").lower()
|
| 133 |
+
if search_term:
|
| 134 |
+
filtered_df_final = filtered_df_final[filtered_df_final[col].astype(str).str.lower().str.contains(search_term, na=False)]
|
| 135 |
+
|
| 136 |
+
# --- Display Filtered Data ---
|
| 137 |
+
st.divider() # Visual separator.
|
| 138 |
+
st.subheader("Filtered Data View") # Styled by shared CSS
|
| 139 |
+
st.write(f"{len(filtered_df_final)} records matching filters.") # Display count of filtered records.
|
| 140 |
+
|
| 141 |
+
# Define preferred column order for the displayed DataFrame.
|
| 142 |
+
cols_to_display_order = ['tensor_id', 'created_by', 'shape', 'dtype']
|
| 143 |
+
remaining_cols = [col for col in df_display_initial.columns if col not in cols_to_display_order]
|
| 144 |
+
# Ensure only existing columns are included in the final display order.
|
| 145 |
+
final_display_columns = [col for col in cols_to_display_order if col in filtered_df_final.columns] + \
|
| 146 |
+
[col for col in remaining_cols if col in filtered_df_final.columns]
|
| 147 |
+
|
| 148 |
+
st.dataframe(filtered_df_final[final_display_columns], use_container_width=True, hide_index=True)
|
| 149 |
+
|
| 150 |
+
# --- Tensor Preview & Visualization ---
|
| 151 |
+
st.divider()
|
| 152 |
+
st.subheader("Tensor Preview") # Styled by shared CSS
|
| 153 |
+
|
| 154 |
+
if not filtered_df_final.empty:
|
| 155 |
+
available_tensor_ids = filtered_df_final['tensor_id'].tolist()
|
| 156 |
+
# Dropdown to select a tensor ID from the filtered results for preview.
|
| 157 |
+
selected_tensor_id = st.selectbox("Select Tensor ID to Preview:", available_tensor_ids, key="tensor_preview_select")
|
| 158 |
+
|
| 159 |
+
if selected_tensor_id:
|
| 160 |
+
# Retrieve the full record (including raw tensor data list) from the original preview list.
|
| 161 |
+
selected_full_record = next((r for r in records_preview if r.get('id') == selected_tensor_id), None)
|
| 162 |
+
|
| 163 |
+
if selected_full_record:
|
| 164 |
+
st.write("**Full Metadata:**")
|
| 165 |
+
st.json(selected_full_record.get('metadata', {})) # Display all metadata for the selected tensor.
|
| 166 |
+
|
| 167 |
+
shape = selected_full_record.get("shape")
|
| 168 |
+
dtype_str = selected_full_record.get("dtype")
|
| 169 |
+
data_list = selected_full_record.get("data") # Raw list representation of tensor data.
|
| 170 |
+
|
| 171 |
+
st.write(f"**Tensor Info:** Shape=`{shape}`, Dtype=`{dtype_str}`")
|
| 172 |
+
source_agent = selected_full_record.get('metadata', {}).get('created_by', 'Unknown')
|
| 173 |
+
st.write(f"**Source Agent:** `{source_agent}`") # Display the 'created_by' agent.
|
| 174 |
+
|
| 175 |
+
try:
|
| 176 |
+
if shape and dtype_str and data_list is not None:
|
| 177 |
+
# Reconstruct the tensor from its list representation and metadata.
|
| 178 |
+
torch_dtype = getattr(torch, dtype_str, None) # Get torch.dtype from string.
|
| 179 |
+
if torch_dtype is None:
|
| 180 |
+
st.error(f"Unsupported dtype: {dtype_str}. Cannot reconstruct tensor.")
|
| 181 |
+
else:
|
| 182 |
+
tensor = torch.tensor(data_list, dtype=torch_dtype)
|
| 183 |
+
st.write("**Tensor Data (first 10 elements flattened):**")
|
| 184 |
+
st.code(f"{tensor.flatten()[:10].cpu().numpy()}...") # Display a snippet of tensor data.
|
| 185 |
+
|
| 186 |
+
# --- Simple Visualizations based on tensor dimensions ---
|
| 187 |
+
if tensor.ndim == 1 and tensor.numel() > 1: # 1D tensor: line chart.
|
| 188 |
+
st.line_chart(tensor.cpu().numpy())
|
| 189 |
+
elif tensor.ndim == 2 and tensor.shape[0] > 1 and tensor.shape[1] > 1 : # 2D tensor: heatmap.
|
| 190 |
+
try:
|
| 191 |
+
fig = px.imshow(tensor.cpu().numpy(), title="Tensor Heatmap", aspect="auto")
|
| 192 |
+
st.plotly_chart(fig, use_container_width=True)
|
| 193 |
+
except Exception as plot_err:
|
| 194 |
+
st.warning(f"Could not generate heatmap: {plot_err}")
|
| 195 |
+
elif tensor.ndim == 3 and tensor.shape[0] in [1, 3]: # 3D tensor (potential image): try to display as image.
|
| 196 |
+
try:
|
| 197 |
+
display_tensor = tensor.cpu()
|
| 198 |
+
if display_tensor.shape[0] == 1: # Grayscale image (C, H, W) -> (H, W)
|
| 199 |
+
display_tensor = display_tensor.squeeze(0)
|
| 200 |
+
elif display_tensor.shape[0] == 3: # RGB image (C, H, W) -> (H, W, C)
|
| 201 |
+
display_tensor = display_tensor.permute(1, 2, 0)
|
| 202 |
+
|
| 203 |
+
# Normalize for display if not in typical image range [0,1] or [0,255].
|
| 204 |
+
if display_tensor.max() > 1.0 or display_tensor.min() < 0.0: # Basic check
|
| 205 |
+
display_tensor = (display_tensor - display_tensor.min()) / (display_tensor.max() - display_tensor.min() + 1e-6) # Normalize to [0,1]
|
| 206 |
+
|
| 207 |
+
st.image(display_tensor.numpy(), caption="Tensor as Image (Attempted)", use_column_width=True)
|
| 208 |
+
except Exception as img_err:
|
| 209 |
+
st.warning(f"Could not display tensor as image: {img_err}")
|
| 210 |
+
else:
|
| 211 |
+
st.info("No specific visualization available for this tensor's shape/dimension.")
|
| 212 |
+
else:
|
| 213 |
+
st.warning("Tensor data, shape, or dtype missing in the selected record.")
|
| 214 |
+
except Exception as tensor_err:
|
| 215 |
+
st.error(f"Error processing tensor data for preview: {tensor_err}")
|
| 216 |
+
else:
|
| 217 |
+
st.warning("Selected tensor ID details not found in the fetched preview data.")
|
| 218 |
+
else:
|
| 219 |
+
st.info("Select a Tensor ID from the filtered table above to preview its details.")
|
| 220 |
+
else:
|
| 221 |
+
st.info("No records match the current filters to allow preview.")
|
| 222 |
+
else:
|
| 223 |
+
st.info("Select a dataset to start exploring.")
|
| 224 |
+
|
pages/nql_chatbot_v2.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# pages/nql_chatbot_v2.py
|
| 2 |
+
|
| 3 |
+
import streamlit as st
|
| 4 |
+
import pandas as pd
|
| 5 |
+
|
| 6 |
+
# Import from the shared utils for pages
|
| 7 |
+
try:
|
| 8 |
+
from pages.pages_shared_utils import (
|
| 9 |
+
load_css as load_shared_css,
|
| 10 |
+
post_nql_query
|
| 11 |
+
)
|
| 12 |
+
except ImportError:
|
| 13 |
+
st.error("Critical Error: Could not import `pages_shared_utils`. Page cannot function.")
|
| 14 |
+
def load_shared_css(): pass
|
| 15 |
+
def post_nql_query(query: str):
|
| 16 |
+
st.error("`post_nql_query` unavailable.")
|
| 17 |
+
return {"query": query, "response_text": "Error: NQL processing unavailable.", "error": "Setup issue.", "results": None}
|
| 18 |
+
st.stop()
|
| 19 |
+
|
| 20 |
+
st.set_page_config(page_title="NQL Chatbot (V2)", layout="wide")
|
| 21 |
+
load_shared_css() # Load common CSS from shared utilities
|
| 22 |
+
|
| 23 |
+
# Custom CSS for NQL Chatbot page
|
| 24 |
+
# These styles are specific to the chat interface elements.
|
| 25 |
+
st.markdown("""
|
| 26 |
+
<style>
|
| 27 |
+
/* Chatbot specific styles */
|
| 28 |
+
.stChatMessage { /* Base style for chat messages */
|
| 29 |
+
border-radius: 10px; /* Rounded corners */
|
| 30 |
+
padding: 0.85rem 1.15rem; /* Comfortable padding */
|
| 31 |
+
margin-bottom: 0.75rem; /* Space between messages */
|
| 32 |
+
box-shadow: 0 2px 5px rgba(0,0,0,0.15); /* Subtle shadow */
|
| 33 |
+
border: 1px solid #2a3f5c; /* Consistent border with other elements */
|
| 34 |
+
}
|
| 35 |
+
/* User messages styling */
|
| 36 |
+
.stChatMessage[data-testid="stChatMessageContent"]:has(.user-avatar) {
|
| 37 |
+
background-color: #2a2f4c; /* Darker blue, similar to nav hover, for user messages */
|
| 38 |
+
border-left: 5px solid #3a6fbf; /* Accent blue border, consistent with active nav link */
|
| 39 |
+
}
|
| 40 |
+
/* Assistant messages styling */
|
| 41 |
+
.stChatMessage[data-testid="stChatMessageContent"]:has(.assistant-avatar) {
|
| 42 |
+
background-color: #18223f; /* Slightly lighter blue-gray, similar to common-card */
|
| 43 |
+
border-left: 5px solid #7070ff; /* Muted purple accent for assistant messages */
|
| 44 |
+
}
|
| 45 |
+
.user-avatar, .assistant-avatar { /* Styling for user/assistant avatars */
|
| 46 |
+
font-size: 1.5rem; /* Avatar size */
|
| 47 |
+
margin-right: 0.5rem; /* Space between avatar and message content */
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
/* Styling for the chat input area at the bottom */
|
| 51 |
+
/* The .chat-input-container class is a convention, Streamlit might not add this by default.
|
| 52 |
+
Targeting based on Streamlit's internal structure is more robust. */
|
| 53 |
+
div[data-testid="stChatInput"] { /* Main container for chat input */
|
| 54 |
+
background-color: #0a0f2c; /* Match page background */
|
| 55 |
+
border-top: 1px solid #2a3f5c; /* Separator line */
|
| 56 |
+
padding: 0.75rem;
|
| 57 |
+
}
|
| 58 |
+
div[data-testid="stChatInput"] textarea { /* The actual text input field */
|
| 59 |
+
border: 1px solid #3a3f5c !important;
|
| 60 |
+
background-color: #1a1f3c !important; /* Dark input background */
|
| 61 |
+
color: #e0e0e0 !important; /* Light text color */
|
| 62 |
+
border-radius: 5px !important;
|
| 63 |
+
}
|
| 64 |
+
div[data-testid="stChatInput"] button { /* The send button */
|
| 65 |
+
border: none !important;
|
| 66 |
+
background-color: #3a6fbf !important; /* Accent blue, consistent with primary buttons */
|
| 67 |
+
color: white !important;
|
| 68 |
+
border-radius: 5px !important;
|
| 69 |
+
}
|
| 70 |
+
div[data-testid="stChatInput"] button:hover {
|
| 71 |
+
background-color: #4a7fdc !important; /* Lighter blue on hover */
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
/* Styling for DataFrames displayed within chat messages */
|
| 75 |
+
.stDataFrame { /* General DataFrame styling from shared_utils will apply */
|
| 76 |
+
margin-top: 0.5rem; /* Space above DataFrame in chat */
|
| 77 |
+
}
|
| 78 |
+
/* .results-dataframe and .results-dataframe .col-header are less reliable than generic .stDataFrame */
|
| 79 |
+
</style>
|
| 80 |
+
""", unsafe_allow_html=True)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
st.title("💬 Natural Query Language (NQL) Chatbot")
|
| 84 |
+
st.caption("Query Tensorus datasets using natural language. Powered by Tensorus NQL Agent.")
|
| 85 |
+
|
| 86 |
+
# Initialize chat history in session state if it doesn't exist.
|
| 87 |
+
# Use a unique key for this page's messages to avoid conflicts with other potential chat interfaces.
|
| 88 |
+
if "nql_messages_v2" not in st.session_state:
|
| 89 |
+
st.session_state.nql_messages_v2 = []
|
| 90 |
+
|
| 91 |
+
# Display chat messages from history on app rerun.
|
| 92 |
+
# This ensures that the conversation persists during the session.
|
| 93 |
+
for message in st.session_state.nql_messages_v2:
|
| 94 |
+
avatar = "👤" if message["role"] == "user" else "🤖" # Assign avatar based on role
|
| 95 |
+
with st.chat_message(message["role"], avatar=avatar):
|
| 96 |
+
st.markdown(message["content"]) # Display the text content of the message.
|
| 97 |
+
# If the message includes a DataFrame (results from NQL query), display it.
|
| 98 |
+
if "results_df" in message and message["results_df"] is not None and not message["results_df"].empty:
|
| 99 |
+
st.dataframe(message["results_df"], use_container_width=True, hide_index=True)
|
| 100 |
+
# If the message includes an error, display it in an error box.
|
| 101 |
+
elif "error" in message and message["error"]:
|
| 102 |
+
st.error(message["error"])
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
# React to user input from the chat interface.
|
| 106 |
+
if prompt := st.chat_input("Enter your query (e.g., 'show 5 tensors from my_dataset')"):
|
| 107 |
+
# Add user's message to chat history and display it in the chat interface.
|
| 108 |
+
st.session_state.nql_messages_v2.append({"role": "user", "content": prompt})
|
| 109 |
+
with st.chat_message("user", avatar="👤"):
|
| 110 |
+
st.markdown(prompt)
|
| 111 |
+
|
| 112 |
+
# Get assistant's response by processing the query.
|
| 113 |
+
with st.spinner("Tensorus NQL Agent is thinking..."): # Show a loading spinner.
|
| 114 |
+
# Call the NQL processing function from shared utilities.
|
| 115 |
+
nql_api_response = post_nql_query(prompt)
|
| 116 |
+
|
| 117 |
+
# Extract relevant information from the API response.
|
| 118 |
+
response_text = nql_api_response.get("response_text", "Sorry, I encountered an issue processing your request.")
|
| 119 |
+
results_data = nql_api_response.get("results") # Expected to be a list of dicts (records).
|
| 120 |
+
error_message = nql_api_response.get("error") # Any error message from the backend.
|
| 121 |
+
|
| 122 |
+
# Prepare data for storing the assistant's message in session state.
|
| 123 |
+
assistant_message_data = {"role": "assistant", "content": response_text}
|
| 124 |
+
|
| 125 |
+
# Display assistant's response in the chat interface.
|
| 126 |
+
with st.chat_message("assistant", avatar="🤖"):
|
| 127 |
+
st.markdown(response_text) # Display the textual response.
|
| 128 |
+
|
| 129 |
+
if results_data: # If the API returned 'results'.
|
| 130 |
+
try:
|
| 131 |
+
# Process results into a Pandas DataFrame for structured display.
|
| 132 |
+
# Assumes results_data is a list of records (dictionaries).
|
| 133 |
+
# Each record might have 'id', 'shape', 'dtype', 'metadata'.
|
| 134 |
+
processed_for_df = []
|
| 135 |
+
for record in results_data:
|
| 136 |
+
row = {"tensor_id": record.get("id")} # Start with tensor_id.
|
| 137 |
+
# Flatten metadata fields into the main row for the DataFrame.
|
| 138 |
+
if isinstance(record.get("metadata"), dict):
|
| 139 |
+
row.update(record["metadata"])
|
| 140 |
+
row["shape"] = str(record.get("shape")) # Ensure shape is a string for display.
|
| 141 |
+
row["dtype"] = record.get("dtype")
|
| 142 |
+
processed_for_df.append(row)
|
| 143 |
+
|
| 144 |
+
if processed_for_df:
|
| 145 |
+
results_df = pd.DataFrame(processed_for_df)
|
| 146 |
+
st.dataframe(results_df, use_container_width=True, hide_index=True)
|
| 147 |
+
assistant_message_data["results_df"] = results_df # Store DataFrame for history.
|
| 148 |
+
elif not error_message: # If no data and no error, it might be a query that doesn't return records.
|
| 149 |
+
st.caption("Query processed, no specific records returned.")
|
| 150 |
+
|
| 151 |
+
except Exception as e:
|
| 152 |
+
# Handle errors during DataFrame processing.
|
| 153 |
+
st.error(f"Error formatting results for display: {e}")
|
| 154 |
+
assistant_message_data["error"] = f"Error formatting results: {e}"
|
| 155 |
+
|
| 156 |
+
elif error_message: # If there's a specific error message from the API.
|
| 157 |
+
st.error(error_message)
|
| 158 |
+
assistant_message_data["error"] = error_message
|
| 159 |
+
|
| 160 |
+
# If no results_data and no explicit error_message, the response_text itself is the primary message.
|
| 161 |
+
|
| 162 |
+
# Add assistant's response (including any processed data or errors) to chat history.
|
| 163 |
+
st.session_state.nql_messages_v2.append(assistant_message_data)
|
| 164 |
+
|
| 165 |
+
else: # This block runs when the page loads or if the chat input is empty.
|
| 166 |
+
# Show a welcome/instruction message if the chat history is empty.
|
| 167 |
+
if not st.session_state.nql_messages_v2:
|
| 168 |
+
st.info("Ask me anything about your data! For example: 'list datasets' or 'show tensors from dataset XYZ limit 5'.")
|
pages/pages_shared_utils.py
ADDED
|
@@ -0,0 +1,547 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# pages/pages_shared_utils.py
|
| 2 |
+
"""
|
| 3 |
+
Shared utility functions for Streamlit pages.
|
| 4 |
+
Copied/adapted from app.py to avoid complex import issues.
|
| 5 |
+
"""
|
| 6 |
+
import streamlit as st
|
| 7 |
+
import requests
|
| 8 |
+
import logging
|
| 9 |
+
import os # Added import
|
| 10 |
+
from typing import Optional, List, Dict, Any # Added for new functions
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
API_BASE_URL = os.getenv("API_BASE_URL", "http://127.0.0.1:7860") # Changed API_BASE_URL
|
| 15 |
+
|
| 16 |
+
def load_css():
|
| 17 |
+
"""Loads the main CSS styles. Assumes app.py's CSS content."""
|
| 18 |
+
st.markdown("""
|
| 19 |
+
<style>
|
| 20 |
+
/* --- Shared Base Styles for Tensorus Platform (Nexus Theme) --- */
|
| 21 |
+
|
| 22 |
+
/* General Page Styles */
|
| 23 |
+
body {
|
| 24 |
+
font-family: 'Arial', 'Helvetica Neue', 'Helvetica', sans-serif;
|
| 25 |
+
line-height: 1.6;
|
| 26 |
+
}
|
| 27 |
+
.stApp { /* Main Streamlit app container */
|
| 28 |
+
background-color: #0a0f2c; /* Primary Background: Dark blue/purple */
|
| 29 |
+
color: #e0e0e0; /* Default Text Color: Light grey */
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
/* Headings & Titles */
|
| 33 |
+
h1, .stTitle { /* Main page titles */
|
| 34 |
+
color: #d0d0ff !important; /* Primary Heading Color: Light purple/blue */
|
| 35 |
+
font-weight: bold !important;
|
| 36 |
+
}
|
| 37 |
+
h2, .stSubheader { /* Section headers */
|
| 38 |
+
color: #c0c0ef !important; /* Secondary Heading Color */
|
| 39 |
+
font-weight: bold !important;
|
| 40 |
+
border-bottom: 1px solid #3a3f5c; /* Accent Border for separation */
|
| 41 |
+
padding-bottom: 0.3rem;
|
| 42 |
+
margin-top: 1.5rem;
|
| 43 |
+
margin-bottom: 1rem;
|
| 44 |
+
}
|
| 45 |
+
h3 { /* General h3, often used in st.markdown */
|
| 46 |
+
color: #b0b0df !important; /* Tertiary Heading Color */
|
| 47 |
+
font-weight: bold !important;
|
| 48 |
+
}
|
| 49 |
+
.stMarkdown p, .stText, .stListItem { /* General text elements */
|
| 50 |
+
color: #c0c0dd; /* Softer light text */
|
| 51 |
+
font-size: 1rem;
|
| 52 |
+
}
|
| 53 |
+
.stCaption, caption { /* Streamlit captions and HTML captions */
|
| 54 |
+
font-size: 0.85rem !important;
|
| 55 |
+
color: #a0a0c0 !important; /* Muted color for captions */
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
/* Custom Top Navigation Bar */
|
| 59 |
+
.topnav-container {
|
| 60 |
+
background-color: #1a1f3c; /* Nav Background: Slightly darker than page */
|
| 61 |
+
padding: 0.5rem 1rem;
|
| 62 |
+
border-bottom: 1px solid #3a3f5c; /* Accent Border */
|
| 63 |
+
display: flex;
|
| 64 |
+
justify-content: flex-start;
|
| 65 |
+
align-items: center;
|
| 66 |
+
position: sticky; top: 0; z-index: 1000; /* Ensure it's on top */
|
| 67 |
+
width: 100%;
|
| 68 |
+
box-sizing: border-box;
|
| 69 |
+
}
|
| 70 |
+
.topnav-container .logo {
|
| 71 |
+
font-size: 1.5em;
|
| 72 |
+
font-weight: bold;
|
| 73 |
+
color: #d0d0ff; /* Primary Heading Color for logo */
|
| 74 |
+
margin-right: 2rem;
|
| 75 |
+
}
|
| 76 |
+
.topnav-container nav a {
|
| 77 |
+
color: #c0c0ff; /* Lighter Text for Nav Links */
|
| 78 |
+
padding: 0.75rem 1rem;
|
| 79 |
+
text-decoration: none;
|
| 80 |
+
font-weight: 500;
|
| 81 |
+
margin-right: 0.5rem;
|
| 82 |
+
border-radius: 4px;
|
| 83 |
+
transition: background-color 0.2s ease, color 0.2s ease;
|
| 84 |
+
}
|
| 85 |
+
.topnav-container nav a:hover {
|
| 86 |
+
background-color: #2a2f4c; /* Nav Link Hover Background */
|
| 87 |
+
color: #ffffff; /* Nav Link Hover Text */
|
| 88 |
+
}
|
| 89 |
+
.topnav-container nav a.active {
|
| 90 |
+
background-color: #3a6fbf; /* Active Nav Link Background (Accent Blue) */
|
| 91 |
+
color: #ffffff; /* Active Nav Link Text */
|
| 92 |
+
font-weight: bold;
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
/* Common Card Style (base for metric cards, agent cards, etc.) */
|
| 96 |
+
.common-card {
|
| 97 |
+
background-color: #18223f; /* Card Background: Darker than nav, but lighter than page */
|
| 98 |
+
border: 1px solid #2a3f5c; /* Card Border: Accent Border color */
|
| 99 |
+
border-radius: 10px;
|
| 100 |
+
padding: 1.5rem;
|
| 101 |
+
margin-bottom: 1rem; /* Space below cards */
|
| 102 |
+
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.2);
|
| 103 |
+
transition: transform 0.2s ease-in-out, box-shadow 0.2s ease-in-out;
|
| 104 |
+
color: #e0e0e0; /* Default text color within cards */
|
| 105 |
+
}
|
| 106 |
+
.common-card:hover {
|
| 107 |
+
transform: translateY(-5px);
|
| 108 |
+
box-shadow: 0 8px 16px rgba(0,0,0,0.3);
|
| 109 |
+
}
|
| 110 |
+
.common-card h3 { /* Titles within cards */
|
| 111 |
+
color: #b0b0df !important; /* Card Title Color: Slightly lighter than main headings */
|
| 112 |
+
font-size: 1.2em !important; /* Slightly larger for card titles */
|
| 113 |
+
margin-top: 0 !important; /* Remove default top margin for h3 in card */
|
| 114 |
+
margin-bottom: 0.75rem !important;
|
| 115 |
+
font-weight: bold !important;
|
| 116 |
+
border-bottom: none !important; /* Override general h2 border for card h3 */
|
| 117 |
+
}
|
| 118 |
+
.common-card p { /* Paragraphs within cards */
|
| 119 |
+
font-size: 0.95em !important; /* Slightly smaller for card content */
|
| 120 |
+
color: #c0c0dd !important;
|
| 121 |
+
margin-bottom: 0.5rem !important;
|
| 122 |
+
}
|
| 123 |
+
.common-card .icon { /* For icons within cards, like dashboard metric cards */
|
| 124 |
+
font-size: 2.5em;
|
| 125 |
+
margin-bottom: 0.75rem;
|
| 126 |
+
color: #7070ff; /* Icon Color: Muted accent */
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
/* Status Indicators (can be used with <span> or <p> or custom divs) */
|
| 131 |
+
.status-indicator {
|
| 132 |
+
padding: 0.4rem 0.8rem !important; /* Slightly more padding */
|
| 133 |
+
border-radius: 15px !important; /* Pill shape */
|
| 134 |
+
font-weight: bold !important;
|
| 135 |
+
font-size: 0.85em !important;
|
| 136 |
+
display: inline-block !important;
|
| 137 |
+
text-align: center !important;
|
| 138 |
+
}
|
| 139 |
+
.status-success, .status-running { color: #ffffff !important; background-color: #4CAF50 !important; } /* Green */
|
| 140 |
+
.status-error { color: #ffffff !important; background-color: #F44336 !important; } /* Red */
|
| 141 |
+
.status-warning { color: #000000 !important; background-color: #FFC107 !important; } /* Amber */
|
| 142 |
+
.status-info { color: #ffffff !important; background-color: #2196F3 !important; } /* Blue */
|
| 143 |
+
.status-stopped { color: #e0e0e0 !important; background-color: #525252 !important; } /* Darker Grey for stopped */
|
| 144 |
+
.status-unknown { color: #333333 !important; background-color: #BDBDBD !important; } /* Lighter grey for unknown */
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
/* Standardized Streamlit Input Styling */
|
| 148 |
+
.stTextInput > div > div > input,
|
| 149 |
+
.stTextArea > div > div > textarea,
|
| 150 |
+
.stSelectbox > div > div,
|
| 151 |
+
.stNumberInput > div > div > input {
|
| 152 |
+
border: 1px solid #3a3f5c !important; /* Accent Border */
|
| 153 |
+
background-color: #1a1f3c !important; /* Nav Background color for inputs */
|
| 154 |
+
color: #e0e0e0 !important; /* Default Text Color */
|
| 155 |
+
border-radius: 5px !important;
|
| 156 |
+
}
|
| 157 |
+
.stMultiSelect > div > div > div { /* Multiselect options container */
|
| 158 |
+
border: 1px solid #3a3f5c !important;
|
| 159 |
+
background-color: #1a1f3c !important;
|
| 160 |
+
}
|
| 161 |
+
.stMultiSelect span[data-baseweb="tag"] { /* Selected items in multiselect */
|
| 162 |
+
background-color: #3a6fbf !important; /* Active Nav Link Background */
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
/* Standardized Streamlit Button Styling */
|
| 167 |
+
.stButton > button {
|
| 168 |
+
border: 1px solid #3a6fbf !important; /* Accent Blue for border */
|
| 169 |
+
background-color: #3a6fbf !important; /* Accent Blue for background */
|
| 170 |
+
color: white !important;
|
| 171 |
+
border-radius: 5px !important;
|
| 172 |
+
padding: 0.5rem 1rem !important;
|
| 173 |
+
transition: background-color 0.2s ease, border-color 0.2s ease;
|
| 174 |
+
}
|
| 175 |
+
.stButton > button:hover {
|
| 176 |
+
background-color: #4a7fdc !important; /* Lighter Accent Blue for hover */
|
| 177 |
+
border-color: #4a7fdc !important;
|
| 178 |
+
}
|
| 179 |
+
.stButton > button:disabled {
|
| 180 |
+
background-color: #2a2f4c !important;
|
| 181 |
+
border-color: #2a2f4c !important;
|
| 182 |
+
color: #777777 !important;
|
| 183 |
+
}
|
| 184 |
+
/* For secondary buttons, Streamlit uses a 'kind' attribute in HTML we can't directly target via pure CSS.
|
| 185 |
+
Instead, use st.button(..., type="secondary") and rely on Streamlit's handling,
|
| 186 |
+
or use st.markdown for fully custom buttons if default secondary is not enough.
|
| 187 |
+
The below attempts to style based on common Streamlit secondary button appearance.
|
| 188 |
+
Note: This specific selector for secondary buttons might be unstable if Streamlit changes its internal class names.
|
| 189 |
+
*/
|
| 190 |
+
.stButton button.st-emotion-cache-LPTKCI { /* Example selector for a secondary button, MAY BE UNSTABLE */
|
| 191 |
+
background-color: #2a2f4c !important;
|
| 192 |
+
border: 1px solid #2a2f4c !important;
|
| 193 |
+
color: #c0c0ff !important;
|
| 194 |
+
}
|
| 195 |
+
.stButton button.st-emotion-cache-LPTKCI:hover {
|
| 196 |
+
background-color: #3a3f5c !important;
|
| 197 |
+
border-color: #3a3f5c !important;
|
| 198 |
+
color: #ffffff !important;
|
| 199 |
+
}
|
| 200 |
+
.stButton button.st-emotion-cache-LPTKCI:disabled { /* Disabled secondary button */
|
| 201 |
+
background-color: #1e2a47 !important;
|
| 202 |
+
border-color: #1e2a47 !important;
|
| 203 |
+
color: #555555 !important;
|
| 204 |
+
}
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
/* Dataframe styling */
|
| 208 |
+
.stDataFrame { /* Main container for dataframes */
|
| 209 |
+
border: 1px solid #2a3f5c !important; /* Accent Border */
|
| 210 |
+
border-radius: 5px !important;
|
| 211 |
+
background-color: #1a1f3c !important; /* Nav Background for dataframe background */
|
| 212 |
+
}
|
| 213 |
+
.stDataFrame th { /* Headers */
|
| 214 |
+
background-color: #2a2f4c !important; /* Nav Link Hover Background for headers */
|
| 215 |
+
color: #d0d0ff !important; /* Primary Heading Color for header text */
|
| 216 |
+
font-weight: bold;
|
| 217 |
+
}
|
| 218 |
+
.stDataFrame td { /* Cells */
|
| 219 |
+
color: #c0c0dd !important; /* Softer light text for cell data */
|
| 220 |
+
border-bottom-color: #2a3f5c !important; /* Accent border for cell lines */
|
| 221 |
+
border-top-color: #2a3f5c !important;
|
| 222 |
+
}
|
| 223 |
+
|
| 224 |
+
</style>
|
| 225 |
+
""", unsafe_allow_html=True)
|
| 226 |
+
|
| 227 |
+
def get_api_status() -> tuple[bool, dict]:
|
| 228 |
+
"""
|
| 229 |
+
Checks if the backend API is reachable and returns its status.
|
| 230 |
+
|
| 231 |
+
Uses the `API_BASE_URL` constant defined in this module.
|
| 232 |
+
|
| 233 |
+
Returns:
|
| 234 |
+
tuple[bool, dict]: A tuple where:
|
| 235 |
+
- The first element is a boolean: True if the API is reachable and returns a 2xx status, False otherwise.
|
| 236 |
+
- The second element is a dictionary:
|
| 237 |
+
- If successful, contains API information (e.g., from `response.json()`).
|
| 238 |
+
- If unsuccessful, contains an 'error' key with a descriptive message.
|
| 239 |
+
"""
|
| 240 |
+
try:
|
| 241 |
+
response = requests.get(f"{API_BASE_URL}/", timeout=3) # Increased timeout slightly
|
| 242 |
+
response.raise_for_status()
|
| 243 |
+
return True, response.json()
|
| 244 |
+
except requests.exceptions.RequestException as e:
|
| 245 |
+
logger.error(f"API connection error in get_api_status (shared_utils): {e}")
|
| 246 |
+
return False, {"error": f"API connection failed: {str(e)}"}
|
| 247 |
+
except Exception as e:
|
| 248 |
+
logger.exception(f"Unexpected error in get_api_status (shared_utils): {e}")
|
| 249 |
+
return False, {"error": f"An unexpected error occurred: {str(e)}"}
|
| 250 |
+
|
| 251 |
+
def get_agent_status() -> Optional[dict]:
|
| 252 |
+
"""
|
| 253 |
+
Fetches the status for all registered agents from the backend.
|
| 254 |
+
|
| 255 |
+
Uses the `API_BASE_URL` constant defined in this module.
|
| 256 |
+
On successful API call, returns a dictionary where keys are agent IDs
|
| 257 |
+
and values are dictionaries containing status and configuration for each agent.
|
| 258 |
+
Returns None if the API call fails or an exception occurs.
|
| 259 |
+
|
| 260 |
+
Returns:
|
| 261 |
+
Optional[dict]: Agent statuses dictionary or None.
|
| 262 |
+
"""
|
| 263 |
+
try:
|
| 264 |
+
response = requests.get(f"{API_BASE_URL}/agents/status", timeout=5)
|
| 265 |
+
response.raise_for_status()
|
| 266 |
+
return response.json()
|
| 267 |
+
except requests.exceptions.RequestException as e:
|
| 268 |
+
logger.error(f"API error fetching agent status (shared_utils): {e}")
|
| 269 |
+
return None
|
| 270 |
+
except Exception as e:
|
| 271 |
+
logger.exception(f"Unexpected error in get_agent_status (shared_utils): {e}")
|
| 272 |
+
return None
|
| 273 |
+
|
| 274 |
+
def start_agent(agent_id: str) -> dict:
|
| 275 |
+
"""
|
| 276 |
+
Sends a request to the backend to start a specific agent.
|
| 277 |
+
|
| 278 |
+
Uses the `API_BASE_URL` constant defined in this module.
|
| 279 |
+
Constructs a POST request to the `/agents/{agent_id}/start` endpoint.
|
| 280 |
+
|
| 281 |
+
Args:
|
| 282 |
+
agent_id (str): The unique identifier of the agent to start.
|
| 283 |
+
|
| 284 |
+
Returns:
|
| 285 |
+
dict: A dictionary containing the API response. Typically includes a 'success' boolean
|
| 286 |
+
and a 'message' string. In case of connection or unexpected errors,
|
| 287 |
+
it also returns a dict with 'success': False and an error 'message'.
|
| 288 |
+
"""
|
| 289 |
+
try:
|
| 290 |
+
response = requests.post(f"{API_BASE_URL}/agents/{agent_id}/start", timeout=7) # Increased timeout
|
| 291 |
+
response.raise_for_status()
|
| 292 |
+
return response.json()
|
| 293 |
+
except requests.exceptions.RequestException as e:
|
| 294 |
+
logger.error(f"API error starting agent {agent_id} (shared_utils): {e}")
|
| 295 |
+
return {"success": False, "message": f"Failed to start agent {agent_id}: {str(e)}"}
|
| 296 |
+
except Exception as e:
|
| 297 |
+
logger.exception(f"Unexpected error in start_agent (shared_utils) for {agent_id}: {e}")
|
| 298 |
+
return {"success": False, "message": f"An unexpected error occurred: {str(e)}"}
|
| 299 |
+
|
| 300 |
+
def stop_agent(agent_id: str) -> dict:
|
| 301 |
+
"""
|
| 302 |
+
Sends a request to the backend to stop a specific agent.
|
| 303 |
+
|
| 304 |
+
Uses the `API_BASE_URL` constant defined in this module.
|
| 305 |
+
Constructs a POST request to the `/agents/{agent_id}/stop` endpoint.
|
| 306 |
+
|
| 307 |
+
Args:
|
| 308 |
+
agent_id (str): The unique identifier of the agent to stop.
|
| 309 |
+
|
| 310 |
+
Returns:
|
| 311 |
+
dict: A dictionary containing the API response. Typically includes a 'success' boolean
|
| 312 |
+
and a 'message' string. In case of connection or unexpected errors,
|
| 313 |
+
it also returns a dict with 'success': False and an error 'message'.
|
| 314 |
+
"""
|
| 315 |
+
try:
|
| 316 |
+
response = requests.post(f"{API_BASE_URL}/agents/{agent_id}/stop", timeout=7) # Increased timeout
|
| 317 |
+
response.raise_for_status()
|
| 318 |
+
return response.json()
|
| 319 |
+
except requests.exceptions.RequestException as e:
|
| 320 |
+
logger.error(f"API error stopping agent {agent_id} (shared_utils): {e}")
|
| 321 |
+
return {"success": False, "message": f"Failed to stop agent {agent_id}: {str(e)}"}
|
| 322 |
+
except Exception as e:
|
| 323 |
+
logger.exception(f"Unexpected error in stop_agent (shared_utils) for {agent_id}: {e}")
|
| 324 |
+
return {"success": False, "message": f"An unexpected error occurred: {str(e)}"}
|
| 325 |
+
|
| 326 |
+
def get_datasets() -> list[str]:
|
| 327 |
+
"""
|
| 328 |
+
Fetches the list of available dataset names from the backend.
|
| 329 |
+
|
| 330 |
+
Uses the `API_BASE_URL` constant defined in this module.
|
| 331 |
+
Targets the `/explorer/datasets` endpoint.
|
| 332 |
+
|
| 333 |
+
Returns:
|
| 334 |
+
list[str]: A list of dataset names. Returns an empty list if the API call
|
| 335 |
+
fails, if the 'datasets' key is missing in the response,
|
| 336 |
+
or if an exception occurs.
|
| 337 |
+
"""
|
| 338 |
+
try:
|
| 339 |
+
response = requests.get(f"{API_BASE_URL}/explorer/datasets", timeout=5)
|
| 340 |
+
response.raise_for_status()
|
| 341 |
+
data = response.json()
|
| 342 |
+
return data.get("datasets", [])
|
| 343 |
+
except requests.exceptions.RequestException as e:
|
| 344 |
+
logger.error(f"API error fetching datasets (shared_utils): {e}")
|
| 345 |
+
return []
|
| 346 |
+
except Exception as e:
|
| 347 |
+
logger.exception(f"Unexpected error in get_datasets (shared_utils): {e}")
|
| 348 |
+
return []
|
| 349 |
+
|
| 350 |
+
def get_dataset_preview(dataset_name: str, limit: int = 10) -> Optional[dict]:
|
| 351 |
+
"""
|
| 352 |
+
Fetches preview data for a specific dataset from the backend.
|
| 353 |
+
|
| 354 |
+
Uses the `API_BASE_URL` constant defined in this module.
|
| 355 |
+
Targets the `/explorer/dataset/{dataset_name}/preview` endpoint with a `limit` parameter.
|
| 356 |
+
|
| 357 |
+
Args:
|
| 358 |
+
dataset_name (str): The name of the dataset to preview.
|
| 359 |
+
limit (int): The maximum number of records to fetch for the preview. Defaults to 10.
|
| 360 |
+
|
| 361 |
+
Returns:
|
| 362 |
+
Optional[dict]: A dictionary containing dataset information (e.g., 'dataset',
|
| 363 |
+
'record_count', 'preview' list of records) if successful.
|
| 364 |
+
Each record in the 'preview' list is a dictionary typically
|
| 365 |
+
containing 'id', 'shape', 'dtype', 'metadata', and 'data' (raw list).
|
| 366 |
+
Returns None if the API call fails or an exception occurs.
|
| 367 |
+
"""
|
| 368 |
+
try:
|
| 369 |
+
response = requests.get(f"{API_BASE_URL}/explorer/dataset/{dataset_name}/preview?limit={limit}", timeout=10)
|
| 370 |
+
response.raise_for_status()
|
| 371 |
+
return response.json()
|
| 372 |
+
except requests.exceptions.RequestException as e:
|
| 373 |
+
logger.error(f"API error fetching dataset preview for {dataset_name} (shared_utils): {e}")
|
| 374 |
+
return None
|
| 375 |
+
except Exception as e:
|
| 376 |
+
logger.exception(f"Unexpected error in get_dataset_preview (shared_utils) for {dataset_name}: {e}")
|
| 377 |
+
return None
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
|
| 381 |
+
|
| 382 |
+
def get_tensor_metadata(dataset_name: str, tensor_id: str) -> Optional[dict]:
|
| 383 |
+
"""Fetch metadata for a specific tensor via the Explorer API."""
|
| 384 |
+
try:
|
| 385 |
+
response = requests.get(
|
| 386 |
+
f"{API_BASE_URL}/explorer/dataset/{dataset_name}/tensor/{tensor_id}/metadata",
|
| 387 |
+
timeout=5,
|
| 388 |
+
)
|
| 389 |
+
response.raise_for_status()
|
| 390 |
+
data = response.json()
|
| 391 |
+
return data.get("metadata", data)
|
| 392 |
+
except requests.exceptions.RequestException as e:
|
| 393 |
+
logger.error(
|
| 394 |
+
f"API error fetching tensor metadata for {dataset_name}/{tensor_id} (shared_utils): {e}"
|
| 395 |
+
)
|
| 396 |
+
return None
|
| 397 |
+
except Exception as e:
|
| 398 |
+
logger.exception(
|
| 399 |
+
f"Unexpected error in get_tensor_metadata (shared_utils) for {dataset_name}/{tensor_id}: {e}"
|
| 400 |
+
|
| 401 |
+
)
|
| 402 |
+
|
| 403 |
+
return None
|
| 404 |
+
|
| 405 |
+
def list_all_agents() -> list[dict[str, str]]:
|
| 406 |
+
"""
|
| 407 |
+
Returns a list of agent details based on the current statuses fetched by `get_agent_status`.
|
| 408 |
+
|
| 409 |
+
Each agent detail is a dictionary with 'id', 'name', and 'status' keys.
|
| 410 |
+
If an agent's name is not explicitly provided in the status data, a default name
|
| 411 |
+
is generated by capitalizing the agent_id and replacing underscores with spaces.
|
| 412 |
+
|
| 413 |
+
This function is a convenience wrapper around `get_agent_status()` if a list
|
| 414 |
+
format of agent information is preferred.
|
| 415 |
+
|
| 416 |
+
Returns:
|
| 417 |
+
list[dict[str, str]]: A list of dictionaries, where each dictionary represents an agent
|
| 418 |
+
and contains its 'id', 'name', and 'status'.
|
| 419 |
+
Returns an empty list if agent statuses cannot be fetched.
|
| 420 |
+
"""
|
| 421 |
+
agents_status_data = get_agent_status()
|
| 422 |
+
if agents_status_data:
|
| 423 |
+
return [
|
| 424 |
+
{
|
| 425 |
+
"id": agent_id,
|
| 426 |
+
"name": agent_data.get("name", agent_id.replace("_", " ").title()), # Default formatted name
|
| 427 |
+
"status": agent_data.get("status", "unknown")
|
| 428 |
+
}
|
| 429 |
+
for agent_id, agent_data in agents_status_data.items()
|
| 430 |
+
]
|
| 431 |
+
return []
|
| 432 |
+
|
| 433 |
+
def post_nql_query(query: str) -> dict:
|
| 434 |
+
"""
|
| 435 |
+
Sends an NQL query to the backend for processing.
|
| 436 |
+
|
| 437 |
+
Uses the `API_BASE_URL` constant defined in this module.
|
| 438 |
+
Constructs a POST request to the `/chat/query` endpoint with the user's query.
|
| 439 |
+
|
| 440 |
+
Args:
|
| 441 |
+
query (str): The Natural Query Language query string provided by the user.
|
| 442 |
+
|
| 443 |
+
Returns:
|
| 444 |
+
dict: A dictionary containing the API response.
|
| 445 |
+
On success, this typically includes:
|
| 446 |
+
- 'query': The original query string.
|
| 447 |
+
- 'response_text': A textual summary of the NQL agent's action or findings.
|
| 448 |
+
- 'results': A list of records (tensors with metadata) if the query involved data retrieval.
|
| 449 |
+
Each record is a dictionary, potentially including 'id', 'shape',
|
| 450 |
+
'dtype', 'metadata', and 'data'.
|
| 451 |
+
- 'count': Number of results found (if applicable).
|
| 452 |
+
On failure (e.g., connection error, API error, unexpected server error):
|
| 453 |
+
- 'query': The original query.
|
| 454 |
+
- 'response_text': An error message.
|
| 455 |
+
- 'error': A more detailed error string.
|
| 456 |
+
- 'results': None or an empty list.
|
| 457 |
+
"""
|
| 458 |
+
try:
|
| 459 |
+
response = requests.post(
|
| 460 |
+
f"{API_BASE_URL}/chat/query", # Ensure API_BASE_URL is defined in this file
|
| 461 |
+
json={"query": query},
|
| 462 |
+
timeout=15
|
| 463 |
+
)
|
| 464 |
+
response.raise_for_status()
|
| 465 |
+
return response.json()
|
| 466 |
+
except requests.exceptions.RequestException as e:
|
| 467 |
+
logger.error(f"Error posting NQL query from pages_shared_utils: {e}")
|
| 468 |
+
# Let the caller handle UI error display
|
| 469 |
+
return {"query": query, "response_text": "Error connecting to backend or processing query.", "error": str(e), "results": None}
|
| 470 |
+
except Exception as e:
|
| 471 |
+
logger.exception(f"Unexpected error in post_nql_query (pages_shared_utils): {e}")
|
| 472 |
+
return {"query": query, "response_text": "An unexpected error occurred.", "error": str(e), "results": None}
|
| 473 |
+
|
| 474 |
+
# --- Functions to be added from app.py ---
|
| 475 |
+
|
| 476 |
+
def configure_agent(agent_id: str, config: dict) -> dict:
|
| 477 |
+
"""
|
| 478 |
+
Sends a request to the backend to configure a specific agent.
|
| 479 |
+
|
| 480 |
+
Uses the `API_BASE_URL` constant defined in this module.
|
| 481 |
+
Constructs a POST request to the `/agents/{agent_id}/configure` endpoint.
|
| 482 |
+
|
| 483 |
+
Args:
|
| 484 |
+
agent_id (str): The unique identifier of the agent to configure.
|
| 485 |
+
config (dict): The configuration dictionary for the agent.
|
| 486 |
+
|
| 487 |
+
Returns:
|
| 488 |
+
dict: A dictionary containing the API response. Typically includes 'success' boolean
|
| 489 |
+
and a 'message' string. In case of connection or unexpected errors,
|
| 490 |
+
it also returns a dict with 'success': False and an error 'message'.
|
| 491 |
+
"""
|
| 492 |
+
try:
|
| 493 |
+
response = requests.post(
|
| 494 |
+
f"{API_BASE_URL}/agents/{agent_id}/configure",
|
| 495 |
+
json={"config": config},
|
| 496 |
+
timeout=7 # Increased timeout similar to start/stop
|
| 497 |
+
)
|
| 498 |
+
response.raise_for_status()
|
| 499 |
+
return response.json()
|
| 500 |
+
except requests.exceptions.RequestException as e:
|
| 501 |
+
logger.error(f"API error configuring agent {agent_id} (shared_utils): {e}")
|
| 502 |
+
return {"success": False, "message": f"Failed to configure agent {agent_id}: {str(e)}"}
|
| 503 |
+
except Exception as e:
|
| 504 |
+
logger.exception(f"Unexpected error in configure_agent (shared_utils) for {agent_id}: {e}")
|
| 505 |
+
return {"success": False, "message": f"An unexpected error occurred: {str(e)}"}
|
| 506 |
+
|
| 507 |
+
def operate_explorer(dataset: str, operation: str, index: int, params: dict) -> dict:
|
| 508 |
+
"""
|
| 509 |
+
Sends an operation request to the data explorer for a specific tensor.
|
| 510 |
+
|
| 511 |
+
Uses the `API_BASE_URL` constant defined in this module.
|
| 512 |
+
Constructs a POST request to the `/explorer/operate` endpoint.
|
| 513 |
+
|
| 514 |
+
Args:
|
| 515 |
+
dataset (str): The name of the dataset containing the tensor.
|
| 516 |
+
operation (str): The operation to perform (e.g., 'view', 'transform').
|
| 517 |
+
index (int): The index of the tensor within the dataset.
|
| 518 |
+
params (dict): Additional parameters required for the operation.
|
| 519 |
+
|
| 520 |
+
Returns:
|
| 521 |
+
dict: A dictionary containing the API response. Typically includes:
|
| 522 |
+
- 'success': A boolean indicating if the operation was accepted.
|
| 523 |
+
- 'metadata': A dictionary with details about the operation or resulting tensor.
|
| 524 |
+
- 'result_data': The data of the resulting tensor (if applicable), or None.
|
| 525 |
+
In case of connection or server-side errors, it returns a dict with
|
| 526 |
+
'success': False, 'metadata': {'error': error_message}, and 'result_data': None.
|
| 527 |
+
"""
|
| 528 |
+
payload = {
|
| 529 |
+
"dataset": dataset,
|
| 530 |
+
"operation": operation,
|
| 531 |
+
"tensor_index": index,
|
| 532 |
+
"params": params
|
| 533 |
+
}
|
| 534 |
+
try:
|
| 535 |
+
response = requests.post(
|
| 536 |
+
f"{API_BASE_URL}/explorer/operate",
|
| 537 |
+
json=payload,
|
| 538 |
+
timeout=15 # Standard timeout for potentially long operations
|
| 539 |
+
)
|
| 540 |
+
response.raise_for_status()
|
| 541 |
+
return response.json()
|
| 542 |
+
except requests.exceptions.RequestException as e:
|
| 543 |
+
logger.error(f"API error in operate_explorer for {dataset} (shared_utils): {e}")
|
| 544 |
+
return {"success": False, "metadata": {"error": str(e)}, "result_data": None}
|
| 545 |
+
except Exception as e:
|
| 546 |
+
logger.exception(f"Unexpected error in operate_explorer (shared_utils) for {dataset}: {e}")
|
| 547 |
+
return {"success": False, "metadata": {"error": str(e)}, "result_data": None}
|
pages/ui_utils.py
ADDED
|
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ui_utils.py (Modifications for Step 3)
|
| 2 |
+
"""Utility functions for the Tensorus Streamlit UI, now using API calls."""
|
| 3 |
+
|
| 4 |
+
import requests
|
| 5 |
+
import streamlit as st
|
| 6 |
+
import logging
|
| 7 |
+
from typing import List, Dict, Any, Optional
|
| 8 |
+
|
| 9 |
+
logger = logging.getLogger(__name__)
|
| 10 |
+
|
| 11 |
+
# --- Configuration ---
|
| 12 |
+
TENSORUS_API_URL = "http://127.0.0.1:7860" # Ensure FastAPI runs here
|
| 13 |
+
|
| 14 |
+
# --- API Interaction Functions ---
|
| 15 |
+
|
| 16 |
+
def get_api_status() -> bool:
|
| 17 |
+
"""Checks if the Tensorus API is reachable."""
|
| 18 |
+
try:
|
| 19 |
+
response = requests.get(f"{TENSORUS_API_URL}/", timeout=2)
|
| 20 |
+
return response.status_code == 200
|
| 21 |
+
except requests.exceptions.ConnectionError:
|
| 22 |
+
return False
|
| 23 |
+
except Exception as e:
|
| 24 |
+
logger.error(f"Error checking API status: {e}")
|
| 25 |
+
return False
|
| 26 |
+
|
| 27 |
+
def list_datasets() -> Optional[List[str]]:
|
| 28 |
+
"""Fetches the list of dataset names from the API."""
|
| 29 |
+
try:
|
| 30 |
+
response = requests.get(f"{TENSORUS_API_URL}/datasets")
|
| 31 |
+
response.raise_for_status()
|
| 32 |
+
data = response.json()
|
| 33 |
+
if data.get("success"):
|
| 34 |
+
return data.get("data", [])
|
| 35 |
+
else:
|
| 36 |
+
st.error(f"API Error listing datasets: {data.get('message')}")
|
| 37 |
+
return None
|
| 38 |
+
except requests.exceptions.RequestException as e:
|
| 39 |
+
st.error(f"Connection Error listing datasets: {e}")
|
| 40 |
+
return None
|
| 41 |
+
except Exception as e:
|
| 42 |
+
st.error(f"Unexpected error listing datasets: {e}")
|
| 43 |
+
logger.exception("Unexpected error in list_datasets")
|
| 44 |
+
return None
|
| 45 |
+
|
| 46 |
+
def fetch_dataset_data(dataset_name: str, offset: int = 0, limit: int = 50) -> Optional[List[Dict[str, Any]]]:
|
| 47 |
+
"""Fetches a page of records from a dataset via API."""
|
| 48 |
+
try:
|
| 49 |
+
params = {"offset": offset, "limit": limit}
|
| 50 |
+
response = requests.get(f"{TENSORUS_API_URL}/datasets/{dataset_name}/records", params=params)
|
| 51 |
+
response.raise_for_status()
|
| 52 |
+
data = response.json()
|
| 53 |
+
if data.get("success"):
|
| 54 |
+
return data.get("data", [])
|
| 55 |
+
else:
|
| 56 |
+
st.error(f"API Error fetching '{dataset_name}': {data.get('message')}")
|
| 57 |
+
return None
|
| 58 |
+
except requests.exceptions.RequestException as e:
|
| 59 |
+
st.error(f"Connection Error fetching '{dataset_name}': {e}")
|
| 60 |
+
return None
|
| 61 |
+
except Exception as e:
|
| 62 |
+
st.error(f"Unexpected error fetching '{dataset_name}': {e}")
|
| 63 |
+
logger.exception(f"Unexpected error in fetch_dataset_data for {dataset_name}")
|
| 64 |
+
return None
|
| 65 |
+
|
| 66 |
+
def execute_nql_query(query: str) -> Optional[Dict[str, Any]]:
|
| 67 |
+
"""Sends an NQL query to the API."""
|
| 68 |
+
try:
|
| 69 |
+
payload = {"query": query}
|
| 70 |
+
response = requests.post(f"{TENSORUS_API_URL}/query", json=payload)
|
| 71 |
+
# Handle specific NQL errors (400) vs other errors
|
| 72 |
+
if response.status_code == 400:
|
| 73 |
+
error_detail = response.json().get("detail", "Unknown NQL processing error")
|
| 74 |
+
return {"success": False, "message": error_detail, "results": None, "count": None}
|
| 75 |
+
response.raise_for_status() # Raise for 5xx etc.
|
| 76 |
+
return response.json() # Return the full NQLResponse structure
|
| 77 |
+
except requests.exceptions.RequestException as e:
|
| 78 |
+
st.error(f"Connection Error executing NQL query: {e}")
|
| 79 |
+
return {"success": False, "message": f"Connection Error: {e}", "results": None, "count": None}
|
| 80 |
+
except Exception as e:
|
| 81 |
+
st.error(f"Unexpected error executing NQL query: {e}")
|
| 82 |
+
logger.exception("Unexpected error in execute_nql_query")
|
| 83 |
+
return {"success": False, "message": f"Unexpected Error: {e}", "results": None, "count": None}
|
| 84 |
+
|
| 85 |
+
# --- NEW/UPDATED Agent and Metrics Functions ---
|
| 86 |
+
|
| 87 |
+
def list_all_agents() -> Optional[List[Dict[str, Any]]]:
|
| 88 |
+
"""Fetches the list of all registered agents from the API."""
|
| 89 |
+
try:
|
| 90 |
+
response = requests.get(f"{TENSORUS_API_URL}/agents")
|
| 91 |
+
response.raise_for_status()
|
| 92 |
+
# The response is directly the list of AgentInfo objects
|
| 93 |
+
return response.json()
|
| 94 |
+
except requests.exceptions.RequestException as e:
|
| 95 |
+
st.error(f"Connection Error listing agents: {e}")
|
| 96 |
+
return None
|
| 97 |
+
except Exception as e:
|
| 98 |
+
st.error(f"Unexpected error listing agents: {e}")
|
| 99 |
+
logger.exception("Unexpected error in list_all_agents")
|
| 100 |
+
return None
|
| 101 |
+
|
| 102 |
+
def get_agent_status(agent_id: str) -> Optional[Dict[str, Any]]:
|
| 103 |
+
"""Fetches status for a specific agent from the API."""
|
| 104 |
+
try:
|
| 105 |
+
response = requests.get(f"{TENSORUS_API_URL}/agents/{agent_id}/status")
|
| 106 |
+
if response.status_code == 404:
|
| 107 |
+
st.error(f"Agent '{agent_id}' not found via API.")
|
| 108 |
+
return None
|
| 109 |
+
response.raise_for_status()
|
| 110 |
+
# Returns AgentStatus model dict
|
| 111 |
+
return response.json()
|
| 112 |
+
except requests.exceptions.RequestException as e:
|
| 113 |
+
st.error(f"Connection Error getting status for agent '{agent_id}': {e}")
|
| 114 |
+
return None
|
| 115 |
+
except Exception as e:
|
| 116 |
+
st.error(f"Unexpected error getting status for agent '{agent_id}': {e}")
|
| 117 |
+
logger.exception(f"Unexpected error in get_agent_status for {agent_id}")
|
| 118 |
+
return None
|
| 119 |
+
|
| 120 |
+
def get_agent_logs(agent_id: str, lines: int = 20) -> Optional[List[str]]:
|
| 121 |
+
"""Fetches recent logs for a specific agent from the API."""
|
| 122 |
+
try:
|
| 123 |
+
response = requests.get(f"{TENSORUS_API_URL}/agents/{agent_id}/logs", params={"lines": lines})
|
| 124 |
+
if response.status_code == 404:
|
| 125 |
+
st.error(f"Agent '{agent_id}' not found via API for logs.")
|
| 126 |
+
return None
|
| 127 |
+
response.raise_for_status()
|
| 128 |
+
data = response.json()
|
| 129 |
+
# Returns AgentLogResponse model dict
|
| 130 |
+
return data.get("logs", [])
|
| 131 |
+
except requests.exceptions.RequestException as e:
|
| 132 |
+
st.error(f"Connection Error getting logs for agent '{agent_id}': {e}")
|
| 133 |
+
return None
|
| 134 |
+
except Exception as e:
|
| 135 |
+
st.error(f"Unexpected error getting logs for agent '{agent_id}': {e}")
|
| 136 |
+
logger.exception(f"Unexpected error in get_agent_logs for {agent_id}")
|
| 137 |
+
return None
|
| 138 |
+
|
| 139 |
+
def start_agent(agent_id: str) -> bool:
|
| 140 |
+
"""Sends a start signal to an agent via the API."""
|
| 141 |
+
try:
|
| 142 |
+
response = requests.post(f"{TENSORUS_API_URL}/agents/{agent_id}/start")
|
| 143 |
+
if response.status_code == 404:
|
| 144 |
+
st.error(f"Agent '{agent_id}' not found via API.")
|
| 145 |
+
return False
|
| 146 |
+
# 202 Accepted is success, other 2xx might be okay too (e.g. already running if handled gracefully)
|
| 147 |
+
# 4xx errors indicate failure
|
| 148 |
+
if 200 <= response.status_code < 300:
|
| 149 |
+
api_response = response.json()
|
| 150 |
+
if api_response.get("success"):
|
| 151 |
+
st.success(f"API: {api_response.get('message', 'Start signal sent.')}")
|
| 152 |
+
return True
|
| 153 |
+
else:
|
| 154 |
+
# API indicated logical failure (e.g., already running)
|
| 155 |
+
st.warning(f"API: {api_response.get('message', 'Agent might already be running.')}")
|
| 156 |
+
return False
|
| 157 |
+
else:
|
| 158 |
+
# Handle other potential errors reported by API
|
| 159 |
+
error_detail = "Unknown error"
|
| 160 |
+
try: error_detail = response.json().get("detail", error_detail)
|
| 161 |
+
except: pass
|
| 162 |
+
st.error(f"API Error starting agent '{agent_id}': {error_detail} (Status: {response.status_code})")
|
| 163 |
+
return False
|
| 164 |
+
except requests.exceptions.RequestException as e:
|
| 165 |
+
st.error(f"Connection Error starting agent '{agent_id}': {e}")
|
| 166 |
+
return False
|
| 167 |
+
except Exception as e:
|
| 168 |
+
st.error(f"Unexpected error starting agent '{agent_id}': {e}")
|
| 169 |
+
logger.exception(f"Unexpected error in start_agent for {agent_id}")
|
| 170 |
+
return False
|
| 171 |
+
|
| 172 |
+
def stop_agent(agent_id: str) -> bool:
|
| 173 |
+
"""Sends a stop signal to an agent via the API."""
|
| 174 |
+
try:
|
| 175 |
+
response = requests.post(f"{TENSORUS_API_URL}/agents/{agent_id}/stop")
|
| 176 |
+
if response.status_code == 404:
|
| 177 |
+
st.error(f"Agent '{agent_id}' not found via API.")
|
| 178 |
+
return False
|
| 179 |
+
if 200 <= response.status_code < 300:
|
| 180 |
+
api_response = response.json()
|
| 181 |
+
if api_response.get("success"):
|
| 182 |
+
st.success(f"API: {api_response.get('message', 'Stop signal sent.')}")
|
| 183 |
+
return True
|
| 184 |
+
else:
|
| 185 |
+
st.warning(f"API: {api_response.get('message', 'Agent might already be stopped.')}")
|
| 186 |
+
return False
|
| 187 |
+
else:
|
| 188 |
+
error_detail = "Unknown error"
|
| 189 |
+
try: error_detail = response.json().get("detail", error_detail)
|
| 190 |
+
except: pass
|
| 191 |
+
st.error(f"API Error stopping agent '{agent_id}': {error_detail} (Status: {response.status_code})")
|
| 192 |
+
return False
|
| 193 |
+
except requests.exceptions.RequestException as e:
|
| 194 |
+
st.error(f"Connection Error stopping agent '{agent_id}': {e}")
|
| 195 |
+
return False
|
| 196 |
+
except Exception as e:
|
| 197 |
+
st.error(f"Unexpected error stopping agent '{agent_id}': {e}")
|
| 198 |
+
logger.exception(f"Unexpected error in stop_agent for {agent_id}")
|
| 199 |
+
return False
|
| 200 |
+
|
| 201 |
+
def get_dashboard_metrics() -> Optional[Dict[str, Any]]:
|
| 202 |
+
"""Fetches dashboard metrics from the API."""
|
| 203 |
+
try:
|
| 204 |
+
response = requests.get(f"{TENSORUS_API_URL}/metrics/dashboard")
|
| 205 |
+
response.raise_for_status()
|
| 206 |
+
# Returns DashboardMetrics model dict
|
| 207 |
+
return response.json()
|
| 208 |
+
except requests.exceptions.RequestException as e:
|
| 209 |
+
st.error(f"Connection Error fetching dashboard metrics: {e}")
|
| 210 |
+
return None
|
| 211 |
+
except Exception as e:
|
| 212 |
+
st.error(f"Unexpected error fetching dashboard metrics: {e}")
|
| 213 |
+
logger.exception("Unexpected error in get_dashboard_metrics")
|
| 214 |
+
return None
|
| 215 |
+
|
| 216 |
+
def get_agent_config(agent_id: str) -> Optional[Dict[str, Any]]:
|
| 217 |
+
"""Fetch an agent's configuration from the API."""
|
| 218 |
+
try:
|
| 219 |
+
response = requests.get(f"{TENSORUS_API_URL}/agents/{agent_id}/config")
|
| 220 |
+
if response.status_code == 404:
|
| 221 |
+
st.error(f"Agent '{agent_id}' not found via API for configuration.")
|
| 222 |
+
return None
|
| 223 |
+
response.raise_for_status()
|
| 224 |
+
return response.json()
|
| 225 |
+
except requests.exceptions.RequestException as e:
|
| 226 |
+
st.error(f"Connection Error fetching config for agent '{agent_id}': {e}")
|
| 227 |
+
return None
|
| 228 |
+
except Exception as e:
|
| 229 |
+
st.error(f"Unexpected error fetching config for agent '{agent_id}': {e}")
|
| 230 |
+
logger.exception(f"Unexpected error in get_agent_config for {agent_id}")
|
| 231 |
+
return None
|
| 232 |
+
|
| 233 |
+
def update_agent_config(agent_id: str, config: Dict[str, Any]) -> bool:
|
| 234 |
+
"""Send updated configuration for an agent to the API."""
|
| 235 |
+
try:
|
| 236 |
+
response = requests.post(
|
| 237 |
+
f"{TENSORUS_API_URL}/agents/{agent_id}/configure",
|
| 238 |
+
json={"config": config},
|
| 239 |
+
)
|
| 240 |
+
if response.status_code == 404:
|
| 241 |
+
st.error(f"Agent '{agent_id}' not found via API for configuration.")
|
| 242 |
+
return False
|
| 243 |
+
if 200 <= response.status_code < 300:
|
| 244 |
+
api_response = response.json()
|
| 245 |
+
if api_response.get("success"):
|
| 246 |
+
st.success(api_response.get("message", "Configuration updated."))
|
| 247 |
+
return True
|
| 248 |
+
else:
|
| 249 |
+
st.error(api_response.get("message", "Failed to update configuration."))
|
| 250 |
+
return False
|
| 251 |
+
else:
|
| 252 |
+
error_detail = "Unknown error"
|
| 253 |
+
try:
|
| 254 |
+
error_detail = response.json().get("detail", error_detail)
|
| 255 |
+
except Exception:
|
| 256 |
+
pass
|
| 257 |
+
st.error(
|
| 258 |
+
f"API Error updating config for '{agent_id}': {error_detail} (Status: {response.status_code})"
|
| 259 |
+
)
|
| 260 |
+
return False
|
| 261 |
+
except requests.exceptions.RequestException as e:
|
| 262 |
+
st.error(f"Connection Error updating config for agent '{agent_id}': {e}")
|
| 263 |
+
return False
|
| 264 |
+
except Exception as e:
|
| 265 |
+
st.error(f"Unexpected error updating config for agent '{agent_id}': {e}")
|
| 266 |
+
logger.exception(f"Unexpected error in update_agent_config for {agent_id}")
|
| 267 |
+
return False
|
pyproject.toml
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "tensorus"
|
| 3 |
+
version = "0.0.3"
|
| 4 |
+
authors = [
|
| 5 |
+
{ name = "Tensorus Team", email = "ai@tensorus.com" }
|
| 6 |
+
]
|
| 7 |
+
description = "An agentic tensor database and agent framework for managing, querying, and automating tensor data workflows."
|
| 8 |
+
readme = "README.md"
|
| 9 |
+
requires-python = ">=3.9"
|
| 10 |
+
license = { file = "LICENSE" }
|
| 11 |
+
keywords = [
|
| 12 |
+
"tensor", "database", "agent", "ai", "pytorch", "fastapi", "streamlit", "automl", "reinforcement-learning", "data-ingestion"
|
| 13 |
+
]
|
| 14 |
+
classifiers = [
|
| 15 |
+
"Development Status :: 3 - Alpha",
|
| 16 |
+
"Intended Audience :: Developers",
|
| 17 |
+
"Intended Audience :: Science/Research",
|
| 18 |
+
"License :: OSI Approved :: MIT License",
|
| 19 |
+
"Operating System :: OS Independent",
|
| 20 |
+
"Programming Language :: Python :: 3.9",
|
| 21 |
+
"Programming Language :: Python :: 3.10",
|
| 22 |
+
"Programming Language :: Python :: 3.11",
|
| 23 |
+
"Topic :: Database",
|
| 24 |
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
| 25 |
+
"Topic :: Software Development :: Libraries"
|
| 26 |
+
]
|
| 27 |
+
[project.urls]
|
| 28 |
+
Homepage = "https://tensorus.com"
|
| 29 |
+
Repository = "https://github.com/tensorus/tensorus"
|
| 30 |
+
|
| 31 |
+
[project.optional-dependencies]
|
| 32 |
+
models = ["tensorus-models>=0.0.3"]
|
| 33 |
+
|
| 34 |
+
[build-system]
|
| 35 |
+
requires = ["setuptools>=61.0", "wheel"]
|
| 36 |
+
build-backend = "setuptools.build_meta"
|
| 37 |
+
|
| 38 |
+
[tool.setuptools.packages.find]
|
| 39 |
+
where = ["."]
|
| 40 |
+
include = ["tensorus*"]
|
| 41 |
+
namespaces = false
|
requirements-test.txt
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# requirements-test.txt
|
| 2 |
+
# Python dependencies needed to run the test suite
|
| 3 |
+
|
| 4 |
+
pytest
|
| 5 |
+
pytest-asyncio
|
| 6 |
+
httpx==0.23.0
|
| 7 |
+
scipy
|
| 8 |
+
semopy>=2.3
|
| 9 |
+
pydantic-settings>=2.0
|
| 10 |
+
# FastAPI version supporting the Pydantic v2 API
|
| 11 |
+
fastapi>=0.110
|
| 12 |
+
numpy>=1.21.0
|
| 13 |
+
torch
|
| 14 |
+
tensorly
|
| 15 |
+
transformers
|
| 16 |
+
psycopg2-binary
|
| 17 |
+
python-jose[cryptography]
|
requirements.txt
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# requirements.txt
|
| 2 |
+
# Updated: 2024-07-26 # (Date updated)
|
| 3 |
+
|
| 4 |
+
# --- Core Tensor and Numerics ---
|
| 5 |
+
torch>=1.13.0
|
| 6 |
+
torchvision>=0.14.0
|
| 7 |
+
segmentation-models-pytorch
|
| 8 |
+
transformers
|
| 9 |
+
numpy>=1.21.0
|
| 10 |
+
tensorly
|
| 11 |
+
|
| 12 |
+
# --- Agent Specific Dependencies ---
|
| 13 |
+
# For Ingestion Agent image processing example
|
| 14 |
+
Pillow>=9.0.0
|
| 15 |
+
|
| 16 |
+
# --- API Layer Dependencies ---
|
| 17 |
+
# FastAPI with Pydantic v2 support
|
| 18 |
+
fastapi>=0.110.0
|
| 19 |
+
pydantic>=2.0.0
|
| 20 |
+
pydantic-settings>=2.0
|
| 21 |
+
# ASGI Server (standard includes extras like watchfiles for reload)
|
| 22 |
+
uvicorn[standard]>=0.20.0
|
| 23 |
+
fastmcp>=0.2.0
|
| 24 |
+
# For PostgreSQL connectivity (used by PostgresMetadataStorage)
|
| 25 |
+
psycopg2-binary>=2.9.0
|
| 26 |
+
# Optional: Needed if using FastAPI file uploads via forms
|
| 27 |
+
# python-multipart>=0.0.5
|
| 28 |
+
|
| 29 |
+
# --- Streamlit UI Dependencies ---
|
| 30 |
+
streamlit>=1.25.0
|
| 31 |
+
# For calling the FastAPI backend from the Streamlit UI
|
| 32 |
+
requests>=2.28.0
|
| 33 |
+
# For JWT validation
|
| 34 |
+
python-jose[cryptography]
|
| 35 |
+
# For plotting in the Streamlit UI (Dashboard, Data Explorer)
|
| 36 |
+
plotly>=5.10.0
|
| 37 |
+
|
| 38 |
+
# --- Testing Dependencies ---
|
| 39 |
+
pytest>=7.0.0
|
| 40 |
+
httpx>=0.23.0 # For FastAPI TestClient
|
| 41 |
+
|
| 42 |
+
# --- Data Analysis & Modeling ---
|
| 43 |
+
# Optional: For plotting example in rl_agent.py
|
| 44 |
+
matplotlib>=3.5.0
|
| 45 |
+
# For Time Series Analysis (ARIMA model)
|
| 46 |
+
scikit-learn>=1.3.0
|
| 47 |
+
umap-learn
|
| 48 |
+
pandas>=1.5.0
|
| 49 |
+
arch>=5.7
|
| 50 |
+
lifelines>=0.28
|
| 51 |
+
semopy>=2.3
|
| 52 |
+
gensim
|
| 53 |
+
joblib
|
| 54 |
+
opencv-python
|
setup.sh
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# Setup script for Tensorus development environment
|
| 3 |
+
set -e
|
| 4 |
+
|
| 5 |
+
# Install Python dependencies
|
| 6 |
+
pip install --index-url https://download.pytorch.org/whl/cpu torch torchvision
|
| 7 |
+
pip install -r requirements.txt
|
| 8 |
+
pip install -r requirements-test.txt
|
| 9 |
+
|
| 10 |
+
if [[ "$INSTALL_MODELS" == "1" ]]; then
|
| 11 |
+
pip install -e .[models]
|
| 12 |
+
fi
|
tensorus/__init__.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tensorus core package."""
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
# The repository previously exposed a large collection of models under
|
| 6 |
+
# ``tensorus.models``. These models have been moved to a separate package.
|
| 7 |
+
# Tensorus now only attempts to import them if available.
|
| 8 |
+
|
| 9 |
+
if not os.environ.get("TENSORUS_MINIMAL_IMPORT"):
|
| 10 |
+
try:
|
| 11 |
+
import importlib
|
| 12 |
+
|
| 13 |
+
models = importlib.import_module("tensorus.models")
|
| 14 |
+
__all__ = ["models"]
|
| 15 |
+
except ModuleNotFoundError:
|
| 16 |
+
__all__ = []
|
| 17 |
+
else:
|
| 18 |
+
__all__ = []
|
tensorus/api.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
tensorus/api/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Tensorus API Package.
|
| 3 |
+
|
| 4 |
+
This package contains the FastAPI application and endpoints for interacting
|
| 5 |
+
with Tensor Descriptors and Semantic Metadata.
|
| 6 |
+
|
| 7 |
+
To run the API (assuming Uvicorn is installed and you are in the project root):
|
| 8 |
+
`uvicorn tensorus.api.main:app --reload --port 7860`
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from .main import app
|
| 12 |
+
|
| 13 |
+
__all__ = ["app"]
|
tensorus/api/dependencies.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from tensorus.metadata import storage_instance as globally_configured_storage_instance
|
| 2 |
+
from tensorus.metadata.storage_abc import MetadataStorage
|
| 3 |
+
|
| 4 |
+
# Note: The `storage_instance` imported here is already configured (InMemory or Postgres)
|
| 5 |
+
# based on the logic in `tensorus/metadata/__init__.py` which reads `tensorus.config.settings`.
|
| 6 |
+
|
| 7 |
+
def get_storage_instance() -> MetadataStorage:
|
| 8 |
+
"""
|
| 9 |
+
FastAPI dependency to get the currently configured metadata storage instance.
|
| 10 |
+
"""
|
| 11 |
+
return globally_configured_storage_instance
|
| 12 |
+
|
| 13 |
+
# If we needed to re-initialize or pass settings directly to the storage for each request
|
| 14 |
+
# (e.g. for request-scoped sessions or dynamic configuration per request, which is not the case here),
|
| 15 |
+
# this function would be more complex. For now, it just returns the global instance.
|
| 16 |
+
#
|
| 17 |
+
# Example of re-initializing if storage_instance could change or needs request context:
|
| 18 |
+
# from tensorus.metadata import get_configured_storage_instance, ConfigurationError
|
| 19 |
+
# from tensorus.config import settings
|
| 20 |
+
#
|
| 21 |
+
# def get_storage_instance_dynamic() -> MetadataStorage:
|
| 22 |
+
# try:
|
| 23 |
+
# # This would re-evaluate settings and re-create the instance per request if needed,
|
| 24 |
+
# # or could access request-specific config.
|
| 25 |
+
# # For our current setup, the global instance is fine.
|
| 26 |
+
# return get_configured_storage_instance()
|
| 27 |
+
# except ConfigurationError as e:
|
| 28 |
+
# # This would ideally be caught by a global exception handler in FastAPI
|
| 29 |
+
# # to return a 500 error if configuration is bad during a request.
|
| 30 |
+
# # However, configuration should typically be validated at startup.
|
| 31 |
+
# raise RuntimeError(f"Storage configuration error: {e}")
|
| 32 |
+
|
| 33 |
+
# The current `storage_instance` is initialized once when `tensorus.metadata` is first imported.
|
| 34 |
+
# This is generally fine for many applications unless the configuration needs to change without restarting.
|
| 35 |
+
# The FastAPI `Depends` system will call `get_storage_instance` for each request that uses it,
|
| 36 |
+
# but this function will always return the same globally initialized `storage_instance`.
|
tensorus/api/endpoints.py
ADDED
|
@@ -0,0 +1,601 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Dict, Optional, Any, Annotated, Literal
|
| 2 |
+
from uuid import UUID
|
| 3 |
+
from datetime import datetime
|
| 4 |
+
|
| 5 |
+
from fastapi import APIRouter, HTTPException, Body, Query, Path
|
| 6 |
+
from pydantic import BaseModel, ValidationError
|
| 7 |
+
|
| 8 |
+
from tensorus.metadata.schemas import (
|
| 9 |
+
TensorDescriptor, SemanticMetadata, DataType,
|
| 10 |
+
LineageSourceType, LineageMetadata, ParentTensorLink,
|
| 11 |
+
ComputationalMetadata, QualityMetadata, RelationalMetadata, UsageMetadata,
|
| 12 |
+
)
|
| 13 |
+
from tensorus.metadata.storage_abc import MetadataStorage
|
| 14 |
+
from tensorus.api.dependencies import get_storage_instance
|
| 15 |
+
from tensorus.storage.connectors import mock_tensor_connector_instance
|
| 16 |
+
from tensorus.metadata.schemas_iodata import TensorusExportData
|
| 17 |
+
|
| 18 |
+
from pydantic import BaseModel as PydanticBaseModel
|
| 19 |
+
from typing import TypeVar, Generic
|
| 20 |
+
|
| 21 |
+
from fastapi import Depends, Security, status # Added status for HTTPException
|
| 22 |
+
from fastapi.responses import JSONResponse
|
| 23 |
+
from .security import verify_api_key, api_key_header_auth
|
| 24 |
+
from tensorus.audit import log_audit_event
|
| 25 |
+
|
| 26 |
+
import copy
|
| 27 |
+
import uuid
|
| 28 |
+
|
| 29 |
+
# Router for TensorDescriptor
|
| 30 |
+
router_tensor_descriptor = APIRouter(
|
| 31 |
+
prefix="/tensor_descriptors",
|
| 32 |
+
tags=["Tensor Descriptors"],
|
| 33 |
+
responses={404: {"description": "Not found"}},
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
# Router for SemanticMetadata
|
| 37 |
+
router_semantic_metadata = APIRouter(
|
| 38 |
+
prefix="/tensor_descriptors/{tensor_id}/semantic", # Corrected prefix for consistency
|
| 39 |
+
tags=["Semantic Metadata (Per Tensor)"],
|
| 40 |
+
responses={404: {"description": "Not found"}},
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
# --- TensorDescriptor Endpoints ---
|
| 44 |
+
@router_tensor_descriptor.post("/", response_model=TensorDescriptor, status_code=status.HTTP_201_CREATED)
|
| 45 |
+
async def create_tensor_descriptor(
|
| 46 |
+
descriptor_data: Dict[str, Any],
|
| 47 |
+
storage: MetadataStorage = Depends(get_storage_instance),
|
| 48 |
+
api_key: str = Depends(verify_api_key)
|
| 49 |
+
):
|
| 50 |
+
tensor_id_str = descriptor_data.get("tensor_id")
|
| 51 |
+
temp_id_for_lookup = UUID(tensor_id_str) if tensor_id_str else uuid.uuid4()
|
| 52 |
+
|
| 53 |
+
fields_to_fetch = ["shape", "data_type", "byte_size", "dimensionality"]
|
| 54 |
+
missing_fields = [field for field in fields_to_fetch if field not in descriptor_data or descriptor_data[field] is None]
|
| 55 |
+
|
| 56 |
+
if missing_fields:
|
| 57 |
+
storage_details = mock_tensor_connector_instance.get_tensor_details(temp_id_for_lookup)
|
| 58 |
+
if storage_details:
|
| 59 |
+
for field in missing_fields:
|
| 60 |
+
if field in storage_details and (field not in descriptor_data or descriptor_data[field] is None) :
|
| 61 |
+
descriptor_data[field] = storage_details[field]
|
| 62 |
+
if "shape" in descriptor_data and descriptor_data["shape"] is not None \
|
| 63 |
+
and ("dimensionality" not in descriptor_data or descriptor_data["dimensionality"] is None):
|
| 64 |
+
descriptor_data["dimensionality"] = len(descriptor_data["shape"])
|
| 65 |
+
try:
|
| 66 |
+
final_descriptor = TensorDescriptor(**descriptor_data)
|
| 67 |
+
storage.add_tensor_descriptor(final_descriptor)
|
| 68 |
+
if mock_tensor_connector_instance.retrieve_tensor(final_descriptor.tensor_id) is None:
|
| 69 |
+
mock_tensor_data_payload = {
|
| 70 |
+
"shape": final_descriptor.shape, "data_type": final_descriptor.data_type.value,
|
| 71 |
+
"byte_size": final_descriptor.byte_size, "info": "Placeholder data by create_tensor_descriptor"
|
| 72 |
+
}
|
| 73 |
+
mock_tensor_connector_instance.store_tensor(final_descriptor.tensor_id, mock_tensor_data_payload)
|
| 74 |
+
log_audit_event(action="CREATE_TENSOR_DESCRIPTOR", user=api_key, tensor_id=str(final_descriptor.tensor_id),
|
| 75 |
+
details={"owner": final_descriptor.owner, "data_type": final_descriptor.data_type.value})
|
| 76 |
+
return final_descriptor
|
| 77 |
+
except ValidationError as e:
|
| 78 |
+
log_audit_event(action="CREATE_TENSOR_DESCRIPTOR_FAILED_VALIDATION", user=api_key, details={"error": str(e), "input_tensor_id": tensor_id_str})
|
| 79 |
+
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=e.errors())
|
| 80 |
+
except ValueError as e:
|
| 81 |
+
log_audit_event(action="CREATE_TENSOR_DESCRIPTOR_FAILED_STORAGE", user=api_key, details={"error": str(e), "input_tensor_id": tensor_id_str})
|
| 82 |
+
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
| 83 |
+
|
| 84 |
+
@router_tensor_descriptor.get("/", response_model=List[TensorDescriptor], summary="List Tensor Descriptors with Advanced Filtering")
|
| 85 |
+
async def list_tensor_descriptors(
|
| 86 |
+
owner: Optional[str] = Query(None), data_type: Optional[DataType] = Query(None),
|
| 87 |
+
tags_contain: Optional[List[str]] = Query(None), lineage_version: Optional[str] = Query(None, alias="lineage.version"),
|
| 88 |
+
lineage_source_type: Optional[LineageSourceType] = Query(None, alias="lineage.source.type"),
|
| 89 |
+
comp_algorithm: Optional[str] = Query(None, alias="computational.algorithm"),
|
| 90 |
+
comp_gpu_model: Optional[str] = Query(None, alias="computational.hardware_info.gpu_model"),
|
| 91 |
+
quality_confidence_gt: Optional[float] = Query(None, alias="quality.confidence_score_gt"),
|
| 92 |
+
quality_noise_lt: Optional[float] = Query(None, alias="quality.noise_level_lt"),
|
| 93 |
+
rel_collection: Optional[str] = Query(None, alias="relational.collection"),
|
| 94 |
+
rel_has_related_tensor_id: Optional[UUID] = Query(None, alias="relational.has_related_tensor_id"),
|
| 95 |
+
usage_last_accessed_before: Optional[datetime] = Query(None, alias="usage.last_accessed_before"),
|
| 96 |
+
usage_used_by_app: Optional[str] = Query(None, alias="usage.used_by_app"),
|
| 97 |
+
storage: MetadataStorage = Depends(get_storage_instance)
|
| 98 |
+
):
|
| 99 |
+
all_descriptors = storage.list_tensor_descriptors(
|
| 100 |
+
owner=owner, data_type=data_type, tags_contain=tags_contain, lineage_version=lineage_version
|
| 101 |
+
) # Pass some common filters
|
| 102 |
+
filtered_descriptors = []
|
| 103 |
+
for desc in all_descriptors: # Apply remaining filters in memory
|
| 104 |
+
if lineage_source_type and (not (lm := storage.get_lineage_metadata(desc.tensor_id)) or not lm.source or lm.source.type != lineage_source_type): continue
|
| 105 |
+
if comp_algorithm or comp_gpu_model:
|
| 106 |
+
cm = storage.get_computational_metadata(desc.tensor_id)
|
| 107 |
+
if not cm: continue
|
| 108 |
+
if comp_algorithm and cm.algorithm != comp_algorithm: continue
|
| 109 |
+
if comp_gpu_model and (not cm.hardware_info or cm.hardware_info.get("gpu_model") != comp_gpu_model): continue
|
| 110 |
+
if quality_confidence_gt or quality_noise_lt:
|
| 111 |
+
qm = storage.get_quality_metadata(desc.tensor_id)
|
| 112 |
+
if not qm: continue
|
| 113 |
+
if quality_confidence_gt and (qm.confidence_score is None or qm.confidence_score <= quality_confidence_gt): continue
|
| 114 |
+
if quality_noise_lt and (qm.noise_level is None or qm.noise_level >= quality_noise_lt): continue
|
| 115 |
+
if rel_collection or rel_has_related_tensor_id:
|
| 116 |
+
rm = storage.get_relational_metadata(desc.tensor_id)
|
| 117 |
+
if not rm: continue
|
| 118 |
+
if rel_collection and rel_collection not in rm.collections: continue
|
| 119 |
+
if rel_has_related_tensor_id and not any(rtl.related_tensor_id == rel_has_related_tensor_id for rtl in rm.related_tensors): continue
|
| 120 |
+
if usage_last_accessed_before or usage_used_by_app:
|
| 121 |
+
um = storage.get_usage_metadata(desc.tensor_id)
|
| 122 |
+
if not um: continue
|
| 123 |
+
if usage_last_accessed_before and (not um.last_accessed_at or um.last_accessed_at >= usage_last_accessed_before): continue
|
| 124 |
+
if usage_used_by_app and usage_used_by_app not in um.application_references: continue
|
| 125 |
+
filtered_descriptors.append(desc)
|
| 126 |
+
return filtered_descriptors
|
| 127 |
+
|
| 128 |
+
@router_tensor_descriptor.get("/{tensor_id}", response_model=TensorDescriptor)
|
| 129 |
+
async def get_tensor_descriptor(tensor_id: UUID = Path(...), storage: MetadataStorage = Depends(get_storage_instance)):
|
| 130 |
+
descriptor = storage.get_tensor_descriptor(tensor_id)
|
| 131 |
+
if not descriptor: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"TensorDescriptor ID {tensor_id} not found.")
|
| 132 |
+
return descriptor
|
| 133 |
+
|
| 134 |
+
class TensorDescriptorUpdate(BaseModel):
|
| 135 |
+
dimensionality: Optional[int]=None; shape: Optional[List[int]]=None; data_type: Optional[DataType]=None
|
| 136 |
+
storage_format: Optional[str]=None; owner: Optional[str]=None; access_control: Optional[Dict[str, List[str]]]=None
|
| 137 |
+
byte_size: Optional[int]=None; compression_info: Optional[Dict[str, Any]]=None
|
| 138 |
+
tags: Optional[List[str]]=None; metadata: Optional[Dict[str, Any]]=None
|
| 139 |
+
|
| 140 |
+
@router_tensor_descriptor.put("/{tensor_id}", response_model=TensorDescriptor)
|
| 141 |
+
async def update_tensor_descriptor(
|
| 142 |
+
tensor_id: UUID = Path(...), updates: TensorDescriptorUpdate = Body(...),
|
| 143 |
+
storage: MetadataStorage = Depends(get_storage_instance), api_key: str = Depends(verify_api_key)
|
| 144 |
+
):
|
| 145 |
+
update_data = updates.model_dump(exclude_unset=True)
|
| 146 |
+
if not update_data: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="No update data provided.")
|
| 147 |
+
current = storage.get_tensor_descriptor(tensor_id)
|
| 148 |
+
if not current:
|
| 149 |
+
log_audit_event("UPDATE_TENSOR_DESCRIPTOR_FAILED_NOT_FOUND", api_key, str(tensor_id))
|
| 150 |
+
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"TensorDescriptor ID {tensor_id} not found.")
|
| 151 |
+
try:
|
| 152 |
+
updated = storage.update_tensor_descriptor(tensor_id, **update_data)
|
| 153 |
+
log_audit_event("UPDATE_TENSOR_DESCRIPTOR", api_key, str(tensor_id), {"updated_fields": list(update_data.keys())})
|
| 154 |
+
return updated
|
| 155 |
+
except (ValidationError, ValueError) as e:
|
| 156 |
+
log_audit_event("UPDATE_TENSOR_DESCRIPTOR_FAILED_VALIDATION", api_key, str(tensor_id), {"error": str(e)})
|
| 157 |
+
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY if isinstance(e, ValidationError) else status.HTTP_400_BAD_REQUEST, detail=str(e))
|
| 158 |
+
|
| 159 |
+
@router_tensor_descriptor.delete("/{tensor_id}", status_code=status.HTTP_200_OK)
|
| 160 |
+
async def delete_tensor_descriptor(
|
| 161 |
+
tensor_id: UUID = Path(...), storage: MetadataStorage = Depends(get_storage_instance), api_key: str = Depends(verify_api_key)
|
| 162 |
+
):
|
| 163 |
+
if not storage.get_tensor_descriptor(tensor_id):
|
| 164 |
+
log_audit_event("DELETE_TENSOR_DESCRIPTOR_FAILED_NOT_FOUND", api_key, str(tensor_id))
|
| 165 |
+
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"TensorDescriptor ID {tensor_id} not found.")
|
| 166 |
+
storage.delete_tensor_descriptor(tensor_id)
|
| 167 |
+
mock_tensor_connector_instance.delete_tensor(tensor_id)
|
| 168 |
+
log_audit_event("DELETE_TENSOR_DESCRIPTOR", api_key, str(tensor_id))
|
| 169 |
+
return {"message": f"TensorDescriptor {tensor_id} and associated data deleted."}
|
| 170 |
+
|
| 171 |
+
# --- SemanticMetadata Endpoints ---
|
| 172 |
+
def _check_td_exists_for_semantic(tensor_id: UUID, storage: MetadataStorage, api_key: Optional[str] = None, action_prefix: str = ""):
|
| 173 |
+
if not storage.get_tensor_descriptor(tensor_id):
|
| 174 |
+
if api_key and action_prefix:
|
| 175 |
+
log_audit_event(f"{action_prefix}_SEMANTIC_METADATA_FAILED_TD_NOT_FOUND", user=api_key, tensor_id=str(tensor_id))
|
| 176 |
+
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Parent TensorDescriptor ID {tensor_id} not found.")
|
| 177 |
+
|
| 178 |
+
@router_semantic_metadata.post("/", response_model=SemanticMetadata, status_code=status.HTTP_201_CREATED)
|
| 179 |
+
async def create_semantic_metadata_for_tensor(
|
| 180 |
+
tensor_id: UUID = Path(...), metadata_in: SemanticMetadata = Body(...),
|
| 181 |
+
storage: MetadataStorage = Depends(get_storage_instance), api_key: str = Depends(verify_api_key)
|
| 182 |
+
):
|
| 183 |
+
_check_td_exists_for_semantic(tensor_id, storage, api_key, "CREATE")
|
| 184 |
+
if metadata_in.tensor_id != tensor_id:
|
| 185 |
+
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="tensor_id in path and body must match.")
|
| 186 |
+
try:
|
| 187 |
+
storage.add_semantic_metadata(metadata_in)
|
| 188 |
+
log_audit_event("CREATE_SEMANTIC_METADATA", api_key, str(tensor_id), {"name": metadata_in.name})
|
| 189 |
+
return metadata_in
|
| 190 |
+
except (ValidationError, ValueError) as e:
|
| 191 |
+
log_audit_event("CREATE_SEMANTIC_METADATA_FAILED", api_key, str(tensor_id), {"name": metadata_in.name, "error": str(e)})
|
| 192 |
+
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST if isinstance(e, ValueError) else status.HTTP_422_UNPROCESSABLE_ENTITY, detail=str(e))
|
| 193 |
+
|
| 194 |
+
@router_semantic_metadata.get("/", response_model=List[SemanticMetadata])
|
| 195 |
+
async def get_all_semantic_metadata_for_tensor(tensor_id: UUID = Path(...), storage: MetadataStorage = Depends(get_storage_instance)):
|
| 196 |
+
_check_td_exists_for_semantic(tensor_id, storage)
|
| 197 |
+
return storage.get_semantic_metadata(tensor_id)
|
| 198 |
+
|
| 199 |
+
class SemanticMetadataUpdate(BaseModel): name: Optional[str] = None; description: Optional[str] = None
|
| 200 |
+
|
| 201 |
+
@router_semantic_metadata.put("/{current_name}", response_model=SemanticMetadata)
|
| 202 |
+
async def update_named_semantic_metadata_for_tensor(
|
| 203 |
+
tensor_id: UUID = Path(...), current_name: str = Path(...), updates: SemanticMetadataUpdate = Body(...),
|
| 204 |
+
storage: MetadataStorage = Depends(get_storage_instance), api_key: str = Depends(verify_api_key)
|
| 205 |
+
):
|
| 206 |
+
_check_td_exists_for_semantic(tensor_id, storage, api_key, "UPDATE")
|
| 207 |
+
if not updates.model_dump(exclude_unset=True):
|
| 208 |
+
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="No update data provided.")
|
| 209 |
+
if not storage.get_semantic_metadata_by_name(tensor_id, current_name):
|
| 210 |
+
log_audit_event("UPDATE_SEMANTIC_METADATA_FAILED_NOT_FOUND", api_key, str(tensor_id), {"name": current_name})
|
| 211 |
+
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"SemanticMetadata '{current_name}' not found for tensor {tensor_id}.")
|
| 212 |
+
try:
|
| 213 |
+
updated = storage.update_semantic_metadata(tensor_id, current_name, new_description=updates.description, new_name=updates.name)
|
| 214 |
+
log_audit_event(
|
| 215 |
+
"UPDATE_SEMANTIC_METADATA",
|
| 216 |
+
api_key,
|
| 217 |
+
str(tensor_id),
|
| 218 |
+
{"original_name": current_name, "updated_fields": updates.model_dump(exclude_unset=True)},
|
| 219 |
+
)
|
| 220 |
+
return updated
|
| 221 |
+
except (ValidationError, ValueError) as e:
|
| 222 |
+
log_audit_event("UPDATE_SEMANTIC_METADATA_FAILED", api_key, str(tensor_id), {"name": current_name, "error": str(e)})
|
| 223 |
+
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST if isinstance(e, ValueError) else status.HTTP_422_UNPROCESSABLE_ENTITY, detail=str(e))
|
| 224 |
+
|
| 225 |
+
@router_semantic_metadata.delete("/{name}", status_code=status.HTTP_204_NO_CONTENT)
|
| 226 |
+
async def delete_named_semantic_metadata_for_tensor(
|
| 227 |
+
tensor_id: UUID = Path(...), name: str = Path(...),
|
| 228 |
+
storage: MetadataStorage = Depends(get_storage_instance), api_key: str = Depends(verify_api_key)
|
| 229 |
+
):
|
| 230 |
+
_check_td_exists_for_semantic(tensor_id, storage, api_key, "DELETE")
|
| 231 |
+
if not storage.get_semantic_metadata_by_name(tensor_id, name):
|
| 232 |
+
log_audit_event("DELETE_SEMANTIC_METADATA_FAILED_NOT_FOUND", api_key, str(tensor_id), {"name": name})
|
| 233 |
+
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"SemanticMetadata '{name}' not found for tensor {tensor_id}.")
|
| 234 |
+
storage.delete_semantic_metadata(tensor_id, name)
|
| 235 |
+
log_audit_event("DELETE_SEMANTIC_METADATA", api_key, str(tensor_id), {"name": name})
|
| 236 |
+
return None
|
| 237 |
+
|
| 238 |
+
# --- Search and Aggregation Routers (GET - No Auth/Audit) ---
|
| 239 |
+
router_search_aggregate = APIRouter(tags=["Search & Aggregate"])
|
| 240 |
+
@router_search_aggregate.get("/search/tensors/", response_model=List[TensorDescriptor])
|
| 241 |
+
async def search_tensors(
|
| 242 |
+
text_query: str = Query(..., min_length=1), fields_to_search: Optional[List[str]] = Query(None),
|
| 243 |
+
storage: MetadataStorage = Depends(get_storage_instance)
|
| 244 |
+
):
|
| 245 |
+
default_fields = ["tensor_id", "owner", "tags", "metadata", "semantic.name", "semantic.description", "lineage.source.identifier", "lineage.version", "computational.algorithm"]
|
| 246 |
+
try: return storage.search_tensor_descriptors(text_query, fields_to_search or default_fields)
|
| 247 |
+
except ValueError as e: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
| 248 |
+
|
| 249 |
+
@router_search_aggregate.get("/aggregate/tensors/", response_model=Dict[str, Any])
|
| 250 |
+
async def aggregate_tensors(
|
| 251 |
+
group_by_field: str = Query(...), agg_function: str = Query(...), agg_field: Optional[str] = Query(None),
|
| 252 |
+
storage: MetadataStorage = Depends(get_storage_instance)
|
| 253 |
+
):
|
| 254 |
+
if agg_function in ["avg", "sum", "min", "max"] and not agg_field:
|
| 255 |
+
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"'{agg_function}' requires 'agg_field'.")
|
| 256 |
+
try: return storage.aggregate_tensor_descriptors(group_by_field, agg_function, agg_field)
|
| 257 |
+
except ValueError as e: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
| 258 |
+
except NotImplementedError as e: raise HTTPException(status_code=status.HTTP_501_NOT_IMPLEMENTED, detail=str(e))
|
| 259 |
+
|
| 260 |
+
# --- Versioning and Lineage Router ---
|
| 261 |
+
router_version_lineage = APIRouter(tags=["Versioning & Lineage"])
|
| 262 |
+
class NewTensorVersionRequest(BaseModel):
|
| 263 |
+
new_version_string: str; dimensionality: Optional[int]=None; shape: Optional[List[int]]=None; data_type: Optional[DataType]=None
|
| 264 |
+
storage_format: Optional[str]=None; owner: Optional[str]=None; access_control: Optional[Dict[str, List[str]]]=None
|
| 265 |
+
byte_size: Optional[int]=None; checksum: Optional[str]=None; compression_info: Optional[Dict[str, Any]]=None
|
| 266 |
+
tags: Optional[List[str]]=None; metadata: Optional[Dict[str, Any]]=None
|
| 267 |
+
lineage_source_identifier: Optional[str]=None; lineage_source_type: Optional[LineageSourceType]=None
|
| 268 |
+
|
| 269 |
+
@router_version_lineage.post("/tensors/{tensor_id}/versions", response_model=TensorDescriptor, status_code=status.HTTP_201_CREATED)
|
| 270 |
+
async def create_tensor_version(
|
| 271 |
+
tensor_id: UUID = Path(...), version_request: NewTensorVersionRequest = Body(...),
|
| 272 |
+
storage: MetadataStorage = Depends(get_storage_instance), api_key: str = Depends(verify_api_key)
|
| 273 |
+
):
|
| 274 |
+
parent_td = storage.get_tensor_descriptor(tensor_id)
|
| 275 |
+
if not parent_td:
|
| 276 |
+
log_audit_event("CREATE_TENSOR_VERSION_FAILED_PARENT_NOT_FOUND", api_key, str(tensor_id), {"new_version": version_request.new_version_string})
|
| 277 |
+
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Parent TensorDescriptor ID {tensor_id} not found.")
|
| 278 |
+
new_version_id = uuid.uuid4()
|
| 279 |
+
new_td_data = parent_td.model_dump(
|
| 280 |
+
exclude={"tensor_id", "creation_timestamp", "last_modified_timestamp"}
|
| 281 |
+
)
|
| 282 |
+
for field, value in version_request.model_dump(exclude_unset=True).items():
|
| 283 |
+
if field in TensorDescriptor.model_fields and value is not None: new_td_data[field] = value
|
| 284 |
+
elif field not in ['new_version_string', 'lineage_source_identifier', 'lineage_source_type']:
|
| 285 |
+
if new_td_data.get("metadata") is None: new_td_data["metadata"] = {}
|
| 286 |
+
new_td_data["metadata"][field] = value
|
| 287 |
+
new_td_data.update({
|
| 288 |
+
"tensor_id": new_version_id, "creation_timestamp": datetime.utcnow(),
|
| 289 |
+
"last_modified_timestamp": datetime.utcnow(),
|
| 290 |
+
"owner": new_td_data.get('owner', parent_td.owner),
|
| 291 |
+
"byte_size": new_td_data.get('byte_size', parent_td.byte_size)
|
| 292 |
+
})
|
| 293 |
+
try:
|
| 294 |
+
new_td = TensorDescriptor(**new_td_data)
|
| 295 |
+
storage.add_tensor_descriptor(new_td)
|
| 296 |
+
except (ValidationError, ValueError) as e:
|
| 297 |
+
log_audit_event("CREATE_TENSOR_VERSION_FAILED_VALIDATION", api_key, str(new_version_id), {"parent_id": str(tensor_id), "error": str(e)})
|
| 298 |
+
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY if isinstance(e, ValidationError) else status.HTTP_400_BAD_REQUEST, detail=str(e))
|
| 299 |
+
|
| 300 |
+
lineage_details = {"tensor_id": new_version_id, "parent_tensors": [ParentTensorLink(tensor_id=tensor_id, relationship="new_version_of")], "version": version_request.new_version_string}
|
| 301 |
+
if version_request.lineage_source_identifier and version_request.lineage_source_type:
|
| 302 |
+
lineage_details["source"] = LineageSource(type=version_request.lineage_source_type, identifier=version_request.lineage_source_identifier) # type: ignore
|
| 303 |
+
storage.add_lineage_metadata(LineageMetadata(**lineage_details))
|
| 304 |
+
log_audit_event("CREATE_TENSOR_VERSION", api_key, str(new_version_id), {"parent_id": str(tensor_id), "version": version_request.new_version_string})
|
| 305 |
+
return new_td
|
| 306 |
+
|
| 307 |
+
@router_version_lineage.get("/tensors/{tensor_id}/versions", response_model=List[TensorDescriptor])
|
| 308 |
+
async def list_tensor_versions(tensor_id: UUID = Path(...), storage: MetadataStorage = Depends(get_storage_instance)):
|
| 309 |
+
results = []
|
| 310 |
+
current_td = storage.get_tensor_descriptor(tensor_id)
|
| 311 |
+
if not current_td: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"TensorDescriptor ID {tensor_id} not found.")
|
| 312 |
+
results.append(current_td)
|
| 313 |
+
child_ids = storage.get_child_tensor_ids(tensor_id)
|
| 314 |
+
for child_id in child_ids:
|
| 315 |
+
child_lineage = storage.get_lineage_metadata(child_id)
|
| 316 |
+
if child_lineage and any(p.tensor_id == tensor_id and p.relationship == "new_version_of" for p in child_lineage.parent_tensors):
|
| 317 |
+
child_td_obj = storage.get_tensor_descriptor(child_id)
|
| 318 |
+
if child_td_obj: results.append(child_td_obj)
|
| 319 |
+
return results
|
| 320 |
+
|
| 321 |
+
class LineageRelationshipRequest(BaseModel): source_tensor_id: UUID; target_tensor_id: UUID; relationship_type: str; details: Optional[Dict[str, Any]] = None
|
| 322 |
+
|
| 323 |
+
@router_version_lineage.post("/lineage/relationships/", status_code=status.HTTP_201_CREATED)
|
| 324 |
+
async def create_lineage_relationship(
|
| 325 |
+
req: LineageRelationshipRequest, storage: MetadataStorage = Depends(get_storage_instance),
|
| 326 |
+
api_key: str = Depends(verify_api_key)
|
| 327 |
+
):
|
| 328 |
+
audit_details = req.model_dump()
|
| 329 |
+
if not storage.get_tensor_descriptor(req.source_tensor_id):
|
| 330 |
+
log_audit_event("CREATE_LINEAGE_REL_FAILED_SRC_NOT_FOUND", api_key, details=audit_details)
|
| 331 |
+
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Source TD {req.source_tensor_id} not found.")
|
| 332 |
+
if not storage.get_tensor_descriptor(req.target_tensor_id):
|
| 333 |
+
log_audit_event("CREATE_LINEAGE_REL_FAILED_TGT_NOT_FOUND", api_key, details=audit_details)
|
| 334 |
+
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Target TD {req.target_tensor_id} not found.")
|
| 335 |
+
|
| 336 |
+
target_lineage = storage.get_lineage_metadata(req.target_tensor_id) or LineageMetadata(tensor_id=req.target_tensor_id)
|
| 337 |
+
if any(p.tensor_id == req.source_tensor_id and p.relationship == req.relationship_type for p in target_lineage.parent_tensors):
|
| 338 |
+
return {"message": "Relationship already exists.", "lineage": target_lineage}
|
| 339 |
+
|
| 340 |
+
target_lineage.parent_tensors.append(ParentTensorLink(tensor_id=req.source_tensor_id, relationship=req.relationship_type))
|
| 341 |
+
try:
|
| 342 |
+
storage.add_lineage_metadata(target_lineage)
|
| 343 |
+
log_audit_event("CREATE_LINEAGE_RELATIONSHIP", api_key, str(req.target_tensor_id), details=audit_details)
|
| 344 |
+
return {"message": "Lineage relationship created/updated.", "lineage": storage.get_lineage_metadata(req.target_tensor_id)}
|
| 345 |
+
except (ValidationError, ValueError) as e:
|
| 346 |
+
log_audit_event("CREATE_LINEAGE_REL_FAILED_VALIDATION", api_key, str(req.target_tensor_id), {**audit_details, "error": str(e)})
|
| 347 |
+
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
| 348 |
+
|
| 349 |
+
@router_version_lineage.get("/tensors/{tensor_id}/lineage/parents", response_model=List[TensorDescriptor])
|
| 350 |
+
async def get_parent_tensors(tensor_id: UUID = Path(...), storage: MetadataStorage = Depends(get_storage_instance)):
|
| 351 |
+
if not storage.get_tensor_descriptor(tensor_id): raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"TD {tensor_id} not found.")
|
| 352 |
+
parent_ids = storage.get_parent_tensor_ids(tensor_id)
|
| 353 |
+
return [td for pid in parent_ids if (td := storage.get_tensor_descriptor(pid)) is not None]
|
| 354 |
+
|
| 355 |
+
@router_version_lineage.get("/tensors/{tensor_id}/lineage/children", response_model=List[TensorDescriptor])
|
| 356 |
+
async def get_child_tensors(tensor_id: UUID = Path(...), storage: MetadataStorage = Depends(get_storage_instance)):
|
| 357 |
+
if not storage.get_tensor_descriptor(tensor_id): raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"TD {tensor_id} not found.")
|
| 358 |
+
child_ids = storage.get_child_tensor_ids(tensor_id)
|
| 359 |
+
return [td for cid in child_ids if (td := storage.get_tensor_descriptor(cid)) is not None]
|
| 360 |
+
|
| 361 |
+
# --- Router for Extended Metadata (CRUD per type) ---
|
| 362 |
+
router_extended_metadata = APIRouter(prefix="/tensor_descriptors/{tensor_id}", tags=["Extended Metadata (Per Tensor)"])
|
| 363 |
+
|
| 364 |
+
def _get_td_or_404_for_extended_meta(tensor_id: UUID, storage: MetadataStorage, api_key: Optional[str]=None, action_prefix: str = ""): # Renamed from previous version
|
| 365 |
+
td = storage.get_tensor_descriptor(tensor_id)
|
| 366 |
+
if not td:
|
| 367 |
+
if api_key and action_prefix: log_audit_event(f"{action_prefix}_METADATA_FAILED_TD_NOT_FOUND", user=api_key, tensor_id=str(tensor_id)) # Generic prefix
|
| 368 |
+
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Parent TensorDescriptor ID {tensor_id} not found.")
|
| 369 |
+
return td
|
| 370 |
+
|
| 371 |
+
async def _upsert_extended_metadata(tensor_id: UUID, metadata_name_cap: str, metadata_in: Any, storage: MetadataStorage, api_key: str):
|
| 372 |
+
_get_td_or_404_for_extended_meta(tensor_id, storage, api_key, f"UPSERT_{metadata_name_cap}")
|
| 373 |
+
if metadata_in.tensor_id != tensor_id:
|
| 374 |
+
log_audit_event(f"UPSERT_{metadata_name_cap}_FAILED_ID_MISMATCH", api_key, str(tensor_id), {"body_id": str(metadata_in.tensor_id)})
|
| 375 |
+
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Path tensor_id and body tensor_id must match.")
|
| 376 |
+
try:
|
| 377 |
+
add_method = getattr(storage, f"add_{metadata_name_cap.lower()}_metadata")
|
| 378 |
+
add_method(metadata_in)
|
| 379 |
+
log_audit_event(f"UPSERT_{metadata_name_cap}_METADATA", api_key, str(tensor_id))
|
| 380 |
+
return metadata_in
|
| 381 |
+
except (ValidationError, ValueError) as e:
|
| 382 |
+
log_audit_event(f"UPSERT_{metadata_name_cap}_FAILED", api_key, str(tensor_id), {"error": str(e)})
|
| 383 |
+
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST if isinstance(e, ValueError) else status.HTTP_422_UNPROCESSABLE_ENTITY, detail=str(e))
|
| 384 |
+
|
| 385 |
+
async def _get_extended_metadata_ep(tensor_id: UUID, metadata_name_cap: str, storage: MetadataStorage): # Renamed to avoid clash
|
| 386 |
+
_get_td_or_404_for_extended_meta(tensor_id, storage)
|
| 387 |
+
get_method = getattr(storage, f"get_{metadata_name_cap.lower()}_metadata")
|
| 388 |
+
meta = get_method(tensor_id)
|
| 389 |
+
if not meta: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"{metadata_name_cap}Metadata not found for tensor {tensor_id}.")
|
| 390 |
+
return meta
|
| 391 |
+
|
| 392 |
+
async def _patch_extended_metadata(tensor_id: UUID, metadata_name_cap: str, updates: Dict[str, Any], storage: MetadataStorage, api_key: str):
|
| 393 |
+
_get_td_or_404_for_extended_meta(tensor_id, storage, api_key, f"PATCH_{metadata_name_cap}")
|
| 394 |
+
get_method = getattr(storage, f"get_{metadata_name_cap.lower()}_metadata")
|
| 395 |
+
if not get_method(tensor_id):
|
| 396 |
+
log_audit_event(f"PATCH_{metadata_name_cap}_FAILED_NOT_FOUND", api_key, str(tensor_id))
|
| 397 |
+
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"{metadata_name_cap}Metadata not found for update.")
|
| 398 |
+
try:
|
| 399 |
+
update_method = getattr(storage, f"update_{metadata_name_cap.lower()}_metadata")
|
| 400 |
+
updated_meta = update_method(tensor_id, **updates)
|
| 401 |
+
log_audit_event(f"PATCH_{metadata_name_cap}_METADATA", api_key, str(tensor_id), {"updated_fields": list(updates.keys())})
|
| 402 |
+
return updated_meta
|
| 403 |
+
except (ValidationError, ValueError) as e:
|
| 404 |
+
log_audit_event(f"PATCH_{metadata_name_cap}_FAILED_VALIDATION", api_key, str(tensor_id), {"error": str(e)})
|
| 405 |
+
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY if isinstance(e, ValidationError) else status.HTTP_400_BAD_REQUEST, detail=str(e))
|
| 406 |
+
|
| 407 |
+
async def _delete_extended_metadata(tensor_id: UUID, metadata_name_cap: str, storage: MetadataStorage, api_key: str):
|
| 408 |
+
_get_td_or_404_for_extended_meta(tensor_id, storage, api_key, f"DELETE_{metadata_name_cap}")
|
| 409 |
+
delete_method = getattr(storage, f"delete_{metadata_name_cap.lower()}_metadata")
|
| 410 |
+
if not delete_method(tensor_id):
|
| 411 |
+
log_audit_event(f"DELETE_{metadata_name_cap}_FAILED_NOT_FOUND", api_key, str(tensor_id))
|
| 412 |
+
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"{metadata_name_cap}Metadata not found for delete.")
|
| 413 |
+
log_audit_event(f"DELETE_{metadata_name_cap}_METADATA", api_key, str(tensor_id))
|
| 414 |
+
return None
|
| 415 |
+
|
| 416 |
+
# Explicit CRUD endpoints for each extended metadata type
|
| 417 |
+
@router_extended_metadata.post("/lineage", response_model=LineageMetadata, status_code=status.HTTP_201_CREATED)
|
| 418 |
+
async def upsert_lineage_metadata_ep(tensor_id: UUID=Path(...), lineage_in: LineageMetadata=Body(...), storage: MetadataStorage=Depends(get_storage_instance), api_key: str=Depends(verify_api_key)):
|
| 419 |
+
return await _upsert_extended_metadata(tensor_id, "Lineage", lineage_in, storage, api_key)
|
| 420 |
+
@router_extended_metadata.get("/lineage", response_model=LineageMetadata)
|
| 421 |
+
async def get_lineage_metadata_ep(tensor_id: UUID=Path(...), storage: MetadataStorage=Depends(get_storage_instance)):
|
| 422 |
+
return await _get_extended_metadata_ep(tensor_id, "Lineage", storage)
|
| 423 |
+
@router_extended_metadata.patch("/lineage", response_model=LineageMetadata)
|
| 424 |
+
async def patch_lineage_metadata_ep(tensor_id: UUID=Path(...), updates: Dict[str,Any]=Body(...), storage: MetadataStorage=Depends(get_storage_instance), api_key: str=Depends(verify_api_key)):
|
| 425 |
+
return await _patch_extended_metadata(tensor_id, "Lineage", updates, storage, api_key)
|
| 426 |
+
@router_extended_metadata.delete("/lineage", status_code=status.HTTP_204_NO_CONTENT)
|
| 427 |
+
async def delete_lineage_metadata_ep(tensor_id: UUID=Path(...), storage: MetadataStorage=Depends(get_storage_instance), api_key: str=Depends(verify_api_key)):
|
| 428 |
+
return await _delete_extended_metadata(tensor_id, "Lineage", storage, api_key)
|
| 429 |
+
|
| 430 |
+
@router_extended_metadata.post("/computational", response_model=ComputationalMetadata, status_code=status.HTTP_201_CREATED)
|
| 431 |
+
async def upsert_computational_metadata_ep(tensor_id: UUID=Path(...), computational_in: ComputationalMetadata=Body(...), storage: MetadataStorage=Depends(get_storage_instance), api_key: str=Depends(verify_api_key)):
|
| 432 |
+
return await _upsert_extended_metadata(tensor_id, "Computational", computational_in, storage, api_key)
|
| 433 |
+
@router_extended_metadata.get("/computational", response_model=ComputationalMetadata)
|
| 434 |
+
async def get_computational_metadata_ep(tensor_id: UUID=Path(...), storage: MetadataStorage=Depends(get_storage_instance)):
|
| 435 |
+
return await _get_extended_metadata_ep(tensor_id, "Computational", storage)
|
| 436 |
+
@router_extended_metadata.patch("/computational", response_model=ComputationalMetadata)
|
| 437 |
+
async def patch_computational_metadata_ep(tensor_id: UUID=Path(...), updates: Dict[str,Any]=Body(...), storage: MetadataStorage=Depends(get_storage_instance), api_key: str=Depends(verify_api_key)):
|
| 438 |
+
return await _patch_extended_metadata(tensor_id, "Computational", updates, storage, api_key)
|
| 439 |
+
@router_extended_metadata.delete("/computational", status_code=status.HTTP_204_NO_CONTENT)
|
| 440 |
+
async def delete_computational_metadata_ep(tensor_id: UUID=Path(...), storage: MetadataStorage=Depends(get_storage_instance), api_key: str=Depends(verify_api_key)):
|
| 441 |
+
return await _delete_extended_metadata(tensor_id, "Computational", storage, api_key)
|
| 442 |
+
|
| 443 |
+
@router_extended_metadata.post("/quality", response_model=QualityMetadata, status_code=status.HTTP_201_CREATED)
|
| 444 |
+
async def upsert_quality_metadata_ep(tensor_id: UUID=Path(...), quality_in: QualityMetadata=Body(...), storage: MetadataStorage=Depends(get_storage_instance), api_key: str=Depends(verify_api_key)):
|
| 445 |
+
return await _upsert_extended_metadata(tensor_id, "Quality", quality_in, storage, api_key)
|
| 446 |
+
@router_extended_metadata.get("/quality", response_model=QualityMetadata)
|
| 447 |
+
async def get_quality_metadata_ep(tensor_id: UUID=Path(...), storage: MetadataStorage=Depends(get_storage_instance)):
|
| 448 |
+
return await _get_extended_metadata_ep(tensor_id, "Quality", storage)
|
| 449 |
+
@router_extended_metadata.patch("/quality", response_model=QualityMetadata)
|
| 450 |
+
async def patch_quality_metadata_ep(tensor_id: UUID=Path(...), updates: Dict[str,Any]=Body(...), storage: MetadataStorage=Depends(get_storage_instance), api_key: str=Depends(verify_api_key)):
|
| 451 |
+
return await _patch_extended_metadata(tensor_id, "Quality", updates, storage, api_key)
|
| 452 |
+
@router_extended_metadata.delete("/quality", status_code=status.HTTP_204_NO_CONTENT)
|
| 453 |
+
async def delete_quality_metadata_ep(tensor_id: UUID=Path(...), storage: MetadataStorage=Depends(get_storage_instance), api_key: str=Depends(verify_api_key)):
|
| 454 |
+
return await _delete_extended_metadata(tensor_id, "Quality", storage, api_key)
|
| 455 |
+
|
| 456 |
+
@router_extended_metadata.post("/relational", response_model=RelationalMetadata, status_code=status.HTTP_201_CREATED)
|
| 457 |
+
async def upsert_relational_metadata_ep(tensor_id: UUID=Path(...), relational_in: RelationalMetadata=Body(...), storage: MetadataStorage=Depends(get_storage_instance), api_key: str=Depends(verify_api_key)):
|
| 458 |
+
return await _upsert_extended_metadata(tensor_id, "Relational", relational_in, storage, api_key)
|
| 459 |
+
@router_extended_metadata.get("/relational", response_model=RelationalMetadata)
|
| 460 |
+
async def get_relational_metadata_ep(tensor_id: UUID=Path(...), storage: MetadataStorage=Depends(get_storage_instance)):
|
| 461 |
+
return await _get_extended_metadata_ep(tensor_id, "Relational", storage)
|
| 462 |
+
@router_extended_metadata.patch("/relational", response_model=RelationalMetadata)
|
| 463 |
+
async def patch_relational_metadata_ep(tensor_id: UUID=Path(...), updates: Dict[str,Any]=Body(...), storage: MetadataStorage=Depends(get_storage_instance), api_key: str=Depends(verify_api_key)):
|
| 464 |
+
return await _patch_extended_metadata(tensor_id, "Relational", updates, storage, api_key)
|
| 465 |
+
@router_extended_metadata.delete("/relational", status_code=status.HTTP_204_NO_CONTENT)
|
| 466 |
+
async def delete_relational_metadata_ep(tensor_id: UUID=Path(...), storage: MetadataStorage=Depends(get_storage_instance), api_key: str=Depends(verify_api_key)):
|
| 467 |
+
return await _delete_extended_metadata(tensor_id, "Relational", storage, api_key)
|
| 468 |
+
|
| 469 |
+
@router_extended_metadata.post("/usage", response_model=UsageMetadata, status_code=status.HTTP_201_CREATED)
|
| 470 |
+
async def upsert_usage_metadata_ep(tensor_id: UUID=Path(...), usage_in: UsageMetadata=Body(...), storage: MetadataStorage=Depends(get_storage_instance), api_key: str=Depends(verify_api_key)):
|
| 471 |
+
return await _upsert_extended_metadata(tensor_id, "Usage", usage_in, storage, api_key)
|
| 472 |
+
@router_extended_metadata.get("/usage", response_model=UsageMetadata)
|
| 473 |
+
async def get_usage_metadata_ep(tensor_id: UUID=Path(...), storage: MetadataStorage=Depends(get_storage_instance)):
|
| 474 |
+
return await _get_extended_metadata_ep(tensor_id, "Usage", storage)
|
| 475 |
+
@router_extended_metadata.patch("/usage", response_model=UsageMetadata)
|
| 476 |
+
async def patch_usage_metadata_ep(tensor_id: UUID=Path(...), updates: Dict[str,Any]=Body(...), storage: MetadataStorage=Depends(get_storage_instance), api_key: str=Depends(verify_api_key)):
|
| 477 |
+
return await _patch_extended_metadata(tensor_id, "Usage", updates, storage, api_key)
|
| 478 |
+
@router_extended_metadata.delete("/usage", status_code=status.HTTP_204_NO_CONTENT)
|
| 479 |
+
async def delete_usage_metadata_ep(tensor_id: UUID=Path(...), storage: MetadataStorage=Depends(get_storage_instance), api_key: str=Depends(verify_api_key)):
|
| 480 |
+
return await _delete_extended_metadata(tensor_id, "Usage", storage, api_key)
|
| 481 |
+
|
| 482 |
+
# --- I/O Router for Export/Import ---
|
| 483 |
+
router_io = APIRouter(prefix="/tensors", tags=["Import/Export"])
|
| 484 |
+
|
| 485 |
+
@router_io.get("/export", response_model=TensorusExportData)
|
| 486 |
+
async def export_tensor_metadata(
|
| 487 |
+
tensor_ids_str: Optional[str] = Query(None, alias="tensor_ids"), # Changed FastAPIQuery to Query
|
| 488 |
+
storage: MetadataStorage = Depends(get_storage_instance)
|
| 489 |
+
):
|
| 490 |
+
parsed_tensor_ids: Optional[List[UUID]] = None
|
| 491 |
+
if tensor_ids_str:
|
| 492 |
+
try: parsed_tensor_ids = [UUID(tid.strip()) for tid in tensor_ids_str.split(',')]
|
| 493 |
+
except ValueError: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid UUID format in tensor_ids.")
|
| 494 |
+
export_data = storage.get_export_data(tensor_ids=parsed_tensor_ids)
|
| 495 |
+
filename = f"tensorus_export_{datetime.utcnow().strftime('%Y%m%d_%H%M%S')}.json"
|
| 496 |
+
headers = {"Content-Disposition": f'attachment; filename="{filename}"'}
|
| 497 |
+
return JSONResponse(content=export_data.model_dump(mode="json"), headers=headers)
|
| 498 |
+
|
| 499 |
+
@router_io.post("/import", summary="Import Tensor Metadata")
|
| 500 |
+
async def import_tensor_metadata(
|
| 501 |
+
import_data_payload: TensorusExportData,
|
| 502 |
+
conflict_strategy: Annotated[Literal["skip", "overwrite"], Query()] = "skip",
|
| 503 |
+
storage: MetadataStorage = Depends(get_storage_instance),
|
| 504 |
+
api_key: str = Depends(verify_api_key)
|
| 505 |
+
):
|
| 506 |
+
try:
|
| 507 |
+
result_summary = storage.import_data(import_data_payload, conflict_strategy=conflict_strategy)
|
| 508 |
+
log_audit_event("IMPORT_DATA", api_key, details={"strategy": conflict_strategy, "summary": result_summary})
|
| 509 |
+
return result_summary
|
| 510 |
+
except NotImplementedError:
|
| 511 |
+
log_audit_event("IMPORT_DATA_FAILED_NOT_IMPLEMENTED", api_key, details={"strategy": conflict_strategy})
|
| 512 |
+
raise HTTPException(status_code=status.HTTP_501_NOT_IMPLEMENTED, detail="Import functionality is not implemented")
|
| 513 |
+
except Exception as e:
|
| 514 |
+
log_audit_event("IMPORT_DATA_FAILED_UNEXPECTED", api_key, details={"strategy": conflict_strategy, "error": str(e)})
|
| 515 |
+
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Unexpected error during import: {e}")
|
| 516 |
+
|
| 517 |
+
# --- Management Router for Health and Metrics ---
|
| 518 |
+
router_management = APIRouter(tags=["Management"])
|
| 519 |
+
class HealthResponse(PydanticBaseModel): status: str; backend: str; detail: Optional[str] = None
|
| 520 |
+
@router_management.get("/health", response_model=HealthResponse)
|
| 521 |
+
async def health_check(storage: MetadataStorage = Depends(get_storage_instance)):
|
| 522 |
+
is_healthy, backend_type = storage.check_health()
|
| 523 |
+
if is_healthy: return HealthResponse(status="ok", backend=backend_type)
|
| 524 |
+
else: return JSONResponse(status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
| 525 |
+
content=HealthResponse(status="error", backend=backend_type, detail="Storage backend connection failed.").model_dump())
|
| 526 |
+
class MetricsResponse(PydanticBaseModel):
|
| 527 |
+
total_tensor_descriptors: int; semantic_metadata_count: int; lineage_metadata_count: int
|
| 528 |
+
computational_metadata_count: int; quality_metadata_count: int; relational_metadata_count: int; usage_metadata_count: int
|
| 529 |
+
@router_management.get("/metrics", response_model=MetricsResponse)
|
| 530 |
+
async def get_metrics(storage: MetadataStorage = Depends(get_storage_instance)):
|
| 531 |
+
return MetricsResponse(
|
| 532 |
+
total_tensor_descriptors=storage.get_tensor_descriptors_count(),
|
| 533 |
+
semantic_metadata_count=storage.get_extended_metadata_count("SemanticMetadata"),
|
| 534 |
+
lineage_metadata_count=storage.get_extended_metadata_count("LineageMetadata"),
|
| 535 |
+
computational_metadata_count=storage.get_extended_metadata_count("ComputationalMetadata"),
|
| 536 |
+
quality_metadata_count=storage.get_extended_metadata_count("QualityMetadata"),
|
| 537 |
+
relational_metadata_count=storage.get_extended_metadata_count("RelationalMetadata"),
|
| 538 |
+
usage_metadata_count=storage.get_extended_metadata_count("UsageMetadata")
|
| 539 |
+
)
|
| 540 |
+
|
| 541 |
+
# --- Analytics Router ---
|
| 542 |
+
router_analytics = APIRouter(
|
| 543 |
+
prefix="/analytics",
|
| 544 |
+
tags=["Analytics"]
|
| 545 |
+
)
|
| 546 |
+
|
| 547 |
+
@router_analytics.get("/co_occurring_tags", response_model=Dict[str, List[Dict[str, Any]]],
|
| 548 |
+
summary="Get Co-occurring Tags",
|
| 549 |
+
description="Finds tags that frequently co-occur with other tags on tensor descriptors.")
|
| 550 |
+
async def api_get_co_occurring_tags(
|
| 551 |
+
min_co_occurrence: int = Query(2, ge=1, description="Minimum number of times tags must appear together."),
|
| 552 |
+
limit: int = Query(10, ge=1, le=100, description="Maximum number of co-occurring tags to return for each primary tag."),
|
| 553 |
+
storage: MetadataStorage = Depends(get_storage_instance)
|
| 554 |
+
):
|
| 555 |
+
try:
|
| 556 |
+
return storage.get_co_occurring_tags(min_co_occurrence=min_co_occurrence, limit=limit)
|
| 557 |
+
except NotImplementedError:
|
| 558 |
+
raise HTTPException(status_code=status.HTTP_501_NOT_IMPLEMENTED, detail="Co-occurring tags analytics not implemented for the current storage backend.")
|
| 559 |
+
except Exception as e:
|
| 560 |
+
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Error calculating co-occurring tags: {e}")
|
| 561 |
+
|
| 562 |
+
|
| 563 |
+
@router_analytics.get("/stale_tensors", response_model=List[TensorDescriptor],
|
| 564 |
+
summary="Get Stale Tensors",
|
| 565 |
+
description="Finds tensors that have not been accessed or modified for a given number of days.")
|
| 566 |
+
async def api_get_stale_tensors(
|
| 567 |
+
threshold_days: int = Query(90, ge=1, description="Number of days to consider a tensor stale."),
|
| 568 |
+
limit: int = Query(100, ge=1, le=1000, description="Maximum number of stale tensors to return."),
|
| 569 |
+
storage: MetadataStorage = Depends(get_storage_instance)
|
| 570 |
+
):
|
| 571 |
+
try:
|
| 572 |
+
return storage.get_stale_tensors(threshold_days=threshold_days, limit=limit)
|
| 573 |
+
except NotImplementedError:
|
| 574 |
+
raise HTTPException(status_code=status.HTTP_501_NOT_IMPLEMENTED, detail="Stale tensor analytics not implemented for the current storage backend.")
|
| 575 |
+
except Exception as e:
|
| 576 |
+
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Error fetching stale tensors: {e}")
|
| 577 |
+
|
| 578 |
+
|
| 579 |
+
@router_analytics.get("/complex_tensors", response_model=List[TensorDescriptor],
|
| 580 |
+
summary="Get Complex Tensors",
|
| 581 |
+
description="Finds tensors considered complex based on lineage (number of parents or transformation steps).")
|
| 582 |
+
async def api_get_complex_tensors(
|
| 583 |
+
min_parent_count: Optional[int] = Query(None, ge=0, description="Minimum number of parent tensors."),
|
| 584 |
+
min_transformation_steps: Optional[int] = Query(None, ge=0, description="Minimum number of transformation steps."),
|
| 585 |
+
limit: int = Query(100, ge=1, le=1000, description="Maximum number of complex tensors to return."),
|
| 586 |
+
storage: MetadataStorage = Depends(get_storage_instance)
|
| 587 |
+
):
|
| 588 |
+
if min_parent_count is None and min_transformation_steps is None:
|
| 589 |
+
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="At least one criterion (min_parent_count or min_transformation_steps) must be provided.")
|
| 590 |
+
try:
|
| 591 |
+
return storage.get_complex_tensors(
|
| 592 |
+
min_parent_count=min_parent_count,
|
| 593 |
+
min_transformation_steps=min_transformation_steps,
|
| 594 |
+
limit=limit
|
| 595 |
+
)
|
| 596 |
+
except ValueError as e:
|
| 597 |
+
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
| 598 |
+
except NotImplementedError:
|
| 599 |
+
raise HTTPException(status_code=status.HTTP_501_NOT_IMPLEMENTED, detail="Complex tensor analytics not implemented for the current storage backend.")
|
| 600 |
+
except Exception as e:
|
| 601 |
+
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Error fetching complex tensors: {e}")
|
tensorus/api/main.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from contextlib import asynccontextmanager
|
| 2 |
+
import logging
|
| 3 |
+
from fastapi import FastAPI
|
| 4 |
+
|
| 5 |
+
from .endpoints import (
|
| 6 |
+
router_tensor_descriptor,
|
| 7 |
+
router_semantic_metadata,
|
| 8 |
+
router_search_aggregate,
|
| 9 |
+
router_version_lineage,
|
| 10 |
+
router_extended_metadata,
|
| 11 |
+
router_io,
|
| 12 |
+
router_management,
|
| 13 |
+
router_analytics # Import the new analytics router
|
| 14 |
+
)
|
| 15 |
+
# Import storage_instance and PostgresMetadataStorage for shutdown event
|
| 16 |
+
from tensorus.metadata import storage_instance
|
| 17 |
+
from tensorus.metadata.postgres_storage import PostgresMetadataStorage
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@asynccontextmanager
|
| 21 |
+
async def lifespan(app: FastAPI):
|
| 22 |
+
# Code to run on startup
|
| 23 |
+
logging.info("Application startup: Lifespan event started.")
|
| 24 |
+
yield
|
| 25 |
+
# Code to run on shutdown
|
| 26 |
+
logging.info("Application shutdown: Attempting to close database connection pool via lifespan.")
|
| 27 |
+
if isinstance(storage_instance, PostgresMetadataStorage):
|
| 28 |
+
logging.info("Closing PostgreSQL connection pool.")
|
| 29 |
+
storage_instance.close_pool() # Assuming close_pool() is synchronous as per existing code
|
| 30 |
+
logging.info("PostgreSQL connection pool closed.")
|
| 31 |
+
else:
|
| 32 |
+
logging.info("No PostgreSQL pool instance found or it's not of expected type.")
|
| 33 |
+
logging.info("Application shutdown: Lifespan event finished.")
|
| 34 |
+
|
| 35 |
+
app = FastAPI(
|
| 36 |
+
title="Tensorus API",
|
| 37 |
+
version="0.1.0", # Consider updating version if features are added/changed significantly
|
| 38 |
+
description="API for managing Tensor Descriptors and Semantic Metadata.",
|
| 39 |
+
contact={
|
| 40 |
+
"name": "Tensorus Development Team",
|
| 41 |
+
"url": "http://example.com/contact", # Replace with actual contact/repo URL
|
| 42 |
+
"email": "dev@example.com", # Replace with actual email
|
| 43 |
+
},
|
| 44 |
+
license_info={
|
| 45 |
+
"name": "Apache 2.0", # Or your chosen license
|
| 46 |
+
"url": "https://www.apache.org/licenses/LICENSE-2.0.html", # Replace with actual license URL
|
| 47 |
+
},
|
| 48 |
+
lifespan=lifespan
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
# Include the routers
|
| 52 |
+
app.include_router(router_tensor_descriptor)
|
| 53 |
+
app.include_router(router_semantic_metadata)
|
| 54 |
+
app.include_router(router_search_aggregate)
|
| 55 |
+
app.include_router(router_version_lineage)
|
| 56 |
+
app.include_router(router_extended_metadata)
|
| 57 |
+
app.include_router(router_io)
|
| 58 |
+
app.include_router(router_management)
|
| 59 |
+
app.include_router(router_analytics) # Register the analytics router
|
| 60 |
+
|
| 61 |
+
@app.get("/", tags=["Root"], summary="Root Endpoint", description="Returns a welcome message for the Tensorus API.")
|
| 62 |
+
async def read_root():
|
| 63 |
+
return {"message": "Welcome to the Tensorus API"}
|
| 64 |
+
|
| 65 |
+
# Old shutdown event handler removed. New handling is in lifespan context manager.
|
| 66 |
+
|
| 67 |
+
# To run this application (for development):
|
| 68 |
+
# uvicorn tensorus.api.main:app --reload --port 7860
|
| 69 |
+
#
|
| 70 |
+
# You would typically have a `__main__.py` or a run script for this.
|
| 71 |
+
# For now, this structure allows importing `app` elsewhere if needed.
|
| 72 |
+
|
| 73 |
+
# Example of how to clear storage for testing (not a production endpoint)
|
| 74 |
+
# from tensorus.metadata.storage import storage_instance
|
| 75 |
+
# @app.post("/debug/clear_storage", tags=["Debug"], include_in_schema=False)
|
| 76 |
+
# async def debug_clear_storage():
|
| 77 |
+
# storage_instance.clear_all_data()
|
| 78 |
+
# return {"message": "All in-memory data cleared."}
|
tensorus/api/security.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import Security, HTTPException, status, Depends
|
| 2 |
+
from fastapi.security.api_key import APIKeyHeader
|
| 3 |
+
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
| 4 |
+
from typing import Optional, Dict, Any # Added Dict, Any for JWT payload
|
| 5 |
+
|
| 6 |
+
from tensorus.config import settings
|
| 7 |
+
from tensorus.audit import log_audit_event
|
| 8 |
+
from jose import jwt, JWTError
|
| 9 |
+
import requests
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class MutableAPIKeyHeader(APIKeyHeader):
|
| 13 |
+
"""APIKeyHeader that allows updating the header name for testing."""
|
| 14 |
+
|
| 15 |
+
@property
|
| 16 |
+
def name(self) -> str: # type: ignore[override]
|
| 17 |
+
return self.model.name
|
| 18 |
+
|
| 19 |
+
@name.setter
|
| 20 |
+
def name(self, value: str) -> None: # type: ignore[override]
|
| 21 |
+
self.model.name = value
|
| 22 |
+
|
| 23 |
+
# --- API Key Authentication ---
|
| 24 |
+
api_key_header_auth = MutableAPIKeyHeader(name=settings.API_KEY_HEADER_NAME, auto_error=False)
|
| 25 |
+
|
| 26 |
+
async def verify_api_key(api_key: Optional[str] = Security(api_key_header_auth)):
|
| 27 |
+
"""
|
| 28 |
+
Verifies the API key provided in the request header.
|
| 29 |
+
Raises HTTPException if the API key is missing or invalid.
|
| 30 |
+
Returns the API key string if valid.
|
| 31 |
+
"""
|
| 32 |
+
if not settings.VALID_API_KEYS:
|
| 33 |
+
# If no API keys are configured, treat as no valid keys configured.
|
| 34 |
+
# Endpoints depending on this will be inaccessible unless keys are provided.
|
| 35 |
+
pass
|
| 36 |
+
if not api_key:
|
| 37 |
+
raise HTTPException(
|
| 38 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 39 |
+
detail="Missing API Key"
|
| 40 |
+
)
|
| 41 |
+
if not settings.VALID_API_KEYS or api_key not in settings.VALID_API_KEYS:
|
| 42 |
+
# If list is empty OR key is not in the list
|
| 43 |
+
raise HTTPException(
|
| 44 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 45 |
+
detail="Invalid API Key"
|
| 46 |
+
)
|
| 47 |
+
return api_key
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
# --- JWT Token Authentication (Conceptual) ---
|
| 51 |
+
oauth2_scheme = HTTPBearer(auto_error=False) # auto_error=False means it won't raise error if token is missing
|
| 52 |
+
|
| 53 |
+
async def verify_jwt_token(token: Optional[HTTPAuthorizationCredentials] = Security(oauth2_scheme)) -> Dict[str, Any]:
|
| 54 |
+
"""
|
| 55 |
+
Conceptual JWT token verification dependency.
|
| 56 |
+
- If JWT auth is disabled, denies access if an endpoint specifically requires it (unless in dev dummy mode).
|
| 57 |
+
- If enabled and in dev dummy mode, allows any bearer token string.
|
| 58 |
+
- If enabled and not in dev dummy mode, raises 501 Not Implemented (actual validation needed here).
|
| 59 |
+
"""
|
| 60 |
+
if not settings.AUTH_JWT_ENABLED:
|
| 61 |
+
# If JWT auth is globally disabled:
|
| 62 |
+
# If an endpoint *still* tries to use this JWT verifier, it means it expects JWT.
|
| 63 |
+
# So, deny access because the system isn't configured for it.
|
| 64 |
+
# However, if AUTH_DEV_MODE_ALLOW_DUMMY_JWT is true, we might let it pass for local dev convenience
|
| 65 |
+
# even if AUTH_JWT_ENABLED is false (treating dummy mode as an override).
|
| 66 |
+
if settings.AUTH_DEV_MODE_ALLOW_DUMMY_JWT and token: # Token provided, dummy mode on
|
| 67 |
+
return {"sub": "dummy_jwt_user_jwt_disabled_but_dev_mode", "username": "dummy_dev_jwt", "token_type": "dummy_bearer_dev"}
|
| 68 |
+
|
| 69 |
+
# Standard behavior: if JWT is not enabled, this verifier should fail if called.
|
| 70 |
+
raise HTTPException(
|
| 71 |
+
status_code=status.HTTP_503_SERVICE_UNAVAILABLE, # Or 403 Forbidden
|
| 72 |
+
detail="JWT authentication is not enabled or configured for this service."
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
# JWT Auth is enabled, proceed.
|
| 76 |
+
if not token:
|
| 77 |
+
raise HTTPException(
|
| 78 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 79 |
+
detail="Not authenticated via JWT (No token provided)"
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
if settings.AUTH_DEV_MODE_ALLOW_DUMMY_JWT:
|
| 83 |
+
# In dev dummy mode, allow any token value.
|
| 84 |
+
return {"sub": "dummy_jwt_user", "username": "dummy_jwt_user", "token_type": "dummy_bearer", "token_value": token.credentials}
|
| 85 |
+
|
| 86 |
+
if not settings.AUTH_JWT_JWKS_URI:
|
| 87 |
+
raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="JWKS URI not configured")
|
| 88 |
+
|
| 89 |
+
try:
|
| 90 |
+
jwks_data = requests.get(settings.AUTH_JWT_JWKS_URI).json()
|
| 91 |
+
except Exception as e: # pragma: no cover - network issues
|
| 92 |
+
log_audit_event("JWT_VALIDATION_FAILED", details={"error": f"Failed fetching JWKS: {e}"})
|
| 93 |
+
raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Unable to fetch JWKS")
|
| 94 |
+
|
| 95 |
+
unverified_header = jwt.get_unverified_header(token.credentials)
|
| 96 |
+
rsa_key = None
|
| 97 |
+
for key in jwks_data.get("keys", []):
|
| 98 |
+
if key.get("kid") == unverified_header.get("kid"):
|
| 99 |
+
rsa_key = key
|
| 100 |
+
break
|
| 101 |
+
|
| 102 |
+
if rsa_key is None:
|
| 103 |
+
log_audit_event("JWT_VALIDATION_FAILED", details={"error": "kid not found"})
|
| 104 |
+
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token header")
|
| 105 |
+
|
| 106 |
+
try:
|
| 107 |
+
claims = jwt.decode(
|
| 108 |
+
token.credentials,
|
| 109 |
+
rsa_key,
|
| 110 |
+
algorithms=[settings.AUTH_JWT_ALGORITHM],
|
| 111 |
+
issuer=settings.AUTH_JWT_ISSUER,
|
| 112 |
+
audience=settings.AUTH_JWT_AUDIENCE,
|
| 113 |
+
)
|
| 114 |
+
except JWTError as e:
|
| 115 |
+
log_audit_event("JWT_VALIDATION_FAILED", details={"error": str(e)})
|
| 116 |
+
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid JWT token")
|
| 117 |
+
|
| 118 |
+
log_audit_event("JWT_VALIDATION_SUCCESS", user=claims.get("sub"), details={"issuer": claims.get("iss"), "aud": claims.get("aud")})
|
| 119 |
+
return claims
|
| 120 |
+
|
| 121 |
+
# Example of how to use it in an endpoint:
|
| 122 |
+
# from fastapi import Depends
|
| 123 |
+
# from .security import verify_api_key
|
| 124 |
+
#
|
| 125 |
+
# @router.post("/some_protected_route", dependencies=[Depends(verify_api_key)])
|
| 126 |
+
# async def protected_route_function():
|
| 127 |
+
# # ...
|
| 128 |
+
#
|
| 129 |
+
# Or if you need the key value:
|
| 130 |
+
# @router.post("/another_route")
|
| 131 |
+
# async def another_route_function(api_key: str = Depends(verify_api_key)):
|
| 132 |
+
# # api_key variable now holds the validated key
|
| 133 |
+
# # ...
|
tensorus/audit.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from typing import Optional, Dict, Any
|
| 3 |
+
import sys
|
| 4 |
+
from tensorus.config import settings
|
| 5 |
+
|
| 6 |
+
# Configure basic logger
|
| 7 |
+
# In a real app, this might be more complex (e.g., JSON logging, log rotation, external service)
|
| 8 |
+
LOG_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
| 9 |
+
LOG_DEFAULT_HANDLERS = [logging.StreamHandler(sys.stdout)] # Log to stdout by default
|
| 10 |
+
|
| 11 |
+
# Attempt to create a log file handler.
|
| 12 |
+
# This is a simple file logger; in production, consider more robust solutions.
|
| 13 |
+
try:
|
| 14 |
+
file_handler = logging.FileHandler(settings.AUDIT_LOG_PATH)
|
| 15 |
+
file_handler.setFormatter(logging.Formatter(LOG_FORMAT))
|
| 16 |
+
LOG_DEFAULT_HANDLERS.append(file_handler)
|
| 17 |
+
except IOError:
|
| 18 |
+
# Handle cases where file cannot be opened (e.g. permissions)
|
| 19 |
+
print(
|
| 20 |
+
f"Warning: Could not open {settings.AUDIT_LOG_PATH} for writing. Audit logs will go to stdout only.",
|
| 21 |
+
file=sys.stderr,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
logging.basicConfig(level=logging.INFO, format=LOG_FORMAT, handlers=LOG_DEFAULT_HANDLERS)
|
| 26 |
+
audit_logger = logging.getLogger("tensorus.audit")
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def log_audit_event(
|
| 30 |
+
action: str,
|
| 31 |
+
user: Optional[str] = "anonymous",
|
| 32 |
+
tensor_id: Optional[str] = None, # Make tensor_id a common, optional parameter
|
| 33 |
+
details: Optional[Dict[str, Any]] = None
|
| 34 |
+
):
|
| 35 |
+
"""
|
| 36 |
+
Logs an audit event.
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
action: A string describing the action performed (e.g., "CREATE_TENSOR_DESCRIPTOR").
|
| 40 |
+
user: The user or API key performing the action. Defaults to "anonymous".
|
| 41 |
+
tensor_id: The primary tensor_id involved in the action, if applicable.
|
| 42 |
+
details: A dictionary of additional relevant information about the event.
|
| 43 |
+
"""
|
| 44 |
+
log_message_parts = [f"Action: {action}"]
|
| 45 |
+
|
| 46 |
+
log_message_parts.append(f"User: {user if user else 'unknown'}")
|
| 47 |
+
|
| 48 |
+
if tensor_id:
|
| 49 |
+
log_message_parts.append(f"TensorID: {tensor_id}")
|
| 50 |
+
|
| 51 |
+
if details:
|
| 52 |
+
# Convert details dict to a string format, e.g., key1=value1, key2=value2
|
| 53 |
+
details_str = ", ".join([f"{k}={v}" for k, v in details.items()])
|
| 54 |
+
log_message_parts.append(f"Details: [{details_str}]")
|
| 55 |
+
|
| 56 |
+
audit_logger.info(" | ".join(log_message_parts))
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
# Example Usage (not part of the library code, just for demonstration):
|
| 60 |
+
if __name__ == "__main__":
|
| 61 |
+
# This will only run if the script is executed directly
|
| 62 |
+
log_audit_event(action="TEST_EVENT", user="test_user", tensor_id="dummy-uuid-123", details={"param1": "value1", "status": "success"})
|
| 63 |
+
audit_logger.info("This is a direct info log from audit_logger for testing handlers.")
|
| 64 |
+
audit_logger.warning("This is a warning log for testing handlers.")
|
| 65 |
+
print(f"Audit logger '{audit_logger.name}' has handlers: {audit_logger.handlers}")
|
| 66 |
+
if not audit_logger.handlers:
|
| 67 |
+
print("Warning: Audit logger has no handlers configured if run outside main app context without direct basicConfig call here.")
|
| 68 |
+
# In the application, basicConfig in this file should set up the handlers globally once.
|
tensorus/automl_agent.py
ADDED
|
@@ -0,0 +1,356 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# automl_agent.py
|
| 2 |
+
"""
|
| 3 |
+
Implements the AutoML Agent for Tensorus.
|
| 4 |
+
|
| 5 |
+
This agent performs basic hyperparameter optimization using random search.
|
| 6 |
+
It trains a simple dummy model on synthetic data, evaluates its performance,
|
| 7 |
+
and logs the results, including storing trial results in TensorStorage.
|
| 8 |
+
|
| 9 |
+
Future Enhancements:
|
| 10 |
+
- Implement more advanced search strategies (Bayesian Optimization, Hyperband).
|
| 11 |
+
- Allow configuration of different model architectures.
|
| 12 |
+
- Integrate with real datasets from TensorStorage.
|
| 13 |
+
- Implement early stopping and other training optimizations.
|
| 14 |
+
- Store best model state_dict (requires serialization strategy).
|
| 15 |
+
- Parallelize trials for faster search.
|
| 16 |
+
- Use dedicated hyperparameter optimization libraries (Optuna, Ray Tune).
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
import torch.nn as nn
|
| 21 |
+
import torch.optim as optim
|
| 22 |
+
import numpy as np
|
| 23 |
+
import random
|
| 24 |
+
import logging
|
| 25 |
+
import time
|
| 26 |
+
from typing import Dict, Any, Callable, Tuple, Optional
|
| 27 |
+
|
| 28 |
+
from .tensor_storage import TensorStorage # Import our storage module
|
| 29 |
+
|
| 30 |
+
# Configure logging
|
| 31 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
| 32 |
+
logger = logging.getLogger(__name__)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
# --- Dummy Model Definition ---
|
| 36 |
+
class DummyMLP(nn.Module):
|
| 37 |
+
"""A simple Multi-Layer Perceptron for regression/classification."""
|
| 38 |
+
def __init__(self, input_dim: int, output_dim: int, hidden_size: int = 64, activation_fn: Callable = nn.ReLU):
|
| 39 |
+
super().__init__()
|
| 40 |
+
self.layer_1 = nn.Linear(input_dim, hidden_size)
|
| 41 |
+
self.activation = activation_fn()
|
| 42 |
+
self.layer_2 = nn.Linear(hidden_size, output_dim)
|
| 43 |
+
|
| 44 |
+
def forward(self, x):
|
| 45 |
+
x = self.activation(self.layer_1(x))
|
| 46 |
+
x = self.layer_2(x)
|
| 47 |
+
return x
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
# --- AutoML Agent Class ---
|
| 51 |
+
class AutoMLAgent:
|
| 52 |
+
"""Performs random search hyperparameter optimization."""
|
| 53 |
+
|
| 54 |
+
def __init__(self,
|
| 55 |
+
tensor_storage: TensorStorage,
|
| 56 |
+
search_space: Dict[str, Callable[[], Any]],
|
| 57 |
+
input_dim: int,
|
| 58 |
+
output_dim: int,
|
| 59 |
+
task_type: str = 'regression', # 'regression' or 'classification'
|
| 60 |
+
results_dataset: str = "automl_results"):
|
| 61 |
+
"""
|
| 62 |
+
Initializes the AutoML Agent.
|
| 63 |
+
|
| 64 |
+
Args:
|
| 65 |
+
tensor_storage: An instance of TensorStorage.
|
| 66 |
+
search_space: Dictionary defining the hyperparameter search space.
|
| 67 |
+
Keys are param names (e.g., 'lr', 'hidden_size').
|
| 68 |
+
Values are functions that sample a value for that param (e.g., lambda: 10**random.uniform(-4,-2)).
|
| 69 |
+
input_dim: Input dimension for the dummy model.
|
| 70 |
+
output_dim: Output dimension for the dummy model.
|
| 71 |
+
task_type: Type of task ('regression' or 'classification'), influences loss and data generation.
|
| 72 |
+
results_dataset: Name of the dataset in TensorStorage to store trial results.
|
| 73 |
+
"""
|
| 74 |
+
if not isinstance(tensor_storage, TensorStorage):
|
| 75 |
+
raise TypeError("tensor_storage must be an instance of TensorStorage")
|
| 76 |
+
if task_type not in ['regression', 'classification']:
|
| 77 |
+
raise ValueError("task_type must be 'regression' or 'classification'")
|
| 78 |
+
|
| 79 |
+
self.tensor_storage = tensor_storage
|
| 80 |
+
self.search_space = search_space
|
| 81 |
+
self.input_dim = input_dim
|
| 82 |
+
self.output_dim = output_dim
|
| 83 |
+
self.task_type = task_type
|
| 84 |
+
self.results_dataset = results_dataset
|
| 85 |
+
|
| 86 |
+
# Ensure results dataset exists
|
| 87 |
+
try:
|
| 88 |
+
self.tensor_storage.get_dataset(self.results_dataset)
|
| 89 |
+
except ValueError:
|
| 90 |
+
logger.info(f"Dataset '{self.results_dataset}' not found. Creating it.")
|
| 91 |
+
self.tensor_storage.create_dataset(self.results_dataset)
|
| 92 |
+
|
| 93 |
+
# Track best results found during the search
|
| 94 |
+
self.best_score: Optional[float] = None # Use negative infinity for maximization tasks if needed
|
| 95 |
+
self.best_params: Optional[Dict[str, Any]] = None
|
| 96 |
+
# Assuming lower score is better (e.g., loss)
|
| 97 |
+
self.higher_score_is_better = False if task_type == 'regression' else True # Accuracy for classification
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
# Device configuration
|
| 101 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 102 |
+
logger.info(f"AutoML Agent using device: {self.device}")
|
| 103 |
+
logger.info(f"AutoML Agent initialized for {task_type} task. Results stored in '{results_dataset}'.")
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def _generate_synthetic_data(self, n_samples=500, batch_size=32) -> Tuple[Any, Any]:
|
| 107 |
+
"""Generates synthetic data loaders for training and validation."""
|
| 108 |
+
X = torch.randn(n_samples, self.input_dim, device=self.device)
|
| 109 |
+
|
| 110 |
+
if self.task_type == 'regression':
|
| 111 |
+
# Simple linear relationship with noise
|
| 112 |
+
true_weight = torch.randn(self.input_dim, self.output_dim, device=self.device) * 2
|
| 113 |
+
true_bias = torch.randn(self.output_dim, device=self.device)
|
| 114 |
+
y = X @ true_weight + true_bias + torch.randn(n_samples, self.output_dim, device=self.device) * 0.5
|
| 115 |
+
loss_fn = nn.MSELoss()
|
| 116 |
+
else: # classification
|
| 117 |
+
# Simple linear separation + softmax for multi-class
|
| 118 |
+
if self.output_dim <= 1:
|
| 119 |
+
raise ValueError("Output dimension must be > 1 for classification task example.")
|
| 120 |
+
true_weight = torch.randn(self.input_dim, self.output_dim, device=self.device)
|
| 121 |
+
logits = X @ true_weight
|
| 122 |
+
y = torch.softmax(logits, dim=1).argmax(dim=1) # Get class labels
|
| 123 |
+
loss_fn = nn.CrossEntropyLoss()
|
| 124 |
+
|
| 125 |
+
# Simple split
|
| 126 |
+
split_idx = int(n_samples * 0.8)
|
| 127 |
+
X_train, X_val = X[:split_idx], X[split_idx:]
|
| 128 |
+
y_train, y_val = y[:split_idx], y[split_idx:]
|
| 129 |
+
|
| 130 |
+
train_dataset = torch.utils.data.TensorDataset(X_train, y_train)
|
| 131 |
+
val_dataset = torch.utils.data.TensorDataset(X_val, y_val)
|
| 132 |
+
|
| 133 |
+
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
| 134 |
+
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size)
|
| 135 |
+
|
| 136 |
+
return train_loader, val_loader, loss_fn
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def _build_dummy_model(self, params: Dict[str, Any]) -> nn.Module:
|
| 140 |
+
"""Builds the dummy MLP model based on hyperparameters."""
|
| 141 |
+
hidden_size = params.get('hidden_size', 64) # Default if not in params
|
| 142 |
+
activation_name = params.get('activation', 'relu') # Default activation
|
| 143 |
+
|
| 144 |
+
act_fn_map = {'relu': nn.ReLU, 'tanh': nn.Tanh, 'sigmoid': nn.Sigmoid}
|
| 145 |
+
activation_fn = act_fn_map.get(activation_name.lower(), nn.ReLU) # Default to ReLU if unknown
|
| 146 |
+
|
| 147 |
+
model = DummyMLP(
|
| 148 |
+
input_dim=self.input_dim,
|
| 149 |
+
output_dim=self.output_dim,
|
| 150 |
+
hidden_size=hidden_size,
|
| 151 |
+
activation_fn=activation_fn
|
| 152 |
+
).to(self.device)
|
| 153 |
+
return model
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def _train_and_evaluate(self, params: Dict[str, Any], num_epochs: int = 5) -> Optional[float]:
|
| 157 |
+
"""Trains and evaluates a model with given hyperparameters."""
|
| 158 |
+
logger.debug(f"Training trial with params: {params}")
|
| 159 |
+
start_time = time.time()
|
| 160 |
+
|
| 161 |
+
try:
|
| 162 |
+
# 1. Build Model
|
| 163 |
+
model = self._build_dummy_model(params)
|
| 164 |
+
|
| 165 |
+
# 2. Get Data and Loss Function
|
| 166 |
+
train_loader, val_loader, loss_fn = self._generate_synthetic_data()
|
| 167 |
+
|
| 168 |
+
# 3. Setup Optimizer
|
| 169 |
+
lr = params.get('lr', 1e-3) # Default LR
|
| 170 |
+
optimizer = optim.Adam(model.parameters(), lr=lr)
|
| 171 |
+
|
| 172 |
+
# 4. Training Loop
|
| 173 |
+
model.train()
|
| 174 |
+
for epoch in range(num_epochs):
|
| 175 |
+
epoch_loss = 0
|
| 176 |
+
for batch_X, batch_y in train_loader:
|
| 177 |
+
optimizer.zero_grad()
|
| 178 |
+
outputs = model(batch_X)
|
| 179 |
+
loss = loss_fn(outputs, batch_y)
|
| 180 |
+
|
| 181 |
+
# Check for NaN/inf loss
|
| 182 |
+
if not torch.isfinite(loss):
|
| 183 |
+
logger.warning(f"Trial failed: Non-finite loss detected during training epoch {epoch}. Params: {params}")
|
| 184 |
+
return None # Indicate failure
|
| 185 |
+
|
| 186 |
+
loss.backward()
|
| 187 |
+
optimizer.step()
|
| 188 |
+
epoch_loss += loss.item()
|
| 189 |
+
# logger.debug(f" Epoch {epoch+1}/{num_epochs}, Train Loss: {epoch_loss/len(train_loader):.4f}")
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
# 5. Evaluation Loop
|
| 193 |
+
model.eval()
|
| 194 |
+
total_val_loss = 0
|
| 195 |
+
total_correct = 0
|
| 196 |
+
total_samples = 0
|
| 197 |
+
with torch.no_grad():
|
| 198 |
+
for batch_X, batch_y in val_loader:
|
| 199 |
+
outputs = model(batch_X)
|
| 200 |
+
loss = loss_fn(outputs, batch_y)
|
| 201 |
+
total_val_loss += loss.item()
|
| 202 |
+
|
| 203 |
+
if self.task_type == 'classification':
|
| 204 |
+
predicted = outputs.argmax(dim=1)
|
| 205 |
+
total_correct += (predicted == batch_y).sum().item()
|
| 206 |
+
total_samples += batch_y.size(0)
|
| 207 |
+
|
| 208 |
+
avg_val_loss = total_val_loss / len(val_loader)
|
| 209 |
+
duration = time.time() - start_time
|
| 210 |
+
logger.debug(f"Trial completed in {duration:.2f}s. Val Loss: {avg_val_loss:.4f}")
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
# 6. Return Score
|
| 214 |
+
if self.task_type == 'regression':
|
| 215 |
+
score = avg_val_loss # Lower is better
|
| 216 |
+
else: # classification
|
| 217 |
+
accuracy = total_correct / total_samples if total_samples > 0 else 0
|
| 218 |
+
score = accuracy # Higher is better
|
| 219 |
+
logger.debug(f" Trial Val Accuracy: {accuracy:.4f}")
|
| 220 |
+
|
| 221 |
+
return score
|
| 222 |
+
|
| 223 |
+
except Exception as e:
|
| 224 |
+
logger.error(f"Trial failed with exception for params {params}: {e}", exc_info=True)
|
| 225 |
+
return None # Indicate failure
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def hyperparameter_search(self, trials: int, num_epochs_per_trial: int = 5) -> Optional[Dict[str, Any]]:
|
| 229 |
+
"""
|
| 230 |
+
Performs random search for the specified number of trials.
|
| 231 |
+
|
| 232 |
+
Args:
|
| 233 |
+
trials: The number of hyperparameter configurations to try.
|
| 234 |
+
num_epochs_per_trial: Number of epochs to train each model configuration.
|
| 235 |
+
|
| 236 |
+
Returns:
|
| 237 |
+
The dictionary of hyperparameters that achieved the best score, or None if no trial succeeded.
|
| 238 |
+
"""
|
| 239 |
+
logger.info(f"--- Starting Hyperparameter Search ({trials} trials) ---")
|
| 240 |
+
self.best_score = None
|
| 241 |
+
self.best_params = None
|
| 242 |
+
|
| 243 |
+
for i in range(trials):
|
| 244 |
+
# 1. Sample hyperparameters
|
| 245 |
+
current_params = {name: sampler() for name, sampler in self.search_space.items()}
|
| 246 |
+
logger.info(f"Trial {i+1}/{trials}: Testing params: {current_params}")
|
| 247 |
+
|
| 248 |
+
# 2. Train and evaluate
|
| 249 |
+
score = self._train_and_evaluate(current_params, num_epochs=num_epochs_per_trial)
|
| 250 |
+
|
| 251 |
+
# 3. Store results in TensorStorage (even if trial failed, record params and score=None)
|
| 252 |
+
score_tensor = torch.tensor(float('nan') if score is None else score) # Store NaN for failed trials
|
| 253 |
+
trial_metadata = {
|
| 254 |
+
"trial_id": i + 1,
|
| 255 |
+
"params": current_params, # Store params dict directly in metadata
|
| 256 |
+
"score": score, # Store score also in metadata for easier querying
|
| 257 |
+
"task_type": self.task_type,
|
| 258 |
+
"search_timestamp_utc": time.time(),
|
| 259 |
+
"created_by": "AutoMLAgent" # Add agent source
|
| 260 |
+
}
|
| 261 |
+
try:
|
| 262 |
+
record_id = self.tensor_storage.insert(
|
| 263 |
+
self.results_dataset,
|
| 264 |
+
score_tensor,
|
| 265 |
+
trial_metadata
|
| 266 |
+
)
|
| 267 |
+
logger.debug(f"Stored trial {i+1} results (Score: {score}) with record ID: {record_id}")
|
| 268 |
+
except Exception as e:
|
| 269 |
+
logger.error(f"Failed to store trial {i+1} results in TensorStorage: {e}")
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
# 4. Update best score if trial succeeded and is better
|
| 273 |
+
if score is not None:
|
| 274 |
+
is_better = False
|
| 275 |
+
if self.best_score is None:
|
| 276 |
+
is_better = True
|
| 277 |
+
elif self.higher_score_is_better and score > self.best_score:
|
| 278 |
+
is_better = True
|
| 279 |
+
elif not self.higher_score_is_better and score < self.best_score:
|
| 280 |
+
is_better = True
|
| 281 |
+
|
| 282 |
+
if is_better:
|
| 283 |
+
self.best_score = score
|
| 284 |
+
self.best_params = current_params
|
| 285 |
+
logger.info(f"*** New best score found! Trial {i+1}: Score={score:.4f}, Params={current_params} ***")
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
logger.info(f"--- Hyperparameter Search Finished ---")
|
| 289 |
+
if self.best_params:
|
| 290 |
+
logger.info(f"Best score overall: {self.best_score:.4f}")
|
| 291 |
+
logger.info(f"Best hyperparameters found: {self.best_params}")
|
| 292 |
+
# Optional: Here you could trigger saving the best model's state_dict
|
| 293 |
+
# e.g., self._save_best_model(self.best_params)
|
| 294 |
+
else:
|
| 295 |
+
logger.warning("No successful trials completed. Could not determine best parameters.")
|
| 296 |
+
|
| 297 |
+
return self.best_params
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
# --- Example Usage ---
|
| 301 |
+
if __name__ == "__main__":
|
| 302 |
+
print("--- Starting AutoML Agent Example ---")
|
| 303 |
+
|
| 304 |
+
# 1. Setup TensorStorage
|
| 305 |
+
storage = TensorStorage()
|
| 306 |
+
|
| 307 |
+
# 2. Define Search Space
|
| 308 |
+
# Simple example for MLP regression
|
| 309 |
+
search_space_reg = {
|
| 310 |
+
'lr': lambda: 10**random.uniform(-5, -2), # Log uniform for learning rate
|
| 311 |
+
'hidden_size': lambda: random.choice([32, 64, 128, 256]),
|
| 312 |
+
'activation': lambda: random.choice(['relu', 'tanh'])
|
| 313 |
+
}
|
| 314 |
+
|
| 315 |
+
# 3. Create the AutoML Agent
|
| 316 |
+
input_dim = 10
|
| 317 |
+
output_dim = 1 # Regression task
|
| 318 |
+
automl_agent = AutoMLAgent(
|
| 319 |
+
tensor_storage=storage,
|
| 320 |
+
search_space=search_space_reg,
|
| 321 |
+
input_dim=input_dim,
|
| 322 |
+
output_dim=output_dim,
|
| 323 |
+
task_type='regression',
|
| 324 |
+
results_dataset="automl_regression_results"
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
# 4. Run the hyperparameter search
|
| 328 |
+
num_trials = 20 # Number of random configurations to test
|
| 329 |
+
num_epochs = 10 # Epochs per trial (keep low for speed)
|
| 330 |
+
best_hyperparams = automl_agent.hyperparameter_search(trials=num_trials, num_epochs_per_trial=num_epochs)
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
# 5. Optional: Check TensorStorage for results
|
| 334 |
+
print("\n--- Checking TensorStorage contents (Sample) ---")
|
| 335 |
+
try:
|
| 336 |
+
results_count = len(storage.get_dataset(automl_agent.results_dataset))
|
| 337 |
+
print(f"Found {results_count} trial records in '{automl_agent.results_dataset}'.")
|
| 338 |
+
|
| 339 |
+
if results_count > 0:
|
| 340 |
+
print("\nExample trial record (metadata):")
|
| 341 |
+
# Sample one record to show structure
|
| 342 |
+
sample_trial = storage.sample_dataset(automl_agent.results_dataset, 1)
|
| 343 |
+
if sample_trial:
|
| 344 |
+
print(f" Metadata: {sample_trial[0]['metadata']}")
|
| 345 |
+
print(f" Score Tensor: {sample_trial[0]['tensor']}") # Should contain the score or NaN
|
| 346 |
+
|
| 347 |
+
# You could also query for the best score using NQLAgent if needed
|
| 348 |
+
# (e.g., find record where score = best_score) - requires parsing results first.
|
| 349 |
+
|
| 350 |
+
except ValueError as e:
|
| 351 |
+
print(f"Could not retrieve dataset '{automl_agent.results_dataset}': {e}")
|
| 352 |
+
except Exception as e:
|
| 353 |
+
print(f"An error occurred checking storage: {e}")
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
print("\n--- AutoML Agent Example Finished ---")
|
tensorus/config.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic_settings import BaseSettings, SettingsConfigDict
|
| 2 |
+
from typing import Optional
|
| 3 |
+
from pydantic import field_validator
|
| 4 |
+
|
| 5 |
+
class Settings(BaseSettings):
|
| 6 |
+
# Configuration for environment variable prefix, .env file, etc.
|
| 7 |
+
# For Pydantic-Settings v2, environment variables are loaded by default.
|
| 8 |
+
# To specify a prefix for env vars (e.g. TENSORUS_STORAGE_BACKEND):
|
| 9 |
+
# model_config = SettingsConfigDict(env_prefix='TENSORUS_') # Pydantic v2
|
| 10 |
+
# For Pydantic v1, it was `Config.env_prefix`.
|
| 11 |
+
|
| 12 |
+
STORAGE_BACKEND: str = "in_memory"
|
| 13 |
+
|
| 14 |
+
POSTGRES_HOST: Optional[str] = None
|
| 15 |
+
POSTGRES_PORT: Optional[int] = 5432 # Default PostgreSQL port
|
| 16 |
+
POSTGRES_USER: Optional[str] = None
|
| 17 |
+
POSTGRES_PASSWORD: Optional[str] = None
|
| 18 |
+
POSTGRES_DB: Optional[str] = None
|
| 19 |
+
POSTGRES_DSN: Optional[str] = None # Alternative to individual params
|
| 20 |
+
|
| 21 |
+
# Example of how to load from a .env file if needed (not strictly required by subtask)
|
| 22 |
+
# model_config = SettingsConfigDict(env_file='.env', env_file_encoding='utf-8', extra='ignore')
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
# Global instance of the settings
|
| 26 |
+
# The environment variables will be loaded when this instance is created.
|
| 27 |
+
# e.g. TENSORUS_STORAGE_BACKEND=postgres will override the default.
|
| 28 |
+
# Note: For pydantic-settings, env var names are case-insensitive by default for matching.
|
| 29 |
+
# If env_prefix is set, it would be TENSORUS_STORAGE_BACKEND. Without it, it's just STORAGE_BACKEND.
|
| 30 |
+
# Let's assume no prefix for now, so environment variables should be STORAGE_BACKEND, POSTGRES_HOST etc.
|
| 31 |
+
# OR, more commonly, one would use the env_prefix.
|
| 32 |
+
# For this exercise, I will assume the user will set environment variables like:
|
| 33 |
+
# export STORAGE_BACKEND="postgres"
|
| 34 |
+
# export POSTGRES_USER="myuser"
|
| 35 |
+
# ...etc.
|
| 36 |
+
# Or, if using an .env file:
|
| 37 |
+
# STORAGE_BACKEND="postgres"
|
| 38 |
+
# POSTGRES_USER="myuser"
|
| 39 |
+
# ...
|
| 40 |
+
#
|
| 41 |
+
# For Pydantic V1 BaseSettings, it would be:
|
| 42 |
+
# class Settings(BaseSettings):
|
| 43 |
+
# STORAGE_BACKEND: str = "in_memory"
|
| 44 |
+
# # ... other fields
|
| 45 |
+
# class Config:
|
| 46 |
+
# env_prefix = "TENSORUS_" # e.g. TENSORUS_STORAGE_BACKEND
|
| 47 |
+
# # case_sensitive = False # for Pydantic V1
|
| 48 |
+
#
|
| 49 |
+
# Given the project uses pydantic 1.10, I will use the V1 style for env_prefix.
|
| 50 |
+
|
| 51 |
+
class SettingsV1(BaseSettings):
|
| 52 |
+
STORAGE_BACKEND: str = "in_memory"
|
| 53 |
+
POSTGRES_HOST: Optional[str] = None
|
| 54 |
+
POSTGRES_PORT: Optional[int] = 5432
|
| 55 |
+
POSTGRES_USER: Optional[str] = None
|
| 56 |
+
POSTGRES_PASSWORD: Optional[str] = None
|
| 57 |
+
POSTGRES_DB: Optional[str] = None
|
| 58 |
+
POSTGRES_DSN: Optional[str] = None
|
| 59 |
+
|
| 60 |
+
# API Security
|
| 61 |
+
VALID_API_KEYS: list[str] = [] # Comma-separated string in env, e.g., "key1,key2,key3"
|
| 62 |
+
API_KEY_HEADER_NAME: str = "X-API-KEY"
|
| 63 |
+
AUDIT_LOG_PATH: str = "tensorus_audit.log"
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
model_config = SettingsConfigDict(env_prefix="TENSORUS_", case_sensitive=False)
|
| 69 |
+
|
| 70 |
+
@field_validator("VALID_API_KEYS", mode="before")
|
| 71 |
+
def split_valid_api_keys(cls, v):
|
| 72 |
+
if isinstance(v, str):
|
| 73 |
+
return [key.strip() for key in v.split(',') if key.strip()]
|
| 74 |
+
return v
|
| 75 |
+
|
| 76 |
+
# JWT Authentication (Conceptual Settings)
|
| 77 |
+
AUTH_JWT_ENABLED: bool = False
|
| 78 |
+
AUTH_JWT_ISSUER: Optional[str] = None
|
| 79 |
+
AUTH_JWT_AUDIENCE: Optional[str] = None
|
| 80 |
+
AUTH_JWT_ALGORITHM: str = "RS256"
|
| 81 |
+
AUTH_JWT_JWKS_URI: Optional[str] = None
|
| 82 |
+
AUTH_DEV_MODE_ALLOW_DUMMY_JWT: bool = False
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
# Use SettingsV1 for Pydantic v1.x compatibility
|
| 86 |
+
settings = SettingsV1()
|
| 87 |
+
|
| 88 |
+
# Manual parsing for VALID_API_KEYS if it's a comma-separated string from env
|
| 89 |
+
# This is a common workaround for Pydantic v1 BaseSettings if the env var is not a JSON list.
|
| 90 |
+
import os
|
| 91 |
+
raw_keys = os.getenv("TENSORUS_VALID_API_KEYS")
|
| 92 |
+
if raw_keys:
|
| 93 |
+
settings.VALID_API_KEYS = [key.strip() for key in raw_keys.split(',')]
|
| 94 |
+
elif not settings.VALID_API_KEYS: # Ensure it's an empty list if env var is not set and default_factory wasn't used
|
| 95 |
+
settings.VALID_API_KEYS = []
|
tensorus/dummy_env.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# dummy_env.py
|
| 2 |
+
"""
|
| 3 |
+
A simple dummy environment for testing the RL agent.
|
| 4 |
+
State: Position (1D)
|
| 5 |
+
Action: Move left (-1), Stay (0), Move right (+1) (Discrete actions)
|
| 6 |
+
Goal: Reach position 0
|
| 7 |
+
Reward: -abs(position), +10 if at goal
|
| 8 |
+
"""
|
| 9 |
+
from typing import Tuple, Dict
|
| 10 |
+
import torch
|
| 11 |
+
import random
|
| 12 |
+
import numpy as np # Use numpy for state representation convenience
|
| 13 |
+
|
| 14 |
+
class DummyEnv:
|
| 15 |
+
def __init__(self, max_steps=50):
|
| 16 |
+
self.state_dim = 1 # Position is the only state variable
|
| 17 |
+
self.action_dim = 3 # Actions: 0 (left), 1 (stay), 2 (right)
|
| 18 |
+
self.max_steps = max_steps
|
| 19 |
+
self.current_pos = 0.0
|
| 20 |
+
self.steps_taken = 0
|
| 21 |
+
self.goal_pos = 0.0
|
| 22 |
+
self.max_pos = 5.0 # Boundaries
|
| 23 |
+
|
| 24 |
+
def reset(self) -> torch.Tensor:
|
| 25 |
+
"""Resets the environment to a random starting position."""
|
| 26 |
+
self.current_pos = random.uniform(-self.max_pos, self.max_pos)
|
| 27 |
+
self.steps_taken = 0
|
| 28 |
+
# Return state as a PyTorch tensor
|
| 29 |
+
return torch.tensor([self.current_pos], dtype=torch.float32)
|
| 30 |
+
|
| 31 |
+
def step(self, action: int) -> Tuple[torch.Tensor, float, bool, Dict]:
|
| 32 |
+
"""Takes an action, updates the state, and returns results."""
|
| 33 |
+
if not isinstance(action, int) or action not in [0, 1, 2]:
|
| 34 |
+
raise ValueError(f"Invalid action: {action}. Must be 0, 1, or 2.")
|
| 35 |
+
|
| 36 |
+
# Update position based on action
|
| 37 |
+
if action == 0: # Move left
|
| 38 |
+
self.current_pos -= 0.5
|
| 39 |
+
elif action == 2: # Move right
|
| 40 |
+
self.current_pos += 0.5
|
| 41 |
+
# Action 1 (stay) does nothing to position
|
| 42 |
+
|
| 43 |
+
# Clip position to boundaries
|
| 44 |
+
self.current_pos = np.clip(self.current_pos, -self.max_pos, self.max_pos)
|
| 45 |
+
|
| 46 |
+
self.steps_taken += 1
|
| 47 |
+
|
| 48 |
+
# Calculate reward
|
| 49 |
+
# Higher reward closer to the goal, large penalty for being far
|
| 50 |
+
reward = -abs(self.current_pos - self.goal_pos) * 0.1 # Small penalty for distance
|
| 51 |
+
done = False
|
| 52 |
+
|
| 53 |
+
# Check if goal is reached (within a small tolerance)
|
| 54 |
+
if abs(self.current_pos - self.goal_pos) < 0.1:
|
| 55 |
+
reward += 10.0 # Bonus for reaching goal
|
| 56 |
+
done = True
|
| 57 |
+
|
| 58 |
+
# Check if max steps exceeded
|
| 59 |
+
if self.steps_taken >= self.max_steps:
|
| 60 |
+
done = True
|
| 61 |
+
# Optional: small penalty for running out of time
|
| 62 |
+
# reward -= 1.0
|
| 63 |
+
|
| 64 |
+
# Return next state, reward, done flag, and info dict
|
| 65 |
+
next_state = torch.tensor([self.current_pos], dtype=torch.float32)
|
| 66 |
+
info = {} # Empty info dict for now
|
| 67 |
+
|
| 68 |
+
return next_state, float(reward), done, info
|
| 69 |
+
|
| 70 |
+
# Example Usage
|
| 71 |
+
if __name__ == "__main__":
|
| 72 |
+
env = DummyEnv()
|
| 73 |
+
state = env.reset()
|
| 74 |
+
print(f"Initial state: {state.item()}")
|
| 75 |
+
done = False
|
| 76 |
+
total_reward = 0
|
| 77 |
+
steps = 0
|
| 78 |
+
|
| 79 |
+
while not done:
|
| 80 |
+
action = random.choice([0, 1, 2]) # Take random action
|
| 81 |
+
next_state, reward, done, _ = env.step(action)
|
| 82 |
+
print(f"Step {steps+1}: Action={action}, Next State={next_state.item():.2f}, Reward={reward:.2f}, Done={done}")
|
| 83 |
+
state = next_state
|
| 84 |
+
total_reward += reward
|
| 85 |
+
steps += 1
|
| 86 |
+
if steps > env.max_steps + 5: # Safety break
|
| 87 |
+
print("Exceeded max steps significantly, breaking.")
|
| 88 |
+
break
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
print(f"\nEpisode finished after {steps} steps. Total reward: {total_reward:.2f}")
|
tensorus/financial_data_generator.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import numpy as np
|
| 3 |
+
import datetime
|
| 4 |
+
|
| 5 |
+
def generate_financial_data(
|
| 6 |
+
days: int = 500,
|
| 7 |
+
initial_price: float = 100.0,
|
| 8 |
+
trend_slope: float = 0.05,
|
| 9 |
+
seasonality_amplitude: float = 10.0,
|
| 10 |
+
seasonality_period: int = 90,
|
| 11 |
+
noise_level: float = 2.0,
|
| 12 |
+
base_volume: int = 100000,
|
| 13 |
+
volume_volatility: float = 0.3,
|
| 14 |
+
start_date_str: str = "2022-01-01"
|
| 15 |
+
) -> pd.DataFrame:
|
| 16 |
+
"""
|
| 17 |
+
Generates synthetic time series data resembling daily stock prices.
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
days (int): Number of trading days to generate data for.
|
| 21 |
+
initial_price (float): Starting price for the stock.
|
| 22 |
+
trend_slope (float): Slope for the linear trend component (price change per day).
|
| 23 |
+
seasonality_amplitude (float): Amplitude of the seasonal (sine wave) component.
|
| 24 |
+
seasonality_period (int): Period in days for the seasonality.
|
| 25 |
+
noise_level (float): Standard deviation of the random noise added to the price.
|
| 26 |
+
base_volume (int): Base daily trading volume.
|
| 27 |
+
volume_volatility (float): Percentage volatility for volume (e.g., 0.3 for +/-30%).
|
| 28 |
+
start_date_str (str): Start date for the data in 'YYYY-MM-DD' format.
|
| 29 |
+
|
| 30 |
+
Returns:
|
| 31 |
+
pd.DataFrame: A DataFrame with columns ['Date', 'Close', 'Volume'].
|
| 32 |
+
'Date' is datetime, 'Close' is float, 'Volume' is int.
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
# 1. Create date range
|
| 36 |
+
try:
|
| 37 |
+
start_date = datetime.datetime.strptime(start_date_str, "%Y-%m-%d")
|
| 38 |
+
except ValueError:
|
| 39 |
+
raise ValueError("start_date_str must be in 'YYYY-MM-DD' format")
|
| 40 |
+
|
| 41 |
+
# Using pandas date_range for business days if needed, or simple days
|
| 42 |
+
# For simplicity, using consecutive days. For actual trading days, use bdate_range.
|
| 43 |
+
dates = pd.date_range(start_date, periods=days, freq='D')
|
| 44 |
+
|
| 45 |
+
# Time component for trend and seasonality
|
| 46 |
+
time_component = np.arange(days)
|
| 47 |
+
|
| 48 |
+
# 2. Generate linear trend component
|
| 49 |
+
trend = trend_slope * time_component
|
| 50 |
+
|
| 51 |
+
# 3. Generate seasonal component (sine wave)
|
| 52 |
+
seasonal = seasonality_amplitude * np.sin(2 * np.pi * time_component / seasonality_period)
|
| 53 |
+
|
| 54 |
+
# 4. Generate random noise
|
| 55 |
+
noise = np.random.normal(loc=0, scale=noise_level, size=days)
|
| 56 |
+
|
| 57 |
+
# 5. Calculate Close price
|
| 58 |
+
close_prices = initial_price + trend + seasonal + noise
|
| 59 |
+
# Ensure prices don't go below a certain minimum (e.g., 1.0)
|
| 60 |
+
close_prices = np.maximum(close_prices, 1.0)
|
| 61 |
+
|
| 62 |
+
# 6. Generate Volume data
|
| 63 |
+
volume_random_factor = np.random.uniform(-volume_volatility, volume_volatility, size=days)
|
| 64 |
+
volumes = base_volume * (1 + volume_random_factor)
|
| 65 |
+
# Ensure volume is positive and integer
|
| 66 |
+
volumes = np.maximum(volumes, 0).astype(int)
|
| 67 |
+
|
| 68 |
+
# 7. Create DataFrame
|
| 69 |
+
df = pd.DataFrame({
|
| 70 |
+
'Date': dates,
|
| 71 |
+
'Close': close_prices,
|
| 72 |
+
'Volume': volumes
|
| 73 |
+
})
|
| 74 |
+
|
| 75 |
+
# For OHLC (Optional - can be expanded later)
|
| 76 |
+
# Open: Could be previous day's close, or close +/- small random factor
|
| 77 |
+
# High: Close + small positive random factor
|
| 78 |
+
# Low: Close - small positive random factor (ensure Low <= Open/Close and High >= Open/Close)
|
| 79 |
+
# For now, these are omitted as per requirements.
|
| 80 |
+
|
| 81 |
+
return df
|
| 82 |
+
|
| 83 |
+
if __name__ == "__main__":
|
| 84 |
+
# Generate data with default parameters
|
| 85 |
+
print("Generating synthetic financial data with default parameters...")
|
| 86 |
+
financial_df = generate_financial_data()
|
| 87 |
+
|
| 88 |
+
# Print head and tail
|
| 89 |
+
print("\nDataFrame Head:")
|
| 90 |
+
print(financial_df.head())
|
| 91 |
+
print("\nDataFrame Tail:")
|
| 92 |
+
print(financial_df.tail())
|
| 93 |
+
print(f"\nGenerated {len(financial_df)} days of data.")
|
| 94 |
+
|
| 95 |
+
# Save to CSV
|
| 96 |
+
output_filename = "synthetic_financial_data.csv"
|
| 97 |
+
try:
|
| 98 |
+
financial_df.to_csv(output_filename, index=False)
|
| 99 |
+
print(f"\nSuccessfully saved data to {output_filename}")
|
| 100 |
+
except Exception as e:
|
| 101 |
+
print(f"\nError saving data to CSV: {e}")
|
| 102 |
+
|
| 103 |
+
# Example with custom parameters
|
| 104 |
+
print("\nGenerating synthetic financial data with custom parameters...")
|
| 105 |
+
custom_financial_df = generate_financial_data(
|
| 106 |
+
days=100,
|
| 107 |
+
initial_price=50.0,
|
| 108 |
+
trend_slope=-0.1, # Downward trend
|
| 109 |
+
seasonality_amplitude=5.0,
|
| 110 |
+
seasonality_period=30,
|
| 111 |
+
noise_level=1.0,
|
| 112 |
+
base_volume=50000,
|
| 113 |
+
start_date_str="2023-01-01"
|
| 114 |
+
)
|
| 115 |
+
print("\nCustom DataFrame Head:")
|
| 116 |
+
print(custom_financial_df.head())
|
| 117 |
+
custom_output_filename = "custom_synthetic_financial_data.csv"
|
| 118 |
+
try:
|
| 119 |
+
custom_financial_df.to_csv(custom_output_filename, index=False)
|
| 120 |
+
print(f"\nSuccessfully saved custom data to {custom_output_filename}")
|
| 121 |
+
except Exception as e:
|
| 122 |
+
print(f"\nError saving custom data to CSV: {e}")
|
tensorus/ingestion_agent.py
ADDED
|
@@ -0,0 +1,409 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ingestion_agent.py
|
| 2 |
+
"""
|
| 3 |
+
Implements the Autonomous Data Ingestion Agent for Tensorus.
|
| 4 |
+
|
| 5 |
+
This agent monitors a source directory for new data files (e.g., CSV, images),
|
| 6 |
+
preprocesses them into tensors using configurable functions, performs basic
|
| 7 |
+
validation, and inserts them into a specified dataset in TensorStorage.
|
| 8 |
+
|
| 9 |
+
Future Enhancements:
|
| 10 |
+
- Monitor cloud storage (S3, GCS) and APIs.
|
| 11 |
+
- More robust error handling for malformed files.
|
| 12 |
+
- More sophisticated duplicate detection (e.g., file hashing).
|
| 13 |
+
- Support for streaming data sources.
|
| 14 |
+
- Asynchronous processing for higher throughput.
|
| 15 |
+
- More complex and configurable preprocessing pipelines.
|
| 16 |
+
- Schema validation against predefined dataset schemas.
|
| 17 |
+
- Resource management controls.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
import os
|
| 21 |
+
import time
|
| 22 |
+
import glob
|
| 23 |
+
import logging
|
| 24 |
+
import threading
|
| 25 |
+
import csv
|
| 26 |
+
from PIL import Image
|
| 27 |
+
import torch
|
| 28 |
+
import torchvision.transforms as T # Use torchvision for image transforms
|
| 29 |
+
import collections # Added import
|
| 30 |
+
|
| 31 |
+
from typing import Dict, Callable, Optional, Tuple, List, Any
|
| 32 |
+
from .tensor_storage import TensorStorage # Import our storage module
|
| 33 |
+
|
| 34 |
+
# Configure basic logging (can be customized further)
|
| 35 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
| 36 |
+
logger = logging.getLogger(__name__)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
# --- Custom Log Handler ---
|
| 40 |
+
class AgentMemoryLogHandler(logging.Handler):
|
| 41 |
+
"""
|
| 42 |
+
A custom logging handler that stores log records in a deque.
|
| 43 |
+
"""
|
| 44 |
+
def __init__(self, deque: collections.deque):
|
| 45 |
+
super().__init__()
|
| 46 |
+
self.deque = deque
|
| 47 |
+
|
| 48 |
+
def emit(self, record: logging.LogRecord) -> None:
|
| 49 |
+
self.deque.append(self.format(record))
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
# --- Default Preprocessing Functions ---
|
| 53 |
+
|
| 54 |
+
def preprocess_csv(file_path: str) -> Tuple[Optional[torch.Tensor], Dict[str, Any]]:
|
| 55 |
+
"""
|
| 56 |
+
Basic CSV preprocessor. Assumes numeric data.
|
| 57 |
+
Reads a CSV, converts rows to tensors (one tensor per row or one tensor for the whole file).
|
| 58 |
+
Returns a single tensor representing the whole file for simplicity here.
|
| 59 |
+
"""
|
| 60 |
+
metadata = {"source_file": file_path, "type": "csv"}
|
| 61 |
+
data = []
|
| 62 |
+
try:
|
| 63 |
+
with open(file_path, 'r', newline='') as csvfile:
|
| 64 |
+
reader = csv.reader(csvfile)
|
| 65 |
+
header = next(reader, None) # Skip header row
|
| 66 |
+
metadata["header"] = header
|
| 67 |
+
for row in reader:
|
| 68 |
+
# Attempt to convert row elements to floats
|
| 69 |
+
try:
|
| 70 |
+
numeric_row = [float(item) for item in row]
|
| 71 |
+
data.append(numeric_row)
|
| 72 |
+
except ValueError:
|
| 73 |
+
logger.warning(f"Skipping non-numeric row in {file_path}: {row}")
|
| 74 |
+
continue # Skip rows that can't be fully converted to float
|
| 75 |
+
|
| 76 |
+
if not data:
|
| 77 |
+
logger.warning(f"No numeric data found or processed in CSV file: {file_path}")
|
| 78 |
+
return None, metadata
|
| 79 |
+
|
| 80 |
+
tensor = torch.tensor(data, dtype=torch.float32)
|
| 81 |
+
logger.debug(f"Successfully processed {file_path} into tensor shape {tensor.shape}")
|
| 82 |
+
return tensor, metadata
|
| 83 |
+
except Exception as e:
|
| 84 |
+
logger.error(f"Failed to process CSV file {file_path}: {e}")
|
| 85 |
+
return None, metadata
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def preprocess_image(file_path: str) -> Tuple[Optional[torch.Tensor], Dict[str, Any]]:
|
| 89 |
+
"""
|
| 90 |
+
Basic Image preprocessor using Pillow and Torchvision transforms.
|
| 91 |
+
Opens an image, applies standard transformations (resize, normalize),
|
| 92 |
+
and returns it as a tensor.
|
| 93 |
+
"""
|
| 94 |
+
metadata = {"source_file": file_path, "type": "image"}
|
| 95 |
+
try:
|
| 96 |
+
img = Image.open(file_path).convert('RGB') # Ensure 3 channels (RGB)
|
| 97 |
+
|
| 98 |
+
# Example transform: Resize, convert to tensor, normalize
|
| 99 |
+
# These should ideally be configurable
|
| 100 |
+
transform = T.Compose([
|
| 101 |
+
T.Resize((128, 128)), # Example fixed size
|
| 102 |
+
T.ToTensor(), # Converts PIL image (H, W, C) [0,255] to Tensor (C, H, W) [0,1]
|
| 103 |
+
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # ImageNet stats
|
| 104 |
+
])
|
| 105 |
+
|
| 106 |
+
tensor = transform(img)
|
| 107 |
+
metadata["original_size"] = img.size # (width, height)
|
| 108 |
+
logger.debug(f"Successfully processed {file_path} into tensor shape {tensor.shape}")
|
| 109 |
+
return tensor, metadata
|
| 110 |
+
except FileNotFoundError:
|
| 111 |
+
logger.error(f"Image file not found: {file_path}")
|
| 112 |
+
return None, metadata
|
| 113 |
+
except Exception as e:
|
| 114 |
+
logger.error(f"Failed to process image file {file_path}: {e}")
|
| 115 |
+
return None, metadata
|
| 116 |
+
|
| 117 |
+
# --- Data Ingestion Agent Class ---
|
| 118 |
+
|
| 119 |
+
class DataIngestionAgent:
|
| 120 |
+
"""
|
| 121 |
+
Monitors a directory for new files and ingests them into TensorStorage.
|
| 122 |
+
"""
|
| 123 |
+
|
| 124 |
+
def __init__(self,
|
| 125 |
+
tensor_storage: TensorStorage,
|
| 126 |
+
dataset_name: str,
|
| 127 |
+
source_directory: str,
|
| 128 |
+
polling_interval_sec: int = 10,
|
| 129 |
+
preprocessing_rules: Optional[Dict[str, Callable[[str], Tuple[Optional[torch.Tensor], Dict[str, Any]]]]] = None):
|
| 130 |
+
"""
|
| 131 |
+
Initializes the DataIngestionAgent.
|
| 132 |
+
|
| 133 |
+
Args:
|
| 134 |
+
tensor_storage: An instance of the TensorStorage class.
|
| 135 |
+
dataset_name: The name of the dataset in TensorStorage to ingest into.
|
| 136 |
+
source_directory: The path to the local directory to monitor.
|
| 137 |
+
polling_interval_sec: How often (in seconds) to check the directory.
|
| 138 |
+
preprocessing_rules: A dictionary mapping lowercase file extensions
|
| 139 |
+
(e.g., '.csv', '.png') to preprocessing functions.
|
| 140 |
+
Each function takes a file path and returns a
|
| 141 |
+
Tuple containing the processed Tensor (or None on failure)
|
| 142 |
+
and a metadata dictionary. If None, default rules are used.
|
| 143 |
+
"""
|
| 144 |
+
if not isinstance(tensor_storage, TensorStorage):
|
| 145 |
+
raise TypeError("tensor_storage must be an instance of TensorStorage")
|
| 146 |
+
if not os.path.isdir(source_directory):
|
| 147 |
+
raise ValueError(f"Source directory '{source_directory}' does not exist or is not a directory.")
|
| 148 |
+
|
| 149 |
+
self.tensor_storage = tensor_storage
|
| 150 |
+
self.dataset_name = dataset_name
|
| 151 |
+
self.source_directory = source_directory
|
| 152 |
+
self.polling_interval = polling_interval_sec
|
| 153 |
+
|
| 154 |
+
# Ensure dataset exists
|
| 155 |
+
try:
|
| 156 |
+
self.tensor_storage.get_dataset(self.dataset_name)
|
| 157 |
+
logger.info(f"Agent targeting existing dataset '{self.dataset_name}'.")
|
| 158 |
+
except ValueError:
|
| 159 |
+
logger.info(f"Dataset '{self.dataset_name}' not found. Creating it.")
|
| 160 |
+
self.tensor_storage.create_dataset(self.dataset_name)
|
| 161 |
+
|
| 162 |
+
# Default preprocessing rules if none provided
|
| 163 |
+
if preprocessing_rules is None:
|
| 164 |
+
self.preprocessing_rules = {
|
| 165 |
+
'.csv': preprocess_csv,
|
| 166 |
+
'.png': preprocess_image,
|
| 167 |
+
'.jpg': preprocess_image,
|
| 168 |
+
'.jpeg': preprocess_image,
|
| 169 |
+
'.tif': preprocess_image,
|
| 170 |
+
'.tiff': preprocess_image,
|
| 171 |
+
}
|
| 172 |
+
logger.info("Using default preprocessing rules for CSV and common image formats.")
|
| 173 |
+
else:
|
| 174 |
+
self.preprocessing_rules = preprocessing_rules
|
| 175 |
+
logger.info(f"Using custom preprocessing rules for extensions: {list(preprocessing_rules.keys())}")
|
| 176 |
+
|
| 177 |
+
self.processed_files = set() # Keep track of files already processed in this session
|
| 178 |
+
self._stop_event = threading.Event()
|
| 179 |
+
self._monitor_thread = None
|
| 180 |
+
self.status = "stopped" # Status reporting
|
| 181 |
+
self.logs = collections.deque(maxlen=100) # Log capturing
|
| 182 |
+
|
| 183 |
+
# Setup custom log handler
|
| 184 |
+
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
|
| 185 |
+
memory_handler = AgentMemoryLogHandler(self.logs)
|
| 186 |
+
memory_handler.setFormatter(formatter)
|
| 187 |
+
logger.addHandler(memory_handler) # Add handler to the specific logger instance
|
| 188 |
+
|
| 189 |
+
logger.info(f"DataIngestionAgent initialized for dataset '{self.dataset_name}' monitoring '{self.source_directory}'.")
|
| 190 |
+
|
| 191 |
+
def get_status(self) -> str:
|
| 192 |
+
"""Returns the current status of the agent."""
|
| 193 |
+
return self.status
|
| 194 |
+
|
| 195 |
+
def get_logs(self, max_lines: Optional[int] = None) -> List[str]:
|
| 196 |
+
"""Returns recent log messages from the agent."""
|
| 197 |
+
if max_lines is None:
|
| 198 |
+
return list(self.logs)
|
| 199 |
+
else:
|
| 200 |
+
return list(self.logs)[-max_lines:]
|
| 201 |
+
|
| 202 |
+
def _validate_data(self, tensor: Optional[torch.Tensor], metadata: Dict[str, Any]) -> bool:
|
| 203 |
+
"""
|
| 204 |
+
Performs basic validation on the preprocessed tensor.
|
| 205 |
+
Returns True if valid, False otherwise.
|
| 206 |
+
"""
|
| 207 |
+
if tensor is None:
|
| 208 |
+
logger.warning(f"Validation failed: Tensor is None for {metadata.get('source_file', 'N/A')}")
|
| 209 |
+
return False
|
| 210 |
+
if not isinstance(tensor, torch.Tensor):
|
| 211 |
+
logger.warning(f"Validation failed: Output is not a tensor for {metadata.get('source_file', 'N/A')}")
|
| 212 |
+
return False
|
| 213 |
+
# Add more specific checks if needed (e.g., tensor.numel() > 0)
|
| 214 |
+
return True
|
| 215 |
+
|
| 216 |
+
def _process_file(self, file_path: str) -> None:
|
| 217 |
+
"""Processes a single detected file."""
|
| 218 |
+
logger.info(f"Detected new file: {file_path}")
|
| 219 |
+
_, file_extension = os.path.splitext(file_path)
|
| 220 |
+
file_extension = file_extension.lower()
|
| 221 |
+
|
| 222 |
+
preprocessor = self.preprocessing_rules.get(file_extension)
|
| 223 |
+
|
| 224 |
+
if preprocessor:
|
| 225 |
+
logger.debug(f"Applying preprocessor for '{file_extension}' to {file_path}")
|
| 226 |
+
try:
|
| 227 |
+
tensor, metadata = preprocessor(file_path)
|
| 228 |
+
|
| 229 |
+
if self._validate_data(tensor, metadata):
|
| 230 |
+
# Ensure tensor is not None before insertion
|
| 231 |
+
if tensor is not None:
|
| 232 |
+
metadata["created_by"] = "IngestionAgent" # Add agent source
|
| 233 |
+
record_id = self.tensor_storage.insert(self.dataset_name, tensor, metadata)
|
| 234 |
+
logger.info(f"Successfully ingested '{file_path}' into dataset '{self.dataset_name}' with record ID: {record_id} (created_by: IngestionAgent)")
|
| 235 |
+
self.processed_files.add(file_path) # Mark as processed only on success
|
| 236 |
+
else:
|
| 237 |
+
# Should have been caught by validation, but as safeguard:
|
| 238 |
+
logger.error(f"Validation passed but tensor is None for {file_path}. Skipping insertion.")
|
| 239 |
+
|
| 240 |
+
else:
|
| 241 |
+
logger.warning(f"Data validation failed for {file_path}. Skipping insertion.")
|
| 242 |
+
|
| 243 |
+
except Exception as e:
|
| 244 |
+
logger.error(f"Unhandled error during preprocessing or insertion for {file_path}: {e}", exc_info=True)
|
| 245 |
+
else:
|
| 246 |
+
logger.debug(f"No preprocessor configured for file extension '{file_extension}'. Skipping file: {file_path}")
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
def _scan_source_directory(self) -> None:
|
| 250 |
+
"""Scans the source directory for new files matching the rules."""
|
| 251 |
+
logger.debug(f"Scanning directory: {self.source_directory}")
|
| 252 |
+
supported_extensions = self.preprocessing_rules.keys()
|
| 253 |
+
|
| 254 |
+
try:
|
| 255 |
+
# Use glob to find all files, then filter
|
| 256 |
+
# This might be inefficient for huge directories, consider os.scandir
|
| 257 |
+
all_files = glob.glob(os.path.join(self.source_directory, '*'), recursive=False) # Non-recursive
|
| 258 |
+
|
| 259 |
+
for file_path in all_files:
|
| 260 |
+
if not os.path.isfile(file_path):
|
| 261 |
+
continue # Skip directories
|
| 262 |
+
|
| 263 |
+
_, file_extension = os.path.splitext(file_path)
|
| 264 |
+
file_extension = file_extension.lower()
|
| 265 |
+
|
| 266 |
+
if file_extension in supported_extensions and file_path not in self.processed_files:
|
| 267 |
+
self._process_file(file_path)
|
| 268 |
+
|
| 269 |
+
except Exception as e:
|
| 270 |
+
logger.error(f"Error scanning source directory '{self.source_directory}': {e}", exc_info=True)
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
def _monitor_loop(self) -> None:
|
| 274 |
+
"""The main loop executed by the background thread."""
|
| 275 |
+
logger.info(f"Starting monitoring loop for '{self.source_directory}'. Polling interval: {self.polling_interval} seconds.")
|
| 276 |
+
while not self._stop_event.is_set():
|
| 277 |
+
self._scan_source_directory()
|
| 278 |
+
# Wait for the specified interval or until stop event is set
|
| 279 |
+
self._stop_event.wait(self.polling_interval)
|
| 280 |
+
logger.info("Monitoring loop stopped.")
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
def start(self) -> None:
|
| 284 |
+
"""Starts the monitoring process in a background thread."""
|
| 285 |
+
if self._monitor_thread is not None and self._monitor_thread.is_alive():
|
| 286 |
+
logger.warning("Agent monitoring is already running.")
|
| 287 |
+
return
|
| 288 |
+
|
| 289 |
+
self._stop_event.clear()
|
| 290 |
+
self._monitor_thread = threading.Thread(target=self._monitor_loop, daemon=True)
|
| 291 |
+
self._monitor_thread.start()
|
| 292 |
+
if self._monitor_thread.is_alive(): # Check if thread actually started
|
| 293 |
+
self.status = "running"
|
| 294 |
+
logger.info("Data Ingestion Agent started monitoring.")
|
| 295 |
+
else:
|
| 296 |
+
self.status = "error" # Or some other appropriate error state
|
| 297 |
+
logger.error("Data Ingestion Agent failed to start monitoring thread.")
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
def stop(self) -> None:
|
| 301 |
+
"""Signals the monitoring thread to stop."""
|
| 302 |
+
if self._monitor_thread is None or not self._monitor_thread.is_alive():
|
| 303 |
+
logger.info("Agent monitoring is not running.")
|
| 304 |
+
return
|
| 305 |
+
|
| 306 |
+
logger.info("Stopping Data Ingestion Agent monitoring...")
|
| 307 |
+
self._stop_event.set()
|
| 308 |
+
self._monitor_thread.join(timeout=self.polling_interval + 5) # Wait for thread to finish
|
| 309 |
+
|
| 310 |
+
if self._monitor_thread.is_alive():
|
| 311 |
+
logger.warning("Monitoring thread did not stop gracefully after timeout.")
|
| 312 |
+
# self.status remains "running" or could be set to "stopping_error"
|
| 313 |
+
else:
|
| 314 |
+
logger.info("Data Ingestion Agent monitoring stopped successfully.")
|
| 315 |
+
self.status = "stopped"
|
| 316 |
+
self._monitor_thread = None
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
# --- Example Usage ---
|
| 320 |
+
if __name__ == "__main__":
|
| 321 |
+
run_example = os.getenv("RUN_INGESTION_AGENT_EXAMPLE", "False").lower() == "true"
|
| 322 |
+
|
| 323 |
+
if run_example:
|
| 324 |
+
logger.info("--- Starting Ingestion Agent Example (RUN_INGESTION_AGENT_EXAMPLE=True) ---")
|
| 325 |
+
|
| 326 |
+
# 1. Setup TensorStorage
|
| 327 |
+
storage = TensorStorage()
|
| 328 |
+
|
| 329 |
+
# 2. Setup a temporary directory for the agent to monitor
|
| 330 |
+
source_dir = "temp_ingestion_source"
|
| 331 |
+
if not os.path.exists(source_dir):
|
| 332 |
+
os.makedirs(source_dir)
|
| 333 |
+
logger.info(f"Created temporary source directory: {source_dir}")
|
| 334 |
+
|
| 335 |
+
# 3. Create the Ingestion Agent
|
| 336 |
+
# We'll use a short polling interval for demonstration
|
| 337 |
+
agent = DataIngestionAgent(
|
| 338 |
+
tensor_storage=storage,
|
| 339 |
+
dataset_name="raw_data",
|
| 340 |
+
source_directory=source_dir,
|
| 341 |
+
polling_interval_sec=5
|
| 342 |
+
)
|
| 343 |
+
|
| 344 |
+
# 4. Start the agent (runs in the background)
|
| 345 |
+
agent.start()
|
| 346 |
+
|
| 347 |
+
# 5. Simulate adding files to the source directory
|
| 348 |
+
print("\nSimulating file creation...")
|
| 349 |
+
time.sleep(2) # Give agent time to start initial scan
|
| 350 |
+
|
| 351 |
+
# Create a dummy CSV file
|
| 352 |
+
csv_path = os.path.join(source_dir, "data_1.csv")
|
| 353 |
+
with open(csv_path, 'w', newline='') as f:
|
| 354 |
+
writer = csv.writer(f)
|
| 355 |
+
writer.writerow(["Timestamp", "Value1", "Value2"])
|
| 356 |
+
writer.writerow(["1678886400", "10.5", "20.1"])
|
| 357 |
+
writer.writerow(["1678886460", "11.2", "20.5"])
|
| 358 |
+
writer.writerow(["1678886520", "invalid", "20.9"]) # Test non-numeric row
|
| 359 |
+
writer.writerow(["1678886580", "10.9", "21.0"])
|
| 360 |
+
print(f"Created CSV: {csv_path}")
|
| 361 |
+
|
| 362 |
+
# Create a dummy image file (requires Pillow)
|
| 363 |
+
try:
|
| 364 |
+
img_path = os.path.join(source_dir, "image_1.png")
|
| 365 |
+
dummy_img = Image.new('RGB', (60, 30), color = 'red')
|
| 366 |
+
dummy_img.save(img_path)
|
| 367 |
+
print(f"Created Image: {img_path}")
|
| 368 |
+
except ImportError:
|
| 369 |
+
print("Pillow not installed, skipping image creation. Install with: pip install Pillow")
|
| 370 |
+
except Exception as e:
|
| 371 |
+
print(f"Could not create dummy image: {e}")
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
# Create an unsupported file type
|
| 375 |
+
txt_path = os.path.join(source_dir, "notes.txt")
|
| 376 |
+
with open(txt_path, 'w') as f:
|
| 377 |
+
f.write("This is a test file.")
|
| 378 |
+
print(f"Created TXT: {txt_path} (should be skipped)")
|
| 379 |
+
|
| 380 |
+
|
| 381 |
+
# 6. Let the agent run for a couple of polling cycles
|
| 382 |
+
print(f"\nWaiting for agent to process files (polling interval {agent.polling_interval}s)...")
|
| 383 |
+
time.sleep(agent.polling_interval * 2 + 1) # Wait for 2 cycles + buffer
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
# 7. Check the contents of TensorStorage
|
| 387 |
+
print("\n--- Checking TensorStorage contents ---")
|
| 388 |
+
try:
|
| 389 |
+
ingested_data = storage.get_dataset_with_metadata("raw_data")
|
| 390 |
+
print(f"Found {len(ingested_data)} items in dataset 'raw_data':")
|
| 391 |
+
for item in ingested_data:
|
| 392 |
+
print(f" Record ID: {item['metadata'].get('record_id')}, Source: {item['metadata'].get('source_file')}, Shape: {item['tensor'].shape}, Dtype: {item['tensor'].dtype}")
|
| 393 |
+
# print(f" Tensor: {item['tensor']}") # Can be verbose
|
| 394 |
+
except ValueError as e:
|
| 395 |
+
print(f"Could not retrieve dataset 'raw_data': {e}")
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
# 8. Stop the agent
|
| 399 |
+
print("\n--- Stopping Agent ---")
|
| 400 |
+
agent.stop()
|
| 401 |
+
|
| 402 |
+
# 9. Clean up the temporary directory (optional)
|
| 403 |
+
# print(f"\nCleaning up temporary directory: {source_dir}")
|
| 404 |
+
# import shutil
|
| 405 |
+
# shutil.rmtree(source_dir)
|
| 406 |
+
|
| 407 |
+
logger.info("--- Ingestion Agent Example Finished ---")
|
| 408 |
+
else:
|
| 409 |
+
logger.info("--- Ingestion Agent Example SKIPPED (RUN_INGESTION_AGENT_EXAMPLE not set to 'true') ---")
|
tensorus/mcp_client.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tensorus MCP client built on fastmcp.Client."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from typing import Any, Optional, Sequence
|
| 8 |
+
|
| 9 |
+
from fastmcp.client import Client as FastMCPClient
|
| 10 |
+
try:
|
| 11 |
+
from fastmcp.tools import TextContent
|
| 12 |
+
except Exception: # pragma: no cover - minimal fallback
|
| 13 |
+
@dataclass
|
| 14 |
+
class TextContent: # type: ignore
|
| 15 |
+
type: str
|
| 16 |
+
text: str
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class TensorusMCPClient:
|
| 20 |
+
"""High level client for the Tensorus MCP server."""
|
| 21 |
+
|
| 22 |
+
def __init__(self, transport: Any) -> None:
|
| 23 |
+
self._client = FastMCPClient(transport)
|
| 24 |
+
|
| 25 |
+
async def __aenter__(self) -> "TensorusMCPClient":
|
| 26 |
+
await self._client.__aenter__()
|
| 27 |
+
return self
|
| 28 |
+
|
| 29 |
+
async def __aexit__(self, exc_type, exc, tb) -> None:
|
| 30 |
+
await self._client.__aexit__(exc_type, exc, tb)
|
| 31 |
+
|
| 32 |
+
async def _call_json(self, name: str, arguments: Optional[dict] = None) -> Any:
|
| 33 |
+
result = await self._client.call_tool(name, arguments or {})
|
| 34 |
+
if not result:
|
| 35 |
+
return None
|
| 36 |
+
content = result[0]
|
| 37 |
+
if isinstance(content, TextContent):
|
| 38 |
+
return json.loads(content.text)
|
| 39 |
+
raise TypeError("Unexpected content type")
|
| 40 |
+
|
| 41 |
+
# --- Dataset management ---
|
| 42 |
+
async def list_datasets(self) -> Any:
|
| 43 |
+
return await self._call_json("tensorus_list_datasets")
|
| 44 |
+
|
| 45 |
+
async def create_dataset(self, dataset_name: str) -> Any:
|
| 46 |
+
return await self._call_json("tensorus_create_dataset", {"dataset_name": dataset_name})
|
| 47 |
+
|
| 48 |
+
async def delete_dataset(self, dataset_name: str) -> Any:
|
| 49 |
+
return await self._call_json("tensorus_delete_dataset", {"dataset_name": dataset_name})
|
| 50 |
+
|
| 51 |
+
# --- Tensor management ---
|
| 52 |
+
async def ingest_tensor(
|
| 53 |
+
self,
|
| 54 |
+
dataset_name: str,
|
| 55 |
+
tensor_shape: Sequence[int],
|
| 56 |
+
tensor_dtype: str,
|
| 57 |
+
tensor_data: Any,
|
| 58 |
+
metadata: Optional[dict] = None,
|
| 59 |
+
) -> Any:
|
| 60 |
+
payload = {
|
| 61 |
+
"dataset_name": dataset_name,
|
| 62 |
+
"tensor_shape": list(tensor_shape),
|
| 63 |
+
"tensor_dtype": tensor_dtype,
|
| 64 |
+
"tensor_data": tensor_data,
|
| 65 |
+
"metadata": metadata,
|
| 66 |
+
}
|
| 67 |
+
return await self._call_json("tensorus_ingest_tensor", payload)
|
| 68 |
+
|
| 69 |
+
async def get_tensor_details(self, dataset_name: str, record_id: str) -> Any:
|
| 70 |
+
return await self._call_json(
|
| 71 |
+
"tensorus_get_tensor_details",
|
| 72 |
+
{"dataset_name": dataset_name, "record_id": record_id},
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
async def delete_tensor(self, dataset_name: str, record_id: str) -> Any:
|
| 76 |
+
return await self._call_json(
|
| 77 |
+
"tensorus_delete_tensor",
|
| 78 |
+
{"dataset_name": dataset_name, "record_id": record_id},
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
async def update_tensor_metadata(
|
| 82 |
+
self, dataset_name: str, record_id: str, new_metadata: dict
|
| 83 |
+
) -> Any:
|
| 84 |
+
return await self._call_json(
|
| 85 |
+
"tensorus_update_tensor_metadata",
|
| 86 |
+
{
|
| 87 |
+
"dataset_name": dataset_name,
|
| 88 |
+
"record_id": record_id,
|
| 89 |
+
"new_metadata": new_metadata,
|
| 90 |
+
},
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
# --- Tensor operations ---
|
| 94 |
+
async def apply_unary_operation(self, operation: str, payload: dict) -> Any:
|
| 95 |
+
return await self._call_json("tensorus_apply_unary_operation", {"operation": operation, **payload})
|
| 96 |
+
|
| 97 |
+
async def apply_binary_operation(self, operation: str, payload: dict) -> Any:
|
| 98 |
+
return await self._call_json("tensorus_apply_binary_operation", {"operation": operation, **payload})
|
| 99 |
+
|
| 100 |
+
async def apply_list_operation(self, operation: str, payload: dict) -> Any:
|
| 101 |
+
return await self._call_json("tensorus_apply_list_operation", {"operation": operation, **payload})
|
| 102 |
+
|
| 103 |
+
async def apply_einsum(self, payload: dict) -> Any:
|
| 104 |
+
return await self._call_json("tensorus_apply_einsum", payload)
|
| 105 |
+
|
| 106 |
+
# --- Misc ---
|
| 107 |
+
async def save_tensor(
|
| 108 |
+
self,
|
| 109 |
+
dataset_name: str,
|
| 110 |
+
tensor_shape: Sequence[int],
|
| 111 |
+
tensor_dtype: str,
|
| 112 |
+
tensor_data: Any,
|
| 113 |
+
metadata: Optional[dict] = None,
|
| 114 |
+
) -> Any:
|
| 115 |
+
payload = {
|
| 116 |
+
"dataset_name": dataset_name,
|
| 117 |
+
"tensor_shape": list(tensor_shape),
|
| 118 |
+
"tensor_dtype": tensor_dtype,
|
| 119 |
+
"tensor_data": tensor_data,
|
| 120 |
+
"metadata": metadata,
|
| 121 |
+
}
|
| 122 |
+
return await self._call_json("save_tensor", payload)
|
| 123 |
+
|
| 124 |
+
async def get_tensor(self, dataset_name: str, record_id: str) -> Any:
|
| 125 |
+
return await self._call_json("get_tensor", {"dataset_name": dataset_name, "record_id": record_id})
|
| 126 |
+
|
| 127 |
+
async def execute_nql_query(self, query: str) -> Any:
|
| 128 |
+
return await self._call_json("execute_nql_query", {"query": query})
|
tensorus/mcp_server.py
ADDED
|
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""FastMCP server exposing Tensorus API endpoints as tools.
|
| 2 |
+
|
| 3 |
+
This module registers a set of MCP tools that proxy to the Tensorus FastAPI
|
| 4 |
+
backend. Tools mirror the ones documented in the README under "Available
|
| 5 |
+
Tools" and return results as :class:`TextContent` objects.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import argparse
|
| 9 |
+
import json
|
| 10 |
+
from typing import Any, Optional, Sequence
|
| 11 |
+
|
| 12 |
+
import httpx
|
| 13 |
+
from fastmcp import FastMCP
|
| 14 |
+
try:
|
| 15 |
+
from fastmcp.tools import TextContent
|
| 16 |
+
except ImportError: # pragma: no cover - support older fastmcp versions
|
| 17 |
+
from dataclasses import dataclass
|
| 18 |
+
|
| 19 |
+
@dataclass
|
| 20 |
+
class TextContent: # minimal fallback for tests
|
| 21 |
+
type: str
|
| 22 |
+
text: str
|
| 23 |
+
|
| 24 |
+
API_BASE_URL = "http://127.0.0.1:7860"
|
| 25 |
+
|
| 26 |
+
server = FastMCP(name="Tensorus FastMCP")
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
async def _post(path: str, payload: dict) -> dict:
|
| 30 |
+
try:
|
| 31 |
+
async with httpx.AsyncClient() as client:
|
| 32 |
+
response = await client.post(f"{API_BASE_URL}{path}", json=payload)
|
| 33 |
+
response.raise_for_status()
|
| 34 |
+
return response.json()
|
| 35 |
+
except httpx.HTTPError as exc: # pragma: no cover - network failures
|
| 36 |
+
return {"error": str(exc)}
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
async def _get(path: str) -> dict:
|
| 40 |
+
try:
|
| 41 |
+
async with httpx.AsyncClient() as client:
|
| 42 |
+
response = await client.get(f"{API_BASE_URL}{path}")
|
| 43 |
+
response.raise_for_status()
|
| 44 |
+
return response.json()
|
| 45 |
+
except httpx.HTTPError as exc: # pragma: no cover - network failures
|
| 46 |
+
return {"error": str(exc)}
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
async def _put(path: str, payload: dict) -> dict:
|
| 50 |
+
try:
|
| 51 |
+
async with httpx.AsyncClient() as client:
|
| 52 |
+
response = await client.put(f"{API_BASE_URL}{path}", json=payload)
|
| 53 |
+
response.raise_for_status()
|
| 54 |
+
return response.json()
|
| 55 |
+
except httpx.HTTPError as exc: # pragma: no cover - network failures
|
| 56 |
+
return {"error": str(exc)}
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
async def _delete(path: str) -> dict:
|
| 60 |
+
try:
|
| 61 |
+
async with httpx.AsyncClient() as client:
|
| 62 |
+
response = await client.delete(f"{API_BASE_URL}{path}")
|
| 63 |
+
response.raise_for_status()
|
| 64 |
+
return response.json()
|
| 65 |
+
except httpx.HTTPError as exc: # pragma: no cover - network failures
|
| 66 |
+
return {"error": str(exc)}
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
@server.tool()
|
| 70 |
+
async def save_tensor(
|
| 71 |
+
dataset_name: str,
|
| 72 |
+
tensor_shape: Sequence[int],
|
| 73 |
+
tensor_dtype: str,
|
| 74 |
+
tensor_data: Any,
|
| 75 |
+
metadata: Optional[dict] = None,
|
| 76 |
+
) -> TextContent:
|
| 77 |
+
"""Save a tensor to a dataset."""
|
| 78 |
+
payload = {
|
| 79 |
+
"shape": list(tensor_shape),
|
| 80 |
+
"dtype": tensor_dtype,
|
| 81 |
+
"data": tensor_data,
|
| 82 |
+
"metadata": metadata,
|
| 83 |
+
}
|
| 84 |
+
result = await _post(f"/datasets/{dataset_name}/ingest", payload)
|
| 85 |
+
return TextContent(type="text", text=json.dumps(result))
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
@server.tool()
|
| 89 |
+
async def get_tensor(dataset_name: str, record_id: str) -> TextContent:
|
| 90 |
+
"""Retrieve a tensor by record ID."""
|
| 91 |
+
result = await _get(f"/datasets/{dataset_name}/tensors/{record_id}")
|
| 92 |
+
return TextContent(type="text", text=json.dumps(result))
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
@server.tool()
|
| 96 |
+
async def execute_nql_query(query: str) -> TextContent:
|
| 97 |
+
"""Execute a Natural Query Language query."""
|
| 98 |
+
result = await _post("/query", {"query": query})
|
| 99 |
+
return TextContent(type="text", text=json.dumps(result))
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
# --- Dataset Management Tools ---
|
| 103 |
+
|
| 104 |
+
@server.tool(name="tensorus_list_datasets")
|
| 105 |
+
async def tensorus_list_datasets() -> TextContent:
|
| 106 |
+
"""List all available datasets."""
|
| 107 |
+
result = await _get("/datasets")
|
| 108 |
+
return TextContent(type="text", text=json.dumps(result))
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
@server.tool(name="tensorus_create_dataset")
|
| 112 |
+
async def tensorus_create_dataset(dataset_name: str) -> TextContent:
|
| 113 |
+
"""Create a new dataset."""
|
| 114 |
+
result = await _post("/datasets/create", {"name": dataset_name})
|
| 115 |
+
return TextContent(type="text", text=json.dumps(result))
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
@server.tool(name="tensorus_delete_dataset")
|
| 119 |
+
async def tensorus_delete_dataset(dataset_name: str) -> TextContent:
|
| 120 |
+
"""Delete an existing dataset."""
|
| 121 |
+
result = await _delete(f"/datasets/{dataset_name}")
|
| 122 |
+
return TextContent(type="text", text=json.dumps(result))
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
# --- Tensor Management Tools ---
|
| 126 |
+
|
| 127 |
+
@server.tool(name="tensorus_ingest_tensor")
|
| 128 |
+
async def tensorus_ingest_tensor(
|
| 129 |
+
dataset_name: str,
|
| 130 |
+
tensor_shape: Sequence[int],
|
| 131 |
+
tensor_dtype: str,
|
| 132 |
+
tensor_data: Any,
|
| 133 |
+
metadata: Optional[dict] = None,
|
| 134 |
+
) -> TextContent:
|
| 135 |
+
"""Ingest a new tensor into a dataset."""
|
| 136 |
+
payload = {
|
| 137 |
+
"shape": list(tensor_shape),
|
| 138 |
+
"dtype": tensor_dtype,
|
| 139 |
+
"data": tensor_data,
|
| 140 |
+
"metadata": metadata,
|
| 141 |
+
}
|
| 142 |
+
result = await _post(f"/datasets/{dataset_name}/ingest", payload)
|
| 143 |
+
return TextContent(type="text", text=json.dumps(result))
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
@server.tool(name="tensorus_get_tensor_details")
|
| 147 |
+
async def tensorus_get_tensor_details(dataset_name: str, record_id: str) -> TextContent:
|
| 148 |
+
"""Retrieve tensor data and metadata."""
|
| 149 |
+
result = await _get(f"/datasets/{dataset_name}/tensors/{record_id}")
|
| 150 |
+
return TextContent(type="text", text=json.dumps(result))
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
@server.tool(name="tensorus_delete_tensor")
|
| 154 |
+
async def tensorus_delete_tensor(dataset_name: str, record_id: str) -> TextContent:
|
| 155 |
+
"""Delete a tensor from a dataset."""
|
| 156 |
+
result = await _delete(f"/datasets/{dataset_name}/tensors/{record_id}")
|
| 157 |
+
return TextContent(type="text", text=json.dumps(result))
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
@server.tool(name="tensorus_update_tensor_metadata")
|
| 161 |
+
async def tensorus_update_tensor_metadata(
|
| 162 |
+
dataset_name: str,
|
| 163 |
+
record_id: str,
|
| 164 |
+
new_metadata: dict,
|
| 165 |
+
) -> TextContent:
|
| 166 |
+
"""Replace metadata for a specific tensor."""
|
| 167 |
+
payload = {"new_metadata": new_metadata}
|
| 168 |
+
result = await _put(f"/datasets/{dataset_name}/tensors/{record_id}/metadata", payload)
|
| 169 |
+
return TextContent(type="text", text=json.dumps(result))
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
# --- Tensor Operation Tools ---
|
| 173 |
+
|
| 174 |
+
@server.tool(name="tensorus_apply_unary_operation")
|
| 175 |
+
async def tensorus_apply_unary_operation(operation: str, request_payload: dict) -> TextContent:
|
| 176 |
+
"""Apply a unary TensorOps operation (e.g., log, reshape)."""
|
| 177 |
+
result = await _post(f"/ops/{operation}", request_payload)
|
| 178 |
+
return TextContent(type="text", text=json.dumps(result))
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
@server.tool(name="tensorus_apply_binary_operation")
|
| 182 |
+
async def tensorus_apply_binary_operation(operation: str, request_payload: dict) -> TextContent:
|
| 183 |
+
"""Apply a binary TensorOps operation (e.g., add, subtract)."""
|
| 184 |
+
result = await _post(f"/ops/{operation}", request_payload)
|
| 185 |
+
return TextContent(type="text", text=json.dumps(result))
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
@server.tool(name="tensorus_apply_list_operation")
|
| 189 |
+
async def tensorus_apply_list_operation(operation: str, request_payload: dict) -> TextContent:
|
| 190 |
+
"""Apply a TensorOps list operation such as concatenate or stack."""
|
| 191 |
+
result = await _post(f"/ops/{operation}", request_payload)
|
| 192 |
+
return TextContent(type="text", text=json.dumps(result))
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
@server.tool(name="tensorus_apply_einsum")
|
| 196 |
+
async def tensorus_apply_einsum(request_payload: dict) -> TextContent:
|
| 197 |
+
"""Apply an einsum operation."""
|
| 198 |
+
result = await _post("/ops/einsum", request_payload)
|
| 199 |
+
return TextContent(type="text", text=json.dumps(result))
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
@server.resource("resource://datasets", name="datasets", description="List of datasets")
|
| 203 |
+
async def datasets_resource() -> str:
|
| 204 |
+
data = await _get("/datasets")
|
| 205 |
+
return json.dumps(data.get("data", []))
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def main() -> None:
|
| 209 |
+
global API_BASE_URL
|
| 210 |
+
|
| 211 |
+
parser = argparse.ArgumentParser(
|
| 212 |
+
description="Run the Tensorus FastMCP server exposing dataset and tensor tools"
|
| 213 |
+
)
|
| 214 |
+
parser.add_argument(
|
| 215 |
+
"--transport", choices=["stdio", "sse"], default="stdio", help="Transport protocol"
|
| 216 |
+
)
|
| 217 |
+
parser.add_argument(
|
| 218 |
+
"--api-url", default=API_BASE_URL, help="Base URL of the running FastAPI backend"
|
| 219 |
+
)
|
| 220 |
+
args = parser.parse_args()
|
| 221 |
+
|
| 222 |
+
API_BASE_URL = args.api_url.rstrip("/")
|
| 223 |
+
|
| 224 |
+
server.run(args.transport)
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
if __name__ == "__main__": # pragma: no cover - manual execution
|
| 228 |
+
main()
|
tensorus/metadata/__init__.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Tensorus Metadata Package.
|
| 3 |
+
|
| 4 |
+
This package provides schemas for describing tensors and their semantic meaning,
|
| 5 |
+
as well as a basic in-memory storage mechanism for managing these metadata objects.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from .schemas import (
|
| 9 |
+
TensorDescriptor,
|
| 10 |
+
SemanticMetadata,
|
| 11 |
+
DataType,
|
| 12 |
+
StorageFormat,
|
| 13 |
+
AccessControl,
|
| 14 |
+
CompressionInfo,
|
| 15 |
+
# Extended Schemas
|
| 16 |
+
LineageSourceType,
|
| 17 |
+
LineageSource,
|
| 18 |
+
ParentTensorLink,
|
| 19 |
+
TransformationStep,
|
| 20 |
+
VersionControlInfo,
|
| 21 |
+
LineageMetadata,
|
| 22 |
+
ComputationalMetadata,
|
| 23 |
+
QualityStatistics,
|
| 24 |
+
MissingValuesInfo,
|
| 25 |
+
OutlierInfo,
|
| 26 |
+
QualityMetadata,
|
| 27 |
+
RelatedTensorLink,
|
| 28 |
+
RelationalMetadata,
|
| 29 |
+
UsageAccessRecord,
|
| 30 |
+
UsageMetadata
|
| 31 |
+
)
|
| 32 |
+
from tensorus.config import settings
|
| 33 |
+
from .storage_abc import MetadataStorage
|
| 34 |
+
from .storage import InMemoryStorage
|
| 35 |
+
from .postgres_storage import PostgresMetadataStorage
|
| 36 |
+
from .schemas_iodata import TensorusExportData, TensorusExportEntry # Import I/O schemas
|
| 37 |
+
|
| 38 |
+
class ConfigurationError(Exception):
|
| 39 |
+
pass
|
| 40 |
+
|
| 41 |
+
def get_configured_storage_instance() -> MetadataStorage:
|
| 42 |
+
if settings.STORAGE_BACKEND == "postgres":
|
| 43 |
+
if settings.POSTGRES_DSN:
|
| 44 |
+
return PostgresMetadataStorage(dsn=settings.POSTGRES_DSN)
|
| 45 |
+
elif settings.POSTGRES_HOST and settings.POSTGRES_USER and settings.POSTGRES_DB:
|
| 46 |
+
return PostgresMetadataStorage(
|
| 47 |
+
host=settings.POSTGRES_HOST,
|
| 48 |
+
port=settings.POSTGRES_PORT or 5432,
|
| 49 |
+
user=settings.POSTGRES_USER,
|
| 50 |
+
password=settings.POSTGRES_PASSWORD,
|
| 51 |
+
database=settings.POSTGRES_DB
|
| 52 |
+
)
|
| 53 |
+
else:
|
| 54 |
+
raise ConfigurationError(
|
| 55 |
+
"PostgreSQL backend selected, but required connection details "
|
| 56 |
+
"(DSN or Host/User/DB) are missing. Please set TENSORUS_POSTGRES_DSN "
|
| 57 |
+
"or TENSORUS_POSTGRES_HOST, TENSORUS_POSTGRES_USER, TENSORUS_POSTGRES_DB."
|
| 58 |
+
)
|
| 59 |
+
elif settings.STORAGE_BACKEND == "in_memory":
|
| 60 |
+
return InMemoryStorage()
|
| 61 |
+
else:
|
| 62 |
+
raise ConfigurationError(f"Unsupported storage backend: {settings.STORAGE_BACKEND}")
|
| 63 |
+
|
| 64 |
+
storage_instance: MetadataStorage = get_configured_storage_instance()
|
| 65 |
+
|
| 66 |
+
__all__ = [
|
| 67 |
+
# Core Schemas
|
| 68 |
+
"TensorDescriptor",
|
| 69 |
+
"SemanticMetadata",
|
| 70 |
+
"DataType",
|
| 71 |
+
"StorageFormat",
|
| 72 |
+
"AccessControl",
|
| 73 |
+
"CompressionInfo",
|
| 74 |
+
# Extended Schemas - Main Classes
|
| 75 |
+
"LineageMetadata",
|
| 76 |
+
"ComputationalMetadata",
|
| 77 |
+
"QualityMetadata",
|
| 78 |
+
"RelationalMetadata",
|
| 79 |
+
"UsageMetadata",
|
| 80 |
+
# Extended Schemas - Helper Classes & Enums
|
| 81 |
+
"LineageSourceType",
|
| 82 |
+
"LineageSource",
|
| 83 |
+
"ParentTensorLink",
|
| 84 |
+
"TransformationStep",
|
| 85 |
+
"VersionControlInfo",
|
| 86 |
+
"QualityStatistics",
|
| 87 |
+
"MissingValuesInfo",
|
| 88 |
+
"OutlierInfo",
|
| 89 |
+
"RelatedTensorLink",
|
| 90 |
+
"UsageAccessRecord",
|
| 91 |
+
# I/O Schemas
|
| 92 |
+
"TensorusExportData",
|
| 93 |
+
"TensorusExportEntry",
|
| 94 |
+
# Storage Abstraction & Implementations
|
| 95 |
+
"MetadataStorage",
|
| 96 |
+
"InMemoryStorage",
|
| 97 |
+
"PostgresMetadataStorage",
|
| 98 |
+
"storage_instance",
|
| 99 |
+
"ConfigurationError",
|
| 100 |
+
"get_configured_storage_instance"
|
| 101 |
+
]
|
tensorus/metadata/postgres_storage.py
ADDED
|
@@ -0,0 +1,741 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import psycopg2
|
| 2 |
+
import psycopg2.pool
|
| 3 |
+
import psycopg2.extras # For dict cursor
|
| 4 |
+
import json
|
| 5 |
+
from typing import List, Optional, Dict, Any
|
| 6 |
+
from uuid import UUID
|
| 7 |
+
import logging
|
| 8 |
+
|
| 9 |
+
from .storage_abc import MetadataStorage
|
| 10 |
+
from .schemas import (
|
| 11 |
+
TensorDescriptor, SemanticMetadata, DataType, StorageFormat, AccessControl, CompressionInfo,
|
| 12 |
+
LineageMetadata, ComputationalMetadata, QualityMetadata,
|
| 13 |
+
RelationalMetadata, UsageMetadata # Import other extended types for method signatures
|
| 14 |
+
)
|
| 15 |
+
import copy # Ensure copy is imported, as it was added to InMemoryStorage and might be useful here too.
|
| 16 |
+
|
| 17 |
+
# Configure module level logging
|
| 18 |
+
logging.basicConfig(level=logging.INFO,
|
| 19 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
| 20 |
+
logger = logging.getLogger(__name__)
|
| 21 |
+
|
| 22 |
+
__all__ = ["PostgresMetadataStorage"]
|
| 23 |
+
|
| 24 |
+
# --- DDL Comments ---
|
| 25 |
+
#
|
| 26 |
+
# CREATE TYPE data_type_enum AS ENUM (
|
| 27 |
+
# 'float32', 'float64', 'float16', 'int32', 'int64', 'int16', 'int8', 'uint8',
|
| 28 |
+
# 'boolean', 'string', 'complex64', 'complex128', 'other'
|
| 29 |
+
# );
|
| 30 |
+
#
|
| 31 |
+
# CREATE TYPE storage_format_enum AS ENUM (
|
| 32 |
+
# 'raw', 'numpy_npz', 'hdf5', 'compressed_zlib', 'compressed_gzip', 'custom'
|
| 33 |
+
# );
|
| 34 |
+
#
|
| 35 |
+
# CREATE TABLE IF NOT EXISTS tensor_descriptors (
|
| 36 |
+
# tensor_id UUID PRIMARY KEY,
|
| 37 |
+
# dimensionality INTEGER NOT NULL,
|
| 38 |
+
# shape INTEGER[] NOT NULL,
|
| 39 |
+
# data_type TEXT NOT NULL, -- Could use data_type_enum
|
| 40 |
+
# storage_format TEXT NOT NULL, -- Could use storage_format_enum
|
| 41 |
+
# creation_timestamp TIMESTAMPTZ NOT NULL,
|
| 42 |
+
# last_modified_timestamp TIMESTAMPTZ NOT NULL,
|
| 43 |
+
# owner TEXT,
|
| 44 |
+
# access_control JSONB,
|
| 45 |
+
# byte_size BIGINT,
|
| 46 |
+
# checksum TEXT,
|
| 47 |
+
# compression_info JSONB,
|
| 48 |
+
# tags TEXT[],
|
| 49 |
+
# metadata JSONB,
|
| 50 |
+
# CONSTRAINT shape_dimensionality_check CHECK (array_length(shape, 1) = dimensionality OR dimensionality = 0 AND shape IS NULL OR array_length(shape,1) = 0)
|
| 51 |
+
# );
|
| 52 |
+
# CREATE INDEX IF NOT EXISTS idx_td_owner ON tensor_descriptors(owner);
|
| 53 |
+
# CREATE INDEX IF NOT EXISTS idx_td_data_type ON tensor_descriptors(data_type);
|
| 54 |
+
# CREATE INDEX IF NOT EXISTS idx_td_tags ON tensor_descriptors USING GIN(tags);
|
| 55 |
+
#
|
| 56 |
+
# -- Table for SemanticMetadata (example, one-to-many with TensorDescriptor)
|
| 57 |
+
# CREATE TABLE IF NOT EXISTS semantic_metadata_entries (
|
| 58 |
+
# id SERIAL PRIMARY KEY, -- Or UUID primary key if preferred
|
| 59 |
+
# tensor_id UUID NOT NULL REFERENCES tensor_descriptors(tensor_id) ON DELETE CASCADE,
|
| 60 |
+
# name TEXT NOT NULL,
|
| 61 |
+
# description TEXT,
|
| 62 |
+
# -- other fields from SemanticMetadata schema ...
|
| 63 |
+
# UNIQUE (tensor_id, name) -- Ensure name is unique per tensor_id
|
| 64 |
+
# );
|
| 65 |
+
#
|
| 66 |
+
# -- Generic table structure for 1-to-1 extended metadata (Lineage, Computational, etc.)
|
| 67 |
+
# -- Replace <metadata_name> with lineage, computational, quality, relational, usage
|
| 68 |
+
# CREATE TABLE IF NOT EXISTS <metadata_name>_metadata (
|
| 69 |
+
# tensor_id UUID PRIMARY KEY REFERENCES tensor_descriptors(tensor_id) ON DELETE CASCADE,
|
| 70 |
+
# data JSONB NOT NULL -- Store the entire Pydantic model as JSONB
|
| 71 |
+
# );
|
| 72 |
+
#
|
| 73 |
+
|
| 74 |
+
class PostgresMetadataStorage(MetadataStorage):
|
| 75 |
+
def __init__(self, dsn: Optional[str] = None, min_conn: int = 1, max_conn: int = 5, **kwargs):
|
| 76 |
+
self.dsn = dsn
|
| 77 |
+
self.pool = None
|
| 78 |
+
if dsn: # Allow DSN or individual params
|
| 79 |
+
self.pool = psycopg2.pool.SimpleConnectionPool(min_conn, max_conn, dsn=dsn)
|
| 80 |
+
elif kwargs.get('database') and kwargs.get('user'):
|
| 81 |
+
self.pool = psycopg2.pool.SimpleConnectionPool(min_conn, max_conn, **kwargs)
|
| 82 |
+
else:
|
| 83 |
+
raise ValueError("PostgreSQL connection parameters (DSN or host/db/user etc.) not provided.")
|
| 84 |
+
|
| 85 |
+
def _execute_query(self, query: str, params: tuple = None, fetch: str = None):
|
| 86 |
+
"""Helper to execute queries with connection pooling."""
|
| 87 |
+
if not self.pool:
|
| 88 |
+
raise ConnectionError("Connection pool not initialized.")
|
| 89 |
+
conn = None
|
| 90 |
+
try:
|
| 91 |
+
conn = self.pool.getconn()
|
| 92 |
+
with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
| 93 |
+
cur.execute(query, params)
|
| 94 |
+
conn.commit()
|
| 95 |
+
if fetch == "one":
|
| 96 |
+
return cur.fetchone()
|
| 97 |
+
if fetch == "all":
|
| 98 |
+
return cur.fetchall()
|
| 99 |
+
# For INSERT/UPDATE/DELETE, rowcount might be useful
|
| 100 |
+
return cur.rowcount
|
| 101 |
+
except Exception as e:
|
| 102 |
+
if conn: conn.rollback()
|
| 103 |
+
# Log error e
|
| 104 |
+
raise # Re-raise after logging or wrap in custom exception
|
| 105 |
+
finally:
|
| 106 |
+
if conn and self.pool:
|
| 107 |
+
self.pool.putconn(conn)
|
| 108 |
+
|
| 109 |
+
# --- TensorDescriptor Methods ---
|
| 110 |
+
def add_tensor_descriptor(self, descriptor: TensorDescriptor) -> None:
|
| 111 |
+
query = """
|
| 112 |
+
INSERT INTO tensor_descriptors (
|
| 113 |
+
tensor_id, dimensionality, shape, data_type, storage_format,
|
| 114 |
+
creation_timestamp, last_modified_timestamp, owner, access_control,
|
| 115 |
+
byte_size, checksum, compression_info, tags, metadata
|
| 116 |
+
) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
|
| 117 |
+
ON CONFLICT (tensor_id) DO UPDATE SET
|
| 118 |
+
dimensionality = EXCLUDED.dimensionality,
|
| 119 |
+
shape = EXCLUDED.shape,
|
| 120 |
+
data_type = EXCLUDED.data_type,
|
| 121 |
+
storage_format = EXCLUDED.storage_format,
|
| 122 |
+
last_modified_timestamp = EXCLUDED.last_modified_timestamp,
|
| 123 |
+
owner = EXCLUDED.owner,
|
| 124 |
+
access_control = EXCLUDED.access_control,
|
| 125 |
+
byte_size = EXCLUDED.byte_size,
|
| 126 |
+
checksum = EXCLUDED.checksum,
|
| 127 |
+
compression_info = EXCLUDED.compression_info,
|
| 128 |
+
tags = EXCLUDED.tags,
|
| 129 |
+
metadata = EXCLUDED.metadata;
|
| 130 |
+
"""
|
| 131 |
+
params = (
|
| 132 |
+
descriptor.tensor_id,
|
| 133 |
+
descriptor.dimensionality,
|
| 134 |
+
descriptor.shape,
|
| 135 |
+
descriptor.data_type.value,
|
| 136 |
+
descriptor.storage_format.value,
|
| 137 |
+
descriptor.creation_timestamp,
|
| 138 |
+
descriptor.last_modified_timestamp,
|
| 139 |
+
descriptor.owner,
|
| 140 |
+
descriptor.access_control.model_dump_json() if descriptor.access_control else None, # Pydantic v2
|
| 141 |
+
# json.dumps(descriptor.access_control.dict()) if descriptor.access_control else None, # Pydantic v1
|
| 142 |
+
descriptor.byte_size,
|
| 143 |
+
descriptor.checksum,
|
| 144 |
+
descriptor.compression_info.model_dump_json() if descriptor.compression_info else None, # Pydantic v2
|
| 145 |
+
# json.dumps(descriptor.compression_info.dict()) if descriptor.compression_info else None, # Pydantic v1
|
| 146 |
+
descriptor.tags,
|
| 147 |
+
json.dumps(descriptor.metadata) if descriptor.metadata else None,
|
| 148 |
+
)
|
| 149 |
+
self._execute_query(query, params)
|
| 150 |
+
|
| 151 |
+
def get_tensor_descriptor(self, tensor_id: UUID) -> Optional[TensorDescriptor]:
|
| 152 |
+
query = "SELECT * FROM tensor_descriptors WHERE tensor_id = %s;"
|
| 153 |
+
row = self._execute_query(query, (tensor_id,), fetch="one")
|
| 154 |
+
if row:
|
| 155 |
+
# Pydantic models expect enums, not their string values directly from DB for some fields
|
| 156 |
+
# Need to handle this transformation carefully.
|
| 157 |
+
# Also, JSONB fields need to be parsed back.
|
| 158 |
+
data = dict(row)
|
| 159 |
+
data['data_type'] = DataType(data['data_type'])
|
| 160 |
+
data['storage_format'] = StorageFormat(data['storage_format'])
|
| 161 |
+
if data.get('access_control'): # JSONB field
|
| 162 |
+
data['access_control'] = AccessControl(**data['access_control'])
|
| 163 |
+
if data.get('compression_info'): # JSONB field
|
| 164 |
+
data['compression_info'] = CompressionInfo(**data['compression_info'])
|
| 165 |
+
# metadata is also JSONB but can be any dict, so direct assignment is fine
|
| 166 |
+
return TensorDescriptor(**data)
|
| 167 |
+
return None
|
| 168 |
+
|
| 169 |
+
def delete_tensor_descriptor(self, tensor_id: UUID) -> bool:
|
| 170 |
+
# ON DELETE CASCADE should handle related metadata in other tables
|
| 171 |
+
query = "DELETE FROM tensor_descriptors WHERE tensor_id = %s;"
|
| 172 |
+
rowcount = self._execute_query(query, (tensor_id,))
|
| 173 |
+
return rowcount > 0
|
| 174 |
+
|
| 175 |
+
# --- SemanticMetadata Methods (Example for one-to-many) ---
|
| 176 |
+
# Assuming semantic_metadata_entries table as defined in DDL comments
|
| 177 |
+
def add_semantic_metadata(self, metadata: SemanticMetadata) -> None:
|
| 178 |
+
# Check if TD exists
|
| 179 |
+
if not self.get_tensor_descriptor(metadata.tensor_id):
|
| 180 |
+
raise ValueError(f"TensorDescriptor with ID {metadata.tensor_id} not found.")
|
| 181 |
+
|
| 182 |
+
# Upsert based on (tensor_id, name)
|
| 183 |
+
query = """
|
| 184 |
+
INSERT INTO semantic_metadata_entries (tensor_id, name, description)
|
| 185 |
+
VALUES (%s, %s, %s)
|
| 186 |
+
ON CONFLICT (tensor_id, name) DO UPDATE SET
|
| 187 |
+
description = EXCLUDED.description;
|
| 188 |
+
"""
|
| 189 |
+
# Add other fields from SemanticMetadata schema to query and params as needed
|
| 190 |
+
params = (metadata.tensor_id, metadata.name, metadata.description)
|
| 191 |
+
self._execute_query(query, params)
|
| 192 |
+
|
| 193 |
+
def get_semantic_metadata(self, tensor_id: UUID) -> List[SemanticMetadata]:
|
| 194 |
+
query = "SELECT tensor_id, name, description FROM semantic_metadata_entries WHERE tensor_id = %s;"
|
| 195 |
+
rows = self._execute_query(query, (tensor_id,), fetch="all")
|
| 196 |
+
return [SemanticMetadata(**dict(row)) for row in rows]
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
# --- Placeholder for other methods ---
|
| 200 |
+
def update_tensor_descriptor(self, tensor_id: UUID, **kwargs) -> Optional[TensorDescriptor]:
|
| 201 |
+
# Complex: involves fetching, updating fields, then writing back.
|
| 202 |
+
# Need to handle partial updates carefully.
|
| 203 |
+
# Example: SELECT ... FOR UPDATE, then construct UPDATE statement.
|
| 204 |
+
current_td = self.get_tensor_descriptor(tensor_id)
|
| 205 |
+
if not current_td:
|
| 206 |
+
return None
|
| 207 |
+
|
| 208 |
+
update_data = current_td.model_dump() # Pydantic v2
|
| 209 |
+
# update_data = current_td.dict() # Pydantic v1
|
| 210 |
+
|
| 211 |
+
for key, value in kwargs.items():
|
| 212 |
+
if key in update_data: # only update valid fields
|
| 213 |
+
update_data[key] = value
|
| 214 |
+
|
| 215 |
+
# Re-validate before saving
|
| 216 |
+
try:
|
| 217 |
+
updated_td = TensorDescriptor(**update_data)
|
| 218 |
+
updated_td.last_modified_timestamp = datetime.utcnow() # Explicitly update timestamp
|
| 219 |
+
self.add_tensor_descriptor(updated_td) # Use add which does upsert
|
| 220 |
+
return updated_td
|
| 221 |
+
except Exception: # Should be pydantic.ValidationError
|
| 222 |
+
# Log error or raise custom validation error
|
| 223 |
+
raise
|
| 224 |
+
|
| 225 |
+
def list_tensor_descriptors(
|
| 226 |
+
self,
|
| 227 |
+
owner: Optional[str] = None,
|
| 228 |
+
data_type: Optional[DataType] = None,
|
| 229 |
+
tags_contain: Optional[List[str]] = None,
|
| 230 |
+
lineage_version: Optional[str] = None,
|
| 231 |
+
# Add other filter params as needed, matching the API layer
|
| 232 |
+
# For brevity, only a few are shown here.
|
| 233 |
+
) -> List[TensorDescriptor]:
|
| 234 |
+
base_query = "SELECT DISTINCT td.* FROM tensor_descriptors td"
|
| 235 |
+
joins: List[str] = []
|
| 236 |
+
conditions: List[str] = []
|
| 237 |
+
params: Dict[str, Any] = {} # Using dict for named parameters with %(name)s style
|
| 238 |
+
|
| 239 |
+
if owner:
|
| 240 |
+
conditions.append("td.owner = %(owner)s")
|
| 241 |
+
params["owner"] = owner
|
| 242 |
+
if data_type:
|
| 243 |
+
conditions.append("td.data_type = %(data_type)s")
|
| 244 |
+
params["data_type"] = data_type.value
|
| 245 |
+
if tags_contain:
|
| 246 |
+
conditions.append("td.tags @> %(tags_contain)s") # Array contains operator
|
| 247 |
+
params["tags_contain"] = tags_contain
|
| 248 |
+
|
| 249 |
+
# Example: Filtering by lineage.version
|
| 250 |
+
if lineage_version:
|
| 251 |
+
if "lm" not in [j.split()[1] for j in joins if len(j.split()) > 1]: # Avoid duplicate joins
|
| 252 |
+
joins.append("LEFT JOIN lineage_metadata lm ON td.tensor_id = lm.tensor_id")
|
| 253 |
+
conditions.append("lm.data->>'version' = %(lineage_version)s")
|
| 254 |
+
params["lineage_version"] = lineage_version
|
| 255 |
+
|
| 256 |
+
# Construct final query
|
| 257 |
+
if joins:
|
| 258 |
+
base_query += " " + " ".join(joins)
|
| 259 |
+
if conditions:
|
| 260 |
+
base_query += " WHERE " + " AND ".join(conditions)
|
| 261 |
+
|
| 262 |
+
base_query += ";"
|
| 263 |
+
|
| 264 |
+
rows = self._execute_query(base_query, params, fetch="all") # type: ignore # psycopg2 params can be dict
|
| 265 |
+
results = []
|
| 266 |
+
for row in rows:
|
| 267 |
+
data = dict(row)
|
| 268 |
+
data['data_type'] = DataType(data['data_type'])
|
| 269 |
+
data['storage_format'] = StorageFormat(data['storage_format'])
|
| 270 |
+
if data.get('access_control'): data['access_control'] = AccessControl(**data['access_control'])
|
| 271 |
+
if data.get('compression_info'): data['compression_info'] = CompressionInfo(**data['compression_info'])
|
| 272 |
+
results.append(TensorDescriptor(**data))
|
| 273 |
+
return results
|
| 274 |
+
|
| 275 |
+
def get_semantic_metadata_by_name(self, tensor_id: UUID, name: str) -> Optional[SemanticMetadata]:
|
| 276 |
+
query = "SELECT tensor_id, name, description FROM semantic_metadata_entries WHERE tensor_id = %s AND name = %s;"
|
| 277 |
+
row = self._execute_query(query, (tensor_id, name), fetch="one")
|
| 278 |
+
return SemanticMetadata(**dict(row)) if row else None
|
| 279 |
+
|
| 280 |
+
def update_semantic_metadata(self, tensor_id: UUID, name: str, new_description: Optional[str] = None, new_name: Optional[str] = None) -> Optional[SemanticMetadata]:
|
| 281 |
+
current_sm = self.get_semantic_metadata_by_name(tensor_id, name)
|
| 282 |
+
if not current_sm:
|
| 283 |
+
return None
|
| 284 |
+
|
| 285 |
+
description_to_set = new_description if new_description is not None else current_sm.description
|
| 286 |
+
name_to_set = new_name if new_name is not None else current_sm.name
|
| 287 |
+
|
| 288 |
+
# If name is being changed, check for conflict first
|
| 289 |
+
if new_name and new_name != name:
|
| 290 |
+
existing_with_new_name = self.get_semantic_metadata_by_name(tensor_id, new_name)
|
| 291 |
+
if existing_with_new_name:
|
| 292 |
+
raise ValueError(f"SemanticMetadata with name '{new_name}' already exists for tensor {tensor_id}.")
|
| 293 |
+
|
| 294 |
+
query = "UPDATE semantic_metadata_entries SET name = %s, description = %s WHERE tensor_id = %s AND name = %s;"
|
| 295 |
+
self._execute_query(query, (name_to_set, description_to_set, tensor_id, name))
|
| 296 |
+
|
| 297 |
+
# Return the updated object by fetching it again
|
| 298 |
+
return self.get_semantic_metadata_by_name(tensor_id, name_to_set)
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
def delete_semantic_metadata(self, tensor_id: UUID, name: str) -> bool:
|
| 302 |
+
query = "DELETE FROM semantic_metadata_entries WHERE tensor_id = %s AND name = %s;"
|
| 303 |
+
rowcount = self._execute_query(query, (tensor_id, name)) # type: ignore # rowcount is int
|
| 304 |
+
return rowcount > 0
|
| 305 |
+
|
| 306 |
+
# --- Implementations for other extended metadata (using JSONB 'data' column pattern) ---
|
| 307 |
+
def _add_jsonb_metadata(self, table_name: str, metadata_obj: Any) -> None:
|
| 308 |
+
# Check if the parent TensorDescriptor exists
|
| 309 |
+
# This check is good practice, though add_tensor_descriptor also checks.
|
| 310 |
+
# Redundant if _add_jsonb_metadata is only called internally after a TD check.
|
| 311 |
+
# For direct calls or future refactoring, it's safer.
|
| 312 |
+
parent_td = self.get_tensor_descriptor(metadata_obj.tensor_id)
|
| 313 |
+
if not parent_td:
|
| 314 |
+
raise ValueError(f"TensorDescriptor with ID {metadata_obj.tensor_id} not found. Cannot add {metadata_obj.__class__.__name__}.")
|
| 315 |
+
|
| 316 |
+
query = f"""
|
| 317 |
+
INSERT INTO {table_name} (tensor_id, data) VALUES (%(tensor_id)s, %(data)s)
|
| 318 |
+
ON CONFLICT (tensor_id) DO UPDATE SET data = EXCLUDED.data;
|
| 319 |
+
"""
|
| 320 |
+
params = {
|
| 321 |
+
"tensor_id": metadata_obj.tensor_id,
|
| 322 |
+
"data": metadata_obj.model_dump_json() # Pydantic v2
|
| 323 |
+
# "data": metadata_obj.json() # Pydantic v1
|
| 324 |
+
}
|
| 325 |
+
self._execute_query(query, params)
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
def _get_jsonb_metadata(self, table_name: str, tensor_id: UUID, model_class: type) -> Optional[Any]:
|
| 329 |
+
query = f"SELECT data FROM {table_name} WHERE tensor_id = %(tensor_id)s;"
|
| 330 |
+
params = {"tensor_id": tensor_id}
|
| 331 |
+
row = self._execute_query(query, params, fetch="one")
|
| 332 |
+
if row and row['data']:
|
| 333 |
+
return model_class.model_validate_json(row['data']) # Pydantic v2
|
| 334 |
+
# return model_class.parse_raw(row['data']) # Pydantic v1
|
| 335 |
+
return None
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
def _update_jsonb_metadata(self, table_name: str, tensor_id: UUID, model_class: type, **kwargs) -> Optional[Any]:
|
| 339 |
+
current_obj = self._get_jsonb_metadata(table_name, tensor_id, model_class)
|
| 340 |
+
if not current_obj:
|
| 341 |
+
return None
|
| 342 |
+
|
| 343 |
+
current_data = current_obj.model_dump() # Pydantic v2
|
| 344 |
+
# current_data = current_obj.dict() # Pydantic v1
|
| 345 |
+
|
| 346 |
+
# Perform a deep update for nested dictionaries if necessary
|
| 347 |
+
# For simple top-level field updates, direct assignment is fine.
|
| 348 |
+
# This example does a simple top-level merge.
|
| 349 |
+
for key, value in kwargs.items():
|
| 350 |
+
current_data[key] = value
|
| 351 |
+
# updated_data = {**current_data, **kwargs} # This does a shallow merge
|
| 352 |
+
|
| 353 |
+
try:
|
| 354 |
+
new_obj = model_class.model_validate(current_data) # Pydantic v2
|
| 355 |
+
# new_obj = model_class(**current_data) # Pydantic v1
|
| 356 |
+
self._add_jsonb_metadata(table_name, new_obj) # Use add for upsert
|
| 357 |
+
return new_obj
|
| 358 |
+
except Exception as e: # Should be Pydantic ValidationError
|
| 359 |
+
# Log e
|
| 360 |
+
raise ValueError(f"Update for {model_class.__name__} failed validation: {e}")
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
def _delete_jsonb_metadata(self, table_name: str, tensor_id: UUID) -> bool:
|
| 364 |
+
query = f"DELETE FROM {table_name} WHERE tensor_id = %(tensor_id)s;"
|
| 365 |
+
params = {"tensor_id": tensor_id}
|
| 366 |
+
rowcount = self._execute_query(query, params) # type: ignore # rowcount is int
|
| 367 |
+
return rowcount > 0
|
| 368 |
+
|
| 369 |
+
def add_lineage_metadata(self, m: LineageMetadata): self._add_jsonb_metadata("lineage_metadata", m)
|
| 370 |
+
def get_lineage_metadata(self, tid: UUID): return self._get_jsonb_metadata("lineage_metadata", tid, LineageMetadata)
|
| 371 |
+
def update_lineage_metadata(self, tid: UUID, **kw): return self._update_jsonb_metadata("lineage_metadata", tid, LineageMetadata, **kw)
|
| 372 |
+
def delete_lineage_metadata(self, tid: UUID): return self._delete_jsonb_metadata("lineage_metadata", tid)
|
| 373 |
+
|
| 374 |
+
def add_computational_metadata(self, m: ComputationalMetadata): self._add_jsonb_metadata("computational_metadata", m)
|
| 375 |
+
def get_computational_metadata(self, tid: UUID): return self._get_jsonb_metadata("computational_metadata", tid, ComputationalMetadata)
|
| 376 |
+
def update_computational_metadata(self, tid: UUID, **kw): return self._update_jsonb_metadata("computational_metadata", tid, ComputationalMetadata, **kw)
|
| 377 |
+
def delete_computational_metadata(self, tid: UUID): return self._delete_jsonb_metadata("computational_metadata", tid)
|
| 378 |
+
|
| 379 |
+
def add_quality_metadata(self, m: QualityMetadata): self._add_jsonb_metadata("quality_metadata", m)
|
| 380 |
+
def get_quality_metadata(self, tid: UUID): return self._get_jsonb_metadata("quality_metadata", tid, QualityMetadata)
|
| 381 |
+
def update_quality_metadata(self, tid: UUID, **kw): return self._update_jsonb_metadata("quality_metadata", tid, QualityMetadata, **kw)
|
| 382 |
+
def delete_quality_metadata(self, tid: UUID): return self._delete_jsonb_metadata("quality_metadata", tid)
|
| 383 |
+
|
| 384 |
+
def add_relational_metadata(self, m: RelationalMetadata): self._add_jsonb_metadata("relational_metadata", m)
|
| 385 |
+
def get_relational_metadata(self, tid: UUID): return self._get_jsonb_metadata("relational_metadata", tid, RelationalMetadata)
|
| 386 |
+
def update_relational_metadata(self, tid: UUID, **kw): return self._update_jsonb_metadata("relational_metadata", tid, RelationalMetadata, **kw)
|
| 387 |
+
def delete_relational_metadata(self, tid: UUID): return self._delete_jsonb_metadata("relational_metadata", tid)
|
| 388 |
+
|
| 389 |
+
def add_usage_metadata(self, m: UsageMetadata): self._add_jsonb_metadata("usage_metadata", m)
|
| 390 |
+
def get_usage_metadata(self, tid: UUID): return self._get_jsonb_metadata("usage_metadata", tid, UsageMetadata)
|
| 391 |
+
def update_usage_metadata(self, tid: UUID, **kw): return self._update_jsonb_metadata("usage_metadata", tid, UsageMetadata, **kw)
|
| 392 |
+
def delete_usage_metadata(self, tid: UUID): return self._delete_jsonb_metadata("usage_metadata", tid)
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
def get_parent_tensor_ids(self, tensor_id: UUID) -> List[UUID]:
|
| 396 |
+
# Assumes lineage_metadata table with 'data' JSONB column
|
| 397 |
+
# and data column has 'parent_tensors' list like: [{"tensor_id": "uuid", ...}]
|
| 398 |
+
query = """
|
| 399 |
+
SELECT parent.value->>'tensor_id' AS parent_id
|
| 400 |
+
FROM lineage_metadata lm, jsonb_array_elements(lm.data->'parent_tensors') AS parent
|
| 401 |
+
WHERE lm.tensor_id = %(tensor_id)s;
|
| 402 |
+
"""
|
| 403 |
+
params = {"tensor_id": tensor_id}
|
| 404 |
+
rows = self._execute_query(query, params, fetch="all")
|
| 405 |
+
return [UUID(row['parent_id']) for row in rows if row['parent_id']]
|
| 406 |
+
|
| 407 |
+
def get_child_tensor_ids(self, tensor_id: UUID) -> List[UUID]:
|
| 408 |
+
# Searches all lineage_metadata entries
|
| 409 |
+
query = """
|
| 410 |
+
SELECT lm.tensor_id
|
| 411 |
+
FROM lineage_metadata lm, jsonb_array_elements(lm.data->'parent_tensors') AS parent
|
| 412 |
+
WHERE parent.value->>'tensor_id' = %(target_parent_id)s;
|
| 413 |
+
"""
|
| 414 |
+
params = {"target_parent_id": str(tensor_id)} # Ensure UUID is string for JSONB comparison
|
| 415 |
+
rows = self._execute_query(query, params, fetch="all")
|
| 416 |
+
return [row['tensor_id'] for row in rows]
|
| 417 |
+
|
| 418 |
+
def search_tensor_descriptors(self, text_query: str, fields: List[str]) -> List[TensorDescriptor]:
|
| 419 |
+
if not fields:
|
| 420 |
+
return []
|
| 421 |
+
|
| 422 |
+
# Base query
|
| 423 |
+
select_clause = "SELECT DISTINCT td.* FROM tensor_descriptors td"
|
| 424 |
+
joins: List[str] = []
|
| 425 |
+
where_conditions: List[str] = []
|
| 426 |
+
query_params: Dict[str, Any] = {"text_query": f"%{text_query}%"} # For ILIKE
|
| 427 |
+
|
| 428 |
+
# Helper to add joins only once
|
| 429 |
+
joined_aliases = {"td"}
|
| 430 |
+
|
| 431 |
+
for field_path in fields:
|
| 432 |
+
parts = field_path.split('.', 1)
|
| 433 |
+
field_prefix = parts[0]
|
| 434 |
+
field_suffix = parts[1] if len(parts) > 1 else None
|
| 435 |
+
|
| 436 |
+
# Default to tensor_descriptors table if no prefix or prefix is 'tensor_descriptor'
|
| 437 |
+
table_alias = "td"
|
| 438 |
+
column_or_json_path = field_prefix # If no suffix, field_prefix is the column name
|
| 439 |
+
|
| 440 |
+
if field_prefix == "semantic":
|
| 441 |
+
if "sm" not in joined_aliases:
|
| 442 |
+
joins.append("LEFT JOIN semantic_metadata_entries sm ON td.tensor_id = sm.tensor_id")
|
| 443 |
+
joined_aliases.add("sm")
|
| 444 |
+
table_alias = "sm"
|
| 445 |
+
column_or_json_path = field_suffix if field_suffix else "name" # Default search semantic name
|
| 446 |
+
elif field_prefix in ["lineage", "computational", "quality", "relational", "usage"]:
|
| 447 |
+
# Assumes extended metadata tables are named e.g. "lineage_metadata"
|
| 448 |
+
# and have a JSONB 'data' column.
|
| 449 |
+
ext_table_name = f"{field_prefix}_metadata"
|
| 450 |
+
ext_alias = f"{field_prefix[0]}m" # e.g., lm, cm
|
| 451 |
+
if ext_alias not in joined_aliases:
|
| 452 |
+
joins.append(f"LEFT JOIN {ext_table_name} {ext_alias} ON td.tensor_id = {ext_alias}.tensor_id")
|
| 453 |
+
joined_aliases.add(ext_alias)
|
| 454 |
+
table_alias = ext_alias
|
| 455 |
+
# Construct JSON path, e.g., data->'source'->>'identifier'
|
| 456 |
+
# This requires parsing field_suffix if it's nested.
|
| 457 |
+
# For simplicity, assume field_suffix is a top-level key in the JSONB 'data' field for now.
|
| 458 |
+
# e.g. if field_suffix = "source.identifier", this becomes data->'source'->>'identifier'
|
| 459 |
+
if field_suffix:
|
| 460 |
+
json_path_parts = field_suffix.split('.')
|
| 461 |
+
json_op_path = "->".join([f"'{p}'" for p in json_path_parts[:-1]])
|
| 462 |
+
if json_op_path:
|
| 463 |
+
column_or_json_path = f"{table_alias}.data->{json_op_path}->>'{json_path_parts[-1]}'"
|
| 464 |
+
else: # Top level key in JSON
|
| 465 |
+
column_or_json_path = f"{table_alias}.data->>'{json_path_parts[-1]}'"
|
| 466 |
+
else: # Searching the whole JSON blob (less efficient, but possible)
|
| 467 |
+
column_or_json_path = f"{table_alias}.data::text" # Cast JSONB to text to search
|
| 468 |
+
else: # Assumed to be a direct column on tensor_descriptors
|
| 469 |
+
column_or_json_path = field_path # e.g. "owner" or "tags"
|
| 470 |
+
|
| 471 |
+
# Add condition for this field
|
| 472 |
+
# For array fields like 'tags', use a different operator or unnest
|
| 473 |
+
if table_alias == "td" and column_or_json_path == "tags":
|
| 474 |
+
# This is a basic way; for tags, often unnesting or specific array ops are better
|
| 475 |
+
where_conditions.append(f"array_to_string({table_alias}.{column_or_json_path}, ' ') ILIKE %(text_query)s")
|
| 476 |
+
else:
|
| 477 |
+
where_conditions.append(f"{column_or_json_path} ILIKE %(text_query)s")
|
| 478 |
+
|
| 479 |
+
if not where_conditions:
|
| 480 |
+
return [] # No valid fields to search
|
| 481 |
+
|
| 482 |
+
query = f"{select_clause} {' '.join(joins)} WHERE {' OR '.join(where_conditions)};"
|
| 483 |
+
|
| 484 |
+
rows = self._execute_query(query, query_params, fetch="all")
|
| 485 |
+
results = []
|
| 486 |
+
for row in rows:
|
| 487 |
+
data = dict(row)
|
| 488 |
+
data['data_type'] = DataType(data['data_type'])
|
| 489 |
+
data['storage_format'] = StorageFormat(data['storage_format'])
|
| 490 |
+
if data.get('access_control'): data['access_control'] = AccessControl(**data['access_control'])
|
| 491 |
+
if data.get('compression_info'): data['compression_info'] = CompressionInfo(**data['compression_info'])
|
| 492 |
+
results.append(TensorDescriptor(**data))
|
| 493 |
+
return results
|
| 494 |
+
|
| 495 |
+
|
| 496 |
+
def aggregate_tensor_descriptors(self, group_by_field: str, agg_function: str, agg_field: Optional[str]=None) -> Dict[Any, Any]:
|
| 497 |
+
# Simplified initial implementation: Group by direct TD fields, count only
|
| 498 |
+
if not group_by_field or not hasattr(TensorDescriptor, group_by_field.split('.')[0]): # Basic check
|
| 499 |
+
raise ValueError(f"Invalid group_by_field: {group_by_field}")
|
| 500 |
+
|
| 501 |
+
# Only supporting count and direct TD fields for now
|
| 502 |
+
# A full implementation needs dynamic JOINs and path resolution similar to search.
|
| 503 |
+
if agg_function.lower() != "count":
|
| 504 |
+
raise NotImplementedError(f"Aggregation function '{agg_function}' not yet fully implemented for all fields.")
|
| 505 |
+
|
| 506 |
+
# Assuming group_by_field is a direct column on tensor_descriptors table for this simplified version
|
| 507 |
+
# e.g. "data_type", "owner"
|
| 508 |
+
sql_group_by_field = group_by_field # Sanitize this if it comes from user input directly!
|
| 509 |
+
|
| 510 |
+
query = f"SELECT {sql_group_by_field}, COUNT(*) as count FROM tensor_descriptors GROUP BY {sql_group_by_field};"
|
| 511 |
+
|
| 512 |
+
rows = self._execute_query(query, {}, fetch="all") # type: ignore # Pass empty dict for params if none
|
| 513 |
+
return {row[sql_group_by_field]: row['count'] for row in rows}
|
| 514 |
+
|
| 515 |
+
# --- Export/Import Methods ---
|
| 516 |
+
def get_export_data(self, tensor_ids: Optional[List[UUID]] = None) -> 'TensorusExportData': # type: ignore
|
| 517 |
+
# This would require complex JOINs or multiple queries per tensor_id
|
| 518 |
+
# to gather all related metadata from different tables / JSONB columns.
|
| 519 |
+
# For a full implementation:
|
| 520 |
+
# 1. Determine list of tensor_ids to export (all if None).
|
| 521 |
+
# 2. For each tensor_id:
|
| 522 |
+
# a. Fetch TensorDescriptor.
|
| 523 |
+
# b. Fetch SemanticMetadata list.
|
| 524 |
+
# c. Fetch LineageMetadata (from its JSONB column).
|
| 525 |
+
# d. Fetch ComputationalMetadata (from its JSONB column).
|
| 526 |
+
# e. ... and so on for all extended types.
|
| 527 |
+
# f. Construct TensorusExportEntry.
|
| 528 |
+
# 3. Assemble into TensorusExportData.
|
| 529 |
+
raise NotImplementedError("get_export_data is not yet fully implemented for PostgresMetadataStorage.")
|
| 530 |
+
|
| 531 |
+
def import_data(self, data: 'TensorusExportData', conflict_strategy: str = "skip") -> Dict[str, int]: # type: ignore
|
| 532 |
+
# This is highly complex for SQL due to potential conflicts, different table structures,
|
| 533 |
+
# and the need for transactional integrity.
|
| 534 |
+
# For each entry in data.entries:
|
| 535 |
+
# 1. Handle TensorDescriptor:
|
| 536 |
+
# - If conflict_strategy is "skip": INSERT ... ON CONFLICT (tensor_id) DO NOTHING.
|
| 537 |
+
# - If conflict_strategy is "overwrite":
|
| 538 |
+
# - DELETE from tensor_descriptors WHERE tensor_id = ... (CASCADE should handle related data).
|
| 539 |
+
# - INSERT new TensorDescriptor.
|
| 540 |
+
# - Or use complex UPSERT that updates all fields if overwriting.
|
| 541 |
+
# 2. Handle SemanticMetadata (one-to-many, specific columns):
|
| 542 |
+
# - If overwriting, delete existing semantic entries for the tensor_id first.
|
| 543 |
+
# - INSERT new semantic entries. Handle (tensor_id, name) conflicts.
|
| 544 |
+
# 3. Handle other extended metadata (one-to-one, JSONB column):
|
| 545 |
+
# - Use INSERT ... ON CONFLICT (tensor_id) DO UPDATE SET data = EXCLUDED.data for upsert.
|
| 546 |
+
# - If "skip" and conflict, DO NOTHING.
|
| 547 |
+
# - If "overwrite" and it's part of a larger TD overwrite, cascade delete handles old, then insert new.
|
| 548 |
+
#
|
| 549 |
+
# All of this should ideally happen within a single transaction for the whole import,
|
| 550 |
+
# or at least per TensorusExportEntry.
|
| 551 |
+
raise NotImplementedError("import_data is not yet fully implemented for PostgresMetadataStorage.")
|
| 552 |
+
|
| 553 |
+
# --- Analytics Methods (Postgres Implementations) ---
|
| 554 |
+
def get_co_occurring_tags(self, min_co_occurrence: int = 2, limit: int = 10) -> Dict[str, List[Dict[str, Any]]]:
|
| 555 |
+
# This query is complex and can be inefficient on large datasets without specific optimizations.
|
| 556 |
+
# It unnests tags, creates pairs, counts them, then formats the output.
|
| 557 |
+
query = """
|
| 558 |
+
WITH tensor_tags AS (
|
| 559 |
+
SELECT tensor_id, unnest(tags) AS tag FROM tensor_descriptors WHERE cardinality(tags) >= 2
|
| 560 |
+
),
|
| 561 |
+
tag_pairs AS (
|
| 562 |
+
SELECT
|
| 563 |
+
t1.tensor_id,
|
| 564 |
+
LEAST(t1.tag, t2.tag) AS tag_a,
|
| 565 |
+
GREATEST(t1.tag, t2.tag) AS tag_b
|
| 566 |
+
FROM tensor_tags t1
|
| 567 |
+
JOIN tensor_tags t2 ON t1.tensor_id = t2.tensor_id AND t1.tag < t2.tag
|
| 568 |
+
),
|
| 569 |
+
pair_counts AS (
|
| 570 |
+
SELECT tag_a, tag_b, COUNT(*) AS co_occurrence_count
|
| 571 |
+
FROM tag_pairs
|
| 572 |
+
GROUP BY tag_a, tag_b
|
| 573 |
+
HAVING COUNT(*) >= %(min_co_occurrence)s
|
| 574 |
+
),
|
| 575 |
+
ranked_pairs AS (
|
| 576 |
+
SELECT *,
|
| 577 |
+
ROW_NUMBER() OVER (PARTITION BY tag_a ORDER BY co_occurrence_count DESC, tag_b) as rn_a,
|
| 578 |
+
ROW_NUMBER() OVER (PARTITION BY tag_b ORDER BY co_occurrence_count DESC, tag_a) as rn_b
|
| 579 |
+
FROM pair_counts
|
| 580 |
+
)
|
| 581 |
+
-- This final SELECT is tricky to get into the desired nested Dict structure directly from SQL.
|
| 582 |
+
-- It's often easier to process the pair_counts or ranked_pairs in Python.
|
| 583 |
+
-- For now, let's fetch ranked pairs and process in Python.
|
| 584 |
+
SELECT tag_a, tag_b, co_occurrence_count FROM ranked_pairs
|
| 585 |
+
WHERE rn_a <= %(limit)s OR rn_b <= %(limit)s;
|
| 586 |
+
-- This limit logic is not perfect for the desired output structure directly.
|
| 587 |
+
-- A simpler approach: just get all pairs above min_co_occurrence and limit/process in Python.
|
| 588 |
+
-- Simpler query for pair counts:
|
| 589 |
+
-- SELECT tag_a, tag_b, COUNT(*) AS co_occurrence_count
|
| 590 |
+
-- FROM tag_pairs
|
| 591 |
+
-- GROUP BY tag_a, tag_b
|
| 592 |
+
-- HAVING COUNT(*) >= %(min_co_occurrence)s
|
| 593 |
+
-- ORDER BY co_occurrence_count DESC;
|
| 594 |
+
-- Then process this result in Python to build the nested dict and apply limits.
|
| 595 |
+
"""
|
| 596 |
+
# Using the simpler query for pair counts
|
| 597 |
+
simpler_query = """
|
| 598 |
+
WITH tensor_tags AS (
|
| 599 |
+
SELECT tensor_id, unnest(tags) AS tag FROM tensor_descriptors WHERE cardinality(tags) >= 2
|
| 600 |
+
),
|
| 601 |
+
tag_pairs AS (
|
| 602 |
+
SELECT LEAST(t1.tag, t2.tag) AS tag_a, GREATEST(t1.tag, t2.tag) AS tag_b
|
| 603 |
+
FROM tensor_tags t1 JOIN tensor_tags t2 ON t1.tensor_id = t2.tensor_id AND t1.tag < t2.tag
|
| 604 |
+
)
|
| 605 |
+
SELECT tag_a, tag_b, COUNT(*) AS co_occurrence_count
|
| 606 |
+
FROM tag_pairs GROUP BY tag_a, tag_b
|
| 607 |
+
HAVING COUNT(*) >= %(min_co_occurrence)s
|
| 608 |
+
ORDER BY tag_a, co_occurrence_count DESC;
|
| 609 |
+
"""
|
| 610 |
+
params = {"min_co_occurrence": min_co_occurrence, "limit": limit} # Limit used in Python processing
|
| 611 |
+
|
| 612 |
+
rows = self._execute_query(simpler_query, params, fetch="all") # type: ignore
|
| 613 |
+
|
| 614 |
+
co_occurrence_map: Dict[str, List[Dict[str, Any]]] = {}
|
| 615 |
+
if rows:
|
| 616 |
+
for row in rows:
|
| 617 |
+
tag_a, tag_b, count = row['tag_a'], row['tag_b'], row['co_occurrence_count']
|
| 618 |
+
# Add for tag_a
|
| 619 |
+
if tag_a not in co_occurrence_map: co_occurrence_map[tag_a] = []
|
| 620 |
+
if len(co_occurrence_map[tag_a]) < limit:
|
| 621 |
+
co_occurrence_map[tag_a].append({"tag": tag_b, "count": count})
|
| 622 |
+
# Add for tag_b
|
| 623 |
+
if tag_b not in co_occurrence_map: co_occurrence_map[tag_b] = []
|
| 624 |
+
if len(co_occurrence_map[tag_b]) < limit:
|
| 625 |
+
co_occurrence_map[tag_b].append({"tag": tag_a, "count": count})
|
| 626 |
+
|
| 627 |
+
# Sort internal lists (already sorted by query for tag_a, but not for tag_b's list)
|
| 628 |
+
for tag_key in co_occurrence_map:
|
| 629 |
+
co_occurrence_map[tag_key].sort(key=lambda x: x["count"], reverse=True)
|
| 630 |
+
|
| 631 |
+
return {k: v for k, v in co_occurrence_map.items() if v} # Filter out tags with no co-occurrences meeting criteria
|
| 632 |
+
|
| 633 |
+
|
| 634 |
+
def get_stale_tensors(self, threshold_days: int, limit: int = 100) -> List[TensorDescriptor]:
|
| 635 |
+
query = """
|
| 636 |
+
SELECT td.*
|
| 637 |
+
FROM tensor_descriptors td
|
| 638 |
+
LEFT JOIN usage_metadata um ON td.tensor_id = um.tensor_id
|
| 639 |
+
WHERE COALESCE( (um.data->>'last_accessed_at')::TIMESTAMPTZ, td.last_modified_timestamp ) < (NOW() - INTERVAL '1 day' * %(threshold_days)s)
|
| 640 |
+
ORDER BY COALESCE( (um.data->>'last_accessed_at')::TIMESTAMPTZ, td.last_modified_timestamp ) ASC
|
| 641 |
+
LIMIT %(limit)s;
|
| 642 |
+
"""
|
| 643 |
+
params = {"threshold_days": threshold_days, "limit": limit}
|
| 644 |
+
rows = self._execute_query(query, params, fetch="all") # type: ignore
|
| 645 |
+
|
| 646 |
+
results = []
|
| 647 |
+
if rows:
|
| 648 |
+
for row in rows:
|
| 649 |
+
data = dict(row)
|
| 650 |
+
data['data_type'] = DataType(data['data_type'])
|
| 651 |
+
data['storage_format'] = StorageFormat(data['storage_format'])
|
| 652 |
+
if data.get('access_control'): data['access_control'] = AccessControl(**data['access_control'])
|
| 653 |
+
if data.get('compression_info'): data['compression_info'] = CompressionInfo(**data['compression_info'])
|
| 654 |
+
results.append(TensorDescriptor(**data))
|
| 655 |
+
return results
|
| 656 |
+
|
| 657 |
+
def get_complex_tensors(self, min_parent_count: Optional[int] = None, min_transformation_steps: Optional[int] = None, limit: int = 100) -> List[TensorDescriptor]:
|
| 658 |
+
if min_parent_count is None and min_transformation_steps is None:
|
| 659 |
+
raise ValueError("At least one criterion (min_parent_count or min_transformation_steps) must be provided.")
|
| 660 |
+
|
| 661 |
+
conditions = []
|
| 662 |
+
params: Dict[str, Any] = {"limit": limit}
|
| 663 |
+
|
| 664 |
+
base_query = "SELECT td.* FROM tensor_descriptors td LEFT JOIN lineage_metadata lm ON td.tensor_id = lm.tensor_id"
|
| 665 |
+
|
| 666 |
+
if min_parent_count is not None:
|
| 667 |
+
conditions.append("jsonb_array_length(lm.data->'parent_tensors') >= %(min_parent_count)s")
|
| 668 |
+
params["min_parent_count"] = min_parent_count
|
| 669 |
+
|
| 670 |
+
if min_transformation_steps is not None:
|
| 671 |
+
conditions.append("jsonb_array_length(lm.data->'transformation_history') >= %(min_transformation_steps)s")
|
| 672 |
+
params["min_transformation_steps"] = min_transformation_steps
|
| 673 |
+
|
| 674 |
+
query = f"{base_query} WHERE ({' OR '.join(conditions)}) LIMIT %(limit)s;"
|
| 675 |
+
|
| 676 |
+
rows = self._execute_query(query, params, fetch="all") # type: ignore
|
| 677 |
+
results = []
|
| 678 |
+
if rows:
|
| 679 |
+
for row in rows:
|
| 680 |
+
data = dict(row)
|
| 681 |
+
data['data_type'] = DataType(data['data_type'])
|
| 682 |
+
data['storage_format'] = StorageFormat(data['storage_format'])
|
| 683 |
+
if data.get('access_control'): data['access_control'] = AccessControl(**data['access_control'])
|
| 684 |
+
if data.get('compression_info'): data['compression_info'] = CompressionInfo(**data['compression_info'])
|
| 685 |
+
results.append(TensorDescriptor(**data))
|
| 686 |
+
return results
|
| 687 |
+
|
| 688 |
+
# --- Health and Metrics Methods (Postgres Implementations) ---
|
| 689 |
+
def check_health(self) -> tuple[bool, str]:
|
| 690 |
+
try:
|
| 691 |
+
self._execute_query("SELECT 1;", fetch=None)
|
| 692 |
+
return True, "postgres"
|
| 693 |
+
except Exception as e:
|
| 694 |
+
logger.error(f"Postgres health check failed: {e}")
|
| 695 |
+
return False, "postgres"
|
| 696 |
+
|
| 697 |
+
def get_tensor_descriptors_count(self) -> int:
|
| 698 |
+
query = "SELECT COUNT(*) as count FROM tensor_descriptors;"
|
| 699 |
+
row = self._execute_query(query, fetch="one")
|
| 700 |
+
return row['count'] if row else 0
|
| 701 |
+
|
| 702 |
+
def get_extended_metadata_count(self, metadata_model_name: str) -> int:
|
| 703 |
+
# Map Pydantic model names to table names
|
| 704 |
+
# This assumes a specific naming convention for tables, e.g., lowercase with underscores.
|
| 705 |
+
table_name_map = {
|
| 706 |
+
"LineageMetadata": "lineage_metadata",
|
| 707 |
+
"ComputationalMetadata": "computational_metadata",
|
| 708 |
+
"QualityMetadata": "quality_metadata",
|
| 709 |
+
"RelationalMetadata": "relational_metadata",
|
| 710 |
+
"UsageMetadata": "usage_metadata",
|
| 711 |
+
"SemanticMetadata": "semantic_metadata_entries" # Special case for semantic
|
| 712 |
+
}
|
| 713 |
+
table_name = table_name_map.get(metadata_model_name)
|
| 714 |
+
|
| 715 |
+
if not table_name:
|
| 716 |
+
# Or raise error, or log warning
|
| 717 |
+
logger.warning(
|
| 718 |
+
f"get_extended_metadata_count called for unmapped model name '{metadata_model_name}' in Postgres."
|
| 719 |
+
)
|
| 720 |
+
return 0
|
| 721 |
+
|
| 722 |
+
query = f"SELECT COUNT(*) as count FROM {table_name};" # Ensure table_name is not from user input directly
|
| 723 |
+
row = self._execute_query(query, fetch="one")
|
| 724 |
+
return row['count'] if row else 0
|
| 725 |
+
|
| 726 |
+
def clear_all_data(self) -> None:
|
| 727 |
+
# In a real scenario, might TRUNCATE tables or use a specific test DB
|
| 728 |
+
self._execute_query("DELETE FROM semantic_metadata_entries;") # Order matters due to FKs
|
| 729 |
+
self._execute_query("DELETE FROM lineage_metadata;")
|
| 730 |
+
self._execute_query("DELETE FROM computational_metadata;")
|
| 731 |
+
self._execute_query("DELETE FROM quality_metadata;")
|
| 732 |
+
self._execute_query("DELETE FROM relational_metadata;") # Corrected self_execute_query
|
| 733 |
+
self._execute_query("DELETE FROM usage_metadata;")
|
| 734 |
+
self._execute_query("DELETE FROM tensor_descriptors;")
|
| 735 |
+
logger.info("Postgres tables cleared (conceptually).")
|
| 736 |
+
|
| 737 |
+
def close_pool(self):
|
| 738 |
+
if self.pool:
|
| 739 |
+
self.pool.closeall()
|
| 740 |
+
self.pool = None
|
| 741 |
+
logger.info("PostgreSQL connection pool closed.")
|
tensorus/metadata/schemas.py
ADDED
|
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from enum import Enum
|
| 2 |
+
from typing import List, Dict, Optional, Any
|
| 3 |
+
from uuid import UUID, uuid4
|
| 4 |
+
from datetime import datetime
|
| 5 |
+
|
| 6 |
+
from pydantic import BaseModel, Field, field_validator, model_validator, ValidationInfo
|
| 7 |
+
|
| 8 |
+
class DataType(str, Enum):
|
| 9 |
+
FLOAT32 = "float32"
|
| 10 |
+
FLOAT64 = "float64"
|
| 11 |
+
FLOAT16 = "float16" # Added
|
| 12 |
+
INT32 = "int32"
|
| 13 |
+
INT64 = "int64"
|
| 14 |
+
INT16 = "int16" # Added
|
| 15 |
+
INT8 = "int8" # Added
|
| 16 |
+
UINT8 = "uint8" # Added
|
| 17 |
+
BOOLEAN = "boolean"
|
| 18 |
+
STRING = "string"
|
| 19 |
+
COMPLEX64 = "complex64" # Added
|
| 20 |
+
COMPLEX128 = "complex128" # Added
|
| 21 |
+
OTHER = "other" # Added
|
| 22 |
+
|
| 23 |
+
class StorageFormat(str, Enum):
|
| 24 |
+
RAW = "raw"
|
| 25 |
+
NUMPY_NPZ = "numpy_npz" # Added
|
| 26 |
+
HDF5 = "hdf5" # Added
|
| 27 |
+
COMPRESSED_ZLIB = "compressed_zlib"
|
| 28 |
+
COMPRESSED_GZIP = "compressed_gzip"
|
| 29 |
+
CUSTOM = "custom" # Added
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class AccessControl(BaseModel):
|
| 33 |
+
read: List[str] = Field(default_factory=list)
|
| 34 |
+
write: List[str] = Field(default_factory=list)
|
| 35 |
+
delete: List[str] = Field(default_factory=list)
|
| 36 |
+
owner_permissions: Optional[str] = None # e.g. "rwd"
|
| 37 |
+
group_permissions: Optional[Dict[str, str]] = None # e.g. {"group_name": "rw"}
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class CompressionInfo(BaseModel):
|
| 41 |
+
algorithm: str
|
| 42 |
+
level: Optional[int] = None
|
| 43 |
+
settings: Optional[Dict[str, Any]] = None # For other settings
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class TensorDescriptor(BaseModel):
|
| 47 |
+
tensor_id: UUID = Field(default_factory=uuid4)
|
| 48 |
+
dimensionality: int = Field(..., ge=0)
|
| 49 |
+
shape: List[int]
|
| 50 |
+
data_type: DataType
|
| 51 |
+
storage_format: StorageFormat = StorageFormat.RAW
|
| 52 |
+
creation_timestamp: datetime = Field(default_factory=datetime.utcnow)
|
| 53 |
+
last_modified_timestamp: datetime = Field(default_factory=datetime.utcnow)
|
| 54 |
+
owner: str # User or service ID
|
| 55 |
+
access_control: AccessControl = Field(default_factory=AccessControl)
|
| 56 |
+
byte_size: int = Field(..., ge=0)
|
| 57 |
+
checksum: Optional[str] = None # e.g. md5, sha256
|
| 58 |
+
compression_info: Optional[CompressionInfo] = None
|
| 59 |
+
tags: Optional[List[str]] = Field(default_factory=list)
|
| 60 |
+
metadata: Optional[Dict[str, Any]] = Field(default_factory=dict) # Generic metadata, allows richer values
|
| 61 |
+
|
| 62 |
+
@field_validator('shape', mode='before')
|
| 63 |
+
def validate_shape(cls, v, info: ValidationInfo):
|
| 64 |
+
dimensionality = info.data.get('dimensionality')
|
| 65 |
+
if dimensionality is not None and len(v) != dimensionality:
|
| 66 |
+
raise ValueError('Shape must have a length equal to dimensionality')
|
| 67 |
+
if not all(isinstance(dim, int) and dim >= 0 for dim in v):
|
| 68 |
+
raise ValueError('All dimensions in shape must be non-negative integers')
|
| 69 |
+
return v
|
| 70 |
+
|
| 71 |
+
@field_validator('last_modified_timestamp') # Defaults to mode='after'
|
| 72 |
+
def validate_last_modified(cls, v: datetime, info: ValidationInfo): # v is now datetime
|
| 73 |
+
# Ensure creation_timestamp is also datetime if accessed from info.data
|
| 74 |
+
# However, it's better to rely on already validated fields if possible,
|
| 75 |
+
# or ensure this validator runs after creation_timestamp is validated and converted.
|
| 76 |
+
# For direct comparison, both should be datetime.
|
| 77 |
+
# Pydantic v2 typically ensures other fields referenced in model_validator or
|
| 78 |
+
# late-stage field_validators are already validated/coerced.
|
| 79 |
+
# Assuming creation_timestamp in info.data is already a datetime due to its type hint and default_factory
|
| 80 |
+
creation_timestamp_from_data = info.data.get('creation_timestamp')
|
| 81 |
+
if isinstance(creation_timestamp_from_data, datetime) and v < creation_timestamp_from_data:
|
| 82 |
+
raise ValueError('Last modified timestamp cannot be before creation timestamp')
|
| 83 |
+
# If creation_timestamp is not yet a datetime (e.g. if it was also mode='before' and a string),
|
| 84 |
+
# this comparison would also fail. Best practice is mode='after' for inter-field validation
|
| 85 |
+
# when types are critical.
|
| 86 |
+
return v
|
| 87 |
+
|
| 88 |
+
def update_last_modified(self):
|
| 89 |
+
self.last_modified_timestamp = datetime.utcnow()
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class SemanticMetadata(BaseModel):
|
| 93 |
+
# Link to TensorDescriptor is implicit via storage key (tensor_id)
|
| 94 |
+
# No, explicit tensor_id is better for standalone validation and clarity.
|
| 95 |
+
tensor_id: UUID
|
| 96 |
+
name: str # Name of this specific semantic annotation (e.g., "primary_class_label", "bounding_boxes")
|
| 97 |
+
description: str
|
| 98 |
+
|
| 99 |
+
@field_validator('name', 'description')
|
| 100 |
+
def check_not_empty(cls, v: str) -> str:
|
| 101 |
+
if not v or not v.strip():
|
| 102 |
+
raise ValueError('Name and description fields cannot be empty or just whitespace.')
|
| 103 |
+
return v
|
| 104 |
+
|
| 105 |
+
# --- Extended Schemas ---
|
| 106 |
+
|
| 107 |
+
# Part of LineageMetadata
|
| 108 |
+
class LineageSourceType(str, Enum):
|
| 109 |
+
FILE = "file"
|
| 110 |
+
API = "api"
|
| 111 |
+
COMPUTATION = "computation"
|
| 112 |
+
DATABASE = "database"
|
| 113 |
+
STREAM = "stream"
|
| 114 |
+
USER_UPLOAD = "user_upload" # Added
|
| 115 |
+
SYNTHETIC = "synthetic" # Added
|
| 116 |
+
OTHER = "other"
|
| 117 |
+
|
| 118 |
+
class LineageSource(BaseModel):
|
| 119 |
+
type: LineageSourceType
|
| 120 |
+
identifier: str # e.g., file path, API endpoint URL, query string, stream topic
|
| 121 |
+
details: Optional[Dict[str, Any]] = Field(default_factory=dict) # e.g., API request params, version of source data
|
| 122 |
+
|
| 123 |
+
# Part of LineageMetadata
|
| 124 |
+
class ParentTensorLink(BaseModel):
|
| 125 |
+
tensor_id: UUID
|
| 126 |
+
relationship: Optional[str] = None # e.g., "transformed_from", "derived_from", "aggregated_from"
|
| 127 |
+
|
| 128 |
+
# Part of LineageMetadata
|
| 129 |
+
class TransformationStep(BaseModel):
|
| 130 |
+
operation: str # e.g., "normalize", "resize", "fft", "model_inference"
|
| 131 |
+
parameters: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
| 132 |
+
timestamp: datetime = Field(default_factory=datetime.utcnow)
|
| 133 |
+
operator: Optional[str] = None # User or service that performed the operation
|
| 134 |
+
software_version: Optional[str] = None # e.g., library version used for the operation
|
| 135 |
+
|
| 136 |
+
# Part of LineageMetadata
|
| 137 |
+
class VersionControlInfo(BaseModel):
|
| 138 |
+
repository: Optional[str] = None # URL of the repository
|
| 139 |
+
commit_hash: Optional[str] = None
|
| 140 |
+
branch: Optional[str] = None
|
| 141 |
+
tag: Optional[str] = None
|
| 142 |
+
path_in_repo: Optional[str] = None # If applicable
|
| 143 |
+
|
| 144 |
+
class LineageMetadata(BaseModel):
|
| 145 |
+
tensor_id: UUID # Links to the TensorDescriptor
|
| 146 |
+
source: Optional[LineageSource] = None
|
| 147 |
+
parent_tensors: List[ParentTensorLink] = Field(default_factory=list)
|
| 148 |
+
transformation_history: List[TransformationStep] = Field(default_factory=list)
|
| 149 |
+
version: Optional[str] = None # Version string for this tensor instance
|
| 150 |
+
version_control: Optional[VersionControlInfo] = None
|
| 151 |
+
provenance: Optional[Dict[str, Any]] = Field(default_factory=dict) # For other unstructured provenance info
|
| 152 |
+
|
| 153 |
+
class ComputationalMetadata(BaseModel):
|
| 154 |
+
tensor_id: UUID
|
| 155 |
+
algorithm: Optional[str] = None # e.g., "ResNet50", "PCA", "ARIMA"
|
| 156 |
+
parameters: Optional[Dict[str, Any]] = Field(default_factory=dict) # Algorithm parameters
|
| 157 |
+
computational_graph_ref: Optional[str] = None # Reference to a stored graph (e.g., ONNX model path, DVC stage)
|
| 158 |
+
execution_environment: Optional[Dict[str, Any]] = Field(default_factory=dict) # e.g., OS, Python version, library versions
|
| 159 |
+
computation_time_seconds: Optional[float] = None
|
| 160 |
+
hardware_info: Optional[Dict[str, Any]] = Field(default_factory=dict) # e.g., CPU, GPU, RAM
|
| 161 |
+
|
| 162 |
+
@field_validator('computation_time_seconds')
|
| 163 |
+
def check_non_negative_time(cls, v):
|
| 164 |
+
if v is not None and v < 0:
|
| 165 |
+
raise ValueError('Computation time cannot be negative')
|
| 166 |
+
return v
|
| 167 |
+
|
| 168 |
+
# Part of QualityMetadata
|
| 169 |
+
class QualityStatistics(BaseModel):
|
| 170 |
+
min_value: Optional[float] = None # Renamed for clarity
|
| 171 |
+
max_value: Optional[float] = None # Renamed for clarity
|
| 172 |
+
mean: Optional[float] = None
|
| 173 |
+
std_dev: Optional[float] = None
|
| 174 |
+
median: Optional[float] = None
|
| 175 |
+
variance: Optional[float] = None
|
| 176 |
+
percentiles: Optional[Dict[float, float]] = None # e.g. {25: val, 50: val, 75: val}
|
| 177 |
+
histogram: Optional[Dict[str, Any]] = None # e.g. {"bins": [], "counts": []}
|
| 178 |
+
|
| 179 |
+
# Part of QualityMetadata
|
| 180 |
+
class MissingValuesInfo(BaseModel):
|
| 181 |
+
count: int = Field(..., ge=0)
|
| 182 |
+
percentage: float = Field(..., ge=0.0, le=100.0)
|
| 183 |
+
strategy: Optional[str] = None # e.g., "imputed_mean", "removed_rows"
|
| 184 |
+
|
| 185 |
+
# Part of QualityMetadata
|
| 186 |
+
class OutlierInfo(BaseModel):
|
| 187 |
+
count: int = Field(..., ge=0)
|
| 188 |
+
percentage: float = Field(..., ge=0.0, le=100.0)
|
| 189 |
+
method_used: Optional[str] = None # e.g., "IQR", "Z-score"
|
| 190 |
+
severity: Optional[Dict[str, int]] = None # e.g. {"mild": 10, "severe": 2}
|
| 191 |
+
|
| 192 |
+
class QualityMetadata(BaseModel):
|
| 193 |
+
tensor_id: UUID
|
| 194 |
+
statistics: Optional[QualityStatistics] = None
|
| 195 |
+
missing_values: Optional[MissingValuesInfo] = None
|
| 196 |
+
outliers: Optional[OutlierInfo] = None
|
| 197 |
+
noise_level: Optional[float] = None # Could be SNR or a qualitative score
|
| 198 |
+
confidence_score: Optional[float] = Field(default=None, ge=0.0, le=1.0)
|
| 199 |
+
validation_results: Optional[Dict[str, Any]] = Field(default_factory=dict) # e.g., {"schema_conformity": True, "range_checks": "passed"}
|
| 200 |
+
drift_score: Optional[float] = None # Data drift score compared to a reference
|
| 201 |
+
|
| 202 |
+
# Part of RelationalMetadata
|
| 203 |
+
class RelatedTensorLink(BaseModel):
|
| 204 |
+
related_tensor_id: UUID # Renamed for clarity
|
| 205 |
+
relationship_type: str # e.g., "augmentation_of", "projection_of", "component_of", "alternative_view"
|
| 206 |
+
details: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
| 207 |
+
|
| 208 |
+
class RelationalMetadata(BaseModel):
|
| 209 |
+
tensor_id: UUID
|
| 210 |
+
related_tensors: List[RelatedTensorLink] = Field(default_factory=list)
|
| 211 |
+
collections: List[str] = Field(default_factory=list) # List of collection names or IDs this tensor belongs to
|
| 212 |
+
dependencies: List[UUID] = Field(default_factory=list) # Other tensors this one directly depends on (not necessarily lineage parents)
|
| 213 |
+
dataset_context: Optional[str] = None # Name or ID of the dataset this tensor is part of
|
| 214 |
+
|
| 215 |
+
# Part of UsageMetadata
|
| 216 |
+
class UsageAccessRecord(BaseModel):
|
| 217 |
+
accessed_at: datetime = Field(default_factory=datetime.utcnow)
|
| 218 |
+
user_or_service: str
|
| 219 |
+
operation_type: str # e.g., "read", "write", "query", "transform", "visualize"
|
| 220 |
+
details: Optional[Dict[str, Any]] = Field(default_factory=dict) # e.g., query parameters, sub-selection info
|
| 221 |
+
status: Optional[str] = "success" # "success", "failure"
|
| 222 |
+
|
| 223 |
+
class UsageMetadata(BaseModel):
|
| 224 |
+
tensor_id: UUID
|
| 225 |
+
access_history: List[UsageAccessRecord] = Field(default_factory=list)
|
| 226 |
+
usage_frequency: Optional[int] = Field(default=0, ge=0) # Could be total accesses or accesses in a time window
|
| 227 |
+
last_accessed_at: Optional[datetime] = None # Explicitly tracked or derived from access_history
|
| 228 |
+
application_references: List[str] = Field(default_factory=list) # Names or IDs of applications/models using this tensor
|
| 229 |
+
purpose: Optional[Dict[str, str]] = Field(default_factory=dict) # e.g. {"training_model_X": "feature_set_A"}
|
| 230 |
+
|
| 231 |
+
@model_validator(mode='after')
|
| 232 |
+
def sync_derived_usage_fields(self) -> 'UsageMetadata':
|
| 233 |
+
if self.access_history: # This will be the fully validated list of UsageAccessRecord objects
|
| 234 |
+
# Update usage_frequency
|
| 235 |
+
self.usage_frequency = len(self.access_history)
|
| 236 |
+
|
| 237 |
+
# Update last_accessed_at
|
| 238 |
+
latest_access_in_history = max(record.accessed_at for record in self.access_history)
|
| 239 |
+
if self.last_accessed_at is None or latest_access_in_history > self.last_accessed_at:
|
| 240 |
+
self.last_accessed_at = latest_access_in_history
|
| 241 |
+
else:
|
| 242 |
+
# If there's no access_history, ensure frequency is 0 if not explicitly set otherwise
|
| 243 |
+
# and last_accessed_at remains as is (or None).
|
| 244 |
+
# The default Field(default=0) for usage_frequency should handle the initial case.
|
| 245 |
+
# If an empty list is provided for access_history, this will correctly set frequency to 0.
|
| 246 |
+
self.usage_frequency = 0
|
| 247 |
+
# self.last_accessed_at will retain its input or default None if access_history is empty
|
| 248 |
+
|
| 249 |
+
return self
|
| 250 |
+
|