axonhub / internal /server /api /chat.go
llzai's picture
Upload 1793 files
9853396 verified
package api
import (
"encoding/json"
"errors"
"net/http"
"github.com/gin-contrib/sse"
"github.com/gin-gonic/gin"
"github.com/looplj/axonhub/internal/log"
"github.com/looplj/axonhub/internal/server/orchestrator"
"github.com/looplj/axonhub/llm/httpclient"
"github.com/looplj/axonhub/llm/streams"
)
// StreamWriter is a function type for writing stream events to the response.
type StreamWriter func(c *gin.Context, stream streams.Stream[*httpclient.StreamEvent])
type ChatCompletionHandlers struct {
ChatCompletionOrchestrator *orchestrator.ChatCompletionOrchestrator
StreamWriter StreamWriter
}
func NewChatCompletionHandlers(orchestrator *orchestrator.ChatCompletionOrchestrator) *ChatCompletionHandlers {
return &ChatCompletionHandlers{
ChatCompletionOrchestrator: orchestrator,
StreamWriter: WriteSSEStream,
}
}
// WithStreamWriter returns a new ChatCompletionHandlers with the specified stream writer.
func (handlers *ChatCompletionHandlers) WithStreamWriter(writer StreamWriter) *ChatCompletionHandlers {
return &ChatCompletionHandlers{
ChatCompletionOrchestrator: handlers.ChatCompletionOrchestrator,
StreamWriter: writer,
}
}
func (handlers *ChatCompletionHandlers) ChatCompletion(c *gin.Context) {
ctx := c.Request.Context()
// Use ReadHTTPRequest to parse the request
genericReq, err := httpclient.ReadHTTPRequest(c.Request)
if err != nil {
httpErr := handlers.ChatCompletionOrchestrator.Inbound.TransformError(ctx, err)
c.JSON(httpErr.StatusCode, json.RawMessage(httpErr.Body))
return
}
if len(genericReq.Body) == 0 {
JSONError(c, http.StatusBadRequest, errors.New("Request body is empty"))
return
}
// log.Debug(ctx, "Chat completion request", log.Any("request", genericReq))
result, err := handlers.ChatCompletionOrchestrator.Process(ctx, genericReq)
if err != nil {
log.Error(ctx, "Error processing chat completion", log.Cause(err))
httpErr := handlers.ChatCompletionOrchestrator.Inbound.TransformError(ctx, err)
c.JSON(httpErr.StatusCode, json.RawMessage(httpErr.Body))
return
}
if result.ChatCompletion != nil {
resp := result.ChatCompletion
contentType := "application/json"
if ct := resp.Headers.Get("Content-Type"); ct != "" {
contentType = ct
}
c.Data(resp.StatusCode, contentType, resp.Body)
return
}
if result.ChatCompletionStream != nil {
defer func() {
log.Debug(ctx, "Close chat stream")
err := result.ChatCompletionStream.Close()
if err != nil {
logger.Error(ctx, "Error closing stream", log.Cause(err))
}
}()
c.Header("Access-Control-Allow-Origin", "*")
streamWriter := handlers.StreamWriter
if streamWriter == nil {
streamWriter = WriteSSEStream
}
streamWriter(c, result.ChatCompletionStream)
}
}
// WriteSSEStream writes stream events as Server-Sent Events (SSE).
func WriteSSEStream(c *gin.Context, stream streams.Stream[*httpclient.StreamEvent]) {
ctx := c.Request.Context()
clientDisconnected := false
defer func() {
if clientDisconnected {
log.Warn(ctx, "Client disconnected")
}
}()
// Set SSE headers
c.Header("Content-Type", sse.ContentType)
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
for {
select {
case <-ctx.Done():
clientDisconnected = true
log.Warn(ctx, "Context done, stopping stream")
return
default:
if stream.Next() {
cur := stream.Current()
c.SSEvent(cur.Type, cur.Data)
log.Debug(ctx, "write stream event", log.Any("event", cur))
c.Writer.Flush()
} else {
if stream.Err() != nil {
log.Error(ctx, "Error in stream", log.Cause(stream.Err()))
c.SSEvent("error", stream.Err())
}
c.Writer.Flush()
return
}
}
}
}