File size: 6,373 Bytes
00cfefb
 
 
 
30bd1c2
00cfefb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30bd1c2
00cfefb
 
 
 
 
30bd1c2
00cfefb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9fb66e8
 
30bd1c2
00cfefb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
from tenacity import retry, stop_after_delay, wait_fixed, retry_if_result
import requests
import json
import os
import logging


class MTOpenApiClient:

    def __init__(self, api_name, api_key=None, cost_attribution=None):
        """
        Initialize MTOpenApiClient with credentials.
        
        Args:
            api_name (str): API endpoint name
            api_key (str, optional): Direct API key for credential lookup
            cost_attribution (str, optional): Cost attribution key for AI flow credentials
        
        Raises:
            ValueError: If credentials are invalid or missing
            FileNotFoundError: If configuration files are not found
        """
        self.api_name = api_name
        
        # Load credentials based on provided parameters
        if api_key is not None:
            self._load_credentials_by_api_key(api_key)
        elif cost_attribution is not None:
            self._load_credentials_by_cost_attribution(cost_attribution)
        else:
            raise ValueError("Either api_key or cost_attribution must be provided")
        
        # Initialize API URLs
        self._initialize_urls()

    def _get_config_path(self, filename):
        """Get the full path to a configuration file."""
        dir_path = os.path.dirname(os.path.abspath(__file__))
        return os.path.join(dir_path, 'config', filename)

    def _load_json_config(self, filename):
        """Load and parse a JSON configuration file."""
        config_path = self._get_config_path(filename)
        try:
            with open(config_path, 'r', encoding='utf-8') as f:
                return json.load(f)
        except FileNotFoundError:
            raise FileNotFoundError(f"Configuration file not found: {config_path}")
        except json.JSONDecodeError as e:
            raise ValueError(f"Invalid JSON in configuration file {filename}: {e}")

    def _validate_credentials(self):
        """Validate that required credentials are present and non-empty."""
        if not hasattr(self, 'api_key') or not self.api_key:
            raise ValueError("api_key is missing or empty")
        if not hasattr(self, 'api_secret') or not self.api_secret:
            raise ValueError("api_secret is missing or empty")

    def _load_credentials_by_api_key(self, api_key):
        """Load credentials using direct API key lookup."""
        self.api_key = api_key
        data = self._load_json_config('ak_sk_mapping.json')
        self.api_secret = data.get(api_key, "")
        self._validate_credentials()

    def _load_credentials_by_cost_attribution(self, cost_attribution):
        """Load credentials using cost attribution lookup."""
        data = self._load_json_config('ai_flow_ak_sk_mapping.json')
        credentials = data.get(cost_attribution, {})
        self.api_key = credentials.get("ak", "")
        self.api_secret = credentials.get("sk", "")
        self.token = credentials.get("token", "")
        self._validate_credentials()

    def _initialize_urls(self):
        """Initialize API URLs with credentials."""
        base_url = "https://openapi.mtlab.meitu.com/v1"
        auth_params = f"api_key={self.api_key}&api_secret={self.api_secret}"
        
        # 用于异步接口的结果获取
        self.query_url = f"{base_url}/query?{auth_params}"
        # API endpoint URL
        self.url = f"{base_url}/{self.api_name}?{auth_params}"

    def fetch_response(self, msg_id):
        url = f"{self.query_url}&msg_id={msg_id}"
        headers = {"Content-Type": "application/json"}
        data={"msg_id": msg_id}
        response = requests.post(url, json=data, headers=headers)
        response.raise_for_status()  # Raise HTTP errors
        logging.info(f"fetch_response: {response.json()}")
        return response.json()

    @retry(
        stop=stop_after_delay(1000),  # Stop after 100 seconds
        wait=wait_fixed(1),  # Wait 1 second between retries
        retry=retry_if_result(lambda res: res.get("error_code") == 4)  # Retry if error_code != 0
    )
    def get_res(self, msg_id):
        """
        Fetch the result with automatic retries and timeout.
        """
        return self.fetch_response(msg_id)

    def async_request(self, data: dict, max_retries: int = 20):
        """
        发起异步请求并轮询获取结果
        
        Args:
            data (dict): 请求数据体,需包含图片URL等必要参数
                注意:异步接口仅支持图片URL格式,不支持base64编码图片
            max_retries (int): 最大重试次数,默认20次
            
        Returns:
            dict: 通过msg_id轮询获取的最终处理结果
            
        流程说明:
            1. 发送POST请求获取msg_id
            2. 通过msg_id轮询获取最终响应数据
        """
        headers = {"Content-Type": "application/json"}
        
        for attempt in range(max_retries):
            try:
                response = requests.post(self.url, json=data, headers=headers)
                response.raise_for_status()
                
                msg_id = response.json().get("msg_id", "")
                if not msg_id:
                    continue
                    
                result = self.get_res(msg_id=msg_id)
                if result.get("error_code", 0) == 0:
                    return result
                else:
                    raise Exception(result.get("error_msg", ""))

            except requests.RequestException:
                if attempt == max_retries - 1:
                    raise
                continue

        return result if 'result' in locals() else {"error_code": -1, "error_msg": "Max retries exceeded"}

    def request(self, data: dict):
        """
        发起同步请求并返回即时响应
        
        Params:
            data (dict): 请求数据体,可包含base64编码的图片数据
                注意:同步接口支持base64编码图片,异步接口需使用URL
            
        Return:
            dict: 原始响应JSON数据
        """
        # 异步接口只支持输入图片url, 同步接口才支持图片base64
        headers = {"Content-Type": "application/json"}
        response = requests.post(self.url, json=data, headers=headers)
        response.raise_for_status()  # Raise HTTP errors
        return response.json()