Snap-Solver / models /baidu_ocr.py
renxsh
init
f1b4581
import base64
import json
import time
import urllib.request
import urllib.parse
from typing import Generator, Dict, Any
from .base import BaseModel
class BaiduOCRModel(BaseModel):
"""
百度OCR模型,用于图像文字识别
"""
def __init__(self, api_key: str, secret_key: str = None, temperature: float = 0.7, system_prompt: str = None):
"""
初始化百度OCR模型
Args:
api_key: 百度API Key
secret_key: 百度Secret Key(可以在api_key中用冒号分隔传入)
temperature: 不用于OCR但保持BaseModel兼容性
system_prompt: 不用于OCR但保持BaseModel兼容性
Raises:
ValueError: 如果API密钥格式无效
"""
super().__init__(api_key, temperature, system_prompt)
# 支持两种格式:单独传递或在api_key中用冒号分隔
if secret_key:
self.api_key = api_key
self.secret_key = secret_key
else:
try:
self.api_key, self.secret_key = api_key.split(':')
except ValueError:
raise ValueError("百度OCR API密钥必须是 'API_KEY:SECRET_KEY' 格式或单独传递secret_key参数")
# 百度API URLs
self.token_url = "https://aip.baidubce.com/oauth/2.0/token"
self.ocr_url = "https://aip.baidubce.com/rest/2.0/ocr/v1/accurate_basic"
# 缓存access_token
self._access_token = None
self._token_expires = 0
def get_access_token(self) -> str:
"""获取百度API的access_token"""
# 检查是否需要刷新token(提前5分钟刷新)
if self._access_token and time.time() < self._token_expires - 300:
return self._access_token
# 请求新的access_token
params = {
'grant_type': 'client_credentials',
'client_id': self.api_key,
'client_secret': self.secret_key
}
data = urllib.parse.urlencode(params).encode('utf-8')
request = urllib.request.Request(self.token_url, data=data)
request.add_header('Content-Type', 'application/x-www-form-urlencoded')
try:
with urllib.request.urlopen(request) as response:
result = json.loads(response.read().decode('utf-8'))
if 'access_token' in result:
self._access_token = result['access_token']
# 设置过期时间(默认30天,但我们提前刷新)
self._token_expires = time.time() + result.get('expires_in', 2592000)
return self._access_token
else:
raise Exception(f"获取access_token失败: {result.get('error_description', '未知错误')}")
except Exception as e:
raise Exception(f"请求access_token失败: {str(e)}")
def ocr_image(self, image_data: str) -> str:
"""
对图像进行OCR识别
Args:
image_data: Base64编码的图像数据
Returns:
str: 识别出的文字内容
"""
access_token = self.get_access_token()
# 准备请求数据
params = {
'image': image_data,
'language_type': 'auto_detect', # 自动检测语言
'detect_direction': 'true', # 检测图像朝向
'probability': 'false' # 不返回置信度(减少响应大小)
}
data = urllib.parse.urlencode(params).encode('utf-8')
url = f"{self.ocr_url}?access_token={access_token}"
request = urllib.request.Request(url, data=data)
request.add_header('Content-Type', 'application/x-www-form-urlencoded')
try:
with urllib.request.urlopen(request) as response:
result = json.loads(response.read().decode('utf-8'))
if 'error_code' in result:
raise Exception(f"百度OCR API错误: {result.get('error_msg', '未知错误')}")
# 提取识别的文字
words_result = result.get('words_result', [])
text_lines = [item['words'] for item in words_result]
return '\n'.join(text_lines)
except Exception as e:
raise Exception(f"OCR识别失败: {str(e)}")
def extract_full_text(self, image_data: str) -> str:
"""
提取图像中的完整文本(与Mathpix兼容的接口)
Args:
image_data: Base64编码的图像数据
Returns:
str: 提取的文本内容
"""
return self.ocr_image(image_data)
def analyze_image(self, image_data: str, proxies: dict = None) -> Generator[Dict[str, Any], None, None]:
"""
分析图像并返回OCR结果(流式输出以保持接口一致性)
Args:
image_data: Base64编码的图像数据
proxies: 代理配置(未使用)
Yields:
dict: 包含OCR结果的响应
"""
try:
text = self.ocr_image(image_data)
yield {
'status': 'completed',
'content': text,
'model': 'baidu-ocr'
}
except Exception as e:
yield {
'status': 'error',
'content': f'OCR识别失败: {str(e)}',
'model': 'baidu-ocr'
}
def analyze_text(self, text: str, proxies: dict = None) -> Generator[Dict[str, Any], None, None]:
"""
分析文本(OCR模型不支持文本分析)
Args:
text: 输入文本
proxies: 代理配置(未使用)
Yields:
dict: 错误响应
"""
yield {
'status': 'error',
'content': 'OCR模型不支持文本分析功能',
'model': 'baidu-ocr'
}
def get_model_identifier(self) -> str:
"""返回模型标识符"""
return "baidu-ocr"