| package openai |
|
|
| import ( |
| "bytes" |
| "mime/multipart" |
| "net/http" |
| "net/http/httptest" |
| "os" |
| "path/filepath" |
| "testing" |
|
|
| "github.com/labstack/echo/v4" |
| "github.com/mudler/LocalAI/core/backend" |
| "github.com/mudler/LocalAI/core/config" |
| "github.com/mudler/LocalAI/core/http/middleware" |
| model "github.com/mudler/LocalAI/pkg/model" |
| "github.com/stretchr/testify/require" |
| ) |
|
|
| func makeMultipartRequest(t *testing.T, fields map[string]string, files map[string][]byte) (*http.Request, string) { |
| b := &bytes.Buffer{} |
| w := multipart.NewWriter(b) |
| for k, v := range fields { |
| _ = w.WriteField(k, v) |
| } |
| for fname, content := range files { |
| fw, err := w.CreateFormFile(fname, fname+".png") |
| require.NoError(t, err) |
| _, err = fw.Write(content) |
| require.NoError(t, err) |
| } |
| require.NoError(t, w.Close()) |
| req := httptest.NewRequest(http.MethodPost, "/v1/images/inpainting", b) |
| req.Header.Set("Content-Type", w.FormDataContentType()) |
| return req, w.FormDataContentType() |
| } |
|
|
| func TestInpainting_MissingFiles(t *testing.T) { |
| e := echo.New() |
| |
| h := InpaintingEndpoint(nil, nil, config.NewApplicationConfig()) |
|
|
| req := httptest.NewRequest(http.MethodPost, "/v1/images/inpainting", nil) |
| rec := httptest.NewRecorder() |
| c := e.NewContext(req, rec) |
|
|
| err := h(c) |
| require.Error(t, err) |
| } |
|
|
| func TestInpainting_HappyPath(t *testing.T) { |
| |
| tmpDir, err := os.MkdirTemp("", "gencontent") |
| require.NoError(t, err) |
| defer os.RemoveAll(tmpDir) |
|
|
| appConf := config.NewApplicationConfig(config.WithGeneratedContentDir(tmpDir)) |
|
|
| |
| orig := backend.ImageGenerationFunc |
| backend.ImageGenerationFunc = func(height, width, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig, refImages []string) (func() error, error) { |
| fn := func() error { |
| |
| return os.WriteFile(dst, []byte("PNGDATA"), 0644) |
| } |
| return fn, nil |
| } |
| defer func() { backend.ImageGenerationFunc = orig }() |
|
|
| |
| fields := map[string]string{"model": "dreamshaper-8-inpainting", "prompt": "A test"} |
| files := map[string][]byte{"image": []byte("IMAGEDATA"), "mask": []byte("MASKDATA")} |
| reqBuf, _ := makeMultipartRequest(t, fields, files) |
|
|
| rec := httptest.NewRecorder() |
| e := echo.New() |
| c := e.NewContext(reqBuf, rec) |
|
|
| |
| c.Set(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG, &config.ModelConfig{Backend: "diffusers"}) |
|
|
| h := InpaintingEndpoint(nil, nil, appConf) |
|
|
| |
| err = h(c) |
| require.NoError(t, err) |
| require.Equal(t, http.StatusOK, rec.Code) |
|
|
| |
| body := rec.Body.String() |
| require.Contains(t, body, "generated-images") |
|
|
| |
| |
| |
| idx := bytes.Index(rec.Body.Bytes(), []byte("generated-images/")) |
| require.True(t, idx >= 0) |
| rest := rec.Body.Bytes()[idx:] |
| end := bytes.IndexAny(rest, "\",}\n") |
| if end == -1 { |
| end = len(rest) |
| } |
| fname := string(rest[len("generated-images/"):end]) |
| |
| _, err = os.Stat(filepath.Join(tmpDir, fname)) |
| require.NoError(t, err) |
| } |
|
|