Spaces:
Running
Running
Amlan-109
feat: Initial commit of LocalAI Amlan Edition with premium branding and personalization
750bbe6 | package http_test | |
| import ( | |
| "bytes" | |
| "context" | |
| "encoding/json" | |
| "fmt" | |
| "io" | |
| "net/http" | |
| "os" | |
| "path/filepath" | |
| "runtime" | |
| "time" | |
| "github.com/mudler/LocalAI/core/application" | |
| "github.com/mudler/LocalAI/core/config" | |
| . "github.com/mudler/LocalAI/core/http" | |
| "github.com/mudler/LocalAI/core/schema" | |
| "github.com/labstack/echo/v4" | |
| "github.com/mudler/LocalAI/core/gallery" | |
| "github.com/mudler/LocalAI/pkg/downloader" | |
| "github.com/mudler/LocalAI/pkg/system" | |
| . "github.com/onsi/ginkgo/v2" | |
| . "github.com/onsi/gomega" | |
| "gopkg.in/yaml.v3" | |
| "github.com/mudler/xlog" | |
| openaigo "github.com/otiai10/openaigo" | |
| "github.com/sashabaranov/go-openai" | |
| "github.com/sashabaranov/go-openai/jsonschema" | |
| ) | |
| const apiKey = "joshua" | |
| const bearerKey = "Bearer " + apiKey | |
| const testPrompt = `### System: | |
| You are an AI assistant that follows instruction extremely well. Help as much as you can. | |
| ### Instruction: | |
| Say hello. | |
| ### Response:` | |
| type modelApplyRequest struct { | |
| ID string `json:"id"` | |
| URL string `json:"url"` | |
| ConfigURL string `json:"config_url"` | |
| Name string `json:"name"` | |
| Overrides map[string]interface{} `json:"overrides"` | |
| } | |
| func getModelStatus(url string) (response map[string]interface{}) { | |
| // Create the HTTP request | |
| req, err := http.NewRequest("GET", url, nil) | |
| req.Header.Set("Content-Type", "application/json") | |
| req.Header.Set("Authorization", bearerKey) | |
| if err != nil { | |
| fmt.Println("Error creating request:", err) | |
| return | |
| } | |
| client := &http.Client{} | |
| resp, err := client.Do(req) | |
| if err != nil { | |
| fmt.Println("Error sending request:", err) | |
| return | |
| } | |
| defer resp.Body.Close() | |
| body, err := io.ReadAll(resp.Body) | |
| if err != nil { | |
| fmt.Println("Error reading response body:", err) | |
| return | |
| } | |
| // Unmarshal the response into a map[string]interface{} | |
| err = json.Unmarshal(body, &response) | |
| if err != nil { | |
| fmt.Println("Error unmarshaling JSON response:", err) | |
| return | |
| } | |
| return | |
| } | |
| func getModels(url string) ([]gallery.GalleryModel, error) { | |
| response := []gallery.GalleryModel{} | |
| uri := downloader.URI(url) | |
| // TODO: No tests currently seem to exercise file:// urls. Fix? | |
| err := uri.ReadWithAuthorizationAndCallback(context.TODO(), "", bearerKey, func(url string, i []byte) error { | |
| // Unmarshal YAML data into a struct | |
| return json.Unmarshal(i, &response) | |
| }) | |
| return response, err | |
| } | |
| func postModelApplyRequest(url string, request modelApplyRequest) (response map[string]interface{}) { | |
| //url := "http://localhost:AI/models/apply" | |
| // Create the request payload | |
| payload, err := json.Marshal(request) | |
| if err != nil { | |
| fmt.Println("Error marshaling JSON:", err) | |
| return | |
| } | |
| // Create the HTTP request | |
| req, err := http.NewRequest("POST", url, bytes.NewBuffer(payload)) | |
| if err != nil { | |
| fmt.Println("Error creating request:", err) | |
| return | |
| } | |
| req.Header.Set("Content-Type", "application/json") | |
| req.Header.Set("Authorization", bearerKey) | |
| // Make the request | |
| client := &http.Client{} | |
| resp, err := client.Do(req) | |
| if err != nil { | |
| fmt.Println("Error making request:", err) | |
| return | |
| } | |
| defer resp.Body.Close() | |
| body, err := io.ReadAll(resp.Body) | |
| if err != nil { | |
| fmt.Println("Error reading response body:", err) | |
| return | |
| } | |
| // Unmarshal the response into a map[string]interface{} | |
| err = json.Unmarshal(body, &response) | |
| if err != nil { | |
| fmt.Println("Error unmarshaling JSON response:", err) | |
| return | |
| } | |
| return | |
| } | |
| func postRequestJSON[B any](url string, bodyJson *B) error { | |
| payload, err := json.Marshal(bodyJson) | |
| if err != nil { | |
| return err | |
| } | |
| GinkgoWriter.Printf("POST %s: %s\n", url, string(payload)) | |
| req, err := http.NewRequest("POST", url, bytes.NewBuffer(payload)) | |
| if err != nil { | |
| return err | |
| } | |
| req.Header.Set("Content-Type", "application/json") | |
| req.Header.Set("Authorization", bearerKey) | |
| client := &http.Client{} | |
| resp, err := client.Do(req) | |
| if err != nil { | |
| return err | |
| } | |
| defer resp.Body.Close() | |
| body, err := io.ReadAll(resp.Body) | |
| if err != nil { | |
| return err | |
| } | |
| if resp.StatusCode < 200 || resp.StatusCode >= 400 { | |
| return fmt.Errorf("unexpected status code: %d, body: %s", resp.StatusCode, string(body)) | |
| } | |
| return nil | |
| } | |
| func postRequestResponseJSON[B1 any, B2 any](url string, reqJson *B1, respJson *B2) error { | |
| payload, err := json.Marshal(reqJson) | |
| if err != nil { | |
| return err | |
| } | |
| GinkgoWriter.Printf("POST %s: %s\n", url, string(payload)) | |
| req, err := http.NewRequest("POST", url, bytes.NewBuffer(payload)) | |
| if err != nil { | |
| return err | |
| } | |
| req.Header.Set("Content-Type", "application/json") | |
| req.Header.Set("Authorization", bearerKey) | |
| client := &http.Client{} | |
| resp, err := client.Do(req) | |
| if err != nil { | |
| return err | |
| } | |
| defer resp.Body.Close() | |
| body, err := io.ReadAll(resp.Body) | |
| if err != nil { | |
| return err | |
| } | |
| if resp.StatusCode < 200 || resp.StatusCode >= 400 { | |
| return fmt.Errorf("unexpected status code: %d, body: %s", resp.StatusCode, string(body)) | |
| } | |
| return json.Unmarshal(body, respJson) | |
| } | |
| func putRequestJSON[B any](url string, bodyJson *B) error { | |
| payload, err := json.Marshal(bodyJson) | |
| if err != nil { | |
| return err | |
| } | |
| GinkgoWriter.Printf("PUT %s: %s\n", url, string(payload)) | |
| req, err := http.NewRequest("PUT", url, bytes.NewBuffer(payload)) | |
| if err != nil { | |
| return err | |
| } | |
| req.Header.Set("Content-Type", "application/json") | |
| req.Header.Set("Authorization", bearerKey) | |
| client := &http.Client{} | |
| resp, err := client.Do(req) | |
| if err != nil { | |
| return err | |
| } | |
| defer resp.Body.Close() | |
| body, err := io.ReadAll(resp.Body) | |
| if err != nil { | |
| return err | |
| } | |
| if resp.StatusCode < 200 || resp.StatusCode >= 400 { | |
| return fmt.Errorf("unexpected status code: %d, body: %s", resp.StatusCode, string(body)) | |
| } | |
| return nil | |
| } | |
| func postInvalidRequest(url string) (error, int) { | |
| req, err := http.NewRequest("POST", url, bytes.NewBufferString("invalid request")) | |
| if err != nil { | |
| return err, -1 | |
| } | |
| req.Header.Set("Content-Type", "application/json") | |
| client := &http.Client{} | |
| resp, err := client.Do(req) | |
| if err != nil { | |
| return err, -1 | |
| } | |
| defer resp.Body.Close() | |
| body, err := io.ReadAll(resp.Body) | |
| if err != nil { | |
| return err, -1 | |
| } | |
| if resp.StatusCode < 200 || resp.StatusCode >= 400 { | |
| return fmt.Errorf("unexpected status code: %d, body: %s", resp.StatusCode, string(body)), resp.StatusCode | |
| } | |
| return nil, resp.StatusCode | |
| } | |
| func getRequest(url string, header http.Header) (error, int, []byte) { | |
| req, err := http.NewRequest("GET", url, nil) | |
| if err != nil { | |
| return err, -1, nil | |
| } | |
| req.Header = header | |
| client := &http.Client{} | |
| resp, err := client.Do(req) | |
| if err != nil { | |
| return err, -1, nil | |
| } | |
| defer resp.Body.Close() | |
| body, err := io.ReadAll(resp.Body) | |
| if err != nil { | |
| return err, -1, nil | |
| } | |
| return nil, resp.StatusCode, body | |
| } | |
| const bertEmbeddingsURL = `https://gist.githubusercontent.com/mudler/0a080b166b87640e8644b09c2aee6e3b/raw/f0e8c26bb72edc16d9fbafbfd6638072126ff225/bert-embeddings-gallery.yaml` | |
| var _ = Describe("API test", func() { | |
| var app *echo.Echo | |
| var client *openai.Client | |
| var client2 *openaigo.Client | |
| var c context.Context | |
| var cancel context.CancelFunc | |
| var tmpdir string | |
| var modelDir string | |
| commonOpts := []config.AppOption{ | |
| config.WithDebug(true), | |
| } | |
| Context("API with ephemeral models", func() { | |
| BeforeEach(func(sc SpecContext) { | |
| var err error | |
| tmpdir, err = os.MkdirTemp("", "") | |
| Expect(err).ToNot(HaveOccurred()) | |
| backendPath := os.Getenv("BACKENDS_PATH") | |
| modelDir = filepath.Join(tmpdir, "models") | |
| err = os.Mkdir(modelDir, 0750) | |
| Expect(err).ToNot(HaveOccurred()) | |
| c, cancel = context.WithCancel(context.Background()) | |
| g := []gallery.GalleryModel{ | |
| { | |
| Metadata: gallery.Metadata{ | |
| Name: "bert", | |
| URL: bertEmbeddingsURL, | |
| }, | |
| }, | |
| { | |
| Metadata: gallery.Metadata{ | |
| Name: "bert2", | |
| URL: bertEmbeddingsURL, | |
| AdditionalFiles: []gallery.File{{Filename: "foo.yaml", URI: bertEmbeddingsURL}}, | |
| }, | |
| Overrides: map[string]interface{}{"foo": "bar"}, | |
| }, | |
| } | |
| out, err := yaml.Marshal(g) | |
| Expect(err).ToNot(HaveOccurred()) | |
| err = os.WriteFile(filepath.Join(modelDir, "gallery_simple.yaml"), out, 0600) | |
| Expect(err).ToNot(HaveOccurred()) | |
| galleries := []config.Gallery{ | |
| { | |
| Name: "test", | |
| URL: "file://" + filepath.Join(modelDir, "gallery_simple.yaml"), | |
| }, | |
| } | |
| systemState, err := system.GetSystemState( | |
| system.WithBackendPath(backendPath), | |
| system.WithModelPath(modelDir), | |
| ) | |
| Expect(err).ToNot(HaveOccurred()) | |
| application, err := application.New( | |
| append(commonOpts, | |
| config.WithContext(c), | |
| config.WithSystemState(systemState), | |
| config.WithGalleries(galleries), | |
| config.WithApiKeys([]string{apiKey}), | |
| )...) | |
| Expect(err).ToNot(HaveOccurred()) | |
| app, err = API(application) | |
| Expect(err).ToNot(HaveOccurred()) | |
| go func() { | |
| if err := app.Start("127.0.0.1:9090"); err != nil && err != http.ErrServerClosed { | |
| xlog.Error("server error", "error", err) | |
| } | |
| }() | |
| defaultConfig := openai.DefaultConfig(apiKey) | |
| defaultConfig.BaseURL = "http://127.0.0.1:9090/v1" | |
| client2 = openaigo.NewClient("") | |
| client2.BaseURL = defaultConfig.BaseURL | |
| // Wait for API to be ready | |
| client = openai.NewClientWithConfig(defaultConfig) | |
| Eventually(func() error { | |
| _, err := client.ListModels(context.TODO()) | |
| return err | |
| }, "2m").ShouldNot(HaveOccurred()) | |
| }) | |
| AfterEach(func(sc SpecContext) { | |
| cancel() | |
| if app != nil { | |
| ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) | |
| defer cancel() | |
| err := app.Shutdown(ctx) | |
| Expect(err).ToNot(HaveOccurred()) | |
| } | |
| err := os.RemoveAll(tmpdir) | |
| Expect(err).ToNot(HaveOccurred()) | |
| _, err = os.ReadDir(tmpdir) | |
| Expect(err).To(HaveOccurred()) | |
| }) | |
| Context("Auth Tests", func() { | |
| It("Should fail if the api key is missing", func() { | |
| err, sc := postInvalidRequest("http://127.0.0.1:9090/models/available") | |
| Expect(err).ToNot(BeNil()) | |
| Expect(sc).To(Equal(401)) | |
| }) | |
| }) | |
| Context("URL routing Tests", func() { | |
| It("Should support reverse-proxy when unauthenticated", func() { | |
| err, sc, body := getRequest("http://127.0.0.1:9090/myprefix/", http.Header{ | |
| "X-Forwarded-Proto": {"https"}, | |
| "X-Forwarded-Host": {"example.org"}, | |
| "X-Forwarded-Prefix": {"/myprefix/"}, | |
| }) | |
| Expect(err).To(BeNil(), "error") | |
| Expect(sc).To(Equal(401), "status code") | |
| Expect(string(body)).To(ContainSubstring(`<base href="https://example.org/myprefix/" />`), "body") | |
| }) | |
| It("Should support reverse-proxy when authenticated", func() { | |
| err, sc, body := getRequest("http://127.0.0.1:9090/myprefix/", http.Header{ | |
| "Authorization": {bearerKey}, | |
| "X-Forwarded-Proto": {"https"}, | |
| "X-Forwarded-Host": {"example.org"}, | |
| "X-Forwarded-Prefix": {"/myprefix/"}, | |
| }) | |
| Expect(err).To(BeNil(), "error") | |
| Expect(sc).To(Equal(200), "status code") | |
| Expect(string(body)).To(ContainSubstring(`<base href="https://example.org/myprefix/" />`), "body") | |
| }) | |
| }) | |
| Context("Applying models", func() { | |
| It("applies models from a gallery", func() { | |
| models, err := getModels("http://127.0.0.1:9090/models/available") | |
| Expect(err).To(BeNil()) | |
| Expect(len(models)).To(Equal(2), fmt.Sprint(models)) | |
| Expect(models[0].Installed).To(BeFalse(), fmt.Sprint(models)) | |
| Expect(models[1].Installed).To(BeFalse(), fmt.Sprint(models)) | |
| response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{ | |
| ID: "test@bert2", | |
| }) | |
| Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response)) | |
| uuid := response["uuid"].(string) | |
| resp := map[string]interface{}{} | |
| Eventually(func() bool { | |
| response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid) | |
| fmt.Println(response) | |
| resp = response | |
| return response["processed"].(bool) | |
| }, "360s", "10s").Should(Equal(true)) | |
| Expect(resp["message"]).ToNot(ContainSubstring("error")) | |
| dat, err := os.ReadFile(filepath.Join(modelDir, "bert2.yaml")) | |
| Expect(err).ToNot(HaveOccurred()) | |
| _, err = os.ReadFile(filepath.Join(modelDir, "foo.yaml")) | |
| Expect(err).ToNot(HaveOccurred()) | |
| content := map[string]interface{}{} | |
| err = yaml.Unmarshal(dat, &content) | |
| Expect(err).ToNot(HaveOccurred()) | |
| Expect(content["usage"]).To(ContainSubstring("You can test this model with curl like this")) | |
| Expect(content["foo"]).To(Equal("bar")) | |
| models, err = getModels("http://127.0.0.1:9090/models/available") | |
| Expect(err).To(BeNil()) | |
| Expect(len(models)).To(Equal(2), fmt.Sprint(models)) | |
| Expect(models[0].Name).To(Or(Equal("bert"), Equal("bert2"))) | |
| Expect(models[1].Name).To(Or(Equal("bert"), Equal("bert2"))) | |
| for _, m := range models { | |
| if m.Name == "bert2" { | |
| Expect(m.Installed).To(BeTrue()) | |
| } else { | |
| Expect(m.Installed).To(BeFalse()) | |
| } | |
| } | |
| }) | |
| It("overrides models", func() { | |
| response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{ | |
| URL: bertEmbeddingsURL, | |
| Name: "bert", | |
| Overrides: map[string]interface{}{ | |
| "backend": "llama", | |
| }, | |
| }) | |
| Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response)) | |
| uuid := response["uuid"].(string) | |
| Eventually(func() bool { | |
| response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid) | |
| return response["processed"].(bool) | |
| }, "360s", "10s").Should(Equal(true)) | |
| dat, err := os.ReadFile(filepath.Join(modelDir, "bert.yaml")) | |
| Expect(err).ToNot(HaveOccurred()) | |
| content := map[string]interface{}{} | |
| err = yaml.Unmarshal(dat, &content) | |
| Expect(err).ToNot(HaveOccurred()) | |
| Expect(content["backend"]).To(Equal("llama")) | |
| }) | |
| It("apply models without overrides", func() { | |
| response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{ | |
| URL: bertEmbeddingsURL, | |
| Name: "bert", | |
| Overrides: map[string]interface{}{}, | |
| }) | |
| Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response)) | |
| uuid := response["uuid"].(string) | |
| Eventually(func() bool { | |
| response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid) | |
| return response["processed"].(bool) | |
| }, "360s", "10s").Should(Equal(true)) | |
| dat, err := os.ReadFile(filepath.Join(modelDir, "bert.yaml")) | |
| Expect(err).ToNot(HaveOccurred()) | |
| content := map[string]interface{}{} | |
| err = yaml.Unmarshal(dat, &content) | |
| Expect(err).ToNot(HaveOccurred()) | |
| Expect(content["usage"]).To(ContainSubstring("You can test this model with curl like this")) | |
| }) | |
| }) | |
| Context("Importing models from URI", func() { | |
| var testYamlFile string | |
| BeforeEach(func() { | |
| // Create a test YAML config file | |
| yamlContent := `name: test-import-model | |
| backend: llama-cpp | |
| description: Test model imported from file URI | |
| parameters: | |
| model: path/to/model.gguf | |
| temperature: 0.7 | |
| ` | |
| testYamlFile = filepath.Join(tmpdir, "test-import.yaml") | |
| err := os.WriteFile(testYamlFile, []byte(yamlContent), 0644) | |
| Expect(err).ToNot(HaveOccurred()) | |
| }) | |
| AfterEach(func() { | |
| err := os.Remove(testYamlFile) | |
| Expect(err).ToNot(HaveOccurred()) | |
| }) | |
| It("should import model from file:// URI pointing to local YAML config", func() { | |
| importReq := schema.ImportModelRequest{ | |
| URI: "file://" + testYamlFile, | |
| Preferences: json.RawMessage(`{}`), | |
| } | |
| var response schema.GalleryResponse | |
| err := postRequestResponseJSON("http://127.0.0.1:9090/models/import-uri", &importReq, &response) | |
| Expect(err).ToNot(HaveOccurred()) | |
| Expect(response.ID).ToNot(BeEmpty()) | |
| uuid := response.ID | |
| resp := map[string]interface{}{} | |
| Eventually(func() bool { | |
| response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid) | |
| resp = response | |
| return response["processed"].(bool) | |
| }, "360s", "10s").Should(Equal(true)) | |
| // Check that the model was imported successfully | |
| Expect(resp["message"]).ToNot(ContainSubstring("error")) | |
| Expect(resp["error"]).To(BeNil()) | |
| // Verify the model config file was created | |
| dat, err := os.ReadFile(filepath.Join(modelDir, "test-import-model.yaml")) | |
| Expect(err).ToNot(HaveOccurred()) | |
| content := map[string]interface{}{} | |
| err = yaml.Unmarshal(dat, &content) | |
| Expect(err).ToNot(HaveOccurred()) | |
| Expect(content["name"]).To(Equal("test-import-model")) | |
| Expect(content["backend"]).To(Equal("llama-cpp")) | |
| }) | |
| It("should return error when file:// URI points to non-existent file", func() { | |
| nonExistentFile := filepath.Join(tmpdir, "nonexistent.yaml") | |
| importReq := schema.ImportModelRequest{ | |
| URI: "file://" + nonExistentFile, | |
| Preferences: json.RawMessage(`{}`), | |
| } | |
| var response schema.GalleryResponse | |
| err := postRequestResponseJSON("http://127.0.0.1:9090/models/import-uri", &importReq, &response) | |
| // The endpoint should return an error immediately | |
| Expect(err).To(HaveOccurred()) | |
| Expect(err.Error()).To(ContainSubstring("failed to discover model config")) | |
| }) | |
| }) | |
| Context("Importing models from URI can't point to absolute paths", func() { | |
| var testYamlFile string | |
| BeforeEach(func() { | |
| // Create a test YAML config file | |
| yamlContent := `name: test-import-model | |
| backend: llama-cpp | |
| description: Test model imported from file URI | |
| parameters: | |
| model: /path/to/model.gguf | |
| temperature: 0.7 | |
| ` | |
| testYamlFile = filepath.Join(tmpdir, "test-import.yaml") | |
| err := os.WriteFile(testYamlFile, []byte(yamlContent), 0644) | |
| Expect(err).ToNot(HaveOccurred()) | |
| }) | |
| AfterEach(func() { | |
| err := os.Remove(testYamlFile) | |
| Expect(err).ToNot(HaveOccurred()) | |
| }) | |
| It("should fail to import model from file:// URI pointing to local YAML config", func() { | |
| importReq := schema.ImportModelRequest{ | |
| URI: "file://" + testYamlFile, | |
| Preferences: json.RawMessage(`{}`), | |
| } | |
| var response schema.GalleryResponse | |
| err := postRequestResponseJSON("http://127.0.0.1:9090/models/import-uri", &importReq, &response) | |
| Expect(err).ToNot(HaveOccurred()) | |
| Expect(response.ID).ToNot(BeEmpty()) | |
| uuid := response.ID | |
| resp := map[string]interface{}{} | |
| Eventually(func() bool { | |
| response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid) | |
| resp = response | |
| return response["processed"].(bool) | |
| }, "360s", "10s").Should(Equal(true)) | |
| // Check that the model was imported successfully | |
| Expect(resp["message"]).To(ContainSubstring("error")) | |
| Expect(resp["error"]).ToNot(BeNil()) | |
| }) | |
| }) | |
| }) | |
| Context("Model gallery", func() { | |
| BeforeEach(func() { | |
| var err error | |
| tmpdir, err = os.MkdirTemp("", "") | |
| backendPath := os.Getenv("BACKENDS_PATH") | |
| Expect(err).ToNot(HaveOccurred()) | |
| modelDir = filepath.Join(tmpdir, "models") | |
| backendAssetsDir := filepath.Join(tmpdir, "backend-assets") | |
| err = os.Mkdir(backendAssetsDir, 0750) | |
| Expect(err).ToNot(HaveOccurred()) | |
| c, cancel = context.WithCancel(context.Background()) | |
| galleries := []config.Gallery{ | |
| { | |
| Name: "localai", | |
| URL: "https://raw.githubusercontent.com/mudler/LocalAI/refs/heads/master/gallery/index.yaml", | |
| }, | |
| } | |
| systemState, err := system.GetSystemState( | |
| system.WithBackendPath(backendPath), | |
| system.WithModelPath(modelDir), | |
| ) | |
| Expect(err).ToNot(HaveOccurred()) | |
| application, err := application.New( | |
| append(commonOpts, | |
| config.WithContext(c), | |
| config.WithGeneratedContentDir(tmpdir), | |
| config.WithSystemState(systemState), | |
| config.WithGalleries(galleries), | |
| )..., | |
| ) | |
| Expect(err).ToNot(HaveOccurred()) | |
| app, err = API(application) | |
| Expect(err).ToNot(HaveOccurred()) | |
| go func() { | |
| if err := app.Start("127.0.0.1:9090"); err != nil && err != http.ErrServerClosed { | |
| xlog.Error("server error", "error", err) | |
| } | |
| }() | |
| defaultConfig := openai.DefaultConfig("") | |
| defaultConfig.BaseURL = "http://127.0.0.1:9090/v1" | |
| client2 = openaigo.NewClient("") | |
| client2.BaseURL = defaultConfig.BaseURL | |
| // Wait for API to be ready | |
| client = openai.NewClientWithConfig(defaultConfig) | |
| Eventually(func() error { | |
| _, err := client.ListModels(context.TODO()) | |
| return err | |
| }, "2m").ShouldNot(HaveOccurred()) | |
| }) | |
| AfterEach(func() { | |
| cancel() | |
| if app != nil { | |
| ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) | |
| defer cancel() | |
| err := app.Shutdown(ctx) | |
| Expect(err).ToNot(HaveOccurred()) | |
| } | |
| err := os.RemoveAll(tmpdir) | |
| Expect(err).ToNot(HaveOccurred()) | |
| _, err = os.ReadDir(tmpdir) | |
| Expect(err).To(HaveOccurred()) | |
| }) | |
| It("runs gguf models (chat)", Label("llama-gguf"), func() { | |
| if runtime.GOOS != "linux" { | |
| Skip("test supported only on linux") | |
| } | |
| modelName := "qwen3-1.7b" | |
| response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{ | |
| ID: "localai@" + modelName, | |
| }) | |
| Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response)) | |
| uuid := response["uuid"].(string) | |
| Eventually(func() bool { | |
| response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid) | |
| return response["processed"].(bool) | |
| }, "900s", "10s").Should(Equal(true)) | |
| By("testing chat") | |
| resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: modelName, Messages: []openai.ChatCompletionMessage{ | |
| { | |
| Role: "user", | |
| Content: "How much is 2+2?", | |
| }, | |
| }}) | |
| Expect(err).ToNot(HaveOccurred()) | |
| Expect(len(resp.Choices)).To(Equal(1)) | |
| Expect(resp.Choices[0].Message.Content).To(Or(ContainSubstring("4"), ContainSubstring("four"))) | |
| By("testing functions") | |
| resp2, err := client.CreateChatCompletion( | |
| context.TODO(), | |
| openai.ChatCompletionRequest{ | |
| Model: modelName, | |
| Messages: []openai.ChatCompletionMessage{ | |
| { | |
| Role: "user", | |
| Content: "What is the weather like in San Francisco (celsius)?", | |
| }, | |
| }, | |
| Functions: []openai.FunctionDefinition{ | |
| openai.FunctionDefinition{ | |
| Name: "get_current_weather", | |
| Description: "Get the current weather", | |
| Parameters: jsonschema.Definition{ | |
| Type: jsonschema.Object, | |
| Properties: map[string]jsonschema.Definition{ | |
| "location": { | |
| Type: jsonschema.String, | |
| Description: "The city and state, e.g. San Francisco, CA", | |
| }, | |
| "unit": { | |
| Type: jsonschema.String, | |
| Enum: []string{"celcius", "fahrenheit"}, | |
| }, | |
| }, | |
| Required: []string{"location"}, | |
| }, | |
| }, | |
| }, | |
| }) | |
| Expect(err).ToNot(HaveOccurred()) | |
| Expect(len(resp2.Choices)).To(Equal(1)) | |
| Expect(resp2.Choices[0].Message.FunctionCall).ToNot(BeNil()) | |
| Expect(resp2.Choices[0].Message.FunctionCall.Name).To(Equal("get_current_weather"), resp2.Choices[0].Message.FunctionCall.Name) | |
| var res map[string]string | |
| err = json.Unmarshal([]byte(resp2.Choices[0].Message.FunctionCall.Arguments), &res) | |
| Expect(err).ToNot(HaveOccurred()) | |
| Expect(res["location"]).To(ContainSubstring("San Francisco"), fmt.Sprint(res)) | |
| Expect(res["unit"]).To(Equal("celcius"), fmt.Sprint(res)) | |
| Expect(string(resp2.Choices[0].FinishReason)).To(Equal("function_call"), fmt.Sprint(resp2.Choices[0].FinishReason)) | |
| }) | |
| It("installs and is capable to run tts", Label("tts"), func() { | |
| if runtime.GOOS != "linux" { | |
| Skip("test supported only on linux") | |
| } | |
| response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{ | |
| ID: "localai@voice-en-us-kathleen-low", | |
| }) | |
| Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response)) | |
| uuid := response["uuid"].(string) | |
| Eventually(func() bool { | |
| response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid) | |
| fmt.Println(response) | |
| return response["processed"].(bool) | |
| }, "360s", "10s").Should(Equal(true)) | |
| // An HTTP Post to the /tts endpoint should return a wav audio file | |
| resp, err := http.Post("http://127.0.0.1:9090/tts", "application/json", bytes.NewBuffer([]byte(`{"input": "Hello world", "model": "voice-en-us-kathleen-low"}`))) | |
| Expect(err).ToNot(HaveOccurred(), fmt.Sprint(resp)) | |
| dat, err := io.ReadAll(resp.Body) | |
| Expect(err).ToNot(HaveOccurred(), fmt.Sprint(resp)) | |
| Expect(resp.StatusCode).To(Equal(200), fmt.Sprint(string(dat))) | |
| Expect(resp.Header.Get("Content-Type")).To(Or(Equal("audio/x-wav"), Equal("audio/wav"), Equal("audio/vnd.wave"))) | |
| }) | |
| It("installs and is capable to generate images", Label("stablediffusion"), func() { | |
| if runtime.GOOS != "linux" { | |
| Skip("test supported only on linux") | |
| } | |
| response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{ | |
| ID: "localai@sd-1.5-ggml", | |
| Name: "stablediffusion", | |
| }) | |
| Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response)) | |
| uuid := response["uuid"].(string) | |
| Eventually(func() bool { | |
| response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid) | |
| fmt.Println(response) | |
| return response["processed"].(bool) | |
| }, "1200s", "10s").Should(Equal(true)) | |
| resp, err := http.Post( | |
| "http://127.0.0.1:9090/v1/images/generations", | |
| "application/json", | |
| bytes.NewBuffer([]byte(`{ | |
| "prompt": "a lovely cat", | |
| "step": 1, "seed":9000, | |
| "size": "256x256", "n":2}`))) | |
| // The response should contain an URL | |
| Expect(err).ToNot(HaveOccurred(), fmt.Sprint(resp)) | |
| dat, err := io.ReadAll(resp.Body) | |
| Expect(err).ToNot(HaveOccurred(), "error reading /image/generations response") | |
| imgUrlResp := &schema.OpenAIResponse{} | |
| err = json.Unmarshal(dat, imgUrlResp) | |
| Expect(err).ToNot(HaveOccurred(), fmt.Sprint(dat)) | |
| Expect(imgUrlResp.Data).ToNot(Or(BeNil(), BeZero())) | |
| imgUrl := imgUrlResp.Data[0].URL | |
| Expect(imgUrl).To(ContainSubstring("http://127.0.0.1:9090/"), imgUrl) | |
| Expect(imgUrl).To(ContainSubstring(".png"), imgUrl) | |
| imgResp, err := http.Get(imgUrl) | |
| Expect(err).To(BeNil()) | |
| Expect(imgResp).ToNot(BeNil()) | |
| Expect(imgResp.StatusCode).To(Equal(200)) | |
| Expect(imgResp.ContentLength).To(BeNumerically(">", 0)) | |
| imgData := make([]byte, 512) | |
| count, err := io.ReadFull(imgResp.Body, imgData) | |
| Expect(err).To(Or(BeNil(), MatchError(io.EOF))) | |
| Expect(count).To(BeNumerically(">", 0)) | |
| Expect(count).To(BeNumerically("<=", 512)) | |
| Expect(http.DetectContentType(imgData)).To(Equal("image/png")) | |
| }) | |
| }) | |
| Context("API query", func() { | |
| BeforeEach(func() { | |
| modelPath := os.Getenv("MODELS_PATH") | |
| backendPath := os.Getenv("BACKENDS_PATH") | |
| c, cancel = context.WithCancel(context.Background()) | |
| var err error | |
| systemState, err := system.GetSystemState( | |
| system.WithBackendPath(backendPath), | |
| system.WithModelPath(modelPath), | |
| ) | |
| Expect(err).ToNot(HaveOccurred()) | |
| application, err := application.New( | |
| append(commonOpts, | |
| config.WithExternalBackend("transformers", os.Getenv("HUGGINGFACE_GRPC")), | |
| config.WithContext(c), | |
| config.WithSystemState(systemState), | |
| )...) | |
| Expect(err).ToNot(HaveOccurred()) | |
| app, err = API(application) | |
| Expect(err).ToNot(HaveOccurred()) | |
| go func() { | |
| if err := app.Start("127.0.0.1:9090"); err != nil && err != http.ErrServerClosed { | |
| xlog.Error("server error", "error", err) | |
| } | |
| }() | |
| defaultConfig := openai.DefaultConfig("") | |
| defaultConfig.BaseURL = "http://127.0.0.1:9090/v1" | |
| client2 = openaigo.NewClient("") | |
| client2.BaseURL = defaultConfig.BaseURL | |
| // Wait for API to be ready | |
| client = openai.NewClientWithConfig(defaultConfig) | |
| Eventually(func() error { | |
| _, err := client.ListModels(context.TODO()) | |
| return err | |
| }, "2m").ShouldNot(HaveOccurred()) | |
| }) | |
| AfterEach(func() { | |
| cancel() | |
| if app != nil { | |
| ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) | |
| defer cancel() | |
| err := app.Shutdown(ctx) | |
| Expect(err).ToNot(HaveOccurred()) | |
| } | |
| }) | |
| It("returns the models list", func() { | |
| models, err := client.ListModels(context.TODO()) | |
| Expect(err).ToNot(HaveOccurred()) | |
| Expect(len(models.Models)).To(Equal(7)) // If "config.yaml" should be included, this should be 8? | |
| }) | |
| It("can generate completions via ggml", func() { | |
| if runtime.GOOS != "linux" { | |
| Skip("test supported only on linux") | |
| } | |
| resp, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "testmodel.ggml", Prompt: testPrompt}) | |
| Expect(err).ToNot(HaveOccurred()) | |
| Expect(len(resp.Choices)).To(Equal(1)) | |
| Expect(resp.Choices[0].Text).ToNot(BeEmpty()) | |
| }) | |
| It("can generate chat completions via ggml", func() { | |
| if runtime.GOOS != "linux" { | |
| Skip("test supported only on linux") | |
| } | |
| resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: "testmodel.ggml", Messages: []openai.ChatCompletionMessage{openai.ChatCompletionMessage{Role: "user", Content: testPrompt}}}) | |
| Expect(err).ToNot(HaveOccurred()) | |
| Expect(len(resp.Choices)).To(Equal(1)) | |
| Expect(resp.Choices[0].Message.Content).ToNot(BeEmpty()) | |
| }) | |
| It("returns logprobs in chat completions when requested", func() { | |
| if runtime.GOOS != "linux" { | |
| Skip("test only on linux") | |
| } | |
| topLogprobsVal := 3 | |
| response, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{ | |
| Model: "testmodel.ggml", | |
| LogProbs: true, | |
| TopLogProbs: topLogprobsVal, | |
| Messages: []openai.ChatCompletionMessage{{Role: "user", Content: testPrompt}}}) | |
| Expect(err).ToNot(HaveOccurred()) | |
| Expect(len(response.Choices)).To(Equal(1)) | |
| Expect(response.Choices[0].Message).ToNot(BeNil()) | |
| Expect(response.Choices[0].Message.Content).ToNot(BeEmpty()) | |
| // Verify logprobs are present and have correct structure | |
| Expect(response.Choices[0].LogProbs).ToNot(BeNil()) | |
| Expect(response.Choices[0].LogProbs.Content).ToNot(BeEmpty()) | |
| Expect(len(response.Choices[0].LogProbs.Content)).To(BeNumerically(">", 1)) | |
| foundatLeastToken := "" | |
| foundAtLeastBytes := []byte{} | |
| foundAtLeastTopLogprobBytes := []byte{} | |
| foundatLeastTopLogprob := "" | |
| // Verify logprobs content structure matches OpenAI format | |
| for _, logprobContent := range response.Choices[0].LogProbs.Content { | |
| // Bytes can be empty for certain tokens (special tokens, etc.), so we don't require it | |
| if len(logprobContent.Bytes) > 0 { | |
| foundAtLeastBytes = logprobContent.Bytes | |
| } | |
| if len(logprobContent.Token) > 0 { | |
| foundatLeastToken = logprobContent.Token | |
| } | |
| Expect(logprobContent.LogProb).To(BeNumerically("<=", 0)) // Logprobs are always <= 0 | |
| Expect(len(logprobContent.TopLogProbs)).To(BeNumerically(">", 1)) | |
| // If top_logprobs is requested, verify top_logprobs array respects the limit | |
| if len(logprobContent.TopLogProbs) > 0 { | |
| // Should respect top_logprobs limit (3 in this test) | |
| Expect(len(logprobContent.TopLogProbs)).To(BeNumerically("<=", topLogprobsVal)) | |
| for _, topLogprob := range logprobContent.TopLogProbs { | |
| if len(topLogprob.Bytes) > 0 { | |
| foundAtLeastTopLogprobBytes = topLogprob.Bytes | |
| } | |
| if len(topLogprob.Token) > 0 { | |
| foundatLeastTopLogprob = topLogprob.Token | |
| } | |
| Expect(topLogprob.LogProb).To(BeNumerically("<=", 0)) | |
| } | |
| } | |
| } | |
| Expect(foundAtLeastBytes).ToNot(BeEmpty()) | |
| Expect(foundAtLeastTopLogprobBytes).ToNot(BeEmpty()) | |
| Expect(foundatLeastToken).ToNot(BeEmpty()) | |
| Expect(foundatLeastTopLogprob).ToNot(BeEmpty()) | |
| }) | |
| It("applies logit_bias to chat completions when requested", func() { | |
| if runtime.GOOS != "linux" { | |
| Skip("test only on linux") | |
| } | |
| // logit_bias is a map of token IDs (as strings) to bias values (-100 to 100) | |
| // According to OpenAI API: modifies the likelihood of specified tokens appearing in the completion | |
| logitBias := map[string]int{ | |
| "15043": 1, // Bias token ID 15043 (example token ID) with bias value 1 | |
| } | |
| response, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{ | |
| Model: "testmodel.ggml", | |
| Messages: []openai.ChatCompletionMessage{{Role: "user", Content: testPrompt}}, | |
| LogitBias: logitBias, | |
| }) | |
| Expect(err).ToNot(HaveOccurred()) | |
| Expect(len(response.Choices)).To(Equal(1)) | |
| Expect(response.Choices[0].Message).ToNot(BeNil()) | |
| Expect(response.Choices[0].Message.Content).ToNot(BeEmpty()) | |
| // If logit_bias is applied, the response should be generated successfully | |
| // We can't easily verify the bias effect without knowing the actual token IDs for the model, | |
| // but the fact that the request succeeds confirms the API accepts and processes logit_bias | |
| }) | |
| It("returns errors", func() { | |
| _, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "foomodel", Prompt: testPrompt}) | |
| Expect(err).To(HaveOccurred()) | |
| Expect(err.Error()).To(ContainSubstring("error, status code: 500, status: 500 Internal Server Error, message: could not load model - all backends returned error:")) | |
| }) | |
| It("shows the external backend", func() { | |
| // Only run on linux | |
| if runtime.GOOS != "linux" { | |
| Skip("test supported only on linux") | |
| } | |
| // do an http request to the /system endpoint | |
| resp, err := http.Get("http://127.0.0.1:9090/system") | |
| Expect(err).ToNot(HaveOccurred()) | |
| Expect(resp.StatusCode).To(Equal(200)) | |
| dat, err := io.ReadAll(resp.Body) | |
| Expect(err).ToNot(HaveOccurred()) | |
| Expect(string(dat)).To(ContainSubstring("huggingface")) | |
| Expect(string(dat)).To(ContainSubstring("llama-cpp")) | |
| }) | |
| It("transcribes audio", func() { | |
| if runtime.GOOS != "linux" { | |
| Skip("test supported only on linux") | |
| } | |
| resp, err := client.CreateTranscription( | |
| context.Background(), | |
| openai.AudioRequest{ | |
| Model: openai.Whisper1, | |
| FilePath: filepath.Join(os.Getenv("TEST_DIR"), "audio.wav"), | |
| }, | |
| ) | |
| Expect(err).ToNot(HaveOccurred()) | |
| Expect(resp.Text).To(ContainSubstring("This is the Micro Machine Man presenting")) | |
| }) | |
| It("calculate embeddings", func() { | |
| if runtime.GOOS != "linux" { | |
| Skip("test supported only on linux") | |
| } | |
| embeddingModel := openai.AdaEmbeddingV2 | |
| resp, err := client.CreateEmbeddings( | |
| context.Background(), | |
| openai.EmbeddingRequest{ | |
| Model: embeddingModel, | |
| Input: []string{"sun", "cat"}, | |
| }, | |
| ) | |
| Expect(err).ToNot(HaveOccurred(), err) | |
| Expect(len(resp.Data[0].Embedding)).To(BeNumerically("==", 4096)) | |
| Expect(len(resp.Data[1].Embedding)).To(BeNumerically("==", 4096)) | |
| sunEmbedding := resp.Data[0].Embedding | |
| resp2, err := client.CreateEmbeddings( | |
| context.Background(), | |
| openai.EmbeddingRequest{ | |
| Model: embeddingModel, | |
| Input: []string{"sun"}, | |
| }, | |
| ) | |
| Expect(err).ToNot(HaveOccurred()) | |
| Expect(resp2.Data[0].Embedding).To(Equal(sunEmbedding)) | |
| Expect(resp2.Data[0].Embedding).ToNot(Equal(resp.Data[1].Embedding)) | |
| resp3, err := client.CreateEmbeddings( | |
| context.Background(), | |
| openai.EmbeddingRequest{ | |
| Model: embeddingModel, | |
| Input: []string{"cat"}, | |
| }, | |
| ) | |
| Expect(err).ToNot(HaveOccurred()) | |
| Expect(resp3.Data[0].Embedding).To(Equal(resp.Data[1].Embedding)) | |
| Expect(resp3.Data[0].Embedding).ToNot(Equal(sunEmbedding)) | |
| }) | |
| Context("External gRPC calls", func() { | |
| It("calculate embeddings with sentencetransformers", func() { | |
| if runtime.GOOS != "linux" { | |
| Skip("test supported only on linux") | |
| } | |
| resp, err := client.CreateEmbeddings( | |
| context.Background(), | |
| openai.EmbeddingRequest{ | |
| Model: openai.AdaCodeSearchCode, | |
| Input: []string{"sun", "cat"}, | |
| }, | |
| ) | |
| Expect(err).ToNot(HaveOccurred()) | |
| Expect(len(resp.Data[0].Embedding)).To(BeNumerically("==", 384)) | |
| Expect(len(resp.Data[1].Embedding)).To(BeNumerically("==", 384)) | |
| sunEmbedding := resp.Data[0].Embedding | |
| resp2, err := client.CreateEmbeddings( | |
| context.Background(), | |
| openai.EmbeddingRequest{ | |
| Model: openai.AdaCodeSearchCode, | |
| Input: []string{"sun"}, | |
| }, | |
| ) | |
| Expect(err).ToNot(HaveOccurred()) | |
| Expect(resp2.Data[0].Embedding).To(Equal(sunEmbedding)) | |
| Expect(resp2.Data[0].Embedding).ToNot(Equal(resp.Data[1].Embedding)) | |
| }) | |
| }) | |
| // See tests/integration/stores_test | |
| Context("Stores", Label("stores"), func() { | |
| BeforeEach(func() { | |
| // Only run on linux | |
| if runtime.GOOS != "linux" { | |
| Skip("test supported only on linux") | |
| } | |
| }) | |
| It("sets, gets, finds and deletes entries", func() { | |
| ks := [][]float32{ | |
| {0.1, 0.2, 0.3}, | |
| {0.4, 0.5, 0.6}, | |
| {0.7, 0.8, 0.9}, | |
| } | |
| vs := []string{ | |
| "test1", | |
| "test2", | |
| "test3", | |
| } | |
| setBody := schema.StoresSet{ | |
| Keys: ks, | |
| Values: vs, | |
| } | |
| url := "http://127.0.0.1:9090/stores/" | |
| err := postRequestJSON(url+"set", &setBody) | |
| Expect(err).ToNot(HaveOccurred()) | |
| getBody := schema.StoresGet{ | |
| Keys: ks, | |
| } | |
| var getRespBody schema.StoresGetResponse | |
| err = postRequestResponseJSON(url+"get", &getBody, &getRespBody) | |
| Expect(err).ToNot(HaveOccurred()) | |
| Expect(len(getRespBody.Keys)).To(Equal(len(ks))) | |
| for i, v := range getRespBody.Keys { | |
| if v[0] == 0.1 { | |
| Expect(getRespBody.Values[i]).To(Equal("test1")) | |
| } else if v[0] == 0.4 { | |
| Expect(getRespBody.Values[i]).To(Equal("test2")) | |
| } else { | |
| Expect(getRespBody.Values[i]).To(Equal("test3")) | |
| } | |
| } | |
| deleteBody := schema.StoresDelete{ | |
| Keys: [][]float32{ | |
| {0.1, 0.2, 0.3}, | |
| }, | |
| } | |
| err = postRequestJSON(url+"delete", &deleteBody) | |
| Expect(err).ToNot(HaveOccurred()) | |
| findBody := schema.StoresFind{ | |
| Key: []float32{0.1, 0.3, 0.7}, | |
| Topk: 10, | |
| } | |
| var findRespBody schema.StoresFindResponse | |
| err = postRequestResponseJSON(url+"find", &findBody, &findRespBody) | |
| Expect(err).ToNot(HaveOccurred()) | |
| Expect(len(findRespBody.Keys)).To(Equal(2)) | |
| for i, v := range findRespBody.Keys { | |
| if v[0] == 0.4 { | |
| Expect(findRespBody.Values[i]).To(Equal("test2")) | |
| } else { | |
| Expect(findRespBody.Values[i]).To(Equal("test3")) | |
| } | |
| Expect(findRespBody.Similarities[i]).To(BeNumerically(">=", -1)) | |
| Expect(findRespBody.Similarities[i]).To(BeNumerically("<=", 1)) | |
| } | |
| }) | |
| Context("Agent Jobs", Label("agent-jobs"), func() { | |
| It("creates and manages tasks", func() { | |
| // Create a task | |
| taskBody := map[string]interface{}{ | |
| "name": "Test Task", | |
| "description": "Test Description", | |
| "model": "testmodel.ggml", | |
| "prompt": "Hello {{.name}}", | |
| "enabled": true, | |
| } | |
| var createResp map[string]interface{} | |
| err := postRequestResponseJSON("http://127.0.0.1:9090/api/agent/tasks", &taskBody, &createResp) | |
| Expect(err).ToNot(HaveOccurred()) | |
| Expect(createResp["id"]).ToNot(BeEmpty()) | |
| taskID := createResp["id"].(string) | |
| // Get the task | |
| var task schema.Task | |
| resp, err := http.Get("http://127.0.0.1:9090/api/agent/tasks/" + taskID) | |
| Expect(err).ToNot(HaveOccurred()) | |
| Expect(resp.StatusCode).To(Equal(200)) | |
| body, _ := io.ReadAll(resp.Body) | |
| json.Unmarshal(body, &task) | |
| Expect(task.Name).To(Equal("Test Task")) | |
| // List tasks | |
| resp, err = http.Get("http://127.0.0.1:9090/api/agent/tasks") | |
| Expect(err).ToNot(HaveOccurred()) | |
| Expect(resp.StatusCode).To(Equal(200)) | |
| var tasks []schema.Task | |
| body, _ = io.ReadAll(resp.Body) | |
| json.Unmarshal(body, &tasks) | |
| Expect(len(tasks)).To(BeNumerically(">=", 1)) | |
| // Update task | |
| taskBody["name"] = "Updated Task" | |
| err = putRequestJSON("http://127.0.0.1:9090/api/agent/tasks/"+taskID, &taskBody) | |
| Expect(err).ToNot(HaveOccurred()) | |
| // Verify update | |
| resp, err = http.Get("http://127.0.0.1:9090/api/agent/tasks/" + taskID) | |
| Expect(err).ToNot(HaveOccurred()) | |
| body, _ = io.ReadAll(resp.Body) | |
| json.Unmarshal(body, &task) | |
| Expect(task.Name).To(Equal("Updated Task")) | |
| // Delete task | |
| req, _ := http.NewRequest("DELETE", "http://127.0.0.1:9090/api/agent/tasks/"+taskID, nil) | |
| req.Header.Set("Authorization", bearerKey) | |
| resp, err = http.DefaultClient.Do(req) | |
| Expect(err).ToNot(HaveOccurred()) | |
| Expect(resp.StatusCode).To(Equal(200)) | |
| }) | |
| It("executes and monitors jobs", func() { | |
| // Create a task first | |
| taskBody := map[string]interface{}{ | |
| "name": "Job Test Task", | |
| "model": "testmodel.ggml", | |
| "prompt": "Say hello", | |
| "enabled": true, | |
| } | |
| var createResp map[string]interface{} | |
| err := postRequestResponseJSON("http://127.0.0.1:9090/api/agent/tasks", &taskBody, &createResp) | |
| Expect(err).ToNot(HaveOccurred()) | |
| taskID := createResp["id"].(string) | |
| // Execute a job | |
| jobBody := map[string]interface{}{ | |
| "task_id": taskID, | |
| "parameters": map[string]string{}, | |
| } | |
| var jobResp schema.JobExecutionResponse | |
| err = postRequestResponseJSON("http://127.0.0.1:9090/api/agent/jobs/execute", &jobBody, &jobResp) | |
| Expect(err).ToNot(HaveOccurred()) | |
| Expect(jobResp.JobID).ToNot(BeEmpty()) | |
| jobID := jobResp.JobID | |
| // Get job status | |
| var job schema.Job | |
| resp, err := http.Get("http://127.0.0.1:9090/api/agent/jobs/" + jobID) | |
| Expect(err).ToNot(HaveOccurred()) | |
| Expect(resp.StatusCode).To(Equal(200)) | |
| body, _ := io.ReadAll(resp.Body) | |
| json.Unmarshal(body, &job) | |
| Expect(job.ID).To(Equal(jobID)) | |
| Expect(job.TaskID).To(Equal(taskID)) | |
| // List jobs | |
| resp, err = http.Get("http://127.0.0.1:9090/api/agent/jobs") | |
| Expect(err).ToNot(HaveOccurred()) | |
| Expect(resp.StatusCode).To(Equal(200)) | |
| var jobs []schema.Job | |
| body, _ = io.ReadAll(resp.Body) | |
| json.Unmarshal(body, &jobs) | |
| Expect(len(jobs)).To(BeNumerically(">=", 1)) | |
| // Cancel job (if still pending/running) | |
| if job.Status == schema.JobStatusPending || job.Status == schema.JobStatusRunning { | |
| req, _ := http.NewRequest("POST", "http://127.0.0.1:9090/api/agent/jobs/"+jobID+"/cancel", nil) | |
| req.Header.Set("Authorization", bearerKey) | |
| resp, err = http.DefaultClient.Do(req) | |
| Expect(err).ToNot(HaveOccurred()) | |
| Expect(resp.StatusCode).To(Equal(200)) | |
| } | |
| }) | |
| It("executes task by name", func() { | |
| // Create a task with a specific name | |
| taskBody := map[string]interface{}{ | |
| "name": "Named Task", | |
| "model": "testmodel.ggml", | |
| "prompt": "Hello", | |
| "enabled": true, | |
| } | |
| var createResp map[string]interface{} | |
| err := postRequestResponseJSON("http://127.0.0.1:9090/api/agent/tasks", &taskBody, &createResp) | |
| Expect(err).ToNot(HaveOccurred()) | |
| // Execute by name | |
| paramsBody := map[string]string{"param1": "value1"} | |
| var jobResp schema.JobExecutionResponse | |
| err = postRequestResponseJSON("http://127.0.0.1:9090/api/agent/tasks/Named Task/execute", ¶msBody, &jobResp) | |
| Expect(err).ToNot(HaveOccurred()) | |
| Expect(jobResp.JobID).ToNot(BeEmpty()) | |
| }) | |
| }) | |
| }) | |
| }) | |
| Context("Config file", func() { | |
| BeforeEach(func() { | |
| if runtime.GOOS != "linux" { | |
| Skip("run this test only on linux") | |
| } | |
| modelPath := os.Getenv("MODELS_PATH") | |
| backendPath := os.Getenv("BACKENDS_PATH") | |
| c, cancel = context.WithCancel(context.Background()) | |
| var err error | |
| systemState, err := system.GetSystemState( | |
| system.WithBackendPath(backendPath), | |
| system.WithModelPath(modelPath), | |
| ) | |
| Expect(err).ToNot(HaveOccurred()) | |
| application, err := application.New( | |
| append(commonOpts, | |
| config.WithContext(c), | |
| config.WithSystemState(systemState), | |
| config.WithConfigFile(os.Getenv("CONFIG_FILE")))..., | |
| ) | |
| Expect(err).ToNot(HaveOccurred()) | |
| app, err = API(application) | |
| Expect(err).ToNot(HaveOccurred()) | |
| go func() { | |
| if err := app.Start("127.0.0.1:9090"); err != nil && err != http.ErrServerClosed { | |
| xlog.Error("server error", "error", err) | |
| } | |
| }() | |
| defaultConfig := openai.DefaultConfig("") | |
| defaultConfig.BaseURL = "http://127.0.0.1:9090/v1" | |
| client2 = openaigo.NewClient("") | |
| client2.BaseURL = defaultConfig.BaseURL | |
| // Wait for API to be ready | |
| client = openai.NewClientWithConfig(defaultConfig) | |
| Eventually(func() error { | |
| _, err := client.ListModels(context.TODO()) | |
| return err | |
| }, "2m").ShouldNot(HaveOccurred()) | |
| }) | |
| AfterEach(func() { | |
| cancel() | |
| if app != nil { | |
| ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) | |
| defer cancel() | |
| err := app.Shutdown(ctx) | |
| Expect(err).ToNot(HaveOccurred()) | |
| } | |
| }) | |
| It("can generate chat completions from config file (list1)", func() { | |
| resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: "list1", Messages: []openai.ChatCompletionMessage{{Role: "user", Content: testPrompt}}}) | |
| Expect(err).ToNot(HaveOccurred()) | |
| Expect(len(resp.Choices)).To(Equal(1)) | |
| Expect(resp.Choices[0].Message.Content).ToNot(BeEmpty()) | |
| }) | |
| It("can generate chat completions from config file (list2)", func() { | |
| resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: "list2", Messages: []openai.ChatCompletionMessage{{Role: "user", Content: testPrompt}}}) | |
| Expect(err).ToNot(HaveOccurred()) | |
| Expect(len(resp.Choices)).To(Equal(1)) | |
| Expect(resp.Choices[0].Message.Content).ToNot(BeEmpty()) | |
| }) | |
| It("can generate edit completions from config file", func() { | |
| request := openaigo.EditCreateRequestBody{ | |
| Model: "list2", | |
| Instruction: "foo", | |
| Input: "bar", | |
| } | |
| resp, err := client2.CreateEdit(context.Background(), request) | |
| Expect(err).ToNot(HaveOccurred()) | |
| Expect(len(resp.Choices)).To(Equal(1)) | |
| Expect(resp.Choices[0].Text).ToNot(BeEmpty()) | |
| }) | |
| }) | |
| }) | |