| package net |
|
|
| import ( |
| "fmt" |
| "io" |
| "math" |
| "mime/multipart" |
| "net/http" |
| "net/textproto" |
| "strings" |
| "time" |
|
|
| "github.com/alist-org/alist/v3/pkg/http_range" |
| log "github.com/sirupsen/logrus" |
| ) |
|
|
| |
| |
| |
| func scanETag(s string) (etag string, remain string) { |
| s = textproto.TrimString(s) |
| start := 0 |
| if strings.HasPrefix(s, "W/") { |
| start = 2 |
| } |
| if len(s[start:]) < 2 || s[start] != '"' { |
| return "", "" |
| } |
| |
| |
| for i := start + 1; i < len(s); i++ { |
| c := s[i] |
| switch { |
| |
| case c == 0x21 || c >= 0x23 && c <= 0x7E || c >= 0x80: |
| case c == '"': |
| return s[:i+1], s[i+1:] |
| default: |
| return "", "" |
| } |
| } |
| return "", "" |
| } |
|
|
| |
| |
| func etagStrongMatch(a, b string) bool { |
| return a == b && a != "" && a[0] == '"' |
| } |
|
|
| |
| |
| func etagWeakMatch(a, b string) bool { |
| return strings.TrimPrefix(a, "W/") == strings.TrimPrefix(b, "W/") |
| } |
|
|
| |
| |
| type condResult int |
|
|
| const ( |
| condNone condResult = iota |
| condTrue |
| condFalse |
| ) |
|
|
| func checkIfMatch(w http.ResponseWriter, r *http.Request) condResult { |
| im := r.Header.Get("If-Match") |
| if im == "" { |
| return condNone |
| } |
| for { |
| im = textproto.TrimString(im) |
| if len(im) == 0 { |
| break |
| } |
| if im[0] == ',' { |
| im = im[1:] |
| continue |
| } |
| if im[0] == '*' { |
| return condTrue |
| } |
| etag, remain := scanETag(im) |
| if etag == "" { |
| break |
| } |
| if etagStrongMatch(etag, w.Header().Get("Etag")) { |
| return condTrue |
| } |
| im = remain |
| } |
|
|
| return condFalse |
| } |
|
|
| func checkIfUnmodifiedSince(r *http.Request, modtime time.Time) condResult { |
| ius := r.Header.Get("If-Unmodified-Since") |
| if ius == "" || isZeroTime(modtime) { |
| return condNone |
| } |
| t, err := http.ParseTime(ius) |
| if err != nil { |
| return condNone |
| } |
|
|
| |
| |
| modtime = modtime.Truncate(time.Second) |
| if ret := modtime.Compare(t); ret <= 0 { |
| return condTrue |
| } |
| return condFalse |
| } |
|
|
| func checkIfNoneMatch(w http.ResponseWriter, r *http.Request) condResult { |
| inm := r.Header.Get("If-None-Match") |
| if inm == "" { |
| return condNone |
| } |
| buf := inm |
| for { |
| buf = textproto.TrimString(buf) |
| if len(buf) == 0 { |
| break |
| } |
| if buf[0] == ',' { |
| buf = buf[1:] |
| continue |
| } |
| if buf[0] == '*' { |
| return condFalse |
| } |
| etag, remain := scanETag(buf) |
| if etag == "" { |
| break |
| } |
| if etagWeakMatch(etag, w.Header().Get("Etag")) { |
| return condFalse |
| } |
| buf = remain |
| } |
| return condTrue |
| } |
|
|
| func checkIfModifiedSince(r *http.Request, modtime time.Time) condResult { |
| if r.Method != "GET" && r.Method != "HEAD" { |
| return condNone |
| } |
| ims := r.Header.Get("If-Modified-Since") |
| if ims == "" || isZeroTime(modtime) { |
| return condNone |
| } |
| t, err := http.ParseTime(ims) |
| if err != nil { |
| return condNone |
| } |
| |
| |
| modtime = modtime.Truncate(time.Second) |
| if ret := modtime.Compare(t); ret <= 0 { |
| return condFalse |
| } |
| return condTrue |
| } |
|
|
| func checkIfRange(w http.ResponseWriter, r *http.Request, modtime time.Time) condResult { |
| if r.Method != "GET" && r.Method != "HEAD" { |
| return condNone |
| } |
| ir := r.Header.Get("If-Range") |
| if ir == "" { |
| return condNone |
| } |
| etag, _ := scanETag(ir) |
| if etag != "" { |
| if etagStrongMatch(etag, w.Header().Get("Etag")) { |
| return condTrue |
| } |
| return condFalse |
| } |
| |
| |
| if modtime.IsZero() { |
| return condFalse |
| } |
| t, err := http.ParseTime(ir) |
| if err != nil { |
| return condFalse |
| } |
| if t.Unix() == modtime.Unix() { |
| return condTrue |
| } |
| return condFalse |
| } |
|
|
| var unixEpochTime = time.Unix(0, 0) |
|
|
| |
| func isZeroTime(t time.Time) bool { |
| return t.IsZero() || t.Equal(unixEpochTime) |
| } |
|
|
| func setLastModified(w http.ResponseWriter, modtime time.Time) { |
| if !isZeroTime(modtime) { |
| w.Header().Set("Last-Modified", modtime.UTC().Format(http.TimeFormat)) |
| } |
| } |
|
|
| func writeNotModified(w http.ResponseWriter) { |
| |
| |
| |
| |
| |
| h := w.Header() |
| delete(h, "Content-Type") |
| delete(h, "Content-Length") |
| delete(h, "Content-Encoding") |
| if h.Get("Etag") != "" { |
| delete(h, "Last-Modified") |
| } |
| w.WriteHeader(http.StatusNotModified) |
| } |
|
|
| |
| |
| func checkPreconditions(w http.ResponseWriter, r *http.Request, modtime time.Time) (done bool, rangeHeader string) { |
| |
| ch := checkIfMatch(w, r) |
| if ch == condNone { |
| ch = checkIfUnmodifiedSince(r, modtime) |
| } |
| if ch == condFalse { |
| w.WriteHeader(http.StatusPreconditionFailed) |
| return true, "" |
| } |
| switch checkIfNoneMatch(w, r) { |
| case condFalse: |
| if r.Method == "GET" || r.Method == "HEAD" { |
| writeNotModified(w) |
| return true, "" |
| } |
| w.WriteHeader(http.StatusPreconditionFailed) |
| return true, "" |
| case condNone: |
| if checkIfModifiedSince(r, modtime) == condFalse { |
| writeNotModified(w) |
| return true, "" |
| } |
| } |
|
|
| rangeHeader = r.Header.Get("Range") |
| if rangeHeader != "" && checkIfRange(w, r, modtime) == condFalse { |
| rangeHeader = "" |
| } |
| return false, rangeHeader |
| } |
|
|
| func sumRangesSize(ranges []http_range.Range) (size int64) { |
| for _, ra := range ranges { |
| size += ra.Length |
| } |
| return |
| } |
|
|
| |
| type countingWriter int64 |
|
|
| func (w *countingWriter) Write(p []byte) (n int, err error) { |
| *w += countingWriter(len(p)) |
| return len(p), nil |
| } |
|
|
| |
| |
| func rangesMIMESize(ranges []http_range.Range, contentType string, contentSize int64) (encSize int64, err error) { |
| var w countingWriter |
| mw := multipart.NewWriter(&w) |
| for _, ra := range ranges { |
| _, err := mw.CreatePart(ra.MimeHeader(contentType, contentSize)) |
| if err != nil { |
| return 0, err |
| } |
| encSize += ra.Length |
| } |
| err = mw.Close() |
| if err != nil { |
| return 0, err |
| } |
| encSize += int64(w) |
| return encSize, nil |
| } |
|
|
| |
| type LimitedReadCloser struct { |
| rc io.ReadCloser |
| remaining int |
| } |
|
|
| func (l *LimitedReadCloser) Read(buf []byte) (int, error) { |
| if l.remaining <= 0 { |
| return 0, io.EOF |
| } |
|
|
| if len(buf) > l.remaining { |
| buf = buf[0:l.remaining] |
| } |
|
|
| n, err := l.rc.Read(buf) |
| l.remaining -= n |
|
|
| return n, err |
| } |
|
|
| func (l *LimitedReadCloser) Close() error { |
| return l.rc.Close() |
| } |
|
|
| |
| |
| func GetRangedHttpReader(readCloser io.ReadCloser, offset, length int64) (io.ReadCloser, error) { |
| var length_int int |
| if length > math.MaxInt { |
| return nil, fmt.Errorf("doesnot support length bigger than int32 max ") |
| } |
| length_int = int(length) |
|
|
| if offset > 100*1024*1024 { |
| log.Warnf("offset is more than 100MB, if loading data from internet, high-latency and wasting of bandwidth is expected") |
| } |
|
|
| if _, err := io.Copy(io.Discard, io.LimitReader(readCloser, offset)); err != nil { |
| return nil, err |
| } |
|
|
| |
| return &LimitedReadCloser{readCloser, length_int}, nil |
| } |
|
|