StarrySkyWorld's picture
Add multi-part concurrent download with single-stream fallback
2e72a7f
package services
import (
"archive/tar"
"archive/zip"
"bufio"
"compress/gzip"
"context"
"errors"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"strings"
"sync"
"sync/atomic"
"time"
"fastfileviewer/backend/internal/models"
"gorm.io/gorm"
)
type taskEvent struct {
TaskID uint
Status models.TaskStatus
Progress int
LogLine string
Error string
}
type TaskService struct {
db *gorm.DB
fs *FSService
queue chan uint
workers int
mu sync.RWMutex
subscribers map[uint]map[chan taskEvent]struct{}
}
const (
minMultiPartSize = 16 * 1024 * 1024
maxDownloadParts = 8
partRetryLimit = 3
)
func NewTaskService(database *gorm.DB, fs *FSService, workers int) *TaskService {
if workers < 1 {
workers = 1
}
return &TaskService{
db: database,
fs: fs,
queue: make(chan uint, 256),
workers: workers,
subscribers: map[uint]map[chan taskEvent]struct{}{},
}
}
func (s *TaskService) Start(ctx context.Context) {
for i := 0; i < s.workers; i++ {
go s.worker(ctx)
}
}
func (s *TaskService) CreateDownloadTask(url, targetPath string) (*models.Task, error) {
task := &models.Task{
Type: models.TaskTypeDownload,
Status: models.TaskStatusQueued,
Source: url,
TargetPath: targetPath,
Progress: 0,
}
if err := s.db.Create(task).Error; err != nil {
return nil, err
}
s.queue <- task.ID
return task, nil
}
func (s *TaskService) CreateExtractTask(sourceArchive, targetPath string) (*models.Task, error) {
task := &models.Task{
Type: models.TaskTypeExtract,
Status: models.TaskStatusQueued,
Source: sourceArchive,
TargetPath: targetPath,
Progress: 0,
}
if err := s.db.Create(task).Error; err != nil {
return nil, err
}
s.queue <- task.ID
return task, nil
}
func (s *TaskService) ListTasks(limit int) ([]models.Task, error) {
if limit <= 0 || limit > 100 {
limit = 50
}
var tasks []models.Task
err := s.db.Order("id desc").Limit(limit).Find(&tasks).Error
return tasks, err
}
func (s *TaskService) GetTask(id uint) (*models.Task, error) {
var task models.Task
if err := s.db.First(&task, id).Error; err != nil {
return nil, err
}
return &task, nil
}
func (s *TaskService) Subscribe(taskID uint) (chan taskEvent, func()) {
ch := make(chan taskEvent, 16)
s.mu.Lock()
if s.subscribers[taskID] == nil {
s.subscribers[taskID] = map[chan taskEvent]struct{}{}
}
s.subscribers[taskID][ch] = struct{}{}
s.mu.Unlock()
cancel := func() {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.subscribers[taskID], ch)
close(ch)
if len(s.subscribers[taskID]) == 0 {
delete(s.subscribers, taskID)
}
}
return ch, cancel
}
func (s *TaskService) worker(ctx context.Context) {
for {
select {
case <-ctx.Done():
return
case id := <-s.queue:
_ = s.runTask(id)
}
}
}
func (s *TaskService) runTask(id uint) error {
task, err := s.GetTask(id)
if err != nil {
return err
}
s.setStatus(task, models.TaskStatusRunning, 1, "task started", "")
switch task.Type {
case models.TaskTypeDownload:
err = s.handleDownload(task)
case models.TaskTypeExtract:
err = s.handleExtract(task)
default:
err = errors.New("unsupported task type")
}
if err != nil {
s.setStatus(task, models.TaskStatusFailed, task.Progress, "", err.Error())
return err
}
s.setStatus(task, models.TaskStatusSuccess, 100, "task completed", "")
return nil
}
func (s *TaskService) handleDownload(task *models.Task) error {
target, err := s.fs.ResolvePath(task.TargetPath)
if err != nil {
return err
}
if err := os.MkdirAll(filepath.Dir(target), 0o755); err != nil {
return err
}
s.appendLog(task, "starting download: "+task.Source)
size, supportsRange, err := probeDownload(task.Source)
if err == nil && supportsRange && size >= minMultiPartSize {
parts := int(size / (8 * 1024 * 1024))
if parts < 2 {
parts = 2
}
if parts > maxDownloadParts {
parts = maxDownloadParts
}
s.appendLog(task, fmt.Sprintf("range download enabled: parts=%d, size=%d bytes", parts, size))
if err := s.multiPartDownload(task, task.Source, target, size, parts); err == nil {
s.appendLog(task, "download complete")
return nil
} else {
s.appendLog(task, "multi-part failed, fallback to single stream: "+err.Error())
_ = os.Remove(target)
}
}
if err := s.singlePartDownload(task, task.Source, target); err != nil {
return err
}
s.appendLog(task, "download complete")
return nil
}
func probeDownload(url string) (size int64, supportsRange bool, err error) {
req, err := http.NewRequest(http.MethodHead, url, nil)
if err != nil {
return 0, false, err
}
resp, err := http.DefaultClient.Do(req) // #nosec G107 user-provided URLs are expected in this app design
if err != nil {
return 0, false, err
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 400 {
return 0, false, fmt.Errorf("probe failed: %s", resp.Status)
}
size = resp.ContentLength
supportsRange = strings.Contains(strings.ToLower(resp.Header.Get("Accept-Ranges")), "bytes")
return size, supportsRange, nil
}
func (s *TaskService) singlePartDownload(task *models.Task, url, target string) error {
resp, err := http.Get(url) // #nosec G107 user-provided URLs are expected in this app design
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return fmt.Errorf("download failed: %s", resp.Status)
}
out, err := os.Create(target)
if err != nil {
return err
}
defer out.Close()
total := resp.ContentLength
var copied int64
reader := bufio.NewReader(resp.Body)
buf := make([]byte, 256*1024)
lastTick := time.Now()
for {
n, readErr := reader.Read(buf)
if n > 0 {
if _, err := out.Write(buf[:n]); err != nil {
return err
}
copied += int64(n)
if total > 0 && time.Since(lastTick) > 500*time.Millisecond {
progress := int((copied * 100) / total)
if progress > 99 {
progress = 99
}
s.setProgress(task, progress)
lastTick = time.Now()
}
}
if readErr == io.EOF {
break
}
if readErr != nil {
return readErr
}
}
return nil
}
func (s *TaskService) multiPartDownload(task *models.Task, url, target string, total int64, parts int) error {
file, err := os.OpenFile(target, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o644)
if err != nil {
return err
}
defer file.Close()
if err := file.Truncate(total); err != nil {
return err
}
var downloaded atomic.Int64
stopProgress := make(chan struct{})
go func() {
ticker := time.NewTicker(500 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-stopProgress:
return
case <-ticker.C:
progress := int((downloaded.Load() * 100) / total)
if progress > 99 {
progress = 99
}
s.setProgress(task, progress)
}
}
}()
var wg sync.WaitGroup
errCh := make(chan error, parts)
chunkSize := total / int64(parts)
for i := 0; i < parts; i++ {
start := int64(i) * chunkSize
end := start + chunkSize - 1
if i == parts-1 {
end = total - 1
}
wg.Add(1)
go func(partID int, rangeStart, rangeEnd int64) {
defer wg.Done()
var lastErr error
for attempt := 1; attempt <= partRetryLimit; attempt++ {
err := downloadRange(url, file, rangeStart, rangeEnd, &downloaded)
if err == nil {
return
}
lastErr = err
time.Sleep(time.Duration(attempt) * 400 * time.Millisecond)
}
errCh <- fmt.Errorf("part %d failed: %w", partID, lastErr)
}(i+1, start, end)
}
wg.Wait()
close(stopProgress)
close(errCh)
for err := range errCh {
if err != nil {
return err
}
}
s.setProgress(task, 99)
return nil
}
func downloadRange(url string, out *os.File, start, end int64, downloaded *atomic.Int64) error {
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return err
}
req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", start, end))
resp, err := http.DefaultClient.Do(req) // #nosec G107 user-provided URLs are expected in this app design
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusPartialContent {
return fmt.Errorf("server does not honor range request: %s", resp.Status)
}
reader := bufio.NewReader(resp.Body)
buf := make([]byte, 256*1024)
offset := start
expected := end - start + 1
var written int64
for {
n, readErr := reader.Read(buf)
if n > 0 {
if _, err := out.WriteAt(buf[:n], offset); err != nil {
return err
}
offset += int64(n)
written += int64(n)
downloaded.Add(int64(n))
}
if readErr == io.EOF {
break
}
if readErr != nil {
return readErr
}
}
if written != expected {
return fmt.Errorf("range incomplete: expected %d got %d", expected, written)
}
return nil
}
func (s *TaskService) handleExtract(task *models.Task) error {
src, err := s.fs.ResolvePath(task.Source)
if err != nil {
return err
}
target, err := s.fs.ResolvePath(task.TargetPath)
if err != nil {
return err
}
if err := os.MkdirAll(target, 0o755); err != nil {
return err
}
lower := strings.ToLower(src)
switch {
case strings.HasSuffix(lower, ".zip"):
return s.extractZip(task, src, target)
case strings.HasSuffix(lower, ".tar.gz"), strings.HasSuffix(lower, ".tgz"):
return s.extractTarGz(task, src, target)
default:
return errors.New("unsupported archive format")
}
}
func (s *TaskService) extractZip(task *models.Task, src, target string) error {
r, err := zip.OpenReader(src)
if err != nil {
return err
}
defer r.Close()
total := len(r.File)
if total == 0 {
return nil
}
for i, f := range r.File {
dest := filepath.Join(target, f.Name)
if !isSubPath(target, dest) {
return fmt.Errorf("zip slip detected: %s", f.Name)
}
if f.FileInfo().IsDir() {
if err := os.MkdirAll(dest, 0o755); err != nil {
return err
}
} else {
if err := os.MkdirAll(filepath.Dir(dest), 0o755); err != nil {
return err
}
in, err := f.Open()
if err != nil {
return err
}
out, err := os.Create(dest)
if err != nil {
in.Close()
return err
}
_, copyErr := io.Copy(out, in)
in.Close()
out.Close()
if copyErr != nil {
return copyErr
}
}
progress := int((int64(i+1) * 100) / int64(total))
if progress > 99 {
progress = 99
}
s.setProgress(task, progress)
}
s.appendLog(task, "zip extract complete")
return nil
}
func (s *TaskService) extractTarGz(task *models.Task, src, target string) error {
f, err := os.Open(src)
if err != nil {
return err
}
defer f.Close()
gzr, err := gzip.NewReader(f)
if err != nil {
return err
}
defer gzr.Close()
tr := tar.NewReader(gzr)
for {
hdr, err := tr.Next()
if err == io.EOF {
break
}
if err != nil {
return err
}
dest := filepath.Join(target, hdr.Name)
if !isSubPath(target, dest) {
return fmt.Errorf("tar path escape detected: %s", hdr.Name)
}
switch hdr.Typeflag {
case tar.TypeDir:
if err := os.MkdirAll(dest, 0o755); err != nil {
return err
}
case tar.TypeReg:
if err := os.MkdirAll(filepath.Dir(dest), 0o755); err != nil {
return err
}
out, err := os.Create(dest)
if err != nil {
return err
}
if _, err := io.Copy(out, tr); err != nil {
out.Close()
return err
}
out.Close()
}
}
s.setProgress(task, 99)
s.appendLog(task, "tar.gz extract complete")
return nil
}
func isSubPath(root, maybeChild string) bool {
root = filepath.Clean(root)
maybeChild = filepath.Clean(maybeChild)
if root == maybeChild {
return true
}
return strings.HasPrefix(maybeChild, root+string(os.PathSeparator))
}
func (s *TaskService) setProgress(task *models.Task, progress int) {
task.Progress = progress
s.db.Model(task).Updates(map[string]interface{}{
"progress": progress,
})
s.publish(task.ID, taskEvent{
TaskID: task.ID,
Status: task.Status,
Progress: progress,
})
}
func (s *TaskService) appendLog(task *models.Task, line string) {
task.Logs = strings.TrimSpace(task.Logs + "\n" + line)
s.db.Model(task).Update("logs", task.Logs)
s.publish(task.ID, taskEvent{
TaskID: task.ID,
Status: task.Status,
Progress: task.Progress,
LogLine: line,
})
}
func (s *TaskService) setStatus(task *models.Task, status models.TaskStatus, progress int, logLine, errorMsg string) {
task.Status = status
task.Progress = progress
task.Error = errorMsg
updates := map[string]interface{}{
"status": status,
"progress": progress,
"error": errorMsg,
}
if logLine != "" {
task.Logs = strings.TrimSpace(task.Logs + "\n" + logLine)
updates["logs"] = task.Logs
}
s.db.Model(task).Updates(updates)
s.publish(task.ID, taskEvent{
TaskID: task.ID,
Status: status,
Progress: progress,
LogLine: logLine,
Error: errorMsg,
})
}
func (s *TaskService) publish(taskID uint, event taskEvent) {
s.mu.RLock()
defer s.mu.RUnlock()
for ch := range s.subscribers[taskID] {
select {
case ch <- event:
default:
}
}
}