package contexts import ( "context" "testing" "github.com/looplj/axonhub/internal/ent" ) func TestWithThread(t *testing.T) { ctx := context.Background() thread := &ent.Thread{ ID: 1, ThreadID: "thread-123", } // Test storing thread entity newCtx := WithThread(ctx, thread) if newCtx == ctx { t.Error("WithThread should return a new context") } // Test retrieving thread entity retrievedThread, ok := GetThread(newCtx) if !ok { t.Error("GetThread should return true for existing thread") } if retrievedThread == nil { t.Error("GetThread should return non-nil thread") } if retrievedThread.ID != thread.ID { t.Errorf("expected ID %d, got %d", thread.ID, retrievedThread.ID) } if retrievedThread.ThreadID != thread.ThreadID { t.Errorf("expected ThreadID %s, got %s", thread.ThreadID, retrievedThread.ThreadID) } } func TestGetThread(t *testing.T) { ctx := context.Background() // Test retrieving thread from empty context thread, ok := GetThread(ctx) if ok { t.Error("GetThread should return false for empty context") } if thread != nil { t.Error("GetThread should return nil for empty context") } // Test retrieving thread from context with other values ctxWithOtherValue := context.WithValue(ctx, "other_key", "other_value") thread, ok = GetThread(ctxWithOtherValue) if ok { t.Error("GetThread should return false for context without thread") } if thread != nil { t.Error("GetThread should return nil for context without thread") } } func TestThreadWithMultipleValues(t *testing.T) { ctx := context.Background() // Test storing thread along with other values ctx = WithAPIKey(ctx, &ent.APIKey{ID: 1, Key: "test-key"}) ctx = WithUser(ctx, &ent.User{ID: 123, Email: "test@example.com"}) ctx = WithThread(ctx, &ent.Thread{ID: 1, ThreadID: "thread-123"}) ctx = WithProjectID(ctx, 456) // Test retrieving all values apiKey, ok := GetAPIKey(ctx) if !ok || apiKey.ID != 1 { t.Error("API key should be stored and retrievable") } user, ok := GetUser(ctx) if !ok || user.ID != 123 { t.Error("User should be stored and retrievable") } thread, ok := GetThread(ctx) if !ok || thread.ID != 1 { t.Error("Thread should be stored and retrievable") } projectID, ok := GetProjectID(ctx) if !ok || projectID != 456 { t.Error("Project ID should be stored and retrievable") } } func TestThreadOverwrite(t *testing.T) { ctx := context.Background() // Test overwriting existing thread ctx = WithThread(ctx, &ent.Thread{ID: 1, ThreadID: "thread-1"}) ctx = WithThread(ctx, &ent.Thread{ID: 2, ThreadID: "thread-2"}) thread, ok := GetThread(ctx) if !ok { t.Error("Thread should exist") } if thread.ID != 2 || thread.ThreadID != "thread-2" { t.Error("Thread should be the overwritten value") } }