Spaces:
Sleeping
Sleeping
| package services | |
| import ( | |
| "archive/tar" | |
| "archive/zip" | |
| "bufio" | |
| "compress/gzip" | |
| "context" | |
| "errors" | |
| "fmt" | |
| "io" | |
| "net/http" | |
| "os" | |
| "path/filepath" | |
| "strings" | |
| "sync" | |
| "sync/atomic" | |
| "time" | |
| "fastfileviewer/backend/internal/models" | |
| "gorm.io/gorm" | |
| ) | |
| type taskEvent struct { | |
| TaskID uint | |
| Status models.TaskStatus | |
| Progress int | |
| LogLine string | |
| Error string | |
| } | |
| type TaskService struct { | |
| db *gorm.DB | |
| fs *FSService | |
| queue chan uint | |
| workers int | |
| mu sync.RWMutex | |
| subscribers map[uint]map[chan taskEvent]struct{} | |
| } | |
| const ( | |
| minMultiPartSize = 16 * 1024 * 1024 | |
| maxDownloadParts = 8 | |
| partRetryLimit = 3 | |
| ) | |
| func NewTaskService(database *gorm.DB, fs *FSService, workers int) *TaskService { | |
| if workers < 1 { | |
| workers = 1 | |
| } | |
| return &TaskService{ | |
| db: database, | |
| fs: fs, | |
| queue: make(chan uint, 256), | |
| workers: workers, | |
| subscribers: map[uint]map[chan taskEvent]struct{}{}, | |
| } | |
| } | |
| func (s *TaskService) Start(ctx context.Context) { | |
| for i := 0; i < s.workers; i++ { | |
| go s.worker(ctx) | |
| } | |
| } | |
| func (s *TaskService) CreateDownloadTask(url, targetPath string) (*models.Task, error) { | |
| task := &models.Task{ | |
| Type: models.TaskTypeDownload, | |
| Status: models.TaskStatusQueued, | |
| Source: url, | |
| TargetPath: targetPath, | |
| Progress: 0, | |
| } | |
| if err := s.db.Create(task).Error; err != nil { | |
| return nil, err | |
| } | |
| s.queue <- task.ID | |
| return task, nil | |
| } | |
| func (s *TaskService) CreateExtractTask(sourceArchive, targetPath string) (*models.Task, error) { | |
| task := &models.Task{ | |
| Type: models.TaskTypeExtract, | |
| Status: models.TaskStatusQueued, | |
| Source: sourceArchive, | |
| TargetPath: targetPath, | |
| Progress: 0, | |
| } | |
| if err := s.db.Create(task).Error; err != nil { | |
| return nil, err | |
| } | |
| s.queue <- task.ID | |
| return task, nil | |
| } | |
| func (s *TaskService) ListTasks(limit int) ([]models.Task, error) { | |
| if limit <= 0 || limit > 100 { | |
| limit = 50 | |
| } | |
| var tasks []models.Task | |
| err := s.db.Order("id desc").Limit(limit).Find(&tasks).Error | |
| return tasks, err | |
| } | |
| func (s *TaskService) GetTask(id uint) (*models.Task, error) { | |
| var task models.Task | |
| if err := s.db.First(&task, id).Error; err != nil { | |
| return nil, err | |
| } | |
| return &task, nil | |
| } | |
| func (s *TaskService) Subscribe(taskID uint) (chan taskEvent, func()) { | |
| ch := make(chan taskEvent, 16) | |
| s.mu.Lock() | |
| if s.subscribers[taskID] == nil { | |
| s.subscribers[taskID] = map[chan taskEvent]struct{}{} | |
| } | |
| s.subscribers[taskID][ch] = struct{}{} | |
| s.mu.Unlock() | |
| cancel := func() { | |
| s.mu.Lock() | |
| defer s.mu.Unlock() | |
| delete(s.subscribers[taskID], ch) | |
| close(ch) | |
| if len(s.subscribers[taskID]) == 0 { | |
| delete(s.subscribers, taskID) | |
| } | |
| } | |
| return ch, cancel | |
| } | |
| func (s *TaskService) worker(ctx context.Context) { | |
| for { | |
| select { | |
| case <-ctx.Done(): | |
| return | |
| case id := <-s.queue: | |
| _ = s.runTask(id) | |
| } | |
| } | |
| } | |
| func (s *TaskService) runTask(id uint) error { | |
| task, err := s.GetTask(id) | |
| if err != nil { | |
| return err | |
| } | |
| s.setStatus(task, models.TaskStatusRunning, 1, "task started", "") | |
| switch task.Type { | |
| case models.TaskTypeDownload: | |
| err = s.handleDownload(task) | |
| case models.TaskTypeExtract: | |
| err = s.handleExtract(task) | |
| default: | |
| err = errors.New("unsupported task type") | |
| } | |
| if err != nil { | |
| s.setStatus(task, models.TaskStatusFailed, task.Progress, "", err.Error()) | |
| return err | |
| } | |
| s.setStatus(task, models.TaskStatusSuccess, 100, "task completed", "") | |
| return nil | |
| } | |
| func (s *TaskService) handleDownload(task *models.Task) error { | |
| target, err := s.fs.ResolvePath(task.TargetPath) | |
| if err != nil { | |
| return err | |
| } | |
| if err := os.MkdirAll(filepath.Dir(target), 0o755); err != nil { | |
| return err | |
| } | |
| s.appendLog(task, "starting download: "+task.Source) | |
| size, supportsRange, err := probeDownload(task.Source) | |
| if err == nil && supportsRange && size >= minMultiPartSize { | |
| parts := int(size / (8 * 1024 * 1024)) | |
| if parts < 2 { | |
| parts = 2 | |
| } | |
| if parts > maxDownloadParts { | |
| parts = maxDownloadParts | |
| } | |
| s.appendLog(task, fmt.Sprintf("range download enabled: parts=%d, size=%d bytes", parts, size)) | |
| if err := s.multiPartDownload(task, task.Source, target, size, parts); err == nil { | |
| s.appendLog(task, "download complete") | |
| return nil | |
| } else { | |
| s.appendLog(task, "multi-part failed, fallback to single stream: "+err.Error()) | |
| _ = os.Remove(target) | |
| } | |
| } | |
| if err := s.singlePartDownload(task, task.Source, target); err != nil { | |
| return err | |
| } | |
| s.appendLog(task, "download complete") | |
| return nil | |
| } | |
| func probeDownload(url string) (size int64, supportsRange bool, err error) { | |
| req, err := http.NewRequest(http.MethodHead, url, nil) | |
| if err != nil { | |
| return 0, false, err | |
| } | |
| resp, err := http.DefaultClient.Do(req) // #nosec G107 user-provided URLs are expected in this app design | |
| if err != nil { | |
| return 0, false, err | |
| } | |
| defer resp.Body.Close() | |
| if resp.StatusCode < 200 || resp.StatusCode >= 400 { | |
| return 0, false, fmt.Errorf("probe failed: %s", resp.Status) | |
| } | |
| size = resp.ContentLength | |
| supportsRange = strings.Contains(strings.ToLower(resp.Header.Get("Accept-Ranges")), "bytes") | |
| return size, supportsRange, nil | |
| } | |
| func (s *TaskService) singlePartDownload(task *models.Task, url, target string) error { | |
| resp, err := http.Get(url) // #nosec G107 user-provided URLs are expected in this app design | |
| if err != nil { | |
| return err | |
| } | |
| defer resp.Body.Close() | |
| if resp.StatusCode < 200 || resp.StatusCode >= 300 { | |
| return fmt.Errorf("download failed: %s", resp.Status) | |
| } | |
| out, err := os.Create(target) | |
| if err != nil { | |
| return err | |
| } | |
| defer out.Close() | |
| total := resp.ContentLength | |
| var copied int64 | |
| reader := bufio.NewReader(resp.Body) | |
| buf := make([]byte, 256*1024) | |
| lastTick := time.Now() | |
| for { | |
| n, readErr := reader.Read(buf) | |
| if n > 0 { | |
| if _, err := out.Write(buf[:n]); err != nil { | |
| return err | |
| } | |
| copied += int64(n) | |
| if total > 0 && time.Since(lastTick) > 500*time.Millisecond { | |
| progress := int((copied * 100) / total) | |
| if progress > 99 { | |
| progress = 99 | |
| } | |
| s.setProgress(task, progress) | |
| lastTick = time.Now() | |
| } | |
| } | |
| if readErr == io.EOF { | |
| break | |
| } | |
| if readErr != nil { | |
| return readErr | |
| } | |
| } | |
| return nil | |
| } | |
| func (s *TaskService) multiPartDownload(task *models.Task, url, target string, total int64, parts int) error { | |
| file, err := os.OpenFile(target, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o644) | |
| if err != nil { | |
| return err | |
| } | |
| defer file.Close() | |
| if err := file.Truncate(total); err != nil { | |
| return err | |
| } | |
| var downloaded atomic.Int64 | |
| stopProgress := make(chan struct{}) | |
| go func() { | |
| ticker := time.NewTicker(500 * time.Millisecond) | |
| defer ticker.Stop() | |
| for { | |
| select { | |
| case <-stopProgress: | |
| return | |
| case <-ticker.C: | |
| progress := int((downloaded.Load() * 100) / total) | |
| if progress > 99 { | |
| progress = 99 | |
| } | |
| s.setProgress(task, progress) | |
| } | |
| } | |
| }() | |
| var wg sync.WaitGroup | |
| errCh := make(chan error, parts) | |
| chunkSize := total / int64(parts) | |
| for i := 0; i < parts; i++ { | |
| start := int64(i) * chunkSize | |
| end := start + chunkSize - 1 | |
| if i == parts-1 { | |
| end = total - 1 | |
| } | |
| wg.Add(1) | |
| go func(partID int, rangeStart, rangeEnd int64) { | |
| defer wg.Done() | |
| var lastErr error | |
| for attempt := 1; attempt <= partRetryLimit; attempt++ { | |
| err := downloadRange(url, file, rangeStart, rangeEnd, &downloaded) | |
| if err == nil { | |
| return | |
| } | |
| lastErr = err | |
| time.Sleep(time.Duration(attempt) * 400 * time.Millisecond) | |
| } | |
| errCh <- fmt.Errorf("part %d failed: %w", partID, lastErr) | |
| }(i+1, start, end) | |
| } | |
| wg.Wait() | |
| close(stopProgress) | |
| close(errCh) | |
| for err := range errCh { | |
| if err != nil { | |
| return err | |
| } | |
| } | |
| s.setProgress(task, 99) | |
| return nil | |
| } | |
| func downloadRange(url string, out *os.File, start, end int64, downloaded *atomic.Int64) error { | |
| req, err := http.NewRequest(http.MethodGet, url, nil) | |
| if err != nil { | |
| return err | |
| } | |
| req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", start, end)) | |
| resp, err := http.DefaultClient.Do(req) // #nosec G107 user-provided URLs are expected in this app design | |
| if err != nil { | |
| return err | |
| } | |
| defer resp.Body.Close() | |
| if resp.StatusCode != http.StatusPartialContent { | |
| return fmt.Errorf("server does not honor range request: %s", resp.Status) | |
| } | |
| reader := bufio.NewReader(resp.Body) | |
| buf := make([]byte, 256*1024) | |
| offset := start | |
| expected := end - start + 1 | |
| var written int64 | |
| for { | |
| n, readErr := reader.Read(buf) | |
| if n > 0 { | |
| if _, err := out.WriteAt(buf[:n], offset); err != nil { | |
| return err | |
| } | |
| offset += int64(n) | |
| written += int64(n) | |
| downloaded.Add(int64(n)) | |
| } | |
| if readErr == io.EOF { | |
| break | |
| } | |
| if readErr != nil { | |
| return readErr | |
| } | |
| } | |
| if written != expected { | |
| return fmt.Errorf("range incomplete: expected %d got %d", expected, written) | |
| } | |
| return nil | |
| } | |
| func (s *TaskService) handleExtract(task *models.Task) error { | |
| src, err := s.fs.ResolvePath(task.Source) | |
| if err != nil { | |
| return err | |
| } | |
| target, err := s.fs.ResolvePath(task.TargetPath) | |
| if err != nil { | |
| return err | |
| } | |
| if err := os.MkdirAll(target, 0o755); err != nil { | |
| return err | |
| } | |
| lower := strings.ToLower(src) | |
| switch { | |
| case strings.HasSuffix(lower, ".zip"): | |
| return s.extractZip(task, src, target) | |
| case strings.HasSuffix(lower, ".tar.gz"), strings.HasSuffix(lower, ".tgz"): | |
| return s.extractTarGz(task, src, target) | |
| default: | |
| return errors.New("unsupported archive format") | |
| } | |
| } | |
| func (s *TaskService) extractZip(task *models.Task, src, target string) error { | |
| r, err := zip.OpenReader(src) | |
| if err != nil { | |
| return err | |
| } | |
| defer r.Close() | |
| total := len(r.File) | |
| if total == 0 { | |
| return nil | |
| } | |
| for i, f := range r.File { | |
| dest := filepath.Join(target, f.Name) | |
| if !isSubPath(target, dest) { | |
| return fmt.Errorf("zip slip detected: %s", f.Name) | |
| } | |
| if f.FileInfo().IsDir() { | |
| if err := os.MkdirAll(dest, 0o755); err != nil { | |
| return err | |
| } | |
| } else { | |
| if err := os.MkdirAll(filepath.Dir(dest), 0o755); err != nil { | |
| return err | |
| } | |
| in, err := f.Open() | |
| if err != nil { | |
| return err | |
| } | |
| out, err := os.Create(dest) | |
| if err != nil { | |
| in.Close() | |
| return err | |
| } | |
| _, copyErr := io.Copy(out, in) | |
| in.Close() | |
| out.Close() | |
| if copyErr != nil { | |
| return copyErr | |
| } | |
| } | |
| progress := int((int64(i+1) * 100) / int64(total)) | |
| if progress > 99 { | |
| progress = 99 | |
| } | |
| s.setProgress(task, progress) | |
| } | |
| s.appendLog(task, "zip extract complete") | |
| return nil | |
| } | |
| func (s *TaskService) extractTarGz(task *models.Task, src, target string) error { | |
| f, err := os.Open(src) | |
| if err != nil { | |
| return err | |
| } | |
| defer f.Close() | |
| gzr, err := gzip.NewReader(f) | |
| if err != nil { | |
| return err | |
| } | |
| defer gzr.Close() | |
| tr := tar.NewReader(gzr) | |
| for { | |
| hdr, err := tr.Next() | |
| if err == io.EOF { | |
| break | |
| } | |
| if err != nil { | |
| return err | |
| } | |
| dest := filepath.Join(target, hdr.Name) | |
| if !isSubPath(target, dest) { | |
| return fmt.Errorf("tar path escape detected: %s", hdr.Name) | |
| } | |
| switch hdr.Typeflag { | |
| case tar.TypeDir: | |
| if err := os.MkdirAll(dest, 0o755); err != nil { | |
| return err | |
| } | |
| case tar.TypeReg: | |
| if err := os.MkdirAll(filepath.Dir(dest), 0o755); err != nil { | |
| return err | |
| } | |
| out, err := os.Create(dest) | |
| if err != nil { | |
| return err | |
| } | |
| if _, err := io.Copy(out, tr); err != nil { | |
| out.Close() | |
| return err | |
| } | |
| out.Close() | |
| } | |
| } | |
| s.setProgress(task, 99) | |
| s.appendLog(task, "tar.gz extract complete") | |
| return nil | |
| } | |
| func isSubPath(root, maybeChild string) bool { | |
| root = filepath.Clean(root) | |
| maybeChild = filepath.Clean(maybeChild) | |
| if root == maybeChild { | |
| return true | |
| } | |
| return strings.HasPrefix(maybeChild, root+string(os.PathSeparator)) | |
| } | |
| func (s *TaskService) setProgress(task *models.Task, progress int) { | |
| task.Progress = progress | |
| s.db.Model(task).Updates(map[string]interface{}{ | |
| "progress": progress, | |
| }) | |
| s.publish(task.ID, taskEvent{ | |
| TaskID: task.ID, | |
| Status: task.Status, | |
| Progress: progress, | |
| }) | |
| } | |
| func (s *TaskService) appendLog(task *models.Task, line string) { | |
| task.Logs = strings.TrimSpace(task.Logs + "\n" + line) | |
| s.db.Model(task).Update("logs", task.Logs) | |
| s.publish(task.ID, taskEvent{ | |
| TaskID: task.ID, | |
| Status: task.Status, | |
| Progress: task.Progress, | |
| LogLine: line, | |
| }) | |
| } | |
| func (s *TaskService) setStatus(task *models.Task, status models.TaskStatus, progress int, logLine, errorMsg string) { | |
| task.Status = status | |
| task.Progress = progress | |
| task.Error = errorMsg | |
| updates := map[string]interface{}{ | |
| "status": status, | |
| "progress": progress, | |
| "error": errorMsg, | |
| } | |
| if logLine != "" { | |
| task.Logs = strings.TrimSpace(task.Logs + "\n" + logLine) | |
| updates["logs"] = task.Logs | |
| } | |
| s.db.Model(task).Updates(updates) | |
| s.publish(task.ID, taskEvent{ | |
| TaskID: task.ID, | |
| Status: status, | |
| Progress: progress, | |
| LogLine: logLine, | |
| Error: errorMsg, | |
| }) | |
| } | |
| func (s *TaskService) publish(taskID uint, event taskEvent) { | |
| s.mu.RLock() | |
| defer s.mu.RUnlock() | |
| for ch := range s.subscribers[taskID] { | |
| select { | |
| case ch <- event: | |
| default: | |
| } | |
| } | |
| } | |