|
|
package wsrelay |
|
|
|
|
|
import ( |
|
|
"bytes" |
|
|
"context" |
|
|
"errors" |
|
|
"fmt" |
|
|
"net/http" |
|
|
"time" |
|
|
|
|
|
"github.com/google/uuid" |
|
|
) |
|
|
|
|
|
|
|
|
type HTTPRequest struct { |
|
|
Method string |
|
|
URL string |
|
|
Headers http.Header |
|
|
Body []byte |
|
|
} |
|
|
|
|
|
|
|
|
type HTTPResponse struct { |
|
|
Status int |
|
|
Headers http.Header |
|
|
Body []byte |
|
|
} |
|
|
|
|
|
|
|
|
type StreamEvent struct { |
|
|
Type string |
|
|
Payload []byte |
|
|
Status int |
|
|
Headers http.Header |
|
|
Err error |
|
|
} |
|
|
|
|
|
|
|
|
func (m *Manager) NonStream(ctx context.Context, provider string, req *HTTPRequest) (*HTTPResponse, error) { |
|
|
if req == nil { |
|
|
return nil, fmt.Errorf("wsrelay: request is nil") |
|
|
} |
|
|
msg := Message{ID: uuid.NewString(), Type: MessageTypeHTTPReq, Payload: encodeRequest(req)} |
|
|
respCh, err := m.Send(ctx, provider, msg) |
|
|
if err != nil { |
|
|
return nil, err |
|
|
} |
|
|
var ( |
|
|
streamMode bool |
|
|
streamResp *HTTPResponse |
|
|
streamBody bytes.Buffer |
|
|
) |
|
|
for { |
|
|
select { |
|
|
case <-ctx.Done(): |
|
|
return nil, ctx.Err() |
|
|
case msg, ok := <-respCh: |
|
|
if !ok { |
|
|
if streamMode { |
|
|
if streamResp == nil { |
|
|
streamResp = &HTTPResponse{Status: http.StatusOK, Headers: make(http.Header)} |
|
|
} else if streamResp.Headers == nil { |
|
|
streamResp.Headers = make(http.Header) |
|
|
} |
|
|
streamResp.Body = append(streamResp.Body[:0], streamBody.Bytes()...) |
|
|
return streamResp, nil |
|
|
} |
|
|
return nil, errors.New("wsrelay: connection closed during response") |
|
|
} |
|
|
switch msg.Type { |
|
|
case MessageTypeHTTPResp: |
|
|
resp := decodeResponse(msg.Payload) |
|
|
if streamMode && streamBody.Len() > 0 && len(resp.Body) == 0 { |
|
|
resp.Body = append(resp.Body[:0], streamBody.Bytes()...) |
|
|
} |
|
|
return resp, nil |
|
|
case MessageTypeError: |
|
|
return nil, decodeError(msg.Payload) |
|
|
case MessageTypeStreamStart, MessageTypeStreamChunk: |
|
|
if msg.Type == MessageTypeStreamStart { |
|
|
streamMode = true |
|
|
streamResp = decodeResponse(msg.Payload) |
|
|
if streamResp.Headers == nil { |
|
|
streamResp.Headers = make(http.Header) |
|
|
} |
|
|
streamBody.Reset() |
|
|
continue |
|
|
} |
|
|
if !streamMode { |
|
|
streamMode = true |
|
|
streamResp = &HTTPResponse{Status: http.StatusOK, Headers: make(http.Header)} |
|
|
} |
|
|
chunk := decodeChunk(msg.Payload) |
|
|
if len(chunk) > 0 { |
|
|
streamBody.Write(chunk) |
|
|
} |
|
|
case MessageTypeStreamEnd: |
|
|
if !streamMode { |
|
|
return &HTTPResponse{Status: http.StatusOK, Headers: make(http.Header)}, nil |
|
|
} |
|
|
if streamResp == nil { |
|
|
streamResp = &HTTPResponse{Status: http.StatusOK, Headers: make(http.Header)} |
|
|
} else if streamResp.Headers == nil { |
|
|
streamResp.Headers = make(http.Header) |
|
|
} |
|
|
streamResp.Body = append(streamResp.Body[:0], streamBody.Bytes()...) |
|
|
return streamResp, nil |
|
|
default: |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
func (m *Manager) Stream(ctx context.Context, provider string, req *HTTPRequest) (<-chan StreamEvent, error) { |
|
|
if req == nil { |
|
|
return nil, fmt.Errorf("wsrelay: request is nil") |
|
|
} |
|
|
msg := Message{ID: uuid.NewString(), Type: MessageTypeHTTPReq, Payload: encodeRequest(req)} |
|
|
respCh, err := m.Send(ctx, provider, msg) |
|
|
if err != nil { |
|
|
return nil, err |
|
|
} |
|
|
out := make(chan StreamEvent) |
|
|
go func() { |
|
|
defer close(out) |
|
|
for { |
|
|
select { |
|
|
case <-ctx.Done(): |
|
|
out <- StreamEvent{Err: ctx.Err()} |
|
|
return |
|
|
case msg, ok := <-respCh: |
|
|
if !ok { |
|
|
out <- StreamEvent{Err: errors.New("wsrelay: stream closed")} |
|
|
return |
|
|
} |
|
|
switch msg.Type { |
|
|
case MessageTypeStreamStart: |
|
|
resp := decodeResponse(msg.Payload) |
|
|
out <- StreamEvent{Type: MessageTypeStreamStart, Status: resp.Status, Headers: resp.Headers} |
|
|
case MessageTypeStreamChunk: |
|
|
chunk := decodeChunk(msg.Payload) |
|
|
out <- StreamEvent{Type: MessageTypeStreamChunk, Payload: chunk} |
|
|
case MessageTypeStreamEnd: |
|
|
out <- StreamEvent{Type: MessageTypeStreamEnd} |
|
|
return |
|
|
case MessageTypeError: |
|
|
out <- StreamEvent{Type: MessageTypeError, Err: decodeError(msg.Payload)} |
|
|
return |
|
|
case MessageTypeHTTPResp: |
|
|
resp := decodeResponse(msg.Payload) |
|
|
out <- StreamEvent{Type: MessageTypeHTTPResp, Status: resp.Status, Headers: resp.Headers, Payload: resp.Body} |
|
|
return |
|
|
default: |
|
|
} |
|
|
} |
|
|
} |
|
|
}() |
|
|
return out, nil |
|
|
} |
|
|
|
|
|
func encodeRequest(req *HTTPRequest) map[string]any { |
|
|
headers := make(map[string]any, len(req.Headers)) |
|
|
for key, values := range req.Headers { |
|
|
copyValues := make([]string, len(values)) |
|
|
copy(copyValues, values) |
|
|
headers[key] = copyValues |
|
|
} |
|
|
return map[string]any{ |
|
|
"method": req.Method, |
|
|
"url": req.URL, |
|
|
"headers": headers, |
|
|
"body": string(req.Body), |
|
|
"sent_at": time.Now().UTC().Format(time.RFC3339Nano), |
|
|
} |
|
|
} |
|
|
|
|
|
func decodeResponse(payload map[string]any) *HTTPResponse { |
|
|
if payload == nil { |
|
|
return &HTTPResponse{Status: http.StatusBadGateway, Headers: make(http.Header)} |
|
|
} |
|
|
resp := &HTTPResponse{Status: http.StatusOK, Headers: make(http.Header)} |
|
|
if status, ok := payload["status"].(float64); ok { |
|
|
resp.Status = int(status) |
|
|
} |
|
|
if headers, ok := payload["headers"].(map[string]any); ok { |
|
|
for key, raw := range headers { |
|
|
switch v := raw.(type) { |
|
|
case []any: |
|
|
for _, item := range v { |
|
|
if str, ok := item.(string); ok { |
|
|
resp.Headers.Add(key, str) |
|
|
} |
|
|
} |
|
|
case []string: |
|
|
for _, str := range v { |
|
|
resp.Headers.Add(key, str) |
|
|
} |
|
|
case string: |
|
|
resp.Headers.Set(key, v) |
|
|
} |
|
|
} |
|
|
} |
|
|
if body, ok := payload["body"].(string); ok { |
|
|
resp.Body = []byte(body) |
|
|
} |
|
|
return resp |
|
|
} |
|
|
|
|
|
func decodeChunk(payload map[string]any) []byte { |
|
|
if payload == nil { |
|
|
return nil |
|
|
} |
|
|
if data, ok := payload["data"].(string); ok { |
|
|
return []byte(data) |
|
|
} |
|
|
return nil |
|
|
} |
|
|
|
|
|
func decodeError(payload map[string]any) error { |
|
|
if payload == nil { |
|
|
return errors.New("wsrelay: unknown error") |
|
|
} |
|
|
message, _ := payload["error"].(string) |
|
|
status := 0 |
|
|
if v, ok := payload["status"].(float64); ok { |
|
|
status = int(v) |
|
|
} |
|
|
if message == "" { |
|
|
message = "wsrelay: upstream error" |
|
|
} |
|
|
return fmt.Errorf("%s (status=%d)", message, status) |
|
|
} |
|
|
|