File size: 9,760 Bytes
0f07ba7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
package hfapi

import (
	"encoding/json"
	"fmt"
	"io"
	"net/http"
	"path/filepath"
	"strings"
)

// Model represents a model from the Hugging Face API
type Model struct {
	ModelID        string                 `json:"modelId"`
	Author         string                 `json:"author"`
	Downloads      int                    `json:"downloads"`
	LastModified   string                 `json:"lastModified"`
	PipelineTag    string                 `json:"pipelineTag"`
	Private        bool                   `json:"private"`
	Tags           []string               `json:"tags"`
	CreatedAt      string                 `json:"createdAt"`
	UpdatedAt      string                 `json:"updatedAt"`
	Sha            string                 `json:"sha"`
	Config         map[string]interface{} `json:"config"`
	ModelIndex     string                 `json:"model_index"`
	LibraryName    string                 `json:"library_name"`
	MaskToken      string                 `json:"mask_token"`
	TokenizerClass string                 `json:"tokenizer_class"`
}

// FileInfo represents file information from HuggingFace
type FileInfo struct {
	Type    string   `json:"type"`
	Oid     string   `json:"oid"`
	Size    int64    `json:"size"`
	Path    string   `json:"path"`
	LFS     *LFSInfo `json:"lfs,omitempty"`
	XetHash string   `json:"xetHash,omitempty"`
}

// LFSInfo represents LFS (Large File Storage) information
type LFSInfo struct {
	Oid         string `json:"oid"`
	Size        int64  `json:"size"`
	PointerSize int    `json:"pointerSize"`
}

// ModelFile represents a file in a model repository
type ModelFile struct {
	Path     string
	Size     int64
	SHA256   string
	IsReadme bool
	URL      string
}

// ModelDetails represents detailed information about a model
type ModelDetails struct {
	ModelID       string
	Author        string
	Files         []ModelFile
	ReadmeFile    *ModelFile
	ReadmeContent string
}

// SearchParams represents the parameters for searching models
type SearchParams struct {
	Sort      string `json:"sort"`
	Direction int    `json:"direction"`
	Limit     int    `json:"limit"`
	Search    string `json:"search"`
}

// Client represents a Hugging Face API client
type Client struct {
	baseURL string
	client  *http.Client
}

// NewClient creates a new Hugging Face API client
func NewClient() *Client {
	return &Client{
		baseURL: "https://huggingface.co/api/models",
		client:  &http.Client{},
	}
}

// SearchModels searches for models using the Hugging Face API
func (c *Client) SearchModels(params SearchParams) ([]Model, error) {
	req, err := http.NewRequest("GET", c.baseURL, nil)
	if err != nil {
		return nil, fmt.Errorf("failed to create request: %w", err)
	}

	// Add query parameters
	q := req.URL.Query()
	q.Add("sort", params.Sort)
	q.Add("direction", fmt.Sprintf("%d", params.Direction))
	q.Add("limit", fmt.Sprintf("%d", params.Limit))
	q.Add("search", params.Search)
	req.URL.RawQuery = q.Encode()

	// Make the HTTP request
	resp, err := c.client.Do(req)
	if err != nil {
		return nil, fmt.Errorf("failed to make request: %w", err)
	}
	defer resp.Body.Close()

	if resp.StatusCode != http.StatusOK {
		return nil, fmt.Errorf("failed to fetch models. Status code: %d", resp.StatusCode)
	}

	// Read the response body
	body, err := io.ReadAll(resp.Body)
	if err != nil {
		return nil, fmt.Errorf("failed to read response body: %w", err)
	}

	// Parse the JSON response
	var models []Model
	if err := json.Unmarshal(body, &models); err != nil {
		return nil, fmt.Errorf("failed to parse JSON response: %w", err)
	}

	return models, nil
}

// GetLatest fetches the latest GGUF models
func (c *Client) GetLatest(searchTerm string, limit int) ([]Model, error) {
	params := SearchParams{
		Sort:      "lastModified",
		Direction: -1,
		Limit:     limit,
		Search:    searchTerm,
	}

	return c.SearchModels(params)
}

// BaseURL returns the current base URL
func (c *Client) BaseURL() string {
	return c.baseURL
}

// SetBaseURL sets a new base URL (useful for testing)
func (c *Client) SetBaseURL(url string) {
	c.baseURL = url
}

