Spaces:
Build error
Build error
| package vertex | |
| import ( | |
| "bytes" | |
| "encoding/base64" | |
| "encoding/json" | |
| "fmt" | |
| "io" | |
| "net/http" | |
| "one-api/model" | |
| "regexp" | |
| "strings" | |
| "github.com/gin-gonic/gin" | |
| "one-api/constant" | |
| "one-api/dto" | |
| "one-api/relay/channel" | |
| vertexcore "one-api/relay/channel/vertex" | |
| relaycommon "one-api/relay/common" | |
| "one-api/service" | |
| ) | |
| // ============================ | |
| // Request / Response structures | |
| // ============================ | |
| type requestPayload struct { | |
| Instances []map[string]any `json:"instances"` | |
| Parameters map[string]any `json:"parameters,omitempty"` | |
| } | |
| type submitResponse struct { | |
| Name string `json:"name"` | |
| } | |
| type operationVideo struct { | |
| MimeType string `json:"mimeType"` | |
| BytesBase64Encoded string `json:"bytesBase64Encoded"` | |
| Encoding string `json:"encoding"` | |
| } | |
| type operationResponse struct { | |
| Name string `json:"name"` | |
| Done bool `json:"done"` | |
| Response struct { | |
| Type string `json:"@type"` | |
| RaiMediaFilteredCount int `json:"raiMediaFilteredCount"` | |
| Videos []operationVideo `json:"videos"` | |
| BytesBase64Encoded string `json:"bytesBase64Encoded"` | |
| Encoding string `json:"encoding"` | |
| Video string `json:"video"` | |
| } `json:"response"` | |
| Error struct { | |
| Message string `json:"message"` | |
| } `json:"error"` | |
| } | |
| // ============================ | |
| // Adaptor implementation | |
| // ============================ | |
| type TaskAdaptor struct { | |
| ChannelType int | |
| apiKey string | |
| baseURL string | |
| } | |
| func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) { | |
| a.ChannelType = info.ChannelType | |
| a.baseURL = info.ChannelBaseUrl | |
| a.apiKey = info.ApiKey | |
| } | |
| // ValidateRequestAndSetAction parses body, validates fields and sets default action. | |
| func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) { | |
| // Use the standard validation method for TaskSubmitReq | |
| return relaycommon.ValidateBasicTaskRequest(c, info, constant.TaskActionTextGenerate) | |
| } | |
| // BuildRequestURL constructs the upstream URL. | |
| func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) { | |
| adc := &vertexcore.Credentials{} | |
| if err := json.Unmarshal([]byte(a.apiKey), adc); err != nil { | |
| return "", fmt.Errorf("failed to decode credentials: %w", err) | |
| } | |
| modelName := info.OriginModelName | |
| if modelName == "" { | |
| modelName = "veo-3.0-generate-001" | |
| } | |
| region := vertexcore.GetModelRegion(info.ApiVersion, modelName) | |
| if strings.TrimSpace(region) == "" { | |
| region = "global" | |
| } | |
| if region == "global" { | |
| return fmt.Sprintf( | |
| "https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:predictLongRunning", | |
| adc.ProjectID, | |
| modelName, | |
| ), nil | |
| } | |
| return fmt.Sprintf( | |
| "https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:predictLongRunning", | |
| region, | |
| adc.ProjectID, | |
| region, | |
| modelName, | |
| ), nil | |
| } | |
| // BuildRequestHeader sets required headers. | |
| func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { | |
| req.Header.Set("Content-Type", "application/json") | |
| req.Header.Set("Accept", "application/json") | |
| adc := &vertexcore.Credentials{} | |
| if err := json.Unmarshal([]byte(a.apiKey), adc); err != nil { | |
| return fmt.Errorf("failed to decode credentials: %w", err) | |
| } | |
| token, err := vertexcore.AcquireAccessToken(*adc, "") | |
| if err != nil { | |
| return fmt.Errorf("failed to acquire access token: %w", err) | |
| } | |
| req.Header.Set("Authorization", "Bearer "+token) | |
| req.Header.Set("x-goog-user-project", adc.ProjectID) | |
| return nil | |
| } | |
| // BuildRequestBody converts request into Vertex specific format. | |
| func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) { | |
| v, ok := c.Get("task_request") | |
| if !ok { | |
| return nil, fmt.Errorf("request not found in context") | |
| } | |
| req := v.(relaycommon.TaskSubmitReq) | |
| body := requestPayload{ | |
| Instances: []map[string]any{{"prompt": req.Prompt}}, | |
| Parameters: map[string]any{}, | |
| } | |
| if req.Metadata != nil { | |
| if v, ok := req.Metadata["storageUri"]; ok { | |
| body.Parameters["storageUri"] = v | |
| } | |
| if v, ok := req.Metadata["sampleCount"]; ok { | |
| body.Parameters["sampleCount"] = v | |
| } | |
| } | |
| if _, ok := body.Parameters["sampleCount"]; !ok { | |
| body.Parameters["sampleCount"] = 1 | |
| } | |
| data, err := json.Marshal(body) | |
| if err != nil { | |
| return nil, err | |
| } | |
| return bytes.NewReader(data), nil | |
| } | |
| // DoRequest delegates to common helper. | |
| func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { | |
| return channel.DoTaskApiRequest(a, c, info, requestBody) | |
| } | |
| // DoResponse handles upstream response, returns taskID etc. | |
| func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) { | |
| responseBody, err := io.ReadAll(resp.Body) | |
| if err != nil { | |
| return "", nil, service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) | |
| } | |
| _ = resp.Body.Close() | |
| var s submitResponse | |
| if err := json.Unmarshal(responseBody, &s); err != nil { | |
| return "", nil, service.TaskErrorWrapper(err, "unmarshal_response_failed", http.StatusInternalServerError) | |
| } | |
| if strings.TrimSpace(s.Name) == "" { | |
| return "", nil, service.TaskErrorWrapper(fmt.Errorf("missing operation name"), "invalid_response", http.StatusInternalServerError) | |
| } | |
| localID := encodeLocalTaskID(s.Name) | |
| c.JSON(http.StatusOK, gin.H{"task_id": localID}) | |
| return localID, responseBody, nil | |
| } | |
| func (a *TaskAdaptor) GetModelList() []string { return []string{"veo-3.0-generate-001"} } | |
| func (a *TaskAdaptor) GetChannelName() string { return "vertex" } | |
| // FetchTask fetch task status | |
| func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) { | |
| taskID, ok := body["task_id"].(string) | |
| if !ok { | |
| return nil, fmt.Errorf("invalid task_id") | |
| } | |
| upstreamName, err := decodeLocalTaskID(taskID) | |
| if err != nil { | |
| return nil, fmt.Errorf("decode task_id failed: %w", err) | |
| } | |
| region := extractRegionFromOperationName(upstreamName) | |
| if region == "" { | |
| region = "us-central1" | |
| } | |
| project := extractProjectFromOperationName(upstreamName) | |
| modelName := extractModelFromOperationName(upstreamName) | |
| if project == "" || modelName == "" { | |
| return nil, fmt.Errorf("cannot extract project/model from operation name") | |
| } | |
| var url string | |
| if region == "global" { | |
| url = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:fetchPredictOperation", project, modelName) | |
| } else { | |
| url = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:fetchPredictOperation", region, project, region, modelName) | |
| } | |
| payload := map[string]string{"operationName": upstreamName} | |
| data, err := json.Marshal(payload) | |
| if err != nil { | |
| return nil, err | |
| } | |
| adc := &vertexcore.Credentials{} | |
| if err := json.Unmarshal([]byte(key), adc); err != nil { | |
| return nil, fmt.Errorf("failed to decode credentials: %w", err) | |
| } | |
| token, err := vertexcore.AcquireAccessToken(*adc, "") | |
| if err != nil { | |
| return nil, fmt.Errorf("failed to acquire access token: %w", err) | |
| } | |
| req, err := http.NewRequest(http.MethodPost, url, bytes.NewReader(data)) | |
| if err != nil { | |
| return nil, err | |
| } | |
| req.Header.Set("Content-Type", "application/json") | |
| req.Header.Set("Accept", "application/json") | |
| req.Header.Set("Authorization", "Bearer "+token) | |
| req.Header.Set("x-goog-user-project", adc.ProjectID) | |
| return service.GetHttpClient().Do(req) | |
| } | |
| func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) { | |
| var op operationResponse | |
| if err := json.Unmarshal(respBody, &op); err != nil { | |
| return nil, fmt.Errorf("unmarshal operation response failed: %w", err) | |
| } | |
| ti := &relaycommon.TaskInfo{} | |
| if op.Error.Message != "" { | |
| ti.Status = model.TaskStatusFailure | |
| ti.Reason = op.Error.Message | |
| ti.Progress = "100%" | |
| return ti, nil | |
| } | |
| if !op.Done { | |
| ti.Status = model.TaskStatusInProgress | |
| ti.Progress = "50%" | |
| return ti, nil | |
| } | |
| ti.Status = model.TaskStatusSuccess | |
| ti.Progress = "100%" | |
| if len(op.Response.Videos) > 0 { | |
| v0 := op.Response.Videos[0] | |
| if v0.BytesBase64Encoded != "" { | |
| mime := strings.TrimSpace(v0.MimeType) | |
| if mime == "" { | |
| enc := strings.TrimSpace(v0.Encoding) | |
| if enc == "" { | |
| enc = "mp4" | |
| } | |
| if strings.Contains(enc, "/") { | |
| mime = enc | |
| } else { | |
| mime = "video/" + enc | |
| } | |
| } | |
| ti.Url = "data:" + mime + ";base64," + v0.BytesBase64Encoded | |
| return ti, nil | |
| } | |
| } | |
| if op.Response.BytesBase64Encoded != "" { | |
| enc := strings.TrimSpace(op.Response.Encoding) | |
| if enc == "" { | |
| enc = "mp4" | |
| } | |
| mime := enc | |
| if !strings.Contains(enc, "/") { | |
| mime = "video/" + enc | |
| } | |
| ti.Url = "data:" + mime + ";base64," + op.Response.BytesBase64Encoded | |
| return ti, nil | |
| } | |
| if op.Response.Video != "" { // some variants use `video` as base64 | |
| enc := strings.TrimSpace(op.Response.Encoding) | |
| if enc == "" { | |
| enc = "mp4" | |
| } | |
| mime := enc | |
| if !strings.Contains(enc, "/") { | |
| mime = "video/" + enc | |
| } | |
| ti.Url = "data:" + mime + ";base64," + op.Response.Video | |
| return ti, nil | |
| } | |
| return ti, nil | |
| } | |
| // ============================ | |
| // helpers | |
| // ============================ | |
| func encodeLocalTaskID(name string) string { | |
| return base64.RawURLEncoding.EncodeToString([]byte(name)) | |
| } | |
| func decodeLocalTaskID(local string) (string, error) { | |
| b, err := base64.RawURLEncoding.DecodeString(local) | |
| if err != nil { | |
| return "", err | |
| } | |
| return string(b), nil | |
| } | |
| var regionRe = regexp.MustCompile(`locations/([a-z0-9-]+)/`) | |
| func extractRegionFromOperationName(name string) string { | |
| m := regionRe.FindStringSubmatch(name) | |
| if len(m) == 2 { | |
| return m[1] | |
| } | |
| return "" | |
| } | |
| var modelRe = regexp.MustCompile(`models/([^/]+)/operations/`) | |
| func extractModelFromOperationName(name string) string { | |
| m := modelRe.FindStringSubmatch(name) | |
| if len(m) == 2 { | |
| return m[1] | |
| } | |
| idx := strings.Index(name, "models/") | |
| if idx >= 0 { | |
| s := name[idx+len("models/"):] | |
| if p := strings.Index(s, "/operations/"); p > 0 { | |
| return s[:p] | |
| } | |
| } | |
| return "" | |
| } | |
| var projectRe = regexp.MustCompile(`projects/([^/]+)/locations/`) | |
| func extractProjectFromOperationName(name string) string { | |
| m := projectRe.FindStringSubmatch(name) | |
| if len(m) == 2 { | |
| return m[1] | |
| } | |
| return "" | |
| } | |