|
|
package amp |
|
|
|
|
|
import ( |
|
|
"bytes" |
|
|
"io" |
|
|
"net/http/httputil" |
|
|
"strings" |
|
|
"time" |
|
|
|
|
|
"github.com/gin-gonic/gin" |
|
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util" |
|
|
log "github.com/sirupsen/logrus" |
|
|
"github.com/tidwall/gjson" |
|
|
"github.com/tidwall/sjson" |
|
|
) |
|
|
|
|
|
|
|
|
type AmpRouteType string |
|
|
|
|
|
const ( |
|
|
|
|
|
RouteTypeLocalProvider AmpRouteType = "LOCAL_PROVIDER" |
|
|
|
|
|
RouteTypeModelMapping AmpRouteType = "MODEL_MAPPING" |
|
|
|
|
|
RouteTypeAmpCredits AmpRouteType = "AMP_CREDITS" |
|
|
|
|
|
RouteTypeNoProvider AmpRouteType = "NO_PROVIDER" |
|
|
) |
|
|
|
|
|
|
|
|
const MappedModelContextKey = "mapped_model" |
|
|
|
|
|
|
|
|
func logAmpRouting(routeType AmpRouteType, requestedModel, resolvedModel, provider, path string) { |
|
|
fields := log.Fields{ |
|
|
"component": "amp-routing", |
|
|
"route_type": string(routeType), |
|
|
"requested_model": requestedModel, |
|
|
"path": path, |
|
|
"timestamp": time.Now().Format(time.RFC3339), |
|
|
} |
|
|
|
|
|
if resolvedModel != "" && resolvedModel != requestedModel { |
|
|
fields["resolved_model"] = resolvedModel |
|
|
} |
|
|
if provider != "" { |
|
|
fields["provider"] = provider |
|
|
} |
|
|
|
|
|
switch routeType { |
|
|
case RouteTypeLocalProvider: |
|
|
fields["cost"] = "free" |
|
|
fields["source"] = "local_oauth" |
|
|
log.WithFields(fields).Debugf("amp using local provider for model: %s", requestedModel) |
|
|
|
|
|
case RouteTypeModelMapping: |
|
|
fields["cost"] = "free" |
|
|
fields["source"] = "local_oauth" |
|
|
fields["mapping"] = requestedModel + " -> " + resolvedModel |
|
|
|
|
|
|
|
|
case RouteTypeAmpCredits: |
|
|
fields["cost"] = "amp_credits" |
|
|
fields["source"] = "ampcode.com" |
|
|
fields["model_id"] = requestedModel |
|
|
log.WithFields(fields).Warnf("forwarding to ampcode.com (uses amp credits) - model_id: %s | To use local provider, add to config: ampcode.model-mappings: [{from: \"%s\", to: \"<your-local-model>\"}]", requestedModel, requestedModel) |
|
|
|
|
|
case RouteTypeNoProvider: |
|
|
fields["cost"] = "none" |
|
|
fields["source"] = "error" |
|
|
fields["model_id"] = requestedModel |
|
|
log.WithFields(fields).Warnf("no provider available for model_id: %s", requestedModel) |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
type FallbackHandler struct { |
|
|
getProxy func() *httputil.ReverseProxy |
|
|
modelMapper ModelMapper |
|
|
forceModelMappings func() bool |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
func NewFallbackHandler(getProxy func() *httputil.ReverseProxy) *FallbackHandler { |
|
|
return &FallbackHandler{ |
|
|
getProxy: getProxy, |
|
|
forceModelMappings: func() bool { return false }, |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
func NewFallbackHandlerWithMapper(getProxy func() *httputil.ReverseProxy, mapper ModelMapper, forceModelMappings func() bool) *FallbackHandler { |
|
|
if forceModelMappings == nil { |
|
|
forceModelMappings = func() bool { return false } |
|
|
} |
|
|
return &FallbackHandler{ |
|
|
getProxy: getProxy, |
|
|
modelMapper: mapper, |
|
|
forceModelMappings: forceModelMappings, |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
func (fh *FallbackHandler) SetModelMapper(mapper ModelMapper) { |
|
|
fh.modelMapper = mapper |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc { |
|
|
return func(c *gin.Context) { |
|
|
requestPath := c.Request.URL.Path |
|
|
|
|
|
|
|
|
bodyBytes, err := io.ReadAll(c.Request.Body) |
|
|
if err != nil { |
|
|
log.Errorf("amp fallback: failed to read request body: %v", err) |
|
|
handler(c) |
|
|
return |
|
|
} |
|
|
|
|
|
|
|
|
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) |
|
|
|
|
|
|
|
|
modelName := extractModelFromRequest(bodyBytes, c) |
|
|
if modelName == "" { |
|
|
|
|
|
handler(c) |
|
|
return |
|
|
} |
|
|
|
|
|
|
|
|
normalizedModel, thinkingMetadata := util.NormalizeThinkingModel(modelName) |
|
|
thinkingSuffix := "" |
|
|
if thinkingMetadata != nil && strings.HasPrefix(modelName, normalizedModel) { |
|
|
thinkingSuffix = modelName[len(normalizedModel):] |
|
|
} |
|
|
|
|
|
resolveMappedModel := func() (string, []string) { |
|
|
if fh.modelMapper == nil { |
|
|
return "", nil |
|
|
} |
|
|
|
|
|
mappedModel := fh.modelMapper.MapModel(modelName) |
|
|
if mappedModel == "" { |
|
|
mappedModel = fh.modelMapper.MapModel(normalizedModel) |
|
|
} |
|
|
mappedModel = strings.TrimSpace(mappedModel) |
|
|
if mappedModel == "" { |
|
|
return "", nil |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
if thinkingSuffix != "" { |
|
|
_, mappedThinkingMetadata := util.NormalizeThinkingModel(mappedModel) |
|
|
if mappedThinkingMetadata == nil { |
|
|
mappedModel += thinkingSuffix |
|
|
} |
|
|
} |
|
|
|
|
|
mappedBaseModel, _ := util.NormalizeThinkingModel(mappedModel) |
|
|
mappedProviders := util.GetProviderName(mappedBaseModel) |
|
|
if len(mappedProviders) == 0 { |
|
|
return "", nil |
|
|
} |
|
|
|
|
|
return mappedModel, mappedProviders |
|
|
} |
|
|
|
|
|
|
|
|
resolvedModel := normalizedModel |
|
|
usedMapping := false |
|
|
var providers []string |
|
|
|
|
|
|
|
|
forceMappings := fh.forceModelMappings != nil && fh.forceModelMappings() |
|
|
|
|
|
if forceMappings { |
|
|
|
|
|
|
|
|
if mappedModel, mappedProviders := resolveMappedModel(); mappedModel != "" { |
|
|
|
|
|
bodyBytes = rewriteModelInRequest(bodyBytes, mappedModel) |
|
|
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) |
|
|
|
|
|
c.Set(MappedModelContextKey, mappedModel) |
|
|
resolvedModel = mappedModel |
|
|
usedMapping = true |
|
|
providers = mappedProviders |
|
|
} |
|
|
|
|
|
|
|
|
if !usedMapping { |
|
|
providers = util.GetProviderName(normalizedModel) |
|
|
} |
|
|
} else { |
|
|
|
|
|
providers = util.GetProviderName(normalizedModel) |
|
|
|
|
|
if len(providers) == 0 { |
|
|
|
|
|
if mappedModel, mappedProviders := resolveMappedModel(); mappedModel != "" { |
|
|
|
|
|
bodyBytes = rewriteModelInRequest(bodyBytes, mappedModel) |
|
|
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) |
|
|
|
|
|
c.Set(MappedModelContextKey, mappedModel) |
|
|
resolvedModel = mappedModel |
|
|
usedMapping = true |
|
|
providers = mappedProviders |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
if len(providers) == 0 { |
|
|
proxy := fh.getProxy() |
|
|
if proxy != nil { |
|
|
|
|
|
logAmpRouting(RouteTypeAmpCredits, modelName, "", "", requestPath) |
|
|
|
|
|
|
|
|
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) |
|
|
|
|
|
|
|
|
proxy.ServeHTTP(c.Writer, c.Request) |
|
|
return |
|
|
} |
|
|
|
|
|
|
|
|
logAmpRouting(RouteTypeNoProvider, modelName, "", "", requestPath) |
|
|
} |
|
|
|
|
|
|
|
|
providerName := "" |
|
|
if len(providers) > 0 { |
|
|
providerName = providers[0] |
|
|
} |
|
|
|
|
|
if usedMapping { |
|
|
|
|
|
log.Debugf("amp model mapping: request %s -> %s", normalizedModel, resolvedModel) |
|
|
logAmpRouting(RouteTypeModelMapping, modelName, resolvedModel, providerName, requestPath) |
|
|
rewriter := NewResponseRewriter(c.Writer, modelName) |
|
|
c.Writer = rewriter |
|
|
|
|
|
filterAntropicBetaHeader(c) |
|
|
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) |
|
|
handler(c) |
|
|
rewriter.Flush() |
|
|
log.Debugf("amp model mapping: response %s -> %s", resolvedModel, modelName) |
|
|
} else if len(providers) > 0 { |
|
|
|
|
|
logAmpRouting(RouteTypeLocalProvider, modelName, resolvedModel, providerName, requestPath) |
|
|
|
|
|
filterAntropicBetaHeader(c) |
|
|
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) |
|
|
handler(c) |
|
|
} else { |
|
|
|
|
|
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) |
|
|
handler(c) |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
func filterAntropicBetaHeader(c *gin.Context) { |
|
|
if betaHeader := c.Request.Header.Get("Anthropic-Beta"); betaHeader != "" { |
|
|
if filtered := filterBetaFeatures(betaHeader, "context-1m-2025-08-07"); filtered != "" { |
|
|
c.Request.Header.Set("Anthropic-Beta", filtered) |
|
|
} else { |
|
|
c.Request.Header.Del("Anthropic-Beta") |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
func rewriteModelInRequest(body []byte, newModel string) []byte { |
|
|
if !gjson.GetBytes(body, "model").Exists() { |
|
|
return body |
|
|
} |
|
|
result, err := sjson.SetBytes(body, "model", newModel) |
|
|
if err != nil { |
|
|
log.Warnf("amp model mapping: failed to rewrite model in request body: %v", err) |
|
|
return body |
|
|
} |
|
|
return result |
|
|
} |
|
|
|
|
|
|
|
|
func extractModelFromRequest(body []byte, c *gin.Context) string { |
|
|
|
|
|
|
|
|
if result := gjson.GetBytes(body, "model"); result.Exists() && result.Type == gjson.String { |
|
|
return result.String() |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
if action := c.Param("action"); action != "" { |
|
|
|
|
|
parts := strings.Split(action, ":") |
|
|
if len(parts) > 0 && parts[0] != "" { |
|
|
return parts[0] |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
if path := c.Param("path"); path != "" { |
|
|
|
|
|
if idx := strings.Index(path, "/models/"); idx >= 0 { |
|
|
modelPart := path[idx+8:] |
|
|
|
|
|
if colonIdx := strings.Index(modelPart, ":"); colonIdx > 0 { |
|
|
return modelPart[:colonIdx] |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
return "" |
|
|
} |
|
|
|