Spaces:
Sleeping
Sleeping
| import base64 | |
| import json | |
| import time | |
| import urllib.request | |
| import urllib.parse | |
| from typing import Generator, Dict, Any | |
| from .base import BaseModel | |
| class BaiduOCRModel(BaseModel): | |
| """ | |
| 百度OCR模型,用于图像文字识别 | |
| """ | |
| def __init__(self, api_key: str, secret_key: str = None, temperature: float = 0.7, system_prompt: str = None): | |
| """ | |
| 初始化百度OCR模型 | |
| Args: | |
| api_key: 百度API Key | |
| secret_key: 百度Secret Key(可以在api_key中用冒号分隔传入) | |
| temperature: 不用于OCR但保持BaseModel兼容性 | |
| system_prompt: 不用于OCR但保持BaseModel兼容性 | |
| Raises: | |
| ValueError: 如果API密钥格式无效 | |
| """ | |
| super().__init__(api_key, temperature, system_prompt) | |
| # 支持两种格式:单独传递或在api_key中用冒号分隔 | |
| if secret_key: | |
| self.api_key = api_key | |
| self.secret_key = secret_key | |
| else: | |
| try: | |
| self.api_key, self.secret_key = api_key.split(':') | |
| except ValueError: | |
| raise ValueError("百度OCR API密钥必须是 'API_KEY:SECRET_KEY' 格式或单独传递secret_key参数") | |
| # 百度API URLs | |
| self.token_url = "https://aip.baidubce.com/oauth/2.0/token" | |
| self.ocr_url = "https://aip.baidubce.com/rest/2.0/ocr/v1/accurate_basic" | |
| # 缓存access_token | |
| self._access_token = None | |
| self._token_expires = 0 | |
| def get_access_token(self) -> str: | |
| """获取百度API的access_token""" | |
| # 检查是否需要刷新token(提前5分钟刷新) | |
| if self._access_token and time.time() < self._token_expires - 300: | |
| return self._access_token | |
| # 请求新的access_token | |
| params = { | |
| 'grant_type': 'client_credentials', | |
| 'client_id': self.api_key, | |
| 'client_secret': self.secret_key | |
| } | |
| data = urllib.parse.urlencode(params).encode('utf-8') | |
| request = urllib.request.Request(self.token_url, data=data) | |
| request.add_header('Content-Type', 'application/x-www-form-urlencoded') | |
| try: | |
| with urllib.request.urlopen(request) as response: | |
| result = json.loads(response.read().decode('utf-8')) | |
| if 'access_token' in result: | |
| self._access_token = result['access_token'] | |
| # 设置过期时间(默认30天,但我们提前刷新) | |
| self._token_expires = time.time() + result.get('expires_in', 2592000) | |
| return self._access_token | |
| else: | |
| raise Exception(f"获取access_token失败: {result.get('error_description', '未知错误')}") | |
| except Exception as e: | |
| raise Exception(f"请求access_token失败: {str(e)}") | |
| def ocr_image(self, image_data: str) -> str: | |
| """ | |
| 对图像进行OCR识别 | |
| Args: | |
| image_data: Base64编码的图像数据 | |
| Returns: | |
| str: 识别出的文字内容 | |
| """ | |
| access_token = self.get_access_token() | |
| # 准备请求数据 | |
| params = { | |
| 'image': image_data, | |
| 'language_type': 'auto_detect', # 自动检测语言 | |
| 'detect_direction': 'true', # 检测图像朝向 | |
| 'probability': 'false' # 不返回置信度(减少响应大小) | |
| } | |
| data = urllib.parse.urlencode(params).encode('utf-8') | |
| url = f"{self.ocr_url}?access_token={access_token}" | |
| request = urllib.request.Request(url, data=data) | |
| request.add_header('Content-Type', 'application/x-www-form-urlencoded') | |
| try: | |
| with urllib.request.urlopen(request) as response: | |
| result = json.loads(response.read().decode('utf-8')) | |
| if 'error_code' in result: | |
| raise Exception(f"百度OCR API错误: {result.get('error_msg', '未知错误')}") | |
| # 提取识别的文字 | |
| words_result = result.get('words_result', []) | |
| text_lines = [item['words'] for item in words_result] | |
| return '\n'.join(text_lines) | |
| except Exception as e: | |
| raise Exception(f"OCR识别失败: {str(e)}") | |
| def extract_full_text(self, image_data: str) -> str: | |
| """ | |
| 提取图像中的完整文本(与Mathpix兼容的接口) | |
| Args: | |
| image_data: Base64编码的图像数据 | |
| Returns: | |
| str: 提取的文本内容 | |
| """ | |
| return self.ocr_image(image_data) | |
| def analyze_image(self, image_data: str, proxies: dict = None) -> Generator[Dict[str, Any], None, None]: | |
| """ | |
| 分析图像并返回OCR结果(流式输出以保持接口一致性) | |
| Args: | |
| image_data: Base64编码的图像数据 | |
| proxies: 代理配置(未使用) | |
| Yields: | |
| dict: 包含OCR结果的响应 | |
| """ | |
| try: | |
| text = self.ocr_image(image_data) | |
| yield { | |
| 'status': 'completed', | |
| 'content': text, | |
| 'model': 'baidu-ocr' | |
| } | |
| except Exception as e: | |
| yield { | |
| 'status': 'error', | |
| 'content': f'OCR识别失败: {str(e)}', | |
| 'model': 'baidu-ocr' | |
| } | |
| def analyze_text(self, text: str, proxies: dict = None) -> Generator[Dict[str, Any], None, None]: | |
| """ | |
| 分析文本(OCR模型不支持文本分析) | |
| Args: | |
| text: 输入文本 | |
| proxies: 代理配置(未使用) | |
| Yields: | |
| dict: 错误响应 | |
| """ | |
| yield { | |
| 'status': 'error', | |
| 'content': 'OCR模型不支持文本分析功能', | |
| 'model': 'baidu-ocr' | |
| } | |
| def get_model_identifier(self) -> str: | |
| """返回模型标识符""" | |
| return "baidu-ocr" | |