diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..b23554198468a191fc51460b3efd199238bd957d 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,38 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +pyspur/docs/images/checks-passed.png filter=lfs diff=lfs merge=lfs -text +pyspur/docs/images/deploy/dark_deploy_python.png filter=lfs diff=lfs merge=lfs -text +pyspur/docs/images/deploy/dark_deploy_ts.png filter=lfs diff=lfs merge=lfs -text +pyspur/docs/images/deploy/light_deploy_python.png filter=lfs diff=lfs merge=lfs -text +pyspur/docs/images/deploy/light_deploy_ts.png filter=lfs diff=lfs merge=lfs -text +pyspur/docs/images/evals/evals.mp4 filter=lfs diff=lfs merge=lfs -text +pyspur/docs/images/example_walkthrough/0_dark.png filter=lfs diff=lfs merge=lfs -text +pyspur/docs/images/example_walkthrough/0_light.png filter=lfs diff=lfs merge=lfs -text +pyspur/docs/images/example_walkthrough/1_dark.png filter=lfs diff=lfs merge=lfs -text +pyspur/docs/images/example_walkthrough/1_light.png filter=lfs diff=lfs merge=lfs -text +pyspur/docs/images/example_walkthrough/3_dark.png filter=lfs diff=lfs merge=lfs -text +pyspur/docs/images/example_walkthrough/3_light.png filter=lfs diff=lfs merge=lfs -text +pyspur/docs/images/example_walkthrough/4_dark.png filter=lfs diff=lfs merge=lfs -text +pyspur/docs/images/example_walkthrough/4_light.png filter=lfs diff=lfs merge=lfs -text +pyspur/docs/images/example_walkthrough/5_dark.gif filter=lfs diff=lfs merge=lfs -text +pyspur/docs/images/example_walkthrough/5_light.gif filter=lfs diff=lfs merge=lfs -text +pyspur/docs/images/example_walkthrough/6_dark.mp4 filter=lfs diff=lfs merge=lfs -text +pyspur/docs/images/example_walkthrough/6_light.mp4 filter=lfs diff=lfs merge=lfs -text +pyspur/docs/images/example_walkthrough/7_dark.mp4 filter=lfs diff=lfs merge=lfs -text +pyspur/docs/images/example_walkthrough/7_light.mp4 filter=lfs diff=lfs merge=lfs -text +pyspur/docs/images/example_walkthrough/8_dark.mp4 filter=lfs diff=lfs merge=lfs -text +pyspur/docs/images/example_walkthrough/8_light.mp4 filter=lfs diff=lfs merge=lfs -text +pyspur/docs/images/example_walkthrough/9_dark.mp4 filter=lfs diff=lfs merge=lfs -text +pyspur/docs/images/example_walkthrough/9_light.mp4 filter=lfs diff=lfs merge=lfs -text +pyspur/docs/images/hero-dark.mp4 filter=lfs diff=lfs merge=lfs -text +pyspur/docs/images/hero-light.mp4 filter=lfs diff=lfs merge=lfs -text +pyspur/docs/images/hero.png filter=lfs diff=lfs merge=lfs -text +pyspur/docs/images/rag/rag1.mp4 filter=lfs diff=lfs merge=lfs -text +pyspur/docs/images/rag/rag2.mp4 filter=lfs diff=lfs merge=lfs -text +pyspur/docs/images/rag/rag3.mp4 filter=lfs diff=lfs merge=lfs -text +pyspur/frontend/public/images/firecrawl.png filter=lfs diff=lfs merge=lfs -text +pyspur/frontend/public/images/google_sheets.png filter=lfs diff=lfs merge=lfs -text +pyspur/frontend/public/images/meta.png filter=lfs diff=lfs merge=lfs -text +pyspur/frontend/public/images/slack.png filter=lfs diff=lfs merge=lfs -text +pyspur/frontend/public/pyspur-black.png filter=lfs diff=lfs merge=lfs -text diff --git a/pyspur/.cursor/rules/frontend-api-calls.mdc b/pyspur/.cursor/rules/frontend-api-calls.mdc new file mode 100644 index 0000000000000000000000000000000000000000..83f92b06375547684ffbec2d2fb8be6a6c786ab5 --- /dev/null +++ b/pyspur/.cursor/rules/frontend-api-calls.mdc @@ -0,0 +1,6 @@ +--- +description: API calls in frontend +globs: +alwaysApply: false +--- +API calls inside the frontend should always be stored inside [api.ts](mdc:frontend/src/utils/api.ts) and use the API_BASE_URL defined there \ No newline at end of file diff --git a/pyspur/.devcontainer/.bashrc b/pyspur/.devcontainer/.bashrc new file mode 100644 index 0000000000000000000000000000000000000000..2b47d530fd585bae288e7039ce30eb26988f00ac --- /dev/null +++ b/pyspur/.devcontainer/.bashrc @@ -0,0 +1,21 @@ +# Enable bash completion +if [ -f /etc/bash_completion ]; then + . /etc/bash_completion +fi + +# Docker compose aliases +alias dcup='docker compose -f docker-compose.dev.yml up --build -d' +alias dlogb='docker logs -f pyspur-backend-1 --since 5m' +alias dlogf='docker logs -f pyspur-frontend-1 --since 5m' +alias dlogn='docker logs -f pyspur-nginx-1 --since 5m' +alias dlogs='docker compose logs -f --since 5m' + +# Test frontend build in temporary container +alias tfeb='docker build --target production -f Dockerfile.frontend \ + --no-cache -t temp-frontend-build . && \ + echo "✅ Frontend build successful!" && \ + docker rmi temp-frontend-build || \ + echo "❌ Frontend build failed!"' + +# Add color to the terminal +export PS1='\[\033[01;32m\]\u@\h\[\033[00m\]:\[\033[01;34m\]\w\[\033[00m\]\$ ' \ No newline at end of file diff --git a/pyspur/.devcontainer/Dockerfile b/pyspur/.devcontainer/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..3aab23cef3d9fba63ca95e0d1ea417233f847262 --- /dev/null +++ b/pyspur/.devcontainer/Dockerfile @@ -0,0 +1,28 @@ +# Base stage +FROM python:3.12 as base +WORKDIR /pyspur + +# Install bash completion +RUN apt-get update && apt-get install -y \ + bash-completion \ + nano \ + vim \ + && rm -rf /var/lib/apt/lists/* + +RUN pip install uv + +COPY backend/ backend/ +RUN uv pip install --system -e "/pyspur/backend/[dev]" + +# Install Node.js for frontend development +RUN curl -fsSL https://deb.nodesource.com/setup_23.x | bash - \ + && apt-get install -y nodejs \ + && npm install -g npm@latest + +# Development stage +FROM base as development +WORKDIR /pyspur/frontend +COPY frontend/package*.json ./ +RUN npm install + +WORKDIR /pyspur \ No newline at end of file diff --git a/pyspur/.devcontainer/README.md b/pyspur/.devcontainer/README.md new file mode 100644 index 0000000000000000000000000000000000000000..9f4147629a525d2d96ac8ce2ed5829b3c905daf9 --- /dev/null +++ b/pyspur/.devcontainer/README.md @@ -0,0 +1,130 @@ +# Development Container Configuration + +[![Open in GitHub Codespaces](https://github.com/codespaces/badge.svg)](https://codespaces.new/pyspur-dev/pyspur) + +This directory contains configuration files for Visual Studio Code Dev Containers / GitHub Codespaces. Dev containers provide a consistent, isolated development environment for this project. + +## Contents + +- `devcontainer.json` - The main configuration file that defines the development container settings +- `Dockerfile` - Defines the container image and development environment + +## Usage + +### Prerequisites + +- Visual Studio Code +- Docker installation: + - Docker Desktop (Windows/macOS) + - Docker Engine (Linux) +- [Remote - Containers](https://marketplace.visualstudio.com/items?itemName=ms-vscode-remote.remote-containers) extension for VS Code + +### Getting Started + +1. Open this project in Visual Studio Code +2. When prompted, click "Reopen in Container" + - Alternatively, press `F1` and select "Remote-Containers: Reopen in Container" +3. Wait for the container to build and initialize +4. Launch the application using: + ```bash + dcup + ``` +5. Access the application (assuming the ports are forwarded as is to the host machine) + - Main application: http://localhost:6080 + - Frontend development server: http://localhost:3000 + - Backend API: http://localhost:8000 + +The development environment will be automatically configured with all necessary tools and extensions. + +### Viewing Logs + +You can monitor the application logs using these commands: + +- View all container logs: + ```bash + dlogs + ``` +- View backend logs only: + ```bash + dlogb + ``` +- View frontend logs only: + ```bash + dlogf + ``` +- View nginx logs only: + ```bash + dlogn + ``` + +All log commands show the last 5 minutes of logs and continue to tail new entries. + +### Modifying the database schemas + + +1. **Stop Containers** + ```bash + docker compose down + ``` + +2. **Generate a Migration** + ```bash + ./generate_migrations.sh 002 + ``` + - Migration file appears in `./backend/app/models/management/alembic/versions/` with prefix `002_...`. + +3. **Review the Generated Script** + - Open the file to ensure it has the intended changes. + +4. **Apply the Migration** + ```bash + docker compose down + docker compose up --build + ``` + - Alembic applies the new migration automatically on startup. + +5. **Test the App** + - Confirm new tables/columns work as expected. + +6. **Commit & Push** + ```bash + git add . + git commit -m "Add migration 002 " + git push origin + ``` + +### Troubleshooting DBs issues + +When modifying the DB models, one needs to be careful to not destroy the local DB due to lacking migrations. + +Sometimes the local dev DB gets corrupted. In such cases, assuming it does not contain production data, the quickest fix is to simply delete it and let the backend rebuild it the next time you run `docker compose up` (or `dcup`). + +You can do so via running + +```bash +docker volume rm pyspur_postgres_data +``` + +## Customization + +You can customize the development environment by: + +- Modifying `devcontainer.json` to: + - Add VS Code extensions + - Set container-specific settings + - Configure environment variables +- Updating the `Dockerfile` to: + - Install additional packages + - Configure system settings + - Add development tools + +## Troubleshooting + +If you encounter issues: + +1. Rebuild the container: `F1` → "Remote-Containers: Rebuild Container" +2. Check Docker logs for build errors +3. Verify Docker Desktop is running +4. Ensure all prerequisites are installed + +For more information, see the [VS Code Remote Development documentation](https://code.visualstudio.com/docs/remote/containers). \ No newline at end of file diff --git a/pyspur/.devcontainer/devcontainer.json b/pyspur/.devcontainer/devcontainer.json new file mode 100644 index 0000000000000000000000000000000000000000..000fae0c65f587ee1bccba13e5b41e5e5b8fe4e3 --- /dev/null +++ b/pyspur/.devcontainer/devcontainer.json @@ -0,0 +1,146 @@ +{ + "name": "PySpur Development", + + "dockerComposeFile": [ + "./docker-compose.yml" + ], + + "service": "devdocker", + + "runServices": ["devdocker"], + + "workspaceFolder": "/pyspur", + + "features": { + "ghcr.io/devcontainers/features/docker-in-docker:2": { + "version": "latest", + "moby": true + } + }, + + "customizations": { + "vscode": { + "extensions": [ + "github.copilot", + "github.copilot-chat", + // Backend extensions + "ms-python.python", + "charliermarsh.ruff", + "tamasfe.even-better-toml", + // Frontend extensions + "dbaeumer.vscode-eslint", + "esbenp.prettier-vscode", + "ms-vscode.vscode-typescript-next" + ], + "settings": { + // Git settings + // bypass pre-commit hooks not allowed + "git.allowNoVerifyCommit": false, + + // Python analysis settings + "python.analysis.autoImportCompletions": true, + "python.analysis.autoImportUserSymbols": true, + "python.analysis.importFormat": "relative", + "python.analysis.typeCheckingMode": "strict", + "python.defaultInterpreterPath": "/usr/local/bin/python", + + // Python linting and formatting + "python.linting.enabled": true, + "python.linting.mypyEnabled": false, + "python.linting.ruffEnabled": true, + + // TypeScript settings + "typescript.tsdk": "/pyspur/frontend/node_modules/typescript/lib", + "typescript.preferences.importModuleSpecifier": "non-relative", + "typescript.preferences.projectRoot": "/pyspur/frontend", + "npm.packageManager": "npm", + + // Editor formatting settings + "editor.formatOnSave": true, + "editor.defaultFormatter": "esbenp.prettier-vscode", + + // Language specific editor settings + "[python]": { + "editor.formatOnType": true, + "editor.formatOnSave": true, + "editor.defaultFormatter": "charliermarsh.ruff", + "editor.codeActionsOnSave": { + "source.organizeImports": "always", + "source.fixAll.ruff": "always" + } + }, + "[typescript]": { + "editor.defaultFormatter": "esbenp.prettier-vscode", + "editor.formatOnSave": true, + "editor.codeActionsOnSave": { + "source.fixAll.eslint": "explicit", + "source.organizeImports": "explicit" + } + }, + "[typescriptreact]": { + "editor.defaultFormatter": "esbenp.prettier-vscode", + "editor.formatOnSave": true, + "editor.codeActionsOnSave": { + "source.fixAll.eslint": "explicit", + "source.organizeImports": "explicit" + } + }, + "[javascript]": { + "editor.defaultFormatter": "esbenp.prettier-vscode", + "editor.formatOnSave": true, + "editor.codeActionsOnSave": { + "source.fixAll.eslint": "explicit", + "source.organizeImports": "explicit" + } + }, + "[javascriptreact]": { + "editor.defaultFormatter": "esbenp.prettier-vscode", + "editor.formatOnSave": true, + "editor.codeActionsOnSave": { + "source.fixAll.eslint": "explicit", + "source.organizeImports": "explicit" + } + }, + "[json]": { + "editor.quickSuggestions": { + "strings": true + }, + "editor.suggest.insertMode": "replace", + "editor.formatOnSave": true, + "editor.defaultFormatter": "esbenp.prettier-vscode" + }, + "[shellscript]": { + "editor.formatOnSave": true, + "editor.defaultFormatter": "esbenp.prettier-vscode" + }, + "[yaml]": { + "editor.insertSpaces": true, + "editor.tabSize": 2, + "editor.autoIndent": "advanced", + "diffEditor.ignoreTrimWhitespace": false, + "editor.formatOnSave": true, + "editor.defaultFormatter": "esbenp.prettier-vscode" + }, + "prettier.configPath": "/pyspur/frontend/.prettierrc" + } + } + }, + "remoteUser": "root", + "shutdownAction": "none", + "forwardPorts": [6080, "backend:8000", "frontend:3000"], + "portsAttributes": { + "frontend:3000" :{ + "label": "frontend", + "onAutoForward": "silent" + }, + "backend:8000" :{ + "label": "backend", + "onAutoForward": "silent" + }, + "6080" :{ + "label": "app", + "onAutoForward": "silent" + } + }, + "postCreateCommand": "chmod +x .devcontainer/post-create.sh && .devcontainer/post-create.sh" +} diff --git a/pyspur/.devcontainer/docker-compose.yml b/pyspur/.devcontainer/docker-compose.yml new file mode 100644 index 0000000000000000000000000000000000000000..82ce8f215528b3f53468f693abb2b60e21b3ddc7 --- /dev/null +++ b/pyspur/.devcontainer/docker-compose.yml @@ -0,0 +1,14 @@ +services: + devdocker: + build: + context: .. + dockerfile: .devcontainer/Dockerfile + target: development + volumes: + # Project files + - ../:/pyspur:cached + - ../.env:/pyspur/backend/.env:cached + - /pyspur/frontend/node_modules + environment: + - PYTHONPATH=/pyspur/backend + command: sleep infinity \ No newline at end of file diff --git a/pyspur/.devcontainer/post-create.sh b/pyspur/.devcontainer/post-create.sh new file mode 100644 index 0000000000000000000000000000000000000000..35b1b34e73eccae28a0eba651d328e4a416a8e57 --- /dev/null +++ b/pyspur/.devcontainer/post-create.sh @@ -0,0 +1,18 @@ +#!/bin/bash + +# Install pre-commit hooks +uv pip install --system pre-commit==4.1.0 +pre-commit install + +# Check if package.json has changed and reinstall if needed +if [ -f /pyspur/frontend/package.json ]; then + cd /pyspur/frontend && npm install +fi + +# Add source command to main bashrc +echo ' +# Source custom settings +# Source custom bashrc settings if the file exists +if [ -f /pyspur/.devcontainer/.bashrc ]; then + source /pyspur/.devcontainer/.bashrc +fi' >> ~/.bashrc \ No newline at end of file diff --git a/pyspur/.dockerignore b/pyspur/.dockerignore new file mode 100644 index 0000000000000000000000000000000000000000..d2cd9c833041923c413b77bbc6f510891dc72a60 --- /dev/null +++ b/pyspur/.dockerignore @@ -0,0 +1,88 @@ +# Version control +.git +.gitignore + +# Dependencies +**/node_modules +**/__pycache__ +**/*.pyc +**/*.pyo +**/*.pyd +**/*.so +**/.Python +**/env +**/venv +**/.env +**/.env.local +**/.env.development.local +**/.env.test.local +**/.env.production.local + +# Python specific +**/develop-eggs +**/eggs +**/.eggs +**/parts +**/sdist +**/var +**/wheels +**/*.egg-info +**/.installed.cfg +**/*.egg + +# Build outputs +**/dist +**/build +**/.next +**/out +**/*.egg-info + +# Development/IDE files +**/.idea +**/.vscode +**/.DS_Store +**/*.swp +**/*.swo + +# Docker files +**/Dockerfile* +**/.dockerignore +docker-compose*.yml + +# Test files +**/__tests__ +**/test +**/*.test.js +**/*.spec.js +**/*.test.py +**/*.spec.py +**/coverage +**/htmlcov + +# Documentation +**/*.md +**/docs + +# Logs +**/logs +**/*.log +**/npm-debug.log* +**/yarn-debug.log* +**/yarn-error.log* + +# Cache +**/.cache +**/.npm +**/.eslintcache +**/.pytest_cache +**/__pycache__ +**/.coverage + +# Data directories +**/data +**/uploads +**/downloads + +# Databases +**/*.db +**/sqlite/*.db \ No newline at end of file diff --git a/pyspur/.env.example b/pyspur/.env.example new file mode 100644 index 0000000000000000000000000000000000000000..6ee599ec9feef73f90138acc227ce7d1801aad55 --- /dev/null +++ b/pyspur/.env.example @@ -0,0 +1,127 @@ +# ====================== +# Core Configuration +# ====================== + +# Environment +# ENVIRONMENT=development +ENVIRONMENT=production +PYTHONUNBUFFERED=1 # This is to prevent Python from buffering stdout and stderr +OAUTHLIB_INSECURE_TRANSPORT=1 # This is to allow OAuth2 to work with http + +# Version tag for Docker images in production +VERSION=latest + +# GitHub repository (username/repo-name) +GITHUB_REPOSITORY=pyspur-dev/pyspur + + +# ====================== +# Application Configuration +# ====================== + +# Application Host Configuration +# This is the host that the application will be running on +# By default, the application will be running on + +PYSPUR_HOST=0.0.0.0 +PYSPUR_PORT=6080 + + +# Backend Configuration +DEBUG=False + + +# ====================== +# Database Settings +# ====================== +# PySpur uses PostgreSQL as the database. By default, the database is hosted in a separate container. +# If you want to use an external database, you can provide the connection details here. +# PostgreSQL Configuration +POSTGRES_DB=pyspur +POSTGRES_USER=pyspur +POSTGRES_PASSWORD=pyspur +POSTGRES_HOST=db +POSTGRES_PORT=5432 + + +# ====================== +# Model Provider API Keys +# ====================== + +# OPENAI_API_KEY=your_openai_api_key +# GEMINI_API_KEY=your_gemini_api_key +# ANTHROPIC_API_KEY=your_anthropic_api_key + +# ====================== +# OpenAI API URL Configuration +# ====================== +# In case you are using OpenAI-compatible API service, you can specify the base URL of the API here +# OPENAI_API_BASE=https://api.openai.com/v1 + +# ====================== +# Ollama Configuration +# ====================== + +# NOTE: +# if the ollama service is running on port 11434 of the host machine, +# then use http://host.docker.internal:11434 as the base url +# if the ollama service is running on a different host, use the ip address or domain name of the host + +# Also make sure the ollama service is configured to accept requests. +# This can be done setting OLLAMA_HOST=0.0.0.0 environment variable before launching the ollama service. + +# OLLAMA_BASE_URL=http://host.docker.internal:11434 + + +# ====================== +# Azure OpenAI Configuration +# ====================== + +# AZURE_OPENAI_API_KEY=your_azure_openai_api_key +# AZURE_OPENAI_API_BASE=https://your-resource-name.openai.azure.com +# AZURE_OPENAI_API_VERSION=your_azure_openai_api_version +# AZURE_OPENAI_DEPLOYMENT_NAME=your_azure_openai_deployment_name +# ====================== + +# ====================== +# Google configuration +# ====================== + +# NEXT_PUBLIC_GOOGLE_CLIENT_ID=your_google_client_id # Google OAuth Client ID +# # This environment variable is used to configure Google OAuth for your application. +# # It should be set to the client id obtained from the Google Developer Console. +# # The prefix 'NEXT_PUBLIC_' is used to expose this variable to the frontend, +# # allowing client-side code to access it. + +# ====================== + +# ====================== +# GitHub configuration +# ====================== + +# GITHUB_ACCESS_TOKEN=your_github_access_token # GitHub Personal Access Token +# # This environment variable is used to configure GitHub OAuth for your application. +# # It should be set to the personal access token obtained from the GitHub Developer Settings. + +# ====================== + +# ====================== +# Firecrawl configuration +# ====================== + +# FIRECRAWL_API_KEY=your_firecrawl_api_key # Firecrawl API Key +# # This environment variable is used to configure Firecrawl API for your application. +# # It should be set to the API key obtained from the Firecrawl Developer Console. + +# ====================== + +# Frontend Configuration +# ====================== +# Usage Data +# ====================== +# We use PostHog to collect anonymous usage data for the PySpur UI. +# This helps us understand how our users are interacting with the application +# and improve the user experience. +# If you want to disable usage data collection, uncomment the following line: +# DISABLE_ANONYMOUS_TELEMETRY=true +# ====================== diff --git a/pyspur/.github/dependabot.yml b/pyspur/.github/dependabot.yml new file mode 100644 index 0000000000000000000000000000000000000000..20cb428819c146bad994fa34eb4fbf1449398d40 --- /dev/null +++ b/pyspur/.github/dependabot.yml @@ -0,0 +1,12 @@ +# To get started with Dependabot version updates, you'll need to specify which +# package ecosystems to update and where the package manifests are located. +# Please see the documentation for more information: +# https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates +# https://containers.dev/guide/dependabot + +version: 2 +updates: + - package-ecosystem: "devcontainers" + directory: "/" + schedule: + interval: weekly diff --git a/pyspur/.github/workflows/release.yml b/pyspur/.github/workflows/release.yml new file mode 100644 index 0000000000000000000000000000000000000000..bd9adb2f53f742fd43ee3050cc53c68bdb41c465 --- /dev/null +++ b/pyspur/.github/workflows/release.yml @@ -0,0 +1,72 @@ +name: Release + +on: + release: + types: [created] + +env: + REGISTRY: ghcr.io + BACKEND_IMAGE_NAME: ${{ github.repository }}-backend + +jobs: + build-and-push-docker: + runs-on: ubuntu-latest + permissions: + contents: read + packages: write + id-token: write # needed for PyPI publishing + outputs: + image_name: ${{ steps.meta-backend.outputs.tags }} + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + ref: ${{ github.event.release.tag_name }} + + - name: Set up QEMU + uses: docker/setup-qemu-action@v3 + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + - name: Log in to the Container registry + uses: docker/login-action@v3 + with: + registry: ${{ env.REGISTRY }} + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Extract metadata (tags, labels) for Backend + id: meta-backend + uses: docker/metadata-action@v5 + with: + images: ${{ env.REGISTRY }}/${{ env.BACKEND_IMAGE_NAME }} + tags: | + type=semver,pattern={{version}} + type=semver,pattern={{major}}.{{minor}} + + - name: Build and push Backend image + uses: docker/build-push-action@v6 + with: + context: . + file: ./Dockerfile.backend + push: true + platforms: linux/amd64,linux/arm64 + target: production + tags: ${{ steps.meta-backend.outputs.tags }} + labels: ${{ steps.meta-backend.outputs.labels }} + + - name: Build Python package + run: | + # Create dist directory + mkdir -p dist + + # Build package using the container we just built - use first tag + DOCKER_TAG=$(echo "${{ steps.meta-backend.outputs.tags }}" | head -n1) + docker run --rm -v "$(pwd)/dist:/dist" "$DOCKER_TAG" sh -c "cd /pyspur/backend && uv build && cp dist/* /dist/" + + - name: Publish package to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 + with: + packages-dir: dist/ \ No newline at end of file diff --git a/pyspur/.gitignore b/pyspur/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..670c7055aa207dfa95dbb74d4d0c04c1466db3a3 --- /dev/null +++ b/pyspur/.gitignore @@ -0,0 +1,178 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/latest/usage/project/#working-with-version-control +.pdm.toml +.pdm-python +.pdm-build/ + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +.DS_Store +.vscode + +# Ruff cache +**/.ruff_cache/ + + +# node_modules +**/node_modules/ +**/node_modules + +prd/ + +# package* in docs +docs/package* diff --git a/pyspur/.pre-commit-config.yaml b/pyspur/.pre-commit-config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8412409793e870766bbc59974adc40d5d1d07b28 --- /dev/null +++ b/pyspur/.pre-commit-config.yaml @@ -0,0 +1,26 @@ +repos: + - repo: local + hooks: + - id: backend-hooks + name: Backend Hooks + entry: pre-commit run --config backend/.pre-commit-config.yaml + language: system + pass_filenames: false + always_run: true + files: ^backend/ + + - id: frontend-hooks + name: Frontend Hooks + entry: bash -c 'cd frontend && npx lint-staged' + language: system + pass_filenames: false + always_run: true + files: ^frontend/ + + - id: frontend-hooks-cleanup + name: Cleanup files created by frontend hooks + entry: bash -c 'cd frontend && rm -f tsconfig.*.tsbuildinfo' + language: system + pass_filenames: false + always_run: true + files: ^frontend/ \ No newline at end of file diff --git a/pyspur/Dockerfile.backend b/pyspur/Dockerfile.backend new file mode 100644 index 0000000000000000000000000000000000000000..6a108e2298f19dc559b424549367c5ac782d9b0a --- /dev/null +++ b/pyspur/Dockerfile.backend @@ -0,0 +1,38 @@ +FROM python:3.12-slim AS base +RUN apt-get update && apt-get install -y \ + libpq-dev \ + gcc \ + curl \ + && rm -rf /var/lib/apt/lists/* + +RUN pip install uv +WORKDIR /pyspur/backend +COPY backend/pyproject.toml . +RUN uv pip compile pyproject.toml > requirements.txt && \ + uv pip install --system --no-cache-dir -r requirements.txt && \ + rm requirements.txt + + +# Development stage +FROM base AS development +ENV PYTHONPATH=/pyspur/backend +# Development-specific instructions here + +# Frontend build stage +FROM node:23-slim AS frontend-builder +WORKDIR /pyspur/frontend +COPY frontend/package*.json ./ +RUN npm ci +COPY frontend/ . +RUN npm run build + +# Production stage +FROM base AS production +ENV PYTHONPATH=/pyspur/backend +COPY backend/ . +# Copy frontend static files from frontend build stage +RUN mkdir -p /pyspur/backend/pyspur/static +RUN rm -rf /pyspur/backend/pyspur/static/* +COPY --from=frontend-builder /pyspur/frontend/out/ /pyspur/backend/pyspur/static/ +COPY .env.example /pyspur/backend/pyspur/templates/.env.example +# Production-specific instructions here diff --git a/pyspur/Dockerfile.frontend b/pyspur/Dockerfile.frontend new file mode 100644 index 0000000000000000000000000000000000000000..d7141b580e83c6a53b3cc5331548e73cadd5a909 --- /dev/null +++ b/pyspur/Dockerfile.frontend @@ -0,0 +1,15 @@ +FROM node:23-slim AS base +WORKDIR /pyspur/frontend +COPY frontend/package*.json ./ + +# Development stage +FROM base AS development +RUN npm install +# Development-specific instructions here + +# Production stage +FROM base AS production +RUN npm ci --only=production +COPY frontend/ . +RUN npm run build +# Production-specific instructions here \ No newline at end of file diff --git a/pyspur/LICENSE b/pyspur/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..29f81d812f3e768fa89638d1f72920dbfd1413a8 --- /dev/null +++ b/pyspur/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/pyspur/README.md b/pyspur/README.md new file mode 100644 index 0000000000000000000000000000000000000000..7af4453d7fef5acd2b053ca49c34d8e4539c58b0 --- /dev/null +++ b/pyspur/README.md @@ -0,0 +1,187 @@ +![PySpur](./docs/images/hero.png) + +

Iterate over your agents 10x faster. AI engineers use PySpur to iterate over AI agents visually without reinventing the wheel.

+ +

+ README in English + 简体中文版自述文件 + 日本語のREADME + README in Korean + Deutsche Version der README +Version française du README +Versión en español del README +

+ +

+ + Docs + + + Meet us + + + Cloud + + + Join Our Discord + +

+ +https://github.com/user-attachments/assets/54d0619f-22fd-476c-bf19-9be083d7e710 + +# 🕸️ Why PySpur? + +## Problem: It takes a 1,000 tiny paper cuts to make AI reliable + +AI engineers today face three problems of building agents: + +* **Prompt Hell**: Hours of prompt tweaking and trial-and-error frustration. +* **Workflow Blindspots**: Lack of visibility into step interactions causing hidden failures and confusion. +* **Terminal Testing Nightmare** Squinting at raw outputs and manually parsing JSON. + +We've been there ourselves, too. We launched a graphic design agent early 2024 and quickly reached thousands of users, yet, struggled with the lack of its reliability and existing debugging tools. + +## Solution: A playground for agents that saves time + +### Step 1: Define Test Cases + +https://github.com/user-attachments/assets/ed9ca45f-7346-463f-b8a4-205bf2c4588f + +### Step 2: Build the agent in Python code or via UI + +https://github.com/user-attachments/assets/7043aae4-fad1-42bd-953a-80c94fce8253 + +### Step 3: Iterate obsessively + +https://github.com/user-attachments/assets/72c9901d-a39c-4f80-85a5-f6f76e55f473 + +### Step 4: Deploy + +https://github.com/user-attachments/assets/b14f34b2-9f16-4bd0-8a0f-1c26e690af93 + +# ✨ Core features: + +- 👤 **Human in the Loop**: Persistent workflows that wait for human approval. +- 🔄 **Loops**: Iterative tool calling with memory. +- 📤 **File Upload**: Upload files or paste URLs to process documents. +- 📋 **Structured Outputs**: UI editor for JSON Schemas. +- 🗃️ **RAG**: Parse, Chunk, Embed, and Upsert Data into a Vector DB. +- 🖼️ **Multimodal**: Support for Video, Images, Audio, Texts, Code. +- 🧰 **Tools**: Slack, Firecrawl.dev, Google Sheets, GitHub, and more. +- 📊 **Traces**: Automatically capture execution traces of deployed agents. +- 🧪 **Evals**: Evaluate agents on real-world datasets. +- 🚀 **One-Click Deploy**: Publish as an API and integrate wherever you want. +- 🐍 **Python-Based**: Add new nodes by creating a single Python file. +- 🎛️ **Any-Vendor-Support**: >100 LLM providers, embedders, and vector DBs. + +# ⚡ Quick start + +This is the quickest way to get started. Python 3.11 or higher is required. + +1. **Install PySpur:** + ```sh + pip install pyspur + ``` + +2. **Initialize a new project:** + ```sh + pyspur init my-project + cd my-project + ``` + This will create a new directory with a `.env` file. + +3. **Start the server:** + ```sh + pyspur serve --sqlite + ``` + By default, this will start PySpur app at `http://localhost:6080` using a sqlite database. + We recommend you configure a postgres instance URL in the `.env` file to get a more stable experience. + +4. **[Optional] Configure Your Environment and Add API Keys:** + - **App UI**: Navigate to API Keys tab to add provider keys (OpenAI, Anthropic, etc.) + - **Manual**: Edit `.env` file (recommended: configure postgres) and restart with `pyspur serve` + + +# 😎 Feature Reel + +## Human-in-the-loop breakpoints: + +These breakpoints pause the workflow when reached and resume whenever a human approves it. +They enable human oversight for workflows that require quality assurance: verify critical outputs before the workflow proceeds. + +https://github.com/user-attachments/assets/98cb2b4e-207c-4d97-965b-4fee47c94ce8 + +## Debug at Node Level: + +https://github.com/user-attachments/assets/6e82ad25-2a46-4c50-b030-415ea9994690 + +## Multimodal (Upload files or paste URLs) + +PDFs, Videos, Audio, Images, ... + +https://github.com/user-attachments/assets/83ed9a22-1ec1-4d86-9dd6-5d945588fd0b + +## Loops + +Loops + +## RAG + +### Step 1) Create Document Collection (Chunking + Parsing) + +https://github.com/user-attachments/assets/c77723b1-c076-4a64-a01d-6d6677e9c60e + +### Step 2) Create Vector Index (Embedding + Vector DB Upsert) + +https://github.com/user-attachments/assets/50e5c711-dd01-4d92-bb23-181a1c5bba25 + +## Modular Building Blocks + +https://github.com/user-attachments/assets/6442f0ad-86d8-43d9-aa70-e5c01e55e876 + +## Evaluate Final Performance + +https://github.com/user-attachments/assets/4dc2abc3-c6e6-4d6d-a5c3-787d518de7ae + +## Coming soon: Self-improvement + +https://github.com/user-attachments/assets/5bef7a16-ef9f-4650-b385-4ea70fa54c8a + +# 🛠️ PySpur Development Setup +#### [ Instructions for development on Unix-like systems. Development on Windows/PC not supported ] + +We recommend using Cursor/VS Code with our dev container (`.devcontainer/devcontainer.json`) for: +- Consistent development environment with pre-configured tools and extensions +- Optimized settings for Python and TypeScript development +- Automatic hot-reloading and port forwarding + +**Option 1: Cursor/VS Code Dev Container (Recommended)** +1. Install [Cursor](https://www.cursor.com/)/[VS Code](https://code.visualstudio.com/) and the [Dev Containers extension](https://marketplace.visualstudio.com/items?itemName=ms-vscode-remote.remote-containers) +2. Clone and open the repository +3. Click "Reopen in Container" when prompted + +**Option 2: Manual Setup** +1. **Clone the repository:** + ```sh + git clone https://github.com/PySpur-com/pyspur.git + cd pyspur + ``` + +2. **Launch using docker-compose.dev.yml:** + ```sh + docker compose -f docker-compose.dev.yml up --build -d + ``` + +3. **Customize your setup:** + Edit `.env` to configure your environment (e.g., PostgreSQL settings). + +Note: Manual setup requires additional configuration and may not include all dev container features. + +# ⭐ Support us + +You can support us in our work by leaving a star! Thank you! + +![star](https://github.com/user-attachments/assets/71f65273-6755-469d-be44-087bb89d5e76) + +Your feedback will be massively appreciated. +Please [tell us](mailto:founders@pyspur.dev?subject=Feature%20Request&body=I%20want%20this%20feature%3Ai) which features on that list you like to see next or request entirely new ones. diff --git a/pyspur/README_CN.md b/pyspur/README_CN.md new file mode 100644 index 0000000000000000000000000000000000000000..f6618bb1fb02157acabe5ddef25213244fbe7503 --- /dev/null +++ b/pyspur/README_CN.md @@ -0,0 +1,156 @@ +![PySpur](./docs/images/hero.png) + +

PySpur 是一个基于 Python 编写的 AI 智能体构建器。AI 工程师使用它来构建智能体,逐步执行并检查过去的运行记录。

+ +

+ README in English + 简体中文版自述文件 + 日本語のREADME + README in Korean + Deutsche Version der README + Version française du README + Versión en español del README +

+ +

+ + Docs + + + Meet us + + + Cloud + + + Join Our Discord + +

+ +https://github.com/user-attachments/assets/1ebf78c9-94b2-468d-bbbb-566311df16fe + +# 🕸️ 为什么选择 PySpur? + +- ✅ **测试驱动**:构建工作流,运行测试用例,并进行迭代。 +- 👤 **人在环路中**:持久化工作流,等待人工批准或拒绝。 +- 🔄 **循环**:具有记忆功能的迭代工具调用。 +- 📤 **文件上传**:上传文件或粘贴 URL 来处理文档。 +- 📋 **结构化输出**:JSON Schema UI 编辑器。 +- 🗃️ **RAG**:解析、分块、嵌入并将数据更新到向量数据库。 +- 🖼️ **多模态**:支持视频、图像、音频、文本、代码。 +- 🧰 **工具**:Slack、Firecrawl.dev、Google Sheets、GitHub 等。 +- 🧪 **评估**:在真实数据集上评估代理。 +- 🚀 **一键部署**:发布为 API 并在任意地方集成。 +- 🐍 **基于 Python**:通过创建单个 Python 文件来添加新节点。 +- 🎛️ **供应商支持**:支持超过 100 个 LLM 供应商、嵌入器和向量数据库。 + +# ⚡ 快速开始 + +这是入门的最快方式。需要 Python 3.11 或更高版本。 + +1. **安装 PySpur:** + ```sh + pip install pyspur + ``` + +2. **初始化新项目:** + ```sh + pyspur init my-project + cd my-project + ``` + 这将创建一个包含 `.env` 文件的新目录。 + +3. **启动服务器:** + ```sh + pyspur serve --sqlite + ``` + 默认情况下,这将使用 SQLite 数据库在 `http://localhost:6080` 启动 PySpur 应用。 + 我们建议你在 `.env` 文件中配置 Postgres 实例的 URL,以获得更稳定的体验。 + +4. **[可选] 配置环境和添加 API 密钥:** + - **应用界面**: 导航至 API 密钥标签页添加供应商密钥(OpenAI、Anthropic 等) + - **手动配置**: 编辑 `.env` 文件(推荐:配置 postgres)并使用 `pyspur serve` 重启 + +# ✨ 核心优势 + +## 人在环路中断点: + +这些断点在达到时会暂停工作流,并在人工批准后恢复。 +它们为需要质量保证的工作流提供人工监督:在工作流继续之前验证关键输出。 + +https://github.com/user-attachments/assets/98cb2b4e-207c-4d97-965b-4fee47c94ce8 + +## 节点级调试: + +https://github.com/user-attachments/assets/6e82ad25-2a46-4c50-b030-415ea9994690 + +## 多模态(上传文件或粘贴 URL) + +支持 PDF、视频、音频、图像等…… + +https://github.com/user-attachments/assets/83ed9a22-1ec1-4d86-9dd6-5d945588fd0b + +## 循环 + +Loops + +## RAG + +### 步骤 1) 创建文档集合(分块 + 解析) + +https://github.com/user-attachments/assets/c77723b1-c076-4a64-a01d-6d6677e9c60e + +### 步骤 2) 创建向量索引(嵌入 + 向量数据库插入) + +https://github.com/user-attachments/assets/50e5c711-dd01-4d92-bb23-181a1c5bba25 + +## 模块化构建块 + +https://github.com/user-attachments/assets/6442f0ad-86d8-43d9-aa70-e5c01e55e876 + +## 评估最终性能 + +https://github.com/user-attachments/assets/4dc2abc3-c6e6-4d6d-a5c3-787d518de7ae + +## 即将推出:自我提升 + +https://github.com/user-attachments/assets/5bef7a16-ef9f-4650-b385-4ea70fa54c8a + +# 🛠️ PySpur 开发环境设置 +#### [ Unix 类系统开发指南。Windows/PC 开发不支持。 ] + +我们推荐使用 Cursor/VS Code 和我们的开发容器(`.devcontainer/devcontainer.json`),它提供: +- 预配置工具和扩展的一致开发环境 +- 针对 Python 和 TypeScript 开发的优化设置 +- 自动热重载和端口转发 + +**选项 1:Cursor/VS Code 开发容器(推荐)** +1. 安装 [Cursor](https://www.cursor.com/)/[VS Code](https://code.visualstudio.com/) 和 [Dev Containers 扩展](https://marketplace.visualstudio.com/items?itemName=ms-vscode-remote.remote-containers) +2. 克隆并打开仓库 +3. 当提示时点击"在容器中重新打开" + +**选项 2:手动设置** +1. **克隆仓库:** + ```sh + git clone https://github.com/PySpur-com/pyspur.git + cd pyspur + ``` + +2. **使用 docker-compose.dev.yml 启动:** + ```sh + docker compose -f docker-compose.dev.yml up --build -d + ``` + +3. **自定义设置:** + 编辑 `.env` 配置环境(例如:PostgreSQL 设置)。 + +注意:手动设置需要额外配置,可能无法包含开发容器提供的所有功能。 + +# ⭐ 支持我们 + +你可以通过给我们项目 Star 来支持我们的工作!谢谢! + +![star](https://github.com/user-attachments/assets/71f65273-6755-469d-be44-087bb89d5e76) + +我们非常重视你的反馈。 +请 [告诉我们](mailto:founders@pyspur.dev?subject=Feature%20Request&body=I%20want%20this%20feature%3Ai) 你想在下一次看到列表中的哪些功能或全新的功能。 diff --git a/pyspur/README_DE.md b/pyspur/README_DE.md new file mode 100644 index 0000000000000000000000000000000000000000..e6d8b7350435d447ddbc52bab5143325dc9ffebb --- /dev/null +++ b/pyspur/README_DE.md @@ -0,0 +1,146 @@ +![PySpur](./docs/images/hero.png) + +

PySpur ist ein KI-Agenten-Builder in Python. KI-Entwickler nutzen ihn, um Agenten zu erstellen, sie Schritt für Schritt auszuführen und vergangene Durchläufe zu analysieren.

+ +

+ README auf Englisch + README auf vereinfachtem Chinesisch + README auf Japanisch + README auf Koreanisch + Deutsche Version der README + README auf Französisch + README auf Spanisch +

+ +

+ + Dokumentation + + + Treffen Sie uns + + + Cloud + + + Discord beitreten + +

+ +https://github.com/user-attachments/assets/1ebf78c9-94b2-468d-bbbb-566311df16fe + +# 🕸️ Warum PySpur? + +- ✅ **Testgetrieben**: Erstellen Sie Workflows, führen Sie Testfälle aus und iterieren Sie. +- 👤 **Human in the Loop**: Persistente Workflows, die auf Genehmigung oder Ablehnung des Users warten. +- 🔄 **Loops**: Wiederholte Toolaufrufe mit Zwischenspeicherung. +- 📤 **Datei-Upload**: Laden Sie Dateien hoch oder fügen Sie URLs ein, um Dokumente zu verarbeiten. +- 📋 **Strukturierte Outputs**: UI-Editor für JSON-Schemata. +- 🗃️ **RAG**: Daten parsen, in Abschnitte unterteilen, einbetten und in eine Vektor-Datenbank einfügen/aktualisieren. +- 🖼️ **Multimodal**: Unterstützung für Video, Bilder, Audio, Texte, Code. +- 🧰 **Tools**: Slack, Firecrawl.dev, Google Sheets, GitHub und mehr. +- 🧪 **Evaluierungen**: Bewerten Sie Agenten anhand von realen Datensätzen. +- 🚀 **One-Click Deploy**: Veröffentlichen Sie Ihre Lösung als API und integrieren Sie sie überall. +- 🐍 **Python-basiert**: Fügen Sie neue Knoten hinzu, indem Sie eine einzige Python-Datei erstellen. +- 🎛️ **Support für jeden Anbieter**: Über 100 LLM-Anbieter, Einbettungslösungen und Vektor-Datenbanken. + +# ⚡ Schnellstart + +Dies ist der schnellste Weg, um loszulegen. Python 3.11 oder höher wird benötigt. + +1. **PySpur installieren:** + ```sh + pip install pyspur + ``` + +2. **Ein neues Projekt initialisieren:** + ```sh + pyspur init my-project + cd my-project + ``` + Dadurch wird ein neues Verzeichnis mit einer `.env`-Datei erstellt. + +3. **Den Server starten:** + ```sh + pyspur serve --sqlite + ``` + Standardmäßig startet dies die PySpur-App unter `http://localhost:6080` mit einer SQLite-Datenbank. + Wir empfehlen, in der `.env`-Datei eine PostgreSQL-Instanz-URL zu konfigurieren, um eine stabilere Erfahrung zu gewährleisten. + +4. **[Optional] Umgebung konfigurieren und API-Schlüssel hinzufügen:** + - **App-Oberfläche**: Navigieren Sie zum Tab „API Keys", um Anbieter-Schlüssel hinzuzufügen (OpenAI, Anthropic usw.) + - **Manuelle Konfiguration**: Bearbeiten Sie die `.env`-Datei (empfohlen: PostgreSQL konfigurieren) und starten Sie mit `pyspur serve` neu + +# ✨ Kernvorteile + +## Mensch-im-Regelkreis-Haltepunkte: + +Diese Haltepunkte pausieren den Workflow, wenn sie erreicht werden, und setzen ihn fort, sobald ein Mensch ihn genehmigt. +Sie ermöglichen menschliche Aufsicht für Workflows, die Qualitätssicherung erfordern: Überprüfen Sie kritische Ausgaben, bevor der Workflow fortgesetzt wird. + +https://github.com/user-attachments/assets/98cb2b4e-207c-4d97-965b-4fee47c94ce8 + +## Debuggen auf Node-Ebene: + +https://github.com/user-attachments/assets/6e82ad25-2a46-4c50-b030-415ea9994690 + +## Multimodal (Dateien hochladen oder URLs einfügen) + +PDFs, Videos, Audio, Bilder, ... + +https://github.com/user-attachments/assets/83ed9a22-1ec1-4d86-9dd6-5d945588fd0b + +## Loops + +Loops + +## RAG + +### Schritt 1) Erstellen einer Dokumentensammlung (Chunking + Parsing) + +https://github.com/user-attachments/assets/c77723b1-c076-4a64-a01d-6d6677e9c60e + +### Schritt 2) Erstellen eines Vektorindex (Einbettung + Einfügen/Aktualisieren in der Vektor-Datenbank) + +https://github.com/user-attachments/assets/50e5c711-dd01-4d92-bb23-181a1c5bba25 + +## Modulare Bausteine + +https://github.com/user-attachments/assets/6442f0ad-86d8-43d9-aa70-e5c01e55e876 + +## Endgültige Leistung bewerten + +https://github.com/user-attachments/assets/4dc2abc3-c6e6-4d6d-a5c3-787d518de7ae + +## Demnächst: Selbstverbesserung + +https://github.com/user-attachments/assets/5bef7a16-ef9f-4650-b385-4ea70fa54c8a + +# 🛠️ PySpur Entwicklungs-Setup +#### [ Anweisungen für die Entwicklung auf Unix-ähnlichen Systemen. Entwicklung auf Windows/PC wird nicht unterstützt ] + +Für die Entwicklung folgen Sie diesen Schritten: + +1. **Das Repository klonen:** + ```sh + git clone https://github.com/PySpur-com/pyspur.git + cd pyspur + ``` + +2. **Mit docker-compose.dev.yml starten:** + ```sh + docker compose -f docker-compose.dev.yml up --build -d + ``` + Dadurch wird eine lokale Instanz von PySpur mit aktiviertem Hot-Reloading für die Entwicklung gestartet. + +3. **Ihre Einrichtung anpassen:** + Bearbeiten Sie die `.env`-Datei, um Ihre Umgebung zu konfigurieren. Standardmäßig verwendet PySpur eine lokale PostgreSQL-Datenbank. Um eine externe Datenbank zu nutzen, ändern Sie die `POSTGRES_*`-Variablen in der `.env`. + +# ⭐ Unterstützen Sie uns + +Sie können uns bei unserer Arbeit unterstützen, indem Sie einen Stern hinterlassen! Vielen Dank! + +![star](https://github.com/user-attachments/assets/71f65273-6755-469d-be44-087bb89d5e76) + +Ihr Feedback wird sehr geschätzt. +Bitte [sagen Sie uns](mailto:founders@pyspur.dev?subject=Feature%20Request&body=I%20want%20this%20feature%3Ai), welche Funktionen aus dieser Liste Sie als Nächstes sehen möchten oder schlagen Sie ganz neue vor. diff --git a/pyspur/README_ES.md b/pyspur/README_ES.md new file mode 100644 index 0000000000000000000000000000000000000000..64a7e7392707e13f5cb1919ddd744210c232f51c --- /dev/null +++ b/pyspur/README_ES.md @@ -0,0 +1,148 @@ +![PySpur](./docs/images/hero.png) + +

PySpur es un constructor de agentes de IA en Python. Los ingenieros de IA lo utilizan para crear agentes, ejecutarlos paso a paso e inspeccionar ejecuciones anteriores.

+ +

+ README en inglés + Versión en chino simplificado + README en japonés + README en coreano + Versión en alemán del README + Versión en francés del README + Versión en español del README +

+ +

+ + Docs + + + Conócenos + + + Cloud + + + Únete a nuestro Discord + +

+ +https://github.com/user-attachments/assets/1ebf78c9-94b2-468d-bbbb-566311df16fe + +# 🕸️ ¿Por qué PySpur? + +- ✅ **Desarrollo Guiado por Pruebas**: Construye flujos de trabajo, ejecuta casos de prueba e itera. +- 👤 **Humano en el Bucle**: Flujos de trabajo persistentes que esperan aprobación o rechazo humano. +- 🔄 **Bucles**: Llamadas iterativas a herramientas con memoria. +- 📤 **Carga de Archivos**: Sube archivos o pega URLs para procesar documentos. +- 📋 **Salidas Estructuradas**: Editor de interfaz para esquemas JSON. +- 🗃️ **RAG**: Analiza, segmenta, incrusta y actualiza datos en una base de datos vectorial. +- 🖼️ **Multimodal**: Soporte para video, imágenes, audio, textos y código. +- 🧰 **Herramientas**: Slack, Firecrawl.dev, Google Sheets, GitHub y más. +- 🧪 **Evaluaciones**: Evalúa agentes en conjuntos de datos del mundo real. +- 🚀 **Despliegue con un clic**: Publica como una API e intégrala donde desees. +- 🐍 **Basado en Python**: Agrega nuevos nodos creando un solo archivo Python. +- 🎛️ **Soporte para Cualquier Proveedor**: Más de 100 proveedores de LLM, embedders y bases de datos vectoriales. + +# ⚡ Inicio Rápido + +Esta es la forma más rápida de comenzar. Se requiere Python 3.11 o superior. + +1. **Instala PySpur:** + ```sh + pip install pyspur + ``` + +2. **Inicializa un nuevo proyecto:** + ```sh + pyspur init my-project + cd my-project + ``` + Esto creará un nuevo directorio con un archivo `.env`. + +3. **Inicia el servidor:** + ```sh + pyspur serve --sqlite + ``` + Por defecto, esto iniciará la aplicación PySpur en `http://localhost:6080` utilizando una base de datos SQLite. + Se recomienda configurar una URL de instancia de Postgres en el archivo `.env` para obtener una experiencia más estable. + +4. **[Opcional] Configura tu entorno y añade claves API:** + - **A través de la interfaz de la aplicación**: Navega a la pestaña de API Keys para añadir claves de proveedores (OpenAI, Anthropic, etc.) + - **Configuración manual**: Edita el archivo `.env` (recomendado: configura postgres) y reinicia con `pyspur serve` + +¡Eso es todo! Haz clic en "New Spur" para crear un flujo de trabajo, o comienza con una de las plantillas predefinidas. + +# ✨ Beneficios Principales + +## Puntos de Interrupción con Humano en el Bucle: + +Estos puntos de interrupción pausan el flujo de trabajo cuando se alcanzan y lo reanudan tan pronto como un humano lo aprueba. +Permiten la supervisión humana para flujos de trabajo que requieren garantía de calidad: verifique las salidas críticas antes de que el flujo de trabajo continúe. + +https://github.com/user-attachments/assets/98cb2b4e-207c-4d97-965b-4fee47c94ce8 + +## Depuración a Nivel de Nodo: + +https://github.com/user-attachments/assets/6e82ad25-2a46-4c50-b030-415ea9994690 + +## Multimodal (Sube archivos o pega URLs) + +PDFs, Videos, Audio, Imágenes, ... + +https://github.com/user-attachments/assets/83ed9a22-1ec1-4d86-9dd6-5d945588fd0b + +## Bucles + +Bucles + +## RAG + +### Paso 1) Crear Colección de Documentos (Segmentación + Análisis) + +https://github.com/user-attachments/assets/c77723b1-c076-4a64-a01d-6d6677e9c60e + +### Paso 2) Crear Índice Vectorial (Incrustación + Actualización en DB Vectorial) + +https://github.com/user-attachments/assets/50e5c711-dd01-4d92-bb23-181a1c5bba25 + +## Bloques Modulares + +https://github.com/user-attachments/assets/6442f0ad-86d8-43d9-aa70-e5c01e55e876 + +## Evaluar el Rendimiento Final + +https://github.com/user-attachments/assets/4dc2abc3-c6e6-4d6d-a5c3-787d518de7ae + +## Próximamente: Auto-mejora + +https://github.com/user-attachments/assets/5bef7a16-ef9f-4650-b385-4ea70fa54c8a + +# 🛠️ Configuración de Desarrollo de PySpur +#### [ Instrucciones para el desarrollo en sistemas tipo Unix. Desarrollo en Windows/PC no es soportado ] + +Para el desarrollo, sigue estos pasos: + +1. **Clona el repositorio:** + ```sh + git clone https://github.com/PySpur-com/pyspur.git + cd pyspur + ``` + +2. **Inicia utilizando docker-compose.dev.yml:** + ```sh + docker compose -f docker-compose.dev.yml up --build -d + ``` + Esto iniciará una instancia local de PySpur con recarga en caliente habilitada para el desarrollo. + +3. **Personaliza tu configuración:** + Edita el archivo `.env` para configurar tu entorno. Por defecto, PySpur utiliza una base de datos PostgreSQL local. Para usar una base de datos externa, modifica las variables `POSTGRES_*` en el archivo `.env`. + +# ⭐ Apóyanos + +¡Puedes apoyarnos en nuestro trabajo dándonos una estrella! ¡Gracias! + +![star](https://github.com/user-attachments/assets/71f65273-6755-469d-be44-087bb89d5e76) + +Tu retroalimentación será enormemente apreciada. +Por favor [dinos](mailto:founders@pyspur.dev?subject=Feature%20Request&body=I%20want%20this%20feature%3Ai) qué características de esa lista te gustaría ver a continuación o solicita nuevas funcionalidades. diff --git a/pyspur/README_FR.md b/pyspur/README_FR.md new file mode 100644 index 0000000000000000000000000000000000000000..8621ecea859ab5e46252066a7d3f28647fbdfd00 --- /dev/null +++ b/pyspur/README_FR.md @@ -0,0 +1,148 @@ +![PySpur](./docs/images/hero.png) + +

PySpur est un créateur d'agents d'IA en Python. Les ingénieurs en IA l'utilisent pour créer des agents, les exécuter étape par étape et inspecter les exécutions passées.

+ +

+ README in English + 简体中文版自述文件 + 日本語のREADME + README in Korean + Deutsche Version der README + Version française du README + Versión en español del README +

+ +

+ + Documentation + + + Rencontrez-nous + + + Cloud + + + Rejoignez notre Discord + +

+ +https://github.com/user-attachments/assets/1ebf78c9-94b2-468d-bbbb-566311df16fe + +# 🕸️ Pourquoi PySpur ? + +- ✅ **Piloté par les tests** : Construisez des workflows, exécutez des cas de test et itérez. +- 👤 **Humain dans la boucle** : Workflows persistants qui attendent l'approbation ou le rejet humain. +- 🔄 **Boucles** : Appels d'outils itératifs avec mémoire. +- 📤 **Téléversement de fichiers** : Téléchargez des fichiers ou collez des URL pour traiter des documents. +- 📋 **Sorties structurées** : Éditeur d'interface utilisateur pour les schémas JSON. +- 🗃️ **RAG** : Analyser, découper, intégrer et insérer ou mettre à jour des données dans une base de données vectorielle. +- 🖼️ **Multimodal** : Support pour vidéos, images, audio, textes, code. +- 🧰 **Outils** : Slack, Firecrawl.dev, Google Sheets, GitHub, et plus encore. +- 🧪 **Évaluations** : Évaluez les agents sur des ensembles de données réelles. +- 🚀 **Déploiement en un clic** : Publiez en tant qu'API et intégrez-le où vous le souhaitez. +- 🐍 **Basé sur Python** : Ajoutez de nouveaux nœuds en créant un seul fichier Python. +- 🎛️ **Support multi-fournisseurs** : >100 fournisseurs de LLM, intégrateurs et bases de données vectorielles. + +# ⚡ Démarrage rapide + +C'est la manière la plus rapide de commencer. Python 3.11 ou une version supérieure est requis. + +1. **Installer PySpur :** + ```sh + pip install pyspur + ``` + +2. **Initialiser un nouveau projet :** + ```sh + pyspur init my-project + cd my-project + ``` + Cela va créer un nouveau répertoire avec un fichier `.env`. + +3. **Démarrer le serveur :** + ```sh + pyspur serve --sqlite + ``` + Par défaut, cela démarrera l'application PySpur sur `http://localhost:6080` en utilisant une base de données SQLite. + Nous vous recommandons de configurer une URL d'instance Postgres dans le fichier `.env` pour une expérience plus stable. + +4. **[Optionnel] Configurer votre environnement et ajouter des clés API :** + - **Via l'interface de l'application** : Naviguez vers l'onglet des clés API pour ajouter des clés de fournisseurs (OpenAI, Anthropic, etc.) + - **Configuration manuelle** : Éditez le fichier `.env` (recommandé : configurez postgres) et redémarrez avec `pyspur serve` + +C'est tout ! Cliquez sur « New Spur » pour créer un workflow, ou commencez avec l'un des modèles de base. + +# ✨ Avantages principaux + +## Points d'arrêt avec humain dans la boucle : + +Ces points d'arrêt mettent en pause le flux de travail lorsqu'ils sont atteints et le reprennent dès qu'un humain l'approuve. +Ils permettent une supervision humaine pour les flux de travail nécessitant une assurance qualité : vérifiez les sorties critiques avant que le flux de travail ne continue. + +https://github.com/user-attachments/assets/98cb2b4e-207c-4d97-965b-4fee47c94ce8 + +## Déboguer au niveau des nœuds : + +https://github.com/user-attachments/assets/6e82ad25-2a46-4c50-b030-415ea9994690 + +## Multimodal (téléverser des fichiers ou coller des URL) + +PDF, vidéos, audio, images, ... + +https://github.com/user-attachments/assets/83ed9a22-1ec1-4d86-9dd6-5d945588fd0b + +## Boucles + +Loops + +## RAG + +### Étape 1) Créer une collection de documents (découpage + analyse) + +https://github.com/user-attachments/assets/c77723b1-c076-4a64-a01d-6d6677e9c60e + +### Étape 2) Créer un index vectoriel (intégration + insertion/mise à jour dans la base de données vectorielle) + +https://github.com/user-attachments/assets/50e5c711-dd01-4d92-bb23-181a1c5bba25 + +## Blocs modulaires + +https://github.com/user-attachments/assets/6442f0ad-86d8-43d9-aa70-e5c01e55e876 + +## Évaluer la performance finale + +https://github.com/user-attachments/assets/4dc2abc3-c6e6-4d6d-a5c3-787d518de7ae + +## Bientôt : Auto-amélioration + +https://github.com/user-attachments/assets/5bef7a16-ef9f-4650-b385-4ea70fa54c8a + +# 🛠️ Configuration de développement de PySpur +#### [ Instructions pour le développement sur des systèmes de type Unix. Le développement sur Windows/PC n'est pas supporté ] + +Pour le développement, suivez ces étapes : + +1. **Cloner le dépôt :** + ```sh + git clone https://github.com/PySpur-com/pyspur.git + cd pyspur + ``` + +2. **Lancer en utilisant docker-compose.dev.yml :** + ```sh + docker compose -f docker-compose.dev.yml up --build -d + ``` + Cela démarrera une instance locale de PySpur avec le rechargement à chaud activé pour le développement. + +3. **Personnaliser votre configuration :** + Modifiez le fichier `.env` pour configurer votre environnement. Par défaut, PySpur utilise une base de données PostgreSQL locale. Pour utiliser une base de données externe, modifiez les variables `POSTGRES_*` dans le fichier `.env`. + +# ⭐ Soutenez-nous + +Vous pouvez nous soutenir en laissant une étoile ! Merci ! + +![star](https://github.com/user-attachments/assets/71f65273-6755-469d-be44-087bb89d5e76) + +Vos retours seront grandement appréciés. +Veuillez nous [faire part](mailto:founders@pyspur.dev?subject=Feature%20Request&body=I%20want%20this%20feature%3Ai) des fonctionnalités de cette liste que vous souhaitez voir prochainement ou proposer de toutes nouvelles fonctionnalités. \ No newline at end of file diff --git a/pyspur/README_JA.md b/pyspur/README_JA.md new file mode 100644 index 0000000000000000000000000000000000000000..bbd7dddfa75597e7d42044944a80291e1ee00709 --- /dev/null +++ b/pyspur/README_JA.md @@ -0,0 +1,145 @@ +![PySpur](./docs/images/hero.png) + +

PySpurはPython製のAIエージェントビルダーです。AIエンジニアはこれを利用してエージェントを構築し、ステップバイステップで実行し、過去の実行結果を検証します。

+ +

+ 英語版README + 简体中文版自述文件 + 日本語のREADME + 韓国語版README + ドイツ語版README + フランス語版README + スペイン語版README +

+ +

+ + ドキュメント + + + お会いしましょう + + + クラウド + + + Discordに参加する + +

+ +https://github.com/user-attachments/assets/1ebf78c9-94b2-468d-bbbb-566311df16fe + +# 🕸️ なぜ PySpur なのか? + +- ✅ **テスト駆動型**: ワークフローを構築し、テストケースを実行し、反復します。 +- 👤 **ヒューマンインザループ**: 人間の承認または拒否を待つ永続的なワークフロー。 +- 🔄 **ループ**: メモリを活用した反復的なツール呼び出し。 +- 📤 **ファイルアップロード**: ファイルのアップロードやURLの貼り付けによりドキュメントを処理します。 +- 📋 **構造化された出力**: JSONスキーマ用のUIエディタ。 +- 🗃️ **RAG**: データを解析、分割、埋め込み、そしてVector DBにアップサートします。 +- 🖼️ **マルチモーダル**: ビデオ、画像、オーディオ、テキスト、コードに対応。 +- 🧰 **ツール**: Slack、Firecrawl.dev、Google Sheets、GitHubなど多数。 +- 🧪 **評価**: 実際のデータセットでエージェントを評価します。 +- 🚀 **ワンクリックデプロイ**: APIとして公開し、どこにでも統合可能。 +- 🐍 **Pythonベース**: 単一のPythonファイルを作成するだけで新しいノードを追加できます。 +- 🎛️ **どのベンダーにも対応**: 100以上のLLMプロバイダー、エンベッダー、Vector DBに対応。 + +# ⚡ クイックスタート + +これは最も迅速なスタート方法です。Python 3.11以上が必要です。 + +1. **PySpurのインストール:** + ```sh + pip install pyspur + ``` + +2. **新しいプロジェクトの初期化:** + ```sh + pyspur init my-project + cd my-project + ``` + これにより、`.env`ファイルを含む新しいディレクトリが作成されます。 + +3. **サーバーの起動:** + ```sh + pyspur serve --sqlite + ``` + デフォルトでは、SQLiteデータベースを使用して `http://localhost:6080` でPySpurアプリが起動します。より安定した動作を求める場合は、`.env`ファイルにPostgresのインスタンスURLを設定することを推奨します。 + +4. **[オプション] 環境設定とAPIキーの追加:** + - **アプリUI**: APIキータブに移動して各プロバイダーのキー(OpenAI、Anthropicなど)を追加 + - **手動設定**: `.env`ファイルを編集(推奨:postgresを設定)し、`pyspur serve`で再起動 + +# ✨ 主な利点 + +## ヒューマンインザループブレークポイント: + +これらのブレークポイントは到達時にワークフローを一時停止し、人間が承認するとすぐに再開します。 +品質保証が必要なワークフローに人間の監視を可能にします:ワークフローが進む前に重要な出力を検証します。 + +https://github.com/user-attachments/assets/98cb2b4e-207c-4d97-965b-4fee47c94ce8 + +## ノードレベルでのデバッグ: + +https://github.com/user-attachments/assets/6e82ad25-2a46-4c50-b030-415ea9994690 + +## マルチモーダル(ファイルアップロードまたはURL貼り付け) + +PDF、ビデオ、オーディオ、画像、… + +https://github.com/user-attachments/assets/83ed9a22-1ec1-4d86-9dd6-5d945588fd0b + +## ループ + +Loops + +## RAG + +### ステップ 1) ドキュメントコレクションの作成(チャンク分割+解析) + +https://github.com/user-attachments/assets/c77723b1-c076-4a64-a01d-6d6677e9c60e + +### ステップ 2) ベクターインデックスの作成(埋め込み+Vector DBアップサート) + +https://github.com/user-attachments/assets/50e5c711-dd01-4d92-bb23-181a1c5bba25 + +## モジュール式ビルディングブロック + +https://github.com/user-attachments/assets/6442f0ad-86d8-43d9-aa70-e5c01e55e876 + +## 最終パフォーマンスの評価 + +https://github.com/user-attachments/assets/4dc2abc3-c6e6-4d6d-a5c3-787d518de7ae + +## 近日公開予定:自己改善 + +https://github.com/user-attachments/assets/5bef7a16-ef9f-4650-b385-4ea70fa54c8a + +# 🛠️ PySpur 開発環境セットアップ +#### [ Unix系システムでの開発向けの手順です。Windows/PCでの開発はサポートされていません ] + +開発のためには、以下の手順に従ってください: + +1. **リポジトリのクローン:** + ```sh + git clone https://github.com/PySpur-com/pyspur.git + cd pyspur + ``` + +2. **docker-compose.dev.ymlを使用して起動:** + ```sh + docker compose -f docker-compose.dev.yml up --build -d + ``` + これにより、開発用にホットリロードが有効なPySpurのローカルインスタンスが起動します。 + +3. **セットアップのカスタマイズ:** + 環境設定のために `.env` ファイルを編集してください。デフォルトでは、PySpurはローカルのPostgreSQLデータベースを使用しています。外部データベースを使用する場合は、`.env` 内の `POSTGRES_*` 変数を変更してください. + +# ⭐ サポート + +スターを押していただくことで、私たちの活動をサポートしていただけます。ありがとうございます! + +![star](https://github.com/user-attachments/assets/71f65273-6755-469d-be44-087bb89d5e76) + +皆様のフィードバックを大変ありがたく思います。 +次にどの機能を見たいか、または全く新しい機能のリクエストがあれば、ぜひ[お知らせください](mailto:founders@pyspur.dev?subject=Feature%20Request&body=I%20want%20this%20feature%3Ai). diff --git a/pyspur/README_KR.md b/pyspur/README_KR.md new file mode 100644 index 0000000000000000000000000000000000000000..2d7f7bd85d1f92a4b3bb6242567ca7857e31ad70 --- /dev/null +++ b/pyspur/README_KR.md @@ -0,0 +1,146 @@ +![PySpur](./docs/images/hero.png) + +

PySpur은 파이썬 기반의 AI 에이전트 빌더입니다. AI 엔지니어들은 이를 사용해 에이전트를 구축하고, 단계별로 실행하며 과거 실행 기록을 검토합니다.

+ +

+ 영문 README + 简体中文版自述文件 + 日本語のREADME + 한국어 README + 독일어 README + 프랑스어 README + 스페인어 README +

+ +

+ + 문서 + + + 만나기 + + + 클라우드 + + + 디스코드 참여 + +

+ +https://github.com/user-attachments/assets/1ebf78c9-94b2-468d-bbbb-566311df16fe + +# 🕸️ 왜 PySpur인가? + +- ✅ **테스트 주도**: 워크플로우를 구축하고, 테스트 케이스를 실행하며, 반복합니다. +- 👤 **인간 참여 루프**: 인간의 승인 또는 거부를 기다리는 지속적인 워크플로우. +- 🔄 **루프**: 메모리를 활용한 반복적 도구 호출. +- 📤 **파일 업로드**: 파일을 업로드하거나 URL을 붙여넣어 문서를 처리. +- 📋 **구조화된 출력**: JSON 스키마용 UI 편집기. +- 🗃️ **RAG**: 데이터를 파싱, 청킹, 임베딩 및 벡터 DB에 업서트. +- 🖼️ **멀티모달**: 비디오, 이미지, 오디오, 텍스트, 코드 지원. +- 🧰 **도구**: Slack, Firecrawl.dev, Google Sheets, GitHub 등. +- 🧪 **평가**: 실제 데이터셋에서 에이전트 평가. +- 🚀 **원클릭 배포**: API로 발행하여 원하는 곳에 통합. +- 🐍 **파이썬 기반**: 단일 파이썬 파일 생성으로 새 노드 추가. +- 🎛️ **모든 벤더 지원**: 100개 이상의 LLM 제공업체, 임베더, 벡터 DB 지원. + +# ⚡ 빠른 시작 + +시작하는 가장 빠른 방법입니다. 파이썬 3.11 이상이 필요합니다. + +1. **PySpur 설치:** + ```sh + pip install pyspur + ``` + +2. **새 프로젝트 초기화:** + ```sh + pyspur init my-project + cd my-project + ``` + 새 디렉토리와 함께 `.env` 파일이 생성됩니다. + +3. **서버 시작:** + ```sh + pyspur serve --sqlite + ``` + 기본적으로 SQLite 데이터베이스를 사용하여 `http://localhost:6080`에서 PySpur 앱이 시작됩니다. + 보다 안정적인 사용을 위해 `.env` 파일에 PostgreSQL 인스턴스 URL을 설정하는 것을 권장합니다. + +4. **[선택 사항] 환경 구성 및 API 키 추가:** + - **앱 UI**: API 키 탭으로 이동하여 공급자 키(OpenAI, Anthropic 등) 추가 + - **수동 구성**: `.env` 파일 편집(권장: postgres 구성) 후 `pyspur serve`로 재시작 + +# ✨ 핵심 이점 + +## 인간 참여 중단점: + +이러한 중단점은 도달했을 때 워크플로우를 일시 중지하고 인간이 승인하면 재개됩니다. +품질 보증이 필요한 워크플로우에 인간의 감독을 가능하게 합니다: 워크플로우가 진행되기 전에 중요한 출력을 검증합니다. + +https://github.com/user-attachments/assets/98cb2b4e-207c-4d97-965b-4fee47c94ce8 + +## 노드 레벨에서 디버그: + +https://github.com/user-attachments/assets/6e82ad25-2a46-4c50-b030-415ea9994690 + +## 멀티모달 (파일 업로드 또는 URL 붙여넣기) + +PDF, 비디오, 오디오, 이미지, ... + +https://github.com/user-attachments/assets/83ed9a22-1ec1-4d86-9dd6-5d945588fd0b + +## 루프 + +Loops + +## RAG + +### 1단계) 문서 컬렉션 생성 (청킹 + 파싱) + +https://github.com/user-attachments/assets/c77723b1-c076-4a64-a01d-6d6677e9c60e + +### 2단계) 벡터 인덱스 생성 (임베딩 + 벡터 DB 업서트) + +https://github.com/user-attachments/assets/50e5c711-dd01-4d92-bb23-181a1c5bba25 + +## 모듈형 빌딩 블록 + +https://github.com/user-attachments/assets/6442f0ad-86d8-43d9-aa70-e5c01e55e876 + +## 최종 성능 평가 + +https://github.com/user-attachments/assets/4dc2abc3-c6e6-4d6d-a5c3-787d518de7ae + +## 곧 추가될 기능: 자기 개선 + +https://github.com/user-attachments/assets/5bef7a16-ef9f-4650-b385-4ea70fa54c8a + +# 🛠️ PySpur 개발 환경 설정 +#### [ 유닉스 계열 시스템 개발 지침. Windows/PC 개발은 지원되지 않음 ] + +개발을 위해 아래 단계를 따르세요: + +1. **리포지토리 클론:** + ```sh + git clone https://github.com/PySpur-com/pyspur.git + cd pyspur + ``` + +2. **docker-compose.dev.yml 사용하여 실행:** + ```sh + docker compose -f docker-compose.dev.yml up --build -d + ``` + 이 명령어는 개발용 핫 리로딩이 활성화된 로컬 PySpur 인스턴스를 시작합니다. + +3. **환경 설정 맞춤:** + 환경 구성을 위해 `.env` 파일을 수정합니다. 기본적으로 PySpur는 로컬 PostgreSQL 데이터베이스를 사용합니다. 외부 데이터베이스를 사용하려면 `.env` 파일의 `POSTGRES_*` 변수를 수정하세요. + +# ⭐ 지원해 주세요 + +별을 남겨 주셔서 저희의 작업을 지원하실 수 있습니다! 감사합니다! + +![star](https://github.com/user-attachments/assets/71f65273-6755-469d-be44-087bb89d5e76) + +여러분의 피드백은 큰 힘이 됩니다. +다음에 보고 싶은 기능이나 완전히 새로운 기능 요청이 있다면 [알려주세요](mailto:founders@pyspur.dev?subject=Feature%20Request&body=I%20want%20this%20feature%3Ai). \ No newline at end of file diff --git a/pyspur/__init__.py b/pyspur/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pyspur/__pycache__/__init__.cpython-312.pyc b/pyspur/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..acf5deb9a39695d1e2fbc0d56b88e5d38700e721 Binary files /dev/null and b/pyspur/__pycache__/__init__.cpython-312.pyc differ diff --git a/pyspur/backend/.gitignore b/pyspur/backend/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..e8d923e34bfabeb87afee28265bb979a550f5c41 --- /dev/null +++ b/pyspur/backend/.gitignore @@ -0,0 +1,7 @@ +# ignore the test database file +test.db +/app/integrations/google/token.json +data/ +/secure_tokens/ +/.bolt-app-installation/ +pyspur/openapi_specs/ diff --git a/pyspur/backend/.pre-commit-config.yaml b/pyspur/backend/.pre-commit-config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2cf6e0383a5ff6adf4e8ec6a996213acf425c568 --- /dev/null +++ b/pyspur/backend/.pre-commit-config.yaml @@ -0,0 +1,18 @@ +repos: + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.9.10 + hooks: + - id: ruff + name: ruff + entry: ruff check + args: [--fix, --exit-non-zero-on-fix, --quiet] + language: system + types_or: [python, pyi] + require_serial: true + - id: ruff-format + name: ruff-format + entry: ruff format + args: [--quiet] + language: system + types_or: [python, pyi] + require_serial: true diff --git a/pyspur/backend/__init__.py b/pyspur/backend/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pyspur/backend/__pycache__/__init__.cpython-312.pyc b/pyspur/backend/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5ec2937d515b71e822a571753439be0ec1126d38 Binary files /dev/null and b/pyspur/backend/__pycache__/__init__.cpython-312.pyc differ diff --git a/pyspur/backend/alembic.ini b/pyspur/backend/alembic.ini new file mode 100644 index 0000000000000000000000000000000000000000..36c2bef4e7361bfc592a31d7ba3046b409b42f96 --- /dev/null +++ b/pyspur/backend/alembic.ini @@ -0,0 +1,117 @@ +# A generic, single database configuration. + +[alembic] +# path to migration scripts +# Use forward slashes (/) also on windows to provide an os agnostic path +script_location = pyspur/models/management/alembic/ + +# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s +# Uncomment the line below if you want the files to be prepended with date and time +# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file +# for all available tokens +# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s + +# sys.path path, will be prepended to sys.path if present. +# defaults to the current working directory. +prepend_sys_path = . + +# timezone to use when rendering the date within the migration file +# as well as the filename. +# If specified, requires the python>=3.9 or backports.zoneinfo library. +# Any required deps can installed by adding `alembic[tz]` to the pip requirements +# string value is passed to ZoneInfo() +# leave blank for localtime +# timezone = + +# max length of characters to apply to the "slug" field +# truncate_slug_length = 40 + +# set to 'true' to run the environment during +# the 'revision' command, regardless of autogenerate +# revision_environment = false + +# set to 'true' to allow .pyc and .pyo files without +# a source .py file to be detected as revisions in the +# versions/ directory +# sourceless = false + +# version location specification; This defaults +# to app/models/management/alembic//versions. When using multiple version +# directories, initial revisions must be specified with --version-path. +# The path separator used here should be the separator specified by "version_path_separator" below. +# version_locations = %(here)s/bar:%(here)s/bat:app/models/management/alembic//versions + +# version path separator; As mentioned above, this is the character used to split +# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep. +# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas. +# Valid values for version_path_separator are: +# +# version_path_separator = : +# version_path_separator = ; +# version_path_separator = space +# version_path_separator = newline +version_path_separator = os # Use os.pathsep. Default configuration used for new projects. + +# set to 'true' to search source files recursively +# in each "version_locations" directory +# new in Alembic version 1.10 +# recursive_version_locations = false + +# the output encoding used when revision files +# are written from script.py.mako +# output_encoding = utf-8 + +sqlalchemy.url = postgresql://%(POSTGRES_USER)s:%(POSTGRES_PASSWORD)s@%(POSTGRES_HOST)s:%(POSTGRES_PORT)s/%(POSTGRES_DB)s + + +[post_write_hooks] +# post_write_hooks defines scripts or Python functions that are run +# on newly generated revision scripts. See the documentation for further +# detail and examples + +# format using "black" - use the console_scripts runner, against the "black" entrypoint +# hooks = black +# black.type = console_scripts +# black.entrypoint = black +# black.options = -l 79 REVISION_SCRIPT_FILENAME + +# lint with attempts to fix using "ruff" - use the exec runner, execute a binary +# hooks = ruff +# ruff.type = exec +# ruff.executable = %(here)s/.venv/bin/ruff +# ruff.options = --fix REVISION_SCRIPT_FILENAME + +# Logging configuration +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARN +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARN +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/pyspur/backend/entrypoint.sh b/pyspur/backend/entrypoint.sh new file mode 100644 index 0000000000000000000000000000000000000000..bbe0e60da3327853a8b2bfcd43e508d82d3a6827 --- /dev/null +++ b/pyspur/backend/entrypoint.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +# First test Ollama connection if URL is provided +if [ -f "test_ollama.sh" ]; then + chmod +x test_ollama.sh + ./test_ollama.sh +fi + +set -e +mkdir -p /pyspur/backend/pyspur/models/management/alembic/versions/ +start_server() { + cd /pyspur/backend + uvicorn "pyspur.api.main:app" --reload --reload-include ./log_conf.yaml --reload-include "**/*.py" --log-config=log_conf.yaml --host 0.0.0.0 --port 8000 +} + +main() { + alembic upgrade head + start_server +} + +main \ No newline at end of file diff --git a/pyspur/backend/llms-ctx.txt b/pyspur/backend/llms-ctx.txt new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pyspur/backend/log_conf.yaml b/pyspur/backend/log_conf.yaml new file mode 100644 index 0000000000000000000000000000000000000000..739968a027da3a80cb25ca3c140d00dd83960a51 --- /dev/null +++ b/pyspur/backend/log_conf.yaml @@ -0,0 +1,54 @@ +version: 1 +disable_existing_loggers: True +formatters: + default: + # "()": uvicorn.logging.DefaultFormatter + format: '%(asctime)s - %(name)s - %(levelname)s - %(message)s' + access: + # "()": uvicorn.logging.AccessFormatter + format: '%(asctime)s - %(name)s - %(levelname)s - %(message)s' +handlers: + default: + formatter: default + class: logging.StreamHandler + stream: ext://sys.stderr + access: + formatter: access + class: logging.StreamHandler + stream: ext://sys.stdout +loggers: + uvicorn.error: + level: INFO + handlers: + - default + propagate: no + uvicorn.access: + level: INFO + handlers: + - access + propagate: no + httpx: + level: ERROR + handlers: + - default + httpcore: + level: ERROR + handlers: + - default + watchfiles.main: + level: INFO + handlers: + - default + LiteLLM: + level: INFO + handlers: + - default + openai._base_client: + level: INFO + handlers: + - default +root: + level: DEBUG + handlers: + - default + propagate: no \ No newline at end of file diff --git a/pyspur/backend/output_files/.gitignore b/pyspur/backend/output_files/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..005717ead0bb8f920c00d76feb8207deb7946a57 --- /dev/null +++ b/pyspur/backend/output_files/.gitignore @@ -0,0 +1,2 @@ +* +!.gitignore diff --git a/pyspur/backend/pyproject.toml b/pyspur/backend/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..692e46499da8098ffea8e8ea59d53fcc80eae3bb --- /dev/null +++ b/pyspur/backend/pyproject.toml @@ -0,0 +1,142 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "pyspur" +version = "0.1.18" +description = "PySpur is a Graph UI for building AI Agents in Python" +requires-python = ">=3.11" +license = "Apache-2.0" +classifiers = [ + "Operating System :: MacOS :: MacOS X", + "Operating System :: POSIX :: Linux", + "Operating System :: Unix", + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", +] +maintainers = [ + {name = "Srijan Patel", email = "srijan@pyspur.dev"}, + {name = "Jean Kaddour", email = "jean@pyspur.dev"}, + {name = "Parshva Bhadra", email = "parshva.bhadra@pyspur.dev"}, +] +dependencies = [ + "alembic==1.14.0", + "arrow==1.3.0", + "asyncio==3.4.3", + "attrs==24.3.0", + "backend==0.2.4.1", + "chromadb==0.6.2", + "datasets==3.2.0", + "docx2txt==0.8", + "docx2python==3.3.0", + "exa-py==1.9.0", + "fastapi==0.115.6", + "genanki==0.13.1", + "google-api-python-client==2.159.0", + "grpcio==1.69.0", + "Jinja2==3.1.6", + "litellm==1.61.15", + "loguru==0.7.3", + "numpy==2.2.1", + "ollama==0.4.5", + "pandas==2.2.3", + "pinecone==5.4.2", + "praw==7.8.1", + "psycopg2-binary==2.9.10", + "pydantic==2.10.5", + "pypdf==5.1.0", + "python-dotenv==1.0.1", + "python-multipart==0.0.20", + "python-pptx==1.0.2", + "PyYAML==6.0.2", + "py-zerox==0.0.7", + "qdrant_client==1.12.2", + "redis==5.2.1", + "regex==2024.11.6", + "requests==2.32.3", + "requests-file==2.1.0", + "requests-oauthlib==1.3.1", + "retrying==1.3.4", + "slack_sdk==3.35.0", + "slack_bolt==1.23.0", + "SQLAlchemy==2.0.36", + "supabase==2.11.0", + "six==1.17.0", + "tenacity==8.3.0", + "tiktoken==0.7.0", + "tqdm==4.67.1", + "weaviate_client==4.10.2", + "itsdangerous==2.2.0", + "phidata==2.7.8", + "youtube_transcript_api==0.6.3", + "PyGithub==2.5.0", + "firecrawl-py==1.10.2", + "httpx[http2]==0.27.2", + "sendgrid==6.11.0", + "resend==2.6.0", + "typer[all]==0.9.0", + "psutil>=7.0.0", +] + +[project.urls] +Repository = "https://github.com/pyspur-dev/pyspur" +Documentation = "https://docs.pyspur.dev" + +[project.scripts] +pyspur = "pyspur.cli:main" + +[project.optional-dependencies] +dev = [ + "pytest>=7.0", + "pytest-cov>=4.0", + "ruff>=0.1.0", +] + +[tool.hatch.build.targets.wheel] +universal = false +packages = ["pyspur"] +zip-safe = false + +[tool.hatch.build.targets.wheel.force-include] +"pyspur/templates" = "pyspur/templates/" +"pyspur/static" = "pyspur/static/" + +[tool.ruff] +line-length = 100 +target-version = "py312" + +[tool.ruff.lint] +select = ["E", "F", "I", "N", "W", "B", "C", "D", "PYI"] +ignore = [ + "B006", # Do not use mutable default arguments + "B008", # Do not perform function call `Depends` in argument defaults + "C901", # Function is too complex + "D100", # Missing docstring in public module + "D101", # Missing docstring in public class + "D102", # Missing docstring in public method + "D103", # Missing docstring in public function + "D104", # Missing docstring in public package + "D105", # Missing docstring in magic method + "D106", # Missing docstring in public nested class + "D107", # Missing docstring in __init__ + "I001", # Import block is un-sorted or un-formatted + "E402", # Module level import not at top of file +] + +[tool.black] +line-length = 100 +target-version = ["py312"] + +[tool.mypy] +python_version = "3.12" +warn_return_any = true +warn_unused_configs = true +disallow_untyped_defs = true +check_untyped_defs = true + +[tool.pytest.ini_options] +testpaths = ["tests"] +python_files = ["test_*.py"] diff --git a/pyspur/backend/pyspur/__init__.py b/pyspur/backend/pyspur/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pyspur/backend/pyspur/__pycache__/__init__.cpython-312.pyc b/pyspur/backend/pyspur/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..90c6dce10f5bed7afda90f811b6f75d55663af7a Binary files /dev/null and b/pyspur/backend/pyspur/__pycache__/__init__.cpython-312.pyc differ diff --git a/pyspur/backend/pyspur/api/__init__.py b/pyspur/backend/pyspur/api/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pyspur/backend/pyspur/api/ai_management.py b/pyspur/backend/pyspur/api/ai_management.py new file mode 100644 index 0000000000000000000000000000000000000000..47a21a20f1759e5347db63d1a6202efa40472edd --- /dev/null +++ b/pyspur/backend/pyspur/api/ai_management.py @@ -0,0 +1,352 @@ +import json +import re +from typing import Any, Dict, List, Literal, Optional, cast + +from fastapi import APIRouter, HTTPException +from loguru import logger +from pydantic import BaseModel + +from ..nodes.llm._utils import generate_text + +router = APIRouter() + + +class SchemaGenerationRequest(BaseModel): + description: str + existing_schema: Optional[str] = None + + +class MessageGenerationRequest(BaseModel): + description: str + message_type: Literal["system", "user"] # "system" or "user" + existing_message: Optional[str] = None + context: Optional[str] = None + available_variables: Optional[List[str]] = None + + +@router.post("/generate_schema/") +async def generate_schema(request: SchemaGenerationRequest) -> Dict[str, Any]: + response: str = "" + try: + # Prepare the system message + system_message = """You are a JSON Schema expert. Your task is to generate a JSON Schema + based on a text description. + The schema should: + 1. Follow JSON Schema standards + 2. Include appropriate types, required fields, and descriptions + 3. Be clear and well-structured + 4. Include type: "object" at the root + 5. Include a properties object + 6. Set appropriate required fields + 7. Include meaningful descriptions for each field + 8. Return ONLY the JSON schema without any markdown formatting or explanation + + Here are some examples: + + + Input: "Create a schema for a person with name, age and optional email" + Output: { + "type": "object", + "properties": { + "name": { + "type": "string", + "description": "The person's full name" + }, + "age": { + "type": "integer", + "description": "The person's age in years", + "minimum": 0 + }, + "email": { + "type": "string", + "description": "The person's email address", + "format": "email" + } + }, + "required": ["name", "age"] + } + + + + Input: "Schema for a blog post with title, content, author details and tags" + Output: { + "type": "object", + "properties": { + "title": { + "type": "string", + "description": "The title of the blog post" + }, + "content": { + "type": "string", + "description": "The main content of the blog post" + }, + "author": { + "type": "object", + "description": "Details about the post author", + "properties": { + "name": { + "type": "string", + "description": "Author's full name" + }, + "bio": { + "type": "string", + "description": "Short biography of the author" + } + }, + "required": ["name"] + }, + "tags": { + "type": "array", + "description": "List of tags associated with the post", + "items": { + "type": "string" + } + } + }, + "required": ["title", "content", "author"] + } + + """ + + # Prepare the user message + user_message = ( + f"Generate a JSON Schema for the following description:\n{request.description}" + ) + + if request.existing_schema: + user_message += ( + f"\n\nPlease consider this existing schema as context:\n{request.existing_schema}" + ) + user_message += ( + "\nModify it based on the description while preserving any compatible parts." + ) + + # Call the LLM + messages = [ + {"role": "system", "content": system_message}, + {"role": "user", "content": user_message}, + ] + + message_response = await generate_text( + messages=messages, model_name="openai/o3-mini", json_mode=True + ) + assert message_response.content, "No response from LLM" + response = message_response.content + + # Try to parse the response in different ways + try: + # First try: direct JSON parse + schema = json.loads(response) + if isinstance(schema, dict) and "output" in schema: + # If we got a wrapper object with an "output" key, extract the schema from it + schema_str = cast(str, schema["output"]) + # Extract JSON from potential markdown code blocks + json_match = re.search(r"```json\s*(.*?)\s*```", schema_str, re.DOTALL) + if json_match: + schema_str = json_match.group(1) + schema = json.loads(schema_str) + except json.JSONDecodeError as e: + # Second try: Look for JSON in markdown code blocks + json_match = re.search(r"```(?:json)?\s*(.*?)\s*```", response, re.DOTALL) + if json_match: + schema = json.loads(json_match.group(1)) + else: + raise ValueError("Could not extract valid JSON schema from response") from e + + # Validate the schema structure + if not isinstance(schema, dict) or "type" not in schema or "properties" not in schema: + raise ValueError("Generated schema is not valid - missing required fields") + + return cast(Dict[str, Any], schema) + + except Exception as e: + # Log the raw response if it exists and is not empty + if response: + truncated_response = response[:1000] + "..." if len(response) > 1000 else response + logger.error(f"Schema generation failed. response (truncated): {truncated_response}.") + raise HTTPException(status_code=400, detail=str(e)) from e + + +@router.post("/generate_message/") +async def generate_message(request: MessageGenerationRequest) -> Dict[str, str]: + response: str = "" + try: + # Prepare the system message based on the message type + if request.message_type == "system": + system_message = """You are an expert at crafting effective \ +system messages for AI assistants. + Your task is to generate a clear, concise, and effective system message based\ +on the provided description. + + # INSTRUCTIONS + A good system message should: + 1. Clearly define the AI's role and purpose + 2. Set appropriate boundaries and constraints + 3. Provide necessary context and background information + 4. Be concise but comprehensive + 5. Use clear, unambiguous language + 6. Use XML tags when appropriate to structure information: + e.g., ..., ... + + # FORMAT REQUIREMENTS + Your generated system message MUST include: + 1. An "# Instructions" section with clearly enumerated instructions (1., 2., 3., etc.) + 2. Clear organization with appropriate headings and structure + + # EXAMPLES + Example 1 (Simple role definition): + ``` + You are a helpful coding assistant that specializes in Python programming. + + # Instructions + 1. Provide accurate Python code examples when requested + 2. Explain coding concepts clearly and concisely + 3. Suggest best practices for Python development + ``` + + Example 2 (With XML tags): + ``` + You are a data analysis expert specialized in interpreting financial data. + + # Instructions + 1. Only provide analysis based on the data provided + 2. Present findings with supporting evidence + 3. Identify trends and patterns in the data + 4. Suggest actionable insights when appropriate + + Do not make assumptions about data you cannot see. + Present your analysis with clear sections for Summary, Details, \ +and Recommendations. + ``` + + Return ONLY the system message text without any additional explanation or formatting. + """ + elif request.message_type == "user": + system_message = """You are an expert at crafting effective user prompts for AI \ +assistants. + Your task is to generate a clear, specific, and effective user prompt based on the \ +provided description. + + # INSTRUCTIONS + A good user prompt should: + 1. Clearly state what is being asked of the AI + 2. Provide necessary context and specific details + 3. Be structured in a way that guides the AI to produce the desired output + 4. Use clear, unambiguous language + 5. Include any relevant constraints or requirements + 6. Use XML tags when appropriate to structure information \ +(e.g., ..., ...) + + # FORMAT REQUIREMENTS + Your generated user prompt MUST include: + 1. An "# Instructions" section with clearly enumerated instructions (1., 2., 3., etc.) + 2. Clear organization with appropriate headings and structure + + # EXAMPLES + Example 1 (Simple request): + ``` + Explain how JavaScript promises work with code examples. + + # Instructions + 1. Explain the concept in simple terms first + 2. Provide practical code examples + 3. Include error handling patterns + ``` + + Example 2 (With XML tags): + ``` + I'm building a React application with a complex state management system.\ + + + Review the following code snippet and suggest improvements for performance \ +and readability: + + + // Code would go here + + + # Instructions + 1. Identify performance bottlenecks in the code + 2. Suggest specific refactoring approaches + 3. Explain the reasoning behind each recommendation + 4. Provide example code for key improvements + ``` + + Return ONLY the user prompt text without any additional explanation or formatting. + """ + else: + raise ValueError(f"Unsupported message type: {request.message_type}") + + # Prepare the user message + user_message = f"Generate a {request.message_type} message based on the following \ +description:\n{request.description}" + + if request.existing_message: + user_message += f"\n\nPlease consider this existing message as a starting \ +point:\n{request.existing_message}" + + # Add context if provided + if request.context: + user_message += f"\n\nAdditional context:\n{request.context}" + + # Add information about available template variables if provided + if request.available_variables and len(request.available_variables) > 0: + variables_str = "\n".join([f"- {var}" for var in request.available_variables]) + + if request.message_type == "system": + user_message += f"\n\nThe message should appropriately incorporate the following \ +template variables that the user has specifically selected for this message:\n{variables_str}\n\n\ +These variables will be replaced with actual values at runtime. Use them in the appropriate places \ +to make the message dynamic and context-aware." + else: # user message + user_message += f"\n\nThe prompt should appropriately incorporate the following \ +template variables that the user has specifically selected for this message:\n{variables_str}\n\n\ +These variables will be replaced with actual values at runtime. Use them in the appropriate places \ +to make the prompt dynamic and personalized." + + # Additional guidance on template variable usage + user_message += "\n\nUse the variables in the format {{ variable_name }}. Only use the \ +variables listed above - do not invent new variables." + + # Prepare messages for the LLM + messages = [ + {"role": "system", "content": system_message}, + {"role": "user", "content": user_message}, + ] + + # Generate the message using OpenAI + message_response = await generate_text( + messages=messages, + model_name="openai/o3-mini", + temperature=0.7, + max_tokens=1000, + ) + response = cast(str, message_response.content) + + # Process the response to extract the message + message: str = "" + if response.strip().startswith("{") and response.strip().endswith("}"): + try: + parsed_response = json.loads(response) + if isinstance(parsed_response, dict) and "output" in parsed_response: + message = cast(str, parsed_response["output"]) + else: + message = response + except json.JSONDecodeError: + message = response + else: + message = response + + # Remove any markdown code blocks if present + if "```" in message: + message = re.sub(r"```.*?```", "", message, flags=re.DOTALL).strip() + else: + # Fallback if response is not a string (shouldn't happen) + message = str(response) + + return {"message": message} + except Exception as e: + logger.error(f"Error generating message: {str(e)}") + if response: + logger.error(f"Raw response: {response}") + raise HTTPException(status_code=500) from e diff --git a/pyspur/backend/pyspur/api/api_app.py b/pyspur/backend/pyspur/api/api_app.py new file mode 100644 index 0000000000000000000000000000000000000000..0f553b1773a41b66ff08d85faf5364688cef658d --- /dev/null +++ b/pyspur/backend/pyspur/api/api_app.py @@ -0,0 +1,53 @@ +from fastapi import FastAPI + +from ..nodes.registry import NodeRegistry + +NodeRegistry.discover_nodes() + +from ..integrations.google.auth import router as google_auth_router +from .ai_management import router as ai_management_router +from .dataset_management import router as dataset_management_router +from .evals_management import router as evals_management_router +from .file_management import router as file_management_router +from .key_management import router as key_management_router +from .node_management import router as node_management_router +from .openai_compatible_api import router as openai_compatible_api_router +from .openapi_management import router as openapi_router +from .output_file_management import router as output_file_management_router +from .rag_management import router as rag_management_router +from .run_management import router as run_management_router +from .session_management import router as session_management_router +from .slack_management import router as slack_management_router +from .template_management import router as template_management_router +from .user_management import router as user_management_router +from .workflow_code_convert import router as workflow_code_router +from .workflow_management import router as workflow_management_router +from .workflow_run import router as workflow_run_router + +# Create a sub-application for API routes +api_app = FastAPI( + docs_url="/docs", + redoc_url="/redoc", + title="PySpur API", + version="1.0.0", +) + +api_app.include_router(node_management_router, prefix="/node", tags=["nodes"]) +api_app.include_router(workflow_management_router, prefix="/wf", tags=["workflows"]) +api_app.include_router(workflow_run_router, prefix="/wf", tags=["workflow runs"]) +api_app.include_router(workflow_code_router, prefix="/code_convert", tags=["workflow code (beta)"]) +api_app.include_router(dataset_management_router, prefix="/ds", tags=["datasets"]) +api_app.include_router(run_management_router, prefix="/run", tags=["runs"]) +api_app.include_router(output_file_management_router, prefix="/of", tags=["output files"]) +api_app.include_router(key_management_router, prefix="/env-mgmt", tags=["environment management"]) +api_app.include_router(template_management_router, prefix="/templates", tags=["templates"]) +api_app.include_router(openai_compatible_api_router, prefix="/api", tags=["openai compatible"]) +api_app.include_router(evals_management_router, prefix="/evals", tags=["evaluations"]) +api_app.include_router(google_auth_router, prefix="/google", tags=["google auth"]) +api_app.include_router(rag_management_router, prefix="/rag", tags=["rag"]) +api_app.include_router(file_management_router, prefix="/files", tags=["files"]) +api_app.include_router(ai_management_router, prefix="/ai", tags=["ai"]) +api_app.include_router(user_management_router, prefix="/user", tags=["users"]) +api_app.include_router(session_management_router, prefix="/session", tags=["sessions"]) +api_app.include_router(slack_management_router, prefix="/slack", tags=["slack integration"]) +api_app.include_router(openapi_router, prefix="/openapi", tags=["openapi"]) diff --git a/pyspur/backend/pyspur/api/dataset_management.py b/pyspur/backend/pyspur/api/dataset_management.py new file mode 100644 index 0000000000000000000000000000000000000000..675f0039cbf44dc71a5bcee9594e95fd6a7cb10b --- /dev/null +++ b/pyspur/backend/pyspur/api/dataset_management.py @@ -0,0 +1,121 @@ +import os +from datetime import datetime, timezone +from typing import List + +from fastapi import APIRouter, Depends, File, HTTPException, UploadFile +from sqlalchemy.orm import Session + +from ..database import get_db +from ..models.dataset_model import DatasetModel +from ..models.run_model import RunModel +from ..schemas.dataset_schemas import DatasetResponseSchema +from ..schemas.run_schemas import RunResponseSchema + +router = APIRouter() + + +def save_file(file: UploadFile) -> str: + filename = file.filename + assert filename is not None + file_location = os.path.join(os.path.dirname(__file__), "..", "..", "datasets", filename) + with open(file_location, "wb+") as file_object: + file_object.write(file.file.read()) + return file_location + + +@router.post("/", description="Upload a new dataset") +def upload_dataset( + name: str, + description: str = "", + file: UploadFile = File(...), + db: Session = Depends(get_db), +) -> DatasetResponseSchema: + file_location = save_file(file) + new_dataset = DatasetModel( + name=name, + description=description, + file_path=file_location, + uploaded_at=datetime.now(timezone.utc), + ) + db.add(new_dataset) + db.commit() + db.refresh(new_dataset) + return DatasetResponseSchema( + id=new_dataset.id, + name=new_dataset.name, + description=new_dataset.description, + filename=new_dataset.file_path, + created_at=new_dataset.uploaded_at, + updated_at=new_dataset.uploaded_at, + ) + + +@router.get( + "/", + response_model=List[DatasetResponseSchema], + description="List all datasets", +) +def list_datasets(db: Session = Depends(get_db)) -> List[DatasetResponseSchema]: + datasets = db.query(DatasetModel).all() + dataset_list = [ + DatasetResponseSchema( + id=ds.id, + name=ds.name, + description=ds.description, + filename=ds.file_path, + created_at=ds.uploaded_at, + updated_at=ds.uploaded_at, + ) + for ds in datasets + ] + return dataset_list + + +@router.get( + "/{dataset_id}/", + response_model=DatasetResponseSchema, + description="Get a dataset by ID", +) +def get_dataset(dataset_id: str, db: Session = Depends(get_db)) -> DatasetResponseSchema: + dataset = db.query(DatasetModel).filter(DatasetModel.id == dataset_id).first() + if not dataset: + raise HTTPException(status_code=404, detail="Dataset not found") + return DatasetResponseSchema( + id=dataset.id, + name=dataset.name, + description=dataset.description, + filename=dataset.file_path, + created_at=dataset.uploaded_at, + updated_at=dataset.uploaded_at, + ) + + +@router.delete( + "/{dataset_id}/", + description="Delete a dataset by ID", +) +def delete_dataset(dataset_id: str, db: Session = Depends(get_db)): + dataset = db.query(DatasetModel).filter(DatasetModel.id == dataset_id).first() + if not dataset: + raise HTTPException(status_code=404, detail="Dataset not found") + db.delete(dataset) + db.commit() + return {"message": "Dataset deleted"} + + +@router.get( + "/{dataset_id}/list_runs/", + description="List all runs that used this dataset", + response_model=List[RunResponseSchema], +) +def list_dataset_runs(dataset_id: str, db: Session = Depends(get_db)): + dataset = db.query(DatasetModel).filter(DatasetModel.id == dataset_id).first() + if not dataset: + raise HTTPException(status_code=404, detail="Dataset not found") + runs = ( + db.query(RunModel) + .filter(RunModel.input_dataset_id == dataset_id) + .order_by(RunModel.created_at.desc()) + .all() + ) + return runs diff --git a/pyspur/backend/pyspur/api/evals_management.py b/pyspur/backend/pyspur/api/evals_management.py new file mode 100644 index 0000000000000000000000000000000000000000..a85a92ce0de007121539b8b3889574f40b8b15a8 --- /dev/null +++ b/pyspur/backend/pyspur/api/evals_management.py @@ -0,0 +1,197 @@ +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Dict, List + +from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException +from sqlalchemy.orm import Session + +from ..database import get_db +from ..evals.evaluator import load_yaml_config, prepare_and_evaluate_dataset +from ..models.eval_run_model import EvalRunModel, EvalRunStatus +from ..models.workflow_model import WorkflowModel +from ..schemas.eval_schemas import ( + EvalRunRequest, + EvalRunResponse, + EvalRunStatusEnum, +) +from ..schemas.workflow_schemas import WorkflowDefinitionSchema +from .workflow_management import get_workflow_output_variables + +router = APIRouter() + +EVALS_DIR = Path(__file__).parent.parent / "evals" / "tasks" + + +@router.get("/", description="List all available evals") +def list_evals() -> List[Dict[str, Any]]: + """ + List all available evals by scanning the tasks directory for YAML files. + """ + evals = [] + if not EVALS_DIR.exists(): + raise HTTPException(status_code=500, detail="Evals directory not found") + for eval_file in EVALS_DIR.glob("*.yaml"): + try: + eval_content = load_yaml_config(yaml_path=eval_file) + metadata = eval_content.get("metadata", {}) + evals.append( + { + "name": metadata.get("name", eval_file.stem), + "description": metadata.get("description", ""), + "type": metadata.get("type", "Unknown"), + "num_samples": metadata.get("num_samples", "N/A"), + "paper_link": metadata.get("paper_link", ""), + "file_name": eval_file.name, + } + ) + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error parsing {eval_file.name}: {e}") + return evals + + +@router.post( + "/launch/", + response_model=EvalRunResponse, + description="Launch an eval job with detailed validation and workflow integration", +) +async def launch_eval( + request: EvalRunRequest, + background_tasks: BackgroundTasks, + db: Session = Depends(get_db), +) -> EvalRunResponse: + """ + Launch an eval job by triggering the evaluator with the specified eval configuration. + """ + # Validate workflow ID + workflow = db.query(WorkflowModel).filter(WorkflowModel.id == request.workflow_id).first() + if not workflow: + raise HTTPException(status_code=404, detail="Workflow not found") + + workflow_definition = WorkflowDefinitionSchema.model_validate(workflow.definition) + + eval_file = EVALS_DIR / f"{request.eval_name}.yaml" + if not eval_file.exists(): + raise HTTPException(status_code=404, detail="Eval configuration not found") + + try: + # Load the eval configuration + eval_config = load_yaml_config(eval_file) + + # Validate the output variable + leaf_node_output_variables = get_workflow_output_variables( + workflow_id=request.workflow_id, db=db + ) + + print(f"Valid output variables: {leaf_node_output_variables}") + + # Extract the list of valid prefixed variables + valid_prefixed_variables = [var["prefixed_variable"] for var in leaf_node_output_variables] + + if request.output_variable not in valid_prefixed_variables: + raise HTTPException( + status_code=400, + detail=( + f"Invalid output variable '{request.output_variable}'. " + f"Must be one of: {leaf_node_output_variables}" + ), + ) + + # Create a new EvalRunModel instance + new_eval_run = EvalRunModel( + eval_name=request.eval_name, + workflow_id=request.workflow_id, + output_variable=request.output_variable, + num_samples=request.num_samples, + status=EvalRunStatus.PENDING, + start_time=datetime.now(timezone.utc), + ) + db.add(new_eval_run) + db.commit() + db.refresh(new_eval_run) + + async def run_eval_task(eval_run_id: str): + with next(get_db()) as session: + eval_run = ( + session.query(EvalRunModel).filter(EvalRunModel.id == eval_run_id).first() + ) + if not eval_run: + session.close() + return + + eval_run.status = EvalRunStatus.RUNNING + session.commit() + + try: + # Run the evaluation asynchronously + results = await prepare_and_evaluate_dataset( + eval_config, + workflow_definition=workflow_definition, + num_samples=eval_run.num_samples, + output_variable=eval_run.output_variable, + ) + eval_run.results = results + eval_run.status = EvalRunStatus.COMPLETED + eval_run.end_time = datetime.now(timezone.utc) + except Exception as e: + eval_run.status = EvalRunStatus.FAILED + eval_run.end_time = datetime.now(timezone.utc) + session.commit() + raise e + finally: + session.commit() + + background_tasks.add_task(run_eval_task, new_eval_run.id) + + # Return all required parameters + return EvalRunResponse( + run_id=new_eval_run.id, + eval_name=new_eval_run.eval_name, + workflow_id=new_eval_run.workflow_id, + status=EvalRunStatusEnum(new_eval_run.status.value), + start_time=new_eval_run.start_time, + end_time=new_eval_run.end_time, + ) + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error launching eval: {e}") + + +@router.get( + "/runs/{eval_run_id}", + response_model=EvalRunResponse, + description="Get the status of an eval run", +) +async def get_eval_run_status(eval_run_id: str, db: Session = Depends(get_db)) -> EvalRunResponse: + eval_run = db.query(EvalRunModel).filter(EvalRunModel.id == eval_run_id).first() + if not eval_run: + raise HTTPException(status_code=404, detail="Eval run not found") + return EvalRunResponse( + run_id=eval_run.id, + eval_name=eval_run.eval_name, + workflow_id=eval_run.workflow_id, + status=EvalRunStatusEnum(eval_run.status.value), + start_time=eval_run.start_time, + end_time=eval_run.end_time, + results=eval_run.results, + ) + + +@router.get( + "/runs/", + response_model=List[EvalRunResponse], + description="List all eval runs", +) +async def list_eval_runs( + db: Session = Depends(get_db), +) -> List[EvalRunResponse]: + eval_runs = db.query(EvalRunModel).order_by(EvalRunModel.start_time.desc()).all() + return [ + EvalRunResponse( + run_id=eval_run.id, + eval_name=eval_run.eval_name, + workflow_id=eval_run.workflow_id, + status=EvalRunStatusEnum(eval_run.status.value), + start_time=eval_run.start_time, + end_time=eval_run.end_time, + ) + for eval_run in eval_runs + ] diff --git a/pyspur/backend/pyspur/api/file_management.py b/pyspur/backend/pyspur/api/file_management.py new file mode 100644 index 0000000000000000000000000000000000000000..f8e6ed8425e5d6c1d0185a5a1f33f623872be596 --- /dev/null +++ b/pyspur/backend/pyspur/api/file_management.py @@ -0,0 +1,144 @@ +import os +import shutil +from datetime import datetime, timezone +from pathlib import Path +from typing import List + +from fastapi import APIRouter, HTTPException +from fastapi.responses import FileResponse + +from ..schemas.file_schemas import FileResponseSchema + +router = APIRouter() + +# Define base data directory +DATA_DIR = Path("data") + + +@router.get( + "/{workflow_id}", + response_model=List[FileResponseSchema], + description="List all files for a specific workflow", +) +async def list_workflow_files(workflow_id: str) -> List[FileResponseSchema]: + """ + List all files in the workflow's directory. + Returns a list of dictionaries containing file information. + """ + workflow_dir = DATA_DIR / "run_files" / workflow_id + + if not workflow_dir.exists(): + return [] + + files: List[FileResponseSchema] = [] + for file_path in workflow_dir.glob("*"): + if file_path.is_file(): + files.append( + FileResponseSchema( + name=file_path.name, + path=str(file_path.relative_to(DATA_DIR)), + size=os.path.getsize(file_path), + created=datetime.fromtimestamp(os.path.getctime(file_path), tz=timezone.utc), + workflow_id=workflow_id, + ) + ) + + return files + + +@router.get( + "/", + response_model=List[FileResponseSchema], + description="List all files across all workflows", +) +async def list_all_files() -> List[FileResponseSchema]: + """ + List all files in the data directory across all workflows. + Returns a list of dictionaries containing file information. + """ + test_files_dir = DATA_DIR / "run_files" + + if not test_files_dir.exists(): + return [] + + files: List[FileResponseSchema] = [] + for workflow_dir in test_files_dir.glob("*"): + if workflow_dir.is_dir(): + workflow_id = workflow_dir.name + for file_path in workflow_dir.glob("*"): + if file_path.is_file(): + files.append( + FileResponseSchema( + name=file_path.name, + workflow_id=workflow_id, + path=str(file_path.relative_to(DATA_DIR)), + size=os.path.getsize(file_path), + created=datetime.fromtimestamp( + os.path.getctime(file_path), tz=timezone.utc + ), + ) + ) + + return files + + +@router.delete("/{workflow_id}/{filename}", description="Delete a specific file") +async def delete_file(workflow_id: str, filename: str): + """ + Delete a specific file from a workflow's directory. + """ + file_path = DATA_DIR / "run_files" / workflow_id / filename + + if not file_path.exists(): + raise HTTPException(status_code=404, detail="File not found") + + try: + os.remove(file_path) + return {"message": "File deleted successfully"} + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error deleting file: {str(e)}") + + +@router.delete("/{workflow_id}", description="Delete all files for a workflow") +async def delete_workflow_files(workflow_id: str): + """ + Delete all files in a workflow's directory. + """ + workflow_dir = DATA_DIR / "run_files" / workflow_id + + if not workflow_dir.exists(): + raise HTTPException(status_code=404, detail="Workflow directory not found") + + try: + shutil.rmtree(workflow_dir) + return {"message": "All workflow files deleted successfully"} + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error deleting workflow files: {str(e)}") + + +@router.get( + "/{file_path:path}", + description="Get a specific file", + response_class=FileResponse, +) +async def get_file(file_path: str): + """ + Get a specific file from the data directory. + Validates file path to prevent path traversal attacks. + """ + # Validate that file_path doesn't contain path traversal patterns + if ".." in file_path or "~" in file_path: + raise HTTPException(status_code=400, detail="Invalid file path") + + # Resolve the full path and ensure it's within DATA_DIR + try: + full_path = (DATA_DIR / file_path).resolve() + if not str(full_path).startswith(str(DATA_DIR.resolve())): + raise HTTPException(status_code=403, detail="Access denied") + except Exception: + raise HTTPException(status_code=400, detail="Invalid file path") + + if not full_path.exists(): + raise HTTPException(status_code=404, detail="File not found") + + return FileResponse(str(full_path)) diff --git a/pyspur/backend/pyspur/api/key_management.py b/pyspur/backend/pyspur/api/key_management.py new file mode 100644 index 0000000000000000000000000000000000000000..24bedb5e881108b1e735eb765b6ec844f29bee52 --- /dev/null +++ b/pyspur/backend/pyspur/api/key_management.py @@ -0,0 +1,477 @@ +import os +from typing import Dict, List, Optional + +from dotenv import dotenv_values, load_dotenv, set_key, unset_key +from fastapi import APIRouter, HTTPException +from pydantic import BaseModel + +from ..rag.datastore.factory import VectorStoreConfig, get_vector_stores +from ..rag.embedder import EmbeddingModelConfig, EmbeddingModels + +# Load existing environment variables from the .env file +load_dotenv(".env") + +router = APIRouter() + + +class ProviderParameter(BaseModel): + name: str + description: str + required: bool = True + type: str = "password" # password, text, select + + +class ProviderConfig(BaseModel): + id: str + name: str + description: str + category: str # 'llm', 'embedding', 'vectorstore' + parameters: List[ProviderParameter] + icon: str = "database" # Default icon for vector stores + + +PROVIDER_CONFIGS = [ + # LLM Providers + ProviderConfig( + id="openai", + name="OpenAI", + description="OpenAI's GPT models", + category="llm", + icon="openai", + parameters=[ + ProviderParameter(name="OPENAI_API_KEY", description="OpenAI API Key"), + ], + ), + ProviderConfig( + id="azure-openai", + name="Azure OpenAI", + description="Azure-hosted OpenAI models", + category="llm", + icon="azure", + parameters=[ + ProviderParameter(name="AZURE_OPENAI_API_KEY", description="Azure OpenAI API Key"), + ProviderParameter( + name="AZURE_OPENAI_ENDPOINT", + description="Azure OpenAI Endpoint URL", + type="text", + ), + ProviderParameter( + name="AZURE_OPENAI_API_VERSION", + description="API Version (e.g. 2023-05-15)", + type="text", + ), + ], + ), + ProviderConfig( + id="anthropic", + name="Anthropic", + description="Anthropic's Claude models", + category="llm", + icon="anthropic", + parameters=[ + ProviderParameter(name="ANTHROPIC_API_KEY", description="Anthropic API Key"), + ], + ), + ProviderConfig( + id="gemini", + name="Google Gemini", + description="Google's Gemini models", + category="llm", + icon="google", + parameters=[ + ProviderParameter(name="GEMINI_API_KEY", description="Google AI API Key"), + ], + ), + ProviderConfig( + id="deepseek", + name="DeepSeek", + description="DeepSeek's code and chat models", + category="llm", + icon="deepseek", + parameters=[ + ProviderParameter(name="DEEPSEEK_API_KEY", description="DeepSeek API Key"), + ], + ), + ProviderConfig( + id="cohere", + name="Cohere", + description="Cohere's language models", + category="llm", + icon="cohere", + parameters=[ + ProviderParameter(name="COHERE_API_KEY", description="Cohere API Key"), + ], + ), + ProviderConfig( + id="voyage", + name="Voyage AI", + description="Voyage's language models", + category="llm", + icon="voyage", + parameters=[ + ProviderParameter(name="VOYAGE_API_KEY", description="Voyage AI API Key"), + ], + ), + ProviderConfig( + id="mistral", + name="Mistral AI", + description="Mistral's language models", + category="llm", + icon="mistral", + parameters=[ + ProviderParameter(name="MISTRAL_API_KEY", description="Mistral AI API Key"), + ], + ), + # Vector Store Providers + ProviderConfig( + id="pinecone", + name="Pinecone", + description="Production-ready vector database", + category="vectorstore", + icon="pinecone", + parameters=[ + ProviderParameter(name="PINECONE_API_KEY", description="Pinecone API Key"), + ProviderParameter( + name="PINECONE_ENVIRONMENT", + description="Pinecone Environment", + type="text", + ), + ProviderParameter( + name="PINECONE_INDEX", + description="Pinecone Index Name", + type="text", + ), + ], + ), + ProviderConfig( + id="weaviate", + name="Weaviate", + description="Multi-modal vector search engine", + category="vectorstore", + icon="weaviate", + parameters=[ + ProviderParameter(name="WEAVIATE_API_KEY", description="Weaviate API Key"), + ProviderParameter( + name="WEAVIATE_URL", + description="Weaviate Instance URL", + type="text", + ), + ], + ), + ProviderConfig( + id="qdrant", + name="Qdrant", + description="Vector database for production", + category="vectorstore", + icon="qdrant", + parameters=[ + ProviderParameter(name="QDRANT_API_KEY", description="Qdrant API Key"), + ProviderParameter( + name="QDRANT_URL", + description="Qdrant Instance URL", + type="text", + ), + ], + ), + ProviderConfig( + id="chroma", + name="Chroma", + description="Open-source embedding database", + category="vectorstore", + icon="chroma", + parameters=[ + ProviderParameter( + name="CHROMA_IN_MEMORY", + description="Run Chroma in memory", + type="text", + ), + ProviderParameter( + name="CHROMA_PERSISTENCE_DIR", + description="Directory for Chroma persistence", + type="text", + ), + ProviderParameter( + name="CHROMA_HOST", + description="Chroma server host", + type="text", + ), + ProviderParameter( + name="CHROMA_PORT", + description="Chroma server port", + type="text", + ), + ProviderParameter( + name="CHROMA_COLLECTION", + description="Chroma collection name", + type="text", + ), + ], + ), + ProviderConfig( + id="supabase", + name="Supabase", + description="Open-source vector database", + category="vectorstore", + icon="supabase", + parameters=[ + ProviderParameter( + name="SUPABASE_URL", + description="Supabase Project URL", + type="text", + ), + ProviderParameter( + name="SUPABASE_ANON_KEY", + description="Supabase Anonymous Key", + type="password", + required=False, + ), + ProviderParameter( + name="SUPABASE_SERVICE_ROLE_KEY", + description="Supabase Service Role Key", + type="password", + required=False, + ), + ], + ), + # Add Reddit Provider + ProviderConfig( + id="reddit", + name="Reddit", + description="Reddit API integration", + category="social", + icon="logos:reddit-icon", + parameters=[ + ProviderParameter(name="REDDIT_CLIENT_ID", description="Reddit API Client ID"), + ProviderParameter(name="REDDIT_CLIENT_SECRET", description="Reddit API Client Secret"), + ProviderParameter( + name="REDDIT_USERNAME", description="Reddit Username", type="text", required=False + ), + ProviderParameter( + name="REDDIT_PASSWORD", + description="Reddit Password", + type="password", + required=False, + ), + ProviderParameter( + name="REDDIT_USER_AGENT", + description="Reddit API User Agent", + type="text", + required=False, + ), + ], + ), + # Add Firecrawl Provider + ProviderConfig( + id="firecrawl", + name="Firecrawl", + description="Web scraping and crawling service", + category="scraping", + icon="solar:spider-bold", + parameters=[ + ProviderParameter(name="FIRECRAWL_API_KEY", description="Firecrawl API Key"), + ], + ), + # Add Slack Provider + ProviderConfig( + id="slack", + name="Slack", + description="Slack messaging and notification service", + category="messaging", + icon="logos:slack-icon", + parameters=[ + ProviderParameter( + name="SLACK_BOT_TOKEN", description="Slack Bot User OAuth Token (starts with xoxb-)" + ), + ProviderParameter( + name="SLACK_USER_TOKEN", + description="Slack User OAuth Token (starts with xoxp-)", + required=False, + ), + ], + ), + # Add Exa Provider + ProviderConfig( + id="exa", + name="Exa", + description="Exa web search API", + category="search", + icon="solar:search-bold", + parameters=[ + ProviderParameter(name="EXA_API_KEY", description="Exa API Key"), + ], + ), +] + +# For backward compatibility, create a flat list of all parameter names +MODEL_PROVIDER_KEYS = [ + {"name": param.name, "value": ""} for config in PROVIDER_CONFIGS for param in config.parameters +] + + +class APIKey(BaseModel): + name: str + value: Optional[str] = None + + +def get_all_env_variables() -> Dict[str, str | None]: + return dotenv_values(".env") + + +def get_env_variable(name: str) -> Optional[str]: + return os.getenv(name) + + +def set_env_variable(name: str, value: str): + """Set an environment variable both in the .env file and in the current process. + + Also ensures the value is properly quoted if it contains special characters. + """ + # Ensure the value is properly quoted if it contains spaces or special characters + if any(c in value for c in " '\"$&()|<>"): + value = f'"{value}"' + + # Update the .env file using set_key + set_key(".env", name, value) + + # Update the os.environ dictionary + os.environ[name] = value + + # Force reload of environment variables + load_dotenv(".env", override=True) + + +def delete_env_variable(name: str): + # Remove the key from the .env file + unset_key(".env", name) + # Remove the key from os.environ + os.environ.pop(name, None) + + +def mask_key_value(value: str, param_type: str = "password") -> str: + """Mask the key value based on the parameter type. + + For password types, shows only the first and last few characters. + For other types, shows the full value. + """ + if param_type != "password": + return value + + visible_chars = 4 # Number of characters to show at the start and end + min_masked_chars = 4 # Minimum number of masked characters + if len(value) <= visible_chars * 2 + min_masked_chars: + return "*" * len(value) + else: + return ( + value[:visible_chars] + "*" * (len(value) - visible_chars * 2) + value[-visible_chars:] + ) + + +@router.get("/providers", description="Get all provider configurations") +async def get_providers(): + """Return all provider configurations.""" + return PROVIDER_CONFIGS + + +@router.get("/", description="Get a list of all environment variable names") +async def list_api_keys(): + """Return a list of all model provider keys.""" + return [k["name"] for k in MODEL_PROVIDER_KEYS] + + +@router.get( + "/{name}", + description="Get the masked value of a specific environment variable", +) +async def get_api_key(name: str): + """Return the masked value of the specified environment variable. + + Requires authentication. + """ + # Find the parameter configuration + param_type = "password" + for config in PROVIDER_CONFIGS: + for param in config.parameters: + if param.name == name: + param_type = param.type + break + + if name not in [k["name"] for k in MODEL_PROVIDER_KEYS]: + raise HTTPException(status_code=404, detail="Key not found") + value = get_env_variable(name) + if value is None: + value = "" + masked_value = mask_key_value(value, param_type) + return APIKey(name=name, value=masked_value) + + +@router.post("/", description="Add or update an environment variable") +async def set_api_key(api_key: APIKey): + """Add a new environment variable or updates an existing one. + + Requires authentication. + """ + if api_key.name not in [k["name"] for k in MODEL_PROVIDER_KEYS]: + raise HTTPException(status_code=404, detail="Key not found") + if not api_key.value: + raise HTTPException(status_code=400, detail="Value is required") + set_env_variable(api_key.name, api_key.value) + return {"message": f"Key '{api_key.name}' set successfully"} + + +@router.delete("/{name}", description="Delete an environment variable") +async def delete_api_key(name: str): + """Delete the specified environment variable. + + Requires authentication. + """ + if name not in [k["name"] for k in MODEL_PROVIDER_KEYS]: + raise HTTPException(status_code=404, detail="Key not found") + if get_env_variable(name) is None: + raise HTTPException(status_code=404, detail="Key not found") + delete_env_variable(name) + return {"message": f"Key '{name}' deleted successfully"} + + +@router.get("/embedding-models/", response_model=Dict[str, EmbeddingModelConfig]) +async def get_embedding_models() -> Dict[str, EmbeddingModelConfig]: + """Get all available embedding models and their configurations.""" + try: + models: Dict[str, EmbeddingModelConfig] = {} + for model in EmbeddingModels: + model_info = EmbeddingModels.get_model_info(model.value) + if model_info: + # Find the corresponding provider config + provider_config = next( + (p for p in PROVIDER_CONFIGS if p.id == model_info.provider.value.lower()), + None, + ) + if provider_config: + # Add required environment variables from the provider config + model_info.required_env_vars = [ + p.name for p in provider_config.parameters if p.required + ] + models[model.value] = model_info + return models + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) from e + + +@router.get("/vector-stores/", response_model=Dict[str, VectorStoreConfig]) +async def get_vector_stores_endpoint() -> Dict[str, VectorStoreConfig]: + """Get all available vector stores and their configurations.""" + try: + stores = get_vector_stores() + # Add required environment variables from provider configs + for store_id, store in stores.items(): + provider_config = next((p for p in PROVIDER_CONFIGS if p.id == store_id), None) + if provider_config: + store.required_env_vars = [p.name for p in provider_config.parameters if p.required] + return stores + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) from e + + +@router.get("/anon-data/", description="Get the status of anonymous telemetry data") +async def get_anon_data_status() -> bool: + """Get the status of anonymous telemetry data.""" + return os.getenv("DISABLE_ANONYMOUS_TELEMETRY", "false").lower() == "true" diff --git a/pyspur/backend/pyspur/api/main.py b/pyspur/backend/pyspur/api/main.py new file mode 100644 index 0000000000000000000000000000000000000000..3b1405141e2defd9768c42a60fad25eae298e39a --- /dev/null +++ b/pyspur/backend/pyspur/api/main.py @@ -0,0 +1,128 @@ +import os +import shutil +import tempfile +import threading +from contextlib import ExitStack, asynccontextmanager +from importlib.resources import as_file, files +from pathlib import Path + +from dotenv import load_dotenv +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import FileResponse +from fastapi.staticfiles import StaticFiles +from loguru import logger + +from .api_app import api_app + +load_dotenv() + +# Create an ExitStack to manage resources +exit_stack = ExitStack() +temporary_static_dir = None +socket_manager = None +socket_thread = None + + +@asynccontextmanager +async def lifespan(app: FastAPI): + """Manage application lifespan and cleanup.""" + global temporary_static_dir, socket_manager, socket_thread + + # Setup: Create temporary directory and extract static files + temporary_static_dir = Path(tempfile.mkdtemp()) + + # Extract static files to temporary directory + static_files = files("pyspur").joinpath("static") + static_dir = exit_stack.enter_context(as_file(static_files)) + + # Copy static files to temporary directory + if static_dir.exists(): + shutil.copytree(static_dir, temporary_static_dir, dirs_exist_ok=True) + + + yield + + # Cleanup: Stop socket manager and remove temporary directory + if socket_manager: + logger.info("Stopping socket manager...") + socket_manager.stopping = True + if socket_thread and socket_thread.is_alive(): + try: + # Give the thread a chance to stop gracefully + socket_thread.join(timeout=5) + except Exception as e: + logger.error(f"Error stopping socket manager thread: {e}") + + exit_stack.close() + shutil.rmtree(temporary_static_dir, ignore_errors=True) + + +app = FastAPI(lifespan=lifespan) + +# Add CORS middleware +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# Mount the API routes under /api +app.mount("/api", api_app, name="api") + +# Optionally, mount directories for assets that you want served directly: +if temporary_static_dir and Path.joinpath(temporary_static_dir, "images").exists(): + app.mount( + "/images", + StaticFiles(directory=str(temporary_static_dir.joinpath("images"))), + name="images", + ) +if temporary_static_dir and Path.joinpath(temporary_static_dir, "_next").exists(): + app.mount( + "/_next", StaticFiles(directory=str(temporary_static_dir.joinpath("_next"))), name="_next" + ) + + +@app.get("/{full_path:path}", include_in_schema=False) +async def serve_frontend(full_path: str): + if not temporary_static_dir: + raise RuntimeError("Static directory not initialized") + + # If the request is empty, serve index.html + if full_path == "": + return FileResponse(temporary_static_dir.joinpath("index.html")) + + # remove trailing slash + if full_path[-1] == "/": + full_path = full_path[:-1] + + # Build a candidate file path from the request. + candidate = temporary_static_dir.joinpath(full_path) + + # If candidate is a directory, try its index.html. + if candidate.is_dir(): + candidate_index = candidate.joinpath("index.html") + if candidate_index.exists(): + return FileResponse(candidate_index) + + # If no direct file, try appending ".html" (for files like dashboard.html) + candidate_html = temporary_static_dir.joinpath(full_path + ".html") + if candidate_html.exists(): + return FileResponse(candidate_html) + + # If a file exists at that candidate, serve it. + if candidate.exists(): + return FileResponse(candidate) + + # Check if the parent directory contains a file named "[id].html" + parts = full_path.split("/") + if len(parts) >= 2: + parent = temporary_static_dir.joinpath(*parts[:-1]) + dynamic_file = parent.joinpath("[id].html") + if dynamic_file.exists(): + return FileResponse(dynamic_file) + + # Fallback: serve the main index.html for client‑side routing. + return FileResponse(temporary_static_dir.joinpath("index.html")) diff --git a/pyspur/backend/pyspur/api/node_management.py b/pyspur/backend/pyspur/api/node_management.py new file mode 100644 index 0000000000000000000000000000000000000000..59e64ca09ac5d6bc5ffddde05024cd2fb0056f8a --- /dev/null +++ b/pyspur/backend/pyspur/api/node_management.py @@ -0,0 +1,69 @@ +from typing import Any, Dict, List + +from fastapi import APIRouter + +from ..nodes.factory import NodeFactory +from ..nodes.llm._utils import LLMModels + +router = APIRouter() + + +@router.get( + "/supported_types/", + description="Get the schemas for all available node types", +) +async def get_node_types() -> Dict[str, List[Dict[str, Any]]]: + """Return the schemas for all available node types.""" + # get the schemas for each node class + node_groups = NodeFactory.get_all_node_types() + + response: Dict[str, List[Dict[str, Any]]] = {} + for group_name, node_types in node_groups.items(): + node_schemas: List[Dict[str, Any]] = [] + for node_type in node_types: + node_class = node_type.node_class + try: + input_schema = node_class.input_model.model_json_schema() + except AttributeError: + input_schema = {} + try: + output_schema = node_class.output_model.model_json_schema() + except AttributeError: + output_schema = {} + + # Get the config schema and update its title with the display name + config_schema = node_class.config_model.model_json_schema() + config_schema["title"] = node_type.display_name + has_fixed_output = node_class.config_model.model_fields["has_fixed_output"].default + + node_schema: Dict[str, Any] = { + "name": node_type.node_type_name, + "input": input_schema, + "output": output_schema, + "config": config_schema, + "visual_tag": node_class.get_default_visual_tag().model_dump(), + "has_fixed_output": has_fixed_output, + } + + # Add model constraints if this is an LLM node + if node_type.node_type_name in ["LLMNode", "SingleLLMCallNode"]: + model_constraints = {} + for model_enum in LLMModels: + model_info = LLMModels.get_model_info(model_enum.value) + if model_info: + model_constraints[model_enum.value] = model_info.constraints.model_dump() + node_schema["model_constraints"] = model_constraints + + # Add the logo if available + logo = node_type.logo + if logo: + node_schema["logo"] = logo + + category = node_type.category + if category: + node_schema["category"] = category + + node_schemas.append(node_schema) + response[group_name] = node_schemas + + return response diff --git a/pyspur/backend/pyspur/api/openai_compatible_api.py b/pyspur/backend/pyspur/api/openai_compatible_api.py new file mode 100644 index 0000000000000000000000000000000000000000..aa5da47294cf3c9e65abc4e5e94d28faa9b8e223 --- /dev/null +++ b/pyspur/backend/pyspur/api/openai_compatible_api.py @@ -0,0 +1,107 @@ +from datetime import datetime, timezone +from typing import Any, Dict, List, Optional, Union + +from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException +from pydantic import BaseModel +from sqlalchemy.orm import Session + +from ..database import get_db +from ..models.workflow_model import WorkflowModel +from ..schemas.run_schemas import StartRunRequestSchema +from .workflow_run import run_workflow_blocking + +router = APIRouter() + + +# Define the request schema for OpenAI-compatible chat completions +class ChatCompletionRequest(BaseModel): + model: str + messages: List[Dict[str, Any]] + functions: Optional[List[Dict[str, Any]]] = None + function_call: Optional[Union[Dict[str, Any], str]] = None + temperature: float = 0.7 + top_p: float = 0.9 + n: int = 1 + stream: bool = False + stop: Optional[Union[str, List[str]]] = None + max_tokens: Optional[int] = None + presence_penalty: float = 0.0 + frequency_penalty: float = 0.0 + logit_bias: Optional[Dict[str, float]] = None + user: Optional[str] = None + + +# Define the response schema for OpenAI-compatible chat completions +class ChatCompletionResponse(BaseModel): + id: str + object: str + created: int + model: str + choices: List[Dict[str, Any]] + usage: Dict[str, int] + + +@router.post( + "/v1/chat/completions", + response_model=ChatCompletionResponse, + description="OpenAI-compatible chat completions endpoint", +) +async def chat_completions( + request: ChatCompletionRequest, + background_tasks: BackgroundTasks, + db: Session = Depends(get_db), +) -> ChatCompletionResponse: + """ + Mimics OpenAI's /v1/chat/completions endpoint for chat-based workflows. + """ + # Fetch the workflow (model maps to workflow_id) + workflow = db.query(WorkflowModel).filter(WorkflowModel.id == request.model).first() + if not workflow: + raise HTTPException(status_code=404, detail="Workflow not found") + + # Get the latest user message + latest_user_message = next( + (message["content"] for message in reversed(request.messages) if message["role"] == "user"), + None, + ) + if not latest_user_message: + raise HTTPException(status_code=400, detail="No user message found in messages") + + # Prepare initial inputs with the latest user message + initial_inputs = {"message": {"value": latest_user_message}} + + # Start a blocking workflow run with the initial inputs + start_run_request = StartRunRequestSchema( + initial_inputs=initial_inputs, + parent_run_id=None, + ) + outputs = await run_workflow_blocking( + workflow_id=request.model, + request=start_run_request, + db=db, + run_type="openai", + ) + + # Format the response with outputs from the workflow + response = ChatCompletionResponse( + id=f"chatcmpl-{datetime.now(timezone.utc).timestamp()}", + object="chat.completion", + created=int(datetime.now(timezone.utc).timestamp()), + model=request.model, + choices=[ + { + "message": { + "role": "assistant", + "content": outputs.get("response", {}).get("value", ""), + }, + "index": 0, + "finish_reason": outputs.get("finish_reason", "stop"), + } + ], + usage={ + "prompt_tokens": outputs.get("prompt_tokens", 0), + "completion_tokens": outputs.get("completion_tokens", 0), + "total_tokens": outputs.get("total_tokens", 0), + }, + ) + return response diff --git a/pyspur/backend/pyspur/api/openapi_management.py b/pyspur/backend/pyspur/api/openapi_management.py new file mode 100644 index 0000000000000000000000000000000000000000..898e106948f8109b06bcec90d1c4bce35c149358 --- /dev/null +++ b/pyspur/backend/pyspur/api/openapi_management.py @@ -0,0 +1,180 @@ +import json +import os +from typing import Dict, List, Optional +from uuid import uuid4 + +from fastapi import APIRouter, HTTPException +from pydantic import BaseModel + +router = APIRouter() + +# Directory to store OpenAPI specs +OPENAPI_SPECS_DIR = "pyspur/openapi_specs" + +# Ensure the directory exists +os.makedirs(OPENAPI_SPECS_DIR, exist_ok=True) + +class OpenAPIEndpoint(BaseModel): + path: str + method: str + summary: Optional[str] = None + operationId: Optional[str] = None + description: Optional[str] = None + input_schema: Optional[Dict] = None + output_schema: Optional[Dict] = None + +class OpenAPISpec(BaseModel): + id: str + name: str + description: str + version: str + endpoints: List[OpenAPIEndpoint] + raw_spec: Dict + +class CreateSpecRequest(BaseModel): + spec: Dict + +@router.post("/specs/", response_model=OpenAPISpec) +async def create_openapi_spec(request: CreateSpecRequest) -> OpenAPISpec: + """Store an OpenAPI specification.""" + try: + # Generate a unique ID for this spec + spec_id = str(uuid4()) + + # Extract basic info from the spec + info = request.spec.get("info", {}) + + # Parse all endpoints from the spec + endpoints: List[OpenAPIEndpoint] = [] + for path, path_item in request.spec.get("paths", {}).items(): + for method, operation in path_item.items(): + # Extract input schema + input_schema: Dict = {"properties": {}} + + # Path parameters + if operation.get("parameters"): + path_params = [p for p in operation["parameters"] if p.get("in") == "path"] + if path_params: + input_schema["properties"]["pathParameters"] = { + "type": "object", + "properties": {p["name"]: p.get("schema", {}) for p in path_params} + } + + # Query parameters + if operation.get("parameters"): + query_params = [p for p in operation["parameters"] if p.get("in") == "query"] + if query_params: + input_schema["properties"]["queryParameters"] = { + "type": "object", + "properties": {p["name"]: p.get("schema", {}) for p in query_params} + } + + # Header parameters + if operation.get("parameters"): + header_params = [p for p in operation["parameters"] if p.get("in") == "header"] + if header_params: + input_schema["properties"]["headerParameters"] = { + "type": "object", + "properties": {p["name"]: p.get("schema", {}) for p in header_params} + } + + # Request body + if operation.get("requestBody"): + content = operation["requestBody"].get("content", {}) + if content: + media_type = next(iter(content)) + input_schema["properties"]["requestBody"] = { + "mediaType": media_type, + "schema": content[media_type].get("schema", {}) + } + + # Output schema + output_schema: Dict = {"properties": {}} + if operation.get("responses"): + for status_code, response in operation["responses"].items(): + if response.get("content"): + media_type = next(iter(response["content"])) + output_schema["properties"][status_code] = { + "description": response.get("description", ""), + "mediaType": media_type, + "schema": response["content"][media_type].get("schema", {}) + } + else: + output_schema["properties"][status_code] = { + "description": response.get("description", ""), + "mediaType": "application/json", + "schema": {} + } + + endpoints.append(OpenAPIEndpoint( + path=path, + method=method.upper(), + summary=operation.get("summary"), + operationId=operation.get("operationId"), + description=operation.get("description"), + input_schema=input_schema, + output_schema=output_schema + )) + + spec_data = OpenAPISpec( + id=spec_id, + name=info.get("title", "Untitled API"), + description=info.get("description", ""), + version=info.get("version", "1.0.0"), + endpoints=endpoints, + raw_spec=request.spec + ) + + # Save the spec to a file + spec_path = os.path.join(OPENAPI_SPECS_DIR, f"{spec_id}.json") + with open(spec_path, "w") as f: + json.dump(spec_data.dict(), f, indent=2) + + return spec_data + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + +@router.get("/specs/", response_model=List[OpenAPISpec]) +async def list_openapi_specs() -> List[OpenAPISpec]: + """List all stored OpenAPI specifications.""" + try: + specs = [] + for filename in os.listdir(OPENAPI_SPECS_DIR): + if filename.endswith(".json"): + with open(os.path.join(OPENAPI_SPECS_DIR, filename)) as f: + spec_data = json.load(f) + specs.append(OpenAPISpec(**spec_data)) + return specs + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + +@router.get("/specs/{spec_id}", response_model=OpenAPISpec) +async def get_openapi_spec(spec_id: str) -> OpenAPISpec: + """Get a specific OpenAPI specification by ID.""" + try: + spec_path = os.path.join(OPENAPI_SPECS_DIR, f"{spec_id}.json") + if not os.path.exists(spec_path): + raise HTTPException(status_code=404, detail="Specification not found") + + with open(spec_path) as f: + spec_data = json.load(f) + return OpenAPISpec(**spec_data) + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + +@router.delete("/specs/{spec_id}") +async def delete_openapi_spec(spec_id: str) -> Dict[str, str]: + """Delete a specific OpenAPI specification by ID.""" + try: + spec_path = os.path.join(OPENAPI_SPECS_DIR, f"{spec_id}.json") + if not os.path.exists(spec_path): + raise HTTPException(status_code=404, detail="Specification not found") + + os.remove(spec_path) + return {"message": "Specification deleted successfully"} + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) diff --git a/pyspur/backend/pyspur/api/output_file_management.py b/pyspur/backend/pyspur/api/output_file_management.py new file mode 100644 index 0000000000000000000000000000000000000000..fe78885167a9a7e805871af7586899a208957fcb --- /dev/null +++ b/pyspur/backend/pyspur/api/output_file_management.py @@ -0,0 +1,92 @@ +from typing import List + +from fastapi import APIRouter, Depends, HTTPException +from fastapi.responses import FileResponse +from sqlalchemy.orm import Session + +from ..database import get_db +from ..models.output_file_model import OutputFileModel +from ..schemas.output_file_schemas import OutputFileResponseSchema + +router = APIRouter() + + +@router.get( + "/", + response_model=List[OutputFileResponseSchema], + description="List all output files", +) +def list_output_files( + db: Session = Depends(get_db), +) -> List[OutputFileResponseSchema]: + output_files = db.query(OutputFileModel).all() + output_file_list = [ + OutputFileResponseSchema( + id=of.id, + file_name=of.file_name, + created_at=of.created_at, + updated_at=of.updated_at, + ) + for of in output_files + ] + return output_file_list + + +@router.get( + "/{output_file_id}/", + response_model=OutputFileResponseSchema, + description="Get an output file by ID", +) +def get_output_file(output_file_id: str, db: Session = Depends(get_db)) -> OutputFileResponseSchema: + output_file = db.query(OutputFileModel).filter(OutputFileModel.id == output_file_id).first() + if not output_file: + raise HTTPException(status_code=404, detail="Output file not found") + return OutputFileResponseSchema( + id=output_file.id, + file_name=output_file.file_name, + created_at=output_file.created_at, + updated_at=output_file.updated_at, + ) + + +@router.delete( + "/{output_file_id}/", + description="Delete an output file by ID", +) +def delete_output_file(output_file_id: str, db: Session = Depends(get_db)): + output_file = db.query(OutputFileModel).filter(OutputFileModel.id == output_file_id).first() + if not output_file: + raise HTTPException(status_code=404, detail="Output file not found") + db.delete(output_file) + db.commit() + return {"message": "Output file deleted"} + + +# download_output_file endpoint +@router.get( + "/{output_file_id}/download/", + description="Download an output file by ID", +) +def download_output_file(output_file_id: str, db: Session = Depends(get_db)): + output_file = db.query(OutputFileModel).filter(OutputFileModel.id == output_file_id).first() + if not output_file: + raise HTTPException(status_code=404, detail="Output file not found") + + # get the appropriate media type based on the file extension + media_type = "application/octet-stream" + if output_file.file_name.endswith(".csv"): + media_type = "text/csv" + elif output_file.file_name.endswith(".json"): + media_type = "application/json" + elif output_file.file_name.endswith(".txt"): + media_type = "text/plain" + elif output_file.file_name.endswith(".jsonl"): + media_type = "application/x-ndjson" + + return FileResponse( + output_file.file_path, + media_type=media_type, + filename=output_file.file_name, + headers={"Content-Disposition": f"attachment; filename={output_file.file_name}"}, + content_disposition_type="attachment", + ) diff --git a/pyspur/backend/pyspur/api/rag_management.py b/pyspur/backend/pyspur/api/rag_management.py new file mode 100644 index 0000000000000000000000000000000000000000..891c307dde8e00a47b3313749361f5b8bf93530c --- /dev/null +++ b/pyspur/backend/pyspur/api/rag_management.py @@ -0,0 +1,943 @@ +import json +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Dict, List, Optional, cast + +from fastapi import ( + APIRouter, + BackgroundTasks, + Depends, + File, + Form, + HTTPException, + UploadFile, +) +from loguru import logger +from sqlalchemy.orm import Session + +from ..database import get_db +from ..models.dc_and_vi_model import ( + DocumentCollectionModel, + DocumentProcessingProgressModel, + DocumentStatus, + VectorIndexModel, +) +from ..rag.chunker import preview_document_chunk +from ..rag.document_collection import DocumentStore +from ..rag.schemas.document_schemas import ( + ChunkingConfigSchema, + DocumentWithChunksSchema, +) +from ..rag.vector_index import VectorIndex +from ..schemas.rag_schemas import ( + ChunkMetadataSchema, + DocumentCollectionCreateSchema, + DocumentCollectionResponseSchema, + ProcessingProgressSchema, + RetrievalRequestSchema, + RetrievalResponseSchema, + RetrievalResultSchema, + VectorIndexCreateSchema, + VectorIndexResponseSchema, +) + +# In-memory progress tracking (replace with database in production) +collection_progress: Dict[str, ProcessingProgressSchema] = {} +index_progress: Dict[str, ProcessingProgressSchema] = {} + + +async def update_collection_progress( + collection_id: str, + status: Optional[str] = None, + progress: Optional[float] = None, + current_step: Optional[str] = None, + processed_files: Optional[int] = None, + total_chunks: Optional[int] = None, + processed_chunks: Optional[int] = None, + error_message: Optional[str] = None, + db: Optional[Session] = None, +) -> None: + """Update document collection processing progress.""" + if collection_id not in collection_progress: + now = datetime.now(timezone.utc).isoformat() + collection_progress[collection_id] = ProcessingProgressSchema( + id=collection_id, + created_at=now, + updated_at=now, + ) + + progress_obj = collection_progress[collection_id] + if status: + progress_obj.status = status + # Update collection status in database + if db is not None: + collection = ( + db.query(DocumentCollectionModel) + .filter(DocumentCollectionModel.id == collection_id) + .first() + ) + if collection: + new_status = cast(DocumentStatus, "ready" if status == "completed" else status) + collection.status = new_status + if error_message: + collection.error_message = error_message + if processed_chunks and total_chunks: + collection.chunk_count = processed_chunks + if processed_files: + collection.document_count = processed_files + db.commit() + + if progress is not None: + progress_obj.progress = progress + if current_step: + progress_obj.current_step = current_step + if processed_files is not None: + progress_obj.processed_files = processed_files + if total_chunks is not None: + progress_obj.total_chunks = total_chunks + if processed_chunks is not None: + progress_obj.processed_chunks = processed_chunks + if error_message: + progress_obj.error_message = error_message + + progress_obj.updated_at = datetime.now(timezone.utc).isoformat() + + +async def update_index_progress( + index_id: str, + status: Optional[str] = None, + progress: Optional[float] = None, + current_step: Optional[str] = None, + total_chunks: Optional[int] = None, + processed_chunks: Optional[int] = None, + error_message: Optional[str] = None, + db: Optional[Session] = None, +) -> None: + """Update vector index processing progress.""" + if not db: + return + + # Get or create progress record + progress_record = ( + db.query(DocumentProcessingProgressModel) + .filter(DocumentProcessingProgressModel.id == index_id) + .first() + ) + + if not progress_record: + now = datetime.now(timezone.utc) + # Create a dictionary of values to initialize the model + values: Dict[str, Any] = { + "id": index_id, + "created_at": now, + "updated_at": now, + "status": status or "processing", + "progress": float(progress or 0.0), + "current_step": current_step or "", + "total_chunks": int(total_chunks or 0), + "processed_chunks": int(processed_chunks or 0), + "error_message": error_message, + } + progress_record = DocumentProcessingProgressModel(**values) + db.add(progress_record) + else: + # Update fields using setattr to handle SQLAlchemy types + if status: + progress_record.status = status + # Update index status in database + index = db.query(VectorIndexModel).filter(VectorIndexModel.id == index_id).first() + if index: + new_status = cast(DocumentStatus, "ready" if status == "completed" else status) + index.status = new_status + if error_message: + index.error_message = error_message + if processed_chunks: + index.chunk_count = int(processed_chunks) + + if progress is not None: + progress_record.progress = float(progress) + if current_step: + progress_record.current_step = current_step + if total_chunks is not None: + progress_record.total_chunks = int(total_chunks) + if processed_chunks is not None: + progress_record.processed_chunks = int(processed_chunks) + if error_message: + progress_record.error_message = error_message + + progress_record.updated_at = datetime.now(timezone.utc) + + db.commit() + + +async def update_index_status(index_id: str, status: str, db: Session) -> None: + """Update vector index status in database.""" + try: + index = db.query(VectorIndexModel).filter(VectorIndexModel.id == index_id).first() + if index: + # Convert string status to DocumentStatus enum + new_status = cast( + DocumentStatus, + "ready" if status == "ready" else "failed" if status == "failed" else "processing", + ) + index.status = new_status + index.updated_at = datetime.now(timezone.utc) + db.commit() + except Exception as e: + logger.error(f"Error updating index status: {e}") + + +async def process_vector_index_creation( + index_id: str, + docs_with_chunks: List[DocumentWithChunksSchema], + config: Dict[str, Any], + db: Session, +) -> None: + """Process vector index creation in background.""" + try: + vector_index = VectorIndex(index_id) + await vector_index.create_from_document_collection( + docs_with_chunks, + config, + lambda p, s, pc, tc: update_index_progress( + index_id, + progress=p, + current_step=s, + processed_chunks=pc, + total_chunks=tc, + db=db, + ), + ) + # Update index status to ready on successful completion + await update_index_status(index_id, "ready", db) + except Exception as e: + logger.error(f"Error processing vector index: {e}") + await update_index_status(index_id, "failed", db) + await update_index_progress(index_id, status="failed", error_message=str(e), db=db) + + +async def update_collection_status(collection_id: str, status: str, db: Session) -> None: + """Update document collection status in database.""" + try: + collection = ( + db.query(DocumentCollectionModel) + .filter(DocumentCollectionModel.id == collection_id) + .first() + ) + if collection: + # Convert string status to DocumentStatus enum + new_status = cast( + DocumentStatus, + "ready" if status == "ready" else "failed" if status == "failed" else "processing", + ) + collection.status = new_status + collection.updated_at = datetime.now(timezone.utc) + db.commit() + except Exception as e: + logger.error(f"Error updating collection status: {e}") + + +async def process_document_collection( + collection_id: str, + file_infos: List[Dict[str, Any]], + config: Dict[str, Any], + db: Session, +) -> None: + """Process document collection in background.""" + try: + doc_store = DocumentStore(collection_id) + + # Create progress callback + async def progress_callback(progress: float, step: str, processed: int, total: int) -> None: + await update_collection_progress( + collection_id, + progress=progress, + current_step=step, + processed_files=processed if step == "parsing" else None, + processed_chunks=processed if step == "chunking" else None, + total_chunks=total if step == "chunking" else None, + db=db, + ) + + await doc_store.process_documents( + file_infos, + config, + progress_callback, + ) + # Update collection status to ready on successful completion + await update_collection_status(collection_id, "ready", db) + except Exception as e: + logger.error(f"Error processing document collection: {e}") + await update_collection_status(collection_id, "failed", db) + await update_collection_progress( + collection_id, status="failed", error_message=str(e), db=db + ) + + +router = APIRouter() + + +@router.post( + "/collections/", + response_model=DocumentCollectionResponseSchema, + description="Create a new document collection from uploaded files and metadata", +) +async def create_document_collection( + background_tasks: BackgroundTasks, + files: List[UploadFile] = File(None), + metadata: str = Form(...), + db: Session = Depends(get_db), +): + """Create a new document collection.""" + try: + # Parse metadata + metadata_dict = json.loads(metadata) + collection_config = DocumentCollectionCreateSchema(**metadata_dict) + + # Validate vision model configuration if enabled + if collection_config.text_processing.use_vision_model: + vision_config = collection_config.text_processing.get_vision_config() + if not vision_config: + raise HTTPException( + status_code=400, detail="Invalid vision model configuration" + ) from None + + # Get current timestamp + now = datetime.now(timezone.utc) + + # Create document collection record + collection = DocumentCollectionModel( + name=collection_config.name, + description=collection_config.description, + status="ready" if not files else "processing", + document_count=len(files) if files else 0, + chunk_count=0, + text_processing_config=collection_config.text_processing.model_dump(), + created_at=now, + updated_at=now, + ) + db.add(collection) + db.commit() + db.refresh(collection) + + # Process files if present + if files: + # Read files and prepare file info + file_infos: List[Dict[str, Any]] = [] + collection_dir = Path(f"data/knowledge_bases/{collection.id}") + collection_dir.mkdir(parents=True, exist_ok=True) + + for file in files: + if file.filename: + file_path = collection_dir / file.filename + content = await file.read() + with open(file_path, "wb") as f: + f.write(content) + file_infos.append( + { + "path": str(file_path), + "mime_type": file.content_type, + "name": file.filename, + } + ) + + # Start background processing with new function + background_tasks.add_task( + process_document_collection, + collection.id, + file_infos, + collection_config.text_processing.model_dump(), + db, + ) + + # Create response + return DocumentCollectionResponseSchema( + id=collection.id, + name=collection.name, + description=collection.description, + status=collection.status, + created_at=collection.created_at.isoformat(), + updated_at=collection.updated_at.isoformat(), + document_count=collection.document_count, + chunk_count=collection.chunk_count, + ) + + except Exception as e: + raise HTTPException(status_code=400, detail=str(e)) from e + + +@router.post( + "/indices/", + response_model=VectorIndexResponseSchema, + description="Create a new vector index from a document collection", +) +async def create_vector_index( + background_tasks: BackgroundTasks, + index_config: VectorIndexCreateSchema, + db: Session = Depends(get_db), +): + """Create a new vector index from a document collection.""" + try: + # Check if collection exists + collection = ( + db.query(DocumentCollectionModel) + .filter(DocumentCollectionModel.id == index_config.collection_id) + .first() + ) + if not collection: + raise HTTPException(status_code=404, detail="Document collection not found") from None + + # Create vector index record + now = datetime.now(timezone.utc) + index = VectorIndexModel( + name=index_config.name, + description=index_config.description, + status="processing", + document_count=collection.document_count, + chunk_count=collection.chunk_count, + embedding_config=index_config.embedding.model_dump(), + collection_id=collection.id, + created_at=now, + updated_at=now, + ) + db.add(index) + db.commit() + db.refresh(index) + + # Initialize progress tracking in database + progress_record = DocumentProcessingProgressModel( + id=index.id, + status="processing", + progress=0.0, + current_step="initializing", + total_files=int(collection.document_count), + processed_files=0, + total_chunks=int(collection.chunk_count), + processed_chunks=0, + created_at=now, + updated_at=now, + ) + db.add(progress_record) + db.commit() + logger.debug(f"Initialized progress tracking for index {index.id}") + + # Get documents with chunks + doc_store = DocumentStore(collection.id) + docs_with_chunks: List[DocumentWithChunksSchema] = [] + for doc_id in doc_store.list_documents(): + doc = doc_store.get_document(doc_id) + if doc: + docs_with_chunks.append(doc) + + # Start background processing with new function + background_tasks.add_task( + process_vector_index_creation, + index.id, + docs_with_chunks, + index_config.embedding.model_dump(), + db, + ) + + # Create response + return VectorIndexResponseSchema( + id=index.id, + name=index.name, + description=index.description, + collection_id=index.collection_id, + status=index.status, + created_at=index.created_at.isoformat(), + updated_at=index.updated_at.isoformat(), + document_count=index.document_count, + chunk_count=index.chunk_count, + embedding_model=index_config.embedding.model, + vector_db=index_config.embedding.vector_db, + ) + + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) from e + + +@router.delete( + "/indices/{index_id}/", + description="Delete a vector index and its associated data", +) +async def delete_vector_index(index_id: str, db: Session = Depends(get_db)): + """Delete a vector index.""" + try: + # Get the vector index from the database + index = db.query(VectorIndexModel).filter(VectorIndexModel.id == index_id).first() + if not index: + raise HTTPException(status_code=404, detail="Vector index not found") from None + + # Delete from vector store and filesystem + vector_index = VectorIndex(index.id) + success = await vector_index.delete() + if not success: + raise HTTPException( + status_code=500, detail="Failed to delete vector index data" + ) from None + + # Remove from tracking database + db.delete(index) + db.commit() + + return {"message": "Vector index deleted successfully"} + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) from e + + +@router.get( + "/collections/", + response_model=List[DocumentCollectionResponseSchema], + description="List all document collections", +) +async def list_document_collections(db: Session = Depends(get_db)): + """List all document collections.""" + try: + collections = db.query(DocumentCollectionModel).all() + return [ + DocumentCollectionResponseSchema( + id=collection.id, + name=collection.name, + description=collection.description, + status=collection.status, + created_at=collection.created_at.isoformat(), + updated_at=collection.updated_at.isoformat(), + document_count=collection.document_count, + chunk_count=collection.chunk_count, + error_message=collection.error_message, + ) + for collection in collections + ] + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) from e + + +@router.get( + "/collections/{collection_id}/", + response_model=DocumentCollectionResponseSchema, +) +async def get_document_collection(collection_id: str, db: Session = Depends(get_db)): + """Get document collection details.""" + try: + collection = ( + db.query(DocumentCollectionModel) + .filter(DocumentCollectionModel.id == collection_id) + .first() + ) + if not collection: + raise HTTPException(status_code=404, detail="Document collection not found") from None + + return DocumentCollectionResponseSchema( + id=collection.id, + name=collection.name, + description=collection.description, + status=collection.status, + created_at=collection.created_at.isoformat(), + updated_at=collection.updated_at.isoformat(), + document_count=collection.document_count, + chunk_count=collection.chunk_count, + error_message=collection.error_message, + ) + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) from e + + +@router.delete( + "/collections/{collection_id}/", + description="Delete a document collection and its associated data", +) +async def delete_document_collection(collection_id: str, db: Session = Depends(get_db)): + """Delete a document collection.""" + try: + # Get the document collection from the database + collection = ( + db.query(DocumentCollectionModel) + .filter(DocumentCollectionModel.id == collection_id) + .first() + ) + if not collection: + raise HTTPException(status_code=404, detail="Document collection not found") from None + + # Delete files from filesystem + collection_dir = Path(f"data/knowledge_bases/{collection_id}") + if collection_dir.exists(): + import shutil + + shutil.rmtree(collection_dir) + + # Remove from tracking database + db.delete(collection) + db.commit() + + return {"message": "Document collection deleted successfully"} + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) from e + + +@router.get( + "/indices/", + response_model=List[VectorIndexResponseSchema], + description="List all vector indices", +) +async def list_vector_indices(db: Session = Depends(get_db)): + """List all vector indices.""" + try: + indices = db.query(VectorIndexModel).all() + return [ + VectorIndexResponseSchema( + id=index.id, + name=index.name, + description=index.description, + collection_id=index.collection_id, + status=index.status, + created_at=index.created_at.isoformat(), + updated_at=index.updated_at.isoformat(), + document_count=index.document_count, + chunk_count=index.chunk_count, + error_message=index.error_message, + embedding_model=index.embedding_config["model"], + vector_db=index.embedding_config["vector_db"], + ) + for index in indices + ] + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) from e + + +@router.get( + "/indices/{index_id}/", + response_model=VectorIndexResponseSchema, + description="Get details of a specific vector index", +) +async def get_vector_index(index_id: str, db: Session = Depends(get_db)): + """Get vector index details.""" + try: + index = db.query(VectorIndexModel).filter(VectorIndexModel.id == index_id).first() + if not index: + raise HTTPException(status_code=404, detail="Vector index not found") from None + + return VectorIndexResponseSchema( + id=index.id, + name=index.name, + description=index.description, + collection_id=index.collection_id, + status=index.status, + created_at=index.created_at.isoformat(), + updated_at=index.updated_at.isoformat(), + document_count=index.document_count, + chunk_count=index.chunk_count, + error_message=index.error_message, + embedding_model=index.embedding_config["model"], + vector_db=index.embedding_config["vector_db"], + ) + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) from e + + +# Add progress tracking endpoints +@router.get( + "/collections/{collection_id}/progress/", + response_model=ProcessingProgressSchema, +) +async def get_collection_progress(collection_id: str): + """Get document collection processing progress.""" + if collection_id not in collection_progress: + raise HTTPException(status_code=404, detail="No progress information found") from None + return collection_progress[collection_id] + + +@router.get( + "/indices/{index_id}/progress/", + response_model=ProcessingProgressSchema, + description="Get the processing progress of a vector index", +) +async def get_index_progress(index_id: str, db: Session = Depends(get_db)): + """Get vector index processing progress.""" + logger.debug(f"Getting progress for index {index_id}") + + progress_record = ( + db.query(DocumentProcessingProgressModel) + .filter(DocumentProcessingProgressModel.id == index_id) + .first() + ) + + if not progress_record: + raise HTTPException(status_code=404, detail="No progress information found") from None + + logger.debug(f"Progress data for index {index_id}: {progress_record.__dict__}") + + return ProcessingProgressSchema( + id=str(progress_record.id), + status=str(progress_record.status), + progress=float(progress_record.progress), + current_step=str(progress_record.current_step), + total_files=int(progress_record.total_files), + processed_files=int(progress_record.processed_files), + total_chunks=int(progress_record.total_chunks), + processed_chunks=int(progress_record.processed_chunks), + error_message=( + str(progress_record.error_message) if progress_record.error_message else None + ), + created_at=progress_record.created_at.isoformat(), + updated_at=progress_record.updated_at.isoformat(), + ) + + +@router.post( + "/collections/{collection_id}/documents/", + response_model=DocumentCollectionResponseSchema, +) +async def add_documents_to_collection( + collection_id: str, + background_tasks: BackgroundTasks, + files: List[UploadFile] = File(...), + db: Session = Depends(get_db), +): + """Add documents to an existing collection.""" + try: + # Get the document collection + collection = ( + db.query(DocumentCollectionModel) + .filter(DocumentCollectionModel.id == collection_id) + .first() + ) + if not collection: + raise HTTPException(status_code=404, detail="Document collection not found") from None + + # Read files and prepare file info + file_infos: List[Dict[str, Any]] = [] + collection_dir = Path(f"data/knowledge_bases/{collection.id}") + collection_dir.mkdir(parents=True, exist_ok=True) + + for file in files: + if file.filename: + file_path = collection_dir / file.filename + content = await file.read() + with open(file_path, "wb") as f: + f.write(content) + file_infos.append( + { + "path": str(file_path), + "mime_type": file.content_type, + "name": file.filename, + } + ) + + # Update collection status + collection.status = "processing" + collection.document_count += len(files) + db.commit() + db.refresh(collection) + + # Start background processing + if file_infos: + doc_store = DocumentStore(collection.id) + + # Create progress callback + async def progress_callback( + progress: float, step: str, processed: int, total: int + ) -> None: + await update_collection_progress( + collection.id, + progress=progress, + current_step=step, + processed_files=processed if step == "parsing" else None, + processed_chunks=processed if step == "chunking" else None, + total_chunks=total if step == "chunking" else None, + db=db, + ) + + background_tasks.add_task( + doc_store.process_documents, + file_infos, + collection.text_processing_config, + progress_callback, + ) + + return DocumentCollectionResponseSchema( + id=collection.id, + name=collection.name, + description=collection.description, + status=collection.status, + created_at=collection.created_at.isoformat(), + updated_at=collection.updated_at.isoformat(), + document_count=collection.document_count, + chunk_count=collection.chunk_count, + error_message=collection.error_message, + ) + + except Exception as e: + raise HTTPException(status_code=400, detail=str(e)) from e + + +@router.delete("/collections/{collection_id}/documents/{document_id}/") +async def delete_document_from_collection( + collection_id: str, + document_id: str, + db: Session = Depends(get_db), +): + """Delete a document from a collection.""" + try: + # Get the document collection + collection = ( + db.query(DocumentCollectionModel) + .filter(DocumentCollectionModel.id == collection_id) + .first() + ) + if not collection: + raise HTTPException(status_code=404, detail="Document collection not found") from None + + # Initialize document store + doc_store = DocumentStore(collection.id) + + # Check if document exists + doc = doc_store.get_document(document_id) + if not doc: + raise HTTPException( + status_code=404, detail="Document not found in collection" + ) from None + + # Delete document + success = doc_store.delete_document(document_id) + if not success: + raise HTTPException(status_code=500, detail="Failed to delete document") from None + + # Update collection stats + collection.document_count -= 1 + if doc.chunks: + collection.chunk_count -= len(doc.chunks) + db.commit() + + return {"message": "Document deleted successfully"} + + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) from e + + +@router.get( + "/collections/{collection_id}/documents/", + response_model=List[DocumentWithChunksSchema], +) +async def get_collection_documents( + collection_id: str, +) -> List[DocumentWithChunksSchema]: + """Get all documents and their chunks for a collection.""" + try: + doc_store = DocumentStore(collection_id) + documents: List[DocumentWithChunksSchema] = [] + for doc_id in doc_store.list_documents(): + doc = doc_store.get_document(doc_id) + if doc: + documents.append(doc) + return documents + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) from e + + +@router.post( + "/collections/preview_chunk/", + description="Preview how a document would be chunked with given configuration", +) +async def preview_chunk( + file: UploadFile = File(...), + chunking_config: str = Form(...), +) -> Dict[str, Any]: + """Preview how a file will be chunked and formatted with templates.""" + try: + # Parse chunking config + config = ChunkingConfigSchema(**json.loads(chunking_config)) + + if not file.filename: + raise HTTPException(status_code=400, detail="Filename is required") from None + + # Get preview using chunker module + preview_chunks, total_chunks = await preview_document_chunk( + file.file, file.filename, file.content_type or "text/plain", config + ) + + return {"chunks": preview_chunks, "total_chunks": total_chunks} + + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) from e + except Exception as e: + logger.error(f"Error previewing chunk: {e}") + raise HTTPException(status_code=500, detail=str(e)) from e + + +@router.post( + "/indices/{index_id}/retrieve/", + response_model=RetrievalResponseSchema, + description="Retrieve relevant chunks from a vector index based on a query", +) +async def retrieve_from_index( + index_id: str, + request: RetrievalRequestSchema, + db: Session = Depends(get_db), +) -> RetrievalResponseSchema: + """Retrieve relevant documents from a vector index.""" + try: + # Get the vector index from the database + index = db.query(VectorIndexModel).filter(VectorIndexModel.id == index_id).first() + if not index: + raise HTTPException(status_code=404, detail="Vector index not found") from None + + # Check if index is ready + if index.status != "ready": + raise HTTPException( + status_code=400, + detail=f"Vector index is not ready (current status: {index.status})", + ) from None + + # Initialize vector index + vector_index = VectorIndex(index.id) + + # Retrieve from vector index with default top_k if not specified + results = await vector_index.retrieve( + query=request.query, + top_k=request.top_k if request.top_k is not None else 5, + score_threshold=request.score_threshold, + semantic_weight=request.semantic_weight, + keyword_weight=request.keyword_weight, + ) + + # Format results + formatted_results: List[RetrievalResultSchema] = [] + for result in results: + chunk = result["chunk"] + metadata = result["metadata"] + formatted_results.append( + RetrievalResultSchema( + text=chunk.text, + score=result["score"], + metadata=ChunkMetadataSchema( + document_id=metadata.get("document_id", ""), + chunk_id=metadata.get("chunk_id", ""), + document_title=metadata.get("document_title"), + page_number=metadata.get("page_number"), + chunk_number=metadata.get("chunk_number"), + ), + ) + ) + + return RetrievalResponseSchema( + results=formatted_results, total_results=len(formatted_results) + ) + + except HTTPException: + raise + except Exception as e: + logger.error(f"Error retrieving from vector index: {e}") + raise HTTPException(status_code=500, detail=str(e)) from e diff --git a/pyspur/backend/pyspur/api/run_management.py b/pyspur/backend/pyspur/api/run_management.py new file mode 100644 index 0000000000000000000000000000000000000000..a34089cbc29b8ad55412ed6f044918a7d521c718 --- /dev/null +++ b/pyspur/backend/pyspur/api/run_management.py @@ -0,0 +1,62 @@ +from typing import List, Optional + +from fastapi import APIRouter, Depends, HTTPException, Query +from sqlalchemy.orm import Session + +from ..database import get_db +from ..models.run_model import RunModel, RunStatus +from ..models.task_model import TaskStatus +from ..schemas.run_schemas import RunResponseSchema + +router = APIRouter() + + +@router.get( + "/", + response_model=List[RunResponseSchema], + description="List all runs", +) +def list_runs( + page: int = Query(default=1, ge=1), + page_size: int = Query(default=10, ge=1, le=100), + parent_only: bool = True, + run_type: Optional[str] = None, + db: Session = Depends(get_db), +): + offset = (page - 1) * page_size + query = db.query(RunModel) + + if parent_only: + query = query.filter(RunModel.parent_run_id.is_(None)) + if run_type: + query = query.filter(RunModel.run_type == run_type) + + runs = query.order_by(RunModel.start_time.desc()).offset(offset).limit(page_size).all() + return runs + + +@router.get("/{run_id}/", response_model=RunResponseSchema) +def get_run(run_id: str, db: Session = Depends(get_db)): + run = db.query(RunModel).filter(RunModel.id == run_id).first() + if not run: + raise HTTPException(status_code=404, detail="Run not found") + return run + + +@router.get("/{run_id}/status/", response_model=RunResponseSchema) +def get_run_status(run_id: str, db: Session = Depends(get_db)): + run = db.query(RunModel).filter(RunModel.id == run_id).first() + if not run: + raise HTTPException(status_code=404, detail="Run not found") + if run.status != RunStatus.FAILED: + failed_tasks = [task for task in run.tasks if task.status == TaskStatus.FAILED] + running_and_pending_tasks = [ + task for task in run.tasks if task.status in [TaskStatus.PENDING, TaskStatus.RUNNING] + ] + if failed_tasks and len(running_and_pending_tasks) == 0: + run.status = RunStatus.FAILED + db.commit() + db.refresh(run) + if not run: + raise HTTPException(status_code=404, detail="Run not found") + return run diff --git a/pyspur/backend/pyspur/api/secure_token_store.py b/pyspur/backend/pyspur/api/secure_token_store.py new file mode 100644 index 0000000000000000000000000000000000000000..25c922a405eb73cefe2d8ccc80a389d73bc8eee6 --- /dev/null +++ b/pyspur/backend/pyspur/api/secure_token_store.py @@ -0,0 +1,242 @@ +import json +import os +import shutil +import time +import uuid +from pathlib import Path +from typing import Dict, Optional + +from cryptography.fernet import Fernet +from fastapi import HTTPException + + +class SecureTokenStore: + """A secure storage for agent-specific tokens. + + This class provides an interface to securely store and retrieve + tokens associated with specific agents, using encryption. + """ + + def __init__(self): + # Ensure the token storage directory exists + storage_dir = Path("./secure_tokens") + try: + storage_dir.mkdir(exist_ok=True) + print(f"Ensured token storage directory exists: {storage_dir}") + except Exception as e: + print(f"Error creating token storage directory: {str(e)}") + # Use a fallback directory in the current working directory + storage_dir = Path.cwd() / "secure_tokens" + storage_dir.mkdir(exist_ok=True) + print(f"Using fallback token storage directory: {storage_dir}") + + self.storage_path = storage_dir / "agent_tokens.enc" + key_path = storage_dir / "encryption_key.txt" + print(f"Token storage path: {self.storage_path}") + + # Try to load encryption key from environment first + self.encryption_key = os.getenv("TOKEN_ENCRYPTION_KEY") + + # If not in environment, try to load from file + if not self.encryption_key and key_path.exists(): + try: + print(f"Loading encryption key from file: {key_path}") + self.encryption_key = key_path.read_text().strip() + print("Successfully loaded encryption key from file") + except Exception as e: + print(f"Error loading encryption key from file: {str(e)}") + self.encryption_key = None + + # If still no key, generate a new one and save it + if not self.encryption_key: + print("No encryption key found, generating a new one") + self.encryption_key = Fernet.generate_key().decode() + + # Save the key to environment + os.environ["TOKEN_ENCRYPTION_KEY"] = self.encryption_key + + # Also save to file for persistence between restarts + try: + print(f"Saving encryption key to file: {key_path}") + key_path.write_text(self.encryption_key) + print("Successfully saved encryption key to file") + except Exception as e: + print(f"Error saving encryption key to file: {str(e)}") + else: + print("Using existing encryption key") + + try: + # Initialize Fernet cipher for encryption/decryption + self.cipher = Fernet(self.encryption_key.encode()) + except Exception as e: + print(f"Error initializing Fernet cipher: {str(e)}") + # Generate a new key as fallback + print("Generating new encryption key as fallback") + self.encryption_key = Fernet.generate_key().decode() + os.environ["TOKEN_ENCRYPTION_KEY"] = self.encryption_key + # Save to file + try: + key_path.write_text(self.encryption_key) + except Exception as e: + print(f"Error saving fallback encryption key to file: {str(e)}") + self.cipher = Fernet(self.encryption_key.encode()) + + # Initialize or load the token store + self.tokens: Dict[str, Dict[str, str]] = {} + self._load_tokens() + + def _load_tokens(self): + """Load tokens from encrypted storage file.""" + if not self.storage_path.exists(): + print(f"Loading tokens from {self.storage_path}") + print("Token file does not exist yet, starting with empty store") + return + + print(f"Loading tokens from {self.storage_path}") + + try: + # Read and decrypt the token data + encrypted_data = self.storage_path.read_bytes() + print(f"Read {len(encrypted_data)} bytes of encrypted data") + + if len(encrypted_data) == 0: + print("Token file is empty, starting with empty store") + return + + try: + # Try to decrypt with current key + decrypted_data = self.cipher.decrypt(encrypted_data) + self.tokens = json.loads(decrypted_data.decode("utf-8")) + print(f"Successfully loaded tokens for {len(self.tokens)} agents") + except Exception as e: + # If decryption fails, back up the file and log the error + print(f"Error decrypting token data: {str(e)}") + print("File may have been encrypted with a different key or be corrupted.") + + backup_path = f"{self.storage_path}.bak.{int(time.time())}" + try: + shutil.copy(self.storage_path, backup_path) + print(f"Moved potentially corrupted token file to {backup_path}") + except Exception as be: + print(f"Error backing up token file: {str(be)}") + + # Keep the in-memory tokens intact (empty or previously loaded) + print("Continuing with current in-memory tokens") + + except Exception as e: + print(f"Error loading token file: {str(e)}") + # Keep the in-memory tokens intact (empty or previously loaded) + + def _save_tokens(self): + """Save encrypted tokens to storage.""" + try: + import json + + data = json.dumps(self.tokens) + encrypted_data = self.cipher.encrypt(data.encode()) + self.storage_path.write_bytes(encrypted_data) + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to save tokens: {str(e)}") + + def create_token_id(self, agent_id: int) -> str: + """Generate a unique token ID for an agent.""" + token_id = str(uuid.uuid4()) + if str(agent_id) not in self.tokens: + self.tokens[str(agent_id)] = {} + self.tokens[str(agent_id)]["token_id"] = token_id + self._save_tokens() + return token_id + + def store_token(self, agent_id: int, token_type: str, token: str) -> str: + """Store a token for an agent""" + key = f"{agent_id}" + if key not in self.tokens: + self.tokens[key] = {} + + token_key = token_type + self.tokens[key][token_key] = token + + try: + # Try to save tokens + self._save_tokens() + except Exception as e: + print(f"Error saving token: {str(e)}") + # If saving fails, try to reset the tokens file and save again + try: + print("Attempting to reset token file and save again") + self.reset_tokens_file() + self._save_tokens() + except Exception as e2: + print(f"Error after reset attempt: {str(e2)}") + + return self._mask_token(token) + + def reset_tokens_file(self): + """Reset the tokens file, for cases where the encryption key has changed""" + # Backup current file if it exists + if self.storage_path.exists(): + backup_path = f"{self.storage_path}.bak.{int(time.time())}" + try: + shutil.copy(self.storage_path, backup_path) + print(f"Backed up token file to {backup_path}") + except Exception as e: + print(f"Error backing up token file: {str(e)}") + + # Delete the current file + try: + if self.storage_path.exists(): + self.storage_path.unlink() + print(f"Deleted token file: {self.storage_path}") + except Exception as e: + print(f"Error deleting token file: {str(e)}") + + # We keep the in-memory tokens that we've loaded or newly added + + def get_token(self, agent_id: int, token_type: str) -> Optional[str]: + """Retrieve a token for a specific agent and token type.""" + print(f"Retrieving {token_type} for agent {agent_id}") + + agent_id_str = str(agent_id) + if agent_id_str not in self.tokens: + print(f"No tokens found for agent {agent_id} (agent not found in token store)") + return None + + if token_type not in self.tokens[agent_id_str]: + print(f"No {token_type} found for agent {agent_id}") + available_types = ", ".join(self.tokens[agent_id_str].keys()) + print(f"Available token types for agent {agent_id}: {available_types}") + return None + + token = self.tokens[agent_id_str][token_type] + masked_token = self._mask_token(token) + print(f"Retrieved {token_type} for agent {agent_id}: {masked_token}") + return token + + def delete_token(self, agent_id: int, token_type: str) -> bool: + """Delete a token for a specific agent and token type.""" + agent_id_str = str(agent_id) + if agent_id_str in self.tokens and token_type in self.tokens[agent_id_str]: + del self.tokens[agent_id_str][token_type] + if not self.tokens[agent_id_str]: # If no tokens left for this agent + del self.tokens[agent_id_str] + self._save_tokens() + return True + return False + + def _mask_token(self, token: str) -> str: + """Create a masked version of the token for display.""" + if len(token) <= 8: + return "*" * len(token) + return token[:4] + "*" * (len(token) - 8) + token[-4:] + + +# Singleton instance +_token_store = None + + +def get_token_store() -> SecureTokenStore: + """Get the singleton token store instance.""" + global _token_store + if _token_store is None: + _token_store = SecureTokenStore() + return _token_store diff --git a/pyspur/backend/pyspur/api/session_management.py b/pyspur/backend/pyspur/api/session_management.py new file mode 100644 index 0000000000000000000000000000000000000000..eb3fd43ce2552f69829859d845ae7a442f3902e5 --- /dev/null +++ b/pyspur/backend/pyspur/api/session_management.py @@ -0,0 +1,192 @@ +from typing import cast + +from fastapi import APIRouter, Depends, HTTPException, Query +from sqlalchemy import func, select +from sqlalchemy.exc import IntegrityError +from sqlalchemy.orm import Session + +from ..database import get_db +from ..models.user_session_model import SessionModel, UserModel +from ..models.workflow_model import WorkflowModel +from ..schemas.session_schemas import ( + SessionCreate, + SessionListResponse, + SessionResponse, +) + +router = APIRouter() + +TEST_USER_EXTERNAL_ID = "test_user" +TEST_USER_METADATA = {"is_test": True} + + +@router.post("/", response_model=SessionResponse) +async def create_session( + session_create: SessionCreate, + db: Session = Depends(get_db), +) -> SessionResponse: + """Create a new session.""" + # Check if session already exists with the given external_id + if session_create.external_id: + existing_session = ( + db.query(SessionModel) + .execution_options(join_depth=2) # Include messages in response + .filter(SessionModel.external_id == session_create.external_id) + .first() + ) + if existing_session: + return SessionResponse.model_validate(existing_session) + + # Verify user exists + user = db.query(UserModel).filter(UserModel.id == session_create.user_id).first() + if not user: + raise HTTPException(status_code=404, detail="User not found") + + # Verify workflow exists + workflow = ( + db.query(WorkflowModel).filter(WorkflowModel.id == session_create.workflow_id).first() + ) + if not workflow: + raise HTTPException(status_code=404, detail="Workflow not found") + + # Create session + session = SessionModel( + user_id=session_create.user_id, + workflow_id=session_create.workflow_id, + external_id=session_create.external_id, + ) + + try: + db.add(session) + db.commit() + db.refresh(session) + return SessionResponse.model_validate(session) + except IntegrityError: + db.rollback() + raise HTTPException( + status_code=400, + detail="Could not create session", + ) from None + + +@router.get("/", response_model=SessionListResponse) +async def list_sessions( + skip: int = Query(0, ge=0), + limit: int = Query(10, ge=1, le=100), + user_id: str | None = None, + db: Session = Depends(get_db), +) -> SessionListResponse: + """List sessions with pagination and optional user filtering.""" + query = select(SessionModel) + + if user_id: + query = query.where(SessionModel.user_id == user_id) + + # Get total count + total_count = cast(int, db.scalar(select(func.count()).select_from(query.subquery()))) + + # Get paginated sessions + sessions = ( + db.query(SessionModel) + .execution_options(join_depth=2) # Include messages in response + .order_by(SessionModel.created_at.desc()) + ) + + if user_id: + sessions = sessions.filter(SessionModel.user_id == user_id) + + sessions = sessions.offset(skip).limit(limit).all() + + # Convert models to response schemas + session_responses = [SessionResponse.model_validate(session) for session in sessions] + return SessionListResponse(sessions=session_responses, total=total_count) + + +@router.get("/{session_id}/", response_model=SessionResponse) +async def get_session( + session_id: str, + db: Session = Depends(get_db), +) -> SessionResponse: + """Get a specific session by ID.""" + session = ( + db.query(SessionModel) + .execution_options(join_depth=2) # Include messages in response + .filter(SessionModel.id == session_id) + .first() + ) + + if not session: + raise HTTPException(status_code=404, detail="Session not found") + + return SessionResponse.model_validate(session) + + +@router.delete("/{session_id}/", status_code=204) +async def delete_session( + session_id: str, + db: Session = Depends(get_db), +) -> None: + """Delete a session.""" + session = db.query(SessionModel).filter(SessionModel.id == session_id).first() + if not session: + raise HTTPException(status_code=404, detail="Session not found") + + db.delete(session) + db.commit() + + +@router.post("/test/", response_model=SessionResponse) +async def create_test_session( + workflow_id: str, + db: Session = Depends(get_db), +) -> SessionResponse: + """Create or reuse a test user and session. + + If a test user exists, it will be reused. + If an empty test session exists for the same workflow, it will be reused. + Otherwise, a new session will be created. + """ + # First verify workflow exists + workflow = db.query(WorkflowModel).filter(WorkflowModel.id == workflow_id).first() + if not workflow: + raise HTTPException(status_code=404, detail="Workflow not found") + + # Get or create test user + test_user = db.query(UserModel).filter(UserModel.external_id == TEST_USER_EXTERNAL_ID).first() + + if not test_user: + test_user = UserModel( + external_id=TEST_USER_EXTERNAL_ID, + user_metadata=TEST_USER_METADATA, + ) + db.add(test_user) + db.commit() + db.refresh(test_user) + + # Look for an existing empty session for this workflow + existing_session = ( + db.query(SessionModel) + .filter( + SessionModel.user_id == test_user.id, + SessionModel.workflow_id == workflow_id, + ) + .execution_options(join_depth=2) # Include messages in response + .order_by(SessionModel.created_at.desc()) + .first() + ) + if existing_session and len(existing_session.messages) == 0: + return SessionResponse.model_validate(existing_session) + + # Create new session + session = SessionModel(user_id=test_user.id, workflow_id=workflow_id) + try: + db.add(session) + db.commit() + db.refresh(session) + return SessionResponse.model_validate(session) + except IntegrityError: + db.rollback() + raise HTTPException( + status_code=400, + detail="Could not create test session", + ) from None diff --git a/pyspur/backend/pyspur/api/slack_management.py b/pyspur/backend/pyspur/api/slack_management.py new file mode 100644 index 0000000000000000000000000000000000000000..92714f68c07ca362479bff88b08f3c45eb02d60a --- /dev/null +++ b/pyspur/backend/pyspur/api/slack_management.py @@ -0,0 +1,1529 @@ +import asyncio +import json +import os +import traceback +from contextlib import asynccontextmanager +from datetime import UTC, datetime +from typing import Any, Callable, Dict, List, Optional, TypeVar, Union, cast + +import psutil +from fastapi import APIRouter, BackgroundTasks, Depends, FastAPI, HTTPException, Request +from loguru import logger +from slack_sdk import WebClient +from slack_sdk.errors import SlackApiError +from slack_sdk.web.async_client import AsyncWebClient +from slack_sdk.web.async_slack_response import AsyncSlackResponse +from sqlalchemy.orm import Session + +from ..database import get_db +from ..integrations.slack.socket_client import get_socket_mode_client +from ..models.run_model import RunModel, RunStatus +from ..models.slack_agent_model import SlackAgentModel +from ..models.task_model import TaskStatus +from ..models.workflow_model import WorkflowModel +from ..schemas.run_schemas import StartRunRequestSchema +from ..schemas.slack_schemas import ( + AgentTokenRequest, + AgentTokenResponse, + SlackAgentCreate, + SlackAgentResponse, + SlackAgentUpdate, + SlackMessage, + SlackMessageResponse, + SlackSocketModeResponse, + SlackTriggerConfig, + WorkflowAssociation, + WorkflowTriggerRequest, + WorkflowTriggerResult, + WorkflowTriggersResponse, +) +from . import key_management +from .secure_token_store import get_token_store +from .workflow_run import run_workflow_non_blocking + +router = APIRouter() + +# API Endpoints +SLACK_API_URL = "https://slack.com/api" +SLACK_POST_MESSAGE_URL = "https://slack.com/api/chat.postMessage" + +# Request timeout (in seconds) +REQUEST_TIMEOUT = 10 + +# Initialize the socket mode client and set up the workflow trigger callback +socket_mode_client = get_socket_mode_client() + +# Define a type variable for the response objects +T = TypeVar("T") + +# Add these type annotations to better handle slack_sdk method calls + + +def _validate_agent_socket_mode( + db: Session, agent_id: int, say_callback: Optional[Callable[..., Any]] = None +) -> bool: + """Validate if an agent should be processing socket mode events. + + Args: + db: Database session + agent_id: Agent ID to validate + say_callback: Optional callback to send message if agent is disabled + + Returns: + bool: True if agent should process events, False otherwise + + """ + agent = db.query(SlackAgentModel).filter(SlackAgentModel.id == agent_id).first() + + if agent is None or not bool(agent.socket_mode_enabled): + logger.warning(f"Rejecting event for agent {agent_id} - socket_mode_enabled=False") + + # Send response to avoid hanging the client + if say_callback: + try: + say_callback(text="This Slack integration is currently disabled.") + except Exception: + pass + return False + + return True + + +# Callback function for socket mode events to trigger workflows +async def handle_socket_mode_event( + trigger_request: WorkflowTriggerRequest, + agent_id: int, + say: Callable[..., Any], + client: Optional[WebClient] = None, +): + """Handle a socket mode event by triggering the associated workflow""" + logger.info(f"Handling socket mode event for agent {agent_id}") + + # This function can be called from a thread or directly, so we need to handle both cases + # Get a database session + db = next(get_db()) + + try: + # First, explicitly check if this agent should actually be processing events + if not _validate_agent_socket_mode(db, agent_id, say): + return + + agent = await _get_active_agent(db, agent_id) + if not agent: + return + + if await _should_trigger_workflow(agent, trigger_request): + await _trigger_workflow(db, agent, trigger_request, say, client) + + except Exception as e: + logger.error(f"Error in handle_socket_mode_event: {e}") + logger.error(f"Error details: {traceback.format_exc()}") + finally: + db.close() + + +# Create a synchronous version of the handler that can be called from threads +def handle_socket_mode_event_sync( + trigger_request: WorkflowTriggerRequest, + agent_id: int, + say: Callable[..., Any], + client: Optional[WebClient] = None, +): + """Synchronous wrapper for handle_socket_mode_event to be used in threaded contexts.""" + # Return the coroutine object without awaiting it + # The socket client will handle awaiting it appropriately + return handle_socket_mode_event(trigger_request, agent_id, say, client) + + +async def _get_active_agent(db: Session, agent_id: int) -> Optional[SlackAgentModel]: + """Get an active agent with workflow configured.""" + agent = ( + db.query(SlackAgentModel) + .filter( + SlackAgentModel.id == agent_id, + SlackAgentModel.is_active.is_(True), + SlackAgentModel.trigger_enabled.is_(True), + SlackAgentModel.workflow_id.isnot(None), + ) + .first() + ) + + if not agent: + logger.warning(f"Agent {agent_id} not found, not active, or has no workflow") + return agent + + +# Handle the item typing issues in the keywords list +async def _should_trigger_workflow( + agent: SlackAgentModel, trigger_request: WorkflowTriggerRequest +) -> bool: + """Determine if a Slack message should trigger a workflow.""" + # Only proceed if triggering is enabled for this agent + try: + # Use explicit conversion for all SQLAlchemy Column boolean fields + trigger_enabled = bool(agent.trigger_enabled) + if not trigger_enabled: + return False + + should_trigger = False + + # Check mention trigger - convert SQLAlchemy Column to bool for comparison + trigger_on_mention = bool(agent.trigger_on_mention) + if trigger_on_mention and trigger_request.event_type == "app_mention": + should_trigger = True + # Check direct message trigger + elif ( + bool(agent.trigger_on_direct_message) + and trigger_request.event_type == "message" + and trigger_request.event_data.get("channel_type") == "im" + ): + should_trigger = True + # Check channel message trigger + elif ( + bool(agent.trigger_on_channel_message) + and trigger_request.event_type == "message" + and trigger_request.event_data.get("channel_type") != "im" + ): + # For channel messages, we need to check for keywords + keywords = getattr(agent, "trigger_keywords", []) or [] + if isinstance(keywords, list): + str_keywords: List[str] = [] + for item in cast(List[Union[str, None]], keywords): + if item is not None: + str_keywords.append(str(item)) + if str_keywords: + message_text = trigger_request.text.lower() + for keyword in str_keywords: + if keyword.lower() in message_text: + return True + return False + else: + return False + + return should_trigger + except Exception as e: + logger.error(f"Error in _should_trigger_workflow: {str(e)}") + return False + + +async def _trigger_workflow( + db: Session, + agent: SlackAgentModel, + trigger_request: WorkflowTriggerRequest, + say: Callable[..., Any], + client: Optional[WebClient] = None, +): + """Trigger the workflow and handle the response""" + try: + # Prepare the run input + run_input = { + "message": trigger_request.text, + "channel_id": trigger_request.channel_id, + "user_id": trigger_request.user_id, + "event_type": trigger_request.event_type, + "event_data": trigger_request.event_data, + "timestamp": datetime.now(UTC).isoformat(), + } + + # Start the workflow run + background_tasks = BackgroundTasks() + run = await start_workflow_run( + db=db, + workflow_id=str(getattr(agent, "workflow_id", "") or ""), + run_input=run_input, + background_tasks=background_tasks, + ) + + # Let the user know we're processing their request + # Use client if available, otherwise fall back to say function + if client and trigger_request.channel_id: + try: + client.chat_postMessage( # type: ignore + channel=trigger_request.channel_id, + text=f"Processing your request... (Run ID: {run.id})", + ) + except Exception as e: + logger.error(f"Error using client to respond: {e}") + say(text=f"Processing your request... (Run ID: {run.id})") + else: + say(text=f"Processing your request... (Run ID: {run.id})") + + logger.info(f"Started workflow run {run.id} for agent {agent.id}") + + # Manually execute background tasks since we're not in a FastAPI endpoint + logger.info(f"Manually executing background tasks for workflow run {run.id}") + for task in background_tasks.tasks: + logger.info(f"Executing task: {str(task)}") + await task() + logger.info(f"Background tasks execution completed for workflow run {run.id}") + + # Wait for workflow to complete and return results to Slack + await _send_workflow_results_to_slack(run.id, trigger_request.channel_id, client, say) + + except Exception as e: + logger.error(f"Error triggering workflow for agent {agent.id}: {e}") + logger.error(f"Traceback: {traceback.format_exc()}") + say(text=f"Sorry, I encountered an error: {str(e)}") + + +# Handle chat_postMessage and auth_test type issues by using proper type annotations +async def _send_workflow_results_to_slack( + run_id: str, + channel_id: str, + client: Optional[WebClient] = None, + say: Optional[Callable[..., Any]] = None, + db: Optional[Session] = None, +): + """Get workflow results and send them back to Slack.""" + own_db_session = db is None + + try: + # Create a new database session if one wasn't provided + if own_db_session: + db = next(get_db()) + + # Wait for the workflow to complete (poll status) + # Poll for up to 2 minutes (24 attempts, 5 seconds apart) + max_attempts = 24 + attempts = 0 + run_complete = False + run = None + + while attempts < max_attempts and not run_complete: + attempts += 1 + # Get the run model from database + run = db.query(RunModel).filter(RunModel.id == run_id).first() + + if not run: + logger.error(f"Run {run_id} not found in database") + break + + # Check if the run has completed or failed + if run.status in [ + RunStatus.COMPLETED, + RunStatus.FAILED, + RunStatus.PAUSED, + RunStatus.CANCELED, + ]: + run_complete = True + break + + # Wait before polling again + await asyncio.sleep(5) + + # Prepare the message to send back + if not run: + message = f"⚠️ Could not find workflow run {run_id}" + elif not run_complete: + message = f"⏱️ Workflow run {run_id} is still in progress (status: {run.status})" + elif run.status == RunStatus.COMPLETED: + # Format the output message + if run.outputs: + # Find the output node + output_content: Optional[str] = None + # Typically outputs are stored with node IDs as keys + for _, output in run.outputs.items(): + # Look for output node or content that has relevant output data + if isinstance(output, dict) and ( + "result" in output or "content" in output or "text" in output + ): + output_dict = cast(Dict[str, Any], output) + output_content = ( + output_dict.get("result") + or output_dict.get("content") + or output_dict.get("text") + ) + break + + if output_content: + message = ( + f"✅ Workflow completed successfully!\n\n*Output:*\n```{output_content}```" + ) + else: + # If we can't find a specific output format, just return the full output as JSON + message = f"✅ Workflow completed successfully!\n\n*Output:*\n```{json.dumps(run.outputs, indent=2)}```" + else: + message = "✅ Workflow completed successfully! (No output data available)" + elif run.status == RunStatus.FAILED: + message = "❌ Workflow run failed" + # Look for error messages in tasks + error_messages: List[str] = [] + if run.tasks: + for task in run.tasks: + if task.status == TaskStatus.FAILED and task.error: + error_messages.append(f"- Task {task.node_id}: {task.error}") + + if error_messages: + message += "\n\n*Errors:*\n" + "\n".join(error_messages) + else: + message = f"⚠️ Workflow run {run_id} ended with status: {run.status}" + + # Send the message back to Slack + if client and channel_id: + try: + logger.info(f"Sending workflow result to Slack channel {channel_id}") + client.chat_postMessage(channel=channel_id, text=message) # type: ignore + logger.info("Successfully sent workflow result to Slack") + except Exception as e: + logger.error(f"Error sending workflow results to Slack: {e}") + if say: + say(text=message) + elif say: + logger.info("Using say function to send workflow result") + say(text=message) + else: + logger.error("Cannot send workflow results: No Slack client or say function available") + + except Exception as e: + logger.error(f"Error sending workflow results to Slack: {e}") + logger.error(f"Traceback: {traceback.format_exc()}") + + # Try to send error message + error_msg = f"Error retrieving workflow results: {str(e)}" + if client and channel_id: + try: + client.chat_postMessage(channel=channel_id, text=error_msg) # type: ignore + except Exception: + if say: + say(text=error_msg) + elif say: + say(text=error_msg) + finally: + # Clean up if we created our own session + if own_db_session and db: + db.close() + + +# Set the callback for the socket mode client - use the sync version +socket_mode_client.set_workflow_trigger_callback(handle_socket_mode_event_sync) # type: ignore + + +@router.get("/agents", response_model=List[SlackAgentResponse]) +async def get_agents(db: Session = Depends(get_db)) -> List[SlackAgentResponse]: + """Get all configured Slack agents.""" + agents = db.query(SlackAgentModel).all() + agent_responses: List[SlackAgentResponse] = [] + + for agent in agents: + # Convert the agent to a SlackAgentResponse with proper type handling + agent_response = _agent_to_response_model(agent) + agent_responses.append(agent_response) + + return agent_responses + + +def _get_nullable_str(value: Any) -> Optional[str]: + """Helper to safely convert nullable SQLAlchemy column to string.""" + return str(value) if value is not None else None + + +# Helper function to convert a SlackAgentModel to a SlackAgentResponse +def _agent_to_response_model(agent: SlackAgentModel) -> SlackAgentResponse: + """Convert a SlackAgentModel to a SlackAgentResponse with proper type handling.""" + try: + agent_id = int(str(agent.id)) + except (TypeError, ValueError): + agent_id = 0 + + # Build dictionary with careful conversion for SQLAlchemy types + agent_dict = { + "id": agent_id, + "name": str(agent.name), + "slack_team_id": _get_nullable_str(agent.slack_team_id), + "slack_team_name": _get_nullable_str(agent.slack_team_name), + "slack_channel_id": _get_nullable_str(agent.slack_channel_id), + "slack_channel_name": _get_nullable_str(agent.slack_channel_name), + "is_active": bool(agent.is_active), + "workflow_id": _get_nullable_str(agent.workflow_id), + "trigger_on_mention": bool(agent.trigger_on_mention), + "trigger_on_direct_message": bool(agent.trigger_on_direct_message), + "trigger_on_channel_message": bool(agent.trigger_on_channel_message), + "trigger_keywords": [str(k) for k in getattr(agent, "trigger_keywords", []) or []], + "trigger_enabled": bool(agent.trigger_enabled), + "has_bot_token": bool(agent.has_bot_token), + "has_user_token": bool(agent.has_user_token), + "has_app_token": bool(agent.has_app_token), + "last_token_update": _get_nullable_str(agent.last_token_update), + "spur_type": str(getattr(agent, "spur_type", "workflow") or "workflow"), + "created_at": str(getattr(agent, "created_at", "") or ""), + } + return SlackAgentResponse.model_validate(agent_dict) + + +@router.post("/agents", response_model=SlackAgentResponse) +async def create_agent(agent_create: SlackAgentCreate, db: Session = Depends(get_db)): + """Create a new Slack agent configuration.""" + # Ensure workflow_id is provided + if not agent_create.workflow_id: + raise HTTPException( + status_code=400, + detail="workflow_id is required - every agent must be associated with a workflow", + ) + + # Create a new agent from the agent_create fields + new_agent = SlackAgentModel( + name=agent_create.name, + slack_team_id=agent_create.slack_team_id, + slack_team_name=agent_create.slack_team_name, + slack_channel_id=agent_create.slack_channel_id, + slack_channel_name=agent_create.slack_channel_name, + is_active=bool(agent_create.is_active), + workflow_id=agent_create.workflow_id, + trigger_on_mention=bool(agent_create.trigger_on_mention), + trigger_on_direct_message=bool(agent_create.trigger_on_direct_message), + trigger_on_channel_message=bool(agent_create.trigger_on_channel_message), + trigger_keywords=agent_create.trigger_keywords, + trigger_enabled=bool(agent_create.trigger_enabled), + has_bot_token=bool(agent_create.has_bot_token), + has_user_token=bool(agent_create.has_user_token), + has_app_token=bool(agent_create.has_app_token), + last_token_update=agent_create.last_token_update, + spur_type=agent_create.spur_type or "workflow", + ) + + db.add(new_agent) + db.commit() + db.refresh(new_agent) + + # Convert the agent to a Pydantic model + return _agent_to_response_model(new_agent) + + +@router.get("/agents/{agent_id}", response_model=SlackAgentResponse) +async def get_agent(agent_id: int, db: Session = Depends(get_db)): + """Get a Slack agent configuration.""" + agent = db.query(SlackAgentModel).filter(SlackAgentModel.id == agent_id).first() + if agent is None: + raise HTTPException(status_code=404, detail="Agent not found") + + # Convert the agent to a Pydantic model + return _agent_to_response_model(agent) + + +@router.post("/agents/{agent_id}/send-message", response_model=SlackMessageResponse) +async def send_agent_message(agent_id: int, message: SlackMessage, db: Session = Depends(get_db)): + """Send a message to a channel using the Slack agent.""" + agent = db.query(SlackAgentModel).filter(SlackAgentModel.id == agent_id).first() + if agent is None: + raise HTTPException(status_code=404, detail="Agent not found") + + # Create a client using the agent's token + token = None + + # If agent has a bot token, retrieve it + has_bot_token = bool(agent.has_bot_token) + if has_bot_token: + logger.info(f"Retrieving bot token for agent {agent_id}") + token_store = get_token_store() + token = token_store.get_token(agent_id, "bot_token") + + if not token: + raise HTTPException(status_code=400, detail="This agent has no bot token configured") + + # Send the message to Slack + try: + client = WebClient(token=token) + logger.info(f"Sending message to channel '{message.channel}'") + + # Call Slack API - ignore type issues with chat_postMessage + slack_response = client.chat_postMessage( # type: ignore + channel=message.channel, + text=message.text, + ) + + # Extract data from the response + response_data = { + "ts": slack_response.get("ts", ""), + "channel": slack_response.get("channel", ""), + "message": "Message sent successfully", + "success": True, + } + + logger.info(f"Message sent successfully: {slack_response.get('ok', False)}") + return response_data + except SlackApiError as e: + logger.error(f"Error sending message to Slack: {str(e)}") + # If there was an error, raise an HTTPException + raise HTTPException( + status_code=500, + detail={ + "message": f"Error sending message to Slack: {str(e)}", + "success": False, + }, + ) from e + + +@router.post("/send-message", response_model=SlackMessageResponse) +async def send_message( + channel: str, text: str, agent_id: Optional[int] = None, db: Session = Depends(get_db) +) -> Dict[str, Any]: + """Send a message to a Slack channel.""" + logger.info(f"Attempting to send message to channel '{channel}' with agent_id: {agent_id}") + + # Initialize WebClient with the bot token + token = None + + # If agent_id is provided, try to get the agent-specific token + if agent_id is not None: + # Try to get the agent + logger.info(f"Searching for agent with ID: {agent_id}") + agent = db.query(SlackAgentModel).filter(SlackAgentModel.id == agent_id).first() + + if agent: + has_bot_token = bool(agent.has_bot_token) + logger.info(f"Found agent '{agent.name}' with has_bot_token={has_bot_token}") + + # If agent has a bot token, retrieve it + if has_bot_token: + logger.info(f"Retrieving bot token for agent {agent_id}") + token_store = get_token_store() + token = token_store.get_token(agent_id, "bot_token") + logger.info(f"Token retrieved: {'Yes' if token else 'No'}") + else: + logger.error(f"Agent with ID {agent_id} not found") + + # If no agent-specific token, try to use the default token from environment + if not token: + logger.info("No agent-specific token found, using environment variables") + token = os.getenv("SLACK_BOT_TOKEN") + + if not token: + logger.error("No Slack bot token configured") + raise HTTPException( + status_code=500, + detail={ + "message": "No Slack bot token configured", + "success": False, + }, + ) + + try: + client = WebClient(token=token) + logger.info(f"Sending message to channel '{channel}'") + + # Call Slack API - ignore type issues with chat_postMessage + slack_response = client.chat_postMessage( # type: ignore + channel=channel, + text=text, + ) + + # Extract data from the response + response_data = { + "ts": slack_response.get("ts", ""), + "channel": slack_response.get("channel", ""), + "message": "Message sent successfully", + "success": True, + } + + logger.info(f"Message sent successfully: {slack_response.get('ok', False)}") + return response_data + except SlackApiError as e: + logger.error(f"Error sending message to Slack: {str(e)}") + # If there was an error, raise an HTTPException + raise HTTPException( + status_code=500, + detail={ + "message": f"Error sending message to Slack: {str(e)}", + "success": False, + }, + ) from e + + +@router.post("/test-message", response_model=SlackMessageResponse) +async def test_message( + channel: str, + text: str = "Hello from PySpur! This is a test message.", + agent_id: Optional[int] = None, + db: Session = Depends(get_db), +) -> Dict[str, Any]: + """Test sending a message to a Slack channel.""" + try: + # Attempt to send the test message using the Slack client + response = await send_message(channel=channel, text=text, agent_id=agent_id, db=db) + return response + except Exception as e: + logger.error(f"Error sending test message: {e}") + return {"success": False, "message": f"Error sending test message: {e}"} + + +@router.put("/agents/{agent_id}/workflow", response_model=SlackAgentResponse) +async def associate_workflow( + agent_id: int, association: WorkflowAssociation, db: Session = Depends(get_db) +): + """Associate a workflow with a Slack agent.""" + agent = db.query(SlackAgentModel).filter(SlackAgentModel.id == agent_id).first() + if agent is None: + raise HTTPException(status_code=404, detail="Agent not found") + + agent.set_field("workflow_id", association.workflow_id) + db.commit() + db.refresh(agent) + + # Convert the agent to a Pydantic model + return _agent_to_response_model(agent) + + +@router.put("/agents/{agent_id}/trigger-config", response_model=SlackAgentResponse) +async def update_trigger_config( + agent_id: int, config: SlackTriggerConfig, db: Session = Depends(get_db) +): + """Update the trigger configuration for a Slack agent.""" + agent = db.query(SlackAgentModel).filter(SlackAgentModel.id == agent_id).first() + if agent is None: + raise HTTPException(status_code=404, detail="Agent not found") + + # Update agent using set_field for type safety + agent.set_field("trigger_on_mention", config.trigger_on_mention) + agent.set_field("trigger_on_direct_message", config.trigger_on_direct_message) + agent.set_field("trigger_on_channel_message", config.trigger_on_channel_message) + agent.set_field("trigger_keywords", config.trigger_keywords) + agent.set_field("trigger_enabled", config.trigger_enabled) + + db.commit() + db.refresh(agent) + + # Convert the agent to a Pydantic model + return _agent_to_response_model(agent) + + +@router.put("/agents/{agent_id}", response_model=SlackAgentResponse) +async def update_agent( + agent_id: int, agent_update: SlackAgentUpdate, db: Session = Depends(get_db) +): + """Update a Slack agent configuration""" + agent = db.query(SlackAgentModel).filter(SlackAgentModel.id == agent_id).first() + if agent is None: + raise HTTPException(status_code=404, detail="Agent not found") + + # Update agent from the agent_update fields + update_data = agent_update.model_dump(exclude_unset=True) + + # Ensure boolean fields are properly converted + for field in update_data: + setattr(agent, field, update_data[field]) + + db.commit() + db.refresh(agent) + + # Convert the agent to a Pydantic model + return _agent_to_response_model(agent) + + +# Helper function to start a workflow run using the existing workflow execution system +async def start_workflow_run( + db: Session, workflow_id: str, run_input: Dict[str, Any], background_tasks: BackgroundTasks +) -> RunModel: + """Start a workflow run with the given input data using the standard workflow execution system.""" + # First, check if the workflow exists + workflow = db.query(WorkflowModel).filter(WorkflowModel.id == workflow_id).first() + if not workflow: + raise HTTPException(status_code=404, detail="Workflow not found") + + # Get the input node ID from the workflow + workflow_definition = workflow.definition + input_node = next( + (node for node in workflow_definition["nodes"] if node["node_type"] == "InputNode"), None + ) + + if not input_node: + raise HTTPException(status_code=400, detail="Workflow has no input node") + + # Create the request payload with the input data + initial_inputs = {input_node["id"]: run_input} + start_run_request = StartRunRequestSchema(initial_inputs=initial_inputs) + + # Use the standard workflow execution system to run the workflow + run_response = await run_workflow_non_blocking( + workflow_id=workflow_id, + start_run_request=start_run_request, + background_tasks=background_tasks, + db=db, + run_type="slack_triggered", + ) + + # Extract the RunModel from the response + return db.query(RunModel).filter(RunModel.id == run_response.id).first() + + +@router.post("/trigger-workflow", response_model=WorkflowTriggersResponse) +async def trigger_workflow( + request: WorkflowTriggerRequest, + background_tasks: BackgroundTasks, + db: Session = Depends(get_db), +): + """Trigger workflows based on a Slack event.""" + result = WorkflowTriggersResponse(triggered_workflows=[]) + + # Find agents that match the team ID + agents = ( + db.query(SlackAgentModel) + .filter( + SlackAgentModel.slack_team_id == request.team_id, + SlackAgentModel.is_active.is_(True), + SlackAgentModel.trigger_enabled.is_(True), + SlackAgentModel.workflow_id.isnot(None), + ) + .all() + ) + + if not agents: + return result + + # Process each agent to see if it should be triggered + for agent in agents: + # Cast SQLAlchemy Column types to Python types for the constructor + agent_id = agent.get_id() + workflow_id = agent.get_workflow_id() + + trigger_result = WorkflowTriggerResult( + agent_id=agent_id, + workflow_id=workflow_id, + status="skipped", + ) + + # Check if this agent should be triggered based on the event type + should_trigger = False + + # Check mention trigger - convert SQLAlchemy Column to bool for comparison + if bool(agent.trigger_on_mention) and request.event_type == "app_mention": + should_trigger = True + + # Check direct message trigger + elif ( + bool(agent.trigger_on_direct_message) + and request.event_type == "message" + and request.event_data.get("channel_type") == "im" + ): + should_trigger = True + + # Check channel message trigger + elif ( + bool(agent.trigger_on_channel_message) + and request.event_type == "message" + and request.event_data.get("channel_type") in ["channel", "group"] + ): + # If keywords are specified, check if any are in the message + keywords = getattr(agent, "trigger_keywords", None) + if keywords and isinstance(keywords, list): + # Make sure keywords is a list of strings before using len + str_keywords: List[str] = [] + for item in cast(List[Union[str, None]], keywords): + if item is not None: + str_keywords.append(str(item)) + if len(str_keywords) > 0: + message_text = request.text.lower() + for keyword in str_keywords: + if keyword.lower() in message_text: + should_trigger = True + break + else: + should_trigger = True + + if should_trigger: + try: + # Prepare the run input + run_input = { + "message": request.text, + "channel_id": request.channel_id, + "user_id": request.user_id, + "event_type": request.event_type, + "event_data": request.event_data, + "timestamp": datetime.now(UTC).isoformat(), # Use timezone-aware datetime + } + + # Start the workflow run using the standard workflow execution system + run = await start_workflow_run( + db=db, + workflow_id=str(getattr(agent, "workflow_id", "") or ""), + run_input=run_input, + background_tasks=background_tasks, + ) + + trigger_result.status = "triggered" + trigger_result.run_id = run.id + + except Exception as e: + trigger_result.status = "error" + trigger_result.error = str(e) + + result.triggered_workflows.append(trigger_result) + + return result + + +@router.post("/events", status_code=200) +async def slack_events( + request: Request, + background_tasks: BackgroundTasks, + db: Session = Depends(get_db), +): + """Handle Slack events from the Events API.""" + data = await request.json() + + # Handle Slack URL verification challenge + if data.get("type") == "url_verification": + return {"challenge": data.get("challenge")} + + # Process other events + event = data.get("event", {}) + event_type = event.get("type") + + if not event_type: + return {"ok": True} + + # Extract relevant data + team_id = data.get("team_id") + user_id = event.get("user") + channel_id = event.get("channel") + text = event.get("text", "") + + # Skip bot messages to avoid loops + if event.get("bot_id") or user_id == "USLACKBOT": + return {"ok": True} + + # Create a trigger request + trigger_request = WorkflowTriggerRequest( + text=text, + channel_id=channel_id, + user_id=user_id, + team_id=team_id, + event_type=event_type, + event_data=event, + ) + + # Process asynchronously to respond to Slack quickly + background_tasks.add_task(trigger_workflow, trigger_request, background_tasks, db) + + return {"ok": True} + + +@router.post("/agents/{agent_id}/tokens/{token_type}", response_model=AgentTokenResponse) +async def set_agent_token( + agent_id: int, token_type: str, token_request: AgentTokenRequest, db: Session = Depends(get_db) +): + # Override token_request.token_type with the type from the URL + agent = db.query(SlackAgentModel).filter(SlackAgentModel.id == agent_id).first() + if agent is None: + raise HTTPException(status_code=404, detail="Agent not found") + + token_store = get_token_store() + token_store.store_token(agent_id, token_type, token_request.token) + + # Update agent token flag + if token_type == "bot_token": + agent.set_field("has_bot_token", True) + + # Try to get team information when setting a bot token + try: + client = AsyncWebClient(token=token_request.token) + response: AsyncSlackResponse = await client.auth_test() # type: ignore + response_data: Dict[str, Any] = response.data if isinstance(response.data, dict) else {} # type: ignore + if response_data.get("ok"): + team_id = str(response_data.get("team_id", "")) + team_name = str(response_data.get("team", "")) + if team_id: + agent.set_field("slack_team_id", team_id) + agent.set_field("slack_team_name", team_name) + logger.info( + f"Updated agent {agent_id} with team information: {team_name} ({team_id})" + ) + except Exception as e: + logger.warning(f"Failed to retrieve team information: {str(e)}") + # Don't fail the token setting if we can't get team info + pass + + elif token_type == "user_token": + agent.set_field("has_user_token", True) + elif token_type == "app_token": + agent.set_field("has_app_token", True) + + current_timestamp = datetime.now(UTC).isoformat() + agent.set_field("last_token_update", current_timestamp) + db.commit() + db.refresh(agent) + + # Mask the token for the response + masked_token = mask_token(token_request.token) + return AgentTokenResponse( + agent_id=agent_id, + token_type=token_type, + masked_token=masked_token, + updated_at=current_timestamp, + ) + + +# Helper function to mask tokens for the API response +def mask_token(token: str) -> str: + """Create a masked version of the token for display.""" + if len(token) <= 8: + return "*" * len(token) + return token[:4] + "*" * (len(token) - 8) + token[-4:] + + +@router.get("/agents/{agent_id}/tokens/{token_type}", response_model=AgentTokenResponse) +async def get_agent_token(agent_id: int, token_type: str, db: Session = Depends(get_db)): + """Get a masked token for a Slack agent.""" + agent = db.query(SlackAgentModel).filter(SlackAgentModel.id == agent_id).first() + if agent is None: + raise HTTPException(status_code=404, detail="Agent not found") + + if token_type not in ["bot_token", "user_token", "app_token"]: + raise HTTPException(status_code=400, detail="Invalid token type") + + token_store = get_token_store() + + # Get the token + token = token_store.get_token(agent_id, token_type) + if not token: + raise HTTPException(status_code=404, detail=f"No {token_type} found for this agent") + + # Mask the token for the response + masked_token = mask_token(token) + + # Get the last update timestamp - handle the value directly + last_token_update = _get_nullable_str(agent.last_token_update) + + return AgentTokenResponse( + agent_id=agent_id, + token_type=token_type, + masked_token=masked_token, + updated_at=last_token_update, + ) + + +@router.delete("/agents/{agent_id}/tokens/{token_type}", status_code=204) +async def delete_agent_token(agent_id: int, token_type: str, db: Session = Depends(get_db)): + """Delete a token for a Slack agent.""" + agent = db.query(SlackAgentModel).filter(SlackAgentModel.id == agent_id).first() + if agent is None: + raise HTTPException(status_code=404, detail="Agent not found") + + if token_type not in ["bot_token", "user_token", "app_token"]: + raise HTTPException(status_code=400, detail="Invalid token type") + + token_store = get_token_store() + + # Delete the token + token_store.delete_token(agent_id, token_type) + + # Update agent token flag + if token_type == "bot_token": + agent.set_field("has_bot_token", False) + elif token_type == "user_token": + agent.set_field("has_user_token", False) + elif token_type == "app_token": + agent.set_field("has_app_token", False) + + current_timestamp = datetime.now(UTC).isoformat() + agent.set_field("last_token_update", current_timestamp) + db.commit() + + return None + + +@router.delete("/agents/{agent_id}", status_code=204) +async def delete_agent(agent_id: int, db: Session = Depends(get_db)): + """Delete a Slack agent by ID.""" + agent = db.query(SlackAgentModel).filter(SlackAgentModel.id == agent_id).first() + if agent is None: + raise HTTPException(status_code=404, detail="Agent not found") + + # Delete any associated tokens + token_store = get_token_store() + token_store.delete_token(agent_id, "bot_token") + token_store.delete_token(agent_id, "user_token") + token_store.delete_token(agent_id, "app_token") + + # Delete the agent + db.delete(agent) + db.commit() + + return None + + +@router.post("/set-token", response_model=dict) +async def set_slack_token(request: Request): + """Directly set the Slack bot token.""" + try: + data = await request.json() + token = data.get("token") + + if not token: + raise HTTPException(status_code=400, detail="Token is required") + + # Store the token + try: + key_management.set_env_variable("SLACK_BOT_TOKEN", token) + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to store token: {str(e)}") from e + + return {"success": True, "message": "Slack token has been set successfully"} + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error setting Slack token: {str(e)}") from e + + +# Define a lifespan context manager for FastAPI +@asynccontextmanager +async def lifespan(app: FastAPI): + """Define startup and shutdown events for the FastAPI application.""" + # Run startup tasks + db = next(get_db()) + try: + logger.info("Running startup tasks for Slack API...") + # Look for and recover orphaned workers + await recover_orphaned_workers(db) + logger.info("Slack API startup tasks completed.") + except Exception as e: + logger.error(f"Error in startup tasks: {e}") + finally: + db.close() + + # Yield control back to FastAPI + yield + + # Cleanup on shutdown if needed + logger.info("Shutting down Slack API...") + + +# Fix socket mode assignments for starting socket mode +@router.post("/agents/{agent_id}/socket-mode/start", response_model=SlackSocketModeResponse) +async def start_socket_mode(agent_id: int, db: Session = Depends(get_db)): + """Start Socket Mode for a Slack agent.""" + agent = db.query(SlackAgentModel).filter(SlackAgentModel.id == agent_id).first() + if agent is None: + raise HTTPException(status_code=404, detail="Agent not found") + + # Check if the agent has an app token or if there's one in the environment + has_app_token = bool(getattr(agent, "has_app_token", False)) + if not has_app_token and not os.getenv("SLACK_APP_TOKEN"): + raise HTTPException( + status_code=400, + detail="Socket Mode requires an app token. Please configure an app token for this agent or set the SLACK_APP_TOKEN environment variable.", + ) + + # Check if the signing secret is configured + signing_secret = os.getenv("SLACK_SIGNING_SECRET", "") + if not signing_secret: + raise HTTPException( + status_code=400, detail="SLACK_SIGNING_SECRET environment variable not configured" + ) + + # Update the agent's socket_mode_enabled field - use set_field to avoid Column typing issues + agent.set_field("socket_mode_enabled", True) + db.commit() + db.refresh(agent) + + # Import socket manager lazily to avoid circular imports + from ..integrations.slack.socket_manager import SocketManager + + # Initialize socket manager and start worker + socket_manager = SocketManager() + success = socket_manager.start_worker(agent_id) + + if not success: + # If worker failed to start, revert the socket_mode_enabled flag + agent.set_field("socket_mode_enabled", False) + db.commit() + db.refresh(agent) + raise HTTPException(status_code=500, detail="Failed to start Socket Mode worker") + + return SlackSocketModeResponse( + agent_id=agent_id, + socket_mode_active=True, + message="Socket Mode worker started successfully.", + ) + + +@router.post("/agents/{agent_id}/socket-mode/stop", response_model=SlackSocketModeResponse) +async def stop_socket_mode(agent_id: int, db: Session = Depends(get_db)): + """Stop Socket Mode for a Slack agent.""" + try: + # Get agent + agent = db.query(SlackAgentModel).filter(SlackAgentModel.id == agent_id).first() + if not agent: + raise HTTPException(status_code=404, detail="Agent not found") + + # Check if socket mode is already disabled + socket_mode_enabled = bool(getattr(agent, "socket_mode_enabled", False)) + if not socket_mode_enabled: + return SlackSocketModeResponse( + agent_id=agent_id, + socket_mode_active=False, + message="Socket Mode already disabled", + ) + + # Disable socket mode for the agent + agent.set_field("socket_mode_enabled", False) + db.commit() + db.refresh(agent) + + # Import socket manager lazily to avoid circular imports + from ..integrations.slack.socket_manager import SocketManager + + # Initialize socket manager and stop worker + socket_manager = SocketManager() + success = socket_manager.stop_worker(agent_id) + + message = ( + "Socket Mode worker stopped successfully." + if success + else "Failed to stop Socket Mode worker." + ) + return SlackSocketModeResponse( + agent_id=agent_id, + socket_mode_active=False, + message=message, + ) + + except Exception as e: + logger.error(f"Error stopping socket mode: {e}") + raise HTTPException(status_code=500, detail=f"Failed to stop Socket Mode: {str(e)}") from e + + +@router.get("/agents/{agent_id}/socket-mode/status", response_model=SlackSocketModeResponse) +async def get_socket_mode_status(agent_id: int, db: Session = Depends(get_db)): + """Get Socket Mode status for a Slack agent.""" + agent = db.query(SlackAgentModel).filter(SlackAgentModel.id == agent_id).first() + if agent is None: + raise HTTPException(status_code=404, detail="Agent not found") + + # Import socket manager lazily to avoid circular imports + from ..integrations.slack.socket_manager import SocketManager + from ..integrations.slack.worker_status import find_running_worker_process, get_worker_status + + # Initialize socket manager + socket_manager = SocketManager() + + # Check worker status through multiple methods to be robust + logger.info(f"Checking socket mode status for agent {agent_id}") + + # Method 1: Check worker status through SocketManager + worker_exists = agent_id in socket_manager.workers + logger.info(f"Agent {agent_id} worker exists in manager: {worker_exists}") + + worker_running = False + if worker_exists: + worker_process = socket_manager.workers[agent_id] + worker_running = worker_process.is_alive() + logger.info(f"Agent {agent_id} worker process is alive: {worker_running}") + + # Method 2: Check marker files and processes + worker_status = get_worker_status(agent_id) + logger.info(f"Agent {agent_id} worker status from marker files: {worker_status}") + + # Method 3: Directly search for processes + is_process_running, process_pid = find_running_worker_process(agent_id) + logger.info( + f"Agent {agent_id} process search result: running={is_process_running}, pid={process_pid}" + ) + + # Combine all results - if any method finds a running worker, consider it active + worker_running = worker_running or worker_status["process_running"] or is_process_running + logger.info(f"Agent {agent_id} combined worker running status: {worker_running}") + + # If we found a running process but it's not tracked in the socket manager, register it + if (worker_status["process_running"] or is_process_running) and not worker_exists: + pid = worker_status["pid"] if worker_status["process_running"] else process_pid + if pid: + logger.info( + f"Registering previously untracked worker process for agent {agent_id}: pid={pid}" + ) + process = psutil.Process(pid) + # Store the psutil Process object directly + socket_manager.workers[agent_id] = process + worker_exists = True + + # Convert SQLAlchemy Columns to bool + is_active = bool(getattr(agent, "is_active", True)) + trigger_enabled = bool(getattr(agent, "trigger_enabled", True)) + has_bot_token = bool(getattr(agent, "has_bot_token", False)) + socket_mode_enabled = bool(getattr(agent, "socket_mode_enabled", False)) + workflow_id = getattr(agent, "workflow_id", None) + + # Check if the agent should be active + agent_is_active = ( + is_active + and trigger_enabled + and has_bot_token + and workflow_id is not None + and socket_mode_enabled + ) + + logger.info(f"Agent {agent_id} database state: socket_mode_enabled={socket_mode_enabled}") + logger.info(f"Agent {agent_id} actual state: worker_running={worker_running}") + + # If worker is running but agent isn't marked as enabled in DB, update the DB + if worker_running and not socket_mode_enabled: + logger.info(f"Updating agent {agent_id} socket_mode_enabled to True to match actual state") + agent.set_field("socket_mode_enabled", True) + db.commit() + db.refresh(agent) + socket_mode_enabled = True + agent_is_active = ( + is_active + and trigger_enabled + and has_bot_token + and workflow_id is not None + and socket_mode_enabled + ) + + # If agent is marked as enabled but worker isn't running, try to restart it + elif socket_mode_enabled and not worker_running: + # Wait a moment to ensure we're not catching a worker in the process of starting + import asyncio + + await asyncio.sleep(1) + + # Check again after the delay + if worker_exists: + worker_running = socket_manager.workers[agent_id].is_alive() + logger.info(f"After delay, agent {agent_id} worker is alive: {worker_running}") + + # Also check process finder again + if not worker_running: + is_process_running, _ = find_running_worker_process(agent_id) + worker_running = is_process_running + logger.info(f"After delay, process search result: running={is_process_running}") + + # If still not running, try to restart it + if not worker_running: + try: + # Try to start the socket mode if the DB says it should be enabled + logger.info( + f"Agent {agent_id} is marked as enabled but no worker is running - attempting to start socket mode" + ) + success = socket_manager.start_worker(agent_id) + if success: + logger.info(f"Successfully started worker for agent {agent_id}") + worker_running = True + else: + # If we couldn't start it, update the DB to reflect reality + logger.info( + f"Failed to start worker for agent {agent_id}, updating DB state to match" + ) + agent.set_field("socket_mode_enabled", False) + db.commit() + db.refresh(agent) + socket_mode_enabled = False + agent_is_active = False + except Exception as e: + logger.error(f"Error attempting to restore socket mode for agent {agent_id}: {e}") + agent.set_field("socket_mode_enabled", False) + db.commit() + db.refresh(agent) + socket_mode_enabled = False + agent_is_active = False + + socket_mode_active = worker_running and agent_is_active + logger.info(f"Final status for agent {agent_id}: socket_mode_active={socket_mode_active}") + + return SlackSocketModeResponse( + agent_id=agent_id, + socket_mode_active=socket_mode_active, + message=f"Socket Mode worker is {'active' if socket_mode_active else 'inactive'}", + ) + + +@router.get("/agents/{agent_id}/debug-tokens") +async def debug_agent_tokens(agent_id: int, db: Session = Depends(get_db)): + """Debug endpoint to check token storage for an agent.""" + agent = db.query(SlackAgentModel).filter(SlackAgentModel.id == agent_id).first() + if agent is None: + raise HTTPException(status_code=404, detail="Agent not found") + + token_store = get_token_store() + + # Check each token type + token_info = { + "agent_id": agent_id, + "agent_name": agent.name, + "has_bot_token_flag": agent.has_bot_token, + "has_user_token_flag": agent.has_user_token, + "has_app_token_flag": agent.has_app_token, + "last_token_update": _get_nullable_str(agent.last_token_update), + } + + # Try to get each token and check if it exists + try: + bot_token = token_store.get_token(agent_id, "bot_token") + token_info["bot_token_exists"] = bool(bot_token) + if bot_token: + token_info["bot_token_starts_with"] = bot_token[:5] + "..." + except Exception as e: + token_info["bot_token_error"] = str(e) + + try: + user_token = token_store.get_token(agent_id, "user_token") + token_info["user_token_exists"] = bool(user_token) + if user_token: + token_info["user_token_starts_with"] = user_token[:5] + "..." + except Exception as e: + token_info["user_token_error"] = str(e) + + try: + app_token = token_store.get_token(agent_id, "app_token") + token_info["app_token_exists"] = bool(app_token) + if app_token: + token_info["app_token_starts_with"] = app_token[:5] + "..." + except Exception as e: + token_info["app_token_error"] = str(e) + + return token_info + + +@router.post("/agents/{agent_id}/test-connection", response_model=dict) +async def test_connection(agent_id: int, db: Session = Depends(get_db)): + """Test if the Slack connection for an agent works properly.""" + try: + agent = db.query(SlackAgentModel).filter(SlackAgentModel.id == agent_id).first() + if agent is None: + raise HTTPException(status_code=404, detail="Agent not found") + + if not bool(agent.has_bot_token): + return {"success": False, "message": "Agent doesn't have a bot token configured"} + + # Get the bot token from the token store + token_store = get_token_store() + bot_token = token_store.get_token(agent_id, "bot_token") + + if not bot_token: + return {"success": False, "message": "Could not retrieve bot token"} + + # Test the token by calling auth.test + client = AsyncWebClient(token=bot_token) + + try: + response: AsyncSlackResponse = await client.auth_test() # type: ignore + response_data: Dict[str, Any] = response.data if isinstance(response.data, dict) else {} # type: ignore + if response_data.get("ok"): + team = str(response_data.get("team", "Unknown workspace")) + team_id = str(response_data.get("team_id", "")) + user = str(response_data.get("user", "Unknown bot")) + + # Update the agent's team information + if team_id: + agent.set_field("slack_team_id", team_id) + agent.set_field("slack_team_name", team) + db.commit() + logger.info( + f"Updated agent {agent_id} with team information: {team} ({team_id})" + ) + + return { + "success": True, + "message": f"Successfully connected to {team} as {user}", + "team_id": team_id, + "bot_id": response_data["bot_id"], + "user_id": response_data["user_id"], + } + else: + return { + "success": False, + "message": ( + f"API call succeeded but returned not OK: " + f"{response_data.get('error', 'Unknown error')}" + ), + } + except SlackApiError as e: + error_response = cast(Dict[str, Any], getattr(e, "response", {})) + error_message = str(error_response.get("error", str(e))) + return {"success": False, "message": f"API Error: {error_message}"} + except Exception as e: + logger.error(f"Error testing Slack connection: {e}") + return {"success": False, "message": f"Error: {str(e)}", "error": str(e)} + + +@router.post("/socket-workers/recover", status_code=200) +async def recover_orphaned_workers(db: Session = Depends(get_db)) -> Dict[str, Any]: + """Find and recover any orphaned socket workers during backend startup. + + This checks for worker marker files and ensures the database state + reflects any running workers. + """ + try: + # Import socket manager lazily to avoid circular imports + from ..integrations.slack.socket_manager import SocketManager + + socket_manager = SocketManager() + + # Check for marker files indicating running workers + marker_dir = "/tmp/pyspur_socket_workers" + if not os.path.exists(marker_dir): + logger.info("No socket worker marker directory found, skipping recovery") + return {"recovered": 0, "message": "No marker directory found"} + + recovered = 0 + markers: List[int] = [] + + # Look for marker files matching agent_*.pid pattern + for filename in os.listdir(marker_dir): + if filename.startswith("agent_") and filename.endswith(".pid"): + try: + # Extract agent ID from filename + agent_id_str = filename[6:-4] # Remove "agent_" prefix and ".pid" suffix + agent_id = int(agent_id_str) + markers.append(agent_id) + + # Check if process is running + pid_file = os.path.join(marker_dir, filename) + with open(pid_file, "r") as f: + pid = int(f.read().strip()) + + # Check if process is still running + is_running = False + + process = psutil.Process(pid) + # Check if the command line contains socket_worker.py + cmdline = process.cmdline() + cmdline_str = " ".join(cmdline) + if ( + "socket_worker.py" in cmdline_str + and f"SLACK_AGENT_ID={agent_id}" in cmdline_str + ): + is_running = True + logger.info(f"Found running worker process for agent {agent_id}: pid={pid}") + + if is_running: + # Update the agent record to reflect the running worker + agent = ( + db.query(SlackAgentModel).filter(SlackAgentModel.id == agent_id).first() + ) + if agent: + if not bool(getattr(agent, "socket_mode_enabled", False)): + logger.info( + f"Recovering agent {agent_id} - worker is running but DB state was socket_mode_enabled=False" + ) + agent.set_field("socket_mode_enabled", True) + db.commit() + + # Try to register the process with the socket manager + try: + process = psutil.Process(pid) + # Store the psutil Process object directly + socket_manager.workers[agent_id] = process + recovered += 1 + except Exception as e: + logger.error(f"Error registering worker process: {e}") + else: + logger.warning( + f"Found marker for agent {agent_id}, but agent does not exist in database" + ) + else: + # Process is not running - clean up the marker file + logger.warning(f"Found stale marker file for agent {agent_id}, removing") + try: + os.remove(pid_file) + except Exception as e: + logger.error(f"Error removing stale marker file: {e}") + except Exception as e: + logger.error(f"Error processing marker file {filename}: {e}") + + return { + "recovered": recovered, + "markers": markers, + "message": f"Recovered {recovered} worker(s)", + } + except Exception as e: + logger.error(f"Error in recover_orphaned_workers: {e}") + logger.error(f"Traceback: {traceback.format_exc()}") + return {"recovered": 0, "error": str(e)} diff --git a/pyspur/backend/pyspur/api/template_management.py b/pyspur/backend/pyspur/api/template_management.py new file mode 100644 index 0000000000000000000000000000000000000000..02068bb2896cecfcb499ed2c4850ca58781fb6c8 --- /dev/null +++ b/pyspur/backend/pyspur/api/template_management.py @@ -0,0 +1,96 @@ +import json +from typing import List +from importlib.resources import files, as_file +import contextlib + +from fastapi import APIRouter, Depends, HTTPException +from pydantic import BaseModel +from sqlalchemy.orm import Session + +from ..database import get_db +from ..schemas.workflow_schemas import ( + WorkflowCreateRequestSchema, + WorkflowResponseSchema, +) +from .workflow_management import create_workflow + + +class TemplateSchema(BaseModel): + """Template schema.""" + + name: str + description: str + features: List[str] + file_name: str + + +router = APIRouter() + +TEMPLATES_RESOURCE = files("pyspur").joinpath("templates") + + +@router.get( + "/", + description="List all available templates", + response_model=List[TemplateSchema], +) +def list_templates() -> List[TemplateSchema]: + """List all available templates.""" + with contextlib.ExitStack() as stack: + templates_dir = stack.enter_context(as_file(TEMPLATES_RESOURCE)) + if not templates_dir.exists(): + raise HTTPException(status_code=500, detail="Templates directory not found") + + # Sort by creation time in descending (most recent first) + sorted_template_files = sorted( + templates_dir.glob("*.json"), + key=lambda p: p.stat().st_ctime, + reverse=True, + ) + + templates: List[TemplateSchema] = [] + for template_file in sorted_template_files: + with open(template_file, "r") as f: + template_content = json.load(f) + metadata = template_content.get("metadata", {}) + templates.append( + TemplateSchema.model_validate( + { + "name": metadata.get("name", template_file.stem), + "description": metadata.get("description", ""), + "features": metadata.get("features", []), + "file_name": template_file.name, + } + ) + ) + return templates + + +@router.post( + "/instantiate/", + description="Instantiate a new workflow from a template", + response_model=WorkflowResponseSchema, +) +def instantiate_template(template: TemplateSchema, db: Session = Depends(get_db)): + """Instantiate a new workflow from a template.""" + template_file_name = template.file_name + with contextlib.ExitStack() as stack: + templates_dir = stack.enter_context(as_file(TEMPLATES_RESOURCE)) + template_path = templates_dir / template_file_name + print(f"Requested template: {template_file_name}") + print(f"Resolved template path: {template_path}") + if not template_path.exists(): + raise HTTPException(status_code=404, detail="Template not found") + with open(template_path, "r") as f: + template_content = json.load(f) + metadata = template_content.get("metadata", {}) + workflow_definition = template_content.get("definition", {}) + new_workflow = create_workflow( + WorkflowCreateRequestSchema( + name=metadata.get("name", "Untitled Workflow"), + description=metadata.get("description", ""), + definition=workflow_definition, + ), + db, + ) + return new_workflow diff --git a/pyspur/backend/pyspur/api/user_management.py b/pyspur/backend/pyspur/api/user_management.py new file mode 100644 index 0000000000000000000000000000000000000000..51a62325ab50be6bbaa22e7c36ab5749848e37f8 --- /dev/null +++ b/pyspur/backend/pyspur/api/user_management.py @@ -0,0 +1,118 @@ +from typing import cast + +from fastapi import APIRouter, Depends, HTTPException, Query +from sqlalchemy import func, select +from sqlalchemy.exc import IntegrityError +from sqlalchemy.orm import Session + +from ..database import get_db +from ..models.user_session_model import UserModel +from ..schemas.user_schemas import ( + UserCreate, + UserListResponse, + UserResponse, + UserUpdate, +) + +router = APIRouter() + + +@router.post("/", response_model=UserResponse) +async def create_user( + user: UserCreate, + db: Session = Depends(get_db), +) -> UserResponse: + """Create a new user.""" + # Check if user already exists with the given external_id + existing_user = db.query(UserModel).filter(UserModel.external_id == user.external_id).first() + if existing_user: + return UserResponse.model_validate(existing_user) + + db_user = UserModel( + external_id=user.external_id, + metadata=user.user_metadata, + ) + + try: + db.add(db_user) + db.commit() + db.refresh(db_user) + return UserResponse.model_validate(db_user) + except IntegrityError: + db.rollback() + raise HTTPException( + status_code=409, + detail=f"User with external_id {user.external_id} already exists", + ) from None + + +@router.get("/", response_model=UserListResponse) +async def list_users( + skip: int = Query(0, ge=0), + limit: int = Query(10, ge=1, le=100), + db: Session = Depends(get_db), +) -> UserListResponse: + """List users with pagination.""" + # Get total count + total_count = cast(int, db.scalar(select(func.count()).select_from(UserModel))) + + # Get paginated users + users = db.query(UserModel).order_by(UserModel.id).offset(skip).limit(limit).all() + + # Convert models to response schemas + user_responses = [UserResponse.model_validate(user) for user in users] + return UserListResponse(users=user_responses, total=total_count) + + +@router.get("/{user_id}/", response_model=UserResponse) +async def get_user( + user_id: str, + db: Session = Depends(get_db), +) -> UserResponse: + """Get a specific user by ID.""" + user = db.get(UserModel, int(user_id.lstrip("U"))) + if not user: + raise HTTPException(status_code=404, detail="User not found") + return UserResponse.model_validate(user) + + +@router.patch("/{user_id}/", response_model=UserResponse) +async def update_user( + user_id: str, + user_update: UserUpdate, + db: Session = Depends(get_db), +) -> UserResponse: + """Update a user.""" + db_user = db.get(UserModel, int(user_id.lstrip("U"))) + if not db_user: + raise HTTPException(status_code=404, detail="User not found") + + update_data = user_update.model_dump(exclude_unset=True) + + try: + for field, value in update_data.items(): + setattr(db_user, field, value) + + db.commit() + db.refresh(db_user) + return UserResponse.model_validate(db_user) + except IntegrityError: + db.rollback() + raise HTTPException( + status_code=409, + detail=f"User with external_id {user_update.external_id} already exists", + ) from None + + +@router.delete("/{user_id}/", status_code=204) +async def delete_user( + user_id: str, + db: Session = Depends(get_db), +) -> None: + """Delete a user.""" + db_user = db.get(UserModel, int(user_id.lstrip("U"))) + if not db_user: + raise HTTPException(status_code=404, detail="User not found") + + db.delete(db_user) + db.commit() diff --git a/pyspur/backend/pyspur/api/workflow_code_convert.py b/pyspur/backend/pyspur/api/workflow_code_convert.py new file mode 100644 index 0000000000000000000000000000000000000000..dfe6bdcb2112e9f403569922dfda80f0450c4733 --- /dev/null +++ b/pyspur/backend/pyspur/api/workflow_code_convert.py @@ -0,0 +1,320 @@ +from datetime import datetime, timezone +from typing import Optional + +from fastapi import APIRouter, Body, Depends, HTTPException, Query +from pydantic import BaseModel +from sqlalchemy.orm import Session + +from ..database import get_db +from ..models.workflow_model import WorkflowModel +from ..schemas.workflow_schemas import WorkflowDefinitionSchema, WorkflowResponseSchema +from ..workflow_code_handler import WorkflowCodeHandler + + +class WorkflowCodeRequest(BaseModel): + """Request to generate code from a workflow or create a workflow from code.""" + + code: Optional[str] = None + workflow_id: Optional[str] = None + preserve_coordinates: bool = True + preserve_dimensions: bool = True + + +class WorkflowCodeResponse(BaseModel): + """Response containing generated workflow code.""" + + code: str + + +router = APIRouter() + + +@router.get( + "/{workflow_id}", + response_model=WorkflowCodeResponse, + description="Generate Python code from a workflow definition", +) +def get_workflow_code( + workflow_id: str, + preserve_coordinates: bool = Query( + True, description="Whether to include node coordinates in the code" + ), + preserve_dimensions: bool = Query( + True, description="Whether to include node dimensions in the code" + ), + db: Session = Depends(get_db), +) -> WorkflowCodeResponse: + """Generate Python code from a workflow definition. + + Args: + workflow_id: The ID of the workflow to generate code for + preserve_coordinates: Whether to include node coordinates in the code + preserve_dimensions: Whether to include node dimensions in the code + db: Database session + + Returns: + The generated Python code for the workflow + + """ + # Fetch the workflow + workflow = db.query(WorkflowModel).filter(WorkflowModel.id == workflow_id).first() + if not workflow: + raise HTTPException(status_code=404, detail="Workflow not found") + + # Parse the workflow definition + try: + workflow_def = WorkflowDefinitionSchema.model_validate(workflow.definition) + except Exception as e: + raise HTTPException( + status_code=400, detail=f"Failed to parse workflow definition: {str(e)}" + ) from e + + # Generate code from the workflow definition + try: + code = WorkflowCodeHandler.generate_code( + workflow_def, + workflow_name=workflow.name, + workflow_description=workflow.description or "", + preserve_coordinates=preserve_coordinates, + preserve_dimensions=preserve_dimensions, + ) + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to generate code: {str(e)}") from e + + return WorkflowCodeResponse(code=code) + + +@router.post( + "/create_from_code", + response_model=WorkflowResponseSchema, + description="Create a new workflow from Python code", +) +def create_workflow_from_code( + request: WorkflowCodeRequest = Body(...), + db: Session = Depends(get_db), +) -> WorkflowResponseSchema: + """Create a new workflow from Python code. + + Args: + request: The request containing the workflow code + db: Database session + + Returns: + The created workflow + + """ + if not request.code: + raise HTTPException(status_code=400, detail="Code is required") + + # Parse the code to get a workflow definition + try: + workflow_def = WorkflowCodeHandler.parse_code(request.code) + except Exception as e: + raise HTTPException(status_code=400, detail=f"Failed to parse code: {str(e)}") from e + + # Extract name from the code if possible + workflow_name = "Code Workflow" + workflow_description = "" + + # Try to find the name from the code + try: + import ast + + tree = ast.parse(request.code) + for node in ast.walk(tree): + # Look for WorkflowBuilder constructor calls + if ( + isinstance(node, ast.Call) + and isinstance(node.func, ast.Name) + and node.func.id == "WorkflowBuilder" + ): + if len(node.args) > 0 and isinstance(node.args[0], ast.Constant): + workflow_name = node.args[0].value + if len(node.args) > 1 and isinstance(node.args[1], ast.Constant): + workflow_description = node.args[1].value + break + except Exception: + # If we can't parse the name, just use the default + pass + + # Create a new workflow record + try: + new_workflow = WorkflowModel( + name=workflow_name, + description=workflow_description, + definition=workflow_def.model_dump(), + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + + db.add(new_workflow) + db.commit() + db.refresh(new_workflow) + + return new_workflow + except Exception as e: + db.rollback() + raise HTTPException(status_code=500, detail=f"Failed to create workflow: {str(e)}") from e + + +@router.put( + "/{workflow_id}", + response_model=WorkflowResponseSchema, + description="Update a workflow from Python code", +) +def update_workflow_from_code( + workflow_id: str, + request: WorkflowCodeRequest = Body(...), + db: Session = Depends(get_db), +) -> WorkflowResponseSchema: + """Update an existing workflow from Python code. + + Args: + workflow_id: The ID of the workflow to update + request: The request containing the workflow code + db: Database session + + Returns: + The updated workflow + + """ + if not request.code: + raise HTTPException(status_code=400, detail="Code is required") + + # Fetch the workflow + workflow = db.query(WorkflowModel).filter(WorkflowModel.id == workflow_id).first() + if not workflow: + raise HTTPException(status_code=404, detail="Workflow not found") + + # Parse the existing workflow for metadata preservation + try: + existing_workflow = WorkflowDefinitionSchema.model_validate(workflow.definition) + except Exception as e: + raise HTTPException( + status_code=400, detail=f"Failed to parse existing workflow definition: {str(e)}" + ) from e + + # Parse the code to get a workflow definition, preserving UI metadata + try: + workflow_def = WorkflowCodeHandler.parse_code( + request.code, existing_workflow=existing_workflow + ) + except Exception as e: + raise HTTPException(status_code=400, detail=f"Failed to parse code: {str(e)}") from e + + # Update the workflow + try: + workflow.definition = workflow_def.model_dump() + workflow.updated_at = datetime.now(timezone.utc) + + # Extract name from the code if possible + try: + import ast + + tree = ast.parse(request.code) + for node in ast.walk(tree): + # Look for WorkflowBuilder constructor calls + if ( + isinstance(node, ast.Call) + and isinstance(node.func, ast.Name) + and node.func.id == "WorkflowBuilder" + ): + if len(node.args) > 0 and isinstance(node.args[0], ast.Constant): + workflow.name = node.args[0].value + if len(node.args) > 1 and isinstance(node.args[1], ast.Constant): + workflow.description = node.args[1].value + break + except Exception: + # If we can't parse the name, keep the existing name + pass + + db.commit() + db.refresh(workflow) + + return workflow + except Exception as e: + db.rollback() + raise HTTPException(status_code=500, detail=f"Failed to update workflow: {str(e)}") from e + + +@router.post( + "/code_to_definition", + response_model=WorkflowDefinitionSchema, + description="Convert Python code to a workflow definition without saving", +) +def code_to_definition( + request: WorkflowCodeRequest = Body(...), + db: Session = Depends(get_db), +) -> WorkflowDefinitionSchema: + """Convert Python code to a workflow definition without saving to the database. + + Args: + request: The request containing the workflow code + db: Database session (unused but required by FastAPI) + + Returns: + The converted workflow definition + + """ + if not request.code: + raise HTTPException(status_code=400, detail="Code is required") + + # Parse the code to get a workflow definition + try: + # If a workflow_id is provided, use it to preserve UI metadata + existing_workflow = None + if request.workflow_id: + workflow = ( + db.query(WorkflowModel).filter(WorkflowModel.id == request.workflow_id).first() + ) + if workflow: + try: + existing_workflow = WorkflowDefinitionSchema.model_validate(workflow.definition) + except Exception: + # If we can't parse the existing workflow, continue without it + pass + + workflow_def = WorkflowCodeHandler.parse_code( + request.code, existing_workflow=existing_workflow + ) + return workflow_def + except Exception as e: + raise HTTPException(status_code=400, detail=f"Failed to parse code: {str(e)}") from e + + +@router.post( + "/definition_to_code", + response_model=WorkflowCodeResponse, + description="Convert a workflow definition to Python code without saving", +) +def definition_to_code( + workflow_def: WorkflowDefinitionSchema = Body(...), + preserve_coordinates: bool = Query( + True, description="Whether to include node coordinates in the code" + ), + preserve_dimensions: bool = Query( + True, description="Whether to include node dimensions in the code" + ), +) -> WorkflowCodeResponse: + """Convert a workflow definition to Python code without saving to the database. + + Args: + workflow_def: The workflow definition to convert + preserve_coordinates: Whether to include node coordinates in the code + preserve_dimensions: Whether to include node dimensions in the code + + Returns: + The generated Python code + + """ + try: + code = WorkflowCodeHandler.generate_code( + workflow_def, + workflow_name="Workflow", + workflow_description="", + preserve_coordinates=preserve_coordinates, + preserve_dimensions=preserve_dimensions, + ) + return WorkflowCodeResponse(code=code) + except Exception as e: + raise HTTPException(status_code=400, detail=f"Failed to generate code: {str(e)}") from e diff --git a/pyspur/backend/pyspur/api/workflow_management.py b/pyspur/backend/pyspur/api/workflow_management.py new file mode 100644 index 0000000000000000000000000000000000000000..632f9c5509550029073d7e1afbc78fccb2525258 --- /dev/null +++ b/pyspur/backend/pyspur/api/workflow_management.py @@ -0,0 +1,523 @@ +import json +import shutil +from datetime import datetime, timezone +from pathlib import Path +from typing import Dict, List + +from fastapi import ( + APIRouter, + BackgroundTasks, + Depends, + File, + Form, + HTTPException, + Query, + UploadFile, + status, +) +from sqlalchemy.orm import Session + +from ..database import get_db +from ..models.workflow_model import WorkflowModel as WorkflowModel +from ..models.workflow_version_model import WorkflowVersionModel as WorkflowVersionModel +from ..nodes.primitives.input import InputNodeConfig +from ..schemas.pause_schemas import ( + PausedWorkflowResponseSchema, + PauseHistoryResponseSchema, +) +from ..schemas.run_schemas import ( + ResumeRunRequestSchema, + RunResponseSchema, +) +from ..schemas.workflow_schemas import ( + SpurType, + WorkflowCreateRequestSchema, + WorkflowDefinitionSchema, + WorkflowNodeSchema, + WorkflowResponseSchema, + WorkflowVersionResponseSchema, +) +from .workflow_run import get_paused_workflows, get_run_pause_history, process_pause_action + +# Main router for workflow management +router = APIRouter() + + +# Paused workflow endpoints +@router.get( + "/paused_workflows/", + response_model=List[PausedWorkflowResponseSchema], + description="List all paused workflows", + tags=["workflows"], +) +def list_paused_workflows( + page: int = Query(default=1, ge=1), + page_size: int = Query(default=10, ge=1, le=100), + db: Session = Depends(get_db), +) -> List[PausedWorkflowResponseSchema]: + return get_paused_workflows(db, page, page_size) + + +@router.get( + "/pause_history/{run_id}/", + response_model=List[PauseHistoryResponseSchema], + description="Get pause history for a run", + tags=["workflows"], +) +def get_pause_history( + run_id: str, db: Session = Depends(get_db) +) -> List[PauseHistoryResponseSchema]: + return get_run_pause_history(db, run_id) + + +@router.post( + "/process_pause_action/{run_id}/", + response_model=RunResponseSchema, + description="Take action on a paused workflow", + tags=["workflows"], +) +def take_pause_action( + run_id: str, + action_request: ResumeRunRequestSchema, + background_tasks: BackgroundTasks, + db: Session = Depends(get_db), +) -> RunResponseSchema: + """Process an action on a paused workflow. + + It allows approving, declining, or overriding a workflow that has been paused + for human intervention. + + Args: + run_id: The ID of the paused run + action_request: The details of the action to take + background_tasks: FastAPI background tasks handler to resume the workflow asynchronously + db: Database session + + Returns: + Information about the resumed run + + """ + return process_pause_action(db, run_id, action_request, background_tasks) + + +def create_a_new_workflow_definition( + spur_type: SpurType = SpurType.WORKFLOW, +) -> WorkflowDefinitionSchema: + if spur_type == SpurType.CHATBOT: + # Create input node with required chatbot fields + input_node_config = InputNodeConfig().model_dump() + input_node_config["output_json_schema"] = json.dumps( + { + "type": "object", + "properties": { + "user_message": {"type": "string"}, + "session_id": {"type": "string"}, + "message_history": { + "type": "array", + "items": { + "type": "object", + "properties": { + "role": {"type": "string"}, + "content": {"type": "string"}, + }, + }, + }, + }, + "required": ["user_message", "session_id"], + } + ) + input_node_config["output_schema"] = { + "user_message": "string", + "session_id": "string", + "message_history": "List[Dict[str, str]]", + } + # Create output node with required chatbot fields + output_node_config = { + "output_schema": {"assistant_message": "string"}, + "output_json_schema": json.dumps( + { + "type": "object", + "properties": {"assistant_message": {"type": "string"}}, + "required": ["assistant_message"], + } + ), + } + + return WorkflowDefinitionSchema( + nodes=[ + WorkflowNodeSchema.model_validate( + { + "id": "input_node", + "node_type": "InputNode", + "coordinates": {"x": 100, "y": 100}, + "config": input_node_config, + } + ), + WorkflowNodeSchema.model_validate( + { + "id": "output_node", + "node_type": "OutputNode", + "coordinates": {"x": 300, "y": 100}, + "config": output_node_config, + } + ), + ], + links=[], + spur_type=spur_type, + ) + else: + return WorkflowDefinitionSchema( + nodes=[ + WorkflowNodeSchema.model_validate( + { + "id": "input_node", + "node_type": "InputNode", + "coordinates": {"x": 100, "y": 100}, + "config": InputNodeConfig().model_dump(), + } + ) + ], + links=[], + spur_type=spur_type, + ) + + +def generate_unique_workflow_name(db: Session, base_name: str) -> str: + existing_workflow = db.query(WorkflowModel).filter(WorkflowModel.name == base_name).first() + if existing_workflow: + timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + return f"{base_name} {timestamp}" + return base_name + + +@router.post( + "/", + response_model=WorkflowResponseSchema, + description="Create a new workflow", +) +def create_workflow( + workflow_request: WorkflowCreateRequestSchema, db: Session = Depends(get_db) +) -> WorkflowResponseSchema: + print(workflow_request) + if not workflow_request.definition: + # If no definition is provided, create a new one with default WORKFLOW type + workflow_request.definition = create_a_new_workflow_definition(spur_type=SpurType.WORKFLOW) + elif ( + workflow_request.definition.spur_type == SpurType.CHATBOT + and len(workflow_request.definition.nodes) == 0 + ): + # If the workflow type is CHATBOT, create a new definition with required fields + workflow_request.definition = create_a_new_workflow_definition(spur_type=SpurType.CHATBOT) + elif len(workflow_request.definition.nodes) == 0: + # If the workflow type is not CHATBOT, create a new definition with default WORKFLOW type + workflow_request.definition = create_a_new_workflow_definition(spur_type=SpurType.WORKFLOW) + + # Generate a unique name for the workflow + workflow_name = generate_unique_workflow_name(db, workflow_request.name or "Untitled Workflow") + new_workflow = WorkflowModel( + name=workflow_name, + description=workflow_request.description, + definition=workflow_request.definition.model_dump(), + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + db.add(new_workflow) + db.commit() + db.refresh(new_workflow) + + return new_workflow + + +@router.put( + "/{workflow_id}/", + response_model=WorkflowResponseSchema, + description="Update a workflow", +) +def update_workflow( + workflow_id: str, + workflow_request: WorkflowCreateRequestSchema, + db: Session = Depends(get_db), +) -> WorkflowResponseSchema: + workflow = db.query(WorkflowModel).filter(WorkflowModel.id == workflow_id).first() + if not workflow: + raise HTTPException(status_code=404, detail="Workflow not found") + if not workflow_request.definition: + raise HTTPException( + status_code=400, + detail="Workflow definition is required to update a workflow", + ) + + workflow.definition = workflow_request.definition.model_dump() + workflow.name = workflow_request.name + workflow.description = workflow_request.description + workflow.updated_at = datetime.now(timezone.utc) + db.commit() + db.refresh(workflow) + + return workflow + + +@router.get( + "/", + response_model=List[WorkflowResponseSchema], + description="List all workflows", +) +def list_workflows( + page: int = Query(default=1, ge=1), + page_size: int = Query(default=10, ge=1, le=100), + db: Session = Depends(get_db), +): + offset = (page - 1) * page_size + workflows = ( + db.query(WorkflowModel) + .order_by(WorkflowModel.created_at.desc()) + .offset(offset) + .limit(page_size) + .all() + ) + valid_workflows: List[WorkflowModel] = [] + for workflow in workflows: + try: + WorkflowResponseSchema.model_validate(workflow) + valid_workflows.append(workflow) + except Exception: + continue + return valid_workflows + + +@router.get( + "/{workflow_id}/", + response_model=WorkflowResponseSchema, + description="Get a workflow by ID", +) +def get_workflow(workflow_id: str, db: Session = Depends(get_db)) -> WorkflowResponseSchema: + workflow = db.query(WorkflowModel).filter(WorkflowModel.id == workflow_id).first() + if not workflow: + raise HTTPException(status_code=404, detail="Workflow not found") + + return workflow + + +@router.put( + "/{workflow_id}/reset/", + response_model=WorkflowResponseSchema, + description="Reset a workflow to its initial state", +) +def reset_workflow(workflow_id: str, db: Session = Depends(get_db)) -> WorkflowResponseSchema: + # Fetch the workflow by ID + workflow = db.query(WorkflowModel).filter(WorkflowModel.id == workflow_id).first() + + # If workflow not found, raise 404 error + if not workflow: + raise HTTPException(status_code=404, detail="Workflow not found") + + # Reset the workflow definition to a new one + workflow.definition = create_a_new_workflow_definition().model_dump() + + # Update the updated_at timestamp + workflow.updated_at = datetime.now(timezone.utc) + + # Commit the changes to the database + db.commit() + db.refresh(workflow) + + # Return the updated workflow + return workflow + + +@router.delete( + "/{workflow_id}/", + status_code=status.HTTP_204_NO_CONTENT, + description="Delete a workflow by ID", +) +def delete_workflow(workflow_id: str, db: Session = Depends(get_db)): + workflow = db.query(WorkflowModel).filter(WorkflowModel.id == workflow_id).first() + if not workflow: + raise HTTPException(status_code=404, detail="Workflow not found") + + try: + # Delete associated test files + test_files_dir = Path("data/test_files") / workflow_id + if test_files_dir.exists(): + shutil.rmtree(test_files_dir) + + # Delete the workflow (cascading will handle related records) + db.delete(workflow) + db.commit() + except Exception as e: + db.rollback() + raise HTTPException( + status_code=500, + detail=f"Error deleting workflow: {str(e)}", + ) from e + + return None + + +@router.post( + "/{workflow_id}/duplicate/", + response_model=WorkflowResponseSchema, + description="Duplicate a workflow by ID", +) +def duplicate_workflow(workflow_id: str, db: Session = Depends(get_db)) -> WorkflowResponseSchema: + # Fetch the workflow by ID + workflow = db.query(WorkflowModel).filter(WorkflowModel.id == workflow_id).first() + + # If workflow not found, raise 404 error + if not workflow: + raise HTTPException(status_code=404, detail="Workflow not found") + + # Create a new WorkflowModel instance by copying fields + new_workflow_name = generate_unique_workflow_name(db, f"{workflow.name} (Copy)") + + new_workflow = WorkflowModel( + name=new_workflow_name, + description=workflow.description, + definition=workflow.definition.model_dump(), + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + + # Add and commit the new workflow + db.add(new_workflow) + db.commit() + db.refresh(new_workflow) + + # Return the duplicated workflow + return new_workflow + + +@router.get( + "/{workflow_id}/output_variables/", + response_model=List[Dict[str, str]], + description="Get the output variables (leaf nodes) of a workflow", +) +def get_workflow_output_variables( + workflow_id: str, db: Session = Depends(get_db) +) -> List[Dict[str, str]]: + """Fetch the output variables (leaf nodes) of a workflow.""" + workflow = db.query(WorkflowModel).filter(WorkflowModel.id == workflow_id).first() + if not workflow: + raise HTTPException(status_code=404, detail="Workflow not found") + + workflow_definition = WorkflowDefinitionSchema.model_validate(workflow.definition) + + # Find leaf nodes (nodes without outgoing links) + all_source_ids = {link.source_id for link in workflow_definition.links} + all_node_ids = {node.id for node in workflow_definition.nodes} + leaf_nodes = all_node_ids - all_source_ids + + # Collect output variables as a list of dictionaries + output_variables: List[Dict[str, str]] = [] + for node in workflow_definition.nodes: + if node.id in leaf_nodes: + try: + # Try to get output_schema from the node config + output_schema: Dict[str, str] = {} + output_schema = node.config.get("output_schema", {}) + + # If no output schema is found, skip this node + if not output_schema: + continue + + for var_name in output_schema.keys(): + output_variables.append( + { + "node_id": node.id, + "variable_name": var_name, + "prefixed_variable": f"{node.id}-{var_name}", + } + ) + except Exception: + # If there's any error processing this node, skip it + continue + + return output_variables + + +@router.post( + "/upload_test_files/", + description="Upload test files for a specific node in a workflow", +) +async def upload_test_files( + workflow_id: str = Form(...), + files: List[UploadFile] = File(...), + node_id: str = Form(...), + db: Session = Depends(get_db), +) -> Dict[str, List[str]]: + """Upload files for test inputs and return their paths.""" + try: + # Get the workflow + workflow = db.query(WorkflowModel).filter(WorkflowModel.id == workflow_id).first() + if not workflow: + raise HTTPException(status_code=404, detail="Workflow not found") + + # Create workflow-specific directory for test files + test_files_dir = Path("data/test_files") / workflow_id + test_files_dir.mkdir(parents=True, exist_ok=True) + + # Save files and collect paths + saved_paths: List[str] = [] + for file in files: + # Generate unique filename + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + safe_filename = f"{timestamp}_{file.filename}" + file_path = test_files_dir / safe_filename + + # Save file + content = await file.read() + with open(file_path, "wb") as f: + f.write(content) + + # Store relative path + saved_paths.append(f"test_files/{workflow_id}/{safe_filename}") + + return {node_id: saved_paths} + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) from e + + +@router.get( + "/{workflow_id}/versions/", + response_model=List[WorkflowVersionResponseSchema], + description="Get all versions of a workflow", + tags=["workflows"], +) +def get_workflow_versions( + workflow_id: str, + page: int = Query(default=1, ge=1), + page_size: int = Query(default=10, ge=1, le=100), + db: Session = Depends(get_db), +) -> List[WorkflowVersionResponseSchema]: + """Retrieve all versions of a workflow, ordered by version number descending. + + Args: + workflow_id: The ID of the workflow + page: Page number for pagination + page_size: Number of items per page + db: Database session + + Returns: + List of workflow versions + + """ + # Check if workflow exists + workflow = db.query(WorkflowModel).filter(WorkflowModel.id == workflow_id).first() + if not workflow: + raise HTTPException(status_code=404, detail="Workflow not found") + + # Calculate offset for pagination + offset = (page - 1) * page_size + + # Query workflow versions + versions = ( + db.query(WorkflowVersionModel) + .filter(WorkflowVersionModel.workflow_id == workflow_id) + .order_by(WorkflowVersionModel.version.desc()) + .offset(offset) + .limit(page_size) + .all() + ) + + # Convert models to response schemas + return [WorkflowVersionResponseSchema.model_validate(version) for version in versions] diff --git a/pyspur/backend/pyspur/api/workflow_run.py b/pyspur/backend/pyspur/api/workflow_run.py new file mode 100644 index 0000000000000000000000000000000000000000..0a04abfe1736aea5d5d920b2cc0cdd5f432b87cf --- /dev/null +++ b/pyspur/backend/pyspur/api/workflow_run.py @@ -0,0 +1,1358 @@ +import asyncio +import base64 +import hashlib +import json +import os +import re +from datetime import datetime, timezone +from pathlib import Path # Import Path for directory handling +from typing import Any, Awaitable, Callable, Coroutine, Dict, List, Optional, Set, Tuple, Union + +from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Query +from loguru import logger +from sqlalchemy.orm import Session + +from ..database import get_db +from ..dataset.ds_util import get_ds_column_names, get_ds_iterator +from ..execution.task_recorder import TaskRecorder +from ..execution.workflow_execution_context import WorkflowExecutionContext +from ..execution.workflow_executor import WorkflowExecutor +from ..models.dataset_model import DatasetModel +from ..models.output_file_model import OutputFileModel +from ..models.run_model import RunModel, RunStatus +from ..models.task_model import TaskModel, TaskStatus +from ..models.workflow_model import WorkflowModel +from ..nodes.base import BaseNodeOutput +from ..nodes.factory import NodeFactory +from ..nodes.logic.human_intervention import HumanInterventionNodeOutput, PauseError +from ..schemas.pause_schemas import ( + PausedWorkflowResponseSchema, + PauseHistoryResponseSchema, +) +from ..schemas.run_schemas import ( + BatchRunRequestSchema, + PartialRunRequestSchema, + ResumeRunRequestSchema, + RunResponseSchema, + StartRunRequestSchema, +) +from ..schemas.workflow_schemas import WorkflowDefinitionSchema, WorkflowNodeSchema +from ..utils.workflow_version_utils import fetch_workflow_version + +router = APIRouter() + + +async def create_run_model( + workflow_id: str, + workflow_version_id: str, + initial_inputs: Dict[str, Dict[str, Any]], + parent_run_id: Optional[str], + run_type: str, + db: Session, +) -> RunModel: + new_run = RunModel( + workflow_id=workflow_id, + workflow_version_id=workflow_version_id, + status=RunStatus.PENDING, + initial_inputs=initial_inputs, + start_time=datetime.now(timezone.utc), + parent_run_id=parent_run_id, + run_type=run_type, + ) + db.add(new_run) + db.commit() + db.refresh(new_run) + return new_run + + +def process_embedded_files( + workflow_id: str, + initial_inputs: Dict[str, Dict[str, Any]], +) -> Dict[str, Dict[str, Any]]: + """Process any embedded files in the initial inputs and save them to disk. + + Returns updated inputs with file paths instead of data URIs. + """ + processed_inputs = initial_inputs.copy() + + # Iterate through the values to find data URIs recursively + def find_and_replace_data_uris(data: Any) -> Any: + if isinstance(data, dict): + return {str(k): find_and_replace_data_uris(v) for k, v in data.items()} # type: ignore + elif isinstance(data, list): + return [find_and_replace_data_uris(item) for item in data] # type: ignore + elif isinstance(data, str) and data.startswith("data:"): + return save_embedded_file(data, workflow_id) + else: + return data + + processed_inputs = find_and_replace_data_uris(processed_inputs) + return processed_inputs + + +def get_node_title_output_map( + nodes: List[WorkflowNodeSchema], + outputs: Dict[str, BaseNodeOutput], +) -> Dict[str, Dict[str, Any]]: + """Create a dictionary of node titles to outputs.""" + title_output_dict: Dict[str, Dict[str, Any]] = {} + for node_id, node_output in outputs.items(): + # Find the node with this ID to get its title + node = next((n for n in nodes if n.id == node_id), None) + if node and hasattr(node, "title") and node.title and node_output: + # Use the node's title as the key + title_output_dict[node.title] = node_output.model_dump() + return title_output_dict + + +@router.post( + "/{workflow_id}/runv2/", + response_model=RunResponseSchema, + description="Run a workflow and return the run details with outputs", +) +async def run_workflow_blocking_v2( # noqa: C901 + workflow_id: str, + request: StartRunRequestSchema, + db: Session = Depends(get_db), + run_type: str = "interactive", +) -> RunResponseSchema: + workflow = db.query(WorkflowModel).filter(WorkflowModel.id == workflow_id).first() + if not workflow: + raise HTTPException(status_code=404, detail="Workflow not found") + + workflow_version = fetch_workflow_version(workflow_id, workflow, db) + workflow_definition = WorkflowDefinitionSchema.model_validate(workflow_version.definition) + + initial_inputs = request.initial_inputs or {} + + # Process any embedded files in the inputs + initial_inputs = process_embedded_files(workflow_id, initial_inputs) + + # Handle file paths if present + if request.files: + for node_id, file_paths in request.files.items(): + if node_id in initial_inputs: + initial_inputs[node_id]["files"] = file_paths + + new_run = await create_run_model( + workflow_id, + workflow_version.id, + initial_inputs, + request.parent_run_id, + run_type, + db, + ) + task_recorder = TaskRecorder(db, new_run.id) + context = WorkflowExecutionContext( + workflow_id=workflow.id, + run_id=new_run.id, + parent_run_id=request.parent_run_id, + run_type=run_type, + db_session=db, + workflow_definition=workflow_version.definition, + ) + executor = WorkflowExecutor( + workflow=workflow_definition, + task_recorder=task_recorder, + context=context, + ) + input_node = next(node for node in workflow_definition.nodes if node.node_type == "InputNode") + + try: + outputs = await executor(initial_inputs[input_node.id]) + + # Check if any tasks were paused + has_paused_tasks = False + paused_node_ids: List[str] = [] + for task in new_run.tasks: + if task.status == TaskStatus.PAUSED: + has_paused_tasks = True + paused_node_ids.append(task.node_id) + + if has_paused_tasks: + # If we have paused tasks, ensure the run is in a PAUSED state + new_run.status = RunStatus.PAUSED + + # Get all blocked nodes from paused nodes + all_blocked_nodes: Set[str] = set() + for paused_node_id in paused_node_ids: + blocked_nodes = executor.get_blocked_nodes(paused_node_id) + all_blocked_nodes.update(blocked_nodes) + + # Make sure all downstream nodes are in PENDING status + for task in new_run.tasks: + if task.status == TaskStatus.CANCELED and task.node_id in all_blocked_nodes: + # Update any CANCELED tasks that should be PENDING + task_recorder.update_task( + node_id=task.node_id, + status=TaskStatus.PENDING, + end_time=datetime.now(), + is_downstream_of_pause=True, + ) + else: + new_run.status = RunStatus.COMPLETED + + new_run.end_time = datetime.now(timezone.utc) + nodes = workflow_version.definition["nodes"] + nodes = [WorkflowNodeSchema.model_validate(node) for node in nodes] + # Create outputs dictionary using node titles as keys instead of node IDs + new_run.outputs = get_node_title_output_map(nodes, outputs) + db.commit() + + # Refresh the run to get the updated tasks + db.refresh(new_run) + response = RunResponseSchema.model_validate(new_run) + response.message = "Workflow execution completed successfully." + return response + + except PauseError as e: + # Make sure the run status is set to PAUSED + new_run.status = RunStatus.PAUSED + new_run.outputs = get_node_title_output_map( + workflow_definition.nodes, {k: v for k, v in executor.outputs.items() if v is not None} + ) + + # Get all blocked nodes from paused nodes + paused_node_ids = [ + task.node_id for task in new_run.tasks if task.status == TaskStatus.PAUSED + ] + all_blocked_nodes: Set[str] = set() + for paused_node_id in paused_node_ids: + blocked_nodes = executor.get_blocked_nodes(paused_node_id) + all_blocked_nodes.update(blocked_nodes) + + # Make sure all downstream nodes are in PENDING status + for task in new_run.tasks: + if task.status == TaskStatus.CANCELED and task.node_id in all_blocked_nodes: + # Update any CANCELED tasks that should be PENDING + task_recorder.update_task( + node_id=task.node_id, + status=TaskStatus.PENDING, + end_time=datetime.now(), + is_downstream_of_pause=True, + ) + + db.commit() + # Refresh the run to get the updated tasks + db.refresh(new_run) + response = RunResponseSchema.model_validate(new_run) + response.message = "Workflow execution paused for human intervention." + raise HTTPException( + status_code=202, + detail=response.model_dump(), + ) from e + except Exception as e: + new_run.status = RunStatus.FAILED + new_run.end_time = datetime.now(timezone.utc) + db.commit() + response = RunResponseSchema.model_validate(new_run) + response.message = f"Workflow execution failed: {str(e)}" + raise HTTPException( + status_code=500, + detail=response.model_dump(), + ) from e + + +@router.post( + "/{workflow_id}/run/", + response_model=Dict[str, Any], + description="Run a workflow and return the outputs", +) +async def run_workflow_blocking( # noqa: C901 + workflow_id: str, + request: StartRunRequestSchema, + db: Session = Depends(get_db), + run_type: str = "interactive", +) -> Dict[str, Any]: + workflow = db.query(WorkflowModel).filter(WorkflowModel.id == workflow_id).first() + if not workflow: + raise HTTPException(status_code=404, detail="Workflow not found") + + workflow_version = fetch_workflow_version(workflow_id, workflow, db) + workflow_definition = WorkflowDefinitionSchema.model_validate(workflow_version.definition) + + initial_inputs = request.initial_inputs or {} + + # Process any embedded files in the inputs + initial_inputs = process_embedded_files(workflow_id, initial_inputs) + + # Handle file paths if present + if request.files: + for node_id, file_paths in request.files.items(): + if node_id in initial_inputs: + initial_inputs[node_id]["files"] = file_paths + + new_run = await create_run_model( + workflow_id, + workflow_version.id, + initial_inputs, + request.parent_run_id, + run_type, + db, + ) + task_recorder = TaskRecorder(db, new_run.id) + context = WorkflowExecutionContext( + workflow_id=workflow.id, + run_id=new_run.id, + parent_run_id=request.parent_run_id, + run_type=run_type, + db_session=db, + workflow_definition=workflow_version.definition, + ) + executor = WorkflowExecutor( + workflow=workflow_definition, + task_recorder=task_recorder, + context=context, + ) + input_node = next(node for node in workflow_definition.nodes if node.node_type == "InputNode") + + try: + outputs = await executor(initial_inputs[input_node.id]) + + # Check if any tasks were paused + has_paused_tasks = False + paused_node_ids: List[str] = [] + for task in new_run.tasks: + if task.status == TaskStatus.PAUSED: + has_paused_tasks = True + paused_node_ids.append(task.node_id) + + if has_paused_tasks: + # If we have paused tasks, ensure the run is in a PAUSED state + new_run.status = RunStatus.PAUSED + + # Get all blocked nodes from paused nodes + all_blocked_nodes: Set[str] = set() + for paused_node_id in paused_node_ids: + blocked_nodes = executor.get_blocked_nodes(paused_node_id) + all_blocked_nodes.update(blocked_nodes) + + # Make sure all downstream nodes are in PENDING status + for task in new_run.tasks: + if task.status == TaskStatus.CANCELED and task.node_id in all_blocked_nodes: + # Update any CANCELED tasks that should be PENDING + task_recorder.update_task( + node_id=task.node_id, + status=TaskStatus.PENDING, + end_time=datetime.now(), + is_downstream_of_pause=True, + ) + else: + new_run.status = RunStatus.COMPLETED + + new_run.end_time = datetime.now(timezone.utc) + new_run.outputs = get_node_title_output_map( + workflow_definition.nodes, {k: v for k, v in executor.outputs.items() if v is not None} + ) + db.commit() + + # Refresh the run to get the updated tasks + db.refresh(new_run) + return outputs + except PauseError as e: + # Make sure the run status is set to PAUSED + new_run.status = RunStatus.PAUSED + new_run.outputs = {k: v.model_dump() for k, v in executor.outputs.items() if v is not None} + + # Get all blocked nodes from paused nodes + paused_node_ids = [ + task.node_id for task in new_run.tasks if task.status == TaskStatus.PAUSED + ] + all_blocked_nodes: Set[str] = set() + for paused_node_id in paused_node_ids: + blocked_nodes = executor.get_blocked_nodes(paused_node_id) + all_blocked_nodes.update(blocked_nodes) + + # Make sure all downstream nodes are in PENDING status + for task in new_run.tasks: + if task.status == TaskStatus.CANCELED and task.node_id in all_blocked_nodes: + # Update any CANCELED tasks that should be PENDING + task_recorder.update_task( + node_id=task.node_id, + status=TaskStatus.PENDING, + end_time=datetime.now(), + is_downstream_of_pause=True, + ) + + db.commit() + # Refresh the run to get the updated tasks + db.refresh(new_run) + raise e + + +@router.post( + "/{workflow_id}/start_run/", + response_model=RunResponseSchema, + description="Start a non-blocking workflow run and return the run details", +) +async def run_workflow_non_blocking( # noqa: C901 + workflow_id: str, + start_run_request: StartRunRequestSchema, + background_tasks: BackgroundTasks, + db: Session = Depends(get_db), + run_type: str = "interactive", +) -> RunResponseSchema: + workflow = db.query(WorkflowModel).filter(WorkflowModel.id == workflow_id).first() + if not workflow: + raise HTTPException(status_code=404, detail="Workflow not found") + + workflow_version = fetch_workflow_version(workflow_id, workflow, db) + workflow_definition = WorkflowDefinitionSchema.model_validate(workflow_version.definition) + + initial_inputs = start_run_request.initial_inputs or {} + + # Process any embedded files in the inputs + initial_inputs = process_embedded_files(workflow_id, initial_inputs) + + new_run = await create_run_model( + workflow_id, + workflow_version.id, + initial_inputs, + start_run_request.parent_run_id, + run_type, + db, + ) + + async def run_workflow_task(run_id: str, workflow_definition: WorkflowDefinitionSchema): + with next(get_db()) as session: + run = session.query(RunModel).filter(RunModel.id == run_id).first() + if not run: + session.close() + return + + # Initialize workflow execution + run.status = RunStatus.RUNNING + session.commit() + task_recorder, context, executor = _setup_workflow_execution( + session, + run, + run_id, + start_run_request.parent_run_id, + run_type, + workflow_version, + workflow_definition, + ) + + # Store context for debugging or audit purposes + run.execution_context = context.model_dump() if hasattr(context, "model_dump") else None + + try: + # Execute workflow + assert run.initial_inputs + input_node = next( + node for node in workflow_definition.nodes if node.node_type == "InputNode" + ) + outputs = await executor(run.initial_inputs[input_node.id]) + run.outputs = get_node_title_output_map(workflow_definition.nodes, outputs) + + # Handle paused tasks if any + has_paused_tasks = _check_for_paused_tasks(run) + + if has_paused_tasks: + _handle_paused_workflow(run, executor, task_recorder, workflow_version) + else: + run.status = RunStatus.COMPLETED + + run.end_time = datetime.now(timezone.utc) + except PauseError: + _handle_pause_exception(run, executor, task_recorder, workflow_version) + session.commit() + # Refresh the run to get the updated tasks + session.refresh(run) + return # Don't raise the exception so the background task can complete + except Exception as e: + run.status = RunStatus.FAILED + run.end_time = datetime.now(timezone.utc) + session.commit() + raise e + session.commit() + + def _setup_workflow_execution( + session: Session, + run: RunModel, + run_id: str, + parent_run_id: Optional[str], + run_type: str, + workflow_version: Any, + workflow_definition: WorkflowDefinitionSchema, + ) -> Tuple[TaskRecorder, WorkflowExecutionContext, WorkflowExecutor]: + """Set up the execution environment for a workflow.""" + task_recorder = TaskRecorder(session, run_id) + context = WorkflowExecutionContext( + workflow_id=run.workflow_id, + run_id=run_id, + parent_run_id=parent_run_id, + run_type=run_type, + db_session=session, + workflow_definition=workflow_version.definition, + ) + executor = WorkflowExecutor( + workflow=workflow_definition, + task_recorder=task_recorder, + context=context, + ) + return task_recorder, context, executor + + def _check_for_paused_tasks(run: RunModel) -> bool: + """Check if any tasks in the run are paused.""" + for task in run.tasks: + if task.status == TaskStatus.PAUSED: + return True + return False + + def _handle_paused_workflow( + run: RunModel, + executor: WorkflowExecutor, + task_recorder: TaskRecorder, + workflow_version: Any, + ) -> None: + """Handle case when workflow has paused tasks.""" + run.status = RunStatus.PAUSED + + # Get all paused node IDs + paused_node_ids = [task.node_id for task in run.tasks if task.status == TaskStatus.PAUSED] + + # Update downstream tasks of paused nodes + _update_downstream_tasks(paused_node_ids, executor, workflow_version, run, task_recorder) + + def _handle_pause_exception( + run: RunModel, + executor: WorkflowExecutor, + task_recorder: TaskRecorder, + workflow_version: Any, + ) -> None: + """Handle PauseException during workflow execution.""" + run.status = RunStatus.PAUSED + run.outputs = get_node_title_output_map( + workflow_version.nodes, {k: v for k, v in executor.outputs.items() if v is not None} + ) + + # Get all paused node IDs + paused_node_ids = [task.node_id for task in run.tasks if task.status == TaskStatus.PAUSED] + + # Update downstream tasks of paused nodes + _update_downstream_tasks(paused_node_ids, executor, workflow_version, run, task_recorder) + + def _update_downstream_tasks( + paused_node_ids: List[str], + executor: WorkflowExecutor, + workflow_version: Any, + run: RunModel, + task_recorder: TaskRecorder, + ) -> None: + """Update status of tasks that depend on paused nodes.""" + all_blocked_nodes: Set[str] = set() + for paused_node_id in paused_node_ids: + blocked_nodes = executor.get_blocked_nodes(paused_node_id) + all_blocked_nodes.update(blocked_nodes) + + # Make sure all downstream nodes are in PENDING status + for task in run.tasks: + if task.status == TaskStatus.CANCELED and task.node_id in all_blocked_nodes: + # Update any CANCELED tasks that should be PENDING + task_recorder.update_task( + node_id=task.node_id, + status=TaskStatus.PENDING, + end_time=datetime.now(), + is_downstream_of_pause=True, + ) + + background_tasks.add_task(run_workflow_task, new_run.id, workflow_definition) + + return new_run + + +@router.post( + "/{workflow_id}/run_partial/", + response_model=Dict[str, Any], + description="Run a partial workflow and return the outputs", +) +async def run_partial_workflow( + workflow_id: str, + request: PartialRunRequestSchema, + db: Session = Depends(get_db), +) -> Dict[str, Any]: + workflow = db.query(WorkflowModel).filter(WorkflowModel.id == workflow_id).first() + if not workflow: + raise HTTPException(status_code=404, detail="Workflow not found") + workflow_definition = WorkflowDefinitionSchema.model_validate(workflow.definition) + executor = WorkflowExecutor(workflow_definition) + input_node = next(node for node in workflow_definition.nodes if node.node_type == "InputNode") + initial_inputs = request.initial_inputs or {} + try: + outputs = await executor.run( + input=initial_inputs.get(input_node.id, {}), + node_ids=[request.node_id], + precomputed_outputs=request.partial_outputs or {}, + ) + return outputs + except Exception as e: + raise HTTPException(status_code=400, detail=str(e)) from e + + +@router.post( + "/{workflow_id}/start_batch_run/", + response_model=RunResponseSchema, + description="Start a batch run of a workflow over a dataset and return the run details", +) +async def batch_run_workflow_non_blocking( # noqa: C901 + workflow_id: str, + request: BatchRunRequestSchema, + background_tasks: BackgroundTasks, + db: Session = Depends(get_db), +) -> RunResponseSchema: + workflow = db.query(WorkflowModel).filter(WorkflowModel.id == workflow_id).first() + if not workflow: + raise HTTPException(status_code=404, detail="Workflow not found") + + workflow_version = fetch_workflow_version(workflow_id, workflow, db) + + dataset_id = request.dataset_id + new_run = await create_run_model(workflow_id, workflow_version.id, {}, None, "batch", db) + + # parse the dataset + dataset = db.query(DatasetModel).filter(DatasetModel.id == dataset_id).first() + if not dataset: + raise HTTPException(status_code=404, detail="Dataset not found") + + # ensure ds columns match workflow inputs + dataset_columns = get_ds_column_names(dataset.file_path) + workflow_definition = WorkflowDefinitionSchema.model_validate(workflow_version.definition) + input_node = next(node for node in workflow_definition.nodes if node.node_type == "InputNode") + input_node_id = input_node.id + workflow_input_schema: Dict[str, str] = input_node.config["input_schema"] + for col in workflow_input_schema.keys(): + if col not in dataset_columns: + raise HTTPException( + status_code=400, + detail=f"Input field '{col}' in input schema not found in the dataset", + ) + + # create output file + output_file_name = f"output_{new_run.id}.jsonl" + output_file_path = os.path.join( + os.path.dirname(__file__), "..", "..", "output_files", output_file_name + ) + output_file = OutputFileModel( + file_name=output_file_name, + file_path=output_file_path, + ) + db.add(output_file) + db.commit() + + file_path = dataset.file_path + + mini_batch_size = request.mini_batch_size + + async def start_mini_batch_runs( + file_path: str, + workflow_id: str, + workflow_input_schema: Dict[str, str], + input_node_id: str, + parent_run_id: str, + background_tasks: BackgroundTasks, + db: Session, + mini_batch_size: int, + output_file_path: str, + ): + ds_iter = get_ds_iterator(file_path) + current_batch: List[Awaitable[Dict[str, Any]]] = [] + batch_count = 0 + for inputs in ds_iter: + initial_inputs = { + input_node_id: {k: v for k, v in inputs.items() if k in workflow_input_schema} + } + single_input_run_task = run_workflow_blocking( + workflow_id=workflow_id, + request=StartRunRequestSchema( + initial_inputs=initial_inputs, parent_run_id=parent_run_id + ), + db=db, + run_type="batch", + ) + current_batch.append(single_input_run_task) + if len(current_batch) == mini_batch_size: + minibatch_results = await asyncio.gather(*current_batch) + current_batch = [] + batch_count += 1 + with open(output_file_path, "a") as output_file: + for output in minibatch_results: + output = { + node_id: output.model_dump() for node_id, output in output.items() + } + output_file.write(json.dumps(output) + "\n") + + if current_batch: + results = await asyncio.gather(*current_batch) + with open(output_file_path, "a") as output_file: + for output in results: + output = {node_id: output.model_dump() for node_id, output in output.items()} + output_file.write(json.dumps(output) + "\n") + + with next(get_db()) as session: + run = session.query(RunModel).filter(RunModel.id == parent_run_id).first() + if not run: + session.close() + return + run.status = RunStatus.COMPLETED + run.end_time = datetime.now(timezone.utc) + session.commit() + + background_tasks.add_task( + start_mini_batch_runs, + file_path, + workflow_id, + workflow_input_schema, + input_node_id, + new_run.id, + background_tasks, + db, + mini_batch_size, + output_file_path, + ) + new_run.output_file_id = output_file.id + db.commit() + return new_run + + +@router.get( + "/{workflow_id}/runs/", + response_model=List[RunResponseSchema], + description="List all runs of a workflow", +) +def list_runs( + workflow_id: str, + page: int = Query(default=1, ge=1), + page_size: int = Query(default=10, ge=1, le=100), + start_date: Optional[datetime] = Query( + default=None, description="Filter runs after this date (inclusive)" + ), + end_date: Optional[datetime] = Query( + default=None, description="Filter runs before this date (inclusive)" + ), + status: Optional[RunStatus] = Query(default=None, description="Filter runs by status"), + db: Session = Depends(get_db), +): + offset = (page - 1) * page_size + query = db.query(RunModel).filter(RunModel.workflow_id == workflow_id) + + # Apply date filters if provided + if start_date: + query = query.filter(RunModel.start_time >= start_date) + if end_date: + query = query.filter(RunModel.start_time <= end_date) + + # Apply status filter if provided + if status: + query = query.filter(RunModel.status == status) + + # Order by start time descending and apply pagination + runs = query.order_by(RunModel.start_time.desc()).offset(offset).limit(page_size).all() + + # Update run status based on task status + for run in runs: + if run.status != RunStatus.FAILED: + failed_tasks = [task for task in run.tasks if task.status == TaskStatus.FAILED] + running_and_pending_tasks = [ + task + for task in run.tasks + if task.status in [TaskStatus.PENDING, TaskStatus.RUNNING] + ] + if failed_tasks and len(running_and_pending_tasks) == 0: + run.status = RunStatus.FAILED + db.commit() + db.refresh(run) + + return runs + + +def save_embedded_file(data_uri: str, workflow_id: str) -> str: + """Save a file from a data URI and return its relative path. + + Uses file content hash for the filename to avoid duplicates. + """ + # Extract the base64 data from the data URI + match = re.match(r"data:([^;]+);base64,(.+)", data_uri) + if not match: + raise ValueError("Invalid data URI format") + + mime_type, base64_data = match.groups() + file_data = base64.b64decode(base64_data) + + # Generate hash from file content + file_hash = hashlib.sha256(file_data).hexdigest()[:16] # Use first 16 chars of hash + + # Determine file extension from mime type + ext_map = { + "image/jpeg": ".jpg", + "image/png": ".png", + "image/gif": ".gif", + "application/pdf": ".pdf", + "video/mp4": ".mp4", + "text/plain": ".txt", + "text/csv": ".csv", + } + extension = ext_map.get(mime_type, "") + + # Create filename and ensure directory exists + filename = f"{file_hash}{extension}" + upload_dir = Path("data/run_files") / workflow_id + upload_dir.mkdir(parents=True, exist_ok=True) + + # Save the file + file_path = upload_dir / filename + with open(file_path, "wb") as f: + f.write(file_data) + + return f"run_files/{workflow_id}/{filename}" + + +def get_paused_workflows( + db: Session, + page: int = 1, + page_size: int = 10, +) -> List[PausedWorkflowResponseSchema]: + """Get all currently paused workflows.""" + # First get runs with paused tasks + paused_task_runs = ( + db.query(TaskModel.run_id).filter(TaskModel.status == TaskStatus.PAUSED).distinct() + ) + + # Then get runs with running tasks + running_task_runs = ( + db.query(TaskModel.run_id).filter(TaskModel.status == TaskStatus.RUNNING).distinct() + ) + + # Main query to get paused runs + paused_runs = ( + db.query(RunModel) + .filter( + # Either the run is marked as paused + (RunModel.status == RunStatus.PAUSED) + | + # Or has paused tasks but no running tasks + ( + RunModel.id.in_(paused_task_runs.scalar_subquery()) + & ~RunModel.id.in_(running_task_runs.scalar_subquery()) + ) + ) + .order_by(RunModel.start_time.desc()) + .offset((page - 1) * page_size) + .limit(page_size) + .all() + ) + + # Build response with workflow definitions + result: List[PausedWorkflowResponseSchema] = [] + for run in paused_runs: + workflow = db.query(WorkflowModel).filter(WorkflowModel.id == run.workflow_id).first() + if not workflow: + continue + + workflow_definition = WorkflowDefinitionSchema.model_validate(workflow.definition) + + # Find the current pause information from tasks + current_pause = None + if run.tasks: + # Find the most recently paused task + paused_tasks = [task for task in run.tasks if task.status == TaskStatus.PAUSED] + if paused_tasks: + # Sort by end_time descending to get the most recent pause + paused_tasks.sort( + key=lambda x: (x.end_time or x.start_time or datetime.min).replace( + tzinfo=timezone.utc + ), + reverse=True, + ) + latest_paused_task = paused_tasks[0] + + # Only create pause history if we have a pause time + pause_time = latest_paused_task.end_time or latest_paused_task.start_time + if pause_time: + # Ensure timezone is set + if pause_time.tzinfo is None: + pause_time = pause_time.replace(tzinfo=timezone.utc) + + current_pause = PauseHistoryResponseSchema( + id=f"PH_{run.id}_{latest_paused_task.node_id}", + run_id=run.id, + node_id=latest_paused_task.node_id, + pause_message=latest_paused_task.error or "Human intervention required", + pause_time=pause_time, + resume_time=latest_paused_task.end_time.replace(tzinfo=timezone.utc) + if latest_paused_task.end_time + else None, + resume_user_id=None, # This would come from task metadata if needed + resume_action=None, # This would come from task metadata if needed + input_data=latest_paused_task.inputs or {}, + comments=None, # This would come from task metadata if needed + ) + + if current_pause: + result.append( + PausedWorkflowResponseSchema( + run=RunResponseSchema.model_validate(run), + current_pause=current_pause, + workflow=workflow_definition, + ) + ) + + return result + + +def get_run_pause_history(db: Session, run_id: str) -> List[PauseHistoryResponseSchema]: + """Get the pause history for a specific run.""" + run = db.query(RunModel).filter(RunModel.id == run_id).first() + if not run: + raise HTTPException(status_code=404, detail="Run not found") + + # Build pause history from tasks + history: List[PauseHistoryResponseSchema] = [] + + if run.tasks: + # Get all tasks that were ever paused + paused_tasks = [task for task in run.tasks if task.status == TaskStatus.PAUSED] + for task in paused_tasks: + # Skip if no pause time + pause_time = task.end_time or task.start_time + if not pause_time: + continue + + # Ensure timezone is set + if pause_time.tzinfo is None: + pause_time = pause_time.replace(tzinfo=timezone.utc) + + history.append( + PauseHistoryResponseSchema( + id=f"PH_{run.id}_{task.node_id}", + run_id=run.id, + node_id=task.node_id, + pause_message=task.error or "Human intervention required", + pause_time=pause_time, + resume_time=task.end_time.replace(tzinfo=timezone.utc) + if task.end_time + else None, + resume_user_id=None, # This would come from task metadata if needed + resume_action=None, # This would come from task metadata if needed + input_data=task.inputs or {}, + comments=None, # This would come from task metadata if needed + ) + ) + + return sorted(history, key=lambda x: x.pause_time, reverse=True) + + +def _get_and_validate_paused_run( + db: Session, + run_id: str, +) -> Tuple[RunModel, TaskModel]: + """Get the paused run and validate its state. + + Args: + db: Database session + run_id: The ID of the paused run + + Returns: + Tuple of (run, paused_task) + + Raises: + HTTPException: If the run is not found or not in a paused state + + """ + # Get the run + run = db.query(RunModel).filter(RunModel.id == run_id).first() + if not run: + raise HTTPException(status_code=404, detail="Run not found") + + if run.status != RunStatus.PAUSED: + # Check if there are any paused tasks + has_paused_tasks = any(task.status == TaskStatus.PAUSED for task in run.tasks) + if not has_paused_tasks: + raise HTTPException(status_code=400, detail="Run is not in a paused state") + + # Find the paused task + paused_task = None + for task in run.tasks: + if task.status == TaskStatus.PAUSED: + paused_task = task + break + + if not paused_task: + raise HTTPException(status_code=400, detail="No paused task found for this run") + + return run, paused_task + + +def _update_paused_task( + db: Session, + run: RunModel, + paused_task: TaskModel, + action_request: ResumeRunRequestSchema, +) -> None: + """Update the paused task with the action. + + Args: + db: Database session + run: The run model + paused_task: The paused task to update + action_request: The action request + + """ + # Update the task with the action + paused_task.end_time = datetime.now(timezone.utc) + paused_task.status = TaskStatus.COMPLETED # Mark as COMPLETED instead of RUNNING + paused_task.error = None # Clear any error message + paused_task.outputs = action_request.inputs # Store new inputs as outputs + + # Delete any pending tasks for the same node + # This prevents duplicate tasks when the workflow is resumed + pending_tasks = ( + db.query(TaskModel) + .filter( + TaskModel.run_id == run.id, + TaskModel.node_id == paused_task.node_id, + TaskModel.status == TaskStatus.PENDING, + ) + .all() + ) + + for pending_task in pending_tasks: + db.delete(pending_task) + + db.commit() + db.refresh(run) + + +def _setup_workflow_executor( + db: Session, + run: RunModel, + paused_task: TaskModel, +) -> Tuple[WorkflowExecutor, WorkflowDefinitionSchema, WorkflowExecutionContext]: + """Set up the workflow executor for resuming the workflow. + + Args: + db: Database session + run: The run model + paused_task: The paused task + + Returns: + Tuple of (executor, workflow_definition, context) + + Raises: + HTTPException: If the workflow is not found + + """ + # Get the workflow + workflow = db.query(WorkflowModel).filter(WorkflowModel.id == run.workflow_id).first() + if not workflow: + raise HTTPException(status_code=404, detail="Workflow not found") + + workflow_version = fetch_workflow_version(run.workflow_id, workflow, db) + workflow_definition = WorkflowDefinitionSchema.model_validate(workflow_version.definition) + + # Update run status to RUNNING + run.status = RunStatus.RUNNING + db.commit() + + # Create a new task recorder and context + task_recorder = TaskRecorder(db, run.id) + context = WorkflowExecutionContext( + workflow_id=workflow.id, + run_id=run.id, + parent_run_id=run.parent_run_id, + run_type=run.run_type, + db_session=db, + workflow_definition=workflow_version.definition, + ) + + # Create executor with the existing workflow definition - pass the paused node ID as resumed + executor = WorkflowExecutor( + workflow=workflow_definition, + task_recorder=task_recorder, + context=context, + resumed_node_ids=[paused_task.node_id], # Tell executor which node was resumed + ) + + return executor, workflow_definition, context + + +def _update_executor_outputs( + executor: WorkflowExecutor, + run: RunModel, + paused_task: TaskModel, + action_request: ResumeRunRequestSchema, + workflow_definition: WorkflowDefinitionSchema, +) -> None: + """Update the executor outputs with existing outputs and resume information. + + Args: + executor: The workflow executor + run: The run model + paused_task: The paused task + action_request: The action request + workflow_definition: The workflow definition + + """ + # Update the outputs with existing outputs + if run.outputs: + executor.outputs = { + k: NodeFactory.create_node( + node_name=node.title, + node_type_name=node.node_type, + config=node.config, + ).output_model.model_validate(v) + for k, v in run.outputs.items() + for node in workflow_definition.nodes + if node.id == k + } + + # Update the paused node's output with resume information + if paused_task.node_id and paused_task.node_id in executor.outputs: + node_output = executor.outputs[paused_task.node_id] + if isinstance(node_output, HumanInterventionNodeOutput): + # Create a properly structured output for the HumanInterventionNode + # First, gather the action request inputs + inputs_data = {} + + # If we have task inputs, include them in the structure + if paused_task.inputs and isinstance(paused_task.inputs, dict): + inputs_data.update(paused_task.inputs) # type: ignore + + # Add the new inputs from the action request + # This ensures downstream nodes can access values via HumanInterventionNode_1.input_1 + if action_request.inputs: + inputs_data.update(action_request.inputs) # type: ignore + + # Create the output with the proper structure - don't nest under input_node + # This makes fields directly accessible + # in templates like {{HumanInterventionNode_1.input_1}} + updated_output = HumanInterventionNodeOutput(**inputs_data) + + # For debugging + print(f"Updated HumanInterventionNodeOutput structure: {updated_output}") + + executor.outputs[paused_task.node_id] = updated_output + + +def process_pause_action( + db: Session, + run_id: str, + action_request: ResumeRunRequestSchema, + bg_tasks: Optional[BackgroundTasks] = None, +) -> RunResponseSchema: + """Process an action on a paused workflow. + + This is the common function used by the take_pause_action endpoint. + It handles the core logic for processing human intervention in paused workflows. + + The workflow_id is retrieved from the run object, + so it doesn't need to be passed as a parameter. + + Args: + db: Database session + run_id: The ID of the paused run + action_request: The details of the action to take + bg_tasks: Optional background tasks handler to resume the workflow asynchronously + + Returns: + Information about the resumed run + + Raises: + HTTPException: If the run is not found or not in a paused state + + """ + # Get and validate the run and paused task + run, paused_task = _get_and_validate_paused_run(db, run_id) + + # Update the paused task + _update_paused_task(db, run, paused_task, action_request) + + # If background_tasks is provided, automatically resume the workflow + if bg_tasks: + # Setup the workflow executor + executor, workflow_definition, context = _setup_workflow_executor(db, run, paused_task) + + # Update executor outputs + _update_executor_outputs(executor, run, paused_task, action_request, workflow_definition) + + # Define the async workflow task and add it to background tasks + bg_tasks.add_task( + _create_resume_workflow_task(executor, run, paused_task, context, action_request, db) + ) + + response = RunResponseSchema.model_validate(run) + response.message = "Task completed and workflow execution resumed automatically." + else: + # If no background_tasks, just return as before + response = RunResponseSchema.model_validate(run) + response.message = ( + "Task marked as completed. Please call the resume" + "endpoint to continue workflow execution." + ) + + return response + + +def _create_resume_workflow_task( # noqa: C901 + executor: WorkflowExecutor, + run: RunModel, + paused_task: TaskModel, + context: WorkflowExecutionContext, + action_request: ResumeRunRequestSchema, + db: Session, +) -> Callable[[], Coroutine[Any, Any, None]]: + """Create the async function for resuming the workflow. + + Args: + executor: The workflow executor + run: The run model + paused_task: The paused task + context: The workflow execution context + action_request: The action request + db: Database session + + Returns: + Async function to resume the workflow + + """ + + async def resume_workflow_task(): # noqa: C901 + try: + # Find any PENDING tasks that were blocked by the paused node + blocked_node_ids: set[str] = set() + if _workflow_definition := getattr(context, "workflow_definition", None): + blocked_node_ids = executor.get_blocked_nodes(paused_task.node_id) + + # Update their status to RUNNING + for task in run.tasks: + if task.status == TaskStatus.PENDING and task.node_id in blocked_node_ids: + task.status = TaskStatus.RUNNING + task.start_time = datetime.now(timezone.utc) + db.commit() + + # Convert outputs to dict format for precomputed_outputs + precomputed: Dict[str, Union[Dict[str, Any], List[Dict[str, Any]]]] = {} + for k, v in executor.outputs.items(): + if v is not None: + try: + precomputed[k] = v.model_dump() + except Exception: + continue + + # Get all nodes including blocked nodes and the resumed node + # We specifically don't include the paused node + # in nodes_to_run since it's already been completed + nodes_to_run: set[str] = blocked_node_ids + + # IMPORTANT: Add the paused node ID to the executor's resumed_node_ids set + # This prevents the executor from creating a new task for this node + executor.add_resumed_node_id(paused_task.node_id) + + # Also add any node IDs that already have COMPLETED tasks + # This prevents the executor from creating new tasks for these nodes + completed_node_ids: set[str] = set() + for task in run.tasks: + if task.status == TaskStatus.COMPLETED: + completed_node_ids.add(task.node_id) + executor.add_resumed_node_id(task.node_id) + + # Make sure we include any necessary node inputs in the precomputed outputs + if action_request.inputs and paused_task.node_id: + # Set the action_request.inputs as the output for the paused task + # When we updated the paused node's output above, we already + # created the proper HumanInterventionNodeOutput structure + # So we just need to make sure it's formatted properly for precomputed_outputs + node_output = executor.outputs.get(paused_task.node_id) + if node_output: + # Use model_dump to get the flat structure of fields + precomputed[paused_task.node_id] = node_output.model_dump() + else: + # Fallback - use the inputs directly + precomputed[paused_task.node_id] = action_request.inputs + + # Run the workflow with the precomputed outputs + outputs = await executor.run( + input={}, # Input already provided in initial run + node_ids=list(nodes_to_run), # Run the blocked nodes + precomputed_outputs=precomputed, # Use existing outputs plus our human input + ) + + # Create a dictionary of outputs - keep existing outputs and add new ones + if run.outputs: + combined_outputs = run.outputs + for k, v in outputs.items(): + combined_outputs[k] = v.model_dump() + run.outputs = combined_outputs + else: + run.outputs = {k: v.model_dump() for k, v in outputs.items()} + + run.status = RunStatus.COMPLETED + run.end_time = datetime.now(timezone.utc) + except Exception as e: + run.status = RunStatus.FAILED + run.end_time = datetime.now(timezone.utc) + logger.error(f"Error resuming workflow: {e}") + db.commit() + + return resume_workflow_task + + +@router.post( + "/cancel_workflow/{run_id}/", + response_model=RunResponseSchema, + description="Cancel a workflow that is awaiting human approval", +) +def cancel_workflow( + run_id: str, + db: Session = Depends(get_db), +) -> RunResponseSchema: + """Cancel a workflow that is currently paused or awaiting human approval. + + This will mark the run as CANCELED in the database and update all pending tasks + to CANCELED as well. + + Args: + run_id: The ID of the run to cancel + db: Database session dependency + + Returns: + Information about the canceled run + + Raises: + HTTPException: If the run is not found or not in a state that can be canceled + + """ + # Get the run + run = db.query(RunModel).filter(RunModel.id == run_id).first() + if not run: + raise HTTPException(status_code=404, detail="Run not found") + + # Check if the run is in a state that can be canceled + if run.status not in [RunStatus.PAUSED, RunStatus.RUNNING]: + raise HTTPException( + status_code=400, + detail=( + f"Run is in state {run.status} and cannot be canceled." + "Only PAUSED or RUNNING runs can be canceled." + ), + ) + + # Update the run status + run.status = RunStatus.CANCELED + run.end_time = datetime.now(timezone.utc) + + # Update all pending and running tasks to canceled + for task in run.tasks: + if task.status in [TaskStatus.PENDING, TaskStatus.RUNNING, TaskStatus.PAUSED]: + task.status = TaskStatus.CANCELED + if not task.end_time: + task.end_time = datetime.now(timezone.utc) + + # Commit the changes + db.commit() + db.refresh(run) + + # Return the updated run + response = RunResponseSchema.model_validate(run) + response.message = "Workflow has been canceled successfully." + return response diff --git a/pyspur/backend/pyspur/cli/__init__.py b/pyspur/backend/pyspur/cli/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4e783c42bb6d25a9a21c97519218e813231a799f --- /dev/null +++ b/pyspur/backend/pyspur/cli/__init__.py @@ -0,0 +1,5 @@ +"""PySpur CLI package.""" + +from .main import main + +__all__ = ["main"] diff --git a/pyspur/backend/pyspur/cli/main.py b/pyspur/backend/pyspur/cli/main.py new file mode 100644 index 0000000000000000000000000000000000000000..0825ad0cf9d677bf8e5c524879adc05787480b5f --- /dev/null +++ b/pyspur/backend/pyspur/cli/main.py @@ -0,0 +1,187 @@ +"""Main module for the PySpur CLI.""" + +import os +import shutil +from importlib.metadata import version as get_version +from pathlib import Path +from typing import Optional + +import typer +import uvicorn +from rich import print +from rich.console import Console + +from .utils import copy_template_file, load_environment, run_migrations + +app = typer.Typer( + name="pyspur", + help="PySpur CLI - A tool for building and deploying AI Agents", + add_completion=False, +) + +console = Console() + + +@app.command(name="version") +def show_version() -> None: + """Display the current version of PySpur.""" + try: + ver = get_version("pyspur") + print(f"PySpur version: [bold green]{ver}[/bold green]") + except ImportError: + print("[yellow]PySpur version: unknown (package not installed)[/yellow]") + + +@app.command() +def init( + path: Optional[str] = typer.Argument( + None, + help="Path where to initialize PySpur project. Defaults to current directory.", + ), +) -> None: + """Initialize a new PySpur project in the specified directory.""" + target_dir = Path(path) if path else Path.cwd() + + if not target_dir.exists(): + target_dir.mkdir(parents=True) + + # Copy .env.example + try: + copy_template_file(".env.example", target_dir / ".env.example") + print("[green]✓[/green] Created .env.example") + + # Create .env if it doesn't exist + env_path = target_dir / ".env" + if not env_path.exists(): + shutil.copy2(target_dir / ".env.example", env_path) + print("[green]✓[/green] Created .env from template") + + # add PROJECT_ROOT to .env + # Check if PROJECT_ROOT is already defined in .env + with open(env_path, "r") as f: + if "PROJECT_ROOT=" not in f.read(): + with open(env_path, "a") as f: + f.write("\n# ================================") + f.write("\n# PROJECT_ROOT: DO NOT CHANGE THIS VALUE") + f.write("\n# ================================") + f.write("\nPROJECT_ROOT=" + str(target_dir) + "\n") + + # add __init__.py to the project directory + init_file_path = target_dir / "__init__.py" + if not init_file_path.exists(): + with open(init_file_path, "w") as f: + f.write("# This is an empty __init__.py file") + print("[green]✓[/green] Created __init__.py") + + custom_dirs = { + "data": target_dir / "data", + "tools": target_dir / "tools", + "spurs": target_dir / "spurs", + } + # Create custom directories + for dir_name, dir_path in custom_dirs.items(): + if not dir_path.exists(): + dir_path.mkdir() + print(f"[green]✓[/green] Created {dir_name} directory") + + # add __init__.py to the tools and spurs directories + for dir_name, dir_path in custom_dirs.items(): + if dir_name in ["tools", "spurs"]: + init_file_path = dir_path / "__init__.py" + if not init_file_path.exists(): + with open(init_file_path, "w") as f: + f.write("# This is an empty __init__.py file") + print(f"[green]✓[/green] Created {dir_name}/__init__.py") + + # add .gitignore to the project, if it doesn't exist + # if it exists, add data/ and .env to it + gitignore_path = target_dir / ".gitignore" + if not gitignore_path.exists(): + with open(gitignore_path, "w") as f: + f.write("# PySpur project\n") + f.write("data/\n") + f.write(".env\n") + else: + with open(gitignore_path, "a") as f: + f.write("data/\n") + f.write(".env\n") + print("[green]✓[/green] Created a .gitignore file") + + print("\n[bold green]PySpur project initialized successfully! 🚀[/bold green]") + print("\nNext steps:") + print("1. Review and update the .env file with your configuration") + print("2. For quick protoype: start the PySpur server with 'pyspur serve --sqlite'") + print( + "3. For production:\n" + " a. Provide a PostgreSQL database details in the .env file\n" + " b. Start the server with 'pyspur serve'" + ) + + print( + "[yellow]Note: We collect anonymous telemetry data that helps us improve PySpur." + " You can disable this by setting DISABLE_ANONYMOUS_TELEMETRY=true in .env[/yellow]" + ) + + except Exception as e: + print(f"[red]Error initializing project: {str(e)}[/red]") + raise typer.Exit(1) from e + + +@app.command() +def serve( + host: str = typer.Option( + None, + help="Host to bind the server to. Defaults to PYSPUR_HOST from environment or 0.0.0.0", + ), + port: int = typer.Option( + None, + help="Port to bind the server to. Defaults to PYSPUR_PORT from environment or 6080", + ), + sqlite: bool = typer.Option( + False, + help="Use SQLite database instead of PostgreSQL. Useful for local development.", + ), +) -> None: + """Start the PySpur server.""" + try: + # Load environment variables + load_environment() + + # Use environment variables as defaults if not provided via CLI + host = host or os.getenv("PYSPUR_HOST", "0.0.0.0") + port = port or int(os.getenv("PYSPUR_PORT", "6080")) + + if sqlite: + print("[yellow]Using SQLite database for local development...[/yellow]") + os.environ["SQLITE_OVERRIDE_DATABASE_URL"] = "sqlite:///./pyspur.db" + + # Run database migrations + print("[yellow]Running database migrations...[/yellow]") + run_migrations() + + if os.getenv("DISABLE_ANONYMOUS_TELEMETRY", "false").lower() != "true": + print( + "[yellow]Note: We collect anonymous telemetry data that helps us improve PySpur." + " You can disable this by setting DISABLE_ANONYMOUS_TELEMETRY=true in .env[/yellow]" + ) + + # Start the server + print(f"\n[green]Starting PySpur server at http://{host}:{port} 🚀[/green]") + uvicorn.run( + "pyspur.api.main:app", + host=host, + port=port, + ) + + except Exception as e: + print(f"[red]Error starting server: {str(e)}[/red]") + raise typer.Exit(1) from e + + +def main() -> None: + """PySpur CLI.""" + app() + + +if __name__ == "__main__": + main() diff --git a/pyspur/backend/pyspur/cli/utils.py b/pyspur/backend/pyspur/cli/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..77b5f129388781fb1ffda3451325e8a8087b54ba --- /dev/null +++ b/pyspur/backend/pyspur/cli/utils.py @@ -0,0 +1,141 @@ +"""Utility functions for the PySpur CLI.""" + +import shutil +import tempfile +from importlib import resources +from pathlib import Path + +import typer +from alembic import command +from alembic.config import Config +from alembic.runtime.migration import MigrationContext +from dotenv import load_dotenv +from rich import print +from sqlalchemy import text + + +def copy_template_file(template_name: str, dest_path: Path) -> None: + """Copy a template file from the package templates directory to the destination.""" + with resources.files("pyspur.templates").joinpath(template_name).open("rb") as src: + with open(dest_path, "wb") as dst: + shutil.copyfileobj(src, dst) + + +def load_environment() -> None: + """Load environment variables from .env file with fallback to .env.example.""" + env_path = Path.cwd() / ".env" + if env_path.exists(): + load_dotenv(env_path) + print("[green]✓[/green] Loaded configuration from .env") + else: + with resources.files("pyspur.templates").joinpath(".env.example").open() as f: + load_dotenv(stream=f) + print( + "[yellow]![/yellow] No .env file found," + " using default configuration from .env.example" + ) + print("[yellow]![/yellow] Run 'pyspur init' to create a customizable .env file") + + +def run_migrations() -> None: + """Run database migrations using SQLAlchemy.""" + try: + # ruff: noqa: F401 + from ..database import database_url, engine + from ..models.base_model import BaseModel + + # Import models + from ..models.dataset_model import DatasetModel # type: ignore + from ..models.dc_and_vi_model import ( + DocumentCollectionModel, # type: ignore + VectorIndexModel, # type: ignore + ) + from ..models.eval_run_model import EvalRunModel # type: ignore + from ..models.output_file_model import OutputFileModel # type: ignore + from ..models.run_model import RunModel # type: ignore + from ..models.slack_agent_model import SlackAgentModel # type: ignore + from ..models.task_model import TaskModel # type: ignore + from ..models.user_session_model import ( + MessageModel, # type: ignore + SessionModel, # type: ignore + UserModel, # type: ignore + ) + from ..models.workflow_model import WorkflowModel # type: ignore + from ..models.workflow_version_model import WorkflowVersionModel # type: ignore + # Import all models to ensure they're registered with SQLAlchemy + + # Test connection + with engine.connect() as conn: + conn.execute(text("SELECT 1")) + print("[green]✓[/green] Connected to database") + + # If using SQLite, create the database file if it doesn't exist + if database_url.startswith("sqlite"): + try: + BaseModel.metadata.create_all(engine) + print("[green]✓[/green] Created SQLite database") + print(f"[green]✓[/green] Database URL: {database_url}") + # Print all tables in the database + tables = BaseModel.metadata.tables + if tables: + print("\n[green]✓[/green] Successfully initialized SQLite database") + else: + print("[red]![/red] SQLite database is empty") + raise typer.Exit(1) + print("[yellow]![/yellow] SQLite database is not recommended for production") + print("[yellow]![/yellow] Please use a postgres instance instead") + return + except Exception: + print("[yellow]![/yellow] SQLite database out of sync, recreating from scratch") + # Ask for confirmation before dropping all tables + confirm = input( + "This will delete all data in the SQLite database. Are you sure? (y/N): " + ) + if confirm.lower() != "y": + print("[yellow]![/yellow] Database recreation cancelled") + print( + "[yellow]![/yellow] Please revert pyspur to the original" + " version that was used to create the database" + ) + print("[yellow]![/yellow] OR use a postgres instance to support migrations") + return + BaseModel.metadata.drop_all(engine) + BaseModel.metadata.create_all(engine) + print("[green]✓[/green] Created SQLite database from scratch") + return + + # For other databases, use Alembic migrations + # Get migration context + context = MigrationContext.configure(conn) + + # Get current revision + current_rev = context.get_current_revision() + + if current_rev is None: + print("[yellow]![/yellow] No previous migrations found, initializing database") + else: + print(f"[green]✓[/green] Current database version: {current_rev}") + + # Get migration scripts directory using importlib.resources + script_location = resources.files("pyspur.models.management.alembic") + if not script_location.is_dir(): + raise FileNotFoundError("Migration scripts not found in package") + + # extract migration scripts directory to a temporary location + with ( + tempfile.TemporaryDirectory() as script_temp_dir, + resources.as_file(script_location) as script_location_path, + ): + shutil.copytree(script_location_path, Path(script_temp_dir), dirs_exist_ok=True) + # Create Alembic config programmatically + config = Config() + config.set_main_option("script_location", str(script_temp_dir)) + config.set_main_option("sqlalchemy.url", database_url) + + # Run upgrade to head + command.upgrade(config, "head") + print("[green]✓[/green] Database schema is up to date") + + except Exception as e: + print(f"[red]Error running migrations: {str(e)}[/red]") + raise typer.Exit(1) from e diff --git a/pyspur/backend/pyspur/database.py b/pyspur/backend/pyspur/database.py new file mode 100644 index 0000000000000000000000000000000000000000..32b443277a5acd512546de9b8eeb2df5984b27df --- /dev/null +++ b/pyspur/backend/pyspur/database.py @@ -0,0 +1,42 @@ +import os +from typing import Iterator + +from sqlalchemy import create_engine +from sqlalchemy.orm import Session, sessionmaker + +# Get the database URL from the environment +POSTGRES_USER = os.getenv("POSTGRES_USER") +POSTGRES_PASSWORD = os.getenv("POSTGRES_PASSWORD") +POSTGRES_HOST = os.getenv("POSTGRES_HOST") +POSTGRES_PORT = os.getenv("POSTGRES_PORT") +POSTGRES_DB = os.getenv("POSTGRES_DB") + +database_url = f"postgresql://{POSTGRES_USER}:{POSTGRES_PASSWORD}@{POSTGRES_HOST}:{POSTGRES_PORT}/{POSTGRES_DB}" + +sqlite_override_database_url = os.getenv("SQLITE_OVERRIDE_DATABASE_URL") +if sqlite_override_database_url: + database_url = sqlite_override_database_url + +# Create the SQLAlchemy engine +engine = create_engine(database_url) + +# Create a configured "Session" class +SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + + +def get_db() -> Iterator[Session]: + """Get a database connection.""" + db = SessionLocal() + try: + yield db + finally: + db.close() + + +def is_db_connected() -> bool: + """Check if the database is connected.""" + try: + engine.connect() + return True + except Exception: + return False diff --git a/pyspur/backend/pyspur/dataset/ds_util.py b/pyspur/backend/pyspur/dataset/ds_util.py new file mode 100644 index 0000000000000000000000000000000000000000..e284b4e677fc9445562a21ed18484eb24018b4f3 --- /dev/null +++ b/pyspur/backend/pyspur/dataset/ds_util.py @@ -0,0 +1,46 @@ +from typing import Any, Dict, Iterator, Set + +import pandas as pd + + +def get_ds_column_names( + file_path: str, +) -> Set[str]: + """ + Returns the column names of a pandas compatible dataset file. + """ + if file_path.endswith(".csv"): + df: pd.DataFrame = pd.read_csv(file_path) # type: ignore + elif file_path.endswith(".parquet"): + df: pd.DataFrame = pd.read_parquet(file_path) + elif file_path.endswith(".jsonl"): + df: pd.DataFrame = pd.read_json(file_path, lines=True) # type: ignore + else: + raise ValueError(f"Unsupported file format: {file_path}") + + # make sure each column name is a string + df.columns = [str(col) for col in df.columns] + + return set(df.columns) + + +def get_ds_iterator( + file_path: str, +) -> Iterator[Dict[str, Any]]: + """ + Returns an iterator over the rows of a pandas compatible dataset file. + """ + if file_path.endswith(".csv"): + df: pd.DataFrame = pd.read_csv(file_path) # type: ignore + elif file_path.endswith(".parquet"): + df: pd.DataFrame = pd.read_parquet(file_path) + elif file_path.endswith(".jsonl"): + df: pd.DataFrame = pd.read_json(file_path, lines=True) # type: ignore + else: + raise ValueError(f"Unsupported file format: {file_path}") + + # make sure each column name is a string + df.columns = [str(col) for col in df.columns] + + for _, row in df.iterrows(): # type: ignore + yield row.to_dict() # type: ignore diff --git a/pyspur/backend/pyspur/evals/README.MD b/pyspur/backend/pyspur/evals/README.MD new file mode 100644 index 0000000000000000000000000000000000000000..f336dfbc5fd6f23d83f1e96598617579d6b6f151 --- /dev/null +++ b/pyspur/backend/pyspur/evals/README.MD @@ -0,0 +1,19 @@ +# Supported data formats + +* HuggingFace Datasets +* Link to CSV +* Blobfile + +# Example tasks + +- [X] GSM8K +- [X] GPQA +- [X] MATH +- [X] MMLU + +# Tips + +* When specifying the regex in a task yaml file, make sure you escape backlashes. + * Option 1: Double quotes with escaped backslashes: + * Option 2: Single quotes (which automatically escapes special characters) + * Do not use a raw string literal `r` like in Python \ No newline at end of file diff --git a/pyspur/backend/pyspur/evals/common.py b/pyspur/backend/pyspur/evals/common.py new file mode 100644 index 0000000000000000000000000000000000000000..2ee8b9cac25be49a8b143e400e97c518bd0c9a65 --- /dev/null +++ b/pyspur/backend/pyspur/evals/common.py @@ -0,0 +1,151 @@ +import re +from typing import Optional + +import numpy as np + +QUERY_TEMPLATE_MULTICHOICE = """ +Answer the following multiple choice question. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before answering. + +{Question} + +A) {A} +B) {B} +C) {C} +D) {D} +""".strip() + +ANSWER_PATTERN_MULTICHOICE = r"(?i)Answer\s*:\s*([A-D])" + + +EQUALITY_TEMPLATE = r""" +Look at the following two expressions (answers to a math problem) and judge whether they are equivalent. Only perform trivial simplifications + +Examples: + + Expression 1: $2x+3$ + Expression 2: $3+2x$ + +Yes + + Expression 1: 3/2 + Expression 2: 1.5 + +Yes + + Expression 1: $x^2+2x+1$ + Expression 2: $y^2+2y+1$ + +No + + Expression 1: $x^2+2x+1$ + Expression 2: $(x+1)^2$ + +Yes + + Expression 1: 3245/5 + Expression 2: 649 + +No +(these are actually equal, don't mark them equivalent if you need to do nontrivial simplifications) + + Expression 1: 2/(-3) + Expression 2: -2/3 + +Yes +(trivial simplifications are allowed) + + Expression 1: 72 degrees + Expression 2: 72 + +Yes +(give benefit of the doubt to units) + + Expression 1: 64 + Expression 2: 64 square feet + +Yes +(give benefit of the doubt to units) + +--- + +YOUR TASK + + +Respond with only "Yes" or "No" (without quotes). Do not include a rationale. + + Expression 1: %(expression1)s + Expression 2: %(expression2)s +""".strip() + + +def format_multichoice_question(row): + return QUERY_TEMPLATE_MULTICHOICE.format(**row) + + +def _compute_stat(values: list, stat: str): + if stat == "mean": + return np.mean(values) + elif stat == "std": + return np.std(values) + elif stat == "min": + return np.min(values) + elif stat == "max": + return np.max(values) + else: + raise ValueError(f"Unknown {stat =}") + + +def normalize_response(response: str) -> str: + """ + Normalize the response by removing markdown and LaTeX formatting that may prevent a match. + """ + + return ( + response.replace("**", "") + .replace("$\\boxed{", "") + .replace("}$", "") + .replace("\\$", "") + .replace("$\\text{", "") + .replace("$", "") + .replace("\\mathrm{", "") + .replace("\\{", "") + .replace("\\text", "") + .replace("\\(", "") + .replace("\\mathbf{", "") + .replace("{", "") + .replace("\\boxed", "") + ) + + +def normalize_extracted_answer(extracted_answer: str) -> str: + return ( + # In arabic these are the letters used for A-D in multiple choice questions + extracted_answer.replace("أ", " A") + .replace("ب", " B") + .replace("ج", " C") + .replace("د", " D") + # In Bengali these are the letters used for A-D in multiple choice questions + .replace("অ", " A") + .replace("ব", " B") + .replace("ড", " C") + .replace("ঢ", " D") + # In Japanese these are the letters sometimes used for A-D in multiple choice questions + .replace("A", " A") + .replace("B", " B") + .replace("C", " C") + .replace("D", " D") + .strip() + ) + + +def extract_answer_with_regex(text: str, regexes: Optional[list[str]] = None) -> str: + """ + Extracts the answer from the text using a regex search. + """ + regexes = regexes or [] + for regex in regexes: + match = re.search(regex, text) + if match: + extracted_answer = match.group(1) + return extracted_answer + return text diff --git a/pyspur/backend/pyspur/evals/evaluator.py b/pyspur/backend/pyspur/evals/evaluator.py new file mode 100644 index 0000000000000000000000000000000000000000..d1a5e501fb5f551b496e47c2b7ab77df691a7de1 --- /dev/null +++ b/pyspur/backend/pyspur/evals/evaluator.py @@ -0,0 +1,642 @@ +# inspired by https://github.com/google-deepmind/gemma/blob/main/colabs/gsm8k_eval.ipynb +import argparse +import asyncio +import importlib.util +import os +import re +from typing import Any, Callable, Dict, List, Optional, Union + +import pandas as pd +import yaml +from datasets import Dataset, load_dataset +from jinja2 import Template + +from ..evals.common import EQUALITY_TEMPLATE, normalize_extracted_answer +from ..execution.workflow_executor import WorkflowExecutor +from ..schemas.workflow_schemas import WorkflowDefinitionSchema + +# Precompiled regular expressions +NUMBER_REGEX = re.compile(r"-?[\d,]*\.?\d+", re.MULTILINE | re.DOTALL | re.IGNORECASE) + + +def find_numbers(text: str) -> List[str]: + """Find all numbers in a string.""" + return NUMBER_REGEX.findall(text) + + +def find_number(text: str, answer_delimiter: str = "The answer is") -> str: + """Find the most relevant number in a string.""" + if answer_delimiter in text: + answer = text.split(answer_delimiter)[-1] + numbers = find_numbers(answer) + if numbers: + return numbers[0] + numbers = find_numbers(text) + return numbers[-1] if numbers else "" + + +def maybe_remove_comma(text: str) -> str: + """Remove commas from numbers in a string.""" + return text.replace(",", "") + + +def load_dataset_by_name( + dataset_name: str, + split: str = "test", + subset: Optional[str] = None, + process_docs: Optional[Callable[[Dataset], Dataset]] = None, +) -> Dataset: + """Load a dataset by name or from a CSV file and return the specified split.""" + if dataset_name.endswith(".csv"): + dataset = pd.read_csv(dataset_name) + dataset = Dataset.from_pandas(dataset) + else: + dataset_args = {"cache_dir": "/tmp"} + if subset: + dataset = load_dataset(dataset_name, subset, **dataset_args) + else: + dataset = load_dataset(dataset_name, **dataset_args) + dataset = dataset[split] + if process_docs: + dataset = process_docs(dataset) + return dataset + + +# https://github.com/EleutherAI/lm-evaluation-harness/blob/1185e89a044618b5adc6f0b9363b629a19fffdc4/lm_eval/utils.py#L402 +def ignore_constructor(loader, node): + return node + + +# https://github.com/EleutherAI/lm-evaluation-harness/blob/1185e89a044618b5adc6f0b9363b629a19fffdc4/lm_eval/utils.py#L406 +def import_function(loader, node): + function_name = loader.construct_scalar(node) + yaml_path = os.path.dirname(loader.name) + + *module_name, function_name = function_name.split(".") + if isinstance(module_name, list): + module_name = ".".join(module_name) + module_path = os.path.normpath(os.path.join(yaml_path, "{}.py".format(module_name))) + + spec = importlib.util.spec_from_file_location(module_name, module_path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + + function = getattr(module, function_name) + return function + + +# https://github.com/EleutherAI/lm-evaluation-harness/blob/1185e89a044618b5adc6f0b9363b629a19fffdc4/lm_eval/utils.py#L423 +def load_yaml_config(yaml_path=None, yaml_config=None, yaml_dir=None, mode="full"): + if mode == "simple": + constructor_fn = ignore_constructor + elif mode == "full": + constructor_fn = import_function + + # Add the import_function constructor to the YAML loader + yaml.add_constructor("!function", constructor_fn) + if yaml_config is None: + with open(yaml_path, "rb") as file: + yaml_config = yaml.full_load(file) + + if yaml_dir is None: + yaml_dir = os.path.dirname(yaml_path) + + assert yaml_dir is not None + + if "include" in yaml_config: + include_path = yaml_config["include"] + del yaml_config["include"] + + if isinstance(include_path, str): + include_path = [include_path] + + # Load from the last one first + include_path.reverse() + final_yaml_config = {} + for path in include_path: + # Assumes that path is a full path. + # If not found, assume the included yaml + # is in the same dir as the original yaml + if not os.path.isfile(path): + path = os.path.join(yaml_dir, path) + + try: + included_yaml_config = load_yaml_config(yaml_path=path, mode=mode) + final_yaml_config.update(included_yaml_config) + except Exception as ex: + # If failed to load, ignore + raise ex + + final_yaml_config.update(yaml_config) + return final_yaml_config + return yaml_config + + +def generate_input_prompt(problem: dict, doc_to_text: str, preamble: str) -> str: + """Generate the input prompt for the model.""" + question_text = Template(doc_to_text).render(**problem) + full_prompt = f"{preamble}\n\n{question_text}" + return full_prompt.strip() + + +async def check_equality(expr1: str, expr2: str) -> bool: + """ + Check if two expressions are equal by using the call_model function. + + Args: + expr1 (str): The first expression. + expr2 (str): The second expression. + + Returns: + bool: True if expressions are equal, False otherwise. + """ + prompt = EQUALITY_TEMPLATE % {"expression1": expr1, "expression2": expr2} + # response = await call_model(prompt) + # return response.lower().strip() == "yes" + # TODO (jean): Implement equality check using simple LLM + return True + + +def extract_output_variable(outputs: dict, workflow_output_variable: str) -> Any: + """ + Extract the output value from the outputs dictionary based on the workflow_output_variable. + + Args: + outputs (dict): The dictionary containing workflow outputs. + workflow_output_variable (str): The variable to extract, potentially nested. + + Returns: + Any: The extracted output value. + + Raises: + ValueError: If the specified variable is not found or the structure is invalid. + """ + current_level = outputs + remaining_variable = workflow_output_variable + + while remaining_variable: + # Find the longest matching key in the current level + matched_key = None + for key in sorted(current_level.keys(), key=len, reverse=True): + if remaining_variable.startswith(key): + matched_key = key + break + + if not matched_key: + # Debugging: Log available keys for better error diagnosis + print(f"Current level keys: {list(current_level.keys())}") + raise ValueError( + f"Key '{remaining_variable}' not found in the current level of outputs" + ) + + # Descend into the matched key + current_level = current_level[matched_key] + + # Remove the matched key and the delimiter from the remaining variable + remaining_variable = remaining_variable[len(matched_key) :] + if remaining_variable.startswith("-"): + remaining_variable = remaining_variable[1:] + + # If the remaining variable is empty, return the current level + if not remaining_variable: + return current_level + + # If the remaining variable is a single key and the current level is a dictionary, + # directly return the value if it exists + if isinstance(current_level, dict) and remaining_variable in current_level: + return current_level[remaining_variable] + + return current_level + + +async def execute_workflow( + full_prompt: str, + workflow_definition: Optional[WorkflowDefinitionSchema] = None, + workflow_output_variable: Optional[str] = None, +) -> str: + """ + Executes an LLM workflow. + + Args: + full_prompt: The prompt to send + workflow: Optional workflow definition to execute + workflow_output_variable: The output variable to extract from workflow results + + Returns: + str: The model's response text + """ + if workflow_definition is None: + raise ValueError("Workflow definition is required") + + # Find input node - we know workflows must have exactly one InputNode + input_node = next( + (node for node in workflow_definition.nodes if node.node_type == "InputNode"), + None, + ) + if input_node is None: + raise ValueError("Workflow must have an InputNode") + + # Extract input schema from the InputNode + input_schema = input_node.config.get("input_schema", {}) + initial_inputs = {input_node.id: {key: full_prompt for key in input_schema.keys()}} + + # Execute workflow with error handling + executor = WorkflowExecutor(workflow_definition) + try: + outputs = await executor(initial_inputs) + outputs = {k: v.model_dump() for k, v in outputs.items()} + except Exception as e: + # Log the error for debugging purposes + print(f"Workflow execution failed: {e}") + # Use an empty output to indicate failure + outputs = {} + + # Debugging: Log the outputs dictionary and workflow_output_variable + print(f"Workflow Output Variable: {workflow_output_variable}") + + # Extract output from specified variable using the new function + outputs = extract_output_variable(outputs, workflow_output_variable) + + print(f"Extracted outputs: {outputs}") + return str(outputs) + + +def extract_answer( + text: str, + answer_extraction: Dict[str, Any], +) -> str: + """ + Extracts the answer from text based on extraction logic specified in the configuration. + + Args: + text (str): The text to extract the answer from. + answer_extraction (Dict[str, Any]): Configuration for answer extraction, including functions and regexes. + + Returns: + str: The extracted answer. + """ + if text is None: + return "" + + # Extract regexes and functions from the extraction configuration + regexes = answer_extraction.get("regexes", []) + functions = answer_extraction.get("functions", []) + + extracted_answer = text + + # Dynamically apply the specified string processing functions in order + for func_name in functions: + # Retrieve the function object from the globals() dictionary + func = globals().get(func_name) + if func and callable(func): + extracted_answer = func(extracted_answer) + else: + raise ValueError(f"Function '{func_name}' is not defined or not callable.") + + # Apply regex patterns to extract the relevant portion of the response + for regex in regexes: + match = re.search(regex, extracted_answer, re.IGNORECASE) + if match: + extracted_answer = match.group(1) + break # Stop after the first successful match + + return extracted_answer.strip() + + +async def evaluate_answer(predicted_answer, ground_truth_answer, evaluation: Dict[str, Any]): + """Evaluates if the predicted answer matches the ground truth based on evaluation logic.""" + if predicted_answer is None or ground_truth_answer is None: + return False + + evaluation_method = evaluation.get("method", "default").lower() + if evaluation_method == "numeric": + try: + correct = float(predicted_answer) == float(ground_truth_answer) + except: + correct = predicted_answer == ground_truth_answer + return correct + elif evaluation_method == "exact_match": + return predicted_answer.strip().lower() == ground_truth_answer.strip().lower() + elif evaluation_method == "mcq": + # Normalize both answers before comparison + return ( + normalize_extracted_answer(predicted_answer).strip().upper() + == normalize_extracted_answer(ground_truth_answer).strip().upper() + ) + elif evaluation_method == "math": + print(f"Checking equality between {predicted_answer} and {ground_truth_answer}") + return await check_equality(predicted_answer, ground_truth_answer) + else: + # Default evaluation method + return predicted_answer == ground_truth_answer + + +def get_ground_truth_answer(problem, doc_to_target): + """Extracts the ground truth answer using the doc_to_target template.""" + doc_to_target_template = Template(doc_to_target) + ground_truth = doc_to_target_template.render(**problem) + return ground_truth.strip() + + +async def evaluate_dataset_batch( + dataset: Dataset, + eval_config: Dict[str, Any], + workflow_definition: WorkflowDefinitionSchema, + batch_size: int = 10, + subject: Optional[str] = None, + subject_category_mapping: Optional[Dict[str, str]] = None, + category_correct: Optional[Dict[str, int]] = None, + category_total: Optional[Dict[str, int]] = None, + output_variable: Optional[str] = None, +) -> dict: + """ + Evaluate the model on a dataset in batches. + + This function performs the core evaluation logic, including generating prompts, + calling the model, and comparing predictions with ground truth answers. + + Args: + dataset: The dataset to evaluate on. + eval_config: Configuration for the evaluation task. + workflow: Workflow definition to execute. + batch_size: Size of batches for processing. + subject: Optional subject name for categorization. + subject_category_mapping: Optional mapping of subjects to categories. + category_correct: Optional dict tracking correct predictions per category. + category_total: Optional dict tracking total samples per category. + output_variable: Optional output variable name from workflow output. + + Returns: + dict: Evaluation metrics, including accuracy and category-wise performance. + """ + # Extract task configuration + preamble = eval_config.get("preamble", "") + doc_to_text = eval_config.get("doc_to_text", "") + doc_to_target = eval_config.get("doc_to_target", "") + ground_truth_answer_extraction = eval_config.get("ground_truth_answer_extraction", {}) + predicted_answer_extraction = eval_config.get("predicted_answer_extraction", {}) + evaluation = eval_config.get("evaluation", {}) + + # Initialize tracking variables + all_responses = {} + short_responses = {} + total = len(dataset) + correct = 0 + example_id = 0 + per_example_results = [] + + # Initialize category tracking if needed + if subject_category_mapping and category_correct is None and category_total is None: + categories = set(subject_category_mapping.values()) + category_correct = {category: 0 for category in categories} + category_total = {category: 0 for category in categories} + + # Process dataset in batches + for batch in dataset.iter(batch_size=batch_size): + transformed_batch = [dict(zip(batch.keys(), values)) for values in zip(*batch.values())] + full_prompts = [ + generate_input_prompt(problem, doc_to_text, preamble) for problem in transformed_batch + ] + + # Call the model on all prompts in the batch concurrently + responses = await asyncio.gather( + *[ + execute_workflow(prompt, workflow_definition, output_variable) + for prompt in full_prompts + ] + ) + + # Process responses and update metrics + for idx, (problem, full_prompt, response_text) in enumerate( + zip(transformed_batch, full_prompts, responses) + ): + predicted_answer = extract_answer(response_text, predicted_answer_extraction) + ground_truth_answer = extract_answer( + get_ground_truth_answer(problem, doc_to_target), + ground_truth_answer_extraction, + ) + + # Store responses + all_responses[example_id] = response_text + short_responses[example_id] = predicted_answer + + # Evaluate correctness + is_correct = await evaluate_answer(predicted_answer, ground_truth_answer, evaluation) + correct += int(is_correct) + + # Add per-example results + per_example_results.append( + { + "example_id": example_id, + "prompt": full_prompt, + "predicted_answer": predicted_answer, + "ground_truth_answer": ground_truth_answer, + "is_correct": is_correct, + } + ) + + # Update category metrics if needed + if subject_category_mapping: + subject_value = subject or problem.get("subject") or problem.get("Subject") + category = subject_category_mapping.get(subject_value, "other") + category_total[category] += 1 + if is_correct: + category_correct[category] += 1 + + # Log results + print(f"Example ID {example_id}") + print(f"Predicted answer: {predicted_answer}") + print(f"Ground truth answer: {ground_truth_answer}") + print(f"Correct: {is_correct}") + print("=" * 40) + example_id += 1 + + # Calculate final metrics + metrics = { + "total_samples": total, + "correct_predictions": correct, + "accuracy": correct / total, + "all_responses": all_responses, + "short_responses": short_responses, + "per_example_results": per_example_results, + } + + if subject_category_mapping: + category_accuracy = { + category: ( + category_correct[category] / category_total[category] + if category_total[category] > 0 + else 0 + ) + for category in category_correct + } + metrics.update( + { + "category_correct": category_correct, + "category_total": category_total, + "category_accuracy": category_accuracy, + } + ) + + return metrics + + +def calculate_metrics( + total_correct: int, + total_samples: int, + category_correct: Optional[Dict[str, int]] = None, + category_total: Optional[Dict[str, int]] = None, +) -> Dict[str, Any]: + """ + Calculate overall and category-wise metrics. + + Args: + total_correct (int): Total number of correct predictions. + total_samples (int): Total number of samples evaluated. + category_correct (Optional[Dict[str, int]]): Correct predictions per category. + category_total (Optional[Dict[str, int]]): Total samples per category. + + Returns: + Dict[str, Any]: A dictionary containing overall accuracy and category-wise accuracy. + """ + overall_accuracy = total_correct / total_samples if total_samples > 0 else 0 + metrics = { + "total_samples": total_samples, + "correct_predictions": total_correct, + "accuracy": overall_accuracy, + } + + if category_correct and category_total: + category_accuracy = { + category: ( + category_correct[category] / category_total[category] + if category_total[category] > 0 + else 0 + ) + for category in category_correct + } + metrics["category_accuracy"] = category_accuracy + + return metrics + + +async def prepare_and_evaluate_dataset( + eval_config: Dict[str, Any], + workflow_definition: WorkflowDefinitionSchema, + batch_size: int = 10, + num_samples: Optional[int] = None, + output_variable: Optional[str] = None, +) -> Dict[str, Any]: + """ + Prepare the dataset and evaluate the model on it. + + This function handles dataset loading, preprocessing, and category-level metrics. + It supports evaluating multiple subsets of a dataset and aggregates the results. + + Args: + eval_config: Configuration for the evaluation task. + workflow: Workflow definition to execute. + batch_size: Size of batches for processing. + num_samples: Optional number of samples to evaluate. + output_variable: Optional output variable name from workflow output. + + Returns: + Dict[str, Any]: Evaluation metrics, including accuracy and category-wise metrics. + """ + # Extract dataset config + dataset_name = eval_config.get("dataset_name") + if not dataset_name: + raise ValueError("dataset_name must be provided in eval_config.") + + dataset_split = eval_config.get("dataset_split", "test") + dataset_subsets = eval_config.get("dataset_subsets", None) # Subsets to evaluate + process_docs = eval_config.get("process_docs") + + # Initialize metrics for aggregation + subset_metrics = {} + total_correct = 0 + total_samples = 0 + category_correct = None + category_total = None + + # Handle multiple subsets if specified + if dataset_subsets and isinstance(dataset_subsets, list): + for subset in dataset_subsets: + # Load the subset of the dataset + dataset = load_dataset_by_name(dataset_name, dataset_split, subset, process_docs) + if num_samples is not None: + dataset = dataset.shuffle(seed=42).select(range(min(num_samples, len(dataset)))) + + # Evaluate the subset + metrics = await evaluate_dataset_batch( + dataset=dataset, + eval_config=eval_config, + workflow_definition=workflow_definition, + batch_size=batch_size, + subject=subset, + subject_category_mapping=eval_config.get("subject_category_mapping"), + output_variable=output_variable, # Pass only the variable name + ) + + # Aggregate metrics + subset_metrics[subset] = metrics + total_correct += metrics["correct_predictions"] + total_samples += metrics["total_samples"] + + # Update category-level metrics if applicable + if "category_correct" in metrics and "category_total" in metrics: + if category_correct is None: + category_correct = metrics["category_correct"] + category_total = metrics["category_total"] + else: + for category in metrics["category_correct"]: + category_correct[category] += metrics["category_correct"][category] + category_total[category] += metrics["category_total"][category] + else: + # Single dataset evaluation + dataset = load_dataset_by_name(dataset_name, dataset_split, None, process_docs) + if num_samples is not None: + dataset = dataset.shuffle(seed=42).select(range(min(num_samples, len(dataset)))) + + metrics = await evaluate_dataset_batch( + dataset=dataset, + eval_config=eval_config, + workflow_definition=workflow_definition, + batch_size=batch_size, + output_variable=output_variable, # Pass only the variable name + ) + + # Aggregate metrics + subset_metrics["default"] = metrics + total_correct = metrics["correct_predictions"] + total_samples = metrics["total_samples"] + category_correct = metrics.get("category_correct", None) + category_total = metrics.get("category_total", None) + + # Calculate overall metrics + results = calculate_metrics( + total_correct=total_correct, + total_samples=total_samples, + category_correct=category_correct, + category_total=category_total, + ) + results["subset_metrics"] = subset_metrics + + return results + + +if __name__ == "__main__": + test_dict = { + "input_node": "hello", + "node_1732234981946": "hello", + "node_1732236371719": "hello", + "node_1732236371719-bd9e35ad-829c-4618-a828-546c4a3e65f6": "hello", + "node_1732236371719-bd9e35ad-829c-4618-a828-546c4a3e65f6-50292c67-7a0e-4f11-97fc-ba2cd8ee45ef": "hello", + "node_1732236371719-bd9e35ad-829c-4618-a828-546c4a3e65f6-50292c67-7a0e-4f11-97fc-ba2cd8ee45ef-5ecee81a-746d-4847-abf6-0e6a6d241de3": "hello", + } + print( + extract_output_variable( + test_dict, + "bd9e35ad-829c-4618-a828-546c4a3e65f6-50292c67-7a0e-4f11-97fc-ba2cd8ee45ef-5ecee81a-746d-4847-abf6-0e6a6d241de3-correct_option", + ) + ) diff --git a/pyspur/backend/pyspur/evals/tasks/gpqa.py b/pyspur/backend/pyspur/evals/tasks/gpqa.py new file mode 100644 index 0000000000000000000000000000000000000000..2f9520c10a4faf10de8c2e7a84bb71a27a84c8e3 --- /dev/null +++ b/pyspur/backend/pyspur/evals/tasks/gpqa.py @@ -0,0 +1,38 @@ +import random +import re + +import datasets + + +def preprocess(text): + if text is None: + return " " + text = text.strip() + text = text.replace(" [title]", ". ") + text = re.sub("\\[.*?\\]", "", text) + text = text.replace(" ", " ") + return text + + +def process_docs(dataset: datasets.Dataset) -> datasets.Dataset: + def _process_doc(doc): + choices = [ + preprocess(doc["Incorrect Answer 1"]), + preprocess(doc["Incorrect Answer 2"]), + preprocess(doc["Incorrect Answer 3"]), + preprocess(doc["Correct Answer"]), + ] + + random.shuffle(choices) + correct_answer_index = choices.index(preprocess(doc["Correct Answer"])) + + out_doc = { + "choice1": choices[0], + "choice2": choices[1], + "choice3": choices[2], + "choice4": choices[3], + "answer": f"{chr(65 + correct_answer_index)}", + } + return out_doc + + return dataset.map(_process_doc) diff --git a/pyspur/backend/pyspur/evals/tasks/gpqa.yaml b/pyspur/backend/pyspur/evals/tasks/gpqa.yaml new file mode 100644 index 0000000000000000000000000000000000000000..58d7eadb544a319b33067162461ef49320c032bf --- /dev/null +++ b/pyspur/backend/pyspur/evals/tasks/gpqa.yaml @@ -0,0 +1,24 @@ +dataset_name: 'https://openaipublic.blob.core.windows.net/simple-evals/gpqa_diamond.csv' +dataset_split: None # Not needed since we're loading from a CSV +dataset_subsets: None + +process_docs: !function gpqa.process_docs + +preamble: | + Answer the following multiple choice question. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD, nothing else. Think step by step before answering. +doc_to_text: "What is the correct answer to this question:{{Question}}\nChoices:\n(A) {{choice1}}\n(B) {{choice2}}\n(C) {{choice3}}\n(D) {{choice4}}\nAnswer:" +doc_to_target: "{{answer}}" + +predicted_answer_extraction: + regexes: + - "(?i)Answer\\s*:\\s*([A-D])" + +evaluation: + method: mcq + +metadata: + name: "GPQA" + description: "Google-Proof Questions Answering" + type: "Reasoning" + num_samples: 1000 + paper_link: "https://arxiv.org/abs/2305.17100" diff --git a/pyspur/backend/pyspur/evals/tasks/gsm8k.yaml b/pyspur/backend/pyspur/evals/tasks/gsm8k.yaml new file mode 100644 index 0000000000000000000000000000000000000000..11b400b9c38be9d1cdda43d6b0a718b436ae676f --- /dev/null +++ b/pyspur/backend/pyspur/evals/tasks/gsm8k.yaml @@ -0,0 +1,49 @@ +preamble: | + As an expert problem solver solve step by step the following mathematical questions. + + Q: There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today? + A: We start with 15 trees. Later we have 21 trees. The difference must be the number of trees they planted. So, they must have planted 21 - 15 = 6 trees. The answer is 6. + + Q: If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot? + A: There are 3 cars in the parking lot already. 2 more arrive. Now there are 3 + 2 = 5 cars. The answer is 5. + + Q: Leah had 32 chocolates and her sister had 42. If they ate 35, how many pieces do they have left in total? + A: Leah had 32 chocolates and Leah's sister had 42. That means there were originally 32 + 42 = 74 chocolates. 35 have been eaten. So in total they still have 74 - 35 = 39 chocolates. The answer is 39. + + Q: Jason had 20 lollipops. He gave Denny some lollipops. Now Jason has 12 lollipops. How many lollipops did Jason give to Denny? + A: Jason had 20 lollipops. Since he only has 12 now, he must have given the rest to Denny. The number of lollipops he has given to Denny must have been 20 - 12 = 8 lollipops. The answer is 8. + + Q: Shawn has five toys. For Christmas, he got two toys each from his mom and dad. How many toys does he have now? + A: He has 5 toys. He got 2 from mom, so after that he has 5 + 2 = 7 toys. Then he got 2 more from dad, so in total he has 7 + 2 = 9 toys. The answer is 9. + + Q: There were nine computers in the server room. Five more computers were installed each day, from monday to thursday. How many computers are now in the server room? + A: There are 4 days from monday to thursday. 5 computers were added each day. That means in total 4 * 5 = 20 computers were added. There were 9 computers in the beginning, so now there are 9 + 20 = 29 computers. The answer is 29. + + Q: Michael had 58 golf balls. On tuesday, he lost 23 golf balls. On wednesday, he lost 2 more. How many golf balls did he have at the end of wednesday? + A: Michael initially had 58 balls. He lost 23 on Tuesday, so after that he has 58 - 23 = 35 balls. On Wednesday he lost 2 more so now he has 35 - 2 = 33 balls. The answer is 33. + + Q: Olivia has $23. She bought five bagels for $3 each. How much money does she have left? + A: She bought 5 bagels for $3 each. This means she spent 5 * $3 = $15 on the bagels. She had $23 in beginning, so now she has $23 - $15 = $8. The answer is 8. + +doc_to_text: "Q: {{question}}\nA:" +doc_to_target: "{{answer}}" +dataset_name: gsm8k +dataset_split: test +dataset_subsets: + - main +ground_truth_answer_extraction: + functions: + - find_number + - maybe_remove_comma +predicted_answer_extraction: + functions: + - find_number + - maybe_remove_comma +evaluation: + method: "numeric" +metadata: + name: "GSM8K" + description: "GSM8K is a dataset of 8,000+ elementary math word problems with answers." + type: "Reasoning" + num_samples: 1000 + paper_link: "https://example.com/original-paper" \ No newline at end of file diff --git a/pyspur/backend/pyspur/evals/tasks/math.yaml b/pyspur/backend/pyspur/evals/tasks/math.yaml new file mode 100644 index 0000000000000000000000000000000000000000..16f664fbad7b01aa6a11c169e000632c1eea9b77 --- /dev/null +++ b/pyspur/backend/pyspur/evals/tasks/math.yaml @@ -0,0 +1,34 @@ +dataset_name: https://openaipublic.blob.core.windows.net/simple-evals/math_test.csv + +preamble: | + Solve the following math problem step by step. The last line of your response should be of the form Answer: $ANSWER (without quotes) where $ANSWER is the answer to the problem.\n\n Remember to put your answer on its own line after "Answer:", and you do not need to use a \boxed command. + +doc_to_text: | + {{ Question }} + +doc_to_target: | + {{ Answer }} + +ground_truth_answer_extraction: + functions: + - find_number + - maybe_remove_comma + regexes: + - "(?i)Answer\\s*:\\s*([^\n]+)" + +predicted_answer_extraction: + functions: + - find_number + - maybe_remove_comma + regexes: + - "(?i)Answer\\s*:\\s*([^\n]+)" + +evaluation: + method: math + +metadata: + name: "Math" + description: "Math is a dataset of 1000+ math word problems with answers." + type: "Reasoning" + num_samples: 1000 + paper_link: "https://example.com/original-paper" \ No newline at end of file diff --git a/pyspur/backend/pyspur/evals/tasks/mmlu.yaml b/pyspur/backend/pyspur/evals/tasks/mmlu.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9c2fd877bed7785ff8797cf44fe7af119333451d --- /dev/null +++ b/pyspur/backend/pyspur/evals/tasks/mmlu.yaml @@ -0,0 +1,139 @@ +preamble: | + Answer the following multiple choice question. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before answering. +doc_to_text: | + {{question}} + A) {{choices[0]}} + B) {{choices[1]}} + C) {{choices[2]}} + D) {{choices[3]}} +doc_to_target: | + {% set index_to_letter = ['A', 'B', 'C', 'D'] %}{{index_to_letter[answer]}} +predicted_answer_extraction: + regexes: + - "(?i)Answer\\s*:\\s*([A-D])" +dataset_name: "cais/mmlu" +dataset_split: "test" +dataset_subsets: + - abstract_algebra + - anatomy + - astronomy + - business_ethics + - clinical_knowledge + - college_biology + - college_chemistry + - college_computer_science + - college_mathematics + - college_medicine + - college_physics + - computer_security + - conceptual_physics + - econometrics + - electrical_engineering + - elementary_mathematics + - formal_logic + - global_facts + - high_school_biology + - high_school_chemistry + - high_school_computer_science + - high_school_european_history + - high_school_geography + - high_school_government_and_politics + - high_school_macroeconomics + - high_school_mathematics + - high_school_microeconomics + - high_school_physics + - high_school_psychology + - high_school_statistics + - high_school_us_history + - high_school_world_history + - human_aging + - human_sexuality + - international_law + - jurisprudence + - logical_fallacies + - machine_learning + - management + - marketing + - medical_genetics + - miscellaneous + - moral_disputes + - moral_scenarios + - nutrition + - philosophy + - prehistory + - professional_accounting + - professional_law + - professional_medicine + - professional_psychology + - public_relations + - security_studies + - sociology + - us_foreign_policy + - virology + - world_religions + +subject_category_mapping: + abstract_algebra: "stem" + anatomy: "other" + astronomy: "stem" + business_ethics: "other" + clinical_knowledge: "other" + college_biology: "stem" + college_chemistry: "stem" + college_computer_science: "stem" + college_mathematics: "stem" + college_medicine: "other" + college_physics: "stem" + computer_security: "stem" + conceptual_physics: "stem" + econometrics: "social_sciences" + electrical_engineering: "stem" + elementary_mathematics: "stem" + formal_logic: "humanities" + global_facts: "other" + high_school_biology: "stem" + high_school_chemistry: "stem" + high_school_computer_science: "stem" + high_school_european_history: "humanities" + high_school_geography: "social_sciences" + high_school_government_and_politics: "social_sciences" + high_school_macroeconomics: "social_sciences" + high_school_mathematics: "stem" + high_school_microeconomics: "social_sciences" + high_school_physics: "stem" + high_school_psychology: "social_sciences" + high_school_statistics: "stem" + high_school_us_history: "humanities" + high_school_world_history: "humanities" + human_aging: "other" + human_sexuality: "social_sciences" + international_law: "humanities" + jurisprudence: "humanities" + logical_fallacies: "humanities" + machine_learning: "stem" + management: "other" + marketing: "other" + medical_genetics: "other" + miscellaneous: "other" + moral_disputes: "humanities" + moral_scenarios: "humanities" + nutrition: "other" + philosophy: "humanities" + prehistory: "humanities" + professional_accounting: "other" + professional_law: "humanities" + professional_medicine: "other" + professional_psychology: "social_sciences" + public_relations: "social_sciences" + security_studies: "social_sciences" + sociology: "social_sciences" + us_foreign_policy: "social_sciences" + virology: "other" + world_religions: "humanities" + +metadata: + name: "MMLU" + description: "Multiple Choice Questions on 57 Subjects" + type: "Reasoning" + num_samples: 1000 + paper_link: "https://example.com/original-paper" diff --git a/pyspur/backend/pyspur/examples/tool_function_example.py b/pyspur/backend/pyspur/examples/tool_function_example.py new file mode 100644 index 0000000000000000000000000000000000000000..de42eb21f9dbedb915dc4d678dc789aa7bb01074 --- /dev/null +++ b/pyspur/backend/pyspur/examples/tool_function_example.py @@ -0,0 +1,323 @@ +"""Example script demonstrating how to create custom tools using the @tool_function decorator. + +This script shows different ways to create and use custom tool functions in PySpur. +""" + +import asyncio +import sys +from pathlib import Path +from typing import Any, Dict, List + +# Add the parent directory to the path so we can import pyspur +script_dir = Path(__file__).parent.parent.parent +sys.path.append(str(script_dir)) + +from pydantic import BaseModel, Field + +from pyspur.nodes.base import BaseNodeOutput +from pyspur.nodes.decorator import tool_function + + +# Basic tool function example +@tool_function( + name="string_manipulation", + display_name="String Manipulation Tool", + description="A tool that performs various string manipulations", + category="Text Processing", +) +def string_manipulator(text: str, operation: str = "uppercase") -> str: + """Manipulate a string based on the specified operation. + + Args: + text: The input text to manipulate + operation: The operation to perform (uppercase, lowercase, capitalize, reverse) + + Returns: + The manipulated string + + """ + if operation == "uppercase": + return text.upper() + elif operation == "lowercase": + return text.lower() + elif operation == "capitalize": + return text.capitalize() + elif operation == "reverse": + return text[::-1] + else: + return f"Unknown operation: {operation}. Try uppercase, lowercase, capitalize, or reverse." + + +# Tool with a custom output model +class MathResult(BaseNodeOutput): + """Custom output model for math operations.""" + + result: float = Field(..., description="The result of the math operation") + operation: str = Field(..., description="The operation that was performed") + inputs: List[float] = Field(..., description="The inputs that were used in the operation") + + +@tool_function( + name="math_operations", + display_name="Math Operations Tool", + description="A tool that performs basic math operations", + category="Math", + output_model=MathResult, # Specify a custom output model +) +def math_tool(numbers: List[float], operation: str = "sum") -> Dict[str, Any]: + """Perform a mathematical operation on a list of numbers. + + Args: + numbers: A list of numbers to operate on + operation: The operation to perform (sum, average, min, max, product) + + Returns: + A dictionary containing the result, operation name, and input numbers + + """ + if not numbers: + return { + "result": 0.0, + "operation": operation, + "inputs": numbers, + } + + if operation == "sum": + result = sum(numbers) + elif operation == "average": + result = sum(numbers) / len(numbers) + elif operation == "min": + result = min(numbers) + elif operation == "max": + result = max(numbers) + elif operation == "product": + result = 1.0 + for num in numbers: + result *= num + else: + result = 0.0 + operation = f"Unknown operation: {operation}" + + return { + "result": result, + "operation": operation, + "inputs": numbers, + } + + +# Tool with templated parameters that can access input values +@tool_function( + name="template_example", + display_name="Template Example Tool", + description="A tool that demonstrates using Jinja2 templates in tool config", + category="Examples", + # Additional configuration parameters can be added here + example_param="This is an example parameter", +) +def template_example(greeting: str, name: str) -> str: + """Create a greeting message. + + Args: + greeting: The greeting to use (e.g., "Hello", "Hi", "Hey") + name: The name to greet + + Returns: + A formatted greeting message + + """ + return f"{greeting}, {name}!" + + +# Tool that returns a Pydantic model converted to a dictionary +class WeatherData(BaseModel): + """Weather data model.""" + + temperature: float = Field(..., description="Temperature in Celsius") + humidity: float = Field(..., description="Humidity percentage") + conditions: str = Field(..., description="Weather conditions (e.g., sunny, rainy)") + + +@tool_function( + name="weather_tool", + display_name="Weather Tool", + description="A tool that returns weather data for a location", + category="Weather", + # When returning a Pydantic model, convert it to a dictionary for the tool function +) +def weather_tool(location: str, units: str = "metric") -> Dict[str, Any]: + """Get weather data for a location. + + Args: + location: The location to get weather data for + units: The units to use for temperature (metric or imperial) + + Returns: + Weather data including temperature, humidity, and conditions + + """ + # In a real tool, this would make an API call to a weather service + # This is a simplified example that returns mock data + mock_data = { + "New York": {"temperature": 22.5, "humidity": 65.0, "conditions": "Partly Cloudy"}, + "London": {"temperature": 18.0, "humidity": 80.0, "conditions": "Rainy"}, + "Tokyo": {"temperature": 28.0, "humidity": 70.0, "conditions": "Sunny"}, + } + + # Get data for the location or use default values + location_data: dict[str, Any] = mock_data.get( + location, {"temperature": 20.0, "humidity": 50.0, "conditions": "Unknown"} + ) + + # Apply temperature conversion if imperial units requested + if units == "imperial": + location_data["temperature"] = location_data["temperature"] * 9 / 5 + 32 + + # Create the model and convert to dictionary + weather = WeatherData(**location_data) + return weather.model_dump() + + +def test_tools_directly(): + """Test the tool functions directly.""" + print("\n" + "=" * 50) + print("TESTING TOOLS AS DIRECT FUNCTION CALLS") + print("=" * 50) + + # Test string manipulator + print("\nString Manipulator Tool:") + result = string_manipulator("Hello, world!", "uppercase") + print(f" Function call with 'uppercase': {result}") + result = string_manipulator("Hello, world!", "reverse") + print(f" Function call with 'reverse': {result}") + + # Test math tool + print("\nMath Tool:") + result = math_tool([1, 2, 3, 4, 5], "average") + print(f" Function call with 'average': {result}") + result = math_tool([1, 2, 3, 4, 5], "product") + print(f" Function call with 'product': {result}") + + # Test template example + print("\nTemplate Tool:") + result = template_example("Hello", "PySpur User") + print(f" Function call: {result}") + + # Test weather tool + print("\nWeather Tool:") + result = weather_tool("Tokyo") + print(f" Function call for Tokyo (metric): {result}") + result = weather_tool("London", "imperial") + print(f" Function call for London (imperial): {result}") + + +def test_tools_as_nodes(): + """Test the tool functions as nodes.""" + print("\n" + "=" * 50) + print("TESTING TOOLS AS NODES") + print("=" * 50) + + # Test string manipulator node + print("\nString Manipulator Node:") + # Create a configuration for the node + config = string_manipulator.config_model() + # Set parameters directly on the config object + config = config.model_validate({"text": "Hello, world!", "operation": "uppercase"}) + + # Create a node with the configuration + node = string_manipulator.create_node(name="string_node", config=config) + # Run the node + result = asyncio.run(node(input={})) + print(f" Node execution result: {result.output}") # type: ignore + print(f" Node class: {node.__class__.__name__}") + print(f" Node display name: {node.display_name}") + print(f" Node category: {node.category}") + + # Test math tool node with a custom output model + print("\nMath Node (with custom output model):") + config = math_tool.config_model() + config = config.model_validate( + {"numbers": [1, 2, 3, 4, 5], "operation": "product", "has_fixed_output": True} + ) + + node = math_tool.create_node(name="math_node", config=config) + result = asyncio.run(node(input={})) + print(f" Node execution result: result={result.result}, operation={result.operation}") # type: ignore + print(f" Inputs used: {result.inputs}") # type: ignore + + # Test template tool with Jinja2 template rendering + print("\nTemplate Node (with Jinja2 rendering):") + config = template_example.config_model() + config = config.model_validate( + {"greeting": "Hello", "name": "{{ input.user_name }}", "has_fixed_output": True} + ) + config.has_fixed_output = True + + node = template_example.create_node(name="template_node", config=config) + # Provide input data that will be used to render the template + result = asyncio.run(node(input={"user_name": "Template User"})) + print(f" Node execution result: {result.output}") # type: ignore + + # Test weather tool node + print("\nWeather Node:") + config = weather_tool.config_model() + config = config.model_validate( + {"location": "New York", "units": "imperial", "has_fixed_output": True} + ) + config.has_fixed_output = True + + node = weather_tool.create_node(name="weather_node", config=config) + result = asyncio.run(node(input={})) + print(" Node execution result:") + for key, value in result.model_dump().items(): + if key != "output_json_schema": + print(f" {key}: {value}") + + +def examine_tool_metadata(): + """Examine metadata from the decorated tools.""" + print("\n" + "=" * 50) + print("EXAMINING TOOL METADATA") + print("=" * 50) + + # Examine string manipulator tool + print("\nString Manipulator Tool Metadata:") + print(f" Tool function name: {string_manipulator.__name__}") # type: ignore + print(f" Node class: {string_manipulator.node_class.__name__}") + print(f" Config model: {string_manipulator.config_model.__name__}") + print(f" Output model: {string_manipulator.output_model.__name__}") + + # Print config schema for math tool + print("\nMath Tool Config Schema:") + config_schema = math_tool.config_model.model_json_schema() + print(f" Required: {config_schema.get('required', [])}") + print(f" Properties: {list(config_schema.get('properties', {}).keys())}") + + # Print output schema for math tool + print("\nMath Tool Output Schema:") + output_schema = math_tool.output_model.model_json_schema() + print(f" Required: {output_schema.get('required', [])}") + print(f" Properties: {list(output_schema.get('properties', {}).keys())}") + + +def main(): + """Run the example demonstrations.""" + print("=" * 80) + print("TOOL FUNCTION DECORATOR EXAMPLES") + print("=" * 80) + + # Test tools directly as functions + test_tools_directly() + + # Test tools as nodes + test_tools_as_nodes() + + # Examine tool metadata + examine_tool_metadata() + + print("\n" + "=" * 80) + print("This example demonstrates how to create tools using the @tool_function decorator.") + print("=" * 80) + + +if __name__ == "__main__": + main() diff --git a/pyspur/backend/pyspur/examples/workflow_as_code_example.py b/pyspur/backend/pyspur/examples/workflow_as_code_example.py new file mode 100644 index 0000000000000000000000000000000000000000..3e46dc0a182d6b5a446d7543df35189bd19fa2a5 --- /dev/null +++ b/pyspur/backend/pyspur/examples/workflow_as_code_example.py @@ -0,0 +1,329 @@ +"""Example script demonstrating workflow-as-code capabilities. + +This script shows how to create workflows programmatically using the WorkflowBuilder. +It includes examples of common workflow patterns. +""" + +import json +import sys +from pathlib import Path + +# Add the parent directory to the path so we can import pyspur +script_dir = Path(__file__).parent.parent.parent +sys.path.append(str(script_dir)) + +from pyspur.schemas.workflow_schemas import SpurType +from pyspur.workflow_builder import WorkflowBuilder + + +def create_simple_qa_workflow(): + """Create a simple question answering workflow with an LLM node.""" + builder = WorkflowBuilder( + name="Simple QA Workflow", + description="A simple workflow that takes a question and answers it using an LLM", + ) + + # Add nodes + input_node = builder.add_node( + node_type="InputNode", + config={ + "output_schema": {"question": "string"}, + "output_json_schema": json.dumps( + {"type": "object", "properties": {"question": {"type": "string"}}} + ), + }, + ) + + llm_node = builder.add_node( + node_type="SingleLLMCallNode", + config={ + "llm_info": { + "model": "openai/gpt-4o", + "temperature": 0.7, + "max_tokens": 1000, + }, + "system_message": ( + "You are a helpful assistant who answers questions concisely and accurately." + ), + "user_message": "{{input_node.question}}", + }, + ) + + output_node = builder.add_node( + node_type="OutputNode", + config={ + "output_schema": {"answer": "string"}, + "output_json_schema": json.dumps( + {"type": "object", "properties": {"answer": {"type": "string"}}} + ), + "output_map": {"answer": "llm_node.response"}, + }, + ) + + # Connect nodes + builder.add_link(input_node, llm_node) + builder.add_link(llm_node, output_node) + + # Add test inputs + builder.add_test_input({"question": "What is the capital of France?"}) + + return builder.build() + + +def create_chatbot_workflow(): + """Create a chatbot workflow that maintains conversation history.""" + builder = WorkflowBuilder( + name="Simple Chatbot", + description=( + "A chatbot that responds to user messages while maintaining conversation history" + ), + ) + + # Set workflow type to chatbot + builder.set_spur_type(SpurType.CHATBOT) + + # Add nodes + input_node = builder.add_node( + node_type="InputNode", + config={ + "output_schema": { + "user_message": "string", + "session_id": "string", + "message_history": "List[Dict[str, str]]", + }, + "output_json_schema": json.dumps( + { + "type": "object", + "properties": { + "user_message": {"type": "string"}, + "session_id": {"type": "string"}, + "message_history": { + "type": "array", + "items": { + "type": "object", + "properties": { + "role": {"type": "string"}, + "content": {"type": "string"}, + }, + }, + }, + }, + "required": ["user_message", "session_id"], + } + ), + }, + ) + + llm_node = builder.add_node( + node_type="SingleLLMCallNode", + config={ + "llm_info": { + "model": "anthropic/claude-3-haiku", + "temperature": 0.7, + "max_tokens": 1000, + }, + "system_message": ( + "You are a helpful assistant. Respond in a friendly and concise manner." + ), + "user_message": "{{input_node.user_message}}", + "message_history": "{{input_node.message_history}}", + }, + ) + + output_node = builder.add_node( + node_type="OutputNode", + config={ + "output_schema": {"assistant_message": "string"}, + "output_json_schema": json.dumps( + { + "type": "object", + "properties": {"assistant_message": {"type": "string"}}, + "required": ["assistant_message"], + } + ), + "output_map": {"assistant_message": "llm_node.response"}, + }, + ) + + # Connect nodes + builder.add_link(input_node, llm_node) + builder.add_link(llm_node, output_node) + + return builder.build() + + +def create_complex_routing_workflow(): + """Create a workflow with conditional routing based on a classifier.""" + builder = WorkflowBuilder( + name="Content Classifier", + description=( + "A workflow that routes content to different processors based on classification" + ), + ) + + # Add nodes + input_node = builder.add_node( + node_type="InputNode", + config={ + "output_schema": {"content": "string"}, + "output_json_schema": json.dumps( + {"type": "object", "properties": {"content": {"type": "string"}}} + ), + }, + ) + + classifier_node = builder.add_node( + node_type="SingleLLMCallNode", + config={ + "llm_info": { + "model": "openai/gpt-4o", + "temperature": 0.2, + }, + "system_message": ( + "You are a content classifier. Categorize the input content " + "into one of these categories: question, statement, or request." + ), + "user_message": ( + "Classify the following text into one category (question, statement," + " or request):\n\n{{input_node.content}}\n\nProvide only the category" + " name without any explanation." + ), + }, + ) + + router_node = builder.add_node( + node_type="RouterNode", + config={ + "routes": [ + {"id": "question", "condition": "{{classifier_node.response == 'question'}}"}, + {"id": "statement", "condition": "{{classifier_node.response == 'statement'}}"}, + {"id": "request", "condition": "{{classifier_node.response == 'request'}}"}, + {"id": "default", "condition": "true"}, + ] + }, + ) + + question_handler = builder.add_node( + node_type="SingleLLMCallNode", + config={ + "llm_info": { + "model": "openai/gpt-4o", + "temperature": 0.7, + }, + "system_message": "You are an expert at answering questions.", + "user_message": "Here's a question I need you to answer:\n\n{{input_node.content}}", + }, + ) + + statement_handler = builder.add_node( + node_type="SingleLLMCallNode", + config={ + "llm_info": { + "model": "openai/gpt-4o", + "temperature": 0.7, + }, + "system_message": "You are an expert at evaluating statements.", + "user_message": ( + "Here's a statement. Please evaluate if it's true," + " false, or needs clarification:\n\n{{input_node.content}}" + ), + }, + ) + + request_handler = builder.add_node( + node_type="SingleLLMCallNode", + config={ + "llm_info": { + "model": "openai/gpt-4o", + "temperature": 0.7, + }, + "system_message": "You are an expert at handling requests.", + "user_message": ( + "Here's a request. Please explain how I might fulfill it:\n\n{{input_node.content}}" + ), + }, + ) + + default_handler = builder.add_node( + node_type="SingleLLMCallNode", + config={ + "llm_info": { + "model": "openai/gpt-4o", + "temperature": 0.7, + }, + "system_message": "You are a general-purpose assistant.", + "user_message": ( + "I couldn't determine the type of this content." + " Please respond appropriately:\n\n{{input_node.content}}" + ), + }, + ) + + coalesce_node = builder.add_node( + node_type="CoalesceNode", config={"output_schema": {"response": "string"}} + ) + + output_node = builder.add_node( + node_type="OutputNode", + config={ + "output_schema": {"content": "string", "content_type": "string", "response": "string"}, + "output_map": { + "content": "input_node.content", + "content_type": "classifier_node.response", + "response": "coalesce_node.response", + }, + }, + ) + + # Connect nodes + builder.add_link(input_node, classifier_node) + builder.add_link(classifier_node, router_node) + + # Connect router to handlers + builder.add_link(source_id=router_node, target_id=question_handler, source_handle="question") + builder.add_link(source_id=router_node, target_id=statement_handler, source_handle="statement") + builder.add_link(source_id=router_node, target_id=request_handler, source_handle="request") + builder.add_link(source_id=router_node, target_id=default_handler, source_handle="default") + + # Connect handlers to coalesce node + builder.add_link(question_handler, coalesce_node) + builder.add_link(statement_handler, coalesce_node) + builder.add_link(request_handler, coalesce_node) + builder.add_link(default_handler, coalesce_node) + + # Connect coalesce node to output + builder.add_link(coalesce_node, output_node) + + # Add test inputs + builder.add_test_input({"content": "What is the capital of France?"}) + builder.add_test_input({"content": "The sky is blue."}) + builder.add_test_input({"content": "Please find me a good restaurant."}) + + return builder.build() + + +def main(): + """Create and save example workflows.""" + examples_dir = Path(__file__).parent / "workflow_examples" + examples_dir.mkdir(exist_ok=True) + + # Create and save the simple QA workflow + qa_workflow = create_simple_qa_workflow() + with open(examples_dir / "simple_qa_workflow.json", "w") as f: + json.dump(qa_workflow.model_dump(), f, indent=2) + + # Create and save the chatbot workflow + chatbot_workflow = create_chatbot_workflow() + with open(examples_dir / "chatbot_workflow.json", "w") as f: + json.dump(chatbot_workflow.model_dump(), f, indent=2) + + # Create and save the complex routing workflow + routing_workflow = create_complex_routing_workflow() + with open(examples_dir / "complex_routing_workflow.json", "w") as f: + json.dump(routing_workflow.model_dump(), f, indent=2) + + print(f"Example workflows saved to {examples_dir}") + + +if __name__ == "__main__": + main() diff --git a/pyspur/backend/pyspur/execution/__init__.py b/pyspur/backend/pyspur/execution/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pyspur/backend/pyspur/execution/__pycache__/__init__.cpython-312.pyc b/pyspur/backend/pyspur/execution/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0e482d037011f4f0f6d4423d12e90c86a451e44f Binary files /dev/null and b/pyspur/backend/pyspur/execution/__pycache__/__init__.cpython-312.pyc differ diff --git a/pyspur/backend/pyspur/execution/__pycache__/workflow_execution_context.cpython-312.pyc b/pyspur/backend/pyspur/execution/__pycache__/workflow_execution_context.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d5bc95501bb1247405a4e7ce3f3a7586d821acd9 Binary files /dev/null and b/pyspur/backend/pyspur/execution/__pycache__/workflow_execution_context.cpython-312.pyc differ diff --git a/pyspur/backend/pyspur/execution/task_recorder.py b/pyspur/backend/pyspur/execution/task_recorder.py new file mode 100644 index 0000000000000000000000000000000000000000..b872a5c979567f0eb0813e70b6a602c69b5d6e74 --- /dev/null +++ b/pyspur/backend/pyspur/execution/task_recorder.py @@ -0,0 +1,198 @@ +from datetime import datetime +from typing import Any, Dict, Optional, List + +from pydantic import BaseModel +from sqlalchemy.orm import Session + +from ..models.task_model import TaskModel, TaskStatus +from ..schemas.workflow_schemas import WorkflowDefinitionSchema + + +class TaskRecorder: + def __init__(self, db: Session, run_id: str): + self.db = db + self.run_id = run_id + self.tasks: Dict[str, TaskModel] = {} + + # Load existing tasks from the database + existing_tasks = db.query(TaskModel).filter(TaskModel.run_id == run_id).all() + + # Group tasks by node_id + node_tasks: Dict[str, List[TaskModel]] = {} + for task in existing_tasks: + if task.node_id not in node_tasks: + node_tasks[task.node_id] = [] + node_tasks[task.node_id].append(task) + + # For each node_id, select the most relevant task + for node_id, tasks in node_tasks.items(): + # If there's only one task, use it + if len(tasks) == 1: + self.tasks[node_id] = tasks[0] + continue + + # If there are multiple tasks, prioritize: + # 1. COMPLETED tasks + # 2. PAUSED tasks + # 3. RUNNING tasks + # 4. PENDING tasks + # 5. FAILED tasks + # 6. CANCELED tasks + + # First, try to find a COMPLETED task + completed_tasks = [t for t in tasks if t.status == TaskStatus.COMPLETED] + if completed_tasks: + # Use the most recently completed task + self.tasks[node_id] = max(completed_tasks, key=lambda t: t.end_time or datetime.min) + continue + + # Next, try to find a PAUSED task + paused_tasks = [t for t in tasks if t.status == TaskStatus.PAUSED] + if paused_tasks: + self.tasks[node_id] = paused_tasks[0] + continue + + # Next, try to find a RUNNING task + running_tasks = [t for t in tasks if t.status == TaskStatus.RUNNING] + if running_tasks: + self.tasks[node_id] = running_tasks[0] + continue + + # Next, try to find a PENDING task + pending_tasks = [t for t in tasks if t.status == TaskStatus.PENDING] + if pending_tasks: + self.tasks[node_id] = pending_tasks[0] + continue + + # Next, try to find a FAILED task + failed_tasks = [t for t in tasks if t.status == TaskStatus.FAILED] + if failed_tasks: + self.tasks[node_id] = failed_tasks[0] + continue + + # Finally, use a CANCELED task + canceled_tasks = [t for t in tasks if t.status == TaskStatus.CANCELED] + if canceled_tasks: + self.tasks[node_id] = canceled_tasks[0] + continue + + # If we get here, just use the first task + self.tasks[node_id] = tasks[0] + + def create_task( + self, + node_id: str, + inputs: Dict[str, Any], + ): + # First check if there's already a task for this node in our in-memory cache + if node_id in self.tasks: + existing_task = self.tasks[node_id] + + # If the existing task is COMPLETED, PAUSED, or RUNNING, don't create a new one + if existing_task.status in [ + TaskStatus.COMPLETED, + TaskStatus.PAUSED, + TaskStatus.RUNNING, + ]: + # Just update the inputs if needed + if inputs and not existing_task.inputs: + existing_task.inputs = inputs + self.db.add(existing_task) + self.db.commit() + return + + # For other statuses (PENDING, FAILED, CANCELED), update the existing task + existing_task.inputs = inputs + existing_task.status = TaskStatus.RUNNING + existing_task.start_time = datetime.now() + existing_task.end_time = None + existing_task.error = None + self.db.add(existing_task) + self.db.commit() + self.db.refresh(existing_task) + return + + # If we don't have a task in memory, check the database for any existing tasks + existing_task = ( + self.db.query(TaskModel) + .filter(TaskModel.run_id == self.run_id, TaskModel.node_id == node_id) + .order_by(TaskModel.end_time.desc().nullslast()) + .first() + ) + + if existing_task: + # If there's an existing task in the database, use it + if existing_task.status in [TaskStatus.COMPLETED, TaskStatus.PAUSED]: + # Don't modify COMPLETED or PAUSED tasks + self.tasks[node_id] = existing_task + return + + # Update the existing task for other statuses + existing_task.inputs = inputs + existing_task.status = TaskStatus.RUNNING + existing_task.start_time = datetime.now() + existing_task.end_time = None + existing_task.error = None + self.db.add(existing_task) + self.db.commit() + self.db.refresh(existing_task) + self.tasks[node_id] = existing_task + return + + # If no existing task was found, create a new one + task = TaskModel( + run_id=self.run_id, + node_id=node_id, + inputs=inputs, + ) + self.db.add(task) + self.db.commit() + self.db.refresh(task) + self.tasks[node_id] = task + return + + def update_task( + self, + node_id: str, + status: TaskStatus, + inputs: Optional[Dict[str, Any]] = None, + outputs: Optional[Dict[str, Any]] = None, + error: Optional[str] = None, + subworkflow: Optional[WorkflowDefinitionSchema] = None, + subworkflow_output: Optional[Dict[str, BaseModel]] = None, + end_time: Optional[datetime] = None, + is_downstream_of_pause: bool = False, + ): + task = self.tasks.get(node_id) + if not task: + self.create_task(node_id, inputs={}) + task = self.tasks[node_id] + + # If task is downstream of a paused node, mark it as pending instead of failed/canceled + if is_downstream_of_pause and status in [TaskStatus.FAILED, TaskStatus.CANCELED]: + status = TaskStatus.PENDING + error = None # Clear any error message + + task.status = status + if inputs: + task.inputs = inputs + if outputs: + task.outputs = outputs + if error: + task.error = error + if end_time: + task.end_time = end_time + if subworkflow: + task.subworkflow = subworkflow.model_dump() + if subworkflow_output: + task.subworkflow_output = { + k: ( + [x.model_dump() if isinstance(x, BaseModel) else x for x in v] + if isinstance(v, list) + else v.model_dump() + ) + for k, v in subworkflow_output.items() + } + self.db.add(task) + self.db.commit() + return diff --git a/pyspur/backend/pyspur/execution/workflow_execution_context.py b/pyspur/backend/pyspur/execution/workflow_execution_context.py new file mode 100644 index 0000000000000000000000000000000000000000..b7cd00e16f5812247acdb92be6717133a5eb372b --- /dev/null +++ b/pyspur/backend/pyspur/execution/workflow_execution_context.py @@ -0,0 +1,20 @@ +from typing import Optional, Dict, Any + +from pydantic import BaseModel +from sqlalchemy.orm import Session + + +class WorkflowExecutionContext(BaseModel): + """ + Contains the context of a workflow execution. + """ + + workflow_id: str + run_id: str + parent_run_id: Optional[str] + run_type: str + db_session: Optional[Session] = None + workflow_definition: Optional[Dict[str, Any]] = None + + class Config: + arbitrary_types_allowed = True diff --git a/pyspur/backend/pyspur/execution/workflow_executor.py b/pyspur/backend/pyspur/execution/workflow_executor.py new file mode 100644 index 0000000000000000000000000000000000000000..4a76853b96b781be92a7b5772f692cdbb0f4975d --- /dev/null +++ b/pyspur/backend/pyspur/execution/workflow_executor.py @@ -0,0 +1,1026 @@ +import asyncio +import traceback +from collections import defaultdict, deque +from datetime import datetime +from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Set, Tuple, Union + +from pydantic import ValidationError + +from ..models.run_model import RunModel, RunStatus +from ..models.task_model import TaskStatus +from ..models.user_session_model import MessageModel, SessionModel +from ..models.workflow_model import WorkflowModel +from ..nodes.base import BaseNode, BaseNodeOutput +from ..nodes.factory import NodeFactory +from ..nodes.logic.human_intervention import PauseError +from ..schemas.workflow_schemas import ( + SpurType, + WorkflowDefinitionSchema, + WorkflowNodeSchema, +) +from .task_recorder import TaskRecorder +from .workflow_execution_context import WorkflowExecutionContext + +if TYPE_CHECKING: + from .task_recorder import TaskRecorder + + +class UpstreamFailureError(Exception): + pass + + +class UnconnectedNodeError(Exception): + pass + + +class WorkflowExecutor: + """Handles the execution of a workflow.""" + + def __init__( + self, + workflow: Union[WorkflowModel, WorkflowDefinitionSchema], + initial_inputs: Optional[Dict[str, Dict[str, Any]]] = None, + task_recorder: Optional["TaskRecorder"] = None, + context: Optional[WorkflowExecutionContext] = None, + resumed_node_ids: Optional[List[str]] = None, + ): + # Convert WorkflowModel to WorkflowDefinitionSchema if needed + if isinstance(workflow, WorkflowModel): + self.workflow = WorkflowDefinitionSchema.model_validate(workflow.definition) + else: + self.workflow = self._process_subworkflows(workflow) + self._initial_inputs = initial_inputs or {} + if task_recorder: + self.task_recorder = task_recorder + elif context and context.run_id and context.db_session: + print("Creating task recorder from context") + self.task_recorder = TaskRecorder(context.db_session, context.run_id) + else: + self.task_recorder = None + self.context = context + self._node_dict: Dict[str, WorkflowNodeSchema] = {} + self.node_instances: Dict[str, BaseNode] = {} + self._dependencies: Dict[str, Set[str]] = {} + self._node_tasks: Dict[str, asyncio.Task[Optional[BaseNodeOutput]]] = {} + self._outputs: Dict[str, Optional[BaseNodeOutput]] = {} + self._failed_nodes: Set[str] = set() + self._resumed_node_ids: Set[str] = set(resumed_node_ids or []) + self._build_node_dict() + self._build_dependencies() + + @property + def outputs(self) -> Dict[str, Optional[BaseNodeOutput]]: + """Get the current outputs of the workflow execution.""" + return self._outputs + + @outputs.setter + def outputs(self, value: Dict[str, Optional[BaseNodeOutput]]): + """Set the outputs of the workflow execution.""" + self._outputs = value + + def _process_subworkflows(self, workflow: WorkflowDefinitionSchema) -> WorkflowDefinitionSchema: + # Group nodes by parent_id + nodes_by_parent: Dict[Optional[str], List[WorkflowNodeSchema]] = {} + for node in workflow.nodes: + parent_id = node.parent_id + if parent_id not in nodes_by_parent: + nodes_by_parent[parent_id] = [] + node_copy = node.model_copy(update={"parent_id": None}) + nodes_by_parent[parent_id].append(node_copy) + + # Get root level nodes (no parent) + root_nodes = nodes_by_parent.get(None, []) + + # Process each parent node's children into subworkflows + for parent_id, child_nodes in nodes_by_parent.items(): + if parent_id is None: + continue + + # Find the parent node in root nodes + parent_node = next((node for node in root_nodes if node.id == parent_id), None) + if not parent_node: + continue + + # Get links between child nodes + child_node_ids = {node.id for node in child_nodes} + subworkflow_links = [ + link + for link in workflow.links + if link.source_id in child_node_ids and link.target_id in child_node_ids + ] + + # Create subworkflow + subworkflow = WorkflowDefinitionSchema(nodes=child_nodes, links=subworkflow_links) + + # Update parent node's config with subworkflow + parent_node.config = { + **parent_node.config, + "subworkflow": subworkflow.model_dump(), + } + + # Return new workflow with only root nodes + return WorkflowDefinitionSchema( + nodes=root_nodes, + links=[ + link + for link in workflow.links + if not any( + node.parent_id + for node in workflow.nodes + if node.id in (link.source_id, link.target_id) + ) + ], + test_inputs=workflow.test_inputs, + spur_type=workflow.spur_type, + ) + + def _build_node_dict(self): + self._node_dict = {node.id: node for node in self.workflow.nodes} + + def _build_dependencies(self): + dependencies: Dict[str, Set[str]] = {node.id: set() for node in self.workflow.nodes} + for link in self.workflow.links: + dependencies[link.target_id].add(link.source_id) + self._dependencies = dependencies + + def _get_source_handles(self) -> Dict[Tuple[str, str], str]: + """Build a mapping of (source_id, target_id) -> source_handle for router nodes only.""" + source_handles: Dict[Tuple[str, str], str] = {} + for link in self.workflow.links: + source_node = self._node_dict[link.source_id] + if source_node.node_type == "RouterNode": + if not link.source_handle: + raise ValueError( + f"Missing source_handle in link from router node " + f"{link.source_id} to {link.target_id}" + ) + source_handles[(link.source_id, link.target_id)] = link.source_handle + return source_handles + + def _get_async_task_for_node_execution( + self, node_id: str + ) -> asyncio.Task[Optional[BaseNodeOutput]]: + if node_id in self._node_tasks: + return self._node_tasks[node_id] + # Start task for the node + task = asyncio.create_task(self._execute_node(node_id)) + self._node_tasks[node_id] = task + + # Record task + if self.task_recorder: + self.task_recorder.create_task(node_id, {}) + return task + + def get_blocked_nodes(self, paused_node_id: str) -> Set[str]: + """Find all nodes that are blocked by the paused node. + + These are nodes that directly or indirectly depend on the paused node. + + Args: + paused_node_id: The ID of the node that is paused + + Returns: + Set of node IDs that are blocked by the paused node + + """ + blocked_nodes: Set[str] = set() + + # Build a dependency graph (which nodes depend on which) + dependents: Dict[str, Set[str]] = defaultdict(set) + for node_id, deps in self._dependencies.items(): + for dep_id in deps: + dependents[dep_id].add(node_id) + + # Start with the paused node and find all nodes that depend on it + queue: deque[str] = deque([paused_node_id]) + visited: Set[str] = set() + + while queue: + current_node_id: str = queue.popleft() + visited.add(current_node_id) + + # Get all nodes that depend on this node + for dependent in dependents.get(current_node_id, set()): + if dependent not in visited: + blocked_nodes.add(dependent) + queue.append(dependent) + + return blocked_nodes + + def is_downstream_of_pause(self, node_id: str) -> bool: + """Check if a node is downstream of any paused node. + + Args: + node_id: The ID of the node to check + + Returns: + True if the node is downstream of a paused node, False otherwise + + """ + # If this node is being resumed, it's not considered downstream of a pause + if node_id in self._resumed_node_ids: + return False + + # Check if we have paused nodes in the workflow + paused_nodes: Set[str] = set() + if self.task_recorder: + # Find paused nodes from tasks + for task in self.task_recorder.tasks.values(): + # Only consider nodes that are still paused and not being resumed + if task.status == TaskStatus.PAUSED and task.node_id not in self._resumed_node_ids: + paused_nodes.add(task.node_id) + + if not paused_nodes: + return False + + # Check if this node is downstream of any paused node + for paused_node_id in paused_nodes: + if _workflow_definition := getattr(self.context, "workflow_definition", None): + blocked_nodes = self.get_blocked_nodes(paused_node_id) + if node_id in blocked_nodes: + return True + + return False + + def _get_workflow_definition(self) -> Dict[str, Any]: + """Get workflow definition from context.""" + return getattr(self.context, "workflow_definition", {}) or {} + + def _get_message_history(self, session_id: str) -> List[Dict[str, str]]: + """Extract message history from a session. + + For chatbot workflows, this extracts the history of user and assistant messages + from the session's message history. + """ + if not self.context or not self.context.db_session: + return [] + + # Query the session and its messages + session = ( + self.context.db_session.query(SessionModel) + .filter(SessionModel.id == session_id) + .first() + ) + + if not session: + return [] + + history: List[Dict[str, Any]] = [] + for message in session.messages: + history.append(message.content) + + return history + + def _store_message_history( + self, session_id: str, user_message: str, assistant_message: str + ) -> None: + """Store the current turn's messages in the session history.""" + if not self.context or not self.context.db_session: + return + + # Create user message + user_msg = MessageModel( + session_id=session_id, + run_id=self.context.run_id if self.context else None, + content={"role": "user", "content": user_message}, + ) + self.context.db_session.add(user_msg) + + # Create assistant message + assistant_msg = MessageModel( + session_id=session_id, + run_id=self.context.run_id if self.context else None, + content={"role": "assistant", "content": assistant_message}, + ) + self.context.db_session.add(assistant_msg) + self.context.db_session.commit() + + def _mark_node_as_paused( + self, node_id: str, pause_output: Optional[BaseNodeOutput] = None + ) -> None: + """Mark a node as paused and store its output.""" + # Store the output + self._outputs[node_id] = pause_output + + # Update the task recorder if available + if self.task_recorder: + self.task_recorder.update_task( + node_id=node_id, + status=TaskStatus.PAUSED, + end_time=datetime.now(), + outputs=self._serialize_output(pause_output) if pause_output else None, + ) + + def _mark_downstream_nodes_as_pending(self, paused_node_id: str) -> Set[str]: + """Mark all downstream nodes of a paused node as pending.""" + # Use explicit typing to satisfy the linter + blocked_nodes: Set[str] = self.get_blocked_nodes(paused_node_id) + + # Record for the return value + all_updated_nodes = set(blocked_nodes) + + # Update tasks if we have a recorder + if self.task_recorder: + current_time = datetime.now() + for blocked_node_id in blocked_nodes: + self.task_recorder.update_task( + node_id=blocked_node_id, + status=TaskStatus.PENDING, + end_time=current_time, + is_downstream_of_pause=True, + ) + + # Remove from failed nodes if necessary + if blocked_node_id in self._failed_nodes: + self._failed_nodes.remove(blocked_node_id) + + return all_updated_nodes + + def _update_run_status_to_paused(self) -> None: + """Update the run status to paused in the database.""" + if self.context is None: + return + + if self.context.db_session is None: + return + + if not hasattr(self.context, "run_id"): + return + + run = ( + self.context.db_session.query(RunModel) + .filter(RunModel.id == self.context.run_id) + .first() + ) + + if run: + run.status = RunStatus.PAUSED + # Note: We don't commit immediately - caller should commit when all updates are done + + def _handle_pause_exception(self, node_id: str, pause_exception: PauseError) -> None: + """Handle a pause exception for a node.""" + # Mark the node as paused + self._mark_node_as_paused(node_id, pause_exception.output) + + # Mark downstream nodes as pending + self._mark_downstream_nodes_as_pending(node_id) + + # Update run status + self._update_run_status_to_paused() + + # Commit all changes at once + if ( + self.context is not None + and hasattr(self.context, "db_session") + and self.context.db_session is not None + ): + self.context.db_session.commit() + + def _get_tasks_to_update(self, run: RunModel) -> List[str]: + """Get list of task IDs that need to be updated from CANCELED to PENDING.""" + # Find all downstream nodes of any paused node + all_blocked_nodes: Set[str] = set() + for task in run.tasks: + if task.status == TaskStatus.PAUSED: + blocked_nodes = self.get_blocked_nodes(task.node_id) + all_blocked_nodes.update(blocked_nodes) + + # Return tasks that are CANCELED but should be PENDING + return [ + task.node_id + for task in run.tasks + if task.status == TaskStatus.CANCELED and task.node_id in all_blocked_nodes + ] + + def _fix_canceled_tasks_after_pause(self, paused_node_id: str) -> None: + """Fix any tasks that were incorrectly marked as CANCELED.""" + if not all([self.task_recorder, self.context, hasattr(self.context, "run_id")]): + return + + assert self.context is not None + + if not hasattr(self.context, "db_session") or self.context.db_session is None: + return + + run = ( + self.context.db_session.query(RunModel) + .filter(RunModel.id == self.context.run_id) + .first() + ) + if not run: + return + + tasks_to_update = self._get_tasks_to_update(run) + if tasks_to_update: + current_time = datetime.now() + for node_id in tasks_to_update: + if self.task_recorder: + self.task_recorder.update_task( + node_id=node_id, + status=TaskStatus.PENDING, + end_time=current_time, + is_downstream_of_pause=True, + ) + self.context.db_session.commit() + + async def _execute_node(self, node_id: str) -> Optional[BaseNodeOutput]: # noqa: C901 + node = self._node_dict[node_id] + node_input = {} + try: + if node_id in self._outputs: + return self._outputs[node_id] + + # Check if this node already has a completed task + if self.task_recorder and node_id in self.task_recorder.tasks: + task = self.task_recorder.tasks[node_id] + if task.status == TaskStatus.COMPLETED and task.outputs: + # If the node already has a completed task, use its outputs + try: + # Create a node instance to get the output model + node_instance = NodeFactory.create_node( + node_name=node.title, + node_type_name=node.node_type, + config=node.config, + ) + node_output = node_instance.output_model.model_validate(task.outputs) + self._outputs[node_id] = node_output + return node_output + except Exception as e: + print(f"Error validating outputs for completed task {node_id}: {e}") + # Continue with normal execution if validation fails + + # Check if this node is downstream of any paused nodes + if self.is_downstream_of_pause(node_id): + if self.task_recorder: + self.task_recorder.update_task( + node_id=node_id, status=TaskStatus.PENDING, end_time=datetime.now() + ) + return None + + # Check if any predecessor nodes failed + dependency_ids = self._dependencies.get(node_id, set()) + + # Wait for dependencies + predecessor_outputs: List[Optional[BaseNodeOutput]] = [] + if dependency_ids: + try: + predecessor_outputs = await asyncio.gather( + *( + self._get_async_task_for_node_execution(dep_id) + for dep_id in dependency_ids + ), + ) + except Exception as e: + raise UpstreamFailureError( + f"Node {node_id} skipped due to upstream failure" + ) from e + + if any(dep_id in self._failed_nodes for dep_id in dependency_ids): + print(f"Node {node_id} skipped due to upstream failure") + self._failed_nodes.add(node_id) + raise UpstreamFailureError(f"Node {node_id} skipped due to upstream failure") + + # Before checking for None outputs, check if any dependencies are paused + has_paused_dependencies = False + if self.task_recorder: + for dep_id in dependency_ids: + task = self.task_recorder.tasks.get(dep_id) + if task and task.status == TaskStatus.PAUSED: + has_paused_dependencies = True + break + + # If a dependency is paused, mark this node as PENDING instead of CANCELED + if has_paused_dependencies: + self._outputs[node_id] = None + if self.task_recorder: + self.task_recorder.update_task( + node_id=node_id, + status=TaskStatus.PENDING, + end_time=datetime.now(), + is_downstream_of_pause=True, + ) + return None + + if node.node_type != "CoalesceNode" and any( + output is None for output in predecessor_outputs + ): + self._outputs[node_id] = None + if self.task_recorder: + # Check if any dependencies are paused before marking as CANCELED + has_paused_dependencies = False + for dep_id in dependency_ids: + task = self.task_recorder.tasks.get(dep_id) + if task and task.status == TaskStatus.PAUSED: + has_paused_dependencies = True + break + + if has_paused_dependencies: + self.task_recorder.update_task( + node_id=node_id, + status=TaskStatus.PENDING, + end_time=datetime.now(), + is_downstream_of_pause=True, + ) + else: + self.task_recorder.update_task( + node_id=node_id, + status=TaskStatus.CANCELED, + end_time=datetime.now(), + ) + return None + + # Get source handles mapping + source_handles = self._get_source_handles() + + # Build node input, handling router outputs specially + for dep_id, output in zip(dependency_ids, predecessor_outputs, strict=False): + if output is None: + continue + predecessor_node = self._node_dict[dep_id] + if predecessor_node.node_type == "RouterNode": + # For router nodes, we must have a source handle + source_handle = source_handles.get((dep_id, node_id)) + if not source_handle: + raise ValueError( + f"Missing source_handle in link from router node {dep_id} to {node_id}" + ) + # Get the specific route's output from the router + route_output = getattr(output, source_handle, None) + if route_output is not None: + node_input[predecessor_node.title] = route_output + else: + self._outputs[node_id] = None + if self.task_recorder: + self.task_recorder.update_task( + node_id=node_id, + status=TaskStatus.CANCELED, + end_time=datetime.now(), + ) + return None + elif predecessor_node.node_type == "HumanInterventionNode": + # Ensure the output is stored with the correct node ID + if hasattr(output, "model_dump"): + # Get a dictionary representation of the output to examine its structure + output_dict = output.model_dump() + # Special transformation for + # HumanInterventionNode - modify node_input directly + # + # This ensures downstream nodes can access by node ID + # like {{HumanInterventionNode_1.input_node.input_1}} + # + # Store the raw output data directly in the node_input + # using dep_id as the key + node_input[predecessor_node.title] = output_dict + else: + node_input[predecessor_node.title] = output + + # Special handling for InputNode - use initial inputs + if node.node_type == "InputNode": + node_input = self._initial_inputs.get(node_id, {}) + + # Only fail early for None inputs if it is NOT a CoalesceNode + if node.node_type != "CoalesceNode" and any(v is None for v in node_input.values()): + self._outputs[node_id] = None + return None + elif node.node_type == "CoalesceNode" and all(v is None for v in node_input.values()): + self._outputs[node_id] = None + return None + + # Remove None values from input + node_input = {k: v for k, v in node_input.items() if v is not None} + + # update task recorder with inputs + if self.task_recorder: + self.task_recorder.update_task( + node_id=node_id, + status=TaskStatus.RUNNING, + inputs={ + dep_id: output.model_dump() if hasattr(output, "model_dump") else output + for dep_id, output in node_input.items() + if node.node_type != "InputNode" + }, + ) + + # If node_input is empty, return None + if not node_input: + self._outputs[node_id] = None + raise UnconnectedNodeError(f"Node {node_id} has no input") + + node_instance = NodeFactory.create_node( + node_name=node.title, + node_type_name=node.node_type, + config=node.config, + ) + self.node_instances[node_id] = node_instance + + # Set workflow definition in node context if available + if hasattr(node_instance, "context"): + node_instance.context = WorkflowExecutionContext( + workflow_id=self.context.workflow_id if self.context else "", + run_id=self.context.run_id if self.context else "", + parent_run_id=self.context.parent_run_id if self.context else None, + run_type=self.context.run_type if self.context else "interactive", + db_session=self.context.db_session if self.context else None, + workflow_definition=self.workflow.model_dump(), + ) + + try: + output = await node_instance(node_input) + + # Update task recorder + if self.task_recorder: + self.task_recorder.update_task( + node_id=node_id, + status=TaskStatus.COMPLETED, + outputs=self._serialize_output(output), + end_time=datetime.now(), + subworkflow=node_instance.subworkflow, + subworkflow_output=node_instance.subworkflow_output, + ) + + # Store output + self._outputs[node_id] = output + return output + except PauseError as e: + self._handle_pause_exception(node_id, e) + # Return None to prevent downstream execution + return None + + except UpstreamFailureError as e: + self._failed_nodes.add(node_id) + self._outputs[node_id] = None + if self.task_recorder: + current_time = datetime.now() + + # Check if this node is downstream of a paused node + has_paused_upstream = False + if hasattr(self, "context") and self.context: + # Find any paused nodes + paused_node_ids: List[str] = [] + for _, task in self.task_recorder.tasks.items(): + if task.status == TaskStatus.PAUSED: + paused_node_ids.append(task.node_id) + + # Check if this node is blocked by any paused node + for paused_node_id in paused_node_ids: + blocked_nodes = self.get_blocked_nodes(paused_node_id) + if node_id in blocked_nodes: + has_paused_upstream = True + break + + if has_paused_upstream: + self.task_recorder.update_task( + node_id=node_id, + status=TaskStatus.PENDING, + end_time=current_time, + error=None, + is_downstream_of_pause=True, + ) + else: + self.task_recorder.update_task( + node_id=node_id, + status=TaskStatus.CANCELED, + end_time=current_time, + error="Upstream failure", + ) + raise e + except Exception as e: + error_msg = ( + f"Node execution failed:\n" + f"Node ID: {node_id}\n" + f"Node Type: {node.node_type}\n" + f"Node Title: {node.title}\n" + f"Inputs: {node_input}\n" + f"Error: {traceback.format_exc()}" + ) + print(error_msg) + self._failed_nodes.add(node_id) + if self.task_recorder: + current_time = datetime.now() + self.task_recorder.update_task( + node_id=node_id, + status=TaskStatus.FAILED, + end_time=current_time, + error=traceback.format_exc(limit=5), + ) + raise e + + def _serialize_output(self, output: Optional[BaseNodeOutput]) -> Optional[Dict[str, Any]]: + """Serialize node outputs, handling datetime objects.""" + if output is None: + return None + + data = output.model_dump() + + def _serialize_value(val: Any) -> Any: + """Recursively serialize values, handling datetime objects and sets.""" + if isinstance(val, datetime): + return val.isoformat() + elif isinstance(val, set): + return list(val) # type: ignore # Convert sets to lists + elif isinstance(val, dict): + return {str(key): _serialize_value(value) for key, value in val.items()} # type: ignore + elif isinstance(val, list): + return [_serialize_value(item) for item in val] # type: ignore + return val + + return {str(key): _serialize_value(value) for key, value in data.items()} + + async def _execute_workflow( # noqa: C901 + self, + input: Dict[str, Any] = {}, + node_ids: List[str] = [], + precomputed_outputs: Dict[str, Dict[str, Any] | List[Dict[str, Any]]] = {}, + ) -> Dict[str, BaseNodeOutput]: + # Handle precomputed outputs first + if precomputed_outputs: + for node_id, output in precomputed_outputs.items(): + try: + if isinstance(output, dict): + self._outputs[node_id] = NodeFactory.create_node( + node_name=self._node_dict[node_id].title, + node_type_name=self._node_dict[node_id].node_type, + config=self._node_dict[node_id].config, + ).output_model.model_validate(output) + else: + # If output is a list of dicts, do not validate the output + # these are outputs of loop nodes, + # their precomputed outputs are not supported yet + continue + + except ValidationError as e: + print( + f"[WARNING]: Precomputed output validation failed for node {node_id}: " + f"{e}\n skipping precomputed output" + ) + except AttributeError as e: + print( + f"[WARNING]: Node {node_id} does not have an output_model defined: " + f"{e}\n skipping precomputed output" + ) + except KeyError as e: + print( + f"[WARNING]: Node {node_id} not found in the predecessor workflow: " + f"{e}\n skipping precomputed output" + ) + + # Store input in initial inputs to be used by InputNode + input_node = next( + ( + node + for node in self.workflow.nodes + if node.node_type == "InputNode" and not node.parent_id + ), + ) + self._initial_inputs[input_node.id] = input + # also update outputs for input node + input_node_obj = NodeFactory.create_node( + node_name=input_node.title, + node_type_name=input_node.node_type, + config=input_node.config, + ) + self._outputs[input_node.id] = await input_node_obj(input) + + nodes_to_run = set(self._node_dict.keys()) + if node_ids: + nodes_to_run = set(node_ids) + + # skip nodes that have parent nodes, as they will be executed as part of their parent node + for node in self.workflow.nodes: + if node.parent_id: + nodes_to_run.discard(node.id) + + # drop outputs for nodes that need to be run + for node_id in nodes_to_run: + self._outputs.pop(node_id, None) + + # Start tasks for all nodes + for node_id in nodes_to_run: + self._get_async_task_for_node_execution(node_id) + + # Wait for all tasks to complete, but don't propagate exceptions + results = await asyncio.gather(*self._node_tasks.values(), return_exceptions=True) + + # Process results to handle any exceptions + paused_node_id: Optional[str] = None + paused_exception: Optional[PauseError] = None + for node_id, result in zip(self._node_tasks.keys(), results, strict=False): + if isinstance(result, PauseError): + # Handle pause state - don't mark as failed + paused_node_id = result.node_id + paused_exception = result + print(f"Node {node_id} paused: {str(result)}") + # Don't add to failed nodes since this is a pause state + continue + elif isinstance(result, Exception): + print(f"Node {node_id} failed with error: {str(result)}") + if paused_node_id and self.task_recorder: + # Check if this node is downstream of the paused node + is_downstream = False + current_node = node_id + while current_node in self._dependencies: + deps = self._dependencies[current_node] + if paused_node_id in deps: + is_downstream = True + break + # Check next level of dependencies + if not deps: + break + current_node = next(iter(deps)) + + if is_downstream: + # Update task status without marking as failed + self.task_recorder.update_task( + node_id=node_id, + status=TaskStatus.PENDING, + end_time=datetime.now(), + is_downstream_of_pause=True, + ) + continue + + self._failed_nodes.add(node_id) + self._outputs[node_id] = None + + # Handle any downstream nodes of paused nodes that might not have been processed yet + if paused_node_id is not None and self.task_recorder: + self._mark_downstream_nodes_as_pending(paused_node_id) + + # Final pass: fix any CANCELED tasks that should be PENDING + if paused_node_id is not None: + self._fix_canceled_tasks_after_pause(paused_node_id) + + # Ensure workflow status is updated to PAUSED if any node is paused + if paused_node_id is not None: + self._update_run_status_to_paused() + # Commit all database changes + if ( + self.context is not None + and hasattr(self.context, "db_session") + and self.context.db_session is not None + ): + self.context.db_session.commit() + + # If we have a paused node, re-raise the pause exception + if paused_exception is not None: + # This must be raised for API endpoints to catch it + raise paused_exception + + # return the non-None outputs + return {node_id: output for node_id, output in self._outputs.items() if output is not None} + + async def run( + self, + input: Dict[str, Any] = {}, + node_ids: List[str] = [], + precomputed_outputs: Dict[str, Dict[str, Any] | List[Dict[str, Any]]] = {}, + ) -> Dict[str, BaseNodeOutput]: + # For chatbot workflows, extract and inject message history + if self.workflow.spur_type == SpurType.CHATBOT: + session_id = input.get("session_id") + user_message = input.get("user_message") + message_history = input.get("message_history", []) + + if session_id and user_message: + if len(message_history) == 0: + # Get message history from the database + message_history = self._get_message_history(session_id) + + # Add message_history to input + input["message_history"] = message_history + + # Run the workflow + outputs = await self._execute_workflow(input, node_ids, precomputed_outputs) + + # For chatbot workflows, store the new messages + if self.workflow.spur_type == SpurType.CHATBOT: + session_id = input.get("session_id") + user_message = input.get("user_message") + + # Find the output node to get assistant's message + output_node = next( + (node for node in self.workflow.nodes if node.node_type == "OutputNode"), None + ) + + if output_node and session_id and user_message: + # Get assistant's message from outputs + assistant_message = None + if output_node.id in outputs: + output = outputs[output_node.id] + # Get the output as a dict to safely access fields + output_dict = output.model_dump() + assistant_message = str(output_dict.get("assistant_message", "")) + + if assistant_message: + # Store the messages + self._store_message_history(session_id, user_message, assistant_message) + + return outputs + + async def __call__( + self, + input: Dict[str, Any] = {}, + node_ids: List[str] = [], + precomputed_outputs: Dict[str, Dict[str, Any] | List[Dict[str, Any]]] = {}, + ) -> Dict[str, BaseNodeOutput]: + """Execute the workflow with the given input data. + + input: input for the input node of the workflow. Dict[: ] + node_ids: list of node_ids to run. If empty, run all nodes. + precomputed_outputs: precomputed outputs for the nodes. + These nodes will not be executed again. + """ + return await self.run(input, node_ids, precomputed_outputs) + + async def run_batch( + self, input_iterator: Iterator[Dict[str, Any]], batch_size: int = 100 + ) -> List[Dict[str, BaseNodeOutput]]: + """Run the workflow on a batch of inputs.""" + results: List[Dict[str, BaseNodeOutput]] = [] + batch_tasks: List[asyncio.Task[Dict[str, BaseNodeOutput]]] = [] + for input in input_iterator: + batch_tasks.append(asyncio.create_task(self.run(input))) + if len(batch_tasks) == batch_size: + results.extend(await asyncio.gather(*batch_tasks)) + batch_tasks = [] + if batch_tasks: + results.extend(await asyncio.gather(*batch_tasks)) + return results + + def add_resumed_node_id(self, node_id: str) -> None: + """Add a node ID to the set of resumed node IDs.""" + self._resumed_node_ids.add(node_id) + + +if __name__ == "__main__": + workflow = WorkflowDefinitionSchema.model_validate( + { + "nodes": [ + { + "id": "input_node", + "title": "", + "node_type": "InputNode", + "config": {"output_schema": {"question": "string"}}, + "coordinates": {"x": 281.25, "y": 128.75}, + }, + { + "id": "bon_node", + "title": "", + "node_type": "BestOfNNode", + "config": { + "samples": 1, + "output_schema": { + "response": "string", + "next_potential_question": "string", + }, + "llm_info": { + "model": "gpt-4o", + "max_tokens": 16384, + "temperature": 0.7, + "top_p": 0.9, + }, + "system_message": "You are a helpful assistant.", + "user_message": "", + }, + "coordinates": {"x": 722.5, "y": 228.75}, + }, + { + "id": "output_node", + "title": "", + "node_type": "OutputNode", + "config": { + "title": "OutputNodeConfig", + "type": "object", + "output_schema": { + "question": "string", + "response": "string", + }, + "output_map": { + "question": "bon_node.next_potential_question", + "response": "bon_node.response", + }, + }, + "coordinates": {"x": 1187.5, "y": 203.75}, + }, + ], + "links": [ + { + "source_id": "input_node", + "target_id": "bon_node", + }, + { + "source_id": "bon_node", + "target_id": "output_node", + }, + ], + "test_inputs": [ + { + "id": 1733466671014, + "question": "

Is altruism inherently selfish?

", + } + ], + } + ) + executor = WorkflowExecutor(workflow) + input = {"question": "Is altruism inherently selfish?"} + outputs = asyncio.run(executor(input)) + print(outputs) diff --git a/pyspur/backend/pyspur/integrations/google/auth.py b/pyspur/backend/pyspur/integrations/google/auth.py new file mode 100644 index 0000000000000000000000000000000000000000..7ed06519ee046c97183762a8e201486af5f3b7ac --- /dev/null +++ b/pyspur/backend/pyspur/integrations/google/auth.py @@ -0,0 +1,67 @@ +import json +import os +import time + +# Import the logger +from logging import getLogger +from pathlib import Path + +from fastapi import APIRouter +from pydantic import BaseModel + +logger = getLogger(__name__) + +# Define a router for Google OAuth +router = APIRouter() + +PROJECT_ROOT = os.getenv("PROJECT_ROOT", os.getcwd()) +BASE_DIR = Path(PROJECT_ROOT) / "credentials" / "google" + +# Default file paths for credentials and tokens. +TOKEN_FILE_PATH = BASE_DIR / "token.json" + + +class TokenInput(BaseModel): + access_token: str + expires_in: int + + +@router.post("/store_token/") +async def store_token(token: TokenInput): + try: + TOKEN_FILE_PATH.parent.mkdir(parents=True, exist_ok=True) + with open(TOKEN_FILE_PATH, "w") as token_file: + current_time = time.time() + token_data = { + "access_token": token.access_token, + "expires_at": current_time + token.expires_in, + } + json.dump(token_data, token_file) + return {"message": "Token stored successfully!"} + except Exception as e: + logger.error(f"Error storing token: {e}") + return {"message": "Error storing token!"} + + +@router.get("/validate_token/") +async def validate_token(): + try: + if not TOKEN_FILE_PATH.exists(): + # If the token file does not exist, return False + return {"is_valid": False} + + with open(TOKEN_FILE_PATH, "r") as token_file: + token_data = json.load(token_file) + expires_at = token_data.get("expires_at") + if expires_at is None: + return {"is_valid": False} + + # Check if the token has expired + if expires_at <= time.time(): + return {"is_valid": False} + + return {"is_valid": True} + except Exception as e: + logger.error(f"Error checking token: {e}") + # In case of an exception, assume the token is invalid + return {"is_valid": False} diff --git a/pyspur/backend/pyspur/integrations/google/client.py b/pyspur/backend/pyspur/integrations/google/client.py new file mode 100644 index 0000000000000000000000000000000000000000..6ad62f5d5558d0378f393dea910563daaab0d315 --- /dev/null +++ b/pyspur/backend/pyspur/integrations/google/client.py @@ -0,0 +1,76 @@ +import json +import os +from pathlib import Path +from typing import Tuple + +from google.oauth2.credentials import Credentials +from googleapiclient.discovery import build # type: ignore +from googleapiclient.errors import HttpError # type: ignore + +PROJECT_ROOT = os.getenv("PROJECT_ROOT", os.getcwd()) +BASE_DIR = Path(PROJECT_ROOT) / "credentials" / "google" + +# Default file paths for credentials and tokens. +TOKEN_FILE = BASE_DIR / "token.json" + + +class GoogleSheetsClient: + token_path = Path(TOKEN_FILE) + + def get_credentials(self) -> Credentials: + if self.token_path.exists(): + with self.token_path.open("r") as token_file: + token_data = json.load(token_file) + access_token = token_data.get("access_token") + if not access_token: + raise RuntimeError("No access token found in token file.") + else: + raise FileNotFoundError("Token file does not exist.") + + # Load existing credentials from token file if it exists. + creds = Credentials(token=access_token) # type: ignore + return creds + + def read_sheet(self, spreadsheet_id: str, range_name: str) -> Tuple[bool, str]: + """ + Fetches data from the specified spreadsheet range. + + :param spreadsheet_id: The unique ID of the Google Spreadsheet (found in its URL). + :param range_name: The A1 notation specifying which cells to retrieve (e.g., "Sheet1!A1:C10"). + :return: A tuple (success, data_or_error). + success = True if data was retrieved successfully, else False. + data_or_error = stringified list of values or an error message. + """ + try: + creds = self.get_credentials() + service = build("sheets", "v4", credentials=creds) # type: ignore + sheet = service.spreadsheets() # type: ignore + + result = ( + sheet.values() + .get(spreadsheetId=spreadsheet_id, range=range_name) # type: ignore + .execute() + ) + + values = result.get("values", []) # type: ignore + if not values: + return False, "No data found." + return True, str(values) # type: ignore + + except HttpError as http_err: + return False, f"HTTP Error: {http_err}" + except Exception as e: + return False, f"An error occurred: {e}" + + +# Example usage (run as script): +if __name__ == "__main__": + SAMPLE_SPREADSHEET_ID = "REPLACE_WITH_YOUR_SPREADSHEET_ID" + SAMPLE_RANGE_NAME = "Sheet1!A1:C6" + + client = GoogleSheetsClient() + success, data = client.read_sheet(SAMPLE_SPREADSHEET_ID, SAMPLE_RANGE_NAME) + if success: + print("Data from sheet:", data) + else: + print("Error:", data) diff --git a/pyspur/backend/pyspur/integrations/slack/client.py b/pyspur/backend/pyspur/integrations/slack/client.py new file mode 100644 index 0000000000000000000000000000000000000000..13095d7809c7bf8fcb6afe2ac3192ae545584a02 --- /dev/null +++ b/pyspur/backend/pyspur/integrations/slack/client.py @@ -0,0 +1,88 @@ +import os + +from dotenv import load_dotenv +from slack_sdk import WebClient +from slack_sdk.errors import SlackApiError + + +SLACK_BOT_TOKEN = os.getenv("SLACK_BOT_TOKEN") +SLACK_USER_TOKEN = os.getenv("SLACK_USER_TOKEN") + + +class SlackClient: + def __init__(self): + self.bot_token = SLACK_BOT_TOKEN + self.user_token = SLACK_USER_TOKEN + + self.bot_client = WebClient(token=self.bot_token) + self.user_client = WebClient(token=self.user_token) + + def send_message_as_bot(self, channel: str, text: str) -> tuple[bool, str]: + """ + Sends a message to the specified Slack channel. + + Returns: + bool: True if successful, False otherwise. + str: The status message. + """ + + if not self.bot_token: + raise ValueError("Slack bot token not found in environment variables.") + + try: + response = self.bot_client.chat_postMessage(channel=channel, text=text) # type: ignore + return response.get("ok", False), "success" + except SlackApiError as e: + print(f"Error sending message: {e}") + return False, str(e) + + def send_message_as_user(self, channel: str, text: str) -> tuple[bool, str]: + """ + Sends a message to the specified Slack channel as a user. + + Returns: + bool: True if successful, False otherwise. + str: The status message. + """ + + if not self.user_token: + raise ValueError("Slack user token not found in environment variables.") + + try: + response = self.user_client.chat_postMessage(channel=channel, text=text) # type: ignore + return response.get("ok", False), "success" + except SlackApiError as e: + print(f"Error sending message: {e}") + return False, str(e) + + def send_message(self, channel: str, text: str, mode: str = "bot") -> tuple[bool, str]: + """ + Sends a message to the specified Slack channel. + + Args: + channel (str): The channel ID to send the message to. + text (str): The message to send to the Slack channel. + mode (str): The mode to send the message in. Can be 'bot' or 'user'. + + Returns: + bool: True if successful, False otherwise. + str: The status message. + """ + if mode == "bot": + return self.send_message_as_bot(channel, text) + elif mode == "user": + return self.send_message_as_user(channel, text) + else: + raise ValueError(f"Invalid mode: {mode}") + + +if __name__ == "__main__": + client = SlackClient() + client.send_message_as_bot(channel="#integrations", text="Hello from the SlackClient!") + client.send_message_as_user(channel="#integrations", text="Hello from the Slack Client!") + client.send_message(channel="#integrations", text="Hello from the Slack Client!", mode="bot") + client.send_message( + channel="#integrations", + text="Hello from the Slack Client!", + mode="user", + ) diff --git a/pyspur/backend/pyspur/integrations/slack/socket_client.py b/pyspur/backend/pyspur/integrations/slack/socket_client.py new file mode 100644 index 0000000000000000000000000000000000000000..14e8b69101b3ed3b2817b00f0afa838e389421fe --- /dev/null +++ b/pyspur/backend/pyspur/integrations/slack/socket_client.py @@ -0,0 +1,701 @@ +# type: ignore +import asyncio +import logging +import os +import threading +import traceback +from datetime import datetime +from typing import Any, Callable, Dict, Optional, Set + +from slack_bolt import App +from slack_bolt.adapter.socket_mode import SocketModeHandler +from slack_bolt.oauth.oauth_settings import OAuthSettings +from slack_sdk.oauth.installation_store import FileInstallationStore +from slack_sdk.oauth.installation_store.models.installation import Installation +from slack_sdk.oauth.state_store import FileOAuthStateStore + +from ...database import get_db +from ...models.slack_agent_model import SlackAgentModel +from ...schemas.slack_schemas import WorkflowTriggerRequest + +logger = logging.getLogger("pyspur") + + +class SocketModeClient: + """Client for handling Slack Socket Mode connections. + + This manages real-time event processing from Slack. + """ + + _instance = None + _lock = threading.Lock() + + def __new__(cls): + with cls._lock: + if cls._instance is None: + cls._instance = super(SocketModeClient, cls).__new__(cls) + cls._instance._initialized = False + return cls._instance + + def __init__(self): + if self._initialized: + return + + # Initialize state + self._socket_mode_handlers: Dict[int, SocketModeHandler] = {} + self._apps: Dict[int, App] = {} + self._initialized = True + self._workflow_trigger_callback: Optional[Callable[..., Any]] = None + # Blacklist for agents that should be ignored even if events are received + self._blacklisted_agents: Set[int] = set() + # Track running background tasks/threads for proper cleanup + self._socket_tasks: Dict[int, Any] = {} + # Track session tokens to forcibly disconnect Slack sessions + self._session_tokens: Dict[int, str] = {} + + # Create a base directory for installation data + os.makedirs("/tmp/slack-installation-store", exist_ok=True) + + logger.info("SocketModeClient initialized") + + def set_workflow_trigger_callback(self, callback: Callable[..., Any]): + """Set the callback function to be called when a workflow should be triggered. + + The callback can be either a regular function or an async coroutine function. + If it's a coroutine function, it will be properly awaited when called. + """ + self._workflow_trigger_callback = callback + # Log whether the callback is a coroutine function + is_async = asyncio.iscoroutinefunction(callback) + logger.info(f"Setting workflow trigger callback. Is async: {is_async}") + + def _register_event_handlers(self, app: App, agent_id: int): + """Register event handlers for the Slack app.""" + + @app.event("app_mention") + def handle_app_mention( + event: Dict[str, Any], + say: Callable, + body: Dict[str, Any], + logger: logging.Logger, + client, + ): + """Handle app mention events from Slack""" + logger.info(f"Agent {agent_id} received mention: {event}") + self._process_event(agent_id, "app_mention", event, body, say, client) + + @app.event("message") + def handle_message( + event: Dict[str, Any], + say: Callable, + body: Dict[str, Any], + logger: logging.Logger, + client, + ): + """Handle message events from Slack""" + # Skip bot messages to avoid loops + if event.get("bot_id") or event.get("user") == "USLACKBOT": + return + + logger.info(f"Agent {agent_id} received message: {event}") + self._process_event(agent_id, "message", event, body, say, client) + + # Add error handler for app + @app.error + def handle_errors(error: Exception, logger: logging.Logger): + """Handle any errors that occur during event processing""" + logger.error(f"Error in Slack app for agent {agent_id}: {error}") + + # Log detailed error information for debugging + logger.error(f"Error details: {traceback.format_exc()}") + + def _process_event( + self, + agent_id: int, + event_type: str, + event: Dict[str, Any], + body: Dict[str, Any], + say: Callable, + client=None, + ): + """Process a Slack event and trigger workflows if appropriate.""" + # Add diagnostics about the event + logger.info(f"Received {event_type} event for agent {agent_id}") + logger.info(f"Current blacklist: {self._blacklisted_agents}") + + # Check if this agent is blacklisted - if so, ignore this event + if agent_id in self._blacklisted_agents: + logger.warning(f"Received event for blacklisted agent {agent_id}, ignoring") + return + + if not self._workflow_trigger_callback: + logger.error("No workflow trigger callback set") + return + + # Get a database session + db = next(get_db()) + + try: + # Get the agent to verify it exists and is active + agent = ( + db.query(SlackAgentModel) + .filter( + SlackAgentModel.id == agent_id, + SlackAgentModel.is_active.is_(True), + SlackAgentModel.trigger_enabled.is_(True), + SlackAgentModel.socket_mode_enabled.is_(True), + ) + .first() + ) + + if not agent: + # Check specifically for socket_mode_enabled to provide better logging + socket_check = ( + db.query(SlackAgentModel).filter(SlackAgentModel.id == agent_id).first() + ) + if socket_check and not getattr(socket_check, "socket_mode_enabled", False): + logger.warning( + f"Received event for agent {agent_id} but socket_mode_enabled is False. " + f"This event should not have been received. Ignoring." + ) + # Add to blacklist to prevent future events + self._blacklisted_agents.add(agent_id) + logger.info(f"Updated blacklist to: {self._blacklisted_agents}") + else: + logger.warning(f"Agent {agent_id} not found or not active") + return + + # Extract relevant information from the event + text = event.get("text", "") + channel_id = event.get("channel", "") + user_id = event.get("user", "") + team_id = body.get("team_id", "") + + # Create a trigger request + trigger_request = WorkflowTriggerRequest( + text=text, + channel_id=channel_id, + user_id=user_id, + team_id=team_id, + event_type=event_type, + event_data=event, + ) + + # Call the workflow trigger callback + # Check if the callback is a coroutine function + callback_result = self._workflow_trigger_callback( + trigger_request, agent_id, say, client + ) + + # If it's a coroutine, we need to run it in an event loop + if asyncio.iscoroutine(callback_result): + try: + # Try to get the current event loop + try: + loop = asyncio.get_event_loop() + except RuntimeError: + # If there's no event loop in this thread, create a new one + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + # Run the coroutine + if loop.is_running(): + # We're in an event loop but in a non-async context + # Schedule the coroutine to run soon + future = asyncio.run_coroutine_threadsafe(callback_result, loop) + # Optionally wait for it to complete with a timeout + # future.result(timeout=10) + else: + # We can run the coroutine directly + loop.run_until_complete(callback_result) + except Exception as e: + logger.error(f"Error executing coroutine: {e}") + logger.error(f"Coroutine error details: {traceback.format_exc()}") + + except Exception as e: + logger.error(f"Error processing event for agent {agent_id}: {e}") + logger.error(f"Error details: {traceback.format_exc()}") + finally: + db.close() + + def start_socket_mode(self, agent_id: int) -> bool: + """Start socket mode for a Slack agent.""" + logger.info(f"Starting socket mode for agent {agent_id}") + + # First make sure any existing socket is stopped + if agent_id in self._socket_mode_handlers: + logger.info(f"Stopping existing socket for agent {agent_id} before restart") + self.stop_socket_mode(agent_id) + + # Remove from blacklist if it's there + if agent_id in self._blacklisted_agents: + self._blacklisted_agents.remove(agent_id) + logger.info(f"Removed agent {agent_id} from blacklist for restart") + + # Get a database session + db = next(get_db()) + + try: + # Get the agent + agent = db.query(SlackAgentModel).filter(SlackAgentModel.id == agent_id).first() + + if not agent: + logger.error(f"Agent {agent_id} not found") + return False + + # Get the tokens from the agent + bot_token = getattr(agent, "slack_bot_token", None) + app_token = getattr(agent, "slack_app_token", None) + + # Initialize token store once to avoid redundant imports + from ...api.secure_token_store import get_token_store + + token_store = get_token_store() + + # Try to get tokens from secure token store if not available on agent + if not bot_token: + bot_token = token_store.get_token(agent_id, "bot_token") + + if not app_token: + app_token = token_store.get_token(agent_id, "app_token") + + # For Socket Mode, signing secret is optional + signing_secret = getattr( + agent, "slack_signing_secret", os.environ.get("SLACK_SIGNING_SECRET") + ) + + if not bot_token or not bot_token.startswith("xoxb-"): + logger.error(f"Invalid bot token format for agent {agent_id}") + return False + + if not app_token or not app_token.startswith("xapp-"): + logger.error(f"Invalid app token format for agent {agent_id}") + return False + + try: + # Create unique installation store path for this agent + installation_store_path = f"/tmp/slack-installation-store/{agent_id}" + os.makedirs(installation_store_path, exist_ok=True) + + # Create installation store + installation_store = FileInstallationStore(base_dir=installation_store_path) + + # Create OAuth settings with installation store + oauth_settings = OAuthSettings( + client_id=os.environ.get("SLACK_CLIENT_ID", ""), + client_secret=os.environ.get("SLACK_CLIENT_SECRET", ""), + scopes=["chat:write", "app_mentions:read", "channels:history", "channels:read"], + installation_store=installation_store, + state_store=FileOAuthStateStore( + base_dir=installation_store_path, expiration_seconds=600 + ), + ) + + # Create the app with OAuth settings + app = App( + token=bot_token, signing_secret=signing_secret, oauth_settings=oauth_settings + ) + + # Manually store the installation data for this workspace + # Get bot info to retrieve the bot_id, bot_user_id, and team_id + bot_info_response = app.client.auth_test() # type: ignore + if not bot_info_response["ok"]: + logger.error(f"Failed to get bot info: {bot_info_response['error']}") + return False + + team_id = str(bot_info_response["team_id"]) # type: ignore + bot_user_id = str(bot_info_response["user_id"]) # type: ignore + + # Create and store installation data + installation = Installation( + app_id=os.environ.get("SLACK_APP_ID", bot_info_response.get("bot_id", "")), + enterprise_id=None, + team_id=team_id, + user_id=bot_user_id, + bot_token=bot_token, + bot_id=bot_info_response.get("bot_id", ""), + bot_user_id=bot_user_id, + bot_scopes=[ + "chat:write", + "app_mentions:read", + "channels:history", + "channels:read", + ], + installed_at=datetime.now().timestamp(), + ) + + # Store the installation data + installation_store.save(installation) + logger.info(f"Stored installation data for team_id: {team_id}") + + # Register event handlers + self._register_event_handlers(app, agent_id) + + # Start socket mode + logger.info(f"Starting socket mode with app token for agent {agent_id}") + socket_handler = SocketModeHandler(app=app, app_token=app_token) + socket_handler.start() + + # Store the handler reference for later stop + self._socket_mode_handlers[agent_id] = socket_handler + self._apps[agent_id] = app + + # Store any background tasks/threads the handler has created + if hasattr(socket_handler, "thread") and socket_handler.thread: + self._socket_tasks[agent_id] = socket_handler.thread + logger.info(f"Stored background thread for agent {agent_id}") + + logger.info(f"Socket mode started successfully for agent {agent_id}") + return True + + except Exception as e: + # Handle exception from socket mode initialization + logger.error(f"Error starting socket mode for agent {agent_id}: {e}") + logger.error(f"Traceback: {traceback.format_exc()}") + return False + except Exception as e: + logger.error(f"Error in socket mode setup for agent {agent_id}: {e}") + logger.error(f"Traceback: {traceback.format_exc()}") + return False + finally: + db.close() + + def stop_socket_mode(self, agent_id: int) -> bool: + """Stop Socket Mode for a specific agent.""" + logger.info(f"Stopping Socket Mode for agent {agent_id}") + + # Add agent to blacklist to reject any incoming events + self._blacklisted_agents.add(agent_id) + logger.info(f"Added agent {agent_id} to blacklist: {self._blacklisted_agents}") + + # ========== FORCIBLY DISCONNECT SLACK SESSIONS USING APP TOKEN ========== + # This is a more aggressive approach that terminates the connection at Slack's side + self._forcibly_disconnect_slack_sessions(agent_id) + + # First check if we have a task/thread to terminate + if agent_id in self._socket_tasks: + thread = self._socket_tasks[agent_id] + logger.info(f"Found background thread for agent {agent_id}: {thread}") + try: + # Check thread type and try to terminate it + if hasattr(thread, "is_alive") and thread.is_alive(): + logger.info("Thread is alive, attempting to terminate") + if hasattr(thread, "_stop"): + thread._stop() + if hasattr(thread, "_terminate"): + thread._terminate() + # Wait a moment for the thread to terminate + import time + + time.sleep(0.5) + logger.info(f"Thread alive after terminate: {thread.is_alive()}") + except Exception as e: + logger.error(f"Error terminating thread for agent {agent_id}: {e}") + finally: + # Remove from our tracking regardless of success + del self._socket_tasks[agent_id] + + if agent_id not in self._socket_mode_handlers: + logger.warning(f"No Socket Mode handler found for agent {agent_id}") + return False + + # Try to close any of the various components that might be keeping the connection alive + try: + # Close the socket mode handler + handler = self._socket_mode_handlers[agent_id] + + # Log handler details for debugging + logger.info(f"Handler type: {type(handler)}") + logger.info(f"Handler attributes: {dir(handler)}") + + # Try to examine the WebSocket connection if possible + if hasattr(handler, "client") and handler.client: + logger.info(f"Client type: {type(handler.client)}") + logger.info(f"Client attributes: {dir(handler.client)}") + if hasattr(handler.client, "is_connected") and callable( + getattr(handler.client, "is_connected", None) + ): + logger.info(f"Client connected status: {handler.client.is_connected()}") + + # First, attempt to kill the WebSocket connection + try: + # Access underlying WebSocket client (may vary based on slack_bolt implementation) + if hasattr(handler, "client") and handler.client: + logger.info(f"Shutting down WebSocket client for agent {agent_id}") + # Force a disconnect if possible + if hasattr(handler.client, "disconnect"): + handler.client.disconnect() + if hasattr(handler.client, "close"): + handler.client.close() + + # Access the app connection + if hasattr(handler, "app") and handler.app: + if hasattr(handler.app, "stop"): + logger.info(f"Stopping app for agent {agent_id}") + handler.app.stop() + except Exception as e: + logger.error(f"Error during WebSocket cleanup: {e}") + + # Try to close the app first to stop any running listeners/callbacks + if agent_id in self._apps: + try: + # Get the app and attempt to shutdown any active listeners + app = self._apps[agent_id] + # Disconnect all listeners and callbacks + if hasattr(app, "client") and app.client: + logger.info(f"Disconnecting client for agent {agent_id}") + if hasattr(app.client, "close"): + # Try to close the client's connection + app.client.close() + except Exception as e: + logger.error(f"Error shutting down app for agent {agent_id}: {e}") + # Continue with handler close even if app shutdown fails + + # Now close the socket handler + try: + logger.info(f"Closing socket handler for agent {agent_id}") + try: + # Try to stop the handler's background processor + if hasattr(handler, "processor") and handler.processor: + if hasattr(handler.processor, "stop"): + handler.processor.stop() + except Exception as e: + logger.error(f"Error stopping processor: {e}") + + # Close the main handler + handler.close() + + # Force close any remaining connections + if hasattr(handler, "client") and handler.client: + logger.info(f"Force closing handler client for agent {agent_id}") + if hasattr(handler.client, "close"): + handler.client.close() + + # If there's a WebSocket connection still open, try to close it + if hasattr(handler, "web_socket_client") and handler.web_socket_client: + logger.info(f"Force closing WebSocket for agent {agent_id}") + if hasattr(handler.web_socket_client, "close"): + handler.web_socket_client.close() + # Even more aggressive - if there's a socket object + if ( + hasattr(handler.web_socket_client, "sock") + and handler.web_socket_client.sock + ): + try: + logger.info(f"Force closing raw socket for agent {agent_id}") + handler.web_socket_client.sock.close() + except Exception as e: + logger.error(f"Error closing raw socket: {e}") + + logger.info(f"Socket handler closed for agent {agent_id}") + except Exception as e: + logger.error(f"Error closing socket handler: {e}") + logger.error(f"Socket close error details: {traceback.format_exc()}") + + # Try more aggressive thread termination approach + self._try_aggressive_thread_termination(agent_id, handler) + + # Remove from dictionaries + if agent_id in self._socket_mode_handlers: + del self._socket_mode_handlers[agent_id] + if agent_id in self._apps: + del self._apps[agent_id] + + # Make sure we're blacklisted + if agent_id not in self._blacklisted_agents: + self._blacklisted_agents.add(agent_id) + logger.info(f"Added agent {agent_id} to blacklist") + + # Verify the socket is actually stopped + if self.is_running(agent_id): + logger.error( + f"Socket for agent {agent_id} is still reported as running after stop attempt" + ) + return False + + # Add an extra check for the socket worker to verify any lingering connections + logger.info(f"Socket Mode stopped successfully for agent {agent_id}") + + # Reset any in-memory connection caches that Slack SDK might be maintaining + try: + # Slack SDK might cache connections somewhere globally - try to clear them + import importlib + + try: + slack_sdk = importlib.import_module("slack_sdk") + if hasattr(slack_sdk, "WebClient") and hasattr(slack_sdk.WebClient, "_reset"): + slack_sdk.WebClient._reset() + logger.info("Reset Slack SDK WebClient") + except Exception: + pass + + try: + socket_mode = importlib.import_module("slack_bolt.adapter.socket_mode") + # Try to get any cached handlers and close them + if hasattr(socket_mode, "_connections"): + socket_mode._connections = {} + except Exception: + pass + except Exception as e: + logger.error(f"Error resetting SDK connections: {e}") + + return True + except Exception as e: + logger.error(f"Error stopping Socket Mode for agent {agent_id}: {e}") + logger.error(f"Traceback: {traceback.format_exc()}") + return False + + def _try_aggressive_thread_termination(self, agent_id: int, handler: Any) -> None: + """Attempt to aggressively terminate any threads or tasks associated with the socket handler.""" + try: + # See if the socket handler has a thread running and try to terminate it + if hasattr(handler, "thread") and handler.thread: + thread = handler.thread + logger.info(f"Found SocketModeHandler thread: {thread}") + if hasattr(thread, "is_alive") and thread.is_alive(): + logger.info("Socket thread is still alive, trying aggressive termination") + # Try different thread termination methods + if hasattr(thread, "_stop"): + thread._stop() + if hasattr(thread, "_terminate"): + thread._terminate() + if hasattr(thread, "cancel"): + thread.cancel() + if hasattr(thread, "kill"): + thread.kill() + + # If there's a running task, try to cancel it + if hasattr(handler, "task") and handler.task: + logger.info("Cancelling socket handler task") + if hasattr(handler.task, "cancel"): + handler.task.cancel() + + # If the handler has a loop running, try to stop it + if hasattr(handler, "loop") and handler.loop: + logger.info("Stopping socket handler event loop") + if hasattr(handler.loop, "stop"): + handler.loop.stop() + if hasattr(handler.loop, "close"): + handler.loop.close() + + # Look for any WebSocketApp connections + if hasattr(handler, "wss_client") and handler.wss_client: + logger.info("Found WebSocketApp, forcefully closing") + if hasattr(handler.wss_client, "close"): + handler.wss_client.close() + if hasattr(handler.wss_client, "sock") and handler.wss_client.sock: + if hasattr(handler.wss_client.sock, "shutdown"): + handler.wss_client.sock.shutdown() + if hasattr(handler.wss_client.sock, "close"): + handler.wss_client.sock.close() + + except Exception as e: + logger.error(f"Error with aggressive thread termination for agent {agent_id}: {e}") + + def is_running(self, agent_id: int) -> bool: + """Check if Socket Mode is running for a specific agent.""" + return agent_id in self._socket_mode_handlers + + def stop_all(self): + """Stop all Socket Mode handlers.""" + logger.info("Stopping all Socket Mode handlers") + + agent_ids = list(self._socket_mode_handlers.keys()) + for agent_id in agent_ids: + self.stop_socket_mode(agent_id) + + def _forcibly_disconnect_slack_sessions(self, agent_id: int) -> None: + """Forcibly disconnect any Slack sessions for this agent by invalidating app connections. + + This is the most effective way to ensure Slack stops sending events to this agent. + """ + logger.info(f"Forcibly disconnecting Slack sessions for agent {agent_id}") + + try: + # Get agent tokens from secure store + from ...api.secure_token_store import get_token_store + + token_store = get_token_store() + + # If we have an app token, use it to revoke connections + app_token = token_store.get_token(agent_id, "app_token") + if app_token: + logger.info(f"Using app token to forcibly disconnect sessions for agent {agent_id}") + try: + # Even though we can't directly revoke app tokens, we can try to disconnect sessions + # by making an auth test call with an invalid client + from slack_sdk import WebClient + from slack_sdk.errors import SlackApiError + + # Create dummy WebClient with the app token (will fail but helps disconnect) + client = WebClient(token=app_token) + try: + # This will fail but trigger a session reset on Slack side + client.auth_test() + except SlackApiError: + pass + + # Try to directly disconnect client + if hasattr(client, "close"): + client.close() + except Exception as e: + logger.error(f"Error disconnecting app token session: {e}") + + # If we have a bot token, also use it for more thorough disconnection + bot_token = token_store.get_token(agent_id, "bot_token") + if bot_token: + logger.info(f"Using bot token to forcibly disconnect sessions for agent {agent_id}") + try: + from slack_sdk import WebClient + from slack_sdk.errors import SlackApiError + + # Create client and try to disconnect gracefully + client = WebClient(token=bot_token) + + # Post a system message to help debug (comment out in production) + try: + # Find the primary channel id from the agent model + db = next(get_db()) + agent = ( + db.query(SlackAgentModel).filter(SlackAgentModel.id == agent_id).first() + ) + if agent and agent.slack_channel_id: + channel_id = agent.slack_channel_id + # Send logout message to channel + client.chat_postMessage( + channel=channel_id, + text=f"⚠️ Socket Mode disabled for agent {agent_id}. Disconnecting active sessions...", + ) + except Exception as notify_err: + logger.error(f"Error notifying Slack channel: {notify_err}") + + # Close client to disconnect + if hasattr(client, "close"): + client.close() + except Exception as e: + logger.error(f"Error disconnecting bot token session: {e}") + + # Last resort - force terminate socket handlers directly by accessing internals + if agent_id in self._socket_mode_handlers: + handler = self._socket_mode_handlers[agent_id] + # Access websocket app directly if possible + if hasattr(handler, "app") and handler.app: + app = handler.app + # Try to stop all socket connections by accessing socket_mode listeners directly + if hasattr(app, "listeners") and isinstance(app.listeners, dict): + for event_type in list(app.listeners.keys()): + try: + # Remove all listeners to prevent event handling + app.listeners[event_type] = [] + except Exception: + pass + + except Exception as e: + logger.error(f"Error forcibly disconnecting sessions: {e}") + logger.error(f"Traceback: {traceback.format_exc()}") + + +# Singleton instance accessor +def get_socket_mode_client() -> SocketModeClient: + """Get the singleton SocketModeClient instance.""" + return SocketModeClient() diff --git a/pyspur/backend/pyspur/integrations/slack/socket_manager.py b/pyspur/backend/pyspur/integrations/slack/socket_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..23f0b5830cfc26e7348080bacacd36123eb9ecc2 --- /dev/null +++ b/pyspur/backend/pyspur/integrations/slack/socket_manager.py @@ -0,0 +1,320 @@ +import logging +import multiprocessing +import os +import signal +import time +from types import FrameType +from typing import Any, Dict, Optional, cast + +import psutil +from loguru import logger +from sqlalchemy.orm import Session + +from ...database import get_db +from ...models.slack_agent_model import SlackAgentModel +from .socket_worker import get_active_agents +from .socket_worker import main as worker_main +from .worker_status import MARKER_DIR, find_running_worker_process + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger.info("Starting Slack Socket Manager") + + +class SocketManager: + """Manager for Slack Socket Mode workers using multiprocessing. + + This manages multiple worker processes, each handling a specific Slack agent. + """ + + _instance: Optional["SocketManager"] = None + + def __new__(cls, *args: Any, **kwargs: Any) -> "SocketManager": + if cls._instance is None: + cls._instance = super(SocketManager, cls).__new__(cls) + return cls._instance + + def __init__(self): + # Prevent reinitializing on subsequent instantiations + if hasattr(self, "_initialized") and self._initialized: + return + # ProcessLike is anything with pid and is_alive() attributes + self.workers: Dict[int, Any] = {} + self.stopping = False + self.setup_signal_handlers() + self._initialized = True + + def setup_signal_handlers(self): + """Set up signal handlers for graceful shutdown.""" + signal.signal(signal.SIGTERM, lambda signum, frame: self.handle_shutdown(signum, frame)) + signal.signal(signal.SIGINT, lambda signum, frame: self.handle_shutdown(signum, frame)) + + def handle_shutdown(self, signum: int, frame: Optional[FrameType] = None) -> None: + """Handle shutdown signals by stopping all workers gracefully.""" + logger.info(f"Received signal {signum}, shutting down all workers...") + self.stopping = True + self.stop_all_workers() + + def start_worker(self, agent_id: int) -> bool: + """Start a new worker process for a specific agent. + + Args: + agent_id: The ID of the Slack agent to handle + + Returns: + bool: True if worker started successfully, False otherwise + + """ + agent_id = int(agent_id) + # First check if there's an existing worker that's actually running + if agent_id in self.workers: + existing_worker = self.workers[agent_id] + if existing_worker.is_alive(): + logger.info( + f"Worker for agent {agent_id} is already running (PID: {existing_worker.pid if hasattr(existing_worker, 'pid') else 'unknown'})" + ) + return True + else: + # Worker exists but isn't running - clean it up + logger.warning(f"Found non-running worker for agent {agent_id} - cleaning up") + try: + if hasattr(existing_worker, "terminate"): + existing_worker.terminate() + del self.workers[agent_id] + except Exception as e: + logger.error(f"Error cleaning up dead worker for agent {agent_id}: {e}") + + # Check for existing marker files and running processes even if not tracked in our workers dictionary + marker_file = f"{MARKER_DIR}/agent_{agent_id}.pid" + pid = None + is_running = False + + # First check if a marker file exists and get the PID from it + if os.path.exists(marker_file): + try: + with open(marker_file, "r") as f: + pid_str = f.read().strip() + if pid_str: + pid = int(pid_str) + logger.info( + f"Found existing marker file for agent {agent_id} with PID {pid}" + ) + except Exception as e: + logger.error(f"Error reading PID from marker file for agent {agent_id}: {e}") + + # Check if the process is running + if pid is not None: + try: + if psutil.pid_exists(pid): + # Verify this is actually a socket worker for this agent + proc = psutil.Process(pid) + cmdline = " ".join(proc.cmdline()) + if "socket_worker.py" in cmdline and f"SLACK_AGENT_ID={agent_id}" in cmdline: + is_running = True + logger.info( + f"Found running socket worker for agent {agent_id} with PID {pid}" + ) + + # Create a tracking object for this process + from types import SimpleNamespace + + dummy_process = SimpleNamespace() + dummy_process.pid = pid + + # Add method to check if process is still alive + def is_alive_check(): + if pid is None: + return False + try: + return psutil.pid_exists(pid) and "socket_worker.py" in " ".join( + psutil.Process(pid).cmdline() + ) + except (psutil.NoSuchProcess, psutil.AccessDenied): + return False + + dummy_process.is_alive = is_alive_check + self.workers[agent_id] = dummy_process + return True + except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess) as e: + logger.warning(f"Error checking process {pid} for agent {agent_id}: {e}") + pass + + # If no existing process was found using marker files, do a more thorough search + if not is_running: + is_running, pid = find_running_worker_process(agent_id) + if is_running and pid: + logger.info( + f"Found running worker for agent {agent_id} using process search: PID {pid}" + ) + # Create a tracking object for this process + try: + from types import SimpleNamespace + + dummy_process = SimpleNamespace() + dummy_process.pid = pid + + # Add method to check if process is still alive + def is_alive_check(): + try: + return psutil.pid_exists(pid) and "socket_worker.py" in " ".join( + psutil.Process(pid).cmdline() + ) + except (psutil.NoSuchProcess, psutil.AccessDenied): + return False + + dummy_process.is_alive = is_alive_check + self.workers[agent_id] = dummy_process + return True + except ImportError: + # If psutil is not available, we can't track the process + pass + + # If we get here, no valid worker is running, so start a new one + try: + # Create and start a new process for this agent + logger.info(f"Starting new worker process for agent {agent_id}") + + # Set environment variable to pass the agent ID + env = os.environ.copy() + env["SLACK_AGENT_ID"] = str(agent_id) + + # Create the process + process = multiprocessing.Process( + target=worker_main, args=(agent_id,), name=f"socket_worker_{agent_id}" + ) + process.daemon = True # Make sure process is daemonized + process.start() + self.workers[agent_id] = process + logger.info(f"Started worker process for agent {agent_id} (PID: {process.pid})") + + # Wait a short moment to ensure process started correctly + time.sleep(0.5) + if not process.is_alive(): + logger.error(f"Worker process for agent {agent_id} failed to start") + return False + + return True + except Exception as e: + logger.error(f"Failed to start worker for agent {agent_id}: {e}") + return False + + def stop_worker(self, agent_id: int) -> bool: + """Stop a specific worker process. + + Args: + agent_id: The ID of the Slack agent whose worker should be stopped + + Returns: + bool: True if worker was stopped successfully, False otherwise + + """ + agent_id = int(agent_id) + if agent_id not in self.workers: + return True + + process = self.workers[agent_id] + try: + if process.is_alive(): + # Send SIGTERM to allow graceful shutdown + process.terminate() + # Wait for a short time for graceful shutdown + process.join(timeout=5) + + # If process is still alive, force kill it + if process.is_alive(): + logger.warning(f"Worker {agent_id} did not stop gracefully, force killing...") + process.kill() + process.join(timeout=1) + + del self.workers[agent_id] + logger.info(f"Stopped worker for agent {agent_id}") + return True + except Exception as e: + logger.error(f"Error stopping worker for agent {agent_id}: {e}") + return False + + def stop_all_workers(self): + """Stop all running worker processes.""" + for agent_id in list(self.workers.keys()): + self.stop_worker(agent_id) + + def check_and_restart_workers(self, db: Session) -> None: + """Check for any workers that need to be started or stopped. + + Args: + db: Database session to use for queries + + """ + try: + # Get all agents that should have active workers + active_agents = get_active_agents(db) + active_agent_ids = {int(cast(int, agent.id)) for agent in active_agents} + + # First, check existing workers and update their status in the database + for agent_id in list(self.workers.keys()): + agent = db.query(SlackAgentModel).filter(SlackAgentModel.id == agent_id).first() + if agent: + # Check if the worker is actually running + is_alive = self.workers[agent_id].is_alive() + + # Update the database to match reality if there's a mismatch + if bool(getattr(agent, "socket_mode_enabled", False)) != is_alive: + logger.info( + f"Updating agent {agent_id} socket_mode_enabled to {is_alive} to match actual state" + ) + agent.socket_mode_enabled = is_alive + db.commit() + db.refresh(agent) + + # Stop workers for agents that should no longer be running + for agent_id in list(self.workers.keys()): + if agent_id not in active_agent_ids: + logger.info(f"Stopping worker for inactive agent {agent_id}") + self.stop_worker(agent_id) + + # Start workers for agents that need them + for agent_id in active_agent_ids: + if agent_id not in self.workers or not self.workers[agent_id].is_alive(): + logger.info(f"Starting worker for agent {agent_id}") + self.start_worker(agent_id) + + # Log status + running_workers = sum(1 for p in self.workers.values() if p.is_alive()) + logger.info(f"Worker status: {running_workers}/{len(active_agent_ids)} workers running") + + except Exception as e: + logger.error(f"Error in worker management: {e}") + + +def run_socket_manager(): + """Main entry point for the socket manager.""" + manager = SocketManager() + logger.info("Socket manager started") + + check_interval = 15 # Check every 15 seconds + while not manager.stopping: + try: + # Get a database session + db = next(get_db()) + try: + # Check and update workers + manager.check_and_restart_workers(db) + finally: + db.close() + + # Sleep before next check, but be responsive to shutdown + for _ in range(check_interval): + if manager.stopping: + break + time.sleep(1) + + except Exception as e: + logger.error(f"Error in socket manager main loop: {e}") + # Brief sleep before retry + time.sleep(5) + + logger.info("Socket manager shutting down") + + +if __name__ == "__main__": + run_socket_manager() diff --git a/pyspur/backend/pyspur/integrations/slack/socket_worker.py b/pyspur/backend/pyspur/integrations/slack/socket_worker.py new file mode 100644 index 0000000000000000000000000000000000000000..b06ca9d706fcb10e8265bda8b9ff1e30d8283a7e --- /dev/null +++ b/pyspur/backend/pyspur/integrations/slack/socket_worker.py @@ -0,0 +1,256 @@ +#!/usr/bin/env python +"""Worker process for handling a single Slack Socket Mode connection. + +This runs in a separate process managed by the SocketManager. +""" + +import asyncio +import logging +import os +import signal +import sys +import types +from datetime import datetime +from typing import Optional + +from loguru import logger +from sqlalchemy.orm import Session + +from ...api.slack_management import handle_socket_mode_event_sync +from ...database import get_db +from ...models.slack_agent_model import SlackAgentModel +from .socket_client import SocketModeClient, get_socket_mode_client + +# Configure logging +logging.basicConfig(level=logging.INFO) + + +def get_active_agents(db: Session) -> list[SlackAgentModel]: + """Get all active agents that have socket mode enabled.""" + agents = ( + db.query(SlackAgentModel) + .filter_by( + is_active=True, trigger_enabled=True, has_bot_token=True, socket_mode_enabled=True + ) + .filter(SlackAgentModel.workflow_id.isnot(None)) + .all() + ) + return agents + + +def setup_shutdown_handler(socket_client: SocketModeClient, agent_id: int): + """Set up signal handlers for graceful shutdown.""" + + def handle_shutdown(signum: int, frame: Optional[types.FrameType]) -> None: + logger.info(f"Worker {agent_id} received signal {signum}, shutting down") + socket_client.stop_socket_mode(agent_id) + sys.exit(0) + + signal.signal(signal.SIGTERM, handle_shutdown) + signal.signal(signal.SIGINT, handle_shutdown) + + +async def check_agent_status(db: Session, agent_id: int) -> bool: + """Check if the agent is still active and should be running. + + Args: + db: Database session + agent_id: The agent ID to check + + Returns: + bool: True if the agent should be running, False otherwise + + """ + try: + agent = db.query(SlackAgentModel).filter(SlackAgentModel.id == agent_id).first() + if not agent: + logger.warning(f"Agent {agent_id} no longer exists") + return False + + return ( + bool(agent.is_active) + and bool(agent.trigger_enabled) + and bool(agent.has_bot_token) + and bool(agent.workflow_id) + and bool(agent.socket_mode_enabled) + ) + except Exception as e: + logger.error(f"Error checking agent {agent_id} status: {e}") + return False + + +async def run_worker(agent_id: int): + """Run the worker process for a specific agent + + Args: + agent_id: The ID of the Slack agent to handle + + """ + # Initialize the socket client + socket_client = get_socket_mode_client() + socket_client.set_workflow_trigger_callback(handle_socket_mode_event_sync) + + # Set up shutdown handlers + setup_shutdown_handler(socket_client, agent_id) + + # Print worker info + worker_id = os.environ.get("HOSTNAME", "unknown") + logger.info(f"Socket worker {worker_id} started for agent {agent_id}") + + # Create a marker file to indicate this worker is running + # This helps with tracking workers even if the API restarts + marker_dir = "/tmp/pyspur_socket_workers" + os.makedirs(marker_dir, exist_ok=True) + marker_file = f"{marker_dir}/agent_{agent_id}.pid" + with open(marker_file, "w") as f: + f.write(str(os.getpid())) + + status_file = f"{marker_dir}/agent_{agent_id}.status" + + # Register a cleanup function to remove the marker file when the process exits + import atexit + + def cleanup_marker(): + try: + if os.path.exists(marker_file): + os.remove(marker_file) + logger.info(f"Removed marker file {marker_file}") + except Exception as e: + logger.error(f"Error removing marker file: {e}") + + atexit.register(cleanup_marker) + + # Add a brief delay to ensure database is ready + await asyncio.sleep(5) + + try: + # Start socket mode for this agent + success = socket_client.start_socket_mode(agent_id) + if not success: + logger.error(f"Failed to start socket mode for agent {agent_id}") + return + + logger.info(f"Socket mode started for agent {agent_id}") + + # Write status information to a status file + try: + with open(status_file, "w") as f: + import json + + status_info = { + "agent_id": agent_id, + "started_at": datetime.now().isoformat(), + "pid": os.getpid(), + "hostname": worker_id, + "status": "running", + } + f.write(json.dumps(status_info)) + except Exception as status_err: + logger.error(f"Error writing status file: {status_err}") + + # Keep checking the agent's status and be resilient to database connection issues + max_retries = 3 + retry_count = 0 + while True: + try: + db = next(get_db()) + try: + # Check if we should still be running + should_run = await check_agent_status(db, agent_id) + # Reset retry counter on successful check + retry_count = 0 + + if not should_run: + logger.info(f"Agent {agent_id} is no longer active, shutting down") + break + + # Check if socket is still running + if not socket_client.is_running(agent_id): + logger.warning( + f"Socket for agent {agent_id} is not running, attempting restart" + ) + success = socket_client.start_socket_mode(agent_id) + if not success: + logger.error(f"Failed to restart socket for agent {agent_id}") + retry_count += 1 + if retry_count >= max_retries: + logger.error( + f"Reached max retries ({max_retries}) for agent {agent_id}, shutting down" + ) + break + finally: + db.close() + except asyncio.CancelledError: + logger.info(f"Worker {agent_id} received cancellation, shutting down gracefully") + break + except Exception as e: + logger.error(f"Error in worker loop for agent {agent_id}: {e}") + retry_count += 1 + if retry_count >= max_retries: + logger.error( + f"Reached max retries ({max_retries}) for agent {agent_id}, shutting down" + ) + break + await asyncio.sleep(5) + + except Exception as e: + logger.error(f"Critical error in worker for agent {agent_id}: {e}") + finally: + # Ensure socket is stopped and cleanup is performed + try: + socket_client.stop_socket_mode(agent_id) + except Exception as stop_err: + logger.error(f"Error stopping socket mode: {stop_err}") + + # Update status file to indicate shutdown + try: + with open(status_file, "w") as f: + import json + + status_info = { + "agent_id": agent_id, + "shutdown_at": datetime.now().isoformat(), + "pid": os.getpid(), + "hostname": worker_id, + "status": "stopped", + } + f.write(json.dumps(status_info)) + except Exception as status_err: + logger.error(f"Error updating status file on shutdown: {status_err}") + + # Try to clean up marker files + cleanup_marker() + + +def main(agent_id: Optional[int] = None): + """Main entry point for the worker + + Args: + agent_id: Optional agent ID to handle. If not provided, will handle all active agents. + + """ + if agent_id is None: + logger.error("No agent ID provided") + sys.exit(1) + + try: + asyncio.run(run_worker(agent_id)) + except KeyboardInterrupt: + logger.info("Received keyboard interrupt, shutting down") + except Exception as e: + logger.error(f"Error in worker main: {e}") + sys.exit(1) + + +if __name__ == "__main__": + # If running directly, get agent ID from environment + agent_id_str = os.environ.get("SLACK_AGENT_ID") + if not agent_id_str: + logger.error("No agent ID provided") + sys.exit(1) + try: + agent_id = int(agent_id_str) + except ValueError: + logger.error(f"Invalid agent ID: {agent_id_str}") + sys.exit(1) + main(agent_id) diff --git a/pyspur/backend/pyspur/integrations/slack/worker_status.py b/pyspur/backend/pyspur/integrations/slack/worker_status.py new file mode 100644 index 0000000000000000000000000000000000000000..67aab335190c5c11c32d7163c7cd2c0138b1eb74 --- /dev/null +++ b/pyspur/backend/pyspur/integrations/slack/worker_status.py @@ -0,0 +1,147 @@ +"""Module for checking the status of Slack socket mode workers. + +This provides utilities for identifying running workers from marker files +and status files, which helps maintain state between API restarts. +""" + +import json +import os +from typing import Any, Dict, List, Optional, Tuple, TypedDict + +import psutil +from loguru import logger + +# Base directory for worker marker files +MARKER_DIR = "/tmp/pyspur_socket_workers" + + +class WorkerStatus(TypedDict): + agent_id: int + marker_exists: bool + process_running: bool + pid: Optional[int] + status_file_exists: bool + status: str + details: Dict[str, Any] + + +def get_worker_status(agent_id: int) -> WorkerStatus: + """Get the status of a worker for a specific agent. + + Args: + agent_id: The ID of the agent to check + + Returns: + Dict: A dictionary with status information + + """ + result = WorkerStatus( + agent_id=agent_id, + marker_exists=False, + process_running=False, + pid=None, + status_file_exists=False, + status="unknown", + details={}, + ) + + # Ensure the marker directory exists + if not os.path.exists(MARKER_DIR): + return result + + # Check for marker file + marker_file = f"{MARKER_DIR}/agent_{agent_id}.pid" + if os.path.exists(marker_file): + result["marker_exists"] = True + + # Read the PID + try: + with open(marker_file, "r") as f: + pid = int(f.read().strip()) + result["pid"] = pid + + # Check if process is running + try: + process = psutil.Process(pid) + cmdline = process.cmdline() + cmdline_str = " ".join(cmdline) + if ( + "socket_worker.py" in cmdline_str + and f"SLACK_AGENT_ID={agent_id}" in cmdline_str + ): + result["process_running"] = True + except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess): + pass + except Exception as e: + logger.error(f"Error reading PID from marker file for agent {agent_id}: {e}") + + # Check for status file + status_file = f"{MARKER_DIR}/agent_{agent_id}.status" + if os.path.exists(status_file): + result["status_file_exists"] = True + + # Read the status + try: + with open(status_file, "r") as f: + status_data = json.load(f) + result["status"] = status_data.get("status", "unknown") + result["details"] = status_data + except Exception as e: + logger.error(f"Error reading status file for agent {agent_id}: {e}") + + return result + + +def list_workers() -> List[Dict[str, Any]]: + """List all workers based on marker files. + + Returns: + List[Dict[str, Any]]: A list of worker status dictionaries + + """ + results: List[Dict[str, Any]] = [] + + # Ensure the marker directory exists + if not os.path.exists(MARKER_DIR): + return results + + # Find all marker files + for filename in os.listdir(MARKER_DIR): + if filename.startswith("agent_") and filename.endswith(".pid"): + try: + # Extract agent ID + agent_id_str = filename[6:-4] # Remove "agent_" prefix and ".pid" suffix + agent_id = int(agent_id_str) + + # Get status for this agent + status = get_worker_status(agent_id) + results.append(dict(status)) + except Exception as e: + logger.error(f"Error processing marker file {filename}: {e}") + + return results + + +def find_running_worker_process(agent_id: int) -> Tuple[bool, Optional[int]]: + """Find a running worker process for the given agent ID. + + Args: + agent_id: The agent ID to look for + + Returns: + Tuple[bool, Optional[int]]: A tuple of (is_running, pid) + + """ + for proc in psutil.process_iter(["pid", "cmdline"]): + try: + cmdline = proc.info["cmdline"] + if cmdline: + cmdline_str = " ".join(cmdline) + if "socket_worker.py" in cmdline_str and ( + f"SLACK_AGENT_ID={agent_id}" in cmdline_str + or f"--agent-id={agent_id}" in cmdline_str + ): + return True, proc.info["pid"] + except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess): + continue + return False, None diff --git a/pyspur/backend/pyspur/models/base_model.py b/pyspur/backend/pyspur/models/base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..c2a0b018b537a67aeb2941182114e79031ffc0c4 --- /dev/null +++ b/pyspur/backend/pyspur/models/base_model.py @@ -0,0 +1,3 @@ +from sqlalchemy.orm import declarative_base + +BaseModel = declarative_base() diff --git a/pyspur/backend/pyspur/models/dataset_model.py b/pyspur/backend/pyspur/models/dataset_model.py new file mode 100644 index 0000000000000000000000000000000000000000..d188e4f01ad8a035cd3c93cfa7229fe4d1fce860 --- /dev/null +++ b/pyspur/backend/pyspur/models/dataset_model.py @@ -0,0 +1,20 @@ +from datetime import datetime, timezone +from typing import Optional + +from sqlalchemy import Computed, DateTime, Integer, String +from sqlalchemy.orm import Mapped, mapped_column + +from .base_model import BaseModel + + +class DatasetModel(BaseModel): + __tablename__ = "datasets" + + _intid: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement="auto") + id: Mapped[str] = mapped_column(String, Computed("'DS' || _intid"), nullable=False, unique=True) + name: Mapped[str] = mapped_column(String, unique=True, nullable=False) + description: Mapped[Optional[str]] = mapped_column(String) + file_path: Mapped[str] = mapped_column(String, nullable=False) + uploaded_at: Mapped[datetime] = mapped_column( + DateTime, default=datetime.now(timezone.utc), index=True + ) diff --git a/pyspur/backend/pyspur/models/dc_and_vi_model.py b/pyspur/backend/pyspur/models/dc_and_vi_model.py new file mode 100644 index 0000000000000000000000000000000000000000..8459957f22a0818ee4d017533cf6d299d487b16a --- /dev/null +++ b/pyspur/backend/pyspur/models/dc_and_vi_model.py @@ -0,0 +1,125 @@ +from datetime import datetime, timezone +from typing import Any, Dict, Literal, Optional + +from sqlalchemy import ( + JSON, + Computed, + DateTime, + Float, + ForeignKey, + Integer, + String, +) +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from .base_model import BaseModel + +# Define valid status values +DocumentStatus = Literal["processing", "ready", "error", "deleted"] + + +class DocumentCollectionModel(BaseModel): + """Model for document collections.""" + + __tablename__ = "document_collections" + + _intid: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement="auto") + id: Mapped[str] = mapped_column(String, Computed("'DC' || _intid"), nullable=False, unique=True) + name: Mapped[str] = mapped_column(String, unique=True, nullable=False) + description: Mapped[Optional[str]] = mapped_column(String) + status: Mapped[DocumentStatus] = mapped_column(String, nullable=False, default="processing") + document_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0) + chunk_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0) + error_message: Mapped[Optional[str]] = mapped_column(String) + + # Store configuration + text_processing_config: Mapped[Dict[str, Any]] = mapped_column( + JSON, + nullable=False, + comment=( + "Configuration for text processing including:" + " chunk_token_size, min_chunk_size_chars, etc." + ), + ) + + # Relationships + vector_indices: Mapped[list["VectorIndexModel"]] = relationship( + "VectorIndexModel", + back_populates="document_collection", + cascade="all, delete-orphan", + ) + + created_at: Mapped[datetime] = mapped_column( + DateTime, default=datetime.now(timezone.utc), index=True + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime, + default=datetime.now(timezone.utc), + onupdate=datetime.now(timezone.utc), + index=True, + ) + + +class VectorIndexModel(BaseModel): + """Model for vector indices.""" + + __tablename__ = "vector_indices" + + _intid: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement="auto") + id: Mapped[str] = mapped_column(String, Computed("'VI' || _intid"), nullable=False, unique=True) + name: Mapped[str] = mapped_column(String, unique=True, nullable=False) + description: Mapped[Optional[str]] = mapped_column(String) + status: Mapped[DocumentStatus] = mapped_column(String, nullable=False, default="processing") + document_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0) + chunk_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0) + error_message: Mapped[Optional[str]] = mapped_column(String) + + # Store configuration + embedding_config: Mapped[Dict[str, Any]] = mapped_column( + JSON, + nullable=False, + comment="Configuration for embeddings including: model, dimensions, batch_size, etc.", + ) + + # Foreign key to document collection + collection_id: Mapped[str] = mapped_column( + String, ForeignKey("document_collections.id"), nullable=False + ) + document_collection: Mapped[DocumentCollectionModel] = relationship( + "DocumentCollectionModel", back_populates="vector_indices" + ) + + created_at: Mapped[datetime] = mapped_column( + DateTime, default=datetime.now(timezone.utc), index=True + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime, + default=datetime.now(timezone.utc), + onupdate=datetime.now(timezone.utc), + index=True, + ) + + +class DocumentProcessingProgressModel(BaseModel): + """Model for tracking processing progress.""" + + __tablename__ = "document_processing_progress" + + id: Mapped[str] = mapped_column(String, primary_key=True) + status: Mapped[str] = mapped_column(String, nullable=False, default="pending") + progress: Mapped[float] = mapped_column(Float, nullable=False, default=0.0) + current_step: Mapped[str] = mapped_column(String, nullable=False, default="initializing") + total_files: Mapped[int] = mapped_column(Integer, nullable=False, default=0) + processed_files: Mapped[int] = mapped_column(Integer, nullable=False, default=0) + total_chunks: Mapped[int] = mapped_column(Integer, nullable=False, default=0) + processed_chunks: Mapped[int] = mapped_column(Integer, nullable=False, default=0) + error_message: Mapped[Optional[str]] = mapped_column(String, nullable=True) + created_at: Mapped[datetime] = mapped_column( + DateTime, default=datetime.now(timezone.utc), index=True + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime, + default=datetime.now(timezone.utc), + onupdate=datetime.now(timezone.utc), + index=True, + ) diff --git a/pyspur/backend/pyspur/models/eval_run_model.py b/pyspur/backend/pyspur/models/eval_run_model.py new file mode 100644 index 0000000000000000000000000000000000000000..3221347b298669f0c703f3de7bc77f013c56e2ce --- /dev/null +++ b/pyspur/backend/pyspur/models/eval_run_model.py @@ -0,0 +1,41 @@ +from datetime import datetime, timezone +from enum import Enum as PyEnum +from typing import Any, Optional + +from sqlalchemy import ( + JSON, + Computed, + DateTime, + Enum, + Integer, + String, +) +from sqlalchemy.orm import Mapped, mapped_column + +from .base_model import BaseModel + + +class EvalRunStatus(PyEnum): + PENDING = "PENDING" + RUNNING = "RUNNING" + COMPLETED = "COMPLETED" + FAILED = "FAILED" + + +class EvalRunModel(BaseModel): + __tablename__ = "eval_runs" + + _intid: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement="auto") + id: Mapped[str] = mapped_column(String, Computed("'ER' || _intid"), nullable=False, unique=True) + eval_name: Mapped[str] = mapped_column(String, nullable=False) + workflow_id: Mapped[str] = mapped_column(String, nullable=False) + status: Mapped[EvalRunStatus] = mapped_column( + Enum(EvalRunStatus), default=EvalRunStatus.PENDING, nullable=False + ) + output_variable: Mapped[str] = mapped_column(String, nullable=False) + num_samples: Mapped[int] = mapped_column(Integer, default=10) + start_time: Mapped[Optional[datetime]] = mapped_column( + DateTime, default=datetime.now(timezone.utc), index=True + ) + end_time: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True, index=True) + results: Mapped[Optional[Any]] = mapped_column(JSON, nullable=True) diff --git a/pyspur/backend/pyspur/models/management/alembic/README b/pyspur/backend/pyspur/models/management/alembic/README new file mode 100644 index 0000000000000000000000000000000000000000..98e4f9c44effe479ed38c66ba922e7bcc672916f --- /dev/null +++ b/pyspur/backend/pyspur/models/management/alembic/README @@ -0,0 +1 @@ +Generic single-database configuration. \ No newline at end of file diff --git a/pyspur/backend/pyspur/models/management/alembic/env.py b/pyspur/backend/pyspur/models/management/alembic/env.py new file mode 100644 index 0000000000000000000000000000000000000000..254b25ee4203c08f4d259cf8798b091e3dcfd5cc --- /dev/null +++ b/pyspur/backend/pyspur/models/management/alembic/env.py @@ -0,0 +1,100 @@ +# ruff: noqa: F401 +from logging.config import fileConfig + +from alembic import context +from sqlalchemy import engine_from_config, pool + +# Import database URL +from pyspur.database import database_url +from pyspur.models.base_model import BaseModel +from pyspur.models.dataset_model import DatasetModel # type: ignore +from pyspur.models.dc_and_vi_model import ( + DocumentCollectionModel, # type: ignore + VectorIndexModel, # type: ignore +) +from pyspur.models.eval_run_model import EvalRunModel # type: ignore +from pyspur.models.output_file_model import OutputFileModel # type: ignore +from pyspur.models.run_model import RunModel # type: ignore +from pyspur.models.slack_agent_model import SlackAgentModel # type: ignore +from pyspur.models.task_model import TaskModel # type: ignore +from pyspur.models.user_session_model import MessageModel, SessionModel, UserModel # type: ignore +from pyspur.models.workflow_model import WorkflowModel # type: ignore +from pyspur.models.workflow_version_model import WorkflowVersionModel # type: ignore + +# this is the Alembic Config object, which provides +# access to the values within the .ini file in use. +config = context.config + +# Interpret the config file for Python logging. +# This line sets up loggers basically. +if config.config_file_name is not None: + fileConfig(config.config_file_name) + +# Set the database URL in the config +config.set_main_option("sqlalchemy.url", database_url) + +# add your model's MetaData object here +target_metadata = BaseModel.metadata + + +def run_migrations_offline() -> None: + """Run migrations in 'offline' mode. + + This configures the context with just a URL + and not an Engine, though an Engine is acceptable + here as well. By skipping the Engine creation + we don't even need a DBAPI to be available. + + Calls to context.execute() here emit the given string to the + script output. + + """ + url = config.get_main_option("sqlalchemy.url") + context.configure( + url=url, + target_metadata=target_metadata, + literal_binds=True, + dialect_opts={"paramstyle": "named"}, + ) + + with context.begin_transaction(): + context.run_migrations() + + +def run_migrations_online() -> None: + """Run migrations in 'online' mode. + + In this scenario we need to create an Engine + and associate a connection with the context. + + """ + connectable = engine_from_config( + config.get_section(config.config_ini_section, {}), + prefix="sqlalchemy.", + poolclass=pool.NullPool, + ) + # use render_as_batch=True for SQLite + url = config.get_main_option("sqlalchemy.url") + if url is not None and url.startswith("sqlite"): + render_as_batch = True + else: + render_as_batch = False + print("#" * 50) + print(f"Using render_as_batch={render_as_batch}, url={url}") + print("#" * 50) + + with connectable.connect() as connection: + context.configure( + connection=connection, + target_metadata=target_metadata, + render_as_batch=render_as_batch, + ) + + with context.begin_transaction(): + context.run_migrations() + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/pyspur/backend/pyspur/models/management/alembic/script.py.mako b/pyspur/backend/pyspur/models/management/alembic/script.py.mako new file mode 100644 index 0000000000000000000000000000000000000000..aa5053c91cc21a9a90f9ce5aa986eab1610f05de --- /dev/null +++ b/pyspur/backend/pyspur/models/management/alembic/script.py.mako @@ -0,0 +1,26 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision: str = ${repr(up_revision)} +down_revision: Union[str, None] = ${repr(down_revision)} +branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)} +depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)} + + +def upgrade() -> None: + ${upgrades if upgrades else "pass"} + + +def downgrade() -> None: + ${downgrades if downgrades else "pass"} diff --git a/pyspur/backend/pyspur/models/management/alembic/versions/000_init_db.py b/pyspur/backend/pyspur/models/management/alembic/versions/000_init_db.py new file mode 100644 index 0000000000000000000000000000000000000000..90ab0d52a0cb51786eeaabb86f2e14b3e2064fdc --- /dev/null +++ b/pyspur/backend/pyspur/models/management/alembic/versions/000_init_db.py @@ -0,0 +1,268 @@ +"""init_db + +Revision ID: 000 +Revises: +Create Date: 2025-01-06 00:42:14.253167 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "000" +down_revision: Union[str, None] = None +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "datasets", + sa.Column("_intid", sa.Integer(), nullable=False), + sa.Column( + "id", + sa.String(), + sa.Computed( + "'DS' || _intid", + ), + nullable=False, + ), + sa.Column("name", sa.String(), nullable=False), + sa.Column("description", sa.String(), nullable=True), + sa.Column("file_path", sa.String(), nullable=False), + sa.Column("uploaded_at", sa.DateTime(), nullable=False), + sa.PrimaryKeyConstraint("_intid"), + sa.UniqueConstraint("_intid"), + sa.UniqueConstraint("id"), + sa.UniqueConstraint("name"), + ) + op.create_table( + "eval_runs", + sa.Column("_intid", sa.Integer(), nullable=False), + sa.Column( + "id", + sa.String(), + sa.Computed( + "'ER' || _intid", + ), + nullable=False, + ), + sa.Column("eval_name", sa.String(), nullable=False), + sa.Column("workflow_id", sa.String(), nullable=False), + sa.Column( + "status", + sa.Enum( + "PENDING", + "RUNNING", + "COMPLETED", + "FAILED", + name="evalrunstatus", + ), + nullable=False, + ), + sa.Column("output_variable", sa.String(), nullable=False), + sa.Column("num_samples", sa.Integer(), nullable=False), + sa.Column("start_time", sa.DateTime(), nullable=True), + sa.Column("end_time", sa.DateTime(), nullable=True), + sa.Column("results", sa.JSON(), nullable=True), + sa.PrimaryKeyConstraint("_intid"), + sa.UniqueConstraint("id"), + ) + op.create_table( + "output_files", + sa.Column("_intid", sa.Integer(), nullable=False), + sa.Column( + "id", + sa.String(), + sa.Computed( + "'OF' || _intid", + ), + nullable=False, + ), + sa.Column("file_name", sa.String(), nullable=False), + sa.Column("file_path", sa.String(), nullable=False), + sa.Column("created_at", sa.DateTime(), nullable=False), + sa.Column("updated_at", sa.DateTime(), nullable=False), + sa.PrimaryKeyConstraint("_intid"), + sa.UniqueConstraint("id"), + ) + op.create_table( + "workflows", + sa.Column("_intid", sa.Integer(), nullable=False), + sa.Column( + "id", + sa.String(), + sa.Computed( + "'S' || _intid", + ), + nullable=False, + ), + sa.Column("name", sa.String(), nullable=False), + sa.Column("description", sa.String(), nullable=True), + sa.Column("definition", sa.JSON(), nullable=False), + sa.Column("created_at", sa.DateTime(), nullable=False), + sa.Column("updated_at", sa.DateTime(), nullable=False), + sa.PrimaryKeyConstraint("_intid"), + sa.UniqueConstraint("id"), + sa.UniqueConstraint("name"), + ) + op.create_table( + "workflow_versions", + sa.Column("_intid", sa.Integer(), nullable=False), + sa.Column( + "id", + sa.String(), + sa.Computed( + "'SV' || _intid", + ), + nullable=False, + ), + sa.Column("version", sa.Integer(), nullable=False), + sa.Column("workflow_id", sa.String(), nullable=False), + sa.Column("name", sa.String(), nullable=False), + sa.Column("description", sa.String(), nullable=True), + sa.Column("definition", sa.JSON(), nullable=False), + sa.Column("definition_hash", sa.String(), nullable=False), + sa.Column("created_at", sa.DateTime(), nullable=False), + sa.Column("updated_at", sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint( + ["workflow_id"], + ["workflows.id"], + ), + sa.PrimaryKeyConstraint("_intid"), + sa.UniqueConstraint("id"), + ) + op.create_index( + op.f("ix_workflow_versions_version"), + "workflow_versions", + ["version"], + unique=True, + ) + op.create_index( + op.f("ix_workflow_versions_workflow_id"), + "workflow_versions", + ["workflow_id"], + unique=False, + ) + op.create_table( + "runs", + sa.Column("_intid", sa.Integer(), nullable=False), + sa.Column( + "id", + sa.String(), + sa.Computed( + "'R' || _intid", + ), + nullable=False, + ), + sa.Column("workflow_id", sa.String(), nullable=False), + sa.Column("workflow_version_id", sa.String(), nullable=False), + sa.Column("parent_run_id", sa.String(), nullable=True), + sa.Column( + "status", + sa.Enum("PENDING", "RUNNING", "COMPLETED", "FAILED", name="runstatus"), + nullable=False, + ), + sa.Column("run_type", sa.String(), nullable=False), + sa.Column("initial_inputs", sa.JSON(), nullable=True), + sa.Column("input_dataset_id", sa.String(), nullable=True), + sa.Column("start_time", sa.DateTime(), nullable=True), + sa.Column("end_time", sa.DateTime(), nullable=True), + sa.Column("outputs", sa.JSON(), nullable=True), + sa.Column("output_file_id", sa.String(), nullable=True), + sa.ForeignKeyConstraint( + ["input_dataset_id"], + ["datasets.id"], + ), + sa.ForeignKeyConstraint( + ["output_file_id"], + ["output_files.id"], + ), + sa.ForeignKeyConstraint( + ["parent_run_id"], + ["runs.id"], + ), + sa.ForeignKeyConstraint( + ["workflow_id"], + ["workflows.id"], + ), + sa.ForeignKeyConstraint( + ["workflow_version_id"], + ["workflow_versions.id"], + ), + sa.PrimaryKeyConstraint("_intid"), + sa.UniqueConstraint("id"), + ) + op.create_index( + op.f("ix_runs_input_dataset_id"), + "runs", + ["input_dataset_id"], + unique=False, + ) + op.create_index(op.f("ix_runs_parent_run_id"), "runs", ["parent_run_id"], unique=False) + op.create_index(op.f("ix_runs_workflow_id"), "runs", ["workflow_id"], unique=False) + op.create_index( + op.f("ix_runs_workflow_version_id"), + "runs", + ["workflow_version_id"], + unique=False, + ) + op.create_table( + "tasks", + sa.Column("_intid", sa.Integer(), nullable=False), + sa.Column( + "id", + sa.String(), + sa.Computed( + "'T' || _intid", + ), + nullable=False, + ), + sa.Column("run_id", sa.String(), nullable=False), + sa.Column("node_id", sa.String(), nullable=False), + sa.Column("parent_task_id", sa.String(), nullable=True), + sa.Column( + "status", + sa.Enum("PENDING", "RUNNING", "COMPLETED", "FAILED", name="taskstatus"), + nullable=False, + ), + sa.Column("inputs", sa.JSON(), nullable=True), + sa.Column("outputs", sa.JSON(), nullable=True), + sa.Column("start_time", sa.DateTime(), nullable=True), + sa.Column("end_time", sa.DateTime(), nullable=True), + sa.Column("subworkflow", sa.JSON(), nullable=True), + sa.Column("subworkflow_output", sa.JSON(), nullable=True), + sa.ForeignKeyConstraint( + ["parent_task_id"], + ["tasks.id"], + ), + sa.ForeignKeyConstraint( + ["run_id"], + ["runs.id"], + ), + sa.PrimaryKeyConstraint("_intid"), + sa.UniqueConstraint("id"), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("tasks") + op.drop_index(op.f("ix_runs_workflow_version_id"), table_name="runs") + op.drop_index(op.f("ix_runs_workflow_id"), table_name="runs") + op.drop_index(op.f("ix_runs_parent_run_id"), table_name="runs") + op.drop_index(op.f("ix_runs_input_dataset_id"), table_name="runs") + op.drop_table("runs") + op.drop_index(op.f("ix_workflow_versions_workflow_id"), table_name="workflow_versions") + op.drop_index(op.f("ix_workflow_versions_version"), table_name="workflow_versions") + op.drop_table("workflow_versions") + op.drop_table("workflows") + op.drop_table("output_files") + op.drop_table("eval_runs") + op.drop_table("datasets") + # ### end Alembic commands ### diff --git a/pyspur/backend/pyspur/models/management/alembic/versions/001_fix_workflow_version_idx.py b/pyspur/backend/pyspur/models/management/alembic/versions/001_fix_workflow_version_idx.py new file mode 100644 index 0000000000000000000000000000000000000000..ec0acd2e75a7d8792f5517f64f304aebf16c4f42 --- /dev/null +++ b/pyspur/backend/pyspur/models/management/alembic/versions/001_fix_workflow_version_idx.py @@ -0,0 +1,42 @@ +"""fix_workflow_version_idx + +Revision ID: 001 +Revises: 000 +Create Date: 2025-01-06 20:47:53.181743 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "001" +down_revision: Union[str, None] = "000" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index("ix_workflow_versions_version", table_name="workflow_versions") + op.create_index( + op.f("ix_workflow_versions_version"), + "workflow_versions", + ["version"], + unique=False, + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index(op.f("ix_workflow_versions_version"), table_name="workflow_versions") + op.create_index( + "ix_workflow_versions_version", + "workflow_versions", + ["version"], + unique=True, + ) + # ### end Alembic commands ### diff --git a/pyspur/backend/pyspur/models/management/alembic/versions/002_add_knowledge_base_model.py b/pyspur/backend/pyspur/models/management/alembic/versions/002_add_knowledge_base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..80f08d3d694f4a655c1f822fbb8e259f26415990 --- /dev/null +++ b/pyspur/backend/pyspur/models/management/alembic/versions/002_add_knowledge_base_model.py @@ -0,0 +1,54 @@ +"""add_knowledge_base_model + +Revision ID: 002 +Revises: 001 +Create Date: 2025-01-10 15:34:29.890929 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "002" +down_revision: Union[str, None] = "001" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "knowledge_bases", + sa.Column("_intid", sa.Integer(), nullable=False), + sa.Column( + "id", + sa.String(), + sa.Computed( + "'KB' || _intid", + ), + nullable=False, + ), + sa.Column("name", sa.String(), nullable=False), + sa.Column("description", sa.String(), nullable=True), + sa.Column("status", sa.String(), nullable=False), + sa.Column("document_count", sa.Integer(), nullable=False), + sa.Column("chunk_count", sa.Integer(), nullable=False), + sa.Column("error_message", sa.String(), nullable=True), + sa.Column("text_processing_config", sa.JSON(), nullable=False), + sa.Column("embedding_config", sa.JSON(), nullable=False), + sa.Column("created_at", sa.DateTime(), nullable=False), + sa.Column("updated_at", sa.DateTime(), nullable=False), + sa.PrimaryKeyConstraint("_intid"), + sa.UniqueConstraint("id"), + sa.UniqueConstraint("name"), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("knowledge_bases") + # ### end Alembic commands ### diff --git a/pyspur/backend/pyspur/models/management/alembic/versions/003_split_knowledge_base_into_dc_and_vi.py b/pyspur/backend/pyspur/models/management/alembic/versions/003_split_knowledge_base_into_dc_and_vi.py new file mode 100644 index 0000000000000000000000000000000000000000..e7c68cf3873be2f95fd9798daac9a8610e7ed28f --- /dev/null +++ b/pyspur/backend/pyspur/models/management/alembic/versions/003_split_knowledge_base_into_dc_and_vi.py @@ -0,0 +1,129 @@ +"""split_knowledge_base_into_dc_and_vi + +Revision ID: 003 +Revises: 002 +Create Date: 2025-01-13 19:42:54.414404 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = "003" +down_revision: Union[str, None] = "002" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "document_collections", + sa.Column("_intid", sa.Integer(), nullable=False), + sa.Column( + "id", + sa.String(), + sa.Computed( + "'DC' || _intid", + ), + nullable=False, + ), + sa.Column("name", sa.String(), nullable=False), + sa.Column("description", sa.String(), nullable=True), + sa.Column("status", sa.String(), nullable=False), + sa.Column("document_count", sa.Integer(), nullable=False), + sa.Column("chunk_count", sa.Integer(), nullable=False), + sa.Column("error_message", sa.String(), nullable=True), + sa.Column("text_processing_config", sa.JSON(), nullable=False), + sa.Column("created_at", sa.DateTime(), nullable=False), + sa.Column("updated_at", sa.DateTime(), nullable=False), + sa.PrimaryKeyConstraint("_intid"), + sa.UniqueConstraint("id"), + sa.UniqueConstraint("name"), + ) + op.create_table( + "vector_indices", + sa.Column("_intid", sa.Integer(), nullable=False), + sa.Column( + "id", + sa.String(), + sa.Computed( + "'VI' || _intid", + ), + nullable=False, + ), + sa.Column("name", sa.String(), nullable=False), + sa.Column("description", sa.String(), nullable=True), + sa.Column("status", sa.String(), nullable=False), + sa.Column("document_count", sa.Integer(), nullable=False), + sa.Column("chunk_count", sa.Integer(), nullable=False), + sa.Column("error_message", sa.String(), nullable=True), + sa.Column("embedding_config", sa.JSON(), nullable=False), + sa.Column("collection_id", sa.String(), nullable=False), + sa.Column("created_at", sa.DateTime(), nullable=False), + sa.Column("updated_at", sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint( + ["collection_id"], + ["document_collections.id"], + ), + sa.PrimaryKeyConstraint("_intid"), + sa.UniqueConstraint("id"), + sa.UniqueConstraint("name"), + ) + op.drop_table("knowledge_bases") + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "knowledge_bases", + sa.Column("_intid", sa.INTEGER(), autoincrement=True, nullable=False), + sa.Column( + "id", + sa.VARCHAR(), + sa.Computed("('KB'::text || _intid)", persisted=True), + autoincrement=False, + nullable=False, + ), + sa.Column("name", sa.VARCHAR(), autoincrement=False, nullable=False), + sa.Column("description", sa.VARCHAR(), autoincrement=False, nullable=True), + sa.Column("status", sa.VARCHAR(), autoincrement=False, nullable=False), + sa.Column("document_count", sa.INTEGER(), autoincrement=False, nullable=False), + sa.Column("chunk_count", sa.INTEGER(), autoincrement=False, nullable=False), + sa.Column("error_message", sa.VARCHAR(), autoincrement=False, nullable=True), + sa.Column( + "text_processing_config", + postgresql.JSON(astext_type=sa.Text()), + autoincrement=False, + nullable=False, + ), + sa.Column( + "embedding_config", + postgresql.JSON(astext_type=sa.Text()), + autoincrement=False, + nullable=False, + ), + sa.Column( + "created_at", + postgresql.TIMESTAMP(), + autoincrement=False, + nullable=False, + ), + sa.Column( + "updated_at", + postgresql.TIMESTAMP(), + autoincrement=False, + nullable=False, + ), + sa.PrimaryKeyConstraint("_intid", name="knowledge_bases_pkey"), + sa.UniqueConstraint("id", name="knowledge_bases_id_key"), + sa.UniqueConstraint("name", name="knowledge_bases_name_key"), + ) + op.drop_table("vector_indices") + op.drop_table("document_collections") + # ### end Alembic commands ### diff --git a/pyspur/backend/pyspur/models/management/alembic/versions/004_add_progress_status.py b/pyspur/backend/pyspur/models/management/alembic/versions/004_add_progress_status.py new file mode 100644 index 0000000000000000000000000000000000000000..c33f660e94041cedcfaacac799290d41ea76c1e9 --- /dev/null +++ b/pyspur/backend/pyspur/models/management/alembic/versions/004_add_progress_status.py @@ -0,0 +1,44 @@ +"""add_progress_status + +Revision ID: 004 +Revises: 003 +Create Date: 2025-01-15 00:31:27.898484 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "004" +down_revision: Union[str, None] = "003" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "processing_progress", + sa.Column("id", sa.String(), nullable=False), + sa.Column("status", sa.String(), nullable=False), + sa.Column("progress", sa.Float(), nullable=False), + sa.Column("current_step", sa.String(), nullable=False), + sa.Column("total_files", sa.Integer(), nullable=False), + sa.Column("processed_files", sa.Integer(), nullable=False), + sa.Column("total_chunks", sa.Integer(), nullable=False), + sa.Column("processed_chunks", sa.Integer(), nullable=False), + sa.Column("error_message", sa.String(), nullable=True), + sa.Column("created_at", sa.DateTime(), nullable=False), + sa.Column("updated_at", sa.DateTime(), nullable=False), + sa.PrimaryKeyConstraint("id"), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("processing_progress") + # ### end Alembic commands ### diff --git a/pyspur/backend/pyspur/models/management/alembic/versions/005_rename_processing_progress_model.py b/pyspur/backend/pyspur/models/management/alembic/versions/005_rename_processing_progress_model.py new file mode 100644 index 0000000000000000000000000000000000000000..b4634236c202f25f0d07ae323065a71157f8855f --- /dev/null +++ b/pyspur/backend/pyspur/models/management/alembic/versions/005_rename_processing_progress_model.py @@ -0,0 +1,29 @@ +"""rename-processing-progress-model + +Revision ID: 005 +Revises: 004 +Create Date: 2025-01-17 20:29:55.674674 + +""" + +from typing import Sequence, Union + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "005" +down_revision: Union[str, None] = "004" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.rename_table("processing_progress", "document_processing_progress") + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.rename_table("document_processing_progress", "processing_progress") + # ### end Alembic commands ### diff --git a/pyspur/backend/pyspur/models/management/alembic/versions/006_track_error_and_task_cancellation.py b/pyspur/backend/pyspur/models/management/alembic/versions/006_track_error_and_task_cancellation.py new file mode 100644 index 0000000000000000000000000000000000000000..eaa0d44a3ab942498a894d6b1ec5201b6324e7d9 --- /dev/null +++ b/pyspur/backend/pyspur/models/management/alembic/versions/006_track_error_and_task_cancellation.py @@ -0,0 +1,64 @@ +"""track-error-and-task-cancellation + +Revision ID: 006 +Revises: 005 +Create Date: 2025-01-17 21:35:57.061233 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = "006" +down_revision: Union[str, None] = "005" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.alter_column( + "document_collections", + "text_processing_config", + existing_type=postgresql.JSON(astext_type=sa.Text()), + comment="Configuration for text processing including: chunk_token_size, min_chunk_size_chars, etc.", + existing_nullable=False, + ) + op.add_column("tasks", sa.Column("error", sa.String(), nullable=True)) + op.alter_column( + "vector_indices", + "embedding_config", + existing_type=postgresql.JSON(astext_type=sa.Text()), + comment="Configuration for embeddings including: model, dimensions, batch_size, etc.", + existing_nullable=False, + ) + + # Update TaskStatus enum type + op.execute("ALTER TYPE taskstatus ADD VALUE IF NOT EXISTS 'CANCELED'") + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.alter_column( + "vector_indices", + "embedding_config", + existing_type=postgresql.JSON(astext_type=sa.Text()), + comment=None, + existing_comment="Configuration for embeddings including: model, dimensions, batch_size, etc.", + existing_nullable=False, + ) + op.drop_column("tasks", "error") + op.alter_column( + "document_collections", + "text_processing_config", + existing_type=postgresql.JSON(astext_type=sa.Text()), + comment=None, + existing_comment="Configuration for text processing including: chunk_token_size, min_chunk_size_chars, etc.", + existing_nullable=False, + ) + # ### end Alembic commands ### diff --git a/pyspur/backend/pyspur/models/management/alembic/versions/007_add_paused_status.py b/pyspur/backend/pyspur/models/management/alembic/versions/007_add_paused_status.py new file mode 100644 index 0000000000000000000000000000000000000000..798c4a3047afa57a659ed5178409c3d055108630 --- /dev/null +++ b/pyspur/backend/pyspur/models/management/alembic/versions/007_add_paused_status.py @@ -0,0 +1,28 @@ +"""add_paused_status. + +Revision ID: 007 +Revises: 006 +Create Date: 2025-02-23 20:25:36.729391 + +""" + +from typing import Sequence, Union + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "007" +down_revision: Union[str, None] = "006" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.execute("ALTER TYPE runstatus ADD VALUE IF NOT EXISTS 'PAUSED'") + op.execute("ALTER TYPE taskstatus ADD VALUE IF NOT EXISTS 'PAUSED'") + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + pass + # ### end Alembic commands ### diff --git a/pyspur/backend/pyspur/models/management/alembic/versions/008_add_user_session_message_models.py b/pyspur/backend/pyspur/models/management/alembic/versions/008_add_user_session_message_models.py new file mode 100644 index 0000000000000000000000000000000000000000..c2e27a5d75ef0cfce382a61b993954f60224be7e --- /dev/null +++ b/pyspur/backend/pyspur/models/management/alembic/versions/008_add_user_session_message_models.py @@ -0,0 +1,111 @@ +"""add_user_session_message_models. + +Revision ID: 008 +Revises: 007 +Create Date: 2025-03-11 07:02:13.799898 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "008" +down_revision: Union[str, None] = "007" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "users", + sa.Column("_intid", sa.Integer(), nullable=False), + sa.Column( + "id", + sa.String(), + sa.Computed( + "'U' || _intid", + ), + nullable=False, + ), + sa.Column("external_id", sa.String(), nullable=True), + sa.Column("user_metadata", sa.JSON(), nullable=False), + sa.Column("created_at", sa.DateTime(), nullable=False), + sa.Column("updated_at", sa.DateTime(), nullable=False), + sa.PrimaryKeyConstraint("_intid"), + sa.UniqueConstraint("external_id"), + sa.UniqueConstraint("id"), + ) + op.create_table( + "sessions", + sa.Column("_intid", sa.Integer(), nullable=False), + sa.Column( + "id", + sa.String(), + sa.Computed( + "'SN' || _intid", + ), + nullable=False, + ), + sa.Column("user_id", sa.String(), nullable=False), + sa.Column("workflow_id", sa.String(), nullable=False), + sa.Column("created_at", sa.DateTime(), nullable=False), + sa.Column("updated_at", sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint( + ["user_id"], + ["users.id"], + ), + sa.ForeignKeyConstraint( + ["workflow_id"], + ["workflows.id"], + ), + sa.PrimaryKeyConstraint("_intid"), + sa.UniqueConstraint("id"), + ) + op.create_index(op.f("ix_sessions_user_id"), "sessions", ["user_id"], unique=False) + op.create_index(op.f("ix_sessions_workflow_id"), "sessions", ["workflow_id"], unique=False) + op.create_table( + "messages", + sa.Column("_intid", sa.Integer(), nullable=False), + sa.Column( + "id", + sa.String(), + sa.Computed( + "'M' || _intid", + ), + nullable=False, + ), + sa.Column("session_id", sa.String(), nullable=False), + sa.Column("run_id", sa.String(), nullable=True), + sa.Column("content", sa.JSON(), nullable=False), + sa.Column("created_at", sa.DateTime(), nullable=False), + sa.Column("updated_at", sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint( + ["run_id"], + ["runs.id"], + ), + sa.ForeignKeyConstraint( + ["session_id"], + ["sessions.id"], + ), + sa.PrimaryKeyConstraint("_intid"), + sa.UniqueConstraint("id"), + ) + op.create_index(op.f("ix_messages_run_id"), "messages", ["run_id"], unique=False) + op.create_index(op.f("ix_messages_session_id"), "messages", ["session_id"], unique=False) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index(op.f("ix_messages_session_id"), table_name="messages") + op.drop_index(op.f("ix_messages_run_id"), table_name="messages") + op.drop_table("messages") + op.drop_index(op.f("ix_sessions_workflow_id"), table_name="sessions") + op.drop_index(op.f("ix_sessions_user_id"), table_name="sessions") + op.drop_table("sessions") + op.drop_table("users") + # ### end Alembic commands ### diff --git a/pyspur/backend/pyspur/models/management/alembic/versions/009_add_external_id_to_session.py b/pyspur/backend/pyspur/models/management/alembic/versions/009_add_external_id_to_session.py new file mode 100644 index 0000000000000000000000000000000000000000..2a0033ea0f19fb46dbf0361826a5b615f8892767 --- /dev/null +++ b/pyspur/backend/pyspur/models/management/alembic/versions/009_add_external_id_to_session.py @@ -0,0 +1,32 @@ +"""add_external_id_to_session. + +Revision ID: 009 +Revises: 008 +Create Date: 2025-03-12 04:32:15.946036 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "009" +down_revision: Union[str, None] = "008" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("sessions", sa.Column("external_id", sa.String(), nullable=True)) + op.create_unique_constraint("sessions_external_id_key", "sessions", ["external_id"]) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_constraint("sessions_external_id_key", "sessions", type_="unique") + op.drop_column("sessions", "external_id") + # ### end Alembic commands ### diff --git a/pyspur/backend/pyspur/models/management/alembic/versions/010_add_idx_to_time_cols.py b/pyspur/backend/pyspur/models/management/alembic/versions/010_add_idx_to_time_cols.py new file mode 100644 index 0000000000000000000000000000000000000000..887305a75de89b7e5f71ad27f6e55a718307cecd --- /dev/null +++ b/pyspur/backend/pyspur/models/management/alembic/versions/010_add_idx_to_time_cols.py @@ -0,0 +1,177 @@ +"""add_idx_to_time_cols. + +Revision ID: 011 +Revises: 010 +Create Date: 2025-03-18 05:49:31.113627 + +""" + +from typing import Sequence, Union + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "010" +down_revision: Union[str, None] = "009" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_index( + op.f("ix_datasets_uploaded_at"), + "datasets", + ["uploaded_at"], + unique=False, + if_not_exists=True, + ) + op.create_index( + op.f("ix_document_collections_created_at"), + "document_collections", + ["created_at"], + unique=False, + if_not_exists=True, + ) + op.create_index( + op.f("ix_document_collections_updated_at"), + "document_collections", + ["updated_at"], + unique=False, + if_not_exists=True, + ) + op.create_index( + op.f("ix_document_processing_progress_created_at"), + "document_processing_progress", + ["created_at"], + unique=False, + if_not_exists=True, + ) + op.create_index( + op.f("ix_document_processing_progress_updated_at"), + "document_processing_progress", + ["updated_at"], + unique=False, + if_not_exists=True, + ) + op.create_index( + op.f("ix_eval_runs_end_time"), "eval_runs", ["end_time"], unique=False, if_not_exists=True + ) + op.create_index( + op.f("ix_eval_runs_start_time"), + "eval_runs", + ["start_time"], + unique=False, + if_not_exists=True, + ) + op.create_index( + op.f("ix_messages_created_at"), "messages", ["created_at"], unique=False, if_not_exists=True + ) + op.create_index( + op.f("ix_messages_updated_at"), "messages", ["updated_at"], unique=False, if_not_exists=True + ) + op.create_index( + op.f("ix_output_files_created_at"), + "output_files", + ["created_at"], + unique=False, + if_not_exists=True, + ) + op.create_index( + op.f("ix_output_files_updated_at"), + "output_files", + ["updated_at"], + unique=False, + if_not_exists=True, + ) + op.create_index( + op.f("ix_runs_start_time"), "runs", ["start_time"], unique=False, if_not_exists=True + ) + op.create_index( + op.f("ix_sessions_created_at"), "sessions", ["created_at"], unique=False, if_not_exists=True + ) + op.create_index( + op.f("ix_sessions_updated_at"), "sessions", ["updated_at"], unique=False, if_not_exists=True + ) + op.create_index( + op.f("ix_users_created_at"), "users", ["created_at"], unique=False, if_not_exists=True + ) + op.create_index( + op.f("ix_users_updated_at"), "users", ["updated_at"], unique=False, if_not_exists=True + ) + op.create_index( + op.f("ix_vector_indices_created_at"), + "vector_indices", + ["created_at"], + unique=False, + if_not_exists=True, + ) + op.create_index( + op.f("ix_vector_indices_updated_at"), + "vector_indices", + ["updated_at"], + unique=False, + if_not_exists=True, + ) + op.create_index( + op.f("ix_workflow_versions_created_at"), + "workflow_versions", + ["created_at"], + unique=False, + if_not_exists=True, + ) + op.create_index( + op.f("ix_workflow_versions_updated_at"), + "workflow_versions", + ["updated_at"], + unique=False, + if_not_exists=True, + ) + op.create_index( + op.f("ix_workflows_created_at"), + "workflows", + ["created_at"], + unique=False, + if_not_exists=True, + ) + op.create_index( + op.f("ix_workflows_updated_at"), + "workflows", + ["updated_at"], + unique=False, + if_not_exists=True, + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index(op.f("ix_workflows_updated_at"), table_name="workflows") + op.drop_index(op.f("ix_workflows_created_at"), table_name="workflows") + op.drop_index(op.f("ix_workflow_versions_updated_at"), table_name="workflow_versions") + op.drop_index(op.f("ix_workflow_versions_created_at"), table_name="workflow_versions") + op.drop_index(op.f("ix_vector_indices_updated_at"), table_name="vector_indices") + op.drop_index(op.f("ix_vector_indices_created_at"), table_name="vector_indices") + op.drop_index(op.f("ix_users_updated_at"), table_name="users") + op.drop_index(op.f("ix_users_created_at"), table_name="users") + op.drop_index(op.f("ix_sessions_updated_at"), table_name="sessions") + op.drop_index(op.f("ix_sessions_created_at"), table_name="sessions") + op.drop_index(op.f("ix_runs_start_time"), table_name="runs") + op.drop_index(op.f("ix_output_files_updated_at"), table_name="output_files") + op.drop_index(op.f("ix_output_files_created_at"), table_name="output_files") + op.drop_index(op.f("ix_messages_updated_at"), table_name="messages") + op.drop_index(op.f("ix_messages_created_at"), table_name="messages") + op.drop_index(op.f("ix_eval_runs_start_time"), table_name="eval_runs") + op.drop_index(op.f("ix_eval_runs_end_time"), table_name="eval_runs") + op.drop_index( + op.f("ix_document_processing_progress_updated_at"), + table_name="document_processing_progress", + ) + op.drop_index( + op.f("ix_document_processing_progress_created_at"), + table_name="document_processing_progress", + ) + op.drop_index(op.f("ix_document_collections_updated_at"), table_name="document_collections") + op.drop_index(op.f("ix_document_collections_created_at"), table_name="document_collections") + op.drop_index(op.f("ix_datasets_uploaded_at"), table_name="datasets") + # ### end Alembic commands ### diff --git a/pyspur/backend/pyspur/models/management/alembic/versions/011_slack_agent.py b/pyspur/backend/pyspur/models/management/alembic/versions/011_slack_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..4426c26d3f73be9faaa0886826837d4e899bda3d --- /dev/null +++ b/pyspur/backend/pyspur/models/management/alembic/versions/011_slack_agent.py @@ -0,0 +1,58 @@ +"""slack_agent. + +Revision ID: 010 +Revises: 009 +Create Date: 2025-03-16 15:09:40.938378 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "011" +down_revision: Union[str, None] = "010" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "slack_agents", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("name", sa.String(), nullable=True), + sa.Column("slack_team_id", sa.String(), nullable=True), + sa.Column("slack_team_name", sa.String(), nullable=True), + sa.Column("slack_channel_id", sa.String(), nullable=True), + sa.Column("slack_channel_name", sa.String(), nullable=True), + sa.Column("is_active", sa.Boolean(), nullable=True), + sa.Column("workflow_id", sa.String(), nullable=True), + sa.Column("trigger_on_mention", sa.Boolean(), nullable=True), + sa.Column("trigger_on_direct_message", sa.Boolean(), nullable=True), + sa.Column("trigger_on_channel_message", sa.Boolean(), nullable=True), + sa.Column("trigger_keywords", sa.JSON(), nullable=True), + sa.Column("trigger_enabled", sa.Boolean(), nullable=True), + sa.ForeignKeyConstraint( + ["workflow_id"], + ["workflows.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index(op.f("ix_slack_agents_id"), "slack_agents", ["id"], unique=False) + op.create_index(op.f("ix_slack_agents_name"), "slack_agents", ["name"], unique=False) + op.create_index( + op.f("ix_slack_agents_slack_team_id"), "slack_agents", ["slack_team_id"], unique=False + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index(op.f("ix_slack_agents_slack_team_id"), table_name="slack_agents") + op.drop_index(op.f("ix_slack_agents_name"), table_name="slack_agents") + op.drop_index(op.f("ix_slack_agents_id"), table_name="slack_agents") + op.drop_table("slack_agents") + # ### end Alembic commands ### diff --git a/pyspur/backend/pyspur/models/management/alembic/versions/012_add_slack_tokens.py b/pyspur/backend/pyspur/models/management/alembic/versions/012_add_slack_tokens.py new file mode 100644 index 0000000000000000000000000000000000000000..1baa159738c1380cbacd15c689fd366a0d91f6b3 --- /dev/null +++ b/pyspur/backend/pyspur/models/management/alembic/versions/012_add_slack_tokens.py @@ -0,0 +1,34 @@ +"""add_slack_tokens + +Revision ID: 012 +Revises: 011 +Create Date: 2025-03-18 18:10:53.774095 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '012' +down_revision: Union[str, None] = '011' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('slack_agents', sa.Column('has_bot_token', sa.Boolean(), nullable=True)) + op.add_column('slack_agents', sa.Column('has_user_token', sa.Boolean(), nullable=True)) + op.add_column('slack_agents', sa.Column('last_token_update', sa.String(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('slack_agents', 'last_token_update') + op.drop_column('slack_agents', 'has_user_token') + op.drop_column('slack_agents', 'has_bot_token') + # ### end Alembic commands ### diff --git a/pyspur/backend/pyspur/models/management/alembic/versions/013_add_date_to_slack_agent.py b/pyspur/backend/pyspur/models/management/alembic/versions/013_add_date_to_slack_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..8d0e9bb1b9b8d3a3888ab8768a7b954533eb90a9 --- /dev/null +++ b/pyspur/backend/pyspur/models/management/alembic/versions/013_add_date_to_slack_agent.py @@ -0,0 +1,32 @@ +"""add_date_to_slack_agent + +Revision ID: 013 +Revises: 012 +Create Date: 2025-03-18 18:27:04.191108 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '013' +down_revision: Union[str, None] = '012' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('slack_agents', sa.Column('spur_type', sa.String(), nullable=True)) + op.add_column('slack_agents', sa.Column('created_at', sa.DateTime(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('slack_agents', 'created_at') + op.drop_column('slack_agents', 'spur_type') + # ### end Alembic commands ### diff --git a/pyspur/backend/pyspur/models/management/alembic/versions/014_add_slack_app_token_flag.py b/pyspur/backend/pyspur/models/management/alembic/versions/014_add_slack_app_token_flag.py new file mode 100644 index 0000000000000000000000000000000000000000..2bc3f548c081fd825dc407f47a6159ca0487286e --- /dev/null +++ b/pyspur/backend/pyspur/models/management/alembic/versions/014_add_slack_app_token_flag.py @@ -0,0 +1,30 @@ +"""add_slack_app_token_flag + +Revision ID: 014 +Revises: 013 +Create Date: 2025-03-20 00:07:50.709065 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '014' +down_revision: Union[str, None] = '013' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('slack_agents', sa.Column('has_app_token', sa.Boolean(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('slack_agents', 'has_app_token') + # ### end Alembic commands ### diff --git a/pyspur/backend/pyspur/models/management/alembic/versions/015_add_persistent_socket_mode.py b/pyspur/backend/pyspur/models/management/alembic/versions/015_add_persistent_socket_mode.py new file mode 100644 index 0000000000000000000000000000000000000000..e4333ea49f4a76a8cdbb80b1032695b6948610cb --- /dev/null +++ b/pyspur/backend/pyspur/models/management/alembic/versions/015_add_persistent_socket_mode.py @@ -0,0 +1,41 @@ +"""add_persistent_socket_mode + +Revision ID: 015 +Revises: 014 +Create Date: 2025-03-22 11:32:40.009239 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "015" +down_revision: Union[str, None] = "014" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column( + "slack_agents", + sa.Column("socket_mode_enabled", sa.Boolean(), nullable=True, server_default="false"), + ) + + # Set default value for existing records + op.execute( + "UPDATE slack_agents SET socket_mode_enabled = false WHERE socket_mode_enabled IS NULL" + ) + + # Make the column not nullable after setting default values + op.alter_column("slack_agents", "socket_mode_enabled", nullable=False) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("slack_agents", "socket_mode_enabled") + # ### end Alembic commands ### diff --git a/pyspur/backend/pyspur/models/output_file_model.py b/pyspur/backend/pyspur/models/output_file_model.py new file mode 100644 index 0000000000000000000000000000000000000000..6aa8d9bbd08ce3f188b7e872d393b171df57ecef --- /dev/null +++ b/pyspur/backend/pyspur/models/output_file_model.py @@ -0,0 +1,31 @@ +from datetime import datetime, timezone + +from sqlalchemy import Computed, DateTime, Integer, String +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from .base_model import BaseModel + + +class OutputFileModel(BaseModel): + __tablename__ = "output_files" + + _intid: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement="auto") + id: Mapped[str] = mapped_column(String, Computed("'OF' || _intid"), nullable=False, unique=True) + file_name: Mapped[str] = mapped_column(String, nullable=False) + file_path: Mapped[str] = mapped_column(String, nullable=False) + created_at: Mapped[datetime] = mapped_column( + DateTime, default=datetime.now(timezone.utc), index=True + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime, + default=datetime.now(timezone.utc), + onupdate=datetime.now(timezone.utc), + index=True, + ) + + run = relationship( + "RunModel", + back_populates="output_file", + single_parent=True, + cascade="all, delete-orphan", + ) diff --git a/pyspur/backend/pyspur/models/run_model.py b/pyspur/backend/pyspur/models/run_model.py new file mode 100644 index 0000000000000000000000000000000000000000..3983a4c6954dd62923bf359e7353a72775622a80 --- /dev/null +++ b/pyspur/backend/pyspur/models/run_model.py @@ -0,0 +1,91 @@ +from datetime import datetime, timezone +from enum import Enum as PyEnum +from typing import Any, Dict, List, Optional + +from sqlalchemy import ( + JSON, + Computed, + DateTime, + Enum, + ForeignKey, + Integer, + String, +) +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from .base_model import BaseModel +from .output_file_model import OutputFileModel +from .task_model import TaskModel +from .workflow_model import WorkflowModel + + +class RunStatus(PyEnum): + PENDING = "PENDING" + RUNNING = "RUNNING" + COMPLETED = "COMPLETED" + FAILED = "FAILED" + PAUSED = "PAUSED" # Added for human intervention nodes + CANCELED = "CANCELED" # Added for canceling workflows awaiting human approval + + +class RunModel(BaseModel): + __tablename__ = "runs" + + _intid: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement="auto") + id: Mapped[str] = mapped_column(String, Computed("'R' || _intid"), nullable=False, unique=True) + workflow_id: Mapped[str] = mapped_column( + String, ForeignKey("workflows.id"), nullable=False, index=True + ) + workflow_version_id: Mapped[int] = mapped_column( + String, ForeignKey("workflow_versions.id"), nullable=False, index=True + ) + parent_run_id: Mapped[Optional[str]] = mapped_column( + String, ForeignKey("runs.id"), nullable=True, index=True + ) + status: Mapped[RunStatus] = mapped_column( + Enum(RunStatus), default=RunStatus.PENDING, nullable=False + ) + run_type: Mapped[str] = mapped_column(String, nullable=False) + initial_inputs: Mapped[Optional[Dict[str, Dict[str, Any]]]] = mapped_column(JSON, nullable=True) + input_dataset_id: Mapped[Optional[str]] = mapped_column( + String, ForeignKey("datasets.id"), nullable=True, index=True + ) + start_time: Mapped[Optional[datetime]] = mapped_column( + DateTime, default=datetime.now(timezone.utc), index=True + ) + end_time: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + outputs: Mapped[Optional[Any]] = mapped_column(JSON, nullable=True) + output_file_id: Mapped[Optional[str]] = mapped_column( + String, ForeignKey("output_files.id"), nullable=True + ) + tasks: Mapped[List["TaskModel"]] = relationship("TaskModel", cascade="all, delete-orphan") + parent_run: Mapped[Optional["RunModel"]] = relationship( + "RunModel", + remote_side=[id], + back_populates="subruns", + ) + subruns: Mapped[List["RunModel"]] = relationship( + "RunModel", back_populates="parent_run", cascade="all, delete-orphan" + ) + output_file: Mapped[Optional["OutputFileModel"]] = relationship( + "OutputFileModel", back_populates="run" + ) + workflow: Mapped["WorkflowModel"] = relationship("WorkflowModel", foreign_keys=[workflow_id]) + + @property + def percentage_complete(self) -> Optional[float]: + if self.status == RunStatus.PENDING: + return 0.0 + elif self.status == RunStatus.COMPLETED: + return 1.0 + elif self.status == RunStatus.FAILED: + return 0.0 + elif self.initial_inputs: + return 0.5 + elif self.input_dataset_id: + # return percentage of subruns completed + return ( + 1.0 + * len([subrun for subrun in self.subruns if subrun.status == RunStatus.COMPLETED]) + / (1.0 * len(self.subruns)) + ) diff --git a/pyspur/backend/pyspur/models/slack_agent_model.py b/pyspur/backend/pyspur/models/slack_agent_model.py new file mode 100644 index 0000000000000000000000000000000000000000..006937aeee01c2ba6a3c346b189e64f8319b25ca --- /dev/null +++ b/pyspur/backend/pyspur/models/slack_agent_model.py @@ -0,0 +1,108 @@ +import os +from datetime import UTC, datetime +from typing import Any + +from sqlalchemy import JSON, Boolean, Column, DateTime, ForeignKey, Integer, String +from sqlalchemy.orm import relationship + +from .base_model import BaseModel + + +class SlackAgentModel(BaseModel): + """Model for storing Slack agent configurations.""" + + __tablename__ = "slack_agents" + + id = Column(Integer, primary_key=True, index=True) + name = Column(String, index=True) + slack_team_id = Column(String, index=True) + slack_team_name = Column(String) + slack_channel_id = Column(String) + slack_channel_name = Column(String) + is_active = Column(Boolean, default=True) + + # Type of Spur agent + spur_type = Column(String, default="workflow") + + # Workflow association + workflow_id = Column(String, ForeignKey("workflows.id"), nullable=True) + workflow = relationship("WorkflowModel", backref="slack_agents") + + # Token reference - we don't store actual tokens here + has_bot_token = Column(Boolean, default=False) + has_user_token = Column(Boolean, default=False) + has_app_token = Column(Boolean, default=False) + last_token_update = Column(String, nullable=True) + + # Trigger configuration + trigger_on_mention = Column(Boolean, default=True) + trigger_on_direct_message = Column(Boolean, default=True) + trigger_on_channel_message = Column(Boolean, default=False) + trigger_keywords = Column(JSON, default=list) + trigger_enabled = Column(Boolean, default=True) + + # Socket Mode configuration + socket_mode_enabled = Column(Boolean, default=False) + + # Creation timestamp + created_at = Column(DateTime, default=lambda: datetime.now(UTC)) + + @property + def has_required_tokens(self) -> bool: + """Check if the agent has the required tokens for basic operation.""" + return bool(getattr(self, "has_bot_token", False)) + + @property + def has_socket_mode_tokens(self) -> bool: + """Check if the agent has the tokens required for Socket Mode.""" + return bool(getattr(self, "has_bot_token", False)) and ( + bool(getattr(self, "has_app_token", False)) or bool(os.getenv("SLACK_APP_TOKEN")) + ) + + def update_token_flags(self, token_type: str, has_token: bool) -> None: + """Update token flags based on token type.""" + if token_type == "bot_token": + self.has_bot_token = has_token + elif token_type == "user_token": + self.has_user_token = has_token + elif token_type == "app_token": + self.has_app_token = has_token + self.last_token_update = datetime.now(UTC).isoformat() + + def set_field(self, field: str, value: Any) -> None: + """Set a field value in a type-safe way. + + Args: + field: The name of the field to set + value: The value to set the field to + + """ + if not hasattr(self, field): + raise ValueError(f"Invalid field name: {field}") + + # Get the SQLAlchemy Column type + column = self.__table__.columns.get(field) + if column is None: + raise ValueError(f"Field {field} is not a database column") + + # Convert the value to the correct type based on the column type + if isinstance(column.type, Boolean): + value = bool(value) + elif isinstance(column.type, String): + value = str(value) if value is not None else None + elif isinstance(column.type, Integer): + value = int(value) if value is not None else None + elif isinstance(column.type, JSON): + # JSON fields can accept any JSON-serializable value + pass + + # Use the internal SQLAlchemy setter + setattr(self, field, value) + + def get_id(self) -> int: + """Get the agent ID as a Python int.""" + return 0 if getattr(self, "id", None) is None else int(str(self.id)) + + def get_workflow_id(self) -> str: + """Get the workflow ID as a Python string.""" + return "" if getattr(self, "workflow_id", None) is None else str(self.workflow_id) diff --git a/pyspur/backend/pyspur/models/task_model.py b/pyspur/backend/pyspur/models/task_model.py new file mode 100644 index 0000000000000000000000000000000000000000..877b4f5cd76bc1d6f5ad645c69c3da7b6f73a5ac --- /dev/null +++ b/pyspur/backend/pyspur/models/task_model.py @@ -0,0 +1,62 @@ +from datetime import datetime, timezone +from enum import Enum as PyEnum +from typing import Any, Optional + +from sqlalchemy import ( + JSON, + Computed, + DateTime, + Enum, + ForeignKey, + Integer, + String, +) +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from .base_model import BaseModel + + +class TaskStatus(PyEnum): + PENDING = "PENDING" + RUNNING = "RUNNING" + COMPLETED = "COMPLETED" + FAILED = "FAILED" + CANCELED = "CANCELED" + PAUSED = "PAUSED" + + +class TaskModel(BaseModel): + __tablename__ = "tasks" + + _intid: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement="auto") + id: Mapped[str] = mapped_column(String, Computed("'T' || _intid"), nullable=False, unique=True) + run_id: Mapped[str] = mapped_column(String, ForeignKey("runs.id"), nullable=False) + node_id: Mapped[str] = mapped_column(String, nullable=False) + parent_task_id: Mapped[Optional[str]] = mapped_column( + String, ForeignKey("tasks.id"), nullable=True + ) + status: Mapped[TaskStatus] = mapped_column( + Enum(TaskStatus), default=TaskStatus.PENDING, nullable=False + ) + inputs: Mapped[Any] = mapped_column(JSON, nullable=True) + outputs: Mapped[Any] = mapped_column(JSON, nullable=True) + error: Mapped[Optional[str]] = mapped_column(String, nullable=True) + start_time: Mapped[Optional[datetime]] = mapped_column( + DateTime, default=datetime.now(timezone.utc) + ) + end_time: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + subworkflow: Mapped[Optional[Any]] = mapped_column(JSON, nullable=True) + subworkflow_output: Mapped[Optional[Any]] = mapped_column(JSON, nullable=True) + + # Relationships + parent_task = relationship("TaskModel", remote_side=[id], back_populates="subtasks") + subtasks = relationship("TaskModel", back_populates="parent_task", cascade="all, delete-orphan") + + @property + def run_time(self) -> Optional[float]: + if self.start_time and self.end_time: + return (self.end_time - self.start_time).total_seconds() + elif self.start_time: + return (datetime.now() - self.start_time).total_seconds() + else: + return None diff --git a/pyspur/backend/pyspur/models/user_session_model.py b/pyspur/backend/pyspur/models/user_session_model.py new file mode 100644 index 0000000000000000000000000000000000000000..205b7b5bdd63548851ef91a12c30c7ee07916b19 --- /dev/null +++ b/pyspur/backend/pyspur/models/user_session_model.py @@ -0,0 +1,98 @@ +from datetime import datetime, timezone +from typing import Any, Dict, List, Optional + +from sqlalchemy import JSON, Computed, DateTime, ForeignKey, Integer, String +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from .base_model import BaseModel +from .workflow_model import WorkflowModel + + +class UserModel(BaseModel): + __tablename__ = "users" + + _intid: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement="auto") + id: Mapped[str] = mapped_column(String, Computed("'U' || _intid"), nullable=False, unique=True) + external_id: Mapped[str] = mapped_column(String, nullable=True, unique=True) + user_metadata: Mapped[Dict[str, Any]] = mapped_column(JSON, nullable=False, default=dict) + created_at: Mapped[datetime] = mapped_column( + DateTime, nullable=False, default=datetime.now(timezone.utc), index=True + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime, + nullable=False, + default=datetime.now(timezone.utc), + onupdate=datetime.now(timezone.utc), + index=True, + ) + + # Relationship to sessions, ordered by most recent first + sessions: Mapped[List["SessionModel"]] = relationship( + "SessionModel", + back_populates="user", + order_by="desc(SessionModel.created_at)", + cascade="all, delete-orphan", + ) + + +class SessionModel(BaseModel): + __tablename__ = "sessions" + + _intid: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement="auto") + id: Mapped[str] = mapped_column(String, Computed("'SN' || _intid"), nullable=False, unique=True) + external_id: Mapped[Optional[str]] = mapped_column(String, nullable=True, unique=True) + user_id: Mapped[str] = mapped_column(String, ForeignKey("users.id"), nullable=False, index=True) + workflow_id: Mapped[str] = mapped_column( + String, ForeignKey("workflows.id"), nullable=False, index=True + ) + created_at: Mapped[datetime] = mapped_column( + DateTime, nullable=False, default=datetime.now(timezone.utc), index=True + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime, + nullable=False, + default=datetime.now(timezone.utc), + onupdate=datetime.now(timezone.utc), + index=True, + ) + + # Relationship to user + user: Mapped["UserModel"] = relationship("UserModel", back_populates="sessions") + + # Relationship to workflow + workflow: Mapped["WorkflowModel"] = relationship("WorkflowModel") + + # Relationship to messages, ordered chronologically + messages: Mapped[List["MessageModel"]] = relationship( + "MessageModel", + back_populates="session", + order_by="MessageModel.created_at", + cascade="all, delete-orphan", + ) + + +class MessageModel(BaseModel): + __tablename__ = "messages" + + _intid: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement="auto") + id: Mapped[str] = mapped_column(String, Computed("'M' || _intid"), nullable=False, unique=True) + session_id: Mapped[str] = mapped_column( + String, ForeignKey("sessions.id"), nullable=False, index=True + ) + run_id: Mapped[Optional[str]] = mapped_column( + String, ForeignKey("runs.id"), nullable=True, index=True + ) + content: Mapped[Dict[str, Any]] = mapped_column(JSON, nullable=False) + created_at: Mapped[datetime] = mapped_column( + DateTime, nullable=False, default=datetime.now(timezone.utc), index=True + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime, + nullable=False, + default=datetime.now(timezone.utc), + onupdate=datetime.now(timezone.utc), + index=True, + ) + + # Relationship to session + session: Mapped["SessionModel"] = relationship("SessionModel", back_populates="messages") diff --git a/pyspur/backend/pyspur/models/workflow_model.py b/pyspur/backend/pyspur/models/workflow_model.py new file mode 100644 index 0000000000000000000000000000000000000000..faa62091ca7561ec0e37de2cd23e2b2b575cf0c2 --- /dev/null +++ b/pyspur/backend/pyspur/models/workflow_model.py @@ -0,0 +1,39 @@ +from datetime import datetime, timezone +from typing import Any, Optional + +from sqlalchemy import JSON, Computed, DateTime, Integer, String +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from .base_model import BaseModel + + +class WorkflowModel(BaseModel): + """Represents a workflow in the system. + + A version of the workflow is created only when the workflow is run. + The latest or current version of the workflow is always stored in the + WorkflowModel itself, while specific versions are managed separately. + """ + + __tablename__ = "workflows" + + _intid: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement="auto") + id: Mapped[str] = mapped_column(String, Computed("'S' || _intid"), nullable=False, unique=True) + name: Mapped[str] = mapped_column(String, unique=True, nullable=False) + description: Mapped[Optional[str]] = mapped_column(String) + definition: Mapped[Any] = mapped_column(JSON, nullable=False) + created_at: Mapped[datetime] = mapped_column( + DateTime, default=datetime.now(timezone.utc), index=True + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime, + default=datetime.now(timezone.utc), + onupdate=datetime.now(timezone.utc), + index=True, + ) + + versions = relationship( + "WorkflowVersionModel", + back_populates="workflow", + cascade="all, delete-orphan", + ) diff --git a/pyspur/backend/pyspur/models/workflow_version_model.py b/pyspur/backend/pyspur/models/workflow_version_model.py new file mode 100644 index 0000000000000000000000000000000000000000..e4cb7c886bf7fedc53668462d946888bbb073e09 --- /dev/null +++ b/pyspur/backend/pyspur/models/workflow_version_model.py @@ -0,0 +1,37 @@ +from datetime import datetime, timezone +from typing import Any, List, Optional + +from sqlalchemy import JSON, Computed, DateTime, ForeignKey, Integer, String +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from .base_model import BaseModel +from .run_model import RunModel + + +class WorkflowVersionModel(BaseModel): + __tablename__ = "workflow_versions" + + _intid: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement="auto") + id: Mapped[str] = mapped_column(String, Computed("'SV' || _intid"), nullable=False, unique=True) + version: Mapped[int] = mapped_column(Integer, nullable=False, index=True) + workflow_id: Mapped[int] = mapped_column(ForeignKey("workflows.id"), nullable=False, index=True) + name: Mapped[str] = mapped_column(String, nullable=False) + description: Mapped[Optional[str]] = mapped_column(String) + definition: Mapped[Any] = mapped_column(JSON, nullable=False) + definition_hash: Mapped[str] = mapped_column(String, nullable=False) + created_at: Mapped[datetime] = mapped_column( + DateTime, default=datetime.now(timezone.utc), index=True + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime, + default=datetime.now(timezone.utc), + onupdate=datetime.now(timezone.utc), + index=True, + ) + + # Relationships + workflow = relationship("WorkflowModel", back_populates="versions") + + runs: Mapped[Optional[List["RunModel"]]] = relationship( + "RunModel", backref="workflow_version", cascade="all, delete-orphan" + ) diff --git a/pyspur/backend/pyspur/nodes/__init__.py b/pyspur/backend/pyspur/nodes/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pyspur/backend/pyspur/nodes/__pycache__/__init__.cpython-312.pyc b/pyspur/backend/pyspur/nodes/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c2b54e19a6eff738154d405c69c7fdcc239a6725 Binary files /dev/null and b/pyspur/backend/pyspur/nodes/__pycache__/__init__.cpython-312.pyc differ diff --git a/pyspur/backend/pyspur/nodes/__pycache__/base.cpython-312.pyc b/pyspur/backend/pyspur/nodes/__pycache__/base.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..55a717e88fe6f1ffc7d783af928c1e6df80142df Binary files /dev/null and b/pyspur/backend/pyspur/nodes/__pycache__/base.cpython-312.pyc differ diff --git a/pyspur/backend/pyspur/nodes/__pycache__/decorator.cpython-312.pyc b/pyspur/backend/pyspur/nodes/__pycache__/decorator.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..90436f35a5096b613595eb3d3dca267641de4a2e Binary files /dev/null and b/pyspur/backend/pyspur/nodes/__pycache__/decorator.cpython-312.pyc differ diff --git a/pyspur/backend/pyspur/nodes/base.py b/pyspur/backend/pyspur/nodes/base.py new file mode 100644 index 0000000000000000000000000000000000000000..e80bb303bc178678adad2b92b652015aaff1337f --- /dev/null +++ b/pyspur/backend/pyspur/nodes/base.py @@ -0,0 +1,379 @@ +import json +from abc import ABC, abstractmethod +from hashlib import md5 +from typing import Any, Dict, List, Optional, Type, cast + +from pydantic import BaseModel, Field, create_model + +from ..execution.workflow_execution_context import WorkflowExecutionContext +from ..schemas.workflow_schemas import WorkflowDefinitionSchema +from ..utils import pydantic_utils + + +class VisualTag(BaseModel): + """Pydantic model for visual tag properties.""" + + acronym: str = Field(...) + color: str = Field( + ..., pattern=r"^#(?:[0-9a-fA-F]{3}){1,2}$" + ) # Hex color code validation using regex + + +class BaseNodeConfig(BaseModel): + """Base class for node configuration models. + + Each node must define its output_schema. + """ + + output_schema: Dict[str, str] = Field( + default={"output": "string"}, + title="Output schema", + description="The schema for the output of the node", + ) + output_json_schema: str = Field( + default='{"type": "object", "properties": {"output": {"type": "string"} } }', + title="Output JSON schema", + description="The JSON schema for the output of the node", + ) + has_fixed_output: bool = Field( + default=False, + description="Whether the node has a fixed output schema defined in config", + ) + model_config = { + "extra": "allow", + } + + +class BaseNodeOutput(BaseModel): + """Base class for all node outputs. + + Each node type will define its own output model that inherits from this. + """ + + pass + + +class BaseNodeInput(BaseModel): + """Base class for node inputs. + + Each node's input model will be dynamically created based on its predecessor nodes, + with fields named after node IDs and types being the corresponding NodeOutputModels. + """ + + pass + + +class BaseNode(ABC): + """Base class for all nodes. + + Each node receives inputs as a Pydantic model where: + - Field names are predecessor node IDs + - Field types are the corresponding NodeOutputModels + """ + + name: str = "" + display_name: str = "" + category: Optional[str] = None + subcategory: Optional[str] = None + logo: Optional[str] = None + config_model: Type[BaseNodeConfig] + output_model: Type[BaseNodeOutput] + input_model: Type[BaseNodeInput] + _config: BaseNodeConfig + _input: BaseNodeInput + _output: BaseNodeOutput + visual_tag: VisualTag + subworkflow: Optional[WorkflowDefinitionSchema] + subworkflow_output: Optional[Dict[str, Any]] + + def __init__( + self, + name: str, + config: BaseNodeConfig, + context: Optional[WorkflowExecutionContext] = None, + ) -> None: + self.name = name + self._config = config + self.context = context + self.subworkflow = None + self.subworkflow_output = None + if not hasattr(self, "visual_tag"): + self.visual_tag = self.get_default_visual_tag() + self.setup() + + def setup(self) -> None: + """Define output_model and any other initialization. + + For dynamic schema nodes, these can be created based on self.config. + """ + if self._config.has_fixed_output: + schema = json.loads(self._config.output_json_schema) + model = pydantic_utils.json_schema_to_model( + schema, model_class_name=self.name, base_class=BaseNodeOutput + ) + self.output_model = model # type: ignore + + def create_output_model_class(self, output_schema: Dict[str, str]) -> Type[BaseNodeOutput]: + """Dynamically creates an output model based on the node's output schema.""" + field_type_to_python_type = { + "string": str, + "str": str, + "integer": int, + "int": int, + "number": float, + "float": float, + "boolean": bool, + "bool": bool, + "list": list, + "dict": dict, + "array": list, + "object": dict, + } + return create_model( + f"{self.name}", + **{ + field_name: ( + (field_type_to_python_type[field_type], ...) + if field_type in field_type_to_python_type + else (field_type, ...) # try as is + ) + for field_name, field_type in output_schema.items() + }, + __base__=BaseNodeOutput, + __config__=None, + __doc__=f"Output model for {self.name} node", + __module__=self.__module__, + __validators__=None, + __cls_kwargs__=None, + ) + + def create_composite_model_instance( + self, model_name: str, instances: Dict[str, BaseModel] + ) -> Type[BaseNodeInput]: + """Create a new Pydantic model that combines all the given models based on their instances. + + Args: + model_name: The name of the new model. + instances: A dictionary of Pydantic model instances. + + Returns: + A new Pydantic model with fields named after the keys of the dictionary. + + """ + # Create the new model class + return create_model( + model_name, + **{key: (instance.__class__, ...) for key, instance in instances.items()}, + __base__=BaseNodeInput, + __config__=None, + __doc__=f"Input model for {self.name} node", + __module__=self.__module__, + __validators__=None, + __cls_kwargs__=None, + ) + + async def __call__( + self, + input: ( + Dict[str, str | int | bool | float | Dict[str, Any] | List[Any]] + | Dict[str, BaseNodeOutput] + | Dict[str, BaseNodeInput] + | BaseNodeInput + ), + ) -> BaseNodeOutput: + """Validate inputs and run the node's logic. + + Args: + input: Pydantic model containing predecessor + outputs or a Dict[str, NodeOutputModel] + + Returns: + The node's output model + + """ + if isinstance(input, dict): + if all(isinstance(value, BaseNodeOutput) for value in input.values()) or all( + isinstance(value, BaseNodeInput) for value in input.values() + ): + # Input is a dictionary of BaseNodeOutput or BaseNodeInput instances, + # creating a composite model + composite_inputs: Dict[str, BaseModel] = cast(Dict[str, BaseModel], input) + self.input_model = self.create_composite_model_instance( + model_name=self.input_model.__name__, + instances=composite_inputs, # preserve original keys + ) + data: Dict[str, Any] = {} + for key, value in composite_inputs.items(): + data[key] = value.model_dump() + input = self.input_model.model_validate(data) + else: + # Input is a dictionary of primitive types + self.input_model = pydantic_utils.create_model( + f"{self.name}Input", + **{field_name: (type(value), value) for field_name, value in input.items()}, + __base__=BaseNodeInput, + __config__=None, + __doc__=f"Input model for {self.name} node", + __module__=self.__module__, + __validators__=None, + __cls_kwargs__=None, + ) + input = self.input_model.model_validate(input) + + self._input = input + result = await self.run(input) + + try: + output_validated = self.output_model.model_validate(result.model_dump()) + except AttributeError: + output_validated = self.output_model.model_validate(result) + except Exception as e: + # Print the result for better debuggability + try: + result_dump = result.model_dump() if hasattr(result, "model_dump") else result + print(f"Validation failed for node {self.name}. Result: {result_dump}") + except Exception as dump_error: + print( + f"Validation failed for node {self.name}. Could not dump result: {dump_error}" + ) + print(f"Result type: {type(result)}") + raise ValueError(f"Output validation error in {self.name}: {e}") from e + + self._output = output_validated + return output_validated + + @abstractmethod + async def run(self, input: BaseModel) -> BaseModel: + """Abstract method where the node's core logic is implemented. + + Args: + input: Pydantic model containing predecessor outputs + + Returns: + An instance compatible with output_model + + """ + pass + + @property + def config(self) -> Any: + """Return the node's configuration.""" + return self.config_model.model_validate(self._config.model_dump()) + + @property + def function_schema(self) -> Dict[str, Any]: + """Return the node's function schema. + + Converts the config model's schema into a function schema format where + config fields become function parameters. If has_fixed_output is true, + both it and output_json_schema are excluded from the parameters. + """ + config_schema = self.config_model.model_json_schema() + + # Get description from the node's docstring if available + description = self.__class__.__doc__ or config_schema.get( + "description", f"Function schema for {self.name}" + ) + # Clean up the docstring by removing extra whitespace and newlines + description = " ".join(line.strip() for line in description.split("\n")).strip() + + # if has_fixed_output is true then no need to include it in the function schema + # and also remove output_json_schema from the parameters + properties = config_schema.get("properties", {}) + if properties.get("has_fixed_output", {}).get("default", False): + properties = { + k: v + for k, v in properties.items() + if k not in ["has_fixed_output", "output_json_schema"] + } + # Also remove from required if present + required = [ + r + for r in config_schema.get("required", []) + if r not in ["has_fixed_output", "output_json_schema"] + ] + else: + required = config_schema.get("required", []) + + # Create function schema + function_schema = { + "type": "function", + "function": { + "name": self.name, + "description": description, + "parameters": { + "type": "object", + "properties": properties, + "required": required, + }, + }, + } + + # If there are any definitions in the original schema, preserve them + if "definitions" in config_schema: + function_schema["definitions"] = config_schema["definitions"] + + return function_schema + + async def call_as_tool(self, arguments: Dict[str, Any]) -> Any: + """Call the node as a tool with the given arguments. + + Args: + arguments: The arguments to pass to the node + + """ + # generate the config model from the arguments + config_model = self.config_model.model_validate(arguments) + + # create a new instance of the node with the config model + node_instance = self.__class__(self.name, config_model, self.context) + # run the node with the input model + input_model = self.input_model.model_validate(arguments) + node_instance._input = input_model + return await node_instance.run(input_model) + + def update_config(self, config: BaseNodeConfig) -> None: + """Update the node's configuration.""" + self._config = config + + @property + def input(self) -> Any: + """Return the node's input.""" + return self.input_model.model_validate(self._input.model_dump()) + + @property + def output(self) -> Any: + """Return the node's output.""" + return self.output_model.model_validate(self._output.model_dump()) + + @classmethod + def get_default_visual_tag(cls) -> VisualTag: + """Set a default visual tag for the node.""" + # default acronym is the first letter of each word in the node name + acronym = "".join([word[0] for word in cls.name.split("_")]).upper() + + # default color is randomly picked from a list of pastel colors + colors = [ + "#007BFF", # Electric Blue + "#28A745", # Emerald Green + "#FFC107", # Sunflower Yellow + "#DC3545", # Crimson Red + "#6F42C1", # Royal Purple + "#FD7E14", # Bright Orange + "#20C997", # Teal + "#E83E8C", # Hot Pink + "#17A2B8", # Cyan + "#6610F2", # Indigo + "#8CC63F", # Lime Green + "#FF00FF", # Magenta + "#FFD700", # Gold + "#FF7F50", # Coral + "#40E0D0", # Turquoise + "#00BFFF", # Deep Sky Blue + "#FF5522", # Orange + "#FA8072", # Salmon + "#8A2BE2", # Violet + ] + color = colors[int(md5(cls.__name__.encode()).hexdigest(), 16) % len(colors)] + + return VisualTag(acronym=acronym, color=color) diff --git a/pyspur/backend/pyspur/nodes/decorator.py b/pyspur/backend/pyspur/nodes/decorator.py new file mode 100644 index 0000000000000000000000000000000000000000..0482d26f2bb63b8ddf15b6fc0ef499f85b9ee851 --- /dev/null +++ b/pyspur/backend/pyspur/nodes/decorator.py @@ -0,0 +1,426 @@ +import inspect +import json +from typing import ( + Any, + Callable, + Dict, + Optional, + Protocol, + Set, + Type, + cast, + get_type_hints, + runtime_checkable, +) + +from jinja2 import Template +from pydantic import BaseModel, Field, create_model + +from ..execution.workflow_execution_context import WorkflowExecutionContext +from ..utils.pydantic_utils import json_schema_to_model +from .base import BaseNode, BaseNodeConfig, BaseNodeInput, BaseNodeOutput, VisualTag + + +class FunctionToolNode(BaseNode): + """Node class for function-based tools. + + This class is used to wrap Python functions as PySpur nodes. It handles parameter extraction, + template rendering, and function execution. + """ + + name: str + display_name: str + config_model: Type[BaseNodeConfig] + output_model: Type[BaseNodeOutput] + input_model: Type[BaseNodeInput] + function_param_names: Set[str] + is_output_model_defined: bool + _func: Callable[..., Any] + _visual_tag: Optional[Dict[str, str]] + + def __init__( + self, + name: str, + config: Optional[BaseNodeConfig] = None, + context: Optional[WorkflowExecutionContext] = None, + func: Optional[Callable[..., Any]] = None, + visual_tag: Optional[Dict[str, str]] = None, + ): + # Create default config if none provided + if config is None: + config = self.config_model() + + # Call parent init first + super().__init__(name=name, config=config, context=context) + + # Store the function and visual tag + if func is not None: + self._func = func + if visual_tag: + self.visual_tag = VisualTag(**visual_tag) + self._visual_tag = visual_tag + + async def run(self, input: BaseModel) -> BaseModel: + # Extract parameters from config directly using the stored parameter names + # This is more efficient than checking sig.parameters each time + kwargs: Dict[str, Any] = {} + + for param_name in self.function_param_names: + if hasattr(self.config, param_name): + kwargs[param_name] = getattr(self.config, param_name) + + # config values can be jinja2 templates so we need to render them + for param_name, param_value in kwargs.items(): + if isinstance(param_value, str): + template = Template(param_value) + kwargs[param_name] = template.render(input=input) + + # Call the original function + result = self._func(**kwargs) + + # Handle async functions + if hasattr(result, "__await__"): + result = await result + + if self.is_output_model_defined: + return self.output_model.model_validate(result) + else: + return self.output_model.model_validate({"output": result}) + + +@runtime_checkable +class ToolFunction(Protocol): + """Protocol for functions decorated with @tool.""" + + node_class: Type[BaseNode] + config_model: Type[BaseNodeConfig] + output_model: Type[BaseNodeOutput] + func_name: str + + def create_node( + self, + name: str = ..., + config: Optional[BaseNodeConfig] = None, + context: Optional[WorkflowExecutionContext] = None, + ) -> BaseNode: ... + + def __call__(self, *args: Any, **kwargs: Any) -> Any: ... + + +def tool_function( + name: Optional[str] = None, + display_name: Optional[str] = None, + description: Optional[str] = None, + category: Optional[str] = None, + visual_tag: Optional[Dict[str, str]] = None, + has_fixed_output: bool = True, + output_model: Optional[Type[BaseNodeOutput]] = None, + **tool_config: Any, +) -> Callable[[Callable[..., Any]], ToolFunction]: + """Register a function as a Tool. + + Args: + name: Optional name for the tool (defaults to function name) + display_name: Optional display name for the UI + description: Optional description (defaults to function docstring) + category: Optional category for organizing tools in the UI + visual_tag: Optional visual styling for the UI + has_fixed_output: Whether the tool has fixed output schema + output_model: Optional custom output model to use instead of generating one + **tool_config: Additional configuration parameters for the tool + + Returns: + Decorated function that can still be called normally + + """ + + def decorator(func: Callable[..., Any]) -> ToolFunction: + # Get function metadata + func_name = name or func.__name__ + func_display_name = display_name or func_name.replace("_", " ").title() + func_doc = description or func.__doc__ or "" + + # Get type hints for input/output + type_hints = get_type_hints(func) + return_type = type_hints.get("return", Any) + + # Get function signature for default values + sig = inspect.signature(func) + + # Create a simple input model - inputs will be determined by workflow position + input_model: Type[BaseNodeInput] = create_model( + f"{func_name}Input", + __base__=BaseNodeInput, + __module__=func.__module__, + __validators__={}, + __cls_kwargs__={}, + __config__=None, + __doc__=f"Input model for {func_name}", + ) + + is_output_model_defined: bool = False + # Use provided output_model if available, otherwise create one + if output_model is not None: + # Use the provided output model + _output_model = output_model + is_output_model_defined = True + else: + # Create output model from function signature + if isinstance(return_type, type) and issubclass(return_type, BaseNodeOutput): + # If return type is already a Pydantic model, use it as base + _output_model = return_type + is_output_model_defined = True + elif isinstance(return_type, type) and issubclass(return_type, BaseModel): + _output_model = json_schema_to_model( + json_schema=return_type.model_json_schema(), + model_class_name=f"{func_name}", + base_class=BaseNodeOutput, + ) + is_output_model_defined = True + else: + # For primitive return types, create a model with a single field named "value" + _output_model = create_model( + f"{func_name}", + output=(return_type, ...), + __base__=BaseNodeOutput, + __module__=func.__module__, + __validators__={}, + __cls_kwargs__={}, + __config__=None, + __doc__=f"Output model for {func_name}", + ) + + # Determine if the function is a method, class method, or static/regular method + is_method = False + is_class_method = False + + # Check if the function is a method by looking at the first parameter + if sig.parameters and list(sig.parameters.keys())[0] in ("self", "cls"): + first_param = list(sig.parameters.keys())[0] + is_method = first_param == "self" + is_class_method = first_param == "cls" + + # Create config fields from function parameters + function_param_fields: Dict[str, Any] = {} + # Keep track of function parameter names for later use + function_param_names: Set[str] = set() + + for param_name, param in sig.parameters.items(): + # Skip self for instance methods and cls for class methods + # These will be handled specially during execution + if (is_method and param_name == "self") or (is_class_method and param_name == "cls"): + continue + + param_type = type_hints.get(param_name, Any) + default = ... if param.default is inspect.Parameter.empty else param.default + + # Add parameter as a config field + function_param_fields[param_name] = ( + param_type, + Field( + default=default if default is not ... else None, + description=f"Parameter '{param_name}' for function {func_name}", + ), + ) + function_param_names.add(param_name) + + # Add other config fields + decorator_param_fields: Dict[str, Any] = {k: (type(v), v) for k, v in tool_config.items()} + + # Merge function and decorator config fields, decorator fields take precedence + # This allows the decorator to override function parameters if needed + config_fields = {**function_param_fields, **decorator_param_fields} + + # Create the config model + config_model: Type[BaseNodeConfig] = create_model( + f"{func_name}Config", + output_json_schema=(str, json.dumps(_output_model.model_json_schema())), + has_fixed_output=(bool, has_fixed_output), + **config_fields, + __base__=BaseNodeConfig, + __module__=func.__module__, + __validators__={}, + __cls_kwargs__={}, + __config__=None, + __doc__=f"Config model for {func_name}", + ) + + # Store these for use in the class definition + nonlocal category + _category = category + _config_model = config_model + _input_model = input_model + _function_param_names = function_param_names + _is_output_model_defined = is_output_model_defined + + # Create a Node class for this function + class CustomFunctionToolNode(FunctionToolNode): + name = func_name + display_name = func_display_name + category = _category or "FunctionTools" + config_model = _config_model + output_model = _output_model # type: ignore + input_model = _input_model + function_param_names = _function_param_names + is_output_model_defined = _is_output_model_defined + __doc__ = func_doc + + def __init__( + self, + name: str = func_name, + config: Optional[BaseNodeConfig] = None, + context: Optional[WorkflowExecutionContext] = None, + ): + super().__init__( + name=name, + config=config, + context=context, + func=func, + visual_tag=visual_tag, + ) + + # Change the name of the class to the function name and bind it to the module + new_class_name = type( + f"{func_name}", + (CustomFunctionToolNode,), + { + "__module__": func.__module__ # Set the module to match the decorated func's module + }, + ) + + # Set NodeClass attribute to the function + func.node_class = new_class_name # type: ignore + + # Set the config model to the config_model + func.config_model = config_model # type: ignore + + # Set the output model to the output_model + func.output_model = _output_model # type: ignore + + # Set the func_name attribute to the function name + func.func_name = func.__name__ # type: ignore + + # Set the create_node function to the func + def create_node( + name: str = func_name, + config: Optional[BaseNodeConfig] = None, + context: Optional[WorkflowExecutionContext] = None, + ) -> FunctionToolNode: + return new_class_name(name=name, config=config, context=context) + + func.create_node = create_node # type: ignore + + return cast(ToolFunction, func) + + return decorator + + +if __name__ == "__main__": + import asyncio + + from pydantic import Field + + # Example usage + @tool_function(name="example_tool", description="An example tool", category="Example") + def example_function(param1: str, param2: int = 42) -> Dict[str, Any]: + """Return a dictionary.""" + return {"param1": param1, "param2": param2} + + # Create a node from the function + node_config = example_function.config_model.model_validate( + {"param1": "test", "param2": 100, "has_fixed_output": True} + ) + node = example_function.create_node(name="example_node", config=node_config) + print("=" * 50) + print("PLAIN FUNCTION EXECUTION:") + print(example_function("test", 100)) # Output: {'param1': 'test', 'param2': 100} + + print("=" * 50) + print("NODE NAME:") + print(node.name) # Output: example_tool + + print("=" * 50) + print("DISPLAY NAME:") + print(node.display_name) # Output: Example Tool + + print("=" * 50) + print("CATEGORY:") + print(node.category) # Output: Example + + print("=" * 50) + print("CONFIG MODEL SCHEMA:") + print(node.config_model.model_json_schema()) # Output: JSON schema of the config model + + print("=" * 50) + print("OUTPUT MODEL SCHEMA:") + print(node.output_model.model_json_schema()) # Output: JSON schema of the output model + + print("=" * 50) + print("NODE EXECUTION RESULT:") + print(asyncio.run(node(input={"input_data": "test"}))) # Output: Result of the function + + # Example with custom config_model and output_model + print("\n" + "=" * 50) + print("EXAMPLE WITH CUSTOM OUTPUT MODEL:") + + # Define custom output model + class CustomOutputModel(BaseNodeOutput): + result: str = Field(..., description="The result of the operation") + status: str = Field("success", description="Status of the operation") + + # Function that will use custom output model + @tool_function( + name="custom_model_tool", + description="Tool with custom output model", + category="Example", + output_model=CustomOutputModel, + ) + def custom_model_function(message: str, prefix: str = "Result: ") -> Dict[str, str]: + """Use custom output model.""" + return {"result": f"{prefix}{message}", "status": "success"} + + # Create a node from the function with custom output model + custom_config = custom_model_function.config_model.model_validate( + {"message": "Hello World", "prefix": "Custom: ", "has_fixed_output": True} + ) + custom_node = custom_model_function.create_node(name="custom_node", config=custom_config) + + print("CUSTOM NODE CONFIG MODEL:") + print(custom_node.config_model.model_json_schema()) + + print("\nCUSTOM NODE OUTPUT MODEL:") + print(custom_node.output_model.model_json_schema()) + + print("\nCUSTOM NODE EXECUTION RESULT:") + print(asyncio.run(custom_node(input={}))) + + # Example with jinja2 template in config rendered using input + print("\n" + "=" * 50) + print("EXAMPLE WITH JINJA2 TEMPLATE IN CONFIG:") + + @tool_function( + name="jinja_tool", + description="Tool with jinja2 template in config rendered using input", + category="Example", + ) + def jinja_function(template: str, suffix: str = "World") -> str: + """Use jinja2 template in config.""" + return template + suffix + + # Create a node from the function with jinja2 template in config + jinja_config = jinja_function.config_model.model_validate( + {"template": "Hello ", "suffix": "{{ input.input_data }}", "has_fixed_output": True} + ) + jinja_node = jinja_function.create_node(name="jinja_node", config=jinja_config) + # Render the template using input + input_data: Dict[str, Any] = {"input_data": "Jinja2!"} + + jinja_result = asyncio.run(jinja_node(input=input_data)) + print("JINJA NODE CONFIG MODEL:") + print(jinja_node.config_model.model_json_schema()) + + print("\nJINJA NODE OUTPUT MODEL:") + print(jinja_node.output_model.model_json_schema()) + + print("\nJINJA NODE EXECUTION RESULT:") + print(jinja_result) # Output: Hello Jinja2! diff --git a/pyspur/backend/pyspur/nodes/email/providers/base.py b/pyspur/backend/pyspur/nodes/email/providers/base.py new file mode 100644 index 0000000000000000000000000000000000000000..3540960c724b3097af4c22467afcd24bbd03583c --- /dev/null +++ b/pyspur/backend/pyspur/nodes/email/providers/base.py @@ -0,0 +1,43 @@ +from enum import Enum +from typing import List, Protocol + +from pydantic import BaseModel, Field + + +class EmailProvider(str, Enum): + RESEND = "resend" + SENDGRID = "sendgrid" + + +class EmailProviderConfig(BaseModel): + """Configuration for an email provider""" + + pass + + +class EmailMessage(BaseModel): + """Common email message format across providers""" + + from_email: str + to_emails: List[str] + subject: str + content: str + + +class EmailResponse(BaseModel): + """Common response format across providers""" + + provider: EmailProvider + message_id: str + status: str + raw_response: str = Field(..., description="JSON string containing the raw provider response") + + +class EmailProviderProtocol(Protocol): + """Protocol that all email providers must implement""" + + def __init__(self, config: EmailProviderConfig): ... + + async def send_email(self, message: EmailMessage) -> EmailResponse: + """Send an email using this provider""" + ... diff --git a/pyspur/backend/pyspur/nodes/email/providers/registry.py b/pyspur/backend/pyspur/nodes/email/providers/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..c4bf7dbad2aa25f6dbea30aa55b63ad4a3803e05 --- /dev/null +++ b/pyspur/backend/pyspur/nodes/email/providers/registry.py @@ -0,0 +1,23 @@ +from typing import Dict, Type + +from .base import EmailProvider, EmailProviderConfig, EmailProviderProtocol +from .resend_provider import ResendProvider +from .sendgrid_provider import SendGridProvider + + +class EmailProviderRegistry: + _providers: Dict[EmailProvider, Type[EmailProviderProtocol]] = { + EmailProvider.RESEND: ResendProvider, + EmailProvider.SENDGRID: SendGridProvider, + } + + @classmethod + def get_provider( + cls, provider_type: EmailProvider, config: EmailProviderConfig + ) -> EmailProviderProtocol: + """Get an instance of the specified email provider""" + if provider_type not in cls._providers: + raise ValueError(f"Unknown email provider: {provider_type}") + + provider_class = cls._providers[provider_type] + return provider_class(config) diff --git a/pyspur/backend/pyspur/nodes/email/providers/resend_provider.py b/pyspur/backend/pyspur/nodes/email/providers/resend_provider.py new file mode 100644 index 0000000000000000000000000000000000000000..2f0ff0daa52bfacb2169cb7e449bbc8723aa0d12 --- /dev/null +++ b/pyspur/backend/pyspur/nodes/email/providers/resend_provider.py @@ -0,0 +1,58 @@ +import json +import os + +import resend + +from .base import ( + EmailMessage, + EmailProvider, + EmailProviderConfig, + EmailResponse, +) + + +class ResendProvider: + def __init__(self, config: EmailProviderConfig): + self.config = config + api_key = os.getenv("RESEND_API_KEY") + if not api_key: + raise ValueError("RESEND_API_KEY environment variable is not set") + resend.api_key = api_key + + async def send_email(self, message: EmailMessage) -> EmailResponse: + params: resend.Emails.SendParams = { + "from": message.from_email, + "to": message.to_emails, + "subject": message.subject, + "text": message.content, + } + + try: + response = resend.Emails.send(params) + # Convert response to a clean dictionary format and then to JSON string + response_dict = { + "id": getattr(response, "id", ""), + "from": message.from_email, + "to": message.to_emails, + "subject": message.subject, + } + + return EmailResponse( + provider=EmailProvider.RESEND, + message_id=str(response_dict["id"]), + status="success", + raw_response=json.dumps(response_dict), + ) + except Exception as e: + error_dict = { + "error": str(e), + "from": message.from_email, + "to": message.to_emails, + "subject": message.subject, + } + return EmailResponse( + provider=EmailProvider.RESEND, + message_id="", + status="error", + raw_response=json.dumps(error_dict), + ) diff --git a/pyspur/backend/pyspur/nodes/email/providers/sendgrid_provider.py b/pyspur/backend/pyspur/nodes/email/providers/sendgrid_provider.py new file mode 100644 index 0000000000000000000000000000000000000000..1088da0e360cc914044b694bbc5a558850c32108 --- /dev/null +++ b/pyspur/backend/pyspur/nodes/email/providers/sendgrid_provider.py @@ -0,0 +1,66 @@ +import json +import os +from typing import Any + +from sendgrid import SendGridAPIClient # type: ignore +from sendgrid.helpers.mail import Content, Email, Mail, To # type: ignore + +from .base import ( + EmailMessage, + EmailProvider, + EmailProviderConfig, + EmailResponse, +) + + +class SendGridProvider: + def __init__(self, config: EmailProviderConfig): + self.config = config + api_key = os.getenv("SENDGRID_API_KEY") + if not api_key: + raise ValueError("SENDGRID_API_KEY environment variable is not set") + self.client = SendGridAPIClient(api_key) + + async def send_email(self, message: EmailMessage) -> EmailResponse: + from_email = Email(message.from_email) + to_emails = [To(email=to_email) for to_email in message.to_emails] + subject = str(message.subject) + content = Content("text/plain", str(message.content)) + + # Create personalization for each recipient + email = Mail() + email.from_email = from_email + email.subject = subject + email.content = [content] + email.to = to_emails + + try: + response: Any = self.client.send(email) # type: ignore + print("response: ", response) + response_dict = { + "id": response.headers.get("X-Message-Id", ""), + "status_code": response.status_code, + "from": message.from_email, + "to": message.to_emails, + "subject": message.subject, + } + + return EmailResponse( + provider=EmailProvider.SENDGRID, + message_id=str(response_dict["id"]), + status="success" if response.status_code == 202 else "error", + raw_response=json.dumps(response_dict), + ) + except Exception as e: + error_dict = { + "error": str(e), + "from": message.from_email, + "to": message.to_emails, + "subject": message.subject, + } + return EmailResponse( + provider=EmailProvider.SENDGRID, + message_id="", + status="error", + raw_response=json.dumps(error_dict), + ) diff --git a/pyspur/backend/pyspur/nodes/email/send_email.py b/pyspur/backend/pyspur/nodes/email/send_email.py new file mode 100644 index 0000000000000000000000000000000000000000..d5e4b834a5f4e51edfe867f48be1bb859879d2f4 --- /dev/null +++ b/pyspur/backend/pyspur/nodes/email/send_email.py @@ -0,0 +1,135 @@ +import json +from typing import Dict, List + +from pydantic import BaseModel, Field + +from ..base import BaseNode, BaseNodeConfig, BaseNodeInput, BaseNodeOutput +from ..utils.template_utils import render_template_or_get_first_string +from .providers.base import ( + EmailMessage, + EmailProvider, + EmailProviderConfig, + EmailResponse, +) +from .providers.registry import EmailProviderRegistry + + +def parse_email_addresses(email_str: str) -> List[str]: + """ + Parse a string containing one or more email addresses. + + Args: + email_str: A string that can be either a single email or a list of emails in the format "['email1', 'email2']" + + Returns: + List[str]: A list of cleaned email addresses + + Example: + >>> parse_email_addresses("test@example.com") + ["test@example.com"] + >>> parse_email_addresses("['test1@example.com', 'test2@example.com']") + ["test1@example.com", "test2@example.com"] + """ + email_str = email_str.strip() + if email_str.startswith("[") and email_str.endswith("]"): + # Remove brackets and split by comma + email_str = email_str[1:-1] + # Split by comma and clean each email + emails = [email.strip().strip("'").strip('"') for email in email_str.split(",")] + # Remove any empty strings + emails = [email for email in emails if email] + if not emails: + raise ValueError("No valid email addresses found in the list") + return emails + return [email_str] + + +class SendEmailNodeOutput(BaseNodeOutput): + provider: EmailProvider = Field(..., description="The email provider used") + message_id: str = Field(..., description="The message ID from the provider") + status: str = Field(..., description="The status of the email send operation") + raw_response: str = Field(..., description="The raw response from the provider as JSON string") + + +class SendEmailNodeConfig(BaseNodeConfig): + provider: EmailProvider = Field( + EmailProvider.RESEND, + description="The email provider to use", + ) + from_template: str = Field("", description="Email address to send from") + to_template: str = Field("", description="Email address to send to") + subject_template: str = Field("", description="Email subject") + content_template: str = Field("", description="Email content (plain text)") + output_schema: Dict[str, str] = Field( + default={ + "provider": "string", + "message_id": "string", + "status": "string", + "raw_response": "string", + }, + description="The schema for the output of the node", + ) + has_fixed_output: bool = True + output_json_schema: str = json.dumps(SendEmailNodeOutput.model_json_schema()) + + +class SendEmailNodeInput(BaseNodeInput): + """Input for the email node""" + + class Config: + extra = "allow" + + +class SendEmailNode(BaseNode): + """Node for sending an email""" + + name = "send_email_node" + display_name = "Send Email" + + config_model = SendEmailNodeConfig + input_model = SendEmailNodeInput + output_model = SendEmailNodeOutput + + async def run(self, input: BaseModel) -> BaseModel: + # Create provider config + provider_config = EmailProviderConfig() + + # Get the appropriate provider instance + provider = EmailProviderRegistry.get_provider(self.config.provider, provider_config) + + # Render the templates + raw_input_dict = input.model_dump() + from_email = render_template_or_get_first_string( + self.config.from_template, raw_input_dict, self.name + ) + to_emails_str = render_template_or_get_first_string( + self.config.to_template, raw_input_dict, self.name + ) + + to_emails = parse_email_addresses(to_emails_str) + + subject = render_template_or_get_first_string( + self.config.subject_template, raw_input_dict, self.name + ) + content = render_template_or_get_first_string( + self.config.content_template, raw_input_dict, self.name + ) + + # Create the email message + message = EmailMessage( + from_email=from_email, + to_emails=to_emails, + subject=subject, + content=content, + ) + + # Send the email + response: EmailResponse = await provider.send_email(message) + + # Return the response + return SendEmailNodeOutput( + provider=response.provider, + message_id=response.message_id, + status=response.status, + raw_response=response.raw_response, + ) diff --git a/pyspur/backend/pyspur/nodes/example.py b/pyspur/backend/pyspur/nodes/example.py new file mode 100644 index 0000000000000000000000000000000000000000..5f7267ee906bec5cfce5824cdb60f3fa53abee34 --- /dev/null +++ b/pyspur/backend/pyspur/nodes/example.py @@ -0,0 +1,53 @@ +from pydantic import BaseModel + +from .base import BaseNode + + +class ExampleNodeConfig(BaseModel): + """ + Configuration parameters for the ExampleNode. + """ + + pass + + +class ExampleNodeInput(BaseModel): + """ + Input parameters for the ExampleNode. + """ + + name: str + + +class ExampleNodeOutput(BaseModel): + """ + Output parameters for the ExampleNode. + """ + + greeting: str + + +class ExampleNode(BaseNode): + """ + Example node that takes a name and returns a greeting. + """ + + name = "example" + config_model = ExampleNodeConfig + input_model = ExampleNodeInput + output_model = ExampleNodeOutput + + def setup(self) -> None: + self.input_model = ExampleNodeInput + self.output_model = ExampleNodeOutput + + async def run(self, input_data: ExampleNodeInput) -> ExampleNodeOutput: + return ExampleNodeOutput(greeting=f"Hello, {input_data.name}!") + + +if __name__ == "__main__": + import asyncio + + example_node = ExampleNode(ExampleNodeConfig()) + output = asyncio.run(example_node(ExampleNodeInput(name="Alice"))) + print(output) diff --git a/pyspur/backend/pyspur/nodes/factory.py b/pyspur/backend/pyspur/nodes/factory.py new file mode 100644 index 0000000000000000000000000000000000000000..83f779cb43462e96d8a805e7f6c7f1c7948bcd36 --- /dev/null +++ b/pyspur/backend/pyspur/nodes/factory.py @@ -0,0 +1,105 @@ +import importlib +from typing import Any, Dict, List + +from ..schemas.node_type_schemas import NodeTypeSchema +from .base import BaseNode +from .node_types import ( + SUPPORTED_NODE_TYPES, + get_all_node_types, + is_valid_node_type, +) +from .registry import NodeRegistry + + +class NodeFactory: + """Create node instances from a configuration. + + Supports both decorator-based registration and legacy configured registration. + + Conventions: + - The node class should be named Node + - The config model should be named NodeConfig + - The input model should be named NodeInput + - The output model should be named NodeOutput + - There should be only one node type class per module + + Nodes can be registered in two ways: + 1. Using the @NodeRegistry.register decorator (recommended) + 2. Through the legacy configured SUPPORTED_NODE_TYPES in node_types.py + """ + + @staticmethod + def get_all_node_types() -> Dict[str, List[NodeTypeSchema]]: + """Return a dictionary of all available node types grouped by category. + + Combines both decorator-registered and configured nodes. + """ + # Get nodes from both sources + configured_nodes = get_all_node_types() + registered_nodes = NodeRegistry.get_registered_nodes() + + # Convert registered nodes to NodeTypeSchema + converted_nodes: Dict[str, List[NodeTypeSchema]] = {} + for category, nodes in registered_nodes.items(): + if category not in converted_nodes: + converted_nodes[category] = [] + for node in nodes: + schema = NodeTypeSchema( + node_type_name=node.node_type_name, + module=node.module, + class_name=node.class_name, + ) + converted_nodes[category].append(schema) + + # Merge nodes, giving priority to configured ones + result = configured_nodes.copy() + for category, nodes in converted_nodes.items(): + if category not in result: + result[category] = [] + # Only add nodes that aren't already present + for node in nodes: + if not any(n.node_type_name == node.node_type_name for n in result[category]): + result[category].append(node) + + return result + + @staticmethod + def create_node(node_name: str, node_type_name: str, config: Any) -> BaseNode: + """Create a node instance from a configuration. + + Checks both registration methods for the node type. + """ + if not is_valid_node_type(node_type_name): + raise ValueError(f"Node type '{node_type_name}' is not valid.") + + module_name = None + class_name = None + + # First check configured nodes + for node_group in SUPPORTED_NODE_TYPES.values(): + for node_type in node_group: + if node_type["node_type_name"] == node_type_name: + module_name = node_type["module"] + class_name = node_type["class_name"] + break + if module_name and class_name: + break + + # If not found, check registry + if not module_name or not class_name: + registered_nodes = NodeRegistry.get_registered_nodes() + for nodes in registered_nodes.values(): + for node in nodes: + if node.node_type_name == node_type_name: + module_name = node.module + class_name = node.class_name + break + if module_name and class_name: + break + + if not module_name or not class_name: + raise ValueError(f"Node type '{node_type_name}' not found.") + + module = importlib.import_module(module_name, package="pyspur") + node_class = getattr(module, class_name) + return node_class(name=node_name, config=node_class.config_model(**config)) diff --git a/pyspur/backend/pyspur/nodes/integrations/anki/anki_basic_node_deck_generator.py b/pyspur/backend/pyspur/nodes/integrations/anki/anki_basic_node_deck_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..03b4b81b61b9e61410254a1cbc7e85a70902eff2 --- /dev/null +++ b/pyspur/backend/pyspur/nodes/integrations/anki/anki_basic_node_deck_generator.py @@ -0,0 +1,190 @@ +import json +import os +import random +from typing import Any, Dict, List, Optional + +import genanki +from jinja2 import Template +from pydantic import BaseModel, Field +from typing_extensions import TypeAlias + +from ...base import BaseNode, BaseNodeConfig, BaseNodeInput, BaseNodeOutput + +# Type aliases for genanki types to help with type checking +GenkaniModel: TypeAlias = Any # genanki.Model +GenkaniDeck: TypeAlias = Any # genanki.Deck +GenkaniNote: TypeAlias = Any # genanki.Note +GenkaniPackage: TypeAlias = Any # genanki.Package + + +class AnkiBasicNodeInput(BaseNodeInput): + """Input for the AnkiBasic node - creates cards with just front and back sides.""" + + front: List[str] = Field( + ..., description="List of front (question) sides of the cards. Supports Jinja templating." + ) + back: List[str] = Field( + ..., description="List of back (answer) sides of the cards. Supports Jinja templating." + ) + + class Config: + extra = "allow" + + +class AnkiBasicNodeOutput(BaseNodeOutput): + """Output for the AnkiBasic node.""" + + deck_path: str = Field(..., description="Path to the generated Anki deck file") + card_count: int = Field(..., description="Number of cards in the generated deck") + + +class AnkiBasicNodeConfig(BaseNodeConfig): + """Configuration for the AnkiBasic node.""" + + deck_name: str = Field("Generated Basic Deck", description="Name of the Anki deck") + output_dir: str = Field( + "data/anki_decks", description="Directory where the deck file will be saved" + ) + model_id: Optional[int] = Field( + None, + description="Optional unique model ID for the Anki model or a random ID will be generated.", + ) + deck_id: Optional[int] = Field( + None, + description="Optional unique deck ID for the Anki deck or a random ID will be generated.", + ) + has_fixed_output: bool = True + output_json_schema: str = Field( + default=json.dumps(AnkiBasicNodeOutput.model_json_schema()), + description="The JSON schema for the output of the node", + ) + + +# @NodeRegistry.register( +# category="Integrations", +# display_name="Anki Basic Card Generator", +# logo="/images/anki.png", +# subcategory="Flashcards", +# ) +class AnkiBasicNode(BaseNode): + """Generate Anki decks with basic cards (front/back only).""" + + name = "anki_basic_node" + display_name = "AnkiBasicNode" + logo = "/images/anki.png" + category = "Anki" + + config_model = AnkiBasicNodeConfig + input_model = AnkiBasicNodeInput + output_model = AnkiBasicNodeOutput + + def __init__( + self, name: str, config: AnkiBasicNodeConfig, context: Optional[Any] = None + ) -> None: + super().__init__(name=name, config=config, context=context) + # Create output directory if it doesn't exist + os.makedirs(self.config.output_dir, exist_ok=True) + + def _render_template(self, template_str: str, data: Dict[str, Any]) -> str: + """Render a Jinja template string with the provided data.""" + try: + return Template(template_str).render(**data) + except Exception as e: + print(f"[ERROR] Failed to render template in {self.name}") + print(f"[ERROR] Template: {template_str}") + raise e + + async def run(self, input: BaseModel) -> BaseModel: + """Generate an Anki deck with basic cards from the provided front and back templates.""" + input_typed = AnkiBasicNodeInput.model_validate(input) + + if len(input_typed.front) != len(input_typed.back): + raise ValueError("Number of front and back entries must match") + + # Generate random IDs if not provided + model_id = self.config.model_id or random.randrange(1 << 30, 1 << 31) + deck_id = self.config.deck_id or random.randrange(1 << 30, 1 << 31) + + # Create the basic Anki model + model: GenkaniModel = genanki.Model( + model_id, + "Basic", + fields=[ + {"name": "Front"}, + {"name": "Back"}, + ], + templates=[ + { + "name": "Card 1", + "qfmt": "{{Front}}", + "afmt": '{{FrontSide}}
{{Back}}', + }, + ], + ) + + # Create the deck + deck: GenkaniDeck = genanki.Deck(deck_id, self.config.deck_name) + + # Get input data for template rendering + input_data = input_typed.model_dump() + + # Add cards to the deck + for front_template, back_template in zip(input_typed.front, input_typed.back, strict=False): + # Render templates + front_rendered = self._render_template(front_template, input_data) + back_rendered = self._render_template(back_template, input_data) + + note: GenkaniNote = genanki.Note(model=model, fields=[front_rendered, back_rendered]) + deck.add_note(note) + + # Generate unique filename + output_path = os.path.join( + self.config.output_dir, + f"{self.config.deck_name.lower().replace(' ', '_')}_{deck_id}.apkg", + ) + + # Save the deck + package: GenkaniPackage = genanki.Package(deck) + package.write_to_file(output_path) + + output = AnkiBasicNodeOutput(deck_path=output_path, card_count=len(input_typed.front)) + return output + + +if __name__ == "__main__": + # Example usage of the AnkiBasic node + import asyncio + + async def example() -> None: + # Create node configuration + config = AnkiBasicNodeConfig(deck_name="Programming Concepts", output_dir="data/anki_decks") + + # Create node instance + node = AnkiBasicNode(name="example_node", config=config) + + # Prepare input data with templates + input_data = AnkiBasicNodeInput( + front=[ + "What is {{concept}}?", + "Explain the difference between {{thing1}} and {{thing2}}?", + ], + back=["{{concept}} is {{definition}}", "The key differences are:\n{{differences}}"], + ) + + # Add template variables to the input + input_data.concept = "recursion" # type: ignore + input_data.definition = "a programming concept where a function calls itself" # type: ignore + input_data.thing1 = "list" # type: ignore + input_data.thing2 = "tuple" # type: ignore + input_data.differences = "Lists are mutable, tuples are immutable" # type: ignore + + # Run the node + result = await node.run(input_data) + result_typed = AnkiBasicNodeOutput.model_validate(result) + + # Print results + print(f"Generated Anki deck at: {result_typed.deck_path}") + print(f"Number of cards: {result_typed.card_count}") + + # Run the example + asyncio.run(example()) diff --git a/pyspur/backend/pyspur/nodes/integrations/firecrawl/__init__.py b/pyspur/backend/pyspur/nodes/integrations/firecrawl/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pyspur/backend/pyspur/nodes/integrations/firecrawl/firecrawl_crawl.py b/pyspur/backend/pyspur/nodes/integrations/firecrawl/firecrawl_crawl.py new file mode 100644 index 0000000000000000000000000000000000000000..785cd7bc498014ae99ac9dcb23311ba06776e1cf --- /dev/null +++ b/pyspur/backend/pyspur/nodes/integrations/firecrawl/firecrawl_crawl.py @@ -0,0 +1,118 @@ +import asyncio +import json +import logging +from typing import Optional + +from pydantic import BaseModel, Field # type: ignore + +from firecrawl import FirecrawlApp # type: ignore + +from ...base import ( + BaseNode, + BaseNodeConfig, + BaseNodeInput, + BaseNodeOutput, +) +from ...registry import NodeRegistry +from ...utils.template_utils import render_template_or_get_first_string + + +class FirecrawlCrawlNodeInput(BaseNodeInput): + """Input for the FirecrawlCrawl node.""" + + class Config: + """Config for the FirecrawlCrawl node input.""" + + extra = "allow" + + +class FirecrawlCrawlNodeOutput(BaseNodeOutput): + """Output for the FirecrawlCrawl node.""" + + crawl_result: str = Field(..., description="The crawled data in markdown or structured format.") + + +class FirecrawlCrawlNodeConfig(BaseNodeConfig): + """Configuration for the FirecrawlCrawl node.""" + + url_template: str = Field( + "", + description="The URL to crawl and convert into clean markdown or structured data.", + ) + limit: Optional[int] = Field(None, description="The maximum number of pages to crawl.") + has_fixed_output: bool = True + output_json_schema: str = Field( + default=json.dumps(FirecrawlCrawlNodeOutput.model_json_schema()), + description="The JSON schema for the output of the node", + ) + + +@NodeRegistry.register( + category="Integrations", + display_name="Firecrawl Crawl", + logo="/images/firecrawl.png", + subcategory="Web Scraping", + position="before:FirecrawlScrapeNode", +) +class FirecrawlCrawlNode(BaseNode): + """Crawl a URL and return the content in markdown or structured format.""" + + name = "firecrawl_crawl_node" + config_model = FirecrawlCrawlNodeConfig + input_model = FirecrawlCrawlNodeInput + output_model = FirecrawlCrawlNodeOutput + category = "Firecrawl" # This will be used by the frontend for subcategory grouping + + async def run(self, input: BaseModel) -> BaseModel: + """Run the FirecrawlCrawl node.""" + try: + # Grab the entire dictionary from the input + raw_input_dict = input.model_dump() + + # Render url_template + url_template = render_template_or_get_first_string( + self.config.url_template, raw_input_dict, self.name + ) + + app = FirecrawlApp() # type: ignore + + # Start the asynchronous crawl + crawl_obj = app.async_crawl_url( # type: ignore + url_template, + params={ + "limit": self.config.limit, + "scrapeOptions": {"formats": ["markdown", "html"]}, + }, + ) + + # Get the crawl ID from the response + crawl_id = crawl_obj.get("id") + if not crawl_id: + raise ValueError("No crawl ID received from async crawl request") + + # Poll for completion with exponential backoff + max_attempts = 30 # Maximum number of attempts + base_delay = 2 # Base delay in seconds + + for attempt in range(max_attempts): + # Check the crawl status + status_response = app.check_crawl_status(crawl_id) # type: ignore + + if status_response.get("status") == "completed": + crawl_result = status_response.get("data", {}) + return FirecrawlCrawlNodeOutput(crawl_result=json.dumps(crawl_result)) + + if status_response.get("status") == "failed": + raise ValueError( + f"Crawl failed: {status_response.get('error', 'Unknown error')}" + ) + + # Calculate delay with exponential backoff (2^attempt seconds) + delay = min(base_delay * (2**attempt), 60) # Cap at 60 seconds + await asyncio.sleep(delay) + + raise TimeoutError("Crawl did not complete within the maximum allowed time") + + except Exception as e: + logging.error(f"Failed to crawl URL: {e}") + return FirecrawlCrawlNodeOutput(crawl_result="") diff --git a/pyspur/backend/pyspur/nodes/integrations/firecrawl/firecrawl_scrape.py b/pyspur/backend/pyspur/nodes/integrations/firecrawl/firecrawl_scrape.py new file mode 100644 index 0000000000000000000000000000000000000000..89b4ac6755c6b2ca8e0f653aef65665d71bb089b --- /dev/null +++ b/pyspur/backend/pyspur/nodes/integrations/firecrawl/firecrawl_scrape.py @@ -0,0 +1,79 @@ +import json +import logging + +from pydantic import BaseModel, Field # type: ignore + +from firecrawl import FirecrawlApp # type: ignore + +from ...base import BaseNode, BaseNodeConfig, BaseNodeInput, BaseNodeOutput +from ...registry import NodeRegistry +from ...utils.template_utils import render_template_or_get_first_string + + +class FirecrawlScrapeNodeInput(BaseNodeInput): + """Input for the FirecrawlScrape node.""" + + class Config: + """Config for the FirecrawlScrape node input.""" + + extra = "allow" + + +class FirecrawlScrapeNodeOutput(BaseNodeOutput): + """Output for the FirecrawlScrape node.""" + + markdown: str = Field(..., description="The scraped data in markdown format.") + + +class FirecrawlScrapeNodeConfig(BaseNodeConfig): + """Configuration for the FirecrawlScrape node.""" + + url_template: str = Field( + "", + description="The URL to scrape and convert into clean markdown or structured data.", + ) + has_fixed_output: bool = True + output_json_schema: str = Field( + default=json.dumps(FirecrawlScrapeNodeOutput.model_json_schema()), + description="The JSON schema for the output of the node", + ) + + +@NodeRegistry.register( + category="Integrations", + display_name="Firecrawl Scrape", + logo="/images/firecrawl.png", + subcategory="Web Scraping", + position="after:FirecrawlCrawlNode", +) +class FirecrawlScrapeNode(BaseNode): + """Scrapes a URL and returns the content in markdown or structured format.""" + + name = "firecrawl_scrape_node" + config_model = FirecrawlScrapeNodeConfig + input_model = FirecrawlScrapeNodeInput + output_model = FirecrawlScrapeNodeOutput + category = "Firecrawl" # This will be used by the frontend for subcategory grouping + + async def run(self, input: BaseModel) -> BaseModel: + """Scrapes a URL and returns the content in markdown or structured format.""" + try: + # Grab the entire dictionary from the input + raw_input_dict = input.model_dump() + + # Render url_template + url_template = render_template_or_get_first_string( + self.config.url_template, raw_input_dict, self.name + ) + + app = FirecrawlApp() # type: ignore + scrape_result = app.scrape_url( # type: ignore + url_template, + params={ + "formats": ["markdown"], + }, + ) + return FirecrawlScrapeNodeOutput(markdown=scrape_result["markdown"]) + except Exception as e: + logging.error(f"Failed to scrape URL: {e}") + return FirecrawlScrapeNodeOutput(markdown="") diff --git a/pyspur/backend/pyspur/nodes/integrations/github/__init__.py b/pyspur/backend/pyspur/nodes/integrations/github/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pyspur/backend/pyspur/nodes/integrations/github/github_create_issue.py b/pyspur/backend/pyspur/nodes/integrations/github/github_create_issue.py new file mode 100644 index 0000000000000000000000000000000000000000..ff12b73d7e539020aaf06b1adf71b2fd8835596e --- /dev/null +++ b/pyspur/backend/pyspur/nodes/integrations/github/github_create_issue.py @@ -0,0 +1,54 @@ +import json +import logging +from typing import Optional + +from phi.tools.github import GithubTools +from pydantic import BaseModel, Field # type: ignore + +from ...base import BaseNode, BaseNodeConfig, BaseNodeInput, BaseNodeOutput + + +class GitHubCreateIssueNodeInput(BaseNodeInput): + """Input for the GitHubCreateIssue node""" + + class Config: + extra = "allow" + + +class GitHubCreateIssueNodeOutput(BaseNodeOutput): + issue: str = Field(..., description="The created issue details in JSON format.") + + +class GitHubCreateIssueNodeConfig(BaseNodeConfig): + repo_name: str = Field("", description="The full name of the repository (e.g. 'owner/repo').") + issue_title: str = Field("", description="The title of the issue.") + body: Optional[str] = Field(None, description="The body content of the issue.") + has_fixed_output: bool = True + output_json_schema: str = Field( + default=json.dumps(GitHubCreateIssueNodeOutput.model_json_schema()), + description="The JSON schema for the output of the node", + ) + + +class GitHubCreateIssueNode(BaseNode): + name = "github_create_issue_node" + display_name = "GitHubCreateIssue" + logo = "/images/github.png" + category = "GitHub" + + config_model = GitHubCreateIssueNodeConfig + input_model = GitHubCreateIssueNodeInput + output_model = GitHubCreateIssueNodeOutput + + async def run(self, input: BaseModel) -> BaseModel: + try: + gh = GithubTools() + issue_info = gh.create_issue( + repo_name=self.config.repo_name, + title=self.config.issue_title, + body=self.config.body, + ) + return GitHubCreateIssueNodeOutput(issue=issue_info) + except Exception as e: + logging.error(f"Failed to create issue: {e}") + return GitHubCreateIssueNodeOutput(issue="") diff --git a/pyspur/backend/pyspur/nodes/integrations/github/github_get_pull_request.py b/pyspur/backend/pyspur/nodes/integrations/github/github_get_pull_request.py new file mode 100644 index 0000000000000000000000000000000000000000..11c225a3fe6491dfd1015406c1088959c242c964 --- /dev/null +++ b/pyspur/backend/pyspur/nodes/integrations/github/github_get_pull_request.py @@ -0,0 +1,53 @@ +import json +import logging + +from phi.tools.github import GithubTools +from pydantic import BaseModel, Field # type: ignore + +from ...base import BaseNode, BaseNodeConfig, BaseNodeInput, BaseNodeOutput + + +class GitHubGetPullRequestNodeInput(BaseNodeInput): + """Input for the GitHubGetPullRequest node""" + + class Config: + extra = "allow" + + +class GitHubGetPullRequestNodeOutput(BaseNodeOutput): + pull_request: str = Field( + ..., description="Details of the requested pull request in JSON format." + ) + + +class GitHubGetPullRequestNodeConfig(BaseNodeConfig): + repo_name: str = Field("", description="The full name of the repository (e.g. 'owner/repo').") + pr_number: str = Field("", description="The pull request number.") + has_fixed_output: bool = True + output_json_schema: str = Field( + default=json.dumps(GitHubGetPullRequestNodeOutput.model_json_schema()), + description="The JSON schema for the output of the node", + ) + + +class GitHubGetPullRequestNode(BaseNode): + name = "github_get_pull_request_node" + display_name = "GitHubGetPullRequest" + logo = "/images/github.png" + category = "GitHub" + + config_model = GitHubGetPullRequestNodeConfig + input_model = GitHubGetPullRequestNodeInput + output_model = GitHubGetPullRequestNodeOutput + + async def run(self, input: BaseModel) -> BaseModel: + try: + gh = GithubTools() + pr_details = gh.get_pull_request( + repo_name=self.config.repo_name, + pr_number=int(self.config.pr_number), + ) + return GitHubGetPullRequestNodeOutput(pull_request=pr_details) + except Exception as e: + logging.error(f"Failed to get pull request details: {e}") + return GitHubGetPullRequestNodeOutput(pull_request="") diff --git a/pyspur/backend/pyspur/nodes/integrations/github/github_get_pull_request_changes.py b/pyspur/backend/pyspur/nodes/integrations/github/github_get_pull_request_changes.py new file mode 100644 index 0000000000000000000000000000000000000000..16a0394c2b9df3d601578417cb1fd37ca5da3ba4 --- /dev/null +++ b/pyspur/backend/pyspur/nodes/integrations/github/github_get_pull_request_changes.py @@ -0,0 +1,54 @@ +import json +import logging +from typing import Optional + +from phi.tools.github import GithubTools +from pydantic import BaseModel, Field # type: ignore + +from ...base import BaseNode, BaseNodeConfig, BaseNodeInput, BaseNodeOutput + + +class GitHubGetPullRequestChangesNodeInput(BaseNodeInput): + """Input for the GitHubGetPullRequestChanges node""" + + class Config: + extra = "allow" + + +class GitHubGetPullRequestChangesNodeOutput(BaseNodeOutput): + pull_request_changes: str = Field( + ..., + description="The list of changed files in the pull request in JSON format.", + ) + + +class GitHubGetPullRequestChangesNodeConfig(BaseNodeConfig): + repo_name: str = Field("", description="The full name of the repository (e.g. 'owner/repo').") + pr_number: Optional[int] = Field(None, description="The pull request number.") + has_fixed_output: bool = True + output_json_schema: str = Field( + default=json.dumps(GitHubGetPullRequestChangesNodeOutput.model_json_schema()), + description="The JSON schema for the output of the node", + ) + + +class GitHubGetPullRequestChangesNode(BaseNode): + name = "github_get_pull_request_changes_node" + display_name = "GitHubGetPullRequestChanges" + logo = "/images/github.png" + category = "GitHub" + + config_model = GitHubGetPullRequestChangesNodeConfig + input_model = GitHubGetPullRequestChangesNodeInput + output_model = GitHubGetPullRequestChangesNodeOutput + + async def run(self, input: BaseModel) -> BaseModel: + try: + gh = GithubTools() + pr_changes = gh.get_pull_request_changes( + repo_name=self.config.repo_name, pr_number=self.config.pr_number + ) + return GitHubGetPullRequestChangesNodeOutput(pull_request_changes=pr_changes) + except Exception as e: + logging.error(f"Failed to get pull request changes: {e}") + return GitHubGetPullRequestChangesNodeOutput(pull_request_changes="") diff --git a/pyspur/backend/pyspur/nodes/integrations/github/github_get_repository.py b/pyspur/backend/pyspur/nodes/integrations/github/github_get_repository.py new file mode 100644 index 0000000000000000000000000000000000000000..4c6c6557aae4e78843c5f2b37735d865ffde98dd --- /dev/null +++ b/pyspur/backend/pyspur/nodes/integrations/github/github_get_repository.py @@ -0,0 +1,49 @@ +import json +import logging + +from phi.tools.github import GithubTools +from pydantic import BaseModel, Field # type: ignore + +from ...base import BaseNode, BaseNodeConfig, BaseNodeInput, BaseNodeOutput + + +class GitHubGetRepositoryNodeInput(BaseNodeInput): + """Input for the GitHubGetRepository node""" + + class Config: + extra = "allow" + + +class GitHubGetRepositoryNodeOutput(BaseNodeOutput): + repository_details: str = Field( + ..., description="Details of the requested repository in JSON format." + ) + + +class GitHubGetRepositoryNodeConfig(BaseNodeConfig): + repo_name: str = Field("", description="The full name of the repository (e.g. 'owner/repo').") + output_json_schema: str = Field( + default=json.dumps(GitHubGetRepositoryNodeOutput.model_json_schema()), + description="The JSON schema for the output of the node", + ) + has_fixed_output: bool = True + + +class GitHubGetRepositoryNode(BaseNode): + name = "github_get_repository_node" + display_name = "GitHubGetRepository" + logo = "/images/github.png" + category = "GitHub" + + config_model = GitHubGetRepositoryNodeConfig + input_model = GitHubGetRepositoryNodeInput + output_model = GitHubGetRepositoryNodeOutput + + async def run(self, input: BaseModel) -> BaseModel: + try: + gh = GithubTools() + repo_details = gh.get_repository(repo_name=self.config.repo_name) + return GitHubGetRepositoryNodeOutput(repository_details=repo_details) + except Exception as e: + logging.error(f"Failed to get repository details: {e}") + return GitHubGetRepositoryNodeOutput(repository_details="") diff --git a/pyspur/backend/pyspur/nodes/integrations/github/github_list_pull_requests.py b/pyspur/backend/pyspur/nodes/integrations/github/github_list_pull_requests.py new file mode 100644 index 0000000000000000000000000000000000000000..fa3e318bb1ac976da9bb3fd38242957b8a171ec4 --- /dev/null +++ b/pyspur/backend/pyspur/nodes/integrations/github/github_list_pull_requests.py @@ -0,0 +1,77 @@ +import json +import logging + +from phi.tools.github import GithubTools +from pydantic import BaseModel, Field # type: ignore + +from ...base import BaseNode, BaseNodeConfig, BaseNodeInput, BaseNodeOutput + + +class GitHubListPullRequestsNodeInput(BaseNodeInput): + """Input for the GitHubListPullRequests node""" + + class Config: + extra = "allow" + + +class GitHubListPullRequestsNodeOutput(BaseNodeOutput): + pull_requests: str = Field(..., description="The pull requests for the repository.") + + +class GitHubListPullRequestsNodeConfig(BaseNodeConfig): + repo_name: str = Field( + "", + description="The GitHub repository URL to fetch the pull requests for.", + ) + state: str = Field( + "open", + description="The state of the pull requests to fetch. Can be 'open', 'closed', or 'all'.", + ) + has_fixed_output: bool = True + output_json_schema: str = Field( + default=json.dumps(GitHubListPullRequestsNodeOutput.model_json_schema()), + description="The JSON schema for the output of the node", + ) + + +class GitHubListPullRequestsNode(BaseNode): + name = "github_list_pull_requests_node" + display_name = "GitHubListPullRequests" + logo = "/images/github.png" + category = "GitHub" + + config_model = GitHubListPullRequestsNodeConfig + input_model = GitHubListPullRequestsNodeInput + output_model = GitHubListPullRequestsNodeOutput + + async def run(self, input: BaseModel) -> BaseModel: + """ + Fetches the pull requests for a given GitHub repository URL and state. + """ + try: + gh = GithubTools() + pull_requests = gh.list_pull_requests( + repo_name=self.config.repo_name, state=self.config.state + ) + return GitHubListPullRequestsNodeOutput(pull_requests=pull_requests) + except Exception as e: + logging.error(f"Failed to get pull requests: {e}") + return GitHubListPullRequestsNodeOutput(pull_requests="") + + +if __name__ == "__main__": + import asyncio + + async def main(): + # Example usage + node = GitHubListPullRequestsNode( + name="github_list_pull_requests_node", + config=GitHubListPullRequestsNodeConfig( + repo_name="parshva-bhadra/pyspur", state="closed" + ), + ) + input_data = GitHubListPullRequestsNodeInput() + output = await node.run(input_data) + print(output) + + asyncio.run(main()) diff --git a/pyspur/backend/pyspur/nodes/integrations/github/github_list_repositories.py b/pyspur/backend/pyspur/nodes/integrations/github/github_list_repositories.py new file mode 100644 index 0000000000000000000000000000000000000000..5a84e3c6656ae3493ddbb9efb7cc245902bf8d6c --- /dev/null +++ b/pyspur/backend/pyspur/nodes/integrations/github/github_list_repositories.py @@ -0,0 +1,46 @@ +import json +import logging + +from phi.tools.github import GithubTools +from pydantic import BaseModel, Field # type: ignore + +from ...base import BaseNode, BaseNodeConfig, BaseNodeInput, BaseNodeOutput + + +class GitHubListRepositoriesNodeInput(BaseNodeInput): + """Input for the GitHubListRepositories node""" + + class Config: + extra = "allow" + + +class GitHubListRepositoriesNodeOutput(BaseNodeOutput): + repositories: str = Field(..., description="A JSON string of the repositories for the user.") + + +class GitHubListRepositoriesNodeConfig(BaseNodeConfig): + has_fixed_output: bool = True + output_json_schema: str = Field( + default=json.dumps(GitHubListRepositoriesNodeOutput.model_json_schema()), + description="The JSON schema for the output of the node", + ) + + +class GitHubListRepositoriesNode(BaseNode): + name = "github_list_repositories_node" + display_name = "GitHubListRepositories" + logo = "/images/github.png" + category = "GitHub" + + config_model = GitHubListRepositoriesNodeConfig + input_model = GitHubListRepositoriesNodeInput + output_model = GitHubListRepositoriesNodeOutput + + async def run(self, input: BaseModel) -> BaseModel: + try: + gh = GithubTools() + repositories = gh.list_repositories() + return GitHubListRepositoriesNodeOutput(repositories=repositories) + except Exception as e: + logging.error(f"Failed to list repositories: {e}") + return GitHubListRepositoriesNodeOutput(repositories="") diff --git a/pyspur/backend/pyspur/nodes/integrations/github/github_search_repositories.py b/pyspur/backend/pyspur/nodes/integrations/github/github_search_repositories.py new file mode 100644 index 0000000000000000000000000000000000000000..6813b0a0848e7470a25c0f85b120b00f0c625437 --- /dev/null +++ b/pyspur/backend/pyspur/nodes/integrations/github/github_search_repositories.py @@ -0,0 +1,61 @@ +import json +import logging + +from phi.tools.github import GithubTools +from pydantic import BaseModel, Field # type: ignore + +from ...base import BaseNode, BaseNodeConfig, BaseNodeInput, BaseNodeOutput + + +class GitHubSearchRepositoriesNodeInput(BaseNodeInput): + """Input for the GitHubSearchRepositories node""" + + class Config: + extra = "allow" + + +class GitHubSearchRepositoriesNodeOutput(BaseNodeOutput): + repositories: str = Field( + ..., + description="A JSON string of repositories matching the search query.", + ) + + +class GitHubSearchRepositoriesNodeConfig(BaseNodeConfig): + query: str = Field(..., description="The search query keywords (e.g. 'machine learning').") + sort: str = Field( + "stars", + description="The field to sort results by. Can be 'stars', 'forks', or 'updated'.", + ) + order: str = Field("desc", description="The order of results. Can be 'asc' or 'desc'.") + per_page: int = Field(5, description="Number of results per page.") + has_fixed_output: bool = True + output_json_schema: str = Field( + default=json.dumps(GitHubSearchRepositoriesNodeOutput.model_json_schema()), + description="The JSON schema for the output of the node", + ) + + +class GitHubSearchRepositoriesNode(BaseNode): + name = "github_search_repositories_node" + display_name = "GitHubSearchRepositories" + logo = "/images/github.png" + category = "GitHub" + + config_model = GitHubSearchRepositoriesNodeConfig + input_model = GitHubSearchRepositoriesNodeInput + output_model = GitHubSearchRepositoriesNodeOutput + + async def run(self, input: BaseModel) -> BaseModel: + try: + gh = GithubTools() + repos = gh.search_repositories( + query=self.config.query, + sort=self.config.sort, + order=self.config.order, + per_page=self.config.per_page, + ) + return GitHubSearchRepositoriesNodeOutput(repositories=repos) + except Exception as e: + logging.error(f"Failed to search repositories: {e}") + return GitHubSearchRepositoriesNodeOutput(repositories="") diff --git a/pyspur/backend/pyspur/nodes/integrations/google/__init__.py b/pyspur/backend/pyspur/nodes/integrations/google/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pyspur/backend/pyspur/nodes/integrations/google/google_sheets_read.py b/pyspur/backend/pyspur/nodes/integrations/google/google_sheets_read.py new file mode 100644 index 0000000000000000000000000000000000000000..f80653fea088af84286b400d99b799f7d1f40076 --- /dev/null +++ b/pyspur/backend/pyspur/nodes/integrations/google/google_sheets_read.py @@ -0,0 +1,66 @@ +import json + +from pydantic import BaseModel, Field + +from ....integrations.google.client import GoogleSheetsClient +from ...base import BaseNode, BaseNodeConfig, BaseNodeInput, BaseNodeOutput + + +class GoogleSheetsReadNodeInput(BaseNodeInput): + """Input for the GoogleSheetsRead node""" + + class Config: + extra = "allow" + + +class GoogleSheetsReadNodeOutput(BaseNodeOutput): + data: str = Field(..., description="The data from the Google Sheet.") + + +class GoogleSheetsReadNodeConfig(BaseNodeConfig): + spreadsheet_id: str = Field("", description="The ID of the Google Sheet to read from.") + range: str = Field( + "", + description="The range of cells to read from (e.g. 'Sheet1!A1:B10').", + ) + has_fixed_output: bool = True + output_json_schema: str = Field( + default=json.dumps(GoogleSheetsReadNodeOutput.model_json_schema()), + description="The JSON schema for the output of the node", + ) + + +class GoogleSheetsReadNode(BaseNode): + """ + Node that reads data from a specified range in a Google Sheet. + """ + + name = "google_sheets_read_node" + display_name = "GoogleSheetsRead" + logo = "/images/google.png" + category = "Google" + + config_model = GoogleSheetsReadNodeConfig + input_model = GoogleSheetsReadNodeInput + output_model = GoogleSheetsReadNodeOutput + + async def run(self, input: BaseModel) -> BaseModel: + """ + Runs the node, uses GoogleSheetsClient to read from the specified + sheet and range, and returns the data in the output model. + """ + sheets_client = GoogleSheetsClient() + + try: + success, result = sheets_client.read_sheet( + spreadsheet_id=self.config.spreadsheet_id, + range_name=self.config.range, + ) + + if success: + return GoogleSheetsReadNodeOutput(data=result) + else: + return GoogleSheetsReadNodeOutput(data=f"Error: {result}") + + except Exception as e: + return GoogleSheetsReadNodeOutput(data=f"Exception occurred: {str(e)}") diff --git a/pyspur/backend/pyspur/nodes/integrations/jina/__init__.py b/pyspur/backend/pyspur/nodes/integrations/jina/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pyspur/backend/pyspur/nodes/integrations/jina/jina_reader.py b/pyspur/backend/pyspur/nodes/integrations/jina/jina_reader.py new file mode 100644 index 0000000000000000000000000000000000000000..f1baaae809ff48b86545585a4d3cb246f4da8e2c --- /dev/null +++ b/pyspur/backend/pyspur/nodes/integrations/jina/jina_reader.py @@ -0,0 +1,72 @@ +import json +import logging + +import httpx +from pydantic import BaseModel, Field # type: ignore + +from ...base import BaseNode, BaseNodeConfig, BaseNodeInput, BaseNodeOutput +from ...utils.template_utils import render_template_or_get_first_string + + +class JinaReaderNodeInput(BaseNodeInput): + """Input for the JinaReader node""" + + class Config: + extra = "allow" + + +class JinaReaderNodeOutput(BaseNodeOutput): + title: str = Field("", description="The title of scraped page") + content: str = Field("", description="The content of the scraped page in markdown format") + + +class JinaReaderNodeConfig(BaseNodeConfig): + url_template: str = Field( + "https://r.jina.ai/{url}", + description="The URL to crawl and convert into clean markdown.", + ) + use_readerlm_v2: bool = Field(True, description="Use the Reader LM v2 model to process the URL") + has_fixed_output: bool = True + output_json_schema: str = Field( + default=json.dumps(JinaReaderNodeOutput.model_json_schema()), + description="The JSON schema for the output of the node", + ) + + +class JinaReaderNode(BaseNode): + name = "jina_reader_node" + display_name = "Reader" + logo = "/images/jina.png" + category = "Jina.AI" + + config_model = JinaReaderNodeConfig + input_model = JinaReaderNodeInput + output_model = JinaReaderNodeOutput + + async def run(self, input: BaseModel) -> BaseModel: + try: + headers = { + "Accept": "application/json", + } + if self.config.use_readerlm_v2: + headers["X-Respond-With"] = "readerlm-v2" + + # Grab the entire dictionary from the input + raw_input_dict = input.model_dump() + + # Render url_template + reader_url = render_template_or_get_first_string( + self.config.url_template, raw_input_dict, self.name + ) + + async with httpx.AsyncClient() as client: + response = await client.get(reader_url, headers=headers, timeout=None) + logging.debug("Fetched from Jina: {text}".format(text=response.text)) + output = JinaReaderNodeOutput.model_validate(response.json()["data"]) + if output.content.startswith("```markdown"): + # remove the backticks/code format indicators in the output + output.content = output.content[12:-4] + return output + except Exception as e: + logging.error(f"Failed to convert URL: {e}") + return JinaReaderNodeOutput(title="", content="") diff --git a/pyspur/backend/pyspur/nodes/integrations/mathpix/pdf_to_latex.py b/pyspur/backend/pyspur/nodes/integrations/mathpix/pdf_to_latex.py new file mode 100644 index 0000000000000000000000000000000000000000..ce976f774a9665241ba969a6763572176f98caa2 --- /dev/null +++ b/pyspur/backend/pyspur/nodes/integrations/mathpix/pdf_to_latex.py @@ -0,0 +1,89 @@ +import json +import logging +import os + +import requests +from pydantic import BaseModel, Field # type: ignore + +from ...base import BaseNode, BaseNodeConfig, BaseNodeInput, BaseNodeOutput +from ...utils.template_utils import render_template_or_get_first_string + + +class MathpixPdfToLatexNodeInput(BaseNodeInput): + """Input for the MathpixPdfToLatex node""" + + class Config: + extra = "allow" + + +class MathpixPdfToLatexNodeOutput(BaseNodeOutput): + latex_result: str = Field( + ..., + description="The converted LaTeX content or conversion status JSON.", + ) + + +class MathpixPdfToLatexNodeConfig(BaseNodeConfig): + url_template: str = Field("", description="The URL of the PDF to convert to LaTeX.") + app_id: str = Field( + default="", + description="Mathpix API app_id. Can be set via environment variable MATHPIX_APP_ID.", + ) + app_key: str = Field( + default="", + description="Mathpix API app_key. Can be set via environment variable MATHPIX_APP_KEY.", + ) + has_fixed_output: bool = True + output_json_schema: str = Field( + default=json.dumps(MathpixPdfToLatexNodeOutput.model_json_schema()), + description="The JSON schema for the output of the node", + ) + + +class MathpixPdfToLatexNode(BaseNode): + name = "mathpix_pdf_to_latex_node" + display_name = "Mathpix PDF to LaTeX" + logo = "/images/mathpix.png" + category = "Mathpix" + + config_model = MathpixPdfToLatexNodeConfig + input_model = MathpixPdfToLatexNodeInput + output_model = MathpixPdfToLatexNodeOutput + + async def run(self, input: BaseModel) -> BaseModel: + """ + Converts a PDF to LaTeX using the Mathpix API. + """ + try: + # Get input dictionary + raw_input_dict = input.model_dump() + + # Render URL template + url = render_template_or_get_first_string( + self.config.url_template, raw_input_dict, self.name + ) + + # Get credentials from config or environment + app_id = self.config.app_id or os.environ.get("MATHPIX_APP_ID") + app_key = self.config.app_key or os.environ.get("MATHPIX_APP_KEY") + + if not app_id or not app_key: + raise ValueError("Mathpix API credentials not provided") + + # Make API request + response = requests.post( + "https://api.mathpix.com/v3/pdf", + json={"url": url, "conversion_formats": {"tex.zip": True}}, + headers={ + "app_id": app_id, + "app_key": app_key, + "Content-type": "application/json", + }, + ) + + # Return the conversion result + return MathpixPdfToLatexNodeOutput(latex_result=json.dumps(response.json(), indent=2)) + + except Exception as e: + logging.error(f"Failed to convert PDF to LaTeX: {e}") + return MathpixPdfToLatexNodeOutput(latex_result=str(e)) diff --git a/pyspur/backend/pyspur/nodes/integrations/meta/ad_library.py b/pyspur/backend/pyspur/nodes/integrations/meta/ad_library.py new file mode 100644 index 0000000000000000000000000000000000000000..89df31ac18123cccd47d35c4c30898891877a413 --- /dev/null +++ b/pyspur/backend/pyspur/nodes/integrations/meta/ad_library.py @@ -0,0 +1,120 @@ +import json +import logging +from typing import List, Optional + +import requests +from pydantic import BaseModel, Field + +from ...base import BaseNode, BaseNodeConfig, BaseNodeInput, BaseNodeOutput + + +class FacebookAdLibraryNodeInput(BaseNodeInput): + """Input for the FacebookAdLibrary node""" + + class Config: + extra = "allow" + + +class FacebookAdLibraryNodeOutput(BaseNodeOutput): + ads: str = Field(..., description="JSON string containing the retrieved ad data") + + +class FacebookAdLibraryNodeConfig(BaseNodeConfig): + access_token: str = Field("", description="Meta API access token for authentication") + profile_url: str = Field("", description="Facebook profile URL to search ads for") + country_code: str = Field( + "US", + description="Two-letter country code for ad search (e.g., US, GB, DE)", + ) + media_type: str = Field("ALL", description="Type of media to search for (ALL, IMAGE, VIDEO)") + platforms: List[str] = Field( + default=["FACEBOOK", "INSTAGRAM"], + description="Platforms to search ads on (FACEBOOK, INSTAGRAM)", + ) + max_ads: int = Field(100, description="Maximum number of ads to retrieve (max 500)") + ad_active_status: str = Field( + "ACTIVE", description="Filter by ad status (ACTIVE, INACTIVE, ALL)" + ) + fields: List[str] = Field( + default=[ + "ad_creation_time", + "ad_creative_body", + "ad_creative_link_caption", + "ad_creative_link_description", + "ad_creative_link_title", + "ad_snapshot_url", + "page_id", + "page_name", + "publisher_platforms", + "spend", + "impressions", + ], + description="Fields to retrieve from the Ad Library API", + ) + has_fixed_output: bool = True + output_json_schema: str = Field( + default=json.dumps(FacebookAdLibraryNodeOutput.model_json_schema()), + description="The JSON schema for the output of the node", + ) + + def __init__(self, **data): + super().__init__(**data) + + +class FacebookAdLibraryNode(BaseNode): + name = "facebook_ad_library_node" + display_name = "FacebookAdLibrary" + logo = "/images/meta.png" + category = "Meta" + + config_model = FacebookAdLibraryNodeConfig + input_model = FacebookAdLibraryNodeInput + output_model = FacebookAdLibraryNodeOutput + + def _extract_page_id(self, profile_url: str) -> Optional[str]: + """Extract page ID from Facebook profile URL""" + # This is a simplified version - you may need to enhance this based on URL formats + try: + if "facebook.com" not in profile_url: + return None + parts = profile_url.rstrip("/").split("/") + return parts[-1] + except Exception: + return None + + async def run(self, input: BaseModel) -> BaseModel: + try: + page_id = self._extract_page_id(self.config.profile_url) + if not page_id: + raise ValueError("Invalid Facebook profile URL") + + api_version = "v18.0" # Using latest stable version + base_url = f"https://graph.facebook.com/{api_version}/ads_archive" + + params = { + "access_token": self.config.access_token, + "search_page_ids": page_id, + "ad_reached_countries": self.config.country_code, + "ad_active_status": self.config.ad_active_status, + "limit": min(500, self.config.max_ads), # API limit is 500 + "fields": ",".join(self.config.fields), + "ad_type": self.config.media_type, + "publisher_platforms": ",".join(self.config.platforms), + } + + response = requests.get(base_url, params=params) + if response.status_code != 200: + raise Exception(f"API request failed: {response.text}") + + data = response.json() + + # Format and return the results + ads_data = data.get("data", []) + if len(ads_data) > self.config.max_ads: + ads_data = ads_data[: self.config.max_ads] + + return FacebookAdLibraryNodeOutput(ads=json.dumps(ads_data)) + + except Exception as e: + logging.error(f"Failed to retrieve Facebook ads: {str(e)}") + return FacebookAdLibraryNodeOutput(ads="[]") diff --git a/pyspur/backend/pyspur/nodes/integrations/reddit/__init__.py b/pyspur/backend/pyspur/nodes/integrations/reddit/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pyspur/backend/pyspur/nodes/integrations/reddit/reddit_create_post.py b/pyspur/backend/pyspur/nodes/integrations/reddit/reddit_create_post.py new file mode 100644 index 0000000000000000000000000000000000000000..9481f3ef659d9b08fcf8cb66dccd241891178cf5 --- /dev/null +++ b/pyspur/backend/pyspur/nodes/integrations/reddit/reddit_create_post.py @@ -0,0 +1,142 @@ +import json +import logging +import os +from typing import Optional + +import praw +from pydantic import BaseModel, Field # type: ignore + +from ...base import BaseNode, BaseNodeConfig, BaseNodeInput, BaseNodeOutput + + +class RedditCreatePostNodeInput(BaseNodeInput): + """Input for the RedditCreatePost node.""" + + class Config: + extra = "allow" + + +class CreatedPostInfo(BaseModel): + id: str = Field(..., description="ID of the created post") + title: str = Field(..., description="Title of the created post") + url: str = Field(..., description="URL of the created post") + permalink: str = Field(..., description="Reddit permalink to the post") + created_utc: float = Field(..., description="Creation timestamp in UTC") + author: str = Field(..., description="Username of the post author") + flair: Optional[str] = Field(None, description="Flair text of the post if any") + + +class RedditCreatePostError(BaseModel): + error: str = Field(..., description="Error message explaining what went wrong") + + +class RedditCreatePostNodeOutput(BaseNodeOutput): + post_info: CreatedPostInfo | RedditCreatePostError = Field( + ..., description="Information about the created post or error details" + ) + + +class RedditCreatePostNodeConfig(BaseNodeConfig): + subreddit: str = Field("", description="The subreddit to post in.") + title: str = Field("", description="The title of the post.") + content: str = Field( + "", description="The content of the post (text for self posts, URL for link posts)." + ) + flair: Optional[str] = Field(None, description="Optional flair to add to the post.") + is_self: bool = Field(True, description="Whether this is a self (text) post or link post.") + username: str = Field( + "", description="Reddit username. Can also be set via REDDIT_USERNAME environment variable." + ) + password: str = Field( + "", description="Reddit password. Can also be set via REDDIT_PASSWORD environment variable." + ) + has_fixed_output: bool = True + output_json_schema: str = Field( + default=json.dumps(RedditCreatePostNodeOutput.model_json_schema()), + description="The JSON schema for the output of the node", + ) + + +class RedditCreatePostNode(BaseNode): + name = "reddit_create_post_node" + display_name = "RedditCreatePost" + logo = "/images/reddit.png" + category = "Reddit" + + config_model = RedditCreatePostNodeConfig + input_model = RedditCreatePostNodeInput + output_model = RedditCreatePostNodeOutput + + async def run(self, input: BaseModel) -> BaseModel: + try: + client_id = os.getenv("REDDIT_CLIENT_ID") + client_secret = os.getenv("REDDIT_CLIENT_SECRET") + user_agent = os.getenv("REDDIT_USER_AGENT", "RedditTools v1.0") + username = os.getenv("REDDIT_USERNAME") or self.config.username + password = os.getenv("REDDIT_PASSWORD") or self.config.password + + if not client_id or not client_secret: + raise ValueError("Reddit API credentials not found in environment variables") + + if not username or not password: + raise ValueError("Reddit username and password are required for creating posts") + + reddit = praw.Reddit( + client_id=client_id, + client_secret=client_secret, + user_agent=user_agent, + username=username, + password=password, + ) + + # Verify authentication + try: + reddit.user.me() + except Exception as e: + logging.error(f"Authentication error: {e}") + return RedditCreatePostNodeOutput( + post_info=RedditCreatePostError(error="Failed to authenticate with Reddit") + ) + + subreddit = reddit.subreddit(self.config.subreddit) + + # Check flair if provided + if self.config.flair: + available_flairs = [f["text"] for f in subreddit.flair.link_templates] + if self.config.flair not in available_flairs: + return RedditCreatePostNodeOutput( + post_info=RedditCreatePostError( + error=f"Invalid flair. Available flairs: {', '.join(available_flairs)}" + ) + ) + + # Create the post + if self.config.is_self: + submission = subreddit.submit( + title=self.config.title, + selftext=self.config.content, + flair_id=self.config.flair, + ) + else: + submission = subreddit.submit( + title=self.config.title, + url=self.config.content, + flair_id=self.config.flair, + ) + + post_info = CreatedPostInfo( + id=submission.id, + title=submission.title, + url=submission.url, + permalink=submission.permalink, + created_utc=submission.created_utc, + author=str(submission.author), + flair=submission.link_flair_text, + ) + + return RedditCreatePostNodeOutput(post_info=post_info) + except Exception as e: + logging.error(f"Failed to create post: {e}") + return RedditCreatePostNodeOutput( + post_info=RedditCreatePostError(error=f"Failed to create post: {str(e)}") + ) diff --git a/pyspur/backend/pyspur/nodes/integrations/reddit/reddit_get_subreddit_info.py b/pyspur/backend/pyspur/nodes/integrations/reddit/reddit_get_subreddit_info.py new file mode 100644 index 0000000000000000000000000000000000000000..1856bc89167839f776e0f144fd824d13a8b25876 --- /dev/null +++ b/pyspur/backend/pyspur/nodes/integrations/reddit/reddit_get_subreddit_info.py @@ -0,0 +1,109 @@ +import json +import logging +import os +from typing import List + +import praw +from pydantic import BaseModel, Field # type: ignore + +from ...base import BaseNode, BaseNodeConfig, BaseNodeInput, BaseNodeOutput + + +class RedditGetSubredditInfoNodeInput(BaseNodeInput): + """Input for the RedditGetSubredditInfo node""" + + class Config: + extra = "allow" + + +class SubredditInfo(BaseModel): + display_name: str = Field(..., description="Display name of the subreddit") + title: str = Field(..., description="Title of the subreddit") + description: str = Field(..., description="Full description of the subreddit") + subscribers: int = Field(..., description="Number of subscribers") + created_utc: float = Field(..., description="Creation timestamp in UTC") + over18: bool = Field(..., description="Whether the subreddit is NSFW") + available_flairs: List[str] = Field(..., description="List of available post flairs") + public_description: str = Field(..., description="Public description of the subreddit") + url: str = Field(..., description="URL of the subreddit") + + +class RedditGetSubredditInfoNodeOutput(BaseNodeOutput): + subreddit_info: SubredditInfo = Field(..., description="The subreddit information") + + +# Define a simple schema without complex nested structures +SIMPLE_OUTPUT_SCHEMA = { + "title": "RedditGetSubredditInfoNodeOutput", + "type": "object", + "properties": { + "subreddit_info": { + "title": "Subreddit Info", + "type": "object", + "description": "The subreddit information", + } + }, + "required": ["subreddit_info"], +} + + +class RedditGetSubredditInfoNodeConfig(BaseNodeConfig): + subreddit_name: str = Field("", description="The name of the subreddit to get information for.") + has_fixed_output: bool = True + output_json_schema: str = Field( + default=json.dumps(SIMPLE_OUTPUT_SCHEMA), + description="The JSON schema for the output of the node", + ) + + +class RedditGetSubredditInfoNode(BaseNode): + name = "reddit_get_subreddit_info_node" + display_name = "RedditGetSubredditInfo" + logo = "/images/reddit.png" + category = "Reddit" + + config_model = RedditGetSubredditInfoNodeConfig + input_model = RedditGetSubredditInfoNodeInput + output_model = RedditGetSubredditInfoNodeOutput + + def setup(self) -> None: + """Override setup to handle schema issues""" + try: + super().setup() + except ValueError as e: + if "Unsupported JSON schema type" in str(e): + # If we hit schema issues, use a very basic setup + logging.warning(f"Schema error: {e}, using simplified approach") + + async def run(self, input: BaseModel) -> BaseModel: + try: + client_id = os.getenv("REDDIT_CLIENT_ID") + client_secret = os.getenv("REDDIT_CLIENT_SECRET") + user_agent = os.getenv("REDDIT_USER_AGENT", "RedditTools v1.0") + + if not client_id or not client_secret: + raise ValueError("Reddit API credentials not found in environment variables") + + reddit = praw.Reddit( + client_id=client_id, + client_secret=client_secret, + user_agent=user_agent, + ) + + subreddit = reddit.subreddit(self.config.subreddit_name) + flairs = [flair["text"] for flair in subreddit.flair.link_templates] + info = SubredditInfo( + display_name=subreddit.display_name, + title=subreddit.title, + description=subreddit.description, + subscribers=subreddit.subscribers, + created_utc=subreddit.created_utc, + over18=subreddit.over18, + available_flairs=flairs, + public_description=subreddit.public_description, + url=subreddit.url, + ) + return RedditGetSubredditInfoNodeOutput(subreddit_info=info) + except Exception as e: + logging.error(f"Failed to get subreddit info: {e}") + raise e diff --git a/pyspur/backend/pyspur/nodes/integrations/reddit/reddit_get_subreddit_stats.py b/pyspur/backend/pyspur/nodes/integrations/reddit/reddit_get_subreddit_stats.py new file mode 100644 index 0000000000000000000000000000000000000000..fa797520f7ac8804f6891b597716320bd5859270 --- /dev/null +++ b/pyspur/backend/pyspur/nodes/integrations/reddit/reddit_get_subreddit_stats.py @@ -0,0 +1,128 @@ +import json +import logging +import os + +import praw +from pydantic import BaseModel, Field # type: ignore + +from ...base import BaseNode, BaseNodeConfig, BaseNodeInput, BaseNodeOutput + + +class RedditGetSubredditStatsNodeInput(BaseNodeInput): + """Input for the RedditGetSubredditStats node""" + + class Config: + extra = "allow" + + +class RecentActivity(BaseModel): + total_comments_last_100_posts: int = Field( + ..., description="Total number of comments in the last 100 posts" + ) + total_score_last_100_posts: int = Field(..., description="Total score of the last 100 posts") + average_comments_per_post: float = Field(..., description="Average number of comments per post") + average_score_per_post: float = Field(..., description="Average score per post") + + +class SubredditStats(BaseModel): + display_name: str = Field(..., description="Display name of the subreddit") + subscribers: int = Field(..., description="Number of subscribers") + active_users: int = Field(..., description="Number of active users") + description: str = Field(..., description="Full description of the subreddit") + created_utc: float = Field(..., description="Creation timestamp in UTC") + over18: bool = Field(..., description="Whether the subreddit is NSFW") + public_description: str = Field(..., description="Public description of the subreddit") + recent_activity: RecentActivity = Field(..., description="Recent activity statistics") + + +class RedditGetSubredditStatsNodeOutput(BaseNodeOutput): + subreddit_stats: SubredditStats = Field(..., description="The subreddit statistics") + + +# Define a simple schema without complex nested structures +SIMPLE_OUTPUT_SCHEMA = { + "title": "RedditGetSubredditStatsNodeOutput", + "type": "object", + "properties": { + "subreddit_stats": { + "title": "Subreddit Stats", + "type": "object", + "description": "The subreddit statistics", + } + }, + "required": ["subreddit_stats"], +} + + +class RedditGetSubredditStatsNodeConfig(BaseNodeConfig): + subreddit: str = Field("", description="The name of the subreddit to get statistics for.") + has_fixed_output: bool = True + output_json_schema: str = Field( + default=json.dumps(SIMPLE_OUTPUT_SCHEMA), + description="The JSON schema for the output of the node", + ) + + +class RedditGetSubredditStatsNode(BaseNode): + name = "reddit_get_subreddit_stats_node" + display_name = "RedditGetSubredditStats" + logo = "/images/reddit.png" + category = "Reddit" + + config_model = RedditGetSubredditStatsNodeConfig + input_model = RedditGetSubredditStatsNodeInput + output_model = RedditGetSubredditStatsNodeOutput + + def setup(self) -> None: + """Override setup to handle schema issues""" + try: + super().setup() + except ValueError as e: + if "Unsupported JSON schema type" in str(e): + # If we hit schema issues, use a very basic setup + logging.warning(f"Schema error: {e}, using simplified approach") + + async def run(self, input: BaseModel) -> BaseModel: + try: + client_id = os.getenv("REDDIT_CLIENT_ID") + client_secret = os.getenv("REDDIT_CLIENT_SECRET") + user_agent = os.getenv("REDDIT_USER_AGENT", "RedditTools v1.0") + + if not client_id or not client_secret: + raise ValueError("Reddit API credentials not found in environment variables") + + reddit = praw.Reddit( + client_id=client_id, + client_secret=client_secret, + user_agent=user_agent, + ) + + subreddit = reddit.subreddit(self.config.subreddit) + + # Get recent posts for activity metrics + recent_posts = list(subreddit.new(limit=100)) + total_comments = sum(post.num_comments for post in recent_posts) + total_score = sum(post.score for post in recent_posts) + + stats = SubredditStats( + display_name=subreddit.display_name, + subscribers=subreddit.subscribers, + active_users=subreddit.active_user_count, + description=subreddit.description, + created_utc=subreddit.created_utc, + over18=subreddit.over18, + public_description=subreddit.public_description, + recent_activity=RecentActivity( + total_comments_last_100_posts=total_comments, + total_score_last_100_posts=total_score, + average_comments_per_post=total_comments / len(recent_posts) + if recent_posts + else 0, + average_score_per_post=total_score / len(recent_posts) if recent_posts else 0, + ), + ) + + return RedditGetSubredditStatsNodeOutput(subreddit_stats=stats) + except Exception as e: + logging.error(f"Failed to get subreddit stats: {e}") + raise e diff --git a/pyspur/backend/pyspur/nodes/integrations/reddit/reddit_get_top_posts.py b/pyspur/backend/pyspur/nodes/integrations/reddit/reddit_get_top_posts.py new file mode 100644 index 0000000000000000000000000000000000000000..9dc159483b437a56f79ef600d1b2f94b9e5c39b6 --- /dev/null +++ b/pyspur/backend/pyspur/nodes/integrations/reddit/reddit_get_top_posts.py @@ -0,0 +1,125 @@ +import json +import logging +import os +from typing import List, Optional + +import praw +from pydantic import BaseModel, Field # type: ignore + +from ...base import BaseNode, BaseNodeConfig, BaseNodeInput, BaseNodeOutput + + +class RedditGetTopPostsNodeInput(BaseNodeInput): + """Input for the RedditGetTopPosts node.""" + + class Config: + extra = "allow" + + +class RedditPost(BaseModel): + title: str = Field(..., description="Title of the post") + score: int = Field(..., description="Score (upvotes) of the post") + url: str = Field(..., description="URL of the post") + author: str = Field(..., description="Username of the post author") + created_utc: float = Field(..., description="Creation timestamp in UTC") + num_comments: int = Field(..., description="Number of comments on the post") + permalink: str = Field(..., description="Reddit permalink to the post") + is_self: bool = Field(..., description="Whether this is a self (text) post") + selftext: Optional[str] = Field(None, description="Text content if this is a self post") + + +class RedditGetTopPostsNodeOutput(BaseNodeOutput): + top_posts: List[RedditPost] = Field(..., description="List of top posts from the subreddit") + + +# Define a simple schema without complex nested structures +SIMPLE_OUTPUT_SCHEMA = { + "title": "RedditGetTopPostsNodeOutput", + "type": "object", + "properties": { + "top_posts": { + "title": "Top Posts", + "type": "array", + "description": "List of top posts from the subreddit", + "items": {"type": "object"}, + } + }, + "required": ["top_posts"], +} + + +class RedditGetTopPostsNodeConfig(BaseNodeConfig): + subreddit: str = Field("", description="The name of the subreddit to get posts from.") + time_filter: str = Field( + "week", description="Time period to filter posts (hour, day, week, month, year, all)." + ) + limit: int = Field(10, description="Number of posts to fetch (max 100).") + only_self_posts: bool = Field(False, description="When True, only return self (text) posts.") + has_fixed_output: bool = True + + # Use a simple predefined schema + output_json_schema: str = Field( + default=json.dumps(SIMPLE_OUTPUT_SCHEMA), + description="The JSON schema for the output of the node", + ) + + +class RedditGetTopPostsNode(BaseNode): + name = "reddit_get_top_posts_node" + display_name = "RedditGetTopPosts" + logo = "/images/reddit.png" + category = "Reddit" + + config_model = RedditGetTopPostsNodeConfig + input_model = RedditGetTopPostsNodeInput + output_model = RedditGetTopPostsNodeOutput + + def setup(self) -> None: + """Override setup to handle schema issues""" + try: + super().setup() + except ValueError as e: + if "Unsupported JSON schema type" in str(e): + # If we hit schema issues, use a very basic setup + logging.warning(f"Schema error: {e}, using simplified approach") + + async def run(self, input: BaseModel) -> BaseModel: + try: + client_id = os.getenv("REDDIT_CLIENT_ID") + client_secret = os.getenv("REDDIT_CLIENT_SECRET") + user_agent = os.getenv("REDDIT_USER_AGENT", "RedditTools v1.0") + + if not client_id or not client_secret: + raise ValueError("Reddit API credentials not found in environment variables") + + reddit = praw.Reddit( + client_id=client_id, + client_secret=client_secret, + user_agent=user_agent, + ) + + posts = reddit.subreddit(self.config.subreddit).top( + time_filter=self.config.time_filter, + limit=min(self.config.limit, 100), # Cap at 100 posts + ) + + top_posts = [ + RedditPost( + title=post.title, + score=post.score, + url=post.url, + author=str(post.author), + created_utc=post.created_utc, + num_comments=post.num_comments, + permalink=post.permalink, + is_self=post.is_self, + selftext=post.selftext if post.is_self else None, + ) + for post in posts + if not self.config.only_self_posts or post.is_self + ] + + return RedditGetTopPostsNodeOutput(top_posts=top_posts) + except Exception as e: + logging.error(f"Failed to get top posts: {e}") + raise e diff --git a/pyspur/backend/pyspur/nodes/integrations/reddit/reddit_get_trending_subreddits.py b/pyspur/backend/pyspur/nodes/integrations/reddit/reddit_get_trending_subreddits.py new file mode 100644 index 0000000000000000000000000000000000000000..dca762ac5820db7a33285325509d759fc6c87501 --- /dev/null +++ b/pyspur/backend/pyspur/nodes/integrations/reddit/reddit_get_trending_subreddits.py @@ -0,0 +1,109 @@ +import json +import logging +import os +from typing import List + +import praw +from pydantic import BaseModel, Field # type: ignore + +from ...base import BaseNode, BaseNodeConfig, BaseNodeInput, BaseNodeOutput + + +class RedditGetTrendingSubredditsNodeInput(BaseNodeInput): + """Input for the RedditGetTrendingSubreddits node""" + + class Config: + extra = "allow" + + +class TrendingSubreddit(BaseModel): + name: str = Field(..., description="Display name of the subreddit") + title: str = Field(..., description="Title of the subreddit") + description: str = Field(..., description="Public description of the subreddit") + subscribers: int = Field(..., description="Number of subscribers") + url: str = Field(..., description="URL of the subreddit") + over18: bool = Field(..., description="Whether the subreddit is NSFW") + + +class RedditGetTrendingSubredditsNodeOutput(BaseNodeOutput): + trending_subreddits: List[TrendingSubreddit] = Field( + ..., description="List of trending subreddits" + ) + + +# Define a simple schema without complex nested structures +SIMPLE_OUTPUT_SCHEMA = { + "title": "RedditGetTrendingSubredditsNodeOutput", + "type": "object", + "properties": { + "trending_subreddits": { + "title": "Trending Subreddits", + "type": "array", + "description": "List of trending subreddits", + "items": {"type": "object"}, + } + }, + "required": ["trending_subreddits"], +} + + +class RedditGetTrendingSubredditsNodeConfig(BaseNodeConfig): + limit: int = Field(5, description="Number of trending subreddits to fetch (max 100).") + has_fixed_output: bool = True + output_json_schema: str = Field( + default=json.dumps(SIMPLE_OUTPUT_SCHEMA), + description="The JSON schema for the output of the node", + ) + + +class RedditGetTrendingSubredditsNode(BaseNode): + name = "reddit_get_trending_subreddits_node" + display_name = "RedditGetTrendingSubreddits" + logo = "/images/reddit.png" + category = "Reddit" + + config_model = RedditGetTrendingSubredditsNodeConfig + input_model = RedditGetTrendingSubredditsNodeInput + output_model = RedditGetTrendingSubredditsNodeOutput + + def setup(self) -> None: + """Override setup to handle schema issues""" + try: + super().setup() + except ValueError as e: + if "Unsupported JSON schema type" in str(e): + # If we hit schema issues, use a very basic setup + logging.warning(f"Schema error: {e}, using simplified approach") + + async def run(self, input: BaseModel) -> BaseModel: + try: + client_id = os.getenv("REDDIT_CLIENT_ID") + client_secret = os.getenv("REDDIT_CLIENT_SECRET") + user_agent = os.getenv("REDDIT_USER_AGENT", "RedditTools v1.0") + + if not client_id or not client_secret: + raise ValueError("Reddit API credentials not found in environment variables") + + reddit = praw.Reddit( + client_id=client_id, + client_secret=client_secret, + user_agent=user_agent, + ) + + popular_subreddits = reddit.subreddits.popular(limit=min(self.config.limit, 100)) + trending = [ + TrendingSubreddit( + name=subreddit.display_name, + title=subreddit.title, + description=subreddit.public_description, + subscribers=subreddit.subscribers, + url=subreddit.url, + over18=subreddit.over18, + ) + for subreddit in popular_subreddits + ] + + return RedditGetTrendingSubredditsNodeOutput(trending_subreddits=trending) + except Exception as e: + logging.error(f"Failed to get trending subreddits: {e}") + raise e diff --git a/pyspur/backend/pyspur/nodes/integrations/reddit/reddit_get_user_info.py b/pyspur/backend/pyspur/nodes/integrations/reddit/reddit_get_user_info.py new file mode 100644 index 0000000000000000000000000000000000000000..de823a251654c2bd5597f83c337c3c87eb6840fc --- /dev/null +++ b/pyspur/backend/pyspur/nodes/integrations/reddit/reddit_get_user_info.py @@ -0,0 +1,103 @@ +import json +import logging +import os + +import praw +from pydantic import BaseModel, Field # type: ignore + +from ...base import BaseNode, BaseNodeConfig, BaseNodeInput, BaseNodeOutput + + +class RedditGetUserInfoNodeInput(BaseNodeInput): + """Input for the RedditGetUserInfo node""" + + class Config: + extra = "allow" + + +class RedditUserInfo(BaseModel): + name: str = Field(..., description="Username of the Reddit user") + comment_karma: int = Field(..., description="Total karma from comments") + link_karma: int = Field(..., description="Total karma from posts/links") + is_mod: bool = Field(..., description="Whether the user is a moderator") + is_gold: bool = Field(..., description="Whether the user has Reddit gold") + is_employee: bool = Field(..., description="Whether the user is a Reddit employee") + created_utc: float = Field(..., description="Account creation timestamp in UTC") + + +class RedditGetUserInfoNodeOutput(BaseNodeOutput): + user_info: RedditUserInfo = Field(..., description="Information about the Reddit user") + + +# Define a simple schema without complex nested structures +SIMPLE_OUTPUT_SCHEMA = { + "title": "RedditGetUserInfoNodeOutput", + "type": "object", + "properties": { + "user_info": { + "title": "User Info", + "type": "object", + "description": "Information about the Reddit user", + } + }, + "required": ["user_info"], +} + + +class RedditGetUserInfoNodeConfig(BaseNodeConfig): + username: str = Field("", description="The Reddit username to get information for.") + has_fixed_output: bool = True + output_json_schema: str = Field( + default=json.dumps(SIMPLE_OUTPUT_SCHEMA), + description="The JSON schema for the output of the node", + ) + + +class RedditGetUserInfoNode(BaseNode): + name = "reddit_get_user_info_node" + display_name = "RedditGetUserInfo" + logo = "/images/reddit.png" + category = "Reddit" + + config_model = RedditGetUserInfoNodeConfig + input_model = RedditGetUserInfoNodeInput + output_model = RedditGetUserInfoNodeOutput + + def setup(self) -> None: + """Override setup to handle schema issues""" + try: + super().setup() + except ValueError as e: + if "Unsupported JSON schema type" in str(e): + # If we hit schema issues, use a very basic setup + logging.warning(f"Schema error: {e}, using simplified approach") + + async def run(self, input: BaseModel) -> BaseModel: + try: + client_id = os.getenv("REDDIT_CLIENT_ID") + client_secret = os.getenv("REDDIT_CLIENT_SECRET") + user_agent = os.getenv("REDDIT_USER_AGENT", "RedditTools v1.0") + + if not client_id or not client_secret: + raise ValueError("Reddit API credentials not found in environment variables") + + reddit = praw.Reddit( + client_id=client_id, + client_secret=client_secret, + user_agent=user_agent, + ) + + user = reddit.redditor(self.config.username) + info = RedditUserInfo( + name=user.name, + comment_karma=user.comment_karma, + link_karma=user.link_karma, + is_mod=user.is_mod, + is_gold=user.is_gold, + is_employee=user.is_employee, + created_utc=user.created_utc, + ) + return RedditGetUserInfoNodeOutput(user_info=info) + except Exception as e: + logging.error(f"Failed to get user info: {e}") + raise e diff --git a/pyspur/backend/pyspur/nodes/integrations/slack/__init__.py b/pyspur/backend/pyspur/nodes/integrations/slack/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pyspur/backend/pyspur/nodes/integrations/slack/slack_notify.py b/pyspur/backend/pyspur/nodes/integrations/slack/slack_notify.py new file mode 100644 index 0000000000000000000000000000000000000000..e569ef844166b64fae0233b6e47c3ab05ec5f68a --- /dev/null +++ b/pyspur/backend/pyspur/nodes/integrations/slack/slack_notify.py @@ -0,0 +1,78 @@ +import json +from enum import Enum +from jinja2 import Template + +from pydantic import BaseModel, Field + +from ....integrations.slack.client import SlackClient +from ...base import BaseNode, BaseNodeConfig, BaseNodeInput, BaseNodeOutput + + +class ModeEnum(str, Enum): + BOT = "bot" + USER = "user" + + +class SlackNotifyNodeInput(BaseNodeInput): + """Input for the SlackNotify node""" + + class Config: + extra = "allow" + + +class SlackNotifyNodeOutput(BaseNodeOutput): + status: str = Field( + ..., + description="Error message if the message was not sent successfully.", + ) + + +class SlackNotifyNodeConfig(BaseNodeConfig): + channel: str = Field("", description="The channel ID to send the message to.") + mode: ModeEnum = Field( + ModeEnum.BOT, + description="The mode to send the message in. Can be 'bot' or 'user'.", + ) + message: str = Field( + default="", + description="The message template to send to Slack. Use {{variable}} syntax to include data from input nodes.", + ) + has_fixed_output: bool = True + output_json_schema: str = Field( + default=json.dumps(SlackNotifyNodeOutput.model_json_schema()), + description="The JSON schema for the output of the node", + ) + + +class SlackNotifyNode(BaseNode): + name = "slack_notify_node" + display_name = "SlackNotify" + logo = "/images/slack.png" + category = "Slack" + + config_model = SlackNotifyNodeConfig + input_model = SlackNotifyNodeInput + output_model = SlackNotifyNodeOutput + + async def run(self, input: BaseModel) -> BaseModel: + """ + Sends a message to the specified Slack channel. + """ + # convert data to a string and send it to the Slack channel + if not self.config.message.strip(): + # If no template is provided, dump the entire input as JSON + message = json.dumps(input.model_dump(), indent=2) + else: + # Render the message template with input variables + try: + message = Template(self.config.message).render(**input.model_dump()) + except Exception as e: + print(f"[ERROR] Failed to render message template in {self.name}") + print(f"[ERROR] Template: {self.config.message} with input: {input.model_dump()}") + raise e + + client = SlackClient() + ok, status = client.send_message( + channel=self.config.channel, text=message, mode=self.config.mode + ) # type: ignore + return SlackNotifyNodeOutput(status=status) diff --git a/pyspur/backend/pyspur/nodes/integrations/youtube/__init__.py b/pyspur/backend/pyspur/nodes/integrations/youtube/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pyspur/backend/pyspur/nodes/integrations/youtube/youtube_transcript.py b/pyspur/backend/pyspur/nodes/integrations/youtube/youtube_transcript.py new file mode 100644 index 0000000000000000000000000000000000000000..2816695b811796bbe7cadbd28b1158c09f5c86d1 --- /dev/null +++ b/pyspur/backend/pyspur/nodes/integrations/youtube/youtube_transcript.py @@ -0,0 +1,62 @@ +import json +import logging + +from phi.tools.youtube_tools import YouTubeTools +from pydantic import BaseModel, Field + +from ...base import BaseNode, BaseNodeConfig, BaseNodeInput, BaseNodeOutput +from ...utils.template_utils import render_template_or_get_first_string + + +class YouTubeTranscriptNodeInput(BaseNodeInput): + """Input for the YouTubeTranscript node""" + + class Config: + extra = "allow" + + +class YouTubeTranscriptNodeOutput(BaseNodeOutput): + transcript: str = Field(..., description="The transcript of the YouTube video.") + + +class YouTubeTranscriptNodeConfig(BaseNodeConfig): + video_url_template: str = Field( + "", + description="The YouTube video url template to fetch the transcript for.", + ) + has_fixed_output: bool = True + output_json_schema: str = Field( + default=json.dumps(YouTubeTranscriptNodeOutput.model_json_schema()), + description="The JSON schema for the output of the node", + ) + + +class YouTubeTranscriptNode(BaseNode): + name = "youtube_transcript_node" + display_name = "YouTubeTranscript" + logo = "/images/youtube.png" + category = "YouTube" + + config_model = YouTubeTranscriptNodeConfig + input_model = YouTubeTranscriptNodeInput + output_model = YouTubeTranscriptNodeOutput + + async def run(self, input: BaseModel) -> BaseModel: + """ + Fetches the transcript for a given YouTube video ID and languages. + """ + try: + # Grab the entire dictionary from the input + raw_input_dict = input.model_dump() + + # Render video_url_template + video_url = render_template_or_get_first_string( + self.config.video_url_template, raw_input_dict, self.name + ) + + yt = YouTubeTools() + transcript: str = yt.get_youtube_video_captions(url=video_url) + return YouTubeTranscriptNodeOutput(transcript=transcript) + except Exception as e: + logging.error(f"Failed to get transcript: {e}") + return YouTubeTranscriptNodeOutput(transcript="") diff --git a/pyspur/backend/pyspur/nodes/llm/_model_info.py b/pyspur/backend/pyspur/nodes/llm/_model_info.py new file mode 100644 index 0000000000000000000000000000000000000000..8e353cecea10887067b286b6484c9a22875afa17 --- /dev/null +++ b/pyspur/backend/pyspur/nodes/llm/_model_info.py @@ -0,0 +1,587 @@ +from enum import Enum +from typing import Optional, Set + +from pydantic import BaseModel + +from ...utils.mime_types_utils import ( + MIME_TYPES_BY_CATEGORY, + MimeCategory, + RecognisedMimeType, +) + + +class LLMProvider(str, Enum): + OPENAI = "openai" + ANTHROPIC = "anthropic" + GEMINI = "gemini" + OLLAMA = "ollama" + AZURE_OPENAI = "azure" + DEEPSEEK = "deepseek" + XAI = "xai" + + +class ModelConstraints(BaseModel): + max_tokens: int + min_temperature: float = 0.0 + max_temperature: float = 1.0 + supports_JSON_output: bool = True + supports_max_tokens: bool = True + supports_temperature: bool = True + supported_mime_types: Set[RecognisedMimeType] = set() # Empty set means no multimodal support + supports_reasoning: bool = False + reasoning_separator: str = r".*?" + supports_thinking: bool = False + thinking_budget_tokens: Optional[int] = None + + def add_mime_categories(self, categories: Set[MimeCategory]) -> "ModelConstraints": + """Add MIME type support for entire categories. + + Args: + categories: Set of MimeCategory to add support for. All MIME types + in these categories will be added. + + Returns: + self: Returns self for method chaining. + + """ + for category in categories: + self.supported_mime_types.update(MIME_TYPES_BY_CATEGORY[category]) + return self + + def add_mime_types(self, mime_types: Set[RecognisedMimeType]) -> "ModelConstraints": + """Add support for specific MIME types. + + Args: + mime_types: Set of specific RecognisedMimeType to add support for. + + Returns: + self: Returns self for method chaining. + + """ + self.supported_mime_types.update(mime_types) + return self + + def is_mime_type_supported(self, mime_type: RecognisedMimeType) -> bool: + """Check if a specific MIME type is supported. + + Args: + mime_type: The RecognisedMimeType to check. + + Returns: + bool: True if the MIME type is supported, False otherwise. + + """ + return mime_type in self.supported_mime_types + + +class LLMModel(BaseModel): + id: str + provider: LLMProvider + name: str + constraints: ModelConstraints + + +class LLMModels(str, Enum): + # OpenAI Models + O1_PRO = "openai/o1-pro" + O3_MINI = "openai/o3-mini" + O3_MINI_2025_01_31 = "openai/o3-mini-2025-01-31" + GPT_4O_MINI = "openai/gpt-4o-mini" + GPT_4O = "openai/gpt-4o" + GPT_4O_MINI_SEARCH_PREVIEW = "openai/gpt-4o-mini-search-preview" + GPT_4O_SEARCH_PREVIEW = "openai/gpt-4o-search-preview" + O1_PREVIEW = "openai/o1-preview" + O1_MINI = "openai/o1-mini" + O1 = "openai/o1" + O1_2024_12_17 = "openai/o1-2024-12-17" + O1_MINI_2024_09_12 = "openai/o1-mini-2024-09-12" + O1_PREVIEW_2024_09_12 = "openai/o1-preview-2024-09-12" + CHATGPT_4O_LATEST = "openai/chatgpt-4o-latest" + + # Azure OpenAI Models + AZURE_GPT_4 = "azure/gpt-4" + AZURE_GPT_35_TURBO = "azure/gpt-35-turbo" + + # Anthropic Models + CLAUDE_3_5_SONNET_LATEST = "anthropic/claude-3-5-sonnet-latest" + CLAUDE_3_5_HAIKU_LATEST = "anthropic/claude-3-5-haiku-latest" + CLAUDE_3_OPUS_LATEST = "anthropic/claude-3-opus-latest" + CLAUDE_3_7_SONNET_LATEST = "anthropic/claude-3-7-sonnet-latest" + + # Google Models + GEMINI_2_0_FLASH_EXP = "gemini/gemini-2.0-flash-exp" + GEMINI_2_0_FLASH = "gemini/gemini-2.0-flash" + GEMINI_1_5_PRO = "gemini/gemini-1.5-pro" + GEMINI_1_5_FLASH = "gemini/gemini-1.5-flash" + GEMINI_1_5_PRO_LATEST = "gemini/gemini-1.5-pro-latest" + GEMINI_1_5_FLASH_LATEST = "gemini/gemini-1.5-flash-latest" + + # Deepseek Models + DEEPSEEK_CHAT = "deepseek/deepseek-chat" + DEEPSEEK_REASONER = "deepseek/deepseek-reasoner" + + # Ollama Models + OLLAMA_MISTRAL_SMALL = "ollama/mistral-small:24b" + OLLAMA_DEEPSEEK_R1 = "ollama/deepseek-r1" + OLLAMA_PHI4 = "ollama/phi4" + OLLAMA_LLAMA3_3_70B = "ollama/llama3.3:70b" + OLLAMA_LLAMA3_2_3B = "ollama/llama3.2:3b" + OLLAMA_LLAMA3_2_1B = "ollama/llama3.2:1b" + OLLAMA_LLAMA3_1_8B = "ollama/llama3.1:8b" + OLLAMA_LLAMA3_1_70B = "ollama/llama3.1:70b" + OLLAMA_LLAMA3_8B = "ollama/llama3:8b" + OLLAMA_LLAMA3_70B = "ollama/llama3:70b" + OLLAMA_GEMMA_3_1B = "ollama/gemma3:1b" + OLLAMA_GEMMA_3_4B = "ollama/gemma3:4b" + OLLAMA_GEMMA_3_12B = "ollama/gemma3:12b" + OLLAMA_GEMMA_3_27B = "ollama/gemma3:27b" + OLLAMA_GEMMA_2 = "ollama/gemma2" + OLLAMA_GEMMA_2_2B = "ollama/gemma2:2b" + OLLAMA_MISTRAL = "ollama/mistral" + OLLAMA_CODELLAMA = "ollama/codellama" + OLLAMA_MIXTRAL = "ollama/mixtral-8x7b-instruct-v0.1" + + XAI_GROK_2 = "xai/grok-2-latest" + + @classmethod + def get_model_info(cls, model_id: str) -> LLMModel | None: + model_registry = { + cls.O3_MINI.value: LLMModel( + id=cls.O3_MINI.value, + provider=LLMProvider.OPENAI, + name="O3 Mini", + constraints=ModelConstraints( + max_tokens=100000, + max_temperature=2.0, + supports_max_tokens=False, + supports_temperature=False, + ).add_mime_categories({MimeCategory.IMAGES}), + ), + cls.O3_MINI_2025_01_31.value: LLMModel( + id=cls.O3_MINI_2025_01_31.value, + provider=LLMProvider.OPENAI, + name="O3 Mini (2025-01-31)", + constraints=ModelConstraints( + max_tokens=100000, + max_temperature=2.0, + supports_max_tokens=False, + supports_temperature=False, + ).add_mime_categories({MimeCategory.IMAGES}), + ), + cls.GPT_4O_MINI.value: LLMModel( + id=cls.GPT_4O_MINI.value, + provider=LLMProvider.OPENAI, + name="GPT-4O Mini", + constraints=ModelConstraints( + max_tokens=16384, max_temperature=2.0 + ).add_mime_categories({MimeCategory.IMAGES}), + ), + cls.GPT_4O.value: LLMModel( + id=cls.GPT_4O.value, + provider=LLMProvider.OPENAI, + name="GPT-4O", + constraints=ModelConstraints( + max_tokens=16384, max_temperature=2.0 + ).add_mime_categories({MimeCategory.IMAGES}), + ), + cls.GPT_4O_MINI_SEARCH_PREVIEW.value: LLMModel( + id=cls.GPT_4O_MINI_SEARCH_PREVIEW.value, + provider=LLMProvider.OPENAI, + name="GPT-4O Mini Search Preview", + constraints=ModelConstraints( + max_tokens=16384, + supports_temperature=False + ).add_mime_categories({MimeCategory.IMAGES}), + ), + cls.GPT_4O_SEARCH_PREVIEW.value: LLMModel( + id=cls.GPT_4O_SEARCH_PREVIEW.value, + provider=LLMProvider.OPENAI, + name="GPT-4O Search Preview", + constraints=ModelConstraints( + max_tokens=16384, + supports_temperature=False + ).add_mime_categories({MimeCategory.IMAGES}), + ), + cls.O1_PREVIEW.value: LLMModel( + id=cls.O1_PREVIEW.value, + provider=LLMProvider.OPENAI, + name="O1 Preview", + constraints=ModelConstraints( + max_tokens=32768, + max_temperature=2.0, + supports_max_tokens=False, + supports_temperature=False, + ).add_mime_categories({MimeCategory.IMAGES}), + ), + cls.O1_MINI.value: LLMModel( + id=cls.O1_MINI.value, + provider=LLMProvider.OPENAI, + name="O1 Mini", + constraints=ModelConstraints( + max_tokens=65536, + max_temperature=2.0, + supports_max_tokens=False, + supports_temperature=False, + ).add_mime_categories({MimeCategory.IMAGES}), + ), + cls.O1.value: LLMModel( + id=cls.O1.value, + provider=LLMProvider.OPENAI, + name="O1", + constraints=ModelConstraints( + max_tokens=100000, + max_temperature=2.0, + supports_max_tokens=False, + supports_temperature=False, + ).add_mime_categories({MimeCategory.IMAGES}), + ), + cls.O1_PRO.value: LLMModel( + id=cls.O1_PRO.value, + provider=LLMProvider.OPENAI, + name="O1 Pro", + constraints=ModelConstraints( + max_tokens=100000, + max_temperature=2.0, + supports_max_tokens=False, + supports_temperature=False, + supports_JSON_output=True, + ).add_mime_categories({MimeCategory.IMAGES}), + ), + cls.O1_2024_12_17.value: LLMModel( + id=cls.O1_2024_12_17.value, + provider=LLMProvider.OPENAI, + name="O1 (2024-12-17)", + constraints=ModelConstraints( + max_tokens=100000, + max_temperature=2.0, + supports_max_tokens=False, + supports_temperature=False, + ).add_mime_categories({MimeCategory.IMAGES}), + ), + cls.O1_MINI_2024_09_12.value: LLMModel( + id=cls.O1_MINI_2024_09_12.value, + provider=LLMProvider.OPENAI, + name="O1 Mini (2024-09-12)", + constraints=ModelConstraints( + max_tokens=65536, + max_temperature=2.0, + supports_max_tokens=False, + supports_temperature=False, + ).add_mime_categories({MimeCategory.IMAGES}), + ), + cls.O1_PREVIEW_2024_09_12.value: LLMModel( + id=cls.O1_PREVIEW_2024_09_12.value, + provider=LLMProvider.OPENAI, + name="O1 Preview (2024-09-12)", + constraints=ModelConstraints( + max_tokens=32768, + max_temperature=2.0, + supports_max_tokens=False, + supports_temperature=False, + ).add_mime_categories({MimeCategory.IMAGES}), + ), + cls.CHATGPT_4O_LATEST.value: LLMModel( + id=cls.CHATGPT_4O_LATEST.value, + provider=LLMProvider.OPENAI, + name="ChatGPT-4 Optimized Latest", + constraints=ModelConstraints( + max_tokens=4096, max_temperature=2.0 + ).add_mime_categories({MimeCategory.IMAGES}), + ), + # Azure OpenAI Models + cls.AZURE_GPT_4.value: LLMModel( + id=cls.AZURE_GPT_4.value, + provider=LLMProvider.AZURE_OPENAI, + name="Azure GPT-4", + constraints=ModelConstraints(max_tokens=4096, max_temperature=2.0), + ), + cls.AZURE_GPT_35_TURBO.value: LLMModel( + id=cls.AZURE_GPT_35_TURBO.value, + provider=LLMProvider.AZURE_OPENAI, + name="Azure GPT-3.5 Turbo", + constraints=ModelConstraints(max_tokens=4096, max_temperature=2.0), + ), + # Anthropic Models + cls.CLAUDE_3_5_SONNET_LATEST.value: LLMModel( + id=cls.CLAUDE_3_5_SONNET_LATEST.value, + provider=LLMProvider.ANTHROPIC, + name="Claude 3.5 Sonnet Latest", + constraints=ModelConstraints( + max_tokens=8192, + max_temperature=1.0, + ).add_mime_categories({MimeCategory.IMAGES, MimeCategory.DOCUMENTS}), + ), + cls.CLAUDE_3_5_HAIKU_LATEST.value: LLMModel( + id=cls.CLAUDE_3_5_HAIKU_LATEST.value, + provider=LLMProvider.ANTHROPIC, + name="Claude 3.5 Haiku Latest", + constraints=ModelConstraints( + max_tokens=8192, + max_temperature=1.0, + ), + ), + cls.CLAUDE_3_OPUS_LATEST.value: LLMModel( + id=cls.CLAUDE_3_OPUS_LATEST.value, + provider=LLMProvider.ANTHROPIC, + name="Claude 3 Opus Latest", + constraints=ModelConstraints( + max_tokens=4096, + max_temperature=1.0, + ).add_mime_categories({MimeCategory.IMAGES}), + ), + cls.CLAUDE_3_7_SONNET_LATEST.value: LLMModel( + id=cls.CLAUDE_3_7_SONNET_LATEST.value, + provider=LLMProvider.ANTHROPIC, + name="Claude 3.7 Sonnet Latest", + constraints=ModelConstraints( + max_tokens=8192, + max_temperature=1.0, + supports_thinking=True, + thinking_budget_tokens=1024, + ).add_mime_categories({MimeCategory.IMAGES, MimeCategory.DOCUMENTS}), + ), + # Google Models + cls.GEMINI_1_5_PRO.value: LLMModel( + id=cls.GEMINI_1_5_PRO.value, + provider=LLMProvider.GEMINI, + name="Gemini 1.5 Pro", + constraints=ModelConstraints( + max_tokens=8192, + max_temperature=1.0, + ).add_mime_categories({MimeCategory.IMAGES, MimeCategory.AUDIO}), + ), + cls.GEMINI_1_5_FLASH.value: LLMModel( + id=cls.GEMINI_1_5_FLASH.value, + provider=LLMProvider.GEMINI, + name="Gemini 1.5 Flash", + constraints=ModelConstraints( + max_tokens=8192, max_temperature=1.0 + ).add_mime_categories( + { + MimeCategory.IMAGES, + MimeCategory.AUDIO, + MimeCategory.VIDEO, + MimeCategory.DOCUMENTS, + MimeCategory.TEXT, + } + ), + ), + cls.GEMINI_1_5_PRO_LATEST.value: LLMModel( + id=cls.GEMINI_1_5_PRO_LATEST.value, + provider=LLMProvider.GEMINI, + name="Gemini 1.5 Pro Latest", + constraints=ModelConstraints( + max_tokens=8192, max_temperature=1.0 + ).add_mime_categories( + { + MimeCategory.IMAGES, + MimeCategory.AUDIO, + MimeCategory.VIDEO, + MimeCategory.DOCUMENTS, + MimeCategory.TEXT, + } + ), + ), + cls.GEMINI_1_5_FLASH_LATEST.value: LLMModel( + id=cls.GEMINI_1_5_FLASH_LATEST.value, + provider=LLMProvider.GEMINI, + name="Gemini 1.5 Flash Latest", + constraints=ModelConstraints( + max_tokens=8192, max_temperature=1.0 + ).add_mime_categories( + { + MimeCategory.IMAGES, + MimeCategory.AUDIO, + MimeCategory.VIDEO, + MimeCategory.DOCUMENTS, + MimeCategory.TEXT, + } + ), + ), + cls.GEMINI_2_0_FLASH_EXP.value: LLMModel( + id=cls.GEMINI_2_0_FLASH_EXP.value, + provider=LLMProvider.GEMINI, + name="Gemini 2.0 Flash Exp", + constraints=ModelConstraints( + max_tokens=8192, max_temperature=1.0 + ).add_mime_categories( + { + MimeCategory.IMAGES, + MimeCategory.AUDIO, + MimeCategory.VIDEO, + MimeCategory.DOCUMENTS, + MimeCategory.TEXT, + } + ), + ), + cls.GEMINI_2_0_FLASH.value: LLMModel( + id=cls.GEMINI_2_0_FLASH.value, + provider=LLMProvider.GEMINI, + name="Gemini 2.0 Flash", + constraints=ModelConstraints( + max_tokens=8192, + max_temperature=2.0, + ).add_mime_categories( + { + MimeCategory.IMAGES, + MimeCategory.AUDIO, + MimeCategory.VIDEO, + MimeCategory.DOCUMENTS, + MimeCategory.TEXT, + } + ), + ), + # Deepseek Models + cls.DEEPSEEK_CHAT.value: LLMModel( + id=cls.DEEPSEEK_CHAT.value, + provider=LLMProvider.DEEPSEEK, + name="Deepseek Chat", + constraints=ModelConstraints( + max_tokens=8192, + max_temperature=2.0, + supports_JSON_output=False, + ), + ), + cls.DEEPSEEK_REASONER.value: LLMModel( + id=cls.DEEPSEEK_REASONER.value, + provider=LLMProvider.DEEPSEEK, + name="Deepseek Reasoner", + constraints=ModelConstraints( + max_tokens=8192, + max_temperature=2.0, + supports_JSON_output=False, + supports_max_tokens=False, + ), + ), + # Ollama Models + cls.OLLAMA_PHI4.value: LLMModel( + id=cls.OLLAMA_PHI4.value, + provider=LLMProvider.OLLAMA, + name="Phi 4", + constraints=ModelConstraints(max_tokens=4096, max_temperature=2.0, supports_JSON_output=False), + ), + cls.OLLAMA_LLAMA3_3_70B.value: LLMModel( + id=cls.OLLAMA_LLAMA3_3_70B.value, + provider=LLMProvider.OLLAMA, + name="Llama 3.3 (70B)", + constraints=ModelConstraints(max_tokens=4096, max_temperature=2.0), + ), + cls.OLLAMA_LLAMA3_2_3B.value: LLMModel( + id=cls.OLLAMA_LLAMA3_2_3B.value, + provider=LLMProvider.OLLAMA, + name="Llama 3.2 (3B)", + constraints=ModelConstraints(max_tokens=4096, max_temperature=2.0, supports_JSON_output=False), + ), + cls.OLLAMA_LLAMA3_2_1B.value: LLMModel( + id=cls.OLLAMA_LLAMA3_2_1B.value, + provider=LLMProvider.OLLAMA, + name="Llama 3.2 (1B)", + constraints=ModelConstraints(max_tokens=4096, max_temperature=2.0, supports_JSON_output=False), + ), + cls.OLLAMA_LLAMA3_1_8B.value: LLMModel( + id=cls.OLLAMA_LLAMA3_1_8B.value, + provider=LLMProvider.OLLAMA, + name="Llama 3.1 (8B)", + constraints=ModelConstraints(max_tokens=4096, max_temperature=2.0, supports_JSON_output=False), + ), + cls.OLLAMA_LLAMA3_1_70B.value: LLMModel( + id=cls.OLLAMA_LLAMA3_1_70B.value, + provider=LLMProvider.OLLAMA, + name="Llama 3.1 (70B)", + constraints=ModelConstraints(max_tokens=4096, max_temperature=2.0, supports_JSON_output=False), + ), + cls.OLLAMA_LLAMA3_8B.value: LLMModel( + id=cls.OLLAMA_LLAMA3_8B.value, + provider=LLMProvider.OLLAMA, + name="Llama 3 (8B)", + constraints=ModelConstraints(max_tokens=4096, max_temperature=2.0, supports_JSON_output=False), + ), + cls.OLLAMA_LLAMA3_70B.value: LLMModel( + id=cls.OLLAMA_LLAMA3_70B.value, + provider=LLMProvider.OLLAMA, + name="Llama 3 (70B)", + constraints=ModelConstraints(max_tokens=4096, max_temperature=2.0, supports_JSON_output=False), + ), + cls.OLLAMA_GEMMA_3_1B.value: LLMModel( + id=cls.OLLAMA_GEMMA_3_1B.value, + provider=LLMProvider.OLLAMA, + name="Gemma 3 (1B)", + constraints=ModelConstraints(max_tokens=4096, max_temperature=2.0, supports_JSON_output=False), + ), + cls.OLLAMA_GEMMA_3_4B.value: LLMModel( + id=cls.OLLAMA_GEMMA_3_4B.value, + provider=LLMProvider.OLLAMA, + name="Gemma 3 (4B)", + constraints=ModelConstraints(max_tokens=4096, max_temperature=2.0, supports_JSON_output=False), + ), + cls.OLLAMA_GEMMA_3_12B.value: LLMModel( + id=cls.OLLAMA_GEMMA_3_12B.value, + provider=LLMProvider.OLLAMA, + name="Gemma 3 (12B)", + constraints=ModelConstraints(max_tokens=4096, max_temperature=2.0, supports_JSON_output=False), + ), + cls.OLLAMA_GEMMA_3_27B.value: LLMModel( + id=cls.OLLAMA_GEMMA_3_27B.value, + provider=LLMProvider.OLLAMA, + name="Gemma 3 (27B)", + constraints=ModelConstraints(max_tokens=4096, max_temperature=2.0, supports_JSON_output=False), + ), + cls.OLLAMA_GEMMA_2.value: LLMModel( + id=cls.OLLAMA_GEMMA_2.value, + provider=LLMProvider.OLLAMA, + name="Gemma 2", + constraints=ModelConstraints(max_tokens=4096, max_temperature=2.0, supports_JSON_output=False), + ), + cls.OLLAMA_GEMMA_2_2B.value: LLMModel( + id=cls.OLLAMA_GEMMA_2_2B.value, + provider=LLMProvider.OLLAMA, + name="Gemma 2 (2B)", + constraints=ModelConstraints(max_tokens=4096, max_temperature=2.0, supports_JSON_output=False), + ), + cls.OLLAMA_MISTRAL.value: LLMModel( + id=cls.OLLAMA_MISTRAL.value, + provider=LLMProvider.OLLAMA, + name="Mistral", + constraints=ModelConstraints(max_tokens=4096, max_temperature=2.0, supports_JSON_output=False), + ), + cls.OLLAMA_CODELLAMA.value: LLMModel( + id=cls.OLLAMA_CODELLAMA.value, + provider=LLMProvider.OLLAMA, + name="CodeLlama", + constraints=ModelConstraints(max_tokens=4096, max_temperature=2.0, supports_JSON_output=False), + ), + cls.OLLAMA_MIXTRAL.value: LLMModel( + id=cls.OLLAMA_MIXTRAL.value, + provider=LLMProvider.OLLAMA, + name="Mixtral 8x7B Instruct", + constraints=ModelConstraints(max_tokens=4096, max_temperature=2.0, supports_JSON_output=False), + ), + cls.OLLAMA_DEEPSEEK_R1.value: LLMModel( + id=cls.OLLAMA_DEEPSEEK_R1.value, + provider=LLMProvider.OLLAMA, + name="Deepseek R1", + constraints=ModelConstraints( + max_tokens=8192, + max_temperature=2.0, + supports_JSON_output=False, + supports_max_tokens=False, + supports_reasoning=True, + ), + ), + cls.OLLAMA_MISTRAL_SMALL.value: LLMModel( + id=cls.OLLAMA_MISTRAL_SMALL.value, + provider=LLMProvider.OLLAMA, + name="Mistral Small 24B", + constraints=ModelConstraints(max_tokens=4096, max_temperature=2.0, supports_JSON_output=False), + ), + cls.XAI_GROK_2.value: LLMModel( + id=cls.XAI_GROK_2.value, + provider=LLMProvider.XAI, + name="xAI Grok 2 Latest", + constraints=ModelConstraints( + max_tokens=131072, + max_temperature=1.0, + ), + ), + } + return model_registry.get(model_id) diff --git a/pyspur/backend/pyspur/nodes/llm/_providers.py b/pyspur/backend/pyspur/nodes/llm/_providers.py new file mode 100644 index 0000000000000000000000000000000000000000..fdd6c13639b3b762f8a2b21795fba642fc273c41 --- /dev/null +++ b/pyspur/backend/pyspur/nodes/llm/_providers.py @@ -0,0 +1,65 @@ +import os +from typing import Any, Dict, Optional + +from pydantic import BaseModel, Field + + +class OllamaOptions(BaseModel): + """Options for Ollama API calls""" + + temperature: float = Field( + default=0.7, + ge=0.0, + le=1.0, + description="Controls randomness in responses", + ) + max_tokens: Optional[int] = Field( + default=None, ge=0, description="Maximum number of tokens to generate" + ) + top_p: Optional[float] = Field( + default=0.9, ge=0.0, le=1.0, description="Nucleus sampling threshold" + ) + top_k: Optional[int] = Field( + default=None, + ge=0, + description="Number of tokens to consider for top-k sampling", + ) + repeat_penalty: Optional[float] = Field( + default=None, ge=0.0, description="Penalty for token repetition" + ) + stop: Optional[list[str]] = Field(default=None, description="Stop sequences to end generation") + response_format: Optional[str] = Field(default=None, description="Format of the response") + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary, excluding None values""" + return {k: v for k, v in self.model_dump().items() if v is not None} + + +def setup_azure_configuration(kwargs: Dict[str, Any]) -> Dict[str, Any]: + """ + Helper function to configure Azure settings from environment variables. + This strips the 'azure/' prefix from the model, removes any 'response_format' + parameter, and verifies that required Azure keys are present. + """ + # Remove the "azure/" prefix if present + base_model = ( + kwargs["model"].replace("azure/", "") + if kwargs["model"].startswith("azure/") + else kwargs["model"] + ) + azure_kwargs = kwargs.copy() + azure_kwargs.pop("response_format", None) + azure_kwargs.update( + { + "model": base_model, + "api_key": os.getenv("AZURE_OPENAI_API_KEY"), + "api_base": os.getenv("AZURE_OPENAI_API_BASE"), + "api_version": os.getenv("AZURE_OPENAI_API_VERSION"), + "deployment_id": os.getenv("AZURE_OPENAI_DEPLOYMENT_NAME"), + } + ) + required_config = ["api_key", "api_base", "api_version", "deployment_id"] + missing_config = [key for key in required_config if not azure_kwargs.get(key)] + if missing_config: + raise ValueError(f"Missing Azure configuration for: {', '.join(missing_config)}") + return azure_kwargs diff --git a/pyspur/backend/pyspur/nodes/llm/_utils.py b/pyspur/backend/pyspur/nodes/llm/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b136796a3ad67862bcdf0898412130a65bdd4443 --- /dev/null +++ b/pyspur/backend/pyspur/nodes/llm/_utils.py @@ -0,0 +1,609 @@ +# type: ignore +import base64 +import json +import logging +import os +import re +from typing import Any, Callable, Dict, List, Optional + +import litellm +from docx2python import docx2python +from dotenv import load_dotenv +from litellm import acompletion +from litellm.types.utils import Message +from ollama import AsyncClient +from pydantic import BaseModel, Field +from tenacity import AsyncRetrying, stop_after_attempt, wait_random_exponential + +from ...utils.file_utils import encode_file_to_base64_data_url +from ...utils.mime_types_utils import get_mime_type_for_url +from ...utils.path_utils import is_external_url, resolve_file_path +from ._model_info import LLMModels +from ._providers import OllamaOptions, setup_azure_configuration + +# uncomment for debugging litellm issues +# litellm.set_verbose=True +load_dotenv() + +# Enable parameter dropping for unsupported parameters +litellm.drop_params = True + +# Clean up Azure API base URL if needed +azure_api_base = os.getenv("AZURE_OPENAI_API_BASE", "").rstrip("/") +if azure_api_base.endswith("/openai"): + azure_api_base = azure_api_base.rstrip("/openai") +os.environ["AZURE_OPENAI_API_BASE"] = azure_api_base + +# Set OpenAI base URL if provided +openai_base_url = os.getenv("OPENAI_API_BASE") +if openai_base_url: + litellm.api_base = openai_base_url + +# If Azure OpenAi is configured, set it as the default provider +if os.getenv("AZURE_OPENAI_API_KEY"): + litellm.api_key = os.getenv("AZURE_OPENAI_API_KEY") + + +class ModelInfo(BaseModel): + model: LLMModels = Field(LLMModels.GPT_4O, description="The LLM model to use for completion") + max_tokens: Optional[int] = Field( + ..., + ge=1, + le=65536, + description="Maximum number of tokens the model can generate", + ) + temperature: Optional[float] = Field( + default=0.7, + ge=0.0, + le=1.0, + description="Temperature for randomness, between 0.0 and 1.0", + ) + top_p: Optional[float] = Field( + default=0.9, + ge=0.0, + le=1.0, + description="Top-p sampling value, between 0.0 and 1.0", + ) + + +def create_messages( + system_message: str, + user_message: str, + few_shot_examples: Optional[List[Dict[str, str]]] = None, + history: Optional[List[Dict[str, str]]] = None, +) -> List[Dict[str, str]]: + messages = [{"role": "system", "content": system_message}] + if few_shot_examples: + for example in few_shot_examples: + messages.append({"role": "user", "content": example["input"]}) + messages.append({"role": "assistant", "content": example["output"]}) + if history: + messages.extend(history) + messages.append({"role": "user", "content": user_message}) + return messages + + +def create_messages_with_images( + system_message: str, + base64_image: str, + user_message: str = "", + few_shot_examples: Optional[List[Dict]] = None, + history: Optional[List[Dict]] = None, +) -> List[Dict[str, str]]: + messages = [ + { + "role": "system", + "content": [{"type": "text", "text": system_message}], + } + ] + if few_shot_examples: + for example in few_shot_examples: + messages.append( + { + "role": "user", + "content": [{"type": "text", "text": example["input"]}], + } + ) + messages.append( + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": {"url": example["img"]}, + } + ], + } + ) + messages.append( + { + "role": "assistant", + "content": [{"type": "text", "text": example["output"]}], + } + ) + if history: + messages.extend(history) + messages.append( + { + "role": "user", + "content": [{"type": "image_url", "image_url": {"url": base64_image}}], + } + ) + if user_message: + messages[-1]["content"].append({"type": "text", "text": user_message}) + return messages + + +def async_retry(*dargs, **dkwargs): + def decorator(f: Callable) -> Callable: + r = AsyncRetrying(*dargs, **dkwargs) + + async def wrapped_f(*args, **kwargs): + async for attempt in r: + with attempt: + return await f(*args, **kwargs) + + return wrapped_f + + return decorator + + +@async_retry( + wait=wait_random_exponential(min=30, max=120), + stop=stop_after_attempt(3), + retry=lambda e: not isinstance( + e, + ( + litellm.exceptions.AuthenticationError, + ValueError, + litellm.exceptions.RateLimitError, + ), + ), +) +async def completion_with_backoff(**kwargs) -> Message: + """Call the LLM completion endpoint with backoff. + + Supports Azure OpenAI, standard OpenAI, or Ollama based on the model name. + """ + try: + model = kwargs.get("model", "") + logging.info("=== LLM Request Configuration ===") + logging.info(f"Requested Model: {model}") + + # Use Azure if either 'azure/' is prefixed or if an Azure API key + # is provided and not using Ollama + if model.startswith("azure/") or ( + os.getenv("AZURE_OPENAI_API_KEY") and not model.startswith("ollama/") + ): + azure_kwargs = setup_azure_configuration(kwargs) + logging.info(f"Using Azure config for model: {azure_kwargs['model']}") + try: + response = await acompletion(**azure_kwargs, drop_params=True) + return response.choices[0].message.content + except Exception as e: + logging.error(f"Error calling Azure OpenAI: {e}") + raise + + elif model.startswith("ollama/"): + logging.info("=== Ollama Configuration ===") + response = await acompletion(**kwargs, drop_params=True) + return response.choices[0].message + else: + logging.info("=== Standard Configuration ===") + response = await acompletion(**kwargs, drop_params=True) + return response.choices[0].message + + except Exception as e: + logging.error("=== LLM Request Error ===") + # Create a save copy of kwargs without sensitive information + save_config = kwargs.copy() + save_config["api_key"] = "********" if "api_key" in save_config else None + logging.error(f"Error occurred with configuration: {save_config}") + logging.error(f"Error type: {type(e).__name__}") + logging.error(f"Error message: {str(e)}") + if hasattr(e, "response"): + logging.error(f"Response status: {getattr(e.response, 'status_code', 'N/A')}") + logging.error(f"Response body: {getattr(e.response, 'text', 'N/A')}") + raise e + + +def sanitize_json_schema(schema: Dict[str, Any]) -> Dict[str, Any]: + """Make a JSON schema compatible with the LLM providers. + + * sets "additionalProperties" to False + * adds all properties to the "required" list recursively + """ + if "additionalProperties" not in schema: + schema["additionalProperties"] = False + if "properties" in schema: + for key, value in schema["properties"].items(): + if "required" not in schema: + schema["required"] = [] + if key not in schema["required"]: + schema["required"].append(key) + sanitize_json_schema(value) + if "$defs" in schema: + for key in schema["$defs"]: + schema["$defs"][key] = sanitize_json_schema(schema["$defs"][key]) + return schema + + +async def generate_text( + messages: List[Dict[str, str]], + model_name: str, + temperature: float = 0.5, + json_mode: bool = False, + max_tokens: int = 16384, + api_base: Optional[str] = None, + url_variables: Optional[Dict[str, str]] = None, + output_json_schema: Optional[str] = None, + tools: Optional[List[Dict[str, Any]]] = None, + tool_choice: Optional[str] = "auto", + thinking: Optional[Dict[str, Any]] = None, +) -> Message: + """Generate text using the specified LLM model. + + Args: + messages: List of message dictionaries with 'role' and 'content' + model_name: Name of the LLM model to use + temperature: Temperature for randomness, between 0.0 and 1.0 + json_mode: Flag to indicate if JSON output is required + max_tokens: Maximum number of tokens the model can generate + api_base: Base URL for the API + url_variables: Dictionary of URL variables for file inputs + output_json_schema: JSON schema for the output format + tools: List of function schemas for function calling + tool_choice: By default the model will determine when and how many tools to use. You can + force specific behavior with the tool_choice parameter. + auto: (Default) Call zero, one, or multiple functions. tool_choice: "auto" + required: Call one or more functions. tool_choice: "required" + Forced Function: Call exactly one specific function. + tool_choice: {"type": "function", "function": {"name": "get_weather"}} + + thinking: Thinking parameters for the model + + """ + kwargs = { + "model": model_name, + "max_tokens": max_tokens, + "messages": messages, + "temperature": temperature, + } + + # Add function calling parameters if provided + if tools: + kwargs["tools"] = tools + if tool_choice: + kwargs["tool_choice"] = tool_choice + + # Get model info to check capabilities + model_info = LLMModels.get_model_info(model_name) + + # Only add thinking parameters if explicitly requested and supported by the model + if thinking and model_info and model_info.constraints.supports_thinking: + kwargs["thinking"] = thinking + + if model_name == "deepseek/deepseek-reasoner": + kwargs.pop("temperature") + + # Get model info to check if it supports JSON output + if model_info and not model_info.constraints.supports_temperature: + kwargs.pop("temperature", None) + if model_info and not model_info.constraints.supports_max_tokens: + kwargs.pop("max_tokens", None) + supports_json = model_info and model_info.constraints.supports_JSON_output + + # Only process JSON schema if the model supports it + if supports_json: + if output_json_schema is None: + output_json_schema = { + "type": "object", + "properties": {"output": {"type": "string"}}, + "required": ["output"], + } + elif output_json_schema.strip() != "": + output_json_schema = json.loads(output_json_schema) + output_json_schema = sanitize_json_schema(output_json_schema) + else: + raise ValueError("Invalid output schema", output_json_schema) + output_json_schema["additionalProperties"] = False + + # check if the model supports response format + if "response_format" in litellm.get_supported_openai_params( + model=model_name, custom_llm_provider=model_info.provider + ): + if litellm.supports_response_schema( + model=model_name, custom_llm_provider=model_info.provider + ) or model_name.startswith("anthropic"): + if "name" not in output_json_schema and "schema" not in output_json_schema: + output_json_schema = { + "schema": output_json_schema, + "strict": True, + "name": "output", + } + kwargs["response_format"] = { + "type": "json_schema", + "json_schema": output_json_schema, + } + else: + kwargs["response_format"] = {"type": "json_object"} + schema_for_prompt = json.dumps(output_json_schema) + system_message = next( + message for message in messages if message["role"] == "system" + ) + system_message["content"] += ( + "\nYou must respond with valid JSON only." + + " No other text before or after the JSON Object." + + "The JSON Object must adhere to this schema: " + + schema_for_prompt + ) + + if json_mode and supports_json: + if model_name.startswith("ollama"): + if api_base is None: + api_base = os.getenv("OLLAMA_BASE_URL") + options = OllamaOptions(temperature=temperature, max_tokens=max_tokens) + raw_response = await ollama_with_backoff( + model=model_name, + options=options, + messages=messages, + format="json", + api_base=api_base, + ) + response = raw_response + message_response = Message( + content=json.dumps(raw_response), + tool_calls=[], + ) + # Handle inputs with URL variables + elif url_variables: + # check if the mime type is supported + mime_type = get_mime_type_for_url(url_variables["image"]) + if not model_info.constraints.is_mime_type_supported(mime_type): + raise ValueError( + f"""Unsupported file type: "{mime_type.value}" for model {model_name}.""" + f""" Supported types: { + [mime.value for mime in model_info.constraints.supported_mime_types] + }""" + ) + + # Transform messages to include URL content + transformed_messages = [] + for msg in messages: + if msg["role"] == "user": + content = [{"type": "text", "text": msg["content"]}] + # Add any URL variables as image_url or other supported types + for _, url in url_variables.items(): + if url: # Only add if URL is provided + # Check if the URL is a base64 data URL + if is_external_url(url) or url.startswith("data:"): + content.append( + { + "type": "image_url", + "image_url": {"url": url}, + } + ) + else: + # For file paths, encode the file with appropriate MIME type + try: + # Use the new path resolution utility + file_path = resolve_file_path(url) + logging.info(f"Reading file from: {file_path}") + + # Check if file is a DOCX file + if str(file_path).lower().endswith(".docx"): + # Convert DOCX to XML + xml_content = convert_docx_to_xml(str(file_path)) + # Encode the XML content directly + data_url = ( + f"data:text/xml;base64," + f"{base64.b64encode(xml_content.encode()).decode()}" + ) + else: + data_url = encode_file_to_base64_data_url(str(file_path)) + + content.append( + { + "type": "image_url", + "image_url": {"url": data_url}, + } + ) + except Exception as e: + logging.error(f"Error reading file {url}: {str(e)}") + raise + msg["content"] = content + transformed_messages.append(msg) + kwargs["messages"] = transformed_messages + message_response: Message = await completion_with_backoff(**kwargs) + response = message_response.content + raw_response = response + else: + message_response: Message = await completion_with_backoff(**kwargs) + response = message_response.content + raw_response = response + else: + if model_name.startswith("ollama"): + if api_base is None: + api_base = os.getenv("OLLAMA_BASE_URL") + options = OllamaOptions(temperature=temperature, max_tokens=max_tokens) + raw_response = await ollama_with_backoff( + model=model_name, + options=options, + messages=messages, + format="json", + api_base=api_base, + ) + response = raw_response + message_response = Message( + content=json.dumps(raw_response), + tool_calls=[], + ) + else: + message_response: Message = await completion_with_backoff(**kwargs) + response = message_response.content + + # For models that don't support JSON output, wrap the response in a JSON structure + if not supports_json: + sanitized_response = response.replace('"', '\\"').replace("\n", "\\n") + if model_info and model_info.constraints.supports_reasoning: + separator = model_info.constraints.reasoning_separator + sanitized_response = re.sub(separator, "", sanitized_response, flags=re.DOTALL) + + # Check for provider-specific fields + if hasattr(raw_response, "choices") and len(raw_response.choices) > 0: + if hasattr(raw_response.choices[0].message, "provider_specific_fields"): + provider_fields = raw_response.choices[0].message.provider_specific_fields + message_response.content = json.dumps( + { + "output": sanitized_response, + "provider_specific_fields": provider_fields, + } + ) + return message_response + message_response.content = f'{{"output": "{sanitized_response}"}}' + return message_response + + # Ensure response is valid JSON for models that support it + if supports_json: + try: + if message_response.tool_calls and len(message_response.tool_calls) > 0: + # If the model made tool calls, return the raw response + return message_response + else: + # Attempt to parse the response as JSON to validate it + _ = json.loads(response) + return message_response + except json.JSONDecodeError: + logging.error(f"Response is not valid JSON: {response}") + # Try to fix common json issues + if not response.startswith("{"): + # Extract JSON if there is extra text + json_match = re.search(r"\{.*\}", response, re.DOTALL) + if json_match: + response = json_match.group(0) + try: + json.loads(response) + message_response.content = response + return message_response + except json.JSONDecodeError: + pass + + # If all attempts to parse JSON fail, wrap the response in a JSON structure + sanitized_response = response.replace('"', '\\"').replace("\n", "\\n") + # Check for provider-specific fields + if hasattr(raw_response, "choices") and len(raw_response.choices) > 0: + if hasattr(raw_response.choices[0].message, "provider_specific_fields"): + provider_fields = raw_response.choices[0].message.provider_specific_fields + message_response.content = json.dumps( + { + "output": sanitized_response, + "provider_specific_fields": provider_fields, + } + ) + return message_response + message_response.content = f'{{"output": "{sanitized_response}"}}' + return message_response + + return message_response + + +def convert_output_schema_to_json_schema( + output_schema: Dict[str, Any], +) -> Dict[str, Any]: + """Convert a simple output schema to a JSON schema. + + Simple output schema is a dictionary with field names and types. + Types can be one of 'str', 'int', 'float' or 'bool'. + """ + json_schema = { + "type": "object", + "properties": {}, + "required": [], + "additionalProperties": False, + } + for field, field_type in output_schema.items(): + if field_type == "str" or field_type == "string": + json_schema["properties"][field] = {"type": "string"} + elif field_type == "int" or field_type == "integer": + json_schema["properties"][field] = {"type": "integer"} + elif field_type == "float" or field_type == "number": + json_schema["properties"][field] = {"type": "number"} + elif field_type == "bool" or field_type == "boolean": + json_schema["properties"][field] = {"type": "boolean"} + json_schema["required"].append(field) + return json_schema + + +def encode_image(image_path: str) -> str: + with open(image_path, "rb") as image_file: + return base64.b64encode(image_file.read()).decode("utf-8") + + +@async_retry(wait=wait_random_exponential(min=30, max=120), stop=stop_after_attempt(3)) +async def ollama_with_backoff( + model: str, + messages: list[dict[str, str]], + format: Optional[str | dict[str, Any]] = None, + options: Optional[OllamaOptions] = None, + api_base: Optional[str] = None, +) -> str: + """Make an async Ollama API call with exponential backoff retry logic. + + Args: + model: The name of the Ollama model to use + messages: List of message dictionaries with 'role' and 'content' + format: Format for the response, either 'json' or a dictionary + options: OllamaOptions instance with model parameters + api_base: Base URL for the Ollama API + + Returns: + Either a string response or a validated Pydantic model instance + + """ + client = AsyncClient(host=api_base) + try: + response = await client.chat( + model=model.replace("ollama/", ""), + messages=messages, + format=format, + options=(options or OllamaOptions()).to_dict(), + ) + return response.message.content + except Exception as e: + logging.error(f"Error calling Ollama API: {e}") + raise e + + +def convert_docx_to_xml(file_path: str) -> str: + """Convert a DOCX file to XML format. + + Args: + file_path: Path to the DOCX file + Returns: + XML string representation of the DOCX file + + """ + try: + with docx2python(file_path) as docx_content: + # Convert the document content to XML format + xml_content = "\n\n" + + # Add metadata + xml_content += "\n" + for key, value in docx_content.properties.items(): + if value: # Only add non-empty properties + xml_content += f"<{key}>{value}\n" + xml_content += "\n" + + # Add document content + xml_content += "\n" + for paragraph in docx_content.text: + if paragraph: # Skip empty paragraphs + xml_content += f"{paragraph}\n" + xml_content += "\n" + xml_content += "" + + return xml_content + except Exception as e: + logging.error(f"Error converting DOCX to XML: {str(e)}") + raise diff --git a/pyspur/backend/pyspur/nodes/llm/agent.py b/pyspur/backend/pyspur/nodes/llm/agent.py new file mode 100644 index 0000000000000000000000000000000000000000..70083ea099c9556c66db0dc7918f6add3fe37a16 --- /dev/null +++ b/pyspur/backend/pyspur/nodes/llm/agent.py @@ -0,0 +1,464 @@ +import asyncio +import json +import pprint +from typing import Any, Dict, List, Optional, cast + +from jinja2 import Template +from litellm import ChatCompletionMessageToolCall, ChatCompletionToolMessage +from pydantic import BaseModel, Field + +from ...schemas.workflow_schemas import WorkflowDefinitionSchema, WorkflowNodeSchema +from ...utils.pydantic_utils import get_nested_field +from ..base import BaseNode, BaseNodeInput, BaseNodeOutput, VisualTag +from ..factory import NodeFactory +from ._utils import create_messages, generate_text +from .single_llm_call import ( + LLMModels, + ModelInfo, + SingleLLMCallNode, + SingleLLMCallNodeConfig, + repair_json, +) + + +class AgentNodeConfig(SingleLLMCallNodeConfig): + """Configuration for the AgentNode. + + Extends SingleLLMCallNodeConfig with support for tools. + """ + + output_json_schema: str = Field( + default=( + '{"type": "object", "properties": {"output": {"type": "string"},' + ' "tool_calls": {"type": "array", "items": {"type": "string"} } },' + ' "required": ["output", "tool_calls"] }' + ), + title="Output JSON schema", + description="The JSON schema for the output of the node", + ) + + subworkflow: Optional[WorkflowDefinitionSchema] = Field( + None, description="Subworkflow containing tool nodes" + ) + max_iterations: int = Field( + 10, description="Maximum number of tool calls the agent can make in a single run" + ) + + +class AgentNodeInput(BaseNodeInput): + pass + + class Config: + extra = "allow" + + +class AgentNodeOutput(BaseNodeOutput): + pass + + +class AgentNode(SingleLLMCallNode): + """Node for executing an LLM-based agent with tool-calling capabilities. + + Features: + - All features from SingleLLMCallNode + - Support for tool calling with other workflow nodes + - Control over the number of iterations and tool choice + - Tool results are fed back to the LLM for further reasoning + """ + + name = "agent_node" + display_name = "Agent" + config_model = AgentNodeConfig + input_model = AgentNodeInput + output_model = AgentNodeOutput + visual_tag = VisualTag(acronym="AGNT", color="#fb8500") + + def __init__( + self, + name: str, + config: AgentNodeConfig, + context: Optional[Any] = None, + tools: Optional[List[WorkflowNodeSchema]] = None, + ) -> None: + super().__init__(name, config, context) + if tools is not None: + self.subworkflow = WorkflowDefinitionSchema(nodes=tools, links=[]) + + def setup(self) -> None: + super().setup() + # Create a dictionary of tool nodes for easy access + self.tools_dict: Dict[str, WorkflowNodeSchema] = {} + tools: List[WorkflowNodeSchema] = ( + self.subworkflow.nodes + if self.subworkflow is not None + else self.config.subworkflow.nodes + if self.config.subworkflow.nodes is not None + else [] + ) or [] + for tool in tools: + self.tools_dict[tool.title.lower()] = tool + + ## Create instances of the tools + self.tools_instances: Dict[str, BaseNode] = {} + for tool in tools: + # Create node instance + tool_node_instance = NodeFactory.create_node( + node_name=tool.title, + node_type_name=tool.node_type, + config=tool.config, + ) + self.tools_instances[tool.title.lower()] = tool_node_instance + + ## Create list of tool schemas to pass to the LLM + self.tools_schemas: List[Dict[str, Any]] = [] + for tool in tools: + tool_node_instance = self.tools_instances[tool.title.lower()] + tool_schema = tool_node_instance.function_schema + self.tools_schemas.append(tool_schema) + + def _render_template(self, template_str: str, data: Dict[str, Any]) -> str: + """Render a template with the given data.""" + try: + return Template(template_str).render(**data) + except Exception as e: + print(f"[ERROR] Failed to render template: {e}") + return template_str + + def add_tools(self, tools: List[WorkflowNodeSchema] | List[BaseNode]) -> None: + """Add tools to the agent node.""" + for tool in tools: + if isinstance(tool, BaseNode): + tool_schema = tool.function_schema + else: + tool_node_instance = NodeFactory.create_node( + node_name=tool.title, + node_type_name=tool.node_type, + config=tool.config, + ) + tool_schema = tool_node_instance.function_schema + self.tools_schemas.append(tool_schema) + + async def _call_tool(self, tool_call: ChatCompletionMessageToolCall) -> Any: + """Call a tool with the provided parameters.""" + tool_name = tool_call.function.name + tool_args = tool_call.function.arguments + tool_call_id = tool_call.id + + assert tool_name is not None, "Tool name cannot be None" + assert tool_call_id is not None, "Tool call ID cannot be None" + assert tool_args is not None, "Tool arguments cannot be None" + + # Get the tool node from the dictionary + tool_node = self.tools_dict.get(tool_name.lower()) + if not tool_node: + raise ValueError(f"Tool {tool_name} not found in tools dictionary") + + # Create node instance + tool_node_instance = NodeFactory.create_node( + node_name=tool_node.title, + node_type_name=tool_node.node_type, + config=tool_node.config, + ) + tool_args = json.loads(tool_args) + return await tool_node_instance.call_as_tool(arguments=tool_args) + + async def execute_parallel_tool_calls( + self, tool_calls: List[ChatCompletionMessageToolCall] + ) -> List[ChatCompletionToolMessage]: + """Execute multiple tool calls in parallel.""" + + # Create async tasks for all tool calls to execute them concurrently + async def process_tool_call( + tool_call: ChatCompletionMessageToolCall, + ) -> ChatCompletionToolMessage: + tool_response = await self._call_tool(tool_call) + return ChatCompletionToolMessage( + role="tool", + content=str(tool_response), + tool_call_id=tool_call.id, + ) + + # Use asyncio.gather to run all tool calls concurrently + tool_messages: List[ChatCompletionToolMessage] = await asyncio.gather( + *[process_tool_call(tool_call) for tool_call in tool_calls] + ) + return tool_messages + + async def run(self, input: BaseModel) -> BaseModel: + # Get the raw input dictionary + raw_input_dict = input.model_dump() + + # Render the system message with the input data + system_message = self._render_template(self.config.system_message, raw_input_dict) + try: + # If user_message is empty, dump the entire raw dictionary + if not self.config.user_message.strip(): + user_message = json.dumps(raw_input_dict, indent=2) + else: + user_message = Template(self.config.user_message).render(**raw_input_dict) + except Exception as e: + print(f"[ERROR] Failed to render user_message {self.name}") + print(f"[ERROR] user_message: {self.config.user_message} with input: {raw_input_dict}") + raise e + + # Extract message history from input if enabled + history: Optional[List[Dict[str, str]]] = None + if self.config.enable_message_history and self.config.message_history_variable: + try: + # Try to get history from the specified variable + history_var = self.config.message_history_variable + if "." in history_var: + # Handle nested fields (e.g., "input_node.message_history") + history = get_nested_field(history_var, input) + else: + # Direct field access + history = raw_input_dict.get(history_var) + + assert isinstance(history, list) or history is None, ( + f"Expected message history to be a list or None, got {type(history)}" + ) + except Exception as e: + print(f"[ERROR] Failed to extract message history: {e}") + history = None + + messages: List[Dict[str, Any]] = create_messages( + system_message=system_message, + user_message=user_message, + few_shot_examples=self.config.few_shot_examples, + history=history, + ) + + model_name = LLMModels(self.config.llm_info.model).value + + url_vars: Optional[Dict[str, str]] = None + # Process URL variables if they exist and we're using a Gemini model + if self.config.url_variables: + url_vars = {} + if "file" in self.config.url_variables: + # Split the input variable reference (e.g. "input_node.video_url") + # Get the nested field value using the helper function + file_value = get_nested_field(self.config.url_variables["file"], input) + # Always use image_url format regardless of file type + url_vars["image"] = file_value + + # Prepare thinking parameters if enabled + thinking_params = None + if self.config.enable_thinking: + model_info = LLMModels.get_model_info(model_name) + if model_info and model_info.constraints.supports_thinking: + thinking_params = { + "type": "enabled", + "budget_tokens": self.config.thinking_budget_tokens + or model_info.constraints.thinking_budget_tokens + or 1024, + } + assert len(self.tools_schemas) > 0, "No tools found in the agent node" + + try: + num_iterations = 0 + model_response = "" + # Loop until either the maximum number of iterations is reached + # or the model responds with an assistant message + while num_iterations < self.config.max_iterations: + message_response = await generate_text( + messages=messages, + model_name=model_name, + temperature=self.config.llm_info.temperature, + max_tokens=self.config.llm_info.max_tokens, + json_mode=True, + url_variables=url_vars, + output_json_schema=self.config.output_json_schema, + thinking=thinking_params, + tools=self.tools_schemas, + ) + print(f"[DEBUG] Iteration {num_iterations + 1} response: {message_response}") + # add the response to the messages + messages.append( + { + "role": message_response.role, + "content": str(message_response.content), + "tool_calls": message_response.tool_calls, + } + ) + num_iterations += 1 + # Check if the response is a tool call + if message_response.tool_calls and len(message_response.tool_calls) > 0: + tool_responses = await self.execute_parallel_tool_calls( + message_response.tool_calls + ) + + messages.extend(cast(List[Dict[str, str]], tool_responses)) + pprint.pprint(messages) + # Add the tool responses to the messages and call the LLM for the next turn + continue + elif ( + message_response.tool_calls is None + and message_response.role == "assistant" + and message_response.content is not None + ): + # If the response is not a tool call, break the loop + # and process the assistant message + model_response = str(message_response.content) + break + else: + # If the response is not a tool call and not an assistant message, + # continue to the next iteration + continue + except Exception as e: + error_str = str(e) + + # Handle all LiteLLM errors + if "litellm" in error_str.lower(): + error_message = "An error occurred with the LLM service" + error_type = "unknown" + + # Extract provider from model name + provider = model_name.split("/")[0] if "/" in model_name else "unknown" + + # Handle specific known error cases + if "VertexAIError" in error_str and "The model is overloaded" in error_str: + error_type = "overloaded" + error_message = "The model is currently overloaded. Please try again later." + elif "rate limit" in error_str.lower(): + error_type = "rate_limit" + error_message = "Rate limit exceeded. Please try again in a few minutes." + elif "context length" in error_str.lower() or "maximum token" in error_str.lower(): + error_type = "context_length" + error_message = ( + "Input is too long for the model's context window." + " Please reduce the input length." + ) + elif ( + "invalid api key" in error_str.lower() or "authentication" in error_str.lower() + ): + error_type = "auth" + error_message = ( + "Authentication error with the LLM service. Please check your API key." + ) + elif "bad gateway" in error_str.lower() or "503" in error_str: + error_type = "service_unavailable" + error_message = ( + "The LLM service is temporarily unavailable. Please try again later." + ) + + raise Exception( + json.dumps( + { + "type": "model_provider_error", + "provider": provider, + "error_type": error_type, + "message": error_message, + "original_error": error_str, + } + ) + ) from e + raise e + + try: + assistant_message_dict = json.loads(model_response) + except Exception: + try: + repaired_str = repair_json(model_response) + assistant_message_dict = json.loads(repaired_str) + except Exception as inner_e: + error_str = str(inner_e) + error_message = ( + "An error occurred while parsing and repairing the assistant message" + ) + error_type = "json_parse_error" + raise Exception( + json.dumps( + { + "type": "parsing_error", + "error_type": error_type, + "message": error_message, + "original_error": error_str, + "assistant_message_str": model_response, + } + ) + ) from inner_e + + assistant_message_dict["tool_calls"] = [ + str(m) for m in messages if m["role"] == "tool" or "tool_calls" in m + ] + + # Validate and return + assistant_message = self.output_model.model_validate(assistant_message_dict) + return assistant_message + + +if __name__ == "__main__": + import asyncio + + async def test_agent_node(): + # Example tool node schema + tool_schema = WorkflowNodeSchema( + id="calculator", + title="Calculator", + node_type="SingleLLMCallNode", + config={ + "output_json_schema": '{"type": "object", \ +"properties": {"result": {"type": "number"} }, "required": ["result"] }', + "llm_info": { + "model": "gpt-4o", + "temperature": 0.7, + "max_tokens": 1000, + }, + "system_message": "You are a calculator.", + "user_message": "{{ expression }}", + }, + ) + + # Create agent node + agent_node = AgentNode( + name="MathHelper", + config=AgentNodeConfig( + llm_info=ModelInfo(model=LLMModels.GPT_4O, temperature=0.7, max_tokens=1000), + system_message=( + "You are a helpful assistant that can use tools to solve math problems." + ), + user_message="I need help with this math problem: {{ problem }}", + max_iterations=5, + url_variables=None, + enable_thinking=False, + thinking_budget_tokens=None, + enable_message_history=True, + message_history_variable="message_history", + output_json_schema=json.dumps( + { + "type": "object", + "properties": { + "answer": {"type": "string"}, + "explanation": {"type": "string"}, + }, + "required": ["answer", "explanation"], + } + ), + subworkflow=None, + ), + ) + agent_node.add_tools([tool_schema]) + + # Create input with message history + test_input = AgentNodeInput.model_validate( + { + "problem": "What is 25 × 13?", + "message_history": [ + {"role": "user", "content": "Can you help me with some math problems?"}, + { + "role": "assistant", + "content": ( + "Of course! I'd be happy to help you solve math problems." + " What would you like to calculate?" + ), + }, + ], + } + ) + + # Run the agent + print("[DEBUG] Testing agent_node...") + output = await agent_node(test_input) + print("[DEBUG] Agent output:", output) + + asyncio.run(test_agent_node()) diff --git a/pyspur/backend/pyspur/nodes/llm/generative/best_of_n.py b/pyspur/backend/pyspur/nodes/llm/generative/best_of_n.py new file mode 100644 index 0000000000000000000000000000000000000000..e954ba77aa5bd38ad1162432e90afe040c232a97 --- /dev/null +++ b/pyspur/backend/pyspur/nodes/llm/generative/best_of_n.py @@ -0,0 +1,220 @@ +import json +from typing import Dict, List + +from pydantic import Field + +from ....nodes.base import BaseNodeInput, BaseNodeOutput +from ....schemas.workflow_schemas import ( + WorkflowDefinitionSchema, + WorkflowLinkSchema, + WorkflowNodeSchema, +) +from ....utils.pydantic_utils import json_schema_to_simple_schema +from ...subworkflow.base_subworkflow_node import ( + BaseSubworkflowNode, + BaseSubworkflowNodeConfig, +) +from .._utils import LLMModels, ModelInfo +from ..single_llm_call import SingleLLMCallNodeConfig + + +class BestOfNNodeConfig(SingleLLMCallNodeConfig, BaseSubworkflowNodeConfig): + samples: int = Field(default=3, ge=1, le=10, description="Number of samples to generate") + rating_prompt: str = Field( + default=( + "Rate the following response on a scale from 0 to 10, where 0 is poor " + "and 10 is excellent. Consider factors such as relevance, coherence, " + "and helpfulness. Respond with only a number." + ), + description="The prompt for the rating LLM", + ) + system_message: str = Field( + default="You are a helpful assistant.", + description="System message for the generation LLM", + ) + user_message: str = Field(default="", description="User message template") + output_schema: Dict[str, str] = Field(default={"response": "string"}) + + +class BestOfNNodeInput(BaseNodeInput): + pass + + +class BestOfNNodeOutput(BaseNodeOutput): + pass + + +class BestOfNNode(BaseSubworkflowNode): + name = "best_of_n_node" + display_name = "Best of N" + config_model = BestOfNNodeConfig + workflow: WorkflowDefinitionSchema + + input_model = BestOfNNodeInput + output_model = BestOfNNodeOutput + + def setup_subworkflow(self) -> None: + samples = self.config.samples + + # Generate the nodes for the subworkflow + nodes: List[WorkflowNodeSchema] = [] + links: List[WorkflowLinkSchema] = [] + + # Input node + input_node_id = "best_of_n_input_node" + input_node = WorkflowNodeSchema( + id=input_node_id, + node_type="InputNode", + config={"enforce_schema": False}, + ) + nodes.append(input_node) + + generation_node_ids: List[str] = [] + rating_node_ids: List[str] = [] + + for i in range(samples): + gen_node_id = f"generation_node_{i}" + gen_node = WorkflowNodeSchema( + id=gen_node_id, + node_type="SingleLLMCallNode", + config={ + "llm_info": self.config.llm_info.model_dump(), + "system_message": self.config.system_message, + "user_message": self.config.user_message, + "output_schema": self.config.output_schema, + "output_json_schema": self.config.output_json_schema, + }, + ) + nodes.append(gen_node) + generation_node_ids.append(gen_node_id) + + # Link input node to generation node + links.append( + WorkflowLinkSchema( + source_id=input_node_id, + target_id=gen_node_id, + ) + ) + + rate_node_id = f"rating_node_{i}" + rate_node = WorkflowNodeSchema( + id=rate_node_id, + node_type="SingleLLMCallNode", + config={ + "llm_info": self.config.llm_info.model_dump(), + "system_message": self.config.rating_prompt, + "user_message": "", + "output_schema": {"rating": "number"}, + "output_json_schema": '{"type": "object", "properties": {"rating": {"type": "number"} }, "required": ["rating"]}', + }, + ) + nodes.append(rate_node) + rating_node_ids.append(rate_node_id) + + # Link generation node to rating node + links.append( + WorkflowLinkSchema( + source_id=gen_node_id, + target_id=rate_node_id, + ) + ) + + # Create a PickOneNode to select the JSON string with the highest rating + pick_one_node_id = "pick_one_node" + if self.config.output_json_schema: + output_schema = json_schema_to_simple_schema(json.loads(self.config.output_json_schema)) + else: + output_schema = self.config.output_schema + pick_one_node = WorkflowNodeSchema( + id=pick_one_node_id, + node_type="PythonFuncNode", + config={ + "output_schema": output_schema, + "output_json_schema": self.config.output_json_schema, + "code": ( + """gen_and_ratings = input_model.model_dump()\n""" + """print(gen_and_ratings)\n""" + """ratings = {k:v['rating'] for k,v in gen_and_ratings.items() if 'rating_node' in k}\n""" + """highest_rating_key = max(ratings, key=ratings.get)\n""" + """print(highest_rating_key)\n""" + """corresponding_gen_key = highest_rating_key.replace('rating_node', 'generation_node')\n""" + """return gen_and_ratings[corresponding_gen_key]\n""" + ), + }, + ) + nodes.append(pick_one_node) + + # Link all generation nodes to the pick_one_node + for i, gen_node_id in enumerate(generation_node_ids): + links.append( + WorkflowLinkSchema( + source_id=gen_node_id, + target_id=pick_one_node_id, + ) + ) + + # Link all rating nodes to the pick_one_node + for i, rate_node_id in enumerate(rating_node_ids): + links.append( + WorkflowLinkSchema( + source_id=rate_node_id, + target_id=pick_one_node_id, + ) + ) + + # add the output node + output_node_id = "output_node" + output_node = WorkflowNodeSchema( + id=output_node_id, + node_type="OutputNode", + config={ + "output_map": {f"{k}": f"pick_one_node.{k}" for k in output_schema.keys()}, + "output_schema": output_schema, + "output_json_schema": self.config.output_json_schema, + }, + ) + nodes.append(output_node) + + # Link the pick_one_node to the output node + links.append( + WorkflowLinkSchema( + source_id=pick_one_node_id, + target_id=output_node_id, + ) + ) + + self.subworkflow = WorkflowDefinitionSchema( + nodes=nodes, + links=links, + ) + super().setup_subworkflow() + + def setup(self) -> None: + self.output_model = self.create_output_model_class(self.config.output_schema) + super().setup() + + +if __name__ == "__main__": + node = BestOfNNode( + name="best_of_n_node", + config=BestOfNNodeConfig( + samples=3, + rating_prompt="Rate the following response on a scale from 0 to 10, where 0 is poor and 10 is excellent. Consider factors such as relevance, coherence, and helpfulness. Respond with only a number.", + llm_info=ModelInfo(model=LLMModels.GPT_4O, max_tokens=150, temperature=1), + system_message="You are a helpful assistant.", + user_message="", + output_schema={"response": "string"}, + url_variables=None, + output_json_schema='{"type": "object", "properties": {"response": {"type": "string"} }, "required": ["response"]}', + ), + ) + import asyncio + + class input_model(BaseNodeInput): + task: str = "write a joke" + comedian: str = "jimmy carr" + + input = input_model() + + output = asyncio.run(node(input)) + print(output) diff --git a/pyspur/backend/pyspur/nodes/llm/generative/branch_solve_merge.py b/pyspur/backend/pyspur/nodes/llm/generative/branch_solve_merge.py new file mode 100644 index 0000000000000000000000000000000000000000..4fabe7d3ea2a55ed1553e6908bb88ba3824ee135 --- /dev/null +++ b/pyspur/backend/pyspur/nodes/llm/generative/branch_solve_merge.py @@ -0,0 +1,342 @@ +from typing import Any, Dict, List + +from pydantic import BaseModel, Field + +from ....execution.workflow_executor import WorkflowExecutor +from ....nodes.base import BaseNodeInput, BaseNodeOutput +from ....schemas.workflow_schemas import ( + WorkflowDefinitionSchema, + WorkflowLinkSchema, + WorkflowNodeSchema, +) +from ...subworkflow.base_subworkflow_node import ( + BaseSubworkflowNode, + BaseSubworkflowNodeConfig, +) +from .._utils import LLMModels, ModelInfo + + +class BranchSolveMergeNodeConfig(BaseSubworkflowNodeConfig): + llm_info: ModelInfo = Field( + default_factory=lambda: ModelInfo( + model=LLMModels.GPT_4O, max_tokens=16384, temperature=0.7 + ), + description="The default LLM model to use", + ) + branch_system_message: str = Field( + default=( + "Please decompose the following task into logical subtasks " + "that make solving the overall task easier." + ), + description="The prompt for the branch LLM", + ) + solve_system_message: str = Field( + default="Please provide a detailed solution for the following subtask:", + description="The prompt for the solve LLM", + ) + merge_system_message: str = Field( + default="Please combine the following solutions into a coherent and comprehensive final answer.", + description="The prompt for the merge LLM", + ) + llm_info: ModelInfo = Field( + default_factory=lambda: ModelInfo( + model=LLMModels.GPT_4O, max_tokens=16384, temperature=0.7 + ), + description="The default LLM model to use", + ) + input_schema: Dict[str, str] = Field(default={"task": "string"}) + output_schema: Dict[str, str] = Field(default={"response": "string"}) + + +class BranchSolveMergeNodeInput(BaseNodeInput): + pass + + +class BranchSolveMergeNodeOutput(BaseNodeOutput): + pass + + +class BranchSolveMergeNode(BaseSubworkflowNode): + name = "branch_solve_merge_node" + display_name = "Branch Solve Merge" + config_model = BranchSolveMergeNodeConfig + input_model = BranchSolveMergeNodeInput + output_model = BranchSolveMergeNodeOutput + + async def run(self, input: BaseModel) -> BaseModel: + """ + Run the BranchSolveMergeNode in two steps: + Step 1: Run the branch node to get subtasks. + Step 2: Build the rest of the subworkflow based on the subtasks and execute it. + """ + # Apply templates to config fields + input_dict = input.model_dump() + new_config = self.apply_templates_to_config(self.config, input_dict) + self.update_config(new_config) + + # Step 1: Run the branch node to get the subtasks + # Build subworkflow for step 1 + self.setup_branch_subworkflow() + branch_output = await self.run_subworkflow(input) + + # Extract subtasks from branch_output + subtasks: List[str] = branch_output["subtasks"] + assert isinstance(subtasks, list) + + # Step 2: Build the subworkflow including solve nodes for each subtask + self.setup_full_subworkflow(subtasks) + + # Prepare the inputs for the subworkflow, including passing the previous outputs + # We don't want the branch node to run again, so we pass its output + self.subworkflow_output = {self.branch_node_id: branch_output} + + # Run the subworkflow starting from solve nodes + final_output = await self.run_subworkflow(input) + + # Return the output of the output node + return self.output_model.model_validate(final_output) + + def setup_branch_subworkflow(self) -> None: + """ + Setup the subworkflow for Step 1: Running the branch node to get subtasks. + """ + nodes: List[WorkflowNodeSchema] = [] + links: List[WorkflowLinkSchema] = [] + + # Input node + input_node_id = "branch_solve_merge_input_node" + self.input_node_id = input_node_id + input_node = WorkflowNodeSchema( + id=input_node_id, + node_type="InputNode", + config={ + "output_schema": {"task": "string"}, + "enforce_schema": False, + }, + ) + nodes.append(input_node) + + # Branch node: Decompose task into subtasks + branch_node_id = "branch_node" + self.branch_node_id = branch_node_id + branch_node = WorkflowNodeSchema( + id=branch_node_id, + node_type="SingleLLMCallNode", + config={ + "llm_info": self.config.llm_info.model_dump(), + "system_message": self.config.branch_system_message, + "user_message": "", + "output_schema": {"subtasks": "List[str]"}, # Expecting list of subtasks + }, + ) + nodes.append(branch_node) + + # Link input node to branch node + links.append( + WorkflowLinkSchema( + source_id=input_node_id, + target_id=branch_node_id, + ) + ) + + # Output node + output_node_id = "output_node" + self.output_node_id = output_node_id + output_node = WorkflowNodeSchema( + id=output_node_id, + node_type="OutputNode", + config={ + "output_schema": {"subtasks": "List[str]"}, + "output_map": {"subtasks": f"{branch_node_id}.subtasks"}, + }, + ) + nodes.append(output_node) + + # Link branch node to output node + links.append( + WorkflowLinkSchema( + source_id=branch_node_id, + target_id=output_node_id, + ) + ) + + self.subworkflow = WorkflowDefinitionSchema(nodes=nodes, links=links) + self.setup_subworkflow() + + def setup_full_subworkflow(self, subtasks: List[str]) -> None: + """ + Setup the subworkflow for Step 2: Solve subtasks and merge solutions. + This subworkflow reuses the branch node's output and adds solve nodes for each subtask. + """ + nodes: List[WorkflowNodeSchema] = [] + links: List[WorkflowLinkSchema] = [] + + # Input node (same as before) + input_node_id = self.input_node_id + input_node = WorkflowNodeSchema( + id=input_node_id, + node_type="InputNode", + config={"echo_mode": True}, + ) + nodes.append(input_node) + + # Branch node (same as before) + branch_node_id = self.branch_node_id + branch_node = WorkflowNodeSchema( + id=branch_node_id, + node_type="SingleLLMCallNode", + config={ + "llm_info": self.config.llm_info.model_dump(), + "system_message": self.config.branch_system_message, + "user_message": "", + "output_schema": {"subtasks": "List[str]"}, + }, + ) + nodes.append(branch_node) + + # Link input node to branch node + links.append( + WorkflowLinkSchema( + source_id=input_node_id, + target_id=branch_node_id, + ) + ) + + # For each subtask, create a solve node + solve_node_ids: List[str] = [] + for idx, _subtask in enumerate(subtasks): + solve_node_id = f"solve_node_{idx}" + solve_node_ids.append(solve_node_id) + solve_node = WorkflowNodeSchema( + id=solve_node_id, + node_type="SingleLLMCallNode", + config={ + "llm_info": self.config.llm_info.model_dump(), + "system_message": self.config.solve_system_message, + "user_message": f"{{{{branch_node.subtasks[{idx}]}}}}", + "output_schema": {f"solution_{idx}": "string"}, + }, + ) + nodes.append(solve_node) + + # Link branch node to solve node + links.append( + WorkflowLinkSchema( + source_id=branch_node_id, + target_id=solve_node_id, + ) + ) + + # Merge node: Combine solutions + merge_node_id = "merge_node" + merge_node = WorkflowNodeSchema( + id=merge_node_id, + node_type="SingleLLMCallNode", + config={ + "llm_info": self.config.llm_info.model_dump(), + "system_message": self.config.merge_system_message, + "user_message": "\n".join( + [f"{{{{solve_node_{i}.solution_{i}}}}}" for i in range(len(subtasks))] + ), + "output_schema": self.config.output_schema, + }, + ) + nodes.append(merge_node) + + # Link solve nodes to merge node + for solve_node_id in solve_node_ids: + links.append( + WorkflowLinkSchema( + source_id=solve_node_id, + target_id=merge_node_id, + ) + ) + + # Output node + output_node_id = "output_node" + output_node = WorkflowNodeSchema( + id=output_node_id, + node_type="OutputNode", + config={ + "output_schema": self.config.output_schema, + "output_map": { + key: f"{merge_node_id}.{key}" for key in self.config.output_schema.keys() + }, + }, + ) + nodes.append(output_node) + + # Link merge node to output node + links.append( + WorkflowLinkSchema( + source_id=merge_node_id, + target_id=output_node_id, + ) + ) + + # Update subworkflow + self.subworkflow = WorkflowDefinitionSchema(nodes=nodes, links=links) + + # Since we have already run branch node, we don't want to run it again + # subworkflow_output stores outputs from previous runs + self.setup_subworkflow() + + async def run_subworkflow(self, input: BaseModel) -> Dict[str, Any]: + """ + Run the current subworkflow and return the output of the output node. + """ + assert self.subworkflow is not None + + # Map input + mapped_input = self._map_input(input) + + # Prepare inputs for subworkflow + input_node = next( + (node for node in self.subworkflow.nodes if node.node_type == "InputNode") + ) + input_dict = {input_node.id: mapped_input} + + # Use stored outputs to avoid re-running nodes + precomputed_outputs = self.subworkflow_output or {} + + # Execute the subworkflow + workflow_executor = WorkflowExecutor(workflow=self.subworkflow, context=self.context) + outputs = await workflow_executor.run(input_dict, precomputed_outputs=precomputed_outputs) + + # Store outputs for potential reuse + if self.subworkflow_output is None: + self.subworkflow_output = outputs + else: + self.subworkflow_output.update(outputs) + + # Get the output of the output node + output_node = next( + (node for node in self.subworkflow.nodes if node.node_type == "OutputNode") + ) + return outputs[output_node.id].model_dump() + + def setup(self) -> None: + # Initial setup + # We don't set up the subworkflow here because it depends on data available at runtime + super().setup() + + +if __name__ == "__main__": + import asyncio + from pprint import pprint + + async def main(): + node = BranchSolveMergeNode( + name="branch_solve_merge_node", + config=BranchSolveMergeNodeConfig(), + ) + + class TestInput(BranchSolveMergeNodeInput): + task: str = "Write a joke like Jimmy Carr about alternative medicine." + + input_data = TestInput() + output = await node(input_data) + pprint(output) + pprint(node.subworkflow_output) + + asyncio.run(main()) diff --git a/pyspur/backend/pyspur/nodes/llm/retriever.py b/pyspur/backend/pyspur/nodes/llm/retriever.py new file mode 100644 index 0000000000000000000000000000000000000000..9cf24bc3b4e5ae613c465ec30bd2f455bba2b9c7 --- /dev/null +++ b/pyspur/backend/pyspur/nodes/llm/retriever.py @@ -0,0 +1,193 @@ +import json +from typing import Dict, List + +from jinja2 import Template +from loguru import logger +from pydantic import BaseModel, Field +from sqlalchemy.orm import Session + +from ...database import get_db +from ...models.dc_and_vi_model import VectorIndexModel +from ...rag.embedder import EmbeddingModels +from ...rag.vector_index import VectorIndex +from ...schemas.rag_schemas import ( + ChunkMetadataSchema, + RetrievalResultSchema, +) +from ..base import ( + BaseNode, + BaseNodeConfig, + BaseNodeInput, + BaseNodeOutput, +) + +# Todo: Use Fixed Node Output; where the outputs will always be chunks + + +class RetrieverNodeInput(BaseNodeInput): + """Input for the retriever node""" + + class Config: + extra = "allow" + + +class RetrieverNodeOutput(BaseNodeOutput): + """Output from the retriever node""" + + results: List[RetrievalResultSchema] = Field(..., description="List of retrieved results") + total_results: int = Field(..., description="Total number of results found") + + +class RetrieverNodeConfig(BaseNodeConfig): + """Configuration for the retriever node""" + + output_schema: Dict[str, str] = Field( + default={ + "results": "list[RetrievalResultSchema]", + "total_results": "integer", + }, + description="The schema for the output of the node", + ) + output_json_schema: str = Field( + default=json.dumps(RetrieverNodeOutput.model_json_schema(), indent=2), + description="The JSON schema for the output of the node", + ) + vector_index_id: str = Field(..., description="ID of the vector index to query", min_length=1) + top_k: int = Field(5, description="Number of results to return", ge=1, le=10) + query_template: str = Field( + "{{input_1}}", + description="Template for the query string. Use {{variable}} syntax to reference input variables.", + ) + # score_threshold: Optional[float] = Field(None, description="Minimum similarity score threshold") + # semantic_weight: float = Field(1.0, description="Weight for semantic search (0 to 1)") + # keyword_weight: Optional[float] = Field(None, description="Weight for keyword search (0 to 1)") + + +class RetrieverNode(BaseNode): + """Node for retrieving relevant documents from a vector index""" + + name = "retriever_node" + display_name = "Retriever" + config_model = RetrieverNodeConfig + input_model = RetrieverNodeInput + output_model = RetrieverNodeOutput + + async def validate_index(self, db: Session) -> None: + """Validate that the vector index exists and is ready""" + index = ( + db.query(VectorIndexModel) + .filter(VectorIndexModel.id == self.config.vector_index_id) + .first() + ) + if not index: + raise ValueError(f"Vector index {self.config.vector_index_id} not found") + if index.status != "ready": + raise ValueError( + f"Vector index {self.config.vector_index_id} is not ready (status: {index.status})" + ) + + async def run(self, input: BaseModel) -> BaseModel: + # Get database session + db = next(get_db()) + + try: + # Validate index exists and is ready + await self.validate_index(db) + + # Get vector index configuration from database + vector_index_model = ( + db.query(VectorIndexModel) + .filter(VectorIndexModel.id == self.config.vector_index_id) + .first() + ) + if not vector_index_model: + raise ValueError(f"Vector index {self.config.vector_index_id} not found") + + logger.info( + f"[DEBUG] Vector index configuration: {vector_index_model.embedding_config}" + ) + + # Get embedding model from vector index configuration + embedding_model = vector_index_model.embedding_config.get("model") + if not embedding_model: + raise ValueError("No embedding model specified in vector index configuration") + + logger.info(f"[DEBUG] Using embedding model: {embedding_model}") + + # Initialize vector index and set its configuration + vector_index = VectorIndex(self.config.vector_index_id) + embedding_model_info = EmbeddingModels.get_model_info(embedding_model) + assert embedding_model_info is not None + vector_index.update_config( + { + "embedding_config": { + "model": embedding_model, + "dimensions": embedding_model_info.dimensions, + }, + "vector_db": vector_index_model.embedding_config.get("vector_db", "pinecone"), + } + ) + + # Render query template with input variables + raw_input_dict = input.model_dump() + query = Template(self.config.query_template).render(**raw_input_dict) + + # Create retrieval request + results = await vector_index.retrieve( + query=query, + top_k=self.config.top_k, + ) + + # Format results + formatted_results: List[RetrievalResultSchema] = [] + for result in results: + chunk = result["chunk"] + metadata = result["metadata"] + formatted_results.append( + RetrievalResultSchema( + text=chunk.text, + score=result["score"], + metadata=ChunkMetadataSchema( + document_id=metadata.get("document_id", ""), + chunk_id=metadata.get("chunk_id", ""), + document_title=metadata.get("document_title"), + page_number=metadata.get("page_number"), + chunk_number=metadata.get("chunk_number"), + ), + ) + ) + + return RetrieverNodeOutput( + results=formatted_results, total_results=len(formatted_results) + ) + except Exception as e: + raise ValueError(f"Error retrieving from vector index: {str(e)}") + finally: + db.close() + + +if __name__ == "__main__": + import asyncio + + async def test_retriever_node(): + # Create a test instance + retriever = RetrieverNode( + name="test_retriever", + config=RetrieverNodeConfig( + vector_index_id="VI1", # Using proper vector index ID format + top_k=3, + query_template="{{input_1}}", + ), + ) + + # Create test input + test_input = RetrieverNodeInput(input_1="What is machine learning?") # type: ignore + + print("[DEBUG] Testing retriever_node...") + try: + output = await retriever(test_input) + print("[DEBUG] Test Output:", output) + except Exception as e: + print("[ERROR] Test failed:", str(e)) + + asyncio.run(test_retriever_node()) diff --git a/pyspur/backend/pyspur/nodes/llm/single_llm_call.py b/pyspur/backend/pyspur/nodes/llm/single_llm_call.py new file mode 100644 index 0000000000000000000000000000000000000000..4d0fd95275bcd5ddc821ce9bd612d27ddc5a884d --- /dev/null +++ b/pyspur/backend/pyspur/nodes/llm/single_llm_call.py @@ -0,0 +1,459 @@ +import json +from typing import Dict, List, Optional + +from dotenv import load_dotenv +from jinja2 import Template +from pydantic import BaseModel, Field + +from ...utils.pydantic_utils import get_nested_field, json_schema_to_model +from ..base import ( + BaseNode, + BaseNodeConfig, + BaseNodeInput, + BaseNodeOutput, +) +from ._utils import LLMModels, ModelInfo, create_messages, generate_text + +load_dotenv() + + +def repair_json(broken_json_str: str) -> str: + import re + from re import Match + from typing import Dict + + # Handle empty or non-string input + if not broken_json_str or not broken_json_str.strip(): + return "{}" + + repaired = broken_json_str + + # Remove common LLM artifacts like XML/markdown tags that might be mixed in with JSON + repaired = re.sub(r"", "", repaired) + + # Remove markdown code block markers if present + repaired = re.sub(r"^```(json)?|```$", "", repaired, flags=re.MULTILINE) + + # Try to extract just the JSON part if it's mixed with other text + json_match = re.search(r"(\{[\s\S]*\})", repaired) + if json_match: + repaired = json_match.group(1) + + # Convert single quotes to double quotes, but not within already double-quoted strings + # First, temporarily replace valid double-quoted strings + placeholder = "PLACEHOLDER" + quoted_strings: Dict[str, str] = {} + counter = 0 + + def replace_quoted(match: Match[str]) -> str: + nonlocal counter + key = f"{placeholder}{counter}" + quoted_strings[key] = match.group(0) + counter += 1 + return key + + # Temporarily store valid double-quoted strings + repaired = re.sub(r'"[^"\\]*(?:\\.[^"\\]*)*"', replace_quoted, repaired) + + # Now convert remaining single quotes to double quotes + repaired = repaired.replace("'", '"') + + # Restore original double-quoted strings + for key, value in quoted_strings.items(): + repaired = repaired.replace(key, value) + + # Remove trailing commas before closing brackets/braces + repaired = re.sub(r",\s*([}\]])", r"\1", repaired) + + # Add missing commas between elements + repaired = re.sub(r"([}\"])\s*([{\[])", r"\1,\2", repaired) + + # Fix unquoted string values + repaired = re.sub(r"([{,]\s*)(\w+)(\s*:)", r'\1"\2"\3', repaired) + + # Remove any extra whitespace around colons + repaired = re.sub(r"\s*:\s*", ":", repaired) + + # If the string is wrapped in extra quotes, remove them + if repaired.startswith('"') and repaired.endswith('"'): + repaired = repaired[1:-1] + + # Extract the substring from the first { to the last } + start = repaired.find("{") + end = repaired.rfind("}") + if start != -1 and end != -1: + repaired = repaired[start : end + 1] + else: + # If no valid JSON object found, return empty object + return "{}" + + # Final cleanup of whitespace + repaired = re.sub(r"\s+", " ", repaired) + + return repaired + + +class SingleLLMCallNodeConfig(BaseNodeConfig): + llm_info: ModelInfo = Field( + ModelInfo(model=LLMModels.GPT_4O, max_tokens=16384, temperature=0.7), + description="The default LLM model to use", + ) + system_message: str = Field( + "You are a helpful assistant.", + description="The system message for the LLM", + ) + user_message: str = Field( + "", + description="The user message for the LLM, serialized from input_schema", + ) + few_shot_examples: Optional[List[Dict[str, str]]] = None + url_variables: Optional[Dict[str, str]] = Field( + None, + description=( + "Optional mapping of URL types (image, video, pdf)" + " to input schema variables for Gemini models" + ), + ) + enable_thinking: bool = Field( + False, + description="Whether to enable thinking mode for supported models", + ) + thinking_budget_tokens: Optional[int] = Field( + None, + description="Budget tokens for thinking mode when enabled", + ) + enable_message_history: bool = Field( + False, + description="Whether to include message history from input in the LLM request", + ) + message_history_variable: Optional[str] = Field( + None, + description="Input variable containing message history (e.g., 'message_history')", + ) + + +class SingleLLMCallNodeInput(BaseNodeInput): + pass + + class Config: + extra = "allow" + + +class SingleLLMCallNodeOutput(BaseNodeOutput): + pass + + +class SingleLLMCallNode(BaseNode): + """Node for making a single LLM call with structured input/output. + + Features: + - Supports variable substitution in system and user messages + - Handles JSON schema validation for outputs + - Supports message history for conversational contexts + - Compatible with various LLM providers through configuration + """ + + name = "single_llm_call_node" + display_name = "Single LLM Call" + config_model = SingleLLMCallNodeConfig + input_model = SingleLLMCallNodeInput + output_model = SingleLLMCallNodeOutput + + def setup(self) -> None: + super().setup() + if self.config.output_json_schema: + self.output_model = json_schema_to_model( + json.loads(self.config.output_json_schema), + self.name, + SingleLLMCallNodeOutput, + ) # type: ignore + + async def run(self, input: BaseModel) -> BaseModel: + # Grab the entire dictionary from the input + raw_input_dict = input.model_dump() + + # Render system_message + system_message = Template(self.config.system_message).render(raw_input_dict) + + try: + # If user_message is empty, dump the entire raw dictionary + if not self.config.user_message.strip(): + user_message = json.dumps(raw_input_dict, indent=2) + else: + user_message = Template(self.config.user_message).render(**raw_input_dict) + except Exception as e: + print(f"[ERROR] Failed to render user_message {self.name}") + print(f"[ERROR] user_message: {self.config.user_message} with input: {raw_input_dict}") + raise e + + # Extract message history from input if enabled + history: Optional[List[Dict[str, str]]] = None + if self.config.enable_message_history and self.config.message_history_variable: + try: + # Try to get history from the specified variable + history_var = self.config.message_history_variable + if "." in history_var: + # Handle nested fields (e.g., "input_node.message_history") + history = get_nested_field(history_var, input) + else: + # Direct field access + history = raw_input_dict.get(history_var) + + assert isinstance(history, list) or history is None, ( + f"Expected message history to be a list or None, got {type(history)}" + ) + except Exception as e: + print(f"[ERROR] Failed to extract message history: {e}") + history = None + + messages = create_messages( + system_message=system_message, + user_message=user_message, + few_shot_examples=self.config.few_shot_examples, + history=history, + ) + + model_name = LLMModels(self.config.llm_info.model).value + + url_vars: Optional[Dict[str, str]] = None + # Process URL variables if they exist and we're using a Gemini model + if self.config.url_variables: + url_vars = {} + if "file" in self.config.url_variables: + # Split the input variable reference (e.g. "input_node.video_url") + # Get the nested field value using the helper function + file_value = get_nested_field(self.config.url_variables["file"], input) + # Always use image_url format regardless of file type + url_vars["image"] = file_value + + # Prepare thinking parameters if enabled + thinking_params = None + if self.config.enable_thinking: + model_info = LLMModels.get_model_info(model_name) + if model_info and model_info.constraints.supports_thinking: + thinking_params = { + "type": "enabled", + "budget_tokens": self.config.thinking_budget_tokens + or model_info.constraints.thinking_budget_tokens + or 1024, + } + + try: + message_response = await generate_text( + messages=messages, + model_name=model_name, + temperature=self.config.llm_info.temperature, + max_tokens=self.config.llm_info.max_tokens, + json_mode=True, + url_variables=url_vars, + output_json_schema=self.config.output_json_schema, + thinking=thinking_params, + ) + + # Extract content from Message object + assistant_message_content = message_response.content + if assistant_message_content is None: + raise ValueError("Assistant message content is None") + + try: + assistant_message_dict = json.loads(assistant_message_content) + except Exception: + try: + repaired_str = repair_json(assistant_message_content) + assistant_message_dict = json.loads(repaired_str) + except Exception as inner_e: + error_str = str(inner_e) + error_message = ( + "An error occurred while parsing and repairing the assistant message" + ) + error_type = "json_parse_error" + raise Exception( + json.dumps( + { + "type": "parsing_error", + "error_type": error_type, + "message": error_message, + "original_error": error_str, + "assistant_message_str": assistant_message_content, + } + ) + ) from inner_e + except Exception as e: + error_str = str(e) + + # Handle all LiteLLM errors + if "litellm" in error_str.lower(): + error_message = "An error occurred with the LLM service" + error_type = "unknown" + + # Extract provider from model name + provider = model_name.split("/")[0] if "/" in model_name else "unknown" + + # Handle specific known error cases + if "VertexAIError" in error_str and "The model is overloaded" in error_str: + error_type = "overloaded" + error_message = "The model is currently overloaded. Please try again later." + elif "rate limit" in error_str.lower(): + error_type = "rate_limit" + error_message = "Rate limit exceeded. Please try again in a few minutes." + elif "context length" in error_str.lower() or "maximum token" in error_str.lower(): + error_type = "context_length" + error_message = ( + "Input is too long for the model's context window." + " Please reduce the input length." + ) + elif ( + "invalid api key" in error_str.lower() or "authentication" in error_str.lower() + ): + error_type = "auth" + error_message = ( + "Authentication error with the LLM service. Please check your API key." + ) + elif "bad gateway" in error_str.lower() or "503" in error_str: + error_type = "service_unavailable" + error_message = ( + "The LLM service is temporarily unavailable. Please try again later." + ) + + raise Exception( + json.dumps( + { + "type": "model_provider_error", + "provider": provider, + "error_type": error_type, + "message": error_message, + "original_error": error_str, + } + ) + ) from e + raise e + + # Validate and return + try: + assistant_message = self.output_model.model_validate(assistant_message_dict) + return assistant_message + except Exception as e: + # For better debugging, include the raw response + raw_response = assistant_message_content + + # Also include what we attempted to validate + validation_input = json.dumps(assistant_message_dict, default=str) + + error_message = ( + f"The LLM did not return valid JSON that matches the expected schema.\n\n" + f"Raw LLM response:\n{raw_response}\n\n" + f"Attempted to validate:\n{validation_input}\n\n" + f"Validation error: {str(e)}" + ) + + raise Exception( + json.dumps( + { + "type": "invalid_json_format", + "message": error_message, + "original_response": raw_response, + "validation_input": validation_input, + "validation_error": str(e), + } + ) + ) from e + + +if __name__ == "__main__": + import asyncio + + from pydantic import create_model + + async def test_llm_nodes(): + # Example 1: Simple test case with a basic user message + simple_llm_node = SingleLLMCallNode( + name="WeatherBot", + config=SingleLLMCallNodeConfig( + llm_info=ModelInfo(model=LLMModels.GPT_4O, temperature=0.4, max_tokens=100), + system_message="You are a helpful assistant.", + user_message="Hello, my name is {{ name }}. I want to ask: {{ question }}", + url_variables=None, + enable_thinking=False, + thinking_budget_tokens=None, + enable_message_history=False, + message_history_variable=None, + output_json_schema=json.dumps( + { + "type": "object", + "properties": { + "answer": {"type": "string"}, + "name_of_user": {"type": "string"}, + }, + "required": ["answer", "name_of_user"], + } + ), + ), + ) + + simple_input = create_model( + "SimpleInput", + name=(str, ...), + question=(str, ...), + __base__=BaseNodeInput, + ).model_validate( + { + "name": "Alice", + "question": "What is the weather like in New York in January?", + } + ) + + print("[DEBUG] Testing simple_llm_node now...") + simple_output = await simple_llm_node(simple_input) + print("[DEBUG] Test Output from single_llm_call:", simple_output) + + # Example 2: Using message history + chat_llm_node = SingleLLMCallNode( + name="ChatBot", + config=SingleLLMCallNodeConfig( + llm_info=ModelInfo(model=LLMModels.GPT_4O, temperature=0.7, max_tokens=100), + system_message=( + "You are a helpful and friendly assistant. Maintain conversation context." + ), + user_message="{{ user_message }}", + url_variables=None, + enable_thinking=False, + thinking_budget_tokens=None, + enable_message_history=True, + message_history_variable="message_history", + output_json_schema=json.dumps( + { + "type": "object", + "properties": {"assistant_message": {"type": "string"}}, + "required": ["assistant_message"], + } + ), + ), + ) + + # Create input with message history + chat_input = create_model( + "ChatInput", + user_message=(str, ...), + message_history=(list, ...), + __base__=BaseNodeInput, + ).model_validate( + { + "user_message": "What's the capital of France?", + "message_history": [ + {"role": "user", "content": "Hello, can you help me with geography questions?"}, + { + "role": "assistant", + "content": ( + "Of course! I'd be happy to help with geography questions." + " What would you like to know?" + ), + }, + ], + } + ) + + print("[DEBUG] Testing chat_llm_node with message history...") + chat_output = await chat_llm_node(chat_input) + print("[DEBUG] Test Output from chat with history:", chat_output) + + asyncio.run(test_llm_nodes()) diff --git a/pyspur/backend/pyspur/nodes/logic/__init__.py b/pyspur/backend/pyspur/nodes/logic/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pyspur/backend/pyspur/nodes/logic/coalesce.py b/pyspur/backend/pyspur/nodes/logic/coalesce.py new file mode 100644 index 0000000000000000000000000000000000000000..ee23ad25948373c0ea9a66e2682a68b916256479 --- /dev/null +++ b/pyspur/backend/pyspur/nodes/logic/coalesce.py @@ -0,0 +1,91 @@ +from typing import Dict, List, Optional + +from pydantic import BaseModel, create_model + +from ..base import BaseNode, BaseNodeConfig, BaseNodeInput, BaseNodeOutput + + +class CoalesceNodeConfig(BaseNodeConfig): + """Configuration for the coalesce node.""" + + preferences: List[str] = [] + + +class CoalesceNodeInput(BaseNodeInput): + """Input model for the coalesce node.""" + + pass + + +class CoalesceNodeOutput(BaseNodeOutput): + """Output model for the coalesce node.""" + + class Config: + arbitrary_types_allowed = True + + pass + + +class CoalesceNode(BaseNode): + """ + A Coalesce node that takes multiple incoming branches and outputs + the first non-null branch's value as its result. + """ + + name = "coalesce_node" + display_name = "Coalesce" + input_model = CoalesceNodeInput + config_model = CoalesceNodeConfig + + async def run(self, input: BaseModel) -> BaseModel: + """ + The `input` here is typically a Pydantic model whose fields correspond + to each upstream dependency. Some may be None, some may be a valid + BaseModel/dict. We find the first non-None field and return it. + """ + self.output_model = CoalesceNodeOutput + + data = input.model_dump() + first_non_null_output: Dict[str, Optional[BaseModel]] = {} + + # Iterate over the keys based on the order specified in preferences + for key in self.config.preferences: # {{ edit_1 }} + if key in data and data[key] is not None: + # Return the first non-None value according to preferences + output_model = create_model( + f"{self.name}", + **{ + k: (type(v), ...) for k, v in data[key].items() + }, # Only include the first non-null key + __base__=CoalesceNodeOutput, + __config__=None, + __module__=self.__module__, + __doc__=f"Output model for {self.name} node", + __validators__=None, + __cls_kwargs__=None, + ) + self.output_model = output_model + first_non_null_output = data[key] + return self.output_model(**first_non_null_output) + + # If all preferred values are None, check the rest of the data + for key, value in data.items(): + if value is not None: + # Return the first non-None value immediately + output_model = create_model( + f"{self.name}", + **{ + k: (type(v), ...) for k, v in value.items() + }, # Only include the first non-null key + __base__=CoalesceNodeOutput, + __config__=None, + __module__=self.__module__, + __doc__=f"Output model for {self.name} node", + __validators__=None, + __cls_kwargs__=None, + ) + self.output_model = output_model + return self.output_model(**value) + + # If all values are None, return an empty output + return None # type: ignore diff --git a/pyspur/backend/pyspur/nodes/logic/human_intervention.py b/pyspur/backend/pyspur/nodes/logic/human_intervention.py new file mode 100644 index 0000000000000000000000000000000000000000..257a4b520fd94031afc8cca37ab68457abbec5e1 --- /dev/null +++ b/pyspur/backend/pyspur/nodes/logic/human_intervention.py @@ -0,0 +1,95 @@ +from enum import Enum as PyEnum +from typing import Optional + +from pydantic import BaseModel, Field + +from ..base import BaseNode, BaseNodeConfig, BaseNodeInput, BaseNodeOutput +from ..registry import NodeRegistry + + +class PauseError(Exception): + """Raised when a workflow execution needs to pause for human intervention.""" + + def __init__( + self, + node_id: str, + message: str = "Human intervention required", + output: Optional[BaseNodeOutput] = None, + ): + self.node_id = node_id + self.message = message + self.output = output + super().__init__(f"Workflow paused at node {node_id}: {message}") + + +class PauseAction(PyEnum): + """Actions that can be taken on a paused workflow.""" + + APPROVE = "APPROVE" + DECLINE = "DECLINE" + OVERRIDE = "OVERRIDE" + + +class HumanInterventionNodeConfig(BaseNodeConfig): + message: str = Field( + default="Human intervention required", + description="Message to display to the user when workflow is paused", + ) + block_only_dependent_nodes: bool = Field( + default=True, + description=( + "If True, only nodes that depend on this node's output will be blocked." + " If False, all downstream nodes will be blocked." + ), + ) + + +class HumanInterventionNodeInput(BaseNodeInput): + """Input model for the human intervention node.""" + + class Config: + extra = "allow" + + +class HumanInterventionNodeOutput(BaseNodeOutput): + class Config: + extra = "allow" # Allow extra fields from the input to pass through + + +@NodeRegistry.register( + category="Logic", + display_name="HumanIntervention", + # logo="/images/human_intervention.png", + position="after:RouterNode", +) +class HumanInterventionNode(BaseNode): + """A node that pauses workflow execution and waits for human input. + + When this node is executed, it pauses the workflow until human intervention + occurs. All input data is passed through to the output after approval. + """ + + name = "human_intervention_node" + config_model = HumanInterventionNodeConfig + input_model = HumanInterventionNodeInput + output_model = HumanInterventionNodeOutput + + def setup(self) -> None: + """Human intervention node setup.""" + super().setup() + + @property + def node_id(self) -> str: + # Return the node id from the instance dict if available, otherwise fallback to self.name + return str(self.__dict__.get("id", self.name)) + + async def run(self, input: BaseModel) -> BaseNodeOutput: + """Process input and pause the workflow execution. + + preserving the nested structure so that downstream nodes can access + outputs as {{HumanInterventionNode_1.input_node.input_1}}. + """ + # Pass through the input data to preserve the nested structure + output_dict = input.model_dump() + output = HumanInterventionNodeOutput(**output_dict) + raise PauseError(str(self.node_id), self.config.message, output) diff --git a/pyspur/backend/pyspur/nodes/logic/merge.py b/pyspur/backend/pyspur/nodes/logic/merge.py new file mode 100644 index 0000000000000000000000000000000000000000..0ec4987b20b417d9fe27e36184ff47e47790028d --- /dev/null +++ b/pyspur/backend/pyspur/nodes/logic/merge.py @@ -0,0 +1,47 @@ +import logging +from typing import Optional + +from pydantic import BaseModel, create_model + +from ..base import BaseNode, BaseNodeConfig, BaseNodeInput, BaseNodeOutput + +logger = logging.getLogger(__name__) + + +class MergeNodeConfig(BaseNodeConfig): + has_fixed_output: bool = False + + +class MergeNodeInput(BaseNodeInput): + pass + + +class MergeNodeOutput(BaseNodeOutput): + class Config: + arbitrary_types_allowed = True + + pass + + +class MergeNode(BaseNode): + """Takes all its inputs and merge them into one output.""" + + name = "merge_node" + display_name = "Merge" + input_model = MergeNodeInput + config_model = MergeNodeConfig + + async def run(self, input: BaseModel) -> BaseModel: + data = input.model_dump() + + self.output_model = create_model( + f"{self.name}", + **{k: (Optional[type(v)], ...) for k, v in data.items()}, + __base__=MergeNodeOutput, + __module__=self.__module__, + __doc__=f"Output model for {self.name} node", + __cls_kwargs__=None, + __config__=None, + __validators__=None, + ) + return self.output_model(**data) diff --git a/pyspur/backend/pyspur/nodes/logic/router.py b/pyspur/backend/pyspur/nodes/logic/router.py new file mode 100644 index 0000000000000000000000000000000000000000..15c0d2ed0d71888bf786d9e8f23a77f51463fff6 --- /dev/null +++ b/pyspur/backend/pyspur/nodes/logic/router.py @@ -0,0 +1,220 @@ +from typing import Any, Dict, Optional + +from pydantic import BaseModel, create_model + +from ...schemas.router_schemas import ( + ComparisonOperator, + RouteConditionGroupSchema, + RouteConditionRuleSchema, +) +from ..base import BaseNode, BaseNodeConfig, BaseNodeInput, BaseNodeOutput + + +class RouterNodeConfig(BaseNodeConfig): + """Configuration for the router node.""" + + route_map: Dict[str, RouteConditionGroupSchema] = { + "route1": RouteConditionGroupSchema( + conditions=[ + RouteConditionRuleSchema( + variable="", operator=ComparisonOperator.CONTAINS, value="" + ) + ] + ) + } + + +class RouterNodeInput(BaseNodeInput): + """Input model for the router node.""" + + pass + + +class RouterNodeOutput(BaseNodeOutput): + """Output model for the router node.""" + + class Config: + arbitrary_types_allowed = True + + pass + + +class RouterNode(BaseNode): + """ + A routing node that directs input data to different routes + based on the evaluation of conditions. The first route acts as the default + if no other conditions match. + """ + + name = "router_node" + display_name = "Router" + input_model = RouterNodeInput + config_model = RouterNodeConfig + + def _evaluate_single_condition( + self, input: BaseModel, condition: RouteConditionRuleSchema + ) -> bool: + """Evaluate a single condition against a specific input variable""" + + def get_nested_value(data: Dict[str, Any], target_key: str) -> Any: + """Get value from nested dictionary using dot notation path.""" + keys = target_key.split(".") + current = data + + for key in keys: + if not isinstance(current, dict): + return None + if key not in current: + return None + current = current[key] + + return current + + try: + if not condition.variable: + return False + + # Retrieve the variable value, including support for nested paths + variable_value = get_nested_value(input.model_dump(), condition.variable) + + if variable_value is None: + if condition.operator != ComparisonOperator.IS_EMPTY: + return False + else: + return True + if condition.operator == ComparisonOperator.CONTAINS: + return str(condition.value) in str(variable_value) + elif condition.operator == ComparisonOperator.EQUALS: + return str(variable_value) == str(condition.value) + elif condition.operator == ComparisonOperator.NUMBER_EQUALS: + return float(variable_value) == float(condition.value) + elif condition.operator == ComparisonOperator.GREATER_THAN: + return float(variable_value) > float(condition.value) + elif condition.operator == ComparisonOperator.LESS_THAN: + return float(variable_value) < float(condition.value) + elif condition.operator == ComparisonOperator.STARTS_WITH: + return str(variable_value).startswith(str(condition.value)) + elif condition.operator == ComparisonOperator.NOT_STARTS_WITH: + return not str(variable_value).startswith(str(condition.value)) + elif condition.operator == ComparisonOperator.IS_EMPTY: + return not bool(variable_value) + elif condition.operator == ComparisonOperator.IS_NOT_EMPTY: + return bool(variable_value) + else: + return False + except (ValueError, TypeError, AttributeError): + return False + + def _evaluate_route_conditions( + self, input: BaseModel, route: RouteConditionGroupSchema + ) -> bool: + """Evaluate all conditions in a route with AND/OR logic""" + if not route.conditions: + # If no conditions, consider it always matches + return True + + result = self._evaluate_single_condition(input, route.conditions[0]) + + for i in range(1, len(route.conditions)): + condition = route.conditions[i] + current_result = self._evaluate_single_condition(input, condition) + + if condition.logicalOperator == "OR": + result = result or current_result + else: # AND is default + result = result and current_result + + return result + + async def run(self, input: BaseModel) -> BaseModel: + """ + Evaluates conditions for each route in order. The first route that matches + gets the input data. If no routes match, the first route acts as a default. + """ + output_model = create_model( + f"{self.name}", + __config__=None, + __base__=RouterNodeOutput, + __doc__=f"Output model for {self.name} node", + __module__=self.__module__, + __validators__=None, + __cls_kwargs__=None, + **{ + field_name: (field_type, None) + for field_name, field_type in input.model_fields.items() + }, + ) + # Create fields for each route with Optional[input type] + route_fields = { + route_name: (Optional[output_model], None) + for route_name in self.config.route_map.keys() + } + new_output_model = create_model( + f"{self.name}CompositeOutput", + __base__=RouterNodeOutput, + __config__=None, + __doc__=f"Composite output model for {self.name} node", + __module__=self.__module__, + __validators__=None, + __cls_kwargs__=None, + **route_fields, + ) + self.output_model = new_output_model + + output: Dict[str, Optional[BaseModel]] = {} + + for route_name, route in self.config.route_map.items(): + if self._evaluate_route_conditions(input, route): + output[route_name] = output_model(**input.model_dump()) + + return self.output_model(**output) + + +if __name__ == "__main__": + # Test the RouterNode + import asyncio + + from pydantic import BaseModel + + class TestInput(RouterNodeInput): + name: str + age: int + is_student: bool + grade: str + + config = RouterNodeConfig( + route_map={ + "route1": RouteConditionGroupSchema( + conditions=[ + RouteConditionRuleSchema( + variable="age", + operator=ComparisonOperator.GREATER_THAN, + value=18, + ), + RouteConditionRuleSchema( + variable="is_student", + operator=ComparisonOperator.EQUALS, + value=True, + ), + ] + ), + "route2": RouteConditionGroupSchema( + conditions=[ + RouteConditionRuleSchema( + variable="grade", + operator=ComparisonOperator.EQUALS, + value="A", + ), + ] + ), + } + ) + + node = RouterNode(config=config, name="router_node") + + input_data = TestInput(name="Alice", age=20, is_student=True, grade="B") + output = asyncio.run(node(input_data)) + import json + + print(json.dumps(output.model_json_schema(), indent=2)) + print(json.dumps(output.model_dump(), indent=2)) diff --git a/pyspur/backend/pyspur/nodes/loops/base_loop_subworkflow_node.py b/pyspur/backend/pyspur/nodes/loops/base_loop_subworkflow_node.py new file mode 100644 index 0000000000000000000000000000000000000000..e4cb7e4ac81725753838f1ed5ff2f2f8f555bf70 --- /dev/null +++ b/pyspur/backend/pyspur/nodes/loops/base_loop_subworkflow_node.py @@ -0,0 +1,111 @@ +from abc import abstractmethod +from typing import Any, Dict, List + +from pydantic import BaseModel, create_model + +from ...execution.workflow_executor import WorkflowExecutor +from ...schemas.workflow_schemas import WorkflowDefinitionSchema +from ..base import BaseNodeInput, BaseNodeOutput +from ..primitives.output import OutputNode +from ..subworkflow.base_subworkflow_node import ( + BaseSubworkflowNode, + BaseSubworkflowNodeConfig, +) + + +class BaseLoopSubworkflowNodeConfig(BaseSubworkflowNodeConfig): + subworkflow: WorkflowDefinitionSchema + + +class BaseLoopSubworkflowNodeInput(BaseNodeInput): + pass + + +class BaseLoopSubworkflowNodeOutput(BaseNodeOutput): + pass + + +class BaseLoopSubworkflowNode(BaseSubworkflowNode): + name = "loop_subworkflow_node" + config_model = BaseLoopSubworkflowNodeConfig + iteration: int + loop_outputs: Dict[str, List[Dict[str, Any]]] + + def setup(self) -> None: + super().setup() + self.loop_outputs = {} + self.iteration = 0 + + def _update_loop_outputs(self, iteration_output: Dict[str, Dict[str, Any]]) -> None: + """Update the loop_outputs dictionary with the current iteration's output""" + for node_id, node_outputs in iteration_output.items(): + # Skip storing the special loop_history field + if "loop_history" in node_outputs: + node_outputs = {k: v for k, v in node_outputs.items() if k != "loop_history"} + + if node_id not in self.loop_outputs: + self.loop_outputs[node_id] = [node_outputs] + else: + self.loop_outputs[node_id].append(node_outputs) + + @abstractmethod + async def stopping_condition(self, input: Dict[str, Any]) -> bool: + """Determine whether to stop the loop based on the current input""" + pass + + async def run_iteration(self, input: Dict[str, Any]) -> Dict[str, Any]: + """Run a single iteration of the loop subworkflow""" + self.subworkflow = self.config.subworkflow + assert self.subworkflow is not None + + # Inject loop outputs into the input + iteration_input = {**input, "loop_history": self.loop_outputs} + + # Execute the subworkflow + self._executor = WorkflowExecutor(workflow=self.config.subworkflow, context=self.context) + workflow_executor = self._executor + outputs = await workflow_executor.run(iteration_input) + + # Convert outputs to dict format + iteration_outputs = {node_id: output.model_dump() for node_id, output in outputs.items()} + + # Update loop outputs with this iteration's results + self._update_loop_outputs(iteration_outputs) + + # Get the output node's results + output_node = next( + node for node in self.subworkflow.nodes if node.node_type == "OutputNode" + ) + return iteration_outputs[output_node.id] + + async def run(self, input: BaseModel) -> BaseModel: + """Execute the loop subworkflow until stopping condition is met""" + current_input = self._map_input(input) + + # Run iterations until stopping condition is met + while not await self.stopping_condition(current_input): + iteration_output = await self.run_iteration(current_input) + current_input.update(iteration_output) + self.iteration += 1 + + self.subworkflow_output = self.loop_outputs + + # create output model for the loop from the subworkflow output node's output_model + output_node = next( + node + for _id, node in self._executor.node_instances.items() + if issubclass(node.__class__, OutputNode) + ) + self.output_model = create_model( + f"{self.name}", + **{name: (field, ...) for name, field in output_node.output_model.model_fields.items()}, + __base__=BaseLoopSubworkflowNodeOutput, + __config__=None, + __module__=self.__module__, + __cls_kwargs__={"arbitrary_types_allowed": True}, + __doc__=None, + __validators__=None, + ) + + # Return final state as BaseModel + return self.output_model.model_validate(current_input) # type: ignore diff --git a/pyspur/backend/pyspur/nodes/loops/for_loop_node.py b/pyspur/backend/pyspur/nodes/loops/for_loop_node.py new file mode 100644 index 0000000000000000000000000000000000000000..ec4fbd9e8222ba9aee32d695bcb1d685667cd3d4 --- /dev/null +++ b/pyspur/backend/pyspur/nodes/loops/for_loop_node.py @@ -0,0 +1,118 @@ +from typing import Any, Dict + +from pydantic import Field + +from ...schemas.workflow_schemas import WorkflowDefinitionSchema +from .base_loop_subworkflow_node import ( + BaseLoopSubworkflowNode, + BaseLoopSubworkflowNodeConfig, + BaseLoopSubworkflowNodeInput, +) + + +class ForLoopNodeConfig(BaseLoopSubworkflowNodeConfig): + num_iterations: int = Field( + default=1, + title="Number of iterations", + description="Number of times to execute the loop", + ) + + +class ForLoopNodeInput(BaseLoopSubworkflowNodeInput): + pass + + +class ForLoopNode(BaseLoopSubworkflowNode): + name = "for_loop" + config_model = ForLoopNodeConfig + input_model = ForLoopNodeInput + + async def stopping_condition(self, input: Dict[str, Any]) -> bool: + """Stop when we've reached the configured number of iterations""" + return self.iteration >= self.config.num_iterations + + +if __name__ == "__main__": + import asyncio + from pprint import pprint + + from ...schemas.workflow_schemas import ( + WorkflowLinkSchema, + WorkflowNodeSchema, + ) + + async def main(): + node = ForLoopNode( + name="test_loop", + config=ForLoopNodeConfig( + subworkflow=WorkflowDefinitionSchema( + nodes=[ + WorkflowNodeSchema( + id="loop_input", + node_type="InputNode", + config={ + "output_schema": { + "count": "int", + "loop_history": "dict", + }, + "enforce_schema": False, + }, + ), + WorkflowNodeSchema( + id="increment", + node_type="PythonFuncNode", + config={ + "code": """ +previous_outputs = input_model.loop_input.loop_history.get('increment', []) +running_total = sum(output['count'] for output in previous_outputs) if previous_outputs else 0 +running_total += input_model.loop_input.count + 1 +return { + 'count': input_model.loop_input.count + 1, + 'running_total': running_total +} +""", + "output_schema": { + "count": "int", + "running_total": "int", + }, + }, + ), + WorkflowNodeSchema( + id="loop_output", + node_type="OutputNode", + config={ + "output_map": { + "count": "increment.count", + "running_total": "increment.running_total", + }, + "output_schema": { + "count": "int", + "running_total": "int", + }, + }, + ), + ], + links=[ + WorkflowLinkSchema( + source_id="loop_input", + target_id="increment", + ), + WorkflowLinkSchema( + source_id="increment", + target_id="loop_output", + ), + ], + ), + num_iterations=5, + ), + ) + + class TestInput(ForLoopNodeInput): + count: int = 0 + + input_data = TestInput() + output = await node(input_data) + pprint(output) + pprint(node.subworkflow_output) + + asyncio.run(main()) diff --git a/pyspur/backend/pyspur/nodes/node_types.py b/pyspur/backend/pyspur/nodes/node_types.py new file mode 100644 index 0000000000000000000000000000000000000000..418a453f21e6560dacf6895063fe8c74099f98ea --- /dev/null +++ b/pyspur/backend/pyspur/nodes/node_types.py @@ -0,0 +1,273 @@ +from typing import Dict, List + +from ..schemas.node_type_schemas import NodeTypeSchema +from .registry import NodeRegistry + +# Simple lists of supported and deprecated node types + + +SUPPORTED_NODE_TYPES = { + "Input/Output": [ + { + "node_type_name": "InputNode", + "module": ".nodes.primitives.input", + "class_name": "InputNode", + }, + { + "node_type_name": "OutputNode", + "module": ".nodes.primitives.output", + "class_name": "OutputNode", + }, + ], + "AI": [ + { + "node_type_name": "SingleLLMCallNode", + "module": ".nodes.llm.single_llm_call", + "class_name": "SingleLLMCallNode", + }, + { + "node_type_name": "AgentNode", + "module": ".nodes.llm.agent", + "class_name": "AgentNode", + }, + { + "node_type_name": "RetrieverNode", + "module": ".nodes.llm.retriever", + "class_name": "RetrieverNode", + }, + { + "node_type_name": "BestOfNNode", + "module": ".nodes.llm.generative.best_of_n", + "class_name": "BestOfNNode", + }, + ], + "Code Execution": [ + { + "node_type_name": "PythonFuncNode", + "module": ".nodes.python.python_func", + "class_name": "PythonFuncNode", + }, + ], + "Logic": [ + { + "node_type_name": "RouterNode", + "module": ".nodes.logic.router", + "class_name": "RouterNode", + }, + { + "node_type_name": "CoalesceNode", + "module": ".nodes.logic.coalesce", + "class_name": "CoalesceNode", + }, + { + "node_type_name": "MergeNode", + "module": ".nodes.logic.merge", + "class_name": "MergeNode", + }, + { + "node_type_name": "StaticValueNode", + "module": ".nodes.primitives.static_value", + "class_name": "StaticValueNode", + }, + ], + "Experimental": [ + { + "node_type_name": "ForLoopNode", + "module": ".nodes.loops.for_loop_node", + "class_name": "ForLoopNode", + } + ], + "Integrations": [ + { + "node_type_name": "SlackNotifyNode", + "module": ".nodes.integrations.slack.slack_notify", + "class_name": "SlackNotifyNode", + }, + { + "node_type_name": "GoogleSheetsReadNode", + "module": ".nodes.integrations.google.google_sheets_read", + "class_name": "GoogleSheetsReadNode", + }, + { + "node_type_name": "YouTubeTranscriptNode", + "module": ".nodes.integrations.youtube.youtube_transcript", + "class_name": "YouTubeTranscriptNode", + }, + { + "node_type_name": "GitHubListPullRequestsNode", + "module": ".nodes.integrations.github.github_list_pull_requests", + "class_name": "GitHubListPullRequestsNode", + }, + { + "node_type_name": "GitHubListRepositoriesNode", + "module": ".nodes.integrations.github.github_list_repositories", + "class_name": "GitHubListRepositoriesNode", + }, + { + "node_type_name": "GitHubGetRepositoryNode", + "module": ".nodes.integrations.github.github_get_repository", + "class_name": "GitHubGetRepositoryNode", + }, + { + "node_type_name": "GitHubSearchRepositoriesNode", + "module": ".nodes.integrations.github.github_search_repositories", + "class_name": "GitHubSearchRepositoriesNode", + }, + { + "node_type_name": "GitHubGetPullRequestNode", + "module": ".nodes.integrations.github.github_get_pull_request", + "class_name": "GitHubGetPullRequestNode", + }, + { + "node_type_name": "GitHubGetPullRequestChangesNode", + "module": ".nodes.integrations.github.github_get_pull_request_changes", + "class_name": "GitHubGetPullRequestChangesNode", + }, + { + "node_type_name": "GitHubCreateIssueNode", + "module": ".nodes.integrations.github.github_create_issue", + "class_name": "GitHubCreateIssueNode", + }, + # { + # "node_type_name": "FirecrawlCrawlNode", + # "module": ".nodes.integrations.firecrawl.firecrawl_crawl", + # "class_name": "FirecrawlCrawlNode", + # }, + # { + # "node_type_name": "FirecrawlScrapeNode", + # "module": ".nodes.integrations.firecrawl.firecrawl_scrape", + # "class_name": "FirecrawlScrapeNode", + # }, + { + "node_type_name": "JinaReaderNode", + "module": ".nodes.integrations.jina.jina_reader", + "class_name": "JinaReaderNode", + }, + # Reddit nodes + { + "node_type_name": "RedditCreatePostNode", + "module": ".nodes.integrations.reddit.reddit_create_post", + "class_name": "RedditCreatePostNode", + }, + { + "node_type_name": "RedditGetTopPostsNode", + "module": ".nodes.integrations.reddit.reddit_get_top_posts", + "class_name": "RedditGetTopPostsNode", + }, + { + "node_type_name": "RedditGetUserInfoNode", + "module": ".nodes.integrations.reddit.reddit_get_user_info", + "class_name": "RedditGetUserInfoNode", + }, + { + "node_type_name": "RedditGetSubredditInfoNode", + "module": ".nodes.integrations.reddit.reddit_get_subreddit_info", + "class_name": "RedditGetSubredditInfoNode", + }, + { + "node_type_name": "RedditGetSubredditStatsNode", + "module": ".nodes.integrations.reddit.reddit_get_subreddit_stats", + "class_name": "RedditGetSubredditStatsNode", + }, + { + "node_type_name": "RedditGetTrendingSubredditsNode", + "module": ".nodes.integrations.reddit.reddit_get_trending_subreddits", + "class_name": "RedditGetTrendingSubredditsNode", + }, + ], + "Search": [ + { + "node_type_name": "ExaSearchNode", + "module": ".nodes.search.exa.search", + "class_name": "ExaSearchNode", + }, + ], + "Tools": [ + { + "node_type_name": "SendEmailNode", + "module": ".nodes.email.send_email", + "class_name": "SendEmailNode", + }, + ], +} + +DEPRECATED_NODE_TYPES = [ + { + "node_type_name": "MCTSNode", + "module": ".nodes.llm.mcts", + "class_name": "MCTSNode", + }, + { + "node_type_name": "MixtureOfAgentsNode", + "module": ".nodes.llm.mixture_of_agents", + "class_name": "MixtureOfAgentsNode", + }, + { + "node_type_name": "SelfConsistencyNode", + "module": ".nodes.llm.self_consistency", + "class_name": "SelfConsistencyNode", + }, + { + "node_type_name": "TreeOfThoughtsNode", + "module": ".nodes.llm.tree_of_thoughts", + "class_name": "TreeOfThoughtsNode", + }, + { + "node_type_name": "StringOutputLLMNode", + "module": ".nodes.llm.string_output_llm", + "class_name": "StringOutputLLMNode", + }, + { + "node_type_name": "StructuredOutputNode", + "module": ".nodes.llm.structured_output", + "class_name": "StructuredOutputNode", + }, + { + "node_type_name": "AdvancedLLMNode", + "module": ".nodes.llm.single_llm_call", + "class_name": "SingleLLMCallNode", + }, + { + "node_type_name": "SubworkflowNode", + "module": ".nodes.subworkflow.subworkflow_node", + "class_name": "SubworkflowNode", + }, + { + "node_type_name": "BranchSolveMergeNode", + "module": ".nodes.llm.generative.branch_solve_merge", + "class_name": "BranchSolveMergeNode", + }, +] + + +def get_all_node_types() -> Dict[str, List[NodeTypeSchema]]: + """Return a dictionary of all available node types grouped by category.""" + node_type_groups: Dict[str, List[NodeTypeSchema]] = {} + for group_name, node_types in SUPPORTED_NODE_TYPES.items(): + node_type_groups[group_name] = [] + for node_type_dict in node_types: + node_type = NodeTypeSchema.model_validate(node_type_dict) + node_type_groups[group_name].append(node_type) + return node_type_groups + + +def is_valid_node_type(node_type_name: str) -> bool: + """Check if a node type is valid (supported, deprecated, or registered via decorator).""" + # Check configured nodes first + for node_types in SUPPORTED_NODE_TYPES.values(): + for node_type in node_types: + if node_type["node_type_name"] == node_type_name: + return True + + for node_type in DEPRECATED_NODE_TYPES: + if node_type["node_type_name"] == node_type_name: + return True + + # Check registry for decorator-registered nodes + registered_nodes = NodeRegistry.get_registered_nodes() + for nodes in registered_nodes.values(): + for node in nodes: + if node.node_type_name == node_type_name: + return True + + return False diff --git a/pyspur/backend/pyspur/nodes/primitives/input.py b/pyspur/backend/pyspur/nodes/primitives/input.py new file mode 100644 index 0000000000000000000000000000000000000000..9ca44e0ada96cfa80e6b8331985dd18f5e1d56fc --- /dev/null +++ b/pyspur/backend/pyspur/nodes/primitives/input.py @@ -0,0 +1,84 @@ +from typing import Any, Dict, List + +from pydantic import BaseModel, create_model + +from ..base import ( + BaseNode, + BaseNodeConfig, + BaseNodeInput, + BaseNodeOutput, +) + + +class InputNodeConfig(BaseNodeConfig): + """ + Configuration for the InputNode. + enforce_schema: bool = False. If True, the output_schema will be enforced. Otherwise the output will be the same as the input. + output_schema: Dict[str, str] = {"input_1": "string"}. The schema of the output. + """ + + enforce_schema: bool = False + output_schema: Dict[str, str] = {"input_1": "string"} + output_json_schema: str = '{"type": "object", "properties": {"input_1": {"type": "string"} } }' + pass + + +class InputNodeInput(BaseNodeInput): + pass + + +class InputNodeOutput(BaseNodeOutput): + pass + + +class InputNode(BaseNode): + """ + Node for defining dataset schema and using the output as input for other nodes. + """ + + name = "input_node" + display_name = "Input" + config_model = InputNodeConfig + input_model = InputNodeInput + output_model = InputNodeOutput + + async def __call__( + self, + input: ( + Dict[str, str | int | bool | float | Dict[str, Any] | List[Any]] + | Dict[str, BaseNodeOutput] + | Dict[str, BaseNodeInput] + | BaseNodeInput + ), + ) -> BaseNodeOutput: + if isinstance(input, dict): + if not any(isinstance(value, BaseNodeOutput) for value in input.values()): + # create a new model based on the input dictionary + fields = {key: (type(value), ...) for key, value in input.items()} + self.output_model = create_model( # type: ignore + self.name, + __base__=BaseNodeOutput, + **fields, # type: ignore + ) + return self.output_model.model_validate(input) # type: ignore + return await super().__call__(input) + + async def run(self, input: BaseModel) -> BaseModel: + if self.config.enforce_schema: + return input + else: + fields = {key: (value, ...) for key, value in input.model_fields.items()} + + new_output_model = create_model( + "InputNodeOutput", + __base__=InputNodeOutput, + __config__=None, + __module__=self.__module__, + __doc__=f"Output model for {self.name} node", + __validators__=None, + __cls_kwargs__=None, + **fields, + ) + self.output_model = new_output_model + ret_value = self.output_model.model_validate(input.model_dump()) # type: ignore + return ret_value # type: ignore diff --git a/pyspur/backend/pyspur/nodes/primitives/output.py b/pyspur/backend/pyspur/nodes/primitives/output.py new file mode 100644 index 0000000000000000000000000000000000000000..3220cd43956641b552a5dd9c52483468b11b7b5d --- /dev/null +++ b/pyspur/backend/pyspur/nodes/primitives/output.py @@ -0,0 +1,82 @@ +from typing import Any, Dict + +from pydantic import BaseModel, Field, create_model + +from ...utils.pydantic_utils import get_nested_field +from ..base import ( + BaseNode, + BaseNodeConfig, + BaseNodeInput, + BaseNodeOutput, +) + + +class OutputNodeConfig(BaseNodeConfig): + """ + Configuration for the OutputNode, focusing on mapping from input fields + (possibly nested via dot notation) to output fields. + """ + + output_map: Dict[str, str] = Field( + default_factory=dict, + title="Output Map", + description="A dictionary mapping input field names (dot-notation allowed) to output field names.", + ) + + +class OutputNode(BaseNode): + """ + Node for defining a typed output schema automatically by inferring it + from the output_map. If output_map is empty, it will simply pass the + entire input through unmodified. + """ + + name = "output_node" + display_name = "Output" + config_model = OutputNodeConfig + input_model = BaseNodeInput + output_model = BaseNodeOutput + + async def run(self, input: BaseModel) -> BaseModel: + """ + Maps the incoming input fields (possibly nested) to the node's output + fields according to self.config.output_map. If no output_map is set, + returns the entire input as output. + + Args: + input (BaseModel): The input model (from predecessor nodes). + + Returns: + BaseModel: The node's typed output model instance. + """ + if self.config.output_map: + # If user provided mappings, create a new model with the mapped fields + model_fields: Dict[str, Any] = {} + for output_key, input_key in self.config.output_map.items(): + model_fields[output_key] = ( + type(get_nested_field(field_name_with_dots=input_key, model=input)), + ..., + ) + self.output_model = create_model( + f"{self.name}", + **model_fields, + __base__=BaseNodeOutput, + __config__=None, + __module__=self.__module__, + ) + else: + # If user provided no mappings, just return everything + model_fields = {k: (type(v), ...) for k, v in input.model_dump().items()} + self.output_model = create_model( + f"{self.name}", + **model_fields, + __base__=BaseNodeOutput, + __config__=None, + __module__=self.__module__, + ) + + output_dict: Dict[str, Any] = {} + for output_key, input_key in self.config.output_map.items(): + output_dict[output_key] = get_nested_field(input_key, input) + + return self.output_model(**output_dict) diff --git a/pyspur/backend/pyspur/nodes/primitives/static_value.py b/pyspur/backend/pyspur/nodes/primitives/static_value.py new file mode 100644 index 0000000000000000000000000000000000000000..921b1842a5c879dc0195ad632d1ac146695a0afa --- /dev/null +++ b/pyspur/backend/pyspur/nodes/primitives/static_value.py @@ -0,0 +1,52 @@ +from typing import Any, Dict + +from pydantic import BaseModel + +from ..base import BaseNode, BaseNodeConfig, BaseNodeInput, BaseNodeOutput + + +class StaticValueNodeConfig(BaseNodeConfig): + values: Dict[str, Any] + + +class StaticValueNodeInput(BaseNodeInput): + pass + + +class StaticValueNodeOutput(BaseNodeOutput): + pass + + +class StaticValueNode(BaseNode): + """Node type for producing constant values declared in the config.""" + + name = "constant_value_node" + display_name = "Static Value" + config_model = StaticValueNodeConfig + input_model = StaticValueNodeInput + output_model = StaticValueNodeOutput + + def setup(self) -> None: + """Create a dynamic output model based on the values in the config.""" + # Convert the values dict to an output schema format + output_schema = {key: type(value).__name__ for key, value in self.config.values.items()} + + # If there are no values, use a default empty output + if not output_schema: + return + + # Create a dynamic output model based on the schema + self.output_model = self.create_output_model_class(output_schema) + + async def run(self, input: BaseModel) -> BaseModel: + return self.output_model(**self.config.values) + + +if __name__ == "__main__": + import asyncio + + # Create a proper config with the required fields + config = StaticValueNodeConfig(values={"key": "value"}) + constant_value_node = StaticValueNode(name="test_node", config=config) + output = asyncio.run(constant_value_node(StaticValueNodeInput())) + print(output) diff --git a/pyspur/backend/pyspur/nodes/python/python_func.py b/pyspur/backend/pyspur/nodes/python/python_func.py new file mode 100644 index 0000000000000000000000000000000000000000..468e0ef379256e80e2e146b082dfe20977bb5306 --- /dev/null +++ b/pyspur/backend/pyspur/nodes/python/python_func.py @@ -0,0 +1,89 @@ +from typing import Any, Dict + +from pydantic import BaseModel + +from ..base import ( + BaseNode, + BaseNodeConfig, + BaseNodeInput, + BaseNodeOutput, +) + + +class PythonFuncNodeConfig(BaseNodeConfig): + code: str = "\n".join( + [ + "# Write your Python code here.", + '# The input data is available as "input" pydantic model.', + "# Return a dictionary of variables that you would like to see in the node output.", + ] + ) + + +class PythonFuncNodeInput(BaseNodeInput): + pass + + +class PythonFuncNodeOutput(BaseNodeOutput): + pass + + +class PythonFuncNode(BaseNode): + """ + Node type for executing Python code on the input data. + """ + + name = "python_func_node" + display_name = "Python Function" + config_model = PythonFuncNodeConfig + input_model = PythonFuncNodeInput + output_model = PythonFuncNodeOutput + + def setup(self) -> None: + return super().setup() + + async def run(self, input: BaseModel) -> BaseModel: + self.output_model = self.create_output_model_class(self.config.output_schema) + # Prepare the execution environment + exec_globals: Dict[str, Any] = {} + exec_locals: Dict[str, Any] = {} + + # Indent user code properly + code_body = "\n".join(" " + line for line in self.config.code.split("\n")) + + # Build the code to execute + function_code = f"def user_function(input_model):\n{code_body}\n" + + # Execute the user-defined function code + exec(function_code, exec_globals, exec_locals) + + # Call the user-defined function and retrieve the output + output_data = exec_locals["user_function"](input) + return self.output_model.model_validate(output_data) + + +if __name__ == "__main__": + import asyncio + + from pydantic import BaseModel, create_model + + config = PythonFuncNodeConfig( + code="\n".join( + [ + "# Write your Python code here.", + '# The input data is available as "input_model" pydantic model.', + "# Return a dictionary of variables that you would like to see in the node output.", + "output = input_model.Input.number ** 2", + "return {'output': output}", + ] + ), + output_schema={"output": "int"}, + ) + A = create_model("Input", number=(int, ...), __base__=BaseNodeOutput).model_validate( + {"number": 5} + ) + input = {"Input": A} + node = PythonFuncNode(config=config, name="PythonFuncTest") + + output = asyncio.run(node(input)) + print(output) diff --git a/pyspur/backend/pyspur/nodes/registry.py b/pyspur/backend/pyspur/nodes/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..217645793ff5d040298bd69f5f12c158b3474d01 --- /dev/null +++ b/pyspur/backend/pyspur/nodes/registry.py @@ -0,0 +1,276 @@ +# backend/app/nodes/registry.py +import importlib +import importlib.util +import os +import traceback +from pathlib import Path +from typing import Any, Dict, List, Optional, Set, Type, Union + +from loguru import logger + +from ..schemas.node_type_schemas import NodeTypeSchema +from .base import BaseNode +from .decorator import FunctionToolNode, ToolFunction + + +class NodeInfo(NodeTypeSchema): + subcategory: Optional[str] = None + + +class NodeRegistry: + _nodes: Dict[str, List[NodeInfo]] = {} + _decorator_registered_classes: Set[Type[BaseNode]] = ( + set() + ) # Track classes registered via decorator + + @classmethod + def register( + cls, + category: str = "Uncategorized", + display_name: Optional[str] = None, + logo: Optional[str] = None, + subcategory: Optional[str] = None, + position: Optional[Union[int, str]] = None, + ): + """Register a node class with metadata. + + Args: + category: The category this node belongs to + display_name: Optional display name for the node + logo: Optional path to the node's logo + subcategory: Optional subcategory for finer-grained organization + position: Optional position specifier. Can be: + - Integer for absolute position + - "after:NodeName" for relative position after a node + - "before:NodeName" for relative position before a node + + Returns: + A decorator that registers the node class with the specified metadata + + """ + + def decorator(node_class: Type[BaseNode]) -> Type[BaseNode]: + # Set metadata on the class + if not hasattr(node_class, "category"): + node_class.category = category + if display_name: + node_class.display_name = display_name + if logo: + node_class.logo = logo + + # Store subcategory as class attribute without type checking + if subcategory: + node_class.subcategory = subcategory + + # Initialize category if not exists + if category not in cls._nodes: + cls._nodes[category] = [] + + # Create node registration info + # Remove 'app.' prefix from module path if present + module_path = node_class.__module__ + if module_path.startswith("pyspur."): + module_path = module_path.replace("pyspur.", "", 1) + + node_info = NodeInfo( + node_type_name=node_class.__name__, + module=f".{module_path}", + class_name=node_class.__name__, + subcategory=subcategory, + ) + + # Handle positioning + nodes_list = cls._nodes[category] + if position is not None: + if isinstance(position, int): + # Insert at specific index + insert_idx = min(position, len(nodes_list)) + nodes_list.insert(insert_idx, node_info) + elif position.startswith("after:"): + target_node = position[6:] + for i, n in enumerate(nodes_list): + if n.node_type_name == target_node: + nodes_list.insert(i + 1, node_info) + break + else: + nodes_list.append(node_info) + elif position.startswith("before:"): + target_node = position[7:] + for i, n in enumerate(nodes_list): + if n.node_type_name == target_node: + nodes_list.insert(i, node_info) + break + else: + nodes_list.append(node_info) + else: + nodes_list.append(node_info) + else: + # Add to end if no position specified + if not any(n.node_type_name == node_class.__name__ for n in nodes_list): + nodes_list.append(node_info) + logger.debug(f"Registered node {node_class.__name__} in category {category}") + cls._decorator_registered_classes.add(node_class) + + return node_class + + return decorator + + @classmethod + def get_registered_nodes( + cls, + ) -> Dict[str, List[NodeInfo]]: + """Get all registered nodes.""" + cls.discover_nodes() + return cls._nodes + + @classmethod + def _discover_in_directory(cls, base_path: Path, package_prefix: str) -> None: + """Recursively discover nodes in a directory and its subdirectories. + + Only registers nodes that explicitly use the @NodeRegistry.register decorator. + """ + # Get all Python files in current directory + for item in base_path.iterdir(): + if item.is_file() and item.suffix == ".py" and not item.name.startswith("_"): + # Construct module name from package prefix and file name + module_name = f"{package_prefix}.{item.stem}" + + try: + # Import module but don't register nodes - they'll self-register if decorated + importlib.import_module(module_name) + except Exception as e: + logger.error(f"Failed to load module {module_name}: {e}") + + # Recursively process subdirectories + elif item.is_dir() and not item.name.startswith("_"): + subpackage = f"{package_prefix}.{item.name}" + cls._discover_in_directory(item, subpackage) + + @classmethod + def discover_nodes(cls, package_path: str = "pyspur.nodes") -> None: + """Automatically discover and register nodes from the package. + + Only nodes with the @NodeRegistry.register decorator will be registered. + + Args: + package_path: The base package path to search for nodes + + """ + try: + package = importlib.import_module(package_path) + if not hasattr(package, "__file__") or package.__file__ is None: + raise ImportError(f"Cannot find package {package_path}") + + base_path = Path(package.__file__).resolve().parent + logger.info(f"Discovering nodes in: {base_path}") + + # Start recursive discovery + cls._discover_in_directory(base_path, package_path) + + # Also discover tool function nodes + cls.discover_tool_functions() + + logger.info( + "Node discovery complete." + f" Found {len(cls._decorator_registered_classes)} decorated nodes." + ) + + except ImportError as e: + logger.error(f"Failed to import base package {package_path}: {e}") + + @classmethod + def discover_tool_functions(cls) -> None: + """Discover and register tool functions from the tools directory. + + This method searches recursively through Python files in the PROJECT_ROOT/tools directory + for functions decorated with @tool_function and registers their node classes. + Only works with proper Python packages (directories with __init__.py). + """ + # Get PROJECT_ROOT from environment variable + project_root = os.getenv("PROJECT_ROOT") + if not project_root: + logger.error("PROJECT_ROOT environment variable not set") + return + + # Get the tools directory path + tools_dir = Path(project_root) / "tools" + if not tools_dir.exists(): + logger.error(f"Tools directory does not exist: {tools_dir}") + return + + logger.info(f"Discovering tool functions in: {tools_dir}") + registered_tools = 0 + + def _is_package_dir(path: Path) -> bool: + """Check if a directory is a Python package (has __init__.py).""" + return (path / "__init__.py").exists() + + def _register_tool_function_node(func: ToolFunction, category: str) -> None: + """Register a tool function node in the NodeRegistry.""" + node_class = func.node_class + category = "Custom Tools" + if category not in cls._nodes: + cls._nodes[category] = [] + + node_info = NodeInfo( + node_type_name=node_class.__name__, + module=node_class.__module__, + # Using dot notation for nested attribute + class_name=f"{func.func_name}.node_class", + subcategory=getattr(node_class, "subcategory", None), + ) + + if not any(n.node_type_name == node_class.__name__ for n in cls._nodes[category]): + cls._nodes[category].append(node_info) + nonlocal registered_tools + registered_tools += 1 + logger.debug( + f"Registered tool function {node_class.__name__} in category {category}" + ) + + def _is_valid_tool_function(attr: Any) -> bool: + """Check if an attribute is a properly decorated tool function.""" + if not isinstance(attr, ToolFunction): + return False + if not issubclass(attr.node_class, FunctionToolNode): + return False # Skip regular functions + # Must have all required node attributes + required_attrs = {"display_name", "config_model", "input_model", "output_model"} + return all(hasattr(attr.node_class, attr_name) for attr_name in required_attrs) + + def _discover_tools_in_directory(path: Path, base_package: str = "tools") -> None: + """Recursively discover tool functions in package directories.""" + # Skip if not a package directory + if not _is_package_dir(path): + return + + for item in path.iterdir(): + if item.is_file() and item.suffix == ".py" and not item.name.startswith("_"): + try: + # Get the module path relative to project root + module_path = f"{base_package}.{item.stem}" + + # Import the module using standard import_module + module = importlib.import_module(module_path) + + # Register any valid tool functions found in the module + for attr_name in dir(module): + attr = getattr(module, attr_name) + if _is_valid_tool_function(attr): + node_class = attr.node_class + category = getattr(node_class, "category", "Uncategorized") + _register_tool_function_node(attr, category) + + except Exception as e: + logger.error(f"Failed to load module {item}: {e}") + logger.error(traceback.format_exc()) + + # Recursively process subdirectories + elif item.is_dir() and not item.name.startswith("_"): + # Update the base package for the subdirectory + subpackage = f"{base_package}.{item.name}" + _discover_tools_in_directory(item, subpackage) + + # Start recursive discovery from tools directory + _discover_tools_in_directory(tools_dir) + logger.info(f"Tool function discovery complete. Found {registered_tools} tool functions.") diff --git a/pyspur/backend/pyspur/nodes/search/__init__.py b/pyspur/backend/pyspur/nodes/search/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f42ba9801d282da8d560c047c6dd6239b30c93bc --- /dev/null +++ b/pyspur/backend/pyspur/nodes/search/__init__.py @@ -0,0 +1 @@ +# Search modules diff --git a/pyspur/backend/pyspur/nodes/search/exa/__init__.py b/pyspur/backend/pyspur/nodes/search/exa/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..faf40d6fff8eace994c6ccb54714a461b183e49a --- /dev/null +++ b/pyspur/backend/pyspur/nodes/search/exa/__init__.py @@ -0,0 +1 @@ +# Exa search module diff --git a/pyspur/backend/pyspur/nodes/search/exa/search.py b/pyspur/backend/pyspur/nodes/search/exa/search.py new file mode 100644 index 0000000000000000000000000000000000000000..bb4f91153c57101d3352e91d676440691b49fb76 --- /dev/null +++ b/pyspur/backend/pyspur/nodes/search/exa/search.py @@ -0,0 +1,142 @@ +import json +import logging +import os +from typing import List, Optional + +from exa_py import Exa +from jinja2 import Template +from pydantic import BaseModel, Field + +from ...base import BaseNode, BaseNodeConfig, BaseNodeInput, BaseNodeOutput + + +class ExaSearchNodeInput(BaseNodeInput): + """Input for the ExaSearch node.""" + + # Input can come from various fields, we'll handle it in the run method + # No explicit query field needed here since we'll use a template + + class Config: + extra = "allow" + + +class ExaSearchResult(BaseModel): + title: str = Field(..., description="Title of the search result") + url: str = Field(..., description="URL of the search result") + content: Optional[str] = Field(None, description="Text content if retrieved") + score: Optional[float] = Field(None, description="Relevance score of the result") + published_date: Optional[str] = Field(None, description="Publication date if available") + author: Optional[str] = Field(None, description="Author if available") + + +class ExaSearchNodeOutput(BaseNodeOutput): + results: List[ExaSearchResult] = Field(..., description="List of search results from Exa") + + +# Define a simple schema without complex nested structures +SIMPLE_OUTPUT_SCHEMA = { + "title": "ExaSearchNodeOutput", + "type": "object", + "properties": { + "results": { + "title": "Search Results", + "type": "array", + "description": "List of search results from Exa", + "items": {"type": "object"}, + } + }, + "required": ["results"], +} + + +class ExaSearchNodeConfig(BaseNodeConfig): + max_results: int = Field( + 10, description="Maximum number of search results to return (max 100)." + ) + include_content: bool = Field( + True, description="When True, fetch and include text content of search results." + ) + max_characters: int = Field( + 1000, + description="Maximum characters to fetch for each result's content (when include_content is True).", + ) + query_template: str = Field( + "{{input_1}}", + description="Template for the query string. Use {{variable}} syntax to reference input variables.", + ) + has_fixed_output: bool = True + + # Use a simple predefined schema + output_json_schema: str = Field( + default=json.dumps(SIMPLE_OUTPUT_SCHEMA), + description="The JSON schema for the output of the node", + ) + + +class ExaSearchNode(BaseNode): + name = "exa_search_node" + display_name = "ExaSearch" + logo = "/images/exa.png" # Placeholder, you may need to add an Exa logo + category = "Search" + + config_model = ExaSearchNodeConfig + input_model = ExaSearchNodeInput + output_model = ExaSearchNodeOutput + + def setup(self) -> None: + """Override setup to handle schema issues""" + try: + super().setup() + except ValueError as e: + if "Unsupported JSON schema type" in str(e): + # If we hit schema issues, use a very basic setup + logging.warning(f"Schema error: {e}, using simplified approach") + + async def run(self, input: BaseModel) -> BaseModel: + try: + api_key = os.getenv("EXA_API_KEY") + + if not api_key: + raise ValueError("Exa API key not found in environment variables") + + # Initialize Exa client + exa = Exa(api_key=api_key) + + # Extract query from input using the template + # This approach is more flexible and handles various input field names + raw_input_dict = input.model_dump() + query = Template(self.config.query_template).render(**raw_input_dict) + + logging.info(f"Executing Exa search with query: {query}") + + # Configure content options based on config + content_options = None + if self.config.include_content: + content_options = {"max_characters": self.config.max_characters} + + # Execute search + search_results = exa.search_and_contents( + query, + num_results=min(self.config.max_results, 100), # Cap at 100 results + text=content_options if self.config.include_content else None, + ) + + # Transform results to our model format + results = [] + for result in search_results.results: + # Extract metadata and content with safer attribute access + result_data = ExaSearchResult( + title=getattr(result, "title", "Untitled"), + url=getattr(result, "url", ""), + content=getattr(result, "text", None), + score=getattr(result, "score", None), + published_date=getattr(result, "published_date", None), + author=getattr(result, "author", None), + ) + results.append(result_data) + + return ExaSearchNodeOutput(results=results) + + except Exception as e: + logging.error(f"Failed to perform Exa search: {e}") + raise e diff --git a/pyspur/backend/pyspur/nodes/subworkflow/base_subworkflow_node.py b/pyspur/backend/pyspur/nodes/subworkflow/base_subworkflow_node.py new file mode 100644 index 0000000000000000000000000000000000000000..209752f4d785af698c85725db6e265b83189ceff --- /dev/null +++ b/pyspur/backend/pyspur/nodes/subworkflow/base_subworkflow_node.py @@ -0,0 +1,87 @@ +from abc import ABC +from typing import Any, Dict, Optional, Set + +from jinja2 import Template +from pydantic import BaseModel, Field + +from ...execution.workflow_executor import WorkflowExecutor +from ...schemas.workflow_schemas import WorkflowNodeSchema +from ...utils.pydantic_utils import get_nested_field +from ..base import BaseNode, BaseNodeConfig + + +class BaseSubworkflowNodeConfig(BaseNodeConfig): + input_map: Optional[Dict[str, str]] = Field( + default=None, + title="Input map", + description="Map of input variables to subworkflow input fields expressed as Dict[, ]", + ) + + +class BaseSubworkflowNode(BaseNode, ABC): + name: str = "static_workflow_node" + config_model = BaseSubworkflowNodeConfig + + def setup(self) -> None: + super().setup() + + def setup_subworkflow(self) -> None: + assert self.subworkflow is not None + self._node_dict: Dict[str, WorkflowNodeSchema] = { + node.id: node for node in self.subworkflow.nodes + } + self._dependencies: Dict[str, Set[str]] = self._build_dependencies() + + self._subworkflow_output_node = next( + (node for node in self.subworkflow.nodes if node.node_type == "OutputNode") + ) + + def _build_dependencies(self) -> Dict[str, Set[str]]: + assert self.subworkflow is not None + dependencies: Dict[str, Set[str]] = {node.id: set() for node in self.subworkflow.nodes} + for link in self.subworkflow.links: + dependencies[link.target_id].add(link.source_id) + return dependencies + + def _map_input(self, input: BaseModel) -> Dict[str, Any]: + if self.config.input_map == {} or self.config.input_map is None: + return input.model_dump() + mapped_input: Dict[str, Any] = {} + for ( + subworkflow_input_field, + input_var_path, + ) in self.config.input_map.items(): + input_var = get_nested_field(input_var_path, input) + mapped_input[subworkflow_input_field] = input_var + return mapped_input + + def apply_templates_to_config( + self, model: BaseSubworkflowNodeConfig, input_data: Dict[str, Any] + ) -> BaseSubworkflowNodeConfig: + """Apply templates to all config fields ending with _message""" + updates: Dict[str, str] = {} + for field_name, value in model.model_dump().items(): + if isinstance(value, str) and field_name.endswith("_message"): + template = Template(value) + updates[field_name] = template.render(**input_data) + if updates: + return model.model_copy(update=updates) + return model + + async def run(self, input: BaseModel) -> BaseModel: + # Apply templates to config fields + input_dict = input.model_dump() + new_config = self.apply_templates_to_config(self.config, input_dict) + self.update_config(new_config) + + self.setup_subworkflow() + assert self.subworkflow is not None + if self.subworkflow_output is None: + self.subworkflow_output = {} + mapped_input = self._map_input(input) + workflow_executor = WorkflowExecutor(workflow=self.subworkflow, context=self.context) + outputs = await workflow_executor.run( + mapped_input, precomputed_outputs=self.subworkflow_output + ) + self.subworkflow_output.update(outputs) + return self.subworkflow_output[self._subworkflow_output_node.id] diff --git a/pyspur/backend/pyspur/nodes/utils/template_utils.py b/pyspur/backend/pyspur/nodes/utils/template_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..77c8bc8f0cf246b15ea79f9dc19ef7b1b6463f0e --- /dev/null +++ b/pyspur/backend/pyspur/nodes/utils/template_utils.py @@ -0,0 +1,41 @@ +import logging +from typing import Any, Dict + +from jinja2 import Template + + +def render_template_or_get_first_string( + template_str: str, input_dict: Dict[Any, Any], node_name: str +) -> str: + """ + Renders a template string with the given input dictionary. + If template is empty, returns the first string value found in the input dictionary. + + Args: + template_str: The template string to render + input_dict: Dictionary containing values for template rendering + node_name: Name of the node (for error logging) + + Returns: + Rendered template string or first string value from input + + Raises: + ValueError: If no string value is found in input when template is empty + """ + try: + # Render template + rendered = Template(template_str).render(**input_dict) + + # If template is empty, find first string value + if not template_str.strip(): + for _, value in input_dict.items(): + if isinstance(value, str): + return value + raise ValueError(f"No string type found in the input dictionary: {input_dict}") + + return rendered + + except Exception as e: + logging.error(f"Failed to render template in {node_name}") + logging.error(f"template: {template_str} with input: {input_dict}") + raise e diff --git a/pyspur/backend/pyspur/rag/README.md b/pyspur/backend/pyspur/rag/README.md new file mode 100644 index 0000000000000000000000000000000000000000..201119e384a87fdb68b75fc358f514b3da654fd1 --- /dev/null +++ b/pyspur/backend/pyspur/rag/README.md @@ -0,0 +1,6 @@ +# Key Components + +* `chunker.py` +* `embedder.py` +* `parser.py` +* `` \ No newline at end of file diff --git a/pyspur/backend/pyspur/rag/chunker.py b/pyspur/backend/pyspur/rag/chunker.py new file mode 100644 index 0000000000000000000000000000000000000000..10df2abfadca920b04bb277d28733d4071b58dae --- /dev/null +++ b/pyspur/backend/pyspur/rag/chunker.py @@ -0,0 +1,222 @@ +import os +import uuid +from pathlib import Path +from tempfile import NamedTemporaryFile +from typing import BinaryIO, Dict, List, Tuple + +import tiktoken +from jinja2 import Template + +from .parser import extract_text_from_file +from .schemas.document_schemas import ( + ChunkingConfigSchema, + DocumentChunkMetadataSchema, + DocumentChunkSchema, + DocumentSchema, +) + +# Global variables +tokenizer = tiktoken.get_encoding("cl100k_base") # The encoding scheme to use for tokenization + + +def apply_template( + text: str, template: str, metadata_template: Dict[str, str] +) -> Tuple[str, Dict[str, str]]: + """Apply Jinja template to chunk text and metadata.""" + try: + # Create template context + context = { + "text": text, + # Add more context variables as needed + } + + # Process text template + text_template = Template(template) + processed_text = text_template.render(**context) + + # Process metadata templates + processed_metadata: Dict[str, str] = {} + for key, template_str in metadata_template.items(): + metadata_template_obj = Template(template_str) + processed_metadata[key] = metadata_template_obj.render(**context) + + return processed_text, processed_metadata + except Exception as e: + # Log error and return original text with basic metadata + print(f"Error applying template: {e}") + return text, {"type": "text_chunk", "error": str(e)} + + +def get_text_chunks(text: str, config: ChunkingConfigSchema) -> List[str]: + """ + Split a text into chunks based on the provided configuration. + + Args: + text: The text to split into chunks. + config: ChunkingConfig containing the chunking parameters. + + Returns: + A list of text chunks. + """ + if not text or text.isspace(): + return [] + + tokens = tokenizer.encode(text, disallowed_special=()) + chunks: List[str] = [] + num_chunks = 0 + + while tokens and num_chunks < config.max_num_chunks: + chunk = tokens[: config.chunk_token_size] + chunk_text = tokenizer.decode(chunk) + + if not chunk_text or chunk_text.isspace(): + tokens = tokens[len(chunk) :] + continue + + last_punctuation = max( + chunk_text.rfind("."), + chunk_text.rfind("?"), + chunk_text.rfind("!"), + chunk_text.rfind("\n"), + ) + + if last_punctuation != -1 and last_punctuation > config.min_chunk_size_chars: + chunk_text = chunk_text[: last_punctuation + 1] + + chunk_text_to_append = chunk_text.replace("\n", " ").strip() + + if len(chunk_text_to_append) > config.min_chunk_length_to_embed: + chunks.append(chunk_text_to_append) + + tokens = tokens[len(tokenizer.encode(chunk_text, disallowed_special=())) :] + num_chunks += 1 + + if tokens: + remaining_text = tokenizer.decode(tokens).replace("\n", " ").strip() + if len(remaining_text) > config.min_chunk_length_to_embed: + chunks.append(remaining_text) + + return chunks + + +def create_document_chunks( + doc: DocumentSchema, config: ChunkingConfigSchema +) -> Tuple[List[DocumentChunkSchema], str]: + """ + Create a list of document chunks from a document object. + + Args: + doc: The document object to create chunks from. + config: ChunkingConfig containing the chunking parameters. + + Returns: + A tuple of (doc_chunks, doc_id). + """ + if not doc.text or doc.text.isspace(): + return [], doc.id or str(uuid.uuid4()) + + doc_id = doc.id or str(uuid.uuid4()) + text_chunks = get_text_chunks(doc.text, config) + + metadata = ( + DocumentChunkMetadataSchema(**doc.metadata.model_dump()) + if doc.metadata is not None + else DocumentChunkMetadataSchema() + ) + metadata.document_id = doc_id + + doc_chunks: List[DocumentChunkSchema] = [] + for i, text_chunk in enumerate(text_chunks): + chunk_id = f"{doc_id}_{i}" + + # Apply template if enabled + if config.template.enabled: + processed_text, processed_metadata = apply_template( + text_chunk, + config.template.template, + config.template.metadata_template or {}, + ) + # Update metadata with processed metadata + chunk_metadata = metadata.model_copy() + chunk_metadata.custom_metadata = processed_metadata + else: + processed_text = text_chunk + chunk_metadata = metadata + + doc_chunk = DocumentChunkSchema( + id=chunk_id, + text=processed_text, + metadata=chunk_metadata, + ) + doc_chunks.append(doc_chunk) + + return doc_chunks, doc_id + + +async def preview_document_chunk( + file: BinaryIO, + filename: str, + mime_type: str, + config: ChunkingConfigSchema, +) -> Tuple[List[Dict[str, str]], int]: + """ + Preview how a document will be chunked and formatted. + + Args: + file: The file object to process + filename: Name of the file + mime_type: MIME type of the file + config: Chunking configuration + + Returns: + Tuple containing: + - List of preview chunks, each containing original_text, processed_text, and metadata + - Total number of chunks + """ + try: + # Create temporary file + with NamedTemporaryFile(delete=False, suffix=Path(filename).suffix) as temp_file: + temp_file.write(file.read()) + temp_file.flush() + + # Extract text using document processing logic + with open(temp_file.name, "rb") as f: + extracted_text = extract_text_from_file(f, mime_type or "text/plain", None) + + # Clean up temp file + os.unlink(temp_file.name) + + # Create a temporary Document object to use create_document_chunks + temp_doc = DocumentSchema(text=extracted_text) + doc_chunks, _ = create_document_chunks(temp_doc, config) + + if not doc_chunks: + raise ValueError("No chunks could be generated with the provided configuration") + + # Take up to 3 chunks for preview: beginning, middle, and end + preview_indices = [] + if len(doc_chunks) == 1: + preview_indices = [0] + elif len(doc_chunks) == 2: + preview_indices = [0, 1] + else: + preview_indices = [0, len(doc_chunks) // 2, len(doc_chunks) - 1] + + preview_chunks = [] + for idx in preview_indices: + chunk = doc_chunks[idx] + preview_chunks.append( + { + "original_text": chunk.text, # This will already be processed if template is enabled + "processed_text": chunk.text, + "metadata": chunk.metadata.custom_metadata + if chunk.metadata + else {"type": "text_chunk"}, + "chunk_index": idx + 1, # 1-based index for display + } + ) + + return preview_chunks, len(doc_chunks) + + except Exception as e: + raise ValueError(f"Error previewing chunk: {str(e)}") diff --git a/pyspur/backend/pyspur/rag/datastore/README.md b/pyspur/backend/pyspur/rag/datastore/README.md new file mode 100644 index 0000000000000000000000000000000000000000..05b98c32ab81d6f995045a6079150fb69c2b088d --- /dev/null +++ b/pyspur/backend/pyspur/rag/datastore/README.md @@ -0,0 +1,10 @@ +# Datastore Module + +This module is adapted from the [ChatGPT Retrieval Plugin datastore implementation](https://github.com/openai/chatgpt-retrieval-plugin/tree/main/datastore). + +We considered using a Git submodule to include the original code, but decided against it for two main reasons: + +1. Simplicity - Direct inclusion makes the codebase more straightforward to work with +2. Update Frequency - The original repository has infrequent updates, reducing the benefits of using a submodule + +The code has been modified and integrated directly into this codebase while maintaining attribution to the original source. diff --git a/pyspur/backend/pyspur/rag/datastore/__init__.py b/pyspur/backend/pyspur/rag/datastore/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pyspur/backend/pyspur/rag/datastore/datastore.py b/pyspur/backend/pyspur/rag/datastore/datastore.py new file mode 100644 index 0000000000000000000000000000000000000000..319ca1f382a7b151fde7c2221638d5ca422a17fe --- /dev/null +++ b/pyspur/backend/pyspur/rag/datastore/datastore.py @@ -0,0 +1,95 @@ +import asyncio +from abc import ABC, abstractmethod +from typing import Dict, List, Optional + +from ..chunker import create_document_chunks +from ..schemas.document_schemas import ( + ChunkingConfigSchema, + DocumentChunkSchema, + DocumentMetadataFilterSchema, + DocumentSchema, + QueryResultSchema, + QueryWithEmbeddingSchema, +) + + +class DataStore(ABC): + def __init__(self, embedding_dimension: Optional[int] = None): + self.embedding_dimension = embedding_dimension + + async def upsert( + self, + documents: List[DocumentSchema], + chunk_token_size: Optional[int] = None, + ) -> List[str]: + """ + Takes in a list of documents and inserts them into the database. + First deletes all the existing vectors with the document id (if necessary, depends on the vector db), then inserts the new ones. + Return a list of document ids. + """ + # Delete any existing vectors for documents with the input document ids + await asyncio.gather( + *[ + self.delete( + filter=DocumentMetadataFilterSchema( + document_id=document.id, + ), + delete_all=False, + ) + for document in documents + if document.id + ] + ) + + chunks = {} + for doc in documents: + # If the document already has chunks with embeddings, use those + if hasattr(doc, "chunks") and doc.chunks: + chunks[doc.id] = doc.chunks + else: + # Only create new chunks if the document doesn't have them + config = ( + ChunkingConfigSchema(chunk_token_size=chunk_token_size) + if chunk_token_size + else ChunkingConfigSchema() + ) + doc_chunks, doc_id = create_document_chunks(doc, config) + chunks[doc_id] = doc_chunks + + return await self._upsert(chunks) + + @abstractmethod + async def _upsert(self, chunks: Dict[str, List[DocumentChunkSchema]]) -> List[str]: + """ + Takes in a list of document chunks and inserts them into the database. + Return a list of document ids. + """ + + raise NotImplementedError + + async def query(self, queries: List[QueryWithEmbeddingSchema]) -> List[QueryResultSchema]: + """ + Takes in a list of queries with embeddings and returns a list of query results with matching document chunks and scores. + """ + return await self._query(queries) + + @abstractmethod + async def _query(self, queries: List[QueryWithEmbeddingSchema]) -> List[QueryResultSchema]: + """ + Takes in a list of queries with embeddings and filters and returns a list of query results with matching document chunks and scores. + """ + raise NotImplementedError + + @abstractmethod + async def delete( + self, + ids: Optional[List[str]] = None, + filter: Optional[DocumentMetadataFilterSchema] = None, + delete_all: Optional[bool] = None, + ) -> bool: + """ + Removes vectors by ids, filter, or everything in the datastore. + Multiple parameters can be used at once. + Returns whether the operation was successful. + """ + raise NotImplementedError diff --git a/pyspur/backend/pyspur/rag/datastore/factory.py b/pyspur/backend/pyspur/rag/datastore/factory.py new file mode 100644 index 0000000000000000000000000000000000000000..6e0243f1d1ed6caee1ed25b3fe4d7fdfdd946d07 --- /dev/null +++ b/pyspur/backend/pyspur/rag/datastore/factory.py @@ -0,0 +1,131 @@ +from typing import Dict, List, Optional + +from pydantic import BaseModel, Field + +from ..embedder import EmbeddingModels +from .datastore import DataStore + + +class VectorStoreConfig(BaseModel): + id: str + name: str + description: str + requires_api_key: bool = False + api_key_env_var: Optional[str] = None + required_env_vars: List[str] = Field(default_factory=list) + + +def get_vector_stores() -> Dict[str, VectorStoreConfig]: + """Get all available vector stores and their configurations.""" + return { + "chroma": VectorStoreConfig( + id="chroma", + name="Chroma", + description="Open-source embedding database", + required_env_vars=[ + "CHROMA_IN_MEMORY", + "CHROMA_PERSISTENCE_DIR", + "CHROMA_HOST", + "CHROMA_PORT", + "CHROMA_COLLECTION", + ], + ), + "pinecone": VectorStoreConfig( + id="pinecone", + name="Pinecone", + description="Production-ready vector database", + api_key_env_var="PINECONE_API_KEY", + required_env_vars=[ + "PINECONE_API_KEY", + "PINECONE_INDEX", + "PINECONE_CLOUD", + "PINECONE_REGION", + ], + ), + "weaviate": VectorStoreConfig( + id="weaviate", + name="Weaviate", + description="Multi-modal vector search engine", + api_key_env_var="WEAVIATE_API_KEY", + required_env_vars=[ + "WEAVIATE_API_KEY", + "WEAVIATE_URL", + "WEAVIATE_CLASS", + ], + ), + "supabase": VectorStoreConfig( + id="supabase", + name="Supabase", + description="Open-source vector database", + required_env_vars=[ + "SUPABASE_URL", + "SUPABASE_ANON_KEY", + "SUPABASE_SERVICE_ROLE_KEY", + ], + ), + "qdrant": VectorStoreConfig( + id="qdrant", + name="Qdrant", + description="Vector database for production", + api_key_env_var="QDRANT_API_KEY", + required_env_vars=[ + "QDRANT_API_KEY", + "QDRANT_URL", + "QDRANT_COLLECTION", + "QDRANT_PORT", + "QDRANT_GRPC_PORT", + ], + ), + } + + +async def get_datastore(datastore: str, embedding_model: Optional[str] = None) -> DataStore: + """Initialize and return a DataStore instance for the specified vector database.""" + assert datastore is not None + + # Validate the datastore is supported + vector_stores = get_vector_stores() + if datastore not in vector_stores: + raise ValueError( + f"Unsupported vector database: {datastore}. " + f"Try one of the following: {', '.join(vector_stores.keys())}" + ) + + # Get embedding dimension from model if specified + embedding_dimension = None + if embedding_model: + model_info = EmbeddingModels.get_model_info(embedding_model) + if model_info: + embedding_dimension = model_info.dimensions + + match datastore: + case "chroma": + from .providers.chroma_datastore import ChromaDataStore + + return ChromaDataStore(embedding_dimension=embedding_dimension) + + case "pinecone": + from .providers.pinecone_datastore import PineconeDataStore + + return PineconeDataStore(embedding_dimension=embedding_dimension) + + case "weaviate": + from .providers.weaviate_datastore import WeaviateDataStore + + return WeaviateDataStore(embedding_dimension=embedding_dimension) + + case "qdrant": + from .providers.qdrant_datastore import QdrantDataStore + + return QdrantDataStore(embedding_dimension=embedding_dimension) + + case "supabase": + from .providers.supabase_datastore import SupabaseDataStore + + return SupabaseDataStore(embedding_dimension=embedding_dimension) + + case _: + raise ValueError( + f"Unsupported vector database: {datastore}. " + f"Try one of the following: {', '.join(vector_stores.keys())}" + ) diff --git a/pyspur/backend/pyspur/rag/datastore/providers/__init__.py b/pyspur/backend/pyspur/rag/datastore/providers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pyspur/backend/pyspur/rag/datastore/providers/chroma_datastore.py b/pyspur/backend/pyspur/rag/datastore/providers/chroma_datastore.py new file mode 100644 index 0000000000000000000000000000000000000000..d72d3053b70acb2f63727cd1d281447474808f12 --- /dev/null +++ b/pyspur/backend/pyspur/rag/datastore/providers/chroma_datastore.py @@ -0,0 +1,255 @@ +""" +Chroma datastore support for the ChatGPT retrieval plugin. + +Consult the Chroma docs and GitHub repo for more information: +- https://docs.trychroma.com/usage-guide?lang=py +- https://github.com/chroma-core/chroma +- https://www.trychroma.com/ +""" + +import os +from datetime import datetime +from typing import Dict, List, Optional + +import chromadb + +from ...chunker import create_document_chunks +from ...schemas.document_schemas import ( + ChunkingConfigSchema, + DocumentChunkMetadataSchema, + DocumentChunkSchema, + DocumentChunkWithScoreSchema, + DocumentMetadataFilterSchema, + DocumentSchema, + QueryResultSchema, + QueryWithEmbeddingSchema, + Source, +) +from ..datastore import DataStore + +CHROMA_IN_MEMORY = os.environ.get("CHROMA_IN_MEMORY", "True") +CHROMA_PERSISTENCE_DIR = os.environ.get("CHROMA_PERSISTENCE_DIR", "openai") +CHROMA_HOST = os.environ.get("CHROMA_HOST", "http://127.0.0.1") +CHROMA_PORT = os.environ.get("CHROMA_PORT", "8000") +CHROMA_COLLECTION = os.environ.get("CHROMA_COLLECTION", "openaiembeddings") + + +class ChromaDataStore(DataStore): + def __init__( + self, + embedding_dimension: Optional[int] = None, + in_memory: bool = CHROMA_IN_MEMORY, # type: ignore + persistence_dir: Optional[str] = CHROMA_PERSISTENCE_DIR, + collection_name: str = CHROMA_COLLECTION, + host: str = CHROMA_HOST, + port: str = CHROMA_PORT, + client: Optional[chromadb.Client] = None, + ): + super().__init__(embedding_dimension=embedding_dimension) + if client: + self._client = client + else: + if in_memory: + settings = ( + chromadb.config.Settings( + chroma_db_impl="duckdb+parquet", + persist_directory=persistence_dir, + ) + if persistence_dir + else chromadb.config.Settings() + ) + + self._client = chromadb.Client(settings=settings) + else: + self._client = chromadb.Client( + settings=chromadb.config.Settings( + chroma_api_impl="rest", + chroma_server_host=host, + chroma_server_http_port=port, + ) + ) + self._collection = self._client.get_or_create_collection( + name=collection_name, + embedding_function=None, + ) + + async def upsert( + self, + documents: List[DocumentSchema], + chunk_token_size: Optional[int] = None, + ) -> List[str]: + """ + Takes in a list of documents and inserts them into the database. If an id already exists, the document is updated. + Return a list of document ids. + """ + chunks = {} + for doc in documents: + config = ( + ChunkingConfigSchema(chunk_token_size=chunk_token_size) + if chunk_token_size + else ChunkingConfigSchema() + ) + doc_chunks, doc_id = create_document_chunks(doc, config) + chunks[doc_id] = doc_chunks + + # Chroma has a true upsert, so we don't need to delete first + return await self._upsert(chunks) + + async def _upsert(self, chunks: Dict[str, List[DocumentChunkSchema]]) -> List[str]: + """ + Takes in a list of list of document chunks and inserts them into the database. + Return a list of document ids. + """ + + self._collection.upsert( + ids=[chunk.id for chunk_list in chunks.values() for chunk in chunk_list], + embeddings=[chunk.embedding for chunk_list in chunks.values() for chunk in chunk_list], + documents=[chunk.text for chunk_list in chunks.values() for chunk in chunk_list], + metadatas=[ + self._process_metadata_for_storage(chunk.metadata) + for chunk_list in chunks.values() + for chunk in chunk_list + ], + ) + return list(chunks.keys()) + + def _where_from_query_filter(self, query_filter: DocumentMetadataFilterSchema) -> Dict: + output = { + k: v + for (k, v) in query_filter.dict().items() + if v is not None and k != "start_date" and k != "end_date" and k != "source" + } + if query_filter.source: + output["source"] = query_filter.source.value + if query_filter.start_date and query_filter.end_date: + output["$and"] = [ + { + "created_at": { + "$gte": int(datetime.fromisoformat(query_filter.start_date).timestamp()) + } + }, + { + "created_at": { + "$lte": int(datetime.fromisoformat(query_filter.end_date).timestamp()) + } + }, + ] + elif query_filter.start_date: + output["created_at"] = { + "$gte": int(datetime.fromisoformat(query_filter.start_date).timestamp()) + } + elif query_filter.end_date: + output["created_at"] = { + "$lte": int(datetime.fromisoformat(query_filter.end_date).timestamp()) + } + + return output + + def _process_metadata_for_storage(self, metadata: DocumentChunkMetadataSchema) -> Dict: + stored_metadata = {} + if metadata.source: + stored_metadata["source"] = metadata.source.value + if metadata.source_id: + stored_metadata["source_id"] = metadata.source_id + if metadata.url: + stored_metadata["url"] = metadata.url + if metadata.created_at: + stored_metadata["created_at"] = int( + datetime.fromisoformat(metadata.created_at).timestamp() + ) + if metadata.author: + stored_metadata["author"] = metadata.author + if metadata.document_id: + stored_metadata["document_id"] = metadata.document_id + + return stored_metadata + + def _process_metadata_from_storage(self, metadata: Dict) -> DocumentChunkMetadataSchema: + return DocumentChunkMetadataSchema( + source=Source(metadata["source"]) if "source" in metadata else None, + source_id=metadata.get("source_id", None), + url=metadata.get("url", None), + created_at=( + datetime.fromtimestamp(metadata["created_at"]).isoformat() + if "created_at" in metadata + else None + ), + author=metadata.get("author", None), + document_id=metadata.get("document_id", None), + ) + + async def _query(self, queries: List[QueryWithEmbeddingSchema]) -> List[QueryResultSchema]: + """ + Takes in a list of queries with embeddings and filters and returns a list of query results with matching document chunks and scores. + """ + results = [ + self._collection.query( + query_embeddings=[query.embedding], + include=["documents", "distances", "metadatas"], # embeddings + n_results=min(query.top_k, self._collection.count()), # type: ignore + where=(self._where_from_query_filter(query.filter) if query.filter else {}), + ) + for query in queries + ] + + output = [] + for query, result in zip(queries, results): + inner_results = [] + (ids,) = result["ids"] + # (embeddings,) = result["embeddings"] + (documents,) = result["documents"] + (metadatas,) = result["metadatas"] + (distances,) = result["distances"] + for id_, text, metadata, distance in zip( + ids, + documents, + metadatas, + distances, # embeddings (https://github.com/openai/chatgpt-retrieval-plugin/pull/59#discussion_r1154985153) + ): + inner_results.append( + DocumentChunkWithScoreSchema( + id=id_, + text=text, + metadata=self._process_metadata_from_storage(metadata), + # embedding=embedding, + score=distance, + ) + ) + output.append(QueryResultSchema(query=query.query, results=inner_results)) + + return output + + async def delete( + self, + ids: Optional[List[str]] = None, + filter: Optional[DocumentMetadataFilterSchema] = None, + delete_all: Optional[bool] = None, + ) -> bool: + """ + Removes vectors by ids, filter, or everything in the datastore. + Multiple parameters can be used at once. + Returns whether the operation was successful. + """ + if delete_all: + self._collection.delete() + return True + + if ids and len(ids) > 0: + if len(ids) > 1: + where_clause = {"$or": [{"document_id": id_} for id_ in ids]} + else: + (id_,) = ids + where_clause = {"document_id": id_} + + if filter: + where_clause = { + "$and": [ + self._where_from_query_filter(filter), + where_clause, + ] + } + elif filter: + where_clause = self._where_from_query_filter(filter) + + self._collection.delete(where=where_clause) + return True diff --git a/pyspur/backend/pyspur/rag/datastore/providers/pgvector_datastore.py b/pyspur/backend/pyspur/rag/datastore/providers/pgvector_datastore.py new file mode 100644 index 0000000000000000000000000000000000000000..e02108f896794ab00ff17ec4538636d68862249d --- /dev/null +++ b/pyspur/backend/pyspur/rag/datastore/providers/pgvector_datastore.py @@ -0,0 +1,179 @@ +from abc import ABC, abstractmethod +from datetime import datetime +from typing import Any, Dict, List, Optional + +from loguru import logger + +from ...schemas.document_schemas import ( + DocumentChunkMetadataSchema, + DocumentChunkSchema, + DocumentChunkWithScoreSchema, + DocumentMetadataFilterSchema, + QueryResultSchema, + QueryWithEmbeddingSchema, +) +from ..datastore import DataStore +from ..services.date import to_unix_timestamp + + +# interface for Postgres client to implement pg based Datastore providers +class PGClient(ABC): + @abstractmethod + async def upsert(self, table: str, json: dict[str, Any]) -> None: + """ + Takes in a list of documents and inserts them into the table. + """ + raise NotImplementedError + + @abstractmethod + async def rpc(self, function_name: str, params: dict[str, Any]) -> Any: + """ + Calls a stored procedure in the database with the given parameters. + """ + raise NotImplementedError + + @abstractmethod + async def delete_like(self, table: str, column: str, pattern: str) -> None: + """ + Deletes rows in the table that match the pattern. + """ + raise NotImplementedError + + @abstractmethod + async def delete_in(self, table: str, column: str, ids: List[str]) -> None: + """ + Deletes rows in the table that match the ids. + """ + raise NotImplementedError + + @abstractmethod + async def delete_by_filters(self, table: str, filter: DocumentMetadataFilterSchema) -> None: + """ + Deletes rows in the table that match the filter. + """ + raise NotImplementedError + + +# abstract class for Postgres based Datastore providers that implements DataStore interface +class PgVectorDataStore(DataStore): + def __init__(self, embedding_dimension: Optional[int] = None): + super().__init__(embedding_dimension=embedding_dimension) + self.client = self.create_db_client() + + @abstractmethod + def create_db_client(self) -> PGClient: + """ + Create db client, can be accessing postgres database via different APIs. + Can be supabase client or psycopg2 based client. + Return a client for postgres DB. + """ + + raise NotImplementedError + + async def _upsert(self, chunks: Dict[str, List[DocumentChunkSchema]]) -> List[str]: + """ + Takes in a dict of document_ids to list of document chunks and inserts them into the database. + Return a list of document ids. + """ + for document_id, document_chunks in chunks.items(): + for chunk in document_chunks: + json = { + "id": chunk.id, + "content": chunk.text, + "embedding": chunk.embedding, + "document_id": document_id, + "source": chunk.metadata.source, + "source_id": chunk.metadata.source_id, + "url": chunk.metadata.url, + "author": chunk.metadata.author, + } + if chunk.metadata.created_at: + json["created_at"] = ( + datetime.fromtimestamp(to_unix_timestamp(chunk.metadata.created_at)), + ) + await self.client.upsert("documents", json) + + return list(chunks.keys()) + + async def _query(self, queries: List[QueryWithEmbeddingSchema]) -> List[QueryResultSchema]: + """ + Takes in a list of queries with embeddings and filters and returns a list of query results with matching document chunks and scores. + """ + query_results: List[QueryResultSchema] = [] + for query in queries: + # get the top 3 documents with the highest cosine similarity using rpc function in the database called "match_page_sections" + params = { + "in_embedding": query.embedding, + } + if query.top_k: + params["in_match_count"] = query.top_k + if query.filter: + if query.filter.document_id: + params["in_document_id"] = query.filter.document_id + if query.filter.source: + params["in_source"] = query.filter.source.value + if query.filter.source_id: + params["in_source_id"] = query.filter.source_id + if query.filter.author: + params["in_author"] = query.filter.author + if query.filter.start_date: + params["in_start_date"] = datetime.fromtimestamp( + to_unix_timestamp(query.filter.start_date) + ) + if query.filter.end_date: + params["in_end_date"] = datetime.fromtimestamp( + to_unix_timestamp(query.filter.end_date) + ) + try: + data = await self.client.rpc("match_page_sections", params=params) + results: List[DocumentChunkWithScoreSchema] = [] + for row in data: + document_chunk = DocumentChunkWithScoreSchema( + id=row["id"], + text=row["content"], + # TODO: add embedding to the response ? + # embedding=row["embedding"], + score=float(row["similarity"]), + metadata=DocumentChunkMetadataSchema( + source=row["source"], + source_id=row["source_id"], + document_id=row["document_id"], + url=row["url"], + created_at=row["created_at"], + author=row["author"], + ), + ) + results.append(document_chunk) + query_results.append(QueryResultSchema(query=query.query, results=results)) + except Exception as e: + logger.error(e) + query_results.append(QueryResultSchema(query=query.query, results=[])) + return query_results + + async def delete( + self, + ids: Optional[List[str]] = None, + filter: Optional[DocumentMetadataFilterSchema] = None, + delete_all: Optional[bool] = None, + ) -> bool: + """ + Removes vectors by ids, filter, or everything in the datastore. + Multiple parameters can be used at once. + Returns whether the operation was successful. + """ + if delete_all: + try: + await self.client.delete_like("documents", "document_id", "%") + except: + return False + elif ids: + try: + await self.client.delete_in("documents", "document_id", ids) + except: + return False + elif filter: + try: + await self.client.delete_by_filters("documents", filter) + except: + return False + return True diff --git a/pyspur/backend/pyspur/rag/datastore/providers/pinecone_datastore.py b/pyspur/backend/pyspur/rag/datastore/providers/pinecone_datastore.py new file mode 100644 index 0000000000000000000000000000000000000000..41ccc7b77b4a886f03e92acabcc9df3ef7715d00 --- /dev/null +++ b/pyspur/backend/pyspur/rag/datastore/providers/pinecone_datastore.py @@ -0,0 +1,349 @@ +import asyncio +import os +from typing import Any, Dict, List, Optional + +from loguru import logger +from pinecone import Pinecone, ServerlessSpec +from tenacity import retry, stop_after_attempt, wait_random_exponential + +from ...schemas.document_schemas import ( + DocumentChunkMetadataSchema, + DocumentChunkSchema, + DocumentChunkWithScoreSchema, + DocumentMetadataFilterSchema, + QueryResultSchema, + QueryWithEmbeddingSchema, + Source, +) +from ..datastore import DataStore +from ..services.date import to_unix_timestamp + +# Read environment variables for Pinecone configuration +PINECONE_API_KEY = os.environ.get("PINECONE_API_KEY") +PINECONE_INDEX = os.environ.get("PINECONE_INDEX") +PINECONE_CLOUD = os.environ.get("PINECONE_CLOUD", "aws") +PINECONE_REGION = os.environ.get("PINECONE_REGION", "us-west-2") + +# Validate required environment variables +missing_vars = [] +if not PINECONE_API_KEY: + missing_vars.append("PINECONE_API_KEY") +if not PINECONE_INDEX: + missing_vars.append("PINECONE_INDEX") + +if missing_vars: + raise ValueError( + f"Missing required environment variables for Pinecone: {', '.join(missing_vars)}. " + "Please set these variables in your environment or .env file." + ) + +# Initialize Pinecone client +pc = Pinecone(api_key=PINECONE_API_KEY) + +# Set the batch size for upserting vectors to Pinecone +UPSERT_BATCH_SIZE = 100 + + +class PineconeDataStore(DataStore): + def __init__(self, embedding_dimension: Optional[int] = None): + super().__init__(embedding_dimension=embedding_dimension) + # Check if the index name is specified and exists in Pinecone + if PINECONE_INDEX and PINECONE_INDEX not in pc.list_indexes().names(): + # Get all fields in the metadata object in a list + fields_to_index = list(DocumentChunkMetadataSchema.model_fields.keys()) + + # Create a new index with the specified name, dimension, and metadata configuration + try: + logger.info( + f"Creating index {PINECONE_INDEX} with metadata config {fields_to_index}" + ) + pc.create_index( + name=PINECONE_INDEX, + dimension=self.embedding_dimension or 1536, # Default to 1536 if not specified + spec=ServerlessSpec(cloud=PINECONE_CLOUD, region=PINECONE_REGION), + metadata_config={"indexed": fields_to_index}, + ) + self.index = pc.Index(name=PINECONE_INDEX) + logger.info(f"Index {PINECONE_INDEX} created successfully") + except Exception as e: + logger.error(f"Error creating index {PINECONE_INDEX}: {e}") + raise e + elif PINECONE_INDEX and PINECONE_INDEX in pc.list_indexes().names(): + # Connect to an existing index with the specified name + try: + logger.info(f"Connecting to existing index {PINECONE_INDEX}") + self.index = pc.Index(name=PINECONE_INDEX) + logger.info(f"Connected to index {PINECONE_INDEX} successfully") + except Exception as e: + logger.error(f"Error connecting to index {PINECONE_INDEX}: {e}") + raise e + + @retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(3)) + async def _upsert(self, chunks: Dict[str, List[DocumentChunkSchema]]) -> List[str]: + """ + Takes in a dict from document id to list of document chunks and inserts them into the index. + Return a list of document ids. + """ + if not isinstance(chunks, dict): + raise ValueError("Expected chunks to be a dictionary") + + # Initialize a list of ids to return + doc_ids: List[str] = [] + # Initialize a list of vectors to upsert + vectors = [] + # Loop through the dict items + for doc_id, chunk_list in chunks.items(): + # Append the id to the ids list + doc_ids.append(doc_id) + logger.info(f"Upserting document_id: {doc_id}") + for chunk in chunk_list: + # Create a vector tuple of (id, embedding, metadata) + # Convert the metadata object to a dict with unix timestamps for dates + pinecone_metadata = self._get_pinecone_metadata(chunk.metadata) + # Add the text and document id to the metadata dict + pinecone_metadata["text"] = chunk.text + pinecone_metadata["document_id"] = doc_id + # Convert embedding values to float + float_embedding = [float(val) for val in chunk.embedding] + # Log embedding details + logger.debug( + f"Chunk {chunk.id} embedding stats - length: {len(float_embedding)}, non-zero values: {sum(1 for x in float_embedding if x != 0)}, sample: {float_embedding[:5]}" + ) + + vector = (chunk.id, float_embedding, pinecone_metadata) + vectors.append(vector) + + # Split the vectors list into batches of the specified size + batches = [ + vectors[i : i + UPSERT_BATCH_SIZE] for i in range(0, len(vectors), UPSERT_BATCH_SIZE) + ] + # Upsert each batch to Pinecone + for batch in batches: + try: + logger.info(f"Upserting batch of size {len(batch)}") + self.index.upsert(vectors=batch) + logger.info(f"Upserted batch successfully") + except Exception as e: + logger.error(f"Error upserting batch: {e}") + raise e + + return doc_ids + + @retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(3)) + async def _query( + self, + queries: List[QueryWithEmbeddingSchema], + ) -> List[QueryResultSchema]: + """ + Takes in a list of queries with embeddings and filters and returns a list of query results with matching document chunks and scores. + """ + + # Define a helper coroutine that performs a single query and returns a QueryResult + async def _single_query( + query: QueryWithEmbeddingSchema, + ) -> QueryResultSchema: + logger.debug(f"Query: {query.query}") + + # Convert the metadata filter object to a dict with pinecone filter expressions + pinecone_filter = self._get_pinecone_filter(query.filter) + + try: + # Query the index with the query embedding, filter, and top_k + query_response = self.index.query( + # namespace=namespace, + top_k=query.top_k or 10, # Default to 10 if top_k is None + vector=query.embedding, + filter=pinecone_filter, + include_metadata=True, + ) + except Exception as e: + logger.error(f"Error querying index: {e}") + raise e + + query_results: List[DocumentChunkWithScoreSchema] = [] + for result in query_response.matches: + score = result.score + metadata = result.metadata + # Remove document id and text from metadata and store it in a new variable + metadata_without_text = ( + {key: value for key, value in metadata.items() if key != "text"} + if metadata + else None + ) + + # If the source is not a valid Source in the Source enum, set it to None + if ( + metadata_without_text + and "source" in metadata_without_text + and metadata_without_text["source"] not in Source.__members__ + ): + metadata_without_text["source"] = None + + # Convert created_at from timestamp back to string if it exists + if metadata_without_text and "created_at" in metadata_without_text: + from datetime import datetime + + timestamp = float(metadata_without_text["created_at"]) + metadata_without_text["created_at"] = datetime.fromtimestamp( + timestamp + ).isoformat() + + # Create a document chunk with score object with the result data + result = DocumentChunkWithScoreSchema( + id=result.id, + score=score, + text=(str(metadata["text"]) if metadata and "text" in metadata else ""), + metadata=DocumentChunkMetadataSchema(**metadata_without_text) + if metadata_without_text + else DocumentChunkMetadataSchema(), + ) + query_results.append(result) + return QueryResultSchema(query=query.query, results=query_results) + + # Use asyncio.gather to run multiple _single_query coroutines concurrently and collect their results + results: List[QueryResultSchema] = await asyncio.gather( + *[_single_query(query) for query in queries] + ) + + return results + + @retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(3)) + async def delete( + self, + ids: Optional[List[str]] = None, + filter: Optional[DocumentMetadataFilterSchema] = None, + delete_all: Optional[bool] = None, + ) -> bool: + """ + Removes vectors by ids, filter, or everything in the datastore. + Multiple parameters can be used at once. + Returns whether the operation was successful. + """ + if delete_all: + try: + logger.info(f"Deleting all vectors from index") + self.index.delete(delete_all=True) + logger.info(f"Deleted all vectors successfully") + return True + except Exception as e: + logger.error(f"Error deleting all vectors: {e}") + return False + + if ids and len(ids) > 0: + try: + # First, query to get the chunk IDs associated with these document IDs + dummy_vector: List[float] = [0.0] * ( + self.embedding_dimension or 1536 + ) # Default to 1536 if not specified + query_response = self.index.query( + vector=dummy_vector, # Dummy vector for metadata-only query + filter={"document_id": {"$in": ids}}, + top_k=10000, # Get as many matches as possible + include_metadata=True, + ) + + # Extract the chunk IDs from the response + chunk_ids: List[str] = [] + if hasattr(query_response, "matches"): + chunk_ids = [str(match.id) for match in query_response.matches] + + if chunk_ids: + logger.info(f"Deleting vectors with chunk ids {chunk_ids}") + self.index.delete(ids=chunk_ids) + logger.info(f"Deleted vectors with ids successfully") + + return True + except Exception as e: + logger.error(f"Error deleting vectors with ids {ids}: {e}") + return False + + if filter: + try: + pinecone_filter = self._get_pinecone_filter(filter) + # Query to get the IDs of vectors that match the filter + dummy_vector: List[float] = [0.0] * ( + self.embedding_dimension or 1536 + ) # Default to 1536 if not specified + query_response = self.index.query( + vector=dummy_vector, # Dummy vector for metadata-only query + filter=pinecone_filter, + top_k=10000, # Get as many matches as possible + include_metadata=True, + ) + + # Extract the IDs from the response + chunk_ids: List[str] = [] + if hasattr(query_response, "matches"): + chunk_ids = [str(match.id) for match in query_response.matches] + + if chunk_ids: + logger.info(f"Deleting vectors with chunk ids {chunk_ids}") + self.index.delete(ids=chunk_ids) + logger.info(f"Deleted vectors with filter successfully") + + return True + except Exception as e: + logger.error(f"Error deleting vectors with filter: {e}") + return False + + return False + + def _get_pinecone_filter( + self, filter: Optional[DocumentMetadataFilterSchema] = None + ) -> Dict[str, Any]: + if filter is None: + return {} + + pinecone_filter = {} + + # For each field in the MetadataFilter, check if it has a value and add the corresponding pinecone filter expression + # For start_date and end_date, uses the $gte and $lte operators respectively + # For other fields, uses the $eq operator + for field, value in filter.model_dump().items(): + if value is not None: + if field == "start_date": + pinecone_filter["created_at"] = pinecone_filter.get("created_at", {}) + pinecone_filter["created_at"]["$gte"] = to_unix_timestamp(value) + elif field == "end_date": + pinecone_filter["created_at"] = pinecone_filter.get("created_at", {}) + pinecone_filter["created_at"]["$lte"] = to_unix_timestamp(value) + else: + pinecone_filter[field] = value + + return pinecone_filter + + def _get_pinecone_metadata( + self, metadata: Optional[DocumentChunkMetadataSchema] = None + ) -> Dict[str, Any]: + if metadata is None: + return {} + + pinecone_metadata = {} + + # Convert the metadata to a dict + metadata_dict = metadata.model_dump() + + # For each field in the Metadata, check if it has a value and add it to the pinecone metadata dict + # Flatten nested structures and ensure values are primitive types + for field, value in metadata_dict.items(): + if value is not None: + if field in ["created_at"]: + pinecone_metadata[field] = to_unix_timestamp(value) + elif isinstance(value, (str, int, float, bool)): + pinecone_metadata[field] = value + elif isinstance(value, list) and all(isinstance(x, str) for x in value): + pinecone_metadata[field] = value + elif isinstance(value, dict): + # Flatten nested dict by prefixing keys with the field name + for k, v in value.items(): + if isinstance(v, (str, int, float, bool)): + pinecone_metadata[f"{field}_{k}"] = v + else: + # Convert other types to strings if possible + try: + pinecone_metadata[field] = str(value) + except: + # Skip values that can't be converted to string + continue + + return pinecone_metadata diff --git a/pyspur/backend/pyspur/rag/datastore/providers/qdrant_datastore.py b/pyspur/backend/pyspur/rag/datastore/providers/qdrant_datastore.py new file mode 100644 index 0000000000000000000000000000000000000000..bfdb9f99d3799c58cf1263d789565ed717310ff5 --- /dev/null +++ b/pyspur/backend/pyspur/rag/datastore/providers/qdrant_datastore.py @@ -0,0 +1,278 @@ +import os +import uuid +from typing import Dict, List, Optional + +import qdrant_client +from grpc._channel import _InactiveRpcError +from qdrant_client.http import models as rest +from qdrant_client.http.exceptions import UnexpectedResponse +from qdrant_client.http.models import PayloadSchemaType + +from ...schemas.document_schemas import ( + DocumentChunkSchema, + DocumentChunkWithScoreSchema, + DocumentMetadataFilterSchema, + QueryResultSchema, + QueryWithEmbeddingSchema, +) +from ..datastore import DataStore +from ..services.date import to_unix_timestamp + +QDRANT_URL = os.environ.get("QDRANT_URL", "http://localhost") +QDRANT_PORT = os.environ.get("QDRANT_PORT", "6333") +QDRANT_GRPC_PORT = os.environ.get("QDRANT_GRPC_PORT", "6334") +QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY") +QDRANT_COLLECTION = os.environ.get("QDRANT_COLLECTION", "document_chunks") + +EMBEDDING_DIMENSION = int(os.environ.get("EMBEDDING_DIMENSION", 256)) + + +class QdrantDataStore(DataStore): + UUID_NAMESPACE = uuid.UUID("3896d314-1e95-4a3a-b45a-945f9f0b541d") + + def __init__( + self, + embedding_dimension: Optional[int] = None, + collection_name: Optional[str] = None, + distance: str = "Cosine", + recreate_collection: bool = False, + ): + super().__init__(embedding_dimension=embedding_dimension) + self.collection_name = collection_name or QDRANT_COLLECTION + self._client = qdrant_client.QdrantClient( + url=QDRANT_URL, + port=int(QDRANT_PORT), + grpc_port=int(QDRANT_GRPC_PORT), + api_key=QDRANT_API_KEY, + prefer_grpc=True, + timeout=10, + ) + self._set_up_collection( + vector_size=self.embedding_dimension or 1536, # Default to 1536 if not specified + distance=distance, + recreate_collection=recreate_collection, + ) + + async def _upsert(self, chunks: Dict[str, List[DocumentChunkSchema]]) -> List[str]: + """ + Takes in a list of document chunks and inserts them into the database. + Return a list of document ids. + """ + points = [ + self._convert_document_chunk_to_point(chunk) + for _, chunks in chunks.items() + for chunk in chunks + ] + self._client.upsert( + collection_name=self.collection_name, + points=points, # type: ignore + wait=True, + ) + return list(chunks.keys()) + + async def _query( + self, + queries: List[QueryWithEmbeddingSchema], + ) -> List[QueryResultSchema]: + """ + Takes in a list of queries with embeddings and filters and returns a list of query results with matching document chunks and scores. + """ + search_requests = [self._convert_query_to_search_request(query) for query in queries] + results = self._client.search_batch( + collection_name=self.collection_name, + requests=search_requests, + ) + return [ + QueryResultSchema( + query=query.query, + results=[ + self._convert_scored_point_to_document_chunk_with_score(point) + for point in result + ], + ) + for query, result in zip(queries, results) + ] + + async def delete( + self, + ids: Optional[List[str]] = None, + filter: Optional[DocumentMetadataFilterSchema] = None, + delete_all: Optional[bool] = None, + ) -> bool: + """ + Removes vectors by ids, filter, or everything in the datastore. + Returns whether the operation was successful. + """ + if ids is None and filter is None and delete_all is None: + raise ValueError("Please provide one of the parameters: ids, filter or delete_all.") + + if delete_all: + points_selector = rest.Filter() + else: + points_selector = self._convert_metadata_filter_to_qdrant_filter(filter, ids) + + response = self._client.delete( + collection_name=self.collection_name, + points_selector=points_selector, # type: ignore + ) + return "COMPLETED" == response.status + + def _convert_document_chunk_to_point( + self, document_chunk: DocumentChunkSchema + ) -> rest.PointStruct: + created_at = ( + to_unix_timestamp(document_chunk.metadata.created_at) + if document_chunk.metadata.created_at is not None + else None + ) + return rest.PointStruct( + id=self._create_document_chunk_id(document_chunk.id), + vector=document_chunk.embedding, # type: ignore + payload={ + "id": document_chunk.id, + "text": document_chunk.text, + "metadata": document_chunk.metadata.dict(), + "created_at": created_at, + }, + ) + + def _create_document_chunk_id(self, external_id: Optional[str]) -> str: + if external_id is None: + return uuid.uuid4().hex + return uuid.uuid5(self.UUID_NAMESPACE, external_id).hex + + def _convert_query_to_search_request( + self, query: QueryWithEmbeddingSchema + ) -> rest.SearchRequest: + return rest.SearchRequest( + vector=query.embedding, + filter=self._convert_metadata_filter_to_qdrant_filter(query.filter), + limit=query.top_k, # type: ignore + with_payload=True, + with_vector=False, + ) + + def _convert_metadata_filter_to_qdrant_filter( + self, + metadata_filter: Optional[DocumentMetadataFilterSchema] = None, + ids: Optional[List[str]] = None, + ) -> Optional[rest.Filter]: + if metadata_filter is None and ids is None: + return None + + must_conditions, should_conditions = [], [] + + # Filtering by document ids + if ids and len(ids) > 0: + for document_id in ids: + should_conditions.append( + rest.FieldCondition( + key="metadata.document_id", + match=rest.MatchValue(value=document_id), + ) + ) + + # Equality filters for the payload attributes + if metadata_filter: + meta_attributes_keys = { + "document_id": "metadata.document_id", + "source": "metadata.source", + "source_id": "metadata.source_id", + "author": "metadata.author", + } + + for meta_attr_name, payload_key in meta_attributes_keys.items(): + attr_value = getattr(metadata_filter, meta_attr_name) + if attr_value is None: + continue + + must_conditions.append( + rest.FieldCondition(key=payload_key, match=rest.MatchValue(value=attr_value)) + ) + + # Date filters use range filtering + start_date = metadata_filter.start_date + end_date = metadata_filter.end_date + if start_date or end_date: + gte_filter = to_unix_timestamp(start_date) if start_date is not None else None + lte_filter = to_unix_timestamp(end_date) if end_date is not None else None + must_conditions.append( + rest.FieldCondition( + key="created_at", + range=rest.Range( + gte=gte_filter, + lte=lte_filter, + ), + ) + ) + + if 0 == len(must_conditions) and 0 == len(should_conditions): + return None + + return rest.Filter(must=must_conditions, should=should_conditions) + + def _convert_scored_point_to_document_chunk_with_score( + self, scored_point: rest.ScoredPoint + ) -> DocumentChunkWithScoreSchema: + payload = scored_point.payload or {} + return DocumentChunkWithScoreSchema( + id=payload.get("id"), + text=scored_point.payload.get("text"), # type: ignore + metadata=scored_point.payload.get("metadata"), # type: ignore + embedding=scored_point.vector, # type: ignore + score=scored_point.score, + ) + + def _set_up_collection(self, vector_size: int, distance: str, recreate_collection: bool): + distance = rest.Distance[distance.upper()] + + if recreate_collection: + self._recreate_collection(distance, vector_size) + + try: + collection_info = self._client.get_collection(self.collection_name) + current_distance = collection_info.config.params.vectors.distance # type: ignore + current_vector_size = collection_info.config.params.vectors.size # type: ignore + + if current_distance != distance: + raise ValueError( + f"Collection '{self.collection_name}' already exists in Qdrant, " + f"but it is configured with a similarity '{current_distance.name}'. " + f"If you want to use that collection, but with a different " + f"similarity, please set `recreate_collection=True` argument." + ) + + if current_vector_size != vector_size: + raise ValueError( + f"Collection '{self.collection_name}' already exists in Qdrant, " + f"but it is configured with a vector size '{current_vector_size}'. " + f"If you want to use that collection, but with a different " + f"vector size, please set `recreate_collection=True` argument." + ) + except (UnexpectedResponse, _InactiveRpcError): + self._recreate_collection(distance, vector_size) + + def _recreate_collection(self, distance: rest.Distance, vector_size: int): + self._client.recreate_collection( + self.collection_name, + vectors_config=rest.VectorParams( + size=vector_size, + distance=distance, + ), + ) + + # Create the payload index for the document_id metadata attribute, as it is + # used to delete the document related entries + self._client.create_payload_index( + self.collection_name, + field_name="metadata.document_id", + field_type=PayloadSchemaType.KEYWORD, + ) + + # Create the payload index for the created_at attribute, to make the lookup + # by range filters faster + self._client.create_payload_index( + self.collection_name, + field_name="created_at", + field_schema=PayloadSchemaType.INTEGER, + ) diff --git a/pyspur/backend/pyspur/rag/datastore/providers/supabase_datastore.py b/pyspur/backend/pyspur/rag/datastore/providers/supabase_datastore.py new file mode 100644 index 0000000000000000000000000000000000000000..b8a9a8046f4a3e8547dd5a15c072aeacfe2cb164 --- /dev/null +++ b/pyspur/backend/pyspur/rag/datastore/providers/supabase_datastore.py @@ -0,0 +1,92 @@ +import os +from typing import Any, List + +from supabase import Client + +from ...schemas.document_schemas import DocumentMetadataFilterSchema +from ..providers.pgvector_datastore import PGClient, PgVectorDataStore + +SUPABASE_URL = os.environ.get("SUPABASE_URL") +assert SUPABASE_URL is not None, "SUPABASE_URL is not set" +SUPABASE_ANON_KEY = os.environ.get("SUPABASE_ANON_KEY") +# use service role key if you want this app to be able to bypass your Row Level Security policies +SUPABASE_SERVICE_ROLE_KEY = os.environ.get("SUPABASE_SERVICE_ROLE_KEY") +assert SUPABASE_ANON_KEY is not None or SUPABASE_SERVICE_ROLE_KEY is not None, ( + "SUPABASE_ANON_KEY or SUPABASE_SERVICE_ROLE_KEY must be set" +) + + +# class that implements the DataStore interface for Supabase Datastore provider +class SupabaseDataStore(PgVectorDataStore): + def create_db_client(self): + return SupabaseClient() + + +class SupabaseClient(PGClient): + def __init__(self) -> None: + super().__init__() + if not SUPABASE_SERVICE_ROLE_KEY: + self.client = Client(SUPABASE_URL, SUPABASE_ANON_KEY) + else: + self.client = Client(SUPABASE_URL, SUPABASE_SERVICE_ROLE_KEY) + + async def upsert(self, table: str, json: dict[str, Any]): + """ + Takes in a list of documents and inserts them into the table. + """ + if "created_at" in json: + json["created_at"] = json["created_at"][0].isoformat() + + self.client.table(table).upsert(json).execute() + + async def rpc(self, function_name: str, params: dict[str, Any]): + """ + Calls a stored procedure in the database with the given parameters. + """ + if "in_start_date" in params: + params["in_start_date"] = params["in_start_date"].isoformat() + if "in_end_date" in params: + params["in_end_date"] = params["in_end_date"].isoformat() + + response = self.client.rpc(function_name, params=params).execute() + return response.data + + async def delete_like(self, table: str, column: str, pattern: str): + """ + Deletes rows in the table that match the pattern. + """ + self.client.table(table).delete().like(column, pattern).execute() + + async def delete_in(self, table: str, column: str, ids: List[str]): + """ + Deletes rows in the table that match the ids. + """ + self.client.table(table).delete().in_(column, ids).execute() + + async def delete_by_filters(self, table: str, filter: DocumentMetadataFilterSchema): + """ + Deletes rows in the table that match the filter. + """ + builder = self.client.table(table).delete() + if filter.document_id: + builder = builder.eq( + "document_id", + filter.document_id, + ) + if filter.source: + builder = builder.eq("source", filter.source) + if filter.source_id: + builder = builder.eq("source_id", filter.source_id) + if filter.author: + builder = builder.eq("author", filter.author) + if filter.start_date: + builder = builder.gte( + "created_at", + filter.start_date[0].isoformat(), + ) + if filter.end_date: + builder = builder.lte( + "created_at", + filter.end_date[0].isoformat(), + ) + builder.execute() diff --git a/pyspur/backend/pyspur/rag/datastore/providers/weaviate_datastore.py b/pyspur/backend/pyspur/rag/datastore/providers/weaviate_datastore.py new file mode 100644 index 0000000000000000000000000000000000000000..ef969e1f84264b1bce3bf7c40b48a2498f14592e --- /dev/null +++ b/pyspur/backend/pyspur/rag/datastore/providers/weaviate_datastore.py @@ -0,0 +1,383 @@ +import asyncio +import os +import re +import uuid +from typing import Dict, List, Optional + +import weaviate +from loguru import logger +from weaviate import Client +from weaviate.util import generate_uuid5 + +from ...schemas.document_schemas import ( + DocumentChunkMetadataSchema, + DocumentChunkSchema, + DocumentChunkWithScoreSchema, + DocumentMetadataFilterSchema, + QueryResultSchema, + QueryWithEmbeddingSchema, + Source, +) +from ..datastore import DataStore + +WEAVIATE_URL_DEFAULT = "http://localhost:8080" +WEAVIATE_CLASS = os.environ.get("WEAVIATE_CLASS", "OpenAIDocument") + +WEAVIATE_BATCH_SIZE = int(os.environ.get("WEAVIATE_BATCH_SIZE", 20)) +WEAVIATE_BATCH_DYNAMIC = os.environ.get("WEAVIATE_BATCH_DYNAMIC", False) +WEAVIATE_BATCH_TIMEOUT_RETRIES = int(os.environ.get("WEAVIATE_TIMEOUT_RETRIES", 3)) +WEAVIATE_BATCH_NUM_WORKERS = int(os.environ.get("WEAVIATE_BATCH_NUM_WORKERS", 1)) + +SCHEMA = { + "class": WEAVIATE_CLASS, + "description": "The main class", + "properties": [ + { + "name": "chunk_id", + "dataType": ["string"], + "description": "The chunk id", + }, + { + "name": "document_id", + "dataType": ["string"], + "description": "The document id", + }, + { + "name": "text", + "dataType": ["text"], + "description": "The chunk's text", + }, + { + "name": "source", + "dataType": ["string"], + "description": "The source of the data", + }, + { + "name": "source_id", + "dataType": ["string"], + "description": "The source id", + }, + { + "name": "url", + "dataType": ["string"], + "description": "The source url", + }, + { + "name": "created_at", + "dataType": ["date"], + "description": "Creation date of document", + }, + { + "name": "author", + "dataType": ["string"], + "description": "Document author", + }, + ], +} + + +def extract_schema_properties(schema): + properties = schema["properties"] + + return {property["name"] for property in properties} + + +class WeaviateDataStore(DataStore): + def handle_errors(self, results: Optional[List[dict]]) -> List[str]: + if not self or not results: + return [] + + error_messages = [] + for result in results: + if ( + "result" not in result + or "errors" not in result["result"] + or "error" not in result["result"]["errors"] + ): + continue + for message in result["result"]["errors"]["error"]: + error_messages.append(message["message"]) + logger.error(message["message"]) + + return error_messages + + def __init__(self, embedding_dimension: Optional[int] = None): + super().__init__(embedding_dimension=embedding_dimension) + auth_credentials = self._build_auth_credentials() + url = WEAVIATE_URL + if not url: + raise ValueError("WEAVIATE_URL is not set") + + logger.debug( + f"Connecting to weaviate instance at {url} with credential type {type(auth_credentials).__name__}" + ) + self.client = Client(url, auth_client_secret=auth_credentials) + self.client.batch.configure( + batch_size=WEAVIATE_BATCH_SIZE, + dynamic=WEAVIATE_BATCH_DYNAMIC, # type: ignore + callback=self.handle_errors, # type: ignore + ) + + if self.client.schema.contains(SCHEMA): + current_schema = self.client.schema.get(WEAVIATE_CLASS) + current_schema_properties = extract_schema_properties(current_schema) + + logger.debug( + f"Found index {WEAVIATE_CLASS} with properties {current_schema_properties}" + ) + else: + new_schema_properties = extract_schema_properties(SCHEMA) + logger.debug( + f"Creating collection {WEAVIATE_CLASS} with properties {new_schema_properties}" + ) + self.client.schema.create_class(SCHEMA) + + @staticmethod + def _build_auth_credentials(): + url = os.environ.get("WEAVIATE_URL", WEAVIATE_URL_DEFAULT) + + if WeaviateDataStore._is_wcs_domain(url): + api_key = os.environ.get("WEAVIATE_API_KEY") + if api_key is not None: + return weaviate.auth.AuthApiKey(api_key=api_key) + else: + raise ValueError("WEAVIATE_API_KEY environment variable is not set") + else: + return None + + async def _upsert(self, chunks: Dict[str, List[DocumentChunkSchema]]) -> List[str]: + """ + Takes in a list of list of document chunks and inserts them into the database. + Return a list of document ids. + """ + doc_ids = [] + + with self.client.batch as batch: + for doc_id, doc_chunks in chunks.items(): + logger.debug(f"Upserting {doc_id} with {len(doc_chunks)} chunks") + for doc_chunk in doc_chunks: + # we generate a uuid regardless of the format of the document_id because + # weaviate needs a uuid to store each document chunk and + # a document chunk cannot share the same uuid + doc_uuid = generate_uuid5(doc_chunk, WEAVIATE_CLASS) + metadata = doc_chunk.metadata + doc_chunk_dict = doc_chunk.dict() + doc_chunk_dict.pop("metadata") + for key, value in metadata.dict().items(): + doc_chunk_dict[key] = value + doc_chunk_dict["chunk_id"] = doc_chunk_dict.pop("id") + doc_chunk_dict["source"] = ( + doc_chunk_dict.pop("source").value if doc_chunk_dict["source"] else None + ) + embedding = doc_chunk_dict.pop("embedding") + + batch.add_data_object( + uuid=doc_uuid, + data_object=doc_chunk_dict, + class_name=WEAVIATE_CLASS, + vector=embedding, + ) + + doc_ids.append(doc_id) + batch.flush() + return doc_ids + + async def _query( + self, + queries: List[QueryWithEmbeddingSchema], + ) -> List[QueryResultSchema]: + """ + Takes in a list of queries with embeddings and filters and returns a list of query results with matching document chunks and scores. + """ + + async def _single_query( + query: QueryWithEmbeddingSchema, + ) -> QueryResultSchema: + logger.debug(f"Query: {query.query}") + if not hasattr(query, "filter") or not query.filter: + result = ( + self.client.query.get( + WEAVIATE_CLASS, + [ + "chunk_id", + "document_id", + "text", + "source", + "source_id", + "url", + "created_at", + "author", + ], + ) + .with_hybrid(query=query.query, alpha=0.5, vector=query.embedding) + .with_limit(query.top_k) # type: ignore + .with_additional(["score", "vector"]) + .do() + ) + else: + filters_ = self.build_filters(query.filter) + result = ( + self.client.query.get( + WEAVIATE_CLASS, + [ + "chunk_id", + "document_id", + "text", + "source", + "source_id", + "url", + "created_at", + "author", + ], + ) + .with_hybrid(query=query.query, alpha=0.5, vector=query.embedding) + .with_where(filters_) + .with_limit(query.top_k) # type: ignore + .with_additional(["score", "vector"]) + .do() + ) + + query_results: List[DocumentChunkWithScoreSchema] = [] + response = result["data"]["Get"][WEAVIATE_CLASS] + + for resp in response: + result = DocumentChunkWithScoreSchema( + id=resp["chunk_id"], + text=resp["text"], + # embedding=resp["_additional"]["vector"], + score=resp["_additional"]["score"], + metadata=DocumentChunkMetadataSchema( + document_id=resp["document_id"] if resp["document_id"] else "", + source=Source(resp["source"]) if resp["source"] else None, + source_id=resp["source_id"], + url=resp["url"], + created_at=resp["created_at"], + author=resp["author"], + ), + ) + query_results.append(result) + return QueryResultSchema(query=query.query, results=query_results) + + return await asyncio.gather(*[_single_query(query) for query in queries]) + + async def delete( + self, + ids: Optional[List[str]] = None, + filter: Optional[DocumentMetadataFilterSchema] = None, + delete_all: Optional[bool] = None, + ) -> bool: + # TODO + """ + Removes vectors by ids, filter, or everything in the datastore. + Returns whether the operation was successful. + """ + if delete_all: + logger.debug(f"Deleting all vectors in index {WEAVIATE_CLASS}") + self.client.schema.delete_all() + return True + + if ids: + operands = [ + { + "path": ["document_id"], + "operator": "Equal", + "valueString": id, + } + for id in ids + ] + + where_clause = {"operator": "Or", "operands": operands} + + logger.debug(f"Deleting vectors from index {WEAVIATE_CLASS} with ids {ids}") + result = self.client.batch.delete_objects( + class_name=WEAVIATE_CLASS, where=where_clause, output="verbose" + ) + + if not bool(result["results"]["successful"]): + logger.debug( + f"Failed to delete the following objects: {result['results']['objects']}" + ) + + if filter: + where_clause = self.build_filters(filter) + + logger.debug(f"Deleting vectors from index {WEAVIATE_CLASS} with filter {where_clause}") + result = self.client.batch.delete_objects(class_name=WEAVIATE_CLASS, where=where_clause) + + if not bool(result["results"]["successful"]): + logger.debug( + f"Failed to delete the following objects: {result['results']['objects']}" + ) + + return True + + @staticmethod + def build_filters(filter): + if filter.source: + filter.source = filter.source.value + + operands = [] + filter_conditions = { + "source": { + "operator": "Equal", + "value": "query.filter.source.value", + "value_key": "valueString", + }, + "start_date": { + "operator": "GreaterThanEqual", + "value_key": "valueDate", + }, + "end_date": {"operator": "LessThanEqual", "value_key": "valueDate"}, + "default": {"operator": "Equal", "value_key": "valueString"}, + } + + for attr, value in filter.__dict__.items(): + if value is not None: + filter_condition = filter_conditions.get(attr, filter_conditions["default"]) + value_key = filter_condition["value_key"] + + operand = { + "path": [ + (attr if not (attr == "start_date" or attr == "end_date") else "created_at") + ], + "operator": filter_condition["operator"], + value_key: value, + } + + operands.append(operand) + + return {"operator": "And", "operands": operands} + + @staticmethod + def _is_valid_weaviate_id(candidate_id: str) -> bool: + """ + Check if candidate_id is a valid UUID for weaviate's use + + Weaviate supports UUIDs of version 3, 4 and 5. This function checks if the candidate_id is a valid UUID of one of these versions. + See https://weaviate.io/developers/weaviate/more-resources/faq#q-are-there-restrictions-on-uuid-formatting-do-i-have-to-adhere-to-any-standards + for more information. + """ + acceptable_version = [3, 4, 5] + + try: + result = uuid.UUID(candidate_id) + if result.version not in acceptable_version: + return False + else: + return True + except ValueError: + return False + + @staticmethod + def _is_wcs_domain(url: str) -> bool: + """ + Check if the given URL ends with ".weaviate.network" or ".weaviate.network/". + + Args: + url (str): The URL to check. + + Returns: + bool: True if the URL ends with the specified strings, False otherwise. + """ + pattern = r"\.(weaviate\.cloud|weaviate\.network)(/)?$" + return bool(re.search(pattern, url)) diff --git a/pyspur/backend/pyspur/rag/datastore/services/date.py b/pyspur/backend/pyspur/rag/datastore/services/date.py new file mode 100644 index 0000000000000000000000000000000000000000..fc2d541e252bc707a25d0d9d4622d9e6cdec6c8e --- /dev/null +++ b/pyspur/backend/pyspur/rag/datastore/services/date.py @@ -0,0 +1,24 @@ +import arrow +from loguru import logger + + +def to_unix_timestamp(date_str: str) -> int: + """ + Convert a date string to a unix timestamp (seconds since epoch). + + Args: + date_str: The date string to convert. + + Returns: + The unix timestamp corresponding to the date string. + + If the date string cannot be parsed as a valid date format, returns the current unix timestamp and prints a warning. + """ + # Try to parse the date string using arrow, which supports many common date formats + try: + date_obj = arrow.get(date_str) + return int(date_obj.timestamp()) + except arrow.parser.ParserError: + # If the parsing fails, return the current unix timestamp and print a warning + logger.info(f"Invalid date format: {date_str}") + return int(arrow.now().timestamp()) diff --git a/pyspur/backend/pyspur/rag/document_collection.py b/pyspur/backend/pyspur/rag/document_collection.py new file mode 100644 index 0000000000000000000000000000000000000000..15a7872dfde109d8ac3a0b2db32b7af78e7feac8 --- /dev/null +++ b/pyspur/backend/pyspur/rag/document_collection.py @@ -0,0 +1,201 @@ +import json +import uuid +from pathlib import Path +from typing import Any, Callable, Coroutine, Dict, List, Optional + +import arrow +from loguru import logger + +from .chunker import ChunkingConfigSchema, create_document_chunks +from .parser import extract_text_from_file +from .schemas.document_schemas import ( + DocumentChunkSchema, + DocumentMetadataSchema, + DocumentSchema, + DocumentWithChunksSchema, + Source, +) + + +class DocumentStore: + """Manages document storage, parsing and chunking.""" + + def __init__(self, kb_id: str): + """Initialize document store for a knowledge base.""" + self.kb_id = kb_id + self.base_dir = Path(f"data/knowledge_bases/{kb_id}") + self.raw_dir = self.base_dir / "raw" + self.chunks_dir = self.base_dir / "chunks" + + # Create directory structure + self.base_dir.mkdir(parents=True, exist_ok=True) + self.raw_dir.mkdir(exist_ok=True) + self.chunks_dir.mkdir(exist_ok=True) + + async def process_documents( + self, + files: List[Dict[str, Any]], + config: Dict[str, Any], + on_progress: Optional[Callable[[float, str, int, int], Coroutine[Any, Any, None]]] = None, + ) -> List[DocumentWithChunksSchema]: + """ + Process documents through parsing and chunking. + + Args: + files: List of file information (path, type, etc.) + config: Configuration for processing + on_progress: Async callback for progress updates + + Returns: + List[DocumentWithChunks]: Processed documents with their chunks + """ + try: + # Initialize progress + if on_progress: + await on_progress(0.0, "parsing", 0, len(files)) + + # Get vision configuration if enabled + vision_config = None + if config.get("use_vision_model", False): + vision_config = { + "model": config.get("vision_model"), + "provider": config.get("vision_provider"), + "api_key": config.get("api_key"), + } + + # 1. Parse documents + documents: List[DocumentSchema] = [] + for i, file_info in enumerate(files): + logger.debug(f"Parsing file {i + 1}/{len(files)}: {file_info.get('path')}") + file_path = Path(file_info["path"]) + + # Create document metadata + metadata = DocumentMetadataSchema( + source=Source.file, + source_id=file_path.name, + created_at=arrow.utcnow().isoformat(), + author=file_info.get("author"), + ) + + # Extract text with vision model if enabled and file is PDF + with open(file_path, "rb") as f: + text = extract_text_from_file( + f, + file_info["mime_type"], + vision_config if file_info["mime_type"] == "application/pdf" else None, + ) + + # Save raw text + doc_id = str(uuid.uuid4()) + raw_path = self.raw_dir / f"{doc_id}.txt" + raw_path.write_text(text) + + # Create document + doc = DocumentSchema(id=doc_id, text=text, metadata=metadata) + documents.append(doc) + + if on_progress: + await on_progress( + (i + 1) / len(files) * 0.5, # First 50% for parsing + "parsing", + i + 1, + len(files), + ) + + # 2. Create chunks + chunking_config = ChunkingConfigSchema( + chunk_token_size=config.get("chunk_token_size", 200), + min_chunk_size_chars=config.get("min_chunk_size_chars", 350), + min_chunk_length_to_embed=config.get("min_chunk_length_to_embed", 5), + embeddings_batch_size=config.get("embeddings_batch_size", 128), + max_num_chunks=config.get("max_num_chunks", 10000), + template=config.get("template", {}), + ) + + docs_with_chunks: List[DocumentWithChunksSchema] = [] + + for i, doc in enumerate(documents): + # Create chunks + doc_chunks, doc_id = create_document_chunks(doc, chunking_config) + + # Save chunks + chunks_path = self.chunks_dir / f"{doc_id}.json" + with open(chunks_path, "w") as f: + json.dump( + [chunk.model_dump() for chunk in doc_chunks], + f, + indent=2, + ) + + # Create DocumentWithChunks + doc_with_chunks = DocumentWithChunksSchema( + id=doc_id, + text=doc.text, + metadata=doc.metadata, + chunks=doc_chunks, + ) + docs_with_chunks.append(doc_with_chunks) + + if on_progress: + await on_progress( + 0.5 + (i + 1) / len(documents) * 0.5, # Last 50% for chunking + "chunking", + i + 1, + len(documents), + ) + + return docs_with_chunks + + except Exception as e: + logger.error(f"Error processing documents: {e}") + raise + + def get_document(self, doc_id: str) -> Optional[DocumentWithChunksSchema]: + """Retrieve a document and its chunks from storage.""" + try: + # Read raw text + raw_path = self.raw_dir / f"{doc_id}.txt" + if not raw_path.exists(): + return None + + text = raw_path.read_text() + + # Read chunks + chunks_path = self.chunks_dir / f"{doc_id}.json" + if not chunks_path.exists(): + return None + + with open(chunks_path) as f: + chunks_data = json.load(f) + chunks = [DocumentChunkSchema(**chunk_data) for chunk_data in chunks_data] + + # Create DocumentWithChunks + return DocumentWithChunksSchema(id=doc_id, text=text, chunks=chunks) + + except Exception as e: + logger.error(f"Error retrieving document {doc_id}: {e}") + return None + + def list_documents(self) -> List[str]: + """List all document IDs in the store.""" + try: + return [p.stem for p in self.raw_dir.glob("*.txt")] + except Exception as e: + logger.error(f"Error listing documents: {e}") + return [] + + def delete_document(self, doc_id: str) -> bool: + """Delete a document and its chunks from storage.""" + try: + raw_path = self.raw_dir / f"{doc_id}.txt" + chunks_path = self.chunks_dir / f"{doc_id}.json" + + if raw_path.exists(): + raw_path.unlink() + if chunks_path.exists(): + chunks_path.unlink() + + return True + except Exception as e: + logger.error(f"Error deleting document {doc_id}: {e}") + return False diff --git a/pyspur/backend/pyspur/rag/embedder.py b/pyspur/backend/pyspur/rag/embedder.py new file mode 100644 index 0000000000000000000000000000000000000000..dbeadffe09424112e99f8abd84469dce11c07539 --- /dev/null +++ b/pyspur/backend/pyspur/rag/embedder.py @@ -0,0 +1,410 @@ +import logging +from enum import Enum +from typing import Any, Callable, Dict, List, Optional, TypeAlias, Union + +import numpy as np +import numpy.typing as npt +from litellm import aembedding +from litellm.types.utils import EmbeddingResponse +from pydantic import BaseModel, Field +from tenacity import stop_after_attempt, wait_random_exponential + +from ..nodes.llm._utils import async_retry + +EmbeddingArray: TypeAlias = npt.NDArray[np.float32] +EmbeddingData = Dict[str, List[float]] + + +class EmbeddingProvider(str, Enum): + OPENAI = "OpenAI" + AZURE_OPENAI = "AzureOpenAI" + COHERE = "Cohere" + VOYAGE = "Voyage" + MISTRAL = "Mistral" + GEMINI = "Gemini" + + +class CohereEncodingFormat(str, Enum): + FLOAT = "float" + INT8 = "int8" + UINT8 = "uint8" + BINARY = "binary" + UBINARY = "ubinary" + + +class EmbeddingModelConfig(BaseModel): + id: str + provider: EmbeddingProvider + name: str + dimensions: int = Field(default=1536) + max_input_length: int = Field(default=8191) + supported_encoding_formats: Optional[List[CohereEncodingFormat]] = None + required_env_vars: List[str] = Field(default_factory=list) + + +class EmbeddingModels(str, Enum): + # OpenAI Models + TEXT_EMBEDDING_3_SMALL = "openai/text-embedding-3-small" + TEXT_EMBEDDING_3_LARGE = "openai/text-embedding-3-large" + + # Azure OpenAI Models + AZURE_TEXT_EMBEDDING_3_SMALL = "azure/text-embedding-3-small" + AZURE_TEXT_EMBEDDING_3_LARGE = "azure/text-embedding-3-large" + + # Cohere Models + COHERE_EMBED_ENGLISH = "cohere/embed-english-v3.0" + COHERE_EMBED_ENGLISH_LIGHT = "cohere/embed-english-light-v3.0" + COHERE_EMBED_MULTILINGUAL = "cohere/embed-multilingual-v3.0" + COHERE_EMBED_MULTILINGUAL_LIGHT = "cohere/embed-multilingual-light-v3.0" + + # Voyage Models + VOYAGE_3_LARGE = "voyage/voyage-3-large" + VOYAGE_3 = "voyage/voyage-3" + VOYAGE_3_LITE = "voyage/voyage-3-lite" + VOYAGE_CODE_3 = "voyage/voyage-code-3" + VOYAGE_FINANCE_2 = "voyage/voyage-finance-2" + VOYAGE_LAW_2 = "voyage/voyage-law-2" + + # Mistral Models + MISTRAL_EMBED = "mistral/mistral-embed" + + # Gemini Models + GEMINI_TEXT_EMBEDDING = "gemini/text-embedding-004" + + @classmethod + def get_model_info(cls, model_id: str) -> Optional[EmbeddingModelConfig]: + model_registry = { + # OpenAI Models + cls.TEXT_EMBEDDING_3_SMALL.value: EmbeddingModelConfig( + id=cls.TEXT_EMBEDDING_3_SMALL.value, + provider=EmbeddingProvider.OPENAI, + name="Text Embedding 3 Small", + dimensions=1536, + max_input_length=8191, + ), + cls.TEXT_EMBEDDING_3_LARGE.value: EmbeddingModelConfig( + id=cls.TEXT_EMBEDDING_3_LARGE.value, + provider=EmbeddingProvider.OPENAI, + name="Text Embedding 3 Large", + dimensions=3072, + max_input_length=8191, + ), + # Azure OpenAI Models + cls.AZURE_TEXT_EMBEDDING_3_SMALL.value: EmbeddingModelConfig( + id=cls.AZURE_TEXT_EMBEDDING_3_SMALL.value, + provider=EmbeddingProvider.AZURE_OPENAI, + name="Azure Text Embedding 3 Small", + dimensions=1536, + max_input_length=8191, + ), + cls.AZURE_TEXT_EMBEDDING_3_LARGE.value: EmbeddingModelConfig( + id=cls.AZURE_TEXT_EMBEDDING_3_LARGE.value, + provider=EmbeddingProvider.AZURE_OPENAI, + name="Azure Text Embedding 3 Large", + dimensions=3072, + max_input_length=8191, + ), + # Cohere Models + cls.COHERE_EMBED_ENGLISH.value: EmbeddingModelConfig( + id=cls.COHERE_EMBED_ENGLISH.value, + provider=EmbeddingProvider.COHERE, + name="Cohere Embed English V3", + dimensions=1024, + max_input_length=8191, + supported_encoding_formats=[ + CohereEncodingFormat.FLOAT, + CohereEncodingFormat.INT8, + CohereEncodingFormat.UINT8, + CohereEncodingFormat.BINARY, + CohereEncodingFormat.UBINARY, + ], + ), + cls.COHERE_EMBED_ENGLISH_LIGHT.value: EmbeddingModelConfig( + id=cls.COHERE_EMBED_ENGLISH_LIGHT.value, + provider=EmbeddingProvider.COHERE, + name="Cohere Embed English Light V3", + dimensions=384, + max_input_length=8191, + supported_encoding_formats=[ + CohereEncodingFormat.FLOAT, + CohereEncodingFormat.INT8, + CohereEncodingFormat.UINT8, + CohereEncodingFormat.BINARY, + CohereEncodingFormat.UBINARY, + ], + ), + cls.COHERE_EMBED_MULTILINGUAL.value: EmbeddingModelConfig( + id=cls.COHERE_EMBED_MULTILINGUAL.value, + provider=EmbeddingProvider.COHERE, + name="Cohere Embed Multilingual V3", + dimensions=1024, + max_input_length=8191, + supported_encoding_formats=[ + CohereEncodingFormat.FLOAT, + CohereEncodingFormat.INT8, + CohereEncodingFormat.UINT8, + CohereEncodingFormat.BINARY, + CohereEncodingFormat.UBINARY, + ], + ), + cls.COHERE_EMBED_MULTILINGUAL_LIGHT.value: EmbeddingModelConfig( + id=cls.COHERE_EMBED_MULTILINGUAL_LIGHT.value, + provider=EmbeddingProvider.COHERE, + name="Cohere Embed Multilingual Light V3", + dimensions=384, + max_input_length=8191, + supported_encoding_formats=[ + CohereEncodingFormat.FLOAT, + CohereEncodingFormat.INT8, + CohereEncodingFormat.UINT8, + CohereEncodingFormat.BINARY, + CohereEncodingFormat.UBINARY, + ], + ), + # Voyage Models + cls.VOYAGE_3_LARGE.value: EmbeddingModelConfig( + id=cls.VOYAGE_3_LARGE.value, + provider=EmbeddingProvider.VOYAGE, + name="Voyage 3 Large", + dimensions=1024, + max_input_length=32000, + ), + cls.VOYAGE_3.value: EmbeddingModelConfig( + id=cls.VOYAGE_3.value, + provider=EmbeddingProvider.VOYAGE, + name="Voyage 3", + dimensions=1024, + max_input_length=32000, + ), + cls.VOYAGE_3_LITE.value: EmbeddingModelConfig( + id=cls.VOYAGE_3_LITE.value, + provider=EmbeddingProvider.VOYAGE, + name="Voyage 3 Lite", + dimensions=512, + max_input_length=32000, + ), + cls.VOYAGE_CODE_3.value: EmbeddingModelConfig( + id=cls.VOYAGE_CODE_3.value, + provider=EmbeddingProvider.VOYAGE, + name="Voyage Code 3", + dimensions=1024, + max_input_length=32000, + ), + cls.VOYAGE_FINANCE_2.value: EmbeddingModelConfig( + id=cls.VOYAGE_FINANCE_2.value, + provider=EmbeddingProvider.VOYAGE, + name="Voyage Finance 2", + dimensions=1024, + max_input_length=32000, + ), + cls.VOYAGE_LAW_2.value: EmbeddingModelConfig( + id=cls.VOYAGE_LAW_2.value, + provider=EmbeddingProvider.VOYAGE, + name="Voyage Law 2", + dimensions=1024, + max_input_length=16000, + ), + # Mistral Models + cls.MISTRAL_EMBED.value: EmbeddingModelConfig( + id=cls.MISTRAL_EMBED.value, + provider=EmbeddingProvider.MISTRAL, + name="Mistral Embed", + dimensions=1024, + max_input_length=8191, + ), + # Gemini Models + cls.GEMINI_TEXT_EMBEDDING.value: EmbeddingModelConfig( + id=cls.GEMINI_TEXT_EMBEDDING.value, + provider=EmbeddingProvider.GEMINI, + name="Gemini Text Embedding", + dimensions=768, + max_input_length=3072, + ), + } + return model_registry.get(model_id) + + +@async_retry( + wait=wait_random_exponential(min=30, max=120), + stop=stop_after_attempt(3), +) +async def get_single_text_embedding( + text: str, + model: str, + dimensions: Optional[int] = None, + api_key: Optional[str] = None, + encoding_format: Optional[CohereEncodingFormat] = None, +) -> List[float]: + """Get embeddings for a single text using the specified model.""" + try: + model_info = EmbeddingModels.get_model_info(model) + if not model_info: + raise ValueError(f"Unknown model: {model}") + + # Truncate text if needed + if len(text) > model_info.max_input_length: + text = text[: model_info.max_input_length] + + # Prepare kwargs for litellm + kwargs = { + "model": model, + "input": text, + } + + # Add optional parameters + if dimensions: + kwargs["dimensions"] = dimensions + if api_key: + kwargs["api_key"] = api_key + if encoding_format and model_info.provider == EmbeddingProvider.COHERE: + if ( + not model_info.supported_encoding_formats + or encoding_format not in model_info.supported_encoding_formats + ): + raise ValueError( + f"Encoding format {encoding_format} not supported for model {model}" + ) + kwargs["encoding_format"] = encoding_format + + response = await aembedding(**kwargs) + return response.data[0]["embedding"] + + except Exception as e: + logging.error(f"Error getting embedding: {str(e)}") + raise + + +@async_retry( + wait=wait_random_exponential(min=30, max=120), + stop=stop_after_attempt(3), +) +async def get_multiple_text_embeddings( + docs: List[Any], + model: str, + dimensions: Optional[int] = None, + text_extractor: Optional[Callable[[Any], str]] = None, + api_key: Optional[str] = None, + batch_size: int = 100, + encoding_format: Optional[CohereEncodingFormat] = None, +) -> EmbeddingArray: + """Compute embeddings for a list of documents.""" + if text_extractor: + texts = [text_extractor(doc) for doc in docs] + else: + if all(isinstance(doc, str) for doc in docs): + texts = docs + else: + logging.error( + "Documents must be strings or you must provide a text_extractor function." + ) + return np.array([], dtype=np.float32) + + # Process in batches + all_embeddings: List[List[float]] = [] + for i in range(0, len(texts), batch_size): + batch = texts[i : i + batch_size] + try: + # Prepare kwargs for litellm + kwargs: Dict[str, Union[str, List[str], int]] = { + "model": model, + "input": batch, + } + # Add optional parameters + if dimensions: + kwargs["dimensions"] = dimensions + if api_key: + kwargs["api_key"] = api_key + if encoding_format: + model_info = EmbeddingModels.get_model_info(model) + if model_info and model_info.provider == EmbeddingProvider.COHERE: + if ( + not model_info.supported_encoding_formats + or encoding_format not in model_info.supported_encoding_formats + ): + raise ValueError( + f"Encoding format {encoding_format} not supported for model {model}" + ) + kwargs["encoding_format"] = encoding_format.value + + # Log the request details + logging.debug(f"[DEBUG] Requesting embeddings for batch of size {len(batch)}") + logging.debug(f"[DEBUG] First text in batch (truncated): {batch[0][:100]}...") + logging.debug(f"[DEBUG] Using model: {model}") + + response: EmbeddingResponse = await aembedding(**kwargs) + batch_embeddings: List[List[float]] = [item["embedding"] for item in response.data] + all_embeddings.extend(batch_embeddings) + logging.debug(f"[DEBUG] Batch embeddings length: {len(batch_embeddings)}") + logging.debug(f"[DEBUG] First embedding sample: {batch_embeddings[0][:5]}") + # Validate embeddings + for i, emb in enumerate(batch_embeddings): + if not emb or len(emb) == 0: + raise ValueError(f"Empty embedding received for text at index {i}") + if all(v == 0 for v in emb): + raise ValueError(f"All-zero embedding received for text at index {i}") + + # Log success + logging.debug(f"Successfully processed batch of {len(batch)} texts") + + except Exception as e: + logging.error(f"Error obtaining embeddings for batch: {str(e)}") + logging.error("Batch details:") + logging.error(f"- Batch size: {len(batch)}") + logging.error(f"- Model: {model}") + logging.error(f"- First text (truncated): {batch[0][:100]}...") + raise # Re-raise the exception to be handled by the retry decorator + + return np.array(all_embeddings, dtype=np.float32) + + +def cosine_similarity(a: EmbeddingArray, b: EmbeddingArray) -> EmbeddingArray: + """Compute cosine similarity between two sets of vectors.""" + norm_a = np.linalg.norm(a, axis=1) + norm_b = np.linalg.norm(b, axis=1) + return np.dot(a, b.T) / np.outer(norm_a, norm_b) + + +async def find_top_k_similar_documents( + query_docs: List[Any], + candidate_docs: List[Any], + model: str, + k: int = 5, + dimensions: Optional[int] = None, + text_extractor: Optional[Callable[[Any], str]] = None, + id_extractor: Optional[Callable[[Any], Any]] = None, + api_key: Optional[str] = None, + encoding_format: Optional[CohereEncodingFormat] = None, +) -> Dict[Any, List[Dict[str, Any]]]: + """Find top k similar documents from candidate_docs for each query doc.""" + query_embeddings: EmbeddingArray = await get_multiple_text_embeddings( + query_docs, + model=model, + dimensions=dimensions, + text_extractor=text_extractor, + api_key=api_key, + encoding_format=encoding_format, + ) + candidate_embeddings: EmbeddingArray = await get_multiple_text_embeddings( + candidate_docs, + model=model, + dimensions=dimensions, + text_extractor=text_extractor, + api_key=api_key, + encoding_format=encoding_format, + ) + + similarity_matrix = cosine_similarity(query_embeddings, candidate_embeddings) + top_k_indices = np.argsort(-similarity_matrix, axis=1)[:, :k] + + top_k_similar_docs: Dict[Any, List[Dict[str, Any]]] = {} + for i, query_doc in enumerate(query_docs): + similar_docs = [ + { + "document": candidate_docs[idx], + "similarity_score": float(similarity_matrix[i][idx]), + } + for idx in top_k_indices[i] + ] + key = id_extractor(query_doc) if id_extractor else i + top_k_similar_docs[key] = similar_docs + return top_k_similar_docs diff --git a/pyspur/backend/pyspur/rag/parser.py b/pyspur/backend/pyspur/rag/parser.py new file mode 100644 index 0000000000000000000000000000000000000000..837e84f1d2a2f289c0071270a7cb3c3aed000459 --- /dev/null +++ b/pyspur/backend/pyspur/rag/parser.py @@ -0,0 +1,181 @@ +import asyncio +import csv +import mimetypes +import os +from io import BufferedReader +from typing import Any, Dict, Optional + +import docx2txt +import pptx +from fastapi import UploadFile +from loguru import logger +from pypdf import PdfReader +from pyzerox import zerox + +from .schemas.document_schemas import DocumentMetadataSchema, DocumentSchema + + +async def get_document_from_file( + file: UploadFile, metadata: DocumentMetadataSchema +) -> DocumentSchema: + extracted_text = await extract_text_from_form_file(file) + + doc = DocumentSchema(text=extracted_text, metadata=metadata) + + return doc + + +def extract_text_from_filepath(filepath: str, mimetype: Optional[str] = None) -> str: + """Return the text content of a file given its filepath.""" + + if mimetype is None: + # Get the mimetype of the file based on its extension + mimetype, _ = mimetypes.guess_type(filepath) + + if not mimetype: + if filepath.endswith(".md"): + mimetype = "text/markdown" + else: + raise Exception("Unsupported file type") + + try: + with open(filepath, "rb") as file: + extracted_text = extract_text_from_file(file, mimetype) + except Exception as e: + logger.error(e) + raise e + + return extracted_text + + +def extract_text_from_file( + file: BufferedReader, + mimetype: str, + vision_config: Optional[Dict[str, Any]] = None, +) -> str: + if vision_config and mimetype == "application/pdf": + # Save to temporary file for vision model processing + temp_file_path = "/tmp/temp_vision_file.pdf" + with open(temp_file_path, "wb") as temp_file: + temp_file.write(file.read()) + + try: + # Process with vision model + extracted_text = asyncio.run( + extract_text_with_vision_model( + file_path=temp_file_path, + model=vision_config.get("model", "gpt-4o-mini"), + api_key=vision_config.get("api_key"), + provider=vision_config.get("provider"), + system_prompt=vision_config.get("system_prompt"), + ) + ) + finally: + # Clean up temporary file + if os.path.exists(temp_file_path): + os.remove(temp_file_path) + + return extracted_text + + # Existing text extraction logic + if mimetype == "application/pdf": + # Extract text from pdf using PyPDF2 + reader = PdfReader(file) + extracted_text = " ".join([page.extract_text() for page in reader.pages]) + elif mimetype == "text/plain" or mimetype == "text/markdown": + # Read text from plain text file + extracted_text = file.read().decode("utf-8") + elif mimetype == "application/vnd.openxmlformats-officedocument.wordprocessingml.document": + # Extract text from docx using docx2txt + extracted_text = docx2txt.process(file) + elif mimetype == "text/csv": + # Extract text from csv using csv module + extracted_text = "" + decoded_buffer = (line.decode("utf-8") for line in file) + reader = csv.reader(decoded_buffer) + for row in reader: + extracted_text += " ".join(row) + "\n" + elif mimetype == "application/vnd.openxmlformats-officedocument.presentationml.presentation": + # Extract text from pptx using python-pptx + extracted_text = "" + presentation = pptx.Presentation(file) + for slide in presentation.slides: + for shape in slide.shapes: + if shape.has_text_frame: + for paragraph in shape.text_frame.paragraphs: + for run in paragraph.runs: + extracted_text += run.text + " " + extracted_text += "\n" + else: + # Unsupported file type + raise ValueError("Unsupported file type: {}".format(mimetype)) + + return extracted_text + + +# Extract text from a file based on its mimetype +async def extract_text_from_form_file(file: UploadFile): + """Return the text content of a file.""" + # get the file body from the upload file object + mimetype = file.content_type + logger.info(f"mimetype: {mimetype}") + logger.info(f"file.file: {file.file}") + logger.info("file: ", file) + + file_stream = await file.read() + + temp_file_path = "/tmp/temp_file" + + # write the file to a temporary location + with open(temp_file_path, "wb") as f: + f.write(file_stream) + + try: + extracted_text = extract_text_from_filepath(temp_file_path, mimetype) + except Exception as e: + logger.error(e) + os.remove(temp_file_path) + raise e + + # remove file from temp location + os.remove(temp_file_path) + + return extracted_text + + +async def extract_text_with_vision_model( + file_path: str, + model: str = "gpt-4o-mini", + api_key: Optional[str] = None, + provider: Optional[str] = None, + system_prompt: Optional[str] = None, +) -> str: + """Extract text from a document using vision models via pyzerox.""" + kwargs: Dict[str, Any] = {} + + # Set up environment variables based on provider + if provider == "openai" and api_key: + os.environ["OPENAI_API_KEY"] = api_key + elif provider == "azure" and api_key: + os.environ["AZURE_API_KEY"] = api_key + elif provider == "gemini" and api_key: + os.environ["GEMINI_API_KEY"] = api_key + elif provider == "anthropic" and api_key: + os.environ["ANTHROPIC_API_KEY"] = api_key + elif provider == "vertex_ai" and api_key: + kwargs = {"vertex_credentials": api_key} + + try: + # Process the document with zerox + result = await zerox( + file_path=file_path, + model=model, + output_dir="/tmp/zerox_output", # Temporary output directory + custom_system_prompt=system_prompt, + cleanup=True, # Clean up temporary files + **kwargs, + ) + return str(result) + except Exception as e: + logger.error(f"Error in vision model processing: {e}") + raise e diff --git a/pyspur/backend/pyspur/rag/reranker.py b/pyspur/backend/pyspur/rag/reranker.py new file mode 100644 index 0000000000000000000000000000000000000000..e785044726d6e2c8e285bc8677093f83b5e0939f --- /dev/null +++ b/pyspur/backend/pyspur/rag/reranker.py @@ -0,0 +1,142 @@ +import logging +from enum import Enum +from typing import Any, Callable, Dict, List, Optional + +from litellm import arerank +from pydantic import BaseModel, Field +from tenacity import stop_after_attempt, wait_random_exponential + +from ..nodes.llm._utils import async_retry + + +class RerankerProvider(str, Enum): + COHERE = "Cohere" + + +class RerankerModelConfig(BaseModel): + id: str + provider: RerankerProvider + name: str + max_input_length: int = Field(default=8191) + + +class RerankerModels(str, Enum): + # Cohere Models + COHERE_RERANK_ENGLISH = "cohere/rerank-english-v3.0" + COHERE_RERANK_MULTILINGUAL = "cohere/rerank-multilingual-v3.0" + + @classmethod + def get_model_info(cls, model_id: str) -> RerankerModelConfig: + model_registry = { + # Cohere Models + cls.COHERE_RERANK_ENGLISH.value: RerankerModelConfig( + id=cls.COHERE_RERANK_ENGLISH.value, + provider=RerankerProvider.COHERE, + name="Cohere Rerank English V3", + max_input_length=8191, + ), + cls.COHERE_RERANK_MULTILINGUAL.value: RerankerModelConfig( + id=cls.COHERE_RERANK_MULTILINGUAL.value, + provider=RerankerProvider.COHERE, + name="Cohere Rerank Multilingual V3", + max_input_length=8191, + ), + } + return model_registry.get(model_id) + + +@async_retry( + wait=wait_random_exponential(min=30, max=120), + stop=stop_after_attempt(3), +) +async def rerank_documents_by_query( + query: str, + documents: List[Any], + model: str, + top_n: int = 3, + text_extractor: Optional[Callable[[Any], str]] = None, + api_key: Optional[str] = None, + batch_size: int = 100, +) -> List[Dict[str, Any]]: + """Rerank documents based on their relevance to the query.""" + try: + model_info = RerankerModels.get_model_info(model) + if not model_info: + raise ValueError(f"Unknown model: {model}") + + # Extract text from documents if text_extractor is provided + if text_extractor: + doc_texts = [text_extractor(doc) for doc in documents] + else: + if all(isinstance(doc, str) for doc in documents): + doc_texts = documents + else: + logging.error( + "Documents must be strings or you must provide a text_extractor function." + ) + return [] + + # Process in batches if needed + all_results = [] + for i in range(0, len(doc_texts), batch_size): + batch = doc_texts[i : i + batch_size] + + # Prepare kwargs for litellm + kwargs = { + "model": model, + "query": query, + "documents": batch, + "top_n": min(top_n, len(batch)), + } + + if api_key: + kwargs["api_key"] = api_key + + response = await arerank(**kwargs) + + # Process results + batch_results = [] + for result in response.data: + batch_results.append( + { + "document": documents[i + result.document_index], + "relevance_score": result.relevance_score, + "index": i + result.document_index, + } + ) + all_results.extend(batch_results) + + # Sort all results by relevance score and take top_n + all_results.sort(key=lambda x: x["relevance_score"], reverse=True) + return all_results[:top_n] + + except Exception as e: + logging.error(f"Error reranking documents: {str(e)}") + raise + + +async def get_top_n_relevant_documents( + query: str, + documents: List[Any], + model: str, + top_n: int = 3, + text_extractor: Optional[Callable[[Any], str]] = None, + id_extractor: Optional[Callable[[Any], Any]] = None, + api_key: Optional[str] = None, +) -> List[Dict[str, Any]]: + """Find the most relevant documents for a query using reranking.""" + results = await rerank_documents_by_query( + query=query, + documents=documents, + model=model, + top_n=top_n, + text_extractor=text_extractor, + api_key=api_key, + ) + + # Add document IDs if id_extractor is provided + if id_extractor: + for result in results: + result["id"] = id_extractor(result["document"]) + + return results diff --git a/pyspur/backend/pyspur/rag/schemas/document_schemas.py b/pyspur/backend/pyspur/rag/schemas/document_schemas.py new file mode 100644 index 0000000000000000000000000000000000000000..2c0a8fa488b766c729b939425460965e4be1a120 --- /dev/null +++ b/pyspur/backend/pyspur/rag/schemas/document_schemas.py @@ -0,0 +1,101 @@ +from enum import Enum +from typing import Dict, List, Optional + +from pydantic import BaseModel, Field + + +class Source(str, Enum): + file = "file" + url = "url" + text = "text" + + +class DocumentMetadataSchema(BaseModel): + """Metadata for a document.""" + + source: Source = Source.text + source_id: Optional[str] = None + created_at: Optional[str] = None + author: Optional[str] = None + title: Optional[str] = None + custom_metadata: Optional[Dict[str, str]] = None + + +class DocumentChunkMetadataSchema(DocumentMetadataSchema): + """Metadata for a document chunk.""" + + document_id: Optional[str] = None + chunk_index: Optional[int] = None + custom_metadata: Optional[Dict[str, str]] = Field(default_factory=dict) + + +class DocumentSchema(BaseModel): + """A document with its metadata.""" + + id: Optional[str] = None + text: str + metadata: Optional[DocumentMetadataSchema] = None + + +class DocumentChunkSchema(BaseModel): + """A chunk of a document with its metadata and embedding.""" + + id: str + text: str + metadata: DocumentChunkMetadataSchema + embedding: Optional[List[float]] = None + + +class DocumentChunkWithScoreSchema(DocumentChunkSchema): + score: float + + +class DocumentWithChunksSchema(DocumentSchema): + """A document with its chunks.""" + + chunks: List[DocumentChunkSchema] = Field(default_factory=list) + + +class DocumentMetadataFilterSchema(BaseModel): + document_id: Optional[str] = None + source: Optional[Source] = None + source_id: Optional[str] = None + author: Optional[str] = None + start_date: Optional[str] = None # any date string format + end_date: Optional[str] = None # any date string format + + +class ChunkTemplateSchema(BaseModel): + """Configuration for chunk templates.""" + + enabled: bool = False + template: str = "{{ text }}" # Default template just shows the text + metadata_template: Optional[Dict[str, str]] = Field( + default_factory=lambda: {"type": "text_chunk"} + ) + + +class ChunkingConfigSchema(BaseModel): + """Configuration for text chunking.""" + + chunk_token_size: int = 200 + min_chunk_size_chars: int = 350 + min_chunk_length_to_embed: int = 5 + embeddings_batch_size: int = 128 + max_num_chunks: int = 10000 + template: ChunkTemplateSchema = Field(default_factory=ChunkTemplateSchema) + + +class QuerySchema(BaseModel): + query: str + filter: Optional[DocumentMetadataFilterSchema] = None + top_k: Optional[int] = 3 + + +class QueryWithEmbeddingSchema(QuerySchema): + embedding: List[float] + + +class QueryResultSchema(BaseModel): + query: str + results: List[DocumentChunkWithScoreSchema] diff --git a/pyspur/backend/pyspur/rag/vector_index.py b/pyspur/backend/pyspur/rag/vector_index.py new file mode 100644 index 0000000000000000000000000000000000000000..800477bdc7f0688c6acbd212b832924aa02e6632 --- /dev/null +++ b/pyspur/backend/pyspur/rag/vector_index.py @@ -0,0 +1,344 @@ +import json +from pathlib import Path +from typing import ( + Any, + Callable, + Coroutine, + Dict, + List, + Optional, + Sequence, + Union, + cast, +) + +import numpy as np +from loguru import logger + +from .datastore.factory import get_datastore +from .embedder import ( + EmbeddingModels, + get_multiple_text_embeddings, + get_single_text_embedding, +) +from .schemas.document_schemas import ( + DocumentChunkSchema, + DocumentMetadataFilterSchema, + DocumentSchema, + DocumentWithChunksSchema, + QueryWithEmbeddingSchema, +) + + +class ProcessingError(Exception): + """Custom exception for vector processing errors""" + + pass + + +async def _call_progress( + on_progress: Optional[Callable[[float, str, int, int], Coroutine[Any, Any, None]]], + progress: float, + stage: str, + processed_chunks: int, + total_chunks: int, +) -> None: + """Helper function to safely call the progress callback""" + if on_progress: + await on_progress(progress, stage, processed_chunks, total_chunks) + + +class VectorIndex: + """Manages vector index operations.""" + + def __init__(self, index_id: str): + """Initialize vector index manager.""" + self.index_id = index_id + self.base_dir = Path(f"data/vector_indices/{index_id}") + self.embeddings_dir = self.base_dir / "embeddings" + self.config_path = self.base_dir / "config.json" + + # Create base directory + self.base_dir.mkdir(parents=True, exist_ok=True) + self.embeddings_dir.mkdir(exist_ok=True) + + # Load or create config + self.config = self._load_config() + + def _load_config(self) -> Dict[str, Any]: + """Load vector index configuration.""" + if self.config_path.exists(): + with open(self.config_path) as f: + return json.load(f) + return {} + + def _save_config(self) -> None: + """Save vector index configuration.""" + with open(self.config_path, "w") as f: + json.dump(self.config, f, indent=2) + + def update_config(self, config: Dict[str, Any]) -> None: + """Update vector index configuration.""" + self.config.update(config) + self._save_config() + + async def create_from_document_collection( + self, + docs_with_chunks: List[DocumentWithChunksSchema], + config: Dict[str, Any], + on_progress: Optional[Callable[[float, str, int, int], Coroutine[Any, Any, None]]] = None, + ) -> str: + """Create a vector index from a document collection. + + Args: + docs_with_chunks: List of documents with their chunks + config: Configuration for processing + on_progress: Async callback for progress updates + + Returns: + str: Vector index ID + + """ + try: + # Update config + self.update_config(config) + + # Get all chunks + all_chunks: List[DocumentChunkSchema] = [] + for doc in docs_with_chunks: + all_chunks.extend(doc.chunks) + + if not all_chunks: + logger.warning("No chunks found to process") + return self.index_id + + # Initialize progress + await _call_progress(on_progress, 0.0, "embedding", 0, len(all_chunks)) + + # Get chunk texts + chunk_texts = [chunk.text for chunk in all_chunks] + + try: + # Use OpenAI's text-embedding-3-small by default + embedding_model = config.get( + "model", + EmbeddingModels.TEXT_EMBEDDING_3_SMALL.value, + ) + model_info = EmbeddingModels.get_model_info(embedding_model) + if not model_info: + raise ValueError(f"Unknown embedding model: {embedding_model}") + + logger.debug( + f"Using embedding model: {embedding_model} with {model_info.dimensions} dimensions" + ) + + # Report starting embeddings phase + await _call_progress( + on_progress, + 0.0, + "embedding", + 0, # processed_chunks + len(all_chunks), # total_chunks + ) + + embeddings: Sequence[ + Union[List[float], np.ndarray] + ] = await get_multiple_text_embeddings( + docs=chunk_texts, + model=embedding_model, + dimensions=model_info.dimensions, + batch_size=config.get("embeddings_batch_size", 128), + api_key=config.get("openai_api_key"), + ) + + logger.debug(f"[DEBUG] Embeddings generated: {embeddings}.") + except Exception as e: + logger.error(f"Error generating embeddings: {str(e)}") + raise ProcessingError(f"Failed to generate embeddings: {str(e)}") + + # Update chunks with embeddings + processed_chunks = 0 + for i, chunk in enumerate(all_chunks): + if embeddings[i] is None: + logger.error(f"No embedding generated for chunk {i}") + continue + + # Convert embedding to list of floats + try: + embedding_list = ( + embeddings[i].tolist() + if hasattr(embeddings[i], "tolist") + else embeddings[i] + ) + embedding_list = [float(x) for x in embedding_list] + chunk.embedding = embedding_list + + # Save embeddings + doc_id = chunk.metadata.document_id + if doc_id is not None: + emb_path = self.embeddings_dir / f"{doc_id}_{i}.json" + with open(emb_path, "w") as f: + json.dump( + { + "chunk_id": chunk.id, + "embedding": embedding_list, + }, + f, + ) + processed_chunks += 1 + except Exception as e: + logger.error(f"Error converting embedding: {str(e)}") + continue + + # Update progress for embedding phase (0-70%) + await _call_progress( + on_progress, + (i + 1) / len(all_chunks) * 0.7, + "embedding", + processed_chunks, # processed_chunks + len(all_chunks), # total_chunks + ) + + # Report starting vector store upload + await _call_progress( + on_progress, + 0.7, + "uploading", + processed_chunks, # processed_chunks + len(all_chunks), # total_chunks + ) + + # Initialize datastore + datastore = await get_datastore(config["vector_db"], embedding_model=embedding_model) + logger.debug("Datastore initialized, starting to upsert chunks.") + + # Insert chunks into datastore + await datastore.upsert( + cast(List[DocumentSchema], docs_with_chunks), + chunk_token_size=config.get("chunk_token_size", 200), + ) + logger.debug("All chunks successfully upserted into datastore.") + + # Update progress for completion + await _call_progress( + on_progress, + 1.0, + "completed", + processed_chunks, # processed_chunks + len(all_chunks), # total_chunks + ) + + return self.index_id + + except Exception as e: + logger.error(f"Error occurred during processing: {e}") + raise ProcessingError(f"Error processing documents: {str(e)}") + + def get_config(self) -> Dict[str, Any]: + """Get the current vector index configuration.""" + return self.config.copy() + + def get_status(self) -> Dict[str, Any]: + """Get the current status of the vector index.""" + return { + "id": self.index_id, + "has_embeddings": self.embeddings_dir.exists() and any(self.embeddings_dir.iterdir()), + "config": self.get_config(), + } + + async def delete(self) -> bool: + """Delete the vector index and its data.""" + try: + # Initialize datastore + datastore = await get_datastore( + self.config["vector_db"], + embedding_model=self.config.get("model"), + ) + + # Delete vectors from vector database + await datastore.delete( + filter=DocumentMetadataFilterSchema( + document_id=self.index_id, + ), + delete_all=False, + ) + + # Delete files from filesystem + if self.base_dir.exists(): + import shutil + + shutil.rmtree(self.base_dir) + + return True + except Exception as e: + logger.error(f"Error deleting vector index: {e}") + return False + + async def retrieve( + self, + query: str, + top_k: int = 5, + score_threshold: Optional[float] = None, + semantic_weight: Optional[float] = 1.0, + keyword_weight: Optional[float] = None, + ) -> List[Dict[str, Any]]: + """Retrieve relevant documents from the vector index. + + Args: + query: The search query + top_k: Number of results to return + score_threshold: Minimum similarity score threshold + semantic_weight: Weight for semantic search (0 to 1) + keyword_weight: Weight for keyword search (0 to 1) + + Returns: + List of documents with their similarity scores + + """ + try: + # Get embedding model from config + embedding_model = self.config.get("embedding_config", {}).get("model") + if not embedding_model: + raise ValueError("No embedding model specified in vector index configuration") + + # Initialize datastore + datastore = await get_datastore( + self.config["vector_db"], embedding_model=embedding_model + ) + + # Get embedding for query + query_embedding = await get_single_text_embedding( + text=query, + model=embedding_model, + api_key=self.config.get("openai_api_key"), + ) + + # Create query with embedding + query_with_embedding = QueryWithEmbeddingSchema( + query=query, + embedding=query_embedding, + top_k=top_k, + ) + + # Query the datastore + results = await datastore.query([query_with_embedding]) + + if not results or not results[0].results: + return [] + + # Format results + formatted_results = [] + for result in results[0].results: + formatted_results.append( + { + "chunk": result, + "score": result.score, + "metadata": result.metadata.model_dump() if result.metadata else {}, + } + ) + + return formatted_results + + except Exception as e: + logger.error(f"Error retrieving from vector index: {e}") + raise diff --git a/pyspur/backend/pyspur/schemas/__pycache__/workflow_schemas.cpython-312.pyc b/pyspur/backend/pyspur/schemas/__pycache__/workflow_schemas.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4257d534987555981ada5f2b1094731e3d49600f Binary files /dev/null and b/pyspur/backend/pyspur/schemas/__pycache__/workflow_schemas.cpython-312.pyc differ diff --git a/pyspur/backend/pyspur/schemas/dataset_schemas.py b/pyspur/backend/pyspur/schemas/dataset_schemas.py new file mode 100644 index 0000000000000000000000000000000000000000..af15d95962d79130ef9e9792e87f27f6c64f45ce --- /dev/null +++ b/pyspur/backend/pyspur/schemas/dataset_schemas.py @@ -0,0 +1,17 @@ +from datetime import datetime +from typing import Optional + +from pydantic import BaseModel + + +class DatasetResponseSchema(BaseModel): + id: str + name: str + description: Optional[str] + filename: str + created_at: datetime + updated_at: datetime + + +class DatasetListResponseSchema(BaseModel): + datasets: list[DatasetResponseSchema] diff --git a/pyspur/backend/pyspur/schemas/eval_schemas.py b/pyspur/backend/pyspur/schemas/eval_schemas.py new file mode 100644 index 0000000000000000000000000000000000000000..3d5095eb28172056e832fd1da4ef3658f0141701 --- /dev/null +++ b/pyspur/backend/pyspur/schemas/eval_schemas.py @@ -0,0 +1,29 @@ +from datetime import datetime +from enum import Enum +from typing import Any, Dict, Optional + +from pydantic import BaseModel + + +class EvalRunRequest(BaseModel): + workflow_id: str + eval_name: str + output_variable: str + num_samples: int = 10 + + +class EvalRunStatusEnum(str, Enum): + PENDING = "PENDING" + RUNNING = "RUNNING" + COMPLETED = "COMPLETED" + FAILED = "FAILED" + + +class EvalRunResponse(BaseModel): + run_id: str + eval_name: str + workflow_id: str + status: EvalRunStatusEnum + start_time: Optional[datetime] + end_time: Optional[datetime] + results: Optional[Dict[str, Any]] = None diff --git a/pyspur/backend/pyspur/schemas/file_schemas.py b/pyspur/backend/pyspur/schemas/file_schemas.py new file mode 100644 index 0000000000000000000000000000000000000000..510295b48df14649e47fdf75692007be2ce06846 --- /dev/null +++ b/pyspur/backend/pyspur/schemas/file_schemas.py @@ -0,0 +1,12 @@ +from datetime import datetime +from typing import Optional + +from pydantic import BaseModel + + +class FileResponseSchema(BaseModel): + name: str + path: str + size: int + created: datetime + workflow_id: Optional[str] = None diff --git a/pyspur/backend/pyspur/schemas/node_type_schemas.py b/pyspur/backend/pyspur/schemas/node_type_schemas.py new file mode 100644 index 0000000000000000000000000000000000000000..f58395fc1e3c0e455359032d2ba2ada318dc8bb6 --- /dev/null +++ b/pyspur/backend/pyspur/schemas/node_type_schemas.py @@ -0,0 +1,57 @@ +import importlib + +from pydantic import BaseModel + + +class NodeTypeSchema(BaseModel): + node_type_name: str + class_name: str + module: str + + @property + def node_class(self): + # Import the module + module = importlib.import_module(name=f"{self.module}", package="pyspur") + + # Split the class name into parts for attribute traversal + parts = self.class_name.split(".") + + # Start with the module + obj = module + + # Traverse the attribute chain + for part in parts: + obj = getattr(obj, part) + + return obj + + @property + def input_model(self): + return self.node_class.input_model + + @property + def display_name(self) -> str: + """Get the display name for the node type, falling back to class name if not set.""" + node_class = self.node_class + return node_class.display_name or node_class.__name__ + + @property + def logo(self) -> str: + """Get the logo for the node type, falling back to None if not set.""" + node_class = self.node_class + return node_class.logo or "" + + @property + def category(self) -> str: + """Get the category for the node type, falling back to None if not set.""" + node_class = self.node_class + return node_class.category or "" + + @property + def config_title(self) -> str: + """Get the title to use for the config, using display name.""" + return self.display_name + + +class MinimumNodeConfigSchema(BaseModel): + node_type: NodeTypeSchema diff --git a/pyspur/backend/pyspur/schemas/output_file_schemas.py b/pyspur/backend/pyspur/schemas/output_file_schemas.py new file mode 100644 index 0000000000000000000000000000000000000000..2b90b6bb38b1336c7c15f8b64fa9a469881b0933 --- /dev/null +++ b/pyspur/backend/pyspur/schemas/output_file_schemas.py @@ -0,0 +1,26 @@ +from datetime import datetime + +from pydantic import BaseModel + + +class OutputFileResponseSchema(BaseModel): + id: str + file_name: str + created_at: datetime + updated_at: datetime + + class Config: + from_attributes = True + + +class OutputFileCreateSchema(BaseModel): + run_id: str + file_name: str + file_path: str + + +class OutputFileUpdateSchema(BaseModel): + id: str + + class Config: + from_attributes = True diff --git a/pyspur/backend/pyspur/schemas/pause_schemas.py b/pyspur/backend/pyspur/schemas/pause_schemas.py new file mode 100644 index 0000000000000000000000000000000000000000..af2c37c486fa83bfd338d8fc34ef3e2b4ce6617a --- /dev/null +++ b/pyspur/backend/pyspur/schemas/pause_schemas.py @@ -0,0 +1,31 @@ +from datetime import datetime +from typing import Any, Dict, Optional + +from pydantic import BaseModel + +from ..nodes.logic.human_intervention import PauseAction +from .run_schemas import RunResponseSchema +from .workflow_schemas import WorkflowDefinitionSchema + + +class PauseHistoryResponseSchema(BaseModel): + """Schema for pause information from a node's output.""" + + id: str # Synthetic ID for API compatibility + run_id: str + node_id: str + pause_message: Optional[str] + pause_time: datetime + resume_time: Optional[datetime] + resume_user_id: Optional[str] + resume_action: Optional[PauseAction] + input_data: Optional[Dict[str, Any]] + comments: Optional[str] + + +class PausedWorkflowResponseSchema(BaseModel): + """Schema for a paused workflow, including its current pause state.""" + + run: RunResponseSchema + current_pause: PauseHistoryResponseSchema + workflow: WorkflowDefinitionSchema diff --git a/pyspur/backend/pyspur/schemas/rag_schemas.py b/pyspur/backend/pyspur/schemas/rag_schemas.py new file mode 100644 index 0000000000000000000000000000000000000000..dfba4133b13a755f8e328846f1031b24352f93cb --- /dev/null +++ b/pyspur/backend/pyspur/schemas/rag_schemas.py @@ -0,0 +1,158 @@ +import os +from typing import Any, Dict, List, Optional + +from fastapi import HTTPException +from pydantic import BaseModel + + +class TemplateSchema(BaseModel): + enabled: bool = False + template: str = "{{ text }}" + metadata_template: Dict[str, str] = {} + + +# Models +class TextProcessingConfigSchema(BaseModel): + chunk_token_size: int = 200 # Default value from original chunker + min_chunk_size_chars: int = 350 # Default value from original chunker + min_chunk_length_to_embed: int = 5 # Default value from original chunker + embeddings_batch_size: int = 128 # Default value from original chunker + max_num_chunks: int = 10000 # Default value from original chunker + use_vision_model: bool = False # Whether to use vision model for PDF parsing + vision_model: Optional[str] = None # Model to use for vision-based parsing + vision_provider: Optional[str] = None # Provider for vision model + template: Optional[TemplateSchema] = TemplateSchema() + + def get_vision_config(self) -> Optional[Dict[str, Any]]: + """Get vision configuration with API key if vision model is enabled.""" + if not self.use_vision_model or not self.vision_model or not self.vision_provider: + return None + + # Get API key based on provider + api_key = None + if self.vision_provider == "openai": + api_key = os.getenv("OPENAI_API_KEY") + elif self.vision_provider == "anthropic": + api_key = os.getenv("ANTHROPIC_API_KEY") + + if not api_key: + raise HTTPException( + status_code=400, + detail=f"Missing API key for vision provider {self.vision_provider}", + ) + + return { + "model": self.vision_model, + "provider": self.vision_provider, + "api_key": api_key, + } + + +class EmbeddingConfigSchema(BaseModel): + model: str + vector_db: str + search_strategy: str + semantic_weight: Optional[float] = None + keyword_weight: Optional[float] = None + top_k: Optional[int] = None + score_threshold: Optional[float] = None + + +class DocumentCollectionCreateSchema(BaseModel): + """Request model for creating a document collection""" + + name: str + description: Optional[str] = None + text_processing: TextProcessingConfigSchema + + +class VectorIndexCreateSchema(BaseModel): + """Request model for creating a vector index""" + + name: str + description: Optional[str] = None + collection_id: str + embedding: EmbeddingConfigSchema + + +class DocumentCollectionResponseSchema(BaseModel): + """Response model for document collection operations""" + + id: str + name: str + description: Optional[str] = None + status: str + created_at: str + updated_at: str + document_count: int + chunk_count: int + error_message: Optional[str] = None + + +class VectorIndexResponseSchema(BaseModel): + """Response model for vector index operations""" + + id: str + name: str + description: Optional[str] = None + collection_id: str + status: str + created_at: str + updated_at: str + document_count: int + chunk_count: int + error_message: Optional[str] = None + embedding_model: str + vector_db: str + + +# Progress tracking models +class ProcessingProgressSchema(BaseModel): + """Base model for tracking processing progress""" + + id: str + status: str = "pending" # pending, processing, completed, failed + progress: float = 0.0 # 0 to 1 + current_step: str = "initializing" # parsing, chunking, embedding, etc. + total_files: int = 0 + processed_files: int = 0 + total_chunks: int = 0 + processed_chunks: int = 0 + error_message: Optional[str] = None + created_at: str + updated_at: str + + +class RetrievalRequestSchema(BaseModel): + """Request model for retrieving from vector index""" + + query: str + top_k: Optional[int] = 5 + score_threshold: Optional[float] = None + semantic_weight: Optional[float] = 1.0 + keyword_weight: Optional[float] = None + + +class ChunkMetadataSchema(BaseModel): + """Schema for chunk metadata in retrieval response""" + + document_id: str + chunk_id: str + document_title: Optional[str] = None + page_number: Optional[int] = None + chunk_number: Optional[int] = None + + +class RetrievalResultSchema(BaseModel): + """Schema for a single retrieval result""" + + text: str + score: float + metadata: ChunkMetadataSchema + + +class RetrievalResponseSchema(BaseModel): + """Response model for retrieval operations""" + + results: List[RetrievalResultSchema] + total_results: int diff --git a/pyspur/backend/pyspur/schemas/router_schemas.py b/pyspur/backend/pyspur/schemas/router_schemas.py new file mode 100644 index 0000000000000000000000000000000000000000..a0d215fe4c6f065337f74324c12e859bf81ddb5d --- /dev/null +++ b/pyspur/backend/pyspur/schemas/router_schemas.py @@ -0,0 +1,34 @@ +from enum import Enum +from typing import Any, List, Literal, Optional + +from pydantic import BaseModel, Field + + +class ComparisonOperator(str, Enum): + CONTAINS = "contains" + EQUALS = "equals" + GREATER_THAN = "greater_than" + LESS_THAN = "less_than" + STARTS_WITH = "starts_with" + NOT_STARTS_WITH = "not_starts_with" + IS_EMPTY = "is_empty" + IS_NOT_EMPTY = "is_not_empty" + NUMBER_EQUALS = "number_equals" + + +LogicalOperator = Literal["AND", "OR"] + + +class RouteConditionRuleSchema(BaseModel): + """Configuration for a single condition""" + + variable: str + operator: ComparisonOperator = Field(default=ComparisonOperator.CONTAINS) + value: Any + logicalOperator: Optional[LogicalOperator] = Field(default="AND") + + +class RouteConditionGroupSchema(BaseModel): + """Configuration for a route with multiple conditions""" + + conditions: List[RouteConditionRuleSchema] diff --git a/pyspur/backend/pyspur/schemas/run_schemas.py b/pyspur/backend/pyspur/schemas/run_schemas.py new file mode 100644 index 0000000000000000000000000000000000000000..95c13a0f16cee575b7767230b129911a3513b166 --- /dev/null +++ b/pyspur/backend/pyspur/schemas/run_schemas.py @@ -0,0 +1,73 @@ +from datetime import datetime +from typing import Any, Dict, List, Optional + +from pydantic import BaseModel, computed_field + +from ..models.run_model import RunStatus +from ..nodes.logic.human_intervention import PauseAction +from .task_schemas import TaskResponseSchema, TaskStatus +from .workflow_schemas import WorkflowVersionResponseSchema + + +class StartRunRequestSchema(BaseModel): + initial_inputs: Optional[Dict[str, Dict[str, Any]]] = None + parent_run_id: Optional[str] = None + files: Optional[Dict[str, List[str]]] = None # Maps node_id to list of file paths + + +class RunResponseSchema(BaseModel): + id: str + workflow_id: str + workflow_version_id: Optional[str] = None + workflow_version: Optional[WorkflowVersionResponseSchema] = None + status: RunStatus + start_time: datetime + end_time: Optional[datetime] = None + initial_inputs: Optional[Dict[str, Dict[str, Any]]] = None + outputs: Optional[Dict[str, Dict[str, Any]]] = None + tasks: List[TaskResponseSchema] = [] + parent_run_id: Optional[str] = None + run_type: str = "interactive" + output_file_id: Optional[str] = None + input_dataset_id: Optional[str] = None + message: Optional[str] = None # Add message field for additional info + + @computed_field + def duration(self) -> Optional[float]: + if self.start_time and self.end_time: + return (self.end_time - self.start_time).total_seconds() + elif self.start_time: + now = datetime.now(self.start_time.tzinfo) + return (now - self.start_time).total_seconds() + return None + + @computed_field(return_type=float) + def percentage_complete(self): + if not self.tasks: + return 0 + completed_tasks = sum(1 for task in self.tasks if task.status == TaskStatus.COMPLETED) + return completed_tasks / len(self.tasks) * 100 + + class Config: + from_attributes = True + + +class PartialRunRequestSchema(BaseModel): + node_id: str + rerun_predecessors: bool = False + initial_inputs: Optional[Dict[str, Dict[str, Any]]] = None + partial_outputs: Optional[Dict[str, Dict[str, Any] | List[Dict[str, Any]]]] = None + + +class ResumeRunRequestSchema(BaseModel): + """Schema for resuming a paused workflow run.""" + + inputs: Dict[str, Any] # Human-provided inputs for the paused node + user_id: str # ID of the user resuming the workflow + action: PauseAction # Action taken (APPROVE/DECLINE/OVERRIDE) + comments: Optional[str] = None # Optional comments about the decision + + +class BatchRunRequestSchema(BaseModel): + dataset_id: str + mini_batch_size: int = 10 diff --git a/pyspur/backend/pyspur/schemas/session_schemas.py b/pyspur/backend/pyspur/schemas/session_schemas.py new file mode 100644 index 0000000000000000000000000000000000000000..3975746cabbd0742055faf9581cbdcac2631fd4d --- /dev/null +++ b/pyspur/backend/pyspur/schemas/session_schemas.py @@ -0,0 +1,49 @@ +from datetime import datetime +from typing import Any, Dict, List, Optional + +from pydantic import BaseModel + + +class MessageBase(BaseModel): + content: Dict[str, Any] + + +class MessageResponse(MessageBase): + id: str + session_id: str + run_id: Optional[str] + created_at: datetime + updated_at: datetime + model_config = { + "from_attributes": True, + } + + +class SessionBase(BaseModel): + workflow_id: str + + +class SessionCreate(SessionBase): + user_id: str + external_id: Optional[str] = None + + +class SessionUpdate(SessionBase): + pass + + +class SessionResponse(SessionBase): + id: str + user_id: str + workflow_id: str + created_at: datetime + updated_at: datetime + messages: List[MessageResponse] + model_config = { + "from_attributes": True, + } + + +class SessionListResponse(BaseModel): + sessions: List[SessionResponse] + total: int diff --git a/pyspur/backend/pyspur/schemas/slack_schemas.py b/pyspur/backend/pyspur/schemas/slack_schemas.py new file mode 100644 index 0000000000000000000000000000000000000000..2e2ae9432697764bfea1db5b23327c26676165f2 --- /dev/null +++ b/pyspur/backend/pyspur/schemas/slack_schemas.py @@ -0,0 +1,207 @@ +from typing import Any, Dict, List, Optional + +from pydantic import BaseModel, Field + + +class SlackAgentBase(BaseModel): + name: str + slack_team_id: Optional[str] = None # Will be populated upon Slack connection + slack_team_name: Optional[str] = None # Will be populated upon Slack connection + slack_channel_id: Optional[str] = None + slack_channel_name: Optional[str] = None + is_active: bool = True + spur_type: str = "workflow" # "spur-web", "spur-chat", etc. + + def model_dump(self, *args: Any, **kwargs: Any) -> Dict[str, Any]: + return super().model_dump(*args, **kwargs) + + +class SlackAgentCreate(SlackAgentBase): + """Request schema for creating a Slack agent. + Note: slack_team_id and slack_team_name will be set automatically based on the + currently configured Slack token if not provided. + """ + + workflow_id: str # Workflow to associate with this agent + trigger_on_mention: bool = True + trigger_on_direct_message: bool = True + trigger_on_channel_message: bool = False + trigger_keywords: Optional[List[str]] = None + trigger_enabled: bool = True + has_bot_token: bool = False + has_user_token: bool = False + has_app_token: bool = False + last_token_update: Optional[str] = None + + +class SlackAgentUpdate(BaseModel): + """Request schema for updating a Slack agent.""" + + name: Optional[str] = None + slack_team_id: Optional[str] = None + slack_team_name: Optional[str] = None + slack_channel_id: Optional[str] = None + slack_channel_name: Optional[str] = None + is_active: Optional[bool] = None + workflow_id: Optional[str] = None + trigger_on_mention: Optional[bool] = None + trigger_on_direct_message: Optional[bool] = None + trigger_on_channel_message: Optional[bool] = None + trigger_keywords: Optional[List[str]] = None + trigger_enabled: Optional[bool] = None + has_bot_token: Optional[bool] = None + has_user_token: Optional[bool] = None + has_app_token: Optional[bool] = None + spur_type: Optional[str] = None + + def model_dump(self, *args: Any, **kwargs: Any) -> Dict[str, Any]: + return super().model_dump(*args, **kwargs) + + +class SlackAgentResponse(SlackAgentBase): + """Response schema for Slack agent information.""" + + id: int + workflow_id: Optional[str] = None + trigger_on_mention: bool + trigger_on_direct_message: bool + trigger_on_channel_message: bool + trigger_keywords: Optional[List[str]] = None + trigger_enabled: bool + has_bot_token: bool = False + has_user_token: bool = False + has_app_token: bool = False + last_token_update: Optional[str] = None + + model_config = {"from_attributes": True} + + +class SlackDirectTokenConfig(BaseModel): + """Request schema for configuring Slack with a direct token.""" + + bot_token: str + description: Optional[str] = "Manually configured Slack bot token" + + +class SlackMessage(BaseModel): + """Schema for Slack messages.""" + + channel: str + text: str + + +class WorkflowAssociation(BaseModel): + """Schema for associating a workflow with a Slack agent.""" + + workflow_id: str + + +class SlackTriggerConfig(BaseModel): + """Configuration schema for Slack triggers.""" + + trigger_on_mention: bool + trigger_on_direct_message: bool + trigger_on_channel_message: bool + trigger_keywords: List[str] + trigger_enabled: bool + + +class WorkflowTriggerRequest(BaseModel): + """Request schema for triggering a workflow from Slack.""" + + text: str + channel_id: str + user_id: str + team_id: str + event_type: str + event_data: Dict[str, Any] + + +class SlackOAuthResponse(BaseModel): + """Response schema for Slack OAuth callback.""" + + success: bool + message: str + team_name: Optional[str] = None + + +class SlackMessageResponse(BaseModel): + """Response schema for sending a message to Slack.""" + + success: bool + message: str + ts: Optional[str] = None + + +class WorkflowTriggerResult(BaseModel): + """Single result for a triggered workflow.""" + + agent_id: int + workflow_id: str + status: str # "triggered", "skipped", "error" + run_id: Optional[str] = None + error: Optional[str] = None + + +class WorkflowTriggersResponse(BaseModel): + """Response schema for triggering workflows from Slack.""" + + triggered_workflows: List[WorkflowTriggerResult] + + +class TemplateWorkflowResponse(BaseModel): + """Response schema for creating a template Slack workflow.""" + + id: str + name: str + description: str + message: str + + +# New schemas for agent token management +class AgentTokenRequest(BaseModel): + """Request schema for token management of Slack agents. + Supports bot_token, user_token, and app_token types. + The app_token is required for Socket Mode connections. + """ + + token: str + token_type: Optional[str] = Field( + default=None, description="Optional. Token type provided via URL." + ) + + +class AgentTokenResponse(BaseModel): + """Response schema for token management operations. + Supports bot_token, user_token, and app_token types. + """ + + agent_id: int + token_type: str + masked_token: str + updated_at: Optional[str] = None + + +class SlackSocketModeResponse(BaseModel): + """Response schema for Socket Mode operations. + + Socket Mode requires both a bot token and an app token to function. + It enables real-time message handling without exposing a public HTTP endpoint. + """ + + agent_id: int + socket_mode_active: bool + message: str + + +class SlackSocketModeConfig(BaseModel): + """Configuration schema for Socket Mode settings. + + Socket Mode is a Slack feature that allows your app to receive events through a WebSocket connection + instead of using HTTP endpoints, which is useful for local development or environments + behind firewalls. It requires both a bot token and an app-level token. + """ + + enabled: bool = True + app_token: Optional[str] = None + use_global_app_token: bool = False # Whether to use the global SLACK_APP_TOKEN env variable diff --git a/pyspur/backend/pyspur/schemas/task_schemas.py b/pyspur/backend/pyspur/schemas/task_schemas.py new file mode 100644 index 0000000000000000000000000000000000000000..1dd8f44373456ab7160f427ed241eeb97d643c76 --- /dev/null +++ b/pyspur/backend/pyspur/schemas/task_schemas.py @@ -0,0 +1,25 @@ +from datetime import datetime +from typing import Any, Dict, Optional + +from pydantic import BaseModel + +from ..models.task_model import TaskStatus +from .workflow_schemas import WorkflowDefinitionSchema + + +class TaskResponseSchema(BaseModel): + id: str + run_id: str + node_id: str + parent_task_id: Optional[str] + status: TaskStatus + inputs: Optional[Any] + outputs: Optional[Any] + error: Optional[str] + start_time: Optional[datetime] + end_time: Optional[datetime] + subworkflow: Optional[WorkflowDefinitionSchema] + subworkflow_output: Optional[Dict[str, Any]] + + class Config: + from_attributes = True # Enable ORM mode diff --git a/pyspur/backend/pyspur/schemas/user_schemas.py b/pyspur/backend/pyspur/schemas/user_schemas.py new file mode 100644 index 0000000000000000000000000000000000000000..2d1776ad5a63f16585fabe98bcbe45e4bb37330d --- /dev/null +++ b/pyspur/backend/pyspur/schemas/user_schemas.py @@ -0,0 +1,34 @@ +from datetime import datetime +from typing import Any, Dict, List, Optional + +from pydantic import BaseModel, Field + + +class UserBase(BaseModel): + external_id: str = Field(..., description="External ID for the user") + user_metadata: Dict[str, Any] = Field( + default_factory=dict, description="Additional user metadata" + ) + + +class UserCreate(UserBase): + pass + + +class UserUpdate(BaseModel): + external_id: Optional[str] = Field(None, description="External ID for the user") + user_metadata: Optional[Dict[str, Any]] = Field(None, description="Additional user metadata") + + +class UserResponse(UserBase): + id: str = Field(..., description="Internal ID with prefix (e.g. U1)") + created_at: datetime + updated_at: datetime + + class Config: + from_attributes = True + + +class UserListResponse(BaseModel): + users: List[UserResponse] + total: int = Field(..., description="Total number of users") diff --git a/pyspur/backend/pyspur/schemas/workflow_schemas.py b/pyspur/backend/pyspur/schemas/workflow_schemas.py new file mode 100644 index 0000000000000000000000000000000000000000..4d769f52b40f345aaedb1cf449e5c8f8f1a43917 --- /dev/null +++ b/pyspur/backend/pyspur/schemas/workflow_schemas.py @@ -0,0 +1,273 @@ +import json +from datetime import datetime +from enum import Enum +from typing import Any, Dict, List, Optional + +from pydantic import BaseModel, field_validator, model_validator +from typing_extensions import Self + +from ..utils.pydantic_utils import json_schema_to_model + + +class SpurType(str, Enum): + """Enum representing the type of spur. + + Workflow: Standard workflow with nodes and edges + Chatbot: Essentially a workflow with chat compatible IO and session management + Agent: Autonomous agent node that calls tools, also has chat compatible IO + and session management + """ + + WORKFLOW = "workflow" + CHATBOT = "chatbot" + AGENT = "agent" + + +class WorkflowNodeCoordinatesSchema(BaseModel): + """Coordinates for a node in a workflow.""" + + x: float + y: float + + +class WorkflowNodeDimensionsSchema(BaseModel): + """Dimensions for a node in a workflow.""" + + width: float + height: float + + +class WorkflowNodeSchema(BaseModel): + """A single step in a workflow. + + Each node receives a dictionary mapping predecessor node IDs to their outputs. + For dynamic schema nodes, the output schema is defined in the config dictionary. + For static schema nodes, the output schema is defined in the node class implementation. + """ + + id: str # ID in the workflow + title: str = "" # Display name + parent_id: Optional[str] = None # ID of the parent node + node_type: str # Name of the node type + config: Dict[ + str, Any + ] = {} # Configuration parameters including dynamic output schema if needed + coordinates: Optional[WorkflowNodeCoordinatesSchema] = ( + None # Position of the node in the workflow + ) + dimensions: Optional[WorkflowNodeDimensionsSchema] = ( + None # Dimensions of the node in the workflow + ) + subworkflow: Optional["WorkflowDefinitionSchema"] = None # Sub-workflow definition + + @model_validator(mode="after") + def default_title_to_id(self) -> Self: + if self.title.strip() == "": + self.title = self.id + return self + + @model_validator(mode="after") + def prefix_model_name_with_provider(self) -> Self: + # We need this to handle spurs created earlier than the prefixing change + if self.node_type in ("SingleLLMCallNode", "BestOfNNode"): + llm_info = self.config.get("llm_info") + assert llm_info is not None + if ( + llm_info["model"].startswith("gpt") + or llm_info["model"].startswith("chatgpt") + or llm_info["model"].startswith("o1") + ): + llm_info["model"] = f"openai/{llm_info['model']}" + if llm_info["model"].startswith("claude"): + llm_info["model"] = f"anthropic/{llm_info['model']}" + return self + + +class WorkflowLinkSchema(BaseModel): + """Connect a source node to a target node. + + The target node will receive the source node's output in its input dictionary. + """ + + source_id: str + target_id: str + source_handle: Optional[str] = None # The output handle from the source node + target_handle: Optional[str] = None # The input handle on the target node + + +class WorkflowDefinitionSchema(BaseModel): + """A workflow is a DAG of nodes.""" + + nodes: List[WorkflowNodeSchema] + links: List[WorkflowLinkSchema] + test_inputs: List[Dict[str, Any]] = [] + spur_type: SpurType = SpurType.WORKFLOW + + @classmethod + @field_validator("nodes") + def nodes_must_have_unique_ids(cls, v: List[WorkflowNodeSchema]): + node_ids = [node.id for node in v] + if len(node_ids) != len(set(node_ids)): + raise ValueError("Node IDs must be unique.") + return v + + @classmethod + @field_validator("nodes") + def must_have_one_and_only_one_input_node(cls, v: List[WorkflowNodeSchema]): + input_nodes = [ + node for node in v if node.node_type == "InputNode" and node.parent_id is None + ] + if len(input_nodes) != 1: + raise ValueError("Workflow must have exactly one input node.") + return v + + @classmethod + @field_validator("nodes") + def must_have_at_most_one_output_node(cls, v: List[WorkflowNodeSchema]): + output_nodes = [ + node for node in v if node.node_type == "OutputNode" and node.parent_id is None + ] + if len(output_nodes) > 1: + raise ValueError("Workflow must have at most one output node.") + return v + + @model_validator(mode="after") + def validate_router_node_links(self) -> Self: + """Validate links connected to RouterNodes. + + They must have correctly formatted target handles. + For RouterNodes, the target handle should match the format: source_node_id.handle_id + """ + for link in self.links: + source_node = next((node for node in self.nodes if node.id == link.source_id), None) + if source_node and source_node.node_type == "RouterNode": + target_handle = link.target_handle or link.source_id + + # If target_handle contains a dot, take only what's after the dot + if target_handle.find(".") != -1: + target_handle = target_handle.split(".")[-1] + + # Ensure it has the correct prefix + if not target_handle.startswith(f"{link.source_id}."): + link.target_handle = f"{link.source_id}.{target_handle}" + + return self + + @model_validator(mode="after") + def validate_chatbot_input_node(self) -> Self: # noqa: C901 + """Validate that chatbot workflows have the required input fields. + + For chatbot workflows, the input node must have user_message and session_id fields. + """ + if self.spur_type == SpurType.CHATBOT: + # chatbot workflows must have input node with the following fields: + # user_message, session_id + input_node = next( + ( + node + for node in self.nodes + if node.node_type == "InputNode" and node.parent_id is None + ), + None, + ) + if input_node: + try: + json_schema = json.loads(input_node.config.get("output_json_schema", "{}")) + model = json_schema_to_model(json_schema) + output_schema = model.model_fields + missing_fields: List[str] = [] + + if "user_message" not in output_schema: + missing_fields.append("user_message") + elif output_schema["user_message"].annotation is not str: + missing_fields.append("user_message (must be of type str)") + + if "session_id" not in output_schema: + missing_fields.append("session_id") + elif output_schema["session_id"].annotation is not str: + missing_fields.append("session_id (must be of type str)") + + if missing_fields: + raise ValueError( + f"Chatbot input node must have the following mandatory fields: " + f"{', '.join(missing_fields)}" + ) + except json.JSONDecodeError: + raise ValueError( + "Invalid JSON schema in input node output_json_schema" + ) from None + return self + + @model_validator(mode="after") + def validate_chatbot_output_node(self) -> Self: + """Validate that chatbot workflows have the required output fields.""" + if self.spur_type == SpurType.CHATBOT: + # chatbot workflows must have output node with assistant_message field + output_node = next( + ( + node + for node in self.nodes + if node.node_type == "OutputNode" and node.parent_id is None + ), + None, + ) + if output_node: + try: + json_schema = json.loads(output_node.config.get("output_json_schema", "{}")) + model = json_schema_to_model(json_schema) + output_schema = model.model_fields + missing_fields: List[str] = [] + + if "assistant_message" not in output_schema: + missing_fields.append("assistant_message") + elif output_schema["assistant_message"].annotation is not str: + missing_fields.append("assistant_message (must be of type str)") + + if missing_fields: + raise ValueError( + f"Chatbot output node must have the following mandatory fields: " + f"{', '.join(missing_fields)}" + ) + except json.JSONDecodeError: + raise ValueError( + "Invalid JSON schema in output node output_json_schema" + ) from None + + return self + + model_config = {"from_attributes": True} + + +class WorkflowCreateRequestSchema(BaseModel): + """A request to create a new workflow.""" + + name: str + description: str = "" + definition: Optional[WorkflowDefinitionSchema] = None + + +class WorkflowResponseSchema(BaseModel): + """A response containing the details of a workflow.""" + + id: str + name: str + description: Optional[str] + definition: WorkflowDefinitionSchema + created_at: datetime + updated_at: datetime + + model_config = {"from_attributes": True} + + +class WorkflowVersionResponseSchema(BaseModel): + """A response containing the details of a workflow version.""" + + version: int + name: str + description: Optional[str] + definition: Any + definition_hash: str + created_at: datetime + updated_at: datetime + + model_config = {"from_attributes": True} diff --git a/pyspur/backend/pyspur/schemas/workflow_validation.py b/pyspur/backend/pyspur/schemas/workflow_validation.py new file mode 100644 index 0000000000000000000000000000000000000000..2ca071274b7c98ad291e78274c3c69142801b3b8 --- /dev/null +++ b/pyspur/backend/pyspur/schemas/workflow_validation.py @@ -0,0 +1,20 @@ +from ..nodes.node_types import is_valid_node_type +from .workflow_schemas import WorkflowDefinitionSchema, WorkflowNodeSchema + + +def validate_node_type(node: WorkflowNodeSchema) -> bool: + """Validate that a node's type is supported.""" + return is_valid_node_type(node.node_type) + + +def validate_workflow_definition(workflow: WorkflowDefinitionSchema) -> bool: + """ + Validate a workflow definition. + Returns True if valid, raises ValueError if invalid. + """ + # Validate all node types are supported + for node in workflow.nodes: + if not validate_node_type(node): + raise ValueError(f"Node type '{node.node_type}' is not valid.") + + return True diff --git a/pyspur/backend/pyspur/static/.gitignore b/pyspur/backend/pyspur/static/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..b4fae888df9f7eadbe7e9af16e385161fca2ff55 --- /dev/null +++ b/pyspur/backend/pyspur/static/.gitignore @@ -0,0 +1,6 @@ +# Ignore all files in this directory +* + +# Except the .gitignore file itself +!.gitignore +!robots.txt \ No newline at end of file diff --git a/pyspur/backend/pyspur/static/robots.txt b/pyspur/backend/pyspur/static/robots.txt new file mode 100644 index 0000000000000000000000000000000000000000..70c2374d7b6ba4aaca3d91abbe29d86659bdd4e5 --- /dev/null +++ b/pyspur/backend/pyspur/static/robots.txt @@ -0,0 +1,2 @@ +User-agent: * +Disallow: / diff --git a/pyspur/backend/pyspur/templates/Slack_Summarizer.json b/pyspur/backend/pyspur/templates/Slack_Summarizer.json new file mode 100644 index 0000000000000000000000000000000000000000..3cf6cdd7fccba43b48f6f84198f037ae6132f0d3 --- /dev/null +++ b/pyspur/backend/pyspur/templates/Slack_Summarizer.json @@ -0,0 +1,301 @@ +{ + "name": "Slack Summarizer", + "metadata": { + "name": "Slack Summarizer", + "description": "Summarize technical blog posts or research papers and share them in Slack.", + "features": ["Blog post summarization", "PDF paper analysis", "Slack integration"] + }, + "definition": { + "nodes": [ + { + "id": "input_node", + "title": "input_node", + "parent_id": null, + "node_type": "InputNode", + "config": { + "output_schema": { + "blogpost_url": "string", + "paper_pdf_file": "string" + }, + "output_json_schema": "{\n \"type\": \"object\",\n \"properties\": {\n \"blogpost_url\": {\n \"type\": \"string\"\n },\n \"paper_pdf_file\": {\n \"type\": \"string\"\n }\n },\n \"required\": [\n \"blogpost_url\",\n \"paper_pdf_file\"\n ]\n}", + "has_fixed_output": false, + "enforce_schema": false + }, + "coordinates": { + "x": 0, + "y": 432 + }, + "dimensions": null, + "subworkflow": null + }, + { + "id": "RouterNode_1", + "title": "RouterNode_1", + "parent_id": null, + "node_type": "RouterNode", + "config": { + "title": "RouterNode_1", + "type": "object", + "output_schema": { + "output": "string" + }, + "output_json_schema": "{\"type\":\"object\",\"properties\":{\"input_node\":{\"type\":\"object\",\"properties\":{\"blogpost_url\":{\"type\":\"string\"},\"paper_pdf_file\":{\"type\":\"string\"}},\"required\":[\"blogpost_url\",\"paper_pdf_file\"]}},\"required\":[\"input_node\"],\"additionalProperties\":false}", + "has_fixed_output": false, + "route_map": { + "route1": { + "conditions": [ + { + "logicalOperator": "AND", + "operator": "is_not_empty", + "value": "", + "variable": "input_node.blogpost_url" + } + ] + }, + "route2": { + "conditions": [ + { + "variable": "input_node.paper_pdf_file", + "operator": "is_not_empty", + "value": "" + } + ] + } + } + }, + "coordinates": { + "x": 438, + "y": 0 + }, + "dimensions": { + "width": 428, + "height": 1077 + }, + "subworkflow": null + }, + { + "id": "FirecrawlScrapeNode_1", + "title": "FirecrawlScrapeNode_1", + "parent_id": null, + "node_type": "FirecrawlScrapeNode", + "config": { + "title": "FirecrawlScrapeNode_1", + "type": "object", + "output_schema": { + "output": "string" + }, + "output_json_schema": "{\"properties\": {\"markdown\": {\"description\": \"The scraped data in markdown format.\", \"title\": \"Markdown\", \"type\": \"string\"}}, \"required\": [\"markdown\"], \"title\": \"FirecrawlScrapeNodeOutput\", \"type\": \"object\"}", + "has_fixed_output": true, + "url_template": "{{RouterNode_1.input_node.blogpost_url}}" + }, + "coordinates": { + "x": 1004, + "y": 463.5 + }, + "dimensions": null, + "subworkflow": null + }, + { + "id": "SingleLLMCallNode_1", + "title": "KeyPointsSummarizer", + "parent_id": null, + "node_type": "SingleLLMCallNode", + "config": { + "title": "KeyPointsSummarizer", + "type": "object", + "output_schema": { + "output": "string" + }, + "output_json_schema": "{\"type\": \"object\", \"properties\": {\"output\": {\"type\": \"string\"} } }", + "has_fixed_output": false, + "llm_info": { + "model": "openai/chatgpt-4o-latest", + "max_tokens": 4096, + "temperature": 0.7, + "top_p": 0.9 + }, + "system_message": "You are a software engineer who breaks down a technical article for colleagues to read.\n\n- Use bullet points to summarize key concepts\n- If appropriate, add some humour sparingly but never force it\n- Your audience are technical software engineers or researchers. You do not need to explain basic SWE concepts to them, you can assume familiarity.\n- Your colleagues work on an AI workflow builder. If appropriate, you can draw a connection between the provided article and how it may inform opportunities, technical decisions or the product roadmap of the AI workflow builder.", + "user_message": "{{FirecrawlScrapeNode_1.markdown}}", + "few_shot_examples": null, + "url_variables": null + }, + "coordinates": { + "x": 1554, + "y": 463.5 + }, + "dimensions": null, + "subworkflow": null + }, + { + "id": "SingleLLMCallNode_2", + "title": "MarkdownExtractor", + "parent_id": null, + "node_type": "SingleLLMCallNode", + "config": { + "title": "MarkdownExtractor", + "type": "object", + "output_schema": { + "output": "string" + }, + "output_json_schema": "{\n \"$schema\": \"http://json-schema.org/draft-07/schema#\",\n \"type\": \"object\",\n \"properties\": {\n \"markdown\": {\n \"type\": \"string\"\n }\n },\n \"required\": [\n \"markdown\"\n ],\n \"additionalProperties\": false\n}", + "has_fixed_output": false, + "llm_info": { + "model": "gemini/gemini-2.0-flash", + "max_tokens": 8192, + "temperature": 0.7, + "top_p": 0.9 + }, + "system_message": "You are an AI / ML researcher who converts a recent arxiv ML paper pdf into markdown as part of a JSON.\n\n- Translate any math equations into mathjax\n\n- You can skip references and appendix, they are not relevant\n\n- Return valid JSON with \"markdown\" key", + "user_message": "", + "few_shot_examples": null, + "url_variables": { + "file": "RouterNode_1.input_node.paper_pdf_file" + } + }, + "coordinates": { + "x": 1017, + "y": 741.5 + }, + "dimensions": null, + "subworkflow": null + }, + { + "id": "SingleLLMCallNode_3", + "title": "PaperSummarizer", + "parent_id": null, + "node_type": "SingleLLMCallNode", + "config": { + "title": "PaperSummarizer", + "type": "object", + "output_schema": { + "output": "string" + }, + "output_json_schema": "{\"type\": \"object\", \"properties\": {\"output\": {\"type\": \"string\"} } }", + "has_fixed_output": false, + "llm_info": { + "model": "openai/chatgpt-4o-latest", + "max_tokens": 4096, + "temperature": 0.7, + "top_p": 0.9 + }, + "system_message": "You are provided with markdown that summarizes a paper. I want you to summarize in the following way:\n\n- Extract exactly three main ideas and for each idea, explain the what, why, and so what\n\n- Focus on the key concepts and not on insignificant details\n\n- If the paper introduces novel methodology, put it into context to what previous methods tried and why this new method is superior\n\n- If the paper includes surprising experimental observations, explain why they are surprising\n\n- Add a little humour but only in places where it's appropriate, never forced", + "user_message": "{{MarkdownExtractor.markdown}}", + "few_shot_examples": null, + "url_variables": null + }, + "coordinates": { + "x": 1583.5, + "y": 741.5 + }, + "dimensions": null, + "subworkflow": null + }, + { + "id": "CoalesceNode_1", + "title": "CoalesceNode_1", + "parent_id": null, + "node_type": "CoalesceNode", + "config": { + "title": "CoalesceNode_1", + "type": "object", + "output_schema": { + "output": "string" + }, + "output_json_schema": "{\"type\":\"object\",\"properties\":{\"output\":{\"type\":\"string\"}},\"required\":[\"output\"],\"additionalProperties\":false}", + "has_fixed_output": false, + "preferences": [ + "SingleLLMCallNode_1|KeyPointsSummarizer", + "SingleLLMCallNode_3|PaperSummarizer" + ] + }, + "coordinates": { + "x": 2370, + "y": 675 + }, + "dimensions": null, + "subworkflow": null + }, + { + "id": "SlackNotifyNode_1", + "title": "SlackNotifyNode_1", + "parent_id": null, + "node_type": "SlackNotifyNode", + "config": { + "title": "SlackNotifyNode_1", + "type": "object", + "output_schema": { + "output": "string" + }, + "output_json_schema": "{\"properties\": {\"status\": {\"description\": \"Error message if the message was not sent successfully.\", \"title\": \"Status\", \"type\": \"string\"}}, \"required\": [\"status\"], \"title\": \"SlackNotifyNodeOutput\", \"type\": \"object\"}", + "has_fixed_output": true, + "channel": "learning", + "mode": "bot", + "message": "Here is your summary\n\n{{CoalesceNode_1.output}}\n\nNow back to work!!!" + }, + "coordinates": { + "x": 2890, + "y": 678.75 + }, + "dimensions": null, + "subworkflow": null + } + ], + "links": [ + { + "source_id": "input_node", + "target_id": "RouterNode_1", + "source_handle": null, + "target_handle": null + }, + { + "source_id": "RouterNode_1", + "target_id": "FirecrawlScrapeNode_1", + "source_handle": "route1", + "target_handle": "RouterNode_1.route1" + }, + { + "source_id": "FirecrawlScrapeNode_1", + "target_id": "SingleLLMCallNode_1", + "source_handle": null, + "target_handle": null + }, + { + "source_id": "RouterNode_1", + "target_id": "SingleLLMCallNode_2", + "source_handle": "route2", + "target_handle": "RouterNode_1.route2" + }, + { + "source_id": "SingleLLMCallNode_2", + "target_id": "SingleLLMCallNode_3", + "source_handle": null, + "target_handle": null + }, + { + "source_id": "SingleLLMCallNode_1", + "target_id": "CoalesceNode_1", + "source_handle": null, + "target_handle": null + }, + { + "source_id": "SingleLLMCallNode_3", + "target_id": "CoalesceNode_1", + "source_handle": null, + "target_handle": null + }, + { + "source_id": "CoalesceNode_1", + "target_id": "SlackNotifyNode_1", + "source_handle": null, + "target_handle": null + } + ], + "test_inputs": [ + { + "id": 1, + "blogpost_url": "https://blog.samaltman.com/three-observations" + } + ] + }, + "description": "" +} \ No newline at end of file diff --git a/pyspur/backend/pyspur/templates/joke_generator.json b/pyspur/backend/pyspur/templates/joke_generator.json new file mode 100644 index 0000000000000000000000000000000000000000..20ee6a6e13c318b1e46b956adfd8f3cad8e24f15 --- /dev/null +++ b/pyspur/backend/pyspur/templates/joke_generator.json @@ -0,0 +1,129 @@ +{ + "name": "Joke Generator using BoN Sampling", + "metadata": { + "name": "Joke Generator", + "description": "Generate and refine jokes using Best-of-N sampling.", + "features": ["Dark humor", "Audience-specific jokes", "Refinement options"] + }, + "definition": { + "nodes": [ + { + "id": "input_node", + "title": "input_node", + "node_type": "InputNode", + "config": { + "output_schema": { + "topic": "string", + "audience": "string" + } + }, + "coordinates": { + "x": 0, + "y": 0 + } + }, + { + "id": "JokeDrafter", + "title": "JokeDrafter", + "node_type": "BestOfNNode", + "config": { + "title": "JokeDrafter", + "output_schema": { + "initial_joke": "string" + }, + "llm_info": { + "model": "gpt-4o", + "max_tokens": 16384, + "temperature": 0.7, + "top_p": 0.9 + }, + "system_message": "You are a stand-up comedian who uses dark humor like Ricky Gervais or Jimmy Carr.\n\nThe user will provide you with a topic and audience, and you have to devise a short joke for that.\n\nYou can roast the person if a person is mentioned, it's only among friends.", + "user_message": "Your audience is: {{input_node.audience}}\nThe topic should be about {{input_node.topic}}", + "few_shot_examples": null, + "samples": 10, + "rating_prompt": "Rate the following joke on a scale from 0 to 10, where 0 is poor and 10 is excellent. \nConsider factors such as surprise, relatability, and punchiness. Respond with only a number.", + "rating_temperature": 0.1, + "rating_max_tokens": 16 + }, + "coordinates": { + "x": 374, + "y": 29.5 + } + }, + { + "id": "JokeRefiner", + "title": "JokeRefiner", + "node_type": "BestOfNNode", + "config": { + "title": "JokeRefiner", + "output_schema": { + "final_joke": "string" + }, + "llm_info": { + "model": "gpt-4o", + "max_tokens": 16384, + "temperature": 0.7, + "top_p": 0.9 + }, + "system_message": "Your goal is to refine a joke to make it more vulgar and concise. It's just among friends, so you can get roasty.\n\n- Be mean\n- Have dark humour\n- Be very punchy", + "user_message": "{{JokeDrafter.initial_joke}}", + "few_shot_examples": null, + "samples": 3, + "rating_prompt": "Rate the following response on a scale from 0 to 10, where 0 is poor and 10 is excellent. Consider factors such as relevance, coherence, and helpfulness. Respond with only a number.", + "rating_temperature": 0.1, + "rating_max_tokens": 16 + }, + "coordinates": { + "x": 750, + "y": 30 + } + }, + { + "id": "SingleShotJoke", + "title": "SingleShotJoke", + "node_type": "SingleLLMCallNode", + "config": { + "title": "SingleShotJoke", + "output_schema": { + "final_joke": "string" + }, + "llm_info": { + "model": "gpt-4o", + "max_tokens": 16384, + "temperature": 0.7, + "top_p": 0.9 + }, + "system_message": "You are a stand-up comedian who uses dark humor like Ricky Gervais or Jimmy Carr.\n\nThe user will provide you with a topic and audience, and you have to devise a short joke for that.\n\nYou can roast the person if a person is mentioned, it's only among friends.", + "user_message": "Your audience is: {{input_node.audience}}\nThe topic should be about {{input_node.topic}}", + "few_shot_examples": null + }, + "coordinates": { + "x": 374, + "y": 204.5 + } + } + ], + "links": [ + { + "source_id": "input_node", + "target_id": "JokeDrafter" + }, + { + "source_id": "JokeDrafter", + "target_id": "JokeRefiner" + }, + { + "source_id": "input_node", + "target_id": "SingleShotJoke" + } + ], + "test_inputs": [ + { + "id": 1732123761259, + "topic": "Emacs vs. Vim", + "audience": "Software Engineers" + } + ] + }, + "description": "" +} \ No newline at end of file diff --git a/pyspur/backend/pyspur/templates/ollama_model_comparison.json b/pyspur/backend/pyspur/templates/ollama_model_comparison.json new file mode 100644 index 0000000000000000000000000000000000000000..4742f44863a44c064ee9adc906030988f6236c16 --- /dev/null +++ b/pyspur/backend/pyspur/templates/ollama_model_comparison.json @@ -0,0 +1,128 @@ +{ + "name": "Ollama Model Comparison", + "metadata": { + "name": "Ollama Model Comparison", + "description": "Compare the performance of different Ollama models. Make sure you have Ollama installed and running.", + "features": ["Llama 3.2", "Gemma 2", "Mistral v0.3"] + }, + "definition": { + "nodes": [ + { + "id": "input_prompt", + "title": "input_prompt", + "node_type": "InputNode", + "config": { + "output_schema": { + "input_1": "str" + }, + "enforce_schema": false, + "title": "input_prompt" + }, + "coordinates": { + "x": 0, + "y": 35 + }, + "subworkflow": null + }, + { + "id": "Llama3_2", + "title": "Llama3_2", + "node_type": "SingleLLMCallNode", + "config": { + "title": "Llama3_2", + "type": "object", + "output_schema": { + "output": "str" + }, + "llm_info": { + "model": "ollama/llama3.2", + "max_tokens": 16384, + "temperature": 0.7, + "top_p": 0.9 + }, + "system_message": "You are a helpful assistant.", + "user_message": "

{{input_prompt.input_1}}

", + "few_shot_examples": null + }, + "coordinates": { + "x": 438, + "y": 0 + }, + "subworkflow": null + }, + { + "id": "Gemma2", + "title": "Gemma2", + "node_type": "SingleLLMCallNode", + "config": { + "title": "Gemma2", + "type": "object", + "output_schema": { + "output": "str" + }, + "llm_info": { + "model": "ollama/gemma2", + "max_tokens": 16384, + "temperature": 0.7, + "top_p": 0.9 + }, + "system_message": "You are a helpful assistant.", + "user_message": "

{{input_prompt.input_1}}

", + "few_shot_examples": null + }, + "coordinates": { + "x": 438, + "y": 369 + }, + "subworkflow": null + }, + { + "id": "Mistral_03", + "title": "Mistral_03", + "node_type": "SingleLLMCallNode", + "config": { + "title": "Mistral_03", + "type": "object", + "output_schema": { + "output": "str" + }, + "llm_info": { + "model": "ollama/mistral", + "max_tokens": 16384, + "temperature": 0.7, + "top_p": 0.9 + }, + "system_message": "You are a helpful assistant.", + "user_message": "

{{input_prompt.input_1}}

", + "few_shot_examples": null + }, + "coordinates": { + "x": 438, + "y": 738 + }, + "subworkflow": null + } + ], + "links": [ + { + "source_id": "input_prompt", + "target_id": "Llama3_2" + }, + { + "source_id": "input_prompt", + "target_id": "Gemma2" + }, + { + "source_id": "input_prompt", + "target_id": "Mistral_03" + } + ], + "test_inputs": [ + { + "id": 1734714471358, + "input_1": "Make a joke about peanuts." + } + ] + }, + "description": "" +} \ No newline at end of file diff --git a/pyspur/backend/pyspur/utils/__init__.py b/pyspur/backend/pyspur/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pyspur/backend/pyspur/utils/__pycache__/__init__.cpython-312.pyc b/pyspur/backend/pyspur/utils/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b583a828f788dbe9ad4dd5bd269aed06dfc0e9f6 Binary files /dev/null and b/pyspur/backend/pyspur/utils/__pycache__/__init__.cpython-312.pyc differ diff --git a/pyspur/backend/pyspur/utils/__pycache__/pydantic_utils.cpython-312.pyc b/pyspur/backend/pyspur/utils/__pycache__/pydantic_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7d3c492e6884d5dd50c1274b36c4cd5876e890f7 Binary files /dev/null and b/pyspur/backend/pyspur/utils/__pycache__/pydantic_utils.cpython-312.pyc differ diff --git a/pyspur/backend/pyspur/utils/file_utils.py b/pyspur/backend/pyspur/utils/file_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3234a4976827ddadf44d0d9d3089d941e3d2531d --- /dev/null +++ b/pyspur/backend/pyspur/utils/file_utils.py @@ -0,0 +1,38 @@ +import base64 +import mimetypes +from pathlib import Path + + +def encode_file_to_base64_data_url(file_path: str) -> str: + """ + Read a file and encode it as a base64 data URL with the appropriate MIME type. + """ + path = Path(file_path) + mime_type = mimetypes.guess_type(path)[0] or "application/octet-stream" + + with open(path, "rb") as f: + file_content = f.read() + base64_data = base64.b64encode(file_content).decode("utf-8") + return f"data:{mime_type};base64,{base64_data}" + + +def get_file_mime_type(file_path: str) -> str: + """ + Get the MIME type for a file based on its extension. + """ + mime_type = mimetypes.guess_type(file_path)[0] + if mime_type is None: + # Default MIME types for common file types + ext = Path(file_path).suffix.lower() + mime_map = { + ".pdf": "application/pdf", + ".txt": "text/plain", + ".jpg": "image/jpeg", + ".jpeg": "image/jpeg", + ".png": "image/png", + ".gif": "image/gif", + ".mp3": "audio/mpeg", + ".mp4": "video/mp4", + } + mime_type = mime_map.get(ext, "application/octet-stream") + return mime_type diff --git a/pyspur/backend/pyspur/utils/mime_types_utils.py b/pyspur/backend/pyspur/utils/mime_types_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..09b85ba5ad6cbfcdbb1c7d56996577a7d7c80d10 --- /dev/null +++ b/pyspur/backend/pyspur/utils/mime_types_utils.py @@ -0,0 +1,125 @@ +import mimetypes +from enum import Enum +from typing import Dict, List + + +class RecognisedMimeType(str, Enum): + """Recognized MIME types that LLMs may support.""" + + # Images + JPEG = "image/jpeg" + PNG = "image/png" + GIF = "image/gif" + WEBP = "image/webp" + SVG = "image/svg+xml" + + # Audio + MP3 = "audio/mpeg" + WAV = "audio/wav" + OGG_AUDIO = "audio/ogg" + WEBM_AUDIO = "audio/webm" + + # Video + MP4 = "video/mp4" + WEBM_VIDEO = "video/webm" + OGG_VIDEO = "video/ogg" + + # Documents + PDF = "application/pdf" + DOC = "application/msword" + DOCX = "application/vnd.openxmlformats-officedocument.wordprocessingml.document" + XLS = "application/vnd.ms-excel" + XLSX = "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" + PPT = "application/vnd.ms-powerpoint" + PPTX = "application/vnd.openxmlformats-officedocument.presentationml.presentation" + + # Text + PLAIN = "text/plain" + HTML = "text/html" + MARKDOWN = "text/markdown" + CSV = "text/csv" + XML = "text/xml" + JSON = "application/json" + + +class MimeCategory(str, Enum): + """Categories of MIME types that LLMs may support.""" + + IMAGES = "images" + AUDIO = "audio" + VIDEO = "video" + DOCUMENTS = "documents" + TEXT = "text" + + +# Common MIME types by category +MIME_TYPES_BY_CATEGORY: Dict[MimeCategory, List[RecognisedMimeType]] = { + MimeCategory.IMAGES: [ + RecognisedMimeType.JPEG, + RecognisedMimeType.PNG, + RecognisedMimeType.GIF, + RecognisedMimeType.WEBP, + RecognisedMimeType.SVG, + ], + MimeCategory.AUDIO: [ + RecognisedMimeType.MP3, + RecognisedMimeType.WAV, + RecognisedMimeType.OGG_AUDIO, + RecognisedMimeType.WEBM_AUDIO, + ], + MimeCategory.VIDEO: [ + RecognisedMimeType.MP4, + RecognisedMimeType.WEBM_VIDEO, + RecognisedMimeType.OGG_VIDEO, + ], + MimeCategory.DOCUMENTS: [ + RecognisedMimeType.PDF, + RecognisedMimeType.DOC, + RecognisedMimeType.DOCX, + RecognisedMimeType.XLS, + RecognisedMimeType.XLSX, + RecognisedMimeType.PPT, + RecognisedMimeType.PPTX, + ], + MimeCategory.TEXT: [ + RecognisedMimeType.PLAIN, + RecognisedMimeType.HTML, + RecognisedMimeType.MARKDOWN, + RecognisedMimeType.CSV, + RecognisedMimeType.XML, + RecognisedMimeType.JSON, + ], +} + + +class UnsupportedFileTypeError(Exception): + """Exception raised when a file type is not supported.""" + + pass + + +def get_mime_type_for_url(url: str) -> RecognisedMimeType: + """ + Get the MIME type for a given URL. + + Args: + url (str): The URL to get the MIME type for. This can be a file path, a URL, or a data URI. + + Returns: + RecognisedMimeType: The MIME type for the URL. + """ + # Data URI + if url.startswith("data:"): + # Data URI + mime_type = url.split(";")[0].split(":")[1] + try: + return RecognisedMimeType(mime_type) + except ValueError: + raise UnsupportedFileTypeError(f"Unsupported data URI: {url.split(';')[0]}") + + # File path or URL + mime_type, _ = mimetypes.guess_type(url) + if mime_type: + return RecognisedMimeType(mime_type) + else: + raise UnsupportedFileTypeError(f"Unsupported file type: {url}") diff --git a/pyspur/backend/pyspur/utils/path_utils.py b/pyspur/backend/pyspur/utils/path_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e18a480124cc64033c4597be75b640fab4c03360 --- /dev/null +++ b/pyspur/backend/pyspur/utils/path_utils.py @@ -0,0 +1,31 @@ +from pathlib import Path + +PROJECT_ROOT = Path.cwd() + + +def is_external_url(url: str) -> bool: + return url.startswith(("http://", "https://", "gs://")) + + +def get_test_files_dir() -> Path: + """Get the directory for test file uploads.""" + test_files_dir = Path.joinpath(PROJECT_ROOT, "data", "test_files") + test_files_dir.mkdir(parents=True, exist_ok=True) + return test_files_dir + + +def resolve_file_path(file_path: str) -> Path | str: + """ + Resolve a file path relative to the project root. + Expects paths in format 'data/test_files/S9/20250120_121759_aialy.pdf' and resolves them to + 'data/test_files/S9/20250120_121759_aialy.pdf' + If the path is an external URL (starts with http:// or https://), returns it as is. + """ + # Handle external URLs + if is_external_url(file_path): + return file_path + + path = Path.joinpath(PROJECT_ROOT, "data", Path(file_path)) + if not path.exists(): + raise FileNotFoundError(f"File not found at expected location: {file_path}") + return path diff --git a/pyspur/backend/pyspur/utils/pydantic_utils.py b/pyspur/backend/pyspur/utils/pydantic_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8a4433c59eaa42e89848fae4c803e07fcfa5936e --- /dev/null +++ b/pyspur/backend/pyspur/utils/pydantic_utils.py @@ -0,0 +1,183 @@ +from typing import Any, Dict, List, Optional, Type + +from pydantic import BaseModel, Field, create_model + + +def get_nested_field(field_name_with_dots: str, model: BaseModel) -> Any: + """Get the value of a nested field from a Pydantic model.""" + field_names = field_name_with_dots.split(".") + value = model + for field_name in field_names: + if isinstance(value, dict): + return value.get(field_name, None) # type: ignore + else: + value = getattr(value, field_name) + return value + + +def get_jinja_template_for_model(model: BaseModel) -> str: + """Generate a Jinja template for a Pydantic model.""" + template = "{\n" + for field_name, _field in model.model_fields.items(): + template += f'"{field_name}": {{{{field_name}}}},\n' + template += "}" + return template + + +def json_schema_to_model( + json_schema: Dict[str, Any], + model_class_name: str = "Output", + base_class: Type[BaseModel] = BaseModel, +) -> Type[BaseModel]: + """Convert a JSON schema to a Pydantic BaseModel class. + + Args: + json_schema: The JSON schema to convert. + model_class_name: The name of the model class to create. + base_class: The base class for the model (default is BaseModel). + + Returns: + A Pydantic BaseModel class. + + """ + # Extract the model name from the schema title. + model_name = model_class_name + + schema_defs = json_schema.get('$defs', {}) + defs = {name: json_schema_to_model(definition, name) for name, definition in schema_defs.items()} + + # Extract the field definitions from the schema properties. + field_definitions = { + name: json_schema_to_pydantic_field(name, prop, json_schema.get("required", []), defs) + for name, prop in json_schema.get("properties", {}).items() + } + + # Create the BaseModel class using create_model(). + return create_model(model_name, **field_definitions, __base__=base_class) + + +def json_schema_to_pydantic_field( + name: str, json_schema: Dict[str, Any], required: List[str], definitions: Dict[str, Type[BaseModel]] +) -> Any: + """Convert a JSON schema property to a Pydantic field definition. + + Args: + name: The field name. + json_schema: The JSON schema property. + required: A list of required fields. + + Returns: + A Pydantic field definition. + + """ + # Get the field type. + type_ = json_schema_to_pydantic_type(json_schema, definitions) + + # Get the field description. + description = json_schema.get("description") + + # Get the field examples. + examples = json_schema.get("examples") + + # Create a Field object with the type, description, and examples. + # The 'required' flag will be set later when creating the model. + return ( + type_, + Field( + description=description, + examples=examples, + default=... if name in required else None, + ), + ) + + +def json_schema_to_pydantic_type(json_schema: Dict[str, Any], definitions: Dict[str, Type[BaseModel]]) -> Any: + """Convert a JSON schema type to a Pydantic type. + + Args: + json_schema: The JSON schema to convert. + + Returns: + A Pydantic type. + + """ + type_ = json_schema.get("type") + + if ref := json_schema.get('$ref'): + return definitions[ref.split("/")[-1]] + + if type_ == "string": + return str + elif type_ == "integer": + return int + elif type_ == "number": + return float + elif type_ == "boolean": + return bool + elif type_ == "array": + items_schema = json_schema.get("items") + if items_schema: + item_type = json_schema_to_pydantic_type(items_schema, definitions) + return List[item_type] + else: + return List + elif type_ == "object": + # Handle nested models. + properties = json_schema.get("properties") + if properties: + nested_model = json_schema_to_model(json_schema) + return nested_model + else: + return Dict + elif type_ == "null": + return Optional[Any] # Use Optional[Any] for nullable fields + elif type_ == None: + return Optional[Any] + else: + raise ValueError(f"Unsupported JSON schema type: {type_}") + + +def json_schema_to_simple_schema(json_schema: Dict[str, Any]) -> Dict[str, str]: + """Convert a JSON schema to a simple schema. + + Args: + json_schema: The JSON schema to convert. + + Returns: + A simple schema. + + """ + simple_schema: Dict[str, str] = {} + + for prop, prop_details in json_schema.get("properties", {}).items(): + prop_type = prop_details.get("type") + if prop_type == "object": + simple_schema[prop] = "dict" + elif prop_type == "array": + simple_schema[prop] = "list" + elif prop_type == "integer": + simple_schema[prop] = "int" + elif prop_type == "number": + simple_schema[prop] = "float" + elif prop_type == "boolean": + simple_schema[prop] = "bool" + elif prop_type == "string": + simple_schema[prop] = "string" + else: + simple_schema[prop] = "Any" + return simple_schema + + +if __name__ == "__main__": + + class TestModel(BaseModel): + name: str + age: int + + json_schema_to_model(TestModel.model_json_schema()) + + + class TestSubModel(BaseModel): + students: List[TestModel] + + json_schema_to_model(TestSubModel.model_json_schema()) \ No newline at end of file diff --git a/pyspur/backend/pyspur/utils/redis_cache_wrapper.py b/pyspur/backend/pyspur/utils/redis_cache_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..91c94d0caabfd3008e69c7731b4817c458f8b720 --- /dev/null +++ b/pyspur/backend/pyspur/utils/redis_cache_wrapper.py @@ -0,0 +1,272 @@ +import asyncio +import atexit +import hashlib +import json +import os +import random +from collections import deque +from typing import Optional + +import attrs +import redis.asyncio as redis +from tqdm.asyncio import tqdm_asyncio + + +def get_digest(data): + """ + Compute MD5 digest of the JSON-serialized data. + """ + return hashlib.md5(json.dumps(data).encode()).hexdigest() + + +def cache_key(input_data, func_key): + """ + Generate a cache key by concatenating the function key with the MD5 hash of the input data. + """ + return f"{func_key}:{get_digest(input_data)}" + + +def get_client_reader(reader_port: int, writer_port: int): + """ + Get Redis client instances for reader and writer. + + If the reader and writer ports are the same, return the same client for both. + """ + if reader_port == writer_port: + redis_client = redis.Redis(host="localhost", port=reader_port) + return redis_client, redis_client + else: + redis_writer = redis.Redis(host="localhost", port=writer_port) + redis_reader = redis.Redis(host="localhost", port=reader_port) + return redis_writer, redis_reader + + +def get_default_port(): + """ + Get the default Redis reader port from environment variables or use 6377. + """ + return int(os.environ.get("REDIS_READER_PORT", 6377)) + + +def get_event_loop(): + """ + Get the current event loop or create a new one if it doesn't exist. + """ + try: + return asyncio.get_event_loop() + except RuntimeError: + return asyncio.new_event_loop() + + +@attrs.define +class RedisWrapper: + """ + A wrapper class for Redis operations with batching capabilities. + + Usage: + # Connect to the remote server Redis instance with + ssh -i ~/.ssh/rr_dev.pem exx@64.255.46.66 -fNT -L 6377:localhost:6377 + + # Set up a local read replica with + redis-server --port 6380 --slaveof 0.0.0.0 6377 + """ + + port: int = attrs.field(default=get_default_port()) + batch_size: int = attrs.field(default=2000) + batch_time: float = attrs.field(default=0.2) + queue: deque = attrs.field(init=False, factory=deque) + client: redis.Redis = attrs.field(init=False) + loop: asyncio.AbstractEventLoop = attrs.field(init=False) + has_items: asyncio.Event = attrs.field(init=False, factory=asyncio.Event) + lock: asyncio.Lock = attrs.field(init=False, factory=asyncio.Lock) + maximum_run_per_pipeline: int = 256 # Max commands per pipeline execution + + def __attrs_post_init__(self): + """ + Post-initialization to set up Redis client and event loop. + """ + self.client = redis.Redis(port=self.port) + self.loop = get_event_loop() + + @classmethod + def singleton(cls, port: Optional[int] = None): + """ + Get a singleton instance of RedisWrapper. + """ + if not hasattr(cls, "_instance"): + if port is None: + port = get_default_port() + cls._instance = cls(port=port) + return cls._instance + + async def enqueue(self, operation, *args): + """ + Enqueue a Redis operation to be executed. + + Args: + operation (str): The Redis operation (e.g., 'GET', 'SET'). + *args: Arguments for the Redis operation. + + Returns: + The result of the Redis operation. + """ + future = self.loop.create_future() + self.queue.append((operation, args, future)) + + async with self.lock: + if future.done(): + return future.result() + await self.flush() + + assert future.done() + return future.result() + + async def read(self, key_str, converter=None): + """ + Asynchronously read a value from Redis. + + Args: + key_str (str): The key to read. + converter (callable, optional): A function to convert the value. + + Returns: + The value from Redis, converted if a converter is provided. + """ + key = f"json_{key_str}" + value = await self.enqueue("GET", key) + if value: + value = value.decode("utf-8") + if converter: + return converter(value) + else: + return json.loads(value) + else: + return None + + async def lrange(self, idx, start, end): + """ + Get a range of elements from a Redis list. + + Args: + idx (str): Index identifier for the list. + start (int): Starting index. + end (int): Ending index. + + Returns: + A list of elements from the Redis list. + """ + key = f"list_{idx}" + values = await self.enqueue("LRANGE", key, start, end) + return [json.loads(value) for value in values] if values else [] + + async def rpush(self, idx, *values): + """ + Append values to the end of a Redis list. + + Args: + idx (str): Index identifier for the list. + *values: Values to append. + """ + key = f"list_{idx}" + json_values = [json.dumps(value) for value in values] + await self.enqueue("RPUSH", key, *json_values) + + async def write(self, key_str, value): + """ + Asynchronously write a value to Redis. + + Args: + key_str (str): The key to write. + value: The value to write. + """ + key = f"json_{key_str}" + return await self.enqueue("MSET", {key: json.dumps(value)}) + + async def clear(self, idx): + """ + Clear a Redis list. + + Args: + idx (str): Index identifier for the list. + """ + key = f"list_{idx}" + await self.enqueue("DELETE", key) + + async def flush(self): + """ + Flush the queued Redis operations using a pipeline. + """ + if not self.queue: + return + + pipeline = self.client.pipeline() + futures = [] + mset_futures = [] + mset_dict = {} + + # Collect operations up to maximum_run_per_pipeline + while self.queue and len(futures) < self.maximum_run_per_pipeline: + operation, args, future = self.queue.popleft() + if operation == "MSET": + mset_futures.append(future) + (arg_dict,) = args + mset_dict.update(arg_dict) + else: + getattr(pipeline, operation.lower())(*args) + futures.append(future) + + if mset_dict: + pipeline.mset(mset_dict) + + results = await pipeline.execute() + assert len(results) == len(futures) + (1 if mset_dict else 0) + + # Handle MSET results + if mset_dict: + mset_result = results[-1] + results = results[:-1] + for future in mset_futures: + future.set_result(mset_result) + + # Set results for other operations + for future, result in zip(futures, results): + future.set_result(result) + + +async def test_all_funcs(i): + """ + Test all RedisWrapper functions. + + Args: + i (int): Test identifier. + """ + client = RedisWrapper.singleton() + + await asyncio.sleep(random.random()) + + # Test read/write operations + key = f"readwrite_{i}" + write_read_val = {"hello": f"world_{i}"} + await client.write(key, write_read_val) + read_val = await client.read(key) + assert read_val == write_read_val, f"read_val: {read_val}, write_read_val: {write_read_val}" + + await asyncio.sleep(random.random()) + + # Test rpush and lrange + key = f"rpushlrange_{i}" + await client.rpush(key, "one", "two", "three") + lrange_result = await client.lrange(key, 0, -1) + assert set(lrange_result) == {"one", "two", "three"} + print(f"Test {i} passed!") + + +async def run_big_tests(): + """ + Run a large number of tests concurrently. + """ + await tqdm_asyncio.gather(*[test_all_funcs(i) for i in range(50000)]) + + +if __name__ == "__main__": + asyncio.run(run_big_tests()) diff --git a/pyspur/backend/pyspur/utils/timing.py b/pyspur/backend/pyspur/utils/timing.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pyspur/backend/pyspur/utils/workflow_version_utils.py b/pyspur/backend/pyspur/utils/workflow_version_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4ba5bb227270f1a8289211165ae92fa0defe8e08 --- /dev/null +++ b/pyspur/backend/pyspur/utils/workflow_version_utils.py @@ -0,0 +1,66 @@ +import hashlib +import json + +from sqlalchemy.orm import Session + +from ..models.workflow_version_model import WorkflowVersionModel +from ..schemas.workflow_schemas import ( + WorkflowDefinitionSchema, + WorkflowResponseSchema, +) + + +def get_latest_workflow_version(workflow_id: str, db: Session) -> int: + """ + Retrieve the latest version number of a workflow. + Returns the latest version number if it exists, otherwise 0. + """ + latest_version = ( + db.query(WorkflowVersionModel) + .filter(WorkflowVersionModel.workflow_id == workflow_id) + .order_by(WorkflowVersionModel.version.desc()) + .first() + ) + + return latest_version.version if latest_version else 0 + + +def hash_workflow_definition(definition: WorkflowDefinitionSchema) -> str: + """ + Create a hash of the workflow definition for comparison. + """ + definition_str = json.dumps(definition, sort_keys=True) + return hashlib.sha256(definition_str.encode("utf-8")).hexdigest() + + +def fetch_workflow_version( + workflow_id: str, workflow: WorkflowResponseSchema, db: Session +) -> WorkflowVersionModel: + """ + Retrieve an existing workflow version with the same definition or create a new one. + """ + definition_hash = hash_workflow_definition(workflow.definition) + existing_version = ( + db.query(WorkflowVersionModel) + .filter( + WorkflowVersionModel.workflow_id == workflow_id, + WorkflowVersionModel.definition_hash == definition_hash, + ) + .first() + ) + + if existing_version: + return existing_version + + latest_version_number = get_latest_workflow_version(workflow_id, db) + new_version = WorkflowVersionModel( + workflow_id=workflow_id, + version=latest_version_number + 1, + name=workflow.name, + description=workflow.description, + definition=workflow.definition, + definition_hash=definition_hash, + ) + db.add(new_version) + db.commit() + return new_version diff --git a/pyspur/backend/pyspur/workflow_builder.py b/pyspur/backend/pyspur/workflow_builder.py new file mode 100644 index 0000000000000000000000000000000000000000..ee72d81cab8603abded5f224cba364e64df76deb --- /dev/null +++ b/pyspur/backend/pyspur/workflow_builder.py @@ -0,0 +1,297 @@ +import json +from typing import Any, Dict, List, Optional, Tuple, Union + +from .schemas.workflow_schemas import ( + SpurType, + WorkflowDefinitionSchema, + WorkflowLinkSchema, + WorkflowNodeCoordinatesSchema, + WorkflowNodeDimensionsSchema, + WorkflowNodeSchema, +) + + +class WorkflowBuilder: + """Builder class for creating workflows programmatically. + + This class allows users to define workflows in code, providing a cleaner + alternative to manually defining the JSON structure. + + Example: + ```python + # Create a workflow builder + builder = WorkflowBuilder("My Workflow", "This is a workflow created with code") + + # Add nodes + input_node = builder.add_node( + id="input_node", + node_type="InputNode", + config={"output_schema": {"question": "string"}} + ) + + llm_node = builder.add_node( + id="llm_node", + node_type="SingleLLMCallNode", + config={ + "llm_info": { + "model": "openai/gpt-4o", + "temperature": 0.7, + }, + "system_message": "You are a helpful assistant." + } + ) + + output_node = builder.add_node( + id="output_node", + node_type="OutputNode", + config={ + "output_schema": {"answer": "string"}, + "output_map": {"answer": "llm_node.response"} + } + ) + + # Connect nodes + builder.add_link(input_node, llm_node) + builder.add_link(llm_node, output_node) + + # Get the workflow definition + workflow_def = builder.build() + ``` + + """ + + def __init__(self, name: str, description: str = ""): + """Initialize a workflow builder. + + Args: + name: The name of the workflow + description: Optional description for the workflow + + """ + self.name = name + self.description = description + self.nodes: List[WorkflowNodeSchema] = [] + self.links: List[WorkflowLinkSchema] = [] + self.test_inputs: List[Dict[str, Any]] = [] + self.spur_type: SpurType = SpurType.WORKFLOW + self._node_counter: Dict[str, int] = {} # Track counts for auto-generated IDs + + # Let's add some default positioning logic to make visualizing nicer + self._next_x = 100 + self._next_y = 100 + self._horizontal_spacing = 250 # Default horizontal spacing between nodes + self._vertical_spacing = 150 # Default vertical spacing between rows + self._max_x_per_row: Dict[int, int] = {0: self._next_x} # Track max x per row + self._current_row = 0 + + def add_node( + self, + node_type: str, + config: Dict[str, Any], + id: Optional[str] = None, + title: str = "", + parent_id: Optional[str] = None, + coordinates: Optional[Tuple[float, float]] = None, + dimensions: Optional[Tuple[float, float]] = None, + subworkflow: Optional[WorkflowDefinitionSchema] = None, + row: Optional[int] = None, + ) -> str: + """Add a node to the workflow. + + Args: + node_type: The type of node to add (e.g., "InputNode", "SingleLLMCallNode") + config: Configuration for the node + id: Optional node ID. If not provided, one will be generated + title: Optional display title for the node + parent_id: Optional parent node ID for hierarchical workflows + coordinates: Optional tuple of (x, y) coordinates for UI positioning + dimensions: Optional tuple of (width, height) for UI sizing + subworkflow: Optional sub-workflow definition for composite nodes + row: Optional row number for positioning (used for auto-layout) + + Returns: + The ID of the added node + + """ + # If no ID is provided, generate one based on the node type + if id is None: + id = self._generate_id(node_type) + + # If no title is provided, use the ID + if not title: + title = id + + # Handle coordinates for UI layout + node_coordinates = None + if coordinates: + node_coordinates = WorkflowNodeCoordinatesSchema(x=coordinates[0], y=coordinates[1]) + else: + # Auto-position the node + if row is not None: + self._current_row = row + + # Calculate coordinates based on current row and position + x = self._max_x_per_row.get(self._current_row, self._next_x) + y = self._current_row * self._vertical_spacing + 100 + + node_coordinates = WorkflowNodeCoordinatesSchema(x=x, y=y) + + # Update the max x for this row + self._max_x_per_row[self._current_row] = x + self._horizontal_spacing + + # Handle dimensions for UI sizing + node_dimensions = None + if dimensions: + node_dimensions = WorkflowNodeDimensionsSchema( + width=dimensions[0], height=dimensions[1] + ) + + # Create the node schema + node = WorkflowNodeSchema( + id=id, + title=title, + parent_id=parent_id, + node_type=node_type, + config=config, + coordinates=node_coordinates, + dimensions=node_dimensions, + subworkflow=subworkflow, + ) + + # Add the node to the workflow + self.nodes.append(node) + + return id + + def add_link( + self, + source_id: str, + target_id: str, + source_handle: Optional[str] = None, + target_handle: Optional[str] = None, + ) -> None: + """Add a link between two nodes. + + Args: + source_id: The ID of the source node + target_id: The ID of the target node + source_handle: Optional source handle for routers and complex nodes + target_handle: Optional target handle + + """ + link = WorkflowLinkSchema( + source_id=source_id, + target_id=target_id, + source_handle=source_handle, + target_handle=target_handle, + ) + self.links.append(link) + + def add_test_input(self, input_data: Dict[str, Any]) -> None: + """Add test input data for the workflow. + + Args: + input_data: A dictionary containing test input data + + """ + self.test_inputs.append(input_data) + + def set_spur_type(self, spur_type: Union[SpurType, str]) -> None: + """Set the type of the workflow. + + Args: + spur_type: The workflow type (workflow, chatbot, or agent) + + """ + self.spur_type = SpurType(spur_type) + + def build(self) -> WorkflowDefinitionSchema: + """Build and return the workflow definition schema. + + Returns: + A WorkflowDefinitionSchema instance representing the complete workflow + + """ + workflow = WorkflowDefinitionSchema( + nodes=self.nodes, + links=self.links, + test_inputs=self.test_inputs, + spur_type=self.spur_type, + ) + return workflow + + def to_dict(self) -> Dict[str, Any]: + """Convert the workflow to a dictionary. + + Returns: + A dictionary representation of the workflow + + """ + workflow = self.build() + return workflow.model_dump() + + def to_json(self, indent: int = 2) -> str: + """Convert the workflow to a JSON string. + + Args: + indent: Number of spaces for indentation in JSON output + + Returns: + A JSON string representation of the workflow + + """ + workflow = self.build() + return json.dumps(workflow.model_dump(), indent=indent) + + @classmethod + def from_workflow_definition( + cls, workflow_def: WorkflowDefinitionSchema, name: str = "", description: str = "" + ) -> "WorkflowBuilder": + """Create a WorkflowBuilder from an existing WorkflowDefinitionSchema. + + Args: + workflow_def: The workflow definition to convert + name: Optional name for the workflow (if not provided, will use "Imported Workflow") + description: Optional description for the workflow + + Returns: + A WorkflowBuilder instance with the nodes and links from the workflow definition + + """ + builder = cls(name or "Imported Workflow", description) + + # Copy nodes + builder.nodes = workflow_def.nodes + + # Copy links + builder.links = workflow_def.links + + # Copy test inputs + builder.test_inputs = workflow_def.test_inputs + + # Copy spur type + builder.spur_type = workflow_def.spur_type + + return builder + + def _generate_id(self, node_type: str) -> str: + """Generate a unique ID for a node based on its type. + + Args: + node_type: The type of node + + Returns: + A unique ID for the node + + """ + # Remove "Node" suffix if present + base_name = node_type + if base_name.endswith("Node"): + base_name = base_name[:-4] + + # Get the counter for this node type + counter = self._node_counter.get(base_name, 0) + 1 + self._node_counter[base_name] = counter + + # Generate the ID + return f"{base_name}_{counter}" diff --git a/pyspur/backend/pyspur/workflow_code_handler.py b/pyspur/backend/pyspur/workflow_code_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..e6ff5f866d42717c8ea69a3a3949b4aa5eca11ac --- /dev/null +++ b/pyspur/backend/pyspur/workflow_code_handler.py @@ -0,0 +1,313 @@ +import re +from typing import Any, Dict, Optional + +from .schemas.workflow_schemas import WorkflowDefinitionSchema +from .workflow_builder import WorkflowBuilder + + +class WorkflowCodeHandler: + """Utility class for handling workflow-as-code functionality. + + This class provides methods to: + 1. Generate Python code from a workflow definition + 2. Parse Python code to create a workflow definition + 3. Reconcile UI-driven changes with code-driven workflows + """ + + @classmethod + def generate_code( + cls, + workflow_def: WorkflowDefinitionSchema, + workflow_name: str = "My Workflow", + workflow_description: str = "", + preserve_coordinates: bool = True, + preserve_dimensions: bool = True, + ) -> str: + """Generate Python code from a workflow definition. + + Args: + workflow_def: The workflow definition to convert to code + workflow_name: The name of the workflow + workflow_description: Optional description for the workflow + preserve_coordinates: Whether to include node coordinates in the code + preserve_dimensions: Whether to include node dimensions in the code + + Returns: + Python code representation of the workflow + + """ + code = [ + "from pyspur.workflow_builder import WorkflowBuilder", + "", + "# Create a workflow builder", + f'builder = WorkflowBuilder("{workflow_name}", "{workflow_description}")', + "", + ] + + code.append("# Set the workflow type") + code.append(f'builder.set_spur_type("{workflow_def.spur_type}")') + code.append("") + + # Generate node creation code + code.append("# Add nodes") + # Map node IDs to variable names + node_vars: Dict[str, str] = {} + + for node in workflow_def.nodes: + # Create a valid Python variable name from the node ID + var_name = cls._create_variable_name(node.id) + node_vars[node.id] = var_name + + # Format the config dictionary as Python code + config_str = cls._format_dict(node.config) + + # Build the add_node arguments + args = [ + f'node_type="{node.node_type}"', + f"config={config_str}", + f'id="{node.id}"', + ] + + if node.title and node.title != node.id: + args.append(f'title="{node.title}"') + + if node.parent_id: + args.append(f'parent_id="{node.parent_id}"') + + if preserve_coordinates and node.coordinates: + args.append(f"coordinates=({node.coordinates.x}, {node.coordinates.y})") + + if preserve_dimensions and node.dimensions: + args.append(f"dimensions=({node.dimensions.width}, {node.dimensions.height})") + + if node.subworkflow: + # This would need recursive handling for subworkflows + args.append("# subworkflow is not included in this code generation") + + # Join all arguments with newlines and proper indentation for readability + formatted_args = ",\n ".join(args) + + # Add the node creation code + code.append(f"{var_name} = builder.add_node(") + code.append(f" {formatted_args}") + code.append(")") + code.append("") + + # Generate link creation code + if workflow_def.links: + code.append("# Add links between nodes") + for link in workflow_def.links: + source_var = node_vars.get(link.source_id, f'"{link.source_id}"') + target_var = node_vars.get(link.target_id, f'"{link.target_id}"') + + if link.source_handle or link.target_handle: + args = [ + f"source_id={source_var}", + f"target_id={target_var}", + ] + + if link.source_handle: + args.append(f'source_handle="{link.source_handle}"') + + if link.target_handle: + args.append(f'target_handle="{link.target_handle}"') + + # Join all arguments with commas + formatted_args = ", ".join(args) + + code.append(f"builder.add_link({formatted_args})") + else: + code.append(f"builder.add_link({source_var}, {target_var})") + + code.append("") + + # Generate test input code + if workflow_def.test_inputs: + code.append("# Add test inputs") + for _i, test_input in enumerate(workflow_def.test_inputs): + input_str = cls._format_dict(test_input) + code.append(f"builder.add_test_input({input_str})") + code.append("") + + # Build the workflow + code.append("# Build the workflow definition") + code.append("workflow_def = builder.build()") + + return "\n".join(code) + + @classmethod + def parse_code( + cls, code: str, existing_workflow: Optional[WorkflowDefinitionSchema] = None + ) -> WorkflowDefinitionSchema: + """Parse Python code and convert it to a workflow definition. + + The code is expected to use the WorkflowBuilder API to define a workflow. + + Args: + code: Python code defining a workflow + existing_workflow: Optional existing workflow to preserve UI metadata from + + Returns: + A WorkflowDefinitionSchema instance + + Raises: + ValueError: If the code cannot be parsed or does not define a workflow + + """ + try: + # Create a local namespace to execute the code + local_vars: Dict[str, Any] = {} + + # Execute the code in a restricted environment + exec(code, {"WorkflowBuilder": WorkflowBuilder}, local_vars) + + # Look for a workflow_def variable that is a WorkflowDefinitionSchema + workflow_def = None + for _var_name, var_value in local_vars.items(): + if isinstance(var_value, WorkflowDefinitionSchema): + workflow_def = var_value + break + + if not workflow_def: + # If we didn't find a WorkflowDefinitionSchema directly, look for a builder + builder = None + for _var_name, var_value in local_vars.items(): + if isinstance(var_value, WorkflowBuilder): + builder = var_value + break + + if builder: + workflow_def = builder.build() + + if not workflow_def: + raise ValueError("No workflow definition found in the code") + + # If we have an existing workflow, preserve UI metadata + if existing_workflow: + workflow_def = cls._reconcile_workflow_with_existing( + workflow_def, existing_workflow + ) + + return workflow_def + + except Exception as e: + raise ValueError(f"Failed to parse workflow code: {str(e)}") from e + + @classmethod + def _reconcile_workflow_with_existing( + cls, new_workflow: WorkflowDefinitionSchema, existing_workflow: WorkflowDefinitionSchema + ) -> WorkflowDefinitionSchema: + """Reconcile a new workflow with an existing one, preserving UI metadata. + + Args: + new_workflow: The new workflow generated from code + existing_workflow: The existing workflow with UI metadata + + Returns: + A workflow definition with code structure and UI metadata + + """ + # Create a mapping of node IDs to nodes in the existing workflow + existing_nodes = {node.id: node for node in existing_workflow.nodes} + + # For each node in the new workflow, copy UI metadata from the existing workflow + for node in new_workflow.nodes: + if node.id in existing_nodes: + existing_node = existing_nodes[node.id] + + # Preserve coordinates if they exist + if existing_node.coordinates and not node.coordinates: + node.coordinates = existing_node.coordinates + + # Preserve dimensions if they exist + if existing_node.dimensions and not node.dimensions: + node.dimensions = existing_node.dimensions + + return new_workflow + + @classmethod + def _format_dict(cls, d: Dict[str, Any], indent: int = 0) -> str: + """Format a dictionary as a Python code string. + + Args: + d: The dictionary to format + indent: The current indentation level + + Returns: + A formatted string representing the dictionary as Python code + + """ + if not d: + return "{}" + + # Handle special cases for formatting + lines = ["{"] + indent_str = " " * (indent + 1) + + for key, value in d.items(): + formatted_value = cls._format_value(value, indent + 1) + lines.append(f'{indent_str}"{key}": {formatted_value},') + + lines.append(" " * indent + "}") + + return "\n".join(lines) + + @classmethod + def _format_value(cls, value: Any, indent: int = 0) -> str: + """Format a value as a Python code string. + + Args: + value: The value to format + indent: The current indentation level + + Returns: + A formatted string representing the value as Python code + + """ + if value is None: + return "None" + elif isinstance(value, (int, float, bool)): + return str(value) + elif isinstance(value, str): + # Escape quotes and special characters + escaped = value.replace('"', '\\"').replace("\n", "\\n") + return f'"{escaped}"' + elif isinstance(value, (list, tuple)): + if not value: + return "[]" + + items = [cls._format_value(item, indent) for item in value] # type: ignore + if len(items) <= 3 and all(len(item) < 40 for item in items): + # Format as a single line if it's short + return f"[{', '.join(items)}]" + else: + # Format as multiple lines + indent_str = " " * indent + next_indent = " " * (indent + 1) + items_str = f",\n{next_indent}".join(items) + return f"[\n{next_indent}{items_str}\n{indent_str}]" + elif isinstance(value, dict): + return cls._format_dict(value, indent) # type: ignore + else: + # For other types, use repr + return repr(value) + + @classmethod + def _create_variable_name(cls, node_id: str) -> str: + """Create a valid Python variable name from a node ID. + + Args: + node_id: The node ID to convert + + Returns: + A valid Python variable name + + """ + # Replace non-alphanumeric characters with underscores + var_name = re.sub(r"[^a-zA-Z0-9_]", "_", node_id) + + # Ensure it starts with a letter or underscore + if var_name and not var_name[0].isalpha() and var_name[0] != "_": + var_name = "node_" + var_name + + return var_name diff --git a/pyspur/backend/scripts/repair_socket_workers.py b/pyspur/backend/scripts/repair_socket_workers.py new file mode 100644 index 0000000000000000000000000000000000000000..1d771ad563ad7efaf544e214d57e31d16be1244d --- /dev/null +++ b/pyspur/backend/scripts/repair_socket_workers.py @@ -0,0 +1,263 @@ +#!/usr/bin/env python + +"""Repair Socket Mode Workers + +This script helps diagnose and repair socket mode worker issues. +It can: +1. List all socket mode workers +2. Clean up stale marker files +3. Restart workers +4. Fix database state to match reality + +Usage: +python repair_socket_workers.py [--clean] [--restart] [--fix-db] + +""" + +import argparse +import os +import sys +from pathlib import Path + +# Add the parent directory to sys.path +sys.path.append(str(Path(__file__).resolve().parent.parent)) + +try: + from pyspur.api.slack_management import recover_orphaned_workers + from pyspur.database import get_db + from pyspur.integrations.slack.worker_status import ( + MARKER_DIR, + find_running_worker_process, + get_worker_status, + list_workers, + ) + from pyspur.models.slack_agent_model import SlackAgentModel + + # Try to import psutil + try: + import psutil + + PSUTIL_AVAILABLE = True + except ImportError: + PSUTIL_AVAILABLE = False + print("Warning: psutil not available, some features will be limited") +except ImportError as e: + print(f"Error importing required modules: {e}") + print( + "Make sure to run this script from the project root or add the project root to PYTHONPATH" + ) + sys.exit(1) + + +def list_all_workers() -> None: + """List all socket mode workers.""" + try: + # List all workers in the marker directory + print(f"Checking marker directory: {MARKER_DIR}") + if not os.path.exists(MARKER_DIR): + print("Marker directory does not exist. No workers found.") + return + + # Get all worker status + workers = list_workers() + + if not workers: + print("No worker marker files found.") + return + + print(f"Found {len(workers)} worker marker files:") + for worker in workers: + # Print key information about each worker + print(f"Agent ID: {worker['agent_id']}") + print(f" PID: {worker['pid']}") + print(f" Running: {worker['process_running']}") + print(f" Status: {worker['status']}") + if worker["details"]: + started_at = worker["details"].get("started_at", "unknown") + last_check = worker["details"].get("last_check", "unknown") + print(f" Started: {started_at}") + print(f" Last check: {last_check}") + print("") + + except Exception as e: + print(f"Error listing workers: {e}") + + +def clean_stale_markers() -> None: + """Clean up stale worker marker files.""" + try: + # Check if the marker directory exists + if not os.path.exists(MARKER_DIR): + print("Marker directory does not exist. Nothing to clean.") + return + + # Get all worker status + workers = list_workers() + + if not workers: + print("No worker marker files found.") + return + + # Count of cleaned files + cleaned = 0 + + # Check each worker + for worker in workers: + agent_id = worker["agent_id"] + pid = worker["pid"] + + # Check if the process is running + if not worker["process_running"]: + # Process is not running, clean up the marker files + print(f"Cleaning up marker files for agent {agent_id} (pid {pid})") + + # Remove pid file + pid_file = f"{MARKER_DIR}/agent_{agent_id}.pid" + if os.path.exists(pid_file): + os.remove(pid_file) + cleaned += 1 + print(f" Removed pid file: {pid_file}") + + # Remove status file + status_file = f"{MARKER_DIR}/agent_{agent_id}.status" + if os.path.exists(status_file): + os.remove(status_file) + cleaned += 1 + print(f" Removed status file: {status_file}") + + print(f"Cleaned up {cleaned} stale marker files.") + + except Exception as e: + print(f"Error cleaning markers: {e}") + + +def restart_workers() -> None: + """Restart socket mode workers.""" + try: + # First, get a database session + db = next(get_db()) + + # Get all agents with socket mode enabled + agents = ( + db.query(SlackAgentModel).filter(SlackAgentModel.socket_mode_enabled.is_(True)).all() + ) + + if not agents: + print("No agents with socket_mode_enabled=True found in the database.") + return + + print(f"Found {len(agents)} agents with socket_mode_enabled=True in the database:") + for agent in agents: + agent_id = agent.id + print(f"Agent ID: {agent_id}, Name: {agent.name}") + + # Check if a worker is already running for this agent + worker_status = get_worker_status(agent_id) + if worker_status["process_running"]: + print(f" Worker already running for agent {agent_id}, skipping restart") + continue + + # If nothing is running, call the recover endpoint + print(f" Restarting worker for agent {agent_id}...") + + # Import the socket manager + from pyspur.integrations.slack.socket_manager import SocketManager + + socket_manager = SocketManager() + + # Start the worker + success = socket_manager.start_worker(agent_id) + if success: + print(f" Successfully started worker for agent {agent_id}") + else: + print(f" Failed to start worker for agent {agent_id}") + + except Exception as e: + print(f"Error restarting workers: {e}") + finally: + db.close() + + +def fix_database_state() -> None: + """Fix database state to match reality.""" + try: + # First, get a database session + db = next(get_db()) + + # Get all agents + all_agents = db.query(SlackAgentModel).all() + if not all_agents: + print("No agents found in the database.") + return + + # Check each agent's status + for agent in all_agents: + agent_id = agent.id + + # Check if there's a worker running for this agent + is_running, _ = find_running_worker_process(agent_id) + db_enabled = bool(agent.socket_mode_enabled) + + # If there's a mismatch between the database and reality, fix it + if is_running and not db_enabled: + print( + f"Agent {agent_id}: Worker is running but socket_mode_enabled is False, fixing..." + ) + agent.socket_mode_enabled = True + db.commit() + db.refresh(agent) + print(f" Updated database state for agent {agent_id}") + elif not is_running and db_enabled: + print( + f"Agent {agent_id}: Worker is not running but socket_mode_enabled is True, fixing..." + ) + agent.socket_mode_enabled = False + db.commit() + db.refresh(agent) + print(f" Updated database state for agent {agent_id}") + else: + print( + f"Agent {agent_id}: Database state matches reality (socket_mode_enabled={db_enabled})" + ) + + except Exception as e: + print(f"Error fixing database state: {e}") + finally: + db.close() + + +def main(): + """Main entry point.""" + parser = argparse.ArgumentParser(description="Repair Socket Mode Workers") + parser.add_argument("--list", action="store_true", help="List all socket mode workers") + parser.add_argument("--clean", action="store_true", help="Clean up stale marker files") + parser.add_argument("--restart", action="store_true", help="Restart workers") + parser.add_argument("--fix-db", action="store_true", help="Fix database state to match reality") + + args = parser.parse_args() + + # If no arguments provided, show help + if not (args.list or args.clean or args.restart or args.fix_db): + parser.print_help() + return + + # Run the requested actions + if args.list: + print("\n=== Listing Socket Mode Workers ===") + list_all_workers() + + if args.clean: + print("\n=== Cleaning Stale Marker Files ===") + clean_stale_markers() + + if args.restart: + print("\n=== Restarting Socket Mode Workers ===") + restart_workers() + + if args.fix_db: + print("\n=== Fixing Database State ===") + fix_database_state() + + +if __name__ == "__main__": + main() diff --git a/pyspur/backend/sqlite/.gitignore b/pyspur/backend/sqlite/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..a3a0c8b5f48c0260a4cb43aa577f9b18896ee280 --- /dev/null +++ b/pyspur/backend/sqlite/.gitignore @@ -0,0 +1,2 @@ +* +!.gitignore \ No newline at end of file diff --git a/pyspur/backend/test_ollama.sh b/pyspur/backend/test_ollama.sh new file mode 100644 index 0000000000000000000000000000000000000000..8331b8df7842c666a404c6acfbde93007de4f7d6 --- /dev/null +++ b/pyspur/backend/test_ollama.sh @@ -0,0 +1,49 @@ +#!/bin/bash + +# Function for fancy error printing +print_error() { + echo " +╔════════════════════════════════════════════════════════════════╗ +║ 🚫 ERROR 🚫 ║ +╠════════════════════════════════════════════════════════════════╣ +║ ║ +║ Cannot connect to Ollama at: $OLLAMA_BASE_URL ║ +║ ║ +║ Please check: ║ +║ 1. Ollama is running ║ +║ 2. The OLLAMA_BASE_URL is correct ║ +║ 3. The network connection is working ║ +║ ║ +║ Error details: ║ +║ $1 +║ ║ +╚════════════════════════════════════════════════════════════════╝ +" + exit 1 +} + +# Check if OLLAMA_BASE_URL is set +if [ -n "$OLLAMA_BASE_URL" ]; then + echo "Testing Ollama connection at: $OLLAMA_BASE_URL" + + # Try to fetch the model list from Ollama + response=$(curl -s -w "\n%{http_code}" "$OLLAMA_BASE_URL/api/tags" \ + -H "Content-Type: application/json" 2>&1) + + # Get the HTTP status code + http_code=$(echo "$response" | tail -n1) + # Get the response body + body=$(echo "$response" | sed '$d') + + # Check if curl command was successful + if [ $? -ne 0 ]; then + print_error "Connection failed: $body" + fi + + # Check if we got a successful response + if [ "$http_code" -ne 200 ]; then + print_error "HTTP Error $http_code: $body" + fi + + echo "✅ Successfully connected to Ollama" +fi \ No newline at end of file diff --git a/pyspur/backend/tests/README.md b/pyspur/backend/tests/README.md new file mode 100644 index 0000000000000000000000000000000000000000..eb097370829e191f866d5bf7b973032db15c2707 --- /dev/null +++ b/pyspur/backend/tests/README.md @@ -0,0 +1,55 @@ +# PySpur Backend Tests + +This directory contains tests for the PySpur backend application. + +## Directory Structure + +- `cli/`: Tests for the CLI module +- `nodes/`: Tests for the nodes module +- `conftest.py`: Common test fixtures + +## Running Tests + +To run all tests: +```bash +python -m pytest +``` + +To run tests for a specific module: +```bash +python -m pytest tests/cli/ +``` + +To run a specific test file: +```bash +python -m pytest tests/cli/test_main.py +``` + +## Coverage + +To run tests with coverage: +```bash +python -m pytest --cov=pyspur +``` + +To generate a coverage report: +```bash +python -m pytest --cov=pyspur --cov-report=html +``` + +This will create an HTML report in the `htmlcov` directory. + +## Adding New Tests + +When adding new tests: + +1. Create a test file with the prefix `test_` (e.g., `test_utils.py`) +2. Group related tests in the appropriate directory (e.g., `cli/`, `nodes/`) +3. Use fixtures from `conftest.py` where possible to avoid duplicating setup code +4. Use mocks to avoid external dependencies + +## Test Naming Conventions + +- Test files: `test_.py` +- Test functions: `test__` +- Test classes: `Test` \ No newline at end of file diff --git a/pyspur/backend/tests/__init__.py b/pyspur/backend/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3b1fafd925421b6de830f2460adffb2a9c4a3ce3 --- /dev/null +++ b/pyspur/backend/tests/__init__.py @@ -0,0 +1 @@ +"""PySpur tests package.""" diff --git a/pyspur/backend/tests/cli/__init__.py b/pyspur/backend/tests/cli/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..260e232d491e9e9b6528d8f0c1f86102d0793926 --- /dev/null +++ b/pyspur/backend/tests/cli/__init__.py @@ -0,0 +1 @@ +"""CLI tests package.""" diff --git a/pyspur/backend/tests/cli/test_main.py b/pyspur/backend/tests/cli/test_main.py new file mode 100644 index 0000000000000000000000000000000000000000..37d4fa09f5f1bc3ca1e7ade9819dfb8104195f21 --- /dev/null +++ b/pyspur/backend/tests/cli/test_main.py @@ -0,0 +1,156 @@ +"""Tests for the main.py module in the PySpur CLI.""" + +import os +from pathlib import Path +from typing import Optional +from unittest.mock import MagicMock, patch + +import pytest +from click.testing import Result +from typer.testing import CliRunner + +from pyspur.cli.main import app + + +@pytest.fixture +def runner() -> CliRunner: + """Fixture for creating a CLI runner.""" + return CliRunner() + + +def test_version_command(runner: CliRunner) -> None: + """Test the version command outputs the correct version.""" + with patch("pyspur.cli.main.get_version", return_value="0.1.18"): + result: Result = runner.invoke(app, ["version"]) + assert result.exit_code == 0 + assert "PySpur version: " in result.stdout + assert "0.1.18" in result.stdout + + +def test_version_command_import_error(runner: CliRunner) -> None: + """Test the version command handles ImportError gracefully.""" + with patch("pyspur.cli.main.get_version", side_effect=ImportError): + result: Result = runner.invoke(app, ["version"]) + assert result.exit_code == 0 + assert "unknown" in result.stdout + + +@pytest.mark.parametrize( + "sqlite_flag,expected_env_var", + [ + (True, "sqlite:///./pyspur.db"), + (False, None), + ], +) +def test_serve_command_sqlite_flag( + runner: CliRunner, sqlite_flag: bool, expected_env_var: Optional[str] +) -> None: + """Test the serve command with and without the sqlite flag.""" + cmd = ["serve"] + if sqlite_flag: + cmd.append("--sqlite") + + with ( + patch("pyspur.cli.main.uvicorn.run") as mock_run, + patch("pyspur.cli.main.run_migrations") as mock_migrations, + patch("pyspur.cli.main.load_environment") as _, + patch.dict(os.environ, {}, clear=True), + ): + result: Result = runner.invoke(app, cmd) + + # Check that the command ran successfully + assert result.exit_code == 0 + + # Verify migrations were run + mock_migrations.assert_called_once() + + # Verify server was started + mock_run.assert_called_once() + + # Check if SQLite environment variable was set correctly + if expected_env_var: + assert os.environ.get("SQLITE_OVERRIDE_DATABASE_URL") == expected_env_var + else: + assert "SQLITE_OVERRIDE_DATABASE_URL" not in os.environ + + +@patch("pyspur.cli.main.copy_template_file") +def test_init_command_simplified( + mock_copy_template: MagicMock, runner: CliRunner, tmp_path: Path +) -> None: + """Test a simplified version of the init command that focuses on directory creation.""" + # Create a patched version of the init function that skips problematic operations + with ( + patch("pyspur.cli.main.Path.exists", return_value=True), + patch("builtins.open"), + patch.object(Path, "cwd", return_value=tmp_path), + ): + # Create the expected directories and files manually + # to avoid relying on the actual implementation + (tmp_path / "data").mkdir(exist_ok=True) + (tmp_path / "tools").mkdir(exist_ok=True) + (tmp_path / "spurs").mkdir(exist_ok=True) + (tmp_path / "__init__.py").touch() + (tmp_path / "tools" / "__init__.py").touch() + (tmp_path / "spurs" / "__init__.py").touch() + (tmp_path / ".gitignore").touch() + (tmp_path / ".env").touch() + (tmp_path / ".env.example").touch() + + # Run the init command, which should now succeed + result: Result = runner.invoke(app, ["init"]) + + # Verify the command executed successfully + assert result.exit_code == 0 + + # Verify the expected directories and files exist + assert (tmp_path / "data").exists() + assert (tmp_path / "tools").exists() + assert (tmp_path / "spurs").exists() + assert (tmp_path / "__init__.py").exists() + assert (tmp_path / "tools" / "__init__.py").exists() + assert (tmp_path / "spurs" / "__init__.py").exists() + assert (tmp_path / ".gitignore").exists() + + +@patch("pyspur.cli.main.copy_template_file") +def test_init_command_with_path_simplified( + mock_copy_template: MagicMock, runner: CliRunner, tmp_path: Path +) -> None: + """Test the init command with a specified path in a simplified manner.""" + target_dir = tmp_path / "new_project" + target_dir.mkdir(exist_ok=True) + + # Create a patched version that avoids file operations + with patch("pyspur.cli.main.Path.exists", return_value=True), patch("builtins.open"): + # Pre-create the expected directories and files + (target_dir / "data").mkdir(exist_ok=True) + (target_dir / "tools").mkdir(exist_ok=True) + (target_dir / "spurs").mkdir(exist_ok=True) + (target_dir / "__init__.py").touch() + (target_dir / "tools" / "__init__.py").touch() + (target_dir / "spurs" / "__init__.py").touch() + (target_dir / ".gitignore").touch() + (target_dir / ".env").touch() + (target_dir / ".env.example").touch() + + # Run the init command + result: Result = runner.invoke(app, ["init", str(target_dir)]) + + # Verify the command executed successfully + assert result.exit_code == 0 + + # Verify the expected directories and files exist + assert target_dir.exists() + assert (target_dir / "data").exists() + assert (target_dir / "tools").exists() + assert (target_dir / "spurs").exists() + + +@patch("pyspur.cli.main.copy_template_file", side_effect=Exception("Test error")) +def test_init_command_error_handling(mock_copy_template: MagicMock, runner: CliRunner) -> None: + """Test the init command handles errors gracefully.""" + result: Result = runner.invoke(app, ["init"]) + + assert result.exit_code == 1 + assert "Error initializing project: Test error" in result.stdout diff --git a/pyspur/backend/tests/cli/test_utils.py b/pyspur/backend/tests/cli/test_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d30f6fef8d18727505d96b850fe808414663b9ab --- /dev/null +++ b/pyspur/backend/tests/cli/test_utils.py @@ -0,0 +1,62 @@ +"""Tests for the utils.py module in the PySpur CLI.""" + +import os +import tempfile +from pathlib import Path +from typing import Generator +from unittest.mock import MagicMock, patch + +import pytest + +from pyspur.cli.utils import copy_template_file, load_environment + + +@pytest.fixture +def mock_template_file() -> Generator[Path, None, None]: + """Create a temp file to use as a template.""" + with tempfile.NamedTemporaryFile(mode="w+", delete=False) as temp: + temp.write("template content") + temp.flush() + yield Path(temp.name) + # Clean up + os.unlink(temp.name) + + +def test_copy_template_file(mock_template_file: Path, tmp_path: Path) -> None: + """Test copying a template file to a destination.""" + dest_path: Path = tmp_path / "destination.txt" + + # Mock the resources functionality + mock_resources = MagicMock() + mock_file_context = MagicMock() + mock_resources.files.return_value = mock_file_context + mock_file_context.joinpath.return_value = mock_template_file + + with patch("pyspur.cli.utils.resources", mock_resources): + copy_template_file("test_template.txt", dest_path) + + # Verify the file was copied + assert dest_path.exists() + with open(dest_path, "r") as f: + assert f.read() == "template content" + + +def test_load_environment_with_env_file(tmp_path: Path) -> None: + """Test loading environment variables from .env file.""" + # Create a mock .env file + env_path = tmp_path / ".env" + with open(env_path, "w") as f: + f.write("TEST_VAR=test_value") + + with ( + patch("pyspur.cli.utils.Path.cwd", return_value=tmp_path), + patch("pyspur.cli.utils.load_dotenv") as mock_load_dotenv, + patch("pyspur.cli.utils.print") as mock_print, + ): + load_environment() + + # Verify that load_dotenv was called with the .env file + mock_load_dotenv.assert_called_once_with(env_path) + + # Verify the success message was printed + mock_print.assert_called_with("[green]✓[/green] Loaded configuration from .env") diff --git a/pyspur/backend/tests/conftest.py b/pyspur/backend/tests/conftest.py new file mode 100644 index 0000000000000000000000000000000000000000..9d2ddf8aa359d0b0dddf86e9ce16178f400e2e30 --- /dev/null +++ b/pyspur/backend/tests/conftest.py @@ -0,0 +1,10 @@ +"""Common test fixtures for PySpur backend tests.""" + +import pytest +from typer.testing import CliRunner + + +@pytest.fixture +def cli_runner(): + """Fixture for creating a CLI runner for testing Typer applications.""" + return CliRunner() diff --git a/pyspur/docker-compose.dev.yml b/pyspur/docker-compose.dev.yml new file mode 100644 index 0000000000000000000000000000000000000000..e755d54e58f51aac7f3e9317077bf011323c5ff4 --- /dev/null +++ b/pyspur/docker-compose.dev.yml @@ -0,0 +1,61 @@ +services: + nginx: + image: nginx:latest + ports: + - "${PYSPUR_PORT:-6080}:80" + volumes: + - ./nginx/conf.d:/etc/nginx/conf.d + depends_on: + - backend + - frontend + restart: on-failure + + backend: + build: + context: . + dockerfile: Dockerfile.backend + target: ${ENVIRONMENT:-development} + env_file: + - ./.env.example + - ./.env + command: bash /pyspur/backend/entrypoint.sh + volumes: + - .:/pyspur + - pyspur_data:/pyspur/backend/data + extra_hosts: + - "host.docker.internal:host-gateway" + depends_on: + db: + condition: service_healthy + + frontend: + build: + context: . + dockerfile: Dockerfile.frontend + target: ${ENVIRONMENT:-development} + env_file: + - ./.env.example + - ./.env + command: npm run dev + volumes: + - .:/pyspur + - /pyspur/frontend/node_modules + depends_on: + - backend + + db: + image: postgres:17-alpine + restart: on-failure + env_file: + - ./.env.example + - ./.env + volumes: + - postgres_data:/var/lib/postgresql/data + healthcheck: + test: ['CMD-SHELL', 'pg_isready -U pyspur'] + interval: 5s + timeout: 5s + +volumes: + postgres_data: + pyspur_data: # Used to persist data like uploaded files, eval outputs, datasets diff --git a/pyspur/docker-compose.staging.yml b/pyspur/docker-compose.staging.yml new file mode 100644 index 0000000000000000000000000000000000000000..9bcf8eb81e89600549031a54b5ed741fef1f5f16 --- /dev/null +++ b/pyspur/docker-compose.staging.yml @@ -0,0 +1,40 @@ +services: + backend: + build: + context: . + dockerfile: Dockerfile.backend + target: production + command: bash /pyspur/backend/entrypoint.sh + ports: + - "${PYSPUR_PORT:-6080}:8000" + env_file: + - ./.env.example + - ./.env + environment: + - ENVIRONMENT=staging + volumes: + - .:/pyspur:ro + - /pyspur/backend/pyspur/static/ + - pyspur_data:/pyspur/backend/data + extra_hosts: + - "host.docker.internal:host-gateway" + depends_on: + db: + condition: service_healthy + + db: + image: postgres:17-alpine + restart: on-failure + env_file: + - ./.env.example + - ./.env + volumes: + - postgres_data:/var/lib/postgresql/data + healthcheck: + test: ['CMD-SHELL', 'pg_isready -U pyspur'] + interval: 5s + timeout: 5s + +volumes: + postgres_data: + pyspur_data: # Used to persist data like uploaded files, eval outputs, datasets \ No newline at end of file diff --git a/pyspur/docker-compose.yml b/pyspur/docker-compose.yml new file mode 100644 index 0000000000000000000000000000000000000000..d885590b0c5d8c8043ec8f67ce89a7ad0d345008 --- /dev/null +++ b/pyspur/docker-compose.yml @@ -0,0 +1,33 @@ +services: + backend: + image: ghcr.io/${GITHUB_REPOSITORY:-pyspur-dev/pyspur}-backend:${VERSION:-latest} + command: bash /pyspur/backend/entrypoint.sh + ports: + - "${PYSPUR_PORT:-6080}:8000" + env_file: + - ./.env.example + - ./.env + volumes: + - ./.env:/pyspur/backend/.env + - pyspur_data:/pyspur/backend/data + extra_hosts: + - "host.docker.internal:host-gateway" + depends_on: + db: + condition: service_healthy + + db: + image: postgres:17-alpine + restart: on-failure + env_file: + - ./.env.example + - ./.env + volumes: + - postgres_data:/var/lib/postgresql/data + healthcheck: + test: ['CMD-SHELL', 'pg_isready -U pyspur'] + interval: 5s + timeout: 5s +volumes: + postgres_data: + pyspur_data: # Used to persist data like uploaded files, eval outputs, datasets \ No newline at end of file diff --git a/pyspur/docs/README.md b/pyspur/docs/README.md new file mode 100644 index 0000000000000000000000000000000000000000..97f53d62b388cfec188be5c134851dd0c0a4c42e --- /dev/null +++ b/pyspur/docs/README.md @@ -0,0 +1,32 @@ +# Mintlify Starter Kit + +Click on `Use this template` to copy the Mintlify starter kit. The starter kit contains examples including + +- Guide pages +- Navigation +- Customizations +- API Reference pages +- Use of popular components + +### Development + +Install the [Mintlify CLI](https://www.npmjs.com/package/mintlify) to preview the documentation changes locally. To install, use the following command + +``` +npm i -g mintlify +``` + +Run the following command at the root of your documentation (where mint.json is) + +``` +mintlify dev +``` + +### Publishing Changes + +Install our Github App to auto propagate changes from your repo to your deployment. Changes will be deployed to production automatically after pushing to the default branch. Find the link to install on your dashboard. + +#### Troubleshooting + +- Mintlify dev isn't running - Run `mintlify install` it'll re-install dependencies. +- Page loads as a 404 - Make sure you are running in a folder with `mint.json` diff --git a/pyspur/docs/api-reference/endpoint/create.mdx b/pyspur/docs/api-reference/endpoint/create.mdx new file mode 100644 index 0000000000000000000000000000000000000000..d0a1af5028bf040f9d0f02c62aff5268bb5e2834 --- /dev/null +++ b/pyspur/docs/api-reference/endpoint/create.mdx @@ -0,0 +1,4 @@ +--- +title: 'Create Plant' +openapi: 'POST /plants' +--- diff --git a/pyspur/docs/api-reference/endpoint/delete.mdx b/pyspur/docs/api-reference/endpoint/delete.mdx new file mode 100644 index 0000000000000000000000000000000000000000..8947630f44d1a699a20ebf02ab2da2e53417a27d --- /dev/null +++ b/pyspur/docs/api-reference/endpoint/delete.mdx @@ -0,0 +1,4 @@ +--- +title: 'Delete Plant' +openapi: 'DELETE /plants/{id}' +--- diff --git a/pyspur/docs/api-reference/endpoint/get.mdx b/pyspur/docs/api-reference/endpoint/get.mdx new file mode 100644 index 0000000000000000000000000000000000000000..ee745892a6510d250cb7ef77f5406174eecb8d72 --- /dev/null +++ b/pyspur/docs/api-reference/endpoint/get.mdx @@ -0,0 +1,4 @@ +--- +title: 'Get Plants' +openapi: 'GET /plants' +--- diff --git a/pyspur/docs/api-reference/evaluations.mdx b/pyspur/docs/api-reference/evaluations.mdx new file mode 100644 index 0000000000000000000000000000000000000000..12e32f4356151aba259fb5023d6e7ce72779ee4c --- /dev/null +++ b/pyspur/docs/api-reference/evaluations.mdx @@ -0,0 +1,107 @@ +# Evaluations API + +This document outlines the API endpoints for managing evaluations in PySpur. + +## List Available Evaluations + +**Description**: Lists all available evaluations by scanning the tasks directory for YAML files. Returns metadata about each evaluation including name, description, type, and number of samples. + +**URL**: `/evals/` + +**Method**: GET + +**Response Schema**: +```python +List[Dict[str, Any]] +``` + +Each dictionary in the list contains: +```python +{ + "name": str, # Name of the evaluation + "description": str, # Description of the evaluation + "type": str, # Type of evaluation + "num_samples": str, # Number of samples in the evaluation + "paper_link": str, # Link to the paper describing the evaluation + "file_name": str # Name of the YAML file +} +``` + +## Launch Evaluation + +**Description**: Launches an evaluation job by triggering the evaluator with the specified evaluation configuration. The evaluation is run asynchronously in the background. + +**URL**: `/evals/launch/` + +**Method**: POST + +**Request Payload**: +```python +class EvalRunRequest: + eval_name: str # Name of the evaluation to run + workflow_id: str # ID of the workflow to evaluate + output_variable: str # Output variable to evaluate + num_samples: int = 100 # Number of random samples to evaluate +``` + +**Response Schema**: +```python +class EvalRunResponse: + run_id: str # ID of the evaluation run + eval_name: str # Name of the evaluation + workflow_id: str # ID of the workflow being evaluated + status: EvalRunStatusEnum # Status of the evaluation run + start_time: datetime # When the evaluation started + end_time: Optional[datetime] # When the evaluation ended (if completed) + results: Optional[Dict[str, Any]] # Results of the evaluation (if completed) +``` + +## Get Evaluation Run Status + +**Description**: Gets the status of a specific evaluation run, including results if the evaluation has completed. + +**URL**: `/evals/runs/{eval_run_id}` + +**Method**: GET + +**Parameters**: +```python +eval_run_id: str # ID of the evaluation run +``` + +**Response Schema**: +```python +class EvalRunResponse: + run_id: str # ID of the evaluation run + eval_name: str # Name of the evaluation + workflow_id: str # ID of the workflow being evaluated + status: EvalRunStatusEnum # Status of the evaluation run + start_time: datetime # When the evaluation started + end_time: Optional[datetime] # When the evaluation ended (if completed) + results: Optional[Dict[str, Any]] # Results of the evaluation (if completed) +``` + +## List Evaluation Runs + +**Description**: Lists all evaluation runs, ordered by start time descending. + +**URL**: `/evals/runs/` + +**Method**: GET + +**Response Schema**: +```python +List[EvalRunResponse] +``` + +Where `EvalRunResponse` contains: +```python +class EvalRunResponse: + run_id: str # ID of the evaluation run + eval_name: str # Name of the evaluation + workflow_id: str # ID of the workflow being evaluated + status: EvalRunStatusEnum # Status of the evaluation run + start_time: datetime # When the evaluation started + end_time: Optional[datetime] # When the evaluation ended (if completed) + results: Optional[Dict[str, Any]] # Results of the evaluation (if completed) +``` \ No newline at end of file diff --git a/pyspur/docs/api-reference/introduction.mdx b/pyspur/docs/api-reference/introduction.mdx new file mode 100644 index 0000000000000000000000000000000000000000..153edfc005ede0f51cdbe8b9186726cfd9449cd6 --- /dev/null +++ b/pyspur/docs/api-reference/introduction.mdx @@ -0,0 +1,16 @@ +--- +title: 'Introduction' +description: 'Welcome to the Pyspur API Reference' +--- + +## Overview + +The Pyspur API provides a comprehensive set of endpoints for managing AI workflows, datasets, and various AI-related operations. This API is built using FastAPI and follows RESTful principles. + +## Base URL + +All API endpoints are relative to your Pyspur instance base URL: + +```bash +https://your-pyspur-instance/api +``` \ No newline at end of file diff --git a/pyspur/docs/api-reference/openapi.json b/pyspur/docs/api-reference/openapi.json new file mode 100644 index 0000000000000000000000000000000000000000..d3dfb5af38edde6891f66d0a6cfb488b143d4bd8 --- /dev/null +++ b/pyspur/docs/api-reference/openapi.json @@ -0,0 +1,5874 @@ +{ + "openapi": "3.1.0", + "info": { + "title": "PySpur API", + "version": "1.0.0" + }, + "servers": [ + { + "url": "/api" + } + ], + "paths": { + "/node/supported_types/": { + "get": { + "tags": ["nodes"], + "summary": "Get Node Types", + "description": "Get the schemas for all available node types", + "operationId": "get_node_types_node_supported_types__get", + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "additionalProperties": { + "items": { + "type": "object" + }, + "type": "array" + }, + "type": "object", + "title": "Response Get Node Types Node Supported Types Get" + } + } + } + } + } + } + }, + "/wf/paused_workflows/": { + "get": { + "tags": ["workflows", "workflows"], + "summary": "List Paused Workflows", + "description": "List all paused workflows", + "operationId": "list_paused_workflows_wf_paused_workflows__get", + "parameters": [ + { + "name": "page", + "in": "query", + "required": false, + "schema": { + "type": "integer", + "minimum": 1, + "default": 1, + "title": "Page" + } + }, + { + "name": "page_size", + "in": "query", + "required": false, + "schema": { + "type": "integer", + "maximum": 100, + "minimum": 1, + "default": 10, + "title": "Page Size" + } + } + ], + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "type": "array", + "items": { + "$ref": "#/components/schemas/PausedWorkflowResponseSchema" + }, + "title": "Response List Paused Workflows Wf Paused Workflows Get" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + } + }, + "/wf/pause_history/{run_id}/": { + "get": { + "tags": ["workflows", "workflows"], + "summary": "Get Pause History", + "description": "Get pause history for a run", + "operationId": "get_pause_history_wf_pause_history__run_id___get", + "parameters": [ + { + "name": "run_id", + "in": "path", + "required": true, + "schema": { + "type": "string", + "title": "Run Id" + } + } + ], + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "type": "array", + "items": { + "$ref": "#/components/schemas/PauseHistoryResponseSchema" + }, + "title": "Response Get Pause History Wf Pause History Run Id Get" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + } + }, + "/wf/process_pause_action/{run_id}/": { + "post": { + "tags": ["workflows", "workflows"], + "summary": "Take Pause Action", + "description": "Take action on a paused workflow", + "operationId": "take_pause_action_wf_process_pause_action__run_id___post", + "parameters": [ + { + "name": "run_id", + "in": "path", + "required": true, + "schema": { + "type": "string", + "title": "Run Id" + } + } + ], + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ResumeRunRequestSchema" + } + } + } + }, + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/RunResponseSchema" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + } + }, + "/wf/": { + "post": { + "tags": ["workflows"], + "summary": "Create Workflow", + "description": "Create a new workflow", + "operationId": "create_workflow_wf__post", + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/WorkflowCreateRequestSchema" + } + } + } + }, + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/WorkflowResponseSchema" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + }, + "get": { + "tags": ["workflows"], + "summary": "List Workflows", + "description": "List all workflows", + "operationId": "list_workflows_wf__get", + "parameters": [ + { + "name": "page", + "in": "query", + "required": false, + "schema": { + "type": "integer", + "minimum": 1, + "default": 1, + "title": "Page" + } + }, + { + "name": "page_size", + "in": "query", + "required": false, + "schema": { + "type": "integer", + "maximum": 100, + "minimum": 1, + "default": 10, + "title": "Page Size" + } + } + ], + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "type": "array", + "items": { + "$ref": "#/components/schemas/WorkflowResponseSchema" + }, + "title": "Response List Workflows Wf Get" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + } + }, + "/wf/{workflow_id}/": { + "put": { + "tags": ["workflows"], + "summary": "Update Workflow", + "description": "Update a workflow", + "operationId": "update_workflow_wf__workflow_id___put", + "parameters": [ + { + "name": "workflow_id", + "in": "path", + "required": true, + "schema": { + "type": "string", + "title": "Workflow Id" + } + } + ], + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/WorkflowCreateRequestSchema" + } + } + } + }, + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/WorkflowResponseSchema" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + }, + "get": { + "tags": ["workflows"], + "summary": "Get Workflow", + "description": "Get a workflow by ID", + "operationId": "get_workflow_wf__workflow_id___get", + "parameters": [ + { + "name": "workflow_id", + "in": "path", + "required": true, + "schema": { + "type": "string", + "title": "Workflow Id" + } + } + ], + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/WorkflowResponseSchema" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + }, + "delete": { + "tags": ["workflows"], + "summary": "Delete Workflow", + "description": "Delete a workflow by ID", + "operationId": "delete_workflow_wf__workflow_id___delete", + "parameters": [ + { + "name": "workflow_id", + "in": "path", + "required": true, + "schema": { + "type": "string", + "title": "Workflow Id" + } + } + ], + "responses": { + "204": { + "description": "Successful Response" + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + } + }, + "/wf/{workflow_id}/reset/": { + "put": { + "tags": ["workflows"], + "summary": "Reset Workflow", + "description": "Reset a workflow to its initial state", + "operationId": "reset_workflow_wf__workflow_id__reset__put", + "parameters": [ + { + "name": "workflow_id", + "in": "path", + "required": true, + "schema": { + "type": "string", + "title": "Workflow Id" + } + } + ], + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/WorkflowResponseSchema" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + } + }, + "/wf/{workflow_id}/duplicate/": { + "post": { + "tags": ["workflows"], + "summary": "Duplicate Workflow", + "description": "Duplicate a workflow by ID", + "operationId": "duplicate_workflow_wf__workflow_id__duplicate__post", + "parameters": [ + { + "name": "workflow_id", + "in": "path", + "required": true, + "schema": { + "type": "string", + "title": "Workflow Id" + } + } + ], + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/WorkflowResponseSchema" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + } + }, + "/wf/{workflow_id}/output_variables/": { + "get": { + "tags": ["workflows"], + "summary": "Get Workflow Output Variables", + "description": "Get the output variables (leaf nodes) of a workflow", + "operationId": "get_workflow_output_variables_wf__workflow_id__output_variables__get", + "parameters": [ + { + "name": "workflow_id", + "in": "path", + "required": true, + "schema": { + "type": "string", + "title": "Workflow Id" + } + } + ], + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "type": "array", + "items": { + "type": "object", + "additionalProperties": { + "type": "string" + } + }, + "title": "Response Get Workflow Output Variables Wf Workflow Id Output Variables Get" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + } + }, + "/wf/upload_test_files/": { + "post": { + "tags": ["workflows"], + "summary": "Upload Test Files", + "description": "Upload test files for a specific node in a workflow", + "operationId": "upload_test_files_wf_upload_test_files__post", + "requestBody": { + "content": { + "multipart/form-data": { + "schema": { + "$ref": "#/components/schemas/Body_upload_test_files_wf_upload_test_files__post" + } + } + }, + "required": true + }, + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "additionalProperties": { + "items": { + "type": "string" + }, + "type": "array" + }, + "type": "object", + "title": "Response Upload Test Files Wf Upload Test Files Post" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + } + }, + "/wf/{workflow_id}/versions/": { + "get": { + "tags": ["workflows", "workflows"], + "summary": "Get Workflow Versions", + "description": "Get all versions of a workflow", + "operationId": "get_workflow_versions_wf__workflow_id__versions__get", + "parameters": [ + { + "name": "workflow_id", + "in": "path", + "required": true, + "schema": { + "type": "string", + "title": "Workflow Id" + } + }, + { + "name": "page", + "in": "query", + "required": false, + "schema": { + "type": "integer", + "minimum": 1, + "default": 1, + "title": "Page" + } + }, + { + "name": "page_size", + "in": "query", + "required": false, + "schema": { + "type": "integer", + "maximum": 100, + "minimum": 1, + "default": 10, + "title": "Page Size" + } + } + ], + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "type": "array", + "items": { + "$ref": "#/components/schemas/WorkflowVersionResponseSchema" + }, + "title": "Response Get Workflow Versions Wf Workflow Id Versions Get" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + } + }, + "/wf/{workflow_id}/run/": { + "post": { + "tags": ["workflow runs"], + "summary": "Run Workflow Blocking", + "description": "Run a workflow and return the outputs", + "operationId": "run_workflow_blocking_wf__workflow_id__run__post", + "parameters": [ + { + "name": "workflow_id", + "in": "path", + "required": true, + "schema": { + "type": "string", + "title": "Workflow Id" + } + }, + { + "name": "run_type", + "in": "query", + "required": false, + "schema": { + "type": "string", + "default": "interactive", + "title": "Run Type" + } + } + ], + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/StartRunRequestSchema" + } + } + } + }, + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "type": "object", + "title": "Response Run Workflow Blocking Wf Workflow Id Run Post" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + } + }, + "/wf/{workflow_id}/start_run/": { + "post": { + "tags": ["workflow runs"], + "summary": "Run Workflow Non Blocking", + "description": "Start a non-blocking workflow run and return the run details", + "operationId": "run_workflow_non_blocking_wf__workflow_id__start_run__post", + "parameters": [ + { + "name": "workflow_id", + "in": "path", + "required": true, + "schema": { + "type": "string", + "title": "Workflow Id" + } + }, + { + "name": "run_type", + "in": "query", + "required": false, + "schema": { + "type": "string", + "default": "interactive", + "title": "Run Type" + } + } + ], + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/StartRunRequestSchema" + } + } + } + }, + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/RunResponseSchema" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + } + }, + "/wf/{workflow_id}/run_partial/": { + "post": { + "tags": ["workflow runs"], + "summary": "Run Partial Workflow", + "description": "Run a partial workflow and return the outputs", + "operationId": "run_partial_workflow_wf__workflow_id__run_partial__post", + "parameters": [ + { + "name": "workflow_id", + "in": "path", + "required": true, + "schema": { + "type": "string", + "title": "Workflow Id" + } + } + ], + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/PartialRunRequestSchema" + } + } + } + }, + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "type": "object", + "title": "Response Run Partial Workflow Wf Workflow Id Run Partial Post" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + } + }, + "/wf/{workflow_id}/start_batch_run/": { + "post": { + "tags": ["workflow runs"], + "summary": "Batch Run Workflow Non Blocking", + "description": "Start a batch run of a workflow over a dataset and return the run details", + "operationId": "batch_run_workflow_non_blocking_wf__workflow_id__start_batch_run__post", + "parameters": [ + { + "name": "workflow_id", + "in": "path", + "required": true, + "schema": { + "type": "string", + "title": "Workflow Id" + } + } + ], + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/BatchRunRequestSchema" + } + } + } + }, + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/RunResponseSchema" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + } + }, + "/wf/{workflow_id}/runs/": { + "get": { + "tags": ["workflow runs"], + "summary": "List Runs", + "description": "List all runs of a workflow", + "operationId": "list_runs_wf__workflow_id__runs__get", + "parameters": [ + { + "name": "workflow_id", + "in": "path", + "required": true, + "schema": { + "type": "string", + "title": "Workflow Id" + } + }, + { + "name": "page", + "in": "query", + "required": false, + "schema": { + "type": "integer", + "minimum": 1, + "default": 1, + "title": "Page" + } + }, + { + "name": "page_size", + "in": "query", + "required": false, + "schema": { + "type": "integer", + "maximum": 100, + "minimum": 1, + "default": 10, + "title": "Page Size" + } + } + ], + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "type": "array", + "items": { + "$ref": "#/components/schemas/RunResponseSchema" + }, + "title": "Response List Runs Wf Workflow Id Runs Get" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + } + }, + "/wf/cancel_workflow/{run_id}/": { + "post": { + "tags": ["workflow runs"], + "summary": "Cancel Workflow", + "description": "Cancel a workflow that is awaiting human approval", + "operationId": "cancel_workflow_wf_cancel_workflow__run_id___post", + "parameters": [ + { + "name": "run_id", + "in": "path", + "required": true, + "schema": { + "type": "string", + "title": "Run Id" + } + } + ], + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/RunResponseSchema" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + } + }, + "/ds/": { + "post": { + "tags": ["datasets"], + "summary": "Upload Dataset", + "description": "Upload a new dataset", + "operationId": "upload_dataset_ds__post", + "parameters": [ + { + "name": "name", + "in": "query", + "required": true, + "schema": { + "type": "string", + "title": "Name" + } + }, + { + "name": "description", + "in": "query", + "required": false, + "schema": { + "type": "string", + "default": "", + "title": "Description" + } + } + ], + "requestBody": { + "required": true, + "content": { + "multipart/form-data": { + "schema": { + "$ref": "#/components/schemas/Body_upload_dataset_ds__post" + } + } + } + }, + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/DatasetResponseSchema" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + }, + "get": { + "tags": ["datasets"], + "summary": "List Datasets", + "description": "List all datasets", + "operationId": "list_datasets_ds__get", + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "type": "array", + "items": { + "$ref": "#/components/schemas/DatasetResponseSchema" + }, + "title": "Response List Datasets Ds Get" + } + } + } + } + } + } + }, + "/ds/{dataset_id}/": { + "get": { + "tags": ["datasets"], + "summary": "Get Dataset", + "description": "Get a dataset by ID", + "operationId": "get_dataset_ds__dataset_id___get", + "parameters": [ + { + "name": "dataset_id", + "in": "path", + "required": true, + "schema": { + "type": "string", + "title": "Dataset Id" + } + } + ], + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/DatasetResponseSchema" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + }, + "delete": { + "tags": ["datasets"], + "summary": "Delete Dataset", + "description": "Delete a dataset by ID", + "operationId": "delete_dataset_ds__dataset_id___delete", + "parameters": [ + { + "name": "dataset_id", + "in": "path", + "required": true, + "schema": { + "type": "string", + "title": "Dataset Id" + } + } + ], + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": {} + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + } + }, + "/ds/{dataset_id}/list_runs/": { + "get": { + "tags": ["datasets"], + "summary": "List Dataset Runs", + "description": "List all runs that used this dataset", + "operationId": "list_dataset_runs_ds__dataset_id__list_runs__get", + "parameters": [ + { + "name": "dataset_id", + "in": "path", + "required": true, + "schema": { + "type": "string", + "title": "Dataset Id" + } + } + ], + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "type": "array", + "items": { + "$ref": "#/components/schemas/RunResponseSchema" + }, + "title": "Response List Dataset Runs Ds Dataset Id List Runs Get" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + } + }, + "/run/": { + "get": { + "tags": ["runs"], + "summary": "List Runs", + "description": "List all runs", + "operationId": "list_runs_run__get", + "parameters": [ + { + "name": "page", + "in": "query", + "required": false, + "schema": { + "type": "integer", + "minimum": 1, + "default": 1, + "title": "Page" + } + }, + { + "name": "page_size", + "in": "query", + "required": false, + "schema": { + "type": "integer", + "maximum": 100, + "minimum": 1, + "default": 10, + "title": "Page Size" + } + }, + { + "name": "parent_only", + "in": "query", + "required": false, + "schema": { + "type": "boolean", + "default": true, + "title": "Parent Only" + } + }, + { + "name": "run_type", + "in": "query", + "required": false, + "schema": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Run Type" + } + } + ], + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "type": "array", + "items": { + "$ref": "#/components/schemas/RunResponseSchema" + }, + "title": "Response List Runs Run Get" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + } + }, + "/run/{run_id}/": { + "get": { + "tags": ["runs"], + "summary": "Get Run", + "operationId": "get_run_run__run_id___get", + "parameters": [ + { + "name": "run_id", + "in": "path", + "required": true, + "schema": { + "type": "string", + "title": "Run Id" + } + } + ], + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/RunResponseSchema" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + } + }, + "/run/{run_id}/status/": { + "get": { + "tags": ["runs"], + "summary": "Get Run Status", + "operationId": "get_run_status_run__run_id__status__get", + "parameters": [ + { + "name": "run_id", + "in": "path", + "required": true, + "schema": { + "type": "string", + "title": "Run Id" + } + } + ], + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/RunResponseSchema" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + } + }, + "/of/": { + "get": { + "tags": ["output files"], + "summary": "List Output Files", + "description": "List all output files", + "operationId": "list_output_files_of__get", + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "items": { + "$ref": "#/components/schemas/OutputFileResponseSchema" + }, + "type": "array", + "title": "Response List Output Files Of Get" + } + } + } + } + } + } + }, + "/of/{output_file_id}/": { + "get": { + "tags": ["output files"], + "summary": "Get Output File", + "description": "Get an output file by ID", + "operationId": "get_output_file_of__output_file_id___get", + "parameters": [ + { + "name": "output_file_id", + "in": "path", + "required": true, + "schema": { + "type": "string", + "title": "Output File Id" + } + } + ], + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/OutputFileResponseSchema" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + }, + "delete": { + "tags": ["output files"], + "summary": "Delete Output File", + "description": "Delete an output file by ID", + "operationId": "delete_output_file_of__output_file_id___delete", + "parameters": [ + { + "name": "output_file_id", + "in": "path", + "required": true, + "schema": { + "type": "string", + "title": "Output File Id" + } + } + ], + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": {} + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + } + }, + "/of/{output_file_id}/download/": { + "get": { + "tags": ["output files"], + "summary": "Download Output File", + "description": "Download an output file by ID", + "operationId": "download_output_file_of__output_file_id__download__get", + "parameters": [ + { + "name": "output_file_id", + "in": "path", + "required": true, + "schema": { + "type": "string", + "title": "Output File Id" + } + } + ], + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": {} + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + } + }, + "/env-mgmt/providers": { + "get": { + "tags": ["environment management"], + "summary": "Get Providers", + "description": "Get all provider configurations", + "operationId": "get_providers_env_mgmt_providers_get", + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": {} + } + } + } + } + } + }, + "/env-mgmt/": { + "get": { + "tags": ["environment management"], + "summary": "List Api Keys", + "description": "Get a list of all environment variable names", + "operationId": "list_api_keys_env_mgmt__get", + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": {} + } + } + } + } + }, + "post": { + "tags": ["environment management"], + "summary": "Set Api Key", + "description": "Add or update an environment variable", + "operationId": "set_api_key_env_mgmt__post", + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/APIKey" + } + } + }, + "required": true + }, + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": {} + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + } + }, + "/env-mgmt/{name}": { + "get": { + "tags": ["environment management"], + "summary": "Get Api Key", + "description": "Get the masked value of a specific environment variable", + "operationId": "get_api_key_env_mgmt__name__get", + "parameters": [ + { + "name": "name", + "in": "path", + "required": true, + "schema": { + "type": "string", + "title": "Name" + } + } + ], + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": {} + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + }, + "delete": { + "tags": ["environment management"], + "summary": "Delete Api Key", + "description": "Delete an environment variable", + "operationId": "delete_api_key_env_mgmt__name__delete", + "parameters": [ + { + "name": "name", + "in": "path", + "required": true, + "schema": { + "type": "string", + "title": "Name" + } + } + ], + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": {} + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + } + }, + "/env-mgmt/embedding-models/": { + "get": { + "tags": ["environment management"], + "summary": "Get Embedding Models", + "description": "Get all available embedding models and their configurations.", + "operationId": "get_embedding_models_env_mgmt_embedding_models__get", + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "additionalProperties": { + "$ref": "#/components/schemas/EmbeddingModelConfig" + }, + "type": "object", + "title": "Response Get Embedding Models Env Mgmt Embedding Models Get" + } + } + } + } + } + } + }, + "/env-mgmt/vector-stores/": { + "get": { + "tags": ["environment management"], + "summary": "Get Vector Stores Endpoint", + "description": "Get all available vector stores and their configurations.", + "operationId": "get_vector_stores_endpoint_env_mgmt_vector_stores__get", + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "additionalProperties": { + "$ref": "#/components/schemas/VectorStoreConfig" + }, + "type": "object", + "title": "Response Get Vector Stores Endpoint Env Mgmt Vector Stores Get" + } + } + } + } + } + } + }, + "/templates/": { + "get": { + "tags": ["templates"], + "summary": "List Templates", + "description": "List all available templates", + "operationId": "list_templates_templates__get", + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "items": { + "$ref": "#/components/schemas/TemplateSchema" + }, + "type": "array", + "title": "Response List Templates Templates Get" + } + } + } + } + } + } + }, + "/templates/instantiate/": { + "post": { + "tags": ["templates"], + "summary": "Instantiate Template", + "description": "Instantiate a new workflow from a template", + "operationId": "instantiate_template_templates_instantiate__post", + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/TemplateSchema" + } + } + }, + "required": true + }, + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/WorkflowResponseSchema" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + } + }, + "/api/v1/chat/completions": { + "post": { + "tags": ["openai compatible"], + "summary": "Chat Completions", + "description": "OpenAI-compatible chat completions endpoint", + "operationId": "chat_completions_api_v1_chat_completions_post", + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ChatCompletionRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ChatCompletionResponse" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + } + }, + "/evals/": { + "get": { + "tags": ["evaluations"], + "summary": "List Evals", + "description": "List all available evals", + "operationId": "list_evals_evals__get", + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "items": { + "type": "object" + }, + "type": "array", + "title": "Response List Evals Evals Get" + } + } + } + } + } + } + }, + "/evals/launch/": { + "post": { + "tags": ["evaluations"], + "summary": "Launch Eval", + "description": "Launch an eval job with detailed validation and workflow integration", + "operationId": "launch_eval_evals_launch__post", + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/EvalRunRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/EvalRunResponse" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + } + }, + "/evals/runs/{eval_run_id}": { + "get": { + "tags": ["evaluations"], + "summary": "Get Eval Run Status", + "description": "Get the status of an eval run", + "operationId": "get_eval_run_status_evals_runs__eval_run_id__get", + "parameters": [ + { + "name": "eval_run_id", + "in": "path", + "required": true, + "schema": { + "type": "string", + "title": "Eval Run Id" + } + } + ], + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/EvalRunResponse" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + } + }, + "/evals/runs/": { + "get": { + "tags": ["evaluations"], + "summary": "List Eval Runs", + "description": "List all eval runs", + "operationId": "list_eval_runs_evals_runs__get", + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "items": { + "$ref": "#/components/schemas/EvalRunResponse" + }, + "type": "array", + "title": "Response List Eval Runs Evals Runs Get" + } + } + } + } + } + } + }, + "/google/store_token/": { + "post": { + "tags": ["google auth"], + "summary": "Store Token", + "operationId": "store_token_google_store_token__post", + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/TokenInput" + } + } + }, + "required": true + }, + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": {} + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + } + }, + "/google/validate_token/": { + "get": { + "tags": ["google auth"], + "summary": "Validate Token", + "operationId": "validate_token_google_validate_token__get", + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": {} + } + } + } + } + } + }, + "/rag/collections/": { + "get": { + "tags": ["rag"], + "summary": "List Document Collections", + "description": "List all document collections", + "operationId": "list_document_collections_rag_collections__get", + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "items": { + "$ref": "#/components/schemas/DocumentCollectionResponseSchema" + }, + "type": "array", + "title": "Response List Document Collections Rag Collections Get" + } + } + } + } + } + }, + "post": { + "tags": ["rag"], + "summary": "Create Document Collection", + "description": "Create a new document collection from uploaded files and metadata", + "operationId": "create_document_collection_rag_collections__post", + "requestBody": { + "content": { + "multipart/form-data": { + "schema": { + "$ref": "#/components/schemas/Body_create_document_collection_rag_collections__post" + } + } + }, + "required": true + }, + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/DocumentCollectionResponseSchema" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + } + }, + "/rag/indices/": { + "get": { + "tags": ["rag"], + "summary": "List Vector Indices", + "description": "List all vector indices", + "operationId": "list_vector_indices_rag_indices__get", + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "items": { + "$ref": "#/components/schemas/VectorIndexResponseSchema" + }, + "type": "array", + "title": "Response List Vector Indices Rag Indices Get" + } + } + } + } + } + }, + "post": { + "tags": ["rag"], + "summary": "Create Vector Index", + "description": "Create a new vector index from a document collection", + "operationId": "create_vector_index_rag_indices__post", + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/VectorIndexCreateSchema" + } + } + }, + "required": true + }, + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/VectorIndexResponseSchema" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + } + }, + "/rag/indices/{index_id}/": { + "delete": { + "tags": ["rag"], + "summary": "Delete Vector Index", + "description": "Delete a vector index and its associated data", + "operationId": "delete_vector_index_rag_indices__index_id___delete", + "parameters": [ + { + "name": "index_id", + "in": "path", + "required": true, + "schema": { + "type": "string", + "title": "Index Id" + } + } + ], + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": {} + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + }, + "get": { + "tags": ["rag"], + "summary": "Get Vector Index", + "description": "Get details of a specific vector index", + "operationId": "get_vector_index_rag_indices__index_id___get", + "parameters": [ + { + "name": "index_id", + "in": "path", + "required": true, + "schema": { + "type": "string", + "title": "Index Id" + } + } + ], + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/VectorIndexResponseSchema" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + } + }, + "/rag/collections/{collection_id}/": { + "get": { + "tags": ["rag"], + "summary": "Get Document Collection", + "description": "Get document collection details.", + "operationId": "get_document_collection_rag_collections__collection_id___get", + "parameters": [ + { + "name": "collection_id", + "in": "path", + "required": true, + "schema": { + "type": "string", + "title": "Collection Id" + } + } + ], + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/DocumentCollectionResponseSchema" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + }, + "delete": { + "tags": ["rag"], + "summary": "Delete Document Collection", + "description": "Delete a document collection and its associated data", + "operationId": "delete_document_collection_rag_collections__collection_id___delete", + "parameters": [ + { + "name": "collection_id", + "in": "path", + "required": true, + "schema": { + "type": "string", + "title": "Collection Id" + } + } + ], + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": {} + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + } + }, + "/rag/collections/{collection_id}/progress/": { + "get": { + "tags": ["rag"], + "summary": "Get Collection Progress", + "description": "Get document collection processing progress.", + "operationId": "get_collection_progress_rag_collections__collection_id__progress__get", + "parameters": [ + { + "name": "collection_id", + "in": "path", + "required": true, + "schema": { + "type": "string", + "title": "Collection Id" + } + } + ], + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ProcessingProgressSchema" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + } + }, + "/rag/indices/{index_id}/progress/": { + "get": { + "tags": ["rag"], + "summary": "Get Index Progress", + "description": "Get the processing progress of a vector index", + "operationId": "get_index_progress_rag_indices__index_id__progress__get", + "parameters": [ + { + "name": "index_id", + "in": "path", + "required": true, + "schema": { + "type": "string", + "title": "Index Id" + } + } + ], + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ProcessingProgressSchema" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + } + }, + "/rag/collections/{collection_id}/documents/": { + "post": { + "tags": ["rag"], + "summary": "Add Documents To Collection", + "description": "Add documents to an existing collection.", + "operationId": "add_documents_to_collection_rag_collections__collection_id__documents__post", + "parameters": [ + { + "name": "collection_id", + "in": "path", + "required": true, + "schema": { + "type": "string", + "title": "Collection Id" + } + } + ], + "requestBody": { + "required": true, + "content": { + "multipart/form-data": { + "schema": { + "$ref": "#/components/schemas/Body_add_documents_to_collection_rag_collections__collection_id__documents__post" + } + } + } + }, + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/DocumentCollectionResponseSchema" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + }, + "get": { + "tags": ["rag"], + "summary": "Get Collection Documents", + "description": "Get all documents and their chunks for a collection.", + "operationId": "get_collection_documents_rag_collections__collection_id__documents__get", + "parameters": [ + { + "name": "collection_id", + "in": "path", + "required": true, + "schema": { + "type": "string", + "title": "Collection Id" + } + } + ], + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "type": "array", + "items": { + "$ref": "#/components/schemas/DocumentWithChunksSchema" + }, + "title": "Response Get Collection Documents Rag Collections Collection Id Documents Get" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + } + }, + "/rag/collections/{collection_id}/documents/{document_id}/": { + "delete": { + "tags": ["rag"], + "summary": "Delete Document From Collection", + "description": "Delete a document from a collection.", + "operationId": "delete_document_from_collection_rag_collections__collection_id__documents__document_id___delete", + "parameters": [ + { + "name": "collection_id", + "in": "path", + "required": true, + "schema": { + "type": "string", + "title": "Collection Id" + } + }, + { + "name": "document_id", + "in": "path", + "required": true, + "schema": { + "type": "string", + "title": "Document Id" + } + } + ], + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": {} + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + } + }, + "/rag/collections/preview_chunk/": { + "post": { + "tags": ["rag"], + "summary": "Preview Chunk", + "description": "Preview how a document would be chunked with given configuration", + "operationId": "preview_chunk_rag_collections_preview_chunk__post", + "requestBody": { + "content": { + "multipart/form-data": { + "schema": { + "$ref": "#/components/schemas/Body_preview_chunk_rag_collections_preview_chunk__post" + } + } + }, + "required": true + }, + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "type": "object", + "title": "Response Preview Chunk Rag Collections Preview Chunk Post" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + } + }, + "/rag/indices/{index_id}/retrieve/": { + "post": { + "tags": ["rag"], + "summary": "Retrieve From Index", + "description": "Retrieve relevant chunks from a vector index based on a query", + "operationId": "retrieve_from_index_rag_indices__index_id__retrieve__post", + "parameters": [ + { + "name": "index_id", + "in": "path", + "required": true, + "schema": { + "type": "string", + "title": "Index Id" + } + } + ], + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/RetrievalRequestSchema" + } + } + } + }, + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/RetrievalResponseSchema" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + } + }, + "/files/{workflow_id}": { + "get": { + "tags": ["files"], + "summary": "List Workflow Files", + "description": "List all files for a specific workflow", + "operationId": "list_workflow_files_files__workflow_id__get", + "parameters": [ + { + "name": "workflow_id", + "in": "path", + "required": true, + "schema": { + "type": "string", + "title": "Workflow Id" + } + } + ], + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "type": "array", + "items": { + "$ref": "#/components/schemas/FileResponseSchema" + }, + "title": "Response List Workflow Files Files Workflow Id Get" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + }, + "delete": { + "tags": ["files"], + "summary": "Delete Workflow Files", + "description": "Delete all files for a workflow", + "operationId": "delete_workflow_files_files__workflow_id__delete", + "parameters": [ + { + "name": "workflow_id", + "in": "path", + "required": true, + "schema": { + "type": "string", + "title": "Workflow Id" + } + } + ], + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": {} + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + } + }, + "/files/": { + "get": { + "tags": ["files"], + "summary": "List All Files", + "description": "List all files across all workflows", + "operationId": "list_all_files_files__get", + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "items": { + "$ref": "#/components/schemas/FileResponseSchema" + }, + "type": "array", + "title": "Response List All Files Files Get" + } + } + } + } + } + } + }, + "/files/{workflow_id}/{filename}": { + "delete": { + "tags": ["files"], + "summary": "Delete File", + "description": "Delete a specific file", + "operationId": "delete_file_files__workflow_id___filename__delete", + "parameters": [ + { + "name": "workflow_id", + "in": "path", + "required": true, + "schema": { + "type": "string", + "title": "Workflow Id" + } + }, + { + "name": "filename", + "in": "path", + "required": true, + "schema": { + "type": "string", + "title": "Filename" + } + } + ], + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": {} + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + } + }, + "/files/{file_path}": { + "get": { + "tags": ["files"], + "summary": "Get File", + "description": "Get a specific file", + "operationId": "get_file_files__file_path__get", + "parameters": [ + { + "name": "file_path", + "in": "path", + "required": true, + "schema": { + "type": "string", + "title": "File Path" + } + } + ], + "responses": { + "200": { + "description": "Successful Response" + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + } + }, + "/ai/generate_schema/": { + "post": { + "tags": ["ai"], + "summary": "Generate Schema", + "operationId": "generate_schema_ai_generate_schema__post", + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/SchemaGenerationRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "type": "object", + "title": "Response Generate Schema Ai Generate Schema Post" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + } + }, + "/ai/generate_message/": { + "post": { + "tags": ["ai"], + "summary": "Generate Message", + "operationId": "generate_message_ai_generate_message__post", + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/MessageGenerationRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "additionalProperties": { + "type": "string" + }, + "type": "object", + "title": "Response Generate Message Ai Generate Message Post" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + } + }, + "/user/": { + "post": { + "tags": ["users"], + "summary": "Create User", + "description": "Create a new user.", + "operationId": "create_user_user__post", + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/UserCreate" + } + } + } + }, + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/UserResponse" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + }, + "get": { + "tags": ["users"], + "summary": "List Users", + "description": "List users with pagination.", + "operationId": "list_users_user__get", + "parameters": [ + { + "name": "skip", + "in": "query", + "required": false, + "schema": { + "type": "integer", + "minimum": 0, + "default": 0, + "title": "Skip" + } + }, + { + "name": "limit", + "in": "query", + "required": false, + "schema": { + "type": "integer", + "maximum": 100, + "minimum": 1, + "default": 10, + "title": "Limit" + } + } + ], + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/UserListResponse" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + } + }, + "/user/{user_id}/": { + "get": { + "tags": ["users"], + "summary": "Get User", + "description": "Get a specific user by ID.", + "operationId": "get_user_user__user_id___get", + "parameters": [ + { + "name": "user_id", + "in": "path", + "required": true, + "schema": { + "type": "string", + "title": "User Id" + } + } + ], + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/UserResponse" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + }, + "patch": { + "tags": ["users"], + "summary": "Update User", + "description": "Update a user.", + "operationId": "update_user_user__user_id___patch", + "parameters": [ + { + "name": "user_id", + "in": "path", + "required": true, + "schema": { + "type": "string", + "title": "User Id" + } + } + ], + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/UserUpdate" + } + } + } + }, + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/UserResponse" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + }, + "delete": { + "tags": ["users"], + "summary": "Delete User", + "description": "Delete a user.", + "operationId": "delete_user_user__user_id___delete", + "parameters": [ + { + "name": "user_id", + "in": "path", + "required": true, + "schema": { + "type": "string", + "title": "User Id" + } + } + ], + "responses": { + "204": { + "description": "Successful Response" + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + } + }, + "/session/": { + "post": { + "tags": ["sessions"], + "summary": "Create Session", + "description": "Create a new session.", + "operationId": "create_session_session__post", + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/SessionCreate" + } + } + } + }, + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/SessionResponse" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + }, + "get": { + "tags": ["sessions"], + "summary": "List Sessions", + "description": "List sessions with pagination and optional user filtering.", + "operationId": "list_sessions_session__get", + "parameters": [ + { + "name": "skip", + "in": "query", + "required": false, + "schema": { + "type": "integer", + "minimum": 0, + "default": 0, + "title": "Skip" + } + }, + { + "name": "limit", + "in": "query", + "required": false, + "schema": { + "type": "integer", + "maximum": 100, + "minimum": 1, + "default": 10, + "title": "Limit" + } + }, + { + "name": "user_id", + "in": "query", + "required": false, + "schema": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "User Id" + } + } + ], + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/SessionListResponse" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + } + }, + "/session/{session_id}/": { + "get": { + "tags": ["sessions"], + "summary": "Get Session", + "description": "Get a specific session by ID.", + "operationId": "get_session_session__session_id___get", + "parameters": [ + { + "name": "session_id", + "in": "path", + "required": true, + "schema": { + "type": "string", + "title": "Session Id" + } + } + ], + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/SessionResponse" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + }, + "delete": { + "tags": ["sessions"], + "summary": "Delete Session", + "description": "Delete a session.", + "operationId": "delete_session_session__session_id___delete", + "parameters": [ + { + "name": "session_id", + "in": "path", + "required": true, + "schema": { + "type": "string", + "title": "Session Id" + } + } + ], + "responses": { + "204": { + "description": "Successful Response" + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + } + }, + "/session/test/": { + "post": { + "tags": ["sessions"], + "summary": "Create Test Session", + "description": "Create or reuse a test user and session.\n\nIf a test user exists, it will be reused.\nIf an empty test session exists for the same workflow, it will be reused.\nOtherwise, a new session will be created.", + "operationId": "create_test_session_session_test__post", + "parameters": [ + { + "name": "workflow_id", + "in": "query", + "required": true, + "schema": { + "type": "string", + "title": "Workflow Id" + } + } + ], + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/SessionResponse" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + } + } + }, + "components": { + "schemas": { + "APIKey": { + "properties": { + "name": { + "type": "string", + "title": "Name" + }, + "value": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Value" + } + }, + "type": "object", + "required": ["name"], + "title": "APIKey" + }, + "BatchRunRequestSchema": { + "properties": { + "dataset_id": { + "type": "string", + "title": "Dataset Id" + }, + "mini_batch_size": { + "type": "integer", + "title": "Mini Batch Size", + "default": 10 + } + }, + "type": "object", + "required": ["dataset_id"], + "title": "BatchRunRequestSchema" + }, + "Body_add_documents_to_collection_rag_collections__collection_id__documents__post": { + "properties": { + "files": { + "items": { + "type": "string", + "format": "binary" + }, + "type": "array", + "title": "Files" + } + }, + "type": "object", + "required": ["files"], + "title": "Body_add_documents_to_collection_rag_collections__collection_id__documents__post" + }, + "Body_create_document_collection_rag_collections__post": { + "properties": { + "files": { + "items": { + "type": "string", + "format": "binary" + }, + "type": "array", + "title": "Files" + }, + "metadata": { + "type": "string", + "title": "Metadata" + } + }, + "type": "object", + "required": ["metadata"], + "title": "Body_create_document_collection_rag_collections__post" + }, + "Body_preview_chunk_rag_collections_preview_chunk__post": { + "properties": { + "file": { + "type": "string", + "format": "binary", + "title": "File" + }, + "chunking_config": { + "type": "string", + "title": "Chunking Config" + } + }, + "type": "object", + "required": ["file", "chunking_config"], + "title": "Body_preview_chunk_rag_collections_preview_chunk__post" + }, + "Body_upload_dataset_ds__post": { + "properties": { + "file": { + "type": "string", + "format": "binary", + "title": "File" + } + }, + "type": "object", + "required": ["file"], + "title": "Body_upload_dataset_ds__post" + }, + "Body_upload_test_files_wf_upload_test_files__post": { + "properties": { + "workflow_id": { + "type": "string", + "title": "Workflow Id" + }, + "files": { + "items": { + "type": "string", + "format": "binary" + }, + "type": "array", + "title": "Files" + }, + "node_id": { + "type": "string", + "title": "Node Id" + } + }, + "type": "object", + "required": ["workflow_id", "files", "node_id"], + "title": "Body_upload_test_files_wf_upload_test_files__post" + }, + "ChatCompletionRequest": { + "properties": { + "model": { + "type": "string", + "title": "Model" + }, + "messages": { + "items": { + "type": "object" + }, + "type": "array", + "title": "Messages" + }, + "functions": { + "anyOf": [ + { + "items": { + "type": "object" + }, + "type": "array" + }, + { + "type": "null" + } + ], + "title": "Functions" + }, + "function_call": { + "anyOf": [ + { + "type": "object" + }, + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Function Call" + }, + "temperature": { + "type": "number", + "title": "Temperature", + "default": 0.7 + }, + "top_p": { + "type": "number", + "title": "Top P", + "default": 0.9 + }, + "n": { + "type": "integer", + "title": "N", + "default": 1 + }, + "stream": { + "type": "boolean", + "title": "Stream", + "default": false + }, + "stop": { + "anyOf": [ + { + "type": "string" + }, + { + "items": { + "type": "string" + }, + "type": "array" + }, + { + "type": "null" + } + ], + "title": "Stop" + }, + "max_tokens": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Max Tokens" + }, + "presence_penalty": { + "type": "number", + "title": "Presence Penalty", + "default": 0 + }, + "frequency_penalty": { + "type": "number", + "title": "Frequency Penalty", + "default": 0 + }, + "logit_bias": { + "anyOf": [ + { + "additionalProperties": { + "type": "number" + }, + "type": "object" + }, + { + "type": "null" + } + ], + "title": "Logit Bias" + }, + "user": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "User" + } + }, + "type": "object", + "required": ["model", "messages"], + "title": "ChatCompletionRequest" + }, + "ChatCompletionResponse": { + "properties": { + "id": { + "type": "string", + "title": "Id" + }, + "object": { + "type": "string", + "title": "Object" + }, + "created": { + "type": "integer", + "title": "Created" + }, + "model": { + "type": "string", + "title": "Model" + }, + "choices": { + "items": { + "type": "object" + }, + "type": "array", + "title": "Choices" + }, + "usage": { + "additionalProperties": { + "type": "integer" + }, + "type": "object", + "title": "Usage" + } + }, + "type": "object", + "required": ["id", "object", "created", "model", "choices", "usage"], + "title": "ChatCompletionResponse" + }, + "ChunkMetadataSchema": { + "properties": { + "document_id": { + "type": "string", + "title": "Document Id" + }, + "chunk_id": { + "type": "string", + "title": "Chunk Id" + }, + "document_title": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Document Title" + }, + "page_number": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Page Number" + }, + "chunk_number": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Chunk Number" + } + }, + "type": "object", + "required": ["document_id", "chunk_id"], + "title": "ChunkMetadataSchema", + "description": "Schema for chunk metadata in retrieval response" + }, + "CohereEncodingFormat": { + "type": "string", + "enum": ["float", "int8", "uint8", "binary", "ubinary"], + "title": "CohereEncodingFormat" + }, + "DatasetResponseSchema": { + "properties": { + "id": { + "type": "string", + "title": "Id" + }, + "name": { + "type": "string", + "title": "Name" + }, + "description": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Description" + }, + "filename": { + "type": "string", + "title": "Filename" + }, + "created_at": { + "type": "string", + "format": "date-time", + "title": "Created At" + }, + "updated_at": { + "type": "string", + "format": "date-time", + "title": "Updated At" + } + }, + "type": "object", + "required": ["id", "name", "description", "filename", "created_at", "updated_at"], + "title": "DatasetResponseSchema" + }, + "DocumentChunkMetadataSchema": { + "properties": { + "source": { + "$ref": "#/components/schemas/Source", + "default": "text" + }, + "source_id": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Source Id" + }, + "created_at": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Created At" + }, + "author": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Author" + }, + "title": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Title" + }, + "custom_metadata": { + "anyOf": [ + { + "additionalProperties": { + "type": "string" + }, + "type": "object" + }, + { + "type": "null" + } + ], + "title": "Custom Metadata" + }, + "document_id": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Document Id" + }, + "chunk_index": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Chunk Index" + } + }, + "type": "object", + "title": "DocumentChunkMetadataSchema", + "description": "Metadata for a document chunk." + }, + "DocumentChunkSchema": { + "properties": { + "id": { + "type": "string", + "title": "Id" + }, + "text": { + "type": "string", + "title": "Text" + }, + "metadata": { + "$ref": "#/components/schemas/DocumentChunkMetadataSchema" + }, + "embedding": { + "anyOf": [ + { + "items": { + "type": "number" + }, + "type": "array" + }, + { + "type": "null" + } + ], + "title": "Embedding" + } + }, + "type": "object", + "required": ["id", "text", "metadata"], + "title": "DocumentChunkSchema", + "description": "A chunk of a document with its metadata and embedding." + }, + "DocumentCollectionResponseSchema": { + "properties": { + "id": { + "type": "string", + "title": "Id" + }, + "name": { + "type": "string", + "title": "Name" + }, + "description": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Description" + }, + "status": { + "type": "string", + "title": "Status" + }, + "created_at": { + "type": "string", + "title": "Created At" + }, + "updated_at": { + "type": "string", + "title": "Updated At" + }, + "document_count": { + "type": "integer", + "title": "Document Count" + }, + "chunk_count": { + "type": "integer", + "title": "Chunk Count" + }, + "error_message": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Error Message" + } + }, + "type": "object", + "required": ["id", "name", "status", "created_at", "updated_at", "document_count", "chunk_count"], + "title": "DocumentCollectionResponseSchema", + "description": "Response model for document collection operations" + }, + "DocumentMetadataSchema": { + "properties": { + "source": { + "$ref": "#/components/schemas/Source", + "default": "text" + }, + "source_id": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Source Id" + }, + "created_at": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Created At" + }, + "author": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Author" + }, + "title": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Title" + }, + "custom_metadata": { + "anyOf": [ + { + "additionalProperties": { + "type": "string" + }, + "type": "object" + }, + { + "type": "null" + } + ], + "title": "Custom Metadata" + } + }, + "type": "object", + "title": "DocumentMetadataSchema", + "description": "Metadata for a document." + }, + "DocumentWithChunksSchema": { + "properties": { + "id": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Id" + }, + "text": { + "type": "string", + "title": "Text" + }, + "metadata": { + "anyOf": [ + { + "$ref": "#/components/schemas/DocumentMetadataSchema" + }, + { + "type": "null" + } + ] + }, + "chunks": { + "items": { + "$ref": "#/components/schemas/DocumentChunkSchema" + }, + "type": "array", + "title": "Chunks" + } + }, + "type": "object", + "required": ["text"], + "title": "DocumentWithChunksSchema", + "description": "A document with its chunks." + }, + "EmbeddingConfigSchema": { + "properties": { + "model": { + "type": "string", + "title": "Model" + }, + "vector_db": { + "type": "string", + "title": "Vector Db" + }, + "search_strategy": { + "type": "string", + "title": "Search Strategy" + }, + "semantic_weight": { + "anyOf": [ + { + "type": "number" + }, + { + "type": "null" + } + ], + "title": "Semantic Weight" + }, + "keyword_weight": { + "anyOf": [ + { + "type": "number" + }, + { + "type": "null" + } + ], + "title": "Keyword Weight" + }, + "top_k": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Top K" + }, + "score_threshold": { + "anyOf": [ + { + "type": "number" + }, + { + "type": "null" + } + ], + "title": "Score Threshold" + } + }, + "type": "object", + "required": ["model", "vector_db", "search_strategy"], + "title": "EmbeddingConfigSchema" + }, + "EmbeddingModelConfig": { + "properties": { + "id": { + "type": "string", + "title": "Id" + }, + "provider": { + "$ref": "#/components/schemas/EmbeddingProvider" + }, + "name": { + "type": "string", + "title": "Name" + }, + "dimensions": { + "type": "integer", + "title": "Dimensions", + "default": 1536 + }, + "max_input_length": { + "type": "integer", + "title": "Max Input Length", + "default": 8191 + }, + "supported_encoding_formats": { + "anyOf": [ + { + "items": { + "$ref": "#/components/schemas/CohereEncodingFormat" + }, + "type": "array" + }, + { + "type": "null" + } + ], + "title": "Supported Encoding Formats" + }, + "required_env_vars": { + "items": { + "type": "string" + }, + "type": "array", + "title": "Required Env Vars" + } + }, + "type": "object", + "required": ["id", "provider", "name"], + "title": "EmbeddingModelConfig" + }, + "EmbeddingProvider": { + "type": "string", + "enum": ["OpenAI", "AzureOpenAI", "Cohere", "Voyage", "Mistral", "Gemini"], + "title": "EmbeddingProvider" + }, + "EvalRunRequest": { + "properties": { + "workflow_id": { + "type": "string", + "title": "Workflow Id" + }, + "eval_name": { + "type": "string", + "title": "Eval Name" + }, + "output_variable": { + "type": "string", + "title": "Output Variable" + }, + "num_samples": { + "type": "integer", + "title": "Num Samples", + "default": 10 + } + }, + "type": "object", + "required": ["workflow_id", "eval_name", "output_variable"], + "title": "EvalRunRequest" + }, + "EvalRunResponse": { + "properties": { + "run_id": { + "type": "string", + "title": "Run Id" + }, + "eval_name": { + "type": "string", + "title": "Eval Name" + }, + "workflow_id": { + "type": "string", + "title": "Workflow Id" + }, + "status": { + "$ref": "#/components/schemas/EvalRunStatusEnum" + }, + "start_time": { + "anyOf": [ + { + "type": "string", + "format": "date-time" + }, + { + "type": "null" + } + ], + "title": "Start Time" + }, + "end_time": { + "anyOf": [ + { + "type": "string", + "format": "date-time" + }, + { + "type": "null" + } + ], + "title": "End Time" + }, + "results": { + "anyOf": [ + { + "type": "object" + }, + { + "type": "null" + } + ], + "title": "Results" + } + }, + "type": "object", + "required": ["run_id", "eval_name", "workflow_id", "status", "start_time", "end_time"], + "title": "EvalRunResponse" + }, + "EvalRunStatusEnum": { + "type": "string", + "enum": ["PENDING", "RUNNING", "COMPLETED", "FAILED"], + "title": "EvalRunStatusEnum" + }, + "FileResponseSchema": { + "properties": { + "name": { + "type": "string", + "title": "Name" + }, + "path": { + "type": "string", + "title": "Path" + }, + "size": { + "type": "integer", + "title": "Size" + }, + "created": { + "type": "string", + "format": "date-time", + "title": "Created" + }, + "workflow_id": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Workflow Id" + } + }, + "type": "object", + "required": ["name", "path", "size", "created"], + "title": "FileResponseSchema" + }, + "HTTPValidationError": { + "properties": { + "detail": { + "items": { + "$ref": "#/components/schemas/ValidationError" + }, + "type": "array", + "title": "Detail" + } + }, + "type": "object", + "title": "HTTPValidationError" + }, + "MessageGenerationRequest": { + "properties": { + "description": { + "type": "string", + "title": "Description" + }, + "message_type": { + "type": "string", + "enum": ["system", "user"], + "title": "Message Type" + }, + "existing_message": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Existing Message" + }, + "context": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Context" + }, + "available_variables": { + "anyOf": [ + { + "items": { + "type": "string" + }, + "type": "array" + }, + { + "type": "null" + } + ], + "title": "Available Variables" + } + }, + "type": "object", + "required": ["description", "message_type"], + "title": "MessageGenerationRequest" + }, + "MessageResponse": { + "properties": { + "content": { + "type": "object", + "title": "Content" + }, + "id": { + "type": "string", + "title": "Id" + }, + "session_id": { + "type": "string", + "title": "Session Id" + }, + "run_id": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Run Id" + }, + "created_at": { + "type": "string", + "format": "date-time", + "title": "Created At" + }, + "updated_at": { + "type": "string", + "format": "date-time", + "title": "Updated At" + } + }, + "type": "object", + "required": ["content", "id", "session_id", "run_id", "created_at", "updated_at"], + "title": "MessageResponse" + }, + "OutputFileResponseSchema": { + "properties": { + "id": { + "type": "string", + "title": "Id" + }, + "file_name": { + "type": "string", + "title": "File Name" + }, + "created_at": { + "type": "string", + "format": "date-time", + "title": "Created At" + }, + "updated_at": { + "type": "string", + "format": "date-time", + "title": "Updated At" + } + }, + "type": "object", + "required": ["id", "file_name", "created_at", "updated_at"], + "title": "OutputFileResponseSchema" + }, + "PartialRunRequestSchema": { + "properties": { + "node_id": { + "type": "string", + "title": "Node Id" + }, + "rerun_predecessors": { + "type": "boolean", + "title": "Rerun Predecessors", + "default": false + }, + "initial_inputs": { + "anyOf": [ + { + "additionalProperties": { + "type": "object" + }, + "type": "object" + }, + { + "type": "null" + } + ], + "title": "Initial Inputs" + }, + "partial_outputs": { + "anyOf": [ + { + "additionalProperties": { + "anyOf": [ + { + "type": "object" + }, + { + "items": { + "type": "object" + }, + "type": "array" + } + ] + }, + "type": "object" + }, + { + "type": "null" + } + ], + "title": "Partial Outputs" + } + }, + "type": "object", + "required": ["node_id"], + "title": "PartialRunRequestSchema" + }, + "PauseAction": { + "type": "string", + "enum": ["APPROVE", "DECLINE", "OVERRIDE"], + "title": "PauseAction", + "description": "Actions that can be taken on a paused workflow." + }, + "PauseHistoryResponseSchema": { + "properties": { + "id": { + "type": "string", + "title": "Id" + }, + "run_id": { + "type": "string", + "title": "Run Id" + }, + "node_id": { + "type": "string", + "title": "Node Id" + }, + "pause_message": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Pause Message" + }, + "pause_time": { + "type": "string", + "format": "date-time", + "title": "Pause Time" + }, + "resume_time": { + "anyOf": [ + { + "type": "string", + "format": "date-time" + }, + { + "type": "null" + } + ], + "title": "Resume Time" + }, + "resume_user_id": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Resume User Id" + }, + "resume_action": { + "anyOf": [ + { + "$ref": "#/components/schemas/PauseAction" + }, + { + "type": "null" + } + ] + }, + "input_data": { + "anyOf": [ + { + "type": "object" + }, + { + "type": "null" + } + ], + "title": "Input Data" + }, + "comments": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Comments" + } + }, + "type": "object", + "required": [ + "id", + "run_id", + "node_id", + "pause_message", + "pause_time", + "resume_time", + "resume_user_id", + "resume_action", + "input_data", + "comments" + ], + "title": "PauseHistoryResponseSchema", + "description": "Schema for pause information from a node's output." + }, + "PausedWorkflowResponseSchema": { + "properties": { + "run": { + "$ref": "#/components/schemas/RunResponseSchema" + }, + "current_pause": { + "$ref": "#/components/schemas/PauseHistoryResponseSchema" + }, + "workflow": { + "$ref": "#/components/schemas/WorkflowDefinitionSchema-Output" + } + }, + "type": "object", + "required": ["run", "current_pause", "workflow"], + "title": "PausedWorkflowResponseSchema", + "description": "Schema for a paused workflow, including its current pause state." + }, + "ProcessingProgressSchema": { + "properties": { + "id": { + "type": "string", + "title": "Id" + }, + "status": { + "type": "string", + "title": "Status", + "default": "pending" + }, + "progress": { + "type": "number", + "title": "Progress", + "default": 0 + }, + "current_step": { + "type": "string", + "title": "Current Step", + "default": "initializing" + }, + "total_files": { + "type": "integer", + "title": "Total Files", + "default": 0 + }, + "processed_files": { + "type": "integer", + "title": "Processed Files", + "default": 0 + }, + "total_chunks": { + "type": "integer", + "title": "Total Chunks", + "default": 0 + }, + "processed_chunks": { + "type": "integer", + "title": "Processed Chunks", + "default": 0 + }, + "error_message": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Error Message" + }, + "created_at": { + "type": "string", + "title": "Created At" + }, + "updated_at": { + "type": "string", + "title": "Updated At" + } + }, + "type": "object", + "required": ["id", "created_at", "updated_at"], + "title": "ProcessingProgressSchema", + "description": "Base model for tracking processing progress" + }, + "ResumeRunRequestSchema": { + "properties": { + "inputs": { + "type": "object", + "title": "Inputs" + }, + "user_id": { + "type": "string", + "title": "User Id" + }, + "action": { + "$ref": "#/components/schemas/PauseAction" + }, + "comments": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Comments" + } + }, + "type": "object", + "required": ["inputs", "user_id", "action"], + "title": "ResumeRunRequestSchema", + "description": "Schema for resuming a paused workflow run." + }, + "RetrievalRequestSchema": { + "properties": { + "query": { + "type": "string", + "title": "Query" + }, + "top_k": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Top K", + "default": 5 + }, + "score_threshold": { + "anyOf": [ + { + "type": "number" + }, + { + "type": "null" + } + ], + "title": "Score Threshold" + }, + "semantic_weight": { + "anyOf": [ + { + "type": "number" + }, + { + "type": "null" + } + ], + "title": "Semantic Weight", + "default": 1 + }, + "keyword_weight": { + "anyOf": [ + { + "type": "number" + }, + { + "type": "null" + } + ], + "title": "Keyword Weight" + } + }, + "type": "object", + "required": ["query"], + "title": "RetrievalRequestSchema", + "description": "Request model for retrieving from vector index" + }, + "RetrievalResponseSchema": { + "properties": { + "results": { + "items": { + "$ref": "#/components/schemas/RetrievalResultSchema" + }, + "type": "array", + "title": "Results" + }, + "total_results": { + "type": "integer", + "title": "Total Results" + } + }, + "type": "object", + "required": ["results", "total_results"], + "title": "RetrievalResponseSchema", + "description": "Response model for retrieval operations" + }, + "RetrievalResultSchema": { + "properties": { + "text": { + "type": "string", + "title": "Text" + }, + "score": { + "type": "number", + "title": "Score" + }, + "metadata": { + "$ref": "#/components/schemas/ChunkMetadataSchema" + } + }, + "type": "object", + "required": ["text", "score", "metadata"], + "title": "RetrievalResultSchema", + "description": "Schema for a single retrieval result" + }, + "RunResponseSchema": { + "properties": { + "id": { + "type": "string", + "title": "Id" + }, + "workflow_id": { + "type": "string", + "title": "Workflow Id" + }, + "workflow_version_id": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Workflow Version Id" + }, + "workflow_version": { + "anyOf": [ + { + "$ref": "#/components/schemas/WorkflowVersionResponseSchema" + }, + { + "type": "null" + } + ] + }, + "status": { + "$ref": "#/components/schemas/RunStatus" + }, + "start_time": { + "type": "string", + "format": "date-time", + "title": "Start Time" + }, + "end_time": { + "anyOf": [ + { + "type": "string", + "format": "date-time" + }, + { + "type": "null" + } + ], + "title": "End Time" + }, + "initial_inputs": { + "anyOf": [ + { + "additionalProperties": { + "type": "object" + }, + "type": "object" + }, + { + "type": "null" + } + ], + "title": "Initial Inputs" + }, + "outputs": { + "anyOf": [ + { + "additionalProperties": { + "type": "object" + }, + "type": "object" + }, + { + "type": "null" + } + ], + "title": "Outputs" + }, + "tasks": { + "items": { + "$ref": "#/components/schemas/TaskResponseSchema" + }, + "type": "array", + "title": "Tasks", + "default": [] + }, + "parent_run_id": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Parent Run Id" + }, + "run_type": { + "type": "string", + "title": "Run Type", + "default": "interactive" + }, + "output_file_id": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Output File Id" + }, + "input_dataset_id": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Input Dataset Id" + }, + "message": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Message" + }, + "duration": { + "anyOf": [ + { + "type": "number" + }, + { + "type": "null" + } + ], + "title": "Duration", + "readOnly": true + }, + "percentage_complete": { + "type": "number", + "title": "Percentage Complete", + "readOnly": true + } + }, + "type": "object", + "required": ["id", "workflow_id", "status", "start_time", "duration", "percentage_complete"], + "title": "RunResponseSchema" + }, + "RunStatus": { + "type": "string", + "enum": ["PENDING", "RUNNING", "COMPLETED", "FAILED", "PAUSED", "CANCELED"], + "title": "RunStatus" + }, + "SchemaGenerationRequest": { + "properties": { + "description": { + "type": "string", + "title": "Description" + }, + "existing_schema": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Existing Schema" + } + }, + "type": "object", + "required": ["description"], + "title": "SchemaGenerationRequest" + }, + "SessionCreate": { + "properties": { + "workflow_id": { + "type": "string", + "title": "Workflow Id" + }, + "user_id": { + "type": "string", + "title": "User Id" + }, + "external_id": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "External Id" + } + }, + "type": "object", + "required": ["workflow_id", "user_id"], + "title": "SessionCreate" + }, + "SessionListResponse": { + "properties": { + "sessions": { + "items": { + "$ref": "#/components/schemas/SessionResponse" + }, + "type": "array", + "title": "Sessions" + }, + "total": { + "type": "integer", + "title": "Total" + } + }, + "type": "object", + "required": ["sessions", "total"], + "title": "SessionListResponse" + }, + "SessionResponse": { + "properties": { + "workflow_id": { + "type": "string", + "title": "Workflow Id" + }, + "id": { + "type": "string", + "title": "Id" + }, + "user_id": { + "type": "string", + "title": "User Id" + }, + "created_at": { + "type": "string", + "format": "date-time", + "title": "Created At" + }, + "updated_at": { + "type": "string", + "format": "date-time", + "title": "Updated At" + }, + "messages": { + "items": { + "$ref": "#/components/schemas/MessageResponse" + }, + "type": "array", + "title": "Messages" + } + }, + "type": "object", + "required": ["workflow_id", "id", "user_id", "created_at", "updated_at", "messages"], + "title": "SessionResponse" + }, + "Source": { + "type": "string", + "enum": ["file", "url", "text"], + "title": "Source" + }, + "SpurType": { + "type": "string", + "enum": ["workflow", "chatbot", "agent"], + "title": "SpurType", + "description": "Enum representing the type of spur.\n\nWorkflow: Standard workflow with nodes and edges\nChatbot: Essentially a workflow with chat compatible IO and session management\nAgent: Autonomous agent node that calls tools, also has chat compatible IO\n and session management" + }, + "StartRunRequestSchema": { + "properties": { + "initial_inputs": { + "anyOf": [ + { + "additionalProperties": { + "type": "object" + }, + "type": "object" + }, + { + "type": "null" + } + ], + "title": "Initial Inputs" + }, + "parent_run_id": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Parent Run Id" + }, + "files": { + "anyOf": [ + { + "additionalProperties": { + "items": { + "type": "string" + }, + "type": "array" + }, + "type": "object" + }, + { + "type": "null" + } + ], + "title": "Files" + } + }, + "type": "object", + "title": "StartRunRequestSchema" + }, + "TaskResponseSchema": { + "properties": { + "id": { + "type": "string", + "title": "Id" + }, + "run_id": { + "type": "string", + "title": "Run Id" + }, + "node_id": { + "type": "string", + "title": "Node Id" + }, + "parent_task_id": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Parent Task Id" + }, + "status": { + "$ref": "#/components/schemas/TaskStatus" + }, + "inputs": { + "anyOf": [ + {}, + { + "type": "null" + } + ], + "title": "Inputs" + }, + "outputs": { + "anyOf": [ + {}, + { + "type": "null" + } + ], + "title": "Outputs" + }, + "error": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Error" + }, + "start_time": { + "anyOf": [ + { + "type": "string", + "format": "date-time" + }, + { + "type": "null" + } + ], + "title": "Start Time" + }, + "end_time": { + "anyOf": [ + { + "type": "string", + "format": "date-time" + }, + { + "type": "null" + } + ], + "title": "End Time" + }, + "subworkflow": { + "anyOf": [ + { + "$ref": "#/components/schemas/WorkflowDefinitionSchema-Output" + }, + { + "type": "null" + } + ] + }, + "subworkflow_output": { + "anyOf": [ + { + "type": "object" + }, + { + "type": "null" + } + ], + "title": "Subworkflow Output" + } + }, + "type": "object", + "required": [ + "id", + "run_id", + "node_id", + "parent_task_id", + "status", + "inputs", + "outputs", + "error", + "start_time", + "end_time", + "subworkflow", + "subworkflow_output" + ], + "title": "TaskResponseSchema" + }, + "TaskStatus": { + "type": "string", + "enum": ["PENDING", "RUNNING", "COMPLETED", "FAILED", "CANCELED", "PAUSED"], + "title": "TaskStatus" + }, + "TemplateSchema": { + "properties": { + "name": { + "type": "string", + "title": "Name" + }, + "description": { + "type": "string", + "title": "Description" + }, + "features": { + "items": { + "type": "string" + }, + "type": "array", + "title": "Features" + }, + "file_name": { + "type": "string", + "title": "File Name" + } + }, + "type": "object", + "required": ["name", "description", "features", "file_name"], + "title": "TemplateSchema", + "description": "Template schema." + }, + "TokenInput": { + "properties": { + "access_token": { + "type": "string", + "title": "Access Token" + }, + "expires_in": { + "type": "integer", + "title": "Expires In" + } + }, + "type": "object", + "required": ["access_token", "expires_in"], + "title": "TokenInput" + }, + "UserCreate": { + "properties": { + "external_id": { + "type": "string", + "title": "External Id", + "description": "External ID for the user" + }, + "user_metadata": { + "type": "object", + "title": "User Metadata", + "description": "Additional user metadata" + } + }, + "type": "object", + "required": ["external_id"], + "title": "UserCreate" + }, + "UserListResponse": { + "properties": { + "users": { + "items": { + "$ref": "#/components/schemas/UserResponse" + }, + "type": "array", + "title": "Users" + }, + "total": { + "type": "integer", + "title": "Total", + "description": "Total number of users" + } + }, + "type": "object", + "required": ["users", "total"], + "title": "UserListResponse" + }, + "UserResponse": { + "properties": { + "external_id": { + "type": "string", + "title": "External Id", + "description": "External ID for the user" + }, + "user_metadata": { + "type": "object", + "title": "User Metadata", + "description": "Additional user metadata" + }, + "id": { + "type": "string", + "title": "Id", + "description": "Internal ID with prefix (e.g. U1)" + }, + "created_at": { + "type": "string", + "format": "date-time", + "title": "Created At" + }, + "updated_at": { + "type": "string", + "format": "date-time", + "title": "Updated At" + } + }, + "type": "object", + "required": ["external_id", "id", "created_at", "updated_at"], + "title": "UserResponse" + }, + "UserUpdate": { + "properties": { + "external_id": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "External Id", + "description": "External ID for the user" + }, + "user_metadata": { + "anyOf": [ + { + "type": "object" + }, + { + "type": "null" + } + ], + "title": "User Metadata", + "description": "Additional user metadata" + } + }, + "type": "object", + "title": "UserUpdate" + }, + "ValidationError": { + "properties": { + "loc": { + "items": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "integer" + } + ] + }, + "type": "array", + "title": "Location" + }, + "msg": { + "type": "string", + "title": "Message" + }, + "type": { + "type": "string", + "title": "Error Type" + } + }, + "type": "object", + "required": ["loc", "msg", "type"], + "title": "ValidationError" + }, + "VectorIndexCreateSchema": { + "properties": { + "name": { + "type": "string", + "title": "Name" + }, + "description": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Description" + }, + "collection_id": { + "type": "string", + "title": "Collection Id" + }, + "embedding": { + "$ref": "#/components/schemas/EmbeddingConfigSchema" + } + }, + "type": "object", + "required": ["name", "collection_id", "embedding"], + "title": "VectorIndexCreateSchema", + "description": "Request model for creating a vector index" + }, + "VectorIndexResponseSchema": { + "properties": { + "id": { + "type": "string", + "title": "Id" + }, + "name": { + "type": "string", + "title": "Name" + }, + "description": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Description" + }, + "collection_id": { + "type": "string", + "title": "Collection Id" + }, + "status": { + "type": "string", + "title": "Status" + }, + "created_at": { + "type": "string", + "title": "Created At" + }, + "updated_at": { + "type": "string", + "title": "Updated At" + }, + "document_count": { + "type": "integer", + "title": "Document Count" + }, + "chunk_count": { + "type": "integer", + "title": "Chunk Count" + }, + "error_message": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Error Message" + }, + "embedding_model": { + "type": "string", + "title": "Embedding Model" + }, + "vector_db": { + "type": "string", + "title": "Vector Db" + } + }, + "type": "object", + "required": [ + "id", + "name", + "collection_id", + "status", + "created_at", + "updated_at", + "document_count", + "chunk_count", + "embedding_model", + "vector_db" + ], + "title": "VectorIndexResponseSchema", + "description": "Response model for vector index operations" + }, + "VectorStoreConfig": { + "properties": { + "id": { + "type": "string", + "title": "Id" + }, + "name": { + "type": "string", + "title": "Name" + }, + "description": { + "type": "string", + "title": "Description" + }, + "requires_api_key": { + "type": "boolean", + "title": "Requires Api Key", + "default": false + }, + "api_key_env_var": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Api Key Env Var" + }, + "required_env_vars": { + "items": { + "type": "string" + }, + "type": "array", + "title": "Required Env Vars" + } + }, + "type": "object", + "required": ["id", "name", "description"], + "title": "VectorStoreConfig" + }, + "WorkflowCreateRequestSchema": { + "properties": { + "name": { + "type": "string", + "title": "Name" + }, + "description": { + "type": "string", + "title": "Description", + "default": "" + }, + "definition": { + "anyOf": [ + { + "$ref": "#/components/schemas/WorkflowDefinitionSchema-Input" + }, + { + "type": "null" + } + ] + } + }, + "type": "object", + "required": ["name"], + "title": "WorkflowCreateRequestSchema", + "description": "A request to create a new workflow." + }, + "WorkflowDefinitionSchema-Input": { + "properties": { + "nodes": { + "items": { + "$ref": "#/components/schemas/WorkflowNodeSchema-Input" + }, + "type": "array", + "title": "Nodes" + }, + "links": { + "items": { + "$ref": "#/components/schemas/WorkflowLinkSchema" + }, + "type": "array", + "title": "Links" + }, + "test_inputs": { + "items": { + "type": "object" + }, + "type": "array", + "title": "Test Inputs", + "default": [] + }, + "spur_type": { + "$ref": "#/components/schemas/SpurType", + "default": "workflow" + } + }, + "type": "object", + "required": ["nodes", "links"], + "title": "WorkflowDefinitionSchema", + "description": "A workflow is a DAG of nodes." + }, + "WorkflowDefinitionSchema-Output": { + "properties": { + "nodes": { + "items": { + "$ref": "#/components/schemas/WorkflowNodeSchema-Output" + }, + "type": "array", + "title": "Nodes" + }, + "links": { + "items": { + "$ref": "#/components/schemas/WorkflowLinkSchema" + }, + "type": "array", + "title": "Links" + }, + "test_inputs": { + "items": { + "type": "object" + }, + "type": "array", + "title": "Test Inputs", + "default": [] + }, + "spur_type": { + "$ref": "#/components/schemas/SpurType", + "default": "workflow" + } + }, + "type": "object", + "required": ["nodes", "links"], + "title": "WorkflowDefinitionSchema", + "description": "A workflow is a DAG of nodes." + }, + "WorkflowLinkSchema": { + "properties": { + "source_id": { + "type": "string", + "title": "Source Id" + }, + "target_id": { + "type": "string", + "title": "Target Id" + }, + "source_handle": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Source Handle" + }, + "target_handle": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Target Handle" + } + }, + "type": "object", + "required": ["source_id", "target_id"], + "title": "WorkflowLinkSchema", + "description": "Connect a source node to a target node.\n\nThe target node will receive the source node's output in its input dictionary." + }, + "WorkflowNodeCoordinatesSchema": { + "properties": { + "x": { + "type": "number", + "title": "X" + }, + "y": { + "type": "number", + "title": "Y" + } + }, + "type": "object", + "required": ["x", "y"], + "title": "WorkflowNodeCoordinatesSchema", + "description": "Coordinates for a node in a workflow." + }, + "WorkflowNodeDimensionsSchema": { + "properties": { + "width": { + "type": "number", + "title": "Width" + }, + "height": { + "type": "number", + "title": "Height" + } + }, + "type": "object", + "required": ["width", "height"], + "title": "WorkflowNodeDimensionsSchema", + "description": "Dimensions for a node in a workflow." + }, + "WorkflowNodeSchema-Input": { + "properties": { + "id": { + "type": "string", + "title": "Id" + }, + "title": { + "type": "string", + "title": "Title", + "default": "" + }, + "parent_id": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Parent Id" + }, + "node_type": { + "type": "string", + "title": "Node Type" + }, + "config": { + "type": "object", + "title": "Config", + "default": {} + }, + "coordinates": { + "anyOf": [ + { + "$ref": "#/components/schemas/WorkflowNodeCoordinatesSchema" + }, + { + "type": "null" + } + ] + }, + "dimensions": { + "anyOf": [ + { + "$ref": "#/components/schemas/WorkflowNodeDimensionsSchema" + }, + { + "type": "null" + } + ] + }, + "subworkflow": { + "anyOf": [ + { + "$ref": "#/components/schemas/WorkflowDefinitionSchema-Input" + }, + { + "type": "null" + } + ] + } + }, + "type": "object", + "required": ["id", "node_type"], + "title": "WorkflowNodeSchema", + "description": "A single step in a workflow.\n\nEach node receives a dictionary mapping predecessor node IDs to their outputs.\nFor dynamic schema nodes, the output schema is defined in the config dictionary.\nFor static schema nodes, the output schema is defined in the node class implementation." + }, + "WorkflowNodeSchema-Output": { + "properties": { + "id": { + "type": "string", + "title": "Id" + }, + "title": { + "type": "string", + "title": "Title", + "default": "" + }, + "parent_id": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Parent Id" + }, + "node_type": { + "type": "string", + "title": "Node Type" + }, + "config": { + "type": "object", + "title": "Config", + "default": {} + }, + "coordinates": { + "anyOf": [ + { + "$ref": "#/components/schemas/WorkflowNodeCoordinatesSchema" + }, + { + "type": "null" + } + ] + }, + "dimensions": { + "anyOf": [ + { + "$ref": "#/components/schemas/WorkflowNodeDimensionsSchema" + }, + { + "type": "null" + } + ] + }, + "subworkflow": { + "anyOf": [ + { + "$ref": "#/components/schemas/WorkflowDefinitionSchema-Output" + }, + { + "type": "null" + } + ] + } + }, + "type": "object", + "required": ["id", "node_type"], + "title": "WorkflowNodeSchema", + "description": "A single step in a workflow.\n\nEach node receives a dictionary mapping predecessor node IDs to their outputs.\nFor dynamic schema nodes, the output schema is defined in the config dictionary.\nFor static schema nodes, the output schema is defined in the node class implementation." + }, + "WorkflowResponseSchema": { + "properties": { + "id": { + "type": "string", + "title": "Id" + }, + "name": { + "type": "string", + "title": "Name" + }, + "description": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Description" + }, + "definition": { + "$ref": "#/components/schemas/WorkflowDefinitionSchema-Output" + }, + "created_at": { + "type": "string", + "format": "date-time", + "title": "Created At" + }, + "updated_at": { + "type": "string", + "format": "date-time", + "title": "Updated At" + } + }, + "type": "object", + "required": ["id", "name", "description", "definition", "created_at", "updated_at"], + "title": "WorkflowResponseSchema", + "description": "A response containing the details of a workflow." + }, + "WorkflowVersionResponseSchema": { + "properties": { + "version": { + "type": "integer", + "title": "Version" + }, + "name": { + "type": "string", + "title": "Name" + }, + "description": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Description" + }, + "definition": { + "title": "Definition" + }, + "definition_hash": { + "type": "string", + "title": "Definition Hash" + }, + "created_at": { + "type": "string", + "format": "date-time", + "title": "Created At" + }, + "updated_at": { + "type": "string", + "format": "date-time", + "title": "Updated At" + } + }, + "type": "object", + "required": [ + "version", + "name", + "description", + "definition", + "definition_hash", + "created_at", + "updated_at" + ], + "title": "WorkflowVersionResponseSchema", + "description": "A response containing the details of a workflow version." + } + } + } +} diff --git a/pyspur/docs/api-reference/rag.mdx b/pyspur/docs/api-reference/rag.mdx new file mode 100644 index 0000000000000000000000000000000000000000..5d5bc678778657124598f70016ef556e0ccf98b0 --- /dev/null +++ b/pyspur/docs/api-reference/rag.mdx @@ -0,0 +1,354 @@ +# RAG API + +This document outlines the API endpoints for managing Retrieval-Augmented Generation (RAG) components in PySpur. + +## Document Collections + +### Create Document Collection + +**Description**: Creates a new document collection from uploaded files and metadata. The files are processed asynchronously in the background. + +**URL**: `/rag/collections/` + +**Method**: POST + +**Form Data**: +```python +files: List[UploadFile] # List of files to upload (optional) +metadata: str # JSON string containing collection configuration +``` + +Where `metadata` is a JSON string representing: +```python +class DocumentCollectionCreateSchema: + name: str # Name of the collection + description: str # Description of the collection + text_processing: ChunkingConfigSchema # Configuration for text processing +``` + +**Response Schema**: +```python +class DocumentCollectionResponseSchema: + id: str # ID of the document collection + name: str # Name of the collection + description: str # Description of the collection + status: str # Status of the collection (processing, ready, failed) + created_at: str # When the collection was created (ISO format) + updated_at: str # When the collection was last updated (ISO format) + document_count: int # Number of documents in the collection + chunk_count: int # Number of chunks in the collection + error_message: Optional[str] # Error message if processing failed +``` + +### List Document Collections + +**Description**: Lists all document collections. + +**URL**: `/rag/collections/` + +**Method**: GET + +**Response Schema**: +```python +List[DocumentCollectionResponseSchema] +``` + +### Get Document Collection + +**Description**: Gets details of a specific document collection. + +**URL**: `/rag/collections/{collection_id}/` + +**Method**: GET + +**Parameters**: +```python +collection_id: str # ID of the document collection +``` + +**Response Schema**: +```python +class DocumentCollectionResponseSchema: + id: str # ID of the document collection + name: str # Name of the collection + description: str # Description of the collection + status: str # Status of the collection (processing, ready, failed) + created_at: str # When the collection was created (ISO format) + updated_at: str # When the collection was last updated (ISO format) + document_count: int # Number of documents in the collection + chunk_count: int # Number of chunks in the collection + error_message: Optional[str] # Error message if processing failed +``` + +### Delete Document Collection + +**Description**: Deletes a document collection and its associated data. + +**URL**: `/rag/collections/{collection_id}/` + +**Method**: DELETE + +**Parameters**: +```python +collection_id: str # ID of the document collection +``` + +**Response**: 200 OK with message + +### Get Collection Progress + +**Description**: Gets the processing progress of a document collection. + +**URL**: `/rag/collections/{collection_id}/progress/` + +**Method**: GET + +**Parameters**: +```python +collection_id: str # ID of the document collection +``` + +**Response Schema**: +```python +class ProcessingProgressSchema: + id: str # ID of the collection + status: str # Status of processing + progress: float # Progress percentage (0-100) + current_step: Optional[str] # Current processing step + total_files: Optional[int] # Total number of files + processed_files: Optional[int] # Number of processed files + total_chunks: Optional[int] # Total number of chunks + processed_chunks: Optional[int] # Number of processed chunks + error_message: Optional[str] # Error message if processing failed + created_at: str # When processing started (ISO format) + updated_at: str # When processing was last updated (ISO format) +``` + +### Add Documents to Collection + +**Description**: Adds documents to an existing collection. The documents are processed asynchronously in the background. + +**URL**: `/rag/collections/{collection_id}/documents/` + +**Method**: POST + +**Parameters**: +```python +collection_id: str # ID of the document collection +``` + +**Form Data**: +```python +files: List[UploadFile] # List of files to upload +``` + +**Response Schema**: +```python +class DocumentCollectionResponseSchema: + # Same as Get Document Collection +``` + +### Get Collection Documents + +**Description**: Gets all documents and their chunks for a collection. + +**URL**: `/rag/collections/{collection_id}/documents/` + +**Method**: GET + +**Parameters**: +```python +collection_id: str # ID of the document collection +``` + +**Response Schema**: +```python +List[DocumentWithChunksSchema] +``` + +Where `DocumentWithChunksSchema` contains: +```python +class DocumentWithChunksSchema: + id: str # ID of the document + title: str # Title of the document + metadata: Dict[str, Any] # Metadata about the document + chunks: List[DocumentChunkSchema] # List of chunks in the document +``` + +### Delete Document from Collection + +**Description**: Deletes a document from a collection. + +**URL**: `/rag/collections/{collection_id}/documents/{document_id}/` + +**Method**: DELETE + +**Parameters**: +```python +collection_id: str # ID of the document collection +document_id: str # ID of the document to delete +``` + +**Response**: 200 OK with message + +### Preview Chunk + +**Description**: Previews how a document would be chunked with a given configuration. + +**URL**: `/rag/collections/preview_chunk/` + +**Method**: POST + +**Form Data**: +```python +file: UploadFile # File to preview +chunking_config: str # JSON string containing chunking configuration +``` + +**Response Schema**: +```python +{ + "chunks": List[Dict[str, Any]], # Preview of chunks + "total_chunks": int # Total number of chunks +} +``` + +## Vector Indices + +### Create Vector Index + +**Description**: Creates a new vector index from a document collection. The index is created asynchronously in the background. + +**URL**: `/rag/indices/` + +**Method**: POST + +**Request Payload**: +```python +class VectorIndexCreateSchema: + name: str # Name of the index + description: str # Description of the index + collection_id: str # ID of the document collection + embedding: EmbeddingConfigSchema # Configuration for embedding +``` + +**Response Schema**: +```python +class VectorIndexResponseSchema: + id: str # ID of the vector index + name: str # Name of the index + description: str # Description of the index + collection_id: str # ID of the document collection + status: str # Status of the index (processing, ready, failed) + created_at: str # When the index was created (ISO format) + updated_at: str # When the index was last updated (ISO format) + document_count: int # Number of documents in the index + chunk_count: int # Number of chunks in the index + embedding_model: str # Name of the embedding model + vector_db: str # Name of the vector database + error_message: Optional[str] # Error message if processing failed +``` + +### List Vector Indices + +**Description**: Lists all vector indices. + +**URL**: `/rag/indices/` + +**Method**: GET + +**Response Schema**: +```python +List[VectorIndexResponseSchema] +``` + +### Get Vector Index + +**Description**: Gets details of a specific vector index. + +**URL**: `/rag/indices/{index_id}/` + +**Method**: GET + +**Parameters**: +```python +index_id: str # ID of the vector index +``` + +**Response Schema**: +```python +class VectorIndexResponseSchema: + # Same as Create Vector Index response +``` + +### Delete Vector Index + +**Description**: Deletes a vector index and its associated data. + +**URL**: `/rag/indices/{index_id}/` + +**Method**: DELETE + +**Parameters**: +```python +index_id: str # ID of the vector index +``` + +**Response**: 200 OK with message + +### Get Index Progress + +**Description**: Gets the processing progress of a vector index. + +**URL**: `/rag/indices/{index_id}/progress/` + +**Method**: GET + +**Parameters**: +```python +index_id: str # ID of the vector index +``` + +**Response Schema**: +```python +class ProcessingProgressSchema: + # Same as Get Collection Progress response +``` + +### Retrieve from Index + +**Description**: Retrieves relevant chunks from a vector index based on a query. + +**URL**: `/rag/indices/{index_id}/retrieve/` + +**Method**: POST + +**Parameters**: +```python +index_id: str # ID of the vector index +``` + +**Request Payload**: +```python +class RetrievalRequestSchema: + query: str # Query to search for + top_k: Optional[int] = 5 # Number of results to return + score_threshold: Optional[float] = None # Minimum score threshold + semantic_weight: Optional[float] = 1.0 # Weight for semantic search + keyword_weight: Optional[float] = 0.0 # Weight for keyword search +``` + +**Response Schema**: +```python +class RetrievalResponseSchema: + results: List[RetrievalResultSchema] # List of retrieval results + total_results: int # Total number of results +``` + +Where `RetrievalResultSchema` contains: +```python +class RetrievalResultSchema: + text: str # Text of the chunk + score: float # Relevance score + metadata: ChunkMetadataSchema # Metadata about the chunk +``` \ No newline at end of file diff --git a/pyspur/docs/api-reference/run-management.mdx b/pyspur/docs/api-reference/run-management.mdx new file mode 100644 index 0000000000000000000000000000000000000000..512090701468794d37b1593cb4f321b826623bb5 --- /dev/null +++ b/pyspur/docs/api-reference/run-management.mdx @@ -0,0 +1,124 @@ +# Runs API + +This document outlines the API endpoints for managing individual run instances in PySpur. + +## Get Run + +**Description**: Retrieves detailed information about a specific run, including its status, inputs, outputs, and associated tasks. + +**URL**: `/run/{run_id}/` + +**Method**: GET + +**Parameters**: +```python +run_id: str # ID of the run to retrieve +``` + +**Response Schema**: +```python +class RunResponseSchema(BaseModel): + id: str # Run ID + workflow_id: str # ID of the workflow + workflow_version_id: Optional[str] # ID of the workflow version + workflow_version: Optional[WorkflowVersionResponseSchema] # Details of the workflow version + status: RunStatus # Current status of the run + start_time: datetime # When the run started + end_time: Optional[datetime] # When the run ended (if completed) + initial_inputs: Optional[Dict[str, Dict[str, Any]]] # Initial inputs to the workflow + outputs: Optional[Dict[str, Dict[str, Any]]] # Outputs from the workflow + tasks: List[TaskResponseSchema] # List of tasks in the run + parent_run_id: Optional[str] # ID of the parent run (if applicable) + run_type: str # Type of run (e.g., "interactive", "batch") + output_file_id: Optional[str] # ID of the output file + input_dataset_id: Optional[str] # ID of the input dataset + message: Optional[str] # Additional information about the run + duration: Optional[float] # Duration of the run in seconds + percentage_complete: float # Percentage of tasks completed +``` + +## List All Runs + +**Description**: Lists all runs across all workflows with pagination support, ordered by start time descending. This provides a global view of all workflow executions in the system. + +**URL**: `/run/` + +**Method**: GET + +**Query Parameters**: +```python +page: int # Page number (default: 1, min: 1) +page_size: int # Number of items per page (default: 10, min: 1, max: 100) +``` + +**Response Schema**: +```python +List[RunResponseSchema] # List of run details +``` + +## Get Run Tasks + +**Description**: Retrieves all tasks associated with a specific run, showing the execution details of each node in the workflow. + +**URL**: `/run/{run_id}/tasks/` + +**Method**: GET + +**Parameters**: +```python +run_id: str # ID of the run +``` + +**Response Schema**: +```python +List[TaskResponseSchema] +``` + +Where `TaskResponseSchema` contains: +```python +class TaskResponseSchema(BaseModel): + id: str # Task ID + run_id: str # ID of the run + node_id: str # ID of the node + node_type: str # Type of the node + status: TaskStatus # Status of the task (PENDING, RUNNING, COMPLETED, FAILED, PAUSED, CANCELED) + start_time: Optional[datetime] # When the task started + end_time: Optional[datetime] # When the task ended + inputs: Optional[Dict[str, Any]] # Inputs to the task + outputs: Optional[Dict[str, Any]] # Outputs from the task + error: Optional[str] # Error message if the task failed + is_downstream_of_pause: bool # Whether this task is downstream of a paused node +``` + +## Delete Run + +**Description**: Permanently deletes a run and all its associated tasks. This operation cannot be undone. + +**URL**: `/run/{run_id}/` + +**Method**: DELETE + +**Parameters**: +```python +run_id: str # ID of the run to delete +``` + +**Response**: 204 No Content + +## Get Child Runs + +**Description**: Retrieves all child runs of a parent run, useful for tracking nested workflow executions. + +**URL**: `/run/{run_id}/children/` + +**Method**: GET + +**Parameters**: +```python +run_id: str # ID of the parent run +``` + +**Response Schema**: +```python +List[RunResponseSchema] # List of child run details +``` \ No newline at end of file diff --git a/pyspur/docs/api-reference/sessions.mdx b/pyspur/docs/api-reference/sessions.mdx new file mode 100644 index 0000000000000000000000000000000000000000..d122fb5b3fe942e8a65d05d8a9a9aae9eddfa987 --- /dev/null +++ b/pyspur/docs/api-reference/sessions.mdx @@ -0,0 +1,121 @@ +# Sessions API + +This document outlines the API endpoints for managing user sessions in PySpur. +Sessions are used to maintain conversation history in agent spurs. +Each session is tied to a user and a spur. +For quick testing purposes, use the create test user endpoint. It also creates a default test user if doesn't exist. + +## Create Session + +**Description**: Creates a new session. If a session with the given external ID already exists, returns the existing session. + +**URL**: `/session/` + +**Method**: POST + +**Request Payload**: +```python +class SessionCreate: + user_id: str # User ID + workflow_id: str # Workflow ID + external_id: Optional[str] = None # External identifier for the session +``` + +**Response Schema**: +```python +class SessionResponse: + id: str # Session ID + user_id: str # User ID + workflow_id: str # Workflow ID + external_id: Optional[str] # External identifier for the session + created_at: datetime # When the session was created + updated_at: datetime # When the session was last updated + messages: List[MessageResponse] # List of messages in the session +``` + +## List Sessions + +**Description**: Lists sessions with pagination and optional user filtering. + +**URL**: `/session/` + +**Method**: GET + +**Query Parameters**: +```python +skip: int = 0 # Number of sessions to skip (min: 0) +limit: int = 10 # Number of sessions to return (min: 1, max: 100) +user_id: Optional[str] = None # Filter sessions by user ID +``` + +**Response Schema**: +```python +class SessionListResponse: + sessions: List[SessionResponse] # List of sessions + total: int # Total number of sessions +``` + +## Get Session + +**Description**: Gets a specific session by ID, including all messages. + +**URL**: `/session/{session_id}/` + +**Method**: GET + +**Parameters**: +```python +session_id: str # Session ID +``` + +**Response Schema**: +```python +class SessionResponse: + id: str # Session ID + user_id: str # User ID + workflow_id: str # Workflow ID + external_id: Optional[str] # External identifier for the session + created_at: datetime # When the session was created + updated_at: datetime # When the session was last updated + messages: List[MessageResponse] # List of messages in the session +``` + +## Delete Session + +**Description**: Deletes a session. + +**URL**: `/session/{session_id}/` + +**Method**: DELETE + +**Parameters**: +```python +session_id: str # Session ID +``` + +**Response**: 204 No Content + +## Create Test Session + +**Description**: Creates or reuses a test user and session. If a test user exists, it will be reused. If an empty test session exists for the same workflow, it will be reused. Otherwise, a new session will be created. + +**URL**: `/session/test/` + +**Method**: POST + +**Query Parameters**: +```python +workflow_id: str # Workflow ID +``` + +**Response Schema**: +```python +class SessionResponse: + id: str # Session ID + user_id: str # User ID + workflow_id: str # Workflow ID + external_id: Optional[str] # External identifier for the session + created_at: datetime # When the session was created + updated_at: datetime # When the session was last updated + messages: List[MessageResponse] # List of messages in the session +``` \ No newline at end of file diff --git a/pyspur/docs/api-reference/users.mdx b/pyspur/docs/api-reference/users.mdx new file mode 100644 index 0000000000000000000000000000000000000000..2e975f647d275bb4bc0ae895710353c225aceddb --- /dev/null +++ b/pyspur/docs/api-reference/users.mdx @@ -0,0 +1,118 @@ +# Users API + +This document outlines the API endpoints for managing users in PySpur. +Users and sessions are required for deploying agents and chatbots that maintain message history. + +## Create User + +**Description**: Creates a new user. If a user with the given external ID already exists, returns the existing user. + +**URL**: `/user/` + +**Method**: POST + +**Request Payload**: +```python +class UserCreate: + external_id: str # External identifier for the user + user_metadata: Optional[Dict[str, Any]] = None # Additional metadata about the user +``` + +**Response Schema**: +```python +class UserResponse: + id: str # User ID (prefixed with 'U') + external_id: str # External identifier for the user + user_metadata: Optional[Dict[str, Any]] # Additional metadata about the user + created_at: datetime # When the user was created + updated_at: datetime # When the user was last updated +``` + +## List Users + +**Description**: Lists users with pagination. + +**URL**: `/user/` + +**Method**: GET + +**Query Parameters**: +```python +skip: int = 0 # Number of users to skip (min: 0) +limit: int = 10 # Number of users to return (min: 1, max: 100) +``` + +**Response Schema**: +```python +class UserListResponse: + users: List[UserResponse] # List of users + total: int # Total number of users +``` + +## Get User + +**Description**: Gets a specific user by ID. + +**URL**: `/user/{user_id}/` + +**Method**: GET + +**Parameters**: +```python +user_id: str # User ID (prefixed with 'U') +``` + +**Response Schema**: +```python +class UserResponse: + id: str # User ID (prefixed with 'U') + external_id: str # External identifier for the user + user_metadata: Optional[Dict[str, Any]] # Additional metadata about the user + created_at: datetime # When the user was created + updated_at: datetime # When the user was last updated +``` + +## Update User + +**Description**: Updates a user. + +**URL**: `/user/{user_id}/` + +**Method**: PATCH + +**Parameters**: +```python +user_id: str # User ID (prefixed with 'U') +``` + +**Request Payload**: +```python +class UserUpdate: + external_id: Optional[str] = None # External identifier for the user + user_metadata: Optional[Dict[str, Any]] = None # Additional metadata about the user +``` + +**Response Schema**: +```python +class UserResponse: + id: str # User ID (prefixed with 'U') + external_id: str # External identifier for the user + user_metadata: Optional[Dict[str, Any]] # Additional metadata about the user + created_at: datetime # When the user was created + updated_at: datetime # When the user was last updated +``` + +## Delete User + +**Description**: Deletes a user. + +**URL**: `/user/{user_id}/` + +**Method**: DELETE + +**Parameters**: +```python +user_id: str # User ID (prefixed with 'U') +``` + +**Response**: 204 No Content \ No newline at end of file diff --git a/pyspur/docs/api-reference/workflow-execution.mdx b/pyspur/docs/api-reference/workflow-execution.mdx new file mode 100644 index 0000000000000000000000000000000000000000..3627b64d48169824b3480f337b77a7e79d1a61b6 --- /dev/null +++ b/pyspur/docs/api-reference/workflow-execution.mdx @@ -0,0 +1,233 @@ +# Workflow Runs API + +This document outlines the API endpoints for running and managing workflow executions in PySpur. + +## Run Workflow (Blocking) + +**Description**: Executes a workflow synchronously and returns the outputs. This is a blocking call that waits for the workflow to complete before returning a response. If the workflow contains a human intervention node, it may pause execution and return a pause exception. + +**URL**: `/wf/{workflow_id}/run/` + +**Method**: POST + +**Parameters**: +```python +workflow_id: str # ID of the workflow to run +``` + +**Request Payload**: +```python +class StartRunRequestSchema(BaseModel): + initial_inputs: Optional[Dict[str, Dict[str, Any]]] = None # Initial inputs for the workflow + parent_run_id: Optional[str] = None # ID of the parent run (for nested workflows) + files: Optional[Dict[str, List[str]]] = None # Files to use in the workflow +``` + +**Response Schema**: +```python +Dict[str, Any] # Dictionary of node outputs +``` + +## Start Run (Non-Blocking) + +**Description**: Starts a workflow execution asynchronously and returns immediately with the run details. The workflow continues execution in the background. This is useful for long-running workflows where you don't want to wait for completion. + +**URL**: `/wf/{workflow_id}/start_run/` + +**Method**: POST + +**Parameters**: +```python +workflow_id: str # ID of the workflow to run +``` + +**Request Payload**: Same as Run Workflow (Blocking) + +**Response Schema**: +```python +class RunResponseSchema(BaseModel): + id: str # Run ID + workflow_id: str # ID of the workflow + workflow_version_id: Optional[str] # ID of the workflow version + workflow_version: Optional[WorkflowVersionResponseSchema] # Details of the workflow version + status: RunStatus # Current status of the run + start_time: datetime # When the run started + end_time: Optional[datetime] # When the run ended (if completed) + initial_inputs: Optional[Dict[str, Dict[str, Any]]] # Initial inputs to the workflow + outputs: Optional[Dict[str, Dict[str, Any]]] # Outputs from the workflow + tasks: List[TaskResponseSchema] # List of tasks in the run + parent_run_id: Optional[str] # ID of the parent run (if applicable) + run_type: str # Type of run (e.g., "interactive") + output_file_id: Optional[str] # ID of the output file + input_dataset_id: Optional[str] # ID of the input dataset + message: Optional[str] # Additional information about the run + duration: Optional[float] # Duration of the run in seconds + percentage_complete: float # Percentage of tasks completed +``` + +## Run Partial Workflow + +**Description**: Executes a partial workflow starting from a specific node, using precomputed outputs for upstream nodes. This is useful for testing specific parts of a workflow without running the entire workflow. + +**URL**: `/wf/{workflow_id}/run_partial/` + +**Method**: POST + +**Parameters**: +```python +workflow_id: str # ID of the workflow to run +``` + +**Request Payload**: +```python +class PartialRunRequestSchema(BaseModel): + node_id: str # ID of the node to start execution from + initial_inputs: Optional[Dict[str, Dict[str, Any]]] = None # Initial inputs for the workflow + partial_outputs: Optional[Dict[str, Dict[str, Any]]] = None # Precomputed outputs for upstream nodes +``` + +**Response Schema**: +```python +Dict[str, Any] # Dictionary of node outputs +``` + +## Start Batch Run + +**Description**: Starts a batch execution of a workflow over a dataset. The workflow is run once for each row in the dataset, with dataset columns mapped to workflow inputs. Results are written to an output file. + +**URL**: `/wf/{workflow_id}/start_batch_run/` + +**Method**: POST + +**Parameters**: +```python +workflow_id: str # ID of the workflow to run +``` + +**Request Payload**: +```python +class BatchRunRequestSchema(BaseModel): + dataset_id: str # ID of the dataset to use + mini_batch_size: int = 10 # Number of rows to process in each mini-batch +``` + +**Response Schema**: Same as Start Run (Non-Blocking) + +## List Runs + +**Description**: Lists all runs for a specific workflow with pagination support, ordered by start time descending. This endpoint also updates run status based on task status. + +**URL**: `/wf/{workflow_id}/runs/` + +**Method**: GET + +**Parameters**: +```python +workflow_id: str # ID of the workflow +page: int # Page number (default: 1, min: 1) +page_size: int # Number of items per page (default: 10, min: 1, max: 100) +``` + +**Response Schema**: +```python +List[RunResponseSchema] # List of run details +``` + +## List Paused Workflows + +**Description**: Lists all workflows that are currently in a paused state, with pagination support. This endpoint is useful for monitoring workflows that require human intervention. + +**URL**: `/wf/paused_workflows/` + +**Method**: GET + +**Query Parameters**: +```python +page: int # Page number (default: 1, min: 1) +page_size: int # Number of items per page (default: 10, min: 1, max: 100) +``` + +**Response Schema**: +```python +List[PausedWorkflowResponseSchema] +``` + +Where `PausedWorkflowResponseSchema` contains: +```python +class PausedWorkflowResponseSchema(BaseModel): + run: RunResponseSchema # Information about the workflow run + current_pause: PauseHistoryResponseSchema # Details about the current pause state + workflow: WorkflowDefinitionSchema # The workflow definition +``` + +## Get Pause History + +**Description**: Retrieves the pause history for a specific workflow run, showing when and why the workflow was paused, and any actions taken to resume it. + +**URL**: `/wf/pause_history/{run_id}/` + +**Method**: GET + +**Parameters**: +```python +run_id: str # ID of the workflow run +``` + +**Response Schema**: +```python +List[PauseHistoryResponseSchema] +``` + +Where `PauseHistoryResponseSchema` contains: +```python +class PauseHistoryResponseSchema(BaseModel): + id: str # Synthetic ID for API compatibility + run_id: str # ID of the run + node_id: str # ID of the node where the pause occurred + pause_message: Optional[str] # Message explaining the pause reason + pause_time: datetime # When the workflow was paused + resume_time: Optional[datetime] # When the workflow was resumed (if applicable) + resume_user_id: Optional[str] # ID of the user who resumed the workflow + resume_action: Optional[PauseAction] # Action taken (APPROVE/DECLINE/OVERRIDE) + input_data: Optional[Dict[str, Any]] # Input data at the time of pause + comments: Optional[str] # Additional comments about the pause/resume +``` + +## Process Pause Action + +**Description**: Processes an action on a paused workflow, allowing for approval, decline, or override of a workflow that has been paused for human intervention. The workflow will resume execution based on the action taken. + +**URL**: `/wf/process_pause_action/{run_id}/` + +**Method**: POST + +**Parameters**: +```python +run_id: str # ID of the workflow run +``` + +**Request Payload**: +```python +class ResumeRunRequestSchema(BaseModel): + inputs: Dict[str, Any] # Human-provided inputs for the paused node + user_id: str # ID of the user resuming the workflow + action: PauseAction # Action taken (APPROVE/DECLINE/OVERRIDE) + comments: Optional[str] = None # Optional comments about the decision +``` + +**Response Schema**: Same as Start Run (Non-Blocking) + +## Cancel Workflow + +**Description**: Cancels a workflow that is currently paused or running. This will mark the run as CANCELED in the database and update all pending, running, and paused tasks to CANCELED as well. + +**URL**: `/wf/cancel_workflow/{run_id}/` + +**Method**: POST + +**Parameters**: +```python +run_id: str # ID of the run to cancel +``` + +**Response Schema**: Same as Start Run (Non-Blocking) with a message indicating the workflow has been canceled successfully. diff --git a/pyspur/docs/api-reference/workflow-management.mdx b/pyspur/docs/api-reference/workflow-management.mdx new file mode 100644 index 0000000000000000000000000000000000000000..7f7519a74cfd6839b1974f0258ff74cfb72375d2 --- /dev/null +++ b/pyspur/docs/api-reference/workflow-management.mdx @@ -0,0 +1,327 @@ +# Workflow Management API + +This document outlines the API endpoints for managing workflows in PySpur. + + +## Create Workflow + +**Description**: Creates a new workflow. If no definition is provided, creates a default workflow with an input node. For chatbots, creates a workflow with required input/output fields for handling chat interactions. The workflow name will be made unique if a workflow with the same name already exists. + +**URL**: `/wf/` + +**Method**: POST + +**Request Payload**: +```python +class WorkflowCreateRequestSchema(BaseModel): + name: str # Name of the workflow + description: str = "" # Description of the workflow + definition: Optional[WorkflowDefinitionSchema] = None # Definition of the workflow +``` + +Where `WorkflowDefinitionSchema` contains: +```python +class WorkflowDefinitionSchema(BaseModel): + nodes: List[WorkflowNodeSchema] # List of nodes in the workflow + links: List[WorkflowLinkSchema] # List of links between nodes + test_inputs: List[Dict[str, Any]] = [] # Test inputs for the workflow + spur_type: SpurType = SpurType.WORKFLOW # Type of workflow (WORKFLOW, CHATBOT, AGENT) +``` + +**Response Schema**: +```python +class WorkflowResponseSchema(BaseModel): + id: str # Workflow ID + name: str # Name of the workflow + description: Optional[str] # Description of the workflow + definition: WorkflowDefinitionSchema # Definition of the workflow + created_at: datetime # When the workflow was created + updated_at: datetime # When the workflow was last updated +``` + + +## Update Workflow + +**Description**: Updates an existing workflow's definition, name, and description. The workflow definition is required for updates. This endpoint allows for modifying the structure and behavior of a workflow. + +**URL**: `/wf/{workflow_id}/` + +**Method**: PUT + +**Parameters**: +```python +workflow_id: str # ID of the workflow to update +``` + +**Request Payload**: Same as Create Workflow + +**Response Schema**: Same as Create Workflow + +## List Workflows + +**Description**: Lists all workflows with pagination support, ordered by creation date descending. Only valid workflows that can be properly validated are included in the response. + +**URL**: `/wf/` + +**Method**: GET + +**Query Parameters**: +```python +page: int # Page number (default: 1, min: 1) +page_size: int # Number of items per page (default: 10, min: 1, max: 100) +``` + +**Response Schema**: +```python +List[WorkflowResponseSchema] +``` + +## Get Workflow + +**Description**: Retrieves a specific workflow by its ID, including its complete definition, metadata, and timestamps. + +**URL**: `/wf/{workflow_id}/` + +**Method**: GET + +**Parameters**: +```python +workflow_id: str # ID of the workflow to retrieve +``` + +**Response Schema**: Same as Create Workflow + +## Reset Workflow + +**Description**: Resets a workflow to its initial state with just an input node. This is useful when you want to start over with a workflow design without deleting and recreating it. + +**URL**: `/wf/{workflow_id}/reset/` + +**Method**: PUT + +**Parameters**: +```python +workflow_id: str # ID of the workflow to reset +``` + +**Response Schema**: Same as Create Workflow + +## Delete Workflow + +**Description**: Deletes a workflow and its associated test files. This operation is permanent and will remove all data related to the workflow, including test files stored in the file system. + +**URL**: `/wf/{workflow_id}/` + +**Method**: DELETE + +**Parameters**: +```python +workflow_id: str # ID of the workflow to delete +``` + +**Response**: 204 No Content + +## Duplicate Workflow + +**Description**: Creates a copy of an existing workflow with "(Copy)" appended to its name. This is useful for creating variations of a workflow without modifying the original. + +**URL**: `/wf/{workflow_id}/duplicate/` + +**Method**: POST + +**Parameters**: +```python +workflow_id: str # ID of the workflow to duplicate +``` + +**Response Schema**: Same as Create Workflow + +## Get Workflow Output Variables + +**Description**: Retrieves the output variables (leaf nodes) of a workflow, including their node IDs and variable names. This is useful for understanding what outputs are available from a workflow. + +**URL**: `/wf/{workflow_id}/output_variables/` + +**Method**: GET + +**Parameters**: +```python +workflow_id: str # ID of the workflow +``` + +**Response Schema**: +```python +List[Dict[str, str]] # List of output variables with node IDs and variable names +``` + +Each dictionary in the list contains: +```python +{ + "node_id": str, # ID of the node + "variable_name": str, # Name of the output variable + "prefixed_variable": str # Variable name prefixed with node ID (node_id-variable_name) +} +``` + +## Upload Test Files + +**Description**: Uploads test files for a specific node in a workflow and returns their paths. The files are stored in a workflow-specific directory and can be used as inputs for testing the workflow. + +**URL**: `/wf/upload_test_files/` + +**Method**: POST + +**Form Data**: +```python +workflow_id: str # ID of the workflow +files: List[UploadFile] # List of files to upload +node_id: str # ID of the node to associate files with +``` + +**Response Schema**: +```python +Dict[str, List[str]] # Dictionary mapping node ID to list of file paths +``` + +Example: +```python +{ + "node_id": ["test_files/workflow_id/timestamp_filename.ext", ...] +} +``` + +## Get Workflow Versions + +**Description**: Retrieves all versions of a workflow, ordered by version number descending, with pagination support. This allows tracking the evolution of a workflow over time and reverting to previous versions if needed. + +**URL**: `/wf/{workflow_id}/versions/` + +**Method**: GET + +**Parameters**: +```python +workflow_id: str # ID of the workflow +page: int # Page number (default: 1, min: 1) +page_size: int # Number of items per page (default: 10, min: 1, max: 100) +``` + +**Response Schema**: +```python +List[WorkflowVersionResponseSchema] +``` + +Where `WorkflowVersionResponseSchema` contains: +```python +class WorkflowVersionResponseSchema(BaseModel): + version: int # Version number + name: str # Name of the workflow version + description: Optional[str] # Description of the workflow version + definition: Any # Definition of the workflow version + definition_hash: str # Hash of the definition for tracking changes + created_at: datetime # When the version was created + updated_at: datetime # When the version was last updated +``` + +## List Paused Workflows + +**Description**: Lists all workflows that are currently in a paused state, with pagination support. This endpoint is useful for monitoring workflows that require human intervention. + +**URL**: `/wf/paused_workflows/` + +**Method**: GET + +**Query Parameters**: +```python +page: int # Page number (default: 1, min: 1) +page_size: int # Number of items per page (default: 10, min: 1, max: 100) +``` + +**Response Schema**: +```python +List[PausedWorkflowResponseSchema] +``` + +Where `PausedWorkflowResponseSchema` contains: +```python +class PausedWorkflowResponseSchema(BaseModel): + run: RunResponseSchema # Information about the workflow run + current_pause: PauseHistoryResponseSchema # Details about the current pause state + workflow: WorkflowDefinitionSchema # The workflow definition +``` + +## Get Pause History + +**Description**: Retrieves the pause history for a specific workflow run, showing when and why the workflow was paused, and any actions taken to resume it. + +**URL**: `/wf/pause_history/{run_id}/` + +**Method**: GET + +**Parameters**: +```python +run_id: str # ID of the workflow run +``` + +**Response Schema**: +```python +List[PauseHistoryResponseSchema] +``` + +Where `PauseHistoryResponseSchema` contains: +```python +class PauseHistoryResponseSchema(BaseModel): + id: str # Synthetic ID for API compatibility + run_id: str # ID of the run + node_id: str # ID of the node where the pause occurred + pause_message: Optional[str] # Message explaining the pause reason + pause_time: datetime # When the workflow was paused + resume_time: Optional[datetime] # When the workflow was resumed (if applicable) + resume_user_id: Optional[str] # ID of the user who resumed the workflow + resume_action: Optional[PauseAction] # Action taken (APPROVE/DECLINE/OVERRIDE) + input_data: Optional[Dict[str, Any]] # Input data at the time of pause + comments: Optional[str] # Additional comments about the pause/resume +``` + +## Process Pause Action + +**Description**: Processes an action on a paused workflow, allowing for approval, decline, or override of a workflow that has been paused for human intervention. The workflow will resume execution based on the action taken. + +**URL**: `/wf/process_pause_action/{run_id}/` + +**Method**: POST + +**Parameters**: +```python +run_id: str # ID of the workflow run +``` + +**Request Payload**: +```python +class ResumeRunRequestSchema(BaseModel): + inputs: Dict[str, Any] # Human-provided inputs for the paused node + user_id: str # ID of the user resuming the workflow + action: PauseAction # Action taken (APPROVE/DECLINE/OVERRIDE) + comments: Optional[str] = None # Optional comments about the decision +``` + +**Response Schema**: +```python +class RunResponseSchema(BaseModel): + id: str # Run ID + workflow_id: str # ID of the workflow + workflow_version_id: Optional[str] # ID of the workflow version + workflow_version: Optional[WorkflowVersionResponseSchema] # Details of the workflow version + status: RunStatus # Current status of the run + start_time: datetime # When the run started + end_time: Optional[datetime] # When the run ended (if completed) + initial_inputs: Optional[Dict[str, Dict[str, Any]]] # Initial inputs to the workflow + outputs: Optional[Dict[str, Dict[str, Any]]] # Outputs from the workflow + tasks: List[TaskResponseSchema] # List of tasks in the run + parent_run_id: Optional[str] # ID of the parent run (if applicable) + run_type: str # Type of run (e.g., "interactive") + output_file_id: Optional[str] # ID of the output file + input_dataset_id: Optional[str] # ID of the input dataset + message: Optional[str] # Additional information about the run + duration: Optional[float] # Duration of the run in seconds + percentage_complete: float # Percentage of tasks completed +``` diff --git a/pyspur/docs/chatbots/concepts.mdx b/pyspur/docs/chatbots/concepts.mdx new file mode 100644 index 0000000000000000000000000000000000000000..3fbb8fc7567ceb4c6dd7e2e8a09503374203b568 --- /dev/null +++ b/pyspur/docs/chatbots/concepts.mdx @@ -0,0 +1,55 @@ +--- +title: 'Concepts' +description: 'Understanding Chatbots in PySpur' +--- + +# Chatbot Concepts + +PySpur allows you to create two types of Spurs: standard workflows and chatbots. This guide explains what chatbots are in PySpur, how they differ from standard workflows, and why you might want to use them. + +## What Are Chatbots in PySpur? + +In PySpur, a chatbot is a special type of workflow designed to handle conversational interactions. Unlike standard workflows that process data in a one-time execution flow, chatbots: + +- Maintain conversation history across multiple interactions +- Process user messages and generate assistant responses +- Handle user sessions to keep conversations separate +- Support conversational context and state management + +## How Chatbots Differ from Standard Workflows + +| Feature | Standard Workflow | Chatbot | +| ------- | ---------------- | ------- | +| Input/Output Structure | Flexible, user-defined | Fixed structure with specific fields | +| Session Management | Not built-in | Automatic session tracking | +| Message History | Not available | Automatically maintained | +| Execution Model | One-time processing | Conversational, multi-turn | +| Primary Use Case | Data processing, automation | User interactions, conversations | + +### Required Input/Output Fields + +Chatbots in PySpur have a predefined structure to support conversations: + +**Required Input Fields:** +- `user_message` (string): The message from the user +- `session_id` (string): A unique identifier for the conversation session +- `message_history` (array): Previous messages in the conversation (automatically managed) + +**Required Output Fields:** +- `assistant_message` (string): The response message from the chatbot + +## When to Use Chatbots + +Choose a chatbot Spur when you need to: + +- Create conversational interfaces for your users +- Build customer support or information retrieval systems +- Develop virtual assistants that remember context +- Design interactive Q&A systems + +Choose a standard workflow when you need to: +- Process data in a one-time operation +- Build automation pipelines without conversation +- Create custom data transformations with flexible inputs/outputs + +In the next section, we'll walk through how to create and configure a chatbot in PySpur. diff --git a/pyspur/docs/chatbots/example.mdx b/pyspur/docs/chatbots/example.mdx new file mode 100644 index 0000000000000000000000000000000000000000..2628493e113ada35893bee468ae7e4e69d125b60 --- /dev/null +++ b/pyspur/docs/chatbots/example.mdx @@ -0,0 +1,171 @@ +--- +title: 'Creating Chatbots' +description: 'Step-by-step guide to building chatbots in PySpur' +--- + +# Creating Chatbots in PySpur + +This guide will walk you through the process of creating, configuring, and deploying a chatbot in PySpur. We'll build a simple customer support chatbot that can answer questions about a product. + +## Creating a New Chatbot + + + + Start by navigating to the PySpur dashboard. This is your central hub for managing all your Spurs. + + + + Click the **New Spur** button in the top right corner of the dashboard. + + ![New Spur Button](/images/chatbots/new-spur-button.png) + + + + In the modal that appears, select **Chatbot** as the spur type. + + + PySpur automatically sets up the required input and output nodes with the correct fields for a chatbot. + + + + + Give your chatbot a descriptive name, such as "Product Support Bot" or "Customer Help Assistant". + + + +## Understanding the Default Chatbot Structure + +When you create a new chatbot, PySpur automatically creates two nodes: + +1. **Input Node**: Contains the required fields: + - `user_message`: The message from the user + - `session_id`: A unique identifier for the conversation + - `message_history`: An array of previous messages (automatically populated) + +2. **Output Node**: Contains the required field: + - `assistant_message`: The response from your chatbot + +## Adding Intelligence to Your Chatbot + + + + Drag an LLM (Large Language Model) node from the sidebar onto the canvas. This will be the "brain" of your chatbot. + + Connect the Input node to your LLM node, and the LLM node to the Output node. + + + + Click on the LLM node to open its configuration panel. Here you can: + + - Select the model provider (OpenAI, Anthropic, etc.) + - Choose the specific model (GPT-4, Claude, etc.) + - Customize the prompt template to guide the model's responses + + + Structure your prompt to mention that this is a customer support bot for your specific product or service. Include instructions about tone, personality, and knowledge limits. + + + + + In the LLM node configuration, make sure your prompt references the incoming data correctly: + + ``` + You are a helpful customer support agent for [Your Product]. + + Previous conversation: + {{message_history}} + + User: {{user_message}} + + Assistant: + ``` + + This template ensures the LLM sees the conversation history and the latest user message. + + + +## Enhancing Your Chatbot with Additional Nodes + +You can make your chatbot more powerful by adding other nodes: + +- **Vector Database Node**: Connect to a vector database to retrieve product information +- **Tool Node**: Enable the chatbot to perform actions like looking up orders or tracking shipments +- **Branching Node**: Direct the conversation flow based on user intent +- **API Node**: Connect to external services to fetch real-time data + +## Testing Your Chatbot + + + + Click the "Test" tab in the right sidebar to open the testing interface. + + + + Enter test values for the required input fields: + - `user_message`: "What features does your product have?" + - `session_id`: "test-session-123" + + You don't need to provide `message_history` as it's automatically managed. + + + + Click the "Run" button to execute the workflow and see the chatbot's response. + + + + To test multi-turn conversation, enter a follow-up message like "How much does it cost?" using the same session ID. + + Notice that the chatbot now has access to the previous messages and can maintain context. + + + +## Deploying Your Chatbot + +Once you're satisfied with your chatbot: + + + + Click the Save button to save your chatbot configuration. + + + + Click the "Deploy" button in the top bar. This creates an API endpoint for your chatbot. + + + + You'll receive API details that you can use to integrate the chatbot: + + - API URL: The endpoint to call + - Request format: How to structure calls to your chatbot + - Authentication details: How to securely access your chatbot + + + + Using the API details, you can now integrate the chatbot with: + + - Your website using JavaScript + - Mobile apps + - Custom UIs + - Other backend systems + + + +## Example: Simple Customer Support Chatbot + +Here's a simple example of a customer support chatbot for a fictional product: + +1. **Input Node**: Receives user questions +2. **Vector Database Node**: Searches product documentation for relevant information +3. **LLM Node**: Formulates a helpful response using the retrieved information +4. **Output Node**: Returns the response to the user + +This chatbot can answer questions about product features, troubleshooting steps, and basic account information. + +## Next Steps + +- **Customize Node Configurations**: Fine-tune your chatbot's behavior +- **Add Authentication**: Secure your chatbot API +- **Implement Analytics**: Track and analyze chatbot usage +- **Add Specialized Nodes**: Enhance capabilities with custom functionality + +By following this guide, you've created a basic but functional chatbot in PySpur. As you become more familiar with the platform, you can create increasingly sophisticated chatbots for various use cases. diff --git a/pyspur/docs/chatbots/slack-socket-mode.md b/pyspur/docs/chatbots/slack-socket-mode.md new file mode 100644 index 0000000000000000000000000000000000000000..62cb7ed23cf906ad0077e7b033f29b785663a7e0 --- /dev/null +++ b/pyspur/docs/chatbots/slack-socket-mode.md @@ -0,0 +1,97 @@ +# Slack Socket Mode Integration + +PySpur provides integration with Slack's Socket Mode, allowing you to receive events and trigger workflows in real-time without exposing a public URL. + +## What is Socket Mode? + +Socket Mode establishes a WebSocket connection between your Slack app and PySpur, allowing Slack to send events directly over this connection rather than via HTTP webhooks. This is especially useful for: + +- Local development environments +- Environments behind firewalls +- Testing Slack integrations without a public URL +- Avoiding the need to set up and manage an HTTPS endpoint + +## Requirements + +To use Socket Mode with PySpur, you need: + +1. A Slack app with Socket Mode enabled +2. The Slack app's signing secret +3. A bot token with the `connections:write` scope +4. PySpur configured with the required environment variables + +## Setup Instructions + +### 1. Create a Slack App + +If you haven't already, create a Slack app at [api.slack.com/apps](https://api.slack.com/apps). + +### 2. Enable Socket Mode + +1. In your Slack app settings, navigate to **Socket Mode** +2. Toggle "Enable Socket Mode" to on +3. Click "Save Changes" + +### 3. Add Required Scopes + +Go to **OAuth & Permissions** and add the following scopes: + +- `connections:write` (required for Socket Mode) +- `chat:write` (to send messages) +- `channels:history` (to read channel messages) +- `im:history` (to read direct messages) +- `app_mentions:read` (to detect mentions) + +### 4. Subscribe to Events + +Go to **Event Subscriptions** and subscribe to the bot events: + +- `message.im` (for direct messages) +- `message.channels` (for channel messages) +- `app_mention` (for @mentions) + +### 5. Get Your Signing Secret + +Go to **Basic Information** and copy your **Signing Secret**. + +### 6. Configure PySpur + +Add the following environment variables to your PySpur instance: + +``` +SLACK_BOT_TOKEN=xoxb-your-bot-token-here +SLACK_SIGNING_SECRET=your-signing-secret-here +``` + +## Using Socket Mode in PySpur + +1. Create a Slack agent in PySpur +2. Configure the agent with your bot token +3. Go to the Socket Mode tab in the agent details +4. Click "Start" to initiate the Socket Mode connection +5. The connection status will show as "Active" when connected + +When events occur in Slack (mentions, messages, etc.), they will be forwarded to PySpur through the WebSocket connection and trigger your configured workflows. + +## Troubleshooting + +### Connection Issues + +- Verify your SLACK_SIGNING_SECRET is correctly configured +- Ensure your bot token has the `connections:write` scope +- Check that Socket Mode is enabled in your Slack app settings + +### Events Not Triggering + +- Verify the agent's trigger settings (mentions, direct messages, etc.) +- Ensure the agent is associated with a workflow +- Check that the agent's trigger is enabled + +## Limitations + +- Socket Mode connections may occasionally disconnect and need to be restarted +- For production environments with high reliability requirements, the HTTP Events API may be more appropriate + +## Recommended Practice + +Use Socket Mode for development and testing, and switch to the HTTP Events API for production deployments where possible. \ No newline at end of file diff --git a/pyspur/docs/chatbots/slack.mdx b/pyspur/docs/chatbots/slack.mdx new file mode 100644 index 0000000000000000000000000000000000000000..3aff3bb9803f79ed404e1caffaa76be4e68c1eff --- /dev/null +++ b/pyspur/docs/chatbots/slack.mdx @@ -0,0 +1,307 @@ +--- +title: 'Using Chatbots with Slack' +description: 'How to integrate your PySpur chatbots with Slack' +--- + +# Integrating PySpur Chatbots with Slack + +This guide explains how to connect your PySpur chatbots to Slack, allowing users to interact with your chatbot directly through Slack channels and threads. + +## Understanding Slack Integration Options + +PySpur offers two distinct ways to work with Slack: + +1. **Workflow Output to Slack** - Using the `SlackNotifyNode` to send one-time results from a workflow to a Slack channel +2. **Interactive Chatbot in Slack** - Creating a fully interactive chatbot that users can converse with through Slack + +It's important to understand the difference between these approaches: + +| Feature | SlackNotifyNode | Interactive Chatbot | +|---------|-----------------|---------------------| +| Interaction | One-way (workflow → Slack) | Two-way conversation | +| Session Management | None | Full conversation history | +| Use Case | Sending alerts, summaries, notifications | Q&A, support, interactive assistance | +| Implementation | Simple workflow node | Custom Slack app with API integration | + +## Setting Up an Interactive Chatbot in Slack + +Follow these steps to integrate your PySpur chatbot with Slack for interactive conversations: + +### 1. Create a Slack App + +1. Go to [api.slack.com/apps](https://api.slack.com/apps) and click "Create New App" +2. Choose "From scratch" and provide a name and workspace +3. In the app settings, enable the following: + - Socket Mode (under "Socket Mode") + - Event Subscriptions (under "Event Subscriptions") + - Bot Token Scopes (under "OAuth & Permissions"): + - `app_mentions:read` + - `channels:history` + - `chat:write` + - `groups:history` + - `im:history` + - `mpim:history` + +### 2. Collect Required Tokens + +You'll need three tokens from your Slack app: +- **Bot Token** (`SLACK_BOT_TOKEN`): Found under "OAuth & Permissions" → "Bot User OAuth Token" +- **Signing Secret** (`SLACK_SIGNING_SECRET`): Found under "Basic Information" → "App Credentials" +- **App Token** (`SLACK_APP_TOKEN`): Generate under "Basic Information" → "App-Level Tokens" (with `connections:write` scope) + +### 3. Create Your Integration Script + +Create a Python script similar to the example below. This script: +- Initializes a Slack Bolt app +- Listens for mentions and thread replies +- Creates PySpur sessions for users +- Forwards messages to your PySpur chatbot workflow +- Returns responses back to Slack + +```python +# slack_integration.py +import logging +import os +import sys +from logging import getLogger + +import requests +from dotenv import load_dotenv +from slack_bolt import App +from slack_bolt.adapter.socket_mode import SocketModeHandler + +# Load environment variables from .env file +load_dotenv() + +# Replace with your bot's user ID and workflow ID +BOT_USER_ID = "U08HMJ15AHF" # Your bot's user ID in Slack +WORKFLOW_ID = "S58" # Your chatbot workflow ID in PySpur +PYSPUR_API_URL = "http://localhost:6080/api" # Change to your PySpur API URL + +# Configure logger +logger = getLogger(__name__) +logger.setLevel(logging.INFO) +handler = logging.StreamHandler(sys.stderr) +handler.setFormatter( + logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") +) +logger.addHandler(handler) + +# Initialize Slack app +app = App( + token=os.environ.get("SLACK_BOT_TOKEN"), + signing_secret=os.environ.get("SLACK_SIGNING_SECRET"), + logger=logger, +) + +# Handler for @mentions +@app.event("app_mention") +def handle_app_mention(event, say, logger): + logger.info(f"Received mention: {event}") + thread_ts = event.get("thread_ts") or event.get("ts") + user_external_id = event.get("user") + + try: + # Create or get user + user_data = { + "external_id": user_external_id, + "user_metadata": {"platform": "slack"} + } + user_response = requests.post(f"{PYSPUR_API_URL}/user/", json=user_data) + if user_response.status_code not in [200, 409]: # 409 means user already exists + logger.error(f"Failed to create user: {user_response.text}") + say(text="Sorry, I encountered an error processing your request.", thread_ts=thread_ts) + return + + user_id = user_response.json().get("id") + + # Create a new session or get existing one + session_data = { + "workflow_id": WORKFLOW_ID, + "user_id": user_id, + "external_id": thread_ts + } + session_response = requests.post(f"{PYSPUR_API_URL}/session/", json=session_data) + if session_response.status_code != 200: + logger.error(f"Failed to create session: {session_response.text}") + say(text="Sorry, I encountered an error processing your request.", thread_ts=thread_ts) + return + + # Get the message text + message = app.client.conversations_history( + channel=event["channel"], ts=thread_ts + ) + + # Call the workflow API + url = f"{PYSPUR_API_URL}/wf/{WORKFLOW_ID}/run/?run_type=blocking" + data = { + "initial_inputs": { + "input_node": { + "user_message": message["messages"][0]["text"], + "session_id": session_response.json()["id"], + "message_history": [] + } + } + } + + response = requests.post(url, json=data) + response_data = response.json() + + # Get the assistant's message from the output node + assistant_message = response_data.get("output_node", {}).get("assistant_message", "") + + # Send the response back to Slack + if assistant_message: + say(text=assistant_message, thread_ts=thread_ts) + else: + say(text="I encountered an issue processing your request.", thread_ts=thread_ts) + + except Exception as e: + logger.error(f"Error processing request: {e}") + say(text="Sorry, I encountered an error processing your request.", thread_ts=thread_ts) + +# Handler for thread replies +@app.event("message") +def handle_thread_replies(event, say, logger): + # Only process thread replies (not the first message) and ignore bot messages + if ( + "thread_ts" in event + and event.get("ts") != event.get("thread_ts") + and not event.get("bot_id") + ): + thread_ts = event["thread_ts"] + channel_id = event["channel"] + user_external_id = event.get("user") + + try: + # Create or get user + user_data = { + "external_id": user_external_id, + "user_metadata": {"platform": "slack"} + } + user_response = requests.post(f"{PYSPUR_API_URL}/user/", json=user_data) + if user_response.status_code not in [200, 409]: + logger.error(f"Failed to create user: {user_response.text}") + say(text="Sorry, I encountered an error processing your request.", thread_ts=thread_ts) + return + + user_id = user_response.json().get("id") + + # Create a new session or get existing one + session_data = { + "workflow_id": WORKFLOW_ID, + "user_id": user_id, + "external_id": thread_ts + } + session_response = requests.post(f"{PYSPUR_API_URL}/session/", json=session_data) + if session_response.status_code != 200: + logger.error(f"Failed to create session: {session_response.text}") + say(text="Sorry, I encountered an error processing your request.", thread_ts=thread_ts) + return + + # Get all replies in the thread + result = app.client.conversations_replies(channel=channel_id, ts=thread_ts) + + # Format messages as a conversation history + chat_messages = [] + for message in result["messages"]: + role = "assistant" if message.get("user") == BOT_USER_ID else "user" + chat_messages.append({"role": role, "content": message.get("text", "")}) + + # Get message history and current message + message_history = chat_messages[:-1] if len(chat_messages) > 1 else [] + user_message = chat_messages[-1]["content"] if chat_messages else "" + + # Call the workflow API + url = f"{PYSPUR_API_URL}/wf/{WORKFLOW_ID}/run/?run_type=blocking" + data = { + "initial_inputs": { + "input_node": { + "user_message": user_message, + "session_id": session_response.json()["id"], + "message_history": message_history, + } + } + } + + response = requests.post(url, json=data) + response_data = response.json() + + # Get the assistant's message + assistant_message = response_data.get("output_node", {}).get("assistant_message", "") + + # Send the response back to Slack + if assistant_message: + say(text=assistant_message, thread_ts=thread_ts) + else: + say(text="I encountered an issue processing your request.", thread_ts=thread_ts) + + except Exception as e: + logger.error(f"Error processing thread reply: {e}") + say(text="Sorry, I had trouble processing your message.", thread_ts=thread_ts) + +if __name__ == "__main__": + # Start the app using Socket Mode + handler = SocketModeHandler(app, os.environ.get("SLACK_APP_TOKEN")) + handler.start() +``` + +### 4. Set Up Environment Variables + +Create a `.env` file with your tokens: + +``` +SLACK_BOT_TOKEN=xoxb-your-token +SLACK_SIGNING_SECRET=your-signing-secret +SLACK_APP_TOKEN=xapp-your-app-token +``` + +### 5. Run Your Integration + +1. Install required packages: + ``` + pip install slack-bolt python-dotenv requests + ``` + +2. Make sure your PySpur server is running + +3. Start your Slack integration: + ``` + python slack_integration.py + ``` + +## How It Works + +This integration creates a bidirectional connection between Slack and your PySpur chatbot: + +1. **User Input**: When a user mentions your bot or replies in a thread, the Slack app captures this input. + +2. **Session Management**: The script creates or retrieves a PySpur session for the user, using the thread timestamp as the external ID to track conversations. + +3. **Message History**: For thread replies, the script retrieves the entire conversation history and formats it in the proper structure for your chatbot. + +4. **PySpur API Call**: The user message and conversation history are sent to your PySpur workflow through the API. + +5. **Response**: The assistant's response from your workflow is posted back to the Slack thread. + +## Key Differences from SlackNotifyNode + +This approach differs significantly from the `SlackNotifyNode`: + +- **SlackNotifyNode** is a one-way communication channel where your workflow sends a message to Slack when it completes. It's ideal for notifications, alerts, or sharing results of automated processes. + +- **Interactive Slack Integration** creates a two-way communication channel where users can have ongoing conversations with your chatbot. The integration manages sessions, tracks conversation history, and maintains context across multiple interactions. + +## Example Use Cases + +- **Support Bot**: Create a support chatbot that answers user questions about your product +- **Data Query Assistant**: Allow users to query company data through natural language in Slack +- **Task Manager**: Build a bot that helps users create and track tasks through conversation +- **Knowledge Base**: Develop a bot that retrieves and explains information from your documentation + +## Troubleshooting + +- **Bot not responding**: Ensure your bot has been invited to the channel and has the necessary permissions +- **Missing messages**: Check your log output for any API errors +- **Session errors**: Verify that your PySpur server is running and accessible +- **Token issues**: Double-check that all environment variables are set correctly diff --git a/pyspur/docs/essentials/code.mdx b/pyspur/docs/essentials/code.mdx new file mode 100644 index 0000000000000000000000000000000000000000..918a98d9866b5bd7b9d32078cc8cc24acc1062f6 --- /dev/null +++ b/pyspur/docs/essentials/code.mdx @@ -0,0 +1,37 @@ +--- +title: 'Code Blocks' +description: 'Display inline code and code blocks' +icon: 'code' +--- + +## Basic + +### Inline Code + +To denote a `word` or `phrase` as code, enclose it in backticks (`). + +``` +To denote a `word` or `phrase` as code, enclose it in backticks (`). +``` + +### Code Block + +Use [fenced code blocks](https://www.markdownguide.org/extended-syntax/#fenced-code-blocks) by enclosing code in three backticks and follow the leading ticks with the programming language of your snippet to get syntax highlighting. Optionally, you can also write the name of your code after the programming language. + +```java HelloWorld.java +class HelloWorld { + public static void main(String[] args) { + System.out.println("Hello, World!"); + } +} +``` + +````md +```java HelloWorld.java +class HelloWorld { + public static void main(String[] args) { + System.out.println("Hello, World!"); + } +} +``` +```` diff --git a/pyspur/docs/essentials/deployment.mdx b/pyspur/docs/essentials/deployment.mdx new file mode 100644 index 0000000000000000000000000000000000000000..95099be327d5abfec8cd9392b30789c5a8c7018c --- /dev/null +++ b/pyspur/docs/essentials/deployment.mdx @@ -0,0 +1,223 @@ +--- +title: 'Deploying Spurs as APIs' +description: 'Turn your Spur workflows into production-ready APIs with one click' +--- + +## One-Click Deployment + +PySpur makes it incredibly easy to deploy your workflows as production-ready APIs with a single click. + + + + Navigate to any workflow you've created and want to deploy. + + + Click the "Deploy" button in the top navigation bar to open the deployment modal. + Deploy Modal Light Mode + Deploy Modal Dark Mode + + + In the modal that appears, you can configure: + - **API call type**: Choose between blocking (synchronous) or non-blocking (asynchronous) calls + - **Programming language**: Select your preferred language for the code example + + For example, in Python: + Deploy Modal Light Mode + Deploy Modal Dark Mode + Or in TypeScript: + Deploy Modal Light Mode + Deploy Modal Dark Mode + + + Copy the generated code example to integrate with your application. + + + +## API Call Types + +PySpur supports two types of API calls when deploying your workflows: + + + + Use blocking calls when: + - You need immediate results + - The workflow completes quickly + - You want to process the response in the same request + + The API will wait for the workflow to complete before returning a response. + + ```bash + # Endpoint structure + POST /api/wf/{workflow_id}/run/?run_type=blocking + ``` + + + Use non-blocking calls when: + - Workflows may take longer to complete + - You want to decouple request and response + - You need better scalability for long-running tasks + + The API will immediately return a run ID, and you can check the status later. + + ```bash + # Start endpoint + POST /api/wf/{workflow_id}/start_run/?run_type=non_blocking + + # Status check endpoint + GET /api/runs/{run_id}/status/ + ``` + + + +## Code Examples + +The deployment modal provides ready-to-use code examples in various programming languages: + + + + ```python + import requests + + # For blocking calls + url = 'https://your-pyspur-instance.com/api/wf/{workflow_id}/run/?run_type=blocking' + data = { + "initial_inputs": { + "InputNode_1": { + "input_field_1": "example_value", + "input_field_2": 123 + } + } + } + + response = requests.post(url, json=data) + print(response.status_code) + print(response.json()) + ``` + + For non-blocking calls: + + ```python + # Step 1: Start the workflow + url = 'https://your-pyspur-instance.com/api/wf/{workflow_id}/start_run/?run_type=non_blocking' + response = requests.post(url, json=data) + run_id = response.json()['id'] + + # Step 2: Check status later + status_url = f'https://your-pyspur-instance.com/api/runs/{run_id}/status/' + status_response = requests.get(status_url) + print(status_response.json()) + ``` + + + ```javascript + // For blocking calls + fetch('https://your-pyspur-instance.com/api/wf/{workflow_id}/run/?run_type=blocking', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ + initial_inputs: { + InputNode_1: { + input_field_1: "example_value", + input_field_2: 123 + } + } + }) + }) + .then(response => response.json()) + .then(data => console.log(data)) + .catch(error => console.error('Error:', error)); + ``` + + + ```bash + # For blocking calls + curl -X POST 'https://your-pyspur-instance.com/api/wf/{workflow_id}/run/?run_type=blocking' \ + -H "Content-Type: application/json" \ + -d '{ + "initial_inputs": { + "InputNode_1": { + "input_field_1": "example_value", + "input_field_2": 123 + } + } + }' + ``` + + + +## Advanced Deployment Options + + + + Run your workflow over a dataset with the batch processing API + + ``` + POST /api/wf/{workflow_id}/start_batch_run/ + ``` + + Provide a dataset ID and mini-batch size to process large datasets efficiently. + + + Cancel in-progress workflows when needed + + ``` + POST /api/cancel_workflow/{run_id}/ + ``` + + This is useful for stopping long-running or paused workflows. + + + PySpur provides full control over your deployed workflows with APIs for: + - Listing all runs of a workflow + - Retrieving run status + - Handling human-in-the-loop interventions + + + +## Security Considerations + +When deploying workflows as APIs, consider: + +1. **API Authentication**: Add appropriate authentication to your PySpur instance +2. **Input Validation**: Ensure workflows validate inputs properly +3. **Error Handling**: Implement robust error handling in your client code + +## Next Steps + + + + Learn how to secure your deployed APIs + + + Track usage and performance of your deployed Spurs + + + Explore additional deployment configuration options + + diff --git a/pyspur/docs/essentials/images.mdx b/pyspur/docs/essentials/images.mdx new file mode 100644 index 0000000000000000000000000000000000000000..11650b7b63c07fe028f2b5939163ef9cd7d11649 --- /dev/null +++ b/pyspur/docs/essentials/images.mdx @@ -0,0 +1,59 @@ +--- +title: 'Images and Embeds' +description: 'Add image, video, and other HTML elements' +icon: 'image' +--- + + + +## Image + +### Using Markdown + +The [markdown syntax](https://www.markdownguide.org/basic-syntax/#images) lets you add images using the following code + +```md +![title](/path/image.jpg) +``` + +Note that the image file size must be less than 5MB. Otherwise, we recommend hosting on a service like [Cloudinary](https://cloudinary.com/) or [S3](https://aws.amazon.com/s3/). You can then use that URL and embed. + +### Using Embeds + +To get more customizability with images, you can also use [embeds](/writing-content/embed) to add images + +```html + +``` + +## Embeds and HTML elements + + + +
+ + + +Mintlify supports [HTML tags in Markdown](https://www.markdownguide.org/basic-syntax/#html). This is helpful if you prefer HTML tags to Markdown syntax, and lets you create documentation with infinite flexibility. + + + +### iFrames + +Loads another HTML page within the document. Most commonly used for embedding videos. + +```html + +``` diff --git a/pyspur/docs/essentials/markdown.mdx b/pyspur/docs/essentials/markdown.mdx new file mode 100644 index 0000000000000000000000000000000000000000..bd19e3065cacffa303a8e1f84c6b2e9f1d142bdb --- /dev/null +++ b/pyspur/docs/essentials/markdown.mdx @@ -0,0 +1,88 @@ +--- +title: 'Markdown Syntax' +description: 'Text, title, and styling in standard markdown' +icon: 'text-size' +--- + +## Titles + +Best used for section headers. + +```md +## Titles +``` + +### Subtitles + +Best use to subsection headers. + +```md +### Subtitles +``` + + + +Each **title** and **subtitle** creates an anchor and also shows up on the table of contents on the right. + + + +## Text Formatting + +We support most markdown formatting. Simply add `**`, `_`, or `~` around text to format it. + +| Style | How to write it | Result | +| ------------- | ----------------- | --------------- | +| Bold | `**bold**` | **bold** | +| Italic | `_italic_` | _italic_ | +| Strikethrough | `~strikethrough~` | ~strikethrough~ | + +You can combine these. For example, write `**_bold and italic_**` to get **_bold and italic_** text. + +You need to use HTML to write superscript and subscript text. That is, add `` or `` around your text. + +| Text Size | How to write it | Result | +| ----------- | ------------------------ | ---------------------- | +| Superscript | `superscript` | superscript | +| Subscript | `subscript` | subscript | + +## Linking to Pages + +You can add a link by wrapping text in `[]()`. You would write `[link to google](https://google.com)` to [link to google](https://google.com). + +Links to pages in your docs need to be root-relative. Basically, you should include the entire folder path. For example, `[link to text](/writing-content/text)` links to the page "Text" in our components section. + +Relative links like `[link to text](../text)` will open slower because we cannot optimize them as easily. + +## Blockquotes + +### Singleline + +To create a blockquote, add a `>` in front of a paragraph. + +> Dorothy followed her through many of the beautiful rooms in her castle. + +```md +> Dorothy followed her through many of the beautiful rooms in her castle. +``` + +### Multiline + +> Dorothy followed her through many of the beautiful rooms in her castle. +> +> The Witch bade her clean the pots and kettles and sweep the floor and keep the fire fed with wood. + +```md +> Dorothy followed her through many of the beautiful rooms in her castle. +> +> The Witch bade her clean the pots and kettles and sweep the floor and keep the fire fed with wood. +``` + +### LaTeX + +Mintlify supports [LaTeX](https://www.latex-project.org) through the Latex component. + +8 x (vk x H1 - H2) = (0,1) + +```md +8 x (vk x H1 - H2) = (0,1) +``` diff --git a/pyspur/docs/essentials/navigation.mdx b/pyspur/docs/essentials/navigation.mdx new file mode 100644 index 0000000000000000000000000000000000000000..3b956c730f156b4d448b815ae9fd523e78380afc --- /dev/null +++ b/pyspur/docs/essentials/navigation.mdx @@ -0,0 +1,66 @@ +--- +title: 'Navigation' +description: 'The navigation field in mint.json defines the pages that go in the navigation menu' +icon: 'map' +--- + +The navigation menu is the list of links on every website. + +You will likely update `mint.json` every time you add a new page. Pages do not show up automatically. + +## Navigation syntax + +Our navigation syntax is recursive which means you can make nested navigation groups. You don't need to include `.mdx` in page names. + + + +```json Regular Navigation +"navigation": [ + { + "group": "Getting Started", + "pages": ["quickstart"] + } +] +``` + +```json Nested Navigation +"navigation": [ + { + "group": "Getting Started", + "pages": [ + "quickstart", + { + "group": "Nested Reference Pages", + "pages": ["nested-reference-page"] + } + ] + } +] +``` + + + +## Folders + +Simply put your MDX files in folders and update the paths in `mint.json`. + +For example, to have a page at `https://yoursite.com/your-folder/your-page` you would make a folder called `your-folder` containing an MDX file called `your-page.mdx`. + + + +You cannot use `api` for the name of a folder unless you nest it inside another folder. Mintlify uses Next.js which reserves the top-level `api` folder for internal server calls. A folder name such as `api-reference` would be accepted. + + + +```json Navigation With Folder +"navigation": [ + { + "group": "Group Name", + "pages": ["your-folder/your-page"] + } +] +``` + +## Hidden Pages + +MDX files not included in `mint.json` will not show up in the sidebar but are accessible through the search bar and by linking directly to them. diff --git a/pyspur/docs/essentials/nodes.mdx b/pyspur/docs/essentials/nodes.mdx new file mode 100644 index 0000000000000000000000000000000000000000..b88165de17be6b1584dad267ba6cc0889f13efda --- /dev/null +++ b/pyspur/docs/essentials/nodes.mdx @@ -0,0 +1,235 @@ +--- +title: Nodes +description: The building blocks of any spur +icon: 'cube' +--- + +import { Callout } from 'nextra/components' + +## What are Nodes? + +Nodes are typed functions that serve as the building blocks of workflows and agent tools. Each node is a self-contained unit that: +- Has a defined schema for inputs and outputs +- Performs a specific task or operation +- Can be connected to other nodes in a workflow +- Can be used as tools by agents + +### Visual Overview + +```mermaid +graph TD + A[Input Node] --> B[Processing Node] + B --> C[Output Node] + + subgraph "Node Structure" + D[Configuration] --> E[Node] + F[Input Schema] --> E + G[Output Schema] --> E + E --> H[Execution Logic] + end +``` + +## Using Nodes in Your Project + +### As Workflow Components + +Nodes can be connected together to create workflows. The workflow executor handles: +- Type validation between connected nodes +- Dependency resolution and execution order +- Data flow between nodes + +```mermaid +graph LR + A[Input Node] -->|question| B[LLM Node] + B -->|response| C[Output Node] + + subgraph "Data Flow" + D[question: string] -->|validates| E[response: string] + end +``` + +```python +# Example of nodes in a workflow +workflow = WorkflowDefinitionSchema( + nodes=[ + { + "id": "input_node", + "title": "User Input", + "node_type": "InputNode", + "config": {"output_schema": {"question": "string"}} + }, + { + "id": "llm_node", + "title": "LLM Processing", + "node_type": "SingleLLMCallNode", + "config": { + "system_message": "You are a helpful assistant.", + "user_message": "{{ question }}" + } + } + ], + links=[ + { + "source_id": "input_node", + "target_id": "llm_node" + } + ] +) +``` + +### As Agent Tools + +Nodes can be used directly by agents as tools. Each node: +- Has built-in parameter validation +- Returns structured outputs +- Can access configured integrations + +## Node Components + +### Configuration + +Each node has a configuration class that defines its behavior: + +```python +class SlackNotifyNodeConfig(BaseNodeConfig): + channel: str = Field("", description="The channel ID to send the message to.") + message: str = Field( + default="", + description="The message template to send to Slack." + ) +``` + +### Input/Output Schema + +Nodes use Pydantic models to define their input and output schemas: + +```python +class SlackNotifyNodeInput(BaseNodeInput): + message: str + channel: str + +class SlackNotifyNodeOutput(BaseNodeOutput): + status: str +``` + +### Execution + +The node's core logic is implemented in its `run` method: + +```python +async def run(self, input: BaseModel) -> BaseModel: + # Node implementation + result = await self.process(input) + return self.output_model.model_validate(result) +``` + +## Node Types and Examples + +### Core Node Types + +```mermaid +graph TD + A[Base Node] --> B[AI Nodes] + A --> C[Integration Nodes] + A --> D[Logic Nodes] +``` + +### 1. AI Nodes + +AI nodes handle interactions with language models and other AI services. Located in `@llm` directory. + +```python +@NodeRegistry.register(category="LLM") +class SingleLLMCallNode(BaseNode): + name = "single_llm_call_node" + + async def run(self, input: BaseModel) -> BaseModel: + messages = create_messages( + system_message=self.config.system_message, + user_message=Template(self.config.user_message).render(**input.model_dump()) + ) + result = await generate_text(messages=messages, **self.config.llm_info) + return self.output_model.model_validate(result) +``` + +### 2. Integration Nodes + +Integration nodes connect with external services and APIs. Located in `@integrations` directory. + +```python +@NodeRegistry.register(category="Integrations") +class SlackNotifyNode(BaseNode): + name = "slack_notify_node" + + async def run(self, input: BaseModel) -> BaseModel: + message = Template(self.config.message).render(**input.model_dump()) + client = SlackClient() + ok, status = client.send_message( + channel=self.config.channel, + text=message + ) + return SlackNotifyNodeOutput(status=status) +``` + +### 3. Logic Nodes + +Logic nodes control workflow execution paths and data flow. Located in `@logic` directory. + +```python +@NodeRegistry.register(category="Logic") +class RouterNode(BaseNode): + name = "router_node" + + async def run(self, input: BaseModel) -> BaseModel: + condition = Template(self.config.condition).render(**input.model_dump()) + route = "true_path" if eval(condition) else "false_path" + return RouterNodeOutput(**{route: input.model_dump()}) +``` + + +The node categories reflect the actual codebase organization: +- **AI Nodes** (`@llm`): Language models and AI services +- **Integration Nodes** (`@integrations`): External service connections +- **Logic Nodes** (`@logic`): Flow control and data routing + + +## Creating Custom Nodes + +### Basic Structure + +To create a custom node: + +1. Define the configuration, input, and output models +2. Implement the node class +3. Register the node + +```python +@NodeRegistry.register(category="Integrations") +class CustomNode(BaseNode): + name = "custom_node" + config_model = CustomNodeConfig + input_model = CustomNodeInput + output_model = CustomNodeOutput + + async def run(self, input: BaseModel) -> BaseModel: + # Implement node logic here + pass +``` + +### Node Registration + +Nodes can be registered using the decorator pattern: + +```python +@NodeRegistry.register( + category="Integrations", + display_name="Slack Notify", + logo="/images/slack.png" +) +class SlackNotifyNode(BaseNode): + # Node implementation +``` + + +Remember to properly handle errors and validate inputs/outputs in your custom nodes. The base node class provides built-in validation, but you should add domain-specific validation in your `run` method. + diff --git a/pyspur/docs/essentials/reusable-snippets.mdx b/pyspur/docs/essentials/reusable-snippets.mdx new file mode 100644 index 0000000000000000000000000000000000000000..c1b1ca35b322f4a7278ba4751befdbd79f10bf98 --- /dev/null +++ b/pyspur/docs/essentials/reusable-snippets.mdx @@ -0,0 +1,110 @@ +--- +title: Reusable Snippets +description: Reusable, custom snippets to keep content in sync +icon: 'recycle' +--- + +import SnippetIntro from '/snippets/snippet-intro.mdx'; + + + +## Creating a custom snippet + +**Pre-condition**: You must create your snippet file in the `snippets` directory. + + + Any page in the `snippets` directory will be treated as a snippet and will not + be rendered into a standalone page. If you want to create a standalone page + from the snippet, import the snippet into another file and call it as a + component. + + +### Default export + +1. Add content to your snippet file that you want to re-use across multiple + locations. Optionally, you can add variables that can be filled in via props + when you import the snippet. + +```mdx snippets/my-snippet.mdx +Hello world! This is my content I want to reuse across pages. My keyword of the +day is {word}. +``` + + + The content that you want to reuse must be inside the `snippets` directory in + order for the import to work. + + +2. Import the snippet into your destination file. + +```mdx destination-file.mdx +--- +title: My title +description: My Description +--- + +import MySnippet from '/snippets/path/to/my-snippet.mdx'; + +## Header + +Lorem impsum dolor sit amet. + + +``` + +### Reusable variables + +1. Export a variable from your snippet file: + +```mdx snippets/path/to/custom-variables.mdx +export const myName = 'my name'; + +export const myObject = { fruit: 'strawberries' }; +``` + +2. Import the snippet from your destination file and use the variable: + +```mdx destination-file.mdx +--- +title: My title +description: My Description +--- + +import { myName, myObject } from '/snippets/path/to/custom-variables.mdx'; + +Hello, my name is {myName} and I like {myObject.fruit}. +``` + +### Reusable components + +1. Inside your snippet file, create a component that takes in props by exporting + your component in the form of an arrow function. + +```mdx snippets/custom-component.mdx +export const MyComponent = ({ title }) => ( +
+

{title}

+

... snippet content ...

+
+); +``` + + + MDX does not compile inside the body of an arrow function. Stick to HTML + syntax when you can or use a default export if you need to use MDX. + + +2. Import the snippet into your destination file and pass in the props + +```mdx destination-file.mdx +--- +title: My title +description: My Description +--- + +import { MyComponent } from '/snippets/custom-component.mdx'; + +Lorem ipsum dolor sit amet. + + +``` diff --git a/pyspur/docs/essentials/settings.mdx b/pyspur/docs/essentials/settings.mdx new file mode 100644 index 0000000000000000000000000000000000000000..1d25a776ed378f928b1f5c88277367703523db0e --- /dev/null +++ b/pyspur/docs/essentials/settings.mdx @@ -0,0 +1,318 @@ +--- +title: 'Global Settings' +description: 'Mintlify gives you complete control over the look and feel of your documentation using the mint.json file' +icon: 'gear' +--- + +Every Mintlify site needs a `mint.json` file with the core configuration settings. Learn more about the [properties](#properties) below. + +## Properties + + +Name of your project. Used for the global title. + +Example: `mintlify` + + + + + An array of groups with all the pages within that group + + + The name of the group. + + Example: `Settings` + + + + The relative paths to the markdown files that will serve as pages. + + Example: `["customization", "page"]` + + + + + + + + Path to logo image or object with path to "light" and "dark" mode logo images + + + Path to the logo in light mode + + + Path to the logo in dark mode + + + Where clicking on the logo links you to + + + + + + Path to the favicon image + + + + Hex color codes for your global theme + + + The primary color. Used for most often for highlighted content, section + headers, accents, in light mode + + + The primary color for dark mode. Used for most often for highlighted + content, section headers, accents, in dark mode + + + The primary color for important buttons + + + The color of the background in both light and dark mode + + + The hex color code of the background in light mode + + + The hex color code of the background in dark mode + + + + + + + + Array of `name`s and `url`s of links you want to include in the topbar + + + The name of the button. + + Example: `Contact us` + + + The url once you click on the button. Example: `https://mintlify.com/docs` + + + + + + + + + Link shows a button. GitHub shows the repo information at the url provided including the number of GitHub stars. + + + If `link`: What the button links to. + + If `github`: Link to the repository to load GitHub information from. + + + Text inside the button. Only required if `type` is a `link`. + + + + + + + Array of version names. Only use this if you want to show different versions + of docs with a dropdown in the navigation bar. + + + + An array of the anchors, includes the `icon`, `color`, and `url`. + + + The [Font Awesome](https://fontawesome.com/search?q=heart) icon used to feature the anchor. + + Example: `comments` + + + The name of the anchor label. + + Example: `Community` + + + The start of the URL that marks what pages go in the anchor. Generally, this is the name of the folder you put your pages in. + + + The hex color of the anchor icon background. Can also be a gradient if you pass an object with the properties `from` and `to` that are each a hex color. + + + Used if you want to hide an anchor until the correct docs version is selected. + + + Pass `true` if you want to hide the anchor until you directly link someone to docs inside it. + + + One of: "brands", "duotone", "light", "sharp-solid", "solid", or "thin" + + + + + + + Override the default configurations for the top-most anchor. + + + The name of the top-most anchor + + + Font Awesome icon. + + + One of: "brands", "duotone", "light", "sharp-solid", "solid", or "thin" + + + + + + An array of navigational tabs. + + + The name of the tab label. + + + The start of the URL that marks what pages go in the tab. Generally, this + is the name of the folder you put your pages in. + + + + + + Configuration for API settings. Learn more about API pages at [API Components](/api-playground/demo). + + + The base url for all API endpoints. If `baseUrl` is an array, it will enable for multiple base url + options that the user can toggle. + + + + + + The authentication strategy used for all API endpoints. + + + The name of the authentication parameter used in the API playground. + + If method is `basic`, the format should be `[usernameName]:[passwordName]` + + + The default value that's designed to be a prefix for the authentication input field. + + E.g. If an `inputPrefix` of `AuthKey` would inherit the default input result of the authentication field as `AuthKey`. + + + + + + Configurations for the API playground + + + + Whether the playground is showing, hidden, or only displaying the endpoint with no added user interactivity `simple` + + Learn more at the [playground guides](/api-playground/demo) + + + + + + Enabling this flag ensures that key ordering in OpenAPI pages matches the key ordering defined in the OpenAPI file. + + This behavior will soon be enabled by default, at which point this field will be deprecated. + + + + + + + A string or an array of strings of URL(s) or relative path(s) pointing to your + OpenAPI file. + + Examples: + + ```json Absolute + "openapi": "https://example.com/openapi.json" + ``` + ```json Relative + "openapi": "/openapi.json" + ``` + ```json Multiple + "openapi": ["https://example.com/openapi1.json", "/openapi2.json", "/openapi3.json"] + ``` + + + + + + An object of social media accounts where the key:property pair represents the social media platform and the account url. + + Example: + ```json + { + "x": "https://x.com/mintlify", + "website": "https://mintlify.com" + } + ``` + + + One of the following values `website`, `facebook`, `x`, `discord`, `slack`, `github`, `linkedin`, `instagram`, `hacker-news` + + Example: `x` + + + The URL to the social platform. + + Example: `https://x.com/mintlify` + + + + + + Configurations to enable feedback buttons + + + + Enables a button to allow users to suggest edits via pull requests + + + Enables a button to allow users to raise an issue about the documentation + + + + + + Customize the dark mode toggle. + + + Set if you always want to show light or dark mode for new users. When not + set, we default to the same mode as the user's operating system. + + + Set to true to hide the dark/light mode toggle. You can combine `isHidden` with `default` to force your docs to only use light or dark mode. For example: + + + ```json Only Dark Mode + "modeToggle": { + "default": "dark", + "isHidden": true + } + ``` + + ```json Only Light Mode + "modeToggle": { + "default": "light", + "isHidden": true + } + ``` + + + + + + + + + A background image to be displayed behind every page. See example with + [Infisical](https://infisical.com/docs) and [FRPC](https://frpc.io). + diff --git a/pyspur/docs/essentials/spurs.mdx b/pyspur/docs/essentials/spurs.mdx new file mode 100644 index 0000000000000000000000000000000000000000..829879821bd53718032facda792b9ac8360e223d --- /dev/null +++ b/pyspur/docs/essentials/spurs.mdx @@ -0,0 +1,142 @@ +--- +title: 'Spurs' +description: 'AI-native dynamic workflow graphs' +icon: 'bolt' +--- + +import { Callout } from 'nextra/components' + +## What are Spurs? + +Spurs are AI-native, dynamic workflow graphs that bridge the gap between traditional workflows and autonomous agents. Think of them as the cool middle child in the automation family - more flexible than their rigid DAG siblings, but more structured than their free-spirited agent cousins. + +```mermaid +graph TD + A[Traditional DAGs] --> B[Spurs] + C[Autonomous Agents] --> B + style B fill:#f9f,stroke:#333,stroke-width:4px +``` + +### Visual Overview + +```mermaid +graph TD + A[Input Node] --> B[LLM Node] + B --> C[Tool Node] + C --> B + B --> D[Output Node] + + style B fill:#f9f,stroke:#333 +``` + +## Why Spurs? + +### Beyond Traditional DAGs + +Traditional workflow systems are like trains on tracks - they follow predetermined paths and can't deviate. But real-world problems often require more flexibility. What if your workflow needs to: +- Decide its next step based on AI analysis? +- Loop through a process until a condition is met? +- Dynamically choose which tools to use? + +This is where Spurs shine. They combine the reliability of structured workflows with the adaptability of AI. + +### The Power of AI-Native Workflows + +Spurs are built from the ground up with AI in mind. Unlike traditional DAGs that can only flow in one direction, Spurs can: +- Create cycles where LLM agents call themselves repeatedly +- Dynamically decide execution paths +- Integrate seamlessly with AI tools and services + + +While traditional workflows say "do this, then that", Spurs say "analyze this, decide what to do next, and keep going until you achieve the goal." + + +## Building Blocks + +### Nodes: The Foundation + +Spurs are composed of nodes - modular, typed functions that serve as building blocks. Each node can: +- Process inputs and produce outputs +- Make decisions about workflow continuation +- Interact with external services +- Execute AI operations + +For a detailed explanation of nodes, check out our [Nodes documentation](/essentials/nodes). + +### Execution Model + +The execution engine behind Spurs is asynchronous and intelligent: +- Nodes run as soon as their dependencies are satisfied +- Multiple nodes can execute in parallel +- The execution path can change dynamically based on LLM decisions +- Cycles are handled gracefully, allowing for iterative processing + +```mermaid +graph LR + A[Input] --> B[Process] + B --> C{LLM Decision} + C -->|Continue| D[More Processing] + C -->|Complete| E[Output] + D --> C +``` + +## Creating Your First Spur + +You have three paths to create Spurs, choose the one that best fits your workflow: + +### 1. Visual UI Builder +Perfect for visual thinkers and rapid prototyping: +```python +# No code needed! Just drag, drop, and connect nodes +``` + +### 2. Python Package +For those who prefer programmatic control: +```python +from pyspur import Spur, Node + +spur = Spur() +spur.add_node(Node.input("user_query")) +spur.add_node(Node.llm("analysis")) +spur.add_node(Node.tool("web_search")) +``` + +### 3. JSON Structure +For infrastructure-as-code and version control: +```json +{ + "nodes": [ + {"type": "input", "id": "query"}, + {"type": "llm", "id": "analysis"}, + {"type": "tool", "id": "search"} + ], + "edges": [...] +} +``` + + +Remember: With great power comes great responsibility. While Spurs can create cycles, make sure they have clear termination conditions to avoid infinite loops! + + +## Advanced Concepts + +### Cyclic Workflows + +Unlike traditional DAGs, Spurs embrace cycles. This enables powerful patterns like: +- Iterative refinement of results +- Recursive problem solving +- Dynamic tool selection and usage + +### Dynamic Execution + +The async execution engine ensures optimal performance: +```python +# Behind the scenes in workflow_executor.py +async def execute_node(self, node_id: str): + # Run nodes as soon as dependencies are satisfied + # No waiting for the entire level to complete! +``` + + +Think of Spurs as a jazz band rather than an orchestra - each node can improvise and play its part when ready, while still maintaining harmony with the overall piece. + diff --git a/pyspur/docs/evals/concepts.mdx b/pyspur/docs/evals/concepts.mdx new file mode 100644 index 0000000000000000000000000000000000000000..b65c574d8a7fb016a6f992951e00610ff55deffd --- /dev/null +++ b/pyspur/docs/evals/concepts.mdx @@ -0,0 +1,120 @@ +--- +title: 'Concepts' +description: 'Learn how PySpur helps you measure the performance of your AI workflows' +--- + +# Understanding Evaluations in PySpur + +Evaluation is the process of measuring how well your AI workflows perform against objective benchmarks. Instead of guessing if your workflow is doing a good job, evaluations provide quantitative metrics so you can: + +- Measure the accuracy of your workflow's outputs +- Compare different versions of your workflows +- Identify areas for improvement +- Build trust in your AI systems + +## Why Evaluate? + +Without evaluation, it's difficult to know if your AI systems are performing as expected. Evaluations help you: + +- **Verify accuracy**: Ensure your workflows produce correct answers +- **Track improvement**: Measure progress as you refine your workflows +- **Compare approaches**: Determine which techniques work best +- **Build confidence**: Provide evidence of your system's capabilities + +## How Evaluations Work in PySpur + +The evaluation process in PySpur has three main components: + +### 1. Evaluation Benchmarks + +PySpur includes pre-built benchmarks from academic and industry standards. Each benchmark: + +- Contains a dataset of problems with known correct answers +- Specifies how to format inputs for your workflow +- Defines how to extract and evaluate outputs from your workflow + +For demonstration purposes, we provide some stock benchmarks for: +- Mathematical reasoning (GSM8K) +- Graduate-level Question answering + +But the real power of evals will be unlocked when used with data matching your use cases. + +### 2. Your Workflow + +You connect your existing PySpur workflow to the evaluation system. The workflow: +- Receives inputs from the evaluation dataset +- Processes them through your custom logic and AI components +- Returns outputs that will be compared against the ground truth + +### 3. Results and Metrics + +After running an evaluation, PySpur provides detailed metrics: + +- **Accuracy**: The percentage of correct answers +- **Per-category breakdowns**: How performance varies across problem types +- **Example-level results**: Which specific examples succeeded or failed +- **Visualizations**: Charts and graphs to help interpret results + +## The Evaluation Workflow in PySpur + +Here's how to run an evaluation in PySpur: + +1. **Choose an Evaluation Benchmark** + - Browse the available evaluation benchmarks + - Review the description, problem type, and sample size + +2. **Select a Workflow to Evaluate** + - Choose which of your workflows to test + - Select the specific output variable to evaluate + +3. **Configure the Evaluation** + - Choose how many samples to evaluate (up to the max available) + - Launch the evaluation job + +4. **Review Results** + - Monitor the evaluation progress in real-time + - Once completed, view detailed accuracy metrics + - Analyze per-example results to identify patterns in errors + +## Example Evaluation Results + +Here's what evaluation results typically look like: + +