|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
package gemini |
|
|
|
|
|
import ( |
|
|
"bytes" |
|
|
"fmt" |
|
|
|
|
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common" |
|
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util" |
|
|
log "github.com/sirupsen/logrus" |
|
|
"github.com/tidwall/gjson" |
|
|
"github.com/tidwall/sjson" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
func ConvertGeminiRequestToGeminiCLI(_ string, inputRawJSON []byte, _ bool) []byte { |
|
|
rawJSON := bytes.Clone(inputRawJSON) |
|
|
template := "" |
|
|
template = `{"project":"","request":{},"model":""}` |
|
|
template, _ = sjson.SetRaw(template, "request", string(rawJSON)) |
|
|
template, _ = sjson.Set(template, "model", gjson.Get(template, "request.model").String()) |
|
|
template, _ = sjson.Delete(template, "request.model") |
|
|
|
|
|
template, errFixCLIToolResponse := fixCLIToolResponse(template) |
|
|
if errFixCLIToolResponse != nil { |
|
|
return []byte{} |
|
|
} |
|
|
|
|
|
systemInstructionResult := gjson.Get(template, "request.system_instruction") |
|
|
if systemInstructionResult.Exists() { |
|
|
template, _ = sjson.SetRaw(template, "request.systemInstruction", systemInstructionResult.Raw) |
|
|
template, _ = sjson.Delete(template, "request.system_instruction") |
|
|
} |
|
|
rawJSON = []byte(template) |
|
|
|
|
|
|
|
|
contents := gjson.GetBytes(rawJSON, "request.contents") |
|
|
if contents.Exists() { |
|
|
prevRole := "" |
|
|
idx := 0 |
|
|
contents.ForEach(func(_ gjson.Result, value gjson.Result) bool { |
|
|
role := value.Get("role").String() |
|
|
valid := role == "user" || role == "model" |
|
|
if role == "" || !valid { |
|
|
var newRole string |
|
|
if prevRole == "" { |
|
|
newRole = "user" |
|
|
} else if prevRole == "user" { |
|
|
newRole = "model" |
|
|
} else { |
|
|
newRole = "user" |
|
|
} |
|
|
path := fmt.Sprintf("request.contents.%d.role", idx) |
|
|
rawJSON, _ = sjson.SetBytes(rawJSON, path, newRole) |
|
|
role = newRole |
|
|
} |
|
|
prevRole = role |
|
|
idx++ |
|
|
return true |
|
|
}) |
|
|
} |
|
|
|
|
|
toolsResult := gjson.GetBytes(rawJSON, "request.tools") |
|
|
if toolsResult.Exists() && toolsResult.IsArray() { |
|
|
toolResults := toolsResult.Array() |
|
|
for i := 0; i < len(toolResults); i++ { |
|
|
functionDeclarationsResult := gjson.GetBytes(rawJSON, fmt.Sprintf("request.tools.%d.function_declarations", i)) |
|
|
if functionDeclarationsResult.Exists() && functionDeclarationsResult.IsArray() { |
|
|
functionDeclarationsResults := functionDeclarationsResult.Array() |
|
|
for j := 0; j < len(functionDeclarationsResults); j++ { |
|
|
parametersResult := gjson.GetBytes(rawJSON, fmt.Sprintf("request.tools.%d.function_declarations.%d.parameters", i, j)) |
|
|
if parametersResult.Exists() { |
|
|
strJson, _ := util.RenameKey(string(rawJSON), fmt.Sprintf("request.tools.%d.function_declarations.%d.parameters", i, j), fmt.Sprintf("request.tools.%d.function_declarations.%d.parametersJsonSchema", i, j)) |
|
|
rawJSON = []byte(strJson) |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
gjson.GetBytes(rawJSON, "request.contents").ForEach(func(key, content gjson.Result) bool { |
|
|
if content.Get("role").String() == "model" { |
|
|
content.Get("parts").ForEach(func(partKey, part gjson.Result) bool { |
|
|
if part.Get("functionCall").Exists() { |
|
|
rawJSON, _ = sjson.SetBytes(rawJSON, fmt.Sprintf("request.contents.%d.parts.%d.thoughtSignature", key.Int(), partKey.Int()), "skip_thought_signature_validator") |
|
|
} else if part.Get("thoughtSignature").Exists() { |
|
|
rawJSON, _ = sjson.SetBytes(rawJSON, fmt.Sprintf("request.contents.%d.parts.%d.thoughtSignature", key.Int(), partKey.Int()), "skip_thought_signature_validator") |
|
|
} |
|
|
return true |
|
|
}) |
|
|
} |
|
|
return true |
|
|
}) |
|
|
|
|
|
return common.AttachDefaultSafetySettings(rawJSON, "request.safetySettings") |
|
|
} |
|
|
|
|
|
|
|
|
type FunctionCallGroup struct { |
|
|
ResponsesNeeded int |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
func fixCLIToolResponse(input string) (string, error) { |
|
|
|
|
|
parsed := gjson.Parse(input) |
|
|
|
|
|
|
|
|
contents := parsed.Get("request.contents") |
|
|
if !contents.Exists() { |
|
|
|
|
|
return input, fmt.Errorf("contents not found in input") |
|
|
} |
|
|
|
|
|
|
|
|
contentsWrapper := `{"contents":[]}` |
|
|
var pendingGroups []*FunctionCallGroup |
|
|
var collectedResponses []gjson.Result |
|
|
|
|
|
|
|
|
|
|
|
contents.ForEach(func(key, value gjson.Result) bool { |
|
|
role := value.Get("role").String() |
|
|
parts := value.Get("parts") |
|
|
|
|
|
|
|
|
var responsePartsInThisContent []gjson.Result |
|
|
parts.ForEach(func(_, part gjson.Result) bool { |
|
|
if part.Get("functionResponse").Exists() { |
|
|
responsePartsInThisContent = append(responsePartsInThisContent, part) |
|
|
} |
|
|
return true |
|
|
}) |
|
|
|
|
|
|
|
|
if len(responsePartsInThisContent) > 0 { |
|
|
collectedResponses = append(collectedResponses, responsePartsInThisContent...) |
|
|
|
|
|
|
|
|
for i := len(pendingGroups) - 1; i >= 0; i-- { |
|
|
group := pendingGroups[i] |
|
|
if len(collectedResponses) >= group.ResponsesNeeded { |
|
|
|
|
|
groupResponses := collectedResponses[:group.ResponsesNeeded] |
|
|
collectedResponses = collectedResponses[group.ResponsesNeeded:] |
|
|
|
|
|
|
|
|
functionResponseContent := `{"parts":[],"role":"function"}` |
|
|
for _, response := range groupResponses { |
|
|
if !response.IsObject() { |
|
|
log.Warnf("failed to parse function response") |
|
|
continue |
|
|
} |
|
|
functionResponseContent, _ = sjson.SetRaw(functionResponseContent, "parts.-1", response.Raw) |
|
|
} |
|
|
|
|
|
if gjson.Get(functionResponseContent, "parts.#").Int() > 0 { |
|
|
contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", functionResponseContent) |
|
|
} |
|
|
|
|
|
|
|
|
pendingGroups = append(pendingGroups[:i], pendingGroups[i+1:]...) |
|
|
break |
|
|
} |
|
|
} |
|
|
|
|
|
return true |
|
|
} |
|
|
|
|
|
|
|
|
if role == "model" { |
|
|
functionCallsCount := 0 |
|
|
parts.ForEach(func(_, part gjson.Result) bool { |
|
|
if part.Get("functionCall").Exists() { |
|
|
functionCallsCount++ |
|
|
} |
|
|
return true |
|
|
}) |
|
|
|
|
|
if functionCallsCount > 0 { |
|
|
|
|
|
if !value.IsObject() { |
|
|
log.Warnf("failed to parse model content") |
|
|
return true |
|
|
} |
|
|
contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", value.Raw) |
|
|
|
|
|
|
|
|
group := &FunctionCallGroup{ |
|
|
ResponsesNeeded: functionCallsCount, |
|
|
} |
|
|
pendingGroups = append(pendingGroups, group) |
|
|
} else { |
|
|
|
|
|
if !value.IsObject() { |
|
|
log.Warnf("failed to parse content") |
|
|
return true |
|
|
} |
|
|
contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", value.Raw) |
|
|
} |
|
|
} else { |
|
|
|
|
|
if !value.IsObject() { |
|
|
log.Warnf("failed to parse content") |
|
|
return true |
|
|
} |
|
|
contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", value.Raw) |
|
|
} |
|
|
|
|
|
return true |
|
|
}) |
|
|
|
|
|
|
|
|
for _, group := range pendingGroups { |
|
|
if len(collectedResponses) >= group.ResponsesNeeded { |
|
|
groupResponses := collectedResponses[:group.ResponsesNeeded] |
|
|
collectedResponses = collectedResponses[group.ResponsesNeeded:] |
|
|
|
|
|
functionResponseContent := `{"parts":[],"role":"function"}` |
|
|
for _, response := range groupResponses { |
|
|
if !response.IsObject() { |
|
|
log.Warnf("failed to parse function response") |
|
|
continue |
|
|
} |
|
|
functionResponseContent, _ = sjson.SetRaw(functionResponseContent, "parts.-1", response.Raw) |
|
|
} |
|
|
|
|
|
if gjson.Get(functionResponseContent, "parts.#").Int() > 0 { |
|
|
contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", functionResponseContent) |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
result := input |
|
|
result, _ = sjson.SetRaw(result, "request.contents", gjson.Get(contentsWrapper, "contents").Raw) |
|
|
|
|
|
return result, nil |
|
|
} |
|
|
|