| package vertex |
|
|
| import ( |
| "crypto/rsa" |
| "crypto/x509" |
| "encoding/base64" |
| "encoding/json" |
| "encoding/pem" |
| "fmt" |
| "strings" |
| ) |
|
|
| |
| |
| |
| func NormalizeServiceAccountJSON(raw []byte) ([]byte, error) { |
| if len(raw) == 0 { |
| return raw, nil |
| } |
| var payload map[string]any |
| if err := json.Unmarshal(raw, &payload); err != nil { |
| return raw, err |
| } |
| normalized, err := NormalizeServiceAccountMap(payload) |
| if err != nil { |
| return raw, err |
| } |
| out, err := json.Marshal(normalized) |
| if err != nil { |
| return raw, err |
| } |
| return out, nil |
| } |
|
|
| |
| |
| func NormalizeServiceAccountMap(sa map[string]any) (map[string]any, error) { |
| if sa == nil { |
| return nil, fmt.Errorf("service account payload is empty") |
| } |
| pk, _ := sa["private_key"].(string) |
| if strings.TrimSpace(pk) == "" { |
| return nil, fmt.Errorf("service account missing private_key") |
| } |
| normalized, err := sanitizePrivateKey(pk) |
| if err != nil { |
| return nil, err |
| } |
| clone := make(map[string]any, len(sa)) |
| for k, v := range sa { |
| clone[k] = v |
| } |
| clone["private_key"] = normalized |
| return clone, nil |
| } |
|
|
| func sanitizePrivateKey(raw string) (string, error) { |
| pk := strings.ReplaceAll(raw, "\r\n", "\n") |
| pk = strings.ReplaceAll(pk, "\r", "\n") |
| pk = stripANSIEscape(pk) |
| pk = strings.ToValidUTF8(pk, "") |
| pk = strings.TrimSpace(pk) |
|
|
| normalized := pk |
| if block, _ := pem.Decode([]byte(pk)); block == nil { |
| |
| if reconstructed, err := rebuildPEM(pk); err == nil { |
| normalized = reconstructed |
| } else { |
| return "", fmt.Errorf("private_key is not valid pem: %w", err) |
| } |
| } |
|
|
| block, _ := pem.Decode([]byte(normalized)) |
| if block == nil { |
| return "", fmt.Errorf("private_key pem decode failed") |
| } |
|
|
| rsaBlock, err := ensureRSAPrivateKey(block) |
| if err != nil { |
| return "", err |
| } |
| return string(pem.EncodeToMemory(rsaBlock)), nil |
| } |
|
|
| func ensureRSAPrivateKey(block *pem.Block) (*pem.Block, error) { |
| if block == nil { |
| return nil, fmt.Errorf("pem block is nil") |
| } |
|
|
| if block.Type == "RSA PRIVATE KEY" { |
| if _, err := x509.ParsePKCS1PrivateKey(block.Bytes); err != nil { |
| return nil, fmt.Errorf("private_key invalid rsa: %w", err) |
| } |
| return block, nil |
| } |
|
|
| if block.Type == "PRIVATE KEY" { |
| key, err := x509.ParsePKCS8PrivateKey(block.Bytes) |
| if err != nil { |
| return nil, fmt.Errorf("private_key invalid pkcs8: %w", err) |
| } |
| rsaKey, ok := key.(*rsa.PrivateKey) |
| if !ok { |
| return nil, fmt.Errorf("private_key is not an RSA key") |
| } |
| der := x509.MarshalPKCS1PrivateKey(rsaKey) |
| return &pem.Block{Type: "RSA PRIVATE KEY", Bytes: der}, nil |
| } |
|
|
| |
| if rsaKey, err := x509.ParsePKCS1PrivateKey(block.Bytes); err == nil { |
| der := x509.MarshalPKCS1PrivateKey(rsaKey) |
| return &pem.Block{Type: "RSA PRIVATE KEY", Bytes: der}, nil |
| } |
| if key, err := x509.ParsePKCS8PrivateKey(block.Bytes); err == nil { |
| if rsaKey, ok := key.(*rsa.PrivateKey); ok { |
| der := x509.MarshalPKCS1PrivateKey(rsaKey) |
| return &pem.Block{Type: "RSA PRIVATE KEY", Bytes: der}, nil |
| } |
| } |
| return nil, fmt.Errorf("private_key uses unsupported format") |
| } |
|
|
| func rebuildPEM(raw string) (string, error) { |
| kind := "PRIVATE KEY" |
| if strings.Contains(raw, "RSA PRIVATE KEY") { |
| kind = "RSA PRIVATE KEY" |
| } |
| header := "-----BEGIN " + kind + "-----" |
| footer := "-----END " + kind + "-----" |
| start := strings.Index(raw, header) |
| end := strings.Index(raw, footer) |
| if start < 0 || end <= start { |
| return "", fmt.Errorf("missing pem markers") |
| } |
| body := raw[start+len(header) : end] |
| payload := filterBase64(body) |
| if payload == "" { |
| return "", fmt.Errorf("private_key base64 payload empty") |
| } |
| der, err := base64.StdEncoding.DecodeString(payload) |
| if err != nil { |
| return "", fmt.Errorf("private_key base64 decode failed: %w", err) |
| } |
| block := &pem.Block{Type: kind, Bytes: der} |
| return string(pem.EncodeToMemory(block)), nil |
| } |
|
|
| func filterBase64(s string) string { |
| var b strings.Builder |
| for _, r := range s { |
| switch { |
| case r >= 'A' && r <= 'Z': |
| b.WriteRune(r) |
| case r >= 'a' && r <= 'z': |
| b.WriteRune(r) |
| case r >= '0' && r <= '9': |
| b.WriteRune(r) |
| case r == '+' || r == '/' || r == '=': |
| b.WriteRune(r) |
| default: |
| |
| } |
| } |
| return b.String() |
| } |
|
|
| func stripANSIEscape(s string) string { |
| in := []rune(s) |
| var out []rune |
| for i := 0; i < len(in); i++ { |
| r := in[i] |
| if r != 0x1b { |
| out = append(out, r) |
| continue |
| } |
| if i+1 >= len(in) { |
| continue |
| } |
| next := in[i+1] |
| switch next { |
| case ']': |
| i += 2 |
| for i < len(in) { |
| if in[i] == 0x07 { |
| break |
| } |
| if in[i] == 0x1b && i+1 < len(in) && in[i+1] == '\\' { |
| i++ |
| break |
| } |
| i++ |
| } |
| case '[': |
| i += 2 |
| for i < len(in) { |
| if (in[i] >= 'A' && in[i] <= 'Z') || (in[i] >= 'a' && in[i] <= 'z') { |
| break |
| } |
| i++ |
| } |
| default: |
| |
| } |
| } |
| return string(out) |
| } |
|
|