|
|
package openai |
|
|
|
|
|
import ( |
|
|
"bufio" |
|
|
"encoding/base64" |
|
|
"encoding/json" |
|
|
"fmt" |
|
|
"io" |
|
|
"net/http" |
|
|
"net/url" |
|
|
"os" |
|
|
"path/filepath" |
|
|
"strconv" |
|
|
"strings" |
|
|
"time" |
|
|
|
|
|
"github.com/google/uuid" |
|
|
"github.com/labstack/echo/v4" |
|
|
"github.com/mudler/LocalAI/core/config" |
|
|
"github.com/mudler/LocalAI/core/http/middleware" |
|
|
"github.com/mudler/LocalAI/core/schema" |
|
|
|
|
|
"github.com/mudler/LocalAI/core/backend" |
|
|
|
|
|
model "github.com/mudler/LocalAI/pkg/model" |
|
|
"github.com/mudler/xlog" |
|
|
) |
|
|
|
|
|
func downloadFile(url string) (string, error) { |
|
|
|
|
|
resp, err := http.Get(url) |
|
|
if err != nil { |
|
|
return "", err |
|
|
} |
|
|
defer resp.Body.Close() |
|
|
|
|
|
|
|
|
out, err := os.CreateTemp("", "image") |
|
|
if err != nil { |
|
|
return "", err |
|
|
} |
|
|
defer out.Close() |
|
|
|
|
|
|
|
|
_, err = io.Copy(out, resp.Body) |
|
|
return out.Name(), err |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
func ImageEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc { |
|
|
return func(c echo.Context) error { |
|
|
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest) |
|
|
if !ok || input.Model == "" { |
|
|
xlog.Error("Image Endpoint - Invalid Input") |
|
|
return echo.ErrBadRequest |
|
|
} |
|
|
|
|
|
config, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) |
|
|
if !ok || config == nil { |
|
|
xlog.Error("Image Endpoint - Invalid Config") |
|
|
return echo.ErrBadRequest |
|
|
} |
|
|
|
|
|
|
|
|
src := "" |
|
|
if input.File != "" { |
|
|
src = processImageFile(input.File, appConfig.GeneratedContentDir) |
|
|
if src != "" { |
|
|
defer os.RemoveAll(src) |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
var inputImages []string |
|
|
if len(input.Files) > 0 { |
|
|
for _, file := range input.Files { |
|
|
processedFile := processImageFile(file, appConfig.GeneratedContentDir) |
|
|
if processedFile != "" { |
|
|
inputImages = append(inputImages, processedFile) |
|
|
defer os.RemoveAll(processedFile) |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
var refImages []string |
|
|
if len(input.RefImages) > 0 { |
|
|
for _, file := range input.RefImages { |
|
|
processedFile := processImageFile(file, appConfig.GeneratedContentDir) |
|
|
if processedFile != "" { |
|
|
refImages = append(refImages, processedFile) |
|
|
defer os.RemoveAll(processedFile) |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
xlog.Debug("Parameter Config", "config", config) |
|
|
|
|
|
switch config.Backend { |
|
|
case "stablediffusion": |
|
|
config.Backend = model.StableDiffusionGGMLBackend |
|
|
case "": |
|
|
config.Backend = model.StableDiffusionGGMLBackend |
|
|
} |
|
|
|
|
|
if !strings.Contains(input.Size, "x") { |
|
|
input.Size = "512x512" |
|
|
xlog.Warn("Invalid size, using default 512x512") |
|
|
} |
|
|
|
|
|
sizeParts := strings.Split(input.Size, "x") |
|
|
if len(sizeParts) != 2 { |
|
|
return fmt.Errorf("invalid value for 'size'") |
|
|
} |
|
|
width, err := strconv.Atoi(sizeParts[0]) |
|
|
if err != nil { |
|
|
return fmt.Errorf("invalid value for 'size'") |
|
|
} |
|
|
height, err := strconv.Atoi(sizeParts[1]) |
|
|
if err != nil { |
|
|
return fmt.Errorf("invalid value for 'size'") |
|
|
} |
|
|
|
|
|
b64JSON := config.ResponseFormat == "b64_json" |
|
|
|
|
|
|
|
|
var result []schema.Item |
|
|
for _, i := range config.PromptStrings { |
|
|
n := input.N |
|
|
if input.N == 0 { |
|
|
n = 1 |
|
|
} |
|
|
for j := 0; j < n; j++ { |
|
|
prompts := strings.Split(i, "|") |
|
|
positive_prompt := prompts[0] |
|
|
negative_prompt := "" |
|
|
if len(prompts) > 1 { |
|
|
negative_prompt = prompts[1] |
|
|
} |
|
|
|
|
|
step := config.Step |
|
|
if step == 0 { |
|
|
step = 15 |
|
|
} |
|
|
|
|
|
if input.Step != 0 { |
|
|
step = input.Step |
|
|
} |
|
|
|
|
|
tempDir := "" |
|
|
if !b64JSON { |
|
|
tempDir = filepath.Join(appConfig.GeneratedContentDir, "images") |
|
|
} |
|
|
|
|
|
outputFile, err := os.CreateTemp(tempDir, "b64") |
|
|
if err != nil { |
|
|
return err |
|
|
} |
|
|
outputFile.Close() |
|
|
|
|
|
output := outputFile.Name() + ".png" |
|
|
|
|
|
err = os.Rename(outputFile.Name(), output) |
|
|
if err != nil { |
|
|
return err |
|
|
} |
|
|
|
|
|
baseURL := middleware.BaseURL(c) |
|
|
|
|
|
|
|
|
inputSrc := src |
|
|
if len(inputImages) > 0 { |
|
|
inputSrc = inputImages[0] |
|
|
} |
|
|
|
|
|
fn, err := backend.ImageGeneration(height, width, step, *config.Seed, positive_prompt, negative_prompt, inputSrc, output, ml, *config, appConfig, refImages) |
|
|
if err != nil { |
|
|
return err |
|
|
} |
|
|
if err := fn(); err != nil { |
|
|
return err |
|
|
} |
|
|
|
|
|
item := &schema.Item{} |
|
|
|
|
|
if b64JSON { |
|
|
defer os.RemoveAll(output) |
|
|
data, err := os.ReadFile(output) |
|
|
if err != nil { |
|
|
return err |
|
|
} |
|
|
item.B64JSON = base64.StdEncoding.EncodeToString(data) |
|
|
} else { |
|
|
base := filepath.Base(output) |
|
|
item.URL, err = url.JoinPath(baseURL, "generated-images", base) |
|
|
if err != nil { |
|
|
return err |
|
|
} |
|
|
} |
|
|
|
|
|
result = append(result, *item) |
|
|
} |
|
|
} |
|
|
|
|
|
id := uuid.New().String() |
|
|
created := int(time.Now().Unix()) |
|
|
resp := &schema.OpenAIResponse{ |
|
|
ID: id, |
|
|
Created: created, |
|
|
Data: result, |
|
|
Usage: schema.OpenAIUsage{ |
|
|
PromptTokens: 0, |
|
|
CompletionTokens: 0, |
|
|
TotalTokens: 0, |
|
|
InputTokens: 0, |
|
|
OutputTokens: 0, |
|
|
InputTokensDetails: &schema.InputTokensDetails{ |
|
|
TextTokens: 0, |
|
|
ImageTokens: 0, |
|
|
}, |
|
|
}, |
|
|
} |
|
|
|
|
|
jsonResult, _ := json.Marshal(resp) |
|
|
xlog.Debug("Response", "response", string(jsonResult)) |
|
|
|
|
|
|
|
|
return c.JSON(200, resp) |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
func processImageFile(file string, generatedContentDir string) string { |
|
|
fileData := []byte{} |
|
|
var err error |
|
|
|
|
|
|
|
|
if strings.HasPrefix(file, "http://") || strings.HasPrefix(file, "https://") { |
|
|
out, err := downloadFile(file) |
|
|
if err != nil { |
|
|
xlog.Error("Failed downloading file", "error", err, "file", file) |
|
|
return "" |
|
|
} |
|
|
defer os.RemoveAll(out) |
|
|
|
|
|
fileData, err = os.ReadFile(out) |
|
|
if err != nil { |
|
|
xlog.Error("Failed reading downloaded file", "error", err, "file", out) |
|
|
return "" |
|
|
} |
|
|
} else { |
|
|
|
|
|
fileData, err = base64.StdEncoding.DecodeString(file) |
|
|
if err != nil { |
|
|
xlog.Error("Failed decoding base64 file", "error", err) |
|
|
return "" |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
outputFile, err := os.CreateTemp(generatedContentDir, "b64") |
|
|
if err != nil { |
|
|
xlog.Error("Failed creating temporary file", "error", err) |
|
|
return "" |
|
|
} |
|
|
|
|
|
|
|
|
writer := bufio.NewWriter(outputFile) |
|
|
_, err = writer.Write(fileData) |
|
|
if err != nil { |
|
|
outputFile.Close() |
|
|
xlog.Error("Failed writing to temporary file", "error", err) |
|
|
return "" |
|
|
} |
|
|
outputFile.Close() |
|
|
|
|
|
return outputFile.Name() |
|
|
} |
|
|
|