| package testutil
|
|
|
| import (
|
| "context"
|
| "strings"
|
| "testing"
|
|
|
| "google.golang.org/genai"
|
| )
|
|
|
|
|
| type TestHelper struct {
|
| Config *Config
|
| Client *genai.Client
|
| }
|
|
|
|
|
| func NewTestHelper(t *testing.T, name string) *TestHelper {
|
| config := DefaultConfigWithPrefix(name)
|
| if err := config.ValidateConfig(); err != nil {
|
| t.Skipf("Skipping test due to configuration error: %v", err)
|
| }
|
|
|
| client, err := config.NewClient()
|
| if err != nil {
|
| t.Skipf("Skipping test due to client creation error: %v", err)
|
| }
|
|
|
| return &TestHelper{
|
| Config: config,
|
| Client: client,
|
| }
|
| }
|
|
|
|
|
| func (h *TestHelper) AssertNoError(t *testing.T, err error, msg ...interface{}) {
|
| t.Helper()
|
| if err != nil {
|
| t.Fatalf("Unexpected error: %v - %v", err, msg)
|
| }
|
| }
|
|
|
|
|
| func (h *TestHelper) LogResponse(t *testing.T, response interface{}, description string) {
|
| t.Helper()
|
| t.Logf("%s: %+v", description, response)
|
| }
|
|
|
|
|
| func (h *TestHelper) PrintHeaders(t *testing.T) {
|
| t.Helper()
|
| t.Logf("Using headers: %+v", h.Config.GetHeaders())
|
| }
|
|
|
|
|
| func (h *TestHelper) CreateTestContext() context.Context {
|
| ctx := context.Background()
|
| return h.Config.WithHeaders(ctx)
|
| }
|
|
|
|
|
| func (h *TestHelper) RunWithHeaders(t *testing.T, testFunc func(ctx context.Context) error) {
|
| t.Helper()
|
| ctx := h.CreateTestContext()
|
| if err := testFunc(ctx); err != nil {
|
| h.AssertNoError(t, err)
|
| }
|
| }
|
|
|
|
|
| func (h *TestHelper) ValidateChatResponse(t *testing.T, response *genai.GenerateContentResponse, description string) {
|
| t.Helper()
|
| if response == nil {
|
| t.Fatalf("Response is nil for %s", description)
|
| }
|
| if len(response.Candidates) == 0 {
|
| t.Fatalf("No candidates in response for %s", description)
|
| }
|
|
|
|
|
| candidate := response.Candidates[0]
|
| if candidate.Content == nil || len(candidate.Content.Parts) == 0 {
|
| t.Fatalf("Empty content in response for %s", description)
|
| }
|
|
|
| t.Logf("%s - Response validated successfully: %d candidates", description, len(response.Candidates))
|
| }
|
|
|
|
|
| func (h *TestHelper) GetModel() string {
|
| return h.Config.GetModel()
|
| }
|
|
|
|
|
| func (h *TestHelper) GetModelWithFallback(fallback string) string {
|
| return h.Config.GetModelWithFallback(fallback)
|
| }
|
|
|
|
|
| func (h *TestHelper) SetModel(model string) {
|
| h.Config.SetModel(model)
|
| }
|
|
|
|
|
| func (h *TestHelper) GetHTTPOptions() *genai.HTTPOptions {
|
| return h.Config.GetHTTPOptions()
|
| }
|
|
|
|
|
| func (h *TestHelper) MergeHTTPOptions(config *genai.GenerateContentConfig) *genai.GenerateContentConfig {
|
| if config == nil {
|
| config = &genai.GenerateContentConfig{}
|
| }
|
| if config.HTTPOptions == nil {
|
| config.HTTPOptions = h.GetHTTPOptions()
|
| } else {
|
|
|
| helperHeaders := h.GetHTTPOptions().Headers
|
| if config.HTTPOptions.Headers == nil {
|
| config.HTTPOptions.Headers = helperHeaders
|
| } else {
|
| for k, v := range helperHeaders {
|
| config.HTTPOptions.Headers[k] = v
|
| }
|
| }
|
| }
|
| return config
|
| }
|
|
|
|
|
| func (h *TestHelper) GenerateContentWithHeaders(ctx context.Context, model string, contents []*genai.Content, config *genai.GenerateContentConfig) (*genai.GenerateContentResponse, error) {
|
| config = h.MergeHTTPOptions(config)
|
| return h.Client.Models.GenerateContent(ctx, model, contents, config)
|
| }
|
|
|
|
|
| func (h *TestHelper) CreateChatWithHeaders(ctx context.Context, model string, config *genai.GenerateContentConfig, history []*genai.Content) (*genai.Chat, error) {
|
| config = h.MergeHTTPOptions(config)
|
| return h.Client.Chats.Create(ctx, model, config, history)
|
| }
|
|
|
|
|
| func CreateTestHelperWithNewTrace(t *testing.T, existingConfig *Config) *TestHelper {
|
| t.Helper()
|
|
|
|
|
| newConfig := &Config{
|
| APIKey: existingConfig.APIKey,
|
| BaseURL: existingConfig.BaseURL,
|
| Timeout: existingConfig.Timeout,
|
| MaxRetries: existingConfig.MaxRetries,
|
| Model: existingConfig.Model,
|
| DisableTrace: existingConfig.DisableTrace,
|
| DisableThread: existingConfig.DisableThread,
|
| ThreadID: existingConfig.ThreadID,
|
| }
|
|
|
|
|
| if !existingConfig.DisableTrace {
|
|
|
| prefix := "trace"
|
| if existingConfig.TraceID != "" {
|
|
|
| if idx := strings.Index(existingConfig.TraceID, "-"); idx > 0 {
|
| prefix = existingConfig.TraceID[:idx]
|
| }
|
| }
|
| newConfig.TraceID = getRandomTraceIDWithPrefix(prefix)
|
| }
|
|
|
| client, err := newConfig.NewClient()
|
| if err != nil {
|
| t.Skipf("Skipping test due to client creation error: %v", err)
|
| }
|
|
|
| return &TestHelper{
|
| Config: newConfig,
|
| Client: client,
|
| }
|
| }
|
|
|
|
|
| func ContainsCaseInsensitive(text, substring string) bool {
|
| return strings.Contains(strings.ToLower(text), strings.ToLower(substring))
|
| }
|
|
|
|
|
| func ContainsAnyCaseInsensitive(text string, substrings ...string) bool {
|
| for _, substring := range substrings {
|
| if ContainsCaseInsensitive(text, substring) {
|
| return true
|
| }
|
| }
|
| return false
|
| }
|
|
|
|
|
| func ExtractTextFromResponse(response *genai.GenerateContentResponse) string {
|
| if response == nil || len(response.Candidates) == 0 {
|
| return ""
|
| }
|
|
|
| candidate := response.Candidates[0]
|
| if candidate.Content == nil || len(candidate.Content.Parts) == 0 {
|
| return ""
|
| }
|
|
|
| var result strings.Builder
|
| for _, part := range candidate.Content.Parts {
|
| if part != nil && part.Text != "" {
|
| result.WriteString(part.Text)
|
| }
|
| }
|
|
|
| return result.String()
|
| }
|
|
|