File size: 3,972 Bytes
0f07ba7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
package startup

import (
	"context"
	"encoding/json"
	"errors"
	"fmt"
	"time"

	"github.com/google/uuid"
	"github.com/mudler/LocalAI/core/config"
	"github.com/mudler/LocalAI/core/gallery"
	"github.com/mudler/LocalAI/core/gallery/importers"
	"github.com/mudler/LocalAI/core/services"
	"github.com/mudler/LocalAI/pkg/model"
	"github.com/mudler/LocalAI/pkg/system"
	"github.com/mudler/LocalAI/pkg/utils"
	"github.com/mudler/xlog"
)

const (
	YAML_EXTENSION = ".yaml"
)

// InstallModels will preload models from the given list of URLs and galleries
// It will download the model if it is not already present in the model path
// It will also try to resolve if the model is an embedded model YAML configuration
func InstallModels(ctx context.Context, galleryService *services.GalleryService, galleries, backendGalleries []config.Gallery, systemState *system.SystemState, modelLoader *model.ModelLoader, enforceScan, autoloadBackendGalleries bool, downloadStatus func(string, string, string, float64), models ...string) error {
	// create an error that groups all errors
	var err error
	for _, url := range models {
		// Check if it's a model gallery, or print a warning
		e, found := installModel(ctx, galleries, backendGalleries, url, systemState, modelLoader, downloadStatus, enforceScan, autoloadBackendGalleries)
		if e != nil && found {
			xlog.Error("[startup] failed installing model", "error", err, "model", url)
			err = errors.Join(err, e)
		} else if !found {
			xlog.Debug("[startup] model not found in the gallery", "model", url)

			if galleryService == nil {
				return fmt.Errorf("cannot start autoimporter, not sure how to handle this uri")
			}

			// TODO: we should just use the discoverModelConfig here and default to this.
			modelConfig, discoverErr := importers.DiscoverModelConfig(url, json.RawMessage{})
			if discoverErr != nil {
				xlog.Error("[startup] failed to discover model config", "error", discoverErr, "model", url)
				err = errors.Join(discoverErr, fmt.Errorf("failed to discover model config: %w", err))
				continue
			}

			uuid, uuidErr := uuid.NewUUID()
			if uuidErr != nil {
				err = errors.Join(uuidErr, fmt.Errorf("failed to generate UUID: %w", uuidErr))
				continue
			}

			galleryService.ModelGalleryChannel <- services.GalleryOp[gallery.GalleryModel, gallery.ModelConfig]{
				Req: gallery.GalleryModel{
					Overrides: map[string]interface{}{},
				},
				ID:                 uuid.String(),
				GalleryElementName: modelConfig.Name,
				GalleryElement:     &modelConfig,
				BackendGalleries:   backendGalleries,
			}

			var status *services.GalleryOpStatus
			// wait for op to finish
			for {
				status = galleryService.GetStatus(uuid.String())
				if status != nil && status.Processed {
					break
				}
				time.Sleep(1 * time.Second)
			}

			if status.Error != nil {
				xlog.Error("[startup] failed to import model", "error", status.Error, "model", modelConfig.Name, "url", url)
				return status.Error
			}

			xlog.Info("[startup] imported model", "model", modelConfig.Name, "url", url)
		}
	}
	return err
}

func installModel(ctx context.Context, galleries, backendGalleries []config.Gallery, modelName string, systemState *system.SystemState, modelLoader *model.ModelLoader, downloadStatus func(string, string, string, float64), enforceScan, autoloadBackendGalleries bool) (error, bool) {
	models, err := gallery.AvailableGalleryModels(galleries, systemState)
	if err != nil {
		return err, false
	}

	model := gallery.FindGalleryElement(models, modelName)
	if model == nil {
		return err, false
	}

	if downloadStatus == nil {
		downloadStatus = utils.DisplayDownloadFunction
	}

	xlog.Info("installing model", "model", modelName, "license", model.License)
	err = gallery.InstallModelFromGallery(ctx, galleries, backendGalleries, systemState, modelLoader, modelName, gallery.GalleryModel{}, downloadStatus, enforceScan, autoloadBackendGalleries)
	if err != nil {
		return err, true
	}

	return nil, true
}