| package httpclient
|
|
|
| import (
|
| "context"
|
| "errors"
|
| "io"
|
| "sync"
|
|
|
| "github.com/tmaxmax/go-sse"
|
|
|
| "github.com/looplj/axonhub/internal/log"
|
| )
|
|
|
|
|
| type decoderRegistry struct {
|
| mu sync.RWMutex
|
| decoders map[string]StreamDecoderFactory
|
| }
|
|
|
|
|
| var globalRegistry = &decoderRegistry{
|
| decoders: make(map[string]StreamDecoderFactory),
|
| }
|
|
|
|
|
| func RegisterDecoder(contentType string, factory StreamDecoderFactory) {
|
| globalRegistry.mu.Lock()
|
| defer globalRegistry.mu.Unlock()
|
|
|
| globalRegistry.decoders[contentType] = factory
|
| }
|
|
|
|
|
| func GetDecoder(contentType string) (StreamDecoderFactory, bool) {
|
| globalRegistry.mu.RLock()
|
| defer globalRegistry.mu.RUnlock()
|
|
|
| factory, exists := globalRegistry.decoders[contentType]
|
|
|
| return factory, exists
|
| }
|
|
|
|
|
| func NewDefaultSSEDecoder(ctx context.Context, rc io.ReadCloser) StreamDecoder {
|
| return &defaultSSEDecoder{
|
| ctx: ctx,
|
|
|
|
|
| sseStream: sse.NewStreamWithConfig(rc, &sse.StreamConfig{
|
| MaxEventSize: 32 * 1024 * 1024,
|
| }),
|
| }
|
| }
|
|
|
|
|
| var _ StreamDecoder = (*defaultSSEDecoder)(nil)
|
|
|
|
|
|
|
|
|
| type defaultSSEDecoder struct {
|
| ctx context.Context
|
| sseStream *sse.Stream
|
| current *StreamEvent
|
| err error
|
|
|
|
|
|
|
| closed bool
|
| closeErr error
|
| }
|
|
|
|
|
| func (s *defaultSSEDecoder) Next() bool {
|
| if s.err != nil {
|
| return false
|
| }
|
|
|
| if s.closed {
|
| return false
|
| }
|
|
|
|
|
| select {
|
| case <-s.ctx.Done():
|
| log.Debug(s.ctx, "SSE stream closed")
|
|
|
| s.err = s.ctx.Err()
|
| _ = s.Close()
|
|
|
| return false
|
| default:
|
| }
|
|
|
|
|
| event, err := s.sseStream.Recv()
|
| if err != nil {
|
| if errors.Is(err, io.EOF) {
|
| log.Debug(s.ctx, "SSE stream closed")
|
| _ = s.Close()
|
|
|
| return false
|
| }
|
|
|
| s.err = err
|
| _ = s.Close()
|
|
|
| return false
|
| }
|
|
|
| log.Debug(s.ctx, "SSE event received", log.Any("event", event))
|
|
|
|
|
| s.current = &StreamEvent{
|
| LastEventID: event.LastEventID,
|
| Type: event.Type,
|
| Data: []byte(event.Data),
|
| }
|
|
|
| return true
|
| }
|
|
|
|
|
| func (s *defaultSSEDecoder) Current() *StreamEvent {
|
| return s.current
|
| }
|
|
|
|
|
| func (s *defaultSSEDecoder) Err() error {
|
| return s.err
|
| }
|
|
|
|
|
| func (s *defaultSSEDecoder) Close() error {
|
|
|
| if s.closed {
|
| return s.closeErr
|
| }
|
|
|
| s.closed = true
|
| if s.sseStream != nil {
|
| s.closeErr = s.sseStream.Close()
|
| log.Debug(s.ctx, "SSE stream closed")
|
| }
|
|
|
| return s.closeErr
|
| }
|
|
|
|
|
| func init() {
|
| RegisterDecoder("text/event-stream", NewDefaultSSEDecoder)
|
| RegisterDecoder("text/event-stream; charset=utf-8", NewDefaultSSEDecoder)
|
| }
|
|
|