diff --git a/.air.toml b/.air.toml
new file mode 100644
index 0000000000000000000000000000000000000000..1e8442249dc8102955fd99b516e8abb241f9f87a
--- /dev/null
+++ b/.air.toml
@@ -0,0 +1,8 @@
+# .air.toml
+[build]
+cmd = "make build"
+bin = "./local-ai"
+args_bin = [ "--debug" ]
+include_ext = ["go", "html", "yaml", "toml", "json", "txt", "md"]
+exclude_dir = ["pkg/grpc/proto"]
+delay = 1000
diff --git a/.devcontainer-scripts/postcreate.sh b/.devcontainer-scripts/postcreate.sh
new file mode 100644
index 0000000000000000000000000000000000000000..3f9035090a355a63c10b2590cb2b2fdd88ee04ac
--- /dev/null
+++ b/.devcontainer-scripts/postcreate.sh
@@ -0,0 +1,17 @@
+#!/bin/bash
+
+cd /workspace
+
+# Get the files into the volume without a bind mount
+if [ ! -d ".git" ]; then
+ git clone https://github.com/mudler/LocalAI.git .
+else
+ git fetch
+fi
+
+echo "Standard Post-Create script completed."
+
+if [ -f "/devcontainer-customization/postcreate.sh" ]; then
+ echo "Launching customization postcreate.sh"
+ bash "/devcontainer-customization/postcreate.sh"
+fi
\ No newline at end of file
diff --git a/.devcontainer-scripts/poststart.sh b/.devcontainer-scripts/poststart.sh
new file mode 100644
index 0000000000000000000000000000000000000000..7e65b4c7ff20dcd508b328390376434a65fa0437
--- /dev/null
+++ b/.devcontainer-scripts/poststart.sh
@@ -0,0 +1,13 @@
+#!/bin/bash
+
+cd /workspace
+
+# Ensures generated source files are present upon load
+make prepare
+
+echo "Standard Post-Start script completed."
+
+if [ -f "/devcontainer-customization/poststart.sh" ]; then
+ echo "Launching customization poststart.sh"
+ bash "/devcontainer-customization/poststart.sh"
+fi
\ No newline at end of file
diff --git a/.devcontainer-scripts/utils.sh b/.devcontainer-scripts/utils.sh
new file mode 100644
index 0000000000000000000000000000000000000000..8416d43d5789a98fab714214022b2c747ff8ab23
--- /dev/null
+++ b/.devcontainer-scripts/utils.sh
@@ -0,0 +1,55 @@
+#!/bin/bash
+
+# This file contains some really simple functions that are useful when building up customization scripts.
+
+
+# Checks if the git config has a user registered - and sets it up if not.
+#
+# Param 1: name
+# Param 2: email
+#
+config_user() {
+ echo "Configuring git for $1 <$2>"
+ local gcn=$(git config --global user.name)
+ if [ -z "${gcn}" ]; then
+ echo "Setting up git user / remote"
+ git config --global user.name "$1"
+ git config --global user.email "$2"
+
+ fi
+}
+
+# Checks if the git remote is configured - and sets it up if not. Fetches either way.
+#
+# Param 1: remote name
+# Param 2: remote url
+#
+config_remote() {
+ echo "Adding git remote and fetching $2 as $1"
+ local gr=$(git remote -v | grep $1)
+ if [ -z "${gr}" ]; then
+ git remote add $1 $2
+ fi
+ git fetch $1
+}
+
+# Setup special .ssh files
+# Prints out lines of text to make things pretty
+# Param 1: bash array, filenames relative to the customization directory that should be copied to ~/.ssh
+setup_ssh() {
+ echo "starting ~/.ssh directory setup..."
+ mkdir -p "${HOME}.ssh"
+ chmod 0700 "${HOME}/.ssh"
+ echo "-----"
+ local files=("$@")
+ for file in "${files[@]}" ; do
+ local cfile="/devcontainer-customization/${file}"
+ local hfile="${HOME}/.ssh/${file}"
+ if [ ! -f "${hfile}" ]; then
+ echo "copying \"${file}\""
+ cp "${cfile}" "${hfile}"
+ chmod 600 "${hfile}"
+ fi
+ done
+ echo "~/.ssh directory setup complete!"
+}
diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json
new file mode 100644
index 0000000000000000000000000000000000000000..37c81ffc41da7915188ccf05e8db74badfed27ad
--- /dev/null
+++ b/.devcontainer/devcontainer.json
@@ -0,0 +1,24 @@
+{
+ "$schema": "https://raw.githubusercontent.com/devcontainers/spec/main/schemas/devContainer.schema.json",
+ "name": "LocalAI",
+ "workspaceFolder": "/workspace",
+ "dockerComposeFile": [ "./docker-compose-devcontainer.yml" ],
+ "service": "api",
+ "shutdownAction": "stopCompose",
+ "customizations": {
+ "vscode": {
+ "extensions": [
+ "golang.go",
+ "ms-vscode.makefile-tools",
+ "ms-azuretools.vscode-docker",
+ "ms-python.python",
+ "ms-python.debugpy",
+ "wayou.vscode-todo-highlight",
+ "waderyan.gitblame"
+ ]
+ }
+ },
+ "forwardPorts": [8080, 3000],
+ "postCreateCommand": "bash /.devcontainer-scripts/postcreate.sh",
+ "postStartCommand": "bash /.devcontainer-scripts/poststart.sh"
+}
\ No newline at end of file
diff --git a/.devcontainer/docker-compose-devcontainer.yml b/.devcontainer/docker-compose-devcontainer.yml
new file mode 100644
index 0000000000000000000000000000000000000000..81610ade5f946e12c656b5d8a0cc98a13a2f0236
--- /dev/null
+++ b/.devcontainer/docker-compose-devcontainer.yml
@@ -0,0 +1,44 @@
+services:
+ api:
+ build:
+ context: ..
+ dockerfile: Dockerfile
+ target: devcontainer
+ env_file:
+ - ../.env
+ ports:
+ - 8080:8080
+ volumes:
+ - localai_workspace:/workspace
+ - ../models:/host-models
+ - ./customization:/devcontainer-customization
+ command: /bin/sh -c "while sleep 1000; do :; done"
+ cap_add:
+ - SYS_PTRACE
+ security_opt:
+ - seccomp:unconfined
+ prometheus:
+ image: prom/prometheus
+ container_name: prometheus
+ command:
+ - '--config.file=/etc/prometheus/prometheus.yml'
+ ports:
+ - 9090:9090
+ restart: unless-stopped
+ volumes:
+ - ./prometheus:/etc/prometheus
+ - prom_data:/prometheus
+ grafana:
+ image: grafana/grafana
+ container_name: grafana
+ ports:
+ - 3000:3000
+ restart: unless-stopped
+ environment:
+ - GF_SECURITY_ADMIN_USER=admin
+ - GF_SECURITY_ADMIN_PASSWORD=grafana
+ volumes:
+ - ./grafana:/etc/grafana/provisioning/datasources
+volumes:
+ prom_data:
+ localai_workspace:
\ No newline at end of file
diff --git a/.devcontainer/grafana/datasource.yml b/.devcontainer/grafana/datasource.yml
new file mode 100644
index 0000000000000000000000000000000000000000..1ed2fa3c2a28cc7193a341842bacbe40953a7c1d
--- /dev/null
+++ b/.devcontainer/grafana/datasource.yml
@@ -0,0 +1,10 @@
+
+apiVersion: 1
+
+datasources:
+- name: Prometheus
+ type: prometheus
+ url: http://prometheus:9090
+ isDefault: true
+ access: proxy
+ editable: true
diff --git a/.devcontainer/prometheus/prometheus.yml b/.devcontainer/prometheus/prometheus.yml
new file mode 100644
index 0000000000000000000000000000000000000000..18c44da71447ad87832496ee321c88d84c6e5be0
--- /dev/null
+++ b/.devcontainer/prometheus/prometheus.yml
@@ -0,0 +1,21 @@
+global:
+ scrape_interval: 15s
+ scrape_timeout: 10s
+ evaluation_interval: 15s
+alerting:
+ alertmanagers:
+ - static_configs:
+ - targets: []
+ scheme: http
+ timeout: 10s
+ api_version: v1
+scrape_configs:
+- job_name: prometheus
+ honor_timestamps: true
+ scrape_interval: 15s
+ scrape_timeout: 10s
+ metrics_path: /metrics
+ scheme: http
+ static_configs:
+ - targets:
+ - localhost:9090
\ No newline at end of file
diff --git a/.dockerignore b/.dockerignore
new file mode 100644
index 0000000000000000000000000000000000000000..5b62e5f31f07150b74728a2b0525560ded34c359
--- /dev/null
+++ b/.dockerignore
@@ -0,0 +1,23 @@
+.idea
+.github
+.vscode
+.devcontainer
+models
+backends
+examples/chatbot-ui/models
+backend/go/image/stablediffusion-ggml/build/
+backend/go/*/build
+backend/go/*/.cache
+backend/go/*/sources
+backend/go/*/package
+examples/rwkv/models
+examples/**/models
+Dockerfile*
+__pycache__
+
+# SonarQube
+.scannerwork
+
+# backend virtual environments
+**/venv
+backend/python/**/source
diff --git a/.editorconfig b/.editorconfig
new file mode 100644
index 0000000000000000000000000000000000000000..b66f364572606f053889dfa2bf79feeaf041dd20
--- /dev/null
+++ b/.editorconfig
@@ -0,0 +1,31 @@
+
+root = true
+
+[*]
+indent_style = space
+indent_size = 2
+end_of_line = lf
+charset = utf-8
+trim_trailing_whitespace = true
+insert_final_newline = true
+
+[*.go]
+indent_style = tab
+
+[Makefile]
+indent_style = tab
+
+[*.proto]
+indent_size = 2
+
+[*.py]
+indent_size = 4
+
+[*.js]
+indent_size = 2
+
+[*.yaml]
+indent_size = 2
+
+[*.md]
+trim_trailing_whitespace = false
diff --git a/.env b/.env
new file mode 100644
index 0000000000000000000000000000000000000000..852d3dac63bd07e21af9d9dbc3c97d2118caa95c
--- /dev/null
+++ b/.env
@@ -0,0 +1,93 @@
+## Set number of threads.
+## Note: prefer the number of physical cores. Overbooking the CPU degrades performance notably.
+# LOCALAI_THREADS=14
+
+## Specify a different bind address (defaults to ":8080")
+# LOCALAI_ADDRESS=127.0.0.1:8080
+
+## Default models context size
+# LOCALAI_CONTEXT_SIZE=512
+#
+## Define galleries.
+## models will to install will be visible in `/models/available`
+# LOCALAI_GALLERIES=[{"name":"localai", "url":"github:mudler/LocalAI/gallery/index.yaml@master"}]
+
+## CORS settings
+# LOCALAI_CORS=true
+# LOCALAI_CORS_ALLOW_ORIGINS=*
+
+## Default path for models
+#
+# LOCALAI_MODELS_PATH=/models
+
+## Enable debug mode
+# LOCALAI_LOG_LEVEL=debug
+
+## Disables COMPEL (Diffusers)
+# COMPEL=0
+
+## Enable/Disable single backend (useful if only one GPU is available)
+# LOCALAI_SINGLE_ACTIVE_BACKEND=true
+
+# Forces shutdown of the backends if busy (only if LOCALAI_SINGLE_ACTIVE_BACKEND is set)
+# LOCALAI_FORCE_BACKEND_SHUTDOWN=true
+
+## Path where to store generated images
+# LOCALAI_IMAGE_PATH=/tmp/generated/images
+
+## Specify a default upload limit in MB (whisper)
+# LOCALAI_UPLOAD_LIMIT=15
+
+## List of external GRPC backends (note on the container image this variable is already set to use extra backends available in extra/)
+# LOCALAI_EXTERNAL_GRPC_BACKENDS=my-backend:127.0.0.1:9000,my-backend2:/usr/bin/backend.py
+
+### Advanced settings ###
+### Those are not really used by LocalAI, but from components in the stack ###
+##
+### Preload libraries
+# LD_PRELOAD=
+
+### Huggingface cache for models
+# HUGGINGFACE_HUB_CACHE=/usr/local/huggingface
+
+### Python backends GRPC max workers
+### Default number of workers for GRPC Python backends.
+### This actually controls wether a backend can process multiple requests or not.
+# PYTHON_GRPC_MAX_WORKERS=1
+
+### Define the number of parallel LLAMA.cpp workers (Defaults to 1)
+# LLAMACPP_PARALLEL=1
+
+### Define a list of GRPC Servers for llama-cpp workers to distribute the load
+# https://github.com/ggerganov/llama.cpp/pull/6829
+# https://github.com/ggerganov/llama.cpp/blob/master/tools/rpc/README.md
+# LLAMACPP_GRPC_SERVERS=""
+
+### Enable to run parallel requests
+# LOCALAI_PARALLEL_REQUESTS=true
+
+# Enable to allow p2p mode
+# LOCALAI_P2P=true
+
+# Enable to use federated mode
+# LOCALAI_FEDERATED=true
+
+# Enable to start federation server
+# FEDERATED_SERVER=true
+
+# Define to use federation token
+# TOKEN=""
+
+### Watchdog settings
+###
+# Enables watchdog to kill backends that are inactive for too much time
+# LOCALAI_WATCHDOG_IDLE=true
+#
+# Time in duration format (e.g. 1h30m) after which a backend is considered idle
+# LOCALAI_WATCHDOG_IDLE_TIMEOUT=5m
+#
+# Enables watchdog to kill backends that are busy for too much time
+# LOCALAI_WATCHDOG_BUSY=true
+#
+# Time in duration format (e.g. 1h30m) after which a backend is considered busy
+# LOCALAI_WATCHDOG_BUSY_TIMEOUT=5m
diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..8d07ffab8d9d41df9de2ea12ff710bb94637e198 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -1,35 +1,29 @@
-*.7z filter=lfs diff=lfs merge=lfs -text
-*.arrow filter=lfs diff=lfs merge=lfs -text
-*.bin filter=lfs diff=lfs merge=lfs -text
-*.bz2 filter=lfs diff=lfs merge=lfs -text
-*.ckpt filter=lfs diff=lfs merge=lfs -text
-*.ftz filter=lfs diff=lfs merge=lfs -text
-*.gz filter=lfs diff=lfs merge=lfs -text
-*.h5 filter=lfs diff=lfs merge=lfs -text
-*.joblib filter=lfs diff=lfs merge=lfs -text
-*.lfs.* filter=lfs diff=lfs merge=lfs -text
-*.mlmodel filter=lfs diff=lfs merge=lfs -text
-*.model filter=lfs diff=lfs merge=lfs -text
-*.msgpack filter=lfs diff=lfs merge=lfs -text
-*.npy filter=lfs diff=lfs merge=lfs -text
-*.npz filter=lfs diff=lfs merge=lfs -text
-*.onnx filter=lfs diff=lfs merge=lfs -text
-*.ot filter=lfs diff=lfs merge=lfs -text
-*.parquet filter=lfs diff=lfs merge=lfs -text
-*.pb filter=lfs diff=lfs merge=lfs -text
-*.pickle filter=lfs diff=lfs merge=lfs -text
-*.pkl filter=lfs diff=lfs merge=lfs -text
-*.pt filter=lfs diff=lfs merge=lfs -text
-*.pth filter=lfs diff=lfs merge=lfs -text
-*.rar filter=lfs diff=lfs merge=lfs -text
-*.safetensors filter=lfs diff=lfs merge=lfs -text
-saved_model/**/* filter=lfs diff=lfs merge=lfs -text
-*.tar.* filter=lfs diff=lfs merge=lfs -text
-*.tar filter=lfs diff=lfs merge=lfs -text
-*.tflite filter=lfs diff=lfs merge=lfs -text
-*.tgz filter=lfs diff=lfs merge=lfs -text
-*.wasm filter=lfs diff=lfs merge=lfs -text
-*.xz filter=lfs diff=lfs merge=lfs -text
-*.zip filter=lfs diff=lfs merge=lfs -text
-*.zst filter=lfs diff=lfs merge=lfs -text
-*tfevents* filter=lfs diff=lfs merge=lfs -text
+*.sh text eol=lf
+backend/cpp/llama/*.hpp linguist-vendoredcore/http/static/assets/KFOlCnqEu92Fr1MmEU9vAw.ttf filter=lfs diff=lfs merge=lfs -text
+core/http/static/assets/KFOmCnqEu92Fr1Me5Q.ttf filter=lfs diff=lfs merge=lfs -text
+core/http/static/assets/UcCO3FwrK3iLTeHuS_fvQtMwCp50KnMw2boKoduKmMEVuFuYMZg.ttf filter=lfs diff=lfs merge=lfs -text
+core/http/static/assets/UcCO3FwrK3iLTeHuS_fvQtMwCp50KnMw2boKoduKmMEVuGKYMZg.ttf filter=lfs diff=lfs merge=lfs -text
+core/http/static/assets/UcCO3FwrK3iLTeHuS_fvQtMwCp50KnMw2boKoduKmMEVuLyfMZg.ttf filter=lfs diff=lfs merge=lfs -text
+core/http/static/assets/fontawesome/webfonts/fa-brands-400.ttf filter=lfs diff=lfs merge=lfs -text
+core/http/static/assets/fontawesome/webfonts/fa-brands-400.woff2 filter=lfs diff=lfs merge=lfs -text
+core/http/static/assets/fontawesome/webfonts/fa-solid-900.ttf filter=lfs diff=lfs merge=lfs -text
+core/http/static/assets/fontawesome/webfonts/fa-solid-900.woff2 filter=lfs diff=lfs merge=lfs -text
+core/http/static/assets/jetbrains-mono-medium.ttf filter=lfs diff=lfs merge=lfs -text
+core/http/static/assets/jetbrains-mono-regular.ttf filter=lfs diff=lfs merge=lfs -text
+core/http/static/assets/jetbrains-mono-semibold.ttf filter=lfs diff=lfs merge=lfs -text
+core/http/static/assets/playfair-display-bold.ttf filter=lfs diff=lfs merge=lfs -text
+core/http/static/assets/playfair-display-regular.ttf filter=lfs diff=lfs merge=lfs -text
+core/http/static/assets/playfair-display-semibold.ttf filter=lfs diff=lfs merge=lfs -text
+core/http/static/logo.png filter=lfs diff=lfs merge=lfs -text
+core/http/static/logo_horizontal.png filter=lfs diff=lfs merge=lfs -text
+docs/assets/images/imagen.png filter=lfs diff=lfs merge=lfs -text
+docs/assets/images/localai_screenshot.png filter=lfs diff=lfs merge=lfs -text
+docs/assets/images/logos/logo.png filter=lfs diff=lfs merge=lfs -text
+docs/assets/images/screenshots/screenshot_chat.png filter=lfs diff=lfs merge=lfs -text
+docs/assets/images/screenshots/screenshot_gallery.png filter=lfs diff=lfs merge=lfs -text
+docs/assets/images/screenshots/screenshot_home.png filter=lfs diff=lfs merge=lfs -text
+docs/assets/images/screenshots/screenshot_image.png filter=lfs diff=lfs merge=lfs -text
+docs/assets/images/screenshots/screenshot_login.png filter=lfs diff=lfs merge=lfs -text
+docs/assets/images/screenshots/screenshot_p2p.png filter=lfs diff=lfs merge=lfs -text
+docs/assets/images/screenshots/screenshot_talk.png filter=lfs diff=lfs merge=lfs -text
+docs/assets/images/screenshots/screenshot_tts.png filter=lfs diff=lfs merge=lfs -text
diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml
new file mode 100644
index 0000000000000000000000000000000000000000..0fc33f328016f5bab2b86ce98ff9d7a2067484ad
--- /dev/null
+++ b/.github/FUNDING.yml
@@ -0,0 +1,5 @@
+# These are supported funding model platforms
+
+github: [mudler]
+custom:
+- https://www.buymeacoffee.com/mudler
diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md
new file mode 100644
index 0000000000000000000000000000000000000000..36e22ced2a6345c527d8808d7710a3f887339b86
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE/bug_report.md
@@ -0,0 +1,29 @@
+---
+name: Bug report
+about: Create a report to help us improve
+title: ''
+labels: bug, unconfirmed, up-for-grabs
+---
+
+
+
+**LocalAI version:**
+
+
+**Environment, CPU architecture, OS, and Version:**
+
+
+**Describe the bug**
+
+
+**To Reproduce**
+
+
+**Expected behavior**
+
+
+**Logs**
+
+
+**Additional context**
+
diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml
new file mode 100644
index 0000000000000000000000000000000000000000..acc65c80ddc9dbd7a3cd15738f8b13028b156e32
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE/config.yml
@@ -0,0 +1,8 @@
+blank_issues_enabled: false
+contact_links:
+ - name: Community Support
+ url: https://github.com/go-skynet/LocalAI/discussions
+ about: Please ask and answer questions here.
+ - name: Discord
+ url: https://discord.gg/uJAeKSAGDy
+ about: Join our community on Discord!
diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md
new file mode 100644
index 0000000000000000000000000000000000000000..d3b2873b2c5e6660af541a22ac88dace41ae7bcc
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE/feature_request.md
@@ -0,0 +1,20 @@
+---
+name: Feature request
+about: Suggest an idea for this project
+title: ''
+labels: enhancement, up-for-grabs
+---
+
+
+
+**Is your feature request related to a problem? Please describe.**
+
+
+**Describe the solution you'd like**
+
+
+**Describe alternatives you've considered**
+
+
+**Additional context**
+
diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md
new file mode 100644
index 0000000000000000000000000000000000000000..ec5e354c5740852dab276e85d0b53d04541d0923
--- /dev/null
+++ b/.github/PULL_REQUEST_TEMPLATE.md
@@ -0,0 +1,31 @@
+**Description**
+
+This PR fixes #
+
+**Notes for Reviewers**
+
+
+**[Signed commits](../CONTRIBUTING.md#signing-off-on-commits-developer-certificate-of-origin)**
+- [ ] Yes, I signed my commits.
+
+
\ No newline at end of file
diff --git a/.github/bump_deps.sh b/.github/bump_deps.sh
new file mode 100644
index 0000000000000000000000000000000000000000..28485ca922bd0e1826e3a4e24bdb900d8e7081aa
--- /dev/null
+++ b/.github/bump_deps.sh
@@ -0,0 +1,27 @@
+#!/bin/bash
+set -xe
+REPO=$1
+BRANCH=$2
+VAR=$3
+FILE=$4
+
+if [ -z "$FILE" ]; then
+ FILE="Makefile"
+fi
+
+LAST_COMMIT=$(curl -s -H "Accept: application/vnd.github.VERSION.sha" "https://api.github.com/repos/$REPO/commits/$BRANCH")
+
+# Read $VAR from Makefile (only first match)
+set +e
+CURRENT_COMMIT="$(grep -m1 "^$VAR?=" $FILE | cut -d'=' -f2)"
+set -e
+
+sed -i $FILE -e "s/$VAR?=.*/$VAR?=$LAST_COMMIT/"
+
+if [ -z "$CURRENT_COMMIT" ]; then
+ echo "Could not find $VAR in Makefile."
+ exit 0
+fi
+
+echo "Changes: https://github.com/$REPO/compare/${CURRENT_COMMIT}..${LAST_COMMIT}" >> "${VAR}_message.txt"
+echo "${LAST_COMMIT}" >> "${VAR}_commit.txt"
\ No newline at end of file
diff --git a/.github/bump_docs.sh b/.github/bump_docs.sh
new file mode 100644
index 0000000000000000000000000000000000000000..e69d3824d27fcf4592cb095ab418c9115aae79f6
--- /dev/null
+++ b/.github/bump_docs.sh
@@ -0,0 +1,7 @@
+#!/bin/bash
+set -xe
+REPO=$1
+
+LATEST_TAG=$(curl -s "https://api.github.com/repos/$REPO/releases/latest" | jq -r '.tag_name')
+
+cat <<< $(jq ".version = \"$LATEST_TAG\"" docs/data/version.json) > docs/data/version.json
diff --git a/.github/check_and_update.py b/.github/check_and_update.py
new file mode 100644
index 0000000000000000000000000000000000000000..704b658e67bf3caa60728ef6866b126e90831d28
--- /dev/null
+++ b/.github/check_and_update.py
@@ -0,0 +1,85 @@
+import hashlib
+from huggingface_hub import hf_hub_download, get_paths_info
+import requests
+import sys
+import os
+
+uri = sys.argv[1]
+file_name = uri.split('/')[-1]
+
+# Function to parse the URI and determine download method
+def parse_uri(uri):
+ if uri.startswith('huggingface://'):
+ repo_id = uri.split('://')[1]
+ return 'huggingface', repo_id.rsplit('/', 1)[0]
+ elif 'huggingface.co' in uri:
+ parts = uri.split('/resolve/')
+ if len(parts) > 1:
+ repo_path = parts[0].split('https://huggingface.co/')[-1]
+ return 'huggingface', repo_path
+ return 'direct', uri
+
+def calculate_sha256(file_path):
+ sha256_hash = hashlib.sha256()
+ with open(file_path, 'rb') as f:
+ for byte_block in iter(lambda: f.read(4096), b''):
+ sha256_hash.update(byte_block)
+ return sha256_hash.hexdigest()
+
+def manual_safety_check_hf(repo_id):
+ scanResponse = requests.get('https://huggingface.co/api/models/' + repo_id + "/scan")
+ scan = scanResponse.json()
+ # Check if 'hasUnsafeFile' exists in the response
+ if 'hasUnsafeFile' in scan:
+ if scan['hasUnsafeFile']:
+ return scan
+ else:
+ return None
+ else:
+ return None
+
+download_type, repo_id_or_url = parse_uri(uri)
+
+new_checksum = None
+file_path = None
+
+# Decide download method based on URI type
+if download_type == 'huggingface':
+ # Check if the repo is flagged as dangerous by HF
+ hazard = manual_safety_check_hf(repo_id_or_url)
+ if hazard != None:
+ print(f'Error: HuggingFace has detected security problems for {repo_id_or_url}: {str(hazard)}', filename=file_name)
+ sys.exit(5)
+ # Use HF API to pull sha
+ for file in get_paths_info(repo_id_or_url, [file_name], repo_type='model'):
+ try:
+ new_checksum = file.lfs.sha256
+ break
+ except Exception as e:
+ print(f'Error from Hugging Face Hub: {str(e)}', file=sys.stderr)
+ sys.exit(2)
+ if new_checksum is None:
+ try:
+ file_path = hf_hub_download(repo_id=repo_id_or_url, filename=file_name)
+ except Exception as e:
+ print(f'Error from Hugging Face Hub: {str(e)}', file=sys.stderr)
+ sys.exit(2)
+else:
+ response = requests.get(repo_id_or_url)
+ if response.status_code == 200:
+ with open(file_name, 'wb') as f:
+ f.write(response.content)
+ file_path = file_name
+ elif response.status_code == 404:
+ print(f'File not found: {response.status_code}', file=sys.stderr)
+ sys.exit(2)
+ else:
+ print(f'Error downloading file: {response.status_code}', file=sys.stderr)
+ sys.exit(1)
+
+if new_checksum is None:
+ new_checksum = calculate_sha256(file_path)
+ print(new_checksum)
+ os.remove(file_path)
+else:
+ print(new_checksum)
diff --git a/.github/checksum_checker.sh b/.github/checksum_checker.sh
new file mode 100644
index 0000000000000000000000000000000000000000..5cbd57f4a33b4093c61b2ebfab1248273f588b1b
--- /dev/null
+++ b/.github/checksum_checker.sh
@@ -0,0 +1,63 @@
+#!/bin/bash
+# This scripts needs yq and huggingface_hub to be installed
+# to install hugingface_hub run pip install huggingface_hub
+
+# Path to the input YAML file
+input_yaml=$1
+
+# Function to download file and check checksum using Python
+function check_and_update_checksum() {
+ model_name="$1"
+ file_name="$2"
+ uri="$3"
+ old_checksum="$4"
+ idx="$5"
+
+ # Download the file and calculate new checksum using Python
+ new_checksum=$(python3 ./.github/check_and_update.py $uri)
+ result=$?
+
+ if [[ $result -eq 5 ]]; then
+ echo "Contaminated entry detected, deleting entry for $model_name..."
+ yq eval -i "del([$idx])" "$input_yaml"
+ return
+ fi
+
+ if [[ "$new_checksum" == "" ]]; then
+ echo "Error calculating checksum for $file_name. Skipping..."
+ return
+ fi
+
+ echo "Checksum for $file_name: $new_checksum"
+
+ # Compare and update the YAML file if checksums do not match
+
+ if [[ $result -eq 2 ]]; then
+ echo "File not found, deleting entry for $file_name..."
+ # yq eval -i "del(.[$idx].files[] | select(.filename == \"$file_name\"))" "$input_yaml"
+ elif [[ "$old_checksum" != "$new_checksum" ]]; then
+ echo "Checksum mismatch for $file_name. Updating..."
+ yq eval -i "del(.[$idx].files[] | select(.filename == \"$file_name\").sha256)" "$input_yaml"
+ yq eval -i "(.[$idx].files[] | select(.filename == \"$file_name\")).sha256 = \"$new_checksum\"" "$input_yaml"
+ elif [[ $result -ne 0 ]]; then
+ echo "Error downloading file $file_name. Skipping..."
+ else
+ echo "Checksum match for $file_name. No update needed."
+ fi
+}
+
+# Read the YAML and process each file
+len=$(yq eval '. | length' "$input_yaml")
+for ((i=0; i<$len; i++))
+do
+ name=$(yq eval ".[$i].name" "$input_yaml")
+ files_len=$(yq eval ".[$i].files | length" "$input_yaml")
+ for ((j=0; j<$files_len; j++))
+ do
+ filename=$(yq eval ".[$i].files[$j].filename" "$input_yaml")
+ uri=$(yq eval ".[$i].files[$j].uri" "$input_yaml")
+ checksum=$(yq eval ".[$i].files[$j].sha256" "$input_yaml")
+ echo "Checking model $name, file $filename. URI = $uri, Checksum = $checksum"
+ check_and_update_checksum "$name" "$filename" "$uri" "$checksum" "$i"
+ done
+done
diff --git a/.github/ci/modelslist.go b/.github/ci/modelslist.go
new file mode 100644
index 0000000000000000000000000000000000000000..719cd094ae9dfeca130f84543c88dee215409798
--- /dev/null
+++ b/.github/ci/modelslist.go
@@ -0,0 +1,304 @@
+package main
+
+import (
+ "fmt"
+ "html/template"
+ "io/ioutil"
+ "os"
+
+ "github.com/microcosm-cc/bluemonday"
+ "gopkg.in/yaml.v3"
+)
+
+var modelPageTemplate string = `
+
+
+
+
+
+ LocalAI models
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ LocalAI model gallery list
+
+
+
+ 🖼️ Available {{.AvailableModels}} models
+
+
+
+
+ Refer to the Model gallery for more information on how to use the models with LocalAI.
+
+ You can install models with the CLI command local-ai models install . or by using the WebUI.
+
+
+
+
+ {{ range $_, $model := .Models }}
+
+
+ {{ $icon := "https://upload.wikimedia.org/wikipedia/commons/6/65/No-Image-Placeholder.svg" }}
+ {{ if $model.Icon }}
+ {{ $icon = $model.Icon }}
+ {{ end }}
+
+
+
+
+
{{$model.Name}}
+
+
+
{{ $model.Description }}
+
+
+
+
+
+
+ More info
+
+
+
+
+
+
+
+
+
+
+ {{ $model.Name}}
+
+
+
+
+
+ Close modal
+
+
+
+
+
+
+
+
+
+ {{ $model.Description }}
+
+
+
+
+ To install the model with the CLI, run:
+ local-ai models install {{$model.Name}}
+
+
+ See also
+ Installation
+ to see how to install models with the REST API.
+
+
+
+
+ {{ range $_, $u := $model.URLs }}
+ {{ $u }}
+ {{ end }}
+
+
+
+
+
+ Close
+
+
+
+
+
+
+
+
+
+ {{ end }}
+
+
+
+
+
+
+
+
+
+
+
+
+`
+
+type GalleryModel struct {
+ Name string `json:"name" yaml:"name"`
+ URLs []string `json:"urls" yaml:"urls"`
+ Icon string `json:"icon" yaml:"icon"`
+ Description string `json:"description" yaml:"description"`
+}
+
+func main() {
+ // read the YAML file which contains the models
+
+ f, err := ioutil.ReadFile(os.Args[1])
+ if err != nil {
+ fmt.Println("Error reading file:", err)
+ return
+ }
+
+ models := []*GalleryModel{}
+ err = yaml.Unmarshal(f, &models)
+ if err != nil {
+ // write to stderr
+ os.Stderr.WriteString("Error unmarshaling YAML: " + err.Error() + "\n")
+ return
+ }
+
+ // Ensure that all arbitrary text content is sanitized before display
+ for i, m := range models {
+ models[i].Name = bluemonday.StrictPolicy().Sanitize(m.Name)
+ models[i].Description = bluemonday.StrictPolicy().Sanitize(m.Description)
+ }
+
+ // render the template
+ data := struct {
+ Models []*GalleryModel
+ AvailableModels int
+ }{
+ Models: models,
+ AvailableModels: len(models),
+ }
+ tmpl := template.Must(template.New("modelPage").Parse(modelPageTemplate))
+
+ err = tmpl.Execute(os.Stdout, data)
+ if err != nil {
+ fmt.Println("Error executing template:", err)
+ return
+ }
+}
diff --git a/.github/dependabot.yml b/.github/dependabot.yml
new file mode 100644
index 0000000000000000000000000000000000000000..cf3a252b0bca9a6aa7661edeb25d1f131bf0a793
--- /dev/null
+++ b/.github/dependabot.yml
@@ -0,0 +1,119 @@
+# https://docs.github.com/en/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file
+version: 2
+updates:
+ - package-ecosystem: "gitsubmodule"
+ directory: "/"
+ schedule:
+ interval: "weekly"
+ - package-ecosystem: "gomod"
+ directory: "/"
+ schedule:
+ interval: "weekly"
+ ignore:
+ - dependency-name: "github.com/mudler/LocalAI/pkg/grpc/proto"
+ - package-ecosystem: "github-actions"
+ # Workflow files stored in the default location of `.github/workflows`. (You don't need to specify `/.github/workflows` for `directory`. You can use `directory: "/"`.)
+ directory: "/"
+ schedule:
+ # Check for updates to GitHub Actions every weekday
+ interval: "weekly"
+ - package-ecosystem: "pip"
+ # Workflow files stored in the default location of `.github/workflows`. (You don't need to specify `/.github/workflows` for `directory`. You can use `directory: "/"`.)
+ directory: "/"
+ schedule:
+ # Check for updates to GitHub Actions every weekday
+ interval: "weekly"
+ - package-ecosystem: "docker"
+ # Workflow files stored in the default location of `.github/workflows`. (You don't need to specify `/.github/workflows` for `directory`. You can use `directory: "/"`.)
+ directory: "/"
+ schedule:
+ # Check for updates to GitHub Actions every weekday
+ interval: "weekly"
+ - package-ecosystem: "pip"
+ directory: "/backend/python/bark"
+ schedule:
+ interval: "weekly"
+ - package-ecosystem: "pip"
+ directory: "/backend/python/common/template"
+ schedule:
+ interval: "weekly"
+ - package-ecosystem: "pip"
+ directory: "/backend/python/coqui"
+ schedule:
+ interval: "weekly"
+ - package-ecosystem: "pip"
+ directory: "/backend/python/diffusers"
+ schedule:
+ interval: "weekly"
+ - package-ecosystem: "pip"
+ directory: "/backend/python/exllama"
+ schedule:
+ interval: "weekly"
+ - package-ecosystem: "pip"
+ directory: "/backend/python/exllama2"
+ schedule:
+ interval: "weekly"
+ - package-ecosystem: "pip"
+ directory: "/backend/python/mamba"
+ schedule:
+ interval: "weekly"
+ - package-ecosystem: "pip"
+ directory: "/backend/python/openvoice"
+ schedule:
+ interval: "weekly"
+ - package-ecosystem: "pip"
+ directory: "/backend/python/rerankers"
+ schedule:
+ interval: "weekly"
+ - package-ecosystem: "pip"
+ directory: "/backend/python/sentencetransformers"
+ schedule:
+ interval: "weekly"
+ - package-ecosystem: "pip"
+ directory: "/backend/python/transformers"
+ schedule:
+ interval: "weekly"
+ - package-ecosystem: "pip"
+ directory: "/backend/python/vllm"
+ schedule:
+ interval: "weekly"
+ - package-ecosystem: "pip"
+ directory: "/examples/chainlit"
+ schedule:
+ interval: "weekly"
+ - package-ecosystem: "pip"
+ directory: "/examples/functions"
+ schedule:
+ interval: "weekly"
+ - package-ecosystem: "pip"
+ directory: "/examples/langchain/langchainpy-localai-example"
+ schedule:
+ interval: "weekly"
+ - package-ecosystem: "pip"
+ directory: "/examples/langchain-chroma"
+ schedule:
+ interval: "weekly"
+ - package-ecosystem: "pip"
+ directory: "/examples/streamlit-bot"
+ schedule:
+ interval: "weekly"
+ - package-ecosystem: "docker"
+ directory: "/examples/k8sgpt"
+ schedule:
+ interval: "weekly"
+ - package-ecosystem: "docker"
+ directory: "/examples/kubernetes"
+ schedule:
+ interval: "weekly"
+ - package-ecosystem: "docker"
+ directory: "/examples/langchain"
+ schedule:
+ interval: "weekly"
+ - package-ecosystem: "gomod"
+ directory: "/examples/semantic-todo"
+ schedule:
+ interval: "weekly"
+ - package-ecosystem: "docker"
+ directory: "/examples/telegram-bot"
+ schedule:
+ interval: "weekly"
diff --git a/.github/gallery-agent/agent.go b/.github/gallery-agent/agent.go
new file mode 100644
index 0000000000000000000000000000000000000000..7a40f717ba753163746021c71c9b70cfc29492e9
--- /dev/null
+++ b/.github/gallery-agent/agent.go
@@ -0,0 +1,445 @@
+package main
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+ "os"
+ "regexp"
+ "slices"
+ "strings"
+
+ "github.com/ghodss/yaml"
+ hfapi "github.com/mudler/LocalAI/pkg/huggingface-api"
+ cogito "github.com/mudler/cogito"
+
+ "github.com/mudler/cogito/structures"
+ "github.com/sashabaranov/go-openai/jsonschema"
+)
+
+var (
+ openAIModel = os.Getenv("OPENAI_MODEL")
+ openAIKey = os.Getenv("OPENAI_KEY")
+ openAIBaseURL = os.Getenv("OPENAI_BASE_URL")
+ galleryIndexPath = os.Getenv("GALLERY_INDEX_PATH")
+ //defaultclient
+ llm = cogito.NewOpenAILLM(openAIModel, openAIKey, openAIBaseURL)
+)
+
+// cleanTextContent removes trailing spaces, tabs, and normalizes line endings
+// to prevent YAML linting issues like trailing spaces and multiple empty lines
+func cleanTextContent(text string) string {
+ lines := strings.Split(text, "\n")
+ var cleanedLines []string
+ var prevEmpty bool
+ for _, line := range lines {
+ // Remove all trailing whitespace (spaces, tabs, etc.)
+ trimmed := strings.TrimRight(line, " \t\r")
+ // Avoid multiple consecutive empty lines
+ if trimmed == "" {
+ if !prevEmpty {
+ cleanedLines = append(cleanedLines, "")
+ }
+ prevEmpty = true
+ } else {
+ cleanedLines = append(cleanedLines, trimmed)
+ prevEmpty = false
+ }
+ }
+ // Remove trailing empty lines from the result
+ result := strings.Join(cleanedLines, "\n")
+ return stripThinkingTags(strings.TrimRight(result, "\n"))
+}
+
+type galleryModel struct {
+ Name string `yaml:"name"`
+ Urls []string `yaml:"urls"`
+}
+
+// isModelExisting checks if a specific model ID exists in the gallery using text search
+func isModelExisting(modelID string) (bool, error) {
+ indexPath := getGalleryIndexPath()
+ content, err := os.ReadFile(indexPath)
+ if err != nil {
+ return false, fmt.Errorf("failed to read %s: %w", indexPath, err)
+ }
+
+ var galleryModels []galleryModel
+
+ err = yaml.Unmarshal(content, &galleryModels)
+ if err != nil {
+ return false, fmt.Errorf("failed to unmarshal %s: %w", indexPath, err)
+ }
+
+ for _, galleryModel := range galleryModels {
+ if slices.Contains(galleryModel.Urls, modelID) {
+ return true, nil
+ }
+ }
+
+ return false, nil
+}
+
+// filterExistingModels removes models that already exist in the gallery
+func filterExistingModels(models []ProcessedModel) ([]ProcessedModel, error) {
+ var filteredModels []ProcessedModel
+ for _, model := range models {
+ exists, err := isModelExisting(model.ModelID)
+ if err != nil {
+ fmt.Printf("Error checking if model %s exists: %v, skipping\n", model.ModelID, err)
+ continue
+ }
+
+ if !exists {
+ filteredModels = append(filteredModels, model)
+ } else {
+ fmt.Printf("Skipping existing model: %s\n", model.ModelID)
+ }
+ }
+
+ fmt.Printf("Filtered out %d existing models, %d new models remaining\n",
+ len(models)-len(filteredModels), len(filteredModels))
+
+ return filteredModels, nil
+}
+
+// getGalleryIndexPath returns the gallery index file path, with a default fallback
+func getGalleryIndexPath() string {
+ if galleryIndexPath != "" {
+ return galleryIndexPath
+ }
+ return "gallery/index.yaml"
+}
+
+func stripThinkingTags(content string) string {
+ // Remove content between and (including multi-line)
+ content = regexp.MustCompile(`(?s).*? `).ReplaceAllString(content, "")
+ // Remove content between and (including multi-line)
+ content = regexp.MustCompile(`(?s).*? `).ReplaceAllString(content, "")
+ // Clean up any extra whitespace
+ content = strings.TrimSpace(content)
+ return content
+}
+
+func getRealReadme(ctx context.Context, repository string) (string, error) {
+ // Create a conversation fragment
+ fragment := cogito.NewEmptyFragment().
+ AddMessage("user",
+ `Your task is to get a clear description of a large language model from huggingface by using the provided tool. I will share with you a repository that might be quantized, and as such probably not by the original model author. We need to get the real description of the model, and not the one that might be quantized. You will have to call the tool to get the readme more than once by figuring out from the quantized readme which is the base model readme. This is the repository: `+repository)
+
+ // Execute with tools
+ result, err := cogito.ExecuteTools(llm, fragment,
+ cogito.WithIterations(3),
+ cogito.WithMaxAttempts(3),
+ cogito.WithTools(&HFReadmeTool{client: hfapi.NewClient()}))
+ if err != nil {
+ return "", err
+ }
+
+ result = result.AddMessage("user", "Describe the model in a clear and concise way that can be shared in a model gallery.")
+
+ // Get a response
+ newFragment, err := llm.Ask(ctx, result)
+ if err != nil {
+ return "", err
+ }
+
+ content := newFragment.LastMessage().Content
+ return cleanTextContent(content), nil
+}
+
+func selectMostInterestingModels(ctx context.Context, searchResult *SearchResult) ([]ProcessedModel, error) {
+
+ if len(searchResult.Models) == 1 {
+ return searchResult.Models, nil
+ }
+
+ // Create a conversation fragment
+ fragment := cogito.NewEmptyFragment().
+ AddMessage("user",
+ `Your task is to analyze a list of AI models and select the most interesting ones for a model gallery. You will be given detailed information about multiple models including their metadata, file information, and README content.
+
+Consider the following criteria when selecting models:
+1. Model popularity (download count)
+2. Model recency (last modified date)
+3. Model completeness (has preferred model file, README, etc.)
+4. Model uniqueness (not duplicates or very similar models)
+5. Model quality (based on README content and description)
+6. Model utility (practical applications)
+
+You should select models that would be most valuable for users browsing a model gallery. Prioritize models that are:
+- Well-documented with clear READMEs
+- Recently updated
+- Popular (high download count)
+- Have the preferred quantization format available
+- Offer unique capabilities or are from reputable authors
+
+Return your analysis and selection reasoning.`)
+
+ // Add the search results as context
+ modelsInfo := fmt.Sprintf("Found %d models matching '%s' with quantization preference '%s':\n\n",
+ searchResult.TotalModelsFound, searchResult.SearchTerm, searchResult.Quantization)
+
+ for i, model := range searchResult.Models {
+ modelsInfo += fmt.Sprintf("Model %d:\n", i+1)
+ modelsInfo += fmt.Sprintf(" ID: %s\n", model.ModelID)
+ modelsInfo += fmt.Sprintf(" Author: %s\n", model.Author)
+ modelsInfo += fmt.Sprintf(" Downloads: %d\n", model.Downloads)
+ modelsInfo += fmt.Sprintf(" Last Modified: %s\n", model.LastModified)
+ modelsInfo += fmt.Sprintf(" Files: %d files\n", len(model.Files))
+
+ if model.PreferredModelFile != nil {
+ modelsInfo += fmt.Sprintf(" Preferred Model File: %s (%d bytes)\n",
+ model.PreferredModelFile.Path, model.PreferredModelFile.Size)
+ } else {
+ modelsInfo += " No preferred model file found\n"
+ }
+
+ if model.ReadmeContent != "" {
+ modelsInfo += fmt.Sprintf(" README: %s\n", model.ReadmeContent)
+ }
+
+ if model.ProcessingError != "" {
+ modelsInfo += fmt.Sprintf(" Processing Error: %s\n", model.ProcessingError)
+ }
+
+ modelsInfo += "\n"
+ }
+
+ fragment = fragment.AddMessage("user", modelsInfo)
+
+ fragment = fragment.AddMessage("user", "Based on your analysis, select the top 5 most interesting models and provide a brief explanation for each selection. Also, create a filtered SearchResult with only the selected models. Return just a list of repositories IDs, you will later be asked to output it as a JSON array with the json tool.")
+
+ // Get a response
+ newFragment, err := llm.Ask(ctx, fragment)
+ if err != nil {
+ return nil, err
+ }
+
+ fmt.Println(newFragment.LastMessage().Content)
+ repositories := struct {
+ Repositories []string `json:"repositories"`
+ }{}
+
+ s := structures.Structure{
+ Schema: jsonschema.Definition{
+ Type: jsonschema.Object,
+ AdditionalProperties: false,
+ Properties: map[string]jsonschema.Definition{
+ "repositories": {
+ Type: jsonschema.Array,
+ Items: &jsonschema.Definition{Type: jsonschema.String},
+ Description: "The trending repositories IDs",
+ },
+ },
+ Required: []string{"repositories"},
+ },
+ Object: &repositories,
+ }
+
+ err = newFragment.ExtractStructure(ctx, llm, s)
+ if err != nil {
+ return nil, err
+ }
+
+ filteredModels := []ProcessedModel{}
+ for _, m := range searchResult.Models {
+ if slices.Contains(repositories.Repositories, m.ModelID) {
+ filteredModels = append(filteredModels, m)
+ }
+ }
+
+ return filteredModels, nil
+}
+
+// ModelMetadata represents extracted metadata from a model
+type ModelMetadata struct {
+ Tags []string `json:"tags"`
+ License string `json:"license"`
+}
+
+// extractModelMetadata extracts tags and license from model README and documentation
+func extractModelMetadata(ctx context.Context, model ProcessedModel) ([]string, string, error) {
+ // Create a conversation fragment
+ fragment := cogito.NewEmptyFragment().
+ AddMessage("user",
+ `Your task is to extract metadata from an AI model's README and documentation. You will be provided with:
+1. Model information (ID, author, description)
+2. README content
+
+You need to extract:
+1. **Tags**: An array of relevant tags that describe the model. Use common tags from the gallery such as:
+ - llm, gguf, gpu, cpu, multimodal, image-to-text, text-to-text, text-to-speech, tts
+ - thinking, reasoning, chat, instruction-tuned, code, vision
+ - Model family names (e.g., llama, qwen, mistral, gemma) if applicable
+ - Any other relevant descriptive tags
+ Select 3-8 most relevant tags.
+
+2. **License**: The license identifier (e.g., "apache-2.0", "mit", "llama2", "gpl-3.0", "bsd", "cc-by-4.0").
+ If no license is found, return an empty string.
+
+Return the extracted metadata in a structured format.`)
+
+ // Add model information
+ modelInfo := "Model Information:\n"
+ modelInfo += fmt.Sprintf(" ID: %s\n", model.ModelID)
+ modelInfo += fmt.Sprintf(" Author: %s\n", model.Author)
+ modelInfo += fmt.Sprintf(" Downloads: %d\n", model.Downloads)
+ if model.ReadmeContent != "" {
+ modelInfo += fmt.Sprintf(" README Content:\n%s\n", model.ReadmeContent)
+ } else if model.ReadmeContentPreview != "" {
+ modelInfo += fmt.Sprintf(" README Preview: %s\n", model.ReadmeContentPreview)
+ }
+
+ fragment = fragment.AddMessage("user", modelInfo)
+ fragment = fragment.AddMessage("user", "Extract the tags and license from the model information. Return the metadata as a JSON object with 'tags' (array of strings) and 'license' (string).")
+
+ // Get a response
+ newFragment, err := llm.Ask(ctx, fragment)
+ if err != nil {
+ return nil, "", err
+ }
+
+ // Extract structured metadata
+ metadata := ModelMetadata{}
+
+ s := structures.Structure{
+ Schema: jsonschema.Definition{
+ Type: jsonschema.Object,
+ AdditionalProperties: false,
+ Properties: map[string]jsonschema.Definition{
+ "tags": {
+ Type: jsonschema.Array,
+ Items: &jsonschema.Definition{Type: jsonschema.String},
+ Description: "Array of relevant tags describing the model",
+ },
+ "license": {
+ Type: jsonschema.String,
+ Description: "License identifier (e.g., apache-2.0, mit, llama2). Empty string if not found.",
+ },
+ },
+ Required: []string{"tags", "license"},
+ },
+ Object: &metadata,
+ }
+
+ err = newFragment.ExtractStructure(ctx, llm, s)
+ if err != nil {
+ return nil, "", err
+ }
+
+ return metadata.Tags, metadata.License, nil
+}
+
+// extractIconFromReadme scans the README content for image URLs and returns the first suitable icon URL found
+func extractIconFromReadme(readmeContent string) string {
+ if readmeContent == "" {
+ return ""
+ }
+
+ // Regular expressions to match image URLs in various formats (case-insensitive)
+ // Match markdown image syntax:  - case insensitive extensions
+ markdownImageRegex := regexp.MustCompile(`(?i)!\[[^\]]*\]\(([^)]+\.(png|jpg|jpeg|svg|webp|gif))\)`)
+ // Match HTML img tags:
+ htmlImageRegex := regexp.MustCompile(`(?i) ]+src=["']([^"']+\.(png|jpg|jpeg|svg|webp|gif))["']`)
+ // Match plain URLs ending with image extensions
+ plainImageRegex := regexp.MustCompile(`(?i)https?://[^\s<>"']+\.(png|jpg|jpeg|svg|webp|gif)`)
+
+ // Try markdown format first
+ matches := markdownImageRegex.FindStringSubmatch(readmeContent)
+ if len(matches) > 1 && matches[1] != "" {
+ url := strings.TrimSpace(matches[1])
+ // Prefer HuggingFace CDN URLs or absolute URLs
+ if strings.HasPrefix(strings.ToLower(url), "http") {
+ return url
+ }
+ }
+
+ // Try HTML img tags
+ matches = htmlImageRegex.FindStringSubmatch(readmeContent)
+ if len(matches) > 1 && matches[1] != "" {
+ url := strings.TrimSpace(matches[1])
+ if strings.HasPrefix(strings.ToLower(url), "http") {
+ return url
+ }
+ }
+
+ // Try plain URLs
+ matches = plainImageRegex.FindStringSubmatch(readmeContent)
+ if len(matches) > 0 {
+ url := strings.TrimSpace(matches[0])
+ if strings.HasPrefix(strings.ToLower(url), "http") {
+ return url
+ }
+ }
+
+ return ""
+}
+
+// getHuggingFaceAvatarURL attempts to get the HuggingFace avatar URL for a user
+func getHuggingFaceAvatarURL(author string) string {
+ if author == "" {
+ return ""
+ }
+
+ // Try to fetch user info from HuggingFace API
+ // HuggingFace API endpoint: https://huggingface.co/api/users/{username}
+ baseURL := "https://huggingface.co"
+ userURL := fmt.Sprintf("%s/api/users/%s", baseURL, author)
+
+ req, err := http.NewRequest("GET", userURL, nil)
+ if err != nil {
+ return ""
+ }
+
+ client := &http.Client{}
+ resp, err := client.Do(req)
+ if err != nil {
+ return ""
+ }
+ defer resp.Body.Close()
+
+ if resp.StatusCode != http.StatusOK {
+ return ""
+ }
+
+ // Parse the response to get avatar URL
+ var userInfo map[string]interface{}
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return ""
+ }
+
+ if err := json.Unmarshal(body, &userInfo); err != nil {
+ return ""
+ }
+
+ // Try to extract avatar URL from response
+ if avatar, ok := userInfo["avatarUrl"].(string); ok && avatar != "" {
+ return avatar
+ }
+ if avatar, ok := userInfo["avatar"].(string); ok && avatar != "" {
+ return avatar
+ }
+
+ return ""
+}
+
+// extractModelIcon extracts icon URL from README or falls back to HuggingFace avatar
+func extractModelIcon(model ProcessedModel) string {
+ // First, try to extract icon from README
+ if icon := extractIconFromReadme(model.ReadmeContent); icon != "" {
+ return icon
+ }
+
+ // Fallback: Try to get HuggingFace user avatar
+ if model.Author != "" {
+ if avatar := getHuggingFaceAvatarURL(model.Author); avatar != "" {
+ return avatar
+ }
+ }
+
+ return ""
+}
diff --git a/.github/gallery-agent/gallery.go b/.github/gallery-agent/gallery.go
new file mode 100644
index 0000000000000000000000000000000000000000..749001a79f42048f1532ceb68c053e8be0751826
--- /dev/null
+++ b/.github/gallery-agent/gallery.go
@@ -0,0 +1,200 @@
+package main
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "os"
+ "strings"
+
+ "github.com/ghodss/yaml"
+ "github.com/mudler/LocalAI/core/gallery/importers"
+)
+
+func formatTextContent(text string) string {
+ return formatTextContentWithIndent(text, 4, 6)
+}
+
+// formatTextContentWithIndent formats text content with specified base and list item indentation
+func formatTextContentWithIndent(text string, baseIndent int, listItemIndent int) string {
+ var formattedLines []string
+ lines := strings.Split(text, "\n")
+ for _, line := range lines {
+ trimmed := strings.TrimRight(line, " \t\r")
+ if trimmed == "" {
+ // Keep empty lines as empty (no indentation)
+ formattedLines = append(formattedLines, "")
+ } else {
+ // Preserve relative indentation from yaml.Marshal output
+ // Count existing leading spaces to preserve relative structure
+ leadingSpaces := len(trimmed) - len(strings.TrimLeft(trimmed, " \t"))
+ trimmedStripped := strings.TrimLeft(trimmed, " \t")
+
+ var totalIndent int
+ if strings.HasPrefix(trimmedStripped, "-") {
+ // List items: use listItemIndent (ignore existing leading spaces)
+ totalIndent = listItemIndent
+ } else {
+ // Regular lines: use baseIndent + preserve relative indentation
+ // This handles both top-level keys (leadingSpaces=0) and nested properties (leadingSpaces>0)
+ totalIndent = baseIndent + leadingSpaces
+ }
+
+ indentStr := strings.Repeat(" ", totalIndent)
+ formattedLines = append(formattedLines, indentStr+trimmedStripped)
+ }
+ }
+ formattedText := strings.Join(formattedLines, "\n")
+ // Remove any trailing spaces from the formatted description
+ formattedText = strings.TrimRight(formattedText, " \t")
+ return formattedText
+}
+
+// generateYAMLEntry generates a YAML entry for a model using the specified anchor
+func generateYAMLEntry(model ProcessedModel, quantization string) string {
+ modelConfig, err := importers.DiscoverModelConfig("https://huggingface.co/"+model.ModelID, json.RawMessage(`{ "quantization": "`+quantization+`"}`))
+ if err != nil {
+ panic(err)
+ }
+
+ // Extract model name from ModelID
+ parts := strings.Split(model.ModelID, "/")
+ modelName := model.ModelID
+ if len(parts) > 0 {
+ modelName = strings.ToLower(parts[len(parts)-1])
+ }
+ // Remove common suffixes
+ modelName = strings.ReplaceAll(modelName, "-gguf", "")
+ modelName = strings.ReplaceAll(modelName, "-q4_k_m", "")
+ modelName = strings.ReplaceAll(modelName, "-q4_k_s", "")
+ modelName = strings.ReplaceAll(modelName, "-q3_k_m", "")
+ modelName = strings.ReplaceAll(modelName, "-q2_k", "")
+
+ description := model.ReadmeContent
+ if description == "" {
+ description = fmt.Sprintf("AI model: %s", modelName)
+ }
+
+ // Clean up description to prevent YAML linting issues
+ description = cleanTextContent(description)
+ formattedDescription := formatTextContent(description)
+
+ configFile := formatTextContent(modelConfig.ConfigFile)
+
+ filesYAML, _ := yaml.Marshal(modelConfig.Files)
+
+ // Files section: list items need 4 spaces (not 6), since files: is at 2 spaces
+ files := formatTextContentWithIndent(string(filesYAML), 4, 4)
+
+ // Build metadata sections
+ var metadataSections []string
+
+ // Add license if present
+ if model.License != "" {
+ metadataSections = append(metadataSections, fmt.Sprintf(` license: "%s"`, model.License))
+ }
+
+ // Add tags if present
+ if len(model.Tags) > 0 {
+ tagsYAML, _ := yaml.Marshal(model.Tags)
+ tagsFormatted := formatTextContentWithIndent(string(tagsYAML), 4, 4)
+ tagsFormatted = strings.TrimRight(tagsFormatted, "\n")
+ metadataSections = append(metadataSections, fmt.Sprintf(" tags:\n%s", tagsFormatted))
+ }
+
+ // Add icon if present
+ if model.Icon != "" {
+ metadataSections = append(metadataSections, fmt.Sprintf(` icon: %s`, model.Icon))
+ }
+
+ // Build the metadata block
+ metadataBlock := ""
+ if len(metadataSections) > 0 {
+ metadataBlock = strings.Join(metadataSections, "\n") + "\n"
+ }
+
+ yamlTemplate := ""
+ yamlTemplate = `- name: "%s"
+ url: "github:mudler/LocalAI/gallery/virtual.yaml@master"
+ urls:
+ - https://huggingface.co/%s
+ description: |
+%s%s
+ overrides:
+%s
+ files:
+%s`
+ // Trim trailing newlines from formatted sections to prevent extra blank lines
+ formattedDescription = strings.TrimRight(formattedDescription, "\n")
+ configFile = strings.TrimRight(configFile, "\n")
+ files = strings.TrimRight(files, "\n")
+ // Add newline before metadata block if present
+ if metadataBlock != "" {
+ metadataBlock = "\n" + strings.TrimRight(metadataBlock, "\n")
+ }
+ return fmt.Sprintf(yamlTemplate,
+ modelName,
+ model.ModelID,
+ formattedDescription,
+ metadataBlock,
+ configFile,
+ files,
+ )
+}
+
+// generateYAMLForModels generates YAML entries for selected models and appends to index.yaml
+func generateYAMLForModels(ctx context.Context, models []ProcessedModel, quantization string) error {
+
+ // Generate YAML entries for each model
+ var yamlEntries []string
+ for _, model := range models {
+ fmt.Printf("Generating YAML entry for model: %s\n", model.ModelID)
+
+ // Generate YAML entry
+ yamlEntry := generateYAMLEntry(model, quantization)
+ yamlEntries = append(yamlEntries, yamlEntry)
+ }
+
+ // Prepend to index.yaml (write at the top)
+ if len(yamlEntries) > 0 {
+ indexPath := getGalleryIndexPath()
+ fmt.Printf("Prepending YAML entries to %s...\n", indexPath)
+
+ // Read current content
+ content, err := os.ReadFile(indexPath)
+ if err != nil {
+ return fmt.Errorf("failed to read %s: %w", indexPath, err)
+ }
+
+ existingContent := string(content)
+ yamlBlock := strings.Join(yamlEntries, "\n")
+
+ // Check if file starts with "---"
+ var newContent string
+ if strings.HasPrefix(existingContent, "---\n") {
+ // File starts with "---", prepend new entries after it
+ restOfContent := strings.TrimPrefix(existingContent, "---\n")
+ // Ensure proper spacing: "---\n" + new entries + "\n" + rest of content
+ newContent = "---\n" + yamlBlock + "\n" + restOfContent
+ } else if strings.HasPrefix(existingContent, "---") {
+ // File starts with "---" but no newline after
+ restOfContent := strings.TrimPrefix(existingContent, "---")
+ newContent = "---\n" + yamlBlock + "\n" + strings.TrimPrefix(restOfContent, "\n")
+ } else {
+ // No "---" at start, prepend new entries at the very beginning
+ // Trim leading whitespace from existing content
+ existingContent = strings.TrimLeft(existingContent, " \t\n\r")
+ newContent = yamlBlock + "\n" + existingContent
+ }
+
+ // Write back to file
+ err = os.WriteFile(indexPath, []byte(newContent), 0644)
+ if err != nil {
+ return fmt.Errorf("failed to write %s: %w", indexPath, err)
+ }
+
+ fmt.Printf("Successfully prepended %d models to %s\n", len(yamlEntries), indexPath)
+ }
+
+ return nil
+}
diff --git a/.github/gallery-agent/main.go b/.github/gallery-agent/main.go
new file mode 100644
index 0000000000000000000000000000000000000000..1aa58a0eef1429d879512a2e9307900c6f3fd0a8
--- /dev/null
+++ b/.github/gallery-agent/main.go
@@ -0,0 +1,383 @@
+package main
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "os"
+ "strconv"
+ "strings"
+ "time"
+
+ hfapi "github.com/mudler/LocalAI/pkg/huggingface-api"
+)
+
+// ProcessedModelFile represents a processed model file with additional metadata
+type ProcessedModelFile struct {
+ Path string `json:"path"`
+ Size int64 `json:"size"`
+ SHA256 string `json:"sha256"`
+ IsReadme bool `json:"is_readme"`
+ FileType string `json:"file_type"` // "model", "readme", "other"
+}
+
+// ProcessedModel represents a processed model with all gathered metadata
+type ProcessedModel struct {
+ ModelID string `json:"model_id"`
+ Author string `json:"author"`
+ Downloads int `json:"downloads"`
+ LastModified string `json:"last_modified"`
+ Files []ProcessedModelFile `json:"files"`
+ PreferredModelFile *ProcessedModelFile `json:"preferred_model_file,omitempty"`
+ ReadmeFile *ProcessedModelFile `json:"readme_file,omitempty"`
+ ReadmeContent string `json:"readme_content,omitempty"`
+ ReadmeContentPreview string `json:"readme_content_preview,omitempty"`
+ QuantizationPreferences []string `json:"quantization_preferences"`
+ ProcessingError string `json:"processing_error,omitempty"`
+ Tags []string `json:"tags,omitempty"`
+ License string `json:"license,omitempty"`
+ Icon string `json:"icon,omitempty"`
+}
+
+// SearchResult represents the complete result of searching and processing models
+type SearchResult struct {
+ SearchTerm string `json:"search_term"`
+ Limit int `json:"limit"`
+ Quantization string `json:"quantization"`
+ TotalModelsFound int `json:"total_models_found"`
+ Models []ProcessedModel `json:"models"`
+ FormattedOutput string `json:"formatted_output"`
+}
+
+// AddedModelSummary represents a summary of models added to the gallery
+type AddedModelSummary struct {
+ SearchTerm string `json:"search_term"`
+ TotalFound int `json:"total_found"`
+ ModelsAdded int `json:"models_added"`
+ AddedModelIDs []string `json:"added_model_ids"`
+ AddedModelURLs []string `json:"added_model_urls"`
+ Quantization string `json:"quantization"`
+ ProcessingTime string `json:"processing_time"`
+}
+
+func main() {
+ startTime := time.Now()
+
+ // Check for synthetic mode
+ syntheticMode := os.Getenv("SYNTHETIC_MODE")
+ if syntheticMode == "true" || syntheticMode == "1" {
+ fmt.Println("Running in SYNTHETIC MODE - generating random test data")
+ err := runSyntheticMode()
+ if err != nil {
+ fmt.Fprintf(os.Stderr, "Error in synthetic mode: %v\n", err)
+ os.Exit(1)
+ }
+ return
+ }
+
+ // Get configuration from environment variables
+ searchTerm := os.Getenv("SEARCH_TERM")
+ if searchTerm == "" {
+ searchTerm = "GGUF"
+ }
+
+ limitStr := os.Getenv("LIMIT")
+ if limitStr == "" {
+ limitStr = "5"
+ }
+ limit, err := strconv.Atoi(limitStr)
+ if err != nil {
+ fmt.Fprintf(os.Stderr, "Error parsing LIMIT: %v\n", err)
+ os.Exit(1)
+ }
+
+ quantization := os.Getenv("QUANTIZATION")
+
+ maxModels := os.Getenv("MAX_MODELS")
+ if maxModels == "" {
+ maxModels = "1"
+ }
+ maxModelsInt, err := strconv.Atoi(maxModels)
+ if err != nil {
+ fmt.Fprintf(os.Stderr, "Error parsing MAX_MODELS: %v\n", err)
+ os.Exit(1)
+ }
+
+ // Print configuration
+ fmt.Printf("Gallery Agent Configuration:\n")
+ fmt.Printf(" Search Term: %s\n", searchTerm)
+ fmt.Printf(" Limit: %d\n", limit)
+ fmt.Printf(" Quantization: %s\n", quantization)
+ fmt.Printf(" Max Models to Add: %d\n", maxModelsInt)
+ fmt.Printf(" Gallery Index Path: %s\n", os.Getenv("GALLERY_INDEX_PATH"))
+ fmt.Println()
+
+ result, err := searchAndProcessModels(searchTerm, limit, quantization)
+ if err != nil {
+ fmt.Fprintf(os.Stderr, "Error: %v\n", err)
+ os.Exit(1)
+ }
+
+ fmt.Println(result.FormattedOutput)
+ var models []ProcessedModel
+
+ if len(result.Models) > 1 {
+ fmt.Println("More than one model found (", len(result.Models), "), using AI agent to select the most interesting models")
+ for _, model := range result.Models {
+ fmt.Println("Model: ", model.ModelID)
+ }
+ // Use AI agent to select the most interesting models
+ fmt.Println("Using AI agent to select the most interesting models...")
+ models, err = selectMostInterestingModels(context.Background(), result)
+ if err != nil {
+ fmt.Fprintf(os.Stderr, "Error in model selection: %v\n", err)
+ // Continue with original result if selection fails
+ models = result.Models
+ }
+ } else if len(result.Models) == 1 {
+ models = result.Models
+ fmt.Println("Only one model found, using it directly")
+ }
+
+ fmt.Print(models)
+
+ // Filter out models that already exist in the gallery
+ fmt.Println("Filtering out existing models...")
+ models, err = filterExistingModels(models)
+ if err != nil {
+ fmt.Fprintf(os.Stderr, "Error filtering existing models: %v\n", err)
+ os.Exit(1)
+ }
+
+ // Limit to maxModelsInt after filtering
+ if len(models) > maxModelsInt {
+ models = models[:maxModelsInt]
+ }
+
+ // Track added models for summary
+ var addedModelIDs []string
+ var addedModelURLs []string
+
+ // Generate YAML entries and append to gallery/index.yaml
+ if len(models) > 0 {
+ for _, model := range models {
+ addedModelIDs = append(addedModelIDs, model.ModelID)
+ // Generate Hugging Face URL for the model
+ modelURL := fmt.Sprintf("https://huggingface.co/%s", model.ModelID)
+ addedModelURLs = append(addedModelURLs, modelURL)
+ }
+ fmt.Println("Generating YAML entries for selected models...")
+ err = generateYAMLForModels(context.Background(), models, quantization)
+ if err != nil {
+ fmt.Fprintf(os.Stderr, "Error generating YAML entries: %v\n", err)
+ os.Exit(1)
+ }
+ } else {
+ fmt.Println("No new models to add to the gallery.")
+ }
+
+ // Create and write summary
+ processingTime := time.Since(startTime).String()
+ summary := AddedModelSummary{
+ SearchTerm: searchTerm,
+ TotalFound: result.TotalModelsFound,
+ ModelsAdded: len(addedModelIDs),
+ AddedModelIDs: addedModelIDs,
+ AddedModelURLs: addedModelURLs,
+ Quantization: quantization,
+ ProcessingTime: processingTime,
+ }
+
+ // Write summary to file
+ summaryData, err := json.MarshalIndent(summary, "", " ")
+ if err != nil {
+ fmt.Fprintf(os.Stderr, "Error marshaling summary: %v\n", err)
+ } else {
+ err = os.WriteFile("gallery-agent-summary.json", summaryData, 0644)
+ if err != nil {
+ fmt.Fprintf(os.Stderr, "Error writing summary file: %v\n", err)
+ } else {
+ fmt.Printf("Summary written to gallery-agent-summary.json\n")
+ }
+ }
+}
+
+func searchAndProcessModels(searchTerm string, limit int, quantization string) (*SearchResult, error) {
+ client := hfapi.NewClient()
+ var outputBuilder strings.Builder
+
+ fmt.Println("Searching for models...")
+ // Initialize the result struct
+ result := &SearchResult{
+ SearchTerm: searchTerm,
+ Limit: limit,
+ Quantization: quantization,
+ Models: []ProcessedModel{},
+ }
+
+ models, err := client.GetLatest(searchTerm, limit)
+ if err != nil {
+ return nil, fmt.Errorf("failed to fetch models: %w", err)
+ }
+
+ fmt.Println("Models found:", len(models))
+ result.TotalModelsFound = len(models)
+
+ if len(models) == 0 {
+ outputBuilder.WriteString("No models found.\n")
+ result.FormattedOutput = outputBuilder.String()
+ return result, nil
+ }
+
+ outputBuilder.WriteString(fmt.Sprintf("Found %d models matching '%s':\n\n", len(models), searchTerm))
+
+ // Process each model
+ for i, model := range models {
+ outputBuilder.WriteString(fmt.Sprintf("%d. Processing Model: %s\n", i+1, model.ModelID))
+ outputBuilder.WriteString(fmt.Sprintf(" Author: %s\n", model.Author))
+ outputBuilder.WriteString(fmt.Sprintf(" Downloads: %d\n", model.Downloads))
+ outputBuilder.WriteString(fmt.Sprintf(" Last Modified: %s\n", model.LastModified))
+
+ // Initialize processed model struct
+ processedModel := ProcessedModel{
+ ModelID: model.ModelID,
+ Author: model.Author,
+ Downloads: model.Downloads,
+ LastModified: model.LastModified,
+ QuantizationPreferences: []string{quantization, "Q4_K_M", "Q4_K_S", "Q3_K_M", "Q2_K"},
+ }
+
+ // Get detailed model information
+ details, err := client.GetModelDetails(model.ModelID)
+ if err != nil {
+ errorMsg := fmt.Sprintf(" Error getting model details: %v\n", err)
+ outputBuilder.WriteString(errorMsg)
+ processedModel.ProcessingError = err.Error()
+ result.Models = append(result.Models, processedModel)
+ continue
+ }
+
+ // Define quantization preferences (in order of preference)
+ quantizationPreferences := []string{quantization, "Q4_K_M", "Q4_K_S", "Q3_K_M", "Q2_K"}
+
+ // Find preferred model file
+ preferredModelFile := hfapi.FindPreferredModelFile(details.Files, quantizationPreferences)
+
+ // Process files
+ processedFiles := make([]ProcessedModelFile, len(details.Files))
+ for j, file := range details.Files {
+ fileType := "other"
+ if file.IsReadme {
+ fileType = "readme"
+ } else if preferredModelFile != nil && file.Path == preferredModelFile.Path {
+ fileType = "model"
+ }
+
+ processedFiles[j] = ProcessedModelFile{
+ Path: file.Path,
+ Size: file.Size,
+ SHA256: file.SHA256,
+ IsReadme: file.IsReadme,
+ FileType: fileType,
+ }
+ }
+
+ processedModel.Files = processedFiles
+
+ // Set preferred model file
+ if preferredModelFile != nil {
+ for _, file := range processedFiles {
+ if file.Path == preferredModelFile.Path {
+ processedModel.PreferredModelFile = &file
+ break
+ }
+ }
+ }
+
+ // Print file information
+ outputBuilder.WriteString(fmt.Sprintf(" Files found: %d\n", len(details.Files)))
+
+ if preferredModelFile != nil {
+ outputBuilder.WriteString(fmt.Sprintf(" Preferred Model File: %s (SHA256: %s)\n",
+ preferredModelFile.Path,
+ preferredModelFile.SHA256))
+ } else {
+ outputBuilder.WriteString(fmt.Sprintf(" No model file found with quantization preferences: %v\n", quantizationPreferences))
+ }
+
+ if details.ReadmeFile != nil {
+ outputBuilder.WriteString(fmt.Sprintf(" README File: %s\n", details.ReadmeFile.Path))
+
+ // Find and set readme file
+ for _, file := range processedFiles {
+ if file.IsReadme {
+ processedModel.ReadmeFile = &file
+ break
+ }
+ }
+
+ fmt.Println("Getting real readme for", model.ModelID, "waiting...")
+ // Use agent to get the real readme and prepare the model description
+ readmeContent, err := getRealReadme(context.Background(), model.ModelID)
+ if err == nil {
+ processedModel.ReadmeContent = readmeContent
+ processedModel.ReadmeContentPreview = truncateString(readmeContent, 200)
+ outputBuilder.WriteString(fmt.Sprintf(" README Content Preview: %s\n",
+ processedModel.ReadmeContentPreview))
+ } else {
+ fmt.Printf(" Warning: Failed to get real readme: %v\n", err)
+ }
+ fmt.Println("Real readme got", readmeContent)
+
+ // Extract metadata (tags, license) from README using LLM
+ fmt.Println("Extracting metadata for", model.ModelID, "waiting...")
+ tags, license, err := extractModelMetadata(context.Background(), processedModel)
+ if err == nil {
+ processedModel.Tags = tags
+ processedModel.License = license
+ outputBuilder.WriteString(fmt.Sprintf(" Tags: %v\n", tags))
+ outputBuilder.WriteString(fmt.Sprintf(" License: %s\n", license))
+ } else {
+ fmt.Printf(" Warning: Failed to extract metadata: %v\n", err)
+ }
+
+ // Extract icon from README or use HuggingFace avatar
+ icon := extractModelIcon(processedModel)
+ if icon != "" {
+ processedModel.Icon = icon
+ outputBuilder.WriteString(fmt.Sprintf(" Icon: %s\n", icon))
+ }
+ // Get README content
+ // readmeContent, err := client.GetReadmeContent(model.ModelID, details.ReadmeFile.Path)
+ // if err == nil {
+ // processedModel.ReadmeContent = readmeContent
+ // processedModel.ReadmeContentPreview = truncateString(readmeContent, 200)
+ // outputBuilder.WriteString(fmt.Sprintf(" README Content Preview: %s\n",
+ // processedModel.ReadmeContentPreview))
+ // }
+ }
+
+ // Print all files with their checksums
+ outputBuilder.WriteString(" All Files:\n")
+ for _, file := range processedFiles {
+ outputBuilder.WriteString(fmt.Sprintf(" - %s (%s, %d bytes", file.Path, file.FileType, file.Size))
+ if file.SHA256 != "" {
+ outputBuilder.WriteString(fmt.Sprintf(", SHA256: %s", file.SHA256))
+ }
+ outputBuilder.WriteString(")\n")
+ }
+
+ outputBuilder.WriteString("\n")
+ result.Models = append(result.Models, processedModel)
+ }
+
+ result.FormattedOutput = outputBuilder.String()
+ return result, nil
+}
+
+func truncateString(s string, maxLen int) string {
+ if len(s) <= maxLen {
+ return s
+ }
+ return s[:maxLen] + "..."
+}
diff --git a/.github/gallery-agent/testing.go b/.github/gallery-agent/testing.go
new file mode 100644
index 0000000000000000000000000000000000000000..c7960a9f2ba45e7491ed4ece738533eb3b85a0f0
--- /dev/null
+++ b/.github/gallery-agent/testing.go
@@ -0,0 +1,224 @@
+package main
+
+import (
+ "context"
+ "fmt"
+ "math/rand"
+ "strings"
+ "time"
+)
+
+// runSyntheticMode generates synthetic test data and appends it to the gallery
+func runSyntheticMode() error {
+ generator := NewSyntheticDataGenerator()
+
+ // Generate a random number of synthetic models (1-3)
+ numModels := generator.rand.Intn(3) + 1
+ fmt.Printf("Generating %d synthetic models for testing...\n", numModels)
+
+ var models []ProcessedModel
+ for i := 0; i < numModels; i++ {
+ model := generator.GenerateProcessedModel()
+ models = append(models, model)
+ fmt.Printf("Generated synthetic model: %s\n", model.ModelID)
+ }
+
+ // Generate YAML entries and append to gallery/index.yaml
+ fmt.Println("Generating YAML entries for synthetic models...")
+ err := generateYAMLForModels(context.Background(), models, "Q4_K_M")
+ if err != nil {
+ return fmt.Errorf("error generating YAML entries: %w", err)
+ }
+
+ fmt.Printf("Successfully added %d synthetic models to the gallery for testing!\n", len(models))
+ return nil
+}
+
+// SyntheticDataGenerator provides methods to generate synthetic test data
+type SyntheticDataGenerator struct {
+ rand *rand.Rand
+}
+
+// NewSyntheticDataGenerator creates a new synthetic data generator
+func NewSyntheticDataGenerator() *SyntheticDataGenerator {
+ return &SyntheticDataGenerator{
+ rand: rand.New(rand.NewSource(time.Now().UnixNano())),
+ }
+}
+
+// GenerateProcessedModelFile creates a synthetic ProcessedModelFile
+func (g *SyntheticDataGenerator) GenerateProcessedModelFile() ProcessedModelFile {
+ fileTypes := []string{"model", "readme", "other"}
+ fileType := fileTypes[g.rand.Intn(len(fileTypes))]
+
+ var path string
+ var isReadme bool
+
+ switch fileType {
+ case "model":
+ path = fmt.Sprintf("model-%s.gguf", g.randomString(8))
+ isReadme = false
+ case "readme":
+ path = "README.md"
+ isReadme = true
+ default:
+ path = fmt.Sprintf("file-%s.txt", g.randomString(6))
+ isReadme = false
+ }
+
+ return ProcessedModelFile{
+ Path: path,
+ Size: int64(g.rand.Intn(1000000000) + 1000000), // 1MB to 1GB
+ SHA256: g.randomSHA256(),
+ IsReadme: isReadme,
+ FileType: fileType,
+ }
+}
+
+// GenerateProcessedModel creates a synthetic ProcessedModel
+func (g *SyntheticDataGenerator) GenerateProcessedModel() ProcessedModel {
+ authors := []string{"microsoft", "meta", "google", "openai", "anthropic", "mistralai", "huggingface"}
+ modelNames := []string{"llama", "gpt", "claude", "mistral", "gemma", "phi", "qwen", "codellama"}
+
+ author := authors[g.rand.Intn(len(authors))]
+ modelName := modelNames[g.rand.Intn(len(modelNames))]
+ modelID := fmt.Sprintf("%s/%s-%s", author, modelName, g.randomString(6))
+
+ // Generate files
+ numFiles := g.rand.Intn(5) + 2 // 2-6 files
+ files := make([]ProcessedModelFile, numFiles)
+
+ // Ensure at least one model file and one readme
+ hasModelFile := false
+ hasReadme := false
+
+ for i := 0; i < numFiles; i++ {
+ files[i] = g.GenerateProcessedModelFile()
+ if files[i].FileType == "model" {
+ hasModelFile = true
+ }
+ if files[i].FileType == "readme" {
+ hasReadme = true
+ }
+ }
+
+ // Add required files if missing
+ if !hasModelFile {
+ modelFile := g.GenerateProcessedModelFile()
+ modelFile.FileType = "model"
+ modelFile.Path = fmt.Sprintf("%s-Q4_K_M.gguf", modelName)
+ files = append(files, modelFile)
+ }
+
+ if !hasReadme {
+ readmeFile := g.GenerateProcessedModelFile()
+ readmeFile.FileType = "readme"
+ readmeFile.Path = "README.md"
+ readmeFile.IsReadme = true
+ files = append(files, readmeFile)
+ }
+
+ // Find preferred model file
+ var preferredModelFile *ProcessedModelFile
+ for i := range files {
+ if files[i].FileType == "model" {
+ preferredModelFile = &files[i]
+ break
+ }
+ }
+
+ // Find readme file
+ var readmeFile *ProcessedModelFile
+ for i := range files {
+ if files[i].FileType == "readme" {
+ readmeFile = &files[i]
+ break
+ }
+ }
+
+ readmeContent := g.generateReadmeContent(modelName, author)
+
+ // Generate sample metadata
+ licenses := []string{"apache-2.0", "mit", "llama2", "gpl-3.0", "bsd", ""}
+ license := licenses[g.rand.Intn(len(licenses))]
+
+ sampleTags := []string{"llm", "gguf", "gpu", "cpu", "text-to-text", "chat", "instruction-tuned"}
+ numTags := g.rand.Intn(4) + 3 // 3-6 tags
+ tags := make([]string, numTags)
+ for i := 0; i < numTags; i++ {
+ tags[i] = sampleTags[g.rand.Intn(len(sampleTags))]
+ }
+ // Remove duplicates
+ tags = g.removeDuplicates(tags)
+
+ // Optionally include icon (50% chance)
+ icon := ""
+ if g.rand.Intn(2) == 0 {
+ icon = fmt.Sprintf("https://cdn-avatars.huggingface.co/v1/production/uploads/%s.png", g.randomString(24))
+ }
+
+ return ProcessedModel{
+ ModelID: modelID,
+ Author: author,
+ Downloads: g.rand.Intn(1000000) + 1000,
+ LastModified: g.randomDate(),
+ Files: files,
+ PreferredModelFile: preferredModelFile,
+ ReadmeFile: readmeFile,
+ ReadmeContent: readmeContent,
+ ReadmeContentPreview: truncateString(readmeContent, 200),
+ QuantizationPreferences: []string{"Q4_K_M", "Q4_K_S", "Q3_K_M", "Q2_K"},
+ ProcessingError: "",
+ Tags: tags,
+ License: license,
+ Icon: icon,
+ }
+}
+
+// Helper methods for synthetic data generation
+func (g *SyntheticDataGenerator) randomString(length int) string {
+ const charset = "abcdefghijklmnopqrstuvwxyz0123456789"
+ b := make([]byte, length)
+ for i := range b {
+ b[i] = charset[g.rand.Intn(len(charset))]
+ }
+ return string(b)
+}
+
+func (g *SyntheticDataGenerator) randomSHA256() string {
+ const charset = "0123456789abcdef"
+ b := make([]byte, 64)
+ for i := range b {
+ b[i] = charset[g.rand.Intn(len(charset))]
+ }
+ return string(b)
+}
+
+func (g *SyntheticDataGenerator) randomDate() string {
+ now := time.Now()
+ daysAgo := g.rand.Intn(365) // Random date within last year
+ pastDate := now.AddDate(0, 0, -daysAgo)
+ return pastDate.Format("2006-01-02T15:04:05.000Z")
+}
+
+func (g *SyntheticDataGenerator) removeDuplicates(slice []string) []string {
+ keys := make(map[string]bool)
+ result := []string{}
+ for _, item := range slice {
+ if !keys[item] {
+ keys[item] = true
+ result = append(result, item)
+ }
+ }
+ return result
+}
+
+func (g *SyntheticDataGenerator) generateReadmeContent(modelName, author string) string {
+ templates := []string{
+ fmt.Sprintf("# %s Model\n\nThis is a %s model developed by %s. It's designed for various natural language processing tasks including text generation, question answering, and conversation.\n\n## Features\n\n- High-quality text generation\n- Efficient inference\n- Multiple quantization options\n- Easy to use with LocalAI\n\n## Usage\n\nUse this model with LocalAI for various AI tasks.", strings.Title(modelName), modelName, author),
+ fmt.Sprintf("# %s\n\nA powerful language model from %s. This model excels at understanding and generating human-like text across multiple domains.\n\n## Capabilities\n\n- Text completion\n- Code generation\n- Creative writing\n- Technical documentation\n\n## Model Details\n\n- Architecture: Transformer-based\n- Training: Large-scale supervised learning\n- Quantization: Available in multiple formats", strings.Title(modelName), author),
+ fmt.Sprintf("# %s Language Model\n\nDeveloped by %s, this model represents state-of-the-art performance in natural language understanding and generation.\n\n## Key Features\n\n- Multilingual support\n- Context-aware responses\n- Efficient memory usage\n- Fast inference speed\n\n## Applications\n\n- Chatbots and virtual assistants\n- Content generation\n- Code completion\n- Educational tools", strings.Title(modelName), author),
+ }
+
+ return templates[g.rand.Intn(len(templates))]
+}
diff --git a/.github/gallery-agent/tools.go b/.github/gallery-agent/tools.go
new file mode 100644
index 0000000000000000000000000000000000000000..3e2fc2f3a17c7028de4b5c0c6fc695fa80118cef
--- /dev/null
+++ b/.github/gallery-agent/tools.go
@@ -0,0 +1,46 @@
+package main
+
+import (
+ "fmt"
+
+ hfapi "github.com/mudler/LocalAI/pkg/huggingface-api"
+ openai "github.com/sashabaranov/go-openai"
+ jsonschema "github.com/sashabaranov/go-openai/jsonschema"
+)
+
+// Get repository README from HF
+type HFReadmeTool struct {
+ client *hfapi.Client
+}
+
+func (s *HFReadmeTool) Execute(args map[string]any) (string, error) {
+ q, ok := args["repository"].(string)
+ if !ok {
+ return "", fmt.Errorf("no query")
+ }
+ readme, err := s.client.GetReadmeContent(q, "README.md")
+ if err != nil {
+ return "", err
+ }
+ return readme, nil
+}
+
+func (s *HFReadmeTool) Tool() openai.Tool {
+ return openai.Tool{
+ Type: openai.ToolTypeFunction,
+ Function: &openai.FunctionDefinition{
+ Name: "hf_readme",
+ Description: "A tool to get the README content of a huggingface repository",
+ Parameters: jsonschema.Definition{
+ Type: jsonschema.Object,
+ Properties: map[string]jsonschema.Definition{
+ "repository": {
+ Type: jsonschema.String,
+ Description: "The huggingface repository to get the README content of",
+ },
+ },
+ Required: []string{"repository"},
+ },
+ },
+ }
+}
diff --git a/.github/labeler.yml b/.github/labeler.yml
new file mode 100644
index 0000000000000000000000000000000000000000..ce4b0290358bdb5e02a94546cad2e6a36f726a31
--- /dev/null
+++ b/.github/labeler.yml
@@ -0,0 +1,33 @@
+enhancement:
+ - head-branch: ['^feature', 'feature']
+
+dependencies:
+- any:
+ - changed-files:
+ - any-glob-to-any-file: 'Makefile'
+ - changed-files:
+ - any-glob-to-any-file: '*.mod'
+ - changed-files:
+ - any-glob-to-any-file: '*.sum'
+
+kind/documentation:
+- any:
+ - changed-files:
+ - any-glob-to-any-file: 'docs/*'
+ - changed-files:
+ - any-glob-to-any-file: '*.md'
+
+area/ai-model:
+- any:
+ - changed-files:
+ - any-glob-to-any-file: 'gallery/*'
+
+examples:
+- any:
+ - changed-files:
+ - any-glob-to-any-file: 'examples/*'
+
+ci:
+- any:
+ - changed-files:
+ - any-glob-to-any-file: '.github/*'
diff --git a/.github/release.yml b/.github/release.yml
new file mode 100644
index 0000000000000000000000000000000000000000..eee7f6ec3d9a7f448d4f76ac8ebab367a7e73762
--- /dev/null
+++ b/.github/release.yml
@@ -0,0 +1,37 @@
+# .github/release.yml
+
+changelog:
+ exclude:
+ labels:
+ - ignore-for-release
+ categories:
+ - title: Breaking Changes 🛠
+ labels:
+ - Semver-Major
+ - breaking-change
+ - title: "Bug fixes :bug:"
+ labels:
+ - bug
+ - regression
+ - title: "🖧 P2P area"
+ labels:
+ - area/p2p
+ - title: Exciting New Features 🎉
+ labels:
+ - Semver-Minor
+ - enhancement
+ - ux
+ - roadmap
+ - title: 🧠 Models
+ labels:
+ - area/ai-model
+ - title: 📖 Documentation and examples
+ labels:
+ - kind/documentation
+ - examples
+ - title: 👒 Dependencies
+ labels:
+ - dependencies
+ - title: Other Changes
+ labels:
+ - "*"
diff --git a/.github/stale.yml b/.github/stale.yml
new file mode 100644
index 0000000000000000000000000000000000000000..af48badee058c6367b1692e09bee00868032e590
--- /dev/null
+++ b/.github/stale.yml
@@ -0,0 +1,18 @@
+# Number of days of inactivity before an issue becomes stale
+daysUntilStale: 45
+# Number of days of inactivity before a stale issue is closed
+daysUntilClose: 10
+# Issues with these labels will never be considered stale
+exemptLabels:
+ - issue/willfix
+# Label to use when marking an issue as stale
+staleLabel: issue/stale
+# Comment to post when marking an issue as stale. Set to `false` to disable
+markComment: >
+ This issue has been automatically marked as stale because it has not had
+ recent activity. It will be closed if no further activity occurs. Thank you
+ for your contributions.
+# Comment to post when closing a stale issue. Set to `false` to disable
+closeComment: >
+ This issue is being automatically closed due to inactivity.
+ However, you may choose to reopen this issue.
\ No newline at end of file
diff --git a/.github/workflows/backend.yml b/.github/workflows/backend.yml
new file mode 100644
index 0000000000000000000000000000000000000000..d56cabde3662f0d79fb86e02c5fd6b1a1d21891c
--- /dev/null
+++ b/.github/workflows/backend.yml
@@ -0,0 +1,1498 @@
+---
+name: 'build backend container images'
+
+on:
+ push:
+ branches:
+ - master
+ tags:
+ - '*'
+
+concurrency:
+ group: ci-backends-${{ github.head_ref || github.ref }}-${{ github.repository }}
+ cancel-in-progress: true
+
+jobs:
+ backend-jobs:
+ uses: ./.github/workflows/backend_build.yml
+ with:
+ tag-latest: ${{ matrix.tag-latest }}
+ tag-suffix: ${{ matrix.tag-suffix }}
+ build-type: ${{ matrix.build-type }}
+ cuda-major-version: ${{ matrix.cuda-major-version }}
+ cuda-minor-version: ${{ matrix.cuda-minor-version }}
+ platforms: ${{ matrix.platforms }}
+ runs-on: ${{ matrix.runs-on }}
+ base-image: ${{ matrix.base-image }}
+ backend: ${{ matrix.backend }}
+ dockerfile: ${{ matrix.dockerfile }}
+ skip-drivers: ${{ matrix.skip-drivers }}
+ context: ${{ matrix.context }}
+ ubuntu-version: ${{ matrix.ubuntu-version }}
+ secrets:
+ dockerUsername: ${{ secrets.DOCKERHUB_USERNAME }}
+ dockerPassword: ${{ secrets.DOCKERHUB_PASSWORD }}
+ quayUsername: ${{ secrets.LOCALAI_REGISTRY_USERNAME }}
+ quayPassword: ${{ secrets.LOCALAI_REGISTRY_PASSWORD }}
+ strategy:
+ fail-fast: false
+ #max-parallel: ${{ github.event_name != 'pull_request' && 6 || 4 }}
+ matrix:
+ include:
+ - build-type: 'l4t'
+ cuda-major-version: "12"
+ cuda-minor-version: "0"
+ platforms: 'linux/arm64'
+ tag-latest: 'auto'
+ tag-suffix: '-nvidia-l4t-diffusers'
+ runs-on: 'ubuntu-24.04-arm'
+ base-image: "nvcr.io/nvidia/l4t-jetpack:r36.4.0"
+ skip-drivers: 'true'
+ backend: "diffusers"
+ dockerfile: "./backend/Dockerfile.python"
+ context: "./"
+ ubuntu-version: '2204'
+ - build-type: ''
+ cuda-major-version: ""
+ cuda-minor-version: ""
+ platforms: 'linux/amd64'
+ tag-latest: 'auto'
+ tag-suffix: '-cpu-diffusers'
+ runs-on: 'ubuntu-latest'
+ base-image: "ubuntu:24.04"
+ skip-drivers: 'true'
+ backend: "diffusers"
+ dockerfile: "./backend/Dockerfile.python"
+ context: "./"
+ ubuntu-version: '2404'
+ - build-type: ''
+ cuda-major-version: ""
+ cuda-minor-version: ""
+ platforms: 'linux/amd64'
+ tag-latest: 'auto'
+ tag-suffix: '-cpu-chatterbox'
+ runs-on: 'ubuntu-latest'
+ base-image: "ubuntu:24.04"
+ skip-drivers: 'true'
+ backend: "chatterbox"
+ dockerfile: "./backend/Dockerfile.python"
+ context: "./"
+ ubuntu-version: '2404'
+ - build-type: ''
+ cuda-major-version: ""
+ cuda-minor-version: ""
+ platforms: 'linux/amd64'
+ tag-latest: 'auto'
+ tag-suffix: '-cpu-moonshine'
+ runs-on: 'ubuntu-latest'
+ base-image: "ubuntu:24.04"
+ skip-drivers: 'true'
+ backend: "moonshine"
+ dockerfile: "./backend/Dockerfile.python"
+ context: "./"
+ ubuntu-version: '2404'
+ # CUDA 12 builds
+ - build-type: 'cublas'
+ cuda-major-version: "12"
+ cuda-minor-version: "9"
+ platforms: 'linux/amd64'
+ tag-latest: 'auto'
+ tag-suffix: '-gpu-nvidia-cuda-12-vibevoice'
+ runs-on: 'ubuntu-latest'
+ base-image: "ubuntu:24.04"
+ skip-drivers: 'false'
+ backend: "vibevoice"
+ dockerfile: "./backend/Dockerfile.python"
+ context: "./"
+ ubuntu-version: '2404'
+ - build-type: 'cublas'
+ cuda-major-version: "12"
+ cuda-minor-version: "9"
+ platforms: 'linux/amd64'
+ tag-latest: 'auto'
+ tag-suffix: '-gpu-nvidia-cuda-12-pocket-tts'
+ runs-on: 'ubuntu-latest'
+ base-image: "ubuntu:24.04"
+ skip-drivers: 'false'
+ backend: "pocket-tts"
+ dockerfile: "./backend/Dockerfile.python"
+ context: "./"
+ ubuntu-version: '2404'
+ - build-type: 'cublas'
+ cuda-major-version: "12"
+ cuda-minor-version: "0"
+ platforms: 'linux/amd64'
+ tag-latest: 'auto'
+ tag-suffix: '-gpu-nvidia-cuda-12-rerankers'
+ runs-on: 'ubuntu-latest'
+ base-image: "ubuntu:24.04"
+ skip-drivers: 'false'
+ backend: "rerankers"
+ dockerfile: "./backend/Dockerfile.python"
+ context: "./"
+ ubuntu-version: '2404'
+ - build-type: 'cublas'
+ cuda-major-version: "12"
+ cuda-minor-version: "9"
+ platforms: 'linux/amd64'
+ tag-latest: 'auto'
+ tag-suffix: '-gpu-nvidia-cuda-12-llama-cpp'
+ runs-on: 'ubuntu-latest'
+ base-image: "ubuntu:24.04"
+ skip-drivers: 'false'
+ backend: "llama-cpp"
+ dockerfile: "./backend/Dockerfile.llama-cpp"
+ context: "./"
+ ubuntu-version: '2404'
+ - build-type: 'cublas'
+ cuda-major-version: "12"
+ cuda-minor-version: "9"
+ platforms: 'linux/amd64'
+ tag-latest: 'auto'
+ tag-suffix: '-gpu-nvidia-cuda-12-vllm'
+ runs-on: 'arc-runner-set'
+ base-image: "ubuntu:24.04"
+ skip-drivers: 'false'
+ backend: "vllm"
+ dockerfile: "./backend/Dockerfile.python"
+ context: "./"
+ ubuntu-version: '2404'
+ - build-type: 'cublas'
+ cuda-major-version: "12"
+ cuda-minor-version: "9"
+ platforms: 'linux/amd64'
+ tag-latest: 'auto'
+ tag-suffix: '-gpu-nvidia-cuda-12-transformers'
+ runs-on: 'ubuntu-latest'
+ base-image: "ubuntu:24.04"
+ skip-drivers: 'false'
+ backend: "transformers"
+ dockerfile: "./backend/Dockerfile.python"
+ context: "./"
+ ubuntu-version: '2404'
+ - build-type: 'cublas'
+ cuda-major-version: "12"
+ cuda-minor-version: "9"
+ platforms: 'linux/amd64'
+ tag-latest: 'auto'
+ tag-suffix: '-gpu-nvidia-cuda-12-diffusers'
+ runs-on: 'ubuntu-latest'
+ base-image: "ubuntu:24.04"
+ skip-drivers: 'false'
+ backend: "diffusers"
+ dockerfile: "./backend/Dockerfile.python"
+ context: "./"
+ ubuntu-version: '2404'
+ - build-type: 'cublas'
+ cuda-major-version: "12"
+ cuda-minor-version: "9"
+ platforms: 'linux/amd64'
+ tag-latest: 'auto'
+ tag-suffix: '-gpu-nvidia-cuda-12-kokoro'
+ runs-on: 'ubuntu-latest'
+ base-image: "ubuntu:24.04"
+ skip-drivers: 'false'
+ backend: "kokoro"
+ dockerfile: "./backend/Dockerfile.python"
+ context: "./"
+ ubuntu-version: '2404'
+ - build-type: 'cublas'
+ cuda-major-version: "12"
+ cuda-minor-version: "9"
+ platforms: 'linux/amd64'
+ tag-latest: 'auto'
+ tag-suffix: '-gpu-nvidia-cuda-12-faster-whisper'
+ runs-on: 'ubuntu-latest'
+ base-image: "ubuntu:24.04"
+ skip-drivers: 'false'
+ backend: "faster-whisper"
+ dockerfile: "./backend/Dockerfile.python"
+ context: "./"
+ ubuntu-version: '2404'
+ - build-type: 'cublas'
+ cuda-major-version: "12"
+ cuda-minor-version: "9"
+ platforms: 'linux/amd64'
+ tag-latest: 'auto'
+ tag-suffix: '-gpu-nvidia-cuda-12-coqui'
+ runs-on: 'ubuntu-latest'
+ base-image: "ubuntu:24.04"
+ skip-drivers: 'false'
+ backend: "coqui"
+ dockerfile: "./backend/Dockerfile.python"
+ context: "./"
+ ubuntu-version: '2404'
+ - build-type: 'cublas'
+ cuda-major-version: "12"
+ cuda-minor-version: "9"
+ platforms: 'linux/amd64'
+ tag-latest: 'auto'
+ tag-suffix: '-gpu-nvidia-cuda-12-bark'
+ runs-on: 'ubuntu-latest'
+ base-image: "ubuntu:24.04"
+ skip-drivers: 'false'
+ backend: "bark"
+ dockerfile: "./backend/Dockerfile.python"
+ context: "./"
+ ubuntu-version: '2404'
+ - build-type: 'cublas'
+ cuda-major-version: "12"
+ cuda-minor-version: "9"
+ platforms: 'linux/amd64'
+ tag-latest: 'auto'
+ tag-suffix: '-gpu-nvidia-cuda-12-chatterbox'
+ runs-on: 'ubuntu-latest'
+ base-image: "ubuntu:24.04"
+ skip-drivers: 'false'
+ backend: "chatterbox"
+ dockerfile: "./backend/Dockerfile.python"
+ context: "./"
+ ubuntu-version: '2404'
+ - build-type: 'cublas'
+ cuda-major-version: "12"
+ cuda-minor-version: "9"
+ platforms: 'linux/amd64'
+ tag-latest: 'auto'
+ tag-suffix: '-gpu-nvidia-cuda-12-moonshine'
+ runs-on: 'ubuntu-latest'
+ base-image: "ubuntu:24.04"
+ skip-drivers: 'false'
+ backend: "moonshine"
+ dockerfile: "./backend/Dockerfile.python"
+ context: "./"
+ ubuntu-version: '2404'
+ - build-type: 'cublas'
+ cuda-major-version: "12"
+ cuda-minor-version: "9"
+ platforms: 'linux/amd64'
+ tag-latest: 'auto'
+ tag-suffix: '-gpu-nvidia-cuda-12-stablediffusion-ggml'
+ runs-on: 'ubuntu-latest'
+ base-image: "ubuntu:24.04"
+ skip-drivers: 'false'
+ backend: "stablediffusion-ggml"
+ dockerfile: "./backend/Dockerfile.golang"
+ context: "./"
+ ubuntu-version: '2404'
+ - build-type: 'cublas'
+ cuda-major-version: "12"
+ cuda-minor-version: "9"
+ platforms: 'linux/amd64'
+ tag-latest: 'auto'
+ tag-suffix: '-gpu-nvidia-cuda-12-whisper'
+ runs-on: 'ubuntu-latest'
+ base-image: "ubuntu:24.04"
+ skip-drivers: 'false'
+ backend: "whisper"
+ dockerfile: "./backend/Dockerfile.golang"
+ context: "./"
+ ubuntu-version: '2404'
+ - build-type: 'cublas'
+ cuda-major-version: "12"
+ cuda-minor-version: "9"
+ platforms: 'linux/amd64'
+ tag-latest: 'auto'
+ tag-suffix: '-gpu-nvidia-cuda-12-rfdetr'
+ runs-on: 'ubuntu-latest'
+ base-image: "ubuntu:24.04"
+ skip-drivers: 'false'
+ backend: "rfdetr"
+ dockerfile: "./backend/Dockerfile.python"
+ context: "./"
+ ubuntu-version: '2404'
+ - build-type: 'cublas'
+ cuda-major-version: "12"
+ cuda-minor-version: "9"
+ platforms: 'linux/amd64'
+ tag-latest: 'auto'
+ tag-suffix: '-gpu-nvidia-cuda-12-exllama2'
+ runs-on: 'ubuntu-latest'
+ base-image: "ubuntu:24.04"
+ skip-drivers: 'false'
+ backend: "exllama2"
+ dockerfile: "./backend/Dockerfile.python"
+ context: "./"
+ ubuntu-version: '2404'
+ - build-type: 'cublas'
+ cuda-major-version: "12"
+ cuda-minor-version: "9"
+ platforms: 'linux/amd64'
+ tag-latest: 'auto'
+ tag-suffix: '-gpu-nvidia-cuda-12-neutts'
+ runs-on: 'ubuntu-latest'
+ base-image: "ubuntu:24.04"
+ skip-drivers: 'false'
+ backend: "neutts"
+ dockerfile: "./backend/Dockerfile.python"
+ context: "./"
+ ubuntu-version: '2404'
+ # cuda 13
+ - build-type: 'cublas'
+ cuda-major-version: "13"
+ cuda-minor-version: "0"
+ platforms: 'linux/amd64'
+ tag-latest: 'auto'
+ tag-suffix: '-gpu-nvidia-cuda-13-rerankers'
+ runs-on: 'ubuntu-latest'
+ base-image: "ubuntu:24.04"
+ skip-drivers: 'false'
+ backend: "rerankers"
+ dockerfile: "./backend/Dockerfile.python"
+ context: "./"
+ ubuntu-version: '2404'
+ - build-type: 'cublas'
+ cuda-major-version: "13"
+ cuda-minor-version: "0"
+ platforms: 'linux/amd64'
+ tag-latest: 'auto'
+ tag-suffix: '-gpu-nvidia-cuda-13-vibevoice'
+ runs-on: 'ubuntu-latest'
+ base-image: "ubuntu:24.04"
+ skip-drivers: 'false'
+ backend: "vibevoice"
+ dockerfile: "./backend/Dockerfile.python"
+ context: "./"
+ ubuntu-version: '2404'
+ - build-type: 'cublas'
+ cuda-major-version: "13"
+ cuda-minor-version: "0"
+ platforms: 'linux/amd64'
+ tag-latest: 'auto'
+ tag-suffix: '-gpu-nvidia-cuda-13-pocket-tts'
+ runs-on: 'ubuntu-latest'
+ base-image: "ubuntu:24.04"
+ skip-drivers: 'false'
+ backend: "pocket-tts"
+ dockerfile: "./backend/Dockerfile.python"
+ context: "./"
+ ubuntu-version: '2404'
+ - build-type: 'cublas'
+ cuda-major-version: "13"
+ cuda-minor-version: "0"
+ platforms: 'linux/amd64'
+ tag-latest: 'auto'
+ tag-suffix: '-gpu-nvidia-cuda-13-llama-cpp'
+ runs-on: 'ubuntu-latest'
+ base-image: "ubuntu:24.04"
+ skip-drivers: 'false'
+ backend: "llama-cpp"
+ dockerfile: "./backend/Dockerfile.llama-cpp"
+ context: "./"
+ ubuntu-version: '2404'
+ - build-type: 'cublas'
+ cuda-major-version: "13"
+ cuda-minor-version: "0"
+ platforms: 'linux/arm64'
+ skip-drivers: 'false'
+ tag-latest: 'auto'
+ tag-suffix: '-nvidia-l4t-cuda-13-arm64-llama-cpp'
+ base-image: "ubuntu:24.04"
+ runs-on: 'ubuntu-24.04-arm'
+ ubuntu-version: '2404'
+ backend: "llama-cpp"
+ dockerfile: "./backend/Dockerfile.llama-cpp"
+ context: "./"
+ - build-type: 'cublas'
+ cuda-major-version: "13"
+ cuda-minor-version: "0"
+ platforms: 'linux/amd64'
+ tag-latest: 'auto'
+ tag-suffix: '-gpu-nvidia-cuda-13-transformers'
+ runs-on: 'ubuntu-latest'
+ base-image: "ubuntu:24.04"
+ skip-drivers: 'false'
+ backend: "transformers"
+ dockerfile: "./backend/Dockerfile.python"
+ context: "./"
+ ubuntu-version: '2404'
+ - build-type: 'cublas'
+ cuda-major-version: "13"
+ cuda-minor-version: "0"
+ platforms: 'linux/amd64'
+ tag-latest: 'auto'
+ tag-suffix: '-gpu-nvidia-cuda-13-diffusers'
+ runs-on: 'ubuntu-latest'
+ base-image: "ubuntu:24.04"
+ skip-drivers: 'false'
+ backend: "diffusers"
+ dockerfile: "./backend/Dockerfile.python"
+ context: "./"
+ ubuntu-version: '2404'
+ - build-type: 'l4t'
+ cuda-major-version: "13"
+ cuda-minor-version: "0"
+ platforms: 'linux/arm64'
+ tag-latest: 'auto'
+ tag-suffix: '-nvidia-l4t-cuda-13-arm64-vibevoice'
+ runs-on: 'ubuntu-24.04-arm'
+ base-image: "ubuntu:24.04"
+ skip-drivers: 'false'
+ ubuntu-version: '2404'
+ backend: "vibevoice"
+ dockerfile: "./backend/Dockerfile.python"
+ context: "./"
+ - build-type: 'l4t'
+ cuda-major-version: "13"
+ cuda-minor-version: "0"
+ platforms: 'linux/arm64'
+ tag-latest: 'auto'
+ tag-suffix: '-nvidia-l4t-cuda-13-arm64-pocket-tts'
+ runs-on: 'ubuntu-24.04-arm'
+ base-image: "ubuntu:24.04"
+ skip-drivers: 'false'
+ ubuntu-version: '2404'
+ backend: "pocket-tts"
+ dockerfile: "./backend/Dockerfile.python"
+ context: "./"
+ - build-type: 'l4t'
+ cuda-major-version: "13"
+ cuda-minor-version: "0"
+ platforms: 'linux/arm64'
+ tag-latest: 'auto'
+ tag-suffix: '-nvidia-l4t-cuda-13-arm64-diffusers'
+ runs-on: 'ubuntu-24.04-arm'
+ base-image: "ubuntu:24.04"
+ skip-drivers: 'false'
+ ubuntu-version: '2404'
+ backend: "diffusers"
+ dockerfile: "./backend/Dockerfile.python"
+ context: "./"
+ - build-type: 'cublas'
+ cuda-major-version: "13"
+ cuda-minor-version: "0"
+ platforms: 'linux/amd64'
+ tag-latest: 'auto'
+ tag-suffix: '-gpu-nvidia-cuda-13-kokoro'
+ runs-on: 'ubuntu-latest'
+ base-image: "ubuntu:24.04"
+ skip-drivers: 'false'
+ backend: "kokoro"
+ dockerfile: "./backend/Dockerfile.python"
+ context: "./"
+ ubuntu-version: '2404'
+ - build-type: 'cublas'
+ cuda-major-version: "13"
+ cuda-minor-version: "0"
+ platforms: 'linux/amd64'
+ tag-latest: 'auto'
+ tag-suffix: '-gpu-nvidia-cuda-13-faster-whisper'
+ runs-on: 'ubuntu-latest'
+ base-image: "ubuntu:24.04"
+ skip-drivers: 'false'
+ backend: "faster-whisper"
+ dockerfile: "./backend/Dockerfile.python"
+ context: "./"
+ ubuntu-version: '2404'
+ - build-type: 'cublas'
+ cuda-major-version: "13"
+ cuda-minor-version: "0"
+ platforms: 'linux/amd64'
+ tag-latest: 'auto'
+ tag-suffix: '-gpu-nvidia-cuda-13-bark'
+ runs-on: 'ubuntu-latest'
+ base-image: "ubuntu:24.04"
+ skip-drivers: 'false'
+ backend: "bark"
+ dockerfile: "./backend/Dockerfile.python"
+ context: "./"
+ ubuntu-version: '2404'
+ - build-type: 'cublas'
+ cuda-major-version: "13"
+ cuda-minor-version: "0"
+ platforms: 'linux/amd64'
+ tag-latest: 'auto'
+ tag-suffix: '-gpu-nvidia-cuda-13-chatterbox'
+ runs-on: 'ubuntu-latest'
+ base-image: "ubuntu:24.04"
+ skip-drivers: 'false'
+ backend: "chatterbox"
+ dockerfile: "./backend/Dockerfile.python"
+ context: "./"
+ ubuntu-version: '2404'
+ - build-type: 'cublas'
+ cuda-major-version: "13"
+ cuda-minor-version: "0"
+ platforms: 'linux/amd64'
+ tag-latest: 'auto'
+ tag-suffix: '-gpu-nvidia-cuda-13-moonshine'
+ runs-on: 'ubuntu-latest'
+ base-image: "ubuntu:24.04"
+ skip-drivers: 'false'
+ backend: "moonshine"
+ dockerfile: "./backend/Dockerfile.python"
+ context: "./"
+ ubuntu-version: '2404'
+ - build-type: 'cublas'
+ cuda-major-version: "13"
+ cuda-minor-version: "0"
+ platforms: 'linux/amd64'
+ tag-latest: 'auto'
+ tag-suffix: '-gpu-nvidia-cuda-13-stablediffusion-ggml'
+ runs-on: 'ubuntu-latest'
+ base-image: "ubuntu:24.04"
+ skip-drivers: 'false'
+ backend: "stablediffusion-ggml"
+ dockerfile: "./backend/Dockerfile.golang"
+ context: "./"
+ ubuntu-version: '2404'
+ - build-type: 'cublas'
+ cuda-major-version: "13"
+ cuda-minor-version: "0"
+ platforms: 'linux/arm64'
+ skip-drivers: 'false'
+ tag-latest: 'auto'
+ tag-suffix: '-nvidia-l4t-cuda-13-arm64-stablediffusion-ggml'
+ base-image: "ubuntu:24.04"
+ ubuntu-version: '2404'
+ runs-on: 'ubuntu-24.04-arm'
+ backend: "stablediffusion-ggml"
+ dockerfile: "./backend/Dockerfile.golang"
+ context: "./"
+ - build-type: 'cublas'
+ cuda-major-version: "13"
+ cuda-minor-version: "0"
+ platforms: 'linux/amd64'
+ tag-latest: 'auto'
+ tag-suffix: '-gpu-nvidia-cuda-13-whisper'
+ runs-on: 'ubuntu-latest'
+ base-image: "ubuntu:24.04"
+ skip-drivers: 'false'
+ backend: "whisper"
+ dockerfile: "./backend/Dockerfile.golang"
+ context: "./"
+ ubuntu-version: '2404'
+ - build-type: 'cublas'
+ cuda-major-version: "13"
+ cuda-minor-version: "0"
+ platforms: 'linux/arm64'
+ skip-drivers: 'false'
+ tag-latest: 'auto'
+ tag-suffix: '-nvidia-l4t-cuda-13-arm64-whisper'
+ base-image: "ubuntu:24.04"
+ ubuntu-version: '2404'
+ runs-on: 'ubuntu-24.04-arm'
+ backend: "whisper"
+ dockerfile: "./backend/Dockerfile.golang"
+ context: "./"
+ - build-type: 'cublas'
+ cuda-major-version: "13"
+ cuda-minor-version: "0"
+ platforms: 'linux/amd64'
+ tag-latest: 'auto'
+ tag-suffix: '-gpu-nvidia-cuda-13-rfdetr'
+ runs-on: 'ubuntu-latest'
+ base-image: "ubuntu:24.04"
+ skip-drivers: 'false'
+ backend: "rfdetr"
+ dockerfile: "./backend/Dockerfile.python"
+ context: "./"
+ ubuntu-version: '2404'
+ # hipblas builds
+ - build-type: 'hipblas'
+ cuda-major-version: ""
+ cuda-minor-version: ""
+ platforms: 'linux/amd64'
+ tag-latest: 'auto'
+ tag-suffix: '-gpu-rocm-hipblas-rerankers'
+ runs-on: 'ubuntu-latest'
+ base-image: "rocm/dev-ubuntu-24.04:6.4.4"
+ skip-drivers: 'false'
+ backend: "rerankers"
+ dockerfile: "./backend/Dockerfile.python"
+ context: "./"
+ ubuntu-version: '2404'
+ - build-type: 'hipblas'
+ cuda-major-version: ""
+ cuda-minor-version: ""
+ platforms: 'linux/amd64'
+ tag-latest: 'auto'
+ tag-suffix: '-gpu-rocm-hipblas-llama-cpp'
+ runs-on: 'ubuntu-latest'
+ base-image: "rocm/dev-ubuntu-24.04:6.4.4"
+ skip-drivers: 'false'
+ backend: "llama-cpp"
+ dockerfile: "./backend/Dockerfile.llama-cpp"
+ context: "./"
+ ubuntu-version: '2404'
+ - build-type: 'hipblas'
+ cuda-major-version: ""
+ cuda-minor-version: ""
+ platforms: 'linux/amd64'
+ tag-latest: 'auto'
+ tag-suffix: '-gpu-rocm-hipblas-vllm'
+ runs-on: 'arc-runner-set'
+ base-image: "rocm/dev-ubuntu-24.04:6.4.4"
+ skip-drivers: 'false'
+ backend: "vllm"
+ dockerfile: "./backend/Dockerfile.python"
+ context: "./"
+ ubuntu-version: '2404'
+ - build-type: 'hipblas'
+ cuda-major-version: ""
+ cuda-minor-version: ""
+ platforms: 'linux/amd64'
+ tag-latest: 'auto'
+ tag-suffix: '-gpu-rocm-hipblas-transformers'
+ runs-on: 'arc-runner-set'
+ base-image: "rocm/dev-ubuntu-24.04:6.4.4"
+ skip-drivers: 'false'
+ backend: "transformers"
+ dockerfile: "./backend/Dockerfile.python"
+ context: "./"
+ ubuntu-version: '2404'
+ - build-type: 'hipblas'
+ cuda-major-version: ""
+ cuda-minor-version: ""
+ platforms: 'linux/amd64'
+ tag-latest: 'auto'
+ tag-suffix: '-gpu-rocm-hipblas-diffusers'
+ runs-on: 'arc-runner-set'
+ base-image: "rocm/dev-ubuntu-24.04:6.4.4"
+ skip-drivers: 'false'
+ backend: "diffusers"
+ dockerfile: "./backend/Dockerfile.python"
+ context: "./"
+ ubuntu-version: '2404'
+ # ROCm additional backends
+ - build-type: 'hipblas'
+ cuda-major-version: ""
+ cuda-minor-version: ""
+ platforms: 'linux/amd64'
+ tag-latest: 'auto'
+ tag-suffix: '-gpu-rocm-hipblas-kokoro'
+ runs-on: 'arc-runner-set'
+ base-image: "rocm/dev-ubuntu-24.04:6.4.4"
+ skip-drivers: 'false'
+ backend: "kokoro"
+ dockerfile: "./backend/Dockerfile.python"
+ context: "./"
+ ubuntu-version: '2404'
+ - build-type: 'hipblas'
+ cuda-major-version: ""
+ cuda-minor-version: ""
+ platforms: 'linux/amd64'
+ tag-latest: 'auto'
+ tag-suffix: '-gpu-rocm-hipblas-vibevoice'
+ runs-on: 'arc-runner-set'
+ base-image: "rocm/dev-ubuntu-24.04:6.4.4"
+ skip-drivers: 'false'
+ backend: "vibevoice"
+ dockerfile: "./backend/Dockerfile.python"
+ context: "./"
+ ubuntu-version: '2404'
+ - build-type: 'hipblas'
+ cuda-major-version: ""
+ cuda-minor-version: ""
+ platforms: 'linux/amd64'
+ tag-latest: 'auto'
+ tag-suffix: '-gpu-rocm-hipblas-pocket-tts'
+ runs-on: 'arc-runner-set'
+ base-image: "rocm/dev-ubuntu-24.04:6.4.4"
+ skip-drivers: 'false'
+ backend: "pocket-tts"
+ dockerfile: "./backend/Dockerfile.python"
+ context: "./"
+ ubuntu-version: '2404'
+ - build-type: 'hipblas'
+ cuda-major-version: ""
+ cuda-minor-version: ""
+ platforms: 'linux/amd64'
+ tag-latest: 'auto'
+ tag-suffix: '-gpu-rocm-hipblas-faster-whisper'
+ runs-on: 'ubuntu-latest'
+ base-image: "rocm/dev-ubuntu-24.04:6.4.4"
+ skip-drivers: 'false'
+ backend: "faster-whisper"
+ dockerfile: "./backend/Dockerfile.python"
+ context: "./"
+ ubuntu-version: '2404'
+ - build-type: 'hipblas'
+ cuda-major-version: ""
+ cuda-minor-version: ""
+ platforms: 'linux/amd64'
+ tag-latest: 'auto'
+ tag-suffix: '-gpu-rocm-hipblas-coqui'
+ runs-on: 'ubuntu-latest'
+ base-image: "rocm/dev-ubuntu-24.04:6.4.4"
+ skip-drivers: 'false'
+ backend: "coqui"
+ dockerfile: "./backend/Dockerfile.python"
+ context: "./"
+ ubuntu-version: '2404'
+ - build-type: 'hipblas'
+ cuda-major-version: ""
+ cuda-minor-version: ""
+ platforms: 'linux/amd64'
+ tag-latest: 'auto'
+ tag-suffix: '-gpu-rocm-hipblas-bark'
+ runs-on: 'arc-runner-set'
+ base-image: "rocm/dev-ubuntu-24.04:6.4.4"
+ skip-drivers: 'false'
+ backend: "bark"
+ dockerfile: "./backend/Dockerfile.python"
+ context: "./"
+ ubuntu-version: '2404'
+ # sycl builds
+ - build-type: 'intel'
+ cuda-major-version: ""
+ cuda-minor-version: ""
+ platforms: 'linux/amd64'
+ tag-latest: 'auto'
+ tag-suffix: '-gpu-intel-rerankers'
+ runs-on: 'ubuntu-latest'
+ base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04"
+ skip-drivers: 'false'
+ backend: "rerankers"
+ dockerfile: "./backend/Dockerfile.python"
+ context: "./"
+ ubuntu-version: '2404'
+ - build-type: 'sycl_f32'
+ cuda-major-version: ""
+ cuda-minor-version: ""
+ platforms: 'linux/amd64'
+ tag-latest: 'auto'
+ tag-suffix: '-gpu-intel-sycl-f32-llama-cpp'
+ runs-on: 'ubuntu-latest'
+ base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04"
+ skip-drivers: 'false'
+ backend: "llama-cpp"
+ dockerfile: "./backend/Dockerfile.llama-cpp"
+ context: "./"
+ ubuntu-version: '2404'
+ - build-type: 'sycl_f16'
+ cuda-major-version: ""
+ cuda-minor-version: ""
+ platforms: 'linux/amd64'
+ tag-latest: 'auto'
+ tag-suffix: '-gpu-intel-sycl-f16-llama-cpp'
+ runs-on: 'ubuntu-latest'
+ base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04"
+ skip-drivers: 'false'
+ backend: "llama-cpp"
+ dockerfile: "./backend/Dockerfile.llama-cpp"
+ context: "./"
+ ubuntu-version: '2404'
+ - build-type: 'intel'
+ cuda-major-version: ""
+ cuda-minor-version: ""
+ platforms: 'linux/amd64'
+ tag-latest: 'auto'
+ tag-suffix: '-gpu-intel-vllm'
+ runs-on: 'arc-runner-set'
+ base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04"
+ skip-drivers: 'false'
+ backend: "vllm"
+ dockerfile: "./backend/Dockerfile.python"
+ context: "./"
+ ubuntu-version: '2404'
+ - build-type: 'intel'
+ cuda-major-version: ""
+ cuda-minor-version: ""
+ platforms: 'linux/amd64'
+ tag-latest: 'auto'
+ tag-suffix: '-gpu-intel-transformers'
+ runs-on: 'ubuntu-latest'
+ base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04"
+ skip-drivers: 'false'
+ backend: "transformers"
+ dockerfile: "./backend/Dockerfile.python"
+ context: "./"
+ ubuntu-version: '2404'
+ - build-type: 'intel'
+ cuda-major-version: ""
+ cuda-minor-version: ""
+ platforms: 'linux/amd64'
+ tag-latest: 'auto'
+ tag-suffix: '-gpu-intel-diffusers'
+ runs-on: 'ubuntu-latest'
+ base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04"
+ skip-drivers: 'false'
+ backend: "diffusers"
+ dockerfile: "./backend/Dockerfile.python"
+ context: "./"
+ ubuntu-version: '2404'
+ - build-type: 'l4t'
+ cuda-major-version: "12"
+ cuda-minor-version: "0"
+ platforms: 'linux/arm64'
+ tag-latest: 'auto'
+ tag-suffix: '-nvidia-l4t-vibevoice'
+ runs-on: 'ubuntu-24.04-arm'
+ base-image: "nvcr.io/nvidia/l4t-jetpack:r36.4.0"
+ skip-drivers: 'true'
+ backend: "vibevoice"
+ dockerfile: "./backend/Dockerfile.python"
+ context: "./"
+ ubuntu-version: '2204'
+ - build-type: 'l4t'
+ cuda-major-version: "12"
+ cuda-minor-version: "0"
+ platforms: 'linux/arm64'
+ tag-latest: 'auto'
+ tag-suffix: '-nvidia-l4t-pocket-tts'
+ runs-on: 'ubuntu-24.04-arm'
+ base-image: "nvcr.io/nvidia/l4t-jetpack:r36.4.0"
+ skip-drivers: 'true'
+ backend: "pocket-tts"
+ dockerfile: "./backend/Dockerfile.python"
+ context: "./"
+ ubuntu-version: '2204'
+ - build-type: 'l4t'
+ cuda-major-version: "12"
+ cuda-minor-version: "0"
+ platforms: 'linux/arm64'
+ tag-latest: 'auto'
+ tag-suffix: '-nvidia-l4t-kokoro'
+ runs-on: 'ubuntu-24.04-arm'
+ base-image: "nvcr.io/nvidia/l4t-jetpack:r36.4.0"
+ skip-drivers: 'true'
+ backend: "kokoro"
+ dockerfile: "./backend/Dockerfile.python"
+ context: "./"
+ ubuntu-version: '2204'
+ # SYCL additional backends
+ - build-type: 'intel'
+ cuda-major-version: ""
+ cuda-minor-version: ""
+ platforms: 'linux/amd64'
+ tag-latest: 'auto'
+ tag-suffix: '-gpu-intel-kokoro'
+ runs-on: 'ubuntu-latest'
+ base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04"
+ skip-drivers: 'false'
+ backend: "kokoro"
+ dockerfile: "./backend/Dockerfile.python"
+ context: "./"
+ ubuntu-version: '2404'
+ - build-type: 'intel'
+ cuda-major-version: ""
+ cuda-minor-version: ""
+ platforms: 'linux/amd64'
+ tag-latest: 'auto'
+ tag-suffix: '-gpu-intel-faster-whisper'
+ runs-on: 'ubuntu-latest'
+ base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04"
+ skip-drivers: 'false'
+ backend: "faster-whisper"
+ dockerfile: "./backend/Dockerfile.python"
+ context: "./"
+ ubuntu-version: '2404'
+ - build-type: 'intel'
+ cuda-major-version: ""
+ cuda-minor-version: ""
+ platforms: 'linux/amd64'
+ tag-latest: 'auto'
+ tag-suffix: '-gpu-intel-vibevoice'
+ runs-on: 'arc-runner-set'
+ base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04"
+ skip-drivers: 'false'
+ backend: "vibevoice"
+ dockerfile: "./backend/Dockerfile.python"
+ context: "./"
+ ubuntu-version: '2404'
+ - build-type: 'intel'
+ cuda-major-version: ""
+ cuda-minor-version: ""
+ platforms: 'linux/amd64'
+ tag-latest: 'auto'
+ tag-suffix: '-gpu-intel-pocket-tts'
+ runs-on: 'arc-runner-set'
+ base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04"
+ skip-drivers: 'false'
+ backend: "pocket-tts"
+ dockerfile: "./backend/Dockerfile.python"
+ context: "./"
+ ubuntu-version: '2404'
+ - build-type: 'intel'
+ cuda-major-version: ""
+ cuda-minor-version: ""
+ platforms: 'linux/amd64'
+ tag-latest: 'auto'
+ tag-suffix: '-gpu-intel-coqui'
+ runs-on: 'ubuntu-latest'
+ base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04"
+ skip-drivers: 'false'
+ backend: "coqui"
+ dockerfile: "./backend/Dockerfile.python"
+ context: "./"
+ ubuntu-version: '2404'
+ - build-type: 'intel'
+ cuda-major-version: ""
+ cuda-minor-version: ""
+ platforms: 'linux/amd64'
+ tag-latest: 'auto'
+ tag-suffix: '-gpu-intel-bark'
+ runs-on: 'ubuntu-latest'
+ base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04"
+ skip-drivers: 'false'
+ backend: "bark"
+ dockerfile: "./backend/Dockerfile.python"
+ context: "./"
+ ubuntu-version: '2404'
+ # piper
+ - build-type: ''
+ cuda-major-version: ""
+ cuda-minor-version: ""
+ platforms: 'linux/amd64,linux/arm64'
+ tag-latest: 'auto'
+ tag-suffix: '-piper'
+ runs-on: 'ubuntu-latest'
+ base-image: "ubuntu:24.04"
+ skip-drivers: 'false'
+ backend: "piper"
+ dockerfile: "./backend/Dockerfile.golang"
+ context: "./"
+ ubuntu-version: '2404'
+ # bark-cpp
+ - build-type: ''
+ cuda-major-version: ""
+ cuda-minor-version: ""
+ platforms: 'linux/amd64'
+ tag-latest: 'auto'
+ tag-suffix: '-bark-cpp'
+ runs-on: 'ubuntu-latest'
+ base-image: "ubuntu:24.04"
+ skip-drivers: 'false'
+ backend: "bark-cpp"
+ dockerfile: "./backend/Dockerfile.golang"
+ context: "./"
+ ubuntu-version: '2404'
+ - build-type: ''
+ cuda-major-version: ""
+ cuda-minor-version: ""
+ platforms: 'linux/amd64,linux/arm64'
+ tag-latest: 'auto'
+ tag-suffix: '-cpu-llama-cpp'
+ runs-on: 'ubuntu-latest'
+ base-image: "ubuntu:24.04"
+ skip-drivers: 'false'
+ backend: "llama-cpp"
+ dockerfile: "./backend/Dockerfile.llama-cpp"
+ context: "./"
+ ubuntu-version: '2404'
+ - build-type: 'cublas'
+ cuda-major-version: "12"
+ cuda-minor-version: "0"
+ platforms: 'linux/arm64'
+ skip-drivers: 'false'
+ tag-latest: 'auto'
+ tag-suffix: '-nvidia-l4t-arm64-llama-cpp'
+ base-image: "nvcr.io/nvidia/l4t-jetpack:r36.4.0"
+ runs-on: 'ubuntu-24.04-arm'
+ backend: "llama-cpp"
+ dockerfile: "./backend/Dockerfile.llama-cpp"
+ context: "./"
+ ubuntu-version: '2204'
+ - build-type: 'vulkan'
+ cuda-major-version: ""
+ cuda-minor-version: ""
+ platforms: 'linux/amd64,linux/arm64'
+ tag-latest: 'auto'
+ tag-suffix: '-gpu-vulkan-llama-cpp'
+ runs-on: 'ubuntu-latest'
+ base-image: "ubuntu:24.04"
+ skip-drivers: 'false'
+ backend: "llama-cpp"
+ dockerfile: "./backend/Dockerfile.llama-cpp"
+ context: "./"
+ ubuntu-version: '2404'
+ # Stablediffusion-ggml
+ - build-type: ''
+ cuda-major-version: ""
+ cuda-minor-version: ""
+ platforms: 'linux/amd64'
+ tag-latest: 'auto'
+ tag-suffix: '-cpu-stablediffusion-ggml'
+ runs-on: 'ubuntu-latest'
+ base-image: "ubuntu:24.04"
+ skip-drivers: 'false'
+ backend: "stablediffusion-ggml"
+ dockerfile: "./backend/Dockerfile.golang"
+ context: "./"
+ ubuntu-version: '2404'
+ - build-type: 'sycl_f32'
+ cuda-major-version: ""
+ cuda-minor-version: ""
+ platforms: 'linux/amd64'
+ tag-latest: 'auto'
+ tag-suffix: '-gpu-intel-sycl-f32-stablediffusion-ggml'
+ runs-on: 'ubuntu-latest'
+ base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04"
+ skip-drivers: 'false'
+ backend: "stablediffusion-ggml"
+ dockerfile: "./backend/Dockerfile.golang"
+ context: "./"
+ ubuntu-version: '2404'
+ - build-type: 'sycl_f16'
+ cuda-major-version: ""
+ cuda-minor-version: ""
+ platforms: 'linux/amd64'
+ tag-latest: 'auto'
+ tag-suffix: '-gpu-intel-sycl-f16-stablediffusion-ggml'
+ runs-on: 'ubuntu-latest'
+ base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04"
+ skip-drivers: 'false'
+ backend: "stablediffusion-ggml"
+ dockerfile: "./backend/Dockerfile.golang"
+ context: "./"
+ ubuntu-version: '2404'
+ - build-type: 'vulkan'
+ cuda-major-version: ""
+ cuda-minor-version: ""
+ platforms: 'linux/amd64,linux/arm64'
+ tag-latest: 'auto'
+ tag-suffix: '-gpu-vulkan-stablediffusion-ggml'
+ runs-on: 'ubuntu-latest'
+ base-image: "ubuntu:24.04"
+ skip-drivers: 'false'
+ backend: "stablediffusion-ggml"
+ dockerfile: "./backend/Dockerfile.golang"
+ context: "./"
+ ubuntu-version: '2404'
+ - build-type: 'cublas'
+ cuda-major-version: "12"
+ cuda-minor-version: "0"
+ platforms: 'linux/arm64'
+ skip-drivers: 'false'
+ tag-latest: 'auto'
+ tag-suffix: '-nvidia-l4t-arm64-stablediffusion-ggml'
+ base-image: "nvcr.io/nvidia/l4t-jetpack:r36.4.0"
+ runs-on: 'ubuntu-24.04-arm'
+ backend: "stablediffusion-ggml"
+ dockerfile: "./backend/Dockerfile.golang"
+ context: "./"
+ ubuntu-version: '2204'
+ # whisper
+ - build-type: ''
+ cuda-major-version: ""
+ cuda-minor-version: ""
+ platforms: 'linux/amd64,linux/arm64'
+ tag-latest: 'auto'
+ tag-suffix: '-cpu-whisper'
+ runs-on: 'ubuntu-latest'
+ base-image: "ubuntu:24.04"
+ skip-drivers: 'false'
+ backend: "whisper"
+ dockerfile: "./backend/Dockerfile.golang"
+ context: "./"
+ ubuntu-version: '2404'
+ - build-type: 'sycl_f32'
+ cuda-major-version: ""
+ cuda-minor-version: ""
+ platforms: 'linux/amd64'
+ tag-latest: 'auto'
+ tag-suffix: '-gpu-intel-sycl-f32-whisper'
+ runs-on: 'ubuntu-latest'
+ base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04"
+ skip-drivers: 'false'
+ backend: "whisper"
+ dockerfile: "./backend/Dockerfile.golang"
+ context: "./"
+ ubuntu-version: '2404'
+ - build-type: 'sycl_f16'
+ cuda-major-version: ""
+ cuda-minor-version: ""
+ platforms: 'linux/amd64'
+ tag-latest: 'auto'
+ tag-suffix: '-gpu-intel-sycl-f16-whisper'
+ runs-on: 'ubuntu-latest'
+ base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04"
+ skip-drivers: 'false'
+ backend: "whisper"
+ dockerfile: "./backend/Dockerfile.golang"
+ context: "./"
+ ubuntu-version: '2404'
+ - build-type: 'vulkan'
+ cuda-major-version: ""
+ cuda-minor-version: ""
+ platforms: 'linux/amd64,linux/arm64'
+ tag-latest: 'auto'
+ tag-suffix: '-gpu-vulkan-whisper'
+ runs-on: 'ubuntu-latest'
+ base-image: "ubuntu:24.04"
+ skip-drivers: 'false'
+ backend: "whisper"
+ dockerfile: "./backend/Dockerfile.golang"
+ context: "./"
+ ubuntu-version: '2404'
+ - build-type: 'cublas'
+ cuda-major-version: "12"
+ cuda-minor-version: "0"
+ platforms: 'linux/arm64'
+ skip-drivers: 'false'
+ tag-latest: 'auto'
+ tag-suffix: '-nvidia-l4t-arm64-whisper'
+ base-image: "nvcr.io/nvidia/l4t-jetpack:r36.4.0"
+ runs-on: 'ubuntu-24.04-arm'
+ backend: "whisper"
+ dockerfile: "./backend/Dockerfile.golang"
+ context: "./"
+ ubuntu-version: '2204'
+ - build-type: 'hipblas'
+ cuda-major-version: ""
+ cuda-minor-version: ""
+ platforms: 'linux/amd64'
+ tag-latest: 'auto'
+ tag-suffix: '-gpu-rocm-hipblas-whisper'
+ base-image: "rocm/dev-ubuntu-24.04:6.4.4"
+ runs-on: 'ubuntu-latest'
+ skip-drivers: 'false'
+ backend: "whisper"
+ dockerfile: "./backend/Dockerfile.golang"
+ context: "./"
+ ubuntu-version: '2404'
+ #silero-vad
+ - build-type: ''
+ cuda-major-version: ""
+ cuda-minor-version: ""
+ platforms: 'linux/amd64,linux/arm64'
+ tag-latest: 'auto'
+ tag-suffix: '-cpu-silero-vad'
+ runs-on: 'ubuntu-latest'
+ base-image: "ubuntu:24.04"
+ skip-drivers: 'false'
+ backend: "silero-vad"
+ dockerfile: "./backend/Dockerfile.golang"
+ context: "./"
+ ubuntu-version: '2404'
+ # local-store
+ - build-type: ''
+ cuda-major-version: ""
+ cuda-minor-version: ""
+ platforms: 'linux/amd64,linux/arm64'
+ tag-latest: 'auto'
+ tag-suffix: '-cpu-local-store'
+ runs-on: 'ubuntu-latest'
+ base-image: "ubuntu:24.04"
+ skip-drivers: 'false'
+ backend: "local-store"
+ dockerfile: "./backend/Dockerfile.golang"
+ context: "./"
+ ubuntu-version: '2404'
+ # huggingface
+ - build-type: ''
+ cuda-major-version: ""
+ cuda-minor-version: ""
+ platforms: 'linux/amd64,linux/arm64'
+ tag-latest: 'auto'
+ tag-suffix: '-huggingface'
+ runs-on: 'ubuntu-latest'
+ base-image: "ubuntu:24.04"
+ skip-drivers: 'false'
+ backend: "huggingface"
+ dockerfile: "./backend/Dockerfile.golang"
+ context: "./"
+ ubuntu-version: '2404'
+ # rfdetr
+ - build-type: ''
+ cuda-major-version: ""
+ cuda-minor-version: ""
+ platforms: 'linux/amd64,linux/arm64'
+ tag-latest: 'auto'
+ tag-suffix: '-cpu-rfdetr'
+ runs-on: 'ubuntu-latest'
+ base-image: "ubuntu:24.04"
+ skip-drivers: 'false'
+ backend: "rfdetr"
+ dockerfile: "./backend/Dockerfile.python"
+ context: "./"
+ ubuntu-version: '2404'
+ - build-type: 'intel'
+ cuda-major-version: ""
+ cuda-minor-version: ""
+ platforms: 'linux/amd64'
+ tag-latest: 'auto'
+ tag-suffix: '-gpu-intel-rfdetr'
+ runs-on: 'ubuntu-latest'
+ base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04"
+ skip-drivers: 'false'
+ backend: "rfdetr"
+ dockerfile: "./backend/Dockerfile.python"
+ context: "./"
+ ubuntu-version: '2404'
+ - build-type: 'l4t'
+ cuda-major-version: "12"
+ cuda-minor-version: "0"
+ platforms: 'linux/arm64'
+ skip-drivers: 'true'
+ tag-latest: 'auto'
+ tag-suffix: '-nvidia-l4t-arm64-rfdetr'
+ base-image: "nvcr.io/nvidia/l4t-jetpack:r36.4.0"
+ runs-on: 'ubuntu-24.04-arm'
+ backend: "rfdetr"
+ dockerfile: "./backend/Dockerfile.python"
+ context: "./"
+ ubuntu-version: '2204'
+ # exllama2
+ - build-type: ''
+ cuda-major-version: ""
+ cuda-minor-version: ""
+ platforms: 'linux/amd64'
+ tag-latest: 'auto'
+ tag-suffix: '-cpu-exllama2'
+ runs-on: 'ubuntu-latest'
+ base-image: "ubuntu:24.04"
+ skip-drivers: 'false'
+ backend: "exllama2"
+ dockerfile: "./backend/Dockerfile.python"
+ context: "./"
+ ubuntu-version: '2404'
+ - build-type: 'intel'
+ cuda-major-version: ""
+ cuda-minor-version: ""
+ platforms: 'linux/amd64'
+ tag-latest: 'auto'
+ tag-suffix: '-gpu-intel-exllama2'
+ runs-on: 'ubuntu-latest'
+ base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04"
+ skip-drivers: 'false'
+ backend: "exllama2"
+ dockerfile: "./backend/Dockerfile.python"
+ context: "./"
+ ubuntu-version: '2404'
+ - build-type: 'hipblas'
+ cuda-major-version: ""
+ cuda-minor-version: ""
+ platforms: 'linux/amd64'
+ skip-drivers: 'true'
+ tag-latest: 'auto'
+ tag-suffix: '-gpu-hipblas-exllama2'
+ base-image: "rocm/dev-ubuntu-24.04:6.4.4"
+ runs-on: 'ubuntu-latest'
+ backend: "exllama2"
+ dockerfile: "./backend/Dockerfile.python"
+ context: "./"
+ ubuntu-version: '2404'
+ - build-type: 'l4t'
+ cuda-major-version: "12"
+ cuda-minor-version: "0"
+ platforms: 'linux/arm64'
+ skip-drivers: 'true'
+ tag-latest: 'auto'
+ tag-suffix: '-nvidia-l4t-arm64-chatterbox'
+ base-image: "nvcr.io/nvidia/l4t-jetpack:r36.4.0"
+ runs-on: 'ubuntu-24.04-arm'
+ backend: "chatterbox"
+ dockerfile: "./backend/Dockerfile.python"
+ context: "./"
+ ubuntu-version: '2204'
+ # runs out of space on the runner
+ # - build-type: 'hipblas'
+ # cuda-major-version: ""
+ # cuda-minor-version: ""
+ # platforms: 'linux/amd64'
+ # tag-latest: 'auto'
+ # tag-suffix: '-gpu-hipblas-rfdetr'
+ # base-image: "rocm/dev-ubuntu-24.04:6.4.4"
+ # runs-on: 'ubuntu-latest'
+ # skip-drivers: 'false'
+ # backend: "rfdetr"
+ # dockerfile: "./backend/Dockerfile.python"
+ # context: "./"
+ # kitten-tts
+ - build-type: ''
+ cuda-major-version: ""
+ cuda-minor-version: ""
+ platforms: 'linux/amd64,linux/arm64'
+ tag-latest: 'auto'
+ tag-suffix: '-kitten-tts'
+ runs-on: 'ubuntu-latest'
+ base-image: "ubuntu:24.04"
+ skip-drivers: 'false'
+ backend: "kitten-tts"
+ dockerfile: "./backend/Dockerfile.python"
+ context: "./"
+ ubuntu-version: '2404'
+ # neutts
+ - build-type: ''
+ cuda-major-version: ""
+ cuda-minor-version: ""
+ platforms: 'linux/amd64,linux/arm64'
+ tag-latest: 'auto'
+ tag-suffix: '-cpu-neutts'
+ runs-on: 'ubuntu-latest'
+ base-image: "ubuntu:24.04"
+ skip-drivers: 'false'
+ backend: "neutts"
+ dockerfile: "./backend/Dockerfile.python"
+ context: "./"
+ ubuntu-version: '2404'
+ - build-type: 'hipblas'
+ cuda-major-version: ""
+ cuda-minor-version: ""
+ platforms: 'linux/amd64'
+ tag-latest: 'auto'
+ tag-suffix: '-gpu-rocm-hipblas-neutts'
+ runs-on: 'arc-runner-set'
+ base-image: "rocm/dev-ubuntu-24.04:6.4.4"
+ skip-drivers: 'false'
+ backend: "neutts"
+ dockerfile: "./backend/Dockerfile.python"
+ context: "./"
+ ubuntu-version: '2404'
+ - build-type: 'l4t'
+ cuda-major-version: "12"
+ cuda-minor-version: "0"
+ platforms: 'linux/arm64'
+ skip-drivers: 'true'
+ tag-latest: 'auto'
+ tag-suffix: '-nvidia-l4t-arm64-neutts'
+ base-image: "nvcr.io/nvidia/l4t-jetpack:r36.4.0"
+ runs-on: 'ubuntu-24.04-arm'
+ backend: "neutts"
+ dockerfile: "./backend/Dockerfile.python"
+ context: "./"
+ ubuntu-version: '2204'
+ - build-type: ''
+ cuda-major-version: ""
+ cuda-minor-version: ""
+ platforms: 'linux/amd64,linux/arm64'
+ tag-latest: 'auto'
+ tag-suffix: '-cpu-vibevoice'
+ runs-on: 'ubuntu-latest'
+ base-image: "ubuntu:24.04"
+ skip-drivers: 'false'
+ backend: "vibevoice"
+ dockerfile: "./backend/Dockerfile.python"
+ context: "./"
+ ubuntu-version: '2404'
+ - build-type: ''
+ cuda-major-version: ""
+ cuda-minor-version: ""
+ platforms: 'linux/amd64,linux/arm64'
+ tag-latest: 'auto'
+ tag-suffix: '-cpu-pocket-tts'
+ runs-on: 'ubuntu-latest'
+ base-image: "ubuntu:24.04"
+ skip-drivers: 'false'
+ backend: "pocket-tts"
+ dockerfile: "./backend/Dockerfile.python"
+ context: "./"
+ ubuntu-version: '2404'
+ backend-jobs-darwin:
+ uses: ./.github/workflows/backend_build_darwin.yml
+ strategy:
+ matrix:
+ include:
+ - backend: "diffusers"
+ tag-suffix: "-metal-darwin-arm64-diffusers"
+ build-type: "mps"
+ - backend: "mlx"
+ tag-suffix: "-metal-darwin-arm64-mlx"
+ build-type: "mps"
+ - backend: "chatterbox"
+ tag-suffix: "-metal-darwin-arm64-chatterbox"
+ build-type: "mps"
+ - backend: "mlx-vlm"
+ tag-suffix: "-metal-darwin-arm64-mlx-vlm"
+ build-type: "mps"
+ - backend: "mlx-audio"
+ tag-suffix: "-metal-darwin-arm64-mlx-audio"
+ build-type: "mps"
+ - backend: "stablediffusion-ggml"
+ tag-suffix: "-metal-darwin-arm64-stablediffusion-ggml"
+ build-type: "metal"
+ lang: "go"
+ - backend: "whisper"
+ tag-suffix: "-metal-darwin-arm64-whisper"
+ build-type: "metal"
+ lang: "go"
+ with:
+ backend: ${{ matrix.backend }}
+ build-type: ${{ matrix.build-type }}
+ go-version: "1.24.x"
+ tag-suffix: ${{ matrix.tag-suffix }}
+ lang: ${{ matrix.lang || 'python' }}
+ use-pip: ${{ matrix.backend == 'diffusers' }}
+ runs-on: "macos-latest"
+ secrets:
+ dockerUsername: ${{ secrets.DOCKERHUB_USERNAME }}
+ dockerPassword: ${{ secrets.DOCKERHUB_PASSWORD }}
+ quayUsername: ${{ secrets.LOCALAI_REGISTRY_USERNAME }}
+ quayPassword: ${{ secrets.LOCALAI_REGISTRY_PASSWORD }}
+ llama-cpp-darwin:
+ runs-on: macos-latest
+ strategy:
+ matrix:
+ go-version: ['1.25.x']
+ steps:
+ - name: Clone
+ uses: actions/checkout@v6
+ with:
+ submodules: true
+ - name: Setup Go ${{ matrix.go-version }}
+ uses: actions/setup-go@v5
+ with:
+ go-version: ${{ matrix.go-version }}
+ cache: false
+ # You can test your matrix by printing the current Go version
+ - name: Display Go version
+ run: go version
+ - name: Dependencies
+ run: |
+ brew install protobuf grpc make protoc-gen-go protoc-gen-go-grpc libomp llvm
+ - name: Build llama-cpp-darwin
+ run: |
+ make protogen-go
+ make backends/llama-cpp-darwin
+ - name: Upload llama-cpp.tar
+ uses: actions/upload-artifact@v6
+ with:
+ name: llama-cpp-tar
+ path: backend-images/llama-cpp.tar
+ llama-cpp-darwin-publish:
+ needs: llama-cpp-darwin
+ if: github.event_name != 'pull_request'
+ runs-on: ubuntu-latest
+ steps:
+ - name: Download llama-cpp.tar
+ uses: actions/download-artifact@v7
+ with:
+ name: llama-cpp-tar
+ path: .
+ - name: Install crane
+ run: |
+ curl -L https://github.com/google/go-containerregistry/releases/latest/download/go-containerregistry_Linux_x86_64.tar.gz | tar -xz
+ sudo mv crane /usr/local/bin/
+ - name: Log in to DockerHub
+ run: |
+ echo "${{ secrets.DOCKERHUB_PASSWORD }}" | crane auth login docker.io -u "${{ secrets.DOCKERHUB_USERNAME }}" --password-stdin
+ - name: Log in to quay.io
+ run: |
+ echo "${{ secrets.LOCALAI_REGISTRY_PASSWORD }}" | crane auth login quay.io -u "${{ secrets.LOCALAI_REGISTRY_USERNAME }}" --password-stdin
+ - name: Docker meta
+ id: meta
+ uses: docker/metadata-action@v5
+ with:
+ images: |
+ localai/localai-backends
+ tags: |
+ type=ref,event=branch
+ type=semver,pattern={{raw}}
+ type=sha
+ flavor: |
+ latest=auto
+ suffix=-metal-darwin-arm64-llama-cpp,onlatest=true
+ - name: Docker meta
+ id: quaymeta
+ uses: docker/metadata-action@v5
+ with:
+ images: |
+ quay.io/go-skynet/local-ai-backends
+ tags: |
+ type=ref,event=branch
+ type=semver,pattern={{raw}}
+ type=sha
+ flavor: |
+ latest=auto
+ suffix=-metal-darwin-arm64-llama-cpp,onlatest=true
+ - name: Push Docker image (DockerHub)
+ run: |
+ for tag in $(echo "${{ steps.meta.outputs.tags }}" | tr ',' '\n'); do
+ crane push llama-cpp.tar $tag
+ done
+ - name: Push Docker image (Quay)
+ run: |
+ for tag in $(echo "${{ steps.quaymeta.outputs.tags }}" | tr ',' '\n'); do
+ crane push llama-cpp.tar $tag
+ done
diff --git a/.github/workflows/backend_build.yml b/.github/workflows/backend_build.yml
new file mode 100644
index 0000000000000000000000000000000000000000..e458dc3cb6b2788fbc14f9180a6ab9c508283842
--- /dev/null
+++ b/.github/workflows/backend_build.yml
@@ -0,0 +1,250 @@
+---
+name: 'build backend container images (reusable)'
+
+on:
+ workflow_call:
+ inputs:
+ base-image:
+ description: 'Base image'
+ required: true
+ type: string
+ build-type:
+ description: 'Build type'
+ default: ''
+ type: string
+ cuda-major-version:
+ description: 'CUDA major version'
+ default: "12"
+ type: string
+ cuda-minor-version:
+ description: 'CUDA minor version'
+ default: "1"
+ type: string
+ platforms:
+ description: 'Platforms'
+ default: ''
+ type: string
+ tag-latest:
+ description: 'Tag latest'
+ default: ''
+ type: string
+ tag-suffix:
+ description: 'Tag suffix'
+ default: ''
+ type: string
+ runs-on:
+ description: 'Runs on'
+ required: true
+ default: ''
+ type: string
+ backend:
+ description: 'Backend to build'
+ required: true
+ type: string
+ context:
+ description: 'Build context'
+ required: true
+ type: string
+ dockerfile:
+ description: 'Build Dockerfile'
+ required: true
+ type: string
+ skip-drivers:
+ description: 'Skip drivers'
+ default: 'false'
+ type: string
+ ubuntu-version:
+ description: 'Ubuntu version'
+ required: false
+ default: '2204'
+ type: string
+ secrets:
+ dockerUsername:
+ required: false
+ dockerPassword:
+ required: false
+ quayUsername:
+ required: true
+ quayPassword:
+ required: true
+
+jobs:
+ backend-build:
+ runs-on: ${{ inputs.runs-on }}
+ env:
+ quay_username: ${{ secrets.quayUsername }}
+ steps:
+
+
+ - name: Free Disk Space (Ubuntu)
+ if: inputs.runs-on == 'ubuntu-latest'
+ uses: jlumbroso/free-disk-space@main
+ with:
+ # this might remove tools that are actually needed,
+ # if set to "true" but frees about 6 GB
+ tool-cache: true
+ # all of these default to true, but feel free to set to
+ # "false" if necessary for your workflow
+ android: true
+ dotnet: true
+ haskell: true
+ large-packages: true
+ docker-images: true
+ swap-storage: true
+
+ - name: Force Install GIT latest
+ run: |
+ sudo apt-get update \
+ && sudo apt-get install -y software-properties-common \
+ && sudo apt-get update \
+ && sudo add-apt-repository -y ppa:git-core/ppa \
+ && sudo apt-get update \
+ && sudo apt-get install -y git
+
+ - name: Checkout
+ uses: actions/checkout@v6
+
+ - name: Release space from worker
+ if: inputs.runs-on == 'ubuntu-latest'
+ run: |
+ echo "Listing top largest packages"
+ pkgs=$(dpkg-query -Wf '${Installed-Size}\t${Package}\t${Status}\n' | awk '$NF == "installed"{print $1 "\t" $2}' | sort -nr)
+ head -n 30 <<< "${pkgs}"
+ echo
+ df -h
+ echo
+ sudo apt-get remove -y '^llvm-.*|^libllvm.*' || true
+ sudo apt-get remove --auto-remove android-sdk-platform-tools snapd || true
+ sudo apt-get purge --auto-remove android-sdk-platform-tools snapd || true
+ sudo rm -rf /usr/local/lib/android
+ sudo apt-get remove -y '^dotnet-.*|^aspnetcore-.*' || true
+ sudo rm -rf /usr/share/dotnet
+ sudo apt-get remove -y '^mono-.*' || true
+ sudo apt-get remove -y '^ghc-.*' || true
+ sudo apt-get remove -y '.*jdk.*|.*jre.*' || true
+ sudo apt-get remove -y 'php.*' || true
+ sudo apt-get remove -y hhvm powershell firefox monodoc-manual msbuild || true
+ sudo apt-get remove -y '^google-.*' || true
+ sudo apt-get remove -y azure-cli || true
+ sudo apt-get remove -y '^mongo.*-.*|^postgresql-.*|^mysql-.*|^mssql-.*' || true
+ sudo apt-get remove -y '^gfortran-.*' || true
+ sudo apt-get remove -y microsoft-edge-stable || true
+ sudo apt-get remove -y firefox || true
+ sudo apt-get remove -y powershell || true
+ sudo apt-get remove -y r-base-core || true
+ sudo apt-get autoremove -y
+ sudo apt-get clean
+ echo
+ echo "Listing top largest packages"
+ pkgs=$(dpkg-query -Wf '${Installed-Size}\t${Package}\t${Status}\n' | awk '$NF == "installed"{print $1 "\t" $2}' | sort -nr)
+ head -n 30 <<< "${pkgs}"
+ echo
+ sudo rm -rfv build || true
+ sudo rm -rf /usr/share/dotnet || true
+ sudo rm -rf /opt/ghc || true
+ sudo rm -rf "/usr/local/share/boost" || true
+ sudo rm -rf "$AGENT_TOOLSDIRECTORY" || true
+ df -h
+
+ - name: Docker meta
+ id: meta
+ if: github.event_name != 'pull_request'
+ uses: docker/metadata-action@v5
+ with:
+ images: |
+ quay.io/go-skynet/local-ai-backends
+ localai/localai-backends
+ tags: |
+ type=ref,event=branch
+ type=semver,pattern={{raw}}
+ type=sha
+ flavor: |
+ latest=${{ inputs.tag-latest }}
+ suffix=${{ inputs.tag-suffix }},onlatest=true
+
+ - name: Docker meta for PR
+ id: meta_pull_request
+ if: github.event_name == 'pull_request'
+ uses: docker/metadata-action@v5
+ with:
+ images: |
+ quay.io/go-skynet/ci-tests
+ tags: |
+ type=ref,event=branch,suffix=${{ github.event.number }}-${{ inputs.backend }}-${{ inputs.build-type }}-${{ inputs.cuda-major-version }}-${{ inputs.cuda-minor-version }}
+ type=semver,pattern={{raw}},suffix=${{ github.event.number }}-${{ inputs.backend }}-${{ inputs.build-type }}-${{ inputs.cuda-major-version }}-${{ inputs.cuda-minor-version }}
+ type=sha,suffix=${{ github.event.number }}-${{ inputs.backend }}-${{ inputs.build-type }}-${{ inputs.cuda-major-version }}-${{ inputs.cuda-minor-version }}
+ flavor: |
+ latest=${{ inputs.tag-latest }}
+ suffix=${{ inputs.tag-suffix }},onlatest=true
+## End testing image
+ - name: Set up QEMU
+ uses: docker/setup-qemu-action@master
+ with:
+ platforms: all
+
+ - name: Set up Docker Buildx
+ id: buildx
+ uses: docker/setup-buildx-action@master
+
+ - name: Login to DockerHub
+ if: github.event_name != 'pull_request'
+ uses: docker/login-action@v3
+ with:
+ username: ${{ secrets.dockerUsername }}
+ password: ${{ secrets.dockerPassword }}
+
+ - name: Login to Quay.io
+ if: ${{ env.quay_username != '' }}
+ uses: docker/login-action@v3
+ with:
+ registry: quay.io
+ username: ${{ secrets.quayUsername }}
+ password: ${{ secrets.quayPassword }}
+
+ - name: Build and push
+ uses: docker/build-push-action@v6
+ if: github.event_name != 'pull_request'
+ with:
+ builder: ${{ steps.buildx.outputs.name }}
+ build-args: |
+ BUILD_TYPE=${{ inputs.build-type }}
+ SKIP_DRIVERS=${{ inputs.skip-drivers }}
+ CUDA_MAJOR_VERSION=${{ inputs.cuda-major-version }}
+ CUDA_MINOR_VERSION=${{ inputs.cuda-minor-version }}
+ BASE_IMAGE=${{ inputs.base-image }}
+ BACKEND=${{ inputs.backend }}
+ UBUNTU_VERSION=${{ inputs.ubuntu-version }}
+ context: ${{ inputs.context }}
+ file: ${{ inputs.dockerfile }}
+ cache-from: type=gha
+ platforms: ${{ inputs.platforms }}
+ push: ${{ github.event_name != 'pull_request' }}
+ tags: ${{ steps.meta.outputs.tags }}
+ labels: ${{ steps.meta.outputs.labels }}
+
+ - name: Build and push (PR)
+ uses: docker/build-push-action@v6
+ if: github.event_name == 'pull_request'
+ with:
+ builder: ${{ steps.buildx.outputs.name }}
+ build-args: |
+ BUILD_TYPE=${{ inputs.build-type }}
+ SKIP_DRIVERS=${{ inputs.skip-drivers }}
+ CUDA_MAJOR_VERSION=${{ inputs.cuda-major-version }}
+ CUDA_MINOR_VERSION=${{ inputs.cuda-minor-version }}
+ BASE_IMAGE=${{ inputs.base-image }}
+ BACKEND=${{ inputs.backend }}
+ UBUNTU_VERSION=${{ inputs.ubuntu-version }}
+ context: ${{ inputs.context }}
+ file: ${{ inputs.dockerfile }}
+ cache-from: type=gha
+ platforms: ${{ inputs.platforms }}
+ push: ${{ env.quay_username != '' }}
+ tags: ${{ steps.meta_pull_request.outputs.tags }}
+ labels: ${{ steps.meta_pull_request.outputs.labels }}
+
+
+
+ - name: job summary
+ run: |
+ echo "Built image: ${{ steps.meta.outputs.labels }}" >> $GITHUB_STEP_SUMMARY
diff --git a/.github/workflows/backend_build_darwin.yml b/.github/workflows/backend_build_darwin.yml
new file mode 100644
index 0000000000000000000000000000000000000000..438301b52e1f112e338bb00c2787a97c7746b249
--- /dev/null
+++ b/.github/workflows/backend_build_darwin.yml
@@ -0,0 +1,144 @@
+---
+name: 'build darwin python backend container images (reusable)'
+
+on:
+ workflow_call:
+ inputs:
+ backend:
+ description: 'Backend to build'
+ required: true
+ type: string
+ build-type:
+ description: 'Build type (e.g., mps)'
+ default: ''
+ type: string
+ use-pip:
+ description: 'Use pip to install dependencies'
+ default: false
+ type: boolean
+ lang:
+ description: 'Programming language (e.g. go)'
+ default: 'python'
+ type: string
+ go-version:
+ description: 'Go version to use'
+ default: '1.24.x'
+ type: string
+ tag-suffix:
+ description: 'Tag suffix for the built image'
+ required: true
+ type: string
+ runs-on:
+ description: 'Runner to use'
+ default: 'macOS-14'
+ type: string
+ secrets:
+ dockerUsername:
+ required: false
+ dockerPassword:
+ required: false
+ quayUsername:
+ required: true
+ quayPassword:
+ required: true
+
+jobs:
+ darwin-backend-build:
+ runs-on: ${{ inputs.runs-on }}
+ strategy:
+ matrix:
+ go-version: ['${{ inputs.go-version }}']
+ steps:
+ - name: Clone
+ uses: actions/checkout@v6
+ with:
+ submodules: true
+
+ - name: Setup Go ${{ matrix.go-version }}
+ uses: actions/setup-go@v5
+ with:
+ go-version: ${{ matrix.go-version }}
+ cache: false
+
+ # You can test your matrix by printing the current Go version
+ - name: Display Go version
+ run: go version
+
+ - name: Dependencies
+ run: |
+ brew install protobuf grpc make protoc-gen-go protoc-gen-go-grpc libomp llvm
+
+ - name: Build ${{ inputs.backend }}-darwin
+ run: |
+ make protogen-go
+ BACKEND=${{ inputs.backend }} BUILD_TYPE=${{ inputs.build-type }} USE_PIP=${{ inputs.use-pip }} make build-darwin-${{ inputs.lang }}-backend
+
+ - name: Upload ${{ inputs.backend }}.tar
+ uses: actions/upload-artifact@v6
+ with:
+ name: ${{ inputs.backend }}-tar
+ path: backend-images/${{ inputs.backend }}.tar
+
+ darwin-backend-publish:
+ needs: darwin-backend-build
+ if: github.event_name != 'pull_request'
+ runs-on: ubuntu-latest
+ steps:
+ - name: Download ${{ inputs.backend }}.tar
+ uses: actions/download-artifact@v7
+ with:
+ name: ${{ inputs.backend }}-tar
+ path: .
+
+ - name: Install crane
+ run: |
+ curl -L https://github.com/google/go-containerregistry/releases/latest/download/go-containerregistry_Linux_x86_64.tar.gz | tar -xz
+ sudo mv crane /usr/local/bin/
+
+ - name: Log in to DockerHub
+ run: |
+ echo "${{ secrets.dockerPassword }}" | crane auth login docker.io -u "${{ secrets.dockerUsername }}" --password-stdin
+
+ - name: Log in to quay.io
+ run: |
+ echo "${{ secrets.quayPassword }}" | crane auth login quay.io -u "${{ secrets.quayUsername }}" --password-stdin
+
+ - name: Docker meta
+ id: meta
+ uses: docker/metadata-action@v5
+ with:
+ images: |
+ localai/localai-backends
+ tags: |
+ type=ref,event=branch
+ type=semver,pattern={{raw}}
+ type=sha
+ flavor: |
+ latest=auto
+ suffix=${{ inputs.tag-suffix }},onlatest=true
+
+ - name: Docker meta
+ id: quaymeta
+ uses: docker/metadata-action@v5
+ with:
+ images: |
+ quay.io/go-skynet/local-ai-backends
+ tags: |
+ type=ref,event=branch
+ type=semver,pattern={{raw}}
+ type=sha
+ flavor: |
+ latest=auto
+ suffix=${{ inputs.tag-suffix }},onlatest=true
+
+ - name: Push Docker image (DockerHub)
+ run: |
+ for tag in $(echo "${{ steps.meta.outputs.tags }}" | tr ',' '\n'); do
+ crane push ${{ inputs.backend }}.tar $tag
+ done
+
+ - name: Push Docker image (Quay)
+ run: |
+ for tag in $(echo "${{ steps.quaymeta.outputs.tags }}" | tr ',' '\n'); do
+ crane push ${{ inputs.backend }}.tar $tag
+ done
diff --git a/.github/workflows/backend_pr.yml b/.github/workflows/backend_pr.yml
new file mode 100644
index 0000000000000000000000000000000000000000..f345448b2cff44ba1414a09e22af690883db1f6b
--- /dev/null
+++ b/.github/workflows/backend_pr.yml
@@ -0,0 +1,79 @@
+name: 'build backend container images (PR-filtered)'
+
+on:
+ pull_request:
+
+concurrency:
+ group: ci-backends-pr-${{ github.head_ref || github.ref }}-${{ github.repository }}
+ cancel-in-progress: true
+
+jobs:
+ generate-matrix:
+ runs-on: ubuntu-latest
+ outputs:
+ matrix: ${{ steps.set-matrix.outputs.matrix }}
+ matrix-darwin: ${{ steps.set-matrix.outputs.matrix-darwin }}
+ has-backends: ${{ steps.set-matrix.outputs.has-backends }}
+ has-backends-darwin: ${{ steps.set-matrix.outputs.has-backends-darwin }}
+ steps:
+ - name: Checkout repository
+ uses: actions/checkout@v6
+
+ - name: Setup Bun
+ uses: oven-sh/setup-bun@v2
+
+ - name: Install dependencies
+ run: |
+ bun add js-yaml
+ bun add @octokit/core
+
+ # filters the matrix in backend.yml
+ - name: Filter matrix for changed backends
+ id: set-matrix
+ env:
+ GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
+ GITHUB_EVENT_PATH: ${{ github.event_path }}
+ run: bun run scripts/changed-backends.js
+
+ backend-jobs:
+ needs: generate-matrix
+ uses: ./.github/workflows/backend_build.yml
+ if: needs.generate-matrix.outputs.has-backends == 'true'
+ with:
+ tag-latest: ${{ matrix.tag-latest }}
+ tag-suffix: ${{ matrix.tag-suffix }}
+ build-type: ${{ matrix.build-type }}
+ cuda-major-version: ${{ matrix.cuda-major-version }}
+ cuda-minor-version: ${{ matrix.cuda-minor-version }}
+ platforms: ${{ matrix.platforms }}
+ runs-on: ${{ matrix.runs-on }}
+ base-image: ${{ matrix.base-image }}
+ backend: ${{ matrix.backend }}
+ dockerfile: ${{ matrix.dockerfile }}
+ skip-drivers: ${{ matrix.skip-drivers }}
+ context: ${{ matrix.context }}
+ ubuntu-version: ${{ matrix.ubuntu-version }}
+ secrets:
+ quayUsername: ${{ secrets.LOCALAI_REGISTRY_USERNAME }}
+ quayPassword: ${{ secrets.LOCALAI_REGISTRY_PASSWORD }}
+ strategy:
+ fail-fast: true
+ matrix: ${{ fromJson(needs.generate-matrix.outputs.matrix) }}
+ backend-jobs-darwin:
+ needs: generate-matrix
+ uses: ./.github/workflows/backend_build_darwin.yml
+ if: needs.generate-matrix.outputs.has-backends-darwin == 'true'
+ with:
+ backend: ${{ matrix.backend }}
+ build-type: ${{ matrix.build-type }}
+ go-version: "1.24.x"
+ tag-suffix: ${{ matrix.tag-suffix }}
+ lang: ${{ matrix.lang || 'python' }}
+ use-pip: ${{ matrix.backend == 'diffusers' }}
+ runs-on: "macos-latest"
+ secrets:
+ quayUsername: ${{ secrets.LOCALAI_REGISTRY_USERNAME }}
+ quayPassword: ${{ secrets.LOCALAI_REGISTRY_PASSWORD }}
+ strategy:
+ fail-fast: true
+ matrix: ${{ fromJson(needs.generate-matrix.outputs.matrix-darwin) }}
diff --git a/.github/workflows/build-test.yaml b/.github/workflows/build-test.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..36474a778b462f54df51c1f79238965395e0d0bd
--- /dev/null
+++ b/.github/workflows/build-test.yaml
@@ -0,0 +1,67 @@
+name: Build test
+
+on:
+ push:
+ branches:
+ - master
+ pull_request:
+
+jobs:
+ build-test:
+ runs-on: ubuntu-latest
+ steps:
+ - name: Checkout
+ uses: actions/checkout@v6
+ with:
+ fetch-depth: 0
+ - name: Set up Go
+ uses: actions/setup-go@v5
+ with:
+ go-version: 1.25
+ - name: Run GoReleaser
+ run: |
+ make dev-dist
+ launcher-build-darwin:
+ runs-on: macos-latest
+ steps:
+ - name: Checkout
+ uses: actions/checkout@v6
+ with:
+ fetch-depth: 0
+ - name: Set up Go
+ uses: actions/setup-go@v5
+ with:
+ go-version: 1.25
+ - name: Build launcher for macOS ARM64
+ run: |
+ make build-launcher-darwin
+ ls -liah dist
+ - name: Upload macOS launcher artifacts
+ uses: actions/upload-artifact@v6
+ with:
+ name: launcher-macos
+ path: dist/
+ retention-days: 30
+
+ launcher-build-linux:
+ runs-on: ubuntu-latest
+ steps:
+ - name: Checkout
+ uses: actions/checkout@v6
+ with:
+ fetch-depth: 0
+ - name: Set up Go
+ uses: actions/setup-go@v5
+ with:
+ go-version: 1.25
+ - name: Build launcher for Linux
+ run: |
+ sudo apt-get update
+ sudo apt-get install golang gcc libgl1-mesa-dev xorg-dev libxkbcommon-dev
+ make build-launcher-linux
+ - name: Upload Linux launcher artifacts
+ uses: actions/upload-artifact@v6
+ with:
+ name: launcher-linux
+ path: local-ai-launcher-linux.tar.xz
+ retention-days: 30
\ No newline at end of file
diff --git a/.github/workflows/bump_deps.yaml b/.github/workflows/bump_deps.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..942bdb989dc664dbf10c0247c899159f328fc9ee
--- /dev/null
+++ b/.github/workflows/bump_deps.yaml
@@ -0,0 +1,63 @@
+name: Bump Backend dependencies
+on:
+ schedule:
+ - cron: 0 20 * * *
+ workflow_dispatch:
+jobs:
+ bump-backends:
+ strategy:
+ fail-fast: false
+ matrix:
+ include:
+ - repository: "ggml-org/llama.cpp"
+ variable: "LLAMA_VERSION"
+ branch: "master"
+ file: "backend/cpp/llama-cpp/Makefile"
+ - repository: "ggml-org/whisper.cpp"
+ variable: "WHISPER_CPP_VERSION"
+ branch: "master"
+ file: "backend/go/whisper/Makefile"
+ - repository: "PABannier/bark.cpp"
+ variable: "BARKCPP_VERSION"
+ branch: "main"
+ file: "Makefile"
+ - repository: "leejet/stable-diffusion.cpp"
+ variable: "STABLEDIFFUSION_GGML_VERSION"
+ branch: "master"
+ file: "backend/go/stablediffusion-ggml/Makefile"
+ - repository: "mudler/go-piper"
+ variable: "PIPER_VERSION"
+ branch: "master"
+ file: "backend/go/piper/Makefile"
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v6
+ - name: Bump dependencies 🔧
+ id: bump
+ run: |
+ bash .github/bump_deps.sh ${{ matrix.repository }} ${{ matrix.branch }} ${{ matrix.variable }} ${{ matrix.file }}
+ {
+ echo 'message<> "$GITHUB_OUTPUT"
+ {
+ echo 'commit<> "$GITHUB_OUTPUT"
+ rm -rfv ${{ matrix.variable }}_message.txt
+ rm -rfv ${{ matrix.variable }}_commit.txt
+ - name: Create Pull Request
+ uses: peter-evans/create-pull-request@v8
+ with:
+ token: ${{ secrets.UPDATE_BOT_TOKEN }}
+ push-to-fork: ci-forks/LocalAI
+ commit-message: ':arrow_up: Update ${{ matrix.repository }}'
+ title: 'chore: :arrow_up: Update ${{ matrix.repository }} to `${{ steps.bump.outputs.commit }}`'
+ branch: "update/${{ matrix.variable }}"
+ body: ${{ steps.bump.outputs.message }}
+ signoff: true
+
+
+
diff --git a/.github/workflows/bump_docs.yaml b/.github/workflows/bump_docs.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..0437d084c03f29cc4a07f9c5d7bf9e0f9a16cfc7
--- /dev/null
+++ b/.github/workflows/bump_docs.yaml
@@ -0,0 +1,31 @@
+name: Bump Documentation
+on:
+ schedule:
+ - cron: 0 20 * * *
+ workflow_dispatch:
+jobs:
+ bump-docs:
+ strategy:
+ fail-fast: false
+ matrix:
+ include:
+ - repository: "mudler/LocalAI"
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v6
+ - name: Bump dependencies 🔧
+ run: |
+ bash .github/bump_docs.sh ${{ matrix.repository }}
+ - name: Create Pull Request
+ uses: peter-evans/create-pull-request@v8
+ with:
+ token: ${{ secrets.UPDATE_BOT_TOKEN }}
+ push-to-fork: ci-forks/LocalAI
+ commit-message: ':arrow_up: Update docs version ${{ matrix.repository }}'
+ title: 'docs: :arrow_up: update docs version ${{ matrix.repository }}'
+ branch: "update/docs"
+ body: Bump of ${{ matrix.repository }} version inside docs
+ signoff: true
+
+
+
diff --git a/.github/workflows/checksum_checker.yaml b/.github/workflows/checksum_checker.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..78ea956907c240cf9157b4218e002d4c03d155d5
--- /dev/null
+++ b/.github/workflows/checksum_checker.yaml
@@ -0,0 +1,46 @@
+name: Check if checksums are up-to-date
+on:
+ schedule:
+ - cron: 0 20 * * *
+ workflow_dispatch:
+jobs:
+ checksum_check:
+ runs-on: ubuntu-latest
+ steps:
+ - name: Force Install GIT latest
+ run: |
+ sudo apt-get update \
+ && sudo apt-get install -y software-properties-common \
+ && sudo apt-get update \
+ && sudo add-apt-repository -y ppa:git-core/ppa \
+ && sudo apt-get update \
+ && sudo apt-get install -y git
+ - uses: actions/checkout@v6
+ - name: Install dependencies
+ run: |
+ sudo apt-get update
+ sudo apt-get install -y pip wget
+ pip install huggingface_hub
+ - name: 'Setup yq'
+ uses: dcarbone/install-yq-action@v1.3.1
+ with:
+ version: 'v4.44.2'
+ download-compressed: true
+ force: true
+
+ - name: Checksum checker 🔧
+ run: |
+ export HF_HOME=/hf_cache
+ sudo mkdir /hf_cache
+ sudo chmod 777 /hf_cache
+ bash .github/checksum_checker.sh gallery/index.yaml
+ - name: Create Pull Request
+ uses: peter-evans/create-pull-request@v8
+ with:
+ token: ${{ secrets.UPDATE_BOT_TOKEN }}
+ push-to-fork: ci-forks/LocalAI
+ commit-message: ':arrow_up: Checksum updates in gallery/index.yaml'
+ title: 'chore(model-gallery): :arrow_up: update checksum'
+ branch: "update/checksum"
+ body: Updating checksums in gallery/index.yaml
+ signoff: true
diff --git a/.github/workflows/dependabot_auto.yml b/.github/workflows/dependabot_auto.yml
new file mode 100644
index 0000000000000000000000000000000000000000..873016ee172969ef76a04b1d784859212f25b39e
--- /dev/null
+++ b/.github/workflows/dependabot_auto.yml
@@ -0,0 +1,43 @@
+name: Dependabot auto-merge
+on:
+- pull_request_target
+
+permissions:
+ contents: write
+ pull-requests: write
+ packages: read
+
+jobs:
+ dependabot:
+ runs-on: ubuntu-latest
+ if: ${{ github.actor == 'dependabot[bot]' }}
+ steps:
+ - name: Dependabot metadata
+ id: metadata
+ uses: dependabot/fetch-metadata@v2.5.0
+ with:
+ github-token: "${{ secrets.GITHUB_TOKEN }}"
+ skip-commit-verification: true
+
+ - name: Checkout repository
+ uses: actions/checkout@v6
+
+ - name: Approve a PR if not already approved
+ run: |
+ gh pr checkout "$PR_URL"
+ if [ "$(gh pr status --json reviewDecision -q .currentBranch.reviewDecision)" != "APPROVED" ];
+ then
+ gh pr review --approve "$PR_URL"
+ else
+ echo "PR already approved.";
+ fi
+ env:
+ PR_URL: ${{github.event.pull_request.html_url}}
+ GITHUB_TOKEN: ${{secrets.GITHUB_TOKEN}}
+
+ - name: Enable auto-merge for Dependabot PRs
+ if: ${{ contains(github.event.pull_request.title, 'bump')}}
+ run: gh pr merge --auto --squash "$PR_URL"
+ env:
+ PR_URL: ${{github.event.pull_request.html_url}}
+ GITHUB_TOKEN: ${{secrets.GITHUB_TOKEN}}
diff --git a/.github/workflows/deploy-explorer.yaml b/.github/workflows/deploy-explorer.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..aa17f162c6a40c3af09be04a0025a4c936649c1c
--- /dev/null
+++ b/.github/workflows/deploy-explorer.yaml
@@ -0,0 +1,64 @@
+name: Explorer deployment
+
+on:
+ push:
+ branches:
+ - master
+ tags:
+ - 'v*'
+
+concurrency:
+ group: ci-deploy-${{ github.head_ref || github.ref }}-${{ github.repository }}
+
+jobs:
+ build-linux:
+ runs-on: ubuntu-latest
+ steps:
+ - name: Clone
+ uses: actions/checkout@v6
+ with:
+ submodules: true
+ - uses: actions/setup-go@v5
+ with:
+ go-version: '1.21.x'
+ cache: false
+ - name: Dependencies
+ run: |
+ sudo apt-get update
+ sudo apt-get install -y wget curl build-essential ffmpeg protobuf-compiler ccache upx-ucl gawk cmake libgmock-dev
+ go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@1958fcbe2ca8bd93af633f11e97d44e567e945af
+ go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.34.2
+ make protogen-go
+ - name: Build api
+ run: |
+ CGO_ENABLED=0 make build
+ - name: rm
+ uses: appleboy/ssh-action@v1.2.4
+ with:
+ host: ${{ secrets.EXPLORER_SSH_HOST }}
+ username: ${{ secrets.EXPLORER_SSH_USERNAME }}
+ key: ${{ secrets.EXPLORER_SSH_KEY }}
+ port: ${{ secrets.EXPLORER_SSH_PORT }}
+ script: |
+ sudo rm -rf local-ai/ || true
+ - name: copy file via ssh
+ uses: appleboy/scp-action@v1.0.0
+ with:
+ host: ${{ secrets.EXPLORER_SSH_HOST }}
+ username: ${{ secrets.EXPLORER_SSH_USERNAME }}
+ key: ${{ secrets.EXPLORER_SSH_KEY }}
+ port: ${{ secrets.EXPLORER_SSH_PORT }}
+ source: "local-ai"
+ overwrite: true
+ rm: true
+ target: ./local-ai
+ - name: restarting
+ uses: appleboy/ssh-action@v1.2.4
+ with:
+ host: ${{ secrets.EXPLORER_SSH_HOST }}
+ username: ${{ secrets.EXPLORER_SSH_USERNAME }}
+ key: ${{ secrets.EXPLORER_SSH_KEY }}
+ port: ${{ secrets.EXPLORER_SSH_PORT }}
+ script: |
+ sudo cp -rfv local-ai/local-ai /usr/bin/local-ai
+ sudo systemctl restart local-ai
diff --git a/.github/workflows/disabled/comment-pr.yaml b/.github/workflows/disabled/comment-pr.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..bb1012f2a7c02e08664ec394b71b1c6cd4fdbab3
--- /dev/null
+++ b/.github/workflows/disabled/comment-pr.yaml
@@ -0,0 +1,83 @@
+name: Comment PRs
+on:
+ pull_request_target:
+
+jobs:
+ comment-pr:
+ env:
+ MODEL_NAME: hermes-2-theta-llama-3-8b
+ runs-on: ubuntu-latest
+ steps:
+ - name: Checkout code
+ uses: actions/checkout@v3
+ with:
+ ref: "${{ github.event.pull_request.merge_commit_sha }}"
+ fetch-depth: 0 # needed to checkout all branches for this Action to work
+ - uses: mudler/localai-github-action@v1
+ with:
+ model: 'hermes-2-theta-llama-3-8b' # Any from models.localai.io, or from huggingface.com with: "huggingface:///file"
+ # Check the PR diff using the current branch and the base branch of the PR
+ - uses: GrantBirki/git-diff-action@v2.7.0
+ id: git-diff-action
+ with:
+ json_diff_file_output: diff.json
+ raw_diff_file_output: diff.txt
+ file_output_only: "true"
+ base_branch: ${{ github.event.pull_request.base.sha }}
+ - name: Show diff
+ env:
+ DIFF: ${{ steps.git-diff-action.outputs.raw-diff-path }}
+ run: |
+ cat $DIFF
+ - name: Summarize
+ env:
+ DIFF: ${{ steps.git-diff-action.outputs.raw-diff-path }}
+ id: summarize
+ run: |
+ input="$(cat $DIFF)"
+
+ # Define the LocalAI API endpoint
+ API_URL="http://localhost:8080/chat/completions"
+
+ # Create a JSON payload using jq to handle special characters
+ json_payload=$(jq -n --arg input "$input" '{
+ model: "'$MODEL_NAME'",
+ messages: [
+ {
+ role: "system",
+ content: "You are LocalAI-bot in Github that helps understanding PRs and assess complexity. Explain what has changed in this PR diff and why"
+ },
+ {
+ role: "user",
+ content: $input
+ }
+ ]
+ }')
+
+ # Send the request to LocalAI
+ response=$(curl -s -X POST $API_URL \
+ -H "Content-Type: application/json" \
+ -d "$json_payload")
+
+ # Extract the summary from the response
+ summary="$(echo $response | jq -r '.choices[0].message.content')"
+
+ # Print the summary
+ # -H "Authorization: Bearer $API_KEY" \
+ echo "Summary:"
+ echo "$summary"
+ echo "payload sent"
+ echo "$json_payload"
+ {
+ echo 'message<> "$GITHUB_OUTPUT"
+ docker logs --tail 10 local-ai
+ - uses: mshick/add-pr-comment@v2
+ if: always()
+ with:
+ repo-token: ${{ secrets.UPDATE_BOT_TOKEN }}
+ message: ${{ steps.summarize.outputs.message }}
+ message-failure: |
+ Uh oh! Could not analyze this PR, maybe it's too big?
diff --git a/.github/workflows/disabled/test-gpu.yml b/.github/workflows/disabled/test-gpu.yml
new file mode 100644
index 0000000000000000000000000000000000000000..ea1de749487a55c42f0f52fa1c01e09891fce007
--- /dev/null
+++ b/.github/workflows/disabled/test-gpu.yml
@@ -0,0 +1,63 @@
+---
+name: 'GPU tests'
+
+on:
+ pull_request:
+ push:
+ branches:
+ - master
+ tags:
+ - '*'
+
+concurrency:
+ group: ci-gpu-tests-${{ github.head_ref || github.ref }}-${{ github.repository }}
+ cancel-in-progress: true
+
+jobs:
+ ubuntu-latest:
+ runs-on: gpu
+ strategy:
+ matrix:
+ go-version: ['1.21.x']
+ steps:
+ - name: Clone
+ uses: actions/checkout@v4
+ with:
+ submodules: true
+ - name: Setup Go ${{ matrix.go-version }}
+ uses: actions/setup-go@v4
+ with:
+ go-version: ${{ matrix.go-version }}
+ # You can test your matrix by printing the current Go version
+ - name: Display Go version
+ run: go version
+ - name: Dependencies
+ run: |
+ sudo apt-get update
+ sudo DEBIAN_FRONTEND=noninteractive apt-get install -y make wget
+ - name: Build
+ run: |
+ if [ ! -e /run/systemd/system ]; then
+ sudo mkdir /run/systemd/system
+ fi
+ sudo mkdir -p /host/tests/${{ github.head_ref || github.ref }}
+ sudo chmod -R 777 /host/tests/${{ github.head_ref || github.ref }}
+ make \
+ TEST_DIR="/host/tests/${{ github.head_ref || github.ref }}" \
+ BUILD_TYPE=cublas \
+ prepare-e2e run-e2e-image test-e2e
+ - name: Release space from worker ♻
+ if: always()
+ run: |
+ sudo rm -rf build || true
+ sudo rm -rf bin || true
+ sudo rm -rf dist || true
+ sudo docker logs $(sudo docker ps -q --filter ancestor=localai-tests) > logs.txt
+ sudo cat logs.txt || true
+ sudo rm -rf logs.txt
+ make clean || true
+ make \
+ TEST_DIR="/host/tests/${{ github.head_ref || github.ref }}" \
+ teardown-e2e || true
+ sudo rm -rf /host/tests/${{ github.head_ref || github.ref }} || true
+ docker system prune -f -a --volumes || true
diff --git a/.github/workflows/gallery-agent.yaml b/.github/workflows/gallery-agent.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..a78d1f436629124641353025f932295234dddea1
--- /dev/null
+++ b/.github/workflows/gallery-agent.yaml
@@ -0,0 +1,132 @@
+name: Gallery Agent
+on:
+
+ schedule:
+ - cron: '0 */3 * * *' # Run every 4 hours
+ workflow_dispatch:
+ inputs:
+ search_term:
+ description: 'Search term for models'
+ required: false
+ default: 'GGUF'
+ type: string
+ limit:
+ description: 'Maximum number of models to process'
+ required: false
+ default: '15'
+ type: string
+ quantization:
+ description: 'Preferred quantization format'
+ required: false
+ default: 'Q4_K_M'
+ type: string
+ max_models:
+ description: 'Maximum number of models to add to the gallery'
+ required: false
+ default: '1'
+ type: string
+jobs:
+ gallery-agent:
+ runs-on: ubuntu-latest
+ steps:
+ - name: Checkout repository
+ uses: actions/checkout@v6
+ with:
+ token: ${{ secrets.GITHUB_TOKEN }}
+
+ - name: Set up Go
+ uses: actions/setup-go@v5
+ with:
+ go-version: '1.21'
+ - name: Proto Dependencies
+ run: |
+ # Install protoc
+ curl -L -s https://github.com/protocolbuffers/protobuf/releases/download/v26.1/protoc-26.1-linux-x86_64.zip -o protoc.zip && \
+ unzip -j -d /usr/local/bin protoc.zip bin/protoc && \
+ rm protoc.zip
+ go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.34.2
+ go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@1958fcbe2ca8bd93af633f11e97d44e567e945af
+ PATH="$PATH:$HOME/go/bin" make protogen-go
+ - uses: mudler/localai-github-action@v1.1
+ with:
+ model: 'https://huggingface.co/bartowski/Qwen_Qwen3-1.7B-GGUF'
+
+ - name: Run gallery agent
+ env:
+ #OPENAI_MODEL: ${{ secrets.OPENAI_MODEL }}
+ OPENAI_MODE: Qwen_Qwen3-1.7B-GGUF
+ OPENAI_BASE_URL: "http://localhost:8080"
+ OPENAI_KEY: ${{ secrets.OPENAI_KEY }}
+ #OPENAI_BASE_URL: ${{ secrets.OPENAI_BASE_URL }}
+ SEARCH_TERM: ${{ github.event.inputs.search_term || 'GGUF' }}
+ LIMIT: ${{ github.event.inputs.limit || '15' }}
+ QUANTIZATION: ${{ github.event.inputs.quantization || 'Q4_K_M' }}
+ MAX_MODELS: ${{ github.event.inputs.max_models || '1' }}
+ run: |
+ export GALLERY_INDEX_PATH=$PWD/gallery/index.yaml
+ go run ./.github/gallery-agent
+
+ - name: Check for changes
+ id: check_changes
+ run: |
+ if git diff --quiet gallery/index.yaml; then
+ echo "changes=false" >> $GITHUB_OUTPUT
+ echo "No changes detected in gallery/index.yaml"
+ else
+ echo "changes=true" >> $GITHUB_OUTPUT
+ echo "Changes detected in gallery/index.yaml"
+ git diff gallery/index.yaml
+ fi
+
+ - name: Read gallery agent summary
+ id: read_summary
+ if: steps.check_changes.outputs.changes == 'true'
+ run: |
+ if [ -f "./gallery-agent-summary.json" ]; then
+ echo "summary_exists=true" >> $GITHUB_OUTPUT
+ # Extract summary data using jq
+ echo "search_term=$(jq -r '.search_term' ./gallery-agent-summary.json)" >> $GITHUB_OUTPUT
+ echo "total_found=$(jq -r '.total_found' ./gallery-agent-summary.json)" >> $GITHUB_OUTPUT
+ echo "models_added=$(jq -r '.models_added' ./gallery-agent-summary.json)" >> $GITHUB_OUTPUT
+ echo "quantization=$(jq -r '.quantization' ./gallery-agent-summary.json)" >> $GITHUB_OUTPUT
+ echo "processing_time=$(jq -r '.processing_time' ./gallery-agent-summary.json)" >> $GITHUB_OUTPUT
+
+ # Create a formatted list of added models with URLs
+ added_models=$(jq -r 'range(0; .added_model_ids | length) as $i | "- [\(.added_model_ids[$i])](\(.added_model_urls[$i]))"' ./gallery-agent-summary.json | tr '\n' '\n')
+ echo "added_models<> $GITHUB_OUTPUT
+ echo "$added_models" >> $GITHUB_OUTPUT
+ echo "EOF" >> $GITHUB_OUTPUT
+ rm -f ./gallery-agent-summary.json
+ else
+ echo "summary_exists=false" >> $GITHUB_OUTPUT
+ fi
+
+ - name: Create Pull Request
+ if: steps.check_changes.outputs.changes == 'true'
+ uses: peter-evans/create-pull-request@v8
+ with:
+ token: ${{ secrets.UPDATE_BOT_TOKEN }}
+ push-to-fork: ci-forks/LocalAI
+ commit-message: 'chore(model gallery): :robot: add new models via gallery agent'
+ title: 'chore(model gallery): :robot: add ${{ steps.read_summary.outputs.models_added || 0 }} new models via gallery agent'
+ # Branch has to be unique so PRs are not overriding each other
+ branch-suffix: timestamp
+ body: |
+ This PR was automatically created by the gallery agent workflow.
+
+ **Summary:**
+ - **Search Term:** ${{ steps.read_summary.outputs.search_term || github.event.inputs.search_term || 'GGUF' }}
+ - **Models Found:** ${{ steps.read_summary.outputs.total_found || 'N/A' }}
+ - **Models Added:** ${{ steps.read_summary.outputs.models_added || '0' }}
+ - **Quantization:** ${{ steps.read_summary.outputs.quantization || github.event.inputs.quantization || 'Q4_K_M' }}
+ - **Processing Time:** ${{ steps.read_summary.outputs.processing_time || 'N/A' }}
+
+ **Added Models:**
+ ${{ steps.read_summary.outputs.added_models || '- No models added' }}
+
+ **Workflow Details:**
+ - Triggered by: `${{ github.event_name }}`
+ - Run ID: `${{ github.run_id }}`
+ - Commit: `${{ github.sha }}`
+ signoff: true
+ delete-branch: true
diff --git a/.github/workflows/generate_grpc_cache.yaml b/.github/workflows/generate_grpc_cache.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..72a2b306741d934ee1506eb3fadf20d5da430238
--- /dev/null
+++ b/.github/workflows/generate_grpc_cache.yaml
@@ -0,0 +1,95 @@
+name: 'generate and publish GRPC docker caches'
+
+on:
+ workflow_dispatch:
+
+ schedule:
+ # daily at midnight
+ - cron: '0 0 * * *'
+
+concurrency:
+ group: grpc-cache-${{ github.head_ref || github.ref }}-${{ github.repository }}
+ cancel-in-progress: true
+
+jobs:
+ generate_caches:
+ strategy:
+ matrix:
+ include:
+ - grpc-base-image: ubuntu:24.04
+ runs-on: 'ubuntu-latest'
+ platforms: 'linux/amd64,linux/arm64'
+ runs-on: ${{matrix.runs-on}}
+ steps:
+ - name: Release space from worker
+ if: matrix.runs-on == 'ubuntu-latest'
+ run: |
+ echo "Listing top largest packages"
+ pkgs=$(dpkg-query -Wf '${Installed-Size}\t${Package}\t${Status}\n' | awk '$NF == "installed"{print $1 "\t" $2}' | sort -nr)
+ head -n 30 <<< "${pkgs}"
+ echo
+ df -h
+ echo
+ sudo apt-get remove -y '^llvm-.*|^libllvm.*' || true
+ sudo apt-get remove --auto-remove android-sdk-platform-tools || true
+ sudo apt-get purge --auto-remove android-sdk-platform-tools || true
+ sudo rm -rf /usr/local/lib/android
+ sudo apt-get remove -y '^dotnet-.*|^aspnetcore-.*' || true
+ sudo rm -rf /usr/share/dotnet
+ sudo apt-get remove -y '^mono-.*' || true
+ sudo apt-get remove -y '^ghc-.*' || true
+ sudo apt-get remove -y '.*jdk.*|.*jre.*' || true
+ sudo apt-get remove -y 'php.*' || true
+ sudo apt-get remove -y hhvm powershell firefox monodoc-manual msbuild || true
+ sudo apt-get remove -y '^google-.*' || true
+ sudo apt-get remove -y azure-cli || true
+ sudo apt-get remove -y '^mongo.*-.*|^postgresql-.*|^mysql-.*|^mssql-.*' || true
+ sudo apt-get remove -y '^gfortran-.*' || true
+ sudo apt-get remove -y microsoft-edge-stable || true
+ sudo apt-get remove -y firefox || true
+ sudo apt-get remove -y powershell || true
+ sudo apt-get remove -y r-base-core || true
+ sudo apt-get autoremove -y
+ sudo apt-get clean
+ echo
+ echo "Listing top largest packages"
+ pkgs=$(dpkg-query -Wf '${Installed-Size}\t${Package}\t${Status}\n' | awk '$NF == "installed"{print $1 "\t" $2}' | sort -nr)
+ head -n 30 <<< "${pkgs}"
+ echo
+ sudo rm -rfv build || true
+ sudo rm -rf /usr/share/dotnet || true
+ sudo rm -rf /opt/ghc || true
+ sudo rm -rf "/usr/local/share/boost" || true
+ sudo rm -rf "$AGENT_TOOLSDIRECTORY" || true
+ df -h
+
+ - name: Set up QEMU
+ uses: docker/setup-qemu-action@master
+ with:
+ platforms: all
+
+ - name: Set up Docker Buildx
+ id: buildx
+ uses: docker/setup-buildx-action@master
+
+ - name: Checkout
+ uses: actions/checkout@v6
+
+ - name: Cache GRPC
+ uses: docker/build-push-action@v6
+ with:
+ builder: ${{ steps.buildx.outputs.name }}
+ # The build-args MUST be an EXACT match between the image cache and other workflow steps that want to use that cache.
+ # This means that even the MAKEFLAGS have to be an EXACT match.
+ # If the build-args are not an EXACT match, it will result in a cache miss, which will require GRPC to be built from scratch.
+ build-args: |
+ GRPC_BASE_IMAGE=${{ matrix.grpc-base-image }}
+ GRPC_MAKEFLAGS=--jobs=4 --output-sync=target
+ GRPC_VERSION=v1.65.0
+ context: .
+ file: ./Dockerfile
+ cache-to: type=gha,ignore-error=true
+ cache-from: type=gha
+ target: grpc
+ platforms: ${{ matrix.platforms }}
+ push: false
diff --git a/.github/workflows/generate_intel_image.yaml b/.github/workflows/generate_intel_image.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..c417ceeb8dbde952d4d79a699c4d5afc4e2696db
--- /dev/null
+++ b/.github/workflows/generate_intel_image.yaml
@@ -0,0 +1,59 @@
+name: 'generate and publish intel docker caches'
+
+on:
+ workflow_dispatch:
+ push:
+ branches:
+ - master
+
+concurrency:
+ group: intel-cache-${{ github.head_ref || github.ref }}-${{ github.repository }}
+ cancel-in-progress: true
+
+jobs:
+ generate_caches:
+ strategy:
+ matrix:
+ include:
+ - base-image: intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04
+ runs-on: 'arc-runner-set'
+ platforms: 'linux/amd64'
+ runs-on: ${{matrix.runs-on}}
+ steps:
+ - name: Set up QEMU
+ uses: docker/setup-qemu-action@master
+ with:
+ platforms: all
+ - name: Login to DockerHub
+ if: github.event_name != 'pull_request'
+ uses: docker/login-action@v3
+ with:
+ username: ${{ secrets.DOCKERHUB_USERNAME }}
+ password: ${{ secrets.DOCKERHUB_PASSWORD }}
+
+ - name: Login to quay
+ if: github.event_name != 'pull_request'
+ uses: docker/login-action@v3
+ with:
+ registry: quay.io
+ username: ${{ secrets.LOCALAI_REGISTRY_USERNAME }}
+ password: ${{ secrets.LOCALAI_REGISTRY_PASSWORD }}
+ - name: Set up Docker Buildx
+ id: buildx
+ uses: docker/setup-buildx-action@master
+
+ - name: Checkout
+ uses: actions/checkout@v6
+
+ - name: Cache Intel images
+ uses: docker/build-push-action@v6
+ with:
+ builder: ${{ steps.buildx.outputs.name }}
+ build-args: |
+ BASE_IMAGE=${{ matrix.base-image }}
+ context: .
+ file: ./Dockerfile
+ tags: quay.io/go-skynet/intel-oneapi-base:24.04
+ push: true
+ target: intel
+ platforms: ${{ matrix.platforms }}
diff --git a/.github/workflows/image-pr.yml b/.github/workflows/image-pr.yml
new file mode 100644
index 0000000000000000000000000000000000000000..fe5236f1699a4986a63c6513bd7bef179a1e8e80
--- /dev/null
+++ b/.github/workflows/image-pr.yml
@@ -0,0 +1,95 @@
+---
+ name: 'build container images tests'
+
+ on:
+ pull_request:
+
+ concurrency:
+ group: ci-${{ github.head_ref || github.ref }}-${{ github.repository }}
+ cancel-in-progress: true
+
+ jobs:
+ image-build:
+ uses: ./.github/workflows/image_build.yml
+ with:
+ tag-latest: ${{ matrix.tag-latest }}
+ tag-suffix: ${{ matrix.tag-suffix }}
+ build-type: ${{ matrix.build-type }}
+ cuda-major-version: ${{ matrix.cuda-major-version }}
+ cuda-minor-version: ${{ matrix.cuda-minor-version }}
+ platforms: ${{ matrix.platforms }}
+ runs-on: ${{ matrix.runs-on }}
+ base-image: ${{ matrix.base-image }}
+ grpc-base-image: ${{ matrix.grpc-base-image }}
+ makeflags: ${{ matrix.makeflags }}
+ ubuntu-version: ${{ matrix.ubuntu-version }}
+ secrets:
+ dockerUsername: ${{ secrets.DOCKERHUB_USERNAME }}
+ dockerPassword: ${{ secrets.DOCKERHUB_PASSWORD }}
+ quayUsername: ${{ secrets.LOCALAI_REGISTRY_USERNAME }}
+ quayPassword: ${{ secrets.LOCALAI_REGISTRY_PASSWORD }}
+ strategy:
+ # Pushing with all jobs in parallel
+ # eats the bandwidth of all the nodes
+ max-parallel: ${{ github.event_name != 'pull_request' && 4 || 8 }}
+ fail-fast: false
+ matrix:
+ include:
+ - build-type: 'cublas'
+ cuda-major-version: "12"
+ cuda-minor-version: "9"
+ platforms: 'linux/amd64'
+ tag-latest: 'false'
+ tag-suffix: '-gpu-nvidia-cuda-12'
+ runs-on: 'ubuntu-latest'
+ base-image: "ubuntu:24.04"
+ makeflags: "--jobs=3 --output-sync=target"
+ ubuntu-version: '2404'
+ - build-type: 'cublas'
+ cuda-major-version: "13"
+ cuda-minor-version: "0"
+ platforms: 'linux/amd64'
+ tag-latest: 'false'
+ tag-suffix: '-gpu-nvidia-cuda-13'
+ runs-on: 'ubuntu-latest'
+ base-image: "ubuntu:22.04"
+ makeflags: "--jobs=3 --output-sync=target"
+ ubuntu-version: '2404'
+ - build-type: 'hipblas'
+ platforms: 'linux/amd64'
+ tag-latest: 'false'
+ tag-suffix: '-hipblas'
+ base-image: "rocm/dev-ubuntu-24.04:6.4.4"
+ grpc-base-image: "ubuntu:24.04"
+ runs-on: 'ubuntu-latest'
+ makeflags: "--jobs=3 --output-sync=target"
+ ubuntu-version: '2404'
+ - build-type: 'sycl'
+ platforms: 'linux/amd64'
+ tag-latest: 'false'
+ base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04"
+ grpc-base-image: "ubuntu:24.04"
+ tag-suffix: 'sycl'
+ runs-on: 'ubuntu-latest'
+ makeflags: "--jobs=3 --output-sync=target"
+ ubuntu-version: '2404'
+ - build-type: 'vulkan'
+ platforms: 'linux/amd64,linux/arm64'
+ tag-latest: 'false'
+ tag-suffix: '-vulkan-core'
+ runs-on: 'ubuntu-latest'
+ base-image: "ubuntu:24.04"
+ makeflags: "--jobs=4 --output-sync=target"
+ ubuntu-version: '2404'
+ - build-type: 'cublas'
+ cuda-major-version: "13"
+ cuda-minor-version: "0"
+ platforms: 'linux/arm64'
+ tag-latest: 'false'
+ tag-suffix: '-nvidia-l4t-arm64-cuda-13'
+ base-image: "ubuntu:24.04"
+ runs-on: 'ubuntu-24.04-arm'
+ makeflags: "--jobs=4 --output-sync=target"
+ skip-drivers: 'false'
+ ubuntu-version: '2404'
+
\ No newline at end of file
diff --git a/.github/workflows/image.yml b/.github/workflows/image.yml
new file mode 100644
index 0000000000000000000000000000000000000000..ce571006e510e87b031c69d73403c80ba77ae406
--- /dev/null
+++ b/.github/workflows/image.yml
@@ -0,0 +1,187 @@
+---
+ name: 'build container images'
+
+ on:
+ push:
+ branches:
+ - master
+ tags:
+ - '*'
+
+ concurrency:
+ group: ci-${{ github.head_ref || github.ref }}-${{ github.repository }}
+ cancel-in-progress: true
+
+ jobs:
+ hipblas-jobs:
+ uses: ./.github/workflows/image_build.yml
+ with:
+ tag-latest: ${{ matrix.tag-latest }}
+ tag-suffix: ${{ matrix.tag-suffix }}
+ build-type: ${{ matrix.build-type }}
+ cuda-major-version: ${{ matrix.cuda-major-version }}
+ cuda-minor-version: ${{ matrix.cuda-minor-version }}
+ platforms: ${{ matrix.platforms }}
+ runs-on: ${{ matrix.runs-on }}
+ base-image: ${{ matrix.base-image }}
+ grpc-base-image: ${{ matrix.grpc-base-image }}
+ aio: ${{ matrix.aio }}
+ makeflags: ${{ matrix.makeflags }}
+ ubuntu-version: ${{ matrix.ubuntu-version }}
+ ubuntu-codename: ${{ matrix.ubuntu-codename }}
+ secrets:
+ dockerUsername: ${{ secrets.DOCKERHUB_USERNAME }}
+ dockerPassword: ${{ secrets.DOCKERHUB_PASSWORD }}
+ quayUsername: ${{ secrets.LOCALAI_REGISTRY_USERNAME }}
+ quayPassword: ${{ secrets.LOCALAI_REGISTRY_PASSWORD }}
+ strategy:
+ matrix:
+ include:
+ - build-type: 'hipblas'
+ platforms: 'linux/amd64'
+ tag-latest: 'auto'
+ tag-suffix: '-gpu-hipblas'
+ base-image: "rocm/dev-ubuntu-24.04:6.4.4"
+ grpc-base-image: "ubuntu:24.04"
+ runs-on: 'ubuntu-latest'
+ makeflags: "--jobs=3 --output-sync=target"
+ aio: "-aio-gpu-hipblas"
+ ubuntu-version: '2404'
+ ubuntu-codename: 'noble'
+
+ core-image-build:
+ uses: ./.github/workflows/image_build.yml
+ with:
+ tag-latest: ${{ matrix.tag-latest }}
+ tag-suffix: ${{ matrix.tag-suffix }}
+ build-type: ${{ matrix.build-type }}
+ cuda-major-version: ${{ matrix.cuda-major-version }}
+ cuda-minor-version: ${{ matrix.cuda-minor-version }}
+ platforms: ${{ matrix.platforms }}
+ runs-on: ${{ matrix.runs-on }}
+ aio: ${{ matrix.aio }}
+ base-image: ${{ matrix.base-image }}
+ grpc-base-image: ${{ matrix.grpc-base-image }}
+ makeflags: ${{ matrix.makeflags }}
+ skip-drivers: ${{ matrix.skip-drivers }}
+ ubuntu-version: ${{ matrix.ubuntu-version }}
+ ubuntu-codename: ${{ matrix.ubuntu-codename }}
+ secrets:
+ dockerUsername: ${{ secrets.DOCKERHUB_USERNAME }}
+ dockerPassword: ${{ secrets.DOCKERHUB_PASSWORD }}
+ quayUsername: ${{ secrets.LOCALAI_REGISTRY_USERNAME }}
+ quayPassword: ${{ secrets.LOCALAI_REGISTRY_PASSWORD }}
+ strategy:
+ #max-parallel: ${{ github.event_name != 'pull_request' && 2 || 4 }}
+ matrix:
+ include:
+ - build-type: ''
+ platforms: 'linux/amd64,linux/arm64'
+ tag-latest: 'auto'
+ tag-suffix: ''
+ base-image: "ubuntu:24.04"
+ runs-on: 'ubuntu-latest'
+ aio: "-aio-cpu"
+ makeflags: "--jobs=4 --output-sync=target"
+ skip-drivers: 'false'
+ ubuntu-version: '2404'
+ ubuntu-codename: 'noble'
+ - build-type: 'cublas'
+ cuda-major-version: "12"
+ cuda-minor-version: "9"
+ platforms: 'linux/amd64'
+ tag-latest: 'auto'
+ tag-suffix: '-gpu-nvidia-cuda-12'
+ runs-on: 'ubuntu-latest'
+ base-image: "ubuntu:24.04"
+ skip-drivers: 'false'
+ makeflags: "--jobs=4 --output-sync=target"
+ aio: "-aio-gpu-nvidia-cuda-12"
+ ubuntu-version: '2404'
+ ubuntu-codename: 'noble'
+ - build-type: 'cublas'
+ cuda-major-version: "13"
+ cuda-minor-version: "0"
+ platforms: 'linux/amd64'
+ tag-latest: 'auto'
+ tag-suffix: '-gpu-nvidia-cuda-13'
+ runs-on: 'ubuntu-latest'
+ base-image: "ubuntu:22.04"
+ skip-drivers: 'false'
+ makeflags: "--jobs=4 --output-sync=target"
+ aio: "-aio-gpu-nvidia-cuda-13"
+ ubuntu-version: '2404'
+ ubuntu-codename: 'noble'
+ - build-type: 'vulkan'
+ platforms: 'linux/amd64,linux/arm64'
+ tag-latest: 'auto'
+ tag-suffix: '-gpu-vulkan'
+ runs-on: 'ubuntu-latest'
+ base-image: "ubuntu:24.04"
+ skip-drivers: 'false'
+ makeflags: "--jobs=4 --output-sync=target"
+ aio: "-aio-gpu-vulkan"
+ ubuntu-version: '2404'
+ ubuntu-codename: 'noble'
+ - build-type: 'intel'
+ platforms: 'linux/amd64'
+ tag-latest: 'auto'
+ base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04"
+ grpc-base-image: "ubuntu:24.04"
+ tag-suffix: '-gpu-intel'
+ runs-on: 'ubuntu-latest'
+ makeflags: "--jobs=3 --output-sync=target"
+ aio: "-aio-gpu-intel"
+ ubuntu-version: '2404'
+ ubuntu-codename: 'noble'
+
+ gh-runner:
+ uses: ./.github/workflows/image_build.yml
+ with:
+ tag-latest: ${{ matrix.tag-latest }}
+ tag-suffix: ${{ matrix.tag-suffix }}
+ build-type: ${{ matrix.build-type }}
+ cuda-major-version: ${{ matrix.cuda-major-version }}
+ cuda-minor-version: ${{ matrix.cuda-minor-version }}
+ platforms: ${{ matrix.platforms }}
+ runs-on: ${{ matrix.runs-on }}
+ aio: ${{ matrix.aio }}
+ base-image: ${{ matrix.base-image }}
+ grpc-base-image: ${{ matrix.grpc-base-image }}
+ makeflags: ${{ matrix.makeflags }}
+ skip-drivers: ${{ matrix.skip-drivers }}
+ ubuntu-version: ${{ matrix.ubuntu-version }}
+ ubuntu-codename: ${{ matrix.ubuntu-codename }}
+ secrets:
+ dockerUsername: ${{ secrets.DOCKERHUB_USERNAME }}
+ dockerPassword: ${{ secrets.DOCKERHUB_PASSWORD }}
+ quayUsername: ${{ secrets.LOCALAI_REGISTRY_USERNAME }}
+ quayPassword: ${{ secrets.LOCALAI_REGISTRY_PASSWORD }}
+ strategy:
+ matrix:
+ include:
+ - build-type: 'cublas'
+ cuda-major-version: "12"
+ cuda-minor-version: "0"
+ platforms: 'linux/arm64'
+ tag-latest: 'auto'
+ tag-suffix: '-nvidia-l4t-arm64'
+ base-image: "nvcr.io/nvidia/l4t-jetpack:r36.4.0"
+ runs-on: 'ubuntu-24.04-arm'
+ makeflags: "--jobs=4 --output-sync=target"
+ skip-drivers: 'true'
+ ubuntu-version: "2204"
+ ubuntu-codename: 'jammy'
+ - build-type: 'cublas'
+ cuda-major-version: "13"
+ cuda-minor-version: "0"
+ platforms: 'linux/arm64'
+ tag-latest: 'auto'
+ tag-suffix: '-nvidia-l4t-arm64-cuda-13'
+ base-image: "ubuntu:24.04"
+ runs-on: 'ubuntu-24.04-arm'
+ makeflags: "--jobs=4 --output-sync=target"
+ skip-drivers: 'false'
+ ubuntu-version: '2404'
+ ubuntu-codename: 'noble'
+
\ No newline at end of file
diff --git a/.github/workflows/image_build.yml b/.github/workflows/image_build.yml
new file mode 100644
index 0000000000000000000000000000000000000000..d72da8af03a4a75aee1841581c5fdd2847fe360f
--- /dev/null
+++ b/.github/workflows/image_build.yml
@@ -0,0 +1,327 @@
+---
+name: 'build container images (reusable)'
+
+on:
+ workflow_call:
+ inputs:
+ base-image:
+ description: 'Base image'
+ required: true
+ type: string
+ grpc-base-image:
+ description: 'GRPC Base image, must be a compatible image with base-image'
+ required: false
+ default: ''
+ type: string
+ build-type:
+ description: 'Build type'
+ default: ''
+ type: string
+ cuda-major-version:
+ description: 'CUDA major version'
+ default: "12"
+ type: string
+ cuda-minor-version:
+ description: 'CUDA minor version'
+ default: "9"
+ type: string
+ platforms:
+ description: 'Platforms'
+ default: ''
+ type: string
+ tag-latest:
+ description: 'Tag latest'
+ default: ''
+ type: string
+ tag-suffix:
+ description: 'Tag suffix'
+ default: ''
+ type: string
+ skip-drivers:
+ description: 'Skip drivers by default'
+ default: 'false'
+ type: string
+ runs-on:
+ description: 'Runs on'
+ required: true
+ default: ''
+ type: string
+ makeflags:
+ description: 'Make Flags'
+ required: false
+ default: '--jobs=4 --output-sync=target'
+ type: string
+ aio:
+ description: 'AIO Image Name'
+ required: false
+ default: ''
+ type: string
+ ubuntu-version:
+ description: 'Ubuntu version'
+ required: false
+ default: '2204'
+ type: string
+ ubuntu-codename:
+ description: 'Ubuntu codename'
+ required: false
+ default: 'noble'
+ type: string
+ secrets:
+ dockerUsername:
+ required: true
+ dockerPassword:
+ required: true
+ quayUsername:
+ required: true
+ quayPassword:
+ required: true
+jobs:
+ reusable_image-build:
+ runs-on: ${{ inputs.runs-on }}
+ steps:
+
+ - name: Free Disk Space (Ubuntu)
+ if: inputs.runs-on == 'ubuntu-latest'
+ uses: jlumbroso/free-disk-space@main
+ with:
+ # this might remove tools that are actually needed,
+ # if set to "true" but frees about 6 GB
+ tool-cache: true
+ # all of these default to true, but feel free to set to
+ # "false" if necessary for your workflow
+ android: true
+ dotnet: true
+ haskell: true
+ large-packages: true
+ docker-images: true
+ swap-storage: true
+ - name: Force Install GIT latest
+ run: |
+ sudo apt-get update \
+ && sudo apt-get install -y software-properties-common \
+ && sudo apt-get update \
+ && sudo add-apt-repository -y ppa:git-core/ppa \
+ && sudo apt-get update \
+ && sudo apt-get install -y git
+ - name: Checkout
+ uses: actions/checkout@v6
+
+ - name: Release space from worker
+ if: inputs.runs-on == 'ubuntu-latest'
+ run: |
+ echo "Listing top largest packages"
+ pkgs=$(dpkg-query -Wf '${Installed-Size}\t${Package}\t${Status}\n' | awk '$NF == "installed"{print $1 "\t" $2}' | sort -nr)
+ head -n 30 <<< "${pkgs}"
+ echo
+ df -h
+ echo
+ sudo apt-get remove -y '^llvm-.*|^libllvm.*' || true
+ sudo apt-get remove --auto-remove android-sdk-platform-tools snapd || true
+ sudo apt-get purge --auto-remove android-sdk-platform-tools snapd || true
+ sudo rm -rf /usr/local/lib/android
+ sudo apt-get remove -y '^dotnet-.*|^aspnetcore-.*' || true
+ sudo rm -rf /usr/share/dotnet
+ sudo apt-get remove -y '^mono-.*' || true
+ sudo apt-get remove -y '^ghc-.*' || true
+ sudo apt-get remove -y '.*jdk.*|.*jre.*' || true
+ sudo apt-get remove -y 'php.*' || true
+ sudo apt-get remove -y hhvm powershell firefox monodoc-manual msbuild || true
+ sudo apt-get remove -y '^google-.*' || true
+ sudo apt-get remove -y azure-cli || true
+ sudo apt-get remove -y '^mongo.*-.*|^postgresql-.*|^mysql-.*|^mssql-.*' || true
+ sudo apt-get remove -y '^gfortran-.*' || true
+ sudo apt-get remove -y microsoft-edge-stable || true
+ sudo apt-get remove -y firefox || true
+ sudo apt-get remove -y powershell || true
+ sudo apt-get remove -y r-base-core || true
+ sudo apt-get autoremove -y
+ sudo apt-get clean
+ echo
+ echo "Listing top largest packages"
+ pkgs=$(dpkg-query -Wf '${Installed-Size}\t${Package}\t${Status}\n' | awk '$NF == "installed"{print $1 "\t" $2}' | sort -nr)
+ head -n 30 <<< "${pkgs}"
+ echo
+ sudo rm -rfv build || true
+ sudo rm -rf /usr/share/dotnet || true
+ sudo rm -rf /opt/ghc || true
+ sudo rm -rf "/usr/local/share/boost" || true
+ sudo rm -rf "$AGENT_TOOLSDIRECTORY" || true
+ df -h
+
+ - name: Docker meta
+ id: meta
+ if: github.event_name != 'pull_request'
+ uses: docker/metadata-action@v5
+ with:
+ images: |
+ quay.io/go-skynet/local-ai
+ localai/localai
+ tags: |
+ type=ref,event=branch
+ type=semver,pattern={{raw}}
+ type=sha
+ flavor: |
+ latest=${{ inputs.tag-latest }}
+ suffix=${{ inputs.tag-suffix }},onlatest=true
+ - name: Docker meta for PR
+ id: meta_pull_request
+ if: github.event_name == 'pull_request'
+ uses: docker/metadata-action@v5
+ with:
+ images: |
+ quay.io/go-skynet/ci-tests
+ tags: |
+ type=ref,event=branch,suffix=localai${{ github.event.number }}-${{ inputs.build-type }}-${{ inputs.cuda-major-version }}-${{ inputs.cuda-minor-version }}
+ type=semver,pattern={{raw}},suffix=localai${{ github.event.number }}-${{ inputs.build-type }}-${{ inputs.cuda-major-version }}-${{ inputs.cuda-minor-version }}
+ type=sha,suffix=localai${{ github.event.number }}-${{ inputs.build-type }}-${{ inputs.cuda-major-version }}-${{ inputs.cuda-minor-version }}
+ flavor: |
+ latest=${{ inputs.tag-latest }}
+ suffix=${{ inputs.tag-suffix }}
+ - name: Docker meta AIO (quay.io)
+ if: inputs.aio != ''
+ id: meta_aio
+ uses: docker/metadata-action@v5
+ with:
+ images: |
+ quay.io/go-skynet/local-ai
+ tags: |
+ type=ref,event=branch
+ type=semver,pattern={{raw}}
+ flavor: |
+ latest=${{ inputs.tag-latest }}
+ suffix=${{ inputs.aio }},onlatest=true
+
+ - name: Docker meta AIO (dockerhub)
+ if: inputs.aio != ''
+ id: meta_aio_dockerhub
+ uses: docker/metadata-action@v5
+ with:
+ images: |
+ localai/localai
+ tags: |
+ type=ref,event=branch
+ type=semver,pattern={{raw}}
+ flavor: |
+ latest=${{ inputs.tag-latest }}
+ suffix=${{ inputs.aio }},onlatest=true
+
+ - name: Set up QEMU
+ uses: docker/setup-qemu-action@master
+ with:
+ platforms: all
+
+ - name: Set up Docker Buildx
+ id: buildx
+ uses: docker/setup-buildx-action@master
+
+ - name: Login to DockerHub
+ if: github.event_name != 'pull_request'
+ uses: docker/login-action@v3
+ with:
+ username: ${{ secrets.dockerUsername }}
+ password: ${{ secrets.dockerPassword }}
+
+ - name: Login to DockerHub
+ if: github.event_name != 'pull_request'
+ uses: docker/login-action@v3
+ with:
+ registry: quay.io
+ username: ${{ secrets.quayUsername }}
+ password: ${{ secrets.quayPassword }}
+
+ - name: Build and push
+ uses: docker/build-push-action@v6
+ if: github.event_name != 'pull_request'
+ with:
+ builder: ${{ steps.buildx.outputs.name }}
+ # The build-args MUST be an EXACT match between the image cache and other workflow steps that want to use that cache.
+ # This means that even the MAKEFLAGS have to be an EXACT match.
+ # If the build-args are not an EXACT match, it will result in a cache miss, which will require GRPC to be built from scratch.
+ # This is why some build args like GRPC_VERSION and MAKEFLAGS are hardcoded
+ build-args: |
+ BUILD_TYPE=${{ inputs.build-type }}
+ CUDA_MAJOR_VERSION=${{ inputs.cuda-major-version }}
+ CUDA_MINOR_VERSION=${{ inputs.cuda-minor-version }}
+ BASE_IMAGE=${{ inputs.base-image }}
+ GRPC_BASE_IMAGE=${{ inputs.grpc-base-image || inputs.base-image }}
+ GRPC_MAKEFLAGS=--jobs=4 --output-sync=target
+ GRPC_VERSION=v1.65.0
+ MAKEFLAGS=${{ inputs.makeflags }}
+ SKIP_DRIVERS=${{ inputs.skip-drivers }}
+ UBUNTU_VERSION=${{ inputs.ubuntu-version }}
+ UBUNTU_CODENAME=${{ inputs.ubuntu-codename }}
+ context: .
+ file: ./Dockerfile
+ cache-from: type=gha
+ platforms: ${{ inputs.platforms }}
+ push: ${{ github.event_name != 'pull_request' }}
+ tags: ${{ steps.meta.outputs.tags }}
+ labels: ${{ steps.meta.outputs.labels }}
+### Start testing image
+ - name: Build and push
+ uses: docker/build-push-action@v6
+ if: github.event_name == 'pull_request'
+ with:
+ builder: ${{ steps.buildx.outputs.name }}
+ # The build-args MUST be an EXACT match between the image cache and other workflow steps that want to use that cache.
+ # This means that even the MAKEFLAGS have to be an EXACT match.
+ # If the build-args are not an EXACT match, it will result in a cache miss, which will require GRPC to be built from scratch.
+ # This is why some build args like GRPC_VERSION and MAKEFLAGS are hardcoded
+ build-args: |
+ BUILD_TYPE=${{ inputs.build-type }}
+ CUDA_MAJOR_VERSION=${{ inputs.cuda-major-version }}
+ CUDA_MINOR_VERSION=${{ inputs.cuda-minor-version }}
+ BASE_IMAGE=${{ inputs.base-image }}
+ GRPC_BASE_IMAGE=${{ inputs.grpc-base-image || inputs.base-image }}
+ GRPC_MAKEFLAGS=--jobs=4 --output-sync=target
+ GRPC_VERSION=v1.65.0
+ MAKEFLAGS=${{ inputs.makeflags }}
+ SKIP_DRIVERS=${{ inputs.skip-drivers }}
+ UBUNTU_VERSION=${{ inputs.ubuntu-version }}
+ UBUNTU_CODENAME=${{ inputs.ubuntu-codename }}
+ context: .
+ file: ./Dockerfile
+ cache-from: type=gha
+ platforms: ${{ inputs.platforms }}
+ #push: true
+ tags: ${{ steps.meta_pull_request.outputs.tags }}
+ labels: ${{ steps.meta_pull_request.outputs.labels }}
+## End testing image
+ - name: Build and push AIO image
+ if: inputs.aio != ''
+ uses: docker/build-push-action@v6
+ with:
+ builder: ${{ steps.buildx.outputs.name }}
+ build-args: |
+ BASE_IMAGE=quay.io/go-skynet/local-ai:${{ steps.meta.outputs.version }}
+ MAKEFLAGS=${{ inputs.makeflags }}
+ context: .
+ file: ./Dockerfile.aio
+ platforms: ${{ inputs.platforms }}
+ push: ${{ github.event_name != 'pull_request' }}
+ tags: ${{ steps.meta_aio.outputs.tags }}
+ labels: ${{ steps.meta_aio.outputs.labels }}
+
+ - name: Build and push AIO image (dockerhub)
+ if: inputs.aio != ''
+ uses: docker/build-push-action@v6
+ with:
+ builder: ${{ steps.buildx.outputs.name }}
+ build-args: |
+ BASE_IMAGE=localai/localai:${{ steps.meta.outputs.version }}
+ MAKEFLAGS=${{ inputs.makeflags }}
+ context: .
+ file: ./Dockerfile.aio
+ platforms: ${{ inputs.platforms }}
+ push: ${{ github.event_name != 'pull_request' }}
+ tags: ${{ steps.meta_aio_dockerhub.outputs.tags }}
+ labels: ${{ steps.meta_aio_dockerhub.outputs.labels }}
+
+ - name: job summary
+ run: |
+ echo "Built image: ${{ steps.meta.outputs.labels }}" >> $GITHUB_STEP_SUMMARY
+
+ - name: job summary(AIO)
+ if: inputs.aio != ''
+ run: |
+ echo "Built image: ${{ steps.meta_aio.outputs.labels }}" >> $GITHUB_STEP_SUMMARY
diff --git a/.github/workflows/labeler.yml b/.github/workflows/labeler.yml
new file mode 100644
index 0000000000000000000000000000000000000000..3b787810fad2096296e0fcc16d6cc348b3af4cd4
--- /dev/null
+++ b/.github/workflows/labeler.yml
@@ -0,0 +1,12 @@
+name: "Pull Request Labeler"
+on:
+- pull_request_target
+
+jobs:
+ labeler:
+ permissions:
+ contents: read
+ pull-requests: write
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/labeler@v6
\ No newline at end of file
diff --git a/.github/workflows/localaibot_automerge.yml b/.github/workflows/localaibot_automerge.yml
new file mode 100644
index 0000000000000000000000000000000000000000..a1513802b2f372b742c05ea38bb66003b8bb6895
--- /dev/null
+++ b/.github/workflows/localaibot_automerge.yml
@@ -0,0 +1,36 @@
+name: LocalAI-bot auto-merge
+on:
+- pull_request_target
+
+permissions:
+ contents: write
+ pull-requests: write
+ packages: read
+ issues: write # for Homebrew/actions/post-comment
+ actions: write # to dispatch publish workflow
+jobs:
+ dependabot:
+ runs-on: ubuntu-latest
+ if: ${{ github.actor == 'localai-bot' && !contains(github.event.pull_request.title, 'chore(model gallery):') }}
+ steps:
+ - name: Checkout repository
+ uses: actions/checkout@v6
+
+ - name: Approve a PR if not already approved
+ run: |
+ gh pr checkout "$PR_URL"
+ if [ "$(gh pr status --json reviewDecision -q .currentBranch.reviewDecision)" != "APPROVED" ];
+ then
+ gh pr review --approve "$PR_URL"
+ else
+ echo "PR already approved.";
+ fi
+ env:
+ PR_URL: ${{github.event.pull_request.html_url}}
+ GITHUB_TOKEN: ${{secrets.GITHUB_TOKEN}}
+
+ - name: Enable auto-merge for LocalAIBot PRs
+ run: gh pr merge --auto --squash "$PR_URL"
+ env:
+ PR_URL: ${{github.event.pull_request.html_url}}
+ GITHUB_TOKEN: ${{secrets.GITHUB_TOKEN}}
diff --git a/.github/workflows/notify-models.yaml b/.github/workflows/notify-models.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..2928fdaf4cc28cc22620329258a5152d85c1f3f6
--- /dev/null
+++ b/.github/workflows/notify-models.yaml
@@ -0,0 +1,174 @@
+name: Notifications for new models
+on:
+ pull_request_target:
+ types:
+ - closed
+
+permissions:
+ contents: read
+ pull-requests: read
+
+jobs:
+ notify-discord:
+ if: ${{ (github.event.pull_request.merged == true) && (contains(github.event.pull_request.labels.*.name, 'area/ai-model')) }}
+ env:
+ MODEL_NAME: gemma-3-12b-it-qat
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v6
+ with:
+ fetch-depth: 0 # needed to checkout all branches for this Action to work
+ ref: ${{ github.event.pull_request.head.sha }} # Checkout the PR head to get the actual changes
+ - uses: mudler/localai-github-action@v1
+ with:
+ model: 'gemma-3-12b-it-qat' # Any from models.localai.io, or from huggingface.com with: "huggingface:///file"
+ # Check the PR diff using the current branch and the base branch of the PR
+ - uses: GrantBirki/git-diff-action@v2.8.1
+ id: git-diff-action
+ with:
+ json_diff_file_output: diff.json
+ raw_diff_file_output: diff.txt
+ file_output_only: "true"
+ - name: Summarize
+ env:
+ DIFF: ${{ steps.git-diff-action.outputs.raw-diff-path }}
+ id: summarize
+ run: |
+ input="$(cat $DIFF)"
+
+ # Define the LocalAI API endpoint
+ API_URL="http://localhost:8080/chat/completions"
+
+ # Create a JSON payload using jq to handle special characters
+ json_payload=$(jq -n --arg input "$input" '{
+ model: "'$MODEL_NAME'",
+ messages: [
+ {
+ role: "system",
+ content: "You are LocalAI-bot. Write a discord message to notify everyone about the new model from the git diff. Make it informal. An example can include: the URL of the model, the name, and a brief description of the model if exists. Also add an hint on how to install it in LocalAI and that can be browsed over https://models.localai.io. For example: local-ai run model_name_here"
+ },
+ {
+ role: "user",
+ content: $input
+ }
+ ]
+ }')
+
+ # Send the request to LocalAI
+ response=$(curl -s -X POST $API_URL \
+ -H "Content-Type: application/json" \
+ -d "$json_payload")
+
+ # Extract the summary from the response
+ summary="$(echo $response | jq -r '.choices[0].message.content')"
+
+ # Print the summary
+ # -H "Authorization: Bearer $API_KEY" \
+ echo "Summary:"
+ echo "$summary"
+ echo "payload sent"
+ echo "$json_payload"
+ {
+ echo 'message<> "$GITHUB_OUTPUT"
+ docker logs --tail 10 local-ai
+ - name: Discord notification
+ env:
+ DISCORD_WEBHOOK: ${{ secrets.DISCORD_WEBHOOK_URL }}
+ DISCORD_USERNAME: "LocalAI-Bot"
+ DISCORD_AVATAR: "https://avatars.githubusercontent.com/u/139863280?v=4"
+ uses: Ilshidur/action-discord@master
+ with:
+ args: ${{ steps.summarize.outputs.message }}
+ - name: Setup tmate session if fails
+ if: ${{ failure() }}
+ uses: mxschmitt/action-tmate@v3.23
+ with:
+ detached: true
+ connect-timeout-seconds: 180
+ limit-access-to-actor: true
+ notify-twitter:
+ if: ${{ (github.event.pull_request.merged == true) && (contains(github.event.pull_request.labels.*.name, 'area/ai-model')) }}
+ env:
+ MODEL_NAME: gemma-3-12b-it-qat
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v6
+ with:
+ fetch-depth: 0 # needed to checkout all branches for this Action to work
+ ref: ${{ github.event.pull_request.head.sha }} # Checkout the PR head to get the actual changes
+ - name: Start LocalAI
+ run: |
+ echo "Starting LocalAI..."
+ docker run -e -ti -d --name local-ai -p 8080:8080 localai/localai:master run --debug $MODEL_NAME
+ until [ "`docker inspect -f {{.State.Health.Status}} local-ai`" == "healthy" ]; do echo "Waiting for container to be ready"; docker logs --tail 10 local-ai; sleep 2; done
+ # Check the PR diff using the current branch and the base branch of the PR
+ - uses: GrantBirki/git-diff-action@v2.8.1
+ id: git-diff-action
+ with:
+ json_diff_file_output: diff.json
+ raw_diff_file_output: diff.txt
+ file_output_only: "true"
+ - name: Summarize
+ env:
+ DIFF: ${{ steps.git-diff-action.outputs.raw-diff-path }}
+ id: summarize
+ run: |
+ input="$(cat $DIFF)"
+
+ # Define the LocalAI API endpoint
+ API_URL="http://localhost:8080/chat/completions"
+
+ # Create a JSON payload using jq to handle special characters
+ json_payload=$(jq -n --arg input "$input" '{
+ model: "'$MODEL_NAME'",
+ messages: [
+ {
+ role: "system",
+ content: "You are LocalAI-bot. Write a twitter message to notify everyone about the new model from the git diff. Make it informal and really short. An example can include: the name, and a brief description of the model if exists. Also add an hint on how to install it in LocalAI. For example: local-ai run model_name_here"
+ },
+ {
+ role: "user",
+ content: $input
+ }
+ ]
+ }')
+
+ # Send the request to LocalAI
+ response=$(curl -s -X POST $API_URL \
+ -H "Content-Type: application/json" \
+ -d "$json_payload")
+
+ # Extract the summary from the response
+ summary="$(echo $response | jq -r '.choices[0].message.content')"
+
+ # Print the summary
+ # -H "Authorization: Bearer $API_KEY" \
+ echo "Summary:"
+ echo "$summary"
+ echo "payload sent"
+ echo "$json_payload"
+ {
+ echo 'message<> "$GITHUB_OUTPUT"
+ docker logs --tail 10 local-ai
+ - uses: Eomm/why-don-t-you-tweet@v2
+ with:
+ tweet-message: ${{ steps.summarize.outputs.message }}
+ env:
+ # Get your tokens from https://developer.twitter.com/apps
+ TWITTER_CONSUMER_API_KEY: ${{ secrets.TWITTER_APP_KEY }}
+ TWITTER_CONSUMER_API_SECRET: ${{ secrets.TWITTER_APP_SECRET }}
+ TWITTER_ACCESS_TOKEN: ${{ secrets.TWITTER_ACCESS_TOKEN }}
+ TWITTER_ACCESS_TOKEN_SECRET: ${{ secrets.TWITTER_ACCESS_TOKEN_SECRET }}
+ - name: Setup tmate session if fails
+ if: ${{ failure() }}
+ uses: mxschmitt/action-tmate@v3.23
+ with:
+ detached: true
+ connect-timeout-seconds: 180
+ limit-access-to-actor: true
diff --git a/.github/workflows/notify-releases.yaml b/.github/workflows/notify-releases.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..b7c6bf847f49e1f46189fea289d5e80cb5ba699d
--- /dev/null
+++ b/.github/workflows/notify-releases.yaml
@@ -0,0 +1,64 @@
+name: Release notifications
+on:
+ release:
+ types:
+ - published
+
+jobs:
+ notify-discord:
+ runs-on: ubuntu-latest
+ env:
+ RELEASE_BODY: ${{ github.event.release.body }}
+ RELEASE_TITLE: ${{ github.event.release.name }}
+ RELEASE_TAG_NAME: ${{ github.event.release.tag_name }}
+ MODEL_NAME: gemma-3-12b-it-qat
+ steps:
+ - uses: mudler/localai-github-action@v1
+ with:
+ model: 'gemma-3-12b-it-qat' # Any from models.localai.io, or from huggingface.com with: "huggingface:///file"
+ - name: Summarize
+ id: summarize
+ run: |
+ input="$RELEASE_TITLE\b$RELEASE_BODY"
+
+ # Define the LocalAI API endpoint
+ API_URL="http://localhost:8080/chat/completions"
+
+ # Create a JSON payload using jq to handle special characters
+ json_payload=$(jq -n --arg input "$input" '{
+ model: "'$MODEL_NAME'",
+ messages: [
+ {
+ role: "system",
+ content: "Write a discord message with a bullet point summary of the release notes."
+ },
+ {
+ role: "user",
+ content: $input
+ }
+ ]
+ }')
+
+ # Send the request to LocalAI API
+ response=$(curl -s -X POST $API_URL \
+ -H "Content-Type: application/json" \
+ -d "$json_payload")
+
+ # Extract the summary from the response
+ summary=$(echo $response | jq -r '.choices[0].message.content')
+
+ # Print the summary
+ # -H "Authorization: Bearer $API_KEY" \
+ {
+ echo 'message<> "$GITHUB_OUTPUT"
+ - name: Discord notification
+ env:
+ DISCORD_WEBHOOK: ${{ secrets.DISCORD_WEBHOOK_URL_RELEASE }}
+ DISCORD_USERNAME: "LocalAI-Bot"
+ DISCORD_AVATAR: "https://avatars.githubusercontent.com/u/139863280?v=4"
+ uses: Ilshidur/action-discord@master
+ with:
+ args: ${{ steps.summarize.outputs.message }}
diff --git a/.github/workflows/prlint.yaml b/.github/workflows/prlint.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..66f338e4778b9b88182f0b2d10b52927a5a048c0
--- /dev/null
+++ b/.github/workflows/prlint.yaml
@@ -0,0 +1,28 @@
+name: Check PR style
+
+on:
+ pull_request_target:
+ types:
+ - opened
+ - reopened
+ - edited
+ - synchronize
+
+jobs:
+ title-lint:
+ runs-on: ubuntu-latest
+ permissions:
+ statuses: write
+ steps:
+ - uses: aslafy-z/conventional-pr-title-action@v3
+ env:
+ GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
+# check-pr-description:
+# runs-on: ubuntu-latest
+# steps:
+# - uses: actions/checkout@v2
+# - uses: jadrol/pr-description-checker-action@v1.0.0
+# id: description-checker
+# with:
+# repo-token: ${{ secrets.GITHUB_TOKEN }}
+# exempt-labels: no qa
diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..104b1beb96a7e693b206190cec34d4b59d16c97a
--- /dev/null
+++ b/.github/workflows/release.yaml
@@ -0,0 +1,64 @@
+name: goreleaser
+
+on:
+ push:
+ tags:
+ - 'v*'
+
+jobs:
+ goreleaser:
+ runs-on: ubuntu-latest
+ steps:
+ - name: Checkout
+ uses: actions/checkout@v6
+ with:
+ fetch-depth: 0
+ - name: Set up Go
+ uses: actions/setup-go@v5
+ with:
+ go-version: 1.23
+ - name: Run GoReleaser
+ uses: goreleaser/goreleaser-action@v6
+ with:
+ version: v2.11.0
+ args: release --clean
+ env:
+ GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
+ launcher-build-darwin:
+ runs-on: macos-latest
+ steps:
+ - name: Checkout
+ uses: actions/checkout@v6
+ with:
+ fetch-depth: 0
+ - name: Set up Go
+ uses: actions/setup-go@v5
+ with:
+ go-version: 1.23
+ - name: Build launcher for macOS ARM64
+ run: |
+ make build-launcher-darwin
+ - name: Upload DMG to Release
+ uses: softprops/action-gh-release@v2
+ with:
+ files: ./dist/LocalAI.dmg
+ launcher-build-linux:
+ runs-on: ubuntu-latest
+ steps:
+ - name: Checkout
+ uses: actions/checkout@v6
+ with:
+ fetch-depth: 0
+ - name: Set up Go
+ uses: actions/setup-go@v5
+ with:
+ go-version: 1.23
+ - name: Build launcher for Linux
+ run: |
+ sudo apt-get update
+ sudo apt-get install golang gcc libgl1-mesa-dev xorg-dev libxkbcommon-dev
+ make build-launcher-linux
+ - name: Upload Linux launcher artifacts
+ uses: softprops/action-gh-release@v2
+ with:
+ files: ./local-ai-launcher-linux.tar.xz
diff --git a/.github/workflows/secscan.yaml b/.github/workflows/secscan.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..2df9190e018b6f22d8f9a9e404600970baad0230
--- /dev/null
+++ b/.github/workflows/secscan.yaml
@@ -0,0 +1,30 @@
+name: "Security Scan"
+
+# Run workflow each time code is pushed to your repository and on a schedule.
+# The scheduled workflow runs every at 00:00 on Sunday UTC time.
+on:
+ push:
+ schedule:
+ - cron: '0 0 * * 0'
+
+jobs:
+ tests:
+ runs-on: ubuntu-latest
+ env:
+ GO111MODULE: on
+ steps:
+ - name: Checkout Source
+ uses: actions/checkout@v6
+ if: ${{ github.actor != 'dependabot[bot]' }}
+ - name: Run Gosec Security Scanner
+ if: ${{ github.actor != 'dependabot[bot]' }}
+ uses: securego/gosec@v2.22.9
+ with:
+ # we let the report trigger content trigger a failure using the GitHub Security features.
+ args: '-no-fail -fmt sarif -out results.sarif ./...'
+ - name: Upload SARIF file
+ if: ${{ github.actor != 'dependabot[bot]' }}
+ uses: github/codeql-action/upload-sarif@v4
+ with:
+ # Path to SARIF file relative to the root of the repository
+ sarif_file: results.sarif
diff --git a/.github/workflows/stalebot.yml b/.github/workflows/stalebot.yml
new file mode 100644
index 0000000000000000000000000000000000000000..07407fbb00901b03e96d56138c24571ff87c1935
--- /dev/null
+++ b/.github/workflows/stalebot.yml
@@ -0,0 +1,24 @@
+name: 'Close stale issues and PRs'
+permissions:
+ issues: write
+ pull-requests: write
+on:
+ schedule:
+ - cron: '30 1 * * *'
+
+jobs:
+ stale:
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/stale@997185467fa4f803885201cee163a9f38240193d # v9
+ with:
+ stale-issue-message: 'This issue is stale because it has been open 90 days with no activity. Remove stale label or comment or this will be closed in 5 days.'
+ stale-pr-message: 'This PR is stale because it has been open 90 days with no activity. Remove stale label or comment or this will be closed in 10 days.'
+ close-issue-message: 'This issue was closed because it has been stalled for 5 days with no activity.'
+ close-pr-message: 'This PR was closed because it has been stalled for 10 days with no activity.'
+ days-before-issue-stale: 90
+ days-before-pr-stale: 90
+ days-before-issue-close: 5
+ days-before-pr-close: 10
+ exempt-issue-labels: 'roadmap'
+ exempt-pr-labels: 'roadmap'
diff --git a/.github/workflows/test-extra.yml b/.github/workflows/test-extra.yml
new file mode 100644
index 0000000000000000000000000000000000000000..0d01cde73e37c9cbd8e46c6a2c5701bd6c329f19
--- /dev/null
+++ b/.github/workflows/test-extra.yml
@@ -0,0 +1,287 @@
+---
+name: 'Tests extras backends'
+
+on:
+ pull_request:
+ push:
+ branches:
+ - master
+ tags:
+ - '*'
+
+concurrency:
+ group: ci-tests-extra-${{ github.head_ref || github.ref }}-${{ github.repository }}
+ cancel-in-progress: true
+
+jobs:
+ # Requires CUDA
+ # tests-chatterbox-tts:
+ # runs-on: ubuntu-latest
+ # steps:
+ # - name: Clone
+ # uses: actions/checkout@v6
+ # with:
+ # submodules: true
+ # - name: Dependencies
+ # run: |
+ # sudo apt-get update
+ # sudo apt-get install build-essential ffmpeg
+ # # Install UV
+ # curl -LsSf https://astral.sh/uv/install.sh | sh
+ # sudo apt-get install -y ca-certificates cmake curl patch python3-pip
+ # sudo apt-get install -y libopencv-dev
+ # pip install --user --no-cache-dir grpcio-tools==1.64.1
+
+ # - name: Test chatterbox-tts
+ # run: |
+ # make --jobs=5 --output-sync=target -C backend/python/chatterbox
+ # make --jobs=5 --output-sync=target -C backend/python/chatterbox test
+ tests-transformers:
+ runs-on: ubuntu-latest
+ steps:
+ - name: Clone
+ uses: actions/checkout@v6
+ with:
+ submodules: true
+ - name: Dependencies
+ run: |
+ sudo apt-get update
+ sudo apt-get install build-essential ffmpeg
+ # Install UV
+ curl -LsSf https://astral.sh/uv/install.sh | sh
+ sudo apt-get install -y ca-certificates cmake curl patch python3-pip
+ sudo apt-get install -y libopencv-dev
+ pip install --user --no-cache-dir grpcio-tools==1.64.1
+
+ - name: Test transformers
+ run: |
+ make --jobs=5 --output-sync=target -C backend/python/transformers
+ make --jobs=5 --output-sync=target -C backend/python/transformers test
+ tests-rerankers:
+ runs-on: ubuntu-latest
+ steps:
+ - name: Clone
+ uses: actions/checkout@v6
+ with:
+ submodules: true
+ - name: Dependencies
+ run: |
+ sudo apt-get update
+ sudo apt-get install build-essential ffmpeg
+ # Install UV
+ curl -LsSf https://astral.sh/uv/install.sh | sh
+ sudo apt-get install -y ca-certificates cmake curl patch python3-pip
+ sudo apt-get install -y libopencv-dev
+ pip install --user --no-cache-dir grpcio-tools==1.64.1
+
+ - name: Test rerankers
+ run: |
+ make --jobs=5 --output-sync=target -C backend/python/rerankers
+ make --jobs=5 --output-sync=target -C backend/python/rerankers test
+
+ tests-diffusers:
+ runs-on: ubuntu-latest
+ steps:
+ - name: Clone
+ uses: actions/checkout@v6
+ with:
+ submodules: true
+ - name: Dependencies
+ run: |
+ sudo apt-get update
+ sudo apt-get install -y build-essential ffmpeg
+ sudo apt-get install -y ca-certificates cmake curl patch python3-pip
+ sudo apt-get install -y libopencv-dev
+ # Install UV
+ curl -LsSf https://astral.sh/uv/install.sh | sh
+ pip install --user --no-cache-dir grpcio-tools==1.64.1
+ - name: Test diffusers
+ run: |
+ make --jobs=5 --output-sync=target -C backend/python/diffusers
+ make --jobs=5 --output-sync=target -C backend/python/diffusers test
+
+ #tests-vllm:
+ # runs-on: ubuntu-latest
+ # steps:
+ # - name: Clone
+ # uses: actions/checkout@v6
+ # with:
+ # submodules: true
+ # - name: Dependencies
+ # run: |
+ # sudo apt-get update
+ # sudo apt-get install -y build-essential ffmpeg
+ # sudo apt-get install -y ca-certificates cmake curl patch python3-pip
+ # sudo apt-get install -y libopencv-dev
+ # # Install UV
+ # curl -LsSf https://astral.sh/uv/install.sh | sh
+ # pip install --user --no-cache-dir grpcio-tools==1.64.1
+ # - name: Test vllm backend
+ # run: |
+ # make --jobs=5 --output-sync=target -C backend/python/vllm
+ # make --jobs=5 --output-sync=target -C backend/python/vllm test
+ # tests-transformers-musicgen:
+ # runs-on: ubuntu-latest
+ # steps:
+ # - name: Clone
+ # uses: actions/checkout@v6
+ # with:
+ # submodules: true
+ # - name: Dependencies
+ # run: |
+ # sudo apt-get update
+ # sudo apt-get install build-essential ffmpeg
+ # # Install UV
+ # curl -LsSf https://astral.sh/uv/install.sh | sh
+ # sudo apt-get install -y ca-certificates cmake curl patch python3-pip
+ # sudo apt-get install -y libopencv-dev
+ # pip install --user --no-cache-dir grpcio-tools==1.64.1
+
+ # - name: Test transformers-musicgen
+ # run: |
+ # make --jobs=5 --output-sync=target -C backend/python/transformers-musicgen
+ # make --jobs=5 --output-sync=target -C backend/python/transformers-musicgen test
+
+ # tests-bark:
+ # runs-on: ubuntu-latest
+ # steps:
+ # - name: Release space from worker
+ # run: |
+ # echo "Listing top largest packages"
+ # pkgs=$(dpkg-query -Wf '${Installed-Size}\t${Package}\t${Status}\n' | awk '$NF == "installed"{print $1 "\t" $2}' | sort -nr)
+ # head -n 30 <<< "${pkgs}"
+ # echo
+ # df -h
+ # echo
+ # sudo apt-get remove -y '^llvm-.*|^libllvm.*' || true
+ # sudo apt-get remove --auto-remove android-sdk-platform-tools || true
+ # sudo apt-get purge --auto-remove android-sdk-platform-tools || true
+ # sudo rm -rf /usr/local/lib/android
+ # sudo apt-get remove -y '^dotnet-.*|^aspnetcore-.*' || true
+ # sudo rm -rf /usr/share/dotnet
+ # sudo apt-get remove -y '^mono-.*' || true
+ # sudo apt-get remove -y '^ghc-.*' || true
+ # sudo apt-get remove -y '.*jdk.*|.*jre.*' || true
+ # sudo apt-get remove -y 'php.*' || true
+ # sudo apt-get remove -y hhvm powershell firefox monodoc-manual msbuild || true
+ # sudo apt-get remove -y '^google-.*' || true
+ # sudo apt-get remove -y azure-cli || true
+ # sudo apt-get remove -y '^mongo.*-.*|^postgresql-.*|^mysql-.*|^mssql-.*' || true
+ # sudo apt-get remove -y '^gfortran-.*' || true
+ # sudo apt-get remove -y microsoft-edge-stable || true
+ # sudo apt-get remove -y firefox || true
+ # sudo apt-get remove -y powershell || true
+ # sudo apt-get remove -y r-base-core || true
+ # sudo apt-get autoremove -y
+ # sudo apt-get clean
+ # echo
+ # echo "Listing top largest packages"
+ # pkgs=$(dpkg-query -Wf '${Installed-Size}\t${Package}\t${Status}\n' | awk '$NF == "installed"{print $1 "\t" $2}' | sort -nr)
+ # head -n 30 <<< "${pkgs}"
+ # echo
+ # sudo rm -rfv build || true
+ # sudo rm -rf /usr/share/dotnet || true
+ # sudo rm -rf /opt/ghc || true
+ # sudo rm -rf "/usr/local/share/boost" || true
+ # sudo rm -rf "$AGENT_TOOLSDIRECTORY" || true
+ # df -h
+ # - name: Clone
+ # uses: actions/checkout@v6
+ # with:
+ # submodules: true
+ # - name: Dependencies
+ # run: |
+ # sudo apt-get update
+ # sudo apt-get install build-essential ffmpeg
+ # # Install UV
+ # curl -LsSf https://astral.sh/uv/install.sh | sh
+ # sudo apt-get install -y ca-certificates cmake curl patch python3-pip
+ # sudo apt-get install -y libopencv-dev
+ # pip install --user --no-cache-dir grpcio-tools==1.64.1
+
+ # - name: Test bark
+ # run: |
+ # make --jobs=5 --output-sync=target -C backend/python/bark
+ # make --jobs=5 --output-sync=target -C backend/python/bark test
+
+
+ # Below tests needs GPU. Commented out for now
+ # TODO: Re-enable as soon as we have GPU nodes
+ # tests-vllm:
+ # runs-on: ubuntu-latest
+ # steps:
+ # - name: Clone
+ # uses: actions/checkout@v6
+ # with:
+ # submodules: true
+ # - name: Dependencies
+ # run: |
+ # sudo apt-get update
+ # sudo apt-get install build-essential ffmpeg
+ # # Install UV
+ # curl -LsSf https://astral.sh/uv/install.sh | sh
+ # sudo apt-get install -y ca-certificates cmake curl patch python3-pip
+ # sudo apt-get install -y libopencv-dev
+ # pip install --user --no-cache-dir grpcio-tools==1.64.1
+ # - name: Test vllm
+ # run: |
+ # make --jobs=5 --output-sync=target -C backend/python/vllm
+ # make --jobs=5 --output-sync=target -C backend/python/vllm test
+
+ tests-coqui:
+ runs-on: ubuntu-latest
+ steps:
+ - name: Clone
+ uses: actions/checkout@v6
+ with:
+ submodules: true
+ - name: Dependencies
+ run: |
+ sudo apt-get update
+ sudo apt-get install build-essential ffmpeg
+ sudo apt-get install -y ca-certificates cmake curl patch espeak espeak-ng python3-pip
+ # Install UV
+ curl -LsSf https://astral.sh/uv/install.sh | sh
+ pip install --user --no-cache-dir grpcio-tools==1.64.1
+ - name: Test coqui
+ run: |
+ make --jobs=5 --output-sync=target -C backend/python/coqui
+ make --jobs=5 --output-sync=target -C backend/python/coqui test
+ tests-moonshine:
+ runs-on: ubuntu-latest
+ steps:
+ - name: Clone
+ uses: actions/checkout@v6
+ with:
+ submodules: true
+ - name: Dependencies
+ run: |
+ sudo apt-get update
+ sudo apt-get install build-essential ffmpeg
+ sudo apt-get install -y ca-certificates cmake curl patch python3-pip
+ # Install UV
+ curl -LsSf https://astral.sh/uv/install.sh | sh
+ pip install --user --no-cache-dir grpcio-tools==1.64.1
+ - name: Test moonshine
+ run: |
+ make --jobs=5 --output-sync=target -C backend/python/moonshine
+ make --jobs=5 --output-sync=target -C backend/python/moonshine test
+ tests-pocket-tts:
+ runs-on: ubuntu-latest
+ steps:
+ - name: Clone
+ uses: actions/checkout@v6
+ with:
+ submodules: true
+ - name: Dependencies
+ run: |
+ sudo apt-get update
+ sudo apt-get install build-essential ffmpeg
+ sudo apt-get install -y ca-certificates cmake curl patch python3-pip
+ # Install UV
+ curl -LsSf https://astral.sh/uv/install.sh | sh
+ pip install --user --no-cache-dir grpcio-tools==1.64.1
+ - name: Test pocket-tts
+ run: |
+ make --jobs=5 --output-sync=target -C backend/python/pocket-tts
+ make --jobs=5 --output-sync=target -C backend/python/pocket-tts test
\ No newline at end of file
diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml
new file mode 100644
index 0000000000000000000000000000000000000000..e54f3003e940c80c796242492f7860c5714c90da
--- /dev/null
+++ b/.github/workflows/test.yml
@@ -0,0 +1,228 @@
+---
+name: 'tests'
+
+on:
+ pull_request:
+ push:
+ branches:
+ - master
+ tags:
+ - '*'
+
+env:
+ GRPC_VERSION: v1.65.0
+
+concurrency:
+ group: ci-tests-${{ github.head_ref || github.ref }}-${{ github.repository }}
+ cancel-in-progress: true
+
+jobs:
+ tests-linux:
+ runs-on: ubuntu-latest
+ strategy:
+ matrix:
+ go-version: ['1.25.x']
+ steps:
+ - name: Free Disk Space (Ubuntu)
+ uses: jlumbroso/free-disk-space@main
+ with:
+ # this might remove tools that are actually needed,
+ # if set to "true" but frees about 6 GB
+ tool-cache: true
+ # all of these default to true, but feel free to set to
+ # "false" if necessary for your workflow
+ android: true
+ dotnet: true
+ haskell: true
+ large-packages: true
+ docker-images: true
+ swap-storage: true
+ - name: Release space from worker
+ run: |
+ echo "Listing top largest packages"
+ pkgs=$(dpkg-query -Wf '${Installed-Size}\t${Package}\t${Status}\n' | awk '$NF == "installed"{print $1 "\t" $2}' | sort -nr)
+ head -n 30 <<< "${pkgs}"
+ echo
+ df -h
+ echo
+ sudo apt-get remove -y '^llvm-.*|^libllvm.*' || true
+ sudo apt-get remove --auto-remove android-sdk-platform-tools || true
+ sudo apt-get purge --auto-remove android-sdk-platform-tools || true
+ sudo rm -rf /usr/local/lib/android
+ sudo apt-get remove -y '^dotnet-.*|^aspnetcore-.*' || true
+ sudo rm -rf /usr/share/dotnet
+ sudo apt-get remove -y '^mono-.*' || true
+ sudo apt-get remove -y '^ghc-.*' || true
+ sudo apt-get remove -y '.*jdk.*|.*jre.*' || true
+ sudo apt-get remove -y 'php.*' || true
+ sudo apt-get remove -y hhvm powershell firefox monodoc-manual msbuild || true
+ sudo apt-get remove -y '^google-.*' || true
+ sudo apt-get remove -y azure-cli || true
+ sudo apt-get remove -y '^mongo.*-.*|^postgresql-.*|^mysql-.*|^mssql-.*' || true
+ sudo apt-get remove -y '^gfortran-.*' || true
+ sudo apt-get autoremove -y
+ sudo apt-get clean
+ echo
+ echo "Listing top largest packages"
+ pkgs=$(dpkg-query -Wf '${Installed-Size}\t${Package}\t${Status}\n' | awk '$NF == "installed"{print $1 "\t" $2}' | sort -nr)
+ head -n 30 <<< "${pkgs}"
+ echo
+ sudo rm -rfv build || true
+ df -h
+ - name: Clone
+ uses: actions/checkout@v6
+ with:
+ submodules: true
+ - name: Setup Go ${{ matrix.go-version }}
+ uses: actions/setup-go@v5
+ with:
+ go-version: ${{ matrix.go-version }}
+ cache: false
+ # You can test your matrix by printing the current Go version
+ - name: Display Go version
+ run: go version
+ - name: Proto Dependencies
+ run: |
+ # Install protoc
+ curl -L -s https://github.com/protocolbuffers/protobuf/releases/download/v26.1/protoc-26.1-linux-x86_64.zip -o protoc.zip && \
+ unzip -j -d /usr/local/bin protoc.zip bin/protoc && \
+ rm protoc.zip
+ go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.34.2
+ go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@1958fcbe2ca8bd93af633f11e97d44e567e945af
+ PATH="$PATH:$HOME/go/bin" make protogen-go
+ - name: Dependencies
+ run: |
+ sudo apt-get update
+ sudo apt-get install build-essential ccache upx-ucl curl ffmpeg
+ sudo apt-get install -y libgmock-dev clang
+ # Install UV
+ curl -LsSf https://astral.sh/uv/install.sh | sh
+ sudo apt-get install -y ca-certificates cmake patch python3-pip unzip
+ sudo apt-get install -y libopencv-dev
+
+ curl -L -s https://github.com/protocolbuffers/protobuf/releases/download/v26.1/protoc-26.1-linux-x86_64.zip -o protoc.zip && \
+ unzip -j -d /usr/local/bin protoc.zip bin/protoc && \
+ rm protoc.zip
+
+ curl -O https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb
+ sudo dpkg -i cuda-keyring_1.1-1_all.deb
+ sudo apt-get update
+ sudo apt-get install -y cuda-nvcc-${CUDA_VERSION} libcublas-dev-${CUDA_VERSION}
+ export CUDACXX=/usr/local/cuda/bin/nvcc
+ make -C backend/python/transformers
+
+ make backends/huggingface backends/llama-cpp backends/local-store backends/silero-vad backends/piper backends/whisper backends/stablediffusion-ggml
+ env:
+ CUDA_VERSION: 12-4
+ - name: Test
+ run: |
+ PATH="$PATH:/root/go/bin" GO_TAGS="tts" make --jobs 5 --output-sync=target test
+ - name: Setup tmate session if tests fail
+ if: ${{ failure() }}
+ uses: mxschmitt/action-tmate@v3.23
+ with:
+ detached: true
+ connect-timeout-seconds: 180
+ limit-access-to-actor: true
+
+ tests-aio-container:
+ runs-on: ubuntu-latest
+ steps:
+ - name: Release space from worker
+ run: |
+ echo "Listing top largest packages"
+ pkgs=$(dpkg-query -Wf '${Installed-Size}\t${Package}\t${Status}\n' | awk '$NF == "installed"{print $1 "\t" $2}' | sort -nr)
+ head -n 30 <<< "${pkgs}"
+ echo
+ df -h
+ echo
+ sudo apt-get remove -y '^llvm-.*|^libllvm.*' || true
+ sudo apt-get remove --auto-remove android-sdk-platform-tools || true
+ sudo apt-get purge --auto-remove android-sdk-platform-tools || true
+ sudo rm -rf /usr/local/lib/android
+ sudo apt-get remove -y '^dotnet-.*|^aspnetcore-.*' || true
+ sudo rm -rf /usr/share/dotnet
+ sudo apt-get remove -y '^mono-.*' || true
+ sudo apt-get remove -y '^ghc-.*' || true
+ sudo apt-get remove -y '.*jdk.*|.*jre.*' || true
+ sudo apt-get remove -y 'php.*' || true
+ sudo apt-get remove -y hhvm powershell firefox monodoc-manual msbuild || true
+ sudo apt-get remove -y '^google-.*' || true
+ sudo apt-get remove -y azure-cli || true
+ sudo apt-get remove -y '^mongo.*-.*|^postgresql-.*|^mysql-.*|^mssql-.*' || true
+ sudo apt-get remove -y '^gfortran-.*' || true
+ sudo apt-get autoremove -y
+ sudo apt-get clean
+ echo
+ echo "Listing top largest packages"
+ pkgs=$(dpkg-query -Wf '${Installed-Size}\t${Package}\t${Status}\n' | awk '$NF == "installed"{print $1 "\t" $2}' | sort -nr)
+ head -n 30 <<< "${pkgs}"
+ echo
+ sudo rm -rfv build || true
+ df -h
+ - name: Clone
+ uses: actions/checkout@v6
+ with:
+ submodules: true
+ - name: Dependencies
+ run: |
+ # Install protoc
+ curl -L -s https://github.com/protocolbuffers/protobuf/releases/download/v26.1/protoc-26.1-linux-x86_64.zip -o protoc.zip && \
+ unzip -j -d /usr/local/bin protoc.zip bin/protoc && \
+ rm protoc.zip
+ go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.34.2
+ go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@1958fcbe2ca8bd93af633f11e97d44e567e945af
+ PATH="$PATH:$HOME/go/bin" make protogen-go
+ - name: Test
+ run: |
+ PATH="$PATH:$HOME/go/bin" make backends/local-store backends/silero-vad backends/llama-cpp backends/whisper backends/piper backends/stablediffusion-ggml docker-build-aio e2e-aio
+ - name: Setup tmate session if tests fail
+ if: ${{ failure() }}
+ uses: mxschmitt/action-tmate@v3.23
+ with:
+ detached: true
+ connect-timeout-seconds: 180
+ limit-access-to-actor: true
+
+ tests-apple:
+ runs-on: macos-latest
+ strategy:
+ matrix:
+ go-version: ['1.25.x']
+ steps:
+ - name: Clone
+ uses: actions/checkout@v6
+ with:
+ submodules: true
+ - name: Setup Go ${{ matrix.go-version }}
+ uses: actions/setup-go@v5
+ with:
+ go-version: ${{ matrix.go-version }}
+ cache: false
+ # You can test your matrix by printing the current Go version
+ - name: Display Go version
+ run: go version
+ - name: Dependencies
+ run: |
+ brew install protobuf grpc make protoc-gen-go protoc-gen-go-grpc libomp llvm
+ pip install --user --no-cache-dir grpcio-tools grpcio
+ - name: Build llama-cpp-darwin
+ run: |
+ make protogen-go
+ make backends/llama-cpp-darwin
+ - name: Test
+ run: |
+ export C_INCLUDE_PATH=/usr/local/include
+ export CPLUS_INCLUDE_PATH=/usr/local/include
+ export CC=/opt/homebrew/opt/llvm/bin/clang
+ # Used to run the newer GNUMake version from brew that supports --output-sync
+ export PATH="/opt/homebrew/opt/make/libexec/gnubin:$PATH"
+ PATH="$PATH:$HOME/go/bin" make protogen-go
+ PATH="$PATH:$HOME/go/bin" BUILD_TYPE="GITHUB_CI_HAS_BROKEN_METAL" CMAKE_ARGS="-DGGML_F16C=OFF -DGGML_AVX512=OFF -DGGML_AVX2=OFF -DGGML_FMA=OFF" make --jobs 4 --output-sync=target test
+ - name: Setup tmate session if tests fail
+ if: ${{ failure() }}
+ uses: mxschmitt/action-tmate@v3.23
+ with:
+ detached: true
+ connect-timeout-seconds: 180
+ limit-access-to-actor: true
diff --git a/.github/workflows/update_swagger.yaml b/.github/workflows/update_swagger.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..b0ca6455afc179cd7f5a2ece34ed8f3eb464e971
--- /dev/null
+++ b/.github/workflows/update_swagger.yaml
@@ -0,0 +1,37 @@
+name: Update swagger
+on:
+ schedule:
+ - cron: 0 20 * * *
+ workflow_dispatch:
+jobs:
+ swagger:
+ strategy:
+ fail-fast: false
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v6
+ - uses: actions/setup-go@v5
+ with:
+ go-version: 'stable'
+ - name: Dependencies
+ run: |
+ sudo apt-get update
+ sudo apt-get install protobuf-compiler
+ - run: |
+ go install github.com/swaggo/swag/cmd/swag@latest
+ go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@1958fcbe2ca8bd93af633f11e97d44e567e945af
+ go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.34.2
+ - name: Bump swagger 🔧
+ run: |
+ make protogen-go swagger
+ - name: Create Pull Request
+ uses: peter-evans/create-pull-request@v8
+ with:
+ token: ${{ secrets.UPDATE_BOT_TOKEN }}
+ push-to-fork: ci-forks/LocalAI
+ commit-message: 'feat(swagger): update swagger'
+ title: 'feat(swagger): update swagger'
+ branch: "update/swagger"
+ body: Update swagger
+ signoff: true
+
diff --git a/.github/workflows/yaml-check.yml b/.github/workflows/yaml-check.yml
new file mode 100644
index 0000000000000000000000000000000000000000..4a5689e2c6b729be3891ebd46688cae34fa3ac16
--- /dev/null
+++ b/.github/workflows/yaml-check.yml
@@ -0,0 +1,26 @@
+name: 'Yamllint GitHub Actions'
+on:
+ - pull_request
+jobs:
+ yamllint:
+ name: 'Yamllint'
+ runs-on: ubuntu-latest
+ steps:
+ - name: 'Checkout'
+ uses: actions/checkout@master
+ - name: 'Yamllint model gallery'
+ uses: karancode/yamllint-github-action@master
+ with:
+ yamllint_file_or_dir: 'gallery'
+ yamllint_strict: false
+ yamllint_comment: true
+ env:
+ GITHUB_ACCESS_TOKEN: ${{ secrets.GITHUB_TOKEN }}
+ - name: 'Yamllint Backend gallery'
+ uses: karancode/yamllint-github-action@master
+ with:
+ yamllint_file_or_dir: 'backend'
+ yamllint_strict: false
+ yamllint_comment: true
+ env:
+ GITHUB_ACCESS_TOKEN: ${{ secrets.GITHUB_TOKEN }}
\ No newline at end of file
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..2ee2ab8588b1b98e88f9399b995cc0ba94615f26
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,65 @@
+# go-llama build artifacts
+/sources/
+__pycache__/
+*.a
+*.o
+get-sources
+prepare-sources
+/backend/cpp/llama-cpp/grpc-server
+/backend/cpp/llama-cpp/llama.cpp
+/backend/cpp/llama-*
+!backend/cpp/llama-cpp
+/backends
+/backend-images
+/result.yaml
+protoc
+
+*.log
+
+go-ggml-transformers
+go-gpt2
+whisper.cpp
+/bloomz
+go-bert
+
+# LocalAI build binary
+LocalAI
+/local-ai
+/local-ai-launcher
+# prevent above rules from omitting the helm chart
+!charts/*
+# prevent above rules from omitting the api/localai folder
+!api/localai
+!core/**/localai
+
+# Ignore models
+models/*
+test-models/
+test-dir/
+
+release/
+
+# just in case
+.DS_Store
+.idea
+
+# Generated during build
+backend-assets/*
+!backend-assets/.keep
+prepare
+/ggml-metal.metal
+docs/static/gallery.html
+
+# Protobuf generated files
+*.pb.go
+*pb2.py
+*pb2_grpc.py
+
+# SonarQube
+.scannerwork
+
+# backend virtual environments
+**/venv
+
+# per-developer customization files for the development container
+.devcontainer/customization/*
diff --git a/.gitmodules b/.gitmodules
new file mode 100644
index 0000000000000000000000000000000000000000..c263dbe06f80c58dd849541d88743ca91f0bd7e2
--- /dev/null
+++ b/.gitmodules
@@ -0,0 +1,3 @@
+[submodule "docs/themes/hugo-theme-relearn"]
+ path = docs/themes/hugo-theme-relearn
+ url = https://github.com/McShelby/hugo-theme-relearn.git
diff --git a/.goreleaser.yaml b/.goreleaser.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..5c5bf9987fe8909289b995eac2febf275b58c793
--- /dev/null
+++ b/.goreleaser.yaml
@@ -0,0 +1,36 @@
+version: 2
+before:
+ hooks:
+ - make protogen-go
+ - go mod tidy
+dist: release
+source:
+ enabled: true
+ name_template: '{{ .ProjectName }}-{{ .Tag }}-source'
+builds:
+ - main: ./cmd/local-ai
+ env:
+ - CGO_ENABLED=0
+ ldflags:
+ - -s -w
+ - -X "github.com/mudler/LocalAI/internal.Version={{ .Tag }}"
+ - -X "github.com/mudler/LocalAI/internal.Commit={{ .FullCommit }}"
+ goos:
+ - linux
+ - darwin
+ #- windows
+ goarch:
+ - amd64
+ - arm64
+ ignore:
+ - goos: darwin
+ goarch: amd64
+archives:
+ - formats: [ 'binary' ] # this removes the tar of the archives, leaving the binaries alone
+ name_template: local-ai-{{ .Tag }}-{{ .Os }}-{{ .Arch }}{{ if .Arm }}v{{ .Arm }}{{ end }}
+checksum:
+ name_template: '{{ .ProjectName }}-{{ .Tag }}-checksums.txt'
+snapshot:
+ version_template: "{{ .Tag }}-next"
+changelog:
+ use: github-native
diff --git a/.vscode/extensions.json b/.vscode/extensions.json
new file mode 100644
index 0000000000000000000000000000000000000000..7203cb3f17c6038d175f2276f1fa943df4ad6034
--- /dev/null
+++ b/.vscode/extensions.json
@@ -0,0 +1,5 @@
+{
+ "recommendations": [
+ "golang.go"
+ ]
+}
\ No newline at end of file
diff --git a/.vscode/launch.json b/.vscode/launch.json
new file mode 100644
index 0000000000000000000000000000000000000000..55da767b41a6d6fd6742badcae735cd048bfa5c5
--- /dev/null
+++ b/.vscode/launch.json
@@ -0,0 +1,34 @@
+{
+ "version": "0.2.0",
+ "configurations": [
+ {
+ "name": "Python: Current File",
+ "type": "debugpy",
+ "request": "launch",
+ "program": "${file}",
+ "console": "integratedTerminal",
+ "justMyCode": false,
+ "cwd": "${fileDirname}",
+ "env": {
+ "OPENAI_API_BASE": "http://localhost:8080/v1",
+ "OPENAI_API_KEY": "abc"
+ }
+ },
+ {
+ "name": "Launch LocalAI API",
+ "type": "go",
+ "request": "launch",
+ "mode": "debug",
+ "program": "${workspaceRoot}",
+ "args": [],
+ "env": {
+ "LOCALAI_LOG_LEVEL": "debug",
+ "LOCALAI_P2P": "true",
+ "LOCALAI_FEDERATED": "true"
+ },
+ "buildFlags": ["-tags", "", "-v"],
+ "envFile": "${workspaceFolder}/.env",
+ "cwd": "${workspaceRoot}"
+ }
+ ]
+}
\ No newline at end of file
diff --git a/.yamllint b/.yamllint
new file mode 100644
index 0000000000000000000000000000000000000000..8b8a89eb47d06fb5a2c3b09098ae1be82f60184c
--- /dev/null
+++ b/.yamllint
@@ -0,0 +1,4 @@
+extends: default
+
+rules:
+ line-length: disable
\ No newline at end of file
diff --git a/AGENTS.md b/AGENTS.md
new file mode 100644
index 0000000000000000000000000000000000000000..bc8b966d15c22fcf94c01ac9d7ded0c5b504c566
--- /dev/null
+++ b/AGENTS.md
@@ -0,0 +1,282 @@
+# Build and testing
+
+Building and testing the project depends on the components involved and the platform where development is taking place. Due to the amount of context required it's usually best not to try building or testing the project unless the user requests it. If you must build the project then inspect the Makefile in the project root and the Makefiles of any backends that are effected by changes you are making. In addition the workflows in .github/workflows can be used as a reference when it is unclear how to build or test a component. The primary Makefile contains targets for building inside or outside Docker, if the user has not previously specified a preference then ask which they would like to use.
+
+## Building a specified backend
+
+Let's say the user wants to build a particular backend for a given platform. For example let's say they want to build bark for ROCM/hipblas
+
+- The Makefile has targets like `docker-build-bark` created with `generate-docker-build-target` at the time of writing. Recently added backends may require a new target.
+- At a minimum we need to set the BUILD_TYPE, BASE_IMAGE build-args
+ - Use .github/workflows/backend.yml as a reference it lists the needed args in the `include` job strategy matrix
+ - l4t and cublas also requires the CUDA major and minor version
+- You can pretty print a command like `DOCKER_MAKEFLAGS=-j$(nproc --ignore=1) BUILD_TYPE=hipblas BASE_IMAGE=rocm/dev-ubuntu-24.04:6.4.4 make docker-build-bark`
+- Unless the user specifies that they want you to run the command, then just print it because not all agent frontends handle long running jobs well and the output may overflow your context
+- The user may say they want to build AMD or ROCM instead of hipblas, or Intel instead of SYCL or NVIDIA insted of l4t or cublas. Ask for confirmation if there is ambiguity.
+- Sometimes the user may need extra parameters to be added to `docker build` (e.g. `--platform` for cross-platform builds or `--progress` to view the full logs), in which case you can generate the `docker build` command directly.
+
+## Adding a New Backend
+
+When adding a new backend to LocalAI, you need to update several files to ensure the backend is properly built, tested, and registered. Here's a step-by-step guide based on the pattern used for adding backends like `moonshine`:
+
+### 1. Create Backend Directory Structure
+
+Create the backend directory under the appropriate location:
+- **Python backends**: `backend/python//`
+- **Go backends**: `backend/go//`
+- **C++ backends**: `backend/cpp//`
+
+For Python backends, you'll typically need:
+- `backend.py` - Main gRPC server implementation
+- `Makefile` - Build configuration
+- `install.sh` - Installation script for dependencies
+- `protogen.sh` - Protocol buffer generation script
+- `requirements.txt` - Python dependencies
+- `run.sh` - Runtime script
+- `test.py` / `test.sh` - Test files
+
+### 2. Add Build Configurations to `.github/workflows/backend.yml`
+
+Add build matrix entries for each platform/GPU type you want to support. Look at similar backends (e.g., `chatterbox`, `faster-whisper`) for reference.
+
+**Placement in file:**
+- CPU builds: Add after other CPU builds (e.g., after `cpu-chatterbox`)
+- CUDA 12 builds: Add after other CUDA 12 builds (e.g., after `gpu-nvidia-cuda-12-chatterbox`)
+- CUDA 13 builds: Add after other CUDA 13 builds (e.g., after `gpu-nvidia-cuda-13-chatterbox`)
+
+**Additional build types you may need:**
+- ROCm/HIP: Use `build-type: 'hipblas'` with `base-image: "rocm/dev-ubuntu-24.04:6.4.4"`
+- Intel/SYCL: Use `build-type: 'intel'` or `build-type: 'sycl_f16'`/`sycl_f32` with `base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04"`
+- L4T (ARM): Use `build-type: 'l4t'` with `platforms: 'linux/arm64'` and `runs-on: 'ubuntu-24.04-arm'`
+
+### 3. Add Backend Metadata to `backend/index.yaml`
+
+**Step 3a: Add Meta Definition**
+
+Add a YAML anchor definition in the `## metas` section (around line 2-300). Look for similar backends to use as a template such as `diffusers` or `chatterbox`
+
+**Step 3b: Add Image Entries**
+
+Add image entries at the end of the file, following the pattern of similar backends such as `diffusers` or `chatterbox`. Include both `latest` (production) and `master` (development) tags.
+
+### 4. Update the Makefile
+
+The Makefile needs to be updated in several places to support building and testing the new backend:
+
+**Step 4a: Add to `.NOTPARALLEL`**
+
+Add `backends/` to the `.NOTPARALLEL` line (around line 2) to prevent parallel execution conflicts:
+
+```makefile
+.NOTPARALLEL: ... backends/
+```
+
+**Step 4b: Add to `prepare-test-extra`**
+
+Add the backend to the `prepare-test-extra` target (around line 312) to prepare it for testing:
+
+```makefile
+prepare-test-extra: protogen-python
+ ...
+ $(MAKE) -C backend/python/
+```
+
+**Step 4c: Add to `test-extra`**
+
+Add the backend to the `test-extra` target (around line 319) to run its tests:
+
+```makefile
+test-extra: prepare-test-extra
+ ...
+ $(MAKE) -C backend/python/ test
+```
+
+**Step 4d: Add Backend Definition**
+
+Add a backend definition variable in the backend definitions section (around line 428-457). The format depends on the backend type:
+
+**For Python backends with root context** (like `faster-whisper`, `bark`):
+```makefile
+BACKEND_ = |python|.|false|true
+```
+
+**For Python backends with `./backend` context** (like `chatterbox`, `moonshine`):
+```makefile
+BACKEND_ = |python|./backend|false|true
+```
+
+**For Go backends**:
+```makefile
+BACKEND_ = |golang|.|false|true
+```
+
+**Step 4e: Generate Docker Build Target**
+
+Add an eval call to generate the docker-build target (around line 480-501):
+
+```makefile
+$(eval $(call generate-docker-build-target,$(BACKEND_)))
+```
+
+**Step 4f: Add to `docker-build-backends`**
+
+Add `docker-build-` to the `docker-build-backends` target (around line 507):
+
+```makefile
+docker-build-backends: ... docker-build-
+```
+
+**Determining the Context:**
+
+- If the backend is in `backend/python//` and uses `./backend` as context in the workflow file, use `./backend` context
+- If the backend is in `backend/python//` but uses `.` as context in the workflow file, use `.` context
+- Check similar backends to determine the correct context
+
+### 5. Verification Checklist
+
+After adding a new backend, verify:
+
+- [ ] Backend directory structure is complete with all necessary files
+- [ ] Build configurations added to `.github/workflows/backend.yml` for all desired platforms
+- [ ] Meta definition added to `backend/index.yaml` in the `## metas` section
+- [ ] Image entries added to `backend/index.yaml` for all build variants (latest + development)
+- [ ] Tag suffixes match between workflow file and index.yaml
+- [ ] Makefile updated with all 6 required changes (`.NOTPARALLEL`, `prepare-test-extra`, `test-extra`, backend definition, docker-build target eval, `docker-build-backends`)
+- [ ] No YAML syntax errors (check with linter)
+- [ ] No Makefile syntax errors (check with linter)
+- [ ] Follows the same pattern as similar backends (e.g., if it's a transcription backend, follow `faster-whisper` pattern)
+
+### 6. Example: Adding a Python Backend
+
+For reference, when `moonshine` was added:
+- **Files created**: `backend/python/moonshine/{backend.py, Makefile, install.sh, protogen.sh, requirements.txt, run.sh, test.py, test.sh}`
+- **Workflow entries**: 3 build configurations (CPU, CUDA 12, CUDA 13)
+- **Index entries**: 1 meta definition + 6 image entries (cpu, cuda12, cuda13 × latest/development)
+- **Makefile updates**:
+ - Added to `.NOTPARALLEL` line
+ - Added to `prepare-test-extra` and `test-extra` targets
+ - Added `BACKEND_MOONSHINE = moonshine|python|./backend|false|true`
+ - Added eval for docker-build target generation
+ - Added `docker-build-moonshine` to `docker-build-backends`
+
+# Coding style
+
+- The project has the following .editorconfig
+
+```
+root = true
+
+[*]
+indent_style = space
+indent_size = 2
+end_of_line = lf
+charset = utf-8
+trim_trailing_whitespace = true
+insert_final_newline = true
+
+[*.go]
+indent_style = tab
+
+[Makefile]
+indent_style = tab
+
+[*.proto]
+indent_size = 2
+
+[*.py]
+indent_size = 4
+
+[*.js]
+indent_size = 2
+
+[*.yaml]
+indent_size = 2
+
+[*.md]
+trim_trailing_whitespace = false
+```
+
+- Use comments sparingly to explain why code does something, not what it does. Comments are there to add context that would be difficult to deduce from reading the code.
+- Prefer modern Go e.g. use `any` not `interface{}`
+
+# Logging
+
+Use `github.com/mudler/xlog` for logging which has the same API as slog.
+
+# llama.cpp Backend
+
+The llama.cpp backend (`backend/cpp/llama-cpp/grpc-server.cpp`) is a gRPC adaptation of the upstream HTTP server (`llama.cpp/tools/server/server.cpp`). It uses the same underlying server infrastructure from `llama.cpp/tools/server/server-context.cpp`.
+
+## Building and Testing
+
+- Test llama.cpp backend compilation: `make backends/llama-cpp`
+- The backend is built as part of the main build process
+- Check `backend/cpp/llama-cpp/Makefile` for build configuration
+
+## Architecture
+
+- **grpc-server.cpp**: gRPC server implementation, adapts HTTP server patterns to gRPC
+- Uses shared server infrastructure: `server-context.cpp`, `server-task.cpp`, `server-queue.cpp`, `server-common.cpp`
+- The gRPC server mirrors the HTTP server's functionality but uses gRPC instead of HTTP
+
+## Common Issues When Updating llama.cpp
+
+When fixing compilation errors after upstream changes:
+1. Check how `server.cpp` (HTTP server) handles the same change
+2. Look for new public APIs or getter methods
+3. Store copies of needed data instead of accessing private members
+4. Update function calls to match new signatures
+5. Test with `make backends/llama-cpp`
+
+## Key Differences from HTTP Server
+
+- gRPC uses `BackendServiceImpl` class with gRPC service methods
+- HTTP server uses `server_routes` with HTTP handlers
+- Both use the same `server_context` and task queue infrastructure
+- gRPC methods: `LoadModel`, `Predict`, `PredictStream`, `Embedding`, `Rerank`, `TokenizeString`, `GetMetrics`, `Health`
+
+## Tool Call Parsing Maintenance
+
+When working on JSON/XML tool call parsing functionality, always check llama.cpp for reference implementation and updates:
+
+### Checking for XML Parsing Changes
+
+1. **Review XML Format Definitions**: Check `llama.cpp/common/chat-parser-xml-toolcall.h` for `xml_tool_call_format` struct changes
+2. **Review Parsing Logic**: Check `llama.cpp/common/chat-parser-xml-toolcall.cpp` for parsing algorithm updates
+3. **Review Format Presets**: Check `llama.cpp/common/chat-parser.cpp` for new XML format presets (search for `xml_tool_call_format form`)
+4. **Review Model Lists**: Check `llama.cpp/common/chat.h` for `COMMON_CHAT_FORMAT_*` enum values that use XML parsing:
+ - `COMMON_CHAT_FORMAT_GLM_4_5`
+ - `COMMON_CHAT_FORMAT_MINIMAX_M2`
+ - `COMMON_CHAT_FORMAT_KIMI_K2`
+ - `COMMON_CHAT_FORMAT_QWEN3_CODER_XML`
+ - `COMMON_CHAT_FORMAT_APRIEL_1_5`
+ - `COMMON_CHAT_FORMAT_XIAOMI_MIMO`
+ - Any new formats added
+
+### Model Configuration Options
+
+Always check `llama.cpp` for new model configuration options that should be supported in LocalAI:
+
+1. **Check Server Context**: Review `llama.cpp/tools/server/server-context.cpp` for new parameters
+2. **Check Chat Params**: Review `llama.cpp/common/chat.h` for `common_chat_params` struct changes
+3. **Check Server Options**: Review `llama.cpp/tools/server/server.cpp` for command-line argument changes
+4. **Examples of options to check**:
+ - `ctx_shift` - Context shifting support
+ - `parallel_tool_calls` - Parallel tool calling
+ - `reasoning_format` - Reasoning format options
+ - Any new flags or parameters
+
+### Implementation Guidelines
+
+1. **Feature Parity**: Always aim for feature parity with llama.cpp's implementation
+2. **Test Coverage**: Add tests for new features matching llama.cpp's behavior
+3. **Documentation**: Update relevant documentation when adding new formats or options
+4. **Backward Compatibility**: Ensure changes don't break existing functionality
+
+### Files to Monitor
+
+- `llama.cpp/common/chat-parser-xml-toolcall.h` - Format definitions
+- `llama.cpp/common/chat-parser-xml-toolcall.cpp` - Parsing logic
+- `llama.cpp/common/chat-parser.cpp` - Format presets and model-specific handlers
+- `llama.cpp/common/chat.h` - Format enums and parameter structures
+- `llama.cpp/tools/server/server-context.cpp` - Server configuration options
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
new file mode 100644
index 0000000000000000000000000000000000000000..87d7edbfc96bad00f3d11118a3290476f7bce3f0
--- /dev/null
+++ b/CONTRIBUTING.md
@@ -0,0 +1,99 @@
+# Contributing to LocalAI
+
+Thank you for your interest in contributing to LocalAI! We appreciate your time and effort in helping to improve our project. Before you get started, please take a moment to review these guidelines.
+
+## Table of Contents
+
+- [Getting Started](#getting-started)
+ - [Prerequisites](#prerequisites)
+ - [Setting up the Development Environment](#setting-up-the-development-environment)
+- [Contributing](#contributing)
+ - [Submitting an Issue](#submitting-an-issue)
+ - [Creating a Pull Request (PR)](#creating-a-pull-request-pr)
+- [Coding Guidelines](#coding-guidelines)
+- [Testing](#testing)
+- [Documentation](#documentation)
+- [Community and Communication](#community-and-communication)
+
+## Getting Started
+
+### Prerequisites
+
+- Golang [1.21]
+- Git
+- macOS/Linux
+
+### Setting up the Development Environment and running localAI in the local environment
+
+1. Clone the repository: `git clone https://github.com/go-skynet/LocalAI.git`
+2. Navigate to the project directory: `cd LocalAI`
+3. Install the required dependencies ( see https://localai.io/basics/build/#build-localai-locally )
+4. Build LocalAI: `make build`
+5. Run LocalAI: `./local-ai`
+6. To Build and live reload: `make build-dev`
+
+## Contributing
+
+We welcome contributions from everyone! To get started, follow these steps:
+
+### Submitting an Issue
+
+If you find a bug, have a feature request, or encounter any issues, please check the [issue tracker](https://github.com/go-skynet/LocalAI/issues) to see if a similar issue has already been reported. If not, feel free to [create a new issue](https://github.com/go-skynet/LocalAI/issues/new) and provide as much detail as possible.
+
+### Creating a Pull Request (PR)
+
+1. Fork the repository.
+2. Create a new branch with a descriptive name: `git checkout -b [branch name]`
+3. Make your changes and commit them.
+4. Push the changes to your fork: `git push origin [branch name]`
+5. Create a new pull request from your branch to the main project's `main` or `master` branch.
+6. Provide a clear description of your changes in the pull request.
+7. Make any requested changes during the review process.
+8. Once your PR is approved, it will be merged into the main project.
+
+## Coding Guidelines
+
+- No specific coding guidelines at the moment. Please make sure the code can be tested. The most popular lint tools like [`golangci-lint`](https://golangci-lint.run) can help you here.
+
+## Testing
+
+`make test` cannot handle all the model now. Please be sure to add a test case for the new features or the part was changed.
+
+### Running AIO tests
+
+All-In-One images has a set of tests that automatically verifies that most of the endpoints works correctly, a flow can be :
+
+```bash
+# Build the LocalAI docker image
+make DOCKER_IMAGE=local-ai docker
+
+# Build the corresponding AIO image
+BASE_IMAGE=local-ai DOCKER_AIO_IMAGE=local-ai-aio:test make docker-aio
+
+# Run the AIO e2e tests
+LOCALAI_IMAGE_TAG=test LOCALAI_IMAGE=local-ai-aio make run-e2e-aio
+```
+
+## Documentation
+
+We are welcome the contribution of the documents, please open new PR or create a new issue. The documentation is available under `docs/` https://github.com/mudler/LocalAI/tree/master/docs
+
+### Gallery YAML Schema
+
+LocalAI provides a JSON Schema for gallery model YAML files at:
+
+`core/schema/gallery-model.schema.json`
+
+This schema mirrors the internal gallery model configuration and can be used by editors (such as VS Code) to enable autocomplete, validation, and inline documentation when creating or modifying gallery files.
+
+To use it with the YAML language server, add the following comment at the top of a gallery YAML file:
+
+```yaml
+# yaml-language-server: $schema=../core/schema/gallery-model.schema.json
+```
+
+## Community and Communication
+
+- You can reach out via the Github issue tracker.
+- Open a new discussion at [Discussion](https://github.com/go-skynet/LocalAI/discussions)
+- Join the Discord channel [Discord](https://discord.gg/uJAeKSAGDy)
diff --git a/Dockerfile b/Dockerfile
new file mode 100644
index 0000000000000000000000000000000000000000..29570be8b699b9c53c4f1a3561725475669ee6c0
--- /dev/null
+++ b/Dockerfile
@@ -0,0 +1,377 @@
+ARG BASE_IMAGE=ubuntu:24.04
+ARG GRPC_BASE_IMAGE=${BASE_IMAGE}
+ARG INTEL_BASE_IMAGE=${BASE_IMAGE}
+ARG UBUNTU_CODENAME=noble
+
+FROM ${BASE_IMAGE} AS requirements
+
+ENV DEBIAN_FRONTEND=noninteractive
+
+RUN apt-get update && \
+ apt-get install -y --no-install-recommends \
+ ca-certificates curl wget espeak-ng libgomp1 \
+ ffmpeg libopenblas0 libopenblas-dev && \
+ apt-get clean && \
+ rm -rf /var/lib/apt/lists/*
+
+# The requirements-drivers target is for BUILD_TYPE specific items. If you need to install something specific to CUDA, or specific to ROCM, it goes here.
+FROM requirements AS requirements-drivers
+
+ARG BUILD_TYPE
+ARG CUDA_MAJOR_VERSION=12
+ARG CUDA_MINOR_VERSION=0
+ARG SKIP_DRIVERS=false
+ARG TARGETARCH
+ARG TARGETVARIANT
+ENV BUILD_TYPE=${BUILD_TYPE}
+ARG UBUNTU_VERSION=2404
+
+RUN mkdir -p /run/localai
+RUN echo "default" > /run/localai/capability
+
+# Vulkan requirements
+RUN < /run/localai/capability
+ fi
+EOT
+
+# CuBLAS requirements
+RUN < /run/localai/capability
+ fi
+EOT
+
+RUN < /run/localai/capability
+ fi
+EOT
+
+# https://github.com/NVIDIA/Isaac-GR00T/issues/343
+RUN < /run/localai/capability && \
+ # I have no idea why, but the ROCM lib packages don't trigger ldconfig after they install, which results in local-ai and others not being able
+ # to locate the libraries. We run ldconfig ourselves to work around this packaging deficiency
+ ldconfig \
+ ; fi
+
+RUN if [ "${BUILD_TYPE}" = "hipblas" ]; then \
+ ln -s /opt/rocm-**/lib/llvm/lib/libomp.so /usr/lib/libomp.so \
+ ; fi
+
+RUN expr "${BUILD_TYPE}" = intel && echo "intel" > /run/localai/capability || echo "not intel"
+
+# Cuda
+ENV PATH=/usr/local/cuda/bin:${PATH}
+
+# HipBLAS requirements
+ENV PATH=/opt/rocm/bin:${PATH}
+
+###################################
+###################################
+
+# The requirements-core target is common to all images. It should not be placed in requirements-core unless every single build will use it.
+FROM requirements-drivers AS build-requirements
+
+ARG GO_VERSION=1.25.4
+ARG CMAKE_VERSION=3.31.10
+ARG CMAKE_FROM_SOURCE=false
+ARG TARGETARCH
+ARG TARGETVARIANT
+
+RUN apt-get update && \
+ apt-get install -y --no-install-recommends \
+ build-essential \
+ ccache \
+ ca-certificates espeak-ng \
+ curl libssl-dev \
+ git \
+ git-lfs \
+ unzip upx-ucl python3 python-is-python3 && \
+ apt-get clean && \
+ rm -rf /var/lib/apt/lists/*
+
+# Install CMake (the version in 22.04 is too old)
+RUN < /etc/apt/sources.list.d/intel-graphics.list
+RUN apt-get update && \
+ apt-get install -y --no-install-recommends \
+ intel-oneapi-runtime-libs && \
+ apt-get clean && \
+ rm -rf /var/lib/apt/lists/*
+
+###################################
+###################################
+
+# The builder-base target has the arguments, variables, and copies shared between full builder images and the uncompiled devcontainer
+
+FROM build-requirements AS builder-base
+
+ARG GO_TAGS=""
+ARG GRPC_BACKENDS
+ARG MAKEFLAGS
+ARG LD_FLAGS="-s -w"
+ARG TARGETARCH
+ARG TARGETVARIANT
+ENV GRPC_BACKENDS=${GRPC_BACKENDS}
+ENV GO_TAGS=${GO_TAGS}
+ENV MAKEFLAGS=${MAKEFLAGS}
+ENV NVIDIA_DRIVER_CAPABILITIES=compute,utility
+ENV NVIDIA_REQUIRE_CUDA="cuda>=${CUDA_MAJOR_VERSION}.0"
+ENV NVIDIA_VISIBLE_DEVICES=all
+ENV LD_FLAGS=${LD_FLAGS}
+
+RUN echo "GO_TAGS: $GO_TAGS" && echo "TARGETARCH: $TARGETARCH"
+
+WORKDIR /build
+
+
+# We need protoc installed, and the version in 22.04 is too old.
+RUN <
+
+
+
+ com.apple.security.network.client
+
+ com.apple.security.network.server
+
+
+
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..65ebf26018f7d8eefefe1741b81009daed03ff93
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2023-2025 Ettore Di Giacinto (mudler@localai.io)
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
\ No newline at end of file
diff --git a/Makefile b/Makefile
new file mode 100644
index 0000000000000000000000000000000000000000..9bc95063e4d98626c5321fd08c9f28ca31935f9b
--- /dev/null
+++ b/Makefile
@@ -0,0 +1,559 @@
+# Disable parallel execution for backend builds
+.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/piper backends/stablediffusion-ggml backends/whisper backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/stablediffusion-ggml-darwin backends/vllm backends/moonshine backends/pocket-tts
+
+GOCMD=go
+GOTEST=$(GOCMD) test
+GOVET=$(GOCMD) vet
+BINARY_NAME=local-ai
+LAUNCHER_BINARY_NAME=local-ai-launcher
+
+CUDA_MAJOR_VERSION?=13
+CUDA_MINOR_VERSION?=0
+UBUNTU_VERSION?=2404
+UBUNTU_CODENAME?=noble
+
+GORELEASER?=
+
+export BUILD_TYPE?=
+export CUDA_MAJOR_VERSION?=12
+export CUDA_MINOR_VERSION?=9
+
+GO_TAGS?=
+BUILD_ID?=
+NATIVE?=false
+
+TEST_DIR=/tmp/test
+
+TEST_FLAKES?=5
+
+RANDOM := $(shell bash -c 'echo $$RANDOM')
+
+VERSION?=$(shell git describe --always --tags || echo "dev" )
+# go tool nm ./local-ai | grep Commit
+LD_FLAGS?=-s -w
+override LD_FLAGS += -X "github.com/mudler/LocalAI/internal.Version=$(VERSION)"
+override LD_FLAGS += -X "github.com/mudler/LocalAI/internal.Commit=$(shell git rev-parse HEAD)"
+
+OPTIONAL_TARGETS?=
+
+export OS := $(shell uname -s)
+ARCH := $(shell uname -m)
+GREEN := $(shell tput -Txterm setaf 2)
+YELLOW := $(shell tput -Txterm setaf 3)
+WHITE := $(shell tput -Txterm setaf 7)
+CYAN := $(shell tput -Txterm setaf 6)
+RESET := $(shell tput -Txterm sgr0)
+
+# Default Docker bridge IP
+E2E_BRIDGE_IP?=172.17.0.1
+
+ifndef UNAME_S
+UNAME_S := $(shell uname -s)
+endif
+
+ifeq ($(OS),Darwin)
+ ifeq ($(OSX_SIGNING_IDENTITY),)
+ OSX_SIGNING_IDENTITY := $(shell security find-identity -v -p codesigning | grep '"' | head -n 1 | sed -E 's/.*"(.*)"/\1/')
+ endif
+endif
+
+# check if goreleaser exists
+ifeq (, $(shell which goreleaser))
+ GORELEASER=curl -sfL https://goreleaser.com/static/run | bash -s --
+else
+ GORELEASER=$(shell which goreleaser)
+endif
+
+TEST_PATHS?=./api/... ./pkg/... ./core/...
+
+
+.PHONY: all test build vendor
+
+all: help
+
+## GENERIC
+rebuild: ## Rebuilds the project
+ $(GOCMD) clean -cache
+ $(MAKE) build
+
+clean: ## Remove build related file
+ $(GOCMD) clean -cache
+ rm -f prepare
+ rm -rf $(BINARY_NAME)
+ rm -rf release/
+ $(MAKE) protogen-clean
+ rmdir pkg/grpc/proto || true
+
+clean-tests:
+ rm -rf test-models
+ rm -rf test-dir
+
+## Install Go tools
+install-go-tools:
+ go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@1958fcbe2ca8bd93af633f11e97d44e567e945af
+ go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.34.2
+
+## Build:
+build: protogen-go install-go-tools ## Build the project
+ $(info ${GREEN}I local-ai build info:${RESET})
+ $(info ${GREEN}I BUILD_TYPE: ${YELLOW}$(BUILD_TYPE)${RESET})
+ $(info ${GREEN}I GO_TAGS: ${YELLOW}$(GO_TAGS)${RESET})
+ $(info ${GREEN}I LD_FLAGS: ${YELLOW}$(LD_FLAGS)${RESET})
+ $(info ${GREEN}I UPX: ${YELLOW}$(UPX)${RESET})
+ rm -rf $(BINARY_NAME) || true
+ CGO_LDFLAGS="$(CGO_LDFLAGS)" $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o $(BINARY_NAME) ./cmd/local-ai
+
+build-launcher: ## Build the launcher application
+ $(info ${GREEN}I local-ai launcher build info:${RESET})
+ $(info ${GREEN}I BUILD_TYPE: ${YELLOW}$(BUILD_TYPE)${RESET})
+ $(info ${GREEN}I GO_TAGS: ${YELLOW}$(GO_TAGS)${RESET})
+ $(info ${GREEN}I LD_FLAGS: ${YELLOW}$(LD_FLAGS)${RESET})
+ rm -rf $(LAUNCHER_BINARY_NAME) || true
+ CGO_LDFLAGS="$(CGO_LDFLAGS)" $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o $(LAUNCHER_BINARY_NAME) ./cmd/launcher
+
+build-all: build build-launcher ## Build both server and launcher
+
+build-dev: ## Run LocalAI in dev mode with live reload
+ @command -v air >/dev/null 2>&1 || go install github.com/air-verse/air@latest
+ air -c .air.toml
+
+dev-dist:
+ $(GORELEASER) build --snapshot --clean
+
+dist:
+ $(GORELEASER) build --clean
+
+osx-signed: build
+ codesign --deep --force --sign "$(OSX_SIGNING_IDENTITY)" --entitlements "./Entitlements.plist" "./$(BINARY_NAME)"
+
+## Run
+run: ## run local-ai
+ CGO_LDFLAGS="$(CGO_LDFLAGS)" $(GOCMD) run ./
+
+test-models/testmodel.ggml:
+ mkdir -p test-models
+ mkdir -p test-dir
+ wget -q https://huggingface.co/mradermacher/gpt2-alpaca-gpt4-GGUF/resolve/main/gpt2-alpaca-gpt4.Q4_K_M.gguf -O test-models/testmodel.ggml
+ wget -q https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-base.en.bin -O test-models/whisper-en
+ wget -q https://huggingface.co/mudler/all-MiniLM-L6-v2/resolve/main/ggml-model-q4_0.bin -O test-models/bert
+ wget -q https://cdn.openai.com/whisper/draft-20220913a/micro-machines.wav -O test-dir/audio.wav
+ cp tests/models_fixtures/* test-models
+
+prepare-test: protogen-go
+ cp tests/models_fixtures/* test-models
+
+########################################################
+## Tests
+########################################################
+
+## Test targets
+test: test-models/testmodel.ggml protogen-go
+ @echo 'Running tests'
+ export GO_TAGS="debug"
+ $(MAKE) prepare-test
+ HUGGINGFACE_GRPC=$(abspath ./)/backend/python/transformers/run.sh TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models BACKENDS_PATH=$(abspath ./)/backends \
+ $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="!llama-gguf" --flake-attempts $(TEST_FLAKES) --fail-fast -v -r $(TEST_PATHS)
+ $(MAKE) test-llama-gguf
+ $(MAKE) test-tts
+ $(MAKE) test-stablediffusion
+
+########################################################
+## AIO tests
+########################################################
+
+docker-build-aio:
+ docker build \
+ --build-arg MAKEFLAGS="--jobs=5 --output-sync=target" \
+ --build-arg BASE_IMAGE=$(BASE_IMAGE) \
+ --build-arg IMAGE_TYPE=$(IMAGE_TYPE) \
+ --build-arg BUILD_TYPE=$(BUILD_TYPE) \
+ --build-arg CUDA_MAJOR_VERSION=$(CUDA_MAJOR_VERSION) \
+ --build-arg CUDA_MINOR_VERSION=$(CUDA_MINOR_VERSION) \
+ --build-arg UBUNTU_VERSION=$(UBUNTU_VERSION) \
+ --build-arg UBUNTU_CODENAME=$(UBUNTU_CODENAME) \
+ --build-arg GO_TAGS="$(GO_TAGS)" \
+ -t local-ai:tests -f Dockerfile .
+ BASE_IMAGE=local-ai:tests DOCKER_AIO_IMAGE=local-ai-aio:test $(MAKE) docker-aio
+
+e2e-aio:
+ LOCALAI_BACKEND_DIR=$(abspath ./backends) \
+ LOCALAI_MODELS_DIR=$(abspath ./models) \
+ LOCALAI_IMAGE_TAG=test \
+ LOCALAI_IMAGE=local-ai-aio \
+ $(MAKE) run-e2e-aio
+
+run-e2e-aio: protogen-go
+ @echo 'Running e2e AIO tests'
+ $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --flake-attempts $(TEST_FLAKES) -v -r ./tests/e2e-aio
+
+########################################################
+## E2E tests
+########################################################
+
+prepare-e2e:
+ mkdir -p $(TEST_DIR)
+ cp -rfv $(abspath ./tests/e2e-fixtures)/gpu.yaml $(TEST_DIR)/gpu.yaml
+ test -e $(TEST_DIR)/ggllm-test-model.bin || wget -q https://huggingface.co/TheBloke/CodeLlama-7B-Instruct-GGUF/resolve/main/codellama-7b-instruct.Q2_K.gguf -O $(TEST_DIR)/ggllm-test-model.bin
+ docker build \
+ --build-arg IMAGE_TYPE=core \
+ --build-arg BUILD_TYPE=$(BUILD_TYPE) \
+ --build-arg BASE_IMAGE=$(BASE_IMAGE) \
+ --build-arg CUDA_MAJOR_VERSION=$(CUDA_MAJOR_VERSION) \
+ --build-arg CUDA_MINOR_VERSION=$(CUDA_MINOR_VERSION) \
+ --build-arg UBUNTU_VERSION=$(UBUNTU_VERSION) \
+ --build-arg UBUNTU_CODENAME=$(UBUNTU_CODENAME) \
+ --build-arg GO_TAGS="$(GO_TAGS)" \
+ --build-arg MAKEFLAGS="$(DOCKER_MAKEFLAGS)" \
+ -t localai-tests .
+
+run-e2e-image:
+ ls -liah $(abspath ./tests/e2e-fixtures)
+ docker run -p 5390:8080 -e MODELS_PATH=/models -e THREADS=1 -e DEBUG=true -d --rm -v $(TEST_DIR):/models --gpus all --name e2e-tests-$(RANDOM) localai-tests
+
+test-e2e:
+ @echo 'Running e2e tests'
+ BUILD_TYPE=$(BUILD_TYPE) \
+ LOCALAI_API=http://$(E2E_BRIDGE_IP):5390/v1 \
+ $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --flake-attempts $(TEST_FLAKES) -v -r ./tests/e2e
+
+teardown-e2e:
+ rm -rf $(TEST_DIR) || true
+ docker stop $$(docker ps -q --filter ancestor=localai-tests)
+
+########################################################
+## Integration and unit tests
+########################################################
+
+test-llama-gguf: prepare-test
+ TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models BACKENDS_PATH=$(abspath ./)/backends \
+ $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="llama-gguf" --flake-attempts $(TEST_FLAKES) -v -r $(TEST_PATHS)
+
+test-tts: prepare-test
+ TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models BACKENDS_PATH=$(abspath ./)/backends \
+ $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="tts" --flake-attempts $(TEST_FLAKES) -v -r $(TEST_PATHS)
+
+test-stablediffusion: prepare-test
+ TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models BACKENDS_PATH=$(abspath ./)/backends \
+ $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="stablediffusion" --flake-attempts $(TEST_FLAKES) -v -r $(TEST_PATHS)
+
+test-stores:
+ $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="stores" --flake-attempts $(TEST_FLAKES) -v -r tests/integration
+
+test-container:
+ docker build --target requirements -t local-ai-test-container .
+ docker run -ti --rm --entrypoint /bin/bash -ti -v $(abspath ./):/build local-ai-test-container
+
+########################################################
+## Help
+########################################################
+
+## Help:
+help: ## Show this help.
+ @echo ''
+ @echo 'Usage:'
+ @echo ' ${YELLOW}make${RESET} ${GREEN}${RESET}'
+ @echo ''
+ @echo 'Targets:'
+ @awk 'BEGIN {FS = ":.*?## "} { \
+ if (/^[a-zA-Z_-]+:.*?##.*$$/) {printf " ${YELLOW}%-20s${GREEN}%s${RESET}\n", $$1, $$2} \
+ else if (/^## .*$$/) {printf " ${CYAN}%s${RESET}\n", substr($$1,4)} \
+ }' $(MAKEFILE_LIST)
+
+########################################################
+## Backends
+########################################################
+
+.PHONY: protogen
+protogen: protogen-go
+
+protoc:
+ @OS_NAME=$$(uname -s | tr '[:upper:]' '[:lower:]'); \
+ ARCH_NAME=$$(uname -m); \
+ if [ "$$OS_NAME" = "darwin" ]; then \
+ if [ "$$ARCH_NAME" = "arm64" ]; then \
+ FILE=protoc-31.1-osx-aarch_64.zip; \
+ elif [ "$$ARCH_NAME" = "x86_64" ]; then \
+ FILE=protoc-31.1-osx-x86_64.zip; \
+ else \
+ echo "Unsupported macOS architecture: $$ARCH_NAME"; exit 1; \
+ fi; \
+ elif [ "$$OS_NAME" = "linux" ]; then \
+ if [ "$$ARCH_NAME" = "x86_64" ]; then \
+ FILE=protoc-31.1-linux-x86_64.zip; \
+ elif [ "$$ARCH_NAME" = "aarch64" ] || [ "$$ARCH_NAME" = "arm64" ]; then \
+ FILE=protoc-31.1-linux-aarch_64.zip; \
+ elif [ "$$ARCH_NAME" = "ppc64le" ]; then \
+ FILE=protoc-31.1-linux-ppcle_64.zip; \
+ elif [ "$$ARCH_NAME" = "s390x" ]; then \
+ FILE=protoc-31.1-linux-s390_64.zip; \
+ elif [ "$$ARCH_NAME" = "i386" ] || [ "$$ARCH_NAME" = "x86" ]; then \
+ FILE=protoc-31.1-linux-x86_32.zip; \
+ else \
+ echo "Unsupported Linux architecture: $$ARCH_NAME"; exit 1; \
+ fi; \
+ else \
+ echo "Unsupported OS: $$OS_NAME"; exit 1; \
+ fi; \
+ URL=https://github.com/protocolbuffers/protobuf/releases/download/v31.1/$$FILE; \
+ curl -L $$URL -o protoc.zip && \
+ unzip -j -d $(CURDIR) protoc.zip bin/protoc && rm protoc.zip
+
+.PHONY: protogen-go
+protogen-go: protoc install-go-tools
+ mkdir -p pkg/grpc/proto
+ ./protoc --experimental_allow_proto3_optional -Ibackend/ --go_out=pkg/grpc/proto/ --go_opt=paths=source_relative --go-grpc_out=pkg/grpc/proto/ --go-grpc_opt=paths=source_relative \
+ backend/backend.proto
+
+.PHONY: protogen-go-clean
+protogen-go-clean:
+ $(RM) pkg/grpc/proto/backend.pb.go pkg/grpc/proto/backend_grpc.pb.go
+ $(RM) bin/*
+
+prepare-test-extra: protogen-python
+ $(MAKE) -C backend/python/transformers
+ $(MAKE) -C backend/python/diffusers
+ $(MAKE) -C backend/python/chatterbox
+ $(MAKE) -C backend/python/vllm
+ $(MAKE) -C backend/python/vibevoice
+ $(MAKE) -C backend/python/moonshine
+ $(MAKE) -C backend/python/pocket-tts
+
+test-extra: prepare-test-extra
+ $(MAKE) -C backend/python/transformers test
+ $(MAKE) -C backend/python/diffusers test
+ $(MAKE) -C backend/python/chatterbox test
+ $(MAKE) -C backend/python/vllm test
+ $(MAKE) -C backend/python/vibevoice test
+ $(MAKE) -C backend/python/moonshine test
+ $(MAKE) -C backend/python/pocket-tts test
+
+DOCKER_IMAGE?=local-ai
+DOCKER_AIO_IMAGE?=local-ai-aio
+IMAGE_TYPE?=core
+BASE_IMAGE?=ubuntu:24.04
+
+docker:
+ docker build \
+ --build-arg BASE_IMAGE=$(BASE_IMAGE) \
+ --build-arg IMAGE_TYPE=$(IMAGE_TYPE) \
+ --build-arg GO_TAGS="$(GO_TAGS)" \
+ --build-arg MAKEFLAGS="$(DOCKER_MAKEFLAGS)" \
+ --build-arg BUILD_TYPE=$(BUILD_TYPE) \
+ --build-arg CUDA_MAJOR_VERSION=$(CUDA_MAJOR_VERSION) \
+ --build-arg CUDA_MINOR_VERSION=$(CUDA_MINOR_VERSION) \
+ --build-arg UBUNTU_VERSION=$(UBUNTU_VERSION) \
+ --build-arg UBUNTU_CODENAME=$(UBUNTU_CODENAME) \
+ -t $(DOCKER_IMAGE) .
+
+docker-cuda12:
+ docker build \
+ --build-arg CUDA_MAJOR_VERSION=${CUDA_MAJOR_VERSION} \
+ --build-arg CUDA_MINOR_VERSION=${CUDA_MINOR_VERSION} \
+ --build-arg BASE_IMAGE=$(BASE_IMAGE) \
+ --build-arg IMAGE_TYPE=$(IMAGE_TYPE) \
+ --build-arg GO_TAGS="$(GO_TAGS)" \
+ --build-arg MAKEFLAGS="$(DOCKER_MAKEFLAGS)" \
+ --build-arg BUILD_TYPE=$(BUILD_TYPE) \
+ --build-arg UBUNTU_VERSION=$(UBUNTU_VERSION) \
+ --build-arg UBUNTU_CODENAME=$(UBUNTU_CODENAME) \
+ -t $(DOCKER_IMAGE)-cuda-12 .
+
+docker-aio:
+ @echo "Building AIO image with base $(BASE_IMAGE) as $(DOCKER_AIO_IMAGE)"
+ docker build \
+ --build-arg BASE_IMAGE=$(BASE_IMAGE) \
+ --build-arg MAKEFLAGS="$(DOCKER_MAKEFLAGS)" \
+ --build-arg CUDA_MAJOR_VERSION=$(CUDA_MAJOR_VERSION) \
+ --build-arg CUDA_MINOR_VERSION=$(CUDA_MINOR_VERSION) \
+ --build-arg UBUNTU_VERSION=$(UBUNTU_VERSION) \
+ --build-arg UBUNTU_CODENAME=$(UBUNTU_CODENAME) \
+ -t $(DOCKER_AIO_IMAGE) -f Dockerfile.aio .
+
+docker-aio-all:
+ $(MAKE) docker-aio DOCKER_AIO_SIZE=cpu
+ $(MAKE) docker-aio DOCKER_AIO_SIZE=cpu
+
+docker-image-intel:
+ docker build \
+ --build-arg BASE_IMAGE=intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04 \
+ --build-arg IMAGE_TYPE=$(IMAGE_TYPE) \
+ --build-arg GO_TAGS="$(GO_TAGS)" \
+ --build-arg MAKEFLAGS="$(DOCKER_MAKEFLAGS)" \
+ --build-arg BUILD_TYPE=intel \
+ --build-arg CUDA_MAJOR_VERSION=$(CUDA_MAJOR_VERSION) \
+ --build-arg CUDA_MINOR_VERSION=$(CUDA_MINOR_VERSION) \
+ --build-arg UBUNTU_VERSION=$(UBUNTU_VERSION) \
+ --build-arg UBUNTU_CODENAME=$(UBUNTU_CODENAME) \
+ -t $(DOCKER_IMAGE) .
+
+########################################################
+## Backends
+########################################################
+
+# Pattern rule for standard backends (docker-based)
+# This matches all backends that use docker-build-* and docker-save-*
+backends/%: docker-build-% docker-save-% build
+ ./local-ai backends install "ocifile://$(abspath ./backend-images/$*.tar)"
+
+# Darwin-specific backends (keep as explicit targets since they have special build logic)
+backends/llama-cpp-darwin: build
+ bash ./scripts/build/llama-cpp-darwin.sh
+ ./local-ai backends install "ocifile://$(abspath ./backend-images/llama-cpp.tar)"
+
+build-darwin-python-backend: build
+ bash ./scripts/build/python-darwin.sh
+
+build-darwin-go-backend: build
+ bash ./scripts/build/golang-darwin.sh
+
+backends/mlx:
+ BACKEND=mlx $(MAKE) build-darwin-python-backend
+ ./local-ai backends install "ocifile://$(abspath ./backend-images/mlx.tar)"
+
+backends/diffuser-darwin:
+ BACKEND=diffusers $(MAKE) build-darwin-python-backend
+ ./local-ai backends install "ocifile://$(abspath ./backend-images/diffusers.tar)"
+
+backends/mlx-vlm:
+ BACKEND=mlx-vlm $(MAKE) build-darwin-python-backend
+ ./local-ai backends install "ocifile://$(abspath ./backend-images/mlx-vlm.tar)"
+
+backends/mlx-audio:
+ BACKEND=mlx-audio $(MAKE) build-darwin-python-backend
+ ./local-ai backends install "ocifile://$(abspath ./backend-images/mlx-audio.tar)"
+
+backends/stablediffusion-ggml-darwin:
+ BACKEND=stablediffusion-ggml BUILD_TYPE=metal $(MAKE) build-darwin-go-backend
+ ./local-ai backends install "ocifile://$(abspath ./backend-images/stablediffusion-ggml.tar)"
+
+backend-images:
+ mkdir -p backend-images
+
+# Backend metadata: BACKEND_NAME | DOCKERFILE_TYPE | BUILD_CONTEXT | PROGRESS_FLAG | NEEDS_BACKEND_ARG
+# llama-cpp is special - uses llama-cpp Dockerfile and doesn't need BACKEND arg
+BACKEND_LLAMA_CPP = llama-cpp|llama-cpp|.|false|false
+
+# Golang backends
+BACKEND_BARK_CPP = bark-cpp|golang|.|false|true
+BACKEND_PIPER = piper|golang|.|false|true
+BACKEND_LOCAL_STORE = local-store|golang|.|false|true
+BACKEND_HUGGINGFACE = huggingface|golang|.|false|true
+BACKEND_SILERO_VAD = silero-vad|golang|.|false|true
+BACKEND_STABLEDIFFUSION_GGML = stablediffusion-ggml|golang|.|--progress=plain|true
+BACKEND_WHISPER = whisper|golang|.|false|true
+
+# Python backends with root context
+BACKEND_RERANKERS = rerankers|python|.|false|true
+BACKEND_TRANSFORMERS = transformers|python|.|false|true
+BACKEND_FASTER_WHISPER = faster-whisper|python|.|false|true
+BACKEND_COQUI = coqui|python|.|false|true
+BACKEND_BARK = bark|python|.|false|true
+BACKEND_EXLLAMA2 = exllama2|python|.|false|true
+BACKEND_RFDETR = rfdetr|python|.|false|true
+BACKEND_KITTEN_TTS = kitten-tts|python|.|false|true
+BACKEND_NEUTTS = neutts|python|.|false|true
+BACKEND_KOKORO = kokoro|python|.|false|true
+BACKEND_VLLM = vllm|python|.|false|true
+BACKEND_DIFFUSERS = diffusers|python|.|--progress=plain|true
+BACKEND_CHATTERBOX = chatterbox|python|.|false|true
+BACKEND_VIBEVOICE = vibevoice|python|.|--progress=plain|true
+BACKEND_MOONSHINE = moonshine|python|.|false|true
+BACKEND_POCKET_TTS = pocket-tts|python|.|false|true
+
+# Helper function to build docker image for a backend
+# Usage: $(call docker-build-backend,BACKEND_NAME,DOCKERFILE_TYPE,BUILD_CONTEXT,PROGRESS_FLAG,NEEDS_BACKEND_ARG)
+define docker-build-backend
+ docker build $(if $(filter-out false,$(4)),$(4)) \
+ --build-arg BUILD_TYPE=$(BUILD_TYPE) \
+ --build-arg BASE_IMAGE=$(BASE_IMAGE) \
+ --build-arg CUDA_MAJOR_VERSION=$(CUDA_MAJOR_VERSION) \
+ --build-arg CUDA_MINOR_VERSION=$(CUDA_MINOR_VERSION) \
+ --build-arg UBUNTU_VERSION=$(UBUNTU_VERSION) \
+ --build-arg UBUNTU_CODENAME=$(UBUNTU_CODENAME) \
+ $(if $(filter true,$(5)),--build-arg BACKEND=$(1)) \
+ -t local-ai-backend:$(1) -f backend/Dockerfile.$(2) $(3)
+endef
+
+# Generate docker-build targets from backend definitions
+define generate-docker-build-target
+docker-build-$(word 1,$(subst |, ,$(1))):
+ $$(call docker-build-backend,$(word 1,$(subst |, ,$(1))),$(word 2,$(subst |, ,$(1))),$(word 3,$(subst |, ,$(1))),$(word 4,$(subst |, ,$(1))),$(word 5,$(subst |, ,$(1))))
+endef
+
+# Generate all docker-build targets
+$(eval $(call generate-docker-build-target,$(BACKEND_LLAMA_CPP)))
+$(eval $(call generate-docker-build-target,$(BACKEND_BARK_CPP)))
+$(eval $(call generate-docker-build-target,$(BACKEND_PIPER)))
+$(eval $(call generate-docker-build-target,$(BACKEND_LOCAL_STORE)))
+$(eval $(call generate-docker-build-target,$(BACKEND_HUGGINGFACE)))
+$(eval $(call generate-docker-build-target,$(BACKEND_SILERO_VAD)))
+$(eval $(call generate-docker-build-target,$(BACKEND_STABLEDIFFUSION_GGML)))
+$(eval $(call generate-docker-build-target,$(BACKEND_WHISPER)))
+$(eval $(call generate-docker-build-target,$(BACKEND_RERANKERS)))
+$(eval $(call generate-docker-build-target,$(BACKEND_TRANSFORMERS)))
+$(eval $(call generate-docker-build-target,$(BACKEND_FASTER_WHISPER)))
+$(eval $(call generate-docker-build-target,$(BACKEND_COQUI)))
+$(eval $(call generate-docker-build-target,$(BACKEND_BARK)))
+$(eval $(call generate-docker-build-target,$(BACKEND_EXLLAMA2)))
+$(eval $(call generate-docker-build-target,$(BACKEND_RFDETR)))
+$(eval $(call generate-docker-build-target,$(BACKEND_KITTEN_TTS)))
+$(eval $(call generate-docker-build-target,$(BACKEND_NEUTTS)))
+$(eval $(call generate-docker-build-target,$(BACKEND_KOKORO)))
+$(eval $(call generate-docker-build-target,$(BACKEND_VLLM)))
+$(eval $(call generate-docker-build-target,$(BACKEND_DIFFUSERS)))
+$(eval $(call generate-docker-build-target,$(BACKEND_CHATTERBOX)))
+$(eval $(call generate-docker-build-target,$(BACKEND_VIBEVOICE)))
+$(eval $(call generate-docker-build-target,$(BACKEND_MOONSHINE)))
+$(eval $(call generate-docker-build-target,$(BACKEND_POCKET_TTS)))
+
+# Pattern rule for docker-save targets
+docker-save-%: backend-images
+ docker save local-ai-backend:$* -o backend-images/$*.tar
+
+docker-build-backends: docker-build-llama-cpp docker-build-rerankers docker-build-vllm docker-build-transformers docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-coqui docker-build-bark docker-build-chatterbox docker-build-vibevoice docker-build-exllama2 docker-build-moonshine docker-build-pocket-tts
+
+########################################################
+### END Backends
+########################################################
+
+.PHONY: swagger
+swagger:
+ swag init -g core/http/app.go --output swagger
+
+.PHONY: gen-assets
+gen-assets:
+ $(GOCMD) run core/dependencies_manager/manager.go webui_static.yaml core/http/static/assets
+
+## Documentation
+docs/layouts/_default:
+ mkdir -p docs/layouts/_default
+
+docs/static/gallery.html: docs/layouts/_default
+ $(GOCMD) run ./.github/ci/modelslist.go ./gallery/index.yaml > docs/static/gallery.html
+
+docs/public: docs/layouts/_default docs/static/gallery.html
+ cd docs && hugo --minify
+
+docs-clean:
+ rm -rf docs/public
+ rm -rf docs/static/gallery.html
+
+.PHONY: docs
+docs: docs/static/gallery.html
+ cd docs && hugo serve
+
+########################################################
+## Platform-specific builds
+########################################################
+
+## fyne cross-platform build
+build-launcher-darwin: build-launcher
+ go run github.com/tiagomelo/macos-dmg-creator/cmd/createdmg@latest \
+ --appName "LocalAI" \
+ --appBinaryPath "$(LAUNCHER_BINARY_NAME)" \
+ --bundleIdentifier "com.localai.launcher" \
+ --iconPath "core/http/static/logo.png" \
+ --outputDir "dist/"
+
+build-launcher-linux:
+ cd cmd/launcher && go run fyne.io/tools/cmd/fyne@latest package -os linux -icon ../../core/http/static/logo.png --executable $(LAUNCHER_BINARY_NAME)-linux && mv launcher.tar.xz ../../$(LAUNCHER_BINARY_NAME)-linux.tar.xz
diff --git a/README.md b/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..743f9f15c23484553550237d5a0e049d638cf17d
--- /dev/null
+++ b/README.md
@@ -0,0 +1,451 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+> :bulb: Get help - [❓FAQ](https://localai.io/faq/) [💭Discussions](https://github.com/go-skynet/LocalAI/discussions) [:speech_balloon: Discord](https://discord.gg/uJAeKSAGDy) [:book: Documentation website](https://localai.io/)
+>
+> [💻 Quickstart](https://localai.io/basics/getting_started/) [🖼️ Models](https://models.localai.io/) [🚀 Roadmap](https://github.com/mudler/LocalAI/issues?q=is%3Aissue+is%3Aopen+label%3Aroadmap) [🛫 Examples](https://github.com/mudler/LocalAI-examples) Try on
+[](https://t.me/localaiofficial_bot)
+
+[](https://github.com/go-skynet/LocalAI/actions/workflows/test.yml)[](https://github.com/go-skynet/LocalAI/actions/workflows/release.yaml)[](https://github.com/go-skynet/LocalAI/actions/workflows/image.yml)[](https://github.com/go-skynet/LocalAI/actions/workflows/bump_deps.yaml)[](https://artifacthub.io/packages/search?repo=localai)
+
+**LocalAI** is the free, Open Source OpenAI alternative. LocalAI act as a drop-in replacement REST API that's compatible with OpenAI (Elevenlabs, Anthropic... ) API specifications for local AI inferencing. It allows you to run LLMs, generate images, audio (and not only) locally or on-prem with consumer grade hardware, supporting multiple model families. Does not require GPU. It is created and maintained by [Ettore Di Giacinto](https://github.com/mudler).
+
+
+## 📚🆕 Local Stack Family
+
+🆕 LocalAI is now part of a comprehensive suite of AI tools designed to work together:
+
+
+
+
+
+
+
+
+
+
+ A powerful Local AI agent management platform that serves as a drop-in replacement for OpenAI's Responses API, enhanced with advanced agentic capabilities.
+
+
+
+
+
+
+
+
+
+
+ A REST-ful API and knowledge base management system that provides persistent memory and storage capabilities for AI agents.
+
+
+
+
+## Screenshots / Video
+
+### Youtube video
+
+
+
+
+
+
+
+
+### Screenshots
+
+| Talk Interface | Generate Audio |
+| --- | --- |
+|  |  |
+
+| Models Overview | Generate Images |
+| --- | --- |
+|  |  |
+
+| Chat Interface | Home |
+| --- | --- |
+|  |  |
+
+| Login | Swarm |
+| --- | --- |
+| |  |
+
+## 💻 Quickstart
+
+> ⚠️ **Note:** The `install.sh` script is currently experiencing issues due to the heavy changes currently undergoing in LocalAI and may produce broken or misconfigured installations. Please use Docker installation (see below) or manual binary installation until [issue #8032](https://github.com/mudler/LocalAI/issues/8032) is resolved.
+
+Run the installer script:
+
+```bash
+# Basic installation
+curl https://localai.io/install.sh | sh
+```
+
+For more installation options, see [Installer Options](https://localai.io/installation/).
+
+### macOS Download:
+
+
+
+
+
+> Note: the DMGs are not signed by Apple as quarantined. See https://github.com/mudler/LocalAI/issues/6268 for a workaround, fix is tracked here: https://github.com/mudler/LocalAI/issues/6244
+
+### Containers (Docker, podman, ...)
+
+> **💡 Docker Run vs Docker Start**
+>
+> - `docker run` creates and starts a new container. If a container with the same name already exists, this command will fail.
+> - `docker start` starts an existing container that was previously created with `docker run`.
+>
+> If you've already run LocalAI before and want to start it again, use: `docker start -i local-ai`
+
+#### CPU only image:
+
+```bash
+docker run -ti --name local-ai -p 8080:8080 localai/localai:latest
+```
+
+#### NVIDIA GPU Images:
+
+```bash
+# CUDA 13.0
+docker run -ti --name local-ai -p 8080:8080 --gpus all localai/localai:latest-gpu-nvidia-cuda-13
+
+# CUDA 12.0
+docker run -ti --name local-ai -p 8080:8080 --gpus all localai/localai:latest-gpu-nvidia-cuda-12
+
+# NVIDIA Jetson (L4T) ARM64
+# CUDA 12 (for Nvidia AGX Orin and similar platforms)
+docker run -ti --name local-ai -p 8080:8080 --gpus all localai/localai:latest-nvidia-l4t-arm64
+
+# CUDA 13 (for Nvidia DGX Spark)
+docker run -ti --name local-ai -p 8080:8080 --gpus all localai/localai:latest-nvidia-l4t-arm64-cuda-13
+```
+
+#### AMD GPU Images (ROCm):
+
+```bash
+docker run -ti --name local-ai -p 8080:8080 --device=/dev/kfd --device=/dev/dri --group-add=video localai/localai:latest-gpu-hipblas
+```
+
+#### Intel GPU Images (oneAPI):
+
+```bash
+docker run -ti --name local-ai -p 8080:8080 --device=/dev/dri/card1 --device=/dev/dri/renderD128 localai/localai:latest-gpu-intel
+```
+
+#### Vulkan GPU Images:
+
+```bash
+docker run -ti --name local-ai -p 8080:8080 localai/localai:latest-gpu-vulkan
+```
+
+#### AIO Images (pre-downloaded models):
+
+```bash
+# CPU version
+docker run -ti --name local-ai -p 8080:8080 localai/localai:latest-aio-cpu
+
+# NVIDIA CUDA 13 version
+docker run -ti --name local-ai -p 8080:8080 --gpus all localai/localai:latest-aio-gpu-nvidia-cuda-13
+
+# NVIDIA CUDA 12 version
+docker run -ti --name local-ai -p 8080:8080 --gpus all localai/localai:latest-aio-gpu-nvidia-cuda-12
+
+# Intel GPU version
+docker run -ti --name local-ai -p 8080:8080 localai/localai:latest-aio-gpu-intel
+
+# AMD GPU version
+docker run -ti --name local-ai -p 8080:8080 --device=/dev/kfd --device=/dev/dri --group-add=video localai/localai:latest-aio-gpu-hipblas
+```
+
+For more information about the AIO images and pre-downloaded models, see [Container Documentation](https://localai.io/basics/container/).
+
+To load models:
+
+```bash
+# From the model gallery (see available models with `local-ai models list`, in the WebUI from the model tab, or visiting https://models.localai.io)
+local-ai run llama-3.2-1b-instruct:q4_k_m
+# Start LocalAI with the phi-2 model directly from huggingface
+local-ai run huggingface://TheBloke/phi-2-GGUF/phi-2.Q8_0.gguf
+# Install and run a model from the Ollama OCI registry
+local-ai run ollama://gemma:2b
+# Run a model from a configuration file
+local-ai run https://gist.githubusercontent.com/.../phi-2.yaml
+# Install and run a model from a standard OCI registry (e.g., Docker Hub)
+local-ai run oci://localai/phi-2:latest
+```
+
+> ⚡ **Automatic Backend Detection**: When you install models from the gallery or YAML files, LocalAI automatically detects your system's GPU capabilities (NVIDIA, AMD, Intel) and downloads the appropriate backend. For advanced configuration options, see [GPU Acceleration](https://localai.io/features/gpu-acceleration/#automatic-backend-detection).
+
+For more information, see [💻 Getting started](https://localai.io/basics/getting_started/index.html), if you are interested in our roadmap items and future enhancements, you can see the [Issues labeled as Roadmap here](https://github.com/mudler/LocalAI/issues?q=is%3Aissue+is%3Aopen+label%3Aroadmap)
+
+## 📰 Latest project news
+
+- December 2025: [Dynamic Memory Resource reclaimer](https://github.com/mudler/LocalAI/pull/7583), [Automatic fitting of models to multiple GPUS(llama.cpp)](https://github.com/mudler/LocalAI/pull/7584), [Added Vibevoice backend](https://github.com/mudler/LocalAI/pull/7494)
+- November 2025: Major improvements to the UX. Among these: [Import models via URL](https://github.com/mudler/LocalAI/pull/7245) and [Multiple chats and history](https://github.com/mudler/LocalAI/pull/7325)
+- October 2025: 🔌 [Model Context Protocol (MCP)](https://localai.io/docs/features/mcp/) support added for agentic capabilities with external tools
+- September 2025: New Launcher application for MacOS and Linux, extended support to many backends for Mac and Nvidia L4T devices. Models: Added MLX-Audio, WAN 2.2. WebUI improvements and Python-based backends now ships portable python environments.
+- August 2025: MLX, MLX-VLM, Diffusers and llama.cpp are now supported on Mac M1/M2/M3+ chips ( with `development` suffix in the gallery ): https://github.com/mudler/LocalAI/pull/6049 https://github.com/mudler/LocalAI/pull/6119 https://github.com/mudler/LocalAI/pull/6121 https://github.com/mudler/LocalAI/pull/6060
+- July/August 2025: 🔍 [Object Detection](https://localai.io/features/object-detection/) added to the API featuring [rf-detr](https://github.com/roboflow/rf-detr)
+- July 2025: All backends migrated outside of the main binary. LocalAI is now more lightweight, small, and automatically downloads the required backend to run the model. [Read the release notes](https://github.com/mudler/LocalAI/releases/tag/v3.2.0)
+- June 2025: [Backend management](https://github.com/mudler/LocalAI/pull/5607) has been added. Attention: extras images are going to be deprecated from the next release! Read [the backend management PR](https://github.com/mudler/LocalAI/pull/5607).
+- May 2025: [Audio input](https://github.com/mudler/LocalAI/pull/5466) and [Reranking](https://github.com/mudler/LocalAI/pull/5396) in llama.cpp backend, [Realtime API](https://github.com/mudler/LocalAI/pull/5392), Support to Gemma, SmollVLM, and more multimodal models (available in the gallery).
+- May 2025: Important: image name changes [See release](https://github.com/mudler/LocalAI/releases/tag/v2.29.0)
+- Apr 2025: Rebrand, WebUI enhancements
+- Apr 2025: [LocalAGI](https://github.com/mudler/LocalAGI) and [LocalRecall](https://github.com/mudler/LocalRecall) join the LocalAI family stack.
+- Apr 2025: WebUI overhaul, AIO images updates
+- Feb 2025: Backend cleanup, Breaking changes, new backends (kokoro, OutelTTS, faster-whisper), Nvidia L4T images
+- Jan 2025: LocalAI model release: https://huggingface.co/mudler/LocalAI-functioncall-phi-4-v0.3, SANA support in diffusers: https://github.com/mudler/LocalAI/pull/4603
+- Dec 2024: stablediffusion.cpp backend (ggml) added ( https://github.com/mudler/LocalAI/pull/4289 )
+- Nov 2024: Bark.cpp backend added ( https://github.com/mudler/LocalAI/pull/4287 )
+- Nov 2024: Voice activity detection models (**VAD**) added to the API: https://github.com/mudler/LocalAI/pull/4204
+- Oct 2024: examples moved to [LocalAI-examples](https://github.com/mudler/LocalAI-examples)
+- Aug 2024: 🆕 FLUX-1, [P2P Explorer](https://explorer.localai.io)
+- July 2024: 🔥🔥 🆕 P2P Dashboard, LocalAI Federated mode and AI Swarms: https://github.com/mudler/LocalAI/pull/2723. P2P Global community pools: https://github.com/mudler/LocalAI/issues/3113
+- May 2024: 🔥🔥 Decentralized P2P llama.cpp: https://github.com/mudler/LocalAI/pull/2343 (peer2peer llama.cpp!) 👉 Docs https://localai.io/features/distribute/
+- May 2024: 🔥🔥 Distributed inferencing: https://github.com/mudler/LocalAI/pull/2324
+- April 2024: Reranker API: https://github.com/mudler/LocalAI/pull/2121
+
+Roadmap items: [List of issues](https://github.com/mudler/LocalAI/issues?q=is%3Aissue+is%3Aopen+label%3Aroadmap)
+
+## 🚀 [Features](https://localai.io/features/)
+
+- 🧩 [Backend Gallery](https://localai.io/backends/): Install/remove backends on the fly, powered by OCI images — fully customizable and API-driven.
+- 📖 [Text generation with GPTs](https://localai.io/features/text-generation/) (`llama.cpp`, `transformers`, `vllm` ... [:book: and more](https://localai.io/model-compatibility/index.html#model-compatibility-table))
+- 🗣 [Text to Audio](https://localai.io/features/text-to-audio/)
+- 🔈 [Audio to Text](https://localai.io/features/audio-to-text/) (Audio transcription with `whisper.cpp`)
+- 🎨 [Image generation](https://localai.io/features/image-generation)
+- 🔥 [OpenAI-alike tools API](https://localai.io/features/openai-functions/)
+- 🧠 [Embeddings generation for vector databases](https://localai.io/features/embeddings/)
+- ✍️ [Constrained grammars](https://localai.io/features/constrained_grammars/)
+- 🖼️ [Download Models directly from Huggingface ](https://localai.io/models/)
+- 🥽 [Vision API](https://localai.io/features/gpt-vision/)
+- 🔍 [Object Detection](https://localai.io/features/object-detection/)
+- 📈 [Reranker API](https://localai.io/features/reranker/)
+- 🆕🖧 [P2P Inferencing](https://localai.io/features/distribute/)
+- 🆕🔌 [Model Context Protocol (MCP)](https://localai.io/docs/features/mcp/) - Agentic capabilities with external tools and [LocalAGI's Agentic capabilities](https://github.com/mudler/LocalAGI)
+- 🔊 Voice activity detection (Silero-VAD support)
+- 🌍 Integrated WebUI!
+
+## 🧩 Supported Backends & Acceleration
+
+LocalAI supports a comprehensive range of AI backends with multiple acceleration options:
+
+### Text Generation & Language Models
+| Backend | Description | Acceleration Support |
+|---------|-------------|---------------------|
+| **llama.cpp** | LLM inference in C/C++ | CUDA 12/13, ROCm, Intel SYCL, Vulkan, Metal, CPU |
+| **vLLM** | Fast LLM inference with PagedAttention | CUDA 12/13, ROCm, Intel |
+| **transformers** | HuggingFace transformers framework | CUDA 12/13, ROCm, Intel, CPU |
+| **exllama2** | GPTQ inference library | CUDA 12/13 |
+| **MLX** | Apple Silicon LLM inference | Metal (M1/M2/M3+) |
+| **MLX-VLM** | Apple Silicon Vision-Language Models | Metal (M1/M2/M3+) |
+
+### Audio & Speech Processing
+| Backend | Description | Acceleration Support |
+|---------|-------------|---------------------|
+| **whisper.cpp** | OpenAI Whisper in C/C++ | CUDA 12/13, ROCm, Intel SYCL, Vulkan, CPU |
+| **faster-whisper** | Fast Whisper with CTranslate2 | CUDA 12/13, ROCm, Intel, CPU |
+| **bark** | Text-to-audio generation | CUDA 12/13, ROCm, Intel |
+| **bark-cpp** | C++ implementation of Bark | CUDA, Metal, CPU |
+| **coqui** | Advanced TTS with 1100+ languages | CUDA 12/13, ROCm, Intel, CPU |
+| **kokoro** | Lightweight TTS model | CUDA 12/13, ROCm, Intel, CPU |
+| **chatterbox** | Production-grade TTS | CUDA 12/13, CPU |
+| **piper** | Fast neural TTS system | CPU |
+| **kitten-tts** | Kitten TTS models | CPU |
+| **silero-vad** | Voice Activity Detection | CPU |
+| **neutts** | Text-to-speech with voice cloning | CUDA 12/13, ROCm, CPU |
+| **vibevoice** | Real-time TTS with voice cloning | CUDA 12/13, ROCm, Intel, CPU |
+| **pocket-tts** | Lightweight CPU-based TTS | CUDA 12/13, ROCm, Intel, CPU |
+
+### Image & Video Generation
+| Backend | Description | Acceleration Support |
+|---------|-------------|---------------------|
+| **stablediffusion.cpp** | Stable Diffusion in C/C++ | CUDA 12/13, Intel SYCL, Vulkan, CPU |
+| **diffusers** | HuggingFace diffusion models | CUDA 12/13, ROCm, Intel, Metal, CPU |
+
+### Specialized AI Tasks
+| Backend | Description | Acceleration Support |
+|---------|-------------|---------------------|
+| **rfdetr** | Real-time object detection | CUDA 12/13, Intel, CPU |
+| **rerankers** | Document reranking API | CUDA 12/13, ROCm, Intel, CPU |
+| **local-store** | Vector database | CPU |
+| **huggingface** | HuggingFace API integration | API-based |
+
+### Hardware Acceleration Matrix
+
+| Acceleration Type | Supported Backends | Hardware Support |
+|-------------------|-------------------|------------------|
+| **NVIDIA CUDA 12** | All CUDA-compatible backends | Nvidia hardware |
+| **NVIDIA CUDA 13** | All CUDA-compatible backends | Nvidia hardware |
+| **AMD ROCm** | llama.cpp, whisper, vllm, transformers, diffusers, rerankers, coqui, kokoro, bark, neutts, vibevoice, pocket-tts | AMD Graphics |
+| **Intel oneAPI** | llama.cpp, whisper, stablediffusion, vllm, transformers, diffusers, rfdetr, rerankers, exllama2, coqui, kokoro, bark, vibevoice, pocket-tts | Intel Arc, Intel iGPUs |
+| **Apple Metal** | llama.cpp, whisper, diffusers, MLX, MLX-VLM, bark-cpp | Apple M1/M2/M3+ |
+| **Vulkan** | llama.cpp, whisper, stablediffusion | Cross-platform GPUs |
+| **NVIDIA Jetson (CUDA 12)** | llama.cpp, whisper, stablediffusion, diffusers, rfdetr | ARM64 embedded AI (AGX Orin, etc.) |
+| **NVIDIA Jetson (CUDA 13)** | llama.cpp, whisper, stablediffusion, diffusers, rfdetr | ARM64 embedded AI (DGX Spark) |
+| **CPU Optimized** | All backends | AVX/AVX2/AVX512, quantization support |
+
+### 🔗 Community and integrations
+
+Build and deploy custom containers:
+- https://github.com/sozercan/aikit
+
+WebUIs:
+- https://github.com/Jirubizu/localai-admin
+- https://github.com/go-skynet/LocalAI-frontend
+- QA-Pilot(An interactive chat project that leverages LocalAI LLMs for rapid understanding and navigation of GitHub code repository) https://github.com/reid41/QA-Pilot
+
+Agentic Libraries:
+- https://github.com/mudler/cogito
+
+MCPs:
+- https://github.com/mudler/MCPs
+
+Model galleries
+- https://github.com/go-skynet/model-gallery
+
+Voice:
+- https://github.com/richiejp/VoxInput
+
+Other:
+- Helm chart https://github.com/go-skynet/helm-charts
+- VSCode extension https://github.com/badgooooor/localai-vscode-plugin
+- Langchain: https://python.langchain.com/docs/integrations/providers/localai/
+- Terminal utility https://github.com/djcopley/ShellOracle
+- Local Smart assistant https://github.com/mudler/LocalAGI
+- Home Assistant https://github.com/sammcj/homeassistant-localai / https://github.com/drndos/hass-openai-custom-conversation / https://github.com/valentinfrlch/ha-gpt4vision
+- Discord bot https://github.com/mudler/LocalAGI/tree/main/examples/discord
+- Slack bot https://github.com/mudler/LocalAGI/tree/main/examples/slack
+- Shell-Pilot(Interact with LLM using LocalAI models via pure shell scripts on your Linux or MacOS system) https://github.com/reid41/shell-pilot
+- Telegram bot https://github.com/mudler/LocalAI/tree/master/examples/telegram-bot
+- Another Telegram Bot https://github.com/JackBekket/Hellper
+- Auto-documentation https://github.com/JackBekket/Reflexia
+- Github bot which answer on issues, with code and documentation as context https://github.com/JackBekket/GitHelper
+- Github Actions: https://github.com/marketplace/actions/start-localai
+- Examples: https://github.com/mudler/LocalAI/tree/master/examples/
+
+
+### 🔗 Resources
+
+- [LLM finetuning guide](https://localai.io/docs/advanced/fine-tuning/)
+- [How to build locally](https://localai.io/basics/build/index.html)
+- [How to install in Kubernetes](https://localai.io/basics/getting_started/index.html#run-localai-in-kubernetes)
+- [Projects integrating LocalAI](https://localai.io/docs/integrations/)
+- [How tos section](https://io.midori-ai.xyz/howtos/) (curated by our community)
+
+## :book: 🎥 [Media, Blogs, Social](https://localai.io/basics/news/#media-blogs-social)
+
+- [Run Visual studio code with LocalAI (SUSE)](https://www.suse.com/c/running-ai-locally/)
+- 🆕 [Run LocalAI on Jetson Nano Devkit](https://mudler.pm/posts/local-ai-jetson-nano-devkit/)
+- [Run LocalAI on AWS EKS with Pulumi](https://www.pulumi.com/blog/low-code-llm-apps-with-local-ai-flowise-and-pulumi/)
+- [Run LocalAI on AWS](https://staleks.hashnode.dev/installing-localai-on-aws-ec2-instance)
+- [Create a slackbot for teams and OSS projects that answer to documentation](https://mudler.pm/posts/smart-slackbot-for-teams/)
+- [LocalAI meets k8sgpt](https://www.youtube.com/watch?v=PKrDNuJ_dfE)
+- [Question Answering on Documents locally with LangChain, LocalAI, Chroma, and GPT4All](https://mudler.pm/posts/localai-question-answering/)
+- [Tutorial to use k8sgpt with LocalAI](https://medium.com/@tyler_97636/k8sgpt-localai-unlock-kubernetes-superpowers-for-free-584790de9b65)
+
+## Citation
+
+If you utilize this repository, data in a downstream project, please consider citing it with:
+
+```
+@misc{localai,
+ author = {Ettore Di Giacinto},
+ title = {LocalAI: The free, Open source OpenAI alternative},
+ year = {2023},
+ publisher = {GitHub},
+ journal = {GitHub repository},
+ howpublished = {\url{https://github.com/go-skynet/LocalAI}},
+```
+
+## ❤️ Sponsors
+
+> Do you find LocalAI useful?
+
+Support the project by becoming [a backer or sponsor](https://github.com/sponsors/mudler). Your logo will show up here with a link to your website.
+
+A huge thank you to our generous sponsors who support this project covering CI expenses, and our [Sponsor list](https://github.com/sponsors/mudler):
+
+
+
+
+
+
+
+
+
+
+### Individual sponsors
+
+A special thanks to individual sponsors that contributed to the project, a full list is in [Github](https://github.com/sponsors/mudler) and [buymeacoffee](https://buymeacoffee.com/mudler), a special shout out goes to [drikster80](https://github.com/drikster80) for being generous. Thank you everyone!
+
+## 🌟 Star history
+
+[](https://star-history.com/#go-skynet/LocalAI&Date)
+
+## 📖 License
+
+LocalAI is a community-driven project created by [Ettore Di Giacinto](https://github.com/mudler/).
+
+MIT - Author Ettore Di Giacinto
+
+## 🙇 Acknowledgements
+
+LocalAI couldn't have been built without the help of great software already available from the community. Thank you!
+
+- [llama.cpp](https://github.com/ggerganov/llama.cpp)
+- https://github.com/tatsu-lab/stanford_alpaca
+- https://github.com/cornelk/llama-go for the initial ideas
+- https://github.com/antimatter15/alpaca.cpp
+- https://github.com/EdVince/Stable-Diffusion-NCNN
+- https://github.com/ggerganov/whisper.cpp
+- https://github.com/rhasspy/piper
+
+## 🤗 Contributors
+
+This is a community project, a special thanks to our contributors! 🤗
+
+
+
diff --git a/SECURITY.md b/SECURITY.md
new file mode 100644
index 0000000000000000000000000000000000000000..9c39f823203df0d53bb9051d399eeb082ab4d286
--- /dev/null
+++ b/SECURITY.md
@@ -0,0 +1,42 @@
+# Security Policy
+
+## Introduction
+
+At LocalAI, we take the security of our software seriously. We understand the importance of protecting our community from vulnerabilities and are committed to ensuring the safety and security of our users.
+
+## Supported Versions
+
+We provide support and updates for certain versions of our software. The following table outlines which versions are currently supported with security updates:
+
+| Version | Supported |
+| ------- | ------------------ |
+| > 2.0 | :white_check_mark: |
+| < 2.0 | :x: |
+
+Please ensure that you are using a supported version to receive the latest security updates.
+
+## Reporting a Vulnerability
+
+We encourage the responsible disclosure of any security vulnerabilities. If you believe you've found a security issue in our software, we kindly ask you to follow the steps below to report it to us:
+
+1. **Email Us:** Send an email to [security@localai.io](mailto:security@localai.io) with a detailed report. Please do not disclose the vulnerability publicly or to any third parties before it has been addressed by us.
+
+2. **Expect a Response:** We aim to acknowledge receipt of vulnerability reports within 48 hours. Our security team will review your report and work closely with you to understand the impact and ensure a thorough investigation.
+
+3. **Collaboration:** If the vulnerability is accepted, we will work with you and our community to address the issue promptly. We'll keep you informed throughout the resolution process and may request additional information or collaboration.
+
+4. **Disclosure:** Once the vulnerability has been resolved, we encourage a coordinated disclosure. We believe in transparency and will work with you to ensure that our community is informed in a responsible manner.
+
+## Use of Third-Party Platforms
+
+As a Free and Open Source Software (FOSS) organization, we do not offer monetary bounties. However, researchers who wish to report vulnerabilities can also do so via [Huntr](https://huntr.dev/bounties), a platform that recognizes contributions to open source security.
+
+## Contact
+
+For any security-related inquiries beyond vulnerability reporting, please contact us at [security@localai.io](mailto:security@localai.io).
+
+## Acknowledgments
+
+We appreciate the efforts of those who contribute to the security of our project. Your responsible disclosure is invaluable to the safety and integrity of LocalAI.
+
+Thank you for helping us keep LocalAI secure.
diff --git a/aio/cpu/README.md b/aio/cpu/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..8b0b1086dbc15e48d97f4c44537914ba3c573965
--- /dev/null
+++ b/aio/cpu/README.md
@@ -0,0 +1,5 @@
+## AIO CPU size
+
+Use this image with CPU-only.
+
+Please keep using only C++ backends so the base image is as small as possible (without CUDA, cuDNN, python, etc).
\ No newline at end of file
diff --git a/aio/cpu/embeddings.yaml b/aio/cpu/embeddings.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..0f88f4511ba50c624a4e57aa62e984260f6387c9
--- /dev/null
+++ b/aio/cpu/embeddings.yaml
@@ -0,0 +1,13 @@
+embeddings: true
+name: text-embedding-ada-002
+backend: llama-cpp
+parameters:
+ model: huggingface://bartowski/granite-embedding-107m-multilingual-GGUF/granite-embedding-107m-multilingual-f16.gguf
+
+usage: |
+ You can test this model with curl like this:
+
+ curl http://localhost:8080/embeddings -X POST -H "Content-Type: application/json" -d '{
+ "input": "Your text string goes here",
+ "model": "text-embedding-ada-002"
+ }'
\ No newline at end of file
diff --git a/aio/cpu/image-gen.yaml b/aio/cpu/image-gen.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..ef3745726e3d492013d2fe122b2db3777a51491c
--- /dev/null
+++ b/aio/cpu/image-gen.yaml
@@ -0,0 +1,23 @@
+name: stablediffusion
+backend: stablediffusion-ggml
+cfg_scale: 4.5
+
+options:
+- sampler:euler
+parameters:
+ model: stable-diffusion-v1-5-pruned-emaonly-Q4_0.gguf
+step: 25
+
+download_files:
+- filename: "stable-diffusion-v1-5-pruned-emaonly-Q4_0.gguf"
+ sha256: "b8944e9fe0b69b36ae1b5bb0185b3a7b8ef14347fe0fa9af6c64c4829022261f"
+ uri: "huggingface://second-state/stable-diffusion-v1-5-GGUF/stable-diffusion-v1-5-pruned-emaonly-Q4_0.gguf"
+
+usage: |
+ curl http://localhost:8080/v1/images/generations \
+ -H "Content-Type: application/json" \
+ -d '{
+ "prompt": "|",
+ "step": 25,
+ "size": "512x512"
+ }'
\ No newline at end of file
diff --git a/aio/cpu/rerank.yaml b/aio/cpu/rerank.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..70d386b2b6c47a010bbb54cd70b4542fc1e00262
--- /dev/null
+++ b/aio/cpu/rerank.yaml
@@ -0,0 +1,33 @@
+name: jina-reranker-v1-base-en
+reranking: true
+f16: true
+parameters:
+ model: jina-reranker-v1-tiny-en.f16.gguf
+backend: llama-cpp
+download_files:
+ - filename: jina-reranker-v1-tiny-en.f16.gguf
+ sha256: 5f696cf0d0f3d347c4a279eee8270e5918554cdac0ed1f632f2619e4e8341407
+ uri: huggingface://mradermacher/jina-reranker-v1-tiny-en-GGUF/jina-reranker-v1-tiny-en.f16.gguf
+
+usage: |
+ You can test this model with curl like this:
+
+ curl http://localhost:8080/v1/rerank \
+ -H "Content-Type: application/json" \
+ -d '{
+ "model": "jina-reranker-v1-base-en",
+ "query": "Organic skincare products for sensitive skin",
+ "documents": [
+ "Eco-friendly kitchenware for modern homes",
+ "Biodegradable cleaning supplies for eco-conscious consumers",
+ "Organic cotton baby clothes for sensitive skin",
+ "Natural organic skincare range for sensitive skin",
+ "Tech gadgets for smart homes: 2024 edition",
+ "Sustainable gardening tools and compost solutions",
+ "Sensitive skin-friendly facial cleansers and toners",
+ "Organic food wraps and storage solutions",
+ "All-natural pet food for dogs with allergies",
+ "Yoga mats made from recycled materials"
+ ],
+ "top_n": 3
+ }'
diff --git a/aio/cpu/speech-to-text.yaml b/aio/cpu/speech-to-text.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..77850d79155439eb5b185a9d322f49910b1fe8c8
--- /dev/null
+++ b/aio/cpu/speech-to-text.yaml
@@ -0,0 +1,18 @@
+name: whisper-1
+backend: whisper
+parameters:
+ model: ggml-whisper-base.bin
+
+usage: |
+ ## example audio file
+ wget --quiet --show-progress -O gb1.ogg https://upload.wikimedia.org/wikipedia/commons/1/1f/George_W_Bush_Columbia_FINAL.ogg
+
+ ## Send the example audio file to the transcriptions endpoint
+ curl http://localhost:8080/v1/audio/transcriptions \
+ -H "Content-Type: multipart/form-data" \
+ -F file="@$PWD/gb1.ogg" -F model="whisper-1"
+
+download_files:
+- filename: "ggml-whisper-base.bin"
+ sha256: "60ed5bc3dd14eea856493d334349b405782ddcaf0028d4b5df4088345fba2efe"
+ uri: "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-base.bin"
\ No newline at end of file
diff --git a/aio/cpu/text-to-speech.yaml b/aio/cpu/text-to-speech.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..4009c3f77ba8aa9cc95fc53ffd1145fef154089a
--- /dev/null
+++ b/aio/cpu/text-to-speech.yaml
@@ -0,0 +1,15 @@
+name: tts-1
+download_files:
+ - filename: voice-en-us-amy-low.tar.gz
+ uri: https://github.com/rhasspy/piper/releases/download/v0.0.2/voice-en-us-amy-low.tar.gz
+backend: piper
+parameters:
+ model: en-us-amy-low.onnx
+
+usage: |
+ To test if this model works as expected, you can use the following curl command:
+
+ curl http://localhost:8080/tts -H "Content-Type: application/json" -d '{
+ "model":"voice-en-us-amy-low",
+ "input": "Hi, this is a test."
+ }'
\ No newline at end of file
diff --git a/aio/cpu/text-to-text.yaml b/aio/cpu/text-to-text.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..19ed1f4403dbad318a498e70e3b213cb2c5cdf77
--- /dev/null
+++ b/aio/cpu/text-to-text.yaml
@@ -0,0 +1,58 @@
+context_size: 8192
+f16: true
+backend: llama-cpp
+function:
+ grammar:
+ no_mixed_free_string: true
+ schema_type: llama3.1 # or JSON is supported too (json)
+ response_regex:
+ - \w+)>(?P.*)
+mmap: true
+name: gpt-4
+parameters:
+ model: Hermes-3-Llama-3.2-3B-Q4_K_M.gguf
+stopwords:
+- <|im_end|>
+-
+- <|eot_id|>
+- <|end_of_text|>
+template:
+ chat: |
+ <|begin_of_text|><|start_header_id|>system<|end_header_id|>
+ You are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>
+ {{.Input }}
+ <|start_header_id|>assistant<|end_header_id|>
+ chat_message: |
+ <|start_header_id|>{{if eq .RoleName "assistant"}}assistant{{else if eq .RoleName "system"}}system{{else if eq .RoleName "tool"}}tool{{else if eq .RoleName "user"}}user{{end}}<|end_header_id|>
+ {{ if .FunctionCall -}}
+ {{ else if eq .RoleName "tool" -}}
+ The Function was executed and the response was:
+ {{ end -}}
+ {{ if .Content -}}
+ {{.Content -}}
+ {{ else if .FunctionCall -}}
+ {{ range .FunctionCall }}
+ [{{.FunctionCall.Name}}({{.FunctionCall.Arguments}})]
+ {{ end }}
+ {{ end -}}
+ <|eot_id|>
+ completion: |
+ {{.Input}}
+ function: |
+ <|start_header_id|>system<|end_header_id|>
+ You are an expert in composing functions. You are given a question and a set of possible functions.
+ Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
+ If none of the functions can be used, point it out. If the given question lacks the parameters required by the function, also point it out. You should only return the function call in tools call sections.
+ If you decide to invoke any of the function(s), you MUST put it in the format as follows:
+ [func_name1(params_name1=params_value1,params_name2=params_value2,...),func_name2(params_name1=params_value1,params_name2=params_value2,...)]
+ You SHOULD NOT include any other text in the response.
+ Here is a list of functions in JSON format that you can invoke.
+ {{toJson .Functions}}
+ <|eot_id|><|start_header_id|>user<|end_header_id|>
+ {{.Input}}
+ <|eot_id|><|start_header_id|>assistant<|end_header_id|>
+
+download_files:
+- filename: Hermes-3-Llama-3.2-3B-Q4_K_M.gguf
+ sha256: 2e220a14ba4328fee38cf36c2c068261560f999fadb5725ce5c6d977cb5126b5
+ uri: huggingface://bartowski/Hermes-3-Llama-3.2-3B-GGUF/Hermes-3-Llama-3.2-3B-Q4_K_M.gguf
\ No newline at end of file
diff --git a/aio/cpu/vad.yaml b/aio/cpu/vad.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..b0dc70d75ed17e011a41c293d0fbd30c9e8d24aa
--- /dev/null
+++ b/aio/cpu/vad.yaml
@@ -0,0 +1,8 @@
+backend: silero-vad
+name: silero-vad
+parameters:
+ model: silero-vad.onnx
+download_files:
+- filename: silero-vad.onnx
+ uri: https://huggingface.co/onnx-community/silero-vad/resolve/main/onnx/model.onnx
+ sha256: a4a068cd6cf1ea8355b84327595838ca748ec29a25bc91fc82e6c299ccdc5808
\ No newline at end of file
diff --git a/aio/cpu/vision.yaml b/aio/cpu/vision.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..37852da059a278b8f106c2738d34dfb631a5086c
--- /dev/null
+++ b/aio/cpu/vision.yaml
@@ -0,0 +1,50 @@
+context_size: 4096
+f16: true
+backend: llama-cpp
+mmap: true
+mmproj: minicpm-v-4_5-mmproj-f16.gguf
+name: gpt-4o
+parameters:
+ model: minicpm-v-4_5-Q4_K_M.gguf
+stopwords:
+- <|im_end|>
+-
+-
+- <|endoftext|>
+template:
+ chat: |
+ {{.Input -}}
+ <|im_start|>assistant
+ chat_message: |
+ <|im_start|>{{ .RoleName }}
+ {{ if .FunctionCall -}}
+ Function call:
+ {{ else if eq .RoleName "tool" -}}
+ Function response:
+ {{ end -}}
+ {{ if .Content -}}
+ {{.Content }}
+ {{ end -}}
+ {{ if .FunctionCall -}}
+ {{toJson .FunctionCall}}
+ {{ end -}}<|im_end|>
+ completion: |
+ {{.Input}}
+ function: |
+ <|im_start|>system
+ You are a function calling AI model. You are provided with functions to execute. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools:
+ {{range .Functions}}
+ {'type': 'function', 'function': {'name': '{{.Name}}', 'description': '{{.Description}}', 'parameters': {{toJson .Parameters}} }}
+ {{end}}
+ For each function call return a json object with function name and arguments
+ <|im_end|>
+ {{.Input -}}
+ <|im_start|>assistant
+
+download_files:
+- filename: minicpm-v-4_5-Q4_K_M.gguf
+ sha256: c1c3c33100b15b4caf7319acce4e23c0eb0ce1cbd12f70e8d24f05aa67b7512f
+ uri: huggingface://openbmb/MiniCPM-V-4_5-gguf/ggml-model-Q4_K_M.gguf
+- filename: minicpm-v-4_5-mmproj-f16.gguf
+ uri: huggingface://openbmb/MiniCPM-V-4_5-gguf/mmproj-model-f16.gguf
+ sha256: 7a7225a32e8d453aaa3d22d8c579b5bf833c253f784cdb05c99c9a76fd616df8
\ No newline at end of file
diff --git a/aio/entrypoint.sh b/aio/entrypoint.sh
new file mode 100644
index 0000000000000000000000000000000000000000..a4b83a9daccc63703d14fa508e04d1d99451649f
--- /dev/null
+++ b/aio/entrypoint.sh
@@ -0,0 +1,138 @@
+#!/bin/bash
+
+echo "===> LocalAI All-in-One (AIO) container starting..."
+
+GPU_ACCELERATION=false
+GPU_VENDOR=""
+
+function check_intel() {
+ if lspci | grep -E 'VGA|3D' | grep -iq intel; then
+ echo "Intel GPU detected"
+ if [ -d /opt/intel ]; then
+ GPU_ACCELERATION=true
+ GPU_VENDOR=intel
+ else
+ echo "Intel GPU detected, but Intel GPU drivers are not installed. GPU acceleration will not be available."
+ fi
+ fi
+}
+
+function check_nvidia_wsl() {
+ if lspci | grep -E 'VGA|3D' | grep -iq "Microsoft Corporation Device 008e"; then
+ # We make the assumption this WSL2 cars is NVIDIA, then check for nvidia-smi
+ # Make sure the container was run with `--gpus all` as the only required parameter
+ echo "NVIDIA GPU detected via WSL2"
+ # nvidia-smi should be installed in the container
+ if nvidia-smi; then
+ GPU_ACCELERATION=true
+ GPU_VENDOR=nvidia
+ else
+ echo "NVIDIA GPU detected via WSL2, but nvidia-smi is not installed. GPU acceleration will not be available."
+ fi
+ fi
+}
+
+function check_amd() {
+ if lspci | grep -E 'VGA|3D' | grep -iq amd; then
+ echo "AMD GPU detected"
+ # Check if ROCm is installed
+ if [ -d /opt/rocm ]; then
+ GPU_ACCELERATION=true
+ GPU_VENDOR=amd
+ else
+ echo "AMD GPU detected, but ROCm is not installed. GPU acceleration will not be available."
+ fi
+ fi
+}
+
+function check_nvidia() {
+ if lspci | grep -E 'VGA|3D' | grep -iq nvidia; then
+ echo "NVIDIA GPU detected"
+ # nvidia-smi should be installed in the container
+ if nvidia-smi; then
+ GPU_ACCELERATION=true
+ GPU_VENDOR=nvidia
+ else
+ echo "NVIDIA GPU detected, but nvidia-smi is not installed. GPU acceleration will not be available."
+ fi
+ fi
+}
+
+function check_metal() {
+ if system_profiler SPDisplaysDataType | grep -iq 'Metal'; then
+ echo "Apple Metal supported GPU detected"
+ GPU_ACCELERATION=true
+ GPU_VENDOR=apple
+ fi
+}
+
+function detect_gpu() {
+ case "$(uname -s)" in
+ Linux)
+ check_nvidia
+ check_amd
+ check_intel
+ check_nvidia_wsl
+ ;;
+ Darwin)
+ check_metal
+ ;;
+ esac
+}
+
+function detect_gpu_size() {
+ # Attempting to find GPU memory size for NVIDIA GPUs
+ if [ "$GPU_ACCELERATION" = true ] && [ "$GPU_VENDOR" = "nvidia" ]; then
+ echo "NVIDIA GPU detected. Attempting to find memory size..."
+ # Using head -n 1 to get the total memory of the 1st NVIDIA GPU detected.
+ # If handling multiple GPUs is required in the future, this is the place to do it
+ nvidia_sm=$(nvidia-smi --query-gpu=memory.total --format=csv,noheader,nounits | head -n 1)
+ if [ ! -z "$nvidia_sm" ]; then
+ echo "Total GPU Memory: $nvidia_sm MiB"
+ # if bigger than 8GB, use 16GB
+ #if [ "$nvidia_sm" -gt 8192 ]; then
+ # GPU_SIZE=gpu-16g
+ #else
+ GPU_SIZE=gpu-8g
+ #fi
+ else
+ echo "Unable to determine NVIDIA GPU memory size. Falling back to CPU."
+ GPU_SIZE=gpu-8g
+ fi
+ elif [ "$GPU_ACCELERATION" = true ] && [ "$GPU_VENDOR" = "intel" ]; then
+ GPU_SIZE=intel
+ # Default to a generic GPU size until we implement GPU size detection for non NVIDIA GPUs
+ elif [ "$GPU_ACCELERATION" = true ]; then
+ echo "Non-NVIDIA GPU detected. Specific GPU memory size detection is not implemented."
+ GPU_SIZE=gpu-8g
+
+ # default to cpu if GPU_SIZE is not set
+ else
+ echo "GPU acceleration is not enabled or supported. Defaulting to CPU."
+ GPU_SIZE=cpu
+ fi
+}
+
+function check_vars() {
+ if [ -z "$MODELS" ]; then
+ echo "MODELS environment variable is not set. Please set it to a comma-separated list of model YAML files to load."
+ exit 1
+ fi
+
+ if [ -z "$PROFILE" ]; then
+ echo "PROFILE environment variable is not set. Please set it to one of the following: cpu, gpu-8g, gpu-16g, apple"
+ exit 1
+ fi
+}
+
+detect_gpu
+detect_gpu_size
+
+PROFILE="${PROFILE:-$GPU_SIZE}" # default to cpu
+export MODELS="${MODELS:-/aio/${PROFILE}/embeddings.yaml,/aio/${PROFILE}/rerank.yaml,/aio/${PROFILE}/text-to-speech.yaml,/aio/${PROFILE}/image-gen.yaml,/aio/${PROFILE}/text-to-text.yaml,/aio/${PROFILE}/speech-to-text.yaml,/aio/${PROFILE}/vad.yaml,/aio/${PROFILE}/vision.yaml}"
+
+check_vars
+
+echo "===> Starting LocalAI[$PROFILE] with the following models: $MODELS"
+
+exec /entrypoint.sh "$@"
diff --git a/aio/gpu-8g/embeddings.yaml b/aio/gpu-8g/embeddings.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..0f88f4511ba50c624a4e57aa62e984260f6387c9
--- /dev/null
+++ b/aio/gpu-8g/embeddings.yaml
@@ -0,0 +1,13 @@
+embeddings: true
+name: text-embedding-ada-002
+backend: llama-cpp
+parameters:
+ model: huggingface://bartowski/granite-embedding-107m-multilingual-GGUF/granite-embedding-107m-multilingual-f16.gguf
+
+usage: |
+ You can test this model with curl like this:
+
+ curl http://localhost:8080/embeddings -X POST -H "Content-Type: application/json" -d '{
+ "input": "Your text string goes here",
+ "model": "text-embedding-ada-002"
+ }'
\ No newline at end of file
diff --git a/aio/gpu-8g/image-gen.yaml b/aio/gpu-8g/image-gen.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..0074aaf0e043bf0a528741af100ab7bfa5759007
--- /dev/null
+++ b/aio/gpu-8g/image-gen.yaml
@@ -0,0 +1,25 @@
+name: stablediffusion
+parameters:
+ model: DreamShaper_8_pruned.safetensors
+backend: diffusers
+step: 25
+f16: true
+
+diffusers:
+ pipeline_type: StableDiffusionPipeline
+ cuda: true
+ enable_parameters: "negative_prompt,num_inference_steps"
+ scheduler_type: "k_dpmpp_2m"
+
+download_files:
+- filename: DreamShaper_8_pruned.safetensors
+ uri: huggingface://Lykon/DreamShaper/DreamShaper_8_pruned.safetensors
+
+usage: |
+ curl http://localhost:8080/v1/images/generations \
+ -H "Content-Type: application/json" \
+ -d '{
+ "prompt": "|",
+ "step": 25,
+ "size": "512x512"
+ }'
\ No newline at end of file
diff --git a/aio/gpu-8g/rerank.yaml b/aio/gpu-8g/rerank.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..70d386b2b6c47a010bbb54cd70b4542fc1e00262
--- /dev/null
+++ b/aio/gpu-8g/rerank.yaml
@@ -0,0 +1,33 @@
+name: jina-reranker-v1-base-en
+reranking: true
+f16: true
+parameters:
+ model: jina-reranker-v1-tiny-en.f16.gguf
+backend: llama-cpp
+download_files:
+ - filename: jina-reranker-v1-tiny-en.f16.gguf
+ sha256: 5f696cf0d0f3d347c4a279eee8270e5918554cdac0ed1f632f2619e4e8341407
+ uri: huggingface://mradermacher/jina-reranker-v1-tiny-en-GGUF/jina-reranker-v1-tiny-en.f16.gguf
+
+usage: |
+ You can test this model with curl like this:
+
+ curl http://localhost:8080/v1/rerank \
+ -H "Content-Type: application/json" \
+ -d '{
+ "model": "jina-reranker-v1-base-en",
+ "query": "Organic skincare products for sensitive skin",
+ "documents": [
+ "Eco-friendly kitchenware for modern homes",
+ "Biodegradable cleaning supplies for eco-conscious consumers",
+ "Organic cotton baby clothes for sensitive skin",
+ "Natural organic skincare range for sensitive skin",
+ "Tech gadgets for smart homes: 2024 edition",
+ "Sustainable gardening tools and compost solutions",
+ "Sensitive skin-friendly facial cleansers and toners",
+ "Organic food wraps and storage solutions",
+ "All-natural pet food for dogs with allergies",
+ "Yoga mats made from recycled materials"
+ ],
+ "top_n": 3
+ }'
diff --git a/aio/gpu-8g/speech-to-text.yaml b/aio/gpu-8g/speech-to-text.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..77850d79155439eb5b185a9d322f49910b1fe8c8
--- /dev/null
+++ b/aio/gpu-8g/speech-to-text.yaml
@@ -0,0 +1,18 @@
+name: whisper-1
+backend: whisper
+parameters:
+ model: ggml-whisper-base.bin
+
+usage: |
+ ## example audio file
+ wget --quiet --show-progress -O gb1.ogg https://upload.wikimedia.org/wikipedia/commons/1/1f/George_W_Bush_Columbia_FINAL.ogg
+
+ ## Send the example audio file to the transcriptions endpoint
+ curl http://localhost:8080/v1/audio/transcriptions \
+ -H "Content-Type: multipart/form-data" \
+ -F file="@$PWD/gb1.ogg" -F model="whisper-1"
+
+download_files:
+- filename: "ggml-whisper-base.bin"
+ sha256: "60ed5bc3dd14eea856493d334349b405782ddcaf0028d4b5df4088345fba2efe"
+ uri: "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-base.bin"
\ No newline at end of file
diff --git a/aio/gpu-8g/text-to-speech.yaml b/aio/gpu-8g/text-to-speech.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..782f8624a032ca9ca657c4d441a3c521dc58f794
--- /dev/null
+++ b/aio/gpu-8g/text-to-speech.yaml
@@ -0,0 +1,15 @@
+name: tts-1
+download_files:
+ - filename: voice-en-us-amy-low.tar.gz
+ uri: https://github.com/rhasspy/piper/releases/download/v0.0.2/voice-en-us-amy-low.tar.gz
+backend: piper
+parameters:
+ model: en-us-amy-low.onnx
+
+usage: |
+ To test if this model works as expected, you can use the following curl command:
+
+ curl http://localhost:8080/tts -H "Content-Type: application/json" -d '{
+ "model":"tts-1",
+ "input": "Hi, this is a test."
+ }'
\ No newline at end of file
diff --git a/aio/gpu-8g/text-to-text.yaml b/aio/gpu-8g/text-to-text.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..7d5c991c9ec9ad9468e857819d49211a9aa071f8
--- /dev/null
+++ b/aio/gpu-8g/text-to-text.yaml
@@ -0,0 +1,54 @@
+context_size: 4096
+f16: true
+backend: llama-cpp
+function:
+ capture_llm_results:
+ - (?s)(.*?)
+ grammar:
+ properties_order: name,arguments
+ json_regex_match:
+ - (?s)(.*?)
+ replace_llm_results:
+ - key: (?s)(.*?)
+ value: ""
+mmap: true
+name: gpt-4
+parameters:
+ model: localai-functioncall-qwen2.5-7b-v0.5-q4_k_m.gguf
+stopwords:
+- <|im_end|>
+-
+-
+template:
+ chat: |
+ {{.Input -}}
+ <|im_start|>assistant
+ chat_message: |
+ <|im_start|>{{ .RoleName }}
+ {{ if .FunctionCall -}}
+ Function call:
+ {{ else if eq .RoleName "tool" -}}
+ Function response:
+ {{ end -}}
+ {{ if .Content -}}
+ {{.Content }}
+ {{ end -}}
+ {{ if .FunctionCall -}}
+ {{toJson .FunctionCall}}
+ {{ end -}}<|im_end|>
+ completion: |
+ {{.Input}}
+ function: |
+ <|im_start|>system
+ You are an AI assistant that executes function calls, and these are the tools at your disposal:
+ {{range .Functions}}
+ {'type': 'function', 'function': {'name': '{{.Name}}', 'description': '{{.Description}}', 'parameters': {{toJson .Parameters}} }}
+ {{end}}
+ <|im_end|>
+ {{.Input -}}
+ <|im_start|>assistant
+
+download_files:
+- filename: localai-functioncall-qwen2.5-7b-v0.5-q4_k_m.gguf
+ sha256: 4e7b7fe1d54b881f1ef90799219dc6cc285d29db24f559c8998d1addb35713d4
+ uri: huggingface://mudler/LocalAI-functioncall-qwen2.5-7b-v0.5-Q4_K_M-GGUF/localai-functioncall-qwen2.5-7b-v0.5-q4_k_m.gguf
diff --git a/aio/gpu-8g/vad.yaml b/aio/gpu-8g/vad.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..b0dc70d75ed17e011a41c293d0fbd30c9e8d24aa
--- /dev/null
+++ b/aio/gpu-8g/vad.yaml
@@ -0,0 +1,8 @@
+backend: silero-vad
+name: silero-vad
+parameters:
+ model: silero-vad.onnx
+download_files:
+- filename: silero-vad.onnx
+ uri: https://huggingface.co/onnx-community/silero-vad/resolve/main/onnx/model.onnx
+ sha256: a4a068cd6cf1ea8355b84327595838ca748ec29a25bc91fc82e6c299ccdc5808
\ No newline at end of file
diff --git a/aio/gpu-8g/vision.yaml b/aio/gpu-8g/vision.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..5c2d9930c5d29ade99400f98a48ba4ee0458185e
--- /dev/null
+++ b/aio/gpu-8g/vision.yaml
@@ -0,0 +1,50 @@
+context_size: 4096
+backend: llama-cpp
+f16: true
+mmap: true
+mmproj: minicpm-v-4_5-mmproj-f16.gguf
+name: gpt-4o
+parameters:
+ model: minicpm-v-4_5-Q4_K_M.gguf
+stopwords:
+- <|im_end|>
+-
+-
+- <|endoftext|>
+template:
+ chat: |
+ {{.Input -}}
+ <|im_start|>assistant
+ chat_message: |
+ <|im_start|>{{ .RoleName }}
+ {{ if .FunctionCall -}}
+ Function call:
+ {{ else if eq .RoleName "tool" -}}
+ Function response:
+ {{ end -}}
+ {{ if .Content -}}
+ {{.Content }}
+ {{ end -}}
+ {{ if .FunctionCall -}}
+ {{toJson .FunctionCall}}
+ {{ end -}}<|im_end|>
+ completion: |
+ {{.Input}}
+ function: |
+ <|im_start|>system
+ You are a function calling AI model. You are provided with functions to execute. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools:
+ {{range .Functions}}
+ {'type': 'function', 'function': {'name': '{{.Name}}', 'description': '{{.Description}}', 'parameters': {{toJson .Parameters}} }}
+ {{end}}
+ For each function call return a json object with function name and arguments
+ <|im_end|>
+ {{.Input -}}
+ <|im_start|>assistant
+
+download_files:
+- filename: minicpm-v-4_5-Q4_K_M.gguf
+ sha256: c1c3c33100b15b4caf7319acce4e23c0eb0ce1cbd12f70e8d24f05aa67b7512f
+ uri: huggingface://openbmb/MiniCPM-V-4_5-gguf/ggml-model-Q4_K_M.gguf
+- filename: minicpm-v-4_5-mmproj-f16.gguf
+ uri: huggingface://openbmb/MiniCPM-V-4_5-gguf/mmproj-model-f16.gguf
+ sha256: 7a7225a32e8d453aaa3d22d8c579b5bf833c253f784cdb05c99c9a76fd616df8
\ No newline at end of file
diff --git a/aio/intel/embeddings.yaml b/aio/intel/embeddings.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..0f88f4511ba50c624a4e57aa62e984260f6387c9
--- /dev/null
+++ b/aio/intel/embeddings.yaml
@@ -0,0 +1,13 @@
+embeddings: true
+name: text-embedding-ada-002
+backend: llama-cpp
+parameters:
+ model: huggingface://bartowski/granite-embedding-107m-multilingual-GGUF/granite-embedding-107m-multilingual-f16.gguf
+
+usage: |
+ You can test this model with curl like this:
+
+ curl http://localhost:8080/embeddings -X POST -H "Content-Type: application/json" -d '{
+ "input": "Your text string goes here",
+ "model": "text-embedding-ada-002"
+ }'
\ No newline at end of file
diff --git a/aio/intel/image-gen.yaml b/aio/intel/image-gen.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..45fe6b62d616a1f2fca7caa69785ab5fd5fac983
--- /dev/null
+++ b/aio/intel/image-gen.yaml
@@ -0,0 +1,20 @@
+name: stablediffusion
+parameters:
+ model: Lykon/dreamshaper-8
+backend: diffusers
+step: 25
+f16: true
+diffusers:
+ pipeline_type: StableDiffusionPipeline
+ cuda: true
+ enable_parameters: "negative_prompt,num_inference_steps"
+ scheduler_type: "k_dpmpp_2m"
+
+usage: |
+ curl http://localhost:8080/v1/images/generations \
+ -H "Content-Type: application/json" \
+ -d '{
+ "prompt": "|",
+ "step": 25,
+ "size": "512x512"
+ }'
\ No newline at end of file
diff --git a/aio/intel/rerank.yaml b/aio/intel/rerank.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..70d386b2b6c47a010bbb54cd70b4542fc1e00262
--- /dev/null
+++ b/aio/intel/rerank.yaml
@@ -0,0 +1,33 @@
+name: jina-reranker-v1-base-en
+reranking: true
+f16: true
+parameters:
+ model: jina-reranker-v1-tiny-en.f16.gguf
+backend: llama-cpp
+download_files:
+ - filename: jina-reranker-v1-tiny-en.f16.gguf
+ sha256: 5f696cf0d0f3d347c4a279eee8270e5918554cdac0ed1f632f2619e4e8341407
+ uri: huggingface://mradermacher/jina-reranker-v1-tiny-en-GGUF/jina-reranker-v1-tiny-en.f16.gguf
+
+usage: |
+ You can test this model with curl like this:
+
+ curl http://localhost:8080/v1/rerank \
+ -H "Content-Type: application/json" \
+ -d '{
+ "model": "jina-reranker-v1-base-en",
+ "query": "Organic skincare products for sensitive skin",
+ "documents": [
+ "Eco-friendly kitchenware for modern homes",
+ "Biodegradable cleaning supplies for eco-conscious consumers",
+ "Organic cotton baby clothes for sensitive skin",
+ "Natural organic skincare range for sensitive skin",
+ "Tech gadgets for smart homes: 2024 edition",
+ "Sustainable gardening tools and compost solutions",
+ "Sensitive skin-friendly facial cleansers and toners",
+ "Organic food wraps and storage solutions",
+ "All-natural pet food for dogs with allergies",
+ "Yoga mats made from recycled materials"
+ ],
+ "top_n": 3
+ }'
diff --git a/aio/intel/speech-to-text.yaml b/aio/intel/speech-to-text.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..77850d79155439eb5b185a9d322f49910b1fe8c8
--- /dev/null
+++ b/aio/intel/speech-to-text.yaml
@@ -0,0 +1,18 @@
+name: whisper-1
+backend: whisper
+parameters:
+ model: ggml-whisper-base.bin
+
+usage: |
+ ## example audio file
+ wget --quiet --show-progress -O gb1.ogg https://upload.wikimedia.org/wikipedia/commons/1/1f/George_W_Bush_Columbia_FINAL.ogg
+
+ ## Send the example audio file to the transcriptions endpoint
+ curl http://localhost:8080/v1/audio/transcriptions \
+ -H "Content-Type: multipart/form-data" \
+ -F file="@$PWD/gb1.ogg" -F model="whisper-1"
+
+download_files:
+- filename: "ggml-whisper-base.bin"
+ sha256: "60ed5bc3dd14eea856493d334349b405782ddcaf0028d4b5df4088345fba2efe"
+ uri: "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-base.bin"
\ No newline at end of file
diff --git a/aio/intel/text-to-speech.yaml b/aio/intel/text-to-speech.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..782f8624a032ca9ca657c4d441a3c521dc58f794
--- /dev/null
+++ b/aio/intel/text-to-speech.yaml
@@ -0,0 +1,15 @@
+name: tts-1
+download_files:
+ - filename: voice-en-us-amy-low.tar.gz
+ uri: https://github.com/rhasspy/piper/releases/download/v0.0.2/voice-en-us-amy-low.tar.gz
+backend: piper
+parameters:
+ model: en-us-amy-low.onnx
+
+usage: |
+ To test if this model works as expected, you can use the following curl command:
+
+ curl http://localhost:8080/tts -H "Content-Type: application/json" -d '{
+ "model":"tts-1",
+ "input": "Hi, this is a test."
+ }'
\ No newline at end of file
diff --git a/aio/intel/text-to-text.yaml b/aio/intel/text-to-text.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..9fe7c11436e4a90954190c6f32b4bbb709518b8c
--- /dev/null
+++ b/aio/intel/text-to-text.yaml
@@ -0,0 +1,54 @@
+context_size: 4096
+f16: true
+backend: llama-cpp
+function:
+ capture_llm_results:
+ - (?s)(.*?)
+ grammar:
+ properties_order: name,arguments
+ json_regex_match:
+ - (?s)(.*?)
+ replace_llm_results:
+ - key: (?s)(.*?)
+ value: ""
+mmap: true
+name: gpt-4
+parameters:
+ model: localai-functioncall-qwen2.5-7b-v0.5-q4_k_m.gguf
+stopwords:
+- <|im_end|>
+-
+-
+template:
+ chat: |
+ {{.Input -}}
+ <|im_start|>assistant
+ chat_message: |
+ <|im_start|>{{ .RoleName }}
+ {{ if .FunctionCall -}}
+ Function call:
+ {{ else if eq .RoleName "tool" -}}
+ Function response:
+ {{ end -}}
+ {{ if .Content -}}
+ {{.Content }}
+ {{ end -}}
+ {{ if .FunctionCall -}}
+ {{toJson .FunctionCall}}
+ {{ end -}}<|im_end|>
+ completion: |
+ {{.Input}}
+ function: |
+ <|im_start|>system
+ You are an AI assistant that executes function calls, and these are the tools at your disposal:
+ {{range .Functions}}
+ {'type': 'function', 'function': {'name': '{{.Name}}', 'description': '{{.Description}}', 'parameters': {{toJson .Parameters}} }}
+ {{end}}
+ <|im_end|>
+ {{.Input -}}
+ <|im_start|>assistant
+
+download_files:
+- filename: localai-functioncall-phi-4-v0.3-q4_k_m.gguf
+ sha256: 23fee048ded2a6e2e1a7b6bbefa6cbf83068f194caa9552aecbaa00fec8a16d5
+ uri: huggingface://mudler/LocalAI-functioncall-phi-4-v0.3-Q4_K_M-GGUF/localai-functioncall-phi-4-v0.3-q4_k_m.gguf
\ No newline at end of file
diff --git a/aio/intel/vad.yaml b/aio/intel/vad.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..b0dc70d75ed17e011a41c293d0fbd30c9e8d24aa
--- /dev/null
+++ b/aio/intel/vad.yaml
@@ -0,0 +1,8 @@
+backend: silero-vad
+name: silero-vad
+parameters:
+ model: silero-vad.onnx
+download_files:
+- filename: silero-vad.onnx
+ uri: https://huggingface.co/onnx-community/silero-vad/resolve/main/onnx/model.onnx
+ sha256: a4a068cd6cf1ea8355b84327595838ca748ec29a25bc91fc82e6c299ccdc5808
\ No newline at end of file
diff --git a/aio/intel/vision.yaml b/aio/intel/vision.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..00b8c0680048f11d832c57e8f9cea6356438e67d
--- /dev/null
+++ b/aio/intel/vision.yaml
@@ -0,0 +1,51 @@
+context_size: 4096
+backend: llama-cpp
+f16: true
+mmap: true
+mmproj: minicpm-v-4_5-mmproj-f16.gguf
+name: gpt-4o
+parameters:
+ model: minicpm-v-4_5-Q4_K_M.gguf
+stopwords:
+- <|im_end|>
+-
+-
+- <|endoftext|>
+template:
+ chat: |
+ {{.Input -}}
+ <|im_start|>assistant
+ chat_message: |
+ <|im_start|>{{ .RoleName }}
+ {{ if .FunctionCall -}}
+ Function call:
+ {{ else if eq .RoleName "tool" -}}
+ Function response:
+ {{ end -}}
+ {{ if .Content -}}
+ {{.Content }}
+ {{ end -}}
+ {{ if .FunctionCall -}}
+ {{toJson .FunctionCall}}
+ {{ end -}}<|im_end|>
+ completion: |
+ {{.Input}}
+ function: |
+ <|im_start|>system
+ You are a function calling AI model. You are provided with functions to execute. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools:
+ {{range .Functions}}
+ {'type': 'function', 'function': {'name': '{{.Name}}', 'description': '{{.Description}}', 'parameters': {{toJson .Parameters}} }}
+ {{end}}
+ For each function call return a json object with function name and arguments
+ <|im_end|>
+ {{.Input -}}
+ <|im_start|>assistant
+
+
+download_files:
+- filename: minicpm-v-4_5-Q4_K_M.gguf
+ sha256: c1c3c33100b15b4caf7319acce4e23c0eb0ce1cbd12f70e8d24f05aa67b7512f
+ uri: huggingface://openbmb/MiniCPM-V-4_5-gguf/ggml-model-Q4_K_M.gguf
+- filename: minicpm-v-4_5-mmproj-f16.gguf
+ uri: huggingface://openbmb/MiniCPM-V-4_5-gguf/mmproj-model-f16.gguf
+ sha256: 7a7225a32e8d453aaa3d22d8c579b5bf833c253f784cdb05c99c9a76fd616df8
\ No newline at end of file
diff --git a/backend/Dockerfile.golang b/backend/Dockerfile.golang
new file mode 100644
index 0000000000000000000000000000000000000000..5c7f33caf2f06980e524be132881ec43e18532dd
--- /dev/null
+++ b/backend/Dockerfile.golang
@@ -0,0 +1,192 @@
+ARG BASE_IMAGE=ubuntu:24.04
+
+FROM ${BASE_IMAGE} AS builder
+ARG BACKEND=rerankers
+ARG BUILD_TYPE
+ENV BUILD_TYPE=${BUILD_TYPE}
+ARG CUDA_MAJOR_VERSION
+ARG CUDA_MINOR_VERSION
+ARG SKIP_DRIVERS=false
+ENV CUDA_MAJOR_VERSION=${CUDA_MAJOR_VERSION}
+ENV CUDA_MINOR_VERSION=${CUDA_MINOR_VERSION}
+ENV DEBIAN_FRONTEND=noninteractive
+ARG TARGETARCH
+ARG TARGETVARIANT
+ARG GO_VERSION=1.25.4
+ARG UBUNTU_VERSION=2404
+
+RUN apt-get update && \
+ apt-get install -y --no-install-recommends \
+ build-essential \
+ git ccache \
+ ca-certificates \
+ make cmake wget \
+ curl unzip \
+ libssl-dev && \
+ apt-get clean && \
+ rm -rf /var/lib/apt/lists/*
+
+
+# Cuda
+ENV PATH=/usr/local/cuda/bin:${PATH}
+
+# HipBLAS requirements
+ENV PATH=/opt/rocm/bin:${PATH}
+
+
+# Vulkan requirements
+RUN < breakdown = 2;
+}
+
+message StatusResponse {
+ enum State {
+ UNINITIALIZED = 0;
+ BUSY = 1;
+ READY = 2;
+ ERROR = -1;
+ }
+ State state = 1;
+ MemoryUsageData memory = 2;
+}
+
+message Message {
+ string role = 1;
+ string content = 2;
+ // Optional fields for OpenAI-compatible message format
+ string name = 3; // Tool name (for tool messages)
+ string tool_call_id = 4; // Tool call ID (for tool messages)
+ string reasoning_content = 5; // Reasoning content (for thinking models)
+ string tool_calls = 6; // Tool calls as JSON string (for assistant messages with tool calls)
+}
+
+message DetectOptions {
+ string src = 1;
+}
+
+message Detection {
+ float x = 1;
+ float y = 2;
+ float width = 3;
+ float height = 4;
+ float confidence = 5;
+ string class_name = 6;
+}
+
+message DetectResponse {
+ repeated Detection Detections = 1;
+}
diff --git a/backend/cpp/grpc/.gitignore b/backend/cpp/grpc/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..e533db2e1aa15153259182350374af930808d9c0
--- /dev/null
+++ b/backend/cpp/grpc/.gitignore
@@ -0,0 +1,3 @@
+installed_packages/
+grpc_build/
+grpc_repo/
diff --git a/backend/cpp/grpc/Makefile b/backend/cpp/grpc/Makefile
new file mode 100644
index 0000000000000000000000000000000000000000..9189b69ad62a025e79953fec095f5a7c444f5496
--- /dev/null
+++ b/backend/cpp/grpc/Makefile
@@ -0,0 +1,70 @@
+# Basic platform detection
+HOST_SYSTEM = $(shell uname | cut -f 1 -d_)
+SYSTEM ?= $(HOST_SYSTEM)
+
+TAG_LIB_GRPC?=v1.59.0
+GIT_REPO_LIB_GRPC?=https://github.com/grpc/grpc.git
+GIT_CLONE_DEPTH?=1
+
+INSTALLED_PACKAGES=installed_packages
+GRPC_REPO=grpc_repo
+GRPC_BUILD=grpc_build
+
+export CMAKE_ARGS?=
+CMAKE_ARGS+=-DCMAKE_BUILD_TYPE=Release
+CMAKE_ARGS+=-DgRPC_INSTALL=ON
+CMAKE_ARGS+=-DEXECUTABLE_OUTPUT_PATH=../$(INSTALLED_PACKAGES)/grpc/bin
+CMAKE_ARGS+=-DLIBRARY_OUTPUT_PATH=../$(INSTALLED_PACKAGES)/grpc/lib
+CMAKE_ARGS+=-DgRPC_BUILD_TESTS=OFF
+CMAKE_ARGS+=-DgRPC_BUILD_CSHARP_EXT=OFF
+CMAKE_ARGS+=-DgRPC_BUILD_GRPC_CPP_PLUGIN=ON
+CMAKE_ARGS+=-DgRPC_BUILD_GRPC_CSHARP_PLUGIN=OFF
+CMAKE_ARGS+=-DgRPC_BUILD_GRPC_NODE_PLUGIN=OFF
+CMAKE_ARGS+=-DgRPC_BUILD_GRPC_OBJECTIVE_C_PLUGIN=OFF
+CMAKE_ARGS+=-DgRPC_BUILD_GRPC_PHP_PLUGIN=OFF
+CMAKE_ARGS+=-DgRPC_BUILD_GRPC_PYTHON_PLUGIN=ON
+CMAKE_ARGS+=-DgRPC_BUILD_GRPC_RUBY_PLUGIN=OFF
+CMAKE_ARGS+=-Dprotobuf_WITH_ZLIB=ON
+CMAKE_ARGS+=-DRE2_BUILD_TESTING=OFF
+CMAKE_ARGS+=-DCMAKE_INSTALL_PREFIX=../$(INSTALLED_PACKAGES)
+
+# windows need to set OPENSSL_NO_ASM. Results in slower crypto performance but doesn't build otherwise.
+# May be resolvable, but for now its set. More info: https://stackoverflow.com/a/75240504/480673
+ifeq ($(SYSTEM),MSYS)
+CMAKE_ARGS+=-DOPENSSL_NO_ASM=ON
+endif
+ifeq ($(SYSTEM),MINGW64)
+CMAKE_ARGS+=-DOPENSSL_NO_ASM=ON
+endif
+ifeq ($(SYSTEM),MINGW32)
+CMAKE_ARGS+=-DOPENSSL_NO_ASM=ON
+endif
+ifeq ($(SYSTEM),CYGWIN)
+CMAKE_ARGS+=-DOPENSSL_NO_ASM=ON
+endif
+
+$(INSTALLED_PACKAGES): grpc_build
+
+$(GRPC_REPO):
+ mkdir -p $(GRPC_REPO)/grpc
+ cd $(GRPC_REPO)/grpc && \
+ git init && \
+ git remote add origin $(GIT_REPO_LIB_GRPC) && \
+ git fetch origin && \
+ git checkout $(TAG_LIB_GRPC) && \
+ git submodule update --init --recursive --depth 1 --single-branch
+
+$(GRPC_BUILD): $(GRPC_REPO)
+ mkdir -p $(GRPC_BUILD)
+ cd $(GRPC_BUILD) && cmake $(CMAKE_ARGS) ../$(GRPC_REPO)/grpc && cmake --build . && cmake --build . --target install
+
+build: $(INSTALLED_PACKAGES)
+
+rebuild:
+ rm -rf grpc_build
+ $(MAKE) grpc_build
+
+clean:
+ rm -rf grpc_build
+ rm -rf grpc_repo
+ rm -rf installed_packages
diff --git a/backend/cpp/llama-cpp/CMakeLists.txt b/backend/cpp/llama-cpp/CMakeLists.txt
new file mode 100644
index 0000000000000000000000000000000000000000..5984619755323c13d115441fa01ad75d586faf20
--- /dev/null
+++ b/backend/cpp/llama-cpp/CMakeLists.txt
@@ -0,0 +1,73 @@
+set(TARGET grpc-server)
+set(CMAKE_CXX_STANDARD 17)
+cmake_minimum_required(VERSION 3.15)
+set(TARGET grpc-server)
+set(_PROTOBUF_LIBPROTOBUF libprotobuf)
+set(_REFLECTION grpc++_reflection)
+
+if (${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
+ # Set correct Homebrew install folder for Apple Silicon and Intel Macs
+ if (CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "arm64")
+ set(HOMEBREW_DEFAULT_PREFIX "/opt/homebrew")
+ else()
+ set(HOMEBREW_DEFAULT_PREFIX "/usr/local")
+ endif()
+
+ link_directories("${HOMEBREW_DEFAULT_PREFIX}/lib")
+ include_directories("${HOMEBREW_DEFAULT_PREFIX}/include")
+endif()
+
+find_package(absl CONFIG REQUIRED)
+find_package(Protobuf CONFIG REQUIRED)
+find_package(gRPC CONFIG REQUIRED)
+
+find_program(_PROTOBUF_PROTOC protoc)
+set(_GRPC_GRPCPP grpc++)
+find_program(_GRPC_CPP_PLUGIN_EXECUTABLE grpc_cpp_plugin)
+
+include_directories(${CMAKE_CURRENT_BINARY_DIR})
+include_directories(${Protobuf_INCLUDE_DIRS})
+
+message(STATUS "Using protobuf version ${Protobuf_VERSION} | Protobuf_INCLUDE_DIRS: ${Protobuf_INCLUDE_DIRS} | CMAKE_CURRENT_BINARY_DIR: ${CMAKE_CURRENT_BINARY_DIR}")
+
+# Proto file
+get_filename_component(hw_proto "../../../../../../backend/backend.proto" ABSOLUTE)
+get_filename_component(hw_proto_path "${hw_proto}" PATH)
+
+# Generated sources
+set(hw_proto_srcs "${CMAKE_CURRENT_BINARY_DIR}/backend.pb.cc")
+set(hw_proto_hdrs "${CMAKE_CURRENT_BINARY_DIR}/backend.pb.h")
+set(hw_grpc_srcs "${CMAKE_CURRENT_BINARY_DIR}/backend.grpc.pb.cc")
+set(hw_grpc_hdrs "${CMAKE_CURRENT_BINARY_DIR}/backend.grpc.pb.h")
+
+add_custom_command(
+ OUTPUT "${hw_proto_srcs}" "${hw_proto_hdrs}" "${hw_grpc_srcs}" "${hw_grpc_hdrs}"
+ COMMAND ${_PROTOBUF_PROTOC}
+ ARGS --grpc_out "${CMAKE_CURRENT_BINARY_DIR}"
+ --cpp_out "${CMAKE_CURRENT_BINARY_DIR}"
+ -I "${hw_proto_path}"
+ --plugin=protoc-gen-grpc="${_GRPC_CPP_PLUGIN_EXECUTABLE}"
+ "${hw_proto}"
+ DEPENDS "${hw_proto}")
+
+# hw_grpc_proto
+add_library(hw_grpc_proto
+ ${hw_grpc_srcs}
+ ${hw_grpc_hdrs}
+ ${hw_proto_srcs}
+ ${hw_proto_hdrs} )
+
+add_executable(${TARGET} grpc-server.cpp json.hpp httplib.h)
+
+target_include_directories(${TARGET} PRIVATE ../llava)
+target_include_directories(${TARGET} PRIVATE ${CMAKE_SOURCE_DIR})
+
+target_link_libraries(${TARGET} PRIVATE common llama mtmd ${CMAKE_THREAD_LIBS_INIT} absl::flags hw_grpc_proto
+ absl::flags_parse
+ gRPC::${_REFLECTION}
+ gRPC::${_GRPC_GRPCPP}
+ protobuf::${_PROTOBUF_LIBPROTOBUF})
+target_compile_features(${TARGET} PRIVATE cxx_std_11)
+if(TARGET BUILD_INFO)
+ add_dependencies(${TARGET} BUILD_INFO)
+endif()
diff --git a/backend/cpp/llama-cpp/Makefile b/backend/cpp/llama-cpp/Makefile
new file mode 100644
index 0000000000000000000000000000000000000000..fdfb9e0017124d568a7fe9ec74ff10a42c71c18a
--- /dev/null
+++ b/backend/cpp/llama-cpp/Makefile
@@ -0,0 +1,167 @@
+
+LLAMA_VERSION?=d98b548120eecf98f0f6eaa1ba7e29b3afda9f2e
+LLAMA_REPO?=https://github.com/ggerganov/llama.cpp
+
+CMAKE_ARGS?=
+BUILD_TYPE?=
+NATIVE?=false
+ONEAPI_VARS?=/opt/intel/oneapi/setvars.sh
+TARGET?=--target grpc-server
+JOBS?=$(shell nproc 2>/dev/null || sysctl -n hw.ncpu 2>/dev/null || echo 1)
+ARCH?=$(shell uname -m)
+
+# Disable Shared libs as we are linking on static gRPC and we can't mix shared and static
+CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF -DLLAMA_CURL=OFF
+
+CURRENT_MAKEFILE_DIR := $(dir $(abspath $(lastword $(MAKEFILE_LIST))))
+ifeq ($(NATIVE),false)
+ CMAKE_ARGS+=-DGGML_NATIVE=OFF -DLLAMA_OPENSSL=OFF
+endif
+# If build type is cublas, then we set -DGGML_CUDA=ON to CMAKE_ARGS automatically
+ifeq ($(BUILD_TYPE),cublas)
+ CMAKE_ARGS+=-DGGML_CUDA=ON
+# If build type is openblas then we set -DGGML_BLAS=ON -DGGML_BLAS_VENDOR=OpenBLAS
+# to CMAKE_ARGS automatically
+else ifeq ($(BUILD_TYPE),openblas)
+ CMAKE_ARGS+=-DGGML_BLAS=ON -DGGML_BLAS_VENDOR=OpenBLAS
+# If build type is clblas (openCL) we set -DGGML_CLBLAST=ON -DCLBlast_DIR=/some/path
+else ifeq ($(BUILD_TYPE),clblas)
+ CMAKE_ARGS+=-DGGML_CLBLAST=ON -DCLBlast_DIR=/some/path
+# If it's hipblas we do have also to set CC=/opt/rocm/llvm/bin/clang CXX=/opt/rocm/llvm/bin/clang++
+else ifeq ($(BUILD_TYPE),hipblas)
+ ROCM_HOME ?= /opt/rocm
+ ROCM_PATH ?= /opt/rocm
+ export CXX=$(ROCM_HOME)/llvm/bin/clang++
+ export CC=$(ROCM_HOME)/llvm/bin/clang
+ AMDGPU_TARGETS?=gfx803,gfx900,gfx906,gfx908,gfx90a,gfx942,gfx1010,gfx1030,gfx1032,gfx1100,gfx1101,gfx1102,gfx1200,gfx1201
+ CMAKE_ARGS+=-DGGML_HIP=ON -DAMDGPU_TARGETS=$(AMDGPU_TARGETS)
+else ifeq ($(BUILD_TYPE),vulkan)
+ CMAKE_ARGS+=-DGGML_VULKAN=1
+else ifeq ($(OS),Darwin)
+ ifeq ($(BUILD_TYPE),)
+ BUILD_TYPE=metal
+ endif
+ ifneq ($(BUILD_TYPE),metal)
+ CMAKE_ARGS+=-DGGML_METAL=OFF
+ else
+ CMAKE_ARGS+=-DGGML_METAL=ON
+ CMAKE_ARGS+=-DGGML_METAL_EMBED_LIBRARY=ON
+ CMAKE_ARGS+=-DGGML_METAL_USE_BF16=ON
+ CMAKE_ARGS+=-DGGML_OPENMP=OFF
+ endif
+ TARGET+=--target ggml-metal
+endif
+
+ifeq ($(BUILD_TYPE),sycl_f16)
+ CMAKE_ARGS+=-DGGML_SYCL=ON \
+ -DCMAKE_C_COMPILER=icx \
+ -DCMAKE_CXX_COMPILER=icpx \
+ -DCMAKE_CXX_FLAGS="-fsycl" \
+ -DGGML_SYCL_F16=ON
+endif
+
+ifeq ($(BUILD_TYPE),sycl_f32)
+ CMAKE_ARGS+=-DGGML_SYCL=ON \
+ -DCMAKE_C_COMPILER=icx \
+ -DCMAKE_CXX_COMPILER=icpx \
+ -DCMAKE_CXX_FLAGS="-fsycl"
+endif
+
+INSTALLED_PACKAGES=$(CURDIR)/../grpc/installed_packages
+INSTALLED_LIB_CMAKE=$(INSTALLED_PACKAGES)/lib/cmake
+ADDED_CMAKE_ARGS=-Dabsl_DIR=${INSTALLED_LIB_CMAKE}/absl \
+ -DProtobuf_DIR=${INSTALLED_LIB_CMAKE}/protobuf \
+ -Dutf8_range_DIR=${INSTALLED_LIB_CMAKE}/utf8_range \
+ -DgRPC_DIR=${INSTALLED_LIB_CMAKE}/grpc \
+ -DCMAKE_CXX_STANDARD_INCLUDE_DIRECTORIES=${INSTALLED_PACKAGES}/include
+build-llama-cpp-grpc-server:
+# Conditionally build grpc for the llama backend to use if needed
+ifdef BUILD_GRPC_FOR_BACKEND_LLAMA
+ $(MAKE) -C ../../grpc build
+ _PROTOBUF_PROTOC=${INSTALLED_PACKAGES}/bin/proto \
+ _GRPC_CPP_PLUGIN_EXECUTABLE=${INSTALLED_PACKAGES}/bin/grpc_cpp_plugin \
+ PATH="${INSTALLED_PACKAGES}/bin:${PATH}" \
+ CMAKE_ARGS="${CMAKE_ARGS} ${ADDED_CMAKE_ARGS}" \
+ LLAMA_VERSION=$(LLAMA_VERSION) \
+ $(MAKE) -C $(CURRENT_MAKEFILE_DIR)/../$(VARIANT) grpc-server
+else
+ echo "BUILD_GRPC_FOR_BACKEND_LLAMA is not defined."
+ LLAMA_VERSION=$(LLAMA_VERSION) $(MAKE) -C $(CURRENT_MAKEFILE_DIR)/../$(VARIANT) grpc-server
+endif
+
+llama-cpp-avx2: llama.cpp
+ cp -rf $(CURRENT_MAKEFILE_DIR)/../llama-cpp $(CURRENT_MAKEFILE_DIR)/../llama-cpp-avx2-build
+ $(MAKE) -C $(CURRENT_MAKEFILE_DIR)/../llama-cpp-avx2-build purge
+ $(info ${GREEN}I llama-cpp build info:avx2${RESET})
+ CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=on -DGGML_AVX2=on -DGGML_AVX512=off -DGGML_FMA=on -DGGML_F16C=on" $(MAKE) VARIANT="llama-cpp-avx2-build" build-llama-cpp-grpc-server
+ cp -rfv $(CURRENT_MAKEFILE_DIR)/../llama-cpp-avx2-build/grpc-server llama-cpp-avx2
+
+llama-cpp-avx512: llama.cpp
+ cp -rf $(CURRENT_MAKEFILE_DIR)/../llama-cpp $(CURRENT_MAKEFILE_DIR)/../llama-cpp-avx512-build
+ $(MAKE) -C $(CURRENT_MAKEFILE_DIR)/../llama-cpp-avx512-build purge
+ $(info ${GREEN}I llama-cpp build info:avx512${RESET})
+ CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=on -DGGML_AVX2=off -DGGML_AVX512=on -DGGML_FMA=on -DGGML_F16C=on" $(MAKE) VARIANT="llama-cpp-avx512-build" build-llama-cpp-grpc-server
+ cp -rfv $(CURRENT_MAKEFILE_DIR)/../llama-cpp-avx512-build/grpc-server llama-cpp-avx512
+
+llama-cpp-avx: llama.cpp
+ cp -rf $(CURRENT_MAKEFILE_DIR)/../llama-cpp $(CURRENT_MAKEFILE_DIR)/../llama-cpp-avx-build
+ $(MAKE) -C $(CURRENT_MAKEFILE_DIR)/../llama-cpp-avx-build purge
+ $(info ${GREEN}I llama-cpp build info:avx${RESET})
+ CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=on -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off -DGGML_BMI2=off" $(MAKE) VARIANT="llama-cpp-avx-build" build-llama-cpp-grpc-server
+ cp -rfv $(CURRENT_MAKEFILE_DIR)/../llama-cpp-avx-build/grpc-server llama-cpp-avx
+
+llama-cpp-fallback: llama.cpp
+ cp -rf $(CURRENT_MAKEFILE_DIR)/../llama-cpp $(CURRENT_MAKEFILE_DIR)/../llama-cpp-fallback-build
+ $(MAKE) -C $(CURRENT_MAKEFILE_DIR)/../llama-cpp-fallback-build purge
+ $(info ${GREEN}I llama-cpp build info:fallback${RESET})
+ CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=off -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off -DGGML_BMI2=off" $(MAKE) VARIANT="llama-cpp-fallback-build" build-llama-cpp-grpc-server
+ cp -rfv $(CURRENT_MAKEFILE_DIR)/../llama-cpp-fallback-build/grpc-server llama-cpp-fallback
+
+llama-cpp-grpc: llama.cpp
+ cp -rf $(CURRENT_MAKEFILE_DIR)/../llama-cpp $(CURRENT_MAKEFILE_DIR)/../llama-cpp-grpc-build
+ $(MAKE) -C $(CURRENT_MAKEFILE_DIR)/../llama-cpp-grpc-build purge
+ $(info ${GREEN}I llama-cpp build info:grpc${RESET})
+ CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_RPC=ON -DGGML_AVX=off -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off -DGGML_BMI2=off" TARGET="--target grpc-server --target rpc-server" $(MAKE) VARIANT="llama-cpp-grpc-build" build-llama-cpp-grpc-server
+ cp -rfv $(CURRENT_MAKEFILE_DIR)/../llama-cpp-grpc-build/grpc-server llama-cpp-grpc
+
+llama-cpp-rpc-server: llama-cpp-grpc
+ cp -rf $(CURRENT_MAKEFILE_DIR)/../llama-cpp-grpc-build/llama.cpp/build/bin/rpc-server llama-cpp-rpc-server
+
+llama.cpp:
+ mkdir -p llama.cpp
+ cd llama.cpp && \
+ git init && \
+ git remote add origin $(LLAMA_REPO) && \
+ git fetch origin && \
+ git checkout -b build $(LLAMA_VERSION) && \
+ git submodule update --init --recursive --depth 1 --single-branch
+
+llama.cpp/tools/grpc-server: llama.cpp
+ mkdir -p llama.cpp/tools/grpc-server
+ bash prepare.sh
+
+rebuild:
+ bash prepare.sh
+ rm -rf grpc-server
+ $(MAKE) grpc-server
+
+package:
+ bash package.sh
+
+purge:
+ rm -rf llama.cpp/build
+ rm -rf llama.cpp/tools/grpc-server
+ rm -rf grpc-server
+
+clean: purge
+ rm -rf llama.cpp
+
+grpc-server: llama.cpp llama.cpp/tools/grpc-server
+ @echo "Building grpc-server with $(BUILD_TYPE) build type and $(CMAKE_ARGS)"
+ifneq (,$(findstring sycl,$(BUILD_TYPE)))
+ +bash -c "source $(ONEAPI_VARS); \
+ cd llama.cpp && mkdir -p build && cd build && cmake .. $(CMAKE_ARGS) && cmake --build . --config Release -j $(JOBS) $(TARGET)"
+else
+ +cd llama.cpp && mkdir -p build && cd build && cmake .. $(CMAKE_ARGS) && cmake --build . --config Release -j $(JOBS) $(TARGET)
+endif
+ cp llama.cpp/build/bin/grpc-server .
diff --git a/backend/cpp/llama-cpp/grpc-server.cpp b/backend/cpp/llama-cpp/grpc-server.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..116454ccd2316be17cf484c2a94db4c9b27035b3
--- /dev/null
+++ b/backend/cpp/llama-cpp/grpc-server.cpp
@@ -0,0 +1,2553 @@
+// llama.cpp gRPC C++ backend server
+//
+// Ettore Di Giacinto and llama.cpp authors
+//
+// This is a gRPC server for llama.cpp compatible with the LocalAI proto
+// Note: this is a re-adaptation of the original llama.cpp example/server.cpp for HTTP (https://github.com/ggerganov/llama.cpp/tree/master/examples/server),
+// but modified to work with gRPC
+//
+
+#include "server-task.cpp"
+#include "server-queue.cpp"
+#include "server-common.cpp"
+#include "server-context.cpp"
+
+// LocalAI
+
+#include "backend.pb.h"
+#include "backend.grpc.pb.h"
+#include "common.h"
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
+#if defined(_WIN32)
+#include
+#endif
+
+
+using grpc::Server;
+using grpc::ServerBuilder;
+using grpc::ServerContext;
+using grpc::Status;
+// END LocalAI
+
+
+/////////////////////////////////
+////////////////////////////////
+//////// LOCALAI code starts below here
+/////////////////////////////////
+////////////////////////////////
+
+bool loaded_model; // TODO: add a mutex for this, but happens only once loading the model
+
+static std::function shutdown_handler;
+static std::atomic_flag is_terminating = ATOMIC_FLAG_INIT;
+
+static inline void signal_handler(int signal) {
+ if (is_terminating.test_and_set()) {
+ // in case it hangs, we can force terminate the server by hitting Ctrl+C twice
+ // this is for better developer experience, we can remove when the server is stable enough
+ fprintf(stderr, "Received second interrupt, terminating immediately.\n");
+ exit(1);
+ }
+
+ shutdown_handler(signal);
+}
+
+// Forward declarations
+static void start_llama_server(server_context& ctx_server);
+static json parse_options(bool streaming, const backend::PredictOptions* predict, const common_params& params_base, llama_context* ctx);
+static ggml_type kv_cache_type_from_str(const std::string & s);
+static std::string get_all_kv_cache_types();
+static void add_rpc_devices(std::string servers);
+static void params_parse(server_context& ctx_server, const backend::ModelOptions* request, common_params & params);
+
+static void start_llama_server(server_context& ctx_server) {
+
+ LOG_INF("%s: starting llama server\n", __func__);
+
+ LOG_INF("%s: waiting for model to be loaded\n", __func__);
+ // Wait for model to be loaded first
+ while (!loaded_model) {
+ std::this_thread::sleep_for(std::chrono::milliseconds(100));
+ }
+
+ LOG_INF("%s: model loaded\n", __func__);
+
+ // print sample chat example to make it clear which template is used
+ // LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__,
+ // common_chat_templates_source(ctx_server.impl->chat_templates.get()),
+ // common_chat_format_example(ctx_server.impl->chat_templates.get(), ctx_server.impl->params_base.use_jinja).c_str(), ctx_server.impl->params_base.default_template_kwargs);
+
+ // Keep the chat templates initialized in load_model() so they can be used when UseTokenizerTemplate is enabled
+ // Templates will only be used conditionally in Predict/PredictStream when UseTokenizerTemplate is true and Messages are provided
+
+ shutdown_handler = [&](int) {
+ // this will unblock start_loop()
+ ctx_server.terminate();
+ };
+
+ // TODO: refactor in common/console
+#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
+ struct sigaction sigint_action;
+ sigint_action.sa_handler = signal_handler;
+ sigemptyset (&sigint_action.sa_mask);
+ sigint_action.sa_flags = 0;
+ sigaction(SIGINT, &sigint_action, NULL);
+ sigaction(SIGTERM, &sigint_action, NULL);
+#elif defined (_WIN32)
+ auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL {
+ return (ctrl_type == CTRL_C_EVENT) ? (signal_handler(SIGINT), true) : false;
+ };
+ SetConsoleCtrlHandler(reinterpret_cast(console_ctrl_handler), true);
+#endif
+
+ // this call blocks the main thread until ctx_server.terminate() is called
+ ctx_server.start_loop();
+}
+
+json parse_options(bool streaming, const backend::PredictOptions* predict, const common_params& params_base, llama_context* ctx)
+{
+
+ // Create now a json data from the prediction options instead
+ //
+ json data;
+ data["stream"] = streaming;
+ data["cache_prompt"] = predict->promptcacheall();
+ data["n_predict"] = predict->tokens() == 0 ? -1 : predict->tokens();
+ data["top_k"] = predict->topk();
+ data["top_p"] = predict->topp();
+ data["typical_p"] = predict->typicalp();
+ data["temperature"] = predict->temperature();
+ data["repeat_last_n"] = predict->repeat();
+ data["repeat_penalty"] = predict->penalty();
+ data["frequency_penalty"] = predict->frequencypenalty();
+ data["presence_penalty"] = predict->presencepenalty();
+ data["mirostat"] = predict->mirostat();
+ data["mirostat_tau"] = predict->mirostattau();
+ data["mirostat_eta"] = predict->mirostateta();
+ data["n_keep"] = predict->nkeep();
+ data["seed"] = predict->seed();
+
+
+ std::string grammar_str = predict->grammar();
+
+
+
+ if (!grammar_str.empty()) {
+ data["grammar"] = grammar_str;
+ SRV_INF("Using grammar: %s\n", grammar_str.c_str());
+ }
+
+ // Only set prompt if UseTokenizerTemplate is false or if no Messages are provided
+ // When UseTokenizerTemplate is true and Messages are provided, prompt will be set via chat templates in Predict/PredictStream
+ if (!predict->usetokenizertemplate() || predict->messages_size() == 0) {
+ data["prompt"] = predict->prompt();
+ }
+
+ // Extract tools and tool_choice from proto and add to data JSON
+ SRV_INF("[TOOLS DEBUG] parse_options: Checking for tools in proto, tools().empty()=%d, tools().size()=%zu\n",
+ predict->tools().empty() ? 1 : 0, predict->tools().size());
+ if (!predict->tools().empty()) {
+ SRV_INF("[TOOLS DEBUG] parse_options: Tools string from proto (first 500 chars): %s\n",
+ predict->tools().substr(0, std::min(500, predict->tools().size())).c_str());
+ try {
+ // Parse tools JSON string and add to data
+ json tools_json = json::parse(predict->tools());
+ data["tools"] = tools_json;
+ SRV_INF("Extracted tools from proto: %s\n", predict->tools().c_str());
+ // Debug: Log tools count and names
+ if (tools_json.is_array()) {
+ SRV_INF("[TOOLS DEBUG] parse_options: Successfully parsed %zu tools from Go layer\n", tools_json.size());
+ for (size_t i = 0; i < tools_json.size(); i++) {
+ if (tools_json[i].contains("function") && tools_json[i]["function"].contains("name")) {
+ SRV_INF("[TOOLS DEBUG] parse_options: Tool %zu: %s\n", i, tools_json[i]["function"]["name"].get().c_str());
+ } else if (tools_json[i].contains("name")) {
+ SRV_INF("[TOOLS DEBUG] parse_options: Tool %zu: %s\n", i, tools_json[i]["name"].get().c_str());
+ }
+ }
+ } else {
+ SRV_WRN("[TOOLS DEBUG] parse_options: Parsed tools JSON is not an array: %s\n", tools_json.dump().c_str());
+ }
+ } catch (const json::parse_error& e) {
+ SRV_WRN("Failed to parse tools JSON from proto: %s\n", e.what());
+ SRV_WRN("[TOOLS DEBUG] parse_options: Tools string that failed to parse: %s\n", predict->tools().c_str());
+ }
+ } else {
+ SRV_INF("%s", "[TOOLS DEBUG] parse_options: No tools received from Go layer (predict->tools() is empty)\n");
+ }
+
+ // Debug: Verify tools are in data after extraction
+ if (data.contains("tools")) {
+ SRV_INF("[TOOLS DEBUG] parse_options: Tools successfully added to data, count: %zu\n",
+ data["tools"].is_array() ? data["tools"].size() : 0);
+ } else {
+ SRV_INF("%s", "[TOOLS DEBUG] parse_options: WARNING - Tools NOT in data after extraction!\n");
+ }
+ if (!predict->toolchoice().empty()) {
+ try {
+ // Parse tool_choice JSON string
+ json tool_choice_json = json::parse(predict->toolchoice());
+ // tool_choice can be a string ("auto", "none", "required") or an object
+ // Store it as-is (string or object) so we can convert object to "required" later when adding to body_json
+ if (tool_choice_json.is_string()) {
+ data["tool_choice"] = tool_choice_json.get();
+ SRV_DBG("[TOOLS DEBUG] Received tool_choice from Go layer: %s\n", tool_choice_json.get().c_str());
+ } else {
+ // Store object as-is so we can detect it later and convert to "required"
+ data["tool_choice"] = tool_choice_json;
+ SRV_DBG("[TOOLS DEBUG] Received tool_choice object from Go layer: %s\n", tool_choice_json.dump().c_str());
+ }
+ SRV_INF("Extracted tool_choice from proto: %s\n", predict->toolchoice().c_str());
+ } catch (const json::parse_error& e) {
+ // If parsing fails, treat as string
+ data["tool_choice"] = predict->toolchoice();
+ SRV_INF("Extracted tool_choice as string: %s\n", predict->toolchoice().c_str());
+ }
+ }
+
+ // Extract logprobs and top_logprobs from proto and add to JSON data
+ // Following server.cpp pattern: logprobs maps to n_probs when provided
+ if (predict->logprobs() > 0) {
+ data["logprobs"] = predict->logprobs();
+ // Map logprobs to n_probs (following server.cpp line 369 pattern)
+ // n_probs will be set by params_from_json_cmpl if logprobs is provided
+ data["n_probs"] = predict->logprobs();
+ SRV_INF("Using logprobs: %d\n", predict->logprobs());
+ }
+ if (predict->toplogprobs() > 0) {
+ data["top_logprobs"] = predict->toplogprobs();
+ SRV_INF("Using top_logprobs: %d\n", predict->toplogprobs());
+ }
+
+ // Extract logit_bias from proto and add to JSON data
+ if (!predict->logitbias().empty()) {
+ try {
+ // Parse logit_bias JSON string from proto
+ json logit_bias_json = json::parse(predict->logitbias());
+ // Add to data - llama.cpp server expects it as an object (map)
+ data["logit_bias"] = logit_bias_json;
+ SRV_INF("Using logit_bias: %s\n", predict->logitbias().c_str());
+ } catch (const json::parse_error& e) {
+ SRV_ERR("Failed to parse logit_bias JSON from proto: %s\n", e.what());
+ }
+ }
+
+ data["ignore_eos"] = predict->ignoreeos();
+ data["embeddings"] = predict->embeddings();
+
+ // Add the correlationid to json data
+ data["correlation_id"] = predict->correlationid();
+
+ // for each image in the request, add the image data
+ //
+ for (int i = 0; i < predict->images_size(); i++) {
+ data["image_data"].push_back(json
+ {
+ {"id", i},
+ {"data", predict->images(i)},
+ });
+ }
+
+ // for each audio in the request, add the audio data
+ for (int i = 0; i < predict->audios_size(); i++) {
+ data["audio_data"].push_back(json
+ {
+ {"id", i},
+ {"data", predict->audios(i)},
+ });
+ }
+
+ data["stop"] = predict->stopprompts();
+ // data["n_probs"] = predict->nprobs();
+ //TODO: images,
+
+ // Serialize grammar triggers from server context to JSON array
+ if (!params_base.sampling.grammar_triggers.empty()) {
+ json grammar_triggers = json::array();
+ for (const auto& trigger : params_base.sampling.grammar_triggers) {
+ json trigger_json;
+ trigger_json["value"] = trigger.value;
+ // Always serialize as WORD type since upstream converts WORD to TOKEN internally
+ trigger_json["type"] = static_cast(COMMON_GRAMMAR_TRIGGER_TYPE_WORD);
+ grammar_triggers.push_back(trigger_json);
+ }
+ data["grammar_triggers"] = grammar_triggers;
+ }
+
+ // Serialize preserved tokens from server context to JSON array
+ if (!params_base.sampling.preserved_tokens.empty()) {
+ json preserved_tokens = json::array();
+ for (const auto& token : params_base.sampling.preserved_tokens) {
+ preserved_tokens.push_back(common_token_to_piece(ctx, token));
+ }
+ data["preserved_tokens"] = preserved_tokens;
+ }
+
+ return data;
+}
+
+
+const std::vector kv_cache_types = {
+ GGML_TYPE_F32,
+ GGML_TYPE_F16,
+ GGML_TYPE_BF16,
+ GGML_TYPE_Q8_0,
+ GGML_TYPE_Q4_0,
+ GGML_TYPE_Q4_1,
+ GGML_TYPE_IQ4_NL,
+ GGML_TYPE_Q5_0,
+ GGML_TYPE_Q5_1,
+};
+
+static ggml_type kv_cache_type_from_str(const std::string & s) {
+ for (const auto & type : kv_cache_types) {
+ if (ggml_type_name(type) == s) {
+ return type;
+ }
+ }
+ throw std::runtime_error("Unsupported cache type: " + s);
+}
+
+static std::string get_all_kv_cache_types() {
+ std::ostringstream msg;
+ for (const auto & type : kv_cache_types) {
+ msg << ggml_type_name(type) << (&type == &kv_cache_types.back() ? "" : ", ");
+ }
+ return msg.str();
+}
+
+// Adds an RPC server
+// Description here: https://github.com/ggml-org/llama.cpp/blob/master/tools/rpc/README.md
+static void add_rpc_devices(std::string servers) {
+ auto rpc_servers = string_split(servers, ',');
+ // Trim whitespace to allow more flexible configurations, such as having entries on separate lines.
+ for (std::string & server : rpc_servers)
+ {
+ server.erase(0, server.find_first_not_of(" \t\n\r"));
+ server.erase(server.find_last_not_of(" \t\n\r") + 1);
+ }
+ if (rpc_servers.empty()) {
+ throw std::invalid_argument("no RPC servers specified");
+ }
+ ggml_backend_reg_t rpc_reg = ggml_backend_reg_by_name("RPC");
+ if (!rpc_reg) {
+ throw std::invalid_argument("failed to find RPC backend");
+ }
+ typedef ggml_backend_reg_t (*ggml_backend_rpc_add_server_t)(const char * endpoint);
+ ggml_backend_rpc_add_server_t ggml_backend_rpc_add_server_fn = (ggml_backend_rpc_add_server_t) ggml_backend_reg_get_proc_address(rpc_reg, "ggml_backend_rpc_add_server");
+ if (!ggml_backend_rpc_add_server_fn) {
+ throw std::invalid_argument("failed to find RPC add server function");
+ }
+ for (const auto & server : rpc_servers) {
+ ggml_backend_reg_t reg = ggml_backend_rpc_add_server_fn(server.c_str());
+ ggml_backend_register(reg);
+ }
+}
+
+static void params_parse(server_context& /*ctx_server*/, const backend::ModelOptions* request,
+ common_params & params) {
+
+ // this is comparable to: https://github.com/ggerganov/llama.cpp/blob/d9b33fe95bd257b36c84ee5769cc048230067d6f/examples/server/server.cpp#L1809
+
+ params.model.path = request->modelfile();
+ if (!request->mmproj().empty()) {
+ params.mmproj.path = request->mmproj();
+ }
+ // params.model_alias ??
+ params.model_alias = request->modelfile();
+ if (!request->cachetypekey().empty()) {
+ params.cache_type_k = kv_cache_type_from_str(request->cachetypekey());
+ }
+ if (!request->cachetypevalue().empty()) {
+ params.cache_type_v = kv_cache_type_from_str(request->cachetypevalue());
+ }
+ params.n_ctx = request->contextsize();
+ //params.memory_f16 = request->f16memory();
+ params.cpuparams.n_threads = request->threads();
+ params.n_gpu_layers = request->ngpulayers();
+ params.n_batch = request->nbatch();
+ //params.verbosity = INT_MAX;
+ // Enable all debug logs by setting verbosity threshold to maximum
+ //common_log_set_verbosity_thold(INT_MAX);
+ params.n_ubatch = request->nbatch(); // fixes issue with reranking models being limited to 512 tokens (the default n_ubatch size); allows for setting the maximum input amount of tokens thereby avoiding this error "input is too large to process. increase the physical batch size"
+
+ // Initialize ctx_shift to false by default (can be overridden by options)
+ params.ctx_shift = false;
+ // Initialize cache_ram_mib to -1 by default (no limit, can be overridden by options)
+ params.cache_ram_mib = -1;
+ // Initialize n_parallel to 1 by default (can be overridden by options)
+ params.n_parallel = 1;
+ // Initialize grpc_servers to empty (can be overridden by options)
+ std::string grpc_servers_option = "";
+
+ // Initialize fit_params options (can be overridden by options)
+ // fit_params: whether to auto-adjust params to fit device memory (default: true as in llama.cpp)
+ params.fit_params = true;
+ // fit_params_target: target margin per device in bytes (default: 1GB per device)
+ // Initialize as vector with default value for all devices
+ params.fit_params_target = std::vector(llama_max_devices(), 1024 * 1024 * 1024);
+ // fit_params_min_ctx: minimum context size for fit (default: 4096)
+ params.fit_params_min_ctx = 4096;
+
+ // Initialize additional server options (can be overridden by options)
+ // n_cache_reuse: min chunk size for KV cache reuse via shifting (default: 0 = disabled)
+ params.n_cache_reuse = 0;
+ // slot_prompt_similarity: threshold for slot prompt matching (default: 0.1)
+ params.slot_prompt_similarity = 0.1f;
+ // swa_full: use full-size SWA cache (default: false)
+ params.swa_full = false;
+ // cont_batching: continuous batching (default: true, auto-enabled when n_parallel > 1)
+ params.cont_batching = true;
+ // check_tensors: validate tensor data (default: false)
+ params.check_tensors = false;
+ // warmup: enable warmup run (default: true)
+ params.warmup = true;
+ // no_op_offload: disable host tensor op offload (default: false)
+ params.no_op_offload = false;
+ // kv_unified: enable unified KV cache (default: false)
+ params.kv_unified = false;
+ // n_ctx_checkpoints: max context checkpoints per slot (default: 8)
+ params.n_ctx_checkpoints = 8;
+
+ // decode options. Options are in form optname:optvale, or if booleans only optname.
+ for (int i = 0; i < request->options_size(); i++) {
+ std::string opt = request->options(i);
+ std::vector opt_buf(opt.begin(), opt.end());
+ opt_buf.push_back('\0');
+ char *optname = strtok(opt_buf.data(), ":");
+ char *optval = strtok(NULL, ":");
+ std::string optval_str = (optval == NULL) ? "true" : optval;
+
+ if (!strcmp(optname, "context_shift")) {
+ if (optval_str == "true" || optval_str == "1" || optval_str == "yes" || optval_str == "on" || optval_str == "enabled") {
+ params.ctx_shift = true;
+ } else if (optval_str == "false" || optval_str == "0" || optval_str == "no" || optval_str == "off" || optval_str == "disabled") {
+ params.ctx_shift = false;
+ }
+ } else if (!strcmp(optname, "use_jinja") || !strcmp(optname, "jinja")) {
+ if (optval_str == "true" || optval_str == "1" || optval_str == "yes" || optval_str == "on" || optval_str == "enabled") {
+ params.use_jinja = true;
+ } else if (optval_str == "false" || optval_str == "0" || optval_str == "no" || optval_str == "off" || optval_str == "disabled") {
+ params.use_jinja = false;
+ }
+ } else if (!strcmp(optname, "cache_ram")) {
+ if (optval != NULL) {
+ try {
+ params.cache_ram_mib = std::stoi(optval_str);
+ } catch (const std::exception& e) {
+ // If conversion fails, keep default value (-1)
+ }
+ }
+ } else if (!strcmp(optname, "parallel") || !strcmp(optname, "n_parallel")) {
+ if (optval != NULL) {
+ try {
+ params.n_parallel = std::stoi(optval_str);
+ if (params.n_parallel > 1) {
+ params.cont_batching = true;
+ }
+ } catch (const std::exception& e) {
+ // If conversion fails, keep default value (1)
+ }
+ }
+ } else if (!strcmp(optname, "grpc_servers") || !strcmp(optname, "rpc_servers")) {
+ if (optval != NULL) {
+ grpc_servers_option = optval_str;
+ }
+ } else if (!strcmp(optname, "fit_params") || !strcmp(optname, "fit")) {
+ if (optval_str == "true" || optval_str == "1" || optval_str == "yes" || optval_str == "on" || optval_str == "enabled") {
+ params.fit_params = true;
+ } else if (optval_str == "false" || optval_str == "0" || optval_str == "no" || optval_str == "off" || optval_str == "disabled") {
+ params.fit_params = false;
+ }
+ } else if (!strcmp(optname, "fit_params_target") || !strcmp(optname, "fit_target")) {
+ if (optval != NULL) {
+ try {
+ // Value is in MiB, can be comma-separated list for multiple devices
+ // Single value is broadcast across all devices
+ std::string arg_next = optval_str;
+ const std::regex regex{ R"([,/]+)" };
+ std::sregex_token_iterator it{ arg_next.begin(), arg_next.end(), regex, -1 };
+ std::vector split_arg{ it, {} };
+ if (split_arg.size() >= llama_max_devices()) {
+ // Too many values provided
+ continue;
+ }
+ if (split_arg.size() == 1) {
+ // Single value: broadcast to all devices
+ size_t value_mib = std::stoul(split_arg[0]);
+ std::fill(params.fit_params_target.begin(), params.fit_params_target.end(), value_mib * 1024 * 1024);
+ } else {
+ // Multiple values: set per device
+ for (size_t i = 0; i < split_arg.size() && i < params.fit_params_target.size(); i++) {
+ params.fit_params_target[i] = std::stoul(split_arg[i]) * 1024 * 1024;
+ }
+ }
+ } catch (const std::exception& e) {
+ // If conversion fails, keep default value (1GB per device)
+ }
+ }
+ } else if (!strcmp(optname, "fit_params_min_ctx") || !strcmp(optname, "fit_ctx")) {
+ if (optval != NULL) {
+ try {
+ params.fit_params_min_ctx = std::stoi(optval_str);
+ } catch (const std::exception& e) {
+ // If conversion fails, keep default value (4096)
+ }
+ }
+ } else if (!strcmp(optname, "n_cache_reuse") || !strcmp(optname, "cache_reuse")) {
+ if (optval != NULL) {
+ try {
+ params.n_cache_reuse = std::stoi(optval_str);
+ } catch (const std::exception& e) {
+ // If conversion fails, keep default value (0)
+ }
+ }
+ } else if (!strcmp(optname, "slot_prompt_similarity") || !strcmp(optname, "sps")) {
+ if (optval != NULL) {
+ try {
+ params.slot_prompt_similarity = std::stof(optval_str);
+ } catch (const std::exception& e) {
+ // If conversion fails, keep default value (0.1)
+ }
+ }
+ } else if (!strcmp(optname, "swa_full")) {
+ if (optval_str == "true" || optval_str == "1" || optval_str == "yes" || optval_str == "on" || optval_str == "enabled") {
+ params.swa_full = true;
+ } else if (optval_str == "false" || optval_str == "0" || optval_str == "no" || optval_str == "off" || optval_str == "disabled") {
+ params.swa_full = false;
+ }
+ } else if (!strcmp(optname, "cont_batching") || !strcmp(optname, "continuous_batching")) {
+ if (optval_str == "true" || optval_str == "1" || optval_str == "yes" || optval_str == "on" || optval_str == "enabled") {
+ params.cont_batching = true;
+ } else if (optval_str == "false" || optval_str == "0" || optval_str == "no" || optval_str == "off" || optval_str == "disabled") {
+ params.cont_batching = false;
+ }
+ } else if (!strcmp(optname, "check_tensors")) {
+ if (optval_str == "true" || optval_str == "1" || optval_str == "yes" || optval_str == "on" || optval_str == "enabled") {
+ params.check_tensors = true;
+ } else if (optval_str == "false" || optval_str == "0" || optval_str == "no" || optval_str == "off" || optval_str == "disabled") {
+ params.check_tensors = false;
+ }
+ } else if (!strcmp(optname, "warmup")) {
+ if (optval_str == "true" || optval_str == "1" || optval_str == "yes" || optval_str == "on" || optval_str == "enabled") {
+ params.warmup = true;
+ } else if (optval_str == "false" || optval_str == "0" || optval_str == "no" || optval_str == "off" || optval_str == "disabled") {
+ params.warmup = false;
+ }
+ } else if (!strcmp(optname, "no_op_offload")) {
+ if (optval_str == "true" || optval_str == "1" || optval_str == "yes" || optval_str == "on" || optval_str == "enabled") {
+ params.no_op_offload = true;
+ } else if (optval_str == "false" || optval_str == "0" || optval_str == "no" || optval_str == "off" || optval_str == "disabled") {
+ params.no_op_offload = false;
+ }
+ } else if (!strcmp(optname, "kv_unified") || !strcmp(optname, "unified_kv")) {
+ if (optval_str == "true" || optval_str == "1" || optval_str == "yes" || optval_str == "on" || optval_str == "enabled") {
+ params.kv_unified = true;
+ } else if (optval_str == "false" || optval_str == "0" || optval_str == "no" || optval_str == "off" || optval_str == "disabled") {
+ params.kv_unified = false;
+ }
+ } else if (!strcmp(optname, "n_ctx_checkpoints") || !strcmp(optname, "ctx_checkpoints")) {
+ if (optval != NULL) {
+ try {
+ params.n_ctx_checkpoints = std::stoi(optval_str);
+ } catch (const std::exception& e) {
+ // If conversion fails, keep default value (8)
+ }
+ }
+ }
+ }
+
+ // Set params.n_parallel from environment variable if not set via options (fallback)
+ if (params.n_parallel == 1) {
+ const char *env_parallel = std::getenv("LLAMACPP_PARALLEL");
+ if (env_parallel != NULL) {
+ try {
+ params.n_parallel = std::stoi(env_parallel);
+ if (params.n_parallel > 1) {
+ params.cont_batching = true;
+ }
+ } catch (const std::exception& e) {
+ // If conversion fails, keep default value (1)
+ }
+ }
+ }
+
+ // Add RPC devices from option or environment variable (fallback)
+ if (!grpc_servers_option.empty()) {
+ add_rpc_devices(grpc_servers_option);
+ } else {
+ const char *llama_grpc_servers = std::getenv("LLAMACPP_GRPC_SERVERS");
+ if (llama_grpc_servers != NULL) {
+ add_rpc_devices(std::string(llama_grpc_servers));
+ }
+ }
+
+ // Add kv_overrides
+ if (request->overrides_size() > 0) {
+ for (int i = 0; i < request->overrides_size(); i++) {
+ string_parse_kv_override(request->overrides(i).c_str(), params.kv_overrides);
+ }
+ }
+
+ if (!params.kv_overrides.empty()) {
+ params.kv_overrides.emplace_back();
+ params.kv_overrides.back().key[0] = 0;
+ }
+
+ // TODO: Add yarn
+
+ if (!request->tensorsplit().empty()) {
+ std::string arg_next = request->tensorsplit();
+
+ // split string by , and /
+ const std::regex regex{ R"([,/]+)" };
+ std::sregex_token_iterator it{ arg_next.begin(), arg_next.end(), regex, -1 };
+ std::vector split_arg{ it, {} };
+
+ GGML_ASSERT(split_arg.size() <= llama_max_devices());
+
+ for (size_t i_device = 0; i_device < llama_max_devices(); ++i_device) {
+ if (i_device < split_arg.size()) {
+ params.tensor_split[i_device] = std::stof(split_arg[i_device]);
+ }
+ else {
+ params.tensor_split[i_device] = 0.0f;
+ }
+ }
+ }
+
+ if (!request->maingpu().empty()) {
+ params.main_gpu = std::stoi(request->maingpu());
+ }
+ if (!request->loraadapter().empty() && !request->lorabase().empty()) {
+ float scale_factor = 1.0f;
+ if (request->lorascale() != 0.0f) {
+ scale_factor = request->lorascale();
+ }
+ // get the directory of modelfile
+ std::string model_dir = params.model.path.substr(0, params.model.path.find_last_of("/\\"));
+ common_adapter_lora_info lora_info;
+ lora_info.path = model_dir + "/" + request->loraadapter();
+ lora_info.scale = scale_factor;
+ lora_info.task_name = "";
+ lora_info.prompt_prefix = "";
+ lora_info.ptr = nullptr;
+ params.lora_adapters.push_back(std::move(lora_info));
+ }
+ params.use_mlock = request->mlock();
+ params.use_mmap = request->mmap();
+
+ if (request->flashattention() == "on" || request->flashattention() == "enabled") {
+ params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_ENABLED;
+ } else if (request->flashattention() == "off" || request->flashattention() == "disabled") {
+ params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_DISABLED;
+ } else if (request->flashattention() == "auto") {
+ params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_AUTO;
+ }
+
+ params.no_kv_offload = request->nokvoffload();
+ params.embedding = request->embeddings() || request->reranking();
+ if (request->reranking()) {
+ params.pooling_type = LLAMA_POOLING_TYPE_RANK;
+ }
+
+
+ if (request->ropescaling() == "none") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_NONE; }
+ else if (request->ropescaling() == "yarn") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_YARN; }
+ else if (request->ropescaling() == "linear") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_LINEAR; }
+
+ if ( request->yarnextfactor() != 0.0f ) {
+ params.yarn_ext_factor = request->yarnextfactor();
+ }
+ if ( request->yarnattnfactor() != 0.0f ) {
+ params.yarn_attn_factor = request->yarnattnfactor();
+ }
+ if ( request->yarnbetafast() != 0.0f ) {
+ params.yarn_beta_fast = request->yarnbetafast();
+ }
+ if ( request->yarnbetaslow() != 0.0f ) {
+ params.yarn_beta_slow = request->yarnbetaslow();
+ }
+ if ( request->ropefreqbase() != 0.0f ) {
+ params.rope_freq_base = request->ropefreqbase();
+ }
+ if ( request->ropefreqscale() != 0.0f ) {
+ params.rope_freq_scale = request->ropefreqscale();
+ }
+
+ if (request->grammartriggers_size() > 0) {
+ //params.sampling.grammar_lazy = true;
+ // Store grammar trigger words for processing after model is loaded
+ for (int i = 0; i < request->grammartriggers_size(); i++) {
+ const auto & word = request->grammartriggers(i).word();
+ common_grammar_trigger trigger;
+ trigger.type = COMMON_GRAMMAR_TRIGGER_TYPE_WORD;
+ trigger.value = word;
+ params.sampling.grammar_triggers.push_back(std::move(trigger));
+ }
+ }
+}
+
+
+// GRPC Server start
+class BackendServiceImpl final : public backend::Backend::Service {
+private:
+ server_context& ctx_server;
+ common_params params_base; // Store copy of params_base, set after model load
+
+public:
+ BackendServiceImpl(server_context& ctx) : ctx_server(ctx) {}
+
+ grpc::Status Health(ServerContext* /*context*/, const backend::HealthMessage* /*request*/, backend::Reply* reply) override {
+ // Implement Health RPC
+ reply->set_message("OK");
+ return Status::OK;
+ }
+
+ grpc::Status LoadModel(ServerContext* /*context*/, const backend::ModelOptions* request, backend::Result* result) override {
+ // Implement LoadModel RPC
+ common_params params;
+ params_parse(ctx_server, request, params);
+
+ common_init();
+ // Ensure debug logs are enabled after common_init() sets up logging
+ common_log_set_verbosity_thold(params.verbosity);
+
+ llama_backend_init();
+ llama_numa_init(params.numa);
+
+
+ LOG_INF("system info: n_threads = %d, n_threads_batch = %d, total_threads = %d\n", params.cpuparams.n_threads, params.cpuparams_batch.n_threads, std::thread::hardware_concurrency());
+ LOG_INF("\n");
+ LOG_INF("%s\n", common_params_get_system_info(params).c_str());
+ LOG_INF("\n");
+
+ // Capture error messages during model loading
+ struct error_capture {
+ std::string captured_error;
+ std::mutex error_mutex;
+ ggml_log_callback original_callback;
+ void* original_user_data;
+ } error_capture_data;
+
+ // Get original log callback
+ llama_log_get(&error_capture_data.original_callback, &error_capture_data.original_user_data);
+
+ // Set custom callback to capture errors
+ llama_log_set([](ggml_log_level level, const char * text, void * user_data) {
+ auto* capture = static_cast(user_data);
+
+ // Capture error messages
+ if (level == GGML_LOG_LEVEL_ERROR) {
+ std::lock_guard lock(capture->error_mutex);
+ // Append error message, removing trailing newlines
+ std::string msg(text);
+ while (!msg.empty() && (msg.back() == '\n' || msg.back() == '\r')) {
+ msg.pop_back();
+ }
+ if (!msg.empty()) {
+ if (!capture->captured_error.empty()) {
+ capture->captured_error.append("; ");
+ }
+ capture->captured_error.append(msg);
+ }
+ }
+
+ // Also call original callback to preserve logging
+ if (capture->original_callback) {
+ capture->original_callback(level, text, capture->original_user_data);
+ }
+ }, &error_capture_data);
+
+ // load the model
+ bool load_success = ctx_server.load_model(params);
+
+ // Restore original log callback
+ llama_log_set(error_capture_data.original_callback, error_capture_data.original_user_data);
+
+ if (!load_success) {
+ std::string error_msg = "Failed to load model: " + params.model.path;
+ if (!params.mmproj.path.empty()) {
+ error_msg += " (with mmproj: " + params.mmproj.path + ")";
+ }
+ if (params.has_speculative() && !params.speculative.model.path.empty()) {
+ error_msg += " (with draft model: " + params.speculative.model.path + ")";
+ }
+
+ // Add captured error details if available
+ {
+ std::lock_guard lock(error_capture_data.error_mutex);
+ if (!error_capture_data.captured_error.empty()) {
+ error_msg += ". Error: " + error_capture_data.captured_error;
+ } else {
+ error_msg += ". Model file may not exist or be invalid.";
+ }
+ }
+
+ result->set_message(error_msg);
+ result->set_success(false);
+ return grpc::Status(grpc::StatusCode::INTERNAL, error_msg);
+ }
+
+ // Process grammar triggers now that vocab is available
+ if (!params.sampling.grammar_triggers.empty()) {
+ std::vector processed_triggers;
+ for (const auto& trigger : params.sampling.grammar_triggers) {
+ if (trigger.type == COMMON_GRAMMAR_TRIGGER_TYPE_WORD) {
+ auto ids = common_tokenize(ctx_server.impl->vocab, trigger.value, /* add_special= */ false, /* parse_special= */ true);
+ if (ids.size() == 1) {
+ auto token = ids[0];
+ // Add the token to preserved_tokens if not already present
+ if (params.sampling.preserved_tokens.find(token) == params.sampling.preserved_tokens.end()) {
+ params.sampling.preserved_tokens.insert(token);
+ LOG_INF("Added grammar trigger token to preserved tokens: %d (`%s`)\n", token, trigger.value.c_str());
+ }
+ LOG_INF("Grammar trigger token: %d (`%s`)\n", token, trigger.value.c_str());
+ common_grammar_trigger processed_trigger;
+ processed_trigger.type = COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN;
+ processed_trigger.value = trigger.value;
+ processed_trigger.token = token;
+ processed_triggers.push_back(std::move(processed_trigger));
+ } else {
+ LOG_INF("Grammar trigger word: `%s`\n", trigger.value.c_str());
+ processed_triggers.push_back(trigger);
+ }
+ } else {
+ processed_triggers.push_back(trigger);
+ }
+ }
+ // Update the grammar triggers in params
+ params.sampling.grammar_triggers = std::move(processed_triggers);
+ }
+
+ //ctx_server.init();
+ result->set_message("Loading succeeded");
+ result->set_success(true);
+ loaded_model = true;
+ // Store copy of params_base for use in parse_options and other methods
+ params_base = params;
+
+ return Status::OK;
+ }
+
+ // Helper function to extract logprobs from JSON response
+ static json extract_logprobs_from_json(const json& res_json) {
+ json logprobs_json = json::object();
+
+ // Check for OAI-compatible format: choices[0].logprobs
+ if (res_json.contains("choices") && res_json["choices"].is_array() &&
+ res_json["choices"].size() > 0 && res_json["choices"][0].contains("logprobs")) {
+ logprobs_json = res_json["choices"][0]["logprobs"];
+ }
+ // Check for non-OAI format: completion_probabilities
+ else if (res_json.contains("completion_probabilities")) {
+ // Convert completion_probabilities to OAI format
+ logprobs_json["content"] = res_json["completion_probabilities"];
+ }
+ // Check for direct logprobs field
+ else if (res_json.contains("logprobs")) {
+ logprobs_json = res_json["logprobs"];
+ }
+
+ return logprobs_json;
+ }
+
+ grpc::Status PredictStream(grpc::ServerContext* context, const backend::PredictOptions* request, grpc::ServerWriter* writer) override {
+ if (params_base.model.path.empty()) {
+ return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "Model not loaded");
+ }
+ json data = parse_options(true, request, params_base, ctx_server.get_llama_context());
+
+
+ //Raise error if embeddings is set to true
+ if (params_base.embedding) {
+ return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "Embedding is not supported in streaming mode");
+ }
+
+
+ auto completion_id = gen_chatcmplid();
+ // get response reader - it contains references to the queues and will stay valid
+ auto rd = ctx_server.get_response_reader();
+ try {
+ std::vector tasks;
+
+ std::string prompt_str;
+ std::vector files; // Declare files early so it's accessible in both branches
+ // Handle chat templates when UseTokenizerTemplate is enabled and Messages are provided
+ if (request->usetokenizertemplate() && request->messages_size() > 0 && ctx_server.impl->chat_templates != nullptr) {
+ // Convert proto Messages to JSON format compatible with oaicompat_chat_params_parse
+ json body_json;
+ json messages_json = json::array();
+
+ // Find the last user message index to attach images/audio to
+ int last_user_msg_idx = -1;
+ for (int i = request->messages_size() - 1; i >= 0; i--) {
+ if (request->messages(i).role() == "user") {
+ last_user_msg_idx = i;
+ break;
+ }
+ }
+
+ for (int i = 0; i < request->messages_size(); i++) {
+ const auto& msg = request->messages(i);
+ json msg_json;
+ msg_json["role"] = msg.role();
+
+ bool is_last_user_msg = (i == last_user_msg_idx);
+ bool has_images_or_audio = (request->images_size() > 0 || request->audios_size() > 0);
+
+ // Handle content - can be string, null, or array
+ // For multimodal content, we'll embed images/audio from separate fields
+ if (!msg.content().empty()) {
+ // Try to parse content as JSON to see if it's already an array
+ json content_val;
+ try {
+ content_val = json::parse(msg.content());
+ // Handle null values - convert to empty string to avoid template errors
+ if (content_val.is_null()) {
+ content_val = "";
+ }
+ } catch (const json::parse_error&) {
+ // Not JSON, treat as plain string
+ content_val = msg.content();
+ }
+
+ // If content is an object (e.g., from tool call failures), convert to string
+ if (content_val.is_object()) {
+ content_val = content_val.dump();
+ }
+
+ // If content is a string and this is the last user message with images/audio, combine them
+ if (content_val.is_string() && is_last_user_msg && has_images_or_audio) {
+ json content_array = json::array();
+ // Add text first
+ content_array.push_back({{"type", "text"}, {"text", content_val.get()}});
+ // Add images
+ if (request->images_size() > 0) {
+ for (int j = 0; j < request->images_size(); j++) {
+ json image_chunk;
+ image_chunk["type"] = "image_url";
+ json image_url;
+ image_url["url"] = "data:image/jpeg;base64," + request->images(j);
+ image_chunk["image_url"] = image_url;
+ content_array.push_back(image_chunk);
+ }
+ }
+ // Add audios
+ if (request->audios_size() > 0) {
+ for (int j = 0; j < request->audios_size(); j++) {
+ json audio_chunk;
+ audio_chunk["type"] = "input_audio";
+ json input_audio;
+ input_audio["data"] = request->audios(j);
+ input_audio["format"] = "wav"; // default, could be made configurable
+ audio_chunk["input_audio"] = input_audio;
+ content_array.push_back(audio_chunk);
+ }
+ }
+ msg_json["content"] = content_array;
+ } else {
+ // Use content as-is (already array or not last user message)
+ // Ensure null values are converted to empty string
+ if (content_val.is_null()) {
+ msg_json["content"] = "";
+ } else {
+ msg_json["content"] = content_val;
+ }
+ }
+ } else if (is_last_user_msg && has_images_or_audio) {
+ // If no content but this is the last user message with images/audio, create content array
+ json content_array = json::array();
+ if (request->images_size() > 0) {
+ for (int j = 0; j < request->images_size(); j++) {
+ json image_chunk;
+ image_chunk["type"] = "image_url";
+ json image_url;
+ image_url["url"] = "data:image/jpeg;base64," + request->images(j);
+ image_chunk["image_url"] = image_url;
+ content_array.push_back(image_chunk);
+ }
+ }
+ if (request->audios_size() > 0) {
+ for (int j = 0; j < request->audios_size(); j++) {
+ json audio_chunk;
+ audio_chunk["type"] = "input_audio";
+ json input_audio;
+ input_audio["data"] = request->audios(j);
+ input_audio["format"] = "wav"; // default, could be made configurable
+ audio_chunk["input_audio"] = input_audio;
+ content_array.push_back(audio_chunk);
+ }
+ }
+ msg_json["content"] = content_array;
+ } else if (msg.role() == "tool") {
+ // Tool role messages must have content field set, even if empty
+ // Jinja templates expect content to be a string, not null or object
+ SRV_INF("[CONTENT DEBUG] PredictStream: Message %d is tool role, content_empty=%d\n", i, msg.content().empty() ? 1 : 0);
+ if (msg.content().empty()) {
+ msg_json["content"] = "";
+ SRV_INF("[CONTENT DEBUG] PredictStream: Message %d (tool): empty content, set to empty string\n", i);
+ } else {
+ SRV_INF("[CONTENT DEBUG] PredictStream: Message %d (tool): content exists: %s\n",
+ i, msg.content().substr(0, std::min(200, msg.content().size())).c_str());
+ // Content exists, parse and ensure it's a string
+ json content_val;
+ try {
+ content_val = json::parse(msg.content());
+ SRV_INF("[CONTENT DEBUG] PredictStream: Message %d (tool): parsed JSON, type=%s\n",
+ i, content_val.is_null() ? "null" :
+ content_val.is_object() ? "object" :
+ content_val.is_string() ? "string" :
+ content_val.is_array() ? "array" : "other");
+ // Handle null values - Jinja templates expect content to be a string, not null
+ if (content_val.is_null()) {
+ msg_json["content"] = "";
+ SRV_INF("[CONTENT DEBUG] PredictStream: Message %d (tool): null content, converted to empty string\n", i);
+ } else if (content_val.is_object()) {
+ // If content is an object (e.g., from tool call failures/errors), convert to string
+ msg_json["content"] = content_val.dump();
+ SRV_INF("[CONTENT DEBUG] PredictStream: Message %d (tool): object content, converted to string: %s\n",
+ i, content_val.dump().substr(0, std::min(200, content_val.dump().size())).c_str());
+ } else if (content_val.is_string()) {
+ msg_json["content"] = content_val.get();
+ SRV_INF("[CONTENT DEBUG] PredictStream: Message %d (tool): string content, using as-is\n", i);
+ } else {
+ // For arrays or other types, convert to string
+ msg_json["content"] = content_val.dump();
+ SRV_INF("[CONTENT DEBUG] PredictStream: Message %d (tool): %s content, converted to string\n",
+ i, content_val.is_array() ? "array" : "other type");
+ }
+ } catch (const json::parse_error&) {
+ // Not JSON, treat as plain string
+ msg_json["content"] = msg.content();
+ SRV_INF("[CONTENT DEBUG] PredictStream: Message %d (tool): not JSON, using as string\n", i);
+ }
+ }
+ } else {
+ // Ensure all messages have content set (fallback for any unhandled cases)
+ // Jinja templates expect content to be present, default to empty string if not set
+ if (!msg_json.contains("content")) {
+ SRV_INF("[CONTENT DEBUG] PredictStream: Message %d (role=%s): no content field, adding empty string\n",
+ i, msg.role().c_str());
+ msg_json["content"] = "";
+ }
+ }
+
+ // Add optional fields for OpenAI-compatible message format
+ if (!msg.name().empty()) {
+ msg_json["name"] = msg.name();
+ }
+ if (!msg.tool_call_id().empty()) {
+ msg_json["tool_call_id"] = msg.tool_call_id();
+ }
+ if (!msg.reasoning_content().empty()) {
+ msg_json["reasoning_content"] = msg.reasoning_content();
+ }
+ if (!msg.tool_calls().empty()) {
+ // Parse tool_calls JSON string and add to message
+ try {
+ json tool_calls = json::parse(msg.tool_calls());
+ msg_json["tool_calls"] = tool_calls;
+ SRV_INF("[TOOL CALLS DEBUG] PredictStream: Message %d has tool_calls: %s\n", i, tool_calls.dump().c_str());
+ // IMPORTANT: If message has tool_calls but content is empty or not set,
+ // set content to space " " instead of empty string "", because llama.cpp's
+ // common_chat_msgs_to_json_oaicompat converts empty strings to null (line 312),
+ // which causes template errors when accessing message.content[:tool_start_length]
+ if (!msg_json.contains("content") || (msg_json.contains("content") && msg_json["content"].is_string() && msg_json["content"].get().empty())) {
+ SRV_INF("[CONTENT DEBUG] PredictStream: Message %d has tool_calls but empty content, setting to space\n", i);
+ msg_json["content"] = " ";
+ }
+ // Log each tool call with name and arguments
+ if (tool_calls.is_array()) {
+ for (size_t tc_idx = 0; tc_idx < tool_calls.size(); tc_idx++) {
+ const auto& tc = tool_calls[tc_idx];
+ std::string tool_name = "unknown";
+ std::string tool_args = "{}";
+ if (tc.contains("function")) {
+ const auto& func = tc["function"];
+ if (func.contains("name")) {
+ tool_name = func["name"].get();
+ }
+ if (func.contains("arguments")) {
+ tool_args = func["arguments"].is_string() ?
+ func["arguments"].get() :
+ func["arguments"].dump();
+ }
+ } else if (tc.contains("name")) {
+ tool_name = tc["name"].get();
+ if (tc.contains("arguments")) {
+ tool_args = tc["arguments"].is_string() ?
+ tc["arguments"].get() :
+ tc["arguments"].dump();
+ }
+ }
+ SRV_INF("[TOOL CALLS DEBUG] PredictStream: Message %d, tool_call %zu: name=%s, arguments=%s\n",
+ i, tc_idx, tool_name.c_str(), tool_args.c_str());
+ }
+ }
+ } catch (const json::parse_error& e) {
+ SRV_WRN("Failed to parse tool_calls JSON: %s\n", e.what());
+ }
+ }
+
+ // Debug: Log final content state before adding to array
+ if (msg_json.contains("content")) {
+ if (msg_json["content"].is_null()) {
+ SRV_INF("[CONTENT DEBUG] PredictStream: Message %d FINAL STATE: content is NULL - THIS WILL CAUSE ERROR!\n", i);
+ } else {
+ SRV_INF("[CONTENT DEBUG] PredictStream: Message %d FINAL STATE: content type=%s, has_value=%d\n",
+ i, msg_json["content"].is_string() ? "string" :
+ msg_json["content"].is_array() ? "array" :
+ msg_json["content"].is_object() ? "object" : "other",
+ msg_json["content"].is_null() ? 0 : 1);
+ }
+ } else {
+ SRV_INF("[CONTENT DEBUG] PredictStream: Message %d FINAL STATE: NO CONTENT FIELD - THIS WILL CAUSE ERROR!\n", i);
+ }
+
+ messages_json.push_back(msg_json);
+ }
+
+ // Final safety check: Ensure no message has null content (Jinja templates require strings)
+ SRV_INF("[CONTENT DEBUG] PredictStream: Running final safety check on %zu messages\n", messages_json.size());
+ for (size_t idx = 0; idx < messages_json.size(); idx++) {
+ auto& msg = messages_json[idx];
+ if (msg.contains("content") && msg["content"].is_null()) {
+ SRV_INF("[CONTENT DEBUG] PredictStream: Safety check found message %zu with NULL content, converting to empty string\n", idx);
+ msg["content"] = "";
+ } else if (!msg.contains("content")) {
+ SRV_INF("[CONTENT DEBUG] PredictStream: Safety check found message %zu without content field, adding empty string\n", idx);
+ msg["content"] = "";
+ } else {
+ SRV_INF("[CONTENT DEBUG] PredictStream: Safety check message %zu: content OK, type=%s\n",
+ idx, msg["content"].is_string() ? "string" :
+ msg["content"].is_array() ? "array" :
+ msg["content"].is_object() ? "object" : "other");
+ }
+ }
+
+ // Debug: Count tool messages
+ int tool_msg_count = 0;
+ for (const auto& msg : messages_json) {
+ if (msg.contains("role") && msg["role"] == "tool") {
+ tool_msg_count++;
+ }
+ }
+ SRV_DBG("[TOOLS DEBUG] PredictStream: Built %d tool messages out of %zu total messages\n", tool_msg_count, messages_json.size());
+
+ // Debug: Print full conversation (messages)
+ SRV_DBG("[CONVERSATION DEBUG] PredictStream: Full messages array:\n%s\n", messages_json.dump(2).c_str());
+
+ body_json["messages"] = messages_json;
+ body_json["stream"] = true; // PredictStream is always streaming
+
+ // Check if grammar is provided from Go layer (NoGrammar=false)
+ // If grammar is provided, we must use it and NOT let template generate grammar from tools
+ // oaicompat_chat_params_parse throws an error if both grammar and tools are provided
+ bool has_grammar_from_go = data.contains("grammar") &&
+ data["grammar"].is_string() &&
+ !data["grammar"].get().empty();
+
+ SRV_INF("[TOOLS DEBUG] PredictStream: has_grammar_from_go=%d, data.contains(\"tools\")=%d, data.contains(\"grammar\")=%d\n",
+ has_grammar_from_go ? 1 : 0,
+ data.contains("tools") ? 1 : 0,
+ data.contains("grammar") ? 1 : 0);
+ if (data.contains("grammar")) {
+ SRV_INF("[TOOLS DEBUG] PredictStream: grammar type=%s, empty=%d\n",
+ data["grammar"].is_string() ? "string" : "other",
+ data["grammar"].is_string() && data["grammar"].get().empty() ? 1 : 0);
+ }
+
+ // Copy other relevant fields from data that oaicompat_chat_params_parse expects
+ // Tools and tool_choice are only passed when NoGrammar is true (grammar not provided)
+ // When grammar is provided from Go layer, we use it instead of template-generated grammar
+ if (!has_grammar_from_go) {
+ // NoGrammar=true: pass tools and let template generate grammar
+ if (data.contains("tools")) {
+ body_json["tools"] = data["tools"];
+ std::string tools_str = data["tools"].dump();
+ SRV_INF("Using tools from data (NoGrammar=true): %s\n", tools_str.c_str());
+ // Debug: Log tools count and details before template processing
+ if (data["tools"].is_array()) {
+ SRV_INF("[TOOLS DEBUG] PredictStream: Passing %zu tools to oaicompat_chat_params_parse\n", data["tools"].size());
+ for (size_t t_idx = 0; t_idx < data["tools"].size(); t_idx++) {
+ const auto& tool = data["tools"][t_idx];
+ std::string tool_name = "unknown";
+ std::string tool_desc = "";
+ if (tool.contains("function")) {
+ const auto& func = tool["function"];
+ if (func.contains("name")) {
+ tool_name = func["name"].get();
+ }
+ if (func.contains("description")) {
+ tool_desc = func["description"].is_string() ?
+ func["description"].get() : "";
+ }
+ } else if (tool.contains("name")) {
+ tool_name = tool["name"].get();
+ if (tool.contains("description")) {
+ tool_desc = tool["description"].is_string() ?
+ tool["description"].get() : "";
+ }
+ }
+ SRV_INF("[TOOLS DEBUG] PredictStream: Tool %zu: name=%s, description=%s\n",
+ t_idx, tool_name.c_str(), tool_desc.substr(0, 100).c_str());
+ }
+ }
+ } else {
+ SRV_WRN("%s", "No tools found in data - tool calls will not work without tools field\n");
+ SRV_DBG("[TOOLS DEBUG] PredictStream: No tools in data, tool_choice=%s\n", data.contains("tool_choice") ? data["tool_choice"].dump().c_str() : "not set");
+ }
+ if (data.contains("tool_choice")) {
+ // tool_choice can be a string or object, but oaicompat_chat_params_parse expects a string
+ // Convert object tool_choice to "required" (since a specific function is requested)
+ if (data["tool_choice"].is_string()) {
+ body_json["tool_choice"] = data["tool_choice"].get();
+ } else if (data["tool_choice"].is_object()) {
+ // Object tool_choice means a specific function is requested, use "required"
+ body_json["tool_choice"] = "required";
+ std::string tool_choice_obj_str = data["tool_choice"].dump();
+ SRV_INF("Converted object tool_choice to 'required': %s\n", tool_choice_obj_str.c_str());
+ } else {
+ // Fallback: convert to string
+ body_json["tool_choice"] = data["tool_choice"].dump();
+ }
+ std::string tool_choice_str = body_json["tool_choice"].get();
+ SRV_INF("Using tool_choice: %s\n", tool_choice_str.c_str());
+ } else {
+ // Default to "auto" if not specified
+ body_json["tool_choice"] = "auto";
+ }
+ } else {
+ // Grammar is provided from Go layer (NoGrammar=false) - use it, don't pass tools
+ SRV_INF("%s", "Grammar provided from Go layer - using it instead of template-generated grammar\n");
+ // Grammar will be copied from data after parsing (it's already in data)
+ }
+
+ if (data.contains("json_schema")) {
+ body_json["json_schema"] = data["json_schema"];
+ }
+ // If grammar is provided from Go layer, copy it to body_json so it's preserved
+ // (though oaicompat_chat_params_parse may not use it if tools are present)
+ if (has_grammar_from_go) {
+ body_json["grammar"] = data["grammar"];
+ }
+ if (data.contains("response_format")) {
+ body_json["response_format"] = data["response_format"];
+ }
+ if (data.contains("chat_template_kwargs")) {
+ body_json["chat_template_kwargs"] = data["chat_template_kwargs"];
+ }
+ // Pass parallel_tool_calls if present (used by oaicompat_chat_params_parse)
+ if (data.contains("parallel_tool_calls")) {
+ body_json["parallel_tool_calls"] = data["parallel_tool_calls"];
+ }
+ // Pass add_generation_prompt if present (used by oaicompat_chat_params_parse)
+ if (data.contains("add_generation_prompt")) {
+ body_json["add_generation_prompt"] = data["add_generation_prompt"];
+ }
+
+ // Debug: Print full body_json before template processing (includes messages, tools, tool_choice, etc.)
+ SRV_DBG("[CONVERSATION DEBUG] PredictStream: Full body_json before oaicompat_chat_params_parse:\n%s\n", body_json.dump(2).c_str());
+
+ // Use the same approach as server.cpp: call oaicompat_chat_params_parse
+ // This handles all template application, grammar merging, etc. automatically
+ // Files extracted from multimodal content in messages will be added to the files vector
+ // Create parser options with current chat_templates to ensure tmpls is not null
+ oaicompat_parser_options parser_opt = ctx_server.impl->oai_parser_opt;
+ parser_opt.tmpls = ctx_server.impl->chat_templates.get(); // Ensure tmpls is set to current chat_templates
+ // Update allow_image and allow_audio based on current mctx state
+ parser_opt.allow_image = ctx_server.impl->mctx ? mtmd_support_vision(ctx_server.impl->mctx) : false;
+ parser_opt.allow_audio = ctx_server.impl->mctx ? mtmd_support_audio(ctx_server.impl->mctx) : false;
+
+ // Debug: Log tools before template processing
+ if (body_json.contains("tools")) {
+ SRV_DBG("[TOOLS DEBUG] PredictStream: Before oaicompat_chat_params_parse - tools count: %zu\n",
+ body_json["tools"].is_array() ? body_json["tools"].size() : 0);
+ }
+
+ // Debug: Verify messages content before template processing
+ // Also ensure ALL messages have content set to string (not null) - templates expect strings
+ if (body_json.contains("messages") && body_json["messages"].is_array()) {
+ SRV_INF("[CONTENT DEBUG] PredictStream: Before oaicompat_chat_params_parse - checking %zu messages\n", body_json["messages"].size());
+ for (size_t idx = 0; idx < body_json["messages"].size(); idx++) {
+ auto& msg = body_json["messages"][idx];
+ std::string role_str = msg.contains("role") ? msg["role"].get() : "unknown";
+ if (msg.contains("content")) {
+ if (msg["content"].is_null()) {
+ SRV_INF("[CONTENT DEBUG] PredictStream: BEFORE TEMPLATE - Message %zu (role=%s) has NULL content - FIXING!\n", idx, role_str.c_str());
+ msg["content"] = ""; // Fix null content
+ } else if (role_str == "tool" && msg["content"].is_array()) {
+ // Tool messages must have string content, not array
+ // oaicompat_chat_params_parse expects tool messages to have string content
+ SRV_INF("[CONTENT DEBUG] PredictStream: BEFORE TEMPLATE - Message %zu (role=tool) has array content, converting to string\n", idx);
+ msg["content"] = msg["content"].dump();
+ } else if (!msg["content"].is_string() && !msg["content"].is_array()) {
+ // If content is object or other non-string type, convert to string for templates
+ SRV_INF("[CONTENT DEBUG] PredictStream: BEFORE TEMPLATE - Message %zu (role=%s) content is not string/array, converting\n", idx, role_str.c_str());
+ if (msg["content"].is_object()) {
+ msg["content"] = msg["content"].dump();
+ } else {
+ msg["content"] = "";
+ }
+ } else {
+ SRV_INF("[CONTENT DEBUG] PredictStream: BEFORE TEMPLATE - Message %zu (role=%s): content type=%s\n",
+ idx, role_str.c_str(),
+ msg["content"].is_string() ? "string" :
+ msg["content"].is_array() ? "array" :
+ msg["content"].is_object() ? "object" : "other");
+ }
+ } else {
+ SRV_INF("[CONTENT DEBUG] PredictStream: BEFORE TEMPLATE - Message %zu (role=%s) MISSING content field - ADDING!\n", idx, role_str.c_str());
+ msg["content"] = ""; // Add missing content
+ }
+ }
+ }
+
+ json parsed_data = oaicompat_chat_params_parse(body_json, parser_opt, files);
+
+ // Debug: Log tools after template processing
+ if (parsed_data.contains("tools")) {
+ SRV_DBG("[TOOLS DEBUG] PredictStream: After oaicompat_chat_params_parse - tools count: %zu\n",
+ parsed_data["tools"].is_array() ? parsed_data["tools"].size() : 0);
+ } else {
+ SRV_DBG("%s", "[TOOLS DEBUG] PredictStream: After oaicompat_chat_params_parse - no tools in parsed_data\n");
+ }
+
+ // Extract the prompt from parsed data
+ prompt_str = parsed_data.at("prompt").get();
+
+ // Preserve grammar from Go layer if it was provided (NoGrammar=false)
+ // Otherwise, use grammar from parsed_data (template-generated when NoGrammar=true)
+ json preserved_grammar;
+ if (has_grammar_from_go && data.contains("grammar")) {
+ preserved_grammar = data["grammar"];
+ }
+
+ // Merge all fields from parsed_data into data (grammar, grammar_triggers, preserved_tokens, parse_tool_calls, etc.)
+ // This ensures all template-generated fields are included
+ // parse_tool_calls is set by oaicompat_chat_params_parse when tools are present
+ for (const auto& item : parsed_data.items()) {
+ if (item.key() != "prompt") { // Don't overwrite prompt_str, we already extracted it
+ // If grammar was provided from Go layer, preserve it instead of template-generated grammar
+ if (item.key() == "grammar" && has_grammar_from_go && !preserved_grammar.is_null()) {
+ data["grammar"] = preserved_grammar;
+ } else {
+ data[item.key()] = item.value();
+ }
+ }
+ }
+
+ // Debug: Log parse_tool_calls if present (set by oaicompat_chat_params_parse when tools are present)
+ if (data.contains("parse_tool_calls")) {
+ SRV_DBG("[TOOLS DEBUG] PredictStream: parse_tool_calls=%s\n", data["parse_tool_calls"].get() ? "true" : "false");
+ }
+ } else {
+ // Use prompt directly from data
+ if (data.contains("prompt") && data["prompt"].is_string()) {
+ prompt_str = data["prompt"].get();
+ } else {
+ prompt_str = request->prompt();
+ }
+ }
+
+ const auto type = SERVER_TASK_TYPE_COMPLETION;
+ // TODO: this log can become very long, put it behind a flag or think about a more compact format
+ //SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get().c_str() : prompt.dump(2).c_str());
+
+ // If not using chat templates, extract files from image_data/audio_data fields
+ // (If using chat templates, files were already extracted by oaicompat_chat_params_parse)
+ if (!request->usetokenizertemplate() || request->messages_size() == 0 || ctx_server.impl->chat_templates == nullptr) {
+ const auto &images_data = data.find("image_data");
+ if (images_data != data.end() && images_data->is_array())
+ {
+ for (const auto &img : *images_data)
+ {
+ auto decoded_data = base64_decode(img["data"].get());
+ files.push_back(decoded_data);
+ }
+ }
+
+ const auto &audio_data = data.find("audio_data");
+ if (audio_data != data.end() && audio_data->is_array())
+ {
+ for (const auto &audio : *audio_data)
+ {
+ auto decoded_data = base64_decode(audio["data"].get());
+ files.push_back(decoded_data);
+ }
+ }
+ }
+
+ const bool has_mtmd = ctx_server.impl->mctx != nullptr;
+
+ // process prompt
+ std::vector inputs;
+ if (has_mtmd) {
+ // multimodal
+ inputs.push_back(process_mtmd_prompt(ctx_server.impl->mctx, prompt_str, files));
+ } else {
+ // Everything else, including multimodal completions.
+ inputs = tokenize_input_prompts(ctx_server.impl->vocab, ctx_server.impl->mctx, prompt_str, true, true);
+ }
+
+ tasks.reserve(inputs.size());
+ for (size_t i = 0; i < inputs.size(); i++) {
+ server_task task = server_task(type);
+
+ task.id = rd.queue_tasks.get_new_id();
+ task.index = i;
+
+ task.tokens = std::move(inputs[i]);
+ task.params = server_task::params_from_json_cmpl(
+ ctx_server.impl->vocab,
+ params_base,
+ ctx_server.get_meta().slot_n_ctx,
+ data);
+ task.id_slot = json_value(data, "id_slot", -1);
+
+ // OAI-compat
+ task.params.res_type = TASK_RESPONSE_TYPE_NONE;
+ task.params.oaicompat_cmpl_id = completion_id;
+ // oaicompat_model is already populated by params_from_json_cmpl
+
+ tasks.push_back(std::move(task));
+ }
+
+ rd.post_tasks(std::move(tasks));
+ } catch (const std::exception & e) {
+ return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, e.what());
+ }
+
+ // Get first result for error checking (following server.cpp pattern)
+ server_task_result_ptr first_result = rd.next([&context]() { return context->IsCancelled(); });
+ if (first_result == nullptr) {
+ // connection is closed
+ return grpc::Status(grpc::StatusCode::CANCELLED, "Request cancelled by client");
+ } else if (first_result->is_error()) {
+ json error_json = first_result->to_json();
+ backend::Reply reply;
+ reply.set_message(error_json.value("message", ""));
+ writer->Write(reply);
+ return grpc::Status(grpc::StatusCode::INTERNAL, error_json.value("message", "Error occurred"));
+ }
+
+ // Process first result
+ json first_res_json = first_result->to_json();
+ if (first_res_json.is_array()) {
+ for (const auto & res : first_res_json) {
+ std::string completion_text = res.value("content", "");
+
+ backend::Reply reply;
+ reply.set_message(completion_text);
+ int32_t tokens_predicted = res.value("tokens_predicted", 0);
+ reply.set_tokens(tokens_predicted);
+ int32_t tokens_evaluated = res.value("tokens_evaluated", 0);
+ reply.set_prompt_tokens(tokens_evaluated);
+
+ if (res.contains("timings")) {
+ double timing_prompt_processing = res.at("timings").value("prompt_ms", 0.0);
+ reply.set_timing_prompt_processing(timing_prompt_processing);
+ double timing_token_generation = res.at("timings").value("predicted_ms", 0.0);
+ reply.set_timing_token_generation(timing_token_generation);
+ }
+
+ // Extract and set logprobs if present
+ json logprobs_json = extract_logprobs_from_json(res);
+ if (!logprobs_json.empty() && !logprobs_json.is_null()) {
+ std::string logprobs_str = logprobs_json.dump();
+ reply.set_logprobs(logprobs_str);
+ }
+
+ writer->Write(reply);
+ }
+ } else {
+ std::string completion_text = first_res_json.value("content", "");
+
+ backend::Reply reply;
+ reply.set_message(completion_text);
+ int32_t tokens_predicted = first_res_json.value("tokens_predicted", 0);
+ reply.set_tokens(tokens_predicted);
+ int32_t tokens_evaluated = first_res_json.value("tokens_evaluated", 0);
+ reply.set_prompt_tokens(tokens_evaluated);
+
+ if (first_res_json.contains("timings")) {
+ double timing_prompt_processing = first_res_json.at("timings").value("prompt_ms", 0.0);
+ reply.set_timing_prompt_processing(timing_prompt_processing);
+ double timing_token_generation = first_res_json.at("timings").value("predicted_ms", 0.0);
+ reply.set_timing_token_generation(timing_token_generation);
+ }
+
+ // Extract and set logprobs if present
+ json logprobs_json = extract_logprobs_from_json(first_res_json);
+ if (!logprobs_json.empty() && !logprobs_json.is_null()) {
+ std::string logprobs_str = logprobs_json.dump();
+ reply.set_logprobs(logprobs_str);
+ }
+
+ writer->Write(reply);
+ }
+
+ // Process subsequent results
+ while (rd.has_next()) {
+ // Check if context is cancelled before processing result
+ if (context->IsCancelled()) {
+ break;
+ }
+
+ auto result = rd.next([&context]() { return context->IsCancelled(); });
+ if (result == nullptr) {
+ // connection is closed
+ break;
+ }
+
+ json res_json = result->to_json();
+ if (res_json.is_array()) {
+ for (const auto & res : res_json) {
+ std::string completion_text = res.value("content", "");
+
+ backend::Reply reply;
+ reply.set_message(completion_text);
+ int32_t tokens_predicted = res.value("tokens_predicted", 0);
+ reply.set_tokens(tokens_predicted);
+ int32_t tokens_evaluated = res.value("tokens_evaluated", 0);
+ reply.set_prompt_tokens(tokens_evaluated);
+
+ if (res.contains("timings")) {
+ double timing_prompt_processing = res.at("timings").value("prompt_ms", 0.0);
+ reply.set_timing_prompt_processing(timing_prompt_processing);
+ double timing_token_generation = res.at("timings").value("predicted_ms", 0.0);
+ reply.set_timing_token_generation(timing_token_generation);
+ }
+
+ // Extract and set logprobs if present
+ json logprobs_json = extract_logprobs_from_json(res);
+ if (!logprobs_json.empty() && !logprobs_json.is_null()) {
+ std::string logprobs_str = logprobs_json.dump();
+ reply.set_logprobs(logprobs_str);
+ }
+
+ writer->Write(reply);
+ }
+ } else {
+ std::string completion_text = res_json.value("content", "");
+
+ backend::Reply reply;
+ reply.set_message(completion_text);
+ int32_t tokens_predicted = res_json.value("tokens_predicted", 0);
+ reply.set_tokens(tokens_predicted);
+ int32_t tokens_evaluated = res_json.value("tokens_evaluated", 0);
+ reply.set_prompt_tokens(tokens_evaluated);
+
+ if (res_json.contains("timings")) {
+ double timing_prompt_processing = res_json.at("timings").value("prompt_ms", 0.0);
+ reply.set_timing_prompt_processing(timing_prompt_processing);
+ double timing_token_generation = res_json.at("timings").value("predicted_ms", 0.0);
+ reply.set_timing_token_generation(timing_token_generation);
+ }
+
+ // Extract and set logprobs if present
+ json logprobs_json = extract_logprobs_from_json(res_json);
+ if (!logprobs_json.empty() && !logprobs_json.is_null()) {
+ std::string logprobs_str = logprobs_json.dump();
+ reply.set_logprobs(logprobs_str);
+ }
+
+ writer->Write(reply);
+ }
+ }
+
+ // Check if context was cancelled during processing
+ if (context->IsCancelled()) {
+ return grpc::Status(grpc::StatusCode::CANCELLED, "Request cancelled by client");
+ }
+
+ return grpc::Status::OK;
+ }
+
+ grpc::Status Predict(ServerContext* context, const backend::PredictOptions* request, backend::Reply* reply) override {
+ if (params_base.model.path.empty()) {
+ return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "Model not loaded");
+ }
+ json data = parse_options(true, request, params_base, ctx_server.get_llama_context());
+
+ data["stream"] = false;
+ //Raise error if embeddings is set to true
+ if (params_base.embedding) {
+ return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "Embedding is not supported in Predict mode");
+ }
+ std::cout << "[PREDICT] Received result: " << data.dump(2) << std::endl;
+ auto completion_id = gen_chatcmplid();
+ auto rd = ctx_server.get_response_reader();
+ try {
+ std::vector tasks;
+
+ std::string prompt_str;
+ std::vector files; // Declare files early so it's accessible in both branches
+ // Handle chat templates when UseTokenizerTemplate is enabled and Messages are provided
+ if (request->usetokenizertemplate() && request->messages_size() > 0 && ctx_server.impl->chat_templates != nullptr) {
+ // Convert proto Messages to JSON format compatible with oaicompat_chat_params_parse
+ json body_json;
+ json messages_json = json::array();
+
+ // Find the last user message index to attach images/audio to
+ int last_user_msg_idx = -1;
+ for (int i = request->messages_size() - 1; i >= 0; i--) {
+ if (request->messages(i).role() == "user") {
+ last_user_msg_idx = i;
+ break;
+ }
+ }
+
+ SRV_INF("[CONTENT DEBUG] Predict: Processing %d messages\n", request->messages_size());
+ for (int i = 0; i < request->messages_size(); i++) {
+ const auto& msg = request->messages(i);
+ json msg_json;
+ msg_json["role"] = msg.role();
+
+ SRV_INF("[CONTENT DEBUG] Predict: Message %d: role=%s, content_empty=%d, content_length=%zu\n",
+ i, msg.role().c_str(), msg.content().empty() ? 1 : 0, msg.content().size());
+ if (!msg.content().empty()) {
+ SRV_INF("[CONTENT DEBUG] Predict: Message %d content (first 200 chars): %s\n",
+ i, msg.content().substr(0, std::min(200, msg.content().size())).c_str());
+ }
+
+ bool is_last_user_msg = (i == last_user_msg_idx);
+ bool has_images_or_audio = (request->images_size() > 0 || request->audios_size() > 0);
+
+ // Handle content - can be string, null, or array
+ // For multimodal content, we'll embed images/audio from separate fields
+ if (!msg.content().empty()) {
+ // Try to parse content as JSON to see if it's already an array
+ json content_val;
+ try {
+ content_val = json::parse(msg.content());
+ // Handle null values - convert to empty string to avoid template errors
+ if (content_val.is_null()) {
+ SRV_INF("[CONTENT DEBUG] Predict: Message %d parsed JSON is null, converting to empty string\n", i);
+ content_val = "";
+ }
+ } catch (const json::parse_error&) {
+ // Not JSON, treat as plain string
+ content_val = msg.content();
+ }
+
+ // If content is an object (e.g., from tool call failures), convert to string
+ if (content_val.is_object()) {
+ SRV_INF("[CONTENT DEBUG] Predict: Message %d content is object, converting to string\n", i);
+ content_val = content_val.dump();
+ }
+
+ // If content is a string and this is the last user message with images/audio, combine them
+ if (content_val.is_string() && is_last_user_msg && has_images_or_audio) {
+ json content_array = json::array();
+ // Add text first
+ content_array.push_back({{"type", "text"}, {"text", content_val.get()}});
+ // Add images
+ if (request->images_size() > 0) {
+ for (int j = 0; j < request->images_size(); j++) {
+ json image_chunk;
+ image_chunk["type"] = "image_url";
+ json image_url;
+ image_url["url"] = "data:image/jpeg;base64," + request->images(j);
+ image_chunk["image_url"] = image_url;
+ content_array.push_back(image_chunk);
+ }
+ }
+ // Add audios
+ if (request->audios_size() > 0) {
+ for (int j = 0; j < request->audios_size(); j++) {
+ json audio_chunk;
+ audio_chunk["type"] = "input_audio";
+ json input_audio;
+ input_audio["data"] = request->audios(j);
+ input_audio["format"] = "wav"; // default, could be made configurable
+ audio_chunk["input_audio"] = input_audio;
+ content_array.push_back(audio_chunk);
+ }
+ }
+ msg_json["content"] = content_array;
+ } else {
+ // Use content as-is (already array or not last user message)
+ // Ensure null values are converted to empty string
+ if (content_val.is_null()) {
+ SRV_INF("[CONTENT DEBUG] Predict: Message %d content_val was null, setting to empty string\n", i);
+ msg_json["content"] = "";
+ } else {
+ msg_json["content"] = content_val;
+ SRV_INF("[CONTENT DEBUG] Predict: Message %d content set, type=%s\n",
+ i, content_val.is_string() ? "string" :
+ content_val.is_array() ? "array" :
+ content_val.is_object() ? "object" : "other");
+ }
+ }
+ } else if (is_last_user_msg && has_images_or_audio) {
+ // If no content but this is the last user message with images/audio, create content array
+ json content_array = json::array();
+ if (request->images_size() > 0) {
+ for (int j = 0; j < request->images_size(); j++) {
+ json image_chunk;
+ image_chunk["type"] = "image_url";
+ json image_url;
+ image_url["url"] = "data:image/jpeg;base64," + request->images(j);
+ image_chunk["image_url"] = image_url;
+ content_array.push_back(image_chunk);
+ }
+ }
+ if (request->audios_size() > 0) {
+ for (int j = 0; j < request->audios_size(); j++) {
+ json audio_chunk;
+ audio_chunk["type"] = "input_audio";
+ json input_audio;
+ input_audio["data"] = request->audios(j);
+ input_audio["format"] = "wav"; // default, could be made configurable
+ audio_chunk["input_audio"] = input_audio;
+ content_array.push_back(audio_chunk);
+ }
+ }
+ msg_json["content"] = content_array;
+ SRV_INF("[CONTENT DEBUG] Predict: Message %d created content array with media\n", i);
+ } else if (!msg.tool_calls().empty()) {
+ // Tool call messages may have null content, but templates expect string
+ // IMPORTANT: Set to space " " instead of empty string "", because llama.cpp's
+ // common_chat_msgs_to_json_oaicompat converts empty strings to null (line 312),
+ // which causes template errors when accessing message.content[:tool_start_length]
+ SRV_INF("[CONTENT DEBUG] Predict: Message %d has tool_calls, setting content to space (not empty string)\n", i);
+ msg_json["content"] = " ";
+ } else if (msg.role() == "tool") {
+ // Tool role messages must have content field set, even if empty
+ // Jinja templates expect content to be a string, not null or object
+ SRV_INF("[CONTENT DEBUG] Predict: Message %d is tool role, content_empty=%d\n", i, msg.content().empty() ? 1 : 0);
+ if (msg.content().empty()) {
+ msg_json["content"] = "";
+ SRV_INF("[CONTENT DEBUG] Predict: Message %d (tool): empty content, set to empty string\n", i);
+ } else {
+ SRV_INF("[CONTENT DEBUG] Predict: Message %d (tool): content exists: %s\n",
+ i, msg.content().substr(0, std::min(200, msg.content().size())).c_str());
+ // Content exists, parse and ensure it's a string
+ json content_val;
+ try {
+ content_val = json::parse(msg.content());
+ SRV_INF("[CONTENT DEBUG] Predict: Message %d (tool): parsed JSON, type=%s\n",
+ i, content_val.is_null() ? "null" :
+ content_val.is_object() ? "object" :
+ content_val.is_string() ? "string" :
+ content_val.is_array() ? "array" : "other");
+ // Handle null values - Jinja templates expect content to be a string, not null
+ if (content_val.is_null()) {
+ msg_json["content"] = "";
+ SRV_INF("[CONTENT DEBUG] Predict: Message %d (tool): null content, converted to empty string\n", i);
+ } else if (content_val.is_object()) {
+ // If content is an object (e.g., from tool call failures/errors), convert to string
+ msg_json["content"] = content_val.dump();
+ SRV_INF("[CONTENT DEBUG] Predict: Message %d (tool): object content, converted to string: %s\n",
+ i, content_val.dump().substr(0, std::min(200, content_val.dump().size())).c_str());
+ } else if (content_val.is_string()) {
+ msg_json["content"] = content_val.get();
+ SRV_INF("[CONTENT DEBUG] Predict: Message %d (tool): string content, using as-is\n", i);
+ } else {
+ // For arrays or other types, convert to string
+ msg_json["content"] = content_val.dump();
+ SRV_INF("[CONTENT DEBUG] Predict: Message %d (tool): %s content, converted to string\n",
+ i, content_val.is_array() ? "array" : "other type");
+ }
+ } catch (const json::parse_error&) {
+ // Not JSON, treat as plain string
+ msg_json["content"] = msg.content();
+ SRV_INF("[CONTENT DEBUG] Predict: Message %d (tool): not JSON, using as string\n", i);
+ }
+ }
+ } else {
+ // Ensure all messages have content set (fallback for any unhandled cases)
+ // Jinja templates expect content to be present, default to empty string if not set
+ if (!msg_json.contains("content")) {
+ SRV_INF("[CONTENT DEBUG] Predict: Message %d (role=%s): no content field, adding empty string\n",
+ i, msg.role().c_str());
+ msg_json["content"] = "";
+ }
+ }
+
+ // Add optional fields for OpenAI-compatible message format
+ if (!msg.name().empty()) {
+ msg_json["name"] = msg.name();
+ }
+ if (!msg.tool_call_id().empty()) {
+ msg_json["tool_call_id"] = msg.tool_call_id();
+ }
+ if (!msg.reasoning_content().empty()) {
+ msg_json["reasoning_content"] = msg.reasoning_content();
+ }
+ if (!msg.tool_calls().empty()) {
+ // Parse tool_calls JSON string and add to message
+ try {
+ json tool_calls = json::parse(msg.tool_calls());
+ msg_json["tool_calls"] = tool_calls;
+ SRV_INF("[TOOL CALLS DEBUG] Predict: Message %d has tool_calls: %s\n", i, tool_calls.dump().c_str());
+ // IMPORTANT: If message has tool_calls but content is empty or not set,
+ // set content to space " " instead of empty string "", because llama.cpp's
+ // common_chat_msgs_to_json_oaicompat converts empty strings to null (line 312),
+ // which causes template errors when accessing message.content[:tool_start_length]
+ if (!msg_json.contains("content") || (msg_json.contains("content") && msg_json["content"].is_string() && msg_json["content"].get().empty())) {
+ SRV_INF("[CONTENT DEBUG] Predict: Message %d has tool_calls but empty content, setting to space\n", i);
+ msg_json["content"] = " ";
+ }
+ // Log each tool call with name and arguments
+ if (tool_calls.is_array()) {
+ for (size_t tc_idx = 0; tc_idx < tool_calls.size(); tc_idx++) {
+ const auto& tc = tool_calls[tc_idx];
+ std::string tool_name = "unknown";
+ std::string tool_args = "{}";
+ if (tc.contains("function")) {
+ const auto& func = tc["function"];
+ if (func.contains("name")) {
+ tool_name = func["name"].get();
+ }
+ if (func.contains("arguments")) {
+ tool_args = func["arguments"].is_string() ?
+ func["arguments"].get() :
+ func["arguments"].dump();
+ }
+ } else if (tc.contains("name")) {
+ tool_name = tc["name"].get();
+ if (tc.contains("arguments")) {
+ tool_args = tc["arguments"].is_string() ?
+ tc["arguments"].get() :
+ tc["arguments"].dump();
+ }
+ }
+ SRV_INF("[TOOL CALLS DEBUG] Predict: Message %d, tool_call %zu: name=%s, arguments=%s\n",
+ i, tc_idx, tool_name.c_str(), tool_args.c_str());
+ }
+ }
+ } catch (const json::parse_error& e) {
+ SRV_WRN("Failed to parse tool_calls JSON: %s\n", e.what());
+ }
+ }
+
+ // Debug: Log final content state before adding to array
+ if (msg_json.contains("content")) {
+ if (msg_json["content"].is_null()) {
+ SRV_INF("[CONTENT DEBUG] Predict: Message %d FINAL STATE: content is NULL - THIS WILL CAUSE ERROR!\n", i);
+ } else {
+ SRV_INF("[CONTENT DEBUG] Predict: Message %d FINAL STATE: content type=%s, has_value=%d\n",
+ i, msg_json["content"].is_string() ? "string" :
+ msg_json["content"].is_array() ? "array" :
+ msg_json["content"].is_object() ? "object" : "other",
+ msg_json["content"].is_null() ? 0 : 1);
+ }
+ } else {
+ SRV_INF("[CONTENT DEBUG] Predict: Message %d FINAL STATE: NO CONTENT FIELD - THIS WILL CAUSE ERROR!\n", i);
+ }
+
+ messages_json.push_back(msg_json);
+ }
+
+ // Final safety check: Ensure no message has null content (Jinja templates require strings)
+ SRV_INF("[CONTENT DEBUG] Predict: Running final safety check on %zu messages\n", messages_json.size());
+ for (size_t idx = 0; idx < messages_json.size(); idx++) {
+ auto& msg = messages_json[idx];
+ std::string role_str = msg.contains("role") ? msg["role"].get() : "unknown";
+ if (msg.contains("content") && msg["content"].is_null()) {
+ SRV_INF("[CONTENT DEBUG] Predict: Safety check found message %zu (role=%s) with NULL content, converting to empty string\n", idx, role_str.c_str());
+ msg["content"] = "";
+ } else if (!msg.contains("content")) {
+ SRV_INF("[CONTENT DEBUG] Predict: Safety check found message %zu (role=%s) without content field, adding empty string\n", idx, role_str.c_str());
+ msg["content"] = "";
+ } else {
+ SRV_INF("[CONTENT DEBUG] Predict: Safety check message %zu (role=%s): content OK, type=%s\n",
+ idx, role_str.c_str(),
+ msg["content"].is_string() ? "string" :
+ msg["content"].is_array() ? "array" :
+ msg["content"].is_object() ? "object" : "other");
+ }
+ }
+
+ // Debug: Count tool messages
+ int tool_msg_count = 0;
+ for (const auto& msg : messages_json) {
+ if (msg.contains("role") && msg["role"] == "tool") {
+ tool_msg_count++;
+ }
+ }
+ SRV_DBG("[TOOLS DEBUG] Predict: Built %d tool messages out of %zu total messages\n", tool_msg_count, messages_json.size());
+
+ // Debug: Print full conversation (messages)
+ SRV_DBG("[CONVERSATION DEBUG] Predict: Full messages array:\n%s\n", messages_json.dump(2).c_str());
+
+ body_json["messages"] = messages_json;
+ body_json["stream"] = false;
+
+ // Check if grammar is provided from Go layer (NoGrammar=false)
+ // If grammar is provided, we must use it and NOT let template generate grammar from tools
+ // oaicompat_chat_params_parse throws an error if both grammar and tools are provided
+ bool has_grammar_from_go = data.contains("grammar") &&
+ data["grammar"].is_string() &&
+ !data["grammar"].get().empty();
+
+ SRV_INF("[TOOLS DEBUG] Predict: has_grammar_from_go=%d, data.contains(\"tools\")=%d, data.contains(\"grammar\")=%d\n",
+ has_grammar_from_go ? 1 : 0,
+ data.contains("tools") ? 1 : 0,
+ data.contains("grammar") ? 1 : 0);
+ if (data.contains("grammar")) {
+ SRV_INF("[TOOLS DEBUG] Predict: grammar type=%s, empty=%d\n",
+ data["grammar"].is_string() ? "string" : "other",
+ data["grammar"].is_string() && data["grammar"].get().empty() ? 1 : 0);
+ }
+
+ // Copy other relevant fields from data that oaicompat_chat_params_parse expects
+ // Tools and tool_choice are only passed when NoGrammar is true (grammar not provided)
+ // When grammar is provided from Go layer, we use it instead of template-generated grammar
+ if (!has_grammar_from_go) {
+ // NoGrammar=true: pass tools and let template generate grammar
+ if (data.contains("tools")) {
+ body_json["tools"] = data["tools"];
+ std::string tools_str = data["tools"].dump();
+ SRV_INF("Using tools from data (NoGrammar=true): %s\n", tools_str.c_str());
+ // Debug: Log tools count and details before template processing
+ if (data["tools"].is_array()) {
+ SRV_INF("[TOOLS DEBUG] Predict: Passing %zu tools to oaicompat_chat_params_parse\n", data["tools"].size());
+ for (size_t t_idx = 0; t_idx < data["tools"].size(); t_idx++) {
+ const auto& tool = data["tools"][t_idx];
+ std::string tool_name = "unknown";
+ std::string tool_desc = "";
+ if (tool.contains("function")) {
+ const auto& func = tool["function"];
+ if (func.contains("name")) {
+ tool_name = func["name"].get();
+ }
+ if (func.contains("description")) {
+ tool_desc = func["description"].is_string() ?
+ func["description"].get() : "";
+ }
+ } else if (tool.contains("name")) {
+ tool_name = tool["name"].get();
+ if (tool.contains("description")) {
+ tool_desc = tool["description"].is_string() ?
+ tool["description"].get() : "";
+ }
+ }
+ SRV_INF("[TOOLS DEBUG] Predict: Tool %zu: name=%s, description=%s\n",
+ t_idx, tool_name.c_str(), tool_desc.substr(0, 100).c_str());
+ }
+ }
+ } else {
+ SRV_WRN("%s", "No tools found in data - tool calls will not work without tools field\n");
+ SRV_DBG("[TOOLS DEBUG] Predict: No tools in data, tool_choice=%s\n", data.contains("tool_choice") ? data["tool_choice"].dump().c_str() : "not set");
+ }
+ if (data.contains("tool_choice")) {
+ // tool_choice can be a string or object, but oaicompat_chat_params_parse expects a string
+ // Convert object tool_choice to "required" (since a specific function is requested)
+ if (data["tool_choice"].is_string()) {
+ body_json["tool_choice"] = data["tool_choice"].get();
+ } else if (data["tool_choice"].is_object()) {
+ // Object tool_choice means a specific function is requested, use "required"
+ body_json["tool_choice"] = "required";
+ std::string tool_choice_obj_str = data["tool_choice"].dump();
+ SRV_INF("Converted object tool_choice to 'required': %s\n", tool_choice_obj_str.c_str());
+ } else {
+ // Fallback: convert to string
+ body_json["tool_choice"] = data["tool_choice"].dump();
+ }
+ std::string tool_choice_str = body_json["tool_choice"].get();
+ SRV_INF("Using tool_choice: %s\n", tool_choice_str.c_str());
+ } else {
+ // Default to "auto" if not specified
+ body_json["tool_choice"] = "auto";
+ }
+ } else {
+ // Grammar is provided from Go layer (NoGrammar=false) - use it, don't pass tools
+ SRV_INF("%s", "Grammar provided from Go layer - using it instead of template-generated grammar\n");
+ // Grammar will be copied from data after parsing (it's already in data)
+ }
+
+ if (data.contains("json_schema")) {
+ body_json["json_schema"] = data["json_schema"];
+ }
+ // If grammar is provided from Go layer, copy it to body_json so it's preserved
+ // (though oaicompat_chat_params_parse may not use it if tools are present)
+ if (has_grammar_from_go) {
+ body_json["grammar"] = data["grammar"];
+ }
+ if (data.contains("response_format")) {
+ body_json["response_format"] = data["response_format"];
+ }
+ if (data.contains("chat_template_kwargs")) {
+ body_json["chat_template_kwargs"] = data["chat_template_kwargs"];
+ }
+ // Pass parallel_tool_calls if present (used by oaicompat_chat_params_parse)
+ if (data.contains("parallel_tool_calls")) {
+ body_json["parallel_tool_calls"] = data["parallel_tool_calls"];
+ }
+ // Pass add_generation_prompt if present (used by oaicompat_chat_params_parse)
+ if (data.contains("add_generation_prompt")) {
+ body_json["add_generation_prompt"] = data["add_generation_prompt"];
+ }
+
+ // Debug: Print full body_json before template processing (includes messages, tools, tool_choice, etc.)
+ SRV_DBG("[CONVERSATION DEBUG] Predict: Full body_json before oaicompat_chat_params_parse:\n%s\n", body_json.dump(2).c_str());
+
+ // Use the same approach as server.cpp: call oaicompat_chat_params_parse
+ // This handles all template application, grammar merging, etc. automatically
+ // Files extracted from multimodal content in messages will be added to the files vector
+ // Create parser options with current chat_templates to ensure tmpls is not null
+ oaicompat_parser_options parser_opt = ctx_server.impl->oai_parser_opt;
+ parser_opt.tmpls = ctx_server.impl->chat_templates.get(); // Ensure tmpls is set to current chat_templates
+ // Update allow_image and allow_audio based on current mctx state
+ parser_opt.allow_image = ctx_server.impl->mctx ? mtmd_support_vision(ctx_server.impl->mctx) : false;
+ parser_opt.allow_audio = ctx_server.impl->mctx ? mtmd_support_audio(ctx_server.impl->mctx) : false;
+
+ // Debug: Log tools before template processing
+ if (body_json.contains("tools")) {
+ SRV_DBG("[TOOLS DEBUG] Predict: Before oaicompat_chat_params_parse - tools count: %zu\n",
+ body_json["tools"].is_array() ? body_json["tools"].size() : 0);
+ }
+
+ // Debug: Verify messages content before template processing
+ // Also ensure ALL messages have content set to string (not null) - templates expect strings
+ if (body_json.contains("messages") && body_json["messages"].is_array()) {
+ SRV_INF("[CONTENT DEBUG] Predict: Before oaicompat_chat_params_parse - checking %zu messages\n", body_json["messages"].size());
+ for (size_t idx = 0; idx < body_json["messages"].size(); idx++) {
+ auto& msg = body_json["messages"][idx];
+ std::string role_str = msg.contains("role") ? msg["role"].get() : "unknown";
+ if (msg.contains("content")) {
+ if (msg["content"].is_null()) {
+ SRV_INF("[CONTENT DEBUG] Predict: BEFORE TEMPLATE - Message %zu (role=%s) has NULL content - FIXING!\n", idx, role_str.c_str());
+ msg["content"] = ""; // Fix null content
+ } else if (role_str == "tool" && msg["content"].is_array()) {
+ // Tool messages must have string content, not array
+ // oaicompat_chat_params_parse expects tool messages to have string content
+ SRV_INF("[CONTENT DEBUG] Predict: BEFORE TEMPLATE - Message %zu (role=tool) has array content, converting to string\n", idx);
+ msg["content"] = msg["content"].dump();
+ } else if (!msg["content"].is_string() && !msg["content"].is_array()) {
+ // If content is object or other non-string type, convert to string for templates
+ SRV_INF("[CONTENT DEBUG] Predict: BEFORE TEMPLATE - Message %zu (role=%s) content is not string/array, converting\n", idx, role_str.c_str());
+ if (msg["content"].is_object()) {
+ msg["content"] = msg["content"].dump();
+ } else {
+ msg["content"] = "";
+ }
+ } else {
+ SRV_INF("[CONTENT DEBUG] Predict: BEFORE TEMPLATE - Message %zu (role=%s): content type=%s\n",
+ idx, role_str.c_str(),
+ msg["content"].is_string() ? "string" :
+ msg["content"].is_array() ? "array" :
+ msg["content"].is_object() ? "object" : "other");
+ }
+ } else {
+ SRV_INF("[CONTENT DEBUG] Predict: BEFORE TEMPLATE - Message %zu (role=%s) MISSING content field - ADDING!\n", idx, role_str.c_str());
+ msg["content"] = ""; // Add missing content
+ }
+ }
+ }
+
+ json parsed_data = oaicompat_chat_params_parse(body_json, parser_opt, files);
+
+ // Debug: Log tools after template processing
+ if (parsed_data.contains("tools")) {
+ SRV_DBG("[TOOLS DEBUG] Predict: After oaicompat_chat_params_parse - tools count: %zu\n",
+ parsed_data["tools"].is_array() ? parsed_data["tools"].size() : 0);
+ } else {
+ SRV_DBG("%s", "[TOOLS DEBUG] Predict: After oaicompat_chat_params_parse - no tools in parsed_data\n");
+ }
+
+ // Extract the prompt from parsed data
+ prompt_str = parsed_data.at("prompt").get();
+
+ // Preserve grammar from Go layer if it was provided (NoGrammar=false)
+ // Otherwise, use grammar from parsed_data (template-generated when NoGrammar=true)
+ json preserved_grammar;
+ if (has_grammar_from_go && data.contains("grammar")) {
+ preserved_grammar = data["grammar"];
+ }
+
+ // Merge all fields from parsed_data into data (grammar, grammar_triggers, preserved_tokens, parse_tool_calls, etc.)
+ // This ensures all template-generated fields are included
+ // parse_tool_calls is set by oaicompat_chat_params_parse when tools are present
+ for (const auto& item : parsed_data.items()) {
+ if (item.key() != "prompt") { // Don't overwrite prompt_str, we already extracted it
+ // If grammar was provided from Go layer, preserve it instead of template-generated grammar
+ if (item.key() == "grammar" && has_grammar_from_go && !preserved_grammar.is_null()) {
+ data["grammar"] = preserved_grammar;
+ } else {
+ data[item.key()] = item.value();
+ }
+ }
+ }
+
+ // Debug: Log parse_tool_calls if present (set by oaicompat_chat_params_parse when tools are present)
+ if (data.contains("parse_tool_calls")) {
+ SRV_DBG("[TOOLS DEBUG] Predict: parse_tool_calls=%s\n", data["parse_tool_calls"].get() ? "true" : "false");
+ }
+ } else {
+ // Use prompt directly from data
+ if (data.contains("prompt") && data["prompt"].is_string()) {
+ prompt_str = data["prompt"].get();
+ } else {
+ prompt_str = request->prompt();
+ }
+ }
+
+ const auto type = SERVER_TASK_TYPE_COMPLETION;
+ // TODO: this log can become very long, put it behind a flag or think about a more compact format
+ //SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get().c_str() : prompt.dump(2).c_str());
+
+ // If not using chat templates, extract files from image_data/audio_data fields
+ // (If using chat templates, files were already extracted by oaicompat_chat_params_parse)
+ if (!request->usetokenizertemplate() || request->messages_size() == 0 || ctx_server.impl->chat_templates == nullptr) {
+ const auto &images_data = data.find("image_data");
+ if (images_data != data.end() && images_data->is_array())
+ {
+ std::cout << "[PREDICT] Processing " << images_data->size() << " images" << std::endl;
+ for (const auto &img : *images_data)
+ {
+ std::cout << "[PREDICT] Processing image" << std::endl;
+ auto decoded_data = base64_decode(img["data"].get());
+ files.push_back(decoded_data);
+ }
+ }
+
+ const auto &audio_data = data.find("audio_data");
+ if (audio_data != data.end() && audio_data->is_array())
+ {
+ for (const auto &audio : *audio_data)
+ {
+ auto decoded_data = base64_decode(audio["data"].get());
+ files.push_back(decoded_data);
+ }
+ }
+ }
+
+ // process files
+ const bool has_mtmd = ctx_server.impl->mctx != nullptr;
+
+ // process prompt
+ std::vector inputs;
+ if (has_mtmd) {
+ // multimodal
+ inputs.push_back(process_mtmd_prompt(ctx_server.impl->mctx, prompt_str, files));
+ } else {
+ // Everything else, including multimodal completions.
+ inputs = tokenize_input_prompts(ctx_server.impl->vocab, ctx_server.impl->mctx, prompt_str, true, true);
+ }
+
+ tasks.reserve(inputs.size());
+ for (size_t i = 0; i < inputs.size(); i++) {
+ server_task task = server_task(type);
+
+ task.id = rd.queue_tasks.get_new_id();
+ task.index = i;
+
+ task.tokens = std::move(inputs[i]);
+ task.params = server_task::params_from_json_cmpl(
+ ctx_server.impl->vocab,
+ params_base,
+ ctx_server.get_meta().slot_n_ctx,
+ data);
+ task.id_slot = json_value(data, "id_slot", -1);
+
+ // OAI-compat
+ task.params.res_type = TASK_RESPONSE_TYPE_NONE;
+ task.params.oaicompat_cmpl_id = completion_id;
+ // oaicompat_model is already populated by params_from_json_cmpl
+
+ tasks.push_back(std::move(task));
+ }
+
+ rd.post_tasks(std::move(tasks));
+ } catch (const std::exception & e) {
+ return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, e.what());
+ }
+
+
+ std::cout << "[DEBUG] Waiting for results..." << std::endl;
+
+ // Wait for all results
+ auto all_results = rd.wait_for_all([&context]() { return context->IsCancelled(); });
+
+ if (all_results.is_terminated) {
+ return grpc::Status(grpc::StatusCode::CANCELLED, "Request cancelled by client");
+ } else if (all_results.error) {
+ std::cout << "[DEBUG] Error in results: " << all_results.error->to_json().value("message", "") << std::endl;
+ reply->set_message(all_results.error->to_json().value("message", ""));
+ return grpc::Status(grpc::StatusCode::INTERNAL, all_results.error->to_json().value("message", "Error occurred"));
+ } else {
+ std::cout << "[DEBUG] Received " << all_results.results.size() << " results" << std::endl;
+ if (all_results.results.size() == 1) {
+ // single result
+ GGML_ASSERT(dynamic_cast(all_results.results[0].get()) != nullptr);
+ json result_json = all_results.results[0]->to_json();
+ reply->set_message(result_json.value("content", ""));
+
+ int32_t tokens_predicted = result_json.value("tokens_predicted", 0);
+ reply->set_tokens(tokens_predicted);
+ int32_t tokens_evaluated = result_json.value("tokens_evaluated", 0);
+ reply->set_prompt_tokens(tokens_evaluated);
+
+ if (result_json.contains("timings")) {
+ double timing_prompt_processing = result_json.at("timings").value("prompt_ms", 0.0);
+ reply->set_timing_prompt_processing(timing_prompt_processing);
+ double timing_token_generation = result_json.at("timings").value("predicted_ms", 0.0);
+ reply->set_timing_token_generation(timing_token_generation);
+ }
+
+ // Extract and set logprobs if present
+ json logprobs_json = extract_logprobs_from_json(result_json);
+ if (!logprobs_json.empty() && !logprobs_json.is_null()) {
+ std::string logprobs_str = logprobs_json.dump();
+ reply->set_logprobs(logprobs_str);
+ }
+
+ } else {
+ // multiple results (multitask)
+ json arr = json::array();
+ json logprobs_arr = json::array();
+ bool has_logprobs = false;
+ for (auto & res : all_results.results) {
+ GGML_ASSERT(dynamic_cast(res.get()) != nullptr);
+ json res_json = res->to_json();
+ arr.push_back(res_json.value("content", ""));
+
+ // Extract logprobs for each result
+ json logprobs_json = extract_logprobs_from_json(res_json);
+ if (!logprobs_json.empty() && !logprobs_json.is_null()) {
+ has_logprobs = true;
+ logprobs_arr.push_back(logprobs_json);
+ } else {
+ logprobs_arr.push_back(json::object());
+ }
+ }
+ reply->set_message(arr);
+
+ // Set logprobs if any result has them
+ if (has_logprobs) {
+ std::string logprobs_str = logprobs_arr.dump();
+ reply->set_logprobs(logprobs_str);
+ }
+ }
+ }
+
+ std::cout << "[DEBUG] Predict request completed successfully" << std::endl;
+
+ // Check if context was cancelled during processing
+ if (context->IsCancelled()) {
+ return grpc::Status(grpc::StatusCode::CANCELLED, "Request cancelled by client");
+ }
+
+ return grpc::Status::OK;
+ }
+
+ grpc::Status Embedding(ServerContext* context, const backend::PredictOptions* request, backend::EmbeddingResult* embeddingResult) override {
+ if (params_base.model.path.empty()) {
+ return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "Model not loaded");
+ }
+ json body = parse_options(false, request, params_base, ctx_server.get_llama_context());
+
+ body["stream"] = false;
+
+ /*
+ if (llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) {
+ return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "Pooling type 'none' is not OAI compatible. Please use a different pooling type");
+ }
+ */
+
+ // for the shape of input/content, see tokenize_input_prompts()
+ json prompt = body.at("embeddings");
+
+
+ auto tokenized_prompts = tokenize_input_prompts(ctx_server.impl->vocab, ctx_server.impl->mctx, prompt, true, true);
+ for (const auto & tokens : tokenized_prompts) {
+ // this check is necessary for models that do not add BOS token to the input
+ if (tokens.empty()) {
+ return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "Input content cannot be empty");
+ }
+ }
+
+ int embd_normalize = 2; // default to Euclidean/L2 norm
+ // create and queue the task
+ auto rd = ctx_server.get_response_reader();
+ {
+ std::vector tasks;
+ for (size_t i = 0; i < tokenized_prompts.size(); i++) {
+ server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING);
+
+ task.id = rd.queue_tasks.get_new_id();
+ task.index = i;
+ task.tokens = std::move(tokenized_prompts[i]);
+
+ task.params.res_type = TASK_RESPONSE_TYPE_NONE;
+ task.params.embd_normalize = embd_normalize;
+ tasks.push_back(std::move(task));
+ }
+
+ rd.post_tasks(std::move(tasks));
+ }
+
+ // Wait for all results
+ auto all_results = rd.wait_for_all([&context]() { return context->IsCancelled(); });
+
+ if (all_results.is_terminated) {
+ return grpc::Status(grpc::StatusCode::CANCELLED, "Request cancelled by client");
+ } else if (all_results.error) {
+ return grpc::Status(grpc::StatusCode::INTERNAL, all_results.error->to_json().value("message", "Error in receiving results"));
+ }
+
+ // Collect responses
+ json responses = json::array();
+ for (auto & res : all_results.results) {
+ GGML_ASSERT(dynamic_cast(res.get()) != nullptr);
+ responses.push_back(res->to_json());
+ }
+
+ std::cout << "[DEBUG] Responses size: " << responses.size() << std::endl;
+
+ // Process the responses and extract embeddings
+ for (const auto & response_elem : responses) {
+ // Check if the response has an "embedding" field
+ if (response_elem.contains("embedding")) {
+ json embedding_data = json_value(response_elem, "embedding", json::array());
+
+ if (embedding_data.is_array() && !embedding_data.empty()) {
+ for (const auto & embedding_vector : embedding_data) {
+ if (embedding_vector.is_array()) {
+ for (const auto & embedding_value : embedding_vector) {
+ embeddingResult->add_embeddings(embedding_value.get());
+ }
+ }
+ }
+ }
+ } else {
+ // Check if the response itself contains the embedding data directly
+ if (response_elem.is_array()) {
+ for (const auto & embedding_value : response_elem) {
+ embeddingResult->add_embeddings(embedding_value.get());
+ }
+ }
+ }
+ }
+
+
+
+
+ return grpc::Status::OK;
+ }
+
+ grpc::Status Rerank(ServerContext* context, const backend::RerankRequest* request, backend::RerankResult* rerankResult) override {
+ if (!params_base.embedding || params_base.pooling_type != LLAMA_POOLING_TYPE_RANK) {
+ return grpc::Status(grpc::StatusCode::UNIMPLEMENTED, "This server does not support reranking. Start it with `--reranking` and without `--embedding`");
+ }
+
+ // Validate request
+ if (request->query().empty()) {
+ return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "\"query\" must be provided");
+ }
+
+ if (request->documents_size() == 0) {
+ return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "\"documents\" must be a non-empty string array");
+ }
+
+ // Create and queue the task
+ auto rd = ctx_server.get_response_reader();
+ {
+ std::vector tasks;
+ std::vector documents;
+ for (int i = 0; i < request->documents_size(); i++) {
+ documents.push_back(request->documents(i));
+ }
+
+ tasks.reserve(documents.size());
+ for (size_t i = 0; i < documents.size(); i++) {
+ auto tmp = format_prompt_rerank(ctx_server.impl->model, ctx_server.impl->vocab, ctx_server.impl->mctx, request->query(), documents[i]);
+ server_task task = server_task(SERVER_TASK_TYPE_RERANK);
+ task.id = rd.queue_tasks.get_new_id();
+ task.index = i;
+ task.tokens = std::move(tmp);
+ tasks.push_back(std::move(task));
+ }
+
+ rd.post_tasks(std::move(tasks));
+ }
+
+ // Wait for all results
+ auto all_results = rd.wait_for_all([&context]() { return context->IsCancelled(); });
+
+ if (all_results.is_terminated) {
+ return grpc::Status(grpc::StatusCode::CANCELLED, "Request cancelled by client");
+ } else if (all_results.error) {
+ return grpc::Status(grpc::StatusCode::INTERNAL, all_results.error->to_json().value("message", "Error in receiving results"));
+ }
+
+ // Collect responses
+ json responses = json::array();
+ for (auto & res : all_results.results) {
+ GGML_ASSERT(dynamic_cast(res.get()) != nullptr);
+ responses.push_back(res->to_json());
+ }
+ // Sort responses by score in descending order
+ std::sort(responses.begin(), responses.end(), [](const json& a, const json& b) {
+ return a.value("score", 0.0f) > b.value("score", 0.0f);
+ });
+
+ // Crop results by request.top_n if specified
+ int top_n = request->top_n();
+ if (top_n > 0 && top_n < static_cast(responses.size())) {
+ responses = json(responses.begin(), responses.begin() + top_n);
+ }
+ // Set usage information
+ backend::Usage* usage = rerankResult->mutable_usage();
+ int total_tokens = 0;
+ int prompt_tokens = 0;
+
+ // Create document results
+ for (const auto& response : responses) {
+ backend::DocumentResult* doc_result = rerankResult->add_results();
+ doc_result->set_index(response.value("index", 0));
+ doc_result->set_text(request->documents(response.value("index", 0)));
+ doc_result->set_relevance_score(response.value("score", 0.0f));
+
+ // Add tokens evaluated for this document
+ int tokens_evaluated = response.value("tokens_evaluated", 0);
+ total_tokens += tokens_evaluated;
+ prompt_tokens += tokens_evaluated;
+ }
+
+ // Set the total tokens in usage
+ usage->set_total_tokens(total_tokens);
+ usage->set_prompt_tokens(prompt_tokens);
+
+ return grpc::Status::OK;
+ }
+
+ grpc::Status TokenizeString(ServerContext* /*context*/, const backend::PredictOptions* request, backend::TokenizationResponse* response) override {
+ if (params_base.model.path.empty()) {
+ return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "Model not loaded");
+ }
+ json body = parse_options(false, request, params_base, ctx_server.get_llama_context());
+ body["stream"] = false;
+
+ json tokens_response = json::array();
+ if (body.count("prompt") != 0) {
+ const bool add_special = json_value(body, "add_special", false);
+
+ llama_tokens tokens = tokenize_mixed(ctx_server.impl->vocab, body.at("content"), add_special, true);
+
+
+ for (const auto& token : tokens) {
+ std::string piece = common_token_to_piece(ctx_server.get_llama_context(), token);
+ response->add_tokens(token);
+ }
+ }
+
+ return grpc::Status::OK;
+ }
+
+ grpc::Status GetMetrics(ServerContext* /*context*/, const backend::MetricsRequest* /*request*/, backend::MetricsResponse* response) override {
+
+// request slots data using task queue
+ auto rd = ctx_server.get_response_reader();
+ int task_id = rd.queue_tasks.get_new_id();
+ {
+ server_task task(SERVER_TASK_TYPE_METRICS);
+ task.id = task_id;
+ rd.queue_results.add_waiting_task_id(task_id);
+ rd.queue_tasks.post(std::move(task), true); // high-priority task
+ }
+
+ // get the result
+ server_task_result_ptr result = rd.queue_results.recv(task_id);
+ rd.queue_results.remove_waiting_task_id(task_id);
+
+ if (result->is_error()) {
+ // Handle case when no active slot exists
+ response->set_slot_id(0);
+ response->set_prompt_json_for_slot("");
+ response->set_tokens_per_second(0);
+ response->set_tokens_generated(0);
+ response->set_prompt_tokens_processed(0);
+ return grpc::Status(grpc::StatusCode::INTERNAL, "Error in receiving results");
+ }
+
+ // TODO: get rid of this dynamic_cast
+ auto res_metrics = dynamic_cast(result.get());
+ GGML_ASSERT(res_metrics != nullptr);
+
+ // Populate the response with metrics
+ response->set_slot_id(0);
+ response->set_prompt_json_for_slot("");
+ response->set_tokens_per_second(res_metrics->n_prompt_tokens_processed ? 1.e3 / res_metrics->t_prompt_processing * res_metrics->n_prompt_tokens_processed : 0.);
+ response->set_tokens_generated(res_metrics->n_tokens_predicted_total);
+ response->set_prompt_tokens_processed(res_metrics->n_prompt_tokens_processed_total);
+
+
+ return grpc::Status::OK;
+ }
+};
+
+
+int main(int argc, char** argv) {
+ std::string server_address("localhost:50051");
+
+ // Define long and short options
+ struct option long_options[] = {
+ {"addr", required_argument, nullptr, 'a'},
+ {nullptr, 0, nullptr, 0}
+ };
+
+ // Parse command-line arguments
+ int option;
+ int option_index = 0;
+ while ((option = getopt_long(argc, argv, "a:", long_options, &option_index)) != -1) {
+ switch (option) {
+ case 'a':
+ server_address = optarg;
+ break;
+ default:
+ std::cerr << "Usage: " << argv[0] << " [--addr=] or [-a ]" << std::endl;
+ return 1;
+ }
+ }
+
+ server_context ctx_server;
+ BackendServiceImpl service(ctx_server);
+
+ ServerBuilder builder;
+ builder.AddListeningPort(server_address, grpc::InsecureServerCredentials());
+ builder.RegisterService(&service);
+ builder.SetMaxMessageSize(50 * 1024 * 1024); // 50MB
+ builder.SetMaxSendMessageSize(50 * 1024 * 1024); // 50MB
+ builder.SetMaxReceiveMessageSize(50 * 1024 * 1024); // 50MB
+ std::unique_ptr server(builder.BuildAndStart());
+ // run the HTTP server in a thread - see comment below
+ std::thread t([&]()
+ {
+ std::cout << "Server listening on " << server_address << std::endl;
+ server->Wait();
+ return 0;
+ });
+
+ // clean up function, to be called before exit
+ auto clean_up = [&server, &ctx_server]() {
+ SRV_INF("%s: cleaning up before exit...\n", __func__);
+ server->Shutdown();
+ ctx_server.terminate();
+ llama_backend_free();
+ };
+
+
+ //);
+ start_llama_server(ctx_server);
+ std::cout << "stopping" << std::endl;
+
+
+ clean_up();
+ t.join();
+
+ return 0;
+}
diff --git a/backend/cpp/llama-cpp/package.sh b/backend/cpp/llama-cpp/package.sh
new file mode 100644
index 0000000000000000000000000000000000000000..b1b7cd9a818a36cd3fb2217886ffb798c7d80c06
--- /dev/null
+++ b/backend/cpp/llama-cpp/package.sh
@@ -0,0 +1,52 @@
+#!/bin/bash
+
+# Script to copy the appropriate libraries based on architecture
+# This script is used in the final stage of the Dockerfile
+
+set -e
+
+CURDIR=$(dirname "$(realpath $0)")
+REPO_ROOT="${CURDIR}/../../.."
+
+# Create lib directory
+mkdir -p $CURDIR/package/lib
+
+cp -avrf $CURDIR/llama-cpp-* $CURDIR/package/
+cp -rfv $CURDIR/run.sh $CURDIR/package/
+
+# Detect architecture and copy appropriate libraries
+if [ -f "/lib64/ld-linux-x86-64.so.2" ]; then
+ # x86_64 architecture
+ echo "Detected x86_64 architecture, copying x86_64 libraries..."
+ cp -arfLv /lib64/ld-linux-x86-64.so.2 $CURDIR/package/lib/ld.so
+ cp -arfLv /lib/x86_64-linux-gnu/libc.so.6 $CURDIR/package/lib/libc.so.6
+ cp -arfLv /lib/x86_64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1
+ cp -arfLv /lib/x86_64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6
+ cp -arfLv /lib/x86_64-linux-gnu/libm.so.6 $CURDIR/package/lib/libm.so.6
+ cp -arfLv /lib/x86_64-linux-gnu/libgomp.so.1 $CURDIR/package/lib/libgomp.so.1
+elif [ -f "/lib/ld-linux-aarch64.so.1" ]; then
+ # ARM64 architecture
+ echo "Detected ARM64 architecture, copying ARM64 libraries..."
+ cp -arfLv /lib/ld-linux-aarch64.so.1 $CURDIR/package/lib/ld.so
+ cp -arfLv /lib/aarch64-linux-gnu/libc.so.6 $CURDIR/package/lib/libc.so.6
+ cp -arfLv /lib/aarch64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1
+ cp -arfLv /lib/aarch64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6
+ cp -arfLv /lib/aarch64-linux-gnu/libm.so.6 $CURDIR/package/lib/libm.so.6
+ cp -arfLv /lib/aarch64-linux-gnu/libgomp.so.1 $CURDIR/package/lib/libgomp.so.1
+else
+ echo "Error: Could not detect architecture"
+ exit 1
+fi
+
+# Package GPU libraries based on BUILD_TYPE
+# The GPU library packaging script will detect BUILD_TYPE and copy appropriate GPU libraries
+GPU_LIB_SCRIPT="${REPO_ROOT}/scripts/build/package-gpu-libs.sh"
+if [ -f "$GPU_LIB_SCRIPT" ]; then
+ echo "Packaging GPU libraries for BUILD_TYPE=${BUILD_TYPE:-cpu}..."
+ source "$GPU_LIB_SCRIPT" "$CURDIR/package/lib"
+ package_gpu_libs
+fi
+
+echo "Packaging completed successfully"
+ls -liah $CURDIR/package/
+ls -liah $CURDIR/package/lib/
\ No newline at end of file
diff --git a/backend/cpp/llama-cpp/prepare.sh b/backend/cpp/llama-cpp/prepare.sh
new file mode 100644
index 0000000000000000000000000000000000000000..f9b7e3dd2651897e458ddfb65eb0a4f6e10ae666
--- /dev/null
+++ b/backend/cpp/llama-cpp/prepare.sh
@@ -0,0 +1,31 @@
+#!/bin/bash
+
+## Patches
+
+## Apply patches from the `patches` directory
+if [ -d "patches" ]; then
+ for patch in $(ls patches); do
+ echo "Applying patch $patch"
+ patch -d llama.cpp/ -p1 < patches/$patch
+ done
+fi
+
+set -e
+
+for file in $(ls llama.cpp/tools/server/); do
+ cp -rfv llama.cpp/tools/server/$file llama.cpp/tools/grpc-server/
+done
+
+cp -r CMakeLists.txt llama.cpp/tools/grpc-server/
+cp -r grpc-server.cpp llama.cpp/tools/grpc-server/
+cp -rfv llama.cpp/vendor/nlohmann/json.hpp llama.cpp/tools/grpc-server/
+cp -rfv llama.cpp/vendor/cpp-httplib/httplib.h llama.cpp/tools/grpc-server/
+
+set +e
+if grep -q "grpc-server" llama.cpp/tools/CMakeLists.txt; then
+ echo "grpc-server already added"
+else
+ echo "add_subdirectory(grpc-server)" >> llama.cpp/tools/CMakeLists.txt
+fi
+set -e
+
diff --git a/backend/cpp/llama-cpp/run.sh b/backend/cpp/llama-cpp/run.sh
new file mode 100644
index 0000000000000000000000000000000000000000..2f1ff13cf3096fca4b0242dc324c255540bd1fb0
--- /dev/null
+++ b/backend/cpp/llama-cpp/run.sh
@@ -0,0 +1,62 @@
+#!/bin/bash
+set -ex
+
+# Get the absolute current dir where the script is located
+CURDIR=$(dirname "$(realpath $0)")
+
+cd /
+
+echo "CPU info:"
+grep -e "model\sname" /proc/cpuinfo | head -1
+grep -e "flags" /proc/cpuinfo | head -1
+
+BINARY=llama-cpp-fallback
+
+if grep -q -e "\savx\s" /proc/cpuinfo ; then
+ echo "CPU: AVX found OK"
+ if [ -e $CURDIR/llama-cpp-avx ]; then
+ BINARY=llama-cpp-avx
+ fi
+fi
+
+if grep -q -e "\savx2\s" /proc/cpuinfo ; then
+ echo "CPU: AVX2 found OK"
+ if [ -e $CURDIR/llama-cpp-avx2 ]; then
+ BINARY=llama-cpp-avx2
+ fi
+fi
+
+# Check avx 512
+if grep -q -e "\savx512f\s" /proc/cpuinfo ; then
+ echo "CPU: AVX512F found OK"
+ if [ -e $CURDIR/llama-cpp-avx512 ]; then
+ BINARY=llama-cpp-avx512
+ fi
+fi
+
+if [ -n "$LLAMACPP_GRPC_SERVERS" ]; then
+ if [ -e $CURDIR/llama-cpp-grpc ]; then
+ BINARY=llama-cpp-grpc
+ fi
+fi
+
+# Extend ld library path with the dir where this script is located/lib
+if [ "$(uname)" == "Darwin" ]; then
+ export DYLD_LIBRARY_PATH=$CURDIR/lib:$DYLD_LIBRARY_PATH
+ #export DYLD_FALLBACK_LIBRARY_PATH=$CURDIR/lib:$DYLD_FALLBACK_LIBRARY_PATH
+else
+ export LD_LIBRARY_PATH=$CURDIR/lib:$LD_LIBRARY_PATH
+fi
+
+# If there is a lib/ld.so, use it
+if [ -f $CURDIR/lib/ld.so ]; then
+ echo "Using lib/ld.so"
+ echo "Using binary: $BINARY"
+ exec $CURDIR/lib/ld.so $CURDIR/$BINARY "$@"
+fi
+
+echo "Using binary: $BINARY"
+exec $CURDIR/$BINARY "$@"
+
+# We should never reach this point, however just in case we do, run fallback
+exec $CURDIR/llama-cpp-fallback "$@"
\ No newline at end of file
diff --git a/backend/go/bark-cpp/Makefile b/backend/go/bark-cpp/Makefile
new file mode 100644
index 0000000000000000000000000000000000000000..1bff58c4fad1ad387ea5f23ac89ed770dc518a58
--- /dev/null
+++ b/backend/go/bark-cpp/Makefile
@@ -0,0 +1,51 @@
+INCLUDE_PATH := $(abspath ./)
+LIBRARY_PATH := $(abspath ./)
+
+AR?=ar
+
+CMAKE_ARGS?=-DGGML_NATIVE=OFF
+BUILD_TYPE?=
+GOCMD=go
+# keep standard at C11 and C++11
+CXXFLAGS = -I. -I$(INCLUDE_PATH)/sources/bark.cpp/examples -I$(INCLUDE_PATH)/sources/bark.cpp/encodec.cpp/ggml/include -I$(INCLUDE_PATH)/sources/bark.cpp/spm-headers -I$(INCLUDE_PATH)/sources/bark.cpp -O3 -DNDEBUG -std=c++17 -fPIC
+LDFLAGS = -L$(LIBRARY_PATH) -L$(LIBRARY_PATH)/sources/bark.cpp/build/examples -lbark -lstdc++ -lm
+
+# bark.cpp
+BARKCPP_REPO?=https://github.com/PABannier/bark.cpp.git
+BARKCPP_VERSION?=5d5be84f089ab9ea53b7a793f088d3fbf7247495
+
+# warnings
+CXXFLAGS += -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function
+
+## bark.cpp
+sources/bark.cpp:
+ git clone --recursive $(BARKCPP_REPO) sources/bark.cpp && \
+ cd sources/bark.cpp && \
+ git checkout $(BARKCPP_VERSION) && \
+ git submodule update --init --recursive --depth 1 --single-branch
+
+sources/bark.cpp/build/libbark.a: sources/bark.cpp
+ cd sources/bark.cpp && \
+ mkdir -p build && \
+ cd build && \
+ cmake $(CMAKE_ARGS) .. && \
+ cmake --build . --config Release
+
+gobark.o:
+ $(CXX) $(CXXFLAGS) gobark.cpp -o gobark.o -c $(LDFLAGS)
+
+libbark.a: sources/bark.cpp/build/libbark.a gobark.o
+ cp $(INCLUDE_PATH)/sources/bark.cpp/build/libbark.a ./
+ $(AR) rcs libbark.a gobark.o
+
+bark-cpp: libbark.a
+ CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH="$(CURDIR)" LIBRARY_PATH=$(CURDIR) \
+ $(GOCMD) build -v -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o bark-cpp ./
+
+package:
+ bash package.sh
+
+build: bark-cpp package
+
+clean:
+ rm -f gobark.o libbark.a
\ No newline at end of file
diff --git a/backend/go/bark-cpp/gobark.cpp b/backend/go/bark-cpp/gobark.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..fa4bb336f91e12a8368471b60f6c73862904c4c8
--- /dev/null
+++ b/backend/go/bark-cpp/gobark.cpp
@@ -0,0 +1,85 @@
+#include
+#include
+
+#include "bark.h"
+#include "gobark.h"
+#include "common.h"
+#include "ggml.h"
+
+struct bark_context *c;
+
+void bark_print_progress_callback(struct bark_context *bctx, enum bark_encoding_step step, int progress, void *user_data) {
+ if (step == bark_encoding_step::SEMANTIC) {
+ printf("\rGenerating semantic tokens... %d%%", progress);
+ } else if (step == bark_encoding_step::COARSE) {
+ printf("\rGenerating coarse tokens... %d%%", progress);
+ } else if (step == bark_encoding_step::FINE) {
+ printf("\rGenerating fine tokens... %d%%", progress);
+ }
+ fflush(stdout);
+}
+
+int load_model(char *model) {
+ // initialize bark context
+ struct bark_context_params ctx_params = bark_context_default_params();
+ bark_params params;
+
+ params.model_path = model;
+
+ // ctx_params.verbosity = verbosity;
+ ctx_params.progress_callback = bark_print_progress_callback;
+ ctx_params.progress_callback_user_data = nullptr;
+
+ struct bark_context *bctx = bark_load_model(params.model_path.c_str(), ctx_params, params.seed);
+ if (!bctx) {
+ fprintf(stderr, "%s: Could not load model\n", __func__);
+ return 1;
+ }
+
+ c = bctx;
+
+ return 0;
+}
+
+int tts(char *text,int threads, char *dst ) {
+
+ ggml_time_init();
+ const int64_t t_main_start_us = ggml_time_us();
+
+ // generate audio
+ if (!bark_generate_audio(c, text, threads)) {
+ fprintf(stderr, "%s: An error occurred. If the problem persists, feel free to open an issue to report it.\n", __func__);
+ return 1;
+ }
+
+ const float *audio_data = bark_get_audio_data(c);
+ if (audio_data == NULL) {
+ fprintf(stderr, "%s: Could not get audio data\n", __func__);
+ return 1;
+ }
+
+ const int audio_arr_size = bark_get_audio_data_size(c);
+
+ std::vector audio_arr(audio_data, audio_data + audio_arr_size);
+
+ write_wav_on_disk(audio_arr, dst);
+
+ // report timing
+ {
+ const int64_t t_main_end_us = ggml_time_us();
+ const int64_t t_load_us = bark_get_load_time(c);
+ const int64_t t_eval_us = bark_get_eval_time(c);
+
+ printf("\n\n");
+ printf("%s: load time = %8.2f ms\n", __func__, t_load_us / 1000.0f);
+ printf("%s: eval time = %8.2f ms\n", __func__, t_eval_us / 1000.0f);
+ printf("%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us) / 1000.0f);
+ }
+
+ return 0;
+}
+
+int unload() {
+ bark_free(c);
+}
+
diff --git a/backend/go/bark-cpp/gobark.go b/backend/go/bark-cpp/gobark.go
new file mode 100644
index 0000000000000000000000000000000000000000..8b01ebe2f821796a727784532aa9c2ba50753789
--- /dev/null
+++ b/backend/go/bark-cpp/gobark.go
@@ -0,0 +1,52 @@
+package main
+
+// #cgo CXXFLAGS: -I${SRCDIR}/sources/bark.cpp/ -I${SRCDIR}/sources/bark.cpp/encodec.cpp -I${SRCDIR}/sources/bark.cpp/encodec.cpp/ggml/include -I${SRCDIR}/sources/bark.cpp/examples -I${SRCDIR}/sources/bark.cpp/spm-headers
+// #cgo LDFLAGS: -L${SRCDIR}/ -L${SRCDIR}/sources/bark.cpp/build/examples -L${SRCDIR}/sources/bark.cpp/build/encodec.cpp/ggml/src/ -L${SRCDIR}/sources/bark.cpp/build/encodec.cpp/ -lbark -lencodec -lcommon -lggml -lgomp
+// #include
+// #include
+import "C"
+
+import (
+ "fmt"
+ "unsafe"
+
+ "github.com/mudler/LocalAI/pkg/grpc/base"
+ pb "github.com/mudler/LocalAI/pkg/grpc/proto"
+)
+
+type Bark struct {
+ base.SingleThread
+ threads int
+}
+
+func (sd *Bark) Load(opts *pb.ModelOptions) error {
+
+ sd.threads = int(opts.Threads)
+
+ modelFile := C.CString(opts.ModelFile)
+ defer C.free(unsafe.Pointer(modelFile))
+
+ ret := C.load_model(modelFile)
+ if ret != 0 {
+ return fmt.Errorf("inference failed")
+ }
+
+ return nil
+}
+
+func (sd *Bark) TTS(opts *pb.TTSRequest) error {
+ t := C.CString(opts.Text)
+ defer C.free(unsafe.Pointer(t))
+
+ dst := C.CString(opts.Dst)
+ defer C.free(unsafe.Pointer(dst))
+
+ threads := C.int(sd.threads)
+
+ ret := C.tts(t, threads, dst)
+ if ret != 0 {
+ return fmt.Errorf("inference failed")
+ }
+
+ return nil
+}
diff --git a/backend/go/bark-cpp/gobark.h b/backend/go/bark-cpp/gobark.h
new file mode 100644
index 0000000000000000000000000000000000000000..06fb965d5db44d0c813ca4985fd77506a7b97c1b
--- /dev/null
+++ b/backend/go/bark-cpp/gobark.h
@@ -0,0 +1,8 @@
+#ifdef __cplusplus
+extern "C" {
+#endif
+int load_model(char *model);
+int tts(char *text,int threads, char *dst );
+#ifdef __cplusplus
+}
+#endif
\ No newline at end of file
diff --git a/backend/go/bark-cpp/main.go b/backend/go/bark-cpp/main.go
new file mode 100644
index 0000000000000000000000000000000000000000..840a687d4b140c0c8e42132d5dca482dd8248380
--- /dev/null
+++ b/backend/go/bark-cpp/main.go
@@ -0,0 +1,20 @@
+package main
+
+// Note: this is started internally by LocalAI and a server is allocated for each model
+import (
+ "flag"
+
+ grpc "github.com/mudler/LocalAI/pkg/grpc"
+)
+
+var (
+ addr = flag.String("addr", "localhost:50051", "the address to connect to")
+)
+
+func main() {
+ flag.Parse()
+
+ if err := grpc.StartServer(*addr, &Bark{}); err != nil {
+ panic(err)
+ }
+}
diff --git a/backend/go/bark-cpp/package.sh b/backend/go/bark-cpp/package.sh
new file mode 100644
index 0000000000000000000000000000000000000000..6dce5851f292bdf87c3fdd8f301d3d31cd6fb339
--- /dev/null
+++ b/backend/go/bark-cpp/package.sh
@@ -0,0 +1,41 @@
+#!/bin/bash
+
+# Script to copy the appropriate libraries based on architecture
+# This script is used in the final stage of the Dockerfile
+
+set -e
+
+CURDIR=$(dirname "$(realpath $0)")
+
+# Create lib directory
+mkdir -p $CURDIR/package/lib
+cp -avrf $CURDIR/bark-cpp $CURDIR/package/
+cp -rfv $CURDIR/run.sh $CURDIR/package/
+
+# Detect architecture and copy appropriate libraries
+if [ -f "/lib64/ld-linux-x86-64.so.2" ]; then
+ # x86_64 architecture
+ echo "Detected x86_64 architecture, copying x86_64 libraries..."
+ cp -arfLv /lib64/ld-linux-x86-64.so.2 $CURDIR/package/lib/ld.so
+ cp -arfLv /lib/x86_64-linux-gnu/libc.so.6 $CURDIR/package/lib/libc.so.6
+ cp -arfLv /lib/x86_64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1
+ cp -arfLv /lib/x86_64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6
+ cp -arfLv /lib/x86_64-linux-gnu/libm.so.6 $CURDIR/package/lib/libm.so.6
+ cp -arfLv /lib/x86_64-linux-gnu/libgomp.so.1 $CURDIR/package/lib/libgomp.so.1
+elif [ -f "/lib/ld-linux-aarch64.so.1" ]; then
+ # ARM64 architecture
+ echo "Detected ARM64 architecture, copying ARM64 libraries..."
+ cp -arfLv /lib/ld-linux-aarch64.so.1 $CURDIR/package/lib/ld.so
+ cp -arfLv /lib/aarch64-linux-gnu/libc.so.6 $CURDIR/package/lib/libc.so.6
+ cp -arfLv /lib/aarch64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1
+ cp -arfLv /lib/aarch64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6
+ cp -arfLv /lib/aarch64-linux-gnu/libm.so.6 $CURDIR/package/lib/libm.so.6
+ cp -arfLv /lib/aarch64-linux-gnu/libgomp.so.1 $CURDIR/package/lib/libgomp.so.1
+else
+ echo "Error: Could not detect architecture"
+ exit 1
+fi
+
+echo "Packaging completed successfully"
+ls -liah $CURDIR/package/
+ls -liah $CURDIR/package/lib/
\ No newline at end of file
diff --git a/backend/go/bark-cpp/run.sh b/backend/go/bark-cpp/run.sh
new file mode 100644
index 0000000000000000000000000000000000000000..567d3b89ef09d7ad3800397f2babbb5e4071f4d7
--- /dev/null
+++ b/backend/go/bark-cpp/run.sh
@@ -0,0 +1,13 @@
+#!/bin/bash
+set -ex
+
+CURDIR=$(dirname "$(realpath $0)")
+export LD_LIBRARY_PATH=$CURDIR/lib:$LD_LIBRARY_PATH
+
+# If there is a lib/ld.so, use it
+if [ -f $CURDIR/lib/ld.so ]; then
+ echo "Using lib/ld.so"
+ exec $CURDIR/lib/ld.so $CURDIR/bark-cpp "$@"
+fi
+
+exec $CURDIR/bark-cpp "$@"
\ No newline at end of file
diff --git a/backend/go/huggingface/Makefile b/backend/go/huggingface/Makefile
new file mode 100644
index 0000000000000000000000000000000000000000..77b6c82ed2b7982ace959271a5ebc800efdac3cc
--- /dev/null
+++ b/backend/go/huggingface/Makefile
@@ -0,0 +1,9 @@
+GOCMD=go
+
+huggingface:
+ CGO_ENABLED=0 $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o huggingface ./
+
+package:
+ bash package.sh
+
+build: huggingface package
\ No newline at end of file
diff --git a/backend/go/huggingface/langchain.go b/backend/go/huggingface/langchain.go
new file mode 100644
index 0000000000000000000000000000000000000000..a18c6c87648bdeb2ca07cb3b89a5fc1b45cb4f94
--- /dev/null
+++ b/backend/go/huggingface/langchain.go
@@ -0,0 +1,64 @@
+package main
+
+// This is a wrapper to statisfy the GRPC service interface
+// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
+import (
+ "fmt"
+ "os"
+
+ "github.com/mudler/LocalAI/pkg/grpc/base"
+ pb "github.com/mudler/LocalAI/pkg/grpc/proto"
+ "github.com/mudler/LocalAI/pkg/langchain"
+)
+
+type LLM struct {
+ base.Base
+
+ langchain *langchain.HuggingFace
+ model string
+}
+
+func (llm *LLM) Load(opts *pb.ModelOptions) error {
+ var err error
+ hfToken := os.Getenv("HUGGINGFACEHUB_API_TOKEN")
+ if hfToken == "" {
+ return fmt.Errorf("no huggingface token provided")
+ }
+ llm.langchain, err = langchain.NewHuggingFace(opts.Model, hfToken)
+ llm.model = opts.Model
+ return err
+}
+
+func (llm *LLM) Predict(opts *pb.PredictOptions) (string, error) {
+ o := []langchain.PredictOption{
+ langchain.SetModel(llm.model),
+ langchain.SetMaxTokens(int(opts.Tokens)),
+ langchain.SetTemperature(float64(opts.Temperature)),
+ langchain.SetStopWords(opts.StopPrompts),
+ }
+ pred, err := llm.langchain.PredictHuggingFace(opts.Prompt, o...)
+ if err != nil {
+ return "", err
+ }
+ return pred.Completion, nil
+}
+
+func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) error {
+ o := []langchain.PredictOption{
+ langchain.SetModel(llm.model),
+ langchain.SetMaxTokens(int(opts.Tokens)),
+ langchain.SetTemperature(float64(opts.Temperature)),
+ langchain.SetStopWords(opts.StopPrompts),
+ }
+ go func() {
+ res, err := llm.langchain.PredictHuggingFace(opts.Prompt, o...)
+
+ if err != nil {
+ fmt.Println("err: ", err)
+ }
+ results <- res.Completion
+ close(results)
+ }()
+
+ return nil
+}
diff --git a/backend/go/huggingface/main.go b/backend/go/huggingface/main.go
new file mode 100644
index 0000000000000000000000000000000000000000..acf4408799e123d5492a6cd34b1bef5ffae12dfc
--- /dev/null
+++ b/backend/go/huggingface/main.go
@@ -0,0 +1,21 @@
+package main
+
+// Note: this is started internally by LocalAI and a server is allocated for each model
+
+import (
+ "flag"
+
+ grpc "github.com/mudler/LocalAI/pkg/grpc"
+)
+
+var (
+ addr = flag.String("addr", "localhost:50051", "the address to connect to")
+)
+
+func main() {
+ flag.Parse()
+
+ if err := grpc.StartServer(*addr, &LLM{}); err != nil {
+ panic(err)
+ }
+}
diff --git a/backend/go/huggingface/package.sh b/backend/go/huggingface/package.sh
new file mode 100644
index 0000000000000000000000000000000000000000..6218a65f690facc749a761b2ce0dbdfa44e2a960
--- /dev/null
+++ b/backend/go/huggingface/package.sh
@@ -0,0 +1,12 @@
+#!/bin/bash
+
+# Script to copy the appropriate libraries based on architecture
+# This script is used in the final stage of the Dockerfile
+
+set -e
+
+CURDIR=$(dirname "$(realpath $0)")
+
+mkdir -p $CURDIR/package
+cp -avrf $CURDIR/huggingface $CURDIR/package/
+cp -rfv $CURDIR/run.sh $CURDIR/package/
\ No newline at end of file
diff --git a/backend/go/huggingface/run.sh b/backend/go/huggingface/run.sh
new file mode 100644
index 0000000000000000000000000000000000000000..08972b5d27bddb25e5e7270c489d7116a37bf4e3
--- /dev/null
+++ b/backend/go/huggingface/run.sh
@@ -0,0 +1,6 @@
+#!/bin/bash
+set -ex
+
+CURDIR=$(dirname "$(realpath $0)")
+
+exec $CURDIR/huggingface "$@"
\ No newline at end of file
diff --git a/backend/go/llm/llama/llama.go b/backend/go/llm/llama/llama.go
new file mode 100644
index 0000000000000000000000000000000000000000..011023fe7ab9f550454c6f6cbc33565be966acd0
--- /dev/null
+++ b/backend/go/llm/llama/llama.go
@@ -0,0 +1,260 @@
+package main
+
+// This is a wrapper to statisfy the GRPC service interface
+// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
+import (
+ "fmt"
+ "path/filepath"
+
+ "github.com/go-skynet/go-llama.cpp"
+ "github.com/mudler/LocalAI/pkg/grpc/base"
+ pb "github.com/mudler/LocalAI/pkg/grpc/proto"
+)
+
+type LLM struct {
+ base.SingleThread
+
+ llama *llama.LLama
+ draftModel *llama.LLama
+}
+
+func (llm *LLM) Load(opts *pb.ModelOptions) error {
+ ropeFreqBase := float32(10000)
+ ropeFreqScale := float32(1)
+
+ if opts.RopeFreqBase != 0 {
+ ropeFreqBase = opts.RopeFreqBase
+ }
+ if opts.RopeFreqScale != 0 {
+ ropeFreqScale = opts.RopeFreqScale
+ }
+
+ llamaOpts := []llama.ModelOption{
+ llama.WithRopeFreqBase(ropeFreqBase),
+ llama.WithRopeFreqScale(ropeFreqScale),
+ }
+
+ if opts.NoMulMatQ {
+ llamaOpts = append(llamaOpts, llama.SetMulMatQ(false))
+ }
+
+ // Get base path of opts.ModelFile and use the same for lora (assume the same path)
+ basePath := filepath.Dir(opts.ModelFile)
+
+ if opts.LoraAdapter != "" {
+ llamaOpts = append(llamaOpts, llama.SetLoraAdapter(filepath.Join(basePath, opts.LoraAdapter)))
+ }
+
+ if opts.LoraBase != "" {
+ llamaOpts = append(llamaOpts, llama.SetLoraBase(filepath.Join(basePath, opts.LoraBase)))
+ }
+
+ if opts.ContextSize != 0 {
+ llamaOpts = append(llamaOpts, llama.SetContext(int(opts.ContextSize)))
+ }
+ if opts.F16Memory {
+ llamaOpts = append(llamaOpts, llama.EnableF16Memory)
+ }
+ if opts.Embeddings {
+ llamaOpts = append(llamaOpts, llama.EnableEmbeddings)
+ }
+ if opts.Reranking {
+ llamaOpts = append(llamaOpts, llama.EnableReranking)
+ }
+ if opts.NGPULayers != 0 {
+ llamaOpts = append(llamaOpts, llama.SetGPULayers(int(opts.NGPULayers)))
+ }
+
+ llamaOpts = append(llamaOpts, llama.SetMMap(opts.MMap))
+ llamaOpts = append(llamaOpts, llama.SetMainGPU(opts.MainGPU))
+ llamaOpts = append(llamaOpts, llama.SetTensorSplit(opts.TensorSplit))
+ if opts.NBatch != 0 {
+ llamaOpts = append(llamaOpts, llama.SetNBatch(int(opts.NBatch)))
+ } else {
+ llamaOpts = append(llamaOpts, llama.SetNBatch(512))
+ }
+
+ if opts.NUMA {
+ llamaOpts = append(llamaOpts, llama.EnableNUMA)
+ }
+
+ if opts.LowVRAM {
+ llamaOpts = append(llamaOpts, llama.EnabelLowVRAM)
+ }
+
+ if opts.DraftModel != "" {
+ // https://github.com/ggerganov/llama.cpp/blob/71ca2fad7d6c0ef95ef9944fb3a1a843e481f314/examples/speculative/speculative.cpp#L40
+ llamaOpts = append(llamaOpts, llama.SetPerplexity(true))
+ }
+
+ model, err := llama.New(opts.ModelFile, llamaOpts...)
+
+ if opts.DraftModel != "" {
+ // opts.DraftModel is relative to opts.ModelFile, so we need to get the basepath of opts.ModelFile
+ if !filepath.IsAbs(opts.DraftModel) {
+ dir := filepath.Dir(opts.ModelFile)
+ opts.DraftModel = filepath.Join(dir, opts.DraftModel)
+ }
+
+ draftModel, err := llama.New(opts.DraftModel, llamaOpts...)
+ if err != nil {
+ return err
+ }
+ llm.draftModel = draftModel
+ }
+
+ llm.llama = model
+
+ return err
+}
+
+func buildPredictOptions(opts *pb.PredictOptions) []llama.PredictOption {
+ ropeFreqBase := float32(10000)
+ ropeFreqScale := float32(1)
+
+ if opts.RopeFreqBase != 0 {
+ ropeFreqBase = opts.RopeFreqBase
+ }
+ if opts.RopeFreqScale != 0 {
+ ropeFreqScale = opts.RopeFreqScale
+ }
+ predictOptions := []llama.PredictOption{
+ llama.SetTemperature(opts.Temperature),
+ llama.SetTopP(opts.TopP),
+ llama.SetTopK(int(opts.TopK)),
+ llama.SetTokens(int(opts.Tokens)),
+ llama.SetThreads(int(opts.Threads)),
+ llama.WithGrammar(opts.Grammar),
+ llama.SetRopeFreqBase(ropeFreqBase),
+ llama.SetRopeFreqScale(ropeFreqScale),
+ llama.SetNegativePromptScale(opts.NegativePromptScale),
+ llama.SetNegativePrompt(opts.NegativePrompt),
+ }
+
+ if opts.PromptCacheAll {
+ predictOptions = append(predictOptions, llama.EnablePromptCacheAll)
+ }
+
+ if opts.PromptCacheRO {
+ predictOptions = append(predictOptions, llama.EnablePromptCacheRO)
+ }
+
+ // Expected absolute path
+ if opts.PromptCachePath != "" {
+ predictOptions = append(predictOptions, llama.SetPathPromptCache(opts.PromptCachePath))
+ }
+
+ if opts.Mirostat != 0 {
+ predictOptions = append(predictOptions, llama.SetMirostat(int(opts.Mirostat)))
+ }
+
+ if opts.MirostatETA != 0 {
+ predictOptions = append(predictOptions, llama.SetMirostatETA(opts.MirostatETA))
+ }
+
+ if opts.MirostatTAU != 0 {
+ predictOptions = append(predictOptions, llama.SetMirostatTAU(opts.MirostatTAU))
+ }
+
+ if opts.Debug {
+ predictOptions = append(predictOptions, llama.Debug)
+ }
+
+ predictOptions = append(predictOptions, llama.SetStopWords(opts.StopPrompts...))
+
+ if opts.PresencePenalty != 0 {
+ predictOptions = append(predictOptions, llama.SetPenalty(opts.PresencePenalty))
+ }
+
+ if opts.NKeep != 0 {
+ predictOptions = append(predictOptions, llama.SetNKeep(int(opts.NKeep)))
+ }
+
+ if opts.Batch != 0 {
+ predictOptions = append(predictOptions, llama.SetBatch(int(opts.Batch)))
+ }
+
+ if opts.F16KV {
+ predictOptions = append(predictOptions, llama.EnableF16KV)
+ }
+
+ if opts.IgnoreEOS {
+ predictOptions = append(predictOptions, llama.IgnoreEOS)
+ }
+
+ if opts.Seed != 0 {
+ predictOptions = append(predictOptions, llama.SetSeed(int(opts.Seed)))
+ }
+
+ if opts.NDraft != 0 {
+ predictOptions = append(predictOptions, llama.SetNDraft(int(opts.NDraft)))
+ }
+ //predictOptions = append(predictOptions, llama.SetLogitBias(c.Seed))
+
+ predictOptions = append(predictOptions, llama.SetFrequencyPenalty(opts.FrequencyPenalty))
+ predictOptions = append(predictOptions, llama.SetMlock(opts.MLock))
+ predictOptions = append(predictOptions, llama.SetMemoryMap(opts.MMap))
+ predictOptions = append(predictOptions, llama.SetPredictionMainGPU(opts.MainGPU))
+ predictOptions = append(predictOptions, llama.SetPredictionTensorSplit(opts.TensorSplit))
+ predictOptions = append(predictOptions, llama.SetTailFreeSamplingZ(opts.TailFreeSamplingZ))
+ predictOptions = append(predictOptions, llama.SetTypicalP(opts.TypicalP))
+ return predictOptions
+}
+
+func (llm *LLM) Predict(opts *pb.PredictOptions) (string, error) {
+ if llm.draftModel != nil {
+ return llm.llama.SpeculativeSampling(llm.draftModel, opts.Prompt, buildPredictOptions(opts)...)
+ }
+ return llm.llama.Predict(opts.Prompt, buildPredictOptions(opts)...)
+}
+
+func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) error {
+ predictOptions := buildPredictOptions(opts)
+
+ predictOptions = append(predictOptions, llama.SetTokenCallback(func(token string) bool {
+ results <- token
+ return true
+ }))
+
+ go func() {
+ var err error
+ if llm.draftModel != nil {
+ _, err = llm.llama.SpeculativeSampling(llm.draftModel, opts.Prompt, buildPredictOptions(opts)...)
+ } else {
+ _, err = llm.llama.Predict(opts.Prompt, predictOptions...)
+ }
+
+ if err != nil {
+ fmt.Println("err: ", err)
+ }
+ close(results)
+ }()
+
+ return nil
+}
+
+func (llm *LLM) Embeddings(opts *pb.PredictOptions) ([]float32, error) {
+ predictOptions := buildPredictOptions(opts)
+
+ if len(opts.EmbeddingTokens) > 0 {
+ tokens := []int{}
+ for _, t := range opts.EmbeddingTokens {
+ tokens = append(tokens, int(t))
+ }
+ return llm.llama.TokenEmbeddings(tokens, predictOptions...)
+ }
+
+ return llm.llama.Embeddings(opts.Embeddings, predictOptions...)
+}
+
+func (llm *LLM) TokenizeString(opts *pb.PredictOptions) (pb.TokenizationResponse, error) {
+ predictOptions := buildPredictOptions(opts)
+ l, tokens, err := llm.llama.TokenizeString(opts.Prompt, predictOptions...)
+ if err != nil {
+ return pb.TokenizationResponse{}, err
+ }
+ return pb.TokenizationResponse{
+ Length: l,
+ Tokens: tokens,
+ }, nil
+}
diff --git a/backend/go/llm/llama/main.go b/backend/go/llm/llama/main.go
new file mode 100644
index 0000000000000000000000000000000000000000..83dc35ad8b55b5270e432119bd1f867b7be07b33
--- /dev/null
+++ b/backend/go/llm/llama/main.go
@@ -0,0 +1,23 @@
+package main
+
+// GRPC Falcon server
+
+// Note: this is started internally by LocalAI and a server is allocated for each model
+
+import (
+ "flag"
+
+ grpc "github.com/mudler/LocalAI/pkg/grpc"
+)
+
+var (
+ addr = flag.String("addr", "localhost:50051", "the address to connect to")
+)
+
+func main() {
+ flag.Parse()
+
+ if err := grpc.StartServer(*addr, &LLM{}); err != nil {
+ panic(err)
+ }
+}
diff --git a/backend/go/local-store/Makefile b/backend/go/local-store/Makefile
new file mode 100644
index 0000000000000000000000000000000000000000..6cde84b00d200f410272a989fe293e4c1f34c658
--- /dev/null
+++ b/backend/go/local-store/Makefile
@@ -0,0 +1,9 @@
+GOCMD=go
+
+local-store:
+ CGO_ENABLED=0 $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o local-store ./
+
+package:
+ bash package.sh
+
+build: local-store package
\ No newline at end of file
diff --git a/backend/go/local-store/debug.go b/backend/go/local-store/debug.go
new file mode 100644
index 0000000000000000000000000000000000000000..0654d295271b319166d5df50a2e0afc621df74e8
--- /dev/null
+++ b/backend/go/local-store/debug.go
@@ -0,0 +1,14 @@
+//go:build debug
+// +build debug
+
+package main
+
+import (
+ "github.com/mudler/xlog"
+)
+
+func assert(cond bool, msg string) {
+ if !cond {
+ xlog.Fatal().Stack().Msg(msg)
+ }
+}
diff --git a/backend/go/local-store/main.go b/backend/go/local-store/main.go
new file mode 100644
index 0000000000000000000000000000000000000000..f06dfa6f511356aaed39cfcf5754cba70b30262b
--- /dev/null
+++ b/backend/go/local-store/main.go
@@ -0,0 +1,25 @@
+package main
+
+// Note: this is started internally by LocalAI and a server is allocated for each store
+
+import (
+ "flag"
+ "os"
+
+ grpc "github.com/mudler/LocalAI/pkg/grpc"
+ "github.com/mudler/xlog"
+)
+
+var (
+ addr = flag.String("addr", "localhost:50051", "the address to connect to")
+)
+
+func main() {
+ xlog.SetLogger(xlog.NewLogger(xlog.LogLevel(os.Getenv("LOCALAI_LOG_LEVEL")), os.Getenv("LOCALAI_LOG_FORMAT")))
+
+ flag.Parse()
+
+ if err := grpc.StartServer(*addr, NewStore()); err != nil {
+ panic(err)
+ }
+}
diff --git a/backend/go/local-store/package.sh b/backend/go/local-store/package.sh
new file mode 100644
index 0000000000000000000000000000000000000000..af94e0ee7f060306d2f7ac8cc0f931f0bfabb6fb
--- /dev/null
+++ b/backend/go/local-store/package.sh
@@ -0,0 +1,12 @@
+#!/bin/bash
+
+# Script to copy the appropriate libraries based on architecture
+# This script is used in the final stage of the Dockerfile
+
+set -e
+
+CURDIR=$(dirname "$(realpath $0)")
+
+mkdir -p $CURDIR/package
+cp -avrf $CURDIR/local-store $CURDIR/package/
+cp -rfv $CURDIR/run.sh $CURDIR/package/
\ No newline at end of file
diff --git a/backend/go/local-store/production.go b/backend/go/local-store/production.go
new file mode 100644
index 0000000000000000000000000000000000000000..418b6397283aa68966fa474b8abcf85c37f3a2b2
--- /dev/null
+++ b/backend/go/local-store/production.go
@@ -0,0 +1,7 @@
+//go:build !debug
+// +build !debug
+
+package main
+
+func assert(cond bool, msg string) {
+}
diff --git a/backend/go/local-store/run.sh b/backend/go/local-store/run.sh
new file mode 100644
index 0000000000000000000000000000000000000000..479f3b486f079e89ecf3bd3bdfe6a21f48cd22e3
--- /dev/null
+++ b/backend/go/local-store/run.sh
@@ -0,0 +1,6 @@
+#!/bin/bash
+set -ex
+
+CURDIR=$(dirname "$(realpath $0)")
+
+exec $CURDIR/local-store "$@"
\ No newline at end of file
diff --git a/backend/go/local-store/store.go b/backend/go/local-store/store.go
new file mode 100644
index 0000000000000000000000000000000000000000..2082684bcb37843480a60d9a5e34762a605199bc
--- /dev/null
+++ b/backend/go/local-store/store.go
@@ -0,0 +1,515 @@
+package main
+
+// This is a wrapper to statisfy the GRPC service interface
+// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
+import (
+ "container/heap"
+ "errors"
+ "fmt"
+ "math"
+ "slices"
+
+ "github.com/mudler/LocalAI/pkg/grpc/base"
+ pb "github.com/mudler/LocalAI/pkg/grpc/proto"
+
+ "github.com/mudler/xlog"
+)
+
+type Store struct {
+ base.SingleThread
+
+ // The sorted keys
+ keys [][]float32
+ // The sorted values
+ values [][]byte
+
+ // If for every K it holds that ||k||^2 = 1, then we can use the normalized distance functions
+ // TODO: Should we normalize incoming keys if they are not instead?
+ keysAreNormalized bool
+ // The first key decides the length of the keys
+ keyLen int
+}
+
+// TODO: Only used for sorting using Go's builtin implementation. The interfaces are columnar because
+// that's theoretically best for memory layout and cache locality, but this isn't optimized yet.
+type Pair struct {
+ Key []float32
+ Value []byte
+}
+
+func NewStore() *Store {
+ return &Store{
+ keys: make([][]float32, 0),
+ values: make([][]byte, 0),
+ keysAreNormalized: true,
+ keyLen: -1,
+ }
+}
+
+func compareSlices(k1, k2 []float32) int {
+ assert(len(k1) == len(k2), fmt.Sprintf("compareSlices: len(k1) = %d, len(k2) = %d", len(k1), len(k2)))
+
+ return slices.Compare(k1, k2)
+}
+
+func hasKey(unsortedSlice [][]float32, target []float32) bool {
+ return slices.ContainsFunc(unsortedSlice, func(k []float32) bool {
+ return compareSlices(k, target) == 0
+ })
+}
+
+func findInSortedSlice(sortedSlice [][]float32, target []float32) (int, bool) {
+ return slices.BinarySearchFunc(sortedSlice, target, func(k, t []float32) int {
+ return compareSlices(k, t)
+ })
+}
+
+func isSortedPairs(kvs []Pair) bool {
+ for i := 1; i < len(kvs); i++ {
+ if compareSlices(kvs[i-1].Key, kvs[i].Key) > 0 {
+ return false
+ }
+ }
+
+ return true
+}
+
+func isSortedKeys(keys [][]float32) bool {
+ for i := 1; i < len(keys); i++ {
+ if compareSlices(keys[i-1], keys[i]) > 0 {
+ return false
+ }
+ }
+
+ return true
+}
+
+func sortIntoKeySlicese(keys []*pb.StoresKey) [][]float32 {
+ ks := make([][]float32, len(keys))
+
+ for i, k := range keys {
+ ks[i] = k.Floats
+ }
+
+ slices.SortFunc(ks, compareSlices)
+
+ assert(len(ks) == len(keys), fmt.Sprintf("len(ks) = %d, len(keys) = %d", len(ks), len(keys)))
+ assert(isSortedKeys(ks), "keys are not sorted")
+
+ return ks
+}
+
+func (s *Store) Load(opts *pb.ModelOptions) error {
+ if opts.Model != "" {
+ return errors.New("not implemented")
+ }
+ return nil
+}
+
+// Sort the incoming kvs and merge them with the existing sorted kvs
+func (s *Store) StoresSet(opts *pb.StoresSetOptions) error {
+ if len(opts.Keys) == 0 {
+ return fmt.Errorf("no keys to add")
+ }
+
+ if len(opts.Keys) != len(opts.Values) {
+ return fmt.Errorf("len(keys) = %d, len(values) = %d", len(opts.Keys), len(opts.Values))
+ }
+
+ if s.keyLen == -1 {
+ s.keyLen = len(opts.Keys[0].Floats)
+ } else {
+ if len(opts.Keys[0].Floats) != s.keyLen {
+ return fmt.Errorf("Try to add key with length %d when existing length is %d", len(opts.Keys[0].Floats), s.keyLen)
+ }
+ }
+
+ kvs := make([]Pair, len(opts.Keys))
+
+ for i, k := range opts.Keys {
+ if s.keysAreNormalized && !isNormalized(k.Floats) {
+ s.keysAreNormalized = false
+ var sample []float32
+ if len(s.keys) > 5 {
+ sample = k.Floats[:5]
+ } else {
+ sample = k.Floats
+ }
+ xlog.Debug("Key is not normalized", "sample", sample)
+ }
+
+ kvs[i] = Pair{
+ Key: k.Floats,
+ Value: opts.Values[i].Bytes,
+ }
+ }
+
+ slices.SortFunc(kvs, func(a, b Pair) int {
+ return compareSlices(a.Key, b.Key)
+ })
+
+ assert(len(kvs) == len(opts.Keys), fmt.Sprintf("len(kvs) = %d, len(opts.Keys) = %d", len(kvs), len(opts.Keys)))
+ assert(isSortedPairs(kvs), "keys are not sorted")
+
+ l := len(kvs) + len(s.keys)
+ merge_ks := make([][]float32, 0, l)
+ merge_vs := make([][]byte, 0, l)
+
+ i, j := 0, 0
+ for {
+ if i+j >= l {
+ break
+ }
+
+ if i >= len(kvs) {
+ merge_ks = append(merge_ks, s.keys[j])
+ merge_vs = append(merge_vs, s.values[j])
+ j++
+ continue
+ }
+
+ if j >= len(s.keys) {
+ merge_ks = append(merge_ks, kvs[i].Key)
+ merge_vs = append(merge_vs, kvs[i].Value)
+ i++
+ continue
+ }
+
+ c := compareSlices(kvs[i].Key, s.keys[j])
+ if c < 0 {
+ merge_ks = append(merge_ks, kvs[i].Key)
+ merge_vs = append(merge_vs, kvs[i].Value)
+ i++
+ } else if c > 0 {
+ merge_ks = append(merge_ks, s.keys[j])
+ merge_vs = append(merge_vs, s.values[j])
+ j++
+ } else {
+ merge_ks = append(merge_ks, kvs[i].Key)
+ merge_vs = append(merge_vs, kvs[i].Value)
+ i++
+ j++
+ }
+ }
+
+ assert(len(merge_ks) == l, fmt.Sprintf("len(merge_ks) = %d, l = %d", len(merge_ks), l))
+ assert(isSortedKeys(merge_ks), "merge keys are not sorted")
+
+ s.keys = merge_ks
+ s.values = merge_vs
+
+ return nil
+}
+
+func (s *Store) StoresDelete(opts *pb.StoresDeleteOptions) error {
+ if len(opts.Keys) == 0 {
+ return fmt.Errorf("no keys to delete")
+ }
+
+ if len(opts.Keys) == 0 {
+ return fmt.Errorf("no keys to add")
+ }
+
+ if s.keyLen == -1 {
+ s.keyLen = len(opts.Keys[0].Floats)
+ } else {
+ if len(opts.Keys[0].Floats) != s.keyLen {
+ return fmt.Errorf("Trying to delete key with length %d when existing length is %d", len(opts.Keys[0].Floats), s.keyLen)
+ }
+ }
+
+ ks := sortIntoKeySlicese(opts.Keys)
+
+ l := len(s.keys) - len(ks)
+ merge_ks := make([][]float32, 0, l)
+ merge_vs := make([][]byte, 0, l)
+
+ tail_ks := s.keys
+ tail_vs := s.values
+ for _, k := range ks {
+ j, found := findInSortedSlice(tail_ks, k)
+
+ if found {
+ merge_ks = append(merge_ks, tail_ks[:j]...)
+ merge_vs = append(merge_vs, tail_vs[:j]...)
+ tail_ks = tail_ks[j+1:]
+ tail_vs = tail_vs[j+1:]
+ } else {
+ assert(!hasKey(s.keys, k), fmt.Sprintf("Key exists, but was not found: t=%d, %v", len(tail_ks), k))
+ }
+
+ xlog.Debug("Delete", "found", found, "tailLen", len(tail_ks), "j", j, "mergeKeysLen", len(merge_ks), "mergeValuesLen", len(merge_vs))
+ }
+
+ merge_ks = append(merge_ks, tail_ks...)
+ merge_vs = append(merge_vs, tail_vs...)
+
+ assert(len(merge_ks) <= len(s.keys), fmt.Sprintf("len(merge_ks) = %d, len(s.keys) = %d", len(merge_ks), len(s.keys)))
+
+ s.keys = merge_ks
+ s.values = merge_vs
+
+ assert(len(s.keys) >= l, fmt.Sprintf("len(s.keys) = %d, l = %d", len(s.keys), l))
+ assert(isSortedKeys(s.keys), "keys are not sorted")
+ assert(func() bool {
+ for _, k := range ks {
+ if _, found := findInSortedSlice(s.keys, k); found {
+ return false
+ }
+ }
+ return true
+ }(), "Keys to delete still present")
+
+ if len(s.keys) != l {
+ xlog.Debug("Delete: Some keys not found", "keysLen", len(s.keys), "expectedLen", l)
+ }
+
+ return nil
+}
+
+func (s *Store) StoresGet(opts *pb.StoresGetOptions) (pb.StoresGetResult, error) {
+ pbKeys := make([]*pb.StoresKey, 0, len(opts.Keys))
+ pbValues := make([]*pb.StoresValue, 0, len(opts.Keys))
+ ks := sortIntoKeySlicese(opts.Keys)
+
+ if len(s.keys) == 0 {
+ xlog.Debug("Get: No keys in store")
+ }
+
+ if s.keyLen == -1 {
+ s.keyLen = len(opts.Keys[0].Floats)
+ } else {
+ if len(opts.Keys[0].Floats) != s.keyLen {
+ return pb.StoresGetResult{}, fmt.Errorf("Try to get a key with length %d when existing length is %d", len(opts.Keys[0].Floats), s.keyLen)
+ }
+ }
+
+ tail_k := s.keys
+ tail_v := s.values
+ for i, k := range ks {
+ j, found := findInSortedSlice(tail_k, k)
+
+ if found {
+ pbKeys = append(pbKeys, &pb.StoresKey{
+ Floats: k,
+ })
+ pbValues = append(pbValues, &pb.StoresValue{
+ Bytes: tail_v[j],
+ })
+
+ tail_k = tail_k[j+1:]
+ tail_v = tail_v[j+1:]
+ } else {
+ assert(!hasKey(s.keys, k), fmt.Sprintf("Key exists, but was not found: i=%d, %v", i, k))
+ }
+ }
+
+ if len(pbKeys) != len(opts.Keys) {
+ xlog.Debug("Get: Some keys not found", "pbKeysLen", len(pbKeys), "optsKeysLen", len(opts.Keys), "storeKeysLen", len(s.keys))
+ }
+
+ return pb.StoresGetResult{
+ Keys: pbKeys,
+ Values: pbValues,
+ }, nil
+}
+
+func isNormalized(k []float32) bool {
+ var sum float64
+
+ for _, v := range k {
+ v64 := float64(v)
+ sum += v64 * v64
+ }
+
+ s := math.Sqrt(sum)
+
+ return s >= 0.99 && s <= 1.01
+}
+
+// TODO: This we could replace with handwritten SIMD code
+func normalizedCosineSimilarity(k1, k2 []float32) float32 {
+ assert(len(k1) == len(k2), fmt.Sprintf("normalizedCosineSimilarity: len(k1) = %d, len(k2) = %d", len(k1), len(k2)))
+
+ var dot float32
+ for i := 0; i < len(k1); i++ {
+ dot += k1[i] * k2[i]
+ }
+
+ assert(dot >= -1.01 && dot <= 1.01, fmt.Sprintf("dot = %f", dot))
+
+ // 2.0 * (1.0 - dot) would be the Euclidean distance
+ return dot
+}
+
+type PriorityItem struct {
+ Similarity float32
+ Key []float32
+ Value []byte
+}
+
+type PriorityQueue []*PriorityItem
+
+func (pq PriorityQueue) Len() int { return len(pq) }
+
+func (pq PriorityQueue) Less(i, j int) bool {
+ // Inverted because the most similar should be at the top
+ return pq[i].Similarity < pq[j].Similarity
+}
+
+func (pq PriorityQueue) Swap(i, j int) {
+ pq[i], pq[j] = pq[j], pq[i]
+}
+
+func (pq *PriorityQueue) Push(x any) {
+ item := x.(*PriorityItem)
+ *pq = append(*pq, item)
+}
+
+func (pq *PriorityQueue) Pop() any {
+ old := *pq
+ n := len(old)
+ item := old[n-1]
+ *pq = old[0 : n-1]
+ return item
+}
+
+func (s *Store) StoresFindNormalized(opts *pb.StoresFindOptions) (pb.StoresFindResult, error) {
+ tk := opts.Key.Floats
+ top_ks := make(PriorityQueue, 0, int(opts.TopK))
+ heap.Init(&top_ks)
+
+ for i, k := range s.keys {
+ sim := normalizedCosineSimilarity(tk, k)
+ heap.Push(&top_ks, &PriorityItem{
+ Similarity: sim,
+ Key: k,
+ Value: s.values[i],
+ })
+
+ if top_ks.Len() > int(opts.TopK) {
+ heap.Pop(&top_ks)
+ }
+ }
+
+ similarities := make([]float32, top_ks.Len())
+ pbKeys := make([]*pb.StoresKey, top_ks.Len())
+ pbValues := make([]*pb.StoresValue, top_ks.Len())
+
+ for i := top_ks.Len() - 1; i >= 0; i-- {
+ item := heap.Pop(&top_ks).(*PriorityItem)
+
+ similarities[i] = item.Similarity
+ pbKeys[i] = &pb.StoresKey{
+ Floats: item.Key,
+ }
+ pbValues[i] = &pb.StoresValue{
+ Bytes: item.Value,
+ }
+ }
+
+ return pb.StoresFindResult{
+ Keys: pbKeys,
+ Values: pbValues,
+ Similarities: similarities,
+ }, nil
+}
+
+func cosineSimilarity(k1, k2 []float32, mag1 float64) float32 {
+ assert(len(k1) == len(k2), fmt.Sprintf("cosineSimilarity: len(k1) = %d, len(k2) = %d", len(k1), len(k2)))
+
+ var dot, mag2 float64
+ for i := 0; i < len(k1); i++ {
+ dot += float64(k1[i] * k2[i])
+ mag2 += float64(k2[i] * k2[i])
+ }
+
+ sim := float32(dot / (mag1 * math.Sqrt(mag2)))
+
+ assert(sim >= -1.01 && sim <= 1.01, fmt.Sprintf("sim = %f", sim))
+
+ return sim
+}
+
+func (s *Store) StoresFindFallback(opts *pb.StoresFindOptions) (pb.StoresFindResult, error) {
+ tk := opts.Key.Floats
+ top_ks := make(PriorityQueue, 0, int(opts.TopK))
+ heap.Init(&top_ks)
+
+ var mag1 float64
+ for _, v := range tk {
+ mag1 += float64(v * v)
+ }
+ mag1 = math.Sqrt(mag1)
+
+ for i, k := range s.keys {
+ dist := cosineSimilarity(tk, k, mag1)
+ heap.Push(&top_ks, &PriorityItem{
+ Similarity: dist,
+ Key: k,
+ Value: s.values[i],
+ })
+
+ if top_ks.Len() > int(opts.TopK) {
+ heap.Pop(&top_ks)
+ }
+ }
+
+ similarities := make([]float32, top_ks.Len())
+ pbKeys := make([]*pb.StoresKey, top_ks.Len())
+ pbValues := make([]*pb.StoresValue, top_ks.Len())
+
+ for i := top_ks.Len() - 1; i >= 0; i-- {
+ item := heap.Pop(&top_ks).(*PriorityItem)
+
+ similarities[i] = item.Similarity
+ pbKeys[i] = &pb.StoresKey{
+ Floats: item.Key,
+ }
+ pbValues[i] = &pb.StoresValue{
+ Bytes: item.Value,
+ }
+ }
+
+ return pb.StoresFindResult{
+ Keys: pbKeys,
+ Values: pbValues,
+ Similarities: similarities,
+ }, nil
+}
+
+func (s *Store) StoresFind(opts *pb.StoresFindOptions) (pb.StoresFindResult, error) {
+ tk := opts.Key.Floats
+
+ if len(tk) != s.keyLen {
+ return pb.StoresFindResult{}, fmt.Errorf("Try to find key with length %d when existing length is %d", len(tk), s.keyLen)
+ }
+
+ if opts.TopK < 1 {
+ return pb.StoresFindResult{}, fmt.Errorf("opts.TopK = %d, must be >= 1", opts.TopK)
+ }
+
+ if s.keyLen == -1 {
+ s.keyLen = len(opts.Key.Floats)
+ } else {
+ if len(opts.Key.Floats) != s.keyLen {
+ return pb.StoresFindResult{}, fmt.Errorf("Try to add key with length %d when existing length is %d", len(opts.Key.Floats), s.keyLen)
+ }
+ }
+
+ if s.keysAreNormalized && isNormalized(tk) {
+ return s.StoresFindNormalized(opts)
+ } else {
+ if s.keysAreNormalized {
+ var sample []float32
+ if len(s.keys) > 5 {
+ sample = tk[:5]
+ } else {
+ sample = tk
+ }
+ xlog.Debug("Trying to compare non-normalized key with normalized keys", "sample", sample)
+ }
+
+ return s.StoresFindFallback(opts)
+ }
+}
diff --git a/backend/go/piper/Makefile b/backend/go/piper/Makefile
new file mode 100644
index 0000000000000000000000000000000000000000..020028a5dc7ceab21150da3a62e93a2147c10395
--- /dev/null
+++ b/backend/go/piper/Makefile
@@ -0,0 +1,37 @@
+
+# go-piper version
+PIPER_REPO?=https://github.com/mudler/go-piper
+PIPER_VERSION?=e10ca041a885d4a8f3871d52924b47792d5e5aa0
+
+CURRENT_DIR=$(abspath ./)
+GOCMD=go
+
+PIPER_CGO_CXXFLAGS+=-I$(CURRENT_DIR)/sources/go-piper/piper/src/cpp -I$(CURRENT_DIR)/sources/go-piper/piper/build/fi/include -I$(CURRENT_DIR)/sources/go-piper/piper/build/pi/include -I$(CURRENT_DIR)/sources/go-piper/piper/build/si/include
+PIPER_CGO_LDFLAGS+=-L$(CURRENT_DIR)/sources/go-piper/piper/build/fi/lib -L$(CURRENT_DIR)/sources/go-piper/piper/build/pi/lib -L$(CURRENT_DIR)/sources/go-piper/piper/build/si/lib -lfmt -lspdlog -lucd
+
+## go-piper
+sources/go-piper:
+ mkdir -p sources/go-piper
+ cd sources/go-piper && \
+ git init && \
+ git remote add origin $(PIPER_REPO) && \
+ git fetch origin && \
+ git checkout $(PIPER_VERSION) && \
+ git submodule update --init --recursive --depth 1 --single-branch
+
+sources/go-piper/libpiper_binding.a: sources/go-piper
+ $(MAKE) -C sources/go-piper libpiper_binding.a example/main piper.o
+
+espeak-ng-data: sources/go-piper sources/go-piper/libpiper_binding.a
+ mkdir -p espeak-ng-data
+ @cp -rf sources/go-piper/piper-phonemize/pi/share/espeak-ng-data/. espeak-ng-data
+
+piper: sources/go-piper sources/go-piper/libpiper_binding.a espeak-ng-data
+ $(GOCMD) mod edit -replace github.com/mudler/go-piper=$(CURRENT_DIR)/sources/go-piper
+ CGO_CXXFLAGS="$(PIPER_CGO_CXXFLAGS)" CGO_LDFLAGS="$(PIPER_CGO_LDFLAGS)" LIBRARY_PATH=$(CURRENT_DIR)/sources/go-piper \
+ $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o piper ./
+
+package:
+ bash package.sh
+
+build: piper package
\ No newline at end of file
diff --git a/backend/go/piper/main.go b/backend/go/piper/main.go
new file mode 100644
index 0000000000000000000000000000000000000000..e02cd91964e4b746b03beaa2a1b4c57d4f305998
--- /dev/null
+++ b/backend/go/piper/main.go
@@ -0,0 +1,21 @@
+package main
+
+// Note: this is started internally by LocalAI and a server is allocated for each model
+
+import (
+ "flag"
+
+ grpc "github.com/mudler/LocalAI/pkg/grpc"
+)
+
+var (
+ addr = flag.String("addr", "localhost:50051", "the address to connect to")
+)
+
+func main() {
+ flag.Parse()
+
+ if err := grpc.StartServer(*addr, &Piper{}); err != nil {
+ panic(err)
+ }
+}
diff --git a/backend/go/piper/package.sh b/backend/go/piper/package.sh
new file mode 100644
index 0000000000000000000000000000000000000000..646efc3626ee43e11149e6d9069482d8dce902df
--- /dev/null
+++ b/backend/go/piper/package.sh
@@ -0,0 +1,54 @@
+#!/bin/bash
+
+# Script to copy the appropriate libraries based on architecture
+# This script is used in the final stage of the Dockerfile
+
+set -e
+
+CURDIR=$(dirname "$(realpath $0)")
+
+# Create lib directory
+mkdir -p $CURDIR/package/lib
+
+cp -avrf $CURDIR/piper $CURDIR/package/
+cp -avrf $CURDIR/espeak-ng-data $CURDIR/package/
+cp -rfv $CURDIR/run.sh $CURDIR/package/
+cp -rfLv $CURDIR/sources/go-piper/piper-phonemize/pi/lib/* $CURDIR/package/lib/
+
+# Detect architecture and copy appropriate libraries
+if [ -f "/lib64/ld-linux-x86-64.so.2" ]; then
+ # x86_64 architecture
+ echo "Detected x86_64 architecture, copying x86_64 libraries..."
+ cp -arfLv /lib64/ld-linux-x86-64.so.2 $CURDIR/package/lib/ld.so
+ cp -arfLv /lib/x86_64-linux-gnu/libc.so.6 $CURDIR/package/lib/libc.so.6
+ cp -arfLv /lib/x86_64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1
+ cp -arfLv /lib/x86_64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6
+ cp -arfLv /lib/x86_64-linux-gnu/libm.so.6 $CURDIR/package/lib/libm.so.6
+ cp -arfLv /lib/x86_64-linux-gnu/libgomp.so.1 $CURDIR/package/lib/libgomp.so.1
+ cp -arfLv /lib/x86_64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1
+ cp -arfLv /lib/x86_64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6
+ cp -arfLv /lib/x86_64-linux-gnu/libdl.so.2 $CURDIR/package/lib/libdl.so.2
+ cp -arfLv /lib/x86_64-linux-gnu/librt.so.1 $CURDIR/package/lib/librt.so.1
+ cp -arfLv /lib/x86_64-linux-gnu/libpthread.so.0 $CURDIR/package/lib/libpthread.so.0
+elif [ -f "/lib/ld-linux-aarch64.so.1" ]; then
+ # ARM64 architecture
+ echo "Detected ARM64 architecture, copying ARM64 libraries..."
+ cp -arfLv /lib/ld-linux-aarch64.so.1 $CURDIR/package/lib/ld.so
+ cp -arfLv /lib/aarch64-linux-gnu/libc.so.6 $CURDIR/package/lib/libc.so.6
+ cp -arfLv /lib/aarch64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1
+ cp -arfLv /lib/aarch64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6
+ cp -arfLv /lib/aarch64-linux-gnu/libm.so.6 $CURDIR/package/lib/libm.so.6
+ cp -arfLv /lib/aarch64-linux-gnu/libgomp.so.1 $CURDIR/package/lib/libgomp.so.1
+ cp -arfLv /lib/aarch64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1
+ cp -arfLv /lib/aarch64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6
+ cp -arfLv /lib/aarch64-linux-gnu/libdl.so.2 $CURDIR/package/lib/libdl.so.2
+ cp -arfLv /lib/aarch64-linux-gnu/librt.so.1 $CURDIR/package/lib/librt.so.1
+ cp -arfLv /lib/aarch64-linux-gnu/libpthread.so.0 $CURDIR/package/lib/libpthread.so.0
+else
+ echo "Error: Could not detect architecture"
+ exit 1
+fi
+
+echo "Packaging completed successfully"
+ls -liah $CURDIR/package/
+ls -liah $CURDIR/package/lib/
\ No newline at end of file
diff --git a/backend/go/piper/piper.go b/backend/go/piper/piper.go
new file mode 100644
index 0000000000000000000000000000000000000000..2ec985c9f46b2498bca336ee32df0df2679ef4a9
--- /dev/null
+++ b/backend/go/piper/piper.go
@@ -0,0 +1,49 @@
+package main
+
+// This is a wrapper to statisfy the GRPC service interface
+// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
+import (
+ "fmt"
+ "os"
+ "path/filepath"
+
+ "github.com/mudler/LocalAI/pkg/grpc/base"
+ pb "github.com/mudler/LocalAI/pkg/grpc/proto"
+ piper "github.com/mudler/go-piper"
+)
+
+type Piper struct {
+ base.SingleThread
+ piper *PiperB
+}
+
+func (sd *Piper) Load(opts *pb.ModelOptions) error {
+ if filepath.Ext(opts.ModelFile) != ".onnx" {
+ return fmt.Errorf("unsupported model type %s (should end with .onnx)", opts.ModelFile)
+ }
+ var err error
+ // Note: the Model here is a path to a directory containing the model files
+ sd.piper, err = New(os.Getenv("ESPEAK_NG_DATA"))
+ return err
+}
+
+func (sd *Piper) TTS(opts *pb.TTSRequest) error {
+ return sd.piper.TTS(opts.Text, opts.Model, opts.Dst)
+}
+
+type PiperB struct {
+ assetDir string
+}
+
+func New(assetDir string) (*PiperB, error) {
+ if _, err := os.Stat(assetDir); err != nil {
+ return nil, err
+ }
+ return &PiperB{
+ assetDir: assetDir,
+ }, nil
+}
+
+func (s *PiperB) TTS(text, model, dst string) error {
+ return piper.TextToWav(text, model, s.assetDir, "", dst)
+}
diff --git a/backend/go/piper/run.sh b/backend/go/piper/run.sh
new file mode 100644
index 0000000000000000000000000000000000000000..fe120ea882b11f0d04eabf674f8dd4d8d35e850d
--- /dev/null
+++ b/backend/go/piper/run.sh
@@ -0,0 +1,15 @@
+#!/bin/bash
+set -ex
+
+CURDIR=$(dirname "$(realpath $0)")
+
+export ESPEAK_NG_DATA=$CURDIR/espeak-ng-data
+export LD_LIBRARY_PATH=$CURDIR/lib:$LD_LIBRARY_PATH
+
+# If there is a lib/ld.so, use it
+if [ -f $CURDIR/lib/ld.so ]; then
+ echo "Using lib/ld.so"
+ exec $CURDIR/lib/ld.so $CURDIR/piper "$@"
+fi
+
+exec $CURDIR/piper "$@"
\ No newline at end of file
diff --git a/backend/go/silero-vad/Makefile b/backend/go/silero-vad/Makefile
new file mode 100644
index 0000000000000000000000000000000000000000..93fd6b4c9ea35a954a69989defce232c1d579c45
--- /dev/null
+++ b/backend/go/silero-vad/Makefile
@@ -0,0 +1,47 @@
+
+CURRENT_DIR=$(abspath ./)
+GOCMD=go
+
+ONNX_VERSION?=1.20.0
+ONNX_ARCH?=x64
+ONNX_OS?=linux
+
+# Detect if we are running on arm64
+ifneq (,$(findstring aarch64,$(shell uname -m)))
+ ONNX_ARCH=aarch64
+endif
+
+ifeq ($(OS),Darwin)
+ ONNX_OS=osx
+ ifneq (,$(findstring aarch64,$(shell uname -m)))
+ ONNX_ARCH=arm64
+ else ifneq (,$(findstring arm64,$(shell uname -m)))
+ ONNX_ARCH=arm64
+ else
+ ONNX_ARCH=x86_64
+ endif
+endif
+
+sources/onnxruntime:
+ mkdir -p sources/onnxruntime
+ curl -L https://github.com/microsoft/onnxruntime/releases/download/v$(ONNX_VERSION)/onnxruntime-$(ONNX_OS)-$(ONNX_ARCH)-$(ONNX_VERSION).tgz -o sources/onnxruntime/onnxruntime-$(ONNX_OS)-$(ONNX_ARCH)-$(ONNX_VERSION).tgz
+ cd sources/onnxruntime && tar -xvf onnxruntime-$(ONNX_OS)-$(ONNX_ARCH)-$(ONNX_VERSION).tgz && rm onnxruntime-$(ONNX_OS)-$(ONNX_ARCH)-$(ONNX_VERSION).tgz
+ cd sources/onnxruntime && mv onnxruntime-$(ONNX_OS)-$(ONNX_ARCH)-$(ONNX_VERSION)/* ./
+
+backend-assets/lib/libonnxruntime.so.1: sources/onnxruntime
+ mkdir -p backend-assets/lib
+ cp -rfLv sources/onnxruntime/lib/* backend-assets/lib/
+ifeq ($(OS),Darwin)
+ mv backend-assets/lib/libonnxruntime.$(ONNX_VERSION).dylib backend-assets/lib/libonnxruntime.dylib
+else
+ mv backend-assets/lib/libonnxruntime.so.$(ONNX_VERSION) backend-assets/lib/libonnxruntime.so.1
+endif
+
+silero-vad: backend-assets/lib/libonnxruntime.so.1
+ CGO_LDFLAGS="$(CGO_LDFLAGS)" CPATH="$(CPATH):$(CURRENT_DIR)/sources/onnxruntime/include/" LIBRARY_PATH=$(CURRENT_DIR)/backend-assets/lib \
+ $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o silero-vad ./
+
+package:
+ bash package.sh
+
+build: silero-vad package
\ No newline at end of file
diff --git a/backend/go/silero-vad/main.go b/backend/go/silero-vad/main.go
new file mode 100644
index 0000000000000000000000000000000000000000..28f51e49298fc27462e72c0bcd82c82f7dff9c7b
--- /dev/null
+++ b/backend/go/silero-vad/main.go
@@ -0,0 +1,21 @@
+package main
+
+// Note: this is started internally by LocalAI and a server is allocated for each model
+
+import (
+ "flag"
+
+ grpc "github.com/mudler/LocalAI/pkg/grpc"
+)
+
+var (
+ addr = flag.String("addr", "localhost:50051", "the address to connect to")
+)
+
+func main() {
+ flag.Parse()
+
+ if err := grpc.StartServer(*addr, &VAD{}); err != nil {
+ panic(err)
+ }
+}
diff --git a/backend/go/silero-vad/package.sh b/backend/go/silero-vad/package.sh
new file mode 100644
index 0000000000000000000000000000000000000000..1c524000c6b83a81bbc55d72433ea1c161befd5b
--- /dev/null
+++ b/backend/go/silero-vad/package.sh
@@ -0,0 +1,53 @@
+#!/bin/bash
+
+# Script to copy the appropriate libraries based on architecture
+# This script is used in the final stage of the Dockerfile
+
+set -e
+
+CURDIR=$(dirname "$(realpath $0)")
+
+# Create lib directory
+mkdir -p $CURDIR/package/lib
+
+cp -avrf $CURDIR/silero-vad $CURDIR/package/
+cp -avrf $CURDIR/run.sh $CURDIR/package/
+cp -rfLv $CURDIR/backend-assets/lib/* $CURDIR/package/lib/
+
+# Detect architecture and copy appropriate libraries
+if [ -f "/lib64/ld-linux-x86-64.so.2" ]; then
+ # x86_64 architecture
+ echo "Detected x86_64 architecture, copying x86_64 libraries..."
+ cp -arfLv /lib64/ld-linux-x86-64.so.2 $CURDIR/package/lib/ld.so
+ cp -arfLv /lib/x86_64-linux-gnu/libc.so.6 $CURDIR/package/lib/libc.so.6
+ cp -arfLv /lib/x86_64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1
+ cp -arfLv /lib/x86_64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6
+ cp -arfLv /lib/x86_64-linux-gnu/libm.so.6 $CURDIR/package/lib/libm.so.6
+ cp -arfLv /lib/x86_64-linux-gnu/libgomp.so.1 $CURDIR/package/lib/libgomp.so.1
+ cp -arfLv /lib/x86_64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1
+ cp -arfLv /lib/x86_64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6
+ cp -arfLv /lib/x86_64-linux-gnu/libdl.so.2 $CURDIR/package/lib/libdl.so.2
+ cp -arfLv /lib/x86_64-linux-gnu/librt.so.1 $CURDIR/package/lib/librt.so.1
+ cp -arfLv /lib/x86_64-linux-gnu/libpthread.so.0 $CURDIR/package/lib/libpthread.so.0
+elif [ -f "/lib/ld-linux-aarch64.so.1" ]; then
+ # ARM64 architecture
+ echo "Detected ARM64 architecture, copying ARM64 libraries..."
+ cp -arfLv /lib/ld-linux-aarch64.so.1 $CURDIR/package/lib/ld.so
+ cp -arfLv /lib/aarch64-linux-gnu/libc.so.6 $CURDIR/package/lib/libc.so.6
+ cp -arfLv /lib/aarch64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1
+ cp -arfLv /lib/aarch64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6
+ cp -arfLv /lib/aarch64-linux-gnu/libm.so.6 $CURDIR/package/lib/libm.so.6
+ cp -arfLv /lib/aarch64-linux-gnu/libgomp.so.1 $CURDIR/package/lib/libgomp.so.1
+ cp -arfLv /lib/aarch64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1
+ cp -arfLv /lib/aarch64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6
+ cp -arfLv /lib/aarch64-linux-gnu/libdl.so.2 $CURDIR/package/lib/libdl.so.2
+ cp -arfLv /lib/aarch64-linux-gnu/librt.so.1 $CURDIR/package/lib/librt.so.1
+ cp -arfLv /lib/aarch64-linux-gnu/libpthread.so.0 $CURDIR/package/lib/libpthread.so.0
+else
+ echo "Error: Could not detect architecture"
+ exit 1
+fi
+
+echo "Packaging completed successfully"
+ls -liah $CURDIR/package/
+ls -liah $CURDIR/package/lib/
\ No newline at end of file
diff --git a/backend/go/silero-vad/run.sh b/backend/go/silero-vad/run.sh
new file mode 100644
index 0000000000000000000000000000000000000000..72658908aa48ff27bbc3e2e0c5ef61e682728733
--- /dev/null
+++ b/backend/go/silero-vad/run.sh
@@ -0,0 +1,14 @@
+#!/bin/bash
+set -ex
+
+CURDIR=$(dirname "$(realpath $0)")
+
+export LD_LIBRARY_PATH=$CURDIR/lib:$LD_LIBRARY_PATH
+
+# If there is a lib/ld.so, use it
+if [ -f $CURDIR/lib/ld.so ]; then
+ echo "Using lib/ld.so"
+ exec $CURDIR/lib/ld.so $CURDIR/silero-vad "$@"
+fi
+
+exec $CURDIR/silero-vad "$@"
\ No newline at end of file
diff --git a/backend/go/silero-vad/vad.go b/backend/go/silero-vad/vad.go
new file mode 100644
index 0000000000000000000000000000000000000000..f3e9f7be8639a1081fc45b9fe2c82fc9ee987c43
--- /dev/null
+++ b/backend/go/silero-vad/vad.go
@@ -0,0 +1,58 @@
+package main
+
+// This is a wrapper to statisfy the GRPC service interface
+// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
+import (
+ "fmt"
+
+ "github.com/mudler/LocalAI/pkg/grpc/base"
+ pb "github.com/mudler/LocalAI/pkg/grpc/proto"
+ "github.com/streamer45/silero-vad-go/speech"
+)
+
+type VAD struct {
+ base.SingleThread
+ detector *speech.Detector
+}
+
+func (vad *VAD) Load(opts *pb.ModelOptions) error {
+ v, err := speech.NewDetector(speech.DetectorConfig{
+ ModelPath: opts.ModelFile,
+ SampleRate: 16000,
+ //WindowSize: 1024,
+ Threshold: 0.5,
+ MinSilenceDurationMs: 100,
+ SpeechPadMs: 30,
+ })
+ if err != nil {
+ return fmt.Errorf("create silero detector: %w", err)
+ }
+
+ vad.detector = v
+ return err
+}
+
+func (vad *VAD) VAD(req *pb.VADRequest) (pb.VADResponse, error) {
+ audio := req.Audio
+
+ if err := vad.detector.Reset(); err != nil {
+ return pb.VADResponse{}, fmt.Errorf("reset: %w", err)
+ }
+
+ segments, err := vad.detector.Detect(audio)
+ if err != nil {
+ return pb.VADResponse{}, fmt.Errorf("detect: %w", err)
+ }
+
+ vadSegments := []*pb.VADSegment{}
+ for _, s := range segments {
+ vadSegments = append(vadSegments, &pb.VADSegment{
+ Start: float32(s.SpeechStartAt),
+ End: float32(s.SpeechEndAt),
+ })
+ }
+
+ return pb.VADResponse{
+ Segments: vadSegments,
+ }, nil
+}
diff --git a/backend/go/stablediffusion-ggml/.gitignore b/backend/go/stablediffusion-ggml/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..2dfc6b056191807393d79f4c337fed5024297309
--- /dev/null
+++ b/backend/go/stablediffusion-ggml/.gitignore
@@ -0,0 +1,6 @@
+package/
+sources/
+.cache/
+build/
+libgosd.so
+stablediffusion-ggml
diff --git a/backend/go/stablediffusion-ggml/CMakeLists.txt b/backend/go/stablediffusion-ggml/CMakeLists.txt
new file mode 100644
index 0000000000000000000000000000000000000000..0d1d003e18eb33af4b5352399e84e88a20ce6dfb
--- /dev/null
+++ b/backend/go/stablediffusion-ggml/CMakeLists.txt
@@ -0,0 +1,20 @@
+cmake_minimum_required(VERSION 3.12)
+project(gosd LANGUAGES C CXX)
+set(CMAKE_POSITION_INDEPENDENT_CODE ON)
+
+add_subdirectory(./sources/stablediffusion-ggml.cpp)
+
+add_library(gosd MODULE gosd.cpp)
+target_link_libraries(gosd PRIVATE stable-diffusion ggml)
+
+if(CMAKE_CXX_COMPILER_ID MATCHES "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 9.0)
+ target_link_libraries(gosd PRIVATE stdc++fs)
+endif()
+
+target_include_directories(gosd PUBLIC
+ stable-diffusion.cpp
+ stable-diffusion.cpp/thirdparty
+)
+
+set_property(TARGET gosd PROPERTY CXX_STANDARD 17)
+set_target_properties(gosd PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR})
diff --git a/backend/go/stablediffusion-ggml/Makefile b/backend/go/stablediffusion-ggml/Makefile
new file mode 100644
index 0000000000000000000000000000000000000000..a18b0c82134c59ba9f5216b6dad25e65fd41e6d3
--- /dev/null
+++ b/backend/go/stablediffusion-ggml/Makefile
@@ -0,0 +1,86 @@
+CMAKE_ARGS?=
+BUILD_TYPE?=
+NATIVE?=false
+
+GOCMD?=go
+GO_TAGS?=
+JOBS?=$(shell nproc --ignore=1)
+
+# stablediffusion.cpp (ggml)
+STABLEDIFFUSION_GGML_REPO?=https://github.com/leejet/stable-diffusion.cpp
+STABLEDIFFUSION_GGML_VERSION?=7010bb4dff7bd55b03d35ef9772142c21699eba9
+
+CMAKE_ARGS+=-DGGML_MAX_NAME=128
+
+ifeq ($(NATIVE),false)
+ CMAKE_ARGS+=-DGGML_NATIVE=OFF
+endif
+
+# If build type is cublas, then we set -DGGML_CUDA=ON to CMAKE_ARGS automatically
+ifeq ($(BUILD_TYPE),cublas)
+ CMAKE_ARGS+=-DSD_CUDA=ON -DGGML_CUDA=ON
+# If build type is openblas then we set -DGGML_BLAS=ON -DGGML_BLAS_VENDOR=OpenBLAS
+# to CMAKE_ARGS automatically
+else ifeq ($(BUILD_TYPE),openblas)
+ CMAKE_ARGS+=-DGGML_BLAS=ON -DGGML_BLAS_VENDOR=OpenBLAS
+# If build type is clblas (openCL) we set -DGGML_CLBLAST=ON -DCLBlast_DIR=/some/path
+else ifeq ($(BUILD_TYPE),clblas)
+ CMAKE_ARGS+=-DGGML_CLBLAST=ON -DCLBlast_DIR=/some/path
+# If it's hipblas we do have also to set CC=/opt/rocm/llvm/bin/clang CXX=/opt/rocm/llvm/bin/clang++
+else ifeq ($(BUILD_TYPE),hipblas)
+ ROCM_HOME ?= /opt/rocm
+ ROCM_PATH ?= /opt/rocm
+ export CXX=$(ROCM_HOME)/llvm/bin/clang++
+ export CC=$(ROCM_HOME)/llvm/bin/clang
+ AMDGPU_TARGETS?=gfx803,gfx900,gfx906,gfx908,gfx90a,gfx942,gfx1010,gfx1030,gfx1032,gfx1100,gfx1101,gfx1102,gfx1200,gfx1201
+ CMAKE_ARGS+=-DSD_HIPBLAS=ON -DGGML_HIPBLAS=ON -DAMDGPU_TARGETS=$(AMDGPU_TARGETS)
+else ifeq ($(BUILD_TYPE),vulkan)
+ CMAKE_ARGS+=-DSD_VULKAN=ON -DGGML_VULKAN=ON
+else ifeq ($(OS),Darwin)
+ ifneq ($(BUILD_TYPE),metal)
+ CMAKE_ARGS+=-DSD_METAL=OFF -DGGML_METAL=OFF
+ else
+ CMAKE_ARGS+=-DSD_METAL=ON -DGGML_METAL=ON
+ CMAKE_ARGS+=-DGGML_METAL_EMBED_LIBRARY=ON
+ endif
+endif
+
+ifeq ($(BUILD_TYPE),sycl_f16)
+ CMAKE_ARGS+=-DGGML_SYCL=ON \
+ -DCMAKE_C_COMPILER=icx \
+ -DCMAKE_CXX_COMPILER=icpx \
+ -DSD_SYCL=ON \
+ -DGGML_SYCL_F16=ON
+endif
+
+ifeq ($(BUILD_TYPE),sycl_f32)
+ CMAKE_ARGS+=-DGGML_SYCL=ON \
+ -DCMAKE_C_COMPILER=icx \
+ -DCMAKE_CXX_COMPILER=icpx \
+ -DSD_SYCL=ON
+endif
+
+sources/stablediffusion-ggml.cpp:
+ git clone --recursive $(STABLEDIFFUSION_GGML_REPO) sources/stablediffusion-ggml.cpp && \
+ cd sources/stablediffusion-ggml.cpp && \
+ git checkout $(STABLEDIFFUSION_GGML_VERSION) && \
+ git submodule update --init --recursive --depth 1 --single-branch
+
+libgosd.so: sources/stablediffusion-ggml.cpp CMakeLists.txt gosd.cpp gosd.h
+ mkdir -p build && \
+ cd build && \
+ cmake .. $(CMAKE_ARGS) && \
+ cmake --build . --config Release -j$(JOBS) && \
+ cd .. && \
+ mv build/libgosd.so ./
+
+stablediffusion-ggml: main.go gosd.go libgosd.so
+ CGO_ENABLED=0 $(GOCMD) build -tags "$(GO_TAGS)" -o stablediffusion-ggml ./
+
+package: stablediffusion-ggml
+ bash package.sh
+
+build: package
+
+clean:
+ rm -rf libgosd.so build stablediffusion-ggml package sources
diff --git a/backend/go/stablediffusion-ggml/gosd.cpp b/backend/go/stablediffusion-ggml/gosd.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..2d8429c4ae10b8382ce609df8478443017fba57a
--- /dev/null
+++ b/backend/go/stablediffusion-ggml/gosd.cpp
@@ -0,0 +1,1117 @@
+#include "stable-diffusion.h"
+#include
+#include
+#define GGML_MAX_NAME 128
+
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include "gosd.h"
+
+#define STB_IMAGE_IMPLEMENTATION
+#define STB_IMAGE_STATIC
+#include "stb_image.h"
+
+#define STB_IMAGE_WRITE_IMPLEMENTATION
+#define STB_IMAGE_WRITE_STATIC
+#include "stb_image_write.h"
+
+#define STB_IMAGE_RESIZE_IMPLEMENTATION
+#define STB_IMAGE_RESIZE_STATIC
+#include "stb_image_resize.h"
+#include
+#include
+
+// Names of the sampler method, same order as enum sample_method in stable-diffusion.h
+const char* sample_method_str[] = {
+ "euler",
+ "euler_a",
+ "heun",
+ "dpm2",
+ "dpm++2s_a",
+ "dpm++2m",
+ "dpm++2mv2",
+ "ipndm",
+ "ipndm_v",
+ "lcm",
+ "ddim_trailing",
+ "tcd",
+};
+
+static_assert(std::size(sample_method_str) == SAMPLE_METHOD_COUNT, "sample method mismatch");
+
+// Names of the sigma schedule overrides, same order as sample_schedule in stable-diffusion.h
+const char* schedulers[] = {
+ "discrete",
+ "karras",
+ "exponential",
+ "ays",
+ "gits",
+ "sgm_uniform",
+ "simple",
+ "smoothstep",
+ "kl_optimal",
+ "lcm",
+};
+
+static_assert(std::size(schedulers) == SCHEDULER_COUNT, "schedulers mismatch");
+
+// New enum string arrays
+const char* rng_type_str[] = {
+ "std_default",
+ "cuda",
+ "cpu",
+};
+static_assert(std::size(rng_type_str) == RNG_TYPE_COUNT, "rng type mismatch");
+
+const char* prediction_str[] = {
+ "epsilon",
+ "v",
+ "edm_v",
+ "flow",
+ "flux_flow",
+ "flux2_flow",
+};
+static_assert(std::size(prediction_str) == PREDICTION_COUNT, "prediction mismatch");
+
+const char* lora_apply_mode_str[] = {
+ "auto",
+ "immediately",
+ "at_runtime",
+};
+static_assert(std::size(lora_apply_mode_str) == LORA_APPLY_MODE_COUNT, "lora apply mode mismatch");
+
+constexpr const char* sd_type_str[] = {
+ "f32", // 0
+ "f16", // 1
+ "q4_0", // 2
+ "q4_1", // 3
+ nullptr, // 4
+ nullptr, // 5
+ "q5_0", // 6
+ "q5_1", // 7
+ "q8_0", // 8
+ "q8_1", // 9
+ "q2_k", // 10
+ "q3_k", // 11
+ "q4_k", // 12
+ "q5_k", // 13
+ "q6_k", // 14
+ "q8_k", // 15
+ "iq2_xxs", // 16
+ "iq2_xs", // 17
+ "iq3_xxs", // 18
+ "iq1_s", // 19
+ "iq4_nl", // 20
+ "iq3_s", // 21
+ "iq2_s", // 22
+ "iq4_xs", // 23
+ "i8", // 24
+ "i16", // 25
+ "i32", // 26
+ "i64", // 27
+ "f64", // 28
+ "iq1_m", // 29
+ "bf16", // 30
+ nullptr, nullptr, nullptr, nullptr, // 31-34
+ "tq1_0", // 35
+ "tq2_0", // 36
+ nullptr, nullptr, // 37-38
+ "mxfp4" // 39
+};
+static_assert(std::size(sd_type_str) == SD_TYPE_COUNT, "sd type mismatch");
+
+sd_ctx_params_t ctx_params;
+sd_ctx_t* sd_c;
+// Moved from the context (load time) to generation time params
+scheduler_t scheduler = SCHEDULER_COUNT;
+sample_method_t sample_method = SAMPLE_METHOD_COUNT;
+
+// Storage for embeddings (needs to persist for the lifetime of ctx_params)
+static std::vector embedding_vec;
+// Storage for embedding strings (needs to persist as long as embedding_vec references them)
+static std::vector embedding_strings;
+
+// Storage for LoRAs (needs to persist for the lifetime of generation params)
+static std::vector lora_vec;
+// Storage for LoRA strings (needs to persist as long as lora_vec references them)
+static std::vector lora_strings;
+// Storage for lora_dir path
+static std::string lora_dir_path;
+
+// Build embeddings vector from directory, similar to upstream CLI
+static void build_embedding_vec(const char* embedding_dir) {
+ embedding_vec.clear();
+ embedding_strings.clear();
+
+ if (!embedding_dir || strlen(embedding_dir) == 0) {
+ return;
+ }
+
+ if (!std::filesystem::exists(embedding_dir) || !std::filesystem::is_directory(embedding_dir)) {
+ fprintf(stderr, "Embedding directory does not exist or is not a directory: %s\n", embedding_dir);
+ return;
+ }
+
+ static const std::vector valid_ext = {".pt", ".safetensors", ".gguf"};
+
+ for (const auto& entry : std::filesystem::directory_iterator(embedding_dir)) {
+ if (!entry.is_regular_file()) {
+ continue;
+ }
+
+ auto path = entry.path();
+ std::string ext = path.extension().string();
+
+ bool valid = false;
+ for (const auto& e : valid_ext) {
+ if (ext == e) {
+ valid = true;
+ break;
+ }
+ }
+ if (!valid) {
+ continue;
+ }
+
+ std::string name = path.stem().string();
+ std::string full_path = path.string();
+
+ // Store strings in persistent storage
+ embedding_strings.push_back(name);
+ embedding_strings.push_back(full_path);
+
+ sd_embedding_t item;
+ item.name = embedding_strings[embedding_strings.size() - 2].c_str();
+ item.path = embedding_strings[embedding_strings.size() - 1].c_str();
+
+ embedding_vec.push_back(item);
+ fprintf(stderr, "Found embedding: %s -> %s\n", item.name, item.path);
+ }
+
+ fprintf(stderr, "Loaded %zu embeddings from %s\n", embedding_vec.size(), embedding_dir);
+}
+
+// Discover LoRA files in directory and build a map of name -> path
+static std::map discover_lora_files(const char* lora_dir) {
+ std::map lora_map;
+
+ if (!lora_dir || strlen(lora_dir) == 0) {
+ fprintf(stderr, "LoRA directory not specified\n");
+ return lora_map;
+ }
+
+ if (!std::filesystem::exists(lora_dir) || !std::filesystem::is_directory(lora_dir)) {
+ fprintf(stderr, "LoRA directory does not exist or is not a directory: %s\n", lora_dir);
+ return lora_map;
+ }
+
+ static const std::vector valid_ext = {".safetensors", ".ckpt", ".pt", ".gguf"};
+
+ fprintf(stderr, "Discovering LoRA files in: %s\n", lora_dir);
+
+ for (const auto& entry : std::filesystem::directory_iterator(lora_dir)) {
+ if (!entry.is_regular_file()) {
+ continue;
+ }
+
+ auto path = entry.path();
+ std::string ext = path.extension().string();
+
+ bool valid = false;
+ for (const auto& e : valid_ext) {
+ if (ext == e) {
+ valid = true;
+ break;
+ }
+ }
+ if (!valid) {
+ continue;
+ }
+
+ std::string name = path.stem().string(); // stem() already removes extension
+ std::string full_path = path.string();
+
+ // Store the name (without extension) -> full path mapping
+ // This allows users to specify just the name in
+ lora_map[name] = full_path;
+
+ fprintf(stderr, "Found LoRA file: %s -> %s\n", name.c_str(), full_path.c_str());
+ }
+
+ fprintf(stderr, "Discovered %zu LoRA files in %s\n", lora_map.size(), lora_dir);
+ return lora_map;
+}
+
+// Helper function to check if a path is absolute (matches upstream)
+static bool is_absolute_path(const std::string& p) {
+#ifdef _WIN32
+ // Windows: C:/path or C:\path
+ return p.size() > 1 && std::isalpha(static_cast(p[0])) && p[1] == ':';
+#else
+ // Unix: /path
+ return !p.empty() && p[0] == '/';
+#endif
+}
+
+// Parse LoRAs from prompt string (e.g., "" or "")
+// Returns a vector of LoRA info and the cleaned prompt with LoRA tags removed
+// Matches upstream implementation more closely
+static std::pair, std::string> parse_loras_from_prompt(const std::string& prompt, const char* lora_dir) {
+ std::vector loras;
+ std::string cleaned_prompt = prompt;
+
+ if (!lora_dir || strlen(lora_dir) == 0) {
+ fprintf(stderr, "LoRA directory not set, cannot parse LoRAs from prompt\n");
+ return {loras, cleaned_prompt};
+ }
+
+ // Discover LoRA files for name-based lookup
+ std::map discovered_lora_map = discover_lora_files(lora_dir);
+
+ // Map to accumulate multipliers for the same LoRA (matches upstream)
+ std::map lora_map;
+ std::map high_noise_lora_map;
+
+ static const std::regex re(R"(]+):([^>]+)>)");
+ static const std::vector valid_ext = {".pt", ".safetensors", ".gguf"};
+ std::smatch m;
+
+ std::string tmp = prompt;
+
+ fprintf(stderr, "Parsing LoRAs from prompt: %s\n", prompt.c_str());
+
+ while (std::regex_search(tmp, m, re)) {
+ std::string raw_path = m[1].str();
+ const std::string raw_mul = m[2].str();
+
+ float mul = 0.f;
+ try {
+ mul = std::stof(raw_mul);
+ } catch (...) {
+ tmp = m.suffix().str();
+ cleaned_prompt = std::regex_replace(cleaned_prompt, re, "", std::regex_constants::format_first_only);
+ fprintf(stderr, "Invalid LoRA multiplier '%s', skipping\n", raw_mul.c_str());
+ continue;
+ }
+
+ bool is_high_noise = false;
+ static const std::string prefix = "|high_noise|";
+ if (raw_path.rfind(prefix, 0) == 0) {
+ raw_path.erase(0, prefix.size());
+ is_high_noise = true;
+ }
+
+ std::filesystem::path final_path;
+ if (is_absolute_path(raw_path)) {
+ final_path = raw_path;
+ } else {
+ // Try name-based lookup first
+ auto it = discovered_lora_map.find(raw_path);
+ if (it != discovered_lora_map.end()) {
+ final_path = it->second;
+ } else {
+ // Try case-insensitive lookup
+ bool found = false;
+ for (const auto& pair : discovered_lora_map) {
+ std::string lower_name = raw_path;
+ std::string lower_key = pair.first;
+ std::transform(lower_name.begin(), lower_name.end(), lower_name.begin(), ::tolower);
+ std::transform(lower_key.begin(), lower_key.end(), lower_key.begin(), ::tolower);
+ if (lower_name == lower_key) {
+ final_path = pair.second;
+ found = true;
+ break;
+ }
+ }
+ if (!found) {
+ // Try as relative path in lora_dir
+ final_path = std::filesystem::path(lora_dir) / raw_path;
+ }
+ }
+ }
+
+ // Try adding extensions if file doesn't exist
+ if (!std::filesystem::exists(final_path)) {
+ bool found = false;
+ for (const auto& ext : valid_ext) {
+ std::filesystem::path try_path = final_path;
+ try_path += ext;
+ if (std::filesystem::exists(try_path)) {
+ final_path = try_path;
+ found = true;
+ break;
+ }
+ }
+ if (!found) {
+ fprintf(stderr, "WARNING: LoRA file not found: %s\n", final_path.lexically_normal().string().c_str());
+ tmp = m.suffix().str();
+ cleaned_prompt = std::regex_replace(cleaned_prompt, re, "", std::regex_constants::format_first_only);
+ continue;
+ }
+ }
+
+ // Normalize path (matches upstream)
+ const std::string key = final_path.lexically_normal().string();
+
+ // Accumulate multiplier if same LoRA appears multiple times (matches upstream)
+ if (is_high_noise) {
+ high_noise_lora_map[key] += mul;
+ } else {
+ lora_map[key] += mul;
+ }
+
+ fprintf(stderr, "Parsed LoRA: path='%s', multiplier=%.2f, is_high_noise=%s\n",
+ key.c_str(), mul, is_high_noise ? "true" : "false");
+
+ cleaned_prompt = std::regex_replace(cleaned_prompt, re, "", std::regex_constants::format_first_only);
+ tmp = m.suffix().str();
+ }
+
+ // Build final LoRA vector from accumulated maps (matches upstream)
+ // Store all path strings first to ensure they persist
+ for (const auto& kv : lora_map) {
+ lora_strings.push_back(kv.first);
+ }
+ for (const auto& kv : high_noise_lora_map) {
+ lora_strings.push_back(kv.first);
+ }
+
+ // Now build the LoRA vector with pointers to the stored strings
+ size_t string_idx = 0;
+ for (const auto& kv : lora_map) {
+ sd_lora_t item;
+ item.is_high_noise = false;
+ item.path = lora_strings[string_idx].c_str();
+ item.multiplier = kv.second;
+ loras.push_back(item);
+ string_idx++;
+ }
+
+ for (const auto& kv : high_noise_lora_map) {
+ sd_lora_t item;
+ item.is_high_noise = true;
+ item.path = lora_strings[string_idx].c_str();
+ item.multiplier = kv.second;
+ loras.push_back(item);
+ string_idx++;
+ }
+
+ // Clean up extra spaces
+ std::regex space_regex(R"(\s+)");
+ cleaned_prompt = std::regex_replace(cleaned_prompt, space_regex, " ");
+ // Trim leading/trailing spaces
+ size_t first = cleaned_prompt.find_first_not_of(" \t");
+ if (first != std::string::npos) {
+ cleaned_prompt.erase(0, first);
+ }
+ size_t last = cleaned_prompt.find_last_not_of(" \t");
+ if (last != std::string::npos) {
+ cleaned_prompt.erase(last + 1);
+ }
+
+ fprintf(stderr, "Parsed %zu LoRA(s) from prompt. Cleaned prompt: %s\n", loras.size(), cleaned_prompt.c_str());
+
+ return {loras, cleaned_prompt};
+}
+
+// Copied from the upstream CLI
+static void sd_log_cb(enum sd_log_level_t level, const char* log, void* data) {
+ //SDParams* params = (SDParams*)data;
+ const char* level_str;
+
+ if (!log /*|| (!params->verbose && level <= SD_LOG_DEBUG)*/) {
+ return;
+ }
+
+ switch (level) {
+ case SD_LOG_DEBUG:
+ level_str = "DEBUG";
+ break;
+ case SD_LOG_INFO:
+ level_str = "INFO";
+ break;
+ case SD_LOG_WARN:
+ level_str = "WARN";
+ break;
+ case SD_LOG_ERROR:
+ level_str = "ERROR";
+ break;
+ default: /* Potential future-proofing */
+ level_str = "?????";
+ break;
+ }
+
+ fprintf(stderr, "[%-5s] ", level_str);
+ fputs(log, stderr);
+ fflush(stderr);
+}
+
+int load_model(const char *model, char *model_path, char* options[], int threads, int diff) {
+ fprintf (stderr, "Loading model: %p=%s\n", model, model);
+
+ sd_set_log_callback(sd_log_cb, NULL);
+
+ const char *stableDiffusionModel = "";
+ if (diff == 1 ) {
+ stableDiffusionModel = strdup(model);
+ model = "";
+ }
+
+ // decode options. Options are in form optname:optvale, or if booleans only optname.
+ const char *clip_l_path = "";
+ const char *clip_g_path = "";
+ const char *t5xxl_path = "";
+ const char *vae_path = "";
+ const char *scheduler_str = "";
+ const char *sampler = "";
+ const char *clip_vision_path = "";
+ const char *llm_path = "";
+ const char *llm_vision_path = "";
+ const char *diffusion_model_path = stableDiffusionModel;
+ const char *high_noise_diffusion_model_path = "";
+ const char *taesd_path = "";
+ const char *control_net_path = "";
+ const char *embedding_dir = "";
+ const char *photo_maker_path = "";
+ const char *tensor_type_rules = "";
+ char *lora_dir = model_path;
+
+ bool vae_decode_only = true;
+ int n_threads = threads;
+ enum sd_type_t wtype = SD_TYPE_COUNT;
+ enum rng_type_t rng_type = CUDA_RNG;
+ enum rng_type_t sampler_rng_type = RNG_TYPE_COUNT;
+ enum prediction_t prediction = PREDICTION_COUNT;
+ enum lora_apply_mode_t lora_apply_mode = LORA_APPLY_AUTO;
+ bool offload_params_to_cpu = false;
+ bool keep_clip_on_cpu = false;
+ bool keep_control_net_on_cpu = false;
+ bool keep_vae_on_cpu = false;
+ bool diffusion_flash_attn = false;
+ bool tae_preview_only = false;
+ bool diffusion_conv_direct = false;
+ bool vae_conv_direct = false;
+ bool force_sdxl_vae_conv_scale = false;
+ bool chroma_use_dit_mask = true;
+ bool chroma_use_t5_mask = false;
+ int chroma_t5_mask_pad = 1;
+ float flow_shift = INFINITY;
+
+ fprintf(stderr, "parsing options: %p\n", options);
+
+ // If options is not NULL, parse options
+ for (int i = 0; options[i] != NULL; i++) {
+ const char *optname = strtok(options[i], ":");
+ const char *optval = strtok(NULL, ":");
+ if (optval == NULL) {
+ optval = "true";
+ }
+
+ if (!strcmp(optname, "clip_l_path")) {
+ clip_l_path = strdup(optval);
+ }
+ if (!strcmp(optname, "clip_g_path")) {
+ clip_g_path = strdup(optval);
+ }
+ if (!strcmp(optname, "t5xxl_path")) {
+ t5xxl_path = strdup(optval);
+ }
+ if (!strcmp(optname, "vae_path")) {
+ vae_path = strdup(optval);
+ }
+ if (!strcmp(optname, "scheduler")) {
+ scheduler_str = optval;
+ }
+ if (!strcmp(optname, "sampler")) {
+ sampler = optval;
+ }
+ if (!strcmp(optname, "lora_dir")) {
+ // Path join with model dir
+ if (model_path && strlen(model_path) > 0) {
+ std::filesystem::path model_path_str(model_path);
+ std::filesystem::path lora_path(optval);
+ std::filesystem::path full_lora_path = model_path_str / lora_path;
+ lora_dir = strdup(full_lora_path.string().c_str());
+ lora_dir_path = full_lora_path.string();
+ fprintf(stderr, "LoRA dir resolved to: %s\n", lora_dir);
+ } else {
+ lora_dir = strdup(optval);
+ lora_dir_path = std::string(optval);
+ fprintf(stderr, "No model path provided, using lora dir as-is: %s\n", lora_dir);
+ }
+ // Discover LoRAs immediately when directory is set
+ if (lora_dir && strlen(lora_dir) > 0) {
+ discover_lora_files(lora_dir);
+ }
+ }
+
+ // New parsing
+ if (!strcmp(optname, "clip_vision_path")) clip_vision_path = strdup(optval);
+ if (!strcmp(optname, "llm_path")) llm_path = strdup(optval);
+ if (!strcmp(optname, "llm_vision_path")) llm_vision_path = strdup(optval);
+ if (!strcmp(optname, "diffusion_model_path")) diffusion_model_path = strdup(optval);
+ if (!strcmp(optname, "high_noise_diffusion_model_path")) high_noise_diffusion_model_path = strdup(optval);
+ if (!strcmp(optname, "taesd_path")) taesd_path = strdup(optval);
+ if (!strcmp(optname, "control_net_path")) control_net_path = strdup(optval);
+ if (!strcmp(optname, "embedding_dir")) {
+ // Path join with model dir
+ if (model_path && strlen(model_path) > 0) {
+ std::filesystem::path model_path_str(model_path);
+ std::filesystem::path embedding_path(optval);
+ std::filesystem::path full_embedding_path = model_path_str / embedding_path;
+ embedding_dir = strdup(full_embedding_path.string().c_str());
+ fprintf(stderr, "Embedding dir resolved to: %s\n", embedding_dir);
+ } else {
+ embedding_dir = strdup(optval);
+ fprintf(stderr, "No model path provided, using embedding dir as-is: %s\n", embedding_dir);
+ }
+ }
+ if (!strcmp(optname, "photo_maker_path")) photo_maker_path = strdup(optval);
+ if (!strcmp(optname, "tensor_type_rules")) tensor_type_rules = strdup(optval);
+
+ if (!strcmp(optname, "vae_decode_only")) vae_decode_only = (strcmp(optval, "true") == 0 || strcmp(optval, "1") == 0);
+ if (!strcmp(optname, "offload_params_to_cpu")) offload_params_to_cpu = (strcmp(optval, "true") == 0 || strcmp(optval, "1") == 0);
+ if (!strcmp(optname, "keep_clip_on_cpu")) keep_clip_on_cpu = (strcmp(optval, "true") == 0 || strcmp(optval, "1") == 0);
+ if (!strcmp(optname, "keep_control_net_on_cpu")) keep_control_net_on_cpu = (strcmp(optval, "true") == 0 || strcmp(optval, "1") == 0);
+ if (!strcmp(optname, "keep_vae_on_cpu")) keep_vae_on_cpu = (strcmp(optval, "true") == 0 || strcmp(optval, "1") == 0);
+ if (!strcmp(optname, "diffusion_flash_attn")) diffusion_flash_attn = (strcmp(optval, "true") == 0 || strcmp(optval, "1") == 0);
+ if (!strcmp(optname, "tae_preview_only")) tae_preview_only = (strcmp(optval, "true") == 0 || strcmp(optval, "1") == 0);
+ if (!strcmp(optname, "diffusion_conv_direct")) diffusion_conv_direct = (strcmp(optval, "true") == 0 || strcmp(optval, "1") == 0);
+ if (!strcmp(optname, "vae_conv_direct")) vae_conv_direct = (strcmp(optval, "true") == 0 || strcmp(optval, "1") == 0);
+ if (!strcmp(optname, "force_sdxl_vae_conv_scale")) force_sdxl_vae_conv_scale = (strcmp(optval, "true") == 0 || strcmp(optval, "1") == 0);
+ if (!strcmp(optname, "chroma_use_dit_mask")) chroma_use_dit_mask = (strcmp(optval, "true") == 0 || strcmp(optval, "1") == 0);
+ if (!strcmp(optname, "chroma_use_t5_mask")) chroma_use_t5_mask = (strcmp(optval, "true") == 0 || strcmp(optval, "1") == 0);
+
+ if (!strcmp(optname, "n_threads")) n_threads = atoi(optval);
+ if (!strcmp(optname, "chroma_t5_mask_pad")) chroma_t5_mask_pad = atoi(optval);
+
+ if (!strcmp(optname, "flow_shift")) flow_shift = atof(optval);
+
+ if (!strcmp(optname, "rng_type")) {
+ int found = -1;
+ for (int m = 0; m < RNG_TYPE_COUNT; m++) {
+ if (!strcmp(optval, rng_type_str[m])) {
+ found = m;
+ break;
+ }
+ }
+ if (found != -1) {
+ rng_type = (rng_type_t)found;
+ fprintf(stderr, "Found rng_type: %s\n", optval);
+ } else {
+ fprintf(stderr, "Invalid rng_type: %s, using default\n", optval);
+ }
+ }
+ if (!strcmp(optname, "sampler_rng_type")) {
+ int found = -1;
+ for (int m = 0; m < RNG_TYPE_COUNT; m++) {
+ if (!strcmp(optval, rng_type_str[m])) {
+ found = m;
+ break;
+ }
+ }
+ if (found != -1) {
+ sampler_rng_type = (rng_type_t)found;
+ fprintf(stderr, "Found sampler_rng_type: %s\n", optval);
+ } else {
+ fprintf(stderr, "Invalid sampler_rng_type: %s, using default\n", optval);
+ }
+ }
+ if (!strcmp(optname, "prediction")) {
+ int found = -1;
+ for (int m = 0; m < PREDICTION_COUNT; m++) {
+ if (!strcmp(optval, prediction_str[m])) {
+ found = m;
+ break;
+ }
+ }
+ if (found != -1) {
+ prediction = (prediction_t)found;
+ fprintf(stderr, "Found prediction: %s\n", optval);
+ } else {
+ fprintf(stderr, "Invalid prediction: %s, using default\n", optval);
+ }
+ }
+ if (!strcmp(optname, "lora_apply_mode")) {
+ int found = -1;
+ for (int m = 0; m < LORA_APPLY_MODE_COUNT; m++) {
+ if (!strcmp(optval, lora_apply_mode_str[m])) {
+ found = m;
+ break;
+ }
+ }
+ if (found != -1) {
+ lora_apply_mode = (lora_apply_mode_t)found;
+ fprintf(stderr, "Found lora_apply_mode: %s\n", optval);
+ } else {
+ fprintf(stderr, "Invalid lora_apply_mode: %s, using default\n", optval);
+ }
+ }
+ if (!strcmp(optname, "wtype")) {
+ int found = -1;
+ for (int m = 0; m < SD_TYPE_COUNT; m++) {
+ if (sd_type_str[m] && !strcmp(optval, sd_type_str[m])) {
+ found = m;
+ break;
+ }
+ }
+ if (found != -1) {
+ wtype = (sd_type_t)found;
+ fprintf(stderr, "Found wtype: %s\n", optval);
+ } else {
+ fprintf(stderr, "Invalid wtype: %s, using default\n", optval);
+ }
+ }
+ }
+
+ fprintf(stderr, "parsed options\n");
+
+ // Build embeddings vector from directory if provided
+ build_embedding_vec(embedding_dir);
+
+ fprintf (stderr, "Creating context\n");
+ sd_ctx_params_init(&ctx_params);
+ ctx_params.model_path = model;
+ ctx_params.clip_l_path = clip_l_path;
+ ctx_params.clip_g_path = clip_g_path;
+ ctx_params.clip_vision_path = clip_vision_path;
+ ctx_params.t5xxl_path = t5xxl_path;
+ ctx_params.llm_path = llm_path;
+ ctx_params.llm_vision_path = llm_vision_path;
+ ctx_params.diffusion_model_path = diffusion_model_path;
+ ctx_params.high_noise_diffusion_model_path = high_noise_diffusion_model_path;
+ ctx_params.vae_path = vae_path;
+ ctx_params.taesd_path = taesd_path;
+ ctx_params.control_net_path = control_net_path;
+ if (lora_dir && strlen(lora_dir) > 0) {
+ lora_dir_path = std::string(lora_dir);
+ fprintf(stderr, "LoRA model directory set to: %s\n", lora_dir);
+ // Discover LoRAs at load time for logging
+ discover_lora_files(lora_dir);
+ } else {
+ fprintf(stderr, "WARNING: LoRA model directory not set. LoRAs in prompts will not be loaded.\n");
+ }
+ // Set embeddings array and count
+ ctx_params.embeddings = embedding_vec.empty() ? NULL : embedding_vec.data();
+ ctx_params.embedding_count = static_cast(embedding_vec.size());
+ ctx_params.photo_maker_path = photo_maker_path;
+ ctx_params.tensor_type_rules = tensor_type_rules;
+ ctx_params.vae_decode_only = vae_decode_only;
+ // XXX: Setting to true causes a segfault on the second run
+ ctx_params.free_params_immediately = false;
+ ctx_params.n_threads = n_threads;
+ ctx_params.rng_type = rng_type;
+ ctx_params.keep_clip_on_cpu = keep_clip_on_cpu;
+ if (wtype != SD_TYPE_COUNT) ctx_params.wtype = wtype;
+ if (sampler_rng_type != RNG_TYPE_COUNT) ctx_params.sampler_rng_type = sampler_rng_type;
+ if (prediction != PREDICTION_COUNT) ctx_params.prediction = prediction;
+ if (lora_apply_mode != LORA_APPLY_MODE_COUNT) ctx_params.lora_apply_mode = lora_apply_mode;
+ ctx_params.offload_params_to_cpu = offload_params_to_cpu;
+ ctx_params.keep_control_net_on_cpu = keep_control_net_on_cpu;
+ ctx_params.keep_vae_on_cpu = keep_vae_on_cpu;
+ ctx_params.diffusion_flash_attn = diffusion_flash_attn;
+ ctx_params.tae_preview_only = tae_preview_only;
+ ctx_params.diffusion_conv_direct = diffusion_conv_direct;
+ ctx_params.vae_conv_direct = vae_conv_direct;
+ ctx_params.force_sdxl_vae_conv_scale = force_sdxl_vae_conv_scale;
+ ctx_params.chroma_use_dit_mask = chroma_use_dit_mask;
+ ctx_params.chroma_use_t5_mask = chroma_use_t5_mask;
+ ctx_params.chroma_t5_mask_pad = chroma_t5_mask_pad;
+ ctx_params.flow_shift = flow_shift;
+ sd_ctx_t* sd_ctx = new_sd_ctx(&ctx_params);
+
+ if (sd_ctx == NULL) {
+ fprintf (stderr, "failed loading model (generic error)\n");
+ // TODO: Clean up allocated memory
+ return 1;
+ }
+ fprintf (stderr, "Created context: OK\n");
+
+ int sample_method_found = -1;
+ for (int m = 0; m < SAMPLE_METHOD_COUNT; m++) {
+ if (!strcmp(sampler, sample_method_str[m])) {
+ sample_method_found = m;
+ fprintf(stderr, "Found sampler: %s\n", sampler);
+ }
+ }
+ if (sample_method_found == -1) {
+ sample_method_found = sd_get_default_sample_method(sd_ctx);
+ fprintf(stderr, "Invalid sample method, using default: %s\n", sample_method_str[sample_method_found]);
+ }
+ sample_method = (sample_method_t)sample_method_found;
+
+ for (int d = 0; d < SCHEDULER_COUNT; d++) {
+ if (!strcmp(scheduler_str, schedulers[d])) {
+ scheduler = (scheduler_t)d;
+ fprintf (stderr, "Found scheduler: %s\n", scheduler_str);
+ }
+ }
+ if (scheduler == SCHEDULER_COUNT) {
+ scheduler = sd_get_default_scheduler(sd_ctx, sample_method);
+ fprintf(stderr, "Invalid scheduler, using default: %s\n", schedulers[scheduler]);
+ }
+
+ sd_c = sd_ctx;
+
+ return 0;
+}
+
+void sd_tiling_params_set_enabled(sd_tiling_params_t *params, bool enabled) {
+ params->enabled = enabled;
+}
+
+void sd_tiling_params_set_tile_sizes(sd_tiling_params_t *params, int tile_size_x, int tile_size_y) {
+ params->tile_size_x = tile_size_x;
+ params->tile_size_y = tile_size_y;
+}
+
+void sd_tiling_params_set_rel_sizes(sd_tiling_params_t *params, float rel_size_x, float rel_size_y) {
+ params->rel_size_x = rel_size_x;
+ params->rel_size_y = rel_size_y;
+}
+
+void sd_tiling_params_set_target_overlap(sd_tiling_params_t *params, float target_overlap) {
+ params->target_overlap = target_overlap;
+}
+
+sd_tiling_params_t* sd_img_gen_params_get_vae_tiling_params(sd_img_gen_params_t *params) {
+ return ¶ms->vae_tiling_params;
+}
+
+sd_img_gen_params_t* sd_img_gen_params_new(void) {
+ sd_img_gen_params_t *params = (sd_img_gen_params_t *)std::malloc(sizeof(sd_img_gen_params_t));
+ sd_img_gen_params_init(params);
+ sd_sample_params_init(¶ms->sample_params);
+ sd_cache_params_init(¶ms->cache);
+ params->control_strength = 0.9f;
+ return params;
+}
+
+// Storage for cleaned prompt strings (needs to persist)
+static std::string cleaned_prompt_storage;
+static std::string cleaned_negative_prompt_storage;
+
+void sd_img_gen_params_set_prompts(sd_img_gen_params_t *params, const char *prompt, const char *negative_prompt) {
+ // Clear previous LoRA data
+ lora_vec.clear();
+ lora_strings.clear();
+
+ // Parse LoRAs from prompt
+ std::string prompt_str = prompt ? prompt : "";
+ std::string negative_prompt_str = negative_prompt ? negative_prompt : "";
+
+ // Get lora_dir from ctx_params if available, otherwise use stored path
+ const char* lora_dir_to_use = lora_dir_path.empty() ? nullptr : lora_dir_path.c_str();
+
+ auto [loras, cleaned_prompt] = parse_loras_from_prompt(prompt_str, lora_dir_to_use);
+ lora_vec = loras;
+ cleaned_prompt_storage = cleaned_prompt;
+
+ // Also check negative prompt for LoRAs (though this is less common)
+ auto [neg_loras, cleaned_negative] = parse_loras_from_prompt(negative_prompt_str, lora_dir_to_use);
+ // Merge negative prompt LoRAs (though typically not used)
+ if (!neg_loras.empty()) {
+ fprintf(stderr, "Note: Found %zu LoRAs in negative prompt (may not be supported)\n", neg_loras.size());
+ }
+ cleaned_negative_prompt_storage = cleaned_negative;
+
+ // Set the cleaned prompts
+ params->prompt = cleaned_prompt_storage.c_str();
+ params->negative_prompt = cleaned_negative_prompt_storage.c_str();
+
+ // Set LoRAs in params
+ params->loras = lora_vec.empty() ? nullptr : lora_vec.data();
+ params->lora_count = static_cast(lora_vec.size());
+
+ fprintf(stderr, "Set prompts with %zu LoRAs. Original prompt: %s\n", lora_vec.size(), prompt ? prompt : "(null)");
+ fprintf(stderr, "Cleaned prompt: %s\n", cleaned_prompt_storage.c_str());
+
+ // Debug: Verify LoRAs are set correctly
+ if (params->loras && params->lora_count > 0) {
+ fprintf(stderr, "DEBUG: LoRAs set in params structure:\n");
+ for (uint32_t i = 0; i < params->lora_count; i++) {
+ fprintf(stderr, " params->loras[%u]: path='%s' (ptr=%p), multiplier=%.2f, is_high_noise=%s\n",
+ i,
+ params->loras[i].path ? params->loras[i].path : "(null)",
+ (void*)params->loras[i].path,
+ params->loras[i].multiplier,
+ params->loras[i].is_high_noise ? "true" : "false");
+ }
+ } else {
+ fprintf(stderr, "DEBUG: No LoRAs set in params structure (loras=%p, lora_count=%u)\n",
+ (void*)params->loras, params->lora_count);
+ }
+}
+
+void sd_img_gen_params_set_dimensions(sd_img_gen_params_t *params, int width, int height) {
+ params->width = width;
+ params->height = height;
+}
+
+void sd_img_gen_params_set_seed(sd_img_gen_params_t *params, int64_t seed) {
+ params->seed = seed;
+}
+
+int gen_image(sd_img_gen_params_t *p, int steps, char *dst, float cfg_scale, char *src_image, float strength, char *mask_image, char* ref_images[], int ref_images_count) {
+
+ sd_image_t* results;
+
+ std::vector skip_layers = {7, 8, 9};
+
+ fprintf (stderr, "Generating image\n");
+
+ p->sample_params.guidance.txt_cfg = cfg_scale;
+ p->sample_params.guidance.slg.layers = skip_layers.data();
+ p->sample_params.guidance.slg.layer_count = skip_layers.size();
+ p->sample_params.sample_method = sample_method;
+ p->sample_params.sample_steps = steps;
+ p->sample_params.scheduler = scheduler;
+
+ int width = p->width;
+ int height = p->height;
+
+ // Handle input image for img2img
+ bool has_input_image = (src_image != NULL && strlen(src_image) > 0);
+ bool has_mask_image = (mask_image != NULL && strlen(mask_image) > 0);
+
+ uint8_t* input_image_buffer = NULL;
+ uint8_t* mask_image_buffer = NULL;
+ std::vector default_mask_image_vec;
+
+ if (has_input_image) {
+ fprintf(stderr, "Loading input image: %s\n", src_image);
+
+ int c = 0;
+ int img_width = 0;
+ int img_height = 0;
+ input_image_buffer = stbi_load(src_image, &img_width, &img_height, &c, 3);
+ if (input_image_buffer == NULL) {
+ fprintf(stderr, "Failed to load input image from '%s'\n", src_image);
+ return 1;
+ }
+ if (c < 3) {
+ fprintf(stderr, "Input image must have at least 3 channels, got %d\n", c);
+ free(input_image_buffer);
+ return 1;
+ }
+
+ // Resize input image if dimensions don't match
+ if (img_width != width || img_height != height) {
+ fprintf(stderr, "Resizing input image from %dx%d to %dx%d\n", img_width, img_height, width, height);
+
+ uint8_t* resized_image_buffer = (uint8_t*)malloc(height * width * 3);
+ if (resized_image_buffer == NULL) {
+ fprintf(stderr, "Failed to allocate memory for resized image\n");
+ free(input_image_buffer);
+ return 1;
+ }
+
+ stbir_resize(input_image_buffer, img_width, img_height, 0,
+ resized_image_buffer, width, height, 0, STBIR_TYPE_UINT8,
+ 3, STBIR_ALPHA_CHANNEL_NONE, 0,
+ STBIR_EDGE_CLAMP, STBIR_EDGE_CLAMP,
+ STBIR_FILTER_BOX, STBIR_FILTER_BOX,
+ STBIR_COLORSPACE_SRGB, nullptr);
+
+ free(input_image_buffer);
+ input_image_buffer = resized_image_buffer;
+ }
+
+ p->init_image = {(uint32_t)width, (uint32_t)height, 3, input_image_buffer};
+ p->strength = strength;
+ fprintf(stderr, "Using img2img with strength: %.2f\n", strength);
+ } else {
+ // No input image, use empty image for text-to-image
+ p->init_image = {(uint32_t)width, (uint32_t)height, 3, NULL};
+ p->strength = 0.0f;
+ }
+
+ // Handle mask image for inpainting
+ if (has_mask_image) {
+ fprintf(stderr, "Loading mask image: %s\n", mask_image);
+
+ int c = 0;
+ int mask_width = 0;
+ int mask_height = 0;
+ mask_image_buffer = stbi_load(mask_image, &mask_width, &mask_height, &c, 1);
+ if (mask_image_buffer == NULL) {
+ fprintf(stderr, "Failed to load mask image from '%s'\n", mask_image);
+ if (input_image_buffer) free(input_image_buffer);
+ return 1;
+ }
+
+ // Resize mask if dimensions don't match
+ if (mask_width != width || mask_height != height) {
+ fprintf(stderr, "Resizing mask image from %dx%d to %dx%d\n", mask_width, mask_height, width, height);
+
+ uint8_t* resized_mask_buffer = (uint8_t*)malloc(height * width);
+ if (resized_mask_buffer == NULL) {
+ fprintf(stderr, "Failed to allocate memory for resized mask\n");
+ free(mask_image_buffer);
+ if (input_image_buffer) free(input_image_buffer);
+ return 1;
+ }
+
+ stbir_resize(mask_image_buffer, mask_width, mask_height, 0,
+ resized_mask_buffer, width, height, 0, STBIR_TYPE_UINT8,
+ 1, STBIR_ALPHA_CHANNEL_NONE, 0,
+ STBIR_EDGE_CLAMP, STBIR_EDGE_CLAMP,
+ STBIR_FILTER_BOX, STBIR_FILTER_BOX,
+ STBIR_COLORSPACE_SRGB, nullptr);
+
+ free(mask_image_buffer);
+ mask_image_buffer = resized_mask_buffer;
+ }
+
+ p->mask_image = {(uint32_t)width, (uint32_t)height, 1, mask_image_buffer};
+ fprintf(stderr, "Using inpainting with mask\n");
+ } else {
+ // No mask image, create default full mask
+ default_mask_image_vec.resize(width * height, 255);
+ p->mask_image = {(uint32_t)width, (uint32_t)height, 1, default_mask_image_vec.data()};
+ }
+
+ // Handle reference images
+ std::vector ref_images_vec;
+ std::vector ref_image_buffers;
+
+ if (ref_images_count > 0 && ref_images != NULL) {
+ fprintf(stderr, "Loading %d reference images\n", ref_images_count);
+
+ for (int i = 0; i < ref_images_count; i++) {
+ if (ref_images[i] == NULL || strlen(ref_images[i]) == 0) {
+ continue;
+ }
+
+ fprintf(stderr, "Loading reference image %d: %s\n", i + 1, ref_images[i]);
+
+ int c = 0;
+ int ref_width = 0;
+ int ref_height = 0;
+ uint8_t* ref_image_buffer = stbi_load(ref_images[i], &ref_width, &ref_height, &c, 3);
+ if (ref_image_buffer == NULL) {
+ fprintf(stderr, "Failed to load reference image from '%s'\n", ref_images[i]);
+ continue;
+ }
+ if (c < 3) {
+ fprintf(stderr, "Reference image must have at least 3 channels, got %d\n", c);
+ free(ref_image_buffer);
+ continue;
+ }
+
+ // Resize reference image if dimensions don't match
+ if (ref_width != width || ref_height != height) {
+ fprintf(stderr, "Resizing reference image from %dx%d to %dx%d\n", ref_width, ref_height, width, height);
+
+ uint8_t* resized_ref_buffer = (uint8_t*)malloc(height * width * 3);
+ if (resized_ref_buffer == NULL) {
+ fprintf(stderr, "Failed to allocate memory for resized reference image\n");
+ free(ref_image_buffer);
+ continue;
+ }
+
+ stbir_resize(ref_image_buffer, ref_width, ref_height, 0,
+ resized_ref_buffer, width, height, 0, STBIR_TYPE_UINT8,
+ 3, STBIR_ALPHA_CHANNEL_NONE, 0,
+ STBIR_EDGE_CLAMP, STBIR_EDGE_CLAMP,
+ STBIR_FILTER_BOX, STBIR_FILTER_BOX,
+ STBIR_COLORSPACE_SRGB, nullptr);
+
+ free(ref_image_buffer);
+ ref_image_buffer = resized_ref_buffer;
+ }
+
+ ref_image_buffers.push_back(ref_image_buffer);
+ ref_images_vec.push_back({(uint32_t)width, (uint32_t)height, 3, ref_image_buffer});
+ }
+
+ if (!ref_images_vec.empty()) {
+ p->ref_images = ref_images_vec.data();
+ p->ref_images_count = ref_images_vec.size();
+ fprintf(stderr, "Using %zu reference images\n", ref_images_vec.size());
+ }
+ }
+
+ // Log LoRA information
+ if (p->loras && p->lora_count > 0) {
+ fprintf(stderr, "Using %u LoRA(s) in generation:\n", p->lora_count);
+ for (uint32_t i = 0; i < p->lora_count; i++) {
+ fprintf(stderr, " LoRA[%u]: path='%s', multiplier=%.2f, is_high_noise=%s\n",
+ i,
+ p->loras[i].path ? p->loras[i].path : "(null)",
+ p->loras[i].multiplier,
+ p->loras[i].is_high_noise ? "true" : "false");
+ }
+ } else {
+ fprintf(stderr, "No LoRAs specified for this generation\n");
+ }
+
+ fprintf(stderr, "Generating image with params: \nctx\n---\n%s\ngen\n---\n%s\n",
+ sd_ctx_params_to_str(&ctx_params),
+ sd_img_gen_params_to_str(p));
+
+ results = generate_image(sd_c, p);
+
+ std::free(p);
+
+ if (results == NULL) {
+ fprintf (stderr, "NO results\n");
+ if (input_image_buffer) free(input_image_buffer);
+ if (mask_image_buffer) free(mask_image_buffer);
+ for (auto buffer : ref_image_buffers) {
+ if (buffer) free(buffer);
+ }
+ return 1;
+ }
+
+ if (results[0].data == NULL) {
+ fprintf (stderr, "Results with no data\n");
+ if (input_image_buffer) free(input_image_buffer);
+ if (mask_image_buffer) free(mask_image_buffer);
+ for (auto buffer : ref_image_buffers) {
+ if (buffer) free(buffer);
+ }
+ return 1;
+ }
+
+ fprintf (stderr, "Writing PNG\n");
+
+ fprintf (stderr, "DST: %s\n", dst);
+ fprintf (stderr, "Width: %d\n", results[0].width);
+ fprintf (stderr, "Height: %d\n", results[0].height);
+ fprintf (stderr, "Channel: %d\n", results[0].channel);
+ fprintf (stderr, "Data: %p\n", results[0].data);
+
+ int ret = stbi_write_png(dst, results[0].width, results[0].height, results[0].channel,
+ results[0].data, 0, NULL);
+ if (ret)
+ fprintf (stderr, "Saved resulting image to '%s'\n", dst);
+ else
+ fprintf(stderr, "Failed to write image to '%s'\n", dst);
+
+ // Clean up
+ free(results[0].data);
+ results[0].data = NULL;
+ free(results);
+ if (input_image_buffer) free(input_image_buffer);
+ if (mask_image_buffer) free(mask_image_buffer);
+ for (auto buffer : ref_image_buffers) {
+ if (buffer) free(buffer);
+ }
+ fprintf (stderr, "gen_image is done: %s\n", dst);
+ fflush(stderr);
+
+ return !ret;
+}
+
+int unload() {
+ free_sd_ctx(sd_c);
+ return 0;
+}
+
diff --git a/backend/go/stablediffusion-ggml/gosd.go b/backend/go/stablediffusion-ggml/gosd.go
new file mode 100644
index 0000000000000000000000000000000000000000..205f3f2d17c0cddf778c65e4bc28ffef28e15c82
--- /dev/null
+++ b/backend/go/stablediffusion-ggml/gosd.go
@@ -0,0 +1,155 @@
+package main
+
+import (
+ "fmt"
+ "os"
+ "path/filepath"
+ "runtime"
+ "strings"
+ "unsafe"
+
+ "github.com/mudler/LocalAI/pkg/grpc/base"
+ pb "github.com/mudler/LocalAI/pkg/grpc/proto"
+ "github.com/mudler/LocalAI/pkg/utils"
+)
+
+type SDGGML struct {
+ base.SingleThread
+ threads int
+ sampleMethod string
+ cfgScale float32
+}
+
+var (
+ LoadModel func(model, model_apth string, options []uintptr, threads int32, diff int) int
+ GenImage func(params uintptr, steps int, dst string, cfgScale float32, srcImage string, strength float32, maskImage string, refImages []uintptr, refImagesCount int) int
+
+ TilingParamsSetEnabled func(params uintptr, enabled bool)
+ TilingParamsSetTileSizes func(params uintptr, tileSizeX int, tileSizeY int)
+ TilingParamsSetRelSizes func(params uintptr, relSizeX float32, relSizeY float32)
+ TilingParamsSetTargetOverlap func(params uintptr, targetOverlap float32)
+
+ ImgGenParamsNew func() uintptr
+ ImgGenParamsSetPrompts func(params uintptr, prompt string, negativePrompt string)
+ ImgGenParamsSetDimensions func(params uintptr, width int, height int)
+ ImgGenParamsSetSeed func(params uintptr, seed int64)
+ ImgGenParamsGetVaeTilingParams func(params uintptr) uintptr
+)
+
+// Copied from Purego internal/strings
+// TODO: We should upstream sending []string
+func hasSuffix(s, suffix string) bool {
+ return len(s) >= len(suffix) && s[len(s)-len(suffix):] == suffix
+}
+
+func CString(name string) *byte {
+ if hasSuffix(name, "\x00") {
+ return &(*(*[]byte)(unsafe.Pointer(&name)))[0]
+ }
+ b := make([]byte, len(name)+1)
+ copy(b, name)
+ return &b[0]
+}
+
+func (sd *SDGGML) Load(opts *pb.ModelOptions) error {
+
+ sd.threads = int(opts.Threads)
+
+ modelPath := opts.ModelPath
+
+ modelFile := opts.ModelFile
+ modelPathC := modelPath
+
+ var diffusionModel int
+
+ var oo []string
+ for _, op := range opts.Options {
+ if op == "diffusion_model" {
+ diffusionModel = 1
+ continue
+ }
+
+ // If it's an option path, we resolve absolute path from the model path
+ if strings.Contains(op, ":") && strings.Contains(op, "path") {
+ data := strings.Split(op, ":")
+ data[1] = filepath.Join(opts.ModelPath, data[1])
+ if err := utils.VerifyPath(data[1], opts.ModelPath); err == nil {
+ oo = append(oo, strings.Join(data, ":"))
+ }
+ } else {
+ oo = append(oo, op)
+ }
+ }
+
+ fmt.Fprintf(os.Stderr, "Options: %+v\n", oo)
+
+ // At the time of writing Purego doesn't recurse into slices and convert Go strings to pointers so we need to do that
+ var keepAlive []any
+ options := make([]uintptr, len(oo), len(oo)+1)
+ for i, op := range oo {
+ bytep := CString(op)
+ options[i] = uintptr(unsafe.Pointer(bytep))
+ keepAlive = append(keepAlive, bytep)
+ }
+
+ sd.cfgScale = opts.CFGScale
+
+ ret := LoadModel(modelFile, modelPathC, options, opts.Threads, diffusionModel)
+ runtime.KeepAlive(keepAlive)
+ fmt.Fprintf(os.Stderr, "LoadModel: %d\n", ret)
+ if ret != 0 {
+ return fmt.Errorf("could not load model")
+ }
+
+ return nil
+}
+
+func (sd *SDGGML) GenerateImage(opts *pb.GenerateImageRequest) error {
+ t := opts.PositivePrompt
+ dst := opts.Dst
+ negative := opts.NegativePrompt
+ srcImage := opts.Src
+
+ var maskImage string
+ if opts.EnableParameters != "" {
+ if strings.Contains(opts.EnableParameters, "mask:") {
+ parts := strings.Split(opts.EnableParameters, "mask:")
+ if len(parts) > 1 {
+ maskPath := strings.TrimSpace(parts[1])
+ if maskPath != "" {
+ maskImage = maskPath
+ }
+ }
+ }
+ }
+
+ // At the time of writing Purego doesn't recurse into slices and convert Go strings to pointers so we need to do that
+ var keepAlive []any
+ refImagesCount := len(opts.RefImages)
+ refImages := make([]uintptr, refImagesCount, refImagesCount+1)
+ for i, ri := range opts.RefImages {
+ bytep := CString(ri)
+ refImages[i] = uintptr(unsafe.Pointer(bytep))
+ keepAlive = append(keepAlive, bytep)
+ }
+
+ // Default strength for img2img (0.75 is a good default)
+ strength := float32(0.75)
+
+ // free'd by GenImage
+ p := ImgGenParamsNew()
+ ImgGenParamsSetPrompts(p, t, negative)
+ ImgGenParamsSetDimensions(p, int(opts.Width), int(opts.Height))
+ ImgGenParamsSetSeed(p, int64(opts.Seed))
+ vaep := ImgGenParamsGetVaeTilingParams(p)
+ TilingParamsSetEnabled(vaep, false)
+
+ ret := GenImage(p, int(opts.Step), dst, sd.cfgScale, srcImage, strength, maskImage, refImages, refImagesCount)
+ runtime.KeepAlive(keepAlive)
+ fmt.Fprintf(os.Stderr, "GenImage: %d\n", ret)
+ if ret != 0 {
+ return fmt.Errorf("inference failed")
+ }
+
+ return nil
+}
diff --git a/backend/go/stablediffusion-ggml/gosd.h b/backend/go/stablediffusion-ggml/gosd.h
new file mode 100644
index 0000000000000000000000000000000000000000..8324a3ead4eabaa669ee65f0fe912e5f502184c0
--- /dev/null
+++ b/backend/go/stablediffusion-ggml/gosd.h
@@ -0,0 +1,23 @@
+#include
+#include "stable-diffusion.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+void sd_tiling_params_set_enabled(sd_tiling_params_t *params, bool enabled);
+void sd_tiling_params_set_tile_sizes(sd_tiling_params_t *params, int tile_size_x, int tile_size_y);
+void sd_tiling_params_set_rel_sizes(sd_tiling_params_t *params, float rel_size_x, float rel_size_y);
+void sd_tiling_params_set_target_overlap(sd_tiling_params_t *params, float target_overlap);
+sd_tiling_params_t* sd_img_gen_params_get_vae_tiling_params(sd_img_gen_params_t *params);
+
+sd_img_gen_params_t* sd_img_gen_params_new(void);
+void sd_img_gen_params_set_prompts(sd_img_gen_params_t *params, const char *prompt, const char *negative_prompt);
+void sd_img_gen_params_set_dimensions(sd_img_gen_params_t *params, int width, int height);
+void sd_img_gen_params_set_seed(sd_img_gen_params_t *params, int64_t seed);
+
+int load_model(const char *model, char *model_path, char* options[], int threads, int diffusionModel);
+int gen_image(sd_img_gen_params_t *p, int steps, char *dst, float cfg_scale, char *src_image, float strength, char *mask_image, char* ref_images[], int ref_images_count);
+#ifdef __cplusplus
+}
+#endif
diff --git a/backend/go/stablediffusion-ggml/main.go b/backend/go/stablediffusion-ggml/main.go
new file mode 100644
index 0000000000000000000000000000000000000000..4f053fbbef94ac08ac9fd3c60ad94b6bb161e4e2
--- /dev/null
+++ b/backend/go/stablediffusion-ggml/main.go
@@ -0,0 +1,49 @@
+package main
+
+import (
+ "flag"
+
+ "github.com/ebitengine/purego"
+ grpc "github.com/mudler/LocalAI/pkg/grpc"
+)
+
+var (
+ addr = flag.String("addr", "localhost:50051", "the address to connect to")
+)
+
+type LibFuncs struct {
+ FuncPtr any
+ Name string
+}
+
+func main() {
+ gosd, err := purego.Dlopen("./libgosd.so", purego.RTLD_NOW|purego.RTLD_GLOBAL)
+ if err != nil {
+ panic(err)
+ }
+
+ libFuncs := []LibFuncs{
+ {&LoadModel, "load_model"},
+ {&GenImage, "gen_image"},
+ {&TilingParamsSetEnabled, "sd_tiling_params_set_enabled"},
+ {&TilingParamsSetTileSizes, "sd_tiling_params_set_tile_sizes"},
+ {&TilingParamsSetRelSizes, "sd_tiling_params_set_rel_sizes"},
+ {&TilingParamsSetTargetOverlap, "sd_tiling_params_set_target_overlap"},
+
+ {&ImgGenParamsNew, "sd_img_gen_params_new"},
+ {&ImgGenParamsSetPrompts, "sd_img_gen_params_set_prompts"},
+ {&ImgGenParamsSetDimensions, "sd_img_gen_params_set_dimensions"},
+ {&ImgGenParamsSetSeed, "sd_img_gen_params_set_seed"},
+ {&ImgGenParamsGetVaeTilingParams, "sd_img_gen_params_get_vae_tiling_params"},
+ }
+
+ for _, lf := range libFuncs {
+ purego.RegisterLibFunc(lf.FuncPtr, gosd, lf.Name)
+ }
+
+ flag.Parse()
+
+ if err := grpc.StartServer(*addr, &SDGGML{}); err != nil {
+ panic(err)
+ }
+}
diff --git a/backend/go/stablediffusion-ggml/package.sh b/backend/go/stablediffusion-ggml/package.sh
new file mode 100644
index 0000000000000000000000000000000000000000..34b158c41faa0f25ada8173d97036ba2dcb61c20
--- /dev/null
+++ b/backend/go/stablediffusion-ggml/package.sh
@@ -0,0 +1,65 @@
+#!/bin/bash
+
+# Script to copy the appropriate libraries based on architecture
+# This script is used in the final stage of the Dockerfile
+
+set -e
+
+CURDIR=$(dirname "$(realpath $0)")
+REPO_ROOT="${CURDIR}/../../.."
+
+# Create lib directory
+mkdir -p $CURDIR/package/lib
+
+cp -avf $CURDIR/libgosd.so $CURDIR/package/
+cp -avf $CURDIR/stablediffusion-ggml $CURDIR/package/
+cp -fv $CURDIR/run.sh $CURDIR/package/
+
+# Detect architecture and copy appropriate libraries
+if [ -f "/lib64/ld-linux-x86-64.so.2" ]; then
+ # x86_64 architecture
+ echo "Detected x86_64 architecture, copying x86_64 libraries..."
+ cp -arfLv /lib64/ld-linux-x86-64.so.2 $CURDIR/package/lib/ld.so
+ cp -arfLv /lib/x86_64-linux-gnu/libc.so.6 $CURDIR/package/lib/libc.so.6
+ cp -arfLv /lib/x86_64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1
+ cp -arfLv /lib/x86_64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6
+ cp -arfLv /lib/x86_64-linux-gnu/libm.so.6 $CURDIR/package/lib/libm.so.6
+ cp -arfLv /lib/x86_64-linux-gnu/libgomp.so.1 $CURDIR/package/lib/libgomp.so.1
+ cp -arfLv /lib/x86_64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1
+ cp -arfLv /lib/x86_64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6
+ cp -arfLv /lib/x86_64-linux-gnu/libdl.so.2 $CURDIR/package/lib/libdl.so.2
+ cp -arfLv /lib/x86_64-linux-gnu/librt.so.1 $CURDIR/package/lib/librt.so.1
+ cp -arfLv /lib/x86_64-linux-gnu/libpthread.so.0 $CURDIR/package/lib/libpthread.so.0
+elif [ -f "/lib/ld-linux-aarch64.so.1" ]; then
+ # ARM64 architecture
+ echo "Detected ARM64 architecture, copying ARM64 libraries..."
+ cp -arfLv /lib/ld-linux-aarch64.so.1 $CURDIR/package/lib/ld.so
+ cp -arfLv /lib/aarch64-linux-gnu/libc.so.6 $CURDIR/package/lib/libc.so.6
+ cp -arfLv /lib/aarch64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1
+ cp -arfLv /lib/aarch64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6
+ cp -arfLv /lib/aarch64-linux-gnu/libm.so.6 $CURDIR/package/lib/libm.so.6
+ cp -arfLv /lib/aarch64-linux-gnu/libgomp.so.1 $CURDIR/package/lib/libgomp.so.1
+ cp -arfLv /lib/aarch64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1
+ cp -arfLv /lib/aarch64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6
+ cp -arfLv /lib/aarch64-linux-gnu/libdl.so.2 $CURDIR/package/lib/libdl.so.2
+ cp -arfLv /lib/aarch64-linux-gnu/librt.so.1 $CURDIR/package/lib/librt.so.1
+ cp -arfLv /lib/aarch64-linux-gnu/libpthread.so.0 $CURDIR/package/lib/libpthread.so.0
+elif [ $(uname -s) = "Darwin" ]; then
+ echo "Detected Darwin"
+else
+ echo "Error: Could not detect architecture"
+ exit 1
+fi
+
+# Package GPU libraries based on BUILD_TYPE
+# The GPU library packaging script will detect BUILD_TYPE and copy appropriate GPU libraries
+GPU_LIB_SCRIPT="${REPO_ROOT}/scripts/build/package-gpu-libs.sh"
+if [ -f "$GPU_LIB_SCRIPT" ]; then
+ echo "Packaging GPU libraries for BUILD_TYPE=${BUILD_TYPE:-cpu}..."
+ source "$GPU_LIB_SCRIPT" "$CURDIR/package/lib"
+ package_gpu_libs
+fi
+
+echo "Packaging completed successfully"
+ls -liah $CURDIR/package/
+ls -liah $CURDIR/package/lib/
diff --git a/backend/go/stablediffusion-ggml/run.sh b/backend/go/stablediffusion-ggml/run.sh
new file mode 100644
index 0000000000000000000000000000000000000000..ab8c576a0843ec0c846e411704b88858c20ddd1f
--- /dev/null
+++ b/backend/go/stablediffusion-ggml/run.sh
@@ -0,0 +1,14 @@
+#!/bin/bash
+set -ex
+
+CURDIR=$(dirname "$(realpath $0)")
+
+export LD_LIBRARY_PATH=$CURDIR/lib:$LD_LIBRARY_PATH
+
+# If there is a lib/ld.so, use it
+if [ -f $CURDIR/lib/ld.so ]; then
+ echo "Using lib/ld.so"
+ exec $CURDIR/lib/ld.so $CURDIR/stablediffusion-ggml "$@"
+fi
+
+exec $CURDIR/stablediffusion-ggml "$@"
\ No newline at end of file
diff --git a/backend/go/whisper/.gitignore b/backend/go/whisper/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..7c42de3f14990ab041ff8cbc46baadc4b16804ae
--- /dev/null
+++ b/backend/go/whisper/.gitignore
@@ -0,0 +1,7 @@
+.cache/
+sources/
+build/
+package/
+whisper
+*.so
+compile_commands.json
diff --git a/backend/go/whisper/CMakeLists.txt b/backend/go/whisper/CMakeLists.txt
new file mode 100644
index 0000000000000000000000000000000000000000..60cc178f2b23c4c5424cae2f9e67b69a21217fbf
--- /dev/null
+++ b/backend/go/whisper/CMakeLists.txt
@@ -0,0 +1,16 @@
+cmake_minimum_required(VERSION 3.12)
+project(gowhisper LANGUAGES C CXX)
+set(CMAKE_POSITION_INDEPENDENT_CODE ON)
+set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
+
+add_subdirectory(./sources/whisper.cpp)
+
+add_library(gowhisper MODULE gowhisper.cpp)
+target_link_libraries(gowhisper PRIVATE whisper ggml)
+
+if(CMAKE_CXX_COMPILER_ID MATCHES "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 9.0)
+ target_link_libraries(gosd PRIVATE stdc++fs)
+endif()
+
+set_property(TARGET gowhisper PROPERTY CXX_STANDARD 17)
+set_target_properties(gowhisper PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR})
diff --git a/backend/go/whisper/Makefile b/backend/go/whisper/Makefile
new file mode 100644
index 0000000000000000000000000000000000000000..9754c22ab8b8d9b40cf4a0c8206e81c3849fb352
--- /dev/null
+++ b/backend/go/whisper/Makefile
@@ -0,0 +1,122 @@
+CMAKE_ARGS?=
+BUILD_TYPE?=
+NATIVE?=false
+
+GOCMD?=go
+GO_TAGS?=
+JOBS?=$(shell nproc --ignore=1)
+
+# whisper.cpp version
+WHISPER_REPO?=https://github.com/ggml-org/whisper.cpp
+WHISPER_CPP_VERSION?=47af2fb70f7e4ee1ba40c8bed513760fdfe7a704
+SO_TARGET?=libgowhisper.so
+
+CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF
+
+ifeq ($(NATIVE),false)
+ CMAKE_ARGS+=-DGGML_NATIVE=OFF
+endif
+
+ifeq ($(BUILD_TYPE),cublas)
+ CMAKE_ARGS+=-DGGML_CUDA=ON
+else ifeq ($(BUILD_TYPE),openblas)
+ CMAKE_ARGS+=-DGGML_BLAS=ON -DGGML_BLAS_VENDOR=OpenBLAS
+else ifeq ($(BUILD_TYPE),clblas)
+ CMAKE_ARGS+=-DGGML_CLBLAST=ON -DCLBlast_DIR=/some/path
+else ifeq ($(BUILD_TYPE),hipblas)
+ CMAKE_ARGS+=-DGGML_HIPBLAS=ON
+else ifeq ($(BUILD_TYPE),vulkan)
+ CMAKE_ARGS+=-DGGML_VULKAN=ON
+else ifeq ($(OS),Darwin)
+ ifneq ($(BUILD_TYPE),metal)
+ CMAKE_ARGS+=-DGGML_METAL=OFF
+ else
+ CMAKE_ARGS+=-DGGML_METAL=ON
+ CMAKE_ARGS+=-DGGML_METAL_EMBED_LIBRARY=ON
+ endif
+endif
+
+ifeq ($(BUILD_TYPE),sycl_f16)
+ CMAKE_ARGS+=-DGGML_SYCL=ON \
+ -DCMAKE_C_COMPILER=icx \
+ -DCMAKE_CXX_COMPILER=icpx \
+ -DGGML_SYCL_F16=ON
+endif
+
+ifeq ($(BUILD_TYPE),sycl_f32)
+ CMAKE_ARGS+=-DGGML_SYCL=ON \
+ -DCMAKE_C_COMPILER=icx \
+ -DCMAKE_CXX_COMPILER=icpx
+endif
+
+sources/whisper.cpp:
+ mkdir -p sources/whisper.cpp
+ cd sources/whisper.cpp && \
+ git init && \
+ git remote add origin $(WHISPER_REPO) && \
+ git fetch origin && \
+ git checkout $(WHISPER_CPP_VERSION) && \
+ git submodule update --init --recursive --depth 1 --single-branch
+
+# Detect OS
+UNAME_S := $(shell uname -s)
+
+# Only build CPU variants on Linux
+ifeq ($(UNAME_S),Linux)
+ VARIANT_TARGETS = libgowhisper-avx.so libgowhisper-avx2.so libgowhisper-avx512.so libgowhisper-fallback.so
+else
+ # On non-Linux (e.g., Darwin), build only fallback variant
+ VARIANT_TARGETS = libgowhisper-fallback.so
+endif
+
+whisper: main.go gowhisper.go $(VARIANT_TARGETS)
+ CGO_ENABLED=0 $(GOCMD) build -tags "$(GO_TAGS)" -o whisper ./
+
+package: whisper
+ bash package.sh
+
+build: package
+
+clean: purge
+ rm -rf libgowhisper*.so sources/whisper.cpp whisper
+
+purge:
+ rm -rf build*
+
+# Build all variants (Linux only)
+ifeq ($(UNAME_S),Linux)
+libgowhisper-avx.so: sources/whisper.cpp
+ $(MAKE) purge
+ $(info ${GREEN}I whisper build info:avx${RESET})
+ SO_TARGET=libgowhisper-avx.so CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=on -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off" $(MAKE) libgowhisper-custom
+ rm -rfv build*
+
+libgowhisper-avx2.so: sources/whisper.cpp
+ $(MAKE) purge
+ $(info ${GREEN}I whisper build info:avx2${RESET})
+ SO_TARGET=libgowhisper-avx2.so CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=on -DGGML_AVX2=on -DGGML_AVX512=off -DGGML_FMA=on -DGGML_F16C=on" $(MAKE) libgowhisper-custom
+ rm -rfv build*
+
+libgowhisper-avx512.so: sources/whisper.cpp
+ $(MAKE) purge
+ $(info ${GREEN}I whisper build info:avx512${RESET})
+ SO_TARGET=libgowhisper-avx512.so CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=on -DGGML_AVX2=off -DGGML_AVX512=on -DGGML_FMA=on -DGGML_F16C=on" $(MAKE) libgowhisper-custom
+ rm -rfv build*
+endif
+
+# Build fallback variant (all platforms)
+libgowhisper-fallback.so: sources/whisper.cpp
+ $(MAKE) purge
+ $(info ${GREEN}I whisper build info:fallback${RESET})
+ SO_TARGET=libgowhisper-fallback.so CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=off -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off" $(MAKE) libgowhisper-custom
+ rm -rfv build*
+
+libgowhisper-custom: CMakeLists.txt gowhisper.cpp gowhisper.h
+ mkdir -p build-$(SO_TARGET) && \
+ cd build-$(SO_TARGET) && \
+ cmake .. $(CMAKE_ARGS) && \
+ cmake --build . --config Release -j$(JOBS) && \
+ cd .. && \
+ mv build-$(SO_TARGET)/libgowhisper.so ./$(SO_TARGET)
+
+all: whisper package
diff --git a/backend/go/whisper/gowhisper.cpp b/backend/go/whisper/gowhisper.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..f1756d780c8e3cc2cdb05e28787eab6c0efe6209
--- /dev/null
+++ b/backend/go/whisper/gowhisper.cpp
@@ -0,0 +1,156 @@
+#include "gowhisper.h"
+#include "ggml-backend.h"
+#include "whisper.h"
+#include
+
+static struct whisper_vad_context *vctx;
+static struct whisper_context *ctx;
+static std::vector flat_segs;
+
+static void ggml_log_cb(enum ggml_log_level level, const char *log,
+ void *data) {
+ const char *level_str;
+
+ if (!log) {
+ return;
+ }
+
+ switch (level) {
+ case GGML_LOG_LEVEL_DEBUG:
+ level_str = "DEBUG";
+ break;
+ case GGML_LOG_LEVEL_INFO:
+ level_str = "INFO";
+ break;
+ case GGML_LOG_LEVEL_WARN:
+ level_str = "WARN";
+ break;
+ case GGML_LOG_LEVEL_ERROR:
+ level_str = "ERROR";
+ break;
+ default: /* Potential future-proofing */
+ level_str = "?????";
+ break;
+ }
+
+ fprintf(stderr, "[%-5s] ", level_str);
+ fputs(log, stderr);
+ fflush(stderr);
+}
+
+int load_model(const char *const model_path) {
+ whisper_log_set(ggml_log_cb, nullptr);
+ ggml_backend_load_all();
+
+ struct whisper_context_params cparams = whisper_context_default_params();
+
+ ctx = whisper_init_from_file_with_params(model_path, cparams);
+ if (ctx == nullptr) {
+ fprintf(stderr, "error: Also failed to init model as transcriber\n");
+ return 1;
+ }
+
+ return 0;
+}
+
+int load_model_vad(const char *const model_path) {
+ whisper_log_set(ggml_log_cb, nullptr);
+ ggml_backend_load_all();
+
+ struct whisper_vad_context_params vcparams =
+ whisper_vad_default_context_params();
+
+ // XXX: Overridden to false in upstream due to performance?
+ // vcparams.use_gpu = true;
+
+ vctx = whisper_vad_init_from_file_with_params(model_path, vcparams);
+ if (vctx == nullptr) {
+ fprintf(stderr, "error: Failed to init model as VAD\n");
+ return 1;
+ }
+
+ return 0;
+}
+
+int vad(float pcmf32[], size_t pcmf32_len, float **segs_out,
+ size_t *segs_out_len) {
+ if (!whisper_vad_detect_speech(vctx, pcmf32, pcmf32_len)) {
+ fprintf(stderr, "error: failed to detect speech\n");
+ return 1;
+ }
+
+ struct whisper_vad_params params = whisper_vad_default_params();
+ struct whisper_vad_segments *segs =
+ whisper_vad_segments_from_probs(vctx, params);
+ size_t segn = whisper_vad_segments_n_segments(segs);
+
+ // fprintf(stderr, "Got segments %zd\n", segn);
+
+ flat_segs.clear();
+
+ for (int i = 0; i < segn; i++) {
+ flat_segs.push_back(whisper_vad_segments_get_segment_t0(segs, i));
+ flat_segs.push_back(whisper_vad_segments_get_segment_t1(segs, i));
+ }
+
+ // fprintf(stderr, "setting out variables: %p=%p -> %p, %p=%zx -> %zx\n",
+ // segs_out, *segs_out, flat_segs.data(), segs_out_len, *segs_out_len,
+ // flat_segs.size());
+ *segs_out = flat_segs.data();
+ *segs_out_len = flat_segs.size();
+
+ // fprintf(stderr, "freeing segs\n");
+ whisper_vad_free_segments(segs);
+
+ // fprintf(stderr, "returning\n");
+ return 0;
+}
+
+int transcribe(uint32_t threads, char *lang, bool translate, bool tdrz,
+ float pcmf32[], size_t pcmf32_len, size_t *segs_out_len, char *prompt) {
+ whisper_full_params wparams =
+ whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
+
+ wparams.n_threads = threads;
+ if (*lang != '\0')
+ wparams.language = lang;
+ else {
+ wparams.language = nullptr;
+ }
+
+ wparams.translate = translate;
+ wparams.debug_mode = true;
+ wparams.print_progress = true;
+ wparams.tdrz_enable = tdrz;
+ wparams.initial_prompt = prompt;
+
+ fprintf(stderr, "info: Enable tdrz: %d\n", tdrz);
+ fprintf(stderr, "info: Initial prompt: \"%s\"\n", prompt);
+
+ if (whisper_full(ctx, wparams, pcmf32, pcmf32_len)) {
+ fprintf(stderr, "error: transcription failed\n");
+ return 1;
+ }
+
+ *segs_out_len = whisper_full_n_segments(ctx);
+
+ return 0;
+}
+
+const char *get_segment_text(int i) {
+ return whisper_full_get_segment_text(ctx, i);
+}
+
+int64_t get_segment_t0(int i) { return whisper_full_get_segment_t0(ctx, i); }
+
+int64_t get_segment_t1(int i) { return whisper_full_get_segment_t1(ctx, i); }
+
+int n_tokens(int i) { return whisper_full_n_tokens(ctx, i); }
+
+int32_t get_token_id(int i, int j) {
+ return whisper_full_get_token_id(ctx, i, j);
+}
+
+bool get_segment_speaker_turn_next(int i) {
+ return whisper_full_get_segment_speaker_turn_next(ctx, i);
+}
diff --git a/backend/go/whisper/gowhisper.go b/backend/go/whisper/gowhisper.go
new file mode 100644
index 0000000000000000000000000000000000000000..047f0ab8878a7fd2e1e825d57200cdda3402b634
--- /dev/null
+++ b/backend/go/whisper/gowhisper.go
@@ -0,0 +1,161 @@
+package main
+
+import (
+ "fmt"
+ "os"
+ "path/filepath"
+ "strings"
+ "unsafe"
+
+ "github.com/go-audio/wav"
+ "github.com/mudler/LocalAI/pkg/grpc/base"
+ pb "github.com/mudler/LocalAI/pkg/grpc/proto"
+ "github.com/mudler/LocalAI/pkg/utils"
+)
+
+var (
+ CppLoadModel func(modelPath string) int
+ CppLoadModelVAD func(modelPath string) int
+ CppVAD func(pcmf32 []float32, pcmf32Size uintptr, segsOut unsafe.Pointer, segsOutLen unsafe.Pointer) int
+ CppTranscribe func(threads uint32, lang string, translate bool, diarize bool, pcmf32 []float32, pcmf32Len uintptr, segsOutLen unsafe.Pointer, prompt string) int
+ CppGetSegmentText func(i int) string
+ CppGetSegmentStart func(i int) int64
+ CppGetSegmentEnd func(i int) int64
+ CppNTokens func(i int) int
+ CppGetTokenID func(i int, j int) int
+ CppGetSegmentSpeakerTurnNext func(i int) bool
+)
+
+type Whisper struct {
+ base.SingleThread
+}
+
+func (w *Whisper) Load(opts *pb.ModelOptions) error {
+ vadOnly := false
+
+ for _, oo := range opts.Options {
+ if oo == "vad_only" {
+ vadOnly = true
+ } else {
+ fmt.Fprintf(os.Stderr, "Unrecognized option: %v\n", oo)
+ }
+ }
+
+ if vadOnly {
+ if ret := CppLoadModelVAD(opts.ModelFile); ret != 0 {
+ return fmt.Errorf("Failed to load Whisper VAD model")
+ }
+
+ return nil
+ }
+
+ if ret := CppLoadModel(opts.ModelFile); ret != 0 {
+ return fmt.Errorf("Failed to load Whisper transcription model")
+ }
+
+ return nil
+}
+
+func (w *Whisper) VAD(req *pb.VADRequest) (pb.VADResponse, error) {
+ audio := req.Audio
+ // We expect 0xdeadbeef to be overwritten and if we see it in a stack trace we know it wasn't
+ segsPtr, segsLen := uintptr(0xdeadbeef), uintptr(0xdeadbeef)
+ segsPtrPtr, segsLenPtr := unsafe.Pointer(&segsPtr), unsafe.Pointer(&segsLen)
+
+ if ret := CppVAD(audio, uintptr(len(audio)), segsPtrPtr, segsLenPtr); ret != 0 {
+ return pb.VADResponse{}, fmt.Errorf("Failed VAD")
+ }
+
+ // Happens when CPP vector has not had any elements pushed to it
+ if segsPtr == 0 {
+ return pb.VADResponse{
+ Segments: []*pb.VADSegment{},
+ }, nil
+ }
+
+ // unsafeptr warning is caused by segsPtr being on the stack and therefor being subject to stack copying AFAICT
+ // however the stack shouldn't have grown between setting segsPtr and now, also the memory pointed to is allocated by C++
+ segs := unsafe.Slice((*float32)(unsafe.Pointer(segsPtr)), segsLen)
+
+ vadSegments := []*pb.VADSegment{}
+ for i := range len(segs) >> 1 {
+ s := segs[2*i] / 100
+ t := segs[2*i+1] / 100
+ vadSegments = append(vadSegments, &pb.VADSegment{
+ Start: s,
+ End: t,
+ })
+ }
+
+ return pb.VADResponse{
+ Segments: vadSegments,
+ }, nil
+}
+
+func (w *Whisper) AudioTranscription(opts *pb.TranscriptRequest) (pb.TranscriptResult, error) {
+ dir, err := os.MkdirTemp("", "whisper")
+ if err != nil {
+ return pb.TranscriptResult{}, err
+ }
+ defer os.RemoveAll(dir)
+
+ convertedPath := filepath.Join(dir, "converted.wav")
+
+ if err := utils.AudioToWav(opts.Dst, convertedPath); err != nil {
+ return pb.TranscriptResult{}, err
+ }
+
+ // Open samples
+ fh, err := os.Open(convertedPath)
+ if err != nil {
+ return pb.TranscriptResult{}, err
+ }
+ defer fh.Close()
+
+ // Read samples
+ d := wav.NewDecoder(fh)
+ buf, err := d.FullPCMBuffer()
+ if err != nil {
+ return pb.TranscriptResult{}, err
+ }
+
+ data := buf.AsFloat32Buffer().Data
+ segsLen := uintptr(0xdeadbeef)
+ segsLenPtr := unsafe.Pointer(&segsLen)
+
+ if ret := CppTranscribe(opts.Threads, opts.Language, opts.Translate, opts.Diarize, data, uintptr(len(data)), segsLenPtr, opts.Prompt); ret != 0 {
+ return pb.TranscriptResult{}, fmt.Errorf("Failed Transcribe")
+ }
+
+ segments := []*pb.TranscriptSegment{}
+ text := ""
+ for i := range int(segsLen) {
+ s := CppGetSegmentStart(i)
+ t := CppGetSegmentEnd(i)
+ txt := strings.Clone(CppGetSegmentText(i))
+ tokens := make([]int32, CppNTokens(i))
+
+ if opts.Diarize && CppGetSegmentSpeakerTurnNext(i) {
+ txt += " [SPEAKER_TURN]"
+ }
+
+ for j := range tokens {
+ tokens[j] = int32(CppGetTokenID(i, j))
+ }
+ segment := &pb.TranscriptSegment{
+ Id: int32(i),
+ Text: txt,
+ Start: s, End: t,
+ Tokens: tokens,
+ }
+
+ segments = append(segments, segment)
+
+ text += " " + strings.TrimSpace(txt)
+ }
+
+ return pb.TranscriptResult{
+ Segments: segments,
+ Text: strings.TrimSpace(text),
+ }, nil
+}
diff --git a/backend/go/whisper/gowhisper.h b/backend/go/whisper/gowhisper.h
new file mode 100644
index 0000000000000000000000000000000000000000..0e061cf93debb50a2c90e7bea9e0defeb38bb657
--- /dev/null
+++ b/backend/go/whisper/gowhisper.h
@@ -0,0 +1,18 @@
+#include
+#include
+
+extern "C" {
+int load_model(const char *const model_path);
+int load_model_vad(const char *const model_path);
+int vad(float pcmf32[], size_t pcmf32_size, float **segs_out,
+ size_t *segs_out_len);
+int transcribe(uint32_t threads, char *lang, bool translate, bool tdrz,
+ float pcmf32[], size_t pcmf32_len, size_t *segs_out_len,
+ char *prompt);
+const char *get_segment_text(int i);
+int64_t get_segment_t0(int i);
+int64_t get_segment_t1(int i);
+int n_tokens(int i);
+int32_t get_token_id(int i, int j);
+bool get_segment_speaker_turn_next(int i);
+}
diff --git a/backend/go/whisper/main.go b/backend/go/whisper/main.go
new file mode 100644
index 0000000000000000000000000000000000000000..794c0a2283a8a35038b8421448df809bdc3c088c
--- /dev/null
+++ b/backend/go/whisper/main.go
@@ -0,0 +1,55 @@
+package main
+
+// Note: this is started internally by LocalAI and a server is allocated for each model
+import (
+ "flag"
+ "os"
+
+ "github.com/ebitengine/purego"
+ grpc "github.com/mudler/LocalAI/pkg/grpc"
+)
+
+var (
+ addr = flag.String("addr", "localhost:50051", "the address to connect to")
+)
+
+type LibFuncs struct {
+ FuncPtr any
+ Name string
+}
+
+func main() {
+ // Get library name from environment variable, default to fallback
+ libName := os.Getenv("WHISPER_LIBRARY")
+ if libName == "" {
+ libName = "./libgowhisper-fallback.so"
+ }
+
+ gosd, err := purego.Dlopen(libName, purego.RTLD_NOW|purego.RTLD_GLOBAL)
+ if err != nil {
+ panic(err)
+ }
+
+ libFuncs := []LibFuncs{
+ {&CppLoadModel, "load_model"},
+ {&CppLoadModelVAD, "load_model_vad"},
+ {&CppVAD, "vad"},
+ {&CppTranscribe, "transcribe"},
+ {&CppGetSegmentText, "get_segment_text"},
+ {&CppGetSegmentStart, "get_segment_t0"},
+ {&CppGetSegmentEnd, "get_segment_t1"},
+ {&CppNTokens, "n_tokens"},
+ {&CppGetTokenID, "get_token_id"},
+ {&CppGetSegmentSpeakerTurnNext, "get_segment_speaker_turn_next"},
+ }
+
+ for _, lf := range libFuncs {
+ purego.RegisterLibFunc(lf.FuncPtr, gosd, lf.Name)
+ }
+
+ flag.Parse()
+
+ if err := grpc.StartServer(*addr, &Whisper{}); err != nil {
+ panic(err)
+ }
+}
diff --git a/backend/go/whisper/package.sh b/backend/go/whisper/package.sh
new file mode 100644
index 0000000000000000000000000000000000000000..dfecdf5c68cb7c2264c8ee201537ad4ac2cb5496
--- /dev/null
+++ b/backend/go/whisper/package.sh
@@ -0,0 +1,65 @@
+#!/bin/bash
+
+# Script to copy the appropriate libraries based on architecture
+# This script is used in the final stage of the Dockerfile
+
+set -e
+
+CURDIR=$(dirname "$(realpath $0)")
+REPO_ROOT="${CURDIR}/../../.."
+
+# Create lib directory
+mkdir -p $CURDIR/package/lib
+
+cp -avf $CURDIR/whisper $CURDIR/package/
+cp -fv $CURDIR/libgowhisper-*.so $CURDIR/package/
+cp -fv $CURDIR/run.sh $CURDIR/package/
+
+# Detect architecture and copy appropriate libraries
+if [ -f "/lib64/ld-linux-x86-64.so.2" ]; then
+ # x86_64 architecture
+ echo "Detected x86_64 architecture, copying x86_64 libraries..."
+ cp -arfLv /lib64/ld-linux-x86-64.so.2 $CURDIR/package/lib/ld.so
+ cp -arfLv /lib/x86_64-linux-gnu/libc.so.6 $CURDIR/package/lib/libc.so.6
+ cp -arfLv /lib/x86_64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1
+ cp -arfLv /lib/x86_64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6
+ cp -arfLv /lib/x86_64-linux-gnu/libm.so.6 $CURDIR/package/lib/libm.so.6
+ cp -arfLv /lib/x86_64-linux-gnu/libgomp.so.1 $CURDIR/package/lib/libgomp.so.1
+ cp -arfLv /lib/x86_64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1
+ cp -arfLv /lib/x86_64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6
+ cp -arfLv /lib/x86_64-linux-gnu/libdl.so.2 $CURDIR/package/lib/libdl.so.2
+ cp -arfLv /lib/x86_64-linux-gnu/librt.so.1 $CURDIR/package/lib/librt.so.1
+ cp -arfLv /lib/x86_64-linux-gnu/libpthread.so.0 $CURDIR/package/lib/libpthread.so.0
+elif [ -f "/lib/ld-linux-aarch64.so.1" ]; then
+ # ARM64 architecture
+ echo "Detected ARM64 architecture, copying ARM64 libraries..."
+ cp -arfLv /lib/ld-linux-aarch64.so.1 $CURDIR/package/lib/ld.so
+ cp -arfLv /lib/aarch64-linux-gnu/libc.so.6 $CURDIR/package/lib/libc.so.6
+ cp -arfLv /lib/aarch64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1
+ cp -arfLv /lib/aarch64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6
+ cp -arfLv /lib/aarch64-linux-gnu/libm.so.6 $CURDIR/package/lib/libm.so.6
+ cp -arfLv /lib/aarch64-linux-gnu/libgomp.so.1 $CURDIR/package/lib/libgomp.so.1
+ cp -arfLv /lib/aarch64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1
+ cp -arfLv /lib/aarch64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6
+ cp -arfLv /lib/aarch64-linux-gnu/libdl.so.2 $CURDIR/package/lib/libdl.so.2
+ cp -arfLv /lib/aarch64-linux-gnu/librt.so.1 $CURDIR/package/lib/librt.so.1
+ cp -arfLv /lib/aarch64-linux-gnu/libpthread.so.0 $CURDIR/package/lib/libpthread.so.0
+elif [ $(uname -s) = "Darwin" ]; then
+ echo "Detected Darwin"
+else
+ echo "Error: Could not detect architecture"
+ exit 1
+fi
+
+# Package GPU libraries based on BUILD_TYPE
+# The GPU library packaging script will detect BUILD_TYPE and copy appropriate GPU libraries
+GPU_LIB_SCRIPT="${REPO_ROOT}/scripts/build/package-gpu-libs.sh"
+if [ -f "$GPU_LIB_SCRIPT" ]; then
+ echo "Packaging GPU libraries for BUILD_TYPE=${BUILD_TYPE:-cpu}..."
+ source "$GPU_LIB_SCRIPT" "$CURDIR/package/lib"
+ package_gpu_libs
+fi
+
+echo "Packaging completed successfully"
+ls -liah $CURDIR/package/
+ls -liah $CURDIR/package/lib/
diff --git a/backend/go/whisper/run.sh b/backend/go/whisper/run.sh
new file mode 100644
index 0000000000000000000000000000000000000000..1af2c05359306177a80b6806969c893045dd40e4
--- /dev/null
+++ b/backend/go/whisper/run.sh
@@ -0,0 +1,52 @@
+#!/bin/bash
+set -ex
+
+# Get the absolute current dir where the script is located
+CURDIR=$(dirname "$(realpath $0)")
+
+cd /
+
+echo "CPU info:"
+if [ "$(uname)" != "Darwin" ]; then
+ grep -e "model\sname" /proc/cpuinfo | head -1
+ grep -e "flags" /proc/cpuinfo | head -1
+fi
+
+LIBRARY="$CURDIR/libgowhisper-fallback.so"
+
+if [ "$(uname)" != "Darwin" ]; then
+ if grep -q -e "\savx\s" /proc/cpuinfo ; then
+ echo "CPU: AVX found OK"
+ if [ -e $CURDIR/libgowhisper-avx.so ]; then
+ LIBRARY="$CURDIR/libgowhisper-avx.so"
+ fi
+ fi
+
+ if grep -q -e "\savx2\s" /proc/cpuinfo ; then
+ echo "CPU: AVX2 found OK"
+ if [ -e $CURDIR/libgowhisper-avx2.so ]; then
+ LIBRARY="$CURDIR/libgowhisper-avx2.so"
+ fi
+ fi
+
+ # Check avx 512
+ if grep -q -e "\savx512f\s" /proc/cpuinfo ; then
+ echo "CPU: AVX512F found OK"
+ if [ -e $CURDIR/libgowhisper-avx512.so ]; then
+ LIBRARY="$CURDIR/libgowhisper-avx512.so"
+ fi
+ fi
+fi
+
+export LD_LIBRARY_PATH=$CURDIR/lib:$LD_LIBRARY_PATH
+export WHISPER_LIBRARY=$LIBRARY
+
+# If there is a lib/ld.so, use it
+if [ -f $CURDIR/lib/ld.so ]; then
+ echo "Using lib/ld.so"
+ echo "Using library: $LIBRARY"
+ exec $CURDIR/lib/ld.so $CURDIR/whisper "$@"
+fi
+
+echo "Using library: $LIBRARY"
+exec $CURDIR/whisper "$@"
\ No newline at end of file
diff --git a/backend/index.yaml b/backend/index.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..45c5bb713b62e2ebee42669ca952da2170edb11b
--- /dev/null
+++ b/backend/index.yaml
@@ -0,0 +1,1712 @@
+---
+## metas
+- &llamacpp
+ name: "llama-cpp"
+ alias: "llama-cpp"
+ license: mit
+ icon: https://user-images.githubusercontent.com/1991296/230134379-7181e485-c521-4d23-a0d6-f7b3b61ba524.png
+ description: |
+ LLM inference in C/C++
+ urls:
+ - https://github.com/ggerganov/llama.cpp
+ tags:
+ - text-to-text
+ - LLM
+ - CPU
+ - GPU
+ - Metal
+ - CUDA
+ - HIP
+ capabilities:
+ default: "cpu-llama-cpp"
+ nvidia: "cuda12-llama-cpp"
+ intel: "intel-sycl-f16-llama-cpp"
+ amd: "rocm-llama-cpp"
+ metal: "metal-llama-cpp"
+ vulkan: "vulkan-llama-cpp"
+ nvidia-l4t: "nvidia-l4t-arm64-llama-cpp"
+ nvidia-cuda-13: "cuda13-llama-cpp"
+ nvidia-cuda-12: "cuda12-llama-cpp"
+ nvidia-l4t-cuda-12: "nvidia-l4t-arm64-llama-cpp"
+ nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-llama-cpp"
+- &whispercpp
+ name: "whisper"
+ alias: "whisper"
+ license: mit
+ icon: https://user-images.githubusercontent.com/1991296/235238348-05d0f6a4-da44-4900-a1de-d0707e75b763.jpeg
+ description: |
+ Port of OpenAI's Whisper model in C/C++
+ urls:
+ - https://github.com/ggml-org/whisper.cpp
+ tags:
+ - audio-transcription
+ - CPU
+ - GPU
+ - CUDA
+ - HIP
+ capabilities:
+ default: "cpu-whisper"
+ nvidia: "cuda12-whisper"
+ intel: "intel-sycl-f16-whisper"
+ metal: "metal-whisper"
+ amd: "rocm-whisper"
+ vulkan: "vulkan-whisper"
+ nvidia-l4t: "nvidia-l4t-arm64-whisper"
+ nvidia-cuda-13: "cuda13-whisper"
+ nvidia-cuda-12: "cuda12-whisper"
+ nvidia-l4t-cuda-12: "nvidia-l4t-arm64-whisper"
+ nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-whisper"
+- &stablediffusionggml
+ name: "stablediffusion-ggml"
+ alias: "stablediffusion-ggml"
+ license: mit
+ icon: https://github.com/leejet/stable-diffusion.cpp/raw/master/assets/cat_with_sd_cpp_42.png
+ description: |
+ Stable Diffusion and Flux in pure C/C++
+ urls:
+ - https://github.com/leejet/stable-diffusion.cpp
+ tags:
+ - image-generation
+ - CPU
+ - GPU
+ - Metal
+ - CUDA
+ - HIP
+ capabilities:
+ default: "cpu-stablediffusion-ggml"
+ nvidia: "cuda12-stablediffusion-ggml"
+ intel: "intel-sycl-f16-stablediffusion-ggml"
+ # amd: "rocm-stablediffusion-ggml"
+ vulkan: "vulkan-stablediffusion-ggml"
+ nvidia-l4t: "nvidia-l4t-arm64-stablediffusion-ggml"
+ metal: "metal-stablediffusion-ggml"
+ nvidia-cuda-13: "cuda13-stablediffusion-ggml"
+ nvidia-cuda-12: "cuda12-stablediffusion-ggml"
+ nvidia-l4t-cuda-12: "nvidia-l4t-arm64-stablediffusion-ggml"
+ nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-stablediffusion-ggml"
+- &rfdetr
+ name: "rfdetr"
+ alias: "rfdetr"
+ license: apache-2.0
+ icon: https://avatars.githubusercontent.com/u/53104118?s=200&v=4
+ description: |
+ RF-DETR is a real-time, transformer-based object detection model architecture developed by Roboflow and released under the Apache 2.0 license.
+ RF-DETR is the first real-time model to exceed 60 AP on the Microsoft COCO benchmark alongside competitive performance at base sizes. It also achieves state-of-the-art performance on RF100-VL, an object detection benchmark that measures model domain adaptability to real world problems. RF-DETR is fastest and most accurate for its size when compared current real-time objection models.
+ RF-DETR is small enough to run on the edge using Inference, making it an ideal model for deployments that need both strong accuracy and real-time performance.
+ urls:
+ - https://github.com/roboflow/rf-detr
+ tags:
+ - object-detection
+ - rfdetr
+ - gpu
+ - cpu
+ capabilities:
+ nvidia: "cuda12-rfdetr"
+ intel: "intel-rfdetr"
+ #amd: "rocm-rfdetr"
+ nvidia-l4t: "nvidia-l4t-arm64-rfdetr"
+ default: "cpu-rfdetr"
+ nvidia-cuda-13: "cuda13-rfdetr"
+ nvidia-cuda-12: "cuda12-rfdetr"
+ nvidia-l4t-cuda-12: "nvidia-l4t-arm64-rfdetr"
+- &vllm
+ name: "vllm"
+ license: apache-2.0
+ urls:
+ - https://github.com/vllm-project/vllm
+ tags:
+ - text-to-text
+ - multimodal
+ - GPTQ
+ - AWQ
+ - AutoRound
+ - INT4
+ - INT8
+ - FP8
+ icon: https://raw.githubusercontent.com/vllm-project/vllm/main/docs/assets/logos/vllm-logo-text-dark.png
+ description: |
+ vLLM is a fast and easy-to-use library for LLM inference and serving.
+ Originally developed in the Sky Computing Lab at UC Berkeley, vLLM has evolved into a community-driven project with contributions from both academia and industry.
+ vLLM is fast with:
+ State-of-the-art serving throughput
+ Efficient management of attention key and value memory with PagedAttention
+ Continuous batching of incoming requests
+ Fast model execution with CUDA/HIP graph
+ Quantizations: GPTQ, AWQ, AutoRound, INT4, INT8, and FP8
+ Optimized CUDA kernels, including integration with FlashAttention and FlashInfer
+ Speculative decoding
+ Chunked prefill
+ alias: "vllm"
+ capabilities:
+ nvidia: "cuda12-vllm"
+ amd: "rocm-vllm"
+ intel: "intel-vllm"
+ nvidia-cuda-12: "cuda12-vllm"
+- &mlx
+ name: "mlx"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-mlx"
+ icon: https://avatars.githubusercontent.com/u/102832242?s=200&v=4
+ urls:
+ - https://github.com/ml-explore/mlx-lm
+ mirrors:
+ - localai/localai-backends:latest-metal-darwin-arm64-mlx
+ license: MIT
+ description: |
+ Run LLMs with MLX
+ tags:
+ - text-to-text
+ - LLM
+ - MLX
+- &mlx-vlm
+ name: "mlx-vlm"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-mlx-vlm"
+ icon: https://avatars.githubusercontent.com/u/102832242?s=200&v=4
+ urls:
+ - https://github.com/Blaizzy/mlx-vlm
+ mirrors:
+ - localai/localai-backends:latest-metal-darwin-arm64-mlx-vlm
+ license: MIT
+ description: |
+ Run Vision-Language Models with MLX
+ tags:
+ - text-to-text
+ - multimodal
+ - vision-language
+ - LLM
+ - MLX
+- &mlx-audio
+ name: "mlx-audio"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-mlx-audio"
+ icon: https://avatars.githubusercontent.com/u/102832242?s=200&v=4
+ urls:
+ - https://github.com/Blaizzy/mlx-audio
+ mirrors:
+ - localai/localai-backends:latest-metal-darwin-arm64-mlx-audio
+ license: MIT
+ description: |
+ Run Audio Models with MLX
+ tags:
+ - audio-to-text
+ - audio-generation
+ - text-to-audio
+ - LLM
+ - MLX
+- &rerankers
+ name: "rerankers"
+ alias: "rerankers"
+ capabilities:
+ nvidia: "cuda12-rerankers"
+ intel: "intel-rerankers"
+ amd: "rocm-rerankers"
+- &transformers
+ name: "transformers"
+ icon: https://camo.githubusercontent.com/26569a27b8a30a488dd345024b71dbc05da7ff1b2ba97bb6080c9f1ee0f26cc7/68747470733a2f2f68756767696e67666163652e636f2f64617461736574732f68756767696e67666163652f646f63756d656e746174696f6e2d696d616765732f7265736f6c76652f6d61696e2f7472616e73666f726d6572732f7472616e73666f726d6572735f61735f615f6d6f64656c5f646566696e6974696f6e2e706e67
+ alias: "transformers"
+ license: apache-2.0
+ description: |
+ Transformers acts as the model-definition framework for state-of-the-art machine learning models in text, computer vision, audio, video, and multimodal model, for both inference and training.
+ It centralizes the model definition so that this definition is agreed upon across the ecosystem. transformers is the pivot across frameworks: if a model definition is supported, it will be compatible with the majority of training frameworks (Axolotl, Unsloth, DeepSpeed, FSDP, PyTorch-Lightning, ...), inference engines (vLLM, SGLang, TGI, ...), and adjacent modeling libraries (llama.cpp, mlx, ...) which leverage the model definition from transformers.
+ urls:
+ - https://github.com/huggingface/transformers
+ tags:
+ - text-to-text
+ - multimodal
+ capabilities:
+ nvidia: "cuda12-transformers"
+ intel: "intel-transformers"
+ amd: "rocm-transformers"
+ nvidia-cuda-13: "cuda13-transformers"
+ nvidia-cuda-12: "cuda12-transformers"
+- &diffusers
+ name: "diffusers"
+ icon: https://raw.githubusercontent.com/huggingface/diffusers/main/docs/source/en/imgs/diffusers_library.jpg
+ description: |
+ 🤗 Diffusers is the go-to library for state-of-the-art pretrained diffusion models for generating images, audio, and even 3D structures of molecules. Whether you're looking for a simple inference solution or training your own diffusion models, 🤗 Diffusers is a modular toolbox that supports both.
+ urls:
+ - https://github.com/huggingface/diffusers
+ tags:
+ - image-generation
+ - video-generation
+ - diffusion-models
+ license: apache-2.0
+ alias: "diffusers"
+ capabilities:
+ nvidia: "cuda12-diffusers"
+ intel: "intel-diffusers"
+ amd: "rocm-diffusers"
+ nvidia-l4t: "nvidia-l4t-diffusers"
+ metal: "metal-diffusers"
+ default: "cpu-diffusers"
+ nvidia-cuda-13: "cuda13-diffusers"
+ nvidia-cuda-12: "cuda12-diffusers"
+ nvidia-l4t-cuda-12: "nvidia-l4t-arm64-diffusers"
+ nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-diffusers"
+- &exllama2
+ name: "exllama2"
+ urls:
+ - https://github.com/turboderp-org/exllamav2
+ tags:
+ - text-to-text
+ - LLM
+ - EXL2
+ license: MIT
+ description: |
+ ExLlamaV2 is an inference library for running local LLMs on modern consumer GPUs.
+ alias: "exllama2"
+ capabilities:
+ nvidia: "cuda12-exllama2"
+ intel: "intel-exllama2"
+ nvidia-cuda-12: "cuda12-exllama2"
+- &faster-whisper
+ icon: https://avatars.githubusercontent.com/u/1520500?s=200&v=4
+ description: |
+ faster-whisper is a reimplementation of OpenAI's Whisper model using CTranslate2, which is a fast inference engine for Transformer models.
+ This implementation is up to 4 times faster than openai/whisper for the same accuracy while using less memory. The efficiency can be further improved with 8-bit quantization on both CPU and GPU.
+ urls:
+ - https://github.com/SYSTRAN/faster-whisper
+ tags:
+ - speech-to-text
+ - Whisper
+ license: MIT
+ name: "faster-whisper"
+ capabilities:
+ nvidia: "cuda12-faster-whisper"
+ intel: "intel-faster-whisper"
+ amd: "rocm-faster-whisper"
+ nvidia-cuda-13: "cuda13-faster-whisper"
+ nvidia-cuda-12: "cuda12-faster-whisper"
+- &moonshine
+ description: |
+ Moonshine is a fast, accurate, and efficient speech-to-text transcription model using ONNX Runtime.
+ It provides real-time transcription capabilities with support for multiple model sizes and GPU acceleration.
+ urls:
+ - https://github.com/moonshine-ai/moonshine
+ tags:
+ - speech-to-text
+ - transcription
+ - ONNX
+ license: MIT
+ name: "moonshine"
+ alias: "moonshine"
+ capabilities:
+ nvidia: "cuda12-moonshine"
+ default: "cpu-moonshine"
+ nvidia-cuda-13: "cuda13-moonshine"
+ nvidia-cuda-12: "cuda12-moonshine"
+- &kokoro
+ icon: https://avatars.githubusercontent.com/u/166769057?v=4
+ description: |
+ Kokoro is an open-weight TTS model with 82 million parameters. Despite its lightweight architecture, it delivers comparable quality to larger models while being significantly faster and more cost-efficient. With Apache-licensed weights, Kokoro can be deployed anywhere from production environments to personal projects.
+ urls:
+ - https://huggingface.co/hexgrad/Kokoro-82M
+ - https://github.com/hexgrad/kokoro
+ tags:
+ - text-to-speech
+ - TTS
+ - LLM
+ license: apache-2.0
+ alias: "kokoro"
+ name: "kokoro"
+ capabilities:
+ nvidia: "cuda12-kokoro"
+ intel: "intel-kokoro"
+ amd: "rocm-kokoro"
+ nvidia-l4t: "nvidia-l4t-kokoro"
+ nvidia-cuda-13: "cuda13-kokoro"
+ nvidia-cuda-12: "cuda12-kokoro"
+ nvidia-l4t-cuda-12: "nvidia-l4t-arm64-kokoro"
+- &coqui
+ urls:
+ - https://github.com/idiap/coqui-ai-TTS
+ description: |
+ 🐸 Coqui TTS is a library for advanced Text-to-Speech generation.
+
+ 🚀 Pretrained models in +1100 languages.
+
+ 🛠️ Tools for training new models and fine-tuning existing models in any language.
+
+ 📚 Utilities for dataset analysis and curation.
+ tags:
+ - text-to-speech
+ - TTS
+ license: mpl-2.0
+ name: "coqui"
+ alias: "coqui"
+ capabilities:
+ nvidia: "cuda12-coqui"
+ intel: "intel-coqui"
+ amd: "rocm-coqui"
+ nvidia-cuda-13: "cuda13-coqui"
+ nvidia-cuda-12: "cuda12-coqui"
+ icon: https://avatars.githubusercontent.com/u/1338804?s=200&v=4
+- &bark
+ urls:
+ - https://github.com/suno-ai/bark
+ description: |
+ Bark is a transformer-based text-to-audio model created by Suno. Bark can generate highly realistic, multilingual speech as well as other audio - including music, background noise and simple sound effects. The model can also produce nonverbal communications like laughing, sighing and crying. To support the research community, we are providing access to pretrained model checkpoints, which are ready for inference and available for commercial use.
+ tags:
+ - text-to-speech
+ - TTS
+ license: MIT
+ name: "bark"
+ alias: "bark"
+ capabilities:
+ cuda: "cuda12-bark"
+ intel: "intel-bark"
+ rocm: "rocm-bark"
+ nvidia-cuda-13: "cuda13-bark"
+ nvidia-cuda-12: "cuda12-bark"
+ icon: https://avatars.githubusercontent.com/u/99442120?s=200&v=4
+- &barkcpp
+ urls:
+ - https://github.com/PABannier/bark.cpp
+ description: |
+ With bark.cpp, our goal is to bring real-time realistic multilingual text-to-speech generation to the community.
+
+ Plain C/C++ implementation without dependencies
+ AVX, AVX2 and AVX512 for x86 architectures
+ CPU and GPU compatible backends
+ Mixed F16 / F32 precision
+ 4-bit, 5-bit and 8-bit integer quantization
+ Metal and CUDA backends
+
+ Models supported
+
+ Bark Small
+ Bark Large
+ tags:
+ - text-to-speech
+ - TTS
+ license: MIT
+ icon: https://github.com/PABannier/bark.cpp/raw/main/assets/banner.png
+ name: "bark-cpp"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-bark-cpp"
+ mirrors:
+ - localai/localai-backends:latest-bark-cpp
+ alias: "bark-cpp"
+- &chatterbox
+ urls:
+ - https://github.com/resemble-ai/chatterbox
+ description: |
+ Resemble AI's first production-grade open source TTS model. Licensed under MIT, Chatterbox has been benchmarked against leading closed-source systems like ElevenLabs, and is consistently preferred in side-by-side evaluations.
+ Whether you're working on memes, videos, games, or AI agents, Chatterbox brings your content to life. It's also the first open source TTS model to support emotion exaggeration control, a powerful feature that makes your voices stand out.
+ tags:
+ - text-to-speech
+ - TTS
+ license: MIT
+ icon: https://private-user-images.githubusercontent.com/660224/448166653-bd8c5f03-e91d-4ee5-b680-57355da204d1.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3NTAxOTE0MDAsIm5iZiI6MTc1MDE5MTEwMCwicGF0aCI6Ii82NjAyMjQvNDQ4MTY2NjUzLWJkOGM1ZjAzLWU5MWQtNGVlNS1iNjgwLTU3MzU1ZGEyMDRkMS5wbmc_WC1BbXotQWxnb3JpdGhtPUFXUzQtSE1BQy1TSEEyNTYmWC1BbXotQ3JlZGVudGlhbD1BS0lBVkNPRFlMU0E1M1BRSzRaQSUyRjIwMjUwNjE3JTJGdXMtZWFzdC0xJTJGczMlMkZhd3M0X3JlcXVlc3QmWC1BbXotRGF0ZT0yMDI1MDYxN1QyMDExNDBaJlgtQW16LUV4cGlyZXM9MzAwJlgtQW16LVNpZ25hdHVyZT1hMmI1NGY3OGFiZTlhNGFkNTVlYTY4NTIwMWEzODRiZGE4YzdhNGQ5MGNhNzE3MDYyYTA2NDIxYTkyYzhiODkwJlgtQW16LVNpZ25lZEhlYWRlcnM9aG9zdCJ9.mR9kM9xX0TdzPuSpuspCllHYQiq79dFQ2rtuNvjrl6w
+ name: "chatterbox"
+ alias: "chatterbox"
+ capabilities:
+ nvidia: "cuda12-chatterbox"
+ metal: "metal-chatterbox"
+ default: "cpu-chatterbox"
+ nvidia-l4t: "nvidia-l4t-arm64-chatterbox"
+ nvidia-cuda-13: "cuda13-chatterbox"
+ nvidia-cuda-12: "cuda12-chatterbox"
+ nvidia-l4t-cuda-12: "nvidia-l4t-arm64-chatterbox"
+ nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-chatterbox"
+- &vibevoice
+ urls:
+ - https://github.com/microsoft/VibeVoice
+ description: |
+ VibeVoice-Realtime is a real-time text-to-speech model that generates natural-sounding speech.
+ tags:
+ - text-to-speech
+ - TTS
+ license: mit
+ name: "vibevoice"
+ alias: "vibevoice"
+ capabilities:
+ nvidia: "cuda12-vibevoice"
+ intel: "intel-vibevoice"
+ amd: "rocm-vibevoice"
+ nvidia-l4t: "nvidia-l4t-vibevoice"
+ default: "cpu-vibevoice"
+ nvidia-cuda-13: "cuda13-vibevoice"
+ nvidia-cuda-12: "cuda12-vibevoice"
+ nvidia-l4t-cuda-12: "nvidia-l4t-vibevoice"
+ nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-vibevoice"
+ icon: https://avatars.githubusercontent.com/u/6154722?s=200&v=4
+- &pocket-tts
+ urls:
+ - https://github.com/kyutai-labs/pocket-tts
+ description: |
+ Pocket TTS is a lightweight text-to-speech model designed to run efficiently on CPUs.
+ tags:
+ - text-to-speech
+ - TTS
+ license: mit
+ name: "pocket-tts"
+ alias: "pocket-tts"
+ capabilities:
+ nvidia: "cuda12-pocket-tts"
+ intel: "intel-pocket-tts"
+ amd: "rocm-pocket-tts"
+ nvidia-l4t: "nvidia-l4t-pocket-tts"
+ default: "cpu-pocket-tts"
+ nvidia-cuda-13: "cuda13-pocket-tts"
+ nvidia-cuda-12: "cuda12-pocket-tts"
+ nvidia-l4t-cuda-12: "nvidia-l4t-pocket-tts"
+ nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-pocket-tts"
+ icon: https://avatars.githubusercontent.com/u/6154722?s=200&v=4
+- &piper
+ name: "piper"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-piper"
+ icon: https://github.com/OHF-Voice/piper1-gpl/raw/main/etc/logo.png
+ urls:
+ - https://github.com/rhasspy/piper
+ - https://github.com/mudler/go-piper
+ mirrors:
+ - localai/localai-backends:latest-piper
+ license: MIT
+ description: |
+ A fast, local neural text to speech system
+ tags:
+ - text-to-speech
+ - TTS
+- &silero-vad
+ name: "silero-vad"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-silero-vad"
+ icon: https://user-images.githubusercontent.com/12515440/89997349-b3523080-dc94-11ea-9906-ca2e8bc50535.png
+ urls:
+ - https://github.com/snakers4/silero-vad
+ mirrors:
+ - localai/localai-backends:latest-cpu-silero-vad
+ description: |
+ Silero VAD: pre-trained enterprise-grade Voice Activity Detector.
+ Silero VAD is a voice activity detection model that can be used to detect whether a given audio contains speech or not.
+ tags:
+ - voice-activity-detection
+ - VAD
+ - silero-vad
+ - CPU
+- &local-store
+ name: "local-store"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-local-store"
+ mirrors:
+ - localai/localai-backends:latest-cpu-local-store
+ urls:
+ - https://github.com/mudler/LocalAI
+ description: |
+ Local Store is a local-first, self-hosted, and open-source vector database.
+ tags:
+ - vector-database
+ - local-first
+ - open-source
+ - CPU
+ license: MIT
+- &huggingface
+ name: "huggingface"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-huggingface"
+ mirrors:
+ - localai/localai-backends:latest-huggingface
+ icon: https://huggingface.co/front/assets/huggingface_logo-noborder.svg
+ urls:
+ - https://huggingface.co/docs/hub/en/api
+ description: |
+ HuggingFace is a backend which uses the huggingface API to run models.
+ tags:
+ - LLM
+ - huggingface
+ license: MIT
+- &kitten-tts
+ name: "kitten-tts"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-kitten-tts"
+ mirrors:
+ - localai/localai-backends:latest-kitten-tts
+ urls:
+ - https://github.com/KittenML/KittenTTS
+ description: |
+ Kitten TTS is a text-to-speech model that can generate speech from text.
+ tags:
+ - text-to-speech
+ - TTS
+ license: apache-2.0
+- &neutts
+ name: "neutts"
+ urls:
+ - https://github.com/neuphonic/neutts-air
+ description: |
+ NeuTTS Air is the world’s first super-realistic, on-device, TTS speech language model with instant voice cloning. Built off a 0.5B LLM backbone, NeuTTS Air brings natural-sounding speech, real-time performance, built-in security and speaker cloning to your local device - unlocking a new category of embedded voice agents, assistants, toys, and compliance-safe apps.
+ tags:
+ - text-to-speech
+ - TTS
+ license: apache-2.0
+ capabilities:
+ default: "cpu-neutts"
+ nvidia: "cuda12-neutts"
+ amd: "rocm-neutts"
+ nvidia-l4t: "nvidia-l4t-neutts"
+ nvidia-cuda-12: "cuda12-neutts"
+ nvidia-l4t-cuda-12: "nvidia-l4t-arm64-neutts"
+- !!merge <<: *neutts
+ name: "neutts-development"
+ capabilities:
+ default: "cpu-neutts-development"
+ nvidia: "cuda12-neutts-development"
+ amd: "rocm-neutts-development"
+ nvidia-l4t: "nvidia-l4t-neutts-development"
+ nvidia-cuda-12: "cuda12-neutts-development"
+ nvidia-l4t-cuda-12: "nvidia-l4t-arm64-neutts-development"
+- !!merge <<: *llamacpp
+ name: "llama-cpp-development"
+ capabilities:
+ default: "cpu-llama-cpp-development"
+ nvidia: "cuda12-llama-cpp-development"
+ intel: "intel-sycl-f16-llama-cpp-development"
+ amd: "rocm-llama-cpp-development"
+ metal: "metal-llama-cpp-development"
+ vulkan: "vulkan-llama-cpp-development"
+ nvidia-l4t: "nvidia-l4t-arm64-llama-cpp-development"
+ nvidia-cuda-13: "cuda13-llama-cpp-development"
+ nvidia-cuda-12: "cuda12-llama-cpp-development"
+ nvidia-l4t-cuda-12: "nvidia-l4t-arm64-llama-cpp-development"
+ nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-llama-cpp-development"
+- !!merge <<: *neutts
+ name: "cpu-neutts"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-neutts"
+ mirrors:
+ - localai/localai-backends:latest-cpu-neutts
+- !!merge <<: *neutts
+ name: "cuda12-neutts"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-neutts"
+ mirrors:
+ - localai/localai-backends:latest-gpu-nvidia-cuda-12-neutts
+- !!merge <<: *neutts
+ name: "rocm-neutts"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-rocm-hipblas-neutts"
+ mirrors:
+ - localai/localai-backends:latest-gpu-rocm-hipblas-neutts
+- !!merge <<: *neutts
+ name: "nvidia-l4t-arm64-neutts"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-arm64-neutts"
+ mirrors:
+ - localai/localai-backends:latest-nvidia-l4t-arm64-neutts
+- !!merge <<: *neutts
+ name: "cpu-neutts-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-cpu-neutts"
+ mirrors:
+ - localai/localai-backends:master-cpu-neutts
+- !!merge <<: *neutts
+ name: "cuda12-neutts-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-neutts"
+ mirrors:
+ - localai/localai-backends:master-gpu-nvidia-cuda-12-neutts
+- !!merge <<: *neutts
+ name: "rocm-neutts-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-gpu-rocm-hipblas-neutts"
+ mirrors:
+ - localai/localai-backends:master-gpu-rocm-hipblas-neutts
+- !!merge <<: *neutts
+ name: "nvidia-l4t-arm64-neutts-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-arm64-neutts"
+ mirrors:
+ - localai/localai-backends:master-nvidia-l4t-arm64-neutts
+- !!merge <<: *mlx
+ name: "mlx-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-mlx"
+ mirrors:
+ - localai/localai-backends:master-metal-darwin-arm64-mlx
+- !!merge <<: *mlx-vlm
+ name: "mlx-vlm-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-mlx-vlm"
+ mirrors:
+ - localai/localai-backends:master-metal-darwin-arm64-mlx-vlm
+- !!merge <<: *mlx-audio
+ name: "mlx-audio-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-mlx-audio"
+ mirrors:
+ - localai/localai-backends:master-metal-darwin-arm64-mlx-audio
+- !!merge <<: *kitten-tts
+ name: "kitten-tts-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-kitten-tts"
+ mirrors:
+ - localai/localai-backends:master-kitten-tts
+- !!merge <<: *huggingface
+ name: "huggingface-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-huggingface"
+ mirrors:
+ - localai/localai-backends:master-huggingface
+- !!merge <<: *local-store
+ name: "local-store-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-cpu-local-store"
+ mirrors:
+ - localai/localai-backends:master-cpu-local-store
+- !!merge <<: *silero-vad
+ name: "silero-vad-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-cpu-silero-vad"
+ mirrors:
+ - localai/localai-backends:master-cpu-silero-vad
+- !!merge <<: *piper
+ name: "piper-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-piper"
+ mirrors:
+ - localai/localai-backends:master-piper
+## llama-cpp
+- !!merge <<: *llamacpp
+ name: "nvidia-l4t-arm64-llama-cpp"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-arm64-llama-cpp"
+ mirrors:
+ - localai/localai-backends:latest-nvidia-l4t-arm64-llama-cpp
+- !!merge <<: *llamacpp
+ name: "nvidia-l4t-arm64-llama-cpp-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-arm64-llama-cpp"
+ mirrors:
+ - localai/localai-backends:master-nvidia-l4t-arm64-llama-cpp
+- !!merge <<: *llamacpp
+ name: "cuda13-nvidia-l4t-arm64-llama-cpp"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-cuda-13-arm64-llama-cpp"
+ mirrors:
+ - localai/localai-backends:latest-nvidia-l4t-cuda-13-arm64-llama-cpp
+- !!merge <<: *llamacpp
+ name: "cuda13-nvidia-l4t-arm64-llama-cpp-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-cuda-13-arm64-llama-cpp"
+ mirrors:
+ - localai/localai-backends:master-nvidia-l4t-cuda-13-arm64-llama-cpp
+- !!merge <<: *llamacpp
+ name: "cpu-llama-cpp"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-llama-cpp"
+ mirrors:
+ - localai/localai-backends:latest-cpu-llama-cpp
+- !!merge <<: *llamacpp
+ name: "cpu-llama-cpp-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-cpu-llama-cpp"
+ mirrors:
+ - localai/localai-backends:master-cpu-llama-cpp
+- !!merge <<: *llamacpp
+ name: "cuda12-llama-cpp"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-llama-cpp"
+ mirrors:
+ - localai/localai-backends:latest-gpu-nvidia-cuda-12-llama-cpp
+- !!merge <<: *llamacpp
+ name: "rocm-llama-cpp"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-rocm-hipblas-llama-cpp"
+ mirrors:
+ - localai/localai-backends:latest-gpu-rocm-hipblas-llama-cpp
+- !!merge <<: *llamacpp
+ name: "intel-sycl-f32-llama-cpp"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-sycl-f32-llama-cpp"
+ mirrors:
+ - localai/localai-backends:latest-gpu-intel-sycl-f32-llama-cpp
+- !!merge <<: *llamacpp
+ name: "intel-sycl-f16-llama-cpp"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-sycl-f16-llama-cpp"
+ mirrors:
+ - localai/localai-backends:latest-gpu-intel-sycl-f16-llama-cpp
+- !!merge <<: *llamacpp
+ name: "vulkan-llama-cpp"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-vulkan-llama-cpp"
+ mirrors:
+ - localai/localai-backends:latest-gpu-vulkan-llama-cpp
+- !!merge <<: *llamacpp
+ name: "vulkan-llama-cpp-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-gpu-vulkan-llama-cpp"
+ mirrors:
+ - localai/localai-backends:master-gpu-vulkan-llama-cpp
+- !!merge <<: *llamacpp
+ name: "metal-llama-cpp"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-llama-cpp"
+ mirrors:
+ - localai/localai-backends:latest-metal-darwin-arm64-llama-cpp
+- !!merge <<: *llamacpp
+ name: "metal-llama-cpp-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-llama-cpp"
+ mirrors:
+ - localai/localai-backends:master-metal-darwin-arm64-llama-cpp
+- !!merge <<: *llamacpp
+ name: "cuda12-llama-cpp-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-llama-cpp"
+ mirrors:
+ - localai/localai-backends:master-gpu-nvidia-cuda-12-llama-cpp
+- !!merge <<: *llamacpp
+ name: "rocm-llama-cpp-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-gpu-rocm-hipblas-llama-cpp"
+ mirrors:
+ - localai/localai-backends:master-gpu-rocm-hipblas-llama-cpp
+- !!merge <<: *llamacpp
+ name: "intel-sycl-f32-llama-cpp-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-gpu-intel-sycl-f32-llama-cpp"
+ mirrors:
+ - localai/localai-backends:master-gpu-intel-sycl-f32-llama-cpp
+- !!merge <<: *llamacpp
+ name: "intel-sycl-f16-llama-cpp-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-gpu-intel-sycl-f16-llama-cpp"
+ mirrors:
+ - localai/localai-backends:master-gpu-intel-sycl-f16-llama-cpp
+- !!merge <<: *llamacpp
+ name: "cuda13-llama-cpp"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-13-llama-cpp"
+ mirrors:
+ - localai/localai-backends:latest-gpu-nvidia-cuda-13-llama-cpp
+- !!merge <<: *llamacpp
+ name: "cuda13-llama-cpp-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-13-llama-cpp"
+ mirrors:
+ - localai/localai-backends:master-gpu-nvidia-cuda-13-llama-cpp
+## whisper
+- !!merge <<: *whispercpp
+ name: "nvidia-l4t-arm64-whisper"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-arm64-whisper"
+ mirrors:
+ - localai/localai-backends:latest-nvidia-l4t-arm64-whisper
+- !!merge <<: *whispercpp
+ name: "nvidia-l4t-arm64-whisper-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-arm64-whisper"
+ mirrors:
+ - localai/localai-backends:master-nvidia-l4t-arm64-whisper
+- !!merge <<: *whispercpp
+ name: "cuda13-nvidia-l4t-arm64-whisper"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-cuda-13-arm64-whisper"
+ mirrors:
+ - localai/localai-backends:latest-nvidia-l4t-cuda-13-arm64-whisper
+- !!merge <<: *whispercpp
+ name: "cuda13-nvidia-l4t-arm64-whisper-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-cuda-13-arm64-whisper"
+ mirrors:
+ - localai/localai-backends:master-nvidia-l4t-cuda-13-arm64-whisper
+- !!merge <<: *whispercpp
+ name: "cpu-whisper"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-whisper"
+ mirrors:
+ - localai/localai-backends:latest-cpu-whisper
+- !!merge <<: *whispercpp
+ name: "metal-whisper"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-whisper"
+ mirrors:
+ - localai/localai-backends:latest-metal-darwin-arm64-whisper
+- !!merge <<: *whispercpp
+ name: "metal-whisper-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-whisper"
+ mirrors:
+ - localai/localai-backends:master-metal-darwin-arm64-whisper
+- !!merge <<: *whispercpp
+ name: "cpu-whisper-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-cpu-whisper"
+ mirrors:
+ - localai/localai-backends:master-cpu-whisper
+- !!merge <<: *whispercpp
+ name: "cuda12-whisper"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-whisper"
+ mirrors:
+ - localai/localai-backends:latest-gpu-nvidia-cuda-12-whisper
+- !!merge <<: *whispercpp
+ name: "rocm-whisper"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-rocm-hipblas-whisper"
+ mirrors:
+ - localai/localai-backends:latest-gpu-rocm-hipblas-whisper
+- !!merge <<: *whispercpp
+ name: "intel-sycl-f32-whisper"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-sycl-f32-whisper"
+ mirrors:
+ - localai/localai-backends:latest-gpu-intel-sycl-f32-whisper
+- !!merge <<: *whispercpp
+ name: "intel-sycl-f16-whisper"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-sycl-f16-whisper"
+ mirrors:
+ - localai/localai-backends:latest-gpu-intel-sycl-f16-whisper
+- !!merge <<: *whispercpp
+ name: "vulkan-whisper"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-vulkan-whisper"
+ mirrors:
+ - localai/localai-backends:latest-gpu-vulkan-whisper
+- !!merge <<: *whispercpp
+ name: "vulkan-whisper-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-gpu-vulkan-whisper"
+ mirrors:
+ - localai/localai-backends:master-gpu-vulkan-whisper
+- !!merge <<: *whispercpp
+ name: "metal-whisper"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-whisper"
+ mirrors:
+ - localai/localai-backends:latest-metal-darwin-arm64-whisper
+- !!merge <<: *whispercpp
+ name: "metal-whisper-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-whisper"
+ mirrors:
+ - localai/localai-backends:master-metal-darwin-arm64-whisper
+- !!merge <<: *whispercpp
+ name: "cuda12-whisper-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-whisper"
+ mirrors:
+ - localai/localai-backends:master-gpu-nvidia-cuda-12-whisper
+- !!merge <<: *whispercpp
+ name: "rocm-whisper-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-gpu-rocm-hipblas-whisper"
+ mirrors:
+ - localai/localai-backends:master-gpu-rocm-hipblas-whisper
+- !!merge <<: *whispercpp
+ name: "intel-sycl-f32-whisper-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-gpu-intel-sycl-f32-whisper"
+ mirrors:
+ - localai/localai-backends:master-gpu-intel-sycl-f32-whisper
+- !!merge <<: *whispercpp
+ name: "intel-sycl-f16-whisper-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-gpu-intel-sycl-f16-whisper"
+ mirrors:
+ - localai/localai-backends:master-gpu-intel-sycl-f16-whisper
+- !!merge <<: *whispercpp
+ name: "cuda13-whisper"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-13-whisper"
+ mirrors:
+ - localai/localai-backends:latest-gpu-nvidia-cuda-13-whisper
+- !!merge <<: *whispercpp
+ name: "cuda13-whisper-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-13-whisper"
+ mirrors:
+ - localai/localai-backends:master-gpu-nvidia-cuda-13-whisper
+## stablediffusion-ggml
+- !!merge <<: *stablediffusionggml
+ name: "cpu-stablediffusion-ggml"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-stablediffusion-ggml"
+ mirrors:
+ - localai/localai-backends:latest-cpu-stablediffusion-ggml
+- !!merge <<: *stablediffusionggml
+ name: "cpu-stablediffusion-ggml-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-cpu-stablediffusion-ggml"
+ mirrors:
+ - localai/localai-backends:master-cpu-stablediffusion-ggml
+- !!merge <<: *stablediffusionggml
+ name: "metal-stablediffusion-ggml"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-stablediffusion-ggml"
+ mirrors:
+ - localai/localai-backends:latest-metal-darwin-arm64-stablediffusion-ggml
+- !!merge <<: *stablediffusionggml
+ name: "metal-stablediffusion-ggml-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-stablediffusion-ggml"
+ mirrors:
+ - localai/localai-backends:master-metal-darwin-arm64-stablediffusion-ggml
+- !!merge <<: *stablediffusionggml
+ name: "vulkan-stablediffusion-ggml"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-vulkan-stablediffusion-ggml"
+ mirrors:
+ - localai/localai-backends:latest-gpu-vulkan-stablediffusion-ggml
+- !!merge <<: *stablediffusionggml
+ name: "vulkan-stablediffusion-ggml-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-gpu-vulkan-stablediffusion-ggml"
+ mirrors:
+ - localai/localai-backends:master-gpu-vulkan-stablediffusion-ggml
+- !!merge <<: *stablediffusionggml
+ name: "cuda12-stablediffusion-ggml"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-stablediffusion-ggml"
+ mirrors:
+ - localai/localai-backends:latest-gpu-nvidia-cuda-12-stablediffusion-ggml
+- !!merge <<: *stablediffusionggml
+ name: "intel-sycl-f32-stablediffusion-ggml"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-sycl-f32-stablediffusion-ggml"
+- !!merge <<: *stablediffusionggml
+ name: "intel-sycl-f16-stablediffusion-ggml"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-sycl-f16-stablediffusion-ggml"
+ mirrors:
+ - localai/localai-backends:latest-gpu-intel-sycl-f16-stablediffusion-ggml
+- !!merge <<: *stablediffusionggml
+ name: "cuda12-stablediffusion-ggml-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-stablediffusion-ggml"
+ mirrors:
+ - localai/localai-backends:master-gpu-nvidia-cuda-12-stablediffusion-ggml
+- !!merge <<: *stablediffusionggml
+ name: "intel-sycl-f32-stablediffusion-ggml-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-gpu-intel-sycl-f32-stablediffusion-ggml"
+ mirrors:
+ - localai/localai-backends:master-gpu-intel-sycl-f32-stablediffusion-ggml
+- !!merge <<: *stablediffusionggml
+ name: "intel-sycl-f16-stablediffusion-ggml-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-gpu-intel-sycl-f16-stablediffusion-ggml"
+ mirrors:
+ - localai/localai-backends:master-gpu-intel-sycl-f16-stablediffusion-ggml
+- !!merge <<: *stablediffusionggml
+ name: "nvidia-l4t-arm64-stablediffusion-ggml-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-arm64-stablediffusion-ggml"
+ mirrors:
+ - localai/localai-backends:master-nvidia-l4t-arm64-stablediffusion-ggml
+- !!merge <<: *stablediffusionggml
+ name: "nvidia-l4t-arm64-stablediffusion-ggml"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-arm64-stablediffusion-ggml"
+ mirrors:
+ - localai/localai-backends:latest-nvidia-l4t-arm64-stablediffusion-ggml
+- !!merge <<: *stablediffusionggml
+ name: "cuda13-nvidia-l4t-arm64-stablediffusion-ggml"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-cuda-13-arm64-stablediffusion-ggml"
+ mirrors:
+ - localai/localai-backends:latest-nvidia-l4t-cuda-13-arm64-stablediffusion-ggml
+- !!merge <<: *stablediffusionggml
+ name: "cuda13-nvidia-l4t-arm64-stablediffusion-ggml-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-cuda-13-arm64-stablediffusion-ggml"
+ mirrors:
+ - localai/localai-backends:master-nvidia-l4t-cuda-13-arm64-stablediffusion-ggml
+- !!merge <<: *stablediffusionggml
+ name: "cuda13-stablediffusion-ggml"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-13-stablediffusion-ggml"
+ mirrors:
+ - localai/localai-backends:latest-gpu-nvidia-cuda-13-stablediffusion-ggml
+- !!merge <<: *stablediffusionggml
+ name: "cuda13-stablediffusion-ggml-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-13-stablediffusion-ggml"
+ mirrors:
+ - localai/localai-backends:master-gpu-nvidia-cuda-13-stablediffusion-ggml
+# vllm
+- !!merge <<: *vllm
+ name: "vllm-development"
+ capabilities:
+ nvidia: "cuda12-vllm-development"
+ amd: "rocm-vllm-development"
+ intel: "intel-vllm-development"
+- !!merge <<: *vllm
+ name: "cuda12-vllm"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-vllm"
+ mirrors:
+ - localai/localai-backends:latest-gpu-nvidia-cuda-12-vllm
+- !!merge <<: *vllm
+ name: "rocm-vllm"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-rocm-hipblas-vllm"
+ mirrors:
+ - localai/localai-backends:latest-gpu-rocm-hipblas-vllm
+- !!merge <<: *vllm
+ name: "intel-vllm"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-vllm"
+ mirrors:
+ - localai/localai-backends:latest-gpu-intel-vllm
+- !!merge <<: *vllm
+ name: "cuda12-vllm-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-vllm"
+ mirrors:
+ - localai/localai-backends:master-gpu-nvidia-cuda-12-vllm
+- !!merge <<: *vllm
+ name: "rocm-vllm-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-gpu-rocm-hipblas-vllm"
+ mirrors:
+ - localai/localai-backends:master-gpu-rocm-hipblas-vllm
+- !!merge <<: *vllm
+ name: "intel-vllm-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-gpu-intel-vllm"
+ mirrors:
+ - localai/localai-backends:master-gpu-intel-vllm
+# rfdetr
+- !!merge <<: *rfdetr
+ name: "rfdetr-development"
+ capabilities:
+ nvidia: "cuda12-rfdetr-development"
+ intel: "intel-rfdetr-development"
+ #amd: "rocm-rfdetr-development"
+ nvidia-l4t: "nvidia-l4t-arm64-rfdetr-development"
+ default: "cpu-rfdetr-development"
+ nvidia-cuda-13: "cuda13-rfdetr-development"
+- !!merge <<: *rfdetr
+ name: "cuda12-rfdetr"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-rfdetr"
+ mirrors:
+ - localai/localai-backends:latest-gpu-nvidia-cuda-12-rfdetr
+- !!merge <<: *rfdetr
+ name: "intel-rfdetr"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-rfdetr"
+ mirrors:
+ - localai/localai-backends:latest-gpu-intel-rfdetr
+# - !!merge <<: *rfdetr
+# name: "rocm-rfdetr"
+# uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-hipblas-rfdetr"
+# mirrors:
+# - localai/localai-backends:latest-gpu-hipblas-rfdetr
+- !!merge <<: *rfdetr
+ name: "nvidia-l4t-arm64-rfdetr"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-arm64-rfdetr"
+ mirrors:
+ - localai/localai-backends:latest-nvidia-l4t-arm64-rfdetr
+- !!merge <<: *rfdetr
+ name: "nvidia-l4t-arm64-rfdetr-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-arm64-rfdetr"
+ mirrors:
+ - localai/localai-backends:master-nvidia-l4t-arm64-rfdetr
+- !!merge <<: *rfdetr
+ name: "cpu-rfdetr"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-rfdetr"
+ mirrors:
+ - localai/localai-backends:latest-cpu-rfdetr
+- !!merge <<: *rfdetr
+ name: "cuda12-rfdetr-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-rfdetr"
+ mirrors:
+ - localai/localai-backends:master-gpu-nvidia-cuda-12-rfdetr
+- !!merge <<: *rfdetr
+ name: "intel-rfdetr-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-gpu-intel-rfdetr"
+ mirrors:
+ - localai/localai-backends:master-gpu-intel-rfdetr
+# - !!merge <<: *rfdetr
+# name: "rocm-rfdetr-development"
+# uri: "quay.io/go-skynet/local-ai-backends:master-gpu-hipblas-rfdetr"
+# mirrors:
+# - localai/localai-backends:master-gpu-hipblas-rfdetr
+- !!merge <<: *rfdetr
+ name: "cpu-rfdetr-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-cpu-rfdetr"
+ mirrors:
+ - localai/localai-backends:master-cpu-rfdetr
+- !!merge <<: *rfdetr
+ name: "intel-rfdetr"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-rfdetr"
+ mirrors:
+ - localai/localai-backends:latest-gpu-intel-rfdetr
+- !!merge <<: *rfdetr
+ name: "cuda13-rfdetr"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-13-rfdetr"
+ mirrors:
+ - localai/localai-backends:latest-gpu-nvidia-cuda-13-rfdetr
+- !!merge <<: *rfdetr
+ name: "cuda13-rfdetr-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-13-rfdetr"
+ mirrors:
+ - localai/localai-backends:master-gpu-nvidia-cuda-13-rfdetr
+## Rerankers
+- !!merge <<: *rerankers
+ name: "rerankers-development"
+ capabilities:
+ nvidia: "cuda12-rerankers-development"
+ intel: "intel-rerankers-development"
+ amd: "rocm-rerankers-development"
+ nvidia-cuda-13: "cuda13-rerankers-development"
+- !!merge <<: *rerankers
+ name: "cuda12-rerankers"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-rerankers"
+ mirrors:
+ - localai/localai-backends:latest-gpu-nvidia-cuda-12-rerankers
+- !!merge <<: *rerankers
+ name: "intel-rerankers"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-rerankers"
+ mirrors:
+ - localai/localai-backends:latest-gpu-intel-rerankers
+- !!merge <<: *rerankers
+ name: "rocm-rerankers"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-rocm-hipblas-rerankers"
+ mirrors:
+ - localai/localai-backends:latest-gpu-rocm-hipblas-rerankers
+- !!merge <<: *rerankers
+ name: "cuda12-rerankers-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-rerankers"
+ mirrors:
+ - localai/localai-backends:master-gpu-nvidia-cuda-12-rerankers
+- !!merge <<: *rerankers
+ name: "rocm-rerankers-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-gpu-rocm-hipblas-rerankers"
+ mirrors:
+ - localai/localai-backends:master-gpu-rocm-hipblas-rerankers
+- !!merge <<: *rerankers
+ name: "intel-rerankers-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-gpu-intel-rerankers"
+ mirrors:
+ - localai/localai-backends:master-gpu-intel-rerankers
+- !!merge <<: *rerankers
+ name: "cuda13-rerankers"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-13-rerankers"
+ mirrors:
+ - localai/localai-backends:latest-gpu-nvidia-cuda-13-rerankers
+- !!merge <<: *rerankers
+ name: "cuda13-rerankers-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-13-rerankers"
+ mirrors:
+ - localai/localai-backends:master-gpu-nvidia-cuda-13-rerankers
+## Transformers
+- !!merge <<: *transformers
+ name: "transformers-development"
+ capabilities:
+ nvidia: "cuda12-transformers-development"
+ intel: "intel-transformers-development"
+ amd: "rocm-transformers-development"
+ nvidia-cuda-13: "cuda13-transformers-development"
+- !!merge <<: *transformers
+ name: "cuda12-transformers"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-transformers"
+ mirrors:
+ - localai/localai-backends:latest-gpu-nvidia-cuda-12-transformers
+- !!merge <<: *transformers
+ name: "rocm-transformers"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-rocm-hipblas-transformers"
+ mirrors:
+ - localai/localai-backends:latest-gpu-rocm-hipblas-transformers
+- !!merge <<: *transformers
+ name: "intel-transformers"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-transformers"
+ mirrors:
+ - localai/localai-backends:latest-gpu-intel-transformers
+- !!merge <<: *transformers
+ name: "cuda12-transformers-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-transformers"
+ mirrors:
+ - localai/localai-backends:master-gpu-nvidia-cuda-12-transformers
+- !!merge <<: *transformers
+ name: "rocm-transformers-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-gpu-rocm-hipblas-transformers"
+ mirrors:
+ - localai/localai-backends:master-gpu-rocm-hipblas-transformers
+- !!merge <<: *transformers
+ name: "intel-transformers-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-gpu-intel-transformers"
+ mirrors:
+ - localai/localai-backends:master-gpu-intel-transformers
+- !!merge <<: *transformers
+ name: "cuda13-transformers"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-13-transformers"
+ mirrors:
+ - localai/localai-backends:latest-gpu-nvidia-cuda-13-transformers
+- !!merge <<: *transformers
+ name: "cuda13-transformers-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-13-transformers"
+ mirrors:
+ - localai/localai-backends:master-gpu-nvidia-cuda-13-transformers
+## Diffusers
+- !!merge <<: *diffusers
+ name: "diffusers-development"
+ capabilities:
+ nvidia: "cuda12-diffusers-development"
+ intel: "intel-diffusers-development"
+ amd: "rocm-diffusers-development"
+ nvidia-l4t: "nvidia-l4t-diffusers-development"
+ metal: "metal-diffusers-development"
+ default: "cpu-diffusers-development"
+ nvidia-cuda-13: "cuda13-diffusers-development"
+- !!merge <<: *diffusers
+ name: "cpu-diffusers"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-diffusers"
+ mirrors:
+ - localai/localai-backends:latest-cpu-diffusers
+- !!merge <<: *diffusers
+ name: "cpu-diffusers-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-cpu-diffusers"
+ mirrors:
+ - localai/localai-backends:master-cpu-diffusers
+- !!merge <<: *diffusers
+ name: "nvidia-l4t-diffusers"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-diffusers"
+ mirrors:
+ - localai/localai-backends:latest-nvidia-l4t-diffusers
+- !!merge <<: *diffusers
+ name: "nvidia-l4t-diffusers-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-diffusers"
+ mirrors:
+ - localai/localai-backends:master-nvidia-l4t-diffusers
+- !!merge <<: *diffusers
+ name: "cuda13-nvidia-l4t-arm64-diffusers"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-cuda-13-arm64-diffusers"
+ mirrors:
+ - localai/localai-backends:latest-nvidia-l4t-cuda-13-arm64-diffusers
+- !!merge <<: *diffusers
+ name: "cuda13-nvidia-l4t-arm64-diffusers-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-cuda-13-arm64-diffusers"
+ mirrors:
+ - localai/localai-backends:master-nvidia-l4t-cuda-13-arm64-diffusers
+- !!merge <<: *diffusers
+ name: "cuda12-diffusers"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-diffusers"
+ mirrors:
+ - localai/localai-backends:latest-gpu-nvidia-cuda-12-diffusers
+- !!merge <<: *diffusers
+ name: "rocm-diffusers"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-rocm-hipblas-diffusers"
+ mirrors:
+ - localai/localai-backends:latest-gpu-rocm-hipblas-diffusers
+- !!merge <<: *diffusers
+ name: "intel-diffusers"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-diffusers"
+ mirrors:
+ - localai/localai-backends:latest-gpu-intel-diffusers
+- !!merge <<: *diffusers
+ name: "cuda12-diffusers-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-diffusers"
+ mirrors:
+ - localai/localai-backends:master-gpu-nvidia-cuda-12-diffusers
+- !!merge <<: *diffusers
+ name: "rocm-diffusers-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-gpu-rocm-hipblas-diffusers"
+ mirrors:
+ - localai/localai-backends:master-gpu-rocm-hipblas-diffusers
+- !!merge <<: *diffusers
+ name: "intel-diffusers-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-gpu-intel-diffusers"
+ mirrors:
+ - localai/localai-backends:master-gpu-intel-diffusers
+- !!merge <<: *diffusers
+ name: "cuda13-diffusers"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-13-diffusers"
+ mirrors:
+ - localai/localai-backends:latest-gpu-nvidia-cuda-13-diffusers
+- !!merge <<: *diffusers
+ name: "cuda13-diffusers-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-13-diffusers"
+ mirrors:
+ - localai/localai-backends:master-gpu-nvidia-cuda-13-diffusers
+- !!merge <<: *diffusers
+ name: "metal-diffusers"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-diffusers"
+ mirrors:
+ - localai/localai-backends:latest-metal-darwin-arm64-diffusers
+- !!merge <<: *diffusers
+ name: "metal-diffusers-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-diffusers"
+ mirrors:
+ - localai/localai-backends:master-metal-darwin-arm64-diffusers
+ ## exllama2
+- !!merge <<: *exllama2
+ name: "exllama2-development"
+ capabilities:
+ nvidia: "cuda12-exllama2-development"
+ intel: "intel-exllama2-development"
+- !!merge <<: *exllama2
+ name: "cuda12-exllama2"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-exllama2"
+ mirrors:
+ - localai/localai-backends:latest-gpu-nvidia-cuda-12-exllama2
+- !!merge <<: *exllama2
+ name: "cuda12-exllama2-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-exllama2"
+ mirrors:
+ - localai/localai-backends:master-gpu-nvidia-cuda-12-exllama2
+## kokoro
+- !!merge <<: *kokoro
+ name: "kokoro-development"
+ capabilities:
+ nvidia: "cuda12-kokoro-development"
+ intel: "intel-kokoro-development"
+ amd: "rocm-kokoro-development"
+ nvidia-l4t: "nvidia-l4t-kokoro-development"
+- !!merge <<: *kokoro
+ name: "cuda12-kokoro-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-kokoro"
+ mirrors:
+ - localai/localai-backends:master-gpu-nvidia-cuda-12-kokoro
+- !!merge <<: *kokoro
+ name: "rocm-kokoro-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-gpu-rocm-hipblas-kokoro"
+ mirrors:
+ - localai/localai-backends:master-gpu-rocm-hipblas-kokoro
+- !!merge <<: *kokoro
+ name: "intel-kokoro"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-kokoro"
+ mirrors:
+ - localai/localai-backends:latest-gpu-intel-kokoro
+- !!merge <<: *kokoro
+ name: "intel-kokoro-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-gpu-intel-kokoro"
+ mirrors:
+ - localai/localai-backends:master-gpu-intel-kokoro
+- !!merge <<: *kokoro
+ name: "nvidia-l4t-kokoro"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-kokoro"
+ mirrors:
+ - localai/localai-backends:latest-nvidia-l4t-kokoro
+- !!merge <<: *kokoro
+ name: "nvidia-l4t-kokoro-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-kokoro"
+ mirrors:
+ - localai/localai-backends:master-nvidia-l4t-kokoro
+- !!merge <<: *kokoro
+ name: "cuda12-kokoro"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-kokoro"
+ mirrors:
+ - localai/localai-backends:latest-gpu-nvidia-cuda-12-kokoro
+- !!merge <<: *kokoro
+ name: "rocm-kokoro"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-rocm-hipblas-kokoro"
+ mirrors:
+ - localai/localai-backends:latest-gpu-rocm-hipblas-kokoro
+- !!merge <<: *kokoro
+ name: "cuda13-kokoro"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-13-kokoro"
+ mirrors:
+ - localai/localai-backends:latest-gpu-nvidia-cuda-13-kokoro
+- !!merge <<: *kokoro
+ name: "cuda13-kokoro-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-13-kokoro"
+ mirrors:
+ - localai/localai-backends:master-gpu-nvidia-cuda-13-kokoro
+## faster-whisper
+- !!merge <<: *faster-whisper
+ name: "faster-whisper-development"
+ capabilities:
+ nvidia: "cuda12-faster-whisper-development"
+ intel: "intel-faster-whisper-development"
+ amd: "rocm-faster-whisper-development"
+ nvidia-cuda-13: "cuda13-faster-whisper-development"
+- !!merge <<: *faster-whisper
+ name: "cuda12-faster-whisper-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-faster-whisper"
+ mirrors:
+ - localai/localai-backends:master-gpu-nvidia-cuda-12-faster-whisper
+- !!merge <<: *faster-whisper
+ name: "rocm-faster-whisper-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-gpu-rocm-hipblas-faster-whisper"
+ mirrors:
+ - localai/localai-backends:master-gpu-rocm-hipblas-faster-whisper
+- !!merge <<: *faster-whisper
+ name: "intel-faster-whisper"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-faster-whisper"
+ mirrors:
+ - localai/localai-backends:latest-gpu-intel-faster-whisper
+- !!merge <<: *faster-whisper
+ name: "intel-faster-whisper-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-gpu-intel-faster-whisper"
+ mirrors:
+ - localai/localai-backends:master-gpu-intel-faster-whisper
+- !!merge <<: *faster-whisper
+ name: "cuda13-faster-whisper"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-13-faster-whisper"
+ mirrors:
+ - localai/localai-backends:latest-gpu-nvidia-cuda-13-faster-whisper
+- !!merge <<: *faster-whisper
+ name: "cuda13-faster-whisper-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-13-faster-whisper"
+ mirrors:
+ - localai/localai-backends:master-gpu-nvidia-cuda-13-faster-whisper
+## moonshine
+- !!merge <<: *moonshine
+ name: "moonshine-development"
+ capabilities:
+ nvidia: "cuda12-moonshine-development"
+ default: "cpu-moonshine-development"
+ nvidia-cuda-13: "cuda13-moonshine-development"
+ nvidia-cuda-12: "cuda12-moonshine-development"
+- !!merge <<: *moonshine
+ name: "cpu-moonshine"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-moonshine"
+ mirrors:
+ - localai/localai-backends:latest-cpu-moonshine
+- !!merge <<: *moonshine
+ name: "cpu-moonshine-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-cpu-moonshine"
+ mirrors:
+ - localai/localai-backends:master-cpu-moonshine
+- !!merge <<: *moonshine
+ name: "cuda12-moonshine"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-moonshine"
+ mirrors:
+ - localai/localai-backends:latest-gpu-nvidia-cuda-12-moonshine
+- !!merge <<: *moonshine
+ name: "cuda12-moonshine-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-moonshine"
+ mirrors:
+ - localai/localai-backends:master-gpu-nvidia-cuda-12-moonshine
+- !!merge <<: *moonshine
+ name: "cuda13-moonshine"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-13-moonshine"
+ mirrors:
+ - localai/localai-backends:latest-gpu-nvidia-cuda-13-moonshine
+- !!merge <<: *moonshine
+ name: "cuda13-moonshine-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-13-moonshine"
+ mirrors:
+ - localai/localai-backends:master-gpu-nvidia-cuda-13-moonshine
+## coqui
+
+- !!merge <<: *coqui
+ name: "coqui-development"
+ capabilities:
+ nvidia: "cuda12-coqui-development"
+ intel: "intel-coqui-development"
+ amd: "rocm-coqui-development"
+- !!merge <<: *coqui
+ name: "cuda12-coqui"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-coqui"
+ mirrors:
+ - localai/localai-backends:latest-gpu-nvidia-cuda-12-coqui
+- !!merge <<: *coqui
+ name: "cuda12-coqui-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-coqui"
+ mirrors:
+ - localai/localai-backends:master-gpu-nvidia-cuda-12-coqui
+- !!merge <<: *coqui
+ name: "rocm-coqui-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-gpu-rocm-hipblas-coqui"
+ mirrors:
+ - localai/localai-backends:master-gpu-rocm-hipblas-coqui
+- !!merge <<: *coqui
+ name: "intel-coqui"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-coqui"
+ mirrors:
+ - localai/localai-backends:latest-gpu-intel-coqui
+- !!merge <<: *coqui
+ name: "intel-coqui-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-gpu-intel-coqui"
+ mirrors:
+ - localai/localai-backends:master-gpu-intel-coqui
+- !!merge <<: *coqui
+ name: "rocm-coqui"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-rocm-hipblas-coqui"
+ mirrors:
+ - localai/localai-backends:latest-gpu-rocm-hipblas-coqui
+## bark
+- !!merge <<: *bark
+ name: "bark-development"
+ capabilities:
+ nvidia: "cuda12-bark-development"
+ intel: "intel-bark-development"
+ amd: "rocm-bark-development"
+- !!merge <<: *bark
+ name: "rocm-bark-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-gpu-rocm-hipblas-bark"
+ mirrors:
+ - localai/localai-backends:master-gpu-rocm-hipblas-bark
+- !!merge <<: *bark
+ name: "intel-bark"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-bark"
+ mirrors:
+ - localai/localai-backends:latest-gpu-intel-bark
+- !!merge <<: *bark
+ name: "intel-bark-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-gpu-intel-bark"
+ mirrors:
+ - localai/localai-backends:master-gpu-intel-bark
+- !!merge <<: *bark
+ name: "cuda12-bark"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-bark"
+ mirrors:
+ - localai/localai-backends:latest-gpu-nvidia-cuda-12-bark
+- !!merge <<: *bark
+ name: "rocm-bark"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-rocm-hipblas-bark"
+ mirrors:
+ - localai/localai-backends:latest-gpu-rocm-hipblas-bark
+- !!merge <<: *bark
+ name: "cuda12-bark-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-bark"
+ mirrors:
+ - localai/localai-backends:master-gpu-nvidia-cuda-12-bark
+- !!merge <<: *barkcpp
+ name: "bark-cpp-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-bark-cpp"
+ alias: "bark-cpp"
+## chatterbox
+- !!merge <<: *chatterbox
+ name: "chatterbox-development"
+ capabilities:
+ nvidia: "cuda12-chatterbox-development"
+ metal: "metal-chatterbox-development"
+ default: "cpu-chatterbox-development"
+ nvidia-l4t: "nvidia-l4t-arm64-chatterbox"
+ nvidia-cuda-13: "cuda13-chatterbox-development"
+ nvidia-cuda-12: "cuda12-chatterbox-development"
+ nvidia-l4t-cuda-12: "nvidia-l4t-arm64-chatterbox"
+ nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-chatterbox"
+- !!merge <<: *chatterbox
+ name: "cpu-chatterbox"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-chatterbox"
+ mirrors:
+ - localai/localai-backends:latest-cpu-chatterbox
+- !!merge <<: *chatterbox
+ name: "cpu-chatterbox-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-cpu-chatterbox"
+ mirrors:
+ - localai/localai-backends:master-cpu-chatterbox
+- !!merge <<: *chatterbox
+ name: "nvidia-l4t-arm64-chatterbox"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-arm64-chatterbox"
+ mirrors:
+ - localai/localai-backends:latest-nvidia-l4t-arm64-chatterbox
+- !!merge <<: *chatterbox
+ name: "nvidia-l4t-arm64-chatterbox-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-arm64-chatterbox"
+ mirrors:
+ - localai/localai-backends:master-nvidia-l4t-arm64-chatterbox
+- !!merge <<: *chatterbox
+ name: "metal-chatterbox"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-chatterbox"
+ mirrors:
+ - localai/localai-backends:latest-metal-darwin-arm64-chatterbox
+- !!merge <<: *chatterbox
+ name: "metal-chatterbox-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-chatterbox"
+ mirrors:
+ - localai/localai-backends:master-metal-darwin-arm64-chatterbox
+- !!merge <<: *chatterbox
+ name: "cuda12-chatterbox-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-chatterbox"
+ mirrors:
+ - localai/localai-backends:master-gpu-nvidia-cuda-12-chatterbox
+- !!merge <<: *chatterbox
+ name: "cuda12-chatterbox"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-chatterbox"
+ mirrors:
+ - localai/localai-backends:latest-gpu-nvidia-cuda-12-chatterbox
+- !!merge <<: *chatterbox
+ name: "cuda13-chatterbox"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-13-chatterbox"
+ mirrors:
+ - localai/localai-backends:latest-gpu-nvidia-cuda-13-chatterbox
+- !!merge <<: *chatterbox
+ name: "cuda13-chatterbox-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-13-chatterbox"
+ mirrors:
+ - localai/localai-backends:master-gpu-nvidia-cuda-13-chatterbox
+- !!merge <<: *chatterbox
+ name: "cuda13-nvidia-l4t-arm64-chatterbox"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-cuda-13-arm64-chatterbox"
+ mirrors:
+ - localai/localai-backends:latest-nvidia-l4t-cuda-13-arm64-chatterbox
+- !!merge <<: *chatterbox
+ name: "cuda13-nvidia-l4t-arm64-chatterbox-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-cuda-13-arm64-chatterbox"
+ mirrors:
+ - localai/localai-backends:master-nvidia-l4t-cuda-13-arm64-chatterbox
+## vibevoice
+- !!merge <<: *vibevoice
+ name: "vibevoice-development"
+ capabilities:
+ nvidia: "cuda12-vibevoice-development"
+ intel: "intel-vibevoice-development"
+ amd: "rocm-vibevoice-development"
+ nvidia-l4t: "nvidia-l4t-vibevoice-development"
+ default: "cpu-vibevoice-development"
+ nvidia-cuda-13: "cuda13-vibevoice-development"
+ nvidia-cuda-12: "cuda12-vibevoice-development"
+ nvidia-l4t-cuda-12: "nvidia-l4t-vibevoice-development"
+ nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-vibevoice-development"
+- !!merge <<: *vibevoice
+ name: "cpu-vibevoice"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-vibevoice"
+ mirrors:
+ - localai/localai-backends:latest-cpu-vibevoice
+- !!merge <<: *vibevoice
+ name: "cpu-vibevoice-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-cpu-vibevoice"
+ mirrors:
+ - localai/localai-backends:master-cpu-vibevoice
+- !!merge <<: *vibevoice
+ name: "cuda12-vibevoice"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-vibevoice"
+ mirrors:
+ - localai/localai-backends:latest-gpu-nvidia-cuda-12-vibevoice
+- !!merge <<: *vibevoice
+ name: "cuda12-vibevoice-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-vibevoice"
+ mirrors:
+ - localai/localai-backends:master-gpu-nvidia-cuda-12-vibevoice
+- !!merge <<: *vibevoice
+ name: "cuda13-vibevoice"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-13-vibevoice"
+ mirrors:
+ - localai/localai-backends:latest-gpu-nvidia-cuda-13-vibevoice
+- !!merge <<: *vibevoice
+ name: "cuda13-vibevoice-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-13-vibevoice"
+ mirrors:
+ - localai/localai-backends:master-gpu-nvidia-cuda-13-vibevoice
+- !!merge <<: *vibevoice
+ name: "intel-vibevoice"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-vibevoice"
+ mirrors:
+ - localai/localai-backends:latest-gpu-intel-vibevoice
+- !!merge <<: *vibevoice
+ name: "intel-vibevoice-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-gpu-intel-vibevoice"
+ mirrors:
+ - localai/localai-backends:master-gpu-intel-vibevoice
+- !!merge <<: *vibevoice
+ name: "rocm-vibevoice"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-rocm-hipblas-vibevoice"
+ mirrors:
+ - localai/localai-backends:latest-gpu-rocm-hipblas-vibevoice
+- !!merge <<: *vibevoice
+ name: "rocm-vibevoice-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-gpu-rocm-hipblas-vibevoice"
+ mirrors:
+ - localai/localai-backends:master-gpu-rocm-hipblas-vibevoice
+- !!merge <<: *vibevoice
+ name: "nvidia-l4t-vibevoice"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-vibevoice"
+ mirrors:
+ - localai/localai-backends:latest-nvidia-l4t-vibevoice
+- !!merge <<: *vibevoice
+ name: "nvidia-l4t-vibevoice-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-vibevoice"
+ mirrors:
+ - localai/localai-backends:master-nvidia-l4t-vibevoice
+- !!merge <<: *vibevoice
+ name: "cuda13-nvidia-l4t-arm64-vibevoice"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-cuda-13-arm64-vibevoice"
+ mirrors:
+ - localai/localai-backends:latest-nvidia-l4t-cuda-13-arm64-vibevoice
+- !!merge <<: *vibevoice
+ name: "cuda13-nvidia-l4t-arm64-vibevoice-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-cuda-13-arm64-vibevoice"
+ mirrors:
+ - localai/localai-backends:master-nvidia-l4t-cuda-13-arm64-vibevoice
+## pocket-tts
+- !!merge <<: *pocket-tts
+ name: "pocket-tts-development"
+ capabilities:
+ nvidia: "cuda12-pocket-tts-development"
+ intel: "intel-pocket-tts-development"
+ amd: "rocm-pocket-tts-development"
+ nvidia-l4t: "nvidia-l4t-pocket-tts-development"
+ default: "cpu-pocket-tts-development"
+ nvidia-cuda-13: "cuda13-pocket-tts-development"
+ nvidia-cuda-12: "cuda12-pocket-tts-development"
+ nvidia-l4t-cuda-12: "nvidia-l4t-pocket-tts-development"
+ nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-pocket-tts-development"
+- !!merge <<: *pocket-tts
+ name: "cpu-pocket-tts"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-pocket-tts"
+ mirrors:
+ - localai/localai-backends:latest-cpu-pocket-tts
+- !!merge <<: *pocket-tts
+ name: "cpu-pocket-tts-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-cpu-pocket-tts"
+ mirrors:
+ - localai/localai-backends:master-cpu-pocket-tts
+- !!merge <<: *pocket-tts
+ name: "cuda12-pocket-tts"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-pocket-tts"
+ mirrors:
+ - localai/localai-backends:latest-gpu-nvidia-cuda-12-pocket-tts
+- !!merge <<: *pocket-tts
+ name: "cuda12-pocket-tts-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-pocket-tts"
+ mirrors:
+ - localai/localai-backends:master-gpu-nvidia-cuda-12-pocket-tts
+- !!merge <<: *pocket-tts
+ name: "cuda13-pocket-tts"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-13-pocket-tts"
+ mirrors:
+ - localai/localai-backends:latest-gpu-nvidia-cuda-13-pocket-tts
+- !!merge <<: *pocket-tts
+ name: "cuda13-pocket-tts-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-13-pocket-tts"
+ mirrors:
+ - localai/localai-backends:master-gpu-nvidia-cuda-13-pocket-tts
+- !!merge <<: *pocket-tts
+ name: "intel-pocket-tts"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-pocket-tts"
+ mirrors:
+ - localai/localai-backends:latest-gpu-intel-pocket-tts
+- !!merge <<: *pocket-tts
+ name: "intel-pocket-tts-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-gpu-intel-pocket-tts"
+ mirrors:
+ - localai/localai-backends:master-gpu-intel-pocket-tts
+- !!merge <<: *pocket-tts
+ name: "rocm-pocket-tts"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-rocm-hipblas-pocket-tts"
+ mirrors:
+ - localai/localai-backends:latest-gpu-rocm-hipblas-pocket-tts
+- !!merge <<: *pocket-tts
+ name: "rocm-pocket-tts-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-gpu-rocm-hipblas-pocket-tts"
+ mirrors:
+ - localai/localai-backends:master-gpu-rocm-hipblas-pocket-tts
+- !!merge <<: *pocket-tts
+ name: "nvidia-l4t-pocket-tts"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-pocket-tts"
+ mirrors:
+ - localai/localai-backends:latest-nvidia-l4t-pocket-tts
+- !!merge <<: *pocket-tts
+ name: "nvidia-l4t-pocket-tts-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-pocket-tts"
+ mirrors:
+ - localai/localai-backends:master-nvidia-l4t-pocket-tts
+- !!merge <<: *pocket-tts
+ name: "cuda13-nvidia-l4t-arm64-pocket-tts"
+ uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-cuda-13-arm64-pocket-tts"
+ mirrors:
+ - localai/localai-backends:latest-nvidia-l4t-cuda-13-arm64-pocket-tts
+- !!merge <<: *pocket-tts
+ name: "cuda13-nvidia-l4t-arm64-pocket-tts-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-cuda-13-arm64-pocket-tts"
+ mirrors:
+ - localai/localai-backends:master-nvidia-l4t-cuda-13-arm64-pocket-tts
diff --git a/backend/python/README.md b/backend/python/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..9f894b77b5968d78e2183b9a49646210d2e46c49
--- /dev/null
+++ b/backend/python/README.md
@@ -0,0 +1,190 @@
+# Python Backends for LocalAI
+
+This directory contains Python-based AI backends for LocalAI, providing support for various AI models and hardware acceleration targets.
+
+## Overview
+
+The Python backends use a unified build system based on `libbackend.sh` that provides:
+- **Automatic virtual environment management** with support for both `uv` and `pip`
+- **Hardware-specific dependency installation** (CPU, CUDA, Intel, MLX, etc.)
+- **Portable Python support** for standalone deployments
+- **Consistent backend execution** across different environments
+
+## Available Backends
+
+### Core AI Models
+- **transformers** - Hugging Face Transformers framework (PyTorch-based)
+- **vllm** - High-performance LLM inference engine
+- **mlx** - Apple Silicon optimized ML framework
+- **exllama2** - ExLlama2 quantized models
+
+### Audio & Speech
+- **bark** - Text-to-speech synthesis
+- **coqui** - Coqui TTS models
+- **faster-whisper** - Fast Whisper speech recognition
+- **kitten-tts** - Lightweight TTS
+- **mlx-audio** - Apple Silicon audio processing
+- **chatterbox** - TTS model
+- **kokoro** - TTS models
+
+### Computer Vision
+- **diffusers** - Stable Diffusion and image generation
+- **mlx-vlm** - Vision-language models for Apple Silicon
+- **rfdetr** - Object detection models
+
+### Specialized
+
+- **rerankers** - Text reranking models
+
+## Quick Start
+
+### Prerequisites
+- Python 3.10+ (default: 3.10.18)
+- `uv` package manager (recommended) or `pip`
+- Appropriate hardware drivers for your target (CUDA, Intel, etc.)
+
+### Installation
+
+Each backend can be installed individually:
+
+```bash
+# Navigate to a specific backend
+cd backend/python/transformers
+
+# Install dependencies
+make transformers
+# or
+bash install.sh
+
+# Run the backend
+make run
+# or
+bash run.sh
+```
+
+### Using the Unified Build System
+
+The `libbackend.sh` script provides consistent commands across all backends:
+
+```bash
+# Source the library in your backend script
+source $(dirname $0)/../common/libbackend.sh
+
+# Install requirements (automatically handles hardware detection)
+installRequirements
+
+# Start the backend server
+startBackend $@
+
+# Run tests
+runUnittests
+```
+
+## Hardware Targets
+
+The build system automatically detects and configures for different hardware:
+
+- **CPU** - Standard CPU-only builds
+- **CUDA** - NVIDIA GPU acceleration (supports CUDA 12/13)
+- **Intel** - Intel XPU/GPU optimization
+- **MLX** - Apple Silicon (M1/M2/M3) optimization
+- **HIP** - AMD GPU acceleration
+
+### Target-Specific Requirements
+
+Backends can specify hardware-specific dependencies:
+- `requirements.txt` - Base requirements
+- `requirements-cpu.txt` - CPU-specific packages
+- `requirements-cublas12.txt` - CUDA 12 packages
+- `requirements-cublas13.txt` - CUDA 13 packages
+- `requirements-intel.txt` - Intel-optimized packages
+- `requirements-mps.txt` - Apple Silicon packages
+
+## Configuration Options
+
+### Environment Variables
+
+- `PYTHON_VERSION` - Python version (default: 3.10)
+- `PYTHON_PATCH` - Python patch version (default: 18)
+- `BUILD_TYPE` - Force specific build target
+- `USE_PIP` - Use pip instead of uv (default: false)
+- `PORTABLE_PYTHON` - Enable portable Python builds
+- `LIMIT_TARGETS` - Restrict backend to specific targets
+
+### Example: CUDA 12 Only Backend
+
+```bash
+# In your backend script
+LIMIT_TARGETS="cublas12"
+source $(dirname $0)/../common/libbackend.sh
+```
+
+### Example: Intel-Optimized Backend
+
+```bash
+# In your backend script
+LIMIT_TARGETS="intel"
+source $(dirname $0)/../common/libbackend.sh
+```
+
+## Development
+
+### Adding a New Backend
+
+1. Create a new directory in `backend/python/`
+2. Copy the template structure from `common/template/`
+3. Implement your `backend.py` with the required gRPC interface
+4. Add appropriate requirements files for your target hardware
+5. Use `libbackend.sh` for consistent build and execution
+
+### Testing
+
+```bash
+# Run backend tests
+make test
+# or
+bash test.sh
+```
+
+### Building
+
+```bash
+# Install dependencies
+make
+
+# Clean build artifacts
+make clean
+```
+
+## Architecture
+
+Each backend follows a consistent structure:
+```
+backend-name/
+├── backend.py # Main backend implementation
+├── requirements.txt # Base dependencies
+├── requirements-*.txt # Hardware-specific dependencies
+├── install.sh # Installation script
+├── run.sh # Execution script
+├── test.sh # Test script
+├── Makefile # Build targets
+└── test.py # Unit tests
+```
+
+## Troubleshooting
+
+### Common Issues
+
+1. **Missing dependencies**: Ensure all requirements files are properly configured
+2. **Hardware detection**: Check that `BUILD_TYPE` matches your system
+3. **Python version**: Verify Python 3.10+ is available
+4. **Virtual environment**: Use `ensureVenv` to create/activate environments
+
+## Contributing
+
+When adding new backends or modifying existing ones:
+1. Follow the established directory structure
+2. Use `libbackend.sh` for consistent behavior
+3. Include appropriate requirements files for all target hardware
+4. Add comprehensive tests
+5. Update this README if adding new backend types
diff --git a/backend/python/bark/Makefile b/backend/python/bark/Makefile
new file mode 100644
index 0000000000000000000000000000000000000000..da996aabeef01838c36824b6b4b61650056fe7aa
--- /dev/null
+++ b/backend/python/bark/Makefile
@@ -0,0 +1,23 @@
+.PHONY: ttsbark
+ttsbark:
+ bash install.sh
+
+.PHONY: run
+run: ttsbark
+ @echo "Running bark..."
+ bash run.sh
+ @echo "bark run."
+
+.PHONY: test
+test: ttsbark
+ @echo "Testing bark..."
+ bash test.sh
+ @echo "bark tested."
+
+.PHONY: protogen-clean
+protogen-clean:
+ $(RM) backend_pb2_grpc.py backend_pb2.py
+
+.PHONY: clean
+clean: protogen-clean
+ rm -rf venv __pycache__
\ No newline at end of file
diff --git a/backend/python/bark/README.md b/backend/python/bark/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..5b571e47b9d99611db57bbb0f488bee8bfe86045
--- /dev/null
+++ b/backend/python/bark/README.md
@@ -0,0 +1,16 @@
+# Creating a separate environment for ttsbark project
+
+```
+make ttsbark
+```
+
+# Testing the gRPC server
+
+```
+ -m unittest test_ttsbark.py
+```
+
+For example
+```
+/opt/conda/envs/bark/bin/python -m unittest extra/grpc/bark/test_ttsbark.py
+``````
\ No newline at end of file
diff --git a/backend/python/bark/backend.py b/backend/python/bark/backend.py
new file mode 100644
index 0000000000000000000000000000000000000000..4997810054e0faa08f9f951f370b7430dfd08945
--- /dev/null
+++ b/backend/python/bark/backend.py
@@ -0,0 +1,98 @@
+#!/usr/bin/env python3
+"""
+This is an extra gRPC server of LocalAI for Bark TTS
+"""
+from concurrent import futures
+import time
+import argparse
+import signal
+import sys
+import os
+from scipy.io.wavfile import write as write_wav
+
+import backend_pb2
+import backend_pb2_grpc
+from bark import SAMPLE_RATE, generate_audio, preload_models
+
+import grpc
+
+
+_ONE_DAY_IN_SECONDS = 60 * 60 * 24
+
+# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
+MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
+
+# Implement the BackendServicer class with the service methods
+class BackendServicer(backend_pb2_grpc.BackendServicer):
+ """
+ BackendServicer is the class that implements the gRPC service
+ """
+ def Health(self, request, context):
+ return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
+ def LoadModel(self, request, context):
+ model_name = request.Model
+ try:
+ print("Preparing models, please wait", file=sys.stderr)
+ # download and load all models
+ preload_models()
+ except Exception as err:
+ return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
+ # Implement your logic here for the LoadModel service
+ # Replace this with your desired response
+ return backend_pb2.Result(message="Model loaded successfully", success=True)
+
+ def TTS(self, request, context):
+ model = request.model
+ print(request, file=sys.stderr)
+ try:
+ audio_array = None
+ if model != "":
+ audio_array = generate_audio(request.text, history_prompt=model)
+ else:
+ audio_array = generate_audio(request.text)
+ print("saving to", request.dst, file=sys.stderr)
+ # save audio to disk
+ write_wav(request.dst, SAMPLE_RATE, audio_array)
+ print("saved to", request.dst, file=sys.stderr)
+ print("tts for", file=sys.stderr)
+ print(request, file=sys.stderr)
+ except Exception as err:
+ return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
+ return backend_pb2.Result(success=True)
+
+def serve(address):
+ server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
+ options=[
+ ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
+ ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
+ ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
+ ])
+ backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
+ server.add_insecure_port(address)
+ server.start()
+ print("Server started. Listening on: " + address, file=sys.stderr)
+
+ # Define the signal handler function
+ def signal_handler(sig, frame):
+ print("Received termination signal. Shutting down...")
+ server.stop(0)
+ sys.exit(0)
+
+ # Set the signal handlers for SIGINT and SIGTERM
+ signal.signal(signal.SIGINT, signal_handler)
+ signal.signal(signal.SIGTERM, signal_handler)
+
+ try:
+ while True:
+ time.sleep(_ONE_DAY_IN_SECONDS)
+ except KeyboardInterrupt:
+ server.stop(0)
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Run the gRPC server.")
+ parser.add_argument(
+ "--addr", default="localhost:50051", help="The address to bind the server to."
+ )
+ args = parser.parse_args()
+
+ serve(args.addr)
diff --git a/backend/python/bark/install.sh b/backend/python/bark/install.sh
new file mode 100644
index 0000000000000000000000000000000000000000..32befa8e6c034734e54038c3b4b565522028f391
--- /dev/null
+++ b/backend/python/bark/install.sh
@@ -0,0 +1,19 @@
+#!/bin/bash
+set -e
+
+backend_dir=$(dirname $0)
+if [ -d $backend_dir/common ]; then
+ source $backend_dir/common/libbackend.sh
+else
+ source $backend_dir/../common/libbackend.sh
+fi
+
+# This is here because the Intel pip index is broken and returns 200 status codes for every package name, it just doesn't return any package links.
+# This makes uv think that the package exists in the Intel pip index, and by default it stops looking at other pip indexes once it finds a match.
+# We need uv to continue falling through to the pypi default index to find optimum[openvino] in the pypi index
+# the --upgrade actually allows us to *downgrade* torch to the version provided in the Intel pip index
+if [ "x${BUILD_PROFILE}" == "xintel" ]; then
+ EXTRA_PIP_INSTALL_FLAGS+=" --upgrade --index-strategy=unsafe-first-match"
+fi
+
+installRequirements
diff --git a/backend/python/bark/requirements-cpu.txt b/backend/python/bark/requirements-cpu.txt
new file mode 100644
index 0000000000000000000000000000000000000000..12e376adeb15af660958229c9d690003e9518887
--- /dev/null
+++ b/backend/python/bark/requirements-cpu.txt
@@ -0,0 +1,4 @@
+transformers
+accelerate
+torch==2.4.1
+torchaudio==2.4.1
\ No newline at end of file
diff --git a/backend/python/bark/requirements-cublas12.txt b/backend/python/bark/requirements-cublas12.txt
new file mode 100644
index 0000000000000000000000000000000000000000..537169495d1ea575474553e6c9e7e2ed74a53c86
--- /dev/null
+++ b/backend/python/bark/requirements-cublas12.txt
@@ -0,0 +1,4 @@
+torch==2.4.1
+torchaudio==2.4.1
+transformers
+accelerate
\ No newline at end of file
diff --git a/backend/python/bark/requirements-hipblas.txt b/backend/python/bark/requirements-hipblas.txt
new file mode 100644
index 0000000000000000000000000000000000000000..4e1fef6cfaa6ab4a0d408c619bdf6182b1bcf6f6
--- /dev/null
+++ b/backend/python/bark/requirements-hipblas.txt
@@ -0,0 +1,5 @@
+--extra-index-url https://download.pytorch.org/whl/rocm6.4
+torch==2.8.0+rocm6.4
+torchaudio==2.8.0+rocm6.4
+transformers
+accelerate
\ No newline at end of file
diff --git a/backend/python/bark/requirements-intel.txt b/backend/python/bark/requirements-intel.txt
new file mode 100644
index 0000000000000000000000000000000000000000..ee3c20240a5ffbb2f6988460e75c905ac59794eb
--- /dev/null
+++ b/backend/python/bark/requirements-intel.txt
@@ -0,0 +1,9 @@
+--extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
+intel-extension-for-pytorch==2.8.10+xpu
+torch==2.3.1+cxx11.abi
+torchaudio==2.3.1+cxx11.abi
+oneccl_bind_pt==2.3.100+xpu
+optimum[openvino]
+setuptools
+transformers
+accelerate
\ No newline at end of file
diff --git a/backend/python/bark/requirements.txt b/backend/python/bark/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..275e0d8bc3e556536ff28afab589d6dfc48d4773
--- /dev/null
+++ b/backend/python/bark/requirements.txt
@@ -0,0 +1,4 @@
+bark==0.1.5
+grpcio==1.76.0
+protobuf
+certifi
\ No newline at end of file
diff --git a/backend/python/bark/run.sh b/backend/python/bark/run.sh
new file mode 100644
index 0000000000000000000000000000000000000000..82b7b09ecc7d20cb3513bdd1b18f3524f7d4cbd7
--- /dev/null
+++ b/backend/python/bark/run.sh
@@ -0,0 +1,9 @@
+#!/bin/bash
+backend_dir=$(dirname $0)
+if [ -d $backend_dir/common ]; then
+ source $backend_dir/common/libbackend.sh
+else
+ source $backend_dir/../common/libbackend.sh
+fi
+
+startBackend $@
\ No newline at end of file
diff --git a/backend/python/bark/test.py b/backend/python/bark/test.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c9f3cf6b0ad36d54ac29ba34da0c84e37c98d2d
--- /dev/null
+++ b/backend/python/bark/test.py
@@ -0,0 +1,81 @@
+"""
+A test script to test the gRPC service
+"""
+import unittest
+import subprocess
+import time
+import backend_pb2
+import backend_pb2_grpc
+
+import grpc
+
+
+class TestBackendServicer(unittest.TestCase):
+ """
+ TestBackendServicer is the class that tests the gRPC service
+ """
+ def setUp(self):
+ """
+ This method sets up the gRPC service by starting the server
+ """
+ self.service = subprocess.Popen(["python3", "backend.py", "--addr", "localhost:50051"])
+ time.sleep(10)
+
+ def tearDown(self) -> None:
+ """
+ This method tears down the gRPC service by terminating the server
+ """
+ self.service.terminate()
+ self.service.wait()
+
+ def test_server_startup(self):
+ """
+ This method tests if the server starts up successfully
+ """
+ try:
+ self.setUp()
+ with grpc.insecure_channel("localhost:50051") as channel:
+ stub = backend_pb2_grpc.BackendStub(channel)
+ response = stub.Health(backend_pb2.HealthMessage())
+ self.assertEqual(response.message, b'OK')
+ except Exception as err:
+ print(err)
+ self.fail("Server failed to start")
+ finally:
+ self.tearDown()
+
+ def test_load_model(self):
+ """
+ This method tests if the model is loaded successfully
+ """
+ try:
+ self.setUp()
+ with grpc.insecure_channel("localhost:50051") as channel:
+ stub = backend_pb2_grpc.BackendStub(channel)
+ response = stub.LoadModel(backend_pb2.ModelOptions(Model="v2/en_speaker_4"))
+ self.assertTrue(response.success)
+ self.assertEqual(response.message, "Model loaded successfully")
+ except Exception as err:
+ print(err)
+ self.fail("LoadModel service failed")
+ finally:
+ self.tearDown()
+
+ def test_tts(self):
+ """
+ This method tests if the embeddings are generated successfully
+ """
+ try:
+ self.setUp()
+ with grpc.insecure_channel("localhost:50051") as channel:
+ stub = backend_pb2_grpc.BackendStub(channel)
+ response = stub.LoadModel(backend_pb2.ModelOptions(Model="v2/en_speaker_4"))
+ self.assertTrue(response.success)
+ tts_request = backend_pb2.TTSRequest(text="80s TV news production music hit for tonight's biggest story")
+ tts_response = stub.TTS(tts_request)
+ self.assertIsNotNone(tts_response)
+ except Exception as err:
+ print(err)
+ self.fail("TTS service failed")
+ finally:
+ self.tearDown()
\ No newline at end of file
diff --git a/backend/python/bark/test.sh b/backend/python/bark/test.sh
new file mode 100644
index 0000000000000000000000000000000000000000..eb59f2aaf3f38f78c7ca3dc414ea490ff66776d7
--- /dev/null
+++ b/backend/python/bark/test.sh
@@ -0,0 +1,11 @@
+#!/bin/bash
+set -e
+
+backend_dir=$(dirname $0)
+if [ -d $backend_dir/common ]; then
+ source $backend_dir/common/libbackend.sh
+else
+ source $backend_dir/../common/libbackend.sh
+fi
+
+runUnittests
diff --git a/backend/python/chatterbox/Makefile b/backend/python/chatterbox/Makefile
new file mode 100644
index 0000000000000000000000000000000000000000..be9330f8eac904e5779547b0da15841985af71e6
--- /dev/null
+++ b/backend/python/chatterbox/Makefile
@@ -0,0 +1,23 @@
+.PHONY: chatterbox
+chatterbox:
+ bash install.sh
+
+.PHONY: run
+run: chatterbox
+ @echo "Running coqui..."
+ bash run.sh
+ @echo "coqui run."
+
+.PHONY: test
+test: chatterbox
+ @echo "Testing coqui..."
+ bash test.sh
+ @echo "coqui tested."
+
+.PHONY: protogen-clean
+protogen-clean:
+ $(RM) backend_pb2_grpc.py backend_pb2.py
+
+.PHONY: clean
+clean: protogen-clean
+ rm -rf venv __pycache__
\ No newline at end of file
diff --git a/backend/python/chatterbox/backend.py b/backend/python/chatterbox/backend.py
new file mode 100644
index 0000000000000000000000000000000000000000..45fd177e24e0164715640b0eeeba3ea5992e4fc2
--- /dev/null
+++ b/backend/python/chatterbox/backend.py
@@ -0,0 +1,257 @@
+#!/usr/bin/env python3
+"""
+This is an extra gRPC server of LocalAI for Chatterbox TTS
+"""
+from concurrent import futures
+import time
+import argparse
+import signal
+import sys
+import os
+import backend_pb2
+import backend_pb2_grpc
+
+import torch
+import torchaudio as ta
+from chatterbox.tts import ChatterboxTTS
+from chatterbox.mtl_tts import ChatterboxMultilingualTTS
+import grpc
+import tempfile
+
+def is_float(s):
+ """Check if a string can be converted to float."""
+ try:
+ float(s)
+ return True
+ except ValueError:
+ return False
+def is_int(s):
+ """Check if a string can be converted to int."""
+ try:
+ int(s)
+ return True
+ except ValueError:
+ return False
+
+def split_text_at_word_boundary(text, max_length=250):
+ """
+ Split text at word boundaries without truncating words.
+ Returns a list of text chunks.
+ """
+ if not text or len(text) <= max_length:
+ return [text]
+
+ chunks = []
+ words = text.split()
+ current_chunk = ""
+
+ for word in words:
+ # Check if adding this word would exceed the limit
+ if len(current_chunk) + len(word) + 1 <= max_length:
+ if current_chunk:
+ current_chunk += " " + word
+ else:
+ current_chunk = word
+ else:
+ # If current chunk is not empty, add it to chunks
+ if current_chunk:
+ chunks.append(current_chunk)
+ current_chunk = word
+ else:
+ # If a single word is longer than max_length, we have to include it anyway
+ chunks.append(word)
+ current_chunk = ""
+
+ # Add the last chunk if it's not empty
+ if current_chunk:
+ chunks.append(current_chunk)
+
+ return chunks
+
+def merge_audio_files(audio_files, output_path, sample_rate):
+ """
+ Merge multiple audio files into a single audio file.
+ """
+ if not audio_files:
+ return
+
+ if len(audio_files) == 1:
+ # If only one file, just copy it
+ import shutil
+ shutil.copy2(audio_files[0], output_path)
+ return
+
+ # Load all audio files
+ waveforms = []
+ for audio_file in audio_files:
+ waveform, sr = ta.load(audio_file)
+ if sr != sample_rate:
+ # Resample if necessary
+ resampler = ta.transforms.Resample(sr, sample_rate)
+ waveform = resampler(waveform)
+ waveforms.append(waveform)
+
+ # Concatenate all waveforms
+ merged_waveform = torch.cat(waveforms, dim=1)
+
+ # Save the merged audio
+ ta.save(output_path, merged_waveform, sample_rate)
+
+ # Clean up temporary files
+ for audio_file in audio_files:
+ if os.path.exists(audio_file):
+ os.remove(audio_file)
+
+_ONE_DAY_IN_SECONDS = 60 * 60 * 24
+
+# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
+MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
+
+# Implement the BackendServicer class with the service methods
+class BackendServicer(backend_pb2_grpc.BackendServicer):
+ """
+ BackendServicer is the class that implements the gRPC service
+ """
+ def Health(self, request, context):
+ return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
+ def LoadModel(self, request, context):
+
+ # Get device
+ # device = "cuda" if request.CUDA else "cpu"
+ if torch.cuda.is_available():
+ print("CUDA is available", file=sys.stderr)
+ device = "cuda"
+ else:
+ print("CUDA is not available", file=sys.stderr)
+ device = "cpu"
+ mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
+ if mps_available:
+ device = "mps"
+ if not torch.cuda.is_available() and request.CUDA:
+ return backend_pb2.Result(success=False, message="CUDA is not available")
+
+
+ options = request.Options
+
+ # empty dict
+ self.options = {}
+
+ # The options are a list of strings in this form optname:optvalue
+ # We are storing all the options in a dict so we can use it later when
+ # generating the images
+ for opt in options:
+ if ":" not in opt:
+ continue
+ key, value = opt.split(":")
+ # if value is a number, convert it to the appropriate type
+ if is_float(value):
+ value = float(value)
+ elif is_int(value):
+ value = int(value)
+ elif value.lower() in ["true", "false"]:
+ value = value.lower() == "true"
+ self.options[key] = value
+
+ self.AudioPath = None
+
+ if os.path.isabs(request.AudioPath):
+ self.AudioPath = request.AudioPath
+ elif request.AudioPath and request.ModelFile != "" and not os.path.isabs(request.AudioPath):
+ # get base path of modelFile
+ modelFileBase = os.path.dirname(request.ModelFile)
+ # modify LoraAdapter to be relative to modelFileBase
+ self.AudioPath = os.path.join(modelFileBase, request.AudioPath)
+ try:
+ print("Preparing models, please wait", file=sys.stderr)
+ if "multilingual" in self.options:
+ # remove key from options
+ del self.options["multilingual"]
+ self.model = ChatterboxMultilingualTTS.from_pretrained(device=device)
+ else:
+ self.model = ChatterboxTTS.from_pretrained(device=device)
+ except Exception as err:
+ return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
+ # Implement your logic here for the LoadModel service
+ # Replace this with your desired response
+ return backend_pb2.Result(message="Model loaded successfully", success=True)
+
+ def TTS(self, request, context):
+ try:
+ kwargs = {}
+
+ if "language" in self.options:
+ kwargs["language_id"] = self.options["language"]
+ if self.AudioPath is not None:
+ kwargs["audio_prompt_path"] = self.AudioPath
+
+ # add options to kwargs
+ kwargs.update(self.options)
+
+ # Check if text exceeds 250 characters
+ # (chatterbox does not support long text)
+ # https://github.com/resemble-ai/chatterbox/issues/60
+ # https://github.com/resemble-ai/chatterbox/issues/110
+ if len(request.text) > 250:
+ # Split text at word boundaries
+ text_chunks = split_text_at_word_boundary(request.text, max_length=250)
+ print(f"Splitting text into chunks of 250 characters: {len(text_chunks)}", file=sys.stderr)
+ # Generate audio for each chunk
+ temp_audio_files = []
+ for i, chunk in enumerate(text_chunks):
+ # Generate audio for this chunk
+ wav = self.model.generate(chunk, **kwargs)
+
+ # Create temporary file for this chunk
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.wav')
+ temp_file.close()
+ ta.save(temp_file.name, wav, self.model.sr)
+ temp_audio_files.append(temp_file.name)
+
+ # Merge all audio files
+ merge_audio_files(temp_audio_files, request.dst, self.model.sr)
+ else:
+ # Generate audio using ChatterboxTTS for short text
+ wav = self.model.generate(request.text, **kwargs)
+ # Save the generated audio
+ ta.save(request.dst, wav, self.model.sr)
+
+ except Exception as err:
+ return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
+ return backend_pb2.Result(success=True)
+
+def serve(address):
+ server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
+ options=[
+ ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
+ ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
+ ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
+ ])
+ backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
+ server.add_insecure_port(address)
+ server.start()
+ print("Server started. Listening on: " + address, file=sys.stderr)
+
+ # Define the signal handler function
+ def signal_handler(sig, frame):
+ print("Received termination signal. Shutting down...")
+ server.stop(0)
+ sys.exit(0)
+
+ # Set the signal handlers for SIGINT and SIGTERM
+ signal.signal(signal.SIGINT, signal_handler)
+ signal.signal(signal.SIGTERM, signal_handler)
+
+ try:
+ while True:
+ time.sleep(_ONE_DAY_IN_SECONDS)
+ except KeyboardInterrupt:
+ server.stop(0)
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Run the gRPC server.")
+ parser.add_argument(
+ "--addr", default="localhost:50051", help="The address to bind the server to."
+ )
+ args = parser.parse_args()
+
+ serve(args.addr)
diff --git a/backend/python/chatterbox/install.sh b/backend/python/chatterbox/install.sh
new file mode 100644
index 0000000000000000000000000000000000000000..04d76fd5b924c11300b9eaf3023d80ecd83d30ce
--- /dev/null
+++ b/backend/python/chatterbox/install.sh
@@ -0,0 +1,25 @@
+#!/bin/bash
+set -e
+
+backend_dir=$(dirname $0)
+if [ -d $backend_dir/common ]; then
+ source $backend_dir/common/libbackend.sh
+else
+ source $backend_dir/../common/libbackend.sh
+fi
+
+# This is here because the Intel pip index is broken and returns 200 status codes for every package name, it just doesn't return any package links.
+# This makes uv think that the package exists in the Intel pip index, and by default it stops looking at other pip indexes once it finds a match.
+# We need uv to continue falling through to the pypi default index to find optimum[openvino] in the pypi index
+# the --upgrade actually allows us to *downgrade* torch to the version provided in the Intel pip index
+if [ "x${BUILD_PROFILE}" == "xintel" ]; then
+ EXTRA_PIP_INSTALL_FLAGS+=" --upgrade --index-strategy=unsafe-first-match"
+fi
+EXTRA_PIP_INSTALL_FLAGS+=" --no-build-isolation"
+
+if [ "x${BUILD_PROFILE}" == "xl4t12" ]; then
+ USE_PIP=true
+fi
+
+
+installRequirements
diff --git a/backend/python/chatterbox/requirements-cpu.txt b/backend/python/chatterbox/requirements-cpu.txt
new file mode 100644
index 0000000000000000000000000000000000000000..df4814ac700b2537c5330fa673b96aa91318b061
--- /dev/null
+++ b/backend/python/chatterbox/requirements-cpu.txt
@@ -0,0 +1,9 @@
+--extra-index-url https://download.pytorch.org/whl/cpu
+accelerate
+torch
+torchaudio
+numpy>=1.24.0,<1.26.0
+transformers
+# https://github.com/mudler/LocalAI/pull/6240#issuecomment-3329518289
+chatterbox-tts@git+https://git@github.com/mudler/chatterbox.git@faster
+#chatterbox-tts==0.1.4
\ No newline at end of file
diff --git a/backend/python/chatterbox/requirements-cublas12.txt b/backend/python/chatterbox/requirements-cublas12.txt
new file mode 100644
index 0000000000000000000000000000000000000000..70c46d2d5a92bcaa45a06993b3db8058cd1aa3a8
--- /dev/null
+++ b/backend/python/chatterbox/requirements-cublas12.txt
@@ -0,0 +1,7 @@
+torch
+torchaudio
+transformers
+numpy>=1.24.0,<1.26.0
+# https://github.com/mudler/LocalAI/pull/6240#issuecomment-3329518289
+chatterbox-tts@git+https://git@github.com/mudler/chatterbox.git@faster
+accelerate
diff --git a/backend/python/chatterbox/requirements-cublas13.txt b/backend/python/chatterbox/requirements-cublas13.txt
new file mode 100644
index 0000000000000000000000000000000000000000..4ac324c9db73f121ca318d7a72541f4b4fec7445
--- /dev/null
+++ b/backend/python/chatterbox/requirements-cublas13.txt
@@ -0,0 +1,8 @@
+--extra-index-url https://download.pytorch.org/whl/cu130
+torch
+torchaudio
+transformers
+numpy>=1.24.0,<1.26.0
+# https://github.com/mudler/LocalAI/pull/6240#issuecomment-3329518289
+chatterbox-tts@git+https://git@github.com/mudler/chatterbox.git@faster
+accelerate
diff --git a/backend/python/chatterbox/requirements-hipblas.txt b/backend/python/chatterbox/requirements-hipblas.txt
new file mode 100644
index 0000000000000000000000000000000000000000..ed30fb8241072ce46a27a13c2f8f5f08f6cdf15a
--- /dev/null
+++ b/backend/python/chatterbox/requirements-hipblas.txt
@@ -0,0 +1,8 @@
+--extra-index-url https://download.pytorch.org/whl/rocm6.4
+torch==2.9.1+rocm6.4
+torchaudio==2.9.1+rocm6.4
+transformers
+numpy>=1.24.0,<1.26.0
+# https://github.com/mudler/LocalAI/pull/6240#issuecomment-3329518289
+chatterbox-tts@git+https://git@github.com/mudler/chatterbox.git@faster
+accelerate
diff --git a/backend/python/chatterbox/requirements-install.txt b/backend/python/chatterbox/requirements-install.txt
new file mode 100644
index 0000000000000000000000000000000000000000..a9ffcac6a60d6413cb5f4f5b41f2fa3107876662
--- /dev/null
+++ b/backend/python/chatterbox/requirements-install.txt
@@ -0,0 +1,5 @@
+# Build dependencies needed for packages installed from source (e.g., git dependencies)
+# When using --no-build-isolation, these must be installed in the venv first
+wheel
+setuptools
+packaging
diff --git a/backend/python/chatterbox/requirements-intel.txt b/backend/python/chatterbox/requirements-intel.txt
new file mode 100644
index 0000000000000000000000000000000000000000..cb88cbc27093e120160ce2aa0ad21b1e2ab5abb5
--- /dev/null
+++ b/backend/python/chatterbox/requirements-intel.txt
@@ -0,0 +1,12 @@
+--extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
+intel-extension-for-pytorch==2.3.110+xpu
+torch==2.3.1+cxx11.abi
+torchaudio==2.3.1+cxx11.abi
+transformers
+numpy>=1.24.0,<1.26.0
+# https://github.com/mudler/LocalAI/pull/6240#issuecomment-3329518289
+chatterbox-tts@git+https://git@github.com/mudler/chatterbox.git@faster
+accelerate
+oneccl_bind_pt==2.3.100+xpu
+optimum[openvino]
+setuptools
\ No newline at end of file
diff --git a/backend/python/chatterbox/requirements-l4t12.txt b/backend/python/chatterbox/requirements-l4t12.txt
new file mode 100644
index 0000000000000000000000000000000000000000..e5cea23925b4f11cccd2bcfc8c78b7b66af6097b
--- /dev/null
+++ b/backend/python/chatterbox/requirements-l4t12.txt
@@ -0,0 +1,7 @@
+--extra-index-url https://pypi.jetson-ai-lab.io/jp6/cu126/
+torch
+torchaudio
+transformers
+numpy>=1.24.0,<1.26.0
+chatterbox-tts@git+https://git@github.com/mudler/chatterbox.git@faster
+accelerate
diff --git a/backend/python/chatterbox/requirements-l4t13.txt b/backend/python/chatterbox/requirements-l4t13.txt
new file mode 100644
index 0000000000000000000000000000000000000000..0f6e3e7de94fbb771b7e8438999bf6e7f6e56ed1
--- /dev/null
+++ b/backend/python/chatterbox/requirements-l4t13.txt
@@ -0,0 +1,7 @@
+--extra-index-url https://download.pytorch.org/whl/cu130
+torch
+torchaudio
+transformers
+numpy>=1.24.0,<1.26.0
+chatterbox-tts@git+https://git@github.com/mudler/chatterbox.git@faster
+accelerate
diff --git a/backend/python/chatterbox/requirements.txt b/backend/python/chatterbox/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..55a0867f0b0e053861e5d47a479f59c7f417b8bb
--- /dev/null
+++ b/backend/python/chatterbox/requirements.txt
@@ -0,0 +1,6 @@
+grpcio==1.71.0
+protobuf
+certifi
+packaging
+setuptools
+poetry
\ No newline at end of file
diff --git a/backend/python/chatterbox/run.sh b/backend/python/chatterbox/run.sh
new file mode 100644
index 0000000000000000000000000000000000000000..82b7b09ecc7d20cb3513bdd1b18f3524f7d4cbd7
--- /dev/null
+++ b/backend/python/chatterbox/run.sh
@@ -0,0 +1,9 @@
+#!/bin/bash
+backend_dir=$(dirname $0)
+if [ -d $backend_dir/common ]; then
+ source $backend_dir/common/libbackend.sh
+else
+ source $backend_dir/../common/libbackend.sh
+fi
+
+startBackend $@
\ No newline at end of file
diff --git a/backend/python/chatterbox/test.py b/backend/python/chatterbox/test.py
new file mode 100644
index 0000000000000000000000000000000000000000..878345ab64b4bc87dacbdfbeb73528d38d0f893f
--- /dev/null
+++ b/backend/python/chatterbox/test.py
@@ -0,0 +1,82 @@
+"""
+A test script to test the gRPC service
+"""
+import unittest
+import subprocess
+import time
+import backend_pb2
+import backend_pb2_grpc
+
+import grpc
+
+
+class TestBackendServicer(unittest.TestCase):
+ """
+ TestBackendServicer is the class that tests the gRPC service
+ """
+ def setUp(self):
+ """
+ This method sets up the gRPC service by starting the server
+ """
+ self.service = subprocess.Popen(["python3", "backend.py", "--addr", "localhost:50051"])
+ time.sleep(30)
+
+ def tearDown(self) -> None:
+ """
+ This method tears down the gRPC service by terminating the server
+ """
+ self.service.terminate()
+ self.service.wait()
+
+ def test_server_startup(self):
+ """
+ This method tests if the server starts up successfully
+ """
+ try:
+ self.setUp()
+ with grpc.insecure_channel("localhost:50051") as channel:
+ stub = backend_pb2_grpc.BackendStub(channel)
+ response = stub.Health(backend_pb2.HealthMessage())
+ self.assertEqual(response.message, b'OK')
+ except Exception as err:
+ print(err)
+ self.fail("Server failed to start")
+ finally:
+ self.tearDown()
+
+ def test_load_model(self):
+ """
+ This method tests if the model is loaded successfully
+ """
+ try:
+ self.setUp()
+ with grpc.insecure_channel("localhost:50051") as channel:
+ stub = backend_pb2_grpc.BackendStub(channel)
+ response = stub.LoadModel(backend_pb2.ModelOptions())
+ print(response)
+ self.assertTrue(response.success)
+ self.assertEqual(response.message, "Model loaded successfully")
+ except Exception as err:
+ print(err)
+ self.fail("LoadModel service failed")
+ finally:
+ self.tearDown()
+
+ def test_tts(self):
+ """
+ This method tests if the embeddings are generated successfully
+ """
+ try:
+ self.setUp()
+ with grpc.insecure_channel("localhost:50051") as channel:
+ stub = backend_pb2_grpc.BackendStub(channel)
+ response = stub.LoadModel(backend_pb2.ModelOptions())
+ self.assertTrue(response.success)
+ tts_request = backend_pb2.TTSRequest(text="80s TV news production music hit for tonight's biggest story")
+ tts_response = stub.TTS(tts_request)
+ self.assertIsNotNone(tts_response)
+ except Exception as err:
+ print(err)
+ self.fail("TTS service failed")
+ finally:
+ self.tearDown()
\ No newline at end of file
diff --git a/backend/python/chatterbox/test.sh b/backend/python/chatterbox/test.sh
new file mode 100644
index 0000000000000000000000000000000000000000..eb59f2aaf3f38f78c7ca3dc414ea490ff66776d7
--- /dev/null
+++ b/backend/python/chatterbox/test.sh
@@ -0,0 +1,11 @@
+#!/bin/bash
+set -e
+
+backend_dir=$(dirname $0)
+if [ -d $backend_dir/common ]; then
+ source $backend_dir/common/libbackend.sh
+else
+ source $backend_dir/../common/libbackend.sh
+fi
+
+runUnittests
diff --git a/backend/python/common/libbackend.sh b/backend/python/common/libbackend.sh
new file mode 100644
index 0000000000000000000000000000000000000000..7956b3c10a5ad3760d2adbc059b9715c06c292a8
--- /dev/null
+++ b/backend/python/common/libbackend.sh
@@ -0,0 +1,535 @@
+#!/usr/bin/env bash
+set -euo pipefail
+
+#
+# use the library by adding the following line to a script:
+# source $(dirname $0)/../common/libbackend.sh
+#
+# If you want to limit what targets a backend can be used on, set the variable LIMIT_TARGETS to a
+# space separated list of valid targets BEFORE sourcing the library, for example to only allow a backend
+# to be used on CUDA and CPU backends:
+#
+# LIMIT_TARGETS="cublas cpu"
+# source $(dirname $0)/../common/libbackend.sh
+#
+# You can use any valid BUILD_TYPE or BUILD_PROFILE, if you need to limit a backend to CUDA 12 only:
+#
+# LIMIT_TARGETS="cublas12"
+# source $(dirname $0)/../common/libbackend.sh
+#
+# You can switch between uv (conda-like) and pip installation methods by setting USE_PIP:
+# USE_PIP=true source $(dirname $0)/../common/libbackend.sh
+#
+# ===================== user-configurable defaults =====================
+PYTHON_VERSION="${PYTHON_VERSION:-3.10}" # e.g. 3.10 / 3.11 / 3.12 / 3.13
+PYTHON_PATCH="${PYTHON_PATCH:-18}" # e.g. 18 -> 3.10.18 ; 13 -> 3.11.13
+PY_STANDALONE_TAG="${PY_STANDALONE_TAG:-20250818}" # release tag date
+# Enable/disable bundling of a portable Python build
+PORTABLE_PYTHON="${PORTABLE_PYTHON:-false}"
+
+# If you want to fully pin the filename (including tuned CPU targets), set:
+# PORTABLE_PY_FILENAME="cpython-3.10.18+20250818-x86_64_v3-unknown-linux-gnu-install_only.tar.gz"
+: "${PORTABLE_PY_FILENAME:=}"
+: "${PORTABLE_PY_SHA256:=}" # optional; if set we verify the download
+# =====================================================================
+
+# Default to uv if USE_PIP is not set
+if [ "x${USE_PIP:-}" == "x" ]; then
+ USE_PIP=false
+fi
+
+# ----------------------- helpers -----------------------
+function _is_musl() {
+ # detect musl (Alpine, etc)
+ if command -v ldd >/dev/null 2>&1; then
+ ldd --version 2>&1 | grep -qi musl && return 0
+ fi
+ # busybox-ish fallback
+ if command -v getconf >/dev/null 2>&1; then
+ getconf GNU_LIBC_VERSION >/dev/null 2>&1 || return 0
+ fi
+ return 1
+}
+
+function _triple() {
+ local os="" arch="" libc="gnu"
+ case "$(uname -s)" in
+ Linux*) os="unknown-linux" ;;
+ Darwin*) os="apple-darwin" ;;
+ MINGW*|MSYS*|CYGWIN*) os="pc-windows-msvc" ;; # best-effort for Git Bash
+ *) echo "Unsupported OS $(uname -s)"; exit 1;;
+ esac
+
+ case "$(uname -m)" in
+ x86_64) arch="x86_64" ;;
+ aarch64|arm64) arch="aarch64" ;;
+ armv7l) arch="armv7" ;;
+ i686|i386) arch="i686" ;;
+ ppc64le) arch="ppc64le" ;;
+ s390x) arch="s390x" ;;
+ riscv64) arch="riscv64" ;;
+ *) echo "Unsupported arch $(uname -m)"; exit 1;;
+ esac
+
+ if [[ "$os" == "unknown-linux" ]]; then
+ if _is_musl; then
+ libc="musl"
+ else
+ libc="gnu"
+ fi
+ echo "${arch}-${os}-${libc}"
+ else
+ echo "${arch}-${os}"
+ fi
+}
+
+function _portable_dir() {
+ echo "${EDIR}/python"
+}
+
+function _portable_bin() {
+ # python-build-standalone puts python in ./bin
+ echo "$(_portable_dir)/bin"
+}
+
+function _portable_python() {
+ if [ -x "$(_portable_bin)/python3" ]; then
+ echo "$(_portable_bin)/python3"
+ else
+ echo "$(_portable_bin)/python"
+ fi
+}
+
+
+# macOS loader env for the portable CPython
+_macosPortableEnv() {
+ if [ "$(uname -s)" = "Darwin" ]; then
+ export DYLD_LIBRARY_PATH="$(_portable_dir)/lib${DYLD_LIBRARY_PATH:+:${DYLD_LIBRARY_PATH}}"
+ export DYLD_FALLBACK_LIBRARY_PATH="$(_portable_dir)/lib${DYLD_FALLBACK_LIBRARY_PATH:+:${DYLD_FALLBACK_LIBRARY_PATH}}"
+ fi
+}
+
+# Good hygiene on macOS for downloaded/extracted trees
+_unquarantinePortablePython() {
+ if [ "$(uname -s)" = "Darwin" ]; then
+ command -v xattr >/dev/null 2>&1 && xattr -dr com.apple.quarantine "$(_portable_dir)" || true
+ fi
+}
+
+# ------------------ ### PORTABLE PYTHON ------------------
+function ensurePortablePython() {
+ local pdir="$(_portable_dir)"
+ local pbin="$(_portable_bin)"
+ local pyexe
+
+ if [ -x "${pbin}/python3" ] || [ -x "${pbin}/python" ]; then
+ _macosPortableEnv
+ return 0
+ fi
+
+ mkdir -p "${pdir}"
+ local triple="$(_triple)"
+
+ local full_ver="${PYTHON_VERSION}.${PYTHON_PATCH}"
+ local fn=""
+ if [ -n "${PORTABLE_PY_FILENAME}" ]; then
+ fn="${PORTABLE_PY_FILENAME}"
+ else
+ # generic asset name: cpython-+--install_only.tar.gz
+ fn="cpython-${full_ver}+${PY_STANDALONE_TAG}-${triple}-install_only.tar.gz"
+ fi
+
+ local url="https://github.com/astral-sh/python-build-standalone/releases/download/${PY_STANDALONE_TAG}/${fn}"
+ local tmp="${pdir}/${fn}"
+ echo "Downloading portable Python: ${fn}"
+ # curl with retries; fall back to wget if needed
+ if command -v curl >/dev/null 2>&1; then
+ curl -L --fail --retry 3 --retry-delay 1 -o "${tmp}" "${url}"
+ else
+ wget -O "${tmp}" "${url}"
+ fi
+
+ if [ -n "${PORTABLE_PY_SHA256}" ]; then
+ echo "${PORTABLE_PY_SHA256} ${tmp}" | sha256sum -c -
+ fi
+
+ echo "Extracting ${fn} -> ${pdir}"
+ # always a .tar.gz (we purposely choose install_only)
+ tar -xzf "${tmp}" -C "${pdir}"
+ rm -f "${tmp}"
+
+ # Some archives nest a directory; if so, flatten to ${pdir}
+ # Find the first dir with a 'bin/python*'
+ local inner
+ inner="$(find "${pdir}" -type f -path "*/bin/python*" -maxdepth 3 2>/dev/null | head -n1 || true)"
+ if [ -n "${inner}" ]; then
+ local inner_root
+ inner_root="$(dirname "$(dirname "${inner}")")" # .../bin -> root
+ if [ "${inner_root}" != "${pdir}" ]; then
+ # move contents up one level
+ shopt -s dotglob
+ mv "${inner_root}/"* "${pdir}/"
+ rm -rf "${inner_root}"
+ shopt -u dotglob
+ fi
+ fi
+
+ _unquarantinePortablePython
+ _macosPortableEnv
+ # Make sure it's runnable
+ pyexe="$(_portable_python)"
+ "${pyexe}" -V
+}
+
+# init handles the setup of the library
+function init() {
+ BACKEND_NAME=${PWD##*/}
+ MY_DIR=$(realpath "$(dirname "$0")")
+ BUILD_PROFILE=$(getBuildProfile)
+
+ EDIR=${MY_DIR}
+ if [ "x${ENV_DIR:-}" != "x" ]; then
+ EDIR=${ENV_DIR}
+ fi
+
+ if [ ! -z "${LIMIT_TARGETS:-}" ]; then
+ isValidTarget=$(checkTargets ${LIMIT_TARGETS})
+ if [ ${isValidTarget} != true ]; then
+ echo "${BACKEND_NAME} can only be used on the following targets: ${LIMIT_TARGETS}"
+ exit 0
+ fi
+ fi
+
+ echo "Initializing libbackend for ${BACKEND_NAME}"
+}
+
+
+# getBuildProfile will inspect the system to determine which build profile is appropriate:
+# returns one of the following:
+# - cublas12
+# - cublas13
+# - hipblas
+# - intel
+function getBuildProfile() {
+ if [ x"${BUILD_TYPE:-}" == "xcublas" ] || [ x"${BUILD_TYPE:-}" == "xl4t" ]; then
+ if [ ! -z "${CUDA_MAJOR_VERSION:-}" ]; then
+ echo ${BUILD_TYPE}${CUDA_MAJOR_VERSION}
+ else
+ echo ${BUILD_TYPE}
+ fi
+ return 0
+ fi
+
+ if [ -d "/opt/intel" ]; then
+ echo "intel"
+ return 0
+ fi
+
+ if [ -n "${BUILD_TYPE:-}" ]; then
+ echo ${BUILD_TYPE}
+ return 0
+ fi
+
+ echo "cpu"
+}
+
+
+# Make the venv relocatable:
+# - rewrite venv/bin/python{,3} to relative symlinks into $(_portable_dir)
+# - normalize entrypoint shebangs to /usr/bin/env python3
+# - optionally update pyvenv.cfg to point to the portable Python directory (only at runtime)
+# Usage: _makeVenvPortable [--update-pyvenv-cfg]
+_makeVenvPortable() {
+ local update_pyvenv_cfg=false
+ if [ "${1:-}" = "--update-pyvenv-cfg" ]; then
+ update_pyvenv_cfg=true
+ fi
+
+ local venv_dir="${EDIR}/venv"
+ local vbin="${venv_dir}/bin"
+
+ [ -d "${vbin}" ] || return 0
+
+ # 1) Replace python symlinks with relative ones to ../../python/bin/python3
+ # (venv/bin -> venv -> EDIR -> python/bin)
+ local rel_py='../../python/bin/python3'
+
+ for name in python3 python; do
+ if [ -e "${vbin}/${name}" ] || [ -L "${vbin}/${name}" ]; then
+ rm -f "${vbin}/${name}"
+ fi
+ done
+ ln -s "${rel_py}" "${vbin}/python3"
+ ln -s "python3" "${vbin}/python"
+
+ # 2) Update pyvenv.cfg to point to the portable Python directory (only at runtime)
+ # Use absolute path resolved at runtime so it works when the venv is copied
+ if [ "$update_pyvenv_cfg" = "true" ]; then
+ local pyvenv_cfg="${venv_dir}/pyvenv.cfg"
+ if [ -f "${pyvenv_cfg}" ]; then
+ local portable_dir="$(_portable_dir)"
+ # Resolve to absolute path - this ensures it works when the backend is copied
+ # Only resolve if the directory exists (it should if ensurePortablePython was called)
+ if [ -d "${portable_dir}" ]; then
+ portable_dir="$(cd "${portable_dir}" && pwd)"
+ else
+ # Fallback to relative path if directory doesn't exist yet
+ portable_dir="../python"
+ fi
+ local sed_i=(sed -i)
+ # macOS/BSD sed needs a backup suffix; GNU sed doesn't. Make it portable:
+ if sed --version >/dev/null 2>&1; then
+ sed_i=(sed -i)
+ else
+ sed_i=(sed -i '')
+ fi
+ # Update the home field in pyvenv.cfg
+ # Handle both absolute paths (starting with /) and relative paths
+ if grep -q "^home = " "${pyvenv_cfg}"; then
+ "${sed_i[@]}" "s|^home = .*|home = ${portable_dir}|" "${pyvenv_cfg}"
+ else
+ # If home field doesn't exist, add it
+ echo "home = ${portable_dir}" >> "${pyvenv_cfg}"
+ fi
+ fi
+ fi
+
+ # 3) Rewrite shebangs of entry points to use env, so the venv is relocatable
+ # Only touch text files that start with #! and reference the current venv.
+ local ve_abs="${vbin}/python"
+ local sed_i=(sed -i)
+ # macOS/BSD sed needs a backup suffix; GNU sed doesn't. Make it portable:
+ if sed --version >/dev/null 2>&1; then
+ sed_i=(sed -i)
+ else
+ sed_i=(sed -i '')
+ fi
+
+ for f in "${vbin}"/*; do
+ [ -f "$f" ] || continue
+ # Fast path: check first two bytes (#!)
+ head -c2 "$f" 2>/dev/null | grep -q '^#!' || continue
+ # Only rewrite if the shebang mentions the (absolute) venv python
+ if head -n1 "$f" | grep -Fq "${ve_abs}"; then
+ "${sed_i[@]}" '1s|^#!.*$|#!/usr/bin/env python3|' "$f"
+ chmod +x "$f" 2>/dev/null || true
+ fi
+ done
+}
+
+
+# ensureVenv makes sure that the venv for the backend both exists, and is activated.
+#
+# This function is idempotent, so you can call it as many times as you want and it will
+# always result in an activated virtual environment
+function ensureVenv() {
+ local interpreter=""
+
+ if [ "x${PORTABLE_PYTHON}" == "xtrue" ] || [ -e "$(_portable_python)" ]; then
+ echo "Using portable Python"
+ ensurePortablePython
+ interpreter="$(_portable_python)"
+ else
+ # Prefer system python${PYTHON_VERSION}, else python3, else fall back to bundled
+ if command -v python${PYTHON_VERSION} >/dev/null 2>&1; then
+ interpreter="python${PYTHON_VERSION}"
+ elif command -v python3 >/dev/null 2>&1; then
+ interpreter="python3"
+ else
+ echo "No suitable system Python found, bootstrapping portable build..."
+ ensurePortablePython
+ interpreter="$(_portable_python)"
+ fi
+ fi
+
+ if [ ! -d "${EDIR}/venv" ]; then
+ if [ "x${USE_PIP}" == "xtrue" ]; then
+ "${interpreter}" -m venv --copies "${EDIR}/venv"
+ source "${EDIR}/venv/bin/activate"
+ "${interpreter}" -m pip install --upgrade pip
+ else
+ if [ "x${PORTABLE_PYTHON}" == "xtrue" ]; then
+ uv venv --python "${interpreter}" "${EDIR}/venv"
+ else
+ uv venv --python "${PYTHON_VERSION}" "${EDIR}/venv"
+ fi
+ fi
+ if [ "x${PORTABLE_PYTHON}" == "xtrue" ]; then
+ # During install, only update symlinks and shebangs, not pyvenv.cfg
+ _makeVenvPortable
+ fi
+ fi
+
+ # We call it here to make sure that when we source a venv we can still use python as expected
+ if [ -x "$(_portable_python)" ]; then
+ _macosPortableEnv
+ fi
+
+ if [ "x${VIRTUAL_ENV:-}" != "x${EDIR}/venv" ]; then
+ source "${EDIR}/venv/bin/activate"
+ fi
+}
+
+
+function runProtogen() {
+ ensureVenv
+ if [ "x${USE_PIP}" == "xtrue" ]; then
+ pip install grpcio-tools
+ else
+ uv pip install grpcio-tools
+ fi
+ pushd "${EDIR}" >/dev/null
+ # use the venv python (ensures correct interpreter & sys.path)
+ python -m grpc_tools.protoc -I../../ -I./ --python_out=. --grpc_python_out=. backend.proto
+ popd >/dev/null
+}
+
+
+# installRequirements looks for several requirements files and if they exist runs the install for them in order
+#
+# - requirements-install.txt
+# - requirements.txt
+# - requirements-${BUILD_TYPE}.txt
+# - requirements-${BUILD_PROFILE}.txt
+#
+# BUILD_PROFILE is a more specific version of BUILD_TYPE, ex: cuda-12 or cuda-13
+# it can also include some options that we do not have BUILD_TYPES for, ex: intel
+#
+# NOTE: for BUILD_PROFILE==intel, this function does NOT automatically use the Intel python package index.
+# you may want to add the following line to a requirements-intel.txt if you use one:
+#
+# --index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
+#
+# If you need to add extra flags into the pip install command you can do so by setting the variable EXTRA_PIP_INSTALL_FLAGS
+# before calling installRequirements. For example:
+#
+# source $(dirname $0)/../common/libbackend.sh
+# EXTRA_PIP_INSTALL_FLAGS="--no-build-isolation"
+# installRequirements
+function installRequirements() {
+ ensureVenv
+ declare -a requirementFiles=(
+ "${EDIR}/requirements-install.txt"
+ "${EDIR}/requirements.txt"
+ "${EDIR}/requirements-${BUILD_TYPE:-}.txt"
+ )
+
+ if [ "x${BUILD_TYPE:-}" != "x${BUILD_PROFILE}" ]; then
+ requirementFiles+=("${EDIR}/requirements-${BUILD_PROFILE}.txt")
+ fi
+ if [ "x${BUILD_TYPE:-}" == "x" ]; then
+ requirementFiles+=("${EDIR}/requirements-cpu.txt")
+ fi
+ requirementFiles+=("${EDIR}/requirements-after.txt")
+ if [ "x${BUILD_TYPE:-}" != "x${BUILD_PROFILE}" ]; then
+ requirementFiles+=("${EDIR}/requirements-${BUILD_PROFILE}-after.txt")
+ fi
+
+ # This is needed to build wheels that e.g. depends on Python.h
+ if [ "x${PORTABLE_PYTHON}" == "xtrue" ]; then
+ export C_INCLUDE_PATH="${C_INCLUDE_PATH:-}:$(_portable_dir)/include/python${PYTHON_VERSION}"
+ fi
+
+ for reqFile in ${requirementFiles[@]}; do
+ if [ -f "${reqFile}" ]; then
+ echo "starting requirements install for ${reqFile}"
+ if [ "x${USE_PIP}" == "xtrue" ]; then
+ pip install ${EXTRA_PIP_INSTALL_FLAGS:-} --requirement "${reqFile}"
+ else
+ uv pip install ${EXTRA_PIP_INSTALL_FLAGS:-} --requirement "${reqFile}"
+ fi
+ echo "finished requirements install for ${reqFile}"
+ fi
+ done
+
+ runProtogen
+}
+
+# startBackend discovers and runs the backend GRPC server
+#
+# You can specify a specific backend file to execute by setting BACKEND_FILE before calling startBackend.
+# example:
+#
+# source ../common/libbackend.sh
+# BACKEND_FILE="${MY_DIR}/source/backend.py"
+# startBackend $@
+#
+# valid filenames for autodiscovered backend servers are:
+# - server.py
+# - backend.py
+# - ${BACKEND_NAME}.py
+function startBackend() {
+ ensureVenv
+ # Update pyvenv.cfg before running to ensure paths are correct for current location
+ # This is critical when the backend position is dynamic (e.g., copied from container)
+ if [ "x${PORTABLE_PYTHON}" == "xtrue" ] || [ -x "$(_portable_python)" ]; then
+ _makeVenvPortable --update-pyvenv-cfg
+ fi
+
+ # Set up GPU library paths if a lib directory exists
+ # This allows backends to include their own GPU libraries (CUDA, ROCm, etc.)
+ if [ -d "${EDIR}/lib" ]; then
+ export LD_LIBRARY_PATH="${EDIR}/lib:${LD_LIBRARY_PATH:-}"
+ echo "Added ${EDIR}/lib to LD_LIBRARY_PATH for GPU libraries"
+ fi
+
+ if [ ! -z "${BACKEND_FILE:-}" ]; then
+ exec "${EDIR}/venv/bin/python" "${BACKEND_FILE}" "$@"
+ elif [ -e "${MY_DIR}/server.py" ]; then
+ exec "${EDIR}/venv/bin/python" "${MY_DIR}/server.py" "$@"
+ elif [ -e "${MY_DIR}/backend.py" ]; then
+ exec "${EDIR}/venv/bin/python" "${MY_DIR}/backend.py" "$@"
+ elif [ -e "${MY_DIR}/${BACKEND_NAME}.py" ]; then
+ exec "${EDIR}/venv/bin/python" "${MY_DIR}/${BACKEND_NAME}.py" "$@"
+ fi
+}
+
+
+# runUnittests discovers and runs python unittests
+#
+# You can specify a specific test file to use by setting TEST_FILE before calling runUnittests.
+# example:
+#
+# source ../common/libbackend.sh
+# TEST_FILE="${MY_DIR}/source/test.py"
+# runUnittests $@
+#
+# be default a file named test.py in the backends directory will be used
+function runUnittests() {
+ ensureVenv
+ if [ ! -z "${TEST_FILE:-}" ]; then
+ testDir=$(dirname "$(realpath "${TEST_FILE}")")
+ testFile=$(basename "${TEST_FILE}")
+ pushd "${testDir}" >/dev/null
+ python -m unittest "${testFile}"
+ popd >/dev/null
+ elif [ -f "${MY_DIR}/test.py" ]; then
+ pushd "${MY_DIR}" >/dev/null
+ python -m unittest test.py
+ popd >/dev/null
+ else
+ echo "no tests defined for ${BACKEND_NAME}"
+ fi
+}
+
+
+##################################################################################
+# Below here are helper functions not intended to be used outside of the library #
+##################################################################################
+
+# checkTargets determines if the current BUILD_TYPE or BUILD_PROFILE is in a list of valid targets
+function checkTargets() {
+ targets=$@
+ declare -a targets=($targets)
+ for target in ${targets[@]}; do
+ if [ "x${BUILD_TYPE:-}" == "x${target}" ]; then
+ echo true; return 0
+ fi
+ if [ "x${BUILD_PROFILE}" == "x${target}" ]; then
+ echo true; return 0
+ fi
+ done
+ echo false
+}
+
+init
diff --git a/backend/python/common/template/Makefile b/backend/python/common/template/Makefile
new file mode 100644
index 0000000000000000000000000000000000000000..f6b9ddc6c888845a9b20c98d3ef8bfae3629a1cd
--- /dev/null
+++ b/backend/python/common/template/Makefile
@@ -0,0 +1,13 @@
+.DEFAULT_GOAL := install
+
+.PHONY: install
+install:
+ bash install.sh
+
+.PHONY: protogen-clean
+protogen-clean:
+ $(RM) backend_pb2_grpc.py backend_pb2.py
+
+.PHONY: clean
+clean: protogen-clean
+ rm -rf venv __pycache__
\ No newline at end of file
diff --git a/backend/python/common/template/backend.py b/backend/python/common/template/backend.py
new file mode 100644
index 0000000000000000000000000000000000000000..7592d3a5ade3f5da69d92d563a0b0cf012283a74
--- /dev/null
+++ b/backend/python/common/template/backend.py
@@ -0,0 +1,4 @@
+#!/usr/bin/env python3
+import grpc
+import backend_pb2
+import backend_pb2_grpc
diff --git a/backend/python/common/template/install.sh b/backend/python/common/template/install.sh
new file mode 100644
index 0000000000000000000000000000000000000000..32befa8e6c034734e54038c3b4b565522028f391
--- /dev/null
+++ b/backend/python/common/template/install.sh
@@ -0,0 +1,19 @@
+#!/bin/bash
+set -e
+
+backend_dir=$(dirname $0)
+if [ -d $backend_dir/common ]; then
+ source $backend_dir/common/libbackend.sh
+else
+ source $backend_dir/../common/libbackend.sh
+fi
+
+# This is here because the Intel pip index is broken and returns 200 status codes for every package name, it just doesn't return any package links.
+# This makes uv think that the package exists in the Intel pip index, and by default it stops looking at other pip indexes once it finds a match.
+# We need uv to continue falling through to the pypi default index to find optimum[openvino] in the pypi index
+# the --upgrade actually allows us to *downgrade* torch to the version provided in the Intel pip index
+if [ "x${BUILD_PROFILE}" == "xintel" ]; then
+ EXTRA_PIP_INSTALL_FLAGS+=" --upgrade --index-strategy=unsafe-first-match"
+fi
+
+installRequirements
diff --git a/backend/python/common/template/protogen.sh b/backend/python/common/template/protogen.sh
new file mode 100644
index 0000000000000000000000000000000000000000..cba7791cbce3e87a4d6aae9f8399013cca2a447b
--- /dev/null
+++ b/backend/python/common/template/protogen.sh
@@ -0,0 +1,11 @@
+#!/bin/bash
+set -e
+
+backend_dir=$(dirname $0)
+if [ -d $backend_dir/common ]; then
+ source $backend_dir/common/libbackend.sh
+else
+ source $backend_dir/../common/libbackend.sh
+fi
+
+runProtogen
\ No newline at end of file
diff --git a/backend/python/common/template/requirements-hipblas.txt b/backend/python/common/template/requirements-hipblas.txt
new file mode 100644
index 0000000000000000000000000000000000000000..b733ec7b148b6ba310daebd51c1ad3a3527bd50a
--- /dev/null
+++ b/backend/python/common/template/requirements-hipblas.txt
@@ -0,0 +1,2 @@
+--extra-index-url https://download.pytorch.org/whl/rocm6.4
+torch
\ No newline at end of file
diff --git a/backend/python/common/template/requirements-intel.txt b/backend/python/common/template/requirements-intel.txt
new file mode 100644
index 0000000000000000000000000000000000000000..53393f6a284bdf3afa0f501258b2e3357236de11
--- /dev/null
+++ b/backend/python/common/template/requirements-intel.txt
@@ -0,0 +1,5 @@
+--extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
+intel-extension-for-pytorch==2.8.10+xpu
+torch==2.8.0
+oneccl_bind_pt==2.8.0+xpu
+optimum[openvino]
\ No newline at end of file
diff --git a/backend/python/common/template/requirements.txt b/backend/python/common/template/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..53bbcb4220436f0fb00d8ddf33c8f19d7a690d00
--- /dev/null
+++ b/backend/python/common/template/requirements.txt
@@ -0,0 +1,3 @@
+grpcio==1.76.0
+protobuf
+grpcio-tools
\ No newline at end of file
diff --git a/backend/python/common/template/run.sh b/backend/python/common/template/run.sh
new file mode 100644
index 0000000000000000000000000000000000000000..82b7b09ecc7d20cb3513bdd1b18f3524f7d4cbd7
--- /dev/null
+++ b/backend/python/common/template/run.sh
@@ -0,0 +1,9 @@
+#!/bin/bash
+backend_dir=$(dirname $0)
+if [ -d $backend_dir/common ]; then
+ source $backend_dir/common/libbackend.sh
+else
+ source $backend_dir/../common/libbackend.sh
+fi
+
+startBackend $@
\ No newline at end of file
diff --git a/backend/python/common/template/test.sh b/backend/python/common/template/test.sh
new file mode 100644
index 0000000000000000000000000000000000000000..eb59f2aaf3f38f78c7ca3dc414ea490ff66776d7
--- /dev/null
+++ b/backend/python/common/template/test.sh
@@ -0,0 +1,11 @@
+#!/bin/bash
+set -e
+
+backend_dir=$(dirname $0)
+if [ -d $backend_dir/common ]; then
+ source $backend_dir/common/libbackend.sh
+else
+ source $backend_dir/../common/libbackend.sh
+fi
+
+runUnittests
diff --git a/backend/python/coqui/Makefile b/backend/python/coqui/Makefile
new file mode 100644
index 0000000000000000000000000000000000000000..6915b0f9f8961eada07045ca74aaae51ce15bd89
--- /dev/null
+++ b/backend/python/coqui/Makefile
@@ -0,0 +1,23 @@
+.PHONY: coqui
+coqui:
+ bash install.sh
+
+.PHONY: run
+run: coqui
+ @echo "Running coqui..."
+ bash run.sh
+ @echo "coqui run."
+
+.PHONY: test
+test: coqui
+ @echo "Testing coqui..."
+ bash test.sh
+ @echo "coqui tested."
+
+.PHONY: protogen-clean
+protogen-clean:
+ $(RM) backend_pb2_grpc.py backend_pb2.py
+
+.PHONY: clean
+clean: protogen-clean
+ rm -rf venv __pycache__
\ No newline at end of file
diff --git a/backend/python/coqui/README.md b/backend/python/coqui/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..e9c1931bb9cd74ca3e597dec96bd40e33acd96a8
--- /dev/null
+++ b/backend/python/coqui/README.md
@@ -0,0 +1,11 @@
+# Creating a separate environment for ttsbark project
+
+```
+make coqui
+```
+
+# Testing the gRPC server
+
+```
+make test
+```
\ No newline at end of file
diff --git a/backend/python/coqui/backend.py b/backend/python/coqui/backend.py
new file mode 100644
index 0000000000000000000000000000000000000000..df115adb503004e4f37cf52642016b0e64be4d17
--- /dev/null
+++ b/backend/python/coqui/backend.py
@@ -0,0 +1,125 @@
+#!/usr/bin/env python3
+"""
+This is an extra gRPC server of LocalAI for Bark TTS
+"""
+from concurrent import futures
+import time
+import argparse
+import signal
+import sys
+import os
+import backend_pb2
+import backend_pb2_grpc
+
+import torch
+from TTS.api import TTS
+
+import grpc
+
+
+_ONE_DAY_IN_SECONDS = 60 * 60 * 24
+
+# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
+MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
+COQUI_LANGUAGE = os.environ.get('COQUI_LANGUAGE', None)
+
+# Implement the BackendServicer class with the service methods
+class BackendServicer(backend_pb2_grpc.BackendServicer):
+ """
+ BackendServicer is the class that implements the gRPC service
+ """
+ def Health(self, request, context):
+ return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
+ def LoadModel(self, request, context):
+
+ # Get device
+ # device = "cuda" if request.CUDA else "cpu"
+ if torch.cuda.is_available():
+ print("CUDA is available", file=sys.stderr)
+ device = "cuda"
+ else:
+ print("CUDA is not available", file=sys.stderr)
+ device = "cpu"
+ mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
+ if mps_available:
+ device = "mps"
+ if not torch.cuda.is_available() and request.CUDA:
+ return backend_pb2.Result(success=False, message="CUDA is not available")
+
+ self.AudioPath = None
+ # List available 🐸TTS models
+ print(TTS().list_models())
+ if os.path.isabs(request.AudioPath):
+ self.AudioPath = request.AudioPath
+ elif request.AudioPath and request.ModelFile != "" and not os.path.isabs(request.AudioPath):
+ # get base path of modelFile
+ modelFileBase = os.path.dirname(request.ModelFile)
+ # modify LoraAdapter to be relative to modelFileBase
+ self.AudioPath = os.path.join(modelFileBase, request.AudioPath)
+
+ try:
+ print("Preparing models, please wait", file=sys.stderr)
+ self.tts = TTS(request.Model).to(device)
+ except Exception as err:
+ return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
+ # Implement your logic here for the LoadModel service
+ # Replace this with your desired response
+ return backend_pb2.Result(message="Model loaded successfully", success=True)
+
+ def TTS(self, request, context):
+ try:
+ # if model is multilingual add language from request or env as fallback
+ lang = request.language or COQUI_LANGUAGE
+ if lang == "":
+ lang = None
+ if self.tts.is_multi_lingual and lang is None:
+ return backend_pb2.Result(success=False, message=f"Model is multi-lingual, but no language was provided")
+
+ # if model is multi-speaker, use speaker_wav or the speaker_id from request.voice
+ if self.tts.is_multi_speaker and self.AudioPath is None and request.voice is None:
+ return backend_pb2.Result(success=False, message=f"Model is multi-speaker, but no speaker was provided")
+
+ if self.tts.is_multi_speaker and request.voice is not None:
+ self.tts.tts_to_file(text=request.text, speaker=request.voice, language=lang, file_path=request.dst)
+ else:
+ self.tts.tts_to_file(text=request.text, speaker_wav=self.AudioPath, language=lang, file_path=request.dst)
+ except Exception as err:
+ return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
+ return backend_pb2.Result(success=True)
+
+def serve(address):
+ server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
+ options=[
+ ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
+ ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
+ ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
+ ])
+ backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
+ server.add_insecure_port(address)
+ server.start()
+ print("Server started. Listening on: " + address, file=sys.stderr)
+
+ # Define the signal handler function
+ def signal_handler(sig, frame):
+ print("Received termination signal. Shutting down...")
+ server.stop(0)
+ sys.exit(0)
+
+ # Set the signal handlers for SIGINT and SIGTERM
+ signal.signal(signal.SIGINT, signal_handler)
+ signal.signal(signal.SIGTERM, signal_handler)
+
+ try:
+ while True:
+ time.sleep(_ONE_DAY_IN_SECONDS)
+ except KeyboardInterrupt:
+ server.stop(0)
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Run the gRPC server.")
+ parser.add_argument(
+ "--addr", default="localhost:50051", help="The address to bind the server to."
+ )
+ args = parser.parse_args()
+
+ serve(args.addr)
diff --git a/backend/python/coqui/install.sh b/backend/python/coqui/install.sh
new file mode 100644
index 0000000000000000000000000000000000000000..32befa8e6c034734e54038c3b4b565522028f391
--- /dev/null
+++ b/backend/python/coqui/install.sh
@@ -0,0 +1,19 @@
+#!/bin/bash
+set -e
+
+backend_dir=$(dirname $0)
+if [ -d $backend_dir/common ]; then
+ source $backend_dir/common/libbackend.sh
+else
+ source $backend_dir/../common/libbackend.sh
+fi
+
+# This is here because the Intel pip index is broken and returns 200 status codes for every package name, it just doesn't return any package links.
+# This makes uv think that the package exists in the Intel pip index, and by default it stops looking at other pip indexes once it finds a match.
+# We need uv to continue falling through to the pypi default index to find optimum[openvino] in the pypi index
+# the --upgrade actually allows us to *downgrade* torch to the version provided in the Intel pip index
+if [ "x${BUILD_PROFILE}" == "xintel" ]; then
+ EXTRA_PIP_INSTALL_FLAGS+=" --upgrade --index-strategy=unsafe-first-match"
+fi
+
+installRequirements
diff --git a/backend/python/coqui/requirements-cpu.txt b/backend/python/coqui/requirements-cpu.txt
new file mode 100644
index 0000000000000000000000000000000000000000..787877bd843965679b688a64f6b8a032531583c3
--- /dev/null
+++ b/backend/python/coqui/requirements-cpu.txt
@@ -0,0 +1,4 @@
+transformers==4.48.3
+accelerate
+torch==2.4.1
+coqui-tts
\ No newline at end of file
diff --git a/backend/python/coqui/requirements-cublas12.txt b/backend/python/coqui/requirements-cublas12.txt
new file mode 100644
index 0000000000000000000000000000000000000000..53ed2ebc760caebddb51d473b087477303ea4f60
--- /dev/null
+++ b/backend/python/coqui/requirements-cublas12.txt
@@ -0,0 +1,5 @@
+torch==2.4.1
+torchaudio==2.4.1
+transformers==4.48.3
+accelerate
+coqui-tts
\ No newline at end of file
diff --git a/backend/python/coqui/requirements-hipblas.txt b/backend/python/coqui/requirements-hipblas.txt
new file mode 100644
index 0000000000000000000000000000000000000000..8e7d034591e35961ee53835a5061cff1e199d49e
--- /dev/null
+++ b/backend/python/coqui/requirements-hipblas.txt
@@ -0,0 +1,6 @@
+--extra-index-url https://download.pytorch.org/whl/rocm6.4
+torch==2.8.0+rocm6.4
+torchaudio==2.8.0+rocm6.4
+transformers==4.48.3
+accelerate
+coqui-tts
\ No newline at end of file
diff --git a/backend/python/coqui/requirements-intel.txt b/backend/python/coqui/requirements-intel.txt
new file mode 100644
index 0000000000000000000000000000000000000000..c45ce1660e68ec7c22921530bb739e02e01b29ec
--- /dev/null
+++ b/backend/python/coqui/requirements-intel.txt
@@ -0,0 +1,10 @@
+--extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
+intel-extension-for-pytorch==2.3.110+xpu
+torch==2.3.1+cxx11.abi
+torchaudio==2.3.1+cxx11.abi
+oneccl_bind_pt==2.3.100+xpu
+optimum[openvino]
+setuptools
+transformers==4.48.3
+accelerate
+coqui-tts
\ No newline at end of file
diff --git a/backend/python/coqui/requirements.txt b/backend/python/coqui/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..95dc03174857be9b250a3c279ba0583ed98572d0
--- /dev/null
+++ b/backend/python/coqui/requirements.txt
@@ -0,0 +1,4 @@
+grpcio==1.76.0
+protobuf
+certifi
+packaging==24.1
\ No newline at end of file
diff --git a/backend/python/coqui/run.sh b/backend/python/coqui/run.sh
new file mode 100644
index 0000000000000000000000000000000000000000..82b7b09ecc7d20cb3513bdd1b18f3524f7d4cbd7
--- /dev/null
+++ b/backend/python/coqui/run.sh
@@ -0,0 +1,9 @@
+#!/bin/bash
+backend_dir=$(dirname $0)
+if [ -d $backend_dir/common ]; then
+ source $backend_dir/common/libbackend.sh
+else
+ source $backend_dir/../common/libbackend.sh
+fi
+
+startBackend $@
\ No newline at end of file
diff --git a/backend/python/coqui/test.py b/backend/python/coqui/test.py
new file mode 100644
index 0000000000000000000000000000000000000000..e0b1a0bdd1240e1edfc9218c0c2ba032e5cf6300
--- /dev/null
+++ b/backend/python/coqui/test.py
@@ -0,0 +1,82 @@
+"""
+A test script to test the gRPC service
+"""
+import unittest
+import subprocess
+import time
+import backend_pb2
+import backend_pb2_grpc
+
+import grpc
+
+
+class TestBackendServicer(unittest.TestCase):
+ """
+ TestBackendServicer is the class that tests the gRPC service
+ """
+ def setUp(self):
+ """
+ This method sets up the gRPC service by starting the server
+ """
+ self.service = subprocess.Popen(["python3", "backend.py", "--addr", "localhost:50051"])
+ time.sleep(30)
+
+ def tearDown(self) -> None:
+ """
+ This method tears down the gRPC service by terminating the server
+ """
+ self.service.terminate()
+ self.service.wait()
+
+ def test_server_startup(self):
+ """
+ This method tests if the server starts up successfully
+ """
+ try:
+ self.setUp()
+ with grpc.insecure_channel("localhost:50051") as channel:
+ stub = backend_pb2_grpc.BackendStub(channel)
+ response = stub.Health(backend_pb2.HealthMessage())
+ self.assertEqual(response.message, b'OK')
+ except Exception as err:
+ print(err)
+ self.fail("Server failed to start")
+ finally:
+ self.tearDown()
+
+ def test_load_model(self):
+ """
+ This method tests if the model is loaded successfully
+ """
+ try:
+ self.setUp()
+ with grpc.insecure_channel("localhost:50051") as channel:
+ stub = backend_pb2_grpc.BackendStub(channel)
+ response = stub.LoadModel(backend_pb2.ModelOptions(Model="tts_models/en/vctk/vits"))
+ print(response)
+ self.assertTrue(response.success)
+ self.assertEqual(response.message, "Model loaded successfully")
+ except Exception as err:
+ print(err)
+ self.fail("LoadModel service failed")
+ finally:
+ self.tearDown()
+
+ def test_tts(self):
+ """
+ This method tests if the embeddings are generated successfully
+ """
+ try:
+ self.setUp()
+ with grpc.insecure_channel("localhost:50051") as channel:
+ stub = backend_pb2_grpc.BackendStub(channel)
+ response = stub.LoadModel(backend_pb2.ModelOptions(Model="tts_models/en/vctk/vits"))
+ self.assertTrue(response.success)
+ tts_request = backend_pb2.TTSRequest(text="80s TV news production music hit for tonight's biggest story")
+ tts_response = stub.TTS(tts_request)
+ self.assertIsNotNone(tts_response)
+ except Exception as err:
+ print(err)
+ self.fail("TTS service failed")
+ finally:
+ self.tearDown()
\ No newline at end of file
diff --git a/backend/python/coqui/test.sh b/backend/python/coqui/test.sh
new file mode 100644
index 0000000000000000000000000000000000000000..eb59f2aaf3f38f78c7ca3dc414ea490ff66776d7
--- /dev/null
+++ b/backend/python/coqui/test.sh
@@ -0,0 +1,11 @@
+#!/bin/bash
+set -e
+
+backend_dir=$(dirname $0)
+if [ -d $backend_dir/common ]; then
+ source $backend_dir/common/libbackend.sh
+else
+ source $backend_dir/../common/libbackend.sh
+fi
+
+runUnittests
diff --git a/backend/python/diffusers/Makefile b/backend/python/diffusers/Makefile
new file mode 100644
index 0000000000000000000000000000000000000000..f9ded4a1cff737f0a17a286942d522bb82257ef9
--- /dev/null
+++ b/backend/python/diffusers/Makefile
@@ -0,0 +1,33 @@
+export CONDA_ENV_PATH = "diffusers.yml"
+
+ifeq ($(BUILD_TYPE), hipblas)
+export CONDA_ENV_PATH = "diffusers-rocm.yml"
+endif
+
+# Intel GPU are supposed to have dependencies installed in the main python
+# environment, so we skip conda installation for SYCL builds.
+# https://github.com/intel/intel-extension-for-pytorch/issues/538
+ifneq (,$(findstring sycl,$(BUILD_TYPE)))
+export SKIP_CONDA=1
+endif
+
+.PHONY: diffusers
+diffusers:
+ bash install.sh
+
+.PHONY: run
+run: diffusers
+ @echo "Running diffusers..."
+ bash run.sh
+ @echo "Diffusers run."
+
+test: diffusers
+ bash test.sh
+
+.PHONY: protogen-clean
+protogen-clean:
+ $(RM) backend_pb2_grpc.py backend_pb2.py
+
+.PHONY: clean
+clean: protogen-clean
+ rm -rf venv __pycache__
\ No newline at end of file
diff --git a/backend/python/diffusers/README.md b/backend/python/diffusers/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..91fff31276942f431468079922da41e6701c9a03
--- /dev/null
+++ b/backend/python/diffusers/README.md
@@ -0,0 +1,136 @@
+# LocalAI Diffusers Backend
+
+This backend provides gRPC access to Hugging Face diffusers pipelines with dynamic pipeline loading.
+
+## Creating a separate environment for the diffusers project
+
+```
+make diffusers
+```
+
+## Dynamic Pipeline Loader
+
+The diffusers backend includes a dynamic pipeline loader (`diffusers_dynamic_loader.py`) that automatically discovers and loads diffusers pipelines at runtime. This eliminates the need for per-pipeline conditional statements - new pipelines added to diffusers become available automatically without code changes.
+
+### How It Works
+
+1. **Pipeline Discovery**: On first use, the loader scans the `diffusers` package to find all classes that inherit from `DiffusionPipeline`.
+
+2. **Registry Caching**: Discovery results are cached for the lifetime of the process to avoid repeated scanning.
+
+3. **Task Aliases**: The loader automatically derives task aliases from class names (e.g., "text-to-image", "image-to-image", "inpainting") without hardcoding.
+
+4. **Multiple Resolution Methods**: Pipelines can be resolved by:
+ - Exact class name (e.g., `StableDiffusionPipeline`)
+ - Task alias (e.g., `text-to-image`, `img2img`)
+ - Model ID (uses HuggingFace Hub to infer pipeline type)
+
+### Usage Examples
+
+```python
+from diffusers_dynamic_loader import (
+ load_diffusers_pipeline,
+ get_available_pipelines,
+ get_available_tasks,
+ resolve_pipeline_class,
+ discover_diffusers_classes,
+ get_available_classes,
+)
+
+# List all available pipelines
+pipelines = get_available_pipelines()
+print(f"Available pipelines: {pipelines[:10]}...")
+
+# List all task aliases
+tasks = get_available_tasks()
+print(f"Available tasks: {tasks}")
+
+# Resolve a pipeline class by name
+cls = resolve_pipeline_class(class_name="StableDiffusionPipeline")
+
+# Resolve by task alias
+cls = resolve_pipeline_class(task="stable-diffusion")
+
+# Load and instantiate a pipeline
+pipe = load_diffusers_pipeline(
+ class_name="StableDiffusionPipeline",
+ model_id="runwayml/stable-diffusion-v1-5",
+ torch_dtype=torch.float16
+)
+
+# Load from single file
+pipe = load_diffusers_pipeline(
+ class_name="StableDiffusionPipeline",
+ model_id="/path/to/model.safetensors",
+ from_single_file=True,
+ torch_dtype=torch.float16
+)
+
+# Discover other diffusers classes (schedulers, models, etc.)
+schedulers = discover_diffusers_classes("SchedulerMixin")
+print(f"Available schedulers: {list(schedulers.keys())[:5]}...")
+
+# Get list of available scheduler classes
+scheduler_list = get_available_classes("SchedulerMixin")
+```
+
+### Generic Class Discovery
+
+The dynamic loader can discover not just pipelines but any class type from diffusers:
+
+```python
+# Discover all scheduler classes
+schedulers = discover_diffusers_classes("SchedulerMixin")
+
+# Discover all model classes
+models = discover_diffusers_classes("ModelMixin")
+
+# Get a sorted list of available classes
+scheduler_names = get_available_classes("SchedulerMixin")
+```
+
+### Special Pipeline Handling
+
+Most pipelines are loaded dynamically through `load_diffusers_pipeline()`. Only pipelines requiring truly custom initialization logic are handled explicitly:
+
+- `FluxTransformer2DModel`: Requires quantization and custom transformer loading (cannot use dynamic loader)
+- `WanPipeline` / `WanImageToVideoPipeline`: Uses dynamic loader with special VAE (float32 dtype)
+- `SanaPipeline`: Uses dynamic loader with post-load dtype conversion for VAE/text encoder
+- `StableVideoDiffusionPipeline`: Uses dynamic loader with CPU offload handling
+- `VideoDiffusionPipeline`: Alias for DiffusionPipeline with video flags
+
+All other pipelines (StableDiffusionPipeline, StableDiffusionXLPipeline, FluxPipeline, etc.) are loaded purely through the dynamic loader.
+
+### Error Handling
+
+When a pipeline cannot be resolved, the loader provides helpful error messages listing available pipelines and tasks:
+
+```
+ValueError: Unknown pipeline class 'NonExistentPipeline'.
+Available pipelines: AnimateDiffPipeline, AnimateDiffVideoToVideoPipeline, ...
+```
+
+## Environment Variables
+
+| Variable | Default | Description |
+|----------|---------|-------------|
+| `COMPEL` | `0` | Enable Compel for prompt weighting |
+| `XPU` | `0` | Enable Intel XPU support |
+| `CLIPSKIP` | `1` | Enable CLIP skip support |
+| `SAFETENSORS` | `1` | Use safetensors format |
+| `CHUNK_SIZE` | `8` | Decode chunk size for video |
+| `FPS` | `7` | Video frames per second |
+| `DISABLE_CPU_OFFLOAD` | `0` | Disable CPU offload |
+| `FRAMES` | `64` | Number of video frames |
+| `BFL_REPO` | `ChuckMcSneed/FLUX.1-dev` | Flux base repo |
+| `PYTHON_GRPC_MAX_WORKERS` | `1` | Max gRPC workers |
+
+## Running Tests
+
+```bash
+./test.sh
+```
+
+The test suite includes:
+- Unit tests for the dynamic loader (`test_dynamic_loader.py`)
+- Integration tests for the gRPC backend (`test.py`)
\ No newline at end of file
diff --git a/backend/python/diffusers/backend.py b/backend/python/diffusers/backend.py
new file mode 100644
index 0000000000000000000000000000000000000000..f26a94b57b5fa2bc7c89ffcc549bcf1e5669d78e
--- /dev/null
+++ b/backend/python/diffusers/backend.py
@@ -0,0 +1,837 @@
+#!/usr/bin/env python3
+"""
+LocalAI Diffusers Backend
+
+This backend provides gRPC access to diffusers pipelines with dynamic pipeline loading.
+New pipelines added to diffusers become available automatically without code changes.
+"""
+from concurrent import futures
+import traceback
+import argparse
+from collections import defaultdict
+from enum import Enum
+import signal
+import sys
+import time
+import os
+
+from PIL import Image
+import torch
+
+import backend_pb2
+import backend_pb2_grpc
+
+import grpc
+
+# Import dynamic loader for pipeline discovery
+from diffusers_dynamic_loader import (
+ get_pipeline_registry,
+ resolve_pipeline_class,
+ get_available_pipelines,
+ load_diffusers_pipeline,
+)
+
+# Import specific items still needed for special cases and safety checker
+from diffusers import DiffusionPipeline, ControlNetModel
+from diffusers import FluxPipeline, FluxTransformer2DModel, AutoencoderKLWan
+from diffusers.pipelines.stable_diffusion import safety_checker
+from diffusers.utils import load_image, export_to_video
+from compel import Compel, ReturnedEmbeddingsType
+from optimum.quanto import freeze, qfloat8, quantize
+from transformers import T5EncoderModel
+from safetensors.torch import load_file
+
+# Import LTX-2 specific utilities
+try:
+ from diffusers.pipelines.ltx2.export_utils import encode_video as ltx2_encode_video
+ LTX2_AVAILABLE = True
+except ImportError:
+ LTX2_AVAILABLE = False
+ ltx2_encode_video = None
+
+_ONE_DAY_IN_SECONDS = 60 * 60 * 24
+COMPEL = os.environ.get("COMPEL", "0") == "1"
+XPU = os.environ.get("XPU", "0") == "1"
+CLIPSKIP = os.environ.get("CLIPSKIP", "1") == "1"
+SAFETENSORS = os.environ.get("SAFETENSORS", "1") == "1"
+CHUNK_SIZE = os.environ.get("CHUNK_SIZE", "8")
+FPS = os.environ.get("FPS", "7")
+DISABLE_CPU_OFFLOAD = os.environ.get("DISABLE_CPU_OFFLOAD", "0") == "1"
+FRAMES = os.environ.get("FRAMES", "64")
+
+if XPU:
+ print(torch.xpu.get_device_name(0))
+
+# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
+MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
+
+
+# https://github.com/CompVis/stable-diffusion/issues/239#issuecomment-1627615287
+def sc(self, clip_input, images): return images, [False for i in images]
+
+
+# edit the StableDiffusionSafetyChecker class so that, when called, it just returns the images and an array of True values
+safety_checker.StableDiffusionSafetyChecker.forward = sc
+
+from diffusers.schedulers import (
+ DDIMScheduler,
+ DPMSolverMultistepScheduler,
+ DPMSolverSinglestepScheduler,
+ EulerAncestralDiscreteScheduler,
+ EulerDiscreteScheduler,
+ HeunDiscreteScheduler,
+ KDPM2AncestralDiscreteScheduler,
+ KDPM2DiscreteScheduler,
+ LMSDiscreteScheduler,
+ PNDMScheduler,
+ UniPCMultistepScheduler,
+)
+
+def is_float(s):
+ """Check if a string can be converted to float."""
+ try:
+ float(s)
+ return True
+ except ValueError:
+ return False
+def is_int(s):
+ """Check if a string can be converted to int."""
+ try:
+ int(s)
+ return True
+ except ValueError:
+ return False
+
+
+# The scheduler list mapping was taken from here: https://github.com/neggles/animatediff-cli/blob/6f336f5f4b5e38e85d7f06f1744ef42d0a45f2a7/src/animatediff/schedulers.py#L39
+# Credits to https://github.com/neggles
+# See https://github.com/huggingface/diffusers/issues/4167 for more details on sched mapping from A1111
+class DiffusionScheduler(str, Enum):
+ ddim = "ddim" # DDIM
+ pndm = "pndm" # PNDM
+ heun = "heun" # Heun
+ unipc = "unipc" # UniPC
+ euler = "euler" # Euler
+ euler_a = "euler_a" # Euler a
+
+ lms = "lms" # LMS
+ k_lms = "k_lms" # LMS Karras
+
+ dpm_2 = "dpm_2" # DPM2
+ k_dpm_2 = "k_dpm_2" # DPM2 Karras
+
+ dpm_2_a = "dpm_2_a" # DPM2 a
+ k_dpm_2_a = "k_dpm_2_a" # DPM2 a Karras
+
+ dpmpp_2m = "dpmpp_2m" # DPM++ 2M
+ k_dpmpp_2m = "k_dpmpp_2m" # DPM++ 2M Karras
+
+ dpmpp_sde = "dpmpp_sde" # DPM++ SDE
+ k_dpmpp_sde = "k_dpmpp_sde" # DPM++ SDE Karras
+
+ dpmpp_2m_sde = "dpmpp_2m_sde" # DPM++ 2M SDE
+ k_dpmpp_2m_sde = "k_dpmpp_2m_sde" # DPM++ 2M SDE Karras
+
+
+def get_scheduler(name: str, config: dict = {}):
+ is_karras = name.startswith("k_")
+ if is_karras:
+ # strip the k_ prefix and add the karras sigma flag to config
+ name = name.lstrip("k_")
+ config["use_karras_sigmas"] = True
+
+ if name == DiffusionScheduler.ddim:
+ sched_class = DDIMScheduler
+ elif name == DiffusionScheduler.pndm:
+ sched_class = PNDMScheduler
+ elif name == DiffusionScheduler.heun:
+ sched_class = HeunDiscreteScheduler
+ elif name == DiffusionScheduler.unipc:
+ sched_class = UniPCMultistepScheduler
+ elif name == DiffusionScheduler.euler:
+ sched_class = EulerDiscreteScheduler
+ elif name == DiffusionScheduler.euler_a:
+ sched_class = EulerAncestralDiscreteScheduler
+ elif name == DiffusionScheduler.lms:
+ sched_class = LMSDiscreteScheduler
+ elif name == DiffusionScheduler.dpm_2:
+ # Equivalent to DPM2 in K-Diffusion
+ sched_class = KDPM2DiscreteScheduler
+ elif name == DiffusionScheduler.dpm_2_a:
+ # Equivalent to `DPM2 a`` in K-Diffusion
+ sched_class = KDPM2AncestralDiscreteScheduler
+ elif name == DiffusionScheduler.dpmpp_2m:
+ # Equivalent to `DPM++ 2M` in K-Diffusion
+ sched_class = DPMSolverMultistepScheduler
+ config["algorithm_type"] = "dpmsolver++"
+ config["solver_order"] = 2
+ elif name == DiffusionScheduler.dpmpp_sde:
+ # Equivalent to `DPM++ SDE` in K-Diffusion
+ sched_class = DPMSolverSinglestepScheduler
+ elif name == DiffusionScheduler.dpmpp_2m_sde:
+ # Equivalent to `DPM++ 2M SDE` in K-Diffusion
+ sched_class = DPMSolverMultistepScheduler
+ config["algorithm_type"] = "sde-dpmsolver++"
+ else:
+ raise ValueError(f"Invalid scheduler '{'k_' if is_karras else ''}{name}'")
+
+ return sched_class.from_config(config)
+
+
+# Implement the BackendServicer class with the service methods
+class BackendServicer(backend_pb2_grpc.BackendServicer):
+
+ def _load_pipeline(self, request, modelFile, fromSingleFile, torchType, variant):
+ """
+ Load a diffusers pipeline dynamically using the dynamic loader.
+
+ This method uses load_diffusers_pipeline() for most pipelines, falling back
+ to explicit handling only for pipelines requiring custom initialization
+ (e.g., quantization, special VAE handling).
+
+ Args:
+ request: The gRPC request containing pipeline configuration
+ modelFile: Path to the model file (for single file loading)
+ fromSingleFile: Whether to use from_single_file() vs from_pretrained()
+ torchType: The torch dtype to use
+ variant: Model variant (e.g., "fp16")
+
+ Returns:
+ The loaded pipeline instance
+ """
+ pipeline_type = request.PipelineType
+
+ # Handle IMG2IMG request flag with default pipeline
+ if request.IMG2IMG and pipeline_type == "":
+ pipeline_type = "StableDiffusionImg2ImgPipeline"
+
+ # ================================================================
+ # Special cases requiring custom initialization logic
+ # Only handle pipelines that truly need custom code (quantization,
+ # special VAE handling, etc.). All other pipelines use dynamic loading.
+ # ================================================================
+
+ # FluxTransformer2DModel - requires quantization and custom transformer loading
+ if pipeline_type == "FluxTransformer2DModel":
+ dtype = torch.bfloat16
+ bfl_repo = os.environ.get("BFL_REPO", "ChuckMcSneed/FLUX.1-dev")
+
+ transformer = FluxTransformer2DModel.from_single_file(modelFile, torch_dtype=dtype)
+ quantize(transformer, weights=qfloat8)
+ freeze(transformer)
+ text_encoder_2 = T5EncoderModel.from_pretrained(bfl_repo, subfolder="text_encoder_2", torch_dtype=dtype)
+ quantize(text_encoder_2, weights=qfloat8)
+ freeze(text_encoder_2)
+
+ pipe = FluxPipeline.from_pretrained(bfl_repo, transformer=None, text_encoder_2=None, torch_dtype=dtype)
+ pipe.transformer = transformer
+ pipe.text_encoder_2 = text_encoder_2
+
+ if request.LowVRAM:
+ pipe.enable_model_cpu_offload()
+ return pipe
+
+ # WanPipeline - requires special VAE with float32 dtype
+ if pipeline_type == "WanPipeline":
+ vae = AutoencoderKLWan.from_pretrained(
+ request.Model,
+ subfolder="vae",
+ torch_dtype=torch.float32
+ )
+ pipe = load_diffusers_pipeline(
+ class_name="WanPipeline",
+ model_id=request.Model,
+ vae=vae,
+ torch_dtype=torchType
+ )
+ self.txt2vid = True
+ return pipe
+
+ # WanImageToVideoPipeline - requires special VAE with float32 dtype
+ if pipeline_type == "WanImageToVideoPipeline":
+ vae = AutoencoderKLWan.from_pretrained(
+ request.Model,
+ subfolder="vae",
+ torch_dtype=torch.float32
+ )
+ pipe = load_diffusers_pipeline(
+ class_name="WanImageToVideoPipeline",
+ model_id=request.Model,
+ vae=vae,
+ torch_dtype=torchType
+ )
+ self.img2vid = True
+ return pipe
+
+ # SanaPipeline - requires special VAE and text encoder dtype conversion
+ if pipeline_type == "SanaPipeline":
+ pipe = load_diffusers_pipeline(
+ class_name="SanaPipeline",
+ model_id=request.Model,
+ variant="bf16",
+ torch_dtype=torch.bfloat16
+ )
+ pipe.vae.to(torch.bfloat16)
+ pipe.text_encoder.to(torch.bfloat16)
+ return pipe
+
+ # VideoDiffusionPipeline - alias for DiffusionPipeline with txt2vid flag
+ if pipeline_type == "VideoDiffusionPipeline":
+ self.txt2vid = True
+ pipe = load_diffusers_pipeline(
+ class_name="DiffusionPipeline",
+ model_id=request.Model,
+ torch_dtype=torchType
+ )
+ return pipe
+
+ # StableVideoDiffusionPipeline - needs img2vid flag and CPU offload
+ if pipeline_type == "StableVideoDiffusionPipeline":
+ self.img2vid = True
+ pipe = load_diffusers_pipeline(
+ class_name="StableVideoDiffusionPipeline",
+ model_id=request.Model,
+ torch_dtype=torchType,
+ variant=variant
+ )
+ if not DISABLE_CPU_OFFLOAD:
+ pipe.enable_model_cpu_offload()
+ return pipe
+
+ # LTX2ImageToVideoPipeline - needs img2vid flag, CPU offload, and special handling
+ if pipeline_type == "LTX2ImageToVideoPipeline":
+ self.img2vid = True
+ self.ltx2_pipeline = True
+ pipe = load_diffusers_pipeline(
+ class_name="LTX2ImageToVideoPipeline",
+ model_id=request.Model,
+ torch_dtype=torchType,
+ variant=variant
+ )
+ if not DISABLE_CPU_OFFLOAD:
+ pipe.enable_model_cpu_offload()
+ return pipe
+
+ # ================================================================
+ # Dynamic pipeline loading - the default path for most pipelines
+ # Uses the dynamic loader to instantiate any pipeline by class name
+ # ================================================================
+
+ # Build kwargs for dynamic loading
+ load_kwargs = {"torch_dtype": torchType}
+
+ # Add variant if not loading from single file
+ if not fromSingleFile and variant:
+ load_kwargs["variant"] = variant
+
+ # Add use_safetensors for from_pretrained
+ if not fromSingleFile:
+ load_kwargs["use_safetensors"] = SAFETENSORS
+
+ # Determine pipeline class name - default to AutoPipelineForText2Image
+ effective_pipeline_type = pipeline_type if pipeline_type else "AutoPipelineForText2Image"
+
+ # Use dynamic loader for all pipelines
+ try:
+ pipe = load_diffusers_pipeline(
+ class_name=effective_pipeline_type,
+ model_id=modelFile if fromSingleFile else request.Model,
+ from_single_file=fromSingleFile,
+ **load_kwargs
+ )
+ except Exception as e:
+ # Provide helpful error with available pipelines
+ available = get_available_pipelines()
+ raise ValueError(
+ f"Failed to load pipeline '{effective_pipeline_type}': {e}\n"
+ f"Available pipelines: {', '.join(available[:30])}..."
+ ) from e
+
+ # Apply LowVRAM optimization if supported and requested
+ if request.LowVRAM and hasattr(pipe, 'enable_model_cpu_offload'):
+ pipe.enable_model_cpu_offload()
+
+ return pipe
+
+ def Health(self, request, context):
+ return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
+
+ def LoadModel(self, request, context):
+ try:
+ print(f"Loading model {request.Model}...", file=sys.stderr)
+ print(f"Request {request}", file=sys.stderr)
+ torchType = torch.float32
+ variant = None
+
+ if request.F16Memory:
+ torchType = torch.float16
+ variant = "fp16"
+
+ options = request.Options
+
+ # empty dict
+ self.options = {}
+
+ # The options are a list of strings in this form optname:optvalue
+ # We are storing all the options in a dict so we can use it later when
+ # generating the images
+ for opt in options:
+ if ":" not in opt:
+ continue
+ key, value = opt.split(":")
+ # if value is a number, convert it to the appropriate type
+ if is_float(value):
+ value = float(value)
+ elif is_int(value):
+ value = int(value)
+ elif value.lower() in ["true", "false"]:
+ value = value.lower() == "true"
+ self.options[key] = value
+
+ # From options, extract if present "torch_dtype" and set it to the appropriate type
+ if "torch_dtype" in self.options:
+ if self.options["torch_dtype"] == "fp16":
+ torchType = torch.float16
+ elif self.options["torch_dtype"] == "bf16":
+ torchType = torch.bfloat16
+ elif self.options["torch_dtype"] == "fp32":
+ torchType = torch.float32
+ # remove it from options
+ del self.options["torch_dtype"]
+
+ print(f"Options: {self.options}", file=sys.stderr)
+
+ local = False
+ modelFile = request.Model
+
+ self.cfg_scale = 7
+ self.PipelineType = request.PipelineType
+
+ if request.CFGScale != 0:
+ self.cfg_scale = request.CFGScale
+
+ clipmodel = "Lykon/dreamshaper-8"
+ if request.CLIPModel != "":
+ clipmodel = request.CLIPModel
+ clipsubfolder = "text_encoder"
+ if request.CLIPSubfolder != "":
+ clipsubfolder = request.CLIPSubfolder
+
+ # Check if ModelFile exists
+ if request.ModelFile != "":
+ if os.path.exists(request.ModelFile):
+ local = True
+ modelFile = request.ModelFile
+
+ fromSingleFile = request.Model.startswith("http") or request.Model.startswith("/") or local
+ self.img2vid = False
+ self.txt2vid = False
+ self.ltx2_pipeline = False
+
+ # Load pipeline using dynamic loader
+ # Special cases that require custom initialization are handled first
+ self.pipe = self._load_pipeline(
+ request=request,
+ modelFile=modelFile,
+ fromSingleFile=fromSingleFile,
+ torchType=torchType,
+ variant=variant
+ )
+
+ if CLIPSKIP and request.CLIPSkip != 0:
+ self.clip_skip = request.CLIPSkip
+ else:
+ self.clip_skip = 0
+
+ # torch_dtype needs to be customized. float16 for GPU, float32 for CPU
+ # TODO: this needs to be customized
+ if request.SchedulerType != "":
+ self.pipe.scheduler = get_scheduler(request.SchedulerType, self.pipe.scheduler.config)
+
+ if COMPEL:
+ self.compel = Compel(
+ tokenizer=[self.pipe.tokenizer, self.pipe.tokenizer_2],
+ text_encoder=[self.pipe.text_encoder, self.pipe.text_encoder_2],
+ returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
+ requires_pooled=[False, True]
+ )
+
+ if request.ControlNet:
+ self.controlnet = ControlNetModel.from_pretrained(
+ request.ControlNet, torch_dtype=torchType, variant=variant
+ )
+ self.pipe.controlnet = self.controlnet
+ else:
+ self.controlnet = None
+
+ if request.LoraAdapter and not os.path.isabs(request.LoraAdapter):
+ # modify LoraAdapter to be relative to modelFileBase
+ request.LoraAdapter = os.path.join(request.ModelPath, request.LoraAdapter)
+
+ device = "cpu" if not request.CUDA else "cuda"
+ if XPU:
+ device = "xpu"
+ mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
+ if mps_available:
+ device = "mps"
+ self.device = device
+ if request.LoraAdapter:
+ # Check if its a local file and not a directory ( we load lora differently for a safetensor file )
+ if os.path.exists(request.LoraAdapter) and not os.path.isdir(request.LoraAdapter):
+ self.pipe.load_lora_weights(request.LoraAdapter)
+ else:
+ self.pipe.unet.load_attn_procs(request.LoraAdapter)
+ if len(request.LoraAdapters) > 0:
+ i = 0
+ adapters_name = []
+ adapters_weights = []
+ for adapter in request.LoraAdapters:
+ if not os.path.isabs(adapter):
+ adapter = os.path.join(request.ModelPath, adapter)
+ self.pipe.load_lora_weights(adapter, adapter_name=f"adapter_{i}")
+ adapters_name.append(f"adapter_{i}")
+ i += 1
+
+ for adapters_weight in request.LoraScales:
+ adapters_weights.append(adapters_weight)
+
+ self.pipe.set_adapters(adapters_name, adapter_weights=adapters_weights)
+
+ if device != "cpu":
+ self.pipe.to(device)
+ if self.controlnet:
+ self.controlnet.to(device)
+
+ except Exception as err:
+ return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
+ # Implement your logic here for the LoadModel service
+ # Replace this with your desired response
+ return backend_pb2.Result(message="Model loaded successfully", success=True)
+
+ # https://github.com/huggingface/diffusers/issues/3064
+ def load_lora_weights(self, checkpoint_path, multiplier, device, dtype):
+ LORA_PREFIX_UNET = "lora_unet"
+ LORA_PREFIX_TEXT_ENCODER = "lora_te"
+ # load LoRA weight from .safetensors
+ state_dict = load_file(checkpoint_path, device=device)
+
+ updates = defaultdict(dict)
+ for key, value in state_dict.items():
+ # it is suggested to print out the key, it usually will be something like below
+ # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"
+
+ layer, elem = key.split('.', 1)
+ updates[layer][elem] = value
+
+ # directly update weight in diffusers model
+ for layer, elems in updates.items():
+
+ if "text" in layer:
+ layer_infos = layer.split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
+ curr_layer = self.pipe.text_encoder
+ else:
+ layer_infos = layer.split(LORA_PREFIX_UNET + "_")[-1].split("_")
+ curr_layer = self.pipe.unet
+
+ # find the target layer
+ temp_name = layer_infos.pop(0)
+ while len(layer_infos) > -1:
+ try:
+ curr_layer = curr_layer.__getattr__(temp_name)
+ if len(layer_infos) > 0:
+ temp_name = layer_infos.pop(0)
+ elif len(layer_infos) == 0:
+ break
+ except Exception:
+ if len(temp_name) > 0:
+ temp_name += "_" + layer_infos.pop(0)
+ else:
+ temp_name = layer_infos.pop(0)
+
+ # get elements for this layer
+ weight_up = elems['lora_up.weight'].to(dtype)
+ weight_down = elems['lora_down.weight'].to(dtype)
+ alpha = elems['alpha'] if 'alpha' in elems else None
+ if alpha:
+ alpha = alpha.item() / weight_up.shape[1]
+ else:
+ alpha = 1.0
+
+ # update weight
+ if len(weight_up.shape) == 4:
+ curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up.squeeze(3).squeeze(2), weight_down.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
+ else:
+ curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up, weight_down)
+
+ def GenerateImage(self, request, context):
+
+ prompt = request.positive_prompt
+
+ steps = 1
+
+ if request.step != 0:
+ steps = request.step
+
+ # create a dictionary of values for the parameters
+ options = {
+ "num_inference_steps": steps,
+ }
+
+ if hasattr(request, 'negative_prompt') and request.negative_prompt != "":
+ options["negative_prompt"] = request.negative_prompt
+
+ # Handle image source: prioritize RefImages over request.src
+ image_src = None
+ if hasattr(request, 'ref_images') and request.ref_images and len(request.ref_images) > 0:
+ # Use the first reference image if available
+ image_src = request.ref_images[0]
+ print(f"Using reference image: {image_src}", file=sys.stderr)
+ elif request.src != "":
+ # Fall back to request.src if no ref_images
+ image_src = request.src
+ print(f"Using source image: {image_src}", file=sys.stderr)
+ else:
+ print("No image source provided", file=sys.stderr)
+
+ if image_src and not self.controlnet and not self.img2vid:
+ image = Image.open(image_src)
+ options["image"] = image
+ elif self.controlnet and image_src:
+ pose_image = load_image(image_src)
+ options["image"] = pose_image
+
+ if CLIPSKIP and self.clip_skip != 0:
+ options["clip_skip"] = self.clip_skip
+
+ kwargs = {}
+
+ # populate kwargs from self.options.
+ kwargs.update(self.options)
+
+ # Set seed
+ if request.seed > 0:
+ kwargs["generator"] = torch.Generator(device=self.device).manual_seed(
+ request.seed
+ )
+
+ if self.PipelineType == "FluxPipeline":
+ kwargs["max_sequence_length"] = 256
+
+ if request.width:
+ kwargs["width"] = request.width
+
+ if request.height:
+ kwargs["height"] = request.height
+
+ if self.PipelineType == "FluxTransformer2DModel":
+ kwargs["output_type"] = "pil"
+ kwargs["generator"] = torch.Generator("cpu").manual_seed(0)
+
+ if self.img2vid:
+ # Load the conditioning image
+ if image_src:
+ image = load_image(image_src)
+ else:
+ # Fallback to request.src for img2vid if no ref_images
+ image = load_image(request.src)
+ image = image.resize((1024, 576))
+
+ generator = torch.manual_seed(request.seed)
+ frames = self.pipe(image, guidance_scale=self.cfg_scale, decode_chunk_size=CHUNK_SIZE, generator=generator).frames[0]
+ export_to_video(frames, request.dst, fps=FPS)
+ return backend_pb2.Result(message="Media generated successfully", success=True)
+
+ if self.txt2vid:
+ video_frames = self.pipe(prompt, guidance_scale=self.cfg_scale, num_inference_steps=steps, num_frames=int(FRAMES)).frames
+ export_to_video(video_frames, request.dst)
+ return backend_pb2.Result(message="Media generated successfully", success=True)
+
+ print(f"Generating image with {kwargs=}", file=sys.stderr)
+ image = {}
+ if COMPEL:
+ conditioning, pooled = self.compel.build_conditioning_tensor(prompt)
+ kwargs["prompt_embeds"] = conditioning
+ kwargs["pooled_prompt_embeds"] = pooled
+ # pass the kwargs dictionary to the self.pipe method
+ image = self.pipe(
+ guidance_scale=self.cfg_scale,
+ **kwargs
+ ).images[0]
+ else:
+ # pass the kwargs dictionary to the self.pipe method
+ image = self.pipe(
+ prompt,
+ guidance_scale=self.cfg_scale,
+ **kwargs
+ ).images[0]
+
+ # save the result
+ image.save(request.dst)
+
+ return backend_pb2.Result(message="Media generated", success=True)
+
+ def GenerateVideo(self, request, context):
+ try:
+ prompt = request.prompt
+ if not prompt:
+ return backend_pb2.Result(success=False, message="No prompt provided for video generation")
+
+ # Set default values from request or use defaults
+ num_frames = request.num_frames if request.num_frames > 0 else 81
+ fps = request.fps if request.fps > 0 else 16
+ cfg_scale = request.cfg_scale if request.cfg_scale > 0 else 4.0
+ num_inference_steps = request.step if request.step > 0 else 40
+
+ # Prepare generation parameters
+ kwargs = {
+ "prompt": prompt,
+ "negative_prompt": request.negative_prompt if request.negative_prompt else "",
+ "height": request.height if request.height > 0 else 720,
+ "width": request.width if request.width > 0 else 1280,
+ "num_frames": num_frames,
+ "guidance_scale": cfg_scale,
+ "num_inference_steps": num_inference_steps,
+ }
+
+ # Add custom options from self.options (including guidance_scale_2 if specified)
+ kwargs.update(self.options)
+
+ # Set seed if provided
+ if request.seed > 0:
+ kwargs["generator"] = torch.Generator(device=self.device).manual_seed(request.seed)
+
+ # Handle start and end images for video generation
+ if request.start_image:
+ kwargs["start_image"] = load_image(request.start_image)
+ if request.end_image:
+ kwargs["end_image"] = load_image(request.end_image)
+
+ print(f"Generating video with {kwargs=}", file=sys.stderr)
+
+ # Generate video frames based on pipeline type
+ if self.ltx2_pipeline or self.PipelineType == "LTX2ImageToVideoPipeline":
+ # LTX-2 image-to-video generation with audio
+ if not LTX2_AVAILABLE:
+ return backend_pb2.Result(success=False, message="LTX-2 pipeline requires diffusers.pipelines.ltx2.export_utils")
+
+ # LTX-2 uses 'image' parameter instead of 'start_image'
+ if request.start_image:
+ image = load_image(request.start_image)
+ kwargs["image"] = image
+ # Remove start_image if it was added
+ kwargs.pop("start_image", None)
+
+ # LTX-2 uses 'frame_rate' instead of 'fps'
+ frame_rate = float(fps)
+ kwargs["frame_rate"] = frame_rate
+
+ # LTX-2 requires output_type="np" and return_dict=False
+ kwargs["output_type"] = "np"
+ kwargs["return_dict"] = False
+
+ # Generate video and audio
+ video, audio = self.pipe(**kwargs)
+
+ # Convert video to uint8 format
+ video = (video * 255).round().astype("uint8")
+ video = torch.from_numpy(video)
+
+ # Use LTX-2's encode_video function which handles audio
+ ltx2_encode_video(
+ video[0],
+ fps=frame_rate,
+ audio=audio[0].float().cpu(),
+ audio_sample_rate=self.pipe.vocoder.config.output_sampling_rate,
+ output_path=request.dst,
+ )
+
+ return backend_pb2.Result(message="Video generated successfully", success=True)
+ elif self.PipelineType == "WanPipeline":
+ # WAN2.2 text-to-video generation
+ output = self.pipe(**kwargs)
+ frames = output.frames[0] # WAN2.2 returns frames in this format
+ elif self.PipelineType == "WanImageToVideoPipeline":
+ # WAN2.2 image-to-video generation
+ if request.start_image:
+ # Load and resize the input image according to WAN2.2 requirements
+ image = load_image(request.start_image)
+ # Use request dimensions or defaults, but respect WAN2.2 constraints
+ request_height = request.height if request.height > 0 else 480
+ request_width = request.width if request.width > 0 else 832
+ max_area = request_height * request_width
+ aspect_ratio = image.height / image.width
+ mod_value = self.pipe.vae_scale_factor_spatial * self.pipe.transformer.config.patch_size[1]
+ height = round((max_area * aspect_ratio) ** 0.5 / mod_value) * mod_value
+ width = round((max_area / aspect_ratio) ** 0.5 / mod_value) * mod_value
+ image = image.resize((width, height))
+ kwargs["image"] = image
+ kwargs["height"] = height
+ kwargs["width"] = width
+
+ output = self.pipe(**kwargs)
+ frames = output.frames[0]
+ elif self.img2vid:
+ # Generic image-to-video generation
+ if request.start_image:
+ image = load_image(request.start_image)
+ image = image.resize((request.width if request.width > 0 else 1024,
+ request.height if request.height > 0 else 576))
+ kwargs["image"] = image
+
+ output = self.pipe(**kwargs)
+ frames = output.frames[0]
+ elif self.txt2vid:
+ # Generic text-to-video generation
+ output = self.pipe(**kwargs)
+ frames = output.frames[0]
+ else:
+ return backend_pb2.Result(success=False, message=f"Pipeline {self.PipelineType} does not support video generation")
+
+ # Export video (for non-LTX-2 pipelines)
+ export_to_video(frames, request.dst, fps=fps)
+
+ return backend_pb2.Result(message="Video generated successfully", success=True)
+
+ except Exception as err:
+ print(f"Error generating video: {err}", file=sys.stderr)
+ traceback.print_exc()
+ return backend_pb2.Result(success=False, message=f"Error generating video: {err}")
+
+
+def serve(address):
+ server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
+ options=[
+ ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
+ ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
+ ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
+ ])
+ backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
+ server.add_insecure_port(address)
+ server.start()
+ print("Server started. Listening on: " + address, file=sys.stderr)
+
+ # Define the signal handler function
+ def signal_handler(sig, frame):
+ print("Received termination signal. Shutting down...")
+ server.stop(0)
+ sys.exit(0)
+
+ # Set the signal handlers for SIGINT and SIGTERM
+ signal.signal(signal.SIGINT, signal_handler)
+ signal.signal(signal.SIGTERM, signal_handler)
+
+ try:
+ while True:
+ time.sleep(_ONE_DAY_IN_SECONDS)
+ except KeyboardInterrupt:
+ server.stop(0)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Run the gRPC server.")
+ parser.add_argument(
+ "--addr", default="localhost:50051", help="The address to bind the server to."
+ )
+ args = parser.parse_args()
+
+ serve(args.addr)
diff --git a/backend/python/diffusers/diffusers_dynamic_loader.py b/backend/python/diffusers/diffusers_dynamic_loader.py
new file mode 100644
index 0000000000000000000000000000000000000000..e47c7c2cf08b7bbb22357339c86dc699ff25edb2
--- /dev/null
+++ b/backend/python/diffusers/diffusers_dynamic_loader.py
@@ -0,0 +1,538 @@
+"""
+Dynamic Diffusers Pipeline Loader
+
+This module provides dynamic discovery and loading of diffusers pipelines at runtime,
+eliminating the need for per-pipeline conditional statements. New pipelines added to
+diffusers become available automatically without code changes.
+
+The module also supports discovering other diffusers classes like schedulers, models,
+and other components, making it a generic solution for dynamic class loading.
+
+Usage:
+ from diffusers_dynamic_loader import load_diffusers_pipeline, get_available_pipelines
+
+ # Load by class name
+ pipe = load_diffusers_pipeline(class_name="StableDiffusionPipeline", model_id="...", torch_dtype=torch.float16)
+
+ # Load by task alias
+ pipe = load_diffusers_pipeline(task="text-to-image", model_id="...", torch_dtype=torch.float16)
+
+ # Load using model_id (infers from HuggingFace Hub if possible)
+ pipe = load_diffusers_pipeline(model_id="runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
+
+ # Get list of available pipelines
+ available = get_available_pipelines()
+
+ # Discover other diffusers classes (schedulers, models, etc.)
+ schedulers = discover_diffusers_classes("SchedulerMixin")
+ models = discover_diffusers_classes("ModelMixin")
+"""
+
+import importlib
+import re
+import sys
+from typing import Any, Dict, List, Optional, Tuple, Type
+
+
+# Global cache for discovered pipelines - computed once per process
+_pipeline_registry: Optional[Dict[str, Type]] = None
+_task_aliases: Optional[Dict[str, List[str]]] = None
+
+# Global cache for other discovered class types
+_class_registries: Dict[str, Dict[str, Type]] = {}
+
+
+def _camel_to_kebab(name: str) -> str:
+ """
+ Convert CamelCase to kebab-case.
+
+ Examples:
+ StableDiffusionPipeline -> stable-diffusion-pipeline
+ StableDiffusionXLImg2ImgPipeline -> stable-diffusion-xl-img-2-img-pipeline
+ """
+ # Insert hyphen before uppercase letters (but not at the start)
+ s1 = re.sub('(.)([A-Z][a-z]+)', r'\1-\2', name)
+ # Insert hyphen before uppercase letters following lowercase letters or numbers
+ s2 = re.sub('([a-z0-9])([A-Z])', r'\1-\2', s1)
+ return s2.lower()
+
+
+def _extract_task_keywords(class_name: str) -> List[str]:
+ """
+ Extract task-related keywords from a pipeline class name.
+
+ This function derives useful task aliases from the class name without
+ hardcoding per-pipeline branches.
+
+ Returns a list of potential task aliases for this pipeline.
+ """
+ aliases = []
+ name_lower = class_name.lower()
+
+ # Direct task mappings based on common patterns in class names
+ task_patterns = {
+ 'text2image': ['text-to-image', 'txt2img', 'text2image'],
+ 'texttoimage': ['text-to-image', 'txt2img', 'text2image'],
+ 'txt2img': ['text-to-image', 'txt2img', 'text2image'],
+ 'img2img': ['image-to-image', 'img2img', 'image2image'],
+ 'image2image': ['image-to-image', 'img2img', 'image2image'],
+ 'imagetoimage': ['image-to-image', 'img2img', 'image2image'],
+ 'img2video': ['image-to-video', 'img2vid', 'img2video'],
+ 'imagetovideo': ['image-to-video', 'img2vid', 'img2video'],
+ 'text2video': ['text-to-video', 'txt2vid', 'text2video'],
+ 'texttovideo': ['text-to-video', 'txt2vid', 'text2video'],
+ 'inpaint': ['inpainting', 'inpaint'],
+ 'depth2img': ['depth-to-image', 'depth2img'],
+ 'depthtoimage': ['depth-to-image', 'depth2img'],
+ 'controlnet': ['controlnet', 'control-net'],
+ 'upscale': ['upscaling', 'upscale', 'super-resolution'],
+ 'superresolution': ['upscaling', 'upscale', 'super-resolution'],
+ }
+
+ # Check for each pattern in the class name
+ for pattern, task_aliases in task_patterns.items():
+ if pattern in name_lower:
+ aliases.extend(task_aliases)
+
+ # Also detect general pipeline types from the class name structure
+ # E.g., StableDiffusionPipeline -> stable-diffusion, flux -> flux
+ # Remove "Pipeline" suffix and convert to kebab case
+ if class_name.endswith('Pipeline'):
+ base_name = class_name[:-8] # Remove "Pipeline"
+ kebab_name = _camel_to_kebab(base_name)
+ aliases.append(kebab_name)
+
+ # Extract model family name (e.g., "stable-diffusion" from "stable-diffusion-xl-img-2-img")
+ parts = kebab_name.split('-')
+ if len(parts) >= 2:
+ # Try the first two words as a family name
+ family = '-'.join(parts[:2])
+ if family not in aliases:
+ aliases.append(family)
+
+ # If no specific task pattern matched but class contains "Pipeline", add "text-to-image" as default
+ # since most diffusion pipelines support text-to-image generation
+ if 'text-to-image' not in aliases and 'image-to-image' not in aliases:
+ # Only add for pipelines that seem to be generation pipelines (not schedulers, etc.)
+ if 'pipeline' in name_lower and not any(x in name_lower for x in ['scheduler', 'processor', 'encoder']):
+ # Don't automatically add - let it be explicit
+ pass
+
+ return list(set(aliases)) # Remove duplicates
+
+
+def discover_diffusers_classes(
+ base_class_name: str,
+ include_base: bool = True
+) -> Dict[str, Type]:
+ """
+ Discover all subclasses of a given base class from diffusers.
+
+ This function provides a generic way to discover any type of diffusers class,
+ not just pipelines. It can be used to discover schedulers, models, processors,
+ and other components.
+
+ Args:
+ base_class_name: Name of the base class to search for subclasses
+ (e.g., "DiffusionPipeline", "SchedulerMixin", "ModelMixin")
+ include_base: Whether to include the base class itself in results
+
+ Returns:
+ Dict mapping class names to class objects
+
+ Examples:
+ # Discover all pipeline classes
+ pipelines = discover_diffusers_classes("DiffusionPipeline")
+
+ # Discover all scheduler classes
+ schedulers = discover_diffusers_classes("SchedulerMixin")
+
+ # Discover all model classes
+ models = discover_diffusers_classes("ModelMixin")
+
+ # Discover AutoPipeline classes
+ auto_pipelines = discover_diffusers_classes("AutoPipelineForText2Image")
+ """
+ global _class_registries
+
+ # Check cache first
+ if base_class_name in _class_registries:
+ return _class_registries[base_class_name]
+
+ import diffusers
+
+ # Try to get the base class from diffusers
+ base_class = None
+ try:
+ base_class = getattr(diffusers, base_class_name)
+ except AttributeError:
+ # Try to find in submodules
+ for submodule in ['schedulers', 'models', 'pipelines']:
+ try:
+ module = importlib.import_module(f'diffusers.{submodule}')
+ if hasattr(module, base_class_name):
+ base_class = getattr(module, base_class_name)
+ break
+ except (ImportError, ModuleNotFoundError):
+ continue
+
+ if base_class is None:
+ raise ValueError(f"Could not find base class '{base_class_name}' in diffusers")
+
+ registry: Dict[str, Type] = {}
+
+ # Include base class if requested
+ if include_base:
+ registry[base_class_name] = base_class
+
+ # Scan diffusers module for subclasses
+ for attr_name in dir(diffusers):
+ try:
+ attr = getattr(diffusers, attr_name)
+ if (isinstance(attr, type) and
+ issubclass(attr, base_class) and
+ (include_base or attr is not base_class)):
+ registry[attr_name] = attr
+ except (ImportError, AttributeError, TypeError, RuntimeError, ModuleNotFoundError):
+ continue
+
+ # Cache the results
+ _class_registries[base_class_name] = registry
+ return registry
+
+
+def get_available_classes(base_class_name: str) -> List[str]:
+ """
+ Get a sorted list of all discovered class names for a given base class.
+
+ Args:
+ base_class_name: Name of the base class (e.g., "SchedulerMixin")
+
+ Returns:
+ Sorted list of discovered class names
+ """
+ return sorted(discover_diffusers_classes(base_class_name).keys())
+
+
+def _discover_pipelines() -> Tuple[Dict[str, Type], Dict[str, List[str]]]:
+ """
+ Discover all subclasses of DiffusionPipeline from diffusers.
+
+ This function uses the generic discover_diffusers_classes() internally
+ and adds pipeline-specific task alias generation. It also includes
+ AutoPipeline classes which are special utility classes for automatic
+ pipeline selection.
+
+ Returns:
+ A tuple of (pipeline_registry, task_aliases) where:
+ - pipeline_registry: Dict mapping class names to class objects
+ - task_aliases: Dict mapping task aliases to lists of class names
+ """
+ # Use the generic discovery function
+ pipeline_registry = discover_diffusers_classes("DiffusionPipeline", include_base=True)
+
+ # Also add AutoPipeline classes - these are special utility classes that are
+ # NOT subclasses of DiffusionPipeline but are commonly used
+ import diffusers
+ auto_pipeline_classes = [
+ "AutoPipelineForText2Image",
+ "AutoPipelineForImage2Image",
+ "AutoPipelineForInpainting",
+ ]
+ for cls_name in auto_pipeline_classes:
+ try:
+ cls = getattr(diffusers, cls_name)
+ if cls is not None:
+ pipeline_registry[cls_name] = cls
+ except AttributeError:
+ # Class not available in this version of diffusers
+ pass
+
+ # Generate task aliases for pipelines
+ task_aliases: Dict[str, List[str]] = {}
+ for attr_name in pipeline_registry:
+ if attr_name == "DiffusionPipeline":
+ continue # Skip base class for alias generation
+
+ aliases = _extract_task_keywords(attr_name)
+ for alias in aliases:
+ if alias not in task_aliases:
+ task_aliases[alias] = []
+ if attr_name not in task_aliases[alias]:
+ task_aliases[alias].append(attr_name)
+
+ return pipeline_registry, task_aliases
+
+
+def get_pipeline_registry() -> Dict[str, Type]:
+ """
+ Get the cached pipeline registry.
+
+ Returns a dictionary mapping pipeline class names to their class objects.
+ The registry is built on first access and cached for subsequent calls.
+ """
+ global _pipeline_registry, _task_aliases
+ if _pipeline_registry is None:
+ _pipeline_registry, _task_aliases = _discover_pipelines()
+ return _pipeline_registry
+
+
+def get_task_aliases() -> Dict[str, List[str]]:
+ """
+ Get the cached task aliases dictionary.
+
+ Returns a dictionary mapping task aliases (e.g., "text-to-image") to
+ lists of pipeline class names that support that task.
+ """
+ global _pipeline_registry, _task_aliases
+ if _task_aliases is None:
+ _pipeline_registry, _task_aliases = _discover_pipelines()
+ return _task_aliases
+
+
+def get_available_pipelines() -> List[str]:
+ """
+ Get a sorted list of all discovered pipeline class names.
+
+ Returns:
+ List of pipeline class names available for loading.
+ """
+ return sorted(get_pipeline_registry().keys())
+
+
+def get_available_tasks() -> List[str]:
+ """
+ Get a sorted list of all available task aliases.
+
+ Returns:
+ List of task aliases (e.g., ["text-to-image", "image-to-image", ...])
+ """
+ return sorted(get_task_aliases().keys())
+
+
+def resolve_pipeline_class(
+ class_name: Optional[str] = None,
+ task: Optional[str] = None,
+ model_id: Optional[str] = None
+) -> Type:
+ """
+ Resolve a pipeline class from class_name, task, or model_id.
+
+ Priority:
+ 1. If class_name is provided, look it up directly
+ 2. If task is provided, resolve through task aliases
+ 3. If model_id is provided, try to infer from HuggingFace Hub
+
+ Args:
+ class_name: Exact pipeline class name (e.g., "StableDiffusionPipeline")
+ task: Task alias (e.g., "text-to-image", "img2img")
+ model_id: HuggingFace model ID (e.g., "runwayml/stable-diffusion-v1-5")
+
+ Returns:
+ The resolved pipeline class.
+
+ Raises:
+ ValueError: If no pipeline could be resolved.
+ """
+ registry = get_pipeline_registry()
+ aliases = get_task_aliases()
+
+ # 1. Direct class name lookup
+ if class_name:
+ if class_name in registry:
+ return registry[class_name]
+ # Try case-insensitive match
+ for name, cls in registry.items():
+ if name.lower() == class_name.lower():
+ return cls
+ raise ValueError(
+ f"Unknown pipeline class '{class_name}'. "
+ f"Available pipelines: {', '.join(sorted(registry.keys())[:20])}..."
+ )
+
+ # 2. Task alias lookup
+ if task:
+ task_lower = task.lower().replace('_', '-')
+ if task_lower in aliases:
+ # Return the first matching pipeline for this task
+ matching_classes = aliases[task_lower]
+ if matching_classes:
+ return registry[matching_classes[0]]
+
+ # Try partial matching
+ for alias, classes in aliases.items():
+ if task_lower in alias or alias in task_lower:
+ if classes:
+ return registry[classes[0]]
+
+ raise ValueError(
+ f"Unknown task '{task}'. "
+ f"Available tasks: {', '.join(sorted(aliases.keys())[:20])}..."
+ )
+
+ # 3. Try to infer from HuggingFace Hub
+ if model_id:
+ try:
+ from huggingface_hub import model_info
+ info = model_info(model_id)
+
+ # Check pipeline_tag
+ if hasattr(info, 'pipeline_tag') and info.pipeline_tag:
+ tag = info.pipeline_tag.lower().replace('_', '-')
+ if tag in aliases:
+ matching_classes = aliases[tag]
+ if matching_classes:
+ return registry[matching_classes[0]]
+
+ # Check model card for hints
+ if hasattr(info, 'cardData') and info.cardData:
+ card = info.cardData
+ if 'pipeline_tag' in card:
+ tag = card['pipeline_tag'].lower().replace('_', '-')
+ if tag in aliases:
+ matching_classes = aliases[tag]
+ if matching_classes:
+ return registry[matching_classes[0]]
+
+ except ImportError:
+ # huggingface_hub not available
+ pass
+ except (KeyError, AttributeError, ValueError, OSError):
+ # Model info lookup failed - common cases:
+ # - KeyError: Missing keys in model card
+ # - AttributeError: Missing attributes on model info
+ # - ValueError: Invalid model data
+ # - OSError: Network or file access issues
+ pass
+
+ # Fallback: use DiffusionPipeline.from_pretrained which auto-detects
+ # DiffusionPipeline is always added to registry in _discover_pipelines (line 132)
+ # but use .get() with import fallback for extra safety
+ from diffusers import DiffusionPipeline
+ return registry.get('DiffusionPipeline', DiffusionPipeline)
+
+ raise ValueError(
+ "Must provide at least one of: class_name, task, or model_id. "
+ f"Available pipelines: {', '.join(sorted(registry.keys())[:20])}... "
+ f"Available tasks: {', '.join(sorted(aliases.keys())[:20])}..."
+ )
+
+
+def load_diffusers_pipeline(
+ class_name: Optional[str] = None,
+ task: Optional[str] = None,
+ model_id: Optional[str] = None,
+ from_single_file: bool = False,
+ **kwargs
+) -> Any:
+ """
+ Load a diffusers pipeline dynamically.
+
+ This function resolves the appropriate pipeline class based on the provided
+ parameters and instantiates it with the given kwargs.
+
+ Args:
+ class_name: Exact pipeline class name (e.g., "StableDiffusionPipeline")
+ task: Task alias (e.g., "text-to-image", "img2img")
+ model_id: HuggingFace model ID or local path
+ from_single_file: If True, use from_single_file() instead of from_pretrained()
+ **kwargs: Additional arguments passed to from_pretrained() or from_single_file()
+
+ Returns:
+ An instantiated pipeline object.
+
+ Raises:
+ ValueError: If no pipeline could be resolved.
+ Exception: If pipeline loading fails.
+
+ Examples:
+ # Load by class name
+ pipe = load_diffusers_pipeline(
+ class_name="StableDiffusionPipeline",
+ model_id="runwayml/stable-diffusion-v1-5",
+ torch_dtype=torch.float16
+ )
+
+ # Load by task
+ pipe = load_diffusers_pipeline(
+ task="text-to-image",
+ model_id="runwayml/stable-diffusion-v1-5",
+ torch_dtype=torch.float16
+ )
+
+ # Load from single file
+ pipe = load_diffusers_pipeline(
+ class_name="StableDiffusionPipeline",
+ model_id="/path/to/model.safetensors",
+ from_single_file=True,
+ torch_dtype=torch.float16
+ )
+ """
+ # Resolve the pipeline class
+ pipeline_class = resolve_pipeline_class(
+ class_name=class_name,
+ task=task,
+ model_id=model_id
+ )
+
+ # If no model_id provided but we have a class, we can't load
+ if model_id is None:
+ raise ValueError("model_id is required to load a pipeline")
+
+ # Load the pipeline
+ try:
+ if from_single_file:
+ # Check if the class has from_single_file method
+ if hasattr(pipeline_class, 'from_single_file'):
+ return pipeline_class.from_single_file(model_id, **kwargs)
+ else:
+ raise ValueError(
+ f"Pipeline class {pipeline_class.__name__} does not support from_single_file(). "
+ f"Use from_pretrained() instead."
+ )
+ else:
+ return pipeline_class.from_pretrained(model_id, **kwargs)
+
+ except Exception as e:
+ # Provide helpful error message
+ available = get_available_pipelines()
+ raise RuntimeError(
+ f"Failed to load pipeline '{pipeline_class.__name__}' from '{model_id}': {e}\n"
+ f"Available pipelines: {', '.join(available[:20])}..."
+ ) from e
+
+
+def get_pipeline_info(class_name: str) -> Dict[str, Any]:
+ """
+ Get information about a specific pipeline class.
+
+ Args:
+ class_name: The pipeline class name
+
+ Returns:
+ Dictionary with pipeline information including:
+ - name: Class name
+ - aliases: List of task aliases
+ - supports_single_file: Whether from_single_file() is available
+ - docstring: Class docstring (if available)
+ """
+ registry = get_pipeline_registry()
+ aliases = get_task_aliases()
+
+ if class_name not in registry:
+ raise ValueError(f"Unknown pipeline: {class_name}")
+
+ cls = registry[class_name]
+
+ # Find all aliases for this pipeline
+ pipeline_aliases = []
+ for alias, classes in aliases.items():
+ if class_name in classes:
+ pipeline_aliases.append(alias)
+
+ return {
+ 'name': class_name,
+ 'aliases': pipeline_aliases,
+ 'supports_single_file': hasattr(cls, 'from_single_file'),
+ 'docstring': cls.__doc__[:200] if cls.__doc__ else None
+ }
diff --git a/backend/python/diffusers/install.sh b/backend/python/diffusers/install.sh
new file mode 100644
index 0000000000000000000000000000000000000000..83703b1b2853d8d6422d0a7c3771e83a9e8f6bfe
--- /dev/null
+++ b/backend/python/diffusers/install.sh
@@ -0,0 +1,30 @@
+#!/bin/bash
+set -e
+
+backend_dir=$(dirname $0)
+if [ -d $backend_dir/common ]; then
+ source $backend_dir/common/libbackend.sh
+else
+ source $backend_dir/../common/libbackend.sh
+fi
+
+# This is here because the Intel pip index is broken and returns 200 status codes for every package name, it just doesn't return any package links.
+# This makes uv think that the package exists in the Intel pip index, and by default it stops looking at other pip indexes once it finds a match.
+# We need uv to continue falling through to the pypi default index to find optimum[openvino] in the pypi index
+# the --upgrade actually allows us to *downgrade* torch to the version provided in the Intel pip index
+if [ "x${BUILD_PROFILE}" == "xintel" ]; then
+ EXTRA_PIP_INSTALL_FLAGS+=" --upgrade --index-strategy=unsafe-first-match"
+fi
+
+if [ "x${BUILD_PROFILE}" == "xl4t12" ]; then
+ USE_PIP=true
+fi
+
+# Use python 3.12 for l4t
+if [ "x${BUILD_PROFILE}" == "xl4t13" ]; then
+ PYTHON_VERSION="3.12"
+ PYTHON_PATCH="12"
+ PY_STANDALONE_TAG="20251120"
+fi
+
+installRequirements
diff --git a/backend/python/diffusers/requirements-cpu.txt b/backend/python/diffusers/requirements-cpu.txt
new file mode 100644
index 0000000000000000000000000000000000000000..fceda06d2f03db19d13f5e743682c94a11e9b254
--- /dev/null
+++ b/backend/python/diffusers/requirements-cpu.txt
@@ -0,0 +1,12 @@
+--extra-index-url https://download.pytorch.org/whl/cpu
+git+https://github.com/huggingface/diffusers
+opencv-python
+transformers
+torchvision==0.22.1
+accelerate
+compel
+peft
+sentencepiece
+torch==2.7.1
+optimum-quanto
+ftfy
\ No newline at end of file
diff --git a/backend/python/diffusers/requirements-cublas12.txt b/backend/python/diffusers/requirements-cublas12.txt
new file mode 100644
index 0000000000000000000000000000000000000000..632e9421f99cc58c369f1367ce62ea757af8da78
--- /dev/null
+++ b/backend/python/diffusers/requirements-cublas12.txt
@@ -0,0 +1,12 @@
+--extra-index-url https://download.pytorch.org/whl/cu121
+git+https://github.com/huggingface/diffusers
+opencv-python
+transformers
+torchvision
+accelerate
+compel
+peft
+sentencepiece
+torch
+ftfy
+optimum-quanto
diff --git a/backend/python/diffusers/requirements-cublas13.txt b/backend/python/diffusers/requirements-cublas13.txt
new file mode 100644
index 0000000000000000000000000000000000000000..4867a85cd4053a8b06e2b21e3c3203e42d9be030
--- /dev/null
+++ b/backend/python/diffusers/requirements-cublas13.txt
@@ -0,0 +1,12 @@
+--extra-index-url https://download.pytorch.org/whl/cu130
+git+https://github.com/huggingface/diffusers
+opencv-python
+transformers
+torchvision
+accelerate
+compel
+peft
+sentencepiece
+torch
+ftfy
+optimum-quanto
diff --git a/backend/python/diffusers/requirements-hipblas.txt b/backend/python/diffusers/requirements-hipblas.txt
new file mode 100644
index 0000000000000000000000000000000000000000..b1f8b3e048c5627a636a440635820f0df279e947
--- /dev/null
+++ b/backend/python/diffusers/requirements-hipblas.txt
@@ -0,0 +1,12 @@
+--extra-index-url https://download.pytorch.org/whl/rocm6.4
+torch==2.8.0+rocm6.4
+torchvision==0.23.0+rocm6.4
+git+https://github.com/huggingface/diffusers
+opencv-python
+transformers
+accelerate
+compel
+peft
+sentencepiece
+optimum-quanto
+ftfy
\ No newline at end of file
diff --git a/backend/python/diffusers/requirements-intel.txt b/backend/python/diffusers/requirements-intel.txt
new file mode 100644
index 0000000000000000000000000000000000000000..fec4d9df73999f39e75b0d051f13c3fe79563503
--- /dev/null
+++ b/backend/python/diffusers/requirements-intel.txt
@@ -0,0 +1,16 @@
+--extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
+intel-extension-for-pytorch==2.3.110+xpu
+torch==2.5.1+cxx11.abi
+torchvision==0.20.1+cxx11.abi
+oneccl_bind_pt==2.8.0+xpu
+optimum[openvino]
+setuptools
+git+https://github.com/huggingface/diffusers
+opencv-python
+transformers
+accelerate
+compel
+peft
+sentencepiece
+optimum-quanto
+ftfy
\ No newline at end of file
diff --git a/backend/python/diffusers/requirements-l4t12.txt b/backend/python/diffusers/requirements-l4t12.txt
new file mode 100644
index 0000000000000000000000000000000000000000..9f77a9d090142605a0c880f04a22911f9079814f
--- /dev/null
+++ b/backend/python/diffusers/requirements-l4t12.txt
@@ -0,0 +1,12 @@
+--extra-index-url https://pypi.jetson-ai-lab.io/jp6/cu129/
+torch
+git+https://github.com/huggingface/diffusers
+transformers
+accelerate
+compel
+peft
+optimum-quanto
+numpy<2
+sentencepiece
+torchvision
+ftfy
diff --git a/backend/python/diffusers/requirements-l4t13.txt b/backend/python/diffusers/requirements-l4t13.txt
new file mode 100644
index 0000000000000000000000000000000000000000..560858e354f44a8512280f3e6da0e927e927756d
--- /dev/null
+++ b/backend/python/diffusers/requirements-l4t13.txt
@@ -0,0 +1,13 @@
+--extra-index-url https://download.pytorch.org/whl/cu130
+torch
+git+https://github.com/huggingface/diffusers
+transformers
+accelerate
+compel
+peft
+optimum-quanto
+numpy<2
+sentencepiece
+torchvision
+ftfy
+chardet
diff --git a/backend/python/diffusers/requirements-mps.txt b/backend/python/diffusers/requirements-mps.txt
new file mode 100644
index 0000000000000000000000000000000000000000..8b7c2413bffa1a43a534bd29b81c2b726ea2be07
--- /dev/null
+++ b/backend/python/diffusers/requirements-mps.txt
@@ -0,0 +1,11 @@
+torch==2.7.1
+torchvision==0.22.1
+git+https://github.com/huggingface/diffusers
+opencv-python
+transformers
+accelerate
+compel
+peft
+sentencepiece
+optimum-quanto
+ftfy
\ No newline at end of file
diff --git a/backend/python/diffusers/requirements.txt b/backend/python/diffusers/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..5621046cddb01db80786ec7e3b226329f5c5eaaa
--- /dev/null
+++ b/backend/python/diffusers/requirements.txt
@@ -0,0 +1,5 @@
+setuptools
+grpcio==1.76.0
+pillow
+protobuf
+certifi
diff --git a/backend/python/diffusers/run.sh b/backend/python/diffusers/run.sh
new file mode 100644
index 0000000000000000000000000000000000000000..74367c99f332d369849f9029fa18a05c40b24dda
--- /dev/null
+++ b/backend/python/diffusers/run.sh
@@ -0,0 +1,17 @@
+#!/bin/bash
+backend_dir=$(dirname $0)
+if [ -d $backend_dir/common ]; then
+ source $backend_dir/common/libbackend.sh
+else
+ source $backend_dir/../common/libbackend.sh
+fi
+
+if [ -d "/opt/intel" ]; then
+ # Assumes we are using the Intel oneAPI container image
+ # https://github.com/intel/intel-extension-for-pytorch/issues/538
+ export XPU=1
+fi
+
+export PYTORCH_ENABLE_MPS_FALLBACK=1
+
+startBackend $@
diff --git a/backend/python/diffusers/test.py b/backend/python/diffusers/test.py
new file mode 100644
index 0000000000000000000000000000000000000000..5befeca0a99aec30b4cdb8c064ea993cf4e2e5de
--- /dev/null
+++ b/backend/python/diffusers/test.py
@@ -0,0 +1,314 @@
+"""
+A test script to test the gRPC service and dynamic loader
+"""
+import unittest
+import subprocess
+import time
+from unittest.mock import patch, MagicMock
+
+# Import dynamic loader for testing (these don't need gRPC)
+import diffusers_dynamic_loader as loader
+from diffusers import DiffusionPipeline, StableDiffusionPipeline
+
+# Try to import gRPC modules - may not be available during unit testing
+try:
+ import grpc
+ import backend_pb2
+ import backend_pb2_grpc
+ GRPC_AVAILABLE = True
+except ImportError:
+ GRPC_AVAILABLE = False
+
+
+@unittest.skipUnless(GRPC_AVAILABLE, "gRPC modules not available")
+class TestBackendServicer(unittest.TestCase):
+ """
+ TestBackendServicer is the class that tests the gRPC service
+ """
+ def setUp(self):
+ """
+ This method sets up the gRPC service by starting the server
+ """
+ self.service = subprocess.Popen(["python3", "backend.py", "--addr", "localhost:50051"])
+
+ def tearDown(self) -> None:
+ """
+ This method tears down the gRPC service by terminating the server
+ """
+ self.service.kill()
+ self.service.wait()
+
+ def test_server_startup(self):
+ """
+ This method tests if the server starts up successfully
+ """
+ time.sleep(20)
+ try:
+ self.setUp()
+ with grpc.insecure_channel("localhost:50051") as channel:
+ stub = backend_pb2_grpc.BackendStub(channel)
+ response = stub.Health(backend_pb2.HealthMessage())
+ self.assertEqual(response.message, b'OK')
+ except Exception as err:
+ print(err)
+ self.fail("Server failed to start")
+ finally:
+ self.tearDown()
+
+ def test_load_model(self):
+ """
+ This method tests if the model is loaded successfully
+ """
+ time.sleep(20)
+ try:
+ self.setUp()
+ with grpc.insecure_channel("localhost:50051") as channel:
+ stub = backend_pb2_grpc.BackendStub(channel)
+ response = stub.LoadModel(backend_pb2.ModelOptions(Model="Lykon/dreamshaper-8"))
+ self.assertTrue(response.success)
+ self.assertEqual(response.message, "Model loaded successfully")
+ except Exception as err:
+ print(err)
+ self.fail("LoadModel service failed")
+ finally:
+ self.tearDown()
+
+ def test(self):
+ """
+ This method tests if the backend can generate images
+ """
+ time.sleep(20)
+ try:
+ self.setUp()
+ with grpc.insecure_channel("localhost:50051") as channel:
+ stub = backend_pb2_grpc.BackendStub(channel)
+ response = stub.LoadModel(backend_pb2.ModelOptions(Model="Lykon/dreamshaper-8"))
+ print(response.message)
+ self.assertTrue(response.success)
+ image_req = backend_pb2.GenerateImageRequest(positive_prompt="cat", width=16,height=16, dst="test.jpg")
+ re = stub.GenerateImage(image_req)
+ self.assertTrue(re.success)
+ except Exception as err:
+ print(err)
+ self.fail("Image gen service failed")
+ finally:
+ self.tearDown()
+
+
+class TestDiffusersDynamicLoader(unittest.TestCase):
+ """Test cases for the diffusers dynamic loader functionality."""
+
+ @classmethod
+ def setUpClass(cls):
+ """Set up test fixtures - clear caches to ensure fresh discovery."""
+ # Reset the caches to ensure fresh discovery
+ loader._pipeline_registry = None
+ loader._task_aliases = None
+
+ def test_camel_to_kebab_conversion(self):
+ """Test CamelCase to kebab-case conversion."""
+ test_cases = [
+ ("StableDiffusionPipeline", "stable-diffusion-pipeline"),
+ ("StableDiffusionXLPipeline", "stable-diffusion-xl-pipeline"),
+ ("FluxPipeline", "flux-pipeline"),
+ ("DiffusionPipeline", "diffusion-pipeline"),
+ ]
+ for input_val, expected in test_cases:
+ with self.subTest(input=input_val):
+ result = loader._camel_to_kebab(input_val)
+ self.assertEqual(result, expected)
+
+ def test_extract_task_keywords(self):
+ """Test task keyword extraction from class names."""
+ # Test text-to-image detection
+ aliases = loader._extract_task_keywords("StableDiffusionPipeline")
+ self.assertIn("stable-diffusion", aliases)
+
+ # Test img2img detection
+ aliases = loader._extract_task_keywords("StableDiffusionImg2ImgPipeline")
+ self.assertIn("image-to-image", aliases)
+ self.assertIn("img2img", aliases)
+
+ # Test inpainting detection
+ aliases = loader._extract_task_keywords("StableDiffusionInpaintPipeline")
+ self.assertIn("inpainting", aliases)
+ self.assertIn("inpaint", aliases)
+
+ # Test depth2img detection
+ aliases = loader._extract_task_keywords("StableDiffusionDepth2ImgPipeline")
+ self.assertIn("depth-to-image", aliases)
+
+ def test_discover_pipelines_finds_known_classes(self):
+ """Test that pipeline discovery finds at least one known pipeline class."""
+ registry = loader.get_pipeline_registry()
+
+ # Check that the registry is not empty
+ self.assertGreater(len(registry), 0, "Pipeline registry should not be empty")
+
+ # Check for known pipeline classes
+ known_pipelines = [
+ "StableDiffusionPipeline",
+ "DiffusionPipeline",
+ ]
+
+ for pipeline_name in known_pipelines:
+ with self.subTest(pipeline=pipeline_name):
+ self.assertIn(
+ pipeline_name,
+ registry,
+ f"Expected to find {pipeline_name} in registry"
+ )
+
+ def test_discover_pipelines_caches_results(self):
+ """Test that pipeline discovery results are cached."""
+ # Get registry twice
+ registry1 = loader.get_pipeline_registry()
+ registry2 = loader.get_pipeline_registry()
+
+ # Should be the same object (cached)
+ self.assertIs(registry1, registry2, "Registry should be cached")
+
+ def test_get_available_pipelines(self):
+ """Test getting list of available pipelines."""
+ available = loader.get_available_pipelines()
+
+ # Should return a list
+ self.assertIsInstance(available, list)
+
+ # Should contain known pipelines
+ self.assertIn("StableDiffusionPipeline", available)
+ self.assertIn("DiffusionPipeline", available)
+
+ # Should be sorted
+ self.assertEqual(available, sorted(available))
+
+ def test_get_available_tasks(self):
+ """Test getting list of available task aliases."""
+ tasks = loader.get_available_tasks()
+
+ # Should return a list
+ self.assertIsInstance(tasks, list)
+
+ # Should be sorted
+ self.assertEqual(tasks, sorted(tasks))
+
+ def test_resolve_pipeline_class_by_name(self):
+ """Test resolving pipeline class by exact name."""
+ cls = loader.resolve_pipeline_class(class_name="StableDiffusionPipeline")
+ self.assertEqual(cls, StableDiffusionPipeline)
+
+ def test_resolve_pipeline_class_by_name_case_insensitive(self):
+ """Test that class name resolution is case-insensitive."""
+ cls1 = loader.resolve_pipeline_class(class_name="StableDiffusionPipeline")
+ cls2 = loader.resolve_pipeline_class(class_name="stablediffusionpipeline")
+ self.assertEqual(cls1, cls2)
+
+ def test_resolve_pipeline_class_by_task(self):
+ """Test resolving pipeline class by task alias."""
+ # Get the registry to find available tasks
+ aliases = loader.get_task_aliases()
+
+ # Test with a common task that should be available
+ if "stable-diffusion" in aliases:
+ cls = loader.resolve_pipeline_class(task="stable-diffusion")
+ self.assertIsNotNone(cls)
+
+ def test_resolve_pipeline_class_unknown_name_raises(self):
+ """Test that resolving unknown class name raises ValueError with helpful message."""
+ with self.assertRaises(ValueError) as ctx:
+ loader.resolve_pipeline_class(class_name="NonExistentPipeline")
+
+ # Check that error message includes available pipelines
+ error_msg = str(ctx.exception)
+ self.assertIn("Unknown pipeline class", error_msg)
+ self.assertIn("Available pipelines", error_msg)
+
+ def test_resolve_pipeline_class_unknown_task_raises(self):
+ """Test that resolving unknown task raises ValueError with helpful message."""
+ with self.assertRaises(ValueError) as ctx:
+ loader.resolve_pipeline_class(task="nonexistent-task-xyz")
+
+ # Check that error message includes available tasks
+ error_msg = str(ctx.exception)
+ self.assertIn("Unknown task", error_msg)
+ self.assertIn("Available tasks", error_msg)
+
+ def test_resolve_pipeline_class_no_params_raises(self):
+ """Test that calling with no parameters raises helpful ValueError."""
+ with self.assertRaises(ValueError) as ctx:
+ loader.resolve_pipeline_class()
+
+ error_msg = str(ctx.exception)
+ self.assertIn("Must provide at least one of", error_msg)
+
+ def test_get_pipeline_info(self):
+ """Test getting pipeline information."""
+ info = loader.get_pipeline_info("StableDiffusionPipeline")
+
+ self.assertEqual(info['name'], "StableDiffusionPipeline")
+ self.assertIsInstance(info['aliases'], list)
+ self.assertIsInstance(info['supports_single_file'], bool)
+
+ def test_get_pipeline_info_unknown_raises(self):
+ """Test that getting info for unknown pipeline raises ValueError."""
+ with self.assertRaises(ValueError) as ctx:
+ loader.get_pipeline_info("NonExistentPipeline")
+
+ self.assertIn("Unknown pipeline", str(ctx.exception))
+
+ def test_discover_diffusers_classes_pipelines(self):
+ """Test generic class discovery for DiffusionPipeline."""
+ classes = loader.discover_diffusers_classes("DiffusionPipeline")
+
+ # Should return a dict
+ self.assertIsInstance(classes, dict)
+
+ # Should contain known pipeline classes
+ self.assertIn("DiffusionPipeline", classes)
+ self.assertIn("StableDiffusionPipeline", classes)
+
+ def test_discover_diffusers_classes_caches_results(self):
+ """Test that class discovery results are cached."""
+ classes1 = loader.discover_diffusers_classes("DiffusionPipeline")
+ classes2 = loader.discover_diffusers_classes("DiffusionPipeline")
+
+ # Should be the same object (cached)
+ self.assertIs(classes1, classes2)
+
+ def test_discover_diffusers_classes_exclude_base(self):
+ """Test discovering classes without base class."""
+ classes = loader.discover_diffusers_classes("DiffusionPipeline", include_base=False)
+
+ # Should still contain subclasses
+ self.assertIn("StableDiffusionPipeline", classes)
+
+ def test_get_available_classes(self):
+ """Test getting list of available classes for a base class."""
+ classes = loader.get_available_classes("DiffusionPipeline")
+
+ # Should return a sorted list
+ self.assertIsInstance(classes, list)
+ self.assertEqual(classes, sorted(classes))
+
+ # Should contain known classes
+ self.assertIn("StableDiffusionPipeline", classes)
+
+
+class TestDiffusersDynamicLoaderWithMocks(unittest.TestCase):
+ """Test cases using mocks to test edge cases."""
+
+ def test_load_pipeline_requires_model_id(self):
+ """Test that load_diffusers_pipeline requires model_id."""
+ with self.assertRaises(ValueError) as ctx:
+ loader.load_diffusers_pipeline(class_name="StableDiffusionPipeline")
+
+ self.assertIn("model_id is required", str(ctx.exception))
+
+ def test_resolve_with_model_id_uses_diffusion_pipeline_fallback(self):
+ """Test that resolving with only model_id falls back to DiffusionPipeline."""
+ # When model_id is provided, if hub lookup is not successful,
+ # should fall back to DiffusionPipeline.
+ # This tests the fallback behavior - the actual hub lookup may succeed
+ # or fail depending on network, but the fallback path should work.
+ cls = loader.resolve_pipeline_class(model_id="some/nonexistent/model")
+ self.assertEqual(cls, DiffusionPipeline)
diff --git a/backend/python/diffusers/test.sh b/backend/python/diffusers/test.sh
new file mode 100644
index 0000000000000000000000000000000000000000..eb59f2aaf3f38f78c7ca3dc414ea490ff66776d7
--- /dev/null
+++ b/backend/python/diffusers/test.sh
@@ -0,0 +1,11 @@
+#!/bin/bash
+set -e
+
+backend_dir=$(dirname $0)
+if [ -d $backend_dir/common ]; then
+ source $backend_dir/common/libbackend.sh
+else
+ source $backend_dir/../common/libbackend.sh
+fi
+
+runUnittests
diff --git a/backend/python/exllama2/.gitignore b/backend/python/exllama2/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..1d3a06547c706bd0f75b130edecf4832295bba6b
--- /dev/null
+++ b/backend/python/exllama2/.gitignore
@@ -0,0 +1 @@
+source
\ No newline at end of file
diff --git a/backend/python/exllama2/Makefile b/backend/python/exllama2/Makefile
new file mode 100644
index 0000000000000000000000000000000000000000..15ba38d120f36f4c380dc71583eadcc238d45dd6
--- /dev/null
+++ b/backend/python/exllama2/Makefile
@@ -0,0 +1,17 @@
+.PHONY: exllama2
+exllama2:
+ bash install.sh
+
+.PHONY: run
+run: exllama2
+ @echo "Running exllama2..."
+ bash run.sh
+ @echo "exllama2 run."
+
+.PHONY: protogen-clean
+protogen-clean:
+ $(RM) backend_pb2_grpc.py backend_pb2.py
+
+.PHONY: clean
+clean: protogen-clean
+ $(RM) -r venv source __pycache__
\ No newline at end of file
diff --git a/backend/python/exllama2/backend.py b/backend/python/exllama2/backend.py
new file mode 100644
index 0000000000000000000000000000000000000000..7aacea360bc0072ca0bda67449da4f211d642eec
--- /dev/null
+++ b/backend/python/exllama2/backend.py
@@ -0,0 +1,143 @@
+#!/usr/bin/env python3
+import grpc
+from concurrent import futures
+import time
+import backend_pb2
+import backend_pb2_grpc
+import argparse
+import signal
+import sys
+import os
+import glob
+
+from pathlib import Path
+import torch
+import torch.nn.functional as F
+from torch import version as torch_version
+
+
+from exllamav2.generator import (
+ ExLlamaV2BaseGenerator,
+ ExLlamaV2Sampler
+)
+
+
+from exllamav2 import (
+ ExLlamaV2,
+ ExLlamaV2Config,
+ ExLlamaV2Cache,
+ ExLlamaV2Cache_8bit,
+ ExLlamaV2Tokenizer,
+ model_init,
+)
+
+
+_ONE_DAY_IN_SECONDS = 60 * 60 * 24
+
+# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
+MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
+
+# Implement the BackendServicer class with the service methods
+class BackendServicer(backend_pb2_grpc.BackendServicer):
+ def Health(self, request, context):
+ return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
+
+ def LoadModel(self, request, context):
+ try:
+ model_directory = request.ModelFile
+
+ config = ExLlamaV2Config()
+ config.model_dir = model_directory
+ config.prepare()
+
+ model = ExLlamaV2(config)
+
+ cache = ExLlamaV2Cache(model, lazy=True)
+ model.load_autosplit(cache)
+
+ tokenizer = ExLlamaV2Tokenizer(config)
+
+ # Initialize generator
+
+ generator = ExLlamaV2BaseGenerator(model, cache, tokenizer)
+
+ self.generator = generator
+
+ generator.warmup()
+ self.model = model
+ self.tokenizer = tokenizer
+ self.cache = cache
+ except Exception as err:
+ return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
+ return backend_pb2.Result(message="Model loaded successfully", success=True)
+
+ def Predict(self, request, context):
+
+ penalty = 1.15
+ if request.Penalty != 0.0:
+ penalty = request.Penalty
+
+ settings = ExLlamaV2Sampler.Settings()
+ settings.temperature = request.Temperature
+ settings.top_k = request.TopK
+ settings.top_p = request.TopP
+ settings.token_repetition_penalty = penalty
+ settings.disallow_tokens(self.tokenizer, [self.tokenizer.eos_token_id])
+ tokens = 512
+
+ if request.Tokens != 0:
+ tokens = request.Tokens
+ output = self.generator.generate_simple(
+ request.Prompt, settings, tokens)
+
+ # Remove prompt from response if present
+ if request.Prompt in output:
+ output = output.replace(request.Prompt, "")
+
+ return backend_pb2.Result(message=bytes(output, encoding='utf-8'))
+
+ def PredictStream(self, request, context):
+ # Implement PredictStream RPC
+ # for reply in some_data_generator():
+ # yield reply
+ # Not implemented yet
+ return self.Predict(request, context)
+
+
+def serve(address):
+ server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
+ options=[
+ ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
+ ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
+ ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
+ ])
+ backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
+ server.add_insecure_port(address)
+ server.start()
+ print("Server started. Listening on: " + address, file=sys.stderr)
+
+ # Define the signal handler function
+ def signal_handler(sig, frame):
+ print("Received termination signal. Shutting down...")
+ server.stop(0)
+ sys.exit(0)
+
+ # Set the signal handlers for SIGINT and SIGTERM
+ signal.signal(signal.SIGINT, signal_handler)
+ signal.signal(signal.SIGTERM, signal_handler)
+
+ try:
+ while True:
+ time.sleep(_ONE_DAY_IN_SECONDS)
+ except KeyboardInterrupt:
+ server.stop(0)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Run the gRPC server.")
+ parser.add_argument(
+ "--addr", default="localhost:50051", help="The address to bind the server to."
+ )
+ args = parser.parse_args()
+
+ serve(args.addr)
diff --git a/backend/python/exllama2/install.sh b/backend/python/exllama2/install.sh
new file mode 100644
index 0000000000000000000000000000000000000000..6cbc28a171a8ed125b918770758cb435a3e6fac3
--- /dev/null
+++ b/backend/python/exllama2/install.sh
@@ -0,0 +1,21 @@
+#!/bin/bash
+set -e
+
+LIMIT_TARGETS="cublas"
+EXTRA_PIP_INSTALL_FLAGS="--no-build-isolation"
+EXLLAMA2_VERSION=c0ddebaaaf8ffd1b3529c2bb654e650bce2f790f
+
+backend_dir=$(dirname $0)
+if [ -d $backend_dir/common ]; then
+ source $backend_dir/common/libbackend.sh
+else
+ source $backend_dir/../common/libbackend.sh
+fi
+
+installRequirements
+
+git clone https://github.com/turboderp/exllamav2 $MY_DIR/source
+pushd ${MY_DIR}/source && git checkout -b build ${EXLLAMA2_VERSION} && popd
+
+# This installs exllamav2 in JIT mode so it will compile the appropriate torch extension at runtime
+EXLLAMA_NOCOMPILE= uv pip install ${EXTRA_PIP_INSTALL_FLAGS} ${MY_DIR}/source/
diff --git a/backend/python/exllama2/requirements-cpu.txt b/backend/python/exllama2/requirements-cpu.txt
new file mode 100644
index 0000000000000000000000000000000000000000..2021fc201f7e35dabdcf545e96705f6bd96b0511
--- /dev/null
+++ b/backend/python/exllama2/requirements-cpu.txt
@@ -0,0 +1,3 @@
+transformers
+accelerate
+torch==2.4.1
\ No newline at end of file
diff --git a/backend/python/exllama2/requirements-cublas12.txt b/backend/python/exllama2/requirements-cublas12.txt
new file mode 100644
index 0000000000000000000000000000000000000000..93e62c5ab27d98558669a5e2c63be4f2f6ad984c
--- /dev/null
+++ b/backend/python/exllama2/requirements-cublas12.txt
@@ -0,0 +1,3 @@
+torch==2.4.1
+transformers
+accelerate
\ No newline at end of file
diff --git a/backend/python/exllama2/requirements-install.txt b/backend/python/exllama2/requirements-install.txt
new file mode 100644
index 0000000000000000000000000000000000000000..322799ff60f47339453383cb0ea9ab7178bb51c9
--- /dev/null
+++ b/backend/python/exllama2/requirements-install.txt
@@ -0,0 +1,4 @@
+# This is here to trigger the install script to add --no-build-isolation to the uv pip install commands
+# exllama2 does not specify it's build requirements per PEP517, so we need to provide some things ourselves
+wheel
+setuptools
\ No newline at end of file
diff --git a/backend/python/exllama2/requirements.txt b/backend/python/exllama2/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..3044ff0e23bbb35a97d9e809e7dc10590aeb8134
--- /dev/null
+++ b/backend/python/exllama2/requirements.txt
@@ -0,0 +1,5 @@
+grpcio==1.76.0
+protobuf
+certifi
+wheel
+setuptools
\ No newline at end of file
diff --git a/backend/python/exllama2/run.sh b/backend/python/exllama2/run.sh
new file mode 100644
index 0000000000000000000000000000000000000000..91c79aadeb277fffe800254e42ea553217f8ca13
--- /dev/null
+++ b/backend/python/exllama2/run.sh
@@ -0,0 +1,11 @@
+#!/bin/bash
+LIMIT_TARGETS="cublas"
+
+backend_dir=$(dirname $0)
+if [ -d $backend_dir/common ]; then
+ source $backend_dir/common/libbackend.sh
+else
+ source $backend_dir/../common/libbackend.sh
+fi
+
+startBackend $@
\ No newline at end of file
diff --git a/backend/python/exllama2/test.sh b/backend/python/exllama2/test.sh
new file mode 100644
index 0000000000000000000000000000000000000000..eb59f2aaf3f38f78c7ca3dc414ea490ff66776d7
--- /dev/null
+++ b/backend/python/exllama2/test.sh
@@ -0,0 +1,11 @@
+#!/bin/bash
+set -e
+
+backend_dir=$(dirname $0)
+if [ -d $backend_dir/common ]; then
+ source $backend_dir/common/libbackend.sh
+else
+ source $backend_dir/../common/libbackend.sh
+fi
+
+runUnittests
diff --git a/backend/python/faster-whisper/Makefile b/backend/python/faster-whisper/Makefile
new file mode 100644
index 0000000000000000000000000000000000000000..f6b9ddc6c888845a9b20c98d3ef8bfae3629a1cd
--- /dev/null
+++ b/backend/python/faster-whisper/Makefile
@@ -0,0 +1,13 @@
+.DEFAULT_GOAL := install
+
+.PHONY: install
+install:
+ bash install.sh
+
+.PHONY: protogen-clean
+protogen-clean:
+ $(RM) backend_pb2_grpc.py backend_pb2.py
+
+.PHONY: clean
+clean: protogen-clean
+ rm -rf venv __pycache__
\ No newline at end of file
diff --git a/backend/python/faster-whisper/backend.py b/backend/python/faster-whisper/backend.py
new file mode 100644
index 0000000000000000000000000000000000000000..808f29238207cb4f110ff103e2bdcc3ceb5ac5f2
--- /dev/null
+++ b/backend/python/faster-whisper/backend.py
@@ -0,0 +1,101 @@
+#!/usr/bin/env python3
+"""
+This is an extra gRPC server of LocalAI for Bark TTS
+"""
+from concurrent import futures
+import time
+import argparse
+import signal
+import sys
+import os
+import backend_pb2
+import backend_pb2_grpc
+import torch
+from faster_whisper import WhisperModel
+
+import grpc
+
+
+_ONE_DAY_IN_SECONDS = 60 * 60 * 24
+
+# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
+MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
+COQUI_LANGUAGE = os.environ.get('COQUI_LANGUAGE', None)
+
+# Implement the BackendServicer class with the service methods
+class BackendServicer(backend_pb2_grpc.BackendServicer):
+ """
+ BackendServicer is the class that implements the gRPC service
+ """
+ def Health(self, request, context):
+ return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
+ def LoadModel(self, request, context):
+ device = "cpu"
+ # Get device
+ # device = "cuda" if request.CUDA else "cpu"
+ if request.CUDA:
+ device = "cuda"
+ mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
+ if mps_available:
+ device = "mps"
+ try:
+ print("Preparing models, please wait", file=sys.stderr)
+ self.model = WhisperModel(request.Model, device=device, compute_type="float16")
+ except Exception as err:
+ return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
+ # Implement your logic here for the LoadModel service
+ # Replace this with your desired response
+ return backend_pb2.Result(message="Model loaded successfully", success=True)
+
+ def AudioTranscription(self, request, context):
+ resultSegments = []
+ text = ""
+ try:
+ segments, info = self.model.transcribe(request.dst, beam_size=5, condition_on_previous_text=False)
+ id = 0
+ for segment in segments:
+ print("[%.2fs -> %.2fs] %s" % (segment.start, segment.end, segment.text))
+ resultSegments.append(backend_pb2.TranscriptSegment(id=id, start=segment.start, end=segment.end, text=segment.text))
+ text += segment.text
+ id += 1
+ except Exception as err:
+ print(f"Unexpected {err=}, {type(err)=}", file=sys.stderr)
+
+ return backend_pb2.TranscriptResult(segments=resultSegments, text=text)
+
+def serve(address):
+ server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
+ options=[
+ ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
+ ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
+ ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
+ ])
+ backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
+ server.add_insecure_port(address)
+ server.start()
+ print("Server started. Listening on: " + address, file=sys.stderr)
+
+ # Define the signal handler function
+ def signal_handler(sig, frame):
+ print("Received termination signal. Shutting down...")
+ server.stop(0)
+ sys.exit(0)
+
+ # Set the signal handlers for SIGINT and SIGTERM
+ signal.signal(signal.SIGINT, signal_handler)
+ signal.signal(signal.SIGTERM, signal_handler)
+
+ try:
+ while True:
+ time.sleep(_ONE_DAY_IN_SECONDS)
+ except KeyboardInterrupt:
+ server.stop(0)
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Run the gRPC server.")
+ parser.add_argument(
+ "--addr", default="localhost:50051", help="The address to bind the server to."
+ )
+ args = parser.parse_args()
+
+ serve(args.addr)
diff --git a/backend/python/faster-whisper/install.sh b/backend/python/faster-whisper/install.sh
new file mode 100644
index 0000000000000000000000000000000000000000..32befa8e6c034734e54038c3b4b565522028f391
--- /dev/null
+++ b/backend/python/faster-whisper/install.sh
@@ -0,0 +1,19 @@
+#!/bin/bash
+set -e
+
+backend_dir=$(dirname $0)
+if [ -d $backend_dir/common ]; then
+ source $backend_dir/common/libbackend.sh
+else
+ source $backend_dir/../common/libbackend.sh
+fi
+
+# This is here because the Intel pip index is broken and returns 200 status codes for every package name, it just doesn't return any package links.
+# This makes uv think that the package exists in the Intel pip index, and by default it stops looking at other pip indexes once it finds a match.
+# We need uv to continue falling through to the pypi default index to find optimum[openvino] in the pypi index
+# the --upgrade actually allows us to *downgrade* torch to the version provided in the Intel pip index
+if [ "x${BUILD_PROFILE}" == "xintel" ]; then
+ EXTRA_PIP_INSTALL_FLAGS+=" --upgrade --index-strategy=unsafe-first-match"
+fi
+
+installRequirements
diff --git a/backend/python/faster-whisper/protogen.sh b/backend/python/faster-whisper/protogen.sh
new file mode 100644
index 0000000000000000000000000000000000000000..d608379c16061622dd9d43f0059584f1a9716d9d
--- /dev/null
+++ b/backend/python/faster-whisper/protogen.sh
@@ -0,0 +1,11 @@
+#!/bin/bash
+set -e
+
+backend_dir=$(dirname $0)
+if [ -d $backend_dir/common ]; then
+ source $backend_dir/common/libbackend.sh
+else
+ source $backend_dir/../common/libbackend.sh
+fi
+
+python3 -m grpc_tools.protoc -I../.. -I./ --python_out=. --grpc_python_out=. backend.proto
\ No newline at end of file
diff --git a/backend/python/faster-whisper/requirements-cpu.txt b/backend/python/faster-whisper/requirements-cpu.txt
new file mode 100644
index 0000000000000000000000000000000000000000..3e03f3adffdd50efa14932002dee5852cd8e901b
--- /dev/null
+++ b/backend/python/faster-whisper/requirements-cpu.txt
@@ -0,0 +1,8 @@
+faster-whisper
+opencv-python
+accelerate
+compel
+peft
+sentencepiece
+torch==2.4.1
+optimum-quanto
\ No newline at end of file
diff --git a/backend/python/faster-whisper/requirements-cublas12.txt b/backend/python/faster-whisper/requirements-cublas12.txt
new file mode 100644
index 0000000000000000000000000000000000000000..8f46fa4a748525ec247c603e29e67feced640a5d
--- /dev/null
+++ b/backend/python/faster-whisper/requirements-cublas12.txt
@@ -0,0 +1,8 @@
+torch==2.4.1
+faster-whisper
+opencv-python
+accelerate
+compel
+peft
+sentencepiece
+optimum-quanto
\ No newline at end of file
diff --git a/backend/python/faster-whisper/requirements-cublas13.txt b/backend/python/faster-whisper/requirements-cublas13.txt
new file mode 100644
index 0000000000000000000000000000000000000000..3c797fce3a06806f41a50d5d5556c88ee42e41cf
--- /dev/null
+++ b/backend/python/faster-whisper/requirements-cublas13.txt
@@ -0,0 +1,9 @@
+--extra-index-url https://download.pytorch.org/whl/cu130
+torch==2.9.1
+faster-whisper
+opencv-python
+accelerate
+compel
+peft
+sentencepiece
+optimum-quanto
\ No newline at end of file
diff --git a/backend/python/faster-whisper/requirements-hipblas.txt b/backend/python/faster-whisper/requirements-hipblas.txt
new file mode 100644
index 0000000000000000000000000000000000000000..da9c9123c0d7e3c1b3e01cab65deca8d33e8fcd1
--- /dev/null
+++ b/backend/python/faster-whisper/requirements-hipblas.txt
@@ -0,0 +1,3 @@
+--extra-index-url https://download.pytorch.org/whl/rocm6.4
+torch
+faster-whisper
\ No newline at end of file
diff --git a/backend/python/faster-whisper/requirements-intel.txt b/backend/python/faster-whisper/requirements-intel.txt
new file mode 100644
index 0000000000000000000000000000000000000000..417aa0b470b7360962f728250119a82edf60cf76
--- /dev/null
+++ b/backend/python/faster-whisper/requirements-intel.txt
@@ -0,0 +1,6 @@
+--extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
+intel-extension-for-pytorch==2.3.110+xpu
+torch==2.3.1+cxx11.abi
+oneccl_bind_pt==2.3.100+xpu
+optimum[openvino]
+faster-whisper
\ No newline at end of file
diff --git a/backend/python/faster-whisper/requirements.txt b/backend/python/faster-whisper/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..e4d843df20c4c292880f011305f8d6998561ab66
--- /dev/null
+++ b/backend/python/faster-whisper/requirements.txt
@@ -0,0 +1,3 @@
+grpcio==1.71.0
+protobuf
+grpcio-tools
\ No newline at end of file
diff --git a/backend/python/faster-whisper/run.sh b/backend/python/faster-whisper/run.sh
new file mode 100644
index 0000000000000000000000000000000000000000..82b7b09ecc7d20cb3513bdd1b18f3524f7d4cbd7
--- /dev/null
+++ b/backend/python/faster-whisper/run.sh
@@ -0,0 +1,9 @@
+#!/bin/bash
+backend_dir=$(dirname $0)
+if [ -d $backend_dir/common ]; then
+ source $backend_dir/common/libbackend.sh
+else
+ source $backend_dir/../common/libbackend.sh
+fi
+
+startBackend $@
\ No newline at end of file
diff --git a/backend/python/faster-whisper/test.sh b/backend/python/faster-whisper/test.sh
new file mode 100644
index 0000000000000000000000000000000000000000..eb59f2aaf3f38f78c7ca3dc414ea490ff66776d7
--- /dev/null
+++ b/backend/python/faster-whisper/test.sh
@@ -0,0 +1,11 @@
+#!/bin/bash
+set -e
+
+backend_dir=$(dirname $0)
+if [ -d $backend_dir/common ]; then
+ source $backend_dir/common/libbackend.sh
+else
+ source $backend_dir/../common/libbackend.sh
+fi
+
+runUnittests
diff --git a/backend/python/kitten-tts/Makefile b/backend/python/kitten-tts/Makefile
new file mode 100644
index 0000000000000000000000000000000000000000..021a9679bfd261f74781b3fdff5c075729127c82
--- /dev/null
+++ b/backend/python/kitten-tts/Makefile
@@ -0,0 +1,23 @@
+.PHONY: kitten-tts
+kitten-tts:
+ bash install.sh
+
+.PHONY: run
+run: kitten-tts
+ @echo "Running kitten-tts..."
+ bash run.sh
+ @echo "kitten-tts run."
+
+.PHONY: test
+test: kitten-tts
+ @echo "Testing kitten-tts..."
+ bash test.sh
+ @echo "kitten-tts tested."
+
+.PHONY: protogen-clean
+protogen-clean:
+ $(RM) backend_pb2_grpc.py backend_pb2.py
+
+.PHONY: clean
+clean: protogen-clean
+ rm -rf venv __pycache__
\ No newline at end of file
diff --git a/backend/python/kitten-tts/backend.py b/backend/python/kitten-tts/backend.py
new file mode 100644
index 0000000000000000000000000000000000000000..b31023c8cac6f6fa59678d5af1ee47cc5bdc527d
--- /dev/null
+++ b/backend/python/kitten-tts/backend.py
@@ -0,0 +1,109 @@
+#!/usr/bin/env python3
+"""
+This is an extra gRPC server of LocalAI for Kitten TTS
+"""
+from concurrent import futures
+import time
+import argparse
+import signal
+import sys
+import os
+import backend_pb2
+import backend_pb2_grpc
+
+import torch
+from kittentts import KittenTTS
+import soundfile as sf
+
+import grpc
+
+
+_ONE_DAY_IN_SECONDS = 60 * 60 * 24
+
+# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
+MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
+KITTEN_LANGUAGE = os.environ.get('KITTEN_LANGUAGE', None)
+
+# Implement the BackendServicer class with the service methods
+class BackendServicer(backend_pb2_grpc.BackendServicer):
+ """
+ BackendServicer is the class that implements the gRPC service
+ """
+ def Health(self, request, context):
+ return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
+ def LoadModel(self, request, context):
+
+ self.AudioPath = None
+ # List available KittenTTS models
+ print("Available KittenTTS voices: expr-voice-2-m, expr-voice-2-f, expr-voice-3-m, expr-voice-3-f, expr-voice-4-m, expr-voice-4-f, expr-voice-5-m, expr-voice-5-f")
+ if os.path.isabs(request.AudioPath):
+ self.AudioPath = request.AudioPath
+ elif request.AudioPath and request.ModelFile != "" and not os.path.isabs(request.AudioPath):
+ # get base path of modelFile
+ modelFileBase = os.path.dirname(request.ModelFile)
+ # modify LoraAdapter to be relative to modelFileBase
+ self.AudioPath = os.path.join(modelFileBase, request.AudioPath)
+
+ try:
+ print("Preparing KittenTTS model, please wait", file=sys.stderr)
+ # Use the model name from request.Model, defaulting to "KittenML/kitten-tts-nano-0.1" if not specified
+ model_name = request.Model if request.Model else "KittenML/kitten-tts-nano-0.1"
+ self.tts = KittenTTS(model_name)
+ except Exception as err:
+ return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
+ # Implement your logic here for the LoadModel service
+ # Replace this with your desired response
+ return backend_pb2.Result(message="Model loaded successfully", success=True)
+
+ def TTS(self, request, context):
+ try:
+ # KittenTTS doesn't use language parameter like TTS, so we ignore it
+ # For multi-speaker models, use voice parameter
+ voice = request.voice if request.voice else "expr-voice-2-f"
+
+ # Generate audio using KittenTTS
+ audio = self.tts.generate(request.text, voice=voice)
+
+ # Save the audio using soundfile
+ sf.write(request.dst, audio, 24000)
+
+ except Exception as err:
+ return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
+ return backend_pb2.Result(success=True)
+
+def serve(address):
+ server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
+ options=[
+ ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
+ ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
+ ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
+ ])
+ backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
+ server.add_insecure_port(address)
+ server.start()
+ print("Server started. Listening on: " + address, file=sys.stderr)
+
+ # Define the signal handler function
+ def signal_handler(sig, frame):
+ print("Received termination signal. Shutting down...")
+ server.stop(0)
+ sys.exit(0)
+
+ # Set the signal handlers for SIGINT and SIGTERM
+ signal.signal(signal.SIGINT, signal_handler)
+ signal.signal(signal.SIGTERM, signal_handler)
+
+ try:
+ while True:
+ time.sleep(_ONE_DAY_IN_SECONDS)
+ except KeyboardInterrupt:
+ server.stop(0)
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Run the gRPC server.")
+ parser.add_argument(
+ "--addr", default="localhost:50051", help="The address to bind the server to."
+ )
+ args = parser.parse_args()
+
+ serve(args.addr)
diff --git a/backend/python/kitten-tts/install.sh b/backend/python/kitten-tts/install.sh
new file mode 100644
index 0000000000000000000000000000000000000000..32befa8e6c034734e54038c3b4b565522028f391
--- /dev/null
+++ b/backend/python/kitten-tts/install.sh
@@ -0,0 +1,19 @@
+#!/bin/bash
+set -e
+
+backend_dir=$(dirname $0)
+if [ -d $backend_dir/common ]; then
+ source $backend_dir/common/libbackend.sh
+else
+ source $backend_dir/../common/libbackend.sh
+fi
+
+# This is here because the Intel pip index is broken and returns 200 status codes for every package name, it just doesn't return any package links.
+# This makes uv think that the package exists in the Intel pip index, and by default it stops looking at other pip indexes once it finds a match.
+# We need uv to continue falling through to the pypi default index to find optimum[openvino] in the pypi index
+# the --upgrade actually allows us to *downgrade* torch to the version provided in the Intel pip index
+if [ "x${BUILD_PROFILE}" == "xintel" ]; then
+ EXTRA_PIP_INSTALL_FLAGS+=" --upgrade --index-strategy=unsafe-first-match"
+fi
+
+installRequirements
diff --git a/backend/python/kitten-tts/requirements.txt b/backend/python/kitten-tts/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..23439f8e57841ac37dc9ee6ec7f112c3b9bc91e0
--- /dev/null
+++ b/backend/python/kitten-tts/requirements.txt
@@ -0,0 +1,5 @@
+grpcio==1.71.0
+protobuf
+certifi
+packaging==24.1
+https://github.com/KittenML/KittenTTS/releases/download/0.1/kittentts-0.1.0-py3-none-any.whl
\ No newline at end of file
diff --git a/backend/python/kitten-tts/run.sh b/backend/python/kitten-tts/run.sh
new file mode 100644
index 0000000000000000000000000000000000000000..82b7b09ecc7d20cb3513bdd1b18f3524f7d4cbd7
--- /dev/null
+++ b/backend/python/kitten-tts/run.sh
@@ -0,0 +1,9 @@
+#!/bin/bash
+backend_dir=$(dirname $0)
+if [ -d $backend_dir/common ]; then
+ source $backend_dir/common/libbackend.sh
+else
+ source $backend_dir/../common/libbackend.sh
+fi
+
+startBackend $@
\ No newline at end of file
diff --git a/backend/python/kitten-tts/test.py b/backend/python/kitten-tts/test.py
new file mode 100644
index 0000000000000000000000000000000000000000..e0b1a0bdd1240e1edfc9218c0c2ba032e5cf6300
--- /dev/null
+++ b/backend/python/kitten-tts/test.py
@@ -0,0 +1,82 @@
+"""
+A test script to test the gRPC service
+"""
+import unittest
+import subprocess
+import time
+import backend_pb2
+import backend_pb2_grpc
+
+import grpc
+
+
+class TestBackendServicer(unittest.TestCase):
+ """
+ TestBackendServicer is the class that tests the gRPC service
+ """
+ def setUp(self):
+ """
+ This method sets up the gRPC service by starting the server
+ """
+ self.service = subprocess.Popen(["python3", "backend.py", "--addr", "localhost:50051"])
+ time.sleep(30)
+
+ def tearDown(self) -> None:
+ """
+ This method tears down the gRPC service by terminating the server
+ """
+ self.service.terminate()
+ self.service.wait()
+
+ def test_server_startup(self):
+ """
+ This method tests if the server starts up successfully
+ """
+ try:
+ self.setUp()
+ with grpc.insecure_channel("localhost:50051") as channel:
+ stub = backend_pb2_grpc.BackendStub(channel)
+ response = stub.Health(backend_pb2.HealthMessage())
+ self.assertEqual(response.message, b'OK')
+ except Exception as err:
+ print(err)
+ self.fail("Server failed to start")
+ finally:
+ self.tearDown()
+
+ def test_load_model(self):
+ """
+ This method tests if the model is loaded successfully
+ """
+ try:
+ self.setUp()
+ with grpc.insecure_channel("localhost:50051") as channel:
+ stub = backend_pb2_grpc.BackendStub(channel)
+ response = stub.LoadModel(backend_pb2.ModelOptions(Model="tts_models/en/vctk/vits"))
+ print(response)
+ self.assertTrue(response.success)
+ self.assertEqual(response.message, "Model loaded successfully")
+ except Exception as err:
+ print(err)
+ self.fail("LoadModel service failed")
+ finally:
+ self.tearDown()
+
+ def test_tts(self):
+ """
+ This method tests if the embeddings are generated successfully
+ """
+ try:
+ self.setUp()
+ with grpc.insecure_channel("localhost:50051") as channel:
+ stub = backend_pb2_grpc.BackendStub(channel)
+ response = stub.LoadModel(backend_pb2.ModelOptions(Model="tts_models/en/vctk/vits"))
+ self.assertTrue(response.success)
+ tts_request = backend_pb2.TTSRequest(text="80s TV news production music hit for tonight's biggest story")
+ tts_response = stub.TTS(tts_request)
+ self.assertIsNotNone(tts_response)
+ except Exception as err:
+ print(err)
+ self.fail("TTS service failed")
+ finally:
+ self.tearDown()
\ No newline at end of file
diff --git a/backend/python/kitten-tts/test.sh b/backend/python/kitten-tts/test.sh
new file mode 100644
index 0000000000000000000000000000000000000000..eb59f2aaf3f38f78c7ca3dc414ea490ff66776d7
--- /dev/null
+++ b/backend/python/kitten-tts/test.sh
@@ -0,0 +1,11 @@
+#!/bin/bash
+set -e
+
+backend_dir=$(dirname $0)
+if [ -d $backend_dir/common ]; then
+ source $backend_dir/common/libbackend.sh
+else
+ source $backend_dir/../common/libbackend.sh
+fi
+
+runUnittests
diff --git a/backend/python/kokoro/Makefile b/backend/python/kokoro/Makefile
new file mode 100644
index 0000000000000000000000000000000000000000..7e1b238228b1dcbf08edef8a4ae9777fd1a88d22
--- /dev/null
+++ b/backend/python/kokoro/Makefile
@@ -0,0 +1,23 @@
+.PHONY: kokoro
+kokoro:
+ bash install.sh
+
+.PHONY: run
+run: kokoro
+ @echo "Running kokoro..."
+ bash run.sh
+ @echo "kokoro run."
+
+.PHONY: test
+test: kokoro
+ @echo "Testing kokoro..."
+ bash test.sh
+ @echo "kokoro tested."
+
+.PHONY: protogen-clean
+protogen-clean:
+ $(RM) backend_pb2_grpc.py backend_pb2.py
+
+.PHONY: clean
+clean: protogen-clean
+ rm -rf venv __pycache__
\ No newline at end of file
diff --git a/backend/python/kokoro/README.md b/backend/python/kokoro/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..a890dc6d51dbd240f5f3bf4e3ca986aff341099b
--- /dev/null
+++ b/backend/python/kokoro/README.md
@@ -0,0 +1,23 @@
+# Kokoro TTS Backend for LocalAI
+
+This is a gRPC server backend for LocalAI that uses the Kokoro TTS pipeline.
+
+## Creating a separate environment for kokoro project
+
+```bash
+make kokoro
+```
+
+## Testing the gRPC server
+
+```bash
+make test
+```
+
+## Features
+
+- Lightweight TTS model with 82 million parameters
+- Apache-licensed weights
+- Fast and cost-efficient
+- Multi-language support
+- Multiple voice options
\ No newline at end of file
diff --git a/backend/python/kokoro/backend.py b/backend/python/kokoro/backend.py
new file mode 100644
index 0000000000000000000000000000000000000000..43d22238f5c5e78e02c2f680564527682f67af06
--- /dev/null
+++ b/backend/python/kokoro/backend.py
@@ -0,0 +1,116 @@
+#!/usr/bin/env python3
+"""
+This is an extra gRPC server of LocalAI for Kokoro TTS
+"""
+from concurrent import futures
+import time
+import argparse
+import signal
+import sys
+import os
+import backend_pb2
+import backend_pb2_grpc
+
+import torch
+from kokoro import KPipeline
+import soundfile as sf
+
+import grpc
+
+
+_ONE_DAY_IN_SECONDS = 60 * 60 * 24
+
+# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
+MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
+KOKORO_LANG_CODE = os.environ.get('KOKORO_LANG_CODE', 'a')
+
+# Implement the BackendServicer class with the service methods
+class BackendServicer(backend_pb2_grpc.BackendServicer):
+ """
+ BackendServicer is the class that implements the gRPC service
+ """
+ def Health(self, request, context):
+ return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
+
+ def LoadModel(self, request, context):
+ try:
+ print("Preparing Kokoro TTS pipeline, please wait", file=sys.stderr)
+ # empty dict
+ self.options = {}
+ options = request.Options
+ # The options are a list of strings in this form optname:optvalue
+ # We are storing all the options in a dict so we can use it later when
+ # generating the images
+ for opt in options:
+ if ":" not in opt:
+ continue
+ key, value = opt.split(":")
+ self.options[key] = value
+
+ # Initialize Kokoro pipeline with language code
+ lang_code = self.options.get("lang_code", KOKORO_LANG_CODE)
+ self.pipeline = KPipeline(lang_code=lang_code)
+ print(f"Kokoro TTS pipeline loaded with language code: {lang_code}", file=sys.stderr)
+ except Exception as err:
+ return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
+
+ return backend_pb2.Result(message="Kokoro TTS pipeline loaded successfully", success=True)
+
+ def TTS(self, request, context):
+ try:
+ # Get voice from request, default to 'af_heart' if not specified
+ voice = request.voice if request.voice else 'af_heart'
+
+ # Generate audio using Kokoro pipeline
+ generator = self.pipeline(request.text, voice=voice)
+
+ speechs = []
+ # Get all the audio segment
+ for i, (gs, ps, audio) in enumerate(generator):
+ speechs.append(audio)
+ print(f"Generated audio segment {i}: gs={gs}, ps={ps}", file=sys.stderr)
+ # Merges the audio segments and writes them to the destination
+ speech = torch.cat(speechs, dim=0)
+ sf.write(request.dst, speech, 24000)
+
+ except Exception as err:
+ return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
+
+ return backend_pb2.Result(success=True)
+
+def serve(address):
+ server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
+ options=[
+ ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
+ ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
+ ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
+ ])
+ backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
+ server.add_insecure_port(address)
+ server.start()
+ print("Server started. Listening on: " + address, file=sys.stderr)
+
+ # Define the signal handler function
+ def signal_handler(sig, frame):
+ print("Received termination signal. Shutting down...")
+ server.stop(0)
+ sys.exit(0)
+
+ # Set the signal handlers for SIGINT and SIGTERM
+ signal.signal(signal.SIGINT, signal_handler)
+ signal.signal(signal.SIGTERM, signal_handler)
+
+ try:
+ while True:
+ time.sleep(_ONE_DAY_IN_SECONDS)
+ except KeyboardInterrupt:
+ server.stop(0)
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Run the gRPC server.")
+ parser.add_argument(
+ "--addr", default="localhost:50051", help="The address to bind the server to."
+ )
+ args = parser.parse_args()
+
+ serve(args.addr)
diff --git a/backend/python/kokoro/install.sh b/backend/python/kokoro/install.sh
new file mode 100644
index 0000000000000000000000000000000000000000..d3b88ea684d35465d7bf47e789ff8332d3f1f12c
--- /dev/null
+++ b/backend/python/kokoro/install.sh
@@ -0,0 +1,23 @@
+#!/bin/bash
+set -e
+
+backend_dir=$(dirname $0)
+if [ -d $backend_dir/common ]; then
+ source $backend_dir/common/libbackend.sh
+else
+ source $backend_dir/../common/libbackend.sh
+fi
+
+# This is here because the Intel pip index is broken and returns 200 status codes for every package name, it just doesn't return any package links.
+# This makes uv think that the package exists in the Intel pip index, and by default it stops looking at other pip indexes once it finds a match.
+# We need uv to continue falling through to the pypi default index to find optimum[openvino] in the pypi index
+# the --upgrade actually allows us to *downgrade* torch to the version provided in the Intel pip index
+if [ "x${BUILD_PROFILE}" == "xintel" ]; then
+ EXTRA_PIP_INSTALL_FLAGS+=" --upgrade --index-strategy=unsafe-first-match"
+fi
+
+if [ "x${BUILD_PROFILE}" == "xl4t12" ]; then
+ USE_PIP=true
+fi
+
+installRequirements
diff --git a/backend/python/kokoro/requirements-cpu.txt b/backend/python/kokoro/requirements-cpu.txt
new file mode 100644
index 0000000000000000000000000000000000000000..1a1abb2f2d56ebc76a8a5f668f9a274b54d05137
--- /dev/null
+++ b/backend/python/kokoro/requirements-cpu.txt
@@ -0,0 +1,6 @@
+--extra-index-url https://download.pytorch.org/whl/cpu
+transformers
+accelerate
+torch
+kokoro
+soundfile
\ No newline at end of file
diff --git a/backend/python/kokoro/requirements-cublas12.txt b/backend/python/kokoro/requirements-cublas12.txt
new file mode 100644
index 0000000000000000000000000000000000000000..2da8b72c0fdf3961519c9fb50a92ac4918a475ad
--- /dev/null
+++ b/backend/python/kokoro/requirements-cublas12.txt
@@ -0,0 +1,6 @@
+torch==2.7.1
+torchaudio==2.7.1
+transformers
+accelerate
+kokoro
+soundfile
\ No newline at end of file
diff --git a/backend/python/kokoro/requirements-cublas13.txt b/backend/python/kokoro/requirements-cublas13.txt
new file mode 100644
index 0000000000000000000000000000000000000000..0835ac729bb86d069a7c31530058c5c418d4b888
--- /dev/null
+++ b/backend/python/kokoro/requirements-cublas13.txt
@@ -0,0 +1,7 @@
+--extra-index-url https://download.pytorch.org/whl/cu130
+torch==2.9.1
+torchaudio==2.9.1
+transformers
+accelerate
+kokoro
+soundfile
\ No newline at end of file
diff --git a/backend/python/kokoro/requirements-hipblas.txt b/backend/python/kokoro/requirements-hipblas.txt
new file mode 100644
index 0000000000000000000000000000000000000000..74262df5c3ce754d639c49de0e8c809d2959997d
--- /dev/null
+++ b/backend/python/kokoro/requirements-hipblas.txt
@@ -0,0 +1,7 @@
+--extra-index-url https://download.pytorch.org/whl/rocm6.4
+torch==2.8.0+rocm6.4
+torchaudio==2.8.0+rocm6.4
+transformers
+accelerate
+kokoro
+soundfile
\ No newline at end of file
diff --git a/backend/python/kokoro/requirements-intel.txt b/backend/python/kokoro/requirements-intel.txt
new file mode 100644
index 0000000000000000000000000000000000000000..c497efd83178fac0a1583fc84449b79da8250439
--- /dev/null
+++ b/backend/python/kokoro/requirements-intel.txt
@@ -0,0 +1,11 @@
+--extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
+intel-extension-for-pytorch==2.8.10+xpu
+torch==2.5.1+cxx11.abi
+oneccl_bind_pt==2.8.0+xpu
+torchaudio==2.5.1+cxx11.abi
+optimum[openvino]
+setuptools
+transformers==4.48.3
+accelerate
+kokoro
+soundfile
\ No newline at end of file
diff --git a/backend/python/kokoro/requirements-l4t12.txt b/backend/python/kokoro/requirements-l4t12.txt
new file mode 100644
index 0000000000000000000000000000000000000000..c03f853de215bb158de69a4f6cafe97f8d20a927
--- /dev/null
+++ b/backend/python/kokoro/requirements-l4t12.txt
@@ -0,0 +1,7 @@
+--extra-index-url https://pypi.jetson-ai-lab.io/jp6/cu126/
+torch
+torchaudio
+transformers
+accelerate
+kokoro
+soundfile
\ No newline at end of file
diff --git a/backend/python/kokoro/requirements.txt b/backend/python/kokoro/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..1a45c4bd4845533883ba7dc1d2b00963c0b05efc
--- /dev/null
+++ b/backend/python/kokoro/requirements.txt
@@ -0,0 +1,6 @@
+grpcio==1.71.0
+protobuf
+certifi
+packaging==24.1
+pip
+chardet
\ No newline at end of file
diff --git a/backend/python/kokoro/run.sh b/backend/python/kokoro/run.sh
new file mode 100644
index 0000000000000000000000000000000000000000..82b7b09ecc7d20cb3513bdd1b18f3524f7d4cbd7
--- /dev/null
+++ b/backend/python/kokoro/run.sh
@@ -0,0 +1,9 @@
+#!/bin/bash
+backend_dir=$(dirname $0)
+if [ -d $backend_dir/common ]; then
+ source $backend_dir/common/libbackend.sh
+else
+ source $backend_dir/../common/libbackend.sh
+fi
+
+startBackend $@
\ No newline at end of file
diff --git a/backend/python/kokoro/test.py b/backend/python/kokoro/test.py
new file mode 100644
index 0000000000000000000000000000000000000000..8fe65a1148a7caa75d91b07f93365525c9119754
--- /dev/null
+++ b/backend/python/kokoro/test.py
@@ -0,0 +1,87 @@
+"""
+A test script to test the gRPC service
+"""
+import unittest
+import subprocess
+import time
+import backend_pb2
+import backend_pb2_grpc
+
+import grpc
+
+
+class TestBackendServicer(unittest.TestCase):
+ """
+ TestBackendServicer is the class that tests the gRPC service
+ """
+ def setUp(self):
+ """
+ This method sets up the gRPC service by starting the server
+ """
+ self.service = subprocess.Popen(["python3", "backend.py", "--addr", "localhost:50051"])
+ time.sleep(30)
+
+ def tearDown(self) -> None:
+ """
+ This method tears down the gRPC service by terminating the server
+ """
+ self.service.terminate()
+ self.service.wait()
+
+ def test_server_startup(self):
+ """
+ This method tests if the server starts up successfully
+ """
+ try:
+ self.setUp()
+ with grpc.insecure_channel("localhost:50051") as channel:
+ stub = backend_pb2_grpc.BackendStub(channel)
+ response = stub.Health(backend_pb2.HealthMessage())
+ self.assertEqual(response.message, b'OK')
+ except Exception as err:
+ print(err)
+ self.fail("Server failed to start")
+ finally:
+ self.tearDown()
+
+ def test_load_model(self):
+ """
+ This method tests if the Kokoro pipeline is loaded successfully
+ """
+ try:
+ self.setUp()
+ with grpc.insecure_channel("localhost:50051") as channel:
+ stub = backend_pb2_grpc.BackendStub(channel)
+ response = stub.LoadModel(backend_pb2.ModelOptions(language="a"))
+ print(response)
+ self.assertTrue(response.success)
+ self.assertEqual(response.message, "Kokoro TTS pipeline loaded successfully")
+ except Exception as err:
+ print(err)
+ self.fail("LoadModel service failed")
+ finally:
+ self.tearDown()
+
+ def test_tts(self):
+ """
+ This method tests if the TTS generation works successfully
+ """
+ try:
+ self.setUp()
+ with grpc.insecure_channel("localhost:50051") as channel:
+ stub = backend_pb2_grpc.BackendStub(channel)
+ response = stub.LoadModel(backend_pb2.ModelOptions(language="a"))
+ self.assertTrue(response.success)
+ tts_request = backend_pb2.TTSRequest(
+ text="Kokoro is an open-weight TTS model with 82 million parameters.",
+ voice="af_heart",
+ dst="test_output.wav"
+ )
+ tts_response = stub.TTS(tts_request)
+ self.assertIsNotNone(tts_response)
+ self.assertTrue(tts_response.success)
+ except Exception as err:
+ print(err)
+ self.fail("TTS service failed")
+ finally:
+ self.tearDown()
\ No newline at end of file
diff --git a/backend/python/kokoro/test.sh b/backend/python/kokoro/test.sh
new file mode 100644
index 0000000000000000000000000000000000000000..eb59f2aaf3f38f78c7ca3dc414ea490ff66776d7
--- /dev/null
+++ b/backend/python/kokoro/test.sh
@@ -0,0 +1,11 @@
+#!/bin/bash
+set -e
+
+backend_dir=$(dirname $0)
+if [ -d $backend_dir/common ]; then
+ source $backend_dir/common/libbackend.sh
+else
+ source $backend_dir/../common/libbackend.sh
+fi
+
+runUnittests
diff --git a/backend/python/mlx-audio/Makefile b/backend/python/mlx-audio/Makefile
new file mode 100644
index 0000000000000000000000000000000000000000..bb7aabe3a3c2c9b8663437ddcc306d2441ee763a
--- /dev/null
+++ b/backend/python/mlx-audio/Makefile
@@ -0,0 +1,23 @@
+.PHONY: mlx-audio
+mlx-audio:
+ bash install.sh
+
+.PHONY: run
+run: mlx-audio
+ @echo "Running mlx-audio..."
+ bash run.sh
+ @echo "mlx run."
+
+.PHONY: test
+test: mlx-audio
+ @echo "Testing mlx-audio..."
+ bash test.sh
+ @echo "mlx tested."
+
+.PHONY: protogen-clean
+protogen-clean:
+ $(RM) backend_pb2_grpc.py backend_pb2.py
+
+.PHONY: clean
+clean: protogen-clean
+ rm -rf venv __pycache__
\ No newline at end of file
diff --git a/backend/python/mlx-audio/backend.py b/backend/python/mlx-audio/backend.py
new file mode 100644
index 0000000000000000000000000000000000000000..da37d2c37e71377488551d3781f48597fa5ee483
--- /dev/null
+++ b/backend/python/mlx-audio/backend.py
@@ -0,0 +1,465 @@
+#!/usr/bin/env python3
+import asyncio
+from concurrent import futures
+import argparse
+import signal
+import sys
+import os
+import shutil
+import glob
+from typing import List
+import time
+import tempfile
+
+import backend_pb2
+import backend_pb2_grpc
+
+import grpc
+from mlx_audio.tts.utils import load_model
+import soundfile as sf
+import numpy as np
+import uuid
+
+def is_float(s):
+ """Check if a string can be converted to float."""
+ try:
+ float(s)
+ return True
+ except ValueError:
+ return False
+def is_int(s):
+ """Check if a string can be converted to int."""
+ try:
+ int(s)
+ return True
+ except ValueError:
+ return False
+
+_ONE_DAY_IN_SECONDS = 60 * 60 * 24
+
+# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
+MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
+
+# Implement the BackendServicer class with the service methods
+class BackendServicer(backend_pb2_grpc.BackendServicer):
+ """
+ A gRPC servicer that implements the Backend service defined in backend.proto.
+ This backend provides TTS (Text-to-Speech) functionality using MLX-Audio.
+ """
+
+ def Health(self, request, context):
+ """
+ Returns a health check message.
+
+ Args:
+ request: The health check request.
+ context: The gRPC context.
+
+ Returns:
+ backend_pb2.Reply: The health check reply.
+ """
+ return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
+
+ async def LoadModel(self, request, context):
+ """
+ Loads a TTS model using MLX-Audio.
+
+ Args:
+ request: The load model request.
+ context: The gRPC context.
+
+ Returns:
+ backend_pb2.Result: The load model result.
+ """
+ try:
+ print(f"Loading MLX-Audio TTS model: {request.Model}", file=sys.stderr)
+ print(f"Request: {request}", file=sys.stderr)
+
+ # Parse options like in the kokoro backend
+ options = request.Options
+ self.options = {}
+
+ # The options are a list of strings in this form optname:optvalue
+ # We store all the options in a dict for later use
+ for opt in options:
+ if ":" not in opt:
+ continue
+ key, value = opt.split(":", 1) # Split only on first colon to handle values with colons
+
+ # Convert numeric values to appropriate types
+ if is_float(value):
+ value = float(value)
+ elif is_int(value):
+ value = int(value)
+ elif value.lower() in ["true", "false"]:
+ value = value.lower() == "true"
+
+ self.options[key] = value
+
+ print(f"Options: {self.options}", file=sys.stderr)
+
+ # Load the model using MLX-Audio's load_model function
+ try:
+ self.tts_model = load_model(request.Model)
+ self.model_path = request.Model
+ print(f"TTS model loaded successfully from {request.Model}", file=sys.stderr)
+ except Exception as model_err:
+ print(f"Error loading TTS model: {model_err}", file=sys.stderr)
+ return backend_pb2.Result(success=False, message=f"Failed to load model: {model_err}")
+
+ except Exception as err:
+ print(f"Error loading MLX-Audio TTS model {err=}, {type(err)=}", file=sys.stderr)
+ return backend_pb2.Result(success=False, message=f"Error loading MLX-Audio TTS model: {err}")
+
+ print("MLX-Audio TTS model loaded successfully", file=sys.stderr)
+ return backend_pb2.Result(message="MLX-Audio TTS model loaded successfully", success=True)
+
+ def TTS(self, request, context):
+ """
+ Generates TTS audio from text using MLX-Audio.
+
+ Args:
+ request: A TTSRequest object containing text, model, destination, voice, and language.
+ context: A grpc.ServicerContext object that provides information about the RPC.
+
+ Returns:
+ A Result object indicating success or failure.
+ """
+ try:
+ # Check if model is loaded
+ if not hasattr(self, 'tts_model') or self.tts_model is None:
+ return backend_pb2.Result(success=False, message="TTS model not loaded. Please call LoadModel first.")
+
+ print(f"Generating TTS with MLX-Audio - text: {request.text[:50]}..., voice: {request.voice}, language: {request.language}", file=sys.stderr)
+
+ # Handle speed parameter based on model type
+ speed_value = self._handle_speed_parameter(request, self.model_path)
+
+ # Map language names to codes if needed
+ lang_code = self._map_language_code(request.language, request.voice)
+
+ # Prepare generation parameters
+ gen_params = {
+ "text": request.text,
+ "speed": speed_value,
+ "verbose": False,
+ }
+
+ # Add model-specific parameters
+ if request.voice and request.voice.strip():
+ gen_params["voice"] = request.voice
+
+ # Check if model supports language codes (primarily Kokoro)
+ if "kokoro" in self.model_path.lower():
+ gen_params["lang_code"] = lang_code
+
+ # Add pitch and gender for Spark models
+ if "spark" in self.model_path.lower():
+ gen_params["pitch"] = 1.0 # Default to moderate
+ gen_params["gender"] = "female" # Default to female
+
+ print(f"Generation parameters: {gen_params}", file=sys.stderr)
+
+ # Generate audio using the loaded model
+ try:
+ results = self.tts_model.generate(**gen_params)
+ except Exception as gen_err:
+ print(f"Error during TTS generation: {gen_err}", file=sys.stderr)
+ return backend_pb2.Result(success=False, message=f"TTS generation failed: {gen_err}")
+
+ # Process the generated audio segments
+ audio_arrays = []
+ for segment in results:
+ audio_arrays.append(segment.audio)
+
+ # If no segments, return error
+ if not audio_arrays:
+ print("No audio segments generated", file=sys.stderr)
+ return backend_pb2.Result(success=False, message="No audio generated")
+
+ # Concatenate all segments
+ cat_audio = np.concatenate(audio_arrays, axis=0)
+
+ # Generate output filename and path
+ if request.dst:
+ output_path = request.dst
+ else:
+ unique_id = str(uuid.uuid4())
+ filename = f"tts_{unique_id}.wav"
+ output_path = filename
+
+ # Write the audio as a WAV
+ try:
+ sf.write(output_path, cat_audio, 24000)
+ print(f"Successfully wrote audio file to {output_path}", file=sys.stderr)
+
+ # Verify the file exists and has content
+ if not os.path.exists(output_path):
+ print(f"File was not created at {output_path}", file=sys.stderr)
+ return backend_pb2.Result(success=False, message="Failed to create audio file")
+
+ file_size = os.path.getsize(output_path)
+ if file_size == 0:
+ print("File was created but is empty", file=sys.stderr)
+ return backend_pb2.Result(success=False, message="Generated audio file is empty")
+
+ print(f"Audio file size: {file_size} bytes", file=sys.stderr)
+
+ except Exception as write_err:
+ print(f"Error writing audio file: {write_err}", file=sys.stderr)
+ return backend_pb2.Result(success=False, message=f"Failed to save audio: {write_err}")
+
+ return backend_pb2.Result(success=True, message=f"TTS audio generated successfully: {output_path}")
+
+ except Exception as e:
+ print(f"Error in MLX-Audio TTS: {e}", file=sys.stderr)
+ return backend_pb2.Result(success=False, message=f"TTS generation failed: {str(e)}")
+
+ async def Predict(self, request, context):
+ """
+ Generates TTS audio based on the given prompt using MLX-Audio TTS.
+ This is a fallback method for compatibility with the Predict endpoint.
+
+ Args:
+ request: The predict request.
+ context: The gRPC context.
+
+ Returns:
+ backend_pb2.Reply: The predict result.
+ """
+ try:
+ # Check if model is loaded
+ if not hasattr(self, 'tts_model') or self.tts_model is None:
+ context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
+ context.set_details("TTS model not loaded. Please call LoadModel first.")
+ return backend_pb2.Reply(message=bytes("", encoding='utf-8'))
+
+ # For TTS, we expect the prompt to contain the text to synthesize
+ if not request.Prompt:
+ context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
+ context.set_details("Prompt is required for TTS generation")
+ return backend_pb2.Reply(message=bytes("", encoding='utf-8'))
+
+ # Handle speed parameter based on model type
+ speed_value = self._handle_speed_parameter(request, self.model_path)
+
+ # Map language names to codes if needed
+ lang_code = self._map_language_code(None, None) # Use defaults for Predict
+
+ # Prepare generation parameters
+ gen_params = {
+ "text": request.Prompt,
+ "speed": speed_value,
+ "verbose": False,
+ }
+
+ # Add model-specific parameters
+ if hasattr(self, 'options') and 'voice' in self.options:
+ gen_params["voice"] = self.options['voice']
+
+ # Check if model supports language codes (primarily Kokoro)
+ if "kokoro" in self.model_path.lower():
+ gen_params["lang_code"] = lang_code
+
+ print(f"Generating TTS with MLX-Audio - text: {request.Prompt[:50]}..., params: {gen_params}", file=sys.stderr)
+
+ # Generate audio using the loaded model
+ try:
+ results = self.tts_model.generate(**gen_params)
+ except Exception as gen_err:
+ print(f"Error during TTS generation: {gen_err}", file=sys.stderr)
+ context.set_code(grpc.StatusCode.INTERNAL)
+ context.set_details(f"TTS generation failed: {gen_err}")
+ return backend_pb2.Reply(message=bytes("", encoding='utf-8'))
+
+ # Process the generated audio segments
+ audio_arrays = []
+ for segment in results:
+ audio_arrays.append(segment.audio)
+
+ # If no segments, return error
+ if not audio_arrays:
+ print("No audio segments generated", file=sys.stderr)
+ return backend_pb2.Reply(message=bytes("No audio generated", encoding='utf-8'))
+
+ # Concatenate all segments
+ cat_audio = np.concatenate(audio_arrays, axis=0)
+ duration = len(cat_audio) / 24000 # Assuming 24kHz sample rate
+
+ # Return success message with audio information
+ response = f"TTS audio generated successfully. Duration: {duration:.2f}s, Sample rate: 24000Hz"
+ return backend_pb2.Reply(message=bytes(response, encoding='utf-8'))
+
+ except Exception as e:
+ print(f"Error in MLX-Audio TTS Predict: {e}", file=sys.stderr)
+ context.set_code(grpc.StatusCode.INTERNAL)
+ context.set_details(f"TTS generation failed: {str(e)}")
+ return backend_pb2.Reply(message=bytes("", encoding='utf-8'))
+
+ def _handle_speed_parameter(self, request, model_path):
+ """
+ Handle speed parameter based on model type.
+
+ Args:
+ request: The TTSRequest object.
+ model_path: The model path to determine model type.
+
+ Returns:
+ float: The processed speed value.
+ """
+ # Get speed from options if available
+ speed = 1.0
+ if hasattr(self, 'options') and 'speed' in self.options:
+ speed = self.options['speed']
+
+ # Handle speed parameter based on model type
+ if "spark" in model_path.lower():
+ # Spark actually expects float values that map to speed descriptions
+ speed_map = {
+ "very_low": 0.0,
+ "low": 0.5,
+ "moderate": 1.0,
+ "high": 1.5,
+ "very_high": 2.0,
+ }
+ if isinstance(speed, str) and speed in speed_map:
+ speed_value = speed_map[speed]
+ else:
+ # Try to use as float, default to 1.0 (moderate) if invalid
+ try:
+ speed_value = float(speed)
+ if speed_value not in [0.0, 0.5, 1.0, 1.5, 2.0]:
+ speed_value = 1.0 # Default to moderate
+ except:
+ speed_value = 1.0 # Default to moderate
+ else:
+ # Other models use float speed values
+ try:
+ speed_value = float(speed)
+ if speed_value < 0.5 or speed_value > 2.0:
+ speed_value = 1.0 # Default to 1.0 if out of range
+ except ValueError:
+ speed_value = 1.0 # Default to 1.0 if invalid
+
+ return speed_value
+
+ def _map_language_code(self, language, voice):
+ """
+ Map language names to codes if needed.
+
+ Args:
+ language: The language parameter from the request.
+ voice: The voice parameter from the request.
+
+ Returns:
+ str: The language code.
+ """
+ if not language:
+ # Default to voice[0] if not found
+ return voice[0] if voice else "a"
+
+ # Map language names to codes if needed
+ language_map = {
+ "american_english": "a",
+ "british_english": "b",
+ "spanish": "e",
+ "french": "f",
+ "hindi": "h",
+ "italian": "i",
+ "portuguese": "p",
+ "japanese": "j",
+ "mandarin_chinese": "z",
+ # Also accept direct language codes
+ "a": "a", "b": "b", "e": "e", "f": "f", "h": "h", "i": "i", "p": "p", "j": "j", "z": "z",
+ }
+
+ return language_map.get(language.lower(), language)
+
+ def _build_generation_params(self, request, default_speed=1.0):
+ """
+ Build generation parameters from request attributes and options for MLX-Audio TTS.
+
+ Args:
+ request: The gRPC request.
+ default_speed: Default speed if not specified.
+
+ Returns:
+ dict: Generation parameters for MLX-Audio
+ """
+ # Initialize generation parameters for MLX-Audio TTS
+ generation_params = {
+ 'speed': default_speed,
+ 'voice': 'af_heart', # Default voice
+ 'lang_code': 'a', # Default language code
+ }
+
+ # Extract parameters from request attributes
+ if hasattr(request, 'Temperature') and request.Temperature > 0:
+ # Temperature could be mapped to speed variation
+ generation_params['speed'] = 1.0 + (request.Temperature - 0.5) * 0.5
+
+ # Override with options if available
+ if hasattr(self, 'options'):
+ # Speed from options
+ if 'speed' in self.options:
+ generation_params['speed'] = self.options['speed']
+
+ # Voice from options
+ if 'voice' in self.options:
+ generation_params['voice'] = self.options['voice']
+
+ # Language code from options
+ if 'lang_code' in self.options:
+ generation_params['lang_code'] = self.options['lang_code']
+
+ # Model-specific parameters
+ param_option_mapping = {
+ 'temp': 'speed',
+ 'temperature': 'speed',
+ 'top_p': 'speed', # Map top_p to speed variation
+ }
+
+ for option_key, param_key in param_option_mapping.items():
+ if option_key in self.options:
+ if param_key == 'speed':
+ # Ensure speed is within reasonable bounds
+ speed_val = float(self.options[option_key])
+ if 0.5 <= speed_val <= 2.0:
+ generation_params[param_key] = speed_val
+
+ return generation_params
+
+async def serve(address):
+ # Start asyncio gRPC server
+ server = grpc.aio.server(migration_thread_pool=futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
+ options=[
+ ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
+ ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
+ ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
+ ])
+ # Add the servicer to the server
+ backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
+ # Bind the server to the address
+ server.add_insecure_port(address)
+
+ # Gracefully shutdown the server on SIGTERM or SIGINT
+ loop = asyncio.get_event_loop()
+ for sig in (signal.SIGINT, signal.SIGTERM):
+ loop.add_signal_handler(
+ sig, lambda: asyncio.ensure_future(server.stop(5))
+ )
+
+ # Start the server
+ await server.start()
+ print("MLX-Audio TTS Server started. Listening on: " + address, file=sys.stderr)
+ # Wait for the server to be terminated
+ await server.wait_for_termination()
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Run the MLX-Audio TTS gRPC server.")
+ parser.add_argument(
+ "--addr", default="localhost:50051", help="The address to bind the server to."
+ )
+ args = parser.parse_args()
+
+ asyncio.run(serve(args.addr))
diff --git a/backend/python/mlx-audio/install.sh b/backend/python/mlx-audio/install.sh
new file mode 100644
index 0000000000000000000000000000000000000000..b8ee4855249062294a6d518538c00ccefb00dd46
--- /dev/null
+++ b/backend/python/mlx-audio/install.sh
@@ -0,0 +1,14 @@
+#!/bin/bash
+set -e
+
+USE_PIP=true
+
+backend_dir=$(dirname $0)
+
+if [ -d $backend_dir/common ]; then
+ source $backend_dir/common/libbackend.sh
+else
+ source $backend_dir/../common/libbackend.sh
+fi
+
+installRequirements
diff --git a/backend/python/mlx-audio/requirements-mps.txt b/backend/python/mlx-audio/requirements-mps.txt
new file mode 100644
index 0000000000000000000000000000000000000000..31df2a190181daaf9b80af306e7199ca5b61d5c0
--- /dev/null
+++ b/backend/python/mlx-audio/requirements-mps.txt
@@ -0,0 +1 @@
+git+https://github.com/Blaizzy/mlx-audio
\ No newline at end of file
diff --git a/backend/python/mlx-audio/requirements.txt b/backend/python/mlx-audio/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..5f47f0cfd87eecba37f59197534d961e4c90c7ed
--- /dev/null
+++ b/backend/python/mlx-audio/requirements.txt
@@ -0,0 +1,7 @@
+grpcio==1.71.0
+protobuf
+certifi
+setuptools
+mlx-audio
+soundfile
+numpy
\ No newline at end of file
diff --git a/backend/python/mlx-audio/run.sh b/backend/python/mlx-audio/run.sh
new file mode 100644
index 0000000000000000000000000000000000000000..fc88f97da712f14faef73f9e8b96589dd8ecc2ad
--- /dev/null
+++ b/backend/python/mlx-audio/run.sh
@@ -0,0 +1,11 @@
+#!/bin/bash
+
+backend_dir=$(dirname $0)
+
+if [ -d $backend_dir/common ]; then
+ source $backend_dir/common/libbackend.sh
+else
+ source $backend_dir/../common/libbackend.sh
+fi
+
+startBackend $@
\ No newline at end of file
diff --git a/backend/python/mlx-audio/test.py b/backend/python/mlx-audio/test.py
new file mode 100644
index 0000000000000000000000000000000000000000..792cb06480fbef445f98e4a9d2cd28f49213a3f2
--- /dev/null
+++ b/backend/python/mlx-audio/test.py
@@ -0,0 +1,142 @@
+import unittest
+import subprocess
+import time
+import backend_pb2
+import backend_pb2_grpc
+
+import grpc
+
+import unittest
+import subprocess
+import time
+import grpc
+import backend_pb2_grpc
+import backend_pb2
+
+class TestBackendServicer(unittest.TestCase):
+ """
+ TestBackendServicer is the class that tests the gRPC service.
+
+ This class contains methods to test the startup and shutdown of the gRPC service.
+ """
+ def setUp(self):
+ self.service = subprocess.Popen(["python", "backend.py", "--addr", "localhost:50051"])
+ time.sleep(10)
+
+ def tearDown(self) -> None:
+ self.service.terminate()
+ self.service.wait()
+
+ def test_server_startup(self):
+ try:
+ self.setUp()
+ with grpc.insecure_channel("localhost:50051") as channel:
+ stub = backend_pb2_grpc.BackendStub(channel)
+ response = stub.Health(backend_pb2.HealthMessage())
+ self.assertEqual(response.message, b'OK')
+ except Exception as err:
+ print(err)
+ self.fail("Server failed to start")
+ finally:
+ self.tearDown()
+ def test_load_model(self):
+ """
+ This method tests if the TTS model is loaded successfully
+ """
+ try:
+ self.setUp()
+ with grpc.insecure_channel("localhost:50051") as channel:
+ stub = backend_pb2_grpc.BackendStub(channel)
+ response = stub.LoadModel(backend_pb2.ModelOptions(Model="mlx-community/Kokoro-82M-4bit"))
+ self.assertTrue(response.success)
+ self.assertEqual(response.message, "MLX-Audio TTS model loaded successfully")
+ except Exception as err:
+ print(err)
+ self.fail("LoadModel service failed")
+ finally:
+ self.tearDown()
+
+ def test_tts_generation(self):
+ """
+ This method tests if TTS audio is generated successfully
+ """
+ try:
+ self.setUp()
+ with grpc.insecure_channel("localhost:50051") as channel:
+ stub = backend_pb2_grpc.BackendStub(channel)
+ response = stub.LoadModel(backend_pb2.ModelOptions(Model="mlx-community/Kokoro-82M-4bit"))
+ self.assertTrue(response.success)
+
+ # Test TTS generation
+ tts_req = backend_pb2.TTSRequest(
+ text="Hello, this is a test of the MLX-Audio TTS system.",
+ model="mlx-community/Kokoro-82M-4bit",
+ voice="af_heart",
+ language="a"
+ )
+ tts_resp = stub.TTS(tts_req)
+ self.assertTrue(tts_resp.success)
+ self.assertIn("TTS audio generated successfully", tts_resp.message)
+ except Exception as err:
+ print(err)
+ self.fail("TTS service failed")
+ finally:
+ self.tearDown()
+
+ def test_tts_with_options(self):
+ """
+ This method tests if TTS works with various options and parameters
+ """
+ try:
+ self.setUp()
+ with grpc.insecure_channel("localhost:50051") as channel:
+ stub = backend_pb2_grpc.BackendStub(channel)
+ response = stub.LoadModel(backend_pb2.ModelOptions(
+ Model="mlx-community/Kokoro-82M-4bit",
+ Options=["voice:af_soft", "speed:1.2", "lang_code:b"]
+ ))
+ self.assertTrue(response.success)
+
+ # Test TTS generation with different voice and language
+ tts_req = backend_pb2.TTSRequest(
+ text="Hello, this is a test with British English accent.",
+ model="mlx-community/Kokoro-82M-4bit",
+ voice="af_soft",
+ language="b"
+ )
+ tts_resp = stub.TTS(tts_req)
+ self.assertTrue(tts_resp.success)
+ self.assertIn("TTS audio generated successfully", tts_resp.message)
+ except Exception as err:
+ print(err)
+ self.fail("TTS with options service failed")
+ finally:
+ self.tearDown()
+
+
+ def test_tts_multilingual(self):
+ """
+ This method tests if TTS works with different languages
+ """
+ try:
+ self.setUp()
+ with grpc.insecure_channel("localhost:50051") as channel:
+ stub = backend_pb2_grpc.BackendStub(channel)
+ response = stub.LoadModel(backend_pb2.ModelOptions(Model="mlx-community/Kokoro-82M-4bit"))
+ self.assertTrue(response.success)
+
+ # Test Spanish TTS
+ tts_req = backend_pb2.TTSRequest(
+ text="Hola, esto es una prueba del sistema TTS MLX-Audio.",
+ model="mlx-community/Kokoro-82M-4bit",
+ voice="af_heart",
+ language="e"
+ )
+ tts_resp = stub.TTS(tts_req)
+ self.assertTrue(tts_resp.success)
+ self.assertIn("TTS audio generated successfully", tts_resp.message)
+ except Exception as err:
+ print(err)
+ self.fail("Multilingual TTS service failed")
+ finally:
+ self.tearDown()
\ No newline at end of file
diff --git a/backend/python/mlx-audio/test.sh b/backend/python/mlx-audio/test.sh
new file mode 100644
index 0000000000000000000000000000000000000000..f31ae54e47dc7f5a10f630fa1d7b5c8ea56f0c9e
--- /dev/null
+++ b/backend/python/mlx-audio/test.sh
@@ -0,0 +1,12 @@
+#!/bin/bash
+set -e
+
+backend_dir=$(dirname $0)
+
+if [ -d $backend_dir/common ]; then
+ source $backend_dir/common/libbackend.sh
+else
+ source $backend_dir/../common/libbackend.sh
+fi
+
+runUnittests
diff --git a/backend/python/mlx-vlm/Makefile b/backend/python/mlx-vlm/Makefile
new file mode 100644
index 0000000000000000000000000000000000000000..804031aa970dcb86de40d9d68458ca2a4e176c5f
--- /dev/null
+++ b/backend/python/mlx-vlm/Makefile
@@ -0,0 +1,23 @@
+.PHONY: mlx-vlm
+mlx-vlm:
+ bash install.sh
+
+.PHONY: run
+run: mlx-vlm
+ @echo "Running mlx-vlm..."
+ bash run.sh
+ @echo "mlx run."
+
+.PHONY: test
+test: mlx-vlm
+ @echo "Testing mlx-vlm..."
+ bash test.sh
+ @echo "mlx tested."
+
+.PHONY: protogen-clean
+protogen-clean:
+ $(RM) backend_pb2_grpc.py backend_pb2.py
+
+.PHONY: clean
+clean: protogen-clean
+ rm -rf venv __pycache__
\ No newline at end of file
diff --git a/backend/python/mlx-vlm/backend.py b/backend/python/mlx-vlm/backend.py
new file mode 100644
index 0000000000000000000000000000000000000000..6c5f8b1896c51627495247023e65c5332f9bc67e
--- /dev/null
+++ b/backend/python/mlx-vlm/backend.py
@@ -0,0 +1,475 @@
+#!/usr/bin/env python3
+import asyncio
+from concurrent import futures
+import argparse
+import signal
+import sys
+import os
+from typing import List
+import time
+
+import backend_pb2
+import backend_pb2_grpc
+
+import grpc
+from mlx_vlm import load, generate, stream_generate
+from mlx_vlm.prompt_utils import apply_chat_template
+from mlx_vlm.utils import load_config, load_image
+import mlx.core as mx
+import base64
+import io
+from PIL import Image
+import tempfile
+
+def is_float(s):
+ """Check if a string can be converted to float."""
+ try:
+ float(s)
+ return True
+ except ValueError:
+ return False
+def is_int(s):
+ """Check if a string can be converted to int."""
+ try:
+ int(s)
+ return True
+ except ValueError:
+ return False
+
+_ONE_DAY_IN_SECONDS = 60 * 60 * 24
+
+# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
+MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
+
+# Implement the BackendServicer class with the service methods
+class BackendServicer(backend_pb2_grpc.BackendServicer):
+ """
+ A gRPC servicer that implements the Backend service defined in backend.proto.
+ """
+
+ def Health(self, request, context):
+ """
+ Returns a health check message.
+
+ Args:
+ request: The health check request.
+ context: The gRPC context.
+
+ Returns:
+ backend_pb2.Reply: The health check reply.
+ """
+ return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
+
+ async def LoadModel(self, request, context):
+ """
+ Loads a multimodal vision-language model using MLX-VLM.
+
+ Args:
+ request: The load model request.
+ context: The gRPC context.
+
+ Returns:
+ backend_pb2.Result: The load model result.
+ """
+ try:
+ print(f"Loading MLX-VLM model: {request.Model}", file=sys.stderr)
+ print(f"Request: {request}", file=sys.stderr)
+
+ # Parse options like in the diffusers backend
+ options = request.Options
+ self.options = {}
+
+ # The options are a list of strings in this form optname:optvalue
+ # We store all the options in a dict for later use
+ for opt in options:
+ if ":" not in opt:
+ continue
+ key, value = opt.split(":", 1) # Split only on first colon to handle values with colons
+
+ if is_float(value):
+ value = float(value)
+ elif is_int(value):
+ value = int(value)
+ elif value.lower() in ["true", "false"]:
+ value = value.lower() == "true"
+
+ self.options[key] = value
+
+ print(f"Options: {self.options}", file=sys.stderr)
+
+ # Load model and processor using MLX-VLM
+ # mlx-vlm load function returns (model, processor) instead of (model, tokenizer)
+ self.model, self.processor = load(request.Model)
+
+ # Load model config for chat template support
+ self.config = load_config(request.Model)
+
+ except Exception as err:
+ print(f"Error loading MLX-VLM model {err=}, {type(err)=}", file=sys.stderr)
+ return backend_pb2.Result(success=False, message=f"Error loading MLX-VLM model: {err}")
+
+ print("MLX-VLM model loaded successfully", file=sys.stderr)
+ return backend_pb2.Result(message="MLX-VLM model loaded successfully", success=True)
+
+ async def Predict(self, request, context):
+ """
+ Generates text based on the given prompt and sampling parameters using MLX-VLM with multimodal support.
+
+ Args:
+ request: The predict request.
+ context: The gRPC context.
+
+ Returns:
+ backend_pb2.Reply: The predict result.
+ """
+ temp_files = []
+ try:
+ # Process images and audios from request
+ image_paths = []
+ audio_paths = []
+
+ # Process images
+ if request.Images:
+ for img_data in request.Images:
+ img_path = self.load_image_from_base64(img_data)
+ if img_path:
+ image_paths.append(img_path)
+ temp_files.append(img_path)
+
+ # Process audios
+ if request.Audios:
+ for audio_data in request.Audios:
+ audio_path = self.load_audio_from_base64(audio_data)
+ if audio_path:
+ audio_paths.append(audio_path)
+ temp_files.append(audio_path)
+
+ # Prepare the prompt with multimodal information
+ prompt = self._prepare_prompt(request, num_images=len(image_paths), num_audios=len(audio_paths))
+
+ # Build generation parameters using request attributes and options
+ max_tokens, generation_params = self._build_generation_params(request)
+
+ print(f"Generating text with MLX-VLM - max_tokens: {max_tokens}, params: {generation_params}", file=sys.stderr)
+ print(f"Images: {len(image_paths)}, Audios: {len(audio_paths)}", file=sys.stderr)
+
+ # Generate text using MLX-VLM with multimodal inputs
+ response = generate(
+ model=self.model,
+ processor=self.processor,
+ prompt=prompt,
+ image=image_paths if image_paths else None,
+ audio=audio_paths if audio_paths else None,
+ max_tokens=max_tokens,
+ temperature=generation_params.get('temp', 0.6),
+ top_p=generation_params.get('top_p', 1.0),
+ verbose=False
+ )
+
+ return backend_pb2.Reply(message=bytes(response, encoding='utf-8'))
+
+ except Exception as e:
+ print(f"Error in MLX-VLM Predict: {e}", file=sys.stderr)
+ context.set_code(grpc.StatusCode.INTERNAL)
+ context.set_details(f"Generation failed: {str(e)}")
+ return backend_pb2.Reply(message=bytes("", encoding='utf-8'))
+ finally:
+ # Clean up temporary files
+ self.cleanup_temp_files(temp_files)
+
+ def Embedding(self, request, context):
+ """
+ A gRPC method that calculates embeddings for a given sentence.
+
+ Note: MLX-VLM doesn't support embeddings directly. This method returns an error.
+
+ Args:
+ request: An EmbeddingRequest object that contains the request parameters.
+ context: A grpc.ServicerContext object that provides information about the RPC.
+
+ Returns:
+ An EmbeddingResult object that contains the calculated embeddings.
+ """
+ print("Embeddings not supported in MLX-VLM backend", file=sys.stderr)
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+ context.set_details("Embeddings are not supported in the MLX-VLM backend.")
+ return backend_pb2.EmbeddingResult()
+
+ async def PredictStream(self, request, context):
+ """
+ Generates text based on the given prompt and sampling parameters, and streams the results using MLX-VLM with multimodal support.
+
+ Args:
+ request: The predict stream request.
+ context: The gRPC context.
+
+ Yields:
+ backend_pb2.Reply: Streaming predict results.
+ """
+ temp_files = []
+ try:
+ # Process images and audios from request
+ image_paths = []
+ audio_paths = []
+
+ # Process images
+ if request.Images:
+ for img_data in request.Images:
+ img_path = self.load_image_from_base64(img_data)
+ if img_path:
+ image_paths.append(img_path)
+ temp_files.append(img_path)
+
+ # Process audios
+ if request.Audios:
+ for audio_data in request.Audios:
+ audio_path = self.load_audio_from_base64(audio_data)
+ if audio_path:
+ audio_paths.append(audio_path)
+ temp_files.append(audio_path)
+
+ # Prepare the prompt with multimodal information
+ prompt = self._prepare_prompt(request, num_images=len(image_paths), num_audios=len(audio_paths))
+
+ # Build generation parameters using request attributes and options
+ max_tokens, generation_params = self._build_generation_params(request, default_max_tokens=512)
+
+ print(f"Streaming text with MLX-VLM - max_tokens: {max_tokens}, params: {generation_params}", file=sys.stderr)
+ print(f"Images: {len(image_paths)}, Audios: {len(audio_paths)}", file=sys.stderr)
+
+ # Stream text generation using MLX-VLM with multimodal inputs
+ for response in stream_generate(
+ model=self.model,
+ processor=self.processor,
+ prompt=prompt,
+ image=image_paths if image_paths else None,
+ audio=audio_paths if audio_paths else None,
+ max_tokens=max_tokens,
+ temperature=generation_params.get('temp', 0.6),
+ top_p=generation_params.get('top_p', 1.0),
+ ):
+ yield backend_pb2.Reply(message=bytes(response.text, encoding='utf-8'))
+
+ except Exception as e:
+ print(f"Error in MLX-VLM PredictStream: {e}", file=sys.stderr)
+ context.set_code(grpc.StatusCode.INTERNAL)
+ context.set_details(f"Streaming generation failed: {str(e)}")
+ yield backend_pb2.Reply(message=bytes("", encoding='utf-8'))
+ finally:
+ # Clean up temporary files
+ self.cleanup_temp_files(temp_files)
+
+ def _prepare_prompt(self, request, num_images=0, num_audios=0):
+ """
+ Prepare the prompt for MLX-VLM generation, handling chat templates and multimodal inputs.
+
+ Args:
+ request: The gRPC request containing prompt and message information.
+ num_images: Number of images in the request.
+ num_audios: Number of audio files in the request.
+
+ Returns:
+ str: The prepared prompt.
+ """
+ # If tokenizer template is enabled and messages are provided instead of prompt, apply the tokenizer template
+ if not request.Prompt and request.UseTokenizerTemplate and request.Messages:
+ # Convert gRPC messages to the format expected by apply_chat_template
+ messages = []
+ for msg in request.Messages:
+ messages.append({"role": msg.role, "content": msg.content})
+
+ # Use mlx-vlm's apply_chat_template which handles multimodal inputs
+ prompt = apply_chat_template(
+ self.processor,
+ self.config,
+ messages,
+ num_images=num_images,
+ num_audios=num_audios
+ )
+ return prompt
+ elif request.Prompt:
+ # If we have a direct prompt but also have images/audio, we need to format it properly
+ if num_images > 0 or num_audios > 0:
+ # Create a simple message structure for multimodal prompt
+ messages = [{"role": "user", "content": request.Prompt}]
+ prompt = apply_chat_template(
+ self.processor,
+ self.config,
+ messages,
+ num_images=num_images,
+ num_audios=num_audios
+ )
+ return prompt
+ else:
+ return request.Prompt
+ else:
+ # Fallback to empty prompt with multimodal template if we have media
+ if num_images > 0 or num_audios > 0:
+ messages = [{"role": "user", "content": ""}]
+ prompt = apply_chat_template(
+ self.processor,
+ self.config,
+ messages,
+ num_images=num_images,
+ num_audios=num_audios
+ )
+ return prompt
+ else:
+ return ""
+
+
+
+
+
+ def _build_generation_params(self, request, default_max_tokens=200):
+ """
+ Build generation parameters from request attributes and options for MLX-VLM.
+
+ Args:
+ request: The gRPC request.
+ default_max_tokens: Default max_tokens if not specified.
+
+ Returns:
+ tuple: (max_tokens, generation_params dict)
+ """
+ # Extract max_tokens
+ max_tokens = getattr(request, 'Tokens', default_max_tokens)
+ if max_tokens == 0:
+ max_tokens = default_max_tokens
+
+ # Extract generation parameters from request attributes
+ temp = getattr(request, 'Temperature', 0.0)
+ if temp == 0.0:
+ temp = 0.6 # Default temperature
+
+ top_p = getattr(request, 'TopP', 0.0)
+ if top_p == 0.0:
+ top_p = 1.0 # Default top_p
+
+ # Initialize generation parameters for MLX-VLM
+ generation_params = {
+ 'temp': temp,
+ 'top_p': top_p,
+ }
+
+ # Add seed if specified
+ seed = getattr(request, 'Seed', 0)
+ if seed != 0:
+ mx.random.seed(seed)
+
+ # Override with options if available
+ if hasattr(self, 'options'):
+ # Max tokens from options
+ if 'max_tokens' in self.options:
+ max_tokens = self.options['max_tokens']
+
+ # Generation parameters from options
+ param_option_mapping = {
+ 'temp': 'temp',
+ 'temperature': 'temp', # alias
+ 'top_p': 'top_p',
+ }
+
+ for option_key, param_key in param_option_mapping.items():
+ if option_key in self.options:
+ generation_params[param_key] = self.options[option_key]
+
+ # Handle seed from options
+ if 'seed' in self.options:
+ mx.random.seed(self.options['seed'])
+
+ return max_tokens, generation_params
+
+ def load_image_from_base64(self, image_data: str):
+ """
+ Load an image from base64 encoded data.
+
+ Args:
+ image_data (str): Base64 encoded image data.
+
+ Returns:
+ PIL.Image or str: The loaded image or path to the image.
+ """
+ try:
+ decoded_data = base64.b64decode(image_data)
+ image = Image.open(io.BytesIO(decoded_data))
+
+ # Save to temporary file for mlx-vlm
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') as tmp_file:
+ image.save(tmp_file.name, format='JPEG')
+ return tmp_file.name
+
+ except Exception as e:
+ print(f"Error loading image from base64: {e}", file=sys.stderr)
+ return None
+
+ def load_audio_from_base64(self, audio_data: str):
+ """
+ Load audio from base64 encoded data.
+
+ Args:
+ audio_data (str): Base64 encoded audio data.
+
+ Returns:
+ str: Path to the loaded audio file.
+ """
+ try:
+ decoded_data = base64.b64decode(audio_data)
+
+ # Save to temporary file for mlx-vlm
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as tmp_file:
+ tmp_file.write(decoded_data)
+ return tmp_file.name
+
+ except Exception as e:
+ print(f"Error loading audio from base64: {e}", file=sys.stderr)
+ return None
+
+ def cleanup_temp_files(self, file_paths: List[str]):
+ """
+ Clean up temporary files.
+
+ Args:
+ file_paths (List[str]): List of file paths to clean up.
+ """
+ for file_path in file_paths:
+ try:
+ if file_path and os.path.exists(file_path):
+ os.remove(file_path)
+ except Exception as e:
+ print(f"Error removing temporary file {file_path}: {e}", file=sys.stderr)
+
+async def serve(address):
+ # Start asyncio gRPC server
+ server = grpc.aio.server(migration_thread_pool=futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
+ options=[
+ ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
+ ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
+ ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
+ ])
+ # Add the servicer to the server
+ backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
+ # Bind the server to the address
+ server.add_insecure_port(address)
+
+ # Gracefully shutdown the server on SIGTERM or SIGINT
+ loop = asyncio.get_event_loop()
+ for sig in (signal.SIGINT, signal.SIGTERM):
+ loop.add_signal_handler(
+ sig, lambda: asyncio.ensure_future(server.stop(5))
+ )
+
+ # Start the server
+ await server.start()
+ print("Server started. Listening on: " + address, file=sys.stderr)
+ # Wait for the server to be terminated
+ await server.wait_for_termination()
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Run the gRPC server.")
+ parser.add_argument(
+ "--addr", default="localhost:50051", help="The address to bind the server to."
+ )
+ args = parser.parse_args()
+
+ asyncio.run(serve(args.addr))
diff --git a/backend/python/mlx-vlm/install.sh b/backend/python/mlx-vlm/install.sh
new file mode 100644
index 0000000000000000000000000000000000000000..b8ee4855249062294a6d518538c00ccefb00dd46
--- /dev/null
+++ b/backend/python/mlx-vlm/install.sh
@@ -0,0 +1,14 @@
+#!/bin/bash
+set -e
+
+USE_PIP=true
+
+backend_dir=$(dirname $0)
+
+if [ -d $backend_dir/common ]; then
+ source $backend_dir/common/libbackend.sh
+else
+ source $backend_dir/../common/libbackend.sh
+fi
+
+installRequirements
diff --git a/backend/python/mlx-vlm/requirements-mps.txt b/backend/python/mlx-vlm/requirements-mps.txt
new file mode 100644
index 0000000000000000000000000000000000000000..8737f6091c70f987fd28bd5748ee7dcf198e8320
--- /dev/null
+++ b/backend/python/mlx-vlm/requirements-mps.txt
@@ -0,0 +1 @@
+git+https://github.com/Blaizzy/mlx-vlm
\ No newline at end of file
diff --git a/backend/python/mlx-vlm/requirements.txt b/backend/python/mlx-vlm/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..f1771cc4adb4b4be9ddfb26acb959beb8278f178
--- /dev/null
+++ b/backend/python/mlx-vlm/requirements.txt
@@ -0,0 +1,4 @@
+grpcio==1.71.0
+protobuf
+certifi
+setuptools
\ No newline at end of file
diff --git a/backend/python/mlx-vlm/run.sh b/backend/python/mlx-vlm/run.sh
new file mode 100644
index 0000000000000000000000000000000000000000..fc88f97da712f14faef73f9e8b96589dd8ecc2ad
--- /dev/null
+++ b/backend/python/mlx-vlm/run.sh
@@ -0,0 +1,11 @@
+#!/bin/bash
+
+backend_dir=$(dirname $0)
+
+if [ -d $backend_dir/common ]; then
+ source $backend_dir/common/libbackend.sh
+else
+ source $backend_dir/../common/libbackend.sh
+fi
+
+startBackend $@
\ No newline at end of file
diff --git a/backend/python/mlx-vlm/test.py b/backend/python/mlx-vlm/test.py
new file mode 100644
index 0000000000000000000000000000000000000000..827aa71a3e33132b75d77a2c192a4000699b7042
--- /dev/null
+++ b/backend/python/mlx-vlm/test.py
@@ -0,0 +1,146 @@
+import unittest
+import subprocess
+import time
+import backend_pb2
+import backend_pb2_grpc
+
+import grpc
+
+import unittest
+import subprocess
+import time
+import grpc
+import backend_pb2_grpc
+import backend_pb2
+
+class TestBackendServicer(unittest.TestCase):
+ """
+ TestBackendServicer is the class that tests the gRPC service.
+
+ This class contains methods to test the startup and shutdown of the gRPC service.
+ """
+ def setUp(self):
+ self.service = subprocess.Popen(["python", "backend.py", "--addr", "localhost:50051"])
+ time.sleep(10)
+
+ def tearDown(self) -> None:
+ self.service.terminate()
+ self.service.wait()
+
+ def test_server_startup(self):
+ try:
+ self.setUp()
+ with grpc.insecure_channel("localhost:50051") as channel:
+ stub = backend_pb2_grpc.BackendStub(channel)
+ response = stub.Health(backend_pb2.HealthMessage())
+ self.assertEqual(response.message, b'OK')
+ except Exception as err:
+ print(err)
+ self.fail("Server failed to start")
+ finally:
+ self.tearDown()
+ def test_load_model(self):
+ """
+ This method tests if the model is loaded successfully
+ """
+ try:
+ self.setUp()
+ with grpc.insecure_channel("localhost:50051") as channel:
+ stub = backend_pb2_grpc.BackendStub(channel)
+ response = stub.LoadModel(backend_pb2.ModelOptions(Model="facebook/opt-125m"))
+ self.assertTrue(response.success)
+ self.assertEqual(response.message, "Model loaded successfully")
+ except Exception as err:
+ print(err)
+ self.fail("LoadModel service failed")
+ finally:
+ self.tearDown()
+
+ def test_text(self):
+ """
+ This method tests if the embeddings are generated successfully
+ """
+ try:
+ self.setUp()
+ with grpc.insecure_channel("localhost:50051") as channel:
+ stub = backend_pb2_grpc.BackendStub(channel)
+ response = stub.LoadModel(backend_pb2.ModelOptions(Model="facebook/opt-125m"))
+ self.assertTrue(response.success)
+ req = backend_pb2.PredictOptions(Prompt="The capital of France is")
+ resp = stub.Predict(req)
+ self.assertIsNotNone(resp.message)
+ except Exception as err:
+ print(err)
+ self.fail("text service failed")
+ finally:
+ self.tearDown()
+
+ def test_sampling_params(self):
+ """
+ This method tests if all sampling parameters are correctly processed
+ NOTE: this does NOT test for correctness, just that we received a compatible response
+ """
+ try:
+ self.setUp()
+ with grpc.insecure_channel("localhost:50051") as channel:
+ stub = backend_pb2_grpc.BackendStub(channel)
+ response = stub.LoadModel(backend_pb2.ModelOptions(Model="facebook/opt-125m"))
+ self.assertTrue(response.success)
+
+ req = backend_pb2.PredictOptions(
+ Prompt="The capital of France is",
+ TopP=0.8,
+ Tokens=50,
+ Temperature=0.7,
+ TopK=40,
+ PresencePenalty=0.1,
+ FrequencyPenalty=0.2,
+ RepetitionPenalty=1.1,
+ MinP=0.05,
+ Seed=42,
+ StopPrompts=["\n"],
+ StopTokenIds=[50256],
+ BadWords=["badword"],
+ IncludeStopStrInOutput=True,
+ IgnoreEOS=True,
+ MinTokens=5,
+ Logprobs=5,
+ PromptLogprobs=5,
+ SkipSpecialTokens=True,
+ SpacesBetweenSpecialTokens=True,
+ TruncatePromptTokens=10,
+ GuidedDecoding=True,
+ N=2,
+ )
+ resp = stub.Predict(req)
+ self.assertIsNotNone(resp.message)
+ self.assertIsNotNone(resp.logprobs)
+ except Exception as err:
+ print(err)
+ self.fail("sampling params service failed")
+ finally:
+ self.tearDown()
+
+
+ def test_embedding(self):
+ """
+ This method tests if the embeddings are generated successfully
+ """
+ try:
+ self.setUp()
+ with grpc.insecure_channel("localhost:50051") as channel:
+ stub = backend_pb2_grpc.BackendStub(channel)
+ response = stub.LoadModel(backend_pb2.ModelOptions(Model="intfloat/e5-mistral-7b-instruct"))
+ self.assertTrue(response.success)
+ embedding_request = backend_pb2.PredictOptions(Embeddings="This is a test sentence.")
+ embedding_response = stub.Embedding(embedding_request)
+ self.assertIsNotNone(embedding_response.embeddings)
+ # assert that is a list of floats
+ self.assertIsInstance(embedding_response.embeddings, list)
+ # assert that the list is not empty
+ self.assertTrue(len(embedding_response.embeddings) > 0)
+ except Exception as err:
+ print(err)
+ self.fail("Embedding service failed")
+ finally:
+ self.tearDown()
\ No newline at end of file
diff --git a/backend/python/mlx-vlm/test.sh b/backend/python/mlx-vlm/test.sh
new file mode 100644
index 0000000000000000000000000000000000000000..f31ae54e47dc7f5a10f630fa1d7b5c8ea56f0c9e
--- /dev/null
+++ b/backend/python/mlx-vlm/test.sh
@@ -0,0 +1,12 @@
+#!/bin/bash
+set -e
+
+backend_dir=$(dirname $0)
+
+if [ -d $backend_dir/common ]; then
+ source $backend_dir/common/libbackend.sh
+else
+ source $backend_dir/../common/libbackend.sh
+fi
+
+runUnittests
diff --git a/backend/python/mlx/Makefile b/backend/python/mlx/Makefile
new file mode 100644
index 0000000000000000000000000000000000000000..06f3bf614854433c5750e698492df87436317ba4
--- /dev/null
+++ b/backend/python/mlx/Makefile
@@ -0,0 +1,23 @@
+.PHONY: mlx
+mlx:
+ bash install.sh
+
+.PHONY: run
+run:
+ @echo "Running mlx..."
+ bash run.sh
+ @echo "mlx run."
+
+.PHONY: test
+test:
+ @echo "Testing mlx..."
+ bash test.sh
+ @echo "mlx tested."
+
+.PHONY: protogen-clean
+protogen-clean:
+ $(RM) backend_pb2_grpc.py backend_pb2.py
+
+.PHONY: clean
+clean: protogen-clean
+ rm -rf venv __pycache__
\ No newline at end of file
diff --git a/backend/python/mlx/backend.py b/backend/python/mlx/backend.py
new file mode 100644
index 0000000000000000000000000000000000000000..aaa0d6f347f8e50fde0b37d50ce29441f0c50a22
--- /dev/null
+++ b/backend/python/mlx/backend.py
@@ -0,0 +1,450 @@
+#!/usr/bin/env python3
+import asyncio
+from concurrent import futures
+import argparse
+import signal
+import sys
+import os
+from typing import List
+import time
+
+import backend_pb2
+import backend_pb2_grpc
+
+import grpc
+from mlx_lm import load, generate, stream_generate
+from mlx_lm.sample_utils import make_sampler
+from mlx_lm.models.cache import make_prompt_cache, can_trim_prompt_cache, trim_prompt_cache
+import mlx.core as mx
+import base64
+import io
+
+from mlx_cache import ThreadSafeLRUPromptCache
+
+_ONE_DAY_IN_SECONDS = 60 * 60 * 24
+
+# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
+MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
+
+def is_float(s):
+ """Check if a string can be converted to float."""
+ try:
+ float(s)
+ return True
+ except ValueError:
+ return False
+def is_int(s):
+ """Check if a string can be converted to int."""
+ try:
+ int(s)
+ return True
+ except ValueError:
+ return False
+
+# Implement the BackendServicer class with the service methods
+class BackendServicer(backend_pb2_grpc.BackendServicer):
+ """
+ A gRPC servicer that implements the Backend service defined in backend.proto.
+ """
+
+ def Health(self, request, context):
+ """
+ Returns a health check message.
+
+ Args:
+ request: The health check request.
+ context: The gRPC context.
+
+ Returns:
+ backend_pb2.Reply: The health check reply.
+ """
+ return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
+
+ async def LoadModel(self, request, context):
+ """
+ Loads a language model using MLX.
+
+ Args:
+ request: The load model request.
+ context: The gRPC context.
+
+ Returns:
+ backend_pb2.Result: The load model result.
+ """
+ try:
+ print(f"Loading MLX model: {request.Model}", file=sys.stderr)
+ print(f"Request: {request}", file=sys.stderr)
+
+ # Parse options like in the diffusers backend
+ options = request.Options
+ self.options = {}
+
+ # The options are a list of strings in this form optname:optvalue
+ # We store all the options in a dict for later use
+ for opt in options:
+ if ":" not in opt:
+ continue
+ key, value = opt.split(":", 1) # Split only on first colon to handle values with colons
+
+ # Convert numeric values to appropriate types
+ if is_float(value):
+ value = float(value)
+ elif is_int(value):
+ value = int(value)
+ elif value.lower() in ["true", "false"]:
+ value = value.lower() == "true"
+
+ self.options[key] = value
+
+ print(f"Options: {self.options}", file=sys.stderr)
+
+ # Build tokenizer config for MLX using options
+ tokenizer_config = {}
+
+ # Handle trust_remote_code from request or options
+ if request.TrustRemoteCode or self.options.get("trust_remote_code", False):
+ tokenizer_config["trust_remote_code"] = True
+
+ # Handle EOS token from options
+ if "eos_token" in self.options:
+ tokenizer_config["eos_token"] = self.options["eos_token"]
+
+ # Handle other tokenizer config options
+ for key in ["pad_token", "bos_token", "unk_token", "sep_token", "cls_token", "mask_token"]:
+ if key in self.options:
+ tokenizer_config[key] = self.options[key]
+
+ # Load model and tokenizer using MLX
+ if tokenizer_config:
+ print(f"Loading with tokenizer_config: {tokenizer_config}", file=sys.stderr)
+ self.model, self.tokenizer = load(request.Model, tokenizer_config=tokenizer_config)
+ else:
+ self.model, self.tokenizer = load(request.Model)
+
+ # Initialize thread-safe LRU prompt cache for efficient generation
+ max_cache_entries = self.options.get("max_cache_entries", 10)
+ self.max_kv_size = self.options.get("max_kv_size", None)
+ self.model_key = request.Model
+ self.lru_cache = ThreadSafeLRUPromptCache(
+ max_size=max_cache_entries,
+ can_trim_fn=can_trim_prompt_cache,
+ trim_fn=trim_prompt_cache,
+ )
+
+ except Exception as err:
+ print(f"Error loading MLX model {err=}, {type(err)=}", file=sys.stderr)
+ return backend_pb2.Result(success=False, message=f"Error loading MLX model: {err}")
+
+ print("MLX model loaded successfully", file=sys.stderr)
+ return backend_pb2.Result(message="MLX model loaded successfully", success=True)
+
+ async def Predict(self, request, context):
+ """
+ Generates text based on the given prompt and sampling parameters using MLX.
+
+ Uses thread-safe LRU prompt cache for efficient prefix reuse across requests.
+
+ Args:
+ request: The predict request.
+ context: The gRPC context.
+
+ Returns:
+ backend_pb2.Reply: The predict result.
+ """
+ prompt_cache = None
+ cache_key = None
+
+ try:
+ # Prepare the prompt and tokenize for cache key
+ prompt_text = self._prepare_prompt(request)
+ cache_key = self._get_tokens_from_prompt(prompt_text)
+
+ # Fetch nearest cache (exact, shorter prefix, or create new)
+ prompt_cache, remaining_tokens = self.lru_cache.fetch_nearest_cache(
+ self.model_key, cache_key
+ )
+ if prompt_cache is None:
+ prompt_cache = make_prompt_cache(self.model, self.max_kv_size)
+ remaining_tokens = cache_key
+
+ # Build generation parameters using request attributes and options
+ max_tokens, sampler_params = self._build_generation_params(request)
+
+ print(f"Generating text with MLX - max_tokens: {max_tokens}, cache_hit: {len(remaining_tokens) < len(cache_key)}", file=sys.stderr)
+
+ # Create sampler with parameters
+ sampler = make_sampler(**sampler_params)
+
+ # Use stream_generate to track generated tokens for cache key
+ generated_text = []
+ for response in stream_generate(
+ self.model,
+ self.tokenizer,
+ prompt=remaining_tokens if remaining_tokens else cache_key,
+ max_tokens=max_tokens,
+ sampler=sampler,
+ prompt_cache=prompt_cache,
+ ):
+ generated_text.append(response.text)
+ cache_key.append(response.token)
+
+ # Insert completed cache
+ self.lru_cache.insert_cache(self.model_key, cache_key, prompt_cache)
+
+ return backend_pb2.Reply(message=bytes(''.join(generated_text), encoding='utf-8'))
+
+ except Exception as e:
+ print(f"Error in MLX Predict: {e}", file=sys.stderr)
+ context.set_code(grpc.StatusCode.INTERNAL)
+ context.set_details(f"Generation failed: {str(e)}")
+ return backend_pb2.Reply(message=bytes("", encoding='utf-8'))
+
+ def Embedding(self, request, context):
+ """
+ A gRPC method that calculates embeddings for a given sentence.
+
+ Note: MLX-LM doesn't support embeddings directly. This method returns an error.
+
+ Args:
+ request: An EmbeddingRequest object that contains the request parameters.
+ context: A grpc.ServicerContext object that provides information about the RPC.
+
+ Returns:
+ An EmbeddingResult object that contains the calculated embeddings.
+ """
+ print("Embeddings not supported in MLX backend", file=sys.stderr)
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+ context.set_details("Embeddings are not supported in the MLX backend.")
+ return backend_pb2.EmbeddingResult()
+
+ async def PredictStream(self, request, context):
+ """
+ Generates text based on the given prompt and sampling parameters, and streams the results using MLX.
+
+ Uses thread-safe LRU prompt cache for efficient prefix reuse across requests.
+
+ Args:
+ request: The predict stream request.
+ context: The gRPC context.
+
+ Yields:
+ backend_pb2.Reply: Streaming predict results.
+ """
+ prompt_cache = None
+ cache_key = None
+
+ try:
+ # Prepare the prompt and tokenize for cache key
+ prompt_text = self._prepare_prompt(request)
+ cache_key = self._get_tokens_from_prompt(prompt_text)
+
+ # Fetch nearest cache (exact, shorter prefix, or create new)
+ prompt_cache, remaining_tokens = self.lru_cache.fetch_nearest_cache(
+ self.model_key, cache_key
+ )
+ if prompt_cache is None:
+ prompt_cache = make_prompt_cache(self.model, self.max_kv_size)
+ remaining_tokens = cache_key
+
+ # Build generation parameters using request attributes and options
+ max_tokens, sampler_params = self._build_generation_params(request, default_max_tokens=512)
+
+ print(f"Streaming text with MLX - max_tokens: {max_tokens}, cache_hit: {len(remaining_tokens) < len(cache_key)}", file=sys.stderr)
+
+ # Create sampler with parameters
+ sampler = make_sampler(**sampler_params)
+
+ # Stream text generation using MLX with proper parameters
+ for response in stream_generate(
+ self.model,
+ self.tokenizer,
+ prompt=remaining_tokens if remaining_tokens else cache_key,
+ max_tokens=max_tokens,
+ sampler=sampler,
+ prompt_cache=prompt_cache,
+ ):
+ cache_key.append(response.token)
+ yield backend_pb2.Reply(message=bytes(response.text, encoding='utf-8'))
+
+ except Exception as e:
+ print(f"Error in MLX PredictStream: {e}", file=sys.stderr)
+ context.set_code(grpc.StatusCode.INTERNAL)
+ context.set_details(f"Streaming generation failed: {str(e)}")
+ yield backend_pb2.Reply(message=bytes("", encoding='utf-8'))
+
+ finally:
+ # Always insert cache, even on interruption
+ if prompt_cache is not None and cache_key is not None:
+ try:
+ self.lru_cache.insert_cache(self.model_key, cache_key, prompt_cache)
+ except Exception as e:
+ print(f"Error inserting cache: {e}", file=sys.stderr)
+
+ def _prepare_prompt(self, request):
+ """
+ Prepare the prompt for MLX generation, handling chat templates if needed.
+
+ Args:
+ request: The gRPC request containing prompt and message information.
+
+ Returns:
+ str: The prepared prompt.
+ """
+ # If tokenizer template is enabled and messages are provided instead of prompt, apply the tokenizer template
+ if not request.Prompt and request.UseTokenizerTemplate and request.Messages:
+ # Convert gRPC messages to the format expected by apply_chat_template
+ messages = []
+ for msg in request.Messages:
+ messages.append({"role": msg.role, "content": msg.content})
+
+ prompt = self.tokenizer.apply_chat_template(
+ messages,
+ tokenize=False,
+ add_generation_prompt=True
+ )
+ return prompt
+ else:
+ return request.Prompt
+
+ def _get_tokens_from_prompt(self, prompt_text: str) -> List[int]:
+ """
+ Tokenize prompt text for cache key generation.
+
+ Args:
+ prompt_text: The prompt string to tokenize.
+
+ Returns:
+ List[int]: List of token IDs.
+ """
+ tokens = self.tokenizer.encode(prompt_text)
+ if hasattr(tokens, 'tolist'):
+ return tokens.tolist()
+ return list(tokens)
+
+
+
+
+
+ def _build_generation_params(self, request, default_max_tokens=200):
+ """
+ Build generation parameters from request attributes and options.
+
+ Args:
+ request: The gRPC request.
+ default_max_tokens: Default max_tokens if not specified.
+
+ Returns:
+ tuple: (max_tokens, sampler_params dict)
+ """
+ # Extract max_tokens
+ max_tokens = getattr(request, 'Tokens', default_max_tokens)
+ if max_tokens == 0:
+ max_tokens = default_max_tokens
+
+ # Extract sampler parameters from request attributes
+ temp = getattr(request, 'Temperature', 0.0)
+ if temp == 0.0:
+ temp = 0.6 # Default temperature
+
+ top_p = getattr(request, 'TopP', 0.0)
+ if top_p == 0.0:
+ top_p = 1.0 # Default top_p
+
+ min_p = getattr(request, 'MinP', 0.0)
+ # min_p default of 0.0 means disabled (no filtering)
+
+ top_k = getattr(request, 'TopK', 0)
+ # top_k default of 0 means disabled (no filtering)
+
+ # Initialize sampler parameters
+ sampler_params = {
+ 'temp': temp,
+ 'top_p': top_p,
+ 'min_p': min_p,
+ 'top_k': top_k,
+ 'xtc_threshold': 0.0,
+ 'xtc_probability': 0.0,
+ }
+
+ # Add seed if specified
+ seed = getattr(request, 'Seed', 0)
+ if seed != 0:
+ mx.random.seed(seed)
+
+ # Override with options if available
+ if hasattr(self, 'options'):
+ # Max tokens from options
+ if 'max_tokens' in self.options:
+ max_tokens = self.options['max_tokens']
+
+ # Sampler parameters from options
+ sampler_option_mapping = {
+ 'temp': 'temp',
+ 'temperature': 'temp', # alias
+ 'top_p': 'top_p',
+ 'min_p': 'min_p',
+ 'top_k': 'top_k',
+ 'xtc_threshold': 'xtc_threshold',
+ 'xtc_probability': 'xtc_probability',
+ }
+
+ for option_key, param_key in sampler_option_mapping.items():
+ if option_key in self.options:
+ sampler_params[param_key] = self.options[option_key]
+
+ # Handle seed from options
+ if 'seed' in self.options:
+ mx.random.seed(self.options['seed'])
+
+ # Special tokens for XTC sampling (if tokenizer has eos_token_ids)
+ xtc_special_tokens = []
+ if hasattr(self.tokenizer, 'eos_token_ids') and self.tokenizer.eos_token_ids:
+ xtc_special_tokens = list(self.tokenizer.eos_token_ids)
+ elif hasattr(self.tokenizer, 'eos_token_id') and self.tokenizer.eos_token_id is not None:
+ xtc_special_tokens = [self.tokenizer.eos_token_id]
+
+ # Add newline token if available
+ try:
+ newline_tokens = self.tokenizer.encode("\n")
+ xtc_special_tokens.extend(newline_tokens)
+ except:
+ pass # Skip if encoding fails
+
+ sampler_params['xtc_special_tokens'] = xtc_special_tokens
+
+ return max_tokens, sampler_params
+
+async def serve(address):
+ # Start asyncio gRPC server
+ server = grpc.aio.server(migration_thread_pool=futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
+ options=[
+ ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
+ ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
+ ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
+ ])
+ # Add the servicer to the server
+ backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
+ # Bind the server to the address
+ server.add_insecure_port(address)
+
+ # Gracefully shutdown the server on SIGTERM or SIGINT
+ loop = asyncio.get_event_loop()
+ for sig in (signal.SIGINT, signal.SIGTERM):
+ loop.add_signal_handler(
+ sig, lambda: asyncio.ensure_future(server.stop(5))
+ )
+
+ # Start the server
+ await server.start()
+ print("Server started. Listening on: " + address, file=sys.stderr)
+ # Wait for the server to be terminated
+ await server.wait_for_termination()
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Run the gRPC server.")
+ parser.add_argument(
+ "--addr", default="localhost:50051", help="The address to bind the server to."
+ )
+ args = parser.parse_args()
+
+ asyncio.run(serve(args.addr))
diff --git a/backend/python/mlx/install.sh b/backend/python/mlx/install.sh
new file mode 100644
index 0000000000000000000000000000000000000000..253ee0c13f1b0c4508a5c934cc80c9b15040bf38
--- /dev/null
+++ b/backend/python/mlx/install.sh
@@ -0,0 +1,15 @@
+#!/bin/bash
+set -e
+
+USE_PIP=true
+PYTHON_VERSION=""
+
+backend_dir=$(dirname $0)
+
+if [ -d $backend_dir/common ]; then
+ source $backend_dir/common/libbackend.sh
+else
+ source $backend_dir/../common/libbackend.sh
+fi
+
+installRequirements
diff --git a/backend/python/mlx/mlx_cache.py b/backend/python/mlx/mlx_cache.py
new file mode 100644
index 0000000000000000000000000000000000000000..6ec2bb9baabbfe04d3825708e03fc17b9bbd3645
--- /dev/null
+++ b/backend/python/mlx/mlx_cache.py
@@ -0,0 +1,266 @@
+"""
+Thread-safe LRU prompt cache for MLX-based backends.
+
+Ported from mlx_lm/server.py (MIT License, Copyright 2023-2024 Apple Inc.)
+with thread-safety additions for LocalAI's gRPC backend.
+
+Usage:
+ from mlx_cache import ThreadSafeLRUPromptCache
+
+ # In LoadModel:
+ self.lru_cache = ThreadSafeLRUPromptCache(max_size=10)
+
+ # In Predict/PredictStream:
+ prompt_cache, remaining_tokens = self.lru_cache.fetch_nearest_cache(model_key, tokens)
+ # ... generate ...
+ self.lru_cache.insert_cache(model_key, tokens, prompt_cache)
+"""
+import copy
+import threading
+from collections import deque
+from dataclasses import dataclass
+from typing import Any, List, Optional, Tuple
+
+
+@dataclass
+class CacheEntry:
+ """A cache entry with reference counting."""
+ prompt_cache: List[Any]
+ count: int
+
+
+@dataclass
+class SearchResult:
+ """Result of searching the cache trie."""
+ model: Any
+ exact: Optional[List[int]]
+ shorter: Optional[List[int]]
+ longer: Optional[List[int]]
+ common_prefix: int
+
+
+class ThreadSafeLRUPromptCache:
+ """
+ Thread-safe LRU cache with prefix matching for prompt KV caches.
+
+ This cache stores KV caches keyed by token sequences and supports:
+ - Exact match: Return the cache for the exact token sequence
+ - Shorter prefix match: Return a cache for a prefix of the tokens
+ - Longer prefix match: If a longer sequence is cached and can be trimmed
+ - LRU eviction: When max_size is exceeded, evict least recently used
+
+ Thread safety is provided via a threading.Lock that protects all
+ cache operations.
+
+ Args:
+ max_size: Maximum number of cache entries (default: 10)
+ can_trim_fn: Optional function to check if a cache can be trimmed
+ trim_fn: Optional function to trim a cache
+ """
+
+ def __init__(
+ self,
+ max_size: int = 10,
+ can_trim_fn: Optional[Any] = None,
+ trim_fn: Optional[Any] = None,
+ ):
+ self.max_size = max_size
+ self._cache = {}
+ self._lru = deque()
+ self._lock = threading.Lock()
+
+ # Optional trim functions (for longer prefix reuse)
+ self._can_trim_fn = can_trim_fn
+ self._trim_fn = trim_fn
+
+ def _search(self, model, tokens: List[int]) -> SearchResult:
+ """
+ Search the cache for a prompt cache. Return exact or close match.
+
+ The cache is organized as a trie where each node is keyed by a token.
+ This allows efficient prefix matching.
+ """
+ if model not in self._cache:
+ return SearchResult(model, None, None, None, 0)
+
+ current = self._cache[model]
+ last_cache_index = -1
+ index = 0
+
+ # Traverse the trie following the token sequence
+ while index < len(tokens) and tokens[index] in current:
+ current = current[tokens[index]]
+ if "cache" in current:
+ last_cache_index = index
+ index += 1
+
+ # Exact match - no need to search for longer or shorter caches
+ if last_cache_index == len(tokens) - 1:
+ return SearchResult(model, tuple(tokens), None, None, 0)
+
+ # Find the shorter cache (a prefix that has a cache)
+ # Note: Uses > 0 (not >= 0) to match upstream mlx_lm/server.py behavior.
+ # Single-token prefixes are not matched, which allows longer cached
+ # sequences to be preferred for trimming. This is acceptable because
+ # real prompts with chat templates are always many tokens.
+ shorter = None
+ if last_cache_index > 0:
+ shorter = tuple(tokens[: last_cache_index + 1])
+
+ # Check for caches that are longer than our token sequence
+ longer = None
+ common_prefix = index
+ if index > 0 and last_cache_index <= 0:
+ best = None
+ stack = [(current, [])]
+ while stack:
+ current, extra = stack.pop()
+ if "cache" in current:
+ if best is None or len(extra) < len(best):
+ best = extra
+ else:
+ for tok in current:
+ stack.append((current[tok], extra + [tok]))
+ if best is not None:
+ longer = tuple(tokens[:index] + best)
+
+ return SearchResult(model, None, shorter, longer, common_prefix)
+
+ def _get(self, model, tokens: Tuple[int, ...]) -> CacheEntry:
+ """Get a cache entry by traversing the trie."""
+ current = self._cache[model]
+ for tok in tokens:
+ current = current[tok]
+ return current["cache"]
+
+ def _delete(self, model, tokens: Tuple[int, ...]) -> None:
+ """Delete a cache entry and clean up empty trie nodes."""
+ path = [self._cache[model]]
+ for tok in tokens:
+ path.append(path[-1][tok])
+ del path[-1]["cache"]
+
+ # Clean up empty nodes bottom-up
+ for i in reversed(range(len(tokens))):
+ d_prev, d, t = path[i], path[i + 1], tokens[i]
+ if len(d) > 0:
+ break
+ del d_prev[t]
+
+ def _extract(self, model, tokens: Tuple[int, ...]) -> CacheEntry:
+ """
+ Extract a cache entry for exclusive use.
+
+ If the entry has count > 1, deep copy and decrement.
+ If count == 1, remove from cache entirely.
+ """
+ cache_entry = self._get(model, tokens)
+ if cache_entry.count == 1:
+ self._delete(model, tokens)
+ self._lru.remove((model, tokens))
+ return cache_entry
+
+ cache_entry.count -= 1
+ return CacheEntry(
+ copy.deepcopy(cache_entry.prompt_cache),
+ 1,
+ )
+
+ def fetch_nearest_cache(
+ self, model, tokens: List[int]
+ ) -> Tuple[Optional[List[Any]], List[int]]:
+ """
+ Fetch the nearest cache for the given token sequence.
+
+ Thread-safe. Returns (cache, remaining_tokens) where:
+ - cache: The KV cache to use (or None if no cache found)
+ - remaining_tokens: Tokens that still need to be processed
+
+ Args:
+ model: Model identifier (used to namespace caches)
+ tokens: The full token sequence for the prompt
+
+ Returns:
+ Tuple of (prompt_cache, remaining_tokens)
+ """
+ with self._lock:
+ tokens_tuple = tuple(tokens)
+ result = self._search(model, tokens)
+
+ # Exact match - extract and return
+ if result.exact is not None:
+ cache_entry = self._extract(result.model, result.exact)
+ return cache_entry.prompt_cache, []
+
+ # Shorter prefix match - extract and return remaining
+ if result.shorter is not None:
+ cache_entry = self._extract(result.model, result.shorter)
+ prefix_len = len(result.shorter)
+ return cache_entry.prompt_cache, list(tokens[prefix_len:])
+
+ # Longer prefix match - try to trim if possible
+ if result.longer is not None and self._can_trim_fn is not None:
+ cache_entry = self._get(result.model, result.longer)
+ if self._can_trim_fn(cache_entry.prompt_cache):
+ # Deep copy and trim
+ trimmed_cache = copy.deepcopy(cache_entry.prompt_cache)
+ prefix = min(len(tokens) - 1, result.common_prefix)
+ num_to_trim = len(result.longer) - prefix
+ if self._trim_fn is not None:
+ self._trim_fn(trimmed_cache, num_to_trim)
+ return trimmed_cache, list(tokens[prefix:])
+
+ # No match found
+ return None, list(tokens)
+
+ def insert_cache(
+ self, model, tokens: List[int], prompt_cache: List[Any]
+ ) -> None:
+ """
+ Insert a cache entry after generation completes.
+
+ Thread-safe. Handles LRU eviction if max_size is exceeded.
+
+ Args:
+ model: Model identifier (used to namespace caches)
+ tokens: The full token sequence (prompt + generated)
+ prompt_cache: The KV cache to store
+ """
+ with self._lock:
+ tokens_tuple = tuple(tokens)
+
+ if model not in self._cache:
+ self._cache[model] = {}
+ current = self._cache[model]
+
+ # Build trie path
+ for tok in tokens_tuple:
+ if tok not in current:
+ current[tok] = {}
+ current = current[tok]
+
+ # Update or create entry
+ if "cache" in current:
+ current["cache"].count += 1
+ self._lru.remove((model, tokens_tuple))
+ else:
+ current["cache"] = CacheEntry(prompt_cache, 1)
+
+ # Update LRU order
+ self._lru.append((model, tokens_tuple))
+
+ # Evict if over capacity
+ if len(self._lru) > self.max_size:
+ evict_model, evict_tokens = self._lru.popleft()
+ self._delete(evict_model, evict_tokens)
+
+ def clear(self) -> None:
+ """Clear all cache entries. Thread-safe."""
+ with self._lock:
+ self._cache.clear()
+ self._lru.clear()
+
+ def __len__(self) -> int:
+ """Return the number of cache entries. Thread-safe."""
+ with self._lock:
+ return len(self._lru)
diff --git a/backend/python/mlx/requirements-mps.txt b/backend/python/mlx/requirements-mps.txt
new file mode 100644
index 0000000000000000000000000000000000000000..22737f5fdda7499b1d5377df1ab1aedff88c4100
--- /dev/null
+++ b/backend/python/mlx/requirements-mps.txt
@@ -0,0 +1 @@
+mlx-lm
\ No newline at end of file
diff --git a/backend/python/mlx/requirements.txt b/backend/python/mlx/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..f1771cc4adb4b4be9ddfb26acb959beb8278f178
--- /dev/null
+++ b/backend/python/mlx/requirements.txt
@@ -0,0 +1,4 @@
+grpcio==1.71.0
+protobuf
+certifi
+setuptools
\ No newline at end of file
diff --git a/backend/python/mlx/run.sh b/backend/python/mlx/run.sh
new file mode 100644
index 0000000000000000000000000000000000000000..fc88f97da712f14faef73f9e8b96589dd8ecc2ad
--- /dev/null
+++ b/backend/python/mlx/run.sh
@@ -0,0 +1,11 @@
+#!/bin/bash
+
+backend_dir=$(dirname $0)
+
+if [ -d $backend_dir/common ]; then
+ source $backend_dir/common/libbackend.sh
+else
+ source $backend_dir/../common/libbackend.sh
+fi
+
+startBackend $@
\ No newline at end of file
diff --git a/backend/python/mlx/test.py b/backend/python/mlx/test.py
new file mode 100644
index 0000000000000000000000000000000000000000..53d7bc7ec1b4d9bf5ffc6503a2913b46dc3fac8d
--- /dev/null
+++ b/backend/python/mlx/test.py
@@ -0,0 +1,234 @@
+import unittest
+import subprocess
+import time
+
+import grpc
+import backend_pb2
+import backend_pb2_grpc
+
+class TestBackendServicer(unittest.TestCase):
+ """
+ TestBackendServicer is the class that tests the gRPC service.
+
+ This class contains methods to test the startup and shutdown of the gRPC service.
+ """
+ def setUp(self):
+ self.service = subprocess.Popen(["python", "backend.py", "--addr", "localhost:50051"])
+ time.sleep(10)
+
+ def tearDown(self) -> None:
+ self.service.terminate()
+ self.service.wait()
+
+ def test_server_startup(self):
+ try:
+ self.setUp()
+ with grpc.insecure_channel("localhost:50051") as channel:
+ stub = backend_pb2_grpc.BackendStub(channel)
+ response = stub.Health(backend_pb2.HealthMessage())
+ self.assertEqual(response.message, b'OK')
+ except Exception as err:
+ print(err)
+ self.fail("Server failed to start")
+ finally:
+ self.tearDown()
+ def test_load_model(self):
+ """
+ This method tests if the model is loaded successfully
+ """
+ try:
+ self.setUp()
+ with grpc.insecure_channel("localhost:50051") as channel:
+ stub = backend_pb2_grpc.BackendStub(channel)
+ response = stub.LoadModel(backend_pb2.ModelOptions(Model="mlx-community/Llama-3.2-1B-Instruct-4bit"))
+ self.assertTrue(response.success)
+ self.assertEqual(response.message, "MLX model loaded successfully")
+ except Exception as err:
+ print(err)
+ self.fail("LoadModel service failed")
+ finally:
+ self.tearDown()
+
+ def test_text(self):
+ """
+ This method tests if the embeddings are generated successfully
+ """
+ try:
+ self.setUp()
+ with grpc.insecure_channel("localhost:50051") as channel:
+ stub = backend_pb2_grpc.BackendStub(channel)
+ response = stub.LoadModel(backend_pb2.ModelOptions(Model="mlx-community/Llama-3.2-1B-Instruct-4bit"))
+ self.assertTrue(response.success)
+ req = backend_pb2.PredictOptions(Prompt="The capital of France is")
+ resp = stub.Predict(req)
+ self.assertIsNotNone(resp.message)
+ except Exception as err:
+ print(err)
+ self.fail("text service failed")
+ finally:
+ self.tearDown()
+
+ def test_sampling_params(self):
+ """
+ This method tests if all sampling parameters are correctly processed
+ NOTE: this does NOT test for correctness, just that we received a compatible response
+ """
+ try:
+ self.setUp()
+ with grpc.insecure_channel("localhost:50051") as channel:
+ stub = backend_pb2_grpc.BackendStub(channel)
+ response = stub.LoadModel(backend_pb2.ModelOptions(Model="mlx-community/Llama-3.2-1B-Instruct-4bit"))
+ self.assertTrue(response.success)
+
+ req = backend_pb2.PredictOptions(
+ Prompt="The capital of France is",
+ TopP=0.8,
+ Tokens=50,
+ Temperature=0.7,
+ TopK=40,
+ PresencePenalty=0.1,
+ FrequencyPenalty=0.2,
+ MinP=0.05,
+ Seed=42,
+ StopPrompts=["\n"],
+ IgnoreEOS=True,
+ )
+ resp = stub.Predict(req)
+ self.assertIsNotNone(resp.message)
+ except Exception as err:
+ print(err)
+ self.fail("sampling params service failed")
+ finally:
+ self.tearDown()
+
+
+ def test_embedding(self):
+ """
+ This method tests if the embeddings are generated successfully
+ """
+ try:
+ self.setUp()
+ with grpc.insecure_channel("localhost:50051") as channel:
+ stub = backend_pb2_grpc.BackendStub(channel)
+ response = stub.LoadModel(backend_pb2.ModelOptions(Model="intfloat/e5-mistral-7b-instruct"))
+ self.assertTrue(response.success)
+ embedding_request = backend_pb2.PredictOptions(Embeddings="This is a test sentence.")
+ embedding_response = stub.Embedding(embedding_request)
+ self.assertIsNotNone(embedding_response.embeddings)
+ # assert that is a list of floats
+ self.assertIsInstance(embedding_response.embeddings, list)
+ # assert that the list is not empty
+ self.assertTrue(len(embedding_response.embeddings) > 0)
+ except Exception as err:
+ print(err)
+ self.fail("Embedding service failed")
+ finally:
+ self.tearDown()
+
+ def test_concurrent_requests(self):
+ """
+ This method tests that concurrent requests don't corrupt each other's cache state.
+ This is a regression test for the race condition in the original implementation.
+ """
+ import concurrent.futures
+
+ try:
+ self.setUp()
+ with grpc.insecure_channel("localhost:50051") as channel:
+ stub = backend_pb2_grpc.BackendStub(channel)
+ response = stub.LoadModel(backend_pb2.ModelOptions(Model="mlx-community/Llama-3.2-1B-Instruct-4bit"))
+ self.assertTrue(response.success)
+
+ def make_request(prompt):
+ req = backend_pb2.PredictOptions(Prompt=prompt, Tokens=20)
+ return stub.Predict(req)
+
+ # Run 5 concurrent requests with different prompts
+ prompts = [
+ "The capital of France is",
+ "The capital of Germany is",
+ "The capital of Italy is",
+ "The capital of Spain is",
+ "The capital of Portugal is",
+ ]
+
+ with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor:
+ futures = [executor.submit(make_request, p) for p in prompts]
+ results = [f.result() for f in concurrent.futures.as_completed(futures)]
+
+ # All results should be non-empty
+ messages = [r.message for r in results]
+ self.assertTrue(all(len(m) > 0 for m in messages), "All requests should return non-empty responses")
+ print(f"Concurrent test passed: {len(messages)} responses received")
+
+ except Exception as err:
+ print(err)
+ self.fail("Concurrent requests test failed")
+ finally:
+ self.tearDown()
+
+ def test_cache_reuse(self):
+ """
+ This method tests that repeated prompts reuse cached KV states.
+ The second request should benefit from the cached prompt processing.
+ """
+ try:
+ self.setUp()
+ with grpc.insecure_channel("localhost:50051") as channel:
+ stub = backend_pb2_grpc.BackendStub(channel)
+ response = stub.LoadModel(backend_pb2.ModelOptions(Model="mlx-community/Llama-3.2-1B-Instruct-4bit"))
+ self.assertTrue(response.success)
+
+ prompt = "The quick brown fox jumps over the lazy dog. "
+
+ # First request - populates cache
+ req1 = backend_pb2.PredictOptions(Prompt=prompt, Tokens=10)
+ resp1 = stub.Predict(req1)
+ self.assertIsNotNone(resp1.message)
+
+ # Second request with same prompt - should reuse cache
+ req2 = backend_pb2.PredictOptions(Prompt=prompt, Tokens=10)
+ resp2 = stub.Predict(req2)
+ self.assertIsNotNone(resp2.message)
+
+ print(f"Cache reuse test passed: first={len(resp1.message)} bytes, second={len(resp2.message)} bytes")
+
+ except Exception as err:
+ print(err)
+ self.fail("Cache reuse test failed")
+ finally:
+ self.tearDown()
+
+ def test_prefix_cache_reuse(self):
+ """
+ This method tests that prompts sharing a common prefix benefit from cached KV states.
+ """
+ try:
+ self.setUp()
+ with grpc.insecure_channel("localhost:50051") as channel:
+ stub = backend_pb2_grpc.BackendStub(channel)
+ response = stub.LoadModel(backend_pb2.ModelOptions(Model="mlx-community/Llama-3.2-1B-Instruct-4bit"))
+ self.assertTrue(response.success)
+
+ # First request with base prompt
+ prompt_base = "Once upon a time in a land far away, "
+ req1 = backend_pb2.PredictOptions(Prompt=prompt_base, Tokens=10)
+ resp1 = stub.Predict(req1)
+ self.assertIsNotNone(resp1.message)
+
+ # Second request with extended prompt (same prefix)
+ prompt_extended = prompt_base + "there lived a brave knight who "
+ req2 = backend_pb2.PredictOptions(Prompt=prompt_extended, Tokens=10)
+ resp2 = stub.Predict(req2)
+ self.assertIsNotNone(resp2.message)
+
+ print(f"Prefix cache test passed: base={len(resp1.message)} bytes, extended={len(resp2.message)} bytes")
+
+ except Exception as err:
+ print(err)
+ self.fail("Prefix cache reuse test failed")
+ finally:
+ self.tearDown()
+
+
+# Unit tests for ThreadSafeLRUPromptCache are in test_mlx_cache.py
\ No newline at end of file
diff --git a/backend/python/mlx/test.sh b/backend/python/mlx/test.sh
new file mode 100644
index 0000000000000000000000000000000000000000..f31ae54e47dc7f5a10f630fa1d7b5c8ea56f0c9e
--- /dev/null
+++ b/backend/python/mlx/test.sh
@@ -0,0 +1,12 @@
+#!/bin/bash
+set -e
+
+backend_dir=$(dirname $0)
+
+if [ -d $backend_dir/common ]; then
+ source $backend_dir/common/libbackend.sh
+else
+ source $backend_dir/../common/libbackend.sh
+fi
+
+runUnittests
diff --git a/backend/python/mlx/test_mlx_cache.py b/backend/python/mlx/test_mlx_cache.py
new file mode 100644
index 0000000000000000000000000000000000000000..c888782e9ddf2db95f2c35fe91567843275acae4
--- /dev/null
+++ b/backend/python/mlx/test_mlx_cache.py
@@ -0,0 +1,480 @@
+"""
+Comprehensive unit tests for ThreadSafeLRUPromptCache.
+
+Tests all cache operation modes:
+- Exact match
+- Shorter prefix match
+- Longer prefix match (with trimming)
+- No match
+- LRU eviction
+- Reference counting
+- Multi-model namespacing
+- Thread safety with data integrity verification
+"""
+import unittest
+import concurrent.futures
+import threading
+import copy
+from mlx_cache import ThreadSafeLRUPromptCache
+
+
+class TestCacheExactMatch(unittest.TestCase):
+ """Tests for exact match cache behavior."""
+
+ def setUp(self):
+ self.cache = ThreadSafeLRUPromptCache(max_size=10)
+
+ def test_exact_match_returns_cache_and_empty_remaining(self):
+ """Exact match should return the cache with no remaining tokens."""
+ tokens = [1, 2, 3, 4, 5]
+ mock_cache = ["kv_cache_data"]
+
+ self.cache.insert_cache("model1", tokens, mock_cache)
+ result_cache, remaining = self.cache.fetch_nearest_cache("model1", tokens)
+
+ self.assertEqual(result_cache, mock_cache)
+ self.assertEqual(remaining, [])
+
+ def test_exact_match_extracts_and_removes_from_cache(self):
+ """Fetching exact match with count=1 should remove entry from cache."""
+ tokens = [1, 2, 3]
+ self.cache.insert_cache("model1", tokens, ["cache"])
+
+ self.assertEqual(len(self.cache), 1)
+
+ # First fetch extracts the entry
+ self.cache.fetch_nearest_cache("model1", tokens)
+
+ # Cache should now be empty
+ self.assertEqual(len(self.cache), 0)
+
+ # Second fetch should return None (no match)
+ result_cache, remaining = self.cache.fetch_nearest_cache("model1", tokens)
+ self.assertIsNone(result_cache)
+ self.assertEqual(remaining, tokens)
+
+
+class TestCacheShorterPrefix(unittest.TestCase):
+ """Tests for shorter prefix match behavior."""
+
+ def setUp(self):
+ self.cache = ThreadSafeLRUPromptCache(max_size=10)
+
+ def test_shorter_prefix_returns_cache_with_remaining_tokens(self):
+ """When cached prefix is shorter, return cache and remaining suffix."""
+ short_tokens = [1, 2, 3]
+ long_tokens = [1, 2, 3, 4, 5, 6]
+ mock_cache = ["prefix_cache"]
+
+ self.cache.insert_cache("model1", short_tokens, mock_cache)
+ result_cache, remaining = self.cache.fetch_nearest_cache("model1", long_tokens)
+
+ self.assertEqual(result_cache, mock_cache)
+ self.assertEqual(remaining, [4, 5, 6])
+
+ def test_shorter_prefix_correct_remaining_calculation(self):
+ """Verify remaining tokens are calculated correctly for various prefix lengths."""
+ # Note: Single-token prefixes ([1] -> [1,2,3]) are deliberately not matched
+ # to allow longer cached sequences to be preferred for trimming.
+ # This matches upstream mlx_lm/server.py behavior.
+ test_cases = [
+ # (cached_tokens, requested_tokens, expected_remaining)
+ ([1, 2], [1, 2, 3, 4, 5], [3, 4, 5]),
+ ([10, 20, 30, 40], [10, 20, 30, 40, 50], [50]),
+ ]
+
+ for cached, requested, expected_remaining in test_cases:
+ with self.subTest(cached=cached, requested=requested):
+ cache = ThreadSafeLRUPromptCache(max_size=10)
+ cache.insert_cache("model", cached, ["cache"])
+ result_cache, remaining = cache.fetch_nearest_cache("model", requested)
+
+ self.assertIsNotNone(result_cache)
+ self.assertEqual(remaining, expected_remaining)
+
+ def test_single_token_prefix_not_matched(self):
+ """Single-token prefixes are not matched (by design, matches upstream).
+
+ This allows longer cached sequences to be preferred for trimming,
+ which provides better KV cache reuse. Single-token caches are rare
+ in practice since real prompts with chat templates are many tokens.
+ """
+ cache = ThreadSafeLRUPromptCache(max_size=10)
+ cache.insert_cache("model", [1], ["cache"])
+
+ result_cache, remaining = cache.fetch_nearest_cache("model", [1, 2, 3])
+
+ # Single-token prefix is NOT matched
+ self.assertIsNone(result_cache)
+ self.assertEqual(remaining, [1, 2, 3])
+
+
+class TestCacheLongerPrefix(unittest.TestCase):
+ """Tests for longer prefix match behavior (trimming)."""
+
+ def setUp(self):
+ # Track trim calls for verification
+ self.trim_calls = []
+
+ def mock_can_trim(cache):
+ return True
+
+ def mock_trim(cache, num_to_trim):
+ self.trim_calls.append(num_to_trim)
+ # Simulate trimming by modifying the cache
+ cache.append(f"trimmed_{num_to_trim}")
+
+ self.cache = ThreadSafeLRUPromptCache(
+ max_size=10,
+ can_trim_fn=mock_can_trim,
+ trim_fn=mock_trim,
+ )
+
+ def test_longer_prefix_triggers_trim(self):
+ """When cached sequence is longer, should trim to match requested prefix."""
+ long_tokens = [1, 2, 3, 4, 5]
+ short_tokens = [1, 2, 3]
+
+ self.cache.insert_cache("model1", long_tokens, ["original_cache"])
+ result_cache, remaining = self.cache.fetch_nearest_cache("model1", short_tokens)
+
+ # Should have called trim
+ self.assertTrue(len(self.trim_calls) > 0, "trim_fn should have been called")
+ # Result should be a trimmed copy, not the original
+ self.assertIn("trimmed_", str(result_cache))
+
+ def test_longer_prefix_without_trim_fn_returns_no_match(self):
+ """Without trim functions, longer prefix should not match."""
+ cache_no_trim = ThreadSafeLRUPromptCache(max_size=10)
+
+ long_tokens = [1, 2, 3, 4, 5]
+ short_tokens = [1, 2, 3]
+
+ cache_no_trim.insert_cache("model1", long_tokens, ["cache"])
+ result_cache, remaining = cache_no_trim.fetch_nearest_cache("model1", short_tokens)
+
+ # Without trim_fn, should return no match
+ self.assertIsNone(result_cache)
+ self.assertEqual(remaining, short_tokens)
+
+ def test_longer_prefix_can_trim_false_returns_no_match(self):
+ """When can_trim_fn returns False, should not attempt trim."""
+ cache = ThreadSafeLRUPromptCache(
+ max_size=10,
+ can_trim_fn=lambda c: False,
+ trim_fn=lambda c, n: None,
+ )
+
+ cache.insert_cache("model1", [1, 2, 3, 4, 5], ["cache"])
+ result_cache, remaining = cache.fetch_nearest_cache("model1", [1, 2, 3])
+
+ self.assertIsNone(result_cache)
+ self.assertEqual(remaining, [1, 2, 3])
+
+
+class TestCacheNoMatch(unittest.TestCase):
+ """Tests for no match behavior."""
+
+ def setUp(self):
+ self.cache = ThreadSafeLRUPromptCache(max_size=10)
+
+ def test_empty_cache_returns_none(self):
+ """Empty cache should return None and all tokens as remaining."""
+ tokens = [1, 2, 3]
+ result_cache, remaining = self.cache.fetch_nearest_cache("model1", tokens)
+
+ self.assertIsNone(result_cache)
+ self.assertEqual(remaining, tokens)
+
+ def test_different_prefix_returns_none(self):
+ """Tokens with different prefix should not match."""
+ self.cache.insert_cache("model1", [1, 2, 3], ["cache"])
+
+ # Completely different tokens
+ result_cache, remaining = self.cache.fetch_nearest_cache("model1", [4, 5, 6])
+
+ self.assertIsNone(result_cache)
+ self.assertEqual(remaining, [4, 5, 6])
+
+ def test_partial_prefix_mismatch_returns_none(self):
+ """Tokens that diverge mid-sequence should not match."""
+ self.cache.insert_cache("model1", [1, 2, 3], ["cache"])
+
+ # Same start but diverges
+ result_cache, remaining = self.cache.fetch_nearest_cache("model1", [1, 2, 99])
+
+ self.assertIsNone(result_cache)
+ self.assertEqual(remaining, [1, 2, 99])
+
+ def test_wrong_model_returns_none(self):
+ """Different model key should not match."""
+ self.cache.insert_cache("model1", [1, 2, 3], ["cache"])
+
+ result_cache, remaining = self.cache.fetch_nearest_cache("model2", [1, 2, 3])
+
+ self.assertIsNone(result_cache)
+ self.assertEqual(remaining, [1, 2, 3])
+
+
+class TestCacheLRUEviction(unittest.TestCase):
+ """Tests for LRU eviction behavior."""
+
+ def setUp(self):
+ self.cache = ThreadSafeLRUPromptCache(max_size=3)
+
+ def test_evicts_oldest_when_full(self):
+ """Should evict least recently used entry when capacity exceeded."""
+ self.cache.insert_cache("model", [1], ["cache1"])
+ self.cache.insert_cache("model", [2], ["cache2"])
+ self.cache.insert_cache("model", [3], ["cache3"])
+
+ self.assertEqual(len(self.cache), 3)
+
+ # Insert 4th entry - should evict [1]
+ self.cache.insert_cache("model", [4], ["cache4"])
+
+ self.assertEqual(len(self.cache), 3)
+
+ # [1] should be evicted
+ result, _ = self.cache.fetch_nearest_cache("model", [1])
+ self.assertIsNone(result)
+
+ # [2], [3], [4] should still exist
+ for tokens in [[2], [3], [4]]:
+ # Re-insert since fetch extracts
+ self.cache.insert_cache("model", tokens, [f"cache{tokens[0]}"])
+
+ result2, _ = self.cache.fetch_nearest_cache("model", [2])
+ self.assertIsNotNone(result2)
+
+ def test_access_updates_lru_order(self):
+ """Accessing an entry should move it to most recently used."""
+ self.cache.insert_cache("model", [1], ["cache1"])
+ self.cache.insert_cache("model", [2], ["cache2"])
+ self.cache.insert_cache("model", [3], ["cache3"])
+
+ # Access [1] to make it most recently used
+ cache1, _ = self.cache.fetch_nearest_cache("model", [1])
+ # Re-insert it (simulating normal usage pattern)
+ self.cache.insert_cache("model", [1], cache1)
+
+ # Now insert two more entries - should evict [2] then [3], not [1]
+ self.cache.insert_cache("model", [4], ["cache4"])
+ self.cache.insert_cache("model", [5], ["cache5"])
+
+ # [1] should still exist (was accessed, so not evicted)
+ result1, _ = self.cache.fetch_nearest_cache("model", [1])
+ self.assertIsNotNone(result1)
+
+ # [2] should be evicted (was oldest after [1] was accessed)
+ result2, _ = self.cache.fetch_nearest_cache("model", [2])
+ self.assertIsNone(result2)
+
+
+class TestCacheReferenceCount(unittest.TestCase):
+ """Tests for reference counting behavior."""
+
+ def setUp(self):
+ self.cache = ThreadSafeLRUPromptCache(max_size=10)
+
+ def test_multiple_inserts_increment_count(self):
+ """Inserting same tokens multiple times should increment count."""
+ tokens = [1, 2, 3]
+
+ self.cache.insert_cache("model", tokens, ["cache"])
+ self.cache.insert_cache("model", tokens, ["cache"])
+ self.cache.insert_cache("model", tokens, ["cache"])
+
+ # Should still be one entry (with count=3 internally)
+ self.assertEqual(len(self.cache), 1)
+
+ # First two fetches should return copies (count decremented)
+ result1, _ = self.cache.fetch_nearest_cache("model", tokens)
+ self.assertIsNotNone(result1)
+
+ result2, _ = self.cache.fetch_nearest_cache("model", tokens)
+ self.assertIsNotNone(result2)
+
+ # Third fetch extracts the last reference
+ result3, _ = self.cache.fetch_nearest_cache("model", tokens)
+ self.assertIsNotNone(result3)
+
+ # Fourth fetch should return None (entry fully extracted)
+ result4, _ = self.cache.fetch_nearest_cache("model", tokens)
+ self.assertIsNone(result4)
+
+ def test_extract_with_high_count_returns_deep_copy(self):
+ """When count > 1, extract should return a deep copy."""
+ tokens = [1, 2, 3]
+ original_cache = [{"nested": "data"}]
+
+ self.cache.insert_cache("model", tokens, original_cache)
+ self.cache.insert_cache("model", tokens, original_cache) # count=2
+
+ result1, _ = self.cache.fetch_nearest_cache("model", tokens)
+
+ # Modify the returned cache
+ result1[0]["nested"] = "modified"
+
+ # Second fetch should get unmodified copy
+ result2, _ = self.cache.fetch_nearest_cache("model", tokens)
+ self.assertEqual(result2[0]["nested"], "data")
+
+
+class TestCacheMultiModel(unittest.TestCase):
+ """Tests for multi-model namespacing."""
+
+ def setUp(self):
+ self.cache = ThreadSafeLRUPromptCache(max_size=10)
+
+ def test_same_tokens_different_models_are_separate(self):
+ """Same token sequence under different models should be independent."""
+ tokens = [1, 2, 3]
+
+ self.cache.insert_cache("model_a", tokens, ["cache_a"])
+ self.cache.insert_cache("model_b", tokens, ["cache_b"])
+
+ self.assertEqual(len(self.cache), 2)
+
+ result_a, _ = self.cache.fetch_nearest_cache("model_a", tokens)
+ result_b, _ = self.cache.fetch_nearest_cache("model_b", tokens)
+
+ self.assertEqual(result_a, ["cache_a"])
+ self.assertEqual(result_b, ["cache_b"])
+
+ def test_eviction_across_models(self):
+ """LRU eviction should work across different models."""
+ cache = ThreadSafeLRUPromptCache(max_size=3)
+
+ cache.insert_cache("model_a", [1], ["a1"])
+ cache.insert_cache("model_b", [1], ["b1"])
+ cache.insert_cache("model_a", [2], ["a2"])
+
+ self.assertEqual(len(cache), 3)
+
+ # Insert 4th - should evict model_a:[1] (oldest)
+ cache.insert_cache("model_b", [2], ["b2"])
+
+ result, _ = cache.fetch_nearest_cache("model_a", [1])
+ self.assertIsNone(result)
+
+
+class TestCacheThreadSafety(unittest.TestCase):
+ """Tests for thread safety with data integrity verification."""
+
+ def test_concurrent_inserts_no_data_loss(self):
+ """Concurrent inserts should not lose data."""
+ cache = ThreadSafeLRUPromptCache(max_size=100)
+ num_threads = 10
+ inserts_per_thread = 20
+
+ def insert_entries(thread_id):
+ for i in range(inserts_per_thread):
+ tokens = [thread_id, i]
+ cache.insert_cache("model", tokens, [f"cache_{thread_id}_{i}"])
+
+ with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor:
+ futures = [executor.submit(insert_entries, tid) for tid in range(num_threads)]
+ concurrent.futures.wait(futures)
+
+ # Verify expected number of entries (may be less due to LRU eviction with max_size=100)
+ # But should be exactly 100 since we inserted exactly 200 and max_size is 100
+ self.assertEqual(len(cache), 100)
+
+ def test_concurrent_fetch_and_insert_no_corruption(self):
+ """Concurrent fetches and inserts should not corrupt data."""
+ cache = ThreadSafeLRUPromptCache(max_size=50)
+ errors = []
+ lock = threading.Lock()
+
+ # Pre-populate with known data
+ for i in range(20):
+ cache.insert_cache("model", [i], [f"original_{i}"])
+
+ def fetch_and_verify(thread_id):
+ try:
+ for _ in range(50):
+ token_id = thread_id % 20
+ result, remaining = cache.fetch_nearest_cache("model", [token_id])
+
+ if result is not None:
+ # Verify data integrity
+ expected_prefix = f"original_{token_id}"
+ if not str(result[0]).startswith("original_"):
+ with lock:
+ errors.append(f"Corrupted data: {result}")
+
+ # Re-insert to keep cache populated
+ cache.insert_cache("model", [token_id], result)
+
+ except Exception as e:
+ with lock:
+ errors.append(str(e))
+
+ with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
+ futures = [executor.submit(fetch_and_verify, tid) for tid in range(10)]
+ concurrent.futures.wait(futures)
+
+ self.assertEqual(errors, [], f"Thread safety errors: {errors}")
+
+ def test_concurrent_operations_maintain_cache_bounds(self):
+ """Cache size should never exceed max_size under concurrent operations."""
+ max_size = 10
+ cache = ThreadSafeLRUPromptCache(max_size=max_size)
+ size_violations = []
+ lock = threading.Lock()
+
+ def random_operations(thread_id):
+ import random
+ for i in range(100):
+ tokens = [random.randint(0, 50)]
+ if random.random() < 0.7:
+ cache.insert_cache("model", tokens, [f"cache_{thread_id}_{i}"])
+ else:
+ cache.fetch_nearest_cache("model", tokens)
+
+ current_size = len(cache)
+ if current_size > max_size:
+ with lock:
+ size_violations.append(current_size)
+
+ with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
+ futures = [executor.submit(random_operations, tid) for tid in range(10)]
+ concurrent.futures.wait(futures)
+
+ self.assertEqual(size_violations, [], f"Size exceeded max: {size_violations}")
+ self.assertLessEqual(len(cache), max_size)
+
+
+class TestCacheClear(unittest.TestCase):
+ """Tests for cache clear operation."""
+
+ def setUp(self):
+ self.cache = ThreadSafeLRUPromptCache(max_size=10)
+
+ def test_clear_removes_all_entries(self):
+ """Clear should remove all entries."""
+ self.cache.insert_cache("model1", [1, 2], ["cache1"])
+ self.cache.insert_cache("model2", [3, 4], ["cache2"])
+ self.cache.insert_cache("model1", [5, 6], ["cache3"])
+
+ self.assertEqual(len(self.cache), 3)
+
+ self.cache.clear()
+
+ self.assertEqual(len(self.cache), 0)
+
+ def test_clear_allows_new_inserts(self):
+ """After clear, new inserts should work normally."""
+ self.cache.insert_cache("model", [1], ["cache1"])
+ self.cache.clear()
+ self.cache.insert_cache("model", [2], ["cache2"])
+
+ self.assertEqual(len(self.cache), 1)
+
+ result, _ = self.cache.fetch_nearest_cache("model", [2])
+ self.assertEqual(result, ["cache2"])
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/backend/python/moonshine/Makefile b/backend/python/moonshine/Makefile
new file mode 100644
index 0000000000000000000000000000000000000000..71050097c44fd9a1cbb283785ea815bb86a38a42
--- /dev/null
+++ b/backend/python/moonshine/Makefile
@@ -0,0 +1,16 @@
+.DEFAULT_GOAL := install
+
+.PHONY: install
+install:
+ bash install.sh
+
+.PHONY: protogen-clean
+protogen-clean:
+ $(RM) backend_pb2_grpc.py backend_pb2.py
+
+.PHONY: clean
+clean: protogen-clean
+ rm -rf venv __pycache__
+
+test: install
+ bash test.sh
\ No newline at end of file
diff --git a/backend/python/moonshine/backend.py b/backend/python/moonshine/backend.py
new file mode 100644
index 0000000000000000000000000000000000000000..bc9e2965be3f061e333b7b243eccda00ef252b9b
--- /dev/null
+++ b/backend/python/moonshine/backend.py
@@ -0,0 +1,113 @@
+#!/usr/bin/env python3
+"""
+This is an extra gRPC server of LocalAI for Moonshine transcription
+"""
+from concurrent import futures
+import time
+import argparse
+import signal
+import sys
+import os
+import backend_pb2
+import backend_pb2_grpc
+import moonshine_onnx
+
+import grpc
+
+
+_ONE_DAY_IN_SECONDS = 60 * 60 * 24
+
+# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
+MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
+
+# Implement the BackendServicer class with the service methods
+class BackendServicer(backend_pb2_grpc.BackendServicer):
+ """
+ BackendServicer is the class that implements the gRPC service
+ """
+ def Health(self, request, context):
+ return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
+
+ def LoadModel(self, request, context):
+ try:
+ print("Preparing models, please wait", file=sys.stderr)
+ # Store the model name for use in transcription
+ # Model name format: e.g., "moonshine/tiny"
+ self.model_name = request.Model
+ print(f"Model name set to: {self.model_name}", file=sys.stderr)
+ except Exception as err:
+ return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
+ return backend_pb2.Result(message="Model loaded successfully", success=True)
+
+ def AudioTranscription(self, request, context):
+ resultSegments = []
+ text = ""
+ try:
+ # moonshine_onnx.transcribe returns a list of strings
+ transcriptions = moonshine_onnx.transcribe(request.dst, self.model_name)
+
+ # Combine all transcriptions into a single text
+ if isinstance(transcriptions, list):
+ text = " ".join(transcriptions)
+ # Create segments for each transcription in the list
+ for id, trans in enumerate(transcriptions):
+ # Since moonshine doesn't provide timing info, we'll create a single segment
+ # with id and text, using approximate timing
+ resultSegments.append(backend_pb2.TranscriptSegment(
+ id=id,
+ start=0,
+ end=0,
+ text=trans
+ ))
+ else:
+ # Handle case where it's not a list (shouldn't happen, but be safe)
+ text = str(transcriptions)
+ resultSegments.append(backend_pb2.TranscriptSegment(
+ id=0,
+ start=0,
+ end=0,
+ text=text
+ ))
+ except Exception as err:
+ print(f"Unexpected {err=}, {type(err)=}", file=sys.stderr)
+ return backend_pb2.TranscriptResult(segments=[], text="")
+
+ return backend_pb2.TranscriptResult(segments=resultSegments, text=text)
+
+def serve(address):
+ server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
+ options=[
+ ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
+ ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
+ ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
+ ])
+ backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
+ server.add_insecure_port(address)
+ server.start()
+ print("Server started. Listening on: " + address, file=sys.stderr)
+
+ # Define the signal handler function
+ def signal_handler(sig, frame):
+ print("Received termination signal. Shutting down...")
+ server.stop(0)
+ sys.exit(0)
+
+ # Set the signal handlers for SIGINT and SIGTERM
+ signal.signal(signal.SIGINT, signal_handler)
+ signal.signal(signal.SIGTERM, signal_handler)
+
+ try:
+ while True:
+ time.sleep(_ONE_DAY_IN_SECONDS)
+ except KeyboardInterrupt:
+ server.stop(0)
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Run the gRPC server.")
+ parser.add_argument(
+ "--addr", default="localhost:50051", help="The address to bind the server to."
+ )
+ args = parser.parse_args()
+
+ serve(args.addr)
+
diff --git a/backend/python/moonshine/install.sh b/backend/python/moonshine/install.sh
new file mode 100644
index 0000000000000000000000000000000000000000..4abc9cf583c07bd53aad3e1e9b877525aa7cee59
--- /dev/null
+++ b/backend/python/moonshine/install.sh
@@ -0,0 +1,12 @@
+#!/bin/bash
+set -e
+
+backend_dir=$(dirname $0)
+if [ -d $backend_dir/common ]; then
+ source $backend_dir/common/libbackend.sh
+else
+ source $backend_dir/../common/libbackend.sh
+fi
+
+installRequirements
+
diff --git a/backend/python/moonshine/protogen.sh b/backend/python/moonshine/protogen.sh
new file mode 100644
index 0000000000000000000000000000000000000000..1dc00c768268886313c738e65d7e7bcfe720d76c
--- /dev/null
+++ b/backend/python/moonshine/protogen.sh
@@ -0,0 +1,12 @@
+#!/bin/bash
+set -e
+
+backend_dir=$(dirname $0)
+if [ -d $backend_dir/common ]; then
+ source $backend_dir/common/libbackend.sh
+else
+ source $backend_dir/../common/libbackend.sh
+fi
+
+python3 -m grpc_tools.protoc -I../.. -I./ --python_out=. --grpc_python_out=. backend.proto
+
diff --git a/backend/python/moonshine/requirements.txt b/backend/python/moonshine/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..240f166cf2753ffcd7f6c8badd1045b885f7026c
--- /dev/null
+++ b/backend/python/moonshine/requirements.txt
@@ -0,0 +1,4 @@
+grpcio==1.71.0
+protobuf
+grpcio-tools
+useful-moonshine-onnx@git+https://git@github.com/moonshine-ai/moonshine.git#subdirectory=moonshine-onnx
\ No newline at end of file
diff --git a/backend/python/moonshine/run.sh b/backend/python/moonshine/run.sh
new file mode 100644
index 0000000000000000000000000000000000000000..8b3809e4a55ba2848f0f37f3d49ce8ecf92dc7de
--- /dev/null
+++ b/backend/python/moonshine/run.sh
@@ -0,0 +1,10 @@
+#!/bin/bash
+backend_dir=$(dirname $0)
+if [ -d $backend_dir/common ]; then
+ source $backend_dir/common/libbackend.sh
+else
+ source $backend_dir/../common/libbackend.sh
+fi
+
+startBackend $@
+
diff --git a/backend/python/moonshine/test.py b/backend/python/moonshine/test.py
new file mode 100644
index 0000000000000000000000000000000000000000..d69a7798d9ef81388bd33abea3f2f361eba7d80b
--- /dev/null
+++ b/backend/python/moonshine/test.py
@@ -0,0 +1,139 @@
+"""
+A test script to test the gRPC service for Moonshine transcription
+"""
+import unittest
+import subprocess
+import time
+import os
+import tempfile
+import shutil
+import backend_pb2
+import backend_pb2_grpc
+
+import grpc
+
+
+class TestBackendServicer(unittest.TestCase):
+ """
+ TestBackendServicer is the class that tests the gRPC service
+ """
+ def setUp(self):
+ """
+ This method sets up the gRPC service by starting the server
+ """
+ self.service = subprocess.Popen(["python3", "backend.py", "--addr", "localhost:50051"])
+ time.sleep(10)
+
+ def tearDown(self) -> None:
+ """
+ This method tears down the gRPC service by terminating the server
+ """
+ self.service.terminate()
+ self.service.wait()
+
+ def test_server_startup(self):
+ """
+ This method tests if the server starts up successfully
+ """
+ try:
+ self.setUp()
+ with grpc.insecure_channel("localhost:50051") as channel:
+ stub = backend_pb2_grpc.BackendStub(channel)
+ response = stub.Health(backend_pb2.HealthMessage())
+ self.assertEqual(response.message, b'OK')
+ except Exception as err:
+ print(err)
+ self.fail("Server failed to start")
+ finally:
+ self.tearDown()
+
+ def test_load_model(self):
+ """
+ This method tests if the model is loaded successfully
+ """
+ try:
+ self.setUp()
+ with grpc.insecure_channel("localhost:50051") as channel:
+ stub = backend_pb2_grpc.BackendStub(channel)
+ response = stub.LoadModel(backend_pb2.ModelOptions(Model="moonshine/tiny"))
+ self.assertTrue(response.success)
+ self.assertEqual(response.message, "Model loaded successfully")
+ except Exception as err:
+ print(err)
+ self.fail("LoadModel service failed")
+ finally:
+ self.tearDown()
+
+ def test_audio_transcription(self):
+ """
+ This method tests if audio transcription works successfully
+ """
+ # Create a temporary directory for the audio file
+ temp_dir = tempfile.mkdtemp()
+ audio_file = os.path.join(temp_dir, 'audio.wav')
+
+ try:
+ # Download the audio file to the temporary directory
+ print(f"Downloading audio file to {audio_file}...")
+ url = "https://cdn.openai.com/whisper/draft-20220913a/micro-machines.wav"
+ result = subprocess.run(
+ ["wget", "-q", url, "-O", audio_file],
+ capture_output=True,
+ text=True
+ )
+ if result.returncode != 0:
+ self.fail(f"Failed to download audio file: {result.stderr}")
+
+ # Verify the file was downloaded
+ if not os.path.exists(audio_file):
+ self.fail(f"Audio file was not downloaded to {audio_file}")
+
+ self.setUp()
+ with grpc.insecure_channel("localhost:50051") as channel:
+ stub = backend_pb2_grpc.BackendStub(channel)
+ # Load the model first
+ load_response = stub.LoadModel(backend_pb2.ModelOptions(Model="moonshine/tiny"))
+ self.assertTrue(load_response.success)
+
+ # Perform transcription
+ transcript_request = backend_pb2.TranscriptRequest(dst=audio_file)
+ transcript_response = stub.AudioTranscription(transcript_request)
+
+ # Print the transcribed text for debugging
+ print(f"Transcribed text: {transcript_response.text}")
+ print(f"Number of segments: {len(transcript_response.segments)}")
+
+ # Verify response structure
+ self.assertIsNotNone(transcript_response)
+ self.assertIsNotNone(transcript_response.text)
+ # Protobuf repeated fields return a sequence, not a list
+ self.assertIsNotNone(transcript_response.segments)
+ # Check if segments is iterable (has length)
+ self.assertGreaterEqual(len(transcript_response.segments), 0)
+
+ # Verify the transcription contains the expected text
+ expected_text = "This is the micro machine man presenting the most midget miniature"
+ self.assertIn(
+ expected_text.lower(),
+ transcript_response.text.lower(),
+ f"Expected text '{expected_text}' not found in transcription: '{transcript_response.text}'"
+ )
+
+ # If we got segments, verify they have the expected structure
+ if len(transcript_response.segments) > 0:
+ segment = transcript_response.segments[0]
+ self.assertIsNotNone(segment.text)
+ self.assertIsInstance(segment.id, int)
+ else:
+ # Even if no segments, we should have text
+ self.assertIsNotNone(transcript_response.text)
+ self.assertGreater(len(transcript_response.text), 0)
+ except Exception as err:
+ print(err)
+ self.fail("AudioTranscription service failed")
+ finally:
+ self.tearDown()
+ # Clean up the temporary directory
+ if os.path.exists(temp_dir):
+ shutil.rmtree(temp_dir)
+
diff --git a/backend/python/moonshine/test.sh b/backend/python/moonshine/test.sh
new file mode 100644
index 0000000000000000000000000000000000000000..f6a66da3e58d8a2c11f5fd3252e8ed4d7dcaea2b
--- /dev/null
+++ b/backend/python/moonshine/test.sh
@@ -0,0 +1,12 @@
+#!/bin/bash
+set -e
+
+backend_dir=$(dirname $0)
+if [ -d $backend_dir/common ]; then
+ source $backend_dir/common/libbackend.sh
+else
+ source $backend_dir/../common/libbackend.sh
+fi
+
+runUnittests
+
diff --git a/backend/python/neutts/Makefile b/backend/python/neutts/Makefile
new file mode 100644
index 0000000000000000000000000000000000000000..7d50ed07be297f1c3c59ff2034a0c53b58346929
--- /dev/null
+++ b/backend/python/neutts/Makefile
@@ -0,0 +1,23 @@
+.PHONY: neutts
+neutts:
+ bash install.sh
+
+.PHONY: run
+run: neutts
+ @echo "Running neutts..."
+ bash run.sh
+ @echo "neutts run."
+
+.PHONY: test
+test: neutts
+ @echo "Testing neutts..."
+ bash test.sh
+ @echo "neutts tested."
+
+.PHONY: protogen-clean
+protogen-clean:
+ $(RM) backend_pb2_grpc.py backend_pb2.py
+
+.PHONY: clean
+clean: protogen-clean
+ rm -rf venv __pycache__
\ No newline at end of file
diff --git a/backend/python/neutts/backend.py b/backend/python/neutts/backend.py
new file mode 100644
index 0000000000000000000000000000000000000000..e765436d104f2f2d62a2937928a25f9fee13d1de
--- /dev/null
+++ b/backend/python/neutts/backend.py
@@ -0,0 +1,162 @@
+#!/usr/bin/env python3
+"""
+This is an extra gRPC server of LocalAI for NeuTTSAir
+"""
+from concurrent import futures
+import time
+import argparse
+import signal
+import sys
+import os
+import backend_pb2
+import backend_pb2_grpc
+import torch
+from neuttsair.neutts import NeuTTSAir
+import soundfile as sf
+
+import grpc
+
+def is_float(s):
+ """Check if a string can be converted to float."""
+ try:
+ float(s)
+ return True
+ except ValueError:
+ return False
+def is_int(s):
+ """Check if a string can be converted to int."""
+ try:
+ int(s)
+ return True
+ except ValueError:
+ return False
+
+_ONE_DAY_IN_SECONDS = 60 * 60 * 24
+
+# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
+MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
+
+# Implement the BackendServicer class with the service methods
+class BackendServicer(backend_pb2_grpc.BackendServicer):
+ """
+ BackendServicer is the class that implements the gRPC service
+ """
+ def Health(self, request, context):
+ return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
+ def LoadModel(self, request, context):
+
+ # Get device
+ # device = "cuda" if request.CUDA else "cpu"
+ if torch.cuda.is_available():
+ print("CUDA is available", file=sys.stderr)
+ device = "cuda"
+ else:
+ print("CUDA is not available", file=sys.stderr)
+ device = "cpu"
+ mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
+ if mps_available:
+ device = "mps"
+ if not torch.cuda.is_available() and request.CUDA:
+ return backend_pb2.Result(success=False, message="CUDA is not available")
+
+
+ options = request.Options
+
+ # empty dict
+ self.options = {}
+ self.ref_text = None
+
+ # The options are a list of strings in this form optname:optvalue
+ # We are storing all the options in a dict so we can use it later when
+ # generating the images
+ for opt in options:
+ if ":" not in opt:
+ continue
+ key, value = opt.split(":")
+ # if value is a number, convert it to the appropriate type
+ if is_float(value):
+ value = float(value)
+ elif is_int(value):
+ value = int(value)
+ elif value.lower() in ["true", "false"]:
+ value = value.lower() == "true"
+ self.options[key] = value
+
+ codec_repo = "neuphonic/neucodec"
+ if "codec_repo" in self.options:
+ codec_repo = self.options["codec_repo"]
+ del self.options["codec_repo"]
+ if "ref_text" in self.options:
+ self.ref_text = self.options["ref_text"]
+ del self.options["ref_text"]
+
+ self.AudioPath = None
+
+ if os.path.isabs(request.AudioPath):
+ self.AudioPath = request.AudioPath
+ elif request.AudioPath and request.ModelFile != "" and not os.path.isabs(request.AudioPath):
+ # get base path of modelFile
+ modelFileBase = os.path.dirname(request.ModelFile)
+ # modify LoraAdapter to be relative to modelFileBase
+ self.AudioPath = os.path.join(modelFileBase, request.AudioPath)
+ try:
+ print("Preparing models, please wait", file=sys.stderr)
+ self.model = NeuTTSAir(backbone_repo=request.Model, backbone_device=device, codec_repo=codec_repo, codec_device=device)
+ except Exception as err:
+ return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
+ # Implement your logic here for the LoadModel service
+ # Replace this with your desired response
+ return backend_pb2.Result(message="Model loaded successfully", success=True)
+
+ def TTS(self, request, context):
+ try:
+ kwargs = {}
+
+ # add options to kwargs
+ kwargs.update(self.options)
+
+ ref_codes = self.model.encode_reference(self.AudioPath)
+
+ wav = self.model.infer(request.text, ref_codes, self.ref_text)
+
+ sf.write(request.dst, wav, 24000)
+ except Exception as err:
+ return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
+ return backend_pb2.Result(success=True)
+
+def serve(address):
+ server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
+ options=[
+ ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
+ ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
+ ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
+ ])
+ backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
+ server.add_insecure_port(address)
+ server.start()
+ print("Server started. Listening on: " + address, file=sys.stderr)
+
+ # Define the signal handler function
+ def signal_handler(sig, frame):
+ print("Received termination signal. Shutting down...")
+ server.stop(0)
+ sys.exit(0)
+
+ # Set the signal handlers for SIGINT and SIGTERM
+ signal.signal(signal.SIGINT, signal_handler)
+ signal.signal(signal.SIGTERM, signal_handler)
+
+ try:
+ while True:
+ time.sleep(_ONE_DAY_IN_SECONDS)
+ except KeyboardInterrupt:
+ server.stop(0)
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Run the gRPC server.")
+ parser.add_argument(
+ "--addr", default="localhost:50051", help="The address to bind the server to."
+ )
+ args = parser.parse_args()
+
+ serve(args.addr)
diff --git a/backend/python/neutts/install.sh b/backend/python/neutts/install.sh
new file mode 100644
index 0000000000000000000000000000000000000000..381788605c33a63d82702f525ded80d5f831866a
--- /dev/null
+++ b/backend/python/neutts/install.sh
@@ -0,0 +1,39 @@
+#!/bin/bash
+set -e
+
+backend_dir=$(dirname $0)
+if [ -d $backend_dir/common ]; then
+ source $backend_dir/common/libbackend.sh
+else
+ source $backend_dir/../common/libbackend.sh
+fi
+
+# This is here because the Intel pip index is broken and returns 200 status codes for every package name, it just doesn't return any package links.
+# This makes uv think that the package exists in the Intel pip index, and by default it stops looking at other pip indexes once it finds a match.
+# We need uv to continue falling through to the pypi default index to find optimum[openvino] in the pypi index
+# the --upgrade actually allows us to *downgrade* torch to the version provided in the Intel pip index
+if [ "x${BUILD_PROFILE}" == "xintel" ]; then
+ EXTRA_PIP_INSTALL_FLAGS+=" --upgrade --index-strategy=unsafe-first-match"
+fi
+
+if [ "x${BUILD_TYPE}" == "xcublas" ] || [ "x${BUILD_TYPE}" == "xl4t" ]; then
+ export CMAKE_ARGS="-DGGML_CUDA=on"
+fi
+
+if [ "x${BUILD_TYPE}" == "xhipblas" ]; then
+ export CMAKE_ARGS="-DGGML_HIPBLAS=on"
+fi
+
+EXTRA_PIP_INSTALL_FLAGS+=" --no-build-isolation"
+
+
+if [ "x${BUILD_PROFILE}" == "xl4t12" ]; then
+ USE_PIP=true
+fi
+
+
+git clone https://github.com/neuphonic/neutts-air neutts-air
+
+cp -rfv neutts-air/neuttsair ./
+
+installRequirements
diff --git a/backend/python/neutts/requirements-after.txt b/backend/python/neutts/requirements-after.txt
new file mode 100644
index 0000000000000000000000000000000000000000..dfa969a39d32518a6c97c9bfcf7383f522da01eb
--- /dev/null
+++ b/backend/python/neutts/requirements-after.txt
@@ -0,0 +1,2 @@
+datasets==4.1.1
+torchtune==0.6.1
\ No newline at end of file
diff --git a/backend/python/neutts/requirements-cpu.txt b/backend/python/neutts/requirements-cpu.txt
new file mode 100644
index 0000000000000000000000000000000000000000..d6f972df9d958bf43f19569b0b26e7a512204216
--- /dev/null
+++ b/backend/python/neutts/requirements-cpu.txt
@@ -0,0 +1,10 @@
+--extra-index-url https://download.pytorch.org/whl/cpu
+accelerate
+torch==2.8.0
+transformers==4.56.1
+librosa==0.11.0
+neucodec>=0.0.4
+phonemizer==3.3.0
+soundfile==0.13.1
+resemble-perth==1.0.1
+llama-cpp-python
\ No newline at end of file
diff --git a/backend/python/neutts/requirements-cublas12.txt b/backend/python/neutts/requirements-cublas12.txt
new file mode 100644
index 0000000000000000000000000000000000000000..13afd3b86d648d303e6298ea47de4ddc96739304
--- /dev/null
+++ b/backend/python/neutts/requirements-cublas12.txt
@@ -0,0 +1,8 @@
+librosa==0.11.0
+neucodec>=0.0.4
+phonemizer==3.3.0
+soundfile==0.13.1
+torch==2.8.0
+transformers==4.56.1
+resemble-perth==1.0.1
+accelerate
\ No newline at end of file
diff --git a/backend/python/neutts/requirements-hipblas.txt b/backend/python/neutts/requirements-hipblas.txt
new file mode 100644
index 0000000000000000000000000000000000000000..72d11e0598178cfe4cbbd8114282989e62ab2918
--- /dev/null
+++ b/backend/python/neutts/requirements-hipblas.txt
@@ -0,0 +1,10 @@
+--extra-index-url https://download.pytorch.org/whl/rocm6.4
+torch==2.8.0+rocm6.4
+transformers==4.56.1
+accelerate
+librosa==0.11.0
+neucodec>=0.0.4
+phonemizer==3.3.0
+soundfile==0.13.1
+resemble-perth==1.0.1
+llama-cpp-python
\ No newline at end of file
diff --git a/backend/python/neutts/requirements-l4t12.txt b/backend/python/neutts/requirements-l4t12.txt
new file mode 100644
index 0000000000000000000000000000000000000000..7932d192eb3203373ce8f053040d5f9695422f3a
--- /dev/null
+++ b/backend/python/neutts/requirements-l4t12.txt
@@ -0,0 +1,10 @@
+--extra-index-url https://pypi.jetson-ai-lab.io/jp6/cu126/
+torch
+transformers
+accelerate
+librosa==0.11.0
+neucodec>=0.0.4
+phonemizer==3.3.0
+soundfile==0.13.1
+resemble-perth==1.0.1
+llama-cpp-python
\ No newline at end of file
diff --git a/backend/python/neutts/requirements.txt b/backend/python/neutts/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..9262a3934510bacb5cf004158b12092d5d2d5beb
--- /dev/null
+++ b/backend/python/neutts/requirements.txt
@@ -0,0 +1,7 @@
+grpcio==1.71.0
+protobuf
+certifi
+packaging
+setuptools
+numpy==2.2.6
+scikit_build_core
\ No newline at end of file
diff --git a/backend/python/neutts/run.sh b/backend/python/neutts/run.sh
new file mode 100644
index 0000000000000000000000000000000000000000..7d9d321eded78cc6d2f2b21932bc4ba966ca43e8
--- /dev/null
+++ b/backend/python/neutts/run.sh
@@ -0,0 +1,10 @@
+#!/bin/bash
+backend_dir=$(dirname $0)
+if [ -d $backend_dir/common ]; then
+ source $backend_dir/common/libbackend.sh
+else
+ source $backend_dir/../common/libbackend.sh
+fi
+
+
+startBackend $@
\ No newline at end of file
diff --git a/backend/python/neutts/test.py b/backend/python/neutts/test.py
new file mode 100644
index 0000000000000000000000000000000000000000..878345ab64b4bc87dacbdfbeb73528d38d0f893f
--- /dev/null
+++ b/backend/python/neutts/test.py
@@ -0,0 +1,82 @@
+"""
+A test script to test the gRPC service
+"""
+import unittest
+import subprocess
+import time
+import backend_pb2
+import backend_pb2_grpc
+
+import grpc
+
+
+class TestBackendServicer(unittest.TestCase):
+ """
+ TestBackendServicer is the class that tests the gRPC service
+ """
+ def setUp(self):
+ """
+ This method sets up the gRPC service by starting the server
+ """
+ self.service = subprocess.Popen(["python3", "backend.py", "--addr", "localhost:50051"])
+ time.sleep(30)
+
+ def tearDown(self) -> None:
+ """
+ This method tears down the gRPC service by terminating the server
+ """
+ self.service.terminate()
+ self.service.wait()
+
+ def test_server_startup(self):
+ """
+ This method tests if the server starts up successfully
+ """
+ try:
+ self.setUp()
+ with grpc.insecure_channel("localhost:50051") as channel:
+ stub = backend_pb2_grpc.BackendStub(channel)
+ response = stub.Health(backend_pb2.HealthMessage())
+ self.assertEqual(response.message, b'OK')
+ except Exception as err:
+ print(err)
+ self.fail("Server failed to start")
+ finally:
+ self.tearDown()
+
+ def test_load_model(self):
+ """
+ This method tests if the model is loaded successfully
+ """
+ try:
+ self.setUp()
+ with grpc.insecure_channel("localhost:50051") as channel:
+ stub = backend_pb2_grpc.BackendStub(channel)
+ response = stub.LoadModel(backend_pb2.ModelOptions())
+ print(response)
+ self.assertTrue(response.success)
+ self.assertEqual(response.message, "Model loaded successfully")
+ except Exception as err:
+ print(err)
+ self.fail("LoadModel service failed")
+ finally:
+ self.tearDown()
+
+ def test_tts(self):
+ """
+ This method tests if the embeddings are generated successfully
+ """
+ try:
+ self.setUp()
+ with grpc.insecure_channel("localhost:50051") as channel:
+ stub = backend_pb2_grpc.BackendStub(channel)
+ response = stub.LoadModel(backend_pb2.ModelOptions())
+ self.assertTrue(response.success)
+ tts_request = backend_pb2.TTSRequest(text="80s TV news production music hit for tonight's biggest story")
+ tts_response = stub.TTS(tts_request)
+ self.assertIsNotNone(tts_response)
+ except Exception as err:
+ print(err)
+ self.fail("TTS service failed")
+ finally:
+ self.tearDown()
\ No newline at end of file
diff --git a/backend/python/neutts/test.sh b/backend/python/neutts/test.sh
new file mode 100644
index 0000000000000000000000000000000000000000..eb59f2aaf3f38f78c7ca3dc414ea490ff66776d7
--- /dev/null
+++ b/backend/python/neutts/test.sh
@@ -0,0 +1,11 @@
+#!/bin/bash
+set -e
+
+backend_dir=$(dirname $0)
+if [ -d $backend_dir/common ]; then
+ source $backend_dir/common/libbackend.sh
+else
+ source $backend_dir/../common/libbackend.sh
+fi
+
+runUnittests
diff --git a/backend/python/pocket-tts/Makefile b/backend/python/pocket-tts/Makefile
new file mode 100644
index 0000000000000000000000000000000000000000..3366bb4874ce766a401cff2107a555482ec2b37e
--- /dev/null
+++ b/backend/python/pocket-tts/Makefile
@@ -0,0 +1,23 @@
+.PHONY: pocket-tts
+pocket-tts:
+ bash install.sh
+
+.PHONY: run
+run: pocket-tts
+ @echo "Running pocket-tts..."
+ bash run.sh
+ @echo "pocket-tts run."
+
+.PHONY: test
+test: pocket-tts
+ @echo "Testing pocket-tts..."
+ bash test.sh
+ @echo "pocket-tts tested."
+
+.PHONY: protogen-clean
+protogen-clean:
+ $(RM) backend_pb2_grpc.py backend_pb2.py
+
+.PHONY: clean
+clean: protogen-clean
+ rm -rf venv __pycache__
diff --git a/backend/python/pocket-tts/backend.py b/backend/python/pocket-tts/backend.py
new file mode 100644
index 0000000000000000000000000000000000000000..b02cf481a55c9f0a95089e6cf58f867e5dd61c9c
--- /dev/null
+++ b/backend/python/pocket-tts/backend.py
@@ -0,0 +1,255 @@
+#!/usr/bin/env python3
+"""
+This is an extra gRPC server of LocalAI for Pocket TTS
+"""
+from concurrent import futures
+import time
+import argparse
+import signal
+import sys
+import os
+import traceback
+import scipy.io.wavfile
+import backend_pb2
+import backend_pb2_grpc
+import torch
+from pocket_tts import TTSModel
+
+import grpc
+
+def is_float(s):
+ """Check if a string can be converted to float."""
+ try:
+ float(s)
+ return True
+ except ValueError:
+ return False
+
+def is_int(s):
+ """Check if a string can be converted to int."""
+ try:
+ int(s)
+ return True
+ except ValueError:
+ return False
+
+_ONE_DAY_IN_SECONDS = 60 * 60 * 24
+
+# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
+MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
+
+# Implement the BackendServicer class with the service methods
+class BackendServicer(backend_pb2_grpc.BackendServicer):
+ """
+ BackendServicer is the class that implements the gRPC service
+ """
+ def Health(self, request, context):
+ return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
+
+ def LoadModel(self, request, context):
+ # Get device
+ if torch.cuda.is_available():
+ print("CUDA is available", file=sys.stderr)
+ device = "cuda"
+ else:
+ print("CUDA is not available", file=sys.stderr)
+ device = "cpu"
+ mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
+ if mps_available:
+ device = "mps"
+ if not torch.cuda.is_available() and request.CUDA:
+ return backend_pb2.Result(success=False, message="CUDA is not available")
+
+ # Normalize potential 'mpx' typo to 'mps'
+ if device == "mpx":
+ print("Note: device 'mpx' detected, treating it as 'mps'.", file=sys.stderr)
+ device = "mps"
+
+ # Validate mps availability if requested
+ if device == "mps" and not torch.backends.mps.is_available():
+ print("Warning: MPS not available. Falling back to CPU.", file=sys.stderr)
+ device = "cpu"
+
+ self.device = device
+
+ options = request.Options
+
+ # empty dict
+ self.options = {}
+
+ # The options are a list of strings in this form optname:optvalue
+ # We are storing all the options in a dict so we can use it later when
+ # generating the audio
+ for opt in options:
+ if ":" not in opt:
+ continue
+ key, value = opt.split(":", 1) # Split only on first colon
+ # if value is a number, convert it to the appropriate type
+ if is_float(value):
+ value = float(value)
+ elif is_int(value):
+ value = int(value)
+ elif value.lower() in ["true", "false"]:
+ value = value.lower() == "true"
+ self.options[key] = value
+
+ # Default voice for caching
+ self.default_voice_url = self.options.get("default_voice", None)
+ self._voice_cache = {}
+
+ try:
+ print("Loading Pocket TTS model", file=sys.stderr)
+ self.tts_model = TTSModel.load_model()
+ print(f"Model loaded successfully. Sample rate: {self.tts_model.sample_rate}", file=sys.stderr)
+
+ # Pre-load default voice if specified
+ if self.default_voice_url:
+ try:
+ print(f"Pre-loading default voice: {self.default_voice_url}", file=sys.stderr)
+ voice_state = self.tts_model.get_state_for_audio_prompt(self.default_voice_url)
+ self._voice_cache[self.default_voice_url] = voice_state
+ print("Default voice loaded successfully", file=sys.stderr)
+ except Exception as e:
+ print(f"Warning: Failed to pre-load default voice: {e}", file=sys.stderr)
+
+ except Exception as err:
+ return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
+
+ return backend_pb2.Result(message="Model loaded successfully", success=True)
+
+ def _get_voice_state(self, voice_input):
+ """
+ Get voice state from cache or load it.
+ voice_input can be:
+ - HuggingFace URL (e.g., hf://kyutai/tts-voices/alba-mackenna/casual.wav)
+ - Local file path
+ - None (use default)
+ """
+ # Use default if no voice specified
+ if not voice_input:
+ voice_input = self.default_voice_url
+
+ if not voice_input:
+ return None
+
+ # Check cache first
+ if voice_input in self._voice_cache:
+ return self._voice_cache[voice_input]
+
+ # Load voice state
+ try:
+ print(f"Loading voice from: {voice_input}", file=sys.stderr)
+ voice_state = self.tts_model.get_state_for_audio_prompt(voice_input)
+ self._voice_cache[voice_input] = voice_state
+ return voice_state
+ except Exception as e:
+ print(f"Error loading voice from {voice_input}: {e}", file=sys.stderr)
+ return None
+
+ def TTS(self, request, context):
+ try:
+ # Determine voice input
+ # Priority: request.voice > AudioPath (from ModelOptions) > default
+ voice_input = None
+
+ if request.voice:
+ voice_input = request.voice
+ elif hasattr(request, 'AudioPath') and request.AudioPath:
+ # Use AudioPath as voice file
+ if os.path.isabs(request.AudioPath):
+ voice_input = request.AudioPath
+ elif hasattr(request, 'ModelFile') and request.ModelFile:
+ model_file_base = os.path.dirname(request.ModelFile)
+ voice_input = os.path.join(model_file_base, request.AudioPath)
+ elif hasattr(request, 'ModelPath') and request.ModelPath:
+ voice_input = os.path.join(request.ModelPath, request.AudioPath)
+ else:
+ voice_input = request.AudioPath
+
+ # Get voice state
+ voice_state = self._get_voice_state(voice_input)
+ if voice_state is None:
+ return backend_pb2.Result(
+ success=False,
+ message=f"Voice not found or failed to load: {voice_input}. Please provide a valid voice URL or file path."
+ )
+
+ # Prepare text
+ text = request.text.strip()
+
+ if not text:
+ return backend_pb2.Result(
+ success=False,
+ message="Text is empty"
+ )
+
+ print(f"Generating audio for text: {text[:50]}...", file=sys.stderr)
+
+ # Generate audio
+ audio = self.tts_model.generate_audio(voice_state, text)
+
+ # Audio is a 1D torch tensor containing PCM data
+ if audio is None or audio.numel() == 0:
+ return backend_pb2.Result(
+ success=False,
+ message="No audio generated"
+ )
+
+ # Save audio to file
+ output_path = request.dst
+ if not output_path:
+ output_path = "/tmp/pocket-tts-output.wav"
+
+ # Ensure output directory exists
+ output_dir = os.path.dirname(output_path)
+ if output_dir and not os.path.exists(output_dir):
+ os.makedirs(output_dir, exist_ok=True)
+
+ # Convert torch tensor to numpy and save
+ audio_numpy = audio.numpy()
+ scipy.io.wavfile.write(output_path, self.tts_model.sample_rate, audio_numpy)
+ print(f"Saved audio to {output_path}", file=sys.stderr)
+
+ except Exception as err:
+ print(f"Error in TTS: {err}", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+ return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
+
+ return backend_pb2.Result(success=True)
+
+def serve(address):
+ server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
+ options=[
+ ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
+ ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
+ ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
+ ])
+ backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
+ server.add_insecure_port(address)
+ server.start()
+ print("Server started. Listening on: " + address, file=sys.stderr)
+
+ # Define the signal handler function
+ def signal_handler(sig, frame):
+ print("Received termination signal. Shutting down...")
+ server.stop(0)
+ sys.exit(0)
+
+ # Set the signal handlers for SIGINT and SIGTERM
+ signal.signal(signal.SIGINT, signal_handler)
+ signal.signal(signal.SIGTERM, signal_handler)
+
+ try:
+ while True:
+ time.sleep(_ONE_DAY_IN_SECONDS)
+ except KeyboardInterrupt:
+ server.stop(0)
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Run the gRPC server.")
+ parser.add_argument(
+ "--addr", default="localhost:50051", help="The address to bind the server to."
+ )
+ args = parser.parse_args()
+
+ serve(args.addr)
diff --git a/backend/python/pocket-tts/install.sh b/backend/python/pocket-tts/install.sh
new file mode 100644
index 0000000000000000000000000000000000000000..6058b3d545ad8ff93f776e23a362bc761e9e9e47
--- /dev/null
+++ b/backend/python/pocket-tts/install.sh
@@ -0,0 +1,30 @@
+#!/bin/bash
+set -e
+
+backend_dir=$(dirname $0)
+if [ -d $backend_dir/common ]; then
+ source $backend_dir/common/libbackend.sh
+else
+ source $backend_dir/../common/libbackend.sh
+fi
+
+# This is here because the Intel pip index is broken and returns 200 status codes for every package name, it just doesn't return any package links.
+# This makes uv think that the package exists in the Intel pip index, and by default it stops looking at other pip indexes once it finds a match.
+# We need uv to continue falling through to the pypi default index to find optimum[openvino] in the pypi index
+# the --upgrade actually allows us to *downgrade* torch to the version provided in the Intel pip index
+if [ "x${BUILD_PROFILE}" == "xintel" ]; then
+ EXTRA_PIP_INSTALL_FLAGS+=" --upgrade --index-strategy=unsafe-first-match"
+fi
+
+# Use python 3.12 for l4t
+if [ "x${BUILD_PROFILE}" == "xl4t13" ]; then
+ PYTHON_VERSION="3.12"
+ PYTHON_PATCH="12"
+ PY_STANDALONE_TAG="20251120"
+fi
+
+if [ "x${BUILD_PROFILE}" == "xl4t12" ]; then
+ USE_PIP=true
+fi
+
+installRequirements
diff --git a/backend/python/pocket-tts/protogen.sh b/backend/python/pocket-tts/protogen.sh
new file mode 100644
index 0000000000000000000000000000000000000000..1ad37dee164bf2aaf7371a196b411c7ae843527d
--- /dev/null
+++ b/backend/python/pocket-tts/protogen.sh
@@ -0,0 +1,11 @@
+#!/bin/bash
+set -e
+
+backend_dir=$(dirname $0)
+if [ -d $backend_dir/common ]; then
+ source $backend_dir/common/libbackend.sh
+else
+ source $backend_dir/../common/libbackend.sh
+fi
+
+python3 -m grpc_tools.protoc -I../.. -I./ --python_out=. --grpc_python_out=. backend.proto
diff --git a/backend/python/pocket-tts/requirements-cpu.txt b/backend/python/pocket-tts/requirements-cpu.txt
new file mode 100644
index 0000000000000000000000000000000000000000..d14153bc5aafbd7dabf2d8a89c649d64a2b34ca8
--- /dev/null
+++ b/backend/python/pocket-tts/requirements-cpu.txt
@@ -0,0 +1,4 @@
+--extra-index-url https://download.pytorch.org/whl/cpu
+pocket-tts
+scipy
+torch
diff --git a/backend/python/pocket-tts/requirements-cublas12.txt b/backend/python/pocket-tts/requirements-cublas12.txt
new file mode 100644
index 0000000000000000000000000000000000000000..f43f5094b9f457cdca418316fdcb29906c2ce1a8
--- /dev/null
+++ b/backend/python/pocket-tts/requirements-cublas12.txt
@@ -0,0 +1,4 @@
+--extra-index-url https://download.pytorch.org/whl/cu121
+pocket-tts
+scipy
+torch
diff --git a/backend/python/pocket-tts/requirements-cublas13.txt b/backend/python/pocket-tts/requirements-cublas13.txt
new file mode 100644
index 0000000000000000000000000000000000000000..26e07545fdc7cde317dde107910e9d8c18f3f160
--- /dev/null
+++ b/backend/python/pocket-tts/requirements-cublas13.txt
@@ -0,0 +1,4 @@
+--extra-index-url https://download.pytorch.org/whl/cu130
+pocket-tts
+scipy
+torch
diff --git a/backend/python/pocket-tts/requirements-hipblas.txt b/backend/python/pocket-tts/requirements-hipblas.txt
new file mode 100644
index 0000000000000000000000000000000000000000..b6f9d2fb6a0abffedebd22feaf3f5e55920e8ef6
--- /dev/null
+++ b/backend/python/pocket-tts/requirements-hipblas.txt
@@ -0,0 +1,4 @@
+--extra-index-url https://download.pytorch.org/whl/rocm6.3
+pocket-tts
+scipy
+torch==2.7.1+rocm6.3
diff --git a/backend/python/pocket-tts/requirements-intel.txt b/backend/python/pocket-tts/requirements-intel.txt
new file mode 100644
index 0000000000000000000000000000000000000000..3bb61cb7311d20081861dc2b1b9ba2cece3a5d48
--- /dev/null
+++ b/backend/python/pocket-tts/requirements-intel.txt
@@ -0,0 +1,4 @@
+--extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
+pocket-tts
+scipy
+torch==2.5.1+cxx11.abi
diff --git a/backend/python/pocket-tts/requirements-l4t12.txt b/backend/python/pocket-tts/requirements-l4t12.txt
new file mode 100644
index 0000000000000000000000000000000000000000..39131ac17b3632dd78f1451dd177c012de0db9cd
--- /dev/null
+++ b/backend/python/pocket-tts/requirements-l4t12.txt
@@ -0,0 +1,4 @@
+--extra-index-url https://pypi.jetson-ai-lab.io/jp6/cu129/
+pocket-tts
+scipy
+torch
diff --git a/backend/python/pocket-tts/requirements-l4t13.txt b/backend/python/pocket-tts/requirements-l4t13.txt
new file mode 100644
index 0000000000000000000000000000000000000000..d6503f7c118d6f300fdc7ec11db064e8ea36acd4
--- /dev/null
+++ b/backend/python/pocket-tts/requirements-l4t13.txt
@@ -0,0 +1,4 @@
+--extra-index-url https://download.pytorch.org/whl/cu130
+pocket-tts
+scipy
+torch
\ No newline at end of file
diff --git a/backend/python/pocket-tts/requirements-mps.txt b/backend/python/pocket-tts/requirements-mps.txt
new file mode 100644
index 0000000000000000000000000000000000000000..235eaffd54fb6d83ad9df1b3fee2198875947377
--- /dev/null
+++ b/backend/python/pocket-tts/requirements-mps.txt
@@ -0,0 +1,4 @@
+pocket-tts
+scipy
+torch==2.7.1
+torchvision==0.22.1
diff --git a/backend/python/pocket-tts/requirements.txt b/backend/python/pocket-tts/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..9e532186b2c8d2061852e7ce9ebd7f5536dc9763
--- /dev/null
+++ b/backend/python/pocket-tts/requirements.txt
@@ -0,0 +1,4 @@
+grpcio==1.71.0
+protobuf
+certifi
+packaging==24.1
diff --git a/backend/python/pocket-tts/run.sh b/backend/python/pocket-tts/run.sh
new file mode 100644
index 0000000000000000000000000000000000000000..eae121f37b0bf655d6b5dce60647099e666ea01a
--- /dev/null
+++ b/backend/python/pocket-tts/run.sh
@@ -0,0 +1,9 @@
+#!/bin/bash
+backend_dir=$(dirname $0)
+if [ -d $backend_dir/common ]; then
+ source $backend_dir/common/libbackend.sh
+else
+ source $backend_dir/../common/libbackend.sh
+fi
+
+startBackend $@
diff --git a/backend/python/pocket-tts/test.py b/backend/python/pocket-tts/test.py
new file mode 100644
index 0000000000000000000000000000000000000000..34efa1080d00f26fdac37a8abce16501422c6245
--- /dev/null
+++ b/backend/python/pocket-tts/test.py
@@ -0,0 +1,141 @@
+"""
+A test script to test the gRPC service
+"""
+import unittest
+import subprocess
+import time
+import os
+import tempfile
+import backend_pb2
+import backend_pb2_grpc
+
+import grpc
+
+
+class TestBackendServicer(unittest.TestCase):
+ """
+ TestBackendServicer is the class that tests the gRPC service
+ """
+ def setUp(self):
+ """
+ This method sets up the gRPC service by starting the server
+ """
+ self.service = subprocess.Popen(["python3", "backend.py", "--addr", "localhost:50051"])
+ time.sleep(30)
+
+ def tearDown(self) -> None:
+ """
+ This method tears down the gRPC service by terminating the server
+ """
+ self.service.terminate()
+ self.service.wait()
+
+ def test_server_startup(self):
+ """
+ This method tests if the server starts up successfully
+ """
+ try:
+ self.setUp()
+ with grpc.insecure_channel("localhost:50051") as channel:
+ stub = backend_pb2_grpc.BackendStub(channel)
+ response = stub.Health(backend_pb2.HealthMessage())
+ self.assertEqual(response.message, b'OK')
+ except Exception as err:
+ print(err)
+ self.fail("Server failed to start")
+ finally:
+ self.tearDown()
+
+ def test_load_model(self):
+ """
+ This method tests if the model is loaded successfully
+ """
+ try:
+ self.setUp()
+ with grpc.insecure_channel("localhost:50051") as channel:
+ stub = backend_pb2_grpc.BackendStub(channel)
+ response = stub.LoadModel(backend_pb2.ModelOptions())
+ print(response)
+ self.assertTrue(response.success)
+ self.assertEqual(response.message, "Model loaded successfully")
+ except Exception as err:
+ print(err)
+ self.fail("LoadModel service failed")
+ finally:
+ self.tearDown()
+
+ def test_tts_with_hf_voice(self):
+ """
+ This method tests TTS generation with HuggingFace voice URL
+ """
+ try:
+ self.setUp()
+ with grpc.insecure_channel("localhost:50051") as channel:
+ stub = backend_pb2_grpc.BackendStub(channel)
+ # Load model
+ response = stub.LoadModel(backend_pb2.ModelOptions())
+ self.assertTrue(response.success)
+
+ # Create temporary output file
+ with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp_file:
+ output_path = tmp_file.name
+
+ # Test TTS with HuggingFace voice URL
+ tts_request = backend_pb2.TTSRequest(
+ text="Hello world, this is a test.",
+ dst=output_path,
+ voice="azelma"
+ )
+ tts_response = stub.TTS(tts_request)
+ self.assertTrue(tts_response.success)
+
+ # Verify output file exists and is not empty
+ self.assertTrue(os.path.exists(output_path))
+ self.assertGreater(os.path.getsize(output_path), 0)
+
+ # Cleanup
+ os.unlink(output_path)
+ except Exception as err:
+ print(err)
+ self.fail("TTS service failed")
+ finally:
+ self.tearDown()
+
+ def test_tts_with_default_voice(self):
+ """
+ This method tests TTS generation with default voice (via AudioPath in LoadModel)
+ """
+ try:
+ self.setUp()
+ with grpc.insecure_channel("localhost:50051") as channel:
+ stub = backend_pb2_grpc.BackendStub(channel)
+ # Load model with default voice
+ load_request = backend_pb2.ModelOptions(
+ Options=["default_voice:azelma"]
+ )
+ response = stub.LoadModel(load_request)
+ self.assertTrue(response.success)
+
+ # Create temporary output file
+ with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp_file:
+ output_path = tmp_file.name
+
+ # Test TTS without specifying voice (should use default)
+ tts_request = backend_pb2.TTSRequest(
+ text="Hello world, this is a test.",
+ dst=output_path
+ )
+ tts_response = stub.TTS(tts_request)
+ self.assertTrue(tts_response.success)
+
+ # Verify output file exists and is not empty
+ self.assertTrue(os.path.exists(output_path))
+ self.assertGreater(os.path.getsize(output_path), 0)
+
+ # Cleanup
+ os.unlink(output_path)
+ except Exception as err:
+ print(err)
+ self.fail("TTS service with default voice failed")
+ finally:
+ self.tearDown()
diff --git a/backend/python/pocket-tts/test.sh b/backend/python/pocket-tts/test.sh
new file mode 100644
index 0000000000000000000000000000000000000000..eb59f2aaf3f38f78c7ca3dc414ea490ff66776d7
--- /dev/null
+++ b/backend/python/pocket-tts/test.sh
@@ -0,0 +1,11 @@
+#!/bin/bash
+set -e
+
+backend_dir=$(dirname $0)
+if [ -d $backend_dir/common ]; then
+ source $backend_dir/common/libbackend.sh
+else
+ source $backend_dir/../common/libbackend.sh
+fi
+
+runUnittests
diff --git a/backend/python/rerankers/Makefile b/backend/python/rerankers/Makefile
new file mode 100644
index 0000000000000000000000000000000000000000..c9a1d30104b4d88e59b243c037d5cd33ba60bfd7
--- /dev/null
+++ b/backend/python/rerankers/Makefile
@@ -0,0 +1,24 @@
+.PHONY: rerankers
+rerankers:
+ bash install.sh
+
+.PHONY: run
+run: rerankers
+ @echo "Running rerankers..."
+ bash run.sh
+ @echo "rerankers run."
+
+# It is not working well by using command line. It only6 works with IDE like VSCode.
+.PHONY: test
+test: rerankers
+ @echo "Testing rerankers..."
+ bash test.sh
+ @echo "rerankers tested."
+
+.PHONY: protogen-clean
+protogen-clean:
+ $(RM) backend_pb2_grpc.py backend_pb2.py
+
+.PHONY: clean
+clean: protogen-clean
+ rm -rf venv __pycache__
\ No newline at end of file
diff --git a/backend/python/rerankers/README.md b/backend/python/rerankers/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..9e73ba0accc1b157eabd440a6c11ffe8435e78b2
--- /dev/null
+++ b/backend/python/rerankers/README.md
@@ -0,0 +1,5 @@
+# Creating a separate environment for the reranker project
+
+```
+make reranker
+```
\ No newline at end of file
diff --git a/backend/python/rerankers/backend.py b/backend/python/rerankers/backend.py
new file mode 100644
index 0000000000000000000000000000000000000000..8ce2636d7a13a87c858b54e421240b7f370fbf82
--- /dev/null
+++ b/backend/python/rerankers/backend.py
@@ -0,0 +1,129 @@
+#!/usr/bin/env python3
+"""
+Extra gRPC server for Rerankers models.
+"""
+from concurrent import futures
+
+import argparse
+import signal
+import sys
+import os
+
+import time
+import backend_pb2
+import backend_pb2_grpc
+
+import grpc
+
+from rerankers import Reranker
+
+_ONE_DAY_IN_SECONDS = 60 * 60 * 24
+
+# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
+MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
+
+# Implement the BackendServicer class with the service methods
+class BackendServicer(backend_pb2_grpc.BackendServicer):
+ """
+ A gRPC servicer for the backend service.
+
+ This class implements the gRPC methods for the backend service, including Health, LoadModel, and Embedding.
+ """
+ def Health(self, request, context):
+ """
+ A gRPC method that returns the health status of the backend service.
+
+ Args:
+ request: A HealthRequest object that contains the request parameters.
+ context: A grpc.ServicerContext object that provides information about the RPC.
+
+ Returns:
+ A Reply object that contains the health status of the backend service.
+ """
+ return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
+
+ def LoadModel(self, request, context):
+ """
+ A gRPC method that loads a model into memory.
+
+ Args:
+ request: A LoadModelRequest object that contains the request parameters.
+ context: A grpc.ServicerContext object that provides information about the RPC.
+
+ Returns:
+ A Result object that contains the result of the LoadModel operation.
+ """
+ model_name = request.Model
+ try:
+ kwargs = {}
+ if request.Type != "":
+ kwargs['model_type'] = request.Type
+ if request.PipelineType != "": # Reuse the PipelineType field for language
+ kwargs['lang'] = request.PipelineType
+ self.model_name = model_name
+ self.model = Reranker(model_name, **kwargs)
+ except Exception as err:
+ return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
+
+ # Implement your logic here for the LoadModel service
+ # Replace this with your desired response
+ return backend_pb2.Result(message="Model loaded successfully", success=True)
+
+ def Rerank(self, request, context):
+ documents = []
+ for idx, doc in enumerate(request.documents):
+ documents.append(doc)
+ ranked_results=self.model.rank(query=request.query, docs=documents, doc_ids=list(range(len(request.documents))))
+ # Prepare results to return
+ cropped_results = ranked_results.top_k(request.top_n) if request.top_n > 0 else ranked_results
+ results = [
+ backend_pb2.DocumentResult(
+ index=res.doc_id,
+ text=res.text,
+ relevance_score=res.score
+ ) for res in (cropped_results)
+ ]
+
+ # Calculate the usage and total tokens
+ # TODO: Implement the usage calculation with reranker
+ total_tokens = sum(len(doc.split()) for doc in request.documents) + len(request.query.split())
+ prompt_tokens = len(request.query.split())
+ usage = backend_pb2.Usage(total_tokens=total_tokens, prompt_tokens=prompt_tokens)
+ return backend_pb2.RerankResult(usage=usage, results=results)
+
+def serve(address):
+ server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
+ options=[
+ ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
+ ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
+ ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
+ ])
+ backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
+ server.add_insecure_port(address)
+ server.start()
+ print("Server started. Listening on: " + address, file=sys.stderr)
+
+ # Define the signal handler function
+ def signal_handler(sig, frame):
+ print("Received termination signal. Shutting down...")
+ server.stop(0)
+ sys.exit(0)
+
+ # Set the signal handlers for SIGINT and SIGTERM
+ signal.signal(signal.SIGINT, signal_handler)
+ signal.signal(signal.SIGTERM, signal_handler)
+
+ try:
+ while True:
+ time.sleep(_ONE_DAY_IN_SECONDS)
+ except KeyboardInterrupt:
+ server.stop(0)
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Run the gRPC server.")
+ parser.add_argument(
+ "--addr", default="localhost:50051", help="The address to bind the server to."
+ )
+ args = parser.parse_args()
+
+ serve(args.addr)
diff --git a/backend/python/rerankers/install.sh b/backend/python/rerankers/install.sh
new file mode 100644
index 0000000000000000000000000000000000000000..4cd3f65111fb66244667671773d38deca0dc06d8
--- /dev/null
+++ b/backend/python/rerankers/install.sh
@@ -0,0 +1,20 @@
+#!/bin/bash
+set -e
+
+backend_dir=$(dirname $0)
+
+if [ -d $backend_dir/common ]; then
+ source $backend_dir/common/libbackend.sh
+else
+ source $backend_dir/../common/libbackend.sh
+fi
+
+# This is here because the Intel pip index is broken and returns 200 status codes for every package name, it just doesn't return any package links.
+# This makes uv think that the package exists in the Intel pip index, and by default it stops looking at other pip indexes once it finds a match.
+# We need uv to continue falling through to the pypi default index to find optimum[openvino] in the pypi index
+# the --upgrade actually allows us to *downgrade* torch to the version provided in the Intel pip index
+if [ "x${BUILD_PROFILE}" == "xintel" ]; then
+ EXTRA_PIP_INSTALL_FLAGS+=" --upgrade --index-strategy=unsafe-first-match"
+fi
+
+installRequirements
diff --git a/backend/python/rerankers/requirements-cpu.txt b/backend/python/rerankers/requirements-cpu.txt
new file mode 100644
index 0000000000000000000000000000000000000000..e27a4726379799ccc30c48d82d72f0b39937c061
--- /dev/null
+++ b/backend/python/rerankers/requirements-cpu.txt
@@ -0,0 +1,4 @@
+transformers
+accelerate
+torch==2.4.1
+rerankers[transformers]
\ No newline at end of file
diff --git a/backend/python/rerankers/requirements-cublas12.txt b/backend/python/rerankers/requirements-cublas12.txt
new file mode 100644
index 0000000000000000000000000000000000000000..e27a4726379799ccc30c48d82d72f0b39937c061
--- /dev/null
+++ b/backend/python/rerankers/requirements-cublas12.txt
@@ -0,0 +1,4 @@
+transformers
+accelerate
+torch==2.4.1
+rerankers[transformers]
\ No newline at end of file
diff --git a/backend/python/rerankers/requirements-cublas13.txt b/backend/python/rerankers/requirements-cublas13.txt
new file mode 100644
index 0000000000000000000000000000000000000000..b565a9cc154a9b4afe05d6945d8e535556bae482
--- /dev/null
+++ b/backend/python/rerankers/requirements-cublas13.txt
@@ -0,0 +1,5 @@
+--extra-index-url https://download.pytorch.org/whl/cu130
+transformers
+accelerate
+torch==2.9.1
+rerankers[transformers]
\ No newline at end of file
diff --git a/backend/python/rerankers/requirements-hipblas.txt b/backend/python/rerankers/requirements-hipblas.txt
new file mode 100644
index 0000000000000000000000000000000000000000..7a72b3d0650fe3e837358be9851ea09a74b78727
--- /dev/null
+++ b/backend/python/rerankers/requirements-hipblas.txt
@@ -0,0 +1,5 @@
+--extra-index-url https://download.pytorch.org/whl/rocm6.4
+transformers
+accelerate
+torch==2.8.0+rocm6.4
+rerankers[transformers]
\ No newline at end of file
diff --git a/backend/python/rerankers/requirements-intel.txt b/backend/python/rerankers/requirements-intel.txt
new file mode 100644
index 0000000000000000000000000000000000000000..820dd84224a754e051e81c88ac13f3237dd95e3b
--- /dev/null
+++ b/backend/python/rerankers/requirements-intel.txt
@@ -0,0 +1,9 @@
+--extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
+intel-extension-for-pytorch==2.3.110+xpu
+transformers
+accelerate
+torch==2.3.1+cxx11.abi
+oneccl_bind_pt==2.8.0+xpu
+rerankers[transformers]
+optimum[openvino]
+setuptools
\ No newline at end of file
diff --git a/backend/python/rerankers/requirements.txt b/backend/python/rerankers/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..cf77f433c70b02e3cab4c74237bc827db1e8f292
--- /dev/null
+++ b/backend/python/rerankers/requirements.txt
@@ -0,0 +1,3 @@
+grpcio==1.76.0
+protobuf
+certifi
\ No newline at end of file
diff --git a/backend/python/rerankers/run.sh b/backend/python/rerankers/run.sh
new file mode 100644
index 0000000000000000000000000000000000000000..4d2769c5a35359202dc0eb8210fdbf172ea823e4
--- /dev/null
+++ b/backend/python/rerankers/run.sh
@@ -0,0 +1,10 @@
+#!/bin/bash
+
+backend_dir=$(dirname $0)
+if [ -d $backend_dir/common ]; then
+ source $backend_dir/common/libbackend.sh
+else
+ source $backend_dir/../common/libbackend.sh
+fi
+
+startBackend $@
\ No newline at end of file
diff --git a/backend/python/rerankers/test.py b/backend/python/rerankers/test.py
new file mode 100644
index 0000000000000000000000000000000000000000..f5890fc25d2407f3965a007a2d915e06887c0d6d
--- /dev/null
+++ b/backend/python/rerankers/test.py
@@ -0,0 +1,146 @@
+"""
+A test script to test the gRPC service
+"""
+import unittest
+import subprocess
+import time
+import backend_pb2
+import backend_pb2_grpc
+
+import grpc
+
+
+class TestBackendServicer(unittest.TestCase):
+ """
+ TestBackendServicer is the class that tests the gRPC service
+ """
+ def setUp(self):
+ """
+ This method sets up the gRPC service by starting the server
+ """
+ self.service = subprocess.Popen(["python3", "backend.py", "--addr", "localhost:50051"])
+ time.sleep(10)
+
+ def tearDown(self) -> None:
+ """
+ This method tears down the gRPC service by terminating the server
+ """
+ self.service.kill()
+ self.service.wait()
+
+ def test_server_startup(self):
+ """
+ This method tests if the server starts up successfully
+ """
+ try:
+ self.setUp()
+ with grpc.insecure_channel("localhost:50051") as channel:
+ stub = backend_pb2_grpc.BackendStub(channel)
+ response = stub.Health(backend_pb2.HealthMessage())
+ self.assertEqual(response.message, b'OK')
+ except Exception as err:
+ print(err)
+ self.fail("Server failed to start")
+ finally:
+ self.tearDown()
+
+ def test_load_model(self):
+ """
+ This method tests if the model is loaded successfully
+ """
+ try:
+ self.setUp()
+ with grpc.insecure_channel("localhost:50051") as channel:
+ stub = backend_pb2_grpc.BackendStub(channel)
+ response = stub.LoadModel(backend_pb2.ModelOptions(Model="cross-encoder"))
+ self.assertTrue(response.success)
+ self.assertEqual(response.message, "Model loaded successfully")
+ except Exception as err:
+ print(err)
+ self.fail("LoadModel service failed")
+ finally:
+ self.tearDown()
+
+ def test_rerank(self):
+ """
+ This method tests if the embeddings are generated successfully
+ """
+ try:
+ self.setUp()
+ with grpc.insecure_channel("localhost:50051") as channel:
+ stub = backend_pb2_grpc.BackendStub(channel)
+ request = backend_pb2.RerankRequest(
+ query="I love you",
+ documents=["I hate you", "I really like you"],
+ top_n=2
+ )
+ response = stub.LoadModel(backend_pb2.ModelOptions(Model="cross-encoder"))
+ self.assertTrue(response.success)
+
+ rerank_response = stub.Rerank(request)
+ print(rerank_response.results[0])
+ self.assertIsNotNone(rerank_response.results)
+ self.assertEqual(len(rerank_response.results), 2)
+ self.assertEqual(rerank_response.results[0].text, "I really like you")
+ self.assertEqual(rerank_response.results[1].text, "I hate you")
+ except Exception as err:
+ print(err)
+ self.fail("Reranker service failed")
+ finally:
+ self.tearDown()
+
+ def test_rerank_omit_top_n(self):
+ """
+ This method tests if the embeddings are generated successfully even top_n is omitted
+ """
+ try:
+ self.setUp()
+ with grpc.insecure_channel("localhost:50051") as channel:
+ stub = backend_pb2_grpc.BackendStub(channel)
+ request = backend_pb2.RerankRequest(
+ query="I love you",
+ documents=["I hate you", "I really like you"],
+ top_n=0 #
+ )
+ response = stub.LoadModel(backend_pb2.ModelOptions(Model="cross-encoder"))
+ self.assertTrue(response.success)
+
+ rerank_response = stub.Rerank(request)
+ print(rerank_response.results[0])
+ self.assertIsNotNone(rerank_response.results)
+ self.assertEqual(len(rerank_response.results), 2)
+ self.assertEqual(rerank_response.results[0].text, "I really like you")
+ self.assertEqual(rerank_response.results[1].text, "I hate you")
+ except Exception as err:
+ print(err)
+ self.fail("Reranker service failed")
+ finally:
+ self.tearDown()
+
+ def test_rerank_crop(self):
+ """
+ This method tests top_n cropping
+ """
+ try:
+ self.setUp()
+ with grpc.insecure_channel("localhost:50051") as channel:
+ stub = backend_pb2_grpc.BackendStub(channel)
+ request = backend_pb2.RerankRequest(
+ query="I love you",
+ documents=["I hate you", "I really like you", "I hate ignoring top_n"],
+ top_n=2
+ )
+ response = stub.LoadModel(backend_pb2.ModelOptions(Model="cross-encoder"))
+ self.assertTrue(response.success)
+
+ rerank_response = stub.Rerank(request)
+ print(rerank_response.results[0])
+ self.assertIsNotNone(rerank_response.results)
+ self.assertEqual(len(rerank_response.results), 2)
+ self.assertEqual(rerank_response.results[0].text, "I really like you")
+ self.assertEqual(rerank_response.results[1].text, "I hate you")
+ except Exception as err:
+ print(err)
+ self.fail("Reranker service failed")
+ finally:
+ self.tearDown()
diff --git a/backend/python/rerankers/test.sh b/backend/python/rerankers/test.sh
new file mode 100644
index 0000000000000000000000000000000000000000..eb59f2aaf3f38f78c7ca3dc414ea490ff66776d7
--- /dev/null
+++ b/backend/python/rerankers/test.sh
@@ -0,0 +1,11 @@
+#!/bin/bash
+set -e
+
+backend_dir=$(dirname $0)
+if [ -d $backend_dir/common ]; then
+ source $backend_dir/common/libbackend.sh
+else
+ source $backend_dir/../common/libbackend.sh
+fi
+
+runUnittests
diff --git a/backend/python/rfdetr/Makefile b/backend/python/rfdetr/Makefile
new file mode 100644
index 0000000000000000000000000000000000000000..f6b9ddc6c888845a9b20c98d3ef8bfae3629a1cd
--- /dev/null
+++ b/backend/python/rfdetr/Makefile
@@ -0,0 +1,13 @@
+.DEFAULT_GOAL := install
+
+.PHONY: install
+install:
+ bash install.sh
+
+.PHONY: protogen-clean
+protogen-clean:
+ $(RM) backend_pb2_grpc.py backend_pb2.py
+
+.PHONY: clean
+clean: protogen-clean
+ rm -rf venv __pycache__
\ No newline at end of file
diff --git a/backend/python/rfdetr/backend.py b/backend/python/rfdetr/backend.py
new file mode 100644
index 0000000000000000000000000000000000000000..57f68647f3254cd55d05d7e8d896c229e219c9d2
--- /dev/null
+++ b/backend/python/rfdetr/backend.py
@@ -0,0 +1,174 @@
+#!/usr/bin/env python3
+"""
+gRPC server for RFDETR object detection models.
+"""
+from concurrent import futures
+
+import argparse
+import signal
+import sys
+import os
+import time
+import base64
+import backend_pb2
+import backend_pb2_grpc
+import grpc
+
+import requests
+
+import supervision as sv
+from inference import get_model
+from PIL import Image
+from io import BytesIO
+
+_ONE_DAY_IN_SECONDS = 60 * 60 * 24
+
+# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
+MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
+
+# Implement the BackendServicer class with the service methods
+class BackendServicer(backend_pb2_grpc.BackendServicer):
+ """
+ A gRPC servicer for the RFDETR backend service.
+
+ This class implements the gRPC methods for object detection using RFDETR models.
+ """
+
+ def __init__(self):
+ self.model = None
+ self.model_name = None
+
+ def Health(self, request, context):
+ """
+ A gRPC method that returns the health status of the backend service.
+
+ Args:
+ request: A HealthMessage object that contains the request parameters.
+ context: A grpc.ServicerContext object that provides information about the RPC.
+
+ Returns:
+ A Reply object that contains the health status of the backend service.
+ """
+ return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
+
+ def LoadModel(self, request, context):
+ """
+ A gRPC method that loads a RFDETR model into memory.
+
+ Args:
+ request: A ModelOptions object that contains the model parameters.
+ context: A grpc.ServicerContext object that provides information about the RPC.
+
+ Returns:
+ A Result object that contains the result of the LoadModel operation.
+ """
+ model_name = request.Model
+ try:
+ # Load the RFDETR model
+ self.model = get_model(model_name)
+ self.model_name = model_name
+ print(f'Loaded RFDETR model: {model_name}')
+ except Exception as err:
+ return backend_pb2.Result(success=False, message=f"Failed to load model: {err}")
+
+ return backend_pb2.Result(message="Model loaded successfully", success=True)
+
+ def Detect(self, request, context):
+ """
+ A gRPC method that performs object detection on an image.
+
+ Args:
+ request: A DetectOptions object that contains the image source.
+ context: A grpc.ServicerContext object that provides information about the RPC.
+
+ Returns:
+ A DetectResponse object that contains the detection results.
+ """
+ if self.model is None:
+ print(f"Model is None")
+ return backend_pb2.DetectResponse()
+ print(f"Model is not None")
+ try:
+ print(f"Decoding image")
+ # Decode the base64 image
+ print(f"Image data: {request.src}")
+
+ image_data = base64.b64decode(request.src)
+ image = Image.open(BytesIO(image_data))
+
+ # Perform inference
+ predictions = self.model.infer(image, confidence=0.5)[0]
+
+ # Convert to proto format
+ proto_detections = []
+ for i in range(len(predictions.predictions)):
+ pred = predictions.predictions[i]
+ print(f"Prediction: {pred}")
+ proto_detection = backend_pb2.Detection(
+ x=float(pred.x),
+ y=float(pred.y),
+ width=float(pred.width),
+ height=float(pred.height),
+ confidence=float(pred.confidence),
+ class_name=pred.class_name
+ )
+ proto_detections.append(proto_detection)
+
+ return backend_pb2.DetectResponse(Detections=proto_detections)
+ except Exception as err:
+ print(f"Detection error: {err}")
+ return backend_pb2.DetectResponse()
+
+ def Status(self, request, context):
+ """
+ A gRPC method that returns the status of the backend service.
+
+ Args:
+ request: A HealthMessage object that contains the request parameters.
+ context: A grpc.ServicerContext object that provides information about the RPC.
+
+ Returns:
+ A StatusResponse object that contains the status information.
+ """
+ state = backend_pb2.StatusResponse.READY if self.model is not None else backend_pb2.StatusResponse.UNINITIALIZED
+ return backend_pb2.StatusResponse(state=state)
+
+def serve(address):
+ server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
+ options=[
+ ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
+ ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
+ ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
+ ])
+ backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
+ server.add_insecure_port(address)
+ server.start()
+ print("[RFDETR] Server started. Listening on: " + address, file=sys.stderr)
+
+ # Define the signal handler function
+ def signal_handler(sig, frame):
+ print("[RFDETR] Received termination signal. Shutting down...")
+ server.stop(0)
+ sys.exit(0)
+
+ # Set the signal handlers for SIGINT and SIGTERM
+ signal.signal(signal.SIGINT, signal_handler)
+ signal.signal(signal.SIGTERM, signal_handler)
+
+ try:
+ while True:
+ time.sleep(_ONE_DAY_IN_SECONDS)
+ except KeyboardInterrupt:
+ server.stop(0)
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Run the RFDETR gRPC server.")
+ parser.add_argument(
+ "--addr", default="localhost:50051", help="The address to bind the server to."
+ )
+ args = parser.parse_args()
+ print(f"[RFDETR] startup: {args}", file=sys.stderr)
+ serve(args.addr)
+
+
+
diff --git a/backend/python/rfdetr/install.sh b/backend/python/rfdetr/install.sh
new file mode 100644
index 0000000000000000000000000000000000000000..32befa8e6c034734e54038c3b4b565522028f391
--- /dev/null
+++ b/backend/python/rfdetr/install.sh
@@ -0,0 +1,19 @@
+#!/bin/bash
+set -e
+
+backend_dir=$(dirname $0)
+if [ -d $backend_dir/common ]; then
+ source $backend_dir/common/libbackend.sh
+else
+ source $backend_dir/../common/libbackend.sh
+fi
+
+# This is here because the Intel pip index is broken and returns 200 status codes for every package name, it just doesn't return any package links.
+# This makes uv think that the package exists in the Intel pip index, and by default it stops looking at other pip indexes once it finds a match.
+# We need uv to continue falling through to the pypi default index to find optimum[openvino] in the pypi index
+# the --upgrade actually allows us to *downgrade* torch to the version provided in the Intel pip index
+if [ "x${BUILD_PROFILE}" == "xintel" ]; then
+ EXTRA_PIP_INSTALL_FLAGS+=" --upgrade --index-strategy=unsafe-first-match"
+fi
+
+installRequirements
diff --git a/backend/python/rfdetr/requirements-cpu.txt b/backend/python/rfdetr/requirements-cpu.txt
new file mode 100644
index 0000000000000000000000000000000000000000..d0d1f4afaa94360b23573d777c546f2661194d2f
--- /dev/null
+++ b/backend/python/rfdetr/requirements-cpu.txt
@@ -0,0 +1,7 @@
+rfdetr
+opencv-python
+accelerate
+peft
+inference
+torch==2.7.1
+optimum-quanto
\ No newline at end of file
diff --git a/backend/python/rfdetr/requirements-cublas12.txt b/backend/python/rfdetr/requirements-cublas12.txt
new file mode 100644
index 0000000000000000000000000000000000000000..36eaa47bb372e8703ad84f12ec00ce22f844b348
--- /dev/null
+++ b/backend/python/rfdetr/requirements-cublas12.txt
@@ -0,0 +1,7 @@
+torch==2.7.1
+rfdetr
+opencv-python
+accelerate
+inference
+peft
+optimum-quanto
\ No newline at end of file
diff --git a/backend/python/rfdetr/requirements-cublas13.txt b/backend/python/rfdetr/requirements-cublas13.txt
new file mode 100644
index 0000000000000000000000000000000000000000..d75a2013c24ddc7e6c753a9fed12dab0fde355aa
--- /dev/null
+++ b/backend/python/rfdetr/requirements-cublas13.txt
@@ -0,0 +1,8 @@
+--extra-index-url https://download.pytorch.org/whl/cu130
+torch==2.9.1
+rfdetr
+opencv-python
+accelerate
+inference
+peft
+optimum-quanto
\ No newline at end of file
diff --git a/backend/python/rfdetr/requirements-hipblas.txt b/backend/python/rfdetr/requirements-hipblas.txt
new file mode 100644
index 0000000000000000000000000000000000000000..884cfba7be4628621d3ba3c60418e093c68b8249
--- /dev/null
+++ b/backend/python/rfdetr/requirements-hipblas.txt
@@ -0,0 +1,9 @@
+--extra-index-url https://download.pytorch.org/whl/rocm6.4
+torch==2.8.0+rocm6.4
+torchvision==0.23.0+rocm6.4
+rfdetr
+opencv-python
+accelerate
+inference
+peft
+optimum-quanto
\ No newline at end of file
diff --git a/backend/python/rfdetr/requirements-intel.txt b/backend/python/rfdetr/requirements-intel.txt
new file mode 100644
index 0000000000000000000000000000000000000000..55fcbb318d9924c34444976341415065fb1ec2f6
--- /dev/null
+++ b/backend/python/rfdetr/requirements-intel.txt
@@ -0,0 +1,13 @@
+--extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
+intel-extension-for-pytorch==2.3.110+xpu
+torch==2.3.1+cxx11.abi
+torchvision==0.18.1+cxx11.abi
+oneccl_bind_pt==2.3.100+xpu
+optimum[openvino]
+setuptools
+rfdetr
+inference
+opencv-python
+accelerate
+peft
+optimum-quanto
\ No newline at end of file
diff --git a/backend/python/rfdetr/requirements.txt b/backend/python/rfdetr/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..44b40efd0b1b514235f2362ee2b22233ba4a749c
--- /dev/null
+++ b/backend/python/rfdetr/requirements.txt
@@ -0,0 +1,3 @@
+grpcio==1.71.0
+protobuf
+grpcio-tools
diff --git a/backend/python/rfdetr/run.sh b/backend/python/rfdetr/run.sh
new file mode 100644
index 0000000000000000000000000000000000000000..82b7b09ecc7d20cb3513bdd1b18f3524f7d4cbd7
--- /dev/null
+++ b/backend/python/rfdetr/run.sh
@@ -0,0 +1,9 @@
+#!/bin/bash
+backend_dir=$(dirname $0)
+if [ -d $backend_dir/common ]; then
+ source $backend_dir/common/libbackend.sh
+else
+ source $backend_dir/../common/libbackend.sh
+fi
+
+startBackend $@
\ No newline at end of file
diff --git a/backend/python/rfdetr/test.sh b/backend/python/rfdetr/test.sh
new file mode 100644
index 0000000000000000000000000000000000000000..eb59f2aaf3f38f78c7ca3dc414ea490ff66776d7
--- /dev/null
+++ b/backend/python/rfdetr/test.sh
@@ -0,0 +1,11 @@
+#!/bin/bash
+set -e
+
+backend_dir=$(dirname $0)
+if [ -d $backend_dir/common ]; then
+ source $backend_dir/common/libbackend.sh
+else
+ source $backend_dir/../common/libbackend.sh
+fi
+
+runUnittests
diff --git a/backend/python/transformers/Makefile b/backend/python/transformers/Makefile
new file mode 100644
index 0000000000000000000000000000000000000000..6897baf0c9b49c420436e89d5e4ecc46dc21195c
--- /dev/null
+++ b/backend/python/transformers/Makefile
@@ -0,0 +1,24 @@
+.PHONY: transformers
+transformers:
+ bash install.sh
+
+.PHONY: run
+run: transformers
+ @echo "Running transformers..."
+ bash run.sh
+ @echo "transformers run."
+
+# It is not working well by using command line. It only6 works with IDE like VSCode.
+.PHONY: test
+test: transformers
+ @echo "Testing transformers..."
+ bash test.sh
+ @echo "transformers tested."
+
+.PHONY: protogen-clean
+protogen-clean:
+ $(RM) backend_pb2_grpc.py backend_pb2.py
+
+.PHONY: clean
+clean: protogen-clean
+ rm -rf venv __pycache__
\ No newline at end of file
diff --git a/backend/python/transformers/README.md b/backend/python/transformers/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..0e72e598338fd82c752f1c7e980604efdca65396
--- /dev/null
+++ b/backend/python/transformers/README.md
@@ -0,0 +1,5 @@
+# Creating a separate environment for the transformers project
+
+```
+make transformers
+```
\ No newline at end of file
diff --git a/backend/python/transformers/backend.py b/backend/python/transformers/backend.py
new file mode 100644
index 0000000000000000000000000000000000000000..05713b917d2a346a4cd3ae6769815344fa2fa803
--- /dev/null
+++ b/backend/python/transformers/backend.py
@@ -0,0 +1,679 @@
+#!/usr/bin/env python3
+"""
+Extra gRPC server for HuggingFace AutoModel models.
+"""
+from concurrent import futures
+
+import argparse
+import signal
+import sys
+import os
+from threading import Thread
+import asyncio
+
+import time
+import backend_pb2
+import backend_pb2_grpc
+
+import grpc
+import torch
+import torch.cuda
+
+
+XPU=os.environ.get("XPU", "0") == "1"
+from transformers import AutoTokenizer, AutoModel, set_seed, TextIteratorStreamer, StoppingCriteriaList, StopStringCriteria, MambaConfig, MambaForCausalLM
+from transformers import AutoProcessor, MusicgenForConditionalGeneration, DiaForConditionalGeneration
+from scipy.io import wavfile
+import outetts
+from sentence_transformers import SentenceTransformer
+
+
+_ONE_DAY_IN_SECONDS = 60 * 60 * 24
+
+# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
+MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
+
+
+def mean_pooling(model_output, attention_mask):
+ """
+ Mean pooling to get sentence embeddings. See:
+ https://huggingface.co/sentence-transformers/paraphrase-distilroberta-base-v1
+ """
+ token_embeddings = model_output[0]
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
+ sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) # Sum columns
+ sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
+ return sum_embeddings / sum_mask
+
+# Implement the BackendServicer class with the service methods
+class BackendServicer(backend_pb2_grpc.BackendServicer):
+ """
+ A gRPC servicer for the backend service.
+
+ This class implements the gRPC methods for the backend service, including Health, LoadModel, and Embedding.
+ """
+ def Health(self, request, context):
+ """
+ A gRPC method that returns the health status of the backend service.
+
+ Args:
+ request: A HealthRequest object that contains the request parameters.
+ context: A grpc.ServicerContext object that provides information about the RPC.
+
+ Returns:
+ A Reply object that contains the health status of the backend service.
+ """
+ return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
+
+ def LoadModel(self, request, context):
+ """
+ A gRPC method that loads a model into memory.
+
+ Args:
+ request: A LoadModelRequest object that contains the request parameters.
+ context: A grpc.ServicerContext object that provides information about the RPC.
+
+ Returns:
+ A Result object that contains the result of the LoadModel operation.
+ """
+
+ model_name = request.Model
+
+ # Check to see if the Model exists in the filesystem already.
+ if os.path.exists(request.ModelFile):
+ model_name = request.ModelFile
+
+ compute = torch.float16
+ if request.F16Memory == True:
+ compute=torch.bfloat16
+
+ self.CUDA = torch.cuda.is_available()
+ self.OV=False
+ self.OuteTTS=False
+ self.DiaTTS=False
+ self.SentenceTransformer = False
+
+ device_map="cpu"
+ mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
+ if mps_available:
+ device_map = "mps"
+ quantization = None
+ autoTokenizer = True
+
+ # Parse options from request.Options
+ self.options = {}
+ options = request.Options
+
+ # The options are a list of strings in this form optname:optvalue
+ # We are storing all the options in a dict so we can use it later when generating
+ # Example options: ["max_new_tokens:3072", "guidance_scale:3.0", "temperature:1.8", "top_p:0.90", "top_k:45"]
+ for opt in options:
+ if ":" not in opt:
+ continue
+ key, value = opt.split(":", 1)
+ # if value is a number, convert it to the appropriate type
+ try:
+ if "." in value:
+ value = float(value)
+ else:
+ value = int(value)
+ except ValueError:
+ # Keep as string if conversion fails
+ pass
+ self.options[key] = value
+
+ print(f"Parsed options: {self.options}", file=sys.stderr)
+
+ if self.CUDA:
+ from transformers import BitsAndBytesConfig, AutoModelForCausalLM
+ if request.MainGPU:
+ device_map=request.MainGPU
+ else:
+ device_map="cuda:0"
+ if request.Quantization == "bnb_4bit":
+ quantization = BitsAndBytesConfig(
+ load_in_4bit = True,
+ bnb_4bit_compute_dtype = compute,
+ bnb_4bit_quant_type = "nf4",
+ bnb_4bit_use_double_quant = True,
+ load_in_8bit = False,
+ )
+ elif request.Quantization == "bnb_8bit":
+ quantization = BitsAndBytesConfig(
+ load_in_4bit=False,
+ bnb_4bit_compute_dtype = None,
+ load_in_8bit=True,
+ )
+
+ try:
+ if request.Type == "AutoModelForCausalLM":
+ if XPU:
+ import intel_extension_for_pytorch as ipex
+ from intel_extension_for_transformers.transformers.modeling import AutoModelForCausalLM
+
+ device_map="xpu"
+ compute=torch.float16
+ if request.Quantization == "xpu_4bit":
+ xpu_4bit = True
+ xpu_8bit = False
+ elif request.Quantization == "xpu_8bit":
+ xpu_4bit = False
+ xpu_8bit = True
+ else:
+ xpu_4bit = False
+ xpu_8bit = False
+ self.model = AutoModelForCausalLM.from_pretrained(model_name,
+ trust_remote_code=request.TrustRemoteCode,
+ use_safetensors=True,
+ device_map=device_map,
+ load_in_4bit=xpu_4bit,
+ load_in_8bit=xpu_8bit,
+ torch_dtype=compute)
+ else:
+ self.model = AutoModelForCausalLM.from_pretrained(model_name,
+ trust_remote_code=request.TrustRemoteCode,
+ use_safetensors=True,
+ quantization_config=quantization,
+ device_map=device_map,
+ torch_dtype=compute)
+ elif request.Type == "OVModelForCausalLM":
+ from optimum.intel.openvino import OVModelForCausalLM
+ from openvino.runtime import Core
+
+ if request.MainGPU:
+ device_map=request.MainGPU
+ else:
+ device_map="AUTO"
+ devices = Core().available_devices
+ if "GPU" in " ".join(devices):
+ device_map="AUTO:GPU"
+ # While working on a fine tuned model, inference may give an inaccuracy and performance drop on GPU if winograd convolutions are selected.
+ # https://docs.openvino.ai/2024/openvino-workflow/running-inference/inference-devices-and-modes/gpu-device.html
+ if "CPU" or "NPU" in device_map:
+ if "-CPU" or "-NPU" not in device_map:
+ ovconfig={"PERFORMANCE_HINT": "CUMULATIVE_THROUGHPUT"}
+ else:
+ ovconfig={"PERFORMANCE_HINT": "CUMULATIVE_THROUGHPUT","GPU_DISABLE_WINOGRAD_CONVOLUTION": "YES"}
+ self.model = OVModelForCausalLM.from_pretrained(model_name,
+ compile=True,
+ trust_remote_code=request.TrustRemoteCode,
+ ov_config=ovconfig,
+ device=device_map)
+ self.OV = True
+ elif request.Type == "OVModelForFeatureExtraction":
+ from optimum.intel.openvino import OVModelForFeatureExtraction
+ from openvino.runtime import Core
+
+ if request.MainGPU:
+ device_map=request.MainGPU
+ else:
+ device_map="AUTO"
+ devices = Core().available_devices
+ if "GPU" in " ".join(devices):
+ device_map="AUTO:GPU"
+ # While working on a fine tuned model, inference may give an inaccuracy and performance drop on GPU if winograd convolutions are selected.
+ # https://docs.openvino.ai/2024/openvino-workflow/running-inference/inference-devices-and-modes/gpu-device.html
+ if "CPU" or "NPU" in device_map:
+ if "-CPU" or "-NPU" not in device_map:
+ ovconfig={"PERFORMANCE_HINT": "CUMULATIVE_THROUGHPUT"}
+ else:
+ ovconfig={"PERFORMANCE_HINT": "CUMULATIVE_THROUGHPUT","GPU_DISABLE_WINOGRAD_CONVOLUTION": "YES"}
+ self.model = OVModelForFeatureExtraction.from_pretrained(model_name,
+ compile=True,
+ trust_remote_code=request.TrustRemoteCode,
+ ov_config=ovconfig,
+ export=True,
+ device=device_map)
+ self.OV = True
+ elif request.Type == "MusicgenForConditionalGeneration":
+ autoTokenizer = False
+ self.processor = AutoProcessor.from_pretrained(model_name)
+ self.model = MusicgenForConditionalGeneration.from_pretrained(model_name)
+ elif request.Type == "DiaForConditionalGeneration":
+ autoTokenizer = False
+ print("DiaForConditionalGeneration", file=sys.stderr)
+ self.processor = AutoProcessor.from_pretrained(model_name)
+ self.model = DiaForConditionalGeneration.from_pretrained(model_name)
+ if self.CUDA:
+ self.model = self.model.to("cuda")
+ self.processor = self.processor.to("cuda")
+ print("DiaForConditionalGeneration loaded", file=sys.stderr)
+ self.DiaTTS = True
+ elif request.Type == "OuteTTS":
+ autoTokenizer = False
+ options = request.Options
+ MODELNAME = "OuteAI/OuteTTS-0.3-1B"
+ TOKENIZER = "OuteAI/OuteTTS-0.3-1B"
+ VERSION = "0.3"
+ SPEAKER = "en_male_1"
+ for opt in options:
+ if opt.startswith("tokenizer:"):
+ TOKENIZER = opt.split(":")[1]
+ break
+ if opt.startswith("version:"):
+ VERSION = opt.split(":")[1]
+ break
+ if opt.startswith("speaker:"):
+ SPEAKER = opt.split(":")[1]
+ break
+
+ if model_name != "":
+ MODELNAME = model_name
+
+ # Configure the model
+ model_config = outetts.HFModelConfig_v2(
+ model_path=MODELNAME,
+ tokenizer_path=TOKENIZER
+ )
+ # Initialize the interface
+ self.interface = outetts.InterfaceHF(model_version=VERSION, cfg=model_config)
+ self.OuteTTS = True
+
+ self.interface.print_default_speakers()
+ if request.AudioPath:
+ if os.path.isabs(request.AudioPath):
+ self.AudioPath = request.AudioPath
+ else:
+ self.AudioPath = os.path.join(request.ModelPath, request.AudioPath)
+ self.speaker = self.interface.create_speaker(audio_path=self.AudioPath)
+ else:
+ self.speaker = self.interface.load_default_speaker(name=SPEAKER)
+ elif request.Type == "SentenceTransformer":
+ autoTokenizer = False
+ self.model = SentenceTransformer(model_name, trust_remote_code=request.TrustRemoteCode)
+ self.SentenceTransformer = True
+ elif request.Type == "Mamba":
+ autoTokenizer = False
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
+ self.model = MambaForCausalLM.from_pretrained(model_name)
+ else:
+ print("Automodel", file=sys.stderr)
+ self.model = AutoModel.from_pretrained(model_name,
+ trust_remote_code=request.TrustRemoteCode,
+ use_safetensors=True,
+ quantization_config=quantization,
+ device_map=device_map,
+ torch_dtype=compute)
+ if request.ContextSize > 0:
+ self.max_tokens = request.ContextSize
+ elif hasattr(self.model, 'config') and hasattr(self.model.config, 'max_position_embeddings'):
+ self.max_tokens = self.model.config.max_position_embeddings
+ else:
+ self.max_tokens = self.options.get("max_new_tokens", 512)
+
+ if autoTokenizer:
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_safetensors=True)
+ self.XPU = False
+
+ if XPU and self.OV == False:
+ self.XPU = True
+ try:
+ print("Optimizing model", model_name, "to XPU.", file=sys.stderr)
+ self.model = ipex.optimize_transformers(self.model, inplace=True, dtype=torch.float16, device="xpu")
+ except Exception as err:
+ print("Not using XPU:", err, file=sys.stderr)
+
+ except Exception as err:
+ print("Error:", err, file=sys.stderr)
+ return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
+ # Implement your logic here for the LoadModel service
+ # Replace this with your desired response
+ return backend_pb2.Result(message="Model loaded successfully", success=True)
+
+ def Embedding(self, request, context):
+ """
+ A gRPC method that calculates embeddings for a given sentence.
+
+ Args:
+ request: An EmbeddingRequest object that contains the request parameters.
+ context: A grpc.ServicerContext object that provides information about the RPC.
+
+ Returns:
+ An EmbeddingResult object that contains the calculated embeddings.
+ """
+
+ set_seed(request.Seed)
+ # Tokenize input
+ max_length = 512
+ if request.Tokens != 0:
+ max_length = request.Tokens
+
+ embeds = None
+
+ if self.SentenceTransformer:
+ print("Calculated embeddings for: " + request.Embeddings, file=sys.stderr)
+ embeds = self.model.encode(request.Embeddings)
+ else:
+ encoded_input = self.tokenizer(request.Embeddings, padding=True, truncation=True, max_length=max_length, return_tensors="pt")
+
+ # Create word embeddings
+ if self.CUDA:
+ encoded_input = encoded_input.to("cuda")
+
+ with torch.no_grad():
+ model_output = self.model(**encoded_input)
+
+ # Pool to get sentence embeddings; i.e. generate one 1024 vector for the entire sentence
+ sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
+ embeds = sentence_embeddings[0]
+ return backend_pb2.EmbeddingResult(embeddings=embeds)
+
+ async def _predict(self, request, context, streaming=False):
+ set_seed(request.Seed)
+ if request.TopP < 0 or request.TopP > 1:
+ request.TopP = 1
+
+ if request.TopK <= 0:
+ request.TopK = 50
+
+ if request.Temperature > 0 :
+ sample=True
+ else:
+ sample=False
+ request.TopP == None
+ request.TopK == None
+ request.Temperature == None
+
+ prompt = request.Prompt
+ if not request.Prompt and request.UseTokenizerTemplate and request.Messages:
+ prompt = self.tokenizer.apply_chat_template(request.Messages, tokenize=False, add_generation_prompt=True)
+
+ inputs = self.tokenizer(prompt, return_tensors="pt")
+
+ if request.Tokens > 0:
+ max_tokens = request.Tokens
+ else:
+ max_tokens = self.max_tokens - inputs["input_ids"].size()[inputs["input_ids"].dim()-1]
+
+ if self.CUDA:
+ inputs = inputs.to("cuda")
+ if XPU and self.OV == False:
+ inputs = inputs.to("xpu")
+ streaming = False
+
+ criteria=[]
+ if request.StopPrompts:
+ criteria = StoppingCriteriaList(
+ [
+ StopStringCriteria(tokenizer=self.tokenizer, stop_strings=request.StopPrompts),
+ ]
+ )
+
+ if streaming:
+ streamer=TextIteratorStreamer(self.tokenizer,
+ skip_prompt=True,
+ skip_special_tokens=True)
+ config=dict(inputs,
+ max_new_tokens=max_tokens,
+ temperature=request.Temperature,
+ top_p=request.TopP,
+ top_k=request.TopK,
+ do_sample=sample,
+ attention_mask=inputs["attention_mask"],
+ eos_token_id=self.tokenizer.eos_token_id,
+ pad_token_id=self.tokenizer.eos_token_id,
+ streamer=streamer,
+ stopping_criteria=criteria,
+ use_cache=True,
+ )
+ thread=Thread(target=self.model.generate, kwargs=config)
+ thread.start()
+ generated_text = ""
+ try:
+ for new_text in streamer:
+ generated_text += new_text
+ yield backend_pb2.Reply(message=bytes(new_text, encoding='utf-8'))
+ finally:
+ thread.join()
+ else:
+ if XPU and self.OV == False:
+ outputs = self.model.generate(inputs["input_ids"],
+ max_new_tokens=max_tokens,
+ temperature=request.Temperature,
+ top_p=request.TopP,
+ top_k=request.TopK,
+ do_sample=sample,
+ pad_token=self.tokenizer.eos_token_id)
+ else:
+ outputs = self.model.generate(**inputs,
+ max_new_tokens=max_tokens,
+ temperature=request.Temperature,
+ top_p=request.TopP,
+ top_k=request.TopK,
+ do_sample=sample,
+ eos_token_id=self.tokenizer.eos_token_id,
+ pad_token_id=self.tokenizer.eos_token_id,
+ stopping_criteria=criteria,
+ use_cache=True,
+ )
+ generated_text = self.tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[1]:], skip_special_tokens=True)[0]
+
+ if streaming:
+ return
+
+ yield backend_pb2.Reply(message=bytes(generated_text, encoding='utf-8'))
+
+ async def Predict(self, request, context):
+ """
+ Generates text based on the given prompt and sampling parameters.
+
+ Args:
+ request: The predict request.
+ context: The gRPC context.
+
+ Returns:
+ backend_pb2.Reply: The predict result.
+ """
+ gen = self._predict(request, context, streaming=False)
+ res = await gen.__anext__()
+ return res
+
+ async def PredictStream(self, request, context):
+ """
+ Generates text based on the given prompt and sampling parameters, and streams the results.
+
+ Args:
+ request: The predict stream request.
+ context: The gRPC context.
+
+ Returns:
+ backend_pb2.Result: The predict stream result.
+ """
+ iterations = self._predict(request, context, streaming=True)
+ try:
+ async for iteration in iterations:
+ yield iteration
+ finally:
+ await iterations.aclose()
+
+ def SoundGeneration(self, request, context):
+ model_name = request.model
+ try:
+ if self.processor is None:
+ if model_name == "":
+ return backend_pb2.Result(success=False, message="request.model is required")
+ self.processor = AutoProcessor.from_pretrained(model_name)
+ if self.model is None:
+ if model_name == "":
+ return backend_pb2.Result(success=False, message="request.model is required")
+ self.model = MusicgenForConditionalGeneration.from_pretrained(model_name)
+ inputs = None
+ if request.text == "":
+ inputs = self.model.get_unconditional_inputs(num_samples=1)
+ elif request.HasField('src'):
+ # TODO SECURITY CODE GOES HERE LOL
+ # WHO KNOWS IF THIS WORKS???
+ sample_rate, wsamples = wavfile.read('path_to_your_file.wav')
+
+ if request.HasField('src_divisor'):
+ wsamples = wsamples[: len(wsamples) // request.src_divisor]
+
+ inputs = self.processor(
+ audio=wsamples,
+ sampling_rate=sample_rate,
+ text=[request.text],
+ padding=True,
+ return_tensors="pt",
+ )
+ else:
+ inputs = self.processor(
+ text=[request.text],
+ padding=True,
+ return_tensors="pt",
+ )
+
+ if request.HasField('duration'):
+ tokens = int(request.duration * 51.2) # 256 tokens = 5 seconds, therefore 51.2 tokens is one second
+ guidance = self.options.get("guidance_scale", 3.0)
+ if request.HasField('temperature'):
+ guidance = request.temperature
+ dosample = self.options.get("do_sample", True)
+ if request.HasField('sample'):
+ dosample = request.sample
+ audio_values = self.model.generate(**inputs, do_sample=dosample, guidance_scale=guidance, max_new_tokens=self.max_tokens)
+ print("[transformers-musicgen] SoundGeneration generated!", file=sys.stderr)
+ sampling_rate = self.model.config.audio_encoder.sampling_rate
+ wavfile.write(request.dst, rate=sampling_rate, data=audio_values[0, 0].numpy())
+ print("[transformers-musicgen] SoundGeneration saved to", request.dst, file=sys.stderr)
+ print("[transformers-musicgen] SoundGeneration for", file=sys.stderr)
+ print("[transformers-musicgen] SoundGeneration requested tokens", tokens, file=sys.stderr)
+ print(request, file=sys.stderr)
+ except Exception as err:
+ return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
+ return backend_pb2.Result(success=True)
+
+
+ def CallDiaTTS(self, request, context):
+ """
+ Generates dialogue audio using the Dia model.
+
+ Args:
+ request: A TTSRequest containing text dialogue and generation parameters
+ context: The gRPC context
+
+ Returns:
+ A Result object indicating success or failure
+ """
+ try:
+ print("[DiaTTS] generating dialogue audio", file=sys.stderr)
+
+ # Prepare text input - expect dialogue format like [S1] ... [S2] ...
+ text = [request.text]
+
+ # Process the input
+ inputs = self.processor(text=text, padding=True, return_tensors="pt")
+
+ # Generate audio with parameters from options or defaults
+ generation_params = {
+ **inputs,
+ "max_new_tokens": self.max_tokens,
+ "guidance_scale": self.options.get("guidance_scale", 3.0),
+ "temperature": self.options.get("temperature", 1.8),
+ "top_p": self.options.get("top_p", 0.90),
+ "top_k": self.options.get("top_k", 45)
+ }
+
+ outputs = self.model.generate(**generation_params)
+
+ # Decode and save audio
+ outputs = self.processor.batch_decode(outputs)
+ self.processor.save_audio(outputs, request.dst)
+
+ print("[DiaTTS] Generated dialogue audio", file=sys.stderr)
+ print("[DiaTTS] Audio saved to", request.dst, file=sys.stderr)
+ print("[DiaTTS] Dialogue generation done", file=sys.stderr)
+
+ except Exception as err:
+ return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
+ return backend_pb2.Result(success=True)
+
+
+ def CallOuteTTS(self, request, context):
+ try:
+ print("[OuteTTS] generating TTS", file=sys.stderr)
+ gen_cfg = outetts.GenerationConfig(
+ text="Speech synthesis is the artificial production of human speech.",
+ temperature=self.options.get("temperature", 0.1),
+ repetition_penalty=self.options.get("repetition_penalty", 1.1),
+ max_length=self.max_tokens,
+ speaker=self.speaker,
+ # voice_characteristics="upbeat enthusiasm, friendliness, clarity, professionalism, and trustworthiness"
+ )
+ output = self.interface.generate(config=gen_cfg)
+ print("[OuteTTS] Generated TTS", file=sys.stderr)
+ output.save(request.dst)
+ print("[OuteTTS] TTS done", file=sys.stderr)
+ except Exception as err:
+ return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
+ return backend_pb2.Result(success=True)
+
+# The TTS endpoint is older, and provides fewer features, but exists for compatibility reasons
+ def TTS(self, request, context):
+ if self.OuteTTS:
+ return self.CallOuteTTS(request, context)
+
+ if self.DiaTTS:
+ print("DiaTTS", file=sys.stderr)
+ return self.CallDiaTTS(request, context)
+
+ model_name = request.model
+ try:
+ if self.processor is None:
+ if model_name == "":
+ return backend_pb2.Result(success=False, message="request.model is required")
+ self.processor = AutoProcessor.from_pretrained(model_name)
+ if self.model is None:
+ if model_name == "":
+ return backend_pb2.Result(success=False, message="request.model is required")
+ self.model = MusicgenForConditionalGeneration.from_pretrained(model_name)
+ inputs = self.processor(
+ text=[request.text],
+ padding=True,
+ return_tensors="pt",
+ )
+ tokens = self.max_tokens # No good place to set the "length" in TTS, so use 10s as a sane default
+ audio_values = self.model.generate(**inputs, max_new_tokens=tokens)
+ print("[transformers-musicgen] TTS generated!", file=sys.stderr)
+ sampling_rate = self.model.config.audio_encoder.sampling_rate
+ wavfile.write(request.dst, rate=sampling_rate, data=audio_values[0, 0].numpy())
+ print("[transformers-musicgen] TTS saved to", request.dst, file=sys.stderr)
+ print("[transformers-musicgen] TTS for", file=sys.stderr)
+ print(request, file=sys.stderr)
+ except Exception as err:
+ return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
+ return backend_pb2.Result(success=True)
+
+async def serve(address):
+ # Start asyncio gRPC server
+ server = grpc.aio.server(migration_thread_pool=futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
+ options=[
+ ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
+ ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
+ ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
+ ])
+ # Add the servicer to the server
+ backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
+ # Bind the server to the address
+ server.add_insecure_port(address)
+
+ # Gracefully shutdown the server on SIGTERM or SIGINT
+ loop = asyncio.get_event_loop()
+ for sig in (signal.SIGINT, signal.SIGTERM):
+ loop.add_signal_handler(
+ sig, lambda: asyncio.ensure_future(server.stop(5))
+ )
+
+ # Start the server
+ await server.start()
+ print("Server started. Listening on: " + address, file=sys.stderr)
+ # Wait for the server to be terminated
+ await server.wait_for_termination()
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Run the gRPC server.")
+ parser.add_argument(
+ "--addr", default="localhost:50051", help="The address to bind the server to."
+ )
+ args = parser.parse_args()
+
+ asyncio.run(serve(args.addr))
diff --git a/backend/python/transformers/install.sh b/backend/python/transformers/install.sh
new file mode 100644
index 0000000000000000000000000000000000000000..32befa8e6c034734e54038c3b4b565522028f391
--- /dev/null
+++ b/backend/python/transformers/install.sh
@@ -0,0 +1,19 @@
+#!/bin/bash
+set -e
+
+backend_dir=$(dirname $0)
+if [ -d $backend_dir/common ]; then
+ source $backend_dir/common/libbackend.sh
+else
+ source $backend_dir/../common/libbackend.sh
+fi
+
+# This is here because the Intel pip index is broken and returns 200 status codes for every package name, it just doesn't return any package links.
+# This makes uv think that the package exists in the Intel pip index, and by default it stops looking at other pip indexes once it finds a match.
+# We need uv to continue falling through to the pypi default index to find optimum[openvino] in the pypi index
+# the --upgrade actually allows us to *downgrade* torch to the version provided in the Intel pip index
+if [ "x${BUILD_PROFILE}" == "xintel" ]; then
+ EXTRA_PIP_INSTALL_FLAGS+=" --upgrade --index-strategy=unsafe-first-match"
+fi
+
+installRequirements
diff --git a/backend/python/transformers/requirements-cpu.txt b/backend/python/transformers/requirements-cpu.txt
new file mode 100644
index 0000000000000000000000000000000000000000..f06a276179ada86a913d5d40031277dbd83a68fb
--- /dev/null
+++ b/backend/python/transformers/requirements-cpu.txt
@@ -0,0 +1,9 @@
+torch==2.7.1
+llvmlite==0.43.0
+numba==0.60.0
+accelerate
+transformers
+bitsandbytes
+outetts
+sentence-transformers==5.2.0
+protobuf==6.33.4
\ No newline at end of file
diff --git a/backend/python/transformers/requirements-cublas12.txt b/backend/python/transformers/requirements-cublas12.txt
new file mode 100644
index 0000000000000000000000000000000000000000..8497cf7e7031471b771cb13242305b6802516403
--- /dev/null
+++ b/backend/python/transformers/requirements-cublas12.txt
@@ -0,0 +1,9 @@
+torch==2.7.1
+accelerate
+llvmlite==0.43.0
+numba==0.60.0
+transformers
+bitsandbytes
+outetts
+sentence-transformers==5.2.0
+protobuf==6.33.4
\ No newline at end of file
diff --git a/backend/python/transformers/requirements-cublas13.txt b/backend/python/transformers/requirements-cublas13.txt
new file mode 100644
index 0000000000000000000000000000000000000000..536978f8d2d6f32b6a63a7d33baa42d26eb375f5
--- /dev/null
+++ b/backend/python/transformers/requirements-cublas13.txt
@@ -0,0 +1,9 @@
+--extra-index-url https://download.pytorch.org/whl/cu130
+torch==2.9.0
+llvmlite==0.43.0
+numba==0.60.0
+transformers
+bitsandbytes
+outetts
+sentence-transformers==5.2.0
+protobuf==6.33.4
\ No newline at end of file
diff --git a/backend/python/transformers/requirements-hipblas.txt b/backend/python/transformers/requirements-hipblas.txt
new file mode 100644
index 0000000000000000000000000000000000000000..0576c6acf1084789da29b218999cc2b3652de15b
--- /dev/null
+++ b/backend/python/transformers/requirements-hipblas.txt
@@ -0,0 +1,11 @@
+--extra-index-url https://download.pytorch.org/whl/rocm6.4
+torch==2.8.0+rocm6.4
+accelerate
+transformers
+llvmlite==0.43.0
+numba==0.60.0
+bitsandbytes
+outetts
+bitsandbytes
+sentence-transformers==5.2.0
+protobuf==6.33.4
\ No newline at end of file
diff --git a/backend/python/transformers/requirements-intel.txt b/backend/python/transformers/requirements-intel.txt
new file mode 100644
index 0000000000000000000000000000000000000000..836861246562276e3c66fa84529ebca393d2176a
--- /dev/null
+++ b/backend/python/transformers/requirements-intel.txt
@@ -0,0 +1,13 @@
+--extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
+intel-extension-for-pytorch==2.3.110+xpu
+torch==2.5.1+cxx11.abi
+oneccl_bind_pt==2.8.0+xpu
+optimum[openvino]
+llvmlite==0.43.0
+numba==0.60.0
+transformers
+intel-extension-for-transformers
+bitsandbytes
+outetts
+sentence-transformers==5.2.0
+protobuf==6.33.4
\ No newline at end of file
diff --git a/backend/python/transformers/requirements.txt b/backend/python/transformers/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..5006c007f3f8e9df4c24328765625fe6e75c8373
--- /dev/null
+++ b/backend/python/transformers/requirements.txt
@@ -0,0 +1,6 @@
+grpcio==1.76.0
+protobuf==6.33.4
+certifi
+setuptools
+scipy==1.15.1
+numpy>=2.0.0
\ No newline at end of file
diff --git a/backend/python/transformers/run.sh b/backend/python/transformers/run.sh
new file mode 100644
index 0000000000000000000000000000000000000000..0b7cb77690793ddfa5c91d4a2d86b81d61ce5660
--- /dev/null
+++ b/backend/python/transformers/run.sh
@@ -0,0 +1,15 @@
+#!/bin/bash
+backend_dir=$(dirname $0)
+if [ -d $backend_dir/common ]; then
+ source $backend_dir/common/libbackend.sh
+else
+ source $backend_dir/../common/libbackend.sh
+fi
+
+if [ -d "/opt/intel" ]; then
+ # Assumes we are using the Intel oneAPI container image
+ # https://github.com/intel/intel-extension-for-pytorch/issues/538
+ export XPU=1
+fi
+
+startBackend $@
\ No newline at end of file
diff --git a/backend/python/transformers/test.py b/backend/python/transformers/test.py
new file mode 100644
index 0000000000000000000000000000000000000000..14efa6a7d8abb86489596af76a6f7806afbfdbc4
--- /dev/null
+++ b/backend/python/transformers/test.py
@@ -0,0 +1,173 @@
+"""
+A test script to test the gRPC service
+"""
+import unittest
+import subprocess
+import time
+import backend_pb2
+import backend_pb2_grpc
+
+import grpc
+
+
+class TestBackendServicer(unittest.TestCase):
+ """
+ TestBackendServicer is the class that tests the gRPC service
+ """
+ def setUp(self):
+ """
+ This method sets up the gRPC service by starting the server
+ """
+ self.service = subprocess.Popen(["python3", "backend.py", "--addr", "localhost:50051"])
+ time.sleep(10)
+
+ def tearDown(self) -> None:
+ """
+ This method tears down the gRPC service by terminating the server
+ """
+ self.service.kill()
+ self.service.wait()
+
+ def test_server_startup(self):
+ """
+ This method tests if the server starts up successfully
+ """
+ try:
+ self.setUp()
+ with grpc.insecure_channel("localhost:50051") as channel:
+ stub = backend_pb2_grpc.BackendStub(channel)
+ response = stub.Health(backend_pb2.HealthMessage())
+ self.assertEqual(response.message, b'OK')
+ except Exception as err:
+ print(err)
+ self.fail("Server failed to start")
+ finally:
+ self.tearDown()
+
+ def test_load_model(self):
+ """
+ This method tests if the model is loaded successfully
+ """
+ try:
+ self.setUp()
+ with grpc.insecure_channel("localhost:50051") as channel:
+ stub = backend_pb2_grpc.BackendStub(channel)
+ response = stub.LoadModel(backend_pb2.ModelOptions(Model="bert-base-cased"))
+ self.assertTrue(response.success)
+ self.assertEqual(response.message, "Model loaded successfully")
+ except Exception as err:
+ print(err)
+ self.fail("LoadModel service failed")
+ finally:
+ self.tearDown()
+
+ def test_embedding(self):
+ """
+ This method tests if the embeddings are generated successfully
+ """
+ try:
+ self.setUp()
+ with grpc.insecure_channel("localhost:50051") as channel:
+ stub = backend_pb2_grpc.BackendStub(channel)
+ response = stub.LoadModel(backend_pb2.ModelOptions(Model="bert-base-cased"))
+ print(response.message)
+ self.assertTrue(response.success)
+ embedding_request = backend_pb2.PredictOptions(Embeddings="This is a test sentence.")
+ embedding_response = stub.Embedding(embedding_request)
+ self.assertIsNotNone(embedding_response.embeddings)
+ except Exception as err:
+ print(err)
+ self.fail("Embedding service failed")
+ finally:
+ self.tearDown()
+
+ def test_audio_load_model(self):
+ """
+ This method tests if the model is loaded successfully
+ """
+ try:
+ self.setUp()
+ with grpc.insecure_channel("localhost:50051") as channel:
+ stub = backend_pb2_grpc.BackendStub(channel)
+ response = stub.LoadModel(backend_pb2.ModelOptions(Model="facebook/musicgen-small",Type="MusicgenForConditionalGeneration"))
+ self.assertTrue(response.success)
+ self.assertEqual(response.message, "Model loaded successfully")
+ except Exception as err:
+ print(err)
+ self.fail("LoadModel service failed")
+ finally:
+ self.tearDown()
+
+ def test_tts(self):
+ """
+ This method tests if TTS is generated successfully
+ """
+ try:
+ self.setUp()
+ with grpc.insecure_channel("localhost:50051") as channel:
+ stub = backend_pb2_grpc.BackendStub(channel)
+ response = stub.LoadModel(backend_pb2.ModelOptions(Model="facebook/musicgen-small",Type="MusicgenForConditionalGeneration"))
+ self.assertTrue(response.success)
+ tts_request = backend_pb2.TTSRequest(text="80s TV news production music hit for tonight's biggest story")
+ tts_response = stub.TTS(tts_request)
+ self.assertIsNotNone(tts_response)
+ except Exception as err:
+ print(err)
+ self.fail("TTS service failed")
+ finally:
+ self.tearDown()
+
+ def test_sound_generation(self):
+ """
+ This method tests if SoundGeneration is generated successfully
+ """
+ try:
+ self.setUp()
+ with grpc.insecure_channel("localhost:50051") as channel:
+ stub = backend_pb2_grpc.BackendStub(channel)
+ response = stub.LoadModel(backend_pb2.ModelOptions(Model="facebook/musicgen-small",Type="MusicgenForConditionalGeneration"))
+ self.assertTrue(response.success)
+ sg_request = backend_pb2.SoundGenerationRequest(text="80s TV news production music hit for tonight's biggest story")
+ sg_response = stub.SoundGeneration(sg_request)
+ self.assertIsNotNone(sg_response)
+ except Exception as err:
+ print(err)
+ self.fail("SoundGeneration service failed")
+ finally:
+ self.tearDown()
+
+ def test_embed_load_model(self):
+ """
+ This method tests if the model is loaded successfully
+ """
+ try:
+ self.setUp()
+ with grpc.insecure_channel("localhost:50051") as channel:
+ stub = backend_pb2_grpc.BackendStub(channel)
+ response = stub.LoadModel(backend_pb2.ModelOptions(Model="bert-base-nli-mean-tokens",Type="SentenceTransformer"))
+ self.assertTrue(response.success)
+ self.assertEqual(response.message, "Model loaded successfully")
+ except Exception as err:
+ print(err)
+ self.fail("LoadModel service failed")
+ finally:
+ self.tearDown()
+
+ def test_sentencetransformers_embedding(self):
+ """
+ This method tests if the embeddings are generated successfully
+ """
+ try:
+ self.setUp()
+ with grpc.insecure_channel("localhost:50051") as channel:
+ stub = backend_pb2_grpc.BackendStub(channel)
+ response = stub.LoadModel(backend_pb2.ModelOptions(Model="bert-base-nli-mean-tokens",Type="SentenceTransformer"))
+ self.assertTrue(response.success)
+ embedding_request = backend_pb2.PredictOptions(Embeddings="This is a test sentence.")
+ embedding_response = stub.Embedding(embedding_request)
+ self.assertIsNotNone(embedding_response.embeddings)
+ except Exception as err:
+ print(err)
+ self.fail("Embedding service failed")
+ finally:
+ self.tearDown()
\ No newline at end of file
diff --git a/backend/python/transformers/test.sh b/backend/python/transformers/test.sh
new file mode 100644
index 0000000000000000000000000000000000000000..eb59f2aaf3f38f78c7ca3dc414ea490ff66776d7
--- /dev/null
+++ b/backend/python/transformers/test.sh
@@ -0,0 +1,11 @@
+#!/bin/bash
+set -e
+
+backend_dir=$(dirname $0)
+if [ -d $backend_dir/common ]; then
+ source $backend_dir/common/libbackend.sh
+else
+ source $backend_dir/../common/libbackend.sh
+fi
+
+runUnittests
diff --git a/backend/python/vibevoice/Makefile b/backend/python/vibevoice/Makefile
new file mode 100644
index 0000000000000000000000000000000000000000..2fd2297be20257b8d8eaf66686064dba76eae755
--- /dev/null
+++ b/backend/python/vibevoice/Makefile
@@ -0,0 +1,23 @@
+.PHONY: vibevoice
+vibevoice:
+ bash install.sh
+
+.PHONY: run
+run: vibevoice
+ @echo "Running vibevoice..."
+ bash run.sh
+ @echo "vibevoice run."
+
+.PHONY: test
+test: vibevoice
+ @echo "Testing vibevoice..."
+ bash test.sh
+ @echo "vibevoice tested."
+
+.PHONY: protogen-clean
+protogen-clean:
+ $(RM) backend_pb2_grpc.py backend_pb2.py
+
+.PHONY: clean
+clean: protogen-clean
+ rm -rf venv __pycache__
\ No newline at end of file
diff --git a/backend/python/vibevoice/backend.py b/backend/python/vibevoice/backend.py
new file mode 100644
index 0000000000000000000000000000000000000000..418940bcb8170cf4e70874a0e1db6a8e5c8ac36d
--- /dev/null
+++ b/backend/python/vibevoice/backend.py
@@ -0,0 +1,485 @@
+#!/usr/bin/env python3
+"""
+This is an extra gRPC server of LocalAI for VibeVoice
+"""
+from concurrent import futures
+import time
+import argparse
+import signal
+import sys
+import os
+import copy
+import traceback
+from pathlib import Path
+import backend_pb2
+import backend_pb2_grpc
+import torch
+from vibevoice.modular.modeling_vibevoice_streaming_inference import VibeVoiceStreamingForConditionalGenerationInference
+from vibevoice.processor.vibevoice_streaming_processor import VibeVoiceStreamingProcessor
+
+import grpc
+
+def is_float(s):
+ """Check if a string can be converted to float."""
+ try:
+ float(s)
+ return True
+ except ValueError:
+ return False
+def is_int(s):
+ """Check if a string can be converted to int."""
+ try:
+ int(s)
+ return True
+ except ValueError:
+ return False
+
+_ONE_DAY_IN_SECONDS = 60 * 60 * 24
+
+# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
+MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
+
+# Implement the BackendServicer class with the service methods
+class BackendServicer(backend_pb2_grpc.BackendServicer):
+ """
+ BackendServicer is the class that implements the gRPC service
+ """
+ def Health(self, request, context):
+ return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
+
+ def LoadModel(self, request, context):
+ # Get device
+ if torch.cuda.is_available():
+ print("CUDA is available", file=sys.stderr)
+ device = "cuda"
+ else:
+ print("CUDA is not available", file=sys.stderr)
+ device = "cpu"
+ mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
+ if mps_available:
+ device = "mps"
+ if not torch.cuda.is_available() and request.CUDA:
+ return backend_pb2.Result(success=False, message="CUDA is not available")
+
+ # Normalize potential 'mpx' typo to 'mps'
+ if device == "mpx":
+ print("Note: device 'mpx' detected, treating it as 'mps'.", file=sys.stderr)
+ device = "mps"
+
+ # Validate mps availability if requested
+ if device == "mps" and not torch.backends.mps.is_available():
+ print("Warning: MPS not available. Falling back to CPU.", file=sys.stderr)
+ device = "cpu"
+
+ self.device = device
+ self._torch_device = torch.device(device)
+
+ options = request.Options
+
+ # empty dict
+ self.options = {}
+
+ # The options are a list of strings in this form optname:optvalue
+ # We are storing all the options in a dict so we can use it later when
+ # generating the audio
+ for opt in options:
+ if ":" not in opt:
+ continue
+ key, value = opt.split(":", 1) # Split only on first colon
+ # if value is a number, convert it to the appropriate type
+ if is_float(value):
+ value = float(value)
+ elif is_int(value):
+ value = int(value)
+ elif value.lower() in ["true", "false"]:
+ value = value.lower() == "true"
+ self.options[key] = value
+
+ # Get model path from request
+ model_path = request.Model
+ if not model_path:
+ model_path = "microsoft/VibeVoice-Realtime-0.5B"
+
+ # Get inference steps from options, default to 5
+ self.inference_steps = self.options.get("inference_steps", 5)
+ if not isinstance(self.inference_steps, int) or self.inference_steps <= 0:
+ self.inference_steps = 5
+
+ # Get cfg_scale from options, default to 1.5
+ self.cfg_scale = self.options.get("cfg_scale", 1.5)
+ if not isinstance(self.cfg_scale, (int, float)) or self.cfg_scale <= 0:
+ self.cfg_scale = 1.5
+
+ # Determine voices directory
+ # Priority order:
+ # 1. voices_dir option (explicitly set by user - highest priority)
+ # 2. Relative to ModelFile if provided
+ # 3. Relative to ModelPath (models directory) if provided
+ # 4. Backend directory
+ # 5. Absolute path from AudioPath if provided
+ voices_dir = None
+
+ # First check if voices_dir is explicitly set in options
+ if "voices_dir" in self.options:
+ voices_dir_option = self.options["voices_dir"]
+ if isinstance(voices_dir_option, str) and voices_dir_option.strip():
+ voices_dir = voices_dir_option.strip()
+ # If relative path, try to resolve it relative to ModelPath or ModelFile
+ if not os.path.isabs(voices_dir):
+ if hasattr(request, 'ModelPath') and request.ModelPath:
+ voices_dir = os.path.join(request.ModelPath, voices_dir)
+ elif request.ModelFile:
+ model_file_base = os.path.dirname(request.ModelFile)
+ voices_dir = os.path.join(model_file_base, voices_dir)
+ # If still relative, make it absolute from current working directory
+ if not os.path.isabs(voices_dir):
+ voices_dir = os.path.abspath(voices_dir)
+ # Check if the directory exists
+ if not os.path.exists(voices_dir):
+ print(f"Warning: voices_dir option specified but directory does not exist: {voices_dir}", file=sys.stderr)
+ voices_dir = None
+
+ # If not set via option, try relative to ModelFile if provided
+ if not voices_dir and request.ModelFile:
+ model_file_base = os.path.dirname(request.ModelFile)
+ voices_dir = os.path.join(model_file_base, "voices", "streaming_model")
+ if not os.path.exists(voices_dir):
+ voices_dir = None
+
+ # If not found, try relative to ModelPath (models directory)
+ if not voices_dir and hasattr(request, 'ModelPath') and request.ModelPath:
+ voices_dir = os.path.join(request.ModelPath, "voices", "streaming_model")
+ if not os.path.exists(voices_dir):
+ voices_dir = None
+
+ # If not found, try relative to backend directory
+ if not voices_dir:
+ backend_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+ voices_dir = os.path.join(backend_dir, "vibevoice", "voices", "streaming_model")
+ if not os.path.exists(voices_dir):
+ # Try absolute path from AudioPath if provided
+ if request.AudioPath and os.path.isabs(request.AudioPath):
+ voices_dir = os.path.dirname(request.AudioPath)
+ else:
+ voices_dir = None
+
+ self.voices_dir = voices_dir
+ self.voice_presets = {}
+ self._voice_cache = {}
+ self.default_voice_key = None
+
+ # Load voice presets if directory exists
+ if self.voices_dir and os.path.exists(self.voices_dir):
+ self._load_voice_presets()
+ else:
+ print(f"Warning: Voices directory not found. Voice presets will not be available.", file=sys.stderr)
+
+ try:
+ print(f"Loading processor & model from {model_path}", file=sys.stderr)
+ self.processor = VibeVoiceStreamingProcessor.from_pretrained(model_path)
+
+ # Decide dtype & attention implementation
+ if self.device == "mps":
+ load_dtype = torch.float32 # MPS requires float32
+ device_map = None
+ attn_impl_primary = "sdpa" # flash_attention_2 not supported on MPS
+ elif self.device == "cuda":
+ load_dtype = torch.bfloat16
+ device_map = "cuda"
+ attn_impl_primary = "flash_attention_2"
+ else: # cpu
+ load_dtype = torch.float32
+ device_map = "cpu"
+ attn_impl_primary = "sdpa"
+
+ print(f"Using device: {self.device}, torch_dtype: {load_dtype}, attn_implementation: {attn_impl_primary}", file=sys.stderr)
+
+ # Load model with device-specific logic
+ try:
+ if self.device == "mps":
+ self.model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained(
+ model_path,
+ torch_dtype=load_dtype,
+ attn_implementation=attn_impl_primary,
+ device_map=None, # load then move
+ )
+ self.model.to("mps")
+ elif self.device == "cuda":
+ self.model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained(
+ model_path,
+ torch_dtype=load_dtype,
+ device_map="cuda",
+ attn_implementation=attn_impl_primary,
+ )
+ else: # cpu
+ self.model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained(
+ model_path,
+ torch_dtype=load_dtype,
+ device_map="cpu",
+ attn_implementation=attn_impl_primary,
+ )
+ except Exception as e:
+ if attn_impl_primary == 'flash_attention_2':
+ print(f"[ERROR] : {type(e).__name__}: {e}", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+ print("Error loading the model. Trying to use SDPA. However, note that only flash_attention_2 has been fully tested, and using SDPA may result in lower audio quality.", file=sys.stderr)
+ self.model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained(
+ model_path,
+ torch_dtype=load_dtype,
+ device_map=(self.device if self.device in ("cuda", "cpu") else None),
+ attn_implementation='sdpa'
+ )
+ if self.device == "mps":
+ self.model.to("mps")
+ else:
+ raise e
+
+ self.model.eval()
+ self.model.set_ddpm_inference_steps(num_steps=self.inference_steps)
+
+ # Set default voice key
+ if self.voice_presets:
+ # Try to get default from environment or use first available
+ preset_name = os.environ.get("VOICE_PRESET")
+ self.default_voice_key = self._determine_voice_key(preset_name)
+ print(f"Default voice preset: {self.default_voice_key}", file=sys.stderr)
+ else:
+ print("Warning: No voice presets available. Voice selection will not work.", file=sys.stderr)
+
+ except Exception as err:
+ return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
+
+ return backend_pb2.Result(message="Model loaded successfully", success=True)
+
+ def _load_voice_presets(self):
+ """Load voice presets from the voices directory."""
+ if not self.voices_dir or not os.path.exists(self.voices_dir):
+ self.voice_presets = {}
+ return
+
+ self.voice_presets = {}
+
+ # Get all .pt files in the voices directory
+ pt_files = [f for f in os.listdir(self.voices_dir)
+ if f.lower().endswith('.pt') and os.path.isfile(os.path.join(self.voices_dir, f))]
+
+ # Create dictionary with filename (without extension) as key
+ for pt_file in pt_files:
+ # Remove .pt extension to get the name
+ name = os.path.splitext(pt_file)[0]
+ # Create full path
+ full_path = os.path.join(self.voices_dir, pt_file)
+ self.voice_presets[name] = full_path
+
+ # Sort the voice presets alphabetically by name
+ self.voice_presets = dict(sorted(self.voice_presets.items()))
+
+ print(f"Found {len(self.voice_presets)} voice files in {self.voices_dir}", file=sys.stderr)
+ if self.voice_presets:
+ print(f"Available voices: {', '.join(self.voice_presets.keys())}", file=sys.stderr)
+
+ def _determine_voice_key(self, name):
+ """Determine voice key from name or use default."""
+ if name and name in self.voice_presets:
+ return name
+
+ # Try default key
+ default_key = "en-WHTest_man"
+ if default_key in self.voice_presets:
+ return default_key
+
+ # Use first available
+ if self.voice_presets:
+ first_key = next(iter(self.voice_presets))
+ print(f"Using fallback voice preset: {first_key}", file=sys.stderr)
+ return first_key
+
+ return None
+
+ def _get_voice_path(self, speaker_name):
+ """Get voice file path for a given speaker name."""
+ if not self.voice_presets:
+ return None
+
+ # First try exact match
+ if speaker_name and speaker_name in self.voice_presets:
+ return self.voice_presets[speaker_name]
+
+ # Try partial matching (case insensitive)
+ if speaker_name:
+ speaker_lower = speaker_name.lower()
+ for preset_name, path in self.voice_presets.items():
+ if preset_name.lower() in speaker_lower or speaker_lower in preset_name.lower():
+ return path
+
+ # Default to first voice if no match found
+ if self.default_voice_key and self.default_voice_key in self.voice_presets:
+ return self.voice_presets[self.default_voice_key]
+ elif self.voice_presets:
+ default_voice = list(self.voice_presets.values())[0]
+ print(f"Warning: No voice preset found for '{speaker_name}', using default voice: {default_voice}", file=sys.stderr)
+ return default_voice
+
+ return None
+
+ def _ensure_voice_cached(self, voice_path):
+ """Load and cache voice preset."""
+ if not voice_path or not os.path.exists(voice_path):
+ return None
+
+ # Use path as cache key
+ if voice_path not in self._voice_cache:
+ print(f"Loading prefilled prompt from {voice_path}", file=sys.stderr)
+ prefilled_outputs = torch.load(
+ voice_path,
+ map_location=self._torch_device,
+ weights_only=False,
+ )
+ self._voice_cache[voice_path] = prefilled_outputs
+
+ return self._voice_cache[voice_path]
+
+ def TTS(self, request, context):
+ try:
+ # Get voice selection
+ # Priority: request.voice > AudioPath > default
+ voice_path = None
+ voice_key = None
+
+ if request.voice:
+ # Try to get voice by name
+ voice_path = self._get_voice_path(request.voice)
+ if voice_path:
+ voice_key = request.voice
+ elif request.AudioPath:
+ # Use AudioPath as voice file
+ if os.path.isabs(request.AudioPath):
+ voice_path = request.AudioPath
+ elif request.ModelFile:
+ model_file_base = os.path.dirname(request.ModelFile)
+ voice_path = os.path.join(model_file_base, request.AudioPath)
+ elif hasattr(request, 'ModelPath') and request.ModelPath:
+ voice_path = os.path.join(request.ModelPath, request.AudioPath)
+ else:
+ voice_path = request.AudioPath
+ elif self.default_voice_key:
+ voice_path = self._get_voice_path(self.default_voice_key)
+ voice_key = self.default_voice_key
+
+ if not voice_path or not os.path.exists(voice_path):
+ return backend_pb2.Result(
+ success=False,
+ message=f"Voice file not found: {voice_path}. Please provide a valid voice preset or AudioPath."
+ )
+
+ # Load voice preset
+ prefilled_outputs = self._ensure_voice_cached(voice_path)
+ if prefilled_outputs is None:
+ return backend_pb2.Result(
+ success=False,
+ message=f"Failed to load voice preset from {voice_path}"
+ )
+
+ # Get generation parameters from options
+ cfg_scale = self.options.get("cfg_scale", self.cfg_scale)
+ inference_steps = self.options.get("inference_steps", self.inference_steps)
+ do_sample = self.options.get("do_sample", False)
+ temperature = self.options.get("temperature", 0.9)
+ top_p = self.options.get("top_p", 0.9)
+
+ # Update inference steps if needed
+ if inference_steps != self.inference_steps:
+ self.model.set_ddpm_inference_steps(num_steps=inference_steps)
+ self.inference_steps = inference_steps
+
+ # Prepare text
+ text = request.text.strip().replace("'", "'").replace('"', '"').replace('"', '"')
+
+ # Prepare inputs
+ inputs = self.processor.process_input_with_cached_prompt(
+ text=text,
+ cached_prompt=prefilled_outputs,
+ padding=True,
+ return_tensors="pt",
+ return_attention_mask=True,
+ )
+
+ # Move tensors to target device
+ target_device = self._torch_device
+ for k, v in inputs.items():
+ if torch.is_tensor(v):
+ inputs[k] = v.to(target_device)
+
+ print(f"Generating audio with cfg_scale: {cfg_scale}, inference_steps: {inference_steps}", file=sys.stderr)
+
+ # Generate audio
+ outputs = self.model.generate(
+ **inputs,
+ max_new_tokens=None,
+ cfg_scale=cfg_scale,
+ tokenizer=self.processor.tokenizer,
+ generation_config={
+ 'do_sample': do_sample,
+ 'temperature': temperature if do_sample else 1.0,
+ 'top_p': top_p if do_sample else 1.0,
+ },
+ verbose=False,
+ all_prefilled_outputs=copy.deepcopy(prefilled_outputs) if prefilled_outputs is not None else None,
+ )
+
+ # Save output
+ if outputs.speech_outputs and outputs.speech_outputs[0] is not None:
+ self.processor.save_audio(
+ outputs.speech_outputs[0], # First (and only) batch item
+ output_path=request.dst,
+ )
+ print(f"Saved output to {request.dst}", file=sys.stderr)
+ else:
+ return backend_pb2.Result(
+ success=False,
+ message="No audio output generated"
+ )
+
+ except Exception as err:
+ print(f"Error in TTS: {err}", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+ return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
+
+ return backend_pb2.Result(success=True)
+
+def serve(address):
+ server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
+ options=[
+ ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
+ ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
+ ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
+ ])
+ backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
+ server.add_insecure_port(address)
+ server.start()
+ print("Server started. Listening on: " + address, file=sys.stderr)
+
+ # Define the signal handler function
+ def signal_handler(sig, frame):
+ print("Received termination signal. Shutting down...")
+ server.stop(0)
+ sys.exit(0)
+
+ # Set the signal handlers for SIGINT and SIGTERM
+ signal.signal(signal.SIGINT, signal_handler)
+ signal.signal(signal.SIGTERM, signal_handler)
+
+ try:
+ while True:
+ time.sleep(_ONE_DAY_IN_SECONDS)
+ except KeyboardInterrupt:
+ server.stop(0)
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Run the gRPC server.")
+ parser.add_argument(
+ "--addr", default="localhost:50051", help="The address to bind the server to."
+ )
+ args = parser.parse_args()
+
+ serve(args.addr)
diff --git a/backend/python/vibevoice/install.sh b/backend/python/vibevoice/install.sh
new file mode 100644
index 0000000000000000000000000000000000000000..3f669c6fe67f09dd55d4e154ef9f360ab8c768a5
--- /dev/null
+++ b/backend/python/vibevoice/install.sh
@@ -0,0 +1,39 @@
+#!/bin/bash
+set -e
+
+backend_dir=$(dirname $0)
+if [ -d $backend_dir/common ]; then
+ source $backend_dir/common/libbackend.sh
+else
+ source $backend_dir/../common/libbackend.sh
+fi
+
+# This is here because the Intel pip index is broken and returns 200 status codes for every package name, it just doesn't return any package links.
+# This makes uv think that the package exists in the Intel pip index, and by default it stops looking at other pip indexes once it finds a match.
+# We need uv to continue falling through to the pypi default index to find optimum[openvino] in the pypi index
+# the --upgrade actually allows us to *downgrade* torch to the version provided in the Intel pip index
+if [ "x${BUILD_PROFILE}" == "xintel" ]; then
+ EXTRA_PIP_INSTALL_FLAGS+=" --upgrade --index-strategy=unsafe-first-match"
+fi
+
+# Use python 3.12 for l4t
+if [ "x${BUILD_PROFILE}" == "xl4t13" ]; then
+ PYTHON_VERSION="3.12"
+ PYTHON_PATCH="12"
+ PY_STANDALONE_TAG="20251120"
+fi
+
+if [ "x${BUILD_PROFILE}" == "xl4t12" ]; then
+ USE_PIP=true
+fi
+
+installRequirements
+
+git clone https://github.com/microsoft/VibeVoice.git
+cd VibeVoice/
+
+if [ "x${USE_PIP}" == "xtrue" ]; then
+ pip install ${EXTRA_PIP_INSTALL_FLAGS:-} .
+else
+ uv pip install ${EXTRA_PIP_INSTALL_FLAGS:-} .
+fi
\ No newline at end of file
diff --git a/backend/python/vibevoice/requirements-cpu.txt b/backend/python/vibevoice/requirements-cpu.txt
new file mode 100644
index 0000000000000000000000000000000000000000..607db4ae3ffe8f52eef07d08136940d5f5b381bd
--- /dev/null
+++ b/backend/python/vibevoice/requirements-cpu.txt
@@ -0,0 +1,22 @@
+--extra-index-url https://download.pytorch.org/whl/cpu
+git+https://github.com/huggingface/diffusers
+opencv-python
+transformers==4.51.3
+torchvision==0.22.1
+accelerate
+compel
+peft
+sentencepiece
+torch==2.7.1
+optimum-quanto
+ftfy
+llvmlite>=0.40.0
+numba>=0.57.0
+tqdm
+numpy
+scipy
+librosa
+ml-collections
+absl-py
+gradio
+av
\ No newline at end of file
diff --git a/backend/python/vibevoice/requirements-cublas12.txt b/backend/python/vibevoice/requirements-cublas12.txt
new file mode 100644
index 0000000000000000000000000000000000000000..267a0313e407f5598462c18fbe32c3d2e1088a3e
--- /dev/null
+++ b/backend/python/vibevoice/requirements-cublas12.txt
@@ -0,0 +1,22 @@
+--extra-index-url https://download.pytorch.org/whl/cu121
+git+https://github.com/huggingface/diffusers
+opencv-python
+transformers==4.51.3
+torchvision
+accelerate
+compel
+peft
+sentencepiece
+torch
+ftfy
+optimum-quanto
+llvmlite>=0.40.0
+numba>=0.57.0
+tqdm
+numpy
+scipy
+librosa
+ml-collections
+absl-py
+gradio
+av
\ No newline at end of file
diff --git a/backend/python/vibevoice/requirements-cublas13.txt b/backend/python/vibevoice/requirements-cublas13.txt
new file mode 100644
index 0000000000000000000000000000000000000000..372be740b24ba7dea111df67ecc52008efa0140a
--- /dev/null
+++ b/backend/python/vibevoice/requirements-cublas13.txt
@@ -0,0 +1,22 @@
+--extra-index-url https://download.pytorch.org/whl/cu130
+git+https://github.com/huggingface/diffusers
+opencv-python
+transformers==4.51.3
+torchvision
+accelerate
+compel
+peft
+sentencepiece
+torch
+ftfy
+optimum-quanto
+llvmlite>=0.40.0
+numba>=0.57.0
+tqdm
+numpy
+scipy
+librosa
+ml-collections
+absl-py
+gradio
+av
\ No newline at end of file
diff --git a/backend/python/vibevoice/requirements-hipblas.txt b/backend/python/vibevoice/requirements-hipblas.txt
new file mode 100644
index 0000000000000000000000000000000000000000..291096c3f7557de46fb70e56176837496ff25d67
--- /dev/null
+++ b/backend/python/vibevoice/requirements-hipblas.txt
@@ -0,0 +1,22 @@
+--extra-index-url https://download.pytorch.org/whl/rocm6.3
+torch==2.7.1+rocm6.3
+torchvision==0.22.1+rocm6.3
+git+https://github.com/huggingface/diffusers
+opencv-python
+transformers==4.51.3
+accelerate
+compel
+peft
+sentencepiece
+optimum-quanto
+ftfy
+llvmlite>=0.40.0
+numba>=0.57.0
+tqdm
+numpy
+scipy
+librosa
+ml-collections
+absl-py
+gradio
+av
\ No newline at end of file
diff --git a/backend/python/vibevoice/requirements-intel.txt b/backend/python/vibevoice/requirements-intel.txt
new file mode 100644
index 0000000000000000000000000000000000000000..e040ef6b56aac604d98d0ccfae218b41cf0e35e0
--- /dev/null
+++ b/backend/python/vibevoice/requirements-intel.txt
@@ -0,0 +1,26 @@
+--extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
+intel-extension-for-pytorch==2.3.110+xpu
+torch==2.5.1+cxx11.abi
+torchvision==0.20.1+cxx11.abi
+oneccl_bind_pt==2.8.0+xpu
+optimum[openvino]
+setuptools
+git+https://github.com/huggingface/diffusers
+opencv-python
+transformers==4.51.3
+accelerate
+compel
+peft
+sentencepiece
+optimum-quanto
+ftfy
+llvmlite>=0.40.0
+numba>=0.57.0
+tqdm
+numpy
+scipy
+librosa
+ml-collections
+absl-py
+gradio
+av
\ No newline at end of file
diff --git a/backend/python/vibevoice/requirements-l4t12.txt b/backend/python/vibevoice/requirements-l4t12.txt
new file mode 100644
index 0000000000000000000000000000000000000000..4e033c0f6cb981b6fa9e83325000c02b608e5e7c
--- /dev/null
+++ b/backend/python/vibevoice/requirements-l4t12.txt
@@ -0,0 +1,22 @@
+--extra-index-url https://pypi.jetson-ai-lab.io/jp6/cu129/
+torch
+git+https://github.com/huggingface/diffusers
+transformers==4.51.3
+accelerate
+compel
+peft
+optimum-quanto
+numpy<2
+sentencepiece
+torchvision
+ftfy
+llvmlite>=0.40.0
+numba>=0.57.0
+tqdm
+numpy
+scipy
+librosa
+ml-collections
+absl-py
+gradio
+av
\ No newline at end of file
diff --git a/backend/python/vibevoice/requirements-l4t13.txt b/backend/python/vibevoice/requirements-l4t13.txt
new file mode 100644
index 0000000000000000000000000000000000000000..ca4848d50f1c11f87b3416eeb507772fe4a9b6f2
--- /dev/null
+++ b/backend/python/vibevoice/requirements-l4t13.txt
@@ -0,0 +1,22 @@
+--extra-index-url https://download.pytorch.org/whl/cu130
+torch
+git+https://github.com/huggingface/diffusers
+transformers==4.51.3
+accelerate
+compel
+peft
+optimum-quanto
+numpy<2
+sentencepiece
+torchvision
+ftfy
+llvmlite>=0.40.0
+numba>=0.57.0
+tqdm
+numpy
+scipy
+librosa
+ml-collections
+absl-py
+gradio
+av
\ No newline at end of file
diff --git a/backend/python/vibevoice/requirements-mps.txt b/backend/python/vibevoice/requirements-mps.txt
new file mode 100644
index 0000000000000000000000000000000000000000..11757190ecf568d5130ddcecb02fd76ceae3b42d
--- /dev/null
+++ b/backend/python/vibevoice/requirements-mps.txt
@@ -0,0 +1,21 @@
+torch==2.7.1
+torchvision==0.22.1
+git+https://github.com/huggingface/diffusers
+opencv-python
+transformers==4.51.3
+accelerate
+compel
+peft
+sentencepiece
+optimum-quanto
+ftfy
+llvmlite>=0.40.0
+numba>=0.57.0
+tqdm
+numpy
+scipy
+librosa
+ml-collections
+absl-py
+gradio
+av
\ No newline at end of file
diff --git a/backend/python/vibevoice/requirements.txt b/backend/python/vibevoice/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..9e532186b2c8d2061852e7ce9ebd7f5536dc9763
--- /dev/null
+++ b/backend/python/vibevoice/requirements.txt
@@ -0,0 +1,4 @@
+grpcio==1.71.0
+protobuf
+certifi
+packaging==24.1
diff --git a/backend/python/vibevoice/run.sh b/backend/python/vibevoice/run.sh
new file mode 100644
index 0000000000000000000000000000000000000000..82b7b09ecc7d20cb3513bdd1b18f3524f7d4cbd7
--- /dev/null
+++ b/backend/python/vibevoice/run.sh
@@ -0,0 +1,9 @@
+#!/bin/bash
+backend_dir=$(dirname $0)
+if [ -d $backend_dir/common ]; then
+ source $backend_dir/common/libbackend.sh
+else
+ source $backend_dir/../common/libbackend.sh
+fi
+
+startBackend $@
\ No newline at end of file
diff --git a/backend/python/vibevoice/test.py b/backend/python/vibevoice/test.py
new file mode 100644
index 0000000000000000000000000000000000000000..e0b1a0bdd1240e1edfc9218c0c2ba032e5cf6300
--- /dev/null
+++ b/backend/python/vibevoice/test.py
@@ -0,0 +1,82 @@
+"""
+A test script to test the gRPC service
+"""
+import unittest
+import subprocess
+import time
+import backend_pb2
+import backend_pb2_grpc
+
+import grpc
+
+
+class TestBackendServicer(unittest.TestCase):
+ """
+ TestBackendServicer is the class that tests the gRPC service
+ """
+ def setUp(self):
+ """
+ This method sets up the gRPC service by starting the server
+ """
+ self.service = subprocess.Popen(["python3", "backend.py", "--addr", "localhost:50051"])
+ time.sleep(30)
+
+ def tearDown(self) -> None:
+ """
+ This method tears down the gRPC service by terminating the server
+ """
+ self.service.terminate()
+ self.service.wait()
+
+ def test_server_startup(self):
+ """
+ This method tests if the server starts up successfully
+ """
+ try:
+ self.setUp()
+ with grpc.insecure_channel("localhost:50051") as channel:
+ stub = backend_pb2_grpc.BackendStub(channel)
+ response = stub.Health(backend_pb2.HealthMessage())
+ self.assertEqual(response.message, b'OK')
+ except Exception as err:
+ print(err)
+ self.fail("Server failed to start")
+ finally:
+ self.tearDown()
+
+ def test_load_model(self):
+ """
+ This method tests if the model is loaded successfully
+ """
+ try:
+ self.setUp()
+ with grpc.insecure_channel("localhost:50051") as channel:
+ stub = backend_pb2_grpc.BackendStub(channel)
+ response = stub.LoadModel(backend_pb2.ModelOptions(Model="tts_models/en/vctk/vits"))
+ print(response)
+ self.assertTrue(response.success)
+ self.assertEqual(response.message, "Model loaded successfully")
+ except Exception as err:
+ print(err)
+ self.fail("LoadModel service failed")
+ finally:
+ self.tearDown()
+
+ def test_tts(self):
+ """
+ This method tests if the embeddings are generated successfully
+ """
+ try:
+ self.setUp()
+ with grpc.insecure_channel("localhost:50051") as channel:
+ stub = backend_pb2_grpc.BackendStub(channel)
+ response = stub.LoadModel(backend_pb2.ModelOptions(Model="tts_models/en/vctk/vits"))
+ self.assertTrue(response.success)
+ tts_request = backend_pb2.TTSRequest(text="80s TV news production music hit for tonight's biggest story")
+ tts_response = stub.TTS(tts_request)
+ self.assertIsNotNone(tts_response)
+ except Exception as err:
+ print(err)
+ self.fail("TTS service failed")
+ finally:
+ self.tearDown()
\ No newline at end of file
diff --git a/backend/python/vibevoice/test.sh b/backend/python/vibevoice/test.sh
new file mode 100644
index 0000000000000000000000000000000000000000..eb59f2aaf3f38f78c7ca3dc414ea490ff66776d7
--- /dev/null
+++ b/backend/python/vibevoice/test.sh
@@ -0,0 +1,11 @@
+#!/bin/bash
+set -e
+
+backend_dir=$(dirname $0)
+if [ -d $backend_dir/common ]; then
+ source $backend_dir/common/libbackend.sh
+else
+ source $backend_dir/../common/libbackend.sh
+fi
+
+runUnittests
diff --git a/backend/python/vllm/Makefile b/backend/python/vllm/Makefile
new file mode 100644
index 0000000000000000000000000000000000000000..c7c1b6869c029b045daf35501bc6b30a718c0014
--- /dev/null
+++ b/backend/python/vllm/Makefile
@@ -0,0 +1,23 @@
+.PHONY: vllm
+vllm:
+ bash install.sh
+
+.PHONY: run
+run: vllm
+ @echo "Running vllm..."
+ bash run.sh
+ @echo "vllm run."
+
+.PHONY: test
+test: vllm
+ @echo "Testing vllm..."
+ bash test.sh
+ @echo "vllm tested."
+
+.PHONY: protogen-clean
+protogen-clean:
+ $(RM) backend_pb2_grpc.py backend_pb2.py
+
+.PHONY: clean
+clean: protogen-clean
+ rm -rf venv __pycache__
\ No newline at end of file
diff --git a/backend/python/vllm/README.md b/backend/python/vllm/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..dc933d2a0b8b268e9b6b828d52adb83b6e4bde62
--- /dev/null
+++ b/backend/python/vllm/README.md
@@ -0,0 +1,5 @@
+# Creating a separate environment for the vllm project
+
+```
+make vllm
+```
\ No newline at end of file
diff --git a/backend/python/vllm/backend.py b/backend/python/vllm/backend.py
new file mode 100644
index 0000000000000000000000000000000000000000..56698a54e5f579ec1dfe00b9d4189573ab44cb55
--- /dev/null
+++ b/backend/python/vllm/backend.py
@@ -0,0 +1,367 @@
+#!/usr/bin/env python3
+import asyncio
+from concurrent import futures
+import argparse
+import signal
+import sys
+import os
+from typing import List
+from PIL import Image
+
+import backend_pb2
+import backend_pb2_grpc
+
+import grpc
+from vllm.engine.arg_utils import AsyncEngineArgs
+from vllm.engine.async_llm_engine import AsyncLLMEngine
+from vllm.sampling_params import SamplingParams
+from vllm.utils import random_uuid
+from vllm.transformers_utils.tokenizer import get_tokenizer
+from vllm.multimodal.utils import fetch_image
+from vllm.assets.video import VideoAsset
+import base64
+import io
+
+_ONE_DAY_IN_SECONDS = 60 * 60 * 24
+
+# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
+MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
+
+# Implement the BackendServicer class with the service methods
+class BackendServicer(backend_pb2_grpc.BackendServicer):
+ """
+ A gRPC servicer that implements the Backend service defined in backend.proto.
+ """
+ def generate(self,prompt, max_new_tokens):
+ """
+ Generates text based on the given prompt and maximum number of new tokens.
+
+ Args:
+ prompt (str): The prompt to generate text from.
+ max_new_tokens (int): The maximum number of new tokens to generate.
+
+ Returns:
+ str: The generated text.
+ """
+ self.generator.end_beam_search()
+
+ # Tokenizing the input
+ ids = self.generator.tokenizer.encode(prompt)
+
+ self.generator.gen_begin_reuse(ids)
+ initial_len = self.generator.sequence[0].shape[0]
+ has_leading_space = False
+ decoded_text = ''
+ for i in range(max_new_tokens):
+ token = self.generator.gen_single_token()
+ if i == 0 and self.generator.tokenizer.tokenizer.IdToPiece(int(token)).startswith('▁'):
+ has_leading_space = True
+
+ decoded_text = self.generator.tokenizer.decode(self.generator.sequence[0][initial_len:])
+ if has_leading_space:
+ decoded_text = ' ' + decoded_text
+
+ if token.item() == self.generator.tokenizer.eos_token_id:
+ break
+ return decoded_text
+
+ def Health(self, request, context):
+ """
+ Returns a health check message.
+
+ Args:
+ request: The health check request.
+ context: The gRPC context.
+
+ Returns:
+ backend_pb2.Reply: The health check reply.
+ """
+ return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
+
+ async def LoadModel(self, request, context):
+ """
+ Loads a language model.
+
+ Args:
+ request: The load model request.
+ context: The gRPC context.
+
+ Returns:
+ backend_pb2.Result: The load model result.
+ """
+ engine_args = AsyncEngineArgs(
+ model=request.Model,
+ )
+
+ if request.Quantization != "":
+ engine_args.quantization = request.Quantization
+ if request.LoadFormat != "":
+ engine_args.load_format = request.LoadFormat
+ if request.GPUMemoryUtilization != 0:
+ engine_args.gpu_memory_utilization = request.GPUMemoryUtilization
+ if request.TrustRemoteCode:
+ engine_args.trust_remote_code = request.TrustRemoteCode
+ if request.EnforceEager:
+ engine_args.enforce_eager = request.EnforceEager
+ if request.TensorParallelSize:
+ engine_args.tensor_parallel_size = request.TensorParallelSize
+ if request.SwapSpace != 0:
+ engine_args.swap_space = request.SwapSpace
+ if request.MaxModelLen != 0:
+ engine_args.max_model_len = request.MaxModelLen
+ if request.DisableLogStatus:
+ engine_args.disable_log_status = request.DisableLogStatus
+ if request.DType != "":
+ engine_args.dtype = request.DType
+ if request.LimitImagePerPrompt != 0 or request.LimitVideoPerPrompt != 0 or request.LimitAudioPerPrompt != 0:
+ # limit-mm-per-prompt defaults to 1 per modality, based on vLLM docs
+ engine_args.limit_mm_per_prompt = {
+ "image": max(request.LimitImagePerPrompt, 1),
+ "video": max(request.LimitVideoPerPrompt, 1),
+ "audio": max(request.LimitAudioPerPrompt, 1)
+ }
+
+ try:
+ self.llm = AsyncLLMEngine.from_engine_args(engine_args)
+ except Exception as err:
+ print(f"Unexpected {err=}, {type(err)=}", file=sys.stderr)
+ return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
+
+ try:
+ engine_model_config = await self.llm.get_model_config()
+ self.tokenizer = get_tokenizer(
+ engine_model_config.tokenizer,
+ tokenizer_mode=engine_model_config.tokenizer_mode,
+ trust_remote_code=engine_model_config.trust_remote_code,
+ truncation_side="left",
+ )
+ except Exception as err:
+ return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
+ print("Model loaded successfully", file=sys.stderr)
+ return backend_pb2.Result(message="Model loaded successfully", success=True)
+
+ async def Predict(self, request, context):
+ """
+ Generates text based on the given prompt and sampling parameters.
+
+ Args:
+ request: The predict request.
+ context: The gRPC context.
+
+ Returns:
+ backend_pb2.Reply: The predict result.
+ """
+ gen = self._predict(request, context, streaming=False)
+ res = await gen.__anext__()
+ return res
+
+ def Embedding(self, request, context):
+ """
+ A gRPC method that calculates embeddings for a given sentence.
+
+ Args:
+ request: An EmbeddingRequest object that contains the request parameters.
+ context: A grpc.ServicerContext object that provides information about the RPC.
+
+ Returns:
+ An EmbeddingResult object that contains the calculated embeddings.
+ """
+ print("Calculated embeddings for: " + request.Embeddings, file=sys.stderr)
+ outputs = self.model.encode(request.Embeddings)
+ # Check if we have one result at least
+ if len(outputs) == 0:
+ context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
+ context.set_details("No embeddings were calculated.")
+ return backend_pb2.EmbeddingResult()
+ return backend_pb2.EmbeddingResult(embeddings=outputs[0].outputs.embedding)
+
+ async def PredictStream(self, request, context):
+ """
+ Generates text based on the given prompt and sampling parameters, and streams the results.
+
+ Args:
+ request: The predict stream request.
+ context: The gRPC context.
+
+ Returns:
+ backend_pb2.Result: The predict stream result.
+ """
+ iterations = self._predict(request, context, streaming=True)
+ try:
+ async for iteration in iterations:
+ yield iteration
+ finally:
+ await iterations.aclose()
+
+ async def _predict(self, request, context, streaming=False):
+ # Build the sampling parameters
+ # NOTE: this must stay in sync with the vllm backend
+ request_to_sampling_params = {
+ "N": "n",
+ "PresencePenalty": "presence_penalty",
+ "FrequencyPenalty": "frequency_penalty",
+ "RepetitionPenalty": "repetition_penalty",
+ "Temperature": "temperature",
+ "TopP": "top_p",
+ "TopK": "top_k",
+ "MinP": "min_p",
+ "Seed": "seed",
+ "StopPrompts": "stop",
+ "StopTokenIds": "stop_token_ids",
+ "BadWords": "bad_words",
+ "IncludeStopStrInOutput": "include_stop_str_in_output",
+ "IgnoreEOS": "ignore_eos",
+ "Tokens": "max_tokens",
+ "MinTokens": "min_tokens",
+ "Logprobs": "logprobs",
+ "PromptLogprobs": "prompt_logprobs",
+ "SkipSpecialTokens": "skip_special_tokens",
+ "SpacesBetweenSpecialTokens": "spaces_between_special_tokens",
+ "TruncatePromptTokens": "truncate_prompt_tokens",
+ "GuidedDecoding": "guided_decoding",
+ }
+
+ sampling_params = SamplingParams(top_p=0.9, max_tokens=200)
+
+ for request_field, param_field in request_to_sampling_params.items():
+ if hasattr(request, request_field):
+ value = getattr(request, request_field)
+ if value not in (None, 0, [], False):
+ setattr(sampling_params, param_field, value)
+
+ # Extract image paths and process images
+ prompt = request.Prompt
+
+ image_paths = request.Images
+ image_data = [self.load_image(img_path) for img_path in image_paths]
+
+ videos_path = request.Videos
+ video_data = [self.load_video(video_path) for video_path in videos_path]
+
+ # If tokenizer template is enabled and messages are provided instead of prompt, apply the tokenizer template
+ if not request.Prompt and request.UseTokenizerTemplate and request.Messages:
+ prompt = self.tokenizer.apply_chat_template(request.Messages, tokenize=False, add_generation_prompt=True)
+
+ # Generate text using the LLM engine
+ request_id = random_uuid()
+ print(f"Generating text with request_id: {request_id}", file=sys.stderr)
+ multi_modal_data = {}
+ if image_data:
+ multi_modal_data["image"] = image_data
+ if video_data:
+ multi_modal_data["video"] = video_data
+ outputs = self.llm.generate(
+ {
+ "prompt": prompt,
+ "multi_modal_data": multi_modal_data if multi_modal_data else None,
+ },
+ sampling_params=sampling_params,
+ request_id=request_id,
+ )
+
+ # Stream the results
+ generated_text = ""
+ try:
+ async for request_output in outputs:
+ iteration_text = request_output.outputs[0].text
+
+ if streaming:
+ # Remove text already sent as vllm concatenates the text from previous yields
+ delta_iteration_text = iteration_text.removeprefix(generated_text)
+ # Send the partial result
+ yield backend_pb2.Reply(message=bytes(delta_iteration_text, encoding='utf-8'))
+
+ # Keep track of text generated
+ generated_text = iteration_text
+ finally:
+ await outputs.aclose()
+
+ # If streaming, we already sent everything
+ if streaming:
+ return
+
+ # Remove the image files from /tmp folder
+ for img_path in image_paths:
+ try:
+ os.remove(img_path)
+ except Exception as e:
+ print(f"Error removing image file: {img_path}, {e}", file=sys.stderr)
+
+ # Sending the final generated text
+ yield backend_pb2.Reply(message=bytes(generated_text, encoding='utf-8'))
+
+ def load_image(self, image_path: str):
+ """
+ Load an image from the given file path or base64 encoded data.
+
+ Args:
+ image_path (str): The path to the image file or base64 encoded data.
+
+ Returns:
+ Image: The loaded image.
+ """
+ try:
+
+ image_data = base64.b64decode(image_path)
+ image = Image.open(io.BytesIO(image_data))
+ return image
+ except Exception as e:
+ print(f"Error loading image {image_path}: {e}", file=sys.stderr)
+ return None
+
+ def load_video(self, video_path: str):
+ """
+ Load a video from the given file path.
+
+ Args:
+ video_path (str): The path to the image file.
+
+ Returns:
+ Video: The loaded video.
+ """
+ try:
+ timestamp = str(int(time.time() * 1000)) # Generate timestamp
+ p = f"/tmp/vl-{timestamp}.data" # Use timestamp in filename
+ with open(p, "wb") as f:
+ f.write(base64.b64decode(video_path))
+ video = VideoAsset(name=p).np_ndarrays
+ os.remove(p)
+ return video
+ except Exception as e:
+ print(f"Error loading video {video_path}: {e}", file=sys.stderr)
+ return None
+
+async def serve(address):
+ # Start asyncio gRPC server
+ server = grpc.aio.server(migration_thread_pool=futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
+ options=[
+ ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
+ ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
+ ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
+ ])
+ # Add the servicer to the server
+ backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
+ # Bind the server to the address
+ server.add_insecure_port(address)
+
+ # Gracefully shutdown the server on SIGTERM or SIGINT
+ loop = asyncio.get_event_loop()
+ for sig in (signal.SIGINT, signal.SIGTERM):
+ loop.add_signal_handler(
+ sig, lambda: asyncio.ensure_future(server.stop(5))
+ )
+
+ # Start the server
+ await server.start()
+ print("Server started. Listening on: " + address, file=sys.stderr)
+ # Wait for the server to be terminated
+ await server.wait_for_termination()
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Run the gRPC server.")
+ parser.add_argument(
+ "--addr", default="localhost:50051", help="The address to bind the server to."
+ )
+ args = parser.parse_args()
+
+ asyncio.run(serve(args.addr))
diff --git a/backend/python/vllm/install.sh b/backend/python/vllm/install.sh
new file mode 100644
index 0000000000000000000000000000000000000000..7dcd29db4a92d036ef748c4a2cee7bd1c5f4f3c0
--- /dev/null
+++ b/backend/python/vllm/install.sh
@@ -0,0 +1,45 @@
+#!/bin/bash
+set -e
+
+EXTRA_PIP_INSTALL_FLAGS="--no-build-isolation"
+
+# Avoid to overcommit the CPU during build
+# https://github.com/vllm-project/vllm/issues/20079
+# https://docs.vllm.ai/en/v0.8.3/serving/env_vars.html
+# https://docs.redhat.com/it/documentation/red_hat_ai_inference_server/3.0/html/vllm_server_arguments/environment_variables-server-arguments
+export NVCC_THREADS=2
+export MAX_JOBS=1
+
+backend_dir=$(dirname $0)
+
+if [ -d $backend_dir/common ]; then
+ source $backend_dir/common/libbackend.sh
+else
+ source $backend_dir/../common/libbackend.sh
+fi
+
+# This is here because the Intel pip index is broken and returns 200 status codes for every package name, it just doesn't return any package links.
+# This makes uv think that the package exists in the Intel pip index, and by default it stops looking at other pip indexes once it finds a match.
+# We need uv to continue falling through to the pypi default index to find optimum[openvino] in the pypi index
+# the --upgrade actually allows us to *downgrade* torch to the version provided in the Intel pip index
+if [ "x${BUILD_PROFILE}" == "xintel" ]; then
+ EXTRA_PIP_INSTALL_FLAGS+=" --upgrade --index-strategy=unsafe-first-match"
+fi
+
+# We don't embed this into the images as it is a large dependency and not always needed.
+# Besides, the speed inference are not actually usable in the current state for production use-cases.
+if [ "x${BUILD_TYPE}" == "x" ] && [ "x${FROM_SOURCE:-}" == "xtrue" ]; then
+ ensureVenv
+ # https://docs.vllm.ai/en/v0.6.1/getting_started/cpu-installation.html
+ if [ ! -d vllm ]; then
+ git clone https://github.com/vllm-project/vllm
+ fi
+ pushd vllm
+ uv pip install wheel packaging ninja "setuptools>=49.4.0" numpy typing-extensions pillow setuptools-scm grpcio==1.68.1 protobuf bitsandbytes
+ uv pip install -v -r requirements-cpu.txt --extra-index-url https://download.pytorch.org/whl/cpu
+ VLLM_TARGET_DEVICE=cpu python setup.py install
+ popd
+ rm -rf vllm
+ else
+ installRequirements
+fi
diff --git a/backend/python/vllm/requirements-after.txt b/backend/python/vllm/requirements-after.txt
new file mode 100644
index 0000000000000000000000000000000000000000..76f11f154037e179df28e0240d2c862c183d1995
--- /dev/null
+++ b/backend/python/vllm/requirements-after.txt
@@ -0,0 +1 @@
+vllm
\ No newline at end of file
diff --git a/backend/python/vllm/requirements-cpu.txt b/backend/python/vllm/requirements-cpu.txt
new file mode 100644
index 0000000000000000000000000000000000000000..16c7cbac50c010c97af761b71b79c18f2c92d343
--- /dev/null
+++ b/backend/python/vllm/requirements-cpu.txt
@@ -0,0 +1,3 @@
+accelerate
+torch==2.7.0
+transformers
\ No newline at end of file
diff --git a/backend/python/vllm/requirements-cublas12-after.txt b/backend/python/vllm/requirements-cublas12-after.txt
new file mode 100644
index 0000000000000000000000000000000000000000..9251ba608461ea0aac981ccdc10b58411ac1dd87
--- /dev/null
+++ b/backend/python/vllm/requirements-cublas12-after.txt
@@ -0,0 +1 @@
+https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu12torch2.7cxx11abiTRUE-cp310-cp310-linux_x86_64.whl
diff --git a/backend/python/vllm/requirements-cublas12.txt b/backend/python/vllm/requirements-cublas12.txt
new file mode 100644
index 0000000000000000000000000000000000000000..8bd72ae125fd83f5c06f9876247307b35f8be53f
--- /dev/null
+++ b/backend/python/vllm/requirements-cublas12.txt
@@ -0,0 +1,4 @@
+accelerate
+torch==2.7.0
+transformers
+bitsandbytes
\ No newline at end of file
diff --git a/backend/python/vllm/requirements-hipblas.txt b/backend/python/vllm/requirements-hipblas.txt
new file mode 100644
index 0000000000000000000000000000000000000000..db732bc864ef015ca299e5f7b4fab9873ae00b95
--- /dev/null
+++ b/backend/python/vllm/requirements-hipblas.txt
@@ -0,0 +1,5 @@
+--extra-index-url https://download.pytorch.org/whl/nightly/rocm6.4
+accelerate
+torch
+transformers
+bitsandbytes
\ No newline at end of file
diff --git a/backend/python/vllm/requirements-install.txt b/backend/python/vllm/requirements-install.txt
new file mode 100644
index 0000000000000000000000000000000000000000..69d263f0b3edb33cc212021c98fb23ba253ab005
--- /dev/null
+++ b/backend/python/vllm/requirements-install.txt
@@ -0,0 +1,6 @@
+# mabma does not specify it's build dependencies per PEP517, so we need to disable build isolation
+# this also means that we need to install the basic build dependencies into the venv ourselves
+# https://github.com/Dao-AILab/causal-conv1d/issues/24
+packaging
+setuptools
+wheel
\ No newline at end of file
diff --git a/backend/python/vllm/requirements-intel.txt b/backend/python/vllm/requirements-intel.txt
new file mode 100644
index 0000000000000000000000000000000000000000..a5a176f2f3b4324b0617dfb259274b142f4b981f
--- /dev/null
+++ b/backend/python/vllm/requirements-intel.txt
@@ -0,0 +1,10 @@
+--extra-index-url https://download.pytorch.org/whl/xpu
+--extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
+intel-extension-for-pytorch==2.7.10+xpu
+accelerate
+torch==2.7.0+xpu
+transformers
+optimum[openvino]
+setuptools
+bitsandbytes
+oneccl_bind_pt==2.7.0+xpu
\ No newline at end of file
diff --git a/backend/python/vllm/requirements.txt b/backend/python/vllm/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..e278be72d44be8694fc81fad87a768ba337c7d13
--- /dev/null
+++ b/backend/python/vllm/requirements.txt
@@ -0,0 +1,4 @@
+grpcio==1.76.0
+protobuf
+certifi
+setuptools
\ No newline at end of file
diff --git a/backend/python/vllm/run.sh b/backend/python/vllm/run.sh
new file mode 100644
index 0000000000000000000000000000000000000000..fc88f97da712f14faef73f9e8b96589dd8ecc2ad
--- /dev/null
+++ b/backend/python/vllm/run.sh
@@ -0,0 +1,11 @@
+#!/bin/bash
+
+backend_dir=$(dirname $0)
+
+if [ -d $backend_dir/common ]; then
+ source $backend_dir/common/libbackend.sh
+else
+ source $backend_dir/../common/libbackend.sh
+fi
+
+startBackend $@
\ No newline at end of file
diff --git a/backend/python/vllm/test.py b/backend/python/vllm/test.py
new file mode 100644
index 0000000000000000000000000000000000000000..827aa71a3e33132b75d77a2c192a4000699b7042
--- /dev/null
+++ b/backend/python/vllm/test.py
@@ -0,0 +1,146 @@
+import unittest
+import subprocess
+import time
+import backend_pb2
+import backend_pb2_grpc
+
+import grpc
+
+import unittest
+import subprocess
+import time
+import grpc
+import backend_pb2_grpc
+import backend_pb2
+
+class TestBackendServicer(unittest.TestCase):
+ """
+ TestBackendServicer is the class that tests the gRPC service.
+
+ This class contains methods to test the startup and shutdown of the gRPC service.
+ """
+ def setUp(self):
+ self.service = subprocess.Popen(["python", "backend.py", "--addr", "localhost:50051"])
+ time.sleep(10)
+
+ def tearDown(self) -> None:
+ self.service.terminate()
+ self.service.wait()
+
+ def test_server_startup(self):
+ try:
+ self.setUp()
+ with grpc.insecure_channel("localhost:50051") as channel:
+ stub = backend_pb2_grpc.BackendStub(channel)
+ response = stub.Health(backend_pb2.HealthMessage())
+ self.assertEqual(response.message, b'OK')
+ except Exception as err:
+ print(err)
+ self.fail("Server failed to start")
+ finally:
+ self.tearDown()
+ def test_load_model(self):
+ """
+ This method tests if the model is loaded successfully
+ """
+ try:
+ self.setUp()
+ with grpc.insecure_channel("localhost:50051") as channel:
+ stub = backend_pb2_grpc.BackendStub(channel)
+ response = stub.LoadModel(backend_pb2.ModelOptions(Model="facebook/opt-125m"))
+ self.assertTrue(response.success)
+ self.assertEqual(response.message, "Model loaded successfully")
+ except Exception as err:
+ print(err)
+ self.fail("LoadModel service failed")
+ finally:
+ self.tearDown()
+
+ def test_text(self):
+ """
+ This method tests if the embeddings are generated successfully
+ """
+ try:
+ self.setUp()
+ with grpc.insecure_channel("localhost:50051") as channel:
+ stub = backend_pb2_grpc.BackendStub(channel)
+ response = stub.LoadModel(backend_pb2.ModelOptions(Model="facebook/opt-125m"))
+ self.assertTrue(response.success)
+ req = backend_pb2.PredictOptions(Prompt="The capital of France is")
+ resp = stub.Predict(req)
+ self.assertIsNotNone(resp.message)
+ except Exception as err:
+ print(err)
+ self.fail("text service failed")
+ finally:
+ self.tearDown()
+
+ def test_sampling_params(self):
+ """
+ This method tests if all sampling parameters are correctly processed
+ NOTE: this does NOT test for correctness, just that we received a compatible response
+ """
+ try:
+ self.setUp()
+ with grpc.insecure_channel("localhost:50051") as channel:
+ stub = backend_pb2_grpc.BackendStub(channel)
+ response = stub.LoadModel(backend_pb2.ModelOptions(Model="facebook/opt-125m"))
+ self.assertTrue(response.success)
+
+ req = backend_pb2.PredictOptions(
+ Prompt="The capital of France is",
+ TopP=0.8,
+ Tokens=50,
+ Temperature=0.7,
+ TopK=40,
+ PresencePenalty=0.1,
+ FrequencyPenalty=0.2,
+ RepetitionPenalty=1.1,
+ MinP=0.05,
+ Seed=42,
+ StopPrompts=["\n"],
+ StopTokenIds=[50256],
+ BadWords=["badword"],
+ IncludeStopStrInOutput=True,
+ IgnoreEOS=True,
+ MinTokens=5,
+ Logprobs=5,
+ PromptLogprobs=5,
+ SkipSpecialTokens=True,
+ SpacesBetweenSpecialTokens=True,
+ TruncatePromptTokens=10,
+ GuidedDecoding=True,
+ N=2,
+ )
+ resp = stub.Predict(req)
+ self.assertIsNotNone(resp.message)
+ self.assertIsNotNone(resp.logprobs)
+ except Exception as err:
+ print(err)
+ self.fail("sampling params service failed")
+ finally:
+ self.tearDown()
+
+
+ def test_embedding(self):
+ """
+ This method tests if the embeddings are generated successfully
+ """
+ try:
+ self.setUp()
+ with grpc.insecure_channel("localhost:50051") as channel:
+ stub = backend_pb2_grpc.BackendStub(channel)
+ response = stub.LoadModel(backend_pb2.ModelOptions(Model="intfloat/e5-mistral-7b-instruct"))
+ self.assertTrue(response.success)
+ embedding_request = backend_pb2.PredictOptions(Embeddings="This is a test sentence.")
+ embedding_response = stub.Embedding(embedding_request)
+ self.assertIsNotNone(embedding_response.embeddings)
+ # assert that is a list of floats
+ self.assertIsInstance(embedding_response.embeddings, list)
+ # assert that the list is not empty
+ self.assertTrue(len(embedding_response.embeddings) > 0)
+ except Exception as err:
+ print(err)
+ self.fail("Embedding service failed")
+ finally:
+ self.tearDown()
\ No newline at end of file
diff --git a/backend/python/vllm/test.sh b/backend/python/vllm/test.sh
new file mode 100644
index 0000000000000000000000000000000000000000..f31ae54e47dc7f5a10f630fa1d7b5c8ea56f0c9e
--- /dev/null
+++ b/backend/python/vllm/test.sh
@@ -0,0 +1,12 @@
+#!/bin/bash
+set -e
+
+backend_dir=$(dirname $0)
+
+if [ -d $backend_dir/common ]; then
+ source $backend_dir/common/libbackend.sh
+else
+ source $backend_dir/../common/libbackend.sh
+fi
+
+runUnittests
diff --git a/cmd/launcher/icon.go b/cmd/launcher/icon.go
new file mode 100644
index 0000000000000000000000000000000000000000..514f7ac5a6fc63cd3527f82cc4d940a5471eb951
--- /dev/null
+++ b/cmd/launcher/icon.go
@@ -0,0 +1,16 @@
+package main
+
+import (
+ _ "embed"
+
+ "fyne.io/fyne/v2"
+)
+
+//go:embed logo.png
+var logoData []byte
+
+// resourceIconPng is the LocalAI logo icon
+var resourceIconPng = &fyne.StaticResource{
+ StaticName: "logo.png",
+ StaticContent: logoData,
+}
diff --git a/cmd/launcher/internal/launcher.go b/cmd/launcher/internal/launcher.go
new file mode 100644
index 0000000000000000000000000000000000000000..0b5592fcc74cf966f94edcb942894b3818406c31
--- /dev/null
+++ b/cmd/launcher/internal/launcher.go
@@ -0,0 +1,866 @@
+package launcher
+
+import (
+ "bufio"
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "log"
+ "net/url"
+ "os"
+ "os/exec"
+ "path/filepath"
+ "strings"
+ "sync"
+ "syscall"
+ "time"
+
+ "fyne.io/fyne/v2"
+ "fyne.io/fyne/v2/container"
+ "fyne.io/fyne/v2/dialog"
+ "fyne.io/fyne/v2/widget"
+)
+
+// Config represents the launcher configuration
+type Config struct {
+ ModelsPath string `json:"models_path"`
+ BackendsPath string `json:"backends_path"`
+ Address string `json:"address"`
+ AutoStart bool `json:"auto_start"`
+ StartOnBoot bool `json:"start_on_boot"`
+ LogLevel string `json:"log_level"`
+ EnvironmentVars map[string]string `json:"environment_vars"`
+ ShowWelcome *bool `json:"show_welcome"`
+}
+
+// Launcher represents the main launcher application
+type Launcher struct {
+ // Core components
+ releaseManager *ReleaseManager
+ config *Config
+ ui *LauncherUI
+ systray *SystrayManager
+ ctx context.Context
+ window fyne.Window
+ app fyne.App
+
+ // Process management
+ localaiCmd *exec.Cmd
+ isRunning bool
+ logBuffer *strings.Builder
+ logMutex sync.RWMutex
+ statusChannel chan string
+
+ // Logging
+ logFile *os.File
+ logPath string
+
+ // UI state
+ lastUpdateCheck time.Time
+}
+
+// NewLauncher creates a new launcher instance
+func NewLauncher(ui *LauncherUI, window fyne.Window, app fyne.App) *Launcher {
+ return &Launcher{
+ releaseManager: NewReleaseManager(),
+ config: &Config{},
+ logBuffer: &strings.Builder{},
+ statusChannel: make(chan string, 100),
+ ctx: context.Background(),
+ ui: ui,
+ window: window,
+ app: app,
+ }
+}
+
+// setupLogging sets up log file for LocalAI process output
+func (l *Launcher) setupLogging() error {
+ // Create logs directory in data folder
+ dataPath := l.GetDataPath()
+ logsDir := filepath.Join(dataPath, "logs")
+ if err := os.MkdirAll(logsDir, 0755); err != nil {
+ return fmt.Errorf("failed to create logs directory: %w", err)
+ }
+
+ // Create log file with timestamp
+ timestamp := time.Now().Format("2006-01-02_15-04-05")
+ l.logPath = filepath.Join(logsDir, fmt.Sprintf("localai_%s.log", timestamp))
+
+ logFile, err := os.Create(l.logPath)
+ if err != nil {
+ return fmt.Errorf("failed to create log file: %w", err)
+ }
+
+ l.logFile = logFile
+ return nil
+}
+
+// Initialize sets up the launcher
+func (l *Launcher) Initialize() error {
+ if l.app == nil {
+ return fmt.Errorf("app is nil")
+ }
+ log.Printf("Initializing launcher...")
+
+ // Setup logging
+ if err := l.setupLogging(); err != nil {
+ return fmt.Errorf("failed to setup logging: %w", err)
+ }
+
+ // Load configuration
+ log.Printf("Loading configuration...")
+ if err := l.loadConfig(); err != nil {
+ return fmt.Errorf("failed to load config: %w", err)
+ }
+ log.Printf("Configuration loaded, current state: ModelsPath=%s, BackendsPath=%s, Address=%s, LogLevel=%s",
+ l.config.ModelsPath, l.config.BackendsPath, l.config.Address, l.config.LogLevel)
+
+ // Clean up any partial downloads
+ log.Printf("Cleaning up partial downloads...")
+ if err := l.releaseManager.CleanupPartialDownloads(); err != nil {
+ log.Printf("Warning: failed to cleanup partial downloads: %v", err)
+ }
+
+ if l.config.StartOnBoot {
+ l.StartLocalAI()
+ }
+ // Set default paths if not configured (only if not already loaded from config)
+ if l.config.ModelsPath == "" {
+ homeDir, _ := os.UserHomeDir()
+ l.config.ModelsPath = filepath.Join(homeDir, ".localai", "models")
+ log.Printf("Setting default ModelsPath: %s", l.config.ModelsPath)
+ }
+ if l.config.BackendsPath == "" {
+ homeDir, _ := os.UserHomeDir()
+ l.config.BackendsPath = filepath.Join(homeDir, ".localai", "backends")
+ log.Printf("Setting default BackendsPath: %s", l.config.BackendsPath)
+ }
+ if l.config.Address == "" {
+ l.config.Address = "127.0.0.1:8080"
+ log.Printf("Setting default Address: %s", l.config.Address)
+ }
+ if l.config.LogLevel == "" {
+ l.config.LogLevel = "info"
+ log.Printf("Setting default LogLevel: %s", l.config.LogLevel)
+ }
+ if l.config.EnvironmentVars == nil {
+ l.config.EnvironmentVars = make(map[string]string)
+ log.Printf("Initializing empty EnvironmentVars map")
+ }
+
+ // Set default welcome window preference
+ if l.config.ShowWelcome == nil {
+ true := true
+ l.config.ShowWelcome = &true
+ log.Printf("Setting default ShowWelcome: true")
+ }
+
+ // Create directories
+ os.MkdirAll(l.config.ModelsPath, 0755)
+ os.MkdirAll(l.config.BackendsPath, 0755)
+
+ // Save the configuration with default values
+ if err := l.saveConfig(); err != nil {
+ log.Printf("Warning: failed to save default configuration: %v", err)
+ }
+
+ // System tray is now handled in main.go using Fyne's built-in approach
+
+ // Check if LocalAI is installed
+ if !l.releaseManager.IsLocalAIInstalled() {
+ log.Printf("No LocalAI installation found")
+ fyne.Do(func() {
+ l.updateStatus("No LocalAI installation found")
+ if l.ui != nil {
+ // Show dialog offering to download LocalAI
+ l.showDownloadLocalAIDialog()
+ }
+ })
+ }
+
+ // Check for updates periodically
+ go l.periodicUpdateCheck()
+
+ return nil
+}
+
+// StartLocalAI starts the LocalAI server
+func (l *Launcher) StartLocalAI() error {
+ if l.isRunning {
+ return fmt.Errorf("LocalAI is already running")
+ }
+
+ // Verify binary integrity before starting
+ if err := l.releaseManager.VerifyInstalledBinary(); err != nil {
+ // Binary is corrupted, remove it and offer to reinstall
+ binaryPath := l.releaseManager.GetBinaryPath()
+ if removeErr := os.Remove(binaryPath); removeErr != nil {
+ log.Printf("Failed to remove corrupted binary: %v", removeErr)
+ }
+ return fmt.Errorf("LocalAI binary is corrupted: %v. Please reinstall LocalAI", err)
+ }
+
+ binaryPath := l.releaseManager.GetBinaryPath()
+ if _, err := os.Stat(binaryPath); os.IsNotExist(err) {
+ return fmt.Errorf("LocalAI binary not found. Please download a release first")
+ }
+
+ // Build command arguments
+ args := []string{
+ "run",
+ "--models-path", l.config.ModelsPath,
+ "--backends-path", l.config.BackendsPath,
+ "--address", l.config.Address,
+ "--log-level", l.config.LogLevel,
+ }
+
+ l.localaiCmd = exec.CommandContext(l.ctx, binaryPath, args...)
+
+ // Apply environment variables
+ if len(l.config.EnvironmentVars) > 0 {
+ env := os.Environ()
+ for key, value := range l.config.EnvironmentVars {
+ env = append(env, fmt.Sprintf("%s=%s", key, value))
+ }
+ l.localaiCmd.Env = env
+ }
+
+ // Setup logging
+ stdout, err := l.localaiCmd.StdoutPipe()
+ if err != nil {
+ return fmt.Errorf("failed to create stdout pipe: %w", err)
+ }
+
+ stderr, err := l.localaiCmd.StderrPipe()
+ if err != nil {
+ return fmt.Errorf("failed to create stderr pipe: %w", err)
+ }
+
+ // Start the process
+ if err := l.localaiCmd.Start(); err != nil {
+ return fmt.Errorf("failed to start LocalAI: %w", err)
+ }
+
+ l.isRunning = true
+
+ fyne.Do(func() {
+ l.updateStatus("LocalAI is starting...")
+ l.updateRunningState(true)
+ })
+
+ // Start log monitoring
+ go l.monitorLogs(stdout, "STDOUT")
+ go l.monitorLogs(stderr, "STDERR")
+
+ // Monitor process with startup timeout
+ go func() {
+ // Wait for process to start or fail
+ err := l.localaiCmd.Wait()
+ l.isRunning = false
+ fyne.Do(func() {
+ l.updateRunningState(false)
+ if err != nil {
+ l.updateStatus(fmt.Sprintf("LocalAI stopped with error: %v", err))
+ } else {
+ l.updateStatus("LocalAI stopped")
+ }
+ })
+ }()
+
+ // Add startup timeout detection
+ go func() {
+ time.Sleep(10 * time.Second) // Wait 10 seconds for startup
+ if l.isRunning {
+ // Check if process is still alive
+ if l.localaiCmd.Process != nil {
+ if err := l.localaiCmd.Process.Signal(syscall.Signal(0)); err != nil {
+ // Process is dead, mark as not running
+ l.isRunning = false
+ fyne.Do(func() {
+ l.updateRunningState(false)
+ l.updateStatus("LocalAI failed to start properly")
+ })
+ }
+ }
+ }
+ }()
+
+ return nil
+}
+
+// StopLocalAI stops the LocalAI server
+func (l *Launcher) StopLocalAI() error {
+ if !l.isRunning || l.localaiCmd == nil {
+ return fmt.Errorf("LocalAI is not running")
+ }
+
+ // Gracefully terminate the process
+ if err := l.localaiCmd.Process.Signal(os.Interrupt); err != nil {
+ // If graceful termination fails, force kill
+ if killErr := l.localaiCmd.Process.Kill(); killErr != nil {
+ return fmt.Errorf("failed to kill LocalAI process: %w", killErr)
+ }
+ }
+
+ l.isRunning = false
+ fyne.Do(func() {
+ l.updateRunningState(false)
+ l.updateStatus("LocalAI stopped")
+ })
+ return nil
+}
+
+// IsRunning returns whether LocalAI is currently running
+func (l *Launcher) IsRunning() bool {
+ return l.isRunning
+}
+
+// Shutdown performs cleanup when the application is closing
+func (l *Launcher) Shutdown() error {
+ log.Printf("Launcher shutting down, stopping LocalAI...")
+
+ // Stop LocalAI if it's running
+ if l.isRunning {
+ if err := l.StopLocalAI(); err != nil {
+ log.Printf("Error stopping LocalAI during shutdown: %v", err)
+ }
+ }
+
+ // Close log file if open
+ if l.logFile != nil {
+ if err := l.logFile.Close(); err != nil {
+ log.Printf("Error closing log file: %v", err)
+ }
+ l.logFile = nil
+ }
+
+ log.Printf("Launcher shutdown complete")
+ return nil
+}
+
+// GetLogs returns the current log buffer
+func (l *Launcher) GetLogs() string {
+ l.logMutex.RLock()
+ defer l.logMutex.RUnlock()
+ return l.logBuffer.String()
+}
+
+// GetRecentLogs returns the most recent logs (last 50 lines) for better error display
+func (l *Launcher) GetRecentLogs() string {
+ l.logMutex.RLock()
+ defer l.logMutex.RUnlock()
+
+ content := l.logBuffer.String()
+ lines := strings.Split(content, "\n")
+
+ // Get last 50 lines
+ if len(lines) > 50 {
+ lines = lines[len(lines)-50:]
+ }
+
+ return strings.Join(lines, "\n")
+}
+
+// GetConfig returns the current configuration
+func (l *Launcher) GetConfig() *Config {
+ return l.config
+}
+
+// SetConfig updates the configuration
+func (l *Launcher) SetConfig(config *Config) error {
+ l.config = config
+ return l.saveConfig()
+}
+
+func (l *Launcher) GetUI() *LauncherUI {
+ return l.ui
+}
+
+func (l *Launcher) SetSystray(systray *SystrayManager) {
+ l.systray = systray
+}
+
+// GetReleaseManager returns the release manager
+func (l *Launcher) GetReleaseManager() *ReleaseManager {
+ return l.releaseManager
+}
+
+// GetWebUIURL returns the URL for the WebUI
+func (l *Launcher) GetWebUIURL() string {
+ address := l.config.Address
+ if strings.HasPrefix(address, ":") {
+ address = "localhost" + address
+ }
+ if !strings.HasPrefix(address, "http") {
+ address = "http://" + address
+ }
+ return address
+}
+
+// GetDataPath returns the path where LocalAI data and logs are stored
+func (l *Launcher) GetDataPath() string {
+ // LocalAI typically stores data in the current working directory or a models directory
+ // First check if models path is configured
+ if l.config != nil && l.config.ModelsPath != "" {
+ // Return the parent directory of models path
+ return filepath.Dir(l.config.ModelsPath)
+ }
+
+ // Fallback to home directory LocalAI folder
+ homeDir, err := os.UserHomeDir()
+ if err != nil {
+ return "."
+ }
+ return filepath.Join(homeDir, ".localai")
+}
+
+// CheckForUpdates checks if there are any available updates
+func (l *Launcher) CheckForUpdates() (bool, string, error) {
+ log.Printf("CheckForUpdates: checking for available updates...")
+ available, version, err := l.releaseManager.IsUpdateAvailable()
+ if err != nil {
+ log.Printf("CheckForUpdates: error occurred: %v", err)
+ return false, "", err
+ }
+ log.Printf("CheckForUpdates: result - available=%v, version=%s", available, version)
+ l.lastUpdateCheck = time.Now()
+ return available, version, nil
+}
+
+// DownloadUpdate downloads the latest version
+func (l *Launcher) DownloadUpdate(version string, progressCallback func(float64)) error {
+ return l.releaseManager.DownloadRelease(version, progressCallback)
+}
+
+// GetCurrentVersion returns the current installed version
+func (l *Launcher) GetCurrentVersion() string {
+ return l.releaseManager.GetInstalledVersion()
+}
+
+// GetCurrentStatus returns the current status
+func (l *Launcher) GetCurrentStatus() string {
+ select {
+ case status := <-l.statusChannel:
+ return status
+ default:
+ if l.isRunning {
+ return "LocalAI is running"
+ }
+ return "Ready"
+ }
+}
+
+// GetLastStatus returns the last known status without consuming from channel
+func (l *Launcher) GetLastStatus() string {
+ if l.isRunning {
+ return "LocalAI is running"
+ }
+
+ // Check if LocalAI is installed
+ if !l.releaseManager.IsLocalAIInstalled() {
+ return "LocalAI not installed"
+ }
+
+ return "Ready"
+}
+
+func (l *Launcher) githubReleaseNotesURL(version string) (*url.URL, error) {
+ // Construct GitHub release URL
+ releaseURL := fmt.Sprintf("https://github.com/%s/%s/releases/tag/%s",
+ l.releaseManager.GitHubOwner,
+ l.releaseManager.GitHubRepo,
+ version)
+
+ // Convert string to *url.URL
+ return url.Parse(releaseURL)
+}
+
+// showDownloadLocalAIDialog shows a dialog offering to download LocalAI
+func (l *Launcher) showDownloadLocalAIDialog() {
+ if l.app == nil {
+ log.Printf("Cannot show download dialog: app is nil")
+ return
+ }
+
+ fyne.DoAndWait(func() {
+ // Create a standalone window for the download dialog
+ dialogWindow := l.app.NewWindow("LocalAI Installation Required")
+ dialogWindow.Resize(fyne.NewSize(500, 350))
+ dialogWindow.CenterOnScreen()
+ dialogWindow.SetCloseIntercept(func() {
+ dialogWindow.Close()
+ })
+
+ // Create the dialog content
+ titleLabel := widget.NewLabel("LocalAI Not Found")
+ titleLabel.TextStyle = fyne.TextStyle{Bold: true}
+ titleLabel.Alignment = fyne.TextAlignCenter
+
+ messageLabel := widget.NewLabel("LocalAI is not installed on your system.\n\nWould you like to download and install the latest version?")
+ messageLabel.Wrapping = fyne.TextWrapWord
+ messageLabel.Alignment = fyne.TextAlignCenter
+
+ // Buttons
+ downloadButton := widget.NewButton("Download & Install", func() {
+ dialogWindow.Close()
+ l.downloadAndInstallLocalAI()
+ if l.systray != nil {
+ l.systray.recreateMenu()
+ }
+ })
+ downloadButton.Importance = widget.HighImportance
+
+ // Release notes button
+ releaseNotesButton := widget.NewButton("View Release Notes", func() {
+ // Get latest release info and open release notes
+ go func() {
+ release, err := l.releaseManager.GetLatestRelease()
+ if err != nil {
+ log.Printf("Failed to get latest release info: %v", err)
+ return
+ }
+
+ releaseNotesURL, err := l.githubReleaseNotesURL(release.Version)
+ if err != nil {
+ log.Printf("Failed to parse URL: %v", err)
+ return
+ }
+
+ l.app.OpenURL(releaseNotesURL)
+ }()
+ })
+
+ skipButton := widget.NewButton("Skip for Now", func() {
+ dialogWindow.Close()
+ })
+
+ // Layout - put release notes button above the main action buttons
+ actionButtons := container.NewHBox(skipButton, downloadButton)
+ content := container.NewVBox(
+ titleLabel,
+ widget.NewSeparator(),
+ messageLabel,
+ widget.NewSeparator(),
+ releaseNotesButton,
+ widget.NewSeparator(),
+ actionButtons,
+ )
+
+ dialogWindow.SetContent(content)
+ dialogWindow.Show()
+ })
+}
+
+// downloadAndInstallLocalAI downloads and installs the latest LocalAI version
+func (l *Launcher) downloadAndInstallLocalAI() {
+ if l.app == nil {
+ log.Printf("Cannot download LocalAI: app is nil")
+ return
+ }
+
+ // First check what the latest version is
+ go func() {
+ log.Printf("Checking for latest LocalAI version...")
+ available, version, err := l.CheckForUpdates()
+ if err != nil {
+ log.Printf("Failed to check for updates: %v", err)
+ l.showDownloadError("Failed to check for latest version", err.Error())
+ return
+ }
+
+ if !available {
+ log.Printf("No updates available, but LocalAI is not installed")
+ l.showDownloadError("No Version Available", "Could not determine the latest LocalAI version. Please check your internet connection and try again.")
+ return
+ }
+
+ log.Printf("Latest version available: %s", version)
+ // Show progress window with the specific version
+ l.showDownloadProgress(version, fmt.Sprintf("Downloading LocalAI %s...", version))
+ }()
+}
+
+// showDownloadError shows an error dialog for download failures
+func (l *Launcher) showDownloadError(title, message string) {
+ fyne.DoAndWait(func() {
+ // Create error window
+ errorWindow := l.app.NewWindow("Download Error")
+ errorWindow.Resize(fyne.NewSize(400, 200))
+ errorWindow.CenterOnScreen()
+ errorWindow.SetCloseIntercept(func() {
+ errorWindow.Close()
+ })
+
+ // Error content
+ titleLabel := widget.NewLabel(title)
+ titleLabel.TextStyle = fyne.TextStyle{Bold: true}
+ titleLabel.Alignment = fyne.TextAlignCenter
+
+ messageLabel := widget.NewLabel(message)
+ messageLabel.Wrapping = fyne.TextWrapWord
+ messageLabel.Alignment = fyne.TextAlignCenter
+
+ // Close button
+ closeButton := widget.NewButton("Close", func() {
+ errorWindow.Close()
+ })
+
+ // Layout
+ content := container.NewVBox(
+ titleLabel,
+ widget.NewSeparator(),
+ messageLabel,
+ widget.NewSeparator(),
+ closeButton,
+ )
+
+ errorWindow.SetContent(content)
+ errorWindow.Show()
+ })
+}
+
+// showDownloadProgress shows a standalone progress window for downloading LocalAI
+func (l *Launcher) showDownloadProgress(version, title string) {
+ fyne.DoAndWait(func() {
+ // Create progress window
+ progressWindow := l.app.NewWindow("Downloading LocalAI")
+ progressWindow.Resize(fyne.NewSize(400, 250))
+ progressWindow.CenterOnScreen()
+ progressWindow.SetCloseIntercept(func() {
+ progressWindow.Close()
+ })
+
+ // Progress bar
+ progressBar := widget.NewProgressBar()
+ progressBar.SetValue(0)
+
+ // Status label
+ statusLabel := widget.NewLabel("Preparing download...")
+
+ // Release notes button
+ releaseNotesButton := widget.NewButton("View Release Notes", func() {
+ releaseNotesURL, err := l.githubReleaseNotesURL(version)
+ if err != nil {
+ log.Printf("Failed to parse URL: %v", err)
+ return
+ }
+
+ l.app.OpenURL(releaseNotesURL)
+ })
+
+ // Progress container
+ progressContainer := container.NewVBox(
+ widget.NewLabel(title),
+ progressBar,
+ statusLabel,
+ widget.NewSeparator(),
+ releaseNotesButton,
+ )
+
+ progressWindow.SetContent(progressContainer)
+ progressWindow.Show()
+
+ // Start download in background
+ go func() {
+ err := l.DownloadUpdate(version, func(progress float64) {
+ // Update progress bar
+ fyne.Do(func() {
+ progressBar.SetValue(progress)
+ percentage := int(progress * 100)
+ statusLabel.SetText(fmt.Sprintf("Downloading... %d%%", percentage))
+ })
+ })
+
+ // Handle completion
+ fyne.Do(func() {
+ if err != nil {
+ statusLabel.SetText(fmt.Sprintf("Download failed: %v", err))
+ // Show error dialog
+ dialog.ShowError(err, progressWindow)
+ } else {
+ statusLabel.SetText("Download completed successfully!")
+ progressBar.SetValue(1.0)
+
+ // Show success dialog
+ dialog.ShowConfirm("Installation Complete",
+ "LocalAI has been downloaded and installed successfully. You can now start LocalAI from the launcher.",
+ func(close bool) {
+ progressWindow.Close()
+ // Update status and refresh systray menu
+ l.updateStatus("LocalAI installed successfully")
+
+ if l.systray != nil {
+ l.systray.recreateMenu()
+ }
+ }, progressWindow)
+ }
+ })
+ }()
+ })
+}
+
+// monitorLogs monitors the output of LocalAI and adds it to the log buffer
+func (l *Launcher) monitorLogs(reader io.Reader, prefix string) {
+ scanner := bufio.NewScanner(reader)
+ for scanner.Scan() {
+ line := scanner.Text()
+ timestamp := time.Now().Format("15:04:05")
+ logLine := fmt.Sprintf("[%s] %s: %s\n", timestamp, prefix, line)
+
+ l.logMutex.Lock()
+ l.logBuffer.WriteString(logLine)
+ // Keep log buffer size reasonable
+ if l.logBuffer.Len() > 100000 { // 100KB
+ content := l.logBuffer.String()
+ // Keep last 50KB
+ if len(content) > 50000 {
+ l.logBuffer.Reset()
+ l.logBuffer.WriteString(content[len(content)-50000:])
+ }
+ }
+ l.logMutex.Unlock()
+
+ // Write to log file if available
+ if l.logFile != nil {
+ if _, err := l.logFile.WriteString(logLine); err != nil {
+ log.Printf("Failed to write to log file: %v", err)
+ }
+ }
+
+ fyne.Do(func() {
+ // Notify UI of new log content
+ if l.ui != nil {
+ l.ui.OnLogUpdate(logLine)
+ }
+
+ // Check for startup completion
+ if strings.Contains(line, "API server listening") {
+ l.updateStatus("LocalAI is running")
+ }
+ })
+ }
+}
+
+// updateStatus updates the status and notifies UI
+func (l *Launcher) updateStatus(status string) {
+ select {
+ case l.statusChannel <- status:
+ default:
+ // Channel full, skip
+ }
+
+ if l.ui != nil {
+ l.ui.UpdateStatus(status)
+ }
+
+ if l.systray != nil {
+ l.systray.UpdateStatus(status)
+ }
+}
+
+// updateRunningState updates the running state in UI and systray
+func (l *Launcher) updateRunningState(isRunning bool) {
+ if l.ui != nil {
+ l.ui.UpdateRunningState(isRunning)
+ }
+
+ if l.systray != nil {
+ l.systray.UpdateRunningState(isRunning)
+ }
+}
+
+// periodicUpdateCheck checks for updates periodically
+func (l *Launcher) periodicUpdateCheck() {
+ ticker := time.NewTicker(1 * time.Hour)
+ defer ticker.Stop()
+
+ for {
+ select {
+ case <-ticker.C:
+ available, version, err := l.CheckForUpdates()
+ if err == nil && available {
+ fyne.Do(func() {
+ l.updateStatus(fmt.Sprintf("Update available: %s", version))
+ if l.systray != nil {
+ l.systray.NotifyUpdateAvailable(version)
+ }
+ if l.ui != nil {
+ l.ui.NotifyUpdateAvailable(version)
+ }
+ })
+ }
+ case <-l.ctx.Done():
+ return
+ }
+ }
+}
+
+// loadConfig loads configuration from file
+func (l *Launcher) loadConfig() error {
+ homeDir, err := os.UserHomeDir()
+ if err != nil {
+ return fmt.Errorf("failed to get home directory: %w", err)
+ }
+
+ configPath := filepath.Join(homeDir, ".localai", "launcher.json")
+ log.Printf("Loading config from: %s", configPath)
+
+ if _, err := os.Stat(configPath); os.IsNotExist(err) {
+ log.Printf("Config file not found, creating default config")
+ // Create default config
+ return l.saveConfig()
+ }
+
+ // Load existing config
+ configData, err := os.ReadFile(configPath)
+ if err != nil {
+ return fmt.Errorf("failed to read config file: %w", err)
+ }
+
+ log.Printf("Config file content: %s", string(configData))
+
+ log.Printf("loadConfig: about to unmarshal JSON data")
+ if err := json.Unmarshal(configData, l.config); err != nil {
+ return fmt.Errorf("failed to parse config file: %w", err)
+ }
+ log.Printf("loadConfig: JSON unmarshaled successfully")
+
+ log.Printf("Loaded config: ModelsPath=%s, BackendsPath=%s, Address=%s, LogLevel=%s",
+ l.config.ModelsPath, l.config.BackendsPath, l.config.Address, l.config.LogLevel)
+ log.Printf("Environment vars: %v", l.config.EnvironmentVars)
+
+ return nil
+}
+
+// saveConfig saves configuration to file
+func (l *Launcher) saveConfig() error {
+ homeDir, err := os.UserHomeDir()
+ if err != nil {
+ return fmt.Errorf("failed to get home directory: %w", err)
+ }
+
+ configDir := filepath.Join(homeDir, ".localai")
+ if err := os.MkdirAll(configDir, 0755); err != nil {
+ return fmt.Errorf("failed to create config directory: %w", err)
+ }
+
+ // Marshal config to JSON
+ log.Printf("saveConfig: marshaling config with EnvironmentVars: %v", l.config.EnvironmentVars)
+ configData, err := json.MarshalIndent(l.config, "", " ")
+ if err != nil {
+ return fmt.Errorf("failed to marshal config: %w", err)
+ }
+ log.Printf("saveConfig: JSON marshaled successfully, length: %d", len(configData))
+
+ configPath := filepath.Join(configDir, "launcher.json")
+ log.Printf("Saving config to: %s", configPath)
+ log.Printf("Config content: %s", string(configData))
+
+ if err := os.WriteFile(configPath, configData, 0644); err != nil {
+ return fmt.Errorf("failed to write config file: %w", err)
+ }
+
+ log.Printf("Config saved successfully")
+ return nil
+}
diff --git a/cmd/launcher/internal/launcher_suite_test.go b/cmd/launcher/internal/launcher_suite_test.go
new file mode 100644
index 0000000000000000000000000000000000000000..3648197b3afd507e2f886775b9978b21920ce575
--- /dev/null
+++ b/cmd/launcher/internal/launcher_suite_test.go
@@ -0,0 +1,13 @@
+package launcher_test
+
+import (
+ "testing"
+
+ . "github.com/onsi/ginkgo/v2"
+ . "github.com/onsi/gomega"
+)
+
+func TestLauncher(t *testing.T) {
+ RegisterFailHandler(Fail)
+ RunSpecs(t, "Launcher Suite")
+}
diff --git a/cmd/launcher/internal/launcher_test.go b/cmd/launcher/internal/launcher_test.go
new file mode 100644
index 0000000000000000000000000000000000000000..15a2a24eeed4892dc9147c9be9f2a2956a2d1c5c
--- /dev/null
+++ b/cmd/launcher/internal/launcher_test.go
@@ -0,0 +1,213 @@
+package launcher_test
+
+import (
+ "os"
+ "path/filepath"
+ "strings"
+
+ . "github.com/onsi/ginkgo/v2"
+ . "github.com/onsi/gomega"
+
+ "fyne.io/fyne/v2/app"
+
+ launcher "github.com/mudler/LocalAI/cmd/launcher/internal"
+)
+
+var _ = Describe("Launcher", func() {
+ var (
+ launcherInstance *launcher.Launcher
+ tempDir string
+ )
+
+ BeforeEach(func() {
+ var err error
+ tempDir, err = os.MkdirTemp("", "launcher-test-*")
+ Expect(err).ToNot(HaveOccurred())
+
+ ui := launcher.NewLauncherUI()
+ app := app.NewWithID("com.localai.launcher")
+
+ launcherInstance = launcher.NewLauncher(ui, nil, app)
+ })
+
+ AfterEach(func() {
+ os.RemoveAll(tempDir)
+ })
+
+ Describe("NewLauncher", func() {
+ It("should create a launcher with default configuration", func() {
+ Expect(launcherInstance.GetConfig()).ToNot(BeNil())
+ })
+ })
+
+ Describe("Initialize", func() {
+ It("should set default paths when not configured", func() {
+ err := launcherInstance.Initialize()
+ Expect(err).ToNot(HaveOccurred())
+
+ config := launcherInstance.GetConfig()
+ Expect(config.ModelsPath).ToNot(BeEmpty())
+ Expect(config.BackendsPath).ToNot(BeEmpty())
+ })
+
+ It("should set default ShowWelcome to true", func() {
+ err := launcherInstance.Initialize()
+ Expect(err).ToNot(HaveOccurred())
+
+ config := launcherInstance.GetConfig()
+ Expect(config.ShowWelcome).To(BeTrue())
+ Expect(config.Address).To(Equal("127.0.0.1:8080"))
+ Expect(config.LogLevel).To(Equal("info"))
+ })
+
+ It("should create models and backends directories", func() {
+ // Set custom paths for testing
+ config := launcherInstance.GetConfig()
+ config.ModelsPath = filepath.Join(tempDir, "models")
+ config.BackendsPath = filepath.Join(tempDir, "backends")
+ launcherInstance.SetConfig(config)
+
+ err := launcherInstance.Initialize()
+ Expect(err).ToNot(HaveOccurred())
+
+ // Check if directories were created
+ _, err = os.Stat(config.ModelsPath)
+ Expect(err).ToNot(HaveOccurred())
+
+ _, err = os.Stat(config.BackendsPath)
+ Expect(err).ToNot(HaveOccurred())
+ })
+ })
+
+ Describe("Configuration", func() {
+ It("should get and set configuration", func() {
+ config := launcherInstance.GetConfig()
+ config.ModelsPath = "/test/models"
+ config.BackendsPath = "/test/backends"
+ config.Address = ":9090"
+ config.LogLevel = "debug"
+
+ err := launcherInstance.SetConfig(config)
+ Expect(err).ToNot(HaveOccurred())
+
+ retrievedConfig := launcherInstance.GetConfig()
+ Expect(retrievedConfig.ModelsPath).To(Equal("/test/models"))
+ Expect(retrievedConfig.BackendsPath).To(Equal("/test/backends"))
+ Expect(retrievedConfig.Address).To(Equal(":9090"))
+ Expect(retrievedConfig.LogLevel).To(Equal("debug"))
+ })
+ })
+
+ Describe("WebUI URL", func() {
+ It("should return correct WebUI URL for localhost", func() {
+ config := launcherInstance.GetConfig()
+ config.Address = ":8080"
+ launcherInstance.SetConfig(config)
+
+ url := launcherInstance.GetWebUIURL()
+ Expect(url).To(Equal("http://localhost:8080"))
+ })
+
+ It("should return correct WebUI URL for full address", func() {
+ config := launcherInstance.GetConfig()
+ config.Address = "127.0.0.1:8080"
+ launcherInstance.SetConfig(config)
+
+ url := launcherInstance.GetWebUIURL()
+ Expect(url).To(Equal("http://127.0.0.1:8080"))
+ })
+
+ It("should handle http prefix correctly", func() {
+ config := launcherInstance.GetConfig()
+ config.Address = "http://localhost:8080"
+ launcherInstance.SetConfig(config)
+
+ url := launcherInstance.GetWebUIURL()
+ Expect(url).To(Equal("http://localhost:8080"))
+ })
+ })
+
+ Describe("Process Management", func() {
+ It("should not be running initially", func() {
+ Expect(launcherInstance.IsRunning()).To(BeFalse())
+ })
+
+ It("should handle start when binary doesn't exist", func() {
+ err := launcherInstance.StartLocalAI()
+ Expect(err).To(HaveOccurred())
+ // Could be either "not found" or "permission denied" depending on test environment
+ errMsg := err.Error()
+ hasExpectedError := strings.Contains(errMsg, "LocalAI binary") ||
+ strings.Contains(errMsg, "permission denied")
+ Expect(hasExpectedError).To(BeTrue(), "Expected error about binary not found or permission denied, got: %s", errMsg)
+ })
+
+ It("should handle stop when not running", func() {
+ err := launcherInstance.StopLocalAI()
+ Expect(err).To(HaveOccurred())
+ Expect(err.Error()).To(ContainSubstring("LocalAI is not running"))
+ })
+ })
+
+ Describe("Logs", func() {
+ It("should return empty logs initially", func() {
+ logs := launcherInstance.GetLogs()
+ Expect(logs).To(BeEmpty())
+ })
+ })
+
+ Describe("Version Management", func() {
+ It("should return empty version when no binary installed", func() {
+ version := launcherInstance.GetCurrentVersion()
+ Expect(version).To(BeEmpty()) // No binary installed in test environment
+ })
+
+ It("should handle update checks", func() {
+ // This test would require mocking HTTP responses
+ // For now, we'll just test that the method doesn't panic
+ _, _, err := launcherInstance.CheckForUpdates()
+ // We expect either success or a network error, not a panic
+ if err != nil {
+ // Network error is acceptable in tests
+ Expect(err.Error()).To(ContainSubstring("failed to fetch"))
+ }
+ })
+ })
+})
+
+var _ = Describe("Config", func() {
+ It("should have proper JSON tags", func() {
+ config := &launcher.Config{
+ ModelsPath: "/test/models",
+ BackendsPath: "/test/backends",
+ Address: ":8080",
+ AutoStart: true,
+ LogLevel: "info",
+ EnvironmentVars: map[string]string{"TEST": "value"},
+ }
+
+ Expect(config.ModelsPath).To(Equal("/test/models"))
+ Expect(config.BackendsPath).To(Equal("/test/backends"))
+ Expect(config.Address).To(Equal(":8080"))
+ Expect(config.AutoStart).To(BeTrue())
+ Expect(config.LogLevel).To(Equal("info"))
+ Expect(config.EnvironmentVars).To(HaveKeyWithValue("TEST", "value"))
+ })
+
+ It("should initialize environment variables map", func() {
+ config := &launcher.Config{}
+ Expect(config.EnvironmentVars).To(BeNil())
+
+ ui := launcher.NewLauncherUI()
+ app := app.NewWithID("com.localai.launcher")
+
+ launcher := launcher.NewLauncher(ui, nil, app)
+
+ err := launcher.Initialize()
+ Expect(err).ToNot(HaveOccurred())
+
+ retrievedConfig := launcher.GetConfig()
+ Expect(retrievedConfig.EnvironmentVars).ToNot(BeNil())
+ Expect(retrievedConfig.EnvironmentVars).To(BeEmpty())
+ })
+})
diff --git a/cmd/launcher/internal/release_manager.go b/cmd/launcher/internal/release_manager.go
new file mode 100644
index 0000000000000000000000000000000000000000..6c0055ee3caf6a165cdfd1d9accf33de17427a8e
--- /dev/null
+++ b/cmd/launcher/internal/release_manager.go
@@ -0,0 +1,559 @@
+package launcher
+
+import (
+ "bufio"
+ "crypto/sha256"
+ "encoding/hex"
+ "encoding/json"
+ "fmt"
+ "io"
+ "log"
+ "net/http"
+ "os"
+ "os/exec"
+ "path/filepath"
+ "runtime"
+ "strings"
+ "time"
+
+ "github.com/mudler/LocalAI/internal"
+)
+
+// Release represents a LocalAI release
+type Release struct {
+ Version string `json:"tag_name"`
+ Name string `json:"name"`
+ Body string `json:"body"`
+ PublishedAt time.Time `json:"published_at"`
+ Assets []Asset `json:"assets"`
+}
+
+// Asset represents a release asset
+type Asset struct {
+ Name string `json:"name"`
+ BrowserDownloadURL string `json:"browser_download_url"`
+ Size int64 `json:"size"`
+}
+
+// ReleaseManager handles LocalAI release management
+type ReleaseManager struct {
+ // GitHubOwner is the GitHub repository owner
+ GitHubOwner string
+ // GitHubRepo is the GitHub repository name
+ GitHubRepo string
+ // BinaryPath is where the LocalAI binary is stored locally
+ BinaryPath string
+ // CurrentVersion is the currently installed version
+ CurrentVersion string
+ // ChecksumsPath is where checksums are stored
+ ChecksumsPath string
+ // MetadataPath is where version metadata is stored
+ MetadataPath string
+ // HTTPClient is the HTTP client used for downloads
+ HTTPClient *http.Client
+}
+
+// NewReleaseManager creates a new release manager
+func NewReleaseManager() *ReleaseManager {
+ homeDir, _ := os.UserHomeDir()
+ binaryPath := filepath.Join(homeDir, ".localai", "bin")
+ checksumsPath := filepath.Join(homeDir, ".localai", "checksums")
+ metadataPath := filepath.Join(homeDir, ".localai", "metadata")
+
+ return &ReleaseManager{
+ GitHubOwner: "mudler",
+ GitHubRepo: "LocalAI",
+ BinaryPath: binaryPath,
+ CurrentVersion: internal.PrintableVersion(),
+ ChecksumsPath: checksumsPath,
+ MetadataPath: metadataPath,
+ HTTPClient: &http.Client{
+ Timeout: 30 * time.Second,
+ },
+ }
+}
+
+// GetLatestRelease fetches the latest release information from GitHub
+func (rm *ReleaseManager) GetLatestRelease() (*Release, error) {
+ url := fmt.Sprintf("https://api.github.com/repos/%s/%s/releases/latest", rm.GitHubOwner, rm.GitHubRepo)
+
+ resp, err := rm.HTTPClient.Get(url)
+ if err != nil {
+ return nil, fmt.Errorf("failed to fetch latest release: %w", err)
+ }
+ defer resp.Body.Close()
+
+ if resp.StatusCode != http.StatusOK {
+ return nil, fmt.Errorf("failed to fetch latest release: status %d", resp.StatusCode)
+ }
+
+ // Parse the JSON response properly
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, fmt.Errorf("failed to read response body: %w", err)
+ }
+
+ release := &Release{}
+ if err := json.Unmarshal(body, release); err != nil {
+ return nil, fmt.Errorf("failed to parse JSON response: %w", err)
+ }
+
+ // Validate the release data
+ if release.Version == "" {
+ return nil, fmt.Errorf("no version found in release data")
+ }
+
+ return release, nil
+}
+
+// DownloadRelease downloads a specific version of LocalAI
+func (rm *ReleaseManager) DownloadRelease(version string, progressCallback func(float64)) error {
+ // Ensure the binary directory exists
+ if err := os.MkdirAll(rm.BinaryPath, 0755); err != nil {
+ return fmt.Errorf("failed to create binary directory: %w", err)
+ }
+
+ // Determine the binary name based on OS and architecture
+ binaryName := rm.GetBinaryName(version)
+ localPath := filepath.Join(rm.BinaryPath, "local-ai")
+
+ // Download the binary
+ downloadURL := fmt.Sprintf("https://github.com/%s/%s/releases/download/%s/%s",
+ rm.GitHubOwner, rm.GitHubRepo, version, binaryName)
+
+ if err := rm.downloadFile(downloadURL, localPath, progressCallback); err != nil {
+ return fmt.Errorf("failed to download binary: %w", err)
+ }
+
+ // Download and verify checksums
+ checksumURL := fmt.Sprintf("https://github.com/%s/%s/releases/download/%s/LocalAI-%s-checksums.txt",
+ rm.GitHubOwner, rm.GitHubRepo, version, version)
+
+ checksumPath := filepath.Join(rm.BinaryPath, "checksums.txt")
+ manualChecksumPath := filepath.Join(rm.ChecksumsPath, fmt.Sprintf("checksums-%s.txt", version))
+
+ // First, check if there's already a checksum file (either manually placed or previously downloaded)
+ // and honor that, skipping download entirely in such case
+ var downloadErr error
+ if _, err := os.Stat(manualChecksumPath); err == nil {
+ log.Printf("Using existing checksums from: %s", manualChecksumPath)
+ checksumPath = manualChecksumPath
+ } else if _, err := os.Stat(checksumPath); err == nil {
+ log.Printf("Using existing checksums from: %s", checksumPath)
+ } else {
+ // No existing checksum file found, try to download
+ downloadErr = rm.downloadFile(checksumURL, checksumPath, nil)
+
+ if downloadErr != nil {
+ log.Printf("Warning: failed to download checksums: %v", downloadErr)
+ log.Printf("Warning: Checksum verification will be skipped. For security, you can manually place checksums at: %s", manualChecksumPath)
+ log.Printf("Download checksums from: %s", checksumURL)
+ // Continue without verification - log warning but don't fail
+ }
+ }
+
+ // Verify the checksum if we have a checksum file
+ if _, err := os.Stat(checksumPath); err == nil {
+ if err := rm.VerifyChecksum(localPath, checksumPath, binaryName); err != nil {
+ return fmt.Errorf("checksum verification failed: %w", err)
+ }
+ log.Printf("Checksum verification successful")
+
+ // Save checksums persistently for future verification
+ if downloadErr == nil {
+ if err := rm.saveChecksums(version, checksumPath, binaryName); err != nil {
+ log.Printf("Warning: failed to save checksums: %v", err)
+ }
+ }
+ } else {
+ log.Printf("Warning: Proceeding without checksum verification")
+ }
+
+ // Make the binary executable
+ if err := os.Chmod(localPath, 0755); err != nil {
+ return fmt.Errorf("failed to make binary executable: %w", err)
+ }
+
+ return nil
+}
+
+// GetBinaryName returns the appropriate binary name for the current platform
+func (rm *ReleaseManager) GetBinaryName(version string) string {
+ versionStr := strings.TrimPrefix(version, "v")
+ os := runtime.GOOS
+ arch := runtime.GOARCH
+
+ // Map Go arch names to the release naming convention
+ switch arch {
+ case "amd64":
+ arch = "amd64"
+ case "arm64":
+ arch = "arm64"
+ default:
+ arch = "amd64" // fallback
+ }
+
+ return fmt.Sprintf("local-ai-v%s-%s-%s", versionStr, os, arch)
+}
+
+// downloadFile downloads a file from a URL to a local path with optional progress callback
+func (rm *ReleaseManager) downloadFile(url, filepath string, progressCallback func(float64)) error {
+ return rm.downloadFileWithRetry(url, filepath, progressCallback, 3)
+}
+
+// downloadFileWithRetry downloads a file from a URL with retry logic
+func (rm *ReleaseManager) downloadFileWithRetry(url, filepath string, progressCallback func(float64), maxRetries int) error {
+ var lastErr error
+
+ for attempt := 1; attempt <= maxRetries; attempt++ {
+ if attempt > 1 {
+ log.Printf("Retrying download (attempt %d/%d): %s", attempt, maxRetries, url)
+ time.Sleep(time.Duration(attempt) * time.Second)
+ }
+
+ resp, err := rm.HTTPClient.Get(url)
+ if err != nil {
+ lastErr = err
+ continue
+ }
+
+ if resp.StatusCode != http.StatusOK {
+ resp.Body.Close()
+ lastErr = fmt.Errorf("bad status: %s", resp.Status)
+ continue
+ }
+
+ out, err := os.Create(filepath)
+ if err != nil {
+ resp.Body.Close()
+ return err
+ }
+
+ // Create a progress reader if callback is provided
+ var reader io.Reader = resp.Body
+ if progressCallback != nil && resp.ContentLength > 0 {
+ reader = &progressReader{
+ Reader: resp.Body,
+ Total: resp.ContentLength,
+ Callback: progressCallback,
+ }
+ }
+
+ _, err = io.Copy(out, reader)
+ resp.Body.Close()
+ out.Close()
+
+ if err != nil {
+ lastErr = err
+ os.Remove(filepath)
+ continue
+ }
+
+ return nil
+ }
+
+ return fmt.Errorf("failed after %d attempts: %w", maxRetries, lastErr)
+}
+
+// saveChecksums saves checksums persistently for future verification
+func (rm *ReleaseManager) saveChecksums(version, checksumPath, binaryName string) error {
+ // Ensure checksums directory exists
+ if err := os.MkdirAll(rm.ChecksumsPath, 0755); err != nil {
+ return fmt.Errorf("failed to create checksums directory: %w", err)
+ }
+
+ // Read the downloaded checksums file
+ checksumData, err := os.ReadFile(checksumPath)
+ if err != nil {
+ return fmt.Errorf("failed to read checksums file: %w", err)
+ }
+
+ // Save to persistent location with version info
+ persistentPath := filepath.Join(rm.ChecksumsPath, fmt.Sprintf("checksums-%s.txt", version))
+ if err := os.WriteFile(persistentPath, checksumData, 0644); err != nil {
+ return fmt.Errorf("failed to write persistent checksums: %w", err)
+ }
+
+ // Also save a "latest" checksums file for the current version
+ latestPath := filepath.Join(rm.ChecksumsPath, "checksums-latest.txt")
+ if err := os.WriteFile(latestPath, checksumData, 0644); err != nil {
+ return fmt.Errorf("failed to write latest checksums: %w", err)
+ }
+
+ // Save version metadata
+ if err := rm.saveVersionMetadata(version); err != nil {
+ log.Printf("Warning: failed to save version metadata: %v", err)
+ }
+
+ log.Printf("Checksums saved for version %s", version)
+ return nil
+}
+
+// saveVersionMetadata saves the installed version information
+func (rm *ReleaseManager) saveVersionMetadata(version string) error {
+ // Ensure metadata directory exists
+ if err := os.MkdirAll(rm.MetadataPath, 0755); err != nil {
+ return fmt.Errorf("failed to create metadata directory: %w", err)
+ }
+
+ // Create metadata structure
+ metadata := struct {
+ Version string `json:"version"`
+ InstalledAt time.Time `json:"installed_at"`
+ BinaryPath string `json:"binary_path"`
+ }{
+ Version: version,
+ InstalledAt: time.Now(),
+ BinaryPath: rm.GetBinaryPath(),
+ }
+
+ // Marshal to JSON
+ metadataData, err := json.MarshalIndent(metadata, "", " ")
+ if err != nil {
+ return fmt.Errorf("failed to marshal metadata: %w", err)
+ }
+
+ // Save metadata file
+ metadataPath := filepath.Join(rm.MetadataPath, "installed-version.json")
+ if err := os.WriteFile(metadataPath, metadataData, 0644); err != nil {
+ return fmt.Errorf("failed to write metadata file: %w", err)
+ }
+
+ log.Printf("Version metadata saved: %s", version)
+ return nil
+}
+
+// progressReader wraps an io.Reader to provide download progress
+type progressReader struct {
+ io.Reader
+ Total int64
+ Current int64
+ Callback func(float64)
+}
+
+func (pr *progressReader) Read(p []byte) (int, error) {
+ n, err := pr.Reader.Read(p)
+ pr.Current += int64(n)
+ if pr.Callback != nil {
+ progress := float64(pr.Current) / float64(pr.Total)
+ pr.Callback(progress)
+ }
+ return n, err
+}
+
+// VerifyChecksum verifies the downloaded file against the provided checksums
+func (rm *ReleaseManager) VerifyChecksum(filePath, checksumPath, binaryName string) error {
+ // Calculate the SHA256 of the downloaded file
+ file, err := os.Open(filePath)
+ if err != nil {
+ return fmt.Errorf("failed to open file for checksum: %w", err)
+ }
+ defer file.Close()
+
+ hasher := sha256.New()
+ if _, err := io.Copy(hasher, file); err != nil {
+ return fmt.Errorf("failed to calculate checksum: %w", err)
+ }
+
+ calculatedHash := hex.EncodeToString(hasher.Sum(nil))
+
+ // Read the checksums file
+ checksumFile, err := os.Open(checksumPath)
+ if err != nil {
+ return fmt.Errorf("failed to open checksums file: %w", err)
+ }
+ defer checksumFile.Close()
+
+ scanner := bufio.NewScanner(checksumFile)
+ for scanner.Scan() {
+ line := strings.TrimSpace(scanner.Text())
+ if strings.Contains(line, binaryName) {
+ parts := strings.Fields(line)
+ if len(parts) >= 2 {
+ expectedHash := parts[0]
+ if calculatedHash == expectedHash {
+ return nil // Checksum verified
+ }
+ return fmt.Errorf("checksum mismatch: expected %s, got %s", expectedHash, calculatedHash)
+ }
+ }
+ }
+
+ return fmt.Errorf("checksum not found for %s", binaryName)
+}
+
+// GetInstalledVersion returns the currently installed version
+func (rm *ReleaseManager) GetInstalledVersion() string {
+
+ // Fallback: Check if the LocalAI binary exists and try to get its version
+ binaryPath := rm.GetBinaryPath()
+ if _, err := os.Stat(binaryPath); os.IsNotExist(err) {
+ return "" // No version installed
+ }
+
+ // try to get version from metadata
+ if version := rm.loadVersionMetadata(); version != "" {
+ return version
+ }
+
+ // Try to run the binary to get the version (fallback method)
+ version, err := exec.Command(binaryPath, "--version").Output()
+ if err != nil {
+ // If binary exists but --version fails, try to determine from filename or other means
+ log.Printf("Binary exists but --version failed: %v", err)
+ return ""
+ }
+
+ stringVersion := strings.TrimSpace(string(version))
+ stringVersion = strings.TrimRight(stringVersion, "\n")
+
+ return stringVersion
+}
+
+// loadVersionMetadata loads the installed version from metadata file
+func (rm *ReleaseManager) loadVersionMetadata() string {
+ metadataPath := filepath.Join(rm.MetadataPath, "installed-version.json")
+
+ // Check if metadata file exists
+ if _, err := os.Stat(metadataPath); os.IsNotExist(err) {
+ return ""
+ }
+
+ // Read metadata file
+ metadataData, err := os.ReadFile(metadataPath)
+ if err != nil {
+ log.Printf("Failed to read metadata file: %v", err)
+ return ""
+ }
+
+ // Parse metadata
+ var metadata struct {
+ Version string `json:"version"`
+ InstalledAt time.Time `json:"installed_at"`
+ BinaryPath string `json:"binary_path"`
+ }
+
+ if err := json.Unmarshal(metadataData, &metadata); err != nil {
+ log.Printf("Failed to parse metadata file: %v", err)
+ return ""
+ }
+
+ // Verify that the binary path in metadata matches current binary path
+ if metadata.BinaryPath != rm.GetBinaryPath() {
+ log.Printf("Binary path mismatch in metadata, ignoring")
+ return ""
+ }
+
+ log.Printf("Loaded version from metadata: %s (installed at %s)", metadata.Version, metadata.InstalledAt.Format("2006-01-02 15:04:05"))
+ return metadata.Version
+}
+
+// GetBinaryPath returns the path to the LocalAI binary
+func (rm *ReleaseManager) GetBinaryPath() string {
+ return filepath.Join(rm.BinaryPath, "local-ai")
+}
+
+// IsUpdateAvailable checks if an update is available
+func (rm *ReleaseManager) IsUpdateAvailable() (bool, string, error) {
+ log.Printf("IsUpdateAvailable: checking for updates...")
+
+ latest, err := rm.GetLatestRelease()
+ if err != nil {
+ log.Printf("IsUpdateAvailable: failed to get latest release: %v", err)
+ return false, "", err
+ }
+ log.Printf("IsUpdateAvailable: latest release version: %s", latest.Version)
+
+ current := rm.GetInstalledVersion()
+ log.Printf("IsUpdateAvailable: current installed version: %s", current)
+
+ if current == "" {
+ // No version installed, offer to download latest
+ log.Printf("IsUpdateAvailable: no version installed, offering latest: %s", latest.Version)
+ return true, latest.Version, nil
+ }
+
+ updateAvailable := latest.Version != current
+ log.Printf("IsUpdateAvailable: update available: %v (latest: %s, current: %s)", updateAvailable, latest.Version, current)
+ return updateAvailable, latest.Version, nil
+}
+
+// IsLocalAIInstalled checks if LocalAI binary exists and is valid
+func (rm *ReleaseManager) IsLocalAIInstalled() bool {
+ binaryPath := rm.GetBinaryPath()
+ if _, err := os.Stat(binaryPath); os.IsNotExist(err) {
+ return false
+ }
+
+ // Verify the binary integrity
+ if err := rm.VerifyInstalledBinary(); err != nil {
+ log.Printf("Binary integrity check failed: %v", err)
+ // Remove corrupted binary
+ if removeErr := os.Remove(binaryPath); removeErr != nil {
+ log.Printf("Failed to remove corrupted binary: %v", removeErr)
+ }
+ return false
+ }
+
+ return true
+}
+
+// VerifyInstalledBinary verifies the installed binary against saved checksums
+func (rm *ReleaseManager) VerifyInstalledBinary() error {
+ binaryPath := rm.GetBinaryPath()
+
+ // Check if we have saved checksums
+ latestChecksumsPath := filepath.Join(rm.ChecksumsPath, "checksums-latest.txt")
+ if _, err := os.Stat(latestChecksumsPath); os.IsNotExist(err) {
+ return fmt.Errorf("no saved checksums found")
+ }
+
+ // Get the binary name for the current version from metadata
+ currentVersion := rm.loadVersionMetadata()
+ if currentVersion == "" {
+ return fmt.Errorf("cannot determine current version from metadata")
+ }
+
+ binaryName := rm.GetBinaryName(currentVersion)
+
+ // Verify against saved checksums
+ return rm.VerifyChecksum(binaryPath, latestChecksumsPath, binaryName)
+}
+
+// CleanupPartialDownloads removes any partial or corrupted downloads
+func (rm *ReleaseManager) CleanupPartialDownloads() error {
+ binaryPath := rm.GetBinaryPath()
+
+ // Check if binary exists but is corrupted
+ if _, err := os.Stat(binaryPath); err == nil {
+ // Binary exists, verify it
+ if verifyErr := rm.VerifyInstalledBinary(); verifyErr != nil {
+ log.Printf("Found corrupted binary, removing: %v", verifyErr)
+ if removeErr := os.Remove(binaryPath); removeErr != nil {
+ log.Printf("Failed to remove corrupted binary: %v", removeErr)
+ }
+ // Clear metadata since binary is corrupted
+ rm.clearVersionMetadata()
+ }
+ }
+
+ // Clean up any temporary checksum files
+ tempChecksumsPath := filepath.Join(rm.BinaryPath, "checksums.txt")
+ if _, err := os.Stat(tempChecksumsPath); err == nil {
+ if removeErr := os.Remove(tempChecksumsPath); removeErr != nil {
+ log.Printf("Failed to remove temporary checksums: %v", removeErr)
+ }
+ }
+
+ return nil
+}
+
+// clearVersionMetadata clears the version metadata (used when binary is corrupted or removed)
+func (rm *ReleaseManager) clearVersionMetadata() {
+ metadataPath := filepath.Join(rm.MetadataPath, "installed-version.json")
+ if err := os.Remove(metadataPath); err != nil && !os.IsNotExist(err) {
+ log.Printf("Failed to clear version metadata: %v", err)
+ } else {
+ log.Printf("Version metadata cleared")
+ }
+}
diff --git a/cmd/launcher/internal/release_manager_test.go b/cmd/launcher/internal/release_manager_test.go
new file mode 100644
index 0000000000000000000000000000000000000000..f6de6aa5abdf95c476d605246dd2d1de5ea39724
--- /dev/null
+++ b/cmd/launcher/internal/release_manager_test.go
@@ -0,0 +1,181 @@
+package launcher_test
+
+import (
+ "os"
+ "path/filepath"
+ "runtime"
+ "time"
+
+ . "github.com/onsi/ginkgo/v2"
+ . "github.com/onsi/gomega"
+
+ launcher "github.com/mudler/LocalAI/cmd/launcher/internal"
+)
+
+var _ = Describe("ReleaseManager", func() {
+ var (
+ rm *launcher.ReleaseManager
+ tempDir string
+ )
+
+ BeforeEach(func() {
+ var err error
+ tempDir, err = os.MkdirTemp("", "launcher-test-*")
+ Expect(err).ToNot(HaveOccurred())
+
+ rm = launcher.NewReleaseManager()
+ // Override binary path for testing
+ rm.BinaryPath = tempDir
+ })
+
+ AfterEach(func() {
+ os.RemoveAll(tempDir)
+ })
+
+ Describe("NewReleaseManager", func() {
+ It("should create a release manager with correct defaults", func() {
+ newRM := launcher.NewReleaseManager()
+ Expect(newRM.GitHubOwner).To(Equal("mudler"))
+ Expect(newRM.GitHubRepo).To(Equal("LocalAI"))
+ Expect(newRM.BinaryPath).To(ContainSubstring(".localai"))
+ Expect(newRM.HTTPClient).ToNot(BeNil())
+ Expect(newRM.HTTPClient.Timeout).To(Equal(30 * time.Second))
+ })
+ })
+
+ Describe("GetBinaryName", func() {
+ It("should return correct binary name for current platform", func() {
+ binaryName := rm.GetBinaryName("v3.4.0")
+ expectedOS := runtime.GOOS
+ expectedArch := runtime.GOARCH
+
+ expected := "local-ai-v3.4.0-" + expectedOS + "-" + expectedArch
+ Expect(binaryName).To(Equal(expected))
+ })
+
+ It("should handle version with and without 'v' prefix", func() {
+ withV := rm.GetBinaryName("v3.4.0")
+ withoutV := rm.GetBinaryName("3.4.0")
+
+ // Both should produce the same result
+ Expect(withV).To(Equal(withoutV))
+ })
+ })
+
+ Describe("GetBinaryPath", func() {
+ It("should return the correct binary path", func() {
+ path := rm.GetBinaryPath()
+ expected := filepath.Join(tempDir, "local-ai")
+ Expect(path).To(Equal(expected))
+ })
+ })
+
+ Describe("GetInstalledVersion", func() {
+ It("should return empty when no binary exists", func() {
+ version := rm.GetInstalledVersion()
+ Expect(version).To(BeEmpty()) // No binary installed in test
+ })
+
+ It("should return empty version when binary exists but no metadata", func() {
+ // Create a fake binary for testing
+ err := os.MkdirAll(rm.BinaryPath, 0755)
+ Expect(err).ToNot(HaveOccurred())
+
+ binaryPath := rm.GetBinaryPath()
+ err = os.WriteFile(binaryPath, []byte("fake binary"), 0755)
+ Expect(err).ToNot(HaveOccurred())
+
+ version := rm.GetInstalledVersion()
+ Expect(version).To(BeEmpty())
+ })
+ })
+
+ Context("with mocked responses", func() {
+ // Note: In a real implementation, we'd mock HTTP responses
+ // For now, we'll test the structure and error handling
+
+ Describe("GetLatestRelease", func() {
+ It("should handle network errors gracefully", func() {
+ // This test would require mocking HTTP client
+ // For demonstration, we're just testing the method exists
+ _, err := rm.GetLatestRelease()
+ // We expect either success or a network error, not a panic
+ // In a real test, we'd mock the HTTP response
+ if err != nil {
+ Expect(err.Error()).To(ContainSubstring("failed to fetch"))
+ }
+ })
+ })
+
+ Describe("DownloadRelease", func() {
+ It("should create binary directory if it doesn't exist", func() {
+ // Remove the temp directory to test creation
+ os.RemoveAll(tempDir)
+
+ // This will fail due to network, but should create the directory
+ rm.DownloadRelease("v3.4.0", nil)
+
+ // Check if directory was created
+ _, err := os.Stat(tempDir)
+ Expect(err).ToNot(HaveOccurred())
+ })
+ })
+ })
+
+ Describe("VerifyChecksum functionality", func() {
+ var (
+ testFile string
+ checksumFile string
+ )
+
+ BeforeEach(func() {
+ testFile = filepath.Join(tempDir, "test-binary")
+ checksumFile = filepath.Join(tempDir, "checksums.txt")
+ })
+
+ It("should verify checksums correctly", func() {
+ // Create a test file with known content
+ testContent := []byte("test content for checksum")
+ err := os.WriteFile(testFile, testContent, 0644)
+ Expect(err).ToNot(HaveOccurred())
+
+ // Calculate expected SHA256
+ // This is a simplified test - in practice we'd use the actual checksum
+ checksumContent := "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855 test-binary\n"
+ err = os.WriteFile(checksumFile, []byte(checksumContent), 0644)
+ Expect(err).ToNot(HaveOccurred())
+
+ // Test checksum verification
+ // Note: This will fail because our content doesn't match the empty string hash
+ // In a real test, we'd calculate the actual hash
+ err = rm.VerifyChecksum(testFile, checksumFile, "test-binary")
+ // We expect this to fail since we're using a dummy checksum
+ Expect(err).To(HaveOccurred())
+ Expect(err.Error()).To(ContainSubstring("checksum mismatch"))
+ })
+
+ It("should handle missing checksum file", func() {
+ // Create test file but no checksum file
+ err := os.WriteFile(testFile, []byte("test"), 0644)
+ Expect(err).ToNot(HaveOccurred())
+
+ err = rm.VerifyChecksum(testFile, checksumFile, "test-binary")
+ Expect(err).To(HaveOccurred())
+ Expect(err.Error()).To(ContainSubstring("failed to open checksums file"))
+ })
+
+ It("should handle missing binary in checksums", func() {
+ // Create files but checksum doesn't contain our binary
+ err := os.WriteFile(testFile, []byte("test"), 0644)
+ Expect(err).ToNot(HaveOccurred())
+
+ checksumContent := "hash other-binary\n"
+ err = os.WriteFile(checksumFile, []byte(checksumContent), 0644)
+ Expect(err).ToNot(HaveOccurred())
+
+ err = rm.VerifyChecksum(testFile, checksumFile, "test-binary")
+ Expect(err).To(HaveOccurred())
+ Expect(err.Error()).To(ContainSubstring("checksum not found"))
+ })
+ })
+})
diff --git a/cmd/launcher/internal/systray_manager.go b/cmd/launcher/internal/systray_manager.go
new file mode 100644
index 0000000000000000000000000000000000000000..4881fce889212fc5da4b6918ece115bf31df1efb
--- /dev/null
+++ b/cmd/launcher/internal/systray_manager.go
@@ -0,0 +1,523 @@
+package launcher
+
+import (
+ "fmt"
+ "log"
+ "net/url"
+
+ "fyne.io/fyne/v2"
+ "fyne.io/fyne/v2/container"
+ "fyne.io/fyne/v2/dialog"
+ "fyne.io/fyne/v2/driver/desktop"
+ "fyne.io/fyne/v2/widget"
+)
+
+// SystrayManager manages the system tray functionality
+type SystrayManager struct {
+ launcher *Launcher
+ window fyne.Window
+ app fyne.App
+ desk desktop.App
+
+ // Menu items that need dynamic updates
+ startStopItem *fyne.MenuItem
+ hasUpdateAvailable bool
+ latestVersion string
+ icon *fyne.StaticResource
+}
+
+// NewSystrayManager creates a new systray manager
+func NewSystrayManager(launcher *Launcher, window fyne.Window, desktop desktop.App, app fyne.App, icon *fyne.StaticResource) *SystrayManager {
+ sm := &SystrayManager{
+ launcher: launcher,
+ window: window,
+ app: app,
+ desk: desktop,
+ icon: icon,
+ }
+ sm.setupMenu(desktop)
+ return sm
+}
+
+// setupMenu sets up the system tray menu
+func (sm *SystrayManager) setupMenu(desk desktop.App) {
+ sm.desk = desk
+
+ // Create the start/stop toggle item
+ sm.startStopItem = fyne.NewMenuItem("Start LocalAI", func() {
+ sm.toggleLocalAI()
+ })
+
+ desk.SetSystemTrayIcon(sm.icon)
+
+ // Initialize the menu state using recreateMenu
+ sm.recreateMenu()
+}
+
+// toggleLocalAI starts or stops LocalAI based on current state
+func (sm *SystrayManager) toggleLocalAI() {
+ if sm.launcher.IsRunning() {
+ go func() {
+ if err := sm.launcher.StopLocalAI(); err != nil {
+ log.Printf("Failed to stop LocalAI: %v", err)
+ sm.showErrorDialog("Failed to Stop LocalAI", err.Error())
+ }
+ }()
+ } else {
+ go func() {
+ if err := sm.launcher.StartLocalAI(); err != nil {
+ log.Printf("Failed to start LocalAI: %v", err)
+ sm.showStartupErrorDialog(err)
+ }
+ }()
+ }
+}
+
+// openWebUI opens the LocalAI WebUI in the default browser
+func (sm *SystrayManager) openWebUI() {
+ if !sm.launcher.IsRunning() {
+ return // LocalAI is not running
+ }
+
+ webURL := sm.launcher.GetWebUIURL()
+ if parsedURL, err := url.Parse(webURL); err == nil {
+ sm.app.OpenURL(parsedURL)
+ }
+}
+
+// openDocumentation opens the LocalAI documentation
+func (sm *SystrayManager) openDocumentation() {
+ if parsedURL, err := url.Parse("https://localai.io"); err == nil {
+ sm.app.OpenURL(parsedURL)
+ }
+}
+
+// updateStartStopItem updates the start/stop menu item based on current state
+func (sm *SystrayManager) updateStartStopItem() {
+ // Since Fyne menu items can't change text dynamically, we recreate the menu
+ sm.recreateMenu()
+}
+
+// recreateMenu recreates the entire menu with updated state
+func (sm *SystrayManager) recreateMenu() {
+ if sm.desk == nil {
+ return
+ }
+
+ // Determine the action based on LocalAI installation and running state
+ var actionItem *fyne.MenuItem
+ if !sm.launcher.GetReleaseManager().IsLocalAIInstalled() {
+ // LocalAI not installed - show install option
+ actionItem = fyne.NewMenuItem("📥 Install Latest Version", func() {
+ sm.launcher.showDownloadLocalAIDialog()
+ })
+ } else if sm.launcher.IsRunning() {
+ // LocalAI is running - show stop option
+ actionItem = fyne.NewMenuItem("🛑 Stop LocalAI", func() {
+ sm.toggleLocalAI()
+ })
+ } else {
+ // LocalAI is installed but not running - show start option
+ actionItem = fyne.NewMenuItem("▶️ Start LocalAI", func() {
+ sm.toggleLocalAI()
+ })
+ }
+
+ menuItems := []*fyne.MenuItem{}
+
+ // Add status at the top (clickable for details)
+ status := sm.launcher.GetLastStatus()
+ statusText := sm.truncateText(status, 30)
+ statusItem := fyne.NewMenuItem("📊 Status: "+statusText, func() {
+ sm.showStatusDetails(status, "")
+ })
+ menuItems = append(menuItems, statusItem)
+
+ // Only show version if LocalAI is installed
+ if sm.launcher.GetReleaseManager().IsLocalAIInstalled() {
+ version := sm.launcher.GetCurrentVersion()
+ versionText := sm.truncateText(version, 25)
+ versionItem := fyne.NewMenuItem("🔧 Version: "+versionText, func() {
+ sm.showStatusDetails(status, version)
+ })
+ menuItems = append(menuItems, versionItem)
+ }
+
+ menuItems = append(menuItems, fyne.NewMenuItemSeparator())
+
+ // Add update notification if available
+ if sm.hasUpdateAvailable {
+ updateItem := fyne.NewMenuItem("🔔 New version available ("+sm.latestVersion+")", func() {
+ sm.downloadUpdate()
+ })
+ menuItems = append(menuItems, updateItem)
+ menuItems = append(menuItems, fyne.NewMenuItemSeparator())
+ }
+
+ // Core actions
+ menuItems = append(menuItems,
+ actionItem,
+ )
+
+ // Only show WebUI option if LocalAI is installed
+ if sm.launcher.GetReleaseManager().IsLocalAIInstalled() && sm.launcher.IsRunning() {
+ menuItems = append(menuItems,
+ fyne.NewMenuItem("Open WebUI", func() {
+ sm.openWebUI()
+ }),
+ )
+ }
+
+ menuItems = append(menuItems,
+ fyne.NewMenuItemSeparator(),
+ fyne.NewMenuItem("Check for Updates", func() {
+ sm.checkForUpdates()
+ }),
+ fyne.NewMenuItemSeparator(),
+ fyne.NewMenuItem("Settings", func() {
+ sm.showSettings()
+ }),
+ fyne.NewMenuItem("Show Welcome Window", func() {
+ sm.showWelcomeWindow()
+ }),
+ fyne.NewMenuItem("Open Data Folder", func() {
+ sm.openDataFolder()
+ }),
+ fyne.NewMenuItemSeparator(),
+ fyne.NewMenuItem("Documentation", func() {
+ sm.openDocumentation()
+ }),
+ fyne.NewMenuItemSeparator(),
+ fyne.NewMenuItem("Quit", func() {
+ // Perform cleanup before quitting
+ if err := sm.launcher.Shutdown(); err != nil {
+ log.Printf("Error during shutdown: %v", err)
+ }
+ sm.app.Quit()
+ }),
+ )
+
+ menu := fyne.NewMenu("LocalAI", menuItems...)
+ sm.desk.SetSystemTrayMenu(menu)
+}
+
+// UpdateRunningState updates the systray based on running state
+func (sm *SystrayManager) UpdateRunningState(isRunning bool) {
+ sm.updateStartStopItem()
+}
+
+// UpdateStatus updates the systray menu to reflect status changes
+func (sm *SystrayManager) UpdateStatus(status string) {
+ sm.recreateMenu()
+}
+
+// checkForUpdates checks for available updates
+func (sm *SystrayManager) checkForUpdates() {
+ go func() {
+ log.Printf("Checking for updates...")
+ available, version, err := sm.launcher.CheckForUpdates()
+ if err != nil {
+ log.Printf("Failed to check for updates: %v", err)
+ return
+ }
+
+ log.Printf("Update check result: available=%v, version=%s", available, version)
+ if available {
+ sm.hasUpdateAvailable = true
+ sm.latestVersion = version
+ sm.recreateMenu()
+ }
+ }()
+}
+
+// downloadUpdate downloads the latest update
+func (sm *SystrayManager) downloadUpdate() {
+ if !sm.hasUpdateAvailable {
+ return
+ }
+
+ // Show progress window
+ sm.showDownloadProgress(sm.latestVersion)
+}
+
+// showSettings shows the settings window
+func (sm *SystrayManager) showSettings() {
+ sm.window.Show()
+ sm.window.RequestFocus()
+}
+
+// showWelcomeWindow shows the welcome window
+func (sm *SystrayManager) showWelcomeWindow() {
+ if sm.launcher.GetUI() != nil {
+ sm.launcher.GetUI().ShowWelcomeWindow()
+ }
+}
+
+// openDataFolder opens the data folder in file manager
+func (sm *SystrayManager) openDataFolder() {
+ dataPath := sm.launcher.GetDataPath()
+ if parsedURL, err := url.Parse("file://" + dataPath); err == nil {
+ sm.app.OpenURL(parsedURL)
+ }
+}
+
+// NotifyUpdateAvailable sets update notification in systray
+func (sm *SystrayManager) NotifyUpdateAvailable(version string) {
+ sm.hasUpdateAvailable = true
+ sm.latestVersion = version
+ sm.recreateMenu()
+}
+
+// truncateText truncates text to specified length and adds ellipsis if needed
+func (sm *SystrayManager) truncateText(text string, maxLength int) string {
+ if len(text) <= maxLength {
+ return text
+ }
+ return text[:maxLength-3] + "..."
+}
+
+// showStatusDetails shows a detailed status window with full information
+func (sm *SystrayManager) showStatusDetails(status, version string) {
+ fyne.DoAndWait(func() {
+ // Create status details window
+ statusWindow := sm.app.NewWindow("LocalAI Status Details")
+ statusWindow.Resize(fyne.NewSize(500, 400))
+ statusWindow.CenterOnScreen()
+
+ // Status information
+ statusLabel := widget.NewLabel("Current Status:")
+ statusValue := widget.NewLabel(status)
+ statusValue.Wrapping = fyne.TextWrapWord
+
+ // Version information (only show if version exists)
+ var versionContainer fyne.CanvasObject
+ if version != "" {
+ versionLabel := widget.NewLabel("Installed Version:")
+ versionValue := widget.NewLabel(version)
+ versionValue.Wrapping = fyne.TextWrapWord
+ versionContainer = container.NewVBox(versionLabel, versionValue)
+ }
+
+ // Running state
+ runningLabel := widget.NewLabel("Running State:")
+ runningValue := widget.NewLabel("")
+ if sm.launcher.IsRunning() {
+ runningValue.SetText("🟢 Running")
+ } else {
+ runningValue.SetText("🔴 Stopped")
+ }
+
+ // WebUI URL
+ webuiLabel := widget.NewLabel("WebUI URL:")
+ webuiValue := widget.NewLabel(sm.launcher.GetWebUIURL())
+ webuiValue.Wrapping = fyne.TextWrapWord
+
+ // Recent logs (last 20 lines)
+ logsLabel := widget.NewLabel("Recent Logs:")
+ logsText := widget.NewMultiLineEntry()
+ logsText.SetText(sm.launcher.GetRecentLogs())
+ logsText.Wrapping = fyne.TextWrapWord
+ logsText.Disable() // Make it read-only
+
+ // Buttons
+ closeButton := widget.NewButton("Close", func() {
+ statusWindow.Close()
+ })
+
+ refreshButton := widget.NewButton("Refresh", func() {
+ // Refresh the status information
+ statusValue.SetText(sm.launcher.GetLastStatus())
+
+ // Note: Version refresh is not implemented for simplicity
+ // The version will be updated when the status details window is reopened
+
+ if sm.launcher.IsRunning() {
+ runningValue.SetText("🟢 Running")
+ } else {
+ runningValue.SetText("🔴 Stopped")
+ }
+ logsText.SetText(sm.launcher.GetRecentLogs())
+ })
+
+ openWebUIButton := widget.NewButton("Open WebUI", func() {
+ sm.openWebUI()
+ })
+
+ // Layout
+ buttons := container.NewHBox(closeButton, refreshButton, openWebUIButton)
+
+ // Build info container dynamically
+ infoItems := []fyne.CanvasObject{
+ statusLabel, statusValue,
+ widget.NewSeparator(),
+ }
+
+ // Add version section if it exists
+ if versionContainer != nil {
+ infoItems = append(infoItems, versionContainer, widget.NewSeparator())
+ }
+
+ infoItems = append(infoItems,
+ runningLabel, runningValue,
+ widget.NewSeparator(),
+ webuiLabel, webuiValue,
+ )
+
+ infoContainer := container.NewVBox(infoItems...)
+
+ content := container.NewVBox(
+ infoContainer,
+ widget.NewSeparator(),
+ logsLabel,
+ logsText,
+ widget.NewSeparator(),
+ buttons,
+ )
+
+ statusWindow.SetContent(content)
+ statusWindow.Show()
+ })
+}
+
+// showErrorDialog shows a simple error dialog
+func (sm *SystrayManager) showErrorDialog(title, message string) {
+ fyne.DoAndWait(func() {
+ dialog.ShowError(fmt.Errorf("%s", message), sm.window)
+ })
+}
+
+// showStartupErrorDialog shows a detailed error dialog with process logs
+func (sm *SystrayManager) showStartupErrorDialog(err error) {
+ fyne.DoAndWait(func() {
+ // Get the recent process logs (more useful for debugging)
+ logs := sm.launcher.GetRecentLogs()
+
+ // Create error window
+ errorWindow := sm.app.NewWindow("LocalAI Startup Failed")
+ errorWindow.Resize(fyne.NewSize(600, 500))
+ errorWindow.CenterOnScreen()
+
+ // Error message
+ errorLabel := widget.NewLabel(fmt.Sprintf("Failed to start LocalAI:\n%s", err.Error()))
+ errorLabel.Wrapping = fyne.TextWrapWord
+
+ // Logs display
+ logsLabel := widget.NewLabel("Process Logs:")
+ logsText := widget.NewMultiLineEntry()
+ logsText.SetText(logs)
+ logsText.Wrapping = fyne.TextWrapWord
+ logsText.Disable() // Make it read-only
+
+ // Buttons
+ closeButton := widget.NewButton("Close", func() {
+ errorWindow.Close()
+ })
+
+ retryButton := widget.NewButton("Retry", func() {
+ errorWindow.Close()
+ // Try to start again
+ go func() {
+ if retryErr := sm.launcher.StartLocalAI(); retryErr != nil {
+ sm.showStartupErrorDialog(retryErr)
+ }
+ }()
+ })
+
+ openLogsButton := widget.NewButton("Open Logs Folder", func() {
+ sm.openDataFolder()
+ })
+
+ // Layout
+ buttons := container.NewHBox(closeButton, retryButton, openLogsButton)
+ content := container.NewVBox(
+ errorLabel,
+ widget.NewSeparator(),
+ logsLabel,
+ logsText,
+ widget.NewSeparator(),
+ buttons,
+ )
+
+ errorWindow.SetContent(content)
+ errorWindow.Show()
+ })
+}
+
+// showDownloadProgress shows a progress window for downloading updates
+func (sm *SystrayManager) showDownloadProgress(version string) {
+ // Create a new window for download progress
+ progressWindow := sm.app.NewWindow("Downloading LocalAI Update")
+ progressWindow.Resize(fyne.NewSize(400, 250))
+ progressWindow.CenterOnScreen()
+
+ // Progress bar
+ progressBar := widget.NewProgressBar()
+ progressBar.SetValue(0)
+
+ // Status label
+ statusLabel := widget.NewLabel("Preparing download...")
+
+ // Release notes button
+ releaseNotesButton := widget.NewButton("View Release Notes", func() {
+ releaseNotesURL, err := sm.launcher.githubReleaseNotesURL(version)
+ if err != nil {
+ log.Printf("Failed to parse URL: %v", err)
+ return
+ }
+
+ sm.app.OpenURL(releaseNotesURL)
+ })
+
+ // Progress container
+ progressContainer := container.NewVBox(
+ widget.NewLabel(fmt.Sprintf("Downloading LocalAI version %s", version)),
+ progressBar,
+ statusLabel,
+ widget.NewSeparator(),
+ releaseNotesButton,
+ )
+
+ progressWindow.SetContent(progressContainer)
+ progressWindow.Show()
+
+ // Start download in background
+ go func() {
+ err := sm.launcher.DownloadUpdate(version, func(progress float64) {
+ // Update progress bar
+ fyne.Do(func() {
+ progressBar.SetValue(progress)
+ percentage := int(progress * 100)
+ statusLabel.SetText(fmt.Sprintf("Downloading... %d%%", percentage))
+ })
+ })
+
+ // Handle completion
+ fyne.Do(func() {
+ if err != nil {
+ statusLabel.SetText(fmt.Sprintf("Download failed: %v", err))
+ // Show error dialog
+ dialog.ShowError(err, progressWindow)
+ } else {
+ statusLabel.SetText("Download completed successfully!")
+ progressBar.SetValue(1.0)
+
+ // Show restart dialog
+ dialog.ShowConfirm("Update Downloaded",
+ "LocalAI has been updated successfully. Please restart the launcher to use the new version.",
+ func(restart bool) {
+ if restart {
+ sm.app.Quit()
+ }
+ progressWindow.Close()
+ }, progressWindow)
+ }
+ })
+
+ // Update systray menu
+ if err == nil {
+ sm.hasUpdateAvailable = false
+ sm.latestVersion = ""
+ sm.recreateMenu()
+ }
+ }()
+}
diff --git a/cmd/launcher/internal/ui.go b/cmd/launcher/internal/ui.go
new file mode 100644
index 0000000000000000000000000000000000000000..7efd781d9b8a4352853a29282e033c0b296ffce9
--- /dev/null
+++ b/cmd/launcher/internal/ui.go
@@ -0,0 +1,795 @@
+package launcher
+
+import (
+ "fmt"
+ "log"
+ "net/url"
+
+ "fyne.io/fyne/v2"
+ "fyne.io/fyne/v2/container"
+ "fyne.io/fyne/v2/dialog"
+ "fyne.io/fyne/v2/widget"
+)
+
+// EnvVar represents an environment variable
+type EnvVar struct {
+ Key string
+ Value string
+}
+
+// LauncherUI handles the user interface
+type LauncherUI struct {
+ // Status display
+ statusLabel *widget.Label
+ versionLabel *widget.Label
+
+ // Control buttons
+ startStopButton *widget.Button
+ webUIButton *widget.Button
+ updateButton *widget.Button
+ downloadButton *widget.Button
+
+ // Configuration
+ modelsPathEntry *widget.Entry
+ backendsPathEntry *widget.Entry
+ addressEntry *widget.Entry
+ logLevelSelect *widget.Select
+ startOnBootCheck *widget.Check
+
+ // Environment Variables
+ envVarsData []EnvVar
+ newEnvKeyEntry *widget.Entry
+ newEnvValueEntry *widget.Entry
+ updateEnvironmentDisplay func()
+
+ // Logs
+ logText *widget.Entry
+
+ // Progress
+ progressBar *widget.ProgressBar
+
+ // Update management
+ latestVersion string
+
+ // Reference to launcher
+ launcher *Launcher
+}
+
+// NewLauncherUI creates a new UI instance
+func NewLauncherUI() *LauncherUI {
+ return &LauncherUI{
+ statusLabel: widget.NewLabel("Initializing..."),
+ versionLabel: widget.NewLabel("Version: Unknown"),
+ startStopButton: widget.NewButton("Start LocalAI", nil),
+ webUIButton: widget.NewButton("Open WebUI", nil),
+ updateButton: widget.NewButton("Check for Updates", nil),
+ modelsPathEntry: widget.NewEntry(),
+ backendsPathEntry: widget.NewEntry(),
+ addressEntry: widget.NewEntry(),
+ logLevelSelect: widget.NewSelect([]string{"error", "warn", "info", "debug", "trace"}, nil),
+ startOnBootCheck: widget.NewCheck("Start LocalAI on system boot", nil),
+ logText: widget.NewMultiLineEntry(),
+ progressBar: widget.NewProgressBar(),
+ envVarsData: []EnvVar{}, // Initialize the environment variables slice
+ }
+}
+
+// CreateMainUI creates the main UI layout
+func (ui *LauncherUI) CreateMainUI(launcher *Launcher) *fyne.Container {
+ ui.launcher = launcher
+ ui.setupBindings()
+
+ // Main tab with status and controls
+ // Configuration is now the main content
+ configTab := ui.createConfigTab()
+
+ // Create a simple container instead of tabs since we only have settings
+ tabs := container.NewVBox(
+ widget.NewCard("LocalAI Launcher Settings", "", configTab),
+ )
+
+ return tabs
+}
+
+// createConfigTab creates the configuration tab
+func (ui *LauncherUI) createConfigTab() *fyne.Container {
+ // Path configuration
+ pathsCard := widget.NewCard("Paths", "", container.NewGridWithColumns(2,
+ widget.NewLabel("Models Path:"),
+ ui.modelsPathEntry,
+ widget.NewLabel("Backends Path:"),
+ ui.backendsPathEntry,
+ ))
+
+ // Server configuration
+ serverCard := widget.NewCard("Server", "", container.NewVBox(
+ container.NewGridWithColumns(2,
+ widget.NewLabel("Address:"),
+ ui.addressEntry,
+ widget.NewLabel("Log Level:"),
+ ui.logLevelSelect,
+ ),
+ ui.startOnBootCheck,
+ ))
+
+ // Save button
+ saveButton := widget.NewButton("Save Configuration", func() {
+ ui.saveConfiguration()
+ })
+
+ // Environment Variables section
+ envCard := ui.createEnvironmentSection()
+
+ return container.NewVBox(
+ pathsCard,
+ serverCard,
+ envCard,
+ saveButton,
+ )
+}
+
+// createEnvironmentSection creates the environment variables section for the config tab
+func (ui *LauncherUI) createEnvironmentSection() *fyne.Container {
+ // Initialize environment variables widgets
+ ui.newEnvKeyEntry = widget.NewEntry()
+ ui.newEnvKeyEntry.SetPlaceHolder("Environment Variable Name")
+
+ ui.newEnvValueEntry = widget.NewEntry()
+ ui.newEnvValueEntry.SetPlaceHolder("Environment Variable Value")
+
+ // Add button
+ addButton := widget.NewButton("Add Environment Variable", func() {
+ ui.addEnvironmentVariable()
+ })
+
+ // Environment variables list with delete buttons
+ ui.envVarsData = []EnvVar{}
+
+ // Create container for environment variables
+ envVarsContainer := container.NewVBox()
+
+ // Update function to rebuild the environment variables display
+ ui.updateEnvironmentDisplay = func() {
+ envVarsContainer.Objects = nil
+ for i, envVar := range ui.envVarsData {
+ index := i // Capture index for closure
+
+ // Create row with label and delete button
+ envLabel := widget.NewLabel(fmt.Sprintf("%s = %s", envVar.Key, envVar.Value))
+ deleteBtn := widget.NewButton("Delete", func() {
+ ui.confirmDeleteEnvironmentVariable(index)
+ })
+ deleteBtn.Importance = widget.DangerImportance
+
+ row := container.NewBorder(nil, nil, nil, deleteBtn, envLabel)
+ envVarsContainer.Add(row)
+ }
+ envVarsContainer.Refresh()
+ }
+
+ // Create a scrollable container for the environment variables
+ envScroll := container.NewScroll(envVarsContainer)
+ envScroll.SetMinSize(fyne.NewSize(400, 150))
+
+ // Input section for adding new environment variables
+ inputSection := container.NewVBox(
+ container.NewGridWithColumns(2,
+ ui.newEnvKeyEntry,
+ ui.newEnvValueEntry,
+ ),
+ addButton,
+ )
+
+ // Environment variables card
+ envCard := widget.NewCard("Environment Variables", "", container.NewVBox(
+ inputSection,
+ widget.NewSeparator(),
+ envScroll,
+ ))
+
+ return container.NewVBox(envCard)
+}
+
+// addEnvironmentVariable adds a new environment variable
+func (ui *LauncherUI) addEnvironmentVariable() {
+ key := ui.newEnvKeyEntry.Text
+ value := ui.newEnvValueEntry.Text
+
+ log.Printf("addEnvironmentVariable: attempting to add %s=%s", key, value)
+ log.Printf("addEnvironmentVariable: current ui.envVarsData has %d items: %v", len(ui.envVarsData), ui.envVarsData)
+
+ if key == "" {
+ log.Printf("addEnvironmentVariable: key is empty, showing error")
+ dialog.ShowError(fmt.Errorf("environment variable name cannot be empty"), ui.launcher.window)
+ return
+ }
+
+ // Check if key already exists
+ for _, envVar := range ui.envVarsData {
+ if envVar.Key == key {
+ log.Printf("addEnvironmentVariable: key %s already exists, showing error", key)
+ dialog.ShowError(fmt.Errorf("environment variable '%s' already exists", key), ui.launcher.window)
+ return
+ }
+ }
+
+ log.Printf("addEnvironmentVariable: adding new env var %s=%s", key, value)
+ ui.envVarsData = append(ui.envVarsData, EnvVar{Key: key, Value: value})
+ log.Printf("addEnvironmentVariable: after adding, ui.envVarsData has %d items: %v", len(ui.envVarsData), ui.envVarsData)
+
+ fyne.Do(func() {
+ if ui.updateEnvironmentDisplay != nil {
+ ui.updateEnvironmentDisplay()
+ }
+ // Clear input fields
+ ui.newEnvKeyEntry.SetText("")
+ ui.newEnvValueEntry.SetText("")
+ })
+
+ log.Printf("addEnvironmentVariable: calling saveEnvironmentVariables")
+ // Save to configuration
+ ui.saveEnvironmentVariables()
+}
+
+// removeEnvironmentVariable removes an environment variable by index
+func (ui *LauncherUI) removeEnvironmentVariable(index int) {
+ if index >= 0 && index < len(ui.envVarsData) {
+ ui.envVarsData = append(ui.envVarsData[:index], ui.envVarsData[index+1:]...)
+ fyne.Do(func() {
+ if ui.updateEnvironmentDisplay != nil {
+ ui.updateEnvironmentDisplay()
+ }
+ })
+ ui.saveEnvironmentVariables()
+ }
+}
+
+// saveEnvironmentVariables saves environment variables to the configuration
+func (ui *LauncherUI) saveEnvironmentVariables() {
+ if ui.launcher == nil {
+ log.Printf("saveEnvironmentVariables: launcher is nil")
+ return
+ }
+
+ config := ui.launcher.GetConfig()
+ log.Printf("saveEnvironmentVariables: before - Environment vars: %v", config.EnvironmentVars)
+
+ config.EnvironmentVars = make(map[string]string)
+ for _, envVar := range ui.envVarsData {
+ config.EnvironmentVars[envVar.Key] = envVar.Value
+ log.Printf("saveEnvironmentVariables: adding %s=%s", envVar.Key, envVar.Value)
+ }
+
+ log.Printf("saveEnvironmentVariables: after - Environment vars: %v", config.EnvironmentVars)
+ log.Printf("saveEnvironmentVariables: calling SetConfig with %d environment variables", len(config.EnvironmentVars))
+
+ err := ui.launcher.SetConfig(config)
+ if err != nil {
+ log.Printf("saveEnvironmentVariables: failed to save config: %v", err)
+ } else {
+ log.Printf("saveEnvironmentVariables: config saved successfully")
+ }
+}
+
+// confirmDeleteEnvironmentVariable shows confirmation dialog for deleting an environment variable
+func (ui *LauncherUI) confirmDeleteEnvironmentVariable(index int) {
+ if index >= 0 && index < len(ui.envVarsData) {
+ envVar := ui.envVarsData[index]
+ dialog.ShowConfirm("Remove Environment Variable",
+ fmt.Sprintf("Remove environment variable '%s'?", envVar.Key),
+ func(remove bool) {
+ if remove {
+ ui.removeEnvironmentVariable(index)
+ }
+ }, ui.launcher.window)
+ }
+}
+
+// setupBindings sets up event handlers for UI elements
+func (ui *LauncherUI) setupBindings() {
+ // Start/Stop button
+ ui.startStopButton.OnTapped = func() {
+ if ui.launcher.IsRunning() {
+ ui.stopLocalAI()
+ } else {
+ ui.startLocalAI()
+ }
+ }
+
+ // WebUI button
+ ui.webUIButton.OnTapped = func() {
+ ui.openWebUI()
+ }
+ ui.webUIButton.Disable() // Disabled until LocalAI is running
+
+ // Update button
+ ui.updateButton.OnTapped = func() {
+ ui.checkForUpdates()
+ }
+
+ // Log level selection
+ ui.logLevelSelect.OnChanged = func(selected string) {
+ if ui.launcher != nil {
+ config := ui.launcher.GetConfig()
+ config.LogLevel = selected
+ ui.launcher.SetConfig(config)
+ }
+ }
+}
+
+// startLocalAI starts the LocalAI service
+func (ui *LauncherUI) startLocalAI() {
+ fyne.Do(func() {
+ ui.startStopButton.Disable()
+ })
+ ui.UpdateStatus("Starting LocalAI...")
+
+ go func() {
+ err := ui.launcher.StartLocalAI()
+ if err != nil {
+ ui.UpdateStatus("Failed to start: " + err.Error())
+ fyne.DoAndWait(func() {
+ dialog.ShowError(err, ui.launcher.window)
+ })
+ } else {
+ fyne.Do(func() {
+ ui.startStopButton.SetText("Stop LocalAI")
+ ui.webUIButton.Enable()
+ })
+ }
+ fyne.Do(func() {
+ ui.startStopButton.Enable()
+ })
+ }()
+}
+
+// stopLocalAI stops the LocalAI service
+func (ui *LauncherUI) stopLocalAI() {
+ fyne.Do(func() {
+ ui.startStopButton.Disable()
+ })
+ ui.UpdateStatus("Stopping LocalAI...")
+
+ go func() {
+ err := ui.launcher.StopLocalAI()
+ if err != nil {
+ fyne.DoAndWait(func() {
+ dialog.ShowError(err, ui.launcher.window)
+ })
+ } else {
+ fyne.Do(func() {
+ ui.startStopButton.SetText("Start LocalAI")
+ ui.webUIButton.Disable()
+ })
+ }
+ fyne.Do(func() {
+ ui.startStopButton.Enable()
+ })
+ }()
+}
+
+// openWebUI opens the LocalAI WebUI in the default browser
+func (ui *LauncherUI) openWebUI() {
+ webURL := ui.launcher.GetWebUIURL()
+ parsedURL, err := url.Parse(webURL)
+ if err != nil {
+ dialog.ShowError(err, ui.launcher.window)
+ return
+ }
+
+ // Open URL in default browser
+ fyne.CurrentApp().OpenURL(parsedURL)
+}
+
+// saveConfiguration saves the current configuration
+func (ui *LauncherUI) saveConfiguration() {
+ log.Printf("saveConfiguration: starting to save configuration")
+
+ config := ui.launcher.GetConfig()
+ log.Printf("saveConfiguration: current config Environment vars: %v", config.EnvironmentVars)
+ log.Printf("saveConfiguration: ui.envVarsData has %d items: %v", len(ui.envVarsData), ui.envVarsData)
+
+ config.ModelsPath = ui.modelsPathEntry.Text
+ config.BackendsPath = ui.backendsPathEntry.Text
+ config.Address = ui.addressEntry.Text
+ config.LogLevel = ui.logLevelSelect.Selected
+ config.StartOnBoot = ui.startOnBootCheck.Checked
+
+ // Ensure environment variables are included in the configuration
+ config.EnvironmentVars = make(map[string]string)
+ for _, envVar := range ui.envVarsData {
+ config.EnvironmentVars[envVar.Key] = envVar.Value
+ log.Printf("saveConfiguration: adding env var %s=%s", envVar.Key, envVar.Value)
+ }
+
+ log.Printf("saveConfiguration: final config Environment vars: %v", config.EnvironmentVars)
+
+ err := ui.launcher.SetConfig(config)
+ if err != nil {
+ log.Printf("saveConfiguration: failed to save config: %v", err)
+ dialog.ShowError(err, ui.launcher.window)
+ } else {
+ log.Printf("saveConfiguration: config saved successfully")
+ dialog.ShowInformation("Configuration", "Configuration saved successfully", ui.launcher.window)
+ }
+}
+
+// checkForUpdates checks for available updates
+func (ui *LauncherUI) checkForUpdates() {
+ fyne.Do(func() {
+ ui.updateButton.Disable()
+ })
+ ui.UpdateStatus("Checking for updates...")
+
+ go func() {
+ available, version, err := ui.launcher.CheckForUpdates()
+ if err != nil {
+ ui.UpdateStatus("Failed to check updates: " + err.Error())
+ fyne.DoAndWait(func() {
+ dialog.ShowError(err, ui.launcher.window)
+ })
+ } else if available {
+ ui.latestVersion = version // Store the latest version
+ ui.UpdateStatus("Update available: " + version)
+ fyne.Do(func() {
+ if ui.downloadButton != nil {
+ ui.downloadButton.Enable()
+ }
+ })
+ ui.NotifyUpdateAvailable(version)
+ } else {
+ ui.UpdateStatus("No updates available")
+ fyne.DoAndWait(func() {
+ dialog.ShowInformation("Updates", "You are running the latest version", ui.launcher.window)
+ })
+ }
+ fyne.Do(func() {
+ ui.updateButton.Enable()
+ })
+ }()
+}
+
+// downloadUpdate downloads the latest update
+func (ui *LauncherUI) downloadUpdate() {
+ // Use stored version or check for updates
+ version := ui.latestVersion
+ if version == "" {
+ _, v, err := ui.launcher.CheckForUpdates()
+ if err != nil {
+ dialog.ShowError(err, ui.launcher.window)
+ return
+ }
+ version = v
+ ui.latestVersion = version
+ }
+
+ if version == "" {
+ dialog.ShowError(fmt.Errorf("no version information available"), ui.launcher.window)
+ return
+ }
+
+ // Disable buttons during download
+ if ui.downloadButton != nil {
+ fyne.Do(func() {
+ ui.downloadButton.Disable()
+ })
+ }
+
+ fyne.Do(func() {
+ ui.progressBar.Show()
+ ui.progressBar.SetValue(0)
+ })
+ ui.UpdateStatus("Downloading update " + version + "...")
+
+ go func() {
+ err := ui.launcher.DownloadUpdate(version, func(progress float64) {
+ // Update progress bar
+ fyne.Do(func() {
+ ui.progressBar.SetValue(progress)
+ })
+ // Update status with percentage
+ percentage := int(progress * 100)
+ ui.UpdateStatus(fmt.Sprintf("Downloading update %s... %d%%", version, percentage))
+ })
+
+ fyne.Do(func() {
+ ui.progressBar.Hide()
+ })
+
+ // Re-enable buttons after download
+ if ui.downloadButton != nil {
+ fyne.Do(func() {
+ ui.downloadButton.Enable()
+ })
+ }
+
+ if err != nil {
+ fyne.DoAndWait(func() {
+ ui.UpdateStatus("Failed to download update: " + err.Error())
+ dialog.ShowError(err, ui.launcher.window)
+ })
+ } else {
+ fyne.DoAndWait(func() {
+ ui.UpdateStatus("Update downloaded successfully")
+ dialog.ShowInformation("Update", "Update downloaded successfully. Please restart the launcher to use the new version.", ui.launcher.window)
+ })
+ }
+ }()
+}
+
+// UpdateStatus updates the status label
+func (ui *LauncherUI) UpdateStatus(status string) {
+ if ui.statusLabel != nil {
+ fyne.Do(func() {
+ ui.statusLabel.SetText(status)
+ })
+ }
+}
+
+// OnLogUpdate handles new log content
+func (ui *LauncherUI) OnLogUpdate(logLine string) {
+ if ui.logText != nil {
+ fyne.Do(func() {
+ currentText := ui.logText.Text
+ ui.logText.SetText(currentText + logLine)
+
+ // Auto-scroll to bottom (simplified)
+ ui.logText.CursorRow = len(ui.logText.Text)
+ })
+ }
+}
+
+// NotifyUpdateAvailable shows an update notification
+func (ui *LauncherUI) NotifyUpdateAvailable(version string) {
+ if ui.launcher != nil && ui.launcher.window != nil {
+ fyne.DoAndWait(func() {
+ dialog.ShowConfirm("Update Available",
+ "A new version ("+version+") is available. Would you like to download it?",
+ func(confirmed bool) {
+ if confirmed {
+ ui.downloadUpdate()
+ }
+ }, ui.launcher.window)
+ })
+ }
+}
+
+// LoadConfiguration loads the current configuration into UI elements
+func (ui *LauncherUI) LoadConfiguration() {
+ if ui.launcher == nil {
+ log.Printf("UI LoadConfiguration: launcher is nil")
+ return
+ }
+
+ config := ui.launcher.GetConfig()
+ log.Printf("UI LoadConfiguration: loading config - ModelsPath=%s, BackendsPath=%s, Address=%s, LogLevel=%s",
+ config.ModelsPath, config.BackendsPath, config.Address, config.LogLevel)
+ log.Printf("UI LoadConfiguration: Environment vars: %v", config.EnvironmentVars)
+
+ ui.modelsPathEntry.SetText(config.ModelsPath)
+ ui.backendsPathEntry.SetText(config.BackendsPath)
+ ui.addressEntry.SetText(config.Address)
+ ui.logLevelSelect.SetSelected(config.LogLevel)
+ ui.startOnBootCheck.SetChecked(config.StartOnBoot)
+
+ // Load environment variables
+ ui.envVarsData = []EnvVar{}
+ for key, value := range config.EnvironmentVars {
+ ui.envVarsData = append(ui.envVarsData, EnvVar{Key: key, Value: value})
+ }
+ if ui.updateEnvironmentDisplay != nil {
+ fyne.Do(func() {
+ ui.updateEnvironmentDisplay()
+ })
+ }
+
+ // Update version display
+ version := ui.launcher.GetCurrentVersion()
+ ui.versionLabel.SetText("Version: " + version)
+
+ log.Printf("UI LoadConfiguration: configuration loaded successfully")
+}
+
+// showDownloadProgress shows a progress window for downloading LocalAI
+func (ui *LauncherUI) showDownloadProgress(version, title string) {
+ fyne.DoAndWait(func() {
+ // Create progress window using the launcher's app
+ progressWindow := ui.launcher.app.NewWindow("Downloading LocalAI")
+ progressWindow.Resize(fyne.NewSize(400, 250))
+ progressWindow.CenterOnScreen()
+
+ // Progress bar
+ progressBar := widget.NewProgressBar()
+ progressBar.SetValue(0)
+
+ // Status label
+ statusLabel := widget.NewLabel("Preparing download...")
+
+ // Release notes button
+ releaseNotesButton := widget.NewButton("View Release Notes", func() {
+ releaseNotesURL, err := ui.launcher.githubReleaseNotesURL(version)
+ if err != nil {
+ log.Printf("Failed to parse URL: %v", err)
+ return
+ }
+
+ ui.launcher.app.OpenURL(releaseNotesURL)
+ })
+
+ // Progress container
+ progressContainer := container.NewVBox(
+ widget.NewLabel(title),
+ progressBar,
+ statusLabel,
+ widget.NewSeparator(),
+ releaseNotesButton,
+ )
+
+ progressWindow.SetContent(progressContainer)
+ progressWindow.Show()
+
+ // Start download in background
+ go func() {
+ err := ui.launcher.DownloadUpdate(version, func(progress float64) {
+ // Update progress bar
+ fyne.Do(func() {
+ progressBar.SetValue(progress)
+ percentage := int(progress * 100)
+ statusLabel.SetText(fmt.Sprintf("Downloading... %d%%", percentage))
+ })
+ })
+
+ // Handle completion
+ fyne.Do(func() {
+ if err != nil {
+ statusLabel.SetText(fmt.Sprintf("Download failed: %v", err))
+ // Show error dialog
+ dialog.ShowError(err, progressWindow)
+ } else {
+ statusLabel.SetText("Download completed successfully!")
+ progressBar.SetValue(1.0)
+
+ // Show success dialog
+ dialog.ShowConfirm("Installation Complete",
+ "LocalAI has been downloaded and installed successfully. You can now start LocalAI from the launcher.",
+ func(close bool) {
+ progressWindow.Close()
+ // Update status
+ ui.UpdateStatus("LocalAI installed successfully")
+ }, progressWindow)
+ }
+ })
+ }()
+ })
+}
+
+// UpdateRunningState updates UI based on LocalAI running state
+func (ui *LauncherUI) UpdateRunningState(isRunning bool) {
+ fyne.Do(func() {
+ if isRunning {
+ ui.startStopButton.SetText("Stop LocalAI")
+ ui.webUIButton.Enable()
+ } else {
+ ui.startStopButton.SetText("Start LocalAI")
+ ui.webUIButton.Disable()
+ }
+ })
+}
+
+// ShowWelcomeWindow displays the welcome window with helpful information
+func (ui *LauncherUI) ShowWelcomeWindow() {
+ if ui.launcher == nil || ui.launcher.window == nil {
+ log.Printf("Cannot show welcome window: launcher or window is nil")
+ return
+ }
+
+ fyne.DoAndWait(func() {
+ // Create welcome window
+ welcomeWindow := ui.launcher.app.NewWindow("Welcome to LocalAI Launcher")
+ welcomeWindow.Resize(fyne.NewSize(600, 500))
+ welcomeWindow.CenterOnScreen()
+ welcomeWindow.SetCloseIntercept(func() {
+ welcomeWindow.Close()
+ })
+
+ // Title
+ titleLabel := widget.NewLabel("Welcome to LocalAI Launcher!")
+ titleLabel.TextStyle = fyne.TextStyle{Bold: true}
+ titleLabel.Alignment = fyne.TextAlignCenter
+
+ // Welcome message
+ welcomeText := `LocalAI Launcher makes it easy to run LocalAI on your system.
+
+What you can do:
+• Start and stop LocalAI server
+• Configure models and backends paths
+• Set environment variables
+• Check for updates automatically
+• Access LocalAI WebUI when running
+
+Getting Started:
+1. Configure your models and backends paths
+2. Click "Start LocalAI" to begin
+3. Use "Open WebUI" to access the interface
+4. Check the system tray for quick access`
+
+ welcomeLabel := widget.NewLabel(welcomeText)
+ welcomeLabel.Wrapping = fyne.TextWrapWord
+
+ // Useful links section
+ linksTitle := widget.NewLabel("Useful Links:")
+ linksTitle.TextStyle = fyne.TextStyle{Bold: true}
+
+ // Create link buttons
+ docsButton := widget.NewButton("📚 Documentation", func() {
+ ui.openURL("https://localai.io/docs/")
+ })
+
+ githubButton := widget.NewButton("🐙 GitHub Repository", func() {
+ ui.openURL("https://github.com/mudler/LocalAI")
+ })
+
+ modelsButton := widget.NewButton("🤖 Model Gallery", func() {
+ ui.openURL("https://localai.io/models/")
+ })
+
+ communityButton := widget.NewButton("💬 Community", func() {
+ ui.openURL("https://discord.gg/XgwjKptP7Z")
+ })
+
+ // Checkbox to disable welcome window
+ dontShowAgainCheck := widget.NewCheck("Don't show this welcome window again", func(checked bool) {
+ if ui.launcher != nil {
+ config := ui.launcher.GetConfig()
+ v := !checked
+ config.ShowWelcome = &v
+ ui.launcher.SetConfig(config)
+ }
+ })
+
+ config := ui.launcher.GetConfig()
+ if config.ShowWelcome != nil {
+ dontShowAgainCheck.SetChecked(*config.ShowWelcome)
+ }
+
+ // Close button
+ closeButton := widget.NewButton("Get Started", func() {
+ welcomeWindow.Close()
+ })
+ closeButton.Importance = widget.HighImportance
+
+ // Layout
+ linksContainer := container.NewVBox(
+ linksTitle,
+ docsButton,
+ githubButton,
+ modelsButton,
+ communityButton,
+ )
+
+ content := container.NewVBox(
+ titleLabel,
+ widget.NewSeparator(),
+ welcomeLabel,
+ widget.NewSeparator(),
+ linksContainer,
+ widget.NewSeparator(),
+ dontShowAgainCheck,
+ widget.NewSeparator(),
+ closeButton,
+ )
+
+ welcomeWindow.SetContent(content)
+ welcomeWindow.Show()
+ })
+}
+
+// openURL opens a URL in the default browser
+func (ui *LauncherUI) openURL(urlString string) {
+ parsedURL, err := url.Parse(urlString)
+ if err != nil {
+ log.Printf("Failed to parse URL %s: %v", urlString, err)
+ return
+ }
+ fyne.CurrentApp().OpenURL(parsedURL)
+}
diff --git a/cmd/launcher/logo.png b/cmd/launcher/logo.png
new file mode 100644
index 0000000000000000000000000000000000000000..94035377eae8c2f90843d261f5e285940167f693
Binary files /dev/null and b/cmd/launcher/logo.png differ
diff --git a/cmd/launcher/main.go b/cmd/launcher/main.go
new file mode 100644
index 0000000000000000000000000000000000000000..220fbb612d3698b6e10c667cae7fda653928faf3
--- /dev/null
+++ b/cmd/launcher/main.go
@@ -0,0 +1,72 @@
+package main
+
+import (
+ "log"
+
+ "fyne.io/fyne/v2"
+ "fyne.io/fyne/v2/app"
+ "fyne.io/fyne/v2/driver/desktop"
+ coreLauncher "github.com/mudler/LocalAI/cmd/launcher/internal"
+ "github.com/mudler/LocalAI/pkg/signals"
+)
+
+func main() {
+ // Create the application with unique ID
+ myApp := app.NewWithID("com.localai.launcher")
+ myApp.SetIcon(resourceIconPng)
+ myWindow := myApp.NewWindow("LocalAI Launcher")
+ myWindow.Resize(fyne.NewSize(800, 600))
+
+ // Create the launcher UI
+ ui := coreLauncher.NewLauncherUI()
+
+ // Initialize the launcher with UI context
+ launcher := coreLauncher.NewLauncher(ui, myWindow, myApp)
+
+ // Setup the UI
+ content := ui.CreateMainUI(launcher)
+ myWindow.SetContent(content)
+
+ // Setup window close behavior - minimize to tray instead of closing
+ myWindow.SetCloseIntercept(func() {
+ myWindow.Hide()
+ })
+
+ // Setup system tray using Fyne's built-in approach``
+ if desk, ok := myApp.(desktop.App); ok {
+ // Create a dynamic systray manager
+ systray := coreLauncher.NewSystrayManager(launcher, myWindow, desk, myApp, resourceIconPng)
+ launcher.SetSystray(systray)
+ }
+
+ // Setup signal handling for graceful shutdown
+ signals.RegisterGracefulTerminationHandler(func() {
+ // Perform cleanup
+ if err := launcher.Shutdown(); err != nil {
+ log.Printf("Error during shutdown: %v", err)
+ }
+ })
+
+ // Initialize the launcher state
+ go func() {
+ if err := launcher.Initialize(); err != nil {
+ log.Printf("Failed to initialize launcher: %v", err)
+ if launcher.GetUI() != nil {
+ launcher.GetUI().UpdateStatus("Failed to initialize: " + err.Error())
+ }
+ } else {
+ // Load configuration into UI
+ launcher.GetUI().LoadConfiguration()
+ launcher.GetUI().UpdateStatus("Ready")
+
+ // Show welcome window if configured to do so
+ config := launcher.GetConfig()
+ if *config.ShowWelcome {
+ launcher.GetUI().ShowWelcomeWindow()
+ }
+ }
+ }()
+
+ // Run the application in background (window only shown when "Settings" is clicked)
+ myApp.Run()
+}
diff --git a/cmd/local-ai/main.go b/cmd/local-ai/main.go
new file mode 100644
index 0000000000000000000000000000000000000000..9d5cb3fba05731ae76f1d7d88a335a2595bbc808
--- /dev/null
+++ b/cmd/local-ai/main.go
@@ -0,0 +1,82 @@
+package main
+
+import (
+ "os"
+ "path/filepath"
+
+ "github.com/alecthomas/kong"
+ "github.com/joho/godotenv"
+ "github.com/mudler/LocalAI/core/cli"
+ "github.com/mudler/LocalAI/internal"
+ "github.com/mudler/xlog"
+
+ _ "github.com/mudler/LocalAI/swagger"
+)
+
+func main() {
+ var err error
+
+ // Initialize xlog at a level of INFO, we will set the desired level after we parse the CLI options
+ xlog.SetLogger(xlog.NewLogger(xlog.LogLevel("info"), "text"))
+
+ // handle loading environment variables from .env files
+ envFiles := []string{".env", "localai.env"}
+ homeDir, err := os.UserHomeDir()
+ if err == nil {
+ envFiles = append(envFiles, filepath.Join(homeDir, "localai.env"), filepath.Join(homeDir, ".config/localai.env"))
+ }
+ envFiles = append(envFiles, "/etc/localai.env")
+
+ for _, envFile := range envFiles {
+ if _, err := os.Stat(envFile); err == nil {
+ xlog.Debug("env file found, loading environment variables from file", "envFile", envFile)
+ err = godotenv.Load(envFile)
+ if err != nil {
+ xlog.Error("failed to load environment variables from file", "error", err, "envFile", envFile)
+ continue
+ }
+ }
+ }
+
+ // Actually parse the CLI options
+ ctx := kong.Parse(&cli.CLI,
+ kong.Description(
+ ` LocalAI is a drop-in replacement OpenAI API for running LLM, GPT and genAI models locally on CPU, GPUs with consumer grade hardware.
+
+For a list of all available models run local-ai models list
+
+Copyright: Ettore Di Giacinto
+
+Version: ${version}
+`,
+ ),
+ kong.UsageOnError(),
+ kong.Vars{
+ "basepath": kong.ExpandPath("."),
+ "galleries": `[{"name":"localai", "url":"github:mudler/LocalAI/gallery/index.yaml@master"}]`,
+ "backends": `[{"name":"localai", "url":"github:mudler/LocalAI/backend/index.yaml@master"}]`,
+ "version": internal.PrintableVersion(),
+ },
+ )
+
+ // Configure the logging level before we run the application
+ // This is here to preserve the existing --debug flag functionality
+ logLevel := "info"
+ if cli.CLI.Debug && cli.CLI.LogLevel == nil {
+ logLevel = "debug"
+ cli.CLI.LogLevel = &logLevel
+ }
+
+ if cli.CLI.LogLevel == nil {
+ cli.CLI.LogLevel = &logLevel
+ }
+
+ // Set xlog logger with the desired level and text format
+ xlog.SetLogger(xlog.NewLogger(xlog.LogLevel(*cli.CLI.LogLevel), *cli.CLI.LogFormat))
+
+ // Run the thing!
+ err = ctx.Run(&cli.CLI.Context)
+ if err != nil {
+ xlog.Fatal("Error running the application", "error", err)
+ }
+}
diff --git a/configuration/.keep b/configuration/.keep
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/core/application/agent_jobs.go b/core/application/agent_jobs.go
new file mode 100644
index 0000000000000000000000000000000000000000..0ed5d928331a01269aac2887d745fd16b6e0a1ed
--- /dev/null
+++ b/core/application/agent_jobs.go
@@ -0,0 +1,42 @@
+package application
+
+import (
+ "time"
+
+ "github.com/mudler/LocalAI/core/services"
+ "github.com/mudler/xlog"
+)
+
+// RestartAgentJobService restarts the agent job service with current ApplicationConfig settings
+func (a *Application) RestartAgentJobService() error {
+ a.agentJobMutex.Lock()
+ defer a.agentJobMutex.Unlock()
+
+ // Stop existing service if running
+ if a.agentJobService != nil {
+ if err := a.agentJobService.Stop(); err != nil {
+ xlog.Warn("Error stopping agent job service", "error", err)
+ }
+ // Wait a bit for shutdown to complete
+ time.Sleep(200 * time.Millisecond)
+ }
+
+ // Create new service instance
+ agentJobService := services.NewAgentJobService(
+ a.ApplicationConfig(),
+ a.ModelLoader(),
+ a.ModelConfigLoader(),
+ a.TemplatesEvaluator(),
+ )
+
+ // Start the service
+ err := agentJobService.Start(a.ApplicationConfig().Context)
+ if err != nil {
+ xlog.Error("Failed to start agent job service", "error", err)
+ return err
+ }
+
+ a.agentJobService = agentJobService
+ xlog.Info("Agent job service restarted")
+ return nil
+}
diff --git a/core/application/application.go b/core/application/application.go
new file mode 100644
index 0000000000000000000000000000000000000000..38a9d2cf9dfbd1cfbda8f212b023e1c54f62fb27
--- /dev/null
+++ b/core/application/application.go
@@ -0,0 +1,92 @@
+package application
+
+import (
+ "context"
+ "sync"
+
+ "github.com/mudler/LocalAI/core/config"
+ "github.com/mudler/LocalAI/core/services"
+ "github.com/mudler/LocalAI/core/templates"
+ "github.com/mudler/LocalAI/pkg/model"
+)
+
+type Application struct {
+ backendLoader *config.ModelConfigLoader
+ modelLoader *model.ModelLoader
+ applicationConfig *config.ApplicationConfig
+ startupConfig *config.ApplicationConfig // Stores original config from env vars (before file loading)
+ templatesEvaluator *templates.Evaluator
+ galleryService *services.GalleryService
+ agentJobService *services.AgentJobService
+ watchdogMutex sync.Mutex
+ watchdogStop chan bool
+ p2pMutex sync.Mutex
+ p2pCtx context.Context
+ p2pCancel context.CancelFunc
+ agentJobMutex sync.Mutex
+}
+
+func newApplication(appConfig *config.ApplicationConfig) *Application {
+ return &Application{
+ backendLoader: config.NewModelConfigLoader(appConfig.SystemState.Model.ModelsPath),
+ modelLoader: model.NewModelLoader(appConfig.SystemState),
+ applicationConfig: appConfig,
+ templatesEvaluator: templates.NewEvaluator(appConfig.SystemState.Model.ModelsPath),
+ }
+}
+
+func (a *Application) ModelConfigLoader() *config.ModelConfigLoader {
+ return a.backendLoader
+}
+
+func (a *Application) ModelLoader() *model.ModelLoader {
+ return a.modelLoader
+}
+
+func (a *Application) ApplicationConfig() *config.ApplicationConfig {
+ return a.applicationConfig
+}
+
+func (a *Application) TemplatesEvaluator() *templates.Evaluator {
+ return a.templatesEvaluator
+}
+
+func (a *Application) GalleryService() *services.GalleryService {
+ return a.galleryService
+}
+
+func (a *Application) AgentJobService() *services.AgentJobService {
+ return a.agentJobService
+}
+
+// StartupConfig returns the original startup configuration (from env vars, before file loading)
+func (a *Application) StartupConfig() *config.ApplicationConfig {
+ return a.startupConfig
+}
+
+func (a *Application) start() error {
+ galleryService := services.NewGalleryService(a.ApplicationConfig(), a.ModelLoader())
+ err := galleryService.Start(a.ApplicationConfig().Context, a.ModelConfigLoader(), a.ApplicationConfig().SystemState)
+ if err != nil {
+ return err
+ }
+
+ a.galleryService = galleryService
+
+ // Initialize agent job service
+ agentJobService := services.NewAgentJobService(
+ a.ApplicationConfig(),
+ a.ModelLoader(),
+ a.ModelConfigLoader(),
+ a.TemplatesEvaluator(),
+ )
+
+ err = agentJobService.Start(a.ApplicationConfig().Context)
+ if err != nil {
+ return err
+ }
+
+ a.agentJobService = agentJobService
+
+ return nil
+}
diff --git a/core/application/config_file_watcher.go b/core/application/config_file_watcher.go
new file mode 100644
index 0000000000000000000000000000000000000000..90b78485d8de0196465c9d2a26ec8a3eddb370f9
--- /dev/null
+++ b/core/application/config_file_watcher.go
@@ -0,0 +1,363 @@
+package application
+
+import (
+ "encoding/json"
+ "fmt"
+ "os"
+ "path"
+ "path/filepath"
+ "time"
+
+ "dario.cat/mergo"
+ "github.com/fsnotify/fsnotify"
+ "github.com/mudler/LocalAI/core/config"
+ "github.com/mudler/xlog"
+)
+
+type fileHandler func(fileContent []byte, appConfig *config.ApplicationConfig) error
+
+type configFileHandler struct {
+ handlers map[string]fileHandler
+
+ watcher *fsnotify.Watcher
+
+ appConfig *config.ApplicationConfig
+}
+
+// TODO: This should be a singleton eventually so other parts of the code can register config file handlers,
+// then we can export it to other packages
+func newConfigFileHandler(appConfig *config.ApplicationConfig) configFileHandler {
+ c := configFileHandler{
+ handlers: make(map[string]fileHandler),
+ appConfig: appConfig,
+ }
+ err := c.Register("api_keys.json", readApiKeysJson(*appConfig), true)
+ if err != nil {
+ xlog.Error("unable to register config file handler", "error", err, "file", "api_keys.json")
+ }
+ err = c.Register("external_backends.json", readExternalBackendsJson(*appConfig), true)
+ if err != nil {
+ xlog.Error("unable to register config file handler", "error", err, "file", "external_backends.json")
+ }
+ err = c.Register("runtime_settings.json", readRuntimeSettingsJson(*appConfig), true)
+ if err != nil {
+ xlog.Error("unable to register config file handler", "error", err, "file", "runtime_settings.json")
+ }
+ // Note: agent_tasks.json and agent_jobs.json are handled by AgentJobService directly
+ // The service watches and reloads these files internally
+ return c
+}
+
+func (c *configFileHandler) Register(filename string, handler fileHandler, runNow bool) error {
+ _, ok := c.handlers[filename]
+ if ok {
+ return fmt.Errorf("handler already registered for file %s", filename)
+ }
+ c.handlers[filename] = handler
+ if runNow {
+ c.callHandler(filename, handler)
+ }
+ return nil
+}
+
+func (c *configFileHandler) callHandler(filename string, handler fileHandler) {
+ rootedFilePath := filepath.Join(c.appConfig.DynamicConfigsDir, filepath.Clean(filename))
+ xlog.Debug("reading file for dynamic config update", "filename", rootedFilePath)
+ fileContent, err := os.ReadFile(rootedFilePath)
+ if err != nil && !os.IsNotExist(err) {
+ xlog.Error("could not read file", "error", err, "filename", rootedFilePath)
+ }
+
+ if err = handler(fileContent, c.appConfig); err != nil {
+ xlog.Error("WatchConfigDirectory goroutine failed to update options", "error", err)
+ }
+}
+
+func (c *configFileHandler) Watch() error {
+ configWatcher, err := fsnotify.NewWatcher()
+ c.watcher = configWatcher
+ if err != nil {
+ return err
+ }
+
+ if c.appConfig.DynamicConfigsDirPollInterval > 0 {
+ xlog.Debug("Poll interval set, falling back to polling for configuration changes")
+ ticker := time.NewTicker(c.appConfig.DynamicConfigsDirPollInterval)
+ go func() {
+ for {
+ <-ticker.C
+ for file, handler := range c.handlers {
+ xlog.Debug("polling config file", "file", file)
+ c.callHandler(file, handler)
+ }
+ }
+ }()
+ }
+
+ // Start listening for events.
+ go func() {
+ for {
+ select {
+ case event, ok := <-c.watcher.Events:
+ if !ok {
+ return
+ }
+ if event.Has(fsnotify.Write | fsnotify.Create | fsnotify.Remove) {
+ handler, ok := c.handlers[path.Base(event.Name)]
+ if !ok {
+ continue
+ }
+
+ c.callHandler(filepath.Base(event.Name), handler)
+ }
+ case err, ok := <-c.watcher.Errors:
+ xlog.Error("config watcher error received", "error", err)
+ if !ok {
+ return
+ }
+ }
+ }
+ }()
+
+ // Add a path.
+ err = c.watcher.Add(c.appConfig.DynamicConfigsDir)
+ if err != nil {
+ return fmt.Errorf("unable to create a watcher on the configuration directory: %+v", err)
+ }
+
+ return nil
+}
+
+// TODO: When we institute graceful shutdown, this should be called
+func (c *configFileHandler) Stop() error {
+ return c.watcher.Close()
+}
+
+func readApiKeysJson(startupAppConfig config.ApplicationConfig) fileHandler {
+ handler := func(fileContent []byte, appConfig *config.ApplicationConfig) error {
+ xlog.Debug("processing api keys runtime update", "numKeys", len(startupAppConfig.ApiKeys))
+
+ if len(fileContent) > 0 {
+ // Parse JSON content from the file
+ var fileKeys []string
+ err := json.Unmarshal(fileContent, &fileKeys)
+ if err != nil {
+ return err
+ }
+
+ xlog.Debug("discovered API keys from api keys dynamic config file", "numKeys", len(fileKeys))
+
+ appConfig.ApiKeys = append(startupAppConfig.ApiKeys, fileKeys...)
+ } else {
+ xlog.Debug("no API keys discovered from dynamic config file")
+ appConfig.ApiKeys = startupAppConfig.ApiKeys
+ }
+ xlog.Debug("total api keys after processing", "numKeys", len(appConfig.ApiKeys))
+ return nil
+ }
+
+ return handler
+}
+
+func readExternalBackendsJson(startupAppConfig config.ApplicationConfig) fileHandler {
+ handler := func(fileContent []byte, appConfig *config.ApplicationConfig) error {
+ xlog.Debug("processing external_backends.json")
+
+ if len(fileContent) > 0 {
+ // Parse JSON content from the file
+ var fileBackends map[string]string
+ err := json.Unmarshal(fileContent, &fileBackends)
+ if err != nil {
+ return err
+ }
+ appConfig.ExternalGRPCBackends = startupAppConfig.ExternalGRPCBackends
+ err = mergo.Merge(&appConfig.ExternalGRPCBackends, &fileBackends)
+ if err != nil {
+ return err
+ }
+ } else {
+ appConfig.ExternalGRPCBackends = startupAppConfig.ExternalGRPCBackends
+ }
+ xlog.Debug("external backends loaded from external_backends.json")
+ return nil
+ }
+ return handler
+}
+
+func readRuntimeSettingsJson(startupAppConfig config.ApplicationConfig) fileHandler {
+ handler := func(fileContent []byte, appConfig *config.ApplicationConfig) error {
+ xlog.Debug("processing runtime_settings.json")
+
+ // Determine if settings came from env vars by comparing with startup config
+ // startupAppConfig contains the original values set from env vars at startup.
+ // If current values match startup values, they came from env vars (or defaults).
+ // We apply file settings only if current values match startup values (meaning not from env vars).
+ envWatchdogIdle := appConfig.WatchDogIdle == startupAppConfig.WatchDogIdle
+ envWatchdogBusy := appConfig.WatchDogBusy == startupAppConfig.WatchDogBusy
+ envWatchdogIdleTimeout := appConfig.WatchDogIdleTimeout == startupAppConfig.WatchDogIdleTimeout
+ envWatchdogBusyTimeout := appConfig.WatchDogBusyTimeout == startupAppConfig.WatchDogBusyTimeout
+ envSingleBackend := appConfig.SingleBackend == startupAppConfig.SingleBackend
+ envMaxActiveBackends := appConfig.MaxActiveBackends == startupAppConfig.MaxActiveBackends
+ envParallelRequests := appConfig.ParallelBackendRequests == startupAppConfig.ParallelBackendRequests
+ envMemoryReclaimerEnabled := appConfig.MemoryReclaimerEnabled == startupAppConfig.MemoryReclaimerEnabled
+ envMemoryReclaimerThreshold := appConfig.MemoryReclaimerThreshold == startupAppConfig.MemoryReclaimerThreshold
+ envThreads := appConfig.Threads == startupAppConfig.Threads
+ envContextSize := appConfig.ContextSize == startupAppConfig.ContextSize
+ envF16 := appConfig.F16 == startupAppConfig.F16
+ envDebug := appConfig.Debug == startupAppConfig.Debug
+ envCORS := appConfig.CORS == startupAppConfig.CORS
+ envCSRF := appConfig.CSRF == startupAppConfig.CSRF
+ envCORSAllowOrigins := appConfig.CORSAllowOrigins == startupAppConfig.CORSAllowOrigins
+ envP2PToken := appConfig.P2PToken == startupAppConfig.P2PToken
+ envP2PNetworkID := appConfig.P2PNetworkID == startupAppConfig.P2PNetworkID
+ envFederated := appConfig.Federated == startupAppConfig.Federated
+ envAutoloadGalleries := appConfig.AutoloadGalleries == startupAppConfig.AutoloadGalleries
+ envAutoloadBackendGalleries := appConfig.AutoloadBackendGalleries == startupAppConfig.AutoloadBackendGalleries
+ envAgentJobRetentionDays := appConfig.AgentJobRetentionDays == startupAppConfig.AgentJobRetentionDays
+ envForceEvictionWhenBusy := appConfig.ForceEvictionWhenBusy == startupAppConfig.ForceEvictionWhenBusy
+ envLRUEvictionMaxRetries := appConfig.LRUEvictionMaxRetries == startupAppConfig.LRUEvictionMaxRetries
+ envLRUEvictionRetryInterval := appConfig.LRUEvictionRetryInterval == startupAppConfig.LRUEvictionRetryInterval
+
+ if len(fileContent) > 0 {
+ var settings config.RuntimeSettings
+ err := json.Unmarshal(fileContent, &settings)
+ if err != nil {
+ return err
+ }
+
+ // Apply file settings only if they don't match startup values (i.e., not from env vars)
+ if settings.WatchdogIdleEnabled != nil && !envWatchdogIdle {
+ appConfig.WatchDogIdle = *settings.WatchdogIdleEnabled
+ if appConfig.WatchDogIdle {
+ appConfig.WatchDog = true
+ }
+ }
+ if settings.WatchdogBusyEnabled != nil && !envWatchdogBusy {
+ appConfig.WatchDogBusy = *settings.WatchdogBusyEnabled
+ if appConfig.WatchDogBusy {
+ appConfig.WatchDog = true
+ }
+ }
+ if settings.WatchdogIdleTimeout != nil && !envWatchdogIdleTimeout {
+ dur, err := time.ParseDuration(*settings.WatchdogIdleTimeout)
+ if err == nil {
+ appConfig.WatchDogIdleTimeout = dur
+ } else {
+ xlog.Warn("invalid watchdog idle timeout in runtime_settings.json", "error", err, "timeout", *settings.WatchdogIdleTimeout)
+ }
+ }
+ if settings.WatchdogBusyTimeout != nil && !envWatchdogBusyTimeout {
+ dur, err := time.ParseDuration(*settings.WatchdogBusyTimeout)
+ if err == nil {
+ appConfig.WatchDogBusyTimeout = dur
+ } else {
+ xlog.Warn("invalid watchdog busy timeout in runtime_settings.json", "error", err, "timeout", *settings.WatchdogBusyTimeout)
+ }
+ }
+ // Handle MaxActiveBackends (new) and SingleBackend (deprecated)
+ if settings.MaxActiveBackends != nil && !envMaxActiveBackends {
+ appConfig.MaxActiveBackends = *settings.MaxActiveBackends
+ // For backward compatibility, also set SingleBackend if MaxActiveBackends == 1
+ appConfig.SingleBackend = (*settings.MaxActiveBackends == 1)
+ } else if settings.SingleBackend != nil && !envSingleBackend {
+ // Legacy: SingleBackend maps to MaxActiveBackends = 1
+ appConfig.SingleBackend = *settings.SingleBackend
+ if *settings.SingleBackend {
+ appConfig.MaxActiveBackends = 1
+ } else {
+ appConfig.MaxActiveBackends = 0
+ }
+ }
+ if settings.ParallelBackendRequests != nil && !envParallelRequests {
+ appConfig.ParallelBackendRequests = *settings.ParallelBackendRequests
+ }
+ if settings.MemoryReclaimerEnabled != nil && !envMemoryReclaimerEnabled {
+ appConfig.MemoryReclaimerEnabled = *settings.MemoryReclaimerEnabled
+ if appConfig.MemoryReclaimerEnabled {
+ appConfig.WatchDog = true // Memory reclaimer requires watchdog
+ }
+ }
+ if settings.MemoryReclaimerThreshold != nil && !envMemoryReclaimerThreshold {
+ appConfig.MemoryReclaimerThreshold = *settings.MemoryReclaimerThreshold
+ }
+ if settings.ForceEvictionWhenBusy != nil && !envForceEvictionWhenBusy {
+ appConfig.ForceEvictionWhenBusy = *settings.ForceEvictionWhenBusy
+ }
+ if settings.LRUEvictionMaxRetries != nil && !envLRUEvictionMaxRetries {
+ appConfig.LRUEvictionMaxRetries = *settings.LRUEvictionMaxRetries
+ }
+ if settings.LRUEvictionRetryInterval != nil && !envLRUEvictionRetryInterval {
+ dur, err := time.ParseDuration(*settings.LRUEvictionRetryInterval)
+ if err == nil {
+ appConfig.LRUEvictionRetryInterval = dur
+ } else {
+ xlog.Warn("invalid LRU eviction retry interval in runtime_settings.json", "error", err, "interval", *settings.LRUEvictionRetryInterval)
+ }
+ }
+ if settings.Threads != nil && !envThreads {
+ appConfig.Threads = *settings.Threads
+ }
+ if settings.ContextSize != nil && !envContextSize {
+ appConfig.ContextSize = *settings.ContextSize
+ }
+ if settings.F16 != nil && !envF16 {
+ appConfig.F16 = *settings.F16
+ }
+ if settings.Debug != nil && !envDebug {
+ appConfig.Debug = *settings.Debug
+ }
+ if settings.CORS != nil && !envCORS {
+ appConfig.CORS = *settings.CORS
+ }
+ if settings.CSRF != nil && !envCSRF {
+ appConfig.CSRF = *settings.CSRF
+ }
+ if settings.CORSAllowOrigins != nil && !envCORSAllowOrigins {
+ appConfig.CORSAllowOrigins = *settings.CORSAllowOrigins
+ }
+ if settings.P2PToken != nil && !envP2PToken {
+ appConfig.P2PToken = *settings.P2PToken
+ }
+ if settings.P2PNetworkID != nil && !envP2PNetworkID {
+ appConfig.P2PNetworkID = *settings.P2PNetworkID
+ }
+ if settings.Federated != nil && !envFederated {
+ appConfig.Federated = *settings.Federated
+ }
+ if settings.Galleries != nil {
+ appConfig.Galleries = *settings.Galleries
+ }
+ if settings.BackendGalleries != nil {
+ appConfig.BackendGalleries = *settings.BackendGalleries
+ }
+ if settings.AutoloadGalleries != nil && !envAutoloadGalleries {
+ appConfig.AutoloadGalleries = *settings.AutoloadGalleries
+ }
+ if settings.AutoloadBackendGalleries != nil && !envAutoloadBackendGalleries {
+ appConfig.AutoloadBackendGalleries = *settings.AutoloadBackendGalleries
+ }
+ if settings.ApiKeys != nil {
+ // API keys from env vars (startup) should be kept, runtime settings keys replace all runtime keys
+ // If runtime_settings.json specifies ApiKeys (even if empty), it replaces all runtime keys
+ // Start with env keys, then add runtime_settings.json keys (which may be empty to clear them)
+ envKeys := startupAppConfig.ApiKeys
+ runtimeKeys := *settings.ApiKeys
+ // Replace all runtime keys with what's in runtime_settings.json
+ appConfig.ApiKeys = append(envKeys, runtimeKeys...)
+ }
+ if settings.AgentJobRetentionDays != nil && !envAgentJobRetentionDays {
+ appConfig.AgentJobRetentionDays = *settings.AgentJobRetentionDays
+ }
+
+ // If watchdog is enabled via file but not via env, ensure WatchDog flag is set
+ if !envWatchdogIdle && !envWatchdogBusy {
+ if settings.WatchdogEnabled != nil && *settings.WatchdogEnabled {
+ appConfig.WatchDog = true
+ }
+ }
+ }
+ xlog.Debug("runtime settings loaded from runtime_settings.json")
+ return nil
+ }
+ return handler
+}
diff --git a/core/application/p2p.go b/core/application/p2p.go
new file mode 100644
index 0000000000000000000000000000000000000000..99527e841260a71114414d12587a24de325f7775
--- /dev/null
+++ b/core/application/p2p.go
@@ -0,0 +1,239 @@
+package application
+
+import (
+ "context"
+ "fmt"
+ "net"
+ "slices"
+ "time"
+
+ "github.com/google/uuid"
+ "github.com/mudler/LocalAI/core/gallery"
+ "github.com/mudler/LocalAI/core/p2p"
+ "github.com/mudler/LocalAI/core/schema"
+ "github.com/mudler/LocalAI/core/services"
+
+ "github.com/mudler/edgevpn/pkg/node"
+ "github.com/mudler/xlog"
+)
+
+func (a *Application) StopP2P() error {
+ if a.p2pCancel != nil {
+ a.p2pCancel()
+ a.p2pCancel = nil
+ a.p2pCtx = nil
+ // Wait a bit for shutdown to complete
+ time.Sleep(200 * time.Millisecond)
+ }
+ return nil
+}
+
+func (a *Application) StartP2P() error {
+ // we need a p2p token
+ if a.applicationConfig.P2PToken == "" {
+ return fmt.Errorf("P2P token is not set")
+ }
+
+ networkID := a.applicationConfig.P2PNetworkID
+
+ ctx, cancel := context.WithCancel(a.ApplicationConfig().Context)
+ a.p2pCtx = ctx
+ a.p2pCancel = cancel
+
+ var n *node.Node
+ // Here we are avoiding creating multiple nodes:
+ // - if the federated mode is enabled, we create a federated node and expose a service
+ // - exposing a service creates a node with specific options, and we don't want to create another node
+
+ // If the federated mode is enabled, we expose a service to the local instance running
+ // at r.Address
+ if a.applicationConfig.Federated {
+ _, port, err := net.SplitHostPort(a.applicationConfig.APIAddress)
+ if err != nil {
+ return err
+ }
+
+ // Here a new node is created and started
+ // and a service is exposed by the node
+ node, err := p2p.ExposeService(ctx, "localhost", port, a.applicationConfig.P2PToken, p2p.NetworkID(networkID, p2p.FederatedID))
+ if err != nil {
+ return err
+ }
+
+ if err := p2p.ServiceDiscoverer(ctx, node, a.applicationConfig.P2PToken, p2p.NetworkID(networkID, p2p.FederatedID), nil, false); err != nil {
+ return err
+ }
+
+ n = node
+ // start node sync in the background
+ if err := a.p2pSync(ctx, node); err != nil {
+ return err
+ }
+ }
+
+ // If a node wasn't created previously, create it
+ if n == nil {
+ node, err := p2p.NewNode(a.applicationConfig.P2PToken)
+ if err != nil {
+ return err
+ }
+ err = node.Start(ctx)
+ if err != nil {
+ return fmt.Errorf("starting new node: %w", err)
+ }
+ n = node
+ }
+
+ // Attach a ServiceDiscoverer to the p2p node
+ xlog.Info("Starting P2P server discovery...")
+ if err := p2p.ServiceDiscoverer(ctx, n, a.applicationConfig.P2PToken, p2p.NetworkID(networkID, p2p.WorkerID), func(serviceID string, node schema.NodeData) {
+ var tunnelAddresses []string
+ for _, v := range p2p.GetAvailableNodes(p2p.NetworkID(networkID, p2p.WorkerID)) {
+ if v.IsOnline() {
+ tunnelAddresses = append(tunnelAddresses, v.TunnelAddress)
+ } else {
+ xlog.Info("Node is offline", "node", v.ID)
+ }
+ }
+ if a.applicationConfig.TunnelCallback != nil {
+ a.applicationConfig.TunnelCallback(tunnelAddresses)
+ }
+ }, true); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+// RestartP2P restarts the P2P stack with current ApplicationConfig settings
+// Note: This method signals that P2P should be restarted, but the actual restart
+// is handled by the caller to avoid import cycles
+func (a *Application) RestartP2P() error {
+ a.p2pMutex.Lock()
+ defer a.p2pMutex.Unlock()
+
+ // Stop existing P2P if running
+ if a.p2pCancel != nil {
+ a.p2pCancel()
+ a.p2pCancel = nil
+ a.p2pCtx = nil
+ // Wait a bit for shutdown to complete
+ time.Sleep(200 * time.Millisecond)
+ }
+
+ appConfig := a.ApplicationConfig()
+
+ // Start P2P if token is set
+ if appConfig.P2PToken == "" {
+ return fmt.Errorf("P2P token is not set")
+ }
+
+ // Create new context for P2P
+ ctx, cancel := context.WithCancel(appConfig.Context)
+ a.p2pCtx = ctx
+ a.p2pCancel = cancel
+
+ // Get API address from config
+ address := appConfig.APIAddress
+ if address == "" {
+ address = "127.0.0.1:8080" // default
+ }
+
+ // Start P2P stack in a goroutine
+ go func() {
+ if err := a.StartP2P(); err != nil {
+ xlog.Error("Failed to start P2P stack", "error", err)
+ cancel() // Cancel context on error
+ }
+ }()
+ xlog.Info("P2P stack restarted with new settings")
+
+ return nil
+}
+
+func syncState(ctx context.Context, n *node.Node, app *Application) error {
+ xlog.Debug("[p2p-sync] Syncing state")
+
+ whatWeHave := []string{}
+ for _, model := range app.ModelConfigLoader().GetAllModelsConfigs() {
+ whatWeHave = append(whatWeHave, model.Name)
+ }
+
+ ledger, _ := n.Ledger()
+ currentData := ledger.CurrentData()
+ xlog.Debug("[p2p-sync] Current data", "data", currentData)
+ data, exists := ledger.GetKey("shared_state", "models")
+ if !exists {
+ ledger.AnnounceUpdate(ctx, time.Minute, "shared_state", "models", whatWeHave)
+ xlog.Debug("No models found in the ledger, announced our models", "models", whatWeHave)
+ }
+
+ models := []string{}
+ if err := data.Unmarshal(&models); err != nil {
+ xlog.Warn("error unmarshalling models", "error", err)
+ return nil
+ }
+
+ xlog.Debug("[p2p-sync] Models comparison", "ourModels", whatWeHave, "ledgerModels", models)
+
+ // Sync with our state
+ whatIsNotThere := []string{}
+ for _, model := range whatWeHave {
+ if !slices.Contains(models, model) {
+ whatIsNotThere = append(whatIsNotThere, model)
+ }
+ }
+ if len(whatIsNotThere) > 0 {
+ xlog.Debug("[p2p-sync] Announcing our models", "models", append(models, whatIsNotThere...))
+ ledger.AnnounceUpdate(
+ ctx,
+ 1*time.Minute,
+ "shared_state",
+ "models",
+ append(models, whatIsNotThere...),
+ )
+ }
+
+ // Check if we have a model that is not in our state, otherwise install it
+ for _, model := range models {
+ if slices.Contains(whatWeHave, model) {
+ xlog.Debug("[p2p-sync] Model is already present in this instance", "model", model)
+ continue
+ }
+
+ // we install model
+ xlog.Info("[p2p-sync] Installing model which is not present in this instance", "model", model)
+
+ uuid, err := uuid.NewUUID()
+ if err != nil {
+ xlog.Error("error generating UUID", "error", err)
+ continue
+ }
+
+ app.GalleryService().ModelGalleryChannel <- services.GalleryOp[gallery.GalleryModel, gallery.ModelConfig]{
+ ID: uuid.String(),
+ GalleryElementName: model,
+ Galleries: app.ApplicationConfig().Galleries,
+ BackendGalleries: app.ApplicationConfig().BackendGalleries,
+ }
+ }
+
+ return nil
+}
+
+func (a *Application) p2pSync(ctx context.Context, n *node.Node) error {
+ go func() {
+ for {
+ select {
+ case <-ctx.Done():
+ return
+ case <-time.After(1 * time.Minute):
+ if err := syncState(ctx, n, a); err != nil {
+ xlog.Error("error syncing state", "error", err)
+ }
+ }
+
+ }
+ }()
+ return nil
+}
diff --git a/core/application/startup.go b/core/application/startup.go
new file mode 100644
index 0000000000000000000000000000000000000000..68e24f196fba5d496ed178ccd0e4ede34d86ee45
--- /dev/null
+++ b/core/application/startup.go
@@ -0,0 +1,376 @@
+package application
+
+import (
+ "encoding/json"
+ "fmt"
+ "os"
+ "path/filepath"
+ "time"
+
+ "github.com/mudler/LocalAI/core/backend"
+ "github.com/mudler/LocalAI/core/config"
+ "github.com/mudler/LocalAI/core/gallery"
+ "github.com/mudler/LocalAI/core/services"
+ coreStartup "github.com/mudler/LocalAI/core/startup"
+ "github.com/mudler/LocalAI/internal"
+
+ "github.com/mudler/LocalAI/pkg/model"
+ "github.com/mudler/LocalAI/pkg/xsysinfo"
+ "github.com/mudler/xlog"
+)
+
+func New(opts ...config.AppOption) (*Application, error) {
+ options := config.NewApplicationConfig(opts...)
+
+ // Store a copy of the startup config (from env vars, before file loading)
+ // This is used to determine if settings came from env vars vs file
+ startupConfigCopy := *options
+ application := newApplication(options)
+ application.startupConfig = &startupConfigCopy
+
+ xlog.Info("Starting LocalAI", "threads", options.Threads, "modelsPath", options.SystemState.Model.ModelsPath)
+ xlog.Info("LocalAI version", "version", internal.PrintableVersion())
+
+ if err := application.start(); err != nil {
+ return nil, err
+ }
+
+ caps, err := xsysinfo.CPUCapabilities()
+ if err == nil {
+ xlog.Debug("CPU capabilities", "capabilities", caps)
+
+ }
+ gpus, err := xsysinfo.GPUs()
+ if err == nil {
+ xlog.Debug("GPU count", "count", len(gpus))
+ for _, gpu := range gpus {
+ xlog.Debug("GPU", "gpu", gpu.String())
+ }
+ }
+
+ // Make sure directories exists
+ if options.SystemState.Model.ModelsPath == "" {
+ return nil, fmt.Errorf("models path cannot be empty")
+ }
+
+ err = os.MkdirAll(options.SystemState.Model.ModelsPath, 0750)
+ if err != nil {
+ return nil, fmt.Errorf("unable to create ModelPath: %q", err)
+ }
+ if options.GeneratedContentDir != "" {
+ err := os.MkdirAll(options.GeneratedContentDir, 0750)
+ if err != nil {
+ return nil, fmt.Errorf("unable to create ImageDir: %q", err)
+ }
+ }
+ if options.UploadDir != "" {
+ err := os.MkdirAll(options.UploadDir, 0750)
+ if err != nil {
+ return nil, fmt.Errorf("unable to create UploadDir: %q", err)
+ }
+ }
+
+ if err := coreStartup.InstallModels(options.Context, application.GalleryService(), options.Galleries, options.BackendGalleries, options.SystemState, application.ModelLoader(), options.EnforcePredownloadScans, options.AutoloadBackendGalleries, nil, options.ModelsURL...); err != nil {
+ xlog.Error("error installing models", "error", err)
+ }
+
+ for _, backend := range options.ExternalBackends {
+ if err := services.InstallExternalBackend(options.Context, options.BackendGalleries, options.SystemState, application.ModelLoader(), nil, backend, "", ""); err != nil {
+ xlog.Error("error installing external backend", "error", err)
+ }
+ }
+
+ configLoaderOpts := options.ToConfigLoaderOptions()
+
+ if err := application.ModelConfigLoader().LoadModelConfigsFromPath(options.SystemState.Model.ModelsPath, configLoaderOpts...); err != nil {
+ xlog.Error("error loading config files", "error", err)
+ }
+
+ if err := gallery.RegisterBackends(options.SystemState, application.ModelLoader()); err != nil {
+ xlog.Error("error registering external backends", "error", err)
+ }
+
+ if options.ConfigFile != "" {
+ if err := application.ModelConfigLoader().LoadMultipleModelConfigsSingleFile(options.ConfigFile, configLoaderOpts...); err != nil {
+ xlog.Error("error loading config file", "error", err)
+ }
+ }
+
+ if err := application.ModelConfigLoader().Preload(options.SystemState.Model.ModelsPath); err != nil {
+ xlog.Error("error downloading models", "error", err)
+ }
+
+ if options.PreloadJSONModels != "" {
+ if err := services.ApplyGalleryFromString(options.SystemState, application.ModelLoader(), options.EnforcePredownloadScans, options.AutoloadBackendGalleries, options.Galleries, options.BackendGalleries, options.PreloadJSONModels); err != nil {
+ return nil, err
+ }
+ }
+
+ if options.PreloadModelsFromPath != "" {
+ if err := services.ApplyGalleryFromFile(options.SystemState, application.ModelLoader(), options.EnforcePredownloadScans, options.AutoloadBackendGalleries, options.Galleries, options.BackendGalleries, options.PreloadModelsFromPath); err != nil {
+ return nil, err
+ }
+ }
+
+ if options.Debug {
+ for _, v := range application.ModelConfigLoader().GetAllModelsConfigs() {
+ xlog.Debug("Model", "name", v.Name, "config", v)
+ }
+ }
+
+ // Load runtime settings from file if DynamicConfigsDir is set
+ // This applies file settings with env var precedence (env vars take priority)
+ // Note: startupConfigCopy was already created above, so it has the original env var values
+ if options.DynamicConfigsDir != "" {
+ loadRuntimeSettingsFromFile(options)
+ }
+
+ // turn off any process that was started by GRPC if the context is canceled
+ go func() {
+ <-options.Context.Done()
+ xlog.Debug("Context canceled, shutting down")
+ err := application.ModelLoader().StopAllGRPC()
+ if err != nil {
+ xlog.Error("error while stopping all grpc backends", "error", err)
+ }
+ }()
+
+ // Initialize watchdog with current settings (after loading from file)
+ initializeWatchdog(application, options)
+
+ if options.LoadToMemory != nil && !options.SingleBackend {
+ for _, m := range options.LoadToMemory {
+ cfg, err := application.ModelConfigLoader().LoadModelConfigFileByNameDefaultOptions(m, options)
+ if err != nil {
+ return nil, err
+ }
+
+ xlog.Debug("Auto loading model into memory from file", "model", m, "file", cfg.Model)
+
+ o := backend.ModelOptions(*cfg, options)
+
+ var backendErr error
+ _, backendErr = application.ModelLoader().Load(o...)
+ if backendErr != nil {
+ return nil, err
+ }
+ }
+ }
+
+ // Watch the configuration directory
+ startWatcher(options)
+
+ xlog.Info("core/startup process completed!")
+ return application, nil
+}
+
+func startWatcher(options *config.ApplicationConfig) {
+ if options.DynamicConfigsDir == "" {
+ // No need to start the watcher if the directory is not set
+ return
+ }
+
+ if _, err := os.Stat(options.DynamicConfigsDir); err != nil {
+ if os.IsNotExist(err) {
+ // We try to create the directory if it does not exist and was specified
+ if err := os.MkdirAll(options.DynamicConfigsDir, 0700); err != nil {
+ xlog.Error("failed creating DynamicConfigsDir", "error", err)
+ }
+ } else {
+ // something else happened, we log the error and don't start the watcher
+ xlog.Error("failed to read DynamicConfigsDir, watcher will not be started", "error", err)
+ return
+ }
+ }
+
+ configHandler := newConfigFileHandler(options)
+ if err := configHandler.Watch(); err != nil {
+ xlog.Error("failed creating watcher", "error", err)
+ }
+}
+
+// loadRuntimeSettingsFromFile loads settings from runtime_settings.json with env var precedence
+// This function is called at startup, before env vars are applied via AppOptions.
+// Since env vars are applied via AppOptions in run.go, we need to check if they're set.
+// We do this by checking if the current options values differ from defaults, which would
+// indicate they were set from env vars. However, a simpler approach is to just apply
+// file settings here, and let the AppOptions (which are applied after this) override them.
+// But actually, this is called AFTER AppOptions are applied in New(), so we need to check env vars.
+// The cleanest solution: Store original values before applying file, or check if values match
+// what would be set from env vars. For now, we'll apply file settings and they'll be
+// overridden by AppOptions if env vars were set (but AppOptions are already applied).
+// Actually, this function is called in New() before AppOptions are fully processed for watchdog.
+// Let's check the call order: New() -> loadRuntimeSettingsFromFile() -> initializeWatchdog()
+// But AppOptions are applied in NewApplicationConfig() which is called first.
+// So at this point, options already has values from env vars. We should compare against
+// defaults to see if env vars were set. But we don't have defaults stored.
+// Simplest: Just apply file settings. If env vars were set, they're already in options.
+// The file watcher handler will handle runtime changes properly by comparing with startupAppConfig.
+func loadRuntimeSettingsFromFile(options *config.ApplicationConfig) {
+ settingsFile := filepath.Join(options.DynamicConfigsDir, "runtime_settings.json")
+ fileContent, err := os.ReadFile(settingsFile)
+ if err != nil {
+ if os.IsNotExist(err) {
+ xlog.Debug("runtime_settings.json not found, using defaults")
+ return
+ }
+ xlog.Warn("failed to read runtime_settings.json", "error", err)
+ return
+ }
+
+ var settings config.RuntimeSettings
+
+ if err := json.Unmarshal(fileContent, &settings); err != nil {
+ xlog.Warn("failed to parse runtime_settings.json", "error", err)
+ return
+ }
+
+ // At this point, options already has values from env vars (via AppOptions in run.go).
+ // To avoid env var duplication, we determine if env vars were set by checking if
+ // current values differ from defaults. Defaults are: false for bools, 0 for durations.
+ // If current value is at default, it likely wasn't set from env var, so we can apply file.
+ // If current value is non-default, it was likely set from env var, so we preserve it.
+ // Note: This means env vars explicitly setting to false/0 won't be distinguishable from defaults,
+ // but that's an acceptable limitation to avoid env var duplication.
+
+ if settings.WatchdogIdleEnabled != nil {
+ // Only apply if current value is default (false), suggesting it wasn't set from env var
+ if !options.WatchDogIdle {
+ options.WatchDogIdle = *settings.WatchdogIdleEnabled
+ if options.WatchDogIdle {
+ options.WatchDog = true
+ }
+ }
+ }
+ if settings.WatchdogBusyEnabled != nil {
+ if !options.WatchDogBusy {
+ options.WatchDogBusy = *settings.WatchdogBusyEnabled
+ if options.WatchDogBusy {
+ options.WatchDog = true
+ }
+ }
+ }
+ if settings.WatchdogIdleTimeout != nil {
+ // Only apply if current value is default (0), suggesting it wasn't set from env var
+ if options.WatchDogIdleTimeout == 0 {
+ dur, err := time.ParseDuration(*settings.WatchdogIdleTimeout)
+ if err == nil {
+ options.WatchDogIdleTimeout = dur
+ } else {
+ xlog.Warn("invalid watchdog idle timeout in runtime_settings.json", "error", err, "timeout", *settings.WatchdogIdleTimeout)
+ }
+ }
+ }
+ if settings.WatchdogBusyTimeout != nil {
+ if options.WatchDogBusyTimeout == 0 {
+ dur, err := time.ParseDuration(*settings.WatchdogBusyTimeout)
+ if err == nil {
+ options.WatchDogBusyTimeout = dur
+ } else {
+ xlog.Warn("invalid watchdog busy timeout in runtime_settings.json", "error", err, "timeout", *settings.WatchdogBusyTimeout)
+ }
+ }
+ }
+ if settings.WatchdogInterval != nil {
+ if options.WatchDogInterval == 0 {
+ dur, err := time.ParseDuration(*settings.WatchdogInterval)
+ if err == nil {
+ options.WatchDogInterval = dur
+ } else {
+ xlog.Warn("invalid watchdog interval in runtime_settings.json", "error", err, "interval", *settings.WatchdogInterval)
+ options.WatchDogInterval = model.DefaultWatchdogInterval
+ }
+ }
+ }
+ // Handle MaxActiveBackends (new) and SingleBackend (deprecated)
+ if settings.MaxActiveBackends != nil {
+ // Only apply if current value is default (0), suggesting it wasn't set from env var
+ if options.MaxActiveBackends == 0 {
+ options.MaxActiveBackends = *settings.MaxActiveBackends
+ // For backward compatibility, also set SingleBackend if MaxActiveBackends == 1
+ options.SingleBackend = (*settings.MaxActiveBackends == 1)
+ }
+ } else if settings.SingleBackend != nil {
+ // Legacy: SingleBackend maps to MaxActiveBackends = 1
+ if !options.SingleBackend {
+ options.SingleBackend = *settings.SingleBackend
+ if *settings.SingleBackend {
+ options.MaxActiveBackends = 1
+ }
+ }
+ }
+ if settings.ParallelBackendRequests != nil {
+ if !options.ParallelBackendRequests {
+ options.ParallelBackendRequests = *settings.ParallelBackendRequests
+ }
+ }
+ if settings.MemoryReclaimerEnabled != nil {
+ // Only apply if current value is default (false), suggesting it wasn't set from env var
+ if !options.MemoryReclaimerEnabled {
+ options.MemoryReclaimerEnabled = *settings.MemoryReclaimerEnabled
+ if options.MemoryReclaimerEnabled {
+ options.WatchDog = true // Memory reclaimer requires watchdog
+ }
+ }
+ }
+ if settings.MemoryReclaimerThreshold != nil {
+ // Only apply if current value is default (0), suggesting it wasn't set from env var
+ if options.MemoryReclaimerThreshold == 0 {
+ options.MemoryReclaimerThreshold = *settings.MemoryReclaimerThreshold
+ }
+ }
+ if settings.AgentJobRetentionDays != nil {
+ // Only apply if current value is default (0), suggesting it wasn't set from env var
+ if options.AgentJobRetentionDays == 0 {
+ options.AgentJobRetentionDays = *settings.AgentJobRetentionDays
+ }
+ }
+ if !options.WatchDogIdle && !options.WatchDogBusy {
+ if settings.WatchdogEnabled != nil && *settings.WatchdogEnabled {
+ options.WatchDog = true
+ }
+ }
+
+ xlog.Debug("Runtime settings loaded from runtime_settings.json")
+}
+
+// initializeWatchdog initializes the watchdog with current ApplicationConfig settings
+func initializeWatchdog(application *Application, options *config.ApplicationConfig) {
+ // Get effective max active backends (considers both MaxActiveBackends and deprecated SingleBackend)
+ lruLimit := options.GetEffectiveMaxActiveBackends()
+
+ // Create watchdog if enabled OR if LRU limit is set OR if memory reclaimer is enabled
+ if options.WatchDog || lruLimit > 0 || options.MemoryReclaimerEnabled {
+ wd := model.NewWatchDog(
+ model.WithProcessManager(application.ModelLoader()),
+ model.WithBusyTimeout(options.WatchDogBusyTimeout),
+ model.WithIdleTimeout(options.WatchDogIdleTimeout),
+ model.WithWatchdogInterval(options.WatchDogInterval),
+ model.WithBusyCheck(options.WatchDogBusy),
+ model.WithIdleCheck(options.WatchDogIdle),
+ model.WithLRULimit(lruLimit),
+ model.WithMemoryReclaimer(options.MemoryReclaimerEnabled, options.MemoryReclaimerThreshold),
+ model.WithForceEvictionWhenBusy(options.ForceEvictionWhenBusy),
+ )
+ application.ModelLoader().SetWatchDog(wd)
+
+ // Initialize ModelLoader LRU eviction retry settings
+ application.ModelLoader().SetLRUEvictionRetrySettings(
+ options.LRUEvictionMaxRetries,
+ options.LRUEvictionRetryInterval,
+ )
+
+ // Start watchdog goroutine if any periodic checks are enabled
+ // LRU eviction doesn't need the Run() loop - it's triggered on model load
+ // But memory reclaimer needs the Run() loop for periodic checking
+ if options.WatchDogBusy || options.WatchDogIdle || options.MemoryReclaimerEnabled {
+ go wd.Run()
+ }
+
+ go func() {
+ <-options.Context.Done()
+ xlog.Debug("Context canceled, shutting down")
+ wd.Shutdown()
+ }()
+ }
+}
diff --git a/core/application/watchdog.go b/core/application/watchdog.go
new file mode 100644
index 0000000000000000000000000000000000000000..054205fef39f7a4a323faa5402647a35f9381a92
--- /dev/null
+++ b/core/application/watchdog.go
@@ -0,0 +1,101 @@
+package application
+
+import (
+ "time"
+
+ "github.com/mudler/LocalAI/pkg/model"
+ "github.com/mudler/xlog"
+)
+
+func (a *Application) StopWatchdog() error {
+ if a.watchdogStop != nil {
+ close(a.watchdogStop)
+ a.watchdogStop = nil
+ }
+ return nil
+}
+
+// startWatchdog starts the watchdog with current ApplicationConfig settings
+// This is an internal method that assumes the caller holds the watchdogMutex
+func (a *Application) startWatchdog() error {
+ appConfig := a.ApplicationConfig()
+
+ // Get effective max active backends (considers both MaxActiveBackends and deprecated SingleBackend)
+ lruLimit := appConfig.GetEffectiveMaxActiveBackends()
+
+ // Create watchdog if enabled OR if LRU limit is set OR if memory reclaimer is enabled
+ // LRU eviction requires watchdog infrastructure even without busy/idle checks
+ if appConfig.WatchDog || lruLimit > 0 || appConfig.MemoryReclaimerEnabled {
+ wd := model.NewWatchDog(
+ model.WithProcessManager(a.modelLoader),
+ model.WithBusyTimeout(appConfig.WatchDogBusyTimeout),
+ model.WithIdleTimeout(appConfig.WatchDogIdleTimeout),
+ model.WithWatchdogInterval(appConfig.WatchDogInterval),
+ model.WithBusyCheck(appConfig.WatchDogBusy),
+ model.WithIdleCheck(appConfig.WatchDogIdle),
+ model.WithLRULimit(lruLimit),
+ model.WithMemoryReclaimer(appConfig.MemoryReclaimerEnabled, appConfig.MemoryReclaimerThreshold),
+ model.WithForceEvictionWhenBusy(appConfig.ForceEvictionWhenBusy),
+ )
+ a.modelLoader.SetWatchDog(wd)
+
+ // Create new stop channel
+ a.watchdogStop = make(chan bool, 1)
+
+ // Start watchdog goroutine if any periodic checks are enabled
+ // LRU eviction doesn't need the Run() loop - it's triggered on model load
+ // But memory reclaimer needs the Run() loop for periodic checking
+ if appConfig.WatchDogBusy || appConfig.WatchDogIdle || appConfig.MemoryReclaimerEnabled {
+ go wd.Run()
+ }
+
+ // Setup shutdown handler
+ go func() {
+ select {
+ case <-a.watchdogStop:
+ xlog.Debug("Watchdog stop signal received")
+ wd.Shutdown()
+ case <-appConfig.Context.Done():
+ xlog.Debug("Context canceled, shutting down watchdog")
+ wd.Shutdown()
+ }
+ }()
+
+ xlog.Info("Watchdog started with new settings", "lruLimit", lruLimit, "busyCheck", appConfig.WatchDogBusy, "idleCheck", appConfig.WatchDogIdle, "memoryReclaimer", appConfig.MemoryReclaimerEnabled, "memoryThreshold", appConfig.MemoryReclaimerThreshold, "interval", appConfig.WatchDogInterval)
+ } else {
+ xlog.Info("Watchdog disabled")
+ }
+
+ return nil
+}
+
+// StartWatchdog starts the watchdog with current ApplicationConfig settings
+func (a *Application) StartWatchdog() error {
+ a.watchdogMutex.Lock()
+ defer a.watchdogMutex.Unlock()
+
+ return a.startWatchdog()
+}
+
+// RestartWatchdog restarts the watchdog with current ApplicationConfig settings
+func (a *Application) RestartWatchdog() error {
+ a.watchdogMutex.Lock()
+ defer a.watchdogMutex.Unlock()
+
+ // Shutdown existing watchdog if running
+ if a.watchdogStop != nil {
+ close(a.watchdogStop)
+ a.watchdogStop = nil
+ }
+
+ // Shutdown existing watchdog if running
+ currentWD := a.modelLoader.GetWatchDog()
+ if currentWD != nil {
+ currentWD.Shutdown()
+ // Wait a bit for shutdown to complete
+ time.Sleep(100 * time.Millisecond)
+ }
+
+ // Start watchdog with new settings
+ return a.startWatchdog()
+}
diff --git a/core/backend/backend_suite_test.go b/core/backend/backend_suite_test.go
new file mode 100644
index 0000000000000000000000000000000000000000..541c91f6be7915b1cdd2834b4b192947e2a438dd
--- /dev/null
+++ b/core/backend/backend_suite_test.go
@@ -0,0 +1,13 @@
+package backend_test
+
+import (
+ "testing"
+
+ . "github.com/onsi/ginkgo/v2"
+ . "github.com/onsi/gomega"
+)
+
+func TestBackend(t *testing.T) {
+ RegisterFailHandler(Fail)
+ RunSpecs(t, "Backend test suite")
+}
diff --git a/core/backend/detection.go b/core/backend/detection.go
new file mode 100644
index 0000000000000000000000000000000000000000..1b199182414595f5c16a69d355b11a7d8bb30390
--- /dev/null
+++ b/core/backend/detection.go
@@ -0,0 +1,33 @@
+package backend
+
+import (
+ "context"
+ "fmt"
+
+ "github.com/mudler/LocalAI/core/config"
+ "github.com/mudler/LocalAI/pkg/grpc/proto"
+ "github.com/mudler/LocalAI/pkg/model"
+)
+
+func Detection(
+ sourceFile string,
+ loader *model.ModelLoader,
+ appConfig *config.ApplicationConfig,
+ modelConfig config.ModelConfig,
+) (*proto.DetectResponse, error) {
+ opts := ModelOptions(modelConfig, appConfig)
+ detectionModel, err := loader.Load(opts...)
+ if err != nil {
+ return nil, err
+ }
+
+ if detectionModel == nil {
+ return nil, fmt.Errorf("could not load detection model")
+ }
+
+ res, err := detectionModel.Detect(context.Background(), &proto.DetectOptions{
+ Src: sourceFile,
+ })
+
+ return res, err
+}
diff --git a/core/backend/embeddings.go b/core/backend/embeddings.go
new file mode 100644
index 0000000000000000000000000000000000000000..2383023c0dc17e6cb33f6fc88554fe188391db85
--- /dev/null
+++ b/core/backend/embeddings.go
@@ -0,0 +1,71 @@
+package backend
+
+import (
+ "fmt"
+
+ "github.com/mudler/LocalAI/core/config"
+
+ "github.com/mudler/LocalAI/pkg/grpc"
+ model "github.com/mudler/LocalAI/pkg/model"
+)
+
+func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (func() ([]float32, error), error) {
+
+ opts := ModelOptions(modelConfig, appConfig)
+
+ inferenceModel, err := loader.Load(opts...)
+ if err != nil {
+ return nil, err
+ }
+
+ var fn func() ([]float32, error)
+ switch model := inferenceModel.(type) {
+ case grpc.Backend:
+ fn = func() ([]float32, error) {
+ predictOptions := gRPCPredictOpts(modelConfig, loader.ModelPath)
+ if len(tokens) > 0 {
+ embeds := []int32{}
+
+ for _, t := range tokens {
+ embeds = append(embeds, int32(t))
+ }
+ predictOptions.EmbeddingTokens = embeds
+
+ res, err := model.Embeddings(appConfig.Context, predictOptions)
+ if err != nil {
+ return nil, err
+ }
+
+ return res.Embeddings, nil
+ }
+ predictOptions.Embeddings = s
+
+ res, err := model.Embeddings(appConfig.Context, predictOptions)
+ if err != nil {
+ return nil, err
+ }
+
+ return res.Embeddings, nil
+ }
+ default:
+ fn = func() ([]float32, error) {
+ return nil, fmt.Errorf("embeddings not supported by the backend")
+ }
+ }
+
+ return func() ([]float32, error) {
+ embeds, err := fn()
+ if err != nil {
+ return embeds, err
+ }
+ // Remove trailing 0s
+ for i := len(embeds) - 1; i >= 0; i-- {
+ if embeds[i] == 0.0 {
+ embeds = embeds[:i]
+ } else {
+ break
+ }
+ }
+ return embeds, nil
+ }, nil
+}
diff --git a/core/backend/image.go b/core/backend/image.go
new file mode 100644
index 0000000000000000000000000000000000000000..651293cf5e1bc9049167714a85a7374467fffa8a
--- /dev/null
+++ b/core/backend/image.go
@@ -0,0 +1,44 @@
+package backend
+
+import (
+ "github.com/mudler/LocalAI/core/config"
+
+ "github.com/mudler/LocalAI/pkg/grpc/proto"
+ model "github.com/mudler/LocalAI/pkg/model"
+)
+
+func ImageGeneration(height, width, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig, refImages []string) (func() error, error) {
+
+ opts := ModelOptions(modelConfig, appConfig)
+ inferenceModel, err := loader.Load(
+ opts...,
+ )
+ if err != nil {
+ return nil, err
+ }
+
+ fn := func() error {
+ _, err := inferenceModel.GenerateImage(
+ appConfig.Context,
+ &proto.GenerateImageRequest{
+ Height: int32(height),
+ Width: int32(width),
+ Step: int32(step),
+ Seed: int32(seed),
+ CLIPSkip: int32(modelConfig.Diffusers.ClipSkip),
+ PositivePrompt: positive_prompt,
+ NegativePrompt: negative_prompt,
+ Dst: dst,
+ Src: src,
+ EnableParameters: modelConfig.Diffusers.EnableParameters,
+ RefImages: refImages,
+ })
+ return err
+ }
+
+ return fn, nil
+}
+
+// ImageGenerationFunc is a test-friendly indirection to call image generation logic.
+// Tests can override this variable to provide a stub implementation.
+var ImageGenerationFunc = ImageGeneration
diff --git a/core/backend/llm.go b/core/backend/llm.go
new file mode 100644
index 0000000000000000000000000000000000000000..06b9d2d4480c3ffe7b78fe75afd84699fbfcb729
--- /dev/null
+++ b/core/backend/llm.go
@@ -0,0 +1,265 @@
+package backend
+
+import (
+ "context"
+ "encoding/json"
+ "regexp"
+ "slices"
+ "strings"
+ "sync"
+ "unicode/utf8"
+
+ "github.com/mudler/xlog"
+
+ "github.com/mudler/LocalAI/core/config"
+ "github.com/mudler/LocalAI/core/schema"
+ "github.com/mudler/LocalAI/core/services"
+
+ "github.com/mudler/LocalAI/core/gallery"
+ "github.com/mudler/LocalAI/pkg/grpc/proto"
+ model "github.com/mudler/LocalAI/pkg/model"
+ "github.com/mudler/LocalAI/pkg/utils"
+)
+
+type LLMResponse struct {
+ Response string // should this be []byte?
+ Usage TokenUsage
+ AudioOutput string
+ Logprobs *schema.Logprobs // Logprobs from the backend response
+}
+
+type TokenUsage struct {
+ Prompt int
+ Completion int
+ TimingPromptProcessing float64
+ TimingTokenGeneration float64
+}
+
+func ModelInference(ctx context.Context, s string, messages schema.Messages, images, videos, audios []string, loader *model.ModelLoader, c *config.ModelConfig, cl *config.ModelConfigLoader, o *config.ApplicationConfig, tokenCallback func(string, TokenUsage) bool, tools string, toolChoice string, logprobs *int, topLogprobs *int, logitBias map[string]float64) (func() (LLMResponse, error), error) {
+ modelFile := c.Model
+
+ // Check if the modelFile exists, if it doesn't try to load it from the gallery
+ if o.AutoloadGalleries { // experimental
+ modelNames, err := services.ListModels(cl, loader, nil, services.SKIP_ALWAYS)
+ if err != nil {
+ return nil, err
+ }
+ if !slices.Contains(modelNames, c.Name) {
+ utils.ResetDownloadTimers()
+ // if we failed to load the model, we try to download it
+ err := gallery.InstallModelFromGallery(ctx, o.Galleries, o.BackendGalleries, o.SystemState, loader, c.Name, gallery.GalleryModel{}, utils.DisplayDownloadFunction, o.EnforcePredownloadScans, o.AutoloadBackendGalleries)
+ if err != nil {
+ xlog.Error("failed to install model from gallery", "error", err, "model", modelFile)
+ //return nil, err
+ }
+ }
+ }
+
+ opts := ModelOptions(*c, o)
+ inferenceModel, err := loader.Load(opts...)
+ if err != nil {
+ return nil, err
+ }
+
+ var protoMessages []*proto.Message
+ // if we are using the tokenizer template, we need to convert the messages to proto messages
+ // unless the prompt has already been tokenized (non-chat endpoints + functions)
+ if c.TemplateConfig.UseTokenizerTemplate && len(messages) > 0 {
+ protoMessages = messages.ToProto()
+ }
+
+ // in GRPC, the backend is supposed to answer to 1 single token if stream is not supported
+ fn := func() (LLMResponse, error) {
+ opts := gRPCPredictOpts(*c, loader.ModelPath)
+ opts.Prompt = s
+ opts.Messages = protoMessages
+ opts.UseTokenizerTemplate = c.TemplateConfig.UseTokenizerTemplate
+ opts.Images = images
+ opts.Videos = videos
+ opts.Audios = audios
+ opts.Tools = tools
+ opts.ToolChoice = toolChoice
+ if logprobs != nil {
+ opts.Logprobs = int32(*logprobs)
+ }
+ if topLogprobs != nil {
+ opts.TopLogprobs = int32(*topLogprobs)
+ }
+ if len(logitBias) > 0 {
+ // Serialize logit_bias map to JSON string for proto
+ logitBiasJSON, err := json.Marshal(logitBias)
+ if err == nil {
+ opts.LogitBias = string(logitBiasJSON)
+ }
+ }
+
+ tokenUsage := TokenUsage{}
+
+ // check the per-model feature flag for usage, since tokenCallback may have a cost.
+ // Defaults to off as for now it is still experimental
+ if c.FeatureFlag.Enabled("usage") {
+ userTokenCallback := tokenCallback
+ if userTokenCallback == nil {
+ userTokenCallback = func(token string, usage TokenUsage) bool {
+ return true
+ }
+ }
+
+ promptInfo, pErr := inferenceModel.TokenizeString(ctx, opts)
+ if pErr == nil && promptInfo.Length > 0 {
+ tokenUsage.Prompt = int(promptInfo.Length)
+ }
+
+ tokenCallback = func(token string, usage TokenUsage) bool {
+ tokenUsage.Completion++
+ return userTokenCallback(token, tokenUsage)
+ }
+ }
+
+ if tokenCallback != nil {
+
+ if c.TemplateConfig.ReplyPrefix != "" {
+ tokenCallback(c.TemplateConfig.ReplyPrefix, tokenUsage)
+ }
+
+ ss := ""
+ var logprobs *schema.Logprobs
+
+ var partialRune []byte
+ err := inferenceModel.PredictStream(ctx, opts, func(reply *proto.Reply) {
+ msg := reply.Message
+ partialRune = append(partialRune, msg...)
+
+ tokenUsage.Prompt = int(reply.PromptTokens)
+ tokenUsage.Completion = int(reply.Tokens)
+ tokenUsage.TimingTokenGeneration = reply.TimingTokenGeneration
+ tokenUsage.TimingPromptProcessing = reply.TimingPromptProcessing
+
+ // Parse logprobs from reply if present (collect from last chunk that has them)
+ if len(reply.Logprobs) > 0 {
+ var parsedLogprobs schema.Logprobs
+ if err := json.Unmarshal(reply.Logprobs, &parsedLogprobs); err == nil {
+ logprobs = &parsedLogprobs
+ }
+ }
+
+ // Process complete runes and accumulate them
+ var completeRunes []byte
+ for len(partialRune) > 0 {
+ r, size := utf8.DecodeRune(partialRune)
+ if r == utf8.RuneError {
+ // incomplete rune, wait for more bytes
+ break
+ }
+ completeRunes = append(completeRunes, partialRune[:size]...)
+ partialRune = partialRune[size:]
+ }
+
+ // If we have complete runes, send them as a single token
+ if len(completeRunes) > 0 {
+ tokenCallback(string(completeRunes), tokenUsage)
+ ss += string(completeRunes)
+ }
+
+ if len(msg) == 0 {
+ tokenCallback("", tokenUsage)
+ }
+ })
+ return LLMResponse{
+ Response: ss,
+ Usage: tokenUsage,
+ Logprobs: logprobs,
+ }, err
+ } else {
+ // TODO: Is the chicken bit the only way to get here? is that acceptable?
+ reply, err := inferenceModel.Predict(ctx, opts)
+ if err != nil {
+ return LLMResponse{}, err
+ }
+ if tokenUsage.Prompt == 0 {
+ tokenUsage.Prompt = int(reply.PromptTokens)
+ }
+ if tokenUsage.Completion == 0 {
+ tokenUsage.Completion = int(reply.Tokens)
+ }
+
+ tokenUsage.TimingTokenGeneration = reply.TimingTokenGeneration
+ tokenUsage.TimingPromptProcessing = reply.TimingPromptProcessing
+
+ response := string(reply.Message)
+ if c.TemplateConfig.ReplyPrefix != "" {
+ response = c.TemplateConfig.ReplyPrefix + response
+ }
+
+ // Parse logprobs from reply if present
+ var logprobs *schema.Logprobs
+ if len(reply.Logprobs) > 0 {
+ var parsedLogprobs schema.Logprobs
+ if err := json.Unmarshal(reply.Logprobs, &parsedLogprobs); err == nil {
+ logprobs = &parsedLogprobs
+ }
+ }
+
+ return LLMResponse{
+ Response: response,
+ Usage: tokenUsage,
+ Logprobs: logprobs,
+ }, err
+ }
+ }
+
+ return fn, nil
+}
+
+var cutstrings map[string]*regexp.Regexp = make(map[string]*regexp.Regexp)
+var mu sync.Mutex = sync.Mutex{}
+
+func Finetune(config config.ModelConfig, input, prediction string) string {
+ if config.Echo {
+ prediction = input + prediction
+ }
+
+ for _, c := range config.Cutstrings {
+ mu.Lock()
+ reg, ok := cutstrings[c]
+ if !ok {
+ r, err := regexp.Compile(c)
+ if err != nil {
+ xlog.Fatal("failed to compile regex", "error", err)
+ }
+ cutstrings[c] = r
+ reg = cutstrings[c]
+ }
+ mu.Unlock()
+ prediction = reg.ReplaceAllString(prediction, "")
+ }
+
+ // extract results from the response which can be for instance inside XML tags
+ var predResult string
+ for _, r := range config.ExtractRegex {
+ mu.Lock()
+ reg, ok := cutstrings[r]
+ if !ok {
+ regex, err := regexp.Compile(r)
+ if err != nil {
+ xlog.Fatal("failed to compile regex", "error", err)
+ }
+ cutstrings[r] = regex
+ reg = regex
+ }
+ mu.Unlock()
+ predResult += reg.FindString(prediction)
+ }
+ if predResult != "" {
+ prediction = predResult
+ }
+
+ for _, c := range config.TrimSpace {
+ prediction = strings.TrimSpace(strings.TrimPrefix(prediction, c))
+ }
+
+ for _, c := range config.TrimSuffix {
+ prediction = strings.TrimSpace(strings.TrimSuffix(prediction, c))
+ }
+ return prediction
+}
diff --git a/core/backend/llm_test.go b/core/backend/llm_test.go
new file mode 100644
index 0000000000000000000000000000000000000000..ea68a93156ef01fc8520d97901d691fde79f17ef
--- /dev/null
+++ b/core/backend/llm_test.go
@@ -0,0 +1,109 @@
+package backend_test
+
+import (
+ . "github.com/mudler/LocalAI/core/backend"
+ "github.com/mudler/LocalAI/core/config"
+ "github.com/mudler/LocalAI/core/schema"
+
+ . "github.com/onsi/ginkgo/v2"
+ . "github.com/onsi/gomega"
+)
+
+var _ = Describe("LLM tests", func() {
+ Context("Finetune LLM output", func() {
+ var (
+ testConfig config.ModelConfig
+ input string
+ prediction string
+ result string
+ )
+
+ BeforeEach(func() {
+ testConfig = config.ModelConfig{
+ PredictionOptions: schema.PredictionOptions{
+ Echo: false,
+ },
+ LLMConfig: config.LLMConfig{
+ Cutstrings: []string{`<.*?>`}, // Example regex for removing XML tags
+ ExtractRegex: []string{`(.*?) `}, // Example regex to extract from tags
+ TrimSpace: []string{" ", "\n"},
+ TrimSuffix: []string{".", "!"},
+ },
+ }
+ })
+
+ Context("when echo is enabled", func() {
+ BeforeEach(func() {
+ testConfig.Echo = true
+ input = "Hello"
+ prediction = "World"
+ })
+
+ It("should prepend input to prediction", func() {
+ result = Finetune(testConfig, input, prediction)
+ Expect(result).To(Equal("HelloWorld"))
+ })
+ })
+
+ Context("when echo is disabled", func() {
+ BeforeEach(func() {
+ testConfig.Echo = false
+ input = "Hello"
+ prediction = "World"
+ })
+
+ It("should not modify the prediction with input", func() {
+ result = Finetune(testConfig, input, prediction)
+ Expect(result).To(Equal("World"))
+ })
+ })
+
+ Context("when cutstrings regex is applied", func() {
+ BeforeEach(func() {
+ input = ""
+ prediction = "Hello
World"
+ })
+
+ It("should remove substrings matching cutstrings regex", func() {
+ result = Finetune(testConfig, input, prediction)
+ Expect(result).To(Equal("Hello World"))
+ })
+ })
+
+ Context("when extract regex is applied", func() {
+ BeforeEach(func() {
+ input = ""
+ prediction = "42 "
+ })
+
+ It("should extract substrings matching the extract regex", func() {
+ result = Finetune(testConfig, input, prediction)
+ Expect(result).To(Equal("42"))
+ })
+ })
+
+ Context("when trimming spaces", func() {
+ BeforeEach(func() {
+ input = ""
+ prediction = " Hello World "
+ })
+
+ It("should trim spaces from the prediction", func() {
+ result = Finetune(testConfig, input, prediction)
+ Expect(result).To(Equal("Hello World"))
+ })
+ })
+
+ Context("when trimming suffixes", func() {
+ BeforeEach(func() {
+ input = ""
+ prediction = "Hello World."
+ })
+
+ It("should trim suffixes from the prediction", func() {
+ result = Finetune(testConfig, input, prediction)
+ Expect(result).To(Equal("Hello World"))
+ })
+ })
+ })
+})
diff --git a/core/backend/options.go b/core/backend/options.go
new file mode 100644
index 0000000000000000000000000000000000000000..f3d5a4ccd4025537d62c2eae160b64e5acc5f584
--- /dev/null
+++ b/core/backend/options.go
@@ -0,0 +1,259 @@
+package backend
+
+import (
+ "math/rand"
+ "os"
+ "path/filepath"
+
+ "github.com/mudler/LocalAI/core/config"
+ pb "github.com/mudler/LocalAI/pkg/grpc/proto"
+ "github.com/mudler/LocalAI/pkg/model"
+ "github.com/mudler/xlog"
+)
+
+func ModelOptions(c config.ModelConfig, so *config.ApplicationConfig, opts ...model.Option) []model.Option {
+ name := c.Name
+ if name == "" {
+ name = c.Model
+ }
+
+ defOpts := []model.Option{
+ model.WithBackendString(c.Backend),
+ model.WithModel(c.Model),
+ model.WithContext(so.Context),
+ model.WithModelID(name),
+ }
+
+ threads := 1
+
+ if c.Threads != nil {
+ threads = *c.Threads
+ }
+
+ if so.Threads != 0 {
+ threads = so.Threads
+ }
+
+ c.Threads = &threads
+
+ grpcOpts := grpcModelOpts(c, so.SystemState.Model.ModelsPath)
+ defOpts = append(defOpts, model.WithLoadGRPCLoadModelOpts(grpcOpts))
+
+ if so.ParallelBackendRequests {
+ defOpts = append(defOpts, model.EnableParallelRequests)
+ }
+
+ if c.GRPC.Attempts != 0 {
+ defOpts = append(defOpts, model.WithGRPCAttempts(c.GRPC.Attempts))
+ }
+
+ if c.GRPC.AttemptsSleepTime != 0 {
+ defOpts = append(defOpts, model.WithGRPCAttemptsDelay(c.GRPC.AttemptsSleepTime))
+ }
+
+ for k, v := range so.ExternalGRPCBackends {
+ defOpts = append(defOpts, model.WithExternalBackend(k, v))
+ }
+
+ return append(defOpts, opts...)
+}
+
+func getSeed(c config.ModelConfig) int32 {
+ var seed int32 = config.RAND_SEED
+
+ if c.Seed != nil {
+ seed = int32(*c.Seed)
+ }
+
+ if seed == config.RAND_SEED {
+ seed = rand.Int31()
+ }
+
+ return seed
+}
+
+func grpcModelOpts(c config.ModelConfig, modelPath string) *pb.ModelOptions {
+ b := 512
+ if c.Batch != 0 {
+ b = c.Batch
+ }
+
+ flashAttention := "auto"
+
+ if c.FlashAttention != nil {
+ flashAttention = *c.FlashAttention
+ }
+
+ f16 := false
+ if c.F16 != nil {
+ f16 = *c.F16
+ }
+
+ embeddings := false
+ if c.Embeddings != nil {
+ embeddings = *c.Embeddings
+ }
+
+ lowVRAM := false
+ if c.LowVRAM != nil {
+ lowVRAM = *c.LowVRAM
+ }
+
+ reranking := false
+ if c.Reranking != nil {
+ reranking = *c.Reranking
+ }
+
+ mmap := false
+ if c.MMap != nil {
+ mmap = *c.MMap
+ }
+
+ ctxSize := 4096
+ if c.ContextSize != nil {
+ ctxSize = *c.ContextSize
+ }
+
+ mmlock := false
+ if c.MMlock != nil {
+ mmlock = *c.MMlock
+ }
+
+ nGPULayers := 9999999
+ if c.NGPULayers != nil {
+ nGPULayers = *c.NGPULayers
+ }
+
+ triggers := make([]*pb.GrammarTrigger, 0)
+ for _, t := range c.FunctionsConfig.GrammarConfig.GrammarTriggers {
+ triggers = append(triggers, &pb.GrammarTrigger{
+ Word: t.Word,
+ })
+ }
+
+ opts := &pb.ModelOptions{
+ CUDA: c.CUDA || c.Diffusers.CUDA,
+ SchedulerType: c.Diffusers.SchedulerType,
+ GrammarTriggers: triggers,
+ PipelineType: c.Diffusers.PipelineType,
+ CFGScale: c.CFGScale,
+ LoraAdapter: c.LoraAdapter,
+ LoraScale: c.LoraScale,
+ LoraAdapters: c.LoraAdapters,
+ LoraScales: c.LoraScales,
+ F16Memory: f16,
+ LoraBase: c.LoraBase,
+ IMG2IMG: c.Diffusers.IMG2IMG,
+ CLIPModel: c.Diffusers.ClipModel,
+ CLIPSubfolder: c.Diffusers.ClipSubFolder,
+ Options: c.Options,
+ Overrides: c.Overrides,
+ CLIPSkip: int32(c.Diffusers.ClipSkip),
+ ControlNet: c.Diffusers.ControlNet,
+ ContextSize: int32(ctxSize),
+ Seed: getSeed(c),
+ NBatch: int32(b),
+ NoMulMatQ: c.NoMulMatQ,
+ DraftModel: c.DraftModel,
+ AudioPath: c.AudioPath,
+ Quantization: c.Quantization,
+ LoadFormat: c.LoadFormat,
+ GPUMemoryUtilization: c.GPUMemoryUtilization,
+ TrustRemoteCode: c.TrustRemoteCode,
+ EnforceEager: c.EnforceEager,
+ SwapSpace: int32(c.SwapSpace),
+ MaxModelLen: int32(c.MaxModelLen),
+ TensorParallelSize: int32(c.TensorParallelSize),
+ DisableLogStatus: c.DisableLogStatus,
+ DType: c.DType,
+ // LimitMMPerPrompt vLLM
+ LimitImagePerPrompt: int32(c.LimitMMPerPrompt.LimitImagePerPrompt),
+ LimitVideoPerPrompt: int32(c.LimitMMPerPrompt.LimitVideoPerPrompt),
+ LimitAudioPerPrompt: int32(c.LimitMMPerPrompt.LimitAudioPerPrompt),
+ FlashAttention: flashAttention,
+ CacheTypeKey: c.CacheTypeK,
+ CacheTypeValue: c.CacheTypeV,
+ NoKVOffload: c.NoKVOffloading,
+ YarnExtFactor: c.YarnExtFactor,
+ YarnAttnFactor: c.YarnAttnFactor,
+ YarnBetaFast: c.YarnBetaFast,
+ YarnBetaSlow: c.YarnBetaSlow,
+ NGQA: c.NGQA,
+ RMSNormEps: c.RMSNormEps,
+ MLock: mmlock,
+ RopeFreqBase: c.RopeFreqBase,
+ RopeScaling: c.RopeScaling,
+ Type: c.ModelType,
+ RopeFreqScale: c.RopeFreqScale,
+ NUMA: c.NUMA,
+ Embeddings: embeddings,
+ Reranking: reranking,
+ LowVRAM: lowVRAM,
+ NGPULayers: int32(nGPULayers),
+ MMap: mmap,
+ MainGPU: c.MainGPU,
+ Threads: int32(*c.Threads),
+ TensorSplit: c.TensorSplit,
+ // RWKV
+ Tokenizer: c.Tokenizer,
+ }
+
+ if c.MMProj != "" {
+ opts.MMProj = filepath.Join(modelPath, c.MMProj)
+ }
+
+ return opts
+}
+
+func gRPCPredictOpts(c config.ModelConfig, modelPath string) *pb.PredictOptions {
+ promptCachePath := ""
+ if c.PromptCachePath != "" {
+ p := filepath.Join(modelPath, c.PromptCachePath)
+ err := os.MkdirAll(filepath.Dir(p), 0750)
+ if err == nil {
+ promptCachePath = p
+ } else {
+ xlog.Error("error creating prompt cache folder", "error", err, "promptCachePath", promptCachePath)
+ }
+ }
+
+ pbOpts := &pb.PredictOptions{
+ Temperature: float32(*c.Temperature),
+ TopP: float32(*c.TopP),
+ NDraft: c.NDraft,
+ TopK: int32(*c.TopK),
+ Tokens: int32(*c.Maxtokens),
+ Threads: int32(*c.Threads),
+ PromptCacheAll: c.PromptCacheAll,
+ PromptCacheRO: c.PromptCacheRO,
+ PromptCachePath: promptCachePath,
+ F16KV: *c.F16,
+ DebugMode: *c.Debug,
+ Grammar: c.Grammar,
+ NegativePromptScale: c.NegativePromptScale,
+ RopeFreqBase: c.RopeFreqBase,
+ RopeFreqScale: c.RopeFreqScale,
+ NegativePrompt: c.NegativePrompt,
+ Mirostat: int32(*c.LLMConfig.Mirostat),
+ MirostatETA: float32(*c.LLMConfig.MirostatETA),
+ MirostatTAU: float32(*c.LLMConfig.MirostatTAU),
+ Debug: *c.Debug,
+ StopPrompts: c.StopWords,
+ Repeat: int32(c.RepeatLastN),
+ FrequencyPenalty: float32(c.FrequencyPenalty),
+ PresencePenalty: float32(c.PresencePenalty),
+ Penalty: float32(c.RepeatPenalty),
+ NKeep: int32(c.Keep),
+ Batch: int32(c.Batch),
+ IgnoreEOS: c.IgnoreEOS,
+ Seed: getSeed(c),
+ MLock: *c.MMlock,
+ MMap: *c.MMap,
+ MainGPU: c.MainGPU,
+ TensorSplit: c.TensorSplit,
+ TailFreeSamplingZ: float32(*c.TFZ),
+ TypicalP: float32(*c.TypicalP),
+ }
+ // Logprobs and TopLogprobs are set by the caller if provided
+ return pbOpts
+}
diff --git a/core/backend/rerank.go b/core/backend/rerank.go
new file mode 100644
index 0000000000000000000000000000000000000000..bcfad7382fc8c9aede155e9b793a81b656e7e81f
--- /dev/null
+++ b/core/backend/rerank.go
@@ -0,0 +1,26 @@
+package backend
+
+import (
+ "context"
+ "fmt"
+
+ "github.com/mudler/LocalAI/core/config"
+ "github.com/mudler/LocalAI/pkg/grpc/proto"
+ model "github.com/mudler/LocalAI/pkg/model"
+)
+
+func Rerank(request *proto.RerankRequest, loader *model.ModelLoader, appConfig *config.ApplicationConfig, modelConfig config.ModelConfig) (*proto.RerankResult, error) {
+ opts := ModelOptions(modelConfig, appConfig)
+ rerankModel, err := loader.Load(opts...)
+ if err != nil {
+ return nil, err
+ }
+
+ if rerankModel == nil {
+ return nil, fmt.Errorf("could not load rerank model")
+ }
+
+ res, err := rerankModel.Rerank(context.Background(), request)
+
+ return res, err
+}
diff --git a/core/backend/soundgeneration.go b/core/backend/soundgeneration.go
new file mode 100644
index 0000000000000000000000000000000000000000..ca78b2db973cdcec079944e465b47a6c7aded4fe
--- /dev/null
+++ b/core/backend/soundgeneration.go
@@ -0,0 +1,66 @@
+package backend
+
+import (
+ "context"
+ "fmt"
+ "os"
+ "path/filepath"
+
+ "github.com/mudler/LocalAI/core/config"
+ "github.com/mudler/LocalAI/pkg/grpc/proto"
+ "github.com/mudler/LocalAI/pkg/model"
+ "github.com/mudler/LocalAI/pkg/utils"
+)
+
+func SoundGeneration(
+ text string,
+ duration *float32,
+ temperature *float32,
+ doSample *bool,
+ sourceFile *string,
+ sourceDivisor *int32,
+ loader *model.ModelLoader,
+ appConfig *config.ApplicationConfig,
+ modelConfig config.ModelConfig,
+) (string, *proto.Result, error) {
+
+ opts := ModelOptions(modelConfig, appConfig)
+ soundGenModel, err := loader.Load(opts...)
+ if err != nil {
+ return "", nil, err
+ }
+
+ if soundGenModel == nil {
+ return "", nil, fmt.Errorf("could not load sound generation model")
+ }
+
+ if err := os.MkdirAll(appConfig.GeneratedContentDir, 0750); err != nil {
+ return "", nil, fmt.Errorf("failed creating audio directory: %s", err)
+ }
+
+ audioDir := filepath.Join(appConfig.GeneratedContentDir, "audio")
+ if err := os.MkdirAll(audioDir, 0750); err != nil {
+ return "", nil, fmt.Errorf("failed creating audio directory: %s", err)
+ }
+
+ fileName := utils.GenerateUniqueFileName(audioDir, "sound_generation", ".wav")
+ filePath := filepath.Join(audioDir, fileName)
+
+ res, err := soundGenModel.SoundGeneration(context.Background(), &proto.SoundGenerationRequest{
+ Text: text,
+ Model: modelConfig.Model,
+ Dst: filePath,
+ Sample: doSample,
+ Duration: duration,
+ Temperature: temperature,
+ Src: sourceFile,
+ SrcDivisor: sourceDivisor,
+ })
+
+ // return RPC error if any
+ if !res.Success {
+ return "", nil, fmt.Errorf("error during sound generation: %s", res.Message)
+ }
+
+ return filePath, res, err
+}
diff --git a/core/backend/stores.go b/core/backend/stores.go
new file mode 100644
index 0000000000000000000000000000000000000000..78257180e93c8d4b30a59cd082cae2c015fb3549
--- /dev/null
+++ b/core/backend/stores.go
@@ -0,0 +1,20 @@
+package backend
+
+import (
+ "github.com/mudler/LocalAI/core/config"
+
+ "github.com/mudler/LocalAI/pkg/grpc"
+ "github.com/mudler/LocalAI/pkg/model"
+)
+
+func StoreBackend(sl *model.ModelLoader, appConfig *config.ApplicationConfig, storeName string, backend string) (grpc.Backend, error) {
+ if backend == "" {
+ backend = model.LocalStoreBackend
+ }
+ sc := []model.Option{
+ model.WithBackendString(backend),
+ model.WithModel(storeName),
+ }
+
+ return sl.Load(sc...)
+}
diff --git a/core/backend/token_metrics.go b/core/backend/token_metrics.go
new file mode 100644
index 0000000000000000000000000000000000000000..c81f57cab50f60eb708df34649b2a6de3be539dd
--- /dev/null
+++ b/core/backend/token_metrics.go
@@ -0,0 +1,31 @@
+package backend
+
+import (
+ "context"
+ "fmt"
+
+ "github.com/mudler/LocalAI/core/config"
+ "github.com/mudler/LocalAI/pkg/grpc/proto"
+ model "github.com/mudler/LocalAI/pkg/model"
+)
+
+func TokenMetrics(
+ modelFile string,
+ loader *model.ModelLoader,
+ appConfig *config.ApplicationConfig,
+ modelConfig config.ModelConfig) (*proto.MetricsResponse, error) {
+
+ opts := ModelOptions(modelConfig, appConfig, model.WithModel(modelFile))
+ model, err := loader.Load(opts...)
+ if err != nil {
+ return nil, err
+ }
+
+ if model == nil {
+ return nil, fmt.Errorf("could not loadmodel model")
+ }
+
+ res, err := model.GetTokenMetrics(context.Background(), &proto.MetricsRequest{})
+
+ return res, err
+}
diff --git a/core/backend/tokenize.go b/core/backend/tokenize.go
new file mode 100644
index 0000000000000000000000000000000000000000..5803e44beadd262700a9e0a611402c8a1329baf9
--- /dev/null
+++ b/core/backend/tokenize.go
@@ -0,0 +1,38 @@
+package backend
+
+import (
+ "github.com/mudler/LocalAI/core/config"
+ "github.com/mudler/LocalAI/core/schema"
+ "github.com/mudler/LocalAI/pkg/grpc"
+ "github.com/mudler/LocalAI/pkg/model"
+)
+
+func ModelTokenize(s string, loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (schema.TokenizeResponse, error) {
+
+ var inferenceModel grpc.Backend
+ var err error
+
+ opts := ModelOptions(modelConfig, appConfig)
+ inferenceModel, err = loader.Load(opts...)
+ if err != nil {
+ return schema.TokenizeResponse{}, err
+ }
+
+ predictOptions := gRPCPredictOpts(modelConfig, loader.ModelPath)
+ predictOptions.Prompt = s
+
+ // tokenize the string
+ resp, err := inferenceModel.TokenizeString(appConfig.Context, predictOptions)
+ if err != nil {
+ return schema.TokenizeResponse{}, err
+ }
+
+ if resp.Tokens == nil {
+ resp.Tokens = make([]int32, 0)
+ }
+
+ return schema.TokenizeResponse{
+ Tokens: resp.Tokens,
+ }, nil
+
+}
diff --git a/core/backend/transcript.go b/core/backend/transcript.go
new file mode 100644
index 0000000000000000000000000000000000000000..66e6878139a95fb5b9b9dc9adb563a5592cac64c
--- /dev/null
+++ b/core/backend/transcript.go
@@ -0,0 +1,61 @@
+package backend
+
+import (
+ "context"
+ "fmt"
+ "time"
+
+ "github.com/mudler/LocalAI/core/config"
+ "github.com/mudler/LocalAI/core/schema"
+
+ "github.com/mudler/LocalAI/pkg/grpc/proto"
+ "github.com/mudler/LocalAI/pkg/model"
+)
+
+func ModelTranscription(audio, language string, translate bool, diarize bool, prompt string, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) {
+
+ if modelConfig.Backend == "" {
+ modelConfig.Backend = model.WhisperBackend
+ }
+
+ opts := ModelOptions(modelConfig, appConfig)
+
+ transcriptionModel, err := ml.Load(opts...)
+ if err != nil {
+ return nil, err
+ }
+
+ if transcriptionModel == nil {
+ return nil, fmt.Errorf("could not load transcription model")
+ }
+
+ r, err := transcriptionModel.AudioTranscription(context.Background(), &proto.TranscriptRequest{
+ Dst: audio,
+ Language: language,
+ Translate: translate,
+ Diarize: diarize,
+ Threads: uint32(*modelConfig.Threads),
+ Prompt: prompt,
+ })
+ if err != nil {
+ return nil, err
+ }
+ tr := &schema.TranscriptionResult{
+ Text: r.Text,
+ }
+ for _, s := range r.Segments {
+ var tks []int
+ for _, t := range s.Tokens {
+ tks = append(tks, int(t))
+ }
+ tr.Segments = append(tr.Segments,
+ schema.TranscriptionSegment{
+ Text: s.Text,
+ Id: int(s.Id),
+ Start: time.Duration(s.Start),
+ End: time.Duration(s.End),
+ Tokens: tks,
+ })
+ }
+ return tr, err
+}
diff --git a/core/backend/tts.go b/core/backend/tts.go
new file mode 100644
index 0000000000000000000000000000000000000000..9c75cb37a1719b37b3a3c785cfcca468151382bc
--- /dev/null
+++ b/core/backend/tts.go
@@ -0,0 +1,76 @@
+package backend
+
+import (
+ "context"
+ "fmt"
+ "os"
+ "path/filepath"
+
+ "github.com/mudler/LocalAI/core/config"
+
+ "github.com/mudler/LocalAI/pkg/grpc/proto"
+ "github.com/mudler/LocalAI/pkg/model"
+ "github.com/mudler/LocalAI/pkg/utils"
+)
+
+func ModelTTS(
+ text,
+ voice,
+ language string,
+ loader *model.ModelLoader,
+ appConfig *config.ApplicationConfig,
+ modelConfig config.ModelConfig,
+) (string, *proto.Result, error) {
+ opts := ModelOptions(modelConfig, appConfig)
+ ttsModel, err := loader.Load(opts...)
+ if err != nil {
+ return "", nil, err
+ }
+
+ if ttsModel == nil {
+ return "", nil, fmt.Errorf("could not load tts model %q", modelConfig.Model)
+ }
+
+ audioDir := filepath.Join(appConfig.GeneratedContentDir, "audio")
+ if err := os.MkdirAll(audioDir, 0750); err != nil {
+ return "", nil, fmt.Errorf("failed creating audio directory: %s", err)
+ }
+
+ fileName := utils.GenerateUniqueFileName(audioDir, "tts", ".wav")
+ filePath := filepath.Join(audioDir, fileName)
+
+ // We join the model name to the model path here. This seems to only be done for TTS and is HIGHLY suspect.
+ // This should be addressed in a follow up PR soon.
+ // Copying it over nearly verbatim, as TTS backends are not functional without this.
+ modelPath := ""
+ // Checking first that it exists and is not outside ModelPath
+ // TODO: we should actually first check if the modelFile is looking like
+ // a FS path
+ mp := filepath.Join(loader.ModelPath, modelConfig.Model)
+ if _, err := os.Stat(mp); err == nil {
+ if err := utils.VerifyPath(mp, appConfig.SystemState.Model.ModelsPath); err != nil {
+ return "", nil, err
+ }
+ modelPath = mp
+ } else {
+ modelPath = modelConfig.Model // skip this step if it fails?????
+ }
+
+ res, err := ttsModel.TTS(context.Background(), &proto.TTSRequest{
+ Text: text,
+ Model: modelPath,
+ Voice: voice,
+ Dst: filePath,
+ Language: &language,
+ })
+ if err != nil {
+ return "", nil, err
+ }
+
+ // return RPC error if any
+ if !res.Success {
+ return "", nil, fmt.Errorf("error during TTS: %s", res.Message)
+ }
+
+ return filePath, res, err
+}
diff --git a/core/backend/vad.go b/core/backend/vad.go
new file mode 100644
index 0000000000000000000000000000000000000000..37859931dc1b52e3f4df9eab8954658d2abfa675
--- /dev/null
+++ b/core/backend/vad.go
@@ -0,0 +1,39 @@
+package backend
+
+import (
+ "context"
+
+ "github.com/mudler/LocalAI/core/config"
+ "github.com/mudler/LocalAI/core/schema"
+ "github.com/mudler/LocalAI/pkg/grpc/proto"
+ "github.com/mudler/LocalAI/pkg/model"
+)
+
+func VAD(request *schema.VADRequest,
+ ctx context.Context,
+ ml *model.ModelLoader,
+ appConfig *config.ApplicationConfig,
+ modelConfig config.ModelConfig) (*schema.VADResponse, error) {
+ opts := ModelOptions(modelConfig, appConfig)
+ vadModel, err := ml.Load(opts...)
+ if err != nil {
+ return nil, err
+ }
+
+ req := proto.VADRequest{
+ Audio: request.Audio,
+ }
+ resp, err := vadModel.VAD(ctx, &req)
+ if err != nil {
+ return nil, err
+ }
+
+ segments := []schema.VADSegment{}
+ for _, s := range resp.Segments {
+ segments = append(segments, schema.VADSegment{Start: s.Start, End: s.End})
+ }
+
+ return &schema.VADResponse{
+ Segments: segments,
+ }, nil
+}
diff --git a/core/backend/video.go b/core/backend/video.go
new file mode 100644
index 0000000000000000000000000000000000000000..666a7625226a9b44afd66e0b5cad9f2e67c3c343
--- /dev/null
+++ b/core/backend/video.go
@@ -0,0 +1,41 @@
+package backend
+
+import (
+ "github.com/mudler/LocalAI/core/config"
+
+ "github.com/mudler/LocalAI/pkg/grpc/proto"
+ model "github.com/mudler/LocalAI/pkg/model"
+)
+
+func VideoGeneration(height, width int32, prompt, negativePrompt, startImage, endImage, dst string, numFrames, fps, seed int32, cfgScale float32, step int32, loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (func() error, error) {
+
+ opts := ModelOptions(modelConfig, appConfig)
+ inferenceModel, err := loader.Load(
+ opts...,
+ )
+ if err != nil {
+ return nil, err
+ }
+
+ fn := func() error {
+ _, err := inferenceModel.GenerateVideo(
+ appConfig.Context,
+ &proto.GenerateVideoRequest{
+ Height: height,
+ Width: width,
+ Prompt: prompt,
+ NegativePrompt: negativePrompt,
+ StartImage: startImage,
+ EndImage: endImage,
+ NumFrames: numFrames,
+ Fps: fps,
+ Seed: seed,
+ CfgScale: cfgScale,
+ Step: step,
+ Dst: dst,
+ })
+ return err
+ }
+
+ return fn, nil
+}
diff --git a/core/cli/backends.go b/core/cli/backends.go
new file mode 100644
index 0000000000000000000000000000000000000000..9877d746a4a036d3c8860f5e216890876b7c0227
--- /dev/null
+++ b/core/cli/backends.go
@@ -0,0 +1,134 @@
+package cli
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+
+ cliContext "github.com/mudler/LocalAI/core/cli/context"
+ "github.com/mudler/LocalAI/core/config"
+ "github.com/mudler/LocalAI/core/gallery"
+ "github.com/mudler/LocalAI/core/services"
+ "github.com/mudler/LocalAI/pkg/model"
+ "github.com/mudler/LocalAI/pkg/system"
+
+ "github.com/mudler/xlog"
+ "github.com/schollz/progressbar/v3"
+)
+
+type BackendsCMDFlags struct {
+ BackendGalleries string `env:"LOCALAI_BACKEND_GALLERIES,BACKEND_GALLERIES" help:"JSON list of backend galleries" group:"backends" default:"${backends}"`
+ BackendsPath string `env:"LOCALAI_BACKENDS_PATH,BACKENDS_PATH" type:"path" default:"${basepath}/backends" help:"Path containing backends used for inferencing" group:"storage"`
+ BackendsSystemPath string `env:"LOCALAI_BACKENDS_SYSTEM_PATH,BACKEND_SYSTEM_PATH" type:"path" default:"/var/lib/local-ai/backends" help:"Path containing system backends used for inferencing" group:"backends"`
+}
+
+type BackendsList struct {
+ BackendsCMDFlags `embed:""`
+}
+
+type BackendsInstall struct {
+ BackendArgs string `arg:"" optional:"" name:"backend" help:"Backend configuration URL to load"`
+ Name string `arg:"" optional:"" name:"name" help:"Name of the backend"`
+ Alias string `arg:"" optional:"" name:"alias" help:"Alias of the backend"`
+
+ BackendsCMDFlags `embed:""`
+}
+
+type BackendsUninstall struct {
+ BackendArgs []string `arg:"" name:"backends" help:"Backend names to uninstall"`
+
+ BackendsCMDFlags `embed:""`
+}
+
+type BackendsCMD struct {
+ List BackendsList `cmd:"" help:"List the backends available in your galleries" default:"withargs"`
+ Install BackendsInstall `cmd:"" help:"Install a backend from the gallery"`
+ Uninstall BackendsUninstall `cmd:"" help:"Uninstall a backend"`
+}
+
+func (bl *BackendsList) Run(ctx *cliContext.Context) error {
+ var galleries []config.Gallery
+ if err := json.Unmarshal([]byte(bl.BackendGalleries), &galleries); err != nil {
+ xlog.Error("unable to load galleries", "error", err)
+ }
+
+ systemState, err := system.GetSystemState(
+ system.WithBackendSystemPath(bl.BackendsSystemPath),
+ system.WithBackendPath(bl.BackendsPath),
+ )
+ if err != nil {
+ return err
+ }
+
+ backends, err := gallery.AvailableBackends(galleries, systemState)
+ if err != nil {
+ return err
+ }
+ for _, backend := range backends {
+ if backend.Installed {
+ fmt.Printf(" * %s@%s (installed)\n", backend.Gallery.Name, backend.Name)
+ } else {
+ fmt.Printf(" - %s@%s\n", backend.Gallery.Name, backend.Name)
+ }
+ }
+ return nil
+}
+
+func (bi *BackendsInstall) Run(ctx *cliContext.Context) error {
+ var galleries []config.Gallery
+ if err := json.Unmarshal([]byte(bi.BackendGalleries), &galleries); err != nil {
+ xlog.Error("unable to load galleries", "error", err)
+ }
+
+ systemState, err := system.GetSystemState(
+ system.WithBackendSystemPath(bi.BackendsSystemPath),
+ system.WithBackendPath(bi.BackendsPath),
+ )
+ if err != nil {
+ return err
+ }
+
+ progressBar := progressbar.NewOptions(
+ 1000,
+ progressbar.OptionSetDescription(fmt.Sprintf("downloading backend %s", bi.BackendArgs)),
+ progressbar.OptionShowBytes(false),
+ progressbar.OptionClearOnFinish(),
+ )
+ progressCallback := func(fileName string, current string, total string, percentage float64) {
+ v := int(percentage * 10)
+ err := progressBar.Set(v)
+ if err != nil {
+ xlog.Error("error while updating progress bar", "error", err, "filename", fileName, "value", v)
+ }
+ }
+
+ modelLoader := model.NewModelLoader(systemState)
+ err = services.InstallExternalBackend(context.Background(), galleries, systemState, modelLoader, progressCallback, bi.BackendArgs, bi.Name, bi.Alias)
+ if err != nil {
+ return err
+ }
+
+ return nil
+}
+
+func (bu *BackendsUninstall) Run(ctx *cliContext.Context) error {
+ for _, backendName := range bu.BackendArgs {
+ xlog.Info("uninstalling backend", "backend", backendName)
+
+ systemState, err := system.GetSystemState(
+ system.WithBackendSystemPath(bu.BackendsSystemPath),
+ system.WithBackendPath(bu.BackendsPath),
+ )
+ if err != nil {
+ return err
+ }
+
+ err = gallery.DeleteBackendFromSystem(systemState, backendName)
+ if err != nil {
+ return err
+ }
+
+ fmt.Printf("Backend %s uninstalled successfully\n", backendName)
+ }
+ return nil
+}
diff --git a/core/cli/cli.go b/core/cli/cli.go
new file mode 100644
index 0000000000000000000000000000000000000000..fc850de945ce171f29fd9d9f62ca7b658a0287fa
--- /dev/null
+++ b/core/cli/cli.go
@@ -0,0 +1,21 @@
+package cli
+
+import (
+ cliContext "github.com/mudler/LocalAI/core/cli/context"
+ "github.com/mudler/LocalAI/core/cli/worker"
+)
+
+var CLI struct {
+ cliContext.Context `embed:""`
+
+ Run RunCMD `cmd:"" help:"Run LocalAI, this the default command if no other command is specified. Run 'local-ai run --help' for more information" default:"withargs"`
+ Federated FederatedCLI `cmd:"" help:"Run LocalAI in federated mode"`
+ Models ModelsCMD `cmd:"" help:"Manage LocalAI models and definitions"`
+ Backends BackendsCMD `cmd:"" help:"Manage LocalAI backends and definitions"`
+ TTS TTSCMD `cmd:"" help:"Convert text to speech"`
+ SoundGeneration SoundGenerationCMD `cmd:"" help:"Generates audio files from text or audio"`
+ Transcript TranscriptCMD `cmd:"" help:"Convert audio to text"`
+ Worker worker.Worker `cmd:"" help:"Run workers to distribute workload (llama.cpp-only)"`
+ Util UtilCMD `cmd:"" help:"Utility commands"`
+ Explorer ExplorerCMD `cmd:"" help:"Run p2p explorer"`
+}
diff --git a/core/cli/context/context.go b/core/cli/context/context.go
new file mode 100644
index 0000000000000000000000000000000000000000..2da238d8b8a5c466db489afe9d71da587c5b4681
--- /dev/null
+++ b/core/cli/context/context.go
@@ -0,0 +1,7 @@
+package cliContext
+
+type Context struct {
+ Debug bool `env:"LOCALAI_DEBUG,DEBUG" default:"false" hidden:"" help:"DEPRECATED, use --log-level=debug instead. Enable debug logging"`
+ LogLevel *string `env:"LOCALAI_LOG_LEVEL" enum:"error,warn,info,debug,trace" help:"Set the level of logs to output [${enum}]"`
+ LogFormat *string `env:"LOCALAI_LOG_FORMAT" default:"default" enum:"default,text,json" help:"Set the format of logs to output [${enum}]"`
+}
diff --git a/core/cli/explorer.go b/core/cli/explorer.go
new file mode 100644
index 0000000000000000000000000000000000000000..d520dac212f41ce6096e0b0d2d21d9dd35a25750
--- /dev/null
+++ b/core/cli/explorer.go
@@ -0,0 +1,59 @@
+package cli
+
+import (
+ "context"
+ "time"
+
+ cliContext "github.com/mudler/LocalAI/core/cli/context"
+ "github.com/mudler/LocalAI/core/explorer"
+ "github.com/mudler/LocalAI/core/http"
+ "github.com/mudler/LocalAI/pkg/signals"
+ "github.com/mudler/xlog"
+)
+
+type ExplorerCMD struct {
+ Address string `env:"LOCALAI_ADDRESS,ADDRESS" default:":8080" help:"Bind address for the API server" group:"api"`
+ PoolDatabase string `env:"LOCALAI_POOL_DATABASE,POOL_DATABASE" default:"explorer.json" help:"Path to the pool database" group:"api"`
+ ConnectionTimeout string `env:"LOCALAI_CONNECTION_TIMEOUT,CONNECTION_TIMEOUT" default:"2m" help:"Connection timeout for the explorer" group:"api"`
+ ConnectionErrorThreshold int `env:"LOCALAI_CONNECTION_ERROR_THRESHOLD,CONNECTION_ERROR_THRESHOLD" default:"3" help:"Connection failure threshold for the explorer" group:"api"`
+
+ WithSync bool `env:"LOCALAI_WITH_SYNC,WITH_SYNC" default:"false" help:"Enable sync with the network" group:"api"`
+ OnlySync bool `env:"LOCALAI_ONLY_SYNC,ONLY_SYNC" default:"false" help:"Only sync with the network" group:"api"`
+}
+
+func (e *ExplorerCMD) Run(ctx *cliContext.Context) error {
+
+ db, err := explorer.NewDatabase(e.PoolDatabase)
+ if err != nil {
+ return err
+ }
+
+ dur, err := time.ParseDuration(e.ConnectionTimeout)
+ if err != nil {
+ return err
+ }
+
+ if e.WithSync {
+ ds := explorer.NewDiscoveryServer(db, dur, e.ConnectionErrorThreshold)
+ go ds.Start(context.Background(), true)
+ }
+
+ if e.OnlySync {
+ ds := explorer.NewDiscoveryServer(db, dur, e.ConnectionErrorThreshold)
+ ctx := context.Background()
+
+ return ds.Start(ctx, false)
+ }
+
+ appHTTP := http.Explorer(db)
+
+ signals.RegisterGracefulTerminationHandler(func() {
+ ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
+ defer cancel()
+ if err := appHTTP.Shutdown(ctx); err != nil {
+ xlog.Error("error during shutdown", "error", err)
+ }
+ })
+
+ return appHTTP.Start(e.Address)
+}
diff --git a/core/cli/federated.go b/core/cli/federated.go
new file mode 100644
index 0000000000000000000000000000000000000000..ceea5a9e43b4e28e5f69dbcdd7ef3c8f9f07bc84
--- /dev/null
+++ b/core/cli/federated.go
@@ -0,0 +1,30 @@
+package cli
+
+import (
+ "context"
+
+ cliContext "github.com/mudler/LocalAI/core/cli/context"
+ "github.com/mudler/LocalAI/core/p2p"
+ "github.com/mudler/LocalAI/pkg/signals"
+)
+
+type FederatedCLI struct {
+ Address string `env:"LOCALAI_ADDRESS,ADDRESS" default:":8080" help:"Bind address for the API server" group:"api"`
+ Peer2PeerToken string `env:"LOCALAI_P2P_TOKEN,P2P_TOKEN,TOKEN" name:"p2ptoken" help:"Token for P2P mode (optional)" group:"p2p"`
+ RandomWorker bool `env:"LOCALAI_RANDOM_WORKER,RANDOM_WORKER" default:"false" help:"Select a random worker from the pool" group:"p2p"`
+ Peer2PeerNetworkID string `env:"LOCALAI_P2P_NETWORK_ID,P2P_NETWORK_ID" help:"Network ID for P2P mode, can be set arbitrarly by the user for grouping a set of instances." group:"p2p"`
+ TargetWorker string `env:"LOCALAI_TARGET_WORKER,TARGET_WORKER" help:"Target worker to run the federated server on" group:"p2p"`
+}
+
+func (f *FederatedCLI) Run(ctx *cliContext.Context) error {
+
+ fs := p2p.NewFederatedServer(f.Address, p2p.NetworkID(f.Peer2PeerNetworkID, p2p.FederatedID), f.Peer2PeerToken, !f.RandomWorker, f.TargetWorker)
+
+ c, cancel := context.WithCancel(context.Background())
+
+ signals.RegisterGracefulTerminationHandler(func() {
+ cancel()
+ })
+
+ return fs.Start(c)
+}
diff --git a/core/cli/models.go b/core/cli/models.go
new file mode 100644
index 0000000000000000000000000000000000000000..3006922c8c87894662b6d5c5ba600e07c4691f3c
--- /dev/null
+++ b/core/cli/models.go
@@ -0,0 +1,144 @@
+package cli
+
+import (
+ "context"
+ "encoding/json"
+ "errors"
+ "fmt"
+
+ cliContext "github.com/mudler/LocalAI/core/cli/context"
+ "github.com/mudler/LocalAI/core/config"
+ "github.com/mudler/LocalAI/core/services"
+
+ "github.com/mudler/LocalAI/core/gallery"
+ "github.com/mudler/LocalAI/core/startup"
+ "github.com/mudler/LocalAI/pkg/downloader"
+ "github.com/mudler/LocalAI/pkg/model"
+ "github.com/mudler/LocalAI/pkg/system"
+ "github.com/mudler/xlog"
+ "github.com/schollz/progressbar/v3"
+)
+
+type ModelsCMDFlags struct {
+ Galleries string `env:"LOCALAI_GALLERIES,GALLERIES" help:"JSON list of galleries" group:"models" default:"${galleries}"`
+ BackendGalleries string `env:"LOCALAI_BACKEND_GALLERIES,BACKEND_GALLERIES" help:"JSON list of backend galleries" group:"backends" default:"${backends}"`
+ ModelsPath string `env:"LOCALAI_MODELS_PATH,MODELS_PATH" type:"path" default:"${basepath}/models" help:"Path containing models used for inferencing" group:"storage"`
+ BackendsPath string `env:"LOCALAI_BACKENDS_PATH,BACKENDS_PATH" type:"path" default:"${basepath}/backends" help:"Path containing backends used for inferencing" group:"storage"`
+}
+
+type ModelsList struct {
+ ModelsCMDFlags `embed:""`
+}
+
+type ModelsInstall struct {
+ DisablePredownloadScan bool `env:"LOCALAI_DISABLE_PREDOWNLOAD_SCAN" help:"If true, disables the best-effort security scanner before downloading any files." group:"hardening" default:"false"`
+ AutoloadBackendGalleries bool `env:"LOCALAI_AUTOLOAD_BACKEND_GALLERIES" help:"If true, automatically loads backend galleries" group:"backends" default:"true"`
+ ModelArgs []string `arg:"" optional:"" name:"models" help:"Model configuration URLs to load"`
+
+ ModelsCMDFlags `embed:""`
+}
+
+type ModelsCMD struct {
+ List ModelsList `cmd:"" help:"List the models available in your galleries" default:"withargs"`
+ Install ModelsInstall `cmd:"" help:"Install a model from the gallery"`
+}
+
+func (ml *ModelsList) Run(ctx *cliContext.Context) error {
+ var galleries []config.Gallery
+ if err := json.Unmarshal([]byte(ml.Galleries), &galleries); err != nil {
+ xlog.Error("unable to load galleries", "error", err)
+ }
+
+ systemState, err := system.GetSystemState(
+ system.WithModelPath(ml.ModelsPath),
+ system.WithBackendPath(ml.BackendsPath),
+ )
+ if err != nil {
+ return err
+ }
+ models, err := gallery.AvailableGalleryModels(galleries, systemState)
+ if err != nil {
+ return err
+ }
+ for _, model := range models {
+ if model.Installed {
+ fmt.Printf(" * %s@%s (installed)\n", model.Gallery.Name, model.Name)
+ } else {
+ fmt.Printf(" - %s@%s\n", model.Gallery.Name, model.Name)
+ }
+ }
+ return nil
+}
+
+func (mi *ModelsInstall) Run(ctx *cliContext.Context) error {
+
+ systemState, err := system.GetSystemState(
+ system.WithModelPath(mi.ModelsPath),
+ system.WithBackendPath(mi.BackendsPath),
+ )
+ if err != nil {
+ return err
+ }
+
+ galleryService := services.NewGalleryService(&config.ApplicationConfig{
+ SystemState: systemState,
+ }, model.NewModelLoader(systemState))
+ err = galleryService.Start(context.Background(), config.NewModelConfigLoader(mi.ModelsPath), systemState)
+ if err != nil {
+ return err
+ }
+
+ var galleries []config.Gallery
+ if err := json.Unmarshal([]byte(mi.Galleries), &galleries); err != nil {
+ xlog.Error("unable to load galleries", "error", err)
+ }
+
+ var backendGalleries []config.Gallery
+ if err := json.Unmarshal([]byte(mi.BackendGalleries), &backendGalleries); err != nil {
+ xlog.Error("unable to load backend galleries", "error", err)
+ }
+
+ for _, modelName := range mi.ModelArgs {
+
+ progressBar := progressbar.NewOptions(
+ 1000,
+ progressbar.OptionSetDescription(fmt.Sprintf("downloading model %s", modelName)),
+ progressbar.OptionShowBytes(false),
+ progressbar.OptionClearOnFinish(),
+ )
+ progressCallback := func(fileName string, current string, total string, percentage float64) {
+ v := int(percentage * 10)
+ err := progressBar.Set(v)
+ if err != nil {
+ xlog.Error("error while updating progress bar", "error", err, "filename", fileName, "value", v)
+ }
+ }
+ //startup.InstallModels()
+ models, err := gallery.AvailableGalleryModels(galleries, systemState)
+ if err != nil {
+ return err
+ }
+
+ modelURI := downloader.URI(modelName)
+
+ if !modelURI.LooksLikeOCI() {
+ model := gallery.FindGalleryElement(models, modelName)
+ if model == nil {
+ xlog.Error("model not found", "model", modelName)
+ return err
+ }
+
+ err = gallery.SafetyScanGalleryModel(model)
+ if err != nil && !errors.Is(err, downloader.ErrNonHuggingFaceFile) {
+ return err
+ }
+ }
+
+ modelLoader := model.NewModelLoader(systemState)
+ err = startup.InstallModels(context.Background(), galleryService, galleries, backendGalleries, systemState, modelLoader, !mi.DisablePredownloadScan, mi.AutoloadBackendGalleries, progressCallback, modelName)
+ if err != nil {
+ return err
+ }
+ }
+ return nil
+}
diff --git a/core/cli/run.go b/core/cli/run.go
new file mode 100644
index 0000000000000000000000000000000000000000..517052b9c52ab3ec38fce5576368c7fa909b67f0
--- /dev/null
+++ b/core/cli/run.go
@@ -0,0 +1,298 @@
+package cli
+
+import (
+ "context"
+ "fmt"
+ "os"
+ "strings"
+ "time"
+
+ "github.com/mudler/LocalAI/core/application"
+ cliContext "github.com/mudler/LocalAI/core/cli/context"
+ "github.com/mudler/LocalAI/core/config"
+ "github.com/mudler/LocalAI/core/http"
+ "github.com/mudler/LocalAI/core/p2p"
+ "github.com/mudler/LocalAI/internal"
+ "github.com/mudler/LocalAI/pkg/signals"
+ "github.com/mudler/LocalAI/pkg/system"
+ "github.com/mudler/xlog"
+)
+
+type RunCMD struct {
+ ModelArgs []string `arg:"" optional:"" name:"models" help:"Model configuration URLs to load"`
+
+ ExternalBackends []string `env:"LOCALAI_EXTERNAL_BACKENDS,EXTERNAL_BACKENDS" help:"A list of external backends to load from gallery on boot" group:"backends"`
+ BackendsPath string `env:"LOCALAI_BACKENDS_PATH,BACKENDS_PATH" type:"path" default:"${basepath}/backends" help:"Path containing backends used for inferencing" group:"backends"`
+ BackendsSystemPath string `env:"LOCALAI_BACKENDS_SYSTEM_PATH,BACKEND_SYSTEM_PATH" type:"path" default:"/var/lib/local-ai/backends" help:"Path containing system backends used for inferencing" group:"backends"`
+ ModelsPath string `env:"LOCALAI_MODELS_PATH,MODELS_PATH" type:"path" default:"${basepath}/models" help:"Path containing models used for inferencing" group:"storage"`
+ GeneratedContentPath string `env:"LOCALAI_GENERATED_CONTENT_PATH,GENERATED_CONTENT_PATH" type:"path" default:"/tmp/generated/content" help:"Location for generated content (e.g. images, audio, videos)" group:"storage"`
+ UploadPath string `env:"LOCALAI_UPLOAD_PATH,UPLOAD_PATH" type:"path" default:"/tmp/localai/upload" help:"Path to store uploads from files api" group:"storage"`
+ LocalaiConfigDir string `env:"LOCALAI_CONFIG_DIR" type:"path" default:"${basepath}/configuration" help:"Directory for dynamic loading of certain configuration files (currently api_keys.json and external_backends.json)" group:"storage"`
+ LocalaiConfigDirPollInterval time.Duration `env:"LOCALAI_CONFIG_DIR_POLL_INTERVAL" help:"Typically the config path picks up changes automatically, but if your system has broken fsnotify events, set this to an interval to poll the LocalAI Config Dir (example: 1m)" group:"storage"`
+ // The alias on this option is there to preserve functionality with the old `--config-file` parameter
+ ModelsConfigFile string `env:"LOCALAI_MODELS_CONFIG_FILE,CONFIG_FILE" aliases:"config-file" help:"YAML file containing a list of model backend configs" group:"storage"`
+ BackendGalleries string `env:"LOCALAI_BACKEND_GALLERIES,BACKEND_GALLERIES" help:"JSON list of backend galleries" group:"backends" default:"${backends}"`
+ Galleries string `env:"LOCALAI_GALLERIES,GALLERIES" help:"JSON list of galleries" group:"models" default:"${galleries}"`
+ AutoloadGalleries bool `env:"LOCALAI_AUTOLOAD_GALLERIES,AUTOLOAD_GALLERIES" group:"models" default:"true"`
+ AutoloadBackendGalleries bool `env:"LOCALAI_AUTOLOAD_BACKEND_GALLERIES,AUTOLOAD_BACKEND_GALLERIES" group:"backends" default:"true"`
+ PreloadModels string `env:"LOCALAI_PRELOAD_MODELS,PRELOAD_MODELS" help:"A List of models to apply in JSON at start" group:"models"`
+ Models []string `env:"LOCALAI_MODELS,MODELS" help:"A List of model configuration URLs to load" group:"models"`
+ PreloadModelsConfig string `env:"LOCALAI_PRELOAD_MODELS_CONFIG,PRELOAD_MODELS_CONFIG" help:"A List of models to apply at startup. Path to a YAML config file" group:"models"`
+
+ F16 bool `name:"f16" env:"LOCALAI_F16,F16" help:"Enable GPU acceleration" group:"performance"`
+ Threads int `env:"LOCALAI_THREADS,THREADS" short:"t" help:"Number of threads used for parallel computation. Usage of the number of physical cores in the system is suggested" group:"performance"`
+ ContextSize int `env:"LOCALAI_CONTEXT_SIZE,CONTEXT_SIZE" help:"Default context size for models" group:"performance"`
+
+ Address string `env:"LOCALAI_ADDRESS,ADDRESS" default:":8080" help:"Bind address for the API server" group:"api"`
+ CORS bool `env:"LOCALAI_CORS,CORS" help:"" group:"api"`
+ CORSAllowOrigins string `env:"LOCALAI_CORS_ALLOW_ORIGINS,CORS_ALLOW_ORIGINS" group:"api"`
+ CSRF bool `env:"LOCALAI_CSRF" help:"Enables fiber CSRF middleware" group:"api"`
+ UploadLimit int `env:"LOCALAI_UPLOAD_LIMIT,UPLOAD_LIMIT" default:"15" help:"Default upload-limit in MB" group:"api"`
+ APIKeys []string `env:"LOCALAI_API_KEY,API_KEY" help:"List of API Keys to enable API authentication. When this is set, all the requests must be authenticated with one of these API keys" group:"api"`
+ DisableWebUI bool `env:"LOCALAI_DISABLE_WEBUI,DISABLE_WEBUI" default:"false" help:"Disables the web user interface. When set to true, the server will only expose API endpoints without serving the web interface" group:"api"`
+ DisableRuntimeSettings bool `env:"LOCALAI_DISABLE_RUNTIME_SETTINGS,DISABLE_RUNTIME_SETTINGS" default:"false" help:"Disables the runtime settings. When set to true, the server will not load the runtime settings from the runtime_settings.json file" group:"api"`
+ DisablePredownloadScan bool `env:"LOCALAI_DISABLE_PREDOWNLOAD_SCAN" help:"If true, disables the best-effort security scanner before downloading any files." group:"hardening" default:"false"`
+ OpaqueErrors bool `env:"LOCALAI_OPAQUE_ERRORS" default:"false" help:"If true, all error responses are replaced with blank 500 errors. This is intended only for hardening against information leaks and is normally not recommended." group:"hardening"`
+ UseSubtleKeyComparison bool `env:"LOCALAI_SUBTLE_KEY_COMPARISON" default:"false" help:"If true, API Key validation comparisons will be performed using constant-time comparisons rather than simple equality. This trades off performance on each request for resiliancy against timing attacks." group:"hardening"`
+ DisableApiKeyRequirementForHttpGet bool `env:"LOCALAI_DISABLE_API_KEY_REQUIREMENT_FOR_HTTP_GET" default:"false" help:"If true, a valid API key is not required to issue GET requests to portions of the web ui. This should only be enabled in secure testing environments" group:"hardening"`
+ DisableMetricsEndpoint bool `env:"LOCALAI_DISABLE_METRICS_ENDPOINT,DISABLE_METRICS_ENDPOINT" default:"false" help:"Disable the /metrics endpoint" group:"api"`
+ HttpGetExemptedEndpoints []string `env:"LOCALAI_HTTP_GET_EXEMPTED_ENDPOINTS" default:"^/$,^/browse/?$,^/talk/?$,^/p2p/?$,^/chat/?$,^/image/?$,^/text2image/?$,^/tts/?$,^/static/.*$,^/swagger.*$" help:"If LOCALAI_DISABLE_API_KEY_REQUIREMENT_FOR_HTTP_GET is overriden to true, this is the list of endpoints to exempt. Only adjust this in case of a security incident or as a result of a personal security posture review" group:"hardening"`
+ Peer2Peer bool `env:"LOCALAI_P2P,P2P" name:"p2p" default:"false" help:"Enable P2P mode" group:"p2p"`
+ Peer2PeerDHTInterval int `env:"LOCALAI_P2P_DHT_INTERVAL,P2P_DHT_INTERVAL" default:"360" name:"p2p-dht-interval" help:"Interval for DHT refresh (used during token generation)" group:"p2p"`
+ Peer2PeerOTPInterval int `env:"LOCALAI_P2P_OTP_INTERVAL,P2P_OTP_INTERVAL" default:"9000" name:"p2p-otp-interval" help:"Interval for OTP refresh (used during token generation)" group:"p2p"`
+ Peer2PeerToken string `env:"LOCALAI_P2P_TOKEN,P2P_TOKEN,TOKEN" name:"p2ptoken" help:"Token for P2P mode (optional)" group:"p2p"`
+ Peer2PeerNetworkID string `env:"LOCALAI_P2P_NETWORK_ID,P2P_NETWORK_ID" help:"Network ID for P2P mode, can be set arbitrarly by the user for grouping a set of instances" group:"p2p"`
+ ParallelRequests bool `env:"LOCALAI_PARALLEL_REQUESTS,PARALLEL_REQUESTS" help:"Enable backends to handle multiple requests in parallel if they support it (e.g.: llama.cpp or vllm)" group:"backends"`
+ SingleActiveBackend bool `env:"LOCALAI_SINGLE_ACTIVE_BACKEND,SINGLE_ACTIVE_BACKEND" help:"Allow only one backend to be run at a time (deprecated: use --max-active-backends=1 instead)" group:"backends"`
+ MaxActiveBackends int `env:"LOCALAI_MAX_ACTIVE_BACKENDS,MAX_ACTIVE_BACKENDS" default:"0" help:"Maximum number of backends to keep loaded at once (0 = unlimited, 1 = single backend mode). Least recently used backends are evicted when limit is reached" group:"backends"`
+ PreloadBackendOnly bool `env:"LOCALAI_PRELOAD_BACKEND_ONLY,PRELOAD_BACKEND_ONLY" default:"false" help:"Do not launch the API services, only the preloaded models / backends are started (useful for multi-node setups)" group:"backends"`
+ ExternalGRPCBackends []string `env:"LOCALAI_EXTERNAL_GRPC_BACKENDS,EXTERNAL_GRPC_BACKENDS" help:"A list of external grpc backends" group:"backends"`
+ EnableWatchdogIdle bool `env:"LOCALAI_WATCHDOG_IDLE,WATCHDOG_IDLE" default:"false" help:"Enable watchdog for stopping backends that are idle longer than the watchdog-idle-timeout" group:"backends"`
+ WatchdogIdleTimeout string `env:"LOCALAI_WATCHDOG_IDLE_TIMEOUT,WATCHDOG_IDLE_TIMEOUT" default:"15m" help:"Threshold beyond which an idle backend should be stopped" group:"backends"`
+ EnableWatchdogBusy bool `env:"LOCALAI_WATCHDOG_BUSY,WATCHDOG_BUSY" default:"false" help:"Enable watchdog for stopping backends that are busy longer than the watchdog-busy-timeout" group:"backends"`
+ WatchdogBusyTimeout string `env:"LOCALAI_WATCHDOG_BUSY_TIMEOUT,WATCHDOG_BUSY_TIMEOUT" default:"5m" help:"Threshold beyond which a busy backend should be stopped" group:"backends"`
+ EnableMemoryReclaimer bool `env:"LOCALAI_MEMORY_RECLAIMER,MEMORY_RECLAIMER,LOCALAI_GPU_RECLAIMER,GPU_RECLAIMER" default:"false" help:"Enable memory threshold monitoring to auto-evict backends when memory usage exceeds threshold (uses GPU VRAM if available, otherwise RAM)" group:"backends"`
+ MemoryReclaimerThreshold float64 `env:"LOCALAI_MEMORY_RECLAIMER_THRESHOLD,MEMORY_RECLAIMER_THRESHOLD,LOCALAI_GPU_RECLAIMER_THRESHOLD,GPU_RECLAIMER_THRESHOLD" default:"0.95" help:"Memory usage threshold (0.0-1.0) that triggers backend eviction (default 0.95 = 95%%)" group:"backends"`
+ ForceEvictionWhenBusy bool `env:"LOCALAI_FORCE_EVICTION_WHEN_BUSY,FORCE_EVICTION_WHEN_BUSY" default:"false" help:"Force eviction even when models have active API calls (default: false for safety)" group:"backends"`
+ LRUEvictionMaxRetries int `env:"LOCALAI_LRU_EVICTION_MAX_RETRIES,LRU_EVICTION_MAX_RETRIES" default:"30" help:"Maximum number of retries when waiting for busy models to become idle before eviction (default: 30)" group:"backends"`
+ LRUEvictionRetryInterval string `env:"LOCALAI_LRU_EVICTION_RETRY_INTERVAL,LRU_EVICTION_RETRY_INTERVAL" default:"1s" help:"Interval between retries when waiting for busy models to become idle (e.g., 1s, 2s) (default: 1s)" group:"backends"`
+ Federated bool `env:"LOCALAI_FEDERATED,FEDERATED" help:"Enable federated instance" group:"federated"`
+ DisableGalleryEndpoint bool `env:"LOCALAI_DISABLE_GALLERY_ENDPOINT,DISABLE_GALLERY_ENDPOINT" help:"Disable the gallery endpoints" group:"api"`
+ MachineTag string `env:"LOCALAI_MACHINE_TAG,MACHINE_TAG" help:"Add Machine-Tag header to each response which is useful to track the machine in the P2P network" group:"api"`
+ LoadToMemory []string `env:"LOCALAI_LOAD_TO_MEMORY,LOAD_TO_MEMORY" help:"A list of models to load into memory at startup" group:"models"`
+ EnableTracing bool `env:"LOCALAI_ENABLE_TRACING,ENABLE_TRACING" help:"Enable API tracing" group:"api"`
+ TracingMaxItems int `env:"LOCALAI_TRACING_MAX_ITEMS" default:"1024" help:"Maximum number of traces to keep" group:"api"`
+ AgentJobRetentionDays int `env:"LOCALAI_AGENT_JOB_RETENTION_DAYS,AGENT_JOB_RETENTION_DAYS" default:"30" help:"Number of days to keep agent job history (default: 30)" group:"api"`
+
+ Version bool
+}
+
+func (r *RunCMD) Run(ctx *cliContext.Context) error {
+ if r.Version {
+ fmt.Println(internal.Version)
+ return nil
+ }
+
+ os.MkdirAll(r.BackendsPath, 0750)
+ os.MkdirAll(r.ModelsPath, 0750)
+
+ systemState, err := system.GetSystemState(
+ system.WithBackendSystemPath(r.BackendsSystemPath),
+ system.WithModelPath(r.ModelsPath),
+ system.WithBackendPath(r.BackendsPath),
+ )
+ if err != nil {
+ return err
+ }
+
+ opts := []config.AppOption{
+ config.WithContext(context.Background()),
+ config.WithConfigFile(r.ModelsConfigFile),
+ config.WithJSONStringPreload(r.PreloadModels),
+ config.WithYAMLConfigPreload(r.PreloadModelsConfig),
+ config.WithSystemState(systemState),
+ config.WithContextSize(r.ContextSize),
+ config.WithDebug(ctx.Debug || (ctx.LogLevel != nil && *ctx.LogLevel == "debug")),
+ config.WithGeneratedContentDir(r.GeneratedContentPath),
+ config.WithUploadDir(r.UploadPath),
+ config.WithDynamicConfigDir(r.LocalaiConfigDir),
+ config.WithDynamicConfigDirPollInterval(r.LocalaiConfigDirPollInterval),
+ config.WithF16(r.F16),
+ config.WithStringGalleries(r.Galleries),
+ config.WithBackendGalleries(r.BackendGalleries),
+ config.WithCors(r.CORS),
+ config.WithCorsAllowOrigins(r.CORSAllowOrigins),
+ config.WithCsrf(r.CSRF),
+ config.WithThreads(r.Threads),
+ config.WithUploadLimitMB(r.UploadLimit),
+ config.WithApiKeys(r.APIKeys),
+ config.WithModelsURL(append(r.Models, r.ModelArgs...)...),
+ config.WithExternalBackends(r.ExternalBackends...),
+ config.WithOpaqueErrors(r.OpaqueErrors),
+ config.WithEnforcedPredownloadScans(!r.DisablePredownloadScan),
+ config.WithSubtleKeyComparison(r.UseSubtleKeyComparison),
+ config.WithDisableApiKeyRequirementForHttpGet(r.DisableApiKeyRequirementForHttpGet),
+ config.WithHttpGetExemptedEndpoints(r.HttpGetExemptedEndpoints),
+ config.WithP2PNetworkID(r.Peer2PeerNetworkID),
+ config.WithLoadToMemory(r.LoadToMemory),
+ config.WithMachineTag(r.MachineTag),
+ config.WithAPIAddress(r.Address),
+ config.WithAgentJobRetentionDays(r.AgentJobRetentionDays),
+ config.WithTunnelCallback(func(tunnels []string) {
+ tunnelEnvVar := strings.Join(tunnels, ",")
+ // TODO: this is very specific to llama.cpp, we should have a more generic way to set the environment variable
+ os.Setenv("LLAMACPP_GRPC_SERVERS", tunnelEnvVar)
+ xlog.Debug("setting LLAMACPP_GRPC_SERVERS", "value", tunnelEnvVar)
+ }),
+ }
+
+ if r.DisableMetricsEndpoint {
+ opts = append(opts, config.DisableMetricsEndpoint)
+ }
+
+ if r.DisableRuntimeSettings {
+ opts = append(opts, config.DisableRuntimeSettings)
+ }
+
+ if r.EnableTracing {
+ opts = append(opts, config.EnableTracing)
+ }
+
+ if r.EnableTracing {
+ opts = append(opts, config.EnableTracing)
+ }
+ opts = append(opts, config.WithTracingMaxItems(r.TracingMaxItems))
+
+ token := ""
+ if r.Peer2Peer || r.Peer2PeerToken != "" {
+ xlog.Info("P2P mode enabled")
+ token = r.Peer2PeerToken
+ if token == "" {
+ // IF no token is provided, and p2p is enabled,
+ // we generate one and wait for the user to pick up the token (this is for interactive)
+ xlog.Info("No token provided, generating one")
+ token = p2p.GenerateToken(r.Peer2PeerDHTInterval, r.Peer2PeerOTPInterval)
+ xlog.Info("Generated Token:")
+ fmt.Println(token)
+
+ xlog.Info("To use the token, you can run the following command in another node or terminal:")
+ fmt.Printf("export TOKEN=\"%s\"\nlocal-ai worker p2p-llama-cpp-rpc\n", token)
+ }
+ opts = append(opts, config.WithP2PToken(token))
+ }
+
+ if r.Federated {
+ opts = append(opts, config.EnableFederated)
+ }
+
+ idleWatchDog := r.EnableWatchdogIdle
+ busyWatchDog := r.EnableWatchdogBusy
+
+ if r.DisableWebUI {
+ opts = append(opts, config.DisableWebUI)
+ }
+
+ if r.DisableGalleryEndpoint {
+ opts = append(opts, config.DisableGalleryEndpoint)
+ }
+
+ if idleWatchDog || busyWatchDog {
+ opts = append(opts, config.EnableWatchDog)
+ if idleWatchDog {
+ opts = append(opts, config.EnableWatchDogIdleCheck)
+ dur, err := time.ParseDuration(r.WatchdogIdleTimeout)
+ if err != nil {
+ return err
+ }
+ opts = append(opts, config.SetWatchDogIdleTimeout(dur))
+ }
+ if busyWatchDog {
+ opts = append(opts, config.EnableWatchDogBusyCheck)
+ dur, err := time.ParseDuration(r.WatchdogBusyTimeout)
+ if err != nil {
+ return err
+ }
+ opts = append(opts, config.SetWatchDogBusyTimeout(dur))
+ }
+ }
+
+ // Handle memory reclaimer (uses GPU VRAM if available, otherwise RAM)
+ if r.EnableMemoryReclaimer {
+ opts = append(opts, config.WithMemoryReclaimer(true, r.MemoryReclaimerThreshold))
+ }
+
+ if r.ParallelRequests {
+ opts = append(opts, config.EnableParallelBackendRequests)
+ }
+
+ // Handle max active backends (LRU eviction)
+ // MaxActiveBackends takes precedence over SingleActiveBackend
+ if r.MaxActiveBackends > 0 {
+ opts = append(opts, config.SetMaxActiveBackends(r.MaxActiveBackends))
+ } else if r.SingleActiveBackend {
+ // Backward compatibility: --single-active-backend is equivalent to --max-active-backends=1
+ opts = append(opts, config.EnableSingleBackend)
+ }
+
+ // Handle LRU eviction settings
+ if r.ForceEvictionWhenBusy {
+ opts = append(opts, config.WithForceEvictionWhenBusy(true))
+ }
+ if r.LRUEvictionMaxRetries > 0 {
+ opts = append(opts, config.WithLRUEvictionMaxRetries(r.LRUEvictionMaxRetries))
+ }
+ if r.LRUEvictionRetryInterval != "" {
+ dur, err := time.ParseDuration(r.LRUEvictionRetryInterval)
+ if err != nil {
+ return fmt.Errorf("invalid LRU eviction retry interval: %w", err)
+ }
+ opts = append(opts, config.WithLRUEvictionRetryInterval(dur))
+ }
+
+ // split ":" to get backend name and the uri
+ for _, v := range r.ExternalGRPCBackends {
+ backend := v[:strings.IndexByte(v, ':')]
+ uri := v[strings.IndexByte(v, ':')+1:]
+ opts = append(opts, config.WithExternalBackend(backend, uri))
+ }
+
+ if r.AutoloadGalleries {
+ opts = append(opts, config.EnableGalleriesAutoload)
+ }
+
+ if r.AutoloadBackendGalleries {
+ opts = append(opts, config.EnableBackendGalleriesAutoload)
+ }
+
+ if r.PreloadBackendOnly {
+ _, err := application.New(opts...)
+ return err
+ }
+
+ app, err := application.New(opts...)
+ if err != nil {
+ return fmt.Errorf("failed basic startup tasks with error %s", err.Error())
+ }
+
+ appHTTP, err := http.API(app)
+ if err != nil {
+ xlog.Error("error during HTTP App construction", "error", err)
+ return err
+ }
+
+ xlog.Info("LocalAI is started and running", "address", r.Address)
+
+ if token != "" {
+ if err := app.StartP2P(); err != nil {
+ return err
+ }
+ }
+
+ signals.RegisterGracefulTerminationHandler(func() {
+ if err := app.ModelLoader().StopAllGRPC(); err != nil {
+ xlog.Error("error while stopping all grpc backends", "error", err)
+ }
+ })
+
+ return appHTTP.Start(r.Address)
+}
diff --git a/core/cli/soundgeneration.go b/core/cli/soundgeneration.go
new file mode 100644
index 0000000000000000000000000000000000000000..5ddf96444fd504f94280a02ec22399663da44ae1
--- /dev/null
+++ b/core/cli/soundgeneration.go
@@ -0,0 +1,117 @@
+package cli
+
+import (
+ "context"
+ "fmt"
+ "os"
+ "path/filepath"
+ "strconv"
+ "strings"
+
+ "github.com/mudler/LocalAI/core/backend"
+ cliContext "github.com/mudler/LocalAI/core/cli/context"
+ "github.com/mudler/LocalAI/core/config"
+ "github.com/mudler/LocalAI/pkg/model"
+ "github.com/mudler/LocalAI/pkg/system"
+ "github.com/mudler/xlog"
+)
+
+type SoundGenerationCMD struct {
+ Text []string `arg:""`
+
+ Backend string `short:"b" required:"" help:"Backend to run the SoundGeneration model"`
+ Model string `short:"m" required:"" help:"Model name to run the SoundGeneration"`
+ Duration string `short:"d" help:"If specified, the length of audio to generate in seconds"`
+ Temperature string `short:"t" help:"If specified, the temperature of the generation"`
+ InputFile string `short:"i" help:"If specified, the input file to condition generation upon"`
+ InputFileSampleDivisor string `short:"f" help:"If InputFile and this divisor is specified, the first portion of the sample file will be used"`
+ DoSample bool `short:"s" default:"true" help:"Enables sampling from the model. Better quality at the cost of speed. Defaults to enabled."`
+ OutputFile string `short:"o" type:"path" help:"The path to write the output wav file"`
+ ModelsPath string `env:"LOCALAI_MODELS_PATH,MODELS_PATH" type:"path" default:"${basepath}/models" help:"Path containing models used for inferencing" group:"storage"`
+ ExternalGRPCBackends []string `env:"LOCALAI_EXTERNAL_GRPC_BACKENDS,EXTERNAL_GRPC_BACKENDS" help:"A list of external grpc backends" group:"backends"`
+}
+
+func parseToFloat32Ptr(input string) *float32 {
+ f, err := strconv.ParseFloat(input, 32)
+ if err != nil {
+ return nil
+ }
+ f2 := float32(f)
+ return &f2
+}
+
+func parseToInt32Ptr(input string) *int32 {
+ i, err := strconv.ParseInt(input, 10, 32)
+ if err != nil {
+ return nil
+ }
+ i2 := int32(i)
+ return &i2
+}
+
+func (t *SoundGenerationCMD) Run(ctx *cliContext.Context) error {
+ outputFile := t.OutputFile
+ outputDir := os.TempDir()
+ if outputFile != "" {
+ outputDir = filepath.Dir(outputFile)
+ }
+ text := strings.Join(t.Text, " ")
+
+ systemState, err := system.GetSystemState(
+ system.WithModelPath(t.ModelsPath),
+ )
+ if err != nil {
+ return err
+ }
+
+ externalBackends := make(map[string]string)
+ // split ":" to get backend name and the uri
+ for _, v := range t.ExternalGRPCBackends {
+ backend := v[:strings.IndexByte(v, ':')]
+ uri := v[strings.IndexByte(v, ':')+1:]
+ externalBackends[backend] = uri
+ fmt.Printf("TMP externalBackends[%q]=%q\n\n", backend, uri)
+ }
+
+ opts := &config.ApplicationConfig{
+ SystemState: systemState,
+ Context: context.Background(),
+ GeneratedContentDir: outputDir,
+ ExternalGRPCBackends: externalBackends,
+ }
+ ml := model.NewModelLoader(systemState)
+
+ defer func() {
+ err := ml.StopAllGRPC()
+ if err != nil {
+ xlog.Error("unable to stop all grpc processes", "error", err)
+ }
+ }()
+
+ options := config.ModelConfig{}
+ options.SetDefaults()
+ options.Backend = t.Backend
+ options.Model = t.Model
+
+ var inputFile *string
+ if t.InputFile != "" {
+ inputFile = &t.InputFile
+ }
+
+ filePath, _, err := backend.SoundGeneration(text,
+ parseToFloat32Ptr(t.Duration), parseToFloat32Ptr(t.Temperature), &t.DoSample,
+ inputFile, parseToInt32Ptr(t.InputFileSampleDivisor), ml, opts, options)
+
+ if err != nil {
+ return err
+ }
+ if outputFile != "" {
+ if err := os.Rename(filePath, outputFile); err != nil {
+ return err
+ }
+ fmt.Printf("Generate file %s\n", outputFile)
+ } else {
+ fmt.Printf("Generate file %s\n", filePath)
+ }
+ return nil
+}
diff --git a/core/cli/transcript.go b/core/cli/transcript.go
new file mode 100644
index 0000000000000000000000000000000000000000..07da1989388aec601b2904a062f404a28aa38002
--- /dev/null
+++ b/core/cli/transcript.go
@@ -0,0 +1,69 @@
+package cli
+
+import (
+ "context"
+ "errors"
+ "fmt"
+
+ "github.com/mudler/LocalAI/core/backend"
+ cliContext "github.com/mudler/LocalAI/core/cli/context"
+ "github.com/mudler/LocalAI/core/config"
+ "github.com/mudler/LocalAI/pkg/model"
+ "github.com/mudler/LocalAI/pkg/system"
+ "github.com/mudler/xlog"
+)
+
+type TranscriptCMD struct {
+ Filename string `arg:""`
+
+ Backend string `short:"b" default:"whisper" help:"Backend to run the transcription model"`
+ Model string `short:"m" required:"" help:"Model name to run the TTS"`
+ Language string `short:"l" help:"Language of the audio file"`
+ Translate bool `short:"c" help:"Translate the transcription to english"`
+ Diarize bool `short:"d" help:"Mark speaker turns"`
+ Threads int `short:"t" default:"1" help:"Number of threads used for parallel computation"`
+ ModelsPath string `env:"LOCALAI_MODELS_PATH,MODELS_PATH" type:"path" default:"${basepath}/models" help:"Path containing models used for inferencing" group:"storage"`
+ Prompt string `short:"p" help:"Previous transcribed text or words that hint at what the model should expect"`
+}
+
+func (t *TranscriptCMD) Run(ctx *cliContext.Context) error {
+ systemState, err := system.GetSystemState(
+ system.WithModelPath(t.ModelsPath),
+ )
+ if err != nil {
+ return err
+ }
+ opts := &config.ApplicationConfig{
+ SystemState: systemState,
+ Context: context.Background(),
+ }
+
+ cl := config.NewModelConfigLoader(t.ModelsPath)
+ ml := model.NewModelLoader(systemState)
+ if err := cl.LoadModelConfigsFromPath(t.ModelsPath); err != nil {
+ return err
+ }
+
+ c, exists := cl.GetModelConfig(t.Model)
+ if !exists {
+ return errors.New("model not found")
+ }
+
+ c.Threads = &t.Threads
+
+ defer func() {
+ err := ml.StopAllGRPC()
+ if err != nil {
+ xlog.Error("unable to stop all grpc processes", "error", err)
+ }
+ }()
+
+ tr, err := backend.ModelTranscription(t.Filename, t.Language, t.Translate, t.Diarize, t.Prompt, ml, c, opts)
+ if err != nil {
+ return err
+ }
+ for _, segment := range tr.Segments {
+ fmt.Println(segment.Start.String(), "-", segment.Text)
+ }
+ return nil
+}
diff --git a/core/cli/tts.go b/core/cli/tts.go
new file mode 100644
index 0000000000000000000000000000000000000000..72d4ee24b84b9fa5868459f7697338f2b54f9046
--- /dev/null
+++ b/core/cli/tts.go
@@ -0,0 +1,78 @@
+package cli
+
+import (
+ "context"
+ "fmt"
+ "os"
+ "path/filepath"
+ "strings"
+
+ "github.com/mudler/LocalAI/core/backend"
+ cliContext "github.com/mudler/LocalAI/core/cli/context"
+ "github.com/mudler/LocalAI/core/config"
+ "github.com/mudler/LocalAI/pkg/model"
+ "github.com/mudler/LocalAI/pkg/system"
+ "github.com/mudler/xlog"
+)
+
+type TTSCMD struct {
+ Text []string `arg:""`
+
+ Backend string `short:"b" default:"piper" help:"Backend to run the TTS model"`
+ Model string `short:"m" required:"" help:"Model name to run the TTS"`
+ Voice string `short:"v" help:"Voice name to run the TTS"`
+ Language string `short:"l" help:"Language to use with the TTS"`
+ OutputFile string `short:"o" type:"path" help:"The path to write the output wav file"`
+ ModelsPath string `env:"LOCALAI_MODELS_PATH,MODELS_PATH" type:"path" default:"${basepath}/models" help:"Path containing models used for inferencing" group:"storage"`
+}
+
+func (t *TTSCMD) Run(ctx *cliContext.Context) error {
+ outputFile := t.OutputFile
+ outputDir := os.TempDir()
+ if outputFile != "" {
+ outputDir = filepath.Dir(outputFile)
+ }
+
+ text := strings.Join(t.Text, " ")
+
+ systemState, err := system.GetSystemState(
+ system.WithModelPath(t.ModelsPath),
+ )
+ if err != nil {
+ return err
+ }
+
+ opts := &config.ApplicationConfig{
+ SystemState: systemState,
+ Context: context.Background(),
+ GeneratedContentDir: outputDir,
+ }
+
+ ml := model.NewModelLoader(systemState)
+
+ defer func() {
+ err := ml.StopAllGRPC()
+ if err != nil {
+ xlog.Error("unable to stop all grpc processes", "error", err)
+ }
+ }()
+
+ options := config.ModelConfig{}
+ options.SetDefaults()
+ options.Backend = t.Backend
+ options.Model = t.Model
+
+ filePath, _, err := backend.ModelTTS(text, t.Voice, t.Language, ml, opts, options)
+ if err != nil {
+ return err
+ }
+ if outputFile != "" {
+ if err := os.Rename(filePath, outputFile); err != nil {
+ return err
+ }
+ fmt.Printf("Generate file %s\n", outputFile)
+ } else {
+ fmt.Printf("Generate file %s\n", filePath)
+ }
+ return nil
+}
diff --git a/core/cli/util.go b/core/cli/util.go
new file mode 100644
index 0000000000000000000000000000000000000000..b002e254e78951b436d4ccf2c6197d099dc49a95
--- /dev/null
+++ b/core/cli/util.go
@@ -0,0 +1,175 @@
+package cli
+
+import (
+ "encoding/json"
+ "errors"
+ "fmt"
+ "os"
+ "path/filepath"
+ "strings"
+
+ "github.com/mholt/archiver/v3"
+ "github.com/mudler/xlog"
+
+ gguf "github.com/gpustack/gguf-parser-go"
+ cliContext "github.com/mudler/LocalAI/core/cli/context"
+ "github.com/mudler/LocalAI/core/config"
+ "github.com/mudler/LocalAI/core/gallery"
+ "github.com/mudler/LocalAI/pkg/downloader"
+ "github.com/mudler/LocalAI/pkg/oci"
+ "github.com/mudler/LocalAI/pkg/system"
+)
+
+type UtilCMD struct {
+ GGUFInfo GGUFInfoCMD `cmd:"" name:"gguf-info" help:"Get information about a GGUF file"`
+ CreateOCIImage CreateOCIImageCMD `cmd:"" name:"create-oci-image" help:"Create an OCI image from a file or a directory"`
+ HFScan HFScanCMD `cmd:"" name:"hf-scan" help:"Checks installed models for known security issues. WARNING: this is a best-effort feature and may not catch everything!"`
+ UsecaseHeuristic UsecaseHeuristicCMD `cmd:"" name:"usecase-heuristic" help:"Checks a specific model config and prints what usecase LocalAI will offer for it."`
+}
+
+type GGUFInfoCMD struct {
+ Args []string `arg:"" optional:"" name:"args" help:"Arguments to pass to the utility command"`
+ Header bool `optional:"" default:"false" name:"header" help:"Show header information"`
+}
+
+type HFScanCMD struct {
+ ModelsPath string `env:"LOCALAI_MODELS_PATH,MODELS_PATH" type:"path" default:"${basepath}/models" help:"Path containing models used for inferencing" group:"storage"`
+ Galleries string `env:"LOCALAI_GALLERIES,GALLERIES" help:"JSON list of galleries" group:"models" default:"${galleries}"`
+ ToScan []string `arg:""`
+}
+
+type UsecaseHeuristicCMD struct {
+ ConfigName string `name:"The config file to check"`
+ ModelsPath string `env:"LOCALAI_MODELS_PATH,MODELS_PATH" type:"path" default:"${basepath}/models" help:"Path containing models used for inferencing" group:"storage"`
+}
+
+type CreateOCIImageCMD struct {
+ Input []string `arg:"" help:"Input file or directory to create an OCI image from"`
+ Output string `default:"image.tar" help:"Output OCI image name"`
+ ImageName string `default:"localai" help:"Image name"`
+ Platform string `default:"linux/amd64" help:"Platform of the image"`
+}
+
+func (u *CreateOCIImageCMD) Run(ctx *cliContext.Context) error {
+ xlog.Info("Creating OCI image from input")
+
+ dir, err := os.MkdirTemp("", "localai")
+ if err != nil {
+ return err
+ }
+ defer os.RemoveAll(dir)
+ err = archiver.Archive(u.Input, filepath.Join(dir, "archive.tar"))
+ if err != nil {
+ return err
+ }
+ xlog.Info("Creating OCI image", "output", u.Output, "input", u.Input)
+
+ platform := strings.Split(u.Platform, "/")
+ if len(platform) != 2 {
+ return fmt.Errorf("invalid platform: %s", u.Platform)
+ }
+
+ return oci.CreateTar(filepath.Join(dir, "archive.tar"), u.Output, u.ImageName, platform[1], platform[0])
+}
+
+func (u *GGUFInfoCMD) Run(ctx *cliContext.Context) error {
+ if len(u.Args) == 0 {
+ return fmt.Errorf("no GGUF file provided")
+ }
+ // We try to guess only if we don't have a template defined already
+ f, err := gguf.ParseGGUFFile(u.Args[0])
+ if err != nil {
+ // Only valid for gguf files
+ xlog.Error("guessDefaultsFromFile: not a GGUF file")
+ return err
+ }
+
+ xlog.Info("GGUF file loaded", "file", u.Args[0], "eosTokenID", f.Tokenizer().EOSTokenID, "bosTokenID", f.Tokenizer().BOSTokenID, "modelName", f.Metadata().Name, "architecture", f.Architecture().Architecture)
+
+ xlog.Info("Tokenizer", "tokenizer", fmt.Sprintf("%+v", f.Tokenizer()))
+ xlog.Info("Architecture", "architecture", fmt.Sprintf("%+v", f.Architecture()))
+
+ v, exists := f.Header.MetadataKV.Get("tokenizer.chat_template")
+ if exists {
+ xlog.Info("chat_template", "template", v.ValueString())
+ }
+
+ if u.Header {
+ for _, metadata := range f.Header.MetadataKV {
+ xlog.Info("metadata", "key", metadata.Key, "value", metadata.Value)
+ }
+ // log.Info().Any("header", fmt.Sprintf("%+v", f.Header)).Msg("Header")
+ }
+
+ return nil
+}
+
+func (hfscmd *HFScanCMD) Run(ctx *cliContext.Context) error {
+
+ systemState, err := system.GetSystemState(
+ system.WithModelPath(hfscmd.ModelsPath),
+ )
+ if err != nil {
+ return err
+ }
+
+ xlog.Info("LocalAI Security Scanner - This is BEST EFFORT functionality! Currently limited to huggingface models!")
+ if len(hfscmd.ToScan) == 0 {
+ xlog.Info("Checking all installed models against galleries")
+ var galleries []config.Gallery
+ if err := json.Unmarshal([]byte(hfscmd.Galleries), &galleries); err != nil {
+ xlog.Error("unable to load galleries", "error", err)
+ }
+
+ err := gallery.SafetyScanGalleryModels(galleries, systemState)
+ if err == nil {
+ xlog.Info("No security warnings were detected for your installed models. Please note that this is a BEST EFFORT tool, and all issues may not be detected.")
+ } else {
+ xlog.Error("! WARNING ! A known-vulnerable model is installed!", "error", err)
+ }
+ return err
+ } else {
+ var errs error = nil
+ for _, uri := range hfscmd.ToScan {
+ xlog.Info("scanning specific uri", "uri", uri)
+ scanResults, err := downloader.HuggingFaceScan(downloader.URI(uri))
+ if err != nil && errors.Is(err, downloader.ErrUnsafeFilesFound) {
+ xlog.Error("! WARNING ! A known-vulnerable model is included in this repo!", "error", err, "clamAV", scanResults.ClamAVInfectedFiles, "pickles", scanResults.DangerousPickles)
+ errs = errors.Join(errs, err)
+ }
+ }
+ if errs != nil {
+ return errs
+ }
+ xlog.Info("No security warnings were detected for your installed models. Please note that this is a BEST EFFORT tool, and all issues may not be detected.")
+ return nil
+ }
+}
+
+func (uhcmd *UsecaseHeuristicCMD) Run(ctx *cliContext.Context) error {
+ if len(uhcmd.ConfigName) == 0 {
+ xlog.Error("ConfigName is a required parameter")
+ return fmt.Errorf("config name is a required parameter")
+ }
+ if len(uhcmd.ModelsPath) == 0 {
+ xlog.Error("ModelsPath is a required parameter")
+ return fmt.Errorf("model path is a required parameter")
+ }
+ bcl := config.NewModelConfigLoader(uhcmd.ModelsPath)
+ err := bcl.ReadModelConfig(uhcmd.ConfigName)
+ if err != nil {
+ xlog.Error("error while loading backend", "error", err, "ConfigName", uhcmd.ConfigName)
+ return err
+ }
+ bc, exists := bcl.GetModelConfig(uhcmd.ConfigName)
+ if !exists {
+ xlog.Error("ConfigName not found", "ConfigName", uhcmd.ConfigName)
+ }
+ for name, uc := range config.GetAllModelConfigUsecases() {
+ if bc.HasUsecases(uc) {
+ xlog.Info("Usecase", "usecase", name)
+ }
+ }
+ xlog.Info("---")
+ return nil
+}
diff --git a/core/cli/worker/worker.go b/core/cli/worker/worker.go
new file mode 100644
index 0000000000000000000000000000000000000000..0a636c3bfacbdec488e40349b850decb7cff4145
--- /dev/null
+++ b/core/cli/worker/worker.go
@@ -0,0 +1,13 @@
+package worker
+
+type WorkerFlags struct {
+ BackendsPath string `env:"LOCALAI_BACKENDS_PATH,BACKENDS_PATH" type:"path" default:"${basepath}/backends" help:"Path containing backends used for inferencing" group:"backends"`
+ BackendGalleries string `env:"LOCALAI_BACKEND_GALLERIES,BACKEND_GALLERIES" help:"JSON list of backend galleries" group:"backends" default:"${backends}"`
+ BackendsSystemPath string `env:"LOCALAI_BACKENDS_SYSTEM_PATH,BACKEND_SYSTEM_PATH" type:"path" default:"/var/lib/local-ai/backends" help:"Path containing system backends used for inferencing" group:"backends"`
+ ExtraLLamaCPPArgs string `name:"llama-cpp-args" env:"LOCALAI_EXTRA_LLAMA_CPP_ARGS,EXTRA_LLAMA_CPP_ARGS" help:"Extra arguments to pass to llama-cpp-rpc-server"`
+}
+
+type Worker struct {
+ P2P P2P `cmd:"" name:"p2p-llama-cpp-rpc" help:"Starts a LocalAI llama.cpp worker in P2P mode (requires a token)"`
+ LLamaCPP LLamaCPP `cmd:"" name:"llama-cpp-rpc" help:"Starts a llama.cpp worker in standalone mode"`
+}
diff --git a/core/cli/worker/worker_llamacpp.go b/core/cli/worker/worker_llamacpp.go
new file mode 100644
index 0000000000000000000000000000000000000000..4f8e8e115566d83ecb12da15eb4e70e007ad1052
--- /dev/null
+++ b/core/cli/worker/worker_llamacpp.go
@@ -0,0 +1,92 @@
+package worker
+
+import (
+ "context"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "os"
+ "path/filepath"
+ "strings"
+ "syscall"
+
+ cliContext "github.com/mudler/LocalAI/core/cli/context"
+ "github.com/mudler/LocalAI/core/config"
+ "github.com/mudler/LocalAI/core/gallery"
+ "github.com/mudler/LocalAI/pkg/model"
+ "github.com/mudler/LocalAI/pkg/system"
+ "github.com/mudler/xlog"
+)
+
+type LLamaCPP struct {
+ WorkerFlags `embed:""`
+}
+
+const (
+ llamaCPPRPCBinaryName = "llama-cpp-rpc-server"
+ llamaCPPGalleryName = "llama-cpp"
+)
+
+func findLLamaCPPBackend(galleries string, systemState *system.SystemState) (string, error) {
+ backends, err := gallery.ListSystemBackends(systemState)
+ if err != nil {
+ xlog.Warn("Failed listing system backends", "error", err)
+ return "", err
+ }
+ xlog.Debug("System backends", "backends", backends)
+
+ backend, ok := backends.Get(llamaCPPGalleryName)
+ if !ok {
+ ml := model.NewModelLoader(systemState)
+ var gals []config.Gallery
+ if err := json.Unmarshal([]byte(galleries), &gals); err != nil {
+ xlog.Error("failed loading galleries", "error", err)
+ return "", err
+ }
+ err := gallery.InstallBackendFromGallery(context.Background(), gals, systemState, ml, llamaCPPGalleryName, nil, true)
+ if err != nil {
+ xlog.Error("llama-cpp backend not found, failed to install it", "error", err)
+ return "", err
+ }
+ }
+ backendPath := filepath.Dir(backend.RunFile)
+
+ if backendPath == "" {
+ return "", errors.New("llama-cpp backend not found, install it first")
+ }
+
+ grpcProcess := filepath.Join(
+ backendPath,
+ llamaCPPRPCBinaryName,
+ )
+
+ return grpcProcess, nil
+}
+
+func (r *LLamaCPP) Run(ctx *cliContext.Context) error {
+
+ if len(os.Args) < 4 {
+ return fmt.Errorf("usage: local-ai worker llama-cpp-rpc -- ")
+ }
+
+ systemState, err := system.GetSystemState(
+ system.WithBackendPath(r.BackendsPath),
+ system.WithBackendSystemPath(r.BackendsSystemPath),
+ )
+ if err != nil {
+ return err
+ }
+ grpcProcess, err := findLLamaCPPBackend(r.BackendGalleries, systemState)
+ if err != nil {
+ return err
+ }
+
+ args := strings.Split(r.ExtraLLamaCPPArgs, " ")
+
+ args = append([]string{grpcProcess}, args...)
+
+ return syscall.Exec(
+ grpcProcess,
+ args,
+ os.Environ())
+}
diff --git a/core/cli/worker/worker_p2p.go b/core/cli/worker/worker_p2p.go
new file mode 100644
index 0000000000000000000000000000000000000000..868357ccffd53bca5fe3fc14725f279626d852eb
--- /dev/null
+++ b/core/cli/worker/worker_p2p.go
@@ -0,0 +1,120 @@
+package worker
+
+import (
+ "context"
+ "fmt"
+ "os"
+ "os/exec"
+ "strings"
+ "time"
+
+ cliContext "github.com/mudler/LocalAI/core/cli/context"
+ "github.com/mudler/LocalAI/core/p2p"
+ "github.com/mudler/LocalAI/pkg/signals"
+ "github.com/mudler/LocalAI/pkg/system"
+ "github.com/mudler/xlog"
+ "github.com/phayes/freeport"
+)
+
+type P2P struct {
+ WorkerFlags `embed:""`
+ Token string `env:"LOCALAI_TOKEN,LOCALAI_P2P_TOKEN,TOKEN" help:"P2P token to use"`
+ NoRunner bool `env:"LOCALAI_NO_RUNNER,NO_RUNNER" help:"Do not start the llama-cpp-rpc-server"`
+ RunnerAddress string `env:"LOCALAI_RUNNER_ADDRESS,RUNNER_ADDRESS" help:"Address of the llama-cpp-rpc-server"`
+ RunnerPort string `env:"LOCALAI_RUNNER_PORT,RUNNER_PORT" help:"Port of the llama-cpp-rpc-server"`
+ Peer2PeerNetworkID string `env:"LOCALAI_P2P_NETWORK_ID,P2P_NETWORK_ID" help:"Network ID for P2P mode, can be set arbitrarly by the user for grouping a set of instances" group:"p2p"`
+}
+
+func (r *P2P) Run(ctx *cliContext.Context) error {
+
+ systemState, err := system.GetSystemState(
+ system.WithBackendPath(r.BackendsPath),
+ system.WithBackendSystemPath(r.BackendsSystemPath),
+ )
+ if err != nil {
+ return err
+ }
+
+ // Check if the token is set
+ // as we always need it.
+ if r.Token == "" {
+ return fmt.Errorf("Token is required")
+ }
+
+ port, err := freeport.GetFreePort()
+ if err != nil {
+ return err
+ }
+
+ address := "127.0.0.1"
+
+ c, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ if r.NoRunner {
+ // Let override which port and address to bind if the user
+ // configure the llama-cpp service on its own
+ p := fmt.Sprint(port)
+ if r.RunnerAddress != "" {
+ address = r.RunnerAddress
+ }
+ if r.RunnerPort != "" {
+ p = r.RunnerPort
+ }
+
+ _, err = p2p.ExposeService(c, address, p, r.Token, p2p.NetworkID(r.Peer2PeerNetworkID, p2p.WorkerID))
+ if err != nil {
+ return err
+ }
+ xlog.Info("You need to start llama-cpp-rpc-server", "address", address, "port", p)
+ } else {
+ // Start llama.cpp directly from the version we have pre-packaged
+ go func() {
+ for {
+ xlog.Info("Starting llama-cpp-rpc-server", "address", address, "port", port)
+
+ grpcProcess, err := findLLamaCPPBackend(r.BackendGalleries, systemState)
+ if err != nil {
+ xlog.Error("Failed to find llama-cpp-rpc-server", "error", err)
+ return
+ }
+
+ var extraArgs []string
+
+ if r.ExtraLLamaCPPArgs != "" {
+ extraArgs = strings.Split(r.ExtraLLamaCPPArgs, " ")
+ }
+ args := append([]string{"--host", address, "--port", fmt.Sprint(port)}, extraArgs...)
+ xlog.Debug("Starting llama-cpp-rpc-server", "address", address, "port", port, "args", args, "argCount", len(args))
+
+ cmd := exec.Command(
+ grpcProcess, args...,
+ )
+
+ cmd.Env = os.Environ()
+
+ cmd.Stderr = os.Stdout
+ cmd.Stdout = os.Stdout
+
+ if err := cmd.Start(); err != nil {
+ xlog.Error("Failed to start llama-cpp-rpc-server", "error", err, "grpcProcess", grpcProcess, "args", args)
+ }
+
+ cmd.Wait()
+ }
+ }()
+
+ _, err = p2p.ExposeService(c, address, fmt.Sprint(port), r.Token, p2p.NetworkID(r.Peer2PeerNetworkID, p2p.WorkerID))
+ if err != nil {
+ return err
+ }
+ }
+
+ signals.RegisterGracefulTerminationHandler(func() {
+ cancel()
+ })
+
+ for {
+ time.Sleep(1 * time.Second)
+ }
+}
diff --git a/core/clients/store.go b/core/clients/store.go
new file mode 100644
index 0000000000000000000000000000000000000000..f737ee4212e95768b1919e2915961c517bc9b607
--- /dev/null
+++ b/core/clients/store.go
@@ -0,0 +1,151 @@
+package clients
+
+import (
+ "bytes"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+)
+
+// Define a struct to hold the store API client
+type StoreClient struct {
+ BaseURL string
+ Client *http.Client
+}
+
+type SetRequest struct {
+ Keys [][]float32 `json:"keys"`
+ Values []string `json:"values"`
+}
+
+type GetRequest struct {
+ Keys [][]float32 `json:"keys"`
+}
+
+type GetResponse struct {
+ Keys [][]float32 `json:"keys"`
+ Values []string `json:"values"`
+}
+
+type DeleteRequest struct {
+ Keys [][]float32 `json:"keys"`
+}
+
+type FindRequest struct {
+ TopK int `json:"topk"`
+ Key []float32 `json:"key"`
+}
+
+type FindResponse struct {
+ Keys [][]float32 `json:"keys"`
+ Values []string `json:"values"`
+ Similarities []float32 `json:"similarities"`
+}
+
+// Constructor for StoreClient
+func NewStoreClient(baseUrl string) *StoreClient {
+ return &StoreClient{
+ BaseURL: baseUrl,
+ Client: &http.Client{},
+ }
+}
+
+// Implement Set method
+func (c *StoreClient) Set(req SetRequest) error {
+ return c.doRequest("stores/set", req)
+}
+
+// Implement Get method
+func (c *StoreClient) Get(req GetRequest) (*GetResponse, error) {
+ body, err := c.doRequestWithResponse("stores/get", req)
+ if err != nil {
+ return nil, err
+ }
+
+ var resp GetResponse
+ err = json.Unmarshal(body, &resp)
+ if err != nil {
+ return nil, err
+ }
+
+ return &resp, nil
+}
+
+// Implement Delete method
+func (c *StoreClient) Delete(req DeleteRequest) error {
+ return c.doRequest("stores/delete", req)
+}
+
+// Implement Find method
+func (c *StoreClient) Find(req FindRequest) (*FindResponse, error) {
+ body, err := c.doRequestWithResponse("stores/find", req)
+ if err != nil {
+ return nil, err
+ }
+
+ var resp FindResponse
+ err = json.Unmarshal(body, &resp)
+ if err != nil {
+ return nil, err
+ }
+
+ return &resp, nil
+}
+
+// Helper function to perform a request without expecting a response body
+func (c *StoreClient) doRequest(path string, data interface{}) error {
+ jsonData, err := json.Marshal(data)
+ if err != nil {
+ return err
+ }
+
+ req, err := http.NewRequest("POST", c.BaseURL+"/"+path, bytes.NewBuffer(jsonData))
+ if err != nil {
+ return err
+ }
+ req.Header.Set("Content-Type", "application/json")
+
+ resp, err := c.Client.Do(req)
+ if err != nil {
+ return err
+ }
+ defer resp.Body.Close()
+
+ if resp.StatusCode != http.StatusOK {
+ return fmt.Errorf("API request to %s failed with status code %d", path, resp.StatusCode)
+ }
+
+ return nil
+}
+
+// Helper function to perform a request and parse the response body
+func (c *StoreClient) doRequestWithResponse(path string, data interface{}) ([]byte, error) {
+ jsonData, err := json.Marshal(data)
+ if err != nil {
+ return nil, err
+ }
+
+ req, err := http.NewRequest("POST", c.BaseURL+"/"+path, bytes.NewBuffer(jsonData))
+ if err != nil {
+ return nil, err
+ }
+ req.Header.Set("Content-Type", "application/json")
+
+ resp, err := c.Client.Do(req)
+ if err != nil {
+ return nil, err
+ }
+ defer resp.Body.Close()
+
+ if resp.StatusCode != http.StatusOK {
+ return nil, fmt.Errorf("API request to %s failed with status code %d", path, resp.StatusCode)
+ }
+
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, err
+ }
+
+ return body, nil
+}
diff --git a/core/config/application_config.go b/core/config/application_config.go
new file mode 100644
index 0000000000000000000000000000000000000000..26b603f53aed93b679876a5df378f3f60752043d
--- /dev/null
+++ b/core/config/application_config.go
@@ -0,0 +1,781 @@
+package config
+
+import (
+ "context"
+ "encoding/json"
+ "regexp"
+ "time"
+
+ "github.com/mudler/LocalAI/pkg/system"
+ "github.com/mudler/LocalAI/pkg/xsysinfo"
+ "github.com/mudler/xlog"
+)
+
+type ApplicationConfig struct {
+ Context context.Context
+ ConfigFile string
+ SystemState *system.SystemState
+ ExternalBackends []string
+ UploadLimitMB, Threads, ContextSize int
+ F16 bool
+ Debug bool
+ EnableTracing bool
+ TracingMaxItems int
+ GeneratedContentDir string
+
+ UploadDir string
+
+ DynamicConfigsDir string
+ DynamicConfigsDirPollInterval time.Duration
+ CORS bool
+ CSRF bool
+ PreloadJSONModels string
+ PreloadModelsFromPath string
+ CORSAllowOrigins string
+ ApiKeys []string
+ P2PToken string
+ P2PNetworkID string
+ Federated bool
+
+ DisableWebUI bool
+ EnforcePredownloadScans bool
+ OpaqueErrors bool
+ UseSubtleKeyComparison bool
+ DisableApiKeyRequirementForHttpGet bool
+ DisableMetrics bool
+ HttpGetExemptedEndpoints []*regexp.Regexp
+ DisableGalleryEndpoint bool
+ LoadToMemory []string
+
+ Galleries []Gallery
+ BackendGalleries []Gallery
+
+ ExternalGRPCBackends map[string]string
+
+ AutoloadGalleries, AutoloadBackendGalleries bool
+
+ SingleBackend bool // Deprecated: use MaxActiveBackends = 1 instead
+ MaxActiveBackends int // Maximum number of active backends (0 = unlimited, 1 = single backend mode)
+ ParallelBackendRequests bool
+
+ WatchDogIdle bool
+ WatchDogBusy bool
+ WatchDog bool
+
+ // Memory Reclaimer settings (works with GPU if available, otherwise RAM)
+ MemoryReclaimerEnabled bool // Enable memory threshold monitoring
+ MemoryReclaimerThreshold float64 // Threshold 0.0-1.0 (e.g., 0.95 = 95%)
+
+ // Eviction settings
+ ForceEvictionWhenBusy bool // Force eviction even when models have active API calls (default: false for safety)
+ LRUEvictionMaxRetries int // Maximum number of retries when waiting for busy models to become idle (default: 30)
+ LRUEvictionRetryInterval time.Duration // Interval between retries when waiting for busy models (default: 1s)
+
+ ModelsURL []string
+
+ WatchDogBusyTimeout, WatchDogIdleTimeout time.Duration
+ WatchDogInterval time.Duration // Interval between watchdog checks
+
+ MachineTag string
+
+ APIAddress string
+
+ TunnelCallback func(tunnels []string)
+
+ DisableRuntimeSettings bool
+
+ AgentJobRetentionDays int // Default: 30 days
+
+ PathWithoutAuth []string
+}
+
+type AppOption func(*ApplicationConfig)
+
+func NewApplicationConfig(o ...AppOption) *ApplicationConfig {
+ opt := &ApplicationConfig{
+ Context: context.Background(),
+ UploadLimitMB: 15,
+ Debug: true,
+ AgentJobRetentionDays: 30, // Default: 30 days
+ LRUEvictionMaxRetries: 30, // Default: 30 retries
+ LRUEvictionRetryInterval: 1 * time.Second, // Default: 1 second
+ TracingMaxItems: 1024,
+ PathWithoutAuth: []string{
+ "/static/",
+ "/generated-audio/",
+ "/generated-images/",
+ "/generated-videos/",
+ "/favicon.svg",
+ "/readyz",
+ "/healthz",
+ },
+ }
+ for _, oo := range o {
+ oo(opt)
+ }
+ return opt
+}
+
+func WithModelsURL(urls ...string) AppOption {
+ return func(o *ApplicationConfig) {
+ o.ModelsURL = urls
+ }
+}
+
+func WithSystemState(state *system.SystemState) AppOption {
+ return func(o *ApplicationConfig) {
+ o.SystemState = state
+ }
+}
+
+func WithExternalBackends(backends ...string) AppOption {
+ return func(o *ApplicationConfig) {
+ o.ExternalBackends = backends
+ }
+}
+
+func WithMachineTag(tag string) AppOption {
+ return func(o *ApplicationConfig) {
+ o.MachineTag = tag
+ }
+}
+
+func WithCors(b bool) AppOption {
+ return func(o *ApplicationConfig) {
+ o.CORS = b
+ }
+}
+
+func WithP2PNetworkID(s string) AppOption {
+ return func(o *ApplicationConfig) {
+ o.P2PNetworkID = s
+ }
+}
+
+func WithCsrf(b bool) AppOption {
+ return func(o *ApplicationConfig) {
+ o.CSRF = b
+ }
+}
+
+func WithP2PToken(s string) AppOption {
+ return func(o *ApplicationConfig) {
+ o.P2PToken = s
+ }
+}
+
+var EnableWatchDog = func(o *ApplicationConfig) {
+ o.WatchDog = true
+}
+
+var EnableTracing = func(o *ApplicationConfig) {
+ o.EnableTracing = true
+}
+
+var EnableWatchDogIdleCheck = func(o *ApplicationConfig) {
+ o.WatchDog = true
+ o.WatchDogIdle = true
+}
+
+var DisableGalleryEndpoint = func(o *ApplicationConfig) {
+ o.DisableGalleryEndpoint = true
+}
+
+var EnableWatchDogBusyCheck = func(o *ApplicationConfig) {
+ o.WatchDog = true
+ o.WatchDogBusy = true
+}
+
+var DisableWebUI = func(o *ApplicationConfig) {
+ o.DisableWebUI = true
+}
+
+var DisableRuntimeSettings = func(o *ApplicationConfig) {
+ o.DisableRuntimeSettings = true
+}
+
+func SetWatchDogBusyTimeout(t time.Duration) AppOption {
+ return func(o *ApplicationConfig) {
+ o.WatchDogBusyTimeout = t
+ }
+}
+
+func SetWatchDogIdleTimeout(t time.Duration) AppOption {
+ return func(o *ApplicationConfig) {
+ o.WatchDogIdleTimeout = t
+ }
+}
+
+// EnableMemoryReclaimer enables memory threshold monitoring.
+// When enabled, the watchdog will evict backends if memory usage exceeds the threshold.
+// Works with GPU VRAM if available, otherwise uses system RAM.
+var EnableMemoryReclaimer = func(o *ApplicationConfig) {
+ o.MemoryReclaimerEnabled = true
+ o.WatchDog = true // Memory reclaimer requires watchdog infrastructure
+}
+
+// SetMemoryReclaimerThreshold sets the memory usage threshold (0.0-1.0).
+// When memory usage exceeds this threshold, backends will be evicted using LRU strategy.
+func SetMemoryReclaimerThreshold(threshold float64) AppOption {
+ return func(o *ApplicationConfig) {
+ if threshold > 0 && threshold <= 1.0 {
+ o.MemoryReclaimerThreshold = threshold
+ o.MemoryReclaimerEnabled = true
+ o.WatchDog = true // Memory reclaimer requires watchdog infrastructure
+ }
+ }
+}
+
+// WithMemoryReclaimer configures the memory reclaimer with the given settings
+func WithMemoryReclaimer(enabled bool, threshold float64) AppOption {
+ return func(o *ApplicationConfig) {
+ o.MemoryReclaimerEnabled = enabled
+ if threshold > 0 && threshold <= 1.0 {
+ o.MemoryReclaimerThreshold = threshold
+ }
+ if enabled {
+ o.WatchDog = true // Memory reclaimer requires watchdog infrastructure
+ }
+ }
+}
+
+// EnableSingleBackend is deprecated: use SetMaxActiveBackends(1) instead.
+// This is kept for backward compatibility.
+var EnableSingleBackend = func(o *ApplicationConfig) {
+ o.SingleBackend = true
+ o.MaxActiveBackends = 1
+}
+
+// SetMaxActiveBackends sets the maximum number of active backends.
+// 0 = unlimited, 1 = single backend mode (replaces EnableSingleBackend)
+func SetMaxActiveBackends(n int) AppOption {
+ return func(o *ApplicationConfig) {
+ o.MaxActiveBackends = n
+ // For backward compatibility, also set SingleBackend if n == 1
+ if n == 1 {
+ o.SingleBackend = true
+ }
+ }
+}
+
+// GetEffectiveMaxActiveBackends returns the effective max active backends limit.
+// It considers both MaxActiveBackends and the deprecated SingleBackend setting.
+// If MaxActiveBackends is set (> 0), it takes precedence.
+// If SingleBackend is true and MaxActiveBackends is 0, returns 1.
+// Otherwise returns 0 (unlimited).
+func (o *ApplicationConfig) GetEffectiveMaxActiveBackends() int {
+ if o.MaxActiveBackends > 0 {
+ return o.MaxActiveBackends
+ }
+ if o.SingleBackend {
+ return 1
+ }
+ return 0
+}
+
+// WithForceEvictionWhenBusy sets whether to force eviction even when models have active API calls
+func WithForceEvictionWhenBusy(enabled bool) AppOption {
+ return func(o *ApplicationConfig) {
+ o.ForceEvictionWhenBusy = enabled
+ }
+}
+
+// WithLRUEvictionMaxRetries sets the maximum number of retries when waiting for busy models to become idle
+func WithLRUEvictionMaxRetries(maxRetries int) AppOption {
+ return func(o *ApplicationConfig) {
+ if maxRetries > 0 {
+ o.LRUEvictionMaxRetries = maxRetries
+ }
+ }
+}
+
+// WithLRUEvictionRetryInterval sets the interval between retries when waiting for busy models
+func WithLRUEvictionRetryInterval(interval time.Duration) AppOption {
+ return func(o *ApplicationConfig) {
+ if interval > 0 {
+ o.LRUEvictionRetryInterval = interval
+ }
+ }
+}
+
+var EnableParallelBackendRequests = func(o *ApplicationConfig) {
+ o.ParallelBackendRequests = true
+}
+
+var EnableGalleriesAutoload = func(o *ApplicationConfig) {
+ o.AutoloadGalleries = true
+}
+
+var EnableBackendGalleriesAutoload = func(o *ApplicationConfig) {
+ o.AutoloadBackendGalleries = true
+}
+
+var EnableFederated = func(o *ApplicationConfig) {
+ o.Federated = true
+}
+
+func WithExternalBackend(name string, uri string) AppOption {
+ return func(o *ApplicationConfig) {
+ if o.ExternalGRPCBackends == nil {
+ o.ExternalGRPCBackends = make(map[string]string)
+ }
+ o.ExternalGRPCBackends[name] = uri
+ }
+}
+
+func WithCorsAllowOrigins(b string) AppOption {
+ return func(o *ApplicationConfig) {
+ o.CORSAllowOrigins = b
+ }
+}
+
+func WithStringGalleries(galls string) AppOption {
+ return func(o *ApplicationConfig) {
+ if galls == "" {
+ o.Galleries = []Gallery{}
+ return
+ }
+ var galleries []Gallery
+ if err := json.Unmarshal([]byte(galls), &galleries); err != nil {
+ xlog.Error("failed loading galleries", "error", err)
+ }
+ o.Galleries = append(o.Galleries, galleries...)
+ }
+}
+
+func WithBackendGalleries(galls string) AppOption {
+ return func(o *ApplicationConfig) {
+ if galls == "" {
+ o.BackendGalleries = []Gallery{}
+ return
+ }
+ var galleries []Gallery
+ if err := json.Unmarshal([]byte(galls), &galleries); err != nil {
+ xlog.Error("failed loading galleries", "error", err)
+ }
+ o.BackendGalleries = append(o.BackendGalleries, galleries...)
+ }
+}
+
+func WithGalleries(galleries []Gallery) AppOption {
+ return func(o *ApplicationConfig) {
+ o.Galleries = append(o.Galleries, galleries...)
+ }
+}
+
+func WithContext(ctx context.Context) AppOption {
+ return func(o *ApplicationConfig) {
+ o.Context = ctx
+ }
+}
+
+func WithYAMLConfigPreload(configFile string) AppOption {
+ return func(o *ApplicationConfig) {
+ o.PreloadModelsFromPath = configFile
+ }
+}
+
+func WithJSONStringPreload(configFile string) AppOption {
+ return func(o *ApplicationConfig) {
+ o.PreloadJSONModels = configFile
+ }
+}
+func WithConfigFile(configFile string) AppOption {
+ return func(o *ApplicationConfig) {
+ o.ConfigFile = configFile
+ }
+}
+
+func WithUploadLimitMB(limit int) AppOption {
+ return func(o *ApplicationConfig) {
+ o.UploadLimitMB = limit
+ }
+}
+
+func WithThreads(threads int) AppOption {
+ return func(o *ApplicationConfig) {
+ if threads == 0 { // 0 is not allowed
+ threads = xsysinfo.CPUPhysicalCores()
+ }
+ o.Threads = threads
+ }
+}
+
+func WithContextSize(ctxSize int) AppOption {
+ return func(o *ApplicationConfig) {
+ o.ContextSize = ctxSize
+ }
+}
+
+func WithTunnelCallback(callback func(tunnels []string)) AppOption {
+ return func(o *ApplicationConfig) {
+ o.TunnelCallback = callback
+ }
+}
+
+func WithF16(f16 bool) AppOption {
+ return func(o *ApplicationConfig) {
+ o.F16 = f16
+ }
+}
+
+func WithDebug(debug bool) AppOption {
+ return func(o *ApplicationConfig) {
+ o.Debug = debug
+ }
+}
+
+func WithTracingMaxItems(items int) AppOption {
+ return func(o *ApplicationConfig) {
+ o.TracingMaxItems = items
+ }
+}
+
+func WithGeneratedContentDir(generatedContentDir string) AppOption {
+ return func(o *ApplicationConfig) {
+ o.GeneratedContentDir = generatedContentDir
+ }
+}
+
+func WithUploadDir(uploadDir string) AppOption {
+ return func(o *ApplicationConfig) {
+ o.UploadDir = uploadDir
+ }
+}
+
+func WithDynamicConfigDir(dynamicConfigsDir string) AppOption {
+ return func(o *ApplicationConfig) {
+ o.DynamicConfigsDir = dynamicConfigsDir
+ }
+}
+
+func WithDynamicConfigDirPollInterval(interval time.Duration) AppOption {
+ return func(o *ApplicationConfig) {
+ o.DynamicConfigsDirPollInterval = interval
+ }
+}
+
+func WithApiKeys(apiKeys []string) AppOption {
+ return func(o *ApplicationConfig) {
+ o.ApiKeys = apiKeys
+ }
+}
+
+func WithAgentJobRetentionDays(days int) AppOption {
+ return func(o *ApplicationConfig) {
+ o.AgentJobRetentionDays = days
+ }
+}
+
+func WithEnforcedPredownloadScans(enforced bool) AppOption {
+ return func(o *ApplicationConfig) {
+ o.EnforcePredownloadScans = enforced
+ }
+}
+
+func WithOpaqueErrors(opaque bool) AppOption {
+ return func(o *ApplicationConfig) {
+ o.OpaqueErrors = opaque
+ }
+}
+
+func WithLoadToMemory(models []string) AppOption {
+ return func(o *ApplicationConfig) {
+ o.LoadToMemory = models
+ }
+}
+
+func WithSubtleKeyComparison(subtle bool) AppOption {
+ return func(o *ApplicationConfig) {
+ o.UseSubtleKeyComparison = subtle
+ }
+}
+
+func WithDisableApiKeyRequirementForHttpGet(required bool) AppOption {
+ return func(o *ApplicationConfig) {
+ o.DisableApiKeyRequirementForHttpGet = required
+ }
+}
+
+func WithAPIAddress(address string) AppOption {
+ return func(o *ApplicationConfig) {
+ o.APIAddress = address
+ }
+}
+
+var DisableMetricsEndpoint AppOption = func(o *ApplicationConfig) {
+ o.DisableMetrics = true
+}
+
+func WithHttpGetExemptedEndpoints(endpoints []string) AppOption {
+ return func(o *ApplicationConfig) {
+ o.HttpGetExemptedEndpoints = []*regexp.Regexp{}
+ for _, epr := range endpoints {
+ r, err := regexp.Compile(epr)
+ if err == nil && r != nil {
+ o.HttpGetExemptedEndpoints = append(o.HttpGetExemptedEndpoints, r)
+ } else {
+ xlog.Warn("Error while compiling HTTP Get Exemption regex, skipping this entry.", "error", err, "regex", epr)
+ }
+ }
+ }
+}
+
+// ToConfigLoaderOptions returns a slice of ConfigLoader Option.
+// Some options defined at the application level are going to be passed as defaults for
+// all the configuration for the models.
+// This includes for instance the context size or the number of threads.
+// If a model doesn't set configs directly to the config model file
+// it will use the defaults defined here.
+func (o *ApplicationConfig) ToConfigLoaderOptions() []ConfigLoaderOption {
+ return []ConfigLoaderOption{
+ LoadOptionContextSize(o.ContextSize),
+ LoadOptionDebug(o.Debug),
+ LoadOptionF16(o.F16),
+ LoadOptionThreads(o.Threads),
+ ModelPath(o.SystemState.Model.ModelsPath),
+ }
+}
+
+// ToRuntimeSettings converts ApplicationConfig to RuntimeSettings for API responses and JSON serialization.
+// This provides a single source of truth - ApplicationConfig holds the live values,
+// and this method creates a RuntimeSettings snapshot for external consumption.
+func (o *ApplicationConfig) ToRuntimeSettings() RuntimeSettings {
+ // Create local copies for pointer fields
+ watchdogEnabled := o.WatchDog
+ watchdogIdle := o.WatchDogIdle
+ watchdogBusy := o.WatchDogBusy
+ singleBackend := o.SingleBackend
+ maxActiveBackends := o.MaxActiveBackends
+ parallelBackendRequests := o.ParallelBackendRequests
+ memoryReclaimerEnabled := o.MemoryReclaimerEnabled
+ memoryReclaimerThreshold := o.MemoryReclaimerThreshold
+ forceEvictionWhenBusy := o.ForceEvictionWhenBusy
+ lruEvictionMaxRetries := o.LRUEvictionMaxRetries
+ threads := o.Threads
+ contextSize := o.ContextSize
+ f16 := o.F16
+ debug := o.Debug
+ tracingMaxItems := o.TracingMaxItems
+ enableTracing := o.EnableTracing
+ cors := o.CORS
+ csrf := o.CSRF
+ corsAllowOrigins := o.CORSAllowOrigins
+ p2pToken := o.P2PToken
+ p2pNetworkID := o.P2PNetworkID
+ federated := o.Federated
+ galleries := o.Galleries
+ backendGalleries := o.BackendGalleries
+ autoloadGalleries := o.AutoloadGalleries
+ autoloadBackendGalleries := o.AutoloadBackendGalleries
+ apiKeys := o.ApiKeys
+ agentJobRetentionDays := o.AgentJobRetentionDays
+
+ // Format timeouts as strings
+ var idleTimeout, busyTimeout, watchdogInterval string
+ if o.WatchDogIdleTimeout > 0 {
+ idleTimeout = o.WatchDogIdleTimeout.String()
+ } else {
+ idleTimeout = "15m" // default
+ }
+ if o.WatchDogBusyTimeout > 0 {
+ busyTimeout = o.WatchDogBusyTimeout.String()
+ } else {
+ busyTimeout = "5m" // default
+ }
+ if o.WatchDogInterval > 0 {
+ watchdogInterval = o.WatchDogInterval.String()
+ } else {
+ watchdogInterval = "2s" // default
+ }
+ var lruEvictionRetryInterval string
+ if o.LRUEvictionRetryInterval > 0 {
+ lruEvictionRetryInterval = o.LRUEvictionRetryInterval.String()
+ } else {
+ lruEvictionRetryInterval = "1s" // default
+ }
+
+ return RuntimeSettings{
+ WatchdogEnabled: &watchdogEnabled,
+ WatchdogIdleEnabled: &watchdogIdle,
+ WatchdogBusyEnabled: &watchdogBusy,
+ WatchdogIdleTimeout: &idleTimeout,
+ WatchdogBusyTimeout: &busyTimeout,
+ WatchdogInterval: &watchdogInterval,
+ SingleBackend: &singleBackend,
+ MaxActiveBackends: &maxActiveBackends,
+ ParallelBackendRequests: ¶llelBackendRequests,
+ MemoryReclaimerEnabled: &memoryReclaimerEnabled,
+ MemoryReclaimerThreshold: &memoryReclaimerThreshold,
+ ForceEvictionWhenBusy: &forceEvictionWhenBusy,
+ LRUEvictionMaxRetries: &lruEvictionMaxRetries,
+ LRUEvictionRetryInterval: &lruEvictionRetryInterval,
+ Threads: &threads,
+ ContextSize: &contextSize,
+ F16: &f16,
+ Debug: &debug,
+ TracingMaxItems: &tracingMaxItems,
+ EnableTracing: &enableTracing,
+ CORS: &cors,
+ CSRF: &csrf,
+ CORSAllowOrigins: &corsAllowOrigins,
+ P2PToken: &p2pToken,
+ P2PNetworkID: &p2pNetworkID,
+ Federated: &federated,
+ Galleries: &galleries,
+ BackendGalleries: &backendGalleries,
+ AutoloadGalleries: &autoloadGalleries,
+ AutoloadBackendGalleries: &autoloadBackendGalleries,
+ ApiKeys: &apiKeys,
+ AgentJobRetentionDays: &agentJobRetentionDays,
+ }
+}
+
+// ApplyRuntimeSettings applies RuntimeSettings to ApplicationConfig.
+// Only non-nil fields in RuntimeSettings are applied.
+// Returns true if watchdog-related settings changed (requiring restart).
+func (o *ApplicationConfig) ApplyRuntimeSettings(settings *RuntimeSettings) (requireRestart bool) {
+ if settings == nil {
+ return false
+ }
+
+ if settings.WatchdogEnabled != nil {
+ o.WatchDog = *settings.WatchdogEnabled
+ requireRestart = true
+ }
+ if settings.WatchdogIdleEnabled != nil {
+ o.WatchDogIdle = *settings.WatchdogIdleEnabled
+ if o.WatchDogIdle {
+ o.WatchDog = true
+ }
+ requireRestart = true
+ }
+ if settings.WatchdogBusyEnabled != nil {
+ o.WatchDogBusy = *settings.WatchdogBusyEnabled
+ if o.WatchDogBusy {
+ o.WatchDog = true
+ }
+ requireRestart = true
+ }
+ if settings.WatchdogIdleTimeout != nil {
+ if dur, err := time.ParseDuration(*settings.WatchdogIdleTimeout); err == nil {
+ o.WatchDogIdleTimeout = dur
+ requireRestart = true
+ }
+ }
+ if settings.WatchdogBusyTimeout != nil {
+ if dur, err := time.ParseDuration(*settings.WatchdogBusyTimeout); err == nil {
+ o.WatchDogBusyTimeout = dur
+ requireRestart = true
+ }
+ }
+ if settings.WatchdogInterval != nil {
+ if dur, err := time.ParseDuration(*settings.WatchdogInterval); err == nil {
+ o.WatchDogInterval = dur
+ requireRestart = true
+ }
+ }
+ if settings.MaxActiveBackends != nil {
+ o.MaxActiveBackends = *settings.MaxActiveBackends
+ o.SingleBackend = (*settings.MaxActiveBackends == 1)
+ requireRestart = true
+ } else if settings.SingleBackend != nil {
+ o.SingleBackend = *settings.SingleBackend
+ if *settings.SingleBackend {
+ o.MaxActiveBackends = 1
+ } else {
+ o.MaxActiveBackends = 0
+ }
+ requireRestart = true
+ }
+ if settings.ParallelBackendRequests != nil {
+ o.ParallelBackendRequests = *settings.ParallelBackendRequests
+ }
+ if settings.MemoryReclaimerEnabled != nil {
+ o.MemoryReclaimerEnabled = *settings.MemoryReclaimerEnabled
+ if *settings.MemoryReclaimerEnabled {
+ o.WatchDog = true
+ }
+ requireRestart = true
+ }
+ if settings.MemoryReclaimerThreshold != nil {
+ if *settings.MemoryReclaimerThreshold > 0 && *settings.MemoryReclaimerThreshold <= 1.0 {
+ o.MemoryReclaimerThreshold = *settings.MemoryReclaimerThreshold
+ requireRestart = true
+ }
+ }
+ if settings.ForceEvictionWhenBusy != nil {
+ o.ForceEvictionWhenBusy = *settings.ForceEvictionWhenBusy
+ // This setting doesn't require restart, can be updated dynamically
+ }
+ if settings.LRUEvictionMaxRetries != nil {
+ o.LRUEvictionMaxRetries = *settings.LRUEvictionMaxRetries
+ // This setting doesn't require restart, can be updated dynamically
+ }
+ if settings.LRUEvictionRetryInterval != nil {
+ if dur, err := time.ParseDuration(*settings.LRUEvictionRetryInterval); err == nil {
+ o.LRUEvictionRetryInterval = dur
+ // This setting doesn't require restart, can be updated dynamically
+ }
+ }
+ if settings.Threads != nil {
+ o.Threads = *settings.Threads
+ }
+ if settings.ContextSize != nil {
+ o.ContextSize = *settings.ContextSize
+ }
+ if settings.F16 != nil {
+ o.F16 = *settings.F16
+ }
+ if settings.Debug != nil {
+ o.Debug = *settings.Debug
+ }
+ if settings.EnableTracing != nil {
+ o.EnableTracing = *settings.EnableTracing
+ }
+ if settings.TracingMaxItems != nil {
+ o.TracingMaxItems = *settings.TracingMaxItems
+ }
+ if settings.CORS != nil {
+ o.CORS = *settings.CORS
+ }
+ if settings.CSRF != nil {
+ o.CSRF = *settings.CSRF
+ }
+ if settings.CORSAllowOrigins != nil {
+ o.CORSAllowOrigins = *settings.CORSAllowOrigins
+ }
+ if settings.P2PToken != nil {
+ o.P2PToken = *settings.P2PToken
+ }
+ if settings.P2PNetworkID != nil {
+ o.P2PNetworkID = *settings.P2PNetworkID
+ }
+ if settings.Federated != nil {
+ o.Federated = *settings.Federated
+ }
+ if settings.Galleries != nil {
+ o.Galleries = *settings.Galleries
+ }
+ if settings.BackendGalleries != nil {
+ o.BackendGalleries = *settings.BackendGalleries
+ }
+ if settings.AutoloadGalleries != nil {
+ o.AutoloadGalleries = *settings.AutoloadGalleries
+ }
+ if settings.AutoloadBackendGalleries != nil {
+ o.AutoloadBackendGalleries = *settings.AutoloadBackendGalleries
+ }
+ if settings.AgentJobRetentionDays != nil {
+ o.AgentJobRetentionDays = *settings.AgentJobRetentionDays
+ }
+ // Note: ApiKeys requires special handling (merging with startup keys) - handled in caller
+
+ return requireRestart
+}
+
+// func WithMetrics(meter *metrics.Metrics) AppOption {
+// return func(o *StartupOptions) {
+// o.Metrics = meter
+// }
+// }
diff --git a/core/config/application_config_test.go b/core/config/application_config_test.go
new file mode 100644
index 0000000000000000000000000000000000000000..c6d4fbecd6bc2de8b335847aa7e99ac785815c0d
--- /dev/null
+++ b/core/config/application_config_test.go
@@ -0,0 +1,577 @@
+package config
+
+import (
+ "time"
+
+ . "github.com/onsi/ginkgo/v2"
+ . "github.com/onsi/gomega"
+)
+
+var _ = Describe("ApplicationConfig RuntimeSettings Conversion", func() {
+ Describe("ToRuntimeSettings", func() {
+ It("should convert all fields correctly", func() {
+ appConfig := &ApplicationConfig{
+ WatchDog: true,
+ WatchDogIdle: true,
+ WatchDogBusy: true,
+ WatchDogIdleTimeout: 20 * time.Minute,
+ WatchDogBusyTimeout: 10 * time.Minute,
+ SingleBackend: false,
+ MaxActiveBackends: 5,
+ ParallelBackendRequests: true,
+ MemoryReclaimerEnabled: true,
+ MemoryReclaimerThreshold: 0.85,
+ Threads: 8,
+ ContextSize: 4096,
+ F16: true,
+ Debug: true,
+ CORS: true,
+ CSRF: true,
+ CORSAllowOrigins: "https://example.com",
+ P2PToken: "test-token",
+ P2PNetworkID: "test-network",
+ Federated: true,
+ Galleries: []Gallery{{Name: "test-gallery", URL: "https://example.com"}},
+ BackendGalleries: []Gallery{{Name: "backend-gallery", URL: "https://example.com/backend"}},
+ AutoloadGalleries: true,
+ AutoloadBackendGalleries: true,
+ ApiKeys: []string{"key1", "key2"},
+ AgentJobRetentionDays: 30,
+ }
+
+ rs := appConfig.ToRuntimeSettings()
+
+ Expect(rs.WatchdogEnabled).ToNot(BeNil())
+ Expect(*rs.WatchdogEnabled).To(BeTrue())
+
+ Expect(rs.WatchdogIdleEnabled).ToNot(BeNil())
+ Expect(*rs.WatchdogIdleEnabled).To(BeTrue())
+
+ Expect(rs.WatchdogBusyEnabled).ToNot(BeNil())
+ Expect(*rs.WatchdogBusyEnabled).To(BeTrue())
+
+ Expect(rs.WatchdogIdleTimeout).ToNot(BeNil())
+ Expect(*rs.WatchdogIdleTimeout).To(Equal("20m0s"))
+
+ Expect(rs.WatchdogBusyTimeout).ToNot(BeNil())
+ Expect(*rs.WatchdogBusyTimeout).To(Equal("10m0s"))
+
+ Expect(rs.SingleBackend).ToNot(BeNil())
+ Expect(*rs.SingleBackend).To(BeFalse())
+
+ Expect(rs.MaxActiveBackends).ToNot(BeNil())
+ Expect(*rs.MaxActiveBackends).To(Equal(5))
+
+ Expect(rs.ParallelBackendRequests).ToNot(BeNil())
+ Expect(*rs.ParallelBackendRequests).To(BeTrue())
+
+ Expect(rs.MemoryReclaimerEnabled).ToNot(BeNil())
+ Expect(*rs.MemoryReclaimerEnabled).To(BeTrue())
+
+ Expect(rs.MemoryReclaimerThreshold).ToNot(BeNil())
+ Expect(*rs.MemoryReclaimerThreshold).To(Equal(0.85))
+
+ Expect(rs.Threads).ToNot(BeNil())
+ Expect(*rs.Threads).To(Equal(8))
+
+ Expect(rs.ContextSize).ToNot(BeNil())
+ Expect(*rs.ContextSize).To(Equal(4096))
+
+ Expect(rs.F16).ToNot(BeNil())
+ Expect(*rs.F16).To(BeTrue())
+
+ Expect(rs.Debug).ToNot(BeNil())
+ Expect(*rs.Debug).To(BeTrue())
+
+ Expect(rs.CORS).ToNot(BeNil())
+ Expect(*rs.CORS).To(BeTrue())
+
+ Expect(rs.CSRF).ToNot(BeNil())
+ Expect(*rs.CSRF).To(BeTrue())
+
+ Expect(rs.CORSAllowOrigins).ToNot(BeNil())
+ Expect(*rs.CORSAllowOrigins).To(Equal("https://example.com"))
+
+ Expect(rs.P2PToken).ToNot(BeNil())
+ Expect(*rs.P2PToken).To(Equal("test-token"))
+
+ Expect(rs.P2PNetworkID).ToNot(BeNil())
+ Expect(*rs.P2PNetworkID).To(Equal("test-network"))
+
+ Expect(rs.Federated).ToNot(BeNil())
+ Expect(*rs.Federated).To(BeTrue())
+
+ Expect(rs.Galleries).ToNot(BeNil())
+ Expect(*rs.Galleries).To(HaveLen(1))
+ Expect((*rs.Galleries)[0].Name).To(Equal("test-gallery"))
+
+ Expect(rs.BackendGalleries).ToNot(BeNil())
+ Expect(*rs.BackendGalleries).To(HaveLen(1))
+ Expect((*rs.BackendGalleries)[0].Name).To(Equal("backend-gallery"))
+
+ Expect(rs.AutoloadGalleries).ToNot(BeNil())
+ Expect(*rs.AutoloadGalleries).To(BeTrue())
+
+ Expect(rs.AutoloadBackendGalleries).ToNot(BeNil())
+ Expect(*rs.AutoloadBackendGalleries).To(BeTrue())
+
+ Expect(rs.ApiKeys).ToNot(BeNil())
+ Expect(*rs.ApiKeys).To(HaveLen(2))
+ Expect(*rs.ApiKeys).To(ContainElements("key1", "key2"))
+
+ Expect(rs.AgentJobRetentionDays).ToNot(BeNil())
+ Expect(*rs.AgentJobRetentionDays).To(Equal(30))
+ })
+
+ It("should use default timeouts when not set", func() {
+ appConfig := &ApplicationConfig{}
+
+ rs := appConfig.ToRuntimeSettings()
+
+ Expect(rs.WatchdogIdleTimeout).ToNot(BeNil())
+ Expect(*rs.WatchdogIdleTimeout).To(Equal("15m"))
+
+ Expect(rs.WatchdogBusyTimeout).ToNot(BeNil())
+ Expect(*rs.WatchdogBusyTimeout).To(Equal("5m"))
+ })
+ })
+
+ Describe("ApplyRuntimeSettings", func() {
+ It("should return false when settings is nil", func() {
+ appConfig := &ApplicationConfig{}
+ changed := appConfig.ApplyRuntimeSettings(nil)
+ Expect(changed).To(BeFalse())
+ })
+
+ It("should only apply non-nil fields", func() {
+ appConfig := &ApplicationConfig{
+ WatchDog: false,
+ Threads: 4,
+ ContextSize: 2048,
+ }
+
+ watchdogEnabled := true
+ rs := &RuntimeSettings{
+ WatchdogEnabled: &watchdogEnabled,
+ // Leave other fields nil
+ }
+
+ changed := appConfig.ApplyRuntimeSettings(rs)
+
+ Expect(changed).To(BeTrue())
+ Expect(appConfig.WatchDog).To(BeTrue())
+ // Unchanged fields should remain
+ Expect(appConfig.Threads).To(Equal(4))
+ Expect(appConfig.ContextSize).To(Equal(2048))
+ })
+
+ It("should apply watchdog settings and return changed=true", func() {
+ appConfig := &ApplicationConfig{}
+
+ watchdogEnabled := true
+ watchdogIdle := true
+ watchdogBusy := true
+ idleTimeout := "30m"
+ busyTimeout := "15m"
+
+ rs := &RuntimeSettings{
+ WatchdogEnabled: &watchdogEnabled,
+ WatchdogIdleEnabled: &watchdogIdle,
+ WatchdogBusyEnabled: &watchdogBusy,
+ WatchdogIdleTimeout: &idleTimeout,
+ WatchdogBusyTimeout: &busyTimeout,
+ }
+
+ changed := appConfig.ApplyRuntimeSettings(rs)
+
+ Expect(changed).To(BeTrue())
+ Expect(appConfig.WatchDog).To(BeTrue())
+ Expect(appConfig.WatchDogIdle).To(BeTrue())
+ Expect(appConfig.WatchDogBusy).To(BeTrue())
+ Expect(appConfig.WatchDogIdleTimeout).To(Equal(30 * time.Minute))
+ Expect(appConfig.WatchDogBusyTimeout).To(Equal(15 * time.Minute))
+ })
+
+ It("should enable watchdog when idle is enabled", func() {
+ appConfig := &ApplicationConfig{WatchDog: false}
+
+ watchdogIdle := true
+ rs := &RuntimeSettings{
+ WatchdogIdleEnabled: &watchdogIdle,
+ }
+
+ appConfig.ApplyRuntimeSettings(rs)
+
+ Expect(appConfig.WatchDog).To(BeTrue())
+ Expect(appConfig.WatchDogIdle).To(BeTrue())
+ })
+
+ It("should enable watchdog when busy is enabled", func() {
+ appConfig := &ApplicationConfig{WatchDog: false}
+
+ watchdogBusy := true
+ rs := &RuntimeSettings{
+ WatchdogBusyEnabled: &watchdogBusy,
+ }
+
+ appConfig.ApplyRuntimeSettings(rs)
+
+ Expect(appConfig.WatchDog).To(BeTrue())
+ Expect(appConfig.WatchDogBusy).To(BeTrue())
+ })
+
+ It("should handle MaxActiveBackends and update SingleBackend accordingly", func() {
+ appConfig := &ApplicationConfig{}
+
+ maxBackends := 1
+ rs := &RuntimeSettings{
+ MaxActiveBackends: &maxBackends,
+ }
+
+ appConfig.ApplyRuntimeSettings(rs)
+
+ Expect(appConfig.MaxActiveBackends).To(Equal(1))
+ Expect(appConfig.SingleBackend).To(BeTrue())
+
+ // Test with multiple backends
+ maxBackends = 5
+ rs = &RuntimeSettings{
+ MaxActiveBackends: &maxBackends,
+ }
+
+ appConfig.ApplyRuntimeSettings(rs)
+
+ Expect(appConfig.MaxActiveBackends).To(Equal(5))
+ Expect(appConfig.SingleBackend).To(BeFalse())
+ })
+
+ It("should handle SingleBackend and update MaxActiveBackends accordingly", func() {
+ appConfig := &ApplicationConfig{}
+
+ singleBackend := true
+ rs := &RuntimeSettings{
+ SingleBackend: &singleBackend,
+ }
+
+ appConfig.ApplyRuntimeSettings(rs)
+
+ Expect(appConfig.SingleBackend).To(BeTrue())
+ Expect(appConfig.MaxActiveBackends).To(Equal(1))
+
+ // Test disabling single backend
+ singleBackend = false
+ rs = &RuntimeSettings{
+ SingleBackend: &singleBackend,
+ }
+
+ appConfig.ApplyRuntimeSettings(rs)
+
+ Expect(appConfig.SingleBackend).To(BeFalse())
+ Expect(appConfig.MaxActiveBackends).To(Equal(0))
+ })
+
+ It("should enable watchdog when memory reclaimer is enabled", func() {
+ appConfig := &ApplicationConfig{WatchDog: false}
+
+ memoryEnabled := true
+ threshold := 0.90
+ rs := &RuntimeSettings{
+ MemoryReclaimerEnabled: &memoryEnabled,
+ MemoryReclaimerThreshold: &threshold,
+ }
+
+ changed := appConfig.ApplyRuntimeSettings(rs)
+
+ Expect(changed).To(BeTrue())
+ Expect(appConfig.WatchDog).To(BeTrue())
+ Expect(appConfig.MemoryReclaimerEnabled).To(BeTrue())
+ Expect(appConfig.MemoryReclaimerThreshold).To(Equal(0.90))
+ })
+
+ It("should reject invalid memory threshold values", func() {
+ appConfig := &ApplicationConfig{MemoryReclaimerThreshold: 0.50}
+
+ // Test threshold > 1.0
+ invalidThreshold := 1.5
+ rs := &RuntimeSettings{
+ MemoryReclaimerThreshold: &invalidThreshold,
+ }
+ appConfig.ApplyRuntimeSettings(rs)
+ Expect(appConfig.MemoryReclaimerThreshold).To(Equal(0.50)) // Should remain unchanged
+
+ // Test threshold <= 0
+ invalidThreshold = 0.0
+ rs = &RuntimeSettings{
+ MemoryReclaimerThreshold: &invalidThreshold,
+ }
+ appConfig.ApplyRuntimeSettings(rs)
+ Expect(appConfig.MemoryReclaimerThreshold).To(Equal(0.50)) // Should remain unchanged
+
+ // Test negative threshold
+ invalidThreshold = -0.5
+ rs = &RuntimeSettings{
+ MemoryReclaimerThreshold: &invalidThreshold,
+ }
+ appConfig.ApplyRuntimeSettings(rs)
+ Expect(appConfig.MemoryReclaimerThreshold).To(Equal(0.50)) // Should remain unchanged
+ })
+
+ It("should accept valid memory threshold at boundary", func() {
+ appConfig := &ApplicationConfig{}
+
+ // Test threshold = 1.0 (maximum valid)
+ threshold := 1.0
+ rs := &RuntimeSettings{
+ MemoryReclaimerThreshold: &threshold,
+ }
+ appConfig.ApplyRuntimeSettings(rs)
+ Expect(appConfig.MemoryReclaimerThreshold).To(Equal(1.0))
+
+ // Test threshold just above 0
+ threshold = 0.01
+ rs = &RuntimeSettings{
+ MemoryReclaimerThreshold: &threshold,
+ }
+ appConfig.ApplyRuntimeSettings(rs)
+ Expect(appConfig.MemoryReclaimerThreshold).To(Equal(0.01))
+ })
+
+ It("should apply performance settings without triggering watchdog change", func() {
+ appConfig := &ApplicationConfig{}
+
+ threads := 16
+ contextSize := 8192
+ f16 := true
+ debug := true
+
+ rs := &RuntimeSettings{
+ Threads: &threads,
+ ContextSize: &contextSize,
+ F16: &f16,
+ Debug: &debug,
+ }
+
+ changed := appConfig.ApplyRuntimeSettings(rs)
+
+ // These settings don't require watchdog restart
+ Expect(changed).To(BeFalse())
+ Expect(appConfig.Threads).To(Equal(16))
+ Expect(appConfig.ContextSize).To(Equal(8192))
+ Expect(appConfig.F16).To(BeTrue())
+ Expect(appConfig.Debug).To(BeTrue())
+ })
+
+ It("should apply CORS and security settings", func() {
+ appConfig := &ApplicationConfig{}
+
+ cors := true
+ csrf := true
+ origins := "https://example.com,https://other.com"
+
+ rs := &RuntimeSettings{
+ CORS: &cors,
+ CSRF: &csrf,
+ CORSAllowOrigins: &origins,
+ }
+
+ appConfig.ApplyRuntimeSettings(rs)
+
+ Expect(appConfig.CORS).To(BeTrue())
+ Expect(appConfig.CSRF).To(BeTrue())
+ Expect(appConfig.CORSAllowOrigins).To(Equal("https://example.com,https://other.com"))
+ })
+
+ It("should apply P2P settings", func() {
+ appConfig := &ApplicationConfig{}
+
+ token := "p2p-test-token"
+ networkID := "p2p-test-network"
+ federated := true
+
+ rs := &RuntimeSettings{
+ P2PToken: &token,
+ P2PNetworkID: &networkID,
+ Federated: &federated,
+ }
+
+ appConfig.ApplyRuntimeSettings(rs)
+
+ Expect(appConfig.P2PToken).To(Equal("p2p-test-token"))
+ Expect(appConfig.P2PNetworkID).To(Equal("p2p-test-network"))
+ Expect(appConfig.Federated).To(BeTrue())
+ })
+
+ It("should apply gallery settings", func() {
+ appConfig := &ApplicationConfig{}
+
+ galleries := []Gallery{
+ {Name: "gallery1", URL: "https://gallery1.com"},
+ {Name: "gallery2", URL: "https://gallery2.com"},
+ }
+ backendGalleries := []Gallery{
+ {Name: "backend-gallery", URL: "https://backend.com"},
+ }
+ autoload := true
+ autoloadBackend := true
+
+ rs := &RuntimeSettings{
+ Galleries: &galleries,
+ BackendGalleries: &backendGalleries,
+ AutoloadGalleries: &autoload,
+ AutoloadBackendGalleries: &autoloadBackend,
+ }
+
+ appConfig.ApplyRuntimeSettings(rs)
+
+ Expect(appConfig.Galleries).To(HaveLen(2))
+ Expect(appConfig.Galleries[0].Name).To(Equal("gallery1"))
+ Expect(appConfig.BackendGalleries).To(HaveLen(1))
+ Expect(appConfig.AutoloadGalleries).To(BeTrue())
+ Expect(appConfig.AutoloadBackendGalleries).To(BeTrue())
+ })
+
+ It("should apply agent settings", func() {
+ appConfig := &ApplicationConfig{}
+
+ retentionDays := 14
+
+ rs := &RuntimeSettings{
+ AgentJobRetentionDays: &retentionDays,
+ }
+
+ appConfig.ApplyRuntimeSettings(rs)
+
+ Expect(appConfig.AgentJobRetentionDays).To(Equal(14))
+ })
+ })
+
+ Describe("Round-trip conversion", func() {
+ It("should maintain values through ToRuntimeSettings -> ApplyRuntimeSettings", func() {
+ original := &ApplicationConfig{
+ WatchDog: true,
+ WatchDogIdle: true,
+ WatchDogBusy: false,
+ WatchDogIdleTimeout: 25 * time.Minute,
+ WatchDogBusyTimeout: 12 * time.Minute,
+ SingleBackend: false,
+ MaxActiveBackends: 3,
+ ParallelBackendRequests: true,
+ MemoryReclaimerEnabled: true,
+ MemoryReclaimerThreshold: 0.92,
+ Threads: 12,
+ ContextSize: 6144,
+ F16: true,
+ Debug: false,
+ CORS: true,
+ CSRF: false,
+ CORSAllowOrigins: "https://test.com",
+ P2PToken: "round-trip-token",
+ P2PNetworkID: "round-trip-network",
+ Federated: true,
+ AutoloadGalleries: true,
+ AutoloadBackendGalleries: false,
+ AgentJobRetentionDays: 60,
+ }
+
+ // Convert to RuntimeSettings
+ rs := original.ToRuntimeSettings()
+
+ // Apply to a new ApplicationConfig
+ target := &ApplicationConfig{}
+ target.ApplyRuntimeSettings(&rs)
+
+ // Verify all values match
+ Expect(target.WatchDog).To(Equal(original.WatchDog))
+ Expect(target.WatchDogIdle).To(Equal(original.WatchDogIdle))
+ Expect(target.WatchDogBusy).To(Equal(original.WatchDogBusy))
+ Expect(target.WatchDogIdleTimeout).To(Equal(original.WatchDogIdleTimeout))
+ Expect(target.WatchDogBusyTimeout).To(Equal(original.WatchDogBusyTimeout))
+ Expect(target.MaxActiveBackends).To(Equal(original.MaxActiveBackends))
+ Expect(target.ParallelBackendRequests).To(Equal(original.ParallelBackendRequests))
+ Expect(target.MemoryReclaimerEnabled).To(Equal(original.MemoryReclaimerEnabled))
+ Expect(target.MemoryReclaimerThreshold).To(Equal(original.MemoryReclaimerThreshold))
+ Expect(target.Threads).To(Equal(original.Threads))
+ Expect(target.ContextSize).To(Equal(original.ContextSize))
+ Expect(target.F16).To(Equal(original.F16))
+ Expect(target.Debug).To(Equal(original.Debug))
+ Expect(target.CORS).To(Equal(original.CORS))
+ Expect(target.CSRF).To(Equal(original.CSRF))
+ Expect(target.CORSAllowOrigins).To(Equal(original.CORSAllowOrigins))
+ Expect(target.P2PToken).To(Equal(original.P2PToken))
+ Expect(target.P2PNetworkID).To(Equal(original.P2PNetworkID))
+ Expect(target.Federated).To(Equal(original.Federated))
+ Expect(target.AutoloadGalleries).To(Equal(original.AutoloadGalleries))
+ Expect(target.AutoloadBackendGalleries).To(Equal(original.AutoloadBackendGalleries))
+ Expect(target.AgentJobRetentionDays).To(Equal(original.AgentJobRetentionDays))
+ })
+
+ It("should handle empty galleries correctly in round-trip", func() {
+ original := &ApplicationConfig{
+ Galleries: []Gallery{},
+ BackendGalleries: []Gallery{},
+ ApiKeys: []string{},
+ }
+
+ rs := original.ToRuntimeSettings()
+ target := &ApplicationConfig{}
+ target.ApplyRuntimeSettings(&rs)
+
+ Expect(target.Galleries).To(BeEmpty())
+ Expect(target.BackendGalleries).To(BeEmpty())
+ })
+ })
+
+ Describe("Edge cases", func() {
+ It("should handle invalid timeout string in ApplyRuntimeSettings", func() {
+ appConfig := &ApplicationConfig{
+ WatchDogIdleTimeout: 10 * time.Minute,
+ }
+
+ invalidTimeout := "not-a-duration"
+ rs := &RuntimeSettings{
+ WatchdogIdleTimeout: &invalidTimeout,
+ }
+
+ appConfig.ApplyRuntimeSettings(rs)
+
+ // Should remain unchanged due to parse error
+ Expect(appConfig.WatchDogIdleTimeout).To(Equal(10 * time.Minute))
+ })
+
+ It("should handle zero values in ApplicationConfig", func() {
+ appConfig := &ApplicationConfig{
+ // All zero values
+ }
+
+ rs := appConfig.ToRuntimeSettings()
+
+ // Should still have non-nil pointers with zero/default values
+ Expect(rs.WatchdogEnabled).ToNot(BeNil())
+ Expect(*rs.WatchdogEnabled).To(BeFalse())
+
+ Expect(rs.Threads).ToNot(BeNil())
+ Expect(*rs.Threads).To(Equal(0))
+
+ Expect(rs.MemoryReclaimerThreshold).ToNot(BeNil())
+ Expect(*rs.MemoryReclaimerThreshold).To(Equal(0.0))
+ })
+
+ It("should prefer MaxActiveBackends over SingleBackend when both are set", func() {
+ appConfig := &ApplicationConfig{}
+
+ maxBackends := 3
+ singleBackend := true
+
+ rs := &RuntimeSettings{
+ MaxActiveBackends: &maxBackends,
+ SingleBackend: &singleBackend,
+ }
+
+ appConfig.ApplyRuntimeSettings(rs)
+
+ // MaxActiveBackends should take precedence
+ Expect(appConfig.MaxActiveBackends).To(Equal(3))
+ Expect(appConfig.SingleBackend).To(BeFalse()) // 3 != 1, so single backend is false
+ })
+ })
+})
diff --git a/core/config/config_suite_test.go b/core/config/config_suite_test.go
new file mode 100644
index 0000000000000000000000000000000000000000..56052cbb7663f0f8679d0e282ad12b6c17436db9
--- /dev/null
+++ b/core/config/config_suite_test.go
@@ -0,0 +1,13 @@
+package config_test
+
+import (
+ "testing"
+
+ . "github.com/onsi/ginkgo/v2"
+ . "github.com/onsi/gomega"
+)
+
+func TestConfig(t *testing.T) {
+ RegisterFailHandler(Fail)
+ RunSpecs(t, "Config test suite")
+}
diff --git a/core/config/gallery.go b/core/config/gallery.go
new file mode 100644
index 0000000000000000000000000000000000000000..002100be5fb281dcb5202a9cf7c798addd0c1505
--- /dev/null
+++ b/core/config/gallery.go
@@ -0,0 +1,6 @@
+package config
+
+type Gallery struct {
+ URL string `json:"url" yaml:"url"`
+ Name string `json:"name" yaml:"name"`
+}
diff --git a/core/config/gguf.go b/core/config/gguf.go
new file mode 100644
index 0000000000000000000000000000000000000000..f63acd35f3c9f29cb9e21268053310edef20b2f0
--- /dev/null
+++ b/core/config/gguf.go
@@ -0,0 +1,86 @@
+package config
+
+import (
+ "github.com/mudler/LocalAI/pkg/xsysinfo"
+ "github.com/mudler/xlog"
+
+ gguf "github.com/gpustack/gguf-parser-go"
+)
+
+const (
+ defaultContextSize = 1024
+ defaultNGPULayers = 99999999
+)
+
+func guessGGUFFromFile(cfg *ModelConfig, f *gguf.GGUFFile, defaultCtx int) {
+
+ if defaultCtx == 0 && cfg.ContextSize == nil {
+ ctxSize := f.EstimateLLaMACppRun().ContextSize
+ if ctxSize > 0 {
+ cSize := int(ctxSize)
+ cfg.ContextSize = &cSize
+ } else {
+ defaultCtx = defaultContextSize
+ cfg.ContextSize = &defaultCtx
+ }
+ }
+
+ // GPU options
+ if cfg.Options == nil {
+ if xsysinfo.HasGPU("nvidia") || xsysinfo.HasGPU("amd") {
+ cfg.Options = []string{"gpu"}
+ }
+ }
+
+ // vram estimation
+ vram, err := xsysinfo.TotalAvailableVRAM()
+ if err != nil {
+ xlog.Error("guessDefaultsFromFile(TotalAvailableVRAM)", "error", err)
+ } else if vram > 0 {
+ estimate, err := xsysinfo.EstimateGGUFVRAMUsage(f, vram)
+ if err != nil {
+ xlog.Error("guessDefaultsFromFile(EstimateGGUFVRAMUsage)", "error", err)
+ } else {
+ if estimate.IsFullOffload {
+ xlog.Warn("guessDefaultsFromFile: full offload is recommended")
+ }
+
+ if estimate.EstimatedVRAM > vram {
+ xlog.Warn("guessDefaultsFromFile: estimated VRAM usage is greater than available VRAM")
+ }
+
+ if cfg.NGPULayers == nil && estimate.EstimatedLayers > 0 {
+ xlog.Debug("guessDefaultsFromFile: layers estimated", "layers", estimate.EstimatedLayers)
+ cfg.NGPULayers = &estimate.EstimatedLayers
+ }
+ }
+ }
+
+ if cfg.NGPULayers == nil {
+ // we assume we want to offload all layers
+ defaultHigh := defaultNGPULayers
+ cfg.NGPULayers = &defaultHigh
+ }
+
+ xlog.Debug("guessDefaultsFromFile: NGPULayers set", "NGPULayers", cfg.NGPULayers)
+
+ // template estimations
+ if cfg.HasTemplate() {
+ // nothing to guess here
+ xlog.Debug("guessDefaultsFromFile: template already set", "name", cfg.Name)
+ return
+ }
+
+ xlog.Debug("Model file loaded", "file", cfg.ModelFileName(), "eosTokenID", f.Tokenizer().EOSTokenID, "bosTokenID", f.Tokenizer().BOSTokenID, "modelName", f.Metadata().Name, "architecture", f.Architecture().Architecture)
+
+ // guess the name
+ if cfg.Name == "" {
+ cfg.Name = f.Metadata().Name
+ }
+
+ // Instruct to use template from llama.cpp
+ cfg.TemplateConfig.UseTokenizerTemplate = true
+ cfg.FunctionsConfig.GrammarConfig.NoGrammar = true
+ cfg.Options = append(cfg.Options, "use_jinja:true")
+ cfg.KnownUsecaseStrings = append(cfg.KnownUsecaseStrings, "FLAG_CHAT")
+}
diff --git a/core/config/guesser.go b/core/config/guesser.go
new file mode 100644
index 0000000000000000000000000000000000000000..e4ca5b1415f978de4920232746675e996016f5f9
--- /dev/null
+++ b/core/config/guesser.go
@@ -0,0 +1,46 @@
+package config
+
+import (
+ "os"
+ "path/filepath"
+
+ gguf "github.com/gpustack/gguf-parser-go"
+ "github.com/mudler/xlog"
+)
+
+func guessDefaultsFromFile(cfg *ModelConfig, modelPath string, defaultCtx int) {
+ if os.Getenv("LOCALAI_DISABLE_GUESSING") == "true" {
+ xlog.Debug("guessDefaultsFromFile: guessing disabled with LOCALAI_DISABLE_GUESSING")
+ return
+ }
+
+ if modelPath == "" {
+ xlog.Debug("guessDefaultsFromFile: modelPath is empty")
+ return
+ }
+
+ // We try to guess only if we don't have a template defined already
+ guessPath := filepath.Join(modelPath, cfg.ModelFileName())
+
+ defer func() {
+ if r := recover(); r != nil {
+ xlog.Error("guessDefaultsFromFile: panic while parsing gguf file")
+ }
+ }()
+
+ defer func() {
+ if cfg.ContextSize == nil {
+ if defaultCtx == 0 {
+ defaultCtx = defaultContextSize
+ }
+ cfg.ContextSize = &defaultCtx
+ }
+ }()
+
+ // try to parse the gguf file
+ f, err := gguf.ParseGGUFFile(guessPath)
+ if err == nil {
+ guessGGUFFromFile(cfg, f, defaultCtx)
+ return
+ }
+}
diff --git a/core/config/model_config.go b/core/config/model_config.go
new file mode 100644
index 0000000000000000000000000000000000000000..9010c84e60c353d311e3599a8b0f3fa17f617561
--- /dev/null
+++ b/core/config/model_config.go
@@ -0,0 +1,722 @@
+package config
+
+import (
+ "fmt"
+ "os"
+ "regexp"
+ "slices"
+ "strings"
+
+ "github.com/mudler/LocalAI/core/schema"
+ "github.com/mudler/LocalAI/pkg/downloader"
+ "github.com/mudler/LocalAI/pkg/functions"
+ "github.com/mudler/cogito"
+ "gopkg.in/yaml.v3"
+)
+
+const (
+ RAND_SEED = -1
+)
+
+// @Description TTS configuration
+type TTSConfig struct {
+
+ // Voice wav path or id
+ Voice string `yaml:"voice,omitempty" json:"voice,omitempty"`
+
+ AudioPath string `yaml:"audio_path,omitempty" json:"audio_path,omitempty"`
+}
+
+// @Description ModelConfig represents a model configuration
+type ModelConfig struct {
+ modelConfigFile string `yaml:"-" json:"-"`
+ schema.PredictionOptions `yaml:"parameters,omitempty" json:"parameters,omitempty"`
+ Name string `yaml:"name,omitempty" json:"name,omitempty"`
+
+ F16 *bool `yaml:"f16,omitempty" json:"f16,omitempty"`
+ Threads *int `yaml:"threads,omitempty" json:"threads,omitempty"`
+ Debug *bool `yaml:"debug,omitempty" json:"debug,omitempty"`
+ Roles map[string]string `yaml:"roles,omitempty" json:"roles,omitempty"`
+ Embeddings *bool `yaml:"embeddings,omitempty" json:"embeddings,omitempty"`
+ Backend string `yaml:"backend,omitempty" json:"backend,omitempty"`
+ TemplateConfig TemplateConfig `yaml:"template,omitempty" json:"template,omitempty"`
+ KnownUsecaseStrings []string `yaml:"known_usecases,omitempty" json:"known_usecases,omitempty"`
+ KnownUsecases *ModelConfigUsecase `yaml:"-" json:"-"`
+ Pipeline Pipeline `yaml:"pipeline,omitempty" json:"pipeline,omitempty"`
+
+ PromptStrings, InputStrings []string `yaml:"-" json:"-"`
+ InputToken [][]int `yaml:"-" json:"-"`
+ functionCallString, functionCallNameString string `yaml:"-" json:"-"`
+ ResponseFormat string `yaml:"-" json:"-"`
+ ResponseFormatMap map[string]interface{} `yaml:"-" json:"-"`
+
+ FunctionsConfig functions.FunctionsConfig `yaml:"function,omitempty" json:"function,omitempty"`
+
+ FeatureFlag FeatureFlag `yaml:"feature_flags,omitempty" json:"feature_flags,omitempty"` // Feature Flag registry. We move fast, and features may break on a per model/backend basis. Registry for (usually temporary) flags that indicate aborting something early.
+ // LLM configs (GPT4ALL, Llama.cpp, ...)
+ LLMConfig `yaml:",inline" json:",inline"`
+
+ // Diffusers
+ Diffusers Diffusers `yaml:"diffusers,omitempty" json:"diffusers,omitempty"`
+ Step int `yaml:"step,omitempty" json:"step,omitempty"`
+
+ // GRPC Options
+ GRPC GRPC `yaml:"grpc,omitempty" json:"grpc,omitempty"`
+
+ // TTS specifics
+ TTSConfig `yaml:"tts,omitempty" json:"tts,omitempty"`
+
+ // CUDA
+ // Explicitly enable CUDA or not (some backends might need it)
+ CUDA bool `yaml:"cuda,omitempty" json:"cuda,omitempty"`
+
+ DownloadFiles []File `yaml:"download_files,omitempty" json:"download_files,omitempty"`
+
+ Description string `yaml:"description,omitempty" json:"description,omitempty"`
+ Usage string `yaml:"usage,omitempty" json:"usage,omitempty"`
+
+ Options []string `yaml:"options,omitempty" json:"options,omitempty"`
+ Overrides []string `yaml:"overrides,omitempty" json:"overrides,omitempty"`
+
+ MCP MCPConfig `yaml:"mcp,omitempty" json:"mcp,omitempty"`
+ Agent AgentConfig `yaml:"agent,omitempty" json:"agent,omitempty"`
+}
+
+// @Description MCP configuration
+type MCPConfig struct {
+ Servers string `yaml:"remote,omitempty" json:"remote,omitempty"`
+ Stdio string `yaml:"stdio,omitempty" json:"stdio,omitempty"`
+}
+
+// @Description Agent configuration
+type AgentConfig struct {
+ MaxAttempts int `yaml:"max_attempts,omitempty" json:"max_attempts,omitempty"`
+ MaxIterations int `yaml:"max_iterations,omitempty" json:"max_iterations,omitempty"`
+ EnableReasoning bool `yaml:"enable_reasoning,omitempty" json:"enable_reasoning,omitempty"`
+ EnablePlanning bool `yaml:"enable_planning,omitempty" json:"enable_planning,omitempty"`
+ EnableMCPPrompts bool `yaml:"enable_mcp_prompts,omitempty" json:"enable_mcp_prompts,omitempty"`
+ EnablePlanReEvaluator bool `yaml:"enable_plan_re_evaluator,omitempty" json:"enable_plan_re_evaluator,omitempty"`
+}
+
+func (c *MCPConfig) MCPConfigFromYAML() (MCPGenericConfig[MCPRemoteServers], MCPGenericConfig[MCPSTDIOServers], error) {
+ var remote MCPGenericConfig[MCPRemoteServers]
+ var stdio MCPGenericConfig[MCPSTDIOServers]
+
+ if err := yaml.Unmarshal([]byte(c.Servers), &remote); err != nil {
+ return remote, stdio, err
+ }
+
+ if err := yaml.Unmarshal([]byte(c.Stdio), &stdio); err != nil {
+ return remote, stdio, err
+ }
+ return remote, stdio, nil
+}
+
+// @Description MCP generic configuration
+type MCPGenericConfig[T any] struct {
+ Servers T `yaml:"mcpServers,omitempty" json:"mcpServers,omitempty"`
+}
+type MCPRemoteServers map[string]MCPRemoteServer
+type MCPSTDIOServers map[string]MCPSTDIOServer
+
+// @Description MCP remote server configuration
+type MCPRemoteServer struct {
+ URL string `json:"url,omitempty"`
+ Token string `json:"token,omitempty"`
+}
+
+// @Description MCP STDIO server configuration
+type MCPSTDIOServer struct {
+ Args []string `json:"args,omitempty"`
+ Env map[string]string `json:"env,omitempty"`
+ Command string `json:"command,omitempty"`
+}
+
+// @Description Pipeline defines other models to use for audio-to-audio
+type Pipeline struct {
+ TTS string `yaml:"tts,omitempty" json:"tts,omitempty"`
+ LLM string `yaml:"llm,omitempty" json:"llm,omitempty"`
+ Transcription string `yaml:"transcription,omitempty" json:"transcription,omitempty"`
+ VAD string `yaml:"vad,omitempty" json:"vad,omitempty"`
+}
+
+// @Description File configuration for model downloads
+type File struct {
+ Filename string `yaml:"filename,omitempty" json:"filename,omitempty"`
+ SHA256 string `yaml:"sha256,omitempty" json:"sha256,omitempty"`
+ URI downloader.URI `yaml:"uri,omitempty" json:"uri,omitempty"`
+}
+
+type FeatureFlag map[string]*bool
+
+func (ff FeatureFlag) Enabled(s string) bool {
+ if v, exists := ff[s]; exists && v != nil {
+ return *v
+ }
+ return false
+}
+
+// @Description GRPC configuration
+type GRPC struct {
+ Attempts int `yaml:"attempts,omitempty" json:"attempts,omitempty"`
+ AttemptsSleepTime int `yaml:"attempts_sleep_time,omitempty" json:"attempts_sleep_time,omitempty"`
+}
+
+// @Description Diffusers configuration
+type Diffusers struct {
+ CUDA bool `yaml:"cuda,omitempty" json:"cuda,omitempty"`
+ PipelineType string `yaml:"pipeline_type,omitempty" json:"pipeline_type,omitempty"`
+ SchedulerType string `yaml:"scheduler_type,omitempty" json:"scheduler_type,omitempty"`
+ EnableParameters string `yaml:"enable_parameters,omitempty" json:"enable_parameters,omitempty"` // A list of comma separated parameters to specify
+ IMG2IMG bool `yaml:"img2img,omitempty" json:"img2img,omitempty"` // Image to Image Diffuser
+ ClipSkip int `yaml:"clip_skip,omitempty" json:"clip_skip,omitempty"` // Skip every N frames
+ ClipModel string `yaml:"clip_model,omitempty" json:"clip_model,omitempty"` // Clip model to use
+ ClipSubFolder string `yaml:"clip_subfolder,omitempty" json:"clip_subfolder,omitempty"` // Subfolder to use for clip model
+ ControlNet string `yaml:"control_net,omitempty" json:"control_net,omitempty"`
+}
+
+// @Description LLMConfig is a struct that holds the configuration that are generic for most of the LLM backends.
+type LLMConfig struct {
+ SystemPrompt string `yaml:"system_prompt,omitempty" json:"system_prompt,omitempty"`
+ TensorSplit string `yaml:"tensor_split,omitempty" json:"tensor_split,omitempty"`
+ MainGPU string `yaml:"main_gpu,omitempty" json:"main_gpu,omitempty"`
+ RMSNormEps float32 `yaml:"rms_norm_eps,omitempty" json:"rms_norm_eps,omitempty"`
+ NGQA int32 `yaml:"ngqa,omitempty" json:"ngqa,omitempty"`
+ PromptCachePath string `yaml:"prompt_cache_path,omitempty" json:"prompt_cache_path,omitempty"`
+ PromptCacheAll bool `yaml:"prompt_cache_all,omitempty" json:"prompt_cache_all,omitempty"`
+ PromptCacheRO bool `yaml:"prompt_cache_ro,omitempty" json:"prompt_cache_ro,omitempty"`
+ MirostatETA *float64 `yaml:"mirostat_eta,omitempty" json:"mirostat_eta,omitempty"`
+ MirostatTAU *float64 `yaml:"mirostat_tau,omitempty" json:"mirostat_tau,omitempty"`
+ Mirostat *int `yaml:"mirostat,omitempty" json:"mirostat,omitempty"`
+ NGPULayers *int `yaml:"gpu_layers,omitempty" json:"gpu_layers,omitempty"`
+ MMap *bool `yaml:"mmap,omitempty" json:"mmap,omitempty"`
+ MMlock *bool `yaml:"mmlock,omitempty" json:"mmlock,omitempty"`
+ LowVRAM *bool `yaml:"low_vram,omitempty" json:"low_vram,omitempty"`
+ Reranking *bool `yaml:"reranking,omitempty" json:"reranking,omitempty"`
+ Grammar string `yaml:"grammar,omitempty" json:"grammar,omitempty"`
+ StopWords []string `yaml:"stopwords,omitempty" json:"stopwords,omitempty"`
+ Cutstrings []string `yaml:"cutstrings,omitempty" json:"cutstrings,omitempty"`
+ ExtractRegex []string `yaml:"extract_regex,omitempty" json:"extract_regex,omitempty"`
+ TrimSpace []string `yaml:"trimspace,omitempty" json:"trimspace,omitempty"`
+ TrimSuffix []string `yaml:"trimsuffix,omitempty" json:"trimsuffix,omitempty"`
+
+ ContextSize *int `yaml:"context_size,omitempty" json:"context_size,omitempty"`
+ NUMA bool `yaml:"numa,omitempty" json:"numa,omitempty"`
+ LoraAdapter string `yaml:"lora_adapter,omitempty" json:"lora_adapter,omitempty"`
+ LoraBase string `yaml:"lora_base,omitempty" json:"lora_base,omitempty"`
+ LoraAdapters []string `yaml:"lora_adapters,omitempty" json:"lora_adapters,omitempty"`
+ LoraScales []float32 `yaml:"lora_scales,omitempty" json:"lora_scales,omitempty"`
+ LoraScale float32 `yaml:"lora_scale,omitempty" json:"lora_scale,omitempty"`
+ NoMulMatQ bool `yaml:"no_mulmatq,omitempty" json:"no_mulmatq,omitempty"`
+ DraftModel string `yaml:"draft_model,omitempty" json:"draft_model,omitempty"`
+ NDraft int32 `yaml:"n_draft,omitempty" json:"n_draft,omitempty"`
+ Quantization string `yaml:"quantization,omitempty" json:"quantization,omitempty"`
+ LoadFormat string `yaml:"load_format,omitempty" json:"load_format,omitempty"`
+ GPUMemoryUtilization float32 `yaml:"gpu_memory_utilization,omitempty" json:"gpu_memory_utilization,omitempty"` // vLLM
+ TrustRemoteCode bool `yaml:"trust_remote_code,omitempty" json:"trust_remote_code,omitempty"` // vLLM
+ EnforceEager bool `yaml:"enforce_eager,omitempty" json:"enforce_eager,omitempty"` // vLLM
+ SwapSpace int `yaml:"swap_space,omitempty" json:"swap_space,omitempty"` // vLLM
+ MaxModelLen int `yaml:"max_model_len,omitempty" json:"max_model_len,omitempty"` // vLLM
+ TensorParallelSize int `yaml:"tensor_parallel_size,omitempty" json:"tensor_parallel_size,omitempty"` // vLLM
+ DisableLogStatus bool `yaml:"disable_log_stats,omitempty" json:"disable_log_stats,omitempty"` // vLLM
+ DType string `yaml:"dtype,omitempty" json:"dtype,omitempty"` // vLLM
+ LimitMMPerPrompt LimitMMPerPrompt `yaml:"limit_mm_per_prompt,omitempty" json:"limit_mm_per_prompt,omitempty"` // vLLM
+ MMProj string `yaml:"mmproj,omitempty" json:"mmproj,omitempty"`
+
+ FlashAttention *string `yaml:"flash_attention,omitempty" json:"flash_attention,omitempty"`
+ NoKVOffloading bool `yaml:"no_kv_offloading,omitempty" json:"no_kv_offloading,omitempty"`
+ CacheTypeK string `yaml:"cache_type_k,omitempty" json:"cache_type_k,omitempty"`
+ CacheTypeV string `yaml:"cache_type_v,omitempty" json:"cache_type_v,omitempty"`
+
+ RopeScaling string `yaml:"rope_scaling,omitempty" json:"rope_scaling,omitempty"`
+ ModelType string `yaml:"type,omitempty" json:"type,omitempty"`
+
+ YarnExtFactor float32 `yaml:"yarn_ext_factor,omitempty" json:"yarn_ext_factor,omitempty"`
+ YarnAttnFactor float32 `yaml:"yarn_attn_factor,omitempty" json:"yarn_attn_factor,omitempty"`
+ YarnBetaFast float32 `yaml:"yarn_beta_fast,omitempty" json:"yarn_beta_fast,omitempty"`
+ YarnBetaSlow float32 `yaml:"yarn_beta_slow,omitempty" json:"yarn_beta_slow,omitempty"`
+
+ CFGScale float32 `yaml:"cfg_scale,omitempty" json:"cfg_scale,omitempty"` // Classifier-Free Guidance Scale
+}
+
+// @Description LimitMMPerPrompt is a struct that holds the configuration for the limit-mm-per-prompt config in vLLM
+type LimitMMPerPrompt struct {
+ LimitImagePerPrompt int `yaml:"image,omitempty" json:"image,omitempty"`
+ LimitVideoPerPrompt int `yaml:"video,omitempty" json:"video,omitempty"`
+ LimitAudioPerPrompt int `yaml:"audio,omitempty" json:"audio,omitempty"`
+}
+
+// @Description TemplateConfig is a struct that holds the configuration of the templating system
+type TemplateConfig struct {
+ // Chat is the template used in the chat completion endpoint
+ Chat string `yaml:"chat,omitempty" json:"chat,omitempty"`
+
+ // ChatMessage is the template used for chat messages
+ ChatMessage string `yaml:"chat_message,omitempty" json:"chat_message,omitempty"`
+
+ // Completion is the template used for completion requests
+ Completion string `yaml:"completion,omitempty" json:"completion,omitempty"`
+
+ // Edit is the template used for edit completion requests
+ Edit string `yaml:"edit,omitempty" json:"edit,omitempty"`
+
+ // Functions is the template used when tools are present in the client requests
+ Functions string `yaml:"function,omitempty" json:"function,omitempty"`
+
+ // UseTokenizerTemplate is a flag that indicates if the tokenizer template should be used.
+ // Note: this is mostly consumed for backends such as vllm and transformers
+ // that can use the tokenizers specified in the JSON config files of the models
+ UseTokenizerTemplate bool `yaml:"use_tokenizer_template,omitempty" json:"use_tokenizer_template,omitempty"`
+
+ // JoinChatMessagesByCharacter is a string that will be used to join chat messages together.
+ // It defaults to \n
+ JoinChatMessagesByCharacter *string `yaml:"join_chat_messages_by_character,omitempty" json:"join_chat_messages_by_character,omitempty"`
+
+ Multimodal string `yaml:"multimodal,omitempty" json:"multimodal,omitempty"`
+
+ ReplyPrefix string `yaml:"reply_prefix,omitempty" json:"reply_prefix,omitempty"`
+}
+
+func (c *ModelConfig) syncKnownUsecasesFromString() {
+ c.KnownUsecases = GetUsecasesFromYAML(c.KnownUsecaseStrings)
+ // Make sure the usecases are valid, we rewrite with what we identified
+ c.KnownUsecaseStrings = []string{}
+ for k, usecase := range GetAllModelConfigUsecases() {
+ if c.HasUsecases(usecase) {
+ c.KnownUsecaseStrings = append(c.KnownUsecaseStrings, k)
+ }
+ }
+}
+
+func (c *ModelConfig) UnmarshalYAML(value *yaml.Node) error {
+ type BCAlias ModelConfig
+ var aux BCAlias
+ if err := value.Decode(&aux); err != nil {
+ return err
+ }
+
+ mc := ModelConfig(aux)
+ *c = mc
+ c.syncKnownUsecasesFromString()
+ return nil
+}
+
+func (c *ModelConfig) SetFunctionCallString(s string) {
+ c.functionCallString = s
+}
+
+func (c *ModelConfig) SetFunctionCallNameString(s string) {
+ c.functionCallNameString = s
+}
+
+func (c *ModelConfig) ShouldUseFunctions() bool {
+ return ((c.functionCallString != "none" || c.functionCallString == "") || c.ShouldCallSpecificFunction())
+}
+
+func (c *ModelConfig) ShouldCallSpecificFunction() bool {
+ return len(c.functionCallNameString) > 0
+}
+
+// MMProjFileName returns the filename of the MMProj file
+// If the MMProj is a URL, it will return the MD5 of the URL which is the filename
+func (c *ModelConfig) MMProjFileName() string {
+ uri := downloader.URI(c.MMProj)
+ if uri.LooksLikeURL() {
+ f, _ := uri.FilenameFromUrl()
+ return f
+ }
+
+ return c.MMProj
+}
+
+func (c *ModelConfig) IsMMProjURL() bool {
+ uri := downloader.URI(c.MMProj)
+ return uri.LooksLikeURL()
+}
+
+func (c *ModelConfig) IsModelURL() bool {
+ uri := downloader.URI(c.Model)
+ return uri.LooksLikeURL()
+}
+
+// ModelFileName returns the filename of the model
+// If the model is a URL, it will return the MD5 of the URL which is the filename
+func (c *ModelConfig) ModelFileName() string {
+ uri := downloader.URI(c.Model)
+ if uri.LooksLikeURL() {
+ f, _ := uri.FilenameFromUrl()
+ return f
+ }
+
+ return c.Model
+}
+
+func (c *ModelConfig) FunctionToCall() string {
+ if c.functionCallNameString != "" &&
+ c.functionCallNameString != "none" && c.functionCallNameString != "auto" {
+ return c.functionCallNameString
+ }
+
+ return c.functionCallString
+}
+
+func (cfg *ModelConfig) SetDefaults(opts ...ConfigLoaderOption) {
+ lo := &LoadOptions{}
+ lo.Apply(opts...)
+
+ ctx := lo.ctxSize
+ threads := lo.threads
+ f16 := lo.f16
+ debug := lo.debug
+ // https://github.com/ggerganov/llama.cpp/blob/75cd4c77292034ecec587ecb401366f57338f7c0/common/sampling.h#L22
+ defaultTopP := 0.95
+ defaultTopK := 40
+ defaultTemp := 0.9
+ // https://github.com/mudler/LocalAI/issues/2780
+ defaultMirostat := 0
+ defaultMirostatTAU := 5.0
+ defaultMirostatETA := 0.1
+ defaultTypicalP := 1.0
+ defaultTFZ := 1.0
+ defaultZero := 0
+
+ trueV := true
+ falseV := false
+
+ if cfg.Seed == nil {
+ // random number generator seed
+ defaultSeed := RAND_SEED
+ cfg.Seed = &defaultSeed
+ }
+
+ if cfg.TopK == nil {
+ cfg.TopK = &defaultTopK
+ }
+
+ if cfg.TypicalP == nil {
+ cfg.TypicalP = &defaultTypicalP
+ }
+
+ if cfg.TFZ == nil {
+ cfg.TFZ = &defaultTFZ
+ }
+
+ if cfg.MMap == nil {
+ // MMap is enabled by default
+
+ // Only exception is for Intel GPUs
+ if os.Getenv("XPU") != "" {
+ cfg.MMap = &falseV
+ } else {
+ cfg.MMap = &trueV
+ }
+ }
+
+ if cfg.MMlock == nil {
+ // MMlock is disabled by default
+ cfg.MMlock = &falseV
+ }
+
+ if cfg.TopP == nil {
+ cfg.TopP = &defaultTopP
+ }
+ if cfg.Temperature == nil {
+ cfg.Temperature = &defaultTemp
+ }
+
+ if cfg.Maxtokens == nil {
+ cfg.Maxtokens = &defaultZero
+ }
+
+ if cfg.Mirostat == nil {
+ cfg.Mirostat = &defaultMirostat
+ }
+
+ if cfg.MirostatETA == nil {
+ cfg.MirostatETA = &defaultMirostatETA
+ }
+
+ if cfg.MirostatTAU == nil {
+ cfg.MirostatTAU = &defaultMirostatTAU
+ }
+
+ if cfg.LowVRAM == nil {
+ cfg.LowVRAM = &falseV
+ }
+
+ if cfg.Embeddings == nil {
+ cfg.Embeddings = &falseV
+ }
+
+ if cfg.Reranking == nil {
+ cfg.Reranking = &falseV
+ }
+
+ if threads == 0 {
+ // Threads can't be 0
+ threads = 4
+ }
+
+ if cfg.Threads == nil {
+ cfg.Threads = &threads
+ }
+
+ if cfg.F16 == nil {
+ cfg.F16 = &f16
+ }
+
+ if cfg.Debug == nil {
+ cfg.Debug = &falseV
+ }
+
+ if debug {
+ cfg.Debug = &trueV
+ }
+
+ guessDefaultsFromFile(cfg, lo.modelPath, ctx)
+ cfg.syncKnownUsecasesFromString()
+}
+
+func (c *ModelConfig) Validate() (bool, error) {
+ downloadedFileNames := []string{}
+ for _, f := range c.DownloadFiles {
+ downloadedFileNames = append(downloadedFileNames, f.Filename)
+ }
+ validationTargets := []string{c.Backend, c.Model, c.MMProj}
+ validationTargets = append(validationTargets, downloadedFileNames...)
+ // Simple validation to make sure the model can be correctly loaded
+ for _, n := range validationTargets {
+ if n == "" {
+ continue
+ }
+ if strings.HasPrefix(n, string(os.PathSeparator)) ||
+ strings.Contains(n, "..") {
+ return false, fmt.Errorf("invalid file path: %s", n)
+ }
+ }
+
+ if c.Backend != "" {
+ // a regex that checks that is a string name with no special characters, except '-' and '_'
+ re := regexp.MustCompile(`^[a-zA-Z0-9-_]+$`)
+ if !re.MatchString(c.Backend) {
+ return false, fmt.Errorf("invalid backend name: %s", c.Backend)
+ }
+ }
+
+ // Validate MCP configuration if present
+ if c.MCP.Servers != "" || c.MCP.Stdio != "" {
+ if _, _, err := c.MCP.MCPConfigFromYAML(); err != nil {
+ return false, fmt.Errorf("invalid MCP configuration: %w", err)
+ }
+ }
+
+ return true, nil
+}
+
+func (c *ModelConfig) HasTemplate() bool {
+ return c.TemplateConfig.Completion != "" || c.TemplateConfig.Edit != "" || c.TemplateConfig.Chat != "" || c.TemplateConfig.ChatMessage != "" || c.TemplateConfig.UseTokenizerTemplate
+}
+
+func (c *ModelConfig) GetModelConfigFile() string {
+ return c.modelConfigFile
+}
+
+type ModelConfigUsecase int
+
+const (
+ FLAG_ANY ModelConfigUsecase = 0b000000000000
+ FLAG_CHAT ModelConfigUsecase = 0b000000000001
+ FLAG_COMPLETION ModelConfigUsecase = 0b000000000010
+ FLAG_EDIT ModelConfigUsecase = 0b000000000100
+ FLAG_EMBEDDINGS ModelConfigUsecase = 0b000000001000
+ FLAG_RERANK ModelConfigUsecase = 0b000000010000
+ FLAG_IMAGE ModelConfigUsecase = 0b000000100000
+ FLAG_TRANSCRIPT ModelConfigUsecase = 0b000001000000
+ FLAG_TTS ModelConfigUsecase = 0b000010000000
+ FLAG_SOUND_GENERATION ModelConfigUsecase = 0b000100000000
+ FLAG_TOKENIZE ModelConfigUsecase = 0b001000000000
+ FLAG_VAD ModelConfigUsecase = 0b010000000000
+ FLAG_VIDEO ModelConfigUsecase = 0b100000000000
+ FLAG_DETECTION ModelConfigUsecase = 0b1000000000000
+
+ // Common Subsets
+ FLAG_LLM ModelConfigUsecase = FLAG_CHAT | FLAG_COMPLETION | FLAG_EDIT
+)
+
+func GetAllModelConfigUsecases() map[string]ModelConfigUsecase {
+ return map[string]ModelConfigUsecase{
+ // Note: FLAG_ANY is intentionally excluded from this map
+ // because it's 0 and would always match in HasUsecases checks
+ "FLAG_CHAT": FLAG_CHAT,
+ "FLAG_COMPLETION": FLAG_COMPLETION,
+ "FLAG_EDIT": FLAG_EDIT,
+ "FLAG_EMBEDDINGS": FLAG_EMBEDDINGS,
+ "FLAG_RERANK": FLAG_RERANK,
+ "FLAG_IMAGE": FLAG_IMAGE,
+ "FLAG_TRANSCRIPT": FLAG_TRANSCRIPT,
+ "FLAG_TTS": FLAG_TTS,
+ "FLAG_SOUND_GENERATION": FLAG_SOUND_GENERATION,
+ "FLAG_TOKENIZE": FLAG_TOKENIZE,
+ "FLAG_VAD": FLAG_VAD,
+ "FLAG_LLM": FLAG_LLM,
+ "FLAG_VIDEO": FLAG_VIDEO,
+ "FLAG_DETECTION": FLAG_DETECTION,
+ }
+}
+
+func stringToFlag(s string) string {
+ return "FLAG_" + strings.ToUpper(s)
+}
+
+func GetUsecasesFromYAML(input []string) *ModelConfigUsecase {
+ if len(input) == 0 {
+ return nil
+ }
+ result := FLAG_ANY
+ flags := GetAllModelConfigUsecases()
+ for _, str := range input {
+ for _, flag := range []string{stringToFlag(str), str} {
+ f, exists := flags[flag]
+ if exists {
+ result |= f
+ }
+ }
+ }
+ return &result
+}
+
+// HasUsecases examines a ModelConfig and determines which endpoints have a chance of success.
+func (c *ModelConfig) HasUsecases(u ModelConfigUsecase) bool {
+ if (c.KnownUsecases != nil) && ((u & *c.KnownUsecases) == u) {
+ return true
+ }
+ return c.GuessUsecases(u)
+}
+
+// GuessUsecases is a **heuristic based** function, as the backend in question may not be loaded yet, and the config may not record what it's useful at.
+// In its current state, this function should ideally check for properties of the config like templates, rather than the direct backend name checks for the lower half.
+// This avoids the maintenance burden of updating this list for each new backend - but unfortunately, that's the best option for some services currently.
+func (c *ModelConfig) GuessUsecases(u ModelConfigUsecase) bool {
+ if (u & FLAG_CHAT) == FLAG_CHAT {
+ if c.TemplateConfig.Chat == "" && c.TemplateConfig.ChatMessage == "" && !c.TemplateConfig.UseTokenizerTemplate {
+ return false
+ }
+ }
+ if (u & FLAG_COMPLETION) == FLAG_COMPLETION {
+ if c.TemplateConfig.Completion == "" {
+ return false
+ }
+ }
+ if (u & FLAG_EDIT) == FLAG_EDIT {
+ if c.TemplateConfig.Edit == "" {
+ return false
+ }
+ }
+ if (u & FLAG_EMBEDDINGS) == FLAG_EMBEDDINGS {
+ if c.Embeddings == nil || !*c.Embeddings {
+ return false
+ }
+ }
+ if (u & FLAG_IMAGE) == FLAG_IMAGE {
+ imageBackends := []string{"diffusers", "stablediffusion", "stablediffusion-ggml"}
+ if !slices.Contains(imageBackends, c.Backend) {
+ return false
+ }
+
+ if c.Backend == "diffusers" && c.Diffusers.PipelineType == "" {
+ return false
+ }
+
+ }
+ if (u & FLAG_VIDEO) == FLAG_VIDEO {
+ videoBackends := []string{"diffusers", "stablediffusion"}
+ if !slices.Contains(videoBackends, c.Backend) {
+ return false
+ }
+
+ if c.Backend == "diffusers" && c.Diffusers.PipelineType == "" {
+ return false
+ }
+
+ }
+ if (u & FLAG_RERANK) == FLAG_RERANK {
+ if c.Backend != "rerankers" {
+ return false
+ }
+ }
+ if (u & FLAG_TRANSCRIPT) == FLAG_TRANSCRIPT {
+ if c.Backend != "whisper" {
+ return false
+ }
+ }
+ if (u & FLAG_TTS) == FLAG_TTS {
+ ttsBackends := []string{"bark-cpp", "piper", "transformers-musicgen", "kokoro"}
+ if !slices.Contains(ttsBackends, c.Backend) {
+ return false
+ }
+ }
+
+ if (u & FLAG_DETECTION) == FLAG_DETECTION {
+ if c.Backend != "rfdetr" {
+ return false
+ }
+ }
+
+ if (u & FLAG_SOUND_GENERATION) == FLAG_SOUND_GENERATION {
+ if c.Backend != "transformers-musicgen" {
+ return false
+ }
+ }
+
+ if (u & FLAG_TOKENIZE) == FLAG_TOKENIZE {
+ tokenizeCapableBackends := []string{"llama.cpp", "rwkv"}
+ if !slices.Contains(tokenizeCapableBackends, c.Backend) {
+ return false
+ }
+ }
+
+ if (u & FLAG_VAD) == FLAG_VAD {
+ if c.Backend != "silero-vad" {
+ return false
+ }
+ }
+
+ return true
+}
+
+// BuildCogitoOptions generates cogito options from the model configuration
+// It accepts a context, MCP sessions, and optional callback functions for status, reasoning, tool calls, and tool results
+func (c *ModelConfig) BuildCogitoOptions() []cogito.Option {
+ cogitoOpts := []cogito.Option{
+ cogito.WithIterations(3), // default to 3 iterations
+ cogito.WithMaxAttempts(3), // default to 3 attempts
+ cogito.WithForceReasoning(),
+ }
+
+ // Apply agent configuration options
+ if c.Agent.EnableReasoning {
+ cogitoOpts = append(cogitoOpts, cogito.EnableToolReasoner)
+ }
+
+ if c.Agent.EnablePlanning {
+ cogitoOpts = append(cogitoOpts, cogito.EnableAutoPlan)
+ }
+
+ if c.Agent.EnableMCPPrompts {
+ cogitoOpts = append(cogitoOpts, cogito.EnableMCPPrompts)
+ }
+
+ if c.Agent.EnablePlanReEvaluator {
+ cogitoOpts = append(cogitoOpts, cogito.EnableAutoPlanReEvaluator)
+ }
+
+ if c.Agent.MaxIterations != 0 {
+ cogitoOpts = append(cogitoOpts, cogito.WithIterations(c.Agent.MaxIterations))
+ }
+
+ if c.Agent.MaxAttempts != 0 {
+ cogitoOpts = append(cogitoOpts, cogito.WithMaxAttempts(c.Agent.MaxAttempts))
+ }
+
+ return cogitoOpts
+}
diff --git a/core/config/model_config_filter.go b/core/config/model_config_filter.go
new file mode 100644
index 0000000000000000000000000000000000000000..cb7cc0bfd45d6113c4702511a5ded3687fab53ad
--- /dev/null
+++ b/core/config/model_config_filter.go
@@ -0,0 +1,35 @@
+package config
+
+import "regexp"
+
+type ModelConfigFilterFn func(string, *ModelConfig) bool
+
+func NoFilterFn(_ string, _ *ModelConfig) bool { return true }
+
+func BuildNameFilterFn(filter string) (ModelConfigFilterFn, error) {
+ if filter == "" {
+ return NoFilterFn, nil
+ }
+ rxp, err := regexp.Compile(filter)
+ if err != nil {
+ return nil, err
+ }
+ return func(name string, config *ModelConfig) bool {
+ if config != nil {
+ return rxp.MatchString(config.Name)
+ }
+ return rxp.MatchString(name)
+ }, nil
+}
+
+func BuildUsecaseFilterFn(usecases ModelConfigUsecase) ModelConfigFilterFn {
+ if usecases == FLAG_ANY {
+ return NoFilterFn
+ }
+ return func(name string, config *ModelConfig) bool {
+ if config == nil {
+ return false // TODO: Potentially make this a param, for now, no known usecase to include
+ }
+ return config.HasUsecases(usecases)
+ }
+}
diff --git a/core/config/model_config_loader.go b/core/config/model_config_loader.go
new file mode 100644
index 0000000000000000000000000000000000000000..1a8c64230560da7a9d23b84de2852e515776b810
--- /dev/null
+++ b/core/config/model_config_loader.go
@@ -0,0 +1,380 @@
+package config
+
+import (
+ "errors"
+ "fmt"
+ "io/fs"
+ "os"
+ "path/filepath"
+ "sort"
+ "strings"
+ "sync"
+
+ "github.com/charmbracelet/glamour"
+ "github.com/mudler/LocalAI/core/schema"
+ "github.com/mudler/LocalAI/pkg/downloader"
+ "github.com/mudler/LocalAI/pkg/utils"
+ "github.com/mudler/xlog"
+ "gopkg.in/yaml.v3"
+)
+
+type ModelConfigLoader struct {
+ configs map[string]ModelConfig
+ modelPath string
+ sync.Mutex
+}
+
+func NewModelConfigLoader(modelPath string) *ModelConfigLoader {
+ return &ModelConfigLoader{
+ configs: make(map[string]ModelConfig),
+ modelPath: modelPath,
+ }
+}
+
+type LoadOptions struct {
+ modelPath string
+ debug bool
+ threads, ctxSize int
+ f16 bool
+}
+
+func LoadOptionDebug(debug bool) ConfigLoaderOption {
+ return func(o *LoadOptions) {
+ o.debug = debug
+ }
+}
+
+func LoadOptionThreads(threads int) ConfigLoaderOption {
+ return func(o *LoadOptions) {
+ o.threads = threads
+ }
+}
+
+func LoadOptionContextSize(ctxSize int) ConfigLoaderOption {
+ return func(o *LoadOptions) {
+ o.ctxSize = ctxSize
+ }
+}
+
+func ModelPath(modelPath string) ConfigLoaderOption {
+ return func(o *LoadOptions) {
+ o.modelPath = modelPath
+ }
+}
+
+func LoadOptionF16(f16 bool) ConfigLoaderOption {
+ return func(o *LoadOptions) {
+ o.f16 = f16
+ }
+}
+
+type ConfigLoaderOption func(*LoadOptions)
+
+func (lo *LoadOptions) Apply(options ...ConfigLoaderOption) {
+ for _, l := range options {
+ l(lo)
+ }
+}
+
+// TODO: either in the next PR or the next commit, I want to merge these down into a single function that looks at the first few characters of the file to determine if we need to deserialize to []BackendConfig or BackendConfig
+func readMultipleModelConfigsFromFile(file string, opts ...ConfigLoaderOption) ([]*ModelConfig, error) {
+ c := &[]*ModelConfig{}
+ f, err := os.ReadFile(file)
+ if err != nil {
+ return nil, fmt.Errorf("readMultipleModelConfigsFromFile cannot read config file %q: %w", file, err)
+ }
+ if err := yaml.Unmarshal(f, c); err != nil {
+ return nil, fmt.Errorf("readMultipleModelConfigsFromFile cannot unmarshal config file %q: %w", file, err)
+ }
+
+ for _, cc := range *c {
+ cc.modelConfigFile = file
+ cc.SetDefaults(opts...)
+ }
+
+ return *c, nil
+}
+
+func readModelConfigFromFile(file string, opts ...ConfigLoaderOption) (*ModelConfig, error) {
+ lo := &LoadOptions{}
+ lo.Apply(opts...)
+
+ c := &ModelConfig{}
+ f, err := os.ReadFile(file)
+ if err != nil {
+ return nil, fmt.Errorf("readModelConfigFromFile cannot read config file %q: %w", file, err)
+ }
+ if err := yaml.Unmarshal(f, c); err != nil {
+ return nil, fmt.Errorf("readModelConfigFromFile cannot unmarshal config file %q: %w", file, err)
+ }
+
+ c.SetDefaults(opts...)
+
+ c.modelConfigFile = file
+ return c, nil
+}
+
+// Load a config file for a model
+func (bcl *ModelConfigLoader) LoadModelConfigFileByName(modelName, modelPath string, opts ...ConfigLoaderOption) (*ModelConfig, error) {
+
+ // Load a config file if present after the model name
+ cfg := &ModelConfig{
+ PredictionOptions: schema.PredictionOptions{
+ BasicModelRequest: schema.BasicModelRequest{
+ Model: modelName,
+ },
+ },
+ }
+
+ cfgExisting, exists := bcl.GetModelConfig(modelName)
+ if exists {
+ cfg = &cfgExisting
+ } else {
+ // Try loading a model config file
+ modelConfig := filepath.Join(modelPath, modelName+".yaml")
+ if _, err := os.Stat(modelConfig); err == nil {
+ if err := bcl.ReadModelConfig(
+ modelConfig, opts...,
+ ); err != nil {
+ return nil, fmt.Errorf("failed loading model config (%s) %s", modelConfig, err.Error())
+ }
+ cfgExisting, exists = bcl.GetModelConfig(modelName)
+ if exists {
+ cfg = &cfgExisting
+ }
+ }
+ }
+
+ cfg.SetDefaults(append(opts, ModelPath(modelPath))...)
+
+ return cfg, nil
+}
+
+func (bcl *ModelConfigLoader) LoadModelConfigFileByNameDefaultOptions(modelName string, appConfig *ApplicationConfig) (*ModelConfig, error) {
+ return bcl.LoadModelConfigFileByName(modelName, appConfig.SystemState.Model.ModelsPath,
+ LoadOptionDebug(appConfig.Debug),
+ LoadOptionThreads(appConfig.Threads),
+ LoadOptionContextSize(appConfig.ContextSize),
+ LoadOptionF16(appConfig.F16),
+ ModelPath(appConfig.SystemState.Model.ModelsPath))
+}
+
+// This format is currently only used when reading a single file at startup, passed in via ApplicationConfig.ConfigFile
+func (bcl *ModelConfigLoader) LoadMultipleModelConfigsSingleFile(file string, opts ...ConfigLoaderOption) error {
+ bcl.Lock()
+ defer bcl.Unlock()
+ c, err := readMultipleModelConfigsFromFile(file, opts...)
+ if err != nil {
+ return fmt.Errorf("cannot load config file: %w", err)
+ }
+
+ for _, cc := range c {
+ if valid, err := cc.Validate(); valid {
+ bcl.configs[cc.Name] = *cc
+ } else {
+ xlog.Warn("skipping invalid model config", "name", cc.Name, "error", err)
+ }
+ }
+ return nil
+}
+
+func (bcl *ModelConfigLoader) ReadModelConfig(file string, opts ...ConfigLoaderOption) error {
+ bcl.Lock()
+ defer bcl.Unlock()
+ c, err := readModelConfigFromFile(file, opts...)
+ if err != nil {
+ return fmt.Errorf("ReadModelConfig cannot read config file %q: %w", file, err)
+ }
+
+ if valid, err := c.Validate(); valid {
+ bcl.configs[c.Name] = *c
+ } else {
+ if err != nil {
+ return fmt.Errorf("config is not valid: %w", err)
+ }
+ return fmt.Errorf("config is not valid")
+ }
+
+ return nil
+}
+
+func (bcl *ModelConfigLoader) GetModelConfig(m string) (ModelConfig, bool) {
+ bcl.Lock()
+ defer bcl.Unlock()
+ v, exists := bcl.configs[m]
+ return v, exists
+}
+
+func (bcl *ModelConfigLoader) GetAllModelsConfigs() []ModelConfig {
+ bcl.Lock()
+ defer bcl.Unlock()
+ var res []ModelConfig
+ for _, v := range bcl.configs {
+ res = append(res, v)
+ }
+
+ sort.SliceStable(res, func(i, j int) bool {
+ return res[i].Name < res[j].Name
+ })
+
+ return res
+}
+
+func (bcl *ModelConfigLoader) GetModelConfigsByFilter(filter ModelConfigFilterFn) []ModelConfig {
+ bcl.Lock()
+ defer bcl.Unlock()
+ var res []ModelConfig
+
+ if filter == nil {
+ filter = NoFilterFn
+ }
+
+ for n, v := range bcl.configs {
+ if filter(n, &v) {
+ res = append(res, v)
+ }
+ }
+
+ // TODO: I don't think this one needs to Sort on name... but we'll see what breaks.
+
+ return res
+}
+
+func (bcl *ModelConfigLoader) RemoveModelConfig(m string) {
+ bcl.Lock()
+ defer bcl.Unlock()
+ delete(bcl.configs, m)
+}
+
+// Preload prepare models if they are not local but url or huggingface repositories
+func (bcl *ModelConfigLoader) Preload(modelPath string) error {
+ bcl.Lock()
+ defer bcl.Unlock()
+
+ status := func(fileName, current, total string, percent float64) {
+ utils.DisplayDownloadFunction(fileName, current, total, percent)
+ }
+
+ xlog.Info("Preloading models", "path", modelPath)
+
+ renderMode := "dark"
+ if os.Getenv("COLOR") != "" {
+ renderMode = os.Getenv("COLOR")
+ }
+
+ glamText := func(t string) {
+ out, err := glamour.Render(t, renderMode)
+ if err == nil && os.Getenv("NO_COLOR") == "" {
+ fmt.Println(out)
+ } else {
+ fmt.Println(t)
+ }
+ }
+
+ for i, config := range bcl.configs {
+
+ // Download files and verify their SHA
+ for i, file := range config.DownloadFiles {
+ xlog.Debug("Checking file exists and matches SHA", "filename", file.Filename)
+
+ if err := utils.VerifyPath(file.Filename, modelPath); err != nil {
+ return err
+ }
+ // Create file path
+ filePath := filepath.Join(modelPath, file.Filename)
+
+ if err := file.URI.DownloadFile(filePath, file.SHA256, i, len(config.DownloadFiles), status); err != nil {
+ return err
+ }
+ }
+
+ // If the model is an URL, expand it, and download the file
+ if config.IsModelURL() {
+ modelFileName := config.ModelFileName()
+ uri := downloader.URI(config.Model)
+ if uri.ResolveURL() != config.Model {
+ // check if file exists
+ if _, err := os.Stat(filepath.Join(modelPath, modelFileName)); errors.Is(err, os.ErrNotExist) {
+ err := uri.DownloadFile(filepath.Join(modelPath, modelFileName), "", 0, 0, status)
+ if err != nil {
+ return err
+ }
+ }
+
+ cc := bcl.configs[i]
+ c := &cc
+ c.PredictionOptions.Model = modelFileName
+ bcl.configs[i] = *c
+ }
+ }
+
+ if config.IsMMProjURL() {
+ modelFileName := config.MMProjFileName()
+ uri := downloader.URI(config.MMProj)
+ // check if file exists
+ if _, err := os.Stat(filepath.Join(modelPath, modelFileName)); errors.Is(err, os.ErrNotExist) {
+ err := uri.DownloadFile(filepath.Join(modelPath, modelFileName), "", 0, 0, status)
+ if err != nil {
+ return err
+ }
+ }
+
+ cc := bcl.configs[i]
+ c := &cc
+ c.MMProj = modelFileName
+ bcl.configs[i] = *c
+ }
+
+ if bcl.configs[i].Name != "" {
+ glamText(fmt.Sprintf("**Model name**: _%s_", bcl.configs[i].Name))
+ }
+ if bcl.configs[i].Description != "" {
+ //glamText("**Description**")
+ glamText(bcl.configs[i].Description)
+ }
+ if bcl.configs[i].Usage != "" {
+ //glamText("**Usage**")
+ glamText(bcl.configs[i].Usage)
+ }
+ }
+ return nil
+}
+
+// LoadModelConfigsFromPath reads all the configurations of the models from a path
+// (non-recursive)
+func (bcl *ModelConfigLoader) LoadModelConfigsFromPath(path string, opts ...ConfigLoaderOption) error {
+ bcl.Lock()
+ defer bcl.Unlock()
+
+ entries, err := os.ReadDir(path)
+ if err != nil {
+ return fmt.Errorf("LoadModelConfigsFromPath cannot read directory '%s': %w", path, err)
+ }
+ files := make([]fs.FileInfo, 0, len(entries))
+ for _, entry := range entries {
+ info, err := entry.Info()
+ if err != nil {
+ return err
+ }
+ files = append(files, info)
+ }
+ for _, file := range files {
+ // Skip templates, YAML and .keep files
+ if !strings.Contains(file.Name(), ".yaml") && !strings.Contains(file.Name(), ".yml") ||
+ strings.HasPrefix(file.Name(), ".") {
+ continue
+ }
+ c, err := readModelConfigFromFile(filepath.Join(path, file.Name()), opts...)
+ if err != nil {
+ xlog.Error("LoadModelConfigsFromPath cannot read config file", "error", err, "File Name", file.Name())
+ continue
+ }
+ if valid, validationErr := c.Validate(); valid {
+ bcl.configs[c.Name] = *c
+ } else {
+ xlog.Error("config is not valid", "error", validationErr, "Name", c.Name)
+ }
+ }
+
+ return nil
+}
diff --git a/core/config/model_config_test.go b/core/config/model_config_test.go
new file mode 100644
index 0000000000000000000000000000000000000000..a086d95f6bd39b7337752c5d5e6e6624dbe40521
--- /dev/null
+++ b/core/config/model_config_test.go
@@ -0,0 +1,228 @@
+package config
+
+import (
+ "io"
+ "net/http"
+ "os"
+
+ . "github.com/onsi/ginkgo/v2"
+ . "github.com/onsi/gomega"
+)
+
+var _ = Describe("Test cases for config related functions", func() {
+ Context("Test Read configuration functions", func() {
+ It("Test Validate", func() {
+ tmp, err := os.CreateTemp("", "config.yaml")
+ Expect(err).To(BeNil())
+ defer os.Remove(tmp.Name())
+ _, err = tmp.WriteString(
+ `backend: "../foo-bar"
+name: "foo"
+parameters:
+ model: "foo-bar"
+known_usecases:
+- chat
+- COMPLETION
+`)
+ Expect(err).ToNot(HaveOccurred())
+ config, err := readModelConfigFromFile(tmp.Name())
+ Expect(err).To(BeNil())
+ Expect(config).ToNot(BeNil())
+ valid, err := config.Validate()
+ Expect(err).To(HaveOccurred())
+ Expect(valid).To(BeFalse())
+ Expect(config.KnownUsecases).ToNot(BeNil())
+ })
+ It("Test Validate", func() {
+ tmp, err := os.CreateTemp("", "config.yaml")
+ Expect(err).To(BeNil())
+ defer os.Remove(tmp.Name())
+ _, err = tmp.WriteString(
+ `name: bar-baz
+backend: "foo-bar"
+parameters:
+ model: "foo-bar"`)
+ Expect(err).ToNot(HaveOccurred())
+ config, err := readModelConfigFromFile(tmp.Name())
+ Expect(err).To(BeNil())
+ Expect(config).ToNot(BeNil())
+ // two configs in config.yaml
+ Expect(config.Name).To(Equal("bar-baz"))
+ valid, err := config.Validate()
+ Expect(err).To(BeNil())
+ Expect(valid).To(BeTrue())
+
+ // download https://raw.githubusercontent.com/mudler/LocalAI/v2.25.0/embedded/models/hermes-2-pro-mistral.yaml
+ httpClient := http.Client{}
+ resp, err := httpClient.Get("https://raw.githubusercontent.com/mudler/LocalAI/v2.25.0/embedded/models/hermes-2-pro-mistral.yaml")
+ Expect(err).To(BeNil())
+ defer resp.Body.Close()
+ tmp, err = os.CreateTemp("", "config.yaml")
+ Expect(err).To(BeNil())
+ defer os.Remove(tmp.Name())
+ _, err = io.Copy(tmp, resp.Body)
+ Expect(err).To(BeNil())
+ config, err = readModelConfigFromFile(tmp.Name())
+ Expect(err).To(BeNil())
+ Expect(config).ToNot(BeNil())
+ // two configs in config.yaml
+ Expect(config.Name).To(Equal("hermes-2-pro-mistral"))
+ valid, err = config.Validate()
+ Expect(err).To(BeNil())
+ Expect(valid).To(BeTrue())
+ })
+ })
+ It("Properly handles backend usecase matching", func() {
+
+ a := ModelConfig{
+ Name: "a",
+ }
+ Expect(a.HasUsecases(FLAG_ANY)).To(BeTrue()) // FLAG_ANY just means the config _exists_ essentially.
+
+ b := ModelConfig{
+ Name: "b",
+ Backend: "stablediffusion",
+ }
+ Expect(b.HasUsecases(FLAG_ANY)).To(BeTrue())
+ Expect(b.HasUsecases(FLAG_IMAGE)).To(BeTrue())
+ Expect(b.HasUsecases(FLAG_CHAT)).To(BeFalse())
+
+ c := ModelConfig{
+ Name: "c",
+ Backend: "llama-cpp",
+ TemplateConfig: TemplateConfig{
+ Chat: "chat",
+ },
+ }
+ Expect(c.HasUsecases(FLAG_ANY)).To(BeTrue())
+ Expect(c.HasUsecases(FLAG_IMAGE)).To(BeFalse())
+ Expect(c.HasUsecases(FLAG_COMPLETION)).To(BeFalse())
+ Expect(c.HasUsecases(FLAG_CHAT)).To(BeTrue())
+
+ d := ModelConfig{
+ Name: "d",
+ Backend: "llama-cpp",
+ TemplateConfig: TemplateConfig{
+ Chat: "chat",
+ Completion: "completion",
+ },
+ }
+ Expect(d.HasUsecases(FLAG_ANY)).To(BeTrue())
+ Expect(d.HasUsecases(FLAG_IMAGE)).To(BeFalse())
+ Expect(d.HasUsecases(FLAG_COMPLETION)).To(BeTrue())
+ Expect(d.HasUsecases(FLAG_CHAT)).To(BeTrue())
+
+ trueValue := true
+ e := ModelConfig{
+ Name: "e",
+ Backend: "llama-cpp",
+ TemplateConfig: TemplateConfig{
+ Completion: "completion",
+ },
+ Embeddings: &trueValue,
+ }
+
+ Expect(e.HasUsecases(FLAG_ANY)).To(BeTrue())
+ Expect(e.HasUsecases(FLAG_IMAGE)).To(BeFalse())
+ Expect(e.HasUsecases(FLAG_COMPLETION)).To(BeTrue())
+ Expect(e.HasUsecases(FLAG_CHAT)).To(BeFalse())
+ Expect(e.HasUsecases(FLAG_EMBEDDINGS)).To(BeTrue())
+
+ f := ModelConfig{
+ Name: "f",
+ Backend: "piper",
+ }
+ Expect(f.HasUsecases(FLAG_ANY)).To(BeTrue())
+ Expect(f.HasUsecases(FLAG_TTS)).To(BeTrue())
+ Expect(f.HasUsecases(FLAG_CHAT)).To(BeFalse())
+
+ g := ModelConfig{
+ Name: "g",
+ Backend: "whisper",
+ }
+ Expect(g.HasUsecases(FLAG_ANY)).To(BeTrue())
+ Expect(g.HasUsecases(FLAG_TRANSCRIPT)).To(BeTrue())
+ Expect(g.HasUsecases(FLAG_TTS)).To(BeFalse())
+
+ h := ModelConfig{
+ Name: "h",
+ Backend: "transformers-musicgen",
+ }
+ Expect(h.HasUsecases(FLAG_ANY)).To(BeTrue())
+ Expect(h.HasUsecases(FLAG_TRANSCRIPT)).To(BeFalse())
+ Expect(h.HasUsecases(FLAG_TTS)).To(BeTrue())
+ Expect(h.HasUsecases(FLAG_SOUND_GENERATION)).To(BeTrue())
+
+ knownUsecases := FLAG_CHAT | FLAG_COMPLETION
+ i := ModelConfig{
+ Name: "i",
+ Backend: "whisper",
+ // Earlier test checks parsing, this just needs to set final values
+ KnownUsecases: &knownUsecases,
+ }
+ Expect(i.HasUsecases(FLAG_ANY)).To(BeTrue())
+ Expect(i.HasUsecases(FLAG_TRANSCRIPT)).To(BeTrue())
+ Expect(i.HasUsecases(FLAG_TTS)).To(BeFalse())
+ Expect(i.HasUsecases(FLAG_COMPLETION)).To(BeTrue())
+ Expect(i.HasUsecases(FLAG_CHAT)).To(BeTrue())
+ })
+ It("Test Validate with invalid MCP config", func() {
+ tmp, err := os.CreateTemp("", "config.yaml")
+ Expect(err).To(BeNil())
+ defer os.Remove(tmp.Name())
+ _, err = tmp.WriteString(
+ `name: test-mcp
+backend: "llama-cpp"
+mcp:
+ stdio: |
+ {
+ "mcpServers": {
+ "ddg": {
+ "command": "/docker/docker",
+ "args": ["run", "-i"]
+ }
+ "weather": {
+ "command": "/docker/docker",
+ "args": ["run", "-i"]
+ }
+ }
+ }`)
+ Expect(err).ToNot(HaveOccurred())
+ config, err := readModelConfigFromFile(tmp.Name())
+ Expect(err).To(BeNil())
+ Expect(config).ToNot(BeNil())
+ valid, err := config.Validate()
+ Expect(err).To(HaveOccurred())
+ Expect(valid).To(BeFalse())
+ Expect(err.Error()).To(ContainSubstring("invalid MCP configuration"))
+ })
+ It("Test Validate with valid MCP config", func() {
+ tmp, err := os.CreateTemp("", "config.yaml")
+ Expect(err).To(BeNil())
+ defer os.Remove(tmp.Name())
+ _, err = tmp.WriteString(
+ `name: test-mcp-valid
+backend: "llama-cpp"
+mcp:
+ stdio: |
+ {
+ "mcpServers": {
+ "ddg": {
+ "command": "/docker/docker",
+ "args": ["run", "-i"]
+ },
+ "weather": {
+ "command": "/docker/docker",
+ "args": ["run", "-i"]
+ }
+ }
+ }`)
+ Expect(err).ToNot(HaveOccurred())
+ config, err := readModelConfigFromFile(tmp.Name())
+ Expect(err).To(BeNil())
+ Expect(config).ToNot(BeNil())
+ valid, err := config.Validate()
+ Expect(err).To(BeNil())
+ Expect(valid).To(BeTrue())
+ })
+})
diff --git a/core/config/model_test.go b/core/config/model_test.go
new file mode 100644
index 0000000000000000000000000000000000000000..f127f8f568304874c99ea68e77a638bc24217431
--- /dev/null
+++ b/core/config/model_test.go
@@ -0,0 +1,113 @@
+package config
+
+import (
+ "os"
+
+ . "github.com/onsi/ginkgo/v2"
+ . "github.com/onsi/gomega"
+)
+
+var _ = Describe("Test cases for config related functions", func() {
+
+ var (
+ configFile string
+ )
+
+ Context("Test Read configuration functions", func() {
+ configFile = os.Getenv("CONFIG_FILE")
+ It("Test readConfigFile", func() {
+ config, err := readMultipleModelConfigsFromFile(configFile)
+ Expect(err).To(BeNil())
+ Expect(config).ToNot(BeNil())
+ // two configs in config.yaml
+ Expect(config[0].Name).To(Equal("list1"))
+ Expect(config[1].Name).To(Equal("list2"))
+ })
+
+ It("Test LoadConfigs", func() {
+
+ bcl := NewModelConfigLoader(os.Getenv("MODELS_PATH"))
+ err := bcl.LoadModelConfigsFromPath(os.Getenv("MODELS_PATH"))
+
+ Expect(err).To(BeNil())
+ configs := bcl.GetAllModelsConfigs()
+ loadedModelNames := []string{}
+ for _, v := range configs {
+ loadedModelNames = append(loadedModelNames, v.Name)
+ }
+ Expect(configs).ToNot(BeNil())
+
+ Expect(loadedModelNames).To(ContainElements("code-search-ada-code-001"))
+
+ // config should includes text-embedding-ada-002 models's api.config
+ Expect(loadedModelNames).To(ContainElements("text-embedding-ada-002"))
+
+ // config should includes rwkv_test models's api.config
+ Expect(loadedModelNames).To(ContainElements("rwkv_test"))
+
+ // config should includes whisper-1 models's api.config
+ Expect(loadedModelNames).To(ContainElements("whisper-1"))
+ })
+
+ It("Test new loadconfig", func() {
+
+ bcl := NewModelConfigLoader(os.Getenv("MODELS_PATH"))
+ err := bcl.LoadModelConfigsFromPath(os.Getenv("MODELS_PATH"))
+ Expect(err).To(BeNil())
+ configs := bcl.GetAllModelsConfigs()
+ loadedModelNames := []string{}
+ for _, v := range configs {
+ loadedModelNames = append(loadedModelNames, v.Name)
+ }
+ Expect(configs).ToNot(BeNil())
+ totalModels := len(loadedModelNames)
+
+ Expect(loadedModelNames).To(ContainElements("code-search-ada-code-001"))
+
+ // config should includes text-embedding-ada-002 models's api.config
+ Expect(loadedModelNames).To(ContainElements("text-embedding-ada-002"))
+
+ // config should includes rwkv_test models's api.config
+ Expect(loadedModelNames).To(ContainElements("rwkv_test"))
+
+ // config should includes whisper-1 models's api.config
+ Expect(loadedModelNames).To(ContainElements("whisper-1"))
+
+ // create a temp directory and store a temporary model
+ tmpdir, err := os.MkdirTemp("", "test")
+ Expect(err).ToNot(HaveOccurred())
+ defer os.RemoveAll(tmpdir)
+
+ // create a temporary model
+ model := `name: "test-model"
+description: "test model"
+options:
+- foo
+- bar
+- baz
+`
+ modelFile := tmpdir + "/test-model.yaml"
+ err = os.WriteFile(modelFile, []byte(model), 0644)
+ Expect(err).ToNot(HaveOccurred())
+
+ err = bcl.LoadModelConfigsFromPath(tmpdir)
+ Expect(err).ToNot(HaveOccurred())
+
+ configs = bcl.GetAllModelsConfigs()
+ Expect(len(configs)).ToNot(Equal(totalModels))
+
+ loadedModelNames = []string{}
+ var testModel ModelConfig
+ for _, v := range configs {
+ loadedModelNames = append(loadedModelNames, v.Name)
+ if v.Name == "test-model" {
+ testModel = v
+ }
+ }
+ Expect(loadedModelNames).To(ContainElements("test-model"))
+ Expect(testModel.Description).To(Equal("test model"))
+ Expect(testModel.Options).To(ContainElements("foo", "bar", "baz"))
+
+ })
+ })
+})
diff --git a/core/config/runtime_settings.go b/core/config/runtime_settings.go
new file mode 100644
index 0000000000000000000000000000000000000000..1a7f6db8175c84af66408579da6aaa93eaf8aff4
--- /dev/null
+++ b/core/config/runtime_settings.go
@@ -0,0 +1,63 @@
+package config
+
+// RuntimeSettings represents runtime configuration that can be changed dynamically.
+// This struct is used for:
+// - API responses (GET /api/settings)
+// - API requests (POST /api/settings)
+// - Persisting to runtime_settings.json
+// - Loading from runtime_settings.json on startup
+//
+// All fields are pointers to distinguish between "not set" and "set to zero/false value".
+type RuntimeSettings struct {
+ // Watchdog settings
+ WatchdogEnabled *bool `json:"watchdog_enabled,omitempty"`
+ WatchdogIdleEnabled *bool `json:"watchdog_idle_enabled,omitempty"`
+ WatchdogBusyEnabled *bool `json:"watchdog_busy_enabled,omitempty"`
+ WatchdogIdleTimeout *string `json:"watchdog_idle_timeout,omitempty"`
+ WatchdogBusyTimeout *string `json:"watchdog_busy_timeout,omitempty"`
+ WatchdogInterval *string `json:"watchdog_interval,omitempty"` // Interval between watchdog checks (e.g., 2s, 30s)
+
+ // Backend management
+ SingleBackend *bool `json:"single_backend,omitempty"` // Deprecated: use MaxActiveBackends = 1 instead
+ MaxActiveBackends *int `json:"max_active_backends,omitempty"` // Maximum number of active backends (0 = unlimited, 1 = single backend mode)
+ ParallelBackendRequests *bool `json:"parallel_backend_requests,omitempty"`
+
+ // Memory Reclaimer settings (works with GPU if available, otherwise RAM)
+ MemoryReclaimerEnabled *bool `json:"memory_reclaimer_enabled,omitempty"` // Enable memory threshold monitoring
+ MemoryReclaimerThreshold *float64 `json:"memory_reclaimer_threshold,omitempty"` // Threshold 0.0-1.0 (e.g., 0.95 = 95%)
+
+ // Eviction settings
+ ForceEvictionWhenBusy *bool `json:"force_eviction_when_busy,omitempty"` // Force eviction even when models have active API calls (default: false for safety)
+ LRUEvictionMaxRetries *int `json:"lru_eviction_max_retries,omitempty"` // Maximum number of retries when waiting for busy models to become idle (default: 30)
+ LRUEvictionRetryInterval *string `json:"lru_eviction_retry_interval,omitempty"` // Interval between retries when waiting for busy models (e.g., 1s, 2s) (default: 1s)
+
+ // Performance settings
+ Threads *int `json:"threads,omitempty"`
+ ContextSize *int `json:"context_size,omitempty"`
+ F16 *bool `json:"f16,omitempty"`
+ Debug *bool `json:"debug,omitempty"`
+ EnableTracing *bool `json:"enable_tracing,omitempty"`
+ TracingMaxItems *int `json:"tracing_max_items,omitempty"`
+
+ // Security/CORS settings
+ CORS *bool `json:"cors,omitempty"`
+ CSRF *bool `json:"csrf,omitempty"`
+ CORSAllowOrigins *string `json:"cors_allow_origins,omitempty"`
+
+ // P2P settings
+ P2PToken *string `json:"p2p_token,omitempty"`
+ P2PNetworkID *string `json:"p2p_network_id,omitempty"`
+ Federated *bool `json:"federated,omitempty"`
+
+ // Gallery settings
+ Galleries *[]Gallery `json:"galleries,omitempty"`
+ BackendGalleries *[]Gallery `json:"backend_galleries,omitempty"`
+ AutoloadGalleries *bool `json:"autoload_galleries,omitempty"`
+ AutoloadBackendGalleries *bool `json:"autoload_backend_galleries,omitempty"`
+
+ // API keys - No omitempty as we need to save empty arrays to clear keys
+ ApiKeys *[]string `json:"api_keys"`
+
+ // Agent settings
+ AgentJobRetentionDays *int `json:"agent_job_retention_days,omitempty"`
+}
diff --git a/core/dependencies_manager/manager.go b/core/dependencies_manager/manager.go
new file mode 100644
index 0000000000000000000000000000000000000000..8434f721071c20fe29042dbe60312c3e9c2ea09d
--- /dev/null
+++ b/core/dependencies_manager/manager.go
@@ -0,0 +1,47 @@
+package main
+
+import (
+ "fmt"
+ "os"
+ "path/filepath"
+
+ "github.com/mudler/LocalAI/pkg/downloader"
+ "github.com/mudler/LocalAI/pkg/utils"
+ "gopkg.in/yaml.v3"
+)
+
+type Asset struct {
+ FileName string `yaml:"filename"`
+ URL string `yaml:"url"`
+ SHA string `yaml:"sha"`
+}
+
+func main() {
+
+ // read the YAML file which contains a list of assets
+ // and download them in the asset path
+ assets := []Asset{}
+
+ assetFile := os.Args[1]
+ destPath := os.Args[2]
+
+ // read the YAML file
+ f, err := os.ReadFile(assetFile)
+ if err != nil {
+ panic(err)
+ }
+ // unmarshal the YAML data into a struct
+ if err := yaml.Unmarshal(f, &assets); err != nil {
+ panic(err)
+ }
+
+ // download the assets
+ for _, asset := range assets {
+ uri := downloader.URI(asset.URL)
+ if err := uri.DownloadFile(filepath.Join(destPath, asset.FileName), asset.SHA, 1, 1, utils.DisplayDownloadFunction); err != nil {
+ panic(err)
+ }
+ }
+
+ fmt.Println("Finished downloading assets")
+}
diff --git a/core/explorer/database.go b/core/explorer/database.go
new file mode 100644
index 0000000000000000000000000000000000000000..e24de0aad26b174f14659703a4b697f8c0d20284
--- /dev/null
+++ b/core/explorer/database.go
@@ -0,0 +1,125 @@
+package explorer
+
+// A simple JSON database for storing and retrieving p2p network tokens and a name and description.
+
+import (
+ "encoding/json"
+ "os"
+ "sort"
+ "sync"
+
+ "github.com/gofrs/flock"
+)
+
+// Database is a simple JSON database for storing and retrieving p2p network tokens and a name and description.
+type Database struct {
+ path string
+ data map[string]TokenData
+ flock *flock.Flock
+ sync.Mutex
+}
+
+// TokenData is a p2p network token with a name and description.
+type TokenData struct {
+ Name string `json:"name"`
+ Description string `json:"description"`
+ Clusters []ClusterData
+ Failures int
+}
+
+type ClusterData struct {
+ Workers []string
+ Type string
+ NetworkID string
+}
+
+// NewDatabase creates a new Database with the given path.
+func NewDatabase(path string) (*Database, error) {
+ fileLock := flock.New(path + ".lock")
+ db := &Database{
+ data: make(map[string]TokenData),
+ path: path,
+ flock: fileLock,
+ }
+ return db, db.load()
+}
+
+// Get retrieves a Token from the Database by its token.
+func (db *Database) Get(token string) (TokenData, bool) {
+ db.flock.Lock() // we are making sure that the file is not being written to
+ defer db.flock.Unlock()
+ db.Lock() // we are making sure that is safe if called by another instance in the same process
+ defer db.Unlock()
+ db.load()
+ t, ok := db.data[token]
+ return t, ok
+}
+
+// Set stores a Token in the Database by its token.
+func (db *Database) Set(token string, t TokenData) error {
+ db.flock.Lock()
+ defer db.flock.Unlock()
+ db.Lock()
+ defer db.Unlock()
+ db.load()
+ db.data[token] = t
+
+ return db.save()
+}
+
+// Delete removes a Token from the Database by its token.
+func (db *Database) Delete(token string) error {
+ db.flock.Lock()
+ defer db.flock.Unlock()
+ db.Lock()
+ defer db.Unlock()
+ db.load()
+ delete(db.data, token)
+ return db.save()
+}
+
+func (db *Database) TokenList() []string {
+ db.flock.Lock()
+ defer db.flock.Unlock()
+ db.Lock()
+ defer db.Unlock()
+ db.load()
+ tokens := []string{}
+ for k := range db.data {
+ tokens = append(tokens, k)
+ }
+
+ sort.Slice(tokens, func(i, j int) bool {
+ // sort by token
+ return tokens[i] < tokens[j]
+ })
+
+ return tokens
+}
+
+// load reads the Database from disk.
+func (db *Database) load() error {
+ if _, err := os.Stat(db.path); os.IsNotExist(err) {
+ return nil
+ }
+
+ // Read the file from disk
+ // Unmarshal the JSON into db.data
+ f, err := os.ReadFile(db.path)
+ if err != nil {
+ return err
+ }
+ return json.Unmarshal(f, &db.data)
+}
+
+// Save writes the Database to disk.
+func (db *Database) save() error {
+ // Marshal db.data into JSON
+ // Write the JSON to the file
+ f, err := os.Create(db.path)
+ if err != nil {
+ return err
+ }
+ defer f.Close()
+ return json.NewEncoder(f).Encode(db.data)
+}
diff --git a/core/explorer/database_test.go b/core/explorer/database_test.go
new file mode 100644
index 0000000000000000000000000000000000000000..7f2cbd268a36b8071aab173f1f5ec5606fca9773
--- /dev/null
+++ b/core/explorer/database_test.go
@@ -0,0 +1,92 @@
+package explorer_test
+
+import (
+ "os"
+
+ . "github.com/onsi/ginkgo/v2"
+ . "github.com/onsi/gomega"
+
+ "github.com/mudler/LocalAI/core/explorer"
+)
+
+var _ = Describe("Database", func() {
+ var (
+ dbPath string
+ db *explorer.Database
+ err error
+ )
+
+ BeforeEach(func() {
+ // Create a temporary file path for the database
+ dbPath = "test_db.json"
+ db, err = explorer.NewDatabase(dbPath)
+ Expect(err).To(BeNil())
+ })
+
+ AfterEach(func() {
+ // Clean up the temporary database file
+ os.Remove(dbPath)
+ })
+
+ Context("when managing tokens", func() {
+ It("should add and retrieve a token", func() {
+ token := "token123"
+ t := explorer.TokenData{Name: "TokenName", Description: "A test token"}
+
+ err = db.Set(token, t)
+ Expect(err).To(BeNil())
+
+ retrievedToken, exists := db.Get(token)
+ Expect(exists).To(BeTrue())
+ Expect(retrievedToken).To(Equal(t))
+ })
+
+ It("should delete a token", func() {
+ token := "token123"
+ t := explorer.TokenData{Name: "TokenName", Description: "A test token"}
+
+ err = db.Set(token, t)
+ Expect(err).To(BeNil())
+
+ err = db.Delete(token)
+ Expect(err).To(BeNil())
+
+ _, exists := db.Get(token)
+ Expect(exists).To(BeFalse())
+ })
+
+ It("should persist data to disk", func() {
+ token := "token123"
+ t := explorer.TokenData{Name: "TokenName", Description: "A test token"}
+
+ err = db.Set(token, t)
+ Expect(err).To(BeNil())
+
+ // Recreate the database object to simulate reloading from disk
+ db, err = explorer.NewDatabase(dbPath)
+ Expect(err).To(BeNil())
+
+ retrievedToken, exists := db.Get(token)
+ Expect(exists).To(BeTrue())
+ Expect(retrievedToken).To(Equal(t))
+
+ // Check the token list
+ tokenList := db.TokenList()
+ Expect(tokenList).To(ContainElement(token))
+ })
+ })
+
+ Context("when loading an empty or non-existent file", func() {
+ It("should start with an empty database", func() {
+ dbPath = "empty_db.json"
+ db, err = explorer.NewDatabase(dbPath)
+ Expect(err).To(BeNil())
+
+ _, exists := db.Get("nonexistent")
+ Expect(exists).To(BeFalse())
+
+ // Clean up
+ os.Remove(dbPath)
+ })
+ })
+})
diff --git a/core/explorer/discovery.go b/core/explorer/discovery.go
new file mode 100644
index 0000000000000000000000000000000000000000..989e784d32b616516aed4c9cc4f694e27001c610
--- /dev/null
+++ b/core/explorer/discovery.go
@@ -0,0 +1,214 @@
+package explorer
+
+import (
+ "context"
+ "fmt"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/mudler/xlog"
+
+ "github.com/mudler/LocalAI/core/p2p"
+ "github.com/mudler/LocalAI/core/schema"
+ "github.com/mudler/edgevpn/pkg/blockchain"
+)
+
+type DiscoveryServer struct {
+ sync.Mutex
+ database *Database
+ connectionTime time.Duration
+ errorThreshold int
+}
+
+// NewDiscoveryServer creates a new DiscoveryServer with the given Database.
+// it keeps the db state in sync with the network state
+func NewDiscoveryServer(db *Database, dur time.Duration, failureThreshold int) *DiscoveryServer {
+ if dur == 0 {
+ dur = 50 * time.Second
+ }
+ if failureThreshold == 0 {
+ failureThreshold = 3
+ }
+ return &DiscoveryServer{
+ database: db,
+ connectionTime: dur,
+ errorThreshold: failureThreshold,
+ }
+}
+
+type Network struct {
+ Clusters []ClusterData
+}
+
+func (s *DiscoveryServer) runBackground() {
+ if len(s.database.TokenList()) == 0 {
+ time.Sleep(5 * time.Second) // avoid busy loop
+ return
+ }
+
+ for _, token := range s.database.TokenList() {
+ c, cancel := context.WithTimeout(context.Background(), s.connectionTime)
+ defer cancel()
+
+ // Connect to the network
+ // Get the number of nodes
+ // save it in the current state (mutex)
+ // do not do in parallel
+ n, err := p2p.NewNode(token)
+ if err != nil {
+ xlog.Error("Failed to create node", "error", err)
+ s.failedToken(token)
+ continue
+ }
+
+ err = n.Start(c)
+ if err != nil {
+ xlog.Error("Failed to start node", "error", err)
+ s.failedToken(token)
+ continue
+ }
+
+ ledger, err := n.Ledger()
+ if err != nil {
+ xlog.Error("Failed to start ledger", "error", err)
+ s.failedToken(token)
+ continue
+ }
+
+ networkData := make(chan ClusterData)
+
+ // get the network data - it takes the whole timeout
+ // as we might not be connected to the network yet,
+ // and few attempts would have to be made before bailing out
+ go s.retrieveNetworkData(c, ledger, networkData)
+
+ hasWorkers := false
+ ledgerK := []ClusterData{}
+ for key := range networkData {
+ ledgerK = append(ledgerK, key)
+ if len(key.Workers) > 0 {
+ hasWorkers = true
+ }
+ }
+
+ xlog.Debug("Network clusters", "network", token, "count", len(ledgerK))
+ if len(ledgerK) != 0 {
+ for _, k := range ledgerK {
+ xlog.Debug("Clusterdata", "network", token, "cluster", k)
+ }
+ }
+
+ if hasWorkers {
+ s.Lock()
+ data, _ := s.database.Get(token)
+ (&data).Clusters = ledgerK
+ (&data).Failures = 0
+ s.database.Set(token, data)
+ s.Unlock()
+ } else {
+ s.failedToken(token)
+ }
+ }
+
+ s.deleteFailedConnections()
+}
+
+func (s *DiscoveryServer) failedToken(token string) {
+ s.Lock()
+ defer s.Unlock()
+ data, _ := s.database.Get(token)
+ (&data).Failures++
+ s.database.Set(token, data)
+}
+
+func (s *DiscoveryServer) deleteFailedConnections() {
+ s.Lock()
+ defer s.Unlock()
+ for _, t := range s.database.TokenList() {
+ data, _ := s.database.Get(t)
+ if data.Failures > s.errorThreshold {
+ xlog.Info("Token has been removed from the database", "token", t)
+ s.database.Delete(t)
+ }
+ }
+}
+
+func (s *DiscoveryServer) retrieveNetworkData(c context.Context, ledger *blockchain.Ledger, networkData chan ClusterData) {
+ clusters := map[string]ClusterData{}
+
+ defer func() {
+ for _, n := range clusters {
+ networkData <- n
+ }
+ close(networkData)
+ }()
+
+ for {
+ select {
+ case <-c.Done():
+ return
+ default:
+ time.Sleep(5 * time.Second)
+
+ data := ledger.LastBlock().Storage
+ LEDGER:
+ for d := range data {
+ toScanForWorkers := false
+ cd := ClusterData{}
+ isWorkerCluster := d == p2p.WorkerID || (strings.Contains(d, "_") && strings.Contains(d, p2p.WorkerID))
+ isFederatedCluster := d == p2p.FederatedID || (strings.Contains(d, "_") && strings.Contains(d, p2p.FederatedID))
+ switch {
+ case isWorkerCluster:
+ toScanForWorkers = true
+ cd.Type = "worker"
+ case isFederatedCluster:
+ toScanForWorkers = true
+ cd.Type = "federated"
+ }
+
+ if strings.Contains(d, "_") {
+ cd.NetworkID = strings.Split(d, "_")[0]
+ }
+
+ if !toScanForWorkers {
+ continue LEDGER
+ }
+
+ atLeastOneWorker := false
+ DATA:
+ for _, v := range data[d] {
+ nd := &schema.NodeData{}
+ if err := v.Unmarshal(nd); err != nil {
+ continue DATA
+ }
+
+ if nd.IsOnline() {
+ atLeastOneWorker = true
+ (&cd).Workers = append(cd.Workers, nd.ID)
+ }
+ }
+
+ if atLeastOneWorker {
+ clusters[d] = cd
+ }
+ }
+ }
+ }
+}
+
+// Start the discovery server. This is meant to be run in to a goroutine.
+func (s *DiscoveryServer) Start(ctx context.Context, keepRunning bool) error {
+ for {
+ select {
+ case <-ctx.Done():
+ return fmt.Errorf("context cancelled")
+ default:
+ // Collect data
+ s.runBackground()
+ if !keepRunning {
+ return nil
+ }
+ }
+ }
+}
diff --git a/core/explorer/explorer_suite_test.go b/core/explorer/explorer_suite_test.go
new file mode 100644
index 0000000000000000000000000000000000000000..fc718d5f8dfaa2281a88bc336ab885603d6bf2cd
--- /dev/null
+++ b/core/explorer/explorer_suite_test.go
@@ -0,0 +1,13 @@
+package explorer_test
+
+import (
+ "testing"
+
+ . "github.com/onsi/ginkgo/v2"
+ . "github.com/onsi/gomega"
+)
+
+func TestExplorer(t *testing.T) {
+ RegisterFailHandler(Fail)
+ RunSpecs(t, "Explorer test suite")
+}
diff --git a/core/gallery/backend_types.go b/core/gallery/backend_types.go
new file mode 100644
index 0000000000000000000000000000000000000000..0fb6e7f2461289ebc6a5755bae3e10a423fcb304
--- /dev/null
+++ b/core/gallery/backend_types.go
@@ -0,0 +1,107 @@
+package gallery
+
+import (
+ "fmt"
+
+ "github.com/mudler/LocalAI/core/config"
+ "github.com/mudler/LocalAI/pkg/system"
+ "github.com/mudler/xlog"
+)
+
+// BackendMetadata represents the metadata stored in a JSON file for each installed backend
+type BackendMetadata struct {
+ // Alias is an optional alternative name for the backend
+ Alias string `json:"alias,omitempty"`
+ // MetaBackendFor points to the concrete backend if this is a meta backend
+ MetaBackendFor string `json:"meta_backend_for,omitempty"`
+ // Name is the original name from the gallery
+ Name string `json:"name,omitempty"`
+ // GalleryURL is the URL of the gallery this backend came from
+ GalleryURL string `json:"gallery_url,omitempty"`
+ // InstalledAt is the timestamp when the backend was installed
+ InstalledAt string `json:"installed_at,omitempty"`
+}
+
+type GalleryBackend struct {
+ Metadata `json:",inline" yaml:",inline"`
+ Alias string `json:"alias,omitempty" yaml:"alias,omitempty"`
+ URI string `json:"uri,omitempty" yaml:"uri,omitempty"`
+ Mirrors []string `json:"mirrors,omitempty" yaml:"mirrors,omitempty"`
+ CapabilitiesMap map[string]string `json:"capabilities,omitempty" yaml:"capabilities,omitempty"`
+}
+
+func (backend *GalleryBackend) FindBestBackendFromMeta(systemState *system.SystemState, backends GalleryElements[*GalleryBackend]) *GalleryBackend {
+ if systemState == nil {
+ return nil
+ }
+
+ realBackend := backend.CapabilitiesMap[systemState.Capability(backend.CapabilitiesMap)]
+ if realBackend == "" {
+ xlog.Debug("No backend found for reported capability", "backend", backend.Name, "reportedCapability", systemState.Capability(backend.CapabilitiesMap))
+ return nil
+ }
+
+ xlog.Debug("Found backend for reported capability", "backend", backend.Name, "reportedCapability", systemState.Capability(backend.CapabilitiesMap))
+ return backends.FindByName(realBackend)
+}
+
+func (m *GalleryBackend) GetInstalled() bool {
+ return m.Installed
+}
+
+func (m *GalleryBackend) GetLicense() string {
+ return m.License
+}
+
+type GalleryBackends []*GalleryBackend
+
+func (m *GalleryBackend) SetGallery(gallery config.Gallery) {
+ m.Gallery = gallery
+}
+
+func (m *GalleryBackend) IsMeta() bool {
+ return len(m.CapabilitiesMap) > 0 && m.URI == ""
+}
+
+// IsCompatibleWith checks if the backend is compatible with the current system capability.
+// For meta backends, it checks if any of the capabilities in the map match the system capability.
+// For concrete backends, it delegates to SystemState.IsBackendCompatible.
+func (m *GalleryBackend) IsCompatibleWith(systemState *system.SystemState) bool {
+ if systemState == nil {
+ return true
+ }
+
+ // Meta backends are compatible if the system capability matches one of the keys
+ if m.IsMeta() {
+ capability := systemState.Capability(m.CapabilitiesMap)
+ _, exists := m.CapabilitiesMap[capability]
+ return exists
+ }
+
+ // For concrete backends, delegate to the system package
+ return systemState.IsBackendCompatible(m.Name, m.URI)
+}
+
+func (m *GalleryBackend) SetInstalled(installed bool) {
+ m.Installed = installed
+}
+
+func (m *GalleryBackend) GetName() string {
+ return m.Name
+}
+
+func (m *GalleryBackend) GetGallery() config.Gallery {
+ return m.Gallery
+}
+
+func (m *GalleryBackend) GetDescription() string {
+ return m.Description
+}
+
+func (m *GalleryBackend) GetTags() []string {
+ return m.Tags
+}
+
+func (m GalleryBackend) ID() string {
+ return fmt.Sprintf("%s@%s", m.Gallery.Name, m.Name)
+}
diff --git a/core/gallery/backends.go b/core/gallery/backends.go
new file mode 100644
index 0000000000000000000000000000000000000000..acef1318d4451d4d13c0d5424ca097779e5c3153
--- /dev/null
+++ b/core/gallery/backends.go
@@ -0,0 +1,450 @@
+// Package gallery provides installation and registration utilities for LocalAI backends,
+// including meta-backend resolution based on system capabilities.
+package gallery
+
+import (
+ "context"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "os"
+ "path/filepath"
+ "strings"
+ "time"
+
+ "github.com/mudler/LocalAI/core/config"
+ "github.com/mudler/LocalAI/pkg/downloader"
+ "github.com/mudler/LocalAI/pkg/model"
+ "github.com/mudler/LocalAI/pkg/system"
+ "github.com/mudler/xlog"
+ cp "github.com/otiai10/copy"
+)
+
+const (
+ metadataFile = "metadata.json"
+ runFile = "run.sh"
+)
+
+// backendCandidate represents an installed concrete backend option for a given alias
+type backendCandidate struct {
+ name string
+ runFile string
+}
+
+// readBackendMetadata reads the metadata JSON file for a backend
+func readBackendMetadata(backendPath string) (*BackendMetadata, error) {
+ metadataPath := filepath.Join(backendPath, metadataFile)
+
+ // If metadata file doesn't exist, return nil (for backward compatibility)
+ if _, err := os.Stat(metadataPath); os.IsNotExist(err) {
+ return nil, nil
+ }
+
+ data, err := os.ReadFile(metadataPath)
+ if err != nil {
+ return nil, fmt.Errorf("failed to read metadata file %q: %v", metadataPath, err)
+ }
+
+ var metadata BackendMetadata
+ if err := json.Unmarshal(data, &metadata); err != nil {
+ return nil, fmt.Errorf("failed to unmarshal metadata file %q: %v", metadataPath, err)
+ }
+
+ return &metadata, nil
+}
+
+// writeBackendMetadata writes the metadata JSON file for a backend
+func writeBackendMetadata(backendPath string, metadata *BackendMetadata) error {
+ metadataPath := filepath.Join(backendPath, metadataFile)
+
+ data, err := json.MarshalIndent(metadata, "", " ")
+ if err != nil {
+ return fmt.Errorf("failed to marshal metadata: %v", err)
+ }
+
+ if err := os.WriteFile(metadataPath, data, 0644); err != nil {
+ return fmt.Errorf("failed to write metadata file %q: %v", metadataPath, err)
+ }
+
+ return nil
+}
+
+// InstallBackendFromGallery installs a backend from the gallery.
+func InstallBackendFromGallery(ctx context.Context, galleries []config.Gallery, systemState *system.SystemState, modelLoader *model.ModelLoader, name string, downloadStatus func(string, string, string, float64), force bool) error {
+ if !force {
+ // check if we already have the backend installed
+ backends, err := ListSystemBackends(systemState)
+ if err != nil {
+ return err
+ }
+ if backends.Exists(name) {
+ return nil
+ }
+ }
+
+ if name == "" {
+ return fmt.Errorf("backend name is empty")
+ }
+
+ xlog.Debug("Installing backend from gallery", "galleries", galleries, "name", name)
+
+ backends, err := AvailableBackends(galleries, systemState)
+ if err != nil {
+ return err
+ }
+
+ backend := FindGalleryElement(backends, name)
+ if backend == nil {
+ return fmt.Errorf("no backend found with name %q", name)
+ }
+
+ if backend.IsMeta() {
+ xlog.Debug("Backend is a meta backend", "systemState", systemState, "name", name)
+
+ // Then, let's try to find the best backend based on the capabilities map
+ bestBackend := backend.FindBestBackendFromMeta(systemState, backends)
+ if bestBackend == nil {
+ return fmt.Errorf("no backend found with capabilities %q", backend.CapabilitiesMap)
+ }
+
+ xlog.Debug("Installing backend from meta backend", "name", name, "bestBackend", bestBackend.Name)
+
+ // Then, let's install the best backend
+ if err := InstallBackend(ctx, systemState, modelLoader, bestBackend, downloadStatus); err != nil {
+ return err
+ }
+
+ // we need now to create a path for the meta backend, with the alias to the installed ones so it can be used to remove it
+ metaBackendPath := filepath.Join(systemState.Backend.BackendsPath, name)
+ if err := os.MkdirAll(metaBackendPath, 0750); err != nil {
+ return fmt.Errorf("failed to create meta backend path %q: %v", metaBackendPath, err)
+ }
+
+ // Create metadata for the meta backend
+ metaMetadata := &BackendMetadata{
+ MetaBackendFor: bestBackend.Name,
+ Name: name,
+ GalleryURL: backend.Gallery.URL,
+ InstalledAt: time.Now().Format(time.RFC3339),
+ }
+
+ if err := writeBackendMetadata(metaBackendPath, metaMetadata); err != nil {
+ return fmt.Errorf("failed to write metadata for meta backend %q: %v", name, err)
+ }
+
+ return nil
+ }
+
+ return InstallBackend(ctx, systemState, modelLoader, backend, downloadStatus)
+}
+
+func InstallBackend(ctx context.Context, systemState *system.SystemState, modelLoader *model.ModelLoader, config *GalleryBackend, downloadStatus func(string, string, string, float64)) error {
+ // Create base path if it doesn't exist
+ err := os.MkdirAll(systemState.Backend.BackendsPath, 0750)
+ if err != nil {
+ return fmt.Errorf("failed to create base path: %v", err)
+ }
+
+ if config.IsMeta() {
+ return fmt.Errorf("meta backends cannot be installed directly")
+ }
+
+ name := config.Name
+ backendPath := filepath.Join(systemState.Backend.BackendsPath, name)
+ err = os.MkdirAll(backendPath, 0750)
+ if err != nil {
+ return fmt.Errorf("failed to create base path: %v", err)
+ }
+
+ uri := downloader.URI(config.URI)
+ // Check if it is a directory
+ if uri.LooksLikeDir() {
+ // It is a directory, we just copy it over in the backend folder
+ if err := cp.Copy(config.URI, backendPath); err != nil {
+ return fmt.Errorf("failed copying: %w", err)
+ }
+ } else {
+ xlog.Debug("Downloading backend", "uri", config.URI, "backendPath", backendPath)
+ if err := uri.DownloadFileWithContext(ctx, backendPath, "", 1, 1, downloadStatus); err != nil {
+ success := false
+ // Try to download from mirrors
+ for _, mirror := range config.Mirrors {
+ // Check for cancellation before trying next mirror
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+ default:
+ }
+ if err := downloader.URI(mirror).DownloadFileWithContext(ctx, backendPath, "", 1, 1, downloadStatus); err == nil {
+ success = true
+ xlog.Debug("Downloaded backend", "uri", config.URI, "backendPath", backendPath)
+ break
+ }
+ }
+
+ if !success {
+ xlog.Error("Failed to download backend", "uri", config.URI, "backendPath", backendPath, "error", err)
+ return fmt.Errorf("failed to download backend %q: %v", config.URI, err)
+ }
+ } else {
+ xlog.Debug("Downloaded backend", "uri", config.URI, "backendPath", backendPath)
+ }
+ }
+
+ // sanity check - check if runfile is present
+ runFile := filepath.Join(backendPath, runFile)
+ if _, err := os.Stat(runFile); os.IsNotExist(err) {
+ xlog.Error("Run file not found", "runFile", runFile)
+ return fmt.Errorf("not a valid backend: run file not found %q", runFile)
+ }
+
+ // Create metadata for the backend
+ metadata := &BackendMetadata{
+ Name: name,
+ GalleryURL: config.Gallery.URL,
+ InstalledAt: time.Now().Format(time.RFC3339),
+ }
+
+ if config.Alias != "" {
+ metadata.Alias = config.Alias
+ }
+
+ if err := writeBackendMetadata(backendPath, metadata); err != nil {
+ return fmt.Errorf("failed to write metadata for backend %q: %v", name, err)
+ }
+
+ return RegisterBackends(systemState, modelLoader)
+}
+
+func DeleteBackendFromSystem(systemState *system.SystemState, name string) error {
+ backends, err := ListSystemBackends(systemState)
+ if err != nil {
+ return err
+ }
+
+ backend, ok := backends.Get(name)
+ if !ok {
+ return fmt.Errorf("backend %q not found", name)
+ }
+
+ if backend.IsSystem {
+ return fmt.Errorf("system backend %q cannot be deleted", name)
+ }
+
+ backendDirectory := filepath.Join(systemState.Backend.BackendsPath, name)
+
+ // check if the backend dir exists
+ if _, err := os.Stat(backendDirectory); os.IsNotExist(err) {
+ // if doesn't exist, it might be an alias, so we need to check if we have a matching alias in
+ // all the backends in the basePath
+ backends, err := os.ReadDir(systemState.Backend.BackendsPath)
+ if err != nil {
+ return err
+ }
+ foundBackend := false
+
+ for _, backend := range backends {
+ if backend.IsDir() {
+ metadata, err := readBackendMetadata(filepath.Join(systemState.Backend.BackendsPath, backend.Name()))
+ if err != nil {
+ return err
+ }
+ if metadata != nil && metadata.Alias == name {
+ backendDirectory = filepath.Join(systemState.Backend.BackendsPath, backend.Name())
+ foundBackend = true
+ break
+ }
+ }
+ }
+
+ // If no backend found, return successfully (idempotent behavior)
+ if !foundBackend {
+ return fmt.Errorf("no backend found with name %q", name)
+ }
+ }
+
+ // If it's a meta backend, delete also associated backend
+ metadata, err := readBackendMetadata(backendDirectory)
+ if err != nil {
+ return err
+ }
+
+ if metadata != nil && metadata.MetaBackendFor != "" {
+ metaBackendDirectory := filepath.Join(systemState.Backend.BackendsPath, metadata.MetaBackendFor)
+ xlog.Debug("Deleting meta backend", "backendDirectory", metaBackendDirectory)
+ if _, err := os.Stat(metaBackendDirectory); os.IsNotExist(err) {
+ return fmt.Errorf("meta backend %q not found", metadata.MetaBackendFor)
+ }
+ os.RemoveAll(metaBackendDirectory)
+ }
+
+ return os.RemoveAll(backendDirectory)
+}
+
+type SystemBackend struct {
+ Name string
+ RunFile string
+ IsMeta bool
+ IsSystem bool
+ Metadata *BackendMetadata
+}
+
+type SystemBackends map[string]SystemBackend
+
+func (b SystemBackends) Exists(name string) bool {
+ _, ok := b[name]
+ return ok
+}
+
+func (b SystemBackends) Get(name string) (SystemBackend, bool) {
+ backend, ok := b[name]
+ return backend, ok
+}
+
+func (b SystemBackends) GetAll() []SystemBackend {
+ backends := make([]SystemBackend, 0)
+ for _, backend := range b {
+ backends = append(backends, backend)
+ }
+ return backends
+}
+
+func ListSystemBackends(systemState *system.SystemState) (SystemBackends, error) {
+ // Gather backends from system and user paths, then resolve alias conflicts by capability.
+ backends := make(SystemBackends)
+
+ // System-provided backends
+ if systemBackends, err := os.ReadDir(systemState.Backend.BackendsSystemPath); err == nil {
+ for _, systemBackend := range systemBackends {
+ if systemBackend.IsDir() {
+ run := filepath.Join(systemState.Backend.BackendsSystemPath, systemBackend.Name(), runFile)
+ if _, err := os.Stat(run); err == nil {
+ backends[systemBackend.Name()] = SystemBackend{
+ Name: systemBackend.Name(),
+ RunFile: run,
+ IsMeta: false,
+ IsSystem: true,
+ Metadata: nil,
+ }
+ }
+ }
+ }
+ } else if !errors.Is(err, os.ErrNotExist) {
+ xlog.Warn("Failed to read system backends, proceeding with user-managed backends", "error", err)
+ } else if errors.Is(err, os.ErrNotExist) {
+ xlog.Debug("No system backends found")
+ }
+
+ // User-managed backends and alias collection
+ entries, err := os.ReadDir(systemState.Backend.BackendsPath)
+ if err != nil {
+ return nil, err
+ }
+
+ aliasGroups := make(map[string][]backendCandidate)
+ metaMap := make(map[string]*BackendMetadata)
+
+ for _, e := range entries {
+ if !e.IsDir() {
+ continue
+ }
+ dir := e.Name()
+ run := filepath.Join(systemState.Backend.BackendsPath, dir, runFile)
+
+ var metadata *BackendMetadata
+ metadataPath := filepath.Join(systemState.Backend.BackendsPath, dir, metadataFile)
+ if _, err := os.Stat(metadataPath); os.IsNotExist(err) {
+ metadata = &BackendMetadata{Name: dir}
+ } else {
+ m, rerr := readBackendMetadata(filepath.Join(systemState.Backend.BackendsPath, dir))
+ if rerr != nil {
+ return nil, rerr
+ }
+ if m == nil {
+ metadata = &BackendMetadata{Name: dir}
+ } else {
+ metadata = m
+ }
+ }
+
+ metaMap[dir] = metadata
+
+ // Concrete backend entry
+ if _, err := os.Stat(run); err == nil {
+ backends[dir] = SystemBackend{
+ Name: dir,
+ RunFile: run,
+ IsMeta: false,
+ Metadata: metadata,
+ }
+ }
+
+ // Alias candidates
+ if metadata.Alias != "" {
+ aliasGroups[metadata.Alias] = append(aliasGroups[metadata.Alias], backendCandidate{name: dir, runFile: run})
+ }
+
+ // Meta backends indirection
+ if metadata.MetaBackendFor != "" {
+ backends[metadata.Name] = SystemBackend{
+ Name: metadata.Name,
+ RunFile: filepath.Join(systemState.Backend.BackendsPath, metadata.MetaBackendFor, runFile),
+ IsMeta: true,
+ Metadata: metadata,
+ }
+ }
+ }
+
+ // Resolve aliases using system capability preferences
+ tokens := systemState.BackendPreferenceTokens()
+ for alias, cands := range aliasGroups {
+ chosen := backendCandidate{}
+ // Try preference tokens
+ for _, t := range tokens {
+ for _, c := range cands {
+ if strings.Contains(strings.ToLower(c.name), t) && c.runFile != "" {
+ chosen = c
+ break
+ }
+ }
+ if chosen.runFile != "" {
+ break
+ }
+ }
+ // Fallback: first runnable
+ if chosen.runFile == "" {
+ for _, c := range cands {
+ if c.runFile != "" {
+ chosen = c
+ break
+ }
+ }
+ }
+ if chosen.runFile == "" {
+ continue
+ }
+ md := metaMap[chosen.name]
+ backends[alias] = SystemBackend{
+ Name: alias,
+ RunFile: chosen.runFile,
+ IsMeta: false,
+ Metadata: md,
+ }
+ }
+
+ return backends, nil
+}
+
+func RegisterBackends(systemState *system.SystemState, modelLoader *model.ModelLoader) error {
+ backends, err := ListSystemBackends(systemState)
+ if err != nil {
+ return err
+ }
+
+ for _, backend := range backends {
+ xlog.Debug("Registering backend", "name", backend.Name, "runFile", backend.RunFile)
+ modelLoader.SetExternalBackend(backend.Name, backend.RunFile)
+ }
+
+ return nil
+}
diff --git a/core/gallery/backends_test.go b/core/gallery/backends_test.go
new file mode 100644
index 0000000000000000000000000000000000000000..96ffe0fe521e5016d4a25391f279e6b557b00a54
--- /dev/null
+++ b/core/gallery/backends_test.go
@@ -0,0 +1,1027 @@
+package gallery
+
+import (
+ "context"
+ "encoding/json"
+ "os"
+ "path/filepath"
+ "runtime"
+
+ "github.com/mudler/LocalAI/core/config"
+ "github.com/mudler/LocalAI/pkg/model"
+ "github.com/mudler/LocalAI/pkg/system"
+ . "github.com/onsi/ginkgo/v2"
+ . "github.com/onsi/gomega"
+ "gopkg.in/yaml.v2"
+)
+
+const (
+ testImage = "quay.io/mudler/tests:localai-backend-test"
+)
+
+var _ = Describe("Runtime capability-based backend selection", func() {
+ var tempDir string
+
+ BeforeEach(func() {
+ var err error
+ tempDir, err = os.MkdirTemp("", "gallery-caps-*")
+ Expect(err).NotTo(HaveOccurred())
+ })
+
+ AfterEach(func() {
+ os.RemoveAll(tempDir)
+ })
+
+ It("ListSystemBackends prefers optimal alias candidate", func() {
+ // Arrange two installed backends sharing the same alias
+ must := func(err error) { Expect(err).NotTo(HaveOccurred()) }
+
+ cpuDir := filepath.Join(tempDir, "cpu-llama-cpp")
+ must(os.MkdirAll(cpuDir, 0o750))
+ cpuMeta := &BackendMetadata{Alias: "llama-cpp", Name: "cpu-llama-cpp"}
+ b, _ := json.Marshal(cpuMeta)
+ must(os.WriteFile(filepath.Join(cpuDir, "metadata.json"), b, 0o644))
+ must(os.WriteFile(filepath.Join(cpuDir, "run.sh"), []byte(""), 0o755))
+
+ cudaDir := filepath.Join(tempDir, "cuda12-llama-cpp")
+ must(os.MkdirAll(cudaDir, 0o750))
+ cudaMeta := &BackendMetadata{Alias: "llama-cpp", Name: "cuda12-llama-cpp"}
+ b, _ = json.Marshal(cudaMeta)
+ must(os.WriteFile(filepath.Join(cudaDir, "metadata.json"), b, 0o644))
+ must(os.WriteFile(filepath.Join(cudaDir, "run.sh"), []byte(""), 0o755))
+
+ // Default system: alias should point to CPU
+ sysDefault, err := system.GetSystemState(
+ system.WithBackendPath(tempDir),
+ )
+ must(err)
+ sysDefault.GPUVendor = "" // force default selection
+ backs, err := ListSystemBackends(sysDefault)
+ must(err)
+ aliasBack, ok := backs.Get("llama-cpp")
+ Expect(ok).To(BeTrue())
+ Expect(aliasBack.RunFile).To(Equal(filepath.Join(cpuDir, "run.sh")))
+ // concrete entries remain
+ _, ok = backs.Get("cpu-llama-cpp")
+ Expect(ok).To(BeTrue())
+ _, ok = backs.Get("cuda12-llama-cpp")
+ Expect(ok).To(BeTrue())
+
+ // NVIDIA system: alias should point to CUDA
+ // Force capability to nvidia to make the test deterministic on platforms like darwin/arm64 (which default to metal)
+ os.Setenv("LOCALAI_FORCE_META_BACKEND_CAPABILITY", "nvidia")
+ defer os.Unsetenv("LOCALAI_FORCE_META_BACKEND_CAPABILITY")
+
+ sysNvidia, err := system.GetSystemState(
+ system.WithBackendPath(tempDir),
+ )
+ must(err)
+ sysNvidia.GPUVendor = "nvidia"
+ sysNvidia.VRAM = 8 * 1024 * 1024 * 1024
+ backs, err = ListSystemBackends(sysNvidia)
+ must(err)
+ aliasBack, ok = backs.Get("llama-cpp")
+ Expect(ok).To(BeTrue())
+ Expect(aliasBack.RunFile).To(Equal(filepath.Join(cudaDir, "run.sh")))
+ })
+})
+
+var _ = Describe("Gallery Backends", func() {
+ var (
+ tempDir string
+ galleries []config.Gallery
+ ml *model.ModelLoader
+ systemState *system.SystemState
+ )
+
+ BeforeEach(func() {
+ var err error
+ tempDir, err = os.MkdirTemp("", "gallery-test-*")
+ Expect(err).NotTo(HaveOccurred())
+
+ // Setup test galleries
+ galleries = []config.Gallery{
+ {
+ Name: "test-gallery",
+ URL: "https://gist.githubusercontent.com/mudler/71d5376bc2aa168873fa519fa9f4bd56/raw/0557f9c640c159fa8e4eab29e8d98df6a3d6e80f/backend-gallery.yaml",
+ },
+ }
+ systemState, err = system.GetSystemState(system.WithBackendPath(tempDir))
+ Expect(err).NotTo(HaveOccurred())
+ ml = model.NewModelLoader(systemState)
+ })
+
+ AfterEach(func() {
+ os.RemoveAll(tempDir)
+ })
+
+ Describe("InstallBackendFromGallery", func() {
+ It("should return error when backend is not found", func() {
+ err := InstallBackendFromGallery(context.TODO(), galleries, systemState, ml, "non-existent", nil, true)
+ Expect(err).To(HaveOccurred())
+ Expect(err.Error()).To(ContainSubstring("no backend found with name \"non-existent\""))
+ })
+
+ It("should install backend from gallery", func() {
+ err := InstallBackendFromGallery(context.TODO(), galleries, systemState, ml, "test-backend", nil, true)
+ Expect(err).ToNot(HaveOccurred())
+ Expect(filepath.Join(tempDir, "test-backend", "run.sh")).To(BeARegularFile())
+ })
+ })
+
+ Describe("Meta Backends", func() {
+ It("should identify meta backends correctly", func() {
+ metaBackend := &GalleryBackend{
+ Metadata: Metadata{
+ Name: "meta-backend",
+ },
+ CapabilitiesMap: map[string]string{
+ "nvidia": "nvidia-backend",
+ "amd": "amd-backend",
+ "intel": "intel-backend",
+ },
+ }
+
+ Expect(metaBackend.IsMeta()).To(BeTrue())
+
+ regularBackend := &GalleryBackend{
+ Metadata: Metadata{
+ Name: "regular-backend",
+ },
+ URI: testImage,
+ }
+
+ Expect(regularBackend.IsMeta()).To(BeFalse())
+
+ emptyMetaBackend := &GalleryBackend{
+ Metadata: Metadata{
+ Name: "empty-meta-backend",
+ },
+ CapabilitiesMap: map[string]string{},
+ }
+
+ Expect(emptyMetaBackend.IsMeta()).To(BeFalse())
+
+ nilMetaBackend := &GalleryBackend{
+ Metadata: Metadata{
+ Name: "nil-meta-backend",
+ },
+ CapabilitiesMap: nil,
+ }
+
+ Expect(nilMetaBackend.IsMeta()).To(BeFalse())
+ })
+
+ It("should check IsCompatibleWith correctly for meta backends", func() {
+ metaBackend := &GalleryBackend{
+ Metadata: Metadata{
+ Name: "meta-backend",
+ },
+ CapabilitiesMap: map[string]string{
+ "nvidia": "nvidia-backend",
+ "amd": "amd-backend",
+ "default": "default-backend",
+ },
+ }
+
+ // Test with nil state - should be compatible
+ Expect(metaBackend.IsCompatibleWith(nil)).To(BeTrue())
+
+ // Test with NVIDIA system - should be compatible (has nvidia key)
+ nvidiaState := &system.SystemState{GPUVendor: "nvidia", VRAM: 8 * 1024 * 1024 * 1024}
+ Expect(metaBackend.IsCompatibleWith(nvidiaState)).To(BeTrue())
+
+ // Test with default (no GPU) - should be compatible (has default key)
+ defaultState := &system.SystemState{}
+ Expect(metaBackend.IsCompatibleWith(defaultState)).To(BeTrue())
+ })
+
+ Describe("IsCompatibleWith for concrete backends", func() {
+ Context("CPU backends", func() {
+ It("should be compatible on all systems", func() {
+ cpuBackend := &GalleryBackend{
+ Metadata: Metadata{
+ Name: "cpu-llama-cpp",
+ },
+ URI: "quay.io/go-skynet/local-ai-backends:latest-cpu-llama-cpp",
+ }
+ Expect(cpuBackend.IsCompatibleWith(&system.SystemState{})).To(BeTrue())
+ Expect(cpuBackend.IsCompatibleWith(&system.SystemState{GPUVendor: system.Nvidia, VRAM: 8 * 1024 * 1024 * 1024})).To(BeTrue())
+ Expect(cpuBackend.IsCompatibleWith(&system.SystemState{GPUVendor: system.AMD, VRAM: 8 * 1024 * 1024 * 1024})).To(BeTrue())
+ })
+ })
+
+ Context("Darwin/Metal backends", func() {
+ When("running on darwin", func() {
+ BeforeEach(func() {
+ if runtime.GOOS != "darwin" {
+ Skip("Skipping darwin-specific tests on non-darwin system")
+ }
+ })
+
+ It("should be compatible for MLX backend", func() {
+ mlxBackend := &GalleryBackend{
+ Metadata: Metadata{
+ Name: "mlx",
+ },
+ URI: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-mlx",
+ }
+ Expect(mlxBackend.IsCompatibleWith(&system.SystemState{})).To(BeTrue())
+ })
+
+ It("should be compatible for metal-llama-cpp backend", func() {
+ metalBackend := &GalleryBackend{
+ Metadata: Metadata{
+ Name: "metal-llama-cpp",
+ },
+ URI: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-llama-cpp",
+ }
+ Expect(metalBackend.IsCompatibleWith(&system.SystemState{})).To(BeTrue())
+ })
+ })
+
+ When("running on non-darwin", func() {
+ BeforeEach(func() {
+ if runtime.GOOS == "darwin" {
+ Skip("Skipping non-darwin-specific tests on darwin system")
+ }
+ })
+
+ It("should NOT be compatible for MLX backend", func() {
+ mlxBackend := &GalleryBackend{
+ Metadata: Metadata{
+ Name: "mlx",
+ },
+ URI: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-mlx",
+ }
+ Expect(mlxBackend.IsCompatibleWith(&system.SystemState{})).To(BeFalse())
+ })
+
+ It("should NOT be compatible for metal-llama-cpp backend", func() {
+ metalBackend := &GalleryBackend{
+ Metadata: Metadata{
+ Name: "metal-llama-cpp",
+ },
+ URI: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-llama-cpp",
+ }
+ Expect(metalBackend.IsCompatibleWith(&system.SystemState{})).To(BeFalse())
+ })
+ })
+ })
+
+ Context("NVIDIA/CUDA backends", func() {
+ When("running on non-darwin", func() {
+ BeforeEach(func() {
+ if runtime.GOOS == "darwin" {
+ Skip("Skipping CUDA tests on darwin system")
+ }
+ })
+
+ It("should NOT be compatible without nvidia GPU", func() {
+ cudaBackend := &GalleryBackend{
+ Metadata: Metadata{
+ Name: "cuda12-llama-cpp",
+ },
+ URI: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-llama-cpp",
+ }
+ Expect(cudaBackend.IsCompatibleWith(&system.SystemState{})).To(BeFalse())
+ Expect(cudaBackend.IsCompatibleWith(&system.SystemState{GPUVendor: system.AMD, VRAM: 8 * 1024 * 1024 * 1024})).To(BeFalse())
+ })
+
+ It("should be compatible with nvidia GPU", func() {
+ cudaBackend := &GalleryBackend{
+ Metadata: Metadata{
+ Name: "cuda12-llama-cpp",
+ },
+ URI: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-llama-cpp",
+ }
+ Expect(cudaBackend.IsCompatibleWith(&system.SystemState{GPUVendor: system.Nvidia, VRAM: 8 * 1024 * 1024 * 1024})).To(BeTrue())
+ })
+
+ It("should be compatible with cuda13 backend on nvidia GPU", func() {
+ cuda13Backend := &GalleryBackend{
+ Metadata: Metadata{
+ Name: "cuda13-llama-cpp",
+ },
+ URI: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-13-llama-cpp",
+ }
+ Expect(cuda13Backend.IsCompatibleWith(&system.SystemState{GPUVendor: system.Nvidia, VRAM: 8 * 1024 * 1024 * 1024})).To(BeTrue())
+ })
+ })
+ })
+
+ Context("AMD/ROCm backends", func() {
+ When("running on non-darwin", func() {
+ BeforeEach(func() {
+ if runtime.GOOS == "darwin" {
+ Skip("Skipping AMD/ROCm tests on darwin system")
+ }
+ })
+
+ It("should NOT be compatible without AMD GPU", func() {
+ rocmBackend := &GalleryBackend{
+ Metadata: Metadata{
+ Name: "rocm-llama-cpp",
+ },
+ URI: "quay.io/go-skynet/local-ai-backends:latest-gpu-rocm-hipblas-llama-cpp",
+ }
+ Expect(rocmBackend.IsCompatibleWith(&system.SystemState{})).To(BeFalse())
+ Expect(rocmBackend.IsCompatibleWith(&system.SystemState{GPUVendor: system.Nvidia, VRAM: 8 * 1024 * 1024 * 1024})).To(BeFalse())
+ })
+
+ It("should be compatible with AMD GPU", func() {
+ rocmBackend := &GalleryBackend{
+ Metadata: Metadata{
+ Name: "rocm-llama-cpp",
+ },
+ URI: "quay.io/go-skynet/local-ai-backends:latest-gpu-rocm-hipblas-llama-cpp",
+ }
+ Expect(rocmBackend.IsCompatibleWith(&system.SystemState{GPUVendor: system.AMD, VRAM: 8 * 1024 * 1024 * 1024})).To(BeTrue())
+ })
+
+ It("should be compatible with hipblas backend on AMD GPU", func() {
+ hipBackend := &GalleryBackend{
+ Metadata: Metadata{
+ Name: "hip-llama-cpp",
+ },
+ URI: "quay.io/go-skynet/local-ai-backends:latest-gpu-hip-llama-cpp",
+ }
+ Expect(hipBackend.IsCompatibleWith(&system.SystemState{GPUVendor: system.AMD, VRAM: 8 * 1024 * 1024 * 1024})).To(BeTrue())
+ })
+ })
+ })
+
+ Context("Intel/SYCL backends", func() {
+ When("running on non-darwin", func() {
+ BeforeEach(func() {
+ if runtime.GOOS == "darwin" {
+ Skip("Skipping Intel/SYCL tests on darwin system")
+ }
+ })
+
+ It("should NOT be compatible without Intel GPU", func() {
+ intelBackend := &GalleryBackend{
+ Metadata: Metadata{
+ Name: "intel-sycl-f16-llama-cpp",
+ },
+ URI: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-sycl-f16-llama-cpp",
+ }
+ Expect(intelBackend.IsCompatibleWith(&system.SystemState{})).To(BeFalse())
+ Expect(intelBackend.IsCompatibleWith(&system.SystemState{GPUVendor: system.Nvidia, VRAM: 8 * 1024 * 1024 * 1024})).To(BeFalse())
+ })
+
+ It("should be compatible with Intel GPU", func() {
+ intelBackend := &GalleryBackend{
+ Metadata: Metadata{
+ Name: "intel-sycl-f16-llama-cpp",
+ },
+ URI: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-sycl-f16-llama-cpp",
+ }
+ Expect(intelBackend.IsCompatibleWith(&system.SystemState{GPUVendor: system.Intel, VRAM: 8 * 1024 * 1024 * 1024})).To(BeTrue())
+ })
+
+ It("should be compatible with intel-sycl-f32 backend on Intel GPU", func() {
+ intelF32Backend := &GalleryBackend{
+ Metadata: Metadata{
+ Name: "intel-sycl-f32-llama-cpp",
+ },
+ URI: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-sycl-f32-llama-cpp",
+ }
+ Expect(intelF32Backend.IsCompatibleWith(&system.SystemState{GPUVendor: system.Intel, VRAM: 8 * 1024 * 1024 * 1024})).To(BeTrue())
+ })
+
+ It("should be compatible with intel-transformers backend on Intel GPU", func() {
+ intelTransformersBackend := &GalleryBackend{
+ Metadata: Metadata{
+ Name: "intel-transformers",
+ },
+ URI: "quay.io/go-skynet/local-ai-backends:latest-intel-transformers",
+ }
+ Expect(intelTransformersBackend.IsCompatibleWith(&system.SystemState{GPUVendor: system.Intel, VRAM: 8 * 1024 * 1024 * 1024})).To(BeTrue())
+ })
+ })
+ })
+
+ Context("Vulkan backends", func() {
+ It("should be compatible on CPU-only systems", func() {
+ // Vulkan backends don't have a specific GPU vendor requirement in the current logic
+ // They are compatible if no other GPU-specific pattern matches
+ vulkanBackend := &GalleryBackend{
+ Metadata: Metadata{
+ Name: "vulkan-llama-cpp",
+ },
+ URI: "quay.io/go-skynet/local-ai-backends:latest-gpu-vulkan-llama-cpp",
+ }
+ // Vulkan doesn't have vendor-specific filtering in current implementation
+ Expect(vulkanBackend.IsCompatibleWith(&system.SystemState{})).To(BeTrue())
+ })
+ })
+ })
+
+ It("should find best backend from meta based on system capabilities", func() {
+
+ metaBackend := &GalleryBackend{
+ Metadata: Metadata{
+ Name: "meta-backend",
+ },
+ CapabilitiesMap: map[string]string{
+ "nvidia": "nvidia-backend",
+ "amd": "amd-backend",
+ "intel": "intel-backend",
+ "metal": "metal-backend",
+ "default": "default-backend",
+ },
+ }
+
+ nvidiaBackend := &GalleryBackend{
+ Metadata: Metadata{
+ Name: "nvidia-backend",
+ },
+ URI: testImage,
+ }
+
+ amdBackend := &GalleryBackend{
+ Metadata: Metadata{
+ Name: "amd-backend",
+ },
+ URI: testImage,
+ }
+
+ metalBackend := &GalleryBackend{
+ Metadata: Metadata{
+ Name: "metal-backend",
+ },
+ URI: testImage,
+ }
+
+ defaultBackend := &GalleryBackend{
+ Metadata: Metadata{
+ Name: "default-backend",
+ },
+ URI: testImage,
+ }
+
+ backends := GalleryElements[*GalleryBackend]{nvidiaBackend, amdBackend, metalBackend, defaultBackend}
+
+ if runtime.GOOS == "darwin" && runtime.GOARCH == "arm64" {
+ metal := &system.SystemState{}
+ bestBackend := metaBackend.FindBestBackendFromMeta(metal, backends)
+ Expect(bestBackend).To(Equal(metalBackend))
+
+ } else {
+ // Test with NVIDIA system state
+ nvidiaSystemState := &system.SystemState{GPUVendor: "nvidia", VRAM: 1000000000000}
+ bestBackend := metaBackend.FindBestBackendFromMeta(nvidiaSystemState, backends)
+ Expect(bestBackend).To(Equal(nvidiaBackend))
+
+ // Test with AMD system state
+ amdSystemState := &system.SystemState{GPUVendor: "amd", VRAM: 1000000000000}
+ bestBackend = metaBackend.FindBestBackendFromMeta(amdSystemState, backends)
+ Expect(bestBackend).To(Equal(amdBackend))
+
+ // Test with default system state (not enough VRAM)
+ defaultSystemState := &system.SystemState{GPUVendor: "amd"}
+ bestBackend = metaBackend.FindBestBackendFromMeta(defaultSystemState, backends)
+ Expect(bestBackend).To(Equal(defaultBackend))
+
+ // Test with default system state
+ defaultSystemState = &system.SystemState{GPUVendor: "default"}
+ bestBackend = metaBackend.FindBestBackendFromMeta(defaultSystemState, backends)
+ Expect(bestBackend).To(Equal(defaultBackend))
+
+ backends = GalleryElements[*GalleryBackend]{nvidiaBackend, amdBackend, metalBackend}
+ // Test with unsupported GPU vendor
+ unsupportedSystemState := &system.SystemState{GPUVendor: "unsupported"}
+ bestBackend = metaBackend.FindBestBackendFromMeta(unsupportedSystemState, backends)
+ Expect(bestBackend).To(BeNil())
+ }
+ })
+
+ It("should handle meta backend deletion correctly", func() {
+ if runtime.GOOS == "darwin" && runtime.GOARCH == "arm64" {
+ Skip("Skipping test on darwin/arm64")
+ }
+
+ metaBackend := &GalleryBackend{
+ Metadata: Metadata{
+ Name: "meta-backend",
+ },
+ CapabilitiesMap: map[string]string{
+ "nvidia": "nvidia-backend",
+ "amd": "amd-backend",
+ "intel": "intel-backend",
+ },
+ }
+
+ nvidiaBackend := &GalleryBackend{
+ Metadata: Metadata{
+ Name: "nvidia-backend",
+ },
+ URI: testImage,
+ }
+
+ amdBackend := &GalleryBackend{
+ Metadata: Metadata{
+ Name: "amd-backend",
+ },
+ URI: testImage,
+ }
+
+ gallery := config.Gallery{
+ Name: "test-gallery",
+ URL: "file://" + filepath.Join(tempDir, "backend-gallery.yaml"),
+ }
+
+ galleryBackend := GalleryBackends{amdBackend, nvidiaBackend, metaBackend}
+
+ dat, err := yaml.Marshal(galleryBackend)
+ Expect(err).NotTo(HaveOccurred())
+ err = os.WriteFile(filepath.Join(tempDir, "backend-gallery.yaml"), dat, 0644)
+ Expect(err).NotTo(HaveOccurred())
+
+ // Test with NVIDIA system state
+ nvidiaSystemState := &system.SystemState{
+ GPUVendor: "nvidia",
+ VRAM: 1000000000000,
+ Backend: system.Backend{BackendsPath: tempDir},
+ }
+ err = InstallBackendFromGallery(context.TODO(), []config.Gallery{gallery}, nvidiaSystemState, ml, "meta-backend", nil, true)
+ Expect(err).NotTo(HaveOccurred())
+
+ metaBackendPath := filepath.Join(tempDir, "meta-backend")
+ Expect(metaBackendPath).To(BeADirectory())
+
+ concreteBackendPath := filepath.Join(tempDir, "nvidia-backend")
+ Expect(concreteBackendPath).To(BeADirectory())
+
+ systemState, err := system.GetSystemState(
+ system.WithBackendPath(tempDir),
+ )
+ Expect(err).NotTo(HaveOccurred())
+
+ allBackends, err := ListSystemBackends(systemState)
+ Expect(err).NotTo(HaveOccurred())
+ Expect(allBackends).To(HaveKey("meta-backend"))
+ Expect(allBackends).To(HaveKey("nvidia-backend"))
+
+ // Delete meta backend by name
+ err = DeleteBackendFromSystem(systemState, "meta-backend")
+ Expect(err).NotTo(HaveOccurred())
+
+ // Verify meta backend directory is deleted
+ Expect(metaBackendPath).NotTo(BeADirectory())
+
+ // Verify concrete backend directory is deleted
+ Expect(concreteBackendPath).NotTo(BeADirectory())
+ })
+
+ It("should handle meta backend deletion correctly with aliases", func() {
+ if runtime.GOOS == "darwin" && runtime.GOARCH == "arm64" {
+ Skip("Skipping test on darwin/arm64")
+ }
+ metaBackend := &GalleryBackend{
+ Metadata: Metadata{
+ Name: "meta-backend",
+ },
+ Alias: "backend-alias",
+ CapabilitiesMap: map[string]string{
+ "nvidia": "nvidia-backend",
+ "amd": "amd-backend",
+ "intel": "intel-backend",
+ },
+ }
+
+ nvidiaBackend := &GalleryBackend{
+ Metadata: Metadata{
+ Name: "nvidia-backend",
+ },
+ Alias: "backend-alias",
+ URI: testImage,
+ }
+
+ amdBackend := &GalleryBackend{
+ Metadata: Metadata{
+ Name: "amd-backend",
+ },
+ Alias: "backend-alias",
+ URI: testImage,
+ }
+
+ gallery := config.Gallery{
+ Name: "test-gallery",
+ URL: "file://" + filepath.Join(tempDir, "backend-gallery.yaml"),
+ }
+
+ galleryBackend := GalleryBackends{amdBackend, nvidiaBackend, metaBackend}
+
+ dat, err := yaml.Marshal(galleryBackend)
+ Expect(err).NotTo(HaveOccurred())
+ err = os.WriteFile(filepath.Join(tempDir, "backend-gallery.yaml"), dat, 0644)
+ Expect(err).NotTo(HaveOccurred())
+
+ // Test with NVIDIA system state
+ nvidiaSystemState := &system.SystemState{
+ GPUVendor: "nvidia",
+ VRAM: 1000000000000,
+ Backend: system.Backend{BackendsPath: tempDir},
+ }
+ err = InstallBackendFromGallery(context.TODO(), []config.Gallery{gallery}, nvidiaSystemState, ml, "meta-backend", nil, true)
+ Expect(err).NotTo(HaveOccurred())
+
+ metaBackendPath := filepath.Join(tempDir, "meta-backend")
+ Expect(metaBackendPath).To(BeADirectory())
+
+ concreteBackendPath := filepath.Join(tempDir, "nvidia-backend")
+ Expect(concreteBackendPath).To(BeADirectory())
+
+ systemState, err := system.GetSystemState(
+ system.WithBackendPath(tempDir),
+ )
+ Expect(err).NotTo(HaveOccurred())
+
+ allBackends, err := ListSystemBackends(systemState)
+ Expect(err).NotTo(HaveOccurred())
+ Expect(allBackends).To(HaveKey("meta-backend"))
+ Expect(allBackends).To(HaveKey("nvidia-backend"))
+ mback, exists := allBackends.Get("meta-backend")
+ Expect(exists).To(BeTrue())
+ Expect(mback.IsMeta).To(BeTrue())
+ Expect(mback.Metadata.MetaBackendFor).To(Equal("nvidia-backend"))
+
+ // Delete meta backend by name
+ err = DeleteBackendFromSystem(systemState, "meta-backend")
+ Expect(err).NotTo(HaveOccurred())
+
+ // Verify meta backend directory is deleted
+ Expect(metaBackendPath).NotTo(BeADirectory())
+
+ // Verify concrete backend directory is deleted
+ Expect(concreteBackendPath).NotTo(BeADirectory())
+ })
+
+ It("should handle meta backend deletion correctly with aliases pointing to the same backend", func() {
+ if runtime.GOOS == "darwin" && runtime.GOARCH == "arm64" {
+ Skip("Skipping test on darwin/arm64")
+ }
+ metaBackend := &GalleryBackend{
+ Metadata: Metadata{
+ Name: "meta-backend",
+ },
+ Alias: "meta-backend",
+ CapabilitiesMap: map[string]string{
+ "nvidia": "nvidia-backend",
+ "amd": "amd-backend",
+ "intel": "intel-backend",
+ },
+ }
+
+ nvidiaBackend := &GalleryBackend{
+ Metadata: Metadata{
+ Name: "nvidia-backend",
+ },
+ Alias: "meta-backend",
+ URI: testImage,
+ }
+
+ amdBackend := &GalleryBackend{
+ Metadata: Metadata{
+ Name: "amd-backend",
+ },
+ Alias: "meta-backend",
+ URI: testImage,
+ }
+
+ gallery := config.Gallery{
+ Name: "test-gallery",
+ URL: "file://" + filepath.Join(tempDir, "backend-gallery.yaml"),
+ }
+
+ galleryBackend := GalleryBackends{amdBackend, nvidiaBackend, metaBackend}
+
+ dat, err := yaml.Marshal(galleryBackend)
+ Expect(err).NotTo(HaveOccurred())
+ err = os.WriteFile(filepath.Join(tempDir, "backend-gallery.yaml"), dat, 0644)
+ Expect(err).NotTo(HaveOccurred())
+
+ // Test with NVIDIA system state
+ nvidiaSystemState := &system.SystemState{
+ GPUVendor: "nvidia",
+ VRAM: 1000000000000,
+ Backend: system.Backend{BackendsPath: tempDir},
+ }
+ err = InstallBackendFromGallery(context.TODO(), []config.Gallery{gallery}, nvidiaSystemState, ml, "meta-backend", nil, true)
+ Expect(err).NotTo(HaveOccurred())
+
+ metaBackendPath := filepath.Join(tempDir, "meta-backend")
+ Expect(metaBackendPath).To(BeADirectory())
+
+ concreteBackendPath := filepath.Join(tempDir, "nvidia-backend")
+ Expect(concreteBackendPath).To(BeADirectory())
+
+ systemState, err := system.GetSystemState(
+ system.WithBackendPath(tempDir),
+ )
+ Expect(err).NotTo(HaveOccurred())
+
+ allBackends, err := ListSystemBackends(systemState)
+ Expect(err).NotTo(HaveOccurred())
+ Expect(allBackends).To(HaveKey("meta-backend"))
+ Expect(allBackends).To(HaveKey("nvidia-backend"))
+ mback, exists := allBackends.Get("meta-backend")
+ Expect(exists).To(BeTrue())
+ Expect(mback.RunFile).To(Equal(filepath.Join(tempDir, "nvidia-backend", "run.sh")))
+
+ // Delete meta backend by name
+ err = DeleteBackendFromSystem(systemState, "meta-backend")
+ Expect(err).NotTo(HaveOccurred())
+
+ // Verify meta backend directory is deleted
+ Expect(metaBackendPath).NotTo(BeADirectory())
+
+ // Verify concrete backend directory is deleted
+ Expect(concreteBackendPath).NotTo(BeADirectory())
+ })
+
+ It("should list meta backends correctly in system backends", func() {
+ // Create a meta backend directory with metadata
+ metaBackendPath := filepath.Join(tempDir, "meta-backend")
+ err := os.MkdirAll(metaBackendPath, 0750)
+ Expect(err).NotTo(HaveOccurred())
+
+ // Create metadata file pointing to concrete backend
+ metadata := &BackendMetadata{
+ MetaBackendFor: "concrete-backend",
+ Name: "meta-backend",
+ InstalledAt: "2023-01-01T00:00:00Z",
+ }
+ metadataData, err := json.Marshal(metadata)
+ Expect(err).NotTo(HaveOccurred())
+ err = os.WriteFile(filepath.Join(metaBackendPath, "metadata.json"), metadataData, 0644)
+ Expect(err).NotTo(HaveOccurred())
+
+ // Create the concrete backend directory with run.sh
+ concreteBackendPath := filepath.Join(tempDir, "concrete-backend")
+ err = os.MkdirAll(concreteBackendPath, 0750)
+ Expect(err).NotTo(HaveOccurred())
+ err = os.WriteFile(filepath.Join(concreteBackendPath, "metadata.json"), []byte("{}"), 0755)
+ Expect(err).NotTo(HaveOccurred())
+ err = os.WriteFile(filepath.Join(concreteBackendPath, "run.sh"), []byte(""), 0755)
+ Expect(err).NotTo(HaveOccurred())
+
+ // List system backends
+ systemState, err := system.GetSystemState(
+ system.WithBackendPath(tempDir),
+ )
+ Expect(err).NotTo(HaveOccurred())
+
+ backends, err := ListSystemBackends(systemState)
+ Expect(err).NotTo(HaveOccurred())
+
+ metaBackend, exists := backends.Get("meta-backend")
+ concreteBackendRunFile := filepath.Join(tempDir, "concrete-backend", "run.sh")
+
+ // Should include both the meta backend name and concrete backend name
+ Expect(exists).To(BeTrue())
+ Expect(backends.Exists("concrete-backend")).To(BeTrue())
+
+ // meta-backend should be empty
+ Expect(metaBackend.IsMeta).To(BeTrue())
+ Expect(metaBackend.RunFile).To(Equal(concreteBackendRunFile))
+ // concrete-backend should point to its own run.sh
+ concreteBackend, exists := backends.Get("concrete-backend")
+ Expect(exists).To(BeTrue())
+ Expect(concreteBackend.RunFile).To(Equal(concreteBackendRunFile))
+ })
+ })
+
+ Describe("InstallBackend", func() {
+ It("should create base path if it doesn't exist", func() {
+ newPath := filepath.Join(tempDir, "new-path")
+ backend := GalleryBackend{
+ Metadata: Metadata{
+ Name: "test-backend",
+ },
+ URI: "test-uri",
+ }
+
+ systemState, err := system.GetSystemState(
+ system.WithBackendPath(newPath),
+ )
+ Expect(err).NotTo(HaveOccurred())
+ err = InstallBackend(context.TODO(), systemState, ml, &backend, nil)
+ Expect(newPath).To(BeADirectory())
+ Expect(err).To(HaveOccurred()) // Will fail due to invalid URI, but path should be created
+ })
+
+ It("should overwrite existing backend", func() {
+ if runtime.GOOS == "darwin" && runtime.GOARCH == "arm64" {
+ Skip("Skipping test on darwin/arm64")
+ }
+ newPath := filepath.Join(tempDir, "test-backend")
+
+ // Create a dummy backend directory
+ err := os.MkdirAll(newPath, 0750)
+ Expect(err).NotTo(HaveOccurred())
+
+ err = os.WriteFile(filepath.Join(newPath, "metadata.json"), []byte("foo"), 0644)
+ Expect(err).NotTo(HaveOccurred())
+ err = os.WriteFile(filepath.Join(newPath, "run.sh"), []byte(""), 0644)
+ Expect(err).NotTo(HaveOccurred())
+
+ backend := GalleryBackend{
+ Metadata: Metadata{
+ Name: "test-backend",
+ },
+ URI: "quay.io/mudler/tests:localai-backend-test",
+ Alias: "test-alias",
+ }
+
+ systemState, err := system.GetSystemState(
+ system.WithBackendPath(tempDir),
+ )
+ Expect(err).NotTo(HaveOccurred())
+ err = InstallBackend(context.TODO(), systemState, ml, &backend, nil)
+ Expect(err).ToNot(HaveOccurred())
+ Expect(filepath.Join(tempDir, "test-backend", "metadata.json")).To(BeARegularFile())
+ dat, err := os.ReadFile(filepath.Join(tempDir, "test-backend", "metadata.json"))
+ Expect(err).ToNot(HaveOccurred())
+ Expect(string(dat)).ToNot(Equal("foo"))
+ })
+
+ It("should overwrite existing backend", func() {
+ if runtime.GOOS == "darwin" && runtime.GOARCH == "arm64" {
+ Skip("Skipping test on darwin/arm64")
+ }
+ newPath := filepath.Join(tempDir, "test-backend")
+
+ // Create a dummy backend directory
+ err := os.MkdirAll(newPath, 0750)
+ Expect(err).NotTo(HaveOccurred())
+
+ backend := GalleryBackend{
+ Metadata: Metadata{
+ Name: "test-backend",
+ },
+ URI: "quay.io/mudler/tests:localai-backend-test",
+ Alias: "test-alias",
+ }
+
+ systemState, err := system.GetSystemState(
+ system.WithBackendPath(tempDir),
+ )
+ Expect(err).NotTo(HaveOccurred())
+
+ Expect(filepath.Join(tempDir, "test-backend", "metadata.json")).ToNot(BeARegularFile())
+
+ err = InstallBackend(context.TODO(), systemState, ml, &backend, nil)
+ Expect(err).ToNot(HaveOccurred())
+ Expect(filepath.Join(tempDir, "test-backend", "metadata.json")).To(BeARegularFile())
+ })
+
+ It("should create alias file when specified", func() {
+ if runtime.GOOS == "darwin" && runtime.GOARCH == "arm64" {
+ Skip("Skipping test on darwin/arm64")
+ }
+ backend := GalleryBackend{
+ Metadata: Metadata{
+ Name: "test-backend",
+ },
+ URI: "quay.io/mudler/tests:localai-backend-test",
+ Alias: "test-alias",
+ }
+
+ systemState, err := system.GetSystemState(
+ system.WithBackendPath(tempDir),
+ )
+ Expect(err).NotTo(HaveOccurred())
+ err = InstallBackend(context.TODO(), systemState, ml, &backend, nil)
+ Expect(err).ToNot(HaveOccurred())
+ Expect(filepath.Join(tempDir, "test-backend", "metadata.json")).To(BeARegularFile())
+
+ // Read and verify metadata
+ metadataData, err := os.ReadFile(filepath.Join(tempDir, "test-backend", "metadata.json"))
+ Expect(err).ToNot(HaveOccurred())
+ var metadata BackendMetadata
+ err = json.Unmarshal(metadataData, &metadata)
+ Expect(err).ToNot(HaveOccurred())
+ Expect(metadata.Alias).To(Equal("test-alias"))
+ Expect(metadata.Name).To(Equal("test-backend"))
+
+ Expect(filepath.Join(tempDir, "test-backend", "run.sh")).To(BeARegularFile())
+
+ // Check that the alias was recognized
+ backends, err := ListSystemBackends(systemState)
+ Expect(err).ToNot(HaveOccurred())
+ aliasBackend, exists := backends.Get("test-alias")
+ Expect(exists).To(BeTrue())
+ Expect(aliasBackend.RunFile).To(Equal(filepath.Join(tempDir, "test-backend", "run.sh")))
+ testB, exists := backends.Get("test-backend")
+ Expect(exists).To(BeTrue())
+ Expect(testB.RunFile).To(Equal(filepath.Join(tempDir, "test-backend", "run.sh")))
+ })
+ })
+
+ Describe("DeleteBackendFromSystem", func() {
+ It("should delete backend directory", func() {
+ backendName := "test-backend"
+ backendPath := filepath.Join(tempDir, backendName)
+
+ // Create a dummy backend directory
+ err := os.MkdirAll(backendPath, 0750)
+ Expect(err).NotTo(HaveOccurred())
+
+ err = os.WriteFile(filepath.Join(backendPath, "metadata.json"), []byte("{}"), 0644)
+ Expect(err).NotTo(HaveOccurred())
+ err = os.WriteFile(filepath.Join(backendPath, "run.sh"), []byte(""), 0644)
+ Expect(err).NotTo(HaveOccurred())
+
+ systemState, err := system.GetSystemState(
+ system.WithBackendPath(tempDir),
+ )
+ Expect(err).NotTo(HaveOccurred())
+ err = DeleteBackendFromSystem(systemState, backendName)
+ Expect(err).NotTo(HaveOccurred())
+ Expect(backendPath).NotTo(BeADirectory())
+ })
+
+ It("should not error when backend doesn't exist", func() {
+ systemState, err := system.GetSystemState(
+ system.WithBackendPath(tempDir),
+ )
+ Expect(err).NotTo(HaveOccurred())
+ err = DeleteBackendFromSystem(systemState, "non-existent")
+ Expect(err).To(HaveOccurred())
+ })
+ })
+
+ Describe("ListSystemBackends", func() {
+ It("should list backends without aliases", func() {
+ // Create some dummy backend directories
+ backendNames := []string{"backend1", "backend2", "backend3"}
+ for _, name := range backendNames {
+ err := os.MkdirAll(filepath.Join(tempDir, name), 0750)
+ Expect(err).NotTo(HaveOccurred())
+ err = os.WriteFile(filepath.Join(tempDir, name, "metadata.json"), []byte("{}"), 0755)
+ Expect(err).NotTo(HaveOccurred())
+ err = os.WriteFile(filepath.Join(tempDir, name, "run.sh"), []byte(""), 0755)
+ Expect(err).NotTo(HaveOccurred())
+ }
+
+ systemState, err := system.GetSystemState(
+ system.WithBackendPath(tempDir),
+ )
+ Expect(err).NotTo(HaveOccurred())
+ backends, err := ListSystemBackends(systemState)
+ Expect(err).NotTo(HaveOccurred())
+ Expect(backends).To(HaveLen(len(backendNames)))
+
+ for _, name := range backendNames {
+ backend, exists := backends.Get(name)
+ Expect(exists).To(BeTrue())
+ Expect(backend.RunFile).To(Equal(filepath.Join(tempDir, name, "run.sh")))
+ }
+ })
+
+ It("should handle backends with aliases", func() {
+ backendName := "backend1"
+ alias := "alias1"
+ backendPath := filepath.Join(tempDir, backendName)
+
+ // Create backend directory
+ err := os.MkdirAll(backendPath, 0750)
+ Expect(err).NotTo(HaveOccurred())
+
+ // Create metadata file with alias
+ metadata := &BackendMetadata{
+ Alias: alias,
+ Name: backendName,
+ InstalledAt: "2023-01-01T00:00:00Z",
+ }
+ metadataData, err := json.Marshal(metadata)
+ Expect(err).NotTo(HaveOccurred())
+ err = os.WriteFile(filepath.Join(backendPath, "metadata.json"), metadataData, 0644)
+ Expect(err).NotTo(HaveOccurred())
+ err = os.WriteFile(filepath.Join(backendPath, "run.sh"), []byte(""), 0755)
+ Expect(err).NotTo(HaveOccurred())
+
+ systemState, err := system.GetSystemState(
+ system.WithBackendPath(tempDir),
+ )
+ Expect(err).NotTo(HaveOccurred())
+ backends, err := ListSystemBackends(systemState)
+ Expect(err).NotTo(HaveOccurred())
+ backend, exists := backends.Get(alias)
+ Expect(exists).To(BeTrue())
+ Expect(backend.RunFile).To(Equal(filepath.Join(tempDir, backendName, "run.sh")))
+ })
+
+ It("should return error when base path doesn't exist", func() {
+ systemState, err := system.GetSystemState(
+ system.WithBackendPath("foobardir"),
+ )
+ Expect(err).NotTo(HaveOccurred())
+ _, err = ListSystemBackends(systemState)
+ Expect(err).To(HaveOccurred())
+ })
+ })
+})
diff --git a/core/gallery/gallery.go b/core/gallery/gallery.go
new file mode 100644
index 0000000000000000000000000000000000000000..6add8cfa73f81654822d7e12a8dcd4de5330de66
--- /dev/null
+++ b/core/gallery/gallery.go
@@ -0,0 +1,345 @@
+package gallery
+
+import (
+ "context"
+ "fmt"
+ "os"
+ "path/filepath"
+ "sort"
+ "strings"
+ "time"
+
+ "github.com/lithammer/fuzzysearch/fuzzy"
+ "github.com/mudler/LocalAI/core/config"
+ "github.com/mudler/LocalAI/pkg/downloader"
+ "github.com/mudler/LocalAI/pkg/system"
+ "github.com/mudler/LocalAI/pkg/xsync"
+ "github.com/mudler/xlog"
+
+ "gopkg.in/yaml.v2"
+)
+
+func GetGalleryConfigFromURL[T any](url string, basePath string) (T, error) {
+ var config T
+ uri := downloader.URI(url)
+ err := uri.ReadWithCallback(basePath, func(url string, d []byte) error {
+ return yaml.Unmarshal(d, &config)
+ })
+ if err != nil {
+ xlog.Error("failed to get gallery config for url", "error", err, "url", url)
+ return config, err
+ }
+ return config, nil
+}
+
+func GetGalleryConfigFromURLWithContext[T any](ctx context.Context, url string, basePath string) (T, error) {
+ var config T
+ uri := downloader.URI(url)
+ err := uri.ReadWithAuthorizationAndCallback(ctx, basePath, "", func(url string, d []byte) error {
+ return yaml.Unmarshal(d, &config)
+ })
+ if err != nil {
+ xlog.Error("failed to get gallery config for url", "error", err, "url", url)
+ return config, err
+ }
+ return config, nil
+}
+
+func ReadConfigFile[T any](filePath string) (*T, error) {
+ // Read the YAML file
+ yamlFile, err := os.ReadFile(filePath)
+ if err != nil {
+ return nil, fmt.Errorf("failed to read YAML file: %v", err)
+ }
+
+ // Unmarshal YAML data into a Config struct
+ var config T
+ err = yaml.Unmarshal(yamlFile, &config)
+ if err != nil {
+ return nil, fmt.Errorf("failed to unmarshal YAML: %v", err)
+ }
+
+ return &config, nil
+}
+
+type GalleryElement interface {
+ SetGallery(gallery config.Gallery)
+ SetInstalled(installed bool)
+ GetName() string
+ GetDescription() string
+ GetTags() []string
+ GetInstalled() bool
+ GetLicense() string
+ GetGallery() config.Gallery
+}
+
+type GalleryElements[T GalleryElement] []T
+
+func (gm GalleryElements[T]) Search(term string) GalleryElements[T] {
+ var filteredModels GalleryElements[T]
+ term = strings.ToLower(term)
+ for _, m := range gm {
+ if fuzzy.Match(term, strings.ToLower(m.GetName())) ||
+ fuzzy.Match(term, strings.ToLower(m.GetGallery().Name)) ||
+ strings.Contains(strings.ToLower(m.GetName()), term) ||
+ strings.Contains(strings.ToLower(m.GetDescription()), term) ||
+ strings.Contains(strings.ToLower(m.GetGallery().Name), term) ||
+ strings.Contains(strings.ToLower(strings.Join(m.GetTags(), ",")), term) {
+ filteredModels = append(filteredModels, m)
+ }
+ }
+
+ return filteredModels
+}
+
+func (gm GalleryElements[T]) SortByName(sortOrder string) GalleryElements[T] {
+ sort.Slice(gm, func(i, j int) bool {
+ if sortOrder == "asc" {
+ return strings.ToLower(gm[i].GetName()) < strings.ToLower(gm[j].GetName())
+ } else {
+ return strings.ToLower(gm[i].GetName()) > strings.ToLower(gm[j].GetName())
+ }
+ })
+ return gm
+}
+
+func (gm GalleryElements[T]) SortByRepository(sortOrder string) GalleryElements[T] {
+ sort.Slice(gm, func(i, j int) bool {
+ if sortOrder == "asc" {
+ return strings.ToLower(gm[i].GetGallery().Name) < strings.ToLower(gm[j].GetGallery().Name)
+ } else {
+ return strings.ToLower(gm[i].GetGallery().Name) > strings.ToLower(gm[j].GetGallery().Name)
+ }
+ })
+ return gm
+}
+
+func (gm GalleryElements[T]) SortByLicense(sortOrder string) GalleryElements[T] {
+ sort.Slice(gm, func(i, j int) bool {
+ licenseI := gm[i].GetLicense()
+ licenseJ := gm[j].GetLicense()
+ var result bool
+ if licenseI == "" && licenseJ != "" {
+ return sortOrder == "desc"
+ } else if licenseI != "" && licenseJ == "" {
+ return sortOrder == "asc"
+ } else if licenseI == "" && licenseJ == "" {
+ return false
+ } else {
+ result = strings.ToLower(licenseI) < strings.ToLower(licenseJ)
+ }
+ if sortOrder == "desc" {
+ return !result
+ } else {
+ return result
+ }
+ })
+ return gm
+}
+
+func (gm GalleryElements[T]) SortByInstalled(sortOrder string) GalleryElements[T] {
+ sort.Slice(gm, func(i, j int) bool {
+ var result bool
+ // Sort by installed status: installed items first (true > false)
+ if gm[i].GetInstalled() != gm[j].GetInstalled() {
+ result = gm[i].GetInstalled()
+ } else {
+ result = strings.ToLower(gm[i].GetName()) < strings.ToLower(gm[j].GetName())
+ }
+ if sortOrder == "desc" {
+ return !result
+ } else {
+ return result
+ }
+ })
+ return gm
+}
+
+func (gm GalleryElements[T]) FindByName(name string) T {
+ for _, m := range gm {
+ if strings.EqualFold(m.GetName(), name) {
+ return m
+ }
+ }
+ var zero T
+ return zero
+}
+
+func (gm GalleryElements[T]) Paginate(pageNum int, itemsNum int) GalleryElements[T] {
+ start := (pageNum - 1) * itemsNum
+ end := start + itemsNum
+ if start > len(gm) {
+ start = len(gm)
+ }
+ if end > len(gm) {
+ end = len(gm)
+ }
+ return gm[start:end]
+}
+
+func FindGalleryElement[T GalleryElement](models []T, name string) T {
+ var model T
+ name = strings.ReplaceAll(name, string(os.PathSeparator), "__")
+
+ if !strings.Contains(name, "@") {
+ for _, m := range models {
+ if strings.EqualFold(strings.ToLower(m.GetName()), strings.ToLower(name)) {
+ model = m
+ break
+ }
+ }
+
+ } else {
+ for _, m := range models {
+ if strings.EqualFold(strings.ToLower(name), strings.ToLower(fmt.Sprintf("%s@%s", m.GetGallery().Name, m.GetName()))) {
+ model = m
+ break
+ }
+ }
+ }
+
+ return model
+}
+
+// List available models
+// Models galleries are a list of yaml files that are hosted on a remote server (for example github).
+// Each yaml file contains a list of models that can be downloaded and optionally overrides to define a new model setting.
+func AvailableGalleryModels(galleries []config.Gallery, systemState *system.SystemState) (GalleryElements[*GalleryModel], error) {
+ var models []*GalleryModel
+
+ // Get models from galleries
+ for _, gallery := range galleries {
+ galleryModels, err := getGalleryElements(gallery, systemState.Model.ModelsPath, func(model *GalleryModel) bool {
+ if _, err := os.Stat(filepath.Join(systemState.Model.ModelsPath, fmt.Sprintf("%s.yaml", model.GetName()))); err == nil {
+ return true
+ }
+ return false
+ })
+ if err != nil {
+ return nil, err
+ }
+ models = append(models, galleryModels...)
+ }
+
+ return models, nil
+}
+
+// List available backends
+func AvailableBackends(galleries []config.Gallery, systemState *system.SystemState) (GalleryElements[*GalleryBackend], error) {
+ return availableBackendsWithFilter(galleries, systemState, true)
+}
+
+// AvailableBackendsUnfiltered returns all available backends without filtering by system capability.
+func AvailableBackendsUnfiltered(galleries []config.Gallery, systemState *system.SystemState) (GalleryElements[*GalleryBackend], error) {
+ return availableBackendsWithFilter(galleries, systemState, false)
+}
+
+// availableBackendsWithFilter is a helper function that lists available backends with optional filtering.
+func availableBackendsWithFilter(galleries []config.Gallery, systemState *system.SystemState, filterByCapability bool) (GalleryElements[*GalleryBackend], error) {
+ var backends []*GalleryBackend
+
+ systemBackends, err := ListSystemBackends(systemState)
+ if err != nil {
+ return nil, err
+ }
+
+ // Get backends from galleries
+ for _, gallery := range galleries {
+ galleryBackends, err := getGalleryElements(gallery, systemState.Backend.BackendsPath, func(backend *GalleryBackend) bool {
+ return systemBackends.Exists(backend.GetName())
+ })
+ if err != nil {
+ return nil, err
+ }
+
+ // Filter backends by system capability if requested
+ if filterByCapability {
+ for _, backend := range galleryBackends {
+ if backend.IsCompatibleWith(systemState) {
+ backends = append(backends, backend)
+ }
+ }
+ } else {
+ backends = append(backends, galleryBackends...)
+ }
+ }
+
+ return backends, nil
+}
+
+func findGalleryURLFromReferenceURL(url string, basePath string) (string, error) {
+ var refFile string
+ uri := downloader.URI(url)
+ err := uri.ReadWithCallback(basePath, func(url string, d []byte) error {
+ refFile = string(d)
+ if len(refFile) == 0 {
+ return fmt.Errorf("invalid reference file at url %s: %s", url, d)
+ }
+ cutPoint := strings.LastIndex(url, "/")
+ refFile = url[:cutPoint+1] + refFile
+ return nil
+ })
+ return refFile, err
+}
+
+type galleryCacheEntry struct {
+ yamlEntry []byte
+ lastUpdated time.Time
+}
+
+func (entry galleryCacheEntry) hasExpired() bool {
+ return entry.lastUpdated.Before(time.Now().Add(-1 * time.Hour))
+}
+
+var galleryCache = xsync.NewSyncedMap[string, galleryCacheEntry]()
+
+func getGalleryElements[T GalleryElement](gallery config.Gallery, basePath string, isInstalledCallback func(T) bool) ([]T, error) {
+ var models []T = []T{}
+
+ if strings.HasSuffix(gallery.URL, ".ref") {
+ var err error
+ gallery.URL, err = findGalleryURLFromReferenceURL(gallery.URL, basePath)
+ if err != nil {
+ return models, err
+ }
+ }
+
+ cacheKey := fmt.Sprintf("%s-%s", gallery.Name, gallery.URL)
+ if galleryCache.Exists(cacheKey) {
+ entry := galleryCache.Get(cacheKey)
+ // refresh if last updated is more than 1 hour ago
+ if !entry.hasExpired() {
+ err := yaml.Unmarshal(entry.yamlEntry, &models)
+ if err != nil {
+ return models, err
+ }
+ } else {
+ galleryCache.Delete(cacheKey)
+ }
+ }
+
+ uri := downloader.URI(gallery.URL)
+
+ if len(models) == 0 {
+ err := uri.ReadWithCallback(basePath, func(url string, d []byte) error {
+ galleryCache.Set(cacheKey, galleryCacheEntry{
+ yamlEntry: d,
+ lastUpdated: time.Now(),
+ })
+ return yaml.Unmarshal(d, &models)
+ })
+ if err != nil {
+ if yamlErr, ok := err.(*yaml.TypeError); ok {
+ xlog.Debug("YAML errors", "errors", strings.Join(yamlErr.Errors, "\n"), "models", models)
+ }
+ return models, fmt.Errorf("failed to read gallery elements: %w", err)
+ }
+ }
+
+ // Add gallery to models
+ for _, model := range models {
+ model.SetGallery(gallery)
+ model.SetInstalled(isInstalledCallback(model))
+ }
+ return models, nil
+}
diff --git a/core/gallery/gallery_suite_test.go b/core/gallery/gallery_suite_test.go
new file mode 100644
index 0000000000000000000000000000000000000000..44256bc27e97051f7ff47dbefc9f80da95b2b2aa
--- /dev/null
+++ b/core/gallery/gallery_suite_test.go
@@ -0,0 +1,13 @@
+package gallery_test
+
+import (
+ "testing"
+
+ . "github.com/onsi/ginkgo/v2"
+ . "github.com/onsi/gomega"
+)
+
+func TestGallery(t *testing.T) {
+ RegisterFailHandler(Fail)
+ RunSpecs(t, "Gallery test suite")
+}
diff --git a/core/gallery/gallery_test.go b/core/gallery/gallery_test.go
new file mode 100644
index 0000000000000000000000000000000000000000..3ba65f2d9a25e83bdf57f13fc5f42dff6fff7c50
--- /dev/null
+++ b/core/gallery/gallery_test.go
@@ -0,0 +1,465 @@
+package gallery_test
+
+import (
+ "os"
+ "path/filepath"
+
+ "github.com/mudler/LocalAI/core/config"
+ . "github.com/mudler/LocalAI/core/gallery"
+ . "github.com/onsi/ginkgo/v2"
+ . "github.com/onsi/gomega"
+ "gopkg.in/yaml.v2"
+)
+
+var _ = Describe("Gallery", func() {
+ var tempDir string
+
+ BeforeEach(func() {
+ var err error
+ tempDir, err = os.MkdirTemp("", "gallery-test-*")
+ Expect(err).NotTo(HaveOccurred())
+ })
+
+ AfterEach(func() {
+ os.RemoveAll(tempDir)
+ })
+
+ Describe("ReadConfigFile", func() {
+ It("should read and unmarshal a valid YAML file", func() {
+ testConfig := map[string]interface{}{
+ "name": "test-model",
+ "description": "A test model",
+ "license": "MIT",
+ }
+ yamlData, err := yaml.Marshal(testConfig)
+ Expect(err).NotTo(HaveOccurred())
+
+ filePath := filepath.Join(tempDir, "test.yaml")
+ err = os.WriteFile(filePath, yamlData, 0644)
+ Expect(err).NotTo(HaveOccurred())
+
+ var result map[string]interface{}
+ config, err := ReadConfigFile[map[string]interface{}](filePath)
+ Expect(err).NotTo(HaveOccurred())
+ Expect(config).NotTo(BeNil())
+ result = *config
+ Expect(result["name"]).To(Equal("test-model"))
+ Expect(result["description"]).To(Equal("A test model"))
+ Expect(result["license"]).To(Equal("MIT"))
+ })
+
+ It("should return error when file does not exist", func() {
+ _, err := ReadConfigFile[map[string]interface{}]("nonexistent.yaml")
+ Expect(err).To(HaveOccurred())
+ })
+
+ It("should return error when YAML is invalid", func() {
+ filePath := filepath.Join(tempDir, "invalid.yaml")
+ err := os.WriteFile(filePath, []byte("invalid: yaml: content: [unclosed"), 0644)
+ Expect(err).NotTo(HaveOccurred())
+
+ _, err = ReadConfigFile[map[string]interface{}](filePath)
+ Expect(err).To(HaveOccurred())
+ })
+ })
+
+ Describe("GalleryElements Search", func() {
+ var elements GalleryElements[*GalleryModel]
+
+ BeforeEach(func() {
+ elements = GalleryElements[*GalleryModel]{
+ {
+ Metadata: Metadata{
+ Name: "bert-embeddings",
+ Description: "BERT model for embeddings",
+ Tags: []string{"embeddings", "bert", "nlp"},
+ License: "Apache-2.0",
+ Gallery: config.Gallery{
+ Name: "huggingface",
+ },
+ },
+ },
+ {
+ Metadata: Metadata{
+ Name: "gpt-2",
+ Description: "GPT-2 language model",
+ Tags: []string{"gpt", "language-model"},
+ License: "MIT",
+ Gallery: config.Gallery{
+ Name: "openai",
+ },
+ },
+ },
+ {
+ Metadata: Metadata{
+ Name: "llama-7b",
+ Description: "LLaMA 7B model",
+ Tags: []string{"llama", "llm"},
+ License: "LLaMA",
+ Gallery: config.Gallery{
+ Name: "meta",
+ },
+ },
+ },
+ }
+ })
+
+ It("should find elements by exact name match", func() {
+ results := elements.Search("bert-embeddings")
+ Expect(results).To(HaveLen(1))
+ Expect(results[0].GetName()).To(Equal("bert-embeddings"))
+ })
+
+ It("should find elements by partial name match", func() {
+ results := elements.Search("bert")
+ Expect(results).To(HaveLen(1))
+ Expect(results[0].GetName()).To(Equal("bert-embeddings"))
+ })
+
+ It("should find elements by description", func() {
+ results := elements.Search("embeddings")
+ Expect(results).To(HaveLen(1))
+ Expect(results[0].GetName()).To(Equal("bert-embeddings"))
+ })
+
+ It("should find elements by gallery name", func() {
+ results := elements.Search("huggingface")
+ Expect(results).To(HaveLen(1))
+ Expect(results[0].GetGallery().Name).To(Equal("huggingface"))
+ })
+
+ It("should find elements by tags", func() {
+ results := elements.Search("nlp")
+ Expect(results).To(HaveLen(1))
+ Expect(results[0].GetName()).To(Equal("bert-embeddings"))
+ })
+
+ It("should be case insensitive", func() {
+ results := elements.Search("BERT")
+ Expect(results).To(HaveLen(1))
+ Expect(results[0].GetName()).To(Equal("bert-embeddings"))
+ })
+
+ It("should find multiple elements", func() {
+ results := elements.Search("gpt")
+ Expect(results).To(HaveLen(1))
+ Expect(results[0].GetName()).To(Equal("gpt-2"))
+ })
+
+ It("should return empty results for no matches", func() {
+ results := elements.Search("nonexistent")
+ Expect(results).To(HaveLen(0))
+ })
+
+ It("should use fuzzy matching", func() {
+ results := elements.Search("bert-emb")
+ Expect(results).To(HaveLen(1))
+ Expect(results[0].GetName()).To(Equal("bert-embeddings"))
+ })
+ })
+
+ Describe("GalleryElements SortByName", func() {
+ var elements GalleryElements[*GalleryModel]
+
+ BeforeEach(func() {
+ elements = GalleryElements[*GalleryModel]{
+ {Metadata: Metadata{Name: "zebra"}},
+ {Metadata: Metadata{Name: "alpha"}},
+ {Metadata: Metadata{Name: "beta"}},
+ }
+ })
+
+ It("should sort ascending", func() {
+ sorted := elements.SortByName("asc")
+ Expect(sorted).To(HaveLen(3))
+ Expect(sorted[0].GetName()).To(Equal("alpha"))
+ Expect(sorted[1].GetName()).To(Equal("beta"))
+ Expect(sorted[2].GetName()).To(Equal("zebra"))
+ })
+
+ It("should sort descending", func() {
+ sorted := elements.SortByName("desc")
+ Expect(sorted).To(HaveLen(3))
+ Expect(sorted[0].GetName()).To(Equal("zebra"))
+ Expect(sorted[1].GetName()).To(Equal("beta"))
+ Expect(sorted[2].GetName()).To(Equal("alpha"))
+ })
+
+ It("should be case insensitive", func() {
+ elements = GalleryElements[*GalleryModel]{
+ {Metadata: Metadata{Name: "Zebra"}},
+ {Metadata: Metadata{Name: "alpha"}},
+ {Metadata: Metadata{Name: "Beta"}},
+ }
+ sorted := elements.SortByName("asc")
+ Expect(sorted[0].GetName()).To(Equal("alpha"))
+ Expect(sorted[1].GetName()).To(Equal("Beta"))
+ Expect(sorted[2].GetName()).To(Equal("Zebra"))
+ })
+ })
+
+ Describe("GalleryElements SortByRepository", func() {
+ var elements GalleryElements[*GalleryModel]
+
+ BeforeEach(func() {
+ elements = GalleryElements[*GalleryModel]{
+ {
+ Metadata: Metadata{
+ Gallery: config.Gallery{Name: "zebra-repo"},
+ },
+ },
+ {
+ Metadata: Metadata{
+ Gallery: config.Gallery{Name: "alpha-repo"},
+ },
+ },
+ {
+ Metadata: Metadata{
+ Gallery: config.Gallery{Name: "beta-repo"},
+ },
+ },
+ }
+ })
+
+ It("should sort ascending", func() {
+ sorted := elements.SortByRepository("asc")
+ Expect(sorted).To(HaveLen(3))
+ Expect(sorted[0].GetGallery().Name).To(Equal("alpha-repo"))
+ Expect(sorted[1].GetGallery().Name).To(Equal("beta-repo"))
+ Expect(sorted[2].GetGallery().Name).To(Equal("zebra-repo"))
+ })
+
+ It("should sort descending", func() {
+ sorted := elements.SortByRepository("desc")
+ Expect(sorted).To(HaveLen(3))
+ Expect(sorted[0].GetGallery().Name).To(Equal("zebra-repo"))
+ Expect(sorted[1].GetGallery().Name).To(Equal("beta-repo"))
+ Expect(sorted[2].GetGallery().Name).To(Equal("alpha-repo"))
+ })
+ })
+
+ Describe("GalleryElements SortByLicense", func() {
+ var elements GalleryElements[*GalleryModel]
+
+ BeforeEach(func() {
+ elements = GalleryElements[*GalleryModel]{
+ {Metadata: Metadata{License: "MIT"}},
+ {Metadata: Metadata{License: "Apache-2.0"}},
+ {Metadata: Metadata{License: ""}},
+ {Metadata: Metadata{License: "GPL-3.0"}},
+ }
+ })
+
+ It("should sort ascending with empty licenses at end", func() {
+ sorted := elements.SortByLicense("asc")
+ Expect(sorted).To(HaveLen(4))
+ Expect(sorted[0].GetLicense()).To(Equal("Apache-2.0"))
+ Expect(sorted[1].GetLicense()).To(Equal("GPL-3.0"))
+ Expect(sorted[2].GetLicense()).To(Equal("MIT"))
+ Expect(sorted[3].GetLicense()).To(Equal(""))
+ })
+
+ It("should sort descending with empty licenses at beginning", func() {
+ sorted := elements.SortByLicense("desc")
+ Expect(sorted).To(HaveLen(4))
+ Expect(sorted[0].GetLicense()).To(Equal(""))
+ Expect(sorted[1].GetLicense()).To(Equal("MIT"))
+ Expect(sorted[2].GetLicense()).To(Equal("GPL-3.0"))
+ Expect(sorted[3].GetLicense()).To(Equal("Apache-2.0"))
+ })
+
+ It("should handle all empty licenses", func() {
+ elements = GalleryElements[*GalleryModel]{
+ {Metadata: Metadata{License: ""}},
+ {Metadata: Metadata{License: ""}},
+ }
+ sorted := elements.SortByLicense("asc")
+ Expect(sorted).To(HaveLen(2))
+ })
+ })
+
+ Describe("GalleryElements SortByInstalled", func() {
+ var elements GalleryElements[*GalleryModel]
+
+ BeforeEach(func() {
+ elements = GalleryElements[*GalleryModel]{
+ {Metadata: Metadata{Name: "installed-2", Installed: true}},
+ {Metadata: Metadata{Name: "not-installed-1", Installed: false}},
+ {Metadata: Metadata{Name: "installed-1", Installed: true}},
+ {Metadata: Metadata{Name: "not-installed-2", Installed: false}},
+ }
+ })
+
+ It("should sort ascending with installed first, then by name", func() {
+ sorted := elements.SortByInstalled("asc")
+ Expect(sorted).To(HaveLen(4))
+ Expect(sorted[0].GetInstalled()).To(BeTrue())
+ Expect(sorted[0].GetName()).To(Equal("installed-1"))
+ Expect(sorted[1].GetInstalled()).To(BeTrue())
+ Expect(sorted[1].GetName()).To(Equal("installed-2"))
+ Expect(sorted[2].GetInstalled()).To(BeFalse())
+ Expect(sorted[2].GetName()).To(Equal("not-installed-1"))
+ Expect(sorted[3].GetInstalled()).To(BeFalse())
+ Expect(sorted[3].GetName()).To(Equal("not-installed-2"))
+ })
+
+ It("should sort descending with not-installed first, then by name", func() {
+ sorted := elements.SortByInstalled("desc")
+ Expect(sorted).To(HaveLen(4))
+ Expect(sorted[0].GetInstalled()).To(BeFalse())
+ Expect(sorted[0].GetName()).To(Equal("not-installed-2"))
+ Expect(sorted[1].GetInstalled()).To(BeFalse())
+ Expect(sorted[1].GetName()).To(Equal("not-installed-1"))
+ Expect(sorted[2].GetInstalled()).To(BeTrue())
+ Expect(sorted[2].GetName()).To(Equal("installed-2"))
+ Expect(sorted[3].GetInstalled()).To(BeTrue())
+ Expect(sorted[3].GetName()).To(Equal("installed-1"))
+ })
+ })
+
+ Describe("GalleryElements FindByName", func() {
+ var elements GalleryElements[*GalleryModel]
+
+ BeforeEach(func() {
+ elements = GalleryElements[*GalleryModel]{
+ {Metadata: Metadata{Name: "bert-embeddings"}},
+ {Metadata: Metadata{Name: "gpt-2"}},
+ {Metadata: Metadata{Name: "llama-7b"}},
+ }
+ })
+
+ It("should find element by exact name", func() {
+ result := elements.FindByName("bert-embeddings")
+ Expect(result).NotTo(BeNil())
+ Expect(result.GetName()).To(Equal("bert-embeddings"))
+ })
+
+ It("should be case insensitive", func() {
+ result := elements.FindByName("BERT-EMBEDDINGS")
+ Expect(result).NotTo(BeNil())
+ Expect(result.GetName()).To(Equal("bert-embeddings"))
+ })
+
+ It("should return zero value when not found", func() {
+ result := elements.FindByName("nonexistent")
+ Expect(result).To(BeNil())
+ })
+ })
+
+ Describe("GalleryElements Paginate", func() {
+ var elements GalleryElements[*GalleryModel]
+
+ BeforeEach(func() {
+ elements = GalleryElements[*GalleryModel]{
+ {Metadata: Metadata{Name: "model-1"}},
+ {Metadata: Metadata{Name: "model-2"}},
+ {Metadata: Metadata{Name: "model-3"}},
+ {Metadata: Metadata{Name: "model-4"}},
+ {Metadata: Metadata{Name: "model-5"}},
+ }
+ })
+
+ It("should return first page", func() {
+ page := elements.Paginate(1, 2)
+ Expect(page).To(HaveLen(2))
+ Expect(page[0].GetName()).To(Equal("model-1"))
+ Expect(page[1].GetName()).To(Equal("model-2"))
+ })
+
+ It("should return second page", func() {
+ page := elements.Paginate(2, 2)
+ Expect(page).To(HaveLen(2))
+ Expect(page[0].GetName()).To(Equal("model-3"))
+ Expect(page[1].GetName()).To(Equal("model-4"))
+ })
+
+ It("should return partial last page", func() {
+ page := elements.Paginate(3, 2)
+ Expect(page).To(HaveLen(1))
+ Expect(page[0].GetName()).To(Equal("model-5"))
+ })
+
+ It("should handle page beyond range", func() {
+ page := elements.Paginate(10, 2)
+ Expect(page).To(HaveLen(0))
+ })
+
+ It("should handle empty elements", func() {
+ empty := GalleryElements[*GalleryModel]{}
+ page := empty.Paginate(1, 10)
+ Expect(page).To(HaveLen(0))
+ })
+ })
+
+ Describe("FindGalleryElement", func() {
+ var models []*GalleryModel
+
+ BeforeEach(func() {
+ models = []*GalleryModel{
+ {
+ Metadata: Metadata{
+ Name: "bert-embeddings",
+ Gallery: config.Gallery{
+ Name: "huggingface",
+ },
+ },
+ },
+ {
+ Metadata: Metadata{
+ Name: "gpt-2",
+ Gallery: config.Gallery{
+ Name: "openai",
+ },
+ },
+ },
+ }
+ })
+
+ It("should find element by name without @ notation", func() {
+ result := FindGalleryElement(models, "bert-embeddings")
+ Expect(result).NotTo(BeNil())
+ Expect(result.GetName()).To(Equal("bert-embeddings"))
+ })
+
+ It("should find element by name with @ notation", func() {
+ result := FindGalleryElement(models, "huggingface@bert-embeddings")
+ Expect(result).NotTo(BeNil())
+ Expect(result.GetName()).To(Equal("bert-embeddings"))
+ Expect(result.GetGallery().Name).To(Equal("huggingface"))
+ })
+
+ It("should be case insensitive", func() {
+ result := FindGalleryElement(models, "BERT-EMBEDDINGS")
+ Expect(result).NotTo(BeNil())
+ Expect(result.GetName()).To(Equal("bert-embeddings"))
+ })
+
+ It("should handle path separators in name", func() {
+ // Path separators are replaced with __, so bert/embeddings becomes bert__embeddings
+ // This test verifies the replacement happens, but won't match unless model name has __
+ modelsWithPath := []*GalleryModel{
+ {
+ Metadata: Metadata{
+ Name: "bert__embeddings",
+ Gallery: config.Gallery{
+ Name: "huggingface",
+ },
+ },
+ },
+ }
+ result := FindGalleryElement(modelsWithPath, "bert/embeddings")
+ Expect(result).NotTo(BeNil())
+ Expect(result.GetName()).To(Equal("bert__embeddings"))
+ })
+
+ It("should return zero value when not found", func() {
+ result := FindGalleryElement(models, "nonexistent")
+ Expect(result).To(BeNil())
+ })
+
+ It("should return zero value when gallery@name not found", func() {
+ result := FindGalleryElement(models, "nonexistent@model")
+ Expect(result).To(BeNil())
+ })
+ })
+})
diff --git a/core/gallery/importers/diffuser.go b/core/gallery/importers/diffuser.go
new file mode 100644
index 0000000000000000000000000000000000000000..c702da3d3025f58597ade2794babf94c9f74d994
--- /dev/null
+++ b/core/gallery/importers/diffuser.go
@@ -0,0 +1,121 @@
+package importers
+
+import (
+ "encoding/json"
+ "path/filepath"
+ "strings"
+
+ "github.com/mudler/LocalAI/core/config"
+ "github.com/mudler/LocalAI/core/gallery"
+ "github.com/mudler/LocalAI/core/schema"
+ "gopkg.in/yaml.v3"
+)
+
+var _ Importer = &DiffuserImporter{}
+
+type DiffuserImporter struct{}
+
+func (i *DiffuserImporter) Match(details Details) bool {
+ preferences, err := details.Preferences.MarshalJSON()
+ if err != nil {
+ return false
+ }
+ preferencesMap := make(map[string]any)
+ err = json.Unmarshal(preferences, &preferencesMap)
+ if err != nil {
+ return false
+ }
+
+ b, ok := preferencesMap["backend"].(string)
+ if ok && b == "diffusers" {
+ return true
+ }
+
+ if details.HuggingFace != nil {
+ for _, file := range details.HuggingFace.Files {
+ if strings.Contains(file.Path, "model_index.json") ||
+ strings.Contains(file.Path, "scheduler/scheduler_config.json") {
+ return true
+ }
+ }
+ }
+
+ return false
+}
+
+func (i *DiffuserImporter) Import(details Details) (gallery.ModelConfig, error) {
+ preferences, err := details.Preferences.MarshalJSON()
+ if err != nil {
+ return gallery.ModelConfig{}, err
+ }
+ preferencesMap := make(map[string]any)
+ err = json.Unmarshal(preferences, &preferencesMap)
+ if err != nil {
+ return gallery.ModelConfig{}, err
+ }
+
+ name, ok := preferencesMap["name"].(string)
+ if !ok {
+ name = filepath.Base(details.URI)
+ }
+
+ description, ok := preferencesMap["description"].(string)
+ if !ok {
+ description = "Imported from " + details.URI
+ }
+
+ backend := "diffusers"
+ b, ok := preferencesMap["backend"].(string)
+ if ok {
+ backend = b
+ }
+
+ pipelineType, ok := preferencesMap["pipeline_type"].(string)
+ if !ok {
+ pipelineType = "StableDiffusionPipeline"
+ }
+
+ schedulerType, ok := preferencesMap["scheduler_type"].(string)
+ if !ok {
+ schedulerType = ""
+ }
+
+ enableParameters, ok := preferencesMap["enable_parameters"].(string)
+ if !ok {
+ enableParameters = "negative_prompt,num_inference_steps"
+ }
+
+ cuda := false
+ if cudaVal, ok := preferencesMap["cuda"].(bool); ok {
+ cuda = cudaVal
+ }
+
+ modelConfig := config.ModelConfig{
+ Name: name,
+ Description: description,
+ KnownUsecaseStrings: []string{"image"},
+ Backend: backend,
+ PredictionOptions: schema.PredictionOptions{
+ BasicModelRequest: schema.BasicModelRequest{
+ Model: details.URI,
+ },
+ },
+ Diffusers: config.Diffusers{
+ PipelineType: pipelineType,
+ SchedulerType: schedulerType,
+ EnableParameters: enableParameters,
+ CUDA: cuda,
+ },
+ }
+
+ data, err := yaml.Marshal(modelConfig)
+ if err != nil {
+ return gallery.ModelConfig{}, err
+ }
+
+ return gallery.ModelConfig{
+ Name: name,
+ Description: description,
+ ConfigFile: string(data),
+ }, nil
+}
diff --git a/core/gallery/importers/diffuser_test.go b/core/gallery/importers/diffuser_test.go
new file mode 100644
index 0000000000000000000000000000000000000000..38765e88bade9668535ce6782b4256d9dcdcb0ca
--- /dev/null
+++ b/core/gallery/importers/diffuser_test.go
@@ -0,0 +1,246 @@
+package importers_test
+
+import (
+ "encoding/json"
+
+ "github.com/mudler/LocalAI/core/gallery/importers"
+ . "github.com/mudler/LocalAI/core/gallery/importers"
+ hfapi "github.com/mudler/LocalAI/pkg/huggingface-api"
+ . "github.com/onsi/ginkgo/v2"
+ . "github.com/onsi/gomega"
+)
+
+var _ = Describe("DiffuserImporter", func() {
+ var importer *DiffuserImporter
+
+ BeforeEach(func() {
+ importer = &DiffuserImporter{}
+ })
+
+ Context("Match", func() {
+ It("should match when backend preference is diffusers", func() {
+ preferences := json.RawMessage(`{"backend": "diffusers"}`)
+ details := Details{
+ URI: "https://example.com/model",
+ Preferences: preferences,
+ }
+
+ result := importer.Match(details)
+ Expect(result).To(BeTrue())
+ })
+
+ It("should match when HuggingFace details contain model_index.json", func() {
+ hfDetails := &hfapi.ModelDetails{
+ Files: []hfapi.ModelFile{
+ {Path: "model_index.json"},
+ },
+ }
+ details := Details{
+ URI: "https://huggingface.co/test/model",
+ HuggingFace: hfDetails,
+ }
+
+ result := importer.Match(details)
+ Expect(result).To(BeTrue())
+ })
+
+ It("should match when HuggingFace details contain scheduler config", func() {
+ hfDetails := &hfapi.ModelDetails{
+ Files: []hfapi.ModelFile{
+ {Path: "scheduler/scheduler_config.json"},
+ },
+ }
+ details := Details{
+ URI: "https://huggingface.co/test/model",
+ HuggingFace: hfDetails,
+ }
+
+ result := importer.Match(details)
+ Expect(result).To(BeTrue())
+ })
+
+ It("should not match when URI has no diffuser files and no backend preference", func() {
+ details := Details{
+ URI: "https://example.com/model.bin",
+ }
+
+ result := importer.Match(details)
+ Expect(result).To(BeFalse())
+ })
+
+ It("should not match when backend preference is different", func() {
+ preferences := json.RawMessage(`{"backend": "llama-cpp"}`)
+ details := Details{
+ URI: "https://example.com/model",
+ Preferences: preferences,
+ }
+
+ result := importer.Match(details)
+ Expect(result).To(BeFalse())
+ })
+
+ It("should return false when JSON preferences are invalid", func() {
+ preferences := json.RawMessage(`invalid json`)
+ details := Details{
+ URI: "https://example.com/model",
+ Preferences: preferences,
+ }
+
+ result := importer.Match(details)
+ Expect(result).To(BeFalse())
+ })
+ })
+
+ Context("Import", func() {
+ It("should import model config with default name and description", func() {
+ details := Details{
+ URI: "https://huggingface.co/test/my-diffuser-model",
+ }
+
+ modelConfig, err := importer.Import(details)
+
+ Expect(err).ToNot(HaveOccurred())
+ Expect(modelConfig.Name).To(Equal("my-diffuser-model"))
+ Expect(modelConfig.Description).To(Equal("Imported from https://huggingface.co/test/my-diffuser-model"))
+ Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: diffusers"))
+ Expect(modelConfig.ConfigFile).To(ContainSubstring("model: https://huggingface.co/test/my-diffuser-model"))
+ Expect(modelConfig.ConfigFile).To(ContainSubstring("pipeline_type: StableDiffusionPipeline"))
+ Expect(modelConfig.ConfigFile).To(ContainSubstring("enable_parameters: negative_prompt,num_inference_steps"))
+ })
+
+ It("should import model config with custom name and description from preferences", func() {
+ preferences := json.RawMessage(`{"name": "custom-diffuser", "description": "Custom diffuser model"}`)
+ details := Details{
+ URI: "https://huggingface.co/test/my-model",
+ Preferences: preferences,
+ }
+
+ modelConfig, err := importer.Import(details)
+
+ Expect(err).ToNot(HaveOccurred())
+ Expect(modelConfig.Name).To(Equal("custom-diffuser"))
+ Expect(modelConfig.Description).To(Equal("Custom diffuser model"))
+ })
+
+ It("should use custom pipeline_type from preferences", func() {
+ preferences := json.RawMessage(`{"pipeline_type": "StableDiffusion3Pipeline"}`)
+ details := Details{
+ URI: "https://huggingface.co/test/my-model",
+ Preferences: preferences,
+ }
+
+ modelConfig, err := importer.Import(details)
+
+ Expect(err).ToNot(HaveOccurred())
+ Expect(modelConfig.ConfigFile).To(ContainSubstring("pipeline_type: StableDiffusion3Pipeline"))
+ })
+
+ It("should use default pipeline_type when not specified", func() {
+ details := Details{
+ URI: "https://huggingface.co/test/my-model",
+ }
+
+ modelConfig, err := importer.Import(details)
+
+ Expect(err).ToNot(HaveOccurred())
+ Expect(modelConfig.ConfigFile).To(ContainSubstring("pipeline_type: StableDiffusionPipeline"))
+ })
+
+ It("should use custom scheduler_type from preferences", func() {
+ preferences := json.RawMessage(`{"scheduler_type": "k_dpmpp_2m"}`)
+ details := Details{
+ URI: "https://huggingface.co/test/my-model",
+ Preferences: preferences,
+ }
+
+ modelConfig, err := importer.Import(details)
+
+ Expect(err).ToNot(HaveOccurred())
+ Expect(modelConfig.ConfigFile).To(ContainSubstring("scheduler_type: k_dpmpp_2m"))
+ })
+
+ It("should use cuda setting from preferences", func() {
+ preferences := json.RawMessage(`{"cuda": true}`)
+ details := Details{
+ URI: "https://huggingface.co/test/my-model",
+ Preferences: preferences,
+ }
+
+ modelConfig, err := importer.Import(details)
+
+ Expect(err).ToNot(HaveOccurred())
+ Expect(modelConfig.ConfigFile).To(ContainSubstring("cuda: true"))
+ })
+
+ It("should use custom enable_parameters from preferences", func() {
+ preferences := json.RawMessage(`{"enable_parameters": "num_inference_steps,guidance_scale"}`)
+ details := Details{
+ URI: "https://huggingface.co/test/my-model",
+ Preferences: preferences,
+ }
+
+ modelConfig, err := importer.Import(details)
+
+ Expect(err).ToNot(HaveOccurred())
+ Expect(modelConfig.ConfigFile).To(ContainSubstring("enable_parameters: num_inference_steps,guidance_scale"))
+ })
+
+ It("should use custom backend from preferences", func() {
+ preferences := json.RawMessage(`{"backend": "diffusers"}`)
+ details := Details{
+ URI: "https://huggingface.co/test/my-model",
+ Preferences: preferences,
+ }
+
+ modelConfig, err := importer.Import(details)
+
+ Expect(err).ToNot(HaveOccurred())
+ Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: diffusers"))
+ })
+
+ It("should handle invalid JSON preferences", func() {
+ preferences := json.RawMessage(`invalid json`)
+ details := Details{
+ URI: "https://huggingface.co/test/my-model",
+ Preferences: preferences,
+ }
+
+ _, err := importer.Import(details)
+ Expect(err).To(HaveOccurred())
+ })
+
+ It("should extract filename correctly from URI with path", func() {
+ details := importers.Details{
+ URI: "https://huggingface.co/test/path/to/model",
+ }
+
+ modelConfig, err := importer.Import(details)
+
+ Expect(err).ToNot(HaveOccurred())
+ Expect(modelConfig.Name).To(Equal("model"))
+ })
+
+ It("should include known_usecases as image in config", func() {
+ details := Details{
+ URI: "https://huggingface.co/test/my-model",
+ }
+
+ modelConfig, err := importer.Import(details)
+
+ Expect(err).ToNot(HaveOccurred())
+ Expect(modelConfig.ConfigFile).To(ContainSubstring("known_usecases:"))
+ Expect(modelConfig.ConfigFile).To(ContainSubstring("- image"))
+ })
+
+ It("should include diffusers configuration in config", func() {
+ details := Details{
+ URI: "https://huggingface.co/test/my-model",
+ }
+
+ modelConfig, err := importer.Import(details)
+
+ Expect(err).ToNot(HaveOccurred())
+ Expect(modelConfig.ConfigFile).To(ContainSubstring("diffusers:"))
+ })
+ })
+})
diff --git a/core/gallery/importers/importers.go b/core/gallery/importers/importers.go
new file mode 100644
index 0000000000000000000000000000000000000000..a5fb96b68b3113118d6a47846a3652e382c10d32
--- /dev/null
+++ b/core/gallery/importers/importers.go
@@ -0,0 +1,121 @@
+package importers
+
+import (
+ "encoding/json"
+ "fmt"
+ "os"
+ "strings"
+
+ "github.com/mudler/xlog"
+ "gopkg.in/yaml.v3"
+
+ "github.com/mudler/LocalAI/core/config"
+ "github.com/mudler/LocalAI/core/gallery"
+ "github.com/mudler/LocalAI/pkg/downloader"
+ hfapi "github.com/mudler/LocalAI/pkg/huggingface-api"
+)
+
+var defaultImporters = []Importer{
+ &LlamaCPPImporter{},
+ &MLXImporter{},
+ &VLLMImporter{},
+ &TransformersImporter{},
+ &DiffuserImporter{},
+}
+
+type Details struct {
+ HuggingFace *hfapi.ModelDetails
+ URI string
+ Preferences json.RawMessage
+}
+
+type Importer interface {
+ Match(details Details) bool
+ Import(details Details) (gallery.ModelConfig, error)
+}
+
+func hasYAMLExtension(uri string) bool {
+ return strings.HasSuffix(uri, ".yaml") || strings.HasSuffix(uri, ".yml")
+}
+
+func DiscoverModelConfig(uri string, preferences json.RawMessage) (gallery.ModelConfig, error) {
+ var err error
+ var modelConfig gallery.ModelConfig
+
+ hf := hfapi.NewClient()
+
+ hfrepoID := strings.ReplaceAll(uri, "huggingface://", "")
+ hfrepoID = strings.ReplaceAll(hfrepoID, "hf://", "")
+ hfrepoID = strings.ReplaceAll(hfrepoID, "https://huggingface.co/", "")
+
+ hfDetails, err := hf.GetModelDetails(hfrepoID)
+ if err != nil {
+ // maybe not a HF repository
+ // TODO: maybe we can check if the URI is a valid HF repository
+ xlog.Debug("Failed to get model details, maybe not a HF repository", "uri", uri, "hfrepoID", hfrepoID)
+ } else {
+ xlog.Debug("Got model details", "uri", uri)
+ xlog.Debug("Model details", "details", hfDetails)
+ }
+
+ // handle local config files ("/my-model.yaml" or "file://my-model.yaml")
+ localURI := uri
+ if strings.HasPrefix(uri, downloader.LocalPrefix) {
+ localURI = strings.TrimPrefix(uri, downloader.LocalPrefix)
+ }
+
+ // if a file exists or it's an url that ends with .yaml or .yml, read the config file directly
+ if _, e := os.Stat(localURI); hasYAMLExtension(localURI) && (e == nil || downloader.URI(localURI).LooksLikeURL()) {
+ var modelYAML []byte
+ if downloader.URI(localURI).LooksLikeURL() {
+ err := downloader.URI(localURI).ReadWithCallback(localURI, func(url string, i []byte) error {
+ modelYAML = i
+ return nil
+ })
+ if err != nil {
+ xlog.Error("error reading model definition", "error", err, "filepath", localURI)
+ return gallery.ModelConfig{}, err
+ }
+ } else {
+ modelYAML, err = os.ReadFile(localURI)
+ if err != nil {
+ xlog.Error("error reading model definition", "error", err, "filepath", localURI)
+ return gallery.ModelConfig{}, err
+ }
+ }
+
+ var modelConfig config.ModelConfig
+ if e := yaml.Unmarshal(modelYAML, &modelConfig); e != nil {
+ return gallery.ModelConfig{}, e
+ }
+
+ configFile, err := yaml.Marshal(modelConfig)
+ return gallery.ModelConfig{
+ Description: modelConfig.Description,
+ Name: modelConfig.Name,
+ ConfigFile: string(configFile),
+ }, err
+ }
+
+ details := Details{
+ HuggingFace: hfDetails,
+ URI: uri,
+ Preferences: preferences,
+ }
+
+ importerMatched := false
+ for _, importer := range defaultImporters {
+ if importer.Match(details) {
+ importerMatched = true
+ modelConfig, err = importer.Import(details)
+ if err != nil {
+ continue
+ }
+ break
+ }
+ }
+ if !importerMatched {
+ return gallery.ModelConfig{}, fmt.Errorf("no importer matched for %s", uri)
+ }
+ return modelConfig, nil
+}
diff --git a/core/gallery/importers/importers_suite_test.go b/core/gallery/importers/importers_suite_test.go
new file mode 100644
index 0000000000000000000000000000000000000000..a65b8163ad5619431c3e0a0248f7120b0ce16422
--- /dev/null
+++ b/core/gallery/importers/importers_suite_test.go
@@ -0,0 +1,13 @@
+package importers_test
+
+import (
+ "testing"
+
+ . "github.com/onsi/ginkgo/v2"
+ . "github.com/onsi/gomega"
+)
+
+func TestImporters(t *testing.T) {
+ RegisterFailHandler(Fail)
+ RunSpecs(t, "Importers test suite")
+}
diff --git a/core/gallery/importers/importers_test.go b/core/gallery/importers/importers_test.go
new file mode 100644
index 0000000000000000000000000000000000000000..c009e51daf2efea08305547aa929ca642720a394
--- /dev/null
+++ b/core/gallery/importers/importers_test.go
@@ -0,0 +1,352 @@
+package importers_test
+
+import (
+ "encoding/json"
+ "fmt"
+ "os"
+ "path/filepath"
+
+ "github.com/mudler/LocalAI/core/gallery/importers"
+ . "github.com/onsi/ginkgo/v2"
+ . "github.com/onsi/gomega"
+)
+
+var _ = Describe("DiscoverModelConfig", func() {
+
+ Context("With only a repository URI", func() {
+ It("should discover and import using LlamaCPPImporter", func() {
+ uri := "https://huggingface.co/mudler/LocalAI-functioncall-qwen2.5-7b-v0.5-Q4_K_M-GGUF"
+ preferences := json.RawMessage(`{}`)
+
+ modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
+
+ Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("Error: %v", err))
+ Expect(modelConfig.Name).To(Equal("LocalAI-functioncall-qwen2.5-7b-v0.5-Q4_K_M-GGUF"), fmt.Sprintf("Model config: %+v", modelConfig))
+ Expect(modelConfig.Description).To(Equal("Imported from https://huggingface.co/mudler/LocalAI-functioncall-qwen2.5-7b-v0.5-Q4_K_M-GGUF"), fmt.Sprintf("Model config: %+v", modelConfig))
+ Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: llama-cpp"), fmt.Sprintf("Model config: %+v", modelConfig))
+ Expect(len(modelConfig.Files)).To(Equal(1), fmt.Sprintf("Model config: %+v", modelConfig))
+ Expect(modelConfig.Files[0].Filename).To(Equal("llama-cpp/models/localai-functioncall-qwen2.5-7b-v0.5-q4_k_m.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
+ Expect(modelConfig.Files[0].URI).To(Equal("https://huggingface.co/mudler/LocalAI-functioncall-qwen2.5-7b-v0.5-Q4_K_M-GGUF/resolve/main/localai-functioncall-qwen2.5-7b-v0.5-q4_k_m.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
+ Expect(modelConfig.Files[0].SHA256).To(Equal("4e7b7fe1d54b881f1ef90799219dc6cc285d29db24f559c8998d1addb35713d4"), fmt.Sprintf("Model config: %+v", modelConfig))
+ })
+
+ It("should discover and import using LlamaCPPImporter", func() {
+ uri := "https://huggingface.co/Qwen/Qwen3-VL-2B-Instruct-GGUF"
+ preferences := json.RawMessage(`{}`)
+
+ modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
+
+ Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("Error: %v", err))
+ Expect(modelConfig.Name).To(Equal("Qwen3-VL-2B-Instruct-GGUF"), fmt.Sprintf("Model config: %+v", modelConfig))
+ Expect(modelConfig.Description).To(Equal("Imported from https://huggingface.co/Qwen/Qwen3-VL-2B-Instruct-GGUF"), fmt.Sprintf("Model config: %+v", modelConfig))
+ Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: llama-cpp"), fmt.Sprintf("Model config: %+v", modelConfig))
+ Expect(modelConfig.ConfigFile).To(ContainSubstring("mmproj: llama-cpp/mmproj/mmproj-Qwen3VL-2B-Instruct-Q8_0.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
+ Expect(modelConfig.ConfigFile).To(ContainSubstring("model: llama-cpp/models/Qwen3VL-2B-Instruct-Q4_K_M.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
+ Expect(len(modelConfig.Files)).To(Equal(2), fmt.Sprintf("Model config: %+v", modelConfig))
+ Expect(modelConfig.Files[0].Filename).To(Equal("llama-cpp/models/Qwen3VL-2B-Instruct-Q4_K_M.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
+ Expect(modelConfig.Files[0].URI).To(Equal("https://huggingface.co/Qwen/Qwen3-VL-2B-Instruct-GGUF/resolve/main/Qwen3VL-2B-Instruct-Q4_K_M.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
+ Expect(modelConfig.Files[0].SHA256).ToNot(BeEmpty(), fmt.Sprintf("Model config: %+v", modelConfig))
+ Expect(modelConfig.Files[1].Filename).To(Equal("llama-cpp/mmproj/mmproj-Qwen3VL-2B-Instruct-Q8_0.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
+ Expect(modelConfig.Files[1].URI).To(Equal("https://huggingface.co/Qwen/Qwen3-VL-2B-Instruct-GGUF/resolve/main/mmproj-Qwen3VL-2B-Instruct-Q8_0.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
+ Expect(modelConfig.Files[1].SHA256).ToNot(BeEmpty(), fmt.Sprintf("Model config: %+v", modelConfig))
+ })
+
+ It("should discover and import using LlamaCPPImporter", func() {
+ uri := "https://huggingface.co/Qwen/Qwen3-VL-2B-Instruct-GGUF"
+ preferences := json.RawMessage(`{ "quantizations": "Q8_0", "mmproj_quantizations": "f16" }`)
+
+ modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
+
+ Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("Error: %v", err))
+ Expect(modelConfig.Name).To(Equal("Qwen3-VL-2B-Instruct-GGUF"), fmt.Sprintf("Model config: %+v", modelConfig))
+ Expect(modelConfig.Description).To(Equal("Imported from https://huggingface.co/Qwen/Qwen3-VL-2B-Instruct-GGUF"), fmt.Sprintf("Model config: %+v", modelConfig))
+ Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: llama-cpp"), fmt.Sprintf("Model config: %+v", modelConfig))
+ Expect(modelConfig.ConfigFile).To(ContainSubstring("mmproj: llama-cpp/mmproj/mmproj-Qwen3VL-2B-Instruct-F16.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
+ Expect(modelConfig.ConfigFile).To(ContainSubstring("model: llama-cpp/models/Qwen3VL-2B-Instruct-Q8_0.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
+ Expect(len(modelConfig.Files)).To(Equal(2), fmt.Sprintf("Model config: %+v", modelConfig))
+ Expect(modelConfig.Files[0].Filename).To(Equal("llama-cpp/models/Qwen3VL-2B-Instruct-Q8_0.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
+ Expect(modelConfig.Files[0].URI).To(Equal("https://huggingface.co/Qwen/Qwen3-VL-2B-Instruct-GGUF/resolve/main/Qwen3VL-2B-Instruct-Q8_0.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
+ Expect(modelConfig.Files[0].SHA256).ToNot(BeEmpty(), fmt.Sprintf("Model config: %+v", modelConfig))
+ Expect(modelConfig.Files[1].Filename).To(Equal("llama-cpp/mmproj/mmproj-Qwen3VL-2B-Instruct-F16.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
+ Expect(modelConfig.Files[1].URI).To(Equal("https://huggingface.co/Qwen/Qwen3-VL-2B-Instruct-GGUF/resolve/main/mmproj-Qwen3VL-2B-Instruct-F16.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
+ Expect(modelConfig.Files[1].SHA256).ToNot(BeEmpty(), fmt.Sprintf("Model config: %+v", modelConfig))
+ })
+ })
+
+ Context("with .gguf URI", func() {
+ It("should discover and import using LlamaCPPImporter", func() {
+ uri := "https://example.com/my-model.gguf"
+ preferences := json.RawMessage(`{}`)
+
+ modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
+
+ Expect(err).ToNot(HaveOccurred())
+ Expect(modelConfig.Name).To(Equal("my-model.gguf"))
+ Expect(modelConfig.Description).To(Equal("Imported from https://example.com/my-model.gguf"))
+ Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: llama-cpp"))
+ })
+
+ It("should use custom preferences when provided", func() {
+ uri := "https://example.com/my-model.gguf"
+ preferences := json.RawMessage(`{"name": "custom-name", "description": "Custom description"}`)
+
+ modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
+
+ Expect(err).ToNot(HaveOccurred())
+ Expect(modelConfig.Name).To(Equal("custom-name"))
+ Expect(modelConfig.Description).To(Equal("Custom description"))
+ })
+ })
+
+ Context("with mlx-community URI", func() {
+ It("should discover and import using MLXImporter", func() {
+ uri := "https://huggingface.co/mlx-community/test-model"
+ preferences := json.RawMessage(`{}`)
+
+ modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
+
+ Expect(err).ToNot(HaveOccurred())
+ Expect(modelConfig.Name).To(Equal("test-model"))
+ Expect(modelConfig.Description).To(Equal("Imported from https://huggingface.co/mlx-community/test-model"))
+ Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: mlx"))
+ })
+
+ It("should use custom preferences when provided", func() {
+ uri := "https://huggingface.co/mlx-community/test-model"
+ preferences := json.RawMessage(`{"name": "custom-mlx", "description": "Custom MLX description"}`)
+
+ modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
+
+ Expect(err).ToNot(HaveOccurred())
+ Expect(modelConfig.Name).To(Equal("custom-mlx"))
+ Expect(modelConfig.Description).To(Equal("Custom MLX description"))
+ })
+ })
+
+ Context("with backend preference", func() {
+ It("should use llama-cpp backend when specified", func() {
+ uri := "https://example.com/model"
+ preferences := json.RawMessage(`{"backend": "llama-cpp"}`)
+
+ modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
+
+ Expect(err).ToNot(HaveOccurred())
+ Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: llama-cpp"))
+ })
+
+ It("should use mlx backend when specified", func() {
+ uri := "https://example.com/model"
+ preferences := json.RawMessage(`{"backend": "mlx"}`)
+
+ modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
+
+ Expect(err).ToNot(HaveOccurred())
+ Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: mlx"))
+ })
+
+ It("should use mlx-vlm backend when specified", func() {
+ uri := "https://example.com/model"
+ preferences := json.RawMessage(`{"backend": "mlx-vlm"}`)
+
+ modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
+
+ Expect(err).ToNot(HaveOccurred())
+ Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: mlx-vlm"))
+ })
+ })
+
+ Context("with HuggingFace URI formats", func() {
+ It("should handle huggingface:// prefix", func() {
+ uri := "huggingface://mlx-community/test-model"
+ preferences := json.RawMessage(`{}`)
+
+ modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
+
+ Expect(err).ToNot(HaveOccurred())
+ Expect(modelConfig.Name).To(Equal("test-model"))
+ })
+
+ It("should handle hf:// prefix", func() {
+ uri := "hf://mlx-community/test-model"
+ preferences := json.RawMessage(`{}`)
+
+ modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
+
+ Expect(err).ToNot(HaveOccurred())
+ Expect(modelConfig.Name).To(Equal("test-model"))
+ })
+
+ It("should handle https://huggingface.co/ prefix", func() {
+ uri := "https://huggingface.co/mlx-community/test-model"
+ preferences := json.RawMessage(`{}`)
+
+ modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
+
+ Expect(err).ToNot(HaveOccurred())
+ Expect(modelConfig.Name).To(Equal("test-model"))
+ })
+ })
+
+ Context("with invalid or non-matching URI", func() {
+ It("should return error when no importer matches", func() {
+ uri := "https://example.com/unknown-model.bin"
+ preferences := json.RawMessage(`{}`)
+
+ modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
+
+ // When no importer matches, the function returns empty config and error
+ // The exact behavior depends on implementation, but typically an error is returned
+ Expect(modelConfig.Name).To(BeEmpty())
+ Expect(err).To(HaveOccurred())
+ })
+ })
+
+ Context("with invalid JSON preferences", func() {
+ It("should return error when JSON is invalid even if URI matches", func() {
+ uri := "https://example.com/model.gguf"
+ preferences := json.RawMessage(`invalid json`)
+
+ // Even though Match() returns true for .gguf extension,
+ // Import() will fail when trying to unmarshal invalid JSON preferences
+ modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
+
+ Expect(err).To(HaveOccurred())
+ Expect(modelConfig.Name).To(BeEmpty())
+ })
+ })
+
+ Context("with local YAML config files", func() {
+ var tempDir string
+
+ BeforeEach(func() {
+ var err error
+ tempDir, err = os.MkdirTemp("", "importers-test-*")
+ Expect(err).ToNot(HaveOccurred())
+ })
+
+ AfterEach(func() {
+ os.RemoveAll(tempDir)
+ })
+
+ It("should read local YAML file with file:// prefix", func() {
+ yamlContent := `name: test-model
+backend: llama-cpp
+description: Test model from local YAML
+parameters:
+ model: /path/to/model.gguf
+ temperature: 0.7
+`
+ yamlFile := filepath.Join(tempDir, "test-model.yaml")
+ err := os.WriteFile(yamlFile, []byte(yamlContent), 0644)
+ Expect(err).ToNot(HaveOccurred())
+
+ uri := "file://" + yamlFile
+ preferences := json.RawMessage(`{}`)
+
+ modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
+
+ Expect(err).ToNot(HaveOccurred())
+ Expect(modelConfig.Name).To(Equal("test-model"))
+ Expect(modelConfig.Description).To(Equal("Test model from local YAML"))
+ Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: llama-cpp"))
+ Expect(modelConfig.ConfigFile).To(ContainSubstring("name: test-model"))
+ })
+
+ It("should read local YAML file without file:// prefix (direct path)", func() {
+ yamlContent := `name: direct-path-model
+backend: mlx
+description: Test model from direct path
+parameters:
+ model: /path/to/model.safetensors
+`
+ yamlFile := filepath.Join(tempDir, "direct-model.yaml")
+ err := os.WriteFile(yamlFile, []byte(yamlContent), 0644)
+ Expect(err).ToNot(HaveOccurred())
+
+ uri := yamlFile
+ preferences := json.RawMessage(`{}`)
+
+ modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
+
+ Expect(err).ToNot(HaveOccurred())
+ Expect(modelConfig.Name).To(Equal("direct-path-model"))
+ Expect(modelConfig.Description).To(Equal("Test model from direct path"))
+ Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: mlx"))
+ })
+
+ It("should read local YAML file with .yml extension", func() {
+ yamlContent := `name: yml-extension-model
+backend: transformers
+description: Test model with .yml extension
+parameters:
+ model: /path/to/model
+`
+ yamlFile := filepath.Join(tempDir, "test-model.yml")
+ err := os.WriteFile(yamlFile, []byte(yamlContent), 0644)
+ Expect(err).ToNot(HaveOccurred())
+
+ uri := "file://" + yamlFile
+ preferences := json.RawMessage(`{}`)
+
+ modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
+
+ Expect(err).ToNot(HaveOccurred())
+ Expect(modelConfig.Name).To(Equal("yml-extension-model"))
+ Expect(modelConfig.Description).To(Equal("Test model with .yml extension"))
+ Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: transformers"))
+ })
+
+ It("should ignore preferences when reading YAML files directly", func() {
+ yamlContent := `name: yaml-model
+backend: llama-cpp
+description: Original description
+parameters:
+ model: /path/to/model.gguf
+`
+ yamlFile := filepath.Join(tempDir, "prefs-test.yaml")
+ err := os.WriteFile(yamlFile, []byte(yamlContent), 0644)
+ Expect(err).ToNot(HaveOccurred())
+
+ uri := "file://" + yamlFile
+ // Preferences should be ignored when reading YAML directly
+ preferences := json.RawMessage(`{"name": "custom-name", "description": "Custom description", "backend": "mlx"}`)
+
+ modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
+
+ Expect(err).ToNot(HaveOccurred())
+ // Should use values from YAML file, not preferences
+ Expect(modelConfig.Name).To(Equal("yaml-model"))
+ Expect(modelConfig.Description).To(Equal("Original description"))
+ Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: llama-cpp"))
+ })
+
+ It("should return error when local YAML file doesn't exist", func() {
+ nonExistentFile := filepath.Join(tempDir, "nonexistent.yaml")
+ uri := "file://" + nonExistentFile
+ preferences := json.RawMessage(`{}`)
+
+ modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
+
+ Expect(err).To(HaveOccurred())
+ Expect(modelConfig.Name).To(BeEmpty())
+ })
+
+ It("should return error when YAML file is invalid/malformed", func() {
+ invalidYaml := `name: invalid-model
+backend: llama-cpp
+invalid: yaml: content: [unclosed bracket
+`
+ yamlFile := filepath.Join(tempDir, "invalid.yaml")
+ err := os.WriteFile(yamlFile, []byte(invalidYaml), 0644)
+ Expect(err).ToNot(HaveOccurred())
+
+ uri := "file://" + yamlFile
+ preferences := json.RawMessage(`{}`)
+
+ modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
+
+ Expect(err).To(HaveOccurred())
+ Expect(modelConfig.Name).To(BeEmpty())
+ })
+ })
+})
diff --git a/core/gallery/importers/llama-cpp.go b/core/gallery/importers/llama-cpp.go
new file mode 100644
index 0000000000000000000000000000000000000000..ae9ec042d7b83977493c21098f97892246f0070d
--- /dev/null
+++ b/core/gallery/importers/llama-cpp.go
@@ -0,0 +1,260 @@
+package importers
+
+import (
+ "encoding/json"
+ "path/filepath"
+ "slices"
+ "strings"
+
+ "github.com/mudler/LocalAI/core/config"
+ "github.com/mudler/LocalAI/core/gallery"
+ "github.com/mudler/LocalAI/core/schema"
+ "github.com/mudler/LocalAI/pkg/downloader"
+ "github.com/mudler/LocalAI/pkg/functions"
+ "github.com/mudler/xlog"
+ "go.yaml.in/yaml/v2"
+)
+
+var _ Importer = &LlamaCPPImporter{}
+
+type LlamaCPPImporter struct{}
+
+func (i *LlamaCPPImporter) Match(details Details) bool {
+ preferences, err := details.Preferences.MarshalJSON()
+ if err != nil {
+ xlog.Error("failed to marshal preferences", "error", err)
+ return false
+ }
+
+ preferencesMap := make(map[string]any)
+
+ if len(preferences) > 0 {
+ err = json.Unmarshal(preferences, &preferencesMap)
+ if err != nil {
+ xlog.Error("failed to unmarshal preferences", "error", err)
+ return false
+ }
+ }
+
+ uri := downloader.URI(details.URI)
+
+ if preferencesMap["backend"] == "llama-cpp" {
+ return true
+ }
+
+ if strings.HasSuffix(details.URI, ".gguf") {
+ return true
+ }
+
+ if uri.LooksLikeOCI() {
+ return true
+ }
+
+ if details.HuggingFace != nil {
+ for _, file := range details.HuggingFace.Files {
+ if strings.HasSuffix(file.Path, ".gguf") {
+ return true
+ }
+ }
+ }
+
+ return false
+}
+
+func (i *LlamaCPPImporter) Import(details Details) (gallery.ModelConfig, error) {
+
+ xlog.Debug("llama.cpp importer matched", "uri", details.URI)
+
+ preferences, err := details.Preferences.MarshalJSON()
+ if err != nil {
+ return gallery.ModelConfig{}, err
+ }
+ preferencesMap := make(map[string]any)
+ if len(preferences) > 0 {
+ err = json.Unmarshal(preferences, &preferencesMap)
+ if err != nil {
+ return gallery.ModelConfig{}, err
+ }
+ }
+
+ name, ok := preferencesMap["name"].(string)
+ if !ok {
+ name = filepath.Base(details.URI)
+ }
+
+ description, ok := preferencesMap["description"].(string)
+ if !ok {
+ description = "Imported from " + details.URI
+ }
+
+ preferedQuantizations, _ := preferencesMap["quantizations"].(string)
+ quants := []string{"q4_k_m"}
+ if preferedQuantizations != "" {
+ quants = strings.Split(preferedQuantizations, ",")
+ }
+
+ mmprojQuants, _ := preferencesMap["mmproj_quantizations"].(string)
+ mmprojQuantsList := []string{"fp16"}
+ if mmprojQuants != "" {
+ mmprojQuantsList = strings.Split(mmprojQuants, ",")
+ }
+
+ embeddings, _ := preferencesMap["embeddings"].(string)
+
+ modelConfig := config.ModelConfig{
+ Name: name,
+ Description: description,
+ KnownUsecaseStrings: []string{"chat"},
+ Options: []string{"use_jinja:true"},
+ Backend: "llama-cpp",
+ TemplateConfig: config.TemplateConfig{
+ UseTokenizerTemplate: true,
+ },
+ FunctionsConfig: functions.FunctionsConfig{
+ GrammarConfig: functions.GrammarConfig{
+ NoGrammar: true,
+ },
+ },
+ }
+
+ if embeddings != "" && strings.ToLower(embeddings) == "true" || strings.ToLower(embeddings) == "yes" {
+ trueV := true
+ modelConfig.Embeddings = &trueV
+ }
+
+ cfg := gallery.ModelConfig{
+ Name: name,
+ Description: description,
+ }
+
+ uri := downloader.URI(details.URI)
+
+ switch {
+ case uri.LooksLikeOCI():
+ ociName := strings.TrimPrefix(string(uri), downloader.OCIPrefix)
+ ociName = strings.TrimPrefix(ociName, downloader.OllamaPrefix)
+ ociName = strings.ReplaceAll(ociName, "/", "__")
+ ociName = strings.ReplaceAll(ociName, ":", "__")
+ cfg.Files = append(cfg.Files, gallery.File{
+ URI: details.URI,
+ Filename: ociName,
+ })
+ modelConfig.PredictionOptions = schema.PredictionOptions{
+ BasicModelRequest: schema.BasicModelRequest{
+ Model: ociName,
+ },
+ }
+ case uri.LooksLikeURL() && strings.HasSuffix(details.URI, ".gguf"):
+ // Extract filename from URL
+ fileName, e := uri.FilenameFromUrl()
+ if e != nil {
+ return gallery.ModelConfig{}, e
+ }
+
+ cfg.Files = append(cfg.Files, gallery.File{
+ URI: details.URI,
+ Filename: fileName,
+ })
+ modelConfig.PredictionOptions = schema.PredictionOptions{
+ BasicModelRequest: schema.BasicModelRequest{
+ Model: fileName,
+ },
+ }
+ case strings.HasSuffix(details.URI, ".gguf"):
+ cfg.Files = append(cfg.Files, gallery.File{
+ URI: details.URI,
+ Filename: filepath.Base(details.URI),
+ })
+ modelConfig.PredictionOptions = schema.PredictionOptions{
+ BasicModelRequest: schema.BasicModelRequest{
+ Model: filepath.Base(details.URI),
+ },
+ }
+ case details.HuggingFace != nil:
+ // We want to:
+ // Get first the chosen quants that match filenames
+ // OR the first mmproj/gguf file found
+ var lastMMProjFile *gallery.File
+ var lastGGUFFile *gallery.File
+ foundPreferedQuant := false
+ foundPreferedMMprojQuant := false
+
+ for _, file := range details.HuggingFace.Files {
+ // Get the mmproj prefered quants
+ if strings.Contains(strings.ToLower(file.Path), "mmproj") {
+ lastMMProjFile = &gallery.File{
+ URI: file.URL,
+ Filename: filepath.Join("llama-cpp", "mmproj", filepath.Base(file.Path)),
+ SHA256: file.SHA256,
+ }
+ if slices.ContainsFunc(mmprojQuantsList, func(quant string) bool {
+ return strings.Contains(strings.ToLower(file.Path), strings.ToLower(quant))
+ }) {
+ cfg.Files = append(cfg.Files, *lastMMProjFile)
+ foundPreferedMMprojQuant = true
+ }
+ } else if strings.HasSuffix(strings.ToLower(file.Path), "gguf") {
+ lastGGUFFile = &gallery.File{
+ URI: file.URL,
+ Filename: filepath.Join("llama-cpp", "models", filepath.Base(file.Path)),
+ SHA256: file.SHA256,
+ }
+ // get the files of the prefered quants
+ if slices.ContainsFunc(quants, func(quant string) bool {
+ return strings.Contains(strings.ToLower(file.Path), strings.ToLower(quant))
+ }) {
+ foundPreferedQuant = true
+ cfg.Files = append(cfg.Files, *lastGGUFFile)
+ }
+ }
+ }
+
+ // Make sure to add at least one file if not already present (which is the latest one)
+ if lastMMProjFile != nil && !foundPreferedMMprojQuant {
+ if !slices.ContainsFunc(cfg.Files, func(f gallery.File) bool {
+ return f.Filename == lastMMProjFile.Filename
+ }) {
+ cfg.Files = append(cfg.Files, *lastMMProjFile)
+ }
+ }
+
+ if lastGGUFFile != nil && !foundPreferedQuant {
+ if !slices.ContainsFunc(cfg.Files, func(f gallery.File) bool {
+ return f.Filename == lastGGUFFile.Filename
+ }) {
+ cfg.Files = append(cfg.Files, *lastGGUFFile)
+ }
+ }
+
+ // Find first mmproj file and configure it in the config file
+ for _, file := range cfg.Files {
+ if !strings.Contains(strings.ToLower(file.Filename), "mmproj") {
+ continue
+ }
+ modelConfig.MMProj = file.Filename
+ break
+ }
+
+ // Find first non-mmproj file and configure it in the config file
+ for _, file := range cfg.Files {
+ if strings.Contains(strings.ToLower(file.Filename), "mmproj") {
+ continue
+ }
+ modelConfig.PredictionOptions = schema.PredictionOptions{
+ BasicModelRequest: schema.BasicModelRequest{
+ Model: file.Filename,
+ },
+ }
+ break
+ }
+ }
+
+ data, err := yaml.Marshal(modelConfig)
+ if err != nil {
+ return gallery.ModelConfig{}, err
+ }
+
+ cfg.ConfigFile = string(data)
+
+ return cfg, nil
+}
diff --git a/core/gallery/importers/llama-cpp_test.go b/core/gallery/importers/llama-cpp_test.go
new file mode 100644
index 0000000000000000000000000000000000000000..a9fe17335c1df3a1375f2adc4f0dedb74f0d39ba
--- /dev/null
+++ b/core/gallery/importers/llama-cpp_test.go
@@ -0,0 +1,132 @@
+package importers_test
+
+import (
+ "encoding/json"
+ "fmt"
+
+ "github.com/mudler/LocalAI/core/gallery/importers"
+ . "github.com/mudler/LocalAI/core/gallery/importers"
+ . "github.com/onsi/ginkgo/v2"
+ . "github.com/onsi/gomega"
+)
+
+var _ = Describe("LlamaCPPImporter", func() {
+ var importer *LlamaCPPImporter
+
+ BeforeEach(func() {
+ importer = &LlamaCPPImporter{}
+ })
+
+ Context("Match", func() {
+ It("should match when URI ends with .gguf", func() {
+ details := Details{
+ URI: "https://example.com/model.gguf",
+ }
+
+ result := importer.Match(details)
+ Expect(result).To(BeTrue())
+ })
+
+ It("should match when backend preference is llama-cpp", func() {
+ preferences := json.RawMessage(`{"backend": "llama-cpp"}`)
+ details := Details{
+ URI: "https://example.com/model",
+ Preferences: preferences,
+ }
+
+ result := importer.Match(details)
+ Expect(result).To(BeTrue())
+ })
+
+ It("should not match when URI does not end with .gguf and no backend preference", func() {
+ details := Details{
+ URI: "https://example.com/model.bin",
+ }
+
+ result := importer.Match(details)
+ Expect(result).To(BeFalse())
+ })
+
+ It("should not match when backend preference is different", func() {
+ preferences := json.RawMessage(`{"backend": "mlx"}`)
+ details := Details{
+ URI: "https://example.com/model",
+ Preferences: preferences,
+ }
+
+ result := importer.Match(details)
+ Expect(result).To(BeFalse())
+ })
+
+ It("should return false when JSON preferences are invalid", func() {
+ preferences := json.RawMessage(`invalid json`)
+ details := Details{
+ URI: "https://example.com/model.gguf",
+ Preferences: preferences,
+ }
+
+ // Invalid JSON causes Match to return false early
+ result := importer.Match(details)
+ Expect(result).To(BeFalse())
+ })
+ })
+
+ Context("Import", func() {
+ It("should import model config with default name and description", func() {
+ details := Details{
+ URI: "https://example.com/my-model.gguf",
+ }
+
+ modelConfig, err := importer.Import(details)
+
+ Expect(err).ToNot(HaveOccurred())
+ Expect(modelConfig.Name).To(Equal("my-model.gguf"))
+ Expect(modelConfig.Description).To(Equal("Imported from https://example.com/my-model.gguf"))
+ Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: llama-cpp"))
+ Expect(len(modelConfig.Files)).To(Equal(1), fmt.Sprintf("Model config: %+v", modelConfig))
+ Expect(modelConfig.Files[0].URI).To(Equal("https://example.com/my-model.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
+ Expect(modelConfig.Files[0].Filename).To(Equal("my-model.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
+ })
+
+ It("should import model config with custom name and description from preferences", func() {
+ preferences := json.RawMessage(`{"name": "custom-model", "description": "Custom description"}`)
+ details := Details{
+ URI: "https://example.com/my-model.gguf",
+ Preferences: preferences,
+ }
+
+ modelConfig, err := importer.Import(details)
+
+ Expect(err).ToNot(HaveOccurred())
+ Expect(modelConfig.Name).To(Equal("custom-model"))
+ Expect(modelConfig.Description).To(Equal("Custom description"))
+ Expect(len(modelConfig.Files)).To(Equal(1), fmt.Sprintf("Model config: %+v", modelConfig))
+ Expect(modelConfig.Files[0].URI).To(Equal("https://example.com/my-model.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
+ Expect(modelConfig.Files[0].Filename).To(Equal("my-model.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
+ })
+
+ It("should handle invalid JSON preferences", func() {
+ preferences := json.RawMessage(`invalid json`)
+ details := Details{
+ URI: "https://example.com/my-model.gguf",
+ Preferences: preferences,
+ }
+
+ _, err := importer.Import(details)
+ Expect(err).To(HaveOccurred())
+ })
+
+ It("should extract filename correctly from URI with path", func() {
+ details := importers.Details{
+ URI: "https://example.com/path/to/model.gguf",
+ }
+
+ modelConfig, err := importer.Import(details)
+
+ Expect(err).ToNot(HaveOccurred())
+ Expect(len(modelConfig.Files)).To(Equal(1), fmt.Sprintf("Model config: %+v", modelConfig))
+ Expect(modelConfig.Files[0].URI).To(Equal("https://example.com/path/to/model.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
+ Expect(modelConfig.Files[0].Filename).To(Equal("model.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
+ })
+ })
+})
diff --git a/core/gallery/importers/mlx.go b/core/gallery/importers/mlx.go
new file mode 100644
index 0000000000000000000000000000000000000000..faa28846f4ea25ffb47a42b18f8b7838f0bf7d23
--- /dev/null
+++ b/core/gallery/importers/mlx.go
@@ -0,0 +1,94 @@
+package importers
+
+import (
+ "encoding/json"
+ "path/filepath"
+ "strings"
+
+ "github.com/mudler/LocalAI/core/config"
+ "github.com/mudler/LocalAI/core/gallery"
+ "github.com/mudler/LocalAI/core/schema"
+ "go.yaml.in/yaml/v2"
+)
+
+var _ Importer = &MLXImporter{}
+
+type MLXImporter struct{}
+
+func (i *MLXImporter) Match(details Details) bool {
+ preferences, err := details.Preferences.MarshalJSON()
+ if err != nil {
+ return false
+ }
+ preferencesMap := make(map[string]any)
+ err = json.Unmarshal(preferences, &preferencesMap)
+ if err != nil {
+ return false
+ }
+
+ b, ok := preferencesMap["backend"].(string)
+ if ok && b == "mlx" || b == "mlx-vlm" {
+ return true
+ }
+
+ // All https://huggingface.co/mlx-community/*
+ if strings.Contains(details.URI, "mlx-community/") {
+ return true
+ }
+
+ return false
+}
+
+func (i *MLXImporter) Import(details Details) (gallery.ModelConfig, error) {
+ preferences, err := details.Preferences.MarshalJSON()
+ if err != nil {
+ return gallery.ModelConfig{}, err
+ }
+ preferencesMap := make(map[string]any)
+ err = json.Unmarshal(preferences, &preferencesMap)
+ if err != nil {
+ return gallery.ModelConfig{}, err
+ }
+
+ name, ok := preferencesMap["name"].(string)
+ if !ok {
+ name = filepath.Base(details.URI)
+ }
+
+ description, ok := preferencesMap["description"].(string)
+ if !ok {
+ description = "Imported from " + details.URI
+ }
+
+ backend := "mlx"
+ b, ok := preferencesMap["backend"].(string)
+ if ok {
+ backend = b
+ }
+
+ modelConfig := config.ModelConfig{
+ Name: name,
+ Description: description,
+ KnownUsecaseStrings: []string{"chat"},
+ Backend: backend,
+ PredictionOptions: schema.PredictionOptions{
+ BasicModelRequest: schema.BasicModelRequest{
+ Model: details.URI,
+ },
+ },
+ TemplateConfig: config.TemplateConfig{
+ UseTokenizerTemplate: true,
+ },
+ }
+
+ data, err := yaml.Marshal(modelConfig)
+ if err != nil {
+ return gallery.ModelConfig{}, err
+ }
+
+ return gallery.ModelConfig{
+ Name: name,
+ Description: description,
+ ConfigFile: string(data),
+ }, nil
+}
diff --git a/core/gallery/importers/mlx_test.go b/core/gallery/importers/mlx_test.go
new file mode 100644
index 0000000000000000000000000000000000000000..82e02aff0b4466834a3f921d7517b19f5db6cf62
--- /dev/null
+++ b/core/gallery/importers/mlx_test.go
@@ -0,0 +1,147 @@
+package importers_test
+
+import (
+ "encoding/json"
+
+ "github.com/mudler/LocalAI/core/gallery/importers"
+ . "github.com/onsi/ginkgo/v2"
+ . "github.com/onsi/gomega"
+)
+
+var _ = Describe("MLXImporter", func() {
+ var importer *importers.MLXImporter
+
+ BeforeEach(func() {
+ importer = &importers.MLXImporter{}
+ })
+
+ Context("Match", func() {
+ It("should match when URI contains mlx-community/", func() {
+ details := importers.Details{
+ URI: "https://huggingface.co/mlx-community/test-model",
+ }
+
+ result := importer.Match(details)
+ Expect(result).To(BeTrue())
+ })
+
+ It("should match when backend preference is mlx", func() {
+ preferences := json.RawMessage(`{"backend": "mlx"}`)
+ details := importers.Details{
+ URI: "https://example.com/model",
+ Preferences: preferences,
+ }
+
+ result := importer.Match(details)
+ Expect(result).To(BeTrue())
+ })
+
+ It("should match when backend preference is mlx-vlm", func() {
+ preferences := json.RawMessage(`{"backend": "mlx-vlm"}`)
+ details := importers.Details{
+ URI: "https://example.com/model",
+ Preferences: preferences,
+ }
+
+ result := importer.Match(details)
+ Expect(result).To(BeTrue())
+ })
+
+ It("should not match when URI does not contain mlx-community/ and no backend preference", func() {
+ details := importers.Details{
+ URI: "https://huggingface.co/other-org/test-model",
+ }
+
+ result := importer.Match(details)
+ Expect(result).To(BeFalse())
+ })
+
+ It("should not match when backend preference is different", func() {
+ preferences := json.RawMessage(`{"backend": "llama-cpp"}`)
+ details := importers.Details{
+ URI: "https://example.com/model",
+ Preferences: preferences,
+ }
+
+ result := importer.Match(details)
+ Expect(result).To(BeFalse())
+ })
+
+ It("should return false when JSON preferences are invalid", func() {
+ preferences := json.RawMessage(`invalid json`)
+ details := importers.Details{
+ URI: "https://huggingface.co/mlx-community/test-model",
+ Preferences: preferences,
+ }
+
+ // Invalid JSON causes Match to return false early
+ result := importer.Match(details)
+ Expect(result).To(BeFalse())
+ })
+ })
+
+ Context("Import", func() {
+ It("should import model config with default name and description", func() {
+ details := importers.Details{
+ URI: "https://huggingface.co/mlx-community/test-model",
+ }
+
+ modelConfig, err := importer.Import(details)
+
+ Expect(err).ToNot(HaveOccurred())
+ Expect(modelConfig.Name).To(Equal("test-model"))
+ Expect(modelConfig.Description).To(Equal("Imported from https://huggingface.co/mlx-community/test-model"))
+ Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: mlx"))
+ Expect(modelConfig.ConfigFile).To(ContainSubstring("model: https://huggingface.co/mlx-community/test-model"))
+ })
+
+ It("should import model config with custom name and description from preferences", func() {
+ preferences := json.RawMessage(`{"name": "custom-mlx-model", "description": "Custom MLX description"}`)
+ details := importers.Details{
+ URI: "https://huggingface.co/mlx-community/test-model",
+ Preferences: preferences,
+ }
+
+ modelConfig, err := importer.Import(details)
+
+ Expect(err).ToNot(HaveOccurred())
+ Expect(modelConfig.Name).To(Equal("custom-mlx-model"))
+ Expect(modelConfig.Description).To(Equal("Custom MLX description"))
+ })
+
+ It("should use custom backend from preferences", func() {
+ preferences := json.RawMessage(`{"backend": "mlx-vlm"}`)
+ details := importers.Details{
+ URI: "https://huggingface.co/mlx-community/test-model",
+ Preferences: preferences,
+ }
+
+ modelConfig, err := importer.Import(details)
+
+ Expect(err).ToNot(HaveOccurred())
+ Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: mlx-vlm"))
+ })
+
+ It("should handle invalid JSON preferences", func() {
+ preferences := json.RawMessage(`invalid json`)
+ details := importers.Details{
+ URI: "https://huggingface.co/mlx-community/test-model",
+ Preferences: preferences,
+ }
+
+ _, err := importer.Import(details)
+ Expect(err).To(HaveOccurred())
+ })
+
+ It("should extract filename correctly from URI with path", func() {
+ details := importers.Details{
+ URI: "https://huggingface.co/mlx-community/path/to/model",
+ }
+
+ modelConfig, err := importer.Import(details)
+
+ Expect(err).ToNot(HaveOccurred())
+ Expect(modelConfig.Name).To(Equal("model"))
+ })
+ })
+})
diff --git a/core/gallery/importers/transformers.go b/core/gallery/importers/transformers.go
new file mode 100644
index 0000000000000000000000000000000000000000..cd09c366d8ac3d607ef513acf7d3eae0fe60a0b0
--- /dev/null
+++ b/core/gallery/importers/transformers.go
@@ -0,0 +1,110 @@
+package importers
+
+import (
+ "encoding/json"
+ "path/filepath"
+ "strings"
+
+ "github.com/mudler/LocalAI/core/config"
+ "github.com/mudler/LocalAI/core/gallery"
+ "github.com/mudler/LocalAI/core/schema"
+ "go.yaml.in/yaml/v2"
+)
+
+var _ Importer = &TransformersImporter{}
+
+type TransformersImporter struct{}
+
+func (i *TransformersImporter) Match(details Details) bool {
+ preferences, err := details.Preferences.MarshalJSON()
+ if err != nil {
+ return false
+ }
+ preferencesMap := make(map[string]any)
+ err = json.Unmarshal(preferences, &preferencesMap)
+ if err != nil {
+ return false
+ }
+
+ b, ok := preferencesMap["backend"].(string)
+ if ok && b == "transformers" {
+ return true
+ }
+
+ if details.HuggingFace != nil {
+ for _, file := range details.HuggingFace.Files {
+ if strings.Contains(file.Path, "tokenizer.json") ||
+ strings.Contains(file.Path, "tokenizer_config.json") {
+ return true
+ }
+ }
+ }
+
+ return false
+}
+
+func (i *TransformersImporter) Import(details Details) (gallery.ModelConfig, error) {
+ preferences, err := details.Preferences.MarshalJSON()
+ if err != nil {
+ return gallery.ModelConfig{}, err
+ }
+ preferencesMap := make(map[string]any)
+ err = json.Unmarshal(preferences, &preferencesMap)
+ if err != nil {
+ return gallery.ModelConfig{}, err
+ }
+
+ name, ok := preferencesMap["name"].(string)
+ if !ok {
+ name = filepath.Base(details.URI)
+ }
+
+ description, ok := preferencesMap["description"].(string)
+ if !ok {
+ description = "Imported from " + details.URI
+ }
+
+ backend := "transformers"
+ b, ok := preferencesMap["backend"].(string)
+ if ok {
+ backend = b
+ }
+
+ modelType, ok := preferencesMap["type"].(string)
+ if !ok {
+ modelType = "AutoModelForCausalLM"
+ }
+
+ quantization, ok := preferencesMap["quantization"].(string)
+ if !ok {
+ quantization = ""
+ }
+
+ modelConfig := config.ModelConfig{
+ Name: name,
+ Description: description,
+ KnownUsecaseStrings: []string{"chat"},
+ Backend: backend,
+ PredictionOptions: schema.PredictionOptions{
+ BasicModelRequest: schema.BasicModelRequest{
+ Model: details.URI,
+ },
+ },
+ TemplateConfig: config.TemplateConfig{
+ UseTokenizerTemplate: true,
+ },
+ }
+ modelConfig.ModelType = modelType
+ modelConfig.Quantization = quantization
+
+ data, err := yaml.Marshal(modelConfig)
+ if err != nil {
+ return gallery.ModelConfig{}, err
+ }
+
+ return gallery.ModelConfig{
+ Name: name,
+ Description: description,
+ ConfigFile: string(data),
+ }, nil
+}
diff --git a/core/gallery/importers/transformers_test.go b/core/gallery/importers/transformers_test.go
new file mode 100644
index 0000000000000000000000000000000000000000..a909e75c1726d50974d1ddb236da6178bf29afe1
--- /dev/null
+++ b/core/gallery/importers/transformers_test.go
@@ -0,0 +1,219 @@
+package importers_test
+
+import (
+ "encoding/json"
+
+ "github.com/mudler/LocalAI/core/gallery/importers"
+ . "github.com/mudler/LocalAI/core/gallery/importers"
+ hfapi "github.com/mudler/LocalAI/pkg/huggingface-api"
+ . "github.com/onsi/ginkgo/v2"
+ . "github.com/onsi/gomega"
+)
+
+var _ = Describe("TransformersImporter", func() {
+ var importer *TransformersImporter
+
+ BeforeEach(func() {
+ importer = &TransformersImporter{}
+ })
+
+ Context("Match", func() {
+ It("should match when backend preference is transformers", func() {
+ preferences := json.RawMessage(`{"backend": "transformers"}`)
+ details := Details{
+ URI: "https://example.com/model",
+ Preferences: preferences,
+ }
+
+ result := importer.Match(details)
+ Expect(result).To(BeTrue())
+ })
+
+ It("should match when HuggingFace details contain tokenizer.json", func() {
+ hfDetails := &hfapi.ModelDetails{
+ Files: []hfapi.ModelFile{
+ {Path: "tokenizer.json"},
+ },
+ }
+ details := Details{
+ URI: "https://huggingface.co/test/model",
+ HuggingFace: hfDetails,
+ }
+
+ result := importer.Match(details)
+ Expect(result).To(BeTrue())
+ })
+
+ It("should match when HuggingFace details contain tokenizer_config.json", func() {
+ hfDetails := &hfapi.ModelDetails{
+ Files: []hfapi.ModelFile{
+ {Path: "tokenizer_config.json"},
+ },
+ }
+ details := Details{
+ URI: "https://huggingface.co/test/model",
+ HuggingFace: hfDetails,
+ }
+
+ result := importer.Match(details)
+ Expect(result).To(BeTrue())
+ })
+
+ It("should not match when URI has no tokenizer files and no backend preference", func() {
+ details := Details{
+ URI: "https://example.com/model.bin",
+ }
+
+ result := importer.Match(details)
+ Expect(result).To(BeFalse())
+ })
+
+ It("should not match when backend preference is different", func() {
+ preferences := json.RawMessage(`{"backend": "llama-cpp"}`)
+ details := Details{
+ URI: "https://example.com/model",
+ Preferences: preferences,
+ }
+
+ result := importer.Match(details)
+ Expect(result).To(BeFalse())
+ })
+
+ It("should return false when JSON preferences are invalid", func() {
+ preferences := json.RawMessage(`invalid json`)
+ details := Details{
+ URI: "https://example.com/model",
+ Preferences: preferences,
+ }
+
+ result := importer.Match(details)
+ Expect(result).To(BeFalse())
+ })
+ })
+
+ Context("Import", func() {
+ It("should import model config with default name and description", func() {
+ details := Details{
+ URI: "https://huggingface.co/test/my-model",
+ }
+
+ modelConfig, err := importer.Import(details)
+
+ Expect(err).ToNot(HaveOccurred())
+ Expect(modelConfig.Name).To(Equal("my-model"))
+ Expect(modelConfig.Description).To(Equal("Imported from https://huggingface.co/test/my-model"))
+ Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: transformers"))
+ Expect(modelConfig.ConfigFile).To(ContainSubstring("model: https://huggingface.co/test/my-model"))
+ Expect(modelConfig.ConfigFile).To(ContainSubstring("type: AutoModelForCausalLM"))
+ })
+
+ It("should import model config with custom name and description from preferences", func() {
+ preferences := json.RawMessage(`{"name": "custom-model", "description": "Custom description"}`)
+ details := Details{
+ URI: "https://huggingface.co/test/my-model",
+ Preferences: preferences,
+ }
+
+ modelConfig, err := importer.Import(details)
+
+ Expect(err).ToNot(HaveOccurred())
+ Expect(modelConfig.Name).To(Equal("custom-model"))
+ Expect(modelConfig.Description).To(Equal("Custom description"))
+ })
+
+ It("should use custom model type from preferences", func() {
+ preferences := json.RawMessage(`{"type": "SentenceTransformer"}`)
+ details := Details{
+ URI: "https://huggingface.co/test/my-model",
+ Preferences: preferences,
+ }
+
+ modelConfig, err := importer.Import(details)
+
+ Expect(err).ToNot(HaveOccurred())
+ Expect(modelConfig.ConfigFile).To(ContainSubstring("type: SentenceTransformer"))
+ })
+
+ It("should use default model type when not specified", func() {
+ details := Details{
+ URI: "https://huggingface.co/test/my-model",
+ }
+
+ modelConfig, err := importer.Import(details)
+
+ Expect(err).ToNot(HaveOccurred())
+ Expect(modelConfig.ConfigFile).To(ContainSubstring("type: AutoModelForCausalLM"))
+ })
+
+ It("should use custom backend from preferences", func() {
+ preferences := json.RawMessage(`{"backend": "transformers"}`)
+ details := Details{
+ URI: "https://huggingface.co/test/my-model",
+ Preferences: preferences,
+ }
+
+ modelConfig, err := importer.Import(details)
+
+ Expect(err).ToNot(HaveOccurred())
+ Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: transformers"))
+ })
+
+ It("should use quantization from preferences", func() {
+ preferences := json.RawMessage(`{"quantization": "int8"}`)
+ details := Details{
+ URI: "https://huggingface.co/test/my-model",
+ Preferences: preferences,
+ }
+
+ modelConfig, err := importer.Import(details)
+
+ Expect(err).ToNot(HaveOccurred())
+ Expect(modelConfig.ConfigFile).To(ContainSubstring("quantization: int8"))
+ })
+
+ It("should handle invalid JSON preferences", func() {
+ preferences := json.RawMessage(`invalid json`)
+ details := Details{
+ URI: "https://huggingface.co/test/my-model",
+ Preferences: preferences,
+ }
+
+ _, err := importer.Import(details)
+ Expect(err).To(HaveOccurred())
+ })
+
+ It("should extract filename correctly from URI with path", func() {
+ details := importers.Details{
+ URI: "https://huggingface.co/test/path/to/model",
+ }
+
+ modelConfig, err := importer.Import(details)
+
+ Expect(err).ToNot(HaveOccurred())
+ Expect(modelConfig.Name).To(Equal("model"))
+ })
+
+ It("should include use_tokenizer_template in config", func() {
+ details := Details{
+ URI: "https://huggingface.co/test/my-model",
+ }
+
+ modelConfig, err := importer.Import(details)
+
+ Expect(err).ToNot(HaveOccurred())
+ Expect(modelConfig.ConfigFile).To(ContainSubstring("use_tokenizer_template: true"))
+ })
+
+ It("should include known_usecases in config", func() {
+ details := Details{
+ URI: "https://huggingface.co/test/my-model",
+ }
+
+ modelConfig, err := importer.Import(details)
+
+ Expect(err).ToNot(HaveOccurred())
+ Expect(modelConfig.ConfigFile).To(ContainSubstring("known_usecases:"))
+ Expect(modelConfig.ConfigFile).To(ContainSubstring("- chat"))
+ })
+ })
+})
diff --git a/core/gallery/importers/vllm.go b/core/gallery/importers/vllm.go
new file mode 100644
index 0000000000000000000000000000000000000000..be544662a5185857491f969d3cba7ddf2252fafa
--- /dev/null
+++ b/core/gallery/importers/vllm.go
@@ -0,0 +1,98 @@
+package importers
+
+import (
+ "encoding/json"
+ "path/filepath"
+ "strings"
+
+ "github.com/mudler/LocalAI/core/config"
+ "github.com/mudler/LocalAI/core/gallery"
+ "github.com/mudler/LocalAI/core/schema"
+ "go.yaml.in/yaml/v2"
+)
+
+var _ Importer = &VLLMImporter{}
+
+type VLLMImporter struct{}
+
+func (i *VLLMImporter) Match(details Details) bool {
+ preferences, err := details.Preferences.MarshalJSON()
+ if err != nil {
+ return false
+ }
+ preferencesMap := make(map[string]any)
+ err = json.Unmarshal(preferences, &preferencesMap)
+ if err != nil {
+ return false
+ }
+
+ b, ok := preferencesMap["backend"].(string)
+ if ok && b == "vllm" {
+ return true
+ }
+
+ if details.HuggingFace != nil {
+ for _, file := range details.HuggingFace.Files {
+ if strings.Contains(file.Path, "tokenizer.json") ||
+ strings.Contains(file.Path, "tokenizer_config.json") {
+ return true
+ }
+ }
+ }
+
+ return false
+}
+
+func (i *VLLMImporter) Import(details Details) (gallery.ModelConfig, error) {
+ preferences, err := details.Preferences.MarshalJSON()
+ if err != nil {
+ return gallery.ModelConfig{}, err
+ }
+ preferencesMap := make(map[string]any)
+ err = json.Unmarshal(preferences, &preferencesMap)
+ if err != nil {
+ return gallery.ModelConfig{}, err
+ }
+
+ name, ok := preferencesMap["name"].(string)
+ if !ok {
+ name = filepath.Base(details.URI)
+ }
+
+ description, ok := preferencesMap["description"].(string)
+ if !ok {
+ description = "Imported from " + details.URI
+ }
+
+ backend := "vllm"
+ b, ok := preferencesMap["backend"].(string)
+ if ok {
+ backend = b
+ }
+
+ modelConfig := config.ModelConfig{
+ Name: name,
+ Description: description,
+ KnownUsecaseStrings: []string{"chat"},
+ Backend: backend,
+ PredictionOptions: schema.PredictionOptions{
+ BasicModelRequest: schema.BasicModelRequest{
+ Model: details.URI,
+ },
+ },
+ TemplateConfig: config.TemplateConfig{
+ UseTokenizerTemplate: true,
+ },
+ }
+
+ data, err := yaml.Marshal(modelConfig)
+ if err != nil {
+ return gallery.ModelConfig{}, err
+ }
+
+ return gallery.ModelConfig{
+ Name: name,
+ Description: description,
+ ConfigFile: string(data),
+ }, nil
+}
diff --git a/core/gallery/importers/vllm_test.go b/core/gallery/importers/vllm_test.go
new file mode 100644
index 0000000000000000000000000000000000000000..b6eb5c953968f1dded84769dbf7c8714fb869d0c
--- /dev/null
+++ b/core/gallery/importers/vllm_test.go
@@ -0,0 +1,181 @@
+package importers_test
+
+import (
+ "encoding/json"
+
+ "github.com/mudler/LocalAI/core/gallery/importers"
+ . "github.com/mudler/LocalAI/core/gallery/importers"
+ hfapi "github.com/mudler/LocalAI/pkg/huggingface-api"
+ . "github.com/onsi/ginkgo/v2"
+ . "github.com/onsi/gomega"
+)
+
+var _ = Describe("VLLMImporter", func() {
+ var importer *VLLMImporter
+
+ BeforeEach(func() {
+ importer = &VLLMImporter{}
+ })
+
+ Context("Match", func() {
+ It("should match when backend preference is vllm", func() {
+ preferences := json.RawMessage(`{"backend": "vllm"}`)
+ details := Details{
+ URI: "https://example.com/model",
+ Preferences: preferences,
+ }
+
+ result := importer.Match(details)
+ Expect(result).To(BeTrue())
+ })
+
+ It("should match when HuggingFace details contain tokenizer.json", func() {
+ hfDetails := &hfapi.ModelDetails{
+ Files: []hfapi.ModelFile{
+ {Path: "tokenizer.json"},
+ },
+ }
+ details := Details{
+ URI: "https://huggingface.co/test/model",
+ HuggingFace: hfDetails,
+ }
+
+ result := importer.Match(details)
+ Expect(result).To(BeTrue())
+ })
+
+ It("should match when HuggingFace details contain tokenizer_config.json", func() {
+ hfDetails := &hfapi.ModelDetails{
+ Files: []hfapi.ModelFile{
+ {Path: "tokenizer_config.json"},
+ },
+ }
+ details := Details{
+ URI: "https://huggingface.co/test/model",
+ HuggingFace: hfDetails,
+ }
+
+ result := importer.Match(details)
+ Expect(result).To(BeTrue())
+ })
+
+ It("should not match when URI has no tokenizer files and no backend preference", func() {
+ details := Details{
+ URI: "https://example.com/model.bin",
+ }
+
+ result := importer.Match(details)
+ Expect(result).To(BeFalse())
+ })
+
+ It("should not match when backend preference is different", func() {
+ preferences := json.RawMessage(`{"backend": "llama-cpp"}`)
+ details := Details{
+ URI: "https://example.com/model",
+ Preferences: preferences,
+ }
+
+ result := importer.Match(details)
+ Expect(result).To(BeFalse())
+ })
+
+ It("should return false when JSON preferences are invalid", func() {
+ preferences := json.RawMessage(`invalid json`)
+ details := Details{
+ URI: "https://example.com/model",
+ Preferences: preferences,
+ }
+
+ result := importer.Match(details)
+ Expect(result).To(BeFalse())
+ })
+ })
+
+ Context("Import", func() {
+ It("should import model config with default name and description", func() {
+ details := Details{
+ URI: "https://huggingface.co/test/my-model",
+ }
+
+ modelConfig, err := importer.Import(details)
+
+ Expect(err).ToNot(HaveOccurred())
+ Expect(modelConfig.Name).To(Equal("my-model"))
+ Expect(modelConfig.Description).To(Equal("Imported from https://huggingface.co/test/my-model"))
+ Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: vllm"))
+ Expect(modelConfig.ConfigFile).To(ContainSubstring("model: https://huggingface.co/test/my-model"))
+ })
+
+ It("should import model config with custom name and description from preferences", func() {
+ preferences := json.RawMessage(`{"name": "custom-model", "description": "Custom description"}`)
+ details := Details{
+ URI: "https://huggingface.co/test/my-model",
+ Preferences: preferences,
+ }
+
+ modelConfig, err := importer.Import(details)
+
+ Expect(err).ToNot(HaveOccurred())
+ Expect(modelConfig.Name).To(Equal("custom-model"))
+ Expect(modelConfig.Description).To(Equal("Custom description"))
+ })
+
+ It("should use custom backend from preferences", func() {
+ preferences := json.RawMessage(`{"backend": "vllm"}`)
+ details := Details{
+ URI: "https://huggingface.co/test/my-model",
+ Preferences: preferences,
+ }
+
+ modelConfig, err := importer.Import(details)
+
+ Expect(err).ToNot(HaveOccurred())
+ Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: vllm"))
+ })
+
+ It("should handle invalid JSON preferences", func() {
+ preferences := json.RawMessage(`invalid json`)
+ details := Details{
+ URI: "https://huggingface.co/test/my-model",
+ Preferences: preferences,
+ }
+
+ _, err := importer.Import(details)
+ Expect(err).To(HaveOccurred())
+ })
+
+ It("should extract filename correctly from URI with path", func() {
+ details := importers.Details{
+ URI: "https://huggingface.co/test/path/to/model",
+ }
+
+ modelConfig, err := importer.Import(details)
+
+ Expect(err).ToNot(HaveOccurred())
+ Expect(modelConfig.Name).To(Equal("model"))
+ })
+
+ It("should include use_tokenizer_template in config", func() {
+ details := Details{
+ URI: "https://huggingface.co/test/my-model",
+ }
+
+ modelConfig, err := importer.Import(details)
+
+ Expect(err).ToNot(HaveOccurred())
+ Expect(modelConfig.ConfigFile).To(ContainSubstring("use_tokenizer_template: true"))
+ })
+
+ It("should include known_usecases in config", func() {
+ details := Details{
+ URI: "https://huggingface.co/test/my-model",
+ }
+
+ modelConfig, err := importer.Import(details)
+
+ Expect(err).ToNot(HaveOccurred())
+ Expect(modelConfig.ConfigFile).To(ContainSubstring("known_usecases:"))
+ Expect(modelConfig.ConfigFile).To(ContainSubstring("- chat"))
+ })
+ })
+})
diff --git a/core/gallery/metadata_type.go b/core/gallery/metadata_type.go
new file mode 100644
index 0000000000000000000000000000000000000000..f0059eab628d5e5faaceb6ece087fcf1f3472f68
--- /dev/null
+++ b/core/gallery/metadata_type.go
@@ -0,0 +1,19 @@
+package gallery
+
+import "github.com/mudler/LocalAI/core/config"
+
+type Metadata struct {
+ URL string `json:"url,omitempty" yaml:"url,omitempty"`
+ Name string `json:"name,omitempty" yaml:"name,omitempty"`
+ Description string `json:"description,omitempty" yaml:"description,omitempty"`
+ License string `json:"license,omitempty" yaml:"license,omitempty"`
+ URLs []string `json:"urls,omitempty" yaml:"urls,omitempty"`
+ Icon string `json:"icon,omitempty" yaml:"icon,omitempty"`
+ Tags []string `json:"tags,omitempty" yaml:"tags,omitempty"`
+ // AdditionalFiles are used to add additional files to the model
+ AdditionalFiles []File `json:"files,omitempty" yaml:"files,omitempty"`
+ // Gallery is a reference to the gallery which contains the model
+ Gallery config.Gallery `json:"gallery,omitempty" yaml:"gallery,omitempty"`
+ // Installed is used to indicate if the model is installed or not
+ Installed bool `json:"installed,omitempty" yaml:"installed,omitempty"`
+}
diff --git a/core/gallery/models.go b/core/gallery/models.go
new file mode 100644
index 0000000000000000000000000000000000000000..133d0d0e63b893cd0e47d184d33871bbddf9b5a0
--- /dev/null
+++ b/core/gallery/models.go
@@ -0,0 +1,448 @@
+package gallery
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "os"
+ "path/filepath"
+ "slices"
+ "strings"
+
+ "dario.cat/mergo"
+ lconfig "github.com/mudler/LocalAI/core/config"
+ "github.com/mudler/LocalAI/pkg/downloader"
+ "github.com/mudler/LocalAI/pkg/model"
+ "github.com/mudler/LocalAI/pkg/system"
+ "github.com/mudler/LocalAI/pkg/utils"
+
+ "github.com/mudler/xlog"
+ "gopkg.in/yaml.v3"
+)
+
+/*
+
+description: |
+ foo
+license: ""
+
+urls:
+-
+-
+
+name: "bar"
+
+config_file: |
+ # Note, name will be injected. or generated by the alias wanted by the user
+ threads: 14
+
+files:
+ - filename: ""
+ sha: ""
+ uri: ""
+
+prompt_templates:
+ - name: ""
+ content: ""
+
+*/
+// ModelConfig is the model configuration which contains all the model details
+// This configuration is read from the gallery endpoint and is used to download and install the model
+// It is the internal structure, separated from the request
+type ModelConfig struct {
+ Description string `yaml:"description"`
+ Icon string `yaml:"icon"`
+ License string `yaml:"license"`
+ URLs []string `yaml:"urls"`
+ Name string `yaml:"name"`
+ ConfigFile string `yaml:"config_file"`
+ Files []File `yaml:"files"`
+ PromptTemplates []PromptTemplate `yaml:"prompt_templates"`
+}
+
+type File struct {
+ Filename string `yaml:"filename" json:"filename"`
+ SHA256 string `yaml:"sha256" json:"sha256"`
+ URI string `yaml:"uri" json:"uri"`
+}
+
+type PromptTemplate struct {
+ Name string `yaml:"name"`
+ Content string `yaml:"content"`
+}
+
+// Installs a model from the gallery
+func InstallModelFromGallery(
+ ctx context.Context,
+ modelGalleries, backendGalleries []lconfig.Gallery,
+ systemState *system.SystemState,
+ modelLoader *model.ModelLoader,
+ name string, req GalleryModel, downloadStatus func(string, string, string, float64), enforceScan, automaticallyInstallBackend bool) error {
+
+ applyModel := func(model *GalleryModel) error {
+ name = strings.ReplaceAll(name, string(os.PathSeparator), "__")
+
+ var config ModelConfig
+
+ if len(model.URL) > 0 {
+ var err error
+ config, err = GetGalleryConfigFromURLWithContext[ModelConfig](ctx, model.URL, systemState.Model.ModelsPath)
+ if err != nil {
+ return err
+ }
+ config.Description = model.Description
+ config.License = model.License
+ } else if len(model.ConfigFile) > 0 {
+ // TODO: is this worse than using the override method with a blank cfg yaml?
+ reYamlConfig, err := yaml.Marshal(model.ConfigFile)
+ if err != nil {
+ return err
+ }
+ config = ModelConfig{
+ ConfigFile: string(reYamlConfig),
+ Description: model.Description,
+ License: model.License,
+ URLs: model.URLs,
+ Name: model.Name,
+ Files: make([]File, 0), // Real values get added below, must be blank
+ // Prompt Template Skipped for now - I expect in this mode that they will be delivered as files.
+ }
+ } else {
+ return fmt.Errorf("invalid gallery model %+v", model)
+ }
+
+ installName := model.Name
+ if req.Name != "" {
+ installName = req.Name
+ }
+
+ // Copy the model configuration from the request schema
+ config.URLs = append(config.URLs, model.URLs...)
+ config.Icon = model.Icon
+ config.Files = append(config.Files, req.AdditionalFiles...)
+ config.Files = append(config.Files, model.AdditionalFiles...)
+
+ // TODO model.Overrides could be merged with user overrides (not defined yet)
+ if req.Overrides != nil {
+ if err := mergo.Merge(&model.Overrides, req.Overrides, mergo.WithOverride); err != nil {
+ return err
+ }
+ }
+
+ installedModel, err := InstallModel(ctx, systemState, installName, &config, model.Overrides, downloadStatus, enforceScan)
+ if err != nil {
+ return err
+ }
+ xlog.Debug("Installed model", "model", installedModel.Name)
+ if automaticallyInstallBackend && installedModel.Backend != "" {
+ xlog.Debug("Installing backend", "backend", installedModel.Backend)
+
+ if err := InstallBackendFromGallery(ctx, backendGalleries, systemState, modelLoader, installedModel.Backend, downloadStatus, false); err != nil {
+ return err
+ }
+ }
+
+ return nil
+ }
+
+ models, err := AvailableGalleryModels(modelGalleries, systemState)
+ if err != nil {
+ return err
+ }
+
+ model := FindGalleryElement(models, name)
+ if model == nil {
+ return fmt.Errorf("no model found with name %q", name)
+ }
+
+ return applyModel(model)
+}
+
+func InstallModel(ctx context.Context, systemState *system.SystemState, nameOverride string, config *ModelConfig, configOverrides map[string]interface{}, downloadStatus func(string, string, string, float64), enforceScan bool) (*lconfig.ModelConfig, error) {
+ basePath := systemState.Model.ModelsPath
+ // Create base path if it doesn't exist
+ err := os.MkdirAll(basePath, 0750)
+ if err != nil {
+ return nil, fmt.Errorf("failed to create base path: %v", err)
+ }
+
+ if len(configOverrides) > 0 {
+ xlog.Debug("Config overrides", "overrides", configOverrides)
+ }
+
+ // Download files and verify their SHA
+ for i, file := range config.Files {
+ // Check for cancellation before each file
+ select {
+ case <-ctx.Done():
+ return nil, ctx.Err()
+ default:
+ }
+
+ xlog.Debug("Checking file exists and matches SHA", "filename", file.Filename)
+
+ if err := utils.VerifyPath(file.Filename, basePath); err != nil {
+ return nil, err
+ }
+
+ // Create file path
+ filePath := filepath.Join(basePath, file.Filename)
+
+ if enforceScan {
+ scanResults, err := downloader.HuggingFaceScan(downloader.URI(file.URI))
+ if err != nil && errors.Is(err, downloader.ErrUnsafeFilesFound) {
+ xlog.Error("Contains unsafe file(s)!", "model", config.Name, "clamAV", scanResults.ClamAVInfectedFiles, "pickles", scanResults.DangerousPickles)
+ return nil, err
+ }
+ }
+ uri := downloader.URI(file.URI)
+ if err := uri.DownloadFileWithContext(ctx, filePath, file.SHA256, i, len(config.Files), downloadStatus); err != nil {
+ return nil, err
+ }
+ }
+
+ // Write prompt template contents to separate files
+ for _, template := range config.PromptTemplates {
+ if err := utils.VerifyPath(template.Name+".tmpl", basePath); err != nil {
+ return nil, err
+ }
+ // Create file path
+ filePath := filepath.Join(basePath, template.Name+".tmpl")
+
+ // Create parent directory
+ err := os.MkdirAll(filepath.Dir(filePath), 0750)
+ if err != nil {
+ return nil, fmt.Errorf("failed to create parent directory for prompt template %q: %v", template.Name, err)
+ }
+ // Create and write file content
+ err = os.WriteFile(filePath, []byte(template.Content), 0600)
+ if err != nil {
+ return nil, fmt.Errorf("failed to write prompt template %q: %v", template.Name, err)
+ }
+
+ xlog.Debug("Prompt template written", "template", template.Name)
+ }
+
+ name := config.Name
+ if nameOverride != "" {
+ name = nameOverride
+ }
+
+ if err := utils.VerifyPath(name+".yaml", basePath); err != nil {
+ return nil, err
+ }
+
+ modelConfig := lconfig.ModelConfig{}
+
+ // write config file
+ if len(configOverrides) != 0 || len(config.ConfigFile) != 0 {
+ configFilePath := filepath.Join(basePath, name+".yaml")
+
+ // Read and update config file as map[string]interface{}
+ configMap := make(map[string]interface{})
+ err = yaml.Unmarshal([]byte(config.ConfigFile), &configMap)
+ if err != nil {
+ return nil, fmt.Errorf("failed to unmarshal config YAML: %v", err)
+ }
+
+ configMap["name"] = name
+
+ if configOverrides != nil {
+ if err := mergo.Merge(&configMap, configOverrides, mergo.WithOverride); err != nil {
+ return nil, err
+ }
+ }
+
+ // Write updated config file
+ updatedConfigYAML, err := yaml.Marshal(configMap)
+ if err != nil {
+ return nil, fmt.Errorf("failed to marshal updated config YAML: %v", err)
+ }
+
+ err = yaml.Unmarshal(updatedConfigYAML, &modelConfig)
+ if err != nil {
+ return nil, fmt.Errorf("failed to unmarshal updated config YAML: %v", err)
+ }
+
+ if valid, err := modelConfig.Validate(); !valid {
+ return nil, fmt.Errorf("failed to validate updated config YAML: %v", err)
+ }
+
+ err = os.WriteFile(configFilePath, updatedConfigYAML, 0600)
+ if err != nil {
+ return nil, fmt.Errorf("failed to write updated config file: %v", err)
+ }
+
+ xlog.Debug("Written config file", "file", configFilePath)
+ }
+
+ // Save the model gallery file for further reference
+ modelFile := filepath.Join(basePath, galleryFileName(name))
+ data, err := yaml.Marshal(config)
+ if err != nil {
+ return nil, err
+ }
+
+ xlog.Debug("Written gallery file", "file", modelFile)
+
+ return &modelConfig, os.WriteFile(modelFile, data, 0600)
+}
+
+func galleryFileName(name string) string {
+ return "._gallery_" + name + ".yaml"
+}
+
+func GetLocalModelConfiguration(basePath string, name string) (*ModelConfig, error) {
+ name = strings.ReplaceAll(name, string(os.PathSeparator), "__")
+ galleryFile := filepath.Join(basePath, galleryFileName(name))
+ return ReadConfigFile[ModelConfig](galleryFile)
+}
+
+func listModelFiles(systemState *system.SystemState, name string) ([]string, error) {
+
+ configFile := filepath.Join(systemState.Model.ModelsPath, fmt.Sprintf("%s.yaml", name))
+ if err := utils.VerifyPath(configFile, systemState.Model.ModelsPath); err != nil {
+ return nil, fmt.Errorf("failed to verify path %s: %w", configFile, err)
+ }
+
+ // os.PathSeparator is not allowed in model names. Replace them with "__" to avoid conflicts with file paths.
+ name = strings.ReplaceAll(name, string(os.PathSeparator), "__")
+
+ galleryFile := filepath.Join(systemState.Model.ModelsPath, galleryFileName(name))
+ if err := utils.VerifyPath(galleryFile, systemState.Model.ModelsPath); err != nil {
+ return nil, fmt.Errorf("failed to verify path %s: %w", galleryFile, err)
+ }
+
+ additionalFiles := []string{}
+ allFiles := []string{}
+
+ // Galleryname is the name of the model in this case
+ dat, err := os.ReadFile(configFile)
+ if err == nil {
+ modelConfig := &lconfig.ModelConfig{}
+
+ err = yaml.Unmarshal(dat, &modelConfig)
+ if err != nil {
+ return nil, err
+ }
+ if modelConfig.Model != "" {
+ additionalFiles = append(additionalFiles, modelConfig.ModelFileName())
+ }
+
+ if modelConfig.MMProj != "" {
+ additionalFiles = append(additionalFiles, modelConfig.MMProjFileName())
+ }
+ }
+
+ // read the model config
+ galleryconfig, err := ReadConfigFile[ModelConfig](galleryFile)
+ if err == nil && galleryconfig != nil {
+ for _, f := range galleryconfig.Files {
+ fullPath := filepath.Join(systemState.Model.ModelsPath, f.Filename)
+ if err := utils.VerifyPath(fullPath, systemState.Model.ModelsPath); err != nil {
+ return allFiles, fmt.Errorf("failed to verify path %s: %w", fullPath, err)
+ }
+ allFiles = append(allFiles, fullPath)
+ }
+ } else {
+ xlog.Error("failed to read gallery file", "error", err, "file", configFile)
+ }
+
+ for _, f := range additionalFiles {
+ fullPath := filepath.Join(filepath.Join(systemState.Model.ModelsPath, f))
+ if err := utils.VerifyPath(fullPath, systemState.Model.ModelsPath); err != nil {
+ return allFiles, fmt.Errorf("failed to verify path %s: %w", fullPath, err)
+ }
+ allFiles = append(allFiles, fullPath)
+ }
+
+ allFiles = append(allFiles, galleryFile)
+
+ // skip duplicates
+ allFiles = utils.Unique(allFiles)
+
+ return allFiles, nil
+}
+
+func DeleteModelFromSystem(systemState *system.SystemState, name string) error {
+ configFile := filepath.Join(systemState.Model.ModelsPath, fmt.Sprintf("%s.yaml", name))
+
+ filesToRemove, err := listModelFiles(systemState, name)
+ if err != nil {
+ return err
+ }
+
+ allOtherFiles := []string{}
+ // Get all files of all other models
+ fi, err := os.ReadDir(systemState.Model.ModelsPath)
+ if err != nil {
+ return err
+ }
+ for _, f := range fi {
+ if f.IsDir() {
+ continue
+ }
+ if strings.HasPrefix(f.Name(), "._gallery_") {
+ continue
+ }
+ if !strings.HasSuffix(f.Name(), ".yaml") && !strings.HasSuffix(f.Name(), ".yml") {
+ continue
+ }
+ if f.Name() == fmt.Sprintf("%s.yaml", name) || f.Name() == fmt.Sprintf("%s.yml", name) {
+ continue
+ }
+
+ name := strings.TrimSuffix(f.Name(), ".yaml")
+ name = strings.TrimSuffix(name, ".yml")
+
+ xlog.Debug("Checking file", "file", f.Name())
+ files, err := listModelFiles(systemState, name)
+ if err != nil {
+ xlog.Debug("failed to list files for model", "error", err, "model", f.Name())
+ continue
+ }
+ allOtherFiles = append(allOtherFiles, files...)
+ }
+
+ xlog.Debug("Files to remove", "files", filesToRemove)
+ xlog.Debug("All other files", "files", allOtherFiles)
+
+ // Removing files
+ for _, f := range filesToRemove {
+ if slices.Contains(allOtherFiles, f) {
+ xlog.Debug("Skipping file because it is part of another model", "file", f)
+ continue
+ }
+ if e := os.Remove(f); e != nil {
+ xlog.Error("failed to remove file", "error", e, "file", f)
+ }
+ }
+
+ return os.Remove(configFile)
+}
+
+// This is ***NEVER*** going to be perfect or finished.
+// This is a BEST EFFORT function to surface known-vulnerable models to users.
+func SafetyScanGalleryModels(galleries []lconfig.Gallery, systemState *system.SystemState) error {
+ galleryModels, err := AvailableGalleryModels(galleries, systemState)
+ if err != nil {
+ return err
+ }
+ for _, gM := range galleryModels {
+ if gM.Installed {
+ err = errors.Join(err, SafetyScanGalleryModel(gM))
+ }
+ }
+ return err
+}
+
+func SafetyScanGalleryModel(galleryModel *GalleryModel) error {
+ for _, file := range galleryModel.AdditionalFiles {
+ scanResults, err := downloader.HuggingFaceScan(downloader.URI(file.URI))
+ if err != nil && errors.Is(err, downloader.ErrUnsafeFilesFound) {
+ xlog.Error("Contains unsafe file(s)!", "model", galleryModel.Name, "clamAV", scanResults.ClamAVInfectedFiles, "pickles", scanResults.DangerousPickles)
+ return err
+ }
+ }
+ return nil
+}
diff --git a/core/gallery/models_test.go b/core/gallery/models_test.go
new file mode 100644
index 0000000000000000000000000000000000000000..c672435996b2796f2b98a46cdd41a0dfdc4c3818
--- /dev/null
+++ b/core/gallery/models_test.go
@@ -0,0 +1,300 @@
+package gallery_test
+
+import (
+ "context"
+ "errors"
+ "os"
+ "path/filepath"
+
+ "github.com/mudler/LocalAI/core/config"
+ . "github.com/mudler/LocalAI/core/gallery"
+ "github.com/mudler/LocalAI/pkg/system"
+ . "github.com/onsi/ginkgo/v2"
+ . "github.com/onsi/gomega"
+ "gopkg.in/yaml.v3"
+)
+
+const bertEmbeddingsURL = `https://gist.githubusercontent.com/mudler/0a080b166b87640e8644b09c2aee6e3b/raw/f0e8c26bb72edc16d9fbafbfd6638072126ff225/bert-embeddings-gallery.yaml`
+
+var _ = Describe("Model test", func() {
+
+ BeforeEach(func() {
+ if os.Getenv("FIXTURES") == "" {
+ Skip("FIXTURES env var not set, skipping model tests")
+ }
+ })
+
+ Context("Downloading", func() {
+ It("applies model correctly", func() {
+ tempdir, err := os.MkdirTemp("", "test")
+ Expect(err).ToNot(HaveOccurred())
+ defer os.RemoveAll(tempdir)
+ c, err := ReadConfigFile[ModelConfig](filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
+ Expect(err).ToNot(HaveOccurred())
+ systemState, err := system.GetSystemState(
+ system.WithModelPath(tempdir),
+ )
+ Expect(err).ToNot(HaveOccurred())
+ _, err = InstallModel(context.TODO(), systemState, "", c, map[string]interface{}{}, func(string, string, string, float64) {}, true)
+ Expect(err).ToNot(HaveOccurred())
+
+ for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "cerebras.yaml"} {
+ _, err = os.Stat(filepath.Join(tempdir, f))
+ Expect(err).ToNot(HaveOccurred())
+ }
+
+ content := map[string]interface{}{}
+
+ dat, err := os.ReadFile(filepath.Join(tempdir, "cerebras.yaml"))
+ Expect(err).ToNot(HaveOccurred())
+
+ err = yaml.Unmarshal(dat, content)
+ Expect(err).ToNot(HaveOccurred())
+
+ Expect(content["context_size"]).To(Equal(1024))
+ })
+
+ It("applies model from gallery correctly", func() {
+ tempdir, err := os.MkdirTemp("", "test")
+ Expect(err).ToNot(HaveOccurred())
+ defer os.RemoveAll(tempdir)
+
+ gallery := []GalleryModel{{
+ Metadata: Metadata{
+ Name: "bert",
+ URL: bertEmbeddingsURL,
+ },
+ }}
+ out, err := yaml.Marshal(gallery)
+ Expect(err).ToNot(HaveOccurred())
+ galleryFilePath := filepath.Join(tempdir, "gallery_simple.yaml")
+ err = os.WriteFile(galleryFilePath, out, 0600)
+ Expect(err).ToNot(HaveOccurred())
+ Expect(filepath.IsAbs(galleryFilePath)).To(BeTrue(), galleryFilePath)
+ galleries := []config.Gallery{
+ {
+ Name: "test",
+ URL: "file://" + galleryFilePath,
+ },
+ }
+ systemState, err := system.GetSystemState(
+ system.WithModelPath(tempdir),
+ )
+ Expect(err).ToNot(HaveOccurred())
+
+ models, err := AvailableGalleryModels(galleries, systemState)
+ Expect(err).ToNot(HaveOccurred())
+ Expect(len(models)).To(Equal(1))
+ Expect(models[0].Name).To(Equal("bert"))
+ Expect(models[0].URL).To(Equal(bertEmbeddingsURL))
+ Expect(models[0].Installed).To(BeFalse())
+
+ err = InstallModelFromGallery(context.TODO(), galleries, []config.Gallery{}, systemState, nil, "test@bert", GalleryModel{}, func(s1, s2, s3 string, f float64) {}, true, true)
+ Expect(err).ToNot(HaveOccurred())
+
+ dat, err := os.ReadFile(filepath.Join(tempdir, "bert.yaml"))
+ Expect(err).ToNot(HaveOccurred())
+
+ content := map[string]interface{}{}
+ err = yaml.Unmarshal(dat, &content)
+ Expect(err).ToNot(HaveOccurred())
+ Expect(content["usage"]).To(ContainSubstring("You can test this model with curl like this"))
+
+ models, err = AvailableGalleryModels(galleries, systemState)
+ Expect(err).ToNot(HaveOccurred())
+ Expect(len(models)).To(Equal(1))
+ Expect(models[0].Installed).To(BeTrue())
+
+ // delete
+ err = DeleteModelFromSystem(systemState, "bert")
+ Expect(err).ToNot(HaveOccurred())
+
+ models, err = AvailableGalleryModels(galleries, systemState)
+ Expect(err).ToNot(HaveOccurred())
+ Expect(len(models)).To(Equal(1))
+ Expect(models[0].Installed).To(BeFalse())
+
+ _, err = os.Stat(filepath.Join(tempdir, "bert.yaml"))
+ Expect(err).To(HaveOccurred())
+ Expect(errors.Is(err, os.ErrNotExist)).To(BeTrue())
+ })
+
+ It("renames model correctly", func() {
+ tempdir, err := os.MkdirTemp("", "test")
+ Expect(err).ToNot(HaveOccurred())
+ defer os.RemoveAll(tempdir)
+ c, err := ReadConfigFile[ModelConfig](filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
+ Expect(err).ToNot(HaveOccurred())
+
+ systemState, err := system.GetSystemState(
+ system.WithModelPath(tempdir),
+ )
+ Expect(err).ToNot(HaveOccurred())
+ _, err = InstallModel(context.TODO(), systemState, "foo", c, map[string]interface{}{}, func(string, string, string, float64) {}, true)
+ Expect(err).ToNot(HaveOccurred())
+
+ for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "foo.yaml"} {
+ _, err = os.Stat(filepath.Join(tempdir, f))
+ Expect(err).ToNot(HaveOccurred())
+ }
+ })
+
+ It("overrides parameters", func() {
+ tempdir, err := os.MkdirTemp("", "test")
+ Expect(err).ToNot(HaveOccurred())
+ defer os.RemoveAll(tempdir)
+ c, err := ReadConfigFile[ModelConfig](filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
+ Expect(err).ToNot(HaveOccurred())
+
+ systemState, err := system.GetSystemState(
+ system.WithModelPath(tempdir),
+ )
+ Expect(err).ToNot(HaveOccurred())
+ _, err = InstallModel(context.TODO(), systemState, "foo", c, map[string]interface{}{"backend": "foo"}, func(string, string, string, float64) {}, true)
+ Expect(err).ToNot(HaveOccurred())
+
+ for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "foo.yaml"} {
+ _, err = os.Stat(filepath.Join(tempdir, f))
+ Expect(err).ToNot(HaveOccurred())
+ }
+
+ content := map[string]interface{}{}
+
+ dat, err := os.ReadFile(filepath.Join(tempdir, "foo.yaml"))
+ Expect(err).ToNot(HaveOccurred())
+
+ err = yaml.Unmarshal(dat, content)
+ Expect(err).ToNot(HaveOccurred())
+
+ Expect(content["backend"]).To(Equal("foo"))
+ })
+
+ It("catches path traversals", func() {
+ tempdir, err := os.MkdirTemp("", "test")
+ Expect(err).ToNot(HaveOccurred())
+ defer os.RemoveAll(tempdir)
+ c, err := ReadConfigFile[ModelConfig](filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
+ Expect(err).ToNot(HaveOccurred())
+
+ systemState, err := system.GetSystemState(
+ system.WithModelPath(tempdir),
+ )
+ Expect(err).ToNot(HaveOccurred())
+ _, err = InstallModel(context.TODO(), systemState, "../../../foo", c, map[string]interface{}{}, func(string, string, string, float64) {}, true)
+ Expect(err).To(HaveOccurred())
+ })
+
+ It("handles nil configOverrides without panic", func() {
+ tempdir, err := os.MkdirTemp("", "test")
+ Expect(err).ToNot(HaveOccurred())
+ defer os.RemoveAll(tempdir)
+ c, err := ReadConfigFile[ModelConfig](filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
+ Expect(err).ToNot(HaveOccurred())
+
+ systemState, err := system.GetSystemState(
+ system.WithModelPath(tempdir),
+ )
+ Expect(err).ToNot(HaveOccurred())
+ _, err = InstallModel(context.TODO(), systemState, "test-model", c, nil, func(string, string, string, float64) {}, true)
+ Expect(err).ToNot(HaveOccurred())
+
+ for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "test-model.yaml"} {
+ _, err = os.Stat(filepath.Join(tempdir, f))
+ Expect(err).ToNot(HaveOccurred())
+ }
+ })
+
+ It("does not delete shared model files when one config is deleted", func() {
+ tempdir, err := os.MkdirTemp("", "test")
+ Expect(err).ToNot(HaveOccurred())
+ defer os.RemoveAll(tempdir)
+
+ systemState, err := system.GetSystemState(
+ system.WithModelPath(tempdir),
+ )
+ Expect(err).ToNot(HaveOccurred())
+
+ // Create a shared model file
+ sharedModelFile := filepath.Join(tempdir, "shared_model.bin")
+ err = os.WriteFile(sharedModelFile, []byte("fake model content"), 0600)
+ Expect(err).ToNot(HaveOccurred())
+
+ // Create first model configuration
+ config1 := `name: model1
+model: shared_model.bin`
+ err = os.WriteFile(filepath.Join(tempdir, "model1.yaml"), []byte(config1), 0600)
+ Expect(err).ToNot(HaveOccurred())
+
+ // Create first model's gallery file
+ galleryConfig1 := ModelConfig{
+ Name: "model1",
+ Files: []File{
+ {Filename: "shared_model.bin"},
+ },
+ }
+ galleryData1, err := yaml.Marshal(galleryConfig1)
+ Expect(err).ToNot(HaveOccurred())
+ err = os.WriteFile(filepath.Join(tempdir, "._gallery_model1.yaml"), galleryData1, 0600)
+ Expect(err).ToNot(HaveOccurred())
+
+ // Create second model configuration sharing the same model file
+ config2 := `name: model2
+model: shared_model.bin`
+ err = os.WriteFile(filepath.Join(tempdir, "model2.yaml"), []byte(config2), 0600)
+ Expect(err).ToNot(HaveOccurred())
+
+ // Create second model's gallery file
+ galleryConfig2 := ModelConfig{
+ Name: "model2",
+ Files: []File{
+ {Filename: "shared_model.bin"},
+ },
+ }
+ galleryData2, err := yaml.Marshal(galleryConfig2)
+ Expect(err).ToNot(HaveOccurred())
+ err = os.WriteFile(filepath.Join(tempdir, "._gallery_model2.yaml"), galleryData2, 0600)
+ Expect(err).ToNot(HaveOccurred())
+
+ // Verify both configurations exist
+ _, err = os.Stat(filepath.Join(tempdir, "model1.yaml"))
+ Expect(err).ToNot(HaveOccurred())
+ _, err = os.Stat(filepath.Join(tempdir, "model2.yaml"))
+ Expect(err).ToNot(HaveOccurred())
+
+ // Verify the shared model file exists
+ _, err = os.Stat(sharedModelFile)
+ Expect(err).ToNot(HaveOccurred())
+
+ // Delete the first model
+ err = DeleteModelFromSystem(systemState, "model1")
+ Expect(err).ToNot(HaveOccurred())
+
+ // Verify the first configuration is deleted
+ _, err = os.Stat(filepath.Join(tempdir, "model1.yaml"))
+ Expect(err).To(HaveOccurred())
+ Expect(errors.Is(err, os.ErrNotExist)).To(BeTrue())
+
+ // Verify the shared model file still exists (not deleted because model2 still uses it)
+ _, err = os.Stat(sharedModelFile)
+ Expect(err).ToNot(HaveOccurred(), "shared model file should not be deleted when used by other configs")
+
+ // Verify the second configuration still exists
+ _, err = os.Stat(filepath.Join(tempdir, "model2.yaml"))
+ Expect(err).ToNot(HaveOccurred())
+
+ // Now delete the second model
+ err = DeleteModelFromSystem(systemState, "model2")
+ Expect(err).ToNot(HaveOccurred())
+
+ // Verify the second configuration is deleted
+ _, err = os.Stat(filepath.Join(tempdir, "model2.yaml"))
+ Expect(err).To(HaveOccurred())
+ Expect(errors.Is(err, os.ErrNotExist)).To(BeTrue())
+
+ // Verify the shared model file is now deleted (no more references)
+ _, err = os.Stat(sharedModelFile)
+ Expect(err).To(HaveOccurred(), "shared model file should be deleted when no configs reference it")
+ Expect(errors.Is(err, os.ErrNotExist)).To(BeTrue())
+ })
+ })
+})
diff --git a/core/gallery/models_types.go b/core/gallery/models_types.go
new file mode 100644
index 0000000000000000000000000000000000000000..000aa2b266d44082b3f808494afeb481d3f69d76
--- /dev/null
+++ b/core/gallery/models_types.go
@@ -0,0 +1,54 @@
+package gallery
+
+import (
+ "fmt"
+
+ "github.com/mudler/LocalAI/core/config"
+)
+
+// GalleryModel is the struct used to represent a model in the gallery returned by the endpoint.
+// It is used to install the model by resolving the URL and downloading the files.
+// The other fields are used to override the configuration of the model.
+type GalleryModel struct {
+ Metadata `json:",inline" yaml:",inline"`
+ // config_file is read in the situation where URL is blank - and therefore this is a base config.
+ ConfigFile map[string]interface{} `json:"config_file,omitempty" yaml:"config_file,omitempty"`
+ // Overrides are used to override the configuration of the model located at URL
+ Overrides map[string]interface{} `json:"overrides,omitempty" yaml:"overrides,omitempty"`
+}
+
+func (m *GalleryModel) GetInstalled() bool {
+ return m.Installed
+}
+
+func (m *GalleryModel) GetLicense() string {
+ return m.License
+}
+
+func (m *GalleryModel) SetGallery(gallery config.Gallery) {
+ m.Gallery = gallery
+}
+
+func (m *GalleryModel) SetInstalled(installed bool) {
+ m.Installed = installed
+}
+
+func (m *GalleryModel) GetName() string {
+ return m.Name
+}
+
+func (m *GalleryModel) GetGallery() config.Gallery {
+ return m.Gallery
+}
+
+func (m GalleryModel) ID() string {
+ return fmt.Sprintf("%s@%s", m.Gallery.Name, m.Name)
+}
+
+func (m *GalleryModel) GetTags() []string {
+ return m.Tags
+}
+
+func (m *GalleryModel) GetDescription() string {
+ return m.Description
+}
diff --git a/core/gallery/request_test.go b/core/gallery/request_test.go
new file mode 100644
index 0000000000000000000000000000000000000000..fb1b20d163cf3b906f4d6debf7bc09144404352f
--- /dev/null
+++ b/core/gallery/request_test.go
@@ -0,0 +1,22 @@
+package gallery_test
+
+import (
+ . "github.com/mudler/LocalAI/core/gallery"
+ . "github.com/onsi/ginkgo/v2"
+ . "github.com/onsi/gomega"
+)
+
+var _ = Describe("Gallery API tests", func() {
+ Context("requests", func() {
+ It("parses github with a branch", func() {
+ req := GalleryModel{
+ Metadata: Metadata{
+ URL: "github:go-skynet/model-gallery/gpt4all-j.yaml@main",
+ },
+ }
+ e, err := GetGalleryConfigFromURL[ModelConfig](req.URL, "")
+ Expect(err).ToNot(HaveOccurred())
+ Expect(e.Name).To(Equal("gpt4all-j"))
+ })
+ })
+})
diff --git a/core/http/app.go b/core/http/app.go
new file mode 100644
index 0000000000000000000000000000000000000000..328a9d8e9a18a394e6c04bc6557dee3c7cd33424
--- /dev/null
+++ b/core/http/app.go
@@ -0,0 +1,223 @@
+package http
+
+import (
+ "embed"
+ "errors"
+ "fmt"
+ "io/fs"
+ "net/http"
+ "os"
+ "path/filepath"
+ "strings"
+
+ "github.com/labstack/echo/v4"
+ "github.com/labstack/echo/v4/middleware"
+
+ "github.com/mudler/LocalAI/core/http/endpoints/localai"
+ httpMiddleware "github.com/mudler/LocalAI/core/http/middleware"
+ "github.com/mudler/LocalAI/core/http/routes"
+
+ "github.com/mudler/LocalAI/core/application"
+ "github.com/mudler/LocalAI/core/schema"
+ "github.com/mudler/LocalAI/core/services"
+
+ "github.com/mudler/xlog"
+)
+
+// Embed a directory
+//
+//go:embed static/*
+var embedDirStatic embed.FS
+
+// @title LocalAI API
+// @version 2.0.0
+// @description The LocalAI Rest API.
+// @termsOfService
+// @contact.name LocalAI
+// @contact.url https://localai.io
+// @license.name MIT
+// @license.url https://raw.githubusercontent.com/mudler/LocalAI/master/LICENSE
+// @BasePath /
+// @securityDefinitions.apikey BearerAuth
+// @in header
+// @name Authorization
+
+func API(application *application.Application) (*echo.Echo, error) {
+ e := echo.New()
+
+ // Set body limit
+ if application.ApplicationConfig().UploadLimitMB > 0 {
+ e.Use(middleware.BodyLimit(fmt.Sprintf("%dM", application.ApplicationConfig().UploadLimitMB)))
+ }
+
+ // Set error handler
+ if !application.ApplicationConfig().OpaqueErrors {
+ e.HTTPErrorHandler = func(err error, c echo.Context) {
+ code := http.StatusInternalServerError
+ var he *echo.HTTPError
+ if errors.As(err, &he) {
+ code = he.Code
+ }
+
+ // Handle 404 errors with HTML rendering when appropriate
+ if code == http.StatusNotFound {
+ notFoundHandler(c)
+ return
+ }
+
+ // Send custom error page
+ c.JSON(code, schema.ErrorResponse{
+ Error: &schema.APIError{Message: err.Error(), Code: code},
+ })
+ }
+ } else {
+ e.HTTPErrorHandler = func(err error, c echo.Context) {
+ code := http.StatusInternalServerError
+ var he *echo.HTTPError
+ if errors.As(err, &he) {
+ code = he.Code
+ }
+ c.NoContent(code)
+ }
+ }
+
+ // Set renderer
+ e.Renderer = renderEngine()
+
+ // Hide banner
+ e.HideBanner = true
+ e.HidePort = true
+
+ // Middleware - StripPathPrefix must be registered early as it uses Rewrite which runs before routing
+ e.Pre(httpMiddleware.StripPathPrefix())
+
+ e.Pre(middleware.RemoveTrailingSlash())
+
+ if application.ApplicationConfig().MachineTag != "" {
+ e.Use(func(next echo.HandlerFunc) echo.HandlerFunc {
+ return func(c echo.Context) error {
+ c.Response().Header().Set("Machine-Tag", application.ApplicationConfig().MachineTag)
+ return next(c)
+ }
+ })
+ }
+
+ // Custom logger middleware using xlog
+ e.Use(func(next echo.HandlerFunc) echo.HandlerFunc {
+ return func(c echo.Context) error {
+ req := c.Request()
+ res := c.Response()
+ err := next(c)
+ xlog.Info("HTTP request", "method", req.Method, "path", req.URL.Path, "status", res.Status)
+ return err
+ }
+ })
+
+ // Recover middleware
+ if !application.ApplicationConfig().Debug {
+ e.Use(middleware.Recover())
+ }
+
+ // Metrics middleware
+ if !application.ApplicationConfig().DisableMetrics {
+ metricsService, err := services.NewLocalAIMetricsService()
+ if err != nil {
+ return nil, err
+ }
+
+ if metricsService != nil {
+ e.Use(localai.LocalAIMetricsAPIMiddleware(metricsService))
+ e.Server.RegisterOnShutdown(func() {
+ metricsService.Shutdown()
+ })
+ }
+ }
+
+ // Health Checks should always be exempt from auth, so register these first
+ routes.HealthRoutes(e)
+
+ // Get key auth middleware
+ keyAuthMiddleware, err := httpMiddleware.GetKeyAuthConfig(application.ApplicationConfig())
+ if err != nil {
+ return nil, fmt.Errorf("failed to create key auth config: %w", err)
+ }
+
+ // Favicon handler
+ e.GET("/favicon.svg", func(c echo.Context) error {
+ data, err := embedDirStatic.ReadFile("static/favicon.svg")
+ if err != nil {
+ return c.NoContent(http.StatusNotFound)
+ }
+ c.Response().Header().Set("Content-Type", "image/svg+xml")
+ return c.Blob(http.StatusOK, "image/svg+xml", data)
+ })
+
+ // Static files - use fs.Sub to create a filesystem rooted at "static"
+ staticFS, err := fs.Sub(embedDirStatic, "static")
+ if err != nil {
+ return nil, fmt.Errorf("failed to create static filesystem: %w", err)
+ }
+ e.StaticFS("/static", staticFS)
+
+ // Generated content directories
+ if application.ApplicationConfig().GeneratedContentDir != "" {
+ os.MkdirAll(application.ApplicationConfig().GeneratedContentDir, 0750)
+ audioPath := filepath.Join(application.ApplicationConfig().GeneratedContentDir, "audio")
+ imagePath := filepath.Join(application.ApplicationConfig().GeneratedContentDir, "images")
+ videoPath := filepath.Join(application.ApplicationConfig().GeneratedContentDir, "videos")
+
+ os.MkdirAll(audioPath, 0750)
+ os.MkdirAll(imagePath, 0750)
+ os.MkdirAll(videoPath, 0750)
+
+ e.Static("/generated-audio", audioPath)
+ e.Static("/generated-images", imagePath)
+ e.Static("/generated-videos", videoPath)
+ }
+
+ // Auth is applied to _all_ endpoints. No exceptions. Filtering out endpoints to bypass is the role of the Skipper property of the KeyAuth Configuration
+ e.Use(keyAuthMiddleware)
+
+ // CORS middleware
+ if application.ApplicationConfig().CORS {
+ corsConfig := middleware.CORSConfig{}
+ if application.ApplicationConfig().CORSAllowOrigins != "" {
+ corsConfig.AllowOrigins = strings.Split(application.ApplicationConfig().CORSAllowOrigins, ",")
+ }
+ e.Use(middleware.CORSWithConfig(corsConfig))
+ }
+
+ // CSRF middleware
+ if application.ApplicationConfig().CSRF {
+ xlog.Debug("Enabling CSRF middleware. Tokens are now required for state-modifying requests")
+ e.Use(middleware.CSRF())
+ }
+
+ requestExtractor := httpMiddleware.NewRequestExtractor(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
+
+ routes.RegisterElevenLabsRoutes(e, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
+
+ // Create opcache for tracking UI operations (used by both UI and LocalAI routes)
+ var opcache *services.OpCache
+ if !application.ApplicationConfig().DisableWebUI {
+ opcache = services.NewOpCache(application.GalleryService())
+ }
+
+ routes.RegisterLocalAIRoutes(e, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService(), opcache, application.TemplatesEvaluator(), application)
+ routes.RegisterOpenAIRoutes(e, requestExtractor, application)
+ routes.RegisterAnthropicRoutes(e, requestExtractor, application)
+ if !application.ApplicationConfig().DisableWebUI {
+ routes.RegisterUIAPIRoutes(e, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService(), opcache, application)
+ routes.RegisterUIRoutes(e, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService())
+ }
+ routes.RegisterJINARoutes(e, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
+
+ // Note: 404 handling is done via HTTPErrorHandler above, no need for catch-all route
+
+ // Log startup message
+ e.Server.RegisterOnShutdown(func() {
+ xlog.Info("LocalAI API server shutting down")
+ })
+
+ return e, nil
+}
diff --git a/core/http/app_test.go b/core/http/app_test.go
new file mode 100644
index 0000000000000000000000000000000000000000..1b41c8124b1c2f72199bd107ebfd137b239ddf47
--- /dev/null
+++ b/core/http/app_test.go
@@ -0,0 +1,1457 @@
+package http_test
+
+import (
+ "bytes"
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+ "os"
+ "path/filepath"
+ "runtime"
+ "time"
+
+ "github.com/mudler/LocalAI/core/application"
+ "github.com/mudler/LocalAI/core/config"
+ . "github.com/mudler/LocalAI/core/http"
+ "github.com/mudler/LocalAI/core/schema"
+
+ "github.com/labstack/echo/v4"
+ "github.com/mudler/LocalAI/core/gallery"
+ "github.com/mudler/LocalAI/pkg/downloader"
+ "github.com/mudler/LocalAI/pkg/system"
+ . "github.com/onsi/ginkgo/v2"
+ . "github.com/onsi/gomega"
+ "gopkg.in/yaml.v3"
+
+ "github.com/mudler/xlog"
+ openaigo "github.com/otiai10/openaigo"
+ "github.com/sashabaranov/go-openai"
+ "github.com/sashabaranov/go-openai/jsonschema"
+)
+
+const apiKey = "joshua"
+const bearerKey = "Bearer " + apiKey
+
+const testPrompt = `### System:
+You are an AI assistant that follows instruction extremely well. Help as much as you can.
+
+### Instruction:
+
+Say hello.
+
+### Response:`
+
+type modelApplyRequest struct {
+ ID string `json:"id"`
+ URL string `json:"url"`
+ ConfigURL string `json:"config_url"`
+ Name string `json:"name"`
+ Overrides map[string]interface{} `json:"overrides"`
+}
+
+func getModelStatus(url string) (response map[string]interface{}) {
+ // Create the HTTP request
+ req, err := http.NewRequest("GET", url, nil)
+ req.Header.Set("Content-Type", "application/json")
+ req.Header.Set("Authorization", bearerKey)
+ if err != nil {
+ fmt.Println("Error creating request:", err)
+ return
+ }
+ client := &http.Client{}
+ resp, err := client.Do(req)
+ if err != nil {
+ fmt.Println("Error sending request:", err)
+ return
+ }
+ defer resp.Body.Close()
+
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ fmt.Println("Error reading response body:", err)
+ return
+ }
+
+ // Unmarshal the response into a map[string]interface{}
+ err = json.Unmarshal(body, &response)
+ if err != nil {
+ fmt.Println("Error unmarshaling JSON response:", err)
+ return
+ }
+ return
+}
+
+func getModels(url string) ([]gallery.GalleryModel, error) {
+ response := []gallery.GalleryModel{}
+ uri := downloader.URI(url)
+ // TODO: No tests currently seem to exercise file:// urls. Fix?
+ err := uri.ReadWithAuthorizationAndCallback(context.TODO(), "", bearerKey, func(url string, i []byte) error {
+ // Unmarshal YAML data into a struct
+ return json.Unmarshal(i, &response)
+ })
+ return response, err
+}
+
+func postModelApplyRequest(url string, request modelApplyRequest) (response map[string]interface{}) {
+
+ //url := "http://localhost:AI/models/apply"
+
+ // Create the request payload
+
+ payload, err := json.Marshal(request)
+ if err != nil {
+ fmt.Println("Error marshaling JSON:", err)
+ return
+ }
+
+ // Create the HTTP request
+ req, err := http.NewRequest("POST", url, bytes.NewBuffer(payload))
+ if err != nil {
+ fmt.Println("Error creating request:", err)
+ return
+ }
+ req.Header.Set("Content-Type", "application/json")
+ req.Header.Set("Authorization", bearerKey)
+
+ // Make the request
+ client := &http.Client{}
+ resp, err := client.Do(req)
+ if err != nil {
+ fmt.Println("Error making request:", err)
+ return
+ }
+ defer resp.Body.Close()
+
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ fmt.Println("Error reading response body:", err)
+ return
+ }
+
+ // Unmarshal the response into a map[string]interface{}
+ err = json.Unmarshal(body, &response)
+ if err != nil {
+ fmt.Println("Error unmarshaling JSON response:", err)
+ return
+ }
+ return
+}
+
+func postRequestJSON[B any](url string, bodyJson *B) error {
+ payload, err := json.Marshal(bodyJson)
+ if err != nil {
+ return err
+ }
+
+ GinkgoWriter.Printf("POST %s: %s\n", url, string(payload))
+
+ req, err := http.NewRequest("POST", url, bytes.NewBuffer(payload))
+ if err != nil {
+ return err
+ }
+
+ req.Header.Set("Content-Type", "application/json")
+ req.Header.Set("Authorization", bearerKey)
+
+ client := &http.Client{}
+ resp, err := client.Do(req)
+ if err != nil {
+ return err
+ }
+
+ defer resp.Body.Close()
+
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return err
+ }
+
+ if resp.StatusCode < 200 || resp.StatusCode >= 400 {
+ return fmt.Errorf("unexpected status code: %d, body: %s", resp.StatusCode, string(body))
+ }
+
+ return nil
+}
+
+func postRequestResponseJSON[B1 any, B2 any](url string, reqJson *B1, respJson *B2) error {
+ payload, err := json.Marshal(reqJson)
+ if err != nil {
+ return err
+ }
+
+ GinkgoWriter.Printf("POST %s: %s\n", url, string(payload))
+
+ req, err := http.NewRequest("POST", url, bytes.NewBuffer(payload))
+ if err != nil {
+ return err
+ }
+
+ req.Header.Set("Content-Type", "application/json")
+ req.Header.Set("Authorization", bearerKey)
+
+ client := &http.Client{}
+ resp, err := client.Do(req)
+ if err != nil {
+ return err
+ }
+ defer resp.Body.Close()
+
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return err
+ }
+
+ if resp.StatusCode < 200 || resp.StatusCode >= 400 {
+ return fmt.Errorf("unexpected status code: %d, body: %s", resp.StatusCode, string(body))
+ }
+
+ return json.Unmarshal(body, respJson)
+}
+
+func putRequestJSON[B any](url string, bodyJson *B) error {
+ payload, err := json.Marshal(bodyJson)
+ if err != nil {
+ return err
+ }
+
+ GinkgoWriter.Printf("PUT %s: %s\n", url, string(payload))
+
+ req, err := http.NewRequest("PUT", url, bytes.NewBuffer(payload))
+ if err != nil {
+ return err
+ }
+
+ req.Header.Set("Content-Type", "application/json")
+ req.Header.Set("Authorization", bearerKey)
+
+ client := &http.Client{}
+ resp, err := client.Do(req)
+ if err != nil {
+ return err
+ }
+ defer resp.Body.Close()
+
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return err
+ }
+
+ if resp.StatusCode < 200 || resp.StatusCode >= 400 {
+ return fmt.Errorf("unexpected status code: %d, body: %s", resp.StatusCode, string(body))
+ }
+
+ return nil
+}
+
+func postInvalidRequest(url string) (error, int) {
+
+ req, err := http.NewRequest("POST", url, bytes.NewBufferString("invalid request"))
+ if err != nil {
+ return err, -1
+ }
+
+ req.Header.Set("Content-Type", "application/json")
+
+ client := &http.Client{}
+ resp, err := client.Do(req)
+ if err != nil {
+ return err, -1
+ }
+
+ defer resp.Body.Close()
+
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return err, -1
+ }
+
+ if resp.StatusCode < 200 || resp.StatusCode >= 400 {
+ return fmt.Errorf("unexpected status code: %d, body: %s", resp.StatusCode, string(body)), resp.StatusCode
+ }
+
+ return nil, resp.StatusCode
+}
+
+func getRequest(url string, header http.Header) (error, int, []byte) {
+
+ req, err := http.NewRequest("GET", url, nil)
+ if err != nil {
+ return err, -1, nil
+ }
+
+ req.Header = header
+
+ client := &http.Client{}
+ resp, err := client.Do(req)
+ if err != nil {
+ return err, -1, nil
+ }
+
+ defer resp.Body.Close()
+
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return err, -1, nil
+ }
+
+ return nil, resp.StatusCode, body
+}
+
+const bertEmbeddingsURL = `https://gist.githubusercontent.com/mudler/0a080b166b87640e8644b09c2aee6e3b/raw/f0e8c26bb72edc16d9fbafbfd6638072126ff225/bert-embeddings-gallery.yaml`
+
+var _ = Describe("API test", func() {
+
+ var app *echo.Echo
+ var client *openai.Client
+ var client2 *openaigo.Client
+ var c context.Context
+ var cancel context.CancelFunc
+ var tmpdir string
+ var modelDir string
+
+ commonOpts := []config.AppOption{
+ config.WithDebug(true),
+ }
+
+ Context("API with ephemeral models", func() {
+
+ BeforeEach(func(sc SpecContext) {
+ var err error
+ tmpdir, err = os.MkdirTemp("", "")
+ Expect(err).ToNot(HaveOccurred())
+
+ backendPath := os.Getenv("BACKENDS_PATH")
+
+ modelDir = filepath.Join(tmpdir, "models")
+ err = os.Mkdir(modelDir, 0750)
+ Expect(err).ToNot(HaveOccurred())
+
+ c, cancel = context.WithCancel(context.Background())
+
+ g := []gallery.GalleryModel{
+ {
+ Metadata: gallery.Metadata{
+ Name: "bert",
+ URL: bertEmbeddingsURL,
+ },
+ },
+ {
+ Metadata: gallery.Metadata{
+ Name: "bert2",
+ URL: bertEmbeddingsURL,
+ AdditionalFiles: []gallery.File{{Filename: "foo.yaml", URI: bertEmbeddingsURL}},
+ },
+ Overrides: map[string]interface{}{"foo": "bar"},
+ },
+ }
+ out, err := yaml.Marshal(g)
+ Expect(err).ToNot(HaveOccurred())
+ err = os.WriteFile(filepath.Join(modelDir, "gallery_simple.yaml"), out, 0600)
+ Expect(err).ToNot(HaveOccurred())
+
+ galleries := []config.Gallery{
+ {
+ Name: "test",
+ URL: "file://" + filepath.Join(modelDir, "gallery_simple.yaml"),
+ },
+ }
+
+ systemState, err := system.GetSystemState(
+ system.WithBackendPath(backendPath),
+ system.WithModelPath(modelDir),
+ )
+ Expect(err).ToNot(HaveOccurred())
+
+ application, err := application.New(
+ append(commonOpts,
+ config.WithContext(c),
+ config.WithSystemState(systemState),
+ config.WithGalleries(galleries),
+ config.WithApiKeys([]string{apiKey}),
+ )...)
+ Expect(err).ToNot(HaveOccurred())
+
+ app, err = API(application)
+ Expect(err).ToNot(HaveOccurred())
+
+ go func() {
+ if err := app.Start("127.0.0.1:9090"); err != nil && err != http.ErrServerClosed {
+ xlog.Error("server error", "error", err)
+ }
+ }()
+
+ defaultConfig := openai.DefaultConfig(apiKey)
+ defaultConfig.BaseURL = "http://127.0.0.1:9090/v1"
+
+ client2 = openaigo.NewClient("")
+ client2.BaseURL = defaultConfig.BaseURL
+
+ // Wait for API to be ready
+ client = openai.NewClientWithConfig(defaultConfig)
+ Eventually(func() error {
+ _, err := client.ListModels(context.TODO())
+ return err
+ }, "2m").ShouldNot(HaveOccurred())
+ })
+
+ AfterEach(func(sc SpecContext) {
+ cancel()
+ if app != nil {
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+ err := app.Shutdown(ctx)
+ Expect(err).ToNot(HaveOccurred())
+ }
+ err := os.RemoveAll(tmpdir)
+ Expect(err).ToNot(HaveOccurred())
+ _, err = os.ReadDir(tmpdir)
+ Expect(err).To(HaveOccurred())
+ })
+
+ Context("Auth Tests", func() {
+ It("Should fail if the api key is missing", func() {
+ err, sc := postInvalidRequest("http://127.0.0.1:9090/models/available")
+ Expect(err).ToNot(BeNil())
+ Expect(sc).To(Equal(401))
+ })
+ })
+
+ Context("URL routing Tests", func() {
+ It("Should support reverse-proxy when unauthenticated", func() {
+
+ err, sc, body := getRequest("http://127.0.0.1:9090/myprefix/", http.Header{
+ "X-Forwarded-Proto": {"https"},
+ "X-Forwarded-Host": {"example.org"},
+ "X-Forwarded-Prefix": {"/myprefix/"},
+ })
+ Expect(err).To(BeNil(), "error")
+ Expect(sc).To(Equal(401), "status code")
+ Expect(string(body)).To(ContainSubstring(` `), "body")
+ })
+
+ It("Should support reverse-proxy when authenticated", func() {
+
+ err, sc, body := getRequest("http://127.0.0.1:9090/myprefix/", http.Header{
+ "Authorization": {bearerKey},
+ "X-Forwarded-Proto": {"https"},
+ "X-Forwarded-Host": {"example.org"},
+ "X-Forwarded-Prefix": {"/myprefix/"},
+ })
+ Expect(err).To(BeNil(), "error")
+ Expect(sc).To(Equal(200), "status code")
+ Expect(string(body)).To(ContainSubstring(` `), "body")
+ })
+ })
+
+ Context("Applying models", func() {
+
+ It("applies models from a gallery", func() {
+ models, err := getModels("http://127.0.0.1:9090/models/available")
+ Expect(err).To(BeNil())
+ Expect(len(models)).To(Equal(2), fmt.Sprint(models))
+ Expect(models[0].Installed).To(BeFalse(), fmt.Sprint(models))
+ Expect(models[1].Installed).To(BeFalse(), fmt.Sprint(models))
+
+ response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{
+ ID: "test@bert2",
+ })
+
+ Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response))
+
+ uuid := response["uuid"].(string)
+ resp := map[string]interface{}{}
+ Eventually(func() bool {
+ response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid)
+ fmt.Println(response)
+ resp = response
+ return response["processed"].(bool)
+ }, "360s", "10s").Should(Equal(true))
+ Expect(resp["message"]).ToNot(ContainSubstring("error"))
+
+ dat, err := os.ReadFile(filepath.Join(modelDir, "bert2.yaml"))
+ Expect(err).ToNot(HaveOccurred())
+
+ _, err = os.ReadFile(filepath.Join(modelDir, "foo.yaml"))
+ Expect(err).ToNot(HaveOccurred())
+
+ content := map[string]interface{}{}
+ err = yaml.Unmarshal(dat, &content)
+ Expect(err).ToNot(HaveOccurred())
+ Expect(content["usage"]).To(ContainSubstring("You can test this model with curl like this"))
+ Expect(content["foo"]).To(Equal("bar"))
+
+ models, err = getModels("http://127.0.0.1:9090/models/available")
+ Expect(err).To(BeNil())
+ Expect(len(models)).To(Equal(2), fmt.Sprint(models))
+ Expect(models[0].Name).To(Or(Equal("bert"), Equal("bert2")))
+ Expect(models[1].Name).To(Or(Equal("bert"), Equal("bert2")))
+ for _, m := range models {
+ if m.Name == "bert2" {
+ Expect(m.Installed).To(BeTrue())
+ } else {
+ Expect(m.Installed).To(BeFalse())
+ }
+ }
+ })
+ It("overrides models", func() {
+
+ response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{
+ URL: bertEmbeddingsURL,
+ Name: "bert",
+ Overrides: map[string]interface{}{
+ "backend": "llama",
+ },
+ })
+
+ Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response))
+
+ uuid := response["uuid"].(string)
+
+ Eventually(func() bool {
+ response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid)
+ return response["processed"].(bool)
+ }, "360s", "10s").Should(Equal(true))
+
+ dat, err := os.ReadFile(filepath.Join(modelDir, "bert.yaml"))
+ Expect(err).ToNot(HaveOccurred())
+
+ content := map[string]interface{}{}
+ err = yaml.Unmarshal(dat, &content)
+ Expect(err).ToNot(HaveOccurred())
+ Expect(content["backend"]).To(Equal("llama"))
+ })
+ It("apply models without overrides", func() {
+ response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{
+ URL: bertEmbeddingsURL,
+ Name: "bert",
+ Overrides: map[string]interface{}{},
+ })
+
+ Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response))
+
+ uuid := response["uuid"].(string)
+
+ Eventually(func() bool {
+ response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid)
+ return response["processed"].(bool)
+ }, "360s", "10s").Should(Equal(true))
+
+ dat, err := os.ReadFile(filepath.Join(modelDir, "bert.yaml"))
+ Expect(err).ToNot(HaveOccurred())
+
+ content := map[string]interface{}{}
+ err = yaml.Unmarshal(dat, &content)
+ Expect(err).ToNot(HaveOccurred())
+ Expect(content["usage"]).To(ContainSubstring("You can test this model with curl like this"))
+ })
+
+ })
+
+ Context("Importing models from URI", func() {
+ var testYamlFile string
+
+ BeforeEach(func() {
+ // Create a test YAML config file
+ yamlContent := `name: test-import-model
+backend: llama-cpp
+description: Test model imported from file URI
+parameters:
+ model: path/to/model.gguf
+ temperature: 0.7
+`
+ testYamlFile = filepath.Join(tmpdir, "test-import.yaml")
+ err := os.WriteFile(testYamlFile, []byte(yamlContent), 0644)
+ Expect(err).ToNot(HaveOccurred())
+ })
+
+ AfterEach(func() {
+ err := os.Remove(testYamlFile)
+ Expect(err).ToNot(HaveOccurred())
+ })
+
+ It("should import model from file:// URI pointing to local YAML config", func() {
+ importReq := schema.ImportModelRequest{
+ URI: "file://" + testYamlFile,
+ Preferences: json.RawMessage(`{}`),
+ }
+
+ var response schema.GalleryResponse
+ err := postRequestResponseJSON("http://127.0.0.1:9090/models/import-uri", &importReq, &response)
+ Expect(err).ToNot(HaveOccurred())
+ Expect(response.ID).ToNot(BeEmpty())
+
+ uuid := response.ID
+ resp := map[string]interface{}{}
+ Eventually(func() bool {
+ response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid)
+ resp = response
+ return response["processed"].(bool)
+ }, "360s", "10s").Should(Equal(true))
+
+ // Check that the model was imported successfully
+ Expect(resp["message"]).ToNot(ContainSubstring("error"))
+ Expect(resp["error"]).To(BeNil())
+
+ // Verify the model config file was created
+ dat, err := os.ReadFile(filepath.Join(modelDir, "test-import-model.yaml"))
+ Expect(err).ToNot(HaveOccurred())
+
+ content := map[string]interface{}{}
+ err = yaml.Unmarshal(dat, &content)
+ Expect(err).ToNot(HaveOccurred())
+ Expect(content["name"]).To(Equal("test-import-model"))
+ Expect(content["backend"]).To(Equal("llama-cpp"))
+ })
+
+ It("should return error when file:// URI points to non-existent file", func() {
+ nonExistentFile := filepath.Join(tmpdir, "nonexistent.yaml")
+ importReq := schema.ImportModelRequest{
+ URI: "file://" + nonExistentFile,
+ Preferences: json.RawMessage(`{}`),
+ }
+
+ var response schema.GalleryResponse
+ err := postRequestResponseJSON("http://127.0.0.1:9090/models/import-uri", &importReq, &response)
+ // The endpoint should return an error immediately
+ Expect(err).To(HaveOccurred())
+ Expect(err.Error()).To(ContainSubstring("failed to discover model config"))
+ })
+ })
+
+ Context("Importing models from URI can't point to absolute paths", func() {
+ var testYamlFile string
+
+ BeforeEach(func() {
+ // Create a test YAML config file
+ yamlContent := `name: test-import-model
+backend: llama-cpp
+description: Test model imported from file URI
+parameters:
+ model: /path/to/model.gguf
+ temperature: 0.7
+`
+ testYamlFile = filepath.Join(tmpdir, "test-import.yaml")
+ err := os.WriteFile(testYamlFile, []byte(yamlContent), 0644)
+ Expect(err).ToNot(HaveOccurred())
+ })
+
+ AfterEach(func() {
+ err := os.Remove(testYamlFile)
+ Expect(err).ToNot(HaveOccurred())
+ })
+
+ It("should fail to import model from file:// URI pointing to local YAML config", func() {
+ importReq := schema.ImportModelRequest{
+ URI: "file://" + testYamlFile,
+ Preferences: json.RawMessage(`{}`),
+ }
+
+ var response schema.GalleryResponse
+ err := postRequestResponseJSON("http://127.0.0.1:9090/models/import-uri", &importReq, &response)
+ Expect(err).ToNot(HaveOccurred())
+ Expect(response.ID).ToNot(BeEmpty())
+
+ uuid := response.ID
+ resp := map[string]interface{}{}
+ Eventually(func() bool {
+ response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid)
+ resp = response
+ return response["processed"].(bool)
+ }, "360s", "10s").Should(Equal(true))
+
+ // Check that the model was imported successfully
+ Expect(resp["message"]).To(ContainSubstring("error"))
+ Expect(resp["error"]).ToNot(BeNil())
+ })
+ })
+ })
+
+ Context("Model gallery", func() {
+ BeforeEach(func() {
+ var err error
+ tmpdir, err = os.MkdirTemp("", "")
+
+ backendPath := os.Getenv("BACKENDS_PATH")
+
+ Expect(err).ToNot(HaveOccurred())
+ modelDir = filepath.Join(tmpdir, "models")
+ backendAssetsDir := filepath.Join(tmpdir, "backend-assets")
+ err = os.Mkdir(backendAssetsDir, 0750)
+ Expect(err).ToNot(HaveOccurred())
+
+ c, cancel = context.WithCancel(context.Background())
+
+ galleries := []config.Gallery{
+ {
+ Name: "localai",
+ URL: "https://raw.githubusercontent.com/mudler/LocalAI/refs/heads/master/gallery/index.yaml",
+ },
+ }
+
+ systemState, err := system.GetSystemState(
+ system.WithBackendPath(backendPath),
+ system.WithModelPath(modelDir),
+ )
+ Expect(err).ToNot(HaveOccurred())
+
+ application, err := application.New(
+ append(commonOpts,
+ config.WithContext(c),
+ config.WithGeneratedContentDir(tmpdir),
+ config.WithSystemState(systemState),
+ config.WithGalleries(galleries),
+ )...,
+ )
+ Expect(err).ToNot(HaveOccurred())
+ app, err = API(application)
+ Expect(err).ToNot(HaveOccurred())
+
+ go func() {
+ if err := app.Start("127.0.0.1:9090"); err != nil && err != http.ErrServerClosed {
+ xlog.Error("server error", "error", err)
+ }
+ }()
+
+ defaultConfig := openai.DefaultConfig("")
+ defaultConfig.BaseURL = "http://127.0.0.1:9090/v1"
+
+ client2 = openaigo.NewClient("")
+ client2.BaseURL = defaultConfig.BaseURL
+
+ // Wait for API to be ready
+ client = openai.NewClientWithConfig(defaultConfig)
+ Eventually(func() error {
+ _, err := client.ListModels(context.TODO())
+ return err
+ }, "2m").ShouldNot(HaveOccurred())
+ })
+
+ AfterEach(func() {
+ cancel()
+ if app != nil {
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+ err := app.Shutdown(ctx)
+ Expect(err).ToNot(HaveOccurred())
+ }
+ err := os.RemoveAll(tmpdir)
+ Expect(err).ToNot(HaveOccurred())
+ _, err = os.ReadDir(tmpdir)
+ Expect(err).To(HaveOccurred())
+ })
+
+ It("runs gguf models (chat)", Label("llama-gguf"), func() {
+ if runtime.GOOS != "linux" {
+ Skip("test supported only on linux")
+ }
+
+ modelName := "qwen3-1.7b"
+ response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{
+ ID: "localai@" + modelName,
+ })
+
+ Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response))
+
+ uuid := response["uuid"].(string)
+
+ Eventually(func() bool {
+ response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid)
+ return response["processed"].(bool)
+ }, "900s", "10s").Should(Equal(true))
+
+ By("testing chat")
+ resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: modelName, Messages: []openai.ChatCompletionMessage{
+ {
+ Role: "user",
+ Content: "How much is 2+2?",
+ },
+ }})
+ Expect(err).ToNot(HaveOccurred())
+ Expect(len(resp.Choices)).To(Equal(1))
+ Expect(resp.Choices[0].Message.Content).To(Or(ContainSubstring("4"), ContainSubstring("four")))
+
+ By("testing functions")
+ resp2, err := client.CreateChatCompletion(
+ context.TODO(),
+ openai.ChatCompletionRequest{
+ Model: modelName,
+ Messages: []openai.ChatCompletionMessage{
+ {
+ Role: "user",
+ Content: "What is the weather like in San Francisco (celsius)?",
+ },
+ },
+ Functions: []openai.FunctionDefinition{
+ openai.FunctionDefinition{
+ Name: "get_current_weather",
+ Description: "Get the current weather",
+ Parameters: jsonschema.Definition{
+ Type: jsonschema.Object,
+ Properties: map[string]jsonschema.Definition{
+ "location": {
+ Type: jsonschema.String,
+ Description: "The city and state, e.g. San Francisco, CA",
+ },
+ "unit": {
+ Type: jsonschema.String,
+ Enum: []string{"celcius", "fahrenheit"},
+ },
+ },
+ Required: []string{"location"},
+ },
+ },
+ },
+ })
+ Expect(err).ToNot(HaveOccurred())
+ Expect(len(resp2.Choices)).To(Equal(1))
+ Expect(resp2.Choices[0].Message.FunctionCall).ToNot(BeNil())
+ Expect(resp2.Choices[0].Message.FunctionCall.Name).To(Equal("get_current_weather"), resp2.Choices[0].Message.FunctionCall.Name)
+
+ var res map[string]string
+ err = json.Unmarshal([]byte(resp2.Choices[0].Message.FunctionCall.Arguments), &res)
+ Expect(err).ToNot(HaveOccurred())
+ Expect(res["location"]).To(ContainSubstring("San Francisco"), fmt.Sprint(res))
+ Expect(res["unit"]).To(Equal("celcius"), fmt.Sprint(res))
+ Expect(string(resp2.Choices[0].FinishReason)).To(Equal("function_call"), fmt.Sprint(resp2.Choices[0].FinishReason))
+ })
+
+ It("installs and is capable to run tts", Label("tts"), func() {
+ if runtime.GOOS != "linux" {
+ Skip("test supported only on linux")
+ }
+
+ response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{
+ ID: "localai@voice-en-us-kathleen-low",
+ })
+
+ Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response))
+
+ uuid := response["uuid"].(string)
+
+ Eventually(func() bool {
+ response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid)
+ fmt.Println(response)
+ return response["processed"].(bool)
+ }, "360s", "10s").Should(Equal(true))
+
+ // An HTTP Post to the /tts endpoint should return a wav audio file
+ resp, err := http.Post("http://127.0.0.1:9090/tts", "application/json", bytes.NewBuffer([]byte(`{"input": "Hello world", "model": "voice-en-us-kathleen-low"}`)))
+ Expect(err).ToNot(HaveOccurred(), fmt.Sprint(resp))
+ dat, err := io.ReadAll(resp.Body)
+ Expect(err).ToNot(HaveOccurred(), fmt.Sprint(resp))
+
+ Expect(resp.StatusCode).To(Equal(200), fmt.Sprint(string(dat)))
+ Expect(resp.Header.Get("Content-Type")).To(Or(Equal("audio/x-wav"), Equal("audio/vnd.wave")))
+ })
+ It("installs and is capable to generate images", Label("stablediffusion"), func() {
+ if runtime.GOOS != "linux" {
+ Skip("test supported only on linux")
+ }
+
+ response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{
+ ID: "localai@sd-1.5-ggml",
+ Name: "stablediffusion",
+ })
+
+ Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response))
+
+ uuid := response["uuid"].(string)
+
+ Eventually(func() bool {
+ response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid)
+ fmt.Println(response)
+ return response["processed"].(bool)
+ }, "1200s", "10s").Should(Equal(true))
+
+ resp, err := http.Post(
+ "http://127.0.0.1:9090/v1/images/generations",
+ "application/json",
+ bytes.NewBuffer([]byte(`{
+ "prompt": "a lovely cat",
+ "step": 1, "seed":9000,
+ "size": "256x256", "n":2}`)))
+ // The response should contain an URL
+ Expect(err).ToNot(HaveOccurred(), fmt.Sprint(resp))
+ dat, err := io.ReadAll(resp.Body)
+ Expect(err).ToNot(HaveOccurred(), "error reading /image/generations response")
+
+ imgUrlResp := &schema.OpenAIResponse{}
+ err = json.Unmarshal(dat, imgUrlResp)
+ Expect(err).ToNot(HaveOccurred(), fmt.Sprint(dat))
+ Expect(imgUrlResp.Data).ToNot(Or(BeNil(), BeZero()))
+ imgUrl := imgUrlResp.Data[0].URL
+ Expect(imgUrl).To(ContainSubstring("http://127.0.0.1:9090/"), imgUrl)
+ Expect(imgUrl).To(ContainSubstring(".png"), imgUrl)
+
+ imgResp, err := http.Get(imgUrl)
+ Expect(err).To(BeNil())
+ Expect(imgResp).ToNot(BeNil())
+ Expect(imgResp.StatusCode).To(Equal(200))
+ Expect(imgResp.ContentLength).To(BeNumerically(">", 0))
+ imgData := make([]byte, 512)
+ count, err := io.ReadFull(imgResp.Body, imgData)
+ Expect(err).To(Or(BeNil(), MatchError(io.EOF)))
+ Expect(count).To(BeNumerically(">", 0))
+ Expect(count).To(BeNumerically("<=", 512))
+ Expect(http.DetectContentType(imgData)).To(Equal("image/png"))
+ })
+ })
+
+ Context("API query", func() {
+ BeforeEach(func() {
+ modelPath := os.Getenv("MODELS_PATH")
+ backendPath := os.Getenv("BACKENDS_PATH")
+ c, cancel = context.WithCancel(context.Background())
+
+ var err error
+
+ systemState, err := system.GetSystemState(
+ system.WithBackendPath(backendPath),
+ system.WithModelPath(modelPath),
+ )
+ Expect(err).ToNot(HaveOccurred())
+
+ application, err := application.New(
+ append(commonOpts,
+ config.WithExternalBackend("transformers", os.Getenv("HUGGINGFACE_GRPC")),
+ config.WithContext(c),
+ config.WithSystemState(systemState),
+ )...)
+ Expect(err).ToNot(HaveOccurred())
+ app, err = API(application)
+ Expect(err).ToNot(HaveOccurred())
+ go func() {
+ if err := app.Start("127.0.0.1:9090"); err != nil && err != http.ErrServerClosed {
+ xlog.Error("server error", "error", err)
+ }
+ }()
+
+ defaultConfig := openai.DefaultConfig("")
+ defaultConfig.BaseURL = "http://127.0.0.1:9090/v1"
+
+ client2 = openaigo.NewClient("")
+ client2.BaseURL = defaultConfig.BaseURL
+
+ // Wait for API to be ready
+ client = openai.NewClientWithConfig(defaultConfig)
+ Eventually(func() error {
+ _, err := client.ListModels(context.TODO())
+ return err
+ }, "2m").ShouldNot(HaveOccurred())
+ })
+ AfterEach(func() {
+ cancel()
+ if app != nil {
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+ err := app.Shutdown(ctx)
+ Expect(err).ToNot(HaveOccurred())
+ }
+ })
+ It("returns the models list", func() {
+ models, err := client.ListModels(context.TODO())
+ Expect(err).ToNot(HaveOccurred())
+ Expect(len(models.Models)).To(Equal(7)) // If "config.yaml" should be included, this should be 8?
+ })
+ It("can generate completions via ggml", func() {
+ if runtime.GOOS != "linux" {
+ Skip("test supported only on linux")
+ }
+ resp, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "testmodel.ggml", Prompt: testPrompt})
+ Expect(err).ToNot(HaveOccurred())
+ Expect(len(resp.Choices)).To(Equal(1))
+ Expect(resp.Choices[0].Text).ToNot(BeEmpty())
+ })
+
+ It("can generate chat completions via ggml", func() {
+ if runtime.GOOS != "linux" {
+ Skip("test supported only on linux")
+ }
+ resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: "testmodel.ggml", Messages: []openai.ChatCompletionMessage{openai.ChatCompletionMessage{Role: "user", Content: testPrompt}}})
+ Expect(err).ToNot(HaveOccurred())
+ Expect(len(resp.Choices)).To(Equal(1))
+ Expect(resp.Choices[0].Message.Content).ToNot(BeEmpty())
+ })
+
+ It("returns logprobs in chat completions when requested", func() {
+ if runtime.GOOS != "linux" {
+ Skip("test only on linux")
+ }
+ topLogprobsVal := 3
+ response, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{
+ Model: "testmodel.ggml",
+ LogProbs: true,
+ TopLogProbs: topLogprobsVal,
+ Messages: []openai.ChatCompletionMessage{{Role: "user", Content: testPrompt}}})
+ Expect(err).ToNot(HaveOccurred())
+
+ Expect(len(response.Choices)).To(Equal(1))
+ Expect(response.Choices[0].Message).ToNot(BeNil())
+ Expect(response.Choices[0].Message.Content).ToNot(BeEmpty())
+
+ // Verify logprobs are present and have correct structure
+ Expect(response.Choices[0].LogProbs).ToNot(BeNil())
+ Expect(response.Choices[0].LogProbs.Content).ToNot(BeEmpty())
+
+ Expect(len(response.Choices[0].LogProbs.Content)).To(BeNumerically(">", 1))
+
+ foundatLeastToken := ""
+ foundAtLeastBytes := []byte{}
+ foundAtLeastTopLogprobBytes := []byte{}
+ foundatLeastTopLogprob := ""
+ // Verify logprobs content structure matches OpenAI format
+ for _, logprobContent := range response.Choices[0].LogProbs.Content {
+ // Bytes can be empty for certain tokens (special tokens, etc.), so we don't require it
+ if len(logprobContent.Bytes) > 0 {
+ foundAtLeastBytes = logprobContent.Bytes
+ }
+ if len(logprobContent.Token) > 0 {
+ foundatLeastToken = logprobContent.Token
+ }
+ Expect(logprobContent.LogProb).To(BeNumerically("<=", 0)) // Logprobs are always <= 0
+ Expect(len(logprobContent.TopLogProbs)).To(BeNumerically(">", 1))
+
+ // If top_logprobs is requested, verify top_logprobs array respects the limit
+ if len(logprobContent.TopLogProbs) > 0 {
+ // Should respect top_logprobs limit (3 in this test)
+ Expect(len(logprobContent.TopLogProbs)).To(BeNumerically("<=", topLogprobsVal))
+ for _, topLogprob := range logprobContent.TopLogProbs {
+ if len(topLogprob.Bytes) > 0 {
+ foundAtLeastTopLogprobBytes = topLogprob.Bytes
+ }
+ if len(topLogprob.Token) > 0 {
+ foundatLeastTopLogprob = topLogprob.Token
+ }
+ Expect(topLogprob.LogProb).To(BeNumerically("<=", 0))
+ }
+ }
+ }
+
+ Expect(foundAtLeastBytes).ToNot(BeEmpty())
+ Expect(foundAtLeastTopLogprobBytes).ToNot(BeEmpty())
+ Expect(foundatLeastToken).ToNot(BeEmpty())
+ Expect(foundatLeastTopLogprob).ToNot(BeEmpty())
+ })
+
+ It("applies logit_bias to chat completions when requested", func() {
+ if runtime.GOOS != "linux" {
+ Skip("test only on linux")
+ }
+ // logit_bias is a map of token IDs (as strings) to bias values (-100 to 100)
+ // According to OpenAI API: modifies the likelihood of specified tokens appearing in the completion
+ logitBias := map[string]int{
+ "15043": 1, // Bias token ID 15043 (example token ID) with bias value 1
+ }
+ response, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{
+ Model: "testmodel.ggml",
+ Messages: []openai.ChatCompletionMessage{{Role: "user", Content: testPrompt}},
+ LogitBias: logitBias,
+ })
+ Expect(err).ToNot(HaveOccurred())
+ Expect(len(response.Choices)).To(Equal(1))
+ Expect(response.Choices[0].Message).ToNot(BeNil())
+ Expect(response.Choices[0].Message.Content).ToNot(BeEmpty())
+ // If logit_bias is applied, the response should be generated successfully
+ // We can't easily verify the bias effect without knowing the actual token IDs for the model,
+ // but the fact that the request succeeds confirms the API accepts and processes logit_bias
+ })
+
+ It("returns errors", func() {
+ _, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "foomodel", Prompt: testPrompt})
+ Expect(err).To(HaveOccurred())
+ Expect(err.Error()).To(ContainSubstring("error, status code: 500, status: 500 Internal Server Error, message: could not load model - all backends returned error:"))
+ })
+
+ It("shows the external backend", func() {
+ // Only run on linux
+ if runtime.GOOS != "linux" {
+ Skip("test supported only on linux")
+ }
+ // do an http request to the /system endpoint
+ resp, err := http.Get("http://127.0.0.1:9090/system")
+ Expect(err).ToNot(HaveOccurred())
+ Expect(resp.StatusCode).To(Equal(200))
+ dat, err := io.ReadAll(resp.Body)
+ Expect(err).ToNot(HaveOccurred())
+ Expect(string(dat)).To(ContainSubstring("huggingface"))
+ Expect(string(dat)).To(ContainSubstring("llama-cpp"))
+ })
+
+ It("transcribes audio", func() {
+ if runtime.GOOS != "linux" {
+ Skip("test supported only on linux")
+ }
+ resp, err := client.CreateTranscription(
+ context.Background(),
+ openai.AudioRequest{
+ Model: openai.Whisper1,
+ FilePath: filepath.Join(os.Getenv("TEST_DIR"), "audio.wav"),
+ },
+ )
+ Expect(err).ToNot(HaveOccurred())
+ Expect(resp.Text).To(ContainSubstring("This is the Micro Machine Man presenting"))
+ })
+
+ It("calculate embeddings", func() {
+ if runtime.GOOS != "linux" {
+ Skip("test supported only on linux")
+ }
+ embeddingModel := openai.AdaEmbeddingV2
+ resp, err := client.CreateEmbeddings(
+ context.Background(),
+ openai.EmbeddingRequest{
+ Model: embeddingModel,
+ Input: []string{"sun", "cat"},
+ },
+ )
+ Expect(err).ToNot(HaveOccurred(), err)
+ Expect(len(resp.Data[0].Embedding)).To(BeNumerically("==", 4096))
+ Expect(len(resp.Data[1].Embedding)).To(BeNumerically("==", 4096))
+
+ sunEmbedding := resp.Data[0].Embedding
+ resp2, err := client.CreateEmbeddings(
+ context.Background(),
+ openai.EmbeddingRequest{
+ Model: embeddingModel,
+ Input: []string{"sun"},
+ },
+ )
+ Expect(err).ToNot(HaveOccurred())
+ Expect(resp2.Data[0].Embedding).To(Equal(sunEmbedding))
+ Expect(resp2.Data[0].Embedding).ToNot(Equal(resp.Data[1].Embedding))
+
+ resp3, err := client.CreateEmbeddings(
+ context.Background(),
+ openai.EmbeddingRequest{
+ Model: embeddingModel,
+ Input: []string{"cat"},
+ },
+ )
+ Expect(err).ToNot(HaveOccurred())
+ Expect(resp3.Data[0].Embedding).To(Equal(resp.Data[1].Embedding))
+ Expect(resp3.Data[0].Embedding).ToNot(Equal(sunEmbedding))
+ })
+
+ Context("External gRPC calls", func() {
+ It("calculate embeddings with sentencetransformers", func() {
+ if runtime.GOOS != "linux" {
+ Skip("test supported only on linux")
+ }
+ resp, err := client.CreateEmbeddings(
+ context.Background(),
+ openai.EmbeddingRequest{
+ Model: openai.AdaCodeSearchCode,
+ Input: []string{"sun", "cat"},
+ },
+ )
+ Expect(err).ToNot(HaveOccurred())
+ Expect(len(resp.Data[0].Embedding)).To(BeNumerically("==", 384))
+ Expect(len(resp.Data[1].Embedding)).To(BeNumerically("==", 384))
+
+ sunEmbedding := resp.Data[0].Embedding
+ resp2, err := client.CreateEmbeddings(
+ context.Background(),
+ openai.EmbeddingRequest{
+ Model: openai.AdaCodeSearchCode,
+ Input: []string{"sun"},
+ },
+ )
+ Expect(err).ToNot(HaveOccurred())
+ Expect(resp2.Data[0].Embedding).To(Equal(sunEmbedding))
+ Expect(resp2.Data[0].Embedding).ToNot(Equal(resp.Data[1].Embedding))
+ })
+ })
+
+ // See tests/integration/stores_test
+ Context("Stores", Label("stores"), func() {
+
+ BeforeEach(func() {
+ // Only run on linux
+ if runtime.GOOS != "linux" {
+ Skip("test supported only on linux")
+ }
+ })
+
+ It("sets, gets, finds and deletes entries", func() {
+ ks := [][]float32{
+ {0.1, 0.2, 0.3},
+ {0.4, 0.5, 0.6},
+ {0.7, 0.8, 0.9},
+ }
+ vs := []string{
+ "test1",
+ "test2",
+ "test3",
+ }
+ setBody := schema.StoresSet{
+ Keys: ks,
+ Values: vs,
+ }
+
+ url := "http://127.0.0.1:9090/stores/"
+ err := postRequestJSON(url+"set", &setBody)
+ Expect(err).ToNot(HaveOccurred())
+
+ getBody := schema.StoresGet{
+ Keys: ks,
+ }
+ var getRespBody schema.StoresGetResponse
+ err = postRequestResponseJSON(url+"get", &getBody, &getRespBody)
+ Expect(err).ToNot(HaveOccurred())
+ Expect(len(getRespBody.Keys)).To(Equal(len(ks)))
+
+ for i, v := range getRespBody.Keys {
+ if v[0] == 0.1 {
+ Expect(getRespBody.Values[i]).To(Equal("test1"))
+ } else if v[0] == 0.4 {
+ Expect(getRespBody.Values[i]).To(Equal("test2"))
+ } else {
+ Expect(getRespBody.Values[i]).To(Equal("test3"))
+ }
+ }
+
+ deleteBody := schema.StoresDelete{
+ Keys: [][]float32{
+ {0.1, 0.2, 0.3},
+ },
+ }
+ err = postRequestJSON(url+"delete", &deleteBody)
+ Expect(err).ToNot(HaveOccurred())
+
+ findBody := schema.StoresFind{
+ Key: []float32{0.1, 0.3, 0.7},
+ Topk: 10,
+ }
+
+ var findRespBody schema.StoresFindResponse
+ err = postRequestResponseJSON(url+"find", &findBody, &findRespBody)
+ Expect(err).ToNot(HaveOccurred())
+ Expect(len(findRespBody.Keys)).To(Equal(2))
+
+ for i, v := range findRespBody.Keys {
+ if v[0] == 0.4 {
+ Expect(findRespBody.Values[i]).To(Equal("test2"))
+ } else {
+ Expect(findRespBody.Values[i]).To(Equal("test3"))
+ }
+
+ Expect(findRespBody.Similarities[i]).To(BeNumerically(">=", -1))
+ Expect(findRespBody.Similarities[i]).To(BeNumerically("<=", 1))
+ }
+ })
+
+ Context("Agent Jobs", Label("agent-jobs"), func() {
+ It("creates and manages tasks", func() {
+ // Create a task
+ taskBody := map[string]interface{}{
+ "name": "Test Task",
+ "description": "Test Description",
+ "model": "testmodel.ggml",
+ "prompt": "Hello {{.name}}",
+ "enabled": true,
+ }
+
+ var createResp map[string]interface{}
+ err := postRequestResponseJSON("http://127.0.0.1:9090/api/agent/tasks", &taskBody, &createResp)
+ Expect(err).ToNot(HaveOccurred())
+ Expect(createResp["id"]).ToNot(BeEmpty())
+ taskID := createResp["id"].(string)
+
+ // Get the task
+ var task schema.Task
+ resp, err := http.Get("http://127.0.0.1:9090/api/agent/tasks/" + taskID)
+ Expect(err).ToNot(HaveOccurred())
+ Expect(resp.StatusCode).To(Equal(200))
+ body, _ := io.ReadAll(resp.Body)
+ json.Unmarshal(body, &task)
+ Expect(task.Name).To(Equal("Test Task"))
+
+ // List tasks
+ resp, err = http.Get("http://127.0.0.1:9090/api/agent/tasks")
+ Expect(err).ToNot(HaveOccurred())
+ Expect(resp.StatusCode).To(Equal(200))
+ var tasks []schema.Task
+ body, _ = io.ReadAll(resp.Body)
+ json.Unmarshal(body, &tasks)
+ Expect(len(tasks)).To(BeNumerically(">=", 1))
+
+ // Update task
+ taskBody["name"] = "Updated Task"
+ err = putRequestJSON("http://127.0.0.1:9090/api/agent/tasks/"+taskID, &taskBody)
+ Expect(err).ToNot(HaveOccurred())
+
+ // Verify update
+ resp, err = http.Get("http://127.0.0.1:9090/api/agent/tasks/" + taskID)
+ Expect(err).ToNot(HaveOccurred())
+ body, _ = io.ReadAll(resp.Body)
+ json.Unmarshal(body, &task)
+ Expect(task.Name).To(Equal("Updated Task"))
+
+ // Delete task
+ req, _ := http.NewRequest("DELETE", "http://127.0.0.1:9090/api/agent/tasks/"+taskID, nil)
+ req.Header.Set("Authorization", bearerKey)
+ resp, err = http.DefaultClient.Do(req)
+ Expect(err).ToNot(HaveOccurred())
+ Expect(resp.StatusCode).To(Equal(200))
+ })
+
+ It("executes and monitors jobs", func() {
+ // Create a task first
+ taskBody := map[string]interface{}{
+ "name": "Job Test Task",
+ "model": "testmodel.ggml",
+ "prompt": "Say hello",
+ "enabled": true,
+ }
+
+ var createResp map[string]interface{}
+ err := postRequestResponseJSON("http://127.0.0.1:9090/api/agent/tasks", &taskBody, &createResp)
+ Expect(err).ToNot(HaveOccurred())
+ taskID := createResp["id"].(string)
+
+ // Execute a job
+ jobBody := map[string]interface{}{
+ "task_id": taskID,
+ "parameters": map[string]string{},
+ }
+
+ var jobResp schema.JobExecutionResponse
+ err = postRequestResponseJSON("http://127.0.0.1:9090/api/agent/jobs/execute", &jobBody, &jobResp)
+ Expect(err).ToNot(HaveOccurred())
+ Expect(jobResp.JobID).ToNot(BeEmpty())
+ jobID := jobResp.JobID
+
+ // Get job status
+ var job schema.Job
+ resp, err := http.Get("http://127.0.0.1:9090/api/agent/jobs/" + jobID)
+ Expect(err).ToNot(HaveOccurred())
+ Expect(resp.StatusCode).To(Equal(200))
+ body, _ := io.ReadAll(resp.Body)
+ json.Unmarshal(body, &job)
+ Expect(job.ID).To(Equal(jobID))
+ Expect(job.TaskID).To(Equal(taskID))
+
+ // List jobs
+ resp, err = http.Get("http://127.0.0.1:9090/api/agent/jobs")
+ Expect(err).ToNot(HaveOccurred())
+ Expect(resp.StatusCode).To(Equal(200))
+ var jobs []schema.Job
+ body, _ = io.ReadAll(resp.Body)
+ json.Unmarshal(body, &jobs)
+ Expect(len(jobs)).To(BeNumerically(">=", 1))
+
+ // Cancel job (if still pending/running)
+ if job.Status == schema.JobStatusPending || job.Status == schema.JobStatusRunning {
+ req, _ := http.NewRequest("POST", "http://127.0.0.1:9090/api/agent/jobs/"+jobID+"/cancel", nil)
+ req.Header.Set("Authorization", bearerKey)
+ resp, err = http.DefaultClient.Do(req)
+ Expect(err).ToNot(HaveOccurred())
+ Expect(resp.StatusCode).To(Equal(200))
+ }
+ })
+
+ It("executes task by name", func() {
+ // Create a task with a specific name
+ taskBody := map[string]interface{}{
+ "name": "Named Task",
+ "model": "testmodel.ggml",
+ "prompt": "Hello",
+ "enabled": true,
+ }
+
+ var createResp map[string]interface{}
+ err := postRequestResponseJSON("http://127.0.0.1:9090/api/agent/tasks", &taskBody, &createResp)
+ Expect(err).ToNot(HaveOccurred())
+
+ // Execute by name
+ paramsBody := map[string]string{"param1": "value1"}
+ var jobResp schema.JobExecutionResponse
+ err = postRequestResponseJSON("http://127.0.0.1:9090/api/agent/tasks/Named Task/execute", ¶msBody, &jobResp)
+ Expect(err).ToNot(HaveOccurred())
+ Expect(jobResp.JobID).ToNot(BeEmpty())
+ })
+ })
+ })
+ })
+
+ Context("Config file", func() {
+ BeforeEach(func() {
+ if runtime.GOOS != "linux" {
+ Skip("run this test only on linux")
+ }
+ modelPath := os.Getenv("MODELS_PATH")
+ backendPath := os.Getenv("BACKENDS_PATH")
+ c, cancel = context.WithCancel(context.Background())
+
+ var err error
+
+ systemState, err := system.GetSystemState(
+ system.WithBackendPath(backendPath),
+ system.WithModelPath(modelPath),
+ )
+ Expect(err).ToNot(HaveOccurred())
+
+ application, err := application.New(
+ append(commonOpts,
+ config.WithContext(c),
+ config.WithSystemState(systemState),
+ config.WithConfigFile(os.Getenv("CONFIG_FILE")))...,
+ )
+ Expect(err).ToNot(HaveOccurred())
+ app, err = API(application)
+ Expect(err).ToNot(HaveOccurred())
+
+ go func() {
+ if err := app.Start("127.0.0.1:9090"); err != nil && err != http.ErrServerClosed {
+ xlog.Error("server error", "error", err)
+ }
+ }()
+
+ defaultConfig := openai.DefaultConfig("")
+ defaultConfig.BaseURL = "http://127.0.0.1:9090/v1"
+ client2 = openaigo.NewClient("")
+ client2.BaseURL = defaultConfig.BaseURL
+ // Wait for API to be ready
+ client = openai.NewClientWithConfig(defaultConfig)
+ Eventually(func() error {
+ _, err := client.ListModels(context.TODO())
+ return err
+ }, "2m").ShouldNot(HaveOccurred())
+ })
+ AfterEach(func() {
+ cancel()
+ if app != nil {
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+ err := app.Shutdown(ctx)
+ Expect(err).ToNot(HaveOccurred())
+ }
+ })
+ It("can generate chat completions from config file (list1)", func() {
+ resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: "list1", Messages: []openai.ChatCompletionMessage{{Role: "user", Content: testPrompt}}})
+ Expect(err).ToNot(HaveOccurred())
+ Expect(len(resp.Choices)).To(Equal(1))
+ Expect(resp.Choices[0].Message.Content).ToNot(BeEmpty())
+ })
+ It("can generate chat completions from config file (list2)", func() {
+ resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: "list2", Messages: []openai.ChatCompletionMessage{{Role: "user", Content: testPrompt}}})
+ Expect(err).ToNot(HaveOccurred())
+ Expect(len(resp.Choices)).To(Equal(1))
+ Expect(resp.Choices[0].Message.Content).ToNot(BeEmpty())
+ })
+ It("can generate edit completions from config file", func() {
+ request := openaigo.EditCreateRequestBody{
+ Model: "list2",
+ Instruction: "foo",
+ Input: "bar",
+ }
+ resp, err := client2.CreateEdit(context.Background(), request)
+ Expect(err).ToNot(HaveOccurred())
+ Expect(len(resp.Choices)).To(Equal(1))
+ Expect(resp.Choices[0].Text).ToNot(BeEmpty())
+ })
+
+ })
+})
diff --git a/core/http/endpoints/anthropic/messages.go b/core/http/endpoints/anthropic/messages.go
new file mode 100644
index 0000000000000000000000000000000000000000..389d604665918e9922f26231e7d6cb84d1f8879a
--- /dev/null
+++ b/core/http/endpoints/anthropic/messages.go
@@ -0,0 +1,537 @@
+package anthropic
+
+import (
+ "encoding/json"
+ "fmt"
+
+ "github.com/google/uuid"
+ "github.com/labstack/echo/v4"
+ "github.com/mudler/LocalAI/core/backend"
+ "github.com/mudler/LocalAI/core/config"
+ "github.com/mudler/LocalAI/core/http/middleware"
+ "github.com/mudler/LocalAI/core/schema"
+ "github.com/mudler/LocalAI/core/templates"
+ "github.com/mudler/LocalAI/pkg/functions"
+ "github.com/mudler/LocalAI/pkg/model"
+ "github.com/mudler/xlog"
+)
+
+// MessagesEndpoint is the Anthropic Messages API endpoint
+// https://docs.anthropic.com/claude/reference/messages_post
+// @Summary Generate a message response for the given messages and model.
+// @Param request body schema.AnthropicRequest true "query params"
+// @Success 200 {object} schema.AnthropicResponse "Response"
+// @Router /v1/messages [post]
+func MessagesEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) echo.HandlerFunc {
+ return func(c echo.Context) error {
+ id := uuid.New().String()
+
+ input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.AnthropicRequest)
+ if !ok || input.Model == "" {
+ return sendAnthropicError(c, 400, "invalid_request_error", "model is required")
+ }
+
+ cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
+ if !ok || cfg == nil {
+ return sendAnthropicError(c, 400, "invalid_request_error", "model configuration not found")
+ }
+
+ if input.MaxTokens <= 0 {
+ return sendAnthropicError(c, 400, "invalid_request_error", "max_tokens is required and must be greater than 0")
+ }
+
+ xlog.Debug("Anthropic Messages endpoint configuration read", "config", cfg)
+
+ // Convert Anthropic messages to OpenAI format for internal processing
+ openAIMessages := convertAnthropicToOpenAIMessages(input)
+
+ // Convert Anthropic tools to internal Functions format
+ funcs, shouldUseFn := convertAnthropicTools(input, cfg)
+
+ // Create an OpenAI-compatible request for internal processing
+ openAIReq := &schema.OpenAIRequest{
+ PredictionOptions: schema.PredictionOptions{
+ BasicModelRequest: schema.BasicModelRequest{Model: input.Model},
+ Temperature: input.Temperature,
+ TopK: input.TopK,
+ TopP: input.TopP,
+ Maxtokens: &input.MaxTokens,
+ },
+ Messages: openAIMessages,
+ Stream: input.Stream,
+ Context: input.Context,
+ Cancel: input.Cancel,
+ }
+
+ // Set stop sequences
+ if len(input.StopSequences) > 0 {
+ openAIReq.Stop = input.StopSequences
+ }
+
+ // Merge config settings
+ if input.Temperature != nil {
+ cfg.Temperature = input.Temperature
+ }
+ if input.TopK != nil {
+ cfg.TopK = input.TopK
+ }
+ if input.TopP != nil {
+ cfg.TopP = input.TopP
+ }
+ cfg.Maxtokens = &input.MaxTokens
+ if len(input.StopSequences) > 0 {
+ cfg.StopWords = append(cfg.StopWords, input.StopSequences...)
+ }
+
+ // Template the prompt with tools if available
+ predInput := evaluator.TemplateMessages(*openAIReq, openAIReq.Messages, cfg, funcs, shouldUseFn)
+ xlog.Debug("Anthropic Messages - Prompt (after templating)", "prompt", predInput)
+
+ if input.Stream {
+ return handleAnthropicStream(c, id, input, cfg, ml, predInput, openAIReq, funcs, shouldUseFn)
+ }
+
+ return handleAnthropicNonStream(c, id, input, cfg, ml, predInput, openAIReq, funcs, shouldUseFn)
+ }
+}
+
+func handleAnthropicNonStream(c echo.Context, id string, input *schema.AnthropicRequest, cfg *config.ModelConfig, ml *model.ModelLoader, predInput string, openAIReq *schema.OpenAIRequest, funcs functions.Functions, shouldUseFn bool) error {
+ images := []string{}
+ for _, m := range openAIReq.Messages {
+ images = append(images, m.StringImages...)
+ }
+
+ predFunc, err := backend.ModelInference(
+ input.Context, predInput, openAIReq.Messages, images, nil, nil, ml, cfg, nil, nil, nil, "", "", nil, nil, nil)
+ if err != nil {
+ xlog.Error("Anthropic model inference failed", "error", err)
+ return sendAnthropicError(c, 500, "api_error", fmt.Sprintf("model inference failed: %v", err))
+ }
+
+ prediction, err := predFunc()
+ if err != nil {
+ xlog.Error("Anthropic prediction failed", "error", err)
+ return sendAnthropicError(c, 500, "api_error", fmt.Sprintf("prediction failed: %v", err))
+ }
+
+ result := backend.Finetune(*cfg, predInput, prediction.Response)
+
+ // Check if the result contains tool calls
+ toolCalls := functions.ParseFunctionCall(result, cfg.FunctionsConfig)
+
+ var contentBlocks []schema.AnthropicContentBlock
+ var stopReason string
+
+ if shouldUseFn && len(toolCalls) > 0 {
+ // Model wants to use tools
+ stopReason = "tool_use"
+ for _, tc := range toolCalls {
+ // Parse arguments as JSON
+ var inputArgs map[string]interface{}
+ if err := json.Unmarshal([]byte(tc.Arguments), &inputArgs); err != nil {
+ xlog.Warn("Failed to parse tool call arguments as JSON", "error", err, "args", tc.Arguments)
+ inputArgs = map[string]interface{}{"raw": tc.Arguments}
+ }
+
+ contentBlocks = append(contentBlocks, schema.AnthropicContentBlock{
+ Type: "tool_use",
+ ID: fmt.Sprintf("toolu_%s_%d", id, len(contentBlocks)),
+ Name: tc.Name,
+ Input: inputArgs,
+ })
+ }
+
+ // Add any text content before the tool calls
+ textContent := functions.ParseTextContent(result, cfg.FunctionsConfig)
+ if textContent != "" {
+ // Prepend text block
+ contentBlocks = append([]schema.AnthropicContentBlock{{Type: "text", Text: textContent}}, contentBlocks...)
+ }
+ } else {
+ // Normal text response
+ stopReason = "end_turn"
+ contentBlocks = []schema.AnthropicContentBlock{
+ {Type: "text", Text: result},
+ }
+ }
+
+ resp := &schema.AnthropicResponse{
+ ID: fmt.Sprintf("msg_%s", id),
+ Type: "message",
+ Role: "assistant",
+ Model: input.Model,
+ StopReason: &stopReason,
+ Content: contentBlocks,
+ Usage: schema.AnthropicUsage{
+ InputTokens: prediction.Usage.Prompt,
+ OutputTokens: prediction.Usage.Completion,
+ },
+ }
+
+ if respData, err := json.Marshal(resp); err == nil {
+ xlog.Debug("Anthropic Response", "response", string(respData))
+ }
+
+ return c.JSON(200, resp)
+}
+
+func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicRequest, cfg *config.ModelConfig, ml *model.ModelLoader, predInput string, openAIReq *schema.OpenAIRequest, funcs functions.Functions, shouldUseFn bool) error {
+ c.Response().Header().Set("Content-Type", "text/event-stream")
+ c.Response().Header().Set("Cache-Control", "no-cache")
+ c.Response().Header().Set("Connection", "keep-alive")
+
+ // Create OpenAI messages for inference
+ openAIMessages := openAIReq.Messages
+
+ images := []string{}
+ for _, m := range openAIMessages {
+ images = append(images, m.StringImages...)
+ }
+
+ // Send message_start event
+ messageStart := schema.AnthropicStreamEvent{
+ Type: "message_start",
+ Message: &schema.AnthropicStreamMessage{
+ ID: fmt.Sprintf("msg_%s", id),
+ Type: "message",
+ Role: "assistant",
+ Content: []schema.AnthropicContentBlock{},
+ Model: input.Model,
+ Usage: schema.AnthropicUsage{InputTokens: 0, OutputTokens: 0},
+ },
+ }
+ sendAnthropicSSE(c, messageStart)
+
+ // Track accumulated content for tool call detection
+ accumulatedContent := ""
+ currentBlockIndex := 0
+ inToolCall := false
+ toolCallsEmitted := 0
+
+ // Send initial content_block_start event
+ contentBlockStart := schema.AnthropicStreamEvent{
+ Type: "content_block_start",
+ Index: currentBlockIndex,
+ ContentBlock: &schema.AnthropicContentBlock{Type: "text", Text: ""},
+ }
+ sendAnthropicSSE(c, contentBlockStart)
+
+ // Stream content deltas
+ tokenCallback := func(token string, usage backend.TokenUsage) bool {
+ accumulatedContent += token
+
+ // If we're using functions, try to detect tool calls incrementally
+ if shouldUseFn {
+ cleanedResult := functions.CleanupLLMResult(accumulatedContent, cfg.FunctionsConfig)
+
+ // Try parsing for tool calls
+ toolCalls := functions.ParseFunctionCall(cleanedResult, cfg.FunctionsConfig)
+
+ // If we detected new tool calls and haven't emitted them yet
+ if len(toolCalls) > toolCallsEmitted {
+ // Stop the current text block if we were in one
+ if !inToolCall && currentBlockIndex == 0 {
+ sendAnthropicSSE(c, schema.AnthropicStreamEvent{
+ Type: "content_block_stop",
+ Index: currentBlockIndex,
+ })
+ currentBlockIndex++
+ inToolCall = true
+ }
+
+ // Emit new tool calls
+ for i := toolCallsEmitted; i < len(toolCalls); i++ {
+ tc := toolCalls[i]
+
+ // Send content_block_start for tool_use
+ sendAnthropicSSE(c, schema.AnthropicStreamEvent{
+ Type: "content_block_start",
+ Index: currentBlockIndex,
+ ContentBlock: &schema.AnthropicContentBlock{
+ Type: "tool_use",
+ ID: fmt.Sprintf("toolu_%s_%d", id, i),
+ Name: tc.Name,
+ },
+ })
+
+ // Send input_json_delta with the arguments
+ sendAnthropicSSE(c, schema.AnthropicStreamEvent{
+ Type: "content_block_delta",
+ Index: currentBlockIndex,
+ Delta: &schema.AnthropicStreamDelta{
+ Type: "input_json_delta",
+ PartialJSON: tc.Arguments,
+ },
+ })
+
+ // Send content_block_stop
+ sendAnthropicSSE(c, schema.AnthropicStreamEvent{
+ Type: "content_block_stop",
+ Index: currentBlockIndex,
+ })
+
+ currentBlockIndex++
+ }
+ toolCallsEmitted = len(toolCalls)
+ return true
+ }
+ }
+
+ // Send regular text delta if not in tool call mode
+ if !inToolCall {
+ delta := schema.AnthropicStreamEvent{
+ Type: "content_block_delta",
+ Index: 0,
+ Delta: &schema.AnthropicStreamDelta{
+ Type: "text_delta",
+ Text: token,
+ },
+ }
+ sendAnthropicSSE(c, delta)
+ }
+ return true
+ }
+
+ predFunc, err := backend.ModelInference(
+ input.Context, predInput, openAIMessages, images, nil, nil, ml, cfg, nil, nil, tokenCallback, "", "", nil, nil, nil)
+ if err != nil {
+ xlog.Error("Anthropic stream model inference failed", "error", err)
+ return sendAnthropicError(c, 500, "api_error", fmt.Sprintf("model inference failed: %v", err))
+ }
+
+ prediction, err := predFunc()
+ if err != nil {
+ xlog.Error("Anthropic stream prediction failed", "error", err)
+ return sendAnthropicError(c, 500, "api_error", fmt.Sprintf("prediction failed: %v", err))
+ }
+
+ // Send content_block_stop event for last block if we didn't close it yet
+ if !inToolCall {
+ contentBlockStop := schema.AnthropicStreamEvent{
+ Type: "content_block_stop",
+ Index: 0,
+ }
+ sendAnthropicSSE(c, contentBlockStop)
+ }
+
+ // Determine stop reason
+ stopReason := "end_turn"
+ if toolCallsEmitted > 0 {
+ stopReason = "tool_use"
+ }
+
+ // Send message_delta event with stop_reason
+ messageDelta := schema.AnthropicStreamEvent{
+ Type: "message_delta",
+ Delta: &schema.AnthropicStreamDelta{
+ StopReason: &stopReason,
+ },
+ Usage: &schema.AnthropicUsage{
+ OutputTokens: prediction.Usage.Completion,
+ },
+ }
+ sendAnthropicSSE(c, messageDelta)
+
+ // Send message_stop event
+ messageStop := schema.AnthropicStreamEvent{
+ Type: "message_stop",
+ }
+ sendAnthropicSSE(c, messageStop)
+
+ return nil
+}
+
+func sendAnthropicSSE(c echo.Context, event schema.AnthropicStreamEvent) {
+ data, err := json.Marshal(event)
+ if err != nil {
+ xlog.Error("Failed to marshal SSE event", "error", err)
+ return
+ }
+ fmt.Fprintf(c.Response().Writer, "event: %s\ndata: %s\n\n", event.Type, string(data))
+ c.Response().Flush()
+}
+
+func sendAnthropicError(c echo.Context, statusCode int, errorType, message string) error {
+ resp := schema.AnthropicErrorResponse{
+ Type: "error",
+ Error: schema.AnthropicError{
+ Type: errorType,
+ Message: message,
+ },
+ }
+ return c.JSON(statusCode, resp)
+}
+
+func convertAnthropicToOpenAIMessages(input *schema.AnthropicRequest) []schema.Message {
+ var messages []schema.Message
+
+ // Add system message if present
+ if input.System != "" {
+ messages = append(messages, schema.Message{
+ Role: "system",
+ StringContent: input.System,
+ Content: input.System,
+ })
+ }
+
+ // Convert Anthropic messages to OpenAI format
+ for _, msg := range input.Messages {
+ openAIMsg := schema.Message{
+ Role: msg.Role,
+ }
+
+ // Handle content (can be string or array of content blocks)
+ switch content := msg.Content.(type) {
+ case string:
+ openAIMsg.StringContent = content
+ openAIMsg.Content = content
+ case []interface{}:
+ // Handle array of content blocks
+ var textContent string
+ var stringImages []string
+ var toolCalls []schema.ToolCall
+ toolCallIndex := 0
+
+ for _, block := range content {
+ if blockMap, ok := block.(map[string]interface{}); ok {
+ blockType, _ := blockMap["type"].(string)
+ switch blockType {
+ case "text":
+ if text, ok := blockMap["text"].(string); ok {
+ textContent += text
+ }
+ case "image":
+ // Handle image content
+ if source, ok := blockMap["source"].(map[string]interface{}); ok {
+ if sourceType, ok := source["type"].(string); ok && sourceType == "base64" {
+ if data, ok := source["data"].(string); ok {
+ mediaType, _ := source["media_type"].(string)
+ // Format as data URI
+ dataURI := fmt.Sprintf("data:%s;base64,%s", mediaType, data)
+ stringImages = append(stringImages, dataURI)
+ }
+ }
+ }
+ case "tool_use":
+ // Convert tool_use to ToolCall format
+ toolID, _ := blockMap["id"].(string)
+ toolName, _ := blockMap["name"].(string)
+ toolInput := blockMap["input"]
+
+ // Serialize input to JSON string
+ inputJSON, err := json.Marshal(toolInput)
+ if err != nil {
+ xlog.Warn("Failed to marshal tool input", "error", err)
+ inputJSON = []byte("{}")
+ }
+
+ toolCalls = append(toolCalls, schema.ToolCall{
+ Index: toolCallIndex,
+ ID: toolID,
+ Type: "function",
+ FunctionCall: schema.FunctionCall{
+ Name: toolName,
+ Arguments: string(inputJSON),
+ },
+ })
+ toolCallIndex++
+ case "tool_result":
+ // Convert tool_result to a message with role "tool"
+ // This is handled by creating a separate message after this block
+ // For now, we'll add it as text content
+ toolUseID, _ := blockMap["tool_use_id"].(string)
+ isError := false
+ if isErrorPtr, ok := blockMap["is_error"].(*bool); ok && isErrorPtr != nil {
+ isError = *isErrorPtr
+ }
+
+ var resultText string
+ if resultContent, ok := blockMap["content"]; ok {
+ switch rc := resultContent.(type) {
+ case string:
+ resultText = rc
+ case []interface{}:
+ // Array of content blocks
+ for _, cb := range rc {
+ if cbMap, ok := cb.(map[string]interface{}); ok {
+ if cbMap["type"] == "text" {
+ if text, ok := cbMap["text"].(string); ok {
+ resultText += text
+ }
+ }
+ }
+ }
+ }
+ }
+
+ // Add tool result as a tool role message
+ // We need to handle this differently - create a new message
+ if msg.Role == "user" {
+ // Store tool result info for creating separate message
+ prefix := ""
+ if isError {
+ prefix = "Error: "
+ }
+ textContent += fmt.Sprintf("\n[Tool Result for %s]: %s%s", toolUseID, prefix, resultText)
+ }
+ }
+ }
+ }
+ openAIMsg.StringContent = textContent
+ openAIMsg.Content = textContent
+ openAIMsg.StringImages = stringImages
+
+ // Add tool calls if present
+ if len(toolCalls) > 0 {
+ openAIMsg.ToolCalls = toolCalls
+ }
+ }
+
+ messages = append(messages, openAIMsg)
+ }
+
+ return messages
+}
+
+// convertAnthropicTools converts Anthropic tools to internal Functions format
+func convertAnthropicTools(input *schema.AnthropicRequest, cfg *config.ModelConfig) (functions.Functions, bool) {
+ if len(input.Tools) == 0 {
+ return nil, false
+ }
+
+ var funcs functions.Functions
+ for _, tool := range input.Tools {
+ f := functions.Function{
+ Name: tool.Name,
+ Description: tool.Description,
+ Parameters: tool.InputSchema,
+ }
+ funcs = append(funcs, f)
+ }
+
+ // Handle tool_choice
+ if input.ToolChoice != nil {
+ switch tc := input.ToolChoice.(type) {
+ case string:
+ // "auto", "any", or "none"
+ if tc == "any" {
+ // Force the model to use one of the tools
+ cfg.SetFunctionCallString("required")
+ } else if tc == "none" {
+ // Don't use tools
+ return nil, false
+ }
+ // "auto" is the default - let model decide
+ case map[string]interface{}:
+ // Specific tool selection: {"type": "tool", "name": "tool_name"}
+ if tcType, ok := tc["type"].(string); ok && tcType == "tool" {
+ if name, ok := tc["name"].(string); ok {
+ // Force specific tool
+ cfg.SetFunctionCallString(name)
+ }
+ }
+ }
+ }
+
+ return funcs, len(funcs) > 0 && cfg.ShouldUseFunctions()
+}
diff --git a/core/http/endpoints/elevenlabs/soundgeneration.go b/core/http/endpoints/elevenlabs/soundgeneration.go
new file mode 100644
index 0000000000000000000000000000000000000000..d292b81cd5b8f0101d82cbac7ba1478f613646e4
--- /dev/null
+++ b/core/http/endpoints/elevenlabs/soundgeneration.go
@@ -0,0 +1,43 @@
+package elevenlabs
+
+import (
+ "path/filepath"
+
+ "github.com/labstack/echo/v4"
+ "github.com/mudler/LocalAI/core/backend"
+ "github.com/mudler/LocalAI/core/config"
+ "github.com/mudler/LocalAI/core/http/middleware"
+ "github.com/mudler/LocalAI/core/schema"
+ "github.com/mudler/LocalAI/pkg/model"
+ "github.com/mudler/xlog"
+)
+
+// SoundGenerationEndpoint is the ElevenLabs SoundGeneration endpoint https://elevenlabs.io/docs/api-reference/sound-generation
+// @Summary Generates audio from the input text.
+// @Param request body schema.ElevenLabsSoundGenerationRequest true "query params"
+// @Success 200 {string} binary "Response"
+// @Router /v1/sound-generation [post]
+func SoundGenerationEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
+ return func(c echo.Context) error {
+
+ input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.ElevenLabsSoundGenerationRequest)
+ if !ok || input.ModelID == "" {
+ return echo.ErrBadRequest
+ }
+
+ cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
+ if !ok || cfg == nil {
+ return echo.ErrBadRequest
+ }
+
+ xlog.Debug("Sound Generation Request about to be sent to backend", "modelFile", "modelFile", "backend", cfg.Backend)
+
+ // TODO: Support uploading files?
+ filePath, _, err := backend.SoundGeneration(input.Text, input.Duration, input.Temperature, input.DoSample, nil, nil, ml, appConfig, *cfg)
+ if err != nil {
+ return err
+ }
+ return c.Attachment(filePath, filepath.Base(filePath))
+
+ }
+}
diff --git a/core/http/endpoints/elevenlabs/tts.go b/core/http/endpoints/elevenlabs/tts.go
new file mode 100644
index 0000000000000000000000000000000000000000..658eb56baa19df7cc6a696d0eeb2f753bb593581
--- /dev/null
+++ b/core/http/endpoints/elevenlabs/tts.go
@@ -0,0 +1,44 @@
+package elevenlabs
+
+import (
+ "path/filepath"
+
+ "github.com/labstack/echo/v4"
+ "github.com/mudler/LocalAI/core/backend"
+ "github.com/mudler/LocalAI/core/config"
+ "github.com/mudler/LocalAI/core/http/middleware"
+ "github.com/mudler/LocalAI/core/schema"
+ "github.com/mudler/LocalAI/pkg/model"
+ "github.com/mudler/xlog"
+)
+
+// TTSEndpoint is the OpenAI Speech API endpoint https://platform.openai.com/docs/api-reference/audio/createSpeech
+// @Summary Generates audio from the input text.
+// @Param voice-id path string true "Account ID"
+// @Param request body schema.TTSRequest true "query params"
+// @Success 200 {string} binary "Response"
+// @Router /v1/text-to-speech/{voice-id} [post]
+func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
+ return func(c echo.Context) error {
+
+ voiceID := c.Param("voice-id")
+
+ input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.ElevenLabsTTSRequest)
+ if !ok || input.ModelID == "" {
+ return echo.ErrBadRequest
+ }
+
+ cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
+ if !ok || cfg == nil {
+ return echo.ErrBadRequest
+ }
+
+ xlog.Debug("elevenlabs TTS request received", "modelName", input.ModelID)
+
+ filePath, _, err := backend.ModelTTS(input.Text, voiceID, input.LanguageCode, ml, appConfig, *cfg)
+ if err != nil {
+ return err
+ }
+ return c.Attachment(filePath, filepath.Base(filePath))
+ }
+}
diff --git a/core/http/endpoints/explorer/dashboard.go b/core/http/endpoints/explorer/dashboard.go
new file mode 100644
index 0000000000000000000000000000000000000000..3c1e0ae913377e013da9f45476256be5537ee03d
--- /dev/null
+++ b/core/http/endpoints/explorer/dashboard.go
@@ -0,0 +1,108 @@
+package explorer
+
+import (
+ "encoding/base64"
+ "net/http"
+ "sort"
+ "strings"
+
+ "github.com/labstack/echo/v4"
+ "github.com/mudler/LocalAI/core/explorer"
+ "github.com/mudler/LocalAI/core/http/middleware"
+ "github.com/mudler/LocalAI/internal"
+)
+
+func Dashboard() echo.HandlerFunc {
+ return func(c echo.Context) error {
+ summary := map[string]interface{}{
+ "Title": "LocalAI API - " + internal.PrintableVersion(),
+ "Version": internal.PrintableVersion(),
+ "BaseURL": middleware.BaseURL(c),
+ }
+
+ contentType := c.Request().Header.Get("Content-Type")
+ accept := c.Request().Header.Get("Accept")
+ if strings.Contains(contentType, "application/json") || (accept != "" && !strings.Contains(accept, "html")) {
+ // The client expects a JSON response
+ return c.JSON(http.StatusOK, summary)
+ } else {
+ // Render index
+ return c.Render(http.StatusOK, "views/explorer", summary)
+ }
+ }
+}
+
+type AddNetworkRequest struct {
+ Token string `json:"token"`
+ Name string `json:"name"`
+ Description string `json:"description"`
+}
+
+type Network struct {
+ explorer.TokenData
+ Token string `json:"token"`
+}
+
+func ShowNetworks(db *explorer.Database) echo.HandlerFunc {
+ return func(c echo.Context) error {
+ results := []Network{}
+ for _, token := range db.TokenList() {
+ networkData, exists := db.Get(token) // get the token data
+ hasWorkers := false
+ for _, cluster := range networkData.Clusters {
+ if len(cluster.Workers) > 0 {
+ hasWorkers = true
+ break
+ }
+ }
+ if exists && hasWorkers {
+ results = append(results, Network{TokenData: networkData, Token: token})
+ }
+ }
+
+ // order by number of clusters
+ sort.Slice(results, func(i, j int) bool {
+ return len(results[i].Clusters) > len(results[j].Clusters)
+ })
+
+ return c.JSON(http.StatusOK, results)
+ }
+}
+
+func AddNetwork(db *explorer.Database) echo.HandlerFunc {
+ return func(c echo.Context) error {
+ request := new(AddNetworkRequest)
+ if err := c.Bind(request); err != nil {
+ return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Cannot parse JSON"})
+ }
+
+ if request.Token == "" {
+ return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Token is required"})
+ }
+
+ if request.Name == "" {
+ return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Name is required"})
+ }
+
+ if request.Description == "" {
+ return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Description is required"})
+ }
+
+ // TODO: check if token is valid, otherwise reject
+ // try to decode the token from base64
+ _, err := base64.StdEncoding.DecodeString(request.Token)
+ if err != nil {
+ return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid token"})
+ }
+
+ if _, exists := db.Get(request.Token); exists {
+ return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Token already exists"})
+ }
+ err = db.Set(request.Token, explorer.TokenData{Name: request.Name, Description: request.Description})
+ if err != nil {
+ return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Cannot add token"})
+ }
+
+ return c.JSON(http.StatusOK, map[string]interface{}{"message": "Token added"})
+ }
+}
diff --git a/core/http/endpoints/jina/rerank.go b/core/http/endpoints/jina/rerank.go
new file mode 100644
index 0000000000000000000000000000000000000000..330fb94a4396e353c9297c1e5d91f5beb8d8d544
--- /dev/null
+++ b/core/http/endpoints/jina/rerank.go
@@ -0,0 +1,76 @@
+package jina
+
+import (
+ "net/http"
+
+ "github.com/labstack/echo/v4"
+ "github.com/mudler/LocalAI/core/backend"
+ "github.com/mudler/LocalAI/core/config"
+ "github.com/mudler/LocalAI/core/http/middleware"
+ "github.com/mudler/LocalAI/core/schema"
+ "github.com/mudler/LocalAI/pkg/grpc/proto"
+ "github.com/mudler/LocalAI/pkg/model"
+ "github.com/mudler/xlog"
+)
+
+// JINARerankEndpoint acts like the Jina reranker endpoint (https://jina.ai/reranker/)
+// @Summary Reranks a list of phrases by relevance to a given text query.
+// @Param request body schema.JINARerankRequest true "query params"
+// @Success 200 {object} schema.JINARerankResponse "Response"
+// @Router /v1/rerank [post]
+func JINARerankEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
+ return func(c echo.Context) error {
+
+ input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.JINARerankRequest)
+ if !ok || input.Model == "" {
+ return echo.ErrBadRequest
+ }
+
+ cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
+ if !ok || cfg == nil {
+ return echo.ErrBadRequest
+ }
+
+ xlog.Debug("JINA Rerank Request received", "model", input.Model)
+ var requestTopN int32
+ docs := int32(len(input.Documents))
+ if input.TopN == nil { // omit top_n to get all
+ requestTopN = docs
+ } else {
+ requestTopN = int32(*input.TopN)
+ if requestTopN < 1 {
+ return c.JSON(http.StatusUnprocessableEntity, "top_n - should be greater than or equal to 1")
+ }
+ if requestTopN > docs { // make it more obvious for backends
+ requestTopN = docs
+ }
+ }
+ request := &proto.RerankRequest{
+ Query: input.Query,
+ TopN: requestTopN,
+ Documents: input.Documents,
+ }
+
+ results, err := backend.Rerank(request, ml, appConfig, *cfg)
+ if err != nil {
+ return err
+ }
+
+ response := &schema.JINARerankResponse{
+ Model: input.Model,
+ }
+
+ for _, r := range results.Results {
+ response.Results = append(response.Results, schema.JINADocumentResult{
+ Index: int(r.Index),
+ Document: schema.JINAText{Text: r.Text},
+ RelevanceScore: float64(r.RelevanceScore),
+ })
+ }
+
+ response.Usage.TotalTokens = int(results.Usage.TotalTokens)
+ response.Usage.PromptTokens = int(results.Usage.PromptTokens)
+
+ return c.JSON(http.StatusOK, response)
+ }
+}
diff --git a/core/http/endpoints/localai/agent_jobs.go b/core/http/endpoints/localai/agent_jobs.go
new file mode 100644
index 0000000000000000000000000000000000000000..c46a0208a10f38334fab8c28e9711015c0ee837d
--- /dev/null
+++ b/core/http/endpoints/localai/agent_jobs.go
@@ -0,0 +1,349 @@
+package localai
+
+import (
+ "fmt"
+ "net/http"
+ "strconv"
+
+ "github.com/labstack/echo/v4"
+ "github.com/mudler/LocalAI/core/application"
+ "github.com/mudler/LocalAI/core/schema"
+)
+
+// CreateTaskEndpoint creates a new agent task
+// @Summary Create a new agent task
+// @Description Create a new reusable agent task with prompt template and configuration
+// @Tags agent-jobs
+// @Accept json
+// @Produce json
+// @Param task body schema.Task true "Task definition"
+// @Success 201 {object} map[string]string "Task created"
+// @Failure 400 {object} map[string]string "Invalid request"
+// @Failure 500 {object} map[string]string "Internal server error"
+// @Router /api/agent/tasks [post]
+func CreateTaskEndpoint(app *application.Application) echo.HandlerFunc {
+ return func(c echo.Context) error {
+ var task schema.Task
+ if err := c.Bind(&task); err != nil {
+ return c.JSON(http.StatusBadRequest, map[string]string{"error": "Invalid request body: " + err.Error()})
+ }
+
+ id, err := app.AgentJobService().CreateTask(task)
+ if err != nil {
+ return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
+ }
+
+ return c.JSON(http.StatusCreated, map[string]string{"id": id})
+ }
+}
+
+// UpdateTaskEndpoint updates an existing task
+// @Summary Update an agent task
+// @Description Update an existing agent task
+// @Tags agent-jobs
+// @Accept json
+// @Produce json
+// @Param id path string true "Task ID"
+// @Param task body schema.Task true "Updated task definition"
+// @Success 200 {object} map[string]string "Task updated"
+// @Failure 400 {object} map[string]string "Invalid request"
+// @Failure 404 {object} map[string]string "Task not found"
+// @Router /api/agent/tasks/{id} [put]
+func UpdateTaskEndpoint(app *application.Application) echo.HandlerFunc {
+ return func(c echo.Context) error {
+ id := c.Param("id")
+ var task schema.Task
+ if err := c.Bind(&task); err != nil {
+ return c.JSON(http.StatusBadRequest, map[string]string{"error": "Invalid request body: " + err.Error()})
+ }
+
+ if err := app.AgentJobService().UpdateTask(id, task); err != nil {
+ if err.Error() == "task not found: "+id {
+ return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
+ }
+ return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
+ }
+
+ return c.JSON(http.StatusOK, map[string]string{"message": "Task updated"})
+ }
+}
+
+// DeleteTaskEndpoint deletes a task
+// @Summary Delete an agent task
+// @Description Delete an agent task by ID
+// @Tags agent-jobs
+// @Produce json
+// @Param id path string true "Task ID"
+// @Success 200 {object} map[string]string "Task deleted"
+// @Failure 404 {object} map[string]string "Task not found"
+// @Router /api/agent/tasks/{id} [delete]
+func DeleteTaskEndpoint(app *application.Application) echo.HandlerFunc {
+ return func(c echo.Context) error {
+ id := c.Param("id")
+ if err := app.AgentJobService().DeleteTask(id); err != nil {
+ if err.Error() == "task not found: "+id {
+ return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
+ }
+ return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
+ }
+
+ return c.JSON(http.StatusOK, map[string]string{"message": "Task deleted"})
+ }
+}
+
+// ListTasksEndpoint lists all tasks
+// @Summary List all agent tasks
+// @Description Get a list of all agent tasks
+// @Tags agent-jobs
+// @Produce json
+// @Success 200 {array} schema.Task "List of tasks"
+// @Router /api/agent/tasks [get]
+func ListTasksEndpoint(app *application.Application) echo.HandlerFunc {
+ return func(c echo.Context) error {
+ tasks := app.AgentJobService().ListTasks()
+ return c.JSON(http.StatusOK, tasks)
+ }
+}
+
+// GetTaskEndpoint gets a task by ID
+// @Summary Get an agent task
+// @Description Get an agent task by ID
+// @Tags agent-jobs
+// @Produce json
+// @Param id path string true "Task ID"
+// @Success 200 {object} schema.Task "Task details"
+// @Failure 404 {object} map[string]string "Task not found"
+// @Router /api/agent/tasks/{id} [get]
+func GetTaskEndpoint(app *application.Application) echo.HandlerFunc {
+ return func(c echo.Context) error {
+ id := c.Param("id")
+ task, err := app.AgentJobService().GetTask(id)
+ if err != nil {
+ return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
+ }
+
+ return c.JSON(http.StatusOK, task)
+ }
+}
+
+// ExecuteJobEndpoint executes a job
+// @Summary Execute an agent job
+// @Description Create and execute a new agent job
+// @Tags agent-jobs
+// @Accept json
+// @Produce json
+// @Param request body schema.JobExecutionRequest true "Job execution request"
+// @Success 201 {object} schema.JobExecutionResponse "Job created"
+// @Failure 400 {object} map[string]string "Invalid request"
+// @Router /api/agent/jobs/execute [post]
+func ExecuteJobEndpoint(app *application.Application) echo.HandlerFunc {
+ return func(c echo.Context) error {
+ var req schema.JobExecutionRequest
+ if err := c.Bind(&req); err != nil {
+ return c.JSON(http.StatusBadRequest, map[string]string{"error": "Invalid request body: " + err.Error()})
+ }
+
+ if req.Parameters == nil {
+ req.Parameters = make(map[string]string)
+ }
+
+ // Build multimedia struct from request
+ var multimedia *schema.MultimediaAttachment
+ if len(req.Images) > 0 || len(req.Videos) > 0 || len(req.Audios) > 0 || len(req.Files) > 0 {
+ multimedia = &schema.MultimediaAttachment{
+ Images: req.Images,
+ Videos: req.Videos,
+ Audios: req.Audios,
+ Files: req.Files,
+ }
+ }
+
+ jobID, err := app.AgentJobService().ExecuteJob(req.TaskID, req.Parameters, "api", multimedia)
+ if err != nil {
+ return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
+ }
+
+ baseURL := c.Scheme() + "://" + c.Request().Host
+ return c.JSON(http.StatusCreated, schema.JobExecutionResponse{
+ JobID: jobID,
+ Status: "pending",
+ URL: baseURL + "/api/agent/jobs/" + jobID,
+ })
+ }
+}
+
+// GetJobEndpoint gets a job by ID
+// @Summary Get an agent job
+// @Description Get an agent job by ID
+// @Tags agent-jobs
+// @Produce json
+// @Param id path string true "Job ID"
+// @Success 200 {object} schema.Job "Job details"
+// @Failure 404 {object} map[string]string "Job not found"
+// @Router /api/agent/jobs/{id} [get]
+func GetJobEndpoint(app *application.Application) echo.HandlerFunc {
+ return func(c echo.Context) error {
+ id := c.Param("id")
+ job, err := app.AgentJobService().GetJob(id)
+ if err != nil {
+ return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
+ }
+
+ return c.JSON(http.StatusOK, job)
+ }
+}
+
+// ListJobsEndpoint lists jobs with optional filtering
+// @Summary List agent jobs
+// @Description Get a list of agent jobs, optionally filtered by task_id and status
+// @Tags agent-jobs
+// @Produce json
+// @Param task_id query string false "Filter by task ID"
+// @Param status query string false "Filter by status (pending, running, completed, failed, cancelled)"
+// @Param limit query int false "Limit number of results"
+// @Success 200 {array} schema.Job "List of jobs"
+// @Router /api/agent/jobs [get]
+func ListJobsEndpoint(app *application.Application) echo.HandlerFunc {
+ return func(c echo.Context) error {
+ var taskID *string
+ var status *schema.JobStatus
+ limit := 0
+
+ if taskIDParam := c.QueryParam("task_id"); taskIDParam != "" {
+ taskID = &taskIDParam
+ }
+
+ if statusParam := c.QueryParam("status"); statusParam != "" {
+ s := schema.JobStatus(statusParam)
+ status = &s
+ }
+
+ if limitParam := c.QueryParam("limit"); limitParam != "" {
+ if l, err := strconv.Atoi(limitParam); err == nil {
+ limit = l
+ }
+ }
+
+ jobs := app.AgentJobService().ListJobs(taskID, status, limit)
+ return c.JSON(http.StatusOK, jobs)
+ }
+}
+
+// CancelJobEndpoint cancels a running job
+// @Summary Cancel an agent job
+// @Description Cancel a running or pending agent job
+// @Tags agent-jobs
+// @Produce json
+// @Param id path string true "Job ID"
+// @Success 200 {object} map[string]string "Job cancelled"
+// @Failure 400 {object} map[string]string "Job cannot be cancelled"
+// @Failure 404 {object} map[string]string "Job not found"
+// @Router /api/agent/jobs/{id}/cancel [post]
+func CancelJobEndpoint(app *application.Application) echo.HandlerFunc {
+ return func(c echo.Context) error {
+ id := c.Param("id")
+ if err := app.AgentJobService().CancelJob(id); err != nil {
+ if err.Error() == "job not found: "+id {
+ return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
+ }
+ return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
+ }
+
+ return c.JSON(http.StatusOK, map[string]string{"message": "Job cancelled"})
+ }
+}
+
+// DeleteJobEndpoint deletes a job
+// @Summary Delete an agent job
+// @Description Delete an agent job by ID
+// @Tags agent-jobs
+// @Produce json
+// @Param id path string true "Job ID"
+// @Success 200 {object} map[string]string "Job deleted"
+// @Failure 404 {object} map[string]string "Job not found"
+// @Router /api/agent/jobs/{id} [delete]
+func DeleteJobEndpoint(app *application.Application) echo.HandlerFunc {
+ return func(c echo.Context) error {
+ id := c.Param("id")
+ if err := app.AgentJobService().DeleteJob(id); err != nil {
+ if err.Error() == "job not found: "+id {
+ return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
+ }
+ return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
+ }
+
+ return c.JSON(http.StatusOK, map[string]string{"message": "Job deleted"})
+ }
+}
+
+// ExecuteTaskByNameEndpoint executes a task by name
+// @Summary Execute a task by name
+// @Description Execute an agent task by its name (convenience endpoint). Parameters can be provided in the request body as a JSON object with string values.
+// @Tags agent-jobs
+// @Accept json
+// @Produce json
+// @Param name path string true "Task name"
+// @Param request body map[string]string false "Template parameters (JSON object with string values)"
+// @Success 201 {object} schema.JobExecutionResponse "Job created"
+// @Failure 400 {object} map[string]string "Invalid request"
+// @Failure 404 {object} map[string]string "Task not found"
+// @Router /api/agent/tasks/{name}/execute [post]
+func ExecuteTaskByNameEndpoint(app *application.Application) echo.HandlerFunc {
+ return func(c echo.Context) error {
+ name := c.Param("name")
+ var params map[string]string
+
+ // Try to bind parameters from request body
+ // If body is empty or invalid, use empty params
+ if c.Request().ContentLength > 0 {
+ if err := c.Bind(¶ms); err != nil {
+ // If binding fails, try to read as raw JSON
+ body := make(map[string]interface{})
+ if err := c.Bind(&body); err == nil {
+ // Convert interface{} values to strings
+ params = make(map[string]string)
+ for k, v := range body {
+ if str, ok := v.(string); ok {
+ params[k] = str
+ } else {
+ // Convert non-string values to string
+ params[k] = fmt.Sprintf("%v", v)
+ }
+ }
+ } else {
+ // If all binding fails, use empty params
+ params = make(map[string]string)
+ }
+ }
+ } else {
+ // No body provided, use empty params
+ params = make(map[string]string)
+ }
+
+ // Find task by name
+ tasks := app.AgentJobService().ListTasks()
+ var task *schema.Task
+ for _, t := range tasks {
+ if t.Name == name {
+ task = &t
+ break
+ }
+ }
+
+ if task == nil {
+ return c.JSON(http.StatusNotFound, map[string]string{"error": "Task not found: " + name})
+ }
+
+ jobID, err := app.AgentJobService().ExecuteJob(task.ID, params, "api", nil)
+ if err != nil {
+ return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
+ }
+
+ baseURL := c.Scheme() + "://" + c.Request().Host
+ return c.JSON(http.StatusCreated, schema.JobExecutionResponse{
+ JobID: jobID,
+ Status: "pending",
+ URL: baseURL + "/api/agent/jobs/" + jobID,
+ })
+ }
+}
diff --git a/core/http/endpoints/localai/backend.go b/core/http/endpoints/localai/backend.go
new file mode 100644
index 0000000000000000000000000000000000000000..f804f1b35c73c626733256aed01be5b166c83b92
--- /dev/null
+++ b/core/http/endpoints/localai/backend.go
@@ -0,0 +1,155 @@
+package localai
+
+import (
+ "encoding/json"
+ "fmt"
+
+ "github.com/google/uuid"
+ "github.com/labstack/echo/v4"
+ "github.com/mudler/LocalAI/core/config"
+ "github.com/mudler/LocalAI/core/gallery"
+ "github.com/mudler/LocalAI/core/http/middleware"
+ "github.com/mudler/LocalAI/core/schema"
+ "github.com/mudler/LocalAI/core/services"
+ "github.com/mudler/LocalAI/pkg/system"
+ "github.com/mudler/xlog"
+)
+
+type BackendEndpointService struct {
+ galleries []config.Gallery
+ backendPath string
+ backendSystemPath string
+ backendApplier *services.GalleryService
+}
+
+type GalleryBackend struct {
+ ID string `json:"id"`
+}
+
+func CreateBackendEndpointService(galleries []config.Gallery, systemState *system.SystemState, backendApplier *services.GalleryService) BackendEndpointService {
+ return BackendEndpointService{
+ galleries: galleries,
+ backendPath: systemState.Backend.BackendsPath,
+ backendSystemPath: systemState.Backend.BackendsSystemPath,
+ backendApplier: backendApplier,
+ }
+}
+
+// GetOpStatusEndpoint returns the job status
+// @Summary Returns the job status
+// @Success 200 {object} services.GalleryOpStatus "Response"
+// @Router /backends/jobs/{uuid} [get]
+func (mgs *BackendEndpointService) GetOpStatusEndpoint() echo.HandlerFunc {
+ return func(c echo.Context) error {
+ status := mgs.backendApplier.GetStatus(c.Param("uuid"))
+ if status == nil {
+ return fmt.Errorf("could not find any status for ID")
+ }
+ return c.JSON(200, status)
+ }
+}
+
+// GetAllStatusEndpoint returns all the jobs status progress
+// @Summary Returns all the jobs status progress
+// @Success 200 {object} map[string]services.GalleryOpStatus "Response"
+// @Router /backends/jobs [get]
+func (mgs *BackendEndpointService) GetAllStatusEndpoint() echo.HandlerFunc {
+ return func(c echo.Context) error {
+ return c.JSON(200, mgs.backendApplier.GetAllStatus())
+ }
+}
+
+// ApplyBackendEndpoint installs a new backend to a LocalAI instance
+// @Summary Install backends to LocalAI.
+// @Param request body GalleryBackend true "query params"
+// @Success 200 {object} schema.BackendResponse "Response"
+// @Router /backends/apply [post]
+func (mgs *BackendEndpointService) ApplyBackendEndpoint() echo.HandlerFunc {
+ return func(c echo.Context) error {
+ input := new(GalleryBackend)
+ // Get input data from the request body
+ if err := c.Bind(input); err != nil {
+ return err
+ }
+
+ uuid, err := uuid.NewUUID()
+ if err != nil {
+ return err
+ }
+ mgs.backendApplier.BackendGalleryChannel <- services.GalleryOp[gallery.GalleryBackend, any]{
+ ID: uuid.String(),
+ GalleryElementName: input.ID,
+ Galleries: mgs.galleries,
+ }
+
+ return c.JSON(200, schema.BackendResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%sbackends/jobs/%s", middleware.BaseURL(c), uuid.String())})
+ }
+}
+
+// DeleteBackendEndpoint lets delete backends from a LocalAI instance
+// @Summary delete backends from LocalAI.
+// @Param name path string true "Backend name"
+// @Success 200 {object} schema.BackendResponse "Response"
+// @Router /backends/delete/{name} [post]
+func (mgs *BackendEndpointService) DeleteBackendEndpoint() echo.HandlerFunc {
+ return func(c echo.Context) error {
+ backendName := c.Param("name")
+
+ mgs.backendApplier.BackendGalleryChannel <- services.GalleryOp[gallery.GalleryBackend, any]{
+ Delete: true,
+ GalleryElementName: backendName,
+ Galleries: mgs.galleries,
+ }
+
+ uuid, err := uuid.NewUUID()
+ if err != nil {
+ return err
+ }
+
+ return c.JSON(200, schema.BackendResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%sbackends/jobs/%s", middleware.BaseURL(c), uuid.String())})
+ }
+}
+
+// ListBackendsEndpoint list the available backends configured in LocalAI
+// @Summary List all Backends
+// @Success 200 {object} []gallery.GalleryBackend "Response"
+// @Router /backends [get]
+func (mgs *BackendEndpointService) ListBackendsEndpoint(systemState *system.SystemState) echo.HandlerFunc {
+ return func(c echo.Context) error {
+ backends, err := gallery.ListSystemBackends(systemState)
+ if err != nil {
+ return err
+ }
+ return c.JSON(200, backends.GetAll())
+ }
+}
+
+// ListModelGalleriesEndpoint list the available galleries configured in LocalAI
+// @Summary List all Galleries
+// @Success 200 {object} []config.Gallery "Response"
+// @Router /backends/galleries [get]
+// NOTE: This is different (and much simpler!) than above! This JUST lists the model galleries that have been loaded, not their contents!
+func (mgs *BackendEndpointService) ListBackendGalleriesEndpoint() echo.HandlerFunc {
+ return func(c echo.Context) error {
+ xlog.Debug("Listing backend galleries", "galleries", mgs.galleries)
+ dat, err := json.Marshal(mgs.galleries)
+ if err != nil {
+ return err
+ }
+ return c.Blob(200, "application/json", dat)
+ }
+}
+
+// ListAvailableBackendsEndpoint list the available backends in the galleries configured in LocalAI
+// @Summary List all available Backends
+// @Success 200 {object} []gallery.GalleryBackend "Response"
+// @Router /backends/available [get]
+func (mgs *BackendEndpointService) ListAvailableBackendsEndpoint(systemState *system.SystemState) echo.HandlerFunc {
+ return func(c echo.Context) error {
+ backends, err := gallery.AvailableBackends(mgs.galleries, systemState)
+ if err != nil {
+ return err
+ }
+ return c.JSON(200, backends)
+ }
+}
diff --git a/core/http/endpoints/localai/backend_monitor.go b/core/http/endpoints/localai/backend_monitor.go
new file mode 100644
index 0000000000000000000000000000000000000000..18016c5792208d426805d78b7acbe0cd00b0bd87
--- /dev/null
+++ b/core/http/endpoints/localai/backend_monitor.go
@@ -0,0 +1,45 @@
+package localai
+
+import (
+ "github.com/labstack/echo/v4"
+ "github.com/mudler/LocalAI/core/schema"
+ "github.com/mudler/LocalAI/core/services"
+)
+
+// BackendMonitorEndpoint returns the status of the specified backend
+// @Summary Backend monitor endpoint
+// @Param request body schema.BackendMonitorRequest true "Backend statistics request"
+// @Success 200 {object} proto.StatusResponse "Response"
+// @Router /backend/monitor [get]
+func BackendMonitorEndpoint(bm *services.BackendMonitorService) echo.HandlerFunc {
+ return func(c echo.Context) error {
+
+ input := new(schema.BackendMonitorRequest)
+ // Get input data from the request body
+ if err := c.Bind(input); err != nil {
+ return err
+ }
+
+ resp, err := bm.CheckAndSample(input.Model)
+ if err != nil {
+ return err
+ }
+ return c.JSON(200, resp)
+ }
+}
+
+// BackendShutdownEndpoint shuts down the specified backend
+// @Summary Backend monitor endpoint
+// @Param request body schema.BackendMonitorRequest true "Backend statistics request"
+// @Router /backend/shutdown [post]
+func BackendShutdownEndpoint(bm *services.BackendMonitorService) echo.HandlerFunc {
+ return func(c echo.Context) error {
+ input := new(schema.BackendMonitorRequest)
+ // Get input data from the request body
+ if err := c.Bind(input); err != nil {
+ return err
+ }
+
+ return bm.ShutdownModel(input.Model)
+ }
+}
diff --git a/core/http/endpoints/localai/detection.go b/core/http/endpoints/localai/detection.go
new file mode 100644
index 0000000000000000000000000000000000000000..77a0c72565269179017026f3ada72d6e8062d196
--- /dev/null
+++ b/core/http/endpoints/localai/detection.go
@@ -0,0 +1,59 @@
+package localai
+
+import (
+ "github.com/labstack/echo/v4"
+ "github.com/mudler/LocalAI/core/backend"
+ "github.com/mudler/LocalAI/core/config"
+ "github.com/mudler/LocalAI/core/http/middleware"
+ "github.com/mudler/LocalAI/core/schema"
+ "github.com/mudler/LocalAI/pkg/model"
+ "github.com/mudler/LocalAI/pkg/utils"
+ "github.com/mudler/xlog"
+)
+
+// DetectionEndpoint is the LocalAI Detection endpoint https://localai.io/docs/api-reference/detection
+// @Summary Detects objects in the input image.
+// @Param request body schema.DetectionRequest true "query params"
+// @Success 200 {object} schema.DetectionResponse "Response"
+// @Router /v1/detection [post]
+func DetectionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
+ return func(c echo.Context) error {
+
+ input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.DetectionRequest)
+ if !ok || input.Model == "" {
+ return echo.ErrBadRequest
+ }
+
+ cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
+ if !ok || cfg == nil {
+ return echo.ErrBadRequest
+ }
+
+ xlog.Debug("Detection", "image", input.Image, "modelFile", "modelFile", "backend", cfg.Backend)
+
+ image, err := utils.GetContentURIAsBase64(input.Image)
+ if err != nil {
+ return err
+ }
+
+ res, err := backend.Detection(image, ml, appConfig, *cfg)
+ if err != nil {
+ return err
+ }
+
+ response := schema.DetectionResponse{
+ Detections: make([]schema.Detection, len(res.Detections)),
+ }
+ for i, detection := range res.Detections {
+ response.Detections[i] = schema.Detection{
+ X: detection.X,
+ Y: detection.Y,
+ Width: detection.Width,
+ Height: detection.Height,
+ ClassName: detection.ClassName,
+ }
+ }
+
+ return c.JSON(200, response)
+ }
+}
diff --git a/core/http/endpoints/localai/edit_model.go b/core/http/endpoints/localai/edit_model.go
new file mode 100644
index 0000000000000000000000000000000000000000..f84b4d21bd0064cc5aefa7077f51fdcd2e13acab
--- /dev/null
+++ b/core/http/endpoints/localai/edit_model.go
@@ -0,0 +1,223 @@
+package localai
+
+import (
+ "fmt"
+ "io"
+ "net/http"
+ "os"
+
+ "github.com/labstack/echo/v4"
+ "github.com/mudler/LocalAI/core/config"
+ httpUtils "github.com/mudler/LocalAI/core/http/middleware"
+ "github.com/mudler/LocalAI/internal"
+ "github.com/mudler/LocalAI/pkg/utils"
+
+ "gopkg.in/yaml.v3"
+)
+
+// GetEditModelPage renders the edit model page with current configuration
+func GetEditModelPage(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
+ return func(c echo.Context) error {
+ modelName := c.Param("name")
+ if modelName == "" {
+ response := ModelResponse{
+ Success: false,
+ Error: "Model name is required",
+ }
+ return c.JSON(http.StatusBadRequest, response)
+ }
+
+ modelConfig, exists := cl.GetModelConfig(modelName)
+ if !exists {
+ response := ModelResponse{
+ Success: false,
+ Error: "Model configuration not found",
+ }
+ return c.JSON(http.StatusNotFound, response)
+ }
+
+ modelConfigFile := modelConfig.GetModelConfigFile()
+ if modelConfigFile == "" {
+ response := ModelResponse{
+ Success: false,
+ Error: "Model configuration file not found",
+ }
+ return c.JSON(http.StatusNotFound, response)
+ }
+ configData, err := os.ReadFile(modelConfigFile)
+ if err != nil {
+ response := ModelResponse{
+ Success: false,
+ Error: "Failed to read configuration file: " + err.Error(),
+ }
+ return c.JSON(http.StatusInternalServerError, response)
+ }
+
+ // Render the edit page with the current configuration
+ templateData := struct {
+ Title string
+ ModelName string
+ Config *config.ModelConfig
+ ConfigJSON string
+ ConfigYAML string
+ BaseURL string
+ Version string
+ }{
+ Title: "LocalAI - Edit Model " + modelName,
+ ModelName: modelName,
+ Config: &modelConfig,
+ ConfigYAML: string(configData),
+ BaseURL: httpUtils.BaseURL(c),
+ Version: internal.PrintableVersion(),
+ }
+
+ return c.Render(http.StatusOK, "views/model-editor", templateData)
+ }
+}
+
+// EditModelEndpoint handles updating existing model configurations
+func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
+ return func(c echo.Context) error {
+ modelName := c.Param("name")
+ if modelName == "" {
+ response := ModelResponse{
+ Success: false,
+ Error: "Model name is required",
+ }
+ return c.JSON(http.StatusBadRequest, response)
+ }
+
+ modelConfig, exists := cl.GetModelConfig(modelName)
+ if !exists {
+ response := ModelResponse{
+ Success: false,
+ Error: "Existing model configuration not found",
+ }
+ return c.JSON(http.StatusNotFound, response)
+ }
+
+ // Get the raw body
+ body, err := io.ReadAll(c.Request().Body)
+ if err != nil {
+ response := ModelResponse{
+ Success: false,
+ Error: "Failed to read request body: " + err.Error(),
+ }
+ return c.JSON(http.StatusBadRequest, response)
+ }
+ if len(body) == 0 {
+ response := ModelResponse{
+ Success: false,
+ Error: "Request body is empty",
+ }
+ return c.JSON(http.StatusBadRequest, response)
+ }
+
+ // Check content to see if it's a valid model config
+ var req config.ModelConfig
+
+ // Parse YAML
+ if err := yaml.Unmarshal(body, &req); err != nil {
+ response := ModelResponse{
+ Success: false,
+ Error: "Failed to parse YAML: " + err.Error(),
+ }
+ return c.JSON(http.StatusBadRequest, response)
+ }
+
+ // Validate required fields
+ if req.Name == "" {
+ response := ModelResponse{
+ Success: false,
+ Error: "Name is required",
+ }
+ return c.JSON(http.StatusBadRequest, response)
+ }
+
+ // Validate the configuration
+ if valid, _ := req.Validate(); !valid {
+ response := ModelResponse{
+ Success: false,
+ Error: "Validation failed",
+ Details: []string{"Configuration validation failed. Please check your YAML syntax and required fields."},
+ }
+ return c.JSON(http.StatusBadRequest, response)
+ }
+
+ // Load the existing configuration
+ configPath := modelConfig.GetModelConfigFile()
+ if err := utils.VerifyPath(configPath, appConfig.SystemState.Model.ModelsPath); err != nil {
+ response := ModelResponse{
+ Success: false,
+ Error: "Model configuration not trusted: " + err.Error(),
+ }
+ return c.JSON(http.StatusNotFound, response)
+ }
+
+ // Write new content to file
+ if err := os.WriteFile(configPath, body, 0644); err != nil {
+ response := ModelResponse{
+ Success: false,
+ Error: "Failed to write configuration file: " + err.Error(),
+ }
+ return c.JSON(http.StatusInternalServerError, response)
+ }
+
+ // Reload configurations
+ if err := cl.LoadModelConfigsFromPath(appConfig.SystemState.Model.ModelsPath, appConfig.ToConfigLoaderOptions()...); err != nil {
+ response := ModelResponse{
+ Success: false,
+ Error: "Failed to reload configurations: " + err.Error(),
+ }
+ return c.JSON(http.StatusInternalServerError, response)
+ }
+
+ // Preload the model
+ if err := cl.Preload(appConfig.SystemState.Model.ModelsPath); err != nil {
+ response := ModelResponse{
+ Success: false,
+ Error: "Failed to preload model: " + err.Error(),
+ }
+ return c.JSON(http.StatusInternalServerError, response)
+ }
+
+ // Return success response
+ response := ModelResponse{
+ Success: true,
+ Message: fmt.Sprintf("Model '%s' updated successfully", modelName),
+ Filename: configPath,
+ Config: req,
+ }
+ return c.JSON(200, response)
+ }
+}
+
+// ReloadModelsEndpoint handles reloading model configurations from disk
+func ReloadModelsEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
+ return func(c echo.Context) error {
+ // Reload configurations
+ if err := cl.LoadModelConfigsFromPath(appConfig.SystemState.Model.ModelsPath, appConfig.ToConfigLoaderOptions()...); err != nil {
+ response := ModelResponse{
+ Success: false,
+ Error: "Failed to reload configurations: " + err.Error(),
+ }
+ return c.JSON(http.StatusInternalServerError, response)
+ }
+
+ // Preload the models
+ if err := cl.Preload(appConfig.SystemState.Model.ModelsPath); err != nil {
+ response := ModelResponse{
+ Success: false,
+ Error: "Failed to preload models: " + err.Error(),
+ }
+ return c.JSON(http.StatusInternalServerError, response)
+ }
+
+ // Return success response
+ response := ModelResponse{
+ Success: true,
+ Message: "Model configurations reloaded successfully",
+ }
+ return c.JSON(http.StatusOK, response)
+ }
+}
diff --git a/core/http/endpoints/localai/edit_model_test.go b/core/http/endpoints/localai/edit_model_test.go
new file mode 100644
index 0000000000000000000000000000000000000000..b354dbc2b2493f0e390507e7010e20eaaa7acd90
--- /dev/null
+++ b/core/http/endpoints/localai/edit_model_test.go
@@ -0,0 +1,84 @@
+package localai_test
+
+import (
+ "bytes"
+ "encoding/json"
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "os"
+ "path/filepath"
+
+ "github.com/labstack/echo/v4"
+ "github.com/mudler/LocalAI/core/config"
+ . "github.com/mudler/LocalAI/core/http/endpoints/localai"
+ "github.com/mudler/LocalAI/pkg/system"
+ . "github.com/onsi/ginkgo/v2"
+ . "github.com/onsi/gomega"
+)
+
+// testRenderer is a simple renderer for tests that returns JSON
+type testRenderer struct{}
+
+func (t *testRenderer) Render(w io.Writer, name string, data interface{}, c echo.Context) error {
+ // For tests, just return the data as JSON
+ return json.NewEncoder(w).Encode(data)
+}
+
+var _ = Describe("Edit Model test", func() {
+
+ var tempDir string
+ BeforeEach(func() {
+ var err error
+ tempDir, err = os.MkdirTemp("", "localai-test")
+ Expect(err).ToNot(HaveOccurred())
+ })
+ AfterEach(func() {
+ os.RemoveAll(tempDir)
+ })
+
+ Context("Edit Model endpoint", func() {
+ It("should edit a model", func() {
+ systemState, err := system.GetSystemState(
+ system.WithModelPath(filepath.Join(tempDir)),
+ )
+ Expect(err).ToNot(HaveOccurred())
+
+ applicationConfig := config.NewApplicationConfig(
+ config.WithSystemState(systemState),
+ )
+ //modelLoader := model.NewModelLoader(systemState, true)
+ modelConfigLoader := config.NewModelConfigLoader(systemState.Model.ModelsPath)
+
+ // Define Echo app and register all routes upfront
+ app := echo.New()
+ // Set up a simple renderer for the test
+ app.Renderer = &testRenderer{}
+ app.POST("/import-model", ImportModelEndpoint(modelConfigLoader, applicationConfig))
+ app.GET("/edit-model/:name", GetEditModelPage(modelConfigLoader, applicationConfig))
+
+ requestBody := bytes.NewBufferString(`{"name": "foo", "backend": "foo", "model": "foo"}`)
+
+ req := httptest.NewRequest("POST", "/import-model", requestBody)
+ req.Header.Set("Content-Type", "application/json")
+ rec := httptest.NewRecorder()
+ app.ServeHTTP(rec, req)
+
+ body, err := io.ReadAll(rec.Body)
+ Expect(err).ToNot(HaveOccurred())
+ Expect(string(body)).To(ContainSubstring("Model configuration created successfully"))
+ Expect(rec.Code).To(Equal(http.StatusOK))
+
+ req = httptest.NewRequest("GET", "/edit-model/foo", nil)
+ rec = httptest.NewRecorder()
+ app.ServeHTTP(rec, req)
+
+ body, err = io.ReadAll(rec.Body)
+ Expect(err).ToNot(HaveOccurred())
+ // The response contains the model configuration with backend field
+ Expect(string(body)).To(ContainSubstring(`"backend":"foo"`))
+ Expect(string(body)).To(ContainSubstring(`"name":"foo"`))
+ Expect(rec.Code).To(Equal(http.StatusOK))
+ })
+ })
+})
diff --git a/core/http/endpoints/localai/gallery.go b/core/http/endpoints/localai/gallery.go
new file mode 100644
index 0000000000000000000000000000000000000000..4c55630fc52d5ebe0585882c502cc53da994c5a2
--- /dev/null
+++ b/core/http/endpoints/localai/gallery.go
@@ -0,0 +1,160 @@
+package localai
+
+import (
+ "encoding/json"
+ "fmt"
+
+ "github.com/google/uuid"
+ "github.com/labstack/echo/v4"
+ "github.com/mudler/LocalAI/core/config"
+ "github.com/mudler/LocalAI/core/gallery"
+ "github.com/mudler/LocalAI/core/http/middleware"
+ "github.com/mudler/LocalAI/core/schema"
+ "github.com/mudler/LocalAI/core/services"
+ "github.com/mudler/LocalAI/pkg/system"
+ "github.com/mudler/xlog"
+)
+
+type ModelGalleryEndpointService struct {
+ galleries []config.Gallery
+ backendGalleries []config.Gallery
+ modelPath string
+ galleryApplier *services.GalleryService
+}
+
+type GalleryModel struct {
+ ID string `json:"id"`
+ gallery.GalleryModel
+}
+
+func CreateModelGalleryEndpointService(galleries []config.Gallery, backendGalleries []config.Gallery, systemState *system.SystemState, galleryApplier *services.GalleryService) ModelGalleryEndpointService {
+ return ModelGalleryEndpointService{
+ galleries: galleries,
+ backendGalleries: backendGalleries,
+ modelPath: systemState.Model.ModelsPath,
+ galleryApplier: galleryApplier,
+ }
+}
+
+// GetOpStatusEndpoint returns the job status
+// @Summary Returns the job status
+// @Success 200 {object} services.GalleryOpStatus "Response"
+// @Router /models/jobs/{uuid} [get]
+func (mgs *ModelGalleryEndpointService) GetOpStatusEndpoint() echo.HandlerFunc {
+ return func(c echo.Context) error {
+ status := mgs.galleryApplier.GetStatus(c.Param("uuid"))
+ if status == nil {
+ return fmt.Errorf("could not find any status for ID")
+ }
+ return c.JSON(200, status)
+ }
+}
+
+// GetAllStatusEndpoint returns all the jobs status progress
+// @Summary Returns all the jobs status progress
+// @Success 200 {object} map[string]services.GalleryOpStatus "Response"
+// @Router /models/jobs [get]
+func (mgs *ModelGalleryEndpointService) GetAllStatusEndpoint() echo.HandlerFunc {
+ return func(c echo.Context) error {
+ return c.JSON(200, mgs.galleryApplier.GetAllStatus())
+ }
+}
+
+// ApplyModelGalleryEndpoint installs a new model to a LocalAI instance from the model gallery
+// @Summary Install models to LocalAI.
+// @Param request body GalleryModel true "query params"
+// @Success 200 {object} schema.GalleryResponse "Response"
+// @Router /models/apply [post]
+func (mgs *ModelGalleryEndpointService) ApplyModelGalleryEndpoint() echo.HandlerFunc {
+ return func(c echo.Context) error {
+ input := new(GalleryModel)
+ // Get input data from the request body
+ if err := c.Bind(input); err != nil {
+ return err
+ }
+
+ uuid, err := uuid.NewUUID()
+ if err != nil {
+ return err
+ }
+ mgs.galleryApplier.ModelGalleryChannel <- services.GalleryOp[gallery.GalleryModel, gallery.ModelConfig]{
+ Req: input.GalleryModel,
+ ID: uuid.String(),
+ GalleryElementName: input.ID,
+ Galleries: mgs.galleries,
+ BackendGalleries: mgs.backendGalleries,
+ }
+
+ return c.JSON(200, schema.GalleryResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%smodels/jobs/%s", middleware.BaseURL(c), uuid.String())})
+ }
+}
+
+// DeleteModelGalleryEndpoint lets delete models from a LocalAI instance
+// @Summary delete models to LocalAI.
+// @Param name path string true "Model name"
+// @Success 200 {object} schema.GalleryResponse "Response"
+// @Router /models/delete/{name} [post]
+func (mgs *ModelGalleryEndpointService) DeleteModelGalleryEndpoint() echo.HandlerFunc {
+ return func(c echo.Context) error {
+ modelName := c.Param("name")
+
+ mgs.galleryApplier.ModelGalleryChannel <- services.GalleryOp[gallery.GalleryModel, gallery.ModelConfig]{
+ Delete: true,
+ GalleryElementName: modelName,
+ }
+
+ uuid, err := uuid.NewUUID()
+ if err != nil {
+ return err
+ }
+
+ return c.JSON(200, schema.GalleryResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%smodels/jobs/%s", middleware.BaseURL(c), uuid.String())})
+ }
+}
+
+// ListModelFromGalleryEndpoint list the available models for installation from the active galleries
+// @Summary List installable models.
+// @Success 200 {object} []gallery.GalleryModel "Response"
+// @Router /models/available [get]
+func (mgs *ModelGalleryEndpointService) ListModelFromGalleryEndpoint(systemState *system.SystemState) echo.HandlerFunc {
+ return func(c echo.Context) error {
+
+ models, err := gallery.AvailableGalleryModels(mgs.galleries, systemState)
+ if err != nil {
+ xlog.Error("could not list models from galleries", "error", err)
+ return err
+ }
+
+ xlog.Debug("Available models from galleries", "modelCount", len(models), "galleryCount", len(mgs.galleries))
+
+ m := []gallery.Metadata{}
+
+ for _, mm := range models {
+ m = append(m, mm.Metadata)
+ }
+
+ xlog.Debug("Models", "models", m)
+
+ dat, err := json.Marshal(m)
+ if err != nil {
+ return fmt.Errorf("could not marshal models: %w", err)
+ }
+ return c.Blob(200, "application/json", dat)
+ }
+}
+
+// ListModelGalleriesEndpoint list the available galleries configured in LocalAI
+// @Summary List all Galleries
+// @Success 200 {object} []config.Gallery "Response"
+// @Router /models/galleries [get]
+// NOTE: This is different (and much simpler!) than above! This JUST lists the model galleries that have been loaded, not their contents!
+func (mgs *ModelGalleryEndpointService) ListModelGalleriesEndpoint() echo.HandlerFunc {
+ return func(c echo.Context) error {
+ xlog.Debug("Listing model galleries", "galleries", mgs.galleries)
+ dat, err := json.Marshal(mgs.galleries)
+ if err != nil {
+ return err
+ }
+ return c.Blob(200, "application/json", dat)
+ }
+}
diff --git a/core/http/endpoints/localai/get_token_metrics.go b/core/http/endpoints/localai/get_token_metrics.go
new file mode 100644
index 0000000000000000000000000000000000000000..69c408e50b7619ca40fc07c86af7207575c98607
--- /dev/null
+++ b/core/http/endpoints/localai/get_token_metrics.go
@@ -0,0 +1,57 @@
+package localai
+
+import (
+ "github.com/labstack/echo/v4"
+ "github.com/mudler/LocalAI/core/backend"
+ "github.com/mudler/LocalAI/core/config"
+ "github.com/mudler/LocalAI/core/http/middleware"
+ "github.com/mudler/LocalAI/core/schema"
+ "github.com/mudler/xlog"
+
+ "github.com/mudler/LocalAI/pkg/model"
+)
+
+// TODO: This is not yet in use. Needs middleware rework, since it is not referenced.
+
+// TokenMetricsEndpoint is an endpoint to get TokensProcessed Per Second for Active SlotID
+//
+// @Summary Get TokenMetrics for Active Slot.
+// @Accept json
+// @Produce audio/x-wav
+// @Success 200 {string} binary "generated audio/wav file"
+// @Router /v1/tokenMetrics [get]
+// @Router /tokenMetrics [get]
+func TokenMetricsEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
+ return func(c echo.Context) error {
+
+ input := new(schema.TokenMetricsRequest)
+
+ // Get input data from the request body
+ if err := c.Bind(input); err != nil {
+ return err
+ }
+
+ modelFile, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
+ if !ok || modelFile != "" {
+ modelFile = input.Model
+ xlog.Warn("Model not found in context", "model", input.Model)
+ }
+
+ cfg, err := cl.LoadModelConfigFileByNameDefaultOptions(modelFile, appConfig)
+
+ if err != nil {
+ xlog.Error("Error loading model config", "error", err)
+ modelFile = input.Model
+ xlog.Warn("Model not found in context", "model", input.Model)
+ } else {
+ modelFile = cfg.Model
+ }
+ xlog.Debug("Token Metrics for model", "model", modelFile)
+
+ response, err := backend.TokenMetrics(modelFile, ml, appConfig, *cfg)
+ if err != nil {
+ return err
+ }
+ return c.JSON(200, response)
+ }
+}
diff --git a/core/http/endpoints/localai/import_model.go b/core/http/endpoints/localai/import_model.go
new file mode 100644
index 0000000000000000000000000000000000000000..9d8926c0a228b178b55e47aaf92f9a70f8d52bed
--- /dev/null
+++ b/core/http/endpoints/localai/import_model.go
@@ -0,0 +1,212 @@
+package localai
+
+import (
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+ "os"
+ "path/filepath"
+ "strings"
+
+ "github.com/google/uuid"
+ "github.com/labstack/echo/v4"
+ "github.com/mudler/LocalAI/core/config"
+ "github.com/mudler/LocalAI/core/gallery"
+ "github.com/mudler/LocalAI/core/gallery/importers"
+ httpUtils "github.com/mudler/LocalAI/core/http/middleware"
+ "github.com/mudler/LocalAI/core/schema"
+ "github.com/mudler/LocalAI/core/services"
+ "github.com/mudler/LocalAI/pkg/utils"
+
+ "gopkg.in/yaml.v3"
+)
+
+// ImportModelURIEndpoint handles creating new model configurations from a URI
+func ImportModelURIEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, galleryService *services.GalleryService, opcache *services.OpCache) echo.HandlerFunc {
+ return func(c echo.Context) error {
+
+ input := new(schema.ImportModelRequest)
+
+ if err := c.Bind(input); err != nil {
+ return err
+ }
+
+ modelConfig, err := importers.DiscoverModelConfig(input.URI, input.Preferences)
+ if err != nil {
+ return fmt.Errorf("failed to discover model config: %w", err)
+ }
+
+ uuid, err := uuid.NewUUID()
+ if err != nil {
+ return err
+ }
+
+ // Determine gallery ID for tracking - use model name if available, otherwise use URI
+ galleryID := input.URI
+ if modelConfig.Name != "" {
+ galleryID = modelConfig.Name
+ }
+
+ // Register operation in opcache if available (for UI progress tracking)
+ if opcache != nil {
+ opcache.Set(galleryID, uuid.String())
+ }
+
+ galleryService.ModelGalleryChannel <- services.GalleryOp[gallery.GalleryModel, gallery.ModelConfig]{
+ Req: gallery.GalleryModel{
+ Overrides: map[string]interface{}{},
+ },
+ ID: uuid.String(),
+ GalleryElementName: galleryID,
+ GalleryElement: &modelConfig,
+ BackendGalleries: appConfig.BackendGalleries,
+ }
+
+ return c.JSON(200, schema.GalleryResponse{
+ ID: uuid.String(),
+ StatusURL: fmt.Sprintf("%smodels/jobs/%s", httpUtils.BaseURL(c), uuid.String()),
+ })
+ }
+}
+
+// ImportModelEndpoint handles creating new model configurations
+func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
+ return func(c echo.Context) error {
+ // Get the raw body
+ body, err := io.ReadAll(c.Request().Body)
+ if err != nil {
+ response := ModelResponse{
+ Success: false,
+ Error: "Failed to read request body: " + err.Error(),
+ }
+ return c.JSON(http.StatusBadRequest, response)
+ }
+ if len(body) == 0 {
+ response := ModelResponse{
+ Success: false,
+ Error: "Request body is empty",
+ }
+ return c.JSON(http.StatusBadRequest, response)
+ }
+
+ // Check content type to determine how to parse
+ contentType := c.Request().Header.Get("Content-Type")
+ var modelConfig config.ModelConfig
+
+ if strings.Contains(contentType, "application/json") {
+ // Parse JSON
+ if err := json.Unmarshal(body, &modelConfig); err != nil {
+ response := ModelResponse{
+ Success: false,
+ Error: "Failed to parse JSON: " + err.Error(),
+ }
+ return c.JSON(http.StatusBadRequest, response)
+ }
+ } else if strings.Contains(contentType, "application/x-yaml") || strings.Contains(contentType, "text/yaml") {
+ // Parse YAML
+ if err := yaml.Unmarshal(body, &modelConfig); err != nil {
+ response := ModelResponse{
+ Success: false,
+ Error: "Failed to parse YAML: " + err.Error(),
+ }
+ return c.JSON(http.StatusBadRequest, response)
+ }
+ } else {
+ // Try to auto-detect format
+ if len(body) > 0 && strings.TrimSpace(string(body))[0] == '{' {
+ // Looks like JSON
+ if err := json.Unmarshal(body, &modelConfig); err != nil {
+ response := ModelResponse{
+ Success: false,
+ Error: "Failed to parse JSON: " + err.Error(),
+ }
+ return c.JSON(http.StatusBadRequest, response)
+ }
+ } else {
+ // Assume YAML
+ if err := yaml.Unmarshal(body, &modelConfig); err != nil {
+ response := ModelResponse{
+ Success: false,
+ Error: "Failed to parse YAML: " + err.Error(),
+ }
+ return c.JSON(http.StatusBadRequest, response)
+ }
+ }
+ }
+
+ // Validate required fields
+ if modelConfig.Name == "" {
+ response := ModelResponse{
+ Success: false,
+ Error: "Name is required",
+ }
+ return c.JSON(http.StatusBadRequest, response)
+ }
+
+ // Set defaults
+ modelConfig.SetDefaults(appConfig.ToConfigLoaderOptions()...)
+
+ // Validate the configuration
+ if valid, _ := modelConfig.Validate(); !valid {
+ response := ModelResponse{
+ Success: false,
+ Error: "Invalid configuration",
+ }
+ return c.JSON(http.StatusBadRequest, response)
+ }
+
+ // Create the configuration file
+ configPath := filepath.Join(appConfig.SystemState.Model.ModelsPath, modelConfig.Name+".yaml")
+ if err := utils.VerifyPath(modelConfig.Name+".yaml", appConfig.SystemState.Model.ModelsPath); err != nil {
+ response := ModelResponse{
+ Success: false,
+ Error: "Model path not trusted: " + err.Error(),
+ }
+ return c.JSON(http.StatusBadRequest, response)
+ }
+
+ // Marshal to YAML for storage
+ yamlData, err := yaml.Marshal(&modelConfig)
+ if err != nil {
+ response := ModelResponse{
+ Success: false,
+ Error: "Failed to marshal configuration: " + err.Error(),
+ }
+ return c.JSON(http.StatusInternalServerError, response)
+ }
+
+ // Write the file
+ if err := os.WriteFile(configPath, yamlData, 0644); err != nil {
+ response := ModelResponse{
+ Success: false,
+ Error: "Failed to write configuration file: " + err.Error(),
+ }
+ return c.JSON(http.StatusInternalServerError, response)
+ }
+ // Reload configurations
+ if err := cl.LoadModelConfigsFromPath(appConfig.SystemState.Model.ModelsPath, appConfig.ToConfigLoaderOptions()...); err != nil {
+ response := ModelResponse{
+ Success: false,
+ Error: "Failed to reload configurations: " + err.Error(),
+ }
+ return c.JSON(http.StatusInternalServerError, response)
+ }
+
+ // Preload the model
+ if err := cl.Preload(appConfig.SystemState.Model.ModelsPath); err != nil {
+ response := ModelResponse{
+ Success: false,
+ Error: "Failed to preload model: " + err.Error(),
+ }
+ return c.JSON(http.StatusInternalServerError, response)
+ }
+ // Return success response
+ response := ModelResponse{
+ Success: true,
+ Message: "Model configuration created successfully",
+ Filename: filepath.Base(configPath),
+ }
+ return c.JSON(200, response)
+ }
+}
diff --git a/core/http/endpoints/localai/localai_suite_test.go b/core/http/endpoints/localai/localai_suite_test.go
new file mode 100644
index 0000000000000000000000000000000000000000..ea415bf70008d5a3f0f9f5e8198010bc083118ab
--- /dev/null
+++ b/core/http/endpoints/localai/localai_suite_test.go
@@ -0,0 +1,13 @@
+package localai_test
+
+import (
+ "testing"
+
+ . "github.com/onsi/ginkgo/v2"
+ . "github.com/onsi/gomega"
+)
+
+func TestLocalAIEndpoints(t *testing.T) {
+ RegisterFailHandler(Fail)
+ RunSpecs(t, "LocalAI Endpoints test suite")
+}
diff --git a/core/http/endpoints/localai/mcp.go b/core/http/endpoints/localai/mcp.go
new file mode 100644
index 0000000000000000000000000000000000000000..721f97a69e81f388d0515b63c28e1f3ded86e266
--- /dev/null
+++ b/core/http/endpoints/localai/mcp.go
@@ -0,0 +1,330 @@
+package localai
+
+import (
+ "context"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "net"
+ "time"
+
+ "github.com/labstack/echo/v4"
+ "github.com/mudler/LocalAI/core/config"
+ mcpTools "github.com/mudler/LocalAI/core/http/endpoints/mcp"
+ "github.com/mudler/LocalAI/core/http/middleware"
+ "github.com/mudler/LocalAI/core/schema"
+ "github.com/mudler/LocalAI/core/templates"
+ "github.com/mudler/LocalAI/pkg/model"
+ "github.com/mudler/cogito"
+ "github.com/mudler/xlog"
+)
+
+// MCP SSE Event Types
+type MCPReasoningEvent struct {
+ Type string `json:"type"`
+ Content string `json:"content"`
+}
+
+type MCPToolCallEvent struct {
+ Type string `json:"type"`
+ Name string `json:"name"`
+ Arguments map[string]interface{} `json:"arguments"`
+ Reasoning string `json:"reasoning"`
+}
+
+type MCPToolResultEvent struct {
+ Type string `json:"type"`
+ Name string `json:"name"`
+ Result string `json:"result"`
+}
+
+type MCPStatusEvent struct {
+ Type string `json:"type"`
+ Message string `json:"message"`
+}
+
+type MCPAssistantEvent struct {
+ Type string `json:"type"`
+ Content string `json:"content"`
+}
+
+type MCPErrorEvent struct {
+ Type string `json:"type"`
+ Message string `json:"message"`
+}
+
+// MCPEndpoint is the endpoint for MCP chat completions. Supports SSE mode, but it is not compatible with the OpenAI apis.
+// @Summary Stream MCP chat completions with reasoning, tool calls, and results
+// @Param request body schema.OpenAIRequest true "query params"
+// @Success 200 {object} schema.OpenAIResponse "Response"
+// @Router /v1/mcp/chat/completions [post]
+func MCPEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) echo.HandlerFunc {
+ return func(c echo.Context) error {
+ ctx := c.Request().Context()
+ created := int(time.Now().Unix())
+
+ // Handle Correlation
+ id := c.Request().Header.Get("X-Correlation-ID")
+ if id == "" {
+ id = fmt.Sprintf("mcp-%d", time.Now().UnixNano())
+ }
+
+ input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
+ if !ok || input.Model == "" {
+ return echo.ErrBadRequest
+ }
+
+ config, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
+ if !ok || config == nil {
+ return echo.ErrBadRequest
+ }
+
+ if config.MCP.Servers == "" && config.MCP.Stdio == "" {
+ return fmt.Errorf("no MCP servers configured")
+ }
+
+ // Get MCP config from model config
+ remote, stdio, err := config.MCP.MCPConfigFromYAML()
+ if err != nil {
+ return fmt.Errorf("failed to get MCP config: %w", err)
+ }
+
+ // Check if we have tools in cache, or we have to have an initial connection
+ sessions, err := mcpTools.SessionsFromMCPConfig(config.Name, remote, stdio)
+ if err != nil {
+ return fmt.Errorf("failed to get MCP sessions: %w", err)
+ }
+
+ if len(sessions) == 0 {
+ return fmt.Errorf("no working MCP servers found")
+ }
+
+ // Build fragment from messages
+ fragment := cogito.NewEmptyFragment()
+ for _, message := range input.Messages {
+ fragment = fragment.AddMessage(message.Role, message.StringContent)
+ }
+
+ _, port, err := net.SplitHostPort(appConfig.APIAddress)
+ if err != nil {
+ return err
+ }
+ apiKey := ""
+ if len(appConfig.ApiKeys) > 0 {
+ apiKey = appConfig.ApiKeys[0]
+ }
+
+ ctxWithCancellation, cancel := context.WithCancel(ctx)
+ defer cancel()
+
+ // TODO: instead of connecting to the API, we should just wire this internally
+ // and act like completion.go.
+ // We can do this as cogito expects an interface and we can create one that
+ // we satisfy to just call internally ComputeChoices
+ defaultLLM := cogito.NewOpenAILLM(config.Name, apiKey, "http://127.0.0.1:"+port)
+
+ // Build cogito options using the consolidated method
+ cogitoOpts := config.BuildCogitoOptions()
+ cogitoOpts = append(
+ cogitoOpts,
+ cogito.WithContext(ctxWithCancellation),
+ cogito.WithMCPs(sessions...),
+ )
+ // Check if streaming is requested
+ toStream := input.Stream
+
+ if !toStream {
+ // Non-streaming mode: execute synchronously and return JSON response
+ cogitoOpts = append(
+ cogitoOpts,
+ cogito.WithStatusCallback(func(s string) {
+ xlog.Debug("[model agent] Status", "model", config.Name, "status", s)
+ }),
+ cogito.WithReasoningCallback(func(s string) {
+ xlog.Debug("[model agent] Reasoning", "model", config.Name, "reasoning", s)
+ }),
+ cogito.WithToolCallBack(func(t *cogito.ToolChoice, state *cogito.SessionState) cogito.ToolCallDecision {
+ xlog.Debug("[model agent] Tool call", "model", config.Name, "tool", t.Name, "reasoning", t.Reasoning, "arguments", t.Arguments)
+ return cogito.ToolCallDecision{
+ Approved: true,
+ }
+ }),
+ cogito.WithToolCallResultCallback(func(t cogito.ToolStatus) {
+ xlog.Debug("[model agent] Tool call result", "model", config.Name, "tool", t.Name, "result", t.Result, "tool_arguments", t.ToolArguments)
+ }),
+ )
+
+ f, err := cogito.ExecuteTools(
+ defaultLLM, fragment,
+ cogitoOpts...,
+ )
+ if err != nil && !errors.Is(err, cogito.ErrNoToolSelected) {
+ return err
+ }
+
+ f, err = defaultLLM.Ask(ctxWithCancellation, f)
+ if err != nil {
+ return err
+ }
+
+ resp := &schema.OpenAIResponse{
+ ID: id,
+ Created: created,
+ Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
+ Choices: []schema.Choice{{Message: &schema.Message{Role: "assistant", Content: &f.LastMessage().Content}}},
+ Object: "chat.completion",
+ }
+
+ jsonResult, _ := json.Marshal(resp)
+ xlog.Debug("Response", "response", string(jsonResult))
+
+ // Return the prediction in the response body
+ return c.JSON(200, resp)
+ }
+
+ // Streaming mode: use SSE
+ // Set up SSE headers
+ c.Response().Header().Set("Content-Type", "text/event-stream")
+ c.Response().Header().Set("Cache-Control", "no-cache")
+ c.Response().Header().Set("Connection", "keep-alive")
+ c.Response().Header().Set("X-Correlation-ID", id)
+
+ // Create channel for streaming events
+ events := make(chan interface{})
+ ended := make(chan error, 1)
+
+ // Set up callbacks for streaming
+ statusCallback := func(s string) {
+ events <- MCPStatusEvent{
+ Type: "status",
+ Message: s,
+ }
+ }
+
+ reasoningCallback := func(s string) {
+ events <- MCPReasoningEvent{
+ Type: "reasoning",
+ Content: s,
+ }
+ }
+
+ toolCallCallback := func(t *cogito.ToolChoice, state *cogito.SessionState) cogito.ToolCallDecision {
+ events <- MCPToolCallEvent{
+ Type: "tool_call",
+ Name: t.Name,
+ Arguments: t.Arguments,
+ Reasoning: t.Reasoning,
+ }
+ return cogito.ToolCallDecision{
+ Approved: true,
+ }
+ }
+
+ toolCallResultCallback := func(t cogito.ToolStatus) {
+ events <- MCPToolResultEvent{
+ Type: "tool_result",
+ Name: t.Name,
+ Result: t.Result,
+ }
+ }
+
+ cogitoOpts = append(cogitoOpts,
+ cogito.WithStatusCallback(statusCallback),
+ cogito.WithReasoningCallback(reasoningCallback),
+ cogito.WithToolCallBack(toolCallCallback),
+ cogito.WithToolCallResultCallback(toolCallResultCallback),
+ )
+
+ // Execute tools in a goroutine
+ go func() {
+ defer close(events)
+
+ f, err := cogito.ExecuteTools(
+ defaultLLM, fragment,
+ cogitoOpts...,
+ )
+ if err != nil && !errors.Is(err, cogito.ErrNoToolSelected) {
+ events <- MCPErrorEvent{
+ Type: "error",
+ Message: fmt.Sprintf("Failed to execute tools: %v", err),
+ }
+ ended <- err
+ return
+ }
+
+ // Get final response
+ f, err = defaultLLM.Ask(ctxWithCancellation, f)
+ if err != nil {
+ events <- MCPErrorEvent{
+ Type: "error",
+ Message: fmt.Sprintf("Failed to get response: %v", err),
+ }
+ ended <- err
+ return
+ }
+
+ // Stream final assistant response
+ content := f.LastMessage().Content
+ events <- MCPAssistantEvent{
+ Type: "assistant",
+ Content: content,
+ }
+
+ ended <- nil
+ }()
+
+ // Stream events to client
+ LOOP:
+ for {
+ select {
+ case <-ctx.Done():
+ // Context was cancelled (client disconnected or request cancelled)
+ xlog.Debug("Request context cancelled, stopping stream")
+ cancel()
+ break LOOP
+ case event := <-events:
+ if event == nil {
+ // Channel closed
+ break LOOP
+ }
+ eventData, err := json.Marshal(event)
+ if err != nil {
+ xlog.Debug("Failed to marshal event", "error", err)
+ continue
+ }
+ xlog.Debug("Sending event", "event", string(eventData))
+ _, err = fmt.Fprintf(c.Response().Writer, "data: %s\n\n", string(eventData))
+ if err != nil {
+ xlog.Debug("Sending event failed", "error", err)
+ cancel()
+ return err
+ }
+ c.Response().Flush()
+ case err := <-ended:
+ if err == nil {
+ // Send done signal
+ fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n")
+ c.Response().Flush()
+ break LOOP
+ }
+ xlog.Error("Stream ended with error", "error", err)
+ errorEvent := MCPErrorEvent{
+ Type: "error",
+ Message: err.Error(),
+ }
+ errorData, marshalErr := json.Marshal(errorEvent)
+ if marshalErr != nil {
+ fmt.Fprintf(c.Response().Writer, "data: {\"type\":\"error\",\"message\":\"Internal error\"}\n\n")
+ } else {
+ fmt.Fprintf(c.Response().Writer, "data: %s\n\n", string(errorData))
+ }
+ fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n")
+ c.Response().Flush()
+ return nil
+ }
+ }
+
+ xlog.Debug("Stream ended")
+ return nil
+ }
+}
diff --git a/core/http/endpoints/localai/metrics.go b/core/http/endpoints/localai/metrics.go
new file mode 100644
index 0000000000000000000000000000000000000000..a5f08a7f6444e901fbd8782c663a7509b47469e1
--- /dev/null
+++ b/core/http/endpoints/localai/metrics.go
@@ -0,0 +1,47 @@
+package localai
+
+import (
+ "time"
+
+ "github.com/labstack/echo/v4"
+ "github.com/mudler/LocalAI/core/services"
+ "github.com/prometheus/client_golang/prometheus/promhttp"
+)
+
+// LocalAIMetricsEndpoint returns the metrics endpoint for LocalAI
+// @Summary Prometheus metrics endpoint
+// @Param request body config.Gallery true "Gallery details"
+// @Router /metrics [get]
+func LocalAIMetricsEndpoint() echo.HandlerFunc {
+ return echo.WrapHandler(promhttp.Handler())
+}
+
+type apiMiddlewareConfig struct {
+ Filter func(c echo.Context) bool
+ metricsService *services.LocalAIMetricsService
+}
+
+func LocalAIMetricsAPIMiddleware(metrics *services.LocalAIMetricsService) echo.MiddlewareFunc {
+ cfg := apiMiddlewareConfig{
+ metricsService: metrics,
+ Filter: func(c echo.Context) bool {
+ return c.Path() == "/metrics"
+ },
+ }
+
+ return func(next echo.HandlerFunc) echo.HandlerFunc {
+ return func(c echo.Context) error {
+ if cfg.Filter != nil && cfg.Filter(c) {
+ return next(c)
+ }
+ path := c.Path()
+ method := c.Request().Method
+
+ start := time.Now()
+ err := next(c)
+ elapsed := float64(time.Since(start)) / float64(time.Second)
+ cfg.metricsService.ObserveAPICall(method, path, elapsed)
+ return err
+ }
+ }
+}
diff --git a/core/http/endpoints/localai/p2p.go b/core/http/endpoints/localai/p2p.go
new file mode 100644
index 0000000000000000000000000000000000000000..afd7d048dc83e76091767252b55b3496ea68fc6d
--- /dev/null
+++ b/core/http/endpoints/localai/p2p.go
@@ -0,0 +1,30 @@
+package localai
+
+import (
+ "github.com/labstack/echo/v4"
+ "github.com/mudler/LocalAI/core/config"
+ "github.com/mudler/LocalAI/core/p2p"
+ "github.com/mudler/LocalAI/core/schema"
+)
+
+// ShowP2PNodes returns the P2P Nodes
+// @Summary Returns available P2P nodes
+// @Success 200 {object} []schema.P2PNodesResponse "Response"
+// @Router /api/p2p [get]
+func ShowP2PNodes(appConfig *config.ApplicationConfig) echo.HandlerFunc {
+ // Render index
+ return func(c echo.Context) error {
+ return c.JSON(200, schema.P2PNodesResponse{
+ Nodes: p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.WorkerID)),
+ FederatedNodes: p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.FederatedID)),
+ })
+ }
+}
+
+// ShowP2PToken returns the P2P token
+// @Summary Show the P2P token
+// @Success 200 {string} string "Response"
+// @Router /api/p2p/token [get]
+func ShowP2PToken(appConfig *config.ApplicationConfig) echo.HandlerFunc {
+ return func(c echo.Context) error { return c.String(200, appConfig.P2PToken) }
+}
diff --git a/core/http/endpoints/localai/settings.go b/core/http/endpoints/localai/settings.go
new file mode 100644
index 0000000000000000000000000000000000000000..93746baaa37379c67165c7a69b1e4446a28e2e24
--- /dev/null
+++ b/core/http/endpoints/localai/settings.go
@@ -0,0 +1,214 @@
+package localai
+
+import (
+ "encoding/json"
+ "io"
+ "net/http"
+ "os"
+ "path/filepath"
+ "time"
+
+ "github.com/labstack/echo/v4"
+ "github.com/mudler/LocalAI/core/application"
+ "github.com/mudler/LocalAI/core/config"
+ "github.com/mudler/LocalAI/core/p2p"
+ "github.com/mudler/LocalAI/core/schema"
+ "github.com/mudler/xlog"
+)
+
+// GetSettingsEndpoint returns current settings with precedence (env > file > defaults)
+func GetSettingsEndpoint(app *application.Application) echo.HandlerFunc {
+ return func(c echo.Context) error {
+ appConfig := app.ApplicationConfig()
+ settings := appConfig.ToRuntimeSettings()
+ return c.JSON(http.StatusOK, settings)
+ }
+}
+
+// UpdateSettingsEndpoint updates settings, saves to file, and applies immediately
+func UpdateSettingsEndpoint(app *application.Application) echo.HandlerFunc {
+ return func(c echo.Context) error {
+ appConfig := app.ApplicationConfig()
+ startupConfig := app.StartupConfig()
+
+ if startupConfig == nil {
+ startupConfig = appConfig
+ }
+
+ body, err := io.ReadAll(c.Request().Body)
+ if err != nil {
+ return c.JSON(http.StatusBadRequest, schema.SettingsResponse{
+ Success: false,
+ Error: "Failed to read request body: " + err.Error(),
+ })
+ }
+
+ var settings config.RuntimeSettings
+ if err := json.Unmarshal(body, &settings); err != nil {
+ return c.JSON(http.StatusBadRequest, schema.SettingsResponse{
+ Success: false,
+ Error: "Failed to parse JSON: " + err.Error(),
+ })
+ }
+
+ // Validate timeouts if provided
+ if settings.WatchdogIdleTimeout != nil {
+ if _, err := time.ParseDuration(*settings.WatchdogIdleTimeout); err != nil {
+ return c.JSON(http.StatusBadRequest, schema.SettingsResponse{
+ Success: false,
+ Error: "Invalid watchdog_idle_timeout format: " + err.Error(),
+ })
+ }
+ }
+ if settings.WatchdogBusyTimeout != nil {
+ if _, err := time.ParseDuration(*settings.WatchdogBusyTimeout); err != nil {
+ return c.JSON(http.StatusBadRequest, schema.SettingsResponse{
+ Success: false,
+ Error: "Invalid watchdog_busy_timeout format: " + err.Error(),
+ })
+ }
+ }
+ if settings.WatchdogInterval != nil {
+ if _, err := time.ParseDuration(*settings.WatchdogInterval); err != nil {
+ return c.JSON(http.StatusBadRequest, schema.SettingsResponse{
+ Success: false,
+ Error: "Invalid watchdog_interval format: " + err.Error(),
+ })
+ }
+ }
+ if settings.LRUEvictionRetryInterval != nil {
+ if _, err := time.ParseDuration(*settings.LRUEvictionRetryInterval); err != nil {
+ return c.JSON(http.StatusBadRequest, schema.SettingsResponse{
+ Success: false,
+ Error: "Invalid lru_eviction_retry_interval format: " + err.Error(),
+ })
+ }
+ }
+
+ // Save to file
+ if appConfig.DynamicConfigsDir == "" {
+ return c.JSON(http.StatusBadRequest, schema.SettingsResponse{
+ Success: false,
+ Error: "DynamicConfigsDir is not set",
+ })
+ }
+
+ settingsFile := filepath.Join(appConfig.DynamicConfigsDir, "runtime_settings.json")
+ settingsJSON, err := json.MarshalIndent(settings, "", " ")
+ if err != nil {
+ return c.JSON(http.StatusInternalServerError, schema.SettingsResponse{
+ Success: false,
+ Error: "Failed to marshal settings: " + err.Error(),
+ })
+ }
+
+ if err := os.WriteFile(settingsFile, settingsJSON, 0600); err != nil {
+ return c.JSON(http.StatusInternalServerError, schema.SettingsResponse{
+ Success: false,
+ Error: "Failed to write settings file: " + err.Error(),
+ })
+ }
+
+ // Apply settings using centralized method
+ watchdogChanged := appConfig.ApplyRuntimeSettings(&settings)
+
+ // Handle API keys specially (merge with startup keys)
+ if settings.ApiKeys != nil {
+ envKeys := startupConfig.ApiKeys
+ runtimeKeys := *settings.ApiKeys
+ appConfig.ApiKeys = append(envKeys, runtimeKeys...)
+ }
+
+ // Update watchdog dynamically for settings that don't require restart
+ if settings.ForceEvictionWhenBusy != nil {
+ currentWD := app.ModelLoader().GetWatchDog()
+ if currentWD != nil {
+ currentWD.SetForceEvictionWhenBusy(*settings.ForceEvictionWhenBusy)
+ xlog.Info("Updated watchdog force eviction when busy setting", "forceEvictionWhenBusy", *settings.ForceEvictionWhenBusy)
+ }
+ }
+
+ // Update ModelLoader LRU eviction retry settings dynamically
+ maxRetries := appConfig.LRUEvictionMaxRetries
+ retryInterval := appConfig.LRUEvictionRetryInterval
+ if settings.LRUEvictionMaxRetries != nil {
+ maxRetries = *settings.LRUEvictionMaxRetries
+ }
+ if settings.LRUEvictionRetryInterval != nil {
+ if dur, err := time.ParseDuration(*settings.LRUEvictionRetryInterval); err == nil {
+ retryInterval = dur
+ }
+ }
+ if settings.LRUEvictionMaxRetries != nil || settings.LRUEvictionRetryInterval != nil {
+ app.ModelLoader().SetLRUEvictionRetrySettings(maxRetries, retryInterval)
+ xlog.Info("Updated LRU eviction retry settings", "maxRetries", maxRetries, "retryInterval", retryInterval)
+ }
+
+ // Check if agent job retention changed
+ agentJobChanged := settings.AgentJobRetentionDays != nil
+
+ // Restart watchdog if settings changed
+ if watchdogChanged {
+ if settings.WatchdogEnabled != nil && !*settings.WatchdogEnabled {
+ if err := app.StopWatchdog(); err != nil {
+ xlog.Error("Failed to stop watchdog", "error", err)
+ return c.JSON(http.StatusInternalServerError, schema.SettingsResponse{
+ Success: false,
+ Error: "Settings saved but failed to stop watchdog: " + err.Error(),
+ })
+ }
+ } else {
+ if err := app.RestartWatchdog(); err != nil {
+ xlog.Error("Failed to restart watchdog", "error", err)
+ return c.JSON(http.StatusInternalServerError, schema.SettingsResponse{
+ Success: false,
+ Error: "Settings saved but failed to restart watchdog: " + err.Error(),
+ })
+ }
+ }
+ }
+
+ // Restart agent job service if retention days changed
+ if agentJobChanged {
+ if err := app.RestartAgentJobService(); err != nil {
+ xlog.Error("Failed to restart agent job service", "error", err)
+ return c.JSON(http.StatusInternalServerError, schema.SettingsResponse{
+ Success: false,
+ Error: "Settings saved but failed to restart agent job service: " + err.Error(),
+ })
+ }
+ }
+
+ // Restart P2P if P2P settings changed
+ p2pChanged := settings.P2PToken != nil || settings.P2PNetworkID != nil || settings.Federated != nil
+ if p2pChanged {
+ if settings.P2PToken != nil && *settings.P2PToken == "" {
+ if err := app.StopP2P(); err != nil {
+ xlog.Error("Failed to stop P2P", "error", err)
+ return c.JSON(http.StatusInternalServerError, schema.SettingsResponse{
+ Success: false,
+ Error: "Settings saved but failed to stop P2P: " + err.Error(),
+ })
+ }
+ } else {
+ if settings.P2PToken != nil && *settings.P2PToken == "0" {
+ token := p2p.GenerateToken(60, 60)
+ settings.P2PToken = &token
+ appConfig.P2PToken = token
+ }
+ if err := app.RestartP2P(); err != nil {
+ xlog.Error("Failed to restart P2P", "error", err)
+ return c.JSON(http.StatusInternalServerError, schema.SettingsResponse{
+ Success: false,
+ Error: "Settings saved but failed to restart P2P: " + err.Error(),
+ })
+ }
+ }
+ }
+
+ return c.JSON(http.StatusOK, schema.SettingsResponse{
+ Success: true,
+ Message: "Settings updated successfully",
+ })
+ }
+}
diff --git a/core/http/endpoints/localai/stores.go b/core/http/endpoints/localai/stores.go
new file mode 100644
index 0000000000000000000000000000000000000000..8074da9e0749359b9e0d66ddf774d80698038881
--- /dev/null
+++ b/core/http/endpoints/localai/stores.go
@@ -0,0 +1,121 @@
+package localai
+
+import (
+ "github.com/labstack/echo/v4"
+ "github.com/mudler/LocalAI/core/backend"
+ "github.com/mudler/LocalAI/core/config"
+ "github.com/mudler/LocalAI/core/schema"
+ "github.com/mudler/LocalAI/pkg/model"
+ "github.com/mudler/LocalAI/pkg/store"
+)
+
+func StoresSetEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
+ return func(c echo.Context) error {
+ input := new(schema.StoresSet)
+
+ if err := c.Bind(input); err != nil {
+ return err
+ }
+
+ sb, err := backend.StoreBackend(sl, appConfig, input.Store, input.Backend)
+ if err != nil {
+ return err
+ }
+
+ vals := make([][]byte, len(input.Values))
+ for i, v := range input.Values {
+ vals[i] = []byte(v)
+ }
+
+ err = store.SetCols(c.Request().Context(), sb, input.Keys, vals)
+ if err != nil {
+ return err
+ }
+
+ return c.NoContent(200)
+ }
+}
+
+func StoresDeleteEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
+ return func(c echo.Context) error {
+ input := new(schema.StoresDelete)
+
+ if err := c.Bind(input); err != nil {
+ return err
+ }
+
+ sb, err := backend.StoreBackend(sl, appConfig, input.Store, input.Backend)
+ if err != nil {
+ return err
+ }
+
+ if err := store.DeleteCols(c.Request().Context(), sb, input.Keys); err != nil {
+ return err
+ }
+
+ return c.NoContent(200)
+ }
+}
+
+func StoresGetEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
+ return func(c echo.Context) error {
+ input := new(schema.StoresGet)
+
+ if err := c.Bind(input); err != nil {
+ return err
+ }
+
+ sb, err := backend.StoreBackend(sl, appConfig, input.Store, input.Backend)
+ if err != nil {
+ return err
+ }
+
+ keys, vals, err := store.GetCols(c.Request().Context(), sb, input.Keys)
+ if err != nil {
+ return err
+ }
+
+ res := schema.StoresGetResponse{
+ Keys: keys,
+ Values: make([]string, len(vals)),
+ }
+
+ for i, v := range vals {
+ res.Values[i] = string(v)
+ }
+
+ return c.JSON(200, res)
+ }
+}
+
+func StoresFindEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
+ return func(c echo.Context) error {
+ input := new(schema.StoresFind)
+
+ if err := c.Bind(input); err != nil {
+ return err
+ }
+
+ sb, err := backend.StoreBackend(sl, appConfig, input.Store, input.Backend)
+ if err != nil {
+ return err
+ }
+
+ keys, vals, similarities, err := store.Find(c.Request().Context(), sb, input.Key, input.Topk)
+ if err != nil {
+ return err
+ }
+
+ res := schema.StoresFindResponse{
+ Keys: keys,
+ Values: make([]string, len(vals)),
+ Similarities: similarities,
+ }
+
+ for i, v := range vals {
+ res.Values[i] = string(v)
+ }
+
+ return c.JSON(200, res)
+ }
+}
diff --git a/core/http/endpoints/localai/system.go b/core/http/endpoints/localai/system.go
new file mode 100644
index 0000000000000000000000000000000000000000..a3831e18483a811d33ce2f4a9daaaf68b907ec31
--- /dev/null
+++ b/core/http/endpoints/localai/system.go
@@ -0,0 +1,36 @@
+package localai
+
+import (
+ "github.com/labstack/echo/v4"
+ "github.com/mudler/LocalAI/core/config"
+ "github.com/mudler/LocalAI/core/schema"
+ "github.com/mudler/LocalAI/pkg/model"
+)
+
+// SystemInformations returns the system informations
+// @Summary Show the LocalAI instance information
+// @Success 200 {object} schema.SystemInformationResponse "Response"
+// @Router /system [get]
+func SystemInformations(ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
+ return func(c echo.Context) error {
+ availableBackends := []string{}
+ loadedModels := ml.ListLoadedModels()
+ for b := range appConfig.ExternalGRPCBackends {
+ availableBackends = append(availableBackends, b)
+ }
+ for b := range ml.GetAllExternalBackends(nil) {
+ availableBackends = append(availableBackends, b)
+ }
+
+ sysmodels := []schema.SysInfoModel{}
+ for _, m := range loadedModels {
+ sysmodels = append(sysmodels, schema.SysInfoModel{ID: m.ID})
+ }
+ return c.JSON(200,
+ schema.SystemInformationResponse{
+ Backends: availableBackends,
+ Models: sysmodels,
+ },
+ )
+ }
+}
diff --git a/core/http/endpoints/localai/tokenize.go b/core/http/endpoints/localai/tokenize.go
new file mode 100644
index 0000000000000000000000000000000000000000..23eec48c75457d13b11ac0bb2861ee0165382419
--- /dev/null
+++ b/core/http/endpoints/localai/tokenize.go
@@ -0,0 +1,35 @@
+package localai
+
+import (
+ "github.com/labstack/echo/v4"
+ "github.com/mudler/LocalAI/core/backend"
+ "github.com/mudler/LocalAI/core/config"
+ "github.com/mudler/LocalAI/core/http/middleware"
+ "github.com/mudler/LocalAI/core/schema"
+ "github.com/mudler/LocalAI/pkg/model"
+)
+
+// TokenizeEndpoint exposes a REST API to tokenize the content
+// @Summary Tokenize the input.
+// @Param request body schema.TokenizeRequest true "Request"
+// @Success 200 {object} schema.TokenizeResponse "Response"
+// @Router /v1/tokenize [post]
+func TokenizeEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
+ return func(c echo.Context) error {
+ input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.TokenizeRequest)
+ if !ok || input.Model == "" {
+ return echo.ErrBadRequest
+ }
+
+ cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
+ if !ok || cfg == nil {
+ return echo.ErrBadRequest
+ }
+
+ tokenResponse, err := backend.ModelTokenize(input.Content, ml, *cfg, appConfig)
+ if err != nil {
+ return err
+ }
+ return c.JSON(200, tokenResponse)
+ }
+}
diff --git a/core/http/endpoints/localai/tts.go b/core/http/endpoints/localai/tts.go
new file mode 100644
index 0000000000000000000000000000000000000000..9dd588ad7cb776e9447c24d449b83a5cba3221e4
--- /dev/null
+++ b/core/http/endpoints/localai/tts.go
@@ -0,0 +1,66 @@
+package localai
+
+import (
+ "path/filepath"
+
+ "github.com/labstack/echo/v4"
+ "github.com/mudler/LocalAI/core/backend"
+ "github.com/mudler/LocalAI/core/config"
+ "github.com/mudler/LocalAI/core/http/middleware"
+ "github.com/mudler/LocalAI/pkg/model"
+
+ "github.com/mudler/LocalAI/core/schema"
+ "github.com/mudler/xlog"
+
+ "github.com/mudler/LocalAI/pkg/utils"
+)
+
+// TTSEndpoint is the OpenAI Speech API endpoint https://platform.openai.com/docs/api-reference/audio/createSpeech
+//
+// @Summary Generates audio from the input text.
+// @Accept json
+// @Produce audio/x-wav
+// @Param request body schema.TTSRequest true "query params"
+// @Success 200 {string} binary "generated audio/wav file"
+// @Router /v1/audio/speech [post]
+// @Router /tts [post]
+func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
+ return func(c echo.Context) error {
+ input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.TTSRequest)
+ if !ok || input.Model == "" {
+ return echo.ErrBadRequest
+ }
+
+ cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
+ if !ok || cfg == nil {
+ return echo.ErrBadRequest
+ }
+
+ xlog.Debug("LocalAI TTS Request received", "model", input.Model)
+
+ if cfg.Backend == "" && input.Backend != "" {
+ cfg.Backend = input.Backend
+ }
+
+ if input.Language != "" {
+ cfg.Language = input.Language
+ }
+
+ if input.Voice != "" {
+ cfg.Voice = input.Voice
+ }
+
+ filePath, _, err := backend.ModelTTS(input.Input, cfg.Voice, cfg.Language, ml, appConfig, *cfg)
+ if err != nil {
+ return err
+ }
+
+ // Convert generated file to target format
+ filePath, err = utils.AudioConvert(filePath, input.Format)
+ if err != nil {
+ return err
+ }
+
+ return c.Attachment(filePath, filepath.Base(filePath))
+ }
+}
diff --git a/core/http/endpoints/localai/types.go b/core/http/endpoints/localai/types.go
new file mode 100644
index 0000000000000000000000000000000000000000..32a5490893506538b08be030f2f804306c0c7727
--- /dev/null
+++ b/core/http/endpoints/localai/types.go
@@ -0,0 +1,11 @@
+package localai
+
+// ModelResponse represents the common response structure for model operations
+type ModelResponse struct {
+ Success bool `json:"success"`
+ Message string `json:"message"`
+ Filename string `json:"filename,omitempty"`
+ Config interface{} `json:"config,omitempty"`
+ Error string `json:"error,omitempty"`
+ Details []string `json:"details,omitempty"`
+}
diff --git a/core/http/endpoints/localai/vad.go b/core/http/endpoints/localai/vad.go
new file mode 100644
index 0000000000000000000000000000000000000000..155574c8510211211bb9f4c20c7ec677ed0865b3
--- /dev/null
+++ b/core/http/endpoints/localai/vad.go
@@ -0,0 +1,41 @@
+package localai
+
+import (
+ "github.com/labstack/echo/v4"
+ "github.com/mudler/LocalAI/core/backend"
+ "github.com/mudler/LocalAI/core/config"
+ "github.com/mudler/LocalAI/core/http/middleware"
+ "github.com/mudler/LocalAI/core/schema"
+ "github.com/mudler/LocalAI/pkg/model"
+ "github.com/mudler/xlog"
+)
+
+// VADEndpoint is Voice-Activation-Detection endpoint
+// @Summary Detect voice fragments in an audio stream
+// @Accept json
+// @Param request body schema.VADRequest true "query params"
+// @Success 200 {object} proto.VADResponse "Response"
+// @Router /vad [post]
+func VADEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
+ return func(c echo.Context) error {
+ input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.VADRequest)
+ if !ok || input.Model == "" {
+ return echo.ErrBadRequest
+ }
+
+ cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
+ if !ok || cfg == nil {
+ return echo.ErrBadRequest
+ }
+
+ xlog.Debug("LocalAI VAD Request received", "model", input.Model)
+
+ resp, err := backend.VAD(input, c.Request().Context(), ml, appConfig, *cfg)
+
+ if err != nil {
+ return err
+ }
+
+ return c.JSON(200, resp)
+ }
+}
diff --git a/core/http/endpoints/localai/video.go b/core/http/endpoints/localai/video.go
new file mode 100644
index 0000000000000000000000000000000000000000..4ff343af0f8689f1e97e59910258151686a8e850
--- /dev/null
+++ b/core/http/endpoints/localai/video.go
@@ -0,0 +1,225 @@
+package localai
+
+import (
+ "bufio"
+ "encoding/base64"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+ "net/url"
+ "os"
+ "path/filepath"
+ "strings"
+ "time"
+
+ "github.com/google/uuid"
+ "github.com/labstack/echo/v4"
+ "github.com/mudler/LocalAI/core/config"
+ "github.com/mudler/LocalAI/core/http/middleware"
+ "github.com/mudler/LocalAI/core/schema"
+
+ "github.com/mudler/LocalAI/core/backend"
+
+ model "github.com/mudler/LocalAI/pkg/model"
+ "github.com/mudler/xlog"
+)
+
+func downloadFile(url string) (string, error) {
+ // Get the data
+ resp, err := http.Get(url)
+ if err != nil {
+ return "", err
+ }
+ defer resp.Body.Close()
+
+ // Create the file
+ out, err := os.CreateTemp("", "video")
+ if err != nil {
+ return "", err
+ }
+ defer out.Close()
+
+ // Write the body to file
+ _, err = io.Copy(out, resp.Body)
+ return out.Name(), err
+}
+
+//
+
+/*
+*
+
+ curl http://localhost:8080/v1/images/generations \
+ -H "Content-Type: application/json" \
+ -d '{
+ "prompt": "A cute baby sea otter",
+ "n": 1,
+ "size": "512x512"
+ }'
+
+*
+*/
+// VideoEndpoint
+// @Summary Creates a video given a prompt.
+// @Param request body schema.VideoRequest true "query params"
+// @Success 200 {object} schema.OpenAIResponse "Response"
+// @Router /video [post]
+func VideoEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
+ return func(c echo.Context) error {
+ input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.VideoRequest)
+ if !ok || input.Model == "" {
+ xlog.Error("Video Endpoint - Invalid Input")
+ return echo.ErrBadRequest
+ }
+
+ config, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
+ if !ok || config == nil {
+ xlog.Error("Video Endpoint - Invalid Config")
+ return echo.ErrBadRequest
+ }
+
+ src := ""
+ if input.StartImage != "" {
+
+ var fileData []byte
+ var err error
+ // check if input.File is an URL, if so download it and save it
+ // to a temporary file
+ if strings.HasPrefix(input.StartImage, "http://") || strings.HasPrefix(input.StartImage, "https://") {
+ out, err := downloadFile(input.StartImage)
+ if err != nil {
+ return fmt.Errorf("failed downloading file:%w", err)
+ }
+ defer os.RemoveAll(out)
+
+ fileData, err = os.ReadFile(out)
+ if err != nil {
+ return fmt.Errorf("failed reading file:%w", err)
+ }
+
+ } else {
+ // base 64 decode the file and write it somewhere
+ // that we will cleanup
+ fileData, err = base64.StdEncoding.DecodeString(input.StartImage)
+ if err != nil {
+ return err
+ }
+ }
+
+ // Create a temporary file
+ outputFile, err := os.CreateTemp(appConfig.GeneratedContentDir, "b64")
+ if err != nil {
+ return err
+ }
+ // write the base64 result
+ writer := bufio.NewWriter(outputFile)
+ _, err = writer.Write(fileData)
+ if err != nil {
+ outputFile.Close()
+ return err
+ }
+ outputFile.Close()
+ src = outputFile.Name()
+ defer os.RemoveAll(src)
+ }
+
+ xlog.Debug("Parameter Config", "config", config)
+
+ switch config.Backend {
+ case "stablediffusion":
+ config.Backend = model.StableDiffusionGGMLBackend
+ case "":
+ config.Backend = model.StableDiffusionGGMLBackend
+ }
+
+ width := input.Width
+ height := input.Height
+
+ if width == 0 {
+ width = 512
+ }
+ if height == 0 {
+ height = 512
+ }
+
+ b64JSON := input.ResponseFormat == "b64_json"
+
+ tempDir := ""
+ if !b64JSON {
+ tempDir = filepath.Join(appConfig.GeneratedContentDir, "videos")
+ }
+ // Create a temporary file
+ outputFile, err := os.CreateTemp(tempDir, "b64")
+ if err != nil {
+ return err
+ }
+ outputFile.Close()
+
+ // TODO: use mime type to determine the extension
+ output := outputFile.Name() + ".mp4"
+
+ // Rename the temporary file
+ err = os.Rename(outputFile.Name(), output)
+ if err != nil {
+ return err
+ }
+
+ baseURL := middleware.BaseURL(c)
+
+ fn, err := backend.VideoGeneration(
+ height,
+ width,
+ input.Prompt,
+ input.NegativePrompt,
+ src,
+ input.EndImage,
+ output,
+ input.NumFrames,
+ input.FPS,
+ input.Seed,
+ input.CFGScale,
+ input.Step,
+ ml,
+ *config,
+ appConfig,
+ )
+ if err != nil {
+ return err
+ }
+ if err := fn(); err != nil {
+ return err
+ }
+
+ item := &schema.Item{}
+
+ if b64JSON {
+ defer os.RemoveAll(output)
+ data, err := os.ReadFile(output)
+ if err != nil {
+ return err
+ }
+ item.B64JSON = base64.StdEncoding.EncodeToString(data)
+ } else {
+ base := filepath.Base(output)
+ item.URL, err = url.JoinPath(baseURL, "generated-videos", base)
+ if err != nil {
+ return err
+ }
+ }
+
+ id := uuid.New().String()
+ created := int(time.Now().Unix())
+ resp := &schema.OpenAIResponse{
+ ID: id,
+ Created: created,
+ Data: []schema.Item{*item},
+ }
+
+ jsonResult, _ := json.Marshal(resp)
+ xlog.Debug("Response", "response", string(jsonResult))
+
+ // Return the prediction in the response body
+ return c.JSON(200, resp)
+ }
+}
diff --git a/core/http/endpoints/localai/welcome.go b/core/http/endpoints/localai/welcome.go
new file mode 100644
index 0000000000000000000000000000000000000000..ce197ba05e739760dad11bf242cd502eef100b12
--- /dev/null
+++ b/core/http/endpoints/localai/welcome.go
@@ -0,0 +1,77 @@
+package localai
+
+import (
+ "strings"
+
+ "github.com/labstack/echo/v4"
+ "github.com/mudler/LocalAI/core/config"
+ "github.com/mudler/LocalAI/core/gallery"
+ "github.com/mudler/LocalAI/core/http/middleware"
+ "github.com/mudler/LocalAI/core/services"
+ "github.com/mudler/LocalAI/internal"
+ "github.com/mudler/LocalAI/pkg/model"
+)
+
+func WelcomeEndpoint(appConfig *config.ApplicationConfig,
+ cl *config.ModelConfigLoader, ml *model.ModelLoader, opcache *services.OpCache) echo.HandlerFunc {
+ return func(c echo.Context) error {
+ modelConfigs := cl.GetAllModelsConfigs()
+ galleryConfigs := map[string]*gallery.ModelConfig{}
+
+ installedBackends, err := gallery.ListSystemBackends(appConfig.SystemState)
+ if err != nil {
+ return err
+ }
+
+ for _, m := range modelConfigs {
+ cfg, err := gallery.GetLocalModelConfiguration(ml.ModelPath, m.Name)
+ if err != nil {
+ continue
+ }
+ galleryConfigs[m.Name] = cfg
+ }
+
+ loadedModels := ml.ListLoadedModels()
+ loadedModelsMap := map[string]bool{}
+ for _, m := range loadedModels {
+ loadedModelsMap[m.ID] = true
+ }
+
+ modelsWithoutConfig, _ := services.ListModels(cl, ml, config.NoFilterFn, services.LOOSE_ONLY)
+
+ // Get model statuses to display in the UI the operation in progress
+ processingModels, taskTypes := opcache.GetStatus()
+
+ summary := map[string]interface{}{
+ "Title": "LocalAI API - " + internal.PrintableVersion(),
+ "Version": internal.PrintableVersion(),
+ "BaseURL": middleware.BaseURL(c),
+ "Models": modelsWithoutConfig,
+ "ModelsConfig": modelConfigs,
+ "GalleryConfig": galleryConfigs,
+ "ApplicationConfig": appConfig,
+ "ProcessingModels": processingModels,
+ "TaskTypes": taskTypes,
+ "LoadedModels": loadedModelsMap,
+ "InstalledBackends": installedBackends,
+ "DisableRuntimeSettings": appConfig.DisableRuntimeSettings,
+ }
+
+ contentType := c.Request().Header.Get("Content-Type")
+ accept := c.Request().Header.Get("Accept")
+ // Default to HTML if Accept header is empty (browser behavior)
+ // Only return JSON if explicitly requested or Content-Type is application/json
+ if strings.Contains(contentType, "application/json") || (accept != "" && !strings.Contains(accept, "text/html")) {
+ // The client expects a JSON response
+ return c.JSON(200, summary)
+ } else {
+ // Check if this is the manage route
+ templateName := "views/index"
+ if strings.HasSuffix(c.Request().URL.Path, "/manage") || c.Request().URL.Path == "/manage" {
+ templateName = "views/manage"
+ }
+ // Render appropriate template
+ return c.Render(200, templateName, summary)
+ }
+ }
+}
diff --git a/core/http/endpoints/mcp/tools.go b/core/http/endpoints/mcp/tools.go
new file mode 100644
index 0000000000000000000000000000000000000000..7954e85b609a71bba8201ecab0b0fc5e0b5b8eb4
--- /dev/null
+++ b/core/http/endpoints/mcp/tools.go
@@ -0,0 +1,120 @@
+package mcp
+
+import (
+ "context"
+ "net/http"
+ "os"
+ "os/exec"
+ "sync"
+ "time"
+
+ "github.com/mudler/LocalAI/core/config"
+ "github.com/mudler/LocalAI/pkg/signals"
+
+ "github.com/modelcontextprotocol/go-sdk/mcp"
+ "github.com/mudler/xlog"
+)
+
+type sessionCache struct {
+ mu sync.Mutex
+ cache map[string][]*mcp.ClientSession
+}
+
+var (
+ cache = sessionCache{
+ cache: make(map[string][]*mcp.ClientSession),
+ }
+
+ client = mcp.NewClient(&mcp.Implementation{Name: "LocalAI", Version: "v1.0.0"}, nil)
+)
+
+func SessionsFromMCPConfig(
+ name string,
+ remote config.MCPGenericConfig[config.MCPRemoteServers],
+ stdio config.MCPGenericConfig[config.MCPSTDIOServers],
+) ([]*mcp.ClientSession, error) {
+ cache.mu.Lock()
+ defer cache.mu.Unlock()
+
+ sessions, exists := cache.cache[name]
+ if exists {
+ return sessions, nil
+ }
+
+ allSessions := []*mcp.ClientSession{}
+
+ ctx, cancel := context.WithCancel(context.Background())
+
+ // Get the list of all the tools that the Agent will be esposed to
+ for _, server := range remote.Servers {
+ xlog.Debug("[MCP remote server] Configuration", "server", server)
+ // Create HTTP client with custom roundtripper for bearer token injection
+ httpClient := &http.Client{
+ Timeout: 360 * time.Second,
+ Transport: newBearerTokenRoundTripper(server.Token, http.DefaultTransport),
+ }
+
+ transport := &mcp.StreamableClientTransport{Endpoint: server.URL, HTTPClient: httpClient}
+ mcpSession, err := client.Connect(ctx, transport, nil)
+ if err != nil {
+ xlog.Error("Failed to connect to MCP server", "error", err, "url", server.URL)
+ continue
+ }
+ xlog.Debug("[MCP remote server] Connected to MCP server", "url", server.URL)
+ cache.cache[name] = append(cache.cache[name], mcpSession)
+ allSessions = append(allSessions, mcpSession)
+ }
+
+ for _, server := range stdio.Servers {
+ xlog.Debug("[MCP stdio server] Configuration", "server", server)
+ command := exec.Command(server.Command, server.Args...)
+ command.Env = os.Environ()
+ for key, value := range server.Env {
+ command.Env = append(command.Env, key+"="+value)
+ }
+ transport := &mcp.CommandTransport{Command: command}
+ mcpSession, err := client.Connect(ctx, transport, nil)
+ if err != nil {
+ xlog.Error("Failed to start MCP server", "error", err, "command", command)
+ continue
+ }
+ xlog.Debug("[MCP stdio server] Connected to MCP server", "command", command)
+ cache.cache[name] = append(cache.cache[name], mcpSession)
+ allSessions = append(allSessions, mcpSession)
+ }
+
+ signals.RegisterGracefulTerminationHandler(func() {
+ for _, session := range allSessions {
+ session.Close()
+ }
+ cancel()
+ })
+
+ return allSessions, nil
+}
+
+// bearerTokenRoundTripper is a custom roundtripper that injects a bearer token
+// into HTTP requests
+type bearerTokenRoundTripper struct {
+ token string
+ base http.RoundTripper
+}
+
+// RoundTrip implements the http.RoundTripper interface
+func (rt *bearerTokenRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
+ if rt.token != "" {
+ req.Header.Set("Authorization", "Bearer "+rt.token)
+ }
+ return rt.base.RoundTrip(req)
+}
+
+// newBearerTokenRoundTripper creates a new roundtripper that injects the given token
+func newBearerTokenRoundTripper(token string, base http.RoundTripper) http.RoundTripper {
+ if base == nil {
+ base = http.DefaultTransport
+ }
+ return &bearerTokenRoundTripper{
+ token: token,
+ base: base,
+ }
+}
diff --git a/core/http/endpoints/openai/chat.go b/core/http/endpoints/openai/chat.go
new file mode 100644
index 0000000000000000000000000000000000000000..4ece68d5c0a85927f752d259d231fb2c4b11503f
--- /dev/null
+++ b/core/http/endpoints/openai/chat.go
@@ -0,0 +1,861 @@
+package openai
+
+import (
+ "encoding/json"
+ "fmt"
+ "strings"
+ "time"
+
+ "github.com/google/uuid"
+ "github.com/labstack/echo/v4"
+ "github.com/mudler/LocalAI/core/backend"
+ "github.com/mudler/LocalAI/core/config"
+ "github.com/mudler/LocalAI/core/http/middleware"
+ "github.com/mudler/LocalAI/core/schema"
+ "github.com/mudler/LocalAI/pkg/functions"
+
+ "github.com/mudler/LocalAI/core/templates"
+ "github.com/mudler/LocalAI/pkg/model"
+
+ "github.com/mudler/xlog"
+)
+
+// ChatEndpoint is the OpenAI Completion API endpoint https://platform.openai.com/docs/api-reference/chat/create
+// @Summary Generate a chat completions for a given prompt and model.
+// @Param request body schema.OpenAIRequest true "query params"
+// @Success 200 {object} schema.OpenAIResponse "Response"
+// @Router /v1/chat/completions [post]
+func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, startupOptions *config.ApplicationConfig) echo.HandlerFunc {
+ var id, textContentToReturn string
+ var created int
+
+ process := func(s string, req *schema.OpenAIRequest, config *config.ModelConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) error {
+ initialMessage := schema.OpenAIResponse{
+ ID: id,
+ Created: created,
+ Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
+ Choices: []schema.Choice{{Delta: &schema.Message{Role: "assistant"}, Index: 0, FinishReason: nil}},
+ }
+ responses <- initialMessage
+
+ // Track accumulated content for reasoning extraction
+ accumulatedContent := ""
+ lastEmittedReasoning := ""
+ lastEmittedCleanedContent := ""
+
+ _, _, err := ComputeChoices(req, s, config, cl, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, tokenUsage backend.TokenUsage) bool {
+ accumulatedContent += s
+ // Extract reasoning from accumulated content
+ currentReasoning, cleanedContent := functions.ExtractReasoning(accumulatedContent)
+
+ // Calculate new reasoning delta (what we haven't emitted yet)
+ var reasoningDelta *string
+ if currentReasoning != lastEmittedReasoning {
+ // Extract only the new part
+ if len(currentReasoning) > len(lastEmittedReasoning) && strings.HasPrefix(currentReasoning, lastEmittedReasoning) {
+ newReasoning := currentReasoning[len(lastEmittedReasoning):]
+ reasoningDelta = &newReasoning
+ lastEmittedReasoning = currentReasoning
+ } else if currentReasoning != "" {
+ // If reasoning changed in a non-append way, emit the full current reasoning
+ reasoningDelta = ¤tReasoning
+ lastEmittedReasoning = currentReasoning
+ }
+ }
+
+ // Calculate content delta from cleaned content
+ var deltaContent string
+ if len(cleanedContent) > len(lastEmittedCleanedContent) && strings.HasPrefix(cleanedContent, lastEmittedCleanedContent) {
+ deltaContent = cleanedContent[len(lastEmittedCleanedContent):]
+ lastEmittedCleanedContent = cleanedContent
+ } else if cleanedContent != lastEmittedCleanedContent {
+ // If cleaned content changed but not in a simple append, extract delta from cleaned content
+ // This handles cases where thinking tags are removed mid-stream
+ if lastEmittedCleanedContent == "" {
+ deltaContent = cleanedContent
+ lastEmittedCleanedContent = cleanedContent
+ } else {
+ // Content changed in non-append way, use the new cleaned content
+ deltaContent = cleanedContent
+ lastEmittedCleanedContent = cleanedContent
+ }
+ }
+ // Only emit content if there's actual content (not just thinking tags)
+ // If deltaContent is empty, we still emit the response but with empty content
+
+ usage := schema.OpenAIUsage{
+ PromptTokens: tokenUsage.Prompt,
+ CompletionTokens: tokenUsage.Completion,
+ TotalTokens: tokenUsage.Prompt + tokenUsage.Completion,
+ }
+ if extraUsage {
+ usage.TimingTokenGeneration = tokenUsage.TimingTokenGeneration
+ usage.TimingPromptProcessing = tokenUsage.TimingPromptProcessing
+ }
+
+ delta := &schema.Message{}
+ // Only include content if there's actual content (not just thinking tags)
+ if deltaContent != "" {
+ delta.Content = &deltaContent
+ }
+ if reasoningDelta != nil && *reasoningDelta != "" {
+ delta.Reasoning = reasoningDelta
+ }
+
+ resp := schema.OpenAIResponse{
+ ID: id,
+ Created: created,
+ Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
+ Choices: []schema.Choice{{Delta: delta, Index: 0, FinishReason: nil}},
+ Object: "chat.completion.chunk",
+ Usage: usage,
+ }
+
+ responses <- resp
+ return true
+ })
+ close(responses)
+ return err
+ }
+ processTools := func(noAction string, prompt string, req *schema.OpenAIRequest, config *config.ModelConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) error {
+ result := ""
+ lastEmittedCount := 0
+ _, tokenUsage, err := ComputeChoices(req, prompt, config, cl, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool {
+ result += s
+ // Try incremental XML parsing for streaming support using iterative parser
+ // This allows emitting partial tool calls as they're being generated
+ cleanedResult := functions.CleanupLLMResult(result, config.FunctionsConfig)
+
+ // Determine XML format from config
+ var xmlFormat *functions.XMLToolCallFormat
+ if config.FunctionsConfig.XMLFormat != nil {
+ xmlFormat = config.FunctionsConfig.XMLFormat
+ } else if config.FunctionsConfig.XMLFormatPreset != "" {
+ xmlFormat = functions.GetXMLFormatPreset(config.FunctionsConfig.XMLFormatPreset)
+ }
+
+ // Use iterative parser for streaming (partial parsing enabled)
+ // Try XML parsing first
+ partialResults, parseErr := functions.ParseXMLIterative(cleanedResult, xmlFormat, true)
+ if parseErr == nil && len(partialResults) > 0 {
+ // Emit new XML tool calls that weren't emitted before
+ if len(partialResults) > lastEmittedCount {
+ for i := lastEmittedCount; i < len(partialResults); i++ {
+ toolCall := partialResults[i]
+ initialMessage := schema.OpenAIResponse{
+ ID: id,
+ Created: created,
+ Model: req.Model,
+ Choices: []schema.Choice{{
+ Delta: &schema.Message{
+ Role: "assistant",
+ ToolCalls: []schema.ToolCall{
+ {
+ Index: i,
+ ID: id,
+ Type: "function",
+ FunctionCall: schema.FunctionCall{
+ Name: toolCall.Name,
+ },
+ },
+ },
+ },
+ Index: 0,
+ FinishReason: nil,
+ }},
+ Object: "chat.completion.chunk",
+ }
+ select {
+ case responses <- initialMessage:
+ default:
+ }
+ }
+ lastEmittedCount = len(partialResults)
+ }
+ } else {
+ // Try JSON tool call parsing for streaming
+ // Check if the result looks like JSON tool calls
+ jsonResults, jsonErr := functions.ParseJSONIterative(cleanedResult, true)
+ if jsonErr == nil && len(jsonResults) > 0 {
+ // Check if these are tool calls (have "name" and optionally "arguments")
+ for _, jsonObj := range jsonResults {
+ if name, ok := jsonObj["name"].(string); ok && name != "" {
+ // This looks like a tool call
+ args := "{}"
+ if argsVal, ok := jsonObj["arguments"]; ok {
+ if argsStr, ok := argsVal.(string); ok {
+ args = argsStr
+ } else {
+ argsBytes, _ := json.Marshal(argsVal)
+ args = string(argsBytes)
+ }
+ }
+ // Emit tool call
+ initialMessage := schema.OpenAIResponse{
+ ID: id,
+ Created: created,
+ Model: req.Model,
+ Choices: []schema.Choice{{
+ Delta: &schema.Message{
+ Role: "assistant",
+ ToolCalls: []schema.ToolCall{
+ {
+ Index: lastEmittedCount,
+ ID: id,
+ Type: "function",
+ FunctionCall: schema.FunctionCall{
+ Name: name,
+ Arguments: args,
+ },
+ },
+ },
+ },
+ Index: 0,
+ FinishReason: nil,
+ }},
+ Object: "chat.completion.chunk",
+ }
+ select {
+ case responses <- initialMessage:
+ default:
+ }
+ lastEmittedCount++
+ }
+ }
+ }
+ }
+ return true
+ })
+ if err != nil {
+ return err
+ }
+ // Extract reasoning before processing tool calls
+ reasoning, cleanedResult := functions.ExtractReasoning(result)
+ result = cleanedResult
+
+ textContentToReturn = functions.ParseTextContent(result, config.FunctionsConfig)
+ result = functions.CleanupLLMResult(result, config.FunctionsConfig)
+ functionResults := functions.ParseFunctionCall(result, config.FunctionsConfig)
+ xlog.Debug("Text content to return", "text", textContentToReturn)
+ noActionToRun := len(functionResults) > 0 && functionResults[0].Name == noAction || len(functionResults) == 0
+
+ switch {
+ case noActionToRun:
+ initialMessage := schema.OpenAIResponse{
+ ID: id,
+ Created: created,
+ Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
+ Choices: []schema.Choice{{Delta: &schema.Message{Role: "assistant"}, Index: 0, FinishReason: nil}},
+ Object: "chat.completion.chunk",
+ }
+ responses <- initialMessage
+
+ result, err := handleQuestion(config, cl, req, ml, startupOptions, functionResults, result, prompt)
+ if err != nil {
+ xlog.Error("error handling question", "error", err)
+ return err
+ }
+ usage := schema.OpenAIUsage{
+ PromptTokens: tokenUsage.Prompt,
+ CompletionTokens: tokenUsage.Completion,
+ TotalTokens: tokenUsage.Prompt + tokenUsage.Completion,
+ }
+ if extraUsage {
+ usage.TimingTokenGeneration = tokenUsage.TimingTokenGeneration
+ usage.TimingPromptProcessing = tokenUsage.TimingPromptProcessing
+ }
+
+ var deltaReasoning *string
+ if reasoning != "" {
+ deltaReasoning = &reasoning
+ }
+ delta := &schema.Message{Content: &result}
+ if deltaReasoning != nil {
+ delta.Reasoning = deltaReasoning
+ }
+
+ resp := schema.OpenAIResponse{
+ ID: id,
+ Created: created,
+ Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
+ Choices: []schema.Choice{{Delta: delta, Index: 0, FinishReason: nil}},
+ Object: "chat.completion.chunk",
+ Usage: usage,
+ }
+
+ responses <- resp
+
+ default:
+ for i, ss := range functionResults {
+ name, args := ss.Name, ss.Arguments
+
+ initialMessage := schema.OpenAIResponse{
+ ID: id,
+ Created: created,
+ Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
+ Choices: []schema.Choice{{
+ Delta: &schema.Message{
+ Role: "assistant",
+ ToolCalls: []schema.ToolCall{
+ {
+ Index: i,
+ ID: id,
+ Type: "function",
+ FunctionCall: schema.FunctionCall{
+ Name: name,
+ },
+ },
+ },
+ },
+ Index: 0,
+ FinishReason: nil,
+ }},
+ Object: "chat.completion.chunk",
+ }
+ responses <- initialMessage
+
+ responses <- schema.OpenAIResponse{
+ ID: id,
+ Created: created,
+ Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
+ Choices: []schema.Choice{{
+ Delta: &schema.Message{
+ Role: "assistant",
+ Content: &textContentToReturn,
+ ToolCalls: []schema.ToolCall{
+ {
+ Index: i,
+ ID: id,
+ Type: "function",
+ FunctionCall: schema.FunctionCall{
+ Arguments: args,
+ },
+ },
+ },
+ },
+ Index: 0,
+ FinishReason: nil,
+ }},
+ Object: "chat.completion.chunk",
+ }
+ }
+ }
+
+ close(responses)
+ return err
+ }
+
+ return func(c echo.Context) error {
+ textContentToReturn = ""
+ id = uuid.New().String()
+ created = int(time.Now().Unix())
+
+ input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
+ if !ok || input.Model == "" {
+ return echo.ErrBadRequest
+ }
+
+ extraUsage := c.Request().Header.Get("Extra-Usage") != ""
+
+ config, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
+ if !ok || config == nil {
+ return echo.ErrBadRequest
+ }
+
+ xlog.Debug("Chat endpoint configuration read", "config", config)
+
+ funcs := input.Functions
+ shouldUseFn := len(input.Functions) > 0 && config.ShouldUseFunctions()
+ strictMode := false
+
+ for _, f := range input.Functions {
+ if f.Strict {
+ strictMode = true
+ break
+ }
+ }
+
+ // Allow the user to set custom actions via config file
+ // to be "embedded" in each model
+ noActionName := "answer"
+ noActionDescription := "use this action to answer without performing any action"
+
+ if config.FunctionsConfig.NoActionFunctionName != "" {
+ noActionName = config.FunctionsConfig.NoActionFunctionName
+ }
+ if config.FunctionsConfig.NoActionDescriptionName != "" {
+ noActionDescription = config.FunctionsConfig.NoActionDescriptionName
+ }
+
+ // If we are using a response format, we need to generate a grammar for it
+ if config.ResponseFormatMap != nil {
+ d := schema.ChatCompletionResponseFormat{}
+ dat, err := json.Marshal(config.ResponseFormatMap)
+ if err != nil {
+ return err
+ }
+ err = json.Unmarshal(dat, &d)
+ if err != nil {
+ return err
+ }
+
+ switch d.Type {
+ case "json_object":
+ input.Grammar = functions.JSONBNF
+ case "json_schema":
+ d := schema.JsonSchemaRequest{}
+ dat, err := json.Marshal(config.ResponseFormatMap)
+ if err != nil {
+ return err
+ }
+ err = json.Unmarshal(dat, &d)
+ if err != nil {
+ return err
+ }
+ fs := &functions.JSONFunctionStructure{
+ AnyOf: []functions.Item{d.JsonSchema.Schema},
+ }
+ g, err := fs.Grammar(config.FunctionsConfig.GrammarOptions()...)
+ if err == nil {
+ input.Grammar = g
+ } else {
+ xlog.Error("Failed generating grammar", "error", err)
+ }
+ }
+ }
+
+ config.Grammar = input.Grammar
+
+ if shouldUseFn {
+ xlog.Debug("Response needs to process functions")
+ }
+
+ switch {
+ // Generates grammar with internal's LocalAI engine
+ case (!config.FunctionsConfig.GrammarConfig.NoGrammar || strictMode) && shouldUseFn:
+ noActionGrammar := functions.Function{
+ Name: noActionName,
+ Description: noActionDescription,
+ Parameters: map[string]interface{}{
+ "properties": map[string]interface{}{
+ "message": map[string]interface{}{
+ "type": "string",
+ "description": "The message to reply the user with",
+ }},
+ },
+ }
+
+ // Append the no action function
+ if !config.FunctionsConfig.DisableNoAction && !strictMode {
+ funcs = append(funcs, noActionGrammar)
+ }
+
+ // Force picking one of the functions by the request
+ if config.FunctionToCall() != "" {
+ funcs = funcs.Select(config.FunctionToCall())
+ }
+
+ // Update input grammar or json_schema based on use_llama_grammar option
+ jsStruct := funcs.ToJSONStructure(config.FunctionsConfig.FunctionNameKey, config.FunctionsConfig.FunctionNameKey)
+ g, err := jsStruct.Grammar(config.FunctionsConfig.GrammarOptions()...)
+ if err == nil {
+ config.Grammar = g
+ } else {
+ xlog.Error("Failed generating grammar", "error", err)
+ }
+ case input.JSONFunctionGrammarObject != nil:
+ g, err := input.JSONFunctionGrammarObject.Grammar(config.FunctionsConfig.GrammarOptions()...)
+ if err == nil {
+ config.Grammar = g
+ } else {
+ xlog.Error("Failed generating grammar", "error", err)
+ }
+
+ default:
+ // Force picking one of the functions by the request
+ if config.FunctionToCall() != "" {
+ funcs = funcs.Select(config.FunctionToCall())
+ }
+ }
+
+ // process functions if we have any defined or if we have a function call string
+
+ // functions are not supported in stream mode (yet?)
+ toStream := input.Stream
+
+ xlog.Debug("Parameters", "config", config)
+
+ var predInput string
+
+ // If we are using the tokenizer template, we don't need to process the messages
+ // unless we are processing functions
+ if !config.TemplateConfig.UseTokenizerTemplate {
+ predInput = evaluator.TemplateMessages(*input, input.Messages, config, funcs, shouldUseFn)
+
+ xlog.Debug("Prompt (after templating)", "prompt", predInput)
+ if config.Grammar != "" {
+ xlog.Debug("Grammar", "grammar", config.Grammar)
+ }
+ }
+
+ switch {
+ case toStream:
+
+ xlog.Debug("Stream request received")
+ c.Response().Header().Set("Content-Type", "text/event-stream")
+ c.Response().Header().Set("Cache-Control", "no-cache")
+ c.Response().Header().Set("Connection", "keep-alive")
+ c.Response().Header().Set("X-Correlation-ID", id)
+
+ responses := make(chan schema.OpenAIResponse)
+ ended := make(chan error, 1)
+
+ go func() {
+ if !shouldUseFn {
+ ended <- process(predInput, input, config, ml, responses, extraUsage)
+ } else {
+ ended <- processTools(noActionName, predInput, input, config, ml, responses, extraUsage)
+ }
+ }()
+
+ usage := &schema.OpenAIUsage{}
+ toolsCalled := false
+
+ LOOP:
+ for {
+ select {
+ case <-input.Context.Done():
+ // Context was cancelled (client disconnected or request cancelled)
+ xlog.Debug("Request context cancelled, stopping stream")
+ input.Cancel()
+ break LOOP
+ case ev := <-responses:
+ if len(ev.Choices) == 0 {
+ xlog.Debug("No choices in the response, skipping")
+ continue
+ }
+ usage = &ev.Usage // Copy a pointer to the latest usage chunk so that the stop message can reference it
+ if len(ev.Choices[0].Delta.ToolCalls) > 0 {
+ toolsCalled = true
+ }
+ respData, err := json.Marshal(ev)
+ if err != nil {
+ xlog.Debug("Failed to marshal response", "error", err)
+ input.Cancel()
+ continue
+ }
+ xlog.Debug("Sending chunk", "chunk", string(respData))
+ _, err = fmt.Fprintf(c.Response().Writer, "data: %s\n\n", string(respData))
+ if err != nil {
+ xlog.Debug("Sending chunk failed", "error", err)
+ input.Cancel()
+ return err
+ }
+ c.Response().Flush()
+ case err := <-ended:
+ if err == nil {
+ break LOOP
+ }
+ xlog.Error("Stream ended with error", "error", err)
+
+ stopReason := FinishReasonStop
+ resp := &schema.OpenAIResponse{
+ ID: id,
+ Created: created,
+ Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
+ Choices: []schema.Choice{
+ {
+ FinishReason: &stopReason,
+ Index: 0,
+ Delta: &schema.Message{Content: "Internal error: " + err.Error()},
+ }},
+ Object: "chat.completion.chunk",
+ Usage: *usage,
+ }
+ respData, marshalErr := json.Marshal(resp)
+ if marshalErr != nil {
+ xlog.Error("Failed to marshal error response", "error", marshalErr)
+ // Send a simple error message as fallback
+ fmt.Fprintf(c.Response().Writer, "data: {\"error\":\"Internal error\"}\n\n")
+ } else {
+ fmt.Fprintf(c.Response().Writer, "data: %s\n\n", respData)
+ }
+ fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n")
+ c.Response().Flush()
+
+ return nil
+ }
+ }
+
+ finishReason := FinishReasonStop
+ if toolsCalled && len(input.Tools) > 0 {
+ finishReason = FinishReasonToolCalls
+ } else if toolsCalled {
+ finishReason = FinishReasonFunctionCall
+ }
+
+ resp := &schema.OpenAIResponse{
+ ID: id,
+ Created: created,
+ Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
+ Choices: []schema.Choice{
+ {
+ FinishReason: &finishReason,
+ Index: 0,
+ Delta: &schema.Message{},
+ }},
+ Object: "chat.completion.chunk",
+ Usage: *usage,
+ }
+ respData, _ := json.Marshal(resp)
+
+ fmt.Fprintf(c.Response().Writer, "data: %s\n\n", respData)
+ fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n")
+ c.Response().Flush()
+ xlog.Debug("Stream ended")
+ return nil
+
+ // no streaming mode
+ default:
+
+ tokenCallback := func(s string, c *[]schema.Choice) {
+ // Extract reasoning from the response
+ reasoning, cleanedS := functions.ExtractReasoning(s)
+ s = cleanedS
+
+ if !shouldUseFn {
+ // no function is called, just reply and use stop as finish reason
+ stopReason := FinishReasonStop
+ message := &schema.Message{Role: "assistant", Content: &s}
+ if reasoning != "" {
+ message.Reasoning = &reasoning
+ }
+ *c = append(*c, schema.Choice{FinishReason: &stopReason, Index: 0, Message: message})
+ return
+ }
+
+ textContentToReturn = functions.ParseTextContent(s, config.FunctionsConfig)
+ s = functions.CleanupLLMResult(s, config.FunctionsConfig)
+ results := functions.ParseFunctionCall(s, config.FunctionsConfig)
+ xlog.Debug("Text content to return", "text", textContentToReturn)
+ noActionsToRun := len(results) > 0 && results[0].Name == noActionName || len(results) == 0
+
+ switch {
+ case noActionsToRun:
+ result, err := handleQuestion(config, cl, input, ml, startupOptions, results, s, predInput)
+ if err != nil {
+ xlog.Error("error handling question", "error", err)
+ return
+ }
+
+ stopReason := FinishReasonStop
+ message := &schema.Message{Role: "assistant", Content: &result}
+ if reasoning != "" {
+ message.Reasoning = &reasoning
+ }
+ *c = append(*c, schema.Choice{
+ FinishReason: &stopReason,
+ Message: message})
+ default:
+ toolCallsReason := FinishReasonToolCalls
+ toolChoice := schema.Choice{
+ FinishReason: &toolCallsReason,
+ Message: &schema.Message{
+ Role: "assistant",
+ },
+ }
+ if reasoning != "" {
+ toolChoice.Message.Reasoning = &reasoning
+ }
+
+ for _, ss := range results {
+ name, args := ss.Name, ss.Arguments
+ if len(input.Tools) > 0 {
+ // If we are using tools, we condense the function calls into
+ // a single response choice with all the tools
+ toolChoice.Message.Content = textContentToReturn
+ toolChoice.Message.ToolCalls = append(toolChoice.Message.ToolCalls,
+ schema.ToolCall{
+ ID: id,
+ Type: "function",
+ FunctionCall: schema.FunctionCall{
+ Name: name,
+ Arguments: args,
+ },
+ },
+ )
+ } else {
+ // otherwise we return more choices directly (deprecated)
+ functionCallReason := FinishReasonFunctionCall
+ message := &schema.Message{
+ Role: "assistant",
+ Content: &textContentToReturn,
+ FunctionCall: map[string]interface{}{
+ "name": name,
+ "arguments": args,
+ },
+ }
+ if reasoning != "" {
+ message.Reasoning = &reasoning
+ }
+ *c = append(*c, schema.Choice{
+ FinishReason: &functionCallReason,
+ Message: message,
+ })
+ }
+ }
+
+ if len(input.Tools) > 0 {
+ // we need to append our result if we are using tools
+ *c = append(*c, toolChoice)
+ }
+ }
+
+ }
+
+ // Echo properly supports context cancellation via c.Request().Context()
+ // No workaround needed!
+
+ result, tokenUsage, err := ComputeChoices(
+ input,
+ predInput,
+ config,
+ cl,
+ startupOptions,
+ ml,
+ tokenCallback,
+ nil,
+ )
+ if err != nil {
+ return err
+ }
+ usage := schema.OpenAIUsage{
+ PromptTokens: tokenUsage.Prompt,
+ CompletionTokens: tokenUsage.Completion,
+ TotalTokens: tokenUsage.Prompt + tokenUsage.Completion,
+ }
+ if extraUsage {
+ usage.TimingTokenGeneration = tokenUsage.TimingTokenGeneration
+ usage.TimingPromptProcessing = tokenUsage.TimingPromptProcessing
+ }
+
+ resp := &schema.OpenAIResponse{
+ ID: id,
+ Created: created,
+ Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
+ Choices: result,
+ Object: "chat.completion",
+ Usage: usage,
+ }
+ respData, _ := json.Marshal(resp)
+ xlog.Debug("Response", "response", string(respData))
+
+ // Return the prediction in the response body
+ return c.JSON(200, resp)
+ }
+ }
+}
+
+func handleQuestion(config *config.ModelConfig, cl *config.ModelConfigLoader, input *schema.OpenAIRequest, ml *model.ModelLoader, o *config.ApplicationConfig, funcResults []functions.FuncCallResults, result, prompt string) (string, error) {
+
+ if len(funcResults) == 0 && result != "" {
+ xlog.Debug("nothing function results but we had a message from the LLM")
+
+ return result, nil
+ }
+
+ xlog.Debug("nothing to do, computing a reply")
+ arg := ""
+ if len(funcResults) > 0 {
+ arg = funcResults[0].Arguments
+ }
+ // If there is a message that the LLM already sends as part of the JSON reply, use it
+ arguments := map[string]interface{}{}
+ if err := json.Unmarshal([]byte(arg), &arguments); err != nil {
+ xlog.Debug("handleQuestion: function result did not contain a valid JSON object")
+ }
+ m, exists := arguments["message"]
+ if exists {
+ switch message := m.(type) {
+ case string:
+ if message != "" {
+ xlog.Debug("Reply received from LLM", "message", message)
+ message = backend.Finetune(*config, prompt, message)
+ xlog.Debug("Reply received from LLM(finetuned)", "message", message)
+
+ return message, nil
+ }
+ }
+ }
+
+ xlog.Debug("No action received from LLM, without a message, computing a reply")
+ // Otherwise ask the LLM to understand the JSON output and the context, and return a message
+ // Note: This costs (in term of CPU/GPU) another computation
+ config.Grammar = ""
+ images := []string{}
+ for _, m := range input.Messages {
+ images = append(images, m.StringImages...)
+ }
+ videos := []string{}
+ for _, m := range input.Messages {
+ videos = append(videos, m.StringVideos...)
+ }
+ audios := []string{}
+ for _, m := range input.Messages {
+ audios = append(audios, m.StringAudios...)
+ }
+
+ // Serialize tools and tool_choice to JSON strings
+ toolsJSON := ""
+ if len(input.Tools) > 0 {
+ toolsBytes, err := json.Marshal(input.Tools)
+ if err == nil {
+ toolsJSON = string(toolsBytes)
+ }
+ }
+ toolChoiceJSON := ""
+ if input.ToolsChoice != nil {
+ toolChoiceBytes, err := json.Marshal(input.ToolsChoice)
+ if err == nil {
+ toolChoiceJSON = string(toolChoiceBytes)
+ }
+ }
+
+ // Extract logprobs from request
+ // According to OpenAI API: logprobs is boolean, top_logprobs (0-20) controls how many top tokens per position
+ var logprobs *int
+ var topLogprobs *int
+ if input.Logprobs.IsEnabled() {
+ // If logprobs is enabled, use top_logprobs if provided, otherwise default to 1
+ if input.TopLogprobs != nil {
+ topLogprobs = input.TopLogprobs
+ // For backend compatibility, set logprobs to the top_logprobs value
+ logprobs = input.TopLogprobs
+ } else {
+ // Default to 1 if logprobs is true but top_logprobs not specified
+ val := 1
+ logprobs = &val
+ topLogprobs = &val
+ }
+ }
+
+ // Extract logit_bias from request
+ // According to OpenAI API: logit_bias is a map of token IDs (as strings) to bias values (-100 to 100)
+ var logitBias map[string]float64
+ if len(input.LogitBias) > 0 {
+ logitBias = input.LogitBias
+ }
+
+ predFunc, err := backend.ModelInference(input.Context, prompt, input.Messages, images, videos, audios, ml, config, cl, o, nil, toolsJSON, toolChoiceJSON, logprobs, topLogprobs, logitBias)
+ if err != nil {
+ xlog.Error("model inference failed", "error", err)
+ return "", err
+ }
+
+ prediction, err := predFunc()
+ if err != nil {
+ xlog.Error("prediction failed", "error", err)
+ return "", err
+ }
+ return backend.Finetune(*config, prompt, prediction.Response), nil
+}
diff --git a/core/http/endpoints/openai/completion.go b/core/http/endpoints/openai/completion.go
new file mode 100644
index 0000000000000000000000000000000000000000..25935120d44d25a4a64aeef53fdb156d75a4fd88
--- /dev/null
+++ b/core/http/endpoints/openai/completion.go
@@ -0,0 +1,258 @@
+package openai
+
+import (
+ "encoding/json"
+ "errors"
+ "fmt"
+ "time"
+
+ "github.com/labstack/echo/v4"
+ "github.com/mudler/LocalAI/core/backend"
+ "github.com/mudler/LocalAI/core/config"
+ "github.com/mudler/LocalAI/core/http/middleware"
+
+ "github.com/google/uuid"
+ "github.com/mudler/LocalAI/core/schema"
+ "github.com/mudler/LocalAI/core/templates"
+ "github.com/mudler/LocalAI/pkg/functions"
+ "github.com/mudler/LocalAI/pkg/model"
+ "github.com/mudler/xlog"
+)
+
+// CompletionEndpoint is the OpenAI Completion API endpoint https://platform.openai.com/docs/api-reference/completions
+// @Summary Generate completions for a given prompt and model.
+// @Param request body schema.OpenAIRequest true "query params"
+// @Success 200 {object} schema.OpenAIResponse "Response"
+// @Router /v1/completions [post]
+func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) echo.HandlerFunc {
+ process := func(id string, s string, req *schema.OpenAIRequest, config *config.ModelConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) error {
+ tokenCallback := func(s string, tokenUsage backend.TokenUsage) bool {
+ created := int(time.Now().Unix())
+
+ usage := schema.OpenAIUsage{
+ PromptTokens: tokenUsage.Prompt,
+ CompletionTokens: tokenUsage.Completion,
+ TotalTokens: tokenUsage.Prompt + tokenUsage.Completion,
+ }
+ if extraUsage {
+ usage.TimingTokenGeneration = tokenUsage.TimingTokenGeneration
+ usage.TimingPromptProcessing = tokenUsage.TimingPromptProcessing
+ }
+ resp := schema.OpenAIResponse{
+ ID: id,
+ Created: created,
+ Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
+ Choices: []schema.Choice{
+ {
+ Index: 0,
+ Text: s,
+ FinishReason: nil,
+ },
+ },
+ Object: "text_completion",
+ Usage: usage,
+ }
+ xlog.Debug("Sending goroutine", "text", s)
+
+ responses <- resp
+ return true
+ }
+ _, _, err := ComputeChoices(req, s, config, cl, appConfig, loader, func(s string, c *[]schema.Choice) {}, tokenCallback)
+ close(responses)
+ return err
+ }
+
+ return func(c echo.Context) error {
+
+ created := int(time.Now().Unix())
+
+ // Handle Correlation
+ id := c.Request().Header.Get("X-Correlation-ID")
+ if id == "" {
+ id = uuid.New().String()
+ }
+ extraUsage := c.Request().Header.Get("Extra-Usage") != ""
+
+ input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
+ if !ok || input.Model == "" {
+ return echo.ErrBadRequest
+ }
+
+ config, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
+ if !ok || config == nil {
+ return echo.ErrBadRequest
+ }
+
+ if config.ResponseFormatMap != nil {
+ d := schema.ChatCompletionResponseFormat{}
+ dat, _ := json.Marshal(config.ResponseFormatMap)
+ _ = json.Unmarshal(dat, &d)
+ if d.Type == "json_object" {
+ input.Grammar = functions.JSONBNF
+ }
+ }
+
+ config.Grammar = input.Grammar
+
+ xlog.Debug("Parameter Config", "config", config)
+
+ if input.Stream {
+ xlog.Debug("Stream request received")
+ c.Response().Header().Set("Content-Type", "text/event-stream")
+ c.Response().Header().Set("Cache-Control", "no-cache")
+ c.Response().Header().Set("Connection", "keep-alive")
+
+ if len(config.PromptStrings) > 1 {
+ return errors.New("cannot handle more than 1 `PromptStrings` when Streaming")
+ }
+
+ predInput := config.PromptStrings[0]
+
+ templatedInput, err := evaluator.EvaluateTemplateForPrompt(templates.CompletionPromptTemplate, *config, templates.PromptTemplateData{
+ Input: predInput,
+ SystemPrompt: config.SystemPrompt,
+ ReasoningEffort: input.ReasoningEffort,
+ Metadata: input.Metadata,
+ })
+ if err == nil {
+ predInput = templatedInput
+ xlog.Debug("Template found, input modified", "input", predInput)
+ }
+
+ responses := make(chan schema.OpenAIResponse)
+
+ ended := make(chan error)
+ go func() {
+ ended <- process(id, predInput, input, config, ml, responses, extraUsage)
+ }()
+
+ LOOP:
+ for {
+ select {
+ case ev := <-responses:
+ if len(ev.Choices) == 0 {
+ xlog.Debug("No choices in the response, skipping")
+ continue
+ }
+ respData, err := json.Marshal(ev)
+ if err != nil {
+ xlog.Debug("Failed to marshal response", "error", err)
+ continue
+ }
+
+ xlog.Debug("Sending chunk", "chunk", string(respData))
+ _, err = fmt.Fprintf(c.Response().Writer, "data: %s\n\n", string(respData))
+ if err != nil {
+ return err
+ }
+ c.Response().Flush()
+ case err := <-ended:
+ if err == nil {
+ break LOOP
+ }
+ xlog.Error("Stream ended with error", "error", err)
+
+ stopReason := FinishReasonStop
+ errorResp := schema.OpenAIResponse{
+ ID: id,
+ Created: created,
+ Model: input.Model,
+ Choices: []schema.Choice{
+ {
+ Index: 0,
+ FinishReason: &stopReason,
+ Text: "Internal error: " + err.Error(),
+ },
+ },
+ Object: "text_completion",
+ }
+ errorData, marshalErr := json.Marshal(errorResp)
+ if marshalErr != nil {
+ xlog.Error("Failed to marshal error response", "error", marshalErr)
+ // Send a simple error message as fallback
+ fmt.Fprintf(c.Response().Writer, "data: {\"error\":\"Internal error\"}\n\n")
+ } else {
+ fmt.Fprintf(c.Response().Writer, "data: %s\n\n", string(errorData))
+ }
+ c.Response().Flush()
+ return nil
+ }
+ }
+
+ stopReason := FinishReasonStop
+ resp := &schema.OpenAIResponse{
+ ID: id,
+ Created: created,
+ Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
+ Choices: []schema.Choice{
+ {
+ Index: 0,
+ FinishReason: &stopReason,
+ },
+ },
+ Object: "text_completion",
+ }
+ respData, _ := json.Marshal(resp)
+
+ fmt.Fprintf(c.Response().Writer, "data: %s\n\n", respData)
+ fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n")
+ c.Response().Flush()
+ return nil
+ }
+
+ var result []schema.Choice
+
+ totalTokenUsage := backend.TokenUsage{}
+
+ for k, i := range config.PromptStrings {
+ templatedInput, err := evaluator.EvaluateTemplateForPrompt(templates.CompletionPromptTemplate, *config, templates.PromptTemplateData{
+ SystemPrompt: config.SystemPrompt,
+ Input: i,
+ ReasoningEffort: input.ReasoningEffort,
+ Metadata: input.Metadata,
+ })
+ if err == nil {
+ i = templatedInput
+ xlog.Debug("Template found, input modified", "input", i)
+ }
+
+ r, tokenUsage, err := ComputeChoices(
+ input, i, config, cl, appConfig, ml, func(s string, c *[]schema.Choice) {
+ stopReason := FinishReasonStop
+ *c = append(*c, schema.Choice{Text: s, FinishReason: &stopReason, Index: k})
+ }, nil)
+ if err != nil {
+ return err
+ }
+
+ totalTokenUsage.TimingTokenGeneration += tokenUsage.TimingTokenGeneration
+ totalTokenUsage.TimingPromptProcessing += tokenUsage.TimingPromptProcessing
+
+ result = append(result, r...)
+ }
+ usage := schema.OpenAIUsage{
+ PromptTokens: totalTokenUsage.Prompt,
+ CompletionTokens: totalTokenUsage.Completion,
+ TotalTokens: totalTokenUsage.Prompt + totalTokenUsage.Completion,
+ }
+ if extraUsage {
+ usage.TimingTokenGeneration = totalTokenUsage.TimingTokenGeneration
+ usage.TimingPromptProcessing = totalTokenUsage.TimingPromptProcessing
+ }
+
+ resp := &schema.OpenAIResponse{
+ ID: id,
+ Created: created,
+ Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
+ Choices: result,
+ Object: "text_completion",
+ Usage: usage,
+ }
+
+ jsonResult, _ := json.Marshal(resp)
+ xlog.Debug("Response", "response", string(jsonResult))
+
+ // Return the prediction in the response body
+ return c.JSON(200, resp)
+ }
+}
diff --git a/core/http/endpoints/openai/constants.go b/core/http/endpoints/openai/constants.go
new file mode 100644
index 0000000000000000000000000000000000000000..bc7dae10bccb37a8304fc4d45f770d1c23e8eb99
--- /dev/null
+++ b/core/http/endpoints/openai/constants.go
@@ -0,0 +1,8 @@
+package openai
+
+// Finish reason constants for OpenAI API responses
+const (
+ FinishReasonStop = "stop"
+ FinishReasonToolCalls = "tool_calls"
+ FinishReasonFunctionCall = "function_call"
+)
diff --git a/core/http/endpoints/openai/edit.go b/core/http/endpoints/openai/edit.go
new file mode 100644
index 0000000000000000000000000000000000000000..1b824df952523f7821a2a693c5ed4afb5a4c0ad6
--- /dev/null
+++ b/core/http/endpoints/openai/edit.go
@@ -0,0 +1,103 @@
+package openai
+
+import (
+ "encoding/json"
+ "time"
+
+ "github.com/labstack/echo/v4"
+ "github.com/mudler/LocalAI/core/backend"
+ "github.com/mudler/LocalAI/core/config"
+ "github.com/mudler/LocalAI/core/http/middleware"
+
+ "github.com/google/uuid"
+ "github.com/mudler/LocalAI/core/schema"
+
+ "github.com/mudler/LocalAI/core/templates"
+ "github.com/mudler/LocalAI/pkg/model"
+
+ "github.com/mudler/xlog"
+)
+
+// EditEndpoint is the OpenAI edit API endpoint
+// @Summary OpenAI edit endpoint
+// @Param request body schema.OpenAIRequest true "query params"
+// @Success 200 {object} schema.OpenAIResponse "Response"
+// @Router /v1/edits [post]
+func EditEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) echo.HandlerFunc {
+
+ return func(c echo.Context) error {
+
+ input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
+ if !ok || input.Model == "" {
+ return echo.ErrBadRequest
+ }
+ // Opt-in extra usage flag
+ extraUsage := c.Request().Header.Get("Extra-Usage") != ""
+
+ config, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
+ if !ok || config == nil {
+ return echo.ErrBadRequest
+ }
+
+ xlog.Debug("Edit Endpoint Input", "input", input)
+ xlog.Debug("Edit Endpoint Config", "config", *config)
+
+ var result []schema.Choice
+ totalTokenUsage := backend.TokenUsage{}
+
+ for _, i := range config.InputStrings {
+ templatedInput, err := evaluator.EvaluateTemplateForPrompt(templates.EditPromptTemplate, *config, templates.PromptTemplateData{
+ Input: i,
+ Instruction: input.Instruction,
+ SystemPrompt: config.SystemPrompt,
+ ReasoningEffort: input.ReasoningEffort,
+ Metadata: input.Metadata,
+ })
+ if err == nil {
+ i = templatedInput
+ xlog.Debug("Template found, input modified", "input", i)
+ }
+
+ r, tokenUsage, err := ComputeChoices(input, i, config, cl, appConfig, ml, func(s string, c *[]schema.Choice) {
+ *c = append(*c, schema.Choice{Text: s})
+ }, nil)
+ if err != nil {
+ return err
+ }
+
+ totalTokenUsage.Prompt += tokenUsage.Prompt
+ totalTokenUsage.Completion += tokenUsage.Completion
+
+ totalTokenUsage.TimingTokenGeneration += tokenUsage.TimingTokenGeneration
+ totalTokenUsage.TimingPromptProcessing += tokenUsage.TimingPromptProcessing
+
+ result = append(result, r...)
+ }
+ usage := schema.OpenAIUsage{
+ PromptTokens: totalTokenUsage.Prompt,
+ CompletionTokens: totalTokenUsage.Completion,
+ TotalTokens: totalTokenUsage.Prompt + totalTokenUsage.Completion,
+ }
+ if extraUsage {
+ usage.TimingTokenGeneration = totalTokenUsage.TimingTokenGeneration
+ usage.TimingPromptProcessing = totalTokenUsage.TimingPromptProcessing
+ }
+
+ id := uuid.New().String()
+ created := int(time.Now().Unix())
+ resp := &schema.OpenAIResponse{
+ ID: id,
+ Created: created,
+ Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
+ Choices: result,
+ Object: "edit",
+ Usage: usage,
+ }
+
+ jsonResult, _ := json.Marshal(resp)
+ xlog.Debug("Response", "response", string(jsonResult))
+
+ // Return the prediction in the response body
+ return c.JSON(200, resp)
+ }
+}
diff --git a/core/http/endpoints/openai/embeddings.go b/core/http/endpoints/openai/embeddings.go
new file mode 100644
index 0000000000000000000000000000000000000000..b88f3eb03795e70e5d94692df438ebbed808c0a5
--- /dev/null
+++ b/core/http/endpoints/openai/embeddings.go
@@ -0,0 +1,83 @@
+package openai
+
+import (
+ "encoding/json"
+ "time"
+
+ "github.com/labstack/echo/v4"
+ "github.com/mudler/LocalAI/core/backend"
+ "github.com/mudler/LocalAI/core/config"
+ "github.com/mudler/LocalAI/core/http/middleware"
+ "github.com/mudler/LocalAI/pkg/model"
+
+ "github.com/google/uuid"
+ "github.com/mudler/LocalAI/core/schema"
+
+ "github.com/mudler/xlog"
+)
+
+// EmbeddingsEndpoint is the OpenAI Embeddings API endpoint https://platform.openai.com/docs/api-reference/embeddings
+// @Summary Get a vector representation of a given input that can be easily consumed by machine learning models and algorithms.
+// @Param request body schema.OpenAIRequest true "query params"
+// @Success 200 {object} schema.OpenAIResponse "Response"
+// @Router /v1/embeddings [post]
+func EmbeddingsEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
+ return func(c echo.Context) error {
+ input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
+ if !ok || input.Model == "" {
+ return echo.ErrBadRequest
+ }
+
+ config, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
+ if !ok || config == nil {
+ return echo.ErrBadRequest
+ }
+
+ xlog.Debug("Parameter Config", "config", config)
+ items := []schema.Item{}
+
+ for i, s := range config.InputToken {
+ // get the model function to call for the result
+ embedFn, err := backend.ModelEmbedding("", s, ml, *config, appConfig)
+ if err != nil {
+ return err
+ }
+
+ embeddings, err := embedFn()
+ if err != nil {
+ return err
+ }
+ items = append(items, schema.Item{Embedding: embeddings, Index: i, Object: "embedding"})
+ }
+
+ for i, s := range config.InputStrings {
+ // get the model function to call for the result
+ embedFn, err := backend.ModelEmbedding(s, []int{}, ml, *config, appConfig)
+ if err != nil {
+ return err
+ }
+
+ embeddings, err := embedFn()
+ if err != nil {
+ return err
+ }
+ items = append(items, schema.Item{Embedding: embeddings, Index: i, Object: "embedding"})
+ }
+
+ id := uuid.New().String()
+ created := int(time.Now().Unix())
+ resp := &schema.OpenAIResponse{
+ ID: id,
+ Created: created,
+ Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
+ Data: items,
+ Object: "list",
+ }
+
+ jsonResult, _ := json.Marshal(resp)
+ xlog.Debug("Response", "response", string(jsonResult))
+
+ // Return the prediction in the response body
+ return c.JSON(200, resp)
+ }
+}
diff --git a/core/http/endpoints/openai/image.go b/core/http/endpoints/openai/image.go
new file mode 100644
index 0000000000000000000000000000000000000000..3575fee2b1676a69d5a506a70984bc5a3c5cadcd
--- /dev/null
+++ b/core/http/endpoints/openai/image.go
@@ -0,0 +1,297 @@
+package openai
+
+import (
+ "bufio"
+ "encoding/base64"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+ "net/url"
+ "os"
+ "path/filepath"
+ "strconv"
+ "strings"
+ "time"
+
+ "github.com/google/uuid"
+ "github.com/labstack/echo/v4"
+ "github.com/mudler/LocalAI/core/config"
+ "github.com/mudler/LocalAI/core/http/middleware"
+ "github.com/mudler/LocalAI/core/schema"
+
+ "github.com/mudler/LocalAI/core/backend"
+
+ model "github.com/mudler/LocalAI/pkg/model"
+ "github.com/mudler/xlog"
+)
+
+func downloadFile(url string) (string, error) {
+ // Get the data
+ resp, err := http.Get(url)
+ if err != nil {
+ return "", err
+ }
+ defer resp.Body.Close()
+
+ // Create the file
+ out, err := os.CreateTemp("", "image")
+ if err != nil {
+ return "", err
+ }
+ defer out.Close()
+
+ // Write the body to file
+ _, err = io.Copy(out, resp.Body)
+ return out.Name(), err
+}
+
+//
+
+/*
+*
+
+ curl http://localhost:8080/v1/images/generations \
+ -H "Content-Type: application/json" \
+ -d '{
+ "prompt": "A cute baby sea otter",
+ "n": 1,
+ "size": "512x512"
+ }'
+
+*
+*/
+// ImageEndpoint is the OpenAI Image generation API endpoint https://platform.openai.com/docs/api-reference/images/create
+// @Summary Creates an image given a prompt.
+// @Param request body schema.OpenAIRequest true "query params"
+// @Success 200 {object} schema.OpenAIResponse "Response"
+// @Router /v1/images/generations [post]
+func ImageEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
+ return func(c echo.Context) error {
+ input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
+ if !ok || input.Model == "" {
+ xlog.Error("Image Endpoint - Invalid Input")
+ return echo.ErrBadRequest
+ }
+
+ config, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
+ if !ok || config == nil {
+ xlog.Error("Image Endpoint - Invalid Config")
+ return echo.ErrBadRequest
+ }
+
+ // Process input images (for img2img/inpainting)
+ src := ""
+ if input.File != "" {
+ src = processImageFile(input.File, appConfig.GeneratedContentDir)
+ if src != "" {
+ defer os.RemoveAll(src)
+ }
+ }
+
+ // Process multiple input images
+ var inputImages []string
+ if len(input.Files) > 0 {
+ for _, file := range input.Files {
+ processedFile := processImageFile(file, appConfig.GeneratedContentDir)
+ if processedFile != "" {
+ inputImages = append(inputImages, processedFile)
+ defer os.RemoveAll(processedFile)
+ }
+ }
+ }
+
+ // Process reference images
+ var refImages []string
+ if len(input.RefImages) > 0 {
+ for _, file := range input.RefImages {
+ processedFile := processImageFile(file, appConfig.GeneratedContentDir)
+ if processedFile != "" {
+ refImages = append(refImages, processedFile)
+ defer os.RemoveAll(processedFile)
+ }
+ }
+ }
+
+ xlog.Debug("Parameter Config", "config", config)
+
+ switch config.Backend {
+ case "stablediffusion":
+ config.Backend = model.StableDiffusionGGMLBackend
+ case "":
+ config.Backend = model.StableDiffusionGGMLBackend
+ }
+
+ if !strings.Contains(input.Size, "x") {
+ input.Size = "512x512"
+ xlog.Warn("Invalid size, using default 512x512")
+ }
+
+ sizeParts := strings.Split(input.Size, "x")
+ if len(sizeParts) != 2 {
+ return fmt.Errorf("invalid value for 'size'")
+ }
+ width, err := strconv.Atoi(sizeParts[0])
+ if err != nil {
+ return fmt.Errorf("invalid value for 'size'")
+ }
+ height, err := strconv.Atoi(sizeParts[1])
+ if err != nil {
+ return fmt.Errorf("invalid value for 'size'")
+ }
+
+ b64JSON := config.ResponseFormat == "b64_json"
+
+ // src and clip_skip
+ var result []schema.Item
+ for _, i := range config.PromptStrings {
+ n := input.N
+ if input.N == 0 {
+ n = 1
+ }
+ for j := 0; j < n; j++ {
+ prompts := strings.Split(i, "|")
+ positive_prompt := prompts[0]
+ negative_prompt := ""
+ if len(prompts) > 1 {
+ negative_prompt = prompts[1]
+ }
+
+ step := config.Step
+ if step == 0 {
+ step = 15
+ }
+
+ if input.Step != 0 {
+ step = input.Step
+ }
+
+ tempDir := ""
+ if !b64JSON {
+ tempDir = filepath.Join(appConfig.GeneratedContentDir, "images")
+ }
+ // Create a temporary file
+ outputFile, err := os.CreateTemp(tempDir, "b64")
+ if err != nil {
+ return err
+ }
+ outputFile.Close()
+
+ output := outputFile.Name() + ".png"
+ // Rename the temporary file
+ err = os.Rename(outputFile.Name(), output)
+ if err != nil {
+ return err
+ }
+
+ baseURL := middleware.BaseURL(c)
+
+ // Use the first input image as src if available, otherwise use the original src
+ inputSrc := src
+ if len(inputImages) > 0 {
+ inputSrc = inputImages[0]
+ }
+
+ fn, err := backend.ImageGeneration(height, width, step, *config.Seed, positive_prompt, negative_prompt, inputSrc, output, ml, *config, appConfig, refImages)
+ if err != nil {
+ return err
+ }
+ if err := fn(); err != nil {
+ return err
+ }
+
+ item := &schema.Item{}
+
+ if b64JSON {
+ defer os.RemoveAll(output)
+ data, err := os.ReadFile(output)
+ if err != nil {
+ return err
+ }
+ item.B64JSON = base64.StdEncoding.EncodeToString(data)
+ } else {
+ base := filepath.Base(output)
+ item.URL, err = url.JoinPath(baseURL, "generated-images", base)
+ if err != nil {
+ return err
+ }
+ }
+
+ result = append(result, *item)
+ }
+ }
+
+ id := uuid.New().String()
+ created := int(time.Now().Unix())
+ resp := &schema.OpenAIResponse{
+ ID: id,
+ Created: created,
+ Data: result,
+ Usage: schema.OpenAIUsage{
+ PromptTokens: 0,
+ CompletionTokens: 0,
+ TotalTokens: 0,
+ InputTokens: 0,
+ OutputTokens: 0,
+ InputTokensDetails: &schema.InputTokensDetails{
+ TextTokens: 0,
+ ImageTokens: 0,
+ },
+ },
+ }
+
+ jsonResult, _ := json.Marshal(resp)
+ xlog.Debug("Response", "response", string(jsonResult))
+
+ // Return the prediction in the response body
+ return c.JSON(200, resp)
+ }
+}
+
+// processImageFile handles a single image file (URL or base64) and returns the path to the temporary file
+func processImageFile(file string, generatedContentDir string) string {
+ fileData := []byte{}
+ var err error
+
+ // check if file is an URL, if so download it and save it to a temporary file
+ if strings.HasPrefix(file, "http://") || strings.HasPrefix(file, "https://") {
+ out, err := downloadFile(file)
+ if err != nil {
+ xlog.Error("Failed downloading file", "error", err, "file", file)
+ return ""
+ }
+ defer os.RemoveAll(out)
+
+ fileData, err = os.ReadFile(out)
+ if err != nil {
+ xlog.Error("Failed reading downloaded file", "error", err, "file", out)
+ return ""
+ }
+ } else {
+ // base 64 decode the file and write it somewhere that we will cleanup
+ fileData, err = base64.StdEncoding.DecodeString(file)
+ if err != nil {
+ xlog.Error("Failed decoding base64 file", "error", err)
+ return ""
+ }
+ }
+
+ // Create a temporary file
+ outputFile, err := os.CreateTemp(generatedContentDir, "b64")
+ if err != nil {
+ xlog.Error("Failed creating temporary file", "error", err)
+ return ""
+ }
+
+ // write the base64 result
+ writer := bufio.NewWriter(outputFile)
+ _, err = writer.Write(fileData)
+ if err != nil {
+ outputFile.Close()
+ xlog.Error("Failed writing to temporary file", "error", err)
+ return ""
+ }
+ outputFile.Close()
+
+ return outputFile.Name()
+}
diff --git a/core/http/endpoints/openai/inference.go b/core/http/endpoints/openai/inference.go
new file mode 100644
index 0000000000000000000000000000000000000000..37b14c98bcfadfa34eb8cc0efed369b29e5b649b
--- /dev/null
+++ b/core/http/endpoints/openai/inference.go
@@ -0,0 +1,115 @@
+package openai
+
+import (
+ "encoding/json"
+
+ "github.com/mudler/LocalAI/core/backend"
+ "github.com/mudler/LocalAI/core/config"
+
+ "github.com/mudler/LocalAI/core/schema"
+ model "github.com/mudler/LocalAI/pkg/model"
+)
+
+func ComputeChoices(
+ req *schema.OpenAIRequest,
+ predInput string,
+ config *config.ModelConfig,
+ bcl *config.ModelConfigLoader,
+ o *config.ApplicationConfig,
+ loader *model.ModelLoader,
+ cb func(string, *[]schema.Choice),
+ tokenCallback func(string, backend.TokenUsage) bool) ([]schema.Choice, backend.TokenUsage, error) {
+ n := req.N // number of completions to return
+ result := []schema.Choice{}
+
+ if n == 0 {
+ n = 1
+ }
+
+ images := []string{}
+ for _, m := range req.Messages {
+ images = append(images, m.StringImages...)
+ }
+ videos := []string{}
+ for _, m := range req.Messages {
+ videos = append(videos, m.StringVideos...)
+ }
+ audios := []string{}
+ for _, m := range req.Messages {
+ audios = append(audios, m.StringAudios...)
+ }
+
+ // Serialize tools and tool_choice to JSON strings
+ toolsJSON := ""
+ if len(req.Tools) > 0 {
+ toolsBytes, err := json.Marshal(req.Tools)
+ if err == nil {
+ toolsJSON = string(toolsBytes)
+ }
+ }
+ toolChoiceJSON := ""
+ if req.ToolsChoice != nil {
+ toolChoiceBytes, err := json.Marshal(req.ToolsChoice)
+ if err == nil {
+ toolChoiceJSON = string(toolChoiceBytes)
+ }
+ }
+
+ // Extract logprobs from request
+ // According to OpenAI API: logprobs is boolean, top_logprobs (0-20) controls how many top tokens per position
+ var logprobs *int
+ var topLogprobs *int
+ if req.Logprobs.IsEnabled() {
+ // If logprobs is enabled, use top_logprobs if provided, otherwise default to 1
+ if req.TopLogprobs != nil {
+ topLogprobs = req.TopLogprobs
+ // For backend compatibility, set logprobs to the top_logprobs value
+ logprobs = req.TopLogprobs
+ } else {
+ // Default to 1 if logprobs is true but top_logprobs not specified
+ val := 1
+ logprobs = &val
+ topLogprobs = &val
+ }
+ }
+
+ // Extract logit_bias from request
+ // According to OpenAI API: logit_bias is a map of token IDs (as strings) to bias values (-100 to 100)
+ var logitBias map[string]float64
+ if len(req.LogitBias) > 0 {
+ logitBias = req.LogitBias
+ }
+
+ // get the model function to call for the result
+ predFunc, err := backend.ModelInference(
+ req.Context, predInput, req.Messages, images, videos, audios, loader, config, bcl, o, tokenCallback, toolsJSON, toolChoiceJSON, logprobs, topLogprobs, logitBias)
+ if err != nil {
+ return result, backend.TokenUsage{}, err
+ }
+
+ tokenUsage := backend.TokenUsage{}
+
+ for i := 0; i < n; i++ {
+ prediction, err := predFunc()
+ if err != nil {
+ return result, backend.TokenUsage{}, err
+ }
+
+ tokenUsage.Prompt += prediction.Usage.Prompt
+ tokenUsage.Completion += prediction.Usage.Completion
+ tokenUsage.TimingPromptProcessing += prediction.Usage.TimingPromptProcessing
+ tokenUsage.TimingTokenGeneration += prediction.Usage.TimingTokenGeneration
+
+ finetunedResponse := backend.Finetune(*config, predInput, prediction.Response)
+ cb(finetunedResponse, &result)
+
+ // Add logprobs to the last choice if present
+ if prediction.Logprobs != nil && len(result) > 0 {
+ result[len(result)-1].Logprobs = prediction.Logprobs
+ }
+
+ //result = append(result, Choice{Text: prediction})
+
+ }
+ return result, tokenUsage, err
+}
diff --git a/core/http/endpoints/openai/inpainting.go b/core/http/endpoints/openai/inpainting.go
new file mode 100644
index 0000000000000000000000000000000000000000..a27ffea54dc99c46167691443a51ce7ba4a4e316
--- /dev/null
+++ b/core/http/endpoints/openai/inpainting.go
@@ -0,0 +1,279 @@
+package openai
+
+import (
+ "encoding/base64"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+ "net/url"
+ "os"
+ "path/filepath"
+ "strconv"
+ "time"
+
+ "github.com/google/uuid"
+ "github.com/labstack/echo/v4"
+ "github.com/mudler/xlog"
+
+ "github.com/mudler/LocalAI/core/backend"
+ "github.com/mudler/LocalAI/core/config"
+ "github.com/mudler/LocalAI/core/http/middleware"
+ "github.com/mudler/LocalAI/core/schema"
+ model "github.com/mudler/LocalAI/pkg/model"
+)
+
+// InpaintingEndpoint handles POST /v1/images/inpainting
+//
+// Swagger / OpenAPI docstring (swaggo):
+// @Summary Image inpainting
+// @Description Perform image inpainting. Accepts multipart/form-data with `image` and `mask` files.
+// @Tags images
+// @Accept multipart/form-data
+// @Produce application/json
+// @Param model formData string true "Model identifier"
+// @Param prompt formData string true "Text prompt guiding the generation"
+// @Param steps formData int false "Number of inference steps (default 25)"
+// @Param image formData file true "Original image file"
+// @Param mask formData file true "Mask image file (white = area to inpaint)"
+// @Success 200 {object} schema.OpenAIResponse
+// @Failure 400 {object} map[string]string
+// @Failure 500 {object} map[string]string
+// @Router /v1/images/inpainting [post]
+func InpaintingEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
+ return func(c echo.Context) error {
+ // Parse basic form values
+ modelName := c.FormValue("model")
+ prompt := c.FormValue("prompt")
+ stepsStr := c.FormValue("steps")
+
+ if modelName == "" || prompt == "" {
+ xlog.Error("Inpainting Endpoint - missing model or prompt")
+ return echo.ErrBadRequest
+ }
+
+ // steps default
+ steps := 25
+ if stepsStr != "" {
+ if v, err := strconv.Atoi(stepsStr); err == nil {
+ steps = v
+ }
+ }
+
+ // Get uploaded files
+ imageFile, err := c.FormFile("image")
+ if err != nil {
+ xlog.Error("Inpainting Endpoint - missing image file", "error", err)
+ return echo.NewHTTPError(http.StatusBadRequest, "missing image file")
+ }
+ maskFile, err := c.FormFile("mask")
+ if err != nil {
+ xlog.Error("Inpainting Endpoint - missing mask file", "error", err)
+ return echo.NewHTTPError(http.StatusBadRequest, "missing mask file")
+ }
+
+ // Read files into memory (small files expected)
+ imgSrc, err := imageFile.Open()
+ if err != nil {
+ return err
+ }
+ defer imgSrc.Close()
+ imgBytes, err := io.ReadAll(imgSrc)
+ if err != nil {
+ return err
+ }
+
+ maskSrc, err := maskFile.Open()
+ if err != nil {
+ return err
+ }
+ defer maskSrc.Close()
+ maskBytes, err := io.ReadAll(maskSrc)
+ if err != nil {
+ return err
+ }
+
+ // Create JSON with base64 fields expected by backend
+ b64Image := base64.StdEncoding.EncodeToString(imgBytes)
+ b64Mask := base64.StdEncoding.EncodeToString(maskBytes)
+
+ // get model config from context (middleware set it)
+ cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
+ if !ok || cfg == nil {
+ xlog.Error("Inpainting Endpoint - model config not found in context")
+ return echo.ErrBadRequest
+ }
+
+ // Use the GeneratedContentDir so the generated PNG is placed where the
+ // HTTP static handler serves `/generated-images`.
+ tmpDir := appConfig.GeneratedContentDir
+ // Ensure the directory exists
+ if err := os.MkdirAll(tmpDir, 0750); err != nil {
+ xlog.Error("Inpainting Endpoint - failed to create generated content dir", "error", err, "dir", tmpDir)
+ return echo.NewHTTPError(http.StatusInternalServerError, "failed to prepare storage")
+ }
+ id := uuid.New().String()
+ jsonPath := filepath.Join(tmpDir, fmt.Sprintf("inpaint_%s.json", id))
+ jsonFile := map[string]string{
+ "image": b64Image,
+ "mask_image": b64Mask,
+ }
+ jf, err := os.CreateTemp(tmpDir, "inpaint_")
+ if err != nil {
+ return err
+ }
+ // setup cleanup on error; if everything succeeds we set success = true
+ success := false
+ var dst string
+ var origRef string
+ var maskRef string
+ defer func() {
+ if !success {
+ // Best-effort cleanup; log any failures
+ if jf != nil {
+ if cerr := jf.Close(); cerr != nil {
+ xlog.Warn("Inpainting Endpoint - failed to close temp json file in cleanup", "error", cerr)
+ }
+ if name := jf.Name(); name != "" {
+ if rerr := os.Remove(name); rerr != nil && !os.IsNotExist(rerr) {
+ xlog.Warn("Inpainting Endpoint - failed to remove temp json file in cleanup", "error", rerr, "file", name)
+ }
+ }
+ }
+ if jsonPath != "" {
+ if rerr := os.Remove(jsonPath); rerr != nil && !os.IsNotExist(rerr) {
+ xlog.Warn("Inpainting Endpoint - failed to remove json file in cleanup", "error", rerr, "file", jsonPath)
+ }
+ }
+ if dst != "" {
+ if rerr := os.Remove(dst); rerr != nil && !os.IsNotExist(rerr) {
+ xlog.Warn("Inpainting Endpoint - failed to remove dst file in cleanup", "error", rerr, "file", dst)
+ }
+ }
+ if origRef != "" {
+ if rerr := os.Remove(origRef); rerr != nil && !os.IsNotExist(rerr) {
+ xlog.Warn("Inpainting Endpoint - failed to remove orig ref file in cleanup", "error", rerr, "file", origRef)
+ }
+ }
+ if maskRef != "" {
+ if rerr := os.Remove(maskRef); rerr != nil && !os.IsNotExist(rerr) {
+ xlog.Warn("Inpainting Endpoint - failed to remove mask ref file in cleanup", "error", rerr, "file", maskRef)
+ }
+ }
+ }
+ }()
+
+ // write original image and mask to disk as ref images so backends that
+ // accept reference image files can use them (maintainer request).
+ origTmp, err := os.CreateTemp(tmpDir, "refimg_")
+ if err != nil {
+ return err
+ }
+ if _, err := origTmp.Write(imgBytes); err != nil {
+ _ = origTmp.Close()
+ _ = os.Remove(origTmp.Name())
+ return err
+ }
+ if cerr := origTmp.Close(); cerr != nil {
+ xlog.Warn("Inpainting Endpoint - failed to close orig temp file", "error", cerr)
+ }
+ origRef = origTmp.Name()
+
+ maskTmp, err := os.CreateTemp(tmpDir, "refmask_")
+ if err != nil {
+ // cleanup origTmp on error
+ _ = os.Remove(origRef)
+ return err
+ }
+ if _, err := maskTmp.Write(maskBytes); err != nil {
+ _ = maskTmp.Close()
+ _ = os.Remove(maskTmp.Name())
+ _ = os.Remove(origRef)
+ return err
+ }
+ if cerr := maskTmp.Close(); cerr != nil {
+ xlog.Warn("Inpainting Endpoint - failed to close mask temp file", "error", cerr)
+ }
+ maskRef = maskTmp.Name()
+ // write JSON
+ enc := json.NewEncoder(jf)
+ if err := enc.Encode(jsonFile); err != nil {
+ if cerr := jf.Close(); cerr != nil {
+ xlog.Warn("Inpainting Endpoint - failed to close temp json file after encode error", "error", cerr)
+ }
+ return err
+ }
+ if cerr := jf.Close(); cerr != nil {
+ xlog.Warn("Inpainting Endpoint - failed to close temp json file", "error", cerr)
+ }
+ // rename to desired name
+ if err := os.Rename(jf.Name(), jsonPath); err != nil {
+ return err
+ }
+ // prepare dst
+ outTmp, err := os.CreateTemp(tmpDir, "out_")
+ if err != nil {
+ return err
+ }
+ if cerr := outTmp.Close(); cerr != nil {
+ xlog.Warn("Inpainting Endpoint - failed to close out temp file", "error", cerr)
+ }
+ dst = outTmp.Name() + ".png"
+ if err := os.Rename(outTmp.Name(), dst); err != nil {
+ return err
+ }
+
+ // Determine width/height default
+ width := 512
+ height := 512
+
+ // Call backend image generation via indirection so tests can stub it
+ // Note: ImageGenerationFunc will call into the loaded model's GenerateImage which expects src JSON
+ // Also pass ref images (orig + mask) so backends that support ref images can use them.
+ refImages := []string{origRef, maskRef}
+ fn, err := backend.ImageGenerationFunc(height, width, steps, 0, prompt, "", jsonPath, dst, ml, *cfg, appConfig, refImages)
+ if err != nil {
+ return err
+ }
+
+ // Execute generation function (blocking)
+ if err := fn(); err != nil {
+ return err
+ }
+
+ // On success, build response URL using BaseURL middleware helper and
+ // the same `generated-images` prefix used by the server static mount.
+ baseURL := middleware.BaseURL(c)
+
+ // Build response using url.JoinPath for correct URL escaping
+ imgPath, err := url.JoinPath(baseURL, "generated-images", filepath.Base(dst))
+ if err != nil {
+ return err
+ }
+
+ created := int(time.Now().Unix())
+ resp := &schema.OpenAIResponse{
+ ID: id,
+ Created: created,
+ Data: []schema.Item{{
+ URL: imgPath,
+ }},
+ Usage: schema.OpenAIUsage{
+ PromptTokens: 0,
+ CompletionTokens: 0,
+ TotalTokens: 0,
+ InputTokens: 0,
+ OutputTokens: 0,
+ InputTokensDetails: &schema.InputTokensDetails{
+ TextTokens: 0,
+ ImageTokens: 0,
+ },
+ },
+ }
+
+ // mark success so defer cleanup will not remove output files
+ success = true
+
+ return c.JSON(http.StatusOK, resp)
+ }
+}
diff --git a/core/http/endpoints/openai/inpainting_test.go b/core/http/endpoints/openai/inpainting_test.go
new file mode 100644
index 0000000000000000000000000000000000000000..de4678d347e8db95cc3ba1a11c23df6a220aedf7
--- /dev/null
+++ b/core/http/endpoints/openai/inpainting_test.go
@@ -0,0 +1,107 @@
+package openai
+
+import (
+ "bytes"
+ "mime/multipart"
+ "net/http"
+ "net/http/httptest"
+ "os"
+ "path/filepath"
+ "testing"
+
+ "github.com/labstack/echo/v4"
+ "github.com/mudler/LocalAI/core/backend"
+ "github.com/mudler/LocalAI/core/config"
+ "github.com/mudler/LocalAI/core/http/middleware"
+ model "github.com/mudler/LocalAI/pkg/model"
+ "github.com/stretchr/testify/require"
+)
+
+func makeMultipartRequest(t *testing.T, fields map[string]string, files map[string][]byte) (*http.Request, string) {
+ b := &bytes.Buffer{}
+ w := multipart.NewWriter(b)
+ for k, v := range fields {
+ _ = w.WriteField(k, v)
+ }
+ for fname, content := range files {
+ fw, err := w.CreateFormFile(fname, fname+".png")
+ require.NoError(t, err)
+ _, err = fw.Write(content)
+ require.NoError(t, err)
+ }
+ require.NoError(t, w.Close())
+ req := httptest.NewRequest(http.MethodPost, "/v1/images/inpainting", b)
+ req.Header.Set("Content-Type", w.FormDataContentType())
+ return req, w.FormDataContentType()
+}
+
+func TestInpainting_MissingFiles(t *testing.T) {
+ e := echo.New()
+ // handler requires cl, ml, appConfig but this test verifies missing files early
+ h := InpaintingEndpoint(nil, nil, config.NewApplicationConfig())
+
+ req := httptest.NewRequest(http.MethodPost, "/v1/images/inpainting", nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ err := h(c)
+ require.Error(t, err)
+}
+
+func TestInpainting_HappyPath(t *testing.T) {
+ // Setup temp generated content dir
+ tmpDir, err := os.MkdirTemp("", "gencontent")
+ require.NoError(t, err)
+ defer os.RemoveAll(tmpDir)
+
+ appConf := config.NewApplicationConfig(config.WithGeneratedContentDir(tmpDir))
+
+ // stub the backend.ImageGenerationFunc
+ orig := backend.ImageGenerationFunc
+ backend.ImageGenerationFunc = func(height, width, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig, refImages []string) (func() error, error) {
+ fn := func() error {
+ // write a fake png file to dst
+ return os.WriteFile(dst, []byte("PNGDATA"), 0644)
+ }
+ return fn, nil
+ }
+ defer func() { backend.ImageGenerationFunc = orig }()
+
+ // prepare multipart request with image and mask
+ fields := map[string]string{"model": "dreamshaper-8-inpainting", "prompt": "A test"}
+ files := map[string][]byte{"image": []byte("IMAGEDATA"), "mask": []byte("MASKDATA")}
+ reqBuf, _ := makeMultipartRequest(t, fields, files)
+
+ rec := httptest.NewRecorder()
+ e := echo.New()
+ c := e.NewContext(reqBuf, rec)
+
+ // set a minimal model config in context as handler expects
+ c.Set(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG, &config.ModelConfig{Backend: "diffusers"})
+
+ h := InpaintingEndpoint(nil, nil, appConf)
+
+ // call handler
+ err = h(c)
+ require.NoError(t, err)
+ require.Equal(t, http.StatusOK, rec.Code)
+
+ // verify response body contains generated-images path
+ body := rec.Body.String()
+ require.Contains(t, body, "generated-images")
+
+ // confirm the file was created in tmpDir
+ // parse out filename from response (naive search)
+ // find "generated-images/" and extract until closing quote or brace
+ idx := bytes.Index(rec.Body.Bytes(), []byte("generated-images/"))
+ require.True(t, idx >= 0)
+ rest := rec.Body.Bytes()[idx:]
+ end := bytes.IndexAny(rest, "\",}\n")
+ if end == -1 {
+ end = len(rest)
+ }
+ fname := string(rest[len("generated-images/"):end])
+ // ensure file exists
+ _, err = os.Stat(filepath.Join(tmpDir, fname))
+ require.NoError(t, err)
+}
diff --git a/core/http/endpoints/openai/list.go b/core/http/endpoints/openai/list.go
new file mode 100644
index 0000000000000000000000000000000000000000..47501dd934f835c60071ed1f67978b3f24b8736e
--- /dev/null
+++ b/core/http/endpoints/openai/list.go
@@ -0,0 +1,50 @@
+package openai
+
+import (
+ "github.com/labstack/echo/v4"
+ "github.com/mudler/LocalAI/core/config"
+ "github.com/mudler/LocalAI/core/schema"
+ "github.com/mudler/LocalAI/core/services"
+ model "github.com/mudler/LocalAI/pkg/model"
+)
+
+// ListModelsEndpoint is the OpenAI Models API endpoint https://platform.openai.com/docs/api-reference/models
+// @Summary List and describe the various models available in the API.
+// @Success 200 {object} schema.ModelsDataResponse "Response"
+// @Router /v1/models [get]
+func ListModelsEndpoint(bcl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
+ return func(c echo.Context) error {
+ // If blank, no filter is applied.
+ filter := c.QueryParam("filter")
+
+ // By default, exclude any loose files that are already referenced by a configuration file.
+ var policy services.LooseFilePolicy
+ excludeConfigured := c.QueryParam("excludeConfigured")
+ if excludeConfigured == "" || excludeConfigured == "true" {
+ policy = services.SKIP_IF_CONFIGURED
+ } else {
+ policy = services.ALWAYS_INCLUDE // This replicates current behavior. TODO: give more options to the user?
+ }
+
+ filterFn, err := config.BuildNameFilterFn(filter)
+ if err != nil {
+ return err
+ }
+
+ modelNames, err := services.ListModels(bcl, ml, filterFn, policy)
+ if err != nil {
+ return err
+ }
+
+ // Map from a slice of names to a slice of OpenAIModel response objects
+ dataModels := []schema.OpenAIModel{}
+ for _, m := range modelNames {
+ dataModels = append(dataModels, schema.OpenAIModel{ID: m, Object: "model"})
+ }
+
+ return c.JSON(200, schema.ModelsDataResponse{
+ Object: "list",
+ Data: dataModels,
+ })
+ }
+}
diff --git a/core/http/endpoints/openai/realtime.go b/core/http/endpoints/openai/realtime.go
new file mode 100644
index 0000000000000000000000000000000000000000..517fa004520161c5ccf409f716fecd8138b2d958
--- /dev/null
+++ b/core/http/endpoints/openai/realtime.go
@@ -0,0 +1,1307 @@
+package openai
+
+import (
+ "context"
+ "encoding/base64"
+ "encoding/json"
+ "fmt"
+ "os"
+ "strings"
+ "sync"
+ "time"
+
+ "net/http"
+
+ "github.com/go-audio/audio"
+ "github.com/gorilla/websocket"
+ "github.com/labstack/echo/v4"
+ "github.com/mudler/LocalAI/core/application"
+ "github.com/mudler/LocalAI/core/config"
+ "github.com/mudler/LocalAI/core/http/endpoints/openai/types"
+ "github.com/mudler/LocalAI/core/templates"
+ laudio "github.com/mudler/LocalAI/pkg/audio"
+ "github.com/mudler/LocalAI/pkg/functions"
+ "github.com/mudler/LocalAI/pkg/grpc/proto"
+ model "github.com/mudler/LocalAI/pkg/model"
+ "github.com/mudler/LocalAI/pkg/sound"
+
+ "google.golang.org/grpc"
+
+ "github.com/mudler/xlog"
+)
+
+const (
+ localSampleRate = 16000
+ remoteSampleRate = 24000
+ vadModel = "silero-vad-ggml"
+)
+
+// A model can be "emulated" that is: transcribe audio to text -> feed text to the LLM -> generate audio as result
+// If the model support instead audio-to-audio, we will use the specific gRPC calls instead
+
+// Session represents a single WebSocket connection and its state
+type Session struct {
+ ID string
+ TranscriptionOnly bool
+ Model string
+ Voice string
+ TurnDetection *types.ServerTurnDetection `json:"turn_detection"` // "server_vad" or "none"
+ InputAudioTranscription *types.InputAudioTranscription
+ Functions functions.Functions
+ Conversations map[string]*Conversation
+ InputAudioBuffer []byte
+ AudioBufferLock sync.Mutex
+ Instructions string
+ DefaultConversationID string
+ ModelInterface Model
+}
+
+func (s *Session) FromClient(session *types.ClientSession) {
+}
+
+func (s *Session) ToServer() types.ServerSession {
+ return types.ServerSession{
+ ID: s.ID,
+ Object: func() string {
+ if s.TranscriptionOnly {
+ return "realtime.transcription_session"
+ } else {
+ return "realtime.session"
+ }
+ }(),
+ Model: s.Model,
+ Modalities: []types.Modality{types.ModalityText, types.ModalityAudio},
+ Instructions: s.Instructions,
+ Voice: s.Voice,
+ InputAudioFormat: types.AudioFormatPcm16,
+ OutputAudioFormat: types.AudioFormatPcm16,
+ TurnDetection: s.TurnDetection,
+ InputAudioTranscription: s.InputAudioTranscription,
+ // TODO: Should be constructed from Functions?
+ Tools: []types.Tool{},
+ // TODO: ToolChoice
+ // TODO: Temperature
+ // TODO: MaxOutputTokens
+ // TODO: InputAudioNoiseReduction
+ }
+}
+
+// TODO: Update to tools?
+// FunctionCall represents a function call initiated by the model
+type FunctionCall struct {
+ Name string `json:"name"`
+ Arguments map[string]interface{} `json:"arguments"`
+}
+
+// Conversation represents a conversation with a list of items
+type Conversation struct {
+ ID string
+ Items []*types.MessageItem
+ Lock sync.Mutex
+}
+
+func (c *Conversation) ToServer() types.Conversation {
+ return types.Conversation{
+ ID: c.ID,
+ Object: "realtime.conversation",
+ }
+}
+
+// Item represents a message, function_call, or function_call_output
+type Item struct {
+ ID string `json:"id"`
+ Object string `json:"object"`
+ Type string `json:"type"` // "message", "function_call", "function_call_output"
+ Status string `json:"status"`
+ Role string `json:"role"`
+ Content []ConversationContent `json:"content,omitempty"`
+ FunctionCall *FunctionCall `json:"function_call,omitempty"`
+}
+
+// ConversationContent represents the content of an item
+type ConversationContent struct {
+ Type string `json:"type"` // "input_text", "input_audio", "text", "audio", etc.
+ Audio string `json:"audio,omitempty"`
+ Text string `json:"text,omitempty"`
+ // Additional fields as needed
+}
+
+// Define the structures for incoming messages
+type IncomingMessage struct {
+ Type types.ClientEventType `json:"type"`
+ Session json.RawMessage `json:"session,omitempty"`
+ Item json.RawMessage `json:"item,omitempty"`
+ Audio string `json:"audio,omitempty"`
+ Response json.RawMessage `json:"response,omitempty"`
+ Error *ErrorMessage `json:"error,omitempty"`
+ // Other fields as needed
+}
+
+// ErrorMessage represents an error message sent to the client
+type ErrorMessage struct {
+ Type string `json:"type"`
+ Code string `json:"code"`
+ Message string `json:"message"`
+ Param string `json:"param,omitempty"`
+ EventID string `json:"event_id,omitempty"`
+}
+
+// Define a structure for outgoing messages
+type OutgoingMessage struct {
+ Type string `json:"type"`
+ Session *Session `json:"session,omitempty"`
+ Conversation *Conversation `json:"conversation,omitempty"`
+ Item *Item `json:"item,omitempty"`
+ Content string `json:"content,omitempty"`
+ Audio string `json:"audio,omitempty"`
+ Error *ErrorMessage `json:"error,omitempty"`
+}
+
+// Map to store sessions (in-memory)
+var sessions = make(map[string]*Session)
+var sessionLock sync.Mutex
+
+// TODO: implement interface as we start to define usages
+type Model interface {
+ VAD(ctx context.Context, in *proto.VADRequest, opts ...grpc.CallOption) (*proto.VADResponse, error)
+ Transcribe(ctx context.Context, in *proto.TranscriptRequest, opts ...grpc.CallOption) (*proto.TranscriptResult, error)
+ Predict(ctx context.Context, in *proto.PredictOptions, opts ...grpc.CallOption) (*proto.Reply, error)
+ PredictStream(ctx context.Context, in *proto.PredictOptions, f func(*proto.Reply), opts ...grpc.CallOption) error
+}
+
+var upgrader = websocket.Upgrader{
+ CheckOrigin: func(r *http.Request) bool {
+ return true // Allow all origins
+ },
+}
+
+// TODO: Implement ephemeral keys to allow these endpoints to be used
+func RealtimeSessions(application *application.Application) echo.HandlerFunc {
+ return func(c echo.Context) error {
+ return c.NoContent(501)
+ }
+}
+
+func RealtimeTranscriptionSession(application *application.Application) echo.HandlerFunc {
+ return func(c echo.Context) error {
+ return c.NoContent(501)
+ }
+}
+
+func Realtime(application *application.Application) echo.HandlerFunc {
+ return func(c echo.Context) error {
+ ws, err := upgrader.Upgrade(c.Response(), c.Request(), nil)
+ if err != nil {
+ return err
+ }
+ defer ws.Close()
+
+ // Extract query parameters from Echo context before passing to websocket handler
+ model := c.QueryParam("model")
+ if model == "" {
+ model = "gpt-4o"
+ }
+ intent := c.QueryParam("intent")
+
+ registerRealtime(application, model, intent)(ws)
+ return nil
+ }
+}
+
+func registerRealtime(application *application.Application, model, intent string) func(c *websocket.Conn) {
+ return func(c *websocket.Conn) {
+
+ evaluator := application.TemplatesEvaluator()
+ xlog.Debug("WebSocket connection established", "address", c.RemoteAddr().String())
+ if intent != "transcription" {
+ sendNotImplemented(c, "Only transcription mode is supported which requires the intent=transcription parameter")
+ }
+
+ xlog.Debug("Realtime params", "model", model, "intent", intent)
+
+ sessionID := generateSessionID()
+ session := &Session{
+ ID: sessionID,
+ TranscriptionOnly: true,
+ Model: model, // default model
+ Voice: "alloy", // default voice
+ TurnDetection: &types.ServerTurnDetection{
+ Type: types.ServerTurnDetectionTypeServerVad,
+ TurnDetectionParams: types.TurnDetectionParams{
+ // TODO: Need some way to pass this to the backend
+ Threshold: 0.5,
+ // TODO: This is ignored and the amount of padding is random at present
+ PrefixPaddingMs: 30,
+ SilenceDurationMs: 500,
+ CreateResponse: func() *bool { t := true; return &t }(),
+ },
+ },
+ InputAudioTranscription: &types.InputAudioTranscription{
+ Model: "whisper-1",
+ },
+ Conversations: make(map[string]*Conversation),
+ }
+
+ // Create a default conversation
+ conversationID := generateConversationID()
+ conversation := &Conversation{
+ ID: conversationID,
+ Items: []*types.MessageItem{},
+ }
+ session.Conversations[conversationID] = conversation
+ session.DefaultConversationID = conversationID
+
+ // TODO: The API has no way to configure the VAD model or other models that make up a pipeline to fake any-to-any
+ // So possibly we could have a way to configure a composite model that can be used in situations where any-to-any is expected
+ pipeline := config.Pipeline{
+ VAD: vadModel,
+ Transcription: session.InputAudioTranscription.Model,
+ }
+
+ m, cfg, err := newTranscriptionOnlyModel(
+ &pipeline,
+ application.ModelConfigLoader(),
+ application.ModelLoader(),
+ application.ApplicationConfig(),
+ )
+ if err != nil {
+ xlog.Error("failed to load model", "error", err)
+ sendError(c, "model_load_error", "Failed to load model", "", "")
+ return
+ }
+ session.ModelInterface = m
+
+ // Store the session
+ sessionLock.Lock()
+ sessions[sessionID] = session
+ sessionLock.Unlock()
+
+ sendEvent(c, types.TranscriptionSessionCreatedEvent{
+ ServerEventBase: types.ServerEventBase{
+ EventID: "event_TODO",
+ Type: types.ServerEventTypeTranscriptionSessionCreated,
+ },
+ Session: session.ToServer(),
+ })
+
+ var (
+ // mt int
+ msg []byte
+ wg sync.WaitGroup
+ done = make(chan struct{})
+ )
+
+ vadServerStarted := true
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ conversation := session.Conversations[session.DefaultConversationID]
+ handleVAD(cfg, evaluator, session, conversation, c, done)
+ }()
+
+ for {
+ if _, msg, err = c.ReadMessage(); err != nil {
+ xlog.Error("read error", "error", err)
+ break
+ }
+
+ // Parse the incoming message
+ var incomingMsg IncomingMessage
+ if err := json.Unmarshal(msg, &incomingMsg); err != nil {
+ xlog.Error("invalid json", "error", err)
+ sendError(c, "invalid_json", "Invalid JSON format", "", "")
+ continue
+ }
+
+ var sessionUpdate types.ClientSession
+ switch incomingMsg.Type {
+ case types.ClientEventTypeTranscriptionSessionUpdate:
+ xlog.Debug("recv", "message", string(msg))
+
+ if err := json.Unmarshal(incomingMsg.Session, &sessionUpdate); err != nil {
+ xlog.Error("failed to unmarshal 'transcription_session.update'", "error", err)
+ sendError(c, "invalid_session_update", "Invalid session update format", "", "")
+ continue
+ }
+ if err := updateTransSession(
+ session,
+ &sessionUpdate,
+ application.ModelConfigLoader(),
+ application.ModelLoader(),
+ application.ApplicationConfig(),
+ ); err != nil {
+ xlog.Error("failed to update session", "error", err)
+ sendError(c, "session_update_error", "Failed to update session", "", "")
+ continue
+ }
+
+ sendEvent(c, types.SessionUpdatedEvent{
+ ServerEventBase: types.ServerEventBase{
+ EventID: "event_TODO",
+ Type: types.ServerEventTypeTranscriptionSessionUpdated,
+ },
+ Session: session.ToServer(),
+ })
+
+ case types.ClientEventTypeSessionUpdate:
+ xlog.Debug("recv", "message", string(msg))
+
+ // Update session configurations
+ if err := json.Unmarshal(incomingMsg.Session, &sessionUpdate); err != nil {
+ xlog.Error("failed to unmarshal 'session.update'", "error", err)
+ sendError(c, "invalid_session_update", "Invalid session update format", "", "")
+ continue
+ }
+ if err := updateSession(
+ session,
+ &sessionUpdate,
+ application.ModelConfigLoader(),
+ application.ModelLoader(),
+ application.ApplicationConfig(),
+ ); err != nil {
+ xlog.Error("failed to update session", "error", err)
+ sendError(c, "session_update_error", "Failed to update session", "", "")
+ continue
+ }
+
+ sendEvent(c, types.SessionUpdatedEvent{
+ ServerEventBase: types.ServerEventBase{
+ EventID: "event_TODO",
+ Type: types.ServerEventTypeSessionUpdated,
+ },
+ Session: session.ToServer(),
+ })
+
+ if session.TurnDetection.Type == types.ServerTurnDetectionTypeServerVad && !vadServerStarted {
+ xlog.Debug("Starting VAD goroutine...")
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ conversation := session.Conversations[session.DefaultConversationID]
+ handleVAD(cfg, evaluator, session, conversation, c, done)
+ }()
+ vadServerStarted = true
+ } else if session.TurnDetection.Type != types.ServerTurnDetectionTypeServerVad && vadServerStarted {
+ xlog.Debug("Stopping VAD goroutine...")
+
+ wg.Add(-1)
+ go func() {
+ done <- struct{}{}
+ }()
+ vadServerStarted = false
+ }
+ case types.ClientEventTypeInputAudioBufferAppend:
+ // Handle 'input_audio_buffer.append'
+ if incomingMsg.Audio == "" {
+ xlog.Error("Audio data is missing in 'input_audio_buffer.append'")
+ sendError(c, "missing_audio_data", "Audio data is missing", "", "")
+ continue
+ }
+
+ // Decode base64 audio data
+ decodedAudio, err := base64.StdEncoding.DecodeString(incomingMsg.Audio)
+ if err != nil {
+ xlog.Error("failed to decode audio data", "error", err)
+ sendError(c, "invalid_audio_data", "Failed to decode audio data", "", "")
+ continue
+ }
+
+ // Append to InputAudioBuffer
+ session.AudioBufferLock.Lock()
+ session.InputAudioBuffer = append(session.InputAudioBuffer, decodedAudio...)
+ session.AudioBufferLock.Unlock()
+
+ case types.ClientEventTypeInputAudioBufferCommit:
+ xlog.Debug("recv", "message", string(msg))
+
+ // TODO: Trigger transcription.
+ // TODO: Ignore this if VAD enabled or interrupt VAD?
+
+ if session.TranscriptionOnly {
+ continue
+ }
+
+ // Commit the audio buffer to the conversation as a new item
+ item := &types.MessageItem{
+ ID: generateItemID(),
+ Type: "message",
+ Status: "completed",
+ Role: "user",
+ Content: []types.MessageContentPart{
+ {
+ Type: "input_audio",
+ Audio: base64.StdEncoding.EncodeToString(session.InputAudioBuffer),
+ },
+ },
+ }
+
+ // Add item to conversation
+ conversation.Lock.Lock()
+ conversation.Items = append(conversation.Items, item)
+ conversation.Lock.Unlock()
+
+ // Reset InputAudioBuffer
+ session.AudioBufferLock.Lock()
+ session.InputAudioBuffer = nil
+ session.AudioBufferLock.Unlock()
+
+ // Send item.created event
+ sendEvent(c, types.ConversationItemCreatedEvent{
+ ServerEventBase: types.ServerEventBase{
+ EventID: "event_TODO",
+ Type: "conversation.item.created",
+ },
+ Item: types.ResponseMessageItem{
+ Object: "realtime.item",
+ MessageItem: *item,
+ },
+ })
+
+ case types.ClientEventTypeConversationItemCreate:
+ xlog.Debug("recv", "message", string(msg))
+
+ // Handle creating new conversation items
+ var item types.ConversationItemCreateEvent
+ if err := json.Unmarshal(incomingMsg.Item, &item); err != nil {
+ xlog.Error("failed to unmarshal 'conversation.item.create'", "error", err)
+ sendError(c, "invalid_item", "Invalid item format", "", "")
+ continue
+ }
+
+ sendNotImplemented(c, "conversation.item.create")
+
+ // Generate item ID and set status
+ // item.ID = generateItemID()
+ // item.Object = "realtime.item"
+ // item.Status = "completed"
+ //
+ // // Add item to conversation
+ // conversation.Lock.Lock()
+ // conversation.Items = append(conversation.Items, &item)
+ // conversation.Lock.Unlock()
+ //
+ // // Send item.created event
+ // sendEvent(c, OutgoingMessage{
+ // Type: "conversation.item.created",
+ // Item: &item,
+ // })
+
+ case types.ClientEventTypeConversationItemDelete:
+ sendError(c, "not_implemented", "Deleting items not implemented", "", "event_TODO")
+
+ case types.ClientEventTypeResponseCreate:
+ // Handle generating a response
+ var responseCreate types.ResponseCreateEvent
+ if len(incomingMsg.Response) > 0 {
+ if err := json.Unmarshal(incomingMsg.Response, &responseCreate); err != nil {
+ xlog.Error("failed to unmarshal 'response.create' response object", "error", err)
+ sendError(c, "invalid_response_create", "Invalid response create format", "", "")
+ continue
+ }
+ }
+
+ // Update session functions if provided
+ if len(responseCreate.Response.Tools) > 0 {
+ // TODO: Tools -> Functions
+ }
+
+ sendNotImplemented(c, "response.create")
+
+ // TODO: Generate a response based on the conversation history
+ // wg.Add(1)
+ // go func() {
+ // defer wg.Done()
+ // generateResponse(cfg, evaluator, session, conversation, responseCreate, c, mt)
+ // }()
+
+ case types.ClientEventTypeResponseCancel:
+ xlog.Debug("recv", "message", string(msg))
+
+ // Handle cancellation of ongoing responses
+ // Implement cancellation logic as needed
+ sendNotImplemented(c, "response.cancel")
+
+ default:
+ xlog.Error("unknown message type", "type", incomingMsg.Type)
+ sendError(c, "unknown_message_type", fmt.Sprintf("Unknown message type: %s", incomingMsg.Type), "", "")
+ }
+ }
+
+ // Close the done channel to signal goroutines to exit
+ close(done)
+ wg.Wait()
+
+ // Remove the session from the sessions map
+ sessionLock.Lock()
+ delete(sessions, sessionID)
+ sessionLock.Unlock()
+ }
+}
+
+// Helper function to send events to the client
+func sendEvent(c *websocket.Conn, event types.ServerEvent) {
+ eventBytes, err := json.Marshal(event)
+ if err != nil {
+ xlog.Error("failed to marshal event", "error", err)
+ return
+ }
+ if err = c.WriteMessage(websocket.TextMessage, eventBytes); err != nil {
+ xlog.Error("write error", "error", err)
+ }
+}
+
+// Helper function to send errors to the client
+func sendError(c *websocket.Conn, code, message, param, eventID string) {
+ errorEvent := types.ErrorEvent{
+ ServerEventBase: types.ServerEventBase{
+ Type: types.ServerEventTypeError,
+ EventID: eventID,
+ },
+ Error: types.Error{
+ Type: "invalid_request_error",
+ Code: code,
+ Message: message,
+ EventID: eventID,
+ },
+ }
+
+ sendEvent(c, errorEvent)
+}
+
+func sendNotImplemented(c *websocket.Conn, message string) {
+ sendError(c, "not_implemented", message, "", "event_TODO")
+}
+
+func updateTransSession(session *Session, update *types.ClientSession, cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) error {
+ sessionLock.Lock()
+ defer sessionLock.Unlock()
+
+ trUpd := update.InputAudioTranscription
+ trCur := session.InputAudioTranscription
+
+ if trUpd != nil && trUpd.Model != "" && trUpd.Model != trCur.Model {
+ pipeline := config.Pipeline{
+ VAD: vadModel,
+ Transcription: trUpd.Model,
+ }
+
+ m, _, err := newTranscriptionOnlyModel(&pipeline, cl, ml, appConfig)
+ if err != nil {
+ return err
+ }
+
+ session.ModelInterface = m
+ }
+
+ if trUpd != nil {
+ trCur.Language = trUpd.Language
+ trCur.Prompt = trUpd.Prompt
+ }
+
+ if update.TurnDetection != nil && update.TurnDetection.Type != "" {
+ session.TurnDetection.Type = types.ServerTurnDetectionType(update.TurnDetection.Type)
+ session.TurnDetection.TurnDetectionParams = update.TurnDetection.TurnDetectionParams
+ }
+
+ return nil
+}
+
+// Function to update session configurations
+func updateSession(session *Session, update *types.ClientSession, cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) error {
+ sessionLock.Lock()
+ defer sessionLock.Unlock()
+
+ if update.Model != "" {
+ pipeline := config.Pipeline{
+ LLM: update.Model,
+ // TODO: Setup pipeline by configuring STT and TTS models
+ }
+ m, err := newModel(&pipeline, cl, ml, appConfig)
+ if err != nil {
+ return err
+ }
+ session.ModelInterface = m
+ session.Model = update.Model
+ }
+
+ if update.Voice != "" {
+ session.Voice = update.Voice
+ }
+ if update.TurnDetection != nil && update.TurnDetection.Type != "" {
+ session.TurnDetection.Type = types.ServerTurnDetectionType(update.TurnDetection.Type)
+ session.TurnDetection.TurnDetectionParams = update.TurnDetection.TurnDetectionParams
+ }
+ // TODO: We should actually check if the field was present in the JSON; empty string means clear the settings
+ if update.Instructions != "" {
+ session.Instructions = update.Instructions
+ }
+ if update.Tools != nil {
+ return fmt.Errorf("Haven't implemented tools")
+ }
+
+ session.InputAudioTranscription = update.InputAudioTranscription
+
+ return nil
+}
+
+// handleVAD is a goroutine that listens for audio data from the client,
+// runs VAD on the audio data, and commits utterances to the conversation
+func handleVAD(cfg *config.ModelConfig, evaluator *templates.Evaluator, session *Session, conv *Conversation, c *websocket.Conn, done chan struct{}) {
+ vadContext, cancel := context.WithCancel(context.Background())
+ go func() {
+ <-done
+ cancel()
+ }()
+
+ silenceThreshold := float64(session.TurnDetection.SilenceDurationMs) / 1000
+ speechStarted := false
+ startTime := time.Now()
+
+ ticker := time.NewTicker(300 * time.Millisecond)
+ defer ticker.Stop()
+
+ for {
+ select {
+ case <-done:
+ return
+ case <-ticker.C:
+ session.AudioBufferLock.Lock()
+ allAudio := make([]byte, len(session.InputAudioBuffer))
+ copy(allAudio, session.InputAudioBuffer)
+ session.AudioBufferLock.Unlock()
+
+ aints := sound.BytesToInt16sLE(allAudio)
+ if len(aints) == 0 || len(aints) < int(silenceThreshold)*remoteSampleRate {
+ continue
+ }
+
+ // Resample from 24kHz to 16kHz
+ aints = sound.ResampleInt16(aints, remoteSampleRate, localSampleRate)
+
+ segments, err := runVAD(vadContext, session, aints)
+ if err != nil {
+ if err.Error() == "unexpected speech end" {
+ xlog.Debug("VAD cancelled")
+ continue
+ }
+ xlog.Error("failed to process audio", "error", err)
+ sendError(c, "processing_error", "Failed to process audio: "+err.Error(), "", "")
+ continue
+ }
+
+ audioLength := float64(len(aints)) / localSampleRate
+
+ // TODO: When resetting the buffer we should retain a small postfix
+ // TODO: The OpenAI documentation seems to suggest that only the client decides when to clear the buffer
+ if len(segments) == 0 && audioLength > silenceThreshold {
+ session.AudioBufferLock.Lock()
+ session.InputAudioBuffer = nil
+ session.AudioBufferLock.Unlock()
+ xlog.Debug("Detected silence for a while, clearing audio buffer")
+
+ sendEvent(c, types.InputAudioBufferClearedEvent{
+ ServerEventBase: types.ServerEventBase{
+ EventID: "event_TODO",
+ Type: types.ServerEventTypeInputAudioBufferCleared,
+ },
+ })
+
+ continue
+ } else if len(segments) == 0 {
+ continue
+ }
+
+ if !speechStarted {
+ sendEvent(c, types.InputAudioBufferSpeechStartedEvent{
+ ServerEventBase: types.ServerEventBase{
+ EventID: "event_TODO",
+ Type: types.ServerEventTypeInputAudioBufferSpeechStarted,
+ },
+ AudioStartMs: time.Now().Sub(startTime).Milliseconds(),
+ })
+ speechStarted = true
+ }
+
+ // Segment still in progress when audio ended
+ segEndTime := segments[len(segments)-1].GetEnd()
+ if segEndTime == 0 {
+ continue
+ }
+
+ if float32(audioLength)-segEndTime > float32(silenceThreshold) {
+ xlog.Debug("Detected end of speech segment")
+ session.AudioBufferLock.Lock()
+ session.InputAudioBuffer = nil
+ session.AudioBufferLock.Unlock()
+
+ sendEvent(c, types.InputAudioBufferSpeechStoppedEvent{
+ ServerEventBase: types.ServerEventBase{
+ EventID: "event_TODO",
+ Type: types.ServerEventTypeInputAudioBufferSpeechStopped,
+ },
+ AudioEndMs: time.Now().Sub(startTime).Milliseconds(),
+ })
+ speechStarted = false
+
+ sendEvent(c, types.InputAudioBufferCommittedEvent{
+ ServerEventBase: types.ServerEventBase{
+ EventID: "event_TODO",
+ Type: types.ServerEventTypeInputAudioBufferCommitted,
+ },
+ ItemID: generateItemID(),
+ PreviousItemID: "TODO",
+ })
+
+ abytes := sound.Int16toBytesLE(aints)
+ // TODO: Remove prefix silence that is is over TurnDetectionParams.PrefixPaddingMs
+ go commitUtterance(vadContext, abytes, cfg, evaluator, session, conv, c)
+ }
+ }
+ }
+}
+
+func commitUtterance(ctx context.Context, utt []byte, cfg *config.ModelConfig, evaluator *templates.Evaluator, session *Session, conv *Conversation, c *websocket.Conn) {
+ if len(utt) == 0 {
+ return
+ }
+
+ // TODO: If we have a real any-to-any model then transcription is optional
+
+ f, err := os.CreateTemp("", "realtime-audio-chunk-*.wav")
+ if err != nil {
+ xlog.Error("failed to create temp file", "error", err)
+ return
+ }
+ defer f.Close()
+ defer os.Remove(f.Name())
+ xlog.Debug("Writing to file", "file", f.Name())
+
+ hdr := laudio.NewWAVHeader(uint32(len(utt)))
+ if err := hdr.Write(f); err != nil {
+ xlog.Error("Failed to write WAV header", "error", err)
+ return
+ }
+
+ if _, err := f.Write(utt); err != nil {
+ xlog.Error("Failed to write audio data", "error", err)
+ return
+ }
+
+ f.Sync()
+
+ if session.InputAudioTranscription != nil {
+ tr, err := session.ModelInterface.Transcribe(ctx, &proto.TranscriptRequest{
+ Dst: f.Name(),
+ Language: session.InputAudioTranscription.Language,
+ Translate: false,
+ Threads: uint32(*cfg.Threads),
+ Prompt: session.InputAudioTranscription.Prompt,
+ })
+ if err != nil {
+ sendError(c, "transcription_failed", err.Error(), "", "event_TODO")
+ }
+
+ sendEvent(c, types.ResponseAudioTranscriptDoneEvent{
+ ServerEventBase: types.ServerEventBase{
+ Type: types.ServerEventTypeResponseAudioTranscriptDone,
+ EventID: "event_TODO",
+ },
+
+ ItemID: generateItemID(),
+ ResponseID: "resp_TODO",
+ OutputIndex: 0,
+ ContentIndex: 0,
+ Transcript: tr.GetText(),
+ })
+ // TODO: Update the prompt with transcription result?
+ }
+
+ if !session.TranscriptionOnly {
+ sendNotImplemented(c, "Commiting items to the conversation not implemented")
+ }
+
+ // TODO: Commit the audio and/or transcribed text to the conversation
+ // Commit logic: create item, broadcast item.created, etc.
+ // item := &Item{
+ // ID: generateItemID(),
+ // Object: "realtime.item",
+ // Type: "message",
+ // Status: "completed",
+ // Role: "user",
+ // Content: []ConversationContent{
+ // {
+ // Type: "input_audio",
+ // Audio: base64.StdEncoding.EncodeToString(utt),
+ // },
+ // },
+ // }
+ // conv.Lock.Lock()
+ // conv.Items = append(conv.Items, item)
+ // conv.Lock.Unlock()
+ //
+ //
+ // sendEvent(c, OutgoingMessage{
+ // Type: "conversation.item.created",
+ // Item: item,
+ // })
+ //
+ //
+ // // trigger the response generation
+ // generateResponse(cfg, evaluator, session, conv, ResponseCreate{}, c, websocket.TextMessage)
+}
+
+func runVAD(ctx context.Context, session *Session, adata []int16) ([]*proto.VADSegment, error) {
+ soundIntBuffer := &audio.IntBuffer{
+ Format: &audio.Format{SampleRate: localSampleRate, NumChannels: 1},
+ SourceBitDepth: 16,
+ Data: sound.ConvertInt16ToInt(adata),
+ }
+
+ float32Data := soundIntBuffer.AsFloat32Buffer().Data
+
+ resp, err := session.ModelInterface.VAD(ctx, &proto.VADRequest{
+ Audio: float32Data,
+ })
+ if err != nil {
+ return nil, err
+ }
+
+ // If resp.Segments is empty => no speech
+ return resp.Segments, nil
+}
+
+// TODO: Below needed for normal mode instead of transcription only
+// Function to generate a response based on the conversation
+// func generateResponse(config *config.ModelConfig, evaluator *templates.Evaluator, session *Session, conversation *Conversation, responseCreate ResponseCreate, c *websocket.Conn, mt int) {
+//
+// log.Debug().Msg("Generating realtime response...")
+//
+// // Compile the conversation history
+// conversation.Lock.Lock()
+// var conversationHistory []schema.Message
+// var latestUserAudio string
+// for _, item := range conversation.Items {
+// for _, content := range item.Content {
+// switch content.Type {
+// case "input_text", "text":
+// conversationHistory = append(conversationHistory, schema.Message{
+// Role: string(item.Role),
+// StringContent: content.Text,
+// Content: content.Text,
+// })
+// case "input_audio":
+// // We do not to turn to text here the audio result.
+// // When generating it later on from the LLM,
+// // we will also generate text and return it and store it in the conversation
+// // Here we just want to get the user audio if there is any as a new input for the conversation.
+// if item.Role == "user" {
+// latestUserAudio = content.Audio
+// }
+// }
+// }
+// }
+//
+// conversation.Lock.Unlock()
+//
+// var generatedText string
+// var generatedAudio []byte
+// var functionCall *FunctionCall
+// var err error
+//
+// if latestUserAudio != "" {
+// // Process the latest user audio input
+// decodedAudio, err := base64.StdEncoding.DecodeString(latestUserAudio)
+// if err != nil {
+// log.Error().Msgf("failed to decode latest user audio: %s", err.Error())
+// sendError(c, "invalid_audio_data", "Failed to decode audio data", "", "")
+// return
+// }
+//
+// // Process the audio input and generate a response
+// generatedText, generatedAudio, functionCall, err = processAudioResponse(session, decodedAudio)
+// if err != nil {
+// log.Error().Msgf("failed to process audio response: %s", err.Error())
+// sendError(c, "processing_error", "Failed to generate audio response", "", "")
+// return
+// }
+// } else {
+//
+// if session.Instructions != "" {
+// conversationHistory = append([]schema.Message{{
+// Role: "system",
+// StringContent: session.Instructions,
+// Content: session.Instructions,
+// }}, conversationHistory...)
+// }
+//
+// funcs := session.Functions
+// shouldUseFn := len(funcs) > 0 && config.ShouldUseFunctions()
+//
+// // Allow the user to set custom actions via config file
+// // to be "embedded" in each model
+// noActionName := "answer"
+// noActionDescription := "use this action to answer without performing any action"
+//
+// if config.FunctionsConfig.NoActionFunctionName != "" {
+// noActionName = config.FunctionsConfig.NoActionFunctionName
+// }
+// if config.FunctionsConfig.NoActionDescriptionName != "" {
+// noActionDescription = config.FunctionsConfig.NoActionDescriptionName
+// }
+//
+// if (!config.FunctionsConfig.GrammarConfig.NoGrammar) && shouldUseFn {
+// noActionGrammar := functions.Function{
+// Name: noActionName,
+// Description: noActionDescription,
+// Parameters: map[string]interface{}{
+// "properties": map[string]interface{}{
+// "message": map[string]interface{}{
+// "type": "string",
+// "description": "The message to reply the user with",
+// }},
+// },
+// }
+//
+// // Append the no action function
+// if !config.FunctionsConfig.DisableNoAction {
+// funcs = append(funcs, noActionGrammar)
+// }
+//
+// // Update input grammar
+// jsStruct := funcs.ToJSONStructure(config.FunctionsConfig.FunctionNameKey, config.FunctionsConfig.FunctionNameKey)
+// g, err := jsStruct.Grammar(config.FunctionsConfig.GrammarOptions()...)
+// if err == nil {
+// config.Grammar = g
+// }
+// }
+//
+// // Generate a response based on text conversation history
+// prompt := evaluator.TemplateMessages(conversationHistory, config, funcs, shouldUseFn)
+//
+// generatedText, functionCall, err = processTextResponse(config, session, prompt)
+// if err != nil {
+// log.Error().Msgf("failed to process text response: %s", err.Error())
+// sendError(c, "processing_error", "Failed to generate text response", "", "")
+// return
+// }
+// log.Debug().Any("text", generatedText).Msg("Generated text response")
+// }
+//
+// if functionCall != nil {
+// // The model wants to call a function
+// // Create a function_call item and send it to the client
+// item := &Item{
+// ID: generateItemID(),
+// Object: "realtime.item",
+// Type: "function_call",
+// Status: "completed",
+// Role: "assistant",
+// FunctionCall: functionCall,
+// }
+//
+// // Add item to conversation
+// conversation.Lock.Lock()
+// conversation.Items = append(conversation.Items, item)
+// conversation.Lock.Unlock()
+//
+// // Send item.created event
+// sendEvent(c, OutgoingMessage{
+// Type: "conversation.item.created",
+// Item: item,
+// })
+//
+// // Optionally, you can generate a message to the user indicating the function call
+// // For now, we'll assume the client handles the function call and may trigger another response
+//
+// } else {
+// // Send response.stream messages
+// if generatedAudio != nil {
+// // If generatedAudio is available, send it as audio
+// encodedAudio := base64.StdEncoding.EncodeToString(generatedAudio)
+// outgoingMsg := OutgoingMessage{
+// Type: "response.stream",
+// Audio: encodedAudio,
+// }
+// sendEvent(c, outgoingMsg)
+// } else {
+// // Send text response (could be streamed in chunks)
+// chunks := splitResponseIntoChunks(generatedText)
+// for _, chunk := range chunks {
+// outgoingMsg := OutgoingMessage{
+// Type: "response.stream",
+// Content: chunk,
+// }
+// sendEvent(c, outgoingMsg)
+// }
+// }
+//
+// // Send response.done message
+// sendEvent(c, OutgoingMessage{
+// Type: "response.done",
+// })
+//
+// // Add the assistant's response to the conversation
+// content := []ConversationContent{}
+// if generatedAudio != nil {
+// content = append(content, ConversationContent{
+// Type: "audio",
+// Audio: base64.StdEncoding.EncodeToString(generatedAudio),
+// })
+// // Optionally include a text transcript
+// if generatedText != "" {
+// content = append(content, ConversationContent{
+// Type: "text",
+// Text: generatedText,
+// })
+// }
+// } else {
+// content = append(content, ConversationContent{
+// Type: "text",
+// Text: generatedText,
+// })
+// }
+//
+// item := &Item{
+// ID: generateItemID(),
+// Object: "realtime.item",
+// Type: "message",
+// Status: "completed",
+// Role: "assistant",
+// Content: content,
+// }
+//
+// // Add item to conversation
+// conversation.Lock.Lock()
+// conversation.Items = append(conversation.Items, item)
+// conversation.Lock.Unlock()
+//
+// // Send item.created event
+// sendEvent(c, OutgoingMessage{
+// Type: "conversation.item.created",
+// Item: item,
+// })
+//
+// log.Debug().Any("item", item).Msg("Realtime response sent")
+// }
+// }
+
+// Function to process text response and detect function calls
+func processTextResponse(config *config.ModelConfig, session *Session, prompt string) (string, *FunctionCall, error) {
+
+ // Placeholder implementation
+ // Replace this with actual model inference logic using session.Model and prompt
+ // For example, the model might return a special token or JSON indicating a function call
+
+ /*
+ predFunc, err := backend.ModelInference(context.Background(), prompt, input.Messages, images, videos, audios, ml, *config, o, nil, "", "", nil, nil, nil)
+
+ result, tokenUsage, err := ComputeChoices(input, prompt, config, startupOptions, ml, func(s string, c *[]schema.Choice) {
+ if !shouldUseFn {
+ // no function is called, just reply and use stop as finish reason
+ stopReason := FinishReasonStop
+ *c = append(*c, schema.Choice{FinishReason: &stopReason, Index: 0, Message: &schema.Message{Role: "assistant", Content: &s}})
+ return
+ }
+
+ textContentToReturn = functions.ParseTextContent(s, config.FunctionsConfig)
+ s = functions.CleanupLLMResult(s, config.FunctionsConfig)
+ results := functions.ParseFunctionCall(s, config.FunctionsConfig)
+ xlog.Debug("Text content to return", "text", textContentToReturn)
+ noActionsToRun := len(results) > 0 && results[0].Name == noActionName || len(results) == 0
+
+ switch {
+ case noActionsToRun:
+ result, err := handleQuestion(config, input, ml, startupOptions, results, s, predInput)
+ if err != nil {
+ xlog.Error("error handling question", "error", err)
+ return
+ }
+ *c = append(*c, schema.Choice{
+ Message: &schema.Message{Role: "assistant", Content: &result}})
+ default:
+ toolChoice := schema.Choice{
+ Message: &schema.Message{
+ Role: "assistant",
+ },
+ }
+
+ if len(input.Tools) > 0 {
+ toolCallsReason := FinishReasonToolCalls
+ toolChoice.FinishReason = &toolCallsReason
+ }
+
+ for _, ss := range results {
+ name, args := ss.Name, ss.Arguments
+ if len(input.Tools) > 0 {
+ // If we are using tools, we condense the function calls into
+ // a single response choice with all the tools
+ toolChoice.Message.Content = textContentToReturn
+ toolChoice.Message.ToolCalls = append(toolChoice.Message.ToolCalls,
+ schema.ToolCall{
+ ID: id,
+ Type: "function",
+ FunctionCall: schema.FunctionCall{
+ Name: name,
+ Arguments: args,
+ },
+ },
+ )
+ } else {
+ // otherwise we return more choices directly
+ functionCallReason := FinishReasonFunctionCall
+ *c = append(*c, schema.Choice{
+ FinishReason: &functionCallReason,
+ Message: &schema.Message{
+ Role: "assistant",
+ Content: &textContentToReturn,
+ FunctionCall: map[string]interface{}{
+ "name": name,
+ "arguments": args,
+ },
+ },
+ })
+ }
+ }
+
+ if len(input.Tools) > 0 {
+ // we need to append our result if we are using tools
+ *c = append(*c, toolChoice)
+ }
+ }
+
+ }, nil)
+ if err != nil {
+ return err
+ }
+
+ resp := &schema.OpenAIResponse{
+ ID: id,
+ Created: created,
+ Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
+ Choices: result,
+ Object: "chat.completion",
+ Usage: schema.OpenAIUsage{
+ PromptTokens: tokenUsage.Prompt,
+ CompletionTokens: tokenUsage.Completion,
+ TotalTokens: tokenUsage.Prompt + tokenUsage.Completion,
+ },
+ }
+ respData, _ := json.Marshal(resp)
+ xlog.Debug("Response", "response", string(respData))
+
+ // Return the prediction in the response body
+ return c.JSON(resp)
+
+ */
+
+ // TODO: use session.ModelInterface...
+ // Simulate a function call
+ if strings.Contains(prompt, "weather") {
+ functionCall := &FunctionCall{
+ Name: "get_weather",
+ Arguments: map[string]interface{}{
+ "location": "New York",
+ "scale": "celsius",
+ },
+ }
+ return "", functionCall, nil
+ }
+
+ // Otherwise, return a normal text response
+ return "This is a generated response based on the conversation.", nil, nil
+}
+
+// Function to process audio response and detect function calls
+func processAudioResponse(session *Session, audioData []byte) (string, []byte, *FunctionCall, error) {
+ // TODO: Do the below or use an any-to-any model like Qwen Omni
+ // Implement the actual model inference logic using session.Model and audioData
+ // For example:
+ // 1. Transcribe the audio to text
+ // 2. Generate a response based on the transcribed text
+ // 3. Check if the model wants to call a function
+ // 4. Convert the response text to speech (audio)
+ //
+ // Placeholder implementation:
+
+ // TODO: template eventual messages, like chat.go
+ reply, err := session.ModelInterface.Predict(context.Background(), &proto.PredictOptions{
+ Prompt: "What's the weather in New York?",
+ })
+
+ if err != nil {
+ return "", nil, nil, err
+ }
+
+ generatedAudio := reply.Audio
+
+ transcribedText := "What's the weather in New York?"
+ var functionCall *FunctionCall
+
+ // Simulate a function call
+ if strings.Contains(transcribedText, "weather") {
+ functionCall = &FunctionCall{
+ Name: "get_weather",
+ Arguments: map[string]interface{}{
+ "location": "New York",
+ "scale": "celsius",
+ },
+ }
+ return "", nil, functionCall, nil
+ }
+
+ // Generate a response
+ generatedText := "This is a response to your speech input."
+
+ return generatedText, generatedAudio, nil, nil
+}
+
+// Function to split the response into chunks (for streaming)
+func splitResponseIntoChunks(response string) []string {
+ // Split the response into chunks of fixed size
+ chunkSize := 50 // characters per chunk
+ var chunks []string
+ for len(response) > 0 {
+ if len(response) > chunkSize {
+ chunks = append(chunks, response[:chunkSize])
+ response = response[chunkSize:]
+ } else {
+ chunks = append(chunks, response)
+ break
+ }
+ }
+ return chunks
+}
+
+// Helper functions to generate unique IDs
+func generateSessionID() string {
+ // Generate a unique session ID
+ // Implement as needed
+ return "sess_" + generateUniqueID()
+}
+
+func generateConversationID() string {
+ // Generate a unique conversation ID
+ // Implement as needed
+ return "conv_" + generateUniqueID()
+}
+
+func generateItemID() string {
+ // Generate a unique item ID
+ // Implement as needed
+ return "item_" + generateUniqueID()
+}
+
+func generateUniqueID() string {
+ // Generate a unique ID string
+ // For simplicity, use a counter or UUID
+ // Implement as needed
+ return "unique_id"
+}
+
+// Structures for 'response.create' messages
+type ResponseCreate struct {
+ Modalities []string `json:"modalities,omitempty"`
+ Instructions string `json:"instructions,omitempty"`
+ Functions functions.Functions `json:"functions,omitempty"`
+ // Other fields as needed
+}
diff --git a/core/http/endpoints/openai/realtime_model.go b/core/http/endpoints/openai/realtime_model.go
new file mode 100644
index 0000000000000000000000000000000000000000..ac52627a8995d3e7cd28f857088404c844e31128
--- /dev/null
+++ b/core/http/endpoints/openai/realtime_model.go
@@ -0,0 +1,258 @@
+package openai
+
+import (
+ "context"
+ "fmt"
+
+ "github.com/mudler/LocalAI/core/backend"
+ "github.com/mudler/LocalAI/core/config"
+ grpcClient "github.com/mudler/LocalAI/pkg/grpc"
+ "github.com/mudler/LocalAI/pkg/grpc/proto"
+ model "github.com/mudler/LocalAI/pkg/model"
+ "github.com/mudler/xlog"
+ "google.golang.org/grpc"
+)
+
+var (
+ _ Model = new(wrappedModel)
+ _ Model = new(anyToAnyModel)
+)
+
+// wrappedModel represent a model which does not support Any-to-Any operations
+// This means that we will fake an Any-to-Any model by overriding some of the gRPC client methods
+// which are for Any-To-Any models, but instead we will call a pipeline (for e.g STT->LLM->TTS)
+type wrappedModel struct {
+ TTSConfig *config.ModelConfig
+ TranscriptionConfig *config.ModelConfig
+ LLMConfig *config.ModelConfig
+ TTSClient grpcClient.Backend
+ TranscriptionClient grpcClient.Backend
+ LLMClient grpcClient.Backend
+
+ VADConfig *config.ModelConfig
+ VADClient grpcClient.Backend
+}
+
+// anyToAnyModel represent a model which supports Any-to-Any operations
+// We have to wrap this out as well because we want to load two models one for VAD and one for the actual model.
+// In the future there could be models that accept continous audio input only so this design will be useful for that
+type anyToAnyModel struct {
+ LLMConfig *config.ModelConfig
+ LLMClient grpcClient.Backend
+
+ VADConfig *config.ModelConfig
+ VADClient grpcClient.Backend
+}
+
+type transcriptOnlyModel struct {
+ TranscriptionConfig *config.ModelConfig
+ TranscriptionClient grpcClient.Backend
+ VADConfig *config.ModelConfig
+ VADClient grpcClient.Backend
+}
+
+func (m *transcriptOnlyModel) VAD(ctx context.Context, in *proto.VADRequest, opts ...grpc.CallOption) (*proto.VADResponse, error) {
+ return m.VADClient.VAD(ctx, in)
+}
+
+func (m *transcriptOnlyModel) Transcribe(ctx context.Context, in *proto.TranscriptRequest, opts ...grpc.CallOption) (*proto.TranscriptResult, error) {
+ return m.TranscriptionClient.AudioTranscription(ctx, in, opts...)
+}
+
+func (m *transcriptOnlyModel) Predict(ctx context.Context, in *proto.PredictOptions, opts ...grpc.CallOption) (*proto.Reply, error) {
+ return nil, fmt.Errorf("predict operation not supported in transcript-only mode")
+}
+
+func (m *transcriptOnlyModel) PredictStream(ctx context.Context, in *proto.PredictOptions, f func(reply *proto.Reply), opts ...grpc.CallOption) error {
+ return fmt.Errorf("predict stream operation not supported in transcript-only mode")
+}
+
+func (m *wrappedModel) VAD(ctx context.Context, in *proto.VADRequest, opts ...grpc.CallOption) (*proto.VADResponse, error) {
+ return m.VADClient.VAD(ctx, in)
+}
+
+func (m *anyToAnyModel) VAD(ctx context.Context, in *proto.VADRequest, opts ...grpc.CallOption) (*proto.VADResponse, error) {
+ return m.VADClient.VAD(ctx, in)
+}
+
+func (m *wrappedModel) Transcribe(ctx context.Context, in *proto.TranscriptRequest, opts ...grpc.CallOption) (*proto.TranscriptResult, error) {
+ return m.TranscriptionClient.AudioTranscription(ctx, in, opts...)
+}
+
+func (m *anyToAnyModel) Transcribe(ctx context.Context, in *proto.TranscriptRequest, opts ...grpc.CallOption) (*proto.TranscriptResult, error) {
+ // TODO: Can any-to-any models transcribe?
+ return m.LLMClient.AudioTranscription(ctx, in, opts...)
+}
+
+func (m *wrappedModel) Predict(ctx context.Context, in *proto.PredictOptions, opts ...grpc.CallOption) (*proto.Reply, error) {
+ // TODO: Convert with pipeline (audio to text, text to llm, result to tts, and return it)
+ // sound.BufferAsWAV(audioData, "audio.wav")
+
+ return m.LLMClient.Predict(ctx, in)
+}
+
+func (m *wrappedModel) PredictStream(ctx context.Context, in *proto.PredictOptions, f func(reply *proto.Reply), opts ...grpc.CallOption) error {
+ // TODO: Convert with pipeline (audio to text, text to llm, result to tts, and return it)
+
+ return m.LLMClient.PredictStream(ctx, in, f)
+}
+
+func (m *anyToAnyModel) Predict(ctx context.Context, in *proto.PredictOptions, opts ...grpc.CallOption) (*proto.Reply, error) {
+ return m.LLMClient.Predict(ctx, in)
+}
+
+func (m *anyToAnyModel) PredictStream(ctx context.Context, in *proto.PredictOptions, f func(reply *proto.Reply), opts ...grpc.CallOption) error {
+ return m.LLMClient.PredictStream(ctx, in, f)
+}
+
+func newTranscriptionOnlyModel(pipeline *config.Pipeline, cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) (Model, *config.ModelConfig, error) {
+ cfgVAD, err := cl.LoadModelConfigFileByName(pipeline.VAD, ml.ModelPath)
+ if err != nil {
+
+ return nil, nil, fmt.Errorf("failed to load backend config: %w", err)
+ }
+
+ if valid, _ := cfgVAD.Validate(); !valid {
+ return nil, nil, fmt.Errorf("failed to validate config: %w", err)
+ }
+
+ opts := backend.ModelOptions(*cfgVAD, appConfig)
+ VADClient, err := ml.Load(opts...)
+ if err != nil {
+ return nil, nil, fmt.Errorf("failed to load tts model: %w", err)
+ }
+
+ cfgSST, err := cl.LoadModelConfigFileByName(pipeline.Transcription, ml.ModelPath)
+ if err != nil {
+
+ return nil, nil, fmt.Errorf("failed to load backend config: %w", err)
+ }
+
+ if valid, _ := cfgSST.Validate(); !valid {
+ return nil, nil, fmt.Errorf("failed to validate config: %w", err)
+ }
+
+ opts = backend.ModelOptions(*cfgSST, appConfig)
+ transcriptionClient, err := ml.Load(opts...)
+ if err != nil {
+ return nil, nil, fmt.Errorf("failed to load SST model: %w", err)
+ }
+
+ return &transcriptOnlyModel{
+ VADConfig: cfgVAD,
+ VADClient: VADClient,
+ TranscriptionConfig: cfgSST,
+ TranscriptionClient: transcriptionClient,
+ }, cfgSST, nil
+}
+
+// returns and loads either a wrapped model or a model that support audio-to-audio
+func newModel(pipeline *config.Pipeline, cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) (Model, error) {
+
+ cfgVAD, err := cl.LoadModelConfigFileByName(pipeline.VAD, ml.ModelPath)
+ if err != nil {
+
+ return nil, fmt.Errorf("failed to load backend config: %w", err)
+ }
+
+ if valid, _ := cfgVAD.Validate(); !valid {
+ return nil, fmt.Errorf("failed to validate config: %w", err)
+ }
+
+ opts := backend.ModelOptions(*cfgVAD, appConfig)
+ VADClient, err := ml.Load(opts...)
+ if err != nil {
+ return nil, fmt.Errorf("failed to load tts model: %w", err)
+ }
+
+ // TODO: Do we always need a transcription model? It can be disabled. Note that any-to-any instruction following models don't transcribe as such, so if transcription is required it is a separate process
+ cfgSST, err := cl.LoadModelConfigFileByName(pipeline.Transcription, ml.ModelPath)
+ if err != nil {
+
+ return nil, fmt.Errorf("failed to load backend config: %w", err)
+ }
+
+ if valid, _ := cfgSST.Validate(); !valid {
+ return nil, fmt.Errorf("failed to validate config: %w", err)
+ }
+
+ opts = backend.ModelOptions(*cfgSST, appConfig)
+ transcriptionClient, err := ml.Load(opts...)
+ if err != nil {
+ return nil, fmt.Errorf("failed to load SST model: %w", err)
+ }
+
+ // TODO: Decide when we have a real any-to-any model
+ if false {
+
+ cfgAnyToAny, err := cl.LoadModelConfigFileByName(pipeline.LLM, ml.ModelPath)
+ if err != nil {
+
+ return nil, fmt.Errorf("failed to load backend config: %w", err)
+ }
+
+ if valid, _ := cfgAnyToAny.Validate(); !valid {
+ return nil, fmt.Errorf("failed to validate config: %w", err)
+ }
+
+ opts := backend.ModelOptions(*cfgAnyToAny, appConfig)
+ anyToAnyClient, err := ml.Load(opts...)
+ if err != nil {
+ return nil, fmt.Errorf("failed to load tts model: %w", err)
+ }
+
+ return &anyToAnyModel{
+ LLMConfig: cfgAnyToAny,
+ LLMClient: anyToAnyClient,
+ VADConfig: cfgVAD,
+ VADClient: VADClient,
+ }, nil
+ }
+
+ xlog.Debug("Loading a wrapped model")
+
+ // Otherwise we want to return a wrapped model, which is a "virtual" model that re-uses other models to perform operations
+ cfgLLM, err := cl.LoadModelConfigFileByName(pipeline.LLM, ml.ModelPath)
+ if err != nil {
+
+ return nil, fmt.Errorf("failed to load backend config: %w", err)
+ }
+
+ if valid, _ := cfgLLM.Validate(); !valid {
+ return nil, fmt.Errorf("failed to validate config: %w", err)
+ }
+
+ cfgTTS, err := cl.LoadModelConfigFileByName(pipeline.TTS, ml.ModelPath)
+ if err != nil {
+
+ return nil, fmt.Errorf("failed to load backend config: %w", err)
+ }
+
+ if valid, _ := cfgTTS.Validate(); !valid {
+ return nil, fmt.Errorf("failed to validate config: %w", err)
+ }
+
+ opts = backend.ModelOptions(*cfgTTS, appConfig)
+ ttsClient, err := ml.Load(opts...)
+ if err != nil {
+ return nil, fmt.Errorf("failed to load tts model: %w", err)
+ }
+
+ opts = backend.ModelOptions(*cfgLLM, appConfig)
+ llmClient, err := ml.Load(opts...)
+ if err != nil {
+ return nil, fmt.Errorf("failed to load LLM model: %w", err)
+ }
+
+ return &wrappedModel{
+ TTSConfig: cfgTTS,
+ TranscriptionConfig: cfgSST,
+ LLMConfig: cfgLLM,
+ TTSClient: ttsClient,
+ TranscriptionClient: transcriptionClient,
+ LLMClient: llmClient,
+
+ VADConfig: cfgVAD,
+ VADClient: VADClient,
+ }, nil
+}
diff --git a/core/http/endpoints/openai/transcription.go b/core/http/endpoints/openai/transcription.go
new file mode 100644
index 0000000000000000000000000000000000000000..2c5f98d5cbc0a5882ff378e40f917c85a655d949
--- /dev/null
+++ b/core/http/endpoints/openai/transcription.go
@@ -0,0 +1,82 @@
+package openai
+
+import (
+ "io"
+ "net/http"
+ "os"
+ "path"
+ "path/filepath"
+
+ "github.com/labstack/echo/v4"
+ "github.com/mudler/LocalAI/core/backend"
+ "github.com/mudler/LocalAI/core/config"
+ "github.com/mudler/LocalAI/core/http/middleware"
+ "github.com/mudler/LocalAI/core/schema"
+ model "github.com/mudler/LocalAI/pkg/model"
+
+ "github.com/mudler/xlog"
+)
+
+// TranscriptEndpoint is the OpenAI Whisper API endpoint https://platform.openai.com/docs/api-reference/audio/create
+// @Summary Transcribes audio into the input language.
+// @accept multipart/form-data
+// @Param model formData string true "model"
+// @Param file formData file true "file"
+// @Success 200 {object} map[string]string "Response"
+// @Router /v1/audio/transcriptions [post]
+func TranscriptEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
+ return func(c echo.Context) error {
+ input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
+ if !ok || input.Model == "" {
+ return echo.ErrBadRequest
+ }
+
+ config, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
+ if !ok || config == nil {
+ return echo.ErrBadRequest
+ }
+
+ diarize := c.FormValue("diarize") != "false"
+ prompt := c.FormValue("prompt")
+
+ // retrieve the file data from the request
+ file, err := c.FormFile("file")
+ if err != nil {
+ return err
+ }
+ f, err := file.Open()
+ if err != nil {
+ return err
+ }
+ defer f.Close()
+
+ dir, err := os.MkdirTemp("", "whisper")
+
+ if err != nil {
+ return err
+ }
+ defer os.RemoveAll(dir)
+
+ dst := filepath.Join(dir, path.Base(file.Filename))
+ dstFile, err := os.Create(dst)
+ if err != nil {
+ return err
+ }
+
+ if _, err := io.Copy(dstFile, f); err != nil {
+ xlog.Debug("Audio file copying error", "filename", file.Filename, "dst", dst, "error", err)
+ return err
+ }
+
+ xlog.Debug("Audio file copied", "dst", dst)
+
+ tr, err := backend.ModelTranscription(dst, input.Language, input.Translate, diarize, prompt, ml, *config, appConfig)
+ if err != nil {
+ return err
+ }
+
+ xlog.Debug("Transcribed", "transcription", tr)
+ // TODO: handle different outputs here
+ return c.JSON(http.StatusOK, tr)
+ }
+}
diff --git a/core/http/endpoints/openai/types/realtime.go b/core/http/endpoints/openai/types/realtime.go
new file mode 100644
index 0000000000000000000000000000000000000000..a79d05d9cb83ad8c07c65af044dbef1d65df39fe
--- /dev/null
+++ b/core/http/endpoints/openai/types/realtime.go
@@ -0,0 +1,1188 @@
+package types
+
+// Most of this file was coppied from https://github.com/WqyJh/go-openai-realtime
+// Copyright (c) 2024 Qiying Wang MIT License
+
+import (
+ "encoding/json"
+ "fmt"
+ "math"
+)
+
+const (
+ // Inf is the maximum value for an IntOrInf.
+ Inf IntOrInf = math.MaxInt
+)
+
+// IntOrInf is a type that can be either an int or "inf".
+type IntOrInf int
+
+// IsInf returns true if the value is "inf".
+func (m IntOrInf) IsInf() bool {
+ return m == Inf
+}
+
+// MarshalJSON marshals the IntOrInf to JSON.
+func (m IntOrInf) MarshalJSON() ([]byte, error) {
+ if m == Inf {
+ return []byte("\"inf\""), nil
+ }
+ return json.Marshal(int(m))
+}
+
+// UnmarshalJSON unmarshals the IntOrInf from JSON.
+func (m *IntOrInf) UnmarshalJSON(data []byte) error {
+ if string(data) == "\"inf\"" {
+ *m = Inf
+ return nil
+ }
+ if len(data) == 0 {
+ return nil
+ }
+ return json.Unmarshal(data, (*int)(m))
+}
+
+type AudioFormat string
+
+const (
+ AudioFormatPcm16 AudioFormat = "pcm16"
+ AudioFormatG711Ulaw AudioFormat = "g711_ulaw"
+ AudioFormatG711Alaw AudioFormat = "g711_alaw"
+)
+
+type Modality string
+
+const (
+ ModalityText Modality = "text"
+ ModalityAudio Modality = "audio"
+)
+
+type ClientTurnDetectionType string
+
+const (
+ ClientTurnDetectionTypeServerVad ClientTurnDetectionType = "server_vad"
+)
+
+type ServerTurnDetectionType string
+
+const (
+ ServerTurnDetectionTypeNone ServerTurnDetectionType = "none"
+ ServerTurnDetectionTypeServerVad ServerTurnDetectionType = "server_vad"
+)
+
+type TurnDetectionType string
+
+const (
+ // TurnDetectionTypeNone means turn detection is disabled.
+ // This can only be used in ServerSession, not in ClientSession.
+ // If you want to disable turn detection, you should send SessionUpdateEvent with TurnDetection set to nil.
+ TurnDetectionTypeNone TurnDetectionType = "none"
+ // TurnDetectionTypeServerVad use server-side VAD to detect turn.
+ // This is default value for newly created session.
+ TurnDetectionTypeServerVad TurnDetectionType = "server_vad"
+)
+
+type TurnDetectionParams struct {
+ // Activation threshold for VAD.
+ Threshold float64 `json:"threshold,omitempty"`
+ // Audio included before speech starts (in milliseconds).
+ PrefixPaddingMs int `json:"prefix_padding_ms,omitempty"`
+ // Duration of silence to detect speech stop (in milliseconds).
+ SilenceDurationMs int `json:"silence_duration_ms,omitempty"`
+ // Whether or not to automatically generate a response when VAD is enabled. true by default.
+ CreateResponse *bool `json:"create_response,omitempty"`
+}
+
+type ClientTurnDetection struct {
+ // Type of turn detection, only "server_vad" is currently supported.
+ Type ClientTurnDetectionType `json:"type"`
+
+ TurnDetectionParams
+}
+
+type ServerTurnDetection struct {
+ // The type of turn detection ("server_vad" or "none").
+ Type ServerTurnDetectionType `json:"type"`
+
+ TurnDetectionParams
+}
+
+type ToolType string
+
+const (
+ ToolTypeFunction ToolType = "function"
+)
+
+type ToolChoiceInterface interface {
+ ToolChoice()
+}
+
+type ToolChoiceString string
+
+func (ToolChoiceString) ToolChoice() {}
+
+const (
+ ToolChoiceAuto ToolChoiceString = "auto"
+ ToolChoiceNone ToolChoiceString = "none"
+ ToolChoiceRequired ToolChoiceString = "required"
+)
+
+type ToolChoice struct {
+ Type ToolType `json:"type"`
+ Function ToolFunction `json:"function,omitempty"`
+}
+
+func (t ToolChoice) ToolChoice() {}
+
+type ToolFunction struct {
+ Name string `json:"name"`
+}
+
+type MessageRole string
+
+const (
+ MessageRoleSystem MessageRole = "system"
+ MessageRoleAssistant MessageRole = "assistant"
+ MessageRoleUser MessageRole = "user"
+)
+
+type InputAudioTranscription struct {
+ // The model used for transcription.
+ Model string `json:"model"`
+ Language string `json:"language,omitempty"`
+ Prompt string `json:"prompt,omitempty"`
+}
+
+type Tool struct {
+ Type ToolType `json:"type"`
+ Name string `json:"name"`
+ Description string `json:"description"`
+ Parameters any `json:"parameters"`
+}
+
+type MessageItemType string
+
+const (
+ MessageItemTypeMessage MessageItemType = "message"
+ MessageItemTypeFunctionCall MessageItemType = "function_call"
+ MessageItemTypeFunctionCallOutput MessageItemType = "function_call_output"
+)
+
+type MessageContentType string
+
+const (
+ MessageContentTypeText MessageContentType = "text"
+ MessageContentTypeAudio MessageContentType = "audio"
+ MessageContentTypeTranscript MessageContentType = "transcript"
+ MessageContentTypeInputText MessageContentType = "input_text"
+ MessageContentTypeInputAudio MessageContentType = "input_audio"
+)
+
+type MessageContentPart struct {
+ // The content type.
+ Type MessageContentType `json:"type"`
+ // The text content. Validated if type is text.
+ Text string `json:"text,omitempty"`
+ // Base64-encoded audio data. Validated if type is audio.
+ Audio string `json:"audio,omitempty"`
+ // The transcript of the audio. Validated if type is transcript.
+ Transcript string `json:"transcript,omitempty"`
+}
+
+type MessageItem struct {
+ // The unique ID of the item.
+ ID string `json:"id,omitempty"`
+ // The type of the item ("message", "function_call", "function_call_output").
+ Type MessageItemType `json:"type"`
+ // The final status of the item.
+ Status ItemStatus `json:"status,omitempty"`
+ // The role associated with the item.
+ Role MessageRole `json:"role,omitempty"`
+ // The content of the item.
+ Content []MessageContentPart `json:"content,omitempty"`
+ // The ID of the function call, if the item is a function call.
+ CallID string `json:"call_id,omitempty"`
+ // The name of the function, if the item is a function call.
+ Name string `json:"name,omitempty"`
+ // The arguments of the function, if the item is a function call.
+ Arguments string `json:"arguments,omitempty"`
+ // The output of the function, if the item is a function call output.
+ Output string `json:"output,omitempty"`
+}
+
+type ResponseMessageItem struct {
+ MessageItem
+ // The object type, must be "realtime.item".
+ Object string `json:"object,omitempty"`
+}
+
+type Error struct {
+ // The type of error (e.g., "invalid_request_error", "server_error").
+ Message string `json:"message,omitempty"`
+ // Error code, if any.
+ Type string `json:"type,omitempty"`
+ // A human-readable error message.
+ Code string `json:"code,omitempty"`
+ // Parameter related to the error, if any.
+ Param string `json:"param,omitempty"`
+ // The event_id of the client event that caused the error, if applicable.
+ EventID string `json:"event_id,omitempty"`
+}
+
+// ServerToolChoice is a type that can be used to choose a tool response from the server.
+type ServerToolChoice struct {
+ String ToolChoiceString
+ Function ToolChoice
+}
+
+// UnmarshalJSON is a custom unmarshaler for ServerToolChoice.
+func (m *ServerToolChoice) UnmarshalJSON(data []byte) error {
+ err := json.Unmarshal(data, &m.Function)
+ if err != nil {
+ if data[0] == '"' {
+ data = data[1:]
+ }
+ if data[len(data)-1] == '"' {
+ data = data[:len(data)-1]
+ }
+ m.String = ToolChoiceString(data)
+ m.Function = ToolChoice{}
+ return nil
+ }
+ return nil
+}
+
+// IsFunction returns true if the tool choice is a function call.
+func (m *ServerToolChoice) IsFunction() bool {
+ return m.Function.Type == ToolTypeFunction
+}
+
+// Get returns the ToolChoiceInterface based on the type of tool choice.
+func (m ServerToolChoice) Get() ToolChoiceInterface {
+ if m.IsFunction() {
+ return m.Function
+ }
+ return m.String
+}
+
+type ServerSession struct {
+ // The unique ID of the session.
+ ID string `json:"id"`
+ // The object type, must be "realtime.session".
+ Object string `json:"object"`
+ // The default model used for this session.
+ Model string `json:"model"`
+ // The set of modalities the model can respond with.
+ Modalities []Modality `json:"modalities,omitempty"`
+ // The default system instructions.
+ Instructions string `json:"instructions,omitempty"`
+ // The voice the model uses to respond - one of alloy, echo, or shimmer.
+ Voice string `json:"voice,omitempty"`
+ // The format of input audio.
+ InputAudioFormat AudioFormat `json:"input_audio_format,omitempty"`
+ // The format of output audio.
+ OutputAudioFormat AudioFormat `json:"output_audio_format,omitempty"`
+ // Configuration for input audio transcription.
+ InputAudioTranscription *InputAudioTranscription `json:"input_audio_transcription,omitempty"`
+ // Configuration for turn detection.
+ TurnDetection *ServerTurnDetection `json:"turn_detection,omitempty"`
+ // Tools (functions) available to the model.
+ Tools []Tool `json:"tools,omitempty"`
+ // How the model chooses tools.
+ ToolChoice ServerToolChoice `json:"tool_choice,omitempty"`
+ // Sampling temperature.
+ Temperature *float32 `json:"temperature,omitempty"`
+ // Maximum number of output tokens.
+ MaxOutputTokens IntOrInf `json:"max_response_output_tokens,omitempty"`
+}
+
+type ItemStatus string
+
+const (
+ ItemStatusInProgress ItemStatus = "in_progress"
+ ItemStatusCompleted ItemStatus = "completed"
+ ItemStatusIncomplete ItemStatus = "incomplete"
+)
+
+type Conversation struct {
+ // The unique ID of the conversation.
+ ID string `json:"id"`
+ // The object type, must be "realtime.conversation".
+ Object string `json:"object"`
+}
+
+type ResponseStatus string
+
+const (
+ ResponseStatusInProgress ResponseStatus = "in_progress"
+ ResponseStatusCompleted ResponseStatus = "completed"
+ ResponseStatusCancelled ResponseStatus = "cancelled"
+ ResponseStatusIncomplete ResponseStatus = "incomplete"
+ ResponseStatusFailed ResponseStatus = "failed"
+)
+
+type CachedTokensDetails struct {
+ TextTokens int `json:"text_tokens"`
+ AudioTokens int `json:"audio_tokens"`
+}
+
+type InputTokenDetails struct {
+ CachedTokens int `json:"cached_tokens"`
+ TextTokens int `json:"text_tokens"`
+ AudioTokens int `json:"audio_tokens"`
+ CachedTokensDetails CachedTokensDetails `json:"cached_tokens_details,omitempty"`
+}
+
+type OutputTokenDetails struct {
+ TextTokens int `json:"text_tokens"`
+ AudioTokens int `json:"audio_tokens"`
+}
+
+type Usage struct {
+ TotalTokens int `json:"total_tokens"`
+ InputTokens int `json:"input_tokens"`
+ OutputTokens int `json:"output_tokens"`
+ // Input token details.
+ InputTokenDetails InputTokenDetails `json:"input_token_details,omitempty"`
+ // Output token details.
+ OutputTokenDetails OutputTokenDetails `json:"output_token_details,omitempty"`
+}
+
+type Response struct {
+ // The unique ID of the response.
+ ID string `json:"id"`
+ // The object type, must be "realtime.response".
+ Object string `json:"object"`
+ // The status of the response.
+ Status ResponseStatus `json:"status"`
+ // Additional details about the status.
+ StatusDetails any `json:"status_details,omitempty"`
+ // The list of output items generated by the response.
+ Output []ResponseMessageItem `json:"output"`
+ // Usage statistics for the response.
+ Usage *Usage `json:"usage,omitempty"`
+}
+
+type RateLimit struct {
+ // The name of the rate limit ("requests", "tokens", "input_tokens", "output_tokens").
+ Name string `json:"name"`
+ // The maximum allowed value for the rate limit.
+ Limit int `json:"limit"`
+ // The remaining value before the limit is reached.
+ Remaining int `json:"remaining"`
+ // Seconds until the rate limit resets.
+ ResetSeconds float64 `json:"reset_seconds"`
+}
+
+// ClientEventType is the type of client event. See https://platform.openai.com/docs/guides/realtime/client-events
+type ClientEventType string
+
+const (
+ ClientEventTypeSessionUpdate ClientEventType = "session.update"
+ ClientEventTypeTranscriptionSessionUpdate ClientEventType = "transcription_session.update"
+ ClientEventTypeInputAudioBufferAppend ClientEventType = "input_audio_buffer.append"
+ ClientEventTypeInputAudioBufferCommit ClientEventType = "input_audio_buffer.commit"
+ ClientEventTypeInputAudioBufferClear ClientEventType = "input_audio_buffer.clear"
+ ClientEventTypeConversationItemCreate ClientEventType = "conversation.item.create"
+ ClientEventTypeConversationItemTruncate ClientEventType = "conversation.item.truncate"
+ ClientEventTypeConversationItemDelete ClientEventType = "conversation.item.delete"
+ ClientEventTypeResponseCreate ClientEventType = "response.create"
+ ClientEventTypeResponseCancel ClientEventType = "response.cancel"
+)
+
+// ClientEvent is the interface for client event.
+type ClientEvent interface {
+ ClientEventType() ClientEventType
+}
+
+// EventBase is the base struct for all client events.
+type EventBase struct {
+ // Optional client-generated ID used to identify this event.
+ EventID string `json:"event_id,omitempty"`
+}
+
+type ClientSession struct {
+ Model string `json:"model,omitempty"`
+ // The set of modalities the model can respond with. To disable audio, set this to ["text"].
+ Modalities []Modality `json:"modalities,omitempty"`
+ // The default system instructions prepended to model calls.
+ Instructions string `json:"instructions,omitempty"`
+ // The voice the model uses to respond - one of alloy, echo, or shimmer. Cannot be changed once the model has responded with audio at least once.
+ Voice string `json:"voice,omitempty"`
+ // The format of input audio. Options are "pcm16", "g711_ulaw", or "g711_alaw".
+ InputAudioFormat AudioFormat `json:"input_audio_format,omitempty"`
+ // The format of output audio. Options are "pcm16", "g711_ulaw", or "g711_alaw".
+ OutputAudioFormat AudioFormat `json:"output_audio_format,omitempty"`
+ // Configuration for input audio transcription. Can be set to `nil` to turn off.
+ InputAudioTranscription *InputAudioTranscription `json:"input_audio_transcription,omitempty"`
+ // Configuration for turn detection. Can be set to `nil` to turn off.
+ TurnDetection *ClientTurnDetection `json:"turn_detection"`
+ // Tools (functions) available to the model.
+ Tools []Tool `json:"tools,omitempty"`
+ // How the model chooses tools. Options are "auto", "none", "required", or specify a function.
+ ToolChoice ToolChoiceInterface `json:"tool_choice,omitempty"`
+ // Sampling temperature for the model.
+ Temperature *float32 `json:"temperature,omitempty"`
+ // Maximum number of output tokens for a single assistant response, inclusive of tool calls. Provide an integer between 1 and 4096 to limit output tokens, or "inf" for the maximum available tokens for a given model. Defaults to "inf".
+ MaxOutputTokens IntOrInf `json:"max_response_output_tokens,omitempty"`
+}
+
+type CreateSessionRequest struct {
+ ClientSession
+
+ // The Realtime model used for this session.
+ Model string `json:"model,omitempty"`
+}
+
+type ClientSecret struct {
+ // Ephemeral key usable in client environments to authenticate connections to the Realtime API. Use this in client-side environments rather than a standard API token, which should only be used server-side.
+ Value string `json:"value"`
+ // Timestamp for when the token expires. Currently, all tokens expire after one minute.
+ ExpiresAt int64 `json:"expires_at"`
+}
+
+type CreateSessionResponse struct {
+ ServerSession
+
+ // Ephemeral key returned by the API.
+ ClientSecret ClientSecret `json:"client_secret"`
+}
+
+// SessionUpdateEvent is the event for session update.
+// Send this event to update the session’s default configuration.
+// See https://platform.openai.com/docs/api-reference/realtime-client-events/session/update
+type SessionUpdateEvent struct {
+ EventBase
+ // Session configuration to update.
+ Session ClientSession `json:"session"`
+}
+
+func (m SessionUpdateEvent) ClientEventType() ClientEventType {
+ return ClientEventTypeSessionUpdate
+}
+
+func (m SessionUpdateEvent) MarshalJSON() ([]byte, error) {
+ type sessionUpdateEvent SessionUpdateEvent
+ v := struct {
+ *sessionUpdateEvent
+ Type ClientEventType `json:"type"`
+ }{
+ sessionUpdateEvent: (*sessionUpdateEvent)(&m),
+ Type: m.ClientEventType(),
+ }
+ return json.Marshal(v)
+}
+
+// InputAudioBufferAppendEvent is the event for input audio buffer append.
+// Send this event to append audio bytes to the input audio buffer.
+// See https://platform.openai.com/docs/api-reference/realtime-client-events/input_audio_buffer/append
+type InputAudioBufferAppendEvent struct {
+ EventBase
+ Audio string `json:"audio"` // Base64-encoded audio bytes.
+}
+
+func (m InputAudioBufferAppendEvent) ClientEventType() ClientEventType {
+ return ClientEventTypeInputAudioBufferAppend
+}
+
+func (m InputAudioBufferAppendEvent) MarshalJSON() ([]byte, error) {
+ type inputAudioBufferAppendEvent InputAudioBufferAppendEvent
+ v := struct {
+ *inputAudioBufferAppendEvent
+ Type ClientEventType `json:"type"`
+ }{
+ inputAudioBufferAppendEvent: (*inputAudioBufferAppendEvent)(&m),
+ Type: m.ClientEventType(),
+ }
+ return json.Marshal(v)
+}
+
+// InputAudioBufferCommitEvent is the event for input audio buffer commit.
+// Send this event to commit audio bytes to a user message.
+// See https://platform.openai.com/docs/api-reference/realtime-client-events/input_audio_buffer/commit
+type InputAudioBufferCommitEvent struct {
+ EventBase
+}
+
+func (m InputAudioBufferCommitEvent) ClientEventType() ClientEventType {
+ return ClientEventTypeInputAudioBufferCommit
+}
+
+func (m InputAudioBufferCommitEvent) MarshalJSON() ([]byte, error) {
+ type inputAudioBufferCommitEvent InputAudioBufferCommitEvent
+ v := struct {
+ *inputAudioBufferCommitEvent
+ Type ClientEventType `json:"type"`
+ }{
+ inputAudioBufferCommitEvent: (*inputAudioBufferCommitEvent)(&m),
+ Type: m.ClientEventType(),
+ }
+ return json.Marshal(v)
+}
+
+// InputAudioBufferClearEvent is the event for input audio buffer clear.
+// Send this event to clear the audio bytes in the buffer.
+// See https://platform.openai.com/docs/api-reference/realtime-client-events/input_audio_buffer/clear
+type InputAudioBufferClearEvent struct {
+ EventBase
+}
+
+func (m InputAudioBufferClearEvent) ClientEventType() ClientEventType {
+ return ClientEventTypeInputAudioBufferClear
+}
+
+func (m InputAudioBufferClearEvent) MarshalJSON() ([]byte, error) {
+ type inputAudioBufferClearEvent InputAudioBufferClearEvent
+ v := struct {
+ *inputAudioBufferClearEvent
+ Type ClientEventType `json:"type"`
+ }{
+ inputAudioBufferClearEvent: (*inputAudioBufferClearEvent)(&m),
+ Type: m.ClientEventType(),
+ }
+ return json.Marshal(v)
+}
+
+// ConversationItemCreateEvent is the event for conversation item create.
+// Send this event when adding an item to the conversation.
+// See https://platform.openai.com/docs/api-reference/realtime-client-events/conversation/item/create
+type ConversationItemCreateEvent struct {
+ EventBase
+ // The ID of the preceding item after which the new item will be inserted.
+ PreviousItemID string `json:"previous_item_id,omitempty"`
+ // The item to add to the conversation.
+ Item MessageItem `json:"item"`
+}
+
+func (m ConversationItemCreateEvent) ClientEventType() ClientEventType {
+ return ClientEventTypeConversationItemCreate
+}
+
+func (m ConversationItemCreateEvent) MarshalJSON() ([]byte, error) {
+ type conversationItemCreateEvent ConversationItemCreateEvent
+ v := struct {
+ *conversationItemCreateEvent
+ Type ClientEventType `json:"type"`
+ }{
+ conversationItemCreateEvent: (*conversationItemCreateEvent)(&m),
+ Type: m.ClientEventType(),
+ }
+ return json.Marshal(v)
+}
+
+// ConversationItemTruncateEvent is the event for conversation item truncate.
+// Send this event when you want to truncate a previous assistant message’s audio.
+// See https://platform.openai.com/docs/api-reference/realtime-client-events/conversation/item/truncate
+type ConversationItemTruncateEvent struct {
+ EventBase
+ // The ID of the assistant message item to truncate.
+ ItemID string `json:"item_id"`
+ // The index of the content part to truncate.
+ ContentIndex int `json:"content_index"`
+ // Inclusive duration up to which audio is truncated, in milliseconds.
+ AudioEndMs int `json:"audio_end_ms"`
+}
+
+func (m ConversationItemTruncateEvent) ClientEventType() ClientEventType {
+ return ClientEventTypeConversationItemTruncate
+}
+
+func (m ConversationItemTruncateEvent) MarshalJSON() ([]byte, error) {
+ type conversationItemTruncateEvent ConversationItemTruncateEvent
+ v := struct {
+ *conversationItemTruncateEvent
+ Type ClientEventType `json:"type"`
+ }{
+ conversationItemTruncateEvent: (*conversationItemTruncateEvent)(&m),
+ Type: m.ClientEventType(),
+ }
+ return json.Marshal(v)
+}
+
+// ConversationItemDeleteEvent is the event for conversation item delete.
+// Send this event when you want to remove any item from the conversation history.
+// See https://platform.openai.com/docs/api-reference/realtime-client-events/conversation/item/delete
+type ConversationItemDeleteEvent struct {
+ EventBase
+ // The ID of the item to delete.
+ ItemID string `json:"item_id"`
+}
+
+func (m ConversationItemDeleteEvent) ClientEventType() ClientEventType {
+ return ClientEventTypeConversationItemDelete
+}
+
+func (m ConversationItemDeleteEvent) MarshalJSON() ([]byte, error) {
+ type conversationItemDeleteEvent ConversationItemDeleteEvent
+ v := struct {
+ *conversationItemDeleteEvent
+ Type ClientEventType `json:"type"`
+ }{
+ conversationItemDeleteEvent: (*conversationItemDeleteEvent)(&m),
+ Type: m.ClientEventType(),
+ }
+ return json.Marshal(v)
+}
+
+type ResponseCreateParams struct {
+ // The modalities for the response.
+ Modalities []Modality `json:"modalities,omitempty"`
+ // Instructions for the model.
+ Instructions string `json:"instructions,omitempty"`
+ // The voice the model uses to respond - one of alloy, echo, or shimmer.
+ Voice string `json:"voice,omitempty"`
+ // The format of output audio.
+ OutputAudioFormat AudioFormat `json:"output_audio_format,omitempty"`
+ // Tools (functions) available to the model.
+ Tools []Tool `json:"tools,omitempty"`
+ // How the model chooses tools.
+ ToolChoice ToolChoiceInterface `json:"tool_choice,omitempty"`
+ // Sampling temperature.
+ Temperature *float32 `json:"temperature,omitempty"`
+ // Maximum number of output tokens for a single assistant response, inclusive of tool calls. Provide an integer between 1 and 4096 to limit output tokens, or "inf" for the maximum available tokens for a given model. Defaults to "inf".
+ MaxOutputTokens IntOrInf `json:"max_output_tokens,omitempty"`
+}
+
+// ResponseCreateEvent is the event for response create.
+// Send this event to trigger a response generation.
+// See https://platform.openai.com/docs/api-reference/realtime-client-events/response/create
+type ResponseCreateEvent struct {
+ EventBase
+ // Configuration for the response.
+ Response ResponseCreateParams `json:"response"`
+}
+
+func (m ResponseCreateEvent) ClientEventType() ClientEventType {
+ return ClientEventTypeResponseCreate
+}
+
+func (m ResponseCreateEvent) MarshalJSON() ([]byte, error) {
+ type responseCreateEvent ResponseCreateEvent
+ v := struct {
+ *responseCreateEvent
+ Type ClientEventType `json:"type"`
+ }{
+ responseCreateEvent: (*responseCreateEvent)(&m),
+ Type: m.ClientEventType(),
+ }
+ return json.Marshal(v)
+}
+
+// ResponseCancelEvent is the event for response cancel.
+// Send this event to cancel an in-progress response.
+// See https://platform.openai.com/docs/api-reference/realtime-client-events/response/cancel
+type ResponseCancelEvent struct {
+ EventBase
+ // A specific response ID to cancel - if not provided, will cancel an in-progress response in the default conversation.
+ ResponseID string `json:"response_id,omitempty"`
+}
+
+func (m ResponseCancelEvent) ClientEventType() ClientEventType {
+ return ClientEventTypeResponseCancel
+}
+
+func (m ResponseCancelEvent) MarshalJSON() ([]byte, error) {
+ type responseCancelEvent ResponseCancelEvent
+ v := struct {
+ *responseCancelEvent
+ Type ClientEventType `json:"type"`
+ }{
+ responseCancelEvent: (*responseCancelEvent)(&m),
+ Type: m.ClientEventType(),
+ }
+ return json.Marshal(v)
+}
+
+// MarshalClientEvent marshals the client event to JSON.
+func MarshalClientEvent(event ClientEvent) ([]byte, error) {
+ return json.Marshal(event)
+}
+
+type ServerEventType string
+
+const (
+ ServerEventTypeError ServerEventType = "error"
+ ServerEventTypeSessionCreated ServerEventType = "session.created"
+ ServerEventTypeSessionUpdated ServerEventType = "session.updated"
+ ServerEventTypeTranscriptionSessionCreated ServerEventType = "transcription_session.created"
+ ServerEventTypeTranscriptionSessionUpdated ServerEventType = "transcription_session.updated"
+ ServerEventTypeConversationCreated ServerEventType = "conversation.created"
+ ServerEventTypeInputAudioBufferCommitted ServerEventType = "input_audio_buffer.committed"
+ ServerEventTypeInputAudioBufferCleared ServerEventType = "input_audio_buffer.cleared"
+ ServerEventTypeInputAudioBufferSpeechStarted ServerEventType = "input_audio_buffer.speech_started"
+ ServerEventTypeInputAudioBufferSpeechStopped ServerEventType = "input_audio_buffer.speech_stopped"
+ ServerEventTypeConversationItemCreated ServerEventType = "conversation.item.created"
+ ServerEventTypeConversationItemInputAudioTranscriptionCompleted ServerEventType = "conversation.item.input_audio_transcription.completed"
+ ServerEventTypeConversationItemInputAudioTranscriptionFailed ServerEventType = "conversation.item.input_audio_transcription.failed"
+ ServerEventTypeConversationItemTruncated ServerEventType = "conversation.item.truncated"
+ ServerEventTypeConversationItemDeleted ServerEventType = "conversation.item.deleted"
+ ServerEventTypeResponseCreated ServerEventType = "response.created"
+ ServerEventTypeResponseDone ServerEventType = "response.done"
+ ServerEventTypeResponseOutputItemAdded ServerEventType = "response.output_item.added"
+ ServerEventTypeResponseOutputItemDone ServerEventType = "response.output_item.done"
+ ServerEventTypeResponseContentPartAdded ServerEventType = "response.content_part.added"
+ ServerEventTypeResponseContentPartDone ServerEventType = "response.content_part.done"
+ ServerEventTypeResponseTextDelta ServerEventType = "response.text.delta"
+ ServerEventTypeResponseTextDone ServerEventType = "response.text.done"
+ ServerEventTypeResponseAudioTranscriptDelta ServerEventType = "response.audio_transcript.delta"
+ ServerEventTypeResponseAudioTranscriptDone ServerEventType = "response.audio_transcript.done"
+ ServerEventTypeResponseAudioDelta ServerEventType = "response.audio.delta"
+ ServerEventTypeResponseAudioDone ServerEventType = "response.audio.done"
+ ServerEventTypeResponseFunctionCallArgumentsDelta ServerEventType = "response.function_call_arguments.delta"
+ ServerEventTypeResponseFunctionCallArgumentsDone ServerEventType = "response.function_call_arguments.done"
+ ServerEventTypeRateLimitsUpdated ServerEventType = "rate_limits.updated"
+)
+
+// ServerEvent is the interface for server events.
+type ServerEvent interface {
+ ServerEventType() ServerEventType
+}
+
+// ServerEventBase is the base struct for all server events.
+type ServerEventBase struct {
+ // The unique ID of the server event.
+ EventID string `json:"event_id,omitempty"`
+ // The type of the server event.
+ Type ServerEventType `json:"type"`
+}
+
+func (m ServerEventBase) ServerEventType() ServerEventType {
+ return m.Type
+}
+
+// ErrorEvent is the event for error.
+// Returned when an error occurs.
+// See https://platform.openai.com/docs/api-reference/realtime-server-events/error
+type ErrorEvent struct {
+ ServerEventBase
+ // Details of the error.
+ Error Error `json:"error"`
+}
+
+// SessionCreatedEvent is the event for session created.
+// Returned when a session is created. Emitted automatically when a new connection is established.
+// See https://platform.openai.com/docs/api-reference/realtime-server-events/session/created
+type SessionCreatedEvent struct {
+ ServerEventBase
+ // The session resource.
+ Session ServerSession `json:"session"`
+}
+
+// TranscriptionSessionCreatedEvent is the event for session created.
+// Returned when a transcription session is created.
+// See https://platform.openai.com/docs/api-reference/realtime-server-events/session/created
+type TranscriptionSessionCreatedEvent struct {
+ ServerEventBase
+ // The transcription session resource.
+ Session ServerSession `json:"session"`
+}
+
+// SessionUpdatedEvent is the event for session updated.
+// Returned when a session is updated.
+// See https://platform.openai.com/docs/api-reference/realtime-server-events/session/updated
+type SessionUpdatedEvent struct {
+ ServerEventBase
+ // The updated session resource.
+ Session ServerSession `json:"session"`
+}
+
+// ConversationCreatedEvent is the event for conversation created.
+// Returned when a conversation is created. Emitted right after session creation.
+// See https://platform.openai.com/docs/api-reference/realtime-server-events/conversation/created
+type ConversationCreatedEvent struct {
+ ServerEventBase
+ // The conversation resource.
+ Conversation Conversation `json:"conversation"`
+}
+
+// InputAudioBufferCommittedEvent is the event for input audio buffer committed.
+// Returned when an input audio buffer is committed, either by the client or automatically in server VAD mode.
+// See https://platform.openai.com/docs/api-reference/realtime-server-events/input_audio_buffer/committed
+type InputAudioBufferCommittedEvent struct {
+ ServerEventBase
+ // The ID of the preceding item after which the new item will be inserted.
+ PreviousItemID string `json:"previous_item_id,omitempty"`
+ // The ID of the user message item that will be created.
+ ItemID string `json:"item_id"`
+}
+
+// InputAudioBufferClearedEvent is the event for input audio buffer cleared.
+// Returned when the input audio buffer is cleared by the client.
+// See https://platform.openai.com/docs/api-reference/realtime-server-events/input_audio_buffer/cleared
+type InputAudioBufferClearedEvent struct {
+ ServerEventBase
+}
+
+// InputAudioBufferSpeechStartedEvent is the event for input audio buffer speech started.
+// Returned in server turn detection mode when speech is detected.
+// See https://platform.openai.com/docs/api-reference/realtime-server-events/input_audio_buffer/speech_started
+type InputAudioBufferSpeechStartedEvent struct {
+ ServerEventBase
+ // Milliseconds since the session started when speech was detected.
+ AudioStartMs int64 `json:"audio_start_ms"`
+ // The ID of the user message item that will be created when speech stops.
+ ItemID string `json:"item_id"`
+}
+
+// InputAudioBufferSpeechStoppedEvent is the event for input audio buffer speech stopped.
+// Returned in server turn detection mode when speech stops.
+// See https://platform.openai.com/docs/api-reference/realtime-server-events/input_audio_buffer/speech_stopped
+type InputAudioBufferSpeechStoppedEvent struct {
+ ServerEventBase
+ // Milliseconds since the session started when speech stopped.
+ AudioEndMs int64 `json:"audio_end_ms"`
+ // The ID of the user message item that will be created.
+ ItemID string `json:"item_id"`
+}
+
+type ConversationItemCreatedEvent struct {
+ ServerEventBase
+ PreviousItemID string `json:"previous_item_id,omitempty"`
+ Item ResponseMessageItem `json:"item"`
+}
+
+type ConversationItemInputAudioTranscriptionCompletedEvent struct {
+ ServerEventBase
+ ItemID string `json:"item_id"`
+ ContentIndex int `json:"content_index"`
+ Transcript string `json:"transcript"`
+}
+
+type ConversationItemInputAudioTranscriptionFailedEvent struct {
+ ServerEventBase
+ ItemID string `json:"item_id"`
+ ContentIndex int `json:"content_index"`
+ Error Error `json:"error"`
+}
+
+type ConversationItemTruncatedEvent struct {
+ ServerEventBase
+ ItemID string `json:"item_id"` // The ID of the assistant message item that was truncated.
+ ContentIndex int `json:"content_index"` // The index of the content part that was truncated.
+ AudioEndMs int `json:"audio_end_ms"` // The duration up to which the audio was truncated, in milliseconds.
+}
+
+type ConversationItemDeletedEvent struct {
+ ServerEventBase
+ ItemID string `json:"item_id"` // The ID of the item that was deleted.
+}
+
+// ResponseCreatedEvent is the event for response created.
+// Returned when a new Response is created. The first event of response creation, where the response is in an initial state of "in_progress".
+// See https://platform.openai.com/docs/api-reference/realtime-server-events/response/created
+type ResponseCreatedEvent struct {
+ ServerEventBase
+ // The response resource.
+ Response Response `json:"response"`
+}
+
+// ResponseDoneEvent is the event for response done.
+// Returned when a Response is done streaming. Always emitted, no matter the final state.
+// See https://platform.openai.com/docs/api-reference/realtime-server-events/response/done
+type ResponseDoneEvent struct {
+ ServerEventBase
+ // The response resource.
+ Response Response `json:"response"`
+}
+
+// ResponseOutputItemAddedEvent is the event for response output item added.
+// Returned when a new Item is created during response generation.
+// See https://platform.openai.com/docs/api-reference/realtime-server-events/response/output_item/added
+type ResponseOutputItemAddedEvent struct {
+ ServerEventBase
+ // The ID of the response to which the item belongs.
+ ResponseID string `json:"response_id"`
+ // The index of the output item in the response.
+ OutputIndex int `json:"output_index"`
+ // The item that was added.
+ Item ResponseMessageItem `json:"item"`
+}
+
+// ResponseOutputItemDoneEvent is the event for response output item done.
+// Returned when an Item is done streaming. Also emitted when a Response is interrupted, incomplete, or cancelled.
+// See https://platform.openai.com/docs/api-reference/realtime-server-events/response/output_item/done
+type ResponseOutputItemDoneEvent struct {
+ ServerEventBase
+ // The ID of the response to which the item belongs.
+ ResponseID string `json:"response_id"`
+ // The index of the output item in the response.
+ OutputIndex int `json:"output_index"`
+ // The completed item.
+ Item ResponseMessageItem `json:"item"`
+}
+
+// ResponseContentPartAddedEvent is the event for response content part added.
+// Returned when a new content part is added to an assistant message item during response generation.
+// See https://platform.openai.com/docs/api-reference/realtime-server-events/response/content_part/added
+type ResponseContentPartAddedEvent struct {
+ ServerEventBase
+ ResponseID string `json:"response_id"`
+ ItemID string `json:"item_id"`
+ OutputIndex int `json:"output_index"`
+ ContentIndex int `json:"content_index"`
+ Part MessageContentPart `json:"part"`
+}
+
+// ResponseContentPartDoneEvent is the event for response content part done.
+// Returned when a content part is done streaming in an assistant message item. Also emitted when a Response is interrupted, incomplete, or cancelled.
+// See https://platform.openai.com/docs/api-reference/realtime-server-events/response/content_part/done
+type ResponseContentPartDoneEvent struct {
+ ServerEventBase
+ // The ID of the response.
+ ResponseID string `json:"response_id"`
+ // The ID of the item to which the content part was added.
+ ItemID string `json:"item_id"`
+ // The index of the output item in the response.
+ OutputIndex int `json:"output_index"`
+ // The index of the content part in the item's content array.
+ ContentIndex int `json:"content_index"`
+ // The content part that was added.
+ Part MessageContentPart `json:"part"`
+}
+
+// ResponseTextDeltaEvent is the event for response text delta.
+// Returned when the text value of a "text" content part is updated.
+// See https://platform.openai.com/docs/api-reference/realtime-server-events/response/text/delta
+type ResponseTextDeltaEvent struct {
+ ServerEventBase
+ ResponseID string `json:"response_id"`
+ ItemID string `json:"item_id"`
+ OutputIndex int `json:"output_index"`
+ ContentIndex int `json:"content_index"`
+ Delta string `json:"delta"`
+}
+
+// ResponseTextDoneEvent is the event for response text done.
+// Returned when the text value of a "text" content part is done streaming. Also emitted when a Response is interrupted, incomplete, or cancelled.
+// See https://platform.openai.com/docs/api-reference/realtime-server-events/response/text/done
+type ResponseTextDoneEvent struct {
+ ServerEventBase
+ ResponseID string `json:"response_id"`
+ ItemID string `json:"item_id"`
+ OutputIndex int `json:"output_index"`
+ ContentIndex int `json:"content_index"`
+ Text string `json:"text"`
+}
+
+// ResponseAudioTranscriptDeltaEvent is the event for response audio transcript delta.
+// Returned when the model-generated transcription of audio output is updated.
+// See https://platform.openai.com/docs/api-reference/realtime-server-events/response/audio_transcript/delta
+type ResponseAudioTranscriptDeltaEvent struct {
+ ServerEventBase
+ // The ID of the response.
+ ResponseID string `json:"response_id"`
+ // The ID of the item.
+ ItemID string `json:"item_id"`
+ // The index of the output item in the response.
+ OutputIndex int `json:"output_index"`
+ // The index of the content part in the item's content array.
+ ContentIndex int `json:"content_index"`
+ // The transcript delta.
+ Delta string `json:"delta"`
+}
+
+// ResponseAudioTranscriptDoneEvent is the event for response audio transcript done.
+// Returned when the model-generated transcription of audio output is done streaming. Also emitted when a Response is interrupted, incomplete, or cancelled.
+// See https://platform.openai.com/docs/api-reference/realtime-server-events/response/audio_transcript/done
+type ResponseAudioTranscriptDoneEvent struct {
+ ServerEventBase
+ // The ID of the response.
+ ResponseID string `json:"response_id"`
+ // The ID of the item.
+ ItemID string `json:"item_id"`
+ // The index of the output item in the response.
+ OutputIndex int `json:"output_index"`
+ // The index of the content part in the item's content array.
+ ContentIndex int `json:"content_index"`
+ // The final transcript of the audio.
+ Transcript string `json:"transcript"`
+}
+
+// ResponseAudioDeltaEvent is the event for response audio delta.
+// Returned when the model-generated audio is updated.
+// See https://platform.openai.com/docs/api-reference/realtime-server-events/response/audio/delta
+type ResponseAudioDeltaEvent struct {
+ ServerEventBase
+ // The ID of the response.
+ ResponseID string `json:"response_id"`
+ // The ID of the item.
+ ItemID string `json:"item_id"`
+ // The index of the output item in the response.
+ OutputIndex int `json:"output_index"`
+ // The index of the content part in the item's content array.
+ ContentIndex int `json:"content_index"`
+ // Base64-encoded audio data delta.
+ Delta string `json:"delta"`
+}
+
+// ResponseAudioDoneEvent is the event for response audio done.
+// Returned when the model-generated audio is done. Also emitted when a Response is interrupted, incomplete, or cancelled.
+// See https://platform.openai.com/docs/api-reference/realtime-server-events/response/audio/done
+type ResponseAudioDoneEvent struct {
+ ServerEventBase
+ // The ID of the response.
+ ResponseID string `json:"response_id"`
+ // The ID of the item.
+ ItemID string `json:"item_id"`
+ // The index of the output item in the response.
+ OutputIndex int `json:"output_index"`
+ // The index of the content part in the item's content array.
+ ContentIndex int `json:"content_index"`
+}
+
+// ResponseFunctionCallArgumentsDeltaEvent is the event for response function call arguments delta.
+// Returned when the model-generated function call arguments are updated.
+// See https://platform.openai.com/docs/api-reference/realtime-server-events/response/function_call_arguments/delta
+type ResponseFunctionCallArgumentsDeltaEvent struct {
+ ServerEventBase
+ // The ID of the response.
+ ResponseID string `json:"response_id"`
+ // The ID of the item.
+ ItemID string `json:"item_id"`
+ // The index of the output item in the response.
+ OutputIndex int `json:"output_index"`
+ // The ID of the function call.
+ CallID string `json:"call_id"`
+ // The arguments delta as a JSON string.
+ Delta string `json:"delta"`
+}
+
+// ResponseFunctionCallArgumentsDoneEvent is the event for response function call arguments done.
+// Returned when the model-generated function call arguments are done streaming. Also emitted when a Response is interrupted, incomplete, or cancelled.
+// See https://platform.openai.com/docs/api-reference/realtime-server-events/response/function_call_arguments/done
+type ResponseFunctionCallArgumentsDoneEvent struct {
+ ServerEventBase
+ // The ID of the response.
+ ResponseID string `json:"response_id"`
+ // The ID of the item.
+ ItemID string `json:"item_id"`
+ // The index of the output item in the response.
+ OutputIndex int `json:"output_index"`
+ // The ID of the function call.
+ CallID string `json:"call_id"`
+ // The final arguments as a JSON string.
+ Arguments string `json:"arguments"`
+ // The name of the function. Not shown in API reference but present in the actual event.
+ Name string `json:"name"`
+}
+
+// RateLimitsUpdatedEvent is the event for rate limits updated.
+// Emitted after every "response.done" event to indicate the updated rate limits.
+// See https://platform.openai.com/docs/api-reference/realtime-server-events/rate_limits/updated
+type RateLimitsUpdatedEvent struct {
+ ServerEventBase
+ // List of rate limit information.
+ RateLimits []RateLimit `json:"rate_limits"`
+}
+
+type ServerEventInterface interface {
+ ErrorEvent |
+ SessionCreatedEvent |
+ SessionUpdatedEvent |
+ ConversationCreatedEvent |
+ InputAudioBufferCommittedEvent |
+ InputAudioBufferClearedEvent |
+ InputAudioBufferSpeechStartedEvent |
+ InputAudioBufferSpeechStoppedEvent |
+ ConversationItemCreatedEvent |
+ ConversationItemInputAudioTranscriptionCompletedEvent |
+ ConversationItemInputAudioTranscriptionFailedEvent |
+ ConversationItemTruncatedEvent |
+ ConversationItemDeletedEvent |
+ ResponseCreatedEvent |
+ ResponseDoneEvent |
+ ResponseOutputItemAddedEvent |
+ ResponseOutputItemDoneEvent |
+ ResponseContentPartAddedEvent |
+ ResponseContentPartDoneEvent |
+ ResponseTextDeltaEvent |
+ ResponseTextDoneEvent |
+ ResponseAudioTranscriptDeltaEvent |
+ ResponseAudioTranscriptDoneEvent |
+ ResponseAudioDeltaEvent |
+ ResponseAudioDoneEvent |
+ ResponseFunctionCallArgumentsDeltaEvent |
+ ResponseFunctionCallArgumentsDoneEvent |
+ RateLimitsUpdatedEvent
+}
+
+func unmarshalServerEvent[T ServerEventInterface](data []byte) (T, error) {
+ var t T
+ err := json.Unmarshal(data, &t)
+ if err != nil {
+ return t, err
+ }
+ return t, nil
+}
+
+// UnmarshalServerEvent unmarshals the server event from the given JSON data.
+func UnmarshalServerEvent(data []byte) (ServerEvent, error) { //nolint:funlen,cyclop // TODO: optimize
+ var eventType struct {
+ Type ServerEventType `json:"type"`
+ }
+ err := json.Unmarshal(data, &eventType)
+ if err != nil {
+ return nil, err
+ }
+ switch eventType.Type {
+ case ServerEventTypeError:
+ return unmarshalServerEvent[ErrorEvent](data)
+ case ServerEventTypeSessionCreated:
+ return unmarshalServerEvent[SessionCreatedEvent](data)
+ case ServerEventTypeSessionUpdated:
+ return unmarshalServerEvent[SessionUpdatedEvent](data)
+ case ServerEventTypeConversationCreated:
+ return unmarshalServerEvent[ConversationCreatedEvent](data)
+ case ServerEventTypeInputAudioBufferCommitted:
+ return unmarshalServerEvent[InputAudioBufferCommittedEvent](data)
+ case ServerEventTypeInputAudioBufferCleared:
+ return unmarshalServerEvent[InputAudioBufferClearedEvent](data)
+ case ServerEventTypeInputAudioBufferSpeechStarted:
+ return unmarshalServerEvent[InputAudioBufferSpeechStartedEvent](data)
+ case ServerEventTypeInputAudioBufferSpeechStopped:
+ return unmarshalServerEvent[InputAudioBufferSpeechStoppedEvent](data)
+ case ServerEventTypeConversationItemCreated:
+ return unmarshalServerEvent[ConversationItemCreatedEvent](data)
+ case ServerEventTypeConversationItemInputAudioTranscriptionCompleted:
+ return unmarshalServerEvent[ConversationItemInputAudioTranscriptionCompletedEvent](data)
+ case ServerEventTypeConversationItemInputAudioTranscriptionFailed:
+ return unmarshalServerEvent[ConversationItemInputAudioTranscriptionFailedEvent](data)
+ case ServerEventTypeConversationItemTruncated:
+ return unmarshalServerEvent[ConversationItemTruncatedEvent](data)
+ case ServerEventTypeConversationItemDeleted:
+ return unmarshalServerEvent[ConversationItemDeletedEvent](data)
+ case ServerEventTypeResponseCreated:
+ return unmarshalServerEvent[ResponseCreatedEvent](data)
+ case ServerEventTypeResponseDone:
+ return unmarshalServerEvent[ResponseDoneEvent](data)
+ case ServerEventTypeResponseOutputItemAdded:
+ return unmarshalServerEvent[ResponseOutputItemAddedEvent](data)
+ case ServerEventTypeResponseOutputItemDone:
+ return unmarshalServerEvent[ResponseOutputItemDoneEvent](data)
+ case ServerEventTypeResponseContentPartAdded:
+ return unmarshalServerEvent[ResponseContentPartAddedEvent](data)
+ case ServerEventTypeResponseContentPartDone:
+ return unmarshalServerEvent[ResponseContentPartDoneEvent](data)
+ case ServerEventTypeResponseTextDelta:
+ return unmarshalServerEvent[ResponseTextDeltaEvent](data)
+ case ServerEventTypeResponseTextDone:
+ return unmarshalServerEvent[ResponseTextDoneEvent](data)
+ case ServerEventTypeResponseAudioTranscriptDelta:
+ return unmarshalServerEvent[ResponseAudioTranscriptDeltaEvent](data)
+ case ServerEventTypeResponseAudioTranscriptDone:
+ return unmarshalServerEvent[ResponseAudioTranscriptDoneEvent](data)
+ case ServerEventTypeResponseAudioDelta:
+ return unmarshalServerEvent[ResponseAudioDeltaEvent](data)
+ case ServerEventTypeResponseAudioDone:
+ return unmarshalServerEvent[ResponseAudioDoneEvent](data)
+ case ServerEventTypeResponseFunctionCallArgumentsDelta:
+ return unmarshalServerEvent[ResponseFunctionCallArgumentsDeltaEvent](data)
+ case ServerEventTypeResponseFunctionCallArgumentsDone:
+ return unmarshalServerEvent[ResponseFunctionCallArgumentsDoneEvent](data)
+ case ServerEventTypeRateLimitsUpdated:
+ return unmarshalServerEvent[RateLimitsUpdatedEvent](data)
+ default:
+ // This should never happen.
+ return nil, fmt.Errorf("unknown server event type: %s", eventType.Type)
+ }
+}
diff --git a/core/http/endpoints/openai/video.go b/core/http/endpoints/openai/video.go
new file mode 100644
index 0000000000000000000000000000000000000000..12c06ffe61ac4a2a6878e0842c4849b9c5bd4606
--- /dev/null
+++ b/core/http/endpoints/openai/video.go
@@ -0,0 +1,140 @@
+package openai
+
+import (
+ "encoding/json"
+ "fmt"
+ "strconv"
+ "strings"
+
+ "github.com/labstack/echo/v4"
+ "github.com/mudler/LocalAI/core/config"
+ "github.com/mudler/LocalAI/core/http/endpoints/localai"
+ "github.com/mudler/LocalAI/core/http/middleware"
+ "github.com/mudler/LocalAI/core/schema"
+ model "github.com/mudler/LocalAI/pkg/model"
+)
+
+func VideoEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
+ return func(c echo.Context) error {
+ input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
+ if !ok || input == nil {
+ return echo.ErrBadRequest
+ }
+ var raw map[string]interface{}
+ body := make([]byte, 0)
+ if c.Request().Body != nil {
+ c.Request().Body.Read(body)
+ }
+ if len(body) > 0 {
+ _ = json.Unmarshal(body, &raw)
+ }
+ // Build VideoRequest using shared mapper
+ vr := MapOpenAIToVideo(input, raw)
+ // Place VideoRequest into context so localai.VideoEndpoint can consume it
+ c.Set(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, vr)
+ // Delegate to existing localai handler
+ return localai.VideoEndpoint(cl, ml, appConfig)(c)
+ }
+}
+
+// VideoEndpoint godoc
+// @Summary Generate a video from an OpenAI-compatible request
+// @Description Accepts an OpenAI-style request and delegates to the LocalAI video generator
+// @Tags openai
+// @Accept json
+// @Produce json
+// @Param request body schema.OpenAIRequest true "OpenAI-style request"
+// @Success 200 {object} map[string]interface{}
+// @Failure 400 {object} map[string]interface{}
+// @Router /v1/videos [post]
+
+func MapOpenAIToVideo(input *schema.OpenAIRequest, raw map[string]interface{}) *schema.VideoRequest {
+ vr := &schema.VideoRequest{}
+ if input == nil {
+ return vr
+ }
+
+ if input.Model != "" {
+ vr.Model = input.Model
+ }
+
+ // Prompt mapping
+ switch p := input.Prompt.(type) {
+ case string:
+ vr.Prompt = p
+ case []interface{}:
+ if len(p) > 0 {
+ if s, ok := p[0].(string); ok {
+ vr.Prompt = s
+ }
+ }
+ }
+
+ // Size
+ size := input.Size
+ if size == "" && raw != nil {
+ if v, ok := raw["size"].(string); ok {
+ size = v
+ }
+ }
+ if size != "" {
+ parts := strings.SplitN(size, "x", 2)
+ if len(parts) == 2 {
+ if wi, err := strconv.Atoi(parts[0]); err == nil {
+ vr.Width = int32(wi)
+ }
+ if hi, err := strconv.Atoi(parts[1]); err == nil {
+ vr.Height = int32(hi)
+ }
+ }
+ }
+
+ // seconds -> num frames
+ secondsStr := ""
+ if raw != nil {
+ if v, ok := raw["seconds"].(string); ok {
+ secondsStr = v
+ } else if v, ok := raw["seconds"].(float64); ok {
+ secondsStr = fmt.Sprintf("%v", int(v))
+ }
+ }
+ fps := int32(30)
+ if raw != nil {
+ if rawFPS, ok := raw["fps"]; ok {
+ switch rf := rawFPS.(type) {
+ case float64:
+ fps = int32(rf)
+ case string:
+ if fi, err := strconv.Atoi(rf); err == nil {
+ fps = int32(fi)
+ }
+ }
+ }
+ }
+ if secondsStr != "" {
+ if secF, err := strconv.Atoi(secondsStr); err == nil {
+ vr.FPS = fps
+ vr.NumFrames = int32(secF) * fps
+ }
+ }
+
+ // input_reference
+ if raw != nil {
+ if v, ok := raw["input_reference"].(string); ok {
+ vr.StartImage = v
+ }
+ }
+
+ // response format
+ if input.ResponseFormat != nil {
+ if rf, ok := input.ResponseFormat.(string); ok {
+ vr.ResponseFormat = rf
+ }
+ }
+
+ if input.Step != 0 {
+ vr.Step = int32(input.Step)
+ }
+
+ return vr
+}
diff --git a/core/http/explorer.go b/core/http/explorer.go
new file mode 100644
index 0000000000000000000000000000000000000000..67c190561bf5c1327f65972b155cac073071502a
--- /dev/null
+++ b/core/http/explorer.go
@@ -0,0 +1,50 @@
+package http
+
+import (
+ "io/fs"
+ "net/http"
+
+ "github.com/labstack/echo/v4"
+ "github.com/mudler/LocalAI/core/explorer"
+ "github.com/mudler/LocalAI/core/http/middleware"
+ "github.com/mudler/LocalAI/core/http/routes"
+ "github.com/mudler/xlog"
+)
+
+func Explorer(db *explorer.Database) *echo.Echo {
+ e := echo.New()
+
+ // Set renderer
+ e.Renderer = renderEngine()
+
+ // Hide banner
+ e.HideBanner = true
+
+ e.Pre(middleware.StripPathPrefix())
+ routes.RegisterExplorerRoutes(e, db)
+
+ // Favicon handler
+ e.GET("/favicon.svg", func(c echo.Context) error {
+ data, err := embedDirStatic.ReadFile("static/favicon.svg")
+ if err != nil {
+ return c.NoContent(http.StatusNotFound)
+ }
+ c.Response().Header().Set("Content-Type", "image/svg+xml")
+ return c.Blob(http.StatusOK, "image/svg+xml", data)
+ })
+
+ // Static files - use fs.Sub to create a filesystem rooted at "static"
+ staticFS, err := fs.Sub(embedDirStatic, "static")
+ if err != nil {
+ // Log error but continue - static files might not work
+ xlog.Error("failed to create static filesystem", "error", err)
+ } else {
+ e.StaticFS("/static", staticFS)
+ }
+
+ // Define a custom 404 handler
+ // Note: keep this at the bottom!
+ e.GET("/*", notFoundHandler)
+
+ return e
+}
diff --git a/core/http/http_suite_test.go b/core/http/http_suite_test.go
new file mode 100644
index 0000000000000000000000000000000000000000..94467437f926ee14cb54c1ea894acbce5d1a1f3b
--- /dev/null
+++ b/core/http/http_suite_test.go
@@ -0,0 +1,13 @@
+package http_test
+
+import (
+ "testing"
+
+ . "github.com/onsi/ginkgo/v2"
+ . "github.com/onsi/gomega"
+)
+
+func TestLocalAI(t *testing.T) {
+ RegisterFailHandler(Fail)
+ RunSpecs(t, "LocalAI HTTP test suite")
+}
diff --git a/core/http/middleware/auth.go b/core/http/middleware/auth.go
new file mode 100644
index 0000000000000000000000000000000000000000..4dde8f73260a2ab99609642c9839162eee085609
--- /dev/null
+++ b/core/http/middleware/auth.go
@@ -0,0 +1,179 @@
+package middleware
+
+import (
+ "crypto/subtle"
+ "errors"
+ "net/http"
+ "strings"
+
+ "github.com/labstack/echo/v4"
+ "github.com/labstack/echo/v4/middleware"
+ "github.com/mudler/LocalAI/core/config"
+ "github.com/mudler/LocalAI/core/schema"
+)
+
+var ErrMissingOrMalformedAPIKey = errors.New("missing or malformed API Key")
+
+// GetKeyAuthConfig returns Echo's KeyAuth middleware configuration
+func GetKeyAuthConfig(applicationConfig *config.ApplicationConfig) (echo.MiddlewareFunc, error) {
+ // Create validator function
+ validator := getApiKeyValidationFunction(applicationConfig)
+
+ // Create error handler
+ errorHandler := getApiKeyErrorHandler(applicationConfig)
+
+ // Create Next function (skip middleware for certain requests)
+ skipper := getApiKeyRequiredFilterFunction(applicationConfig)
+
+ // Wrap it with our custom key lookup that checks multiple sources
+ return func(next echo.HandlerFunc) echo.HandlerFunc {
+ return func(c echo.Context) error {
+ if len(applicationConfig.ApiKeys) == 0 {
+ return next(c)
+ }
+
+ // Skip if skipper says so
+ if skipper != nil && skipper(c) {
+ return next(c)
+ }
+
+ // Try to extract key from multiple sources
+ key, err := extractKeyFromMultipleSources(c)
+ if err != nil {
+ return errorHandler(err, c)
+ }
+
+ // Validate the key
+ valid, err := validator(key, c)
+ if err != nil || !valid {
+ return errorHandler(ErrMissingOrMalformedAPIKey, c)
+ }
+
+ // Store key in context for later use
+ c.Set("api_key", key)
+
+ return next(c)
+ }
+ }, nil
+}
+
+// extractKeyFromMultipleSources checks multiple sources for the API key
+// in order: Authorization header, x-api-key header, xi-api-key header, token cookie
+func extractKeyFromMultipleSources(c echo.Context) (string, error) {
+ // Check Authorization header first
+ auth := c.Request().Header.Get("Authorization")
+ if auth != "" {
+ // Check for Bearer scheme
+ if strings.HasPrefix(auth, "Bearer ") {
+ return strings.TrimPrefix(auth, "Bearer "), nil
+ }
+ // If no Bearer prefix, return as-is (for backward compatibility)
+ return auth, nil
+ }
+
+ // Check x-api-key header
+ if key := c.Request().Header.Get("x-api-key"); key != "" {
+ return key, nil
+ }
+
+ // Check xi-api-key header
+ if key := c.Request().Header.Get("xi-api-key"); key != "" {
+ return key, nil
+ }
+
+ // Check token cookie
+ cookie, err := c.Cookie("token")
+ if err == nil && cookie != nil && cookie.Value != "" {
+ return cookie.Value, nil
+ }
+
+ return "", ErrMissingOrMalformedAPIKey
+}
+
+func getApiKeyErrorHandler(applicationConfig *config.ApplicationConfig) func(error, echo.Context) error {
+ return func(err error, c echo.Context) error {
+ if errors.Is(err, ErrMissingOrMalformedAPIKey) {
+ if len(applicationConfig.ApiKeys) == 0 {
+ return nil // if no keys are set up, any error we get here is not an error.
+ }
+ c.Response().Header().Set("WWW-Authenticate", "Bearer")
+ if applicationConfig.OpaqueErrors {
+ return c.NoContent(http.StatusUnauthorized)
+ }
+
+ // Check if the request content type is JSON
+ contentType := c.Request().Header.Get("Content-Type")
+ if strings.Contains(contentType, "application/json") {
+ return c.JSON(http.StatusUnauthorized, schema.ErrorResponse{
+ Error: &schema.APIError{
+ Message: "An authentication key is required",
+ Code: 401,
+ Type: "invalid_request_error",
+ },
+ })
+ }
+
+ return c.Render(http.StatusUnauthorized, "views/login", map[string]interface{}{
+ "BaseURL": BaseURL(c),
+ })
+ }
+ if applicationConfig.OpaqueErrors {
+ return c.NoContent(http.StatusInternalServerError)
+ }
+ return err
+ }
+}
+
+func getApiKeyValidationFunction(applicationConfig *config.ApplicationConfig) func(string, echo.Context) (bool, error) {
+ if applicationConfig.UseSubtleKeyComparison {
+ return func(key string, c echo.Context) (bool, error) {
+ if len(applicationConfig.ApiKeys) == 0 {
+ return true, nil // If no keys are setup, accept everything
+ }
+ for _, validKey := range applicationConfig.ApiKeys {
+ if subtle.ConstantTimeCompare([]byte(key), []byte(validKey)) == 1 {
+ return true, nil
+ }
+ }
+ return false, ErrMissingOrMalformedAPIKey
+ }
+ }
+
+ return func(key string, c echo.Context) (bool, error) {
+ if len(applicationConfig.ApiKeys) == 0 {
+ return true, nil // If no keys are setup, accept everything
+ }
+ for _, validKey := range applicationConfig.ApiKeys {
+ if key == validKey {
+ return true, nil
+ }
+ }
+ return false, ErrMissingOrMalformedAPIKey
+ }
+}
+
+func getApiKeyRequiredFilterFunction(applicationConfig *config.ApplicationConfig) middleware.Skipper {
+ return func(c echo.Context) bool {
+ path := c.Request().URL.Path
+
+ for _, p := range applicationConfig.PathWithoutAuth {
+ if strings.HasPrefix(path, p) {
+ return true
+ }
+ }
+
+ // Handle GET request exemptions if enabled
+ if applicationConfig.DisableApiKeyRequirementForHttpGet {
+ if c.Request().Method != http.MethodGet {
+ return false
+ }
+ for _, rx := range applicationConfig.HttpGetExemptedEndpoints {
+ if rx.MatchString(c.Path()) {
+ return true
+ }
+ }
+ }
+
+ return false
+ }
+}
diff --git a/core/http/middleware/baseurl.go b/core/http/middleware/baseurl.go
new file mode 100644
index 0000000000000000000000000000000000000000..78a59289a81fd40ba997299f7a7bee703de7ebd0
--- /dev/null
+++ b/core/http/middleware/baseurl.go
@@ -0,0 +1,48 @@
+package middleware
+
+import (
+ "strings"
+
+ "github.com/labstack/echo/v4"
+)
+
+// BaseURL returns the base URL for the given HTTP request context.
+// It takes into account that the app may be exposed by a reverse-proxy under a different protocol, host and path.
+// The returned URL is guaranteed to end with `/`.
+// The method should be used in conjunction with the StripPathPrefix middleware.
+func BaseURL(c echo.Context) string {
+ path := c.Path()
+ origPath := c.Request().URL.Path
+
+ // Check if StripPathPrefix middleware stored the original path
+ if storedPath, ok := c.Get("_original_path").(string); ok && storedPath != "" {
+ origPath = storedPath
+ }
+
+ // Check X-Forwarded-Proto for scheme
+ scheme := "http"
+ if c.Request().Header.Get("X-Forwarded-Proto") == "https" {
+ scheme = "https"
+ } else if c.Request().TLS != nil {
+ scheme = "https"
+ }
+
+ // Check X-Forwarded-Host for host
+ host := c.Request().Host
+ if forwardedHost := c.Request().Header.Get("X-Forwarded-Host"); forwardedHost != "" {
+ host = forwardedHost
+ }
+
+ if path != origPath && strings.HasSuffix(origPath, path) && len(path) > 0 {
+ prefixLen := len(origPath) - len(path)
+ if prefixLen > 0 && prefixLen <= len(origPath) {
+ pathPrefix := origPath[:prefixLen]
+ if !strings.HasSuffix(pathPrefix, "/") {
+ pathPrefix += "/"
+ }
+ return scheme + "://" + host + pathPrefix
+ }
+ }
+
+ return scheme + "://" + host + "/"
+}
diff --git a/core/http/middleware/baseurl_test.go b/core/http/middleware/baseurl_test.go
new file mode 100644
index 0000000000000000000000000000000000000000..b0770b8eae41c07fb18c3986f9b697598b7132c3
--- /dev/null
+++ b/core/http/middleware/baseurl_test.go
@@ -0,0 +1,58 @@
+package middleware
+
+import (
+ "net/http/httptest"
+
+ "github.com/labstack/echo/v4"
+ . "github.com/onsi/ginkgo/v2"
+ . "github.com/onsi/gomega"
+)
+
+var _ = Describe("BaseURL", func() {
+ Context("without prefix", func() {
+ It("should return base URL without prefix", func() {
+ app := echo.New()
+ actualURL := ""
+
+ // Register route - use the actual request path so routing works
+ routePath := "/hello/world"
+ app.GET(routePath, func(c echo.Context) error {
+ actualURL = BaseURL(c)
+ return nil
+ })
+
+ req := httptest.NewRequest("GET", "/hello/world", nil)
+ rec := httptest.NewRecorder()
+ app.ServeHTTP(rec, req)
+
+ Expect(rec.Code).To(Equal(200), "response status code")
+ Expect(actualURL).To(Equal("http://example.com/"), "base URL")
+ })
+ })
+
+ Context("with prefix", func() {
+ It("should return base URL with prefix", func() {
+ app := echo.New()
+ actualURL := ""
+
+ // Register route with the stripped path (after middleware removes prefix)
+ routePath := "/hello/world"
+ app.GET(routePath, func(c echo.Context) error {
+ // Simulate what StripPathPrefix middleware does - store original path
+ c.Set("_original_path", "/myprefix/hello/world")
+ // Modify the request path to simulate prefix stripping
+ c.Request().URL.Path = "/hello/world"
+ actualURL = BaseURL(c)
+ return nil
+ })
+
+ // Make request with stripped path (middleware would have already processed it)
+ req := httptest.NewRequest("GET", "/hello/world", nil)
+ rec := httptest.NewRecorder()
+ app.ServeHTTP(rec, req)
+
+ Expect(rec.Code).To(Equal(200), "response status code")
+ Expect(actualURL).To(Equal("http://example.com/myprefix/"), "base URL")
+ })
+ })
+})
diff --git a/core/http/middleware/middleware_suite_test.go b/core/http/middleware/middleware_suite_test.go
new file mode 100644
index 0000000000000000000000000000000000000000..0f40add2539dd8a0a5f856233deee1ecd82b7c6c
--- /dev/null
+++ b/core/http/middleware/middleware_suite_test.go
@@ -0,0 +1,13 @@
+package middleware_test
+
+import (
+ "testing"
+
+ . "github.com/onsi/ginkgo/v2"
+ . "github.com/onsi/gomega"
+)
+
+func TestMiddleware(t *testing.T) {
+ RegisterFailHandler(Fail)
+ RunSpecs(t, "Middleware test suite")
+}
diff --git a/core/http/middleware/request.go b/core/http/middleware/request.go
new file mode 100644
index 0000000000000000000000000000000000000000..76d7fee643787e78b18ef02eb2c5af205ecd7318
--- /dev/null
+++ b/core/http/middleware/request.go
@@ -0,0 +1,486 @@
+package middleware
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "strconv"
+ "strings"
+
+ "github.com/google/uuid"
+ "github.com/labstack/echo/v4"
+ "github.com/mudler/LocalAI/core/config"
+ "github.com/mudler/LocalAI/core/schema"
+ "github.com/mudler/LocalAI/core/services"
+ "github.com/mudler/LocalAI/core/templates"
+ "github.com/mudler/LocalAI/pkg/functions"
+ "github.com/mudler/LocalAI/pkg/model"
+ "github.com/mudler/LocalAI/pkg/utils"
+ "github.com/mudler/xlog"
+)
+
+type correlationIDKeyType string
+
+// CorrelationIDKey to track request across process boundary
+const CorrelationIDKey correlationIDKeyType = "correlationID"
+
+type RequestExtractor struct {
+ modelConfigLoader *config.ModelConfigLoader
+ modelLoader *model.ModelLoader
+ applicationConfig *config.ApplicationConfig
+}
+
+func NewRequestExtractor(modelConfigLoader *config.ModelConfigLoader, modelLoader *model.ModelLoader, applicationConfig *config.ApplicationConfig) *RequestExtractor {
+ return &RequestExtractor{
+ modelConfigLoader: modelConfigLoader,
+ modelLoader: modelLoader,
+ applicationConfig: applicationConfig,
+ }
+}
+
+const CONTEXT_LOCALS_KEY_MODEL_NAME = "MODEL_NAME"
+const CONTEXT_LOCALS_KEY_LOCALAI_REQUEST = "LOCALAI_REQUEST"
+const CONTEXT_LOCALS_KEY_MODEL_CONFIG = "MODEL_CONFIG"
+
+// TODO: Refactor to not return error if unchanged
+func (re *RequestExtractor) setModelNameFromRequest(c echo.Context) {
+ model, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
+ if ok && model != "" {
+ return
+ }
+ model = c.Param("model")
+
+ if model == "" {
+ model = c.QueryParam("model")
+ }
+
+ // Check FormValue for multipart/form-data requests (e.g., /v1/images/inpainting)
+ if model == "" {
+ model = c.FormValue("model")
+ }
+
+ if model == "" {
+ // Set model from bearer token, if available
+ auth := c.Request().Header.Get("Authorization")
+ bearer := strings.TrimPrefix(auth, "Bearer ")
+ if bearer != "" && bearer != auth {
+ exists, err := services.CheckIfModelExists(re.modelConfigLoader, re.modelLoader, bearer, services.ALWAYS_INCLUDE)
+ if err == nil && exists {
+ model = bearer
+ }
+ }
+ }
+
+ c.Set(CONTEXT_LOCALS_KEY_MODEL_NAME, model)
+}
+
+func (re *RequestExtractor) BuildConstantDefaultModelNameMiddleware(defaultModelName string) echo.MiddlewareFunc {
+ return func(next echo.HandlerFunc) echo.HandlerFunc {
+ return func(c echo.Context) error {
+ re.setModelNameFromRequest(c)
+ localModelName, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
+ if !ok || localModelName == "" {
+ c.Set(CONTEXT_LOCALS_KEY_MODEL_NAME, defaultModelName)
+ xlog.Debug("context local model name not found, setting to default", "defaultModelName", defaultModelName)
+ }
+ return next(c)
+ }
+ }
+}
+
+func (re *RequestExtractor) BuildFilteredFirstAvailableDefaultModel(filterFn config.ModelConfigFilterFn) echo.MiddlewareFunc {
+ return func(next echo.HandlerFunc) echo.HandlerFunc {
+ return func(c echo.Context) error {
+ re.setModelNameFromRequest(c)
+ localModelName := c.Get(CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
+ if localModelName != "" { // Don't overwrite existing values
+ return next(c)
+ }
+
+ modelNames, err := services.ListModels(re.modelConfigLoader, re.modelLoader, filterFn, services.SKIP_IF_CONFIGURED)
+ if err != nil {
+ xlog.Error("non-fatal error calling ListModels during SetDefaultModelNameToFirstAvailable()", "error", err)
+ return next(c)
+ }
+
+ if len(modelNames) == 0 {
+ xlog.Warn("SetDefaultModelNameToFirstAvailable used with no matching models installed")
+ // This is non-fatal - making it so was breaking the case of direct installation of raw models
+ // return errors.New("this endpoint requires at least one model to be installed")
+ return next(c)
+ }
+
+ c.Set(CONTEXT_LOCALS_KEY_MODEL_NAME, modelNames[0])
+ xlog.Debug("context local model name not found, setting to the first model", "first model name", modelNames[0])
+ return next(c)
+ }
+ }
+}
+
+// TODO: If context and cancel above belong on all methods, move that part of above into here!
+// Otherwise, it's in its own method below for now
+func (re *RequestExtractor) SetModelAndConfig(initializer func() schema.LocalAIRequest) echo.MiddlewareFunc {
+ return func(next echo.HandlerFunc) echo.HandlerFunc {
+ return func(c echo.Context) error {
+ input := initializer()
+ if input == nil {
+ return echo.NewHTTPError(http.StatusBadRequest, "unable to initialize body")
+ }
+ if err := c.Bind(input); err != nil {
+ return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("failed parsing request body: %v", err))
+ }
+
+ // If this request doesn't have an associated model name, fetch it from earlier in the middleware chain
+ if input.ModelName(nil) == "" {
+ localModelName, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
+ if ok && localModelName != "" {
+ xlog.Debug("overriding empty model name in request body with value found earlier in middleware chain", "context localModelName", localModelName)
+ input.ModelName(&localModelName)
+ }
+ }
+
+ cfg, err := re.modelConfigLoader.LoadModelConfigFileByNameDefaultOptions(input.ModelName(nil), re.applicationConfig)
+
+ if err != nil {
+ xlog.Warn("Model Configuration File not found", "model", input.ModelName(nil), "error", err)
+ } else if cfg.Model == "" && input.ModelName(nil) != "" {
+ xlog.Debug("config does not include model, using input", "input.ModelName", input.ModelName(nil))
+ cfg.Model = input.ModelName(nil)
+ }
+
+ c.Set(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, input)
+ c.Set(CONTEXT_LOCALS_KEY_MODEL_CONFIG, cfg)
+
+ return next(c)
+ }
+ }
+}
+
+func (re *RequestExtractor) SetOpenAIRequest(c echo.Context) error {
+ input, ok := c.Get(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
+ if !ok || input.Model == "" {
+ return echo.ErrBadRequest
+ }
+
+ cfg, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
+ if !ok || cfg == nil {
+ return echo.ErrBadRequest
+ }
+
+ // Extract or generate the correlation ID
+ correlationID := c.Request().Header.Get("X-Correlation-ID")
+ if correlationID == "" {
+ correlationID = uuid.New().String()
+ }
+ c.Response().Header().Set("X-Correlation-ID", correlationID)
+
+ // Use the request context directly - Echo properly supports context cancellation!
+ // No need for workarounds like handleConnectionCancellation
+ reqCtx := c.Request().Context()
+ c1, cancel := context.WithCancel(re.applicationConfig.Context)
+
+ // Cancel when request context is cancelled (client disconnects)
+ go func() {
+ select {
+ case <-reqCtx.Done():
+ cancel()
+ case <-c1.Done():
+ // Already cancelled
+ }
+ }()
+
+ // Add the correlation ID to the new context
+ ctxWithCorrelationID := context.WithValue(c1, CorrelationIDKey, correlationID)
+
+ input.Context = ctxWithCorrelationID
+ input.Cancel = cancel
+
+ err := mergeOpenAIRequestAndModelConfig(cfg, input)
+ if err != nil {
+ return err
+ }
+
+ if cfg.Model == "" {
+ xlog.Debug("replacing empty cfg.Model with input value", "input.Model", input.Model)
+ cfg.Model = input.Model
+ }
+
+ c.Set(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, input)
+ c.Set(CONTEXT_LOCALS_KEY_MODEL_CONFIG, cfg)
+
+ return nil
+}
+
+func mergeOpenAIRequestAndModelConfig(config *config.ModelConfig, input *schema.OpenAIRequest) error {
+ if input.Echo {
+ config.Echo = input.Echo
+ }
+ if input.TopK != nil {
+ config.TopK = input.TopK
+ }
+ if input.TopP != nil {
+ config.TopP = input.TopP
+ }
+
+ if input.Backend != "" {
+ config.Backend = input.Backend
+ }
+
+ if input.ClipSkip != 0 {
+ config.Diffusers.ClipSkip = input.ClipSkip
+ }
+
+ if input.NegativePromptScale != 0 {
+ config.NegativePromptScale = input.NegativePromptScale
+ }
+
+ if input.NegativePrompt != "" {
+ config.NegativePrompt = input.NegativePrompt
+ }
+
+ if input.RopeFreqBase != 0 {
+ config.RopeFreqBase = input.RopeFreqBase
+ }
+
+ if input.RopeFreqScale != 0 {
+ config.RopeFreqScale = input.RopeFreqScale
+ }
+
+ if input.Grammar != "" {
+ config.Grammar = input.Grammar
+ }
+
+ if input.Temperature != nil {
+ config.Temperature = input.Temperature
+ }
+
+ if input.Maxtokens != nil {
+ config.Maxtokens = input.Maxtokens
+ }
+
+ if input.ResponseFormat != nil {
+ switch responseFormat := input.ResponseFormat.(type) {
+ case string:
+ config.ResponseFormat = responseFormat
+ case map[string]interface{}:
+ config.ResponseFormatMap = responseFormat
+ }
+ }
+
+ switch stop := input.Stop.(type) {
+ case string:
+ if stop != "" {
+ config.StopWords = append(config.StopWords, stop)
+ }
+ case []interface{}:
+ for _, pp := range stop {
+ if s, ok := pp.(string); ok {
+ config.StopWords = append(config.StopWords, s)
+ }
+ }
+ }
+
+ if len(input.Tools) > 0 {
+ for _, tool := range input.Tools {
+ input.Functions = append(input.Functions, tool.Function)
+ }
+ }
+
+ if input.ToolsChoice != nil {
+ var toolChoice functions.Tool
+
+ switch content := input.ToolsChoice.(type) {
+ case string:
+ _ = json.Unmarshal([]byte(content), &toolChoice)
+ case map[string]interface{}:
+ dat, _ := json.Marshal(content)
+ _ = json.Unmarshal(dat, &toolChoice)
+ }
+ input.FunctionCall = map[string]interface{}{
+ "name": toolChoice.Function.Name,
+ }
+ }
+
+ // Decode each request's message content
+ imgIndex, vidIndex, audioIndex := 0, 0, 0
+ for i, m := range input.Messages {
+ nrOfImgsInMessage := 0
+ nrOfVideosInMessage := 0
+ nrOfAudiosInMessage := 0
+
+ switch content := m.Content.(type) {
+ case string:
+ input.Messages[i].StringContent = content
+ case []interface{}:
+ dat, _ := json.Marshal(content)
+ c := []schema.Content{}
+ json.Unmarshal(dat, &c)
+
+ textContent := ""
+ // we will template this at the end
+
+ CONTENT:
+ for _, pp := range c {
+ switch pp.Type {
+ case "text":
+ textContent += pp.Text
+ //input.Messages[i].StringContent = pp.Text
+ case "video", "video_url":
+ // Decode content as base64 either if it's an URL or base64 text
+ base64, err := utils.GetContentURIAsBase64(pp.VideoURL.URL)
+ if err != nil {
+ xlog.Error("Failed encoding video", "error", err)
+ continue CONTENT
+ }
+ input.Messages[i].StringVideos = append(input.Messages[i].StringVideos, base64) // TODO: make sure that we only return base64 stuff
+ vidIndex++
+ nrOfVideosInMessage++
+ case "audio_url", "audio":
+ // Decode content as base64 either if it's an URL or base64 text
+ base64, err := utils.GetContentURIAsBase64(pp.AudioURL.URL)
+ if err != nil {
+ xlog.Error("Failed encoding audio", "error", err)
+ continue CONTENT
+ }
+ input.Messages[i].StringAudios = append(input.Messages[i].StringAudios, base64) // TODO: make sure that we only return base64 stuff
+ audioIndex++
+ nrOfAudiosInMessage++
+ case "input_audio":
+ // TODO: make sure that we only return base64 stuff
+ input.Messages[i].StringAudios = append(input.Messages[i].StringAudios, pp.InputAudio.Data)
+ audioIndex++
+ nrOfAudiosInMessage++
+ case "image_url", "image":
+ // Decode content as base64 either if it's an URL or base64 text
+ base64, err := utils.GetContentURIAsBase64(pp.ImageURL.URL)
+ if err != nil {
+ xlog.Error("Failed encoding image", "error", err)
+ continue CONTENT
+ }
+
+ input.Messages[i].StringImages = append(input.Messages[i].StringImages, base64) // TODO: make sure that we only return base64 stuff
+
+ imgIndex++
+ nrOfImgsInMessage++
+ }
+ }
+
+ input.Messages[i].StringContent, _ = templates.TemplateMultiModal(config.TemplateConfig.Multimodal, templates.MultiModalOptions{
+ TotalImages: imgIndex,
+ TotalVideos: vidIndex,
+ TotalAudios: audioIndex,
+ ImagesInMessage: nrOfImgsInMessage,
+ VideosInMessage: nrOfVideosInMessage,
+ AudiosInMessage: nrOfAudiosInMessage,
+ }, textContent)
+ }
+ }
+
+ if input.RepeatPenalty != 0 {
+ config.RepeatPenalty = input.RepeatPenalty
+ }
+
+ if input.FrequencyPenalty != 0 {
+ config.FrequencyPenalty = input.FrequencyPenalty
+ }
+
+ if input.PresencePenalty != 0 {
+ config.PresencePenalty = input.PresencePenalty
+ }
+
+ if input.Keep != 0 {
+ config.Keep = input.Keep
+ }
+
+ if input.Batch != 0 {
+ config.Batch = input.Batch
+ }
+
+ if input.IgnoreEOS {
+ config.IgnoreEOS = input.IgnoreEOS
+ }
+
+ if input.Seed != nil {
+ config.Seed = input.Seed
+ }
+
+ if input.TypicalP != nil {
+ config.TypicalP = input.TypicalP
+ }
+
+ xlog.Debug("input.Input", "input", fmt.Sprintf("%+v", input.Input))
+
+ switch inputs := input.Input.(type) {
+ case string:
+ if inputs != "" {
+ config.InputStrings = append(config.InputStrings, inputs)
+ }
+ case []any:
+ for _, pp := range inputs {
+ switch i := pp.(type) {
+ case string:
+ config.InputStrings = append(config.InputStrings, i)
+ case []any:
+ tokens := []int{}
+ inputStrings := []string{}
+ for _, ii := range i {
+ switch ii := ii.(type) {
+ case int:
+ tokens = append(tokens, ii)
+ case float64:
+ tokens = append(tokens, int(ii))
+ case string:
+ inputStrings = append(inputStrings, ii)
+ default:
+ xlog.Error("Unknown input type", "type", fmt.Sprintf("%T", ii))
+ }
+ }
+ config.InputToken = append(config.InputToken, tokens)
+ config.InputStrings = append(config.InputStrings, inputStrings...)
+ }
+ }
+ }
+
+ // Can be either a string or an object
+ switch fnc := input.FunctionCall.(type) {
+ case string:
+ if fnc != "" {
+ config.SetFunctionCallString(fnc)
+ }
+ case map[string]interface{}:
+ var name string
+ n, exists := fnc["name"]
+ if exists {
+ nn, e := n.(string)
+ if e {
+ name = nn
+ }
+ }
+ config.SetFunctionCallNameString(name)
+ }
+
+ switch p := input.Prompt.(type) {
+ case string:
+ config.PromptStrings = append(config.PromptStrings, p)
+ case []interface{}:
+ for _, pp := range p {
+ if s, ok := pp.(string); ok {
+ config.PromptStrings = append(config.PromptStrings, s)
+ }
+ }
+ }
+
+ // If a quality was defined as number, convert it to step
+ if input.Quality != "" {
+ q, err := strconv.Atoi(input.Quality)
+ if err == nil {
+ config.Step = q
+ }
+ }
+
+ if valid, _ := config.Validate(); valid {
+ return nil
+ }
+ return fmt.Errorf("unable to validate configuration after merging")
+}
diff --git a/core/http/middleware/strippathprefix.go b/core/http/middleware/strippathprefix.go
new file mode 100644
index 0000000000000000000000000000000000000000..451ccfe667ca6173565e8c19a1864fa82e6acabc
--- /dev/null
+++ b/core/http/middleware/strippathprefix.go
@@ -0,0 +1,57 @@
+package middleware
+
+import (
+ "strings"
+
+ "github.com/labstack/echo/v4"
+)
+
+// StripPathPrefix returns middleware that strips a path prefix from the request path.
+// The path prefix is obtained from the X-Forwarded-Prefix HTTP request header.
+// This must be registered as Pre middleware (using e.Pre()) to modify the path before routing.
+func StripPathPrefix() echo.MiddlewareFunc {
+ return func(next echo.HandlerFunc) echo.HandlerFunc {
+ return func(c echo.Context) error {
+ prefixes := c.Request().Header.Values("X-Forwarded-Prefix")
+ originalPath := c.Request().URL.Path
+
+ for _, prefix := range prefixes {
+ if prefix != "" {
+ normalizedPrefix := prefix
+ if !strings.HasSuffix(prefix, "/") {
+ normalizedPrefix = prefix + "/"
+ }
+
+ if strings.HasPrefix(originalPath, normalizedPrefix) {
+ // Update the request path by stripping the normalized prefix
+ newPath := originalPath[len(normalizedPrefix):]
+ if newPath == "" {
+ newPath = "/"
+ }
+ // Ensure path starts with / for proper routing
+ if !strings.HasPrefix(newPath, "/") {
+ newPath = "/" + newPath
+ }
+ // Update the URL path - Echo's router uses URL.Path for routing
+ c.Request().URL.Path = newPath
+ c.Request().URL.RawPath = ""
+ // Update RequestURI to match the new path (needed for proper routing)
+ if c.Request().URL.RawQuery != "" {
+ c.Request().RequestURI = newPath + "?" + c.Request().URL.RawQuery
+ } else {
+ c.Request().RequestURI = newPath
+ }
+ // Store original path for BaseURL utility
+ c.Set("_original_path", originalPath)
+ break
+ } else if originalPath == prefix || originalPath == prefix+"/" {
+ // Redirect to prefix with trailing slash (use 302 to match test expectations)
+ return c.Redirect(302, normalizedPrefix)
+ }
+ }
+ }
+
+ return next(c)
+ }
+ }
+}
diff --git a/core/http/middleware/strippathprefix_test.go b/core/http/middleware/strippathprefix_test.go
new file mode 100644
index 0000000000000000000000000000000000000000..32c1c5d4af6adae2baff0ac57db7f2ff29048fc3
--- /dev/null
+++ b/core/http/middleware/strippathprefix_test.go
@@ -0,0 +1,134 @@
+package middleware
+
+import (
+ "net/http/httptest"
+
+ "github.com/labstack/echo/v4"
+ . "github.com/onsi/ginkgo/v2"
+ . "github.com/onsi/gomega"
+)
+
+var _ = Describe("StripPathPrefix", func() {
+ var app *echo.Echo
+ var actualPath string
+ var appInitialized bool
+
+ BeforeEach(func() {
+ actualPath = ""
+ if !appInitialized {
+ app = echo.New()
+ app.Pre(StripPathPrefix())
+
+ app.GET("/hello/world", func(c echo.Context) error {
+ actualPath = c.Request().URL.Path
+ return nil
+ })
+
+ app.GET("/", func(c echo.Context) error {
+ actualPath = c.Request().URL.Path
+ return nil
+ })
+ appInitialized = true
+ }
+ })
+
+ Context("without prefix", func() {
+ It("should not modify path when no header is present", func() {
+ req := httptest.NewRequest("GET", "/hello/world", nil)
+ rec := httptest.NewRecorder()
+ app.ServeHTTP(rec, req)
+
+ Expect(rec.Code).To(Equal(200), "response status code")
+ Expect(actualPath).To(Equal("/hello/world"), "rewritten path")
+ })
+
+ It("should not modify root path when no header is present", func() {
+ req := httptest.NewRequest("GET", "/", nil)
+ rec := httptest.NewRecorder()
+ app.ServeHTTP(rec, req)
+
+ Expect(rec.Code).To(Equal(200), "response status code")
+ Expect(actualPath).To(Equal("/"), "rewritten path")
+ })
+
+ It("should not modify path when header does not match", func() {
+ req := httptest.NewRequest("GET", "/hello/world", nil)
+ req.Header["X-Forwarded-Prefix"] = []string{"/otherprefix/"}
+ rec := httptest.NewRecorder()
+ app.ServeHTTP(rec, req)
+
+ Expect(rec.Code).To(Equal(200), "response status code")
+ Expect(actualPath).To(Equal("/hello/world"), "rewritten path")
+ })
+ })
+
+ Context("with prefix", func() {
+ It("should return 404 when prefix does not match header", func() {
+ req := httptest.NewRequest("GET", "/prefix/hello/world", nil)
+ req.Header["X-Forwarded-Prefix"] = []string{"/otherprefix/"}
+ rec := httptest.NewRecorder()
+ app.ServeHTTP(rec, req)
+
+ Expect(rec.Code).To(Equal(404), "response status code")
+ })
+
+ It("should strip matching prefix from path", func() {
+ req := httptest.NewRequest("GET", "/myprefix/hello/world", nil)
+ req.Header["X-Forwarded-Prefix"] = []string{"/myprefix/"}
+ rec := httptest.NewRecorder()
+ app.ServeHTTP(rec, req)
+
+ Expect(rec.Code).To(Equal(200), "response status code")
+ Expect(actualPath).To(Equal("/hello/world"), "rewritten path")
+ })
+
+ It("should strip prefix when it matches the first header value", func() {
+ req := httptest.NewRequest("GET", "/myprefix/hello/world", nil)
+ req.Header["X-Forwarded-Prefix"] = []string{"/myprefix/", "/otherprefix/"}
+ rec := httptest.NewRecorder()
+ app.ServeHTTP(rec, req)
+
+ Expect(rec.Code).To(Equal(200), "response status code")
+ Expect(actualPath).To(Equal("/hello/world"), "rewritten path")
+ })
+
+ It("should strip prefix when it matches the second header value", func() {
+ req := httptest.NewRequest("GET", "/myprefix/hello/world", nil)
+ req.Header["X-Forwarded-Prefix"] = []string{"/otherprefix/", "/myprefix/"}
+ rec := httptest.NewRecorder()
+ app.ServeHTTP(rec, req)
+
+ Expect(rec.Code).To(Equal(200), "response status code")
+ Expect(actualPath).To(Equal("/hello/world"), "rewritten path")
+ })
+
+ It("should strip prefix when header does not end with slash", func() {
+ req := httptest.NewRequest("GET", "/myprefix/hello/world", nil)
+ req.Header["X-Forwarded-Prefix"] = []string{"/myprefix"}
+ rec := httptest.NewRecorder()
+ app.ServeHTTP(rec, req)
+
+ Expect(rec.Code).To(Equal(200), "response status code")
+ Expect(actualPath).To(Equal("/hello/world"), "rewritten path")
+ })
+
+ It("should return 404 when prefix does not match header without trailing slash", func() {
+ req := httptest.NewRequest("GET", "/myprefix-suffix/hello/world", nil)
+ req.Header["X-Forwarded-Prefix"] = []string{"/myprefix"}
+ rec := httptest.NewRecorder()
+ app.ServeHTTP(rec, req)
+
+ Expect(rec.Code).To(Equal(404), "response status code")
+ })
+
+ It("should redirect when prefix does not end with a slash", func() {
+ req := httptest.NewRequest("GET", "/myprefix", nil)
+ req.Header["X-Forwarded-Prefix"] = []string{"/myprefix"}
+ rec := httptest.NewRecorder()
+ app.ServeHTTP(rec, req)
+
+ Expect(rec.Code).To(Equal(302), "response status code")
+ Expect(rec.Header().Get("Location")).To(Equal("/myprefix/"), "redirect location")
+ })
+ })
+})
diff --git a/core/http/middleware/trace.go b/core/http/middleware/trace.go
new file mode 100644
index 0000000000000000000000000000000000000000..aa63ba349f37265c398a948d7e84dd9267ac750f
--- /dev/null
+++ b/core/http/middleware/trace.go
@@ -0,0 +1,156 @@
+package middleware
+
+import (
+ "bytes"
+ "github.com/emirpasic/gods/v2/queues/circularbuffer"
+ "io"
+ "net/http"
+ "sort"
+ "sync"
+ "time"
+
+ "github.com/labstack/echo/v4"
+ "github.com/mudler/LocalAI/core/application"
+ "github.com/mudler/xlog"
+)
+
+type APIExchangeRequest struct {
+ Method string `json:"method"`
+ Path string `json:"path"`
+ Headers *http.Header `json:"headers"`
+ Body *[]byte `json:"body"`
+}
+
+type APIExchangeResponse struct {
+ Status int `json:"status"`
+ Headers *http.Header `json:"headers"`
+ Body *[]byte `json:"body"`
+}
+
+type APIExchange struct {
+ Timestamp time.Time `json:"timestamp"`
+ Request APIExchangeRequest `json:"request"`
+ Response APIExchangeResponse `json:"response"`
+}
+
+var traceBuffer *circularbuffer.Queue[APIExchange]
+var mu sync.Mutex
+var logChan = make(chan APIExchange, 100)
+
+type bodyWriter struct {
+ http.ResponseWriter
+ body *bytes.Buffer
+}
+
+func (w *bodyWriter) Write(b []byte) (int, error) {
+ w.body.Write(b)
+ return w.ResponseWriter.Write(b)
+}
+
+func (w *bodyWriter) Flush() {
+ if flusher, ok := w.ResponseWriter.(http.Flusher); ok {
+ flusher.Flush()
+ }
+}
+
+// TraceMiddleware intercepts and logs JSON API requests and responses
+func TraceMiddleware(app *application.Application) echo.MiddlewareFunc {
+ if app.ApplicationConfig().EnableTracing && traceBuffer == nil {
+ traceBuffer = circularbuffer.New[APIExchange](app.ApplicationConfig().TracingMaxItems)
+
+ go func() {
+ for exchange := range logChan {
+ mu.Lock()
+ traceBuffer.Enqueue(exchange)
+ mu.Unlock()
+ }
+ }()
+ }
+
+ return func(next echo.HandlerFunc) echo.HandlerFunc {
+ return func(c echo.Context) error {
+ if !app.ApplicationConfig().EnableTracing {
+ return next(c)
+ }
+
+ if c.Request().Header.Get("Content-Type") != "application/json" {
+ return next(c)
+ }
+
+ body, err := io.ReadAll(c.Request().Body)
+ if err != nil {
+ xlog.Error("Failed to read request body")
+ return err
+ }
+
+ // Restore the body for downstream handlers
+ c.Request().Body = io.NopCloser(bytes.NewBuffer(body))
+
+ startTime := time.Now()
+
+ // Wrap response writer to capture body
+ resBody := new(bytes.Buffer)
+ mw := &bodyWriter{
+ ResponseWriter: c.Response().Writer,
+ body: resBody,
+ }
+ c.Response().Writer = mw
+
+ err = next(c)
+ if err != nil {
+ c.Response().Writer = mw.ResponseWriter // Restore original writer if error
+ return err
+ }
+
+ // Create exchange log
+ requestHeaders := c.Request().Header.Clone()
+ requestBody := make([]byte, len(body))
+ copy(requestBody, body)
+ responseHeaders := c.Response().Header().Clone()
+ responseBody := make([]byte, resBody.Len())
+ copy(responseBody, resBody.Bytes())
+ exchange := APIExchange{
+ Timestamp: startTime,
+ Request: APIExchangeRequest{
+ Method: c.Request().Method,
+ Path: c.Path(),
+ Headers: &requestHeaders,
+ Body: &requestBody,
+ },
+ Response: APIExchangeResponse{
+ Status: c.Response().Status,
+ Headers: &responseHeaders,
+ Body: &responseBody,
+ },
+ }
+
+ select {
+ case logChan <- exchange:
+ default:
+ xlog.Warn("Trace channel full, dropping trace")
+ }
+
+ return nil
+ }
+ }
+}
+
+// GetTraces returns a copy of the logged API exchanges for display
+func GetTraces() []APIExchange {
+ mu.Lock()
+ traces := traceBuffer.Values()
+ mu.Unlock()
+
+ sort.Slice(traces, func(i, j int) bool {
+ return traces[i].Timestamp.Before(traces[j].Timestamp)
+ })
+
+ return traces
+}
+
+// ClearTraces clears the in-memory logs
+func ClearTraces() {
+ mu.Lock()
+ traceBuffer.Clear()
+ mu.Unlock()
+}
diff --git a/core/http/openai_mapping_test.go b/core/http/openai_mapping_test.go
new file mode 100644
index 0000000000000000000000000000000000000000..a1dc44b393a3400ac2ef296e3239fd7e7952abd4
--- /dev/null
+++ b/core/http/openai_mapping_test.go
@@ -0,0 +1,75 @@
+package http_test
+
+import (
+ "encoding/json"
+
+ openai "github.com/mudler/LocalAI/core/http/endpoints/openai"
+ "github.com/mudler/LocalAI/core/schema"
+
+ . "github.com/onsi/ginkgo/v2"
+ . "github.com/onsi/gomega"
+)
+
+var _ = Describe("MapOpenAIToVideo", func() {
+ It("maps size and seconds correctly", func() {
+ cases := []struct {
+ name string
+ input *schema.OpenAIRequest
+ raw map[string]interface{}
+ expectsW int32
+ expectsH int32
+ expectsF int32
+ expectsN int32
+ }{
+ {
+ name: "size in input",
+ input: &schema.OpenAIRequest{
+ PredictionOptions: schema.PredictionOptions{
+ BasicModelRequest: schema.BasicModelRequest{Model: "m"},
+ },
+ Size: "256x128",
+ },
+ expectsW: 256,
+ expectsH: 128,
+ },
+ {
+ name: "size in raw and seconds as string",
+ input: &schema.OpenAIRequest{PredictionOptions: schema.PredictionOptions{BasicModelRequest: schema.BasicModelRequest{Model: "m"}}},
+ raw: map[string]interface{}{"size": "720x480", "seconds": "2"},
+ expectsW: 720,
+ expectsH: 480,
+ expectsF: 30,
+ expectsN: 60,
+ },
+ {
+ name: "seconds as number and fps override",
+ input: &schema.OpenAIRequest{PredictionOptions: schema.PredictionOptions{BasicModelRequest: schema.BasicModelRequest{Model: "m"}}},
+ raw: map[string]interface{}{"seconds": 3.0, "fps": 24.0},
+ expectsF: 24,
+ expectsN: 72,
+ },
+ }
+
+ for _, c := range cases {
+ By(c.name)
+ vr := openai.MapOpenAIToVideo(c.input, c.raw)
+ if c.expectsW != 0 {
+ Expect(vr.Width).To(Equal(c.expectsW))
+ }
+ if c.expectsH != 0 {
+ Expect(vr.Height).To(Equal(c.expectsH))
+ }
+ if c.expectsF != 0 {
+ Expect(vr.FPS).To(Equal(c.expectsF))
+ }
+ if c.expectsN != 0 {
+ Expect(vr.NumFrames).To(Equal(c.expectsN))
+ }
+
+ b, err := json.Marshal(vr)
+ Expect(err).ToNot(HaveOccurred())
+ _ = b
+ }
+ })
+})
+
diff --git a/core/http/openai_videos_test.go b/core/http/openai_videos_test.go
new file mode 100644
index 0000000000000000000000000000000000000000..60faada8f5ca64bcd820264b823ae1a3f1d9dea6
--- /dev/null
+++ b/core/http/openai_videos_test.go
@@ -0,0 +1,168 @@
+package http_test
+
+import (
+ "bytes"
+ "context"
+ "encoding/json"
+ "io"
+ "net/http"
+ "os"
+ "path/filepath"
+ "time"
+
+ "github.com/mudler/LocalAI/core/application"
+ "github.com/mudler/LocalAI/core/config"
+ "github.com/mudler/LocalAI/pkg/system"
+ "github.com/mudler/LocalAI/pkg/grpc"
+ pb "github.com/mudler/LocalAI/pkg/grpc/proto"
+ "fmt"
+ . "github.com/mudler/LocalAI/core/http"
+ "github.com/labstack/echo/v4"
+
+ . "github.com/onsi/ginkgo/v2"
+ . "github.com/onsi/gomega"
+)
+
+const testAPIKey = "joshua"
+
+type fakeAI struct{}
+
+func (f *fakeAI) Busy() bool { return false }
+func (f *fakeAI) Lock() {}
+func (f *fakeAI) Unlock() {}
+func (f *fakeAI) Locking() bool { return false }
+func (f *fakeAI) Predict(*pb.PredictOptions) (string, error) { return "", nil }
+func (f *fakeAI) PredictStream(*pb.PredictOptions, chan string) error {
+ return nil
+}
+func (f *fakeAI) Load(*pb.ModelOptions) error { return nil }
+func (f *fakeAI) Embeddings(*pb.PredictOptions) ([]float32, error) { return nil, nil }
+func (f *fakeAI) GenerateImage(*pb.GenerateImageRequest) error { return nil }
+func (f *fakeAI) GenerateVideo(*pb.GenerateVideoRequest) error { return nil }
+func (f *fakeAI) Detect(*pb.DetectOptions) (pb.DetectResponse, error) { return pb.DetectResponse{}, nil }
+func (f *fakeAI) AudioTranscription(*pb.TranscriptRequest) (pb.TranscriptResult, error) {
+ return pb.TranscriptResult{}, nil
+}
+func (f *fakeAI) TTS(*pb.TTSRequest) error { return nil }
+func (f *fakeAI) SoundGeneration(*pb.SoundGenerationRequest) error { return nil }
+func (f *fakeAI) TokenizeString(*pb.PredictOptions) (pb.TokenizationResponse, error) {
+ return pb.TokenizationResponse{}, nil
+}
+func (f *fakeAI) Status() (pb.StatusResponse, error) { return pb.StatusResponse{}, nil }
+func (f *fakeAI) StoresSet(*pb.StoresSetOptions) error { return nil }
+func (f *fakeAI) StoresDelete(*pb.StoresDeleteOptions) error { return nil }
+func (f *fakeAI) StoresGet(*pb.StoresGetOptions) (pb.StoresGetResult, error) {
+ return pb.StoresGetResult{}, nil
+}
+func (f *fakeAI) StoresFind(*pb.StoresFindOptions) (pb.StoresFindResult, error) {
+ return pb.StoresFindResult{}, nil
+}
+func (f *fakeAI) VAD(*pb.VADRequest) (pb.VADResponse, error) { return pb.VADResponse{}, nil }
+
+var _ = Describe("OpenAI /v1/videos (embedded backend)", func() {
+ var tmpdir string
+ var appServer *application.Application
+ var app *echo.Echo
+ var ctx context.Context
+ var cancel context.CancelFunc
+
+ BeforeEach(func() {
+ var err error
+ tmpdir, err = os.MkdirTemp("", "")
+ Expect(err).ToNot(HaveOccurred())
+
+ modelDir := filepath.Join(tmpdir, "models")
+ err = os.Mkdir(modelDir, 0750)
+ Expect(err).ToNot(HaveOccurred())
+
+ ctx, cancel = context.WithCancel(context.Background())
+
+ systemState, err := system.GetSystemState(
+ system.WithModelPath(modelDir),
+ )
+ Expect(err).ToNot(HaveOccurred())
+
+ grpc.Provide("embedded://fake", &fakeAI{})
+
+ appServer, err = application.New(
+ config.WithContext(ctx),
+ config.WithSystemState(systemState),
+ config.WithApiKeys([]string{testAPIKey}),
+ config.WithGeneratedContentDir(tmpdir),
+ config.WithExternalBackend("fake", "embedded://fake"),
+ )
+ Expect(err).ToNot(HaveOccurred())
+ })
+
+ AfterEach(func() {
+ cancel()
+ if app != nil {
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+ _ = app.Shutdown(ctx)
+ }
+ _ = os.RemoveAll(tmpdir)
+ })
+
+ It("accepts OpenAI-style video create and delegates to backend", func() {
+ var err error
+ app, err = API(appServer)
+ Expect(err).ToNot(HaveOccurred())
+ go func() {
+ if err := app.Start("127.0.0.1:9091"); err != nil && err != http.ErrServerClosed {
+ // Log error if needed
+ }
+ }()
+
+ // wait for server
+ client := &http.Client{Timeout: 5 * time.Second}
+ Eventually(func() error {
+ req, _ := http.NewRequest("GET", "http://127.0.0.1:9091/v1/models", nil)
+ req.Header.Set("Authorization", "Bearer "+testAPIKey)
+ resp, err := client.Do(req)
+ if err != nil {
+ return err
+ }
+ defer resp.Body.Close()
+ if resp.StatusCode >= 400 {
+ return fmt.Errorf("bad status: %d", resp.StatusCode)
+ }
+ return nil
+ }, "30s", "500ms").Should(Succeed())
+
+ body := map[string]interface{}{
+ "model": "fake-model",
+ "backend": "fake",
+ "prompt": "a test video",
+ "size": "256x256",
+ "seconds": "1",
+ }
+ payload, err := json.Marshal(body)
+ Expect(err).ToNot(HaveOccurred())
+
+ req, err := http.NewRequest("POST", "http://127.0.0.1:9091/v1/videos", bytes.NewBuffer(payload))
+ Expect(err).ToNot(HaveOccurred())
+ req.Header.Set("Content-Type", "application/json")
+ req.Header.Set("Authorization", "Bearer "+testAPIKey)
+
+ resp, err := client.Do(req)
+ Expect(err).ToNot(HaveOccurred())
+ defer resp.Body.Close()
+ Expect(resp.StatusCode).To(Equal(200))
+
+ dat, err := io.ReadAll(resp.Body)
+ Expect(err).ToNot(HaveOccurred())
+
+ var out map[string]interface{}
+ err = json.Unmarshal(dat, &out)
+ Expect(err).ToNot(HaveOccurred())
+ data, ok := out["data"].([]interface{})
+ Expect(ok).To(BeTrue())
+ Expect(len(data)).To(BeNumerically(">", 0))
+ first := data[0].(map[string]interface{})
+ url, ok := first["url"].(string)
+ Expect(ok).To(BeTrue())
+ Expect(url).To(ContainSubstring("/generated-videos/"))
+ Expect(url).To(ContainSubstring(".mp4"))
+ })
+})
diff --git a/core/http/render.go b/core/http/render.go
new file mode 100644
index 0000000000000000000000000000000000000000..569c779877205a8dc73186fa2a43bf732a2ebb85
--- /dev/null
+++ b/core/http/render.go
@@ -0,0 +1,89 @@
+package http
+
+import (
+ "embed"
+ "fmt"
+ "html/template"
+ "io"
+ "io/fs"
+ "net/http"
+ "strings"
+
+ "github.com/Masterminds/sprig/v3"
+ "github.com/labstack/echo/v4"
+ "github.com/microcosm-cc/bluemonday"
+ "github.com/mudler/LocalAI/core/http/middleware"
+ "github.com/mudler/LocalAI/core/schema"
+ "github.com/russross/blackfriday"
+)
+
+//go:embed views/*
+var viewsfs embed.FS
+
+// TemplateRenderer is a custom template renderer for Echo
+type TemplateRenderer struct {
+ templates *template.Template
+}
+
+// Render renders a template document
+func (t *TemplateRenderer) Render(w io.Writer, name string, data interface{}, c echo.Context) error {
+ return t.templates.ExecuteTemplate(w, name, data)
+}
+
+func notFoundHandler(c echo.Context) error {
+ // Check if the request accepts JSON
+ contentType := c.Request().Header.Get("Content-Type")
+ accept := c.Request().Header.Get("Accept")
+ if strings.Contains(contentType, "application/json") || !strings.Contains(accept, "text/html") {
+ // The client expects a JSON response
+ return c.JSON(http.StatusNotFound, schema.ErrorResponse{
+ Error: &schema.APIError{Message: "Resource not found", Code: http.StatusNotFound},
+ })
+ } else {
+ // The client expects an HTML response
+ return c.Render(http.StatusNotFound, "views/404", map[string]interface{}{
+ "BaseURL": middleware.BaseURL(c),
+ })
+ }
+}
+
+func renderEngine() *TemplateRenderer {
+ // Parse all templates from embedded filesystem
+ tmpl := template.New("").Funcs(sprig.FuncMap())
+ tmpl = tmpl.Funcs(template.FuncMap{
+ "MDToHTML": markDowner,
+ })
+
+ // Recursively walk through embedded filesystem and parse all HTML templates
+ err := fs.WalkDir(viewsfs, "views", func(path string, d fs.DirEntry, err error) error {
+ if err != nil {
+ return err
+ }
+ if !d.IsDir() && strings.HasSuffix(path, ".html") {
+ data, err := viewsfs.ReadFile(path)
+ if err == nil {
+ // Remove .html extension to get template name (e.g., "views/index.html" -> "views/index")
+ templateName := strings.TrimSuffix(path, ".html")
+ _, err := tmpl.New(templateName).Parse(string(data))
+ if err != nil {
+ // If parsing fails, try parsing without explicit name (for templates with {{define}})
+ tmpl.Parse(string(data))
+ }
+ }
+ }
+ return nil
+ })
+ if err != nil {
+ // Log error but continue - templates might still work
+ fmt.Printf("Error walking views directory: %v\n", err)
+ }
+
+ return &TemplateRenderer{
+ templates: tmpl,
+ }
+}
+
+func markDowner(args ...interface{}) template.HTML {
+ s := blackfriday.MarkdownCommon([]byte(fmt.Sprintf("%s", args...)))
+ return template.HTML(bluemonday.UGCPolicy().Sanitize(string(s)))
+}
diff --git a/core/http/routes/anthropic.go b/core/http/routes/anthropic.go
new file mode 100644
index 0000000000000000000000000000000000000000..9f7050c27edf70a67db4a2eeb07038efc2d67d41
--- /dev/null
+++ b/core/http/routes/anthropic.go
@@ -0,0 +1,108 @@
+package routes
+
+import (
+ "context"
+ "fmt"
+ "net/http"
+
+ "github.com/google/uuid"
+ "github.com/labstack/echo/v4"
+ "github.com/mudler/LocalAI/core/application"
+ "github.com/mudler/LocalAI/core/config"
+ "github.com/mudler/LocalAI/core/http/endpoints/anthropic"
+ "github.com/mudler/LocalAI/core/http/middleware"
+ "github.com/mudler/LocalAI/core/schema"
+ "github.com/mudler/xlog"
+)
+
+func RegisterAnthropicRoutes(app *echo.Echo,
+ re *middleware.RequestExtractor,
+ application *application.Application) {
+
+ // Anthropic Messages API endpoint
+ messagesHandler := anthropic.MessagesEndpoint(
+ application.ModelConfigLoader(),
+ application.ModelLoader(),
+ application.TemplatesEvaluator(),
+ application.ApplicationConfig(),
+ )
+
+ messagesMiddleware := []echo.MiddlewareFunc{
+ middleware.TraceMiddleware(application),
+ re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_CHAT)),
+ re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.AnthropicRequest) }),
+ setAnthropicRequestContext(application.ApplicationConfig()),
+ }
+
+ // Main Anthropic endpoint
+ app.POST("/v1/messages", messagesHandler, messagesMiddleware...)
+
+ // Also support without version prefix for compatibility
+ app.POST("/messages", messagesHandler, messagesMiddleware...)
+}
+
+// setAnthropicRequestContext sets up the context and cancel function for Anthropic requests
+func setAnthropicRequestContext(appConfig *config.ApplicationConfig) echo.MiddlewareFunc {
+ return func(next echo.HandlerFunc) echo.HandlerFunc {
+ return func(c echo.Context) error {
+ input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.AnthropicRequest)
+ if !ok || input.Model == "" {
+ return echo.NewHTTPError(http.StatusBadRequest, "model is required")
+ }
+
+ cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
+ if !ok || cfg == nil {
+ return echo.NewHTTPError(http.StatusBadRequest, "model configuration not found")
+ }
+
+ // Extract or generate the correlation ID
+ // Anthropic uses x-request-id header
+ correlationID := c.Request().Header.Get("x-request-id")
+ if correlationID == "" {
+ correlationID = uuid.New().String()
+ }
+ c.Response().Header().Set("x-request-id", correlationID)
+
+ // Set up context with cancellation
+ reqCtx := c.Request().Context()
+ c1, cancel := context.WithCancel(appConfig.Context)
+
+ // Cancel when request context is cancelled (client disconnects)
+ go func() {
+ select {
+ case <-reqCtx.Done():
+ cancel()
+ case <-c1.Done():
+ // Already cancelled
+ }
+ }()
+
+ // Add the correlation ID to the new context
+ ctxWithCorrelationID := context.WithValue(c1, middleware.CorrelationIDKey, correlationID)
+
+ input.Context = ctxWithCorrelationID
+ input.Cancel = cancel
+
+ if cfg.Model == "" {
+ xlog.Debug("replacing empty cfg.Model with input value", "input.Model", input.Model)
+ cfg.Model = input.Model
+ }
+
+ c.Set(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, input)
+ c.Set(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG, cfg)
+
+ // Log the Anthropic API version if provided
+ anthropicVersion := c.Request().Header.Get("anthropic-version")
+ if anthropicVersion != "" {
+ xlog.Debug("Anthropic API version", "version", anthropicVersion)
+ }
+
+ // Validate max_tokens is provided
+ if input.MaxTokens <= 0 {
+ return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("max_tokens is required and must be greater than 0"))
+ }
+
+ return next(c)
+ }
+ }
+}
diff --git a/core/http/routes/elevenlabs.go b/core/http/routes/elevenlabs.go
new file mode 100644
index 0000000000000000000000000000000000000000..90f73eec6417de00df4b65b4a6ee20366df12530
--- /dev/null
+++ b/core/http/routes/elevenlabs.go
@@ -0,0 +1,31 @@
+package routes
+
+import (
+ "github.com/labstack/echo/v4"
+ "github.com/mudler/LocalAI/core/config"
+ "github.com/mudler/LocalAI/core/http/endpoints/elevenlabs"
+ "github.com/mudler/LocalAI/core/http/middleware"
+ "github.com/mudler/LocalAI/core/schema"
+ "github.com/mudler/LocalAI/pkg/model"
+)
+
+func RegisterElevenLabsRoutes(app *echo.Echo,
+ re *middleware.RequestExtractor,
+ cl *config.ModelConfigLoader,
+ ml *model.ModelLoader,
+ appConfig *config.ApplicationConfig) {
+
+ // Elevenlabs
+ ttsHandler := elevenlabs.TTSEndpoint(cl, ml, appConfig)
+ app.POST("/v1/text-to-speech/:voice-id",
+ ttsHandler,
+ re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_TTS)),
+ re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.ElevenLabsTTSRequest) }))
+
+ soundGenHandler := elevenlabs.SoundGenerationEndpoint(cl, ml, appConfig)
+ app.POST("/v1/sound-generation",
+ soundGenHandler,
+ re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_SOUND_GENERATION)),
+ re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.ElevenLabsSoundGenerationRequest) }))
+
+}
diff --git a/core/http/routes/explorer.go b/core/http/routes/explorer.go
new file mode 100644
index 0000000000000000000000000000000000000000..670bf67c42fddcdc0efc7fd8be102c044be1479b
--- /dev/null
+++ b/core/http/routes/explorer.go
@@ -0,0 +1,13 @@
+package routes
+
+import (
+ "github.com/labstack/echo/v4"
+ coreExplorer "github.com/mudler/LocalAI/core/explorer"
+ "github.com/mudler/LocalAI/core/http/endpoints/explorer"
+)
+
+func RegisterExplorerRoutes(app *echo.Echo, db *coreExplorer.Database) {
+ app.GET("/", explorer.Dashboard())
+ app.POST("/network/add", explorer.AddNetwork(db))
+ app.GET("/networks", explorer.ShowNetworks(db))
+}
diff --git a/core/http/routes/health.go b/core/http/routes/health.go
new file mode 100644
index 0000000000000000000000000000000000000000..5b03953733d882dc6101e244647b61ba09100ab9
--- /dev/null
+++ b/core/http/routes/health.go
@@ -0,0 +1,15 @@
+package routes
+
+import (
+ "github.com/labstack/echo/v4"
+)
+
+func HealthRoutes(app *echo.Echo) {
+ // Service health checks
+ ok := func(c echo.Context) error {
+ return c.NoContent(200)
+ }
+
+ app.GET("/healthz", ok)
+ app.GET("/readyz", ok)
+}
diff --git a/core/http/routes/jina.go b/core/http/routes/jina.go
new file mode 100644
index 0000000000000000000000000000000000000000..b4fafbc57f5001dc63755ef3a671c7471c2ec935
--- /dev/null
+++ b/core/http/routes/jina.go
@@ -0,0 +1,25 @@
+package routes
+
+import (
+ "github.com/labstack/echo/v4"
+ "github.com/mudler/LocalAI/core/config"
+ "github.com/mudler/LocalAI/core/http/endpoints/jina"
+ "github.com/mudler/LocalAI/core/http/middleware"
+ "github.com/mudler/LocalAI/core/schema"
+
+ "github.com/mudler/LocalAI/pkg/model"
+)
+
+func RegisterJINARoutes(app *echo.Echo,
+ re *middleware.RequestExtractor,
+ cl *config.ModelConfigLoader,
+ ml *model.ModelLoader,
+ appConfig *config.ApplicationConfig) {
+
+ // POST endpoint to mimic the reranking
+ rerankHandler := jina.JINARerankEndpoint(cl, ml, appConfig)
+ app.POST("/v1/rerank",
+ rerankHandler,
+ re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_RERANK)),
+ re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.JINARerankRequest) }))
+}
diff --git a/core/http/routes/localai.go b/core/http/routes/localai.go
new file mode 100644
index 0000000000000000000000000000000000000000..f70a44b2109c7f8c7be6670191f4cd3379ecd5b3
--- /dev/null
+++ b/core/http/routes/localai.go
@@ -0,0 +1,178 @@
+package routes
+
+import (
+ "github.com/labstack/echo/v4"
+ "github.com/mudler/LocalAI/core/application"
+ "github.com/mudler/LocalAI/core/config"
+ "github.com/mudler/LocalAI/core/http/endpoints/localai"
+ "github.com/mudler/LocalAI/core/http/middleware"
+ "github.com/mudler/LocalAI/core/schema"
+ "github.com/mudler/LocalAI/core/services"
+ "github.com/mudler/LocalAI/core/templates"
+ "github.com/mudler/LocalAI/internal"
+ "github.com/mudler/LocalAI/pkg/model"
+ echoswagger "github.com/swaggo/echo-swagger"
+)
+
+func RegisterLocalAIRoutes(router *echo.Echo,
+ requestExtractor *middleware.RequestExtractor,
+ cl *config.ModelConfigLoader,
+ ml *model.ModelLoader,
+ appConfig *config.ApplicationConfig,
+ galleryService *services.GalleryService,
+ opcache *services.OpCache,
+ evaluator *templates.Evaluator,
+ app *application.Application) {
+
+ router.GET("/swagger/*", echoswagger.WrapHandler) // default
+
+ // LocalAI API endpoints
+ if !appConfig.DisableGalleryEndpoint {
+ // Import model page
+ router.GET("/import-model", func(c echo.Context) error {
+ return c.Render(200, "views/model-editor", map[string]interface{}{
+ "Title": "LocalAI - Import Model",
+ "BaseURL": middleware.BaseURL(c),
+ "Version": internal.PrintableVersion(),
+ })
+ })
+
+ // Edit model page
+ router.GET("/models/edit/:name", localai.GetEditModelPage(cl, appConfig))
+ modelGalleryEndpointService := localai.CreateModelGalleryEndpointService(appConfig.Galleries, appConfig.BackendGalleries, appConfig.SystemState, galleryService)
+ router.POST("/models/apply", modelGalleryEndpointService.ApplyModelGalleryEndpoint())
+ router.POST("/models/delete/:name", modelGalleryEndpointService.DeleteModelGalleryEndpoint())
+
+ router.GET("/models/available", modelGalleryEndpointService.ListModelFromGalleryEndpoint(appConfig.SystemState))
+ router.GET("/models/galleries", modelGalleryEndpointService.ListModelGalleriesEndpoint())
+ router.GET("/models/jobs/:uuid", modelGalleryEndpointService.GetOpStatusEndpoint())
+ router.GET("/models/jobs", modelGalleryEndpointService.GetAllStatusEndpoint())
+
+ backendGalleryEndpointService := localai.CreateBackendEndpointService(
+ appConfig.BackendGalleries,
+ appConfig.SystemState,
+ galleryService)
+ router.POST("/backends/apply", backendGalleryEndpointService.ApplyBackendEndpoint())
+ router.POST("/backends/delete/:name", backendGalleryEndpointService.DeleteBackendEndpoint())
+ router.GET("/backends", backendGalleryEndpointService.ListBackendsEndpoint(appConfig.SystemState))
+ router.GET("/backends/available", backendGalleryEndpointService.ListAvailableBackendsEndpoint(appConfig.SystemState))
+ router.GET("/backends/galleries", backendGalleryEndpointService.ListBackendGalleriesEndpoint())
+ router.GET("/backends/jobs/:uuid", backendGalleryEndpointService.GetOpStatusEndpoint())
+ // Custom model import endpoint
+ router.POST("/models/import", localai.ImportModelEndpoint(cl, appConfig))
+
+ // URI model import endpoint
+ router.POST("/models/import-uri", localai.ImportModelURIEndpoint(cl, appConfig, galleryService, opcache))
+
+ // Custom model edit endpoint
+ router.POST("/models/edit/:name", localai.EditModelEndpoint(cl, appConfig))
+
+ // Reload models endpoint
+ router.POST("/models/reload", localai.ReloadModelsEndpoint(cl, appConfig))
+ }
+
+ detectionHandler := localai.DetectionEndpoint(cl, ml, appConfig)
+ router.POST("/v1/detection",
+ detectionHandler,
+ requestExtractor.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_DETECTION)),
+ requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.DetectionRequest) }))
+
+ ttsHandler := localai.TTSEndpoint(cl, ml, appConfig)
+ router.POST("/tts",
+ ttsHandler,
+ requestExtractor.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_TTS)),
+ requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.TTSRequest) }))
+
+ vadHandler := localai.VADEndpoint(cl, ml, appConfig)
+ router.POST("/vad",
+ vadHandler,
+ requestExtractor.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_VAD)),
+ requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.VADRequest) }))
+ router.POST("/v1/vad",
+ vadHandler,
+ requestExtractor.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_VAD)),
+ requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.VADRequest) }))
+
+ // Stores
+ router.POST("/stores/set", localai.StoresSetEndpoint(ml, appConfig))
+ router.POST("/stores/delete", localai.StoresDeleteEndpoint(ml, appConfig))
+ router.POST("/stores/get", localai.StoresGetEndpoint(ml, appConfig))
+ router.POST("/stores/find", localai.StoresFindEndpoint(ml, appConfig))
+
+ if !appConfig.DisableMetrics {
+ router.GET("/metrics", localai.LocalAIMetricsEndpoint())
+ }
+
+ videoHandler := localai.VideoEndpoint(cl, ml, appConfig)
+ router.POST("/video",
+ videoHandler,
+ requestExtractor.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_VIDEO)),
+ requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.VideoRequest) }))
+
+ // Backend Statistics Module
+ // TODO: Should these use standard middlewares? Refactor later, they are extremely simple.
+ backendMonitorService := services.NewBackendMonitorService(ml, cl, appConfig) // Split out for now
+ router.GET("/backend/monitor", localai.BackendMonitorEndpoint(backendMonitorService))
+ router.POST("/backend/shutdown", localai.BackendShutdownEndpoint(backendMonitorService))
+ // The v1/* urls are exactly the same as above - makes local e2e testing easier if they are registered.
+ router.GET("/v1/backend/monitor", localai.BackendMonitorEndpoint(backendMonitorService))
+ router.POST("/v1/backend/shutdown", localai.BackendShutdownEndpoint(backendMonitorService))
+
+ // p2p
+ router.GET("/api/p2p", localai.ShowP2PNodes(appConfig))
+ router.GET("/api/p2p/token", localai.ShowP2PToken(appConfig))
+
+ router.GET("/version", func(c echo.Context) error {
+ return c.JSON(200, struct {
+ Version string `json:"version"`
+ }{Version: internal.PrintableVersion()})
+ })
+
+ router.GET("/system", localai.SystemInformations(ml, appConfig))
+
+ // misc
+ tokenizeHandler := localai.TokenizeEndpoint(cl, ml, appConfig)
+ router.POST("/v1/tokenize",
+ tokenizeHandler,
+ requestExtractor.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_TOKENIZE)),
+ requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.TokenizeRequest) }))
+
+ // MCP endpoint - supports both streaming and non-streaming modes
+ // Note: streaming mode is NOT compatible with the OpenAI apis. We have a set which streams more states.
+ if evaluator != nil {
+ mcpStreamHandler := localai.MCPEndpoint(cl, ml, evaluator, appConfig)
+ mcpStreamMiddleware := []echo.MiddlewareFunc{
+ requestExtractor.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_CHAT)),
+ requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }),
+ func(next echo.HandlerFunc) echo.HandlerFunc {
+ return func(c echo.Context) error {
+ if err := requestExtractor.SetOpenAIRequest(c); err != nil {
+ return err
+ }
+ return next(c)
+ }
+ },
+ }
+ router.POST("/v1/mcp/chat/completions", mcpStreamHandler, mcpStreamMiddleware...)
+ router.POST("/mcp/v1/chat/completions", mcpStreamHandler, mcpStreamMiddleware...)
+ router.POST("/mcp/chat/completions", mcpStreamHandler, mcpStreamMiddleware...)
+ }
+
+ // Agent job routes
+ if app != nil && app.AgentJobService() != nil {
+ router.POST("/api/agent/tasks", localai.CreateTaskEndpoint(app))
+ router.PUT("/api/agent/tasks/:id", localai.UpdateTaskEndpoint(app))
+ router.DELETE("/api/agent/tasks/:id", localai.DeleteTaskEndpoint(app))
+ router.GET("/api/agent/tasks", localai.ListTasksEndpoint(app))
+ router.GET("/api/agent/tasks/:id", localai.GetTaskEndpoint(app))
+
+ router.POST("/api/agent/jobs/execute", localai.ExecuteJobEndpoint(app))
+ router.GET("/api/agent/jobs/:id", localai.GetJobEndpoint(app))
+ router.GET("/api/agent/jobs", localai.ListJobsEndpoint(app))
+ router.POST("/api/agent/jobs/:id/cancel", localai.CancelJobEndpoint(app))
+ router.DELETE("/api/agent/jobs/:id", localai.DeleteJobEndpoint(app))
+
+ router.POST("/api/agent/tasks/:name/execute", localai.ExecuteTaskByNameEndpoint(app))
+ }
+
+}
diff --git a/core/http/routes/openai.go b/core/http/routes/openai.go
new file mode 100644
index 0000000000000000000000000000000000000000..2d62859f317fcd25920dfa3761d6c5d79833a5e6
--- /dev/null
+++ b/core/http/routes/openai.go
@@ -0,0 +1,179 @@
+package routes
+
+import (
+ "github.com/labstack/echo/v4"
+ "github.com/mudler/LocalAI/core/application"
+ "github.com/mudler/LocalAI/core/config"
+ "github.com/mudler/LocalAI/core/http/endpoints/localai"
+ "github.com/mudler/LocalAI/core/http/endpoints/openai"
+ "github.com/mudler/LocalAI/core/http/middleware"
+ "github.com/mudler/LocalAI/core/schema"
+)
+
+func RegisterOpenAIRoutes(app *echo.Echo,
+ re *middleware.RequestExtractor,
+ application *application.Application) {
+ // openAI compatible API endpoint
+ traceMiddleware := middleware.TraceMiddleware(application)
+
+ // realtime
+ // TODO: Modify/disable the API key middleware for this endpoint to allow ephemeral keys created by sessions
+ app.GET("/v1/realtime", openai.Realtime(application))
+ app.POST("/v1/realtime/sessions", openai.RealtimeTranscriptionSession(application), traceMiddleware)
+ app.POST("/v1/realtime/transcription_session", openai.RealtimeTranscriptionSession(application), traceMiddleware)
+
+ // chat
+ chatHandler := openai.ChatEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.TemplatesEvaluator(), application.ApplicationConfig())
+ chatMiddleware := []echo.MiddlewareFunc{
+ traceMiddleware,
+ re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_CHAT)),
+ re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }),
+ func(next echo.HandlerFunc) echo.HandlerFunc {
+ return func(c echo.Context) error {
+ if err := re.SetOpenAIRequest(c); err != nil {
+ return err
+ }
+ return next(c)
+ }
+ },
+ }
+ app.POST("/v1/chat/completions", chatHandler, chatMiddleware...)
+ app.POST("/chat/completions", chatHandler, chatMiddleware...)
+
+ // edit
+ editHandler := openai.EditEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.TemplatesEvaluator(), application.ApplicationConfig())
+ editMiddleware := []echo.MiddlewareFunc{
+ traceMiddleware,
+ re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_EDIT)),
+ re.BuildConstantDefaultModelNameMiddleware("gpt-4o"),
+ re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }),
+ func(next echo.HandlerFunc) echo.HandlerFunc {
+ return func(c echo.Context) error {
+ if err := re.SetOpenAIRequest(c); err != nil {
+ return err
+ }
+ return next(c)
+ }
+ },
+ }
+ app.POST("/v1/edits", editHandler, editMiddleware...)
+ app.POST("/edits", editHandler, editMiddleware...)
+
+ // completion
+ completionHandler := openai.CompletionEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.TemplatesEvaluator(), application.ApplicationConfig())
+ completionMiddleware := []echo.MiddlewareFunc{
+ traceMiddleware,
+ re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_COMPLETION)),
+ re.BuildConstantDefaultModelNameMiddleware("gpt-4o"),
+ re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }),
+ func(next echo.HandlerFunc) echo.HandlerFunc {
+ return func(c echo.Context) error {
+ if err := re.SetOpenAIRequest(c); err != nil {
+ return err
+ }
+ return next(c)
+ }
+ },
+ }
+ app.POST("/v1/completions", completionHandler, completionMiddleware...)
+ app.POST("/completions", completionHandler, completionMiddleware...)
+ app.POST("/v1/engines/:model/completions", completionHandler, completionMiddleware...)
+
+ // embeddings
+ embeddingHandler := openai.EmbeddingsEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
+ embeddingMiddleware := []echo.MiddlewareFunc{
+ traceMiddleware,
+ re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_EMBEDDINGS)),
+ re.BuildConstantDefaultModelNameMiddleware("gpt-4o"),
+ re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }),
+ func(next echo.HandlerFunc) echo.HandlerFunc {
+ return func(c echo.Context) error {
+ if err := re.SetOpenAIRequest(c); err != nil {
+ return err
+ }
+ return next(c)
+ }
+ },
+ }
+ app.POST("/v1/embeddings", embeddingHandler, embeddingMiddleware...)
+ app.POST("/embeddings", embeddingHandler, embeddingMiddleware...)
+ app.POST("/v1/engines/:model/embeddings", embeddingHandler, embeddingMiddleware...)
+
+ audioHandler := openai.TranscriptEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
+ audioMiddleware := []echo.MiddlewareFunc{
+ traceMiddleware,
+ re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_TRANSCRIPT)),
+ re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }),
+ func(next echo.HandlerFunc) echo.HandlerFunc {
+ return func(c echo.Context) error {
+ if err := re.SetOpenAIRequest(c); err != nil {
+ return err
+ }
+ return next(c)
+ }
+ },
+ }
+ // audio
+ app.POST("/v1/audio/transcriptions", audioHandler, audioMiddleware...)
+ app.POST("/audio/transcriptions", audioHandler, audioMiddleware...)
+
+ audioSpeechHandler := localai.TTSEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
+ audioSpeechMiddleware := []echo.MiddlewareFunc{
+ traceMiddleware,
+ re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_TTS)),
+ re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.TTSRequest) }),
+ }
+
+ app.POST("/v1/audio/speech", audioSpeechHandler, audioSpeechMiddleware...)
+ app.POST("/audio/speech", audioSpeechHandler, audioSpeechMiddleware...)
+
+ // images
+ imageHandler := openai.ImageEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
+ imageMiddleware := []echo.MiddlewareFunc{
+ traceMiddleware,
+ // Default: use the first available image generation model
+ re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_IMAGE)),
+ re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }),
+ func(next echo.HandlerFunc) echo.HandlerFunc {
+ return func(c echo.Context) error {
+ if err := re.SetOpenAIRequest(c); err != nil {
+ return err
+ }
+ return next(c)
+ }
+ },
+ }
+
+ app.POST("/v1/images/generations", imageHandler, imageMiddleware...)
+ app.POST("/images/generations", imageHandler, imageMiddleware...)
+
+ // inpainting endpoint (image + mask) - reuse same middleware config as images
+ inpaintingHandler := openai.InpaintingEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
+ app.POST("/v1/images/inpainting", inpaintingHandler, imageMiddleware...)
+ app.POST("/images/inpainting", inpaintingHandler, imageMiddleware...)
+
+ // videos (OpenAI-compatible endpoints mapped to LocalAI video handler)
+ videoHandler := openai.VideoEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
+ videoMiddleware := []echo.MiddlewareFunc{
+ traceMiddleware,
+ re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_VIDEO)),
+ re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }),
+ func(next echo.HandlerFunc) echo.HandlerFunc {
+ return func(c echo.Context) error {
+ if err := re.SetOpenAIRequest(c); err != nil {
+ return err
+ }
+ return next(c)
+ }
+ },
+ }
+
+ // OpenAI-style create video endpoint
+ app.POST("/v1/videos", videoHandler, videoMiddleware...)
+ app.POST("/v1/videos/generations", videoHandler, videoMiddleware...)
+ app.POST("/videos", videoHandler, videoMiddleware...)
+
+ // List models
+ app.GET("/v1/models", openai.ListModelsEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig()))
+ app.GET("/models", openai.ListModelsEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig()))
+}
diff --git a/core/http/routes/ui.go b/core/http/routes/ui.go
new file mode 100644
index 0000000000000000000000000000000000000000..bfe93224a26de7752abe8d02121246296ecf7137
--- /dev/null
+++ b/core/http/routes/ui.go
@@ -0,0 +1,390 @@
+package routes
+
+import (
+ "github.com/labstack/echo/v4"
+ "github.com/mudler/LocalAI/core/config"
+ "github.com/mudler/LocalAI/core/gallery"
+ "github.com/mudler/LocalAI/core/http/endpoints/localai"
+ "github.com/mudler/LocalAI/core/http/middleware"
+ "github.com/mudler/LocalAI/core/services"
+ "github.com/mudler/LocalAI/internal"
+ "github.com/mudler/LocalAI/pkg/model"
+)
+
+func RegisterUIRoutes(app *echo.Echo,
+ cl *config.ModelConfigLoader,
+ ml *model.ModelLoader,
+ appConfig *config.ApplicationConfig,
+ galleryService *services.GalleryService) {
+
+ // keeps the state of ops that are started from the UI
+ var processingOps = services.NewOpCache(galleryService)
+
+ app.GET("/", localai.WelcomeEndpoint(appConfig, cl, ml, processingOps))
+ app.GET("/manage", localai.WelcomeEndpoint(appConfig, cl, ml, processingOps))
+
+ if !appConfig.DisableRuntimeSettings {
+ // Settings page
+ app.GET("/settings", func(c echo.Context) error {
+ summary := map[string]interface{}{
+ "Title": "LocalAI - Settings",
+ "BaseURL": middleware.BaseURL(c),
+ }
+ return c.Render(200, "views/settings", summary)
+ })
+ }
+
+ // Agent Jobs pages
+ app.GET("/agent-jobs", func(c echo.Context) error {
+ modelConfigs := cl.GetAllModelsConfigs()
+ summary := map[string]interface{}{
+ "Title": "LocalAI - Agent Jobs",
+ "BaseURL": middleware.BaseURL(c),
+ "Version": internal.PrintableVersion(),
+ "ModelsConfig": modelConfigs,
+ }
+ return c.Render(200, "views/agent-jobs", summary)
+ })
+
+ app.GET("/agent-jobs/tasks/new", func(c echo.Context) error {
+ modelConfigs := cl.GetAllModelsConfigs()
+ summary := map[string]interface{}{
+ "Title": "LocalAI - Create Task",
+ "BaseURL": middleware.BaseURL(c),
+ "Version": internal.PrintableVersion(),
+ "ModelsConfig": modelConfigs,
+ }
+ return c.Render(200, "views/agent-task-details", summary)
+ })
+
+ // More specific route must come first
+ app.GET("/agent-jobs/tasks/:id/edit", func(c echo.Context) error {
+ modelConfigs := cl.GetAllModelsConfigs()
+ summary := map[string]interface{}{
+ "Title": "LocalAI - Edit Task",
+ "BaseURL": middleware.BaseURL(c),
+ "Version": internal.PrintableVersion(),
+ "ModelsConfig": modelConfigs,
+ }
+ return c.Render(200, "views/agent-task-details", summary)
+ })
+
+ // Task details page (less specific, comes after edit route)
+ app.GET("/agent-jobs/tasks/:id", func(c echo.Context) error {
+ summary := map[string]interface{}{
+ "Title": "LocalAI - Task Details",
+ "BaseURL": middleware.BaseURL(c),
+ "Version": internal.PrintableVersion(),
+ }
+ return c.Render(200, "views/agent-task-details", summary)
+ })
+
+ app.GET("/agent-jobs/jobs/:id", func(c echo.Context) error {
+ summary := map[string]interface{}{
+ "Title": "LocalAI - Job Details",
+ "BaseURL": middleware.BaseURL(c),
+ "Version": internal.PrintableVersion(),
+ }
+ return c.Render(200, "views/agent-job-details", summary)
+ })
+
+ // P2P
+ app.GET("/p2p", func(c echo.Context) error {
+ summary := map[string]interface{}{
+ "Title": "LocalAI - P2P dashboard",
+ "BaseURL": middleware.BaseURL(c),
+ "Version": internal.PrintableVersion(),
+ //"Nodes": p2p.GetAvailableNodes(""),
+ //"FederatedNodes": p2p.GetAvailableNodes(p2p.FederatedID),
+
+ "P2PToken": appConfig.P2PToken,
+ "NetworkID": appConfig.P2PNetworkID,
+ }
+
+ // Render index
+ return c.Render(200, "views/p2p", summary)
+ })
+
+ // Note: P2P UI fragment routes (/p2p/ui/*) were removed
+ // P2P nodes are now fetched via JSON API at /api/p2p/workers and /api/p2p/federation
+
+ // End P2P
+
+ if !appConfig.DisableGalleryEndpoint {
+ registerGalleryRoutes(app, cl, appConfig, galleryService, processingOps)
+ registerBackendGalleryRoutes(app, appConfig, galleryService, processingOps)
+ }
+
+ app.GET("/talk", func(c echo.Context) error {
+ modelConfigs, _ := services.ListModels(cl, ml, config.NoFilterFn, services.SKIP_IF_CONFIGURED)
+
+ if len(modelConfigs) == 0 {
+ // If no model is available redirect to the index which suggests how to install models
+ return c.Redirect(302, middleware.BaseURL(c))
+ }
+
+ summary := map[string]interface{}{
+ "Title": "LocalAI - Talk",
+ "BaseURL": middleware.BaseURL(c),
+ "ModelsConfig": modelConfigs,
+ "Model": modelConfigs[0],
+
+ "Version": internal.PrintableVersion(),
+ }
+
+ // Render index
+ return c.Render(200, "views/talk", summary)
+ })
+
+ app.GET("/chat", func(c echo.Context) error {
+ modelConfigs := cl.GetAllModelsConfigs()
+ modelsWithoutConfig, _ := services.ListModels(cl, ml, config.NoFilterFn, services.LOOSE_ONLY)
+
+ if len(modelConfigs)+len(modelsWithoutConfig) == 0 {
+ // If no model is available redirect to the index which suggests how to install models
+ return c.Redirect(302, middleware.BaseURL(c))
+ }
+ modelThatCanBeUsed := ""
+ galleryConfigs := map[string]*gallery.ModelConfig{}
+
+ for _, m := range modelConfigs {
+ cfg, err := gallery.GetLocalModelConfiguration(ml.ModelPath, m.Name)
+ if err != nil {
+ continue
+ }
+ galleryConfigs[m.Name] = cfg
+ }
+
+ title := "LocalAI - Chat"
+ var modelContextSize *int
+
+ for _, b := range modelConfigs {
+ if b.HasUsecases(config.FLAG_CHAT) {
+ modelThatCanBeUsed = b.Name
+ title = "LocalAI - Chat with " + modelThatCanBeUsed
+ if b.LLMConfig.ContextSize != nil {
+ modelContextSize = b.LLMConfig.ContextSize
+ }
+ break
+ }
+ }
+
+ summary := map[string]interface{}{
+ "Title": title,
+ "BaseURL": middleware.BaseURL(c),
+ "ModelsWithoutConfig": modelsWithoutConfig,
+ "GalleryConfig": galleryConfigs,
+ "ModelsConfig": modelConfigs,
+ "Model": modelThatCanBeUsed,
+ "ContextSize": modelContextSize,
+ "Version": internal.PrintableVersion(),
+ }
+
+ // Render index
+ return c.Render(200, "views/chat", summary)
+ })
+
+ // Show the Chat page
+ app.GET("/chat/:model", func(c echo.Context) error {
+ modelConfigs := cl.GetAllModelsConfigs()
+ modelsWithoutConfig, _ := services.ListModels(cl, ml, config.NoFilterFn, services.LOOSE_ONLY)
+
+ galleryConfigs := map[string]*gallery.ModelConfig{}
+ modelName := c.Param("model")
+ var modelContextSize *int
+
+ for _, m := range modelConfigs {
+ cfg, err := gallery.GetLocalModelConfiguration(ml.ModelPath, m.Name)
+ if err != nil {
+ continue
+ }
+ galleryConfigs[m.Name] = cfg
+ if m.Name == modelName && m.LLMConfig.ContextSize != nil {
+ modelContextSize = m.LLMConfig.ContextSize
+ }
+ }
+
+ summary := map[string]interface{}{
+ "Title": "LocalAI - Chat with " + modelName,
+ "BaseURL": middleware.BaseURL(c),
+ "ModelsConfig": modelConfigs,
+ "GalleryConfig": galleryConfigs,
+ "ModelsWithoutConfig": modelsWithoutConfig,
+ "Model": modelName,
+ "ContextSize": modelContextSize,
+ "Version": internal.PrintableVersion(),
+ }
+
+ // Render index
+ return c.Render(200, "views/chat", summary)
+ })
+
+ app.GET("/image/:model", func(c echo.Context) error {
+ modelConfigs := cl.GetAllModelsConfigs()
+ modelsWithoutConfig, _ := services.ListModels(cl, ml, config.NoFilterFn, services.LOOSE_ONLY)
+
+ summary := map[string]interface{}{
+ "Title": "LocalAI - Generate images with " + c.Param("model"),
+ "BaseURL": middleware.BaseURL(c),
+ "ModelsConfig": modelConfigs,
+ "ModelsWithoutConfig": modelsWithoutConfig,
+ "Model": c.Param("model"),
+ "Version": internal.PrintableVersion(),
+ }
+
+ // Render index
+ return c.Render(200, "views/image", summary)
+ })
+
+ app.GET("/image", func(c echo.Context) error {
+ modelConfigs := cl.GetAllModelsConfigs()
+ modelsWithoutConfig, _ := services.ListModels(cl, ml, config.NoFilterFn, services.LOOSE_ONLY)
+
+ if len(modelConfigs)+len(modelsWithoutConfig) == 0 {
+ // If no model is available redirect to the index which suggests how to install models
+ return c.Redirect(302, middleware.BaseURL(c))
+ }
+
+ modelThatCanBeUsed := ""
+ title := "LocalAI - Generate images"
+
+ for _, b := range modelConfigs {
+ if b.HasUsecases(config.FLAG_IMAGE) {
+ modelThatCanBeUsed = b.Name
+ title = "LocalAI - Generate images with " + modelThatCanBeUsed
+ break
+ }
+ }
+
+ summary := map[string]interface{}{
+ "Title": title,
+ "BaseURL": middleware.BaseURL(c),
+ "ModelsConfig": modelConfigs,
+ "ModelsWithoutConfig": modelsWithoutConfig,
+ "Model": modelThatCanBeUsed,
+ "Version": internal.PrintableVersion(),
+ }
+
+ // Render index
+ return c.Render(200, "views/image", summary)
+ })
+
+ app.GET("/tts/:model", func(c echo.Context) error {
+ modelConfigs := cl.GetAllModelsConfigs()
+ modelsWithoutConfig, _ := services.ListModels(cl, ml, config.NoFilterFn, services.LOOSE_ONLY)
+
+ summary := map[string]interface{}{
+ "Title": "LocalAI - Generate images with " + c.Param("model"),
+ "BaseURL": middleware.BaseURL(c),
+ "ModelsConfig": modelConfigs,
+ "ModelsWithoutConfig": modelsWithoutConfig,
+ "Model": c.Param("model"),
+ "Version": internal.PrintableVersion(),
+ }
+
+ // Render index
+ return c.Render(200, "views/tts", summary)
+ })
+
+ app.GET("/tts", func(c echo.Context) error {
+ modelConfigs := cl.GetAllModelsConfigs()
+ modelsWithoutConfig, _ := services.ListModels(cl, ml, config.NoFilterFn, services.LOOSE_ONLY)
+
+ if len(modelConfigs)+len(modelsWithoutConfig) == 0 {
+ // If no model is available redirect to the index which suggests how to install models
+ return c.Redirect(302, middleware.BaseURL(c))
+ }
+
+ modelThatCanBeUsed := ""
+ title := "LocalAI - Generate audio"
+
+ for _, b := range modelConfigs {
+ if b.HasUsecases(config.FLAG_TTS) {
+ modelThatCanBeUsed = b.Name
+ title = "LocalAI - Generate audio with " + modelThatCanBeUsed
+ break
+ }
+ }
+ summary := map[string]interface{}{
+ "Title": title,
+ "BaseURL": middleware.BaseURL(c),
+ "ModelsConfig": modelConfigs,
+ "ModelsWithoutConfig": modelsWithoutConfig,
+ "Model": modelThatCanBeUsed,
+ "Version": internal.PrintableVersion(),
+ }
+
+ // Render index
+ return c.Render(200, "views/tts", summary)
+ })
+
+ app.GET("/video/:model", func(c echo.Context) error {
+ modelConfigs := cl.GetAllModelsConfigs()
+ modelsWithoutConfig, _ := services.ListModels(cl, ml, config.NoFilterFn, services.LOOSE_ONLY)
+
+ summary := map[string]interface{}{
+ "Title": "LocalAI - Generate videos with " + c.Param("model"),
+ "BaseURL": middleware.BaseURL(c),
+ "ModelsConfig": modelConfigs,
+ "ModelsWithoutConfig": modelsWithoutConfig,
+ "Model": c.Param("model"),
+ "Version": internal.PrintableVersion(),
+ }
+
+ // Render index
+ return c.Render(200, "views/video", summary)
+ })
+
+ app.GET("/video", func(c echo.Context) error {
+ modelConfigs := cl.GetAllModelsConfigs()
+ modelsWithoutConfig, _ := services.ListModels(cl, ml, config.NoFilterFn, services.LOOSE_ONLY)
+
+ if len(modelConfigs)+len(modelsWithoutConfig) == 0 {
+ // If no model is available redirect to the index which suggests how to install models
+ return c.Redirect(302, middleware.BaseURL(c))
+ }
+
+ modelThatCanBeUsed := ""
+ title := "LocalAI - Generate videos"
+
+ for _, b := range modelConfigs {
+ if b.HasUsecases(config.FLAG_VIDEO) {
+ modelThatCanBeUsed = b.Name
+ title = "LocalAI - Generate videos with " + modelThatCanBeUsed
+ break
+ }
+ }
+
+ summary := map[string]interface{}{
+ "Title": title,
+ "BaseURL": middleware.BaseURL(c),
+ "ModelsConfig": modelConfigs,
+ "ModelsWithoutConfig": modelsWithoutConfig,
+ "Model": modelThatCanBeUsed,
+ "Version": internal.PrintableVersion(),
+ }
+
+ // Render index
+ return c.Render(200, "views/video", summary)
+ })
+
+ // Traces UI
+ app.GET("/traces", func(c echo.Context) error {
+ summary := map[string]interface{}{
+ "Title": "LocalAI - Traces",
+ "BaseURL": middleware.BaseURL(c),
+ "Version": internal.PrintableVersion(),
+ }
+ return c.Render(200, "views/traces", summary)
+ })
+
+ app.GET("/api/traces", func(c echo.Context) error {
+ return c.JSON(200, middleware.GetTraces())
+ })
+
+ app.POST("/api/traces/clear", func(c echo.Context) error {
+ middleware.ClearTraces()
+ return c.NoContent(204)
+ })
+
+}
diff --git a/core/http/routes/ui_api.go b/core/http/routes/ui_api.go
new file mode 100644
index 0000000000000000000000000000000000000000..31dd66e1ff25ae1ef5f15dc9c574947ffdae575e
--- /dev/null
+++ b/core/http/routes/ui_api.go
@@ -0,0 +1,978 @@
+package routes
+
+import (
+ "context"
+ "fmt"
+ "math"
+ "net/http"
+ "net/url"
+ "sort"
+ "strconv"
+ "strings"
+
+ "github.com/google/uuid"
+ "github.com/labstack/echo/v4"
+ "github.com/mudler/LocalAI/core/application"
+ "github.com/mudler/LocalAI/core/config"
+ "github.com/mudler/LocalAI/core/gallery"
+ "github.com/mudler/LocalAI/core/http/endpoints/localai"
+ "github.com/mudler/LocalAI/core/http/middleware"
+ "github.com/mudler/LocalAI/core/p2p"
+ "github.com/mudler/LocalAI/core/services"
+ "github.com/mudler/LocalAI/pkg/model"
+ "github.com/mudler/LocalAI/pkg/xsysinfo"
+ "github.com/mudler/xlog"
+)
+
+const (
+ nameSortFieldName = "name"
+ repositorySortFieldName = "repository"
+ licenseSortFieldName = "license"
+ statusSortFieldName = "status"
+ ascSortOrder = "asc"
+)
+
+// RegisterUIAPIRoutes registers JSON API routes for the web UI
+func RegisterUIAPIRoutes(app *echo.Echo, cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig, galleryService *services.GalleryService, opcache *services.OpCache, applicationInstance *application.Application) {
+
+ // Operations API - Get all current operations (models + backends)
+ app.GET("/api/operations", func(c echo.Context) error {
+ processingData, taskTypes := opcache.GetStatus()
+
+ operations := []map[string]interface{}{}
+ for galleryID, jobID := range processingData {
+ taskType := "installation"
+ if tt, ok := taskTypes[galleryID]; ok {
+ taskType = tt
+ }
+
+ status := galleryService.GetStatus(jobID)
+ progress := 0
+ isDeletion := false
+ isQueued := false
+ isCancelled := false
+ isCancellable := false
+ message := ""
+
+ if status != nil {
+ // Skip completed operations (unless cancelled and not yet cleaned up)
+ if status.Processed && !status.Cancelled {
+ continue
+ }
+ // Skip cancelled operations that are processed (they're done, no need to show)
+ if status.Processed && status.Cancelled {
+ continue
+ }
+
+ progress = int(status.Progress)
+ isDeletion = status.Deletion
+ isCancelled = status.Cancelled
+ isCancellable = status.Cancellable
+ message = status.Message
+ if isDeletion {
+ taskType = "deletion"
+ }
+ if isCancelled {
+ taskType = "cancelled"
+ }
+ } else {
+ // Job is queued but hasn't started
+ isQueued = true
+ isCancellable = true
+ message = "Operation queued"
+ }
+
+ // Determine if it's a model or backend
+ // First check if it was explicitly marked as a backend operation
+ isBackend := opcache.IsBackendOp(galleryID)
+ // If not explicitly marked, check if it matches a known backend from the gallery
+ if !isBackend {
+ backends, _ := gallery.AvailableBackends(appConfig.BackendGalleries, appConfig.SystemState)
+ for _, b := range backends {
+ backendID := fmt.Sprintf("%s@%s", b.Gallery.Name, b.Name)
+ if backendID == galleryID || b.Name == galleryID {
+ isBackend = true
+ break
+ }
+ }
+ }
+
+ // Extract display name (remove repo prefix if exists)
+ displayName := galleryID
+ if strings.Contains(galleryID, "@") {
+ parts := strings.Split(galleryID, "@")
+ if len(parts) > 1 {
+ displayName = parts[1]
+ }
+ }
+
+ operations = append(operations, map[string]interface{}{
+ "id": galleryID,
+ "name": displayName,
+ "fullName": galleryID,
+ "jobID": jobID,
+ "progress": progress,
+ "taskType": taskType,
+ "isDeletion": isDeletion,
+ "isBackend": isBackend,
+ "isQueued": isQueued,
+ "isCancelled": isCancelled,
+ "cancellable": isCancellable,
+ "message": message,
+ })
+ }
+
+ // Sort operations by progress (ascending), then by ID for stable display order
+ sort.Slice(operations, func(i, j int) bool {
+ progressI := operations[i]["progress"].(int)
+ progressJ := operations[j]["progress"].(int)
+
+ // Primary sort by progress
+ if progressI != progressJ {
+ return progressI < progressJ
+ }
+
+ // Secondary sort by ID for stability when progress is the same
+ return operations[i]["id"].(string) < operations[j]["id"].(string)
+ })
+
+ return c.JSON(200, map[string]interface{}{
+ "operations": operations,
+ })
+ })
+
+ // Cancel operation endpoint
+ app.POST("/api/operations/:jobID/cancel", func(c echo.Context) error {
+ jobID := c.Param("jobID")
+ xlog.Debug("API request to cancel operation", "jobID", jobID)
+
+ err := galleryService.CancelOperation(jobID)
+ if err != nil {
+ xlog.Error("Failed to cancel operation", "error", err, "jobID", jobID)
+ return c.JSON(http.StatusBadRequest, map[string]interface{}{
+ "error": err.Error(),
+ })
+ }
+
+ // Clean up opcache for cancelled operation
+ opcache.DeleteUUID(jobID)
+
+ return c.JSON(200, map[string]interface{}{
+ "success": true,
+ "message": "Operation cancelled",
+ })
+ })
+
+ // Model Gallery APIs
+ app.GET("/api/models", func(c echo.Context) error {
+ term := c.QueryParam("term")
+ page := c.QueryParam("page")
+ if page == "" {
+ page = "1"
+ }
+ items := c.QueryParam("items")
+ if items == "" {
+ items = "21"
+ }
+
+ models, err := gallery.AvailableGalleryModels(appConfig.Galleries, appConfig.SystemState)
+ if err != nil {
+ xlog.Error("could not list models from galleries", "error", err)
+ return c.JSON(http.StatusInternalServerError, map[string]interface{}{
+ "error": err.Error(),
+ })
+ }
+
+ // Get all available tags
+ allTags := map[string]struct{}{}
+ tags := []string{}
+ for _, m := range models {
+ for _, t := range m.Tags {
+ allTags[t] = struct{}{}
+ }
+ }
+ for t := range allTags {
+ tags = append(tags, t)
+ }
+ sort.Strings(tags)
+
+ if term != "" {
+ models = gallery.GalleryElements[*gallery.GalleryModel](models).Search(term)
+ }
+
+ // Get model statuses
+ processingModelsData, taskTypes := opcache.GetStatus()
+
+ // Apply sorting if requested
+ sortBy := c.QueryParam("sort")
+ sortOrder := c.QueryParam("order")
+ if sortOrder == "" {
+ sortOrder = ascSortOrder
+ }
+
+ switch sortBy {
+ case nameSortFieldName:
+ models = gallery.GalleryElements[*gallery.GalleryModel](models).SortByName(sortOrder)
+ case repositorySortFieldName:
+ models = gallery.GalleryElements[*gallery.GalleryModel](models).SortByRepository(sortOrder)
+ case licenseSortFieldName:
+ models = gallery.GalleryElements[*gallery.GalleryModel](models).SortByLicense(sortOrder)
+ case statusSortFieldName:
+ models = gallery.GalleryElements[*gallery.GalleryModel](models).SortByInstalled(sortOrder)
+ }
+
+ pageNum, err := strconv.Atoi(page)
+ if err != nil || pageNum < 1 {
+ pageNum = 1
+ }
+
+ itemsNum, err := strconv.Atoi(items)
+ if err != nil || itemsNum < 1 {
+ itemsNum = 21
+ }
+
+ totalPages := int(math.Ceil(float64(len(models)) / float64(itemsNum)))
+ totalModels := len(models)
+
+ if pageNum > 0 {
+ models = models.Paginate(pageNum, itemsNum)
+ }
+
+ // Convert models to JSON-friendly format and deduplicate by ID
+ modelsJSON := make([]map[string]interface{}, 0, len(models))
+ seenIDs := make(map[string]bool)
+
+ for _, m := range models {
+ modelID := m.ID()
+
+ // Skip duplicate IDs to prevent Alpine.js x-for errors
+ if seenIDs[modelID] {
+ xlog.Debug("Skipping duplicate model ID", "modelID", modelID)
+ continue
+ }
+ seenIDs[modelID] = true
+
+ currentlyProcessing := opcache.Exists(modelID)
+ jobID := ""
+ isDeletionOp := false
+ if currentlyProcessing {
+ jobID = opcache.Get(modelID)
+ status := galleryService.GetStatus(jobID)
+ if status != nil && status.Deletion {
+ isDeletionOp = true
+ }
+ }
+
+ _, trustRemoteCodeExists := m.Overrides["trust_remote_code"]
+
+ modelsJSON = append(modelsJSON, map[string]interface{}{
+ "id": modelID,
+ "name": m.Name,
+ "description": m.Description,
+ "icon": m.Icon,
+ "license": m.License,
+ "urls": m.URLs,
+ "tags": m.Tags,
+ "gallery": m.Gallery.Name,
+ "installed": m.Installed,
+ "processing": currentlyProcessing,
+ "jobID": jobID,
+ "isDeletion": isDeletionOp,
+ "trustRemoteCode": trustRemoteCodeExists,
+ "additionalFiles": m.AdditionalFiles,
+ })
+ }
+
+ prevPage := pageNum - 1
+ nextPage := pageNum + 1
+ if prevPage < 1 {
+ prevPage = 1
+ }
+ if nextPage > totalPages {
+ nextPage = totalPages
+ }
+
+ // Calculate installed models count (models with configs + models without configs)
+ modelConfigs := cl.GetAllModelsConfigs()
+ modelsWithoutConfig, _ := services.ListModels(cl, ml, config.NoFilterFn, services.LOOSE_ONLY)
+ installedModelsCount := len(modelConfigs) + len(modelsWithoutConfig)
+
+ return c.JSON(200, map[string]interface{}{
+ "models": modelsJSON,
+ "repositories": appConfig.Galleries,
+ "allTags": tags,
+ "processingModels": processingModelsData,
+ "taskTypes": taskTypes,
+ "availableModels": totalModels,
+ "installedModels": installedModelsCount,
+ "currentPage": pageNum,
+ "totalPages": totalPages,
+ "prevPage": prevPage,
+ "nextPage": nextPage,
+ })
+ })
+
+ app.POST("/api/models/install/:id", func(c echo.Context) error {
+ galleryID := c.Param("id")
+ // URL decode the gallery ID (e.g., "localai%40model" -> "localai@model")
+ galleryID, err := url.QueryUnescape(galleryID)
+ if err != nil {
+ return c.JSON(http.StatusBadRequest, map[string]interface{}{
+ "error": "invalid model ID",
+ })
+ }
+ xlog.Debug("API job submitted to install", "galleryID", galleryID)
+
+ id, err := uuid.NewUUID()
+ if err != nil {
+ return c.JSON(http.StatusInternalServerError, map[string]interface{}{
+ "error": err.Error(),
+ })
+ }
+
+ uid := id.String()
+ opcache.Set(galleryID, uid)
+
+ ctx, cancelFunc := context.WithCancel(context.Background())
+ op := services.GalleryOp[gallery.GalleryModel, gallery.ModelConfig]{
+ ID: uid,
+ GalleryElementName: galleryID,
+ Galleries: appConfig.Galleries,
+ BackendGalleries: appConfig.BackendGalleries,
+ Context: ctx,
+ CancelFunc: cancelFunc,
+ }
+ // Store cancellation function immediately so queued operations can be cancelled
+ galleryService.StoreCancellation(uid, cancelFunc)
+ go func() {
+ galleryService.ModelGalleryChannel <- op
+ }()
+
+ return c.JSON(200, map[string]interface{}{
+ "jobID": uid,
+ "message": "Installation started",
+ })
+ })
+
+ app.POST("/api/models/delete/:id", func(c echo.Context) error {
+ galleryID := c.Param("id")
+ // URL decode the gallery ID
+ galleryID, err := url.QueryUnescape(galleryID)
+ if err != nil {
+ return c.JSON(http.StatusBadRequest, map[string]interface{}{
+ "error": "invalid model ID",
+ })
+ }
+ xlog.Debug("API job submitted to delete", "galleryID", galleryID)
+
+ var galleryName = galleryID
+ if strings.Contains(galleryID, "@") {
+ galleryName = strings.Split(galleryID, "@")[1]
+ }
+
+ id, err := uuid.NewUUID()
+ if err != nil {
+ return c.JSON(http.StatusInternalServerError, map[string]interface{}{
+ "error": err.Error(),
+ })
+ }
+
+ uid := id.String()
+
+ opcache.Set(galleryID, uid)
+
+ ctx, cancelFunc := context.WithCancel(context.Background())
+ op := services.GalleryOp[gallery.GalleryModel, gallery.ModelConfig]{
+ ID: uid,
+ Delete: true,
+ GalleryElementName: galleryName,
+ Galleries: appConfig.Galleries,
+ BackendGalleries: appConfig.BackendGalleries,
+ Context: ctx,
+ CancelFunc: cancelFunc,
+ }
+ // Store cancellation function immediately so queued operations can be cancelled
+ galleryService.StoreCancellation(uid, cancelFunc)
+ go func() {
+ galleryService.ModelGalleryChannel <- op
+ cl.RemoveModelConfig(galleryName)
+ }()
+
+ return c.JSON(200, map[string]interface{}{
+ "jobID": uid,
+ "message": "Deletion started",
+ })
+ })
+
+ app.POST("/api/models/config/:id", func(c echo.Context) error {
+ galleryID := c.Param("id")
+ // URL decode the gallery ID
+ galleryID, err := url.QueryUnescape(galleryID)
+ if err != nil {
+ return c.JSON(http.StatusBadRequest, map[string]interface{}{
+ "error": "invalid model ID",
+ })
+ }
+ xlog.Debug("API job submitted to get config", "galleryID", galleryID)
+
+ models, err := gallery.AvailableGalleryModels(appConfig.Galleries, appConfig.SystemState)
+ if err != nil {
+ return c.JSON(http.StatusInternalServerError, map[string]interface{}{
+ "error": err.Error(),
+ })
+ }
+
+ model := gallery.FindGalleryElement(models, galleryID)
+ if model == nil {
+ return c.JSON(http.StatusNotFound, map[string]interface{}{
+ "error": "model not found",
+ })
+ }
+
+ config, err := gallery.GetGalleryConfigFromURL[gallery.ModelConfig](model.URL, appConfig.SystemState.Model.ModelsPath)
+ if err != nil {
+ return c.JSON(http.StatusInternalServerError, map[string]interface{}{
+ "error": err.Error(),
+ })
+ }
+
+ _, err = gallery.InstallModel(context.Background(), appConfig.SystemState, model.Name, &config, model.Overrides, nil, false)
+ if err != nil {
+ return c.JSON(http.StatusInternalServerError, map[string]interface{}{
+ "error": err.Error(),
+ })
+ }
+
+ return c.JSON(200, map[string]interface{}{
+ "message": "Configuration file saved",
+ })
+ })
+
+ app.GET("/api/models/job/:uid", func(c echo.Context) error {
+ jobUID := c.Param("uid")
+
+ status := galleryService.GetStatus(jobUID)
+ if status == nil {
+ // Job is queued but hasn't started processing yet
+ return c.JSON(200, map[string]interface{}{
+ "progress": 0,
+ "message": "Operation queued",
+ "galleryElementName": "",
+ "processed": false,
+ "deletion": false,
+ "queued": true,
+ })
+ }
+
+ response := map[string]interface{}{
+ "progress": status.Progress,
+ "message": status.Message,
+ "galleryElementName": status.GalleryElementName,
+ "processed": status.Processed,
+ "deletion": status.Deletion,
+ "queued": false,
+ }
+
+ if status.Error != nil {
+ response["error"] = status.Error.Error()
+ }
+
+ if status.Progress == 100 && status.Processed && status.Message == "completed" {
+ opcache.DeleteUUID(jobUID)
+ response["completed"] = true
+ }
+
+ return c.JSON(200, response)
+ })
+
+ // Backend Gallery APIs
+ app.GET("/api/backends", func(c echo.Context) error {
+ term := c.QueryParam("term")
+ page := c.QueryParam("page")
+ if page == "" {
+ page = "1"
+ }
+ items := c.QueryParam("items")
+ if items == "" {
+ items = "21"
+ }
+
+ backends, err := gallery.AvailableBackends(appConfig.BackendGalleries, appConfig.SystemState)
+ if err != nil {
+ xlog.Error("could not list backends from galleries", "error", err)
+ return c.JSON(http.StatusInternalServerError, map[string]interface{}{
+ "error": err.Error(),
+ })
+ }
+
+ // Get all available tags
+ allTags := map[string]struct{}{}
+ tags := []string{}
+ for _, b := range backends {
+ for _, t := range b.Tags {
+ allTags[t] = struct{}{}
+ }
+ }
+ for t := range allTags {
+ tags = append(tags, t)
+ }
+ sort.Strings(tags)
+
+ if term != "" {
+ backends = gallery.GalleryElements[*gallery.GalleryBackend](backends).Search(term)
+ }
+
+ // Get backend statuses
+ processingBackendsData, taskTypes := opcache.GetStatus()
+
+ // Apply sorting if requested
+ sortBy := c.QueryParam("sort")
+ sortOrder := c.QueryParam("order")
+ if sortOrder == "" {
+ sortOrder = ascSortOrder
+ }
+
+ switch sortBy {
+ case nameSortFieldName:
+ backends = gallery.GalleryElements[*gallery.GalleryBackend](backends).SortByName(sortOrder)
+ case repositorySortFieldName:
+ backends = gallery.GalleryElements[*gallery.GalleryBackend](backends).SortByRepository(sortOrder)
+ case licenseSortFieldName:
+ backends = gallery.GalleryElements[*gallery.GalleryBackend](backends).SortByLicense(sortOrder)
+ case statusSortFieldName:
+ backends = gallery.GalleryElements[*gallery.GalleryBackend](backends).SortByInstalled(sortOrder)
+ }
+
+ pageNum, err := strconv.Atoi(page)
+ if err != nil || pageNum < 1 {
+ pageNum = 1
+ }
+
+ itemsNum, err := strconv.Atoi(items)
+ if err != nil || itemsNum < 1 {
+ itemsNum = 21
+ }
+
+ totalPages := int(math.Ceil(float64(len(backends)) / float64(itemsNum)))
+ totalBackends := len(backends)
+
+ if pageNum > 0 {
+ backends = backends.Paginate(pageNum, itemsNum)
+ }
+
+ // Convert backends to JSON-friendly format and deduplicate by ID
+ backendsJSON := make([]map[string]interface{}, 0, len(backends))
+ seenBackendIDs := make(map[string]bool)
+
+ for _, b := range backends {
+ backendID := b.ID()
+
+ // Skip duplicate IDs to prevent Alpine.js x-for errors
+ if seenBackendIDs[backendID] {
+ xlog.Debug("Skipping duplicate backend ID", "backendID", backendID)
+ continue
+ }
+ seenBackendIDs[backendID] = true
+
+ currentlyProcessing := opcache.Exists(backendID)
+ jobID := ""
+ isDeletionOp := false
+ if currentlyProcessing {
+ jobID = opcache.Get(backendID)
+ status := galleryService.GetStatus(jobID)
+ if status != nil && status.Deletion {
+ isDeletionOp = true
+ }
+ }
+
+ backendsJSON = append(backendsJSON, map[string]interface{}{
+ "id": backendID,
+ "name": b.Name,
+ "description": b.Description,
+ "icon": b.Icon,
+ "license": b.License,
+ "urls": b.URLs,
+ "tags": b.Tags,
+ "gallery": b.Gallery.Name,
+ "installed": b.Installed,
+ "processing": currentlyProcessing,
+ "jobID": jobID,
+ "isDeletion": isDeletionOp,
+ })
+ }
+
+ prevPage := pageNum - 1
+ nextPage := pageNum + 1
+ if prevPage < 1 {
+ prevPage = 1
+ }
+ if nextPage > totalPages {
+ nextPage = totalPages
+ }
+
+ // Calculate installed backends count
+ installedBackends, err := gallery.ListSystemBackends(appConfig.SystemState)
+ installedBackendsCount := 0
+ if err == nil {
+ installedBackendsCount = len(installedBackends)
+ }
+
+ // Get the detected system capability
+ detectedCapability := ""
+ if appConfig.SystemState != nil {
+ detectedCapability = appConfig.SystemState.DetectedCapability()
+ }
+
+ return c.JSON(200, map[string]interface{}{
+ "backends": backendsJSON,
+ "repositories": appConfig.BackendGalleries,
+ "allTags": tags,
+ "processingBackends": processingBackendsData,
+ "taskTypes": taskTypes,
+ "availableBackends": totalBackends,
+ "installedBackends": installedBackendsCount,
+ "currentPage": pageNum,
+ "totalPages": totalPages,
+ "prevPage": prevPage,
+ "nextPage": nextPage,
+ "systemCapability": detectedCapability,
+ })
+ })
+
+ app.POST("/api/backends/install/:id", func(c echo.Context) error {
+ backendID := c.Param("id")
+ // URL decode the backend ID
+ backendID, err := url.QueryUnescape(backendID)
+ if err != nil {
+ return c.JSON(http.StatusBadRequest, map[string]interface{}{
+ "error": "invalid backend ID",
+ })
+ }
+ xlog.Debug("API job submitted to install backend", "backendID", backendID)
+
+ id, err := uuid.NewUUID()
+ if err != nil {
+ return c.JSON(http.StatusInternalServerError, map[string]interface{}{
+ "error": err.Error(),
+ })
+ }
+
+ uid := id.String()
+ opcache.SetBackend(backendID, uid)
+
+ ctx, cancelFunc := context.WithCancel(context.Background())
+ op := services.GalleryOp[gallery.GalleryBackend, any]{
+ ID: uid,
+ GalleryElementName: backendID,
+ Galleries: appConfig.BackendGalleries,
+ Context: ctx,
+ CancelFunc: cancelFunc,
+ }
+ // Store cancellation function immediately so queued operations can be cancelled
+ galleryService.StoreCancellation(uid, cancelFunc)
+ go func() {
+ galleryService.BackendGalleryChannel <- op
+ }()
+
+ return c.JSON(200, map[string]interface{}{
+ "jobID": uid,
+ "message": "Backend installation started",
+ })
+ })
+
+ // Install backend from external source (OCI image, URL, or path)
+ app.POST("/api/backends/install-external", func(c echo.Context) error {
+ // Request body structure
+ type ExternalBackendRequest struct {
+ URI string `json:"uri"`
+ Name string `json:"name"`
+ Alias string `json:"alias"`
+ }
+
+ var req ExternalBackendRequest
+ if err := c.Bind(&req); err != nil {
+ return c.JSON(http.StatusBadRequest, map[string]interface{}{
+ "error": "invalid request body",
+ })
+ }
+
+ // Validate required fields
+ if req.URI == "" {
+ return c.JSON(http.StatusBadRequest, map[string]interface{}{
+ "error": "uri is required",
+ })
+ }
+
+ xlog.Debug("API job submitted to install external backend", "uri", req.URI, "name", req.Name, "alias", req.Alias)
+
+ id, err := uuid.NewUUID()
+ if err != nil {
+ return c.JSON(http.StatusInternalServerError, map[string]interface{}{
+ "error": err.Error(),
+ })
+ }
+
+ uid := id.String()
+
+ // Use URI as the key for opcache, or name if provided
+ cacheKey := req.URI
+ if req.Name != "" {
+ cacheKey = req.Name
+ }
+ opcache.SetBackend(cacheKey, uid)
+
+ ctx, cancelFunc := context.WithCancel(context.Background())
+ op := services.GalleryOp[gallery.GalleryBackend, any]{
+ ID: uid,
+ GalleryElementName: req.Name, // May be empty, will be derived during installation
+ Galleries: appConfig.BackendGalleries,
+ Context: ctx,
+ CancelFunc: cancelFunc,
+ ExternalURI: req.URI,
+ ExternalName: req.Name,
+ ExternalAlias: req.Alias,
+ }
+ // Store cancellation function immediately so queued operations can be cancelled
+ galleryService.StoreCancellation(uid, cancelFunc)
+ go func() {
+ galleryService.BackendGalleryChannel <- op
+ }()
+
+ return c.JSON(200, map[string]interface{}{
+ "jobID": uid,
+ "message": "External backend installation started",
+ })
+ })
+
+ app.POST("/api/backends/delete/:id", func(c echo.Context) error {
+ backendID := c.Param("id")
+ // URL decode the backend ID
+ backendID, err := url.QueryUnescape(backendID)
+ if err != nil {
+ return c.JSON(http.StatusBadRequest, map[string]interface{}{
+ "error": "invalid backend ID",
+ })
+ }
+ xlog.Debug("API job submitted to delete backend", "backendID", backendID)
+
+ var backendName = backendID
+ if strings.Contains(backendID, "@") {
+ backendName = strings.Split(backendID, "@")[1]
+ }
+
+ id, err := uuid.NewUUID()
+ if err != nil {
+ return c.JSON(http.StatusInternalServerError, map[string]interface{}{
+ "error": err.Error(),
+ })
+ }
+
+ uid := id.String()
+
+ opcache.SetBackend(backendID, uid)
+
+ ctx, cancelFunc := context.WithCancel(context.Background())
+ op := services.GalleryOp[gallery.GalleryBackend, any]{
+ ID: uid,
+ Delete: true,
+ GalleryElementName: backendName,
+ Galleries: appConfig.BackendGalleries,
+ Context: ctx,
+ CancelFunc: cancelFunc,
+ }
+ // Store cancellation function immediately so queued operations can be cancelled
+ galleryService.StoreCancellation(uid, cancelFunc)
+ go func() {
+ galleryService.BackendGalleryChannel <- op
+ }()
+
+ return c.JSON(200, map[string]interface{}{
+ "jobID": uid,
+ "message": "Backend deletion started",
+ })
+ })
+
+ app.GET("/api/backends/job/:uid", func(c echo.Context) error {
+ jobUID := c.Param("uid")
+
+ status := galleryService.GetStatus(jobUID)
+ if status == nil {
+ // Job is queued but hasn't started processing yet
+ return c.JSON(200, map[string]interface{}{
+ "progress": 0,
+ "message": "Operation queued",
+ "galleryElementName": "",
+ "processed": false,
+ "deletion": false,
+ "queued": true,
+ })
+ }
+
+ response := map[string]interface{}{
+ "progress": status.Progress,
+ "message": status.Message,
+ "galleryElementName": status.GalleryElementName,
+ "processed": status.Processed,
+ "deletion": status.Deletion,
+ "queued": false,
+ }
+
+ if status.Error != nil {
+ response["error"] = status.Error.Error()
+ }
+
+ if status.Progress == 100 && status.Processed && status.Message == "completed" {
+ opcache.DeleteUUID(jobUID)
+ response["completed"] = true
+ }
+
+ return c.JSON(200, response)
+ })
+
+ // System Backend Deletion API (for installed backends on index page)
+ app.POST("/api/backends/system/delete/:name", func(c echo.Context) error {
+ backendName := c.Param("name")
+ // URL decode the backend name
+ backendName, err := url.QueryUnescape(backendName)
+ if err != nil {
+ return c.JSON(http.StatusBadRequest, map[string]interface{}{
+ "error": "invalid backend name",
+ })
+ }
+ xlog.Debug("API request to delete system backend", "backendName", backendName)
+
+ // Use the gallery package to delete the backend
+ if err := gallery.DeleteBackendFromSystem(appConfig.SystemState, backendName); err != nil {
+ xlog.Error("Failed to delete backend", "error", err, "backendName", backendName)
+ return c.JSON(http.StatusInternalServerError, map[string]interface{}{
+ "error": err.Error(),
+ })
+ }
+
+ return c.JSON(200, map[string]interface{}{
+ "success": true,
+ "message": "Backend deleted successfully",
+ })
+ })
+
+ // P2P APIs
+ app.GET("/api/p2p/workers", func(c echo.Context) error {
+ nodes := p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.WorkerID))
+
+ nodesJSON := make([]map[string]interface{}, 0, len(nodes))
+ for _, n := range nodes {
+ nodesJSON = append(nodesJSON, map[string]interface{}{
+ "name": n.Name,
+ "id": n.ID,
+ "tunnelAddress": n.TunnelAddress,
+ "serviceID": n.ServiceID,
+ "lastSeen": n.LastSeen,
+ "isOnline": n.IsOnline(),
+ })
+ }
+
+ return c.JSON(200, map[string]interface{}{
+ "nodes": nodesJSON,
+ })
+ })
+
+ app.GET("/api/p2p/federation", func(c echo.Context) error {
+ nodes := p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.FederatedID))
+
+ nodesJSON := make([]map[string]interface{}, 0, len(nodes))
+ for _, n := range nodes {
+ nodesJSON = append(nodesJSON, map[string]interface{}{
+ "name": n.Name,
+ "id": n.ID,
+ "tunnelAddress": n.TunnelAddress,
+ "serviceID": n.ServiceID,
+ "lastSeen": n.LastSeen,
+ "isOnline": n.IsOnline(),
+ })
+ }
+
+ return c.JSON(200, map[string]interface{}{
+ "nodes": nodesJSON,
+ })
+ })
+
+ app.GET("/api/p2p/stats", func(c echo.Context) error {
+ workerNodes := p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.WorkerID))
+ federatedNodes := p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.FederatedID))
+
+ workersOnline := 0
+ for _, n := range workerNodes {
+ if n.IsOnline() {
+ workersOnline++
+ }
+ }
+
+ federatedOnline := 0
+ for _, n := range federatedNodes {
+ if n.IsOnline() {
+ federatedOnline++
+ }
+ }
+
+ return c.JSON(200, map[string]interface{}{
+ "workers": map[string]interface{}{
+ "online": workersOnline,
+ "total": len(workerNodes),
+ },
+ "federated": map[string]interface{}{
+ "online": federatedOnline,
+ "total": len(federatedNodes),
+ },
+ })
+ })
+
+ // Resources API endpoint - unified memory info (GPU if available, otherwise RAM)
+ app.GET("/api/resources", func(c echo.Context) error {
+ resourceInfo := xsysinfo.GetResourceInfo()
+
+ // Format watchdog interval
+ watchdogInterval := "2s" // default
+ if appConfig.WatchDogInterval > 0 {
+ watchdogInterval = appConfig.WatchDogInterval.String()
+ }
+
+ response := map[string]interface{}{
+ "type": resourceInfo.Type, // "gpu" or "ram"
+ "available": resourceInfo.Available,
+ "gpus": resourceInfo.GPUs,
+ "ram": resourceInfo.RAM,
+ "aggregate": resourceInfo.Aggregate,
+ "reclaimer_enabled": appConfig.MemoryReclaimerEnabled,
+ "reclaimer_threshold": appConfig.MemoryReclaimerThreshold,
+ "watchdog_interval": watchdogInterval,
+ }
+
+ return c.JSON(200, response)
+ })
+
+ if !appConfig.DisableRuntimeSettings {
+ // Settings API
+ app.GET("/api/settings", localai.GetSettingsEndpoint(applicationInstance))
+ app.POST("/api/settings", localai.UpdateSettingsEndpoint(applicationInstance))
+ }
+
+ // Logs API
+ app.GET("/api/traces", func(c echo.Context) error {
+ if !appConfig.EnableTracing {
+ return c.JSON(503, map[string]any{
+ "error": "Tracing disabled",
+ })
+ }
+ traces := middleware.GetTraces()
+ return c.JSON(200, map[string]interface{}{
+ "traces": traces,
+ })
+ })
+
+ app.POST("/api/traces/clear", func(c echo.Context) error {
+ middleware.ClearTraces()
+ return c.JSON(200, map[string]interface{}{
+ "message": "Traces cleared",
+ })
+ })
+}
diff --git a/core/http/routes/ui_api_backends_test.go b/core/http/routes/ui_api_backends_test.go
new file mode 100644
index 0000000000000000000000000000000000000000..b611d403e6e1352c76934d7a4264f3dac03cebfd
--- /dev/null
+++ b/core/http/routes/ui_api_backends_test.go
@@ -0,0 +1,210 @@
+package routes_test
+
+import (
+ "bytes"
+ "context"
+ "encoding/json"
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "os"
+ "path/filepath"
+ "testing"
+
+ "github.com/labstack/echo/v4"
+ "github.com/mudler/LocalAI/core/config"
+ "github.com/mudler/LocalAI/core/gallery"
+ "github.com/mudler/LocalAI/core/http/routes"
+ "github.com/mudler/LocalAI/core/services"
+ "github.com/mudler/LocalAI/pkg/model"
+ "github.com/mudler/LocalAI/pkg/system"
+ . "github.com/onsi/ginkgo/v2"
+ . "github.com/onsi/gomega"
+)
+
+func TestRoutes(t *testing.T) {
+ RegisterFailHandler(Fail)
+ RunSpecs(t, "Routes Suite")
+}
+
+var _ = Describe("Backend API Routes", func() {
+ var (
+ app *echo.Echo
+ tempDir string
+ appConfig *config.ApplicationConfig
+ galleryService *services.GalleryService
+ modelLoader *model.ModelLoader
+ systemState *system.SystemState
+ configLoader *config.ModelConfigLoader
+ )
+
+ BeforeEach(func() {
+ var err error
+ tempDir, err = os.MkdirTemp("", "backend-routes-test-*")
+ Expect(err).NotTo(HaveOccurred())
+
+ systemState, err = system.GetSystemState(
+ system.WithBackendPath(filepath.Join(tempDir, "backends")),
+ )
+ Expect(err).NotTo(HaveOccurred())
+ systemState.Model.ModelsPath = filepath.Join(tempDir, "models")
+
+ // Create directories
+ err = os.MkdirAll(systemState.Backend.BackendsPath, 0750)
+ Expect(err).NotTo(HaveOccurred())
+ err = os.MkdirAll(systemState.Model.ModelsPath, 0750)
+ Expect(err).NotTo(HaveOccurred())
+
+ modelLoader = model.NewModelLoader(systemState)
+ configLoader = config.NewModelConfigLoader(tempDir)
+
+ appConfig = config.NewApplicationConfig(
+ config.WithContext(context.Background()),
+ )
+ appConfig.SystemState = systemState
+ appConfig.BackendGalleries = []config.Gallery{}
+
+ galleryService = services.NewGalleryService(appConfig, modelLoader)
+ // Start the gallery service
+ err = galleryService.Start(context.Background(), configLoader, systemState)
+ Expect(err).NotTo(HaveOccurred())
+
+ app = echo.New()
+
+ // Register the API routes for backends
+ opcache := services.NewOpCache(galleryService)
+ routes.RegisterUIAPIRoutes(app, configLoader, modelLoader, appConfig, galleryService, opcache, nil)
+ })
+
+ AfterEach(func() {
+ os.RemoveAll(tempDir)
+ })
+
+ Describe("POST /api/backends/install-external", func() {
+ It("should return error when URI is missing", func() {
+ reqBody := map[string]string{
+ "name": "test-backend",
+ }
+ jsonBody, err := json.Marshal(reqBody)
+ Expect(err).NotTo(HaveOccurred())
+
+ req := httptest.NewRequest(http.MethodPost, "/api/backends/install-external", bytes.NewBuffer(jsonBody))
+ req.Header.Set("Content-Type", "application/json")
+ rec := httptest.NewRecorder()
+
+ app.ServeHTTP(rec, req)
+
+ Expect(rec.Code).To(Equal(http.StatusBadRequest))
+
+ var response map[string]interface{}
+ err = json.Unmarshal(rec.Body.Bytes(), &response)
+ Expect(err).NotTo(HaveOccurred())
+ Expect(response["error"]).To(Equal("uri is required"))
+ })
+
+ It("should accept valid request and return job ID", func() {
+ reqBody := map[string]string{
+ "uri": "oci://quay.io/example/backend:latest",
+ "name": "test-backend",
+ "alias": "test-alias",
+ }
+ jsonBody, err := json.Marshal(reqBody)
+ Expect(err).NotTo(HaveOccurred())
+
+ req := httptest.NewRequest(http.MethodPost, "/api/backends/install-external", bytes.NewBuffer(jsonBody))
+ req.Header.Set("Content-Type", "application/json")
+ rec := httptest.NewRecorder()
+
+ app.ServeHTTP(rec, req)
+
+ Expect(rec.Code).To(Equal(http.StatusOK))
+
+ var response map[string]interface{}
+ err = json.Unmarshal(rec.Body.Bytes(), &response)
+ Expect(err).NotTo(HaveOccurred())
+ Expect(response["jobID"]).NotTo(BeEmpty())
+ Expect(response["message"]).To(Equal("External backend installation started"))
+ })
+
+ It("should accept request with only URI", func() {
+ reqBody := map[string]string{
+ "uri": "/path/to/local/backend",
+ }
+ jsonBody, err := json.Marshal(reqBody)
+ Expect(err).NotTo(HaveOccurred())
+
+ req := httptest.NewRequest(http.MethodPost, "/api/backends/install-external", bytes.NewBuffer(jsonBody))
+ req.Header.Set("Content-Type", "application/json")
+ rec := httptest.NewRecorder()
+
+ app.ServeHTTP(rec, req)
+
+ Expect(rec.Code).To(Equal(http.StatusOK))
+
+ var response map[string]interface{}
+ err = json.Unmarshal(rec.Body.Bytes(), &response)
+ Expect(err).NotTo(HaveOccurred())
+ Expect(response["jobID"]).NotTo(BeEmpty())
+ })
+
+ It("should return error for invalid JSON body", func() {
+ req := httptest.NewRequest(http.MethodPost, "/api/backends/install-external", bytes.NewBufferString("invalid json"))
+ req.Header.Set("Content-Type", "application/json")
+ rec := httptest.NewRecorder()
+
+ app.ServeHTTP(rec, req)
+
+ Expect(rec.Code).To(Equal(http.StatusBadRequest))
+ })
+ })
+
+ Describe("GET /api/backends/job/:uid", func() {
+ It("should return queued status for unknown job", func() {
+ req := httptest.NewRequest(http.MethodGet, "/api/backends/job/unknown-job-id", nil)
+ rec := httptest.NewRecorder()
+
+ app.ServeHTTP(rec, req)
+
+ Expect(rec.Code).To(Equal(http.StatusOK))
+
+ var response map[string]interface{}
+ err := json.Unmarshal(rec.Body.Bytes(), &response)
+ Expect(err).NotTo(HaveOccurred())
+ Expect(response["queued"]).To(Equal(true))
+ Expect(response["processed"]).To(Equal(false))
+ })
+ })
+})
+
+// Helper function to make POST request
+func postRequest(url string, body interface{}) (*http.Response, error) {
+ jsonBody, err := json.Marshal(body)
+ if err != nil {
+ return nil, err
+ }
+
+ req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonBody))
+ if err != nil {
+ return nil, err
+ }
+ req.Header.Set("Content-Type", "application/json")
+
+ client := &http.Client{}
+ return client.Do(req)
+}
+
+// Helper function to read response body
+func readResponseBody(resp *http.Response) (map[string]interface{}, error) {
+ defer resp.Body.Close()
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, err
+ }
+
+ var result map[string]interface{}
+ err = json.Unmarshal(body, &result)
+ return result, err
+}
+
+// Avoid unused import errors
+var _ = gallery.GalleryModel{}
diff --git a/core/http/routes/ui_backend_gallery.go b/core/http/routes/ui_backend_gallery.go
new file mode 100644
index 0000000000000000000000000000000000000000..8f0a31351236609411d5478cd488e6c09d509593
--- /dev/null
+++ b/core/http/routes/ui_backend_gallery.go
@@ -0,0 +1,24 @@
+package routes
+
+import (
+ "github.com/labstack/echo/v4"
+ "github.com/mudler/LocalAI/core/config"
+ "github.com/mudler/LocalAI/core/http/middleware"
+ "github.com/mudler/LocalAI/core/services"
+ "github.com/mudler/LocalAI/internal"
+)
+
+func registerBackendGalleryRoutes(app *echo.Echo, appConfig *config.ApplicationConfig, galleryService *services.GalleryService, opcache *services.OpCache) {
+ // Show the Backends page (all backends are loaded client-side via Alpine.js)
+ app.GET("/browse/backends", func(c echo.Context) error {
+ summary := map[string]interface{}{
+ "Title": "LocalAI - Backends",
+ "BaseURL": middleware.BaseURL(c),
+ "Version": internal.PrintableVersion(),
+ "Repositories": appConfig.BackendGalleries,
+ }
+
+ // Render index - backends are now loaded via Alpine.js from /api/backends
+ return c.Render(200, "views/backends", summary)
+ })
+}
diff --git a/core/http/routes/ui_gallery.go b/core/http/routes/ui_gallery.go
new file mode 100644
index 0000000000000000000000000000000000000000..dfd39fe764780ef53b527c85a7f2ebc11c4c6019
--- /dev/null
+++ b/core/http/routes/ui_gallery.go
@@ -0,0 +1,23 @@
+package routes
+
+import (
+ "github.com/labstack/echo/v4"
+ "github.com/mudler/LocalAI/core/config"
+ "github.com/mudler/LocalAI/core/http/middleware"
+ "github.com/mudler/LocalAI/core/services"
+ "github.com/mudler/LocalAI/internal"
+)
+
+func registerGalleryRoutes(app *echo.Echo, cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, galleryService *services.GalleryService, opcache *services.OpCache) {
+ app.GET("/browse", func(c echo.Context) error {
+ summary := map[string]interface{}{
+ "Title": "LocalAI - Models",
+ "BaseURL": middleware.BaseURL(c),
+ "Version": internal.PrintableVersion(),
+ "Repositories": appConfig.Galleries,
+ }
+
+ // Render index - models are now loaded via Alpine.js from /api/models
+ return c.Render(200, "views/models", summary)
+ })
+}
diff --git a/core/http/static/animations.css b/core/http/static/animations.css
new file mode 100644
index 0000000000000000000000000000000000000000..c0d85eea5a0b3f20541f7cb08091a3322b19b812
--- /dev/null
+++ b/core/http/static/animations.css
@@ -0,0 +1,247 @@
+/* LocalAI Animation System */
+/* Purposeful animations with performance optimization */
+
+/* Animation Keyframes */
+@keyframes fadeIn {
+ from {
+ opacity: 0;
+ }
+ to {
+ opacity: 1;
+ }
+}
+
+@keyframes fadeInUp {
+ from {
+ opacity: 0;
+ transform: translateY(20px);
+ }
+ to {
+ opacity: 1;
+ transform: translateY(0);
+ }
+}
+
+@keyframes fadeInDown {
+ from {
+ opacity: 0;
+ transform: translateY(-20px);
+ }
+ to {
+ opacity: 1;
+ transform: translateY(0);
+ }
+}
+
+@keyframes cardReveal {
+ from {
+ opacity: 0;
+ transform: translateY(20px);
+ }
+ to {
+ opacity: 1;
+ transform: translateY(0);
+ }
+}
+
+@keyframes slideInRight {
+ from {
+ opacity: 0;
+ transform: translateX(-20px);
+ }
+ to {
+ opacity: 1;
+ transform: translateX(0);
+ }
+}
+
+@keyframes slideInLeft {
+ from {
+ opacity: 0;
+ transform: translateX(20px);
+ }
+ to {
+ opacity: 1;
+ transform: translateX(0);
+ }
+}
+
+@keyframes pulse {
+ 0%, 100% {
+ opacity: 1;
+ }
+ 50% {
+ opacity: 0.5;
+ }
+}
+
+@keyframes glow {
+ 0%, 100% {
+ box-shadow: 0 0 8px rgba(56, 189, 248, 0.15);
+ }
+ 50% {
+ box-shadow: 0 0 12px rgba(56, 189, 248, 0.25);
+ }
+}
+
+@keyframes scaleIn {
+ from {
+ opacity: 0;
+ transform: scale(0.95);
+ }
+ to {
+ opacity: 1;
+ transform: scale(1);
+ }
+}
+
+/* P2P/Network Specific Animations */
+@keyframes rotateCircleNodes {
+ 0% { transform: rotate(0deg); }
+ 100% { transform: rotate(360deg); }
+}
+
+@keyframes shakeFlask {
+ 0%, 10% { transform: rotate(0deg); }
+ 20% { transform: rotate(-10deg); }
+ 30% { transform: rotate(10deg); }
+ 40% { transform: rotate(-8deg); }
+ 50% { transform: rotate(8deg); }
+ 60% { transform: rotate(-5deg); }
+ 70% { transform: rotate(5deg); }
+ 80% { transform: rotate(-2deg); }
+ 90% { transform: rotate(2deg); }
+ 100% { transform: rotate(0deg); }
+}
+
+@keyframes nodeGlow {
+ 0% { left: -100%; }
+ 50% { left: 100%; }
+ 100% { left: 100%; }
+}
+
+/* Animation Utility Classes */
+.fade-in {
+ animation: fadeIn var(--duration-fast) var(--ease-out);
+}
+
+/* Transition Utility Classes */
+.transition-default {
+ transition: all var(--duration-fast) var(--ease-default);
+}
+
+.transition-color {
+ transition: color var(--duration-fast) var(--ease-default);
+}
+
+.transition-background {
+ transition: background-color var(--duration-fast) var(--ease-default);
+}
+
+.fade-in-up {
+ animation: fadeInUp var(--duration-normal) var(--ease-out) backwards;
+}
+
+.fade-in-down {
+ animation: fadeInDown var(--duration-normal) var(--ease-out) backwards;
+}
+
+.slide-in-right {
+ animation: slideInRight var(--duration-normal) var(--ease-out) backwards;
+}
+
+.slide-in-left {
+ animation: slideInLeft var(--duration-normal) var(--ease-out) backwards;
+}
+
+.scale-in {
+ animation: scaleIn var(--duration-normal) var(--ease-out) backwards;
+}
+
+/* Staggered Card Animations */
+.card-animate {
+ animation: cardReveal var(--duration-normal) var(--ease-out) backwards;
+}
+
+.card-animate:nth-child(1) { animation-delay: 0ms; }
+.card-animate:nth-child(2) { animation-delay: 50ms; }
+.card-animate:nth-child(3) { animation-delay: 100ms; }
+.card-animate:nth-child(4) { animation-delay: 150ms; }
+.card-animate:nth-child(5) { animation-delay: 200ms; }
+.card-animate:nth-child(6) { animation-delay: 250ms; }
+.card-animate:nth-child(7) { animation-delay: 300ms; }
+.card-animate:nth-child(8) { animation-delay: 350ms; }
+.card-animate:nth-child(9) { animation-delay: 400ms; }
+.card-animate:nth-child(10) { animation-delay: 450ms; }
+.card-animate:nth-child(11) { animation-delay: 500ms; }
+.card-animate:nth-child(12) { animation-delay: 550ms; }
+
+/* Hero Text Animation */
+.hero-title {
+ animation: fadeInUp var(--duration-normal) var(--ease-out) backwards;
+ animation-delay: 50ms;
+}
+
+.hero-subtitle {
+ animation: fadeInUp var(--duration-normal) var(--ease-out) backwards;
+ animation-delay: 100ms;
+}
+
+/* Navigation Animation */
+.nav-fade-in {
+ animation: fadeIn var(--duration-normal) var(--ease-out) backwards;
+ animation-delay: 0ms;
+}
+
+/* Loading States - Minimal */
+.pulse-animation {
+ animation: pulse 1.5s var(--ease-in-out) infinite;
+}
+
+.glow-animation {
+ animation: glow 1.5s var(--ease-in-out) infinite;
+}
+
+/* Reduced Motion Support */
+@media (prefers-reduced-motion: reduce) {
+ *,
+ *::before,
+ *::after {
+ animation-duration: 0.01ms !important;
+ animation-iteration-count: 1 !important;
+ transition-duration: 0.01ms !important;
+ scroll-behavior: auto !important;
+ }
+
+ .card-animate,
+ .fade-in-up,
+ .fade-in-down,
+ .slide-in-right,
+ .slide-in-left,
+ .scale-in,
+ .hero-title,
+ .hero-subtitle {
+ animation: none !important;
+ }
+}
+
+/* Performance Optimization */
+.card-animate,
+.fade-in-up,
+.fade-in-down,
+.slide-in-right,
+.slide-in-left,
+.scale-in {
+ will-change: transform, opacity;
+}
+
+/* After animation completes, remove will-change */
+.card-animate.animation-complete,
+.fade-in-up.animation-complete,
+.fade-in-down.animation-complete,
+.slide-in-right.animation-complete,
+.slide-in-left.animation-complete,
+.scale-in.animation-complete {
+ will-change: auto;
+}
+
diff --git a/core/http/static/assets/KFOlCnqEu92Fr1MmEU9fBBc9.ttf b/core/http/static/assets/KFOlCnqEu92Fr1MmEU9fBBc9.ttf
new file mode 100644
index 0000000000000000000000000000000000000000..4f515e2adcfbfa3a34a3673ca3f3be6e61a32182
Binary files /dev/null and b/core/http/static/assets/KFOlCnqEu92Fr1MmEU9fBBc9.ttf differ
diff --git a/core/http/static/assets/KFOlCnqEu92Fr1MmEU9vAw.ttf b/core/http/static/assets/KFOlCnqEu92Fr1MmEU9vAw.ttf
new file mode 100644
index 0000000000000000000000000000000000000000..6366a58648c1364b56c2d153d40b5504ae12d07e
--- /dev/null
+++ b/core/http/static/assets/KFOlCnqEu92Fr1MmEU9vAw.ttf
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ecf88da1f85fa75dfce5aa0d9dd2973dd40e5702ce351d4de3ccfe58206044ce
+size 129768
diff --git a/core/http/static/assets/KFOlCnqEu92Fr1MmSU5fBBc9.ttf b/core/http/static/assets/KFOlCnqEu92Fr1MmSU5fBBc9.ttf
new file mode 100644
index 0000000000000000000000000000000000000000..0ddede80e25591c1729c1bb9ea34ca168436bddf
Binary files /dev/null and b/core/http/static/assets/KFOlCnqEu92Fr1MmSU5fBBc9.ttf differ
diff --git a/core/http/static/assets/KFOlCnqEu92Fr1MmWUlfBBc9.ttf b/core/http/static/assets/KFOlCnqEu92Fr1MmWUlfBBc9.ttf
new file mode 100644
index 0000000000000000000000000000000000000000..59830da823fa3776d4ae2eb24e1a7aa56424fa3a
Binary files /dev/null and b/core/http/static/assets/KFOlCnqEu92Fr1MmWUlfBBc9.ttf differ
diff --git a/core/http/static/assets/KFOlCnqEu92Fr1MmYUtfBBc9.ttf b/core/http/static/assets/KFOlCnqEu92Fr1MmYUtfBBc9.ttf
new file mode 100644
index 0000000000000000000000000000000000000000..bc31a871fe11f57034a4fa340ea76ab46458639e
Binary files /dev/null and b/core/http/static/assets/KFOlCnqEu92Fr1MmYUtfBBc9.ttf differ
diff --git a/core/http/static/assets/KFOmCnqEu92Fr1Me5Q.ttf b/core/http/static/assets/KFOmCnqEu92Fr1Me5Q.ttf
new file mode 100644
index 0000000000000000000000000000000000000000..8d10adcd1c995360d655646dbfaaee7f3c51569d
--- /dev/null
+++ b/core/http/static/assets/KFOmCnqEu92Fr1Me5Q.ttf
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7277cfb805def6410f317129b8e1f78bdd47d1a4e24c233077d06e88a36e57ae
+size 129584
diff --git a/core/http/static/assets/KFOmCnqEu92Fr1Mu4mxP.ttf b/core/http/static/assets/KFOmCnqEu92Fr1Mu4mxP.ttf
new file mode 100644
index 0000000000000000000000000000000000000000..d0e63254ce6733cc10961f6daafcaf4bf7222149
Binary files /dev/null and b/core/http/static/assets/KFOmCnqEu92Fr1Mu4mxP.ttf differ
diff --git a/core/http/static/assets/UcCO3FwrK3iLTeHuS_fvQtMwCp50KnMw2boKoduKmMEVuFuYMZg.ttf b/core/http/static/assets/UcCO3FwrK3iLTeHuS_fvQtMwCp50KnMw2boKoduKmMEVuFuYMZg.ttf
new file mode 100644
index 0000000000000000000000000000000000000000..3e39f556c0a7bcf358beb8f85e966ef517dda0f6
--- /dev/null
+++ b/core/http/static/assets/UcCO3FwrK3iLTeHuS_fvQtMwCp50KnMw2boKoduKmMEVuFuYMZg.ttf
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5ee848665d6d9cec30648d49919e4fba35489ef648c8cbdaff181044d6d28ca8
+size 309760
diff --git a/core/http/static/assets/UcCO3FwrK3iLTeHuS_fvQtMwCp50KnMw2boKoduKmMEVuGKYMZg.ttf b/core/http/static/assets/UcCO3FwrK3iLTeHuS_fvQtMwCp50KnMw2boKoduKmMEVuGKYMZg.ttf
new file mode 100644
index 0000000000000000000000000000000000000000..468a555db198b895f277f55feaaff4a683c67135
--- /dev/null
+++ b/core/http/static/assets/UcCO3FwrK3iLTeHuS_fvQtMwCp50KnMw2boKoduKmMEVuGKYMZg.ttf
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:702d9ba4c20991a732b767801ff996a93990a7d1a3a6954e521224de714c4b7c
+size 309404
diff --git a/core/http/static/assets/UcCO3FwrK3iLTeHuS_fvQtMwCp50KnMw2boKoduKmMEVuLyfMZg.ttf b/core/http/static/assets/UcCO3FwrK3iLTeHuS_fvQtMwCp50KnMw2boKoduKmMEVuLyfMZg.ttf
new file mode 100644
index 0000000000000000000000000000000000000000..adb966eb75cbc9b6d2b1cf32f44969649525a898
--- /dev/null
+++ b/core/http/static/assets/UcCO3FwrK3iLTeHuS_fvQtMwCp50KnMw2boKoduKmMEVuLyfMZg.ttf
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:02c6d2ce3eb535653060cf6105c31551ba740750a7fd8a3e084d8864d82b888d
+size 303412
diff --git a/core/http/static/assets/alpine.js b/core/http/static/assets/alpine.js
new file mode 100644
index 0000000000000000000000000000000000000000..fd09b4961dff02ae32a50272d9d6d11a5ab6cd0a
--- /dev/null
+++ b/core/http/static/assets/alpine.js
@@ -0,0 +1,5 @@
+(()=>{var rt=!1,nt=!1,U=[],it=-1;function qt(e){On(e)}function On(e){U.includes(e)||U.push(e),Cn()}function Ee(e){let t=U.indexOf(e);t!==-1&&t>it&&U.splice(t,1)}function Cn(){!nt&&!rt&&(rt=!0,queueMicrotask(Tn))}function Tn(){rt=!1,nt=!0;for(let e=0;ee.effect(t,{scheduler:r=>{ot?qt(r):r()}}),st=e.raw}function at(e){D=e}function Gt(e){let t=()=>{};return[n=>{let i=D(n);return e._x_effects||(e._x_effects=new Set,e._x_runEffects=()=>{e._x_effects.forEach(o=>o())}),e._x_effects.add(i),t=()=>{i!==void 0&&(e._x_effects.delete(i),L(i))},i},()=>{t()}]}function Se(e,t){let r=!0,n,i=D(()=>{let o=e();JSON.stringify(o),r?n=o:queueMicrotask(()=>{t(o,n),n=o}),r=!1});return()=>L(i)}var Jt=[],Yt=[],Xt=[];function Zt(e){Xt.push(e)}function ee(e,t){typeof t=="function"?(e._x_cleanups||(e._x_cleanups=[]),e._x_cleanups.push(t)):(t=e,Yt.push(t))}function Ae(e){Jt.push(e)}function Oe(e,t,r){e._x_attributeCleanups||(e._x_attributeCleanups={}),e._x_attributeCleanups[t]||(e._x_attributeCleanups[t]=[]),e._x_attributeCleanups[t].push(r)}function ct(e,t){e._x_attributeCleanups&&Object.entries(e._x_attributeCleanups).forEach(([r,n])=>{(t===void 0||t.includes(r))&&(n.forEach(i=>i()),delete e._x_attributeCleanups[r])})}function Qt(e){if(e._x_cleanups)for(;e._x_cleanups.length;)e._x_cleanups.pop()()}var lt=new MutationObserver(pt),ut=!1;function le(){lt.observe(document,{subtree:!0,childList:!0,attributes:!0,attributeOldValue:!0}),ut=!0}function ft(){Rn(),lt.disconnect(),ut=!1}var ce=[];function Rn(){let e=lt.takeRecords();ce.push(()=>e.length>0&&pt(e));let t=ce.length;queueMicrotask(()=>{if(ce.length===t)for(;ce.length>0;)ce.shift()()})}function _(e){if(!ut)return e();ft();let t=e();return le(),t}var dt=!1,ve=[];function er(){dt=!0}function tr(){dt=!1,pt(ve),ve=[]}function pt(e){if(dt){ve=ve.concat(e);return}let t=new Set,r=new Set,n=new Map,i=new Map;for(let o=0;os.nodeType===1&&t.add(s)),e[o].removedNodes.forEach(s=>s.nodeType===1&&r.add(s))),e[o].type==="attributes")){let s=e[o].target,a=e[o].attributeName,c=e[o].oldValue,l=()=>{n.has(s)||n.set(s,[]),n.get(s).push({name:a,value:s.getAttribute(a)})},u=()=>{i.has(s)||i.set(s,[]),i.get(s).push(a)};s.hasAttribute(a)&&c===null?l():s.hasAttribute(a)?(u(),l()):u()}i.forEach((o,s)=>{ct(s,o)}),n.forEach((o,s)=>{Jt.forEach(a=>a(s,o))});for(let o of r)t.has(o)||Yt.forEach(s=>s(o));t.forEach(o=>{o._x_ignoreSelf=!0,o._x_ignore=!0});for(let o of t)r.has(o)||o.isConnected&&(delete o._x_ignoreSelf,delete o._x_ignore,Xt.forEach(s=>s(o)),o._x_ignore=!0,o._x_ignoreSelf=!0);t.forEach(o=>{delete o._x_ignoreSelf,delete o._x_ignore}),t=null,r=null,n=null,i=null}function Ce(e){return F(j(e))}function P(e,t,r){return e._x_dataStack=[t,...j(r||e)],()=>{e._x_dataStack=e._x_dataStack.filter(n=>n!==t)}}function j(e){return e._x_dataStack?e._x_dataStack:typeof ShadowRoot=="function"&&e instanceof ShadowRoot?j(e.host):e.parentNode?j(e.parentNode):[]}function F(e){return new Proxy({objects:e},Mn)}var Mn={ownKeys({objects:e}){return Array.from(new Set(e.flatMap(t=>Object.keys(t))))},has({objects:e},t){return t==Symbol.unscopables?!1:e.some(r=>Object.prototype.hasOwnProperty.call(r,t)||Reflect.has(r,t))},get({objects:e},t,r){return t=="toJSON"?Nn:Reflect.get(e.find(n=>Reflect.has(n,t))||{},t,r)},set({objects:e},t,r,n){let i=e.find(s=>Object.prototype.hasOwnProperty.call(s,t))||e[e.length-1],o=Object.getOwnPropertyDescriptor(i,t);return o?.set&&o?.get?Reflect.set(i,t,r,n):Reflect.set(i,t,r)}};function Nn(){return Reflect.ownKeys(this).reduce((t,r)=>(t[r]=Reflect.get(this,r),t),{})}function Te(e){let t=n=>typeof n=="object"&&!Array.isArray(n)&&n!==null,r=(n,i="")=>{Object.entries(Object.getOwnPropertyDescriptors(n)).forEach(([o,{value:s,enumerable:a}])=>{if(a===!1||s===void 0||typeof s=="object"&&s!==null&&s.__v_skip)return;let c=i===""?o:`${i}.${o}`;typeof s=="object"&&s!==null&&s._x_interceptor?n[o]=s.initialize(e,c,o):t(s)&&s!==n&&!(s instanceof Element)&&r(s,c)})};return r(e)}function Re(e,t=()=>{}){let r={initialValue:void 0,_x_interceptor:!0,initialize(n,i,o){return e(this.initialValue,()=>Dn(n,i),s=>mt(n,i,s),i,o)}};return t(r),n=>{if(typeof n=="object"&&n!==null&&n._x_interceptor){let i=r.initialize.bind(r);r.initialize=(o,s,a)=>{let c=n.initialize(o,s,a);return r.initialValue=c,i(o,s,a)}}else r.initialValue=n;return r}}function Dn(e,t){return t.split(".").reduce((r,n)=>r[n],e)}function mt(e,t,r){if(typeof t=="string"&&(t=t.split(".")),t.length===1)e[t[0]]=r;else{if(t.length===0)throw error;return e[t[0]]||(e[t[0]]={}),mt(e[t[0]],t.slice(1),r)}}var rr={};function y(e,t){rr[e]=t}function ue(e,t){return Object.entries(rr).forEach(([r,n])=>{let i=null;function o(){if(i)return i;{let[s,a]=_t(t);return i={interceptor:Re,...s},ee(t,a),i}}Object.defineProperty(e,`$${r}`,{get(){return n(t,o())},enumerable:!1})}),e}function nr(e,t,r,...n){try{return r(...n)}catch(i){te(i,e,t)}}function te(e,t,r=void 0){e=Object.assign(e??{message:"No error message given."},{el:t,expression:r}),console.warn(`Alpine Expression Error: ${e.message}
+
+${r?'Expression: "'+r+`"
+
+`:""}`,t),setTimeout(()=>{throw e},0)}var Me=!0;function De(e){let t=Me;Me=!1;let r=e();return Me=t,r}function M(e,t,r={}){let n;return x(e,t)(i=>n=i,r),n}function x(...e){return ir(...e)}var ir=gt;function or(e){ir=e}function gt(e,t){let r={};ue(r,e);let n=[r,...j(e)],i=typeof t=="function"?Pn(n,t):kn(n,t,e);return nr.bind(null,e,t,i)}function Pn(e,t){return(r=()=>{},{scope:n={},params:i=[]}={})=>{let o=t.apply(F([n,...e]),i);Ne(r,o)}}var ht={};function In(e,t){if(ht[e])return ht[e];let r=Object.getPrototypeOf(async function(){}).constructor,n=/^[\n\s]*if.*\(.*\)/.test(e.trim())||/^(let|const)\s/.test(e.trim())?`(async()=>{ ${e} })()`:e,o=(()=>{try{let s=new r(["__self","scope"],`with (scope) { __self.result = ${n} }; __self.finished = true; return __self.result;`);return Object.defineProperty(s,"name",{value:`[Alpine] ${e}`}),s}catch(s){return te(s,t,e),Promise.resolve()}})();return ht[e]=o,o}function kn(e,t,r){let n=In(t,r);return(i=()=>{},{scope:o={},params:s=[]}={})=>{n.result=void 0,n.finished=!1;let a=F([o,...e]);if(typeof n=="function"){let c=n(n,a).catch(l=>te(l,r,t));n.finished?(Ne(i,n.result,a,s,r),n.result=void 0):c.then(l=>{Ne(i,l,a,s,r)}).catch(l=>te(l,r,t)).finally(()=>n.result=void 0)}}}function Ne(e,t,r,n,i){if(Me&&typeof t=="function"){let o=t.apply(r,n);o instanceof Promise?o.then(s=>Ne(e,s,r,n)).catch(s=>te(s,i,t)):e(o)}else typeof t=="object"&&t instanceof Promise?t.then(o=>e(o)):e(t)}var bt="x-";function C(e=""){return bt+e}function sr(e){bt=e}var Pe={};function d(e,t){return Pe[e]=t,{before(r){if(!Pe[r]){console.warn(String.raw`Cannot find directive \`${r}\`. \`${e}\` will use the default order of execution`);return}let n=W.indexOf(r);W.splice(n>=0?n:W.indexOf("DEFAULT"),0,e)}}}function ar(e){return Object.keys(Pe).includes(e)}function de(e,t,r){if(t=Array.from(t),e._x_virtualDirectives){let o=Object.entries(e._x_virtualDirectives).map(([a,c])=>({name:a,value:c})),s=wt(o);o=o.map(a=>s.find(c=>c.name===a.name)?{name:`x-bind:${a.name}`,value:`"${a.value}"`}:a),t=t.concat(o)}let n={};return t.map(ur((o,s)=>n[o]=s)).filter(dr).map($n(n,r)).sort(jn).map(o=>Ln(e,o))}function wt(e){return Array.from(e).map(ur()).filter(t=>!dr(t))}var xt=!1,fe=new Map,cr=Symbol();function lr(e){xt=!0;let t=Symbol();cr=t,fe.set(t,[]);let r=()=>{for(;fe.get(t).length;)fe.get(t).shift()();fe.delete(t)},n=()=>{xt=!1,r()};e(r),n()}function _t(e){let t=[],r=a=>t.push(a),[n,i]=Gt(e);return t.push(i),[{Alpine:B,effect:n,cleanup:r,evaluateLater:x.bind(x,e),evaluate:M.bind(M,e)},()=>t.forEach(a=>a())]}function Ln(e,t){let r=()=>{},n=Pe[t.type]||r,[i,o]=_t(e);Oe(e,t.original,o);let s=()=>{e._x_ignore||e._x_ignoreSelf||(n.inline&&n.inline(e,t,i),n=n.bind(n,e,t,i),xt?fe.get(cr).push(n):n())};return s.runCleanups=o,s}var Ie=(e,t)=>({name:r,value:n})=>(r.startsWith(e)&&(r=r.replace(e,t)),{name:r,value:n}),ke=e=>e;function ur(e=()=>{}){return({name:t,value:r})=>{let{name:n,value:i}=fr.reduce((o,s)=>s(o),{name:t,value:r});return n!==t&&e(n,t),{name:n,value:i}}}var fr=[];function re(e){fr.push(e)}function dr({name:e}){return pr().test(e)}var pr=()=>new RegExp(`^${bt}([^:^.]+)\\b`);function $n(e,t){return({name:r,value:n})=>{let i=r.match(pr()),o=r.match(/:([a-zA-Z0-9\-_:]+)/),s=r.match(/\.[^.\]]+(?=[^\]]*$)/g)||[],a=t||e[r]||r;return{type:i?i[1]:null,value:o?o[1]:null,modifiers:s.map(c=>c.replace(".","")),expression:n,original:a}}}var yt="DEFAULT",W=["ignore","ref","data","id","anchor","bind","init","for","model","modelable","transition","show","if",yt,"teleport"];function jn(e,t){let r=W.indexOf(e.type)===-1?yt:e.type,n=W.indexOf(t.type)===-1?yt:t.type;return W.indexOf(r)-W.indexOf(n)}function G(e,t,r={}){e.dispatchEvent(new CustomEvent(t,{detail:r,bubbles:!0,composed:!0,cancelable:!0}))}function T(e,t){if(typeof ShadowRoot=="function"&&e instanceof ShadowRoot){Array.from(e.children).forEach(i=>T(i,t));return}let r=!1;if(t(e,()=>r=!0),r)return;let n=e.firstElementChild;for(;n;)T(n,t,!1),n=n.nextElementSibling}function E(e,...t){console.warn(`Alpine Warning: ${e}`,...t)}var mr=!1;function _r(){mr&&E("Alpine has already been initialized on this page. Calling Alpine.start() more than once can cause problems."),mr=!0,document.body||E("Unable to initialize. Trying to load Alpine before `` is available. Did you forget to add `defer` in Alpine's `
+
+
+