|
|
package http_test |
|
|
|
|
|
import ( |
|
|
"bytes" |
|
|
"context" |
|
|
"encoding/json" |
|
|
"fmt" |
|
|
"io" |
|
|
"net/http" |
|
|
"os" |
|
|
"path/filepath" |
|
|
"runtime" |
|
|
"time" |
|
|
|
|
|
"github.com/mudler/LocalAI/core/application" |
|
|
"github.com/mudler/LocalAI/core/config" |
|
|
. "github.com/mudler/LocalAI/core/http" |
|
|
"github.com/mudler/LocalAI/core/schema" |
|
|
|
|
|
"github.com/labstack/echo/v4" |
|
|
"github.com/mudler/LocalAI/core/gallery" |
|
|
"github.com/mudler/LocalAI/pkg/downloader" |
|
|
"github.com/mudler/LocalAI/pkg/system" |
|
|
. "github.com/onsi/ginkgo/v2" |
|
|
. "github.com/onsi/gomega" |
|
|
"gopkg.in/yaml.v3" |
|
|
|
|
|
"github.com/mudler/xlog" |
|
|
openaigo "github.com/otiai10/openaigo" |
|
|
"github.com/sashabaranov/go-openai" |
|
|
"github.com/sashabaranov/go-openai/jsonschema" |
|
|
) |
|
|
|
|
|
const apiKey = "joshua" |
|
|
const bearerKey = "Bearer " + apiKey |
|
|
|
|
|
const testPrompt = `### System: |
|
|
You are an AI assistant that follows instruction extremely well. Help as much as you can. |
|
|
|
|
|
### Instruction: |
|
|
|
|
|
Say hello. |
|
|
|
|
|
### Response:` |
|
|
|
|
|
type modelApplyRequest struct { |
|
|
ID string `json:"id"` |
|
|
URL string `json:"url"` |
|
|
ConfigURL string `json:"config_url"` |
|
|
Name string `json:"name"` |
|
|
Overrides map[string]interface{} `json:"overrides"` |
|
|
} |
|
|
|
|
|
func getModelStatus(url string) (response map[string]interface{}) { |
|
|
|
|
|
req, err := http.NewRequest("GET", url, nil) |
|
|
req.Header.Set("Content-Type", "application/json") |
|
|
req.Header.Set("Authorization", bearerKey) |
|
|
if err != nil { |
|
|
fmt.Println("Error creating request:", err) |
|
|
return |
|
|
} |
|
|
client := &http.Client{} |
|
|
resp, err := client.Do(req) |
|
|
if err != nil { |
|
|
fmt.Println("Error sending request:", err) |
|
|
return |
|
|
} |
|
|
defer resp.Body.Close() |
|
|
|
|
|
body, err := io.ReadAll(resp.Body) |
|
|
if err != nil { |
|
|
fmt.Println("Error reading response body:", err) |
|
|
return |
|
|
} |
|
|
|
|
|
|
|
|
err = json.Unmarshal(body, &response) |
|
|
if err != nil { |
|
|
fmt.Println("Error unmarshaling JSON response:", err) |
|
|
return |
|
|
} |
|
|
return |
|
|
} |
|
|
|
|
|
func getModels(url string) ([]gallery.GalleryModel, error) { |
|
|
response := []gallery.GalleryModel{} |
|
|
uri := downloader.URI(url) |
|
|
|
|
|
err := uri.ReadWithAuthorizationAndCallback(context.TODO(), "", bearerKey, func(url string, i []byte) error { |
|
|
|
|
|
return json.Unmarshal(i, &response) |
|
|
}) |
|
|
return response, err |
|
|
} |
|
|
|
|
|
func postModelApplyRequest(url string, request modelApplyRequest) (response map[string]interface{}) { |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
payload, err := json.Marshal(request) |
|
|
if err != nil { |
|
|
fmt.Println("Error marshaling JSON:", err) |
|
|
return |
|
|
} |
|
|
|
|
|
|
|
|
req, err := http.NewRequest("POST", url, bytes.NewBuffer(payload)) |
|
|
if err != nil { |
|
|
fmt.Println("Error creating request:", err) |
|
|
return |
|
|
} |
|
|
req.Header.Set("Content-Type", "application/json") |
|
|
req.Header.Set("Authorization", bearerKey) |
|
|
|
|
|
|
|
|
client := &http.Client{} |
|
|
resp, err := client.Do(req) |
|
|
if err != nil { |
|
|
fmt.Println("Error making request:", err) |
|
|
return |
|
|
} |
|
|
defer resp.Body.Close() |
|
|
|
|
|
body, err := io.ReadAll(resp.Body) |
|
|
if err != nil { |
|
|
fmt.Println("Error reading response body:", err) |
|
|
return |
|
|
} |
|
|
|
|
|
|
|
|
err = json.Unmarshal(body, &response) |
|
|
if err != nil { |
|
|
fmt.Println("Error unmarshaling JSON response:", err) |
|
|
return |
|
|
} |
|
|
return |
|
|
} |
|
|
|
|
|
func postRequestJSON[B any](url string, bodyJson *B) error { |
|
|
payload, err := json.Marshal(bodyJson) |
|
|
if err != nil { |
|
|
return err |
|
|
} |
|
|
|
|
|
GinkgoWriter.Printf("POST %s: %s\n", url, string(payload)) |
|
|
|
|
|
req, err := http.NewRequest("POST", url, bytes.NewBuffer(payload)) |
|
|
if err != nil { |
|
|
return err |
|
|
} |
|
|
|
|
|
req.Header.Set("Content-Type", "application/json") |
|
|
req.Header.Set("Authorization", bearerKey) |
|
|
|
|
|
client := &http.Client{} |
|
|
resp, err := client.Do(req) |
|
|
if err != nil { |
|
|
return err |
|
|
} |
|
|
|
|
|
defer resp.Body.Close() |
|
|
|
|
|
body, err := io.ReadAll(resp.Body) |
|
|
if err != nil { |
|
|
return err |
|
|
} |
|
|
|
|
|
if resp.StatusCode < 200 || resp.StatusCode >= 400 { |
|
|
return fmt.Errorf("unexpected status code: %d, body: %s", resp.StatusCode, string(body)) |
|
|
} |
|
|
|
|
|
return nil |
|
|
} |
|
|
|
|
|
func postRequestResponseJSON[B1 any, B2 any](url string, reqJson *B1, respJson *B2) error { |
|
|
payload, err := json.Marshal(reqJson) |
|
|
if err != nil { |
|
|
return err |
|
|
} |
|
|
|
|
|
GinkgoWriter.Printf("POST %s: %s\n", url, string(payload)) |
|
|
|
|
|
req, err := http.NewRequest("POST", url, bytes.NewBuffer(payload)) |
|
|
if err != nil { |
|
|
return err |
|
|
} |
|
|
|
|
|
req.Header.Set("Content-Type", "application/json") |
|
|
req.Header.Set("Authorization", bearerKey) |
|
|
|
|
|
client := &http.Client{} |
|
|
resp, err := client.Do(req) |
|
|
if err != nil { |
|
|
return err |
|
|
} |
|
|
defer resp.Body.Close() |
|
|
|
|
|
body, err := io.ReadAll(resp.Body) |
|
|
if err != nil { |
|
|
return err |
|
|
} |
|
|
|
|
|
if resp.StatusCode < 200 || resp.StatusCode >= 400 { |
|
|
return fmt.Errorf("unexpected status code: %d, body: %s", resp.StatusCode, string(body)) |
|
|
} |
|
|
|
|
|
return json.Unmarshal(body, respJson) |
|
|
} |
|
|
|
|
|
func putRequestJSON[B any](url string, bodyJson *B) error { |
|
|
payload, err := json.Marshal(bodyJson) |
|
|
if err != nil { |
|
|
return err |
|
|
} |
|
|
|
|
|
GinkgoWriter.Printf("PUT %s: %s\n", url, string(payload)) |
|
|
|
|
|
req, err := http.NewRequest("PUT", url, bytes.NewBuffer(payload)) |
|
|
if err != nil { |
|
|
return err |
|
|
} |
|
|
|
|
|
req.Header.Set("Content-Type", "application/json") |
|
|
req.Header.Set("Authorization", bearerKey) |
|
|
|
|
|
client := &http.Client{} |
|
|
resp, err := client.Do(req) |
|
|
if err != nil { |
|
|
return err |
|
|
} |
|
|
defer resp.Body.Close() |
|
|
|
|
|
body, err := io.ReadAll(resp.Body) |
|
|
if err != nil { |
|
|
return err |
|
|
} |
|
|
|
|
|
if resp.StatusCode < 200 || resp.StatusCode >= 400 { |
|
|
return fmt.Errorf("unexpected status code: %d, body: %s", resp.StatusCode, string(body)) |
|
|
} |
|
|
|
|
|
return nil |
|
|
} |
|
|
|
|
|
func postInvalidRequest(url string) (error, int) { |
|
|
|
|
|
req, err := http.NewRequest("POST", url, bytes.NewBufferString("invalid request")) |
|
|
if err != nil { |
|
|
return err, -1 |
|
|
} |
|
|
|
|
|
req.Header.Set("Content-Type", "application/json") |
|
|
|
|
|
client := &http.Client{} |
|
|
resp, err := client.Do(req) |
|
|
if err != nil { |
|
|
return err, -1 |
|
|
} |
|
|
|
|
|
defer resp.Body.Close() |
|
|
|
|
|
body, err := io.ReadAll(resp.Body) |
|
|
if err != nil { |
|
|
return err, -1 |
|
|
} |
|
|
|
|
|
if resp.StatusCode < 200 || resp.StatusCode >= 400 { |
|
|
return fmt.Errorf("unexpected status code: %d, body: %s", resp.StatusCode, string(body)), resp.StatusCode |
|
|
} |
|
|
|
|
|
return nil, resp.StatusCode |
|
|
} |
|
|
|
|
|
func getRequest(url string, header http.Header) (error, int, []byte) { |
|
|
|
|
|
req, err := http.NewRequest("GET", url, nil) |
|
|
if err != nil { |
|
|
return err, -1, nil |
|
|
} |
|
|
|
|
|
req.Header = header |
|
|
|
|
|
client := &http.Client{} |
|
|
resp, err := client.Do(req) |
|
|
if err != nil { |
|
|
return err, -1, nil |
|
|
} |
|
|
|
|
|
defer resp.Body.Close() |
|
|
|
|
|
body, err := io.ReadAll(resp.Body) |
|
|
if err != nil { |
|
|
return err, -1, nil |
|
|
} |
|
|
|
|
|
return nil, resp.StatusCode, body |
|
|
} |
|
|
|
|
|
const bertEmbeddingsURL = `https://gist.githubusercontent.com/mudler/0a080b166b87640e8644b09c2aee6e3b/raw/f0e8c26bb72edc16d9fbafbfd6638072126ff225/bert-embeddings-gallery.yaml` |
|
|
|
|
|
var _ = Describe("API test", func() { |
|
|
|
|
|
var app *echo.Echo |
|
|
var client *openai.Client |
|
|
var client2 *openaigo.Client |
|
|
var c context.Context |
|
|
var cancel context.CancelFunc |
|
|
var tmpdir string |
|
|
var modelDir string |
|
|
|
|
|
commonOpts := []config.AppOption{ |
|
|
config.WithDebug(true), |
|
|
} |
|
|
|
|
|
Context("API with ephemeral models", func() { |
|
|
|
|
|
BeforeEach(func(sc SpecContext) { |
|
|
var err error |
|
|
tmpdir, err = os.MkdirTemp("", "") |
|
|
Expect(err).ToNot(HaveOccurred()) |
|
|
|
|
|
backendPath := os.Getenv("BACKENDS_PATH") |
|
|
|
|
|
modelDir = filepath.Join(tmpdir, "models") |
|
|
err = os.Mkdir(modelDir, 0750) |
|
|
Expect(err).ToNot(HaveOccurred()) |
|
|
|
|
|
c, cancel = context.WithCancel(context.Background()) |
|
|
|
|
|
g := []gallery.GalleryModel{ |
|
|
{ |
|
|
Metadata: gallery.Metadata{ |
|
|
Name: "bert", |
|
|
URL: bertEmbeddingsURL, |
|
|
}, |
|
|
}, |
|
|
{ |
|
|
Metadata: gallery.Metadata{ |
|
|
Name: "bert2", |
|
|
URL: bertEmbeddingsURL, |
|
|
AdditionalFiles: []gallery.File{{Filename: "foo.yaml", URI: bertEmbeddingsURL}}, |
|
|
}, |
|
|
Overrides: map[string]interface{}{"foo": "bar"}, |
|
|
}, |
|
|
} |
|
|
out, err := yaml.Marshal(g) |
|
|
Expect(err).ToNot(HaveOccurred()) |
|
|
err = os.WriteFile(filepath.Join(modelDir, "gallery_simple.yaml"), out, 0600) |
|
|
Expect(err).ToNot(HaveOccurred()) |
|
|
|
|
|
galleries := []config.Gallery{ |
|
|
{ |
|
|
Name: "test", |
|
|
URL: "file://" + filepath.Join(modelDir, "gallery_simple.yaml"), |
|
|
}, |
|
|
} |
|
|
|
|
|
systemState, err := system.GetSystemState( |
|
|
system.WithBackendPath(backendPath), |
|
|
system.WithModelPath(modelDir), |
|
|
) |
|
|
Expect(err).ToNot(HaveOccurred()) |
|
|
|
|
|
application, err := application.New( |
|
|
append(commonOpts, |
|
|
config.WithContext(c), |
|
|
config.WithSystemState(systemState), |
|
|
config.WithGalleries(galleries), |
|
|
config.WithApiKeys([]string{apiKey}), |
|
|
)...) |
|
|
Expect(err).ToNot(HaveOccurred()) |
|
|
|
|
|
app, err = API(application) |
|
|
Expect(err).ToNot(HaveOccurred()) |
|
|
|
|
|
go func() { |
|
|
if err := app.Start("127.0.0.1:9090"); err != nil && err != http.ErrServerClosed { |
|
|
xlog.Error("server error", "error", err) |
|
|
} |
|
|
}() |
|
|
|
|
|
defaultConfig := openai.DefaultConfig(apiKey) |
|
|
defaultConfig.BaseURL = "http://127.0.0.1:9090/v1" |
|
|
|
|
|
client2 = openaigo.NewClient("") |
|
|
client2.BaseURL = defaultConfig.BaseURL |
|
|
|
|
|
|
|
|
client = openai.NewClientWithConfig(defaultConfig) |
|
|
Eventually(func() error { |
|
|
_, err := client.ListModels(context.TODO()) |
|
|
return err |
|
|
}, "2m").ShouldNot(HaveOccurred()) |
|
|
}) |
|
|
|
|
|
AfterEach(func(sc SpecContext) { |
|
|
cancel() |
|
|
if app != nil { |
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) |
|
|
defer cancel() |
|
|
err := app.Shutdown(ctx) |
|
|
Expect(err).ToNot(HaveOccurred()) |
|
|
} |
|
|
err := os.RemoveAll(tmpdir) |
|
|
Expect(err).ToNot(HaveOccurred()) |
|
|
_, err = os.ReadDir(tmpdir) |
|
|
Expect(err).To(HaveOccurred()) |
|
|
}) |
|
|
|
|
|
Context("Auth Tests", func() { |
|
|
It("Should fail if the api key is missing", func() { |
|
|
err, sc := postInvalidRequest("http://127.0.0.1:9090/models/available") |
|
|
Expect(err).ToNot(BeNil()) |
|
|
Expect(sc).To(Equal(401)) |
|
|
}) |
|
|
}) |
|
|
|
|
|
Context("URL routing Tests", func() { |
|
|
It("Should support reverse-proxy when unauthenticated", func() { |
|
|
|
|
|
err, sc, body := getRequest("http://127.0.0.1:9090/myprefix/", http.Header{ |
|
|
"X-Forwarded-Proto": {"https"}, |
|
|
"X-Forwarded-Host": {"example.org"}, |
|
|
"X-Forwarded-Prefix": {"/myprefix/"}, |
|
|
}) |
|
|
Expect(err).To(BeNil(), "error") |
|
|
Expect(sc).To(Equal(401), "status code") |
|
|
Expect(string(body)).To(ContainSubstring(`<base href="https://example.org/myprefix/" />`), "body") |
|
|
}) |
|
|
|
|
|
It("Should support reverse-proxy when authenticated", func() { |
|
|
|
|
|
err, sc, body := getRequest("http://127.0.0.1:9090/myprefix/", http.Header{ |
|
|
"Authorization": {bearerKey}, |
|
|
"X-Forwarded-Proto": {"https"}, |
|
|
"X-Forwarded-Host": {"example.org"}, |
|
|
"X-Forwarded-Prefix": {"/myprefix/"}, |
|
|
}) |
|
|
Expect(err).To(BeNil(), "error") |
|
|
Expect(sc).To(Equal(200), "status code") |
|
|
Expect(string(body)).To(ContainSubstring(`<base href="https://example.org/myprefix/" />`), "body") |
|
|
}) |
|
|
}) |
|
|
|
|
|
Context("Applying models", func() { |
|
|
|
|
|
It("applies models from a gallery", func() { |
|
|
models, err := getModels("http://127.0.0.1:9090/models/available") |
|
|
Expect(err).To(BeNil()) |
|
|
Expect(len(models)).To(Equal(2), fmt.Sprint(models)) |
|
|
Expect(models[0].Installed).To(BeFalse(), fmt.Sprint(models)) |
|
|
Expect(models[1].Installed).To(BeFalse(), fmt.Sprint(models)) |
|
|
|
|
|
response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{ |
|
|
ID: "test@bert2", |
|
|
}) |
|
|
|
|
|
Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response)) |
|
|
|
|
|
uuid := response["uuid"].(string) |
|
|
resp := map[string]interface{}{} |
|
|
Eventually(func() bool { |
|
|
response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid) |
|
|
fmt.Println(response) |
|
|
resp = response |
|
|
return response["processed"].(bool) |
|
|
}, "360s", "10s").Should(Equal(true)) |
|
|
Expect(resp["message"]).ToNot(ContainSubstring("error")) |
|
|
|
|
|
dat, err := os.ReadFile(filepath.Join(modelDir, "bert2.yaml")) |
|
|
Expect(err).ToNot(HaveOccurred()) |
|
|
|
|
|
_, err = os.ReadFile(filepath.Join(modelDir, "foo.yaml")) |
|
|
Expect(err).ToNot(HaveOccurred()) |
|
|
|
|
|
content := map[string]interface{}{} |
|
|
err = yaml.Unmarshal(dat, &content) |
|
|
Expect(err).ToNot(HaveOccurred()) |
|
|
Expect(content["usage"]).To(ContainSubstring("You can test this model with curl like this")) |
|
|
Expect(content["foo"]).To(Equal("bar")) |
|
|
|
|
|
models, err = getModels("http://127.0.0.1:9090/models/available") |
|
|
Expect(err).To(BeNil()) |
|
|
Expect(len(models)).To(Equal(2), fmt.Sprint(models)) |
|
|
Expect(models[0].Name).To(Or(Equal("bert"), Equal("bert2"))) |
|
|
Expect(models[1].Name).To(Or(Equal("bert"), Equal("bert2"))) |
|
|
for _, m := range models { |
|
|
if m.Name == "bert2" { |
|
|
Expect(m.Installed).To(BeTrue()) |
|
|
} else { |
|
|
Expect(m.Installed).To(BeFalse()) |
|
|
} |
|
|
} |
|
|
}) |
|
|
It("overrides models", func() { |
|
|
|
|
|
response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{ |
|
|
URL: bertEmbeddingsURL, |
|
|
Name: "bert", |
|
|
Overrides: map[string]interface{}{ |
|
|
"backend": "llama", |
|
|
}, |
|
|
}) |
|
|
|
|
|
Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response)) |
|
|
|
|
|
uuid := response["uuid"].(string) |
|
|
|
|
|
Eventually(func() bool { |
|
|
response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid) |
|
|
return response["processed"].(bool) |
|
|
}, "360s", "10s").Should(Equal(true)) |
|
|
|
|
|
dat, err := os.ReadFile(filepath.Join(modelDir, "bert.yaml")) |
|
|
Expect(err).ToNot(HaveOccurred()) |
|
|
|
|
|
content := map[string]interface{}{} |
|
|
err = yaml.Unmarshal(dat, &content) |
|
|
Expect(err).ToNot(HaveOccurred()) |
|
|
Expect(content["backend"]).To(Equal("llama")) |
|
|
}) |
|
|
It("apply models without overrides", func() { |
|
|
response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{ |
|
|
URL: bertEmbeddingsURL, |
|
|
Name: "bert", |
|
|
Overrides: map[string]interface{}{}, |
|
|
}) |
|
|
|
|
|
Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response)) |
|
|
|
|
|
uuid := response["uuid"].(string) |
|
|
|
|
|
Eventually(func() bool { |
|
|
response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid) |
|
|
return response["processed"].(bool) |
|
|
}, "360s", "10s").Should(Equal(true)) |
|
|
|
|
|
dat, err := os.ReadFile(filepath.Join(modelDir, "bert.yaml")) |
|
|
Expect(err).ToNot(HaveOccurred()) |
|
|
|
|
|
content := map[string]interface{}{} |
|
|
err = yaml.Unmarshal(dat, &content) |
|
|
Expect(err).ToNot(HaveOccurred()) |
|
|
Expect(content["usage"]).To(ContainSubstring("You can test this model with curl like this")) |
|
|
}) |
|
|
|
|
|
}) |
|
|
|
|
|
Context("Importing models from URI", func() { |
|
|
var testYamlFile string |
|
|
|
|
|
BeforeEach(func() { |
|
|
|
|
|
yamlContent := `name: test-import-model |
|
|
backend: llama-cpp |
|
|
description: Test model imported from file URI |
|
|
parameters: |
|
|
model: path/to/model.gguf |
|
|
temperature: 0.7 |
|
|
` |
|
|
testYamlFile = filepath.Join(tmpdir, "test-import.yaml") |
|
|
err := os.WriteFile(testYamlFile, []byte(yamlContent), 0644) |
|
|
Expect(err).ToNot(HaveOccurred()) |
|
|
}) |
|
|
|
|
|
AfterEach(func() { |
|
|
err := os.Remove(testYamlFile) |
|
|
Expect(err).ToNot(HaveOccurred()) |
|
|
}) |
|
|
|
|
|
It("should import model from file:// URI pointing to local YAML config", func() { |
|
|
importReq := schema.ImportModelRequest{ |
|
|
URI: "file://" + testYamlFile, |
|
|
Preferences: json.RawMessage(`{}`), |
|
|
} |
|
|
|
|
|
var response schema.GalleryResponse |
|
|
err := postRequestResponseJSON("http://127.0.0.1:9090/models/import-uri", &importReq, &response) |
|
|
Expect(err).ToNot(HaveOccurred()) |
|
|
Expect(response.ID).ToNot(BeEmpty()) |
|
|
|
|
|
uuid := response.ID |
|
|
resp := map[string]interface{}{} |
|
|
Eventually(func() bool { |
|
|
response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid) |
|
|
resp = response |
|
|
return response["processed"].(bool) |
|
|
}, "360s", "10s").Should(Equal(true)) |
|
|
|
|
|
|
|
|
Expect(resp["message"]).ToNot(ContainSubstring("error")) |
|
|
Expect(resp["error"]).To(BeNil()) |
|
|
|
|
|
|
|
|
dat, err := os.ReadFile(filepath.Join(modelDir, "test-import-model.yaml")) |
|
|
Expect(err).ToNot(HaveOccurred()) |
|
|
|
|
|
content := map[string]interface{}{} |
|
|
err = yaml.Unmarshal(dat, &content) |
|
|
Expect(err).ToNot(HaveOccurred()) |
|
|
Expect(content["name"]).To(Equal("test-import-model")) |
|
|
Expect(content["backend"]).To(Equal("llama-cpp")) |
|
|
}) |
|
|
|
|
|
It("should return error when file:// URI points to non-existent file", func() { |
|
|
nonExistentFile := filepath.Join(tmpdir, "nonexistent.yaml") |
|
|
importReq := schema.ImportModelRequest{ |
|
|
URI: "file://" + nonExistentFile, |
|
|
Preferences: json.RawMessage(`{}`), |
|
|
} |
|
|
|
|
|
var response schema.GalleryResponse |
|
|
err := postRequestResponseJSON("http://127.0.0.1:9090/models/import-uri", &importReq, &response) |
|
|
|
|
|
Expect(err).To(HaveOccurred()) |
|
|
Expect(err.Error()).To(ContainSubstring("failed to discover model config")) |
|
|
}) |
|
|
}) |
|
|
|
|
|
Context("Importing models from URI can't point to absolute paths", func() { |
|
|
var testYamlFile string |
|
|
|
|
|
BeforeEach(func() { |
|
|
|
|
|
yamlContent := `name: test-import-model |
|
|
backend: llama-cpp |
|
|
description: Test model imported from file URI |
|
|
parameters: |
|
|
model: /path/to/model.gguf |
|
|
temperature: 0.7 |
|
|
` |
|
|
testYamlFile = filepath.Join(tmpdir, "test-import.yaml") |
|
|
err := os.WriteFile(testYamlFile, []byte(yamlContent), 0644) |
|
|
Expect(err).ToNot(HaveOccurred()) |
|
|
}) |
|
|
|
|
|
AfterEach(func() { |
|
|
err := os.Remove(testYamlFile) |
|
|
Expect(err).ToNot(HaveOccurred()) |
|
|
}) |
|
|
|
|
|
It("should fail to import model from file:// URI pointing to local YAML config", func() { |
|
|
importReq := schema.ImportModelRequest{ |
|
|
URI: "file://" + testYamlFile, |
|
|
Preferences: json.RawMessage(`{}`), |
|
|
} |
|
|
|
|
|
var response schema.GalleryResponse |
|
|
err := postRequestResponseJSON("http://127.0.0.1:9090/models/import-uri", &importReq, &response) |
|
|
Expect(err).ToNot(HaveOccurred()) |
|
|
Expect(response.ID).ToNot(BeEmpty()) |
|
|
|
|
|
uuid := response.ID |
|
|
resp := map[string]interface{}{} |
|
|
Eventually(func() bool { |
|
|
response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid) |
|
|
resp = response |
|
|
return response["processed"].(bool) |
|
|
}, "360s", "10s").Should(Equal(true)) |
|
|
|
|
|
|
|
|
Expect(resp["message"]).To(ContainSubstring("error")) |
|
|
Expect(resp["error"]).ToNot(BeNil()) |
|
|
}) |
|
|
}) |
|
|
}) |
|
|
|
|
|
Context("Model gallery", func() { |
|
|
BeforeEach(func() { |
|
|
var err error |
|
|
tmpdir, err = os.MkdirTemp("", "") |
|
|
|
|
|
backendPath := os.Getenv("BACKENDS_PATH") |
|
|
|
|
|
Expect(err).ToNot(HaveOccurred()) |
|
|
modelDir = filepath.Join(tmpdir, "models") |
|
|
backendAssetsDir := filepath.Join(tmpdir, "backend-assets") |
|
|
err = os.Mkdir(backendAssetsDir, 0750) |
|
|
Expect(err).ToNot(HaveOccurred()) |
|
|
|
|
|
c, cancel = context.WithCancel(context.Background()) |
|
|
|
|
|
galleries := []config.Gallery{ |
|
|
{ |
|
|
Name: "localai", |
|
|
URL: "https://raw.githubusercontent.com/mudler/LocalAI/refs/heads/master/gallery/index.yaml", |
|
|
}, |
|
|
} |
|
|
|
|
|
systemState, err := system.GetSystemState( |
|
|
system.WithBackendPath(backendPath), |
|
|
system.WithModelPath(modelDir), |
|
|
) |
|
|
Expect(err).ToNot(HaveOccurred()) |
|
|
|
|
|
application, err := application.New( |
|
|
append(commonOpts, |
|
|
config.WithContext(c), |
|
|
config.WithGeneratedContentDir(tmpdir), |
|
|
config.WithSystemState(systemState), |
|
|
config.WithGalleries(galleries), |
|
|
)..., |
|
|
) |
|
|
Expect(err).ToNot(HaveOccurred()) |
|
|
app, err = API(application) |
|
|
Expect(err).ToNot(HaveOccurred()) |
|
|
|
|
|
go func() { |
|
|
if err := app.Start("127.0.0.1:9090"); err != nil && err != http.ErrServerClosed { |
|
|
xlog.Error("server error", "error", err) |
|
|
} |
|
|
}() |
|
|
|
|
|
defaultConfig := openai.DefaultConfig("") |
|
|
defaultConfig.BaseURL = "http://127.0.0.1:9090/v1" |
|
|
|
|
|
client2 = openaigo.NewClient("") |
|
|
client2.BaseURL = defaultConfig.BaseURL |
|
|
|
|
|
|
|
|
client = openai.NewClientWithConfig(defaultConfig) |
|
|
Eventually(func() error { |
|
|
_, err := client.ListModels(context.TODO()) |
|
|
return err |
|
|
}, "2m").ShouldNot(HaveOccurred()) |
|
|
}) |
|
|
|
|
|
AfterEach(func() { |
|
|
cancel() |
|
|
if app != nil { |
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) |
|
|
defer cancel() |
|
|
err := app.Shutdown(ctx) |
|
|
Expect(err).ToNot(HaveOccurred()) |
|
|
} |
|
|
err := os.RemoveAll(tmpdir) |
|
|
Expect(err).ToNot(HaveOccurred()) |
|
|
_, err = os.ReadDir(tmpdir) |
|
|
Expect(err).To(HaveOccurred()) |
|
|
}) |
|
|
|
|
|
It("runs gguf models (chat)", Label("llama-gguf"), func() { |
|
|
if runtime.GOOS != "linux" { |
|
|
Skip("test supported only on linux") |
|
|
} |
|
|
|
|
|
modelName := "qwen3-1.7b" |
|
|
response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{ |
|
|
ID: "localai@" + modelName, |
|
|
}) |
|
|
|
|
|
Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response)) |
|
|
|
|
|
uuid := response["uuid"].(string) |
|
|
|
|
|
Eventually(func() bool { |
|
|
response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid) |
|
|
return response["processed"].(bool) |
|
|
}, "900s", "10s").Should(Equal(true)) |
|
|
|
|
|
By("testing chat") |
|
|
resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: modelName, Messages: []openai.ChatCompletionMessage{ |
|
|
{ |
|
|
Role: "user", |
|
|
Content: "How much is 2+2?", |
|
|
}, |
|
|
}}) |
|
|
Expect(err).ToNot(HaveOccurred()) |
|
|
Expect(len(resp.Choices)).To(Equal(1)) |
|
|
Expect(resp.Choices[0].Message.Content).To(Or(ContainSubstring("4"), ContainSubstring("four"))) |
|
|
|
|
|
By("testing functions") |
|
|
resp2, err := client.CreateChatCompletion( |
|
|
context.TODO(), |
|
|
openai.ChatCompletionRequest{ |
|
|
Model: modelName, |
|
|
Messages: []openai.ChatCompletionMessage{ |
|
|
{ |
|
|
Role: "user", |
|
|
Content: "What is the weather like in San Francisco (celsius)?", |
|
|
}, |
|
|
}, |
|
|
Functions: []openai.FunctionDefinition{ |
|
|
openai.FunctionDefinition{ |
|
|
Name: "get_current_weather", |
|
|
Description: "Get the current weather", |
|
|
Parameters: jsonschema.Definition{ |
|
|
Type: jsonschema.Object, |
|
|
Properties: map[string]jsonschema.Definition{ |
|
|
"location": { |
|
|
Type: jsonschema.String, |
|
|
Description: "The city and state, e.g. San Francisco, CA", |
|
|
}, |
|
|
"unit": { |
|
|
Type: jsonschema.String, |
|
|
Enum: []string{"celcius", "fahrenheit"}, |
|
|
}, |
|
|
}, |
|
|
Required: []string{"location"}, |
|
|
}, |
|
|
}, |
|
|
}, |
|
|
}) |
|
|
Expect(err).ToNot(HaveOccurred()) |
|
|
Expect(len(resp2.Choices)).To(Equal(1)) |
|
|
Expect(resp2.Choices[0].Message.FunctionCall).ToNot(BeNil()) |
|
|
Expect(resp2.Choices[0].Message.FunctionCall.Name).To(Equal("get_current_weather"), resp2.Choices[0].Message.FunctionCall.Name) |
|
|
|
|
|
var res map[string]string |
|
|
err = json.Unmarshal([]byte(resp2.Choices[0].Message.FunctionCall.Arguments), &res) |
|
|
Expect(err).ToNot(HaveOccurred()) |
|
|
Expect(res["location"]).To(ContainSubstring("San Francisco"), fmt.Sprint(res)) |
|
|
Expect(res["unit"]).To(Equal("celcius"), fmt.Sprint(res)) |
|
|
Expect(string(resp2.Choices[0].FinishReason)).To(Equal("function_call"), fmt.Sprint(resp2.Choices[0].FinishReason)) |
|
|
}) |
|
|
|
|
|
It("installs and is capable to run tts", Label("tts"), func() { |
|
|
if runtime.GOOS != "linux" { |
|
|
Skip("test supported only on linux") |
|
|
} |
|
|
|
|
|
response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{ |
|
|
ID: "localai@voice-en-us-kathleen-low", |
|
|
}) |
|
|
|
|
|
Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response)) |
|
|
|
|
|
uuid := response["uuid"].(string) |
|
|
|
|
|
Eventually(func() bool { |
|
|
response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid) |
|
|
fmt.Println(response) |
|
|
return response["processed"].(bool) |
|
|
}, "360s", "10s").Should(Equal(true)) |
|
|
|
|
|
|
|
|
resp, err := http.Post("http://127.0.0.1:9090/tts", "application/json", bytes.NewBuffer([]byte(`{"input": "Hello world", "model": "voice-en-us-kathleen-low"}`))) |
|
|
Expect(err).ToNot(HaveOccurred(), fmt.Sprint(resp)) |
|
|
dat, err := io.ReadAll(resp.Body) |
|
|
Expect(err).ToNot(HaveOccurred(), fmt.Sprint(resp)) |
|
|
|
|
|
Expect(resp.StatusCode).To(Equal(200), fmt.Sprint(string(dat))) |
|
|
Expect(resp.Header.Get("Content-Type")).To(Or(Equal("audio/x-wav"), Equal("audio/vnd.wave"))) |
|
|
}) |
|
|
It("installs and is capable to generate images", Label("stablediffusion"), func() { |
|
|
if runtime.GOOS != "linux" { |
|
|
Skip("test supported only on linux") |
|
|
} |
|
|
|
|
|
response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{ |
|
|
ID: "localai@sd-1.5-ggml", |
|
|
Name: "stablediffusion", |
|
|
}) |
|
|
|
|
|
Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response)) |
|
|
|
|
|
uuid := response["uuid"].(string) |
|
|
|
|
|
Eventually(func() bool { |
|
|
response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid) |
|
|
fmt.Println(response) |
|
|
return response["processed"].(bool) |
|
|
}, "1200s", "10s").Should(Equal(true)) |
|
|
|
|
|
resp, err := http.Post( |
|
|
"http://127.0.0.1:9090/v1/images/generations", |
|
|
"application/json", |
|
|
bytes.NewBuffer([]byte(`{ |
|
|
"prompt": "a lovely cat", |
|
|
"step": 1, "seed":9000, |
|
|
"size": "256x256", "n":2}`))) |
|
|
|
|
|
Expect(err).ToNot(HaveOccurred(), fmt.Sprint(resp)) |
|
|
dat, err := io.ReadAll(resp.Body) |
|
|
Expect(err).ToNot(HaveOccurred(), "error reading /image/generations response") |
|
|
|
|
|
imgUrlResp := &schema.OpenAIResponse{} |
|
|
err = json.Unmarshal(dat, imgUrlResp) |
|
|
Expect(err).ToNot(HaveOccurred(), fmt.Sprint(dat)) |
|
|
Expect(imgUrlResp.Data).ToNot(Or(BeNil(), BeZero())) |
|
|
imgUrl := imgUrlResp.Data[0].URL |
|
|
Expect(imgUrl).To(ContainSubstring("http://127.0.0.1:9090/"), imgUrl) |
|
|
Expect(imgUrl).To(ContainSubstring(".png"), imgUrl) |
|
|
|
|
|
imgResp, err := http.Get(imgUrl) |
|
|
Expect(err).To(BeNil()) |
|
|
Expect(imgResp).ToNot(BeNil()) |
|
|
Expect(imgResp.StatusCode).To(Equal(200)) |
|
|
Expect(imgResp.ContentLength).To(BeNumerically(">", 0)) |
|
|
imgData := make([]byte, 512) |
|
|
count, err := io.ReadFull(imgResp.Body, imgData) |
|
|
Expect(err).To(Or(BeNil(), MatchError(io.EOF))) |
|
|
Expect(count).To(BeNumerically(">", 0)) |
|
|
Expect(count).To(BeNumerically("<=", 512)) |
|
|
Expect(http.DetectContentType(imgData)).To(Equal("image/png")) |
|
|
}) |
|
|
}) |
|
|
|
|
|
Context("API query", func() { |
|
|
BeforeEach(func() { |
|
|
modelPath := os.Getenv("MODELS_PATH") |
|
|
backendPath := os.Getenv("BACKENDS_PATH") |
|
|
c, cancel = context.WithCancel(context.Background()) |
|
|
|
|
|
var err error |
|
|
|
|
|
systemState, err := system.GetSystemState( |
|
|
system.WithBackendPath(backendPath), |
|
|
system.WithModelPath(modelPath), |
|
|
) |
|
|
Expect(err).ToNot(HaveOccurred()) |
|
|
|
|
|
application, err := application.New( |
|
|
append(commonOpts, |
|
|
config.WithExternalBackend("transformers", os.Getenv("HUGGINGFACE_GRPC")), |
|
|
config.WithContext(c), |
|
|
config.WithSystemState(systemState), |
|
|
)...) |
|
|
Expect(err).ToNot(HaveOccurred()) |
|
|
app, err = API(application) |
|
|
Expect(err).ToNot(HaveOccurred()) |
|
|
go func() { |
|
|
if err := app.Start("127.0.0.1:9090"); err != nil && err != http.ErrServerClosed { |
|
|
xlog.Error("server error", "error", err) |
|
|
} |
|
|
}() |
|
|
|
|
|
defaultConfig := openai.DefaultConfig("") |
|
|
defaultConfig.BaseURL = "http://127.0.0.1:9090/v1" |
|
|
|
|
|
client2 = openaigo.NewClient("") |
|
|
client2.BaseURL = defaultConfig.BaseURL |
|
|
|
|
|
|
|
|
client = openai.NewClientWithConfig(defaultConfig) |
|
|
Eventually(func() error { |
|
|
_, err := client.ListModels(context.TODO()) |
|
|
return err |
|
|
}, "2m").ShouldNot(HaveOccurred()) |
|
|
}) |
|
|
AfterEach(func() { |
|
|
cancel() |
|
|
if app != nil { |
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) |
|
|
defer cancel() |
|
|
err := app.Shutdown(ctx) |
|
|
Expect(err).ToNot(HaveOccurred()) |
|
|
} |
|
|
}) |
|
|
It("returns the models list", func() { |
|
|
models, err := client.ListModels(context.TODO()) |
|
|
Expect(err).ToNot(HaveOccurred()) |
|
|
Expect(len(models.Models)).To(Equal(7)) |
|
|
}) |
|
|
It("can generate completions via ggml", func() { |
|
|
if runtime.GOOS != "linux" { |
|
|
Skip("test supported only on linux") |
|
|
} |
|
|
resp, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "testmodel.ggml", Prompt: testPrompt}) |
|
|
Expect(err).ToNot(HaveOccurred()) |
|
|
Expect(len(resp.Choices)).To(Equal(1)) |
|
|
Expect(resp.Choices[0].Text).ToNot(BeEmpty()) |
|
|
}) |
|
|
|
|
|
It("can generate chat completions via ggml", func() { |
|
|
if runtime.GOOS != "linux" { |
|
|
Skip("test supported only on linux") |
|
|
} |
|
|
resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: "testmodel.ggml", Messages: []openai.ChatCompletionMessage{openai.ChatCompletionMessage{Role: "user", Content: testPrompt}}}) |
|
|
Expect(err).ToNot(HaveOccurred()) |
|
|
Expect(len(resp.Choices)).To(Equal(1)) |
|
|
Expect(resp.Choices[0].Message.Content).ToNot(BeEmpty()) |
|
|
}) |
|
|
|
|
|
It("returns logprobs in chat completions when requested", func() { |
|
|
if runtime.GOOS != "linux" { |
|
|
Skip("test only on linux") |
|
|
} |
|
|
topLogprobsVal := 3 |
|
|
response, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{ |
|
|
Model: "testmodel.ggml", |
|
|
LogProbs: true, |
|
|
TopLogProbs: topLogprobsVal, |
|
|
Messages: []openai.ChatCompletionMessage{{Role: "user", Content: testPrompt}}}) |
|
|
Expect(err).ToNot(HaveOccurred()) |
|
|
|
|
|
Expect(len(response.Choices)).To(Equal(1)) |
|
|
Expect(response.Choices[0].Message).ToNot(BeNil()) |
|
|
Expect(response.Choices[0].Message.Content).ToNot(BeEmpty()) |
|
|
|
|
|
|
|
|
Expect(response.Choices[0].LogProbs).ToNot(BeNil()) |
|
|
Expect(response.Choices[0].LogProbs.Content).ToNot(BeEmpty()) |
|
|
|
|
|
Expect(len(response.Choices[0].LogProbs.Content)).To(BeNumerically(">", 1)) |
|
|
|
|
|
foundatLeastToken := "" |
|
|
foundAtLeastBytes := []byte{} |
|
|
foundAtLeastTopLogprobBytes := []byte{} |
|
|
foundatLeastTopLogprob := "" |
|
|
|
|
|
for _, logprobContent := range response.Choices[0].LogProbs.Content { |
|
|
|
|
|
if len(logprobContent.Bytes) > 0 { |
|
|
foundAtLeastBytes = logprobContent.Bytes |
|
|
} |
|
|
if len(logprobContent.Token) > 0 { |
|
|
foundatLeastToken = logprobContent.Token |
|
|
} |
|
|
Expect(logprobContent.LogProb).To(BeNumerically("<=", 0)) |
|
|
Expect(len(logprobContent.TopLogProbs)).To(BeNumerically(">", 1)) |
|
|
|
|
|
|
|
|
if len(logprobContent.TopLogProbs) > 0 { |
|
|
|
|
|
Expect(len(logprobContent.TopLogProbs)).To(BeNumerically("<=", topLogprobsVal)) |
|
|
for _, topLogprob := range logprobContent.TopLogProbs { |
|
|
if len(topLogprob.Bytes) > 0 { |
|
|
foundAtLeastTopLogprobBytes = topLogprob.Bytes |
|
|
} |
|
|
if len(topLogprob.Token) > 0 { |
|
|
foundatLeastTopLogprob = topLogprob.Token |
|
|
} |
|
|
Expect(topLogprob.LogProb).To(BeNumerically("<=", 0)) |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
Expect(foundAtLeastBytes).ToNot(BeEmpty()) |
|
|
Expect(foundAtLeastTopLogprobBytes).ToNot(BeEmpty()) |
|
|
Expect(foundatLeastToken).ToNot(BeEmpty()) |
|
|
Expect(foundatLeastTopLogprob).ToNot(BeEmpty()) |
|
|
}) |
|
|
|
|
|
It("applies logit_bias to chat completions when requested", func() { |
|
|
if runtime.GOOS != "linux" { |
|
|
Skip("test only on linux") |
|
|
} |
|
|
|
|
|
|
|
|
logitBias := map[string]int{ |
|
|
"15043": 1, |
|
|
} |
|
|
response, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{ |
|
|
Model: "testmodel.ggml", |
|
|
Messages: []openai.ChatCompletionMessage{{Role: "user", Content: testPrompt}}, |
|
|
LogitBias: logitBias, |
|
|
}) |
|
|
Expect(err).ToNot(HaveOccurred()) |
|
|
Expect(len(response.Choices)).To(Equal(1)) |
|
|
Expect(response.Choices[0].Message).ToNot(BeNil()) |
|
|
Expect(response.Choices[0].Message.Content).ToNot(BeEmpty()) |
|
|
|
|
|
|
|
|
|
|
|
}) |
|
|
|
|
|
It("returns errors", func() { |
|
|
_, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "foomodel", Prompt: testPrompt}) |
|
|
Expect(err).To(HaveOccurred()) |
|
|
Expect(err.Error()).To(ContainSubstring("error, status code: 500, status: 500 Internal Server Error, message: could not load model - all backends returned error:")) |
|
|
}) |
|
|
|
|
|
It("shows the external backend", func() { |
|
|
|
|
|
if runtime.GOOS != "linux" { |
|
|
Skip("test supported only on linux") |
|
|
} |
|
|
|
|
|
resp, err := http.Get("http://127.0.0.1:9090/system") |
|
|
Expect(err).ToNot(HaveOccurred()) |
|
|
Expect(resp.StatusCode).To(Equal(200)) |
|
|
dat, err := io.ReadAll(resp.Body) |
|
|
Expect(err).ToNot(HaveOccurred()) |
|
|
Expect(string(dat)).To(ContainSubstring("huggingface")) |
|
|
Expect(string(dat)).To(ContainSubstring("llama-cpp")) |
|
|
}) |
|
|
|
|
|
It("transcribes audio", func() { |
|
|
if runtime.GOOS != "linux" { |
|
|
Skip("test supported only on linux") |
|
|
} |
|
|
resp, err := client.CreateTranscription( |
|
|
context.Background(), |
|
|
openai.AudioRequest{ |
|
|
Model: openai.Whisper1, |
|
|
FilePath: filepath.Join(os.Getenv("TEST_DIR"), "audio.wav"), |
|
|
}, |
|
|
) |
|
|
Expect(err).ToNot(HaveOccurred()) |
|
|
Expect(resp.Text).To(ContainSubstring("This is the Micro Machine Man presenting")) |
|
|
}) |
|
|
|
|
|
It("calculate embeddings", func() { |
|
|
if runtime.GOOS != "linux" { |
|
|
Skip("test supported only on linux") |
|
|
} |
|
|
embeddingModel := openai.AdaEmbeddingV2 |
|
|
resp, err := client.CreateEmbeddings( |
|
|
context.Background(), |
|
|
openai.EmbeddingRequest{ |
|
|
Model: embeddingModel, |
|
|
Input: []string{"sun", "cat"}, |
|
|
}, |
|
|
) |
|
|
Expect(err).ToNot(HaveOccurred(), err) |
|
|
Expect(len(resp.Data[0].Embedding)).To(BeNumerically("==", 4096)) |
|
|
Expect(len(resp.Data[1].Embedding)).To(BeNumerically("==", 4096)) |
|
|
|
|
|
sunEmbedding := resp.Data[0].Embedding |
|
|
resp2, err := client.CreateEmbeddings( |
|
|
context.Background(), |
|
|
openai.EmbeddingRequest{ |
|
|
Model: embeddingModel, |
|
|
Input: []string{"sun"}, |
|
|
}, |
|
|
) |
|
|
Expect(err).ToNot(HaveOccurred()) |
|
|
Expect(resp2.Data[0].Embedding).To(Equal(sunEmbedding)) |
|
|
Expect(resp2.Data[0].Embedding).ToNot(Equal(resp.Data[1].Embedding)) |
|
|
|
|
|
resp3, err := client.CreateEmbeddings( |
|
|
context.Background(), |
|
|
openai.EmbeddingRequest{ |
|
|
Model: embeddingModel, |
|
|
Input: []string{"cat"}, |
|
|
}, |
|
|
) |
|
|
Expect(err).ToNot(HaveOccurred()) |
|
|
Expect(resp3.Data[0].Embedding).To(Equal(resp.Data[1].Embedding)) |
|
|
Expect(resp3.Data[0].Embedding).ToNot(Equal(sunEmbedding)) |
|
|
}) |
|
|
|
|
|
Context("External gRPC calls", func() { |
|
|
It("calculate embeddings with sentencetransformers", func() { |
|
|
if runtime.GOOS != "linux" { |
|
|
Skip("test supported only on linux") |
|
|
} |
|
|
resp, err := client.CreateEmbeddings( |
|
|
context.Background(), |
|
|
openai.EmbeddingRequest{ |
|
|
Model: openai.AdaCodeSearchCode, |
|
|
Input: []string{"sun", "cat"}, |
|
|
}, |
|
|
) |
|
|
Expect(err).ToNot(HaveOccurred()) |
|
|
Expect(len(resp.Data[0].Embedding)).To(BeNumerically("==", 384)) |
|
|
Expect(len(resp.Data[1].Embedding)).To(BeNumerically("==", 384)) |
|
|
|
|
|
sunEmbedding := resp.Data[0].Embedding |
|
|
resp2, err := client.CreateEmbeddings( |
|
|
context.Background(), |
|
|
openai.EmbeddingRequest{ |
|
|
Model: openai.AdaCodeSearchCode, |
|
|
Input: []string{"sun"}, |
|
|
}, |
|
|
) |
|
|
Expect(err).ToNot(HaveOccurred()) |
|
|
Expect(resp2.Data[0].Embedding).To(Equal(sunEmbedding)) |
|
|
Expect(resp2.Data[0].Embedding).ToNot(Equal(resp.Data[1].Embedding)) |
|
|
}) |
|
|
}) |
|
|
|
|
|
|
|
|
Context("Stores", Label("stores"), func() { |
|
|
|
|
|
BeforeEach(func() { |
|
|
|
|
|
if runtime.GOOS != "linux" { |
|
|
Skip("test supported only on linux") |
|
|
} |
|
|
}) |
|
|
|
|
|
It("sets, gets, finds and deletes entries", func() { |
|
|
ks := [][]float32{ |
|
|
{0.1, 0.2, 0.3}, |
|
|
{0.4, 0.5, 0.6}, |
|
|
{0.7, 0.8, 0.9}, |
|
|
} |
|
|
vs := []string{ |
|
|
"test1", |
|
|
"test2", |
|
|
"test3", |
|
|
} |
|
|
setBody := schema.StoresSet{ |
|
|
Keys: ks, |
|
|
Values: vs, |
|
|
} |
|
|
|
|
|
url := "http://127.0.0.1:9090/stores/" |
|
|
err := postRequestJSON(url+"set", &setBody) |
|
|
Expect(err).ToNot(HaveOccurred()) |
|
|
|
|
|
getBody := schema.StoresGet{ |
|
|
Keys: ks, |
|
|
} |
|
|
var getRespBody schema.StoresGetResponse |
|
|
err = postRequestResponseJSON(url+"get", &getBody, &getRespBody) |
|
|
Expect(err).ToNot(HaveOccurred()) |
|
|
Expect(len(getRespBody.Keys)).To(Equal(len(ks))) |
|
|
|
|
|
for i, v := range getRespBody.Keys { |
|
|
if v[0] == 0.1 { |
|
|
Expect(getRespBody.Values[i]).To(Equal("test1")) |
|
|
} else if v[0] == 0.4 { |
|
|
Expect(getRespBody.Values[i]).To(Equal("test2")) |
|
|
} else { |
|
|
Expect(getRespBody.Values[i]).To(Equal("test3")) |
|
|
} |
|
|
} |
|
|
|
|
|
deleteBody := schema.StoresDelete{ |
|
|
Keys: [][]float32{ |
|
|
{0.1, 0.2, 0.3}, |
|
|
}, |
|
|
} |
|
|
err = postRequestJSON(url+"delete", &deleteBody) |
|
|
Expect(err).ToNot(HaveOccurred()) |
|
|
|
|
|
findBody := schema.StoresFind{ |
|
|
Key: []float32{0.1, 0.3, 0.7}, |
|
|
Topk: 10, |
|
|
} |
|
|
|
|
|
var findRespBody schema.StoresFindResponse |
|
|
err = postRequestResponseJSON(url+"find", &findBody, &findRespBody) |
|
|
Expect(err).ToNot(HaveOccurred()) |
|
|
Expect(len(findRespBody.Keys)).To(Equal(2)) |
|
|
|
|
|
for i, v := range findRespBody.Keys { |
|
|
if v[0] == 0.4 { |
|
|
Expect(findRespBody.Values[i]).To(Equal("test2")) |
|
|
} else { |
|
|
Expect(findRespBody.Values[i]).To(Equal("test3")) |
|
|
} |
|
|
|
|
|
Expect(findRespBody.Similarities[i]).To(BeNumerically(">=", -1)) |
|
|
Expect(findRespBody.Similarities[i]).To(BeNumerically("<=", 1)) |
|
|
} |
|
|
}) |
|
|
|
|
|
Context("Agent Jobs", Label("agent-jobs"), func() { |
|
|
It("creates and manages tasks", func() { |
|
|
|
|
|
taskBody := map[string]interface{}{ |
|
|
"name": "Test Task", |
|
|
"description": "Test Description", |
|
|
"model": "testmodel.ggml", |
|
|
"prompt": "Hello {{.name}}", |
|
|
"enabled": true, |
|
|
} |
|
|
|
|
|
var createResp map[string]interface{} |
|
|
err := postRequestResponseJSON("http://127.0.0.1:9090/api/agent/tasks", &taskBody, &createResp) |
|
|
Expect(err).ToNot(HaveOccurred()) |
|
|
Expect(createResp["id"]).ToNot(BeEmpty()) |
|
|
taskID := createResp["id"].(string) |
|
|
|
|
|
|
|
|
var task schema.Task |
|
|
resp, err := http.Get("http://127.0.0.1:9090/api/agent/tasks/" + taskID) |
|
|
Expect(err).ToNot(HaveOccurred()) |
|
|
Expect(resp.StatusCode).To(Equal(200)) |
|
|
body, _ := io.ReadAll(resp.Body) |
|
|
json.Unmarshal(body, &task) |
|
|
Expect(task.Name).To(Equal("Test Task")) |
|
|
|
|
|
|
|
|
resp, err = http.Get("http://127.0.0.1:9090/api/agent/tasks") |
|
|
Expect(err).ToNot(HaveOccurred()) |
|
|
Expect(resp.StatusCode).To(Equal(200)) |
|
|
var tasks []schema.Task |
|
|
body, _ = io.ReadAll(resp.Body) |
|
|
json.Unmarshal(body, &tasks) |
|
|
Expect(len(tasks)).To(BeNumerically(">=", 1)) |
|
|
|
|
|
|
|
|
taskBody["name"] = "Updated Task" |
|
|
err = putRequestJSON("http://127.0.0.1:9090/api/agent/tasks/"+taskID, &taskBody) |
|
|
Expect(err).ToNot(HaveOccurred()) |
|
|
|
|
|
|
|
|
resp, err = http.Get("http://127.0.0.1:9090/api/agent/tasks/" + taskID) |
|
|
Expect(err).ToNot(HaveOccurred()) |
|
|
body, _ = io.ReadAll(resp.Body) |
|
|
json.Unmarshal(body, &task) |
|
|
Expect(task.Name).To(Equal("Updated Task")) |
|
|
|
|
|
|
|
|
req, _ := http.NewRequest("DELETE", "http://127.0.0.1:9090/api/agent/tasks/"+taskID, nil) |
|
|
req.Header.Set("Authorization", bearerKey) |
|
|
resp, err = http.DefaultClient.Do(req) |
|
|
Expect(err).ToNot(HaveOccurred()) |
|
|
Expect(resp.StatusCode).To(Equal(200)) |
|
|
}) |
|
|
|
|
|
It("executes and monitors jobs", func() { |
|
|
|
|
|
taskBody := map[string]interface{}{ |
|
|
"name": "Job Test Task", |
|
|
"model": "testmodel.ggml", |
|
|
"prompt": "Say hello", |
|
|
"enabled": true, |
|
|
} |
|
|
|
|
|
var createResp map[string]interface{} |
|
|
err := postRequestResponseJSON("http://127.0.0.1:9090/api/agent/tasks", &taskBody, &createResp) |
|
|
Expect(err).ToNot(HaveOccurred()) |
|
|
taskID := createResp["id"].(string) |
|
|
|
|
|
|
|
|
jobBody := map[string]interface{}{ |
|
|
"task_id": taskID, |
|
|
"parameters": map[string]string{}, |
|
|
} |
|
|
|
|
|
var jobResp schema.JobExecutionResponse |
|
|
err = postRequestResponseJSON("http://127.0.0.1:9090/api/agent/jobs/execute", &jobBody, &jobResp) |
|
|
Expect(err).ToNot(HaveOccurred()) |
|
|
Expect(jobResp.JobID).ToNot(BeEmpty()) |
|
|
jobID := jobResp.JobID |
|
|
|
|
|
|
|
|
var job schema.Job |
|
|
resp, err := http.Get("http://127.0.0.1:9090/api/agent/jobs/" + jobID) |
|
|
Expect(err).ToNot(HaveOccurred()) |
|
|
Expect(resp.StatusCode).To(Equal(200)) |
|
|
body, _ := io.ReadAll(resp.Body) |
|
|
json.Unmarshal(body, &job) |
|
|
Expect(job.ID).To(Equal(jobID)) |
|
|
Expect(job.TaskID).To(Equal(taskID)) |
|
|
|
|
|
|
|
|
resp, err = http.Get("http://127.0.0.1:9090/api/agent/jobs") |
|
|
Expect(err).ToNot(HaveOccurred()) |
|
|
Expect(resp.StatusCode).To(Equal(200)) |
|
|
var jobs []schema.Job |
|
|
body, _ = io.ReadAll(resp.Body) |
|
|
json.Unmarshal(body, &jobs) |
|
|
Expect(len(jobs)).To(BeNumerically(">=", 1)) |
|
|
|
|
|
|
|
|
if job.Status == schema.JobStatusPending || job.Status == schema.JobStatusRunning { |
|
|
req, _ := http.NewRequest("POST", "http://127.0.0.1:9090/api/agent/jobs/"+jobID+"/cancel", nil) |
|
|
req.Header.Set("Authorization", bearerKey) |
|
|
resp, err = http.DefaultClient.Do(req) |
|
|
Expect(err).ToNot(HaveOccurred()) |
|
|
Expect(resp.StatusCode).To(Equal(200)) |
|
|
} |
|
|
}) |
|
|
|
|
|
It("executes task by name", func() { |
|
|
|
|
|
taskBody := map[string]interface{}{ |
|
|
"name": "Named Task", |
|
|
"model": "testmodel.ggml", |
|
|
"prompt": "Hello", |
|
|
"enabled": true, |
|
|
} |
|
|
|
|
|
var createResp map[string]interface{} |
|
|
err := postRequestResponseJSON("http://127.0.0.1:9090/api/agent/tasks", &taskBody, &createResp) |
|
|
Expect(err).ToNot(HaveOccurred()) |
|
|
|
|
|
|
|
|
paramsBody := map[string]string{"param1": "value1"} |
|
|
var jobResp schema.JobExecutionResponse |
|
|
err = postRequestResponseJSON("http://127.0.0.1:9090/api/agent/tasks/Named Task/execute", ¶msBody, &jobResp) |
|
|
Expect(err).ToNot(HaveOccurred()) |
|
|
Expect(jobResp.JobID).ToNot(BeEmpty()) |
|
|
}) |
|
|
}) |
|
|
}) |
|
|
}) |
|
|
|
|
|
Context("Config file", func() { |
|
|
BeforeEach(func() { |
|
|
if runtime.GOOS != "linux" { |
|
|
Skip("run this test only on linux") |
|
|
} |
|
|
modelPath := os.Getenv("MODELS_PATH") |
|
|
backendPath := os.Getenv("BACKENDS_PATH") |
|
|
c, cancel = context.WithCancel(context.Background()) |
|
|
|
|
|
var err error |
|
|
|
|
|
systemState, err := system.GetSystemState( |
|
|
system.WithBackendPath(backendPath), |
|
|
system.WithModelPath(modelPath), |
|
|
) |
|
|
Expect(err).ToNot(HaveOccurred()) |
|
|
|
|
|
application, err := application.New( |
|
|
append(commonOpts, |
|
|
config.WithContext(c), |
|
|
config.WithSystemState(systemState), |
|
|
config.WithConfigFile(os.Getenv("CONFIG_FILE")))..., |
|
|
) |
|
|
Expect(err).ToNot(HaveOccurred()) |
|
|
app, err = API(application) |
|
|
Expect(err).ToNot(HaveOccurred()) |
|
|
|
|
|
go func() { |
|
|
if err := app.Start("127.0.0.1:9090"); err != nil && err != http.ErrServerClosed { |
|
|
xlog.Error("server error", "error", err) |
|
|
} |
|
|
}() |
|
|
|
|
|
defaultConfig := openai.DefaultConfig("") |
|
|
defaultConfig.BaseURL = "http://127.0.0.1:9090/v1" |
|
|
client2 = openaigo.NewClient("") |
|
|
client2.BaseURL = defaultConfig.BaseURL |
|
|
|
|
|
client = openai.NewClientWithConfig(defaultConfig) |
|
|
Eventually(func() error { |
|
|
_, err := client.ListModels(context.TODO()) |
|
|
return err |
|
|
}, "2m").ShouldNot(HaveOccurred()) |
|
|
}) |
|
|
AfterEach(func() { |
|
|
cancel() |
|
|
if app != nil { |
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) |
|
|
defer cancel() |
|
|
err := app.Shutdown(ctx) |
|
|
Expect(err).ToNot(HaveOccurred()) |
|
|
} |
|
|
}) |
|
|
It("can generate chat completions from config file (list1)", func() { |
|
|
resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: "list1", Messages: []openai.ChatCompletionMessage{{Role: "user", Content: testPrompt}}}) |
|
|
Expect(err).ToNot(HaveOccurred()) |
|
|
Expect(len(resp.Choices)).To(Equal(1)) |
|
|
Expect(resp.Choices[0].Message.Content).ToNot(BeEmpty()) |
|
|
}) |
|
|
It("can generate chat completions from config file (list2)", func() { |
|
|
resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: "list2", Messages: []openai.ChatCompletionMessage{{Role: "user", Content: testPrompt}}}) |
|
|
Expect(err).ToNot(HaveOccurred()) |
|
|
Expect(len(resp.Choices)).To(Equal(1)) |
|
|
Expect(resp.Choices[0].Message.Content).ToNot(BeEmpty()) |
|
|
}) |
|
|
It("can generate edit completions from config file", func() { |
|
|
request := openaigo.EditCreateRequestBody{ |
|
|
Model: "list2", |
|
|
Instruction: "foo", |
|
|
Input: "bar", |
|
|
} |
|
|
resp, err := client2.CreateEdit(context.Background(), request) |
|
|
Expect(err).ToNot(HaveOccurred()) |
|
|
Expect(len(resp.Choices)).To(Equal(1)) |
|
|
Expect(resp.Choices[0].Text).ToNot(BeEmpty()) |
|
|
}) |
|
|
|
|
|
}) |
|
|
}) |
|
|
|