|
|
package model_test |
|
|
|
|
|
import ( |
|
|
"errors" |
|
|
"os" |
|
|
"path/filepath" |
|
|
"sync" |
|
|
"sync/atomic" |
|
|
"time" |
|
|
|
|
|
"github.com/mudler/LocalAI/pkg/model" |
|
|
"github.com/mudler/LocalAI/pkg/system" |
|
|
. "github.com/onsi/ginkgo/v2" |
|
|
. "github.com/onsi/gomega" |
|
|
) |
|
|
|
|
|
var _ = Describe("ModelLoader", func() { |
|
|
var ( |
|
|
modelLoader *model.ModelLoader |
|
|
modelPath string |
|
|
mockModel *model.Model |
|
|
) |
|
|
|
|
|
BeforeEach(func() { |
|
|
|
|
|
modelPath = "/tmp/test_model_path" |
|
|
os.Mkdir(modelPath, 0755) |
|
|
|
|
|
systemState, err := system.GetSystemState( |
|
|
system.WithModelPath(modelPath), |
|
|
) |
|
|
Expect(err).ToNot(HaveOccurred()) |
|
|
modelLoader = model.NewModelLoader(systemState) |
|
|
}) |
|
|
|
|
|
AfterEach(func() { |
|
|
|
|
|
os.RemoveAll(modelPath) |
|
|
}) |
|
|
|
|
|
Context("NewModelLoader", func() { |
|
|
It("should create a new ModelLoader with an empty model map", func() { |
|
|
Expect(modelLoader).ToNot(BeNil()) |
|
|
Expect(modelLoader.ModelPath).To(Equal(modelPath)) |
|
|
Expect(modelLoader.ListLoadedModels()).To(BeEmpty()) |
|
|
}) |
|
|
}) |
|
|
|
|
|
Context("ExistsInModelPath", func() { |
|
|
It("should return true if a file exists in the model path", func() { |
|
|
testFile := filepath.Join(modelPath, "test.model") |
|
|
os.Create(testFile) |
|
|
Expect(modelLoader.ExistsInModelPath("test.model")).To(BeTrue()) |
|
|
}) |
|
|
|
|
|
It("should return false if a file does not exist in the model path", func() { |
|
|
Expect(modelLoader.ExistsInModelPath("nonexistent.model")).To(BeFalse()) |
|
|
}) |
|
|
}) |
|
|
|
|
|
Context("ListFilesInModelPath", func() { |
|
|
It("should list all valid model files in the model path", func() { |
|
|
os.Create(filepath.Join(modelPath, "test.model")) |
|
|
os.Create(filepath.Join(modelPath, "README.md")) |
|
|
|
|
|
files, err := modelLoader.ListFilesInModelPath() |
|
|
Expect(err).To(BeNil()) |
|
|
Expect(files).To(ContainElement("test.model")) |
|
|
Expect(files).ToNot(ContainElement("README.md")) |
|
|
}) |
|
|
}) |
|
|
|
|
|
Context("LoadModel", func() { |
|
|
It("should load a model and keep it in memory", func() { |
|
|
mockModel = model.NewModel("foo", "test.model", nil) |
|
|
|
|
|
mockLoader := func(modelID, modelName, modelFile string) (*model.Model, error) { |
|
|
return mockModel, nil |
|
|
} |
|
|
|
|
|
model, err := modelLoader.LoadModel("foo", "test.model", mockLoader) |
|
|
Expect(err).To(BeNil()) |
|
|
Expect(model).To(Equal(mockModel)) |
|
|
Expect(modelLoader.CheckIsLoaded("foo")).To(Equal(mockModel)) |
|
|
}) |
|
|
|
|
|
It("should return an error if loading the model fails", func() { |
|
|
mockLoader := func(modelID, modelName, modelFile string) (*model.Model, error) { |
|
|
return nil, errors.New("failed to load model") |
|
|
} |
|
|
|
|
|
model, err := modelLoader.LoadModel("foo", "test.model", mockLoader) |
|
|
Expect(err).To(HaveOccurred()) |
|
|
Expect(model).To(BeNil()) |
|
|
}) |
|
|
}) |
|
|
|
|
|
Context("ShutdownModel", func() { |
|
|
It("should shutdown a loaded model", func() { |
|
|
mockLoader := func(modelID, modelName, modelFile string) (*model.Model, error) { |
|
|
return model.NewModel("foo", "test.model", nil), nil |
|
|
} |
|
|
|
|
|
_, err := modelLoader.LoadModel("foo", "test.model", mockLoader) |
|
|
Expect(err).To(BeNil()) |
|
|
|
|
|
err = modelLoader.ShutdownModel("foo") |
|
|
Expect(err).To(BeNil()) |
|
|
Expect(modelLoader.CheckIsLoaded("foo")).To(BeNil()) |
|
|
}) |
|
|
}) |
|
|
|
|
|
Context("Concurrent Loading", func() { |
|
|
It("should handle concurrent requests for the same model", func() { |
|
|
var loadCount int32 |
|
|
mockLoader := func(modelID, modelName, modelFile string) (*model.Model, error) { |
|
|
atomic.AddInt32(&loadCount, 1) |
|
|
time.Sleep(100 * time.Millisecond) |
|
|
return model.NewModel(modelID, modelName, nil), nil |
|
|
} |
|
|
|
|
|
var wg sync.WaitGroup |
|
|
results := make([]*model.Model, 5) |
|
|
errs := make([]error, 5) |
|
|
|
|
|
|
|
|
for i := 0; i < 5; i++ { |
|
|
wg.Add(1) |
|
|
go func(idx int) { |
|
|
defer wg.Done() |
|
|
results[idx], errs[idx] = modelLoader.LoadModel("concurrent-model", "test.model", mockLoader) |
|
|
}(i) |
|
|
} |
|
|
|
|
|
wg.Wait() |
|
|
|
|
|
|
|
|
for i := 0; i < 5; i++ { |
|
|
Expect(errs[i]).To(BeNil()) |
|
|
Expect(results[i]).ToNot(BeNil()) |
|
|
} |
|
|
|
|
|
|
|
|
Expect(atomic.LoadInt32(&loadCount)).To(Equal(int32(1))) |
|
|
|
|
|
|
|
|
for i := 1; i < 5; i++ { |
|
|
Expect(results[i]).To(Equal(results[0])) |
|
|
} |
|
|
}) |
|
|
|
|
|
It("should handle concurrent requests for different models", func() { |
|
|
var loadCount int32 |
|
|
mockLoader := func(modelID, modelName, modelFile string) (*model.Model, error) { |
|
|
atomic.AddInt32(&loadCount, 1) |
|
|
time.Sleep(50 * time.Millisecond) |
|
|
return model.NewModel(modelID, modelName, nil), nil |
|
|
} |
|
|
|
|
|
var wg sync.WaitGroup |
|
|
modelCount := 3 |
|
|
|
|
|
|
|
|
for i := 0; i < modelCount; i++ { |
|
|
wg.Add(1) |
|
|
go func(idx int) { |
|
|
defer wg.Done() |
|
|
modelID := "model-" + string(rune('A'+idx)) |
|
|
_, err := modelLoader.LoadModel(modelID, "test.model", mockLoader) |
|
|
Expect(err).To(BeNil()) |
|
|
}(i) |
|
|
} |
|
|
|
|
|
wg.Wait() |
|
|
|
|
|
|
|
|
Expect(atomic.LoadInt32(&loadCount)).To(Equal(int32(modelCount))) |
|
|
|
|
|
|
|
|
Expect(modelLoader.CheckIsLoaded("model-A")).ToNot(BeNil()) |
|
|
Expect(modelLoader.CheckIsLoaded("model-B")).ToNot(BeNil()) |
|
|
Expect(modelLoader.CheckIsLoaded("model-C")).ToNot(BeNil()) |
|
|
}) |
|
|
|
|
|
It("should track loading count correctly", func() { |
|
|
loadStarted := make(chan struct{}) |
|
|
loadComplete := make(chan struct{}) |
|
|
|
|
|
mockLoader := func(modelID, modelName, modelFile string) (*model.Model, error) { |
|
|
close(loadStarted) |
|
|
<-loadComplete |
|
|
return model.NewModel(modelID, modelName, nil), nil |
|
|
} |
|
|
|
|
|
|
|
|
go func() { |
|
|
modelLoader.LoadModel("slow-model", "test.model", mockLoader) |
|
|
}() |
|
|
|
|
|
|
|
|
<-loadStarted |
|
|
|
|
|
|
|
|
Expect(modelLoader.GetLoadingCount()).To(Equal(1)) |
|
|
|
|
|
|
|
|
close(loadComplete) |
|
|
|
|
|
|
|
|
time.Sleep(50 * time.Millisecond) |
|
|
|
|
|
|
|
|
Expect(modelLoader.GetLoadingCount()).To(Equal(0)) |
|
|
}) |
|
|
|
|
|
It("should retry loading if first attempt fails", func() { |
|
|
var attemptCount int32 |
|
|
mockLoader := func(modelID, modelName, modelFile string) (*model.Model, error) { |
|
|
count := atomic.AddInt32(&attemptCount, 1) |
|
|
if count == 1 { |
|
|
return nil, errors.New("first attempt fails") |
|
|
} |
|
|
return model.NewModel(modelID, modelName, nil), nil |
|
|
} |
|
|
|
|
|
|
|
|
var wg sync.WaitGroup |
|
|
wg.Add(2) |
|
|
|
|
|
var err1, err2 error |
|
|
var m1, m2 *model.Model |
|
|
|
|
|
go func() { |
|
|
defer wg.Done() |
|
|
m1, err1 = modelLoader.LoadModel("retry-model", "test.model", mockLoader) |
|
|
}() |
|
|
|
|
|
|
|
|
time.Sleep(10 * time.Millisecond) |
|
|
|
|
|
go func() { |
|
|
defer wg.Done() |
|
|
m2, err2 = modelLoader.LoadModel("retry-model", "test.model", mockLoader) |
|
|
}() |
|
|
|
|
|
wg.Wait() |
|
|
|
|
|
|
|
|
successCount := 0 |
|
|
if err1 == nil && m1 != nil { |
|
|
successCount++ |
|
|
} |
|
|
if err2 == nil && m2 != nil { |
|
|
successCount++ |
|
|
} |
|
|
Expect(successCount).To(BeNumerically(">=", 1)) |
|
|
}) |
|
|
}) |
|
|
|
|
|
Context("GetLoadingCount", func() { |
|
|
It("should return 0 when nothing is loading", func() { |
|
|
Expect(modelLoader.GetLoadingCount()).To(Equal(0)) |
|
|
}) |
|
|
}) |
|
|
|
|
|
Context("LRU Eviction Retry Settings", func() { |
|
|
It("should allow updating retry settings", func() { |
|
|
modelLoader.SetLRUEvictionRetrySettings(50, 2*time.Second) |
|
|
|
|
|
|
|
|
Expect(modelLoader).ToNot(BeNil()) |
|
|
}) |
|
|
}) |
|
|
}) |
|
|
|