Spaces:
Running
Running
| package main | |
| import ( | |
| "encoding/json" | |
| "fmt" | |
| "log" | |
| "net/http" | |
| "os" | |
| "strings" | |
| "time" | |
| "microgpt-go/pkg/model" | |
| ) | |
| type ChatMessage struct { | |
| Role string `json:"role"` | |
| Content string `json:"content"` | |
| } | |
| type ChatCompletionRequest struct { | |
| Model string `json:"model"` | |
| Messages []ChatMessage `json:"messages"` | |
| Temperature float64 `json:"temperature"` | |
| MaxTokens int `json:"max_tokens"` | |
| TopP float64 `json:"top_p"` | |
| Stream bool `json:"stream"` | |
| } | |
| type ChatCompletionResponse struct { | |
| ID string `json:"id"` | |
| Object string `json:"object"` | |
| Created int64 `json:"created"` | |
| Model string `json:"model"` | |
| Choices []struct { | |
| Message ChatMessage `json:"message"` | |
| Index int `json:"index"` | |
| FinishReason string `json:"finish_reason"` | |
| } `json:"choices"` | |
| Usage struct { | |
| PromptTokens int `json:"prompt_tokens"` | |
| CompletionTokens int `json:"completion_tokens"` | |
| TotalTokens int `json:"total_tokens"` | |
| } `json:"usage"` | |
| } | |
| var ( | |
| gpt func(tokenID, posID int, keys, values [][][]*model.Value) []*model.Value | |
| tokenizer model.TokenizerRuntime | |
| config model.TrainingCheckpointConfig | |
| state map[string][][]*model.Value | |
| ) | |
| func initModel() { | |
| ckptPath := os.Getenv("MODEL_PATH") | |
| if ckptPath == "" { | |
| ckptPath = "models/latest_checkpoint.json" | |
| } | |
| log.Printf("Loading model from %s...", ckptPath) | |
| ckpt, err := model.LoadCheckpoint(ckptPath) | |
| if err != nil { | |
| log.Fatalf("Failed to load checkpoint: %v", err) | |
| } | |
| tokenizer, err = model.TokenizerFromCheckpoint(ckpt) | |
| if err != nil { | |
| log.Fatalf("Failed to load tokenizer: %v", err) | |
| } | |
| state = model.ImportState(ckpt.State) | |
| config = ckpt.Config | |
| gpt = model.BuildGPT(state, config.NLayer, config.NEmbd, config.NHead) | |
| log.Println("Model loaded successfully.") | |
| } | |
| func handleChat(w http.ResponseWriter, r *http.Request) { | |
| if r.Method != http.MethodPost { | |
| http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) | |
| return | |
| } | |
| var req ChatCompletionRequest | |
| if err := json.NewDecoder(r.Body).Decode(&req); err != nil { | |
| http.Error(w, "Invalid request body", http.StatusBadRequest) | |
| return | |
| } | |
| if req.Temperature <= 0 { | |
| req.Temperature = 0.5 | |
| } | |
| if req.TopP <= 0 { | |
| req.TopP = 0.9 | |
| } | |
| if req.MaxTokens <= 0 { | |
| req.MaxTokens = 128 | |
| } | |
| // Simple prompt construction from messages | |
| var promptBuilder strings.Builder | |
| for _, msg := range req.Messages { | |
| role := "User" | |
| if msg.Role == "assistant" { | |
| role = "Assistant" | |
| } | |
| fmt.Fprintf(&promptBuilder, "%s: %s\n", role, msg.Content) | |
| } | |
| promptBuilder.WriteString("Assistant: ") | |
| promptText := promptBuilder.String() | |
| promptTokens := tokenizer.EncodeDoc(promptText) | |
| if len(promptTokens) > config.BlockSize-1 { | |
| promptTokens = promptTokens[len(promptTokens)-(config.BlockSize-1):] | |
| } | |
| keys := make([][][]*model.Value, config.NLayer) | |
| values := make([][][]*model.Value, config.NLayer) | |
| tokenID := tokenizer.BosID | |
| pos := 0 | |
| // Process prompt tokens (pre-fill KV cache) | |
| for _, nextID := range promptTokens { | |
| if pos >= config.BlockSize { | |
| break | |
| } | |
| _ = gpt(tokenID, pos, keys, values) | |
| tokenID = nextID | |
| pos++ | |
| } | |
| // Generate response | |
| completionTokens := 0 | |
| outTokens := make([]int, 0, req.MaxTokens) | |
| recent := make([]int, 0, 64) | |
| stopSeqs := []string{"\nUser:", "\nAssistant:"} | |
| for pos < config.BlockSize && completionTokens < req.MaxTokens { | |
| logits := gpt(tokenID, pos, keys, values) | |
| recentSet := map[int]bool{} | |
| for _, id := range recent { | |
| recentSet[id] = true | |
| } | |
| weights := model.NextTokenWeights(logits, req.Temperature, 40, req.TopP, recentSet, 1.1) | |
| tokenID = model.SampleWeighted(weights) | |
| if tokenID == tokenizer.BosID { | |
| break | |
| } | |
| outTokens = append(outTokens, tokenID) | |
| recent = append(recent, tokenID) | |
| if len(recent) > 64 { | |
| recent = recent[len(recent)-64:] | |
| } | |
| completionTokens++ | |
| pos++ | |
| // Check for stop sequences in decoded text | |
| fullText := tokenizer.DecodeTokens(outTokens) | |
| stopFound := false | |
| for _, stop := range stopSeqs { | |
| if strings.Contains(fullText, stop) { | |
| stopFound = true | |
| break | |
| } | |
| } | |
| if stopFound { | |
| break | |
| } | |
| } | |
| responseText := strings.TrimSpace(tokenizer.DecodeTokens(outTokens)) | |
| // Clean up any trailing stop sequence markers | |
| for _, stop := range stopSeqs { | |
| if idx := strings.Index(responseText, strings.TrimSpace(stop)); idx >= 0 { | |
| responseText = responseText[:idx] | |
| } | |
| } | |
| resp := ChatCompletionResponse{ | |
| ID: fmt.Sprintf("chatcmpl-%d", time.Now().Unix()), | |
| Object: "chat.completion", | |
| Created: time.Now().Unix(), | |
| Model: "microgpt", | |
| Choices: []struct { | |
| Message ChatMessage `json:"message"` | |
| Index int `json:"index"` | |
| FinishReason string `json:"finish_reason"` | |
| }{ | |
| { | |
| Message: ChatMessage{ | |
| Role: "assistant", | |
| Content: strings.TrimSpace(responseText), | |
| }, | |
| Index: 0, | |
| FinishReason: "stop", | |
| }, | |
| }, | |
| } | |
| resp.Usage.PromptTokens = len(promptTokens) | |
| resp.Usage.CompletionTokens = completionTokens | |
| resp.Usage.TotalTokens = resp.Usage.PromptTokens + resp.Usage.CompletionTokens | |
| w.Header().Set("Content-Type", "application/json") | |
| json.NewEncoder(w).Encode(resp) | |
| } | |
| func handleModels(w http.ResponseWriter, r *http.Request) { | |
| resp := struct { | |
| Object string `json:"object"` | |
| Data []struct { | |
| ID string `json:"id"` | |
| Object string `json:"object"` | |
| Created int64 `json:"created"` | |
| OwnedBy string `json:"owned_by"` | |
| } `json:"data"` | |
| }{ | |
| Object: "list", | |
| Data: []struct { | |
| ID string `json:"id"` | |
| Object string `json:"object"` | |
| Created int64 `json:"created"` | |
| OwnedBy string `json:"owned_by"` | |
| }{ | |
| { | |
| ID: "microgpt", | |
| Object: "model", | |
| Created: time.Now().Unix(), | |
| OwnedBy: "microgpt", | |
| }, | |
| }, | |
| } | |
| w.Header().Set("Content-Type", "application/json") | |
| json.NewEncoder(w).Encode(resp) | |
| } | |
| func handleRoot(w http.ResponseWriter, r *http.Request) { | |
| if r.URL.Path != "/" { | |
| http.NotFound(w, r) | |
| return | |
| } | |
| w.Header().Set("Content-Type", "text/plain") | |
| fmt.Fprintf(w, "MicroGPT API is running.\n\nEndpoints:\n- POST /v1/chat/completions\n- GET /v1/models\n") | |
| } | |
| func main() { | |
| initModel() | |
| http.HandleFunc("/", handleRoot) | |
| http.HandleFunc("/v1/chat/completions", handleChat) | |
| http.HandleFunc("/v1/models", handleModels) | |
| port := os.Getenv("PORT") | |
| if port == "" { | |
| port = "7860" // Standard port for HF Spaces | |
| } | |
| log.Printf("Starting OpenAI-compatible server on port %s...", port) | |
| if err := http.ListenAndServe(":"+port, nil); err != nil { | |
| log.Fatalf("Failed to start server: %v", err) | |
| } | |
| } | |