// listFilesInPath lists all files in a specific path of a HuggingFace repository (recursive helper)
func (c *Client) listFilesInPath(repoID, path string) ([]FileInfo, error) {
	baseURL := strings.TrimSuffix(c.baseURL, "/api/models")
	var url string
	if path == "" {
		url = fmt.Sprintf("%s/api/models/%s/tree/main", baseURL, repoID)
	} else {
		url = fmt.Sprintf("%s/api/models/%s/tree/main/%s", baseURL, repoID, path)
	}

	req, err := http.NewRequest("GET", url, nil)
	if err != nil {
		return nil, fmt.Errorf("failed to create request: %w", err)
	}

	resp, err := c.client.Do(req)
	if err != nil {
		return nil, fmt.Errorf("failed to make request: %w", err)
	}
	defer resp.Body.Close()

	if resp.StatusCode != http.StatusOK {
		return nil, fmt.Errorf("failed to fetch files. Status code: %d", resp.StatusCode)
	}

	body, err := io.ReadAll(resp.Body)
	if err != nil {
		return nil, fmt.Errorf("failed to read response body: %w", err)
	}

	var items []FileInfo
	if err := json.Unmarshal(body, &items); err != nil {
		return nil, fmt.Errorf("failed to parse JSON response: %w", err)
	}

	var allFiles []FileInfo
	for _, item := range items {
		switch item.Type {
		// If it's a directory/folder, recursively list its contents
		case "directory", "folder":
			// Build the subfolder path
			subPath := item.Path
			if path != "" {
				subPath = fmt.Sprintf("%s/%s", path, item.Path)
			}

			// Recursively get files from subfolder
			// The recursive call will already prepend the subPath to each file's path
			subFiles, err := c.listFilesInPath(repoID, subPath)
			if err != nil {
				return nil, fmt.Errorf("failed to list files in subfolder %s: %w", subPath, err)
			}

			allFiles = append(allFiles, subFiles...)
		case "file":
			// It's a file, prepend the current path to make it relative to root
			//	if path != "" {
			//		item.Path = fmt.Sprintf("%s/%s", path, item.Path)
			//	}
			allFiles = append(allFiles, item)
		}
	}

	return allFiles, nil
}

// ListFiles lists all files in a HuggingFace repository, including files in subfolders
func (c *Client) ListFiles(repoID string) ([]FileInfo, error) {
	return c.listFilesInPath(repoID, "")
}

// GetFileSHA gets the SHA256 checksum for a specific file by searching through the file list
func (c *Client) GetFileSHA(repoID, fileName string) (string, error) {
	files, err := c.ListFiles(repoID)
	if err != nil {
		return "", fmt.Errorf("failed to list files while getting SHA: %w", err)
	}

	for _, file := range files {
		if filepath.Base(file.Path) == fileName {
			if file.LFS != nil && file.LFS.Oid != "" {
				// The LFS OID contains the SHA256 hash
				return file.LFS.Oid, nil
			}
			// If no LFS, return the regular OID
			return file.Oid, nil
		}
	}

	return "", fmt.Errorf("file %s not found", fileName)
}

// GetModelDetails gets detailed information about a model including files and checksums
func (c *Client) GetModelDetails(repoID string) (*ModelDetails, error) {
	files, err := c.ListFiles(repoID)
	if err != nil {
		return nil, fmt.Errorf("failed to list files: %w", err)
	}

	details := &ModelDetails{
		ModelID: repoID,
		Author:  strings.Split(repoID, "/")[0],
		Files:   make([]ModelFile, 0, len(files)),
	}

	// Process each file
	baseURL := strings.TrimSuffix(c.baseURL, "/api/models")
	for _, file := range files {
		fileName := filepath.Base(file.Path)
		isReadme := strings.Contains(strings.ToLower(fileName), "readme")

		// Extract SHA256 from LFS or use OID
		sha256 := ""
		if file.LFS != nil && file.LFS.Oid != "" {
			sha256 = file.LFS.Oid
		} else {
			sha256 = file.Oid
		}

		// Construct the full URL for the file
		// Use /resolve/main/ for downloading files (handles LFS properly)
		fileURL := fmt.Sprintf("%s/%s/resolve/main/%s", baseURL, repoID, file.Path)

		modelFile := ModelFile{
			Path:     file.Path,
			Size:     file.Size,
			SHA256:   sha256,
			IsReadme: isReadme,
			URL:      fileURL,
		}

		details.Files = append(details.Files, modelFile)

		// Set the readme file
		if isReadme && details.ReadmeFile == nil {
			details.ReadmeFile = &modelFile
		}
	}

	return details, nil
}

// GetReadmeContent gets the content of a README file
func (c *Client) GetReadmeContent(repoID, readmePath string) (string, error) {
	baseURL := strings.TrimSuffix(c.baseURL, "/api/models")
	url := fmt.Sprintf("%s/%s/raw/main/%s", baseURL, repoID, readmePath)

	req, err := http.NewRequest("GET", url, nil)
	if err != nil {
		return "", fmt.Errorf("failed to create request: %w", err)
	}

	resp, err := c.client.Do(req)
	if err != nil {
		return "", fmt.Errorf("failed to make request: %w", err)
	}
	defer resp.Body.Close()

	if resp.StatusCode != http.StatusOK {
		return "", fmt.Errorf("failed to fetch readme content. Status code: %d", resp.StatusCode)
	}

	body, err := io.ReadAll(resp.Body)
	if err != nil {
		return "", fmt.Errorf("failed to read response body: %w", err)
	}

	return string(body), nil
}

// FilterFilesByQuantization filters files by quantization type
func FilterFilesByQuantization(files []ModelFile, quantization string) []ModelFile {
	var filtered []ModelFile
	for _, file := range files {
		fileName := filepath.Base(file.Path)
		if strings.Contains(strings.ToLower(fileName), strings.ToLower(quantization)) {
			filtered = append(filtered, file)
		}
	}
	return filtered
}

// FindPreferredModelFile finds the preferred model file based on quantization preferences
func FindPreferredModelFile(files []ModelFile, preferences []string) *ModelFile {
	for _, preference := range preferences {
		for i := range files {
			fileName := filepath.Base(files[i].Path)
			if strings.Contains(strings.ToLower(fileName), strings.ToLower(preference)) {
				return &files[i]
			}
		}
	}
	return nil
}