| package contexts
|
|
|
| import (
|
| "context"
|
| "testing"
|
|
|
| "github.com/looplj/axonhub/internal/ent"
|
| )
|
|
|
| func TestWithThread(t *testing.T) {
|
| ctx := context.Background()
|
| thread := &ent.Thread{
|
| ID: 1,
|
| ThreadID: "thread-123",
|
| }
|
|
|
|
|
| newCtx := WithThread(ctx, thread)
|
| if newCtx == ctx {
|
| t.Error("WithThread should return a new context")
|
| }
|
|
|
|
|
| retrievedThread, ok := GetThread(newCtx)
|
| if !ok {
|
| t.Error("GetThread should return true for existing thread")
|
| }
|
|
|
| if retrievedThread == nil {
|
| t.Error("GetThread should return non-nil thread")
|
| }
|
|
|
| if retrievedThread.ID != thread.ID {
|
| t.Errorf("expected ID %d, got %d", thread.ID, retrievedThread.ID)
|
| }
|
|
|
| if retrievedThread.ThreadID != thread.ThreadID {
|
| t.Errorf("expected ThreadID %s, got %s", thread.ThreadID, retrievedThread.ThreadID)
|
| }
|
| }
|
|
|
| func TestGetThread(t *testing.T) {
|
| ctx := context.Background()
|
|
|
|
|
| thread, ok := GetThread(ctx)
|
| if ok {
|
| t.Error("GetThread should return false for empty context")
|
| }
|
|
|
| if thread != nil {
|
| t.Error("GetThread should return nil for empty context")
|
| }
|
|
|
|
|
| ctxWithOtherValue := context.WithValue(ctx, "other_key", "other_value")
|
|
|
| thread, ok = GetThread(ctxWithOtherValue)
|
| if ok {
|
| t.Error("GetThread should return false for context without thread")
|
| }
|
|
|
| if thread != nil {
|
| t.Error("GetThread should return nil for context without thread")
|
| }
|
| }
|
|
|
| func TestThreadWithMultipleValues(t *testing.T) {
|
| ctx := context.Background()
|
|
|
|
|
| ctx = WithAPIKey(ctx, &ent.APIKey{ID: 1, Key: "test-key"})
|
| ctx = WithUser(ctx, &ent.User{ID: 123, Email: "test@example.com"})
|
| ctx = WithThread(ctx, &ent.Thread{ID: 1, ThreadID: "thread-123"})
|
| ctx = WithProjectID(ctx, 456)
|
|
|
|
|
| apiKey, ok := GetAPIKey(ctx)
|
| if !ok || apiKey.ID != 1 {
|
| t.Error("API key should be stored and retrievable")
|
| }
|
|
|
| user, ok := GetUser(ctx)
|
| if !ok || user.ID != 123 {
|
| t.Error("User should be stored and retrievable")
|
| }
|
|
|
| thread, ok := GetThread(ctx)
|
| if !ok || thread.ID != 1 {
|
| t.Error("Thread should be stored and retrievable")
|
| }
|
|
|
| projectID, ok := GetProjectID(ctx)
|
| if !ok || projectID != 456 {
|
| t.Error("Project ID should be stored and retrievable")
|
| }
|
| }
|
|
|
| func TestThreadOverwrite(t *testing.T) {
|
| ctx := context.Background()
|
|
|
|
|
| ctx = WithThread(ctx, &ent.Thread{ID: 1, ThreadID: "thread-1"})
|
| ctx = WithThread(ctx, &ent.Thread{ID: 2, ThreadID: "thread-2"})
|
|
|
| thread, ok := GetThread(ctx)
|
| if !ok {
|
| t.Error("Thread should exist")
|
| }
|
|
|
| if thread.ID != 2 || thread.ThreadID != "thread-2" {
|
| t.Error("Thread should be the overwritten value")
|
| }
|
| }
|
|
|