axonhub / internal /pkg /watcher /watcher_redis.go
llzai's picture
Upload 1793 files
9853396 verified
package watcher
import (
"context"
"encoding/json"
"errors"
"sync"
"github.com/redis/go-redis/v9"
"github.com/looplj/axonhub/internal/log"
"github.com/looplj/axonhub/internal/pkg/xredis"
)
type RedisWatcherOptions struct {
Channel string
Buffer int
}
type redisWatcher[T any] struct {
client *redis.Client
channel string
buffer int
mu sync.Mutex
nextID uint64
subs map[uint64]chan T
active int
pubsub *redis.PubSub
cancel context.CancelFunc
}
func NewRedisWatcher[T any](client *redis.Client, opts RedisWatcherOptions) (Notifier[T], error) {
if client == nil {
return nil, errors.New("watcher.RedisWatcher: redis client is required")
}
if opts.Channel == "" {
return nil, errors.New("watcher.RedisWatcher: channel is required")
}
buffer := opts.Buffer
if buffer <= 0 {
buffer = 1
}
return &redisWatcher[T]{
client: client,
channel: opts.Channel,
buffer: buffer,
subs: make(map[uint64]chan T),
}, nil
}
func NewRedisWatcherFromConfig[T any](cfg xredis.Config, opts RedisWatcherOptions) (Notifier[T], error) {
client, err := xredis.NewClient(cfg)
if err != nil {
return nil, err
}
return NewRedisWatcher[T](client, opts)
}
func (w *redisWatcher[T]) Watch() (<-chan T, func()) {
w.mu.Lock()
defer w.mu.Unlock()
id := w.nextID
w.nextID++
ch := make(chan T, w.buffer)
w.subs[id] = ch
w.active++
if w.active == 1 {
w.startLocked()
}
return ch, func() {
w.mu.Lock()
defer w.mu.Unlock()
sub, ok := w.subs[id]
if !ok {
return
}
delete(w.subs, id)
close(sub)
w.active--
if w.active == 0 {
w.stopLocked()
}
}
}
func (w *redisWatcher[T]) Notify(ctx context.Context, v T) error {
payload, err := json.Marshal(v)
if err != nil {
return err
}
return w.client.Publish(ctx, w.channel, payload).Err()
}
func (w *redisWatcher[T]) startLocked() {
if w.pubsub != nil {
return
}
ctx, cancel := context.WithCancel(context.Background())
w.cancel = cancel
w.pubsub = w.client.Subscribe(ctx, w.channel)
_, _ = w.pubsub.Receive(ctx)
ps := w.pubsub
go func(ps *redis.PubSub) {
for {
msg, err := ps.ReceiveMessage(ctx)
if err != nil {
if errors.Is(err, context.Canceled) || errors.Is(err, redis.ErrClosed) || ctx.Err() != nil {
return
}
log.Warn(context.Background(), "watcher redis watcher receive failed",
log.String("channel", w.channel),
log.Cause(err))
continue
}
var v T
if err := json.Unmarshal([]byte(msg.Payload), &v); err != nil {
log.Warn(context.Background(), "watcher redis watcher decode failed",
log.String("channel", w.channel),
log.String("payload", msg.Payload),
log.Cause(err))
continue
}
w.mu.Lock()
for _, sub := range w.subs {
select {
case sub <- v:
default:
}
}
w.mu.Unlock()
}
}(ps)
}
func (w *redisWatcher[T]) stopLocked() {
if w.cancel != nil {
w.cancel()
w.cancel = nil
}
if w.pubsub != nil {
_ = w.pubsub.Close()
w.pubsub = nil
}
}