File size: 7,732 Bytes
0f07ba7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
package model_test

import (
	"errors"
	"os"
	"path/filepath"
	"sync"
	"sync/atomic"
	"time"

	"github.com/mudler/LocalAI/pkg/model"
	"github.com/mudler/LocalAI/pkg/system"
	. "github.com/onsi/ginkgo/v2"
	. "github.com/onsi/gomega"
)

var _ = Describe("ModelLoader", func() {
	var (
		modelLoader *model.ModelLoader
		modelPath   string
		mockModel   *model.Model
	)

	BeforeEach(func() {
		// Setup the model loader with a test directory
		modelPath = "/tmp/test_model_path"
		os.Mkdir(modelPath, 0755)

		systemState, err := system.GetSystemState(
			system.WithModelPath(modelPath),
		)
		Expect(err).ToNot(HaveOccurred())
		modelLoader = model.NewModelLoader(systemState)
	})

	AfterEach(func() {
		// Cleanup test directory
		os.RemoveAll(modelPath)
	})

	Context("NewModelLoader", func() {
		It("should create a new ModelLoader with an empty model map", func() {
			Expect(modelLoader).ToNot(BeNil())
			Expect(modelLoader.ModelPath).To(Equal(modelPath))
			Expect(modelLoader.ListLoadedModels()).To(BeEmpty())
		})
	})

	Context("ExistsInModelPath", func() {
		It("should return true if a file exists in the model path", func() {
			testFile := filepath.Join(modelPath, "test.model")
			os.Create(testFile)
			Expect(modelLoader.ExistsInModelPath("test.model")).To(BeTrue())
		})

		It("should return false if a file does not exist in the model path", func() {
			Expect(modelLoader.ExistsInModelPath("nonexistent.model")).To(BeFalse())
		})
	})

	Context("ListFilesInModelPath", func() {
		It("should list all valid model files in the model path", func() {
			os.Create(filepath.Join(modelPath, "test.model"))
			os.Create(filepath.Join(modelPath, "README.md"))

			files, err := modelLoader.ListFilesInModelPath()
			Expect(err).To(BeNil())
			Expect(files).To(ContainElement("test.model"))
			Expect(files).ToNot(ContainElement("README.md"))
		})
	})

	Context("LoadModel", func() {
		It("should load a model and keep it in memory", func() {
			mockModel = model.NewModel("foo", "test.model", nil)

			mockLoader := func(modelID, modelName, modelFile string) (*model.Model, error) {
				return mockModel, nil
			}

			model, err := modelLoader.LoadModel("foo", "test.model", mockLoader)
			Expect(err).To(BeNil())
			Expect(model).To(Equal(mockModel))
			Expect(modelLoader.CheckIsLoaded("foo")).To(Equal(mockModel))
		})

		It("should return an error if loading the model fails", func() {
			mockLoader := func(modelID, modelName, modelFile string) (*model.Model, error) {
				return nil, errors.New("failed to load model")
			}

			model, err := modelLoader.LoadModel("foo", "test.model", mockLoader)
			Expect(err).To(HaveOccurred())
			Expect(model).To(BeNil())
		})
	})

	Context("ShutdownModel", func() {
		It("should shutdown a loaded model", func() {
			mockLoader := func(modelID, modelName, modelFile string) (*model.Model, error) {
				return model.NewModel("foo", "test.model", nil), nil
			}

			_, err := modelLoader.LoadModel("foo", "test.model", mockLoader)
			Expect(err).To(BeNil())

			err = modelLoader.ShutdownModel("foo")
			Expect(err).To(BeNil())
			Expect(modelLoader.CheckIsLoaded("foo")).To(BeNil())
		})
	})

	Context("Concurrent Loading", func() {
		It("should handle concurrent requests for the same model", func() {
			var loadCount int32
			mockLoader := func(modelID, modelName, modelFile string) (*model.Model, error) {
				atomic.AddInt32(&loadCount, 1)
				time.Sleep(100 * time.Millisecond) // Simulate loading time
				return model.NewModel(modelID, modelName, nil), nil
			}

			var wg sync.WaitGroup
			results := make([]*model.Model, 5)
			errs := make([]error, 5)

			// Start 5 concurrent requests for the same model
			for i := 0; i < 5; i++ {
				wg.Add(1)
				go func(idx int) {
					defer wg.Done()
					results[idx], errs[idx] = modelLoader.LoadModel("concurrent-model", "test.model", mockLoader)
				}(i)
			}

			wg.Wait()

			// All requests should succeed
			for i := 0; i < 5; i++ {
				Expect(errs[i]).To(BeNil())
				Expect(results[i]).ToNot(BeNil())
			}

			// The loader should only have been called once
			Expect(atomic.LoadInt32(&loadCount)).To(Equal(int32(1)))

			// All results should be the same model instance
			for i := 1; i < 5; i++ {
				Expect(results[i]).To(Equal(results[0]))
			}
		})

		It("should handle concurrent requests for different models", func() {
			var loadCount int32
			mockLoader := func(modelID, modelName, modelFile string) (*model.Model, error) {
				atomic.AddInt32(&loadCount, 1)
				time.Sleep(50 * time.Millisecond) // Simulate loading time
				return model.NewModel(modelID, modelName, nil), nil
			}

			var wg sync.WaitGroup
			modelCount := 3

			// Start concurrent requests for different models
			for i := 0; i < modelCount; i++ {
				wg.Add(1)
				go func(idx int) {
					defer wg.Done()
					modelID := "model-" + string(rune('A'+idx))
					_, err := modelLoader.LoadModel(modelID, "test.model", mockLoader)
					Expect(err).To(BeNil())
				}(i)
			}

			wg.Wait()

			// Each model should be loaded exactly once
			Expect(atomic.LoadInt32(&loadCount)).To(Equal(int32(modelCount)))

			// All models should be loaded
			Expect(modelLoader.CheckIsLoaded("model-A")).ToNot(BeNil())
			Expect(modelLoader.CheckIsLoaded("model-B")).ToNot(BeNil())
			Expect(modelLoader.CheckIsLoaded("model-C")).ToNot(BeNil())
		})

		It("should track loading count correctly", func() {
			loadStarted := make(chan struct{})
			loadComplete := make(chan struct{})

			mockLoader := func(modelID, modelName, modelFile string) (*model.Model, error) {
				close(loadStarted)
				<-loadComplete // Wait until we're told to complete
				return model.NewModel(modelID, modelName, nil), nil
			}

			// Start loading in background
			go func() {
				modelLoader.LoadModel("slow-model", "test.model", mockLoader)
			}()

			// Wait for loading to start
			<-loadStarted

			// Loading count should be 1
			Expect(modelLoader.GetLoadingCount()).To(Equal(1))

			// Complete the loading
			close(loadComplete)

			// Wait a bit for cleanup
			time.Sleep(50 * time.Millisecond)

			// Loading count should be back to 0
			Expect(modelLoader.GetLoadingCount()).To(Equal(0))
		})

		It("should retry loading if first attempt fails", func() {
			var attemptCount int32
			mockLoader := func(modelID, modelName, modelFile string) (*model.Model, error) {
				count := atomic.AddInt32(&attemptCount, 1)
				if count == 1 {
					return nil, errors.New("first attempt fails")
				}
				return model.NewModel(modelID, modelName, nil), nil
			}

			// First goroutine will fail
			var wg sync.WaitGroup
			wg.Add(2)

			var err1, err2 error
			var m1, m2 *model.Model

			go func() {
				defer wg.Done()
				m1, err1 = modelLoader.LoadModel("retry-model", "test.model", mockLoader)
			}()

			// Give first goroutine a head start
			time.Sleep(10 * time.Millisecond)

			go func() {
				defer wg.Done()
				m2, err2 = modelLoader.LoadModel("retry-model", "test.model", mockLoader)
			}()

			wg.Wait()

			// At least one should succeed (the second attempt after retry)
			successCount := 0
			if err1 == nil && m1 != nil {
				successCount++
			}
			if err2 == nil && m2 != nil {
				successCount++
			}
			Expect(successCount).To(BeNumerically(">=", 1))
		})
	})

	Context("GetLoadingCount", func() {
		It("should return 0 when nothing is loading", func() {
			Expect(modelLoader.GetLoadingCount()).To(Equal(0))
		})
	})

	Context("LRU Eviction Retry Settings", func() {
		It("should allow updating retry settings", func() {
			modelLoader.SetLRUEvictionRetrySettings(50, 2*time.Second)
			// Settings are updated - we can verify through behavior if needed
			// For now, just verify the call doesn't panic
			Expect(modelLoader).ToNot(BeNil())
		})
	})
})