| package api
|
|
|
| import (
|
| "encoding/json"
|
| "errors"
|
| "net/http"
|
|
|
| "github.com/gin-contrib/sse"
|
| "github.com/gin-gonic/gin"
|
|
|
| "github.com/looplj/axonhub/internal/log"
|
| "github.com/looplj/axonhub/internal/server/orchestrator"
|
| "github.com/looplj/axonhub/llm/httpclient"
|
| "github.com/looplj/axonhub/llm/streams"
|
| )
|
|
|
|
|
| type StreamWriter func(c *gin.Context, stream streams.Stream[*httpclient.StreamEvent])
|
|
|
| type ChatCompletionHandlers struct {
|
| ChatCompletionOrchestrator *orchestrator.ChatCompletionOrchestrator
|
| StreamWriter StreamWriter
|
| }
|
|
|
| func NewChatCompletionHandlers(orchestrator *orchestrator.ChatCompletionOrchestrator) *ChatCompletionHandlers {
|
| return &ChatCompletionHandlers{
|
| ChatCompletionOrchestrator: orchestrator,
|
| StreamWriter: WriteSSEStream,
|
| }
|
| }
|
|
|
|
|
| func (handlers *ChatCompletionHandlers) WithStreamWriter(writer StreamWriter) *ChatCompletionHandlers {
|
| return &ChatCompletionHandlers{
|
| ChatCompletionOrchestrator: handlers.ChatCompletionOrchestrator,
|
| StreamWriter: writer,
|
| }
|
| }
|
|
|
| func (handlers *ChatCompletionHandlers) ChatCompletion(c *gin.Context) {
|
| ctx := c.Request.Context()
|
|
|
|
|
| genericReq, err := httpclient.ReadHTTPRequest(c.Request)
|
| if err != nil {
|
| httpErr := handlers.ChatCompletionOrchestrator.Inbound.TransformError(ctx, err)
|
| c.JSON(httpErr.StatusCode, json.RawMessage(httpErr.Body))
|
|
|
| return
|
| }
|
|
|
| if len(genericReq.Body) == 0 {
|
| JSONError(c, http.StatusBadRequest, errors.New("Request body is empty"))
|
| return
|
| }
|
|
|
|
|
|
|
| result, err := handlers.ChatCompletionOrchestrator.Process(ctx, genericReq)
|
| if err != nil {
|
| log.Error(ctx, "Error processing chat completion", log.Cause(err))
|
|
|
| httpErr := handlers.ChatCompletionOrchestrator.Inbound.TransformError(ctx, err)
|
| c.JSON(httpErr.StatusCode, json.RawMessage(httpErr.Body))
|
|
|
| return
|
| }
|
|
|
| if result.ChatCompletion != nil {
|
| resp := result.ChatCompletion
|
|
|
| contentType := "application/json"
|
| if ct := resp.Headers.Get("Content-Type"); ct != "" {
|
| contentType = ct
|
| }
|
|
|
| c.Data(resp.StatusCode, contentType, resp.Body)
|
|
|
| return
|
| }
|
|
|
| if result.ChatCompletionStream != nil {
|
| defer func() {
|
| log.Debug(ctx, "Close chat stream")
|
|
|
| err := result.ChatCompletionStream.Close()
|
| if err != nil {
|
| logger.Error(ctx, "Error closing stream", log.Cause(err))
|
| }
|
| }()
|
|
|
| c.Header("Access-Control-Allow-Origin", "*")
|
|
|
| streamWriter := handlers.StreamWriter
|
| if streamWriter == nil {
|
| streamWriter = WriteSSEStream
|
| }
|
|
|
| streamWriter(c, result.ChatCompletionStream)
|
| }
|
| }
|
|
|
|
|
| func WriteSSEStream(c *gin.Context, stream streams.Stream[*httpclient.StreamEvent]) {
|
| ctx := c.Request.Context()
|
| clientDisconnected := false
|
|
|
| defer func() {
|
| if clientDisconnected {
|
| log.Warn(ctx, "Client disconnected")
|
| }
|
| }()
|
|
|
|
|
| c.Header("Content-Type", sse.ContentType)
|
| c.Header("Cache-Control", "no-cache")
|
| c.Header("Connection", "keep-alive")
|
|
|
| for {
|
| select {
|
| case <-ctx.Done():
|
| clientDisconnected = true
|
|
|
| log.Warn(ctx, "Context done, stopping stream")
|
|
|
| return
|
| default:
|
| if stream.Next() {
|
| cur := stream.Current()
|
| c.SSEvent(cur.Type, cur.Data)
|
| log.Debug(ctx, "write stream event", log.Any("event", cur))
|
| c.Writer.Flush()
|
| } else {
|
| if stream.Err() != nil {
|
| log.Error(ctx, "Error in stream", log.Cause(stream.Err()))
|
| c.SSEvent("error", stream.Err())
|
| }
|
|
|
| c.Writer.Flush()
|
|
|
| return
|
| }
|
| }
|
| }
|
| }
|
|
|