package services import ( "archive/tar" "archive/zip" "bufio" "compress/gzip" "context" "errors" "fmt" "io" "net/http" "os" "path/filepath" "strings" "sync" "sync/atomic" "time" "fastfileviewer/backend/internal/models" "gorm.io/gorm" ) type taskEvent struct { TaskID uint Status models.TaskStatus Progress int LogLine string Error string } type TaskService struct { db *gorm.DB fs *FSService queue chan uint workers int mu sync.RWMutex subscribers map[uint]map[chan taskEvent]struct{} } const ( minMultiPartSize = 16 * 1024 * 1024 maxDownloadParts = 8 partRetryLimit = 3 ) func NewTaskService(database *gorm.DB, fs *FSService, workers int) *TaskService { if workers < 1 { workers = 1 } return &TaskService{ db: database, fs: fs, queue: make(chan uint, 256), workers: workers, subscribers: map[uint]map[chan taskEvent]struct{}{}, } } func (s *TaskService) Start(ctx context.Context) { for i := 0; i < s.workers; i++ { go s.worker(ctx) } } func (s *TaskService) CreateDownloadTask(url, targetPath string) (*models.Task, error) { task := &models.Task{ Type: models.TaskTypeDownload, Status: models.TaskStatusQueued, Source: url, TargetPath: targetPath, Progress: 0, } if err := s.db.Create(task).Error; err != nil { return nil, err } s.queue <- task.ID return task, nil } func (s *TaskService) CreateExtractTask(sourceArchive, targetPath string) (*models.Task, error) { task := &models.Task{ Type: models.TaskTypeExtract, Status: models.TaskStatusQueued, Source: sourceArchive, TargetPath: targetPath, Progress: 0, } if err := s.db.Create(task).Error; err != nil { return nil, err } s.queue <- task.ID return task, nil } func (s *TaskService) ListTasks(limit int) ([]models.Task, error) { if limit <= 0 || limit > 100 { limit = 50 } var tasks []models.Task err := s.db.Order("id desc").Limit(limit).Find(&tasks).Error return tasks, err } func (s *TaskService) GetTask(id uint) (*models.Task, error) { var task models.Task if err := s.db.First(&task, id).Error; err != nil { return nil, err } return &task, nil } func (s *TaskService) Subscribe(taskID uint) (chan taskEvent, func()) { ch := make(chan taskEvent, 16) s.mu.Lock() if s.subscribers[taskID] == nil { s.subscribers[taskID] = map[chan taskEvent]struct{}{} } s.subscribers[taskID][ch] = struct{}{} s.mu.Unlock() cancel := func() { s.mu.Lock() defer s.mu.Unlock() delete(s.subscribers[taskID], ch) close(ch) if len(s.subscribers[taskID]) == 0 { delete(s.subscribers, taskID) } } return ch, cancel } func (s *TaskService) worker(ctx context.Context) { for { select { case <-ctx.Done(): return case id := <-s.queue: _ = s.runTask(id) } } } func (s *TaskService) runTask(id uint) error { task, err := s.GetTask(id) if err != nil { return err } s.setStatus(task, models.TaskStatusRunning, 1, "task started", "") switch task.Type { case models.TaskTypeDownload: err = s.handleDownload(task) case models.TaskTypeExtract: err = s.handleExtract(task) default: err = errors.New("unsupported task type") } if err != nil { s.setStatus(task, models.TaskStatusFailed, task.Progress, "", err.Error()) return err } s.setStatus(task, models.TaskStatusSuccess, 100, "task completed", "") return nil } func (s *TaskService) handleDownload(task *models.Task) error { target, err := s.fs.ResolvePath(task.TargetPath) if err != nil { return err } if err := os.MkdirAll(filepath.Dir(target), 0o755); err != nil { return err } s.appendLog(task, "starting download: "+task.Source) size, supportsRange, err := probeDownload(task.Source) if err == nil && supportsRange && size >= minMultiPartSize { parts := int(size / (8 * 1024 * 1024)) if parts < 2 { parts = 2 } if parts > maxDownloadParts { parts = maxDownloadParts } s.appendLog(task, fmt.Sprintf("range download enabled: parts=%d, size=%d bytes", parts, size)) if err := s.multiPartDownload(task, task.Source, target, size, parts); err == nil { s.appendLog(task, "download complete") return nil } else { s.appendLog(task, "multi-part failed, fallback to single stream: "+err.Error()) _ = os.Remove(target) } } if err := s.singlePartDownload(task, task.Source, target); err != nil { return err } s.appendLog(task, "download complete") return nil } func probeDownload(url string) (size int64, supportsRange bool, err error) { req, err := http.NewRequest(http.MethodHead, url, nil) if err != nil { return 0, false, err } resp, err := http.DefaultClient.Do(req) // #nosec G107 user-provided URLs are expected in this app design if err != nil { return 0, false, err } defer resp.Body.Close() if resp.StatusCode < 200 || resp.StatusCode >= 400 { return 0, false, fmt.Errorf("probe failed: %s", resp.Status) } size = resp.ContentLength supportsRange = strings.Contains(strings.ToLower(resp.Header.Get("Accept-Ranges")), "bytes") return size, supportsRange, nil } func (s *TaskService) singlePartDownload(task *models.Task, url, target string) error { resp, err := http.Get(url) // #nosec G107 user-provided URLs are expected in this app design if err != nil { return err } defer resp.Body.Close() if resp.StatusCode < 200 || resp.StatusCode >= 300 { return fmt.Errorf("download failed: %s", resp.Status) } out, err := os.Create(target) if err != nil { return err } defer out.Close() total := resp.ContentLength var copied int64 reader := bufio.NewReader(resp.Body) buf := make([]byte, 256*1024) lastTick := time.Now() for { n, readErr := reader.Read(buf) if n > 0 { if _, err := out.Write(buf[:n]); err != nil { return err } copied += int64(n) if total > 0 && time.Since(lastTick) > 500*time.Millisecond { progress := int((copied * 100) / total) if progress > 99 { progress = 99 } s.setProgress(task, progress) lastTick = time.Now() } } if readErr == io.EOF { break } if readErr != nil { return readErr } } return nil } func (s *TaskService) multiPartDownload(task *models.Task, url, target string, total int64, parts int) error { file, err := os.OpenFile(target, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o644) if err != nil { return err } defer file.Close() if err := file.Truncate(total); err != nil { return err } var downloaded atomic.Int64 stopProgress := make(chan struct{}) go func() { ticker := time.NewTicker(500 * time.Millisecond) defer ticker.Stop() for { select { case <-stopProgress: return case <-ticker.C: progress := int((downloaded.Load() * 100) / total) if progress > 99 { progress = 99 } s.setProgress(task, progress) } } }() var wg sync.WaitGroup errCh := make(chan error, parts) chunkSize := total / int64(parts) for i := 0; i < parts; i++ { start := int64(i) * chunkSize end := start + chunkSize - 1 if i == parts-1 { end = total - 1 } wg.Add(1) go func(partID int, rangeStart, rangeEnd int64) { defer wg.Done() var lastErr error for attempt := 1; attempt <= partRetryLimit; attempt++ { err := downloadRange(url, file, rangeStart, rangeEnd, &downloaded) if err == nil { return } lastErr = err time.Sleep(time.Duration(attempt) * 400 * time.Millisecond) } errCh <- fmt.Errorf("part %d failed: %w", partID, lastErr) }(i+1, start, end) } wg.Wait() close(stopProgress) close(errCh) for err := range errCh { if err != nil { return err } } s.setProgress(task, 99) return nil } func downloadRange(url string, out *os.File, start, end int64, downloaded *atomic.Int64) error { req, err := http.NewRequest(http.MethodGet, url, nil) if err != nil { return err } req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", start, end)) resp, err := http.DefaultClient.Do(req) // #nosec G107 user-provided URLs are expected in this app design if err != nil { return err } defer resp.Body.Close() if resp.StatusCode != http.StatusPartialContent { return fmt.Errorf("server does not honor range request: %s", resp.Status) } reader := bufio.NewReader(resp.Body) buf := make([]byte, 256*1024) offset := start expected := end - start + 1 var written int64 for { n, readErr := reader.Read(buf) if n > 0 { if _, err := out.WriteAt(buf[:n], offset); err != nil { return err } offset += int64(n) written += int64(n) downloaded.Add(int64(n)) } if readErr == io.EOF { break } if readErr != nil { return readErr } } if written != expected { return fmt.Errorf("range incomplete: expected %d got %d", expected, written) } return nil } func (s *TaskService) handleExtract(task *models.Task) error { src, err := s.fs.ResolvePath(task.Source) if err != nil { return err } target, err := s.fs.ResolvePath(task.TargetPath) if err != nil { return err } if err := os.MkdirAll(target, 0o755); err != nil { return err } lower := strings.ToLower(src) switch { case strings.HasSuffix(lower, ".zip"): return s.extractZip(task, src, target) case strings.HasSuffix(lower, ".tar.gz"), strings.HasSuffix(lower, ".tgz"): return s.extractTarGz(task, src, target) default: return errors.New("unsupported archive format") } } func (s *TaskService) extractZip(task *models.Task, src, target string) error { r, err := zip.OpenReader(src) if err != nil { return err } defer r.Close() total := len(r.File) if total == 0 { return nil } for i, f := range r.File { dest := filepath.Join(target, f.Name) if !isSubPath(target, dest) { return fmt.Errorf("zip slip detected: %s", f.Name) } if f.FileInfo().IsDir() { if err := os.MkdirAll(dest, 0o755); err != nil { return err } } else { if err := os.MkdirAll(filepath.Dir(dest), 0o755); err != nil { return err } in, err := f.Open() if err != nil { return err } out, err := os.Create(dest) if err != nil { in.Close() return err } _, copyErr := io.Copy(out, in) in.Close() out.Close() if copyErr != nil { return copyErr } } progress := int((int64(i+1) * 100) / int64(total)) if progress > 99 { progress = 99 } s.setProgress(task, progress) } s.appendLog(task, "zip extract complete") return nil } func (s *TaskService) extractTarGz(task *models.Task, src, target string) error { f, err := os.Open(src) if err != nil { return err } defer f.Close() gzr, err := gzip.NewReader(f) if err != nil { return err } defer gzr.Close() tr := tar.NewReader(gzr) for { hdr, err := tr.Next() if err == io.EOF { break } if err != nil { return err } dest := filepath.Join(target, hdr.Name) if !isSubPath(target, dest) { return fmt.Errorf("tar path escape detected: %s", hdr.Name) } switch hdr.Typeflag { case tar.TypeDir: if err := os.MkdirAll(dest, 0o755); err != nil { return err } case tar.TypeReg: if err := os.MkdirAll(filepath.Dir(dest), 0o755); err != nil { return err } out, err := os.Create(dest) if err != nil { return err } if _, err := io.Copy(out, tr); err != nil { out.Close() return err } out.Close() } } s.setProgress(task, 99) s.appendLog(task, "tar.gz extract complete") return nil } func isSubPath(root, maybeChild string) bool { root = filepath.Clean(root) maybeChild = filepath.Clean(maybeChild) if root == maybeChild { return true } return strings.HasPrefix(maybeChild, root+string(os.PathSeparator)) } func (s *TaskService) setProgress(task *models.Task, progress int) { task.Progress = progress s.db.Model(task).Updates(map[string]interface{}{ "progress": progress, }) s.publish(task.ID, taskEvent{ TaskID: task.ID, Status: task.Status, Progress: progress, }) } func (s *TaskService) appendLog(task *models.Task, line string) { task.Logs = strings.TrimSpace(task.Logs + "\n" + line) s.db.Model(task).Update("logs", task.Logs) s.publish(task.ID, taskEvent{ TaskID: task.ID, Status: task.Status, Progress: task.Progress, LogLine: line, }) } func (s *TaskService) setStatus(task *models.Task, status models.TaskStatus, progress int, logLine, errorMsg string) { task.Status = status task.Progress = progress task.Error = errorMsg updates := map[string]interface{}{ "status": status, "progress": progress, "error": errorMsg, } if logLine != "" { task.Logs = strings.TrimSpace(task.Logs + "\n" + logLine) updates["logs"] = task.Logs } s.db.Model(task).Updates(updates) s.publish(task.ID, taskEvent{ TaskID: task.ID, Status: status, Progress: progress, LogLine: logLine, Error: errorMsg, }) } func (s *TaskService) publish(taskID uint, event taskEvent) { s.mu.RLock() defer s.mu.RUnlock() for ch := range s.subscribers[taskID] { select { case ch <- event: default: } } }