OpenList / internal /net /request.go
lin7zhi's picture
Upload folder using huggingface_hub
619f93d verified
package net
import (
"context"
"errors"
"fmt"
"io"
"net/http"
"strconv"
"strings"
"sync"
"time"
"github.com/OpenListTeam/OpenList/v4/internal/conf"
"github.com/OpenListTeam/OpenList/v4/internal/errs"
"github.com/OpenListTeam/OpenList/v4/internal/model"
"github.com/OpenListTeam/OpenList/v4/pkg/utils"
"github.com/rclone/rclone/lib/mmap"
"github.com/OpenListTeam/OpenList/v4/pkg/http_range"
"github.com/aws/aws-sdk-go/aws/awsutil"
log "github.com/sirupsen/logrus"
)
// DefaultDownloadPartSize is the default range of bytes to get at a time when
// using Download().
const DefaultDownloadPartSize = utils.MB * 8
// DefaultDownloadConcurrency is the default number of goroutines to spin up
// when using Download().
const DefaultDownloadConcurrency = 2
// DefaultPartBodyMaxRetries is the default number of retries to make when a part fails to download.
const DefaultPartBodyMaxRetries = 3
var DefaultConcurrencyLimit *ConcurrencyLimit
type Downloader struct {
PartSize int
// PartBodyMaxRetries is the number of retry attempts to make for failed part downloads.
PartBodyMaxRetries int
// The number of goroutines to spin up in parallel when sending parts.
// If this is set to zero, the DefaultDownloadConcurrency value will be used.
//
// Concurrency of 1 will download the parts sequentially.
Concurrency int
//RequestParam HttpRequestParams
HttpClient HttpRequestFunc
*ConcurrencyLimit
}
type HttpRequestFunc func(ctx context.Context, params *HttpRequestParams) (*http.Response, error)
func NewDownloader(options ...func(*Downloader)) *Downloader {
d := &Downloader{ //允许不设置的选项
PartBodyMaxRetries: DefaultPartBodyMaxRetries,
ConcurrencyLimit: DefaultConcurrencyLimit,
}
for _, option := range options {
option(d)
}
return d
}
// Download The Downloader makes multi-thread http requests to remote URL, each chunk(except last one) has PartSize,
// cache some data, then return Reader with assembled data
// Supports range, do not support unknown FileSize, and will fail if FileSize is incorrect
// memory usage is at about Concurrency*PartSize, use this wisely
func (d Downloader) Download(ctx context.Context, p *HttpRequestParams) (readCloser io.ReadCloser, err error) {
var finalP HttpRequestParams
awsutil.Copy(&finalP, p)
if finalP.Range.Length < 0 || finalP.Range.Start+finalP.Range.Length > finalP.Size {
finalP.Range.Length = finalP.Size - finalP.Range.Start
}
impl := downloader{params: &finalP, cfg: d, ctx: ctx}
// Ensures we don't need nil checks later on
// 必需的选项
if impl.cfg.Concurrency == 0 {
impl.cfg.Concurrency = DefaultDownloadConcurrency
}
if impl.cfg.PartSize == 0 {
impl.cfg.PartSize = DefaultDownloadPartSize
}
if conf.MaxBufferLimit > 0 && impl.cfg.PartSize > conf.MaxBufferLimit {
impl.cfg.PartSize = conf.MaxBufferLimit
}
if impl.cfg.HttpClient == nil {
impl.cfg.HttpClient = DefaultHttpRequestFunc
}
return impl.download()
}
// downloader is the implementation structure used internally by Downloader.
type downloader struct {
ctx context.Context
cancel context.CancelCauseFunc
cfg Downloader
params *HttpRequestParams //http request params
chunkChannel chan chunk //chunk chanel
//wg sync.WaitGroup
m sync.Mutex
nextChunk int //next chunk id
bufs []*Buf
written int64 //total bytes of file downloaded from remote
err error
concurrency int //剩余的并发数,递减。到0时停止并发
maxPart int //有多少个分片
pos int64
maxPos int64
m2 sync.Mutex
readingID int // 正在被读取的id
}
type ConcurrencyLimit struct {
_m sync.Mutex
Limit int // 需要大于0
}
var ErrExceedMaxConcurrency = HttpStatusCodeError(http.StatusTooManyRequests)
func (l *ConcurrencyLimit) sub() error {
l._m.Lock()
defer l._m.Unlock()
if l.Limit-1 < 0 {
return ErrExceedMaxConcurrency
}
l.Limit--
// log.Debugf("ConcurrencyLimit.sub: %d", l.Limit)
return nil
}
func (l *ConcurrencyLimit) add() {
l._m.Lock()
defer l._m.Unlock()
l.Limit++
// log.Debugf("ConcurrencyLimit.add: %d", l.Limit)
}
// 检测是否超过限制
func (d *downloader) concurrencyCheck() error {
if d.cfg.ConcurrencyLimit != nil {
return d.cfg.ConcurrencyLimit.sub()
}
return nil
}
func (d *downloader) concurrencyFinish() {
if d.cfg.ConcurrencyLimit != nil {
d.cfg.ConcurrencyLimit.add()
}
}
// download performs the implementation of the object download across ranged GETs.
func (d *downloader) download() (io.ReadCloser, error) {
if err := d.concurrencyCheck(); err != nil {
return nil, err
}
maxPart := 1
if d.params.Range.Length > int64(d.cfg.PartSize) {
maxPart = int((d.params.Range.Length + int64(d.cfg.PartSize) - 1) / int64(d.cfg.PartSize))
}
if maxPart < d.cfg.Concurrency {
d.cfg.Concurrency = maxPart
}
log.Debugf("cfgConcurrency:%d", d.cfg.Concurrency)
if maxPart == 1 {
resp, err := d.cfg.HttpClient(d.ctx, d.params)
if err != nil {
d.concurrencyFinish()
return nil, err
}
closeFunc := resp.Body.Close
resp.Body = utils.NewReadCloser(resp.Body, func() error {
d.m.Lock()
defer d.m.Unlock()
var err error
if closeFunc != nil {
d.concurrencyFinish()
err = closeFunc()
closeFunc = nil
}
return err
})
return resp.Body, nil
}
d.ctx, d.cancel = context.WithCancelCause(d.ctx)
// workers
d.chunkChannel = make(chan chunk, d.cfg.Concurrency)
d.maxPart = maxPart
d.pos = d.params.Range.Start
d.maxPos = d.params.Range.Start + d.params.Range.Length
d.concurrency = d.cfg.Concurrency
_ = d.sendChunkTask(true)
var rc io.ReadCloser = NewMultiReadCloser(d.bufs[0], d.interrupt, d.finishBuf)
// Return error
return rc, d.err
}
func (d *downloader) sendChunkTask(newConcurrency bool) error {
d.m.Lock()
defer d.m.Unlock()
isNewBuf := d.concurrency > 0
if newConcurrency {
if d.concurrency <= 0 {
return nil
}
if d.nextChunk > 0 { // 第一个不检查,因为已经检查过了
if err := d.concurrencyCheck(); err != nil {
return err
}
}
d.concurrency--
go d.downloadPart()
}
var buf *Buf
if isNewBuf {
buf = NewBuf(d.ctx, d.cfg.PartSize)
d.bufs = append(d.bufs, buf)
} else {
buf = d.getBuf(d.nextChunk)
}
if d.pos < d.maxPos {
finalSize := int64(d.cfg.PartSize)
switch d.nextChunk {
case 0:
// 最小分片在前面有助视频播放?
firstSize := d.params.Range.Length % finalSize
if firstSize > 0 {
minSize := finalSize / 2
if firstSize < minSize { // 最小分片太小就调整到一半
finalSize = minSize
} else {
finalSize = firstSize
}
}
case 1:
firstSize := d.params.Range.Length % finalSize
minSize := finalSize / 2
if firstSize > 0 && firstSize < minSize {
finalSize += firstSize - minSize
}
}
err := buf.Reset(int(finalSize))
if err != nil {
return err
}
ch := chunk{
start: d.pos,
size: finalSize,
id: d.nextChunk,
buf: buf,
newConcurrency: newConcurrency,
}
d.pos += finalSize
d.nextChunk++
d.chunkChannel <- ch
return nil
}
return nil
}
// when the final reader Close, we interrupt
func (d *downloader) interrupt() error {
d.m.Lock()
defer d.m.Unlock()
err := d.err
if err == nil && d.written != d.params.Range.Length {
log.Debugf("Downloader interrupt before finish")
err := fmt.Errorf("interrupted")
d.err = err
}
close(d.chunkChannel)
if d.bufs != nil {
d.cancel(err)
for _, buf := range d.bufs {
buf.Close()
}
d.bufs = nil
if d.concurrency > 0 {
d.concurrency = -d.concurrency
}
log.Debugf("maxConcurrency:%d", d.cfg.Concurrency+d.concurrency)
}
return err
}
func (d *downloader) getBuf(id int) (b *Buf) {
return d.bufs[id%len(d.bufs)]
}
func (d *downloader) finishBuf(id int) (isLast bool, nextBuf *Buf) {
id++
if id >= d.maxPart {
return true, nil
}
_ = d.sendChunkTask(false)
d.readingID = id
return false, d.getBuf(id)
}
// downloadPart is an individual goroutine worker reading from the ch channel
// and performing Http request on the data with a given byte range.
func (d *downloader) downloadPart() {
defer d.concurrencyFinish()
for {
select {
case <-d.ctx.Done():
return
case c, ok := <-d.chunkChannel:
if !ok {
return
}
if d.getErr() != nil {
// Drain the channel if there is an error, to prevent deadlocking
// of download producer.
return
}
if err := d.downloadChunk(&c); err != nil {
if err == errCancelConcurrency {
return
}
if err == context.Canceled {
if e := context.Cause(d.ctx); e != nil {
err = e
}
}
d.setErr(err)
d.cancel(err)
return
}
}
}
}
// downloadChunk downloads the chunk
func (d *downloader) downloadChunk(ch *chunk) error {
log.Debugf("start chunk_%d, %+v", ch.id, ch)
params := d.getParamsFromChunk(ch)
var n int64
var err error
for retry := 0; retry <= d.cfg.PartBodyMaxRetries; retry++ {
if d.getErr() != nil {
return nil
}
n, err = d.tryDownloadChunk(params, ch)
if err == nil {
d.incrWritten(n)
log.Debugf("chunk_%d downloaded", ch.id)
break
}
if d.getErr() != nil {
return nil
}
if utils.IsCanceled(d.ctx) {
return d.ctx.Err()
}
// Check if the returned error is an errNeedRetry.
// If this occurs we unwrap the err to set the underlying error
// and attempt any remaining retries.
if e, ok := err.(*errNeedRetry); ok {
err = e.Unwrap()
if n > 0 {
// 测试:下载时 断开openlist向云盘发起的下载连接
// 校验:下载完后校验文件哈希值 一致
d.incrWritten(n)
ch.start += n
ch.size -= n
params.Range.Start = ch.start
params.Range.Length = ch.size
}
log.Warnf("err chunk_%d, object part download error %s, retrying attempt %d. %v",
ch.id, params.URL, retry, err)
} else if err == errInfiniteRetry {
retry--
continue
} else {
break
}
}
return err
}
var errCancelConcurrency = errors.New("cancel concurrency")
var errInfiniteRetry = errors.New("infinite retry")
func (d *downloader) tryDownloadChunk(params *HttpRequestParams, ch *chunk) (int64, error) {
resp, err := d.cfg.HttpClient(d.ctx, params)
if err != nil {
statusCode, ok := errs.UnwrapOrSelf(err).(HttpStatusCodeError)
if !ok {
return 0, err
}
if statusCode == http.StatusRequestedRangeNotSatisfiable {
return 0, err
}
if ch.id == 0 { //第1个任务 有限的重试,超过重试就会结束请求
switch statusCode {
default:
return 0, err
case http.StatusTooManyRequests:
case http.StatusBadGateway:
case http.StatusServiceUnavailable:
case http.StatusGatewayTimeout:
}
<-time.After(time.Millisecond * 200)
return 0, &errNeedRetry{err: err}
}
// 来到这 说明第1个分片下载 连接成功了
// 后续分片下载出错都当超载处理
log.Debugf("err chunk_%d, try downloading:%v", ch.id, err)
d.m.Lock()
isCancelConcurrency := ch.newConcurrency
if d.concurrency > 0 { // 取消剩余的并发任务
// 用于计算实际的并发数
d.concurrency = -d.concurrency
isCancelConcurrency = true
}
if isCancelConcurrency {
d.concurrency--
d.chunkChannel <- *ch
d.m.Unlock()
return 0, errCancelConcurrency
}
d.m.Unlock()
if ch.id != d.readingID { //正在被读取的优先重试
d.m2.Lock()
defer d.m2.Unlock()
<-time.After(time.Millisecond * 200)
}
return 0, errInfiniteRetry
}
defer resp.Body.Close()
//only check file size on the first task
if ch.id == 0 {
err = d.checkTotalBytes(resp)
if err != nil {
return 0, err
}
}
_ = d.sendChunkTask(true)
n, err := utils.CopyWithBuffer(ch.buf, resp.Body)
if err != nil {
return n, &errNeedRetry{err: err}
}
if n != ch.size {
err = fmt.Errorf("chunk download size incorrect, expected=%d, got=%d", ch.size, n)
return n, &errNeedRetry{err: err}
}
return n, nil
}
func (d *downloader) getParamsFromChunk(ch *chunk) *HttpRequestParams {
var params HttpRequestParams
awsutil.Copy(&params, d.params)
// Get the getBuf byte range of data
params.Range = http_range.Range{Start: ch.start, Length: ch.size}
return &params
}
func (d *downloader) checkTotalBytes(resp *http.Response) error {
var err error
totalBytes := int64(-1)
contentRange := resp.Header.Get("Content-Range")
if len(contentRange) == 0 {
// ContentRange is nil when the full file contents is provided, and
// is not chunked. Use ContentLength instead.
if resp.ContentLength > 0 {
totalBytes = resp.ContentLength
}
} else {
parts := strings.Split(contentRange, "/")
total := int64(-1)
// Checking for whether a numbered total exists
// If one does not exist, we will assume the total to be -1, undefined,
// and sequentially download each chunk until hitting a 416 error
totalStr := parts[len(parts)-1]
if totalStr != "*" {
total, err = strconv.ParseInt(totalStr, 10, 64)
if err != nil {
err = fmt.Errorf("failed extracting file size")
}
} else {
err = fmt.Errorf("file size unknown")
}
totalBytes = total
}
if totalBytes != d.params.Size && err == nil {
err = fmt.Errorf("expect file size=%d unmatch remote report size=%d, need refresh cache", d.params.Size, totalBytes)
}
if err != nil {
// _ = d.interrupt()
d.setErr(err)
d.cancel(err)
}
return err
}
func (d *downloader) incrWritten(n int64) {
d.m.Lock()
defer d.m.Unlock()
d.written += n
}
// getErr is a thread-safe getter for the error object
func (d *downloader) getErr() error {
d.m.Lock()
defer d.m.Unlock()
return d.err
}
// setErr is a thread-safe setter for the error object
func (d *downloader) setErr(e error) {
d.m.Lock()
defer d.m.Unlock()
d.err = e
}
// Chunk represents a single chunk of data to write by the worker routine.
// This structure also implements an io.SectionReader style interface for
// io.WriterAt, effectively making it an io.SectionWriter (which does not
// exist).
type chunk struct {
start int64
size int64
buf *Buf
id int
newConcurrency bool
}
func DefaultHttpRequestFunc(ctx context.Context, params *HttpRequestParams) (*http.Response, error) {
header := http_range.ApplyRangeToHttpHeader(params.Range, params.HeaderRef)
return RequestHttp(ctx, "GET", header, params.URL)
}
func GetRangeReaderHttpRequestFunc(rangeReader model.RangeReaderIF) HttpRequestFunc {
return func(ctx context.Context, params *HttpRequestParams) (*http.Response, error) {
rc, err := rangeReader.RangeRead(ctx, params.Range)
if err != nil {
return nil, err
}
return &http.Response{
StatusCode: http.StatusPartialContent,
Status: http.StatusText(http.StatusPartialContent),
Body: rc,
Header: http.Header{
"Content-Range": {params.Range.ContentRange(params.Size)},
},
ContentLength: params.Range.Length,
}, nil
}
}
type HttpRequestParams struct {
URL string
//only want data within this range
Range http_range.Range
HeaderRef http.Header
//total file size
Size int64
}
type errNeedRetry struct {
err error
}
func (e *errNeedRetry) Error() string {
return e.err.Error()
}
func (e *errNeedRetry) Unwrap() error {
return e.err
}
type MultiReadCloser struct {
cfg *cfg
closer closerFunc
finish finishBufFUnc
}
type cfg struct {
rPos int //current reader position, start from 0
curBuf *Buf
}
type closerFunc func() error
type finishBufFUnc func(id int) (isLast bool, buf *Buf)
// NewMultiReadCloser to save memory, we re-use limited Buf, and feed data to Read()
func NewMultiReadCloser(buf *Buf, c closerFunc, fb finishBufFUnc) *MultiReadCloser {
return &MultiReadCloser{closer: c, finish: fb, cfg: &cfg{curBuf: buf}}
}
func (mr MultiReadCloser) Read(p []byte) (n int, err error) {
if mr.cfg.curBuf == nil {
return 0, io.EOF
}
n, err = mr.cfg.curBuf.Read(p)
//log.Debugf("read_%d read current buffer, n=%d ,err=%+v", mr.cfg.rPos, n, err)
if err == io.EOF {
log.Debugf("read_%d finished current buffer", mr.cfg.rPos)
isLast, next := mr.finish(mr.cfg.rPos)
if isLast {
return n, io.EOF
}
mr.cfg.curBuf = next
mr.cfg.rPos++
return n, nil
}
if err == context.Canceled {
if e := context.Cause(mr.cfg.curBuf.ctx); e != nil {
err = e
}
}
return n, err
}
func (mr MultiReadCloser) Close() error {
return mr.closer()
}
type Buf struct {
size int //expected size
ctx context.Context
offR int
offW int
rw sync.Mutex
buf []byte
mmap bool
readSignal chan struct{}
readPending bool
}
// NewBuf is a buffer that can have 1 read & 1 write at the same time.
// when read is faster write, immediately feed data to read after written
func NewBuf(ctx context.Context, maxSize int) *Buf {
br := &Buf{
ctx: ctx,
size: maxSize,
readSignal: make(chan struct{}, 1),
}
if conf.MmapThreshold > 0 && maxSize >= conf.MmapThreshold {
m, err := mmap.Alloc(maxSize)
if err == nil {
br.buf = m
br.mmap = true
return br
}
}
br.buf = make([]byte, maxSize)
return br
}
func (br *Buf) Reset(size int) error {
br.rw.Lock()
defer br.rw.Unlock()
if br.buf == nil {
return io.ErrClosedPipe
}
if size > cap(br.buf) {
return fmt.Errorf("reset size %d exceeds max size %d", size, cap(br.buf))
}
br.size = size
br.offR = 0
br.offW = 0
return nil
}
func (br *Buf) Read(p []byte) (int, error) {
if err := br.ctx.Err(); err != nil {
return 0, err
}
if len(p) == 0 {
return 0, nil
}
if br.offR >= br.size {
return 0, io.EOF
}
for {
br.rw.Lock()
if br.buf == nil {
br.rw.Unlock()
return 0, io.ErrClosedPipe
}
if br.offW < br.offR {
br.rw.Unlock()
return 0, io.ErrUnexpectedEOF
}
if br.offW == br.offR {
br.readPending = true
br.rw.Unlock()
select {
case <-br.ctx.Done():
return 0, br.ctx.Err()
case _, ok := <-br.readSignal:
if !ok {
return 0, io.ErrClosedPipe
}
continue
}
}
n := copy(p, br.buf[br.offR:br.offW])
br.offR += n
br.rw.Unlock()
if n < len(p) && br.offR >= br.size {
return n, io.EOF
}
return n, nil
}
}
func (br *Buf) Write(p []byte) (int, error) {
if err := br.ctx.Err(); err != nil {
return 0, err
}
if len(p) == 0 {
return 0, nil
}
br.rw.Lock()
defer br.rw.Unlock()
if br.buf == nil {
return 0, io.ErrClosedPipe
}
if br.offW >= br.size {
return 0, io.ErrShortWrite
}
n := copy(br.buf[br.offW:], p[:min(br.size-br.offW, len(p))])
br.offW += n
if br.readPending {
br.readPending = false
select {
case br.readSignal <- struct{}{}:
default:
}
}
if n < len(p) {
return n, io.ErrShortWrite
}
return n, nil
}
func (br *Buf) Close() error {
br.rw.Lock()
defer br.rw.Unlock()
var err error
if br.mmap {
err = mmap.Free(br.buf)
br.mmap = false
}
br.buf = nil
close(br.readSignal)
return err
}