|
|
package executor |
|
|
|
|
|
import ( |
|
|
"bufio" |
|
|
"bytes" |
|
|
"context" |
|
|
"encoding/base64" |
|
|
"encoding/binary" |
|
|
"encoding/json" |
|
|
"fmt" |
|
|
"io" |
|
|
"net/http" |
|
|
"os" |
|
|
"path/filepath" |
|
|
"strings" |
|
|
"sync" |
|
|
"time" |
|
|
|
|
|
"github.com/google/uuid" |
|
|
kiroauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro" |
|
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config" |
|
|
kiroclaude "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/kiro/claude" |
|
|
kirocommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/kiro/common" |
|
|
kiroopenai "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/kiro/openai" |
|
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util" |
|
|
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" |
|
|
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" |
|
|
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage" |
|
|
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator" |
|
|
log "github.com/sirupsen/logrus" |
|
|
|
|
|
) |
|
|
|
|
|
const ( |
|
|
|
|
|
kiroContentType = "application/x-amz-json-1.0" |
|
|
kiroAcceptStream = "*/*" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
minEventStreamFrameSize = 16 |
|
|
maxEventStreamMsgSize = 10 << 20 |
|
|
|
|
|
|
|
|
ErrStreamFatal = "fatal" |
|
|
ErrStreamMalformed = "malformed" |
|
|
|
|
|
kiroUserAgent = "aws-sdk-rust/1.3.9 os/macos lang/rust/1.87.0" |
|
|
|
|
|
kiroFullUserAgent = "aws-sdk-rust/1.3.9 ua/2.1 api/ssooidc/1.88.0 os/macos lang/rust/1.87.0 m/E app/AmazonQ-For-CLI" |
|
|
|
|
|
|
|
|
kiroIDEUserAgent = "aws-sdk-js/1.0.18 ua/2.1 os/darwin#25.0.0 lang/js md/nodejs#20.16.0 api/codewhispererstreaming#1.0.18 m/E KiroIDE-0.2.13-66c23a8c5d15afabec89ef9954ef52a119f10d369df04d548fc6c1eac694b0d1" |
|
|
kiroIDEAmzUserAgent = "aws-sdk-js/1.0.18 KiroIDE-0.2.13-66c23a8c5d15afabec89ef9954ef52a119f10d369df04d548fc6c1eac694b0d1" |
|
|
kiroIDEAgentModeSpec = "spec" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
var ( |
|
|
usageUpdateCharThreshold = 5000 |
|
|
usageUpdateTimeInterval = 15 * time.Second |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
type kiroEndpointConfig struct { |
|
|
URL string |
|
|
Origin string |
|
|
AmzTarget string |
|
|
Name string |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
var kiroEndpointConfigs = []kiroEndpointConfig{ |
|
|
{ |
|
|
URL: "https://codewhisperer.us-east-1.amazonaws.com/generateAssistantResponse", |
|
|
Origin: "AI_EDITOR", |
|
|
AmzTarget: "AmazonCodeWhispererStreamingService.GenerateAssistantResponse", |
|
|
Name: "CodeWhisperer", |
|
|
}, |
|
|
{ |
|
|
URL: "https://q.us-east-1.amazonaws.com/", |
|
|
Origin: "CLI", |
|
|
AmzTarget: "AmazonQDeveloperStreamingService.SendMessage", |
|
|
Name: "AmazonQ", |
|
|
}, |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
func getKiroEndpointConfigs(auth *cliproxyauth.Auth) []kiroEndpointConfig { |
|
|
if auth == nil { |
|
|
return kiroEndpointConfigs |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if auth.Metadata != nil { |
|
|
authMethod, _ := auth.Metadata["auth_method"].(string) |
|
|
if authMethod == "idc" { |
|
|
log.Debugf("kiro: IDC auth, using CodeWhisperer endpoint") |
|
|
return kiroEndpointConfigs |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
var preference string |
|
|
if auth.Metadata != nil { |
|
|
if p, ok := auth.Metadata["preferred_endpoint"].(string); ok { |
|
|
preference = p |
|
|
} |
|
|
} |
|
|
|
|
|
if preference == "" && auth.Attributes != nil { |
|
|
preference = auth.Attributes["preferred_endpoint"] |
|
|
} |
|
|
|
|
|
if preference == "" { |
|
|
return kiroEndpointConfigs |
|
|
} |
|
|
|
|
|
preference = strings.ToLower(strings.TrimSpace(preference)) |
|
|
|
|
|
|
|
|
var sorted []kiroEndpointConfig |
|
|
var remaining []kiroEndpointConfig |
|
|
|
|
|
for _, cfg := range kiroEndpointConfigs { |
|
|
name := strings.ToLower(cfg.Name) |
|
|
|
|
|
|
|
|
|
|
|
isMatch := false |
|
|
if (preference == "codewhisperer" || preference == "ide") && name == "codewhisperer" { |
|
|
isMatch = true |
|
|
} else if (preference == "amazonq" || preference == "q" || preference == "cli") && name == "amazonq" { |
|
|
isMatch = true |
|
|
} |
|
|
|
|
|
if isMatch { |
|
|
sorted = append(sorted, cfg) |
|
|
} else { |
|
|
remaining = append(remaining, cfg) |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
if len(sorted) == 0 { |
|
|
return kiroEndpointConfigs |
|
|
} |
|
|
|
|
|
|
|
|
return append(sorted, remaining...) |
|
|
} |
|
|
|
|
|
|
|
|
type KiroExecutor struct { |
|
|
cfg *config.Config |
|
|
refreshMu sync.Mutex |
|
|
} |
|
|
|
|
|
|
|
|
func isIDCAuth(auth *cliproxyauth.Auth) bool { |
|
|
if auth == nil || auth.Metadata == nil { |
|
|
return false |
|
|
} |
|
|
authMethod, _ := auth.Metadata["auth_method"].(string) |
|
|
return authMethod == "idc" |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
func buildKiroPayloadForFormat(body []byte, modelID, profileArn, origin string, isAgentic, isChatOnly bool, sourceFormat sdktranslator.Format, headers http.Header) ([]byte, bool) { |
|
|
switch sourceFormat.String() { |
|
|
case "openai": |
|
|
log.Debugf("kiro: using OpenAI payload builder for source format: %s", sourceFormat.String()) |
|
|
return kiroopenai.BuildKiroPayloadFromOpenAI(body, modelID, profileArn, origin, isAgentic, isChatOnly, headers, nil) |
|
|
default: |
|
|
|
|
|
log.Debugf("kiro: using Claude payload builder for source format: %s", sourceFormat.String()) |
|
|
return kiroclaude.BuildKiroPayload(body, modelID, profileArn, origin, isAgentic, isChatOnly, headers, nil) |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
func NewKiroExecutor(cfg *config.Config) *KiroExecutor { |
|
|
return &KiroExecutor{cfg: cfg} |
|
|
} |
|
|
|
|
|
|
|
|
func (e *KiroExecutor) Identifier() string { return "kiro" } |
|
|
|
|
|
|
|
|
func (e *KiroExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Auth) error { return nil } |
|
|
|
|
|
|
|
|
|
|
|
func (e *KiroExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) { |
|
|
accessToken, profileArn := kiroCredentials(auth) |
|
|
if accessToken == "" { |
|
|
return resp, fmt.Errorf("kiro: access token not found in auth") |
|
|
} |
|
|
|
|
|
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) |
|
|
defer reporter.trackFailure(ctx, &err) |
|
|
|
|
|
|
|
|
if e.isTokenExpired(accessToken) { |
|
|
log.Infof("kiro: access token expired, attempting refresh before request") |
|
|
refreshedAuth, refreshErr := e.Refresh(ctx, auth) |
|
|
if refreshErr != nil { |
|
|
log.Warnf("kiro: pre-request token refresh failed: %v", refreshErr) |
|
|
} else if refreshedAuth != nil { |
|
|
auth = refreshedAuth |
|
|
|
|
|
if persistErr := e.persistRefreshedAuth(auth); persistErr != nil { |
|
|
log.Warnf("kiro: failed to persist refreshed auth: %v", persistErr) |
|
|
} |
|
|
accessToken, profileArn = kiroCredentials(auth) |
|
|
log.Infof("kiro: token refreshed successfully before request") |
|
|
} |
|
|
} |
|
|
|
|
|
from := opts.SourceFormat |
|
|
to := sdktranslator.FromString("kiro") |
|
|
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) |
|
|
|
|
|
kiroModelID := e.mapModelToKiro(req.Model) |
|
|
|
|
|
|
|
|
isAgentic, isChatOnly := determineAgenticMode(req.Model) |
|
|
effectiveProfileArn := getEffectiveProfileArnWithWarning(auth, profileArn) |
|
|
|
|
|
|
|
|
|
|
|
resp, err = e.executeWithRetry(ctx, auth, req, opts, accessToken, effectiveProfileArn, nil, body, from, to, reporter, "", kiroModelID, isAgentic, isChatOnly) |
|
|
return resp, err |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, accessToken, profileArn string, kiroPayload, body []byte, from, to sdktranslator.Format, reporter *usageReporter, currentOrigin, kiroModelID string, isAgentic, isChatOnly bool) (cliproxyexecutor.Response, error) { |
|
|
var resp cliproxyexecutor.Response |
|
|
maxRetries := 2 |
|
|
endpointConfigs := getKiroEndpointConfigs(auth) |
|
|
var last429Err error |
|
|
|
|
|
for endpointIdx := 0; endpointIdx < len(endpointConfigs); endpointIdx++ { |
|
|
endpointConfig := endpointConfigs[endpointIdx] |
|
|
url := endpointConfig.URL |
|
|
|
|
|
currentOrigin = endpointConfig.Origin |
|
|
|
|
|
|
|
|
|
|
|
kiroPayload, _ = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from, opts.Headers) |
|
|
|
|
|
log.Debugf("kiro: trying endpoint %d/%d: %s (Name: %s, Origin: %s)", |
|
|
endpointIdx+1, len(endpointConfigs), url, endpointConfig.Name, currentOrigin) |
|
|
|
|
|
for attempt := 0; attempt <= maxRetries; attempt++ { |
|
|
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(kiroPayload)) |
|
|
if err != nil { |
|
|
return resp, err |
|
|
} |
|
|
|
|
|
httpReq.Header.Set("Content-Type", kiroContentType) |
|
|
httpReq.Header.Set("Accept", kiroAcceptStream) |
|
|
|
|
|
httpReq.Header.Set("X-Amz-Target", endpointConfig.AmzTarget) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if isIDCAuth(auth) { |
|
|
httpReq.Header.Set("User-Agent", kiroIDEUserAgent) |
|
|
httpReq.Header.Set("X-Amz-User-Agent", kiroIDEAmzUserAgent) |
|
|
httpReq.Header.Set("x-amzn-kiro-agent-mode", kiroIDEAgentModeSpec) |
|
|
log.Debugf("kiro: using Kiro IDE headers for IDC auth") |
|
|
} else { |
|
|
httpReq.Header.Set("User-Agent", kiroUserAgent) |
|
|
httpReq.Header.Set("X-Amz-User-Agent", kiroFullUserAgent) |
|
|
} |
|
|
httpReq.Header.Set("Amz-Sdk-Request", "attempt=1; max=3") |
|
|
httpReq.Header.Set("Amz-Sdk-Invocation-Id", uuid.New().String()) |
|
|
|
|
|
|
|
|
httpReq.Header.Set("Authorization", "Bearer "+accessToken) |
|
|
|
|
|
var attrs map[string]string |
|
|
if auth != nil { |
|
|
attrs = auth.Attributes |
|
|
} |
|
|
util.ApplyCustomHeadersFromAttrs(httpReq, attrs) |
|
|
|
|
|
var authID, authLabel, authType, authValue string |
|
|
if auth != nil { |
|
|
authID = auth.ID |
|
|
authLabel = auth.Label |
|
|
authType, authValue = auth.AccountInfo() |
|
|
} |
|
|
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ |
|
|
URL: url, |
|
|
Method: http.MethodPost, |
|
|
Headers: httpReq.Header.Clone(), |
|
|
Body: kiroPayload, |
|
|
Provider: e.Identifier(), |
|
|
AuthID: authID, |
|
|
AuthLabel: authLabel, |
|
|
AuthType: authType, |
|
|
AuthValue: authValue, |
|
|
}) |
|
|
|
|
|
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 120*time.Second) |
|
|
httpResp, err := httpClient.Do(httpReq) |
|
|
if err != nil { |
|
|
recordAPIResponseError(ctx, e.cfg, err) |
|
|
return resp, err |
|
|
} |
|
|
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) |
|
|
|
|
|
|
|
|
|
|
|
if httpResp.StatusCode == 429 { |
|
|
respBody, _ := io.ReadAll(httpResp.Body) |
|
|
_ = httpResp.Body.Close() |
|
|
appendAPIResponseChunk(ctx, e.cfg, respBody) |
|
|
|
|
|
|
|
|
last429Err = statusErr{code: httpResp.StatusCode, msg: string(respBody)} |
|
|
|
|
|
log.Warnf("kiro: %s endpoint quota exhausted (429), will try next endpoint, body: %s", |
|
|
endpointConfig.Name, summarizeErrorBody(httpResp.Header.Get("Content-Type"), respBody)) |
|
|
|
|
|
|
|
|
break |
|
|
} |
|
|
|
|
|
|
|
|
if httpResp.StatusCode >= 500 && httpResp.StatusCode < 600 { |
|
|
respBody, _ := io.ReadAll(httpResp.Body) |
|
|
_ = httpResp.Body.Close() |
|
|
appendAPIResponseChunk(ctx, e.cfg, respBody) |
|
|
|
|
|
if attempt < maxRetries { |
|
|
|
|
|
backoff := time.Duration(1<<attempt) * time.Second |
|
|
if backoff > 30*time.Second { |
|
|
backoff = 30 * time.Second |
|
|
} |
|
|
log.Warnf("kiro: server error %d, retrying in %v (attempt %d/%d)", httpResp.StatusCode, backoff, attempt+1, maxRetries) |
|
|
time.Sleep(backoff) |
|
|
continue |
|
|
} |
|
|
log.Errorf("kiro: server error %d after %d retries", httpResp.StatusCode, maxRetries) |
|
|
return resp, statusErr{code: httpResp.StatusCode, msg: string(respBody)} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
if httpResp.StatusCode == 401 { |
|
|
respBody, _ := io.ReadAll(httpResp.Body) |
|
|
_ = httpResp.Body.Close() |
|
|
appendAPIResponseChunk(ctx, e.cfg, respBody) |
|
|
|
|
|
if attempt < maxRetries { |
|
|
log.Warnf("kiro: received 401 error, attempting token refresh and retry (attempt %d/%d)", attempt+1, maxRetries+1) |
|
|
|
|
|
refreshedAuth, refreshErr := e.Refresh(ctx, auth) |
|
|
if refreshErr != nil { |
|
|
log.Errorf("kiro: token refresh failed: %v", refreshErr) |
|
|
return resp, statusErr{code: httpResp.StatusCode, msg: string(respBody)} |
|
|
} |
|
|
|
|
|
if refreshedAuth != nil { |
|
|
auth = refreshedAuth |
|
|
|
|
|
if persistErr := e.persistRefreshedAuth(auth); persistErr != nil { |
|
|
log.Warnf("kiro: failed to persist refreshed auth: %v", persistErr) |
|
|
|
|
|
} |
|
|
accessToken, profileArn = kiroCredentials(auth) |
|
|
|
|
|
kiroPayload, _ = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from, opts.Headers) |
|
|
log.Infof("kiro: token refreshed successfully, retrying request") |
|
|
continue |
|
|
} |
|
|
} |
|
|
|
|
|
log.Warnf("kiro request error, status: 401, body: %s", summarizeErrorBody(httpResp.Header.Get("Content-Type"), respBody)) |
|
|
return resp, statusErr{code: httpResp.StatusCode, msg: string(respBody)} |
|
|
} |
|
|
|
|
|
|
|
|
if httpResp.StatusCode == 402 { |
|
|
respBody, _ := io.ReadAll(httpResp.Body) |
|
|
_ = httpResp.Body.Close() |
|
|
appendAPIResponseChunk(ctx, e.cfg, respBody) |
|
|
|
|
|
log.Warnf("kiro: received 402 (monthly limit). Upstream body: %s", string(respBody)) |
|
|
|
|
|
|
|
|
return resp, statusErr{code: httpResp.StatusCode, msg: string(respBody)} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
if httpResp.StatusCode == 403 { |
|
|
respBody, _ := io.ReadAll(httpResp.Body) |
|
|
_ = httpResp.Body.Close() |
|
|
appendAPIResponseChunk(ctx, e.cfg, respBody) |
|
|
|
|
|
|
|
|
log.Warnf("kiro: received 403 error (attempt %d/%d), body: %s", attempt+1, maxRetries+1, summarizeErrorBody(httpResp.Header.Get("Content-Type"), respBody)) |
|
|
|
|
|
respBodyStr := string(respBody) |
|
|
|
|
|
|
|
|
if strings.Contains(respBodyStr, "SUSPENDED") || strings.Contains(respBodyStr, "TEMPORARILY_SUSPENDED") { |
|
|
log.Errorf("kiro: account is suspended, cannot proceed") |
|
|
return resp, statusErr{code: httpResp.StatusCode, msg: "account suspended: " + string(respBody)} |
|
|
} |
|
|
|
|
|
|
|
|
isTokenRelated := strings.Contains(respBodyStr, "token") || |
|
|
strings.Contains(respBodyStr, "expired") || |
|
|
strings.Contains(respBodyStr, "invalid") || |
|
|
strings.Contains(respBodyStr, "unauthorized") |
|
|
|
|
|
if isTokenRelated && attempt < maxRetries { |
|
|
log.Warnf("kiro: 403 appears token-related, attempting token refresh") |
|
|
refreshedAuth, refreshErr := e.Refresh(ctx, auth) |
|
|
if refreshErr != nil { |
|
|
log.Errorf("kiro: token refresh failed: %v", refreshErr) |
|
|
|
|
|
return resp, statusErr{code: httpResp.StatusCode, msg: string(respBody)} |
|
|
} |
|
|
if refreshedAuth != nil { |
|
|
auth = refreshedAuth |
|
|
|
|
|
if persistErr := e.persistRefreshedAuth(auth); persistErr != nil { |
|
|
log.Warnf("kiro: failed to persist refreshed auth: %v", persistErr) |
|
|
|
|
|
} |
|
|
accessToken, profileArn = kiroCredentials(auth) |
|
|
kiroPayload, _ = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from, opts.Headers) |
|
|
log.Infof("kiro: token refreshed for 403, retrying request") |
|
|
continue |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
log.Warnf("kiro: 403 error, returning immediately (no endpoint switch)") |
|
|
return resp, statusErr{code: httpResp.StatusCode, msg: string(respBody)} |
|
|
} |
|
|
|
|
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { |
|
|
b, _ := io.ReadAll(httpResp.Body) |
|
|
appendAPIResponseChunk(ctx, e.cfg, b) |
|
|
log.Debugf("kiro request error, status: %d, body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) |
|
|
err = statusErr{code: httpResp.StatusCode, msg: string(b)} |
|
|
if errClose := httpResp.Body.Close(); errClose != nil { |
|
|
log.Errorf("response body close error: %v", errClose) |
|
|
} |
|
|
return resp, err |
|
|
} |
|
|
|
|
|
defer func() { |
|
|
if errClose := httpResp.Body.Close(); errClose != nil { |
|
|
log.Errorf("response body close error: %v", errClose) |
|
|
} |
|
|
}() |
|
|
|
|
|
content, toolUses, usageInfo, stopReason, err := e.parseEventStream(httpResp.Body) |
|
|
if err != nil { |
|
|
recordAPIResponseError(ctx, e.cfg, err) |
|
|
return resp, err |
|
|
} |
|
|
|
|
|
|
|
|
if usageInfo.TotalTokens == 0 { |
|
|
if enc, encErr := getTokenizer(req.Model); encErr == nil { |
|
|
if inp, countErr := countOpenAIChatTokens(enc, opts.OriginalRequest); countErr == nil { |
|
|
usageInfo.InputTokens = inp |
|
|
} |
|
|
} |
|
|
if len(content) > 0 { |
|
|
|
|
|
if enc, encErr := getTokenizer(req.Model); encErr == nil { |
|
|
if tokenCount, countErr := enc.Count(content); countErr == nil { |
|
|
usageInfo.OutputTokens = int64(tokenCount) |
|
|
} |
|
|
} |
|
|
|
|
|
if usageInfo.OutputTokens == 0 { |
|
|
usageInfo.OutputTokens = int64(len(content) / 4) |
|
|
if usageInfo.OutputTokens == 0 { |
|
|
usageInfo.OutputTokens = 1 |
|
|
} |
|
|
} |
|
|
} |
|
|
usageInfo.TotalTokens = usageInfo.InputTokens + usageInfo.OutputTokens |
|
|
} |
|
|
|
|
|
appendAPIResponseChunk(ctx, e.cfg, []byte(content)) |
|
|
reporter.publish(ctx, usageInfo) |
|
|
|
|
|
|
|
|
|
|
|
kiroResponse := kiroclaude.BuildClaudeResponse(content, toolUses, req.Model, usageInfo, stopReason) |
|
|
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, kiroResponse, nil) |
|
|
resp = cliproxyexecutor.Response{Payload: []byte(out)} |
|
|
return resp, nil |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
if last429Err != nil { |
|
|
return resp, last429Err |
|
|
} |
|
|
return resp, fmt.Errorf("kiro: all endpoints exhausted") |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
func (e *KiroExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) { |
|
|
accessToken, profileArn := kiroCredentials(auth) |
|
|
if accessToken == "" { |
|
|
return nil, fmt.Errorf("kiro: access token not found in auth") |
|
|
} |
|
|
|
|
|
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) |
|
|
defer reporter.trackFailure(ctx, &err) |
|
|
|
|
|
|
|
|
if e.isTokenExpired(accessToken) { |
|
|
log.Infof("kiro: access token expired, attempting refresh before stream request") |
|
|
refreshedAuth, refreshErr := e.Refresh(ctx, auth) |
|
|
if refreshErr != nil { |
|
|
log.Warnf("kiro: pre-request token refresh failed: %v", refreshErr) |
|
|
} else if refreshedAuth != nil { |
|
|
auth = refreshedAuth |
|
|
|
|
|
if persistErr := e.persistRefreshedAuth(auth); persistErr != nil { |
|
|
log.Warnf("kiro: failed to persist refreshed auth: %v", persistErr) |
|
|
} |
|
|
accessToken, profileArn = kiroCredentials(auth) |
|
|
log.Infof("kiro: token refreshed successfully before stream request") |
|
|
} |
|
|
} |
|
|
|
|
|
from := opts.SourceFormat |
|
|
to := sdktranslator.FromString("kiro") |
|
|
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) |
|
|
|
|
|
kiroModelID := e.mapModelToKiro(req.Model) |
|
|
|
|
|
|
|
|
isAgentic, isChatOnly := determineAgenticMode(req.Model) |
|
|
effectiveProfileArn := getEffectiveProfileArnWithWarning(auth, profileArn) |
|
|
|
|
|
|
|
|
|
|
|
return e.executeStreamWithRetry(ctx, auth, req, opts, accessToken, effectiveProfileArn, nil, body, from, reporter, "", kiroModelID, isAgentic, isChatOnly) |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, accessToken, profileArn string, kiroPayload, body []byte, from sdktranslator.Format, reporter *usageReporter, currentOrigin, kiroModelID string, isAgentic, isChatOnly bool) (<-chan cliproxyexecutor.StreamChunk, error) { |
|
|
maxRetries := 2 |
|
|
endpointConfigs := getKiroEndpointConfigs(auth) |
|
|
var last429Err error |
|
|
|
|
|
for endpointIdx := 0; endpointIdx < len(endpointConfigs); endpointIdx++ { |
|
|
endpointConfig := endpointConfigs[endpointIdx] |
|
|
url := endpointConfig.URL |
|
|
|
|
|
currentOrigin = endpointConfig.Origin |
|
|
|
|
|
|
|
|
|
|
|
kiroPayload, thinkingEnabled := buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from, opts.Headers) |
|
|
|
|
|
log.Debugf("kiro: stream trying endpoint %d/%d: %s (Name: %s, Origin: %s)", |
|
|
endpointIdx+1, len(endpointConfigs), url, endpointConfig.Name, currentOrigin) |
|
|
|
|
|
for attempt := 0; attempt <= maxRetries; attempt++ { |
|
|
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(kiroPayload)) |
|
|
if err != nil { |
|
|
return nil, err |
|
|
} |
|
|
|
|
|
httpReq.Header.Set("Content-Type", kiroContentType) |
|
|
httpReq.Header.Set("Accept", kiroAcceptStream) |
|
|
|
|
|
httpReq.Header.Set("X-Amz-Target", endpointConfig.AmzTarget) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if isIDCAuth(auth) { |
|
|
httpReq.Header.Set("User-Agent", kiroIDEUserAgent) |
|
|
httpReq.Header.Set("X-Amz-User-Agent", kiroIDEAmzUserAgent) |
|
|
httpReq.Header.Set("x-amzn-kiro-agent-mode", kiroIDEAgentModeSpec) |
|
|
log.Debugf("kiro: using Kiro IDE headers for IDC auth") |
|
|
} else { |
|
|
httpReq.Header.Set("User-Agent", kiroUserAgent) |
|
|
httpReq.Header.Set("X-Amz-User-Agent", kiroFullUserAgent) |
|
|
} |
|
|
httpReq.Header.Set("Amz-Sdk-Request", "attempt=1; max=3") |
|
|
httpReq.Header.Set("Amz-Sdk-Invocation-Id", uuid.New().String()) |
|
|
|
|
|
|
|
|
httpReq.Header.Set("Authorization", "Bearer "+accessToken) |
|
|
|
|
|
var attrs map[string]string |
|
|
if auth != nil { |
|
|
attrs = auth.Attributes |
|
|
} |
|
|
util.ApplyCustomHeadersFromAttrs(httpReq, attrs) |
|
|
|
|
|
var authID, authLabel, authType, authValue string |
|
|
if auth != nil { |
|
|
authID = auth.ID |
|
|
authLabel = auth.Label |
|
|
authType, authValue = auth.AccountInfo() |
|
|
} |
|
|
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ |
|
|
URL: url, |
|
|
Method: http.MethodPost, |
|
|
Headers: httpReq.Header.Clone(), |
|
|
Body: kiroPayload, |
|
|
Provider: e.Identifier(), |
|
|
AuthID: authID, |
|
|
AuthLabel: authLabel, |
|
|
AuthType: authType, |
|
|
AuthValue: authValue, |
|
|
}) |
|
|
|
|
|
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) |
|
|
httpResp, err := httpClient.Do(httpReq) |
|
|
if err != nil { |
|
|
recordAPIResponseError(ctx, e.cfg, err) |
|
|
return nil, err |
|
|
} |
|
|
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) |
|
|
|
|
|
|
|
|
|
|
|
if httpResp.StatusCode == 429 { |
|
|
respBody, _ := io.ReadAll(httpResp.Body) |
|
|
_ = httpResp.Body.Close() |
|
|
appendAPIResponseChunk(ctx, e.cfg, respBody) |
|
|
|
|
|
|
|
|
last429Err = statusErr{code: httpResp.StatusCode, msg: string(respBody)} |
|
|
|
|
|
log.Warnf("kiro: stream %s endpoint quota exhausted (429), will try next endpoint, body: %s", |
|
|
endpointConfig.Name, summarizeErrorBody(httpResp.Header.Get("Content-Type"), respBody)) |
|
|
|
|
|
|
|
|
break |
|
|
} |
|
|
|
|
|
|
|
|
if httpResp.StatusCode >= 500 && httpResp.StatusCode < 600 { |
|
|
respBody, _ := io.ReadAll(httpResp.Body) |
|
|
_ = httpResp.Body.Close() |
|
|
appendAPIResponseChunk(ctx, e.cfg, respBody) |
|
|
|
|
|
if attempt < maxRetries { |
|
|
|
|
|
backoff := time.Duration(1<<attempt) * time.Second |
|
|
if backoff > 30*time.Second { |
|
|
backoff = 30 * time.Second |
|
|
} |
|
|
log.Warnf("kiro: stream server error %d, retrying in %v (attempt %d/%d)", httpResp.StatusCode, backoff, attempt+1, maxRetries) |
|
|
time.Sleep(backoff) |
|
|
continue |
|
|
} |
|
|
log.Errorf("kiro: stream server error %d after %d retries", httpResp.StatusCode, maxRetries) |
|
|
return nil, statusErr{code: httpResp.StatusCode, msg: string(respBody)} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
if httpResp.StatusCode == 400 { |
|
|
respBody, _ := io.ReadAll(httpResp.Body) |
|
|
_ = httpResp.Body.Close() |
|
|
appendAPIResponseChunk(ctx, e.cfg, respBody) |
|
|
|
|
|
log.Warnf("kiro: received 400 error (attempt %d/%d), body: %s", attempt+1, maxRetries+1, summarizeErrorBody(httpResp.Header.Get("Content-Type"), respBody)) |
|
|
|
|
|
|
|
|
return nil, statusErr{code: httpResp.StatusCode, msg: string(respBody)} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
if httpResp.StatusCode == 401 { |
|
|
respBody, _ := io.ReadAll(httpResp.Body) |
|
|
_ = httpResp.Body.Close() |
|
|
appendAPIResponseChunk(ctx, e.cfg, respBody) |
|
|
|
|
|
if attempt < maxRetries { |
|
|
log.Warnf("kiro: stream received 401 error, attempting token refresh and retry (attempt %d/%d)", attempt+1, maxRetries+1) |
|
|
|
|
|
refreshedAuth, refreshErr := e.Refresh(ctx, auth) |
|
|
if refreshErr != nil { |
|
|
log.Errorf("kiro: token refresh failed: %v", refreshErr) |
|
|
return nil, statusErr{code: httpResp.StatusCode, msg: string(respBody)} |
|
|
} |
|
|
|
|
|
if refreshedAuth != nil { |
|
|
auth = refreshedAuth |
|
|
|
|
|
if persistErr := e.persistRefreshedAuth(auth); persistErr != nil { |
|
|
log.Warnf("kiro: failed to persist refreshed auth: %v", persistErr) |
|
|
|
|
|
} |
|
|
accessToken, profileArn = kiroCredentials(auth) |
|
|
|
|
|
kiroPayload, _ = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from, opts.Headers) |
|
|
log.Infof("kiro: token refreshed successfully, retrying stream request") |
|
|
continue |
|
|
} |
|
|
} |
|
|
|
|
|
log.Warnf("kiro stream error, status: 401, body: %s", string(respBody)) |
|
|
return nil, statusErr{code: httpResp.StatusCode, msg: string(respBody)} |
|
|
} |
|
|
|
|
|
|
|
|
if httpResp.StatusCode == 402 { |
|
|
respBody, _ := io.ReadAll(httpResp.Body) |
|
|
_ = httpResp.Body.Close() |
|
|
appendAPIResponseChunk(ctx, e.cfg, respBody) |
|
|
|
|
|
log.Warnf("kiro: stream received 402 (monthly limit). Upstream body: %s", string(respBody)) |
|
|
|
|
|
|
|
|
return nil, statusErr{code: httpResp.StatusCode, msg: string(respBody)} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
if httpResp.StatusCode == 403 { |
|
|
respBody, _ := io.ReadAll(httpResp.Body) |
|
|
_ = httpResp.Body.Close() |
|
|
appendAPIResponseChunk(ctx, e.cfg, respBody) |
|
|
|
|
|
|
|
|
log.Warnf("kiro: stream received 403 error (attempt %d/%d), body: %s", attempt+1, maxRetries+1, string(respBody)) |
|
|
|
|
|
respBodyStr := string(respBody) |
|
|
|
|
|
|
|
|
if strings.Contains(respBodyStr, "SUSPENDED") || strings.Contains(respBodyStr, "TEMPORARILY_SUSPENDED") { |
|
|
log.Errorf("kiro: account is suspended, cannot proceed") |
|
|
return nil, statusErr{code: httpResp.StatusCode, msg: "account suspended: " + string(respBody)} |
|
|
} |
|
|
|
|
|
|
|
|
isTokenRelated := strings.Contains(respBodyStr, "token") || |
|
|
strings.Contains(respBodyStr, "expired") || |
|
|
strings.Contains(respBodyStr, "invalid") || |
|
|
strings.Contains(respBodyStr, "unauthorized") |
|
|
|
|
|
if isTokenRelated && attempt < maxRetries { |
|
|
log.Warnf("kiro: 403 appears token-related, attempting token refresh") |
|
|
refreshedAuth, refreshErr := e.Refresh(ctx, auth) |
|
|
if refreshErr != nil { |
|
|
log.Errorf("kiro: token refresh failed: %v", refreshErr) |
|
|
|
|
|
return nil, statusErr{code: httpResp.StatusCode, msg: string(respBody)} |
|
|
} |
|
|
if refreshedAuth != nil { |
|
|
auth = refreshedAuth |
|
|
|
|
|
if persistErr := e.persistRefreshedAuth(auth); persistErr != nil { |
|
|
log.Warnf("kiro: failed to persist refreshed auth: %v", persistErr) |
|
|
|
|
|
} |
|
|
accessToken, profileArn = kiroCredentials(auth) |
|
|
kiroPayload, _ = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from, opts.Headers) |
|
|
log.Infof("kiro: token refreshed for 403, retrying stream request") |
|
|
continue |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
log.Warnf("kiro: 403 error, returning immediately (no endpoint switch)") |
|
|
return nil, statusErr{code: httpResp.StatusCode, msg: string(respBody)} |
|
|
} |
|
|
|
|
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { |
|
|
b, _ := io.ReadAll(httpResp.Body) |
|
|
appendAPIResponseChunk(ctx, e.cfg, b) |
|
|
log.Debugf("kiro stream error, status: %d, body: %s", httpResp.StatusCode, string(b)) |
|
|
if errClose := httpResp.Body.Close(); errClose != nil { |
|
|
log.Errorf("response body close error: %v", errClose) |
|
|
} |
|
|
return nil, statusErr{code: httpResp.StatusCode, msg: string(b)} |
|
|
} |
|
|
|
|
|
out := make(chan cliproxyexecutor.StreamChunk) |
|
|
|
|
|
go func(resp *http.Response, thinkingEnabled bool) { |
|
|
defer close(out) |
|
|
defer func() { |
|
|
if r := recover(); r != nil { |
|
|
log.Errorf("kiro: panic in stream handler: %v", r) |
|
|
out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("internal error: %v", r)} |
|
|
} |
|
|
}() |
|
|
defer func() { |
|
|
if errClose := resp.Body.Close(); errClose != nil { |
|
|
log.Errorf("response body close error: %v", errClose) |
|
|
} |
|
|
}() |
|
|
|
|
|
|
|
|
|
|
|
log.Debugf("kiro: stream thinkingEnabled = %v (always true for Kiro)", thinkingEnabled) |
|
|
|
|
|
e.streamToChannel(ctx, resp.Body, out, from, req.Model, opts.OriginalRequest, body, reporter, thinkingEnabled) |
|
|
}(httpResp, thinkingEnabled) |
|
|
|
|
|
return out, nil |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
if last429Err != nil { |
|
|
return nil, last429Err |
|
|
} |
|
|
return nil, fmt.Errorf("kiro: stream all endpoints exhausted") |
|
|
} |
|
|
|
|
|
|
|
|
func kiroCredentials(auth *cliproxyauth.Auth) (accessToken, profileArn string) { |
|
|
if auth == nil { |
|
|
return "", "" |
|
|
} |
|
|
|
|
|
|
|
|
if auth.Metadata != nil { |
|
|
if token, ok := auth.Metadata["access_token"].(string); ok { |
|
|
accessToken = token |
|
|
} |
|
|
if arn, ok := auth.Metadata["profile_arn"].(string); ok { |
|
|
profileArn = arn |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
if accessToken == "" && auth.Attributes != nil { |
|
|
accessToken = auth.Attributes["access_token"] |
|
|
profileArn = auth.Attributes["profile_arn"] |
|
|
} |
|
|
|
|
|
|
|
|
if accessToken == "" && auth.Metadata != nil { |
|
|
if token, ok := auth.Metadata["accessToken"].(string); ok { |
|
|
accessToken = token |
|
|
} |
|
|
if arn, ok := auth.Metadata["profileArn"].(string); ok { |
|
|
profileArn = arn |
|
|
} |
|
|
} |
|
|
|
|
|
return accessToken, profileArn |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
func findRealThinkingEndTag(content string, alreadyInCodeBlock, alreadyInInlineCode bool) int { |
|
|
searchStart := 0 |
|
|
for { |
|
|
endIdx := strings.Index(content[searchStart:], kirocommon.ThinkingEndTag) |
|
|
if endIdx < 0 { |
|
|
return -1 |
|
|
} |
|
|
endIdx += searchStart |
|
|
|
|
|
textBeforeEnd := content[:endIdx] |
|
|
textAfterEnd := content[endIdx+len(kirocommon.ThinkingEndTag):] |
|
|
|
|
|
|
|
|
|
|
|
backtickCount := strings.Count(textBeforeEnd, "`") |
|
|
effectiveInInlineCode := alreadyInInlineCode |
|
|
if backtickCount%2 == 1 { |
|
|
effectiveInInlineCode = !effectiveInInlineCode |
|
|
} |
|
|
if effectiveInInlineCode { |
|
|
log.Debugf("kiro: found </thinking> inside inline code at pos %d, skipping", endIdx) |
|
|
searchStart = endIdx + len(kirocommon.ThinkingEndTag) |
|
|
continue |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
fenceCount := strings.Count(textBeforeEnd, "```") |
|
|
altFenceCount := strings.Count(textBeforeEnd, "~~~") |
|
|
effectiveInCodeBlock := alreadyInCodeBlock |
|
|
if fenceCount%2 == 1 || altFenceCount%2 == 1 { |
|
|
effectiveInCodeBlock = !effectiveInCodeBlock |
|
|
} |
|
|
if effectiveInCodeBlock { |
|
|
log.Debugf("kiro: found </thinking> inside code block at pos %d, skipping", endIdx) |
|
|
searchStart = endIdx + len(kirocommon.ThinkingEndTag) |
|
|
continue |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
charBeforeTag := byte(0) |
|
|
if endIdx > 0 { |
|
|
charBeforeTag = content[endIdx-1] |
|
|
} |
|
|
charAfterTag := byte(0) |
|
|
if len(textAfterEnd) > 0 { |
|
|
charAfterTag = textAfterEnd[0] |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
isPrecededByNewlineOrSentenceEnd := charBeforeTag == '\n' || charBeforeTag == '.' || |
|
|
charBeforeTag == '!' || charBeforeTag == '?' || charBeforeTag == 0 |
|
|
isFollowedByNewlineOrEnd := charAfterTag == '\n' || charAfterTag == 0 |
|
|
|
|
|
|
|
|
if isPrecededByNewlineOrSentenceEnd && isFollowedByNewlineOrEnd { |
|
|
log.Debugf("kiro: found properly formatted </thinking> at pos %d", endIdx) |
|
|
return endIdx |
|
|
} |
|
|
|
|
|
|
|
|
lastNewlineIdx := strings.LastIndex(textBeforeEnd, "\n") |
|
|
lineBeforeTag := textBeforeEnd |
|
|
if lastNewlineIdx >= 0 { |
|
|
lineBeforeTag = textBeforeEnd[lastNewlineIdx+1:] |
|
|
} |
|
|
lineBeforeTagLower := strings.ToLower(lineBeforeTag) |
|
|
|
|
|
|
|
|
discussionPatterns := []string{ |
|
|
"标签", "返回", "输出", "包含", "使用", "解析", "转换", "生成", |
|
|
"tag", "return", "output", "contain", "use", "parse", "emit", "convert", "generate", |
|
|
"<thinking>", |
|
|
"`</thinking>`", |
|
|
} |
|
|
isDiscussion := false |
|
|
for _, pattern := range discussionPatterns { |
|
|
if strings.Contains(lineBeforeTagLower, pattern) { |
|
|
isDiscussion = true |
|
|
break |
|
|
} |
|
|
} |
|
|
if isDiscussion { |
|
|
log.Debugf("kiro: found </thinking> after discussion text at pos %d, skipping", endIdx) |
|
|
searchStart = endIdx + len(kirocommon.ThinkingEndTag) |
|
|
continue |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
if len(textAfterEnd) > 0 && charAfterTag != '\n' && charAfterTag != 0 { |
|
|
|
|
|
nextNewline := strings.Index(textAfterEnd, "\n") |
|
|
var textOnSameLine string |
|
|
if nextNewline >= 0 { |
|
|
textOnSameLine = textAfterEnd[:nextNewline] |
|
|
} else { |
|
|
textOnSameLine = textAfterEnd |
|
|
} |
|
|
|
|
|
if strings.TrimSpace(textOnSameLine) != "" { |
|
|
log.Debugf("kiro: found </thinking> with text after on same line at pos %d, skipping", endIdx) |
|
|
searchStart = endIdx + len(kirocommon.ThinkingEndTag) |
|
|
continue |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
if strings.Contains(textAfterEnd, kirocommon.ThinkingStartTag) { |
|
|
nextStartIdx := strings.Index(textAfterEnd, kirocommon.ThinkingStartTag) |
|
|
textBeforeNextStart := textAfterEnd[:nextStartIdx] |
|
|
nextBacktickCount := strings.Count(textBeforeNextStart, "`") |
|
|
nextFenceCount := strings.Count(textBeforeNextStart, "```") |
|
|
nextAltFenceCount := strings.Count(textBeforeNextStart, "~~~") |
|
|
|
|
|
|
|
|
if nextBacktickCount%2 == 0 && nextFenceCount%2 == 0 && nextAltFenceCount%2 == 0 { |
|
|
log.Debugf("kiro: found </thinking> followed by <thinking> at pos %d, likely discussion text, skipping", endIdx) |
|
|
searchStart = endIdx + len(kirocommon.ThinkingEndTag) |
|
|
continue |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
return endIdx |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
func determineAgenticMode(model string) (isAgentic, isChatOnly bool) { |
|
|
isAgentic = strings.HasSuffix(model, "-agentic") |
|
|
isChatOnly = strings.HasSuffix(model, "-chat") |
|
|
return isAgentic, isChatOnly |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
func getEffectiveProfileArn(auth *cliproxyauth.Auth, profileArn string) string { |
|
|
if auth != nil && auth.Metadata != nil { |
|
|
if authMethod, ok := auth.Metadata["auth_method"].(string); ok && authMethod == "builder-id" { |
|
|
return "" |
|
|
} |
|
|
} |
|
|
return profileArn |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
func getEffectiveProfileArnWithWarning(auth *cliproxyauth.Auth, profileArn string) string { |
|
|
if auth != nil && auth.Metadata != nil { |
|
|
if authMethod, ok := auth.Metadata["auth_method"].(string); ok && (authMethod == "builder-id" || authMethod == "idc") { |
|
|
|
|
|
return "" |
|
|
} |
|
|
} |
|
|
|
|
|
if profileArn == "" { |
|
|
log.Warnf("kiro: profile ARN not found in auth, API calls may fail") |
|
|
} |
|
|
return profileArn |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
func (e *KiroExecutor) mapModelToKiro(model string) string { |
|
|
modelMap := map[string]string{ |
|
|
|
|
|
"amazonq-auto": "auto", |
|
|
"amazonq-claude-opus-4-5": "claude-opus-4.5", |
|
|
"amazonq-claude-sonnet-4-5": "claude-sonnet-4.5", |
|
|
"amazonq-claude-sonnet-4-5-20250929": "claude-sonnet-4.5", |
|
|
"amazonq-claude-sonnet-4": "claude-sonnet-4", |
|
|
"amazonq-claude-sonnet-4-20250514": "claude-sonnet-4", |
|
|
"amazonq-claude-haiku-4-5": "claude-haiku-4.5", |
|
|
|
|
|
"kiro-claude-opus-4-5": "claude-opus-4.5", |
|
|
"kiro-claude-sonnet-4-5": "claude-sonnet-4.5", |
|
|
"kiro-claude-sonnet-4-5-20250929": "claude-sonnet-4.5", |
|
|
"kiro-claude-sonnet-4": "claude-sonnet-4", |
|
|
"kiro-claude-sonnet-4-20250514": "claude-sonnet-4", |
|
|
"kiro-claude-haiku-4-5": "claude-haiku-4.5", |
|
|
"kiro-auto": "auto", |
|
|
|
|
|
"claude-opus-4-5": "claude-opus-4.5", |
|
|
"claude-opus-4.5": "claude-opus-4.5", |
|
|
"claude-haiku-4-5": "claude-haiku-4.5", |
|
|
"claude-haiku-4.5": "claude-haiku-4.5", |
|
|
"claude-sonnet-4-5": "claude-sonnet-4.5", |
|
|
"claude-sonnet-4-5-20250929": "claude-sonnet-4.5", |
|
|
"claude-sonnet-4.5": "claude-sonnet-4.5", |
|
|
"claude-sonnet-4": "claude-sonnet-4", |
|
|
"claude-sonnet-4-20250514": "claude-sonnet-4", |
|
|
"auto": "auto", |
|
|
|
|
|
"claude-opus-4.5-agentic": "claude-opus-4.5", |
|
|
"claude-sonnet-4.5-agentic": "claude-sonnet-4.5", |
|
|
"claude-sonnet-4-agentic": "claude-sonnet-4", |
|
|
"claude-haiku-4.5-agentic": "claude-haiku-4.5", |
|
|
"kiro-claude-opus-4-5-agentic": "claude-opus-4.5", |
|
|
"kiro-claude-sonnet-4-5-agentic": "claude-sonnet-4.5", |
|
|
"kiro-claude-sonnet-4-agentic": "claude-sonnet-4", |
|
|
"kiro-claude-haiku-4-5-agentic": "claude-haiku-4.5", |
|
|
} |
|
|
if kiroID, ok := modelMap[model]; ok { |
|
|
return kiroID |
|
|
} |
|
|
|
|
|
|
|
|
modelLower := strings.ToLower(model) |
|
|
|
|
|
|
|
|
if strings.Contains(modelLower, "haiku") { |
|
|
log.Debugf("kiro: unknown Haiku model '%s', mapping to claude-haiku-4.5", model) |
|
|
return "claude-haiku-4.5" |
|
|
} |
|
|
|
|
|
|
|
|
if strings.Contains(modelLower, "sonnet") { |
|
|
|
|
|
if strings.Contains(modelLower, "3-7") || strings.Contains(modelLower, "3.7") { |
|
|
log.Debugf("kiro: unknown Sonnet 3.7 model '%s', mapping to claude-3-7-sonnet-20250219", model) |
|
|
return "claude-3-7-sonnet-20250219" |
|
|
} |
|
|
if strings.Contains(modelLower, "4-5") || strings.Contains(modelLower, "4.5") { |
|
|
log.Debugf("kiro: unknown Sonnet 4.5 model '%s', mapping to claude-sonnet-4.5", model) |
|
|
return "claude-sonnet-4.5" |
|
|
} |
|
|
|
|
|
log.Debugf("kiro: unknown Sonnet model '%s', mapping to claude-sonnet-4", model) |
|
|
return "claude-sonnet-4" |
|
|
} |
|
|
|
|
|
|
|
|
if strings.Contains(modelLower, "opus") { |
|
|
log.Debugf("kiro: unknown Opus model '%s', mapping to claude-opus-4.5", model) |
|
|
return "claude-opus-4.5" |
|
|
} |
|
|
|
|
|
|
|
|
log.Warnf("kiro: unknown model '%s', falling back to claude-sonnet-4.5", model) |
|
|
return "claude-sonnet-4.5" |
|
|
} |
|
|
|
|
|
|
|
|
type EventStreamError struct { |
|
|
Type string |
|
|
Message string |
|
|
Cause error |
|
|
} |
|
|
|
|
|
func (e *EventStreamError) Error() string { |
|
|
if e.Cause != nil { |
|
|
return fmt.Sprintf("event stream %s: %s: %v", e.Type, e.Message, e.Cause) |
|
|
} |
|
|
return fmt.Sprintf("event stream %s: %s", e.Type, e.Message) |
|
|
} |
|
|
|
|
|
|
|
|
type eventStreamMessage struct { |
|
|
EventType string |
|
|
Payload []byte |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
func (e *KiroExecutor) parseEventStream(body io.Reader) (string, []kiroclaude.KiroToolUse, usage.Detail, string, error) { |
|
|
var content strings.Builder |
|
|
var toolUses []kiroclaude.KiroToolUse |
|
|
var usageInfo usage.Detail |
|
|
var stopReason string |
|
|
reader := bufio.NewReader(body) |
|
|
|
|
|
|
|
|
processedIDs := make(map[string]bool) |
|
|
var currentToolUse *kiroclaude.ToolUseState |
|
|
|
|
|
|
|
|
var upstreamContextPercentage float64 |
|
|
|
|
|
for { |
|
|
msg, eventErr := e.readEventStreamMessage(reader) |
|
|
if eventErr != nil { |
|
|
log.Errorf("kiro: parseEventStream error: %v", eventErr) |
|
|
return content.String(), toolUses, usageInfo, stopReason, eventErr |
|
|
} |
|
|
if msg == nil { |
|
|
|
|
|
break |
|
|
} |
|
|
|
|
|
eventType := msg.EventType |
|
|
payload := msg.Payload |
|
|
if len(payload) == 0 { |
|
|
continue |
|
|
} |
|
|
|
|
|
var event map[string]interface{} |
|
|
if err := json.Unmarshal(payload, &event); err != nil { |
|
|
log.Debugf("kiro: skipping malformed event: %v", err) |
|
|
continue |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
if errType, hasErrType := event["_type"].(string); hasErrType { |
|
|
|
|
|
errMsg := "" |
|
|
if msg, ok := event["message"].(string); ok { |
|
|
errMsg = msg |
|
|
} |
|
|
log.Errorf("kiro: received AWS error in event stream: type=%s, message=%s", errType, errMsg) |
|
|
return "", nil, usageInfo, stopReason, fmt.Errorf("kiro API error: %s - %s", errType, errMsg) |
|
|
} |
|
|
if errType, hasErrType := event["type"].(string); hasErrType && (errType == "error" || errType == "exception") { |
|
|
|
|
|
errMsg := "" |
|
|
if msg, ok := event["message"].(string); ok { |
|
|
errMsg = msg |
|
|
} else if errObj, ok := event["error"].(map[string]interface{}); ok { |
|
|
if msg, ok := errObj["message"].(string); ok { |
|
|
errMsg = msg |
|
|
} |
|
|
} |
|
|
log.Errorf("kiro: received error event in stream: type=%s, message=%s", errType, errMsg) |
|
|
return "", nil, usageInfo, stopReason, fmt.Errorf("kiro API error: %s", errMsg) |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
if sr := kirocommon.GetString(event, "stop_reason"); sr != "" { |
|
|
stopReason = sr |
|
|
log.Debugf("kiro: parseEventStream found stop_reason (top-level): %s", stopReason) |
|
|
} |
|
|
if sr := kirocommon.GetString(event, "stopReason"); sr != "" { |
|
|
stopReason = sr |
|
|
log.Debugf("kiro: parseEventStream found stopReason (top-level): %s", stopReason) |
|
|
} |
|
|
|
|
|
|
|
|
switch eventType { |
|
|
case "followupPromptEvent": |
|
|
|
|
|
log.Debugf("kiro: parseEventStream ignoring followupPrompt event") |
|
|
continue |
|
|
|
|
|
case "assistantResponseEvent": |
|
|
if assistantResp, ok := event["assistantResponseEvent"].(map[string]interface{}); ok { |
|
|
if contentText, ok := assistantResp["content"].(string); ok { |
|
|
content.WriteString(contentText) |
|
|
} |
|
|
|
|
|
if sr := kirocommon.GetString(assistantResp, "stop_reason"); sr != "" { |
|
|
stopReason = sr |
|
|
log.Debugf("kiro: parseEventStream found stop_reason in assistantResponseEvent: %s", stopReason) |
|
|
} |
|
|
if sr := kirocommon.GetString(assistantResp, "stopReason"); sr != "" { |
|
|
stopReason = sr |
|
|
log.Debugf("kiro: parseEventStream found stopReason in assistantResponseEvent: %s", stopReason) |
|
|
} |
|
|
|
|
|
if toolUsesRaw, ok := assistantResp["toolUses"].([]interface{}); ok { |
|
|
for _, tuRaw := range toolUsesRaw { |
|
|
if tu, ok := tuRaw.(map[string]interface{}); ok { |
|
|
toolUseID := kirocommon.GetStringValue(tu, "toolUseId") |
|
|
|
|
|
if processedIDs[toolUseID] { |
|
|
log.Debugf("kiro: skipping duplicate tool use from assistantResponse: %s", toolUseID) |
|
|
continue |
|
|
} |
|
|
processedIDs[toolUseID] = true |
|
|
|
|
|
toolUse := kiroclaude.KiroToolUse{ |
|
|
ToolUseID: toolUseID, |
|
|
Name: kirocommon.GetStringValue(tu, "name"), |
|
|
} |
|
|
if input, ok := tu["input"].(map[string]interface{}); ok { |
|
|
toolUse.Input = input |
|
|
} |
|
|
toolUses = append(toolUses, toolUse) |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
if contentText, ok := event["content"].(string); ok { |
|
|
content.WriteString(contentText) |
|
|
} |
|
|
|
|
|
if toolUsesRaw, ok := event["toolUses"].([]interface{}); ok { |
|
|
for _, tuRaw := range toolUsesRaw { |
|
|
if tu, ok := tuRaw.(map[string]interface{}); ok { |
|
|
toolUseID := kirocommon.GetStringValue(tu, "toolUseId") |
|
|
|
|
|
if processedIDs[toolUseID] { |
|
|
log.Debugf("kiro: skipping duplicate direct tool use: %s", toolUseID) |
|
|
continue |
|
|
} |
|
|
processedIDs[toolUseID] = true |
|
|
|
|
|
toolUse := kiroclaude.KiroToolUse{ |
|
|
ToolUseID: toolUseID, |
|
|
Name: kirocommon.GetStringValue(tu, "name"), |
|
|
} |
|
|
if input, ok := tu["input"].(map[string]interface{}); ok { |
|
|
toolUse.Input = input |
|
|
} |
|
|
toolUses = append(toolUses, toolUse) |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
case "toolUseEvent": |
|
|
|
|
|
completedToolUses, newState := kiroclaude.ProcessToolUseEvent(event, currentToolUse, processedIDs) |
|
|
currentToolUse = newState |
|
|
toolUses = append(toolUses, completedToolUses...) |
|
|
|
|
|
case "supplementaryWebLinksEvent": |
|
|
if inputTokens, ok := event["inputTokens"].(float64); ok { |
|
|
usageInfo.InputTokens = int64(inputTokens) |
|
|
} |
|
|
if outputTokens, ok := event["outputTokens"].(float64); ok { |
|
|
usageInfo.OutputTokens = int64(outputTokens) |
|
|
} |
|
|
|
|
|
case "messageStopEvent", "message_stop": |
|
|
|
|
|
if sr := kirocommon.GetString(event, "stop_reason"); sr != "" { |
|
|
stopReason = sr |
|
|
log.Debugf("kiro: parseEventStream found stop_reason in messageStopEvent: %s", stopReason) |
|
|
} |
|
|
if sr := kirocommon.GetString(event, "stopReason"); sr != "" { |
|
|
stopReason = sr |
|
|
log.Debugf("kiro: parseEventStream found stopReason in messageStopEvent: %s", stopReason) |
|
|
} |
|
|
|
|
|
case "messageMetadataEvent", "metadataEvent": |
|
|
|
|
|
|
|
|
var metadata map[string]interface{} |
|
|
if m, ok := event["messageMetadataEvent"].(map[string]interface{}); ok { |
|
|
metadata = m |
|
|
} else if m, ok := event["metadataEvent"].(map[string]interface{}); ok { |
|
|
metadata = m |
|
|
} else { |
|
|
metadata = event |
|
|
} |
|
|
|
|
|
|
|
|
if tokenUsage, ok := metadata["tokenUsage"].(map[string]interface{}); ok { |
|
|
|
|
|
if outputTokens, ok := tokenUsage["outputTokens"].(float64); ok { |
|
|
usageInfo.OutputTokens = int64(outputTokens) |
|
|
log.Infof("kiro: parseEventStream found precise outputTokens in tokenUsage: %d", usageInfo.OutputTokens) |
|
|
} |
|
|
|
|
|
if totalTokens, ok := tokenUsage["totalTokens"].(float64); ok { |
|
|
usageInfo.TotalTokens = int64(totalTokens) |
|
|
log.Infof("kiro: parseEventStream found precise totalTokens in tokenUsage: %d", usageInfo.TotalTokens) |
|
|
} |
|
|
|
|
|
if uncachedInputTokens, ok := tokenUsage["uncachedInputTokens"].(float64); ok { |
|
|
usageInfo.InputTokens = int64(uncachedInputTokens) |
|
|
log.Infof("kiro: parseEventStream found uncachedInputTokens in tokenUsage: %d", usageInfo.InputTokens) |
|
|
} |
|
|
|
|
|
if cacheReadTokens, ok := tokenUsage["cacheReadInputTokens"].(float64); ok { |
|
|
|
|
|
if usageInfo.InputTokens > 0 { |
|
|
usageInfo.InputTokens += int64(cacheReadTokens) |
|
|
} else { |
|
|
usageInfo.InputTokens = int64(cacheReadTokens) |
|
|
} |
|
|
log.Debugf("kiro: parseEventStream found cacheReadInputTokens in tokenUsage: %d", int64(cacheReadTokens)) |
|
|
} |
|
|
|
|
|
if ctxPct, ok := tokenUsage["contextUsagePercentage"].(float64); ok { |
|
|
upstreamContextPercentage = ctxPct |
|
|
log.Debugf("kiro: parseEventStream found contextUsagePercentage in tokenUsage: %.2f%%", ctxPct) |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
if usageInfo.InputTokens == 0 { |
|
|
if inputTokens, ok := metadata["inputTokens"].(float64); ok { |
|
|
usageInfo.InputTokens = int64(inputTokens) |
|
|
log.Debugf("kiro: parseEventStream found inputTokens in messageMetadataEvent: %d", usageInfo.InputTokens) |
|
|
} |
|
|
} |
|
|
if usageInfo.OutputTokens == 0 { |
|
|
if outputTokens, ok := metadata["outputTokens"].(float64); ok { |
|
|
usageInfo.OutputTokens = int64(outputTokens) |
|
|
log.Debugf("kiro: parseEventStream found outputTokens in messageMetadataEvent: %d", usageInfo.OutputTokens) |
|
|
} |
|
|
} |
|
|
if usageInfo.TotalTokens == 0 { |
|
|
if totalTokens, ok := metadata["totalTokens"].(float64); ok { |
|
|
usageInfo.TotalTokens = int64(totalTokens) |
|
|
log.Debugf("kiro: parseEventStream found totalTokens in messageMetadataEvent: %d", usageInfo.TotalTokens) |
|
|
} |
|
|
} |
|
|
|
|
|
case "usageEvent", "usage": |
|
|
|
|
|
if inputTokens, ok := event["inputTokens"].(float64); ok { |
|
|
usageInfo.InputTokens = int64(inputTokens) |
|
|
log.Debugf("kiro: parseEventStream found inputTokens in usageEvent: %d", usageInfo.InputTokens) |
|
|
} |
|
|
if outputTokens, ok := event["outputTokens"].(float64); ok { |
|
|
usageInfo.OutputTokens = int64(outputTokens) |
|
|
log.Debugf("kiro: parseEventStream found outputTokens in usageEvent: %d", usageInfo.OutputTokens) |
|
|
} |
|
|
if totalTokens, ok := event["totalTokens"].(float64); ok { |
|
|
usageInfo.TotalTokens = int64(totalTokens) |
|
|
log.Debugf("kiro: parseEventStream found totalTokens in usageEvent: %d", usageInfo.TotalTokens) |
|
|
} |
|
|
|
|
|
if usageObj, ok := event["usage"].(map[string]interface{}); ok { |
|
|
if inputTokens, ok := usageObj["input_tokens"].(float64); ok { |
|
|
usageInfo.InputTokens = int64(inputTokens) |
|
|
} else if inputTokens, ok := usageObj["prompt_tokens"].(float64); ok { |
|
|
usageInfo.InputTokens = int64(inputTokens) |
|
|
} |
|
|
if outputTokens, ok := usageObj["output_tokens"].(float64); ok { |
|
|
usageInfo.OutputTokens = int64(outputTokens) |
|
|
} else if outputTokens, ok := usageObj["completion_tokens"].(float64); ok { |
|
|
usageInfo.OutputTokens = int64(outputTokens) |
|
|
} |
|
|
if totalTokens, ok := usageObj["total_tokens"].(float64); ok { |
|
|
usageInfo.TotalTokens = int64(totalTokens) |
|
|
} |
|
|
log.Debugf("kiro: parseEventStream found usage object: input=%d, output=%d, total=%d", |
|
|
usageInfo.InputTokens, usageInfo.OutputTokens, usageInfo.TotalTokens) |
|
|
} |
|
|
|
|
|
case "metricsEvent": |
|
|
|
|
|
if metrics, ok := event["metricsEvent"].(map[string]interface{}); ok { |
|
|
if inputTokens, ok := metrics["inputTokens"].(float64); ok { |
|
|
usageInfo.InputTokens = int64(inputTokens) |
|
|
} |
|
|
if outputTokens, ok := metrics["outputTokens"].(float64); ok { |
|
|
usageInfo.OutputTokens = int64(outputTokens) |
|
|
} |
|
|
log.Debugf("kiro: parseEventStream found metricsEvent: input=%d, output=%d", |
|
|
usageInfo.InputTokens, usageInfo.OutputTokens) |
|
|
} |
|
|
|
|
|
case "meteringEvent": |
|
|
|
|
|
|
|
|
if metering, ok := event["meteringEvent"].(map[string]interface{}); ok { |
|
|
unit := "" |
|
|
if u, ok := metering["unit"].(string); ok { |
|
|
unit = u |
|
|
} |
|
|
usageVal := 0.0 |
|
|
if u, ok := metering["usage"].(float64); ok { |
|
|
usageVal = u |
|
|
} |
|
|
log.Infof("kiro: parseEventStream received meteringEvent: usage=%.2f %s", usageVal, unit) |
|
|
|
|
|
|
|
|
} else { |
|
|
|
|
|
unit := "" |
|
|
if u, ok := event["unit"].(string); ok { |
|
|
unit = u |
|
|
} |
|
|
usageVal := 0.0 |
|
|
if u, ok := event["usage"].(float64); ok { |
|
|
usageVal = u |
|
|
} |
|
|
if unit != "" || usageVal > 0 { |
|
|
log.Infof("kiro: parseEventStream received meteringEvent (direct): usage=%.2f %s", usageVal, unit) |
|
|
} |
|
|
} |
|
|
|
|
|
case "error", "exception", "internalServerException", "invalidStateEvent": |
|
|
|
|
|
errMsg := "" |
|
|
errType := eventType |
|
|
|
|
|
|
|
|
if msg, ok := event["message"].(string); ok { |
|
|
errMsg = msg |
|
|
} else if errObj, ok := event[eventType].(map[string]interface{}); ok { |
|
|
if msg, ok := errObj["message"].(string); ok { |
|
|
errMsg = msg |
|
|
} |
|
|
if t, ok := errObj["type"].(string); ok { |
|
|
errType = t |
|
|
} |
|
|
} else if errObj, ok := event["error"].(map[string]interface{}); ok { |
|
|
if msg, ok := errObj["message"].(string); ok { |
|
|
errMsg = msg |
|
|
} |
|
|
if t, ok := errObj["type"].(string); ok { |
|
|
errType = t |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
if reason, ok := event["reason"].(string); ok { |
|
|
errMsg = fmt.Sprintf("%s (reason: %s)", errMsg, reason) |
|
|
} |
|
|
|
|
|
log.Errorf("kiro: parseEventStream received error event: type=%s, message=%s", errType, errMsg) |
|
|
|
|
|
|
|
|
if eventType == "invalidStateEvent" { |
|
|
log.Warnf("kiro: invalidStateEvent received, continuing stream processing") |
|
|
continue |
|
|
} |
|
|
|
|
|
|
|
|
if errMsg != "" { |
|
|
return "", nil, usageInfo, stopReason, fmt.Errorf("kiro API error (%s): %s", errType, errMsg) |
|
|
} |
|
|
|
|
|
default: |
|
|
|
|
|
if ctxPct, ok := event["contextUsagePercentage"].(float64); ok { |
|
|
upstreamContextPercentage = ctxPct |
|
|
log.Debugf("kiro: parseEventStream received context usage: %.2f%%", upstreamContextPercentage) |
|
|
} |
|
|
|
|
|
log.Debugf("kiro: parseEventStream unknown event type: %s, payload: %s", eventType, string(payload)) |
|
|
} |
|
|
|
|
|
|
|
|
if usageInfo.InputTokens == 0 { |
|
|
if inputTokens, ok := event["inputTokens"].(float64); ok { |
|
|
usageInfo.InputTokens = int64(inputTokens) |
|
|
log.Debugf("kiro: parseEventStream found direct inputTokens: %d", usageInfo.InputTokens) |
|
|
} |
|
|
} |
|
|
if usageInfo.OutputTokens == 0 { |
|
|
if outputTokens, ok := event["outputTokens"].(float64); ok { |
|
|
usageInfo.OutputTokens = int64(outputTokens) |
|
|
log.Debugf("kiro: parseEventStream found direct outputTokens: %d", usageInfo.OutputTokens) |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
if usageInfo.InputTokens == 0 || usageInfo.OutputTokens == 0 { |
|
|
if usageObj, ok := event["usage"].(map[string]interface{}); ok { |
|
|
if usageInfo.InputTokens == 0 { |
|
|
if inputTokens, ok := usageObj["input_tokens"].(float64); ok { |
|
|
usageInfo.InputTokens = int64(inputTokens) |
|
|
} else if inputTokens, ok := usageObj["prompt_tokens"].(float64); ok { |
|
|
usageInfo.InputTokens = int64(inputTokens) |
|
|
} |
|
|
} |
|
|
if usageInfo.OutputTokens == 0 { |
|
|
if outputTokens, ok := usageObj["output_tokens"].(float64); ok { |
|
|
usageInfo.OutputTokens = int64(outputTokens) |
|
|
} else if outputTokens, ok := usageObj["completion_tokens"].(float64); ok { |
|
|
usageInfo.OutputTokens = int64(outputTokens) |
|
|
} |
|
|
} |
|
|
if usageInfo.TotalTokens == 0 { |
|
|
if totalTokens, ok := usageObj["total_tokens"].(float64); ok { |
|
|
usageInfo.TotalTokens = int64(totalTokens) |
|
|
} |
|
|
} |
|
|
log.Debugf("kiro: parseEventStream found usage object (fallback): input=%d, output=%d, total=%d", |
|
|
usageInfo.InputTokens, usageInfo.OutputTokens, usageInfo.TotalTokens) |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
if usageEvent, ok := event["supplementaryWebLinksEvent"].(map[string]interface{}); ok { |
|
|
if inputTokens, ok := usageEvent["inputTokens"].(float64); ok { |
|
|
usageInfo.InputTokens = int64(inputTokens) |
|
|
} |
|
|
if outputTokens, ok := usageEvent["outputTokens"].(float64); ok { |
|
|
usageInfo.OutputTokens = int64(outputTokens) |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
contentStr := content.String() |
|
|
cleanedContent, embeddedToolUses := kiroclaude.ParseEmbeddedToolCalls(contentStr, processedIDs) |
|
|
toolUses = append(toolUses, embeddedToolUses...) |
|
|
|
|
|
|
|
|
toolUses = kiroclaude.DeduplicateToolUses(toolUses) |
|
|
|
|
|
|
|
|
|
|
|
if stopReason == "" { |
|
|
if len(toolUses) > 0 { |
|
|
stopReason = "tool_use" |
|
|
log.Debugf("kiro: parseEventStream using fallback stop_reason: tool_use (detected %d tool uses)", len(toolUses)) |
|
|
} else { |
|
|
stopReason = "end_turn" |
|
|
log.Debugf("kiro: parseEventStream using fallback stop_reason: end_turn") |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
if stopReason == "max_tokens" { |
|
|
log.Warnf("kiro: response truncated due to max_tokens limit") |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if upstreamContextPercentage > 0 { |
|
|
calculatedInputTokens := int64(upstreamContextPercentage * 200000 / 100) |
|
|
if calculatedInputTokens > 0 { |
|
|
localEstimate := usageInfo.InputTokens |
|
|
usageInfo.InputTokens = calculatedInputTokens |
|
|
usageInfo.TotalTokens = usageInfo.InputTokens + usageInfo.OutputTokens |
|
|
log.Infof("kiro: parseEventStream using contextUsagePercentage (%.2f%%) to calculate input tokens: %d (local estimate was: %d)", |
|
|
upstreamContextPercentage, calculatedInputTokens, localEstimate) |
|
|
} |
|
|
} |
|
|
|
|
|
return cleanedContent, toolUses, usageInfo, stopReason, nil |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
func (e *KiroExecutor) readEventStreamMessage(reader *bufio.Reader) (*eventStreamMessage, *EventStreamError) { |
|
|
|
|
|
prelude := make([]byte, 12) |
|
|
_, err := io.ReadFull(reader, prelude) |
|
|
if err == io.EOF { |
|
|
return nil, nil |
|
|
} |
|
|
if err != nil { |
|
|
return nil, &EventStreamError{ |
|
|
Type: ErrStreamFatal, |
|
|
Message: "failed to read prelude", |
|
|
Cause: err, |
|
|
} |
|
|
} |
|
|
|
|
|
totalLength := binary.BigEndian.Uint32(prelude[0:4]) |
|
|
headersLength := binary.BigEndian.Uint32(prelude[4:8]) |
|
|
|
|
|
|
|
|
|
|
|
if totalLength < minEventStreamFrameSize { |
|
|
return nil, &EventStreamError{ |
|
|
Type: ErrStreamMalformed, |
|
|
Message: fmt.Sprintf("invalid message length: %d (minimum is %d)", totalLength, minEventStreamFrameSize), |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
if totalLength > maxEventStreamMsgSize { |
|
|
return nil, &EventStreamError{ |
|
|
Type: ErrStreamMalformed, |
|
|
Message: fmt.Sprintf("message too large: %d bytes (maximum is %d)", totalLength, maxEventStreamMsgSize), |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if headersLength > totalLength-16 { |
|
|
return nil, &EventStreamError{ |
|
|
Type: ErrStreamMalformed, |
|
|
Message: fmt.Sprintf("headers length %d exceeds message bounds (total: %d)", headersLength, totalLength), |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
remaining := make([]byte, totalLength-12) |
|
|
_, err = io.ReadFull(reader, remaining) |
|
|
if err != nil { |
|
|
return nil, &EventStreamError{ |
|
|
Type: ErrStreamFatal, |
|
|
Message: "failed to read message body", |
|
|
Cause: err, |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
var eventType string |
|
|
if headersLength > 0 && headersLength <= uint32(len(remaining)) { |
|
|
eventType = e.extractEventTypeFromBytes(remaining[:headersLength]) |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
payloadStart := headersLength |
|
|
payloadEnd := uint32(len(remaining)) - 4 |
|
|
|
|
|
|
|
|
if payloadStart >= payloadEnd { |
|
|
|
|
|
return &eventStreamMessage{ |
|
|
EventType: eventType, |
|
|
Payload: nil, |
|
|
}, nil |
|
|
} |
|
|
|
|
|
payload := remaining[payloadStart:payloadEnd] |
|
|
|
|
|
return &eventStreamMessage{ |
|
|
EventType: eventType, |
|
|
Payload: payload, |
|
|
}, nil |
|
|
} |
|
|
|
|
|
func skipEventStreamHeaderValue(headers []byte, offset int, valueType byte) (int, bool) { |
|
|
switch valueType { |
|
|
case 0, 1: |
|
|
return offset, true |
|
|
case 2: |
|
|
if offset+1 > len(headers) { |
|
|
return offset, false |
|
|
} |
|
|
return offset + 1, true |
|
|
case 3: |
|
|
if offset+2 > len(headers) { |
|
|
return offset, false |
|
|
} |
|
|
return offset + 2, true |
|
|
case 4: |
|
|
if offset+4 > len(headers) { |
|
|
return offset, false |
|
|
} |
|
|
return offset + 4, true |
|
|
case 5: |
|
|
if offset+8 > len(headers) { |
|
|
return offset, false |
|
|
} |
|
|
return offset + 8, true |
|
|
case 6: |
|
|
if offset+2 > len(headers) { |
|
|
return offset, false |
|
|
} |
|
|
valueLen := int(binary.BigEndian.Uint16(headers[offset : offset+2])) |
|
|
offset += 2 |
|
|
if offset+valueLen > len(headers) { |
|
|
return offset, false |
|
|
} |
|
|
return offset + valueLen, true |
|
|
case 8: |
|
|
if offset+8 > len(headers) { |
|
|
return offset, false |
|
|
} |
|
|
return offset + 8, true |
|
|
case 9: |
|
|
if offset+16 > len(headers) { |
|
|
return offset, false |
|
|
} |
|
|
return offset + 16, true |
|
|
default: |
|
|
return offset, false |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
func (e *KiroExecutor) extractEventTypeFromBytes(headers []byte) string { |
|
|
offset := 0 |
|
|
for offset < len(headers) { |
|
|
nameLen := int(headers[offset]) |
|
|
offset++ |
|
|
if offset+nameLen > len(headers) { |
|
|
break |
|
|
} |
|
|
name := string(headers[offset : offset+nameLen]) |
|
|
offset += nameLen |
|
|
|
|
|
if offset >= len(headers) { |
|
|
break |
|
|
} |
|
|
valueType := headers[offset] |
|
|
offset++ |
|
|
|
|
|
if valueType == 7 { |
|
|
if offset+2 > len(headers) { |
|
|
break |
|
|
} |
|
|
valueLen := int(binary.BigEndian.Uint16(headers[offset : offset+2])) |
|
|
offset += 2 |
|
|
if offset+valueLen > len(headers) { |
|
|
break |
|
|
} |
|
|
value := string(headers[offset : offset+valueLen]) |
|
|
offset += valueLen |
|
|
|
|
|
if name == ":event-type" { |
|
|
return value |
|
|
} |
|
|
continue |
|
|
} |
|
|
|
|
|
nextOffset, ok := skipEventStreamHeaderValue(headers, offset, valueType) |
|
|
if !ok { |
|
|
break |
|
|
} |
|
|
offset = nextOffset |
|
|
} |
|
|
return "" |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out chan<- cliproxyexecutor.StreamChunk, targetFormat sdktranslator.Format, model string, originalReq, claudeBody []byte, reporter *usageReporter, thinkingEnabled bool) { |
|
|
reader := bufio.NewReaderSize(body, 20*1024*1024) |
|
|
var totalUsage usage.Detail |
|
|
var hasToolUses bool |
|
|
var upstreamStopReason string |
|
|
|
|
|
|
|
|
processedIDs := make(map[string]bool) |
|
|
var currentToolUse *kiroclaude.ToolUseState |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
var accumulatedContent strings.Builder |
|
|
accumulatedContent.Grow(4096) |
|
|
|
|
|
|
|
|
|
|
|
var lastUsageUpdateLen int |
|
|
var lastUsageUpdateTime = time.Now() |
|
|
var lastReportedOutputTokens int64 |
|
|
|
|
|
|
|
|
var upstreamCreditUsage float64 |
|
|
var upstreamContextPercentage float64 |
|
|
var hasUpstreamUsage bool |
|
|
|
|
|
|
|
|
|
|
|
var translatorParam any |
|
|
|
|
|
|
|
|
inThinkBlock := false |
|
|
isThinkingBlockOpen := false |
|
|
thinkingBlockIndex := -1 |
|
|
var accumulatedThinkingContent strings.Builder |
|
|
|
|
|
|
|
|
var pendingContent strings.Builder |
|
|
|
|
|
|
|
|
|
|
|
if enc, err := getTokenizer(model); err == nil { |
|
|
var inputTokens int64 |
|
|
var countMethod string |
|
|
|
|
|
|
|
|
if inp, err := countClaudeChatTokens(enc, claudeBody); err == nil && inp > 0 { |
|
|
inputTokens = inp |
|
|
countMethod = "claude" |
|
|
} else if inp, err := countOpenAIChatTokens(enc, originalReq); err == nil && inp > 0 { |
|
|
|
|
|
inputTokens = inp |
|
|
countMethod = "openai" |
|
|
} else { |
|
|
|
|
|
inputTokens = int64(len(claudeBody) / 4) |
|
|
if inputTokens == 0 && len(claudeBody) > 0 { |
|
|
inputTokens = 1 |
|
|
} |
|
|
countMethod = "estimate" |
|
|
} |
|
|
|
|
|
totalUsage.InputTokens = inputTokens |
|
|
log.Debugf("kiro: streamToChannel pre-calculated input tokens: %d (method: %s, claude body: %d bytes, original req: %d bytes)", |
|
|
totalUsage.InputTokens, countMethod, len(claudeBody), len(originalReq)) |
|
|
} |
|
|
|
|
|
contentBlockIndex := -1 |
|
|
messageStartSent := false |
|
|
isTextBlockOpen := false |
|
|
var outputLen int |
|
|
|
|
|
|
|
|
defer func() { |
|
|
reporter.publish(ctx, totalUsage) |
|
|
}() |
|
|
|
|
|
for { |
|
|
select { |
|
|
case <-ctx.Done(): |
|
|
return |
|
|
default: |
|
|
} |
|
|
|
|
|
msg, eventErr := e.readEventStreamMessage(reader) |
|
|
if eventErr != nil { |
|
|
|
|
|
log.Errorf("kiro: streamToChannel error: %v", eventErr) |
|
|
|
|
|
|
|
|
out <- cliproxyexecutor.StreamChunk{Err: eventErr} |
|
|
return |
|
|
} |
|
|
if msg == nil { |
|
|
|
|
|
|
|
|
if currentToolUse != nil && !processedIDs[currentToolUse.ToolUseID] { |
|
|
log.Warnf("kiro: flushing incomplete tool use at EOF: %s (ID: %s)", currentToolUse.Name, currentToolUse.ToolUseID) |
|
|
fullInput := currentToolUse.InputBuffer.String() |
|
|
repairedJSON := kiroclaude.RepairJSON(fullInput) |
|
|
var finalInput map[string]interface{} |
|
|
if err := json.Unmarshal([]byte(repairedJSON), &finalInput); err != nil { |
|
|
log.Warnf("kiro: failed to parse incomplete tool input at EOF: %v", err) |
|
|
finalInput = make(map[string]interface{}) |
|
|
} |
|
|
|
|
|
processedIDs[currentToolUse.ToolUseID] = true |
|
|
contentBlockIndex++ |
|
|
|
|
|
|
|
|
blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "tool_use", currentToolUse.ToolUseID, currentToolUse.Name) |
|
|
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) |
|
|
for _, chunk := range sseData { |
|
|
if chunk != "" { |
|
|
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
inputBytes, _ := json.Marshal(finalInput) |
|
|
inputDelta := kiroclaude.BuildClaudeInputJsonDeltaEvent(string(inputBytes), contentBlockIndex) |
|
|
sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, inputDelta, &translatorParam) |
|
|
for _, chunk := range sseData { |
|
|
if chunk != "" { |
|
|
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex) |
|
|
sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam) |
|
|
for _, chunk := range sseData { |
|
|
if chunk != "" { |
|
|
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} |
|
|
} |
|
|
} |
|
|
|
|
|
hasToolUses = true |
|
|
currentToolUse = nil |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
break |
|
|
} |
|
|
|
|
|
eventType := msg.EventType |
|
|
payload := msg.Payload |
|
|
if len(payload) == 0 { |
|
|
continue |
|
|
} |
|
|
appendAPIResponseChunk(ctx, e.cfg, payload) |
|
|
|
|
|
var event map[string]interface{} |
|
|
if err := json.Unmarshal(payload, &event); err != nil { |
|
|
log.Warnf("kiro: failed to unmarshal event payload: %v, raw: %s", err, string(payload)) |
|
|
continue |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
if errType, hasErrType := event["_type"].(string); hasErrType { |
|
|
|
|
|
errMsg := "" |
|
|
if msg, ok := event["message"].(string); ok { |
|
|
errMsg = msg |
|
|
} |
|
|
log.Errorf("kiro: received AWS error in stream: type=%s, message=%s", errType, errMsg) |
|
|
out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("kiro API error: %s - %s", errType, errMsg)} |
|
|
return |
|
|
} |
|
|
if errType, hasErrType := event["type"].(string); hasErrType && (errType == "error" || errType == "exception") { |
|
|
|
|
|
errMsg := "" |
|
|
if msg, ok := event["message"].(string); ok { |
|
|
errMsg = msg |
|
|
} else if errObj, ok := event["error"].(map[string]interface{}); ok { |
|
|
if msg, ok := errObj["message"].(string); ok { |
|
|
errMsg = msg |
|
|
} |
|
|
} |
|
|
log.Errorf("kiro: received error event in stream: type=%s, message=%s", errType, errMsg) |
|
|
out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("kiro API error: %s", errMsg)} |
|
|
return |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
if sr := kirocommon.GetString(event, "stop_reason"); sr != "" { |
|
|
upstreamStopReason = sr |
|
|
log.Debugf("kiro: streamToChannel found stop_reason (top-level): %s", upstreamStopReason) |
|
|
} |
|
|
if sr := kirocommon.GetString(event, "stopReason"); sr != "" { |
|
|
upstreamStopReason = sr |
|
|
log.Debugf("kiro: streamToChannel found stopReason (top-level): %s", upstreamStopReason) |
|
|
} |
|
|
|
|
|
|
|
|
if !messageStartSent { |
|
|
msgStart := kiroclaude.BuildClaudeMessageStartEvent(model, totalUsage.InputTokens) |
|
|
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, msgStart, &translatorParam) |
|
|
for _, chunk := range sseData { |
|
|
if chunk != "" { |
|
|
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} |
|
|
} |
|
|
} |
|
|
messageStartSent = true |
|
|
} |
|
|
|
|
|
switch eventType { |
|
|
case "followupPromptEvent": |
|
|
|
|
|
log.Debugf("kiro: streamToChannel ignoring followupPrompt event") |
|
|
continue |
|
|
|
|
|
case "messageStopEvent", "message_stop": |
|
|
|
|
|
if sr := kirocommon.GetString(event, "stop_reason"); sr != "" { |
|
|
upstreamStopReason = sr |
|
|
log.Debugf("kiro: streamToChannel found stop_reason in messageStopEvent: %s", upstreamStopReason) |
|
|
} |
|
|
if sr := kirocommon.GetString(event, "stopReason"); sr != "" { |
|
|
upstreamStopReason = sr |
|
|
log.Debugf("kiro: streamToChannel found stopReason in messageStopEvent: %s", upstreamStopReason) |
|
|
} |
|
|
|
|
|
case "meteringEvent": |
|
|
|
|
|
|
|
|
if metering, ok := event["meteringEvent"].(map[string]interface{}); ok { |
|
|
unit := "" |
|
|
if u, ok := metering["unit"].(string); ok { |
|
|
unit = u |
|
|
} |
|
|
usageVal := 0.0 |
|
|
if u, ok := metering["usage"].(float64); ok { |
|
|
usageVal = u |
|
|
} |
|
|
upstreamCreditUsage = usageVal |
|
|
hasUpstreamUsage = true |
|
|
log.Infof("kiro: streamToChannel received meteringEvent: usage=%.4f %s", usageVal, unit) |
|
|
} else { |
|
|
|
|
|
if unit, ok := event["unit"].(string); ok { |
|
|
if usage, ok := event["usage"].(float64); ok { |
|
|
upstreamCreditUsage = usage |
|
|
hasUpstreamUsage = true |
|
|
log.Infof("kiro: streamToChannel received meteringEvent (direct): usage=%.4f %s", usage, unit) |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
case "error", "exception", "internalServerException": |
|
|
|
|
|
errMsg := "" |
|
|
errType := eventType |
|
|
|
|
|
|
|
|
if msg, ok := event["message"].(string); ok { |
|
|
errMsg = msg |
|
|
} else if errObj, ok := event[eventType].(map[string]interface{}); ok { |
|
|
if msg, ok := errObj["message"].(string); ok { |
|
|
errMsg = msg |
|
|
} |
|
|
if t, ok := errObj["type"].(string); ok { |
|
|
errType = t |
|
|
} |
|
|
} else if errObj, ok := event["error"].(map[string]interface{}); ok { |
|
|
if msg, ok := errObj["message"].(string); ok { |
|
|
errMsg = msg |
|
|
} |
|
|
} |
|
|
|
|
|
log.Errorf("kiro: streamToChannel received error event: type=%s, message=%s", errType, errMsg) |
|
|
|
|
|
|
|
|
if errMsg != "" { |
|
|
out <- cliproxyexecutor.StreamChunk{ |
|
|
Err: fmt.Errorf("kiro API error (%s): %s", errType, errMsg), |
|
|
} |
|
|
return |
|
|
} |
|
|
|
|
|
case "invalidStateEvent": |
|
|
|
|
|
errMsg := "" |
|
|
if msg, ok := event["message"].(string); ok { |
|
|
errMsg = msg |
|
|
} else if stateEvent, ok := event["invalidStateEvent"].(map[string]interface{}); ok { |
|
|
if msg, ok := stateEvent["message"].(string); ok { |
|
|
errMsg = msg |
|
|
} |
|
|
} |
|
|
log.Warnf("kiro: streamToChannel received invalidStateEvent: %s, continuing", errMsg) |
|
|
continue |
|
|
|
|
|
default: |
|
|
|
|
|
|
|
|
if unit, ok := event["unit"].(string); ok && unit == "credit" { |
|
|
if usage, ok := event["usage"].(float64); ok { |
|
|
upstreamCreditUsage = usage |
|
|
hasUpstreamUsage = true |
|
|
log.Debugf("kiro: received upstream credit usage: %.4f", upstreamCreditUsage) |
|
|
} |
|
|
} |
|
|
|
|
|
if ctxPct, ok := event["contextUsagePercentage"].(float64); ok { |
|
|
upstreamContextPercentage = ctxPct |
|
|
log.Debugf("kiro: received upstream context usage: %.2f%%", upstreamContextPercentage) |
|
|
} |
|
|
|
|
|
|
|
|
if inputTokens, ok := event["inputTokens"].(float64); ok { |
|
|
totalUsage.InputTokens = int64(inputTokens) |
|
|
hasUpstreamUsage = true |
|
|
log.Debugf("kiro: streamToChannel found inputTokens in event %s: %d", eventType, totalUsage.InputTokens) |
|
|
} |
|
|
if outputTokens, ok := event["outputTokens"].(float64); ok { |
|
|
totalUsage.OutputTokens = int64(outputTokens) |
|
|
hasUpstreamUsage = true |
|
|
log.Debugf("kiro: streamToChannel found outputTokens in event %s: %d", eventType, totalUsage.OutputTokens) |
|
|
} |
|
|
if totalTokens, ok := event["totalTokens"].(float64); ok { |
|
|
totalUsage.TotalTokens = int64(totalTokens) |
|
|
log.Debugf("kiro: streamToChannel found totalTokens in event %s: %d", eventType, totalUsage.TotalTokens) |
|
|
} |
|
|
|
|
|
|
|
|
if usageObj, ok := event["usage"].(map[string]interface{}); ok { |
|
|
if inputTokens, ok := usageObj["input_tokens"].(float64); ok { |
|
|
totalUsage.InputTokens = int64(inputTokens) |
|
|
hasUpstreamUsage = true |
|
|
} else if inputTokens, ok := usageObj["prompt_tokens"].(float64); ok { |
|
|
totalUsage.InputTokens = int64(inputTokens) |
|
|
hasUpstreamUsage = true |
|
|
} |
|
|
if outputTokens, ok := usageObj["output_tokens"].(float64); ok { |
|
|
totalUsage.OutputTokens = int64(outputTokens) |
|
|
hasUpstreamUsage = true |
|
|
} else if outputTokens, ok := usageObj["completion_tokens"].(float64); ok { |
|
|
totalUsage.OutputTokens = int64(outputTokens) |
|
|
hasUpstreamUsage = true |
|
|
} |
|
|
if totalTokens, ok := usageObj["total_tokens"].(float64); ok { |
|
|
totalUsage.TotalTokens = int64(totalTokens) |
|
|
} |
|
|
log.Debugf("kiro: streamToChannel found usage object in event %s: input=%d, output=%d, total=%d", |
|
|
eventType, totalUsage.InputTokens, totalUsage.OutputTokens, totalUsage.TotalTokens) |
|
|
} |
|
|
|
|
|
|
|
|
if eventType != "" { |
|
|
log.Debugf("kiro: streamToChannel unknown event type: %s, payload: %s", eventType, string(payload)) |
|
|
} |
|
|
|
|
|
case "assistantResponseEvent": |
|
|
var contentDelta string |
|
|
var toolUses []map[string]interface{} |
|
|
|
|
|
if assistantResp, ok := event["assistantResponseEvent"].(map[string]interface{}); ok { |
|
|
if c, ok := assistantResp["content"].(string); ok { |
|
|
contentDelta = c |
|
|
} |
|
|
|
|
|
if sr := kirocommon.GetString(assistantResp, "stop_reason"); sr != "" { |
|
|
upstreamStopReason = sr |
|
|
log.Debugf("kiro: streamToChannel found stop_reason in assistantResponseEvent: %s", upstreamStopReason) |
|
|
} |
|
|
if sr := kirocommon.GetString(assistantResp, "stopReason"); sr != "" { |
|
|
upstreamStopReason = sr |
|
|
log.Debugf("kiro: streamToChannel found stopReason in assistantResponseEvent: %s", upstreamStopReason) |
|
|
} |
|
|
|
|
|
if tus, ok := assistantResp["toolUses"].([]interface{}); ok { |
|
|
for _, tuRaw := range tus { |
|
|
if tu, ok := tuRaw.(map[string]interface{}); ok { |
|
|
toolUses = append(toolUses, tu) |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
if contentDelta == "" { |
|
|
if c, ok := event["content"].(string); ok { |
|
|
contentDelta = c |
|
|
} |
|
|
} |
|
|
|
|
|
if tus, ok := event["toolUses"].([]interface{}); ok { |
|
|
for _, tuRaw := range tus { |
|
|
if tu, ok := tuRaw.(map[string]interface{}); ok { |
|
|
toolUses = append(toolUses, tu) |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
if contentDelta != "" { |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
outputLen += len(contentDelta) |
|
|
|
|
|
accumulatedContent.WriteString(contentDelta) |
|
|
|
|
|
|
|
|
|
|
|
shouldSendUsageUpdate := false |
|
|
if accumulatedContent.Len()-lastUsageUpdateLen >= usageUpdateCharThreshold { |
|
|
shouldSendUsageUpdate = true |
|
|
} else if time.Since(lastUsageUpdateTime) >= usageUpdateTimeInterval && accumulatedContent.Len() > lastUsageUpdateLen { |
|
|
shouldSendUsageUpdate = true |
|
|
} |
|
|
|
|
|
if shouldSendUsageUpdate { |
|
|
|
|
|
var currentOutputTokens int64 |
|
|
if enc, encErr := getTokenizer(model); encErr == nil { |
|
|
if tokenCount, countErr := enc.Count(accumulatedContent.String()); countErr == nil { |
|
|
currentOutputTokens = int64(tokenCount) |
|
|
} |
|
|
} |
|
|
|
|
|
if currentOutputTokens == 0 { |
|
|
currentOutputTokens = int64(accumulatedContent.Len() / 4) |
|
|
if currentOutputTokens == 0 { |
|
|
currentOutputTokens = 1 |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
if currentOutputTokens > lastReportedOutputTokens+10 { |
|
|
|
|
|
|
|
|
pingEvent := kiroclaude.BuildClaudePingEventWithUsage(totalUsage.InputTokens, currentOutputTokens) |
|
|
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, pingEvent, &translatorParam) |
|
|
for _, chunk := range sseData { |
|
|
if chunk != "" { |
|
|
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} |
|
|
} |
|
|
} |
|
|
|
|
|
lastReportedOutputTokens = currentOutputTokens |
|
|
log.Debugf("kiro: sent real-time usage update - input: %d, output: %d (accumulated: %d chars)", |
|
|
totalUsage.InputTokens, currentOutputTokens, accumulatedContent.Len()) |
|
|
} |
|
|
|
|
|
lastUsageUpdateLen = accumulatedContent.Len() |
|
|
lastUsageUpdateTime = time.Now() |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
pendingContent.WriteString(contentDelta) |
|
|
processContent := pendingContent.String() |
|
|
pendingContent.Reset() |
|
|
|
|
|
|
|
|
for len(processContent) > 0 { |
|
|
if inThinkBlock { |
|
|
|
|
|
endIdx := strings.Index(processContent, kirocommon.ThinkingEndTag) |
|
|
if endIdx >= 0 { |
|
|
|
|
|
thinkingText := processContent[:endIdx] |
|
|
if thinkingText != "" { |
|
|
|
|
|
if !isThinkingBlockOpen { |
|
|
contentBlockIndex++ |
|
|
thinkingBlockIndex = contentBlockIndex |
|
|
isThinkingBlockOpen = true |
|
|
blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(thinkingBlockIndex, "thinking", "", "") |
|
|
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) |
|
|
for _, chunk := range sseData { |
|
|
if chunk != "" { |
|
|
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
thinkingEvent := kiroclaude.BuildClaudeThinkingDeltaEvent(thinkingText, thinkingBlockIndex) |
|
|
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, thinkingEvent, &translatorParam) |
|
|
for _, chunk := range sseData { |
|
|
if chunk != "" { |
|
|
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} |
|
|
} |
|
|
} |
|
|
accumulatedThinkingContent.WriteString(thinkingText) |
|
|
} |
|
|
|
|
|
if isThinkingBlockOpen { |
|
|
blockStop := kiroclaude.BuildClaudeThinkingBlockStopEvent(thinkingBlockIndex) |
|
|
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam) |
|
|
for _, chunk := range sseData { |
|
|
if chunk != "" { |
|
|
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} |
|
|
} |
|
|
} |
|
|
isThinkingBlockOpen = false |
|
|
} |
|
|
inThinkBlock = false |
|
|
processContent = processContent[endIdx+len(kirocommon.ThinkingEndTag):] |
|
|
log.Debugf("kiro: closed thinking block, remaining content: %d chars", len(processContent)) |
|
|
} else { |
|
|
|
|
|
partialMatch := false |
|
|
for i := 1; i < len(kirocommon.ThinkingEndTag) && i <= len(processContent); i++ { |
|
|
if strings.HasSuffix(processContent, kirocommon.ThinkingEndTag[:i]) { |
|
|
|
|
|
pendingContent.WriteString(processContent[len(processContent)-i:]) |
|
|
processContent = processContent[:len(processContent)-i] |
|
|
partialMatch = true |
|
|
break |
|
|
} |
|
|
} |
|
|
if !partialMatch || len(processContent) > 0 { |
|
|
|
|
|
if processContent != "" { |
|
|
if !isThinkingBlockOpen { |
|
|
contentBlockIndex++ |
|
|
thinkingBlockIndex = contentBlockIndex |
|
|
isThinkingBlockOpen = true |
|
|
blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(thinkingBlockIndex, "thinking", "", "") |
|
|
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) |
|
|
for _, chunk := range sseData { |
|
|
if chunk != "" { |
|
|
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} |
|
|
} |
|
|
} |
|
|
} |
|
|
thinkingEvent := kiroclaude.BuildClaudeThinkingDeltaEvent(processContent, thinkingBlockIndex) |
|
|
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, thinkingEvent, &translatorParam) |
|
|
for _, chunk := range sseData { |
|
|
if chunk != "" { |
|
|
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} |
|
|
} |
|
|
} |
|
|
accumulatedThinkingContent.WriteString(processContent) |
|
|
} |
|
|
} |
|
|
processContent = "" |
|
|
} |
|
|
} else { |
|
|
|
|
|
startIdx := strings.Index(processContent, kirocommon.ThinkingStartTag) |
|
|
if startIdx >= 0 { |
|
|
|
|
|
textBefore := processContent[:startIdx] |
|
|
if textBefore != "" { |
|
|
|
|
|
if isThinkingBlockOpen { |
|
|
blockStop := kiroclaude.BuildClaudeThinkingBlockStopEvent(thinkingBlockIndex) |
|
|
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam) |
|
|
for _, chunk := range sseData { |
|
|
if chunk != "" { |
|
|
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} |
|
|
} |
|
|
} |
|
|
isThinkingBlockOpen = false |
|
|
} |
|
|
|
|
|
if !isTextBlockOpen { |
|
|
contentBlockIndex++ |
|
|
isTextBlockOpen = true |
|
|
blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "text", "", "") |
|
|
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) |
|
|
for _, chunk := range sseData { |
|
|
if chunk != "" { |
|
|
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
claudeEvent := kiroclaude.BuildClaudeStreamEvent(textBefore, contentBlockIndex) |
|
|
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, claudeEvent, &translatorParam) |
|
|
for _, chunk := range sseData { |
|
|
if chunk != "" { |
|
|
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
if isTextBlockOpen { |
|
|
blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex) |
|
|
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam) |
|
|
for _, chunk := range sseData { |
|
|
if chunk != "" { |
|
|
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} |
|
|
} |
|
|
} |
|
|
isTextBlockOpen = false |
|
|
} |
|
|
inThinkBlock = true |
|
|
processContent = processContent[startIdx+len(kirocommon.ThinkingStartTag):] |
|
|
log.Debugf("kiro: entered thinking block") |
|
|
} else { |
|
|
|
|
|
partialMatch := false |
|
|
for i := 1; i < len(kirocommon.ThinkingStartTag) && i <= len(processContent); i++ { |
|
|
if strings.HasSuffix(processContent, kirocommon.ThinkingStartTag[:i]) { |
|
|
|
|
|
pendingContent.WriteString(processContent[len(processContent)-i:]) |
|
|
processContent = processContent[:len(processContent)-i] |
|
|
partialMatch = true |
|
|
break |
|
|
} |
|
|
} |
|
|
if !partialMatch || len(processContent) > 0 { |
|
|
|
|
|
if processContent != "" { |
|
|
if !isTextBlockOpen { |
|
|
contentBlockIndex++ |
|
|
isTextBlockOpen = true |
|
|
blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "text", "", "") |
|
|
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) |
|
|
for _, chunk := range sseData { |
|
|
if chunk != "" { |
|
|
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} |
|
|
} |
|
|
} |
|
|
} |
|
|
claudeEvent := kiroclaude.BuildClaudeStreamEvent(processContent, contentBlockIndex) |
|
|
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, claudeEvent, &translatorParam) |
|
|
for _, chunk := range sseData { |
|
|
if chunk != "" { |
|
|
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
processContent = "" |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
for _, tu := range toolUses { |
|
|
toolUseID := kirocommon.GetString(tu, "toolUseId") |
|
|
toolName := kirocommon.GetString(tu, "name") |
|
|
|
|
|
|
|
|
if processedIDs[toolUseID] { |
|
|
log.Debugf("kiro: skipping duplicate tool use in stream: %s", toolUseID) |
|
|
continue |
|
|
} |
|
|
processedIDs[toolUseID] = true |
|
|
|
|
|
hasToolUses = true |
|
|
|
|
|
if isTextBlockOpen && contentBlockIndex >= 0 { |
|
|
blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex) |
|
|
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam) |
|
|
for _, chunk := range sseData { |
|
|
if chunk != "" { |
|
|
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} |
|
|
} |
|
|
} |
|
|
isTextBlockOpen = false |
|
|
} |
|
|
|
|
|
|
|
|
contentBlockIndex++ |
|
|
|
|
|
blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "tool_use", toolUseID, toolName) |
|
|
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) |
|
|
for _, chunk := range sseData { |
|
|
if chunk != "" { |
|
|
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
if input, ok := tu["input"].(map[string]interface{}); ok { |
|
|
inputJSON, err := json.Marshal(input) |
|
|
if err != nil { |
|
|
log.Debugf("kiro: failed to marshal tool input: %v", err) |
|
|
|
|
|
} else { |
|
|
inputDelta := kiroclaude.BuildClaudeInputJsonDeltaEvent(string(inputJSON), contentBlockIndex) |
|
|
sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, inputDelta, &translatorParam) |
|
|
for _, chunk := range sseData { |
|
|
if chunk != "" { |
|
|
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex) |
|
|
sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam) |
|
|
for _, chunk := range sseData { |
|
|
if chunk != "" { |
|
|
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
case "reasoningContentEvent": |
|
|
|
|
|
|
|
|
|
|
|
var thinkingText string |
|
|
var signature string |
|
|
|
|
|
if re, ok := event["reasoningContentEvent"].(map[string]interface{}); ok { |
|
|
if text, ok := re["text"].(string); ok { |
|
|
thinkingText = text |
|
|
} |
|
|
if sig, ok := re["signature"].(string); ok { |
|
|
signature = sig |
|
|
if len(sig) > 20 { |
|
|
log.Debugf("kiro: reasoningContentEvent has signature: %s...", sig[:20]) |
|
|
} else { |
|
|
log.Debugf("kiro: reasoningContentEvent has signature: %s", sig) |
|
|
} |
|
|
} |
|
|
} else { |
|
|
|
|
|
if text, ok := event["text"].(string); ok { |
|
|
thinkingText = text |
|
|
} |
|
|
if sig, ok := event["signature"].(string); ok { |
|
|
signature = sig |
|
|
} |
|
|
} |
|
|
|
|
|
if thinkingText != "" { |
|
|
|
|
|
if isTextBlockOpen && contentBlockIndex >= 0 { |
|
|
blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex) |
|
|
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam) |
|
|
for _, chunk := range sseData { |
|
|
if chunk != "" { |
|
|
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} |
|
|
} |
|
|
} |
|
|
isTextBlockOpen = false |
|
|
} |
|
|
|
|
|
|
|
|
if !isThinkingBlockOpen { |
|
|
contentBlockIndex++ |
|
|
thinkingBlockIndex = contentBlockIndex |
|
|
isThinkingBlockOpen = true |
|
|
blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(thinkingBlockIndex, "thinking", "", "") |
|
|
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) |
|
|
for _, chunk := range sseData { |
|
|
if chunk != "" { |
|
|
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
thinkingEvent := kiroclaude.BuildClaudeThinkingDeltaEvent(thinkingText, thinkingBlockIndex) |
|
|
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, thinkingEvent, &translatorParam) |
|
|
for _, chunk := range sseData { |
|
|
if chunk != "" { |
|
|
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
accumulatedThinkingContent.WriteString(thinkingText) |
|
|
log.Debugf("kiro: received reasoningContentEvent, text length: %d, has signature: %v", len(thinkingText), signature != "") |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
_ = signature |
|
|
|
|
|
case "toolUseEvent": |
|
|
|
|
|
completedToolUses, newState := kiroclaude.ProcessToolUseEvent(event, currentToolUse, processedIDs) |
|
|
currentToolUse = newState |
|
|
|
|
|
|
|
|
for _, tu := range completedToolUses { |
|
|
hasToolUses = true |
|
|
|
|
|
|
|
|
if isTextBlockOpen && contentBlockIndex >= 0 { |
|
|
blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex) |
|
|
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam) |
|
|
for _, chunk := range sseData { |
|
|
if chunk != "" { |
|
|
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} |
|
|
} |
|
|
} |
|
|
isTextBlockOpen = false |
|
|
} |
|
|
|
|
|
contentBlockIndex++ |
|
|
|
|
|
blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "tool_use", tu.ToolUseID, tu.Name) |
|
|
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) |
|
|
for _, chunk := range sseData { |
|
|
if chunk != "" { |
|
|
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} |
|
|
} |
|
|
} |
|
|
|
|
|
if tu.Input != nil { |
|
|
inputJSON, err := json.Marshal(tu.Input) |
|
|
if err != nil { |
|
|
log.Debugf("kiro: failed to marshal tool input in toolUseEvent: %v", err) |
|
|
} else { |
|
|
inputDelta := kiroclaude.BuildClaudeInputJsonDeltaEvent(string(inputJSON), contentBlockIndex) |
|
|
sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, inputDelta, &translatorParam) |
|
|
for _, chunk := range sseData { |
|
|
if chunk != "" { |
|
|
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex) |
|
|
sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam) |
|
|
for _, chunk := range sseData { |
|
|
if chunk != "" { |
|
|
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
case "supplementaryWebLinksEvent": |
|
|
if inputTokens, ok := event["inputTokens"].(float64); ok { |
|
|
totalUsage.InputTokens = int64(inputTokens) |
|
|
} |
|
|
if outputTokens, ok := event["outputTokens"].(float64); ok { |
|
|
totalUsage.OutputTokens = int64(outputTokens) |
|
|
} |
|
|
|
|
|
case "messageMetadataEvent", "metadataEvent": |
|
|
|
|
|
|
|
|
var metadata map[string]interface{} |
|
|
if m, ok := event["messageMetadataEvent"].(map[string]interface{}); ok { |
|
|
metadata = m |
|
|
} else if m, ok := event["metadataEvent"].(map[string]interface{}); ok { |
|
|
metadata = m |
|
|
} else { |
|
|
metadata = event |
|
|
} |
|
|
|
|
|
|
|
|
if tokenUsage, ok := metadata["tokenUsage"].(map[string]interface{}); ok { |
|
|
|
|
|
if outputTokens, ok := tokenUsage["outputTokens"].(float64); ok { |
|
|
totalUsage.OutputTokens = int64(outputTokens) |
|
|
hasUpstreamUsage = true |
|
|
log.Infof("kiro: streamToChannel found precise outputTokens in tokenUsage: %d", totalUsage.OutputTokens) |
|
|
} |
|
|
|
|
|
if totalTokens, ok := tokenUsage["totalTokens"].(float64); ok { |
|
|
totalUsage.TotalTokens = int64(totalTokens) |
|
|
log.Infof("kiro: streamToChannel found precise totalTokens in tokenUsage: %d", totalUsage.TotalTokens) |
|
|
} |
|
|
|
|
|
if uncachedInputTokens, ok := tokenUsage["uncachedInputTokens"].(float64); ok { |
|
|
totalUsage.InputTokens = int64(uncachedInputTokens) |
|
|
hasUpstreamUsage = true |
|
|
log.Infof("kiro: streamToChannel found uncachedInputTokens in tokenUsage: %d", totalUsage.InputTokens) |
|
|
} |
|
|
|
|
|
if cacheReadTokens, ok := tokenUsage["cacheReadInputTokens"].(float64); ok { |
|
|
|
|
|
if totalUsage.InputTokens > 0 { |
|
|
totalUsage.InputTokens += int64(cacheReadTokens) |
|
|
} else { |
|
|
totalUsage.InputTokens = int64(cacheReadTokens) |
|
|
} |
|
|
hasUpstreamUsage = true |
|
|
log.Debugf("kiro: streamToChannel found cacheReadInputTokens in tokenUsage: %d", int64(cacheReadTokens)) |
|
|
} |
|
|
|
|
|
if ctxPct, ok := tokenUsage["contextUsagePercentage"].(float64); ok { |
|
|
upstreamContextPercentage = ctxPct |
|
|
log.Debugf("kiro: streamToChannel found contextUsagePercentage in tokenUsage: %.2f%%", ctxPct) |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
if totalUsage.InputTokens == 0 { |
|
|
if inputTokens, ok := metadata["inputTokens"].(float64); ok { |
|
|
totalUsage.InputTokens = int64(inputTokens) |
|
|
hasUpstreamUsage = true |
|
|
log.Debugf("kiro: streamToChannel found inputTokens in messageMetadataEvent: %d", totalUsage.InputTokens) |
|
|
} |
|
|
} |
|
|
if totalUsage.OutputTokens == 0 { |
|
|
if outputTokens, ok := metadata["outputTokens"].(float64); ok { |
|
|
totalUsage.OutputTokens = int64(outputTokens) |
|
|
hasUpstreamUsage = true |
|
|
log.Debugf("kiro: streamToChannel found outputTokens in messageMetadataEvent: %d", totalUsage.OutputTokens) |
|
|
} |
|
|
} |
|
|
if totalUsage.TotalTokens == 0 { |
|
|
if totalTokens, ok := metadata["totalTokens"].(float64); ok { |
|
|
totalUsage.TotalTokens = int64(totalTokens) |
|
|
log.Debugf("kiro: streamToChannel found totalTokens in messageMetadataEvent: %d", totalUsage.TotalTokens) |
|
|
} |
|
|
} |
|
|
|
|
|
case "usageEvent", "usage": |
|
|
|
|
|
if inputTokens, ok := event["inputTokens"].(float64); ok { |
|
|
totalUsage.InputTokens = int64(inputTokens) |
|
|
log.Debugf("kiro: streamToChannel found inputTokens in usageEvent: %d", totalUsage.InputTokens) |
|
|
} |
|
|
if outputTokens, ok := event["outputTokens"].(float64); ok { |
|
|
totalUsage.OutputTokens = int64(outputTokens) |
|
|
log.Debugf("kiro: streamToChannel found outputTokens in usageEvent: %d", totalUsage.OutputTokens) |
|
|
} |
|
|
if totalTokens, ok := event["totalTokens"].(float64); ok { |
|
|
totalUsage.TotalTokens = int64(totalTokens) |
|
|
log.Debugf("kiro: streamToChannel found totalTokens in usageEvent: %d", totalUsage.TotalTokens) |
|
|
} |
|
|
|
|
|
if usageObj, ok := event["usage"].(map[string]interface{}); ok { |
|
|
if inputTokens, ok := usageObj["input_tokens"].(float64); ok { |
|
|
totalUsage.InputTokens = int64(inputTokens) |
|
|
} else if inputTokens, ok := usageObj["prompt_tokens"].(float64); ok { |
|
|
totalUsage.InputTokens = int64(inputTokens) |
|
|
} |
|
|
if outputTokens, ok := usageObj["output_tokens"].(float64); ok { |
|
|
totalUsage.OutputTokens = int64(outputTokens) |
|
|
} else if outputTokens, ok := usageObj["completion_tokens"].(float64); ok { |
|
|
totalUsage.OutputTokens = int64(outputTokens) |
|
|
} |
|
|
if totalTokens, ok := usageObj["total_tokens"].(float64); ok { |
|
|
totalUsage.TotalTokens = int64(totalTokens) |
|
|
} |
|
|
log.Debugf("kiro: streamToChannel found usage object: input=%d, output=%d, total=%d", |
|
|
totalUsage.InputTokens, totalUsage.OutputTokens, totalUsage.TotalTokens) |
|
|
} |
|
|
|
|
|
case "metricsEvent": |
|
|
|
|
|
if metrics, ok := event["metricsEvent"].(map[string]interface{}); ok { |
|
|
if inputTokens, ok := metrics["inputTokens"].(float64); ok { |
|
|
totalUsage.InputTokens = int64(inputTokens) |
|
|
} |
|
|
if outputTokens, ok := metrics["outputTokens"].(float64); ok { |
|
|
totalUsage.OutputTokens = int64(outputTokens) |
|
|
} |
|
|
log.Debugf("kiro: streamToChannel found metricsEvent: input=%d, output=%d", |
|
|
totalUsage.InputTokens, totalUsage.OutputTokens) |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
if usageEvent, ok := event["supplementaryWebLinksEvent"].(map[string]interface{}); ok { |
|
|
if inputTokens, ok := usageEvent["inputTokens"].(float64); ok { |
|
|
totalUsage.InputTokens = int64(inputTokens) |
|
|
} |
|
|
if outputTokens, ok := usageEvent["outputTokens"].(float64); ok { |
|
|
totalUsage.OutputTokens = int64(outputTokens) |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
if totalUsage.InputTokens == 0 { |
|
|
if inputTokens, ok := event["inputTokens"].(float64); ok { |
|
|
totalUsage.InputTokens = int64(inputTokens) |
|
|
log.Debugf("kiro: streamToChannel found direct inputTokens: %d", totalUsage.InputTokens) |
|
|
} |
|
|
} |
|
|
if totalUsage.OutputTokens == 0 { |
|
|
if outputTokens, ok := event["outputTokens"].(float64); ok { |
|
|
totalUsage.OutputTokens = int64(outputTokens) |
|
|
log.Debugf("kiro: streamToChannel found direct outputTokens: %d", totalUsage.OutputTokens) |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
if totalUsage.InputTokens == 0 || totalUsage.OutputTokens == 0 { |
|
|
if usageObj, ok := event["usage"].(map[string]interface{}); ok { |
|
|
if totalUsage.InputTokens == 0 { |
|
|
if inputTokens, ok := usageObj["input_tokens"].(float64); ok { |
|
|
totalUsage.InputTokens = int64(inputTokens) |
|
|
} else if inputTokens, ok := usageObj["prompt_tokens"].(float64); ok { |
|
|
totalUsage.InputTokens = int64(inputTokens) |
|
|
} |
|
|
} |
|
|
if totalUsage.OutputTokens == 0 { |
|
|
if outputTokens, ok := usageObj["output_tokens"].(float64); ok { |
|
|
totalUsage.OutputTokens = int64(outputTokens) |
|
|
} else if outputTokens, ok := usageObj["completion_tokens"].(float64); ok { |
|
|
totalUsage.OutputTokens = int64(outputTokens) |
|
|
} |
|
|
} |
|
|
if totalUsage.TotalTokens == 0 { |
|
|
if totalTokens, ok := usageObj["total_tokens"].(float64); ok { |
|
|
totalUsage.TotalTokens = int64(totalTokens) |
|
|
} |
|
|
} |
|
|
log.Debugf("kiro: streamToChannel found usage object (fallback): input=%d, output=%d, total=%d", |
|
|
totalUsage.InputTokens, totalUsage.OutputTokens, totalUsage.TotalTokens) |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
if isTextBlockOpen && contentBlockIndex >= 0 { |
|
|
blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex) |
|
|
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam) |
|
|
for _, chunk := range sseData { |
|
|
if chunk != "" { |
|
|
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
if totalUsage.OutputTokens == 0 && accumulatedContent.Len() > 0 { |
|
|
|
|
|
if enc, err := getTokenizer(model); err == nil { |
|
|
if tokenCount, countErr := enc.Count(accumulatedContent.String()); countErr == nil { |
|
|
totalUsage.OutputTokens = int64(tokenCount) |
|
|
log.Debugf("kiro: streamToChannel calculated output tokens using tiktoken: %d", totalUsage.OutputTokens) |
|
|
} else { |
|
|
|
|
|
totalUsage.OutputTokens = int64(accumulatedContent.Len() / 4) |
|
|
if totalUsage.OutputTokens == 0 { |
|
|
totalUsage.OutputTokens = 1 |
|
|
} |
|
|
log.Debugf("kiro: streamToChannel tiktoken count failed, estimated from chars: %d", totalUsage.OutputTokens) |
|
|
} |
|
|
} else { |
|
|
|
|
|
totalUsage.OutputTokens = int64(accumulatedContent.Len() / 4) |
|
|
if totalUsage.OutputTokens == 0 { |
|
|
totalUsage.OutputTokens = 1 |
|
|
} |
|
|
log.Debugf("kiro: streamToChannel estimated output tokens from chars: %d (content len: %d)", totalUsage.OutputTokens, accumulatedContent.Len()) |
|
|
} |
|
|
} else if totalUsage.OutputTokens == 0 && outputLen > 0 { |
|
|
|
|
|
totalUsage.OutputTokens = int64(outputLen / 4) |
|
|
if totalUsage.OutputTokens == 0 { |
|
|
totalUsage.OutputTokens = 1 |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if upstreamContextPercentage > 0 { |
|
|
|
|
|
|
|
|
calculatedInputTokens := int64(upstreamContextPercentage * 200000 / 100) |
|
|
|
|
|
|
|
|
|
|
|
if calculatedInputTokens > 0 { |
|
|
localEstimate := totalUsage.InputTokens |
|
|
totalUsage.InputTokens = calculatedInputTokens |
|
|
log.Debugf("kiro: using contextUsagePercentage (%.2f%%) to calculate input tokens: %d (local estimate was: %d)", |
|
|
upstreamContextPercentage, calculatedInputTokens, localEstimate) |
|
|
} |
|
|
} |
|
|
|
|
|
totalUsage.TotalTokens = totalUsage.InputTokens + totalUsage.OutputTokens |
|
|
|
|
|
|
|
|
if hasUpstreamUsage { |
|
|
log.Debugf("kiro: upstream usage - credits: %.4f, context: %.2f%%, final tokens - input: %d, output: %d, total: %d", |
|
|
upstreamCreditUsage, upstreamContextPercentage, |
|
|
totalUsage.InputTokens, totalUsage.OutputTokens, totalUsage.TotalTokens) |
|
|
} |
|
|
|
|
|
|
|
|
stopReason := upstreamStopReason |
|
|
if stopReason == "" { |
|
|
if hasToolUses { |
|
|
stopReason = "tool_use" |
|
|
log.Debugf("kiro: streamToChannel using fallback stop_reason: tool_use") |
|
|
} else { |
|
|
stopReason = "end_turn" |
|
|
log.Debugf("kiro: streamToChannel using fallback stop_reason: end_turn") |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
if stopReason == "max_tokens" { |
|
|
log.Warnf("kiro: response truncated due to max_tokens limit (streamToChannel)") |
|
|
} |
|
|
|
|
|
|
|
|
msgDelta := kiroclaude.BuildClaudeMessageDeltaEvent(stopReason, totalUsage) |
|
|
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, msgDelta, &translatorParam) |
|
|
for _, chunk := range sseData { |
|
|
if chunk != "" { |
|
|
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
msgStop := kiroclaude.BuildClaudeMessageStopOnlyEvent() |
|
|
sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, msgStop, &translatorParam) |
|
|
for _, chunk := range sseData { |
|
|
if chunk != "" { |
|
|
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} |
|
|
} |
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
func (e *KiroExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { |
|
|
|
|
|
enc, err := getTokenizer(req.Model) |
|
|
if err != nil { |
|
|
log.Warnf("kiro: CountTokens failed to get tokenizer: %v, falling back to estimate", err) |
|
|
|
|
|
estimatedTokens := len(req.Payload) / 4 |
|
|
if estimatedTokens == 0 && len(req.Payload) > 0 { |
|
|
estimatedTokens = 1 |
|
|
} |
|
|
return cliproxyexecutor.Response{ |
|
|
Payload: []byte(fmt.Sprintf(`{"count":%d}`, estimatedTokens)), |
|
|
}, nil |
|
|
} |
|
|
|
|
|
|
|
|
var totalTokens int64 |
|
|
|
|
|
|
|
|
if tokens, countErr := countOpenAIChatTokens(enc, req.Payload); countErr == nil && tokens > 0 { |
|
|
totalTokens = tokens |
|
|
log.Debugf("kiro: CountTokens counted %d tokens using OpenAI chat format", totalTokens) |
|
|
} else { |
|
|
|
|
|
if tokenCount, countErr := enc.Count(string(req.Payload)); countErr == nil { |
|
|
totalTokens = int64(tokenCount) |
|
|
log.Debugf("kiro: CountTokens counted %d tokens from raw payload", totalTokens) |
|
|
} else { |
|
|
|
|
|
totalTokens = int64(len(req.Payload) / 4) |
|
|
if totalTokens == 0 && len(req.Payload) > 0 { |
|
|
totalTokens = 1 |
|
|
} |
|
|
log.Debugf("kiro: CountTokens estimated %d tokens from payload size", totalTokens) |
|
|
} |
|
|
} |
|
|
|
|
|
return cliproxyexecutor.Response{ |
|
|
Payload: []byte(fmt.Sprintf(`{"count":%d}`, totalTokens)), |
|
|
}, nil |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
func (e *KiroExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) { |
|
|
|
|
|
e.refreshMu.Lock() |
|
|
defer e.refreshMu.Unlock() |
|
|
|
|
|
var authID string |
|
|
if auth != nil { |
|
|
authID = auth.ID |
|
|
} else { |
|
|
authID = "<nil>" |
|
|
} |
|
|
log.Debugf("kiro executor: refresh called for auth %s", authID) |
|
|
if auth == nil { |
|
|
return nil, fmt.Errorf("kiro executor: auth is nil") |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if auth.Metadata != nil { |
|
|
if lastRefresh, ok := auth.Metadata["last_refresh"].(string); ok { |
|
|
if refreshTime, err := time.Parse(time.RFC3339, lastRefresh); err == nil { |
|
|
|
|
|
if time.Since(refreshTime) < 30*time.Second { |
|
|
log.Debugf("kiro executor: token was recently refreshed by another goroutine, skipping") |
|
|
return auth, nil |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
if expiresAt, ok := auth.Metadata["expires_at"].(string); ok { |
|
|
if expTime, err := time.Parse(time.RFC3339, expiresAt); err == nil { |
|
|
|
|
|
if time.Until(expTime) > 5*time.Minute { |
|
|
log.Debugf("kiro executor: token is still valid (expires in %v), skipping refresh", time.Until(expTime)) |
|
|
|
|
|
|
|
|
updated := auth.Clone() |
|
|
|
|
|
nextRefresh := expTime.Add(-5 * time.Minute) |
|
|
minNextRefresh := time.Now().Add(30 * time.Second) |
|
|
if nextRefresh.Before(minNextRefresh) { |
|
|
nextRefresh = minNextRefresh |
|
|
} |
|
|
updated.NextRefreshAfter = nextRefresh |
|
|
log.Debugf("kiro executor: setting NextRefreshAfter to %v (in %v)", nextRefresh.Format(time.RFC3339), time.Until(nextRefresh)) |
|
|
return updated, nil |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
var refreshToken string |
|
|
var clientID, clientSecret string |
|
|
var authMethod string |
|
|
var region, startURL string |
|
|
|
|
|
if auth.Metadata != nil { |
|
|
if rt, ok := auth.Metadata["refresh_token"].(string); ok { |
|
|
refreshToken = rt |
|
|
} |
|
|
if cid, ok := auth.Metadata["client_id"].(string); ok { |
|
|
clientID = cid |
|
|
} |
|
|
if cs, ok := auth.Metadata["client_secret"].(string); ok { |
|
|
clientSecret = cs |
|
|
} |
|
|
if am, ok := auth.Metadata["auth_method"].(string); ok { |
|
|
authMethod = am |
|
|
} |
|
|
if r, ok := auth.Metadata["region"].(string); ok { |
|
|
region = r |
|
|
} |
|
|
if su, ok := auth.Metadata["start_url"].(string); ok { |
|
|
startURL = su |
|
|
} |
|
|
} |
|
|
|
|
|
if refreshToken == "" { |
|
|
return nil, fmt.Errorf("kiro executor: refresh token not found") |
|
|
} |
|
|
|
|
|
var tokenData *kiroauth.KiroTokenData |
|
|
var err error |
|
|
|
|
|
ssoClient := kiroauth.NewSSOOIDCClient(e.cfg) |
|
|
|
|
|
|
|
|
switch { |
|
|
case clientID != "" && clientSecret != "" && authMethod == "idc" && region != "": |
|
|
|
|
|
log.Debugf("kiro executor: using SSO OIDC refresh for IDC (region=%s)", region) |
|
|
tokenData, err = ssoClient.RefreshTokenWithRegion(ctx, clientID, clientSecret, refreshToken, region, startURL) |
|
|
case clientID != "" && clientSecret != "" && authMethod == "builder-id": |
|
|
|
|
|
log.Debugf("kiro executor: using SSO OIDC refresh for AWS Builder ID") |
|
|
tokenData, err = ssoClient.RefreshToken(ctx, clientID, clientSecret, refreshToken) |
|
|
default: |
|
|
|
|
|
log.Debugf("kiro executor: using Kiro OAuth refresh endpoint") |
|
|
oauth := kiroauth.NewKiroOAuth(e.cfg) |
|
|
tokenData, err = oauth.RefreshToken(ctx, refreshToken) |
|
|
} |
|
|
|
|
|
if err != nil { |
|
|
return nil, fmt.Errorf("kiro executor: token refresh failed: %w", err) |
|
|
} |
|
|
|
|
|
updated := auth.Clone() |
|
|
now := time.Now() |
|
|
updated.UpdatedAt = now |
|
|
updated.LastRefreshedAt = now |
|
|
|
|
|
if updated.Metadata == nil { |
|
|
updated.Metadata = make(map[string]any) |
|
|
} |
|
|
updated.Metadata["access_token"] = tokenData.AccessToken |
|
|
updated.Metadata["refresh_token"] = tokenData.RefreshToken |
|
|
updated.Metadata["expires_at"] = tokenData.ExpiresAt |
|
|
updated.Metadata["last_refresh"] = now.Format(time.RFC3339) |
|
|
if tokenData.ProfileArn != "" { |
|
|
updated.Metadata["profile_arn"] = tokenData.ProfileArn |
|
|
} |
|
|
if tokenData.AuthMethod != "" { |
|
|
updated.Metadata["auth_method"] = tokenData.AuthMethod |
|
|
} |
|
|
if tokenData.Provider != "" { |
|
|
updated.Metadata["provider"] = tokenData.Provider |
|
|
} |
|
|
|
|
|
if tokenData.ClientID != "" { |
|
|
updated.Metadata["client_id"] = tokenData.ClientID |
|
|
} |
|
|
if tokenData.ClientSecret != "" { |
|
|
updated.Metadata["client_secret"] = tokenData.ClientSecret |
|
|
} |
|
|
|
|
|
if updated.Attributes == nil { |
|
|
updated.Attributes = make(map[string]string) |
|
|
} |
|
|
updated.Attributes["access_token"] = tokenData.AccessToken |
|
|
if tokenData.ProfileArn != "" { |
|
|
updated.Attributes["profile_arn"] = tokenData.ProfileArn |
|
|
} |
|
|
|
|
|
|
|
|
if expiresAt, parseErr := time.Parse(time.RFC3339, tokenData.ExpiresAt); parseErr == nil { |
|
|
updated.NextRefreshAfter = expiresAt.Add(-5 * time.Minute) |
|
|
} |
|
|
|
|
|
log.Infof("kiro executor: token refreshed successfully, expires at %s", tokenData.ExpiresAt) |
|
|
return updated, nil |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
func (e *KiroExecutor) persistRefreshedAuth(auth *cliproxyauth.Auth) error { |
|
|
if auth == nil || auth.Metadata == nil { |
|
|
return fmt.Errorf("kiro executor: cannot persist nil auth or metadata") |
|
|
} |
|
|
|
|
|
|
|
|
var authPath string |
|
|
if auth.Attributes != nil { |
|
|
if p := strings.TrimSpace(auth.Attributes["path"]); p != "" { |
|
|
authPath = p |
|
|
} |
|
|
} |
|
|
if authPath == "" { |
|
|
fileName := strings.TrimSpace(auth.FileName) |
|
|
if fileName == "" { |
|
|
return fmt.Errorf("kiro executor: auth has no file path or filename") |
|
|
} |
|
|
if filepath.IsAbs(fileName) { |
|
|
authPath = fileName |
|
|
} else if e.cfg != nil && e.cfg.AuthDir != "" { |
|
|
authPath = filepath.Join(e.cfg.AuthDir, fileName) |
|
|
} else { |
|
|
return fmt.Errorf("kiro executor: cannot determine auth file path") |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
raw, err := json.Marshal(auth.Metadata) |
|
|
if err != nil { |
|
|
return fmt.Errorf("kiro executor: marshal metadata failed: %w", err) |
|
|
} |
|
|
|
|
|
|
|
|
tmp := authPath + ".tmp" |
|
|
if err := os.WriteFile(tmp, raw, 0o600); err != nil { |
|
|
return fmt.Errorf("kiro executor: write temp auth file failed: %w", err) |
|
|
} |
|
|
if err := os.Rename(tmp, authPath); err != nil { |
|
|
return fmt.Errorf("kiro executor: rename auth file failed: %w", err) |
|
|
} |
|
|
|
|
|
log.Debugf("kiro executor: persisted refreshed auth to %s", authPath) |
|
|
return nil |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
func (e *KiroExecutor) isTokenExpired(accessToken string) bool { |
|
|
if accessToken == "" { |
|
|
return true |
|
|
} |
|
|
|
|
|
|
|
|
parts := strings.Split(accessToken, ".") |
|
|
if len(parts) != 3 { |
|
|
|
|
|
return false |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
payload := parts[1] |
|
|
decoded, err := base64.RawURLEncoding.DecodeString(payload) |
|
|
if err != nil { |
|
|
|
|
|
switch len(payload) % 4 { |
|
|
case 2: |
|
|
payload += "==" |
|
|
case 3: |
|
|
payload += "=" |
|
|
} |
|
|
decoded, err = base64.URLEncoding.DecodeString(payload) |
|
|
if err != nil { |
|
|
log.Debugf("kiro: failed to decode JWT payload: %v", err) |
|
|
return false |
|
|
} |
|
|
} |
|
|
|
|
|
var claims struct { |
|
|
Exp int64 `json:"exp"` |
|
|
} |
|
|
if err := json.Unmarshal(decoded, &claims); err != nil { |
|
|
log.Debugf("kiro: failed to parse JWT claims: %v", err) |
|
|
return false |
|
|
} |
|
|
|
|
|
if claims.Exp == 0 { |
|
|
|
|
|
return false |
|
|
} |
|
|
|
|
|
expTime := time.Unix(claims.Exp, 0) |
|
|
now := time.Now() |
|
|
|
|
|
|
|
|
isExpired := now.After(expTime) || expTime.Sub(now) < time.Minute |
|
|
if isExpired { |
|
|
log.Debugf("kiro: token expired at %s (now: %s)", expTime.Format(time.RFC3339), now.Format(time.RFC3339)) |
|
|
} |
|
|
|
|
|
return isExpired |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|