|
|
package oci |
|
|
|
|
|
import ( |
|
|
"context" |
|
|
"errors" |
|
|
"fmt" |
|
|
"io" |
|
|
"net/http" |
|
|
"os" |
|
|
"runtime" |
|
|
"strconv" |
|
|
"strings" |
|
|
"syscall" |
|
|
"time" |
|
|
|
|
|
"github.com/containerd/containerd/archive" |
|
|
registrytypes "github.com/docker/docker/api/types/registry" |
|
|
"github.com/google/go-containerregistry/pkg/authn" |
|
|
"github.com/google/go-containerregistry/pkg/logs" |
|
|
"github.com/google/go-containerregistry/pkg/name" |
|
|
v1 "github.com/google/go-containerregistry/pkg/v1" |
|
|
"github.com/google/go-containerregistry/pkg/v1/mutate" |
|
|
"github.com/google/go-containerregistry/pkg/v1/remote" |
|
|
"github.com/google/go-containerregistry/pkg/v1/remote/transport" |
|
|
"github.com/google/go-containerregistry/pkg/v1/tarball" |
|
|
"github.com/mudler/LocalAI/pkg/xio" |
|
|
) |
|
|
|
|
|
|
|
|
type staticAuth struct { |
|
|
auth *registrytypes.AuthConfig |
|
|
} |
|
|
|
|
|
func (s staticAuth) Authorization() (*authn.AuthConfig, error) { |
|
|
if s.auth == nil { |
|
|
return nil, nil |
|
|
} |
|
|
return &authn.AuthConfig{ |
|
|
Username: s.auth.Username, |
|
|
Password: s.auth.Password, |
|
|
Auth: s.auth.Auth, |
|
|
IdentityToken: s.auth.IdentityToken, |
|
|
RegistryToken: s.auth.RegistryToken, |
|
|
}, nil |
|
|
} |
|
|
|
|
|
var defaultRetryBackoff = remote.Backoff{ |
|
|
Duration: 1.0 * time.Second, |
|
|
Factor: 3.0, |
|
|
Jitter: 0.1, |
|
|
Steps: 3, |
|
|
} |
|
|
|
|
|
var defaultRetryPredicate = func(err error) bool { |
|
|
if err == nil { |
|
|
return false |
|
|
} |
|
|
|
|
|
if errors.Is(err, io.ErrUnexpectedEOF) || errors.Is(err, io.EOF) || errors.Is(err, syscall.EPIPE) || errors.Is(err, syscall.ECONNRESET) || strings.Contains(err.Error(), "connection refused") { |
|
|
logs.Warn.Printf("retrying %v", err) |
|
|
return true |
|
|
} |
|
|
return false |
|
|
} |
|
|
|
|
|
type progressWriter struct { |
|
|
written int64 |
|
|
total int64 |
|
|
fileName string |
|
|
downloadStatus func(string, string, string, float64) |
|
|
} |
|
|
|
|
|
func formatBytes(bytes int64) string { |
|
|
const unit = 1024 |
|
|
if bytes < unit { |
|
|
return strconv.FormatInt(bytes, 10) + " B" |
|
|
} |
|
|
div, exp := int64(unit), 0 |
|
|
for n := bytes / unit; n >= unit; n /= unit { |
|
|
div *= unit |
|
|
exp++ |
|
|
} |
|
|
return fmt.Sprintf("%.1f %ciB", float64(bytes)/float64(div), "KMGTPE"[exp]) |
|
|
} |
|
|
|
|
|
func (pw *progressWriter) Write(p []byte) (int, error) { |
|
|
n := len(p) |
|
|
pw.written += int64(n) |
|
|
if pw.total > 0 { |
|
|
percentage := float64(pw.written) / float64(pw.total) * 100 |
|
|
|
|
|
pw.downloadStatus(pw.fileName, formatBytes(pw.written), formatBytes(pw.total), percentage) |
|
|
} else { |
|
|
pw.downloadStatus(pw.fileName, formatBytes(pw.written), "", 0) |
|
|
} |
|
|
|
|
|
return n, nil |
|
|
} |
|
|
|
|
|
|
|
|
func ExtractOCIImage(ctx context.Context, img v1.Image, imageRef string, targetDestination string, downloadStatus func(string, string, string, float64)) error { |
|
|
|
|
|
tmpTarFile, err := os.CreateTemp("", "localai-oci-*.tar") |
|
|
if err != nil { |
|
|
return fmt.Errorf("failed to create temporary tar file: %v", err) |
|
|
} |
|
|
defer os.Remove(tmpTarFile.Name()) |
|
|
defer tmpTarFile.Close() |
|
|
|
|
|
|
|
|
err = DownloadOCIImageTar(ctx, img, imageRef, tmpTarFile.Name(), downloadStatus) |
|
|
if err != nil { |
|
|
return fmt.Errorf("failed to download image tar: %v", err) |
|
|
} |
|
|
|
|
|
|
|
|
err = ExtractOCIImageFromTar(ctx, tmpTarFile.Name(), imageRef, targetDestination, downloadStatus) |
|
|
if err != nil { |
|
|
return fmt.Errorf("failed to extract image tar: %v", err) |
|
|
} |
|
|
|
|
|
return nil |
|
|
} |
|
|
|
|
|
func ParseImageParts(image string) (tag, repository, dstimage string) { |
|
|
tag = "latest" |
|
|
repository = "library" |
|
|
if strings.Contains(image, ":") { |
|
|
parts := strings.Split(image, ":") |
|
|
image = parts[0] |
|
|
tag = parts[1] |
|
|
} |
|
|
if strings.Contains("/", image) { |
|
|
parts := strings.Split(image, "/") |
|
|
repository = parts[0] |
|
|
image = parts[1] |
|
|
} |
|
|
dstimage = image |
|
|
return tag, repository, image |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
func GetImage(targetImage, targetPlatform string, auth *registrytypes.AuthConfig, t http.RoundTripper) (v1.Image, error) { |
|
|
var platform *v1.Platform |
|
|
var image v1.Image |
|
|
var err error |
|
|
|
|
|
if targetPlatform != "" { |
|
|
platform, err = v1.ParsePlatform(targetPlatform) |
|
|
if err != nil { |
|
|
return image, err |
|
|
} |
|
|
} else { |
|
|
platform, err = v1.ParsePlatform(fmt.Sprintf("%s/%s", runtime.GOOS, runtime.GOARCH)) |
|
|
if err != nil { |
|
|
return image, err |
|
|
} |
|
|
} |
|
|
|
|
|
ref, err := name.ParseReference(targetImage) |
|
|
if err != nil { |
|
|
return image, err |
|
|
} |
|
|
|
|
|
if t == nil { |
|
|
t = http.DefaultTransport |
|
|
} |
|
|
|
|
|
tr := transport.NewRetry(t, |
|
|
transport.WithRetryBackoff(defaultRetryBackoff), |
|
|
transport.WithRetryPredicate(defaultRetryPredicate), |
|
|
) |
|
|
|
|
|
opts := []remote.Option{ |
|
|
remote.WithTransport(tr), |
|
|
remote.WithPlatform(*platform), |
|
|
} |
|
|
if auth != nil { |
|
|
opts = append(opts, remote.WithAuth(staticAuth{auth})) |
|
|
} else { |
|
|
opts = append(opts, remote.WithAuthFromKeychain(authn.DefaultKeychain)) |
|
|
} |
|
|
|
|
|
image, err = remote.Image(ref, opts...) |
|
|
|
|
|
return image, err |
|
|
} |
|
|
|
|
|
func GetOCIImageSize(targetImage, targetPlatform string, auth *registrytypes.AuthConfig, t http.RoundTripper) (int64, error) { |
|
|
var size int64 |
|
|
var img v1.Image |
|
|
var err error |
|
|
|
|
|
img, err = GetImage(targetImage, targetPlatform, auth, t) |
|
|
if err != nil { |
|
|
return size, err |
|
|
} |
|
|
layers, _ := img.Layers() |
|
|
for _, layer := range layers { |
|
|
s, _ := layer.Size() |
|
|
size += s |
|
|
} |
|
|
|
|
|
return size, nil |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
func DownloadOCIImageTar(ctx context.Context, img v1.Image, imageRef string, tarFilePath string, downloadStatus func(string, string, string, float64)) error { |
|
|
|
|
|
layers, err := img.Layers() |
|
|
if err != nil { |
|
|
return fmt.Errorf("failed to get layers: %v", err) |
|
|
} |
|
|
|
|
|
|
|
|
var totalCompressedSize int64 |
|
|
for _, layer := range layers { |
|
|
size, err := layer.Size() |
|
|
if err != nil { |
|
|
return fmt.Errorf("failed to get layer size: %v", err) |
|
|
} |
|
|
totalCompressedSize += size |
|
|
} |
|
|
|
|
|
|
|
|
tmpDir, err := os.MkdirTemp("", "localai-oci-layers-*") |
|
|
if err != nil { |
|
|
return fmt.Errorf("failed to create temporary directory: %v", err) |
|
|
} |
|
|
defer os.RemoveAll(tmpDir) |
|
|
|
|
|
|
|
|
var downloadedLayers []v1.Layer |
|
|
var downloadedSize int64 |
|
|
|
|
|
|
|
|
imageName := imageRef |
|
|
for i, layer := range layers { |
|
|
layerSize, err := layer.Size() |
|
|
if err != nil { |
|
|
return fmt.Errorf("failed to get layer size: %v", err) |
|
|
} |
|
|
|
|
|
|
|
|
layerFile := fmt.Sprintf("%s/layer-%d.tar.gz", tmpDir, i) |
|
|
file, err := os.Create(layerFile) |
|
|
if err != nil { |
|
|
return fmt.Errorf("failed to create layer file: %v", err) |
|
|
} |
|
|
|
|
|
|
|
|
var writer io.Writer = file |
|
|
if downloadStatus != nil { |
|
|
writer = io.MultiWriter(file, &progressWriter{ |
|
|
total: totalCompressedSize, |
|
|
fileName: fmt.Sprintf("Downloading %d/%d %s", i+1, len(layers), imageName), |
|
|
downloadStatus: downloadStatus, |
|
|
}) |
|
|
} |
|
|
|
|
|
|
|
|
layerReader, err := layer.Compressed() |
|
|
if err != nil { |
|
|
file.Close() |
|
|
return fmt.Errorf("failed to get compressed layer: %v", err) |
|
|
} |
|
|
|
|
|
_, err = xio.Copy(ctx, writer, layerReader) |
|
|
file.Close() |
|
|
if err != nil { |
|
|
return fmt.Errorf("failed to download layer %d: %v", i, err) |
|
|
} |
|
|
|
|
|
|
|
|
downloadedLayer, err := tarball.LayerFromFile(layerFile) |
|
|
if err != nil { |
|
|
return fmt.Errorf("failed to load downloaded layer: %v", err) |
|
|
} |
|
|
|
|
|
downloadedLayers = append(downloadedLayers, downloadedLayer) |
|
|
downloadedSize += layerSize |
|
|
} |
|
|
|
|
|
|
|
|
localImg, err := mutate.AppendLayers(img, downloadedLayers...) |
|
|
if err != nil { |
|
|
return fmt.Errorf("failed to create local image: %v", err) |
|
|
} |
|
|
|
|
|
|
|
|
tarFile, err := os.Create(tarFilePath) |
|
|
if err != nil { |
|
|
return fmt.Errorf("failed to create tar file: %v", err) |
|
|
} |
|
|
defer tarFile.Close() |
|
|
|
|
|
|
|
|
extractReader := mutate.Extract(localImg) |
|
|
_, err = xio.Copy(ctx, tarFile, extractReader) |
|
|
if err != nil { |
|
|
return fmt.Errorf("failed to extract uncompressed tar: %v", err) |
|
|
} |
|
|
|
|
|
return nil |
|
|
} |
|
|
|
|
|
|
|
|
func ExtractOCIImageFromTar(ctx context.Context, tarFilePath, imageRef, targetDestination string, downloadStatus func(string, string, string, float64)) error { |
|
|
|
|
|
tarFile, err := os.Open(tarFilePath) |
|
|
if err != nil { |
|
|
return fmt.Errorf("failed to open tar file: %v", err) |
|
|
} |
|
|
defer tarFile.Close() |
|
|
|
|
|
|
|
|
fileInfo, err := tarFile.Stat() |
|
|
if err != nil { |
|
|
return fmt.Errorf("failed to get file info: %v", err) |
|
|
} |
|
|
|
|
|
var reader io.Reader = tarFile |
|
|
if downloadStatus != nil { |
|
|
reader = io.TeeReader(tarFile, &progressWriter{ |
|
|
total: fileInfo.Size(), |
|
|
fileName: fmt.Sprintf("Extracting %s", imageRef), |
|
|
downloadStatus: downloadStatus, |
|
|
}) |
|
|
} |
|
|
|
|
|
|
|
|
_, err = archive.Apply(ctx, |
|
|
targetDestination, reader, |
|
|
archive.WithNoSameOwner()) |
|
|
|
|
|
return err |
|
|
} |
|
|
|
|
|
|
|
|
func GetOCIImageUncompressedSize(targetImage, targetPlatform string, auth *registrytypes.AuthConfig, t http.RoundTripper) (int64, error) { |
|
|
var totalSize int64 |
|
|
var img v1.Image |
|
|
var err error |
|
|
|
|
|
img, err = GetImage(targetImage, targetPlatform, auth, t) |
|
|
if err != nil { |
|
|
return totalSize, err |
|
|
} |
|
|
|
|
|
layers, err := img.Layers() |
|
|
if err != nil { |
|
|
return totalSize, err |
|
|
} |
|
|
|
|
|
for _, layer := range layers { |
|
|
|
|
|
size, err := layer.Size() |
|
|
if err != nil { |
|
|
return totalSize, err |
|
|
} |
|
|
totalSize += size |
|
|
} |
|
|
|
|
|
return totalSize, nil |
|
|
} |
|
|
|