Spaces:
Paused
Paused
| """ | |
| 通用工具函数模块 | |
| """ | |
| import json | |
| import re | |
| import base64 | |
| import requests | |
| from typing import Dict, Any, List, Optional, Tuple | |
| from pathlib import Path | |
| import logging | |
| from app.core.constants import DATA_URL_PATTERN, IMAGE_URL_PATTERN, VALID_IMAGE_RATIOS | |
| helper_logger = logging.getLogger("app.utils") | |
| PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent | |
| VERSION_FILE_PATH = PROJECT_ROOT / "VERSION" | |
| def extract_mime_type_and_data(base64_string: str) -> Tuple[Optional[str], str]: | |
| """ | |
| 从 base64 字符串中提取 MIME 类型和数据 | |
| Args: | |
| base64_string: 可能包含 MIME 类型信息的 base64 字符串 | |
| Returns: | |
| tuple: (mime_type, encoded_data) | |
| """ | |
| # 检查字符串是否以 "data:" 格式开始 | |
| if base64_string.startswith('data:'): | |
| # 提取 MIME 类型和数据 | |
| pattern = DATA_URL_PATTERN | |
| match = re.match(pattern, base64_string) | |
| if match: | |
| mime_type = "image/jpeg" if match.group(1) == "image/jpg" else match.group(1) | |
| encoded_data = match.group(2) | |
| return mime_type, encoded_data | |
| # 如果不是预期格式,假定它只是数据部分 | |
| return None, base64_string | |
| def convert_image_to_base64(url: str) -> str: | |
| """ | |
| 将图片URL转换为base64编码 | |
| Args: | |
| url: 图片URL | |
| Returns: | |
| str: base64编码的图片数据 | |
| Raises: | |
| Exception: 如果获取图片失败 | |
| """ | |
| response = requests.get(url) | |
| if response.status_code == 200: | |
| # 将图片内容转换为base64 | |
| img_data = base64.b64encode(response.content).decode('utf-8') | |
| return img_data | |
| else: | |
| raise Exception(f"Failed to fetch image: {response.status_code}") | |
| def format_json_response(data: Dict[str, Any], indent: int = 2) -> str: | |
| """ | |
| 格式化JSON响应 | |
| Args: | |
| data: 要格式化的数据 | |
| indent: 缩进空格数 | |
| Returns: | |
| str: 格式化后的JSON字符串 | |
| """ | |
| return json.dumps(data, indent=indent, ensure_ascii=False) | |
| def parse_prompt_parameters(prompt: str, default_ratio: str = "1:1") -> Tuple[str, int, str]: | |
| """ | |
| 从prompt中解析参数 | |
| 支持的格式: | |
| - {n:数量} 例如: {n:2} 生成2张图片 | |
| - {ratio:比例} 例如: {ratio:16:9} 使用16:9比例 | |
| Args: | |
| prompt: 提示文本 | |
| default_ratio: 默认比例 | |
| Returns: | |
| tuple: (清理后的提示文本, 图片数量, 比例) | |
| """ | |
| # 默认值 | |
| n = 1 | |
| aspect_ratio = default_ratio | |
| # 解析n参数 | |
| n_match = re.search(r'{n:(\d+)}', prompt) | |
| if n_match: | |
| n = int(n_match.group(1)) | |
| if n < 1 or n > 4: | |
| raise ValueError(f"Invalid n value: {n}. Must be between 1 and 4.") | |
| prompt = prompt.replace(n_match.group(0), '').strip() | |
| # 解析ratio参数 | |
| ratio_match = re.search(r'{ratio:(\d+:\d+)}', prompt) | |
| if ratio_match: | |
| aspect_ratio = ratio_match.group(1) | |
| if aspect_ratio not in VALID_IMAGE_RATIOS: | |
| raise ValueError( | |
| f"Invalid ratio: {aspect_ratio}. Must be one of: {', '.join(VALID_IMAGE_RATIOS)}" | |
| ) | |
| prompt = prompt.replace(ratio_match.group(0), '').strip() | |
| return prompt, n, aspect_ratio | |
| def extract_image_urls_from_markdown(text: str) -> List[str]: | |
| """ | |
| 从Markdown文本中提取图片URL | |
| Args: | |
| text: Markdown文本 | |
| Returns: | |
| List[str]: 图片URL列表 | |
| """ | |
| pattern = IMAGE_URL_PATTERN | |
| matches = re.findall(pattern, text) | |
| return [match[1] for match in matches] | |
| def is_valid_api_key(key: str) -> bool: | |
| """ | |
| 检查API密钥格式是否有效 | |
| Args: | |
| key: API密钥 | |
| Returns: | |
| bool: 如果密钥格式有效则返回True | |
| """ | |
| # 检查Gemini API密钥格式 | |
| if key.startswith('AIza'): | |
| return len(key) >= 30 | |
| # 检查OpenAI API密钥格式 | |
| if key.startswith('sk-'): | |
| return len(key) >= 30 | |
| return False | |
| def get_current_version(default_version: str = "0.0.0") -> str: | |
| """Reads the current version from the VERSION file.""" | |
| version_file = VERSION_FILE_PATH | |
| try: | |
| with version_file.open('r', encoding='utf-8') as f: | |
| version = f.read().strip() | |
| if not version: | |
| helper_logger.warning(f"VERSION file ('{version_file}') is empty. Using default version '{default_version}'.") | |
| return default_version | |
| return version | |
| except FileNotFoundError: | |
| helper_logger.warning(f"VERSION file not found at '{version_file}'. Using default version '{default_version}'.") | |
| return default_version | |
| except IOError as e: | |
| helper_logger.error(f"Error reading VERSION file ('{version_file}'): {e}. Using default version '{default_version}'.") | |
| return default_version | |