| package main |
|
|
| import ( |
| "context" |
| "encoding/json" |
| "fmt" |
| "log" |
| "net/http" |
| "os" |
| "strings" |
| "sync" |
| "time" |
|
|
| copilot "github.com/github/copilot-sdk/go" |
| ) |
|
|
| |
| type Server struct { |
| client *copilot.Client |
| apiToken string |
| mu sync.Mutex |
| } |
|
|
| |
| func NewServer() (*Server, error) { |
| |
| token := os.Getenv("GH_TOKEN") |
| if token == "" { |
| token = os.Getenv("GITHUB_TOKEN") |
| } |
| if token == "" { |
| log.Println("GH_TOKEN/GITHUB_TOKEN environment variable not set.") |
| return nil, fmt.Errorf("missing GitHub token: set GH_TOKEN as a Hugging Face Space Secret") |
| } else { |
| log.Printf("GH_TOKEN is set (%d chars)", len(token)) |
| } |
| _ = os.Setenv("GH_TOKEN", token) |
| _ = os.Setenv("GITHUB_TOKEN", token) |
|
|
| apiToken := os.Getenv("API_TOKEN") |
| if apiToken == "" { |
| log.Println("API_TOKEN environment variable not set.") |
| return nil, fmt.Errorf("missing API token: set API_TOKEN as a Hugging Face Space Secret") |
| } |
|
|
| client := copilot.NewClient(&copilot.ClientOptions{ |
| LogLevel: "error", |
| }) |
|
|
| log.Println("Starting copilot client...") |
| if err := client.Start(context.Background()); err != nil { |
| log.Printf("Failed to start copilot client: %v", err) |
| return nil, fmt.Errorf("copilot auth/start failed: %w", err) |
| } |
|
|
| log.Println("Copilot client started successfully") |
| return &Server{ |
| client: client, |
| apiToken: apiToken, |
| }, nil |
| } |
|
|
| |
| func (s *Server) Close() { |
| s.client.Stop() |
| } |
|
|
| |
| func (s *Server) HandleModels(w http.ResponseWriter, r *http.Request) { |
| if r.Method != http.MethodGet { |
| writeError(w, http.StatusMethodNotAllowed, "Method not allowed", "invalid_request_error") |
| return |
| } |
|
|
| models, err := s.client.ListModels(r.Context()) |
| if err != nil { |
| log.Printf("Error listing models: %v", err) |
| log.Printf("GH_TOKEN is set: %v", os.Getenv("GH_TOKEN") != "") |
| log.Printf("Token prefix looks like classic PAT: %v", strings.HasPrefix(os.Getenv("GH_TOKEN"), "ghp_")) |
| writeError(w, http.StatusInternalServerError, |
| fmt.Sprintf("Failed to list models: %v. Ensure GH_TOKEN/GITHUB_TOKEN is a valid token with Copilot access (for fine-grained PAT, include the Copilot Requests permission).", err), |
| "api_error") |
| return |
| } |
|
|
| response := ModelsResponse{ |
| Object: "list", |
| Data: make([]ModelData, 0, len(models)), |
| } |
|
|
| for _, model := range models { |
| response.Data = append(response.Data, ModelData{ |
| ID: model.ID, |
| Object: "model", |
| Created: currentTimestamp(), |
| OwnedBy: "github-copilot", |
| }) |
| } |
|
|
| writeJSON(w, http.StatusOK, response) |
| } |
|
|
| |
| func (s *Server) HandleChatCompletions(w http.ResponseWriter, r *http.Request) { |
| if r.Method != http.MethodPost { |
| writeError(w, http.StatusMethodNotAllowed, "Method not allowed", "invalid_request_error") |
| return |
| } |
|
|
| var req ChatCompletionRequest |
| if err := json.NewDecoder(r.Body).Decode(&req); err != nil { |
| writeError(w, http.StatusBadRequest, "Invalid request body", "invalid_request_error") |
| return |
| } |
|
|
| if req.Model == "" { |
| writeError(w, http.StatusBadRequest, "Model is required", "invalid_request_error") |
| return |
| } |
|
|
| if len(req.Messages) == 0 { |
| writeError(w, http.StatusBadRequest, "Messages are required", "invalid_request_error") |
| return |
| } |
|
|
| |
| var systemMessageParts []string |
| for _, msg := range req.Messages { |
| if msg.Role == "system" || msg.Role == "developer" { |
| systemMessageParts = append(systemMessageParts, msg.Content) |
| } |
| } |
|
|
| |
| prompt := buildPrompt(req.Messages) |
|
|
| |
| var copilotTools []copilot.Tool |
| log.Printf("[DEBUG] Received %d tools in request", len(req.Tools)) |
| for _, tool := range req.Tools { |
| if tool.Type == "function" { |
| |
| |
| copilotTools = append(copilotTools, copilot.Tool{ |
| Name: tool.Function.Name, |
| Description: tool.Function.Description, |
| Parameters: tool.Function.Parameters, |
| |
| }) |
| } |
| } |
|
|
| |
| sessionConfig := &copilot.SessionConfig{ |
| Model: req.Model, |
| Streaming: req.Stream, |
| Tools: copilotTools, |
| OnPermissionRequest: copilot.PermissionHandler.ApproveAll, |
| |
| InfiniteSessions: &copilot.InfiniteSessionConfig{ |
| Enabled: copilot.Bool(false), |
| }, |
| } |
|
|
| |
| if len(systemMessageParts) > 0 { |
| systemContent := strings.Join(systemMessageParts, "\n\n") |
| log.Printf("[DEBUG] Setting system message (length: %d)", len(systemContent)) |
| sessionConfig.SystemMessage = &copilot.SystemMessageConfig{ |
| Mode: "replace", |
| Content: systemContent, |
| } |
| } |
|
|
| |
| |
| if len(copilotTools) > 0 { |
| toolNames := make([]string, len(copilotTools)) |
| for i, t := range copilotTools { |
| toolNames[i] = t.Name |
| } |
| sessionConfig.AvailableTools = toolNames |
| } |
|
|
| |
| |
| |
| for i, t := range copilotTools { |
| _ = i |
| _ = t |
| |
| |
| |
| } |
|
|
| |
| session, err := s.client.CreateSession(r.Context(), sessionConfig) |
| if err != nil { |
| log.Printf("[ERROR] Creating session failed: %v", err) |
| writeError(w, http.StatusInternalServerError, "Failed to create session", "api_error") |
| return |
| } |
| defer session.Destroy() |
| log.Printf("[DEBUG] Session created successfully") |
|
|
| |
| |
|
|
| if req.Stream { |
| log.Printf("[DEBUG] Starting streaming response") |
| s.handleStreamingResponse(r.Context(), w, session, prompt, req.Model) |
| } else { |
| log.Printf("[DEBUG] Starting non-streaming response") |
| s.handleNonStreamingResponse(r.Context(), w, session, prompt, req.Model) |
| } |
| } |
|
|
| |
| func (s *Server) handleNonStreamingResponse(ctx context.Context, w http.ResponseWriter, session *copilot.Session, prompt, model string) { |
| var contentBuilder strings.Builder |
| var toolCalls []ToolCall |
| var finishReason string = "stop" |
|
|
| done := make(chan bool) |
| var closeOnce sync.Once |
|
|
| session.On(func(event copilot.SessionEvent) { |
| switch event.Type { |
| case copilot.AssistantMessage: |
| |
| if len(event.Data.ToolRequests) > 0 { |
| finishReason = "tool_calls" |
| for _, tr := range event.Data.ToolRequests { |
| argsJSON, _ := json.Marshal(tr.Arguments) |
| toolCalls = append(toolCalls, ToolCall{ |
| ID: tr.ToolCallID, |
| Type: "function", |
| Function: ToolCallFunction{ |
| Name: tr.Name, |
| Arguments: string(argsJSON), |
| }, |
| }) |
| } |
| } |
| |
| if event.Data.Content != nil { |
| contentBuilder.WriteString(*event.Data.Content) |
| } |
|
|
| case copilot.SessionIdle: |
| closeOnce.Do(func() { close(done) }) |
|
|
| case copilot.SessionError: |
| if event.Data.Message != nil { |
| log.Printf("Session error: %s", *event.Data.Message) |
| } |
| closeOnce.Do(func() { close(done) }) |
| } |
| }) |
|
|
| |
| _, err := session.Send(ctx, copilot.MessageOptions{ |
| Prompt: prompt, |
| }) |
| if err != nil { |
| log.Printf("Error sending message: %v", err) |
| writeError(w, http.StatusInternalServerError, "Failed to send message", "api_error") |
| return |
| } |
|
|
| |
| select { |
| case <-done: |
| case <-time.After(5 * time.Minute): |
| log.Printf("Request timed out") |
| writeError(w, http.StatusGatewayTimeout, "Request timed out", "api_error") |
| return |
| } |
|
|
| |
| response := ChatCompletionResponse{ |
| ID: fmt.Sprintf("chatcmpl-%d", time.Now().UnixNano()), |
| Object: "chat.completion", |
| Created: currentTimestamp(), |
| Model: model, |
| Choices: []Choice{ |
| { |
| Index: 0, |
| Message: &Message{ |
| Role: "assistant", |
| Content: contentBuilder.String(), |
| ToolCalls: toolCalls, |
| }, |
| FinishReason: &finishReason, |
| }, |
| }, |
| } |
|
|
| writeJSON(w, http.StatusOK, response) |
| } |
|
|
| |
| func (s *Server) handleStreamingResponse(ctx context.Context, w http.ResponseWriter, session *copilot.Session, prompt, model string) { |
| |
| w.Header().Set("Content-Type", "text/event-stream") |
| w.Header().Set("Cache-Control", "no-cache") |
| w.Header().Set("Connection", "keep-alive") |
| w.Header().Set("X-Accel-Buffering", "no") |
|
|
| flusher, ok := w.(http.Flusher) |
| if !ok { |
| writeError(w, http.StatusInternalServerError, "Streaming not supported", "api_error") |
| return |
| } |
|
|
| completionID := fmt.Sprintf("chatcmpl-%d", time.Now().UnixNano()) |
| done := make(chan bool) |
| var toolCalls []ToolCall |
| var mu sync.Mutex |
|
|
| sendChunk := func(delta Message, finishReason *string) { |
| chunk := ChatCompletionChunk{ |
| ID: completionID, |
| Object: "chat.completion.chunk", |
| Created: currentTimestamp(), |
| Model: model, |
| Choices: []Choice{ |
| { |
| Index: 0, |
| Delta: &delta, |
| FinishReason: finishReason, |
| }, |
| }, |
| } |
| data, _ := json.Marshal(chunk) |
| if finishReason != nil { |
| log.Printf("[DEBUG] SSE chunk (finish=%s): %s", *finishReason, string(data)) |
| } |
| fmt.Fprintf(w, "data: %s\n\n", data) |
| flusher.Flush() |
| } |
|
|
| |
| sendChunk(Message{Role: "assistant"}, nil) |
|
|
| var closeOnce sync.Once |
| session.On(func(event copilot.SessionEvent) { |
| |
| switch event.Type { |
| case copilot.AssistantMessageDelta: |
| |
| if event.Data.DeltaContent != nil { |
| sendChunk(Message{Content: *event.Data.DeltaContent}, nil) |
| } |
|
|
| case copilot.AssistantMessage: |
| log.Printf("[DEBUG] AssistantMessage - ToolRequests: %d, Content length: %d", |
| len(event.Data.ToolRequests), |
| func() int { |
| if event.Data.Content != nil { |
| return len(*event.Data.Content) |
| } |
| return 0 |
| }()) |
| |
| if len(event.Data.ToolRequests) > 0 { |
| log.Printf("[DEBUG] Tool calls found - streaming to client incrementally") |
| mu.Lock() |
| for i, tr := range event.Data.ToolRequests { |
| argsJSON, _ := json.Marshal(tr.Arguments) |
| |
| idx := i |
| |
| toolCalls = append(toolCalls, ToolCall{ |
| Index: &idx, |
| ID: tr.ToolCallID, |
| Type: "function", |
| Function: ToolCallFunction{ |
| Name: tr.Name, |
| Arguments: string(argsJSON), |
| }, |
| }) |
| |
| sendChunk(Message{ToolCalls: []ToolCall{{ |
| Index: &idx, |
| ID: tr.ToolCallID, |
| Type: "function", |
| Function: ToolCallFunction{ |
| Name: tr.Name, |
| }, |
| }}}, nil) |
| |
| sendChunk(Message{ToolCalls: []ToolCall{{ |
| Index: &idx, |
| Function: ToolCallFunction{ |
| Arguments: string(argsJSON), |
| }, |
| }}}, nil) |
| } |
| mu.Unlock() |
| |
| closeOnce.Do(func() { close(done) }) |
| } |
|
|
| case copilot.SessionIdle: |
| log.Printf("[DEBUG] SessionIdle - completing request") |
| closeOnce.Do(func() { close(done) }) |
|
|
| case copilot.SessionError: |
| if event.Data.Message != nil { |
| log.Printf("[DEBUG] SessionError: %s", *event.Data.Message) |
| } |
| closeOnce.Do(func() { close(done) }) |
|
|
| default: |
| |
| } |
| }) |
|
|
| |
| _, err := session.Send(ctx, copilot.MessageOptions{ |
| Prompt: prompt, |
| }) |
| if err != nil { |
| log.Printf("Error sending message: %v", err) |
| return |
| } |
|
|
| |
| select { |
| case <-done: |
| case <-time.After(5 * time.Minute): |
| log.Printf("Streaming request timed out") |
| return |
| } |
|
|
| |
| mu.Lock() |
| if len(toolCalls) > 0 { |
| |
| |
| sendChunk(Message{}, strPtr("tool_calls")) |
| } else { |
| sendChunk(Message{}, strPtr("stop")) |
| } |
| mu.Unlock() |
|
|
| |
| |
| fmt.Fprintf(w, "data: [DONE]\n\n") |
| flusher.Flush() |
| } |
|
|
| |
| func buildPrompt(messages []Message) string { |
| var parts []string |
|
|
| for _, msg := range messages { |
| switch msg.Role { |
|
|
| case "system", "developer": |
| |
| continue |
| case "user": |
| parts = append(parts, fmt.Sprintf("[User]: %s", msg.Content)) |
| case "assistant": |
| if msg.Content != "" { |
| parts = append(parts, fmt.Sprintf("[Assistant]: %s", msg.Content)) |
| } |
| |
| for _, tc := range msg.ToolCalls { |
| parts = append(parts, fmt.Sprintf("[Assistant called tool %s with args: %s]", tc.Function.Name, tc.Function.Arguments)) |
| } |
| case "tool": |
| parts = append(parts, fmt.Sprintf("[Tool result for %s]: %s", msg.ToolCallID, msg.Content)) |
| } |
| } |
|
|
| return strings.Join(parts, "\n\n") |
| } |
|
|
| |
| func writeJSON(w http.ResponseWriter, status int, data interface{}) { |
| w.Header().Set("Content-Type", "application/json") |
| w.WriteHeader(status) |
| json.NewEncoder(w).Encode(data) |
| } |
|
|
| |
| func writeError(w http.ResponseWriter, status int, message, errType string) { |
| w.Header().Set("Content-Type", "application/json") |
| w.WriteHeader(status) |
| json.NewEncoder(w).Encode(ErrorResponse{ |
| Error: ErrorDetail{ |
| Message: message, |
| Type: errType, |
| }, |
| }) |
| } |
|
|