File size: 2,660 Bytes
6a7089a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
package scheduler

import (
	"encoding/json"
	"errors"
	"log/slog"
	"net/http"

	"github.com/pinchtab/pinchtab/internal/web"
)

var (
	errMissingAgentID = errors.New("missing required field 'agentId'")
	errEmptyBatch     = errors.New("batch must contain at least one task")
)

// BatchRequest is the JSON body for POST /tasks/batch.
type BatchRequest struct {
	AgentID     string         `json:"agentId"`
	CallbackURL string         `json:"callbackUrl,omitempty"`
	Tasks       []BatchTaskDef `json:"tasks"`
}

// BatchTaskDef defines a single task inside a batch.
type BatchTaskDef struct {
	Action   string         `json:"action"`
	TabID    string         `json:"tabId,omitempty"`
	Ref      string         `json:"ref,omitempty"`
	Params   map[string]any `json:"params,omitempty"`
	Priority int            `json:"priority,omitempty"`
	Deadline string         `json:"deadline,omitempty"`
}

// BatchResponseItem is the result for each submitted task in the batch.
type BatchResponseItem struct {
	TaskID   string    `json:"taskId"`
	State    TaskState `json:"state"`
	Position int       `json:"position,omitempty"`
	Error    string    `json:"error,omitempty"`
}

func (s *Scheduler) handleBatch(w http.ResponseWriter, r *http.Request) {
	var req BatchRequest
	if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
		web.Error(w, 400, err)
		return
	}

	if req.AgentID == "" {
		web.Error(w, 400, errMissingAgentID)
		return
	}
	if len(req.Tasks) == 0 {
		web.Error(w, 400, errEmptyBatch)
		return
	}
	s.cfgMu.RLock()
	maxBatch := s.cfg.MaxBatchSize
	s.cfgMu.RUnlock()
	if len(req.Tasks) > maxBatch {
		web.ErrorCode(w, 400, "batch_too_large", "batch exceeds maximum size", false, map[string]any{
			"submitted": len(req.Tasks),
			"max":       maxBatch,
		})
		return
	}

	results := make([]BatchResponseItem, 0, len(req.Tasks))
	for _, td := range req.Tasks {
		sr := SubmitRequest{
			AgentID:     req.AgentID,
			Action:      td.Action,
			TabID:       td.TabID,
			Ref:         td.Ref,
			Params:      td.Params,
			Priority:    td.Priority,
			Deadline:    td.Deadline,
			CallbackURL: req.CallbackURL,
		}

		task, err := s.Submit(sr)
		if err != nil {
			item := BatchResponseItem{State: StateRejected, Error: err.Error()}
			if task != nil {
				item.TaskID = task.ID
			}
			results = append(results, item)
			slog.Warn("batch: task rejected", "agent", req.AgentID, "action", td.Action, "err", err)
			continue
		}

		results = append(results, BatchResponseItem{
			TaskID:   task.ID,
			State:    task.GetState(),
			Position: task.Position,
		})
	}

	web.JSON(w, 202, map[string]any{
		"tasks":     results,
		"submitted": len(results),
	})
}