| package net |
|
|
| import ( |
| "context" |
| "errors" |
| "fmt" |
| "io" |
| "net/http" |
| "strconv" |
| "strings" |
| "sync" |
| "time" |
|
|
| "github.com/OpenListTeam/OpenList/v4/internal/conf" |
| "github.com/OpenListTeam/OpenList/v4/internal/errs" |
| "github.com/OpenListTeam/OpenList/v4/internal/model" |
| "github.com/OpenListTeam/OpenList/v4/pkg/utils" |
| "github.com/rclone/rclone/lib/mmap" |
|
|
| "github.com/OpenListTeam/OpenList/v4/pkg/http_range" |
| "github.com/aws/aws-sdk-go/aws/awsutil" |
| log "github.com/sirupsen/logrus" |
| ) |
|
|
| |
| |
| const DefaultDownloadPartSize = utils.MB * 8 |
|
|
| |
| |
| const DefaultDownloadConcurrency = 2 |
|
|
| |
| const DefaultPartBodyMaxRetries = 3 |
|
|
| var DefaultConcurrencyLimit *ConcurrencyLimit |
|
|
| type Downloader struct { |
| PartSize int |
|
|
| |
| PartBodyMaxRetries int |
|
|
| |
| |
| |
| |
| Concurrency int |
|
|
| |
| HttpClient HttpRequestFunc |
|
|
| *ConcurrencyLimit |
| } |
| type HttpRequestFunc func(ctx context.Context, params *HttpRequestParams) (*http.Response, error) |
|
|
| func NewDownloader(options ...func(*Downloader)) *Downloader { |
| d := &Downloader{ |
| PartBodyMaxRetries: DefaultPartBodyMaxRetries, |
| ConcurrencyLimit: DefaultConcurrencyLimit, |
| } |
| for _, option := range options { |
| option(d) |
| } |
| return d |
| } |
|
|
| |
| |
| |
| |
| func (d Downloader) Download(ctx context.Context, p *HttpRequestParams) (readCloser io.ReadCloser, err error) { |
|
|
| var finalP HttpRequestParams |
| awsutil.Copy(&finalP, p) |
| if finalP.Range.Length < 0 || finalP.Range.Start+finalP.Range.Length > finalP.Size { |
| finalP.Range.Length = finalP.Size - finalP.Range.Start |
| } |
| impl := downloader{params: &finalP, cfg: d, ctx: ctx} |
|
|
| |
| |
| if impl.cfg.Concurrency == 0 { |
| impl.cfg.Concurrency = DefaultDownloadConcurrency |
| } |
| if impl.cfg.PartSize == 0 { |
| impl.cfg.PartSize = DefaultDownloadPartSize |
| } |
| if conf.MaxBufferLimit > 0 && impl.cfg.PartSize > conf.MaxBufferLimit { |
| impl.cfg.PartSize = conf.MaxBufferLimit |
| } |
| if impl.cfg.HttpClient == nil { |
| impl.cfg.HttpClient = DefaultHttpRequestFunc |
| } |
|
|
| return impl.download() |
| } |
|
|
| |
| type downloader struct { |
| ctx context.Context |
| cancel context.CancelCauseFunc |
| cfg Downloader |
|
|
| params *HttpRequestParams |
| chunkChannel chan chunk |
|
|
| |
| m sync.Mutex |
|
|
| nextChunk int |
| bufs []*Buf |
| written int64 |
| err error |
|
|
| concurrency int |
| maxPart int |
| pos int64 |
| maxPos int64 |
| m2 sync.Mutex |
| readingID int |
| } |
|
|
| type ConcurrencyLimit struct { |
| _m sync.Mutex |
| Limit int |
| } |
|
|
| var ErrExceedMaxConcurrency = HttpStatusCodeError(http.StatusTooManyRequests) |
|
|
| func (l *ConcurrencyLimit) sub() error { |
| l._m.Lock() |
| defer l._m.Unlock() |
| if l.Limit-1 < 0 { |
| return ErrExceedMaxConcurrency |
| } |
| l.Limit-- |
| |
| return nil |
| } |
| func (l *ConcurrencyLimit) add() { |
| l._m.Lock() |
| defer l._m.Unlock() |
| l.Limit++ |
| |
| } |
|
|
| |
| func (d *downloader) concurrencyCheck() error { |
| if d.cfg.ConcurrencyLimit != nil { |
| return d.cfg.ConcurrencyLimit.sub() |
| } |
| return nil |
| } |
| func (d *downloader) concurrencyFinish() { |
| if d.cfg.ConcurrencyLimit != nil { |
| d.cfg.ConcurrencyLimit.add() |
| } |
| } |
|
|
| |
| func (d *downloader) download() (io.ReadCloser, error) { |
| if err := d.concurrencyCheck(); err != nil { |
| return nil, err |
| } |
|
|
| maxPart := 1 |
| if d.params.Range.Length > int64(d.cfg.PartSize) { |
| maxPart = int((d.params.Range.Length + int64(d.cfg.PartSize) - 1) / int64(d.cfg.PartSize)) |
| } |
| if maxPart < d.cfg.Concurrency { |
| d.cfg.Concurrency = maxPart |
| } |
| log.Debugf("cfgConcurrency:%d", d.cfg.Concurrency) |
|
|
| if maxPart == 1 { |
| resp, err := d.cfg.HttpClient(d.ctx, d.params) |
| if err != nil { |
| d.concurrencyFinish() |
| return nil, err |
| } |
| closeFunc := resp.Body.Close |
| resp.Body = utils.NewReadCloser(resp.Body, func() error { |
| d.m.Lock() |
| defer d.m.Unlock() |
| var err error |
| if closeFunc != nil { |
| d.concurrencyFinish() |
| err = closeFunc() |
| closeFunc = nil |
| } |
| return err |
| }) |
| return resp.Body, nil |
| } |
| d.ctx, d.cancel = context.WithCancelCause(d.ctx) |
|
|
| |
| d.chunkChannel = make(chan chunk, d.cfg.Concurrency) |
|
|
| d.maxPart = maxPart |
| d.pos = d.params.Range.Start |
| d.maxPos = d.params.Range.Start + d.params.Range.Length |
| d.concurrency = d.cfg.Concurrency |
| _ = d.sendChunkTask(true) |
|
|
| var rc io.ReadCloser = NewMultiReadCloser(d.bufs[0], d.interrupt, d.finishBuf) |
|
|
| |
| return rc, d.err |
| } |
|
|
| func (d *downloader) sendChunkTask(newConcurrency bool) error { |
| d.m.Lock() |
| defer d.m.Unlock() |
| isNewBuf := d.concurrency > 0 |
| if newConcurrency { |
| if d.concurrency <= 0 { |
| return nil |
| } |
| if d.nextChunk > 0 { |
| if err := d.concurrencyCheck(); err != nil { |
| return err |
| } |
| } |
| d.concurrency-- |
| go d.downloadPart() |
| } |
|
|
| var buf *Buf |
| if isNewBuf { |
| buf = NewBuf(d.ctx, d.cfg.PartSize) |
| d.bufs = append(d.bufs, buf) |
| } else { |
| buf = d.getBuf(d.nextChunk) |
| } |
|
|
| if d.pos < d.maxPos { |
| finalSize := int64(d.cfg.PartSize) |
| switch d.nextChunk { |
| case 0: |
| |
| firstSize := d.params.Range.Length % finalSize |
| if firstSize > 0 { |
| minSize := finalSize / 2 |
| if firstSize < minSize { |
| finalSize = minSize |
| } else { |
| finalSize = firstSize |
| } |
| } |
| case 1: |
| firstSize := d.params.Range.Length % finalSize |
| minSize := finalSize / 2 |
| if firstSize > 0 && firstSize < minSize { |
| finalSize += firstSize - minSize |
| } |
| } |
| err := buf.Reset(int(finalSize)) |
| if err != nil { |
| return err |
| } |
| ch := chunk{ |
| start: d.pos, |
| size: finalSize, |
| id: d.nextChunk, |
| buf: buf, |
|
|
| newConcurrency: newConcurrency, |
| } |
| d.pos += finalSize |
| d.nextChunk++ |
| d.chunkChannel <- ch |
| return nil |
| } |
| return nil |
| } |
|
|
| |
| func (d *downloader) interrupt() error { |
| d.m.Lock() |
| defer d.m.Unlock() |
| err := d.err |
| if err == nil && d.written != d.params.Range.Length { |
| log.Debugf("Downloader interrupt before finish") |
| err := fmt.Errorf("interrupted") |
| d.err = err |
| } |
| close(d.chunkChannel) |
| if d.bufs != nil { |
| d.cancel(err) |
| for _, buf := range d.bufs { |
| buf.Close() |
| } |
| d.bufs = nil |
| if d.concurrency > 0 { |
| d.concurrency = -d.concurrency |
| } |
| log.Debugf("maxConcurrency:%d", d.cfg.Concurrency+d.concurrency) |
| } |
| return err |
| } |
| func (d *downloader) getBuf(id int) (b *Buf) { |
| return d.bufs[id%len(d.bufs)] |
| } |
| func (d *downloader) finishBuf(id int) (isLast bool, nextBuf *Buf) { |
| id++ |
| if id >= d.maxPart { |
| return true, nil |
| } |
|
|
| _ = d.sendChunkTask(false) |
|
|
| d.readingID = id |
| return false, d.getBuf(id) |
| } |
|
|
| |
| |
| func (d *downloader) downloadPart() { |
| defer d.concurrencyFinish() |
| for { |
| select { |
| case <-d.ctx.Done(): |
| return |
| case c, ok := <-d.chunkChannel: |
| if !ok { |
| return |
| } |
| if d.getErr() != nil { |
| |
| |
| return |
| } |
| if err := d.downloadChunk(&c); err != nil { |
| if err == errCancelConcurrency { |
| return |
| } |
| if err == context.Canceled { |
| if e := context.Cause(d.ctx); e != nil { |
| err = e |
| } |
| } |
| d.setErr(err) |
| d.cancel(err) |
| return |
| } |
| } |
| } |
| } |
|
|
| |
| func (d *downloader) downloadChunk(ch *chunk) error { |
| log.Debugf("start chunk_%d, %+v", ch.id, ch) |
| params := d.getParamsFromChunk(ch) |
| var n int64 |
| var err error |
| for retry := 0; retry <= d.cfg.PartBodyMaxRetries; retry++ { |
| if d.getErr() != nil { |
| return nil |
| } |
| n, err = d.tryDownloadChunk(params, ch) |
| if err == nil { |
| d.incrWritten(n) |
| log.Debugf("chunk_%d downloaded", ch.id) |
| break |
| } |
| if d.getErr() != nil { |
| return nil |
| } |
| if utils.IsCanceled(d.ctx) { |
| return d.ctx.Err() |
| } |
| |
| |
| |
| if e, ok := err.(*errNeedRetry); ok { |
| err = e.Unwrap() |
| if n > 0 { |
| |
| |
| d.incrWritten(n) |
| ch.start += n |
| ch.size -= n |
| params.Range.Start = ch.start |
| params.Range.Length = ch.size |
| } |
| log.Warnf("err chunk_%d, object part download error %s, retrying attempt %d. %v", |
| ch.id, params.URL, retry, err) |
| } else if err == errInfiniteRetry { |
| retry-- |
| continue |
| } else { |
| break |
| } |
| } |
|
|
| return err |
| } |
|
|
| var errCancelConcurrency = errors.New("cancel concurrency") |
| var errInfiniteRetry = errors.New("infinite retry") |
|
|
| func (d *downloader) tryDownloadChunk(params *HttpRequestParams, ch *chunk) (int64, error) { |
| resp, err := d.cfg.HttpClient(d.ctx, params) |
| if err != nil { |
| statusCode, ok := errs.UnwrapOrSelf(err).(HttpStatusCodeError) |
| if !ok { |
| return 0, err |
| } |
| if statusCode == http.StatusRequestedRangeNotSatisfiable { |
| return 0, err |
| } |
| if ch.id == 0 { |
| switch statusCode { |
| default: |
| return 0, err |
| case http.StatusTooManyRequests: |
| case http.StatusBadGateway: |
| case http.StatusServiceUnavailable: |
| case http.StatusGatewayTimeout: |
| } |
| <-time.After(time.Millisecond * 200) |
| return 0, &errNeedRetry{err: err} |
| } |
|
|
| |
| |
| log.Debugf("err chunk_%d, try downloading:%v", ch.id, err) |
|
|
| d.m.Lock() |
| isCancelConcurrency := ch.newConcurrency |
| if d.concurrency > 0 { |
| |
| d.concurrency = -d.concurrency |
| isCancelConcurrency = true |
| } |
| if isCancelConcurrency { |
| d.concurrency-- |
| d.chunkChannel <- *ch |
| d.m.Unlock() |
| return 0, errCancelConcurrency |
| } |
| d.m.Unlock() |
| if ch.id != d.readingID { |
| d.m2.Lock() |
| defer d.m2.Unlock() |
| <-time.After(time.Millisecond * 200) |
| } |
| return 0, errInfiniteRetry |
| } |
| defer resp.Body.Close() |
| |
| if ch.id == 0 { |
| err = d.checkTotalBytes(resp) |
| if err != nil { |
| return 0, err |
| } |
| } |
| _ = d.sendChunkTask(true) |
| n, err := utils.CopyWithBuffer(ch.buf, resp.Body) |
|
|
| if err != nil { |
| return n, &errNeedRetry{err: err} |
| } |
| if n != ch.size { |
| err = fmt.Errorf("chunk download size incorrect, expected=%d, got=%d", ch.size, n) |
| return n, &errNeedRetry{err: err} |
| } |
|
|
| return n, nil |
| } |
| func (d *downloader) getParamsFromChunk(ch *chunk) *HttpRequestParams { |
| var params HttpRequestParams |
| awsutil.Copy(¶ms, d.params) |
|
|
| |
| params.Range = http_range.Range{Start: ch.start, Length: ch.size} |
| return ¶ms |
| } |
|
|
| func (d *downloader) checkTotalBytes(resp *http.Response) error { |
| var err error |
| totalBytes := int64(-1) |
| contentRange := resp.Header.Get("Content-Range") |
| if len(contentRange) == 0 { |
| |
| |
| if resp.ContentLength > 0 { |
| totalBytes = resp.ContentLength |
| } |
| } else { |
| parts := strings.Split(contentRange, "/") |
|
|
| total := int64(-1) |
|
|
| |
| |
| |
| totalStr := parts[len(parts)-1] |
| if totalStr != "*" { |
| total, err = strconv.ParseInt(totalStr, 10, 64) |
| if err != nil { |
| err = fmt.Errorf("failed extracting file size") |
| } |
| } else { |
| err = fmt.Errorf("file size unknown") |
| } |
|
|
| totalBytes = total |
| } |
| if totalBytes != d.params.Size && err == nil { |
| err = fmt.Errorf("expect file size=%d unmatch remote report size=%d, need refresh cache", d.params.Size, totalBytes) |
| } |
| if err != nil { |
| |
| d.setErr(err) |
| d.cancel(err) |
| } |
| return err |
|
|
| } |
|
|
| func (d *downloader) incrWritten(n int64) { |
| d.m.Lock() |
| defer d.m.Unlock() |
|
|
| d.written += n |
| } |
|
|
| |
| func (d *downloader) getErr() error { |
| d.m.Lock() |
| defer d.m.Unlock() |
|
|
| return d.err |
| } |
|
|
| |
| func (d *downloader) setErr(e error) { |
| d.m.Lock() |
| defer d.m.Unlock() |
|
|
| d.err = e |
| } |
|
|
| |
| |
| |
| |
| type chunk struct { |
| start int64 |
| size int64 |
| buf *Buf |
| id int |
|
|
| newConcurrency bool |
| } |
|
|
| func DefaultHttpRequestFunc(ctx context.Context, params *HttpRequestParams) (*http.Response, error) { |
| header := http_range.ApplyRangeToHttpHeader(params.Range, params.HeaderRef) |
| return RequestHttp(ctx, "GET", header, params.URL) |
| } |
|
|
| func GetRangeReaderHttpRequestFunc(rangeReader model.RangeReaderIF) HttpRequestFunc { |
| return func(ctx context.Context, params *HttpRequestParams) (*http.Response, error) { |
| rc, err := rangeReader.RangeRead(ctx, params.Range) |
| if err != nil { |
| return nil, err |
| } |
|
|
| return &http.Response{ |
| StatusCode: http.StatusPartialContent, |
| Status: http.StatusText(http.StatusPartialContent), |
| Body: rc, |
| Header: http.Header{ |
| "Content-Range": {params.Range.ContentRange(params.Size)}, |
| }, |
| ContentLength: params.Range.Length, |
| }, nil |
| } |
| } |
|
|
| type HttpRequestParams struct { |
| URL string |
| |
| Range http_range.Range |
| HeaderRef http.Header |
| |
| Size int64 |
| } |
| type errNeedRetry struct { |
| err error |
| } |
|
|
| func (e *errNeedRetry) Error() string { |
| return e.err.Error() |
| } |
|
|
| func (e *errNeedRetry) Unwrap() error { |
| return e.err |
| } |
|
|
| type MultiReadCloser struct { |
| cfg *cfg |
| closer closerFunc |
| finish finishBufFUnc |
| } |
|
|
| type cfg struct { |
| rPos int |
| curBuf *Buf |
| } |
|
|
| type closerFunc func() error |
| type finishBufFUnc func(id int) (isLast bool, buf *Buf) |
|
|
| |
| func NewMultiReadCloser(buf *Buf, c closerFunc, fb finishBufFUnc) *MultiReadCloser { |
| return &MultiReadCloser{closer: c, finish: fb, cfg: &cfg{curBuf: buf}} |
| } |
|
|
| func (mr MultiReadCloser) Read(p []byte) (n int, err error) { |
| if mr.cfg.curBuf == nil { |
| return 0, io.EOF |
| } |
| n, err = mr.cfg.curBuf.Read(p) |
| |
| if err == io.EOF { |
| log.Debugf("read_%d finished current buffer", mr.cfg.rPos) |
|
|
| isLast, next := mr.finish(mr.cfg.rPos) |
| if isLast { |
| return n, io.EOF |
| } |
| mr.cfg.curBuf = next |
| mr.cfg.rPos++ |
| return n, nil |
| } |
| if err == context.Canceled { |
| if e := context.Cause(mr.cfg.curBuf.ctx); e != nil { |
| err = e |
| } |
| } |
| return n, err |
| } |
| func (mr MultiReadCloser) Close() error { |
| return mr.closer() |
| } |
|
|
| type Buf struct { |
| size int |
| ctx context.Context |
| offR int |
| offW int |
| rw sync.Mutex |
| buf []byte |
| mmap bool |
|
|
| readSignal chan struct{} |
| readPending bool |
| } |
|
|
| |
| |
| func NewBuf(ctx context.Context, maxSize int) *Buf { |
| br := &Buf{ |
| ctx: ctx, |
| size: maxSize, |
| readSignal: make(chan struct{}, 1), |
| } |
| if conf.MmapThreshold > 0 && maxSize >= conf.MmapThreshold { |
| m, err := mmap.Alloc(maxSize) |
| if err == nil { |
| br.buf = m |
| br.mmap = true |
| return br |
| } |
| } |
| br.buf = make([]byte, maxSize) |
| return br |
| } |
|
|
| func (br *Buf) Reset(size int) error { |
| br.rw.Lock() |
| defer br.rw.Unlock() |
| if br.buf == nil { |
| return io.ErrClosedPipe |
| } |
| if size > cap(br.buf) { |
| return fmt.Errorf("reset size %d exceeds max size %d", size, cap(br.buf)) |
| } |
| br.size = size |
| br.offR = 0 |
| br.offW = 0 |
| return nil |
| } |
|
|
| func (br *Buf) Read(p []byte) (int, error) { |
| if err := br.ctx.Err(); err != nil { |
| return 0, err |
| } |
| if len(p) == 0 { |
| return 0, nil |
| } |
| if br.offR >= br.size { |
| return 0, io.EOF |
| } |
| for { |
| br.rw.Lock() |
| if br.buf == nil { |
| br.rw.Unlock() |
| return 0, io.ErrClosedPipe |
| } |
|
|
| if br.offW < br.offR { |
| br.rw.Unlock() |
| return 0, io.ErrUnexpectedEOF |
| } |
| if br.offW == br.offR { |
| br.readPending = true |
| br.rw.Unlock() |
| select { |
| case <-br.ctx.Done(): |
| return 0, br.ctx.Err() |
| case _, ok := <-br.readSignal: |
| if !ok { |
| return 0, io.ErrClosedPipe |
| } |
| continue |
| } |
| } |
|
|
| n := copy(p, br.buf[br.offR:br.offW]) |
| br.offR += n |
| br.rw.Unlock() |
| if n < len(p) && br.offR >= br.size { |
| return n, io.EOF |
| } |
| return n, nil |
| } |
| } |
|
|
| func (br *Buf) Write(p []byte) (int, error) { |
| if err := br.ctx.Err(); err != nil { |
| return 0, err |
| } |
| if len(p) == 0 { |
| return 0, nil |
| } |
| br.rw.Lock() |
| defer br.rw.Unlock() |
| if br.buf == nil { |
| return 0, io.ErrClosedPipe |
| } |
| if br.offW >= br.size { |
| return 0, io.ErrShortWrite |
| } |
| n := copy(br.buf[br.offW:], p[:min(br.size-br.offW, len(p))]) |
| br.offW += n |
| if br.readPending { |
| br.readPending = false |
| select { |
| case br.readSignal <- struct{}{}: |
| default: |
| } |
| } |
| if n < len(p) { |
| return n, io.ErrShortWrite |
| } |
| return n, nil |
| } |
|
|
| func (br *Buf) Close() error { |
| br.rw.Lock() |
| defer br.rw.Unlock() |
| var err error |
| if br.mmap { |
| err = mmap.Free(br.buf) |
| br.mmap = false |
| } |
| br.buf = nil |
| close(br.readSignal) |
| return err |
| } |
|
|