| | """ |
| | 通用工具函数模块 |
| | """ |
| |
|
| | import base64 |
| | import json |
| | import logging |
| | import re |
| | from pathlib import Path |
| | from typing import Any, Dict, List, Optional, Tuple |
| |
|
| | import requests |
| |
|
| | from app.config.config import Settings |
| | from app.core.constants import DATA_URL_PATTERN, IMAGE_URL_PATTERN, VALID_IMAGE_RATIOS |
| |
|
| | helper_logger = logging.getLogger("app.utils") |
| |
|
| | PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent |
| | VERSION_FILE_PATH = PROJECT_ROOT / "VERSION" |
| |
|
| |
|
| | def extract_mime_type_and_data(base64_string: str) -> Tuple[Optional[str], str]: |
| | """ |
| | 从 base64 字符串中提取 MIME 类型和数据 |
| | |
| | Args: |
| | base64_string: 可能包含 MIME 类型信息的 base64 字符串 |
| | |
| | Returns: |
| | tuple: (mime_type, encoded_data) |
| | """ |
| | |
| | if base64_string.startswith("data:"): |
| | |
| | pattern = DATA_URL_PATTERN |
| | match = re.match(pattern, base64_string) |
| | if match: |
| | mime_type = ( |
| | "image/jpeg" if match.group(1) == "image/jpg" else match.group(1) |
| | ) |
| | encoded_data = match.group(2) |
| | return mime_type, encoded_data |
| |
|
| | |
| | return None, base64_string |
| |
|
| |
|
| | def convert_image_to_base64(url: str) -> str: |
| | """ |
| | 将图片URL转换为base64编码 |
| | |
| | Args: |
| | url: 图片URL |
| | |
| | Returns: |
| | str: base64编码的图片数据 |
| | |
| | Raises: |
| | Exception: 如果获取图片失败 |
| | """ |
| | response = requests.get(url) |
| | if response.status_code == 200: |
| | |
| | img_data = base64.b64encode(response.content).decode("utf-8") |
| | return img_data |
| | else: |
| | raise Exception(f"Failed to fetch image: {response.status_code}") |
| |
|
| |
|
| | def format_json_response(data: Dict[str, Any], indent: int = 2) -> str: |
| | """ |
| | 格式化JSON响应 |
| | |
| | Args: |
| | data: 要格式化的数据 |
| | indent: 缩进空格数 |
| | |
| | Returns: |
| | str: 格式化后的JSON字符串 |
| | """ |
| | return json.dumps(data, indent=indent, ensure_ascii=False) |
| |
|
| |
|
| | def parse_prompt_parameters( |
| | prompt: str, default_ratio: str = "1:1" |
| | ) -> Tuple[str, int, str]: |
| | """ |
| | 从prompt中解析参数 |
| | |
| | 支持的格式: |
| | - {n:数量} 例如: {n:2} 生成2张图片 |
| | - {ratio:比例} 例如: {ratio:16:9} 使用16:9比例 |
| | |
| | Args: |
| | prompt: 提示文本 |
| | default_ratio: 默认比例 |
| | |
| | Returns: |
| | tuple: (清理后的提示文本, 图片数量, 比例) |
| | """ |
| | |
| | n = 1 |
| | aspect_ratio = default_ratio |
| |
|
| | |
| | n_match = re.search(r"{n:(\d+)}", prompt) |
| | if n_match: |
| | n = int(n_match.group(1)) |
| | if n < 1 or n > 4: |
| | raise ValueError(f"Invalid n value: {n}. Must be between 1 and 4.") |
| | prompt = prompt.replace(n_match.group(0), "").strip() |
| |
|
| | |
| | ratio_match = re.search(r"{ratio:(\d+:\d+)}", prompt) |
| | if ratio_match: |
| | aspect_ratio = ratio_match.group(1) |
| | if aspect_ratio not in VALID_IMAGE_RATIOS: |
| | raise ValueError( |
| | f"Invalid ratio: {aspect_ratio}. Must be one of: {', '.join(VALID_IMAGE_RATIOS)}" |
| | ) |
| | prompt = prompt.replace(ratio_match.group(0), "").strip() |
| |
|
| | return prompt, n, aspect_ratio |
| |
|
| |
|
| | def extract_image_urls_from_markdown(text: str) -> List[str]: |
| | """ |
| | 从Markdown文本中提取图片URL |
| | |
| | Args: |
| | text: Markdown文本 |
| | |
| | Returns: |
| | List[str]: 图片URL列表 |
| | """ |
| | pattern = IMAGE_URL_PATTERN |
| | matches = re.findall(pattern, text) |
| | return [match[1] for match in matches] |
| |
|
| |
|
| | def is_valid_api_key(key: str) -> bool: |
| | """ |
| | 检查API密钥格式是否有效 |
| | |
| | Args: |
| | key: API密钥 |
| | |
| | Returns: |
| | bool: 如果密钥格式有效则返回True |
| | """ |
| | |
| | if key.startswith("AIza"): |
| | return len(key) >= 30 |
| |
|
| | |
| | if key.startswith("sk-"): |
| | return len(key) >= 30 |
| |
|
| | return False |
| |
|
| |
|
| | def redact_key_for_logging(key: str) -> str: |
| | """ |
| | Redacts API key for secure logging by showing only first and last 6 characters. |
| | |
| | Args: |
| | key: API key to redact |
| | |
| | Returns: |
| | str: Redacted key in format "first6...last6" or descriptive placeholder for edge cases |
| | """ |
| | if not key: |
| | return key |
| |
|
| | if len(key) <= 12: |
| | return f"{key[:3]}...{key[-3:]}" |
| | else: |
| | return f"{key[:6]}...{key[-6:]}" |
| |
|
| |
|
| | def get_current_version(default_version: str = "0.0.0") -> str: |
| | """Reads the current version from the VERSION file.""" |
| | version_file = VERSION_FILE_PATH |
| | try: |
| | with version_file.open("r", encoding="utf-8") as f: |
| | version = f.read().strip() |
| | if not version: |
| | helper_logger.warning( |
| | f"VERSION file ('{version_file}') is empty. Using default version '{default_version}'." |
| | ) |
| | return default_version |
| | return version |
| | except FileNotFoundError: |
| | helper_logger.warning( |
| | f"VERSION file not found at '{version_file}'. Using default version '{default_version}'." |
| | ) |
| | return default_version |
| | except IOError as e: |
| | helper_logger.error( |
| | f"Error reading VERSION file ('{version_file}'): {e}. Using default version '{default_version}'." |
| | ) |
| | return default_version |
| |
|
| |
|
| | def is_image_upload_configured(settings: Settings) -> bool: |
| | """Return True only if a valid upload provider is selected and all required settings for that provider are present.""" |
| |
|
| | provider = (getattr(settings, "UPLOAD_PROVIDER", "") or "").strip().lower() |
| | if provider == "smms": |
| | return bool(getattr(settings, "SMMS_SECRET_TOKEN", "")) |
| | if provider == "picgo": |
| | return bool(getattr(settings, "PICGO_API_KEY", "")) |
| | if provider == "cloudflare_imgbed": |
| | return all( |
| | [ |
| | getattr(settings, "CLOUDFLARE_IMGBED_URL", ""), |
| | getattr(settings, "CLOUDFLARE_IMGBED_AUTH_CODE", ""), |
| | ] |
| | ) |
| | return False |
| |
|