|
|
package hfapi |
|
|
|
|
|
import ( |
|
|
"encoding/json" |
|
|
"fmt" |
|
|
"io" |
|
|
"net/http" |
|
|
"path/filepath" |
|
|
"strings" |
|
|
) |
|
|
|
|
|
|
|
|
type Model struct { |
|
|
ModelID string `json:"modelId"` |
|
|
Author string `json:"author"` |
|
|
Downloads int `json:"downloads"` |
|
|
LastModified string `json:"lastModified"` |
|
|
PipelineTag string `json:"pipelineTag"` |
|
|
Private bool `json:"private"` |
|
|
Tags []string `json:"tags"` |
|
|
CreatedAt string `json:"createdAt"` |
|
|
UpdatedAt string `json:"updatedAt"` |
|
|
Sha string `json:"sha"` |
|
|
Config map[string]interface{} `json:"config"` |
|
|
ModelIndex string `json:"model_index"` |
|
|
LibraryName string `json:"library_name"` |
|
|
MaskToken string `json:"mask_token"` |
|
|
TokenizerClass string `json:"tokenizer_class"` |
|
|
} |
|
|
|
|
|
|
|
|
type FileInfo struct { |
|
|
Type string `json:"type"` |
|
|
Oid string `json:"oid"` |
|
|
Size int64 `json:"size"` |
|
|
Path string `json:"path"` |
|
|
LFS *LFSInfo `json:"lfs,omitempty"` |
|
|
XetHash string `json:"xetHash,omitempty"` |
|
|
} |
|
|
|
|
|
|
|
|
type LFSInfo struct { |
|
|
Oid string `json:"oid"` |
|
|
Size int64 `json:"size"` |
|
|
PointerSize int `json:"pointerSize"` |
|
|
} |
|
|
|
|
|
|
|
|
type ModelFile struct { |
|
|
Path string |
|
|
Size int64 |
|
|
SHA256 string |
|
|
IsReadme bool |
|
|
URL string |
|
|
} |
|
|
|
|
|
|
|
|
type ModelDetails struct { |
|
|
ModelID string |
|
|
Author string |
|
|
Files []ModelFile |
|
|
ReadmeFile *ModelFile |
|
|
ReadmeContent string |
|
|
} |
|
|
|
|
|
|
|
|
type SearchParams struct { |
|
|
Sort string `json:"sort"` |
|
|
Direction int `json:"direction"` |
|
|
Limit int `json:"limit"` |
|
|
Search string `json:"search"` |
|
|
} |
|
|
|
|
|
|
|
|
type Client struct { |
|
|
baseURL string |
|
|
client *http.Client |
|
|
} |
|
|
|
|
|
|
|
|
func NewClient() *Client { |
|
|
return &Client{ |
|
|
baseURL: "https://huggingface.co/api/models", |
|
|
client: &http.Client{}, |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
func (c *Client) SearchModels(params SearchParams) ([]Model, error) { |
|
|
req, err := http.NewRequest("GET", c.baseURL, nil) |
|
|
if err != nil { |
|
|
return nil, fmt.Errorf("failed to create request: %w", err) |
|
|
} |
|
|
|
|
|
|
|
|
q := req.URL.Query() |
|
|
q.Add("sort", params.Sort) |
|
|
q.Add("direction", fmt.Sprintf("%d", params.Direction)) |
|
|
q.Add("limit", fmt.Sprintf("%d", params.Limit)) |
|
|
q.Add("search", params.Search) |
|
|
req.URL.RawQuery = q.Encode() |
|
|
|
|
|
|
|
|
resp, err := c.client.Do(req) |
|
|
if err != nil { |
|
|
return nil, fmt.Errorf("failed to make request: %w", err) |
|
|
} |
|
|
defer resp.Body.Close() |
|
|
|
|
|
if resp.StatusCode != http.StatusOK { |
|
|
return nil, fmt.Errorf("failed to fetch models. Status code: %d", resp.StatusCode) |
|
|
} |
|
|
|
|
|
|
|
|
body, err := io.ReadAll(resp.Body) |
|
|
if err != nil { |
|
|
return nil, fmt.Errorf("failed to read response body: %w", err) |
|
|
} |
|
|
|
|
|
|
|
|
var models []Model |
|
|
if err := json.Unmarshal(body, &models); err != nil { |
|
|
return nil, fmt.Errorf("failed to parse JSON response: %w", err) |
|
|
} |
|
|
|
|
|
return models, nil |
|
|
} |
|
|
|
|
|
|
|
|
func (c *Client) GetLatest(searchTerm string, limit int) ([]Model, error) { |
|
|
params := SearchParams{ |
|
|
Sort: "lastModified", |
|
|
Direction: -1, |
|
|
Limit: limit, |
|
|
Search: searchTerm, |
|
|
} |
|
|
|
|
|
return c.SearchModels(params) |
|
|
} |
|
|
|
|
|
|
|
|
func (c *Client) BaseURL() string { |
|
|
return c.baseURL |
|
|
} |
|
|
|
|
|
|
|
|
func (c *Client) SetBaseURL(url string) { |
|
|
c.baseURL = url |
|
|
} |
|
|
|
|
|
|
|
|
func (c *Client) listFilesInPath(repoID, path string) ([]FileInfo, error) { |
|
|
baseURL := strings.TrimSuffix(c.baseURL, "/api/models") |
|
|
var url string |
|
|
if path == "" { |
|
|
url = fmt.Sprintf("%s/api/models/%s/tree/main", baseURL, repoID) |
|
|
} else { |
|
|
url = fmt.Sprintf("%s/api/models/%s/tree/main/%s", baseURL, repoID, path) |
|
|
} |
|
|
|
|
|
req, err := http.NewRequest("GET", url, nil) |
|
|
if err != nil { |
|
|
return nil, fmt.Errorf("failed to create request: %w", err) |
|
|
} |
|
|
|
|
|
resp, err := c.client.Do(req) |
|
|
if err != nil { |
|
|
return nil, fmt.Errorf("failed to make request: %w", err) |
|
|
} |
|
|
defer resp.Body.Close() |
|
|
|
|
|
if resp.StatusCode != http.StatusOK { |
|
|
return nil, fmt.Errorf("failed to fetch files. Status code: %d", resp.StatusCode) |
|
|
} |
|
|
|
|
|
body, err := io.ReadAll(resp.Body) |
|
|
if err != nil { |
|
|
return nil, fmt.Errorf("failed to read response body: %w", err) |
|
|
} |
|
|
|
|
|
var items []FileInfo |
|
|
if err := json.Unmarshal(body, &items); err != nil { |
|
|
return nil, fmt.Errorf("failed to parse JSON response: %w", err) |
|
|
} |
|
|
|
|
|
var allFiles []FileInfo |
|
|
for _, item := range items { |
|
|
switch item.Type { |
|
|
|
|
|
case "directory", "folder": |
|
|
|
|
|
subPath := item.Path |
|
|
if path != "" { |
|
|
subPath = fmt.Sprintf("%s/%s", path, item.Path) |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
subFiles, err := c.listFilesInPath(repoID, subPath) |
|
|
if err != nil { |
|
|
return nil, fmt.Errorf("failed to list files in subfolder %s: %w", subPath, err) |
|
|
} |
|
|
|
|
|
allFiles = append(allFiles, subFiles...) |
|
|
case "file": |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
allFiles = append(allFiles, item) |
|
|
} |
|
|
} |
|
|
|
|
|
return allFiles, nil |
|
|
} |
|
|
|
|
|
|
|
|
func (c *Client) ListFiles(repoID string) ([]FileInfo, error) { |
|
|
return c.listFilesInPath(repoID, "") |
|
|
} |
|
|
|
|
|
|
|
|
func (c *Client) GetFileSHA(repoID, fileName string) (string, error) { |
|
|
files, err := c.ListFiles(repoID) |
|
|
if err != nil { |
|
|
return "", fmt.Errorf("failed to list files while getting SHA: %w", err) |
|
|
} |
|
|
|
|
|
for _, file := range files { |
|
|
if filepath.Base(file.Path) == fileName { |
|
|
if file.LFS != nil && file.LFS.Oid != "" { |
|
|
|
|
|
return file.LFS.Oid, nil |
|
|
} |
|
|
|
|
|
return file.Oid, nil |
|
|
} |
|
|
} |
|
|
|
|
|
return "", fmt.Errorf("file %s not found", fileName) |
|
|
} |
|
|
|
|
|
|
|
|
func (c *Client) GetModelDetails(repoID string) (*ModelDetails, error) { |
|
|
files, err := c.ListFiles(repoID) |
|
|
if err != nil { |
|
|
return nil, fmt.Errorf("failed to list files: %w", err) |
|
|
} |
|
|
|
|
|
details := &ModelDetails{ |
|
|
ModelID: repoID, |
|
|
Author: strings.Split(repoID, "/")[0], |
|
|
Files: make([]ModelFile, 0, len(files)), |
|
|
} |
|
|
|
|
|
|
|
|
baseURL := strings.TrimSuffix(c.baseURL, "/api/models") |
|
|
for _, file := range files { |
|
|
fileName := filepath.Base(file.Path) |
|
|
isReadme := strings.Contains(strings.ToLower(fileName), "readme") |
|
|
|
|
|
|
|
|
sha256 := "" |
|
|
if file.LFS != nil && file.LFS.Oid != "" { |
|
|
sha256 = file.LFS.Oid |
|
|
} else { |
|
|
sha256 = file.Oid |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
fileURL := fmt.Sprintf("%s/%s/resolve/main/%s", baseURL, repoID, file.Path) |
|
|
|
|
|
modelFile := ModelFile{ |
|
|
Path: file.Path, |
|
|
Size: file.Size, |
|
|
SHA256: sha256, |
|
|
IsReadme: isReadme, |
|
|
URL: fileURL, |
|
|
} |
|
|
|
|
|
details.Files = append(details.Files, modelFile) |
|
|
|
|
|
|
|
|
if isReadme && details.ReadmeFile == nil { |
|
|
details.ReadmeFile = &modelFile |
|
|
} |
|
|
} |
|
|
|
|
|
return details, nil |
|
|
} |
|
|
|
|
|
|
|
|
func (c *Client) GetReadmeContent(repoID, readmePath string) (string, error) { |
|
|
baseURL := strings.TrimSuffix(c.baseURL, "/api/models") |
|
|
url := fmt.Sprintf("%s/%s/raw/main/%s", baseURL, repoID, readmePath) |
|
|
|
|
|
req, err := http.NewRequest("GET", url, nil) |
|
|
if err != nil { |
|
|
return "", fmt.Errorf("failed to create request: %w", err) |
|
|
} |
|
|
|
|
|
resp, err := c.client.Do(req) |
|
|
if err != nil { |
|
|
return "", fmt.Errorf("failed to make request: %w", err) |
|
|
} |
|
|
defer resp.Body.Close() |
|
|
|
|
|
if resp.StatusCode != http.StatusOK { |
|
|
return "", fmt.Errorf("failed to fetch readme content. Status code: %d", resp.StatusCode) |
|
|
} |
|
|
|
|
|
body, err := io.ReadAll(resp.Body) |
|
|
if err != nil { |
|
|
return "", fmt.Errorf("failed to read response body: %w", err) |
|
|
} |
|
|
|
|
|
return string(body), nil |
|
|
} |
|
|
|
|
|
|
|
|
func FilterFilesByQuantization(files []ModelFile, quantization string) []ModelFile { |
|
|
var filtered []ModelFile |
|
|
for _, file := range files { |
|
|
fileName := filepath.Base(file.Path) |
|
|
if strings.Contains(strings.ToLower(fileName), strings.ToLower(quantization)) { |
|
|
filtered = append(filtered, file) |
|
|
} |
|
|
} |
|
|
return filtered |
|
|
} |
|
|
|
|
|
|
|
|
func FindPreferredModelFile(files []ModelFile, preferences []string) *ModelFile { |
|
|
for _, preference := range preferences { |
|
|
for i := range files { |
|
|
fileName := filepath.Base(files[i].Path) |
|
|
if strings.Contains(strings.ToLower(fileName), strings.ToLower(preference)) { |
|
|
return &files[i] |
|
|
} |
|
|
} |
|
|
} |
|
|
return nil |
|
|
} |
|
|
|