package executor import ( "bytes" "context" "fmt" "strings" "sync" "time" "github.com/gin-gonic/gin" cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage" "github.com/tidwall/gjson" "github.com/tidwall/sjson" ) type usageReporter struct { provider string model string authID string authIndex string apiKey string source string requestedAt time.Time once sync.Once } func newUsageReporter(ctx context.Context, provider, model string, auth *cliproxyauth.Auth) *usageReporter { apiKey := apiKeyFromContext(ctx) reporter := &usageReporter{ provider: provider, model: model, requestedAt: time.Now(), apiKey: apiKey, source: resolveUsageSource(auth, apiKey), } if auth != nil { reporter.authID = auth.ID reporter.authIndex = auth.EnsureIndex() } return reporter } func (r *usageReporter) publish(ctx context.Context, detail usage.Detail) { r.publishWithOutcome(ctx, detail, false) } func (r *usageReporter) publishFailure(ctx context.Context) { r.publishWithOutcome(ctx, usage.Detail{}, true) } func (r *usageReporter) trackFailure(ctx context.Context, errPtr *error) { if r == nil || errPtr == nil { return } if *errPtr != nil { r.publishFailure(ctx) } } func (r *usageReporter) publishWithOutcome(ctx context.Context, detail usage.Detail, failed bool) { if r == nil { return } if detail.TotalTokens == 0 { total := detail.InputTokens + detail.OutputTokens + detail.ReasoningTokens if total > 0 { detail.TotalTokens = total } } if detail.InputTokens == 0 && detail.OutputTokens == 0 && detail.ReasoningTokens == 0 && detail.CachedTokens == 0 && detail.TotalTokens == 0 && !failed { return } r.once.Do(func() { usage.PublishRecord(ctx, usage.Record{ Provider: r.provider, Model: r.model, Source: r.source, APIKey: r.apiKey, AuthID: r.authID, AuthIndex: r.authIndex, RequestedAt: r.requestedAt, Failed: failed, Detail: detail, }) }) } // ensurePublished guarantees that a usage record is emitted exactly once. // It is safe to call multiple times; only the first call wins due to once.Do. // This is used to ensure request counting even when upstream responses do not // include any usage fields (tokens), especially for streaming paths. func (r *usageReporter) ensurePublished(ctx context.Context) { if r == nil { return } r.once.Do(func() { usage.PublishRecord(ctx, usage.Record{ Provider: r.provider, Model: r.model, Source: r.source, APIKey: r.apiKey, AuthID: r.authID, AuthIndex: r.authIndex, RequestedAt: r.requestedAt, Failed: false, Detail: usage.Detail{}, }) }) } func apiKeyFromContext(ctx context.Context) string { if ctx == nil { return "" } ginCtx, ok := ctx.Value("gin").(*gin.Context) if !ok || ginCtx == nil { return "" } if v, exists := ginCtx.Get("apiKey"); exists { switch value := v.(type) { case string: return value case fmt.Stringer: return value.String() default: return fmt.Sprintf("%v", value) } } return "" } func resolveUsageSource(auth *cliproxyauth.Auth, ctxAPIKey string) string { if auth != nil { provider := strings.TrimSpace(auth.Provider) if strings.EqualFold(provider, "gemini-cli") { if id := strings.TrimSpace(auth.ID); id != "" { return id } } if strings.EqualFold(provider, "vertex") { if auth.Metadata != nil { if projectID, ok := auth.Metadata["project_id"].(string); ok { if trimmed := strings.TrimSpace(projectID); trimmed != "" { return trimmed } } if project, ok := auth.Metadata["project"].(string); ok { if trimmed := strings.TrimSpace(project); trimmed != "" { return trimmed } } } } if _, value := auth.AccountInfo(); value != "" { return strings.TrimSpace(value) } if auth.Metadata != nil { if email, ok := auth.Metadata["email"].(string); ok { if trimmed := strings.TrimSpace(email); trimmed != "" { return trimmed } } } if auth.Attributes != nil { if key := strings.TrimSpace(auth.Attributes["api_key"]); key != "" { return key } } } if trimmed := strings.TrimSpace(ctxAPIKey); trimmed != "" { return trimmed } return "" } func parseCodexUsage(data []byte) (usage.Detail, bool) { usageNode := gjson.ParseBytes(data).Get("response.usage") if !usageNode.Exists() { return usage.Detail{}, false } detail := usage.Detail{ InputTokens: usageNode.Get("input_tokens").Int(), OutputTokens: usageNode.Get("output_tokens").Int(), TotalTokens: usageNode.Get("total_tokens").Int(), } if cached := usageNode.Get("input_tokens_details.cached_tokens"); cached.Exists() { detail.CachedTokens = cached.Int() } if reasoning := usageNode.Get("output_tokens_details.reasoning_tokens"); reasoning.Exists() { detail.ReasoningTokens = reasoning.Int() } return detail, true } func parseOpenAIUsage(data []byte) usage.Detail { usageNode := gjson.ParseBytes(data).Get("usage") if !usageNode.Exists() { return usage.Detail{} } detail := usage.Detail{ InputTokens: usageNode.Get("prompt_tokens").Int(), OutputTokens: usageNode.Get("completion_tokens").Int(), TotalTokens: usageNode.Get("total_tokens").Int(), } if cached := usageNode.Get("prompt_tokens_details.cached_tokens"); cached.Exists() { detail.CachedTokens = cached.Int() } if reasoning := usageNode.Get("completion_tokens_details.reasoning_tokens"); reasoning.Exists() { detail.ReasoningTokens = reasoning.Int() } return detail } func parseOpenAIStreamUsage(line []byte) (usage.Detail, bool) { payload := jsonPayload(line) if len(payload) == 0 || !gjson.ValidBytes(payload) { return usage.Detail{}, false } usageNode := gjson.GetBytes(payload, "usage") if !usageNode.Exists() { return usage.Detail{}, false } detail := usage.Detail{ InputTokens: usageNode.Get("prompt_tokens").Int(), OutputTokens: usageNode.Get("completion_tokens").Int(), TotalTokens: usageNode.Get("total_tokens").Int(), } if cached := usageNode.Get("prompt_tokens_details.cached_tokens"); cached.Exists() { detail.CachedTokens = cached.Int() } if reasoning := usageNode.Get("completion_tokens_details.reasoning_tokens"); reasoning.Exists() { detail.ReasoningTokens = reasoning.Int() } return detail, true } func parseClaudeUsage(data []byte) usage.Detail { usageNode := gjson.ParseBytes(data).Get("usage") if !usageNode.Exists() { return usage.Detail{} } detail := usage.Detail{ InputTokens: usageNode.Get("input_tokens").Int(), OutputTokens: usageNode.Get("output_tokens").Int(), CachedTokens: usageNode.Get("cache_read_input_tokens").Int(), } if detail.CachedTokens == 0 { // fall back to creation tokens when read tokens are absent detail.CachedTokens = usageNode.Get("cache_creation_input_tokens").Int() } detail.TotalTokens = detail.InputTokens + detail.OutputTokens return detail } func parseClaudeStreamUsage(line []byte) (usage.Detail, bool) { payload := jsonPayload(line) if len(payload) == 0 || !gjson.ValidBytes(payload) { return usage.Detail{}, false } usageNode := gjson.GetBytes(payload, "usage") if !usageNode.Exists() { return usage.Detail{}, false } detail := usage.Detail{ InputTokens: usageNode.Get("input_tokens").Int(), OutputTokens: usageNode.Get("output_tokens").Int(), CachedTokens: usageNode.Get("cache_read_input_tokens").Int(), } if detail.CachedTokens == 0 { detail.CachedTokens = usageNode.Get("cache_creation_input_tokens").Int() } detail.TotalTokens = detail.InputTokens + detail.OutputTokens return detail, true } func parseGeminiFamilyUsageDetail(node gjson.Result) usage.Detail { detail := usage.Detail{ InputTokens: node.Get("promptTokenCount").Int(), OutputTokens: node.Get("candidatesTokenCount").Int(), ReasoningTokens: node.Get("thoughtsTokenCount").Int(), TotalTokens: node.Get("totalTokenCount").Int(), CachedTokens: node.Get("cachedContentTokenCount").Int(), } if detail.TotalTokens == 0 { detail.TotalTokens = detail.InputTokens + detail.OutputTokens + detail.ReasoningTokens } return detail } func parseGeminiCLIUsage(data []byte) usage.Detail { usageNode := gjson.ParseBytes(data) node := usageNode.Get("response.usageMetadata") if !node.Exists() { node = usageNode.Get("response.usage_metadata") } if !node.Exists() { return usage.Detail{} } return parseGeminiFamilyUsageDetail(node) } func parseGeminiUsage(data []byte) usage.Detail { usageNode := gjson.ParseBytes(data) node := usageNode.Get("usageMetadata") if !node.Exists() { node = usageNode.Get("usage_metadata") } if !node.Exists() { return usage.Detail{} } return parseGeminiFamilyUsageDetail(node) } func parseGeminiStreamUsage(line []byte) (usage.Detail, bool) { payload := jsonPayload(line) if len(payload) == 0 || !gjson.ValidBytes(payload) { return usage.Detail{}, false } node := gjson.GetBytes(payload, "usageMetadata") if !node.Exists() { node = gjson.GetBytes(payload, "usage_metadata") } if !node.Exists() { return usage.Detail{}, false } return parseGeminiFamilyUsageDetail(node), true } func parseGeminiCLIStreamUsage(line []byte) (usage.Detail, bool) { payload := jsonPayload(line) if len(payload) == 0 || !gjson.ValidBytes(payload) { return usage.Detail{}, false } node := gjson.GetBytes(payload, "response.usageMetadata") if !node.Exists() { node = gjson.GetBytes(payload, "usage_metadata") } if !node.Exists() { return usage.Detail{}, false } return parseGeminiFamilyUsageDetail(node), true } func parseAntigravityUsage(data []byte) usage.Detail { usageNode := gjson.ParseBytes(data) node := usageNode.Get("response.usageMetadata") if !node.Exists() { node = usageNode.Get("usageMetadata") } if !node.Exists() { node = usageNode.Get("usage_metadata") } if !node.Exists() { return usage.Detail{} } return parseGeminiFamilyUsageDetail(node) } func parseAntigravityStreamUsage(line []byte) (usage.Detail, bool) { payload := jsonPayload(line) if len(payload) == 0 || !gjson.ValidBytes(payload) { return usage.Detail{}, false } node := gjson.GetBytes(payload, "response.usageMetadata") if !node.Exists() { node = gjson.GetBytes(payload, "usageMetadata") } if !node.Exists() { node = gjson.GetBytes(payload, "usage_metadata") } if !node.Exists() { return usage.Detail{}, false } return parseGeminiFamilyUsageDetail(node), true } var stopChunkWithoutUsage sync.Map func rememberStopWithoutUsage(traceID string) { stopChunkWithoutUsage.Store(traceID, struct{}{}) time.AfterFunc(10*time.Minute, func() { stopChunkWithoutUsage.Delete(traceID) }) } // FilterSSEUsageMetadata removes usageMetadata from SSE events that are not // terminal (finishReason != "stop"). Stop chunks are left untouched. This // function is shared between aistudio and antigravity executors. func FilterSSEUsageMetadata(payload []byte) []byte { if len(payload) == 0 { return payload } lines := bytes.Split(payload, []byte("\n")) modified := false foundData := false for idx, line := range lines { trimmed := bytes.TrimSpace(line) if len(trimmed) == 0 || !bytes.HasPrefix(trimmed, []byte("data:")) { continue } foundData = true dataIdx := bytes.Index(line, []byte("data:")) if dataIdx < 0 { continue } rawJSON := bytes.TrimSpace(line[dataIdx+5:]) traceID := gjson.GetBytes(rawJSON, "traceId").String() if isStopChunkWithoutUsage(rawJSON) && traceID != "" { rememberStopWithoutUsage(traceID) continue } if traceID != "" { if _, ok := stopChunkWithoutUsage.Load(traceID); ok && hasUsageMetadata(rawJSON) { stopChunkWithoutUsage.Delete(traceID) continue } } cleaned, changed := StripUsageMetadataFromJSON(rawJSON) if !changed { continue } var rebuilt []byte rebuilt = append(rebuilt, line[:dataIdx]...) rebuilt = append(rebuilt, []byte("data:")...) if len(cleaned) > 0 { rebuilt = append(rebuilt, ' ') rebuilt = append(rebuilt, cleaned...) } lines[idx] = rebuilt modified = true } if !modified { if !foundData { // Handle payloads that are raw JSON without SSE data: prefix. trimmed := bytes.TrimSpace(payload) cleaned, changed := StripUsageMetadataFromJSON(trimmed) if !changed { return payload } return cleaned } return payload } return bytes.Join(lines, []byte("\n")) } // StripUsageMetadataFromJSON drops usageMetadata unless finishReason is present (terminal). // It handles both formats: // - Aistudio: candidates.0.finishReason // - Antigravity: response.candidates.0.finishReason func StripUsageMetadataFromJSON(rawJSON []byte) ([]byte, bool) { jsonBytes := bytes.TrimSpace(rawJSON) if len(jsonBytes) == 0 || !gjson.ValidBytes(jsonBytes) { return rawJSON, false } // Check for finishReason in both aistudio and antigravity formats finishReason := gjson.GetBytes(jsonBytes, "candidates.0.finishReason") if !finishReason.Exists() { finishReason = gjson.GetBytes(jsonBytes, "response.candidates.0.finishReason") } terminalReason := finishReason.Exists() && strings.TrimSpace(finishReason.String()) != "" usageMetadata := gjson.GetBytes(jsonBytes, "usageMetadata") if !usageMetadata.Exists() { usageMetadata = gjson.GetBytes(jsonBytes, "response.usageMetadata") } // Terminal chunk: keep as-is. if terminalReason { return rawJSON, false } // Nothing to strip if !usageMetadata.Exists() { return rawJSON, false } // Remove usageMetadata from both possible locations cleaned := jsonBytes var changed bool if usageMetadata = gjson.GetBytes(cleaned, "usageMetadata"); usageMetadata.Exists() { // Rename usageMetadata to cpaUsageMetadata in the message_start event of Claude cleaned, _ = sjson.SetRawBytes(cleaned, "cpaUsageMetadata", []byte(usageMetadata.Raw)) cleaned, _ = sjson.DeleteBytes(cleaned, "usageMetadata") changed = true } if usageMetadata = gjson.GetBytes(cleaned, "response.usageMetadata"); usageMetadata.Exists() { // Rename usageMetadata to cpaUsageMetadata in the message_start event of Claude cleaned, _ = sjson.SetRawBytes(cleaned, "response.cpaUsageMetadata", []byte(usageMetadata.Raw)) cleaned, _ = sjson.DeleteBytes(cleaned, "response.usageMetadata") changed = true } return cleaned, changed } func hasUsageMetadata(jsonBytes []byte) bool { if len(jsonBytes) == 0 || !gjson.ValidBytes(jsonBytes) { return false } if gjson.GetBytes(jsonBytes, "usageMetadata").Exists() { return true } if gjson.GetBytes(jsonBytes, "response.usageMetadata").Exists() { return true } return false } func isStopChunkWithoutUsage(jsonBytes []byte) bool { if len(jsonBytes) == 0 || !gjson.ValidBytes(jsonBytes) { return false } finishReason := gjson.GetBytes(jsonBytes, "candidates.0.finishReason") if !finishReason.Exists() { finishReason = gjson.GetBytes(jsonBytes, "response.candidates.0.finishReason") } trimmed := strings.TrimSpace(finishReason.String()) if !finishReason.Exists() || trimmed == "" { return false } return !hasUsageMetadata(jsonBytes) } func jsonPayload(line []byte) []byte { trimmed := bytes.TrimSpace(line) if len(trimmed) == 0 { return nil } if bytes.Equal(trimmed, []byte("[DONE]")) { return nil } if bytes.HasPrefix(trimmed, []byte("event:")) { return nil } if bytes.HasPrefix(trimmed, []byte("data:")) { trimmed = bytes.TrimSpace(trimmed[len("data:"):]) } if len(trimmed) == 0 || trimmed[0] != '{' { return nil } return trimmed }