File size: 2,844 Bytes
619f93d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 | package errgroup
import (
"context"
"fmt"
"sync"
"sync/atomic"
"github.com/avast/retry-go"
)
type token struct{}
type Group struct {
cancel func(error)
ctx context.Context
opts []retry.Option
success uint64
wg sync.WaitGroup
sem chan token
startChan chan token
}
func NewGroupWithContext(ctx context.Context, limit int, retryOpts ...retry.Option) (*Group, context.Context) {
ctx, cancel := context.WithCancelCause(ctx)
return (&Group{cancel: cancel, ctx: ctx, opts: append(retryOpts, retry.Context(ctx))}).SetLimit(limit), ctx
}
// OrderedGroup
// 使得Lifecycle.Before是有序且线程安全
func NewOrderedGroupWithContext(ctx context.Context, limit int, retryOpts ...retry.Option) (*Group, context.Context) {
group, ctx := NewGroupWithContext(ctx, limit, retryOpts...)
group.startChan = make(chan token, 1)
return group, ctx
}
func (g *Group) done() {
if g.sem != nil {
<-g.sem
}
g.wg.Done()
atomic.AddUint64(&g.success, 1)
}
func (g *Group) Wait() error {
g.wg.Wait()
return context.Cause(g.ctx)
}
func (g *Group) Go(do func(ctx context.Context) error) {
g.GoWithLifecycle(Lifecycle{Do: do})
}
type Lifecycle struct {
// Before在OrderedGroup是有序且线程安全的
// 只会被调用一次
Before func(ctx context.Context) (err error)
// 如果Before返回err就不调用Do
Do func(ctx context.Context) (err error)
// 最后调用一次After
After func(err error)
}
func (g *Group) GoWithLifecycle(lifecycle Lifecycle) {
if g.startChan != nil {
select {
case <-g.ctx.Done():
return
case g.startChan <- token{}:
}
}
if g.sem != nil {
select {
case <-g.ctx.Done():
return
case g.sem <- token{}:
}
}
g.wg.Add(1)
go func() {
defer g.done()
var err error
if lifecycle.Before != nil {
err = lifecycle.Before(g.ctx)
}
if err == nil {
if g.startChan != nil {
<-g.startChan
}
err = retry.Do(func() error { return lifecycle.Do(g.ctx) }, g.opts...)
}
if lifecycle.After != nil {
lifecycle.After(err)
}
if err != nil {
select {
case <-g.ctx.Done():
return
default:
g.cancel(err)
}
}
}()
}
func (g *Group) TryGo(f func(ctx context.Context) error) bool {
if g.sem != nil {
select {
case g.sem <- token{}:
default:
return false
}
}
g.wg.Add(1)
go func() {
defer g.done()
if err := retry.Do(func() error { return f(g.ctx) }, g.opts...); err != nil {
g.cancel(err)
}
}()
return true
}
func (g *Group) SetLimit(n int) *Group {
if len(g.sem) != 0 {
panic(fmt.Errorf("errgroup: modify limit while %v goroutines in the group are still active", len(g.sem)))
}
if n > 0 {
g.sem = make(chan token, n)
} else {
g.sem = nil
}
return g
}
func (g *Group) Success() uint64 {
return atomic.LoadUint64(&g.success)
}
func (g *Group) Err() error {
return context.Cause(g.ctx)
}
|