| package translator |
|
|
| import ( |
| "context" |
| "sync" |
| ) |
|
|
| |
| type Registry struct { |
| mu sync.RWMutex |
| requests map[Format]map[Format]RequestTransform |
| responses map[Format]map[Format]ResponseTransform |
| } |
|
|
| |
| func NewRegistry() *Registry { |
| return &Registry{ |
| requests: make(map[Format]map[Format]RequestTransform), |
| responses: make(map[Format]map[Format]ResponseTransform), |
| } |
| } |
|
|
| |
| func (r *Registry) Register(from, to Format, request RequestTransform, response ResponseTransform) { |
| r.mu.Lock() |
| defer r.mu.Unlock() |
|
|
| if _, ok := r.requests[from]; !ok { |
| r.requests[from] = make(map[Format]RequestTransform) |
| } |
| if request != nil { |
| r.requests[from][to] = request |
| } |
|
|
| if _, ok := r.responses[from]; !ok { |
| r.responses[from] = make(map[Format]ResponseTransform) |
| } |
| r.responses[from][to] = response |
| } |
|
|
| |
| |
| func (r *Registry) TranslateRequest(from, to Format, model string, rawJSON []byte, stream bool) []byte { |
| r.mu.RLock() |
| defer r.mu.RUnlock() |
|
|
| if byTarget, ok := r.requests[from]; ok { |
| if fn, isOk := byTarget[to]; isOk && fn != nil { |
| return fn(model, rawJSON, stream) |
| } |
| } |
| return rawJSON |
| } |
|
|
| |
| func (r *Registry) HasResponseTransformer(from, to Format) bool { |
| r.mu.RLock() |
| defer r.mu.RUnlock() |
|
|
| if byTarget, ok := r.responses[from]; ok { |
| if _, isOk := byTarget[to]; isOk { |
| return true |
| } |
| } |
| return false |
| } |
|
|
| |
| func (r *Registry) TranslateStream(ctx context.Context, from, to Format, model string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string { |
| r.mu.RLock() |
| defer r.mu.RUnlock() |
|
|
| if byTarget, ok := r.responses[to]; ok { |
| if fn, isOk := byTarget[from]; isOk && fn.Stream != nil { |
| return fn.Stream(ctx, model, originalRequestRawJSON, requestRawJSON, rawJSON, param) |
| } |
| } |
| return []string{string(rawJSON)} |
| } |
|
|
| |
| func (r *Registry) TranslateNonStream(ctx context.Context, from, to Format, model string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) string { |
| r.mu.RLock() |
| defer r.mu.RUnlock() |
|
|
| if byTarget, ok := r.responses[to]; ok { |
| if fn, isOk := byTarget[from]; isOk && fn.NonStream != nil { |
| return fn.NonStream(ctx, model, originalRequestRawJSON, requestRawJSON, rawJSON, param) |
| } |
| } |
| return string(rawJSON) |
| } |
|
|
| |
| func (r *Registry) TranslateTokenCount(ctx context.Context, from, to Format, count int64, rawJSON []byte) string { |
| r.mu.RLock() |
| defer r.mu.RUnlock() |
|
|
| if byTarget, ok := r.responses[to]; ok { |
| if fn, isOk := byTarget[from]; isOk && fn.TokenCount != nil { |
| return fn.TokenCount(ctx, count) |
| } |
| } |
| return string(rawJSON) |
| } |
|
|
| var defaultRegistry = NewRegistry() |
|
|
| |
| func Default() *Registry { |
| return defaultRegistry |
| } |
|
|
| |
| func Register(from, to Format, request RequestTransform, response ResponseTransform) { |
| defaultRegistry.Register(from, to, request, response) |
| } |
|
|
| |
| func TranslateRequest(from, to Format, model string, rawJSON []byte, stream bool) []byte { |
| return defaultRegistry.TranslateRequest(from, to, model, rawJSON, stream) |
| } |
|
|
| |
| func HasResponseTransformer(from, to Format) bool { |
| return defaultRegistry.HasResponseTransformer(from, to) |
| } |
|
|
| |
| func TranslateStream(ctx context.Context, from, to Format, model string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string { |
| return defaultRegistry.TranslateStream(ctx, from, to, model, originalRequestRawJSON, requestRawJSON, rawJSON, param) |
| } |
|
|
| |
| func TranslateNonStream(ctx context.Context, from, to Format, model string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) string { |
| return defaultRegistry.TranslateNonStream(ctx, from, to, model, originalRequestRawJSON, requestRawJSON, rawJSON, param) |
| } |
|
|
| |
| func TranslateTokenCount(ctx context.Context, from, to Format, count int64, rawJSON []byte) string { |
| return defaultRegistry.TranslateTokenCount(ctx, from, to, count, rawJSON) |
| } |
|
|