|
|
|
|
|
|
|
|
package amp |
|
|
|
|
|
import ( |
|
|
"regexp" |
|
|
"strings" |
|
|
"sync" |
|
|
|
|
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config" |
|
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util" |
|
|
log "github.com/sirupsen/logrus" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
type ModelMapper interface { |
|
|
|
|
|
|
|
|
MapModel(requestedModel string) string |
|
|
|
|
|
|
|
|
UpdateMappings(mappings []config.AmpModelMapping) |
|
|
} |
|
|
|
|
|
|
|
|
type DefaultModelMapper struct { |
|
|
mu sync.RWMutex |
|
|
mappings map[string]string |
|
|
regexps []regexMapping |
|
|
} |
|
|
|
|
|
|
|
|
func NewModelMapper(mappings []config.AmpModelMapping) *DefaultModelMapper { |
|
|
m := &DefaultModelMapper{ |
|
|
mappings: make(map[string]string), |
|
|
regexps: nil, |
|
|
} |
|
|
m.UpdateMappings(mappings) |
|
|
return m |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
func (m *DefaultModelMapper) MapModel(requestedModel string) string { |
|
|
if requestedModel == "" { |
|
|
return "" |
|
|
} |
|
|
|
|
|
m.mu.RLock() |
|
|
defer m.mu.RUnlock() |
|
|
|
|
|
|
|
|
normalizedRequest := strings.ToLower(strings.TrimSpace(requestedModel)) |
|
|
|
|
|
|
|
|
targetModel, exists := m.mappings[normalizedRequest] |
|
|
if !exists { |
|
|
|
|
|
base, _ := util.NormalizeThinkingModel(requestedModel) |
|
|
for _, rm := range m.regexps { |
|
|
if rm.re.MatchString(requestedModel) || (base != "" && rm.re.MatchString(base)) { |
|
|
targetModel = rm.to |
|
|
exists = true |
|
|
break |
|
|
} |
|
|
} |
|
|
if !exists { |
|
|
return "" |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
normalizedTarget, _ := util.NormalizeThinkingModel(targetModel) |
|
|
providers := util.GetProviderName(normalizedTarget) |
|
|
if len(providers) == 0 { |
|
|
log.Debugf("amp model mapping: target model %s has no available providers, skipping mapping", targetModel) |
|
|
return "" |
|
|
} |
|
|
|
|
|
|
|
|
return targetModel |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
func (m *DefaultModelMapper) UpdateMappings(mappings []config.AmpModelMapping) { |
|
|
m.mu.Lock() |
|
|
defer m.mu.Unlock() |
|
|
|
|
|
|
|
|
m.mappings = make(map[string]string, len(mappings)) |
|
|
m.regexps = make([]regexMapping, 0, len(mappings)) |
|
|
|
|
|
for _, mapping := range mappings { |
|
|
from := strings.TrimSpace(mapping.From) |
|
|
to := strings.TrimSpace(mapping.To) |
|
|
|
|
|
if from == "" || to == "" { |
|
|
log.Warnf("amp model mapping: skipping invalid mapping (from=%q, to=%q)", from, to) |
|
|
continue |
|
|
} |
|
|
|
|
|
if mapping.Regex { |
|
|
|
|
|
pattern := "(?i)" + from |
|
|
re, err := regexp.Compile(pattern) |
|
|
if err != nil { |
|
|
log.Warnf("amp model mapping: invalid regex %q: %v", from, err) |
|
|
continue |
|
|
} |
|
|
m.regexps = append(m.regexps, regexMapping{re: re, to: to}) |
|
|
log.Debugf("amp model regex mapping registered: /%s/ -> %s", from, to) |
|
|
} else { |
|
|
|
|
|
normalizedFrom := strings.ToLower(from) |
|
|
m.mappings[normalizedFrom] = to |
|
|
log.Debugf("amp model mapping registered: %s -> %s", from, to) |
|
|
} |
|
|
} |
|
|
|
|
|
if len(m.mappings) > 0 { |
|
|
log.Infof("amp model mapping: loaded %d mapping(s)", len(m.mappings)) |
|
|
} |
|
|
if n := len(m.regexps); n > 0 { |
|
|
log.Infof("amp model mapping: loaded %d regex mapping(s)", n) |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
func (m *DefaultModelMapper) GetMappings() map[string]string { |
|
|
m.mu.RLock() |
|
|
defer m.mu.RUnlock() |
|
|
|
|
|
result := make(map[string]string, len(m.mappings)) |
|
|
for k, v := range m.mappings { |
|
|
result[k] = v |
|
|
} |
|
|
return result |
|
|
} |
|
|
|
|
|
type regexMapping struct { |
|
|
re *regexp.Regexp |
|
|
to string |
|
|
} |
|
|
|