Traves's picture
Add root handler to fix 404
aa0e8f8 verified
package main
import (
"encoding/json"
"fmt"
"log"
"net/http"
"os"
"strings"
"time"
"microgpt-go/pkg/model"
)
type ChatMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
type ChatCompletionRequest struct {
Model string `json:"model"`
Messages []ChatMessage `json:"messages"`
Temperature float64 `json:"temperature"`
MaxTokens int `json:"max_tokens"`
TopP float64 `json:"top_p"`
Stream bool `json:"stream"`
}
type ChatCompletionResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []struct {
Message ChatMessage `json:"message"`
Index int `json:"index"`
FinishReason string `json:"finish_reason"`
} `json:"choices"`
Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
} `json:"usage"`
}
var (
gpt func(tokenID, posID int, keys, values [][][]*model.Value) []*model.Value
tokenizer model.TokenizerRuntime
config model.TrainingCheckpointConfig
state map[string][][]*model.Value
)
func initModel() {
ckptPath := os.Getenv("MODEL_PATH")
if ckptPath == "" {
ckptPath = "models/latest_checkpoint.json"
}
log.Printf("Loading model from %s...", ckptPath)
ckpt, err := model.LoadCheckpoint(ckptPath)
if err != nil {
log.Fatalf("Failed to load checkpoint: %v", err)
}
tokenizer, err = model.TokenizerFromCheckpoint(ckpt)
if err != nil {
log.Fatalf("Failed to load tokenizer: %v", err)
}
state = model.ImportState(ckpt.State)
config = ckpt.Config
gpt = model.BuildGPT(state, config.NLayer, config.NEmbd, config.NHead)
log.Println("Model loaded successfully.")
}
func handleChat(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req ChatCompletionRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Invalid request body", http.StatusBadRequest)
return
}
if req.Temperature <= 0 {
req.Temperature = 0.5
}
if req.TopP <= 0 {
req.TopP = 0.9
}
if req.MaxTokens <= 0 {
req.MaxTokens = 128
}
// Simple prompt construction from messages
var promptBuilder strings.Builder
for _, msg := range req.Messages {
role := "User"
if msg.Role == "assistant" {
role = "Assistant"
}
fmt.Fprintf(&promptBuilder, "%s: %s\n", role, msg.Content)
}
promptBuilder.WriteString("Assistant: ")
promptText := promptBuilder.String()
promptTokens := tokenizer.EncodeDoc(promptText)
if len(promptTokens) > config.BlockSize-1 {
promptTokens = promptTokens[len(promptTokens)-(config.BlockSize-1):]
}
keys := make([][][]*model.Value, config.NLayer)
values := make([][][]*model.Value, config.NLayer)
tokenID := tokenizer.BosID
pos := 0
// Process prompt tokens (pre-fill KV cache)
for _, nextID := range promptTokens {
if pos >= config.BlockSize {
break
}
_ = gpt(tokenID, pos, keys, values)
tokenID = nextID
pos++
}
// Generate response
completionTokens := 0
outTokens := make([]int, 0, req.MaxTokens)
recent := make([]int, 0, 64)
stopSeqs := []string{"\nUser:", "\nAssistant:"}
for pos < config.BlockSize && completionTokens < req.MaxTokens {
logits := gpt(tokenID, pos, keys, values)
recentSet := map[int]bool{}
for _, id := range recent {
recentSet[id] = true
}
weights := model.NextTokenWeights(logits, req.Temperature, 40, req.TopP, recentSet, 1.1)
tokenID = model.SampleWeighted(weights)
if tokenID == tokenizer.BosID {
break
}
outTokens = append(outTokens, tokenID)
recent = append(recent, tokenID)
if len(recent) > 64 {
recent = recent[len(recent)-64:]
}
completionTokens++
pos++
// Check for stop sequences in decoded text
fullText := tokenizer.DecodeTokens(outTokens)
stopFound := false
for _, stop := range stopSeqs {
if strings.Contains(fullText, stop) {
stopFound = true
break
}
}
if stopFound {
break
}
}
responseText := strings.TrimSpace(tokenizer.DecodeTokens(outTokens))
// Clean up any trailing stop sequence markers
for _, stop := range stopSeqs {
if idx := strings.Index(responseText, strings.TrimSpace(stop)); idx >= 0 {
responseText = responseText[:idx]
}
}
resp := ChatCompletionResponse{
ID: fmt.Sprintf("chatcmpl-%d", time.Now().Unix()),
Object: "chat.completion",
Created: time.Now().Unix(),
Model: "microgpt",
Choices: []struct {
Message ChatMessage `json:"message"`
Index int `json:"index"`
FinishReason string `json:"finish_reason"`
}{
{
Message: ChatMessage{
Role: "assistant",
Content: strings.TrimSpace(responseText),
},
Index: 0,
FinishReason: "stop",
},
},
}
resp.Usage.PromptTokens = len(promptTokens)
resp.Usage.CompletionTokens = completionTokens
resp.Usage.TotalTokens = resp.Usage.PromptTokens + resp.Usage.CompletionTokens
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}
func handleModels(w http.ResponseWriter, r *http.Request) {
resp := struct {
Object string `json:"object"`
Data []struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
OwnedBy string `json:"owned_by"`
} `json:"data"`
}{
Object: "list",
Data: []struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
OwnedBy string `json:"owned_by"`
}{
{
ID: "microgpt",
Object: "model",
Created: time.Now().Unix(),
OwnedBy: "microgpt",
},
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}
func handleRoot(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/" {
http.NotFound(w, r)
return
}
w.Header().Set("Content-Type", "text/plain")
fmt.Fprintf(w, "MicroGPT API is running.\n\nEndpoints:\n- POST /v1/chat/completions\n- GET /v1/models\n")
}
func main() {
initModel()
http.HandleFunc("/", handleRoot)
http.HandleFunc("/v1/chat/completions", handleChat)
http.HandleFunc("/v1/models", handleModels)
port := os.Getenv("PORT")
if port == "" {
port = "7860" // Standard port for HF Spaces
}
log.Printf("Starting OpenAI-compatible server on port %s...", port)
if err := http.ListenAndServe(":"+port, nil); err != nil {
log.Fatalf("Failed to start server: %v", err)
}
}