File size: 6,286 Bytes
f1b4581
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
import base64
import json
import time
import urllib.request
import urllib.parse
from typing import Generator, Dict, Any
from .base import BaseModel

class BaiduOCRModel(BaseModel):
    """
    百度OCR模型,用于图像文字识别
    """
    
    def __init__(self, api_key: str, secret_key: str = None, temperature: float = 0.7, system_prompt: str = None):
        """
        初始化百度OCR模型
        
        Args:
            api_key: 百度API Key
            secret_key: 百度Secret Key(可以在api_key中用冒号分隔传入)
            temperature: 不用于OCR但保持BaseModel兼容性
            system_prompt: 不用于OCR但保持BaseModel兼容性
            
        Raises:
            ValueError: 如果API密钥格式无效
        """
        super().__init__(api_key, temperature, system_prompt)
        
        # 支持两种格式:单独传递或在api_key中用冒号分隔
        if secret_key:
            self.api_key = api_key
            self.secret_key = secret_key
        else:
            try:
                self.api_key, self.secret_key = api_key.split(':')
            except ValueError:
                raise ValueError("百度OCR API密钥必须是 'API_KEY:SECRET_KEY' 格式或单独传递secret_key参数")
        
        # 百度API URLs
        self.token_url = "https://aip.baidubce.com/oauth/2.0/token"
        self.ocr_url = "https://aip.baidubce.com/rest/2.0/ocr/v1/accurate_basic"
        
        # 缓存access_token
        self._access_token = None
        self._token_expires = 0
    
    def get_access_token(self) -> str:
        """获取百度API的access_token"""
        # 检查是否需要刷新token(提前5分钟刷新)
        if self._access_token and time.time() < self._token_expires - 300:
            return self._access_token
        
        # 请求新的access_token
        params = {
            'grant_type': 'client_credentials',
            'client_id': self.api_key,
            'client_secret': self.secret_key
        }
        
        data = urllib.parse.urlencode(params).encode('utf-8')
        request = urllib.request.Request(self.token_url, data=data)
        request.add_header('Content-Type', 'application/x-www-form-urlencoded')
        
        try:
            with urllib.request.urlopen(request) as response:
                result = json.loads(response.read().decode('utf-8'))
                
            if 'access_token' in result:
                self._access_token = result['access_token']
                # 设置过期时间(默认30天,但我们提前刷新)
                self._token_expires = time.time() + result.get('expires_in', 2592000)
                return self._access_token
            else:
                raise Exception(f"获取access_token失败: {result.get('error_description', '未知错误')}")
                
        except Exception as e:
            raise Exception(f"请求access_token失败: {str(e)}")
    
    def ocr_image(self, image_data: str) -> str:
        """
        对图像进行OCR识别
        
        Args:
            image_data: Base64编码的图像数据
            
        Returns:
            str: 识别出的文字内容
        """
        access_token = self.get_access_token()
        
        # 准备请求数据
        params = {
            'image': image_data,
            'language_type': 'auto_detect',  # 自动检测语言
            'detect_direction': 'true',      # 检测图像朝向
            'probability': 'false'           # 不返回置信度(减少响应大小)
        }
        
        data = urllib.parse.urlencode(params).encode('utf-8')
        url = f"{self.ocr_url}?access_token={access_token}"
        
        request = urllib.request.Request(url, data=data)
        request.add_header('Content-Type', 'application/x-www-form-urlencoded')
        
        try:
            with urllib.request.urlopen(request) as response:
                result = json.loads(response.read().decode('utf-8'))
                
            if 'error_code' in result:
                raise Exception(f"百度OCR API错误: {result.get('error_msg', '未知错误')}")
            
            # 提取识别的文字
            words_result = result.get('words_result', [])
            text_lines = [item['words'] for item in words_result]
            
            return '\n'.join(text_lines)
            
        except Exception as e:
            raise Exception(f"OCR识别失败: {str(e)}")
    
    def extract_full_text(self, image_data: str) -> str:
        """
        提取图像中的完整文本(与Mathpix兼容的接口)
        
        Args:
            image_data: Base64编码的图像数据
            
        Returns:
            str: 提取的文本内容
        """
        return self.ocr_image(image_data)
    
    def analyze_image(self, image_data: str, proxies: dict = None) -> Generator[Dict[str, Any], None, None]:
        """
        分析图像并返回OCR结果(流式输出以保持接口一致性)
        
        Args:
            image_data: Base64编码的图像数据
            proxies: 代理配置(未使用)
            
        Yields:
            dict: 包含OCR结果的响应
        """
        try:
            text = self.ocr_image(image_data)
            yield {
                'status': 'completed',
                'content': text,
                'model': 'baidu-ocr'
            }
        except Exception as e:
            yield {
                'status': 'error',
                'content': f'OCR识别失败: {str(e)}',
                'model': 'baidu-ocr'
            }
    
    def analyze_text(self, text: str, proxies: dict = None) -> Generator[Dict[str, Any], None, None]:
        """
        分析文本(OCR模型不支持文本分析)
        
        Args:
            text: 输入文本
            proxies: 代理配置(未使用)
            
        Yields:
            dict: 错误响应
        """
        yield {
            'status': 'error',
            'content': 'OCR模型不支持文本分析功能',
            'model': 'baidu-ocr'
        }
    
    def get_model_identifier(self) -> str:
        """返回模型标识符"""
        return "baidu-ocr"