|
|
import base64 |
|
|
import hashlib |
|
|
import os |
|
|
import re |
|
|
import shutil |
|
|
import threading |
|
|
import time |
|
|
from concurrent.futures import ThreadPoolExecutor, as_completed |
|
|
from typing import Optional |
|
|
from collections import OrderedDict |
|
|
|
|
|
import cv2 |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
|
|
|
try: |
|
|
import boto3 |
|
|
from botocore.exceptions import BotoCoreError, ClientError |
|
|
except ImportError: |
|
|
boto3 = None |
|
|
BotoCoreError = ClientError = Exception |
|
|
|
|
|
from config import ( |
|
|
IMAGES_DIR, |
|
|
logger, |
|
|
SAVE_QUALITY, |
|
|
MODELS_PATH, |
|
|
BOS_ACCESS_KEY, |
|
|
BOS_SECRET_KEY, |
|
|
BOS_ENDPOINT, |
|
|
BOS_BUCKET_NAME, |
|
|
BOS_IMAGE_DIR, |
|
|
BOS_UPLOAD_ENABLED, |
|
|
BOS_DOWNLOAD_TARGETS, |
|
|
HUGGINGFACE_REPO_ID, |
|
|
HUGGINGFACE_SYNC_ENABLED, |
|
|
HUGGINGFACE_REVISION, |
|
|
HUGGINGFACE_ALLOW_PATTERNS, |
|
|
HUGGINGFACE_IGNORE_PATTERNS, |
|
|
) |
|
|
|
|
|
_BOS_CLIENT = None |
|
|
_BOS_CLIENT_INITIALIZED = False |
|
|
_BOS_CLIENT_LOCK = threading.Lock() |
|
|
_BOS_DOWNLOAD_LOCK = threading.Lock() |
|
|
_BOS_DOWNLOAD_COMPLETED = False |
|
|
_BOS_BACKGROUND_EXECUTOR = None |
|
|
_BOS_BACKGROUND_FUTURES = [] |
|
|
_IMAGES_DIR_ABS = os.path.abspath(os.path.expanduser(IMAGES_DIR)) |
|
|
_BOS_UPLOAD_CACHE = OrderedDict() |
|
|
_BOS_UPLOAD_CACHE_LOCK = threading.Lock() |
|
|
_BOS_UPLOAD_CACHE_MAX = 2048 |
|
|
|
|
|
|
|
|
def _decode_bos_credential(raw_value: str) -> str: |
|
|
"""将Base64编码的凭证解码为明文,若解码失败则返回原值""" |
|
|
if not raw_value: |
|
|
return "" |
|
|
|
|
|
value = raw_value.strip() |
|
|
if not value: |
|
|
return "" |
|
|
|
|
|
try: |
|
|
padding = len(value) % 4 |
|
|
if padding: |
|
|
value += "=" * (4 - padding) |
|
|
decoded = base64.b64decode(value).decode("utf-8").strip() |
|
|
if decoded: |
|
|
return decoded |
|
|
except Exception: |
|
|
pass |
|
|
return value |
|
|
|
|
|
|
|
|
def _is_path_under_images_dir(file_path: str) -> bool: |
|
|
try: |
|
|
return os.path.commonpath( |
|
|
[_IMAGES_DIR_ABS, os.path.abspath(file_path)] |
|
|
) == _IMAGES_DIR_ABS |
|
|
except ValueError: |
|
|
return False |
|
|
|
|
|
|
|
|
def _get_bos_client(): |
|
|
global _BOS_CLIENT, _BOS_CLIENT_INITIALIZED |
|
|
if _BOS_CLIENT_INITIALIZED: |
|
|
return _BOS_CLIENT |
|
|
|
|
|
with _BOS_CLIENT_LOCK: |
|
|
if _BOS_CLIENT_INITIALIZED: |
|
|
return _BOS_CLIENT |
|
|
|
|
|
if not BOS_UPLOAD_ENABLED: |
|
|
_BOS_CLIENT_INITIALIZED = True |
|
|
_BOS_CLIENT = None |
|
|
return None |
|
|
access_key = _decode_bos_credential(BOS_ACCESS_KEY) |
|
|
secret_key = _decode_bos_credential(BOS_SECRET_KEY) |
|
|
if not all([access_key, secret_key, BOS_ENDPOINT, BOS_BUCKET_NAME]): |
|
|
logger.warning("BOS 上传未配置完整,跳过初始化") |
|
|
_BOS_CLIENT_INITIALIZED = True |
|
|
_BOS_CLIENT = None |
|
|
return None |
|
|
|
|
|
if boto3 is None: |
|
|
logger.warning("未安装 boto3,BOS 上传功能不可用") |
|
|
_BOS_CLIENT_INITIALIZED = True |
|
|
_BOS_CLIENT = None |
|
|
return None |
|
|
|
|
|
try: |
|
|
_BOS_CLIENT = boto3.client( |
|
|
"s3", |
|
|
aws_access_key_id=access_key, |
|
|
aws_secret_access_key=secret_key, |
|
|
endpoint_url=BOS_ENDPOINT, |
|
|
) |
|
|
logger.info("BOS 客户端初始化成功") |
|
|
except Exception as e: |
|
|
logger.warning(f"初始化 BOS 客户端失败,将跳过上传: {e}") |
|
|
_BOS_CLIENT = None |
|
|
finally: |
|
|
_BOS_CLIENT_INITIALIZED = True |
|
|
|
|
|
return _BOS_CLIENT |
|
|
|
|
|
|
|
|
def _normalize_bos_prefix(prefix: Optional[str]) -> str: |
|
|
value = (prefix or "").strip() |
|
|
if not value: |
|
|
return "" |
|
|
value = value.strip("/") |
|
|
if not value: |
|
|
return "" |
|
|
return f"{value}/" if not value.endswith("/") else value |
|
|
|
|
|
|
|
|
def _directory_has_files(path: str) -> bool: |
|
|
try: |
|
|
for _root, _dirs, files in os.walk(path): |
|
|
if files: |
|
|
return True |
|
|
except Exception: |
|
|
return False |
|
|
return False |
|
|
|
|
|
|
|
|
def download_bos_directory(prefix: str, destination_dir: str, *, force_download: bool = False) -> bool: |
|
|
""" |
|
|
将 BOS 上的指定前缀目录同步到本地。 |
|
|
:param prefix: BOS 对象前缀,例如 'models/' 或 '20220620/models' |
|
|
:param destination_dir: 本地目标目录 |
|
|
:param force_download: 是否强制重新下载(忽略本地已存在的文件) |
|
|
:return: 是否确保目录可用 |
|
|
""" |
|
|
client = _get_bos_client() |
|
|
if client is None: |
|
|
logger.warning("BOS 客户端不可用,无法下载资源(prefix=%s)", prefix) |
|
|
return False |
|
|
|
|
|
dest_dir = os.path.abspath(os.path.expanduser(destination_dir)) |
|
|
try: |
|
|
os.makedirs(dest_dir, exist_ok=True) |
|
|
except Exception as exc: |
|
|
logger.error("创建本地目录失败: %s (%s)", dest_dir, exc) |
|
|
return False |
|
|
|
|
|
normalized_prefix = _normalize_bos_prefix(prefix) |
|
|
|
|
|
|
|
|
if not force_download and _directory_has_files(dest_dir): |
|
|
logger.info("本地目录已存在文件,跳过下载: %s -> %s", normalized_prefix or "<root>", dest_dir) |
|
|
return True |
|
|
|
|
|
paginate_kwargs = {"Bucket": BOS_BUCKET_NAME} |
|
|
if normalized_prefix: |
|
|
paginate_kwargs["Prefix"] = normalized_prefix if normalized_prefix.endswith("/") else f"{normalized_prefix}/" |
|
|
|
|
|
found_any = False |
|
|
downloaded = 0 |
|
|
skipped = 0 |
|
|
|
|
|
try: |
|
|
paginator = client.get_paginator("list_objects_v2") |
|
|
for page in paginator.paginate(**paginate_kwargs): |
|
|
for obj in page.get("Contents", []): |
|
|
key = obj.get("Key") |
|
|
if not key: |
|
|
continue |
|
|
if normalized_prefix: |
|
|
prefix_with_slash = normalized_prefix if normalized_prefix.endswith("/") else f"{normalized_prefix}/" |
|
|
if not key.startswith(prefix_with_slash): |
|
|
continue |
|
|
relative_key = key[len(prefix_with_slash):] |
|
|
else: |
|
|
relative_key = key |
|
|
|
|
|
if not relative_key or relative_key.endswith("/"): |
|
|
continue |
|
|
found_any = True |
|
|
|
|
|
target_path = os.path.join(dest_dir, relative_key) |
|
|
target_dir = os.path.dirname(target_path) |
|
|
os.makedirs(target_dir, exist_ok=True) |
|
|
|
|
|
expected_size = obj.get("Size") |
|
|
if ( |
|
|
not force_download |
|
|
and os.path.exists(target_path) |
|
|
and expected_size is not None |
|
|
and expected_size == os.path.getsize(target_path) |
|
|
): |
|
|
skipped += 1 |
|
|
logger.info("文件已存在且大小一致,跳过下载: %s", relative_key) |
|
|
continue |
|
|
|
|
|
tmp_path = f"{target_path}.download" |
|
|
try: |
|
|
size_mb = (expected_size or 0) / (1024 * 1024) |
|
|
logger.info("开始下载: %s (%.2f MB)", relative_key, size_mb) |
|
|
client.download_file(Bucket=BOS_BUCKET_NAME, Key=key, Filename=tmp_path) |
|
|
os.replace(tmp_path, target_path) |
|
|
downloaded += 1 |
|
|
logger.info("下载完成: %s", relative_key) |
|
|
except Exception as exc: |
|
|
logger.warning("下载失败: %s (%s)", key, exc) |
|
|
try: |
|
|
if os.path.exists(tmp_path): |
|
|
os.remove(tmp_path) |
|
|
except Exception: |
|
|
pass |
|
|
except Exception as exc: |
|
|
logger.warning("遍历 BOS 目录失败: %s", exc) |
|
|
return False |
|
|
|
|
|
if not found_any: |
|
|
logger.warning("在 BOS 桶 %s 中未找到前缀 '%s' 的内容", BOS_BUCKET_NAME, normalized_prefix or "<root>") |
|
|
return False |
|
|
|
|
|
logger.info( |
|
|
"BOS 同步完成 prefix=%s -> %s 下载=%d 跳过=%d", |
|
|
normalized_prefix or "<root>", |
|
|
dest_dir, |
|
|
downloaded, |
|
|
skipped, |
|
|
) |
|
|
return downloaded > 0 or skipped > 0 |
|
|
|
|
|
|
|
|
def _get_background_executor() -> ThreadPoolExecutor: |
|
|
global _BOS_BACKGROUND_EXECUTOR |
|
|
if _BOS_BACKGROUND_EXECUTOR is None: |
|
|
_BOS_BACKGROUND_EXECUTOR = ThreadPoolExecutor(max_workers=2, thread_name_prefix="bos-bg") |
|
|
return _BOS_BACKGROUND_EXECUTOR |
|
|
|
|
|
|
|
|
def ensure_huggingface_models(force_download: bool = False) -> bool: |
|
|
"""确保 HuggingFace 模型仓库同步到本地 MODELS_PATH。""" |
|
|
if not HUGGINGFACE_SYNC_ENABLED: |
|
|
logger.info("HuggingFace 模型同步开关已关闭,跳过同步流程") |
|
|
return True |
|
|
|
|
|
repo_id = (HUGGINGFACE_REPO_ID or "").strip() |
|
|
if not repo_id: |
|
|
logger.info("未配置 HuggingFace 仓库,跳过模型下载") |
|
|
return True |
|
|
|
|
|
try: |
|
|
from huggingface_hub import snapshot_download |
|
|
except ImportError: |
|
|
logger.error("未安装 huggingface-hub,无法下载 HuggingFace 模型") |
|
|
return False |
|
|
|
|
|
try: |
|
|
os.makedirs(MODELS_PATH, exist_ok=True) |
|
|
except Exception as exc: |
|
|
logger.error("创建模型目录失败: %s (%s)", MODELS_PATH, exc) |
|
|
return False |
|
|
|
|
|
download_kwargs = { |
|
|
"repo_id": repo_id, |
|
|
"local_dir": MODELS_PATH, |
|
|
"local_dir_use_symlinks": False, |
|
|
} |
|
|
|
|
|
revision = (HUGGINGFACE_REVISION or "").strip() |
|
|
if revision: |
|
|
download_kwargs["revision"] = revision |
|
|
|
|
|
if HUGGINGFACE_ALLOW_PATTERNS: |
|
|
download_kwargs["allow_patterns"] = HUGGINGFACE_ALLOW_PATTERNS |
|
|
|
|
|
if HUGGINGFACE_IGNORE_PATTERNS: |
|
|
download_kwargs["ignore_patterns"] = HUGGINGFACE_IGNORE_PATTERNS |
|
|
|
|
|
if force_download: |
|
|
download_kwargs["force_download"] = True |
|
|
download_kwargs["resume_download"] = False |
|
|
else: |
|
|
download_kwargs["resume_download"] = True |
|
|
|
|
|
try: |
|
|
logger.info( |
|
|
"开始同步 HuggingFace 模型: repo=%s revision=%s -> %s", |
|
|
repo_id, |
|
|
revision or "<default>", |
|
|
MODELS_PATH, |
|
|
) |
|
|
snapshot_path = snapshot_download(**download_kwargs) |
|
|
logger.info( |
|
|
"HuggingFace 模型同步完成: %s -> %s", |
|
|
repo_id, |
|
|
snapshot_path, |
|
|
) |
|
|
return True |
|
|
except Exception as exc: |
|
|
logger.error("HuggingFace 模型下载失败: %s", exc) |
|
|
return False |
|
|
|
|
|
|
|
|
def ensure_bos_resources(force_download: bool = False, include_background: bool = False) -> bool: |
|
|
""" |
|
|
根据配置的 BOS_DOWNLOAD_TARGETS 同步启动所需的模型与数据资源。 |
|
|
:param force_download: 是否强制重新同步所有资源 |
|
|
:param include_background: 是否将标记为后台任务的目标也同步为阻塞任务 |
|
|
:return: 资源是否已准备就绪 |
|
|
""" |
|
|
global _BOS_DOWNLOAD_COMPLETED, _BOS_BACKGROUND_FUTURES |
|
|
|
|
|
with _BOS_DOWNLOAD_LOCK: |
|
|
if _BOS_DOWNLOAD_COMPLETED and not force_download and not include_background: |
|
|
return True |
|
|
|
|
|
targets = BOS_DOWNLOAD_TARGETS or [] |
|
|
if not targets: |
|
|
logger.info("未配置 BOS 下载目标,跳过资源同步") |
|
|
_BOS_DOWNLOAD_COMPLETED = True |
|
|
return True |
|
|
|
|
|
download_jobs = [] |
|
|
background_jobs = [] |
|
|
for target in targets: |
|
|
if not isinstance(target, dict): |
|
|
logger.warning("无效的 BOS 下载配置项: %r", target) |
|
|
continue |
|
|
|
|
|
prefix = target.get("bos_prefix") |
|
|
destination = target.get("destination") |
|
|
description = target.get("description") or prefix or "<unnamed>" |
|
|
background_flag = bool(target.get("background")) |
|
|
|
|
|
if not prefix or not destination: |
|
|
logger.warning("缺少必要字段,无法处理 BOS 下载配置: %r", target) |
|
|
continue |
|
|
|
|
|
job = { |
|
|
"description": description, |
|
|
"prefix": prefix, |
|
|
"destination": destination, |
|
|
} |
|
|
|
|
|
if background_flag and not include_background: |
|
|
background_jobs.append(job) |
|
|
else: |
|
|
download_jobs.append(job) |
|
|
|
|
|
results = [] |
|
|
if download_jobs: |
|
|
max_workers = min(len(download_jobs), max(os.cpu_count() or 1, 1)) |
|
|
if max_workers <= 0: |
|
|
max_workers = 1 |
|
|
|
|
|
with ThreadPoolExecutor(max_workers=max_workers, thread_name_prefix="bos-sync") as executor: |
|
|
future_to_job = {} |
|
|
for job in download_jobs: |
|
|
logger.info( |
|
|
"准备同步 BOS 资源: %s (prefix=%s -> %s)", |
|
|
job["description"], |
|
|
job["prefix"], |
|
|
job["destination"], |
|
|
) |
|
|
future = executor.submit( |
|
|
download_bos_directory, |
|
|
job["prefix"], |
|
|
job["destination"], |
|
|
force_download=force_download, |
|
|
) |
|
|
future_to_job[future] = job |
|
|
|
|
|
for future in as_completed(future_to_job): |
|
|
job = future_to_job[future] |
|
|
description = job["description"] |
|
|
try: |
|
|
success = future.result() |
|
|
except Exception as exc: |
|
|
logger.warning("BOS 资源同步异常: %s (%s)", description, exc) |
|
|
success = False |
|
|
|
|
|
if success: |
|
|
logger.info("BOS 资源已就绪: %s", description) |
|
|
else: |
|
|
logger.warning("BOS 资源同步失败: %s", description) |
|
|
results.append(success) |
|
|
|
|
|
all_ready = all(results) if results else True |
|
|
if all_ready: |
|
|
_BOS_DOWNLOAD_COMPLETED = True |
|
|
|
|
|
if background_jobs: |
|
|
executor = _get_background_executor() |
|
|
|
|
|
def _make_callback(description: str): |
|
|
def _background_done(fut): |
|
|
try: |
|
|
success = fut.result() |
|
|
if success: |
|
|
logger.info("后台 BOS 资源已就绪: %s", description) |
|
|
else: |
|
|
logger.warning("后台 BOS 资源同步失败: %s", description) |
|
|
except Exception as exc: |
|
|
logger.warning("后台 BOS 资源同步异常: %s (%s)", description, exc) |
|
|
finally: |
|
|
with _BOS_DOWNLOAD_LOCK: |
|
|
if fut in _BOS_BACKGROUND_FUTURES: |
|
|
_BOS_BACKGROUND_FUTURES.remove(fut) |
|
|
return _background_done |
|
|
|
|
|
for job in background_jobs: |
|
|
logger.info( |
|
|
"后台同步 BOS 资源: %s (prefix=%s -> %s)", |
|
|
job["description"], |
|
|
job["prefix"], |
|
|
job["destination"], |
|
|
) |
|
|
future = executor.submit( |
|
|
download_bos_directory, |
|
|
job["prefix"], |
|
|
job["destination"], |
|
|
force_download=force_download, |
|
|
) |
|
|
future.add_done_callback(_make_callback(job["description"])) |
|
|
_BOS_BACKGROUND_FUTURES.append(future) |
|
|
|
|
|
return all_ready |
|
|
|
|
|
|
|
|
def upload_file_to_bos(file_path: str, object_name: str | None = None) -> bool: |
|
|
""" |
|
|
将指定文件同步上传到 BOS,失败不会抛出异常。 |
|
|
:param file_path: 本地文件路径 |
|
|
:param object_name: BOS 对象名称(可选) |
|
|
:return: 是否成功上传 |
|
|
""" |
|
|
if not BOS_UPLOAD_ENABLED: |
|
|
return False |
|
|
|
|
|
start_time = time.perf_counter() |
|
|
expanded_path = os.path.abspath(os.path.expanduser(file_path)) |
|
|
if not os.path.isfile(expanded_path): |
|
|
return False |
|
|
|
|
|
if not _is_path_under_images_dir(expanded_path): |
|
|
|
|
|
return False |
|
|
|
|
|
try: |
|
|
file_stat = os.stat(expanded_path) |
|
|
except OSError: |
|
|
return False |
|
|
|
|
|
if _get_bos_client() is None: |
|
|
return False |
|
|
|
|
|
|
|
|
if object_name: |
|
|
object_key = object_name.strip("/ ") |
|
|
else: |
|
|
base_name = os.path.basename(expanded_path) |
|
|
if BOS_IMAGE_DIR: |
|
|
object_key = "/".join( |
|
|
part.strip("/ ") for part in (BOS_IMAGE_DIR, base_name) if part |
|
|
) |
|
|
else: |
|
|
object_key = base_name |
|
|
|
|
|
mtime_ns = getattr(file_stat, "st_mtime_ns", int(file_stat.st_mtime * 1_000_000_000)) |
|
|
cache_signature = (mtime_ns, file_stat.st_size) |
|
|
cache_key = (expanded_path, object_key) |
|
|
|
|
|
with _BOS_UPLOAD_CACHE_LOCK: |
|
|
cached_signature = _BOS_UPLOAD_CACHE.get(cache_key) |
|
|
if cached_signature is not None: |
|
|
_BOS_UPLOAD_CACHE.move_to_end(cache_key) |
|
|
|
|
|
if cached_signature == cache_signature: |
|
|
elapsed_ms = (time.perf_counter() - start_time) * 1000 |
|
|
logger.info("文件已同步至 BOS(跳过重复上传,耗时 %.1f ms): %s", elapsed_ms, object_key) |
|
|
return True |
|
|
|
|
|
def _do_upload(mode_label: str) -> bool: |
|
|
client_inner = _get_bos_client() |
|
|
if client_inner is None: |
|
|
return False |
|
|
upload_start = time.perf_counter() |
|
|
try: |
|
|
client_inner.upload_file(expanded_path, BOS_BUCKET_NAME, object_key) |
|
|
elapsed_ms = (time.perf_counter() - upload_start) * 1000 |
|
|
logger.info("文件已同步至 BOS(%s,耗时 %.1f ms): %s", mode_label, elapsed_ms, object_key) |
|
|
with _BOS_UPLOAD_CACHE_LOCK: |
|
|
_BOS_UPLOAD_CACHE[cache_key] = cache_signature |
|
|
_BOS_UPLOAD_CACHE.move_to_end(cache_key) |
|
|
while len(_BOS_UPLOAD_CACHE) > _BOS_UPLOAD_CACHE_MAX: |
|
|
_BOS_UPLOAD_CACHE.popitem(last=False) |
|
|
return True |
|
|
except (ClientError, BotoCoreError, Exception) as exc: |
|
|
logger.warning("上传到 BOS 失败(%s,%s): %s", object_key, mode_label, exc) |
|
|
return False |
|
|
|
|
|
return _do_upload("同步") |
|
|
|
|
|
|
|
|
def delete_file_from_bos(file_path: str | None = None, |
|
|
object_name: str | None = None) -> bool: |
|
|
""" |
|
|
删除 BOS 中的指定对象,失败不会抛出异常。 |
|
|
:param file_path: 本地文件路径(可选,用于推导文件名) |
|
|
:param object_name: BOS 对象名称(可选,优先使用) |
|
|
:return: 是否成功删除 |
|
|
""" |
|
|
if not BOS_UPLOAD_ENABLED: |
|
|
return False |
|
|
|
|
|
client = _get_bos_client() |
|
|
if client is None: |
|
|
return False |
|
|
|
|
|
key_candidate = object_name.strip("/ ") if object_name else "" |
|
|
|
|
|
if not key_candidate and file_path: |
|
|
base_name = os.path.basename( |
|
|
os.path.abspath(os.path.expanduser(file_path))) |
|
|
key_candidate = base_name.strip() |
|
|
|
|
|
if not key_candidate: |
|
|
return False |
|
|
|
|
|
if BOS_IMAGE_DIR: |
|
|
object_key = "/".join( |
|
|
part.strip("/ ") for part in (BOS_IMAGE_DIR, key_candidate) if part |
|
|
) |
|
|
else: |
|
|
object_key = key_candidate |
|
|
|
|
|
try: |
|
|
client.delete_object(Bucket=BOS_BUCKET_NAME, Key=object_key) |
|
|
logger.info(f"已从 BOS 删除文件: {object_key}") |
|
|
return True |
|
|
except (ClientError, BotoCoreError, Exception) as e: |
|
|
logger.warning(f"删除 BOS 文件失败({object_key}): {e}") |
|
|
return False |
|
|
|
|
|
|
|
|
def image_to_base64(image: np.ndarray) -> str: |
|
|
"""将OpenCV图像转换为base64字符串""" |
|
|
if image is None or image.size == 0: |
|
|
return "" |
|
|
_, buffer = cv2.imencode(".webp", image, [cv2.IMWRITE_WEBP_QUALITY, 90]) |
|
|
img_base64 = base64.b64encode(buffer).decode("utf-8") |
|
|
return f"data:image/webp;base64,{img_base64}" |
|
|
|
|
|
|
|
|
def save_base64_to_unique_file( |
|
|
base64_string: str, output_dir: str = "output_images" |
|
|
) -> str | None: |
|
|
""" |
|
|
将带有MIME类型前缀的Base64字符串解码并保存到本地。 |
|
|
文件名格式为: {md5_hash}_{timestamp}.{extension} |
|
|
""" |
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
try: |
|
|
match = re.match(r"data:(image/\w+);base64,(.+)", base64_string) |
|
|
if match: |
|
|
mime_type = match.group(1) |
|
|
base64_data = match.group(2) |
|
|
else: |
|
|
mime_type = "image/jpeg" |
|
|
base64_data = base64_string |
|
|
|
|
|
extension_map = { |
|
|
"image/jpeg": "jpg", |
|
|
"image/png": "png", |
|
|
"image/gif": "gif", |
|
|
"image/webp": "webp", |
|
|
} |
|
|
file_extension = extension_map.get(mime_type, "webp") |
|
|
|
|
|
decoded_data = base64.b64decode(base64_data) |
|
|
|
|
|
except (ValueError, TypeError, base64.binascii.Error) as e: |
|
|
logger.error(f"Base64 decoding failed: {e}") |
|
|
return None |
|
|
|
|
|
md5_hash = hashlib.md5(base64_data.encode("utf-8")).hexdigest() |
|
|
filename = f"{md5_hash}.{file_extension}" |
|
|
file_path = os.path.join(output_dir, filename) |
|
|
|
|
|
try: |
|
|
with open(file_path, "wb") as f: |
|
|
f.write(decoded_data) |
|
|
return file_path |
|
|
except IOError as e: |
|
|
logger.error(f"File writing failed: {e}") |
|
|
return None |
|
|
|
|
|
|
|
|
def human_readable_size(size_bytes): |
|
|
"""人性化文件大小展示""" |
|
|
for unit in ["B", "KB", "MB", "GB"]: |
|
|
if size_bytes < 1024: |
|
|
return f"{size_bytes:.1f} {unit}" |
|
|
size_bytes /= 1024 |
|
|
return f"{size_bytes:.1f} TB" |
|
|
|
|
|
|
|
|
def delete_file(file_path: str): |
|
|
try: |
|
|
os.remove(file_path) |
|
|
logger.info(f"Deleted file: {file_path}") |
|
|
except Exception as error: |
|
|
logger.error(f"Failed to delete file {file_path}: {error}") |
|
|
|
|
|
|
|
|
def move_file_to_archive(file_path: str): |
|
|
try: |
|
|
if not os.path.exists(IMAGES_DIR): |
|
|
os.makedirs(IMAGES_DIR) |
|
|
filename = os.path.basename(file_path) |
|
|
destination = os.path.join(IMAGES_DIR, filename) |
|
|
shutil.move(file_path, destination) |
|
|
logger.info(f"Moved file to archive: {destination}") |
|
|
except Exception as error: |
|
|
logger.error(f"Failed to move file {file_path} to archive: {error}") |
|
|
|
|
|
|
|
|
def save_image_high_quality( |
|
|
image: np.ndarray, |
|
|
output_path: str, |
|
|
quality: int = SAVE_QUALITY, |
|
|
*, |
|
|
upload_to_bos: bool = True, |
|
|
) -> bool: |
|
|
""" |
|
|
保存图像,保持高质量,不进行压缩 |
|
|
:param image: 图像数组 |
|
|
:param output_path: 输出路径 |
|
|
:param quality: WebP质量 (0-100),默认95 |
|
|
:param upload_to_bos: 是否在写入后同步至 BOS |
|
|
:return: 保存是否成功 |
|
|
""" |
|
|
try: |
|
|
success, encoded_img = cv2.imencode( |
|
|
".webp", image, [cv2.IMWRITE_WEBP_QUALITY, quality] |
|
|
) |
|
|
if not success: |
|
|
logger.error(f"Image encoding failed: {output_path}") |
|
|
return False |
|
|
|
|
|
with open(output_path, "wb") as f: |
|
|
f.write(encoded_img) |
|
|
|
|
|
logger.info(f"High quality image saved successfully: {output_path}, quality: {quality}, size: {len(encoded_img) / 1024:.2f} KB") |
|
|
if upload_to_bos: |
|
|
upload_file_to_bos(output_path) |
|
|
return True |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to save image: {output_path}, error: {e}") |
|
|
return False |
|
|
|
|
|
|
|
|
def convert_numpy_types(obj): |
|
|
"""转换所有 numpy 类型为原生 Python 类型""" |
|
|
if isinstance(obj, (np.float32, np.float64)): |
|
|
return float(obj) |
|
|
elif isinstance(obj, (np.int32, np.int64)): |
|
|
return int(obj) |
|
|
elif isinstance(obj, dict): |
|
|
return {k: convert_numpy_types(v) for k, v in obj.items()} |
|
|
elif isinstance(obj, list): |
|
|
return [convert_numpy_types(i) for i in obj] |
|
|
else: |
|
|
return obj |
|
|
|
|
|
|
|
|
def compress_image_by_quality(image: np.ndarray, quality: int, output_format: str = 'webp') -> tuple[bytes, dict]: |
|
|
""" |
|
|
按质量压缩图像 |
|
|
:param image: 输入图像 |
|
|
:param quality: 压缩质量 (10-100) |
|
|
:param output_format: 输出格式 ('jpg', 'png', 'webp') |
|
|
:return: (压缩后的图像字节数据, 压缩信息) |
|
|
""" |
|
|
try: |
|
|
height, width = image.shape[:2] |
|
|
|
|
|
if output_format.lower() == 'png': |
|
|
|
|
|
compression_level = max(0, min(9, int((100 - quality) / 10))) |
|
|
success, encoded_img = cv2.imencode( |
|
|
".png", image, [cv2.IMWRITE_PNG_COMPRESSION, compression_level] |
|
|
) |
|
|
elif output_format.lower() == 'webp': |
|
|
|
|
|
success, encoded_img = cv2.imencode( |
|
|
".webp", image, [cv2.IMWRITE_WEBP_QUALITY, quality] |
|
|
) |
|
|
else: |
|
|
|
|
|
success, encoded_img = cv2.imencode( |
|
|
".jpg", image, [cv2.IMWRITE_JPEG_QUALITY, quality] |
|
|
) |
|
|
|
|
|
if not success: |
|
|
raise Exception("图像编码失败") |
|
|
|
|
|
compressed_bytes = encoded_img.tobytes() |
|
|
|
|
|
info = { |
|
|
'original_dimensions': f"{width} × {height}", |
|
|
'compressed_dimensions': f"{width} × {height}", |
|
|
'quality': quality, |
|
|
'format': output_format.upper(), |
|
|
'size': len(compressed_bytes) |
|
|
} |
|
|
|
|
|
return compressed_bytes, info |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to compress image by quality: {e}") |
|
|
raise |
|
|
|
|
|
|
|
|
def compress_image_by_dimensions(image: np.ndarray, target_width: int, target_height: int, |
|
|
quality: int = 100, output_format: str = 'jpg') -> tuple[bytes, dict]: |
|
|
""" |
|
|
按尺寸压缩图像 |
|
|
:param image: 输入图像 |
|
|
:param target_width: 目标宽度 |
|
|
:param target_height: 目标高度 |
|
|
:param quality: 压缩质量 |
|
|
:param output_format: 输出格式 |
|
|
:return: (压缩后的图像字节数据, 压缩信息) |
|
|
""" |
|
|
try: |
|
|
original_height, original_width = image.shape[:2] |
|
|
|
|
|
|
|
|
resized_image = cv2.resize( |
|
|
image, (target_width, target_height), |
|
|
interpolation=cv2.INTER_AREA |
|
|
) |
|
|
|
|
|
|
|
|
if output_format.lower() == 'png': |
|
|
compression_level = max(0, min(9, int((100 - quality) / 10))) |
|
|
success, encoded_img = cv2.imencode( |
|
|
".png", resized_image, [cv2.IMWRITE_PNG_COMPRESSION, compression_level] |
|
|
) |
|
|
elif output_format.lower() == 'webp': |
|
|
success, encoded_img = cv2.imencode( |
|
|
".webp", resized_image, [cv2.IMWRITE_WEBP_QUALITY, quality] |
|
|
) |
|
|
else: |
|
|
success, encoded_img = cv2.imencode( |
|
|
".jpg", resized_image, [cv2.IMWRITE_JPEG_QUALITY, quality] |
|
|
) |
|
|
|
|
|
if not success: |
|
|
raise Exception("图像编码失败") |
|
|
|
|
|
compressed_bytes = encoded_img.tobytes() |
|
|
|
|
|
info = { |
|
|
'original_dimensions': f"{original_width} × {original_height}", |
|
|
'compressed_dimensions': f"{target_width} × {target_height}", |
|
|
'quality': quality, |
|
|
'format': output_format.upper(), |
|
|
'size': len(compressed_bytes) |
|
|
} |
|
|
|
|
|
return compressed_bytes, info |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to compress image by dimensions: {e}") |
|
|
raise |
|
|
|
|
|
|
|
|
def compress_image_by_file_size(image: np.ndarray, target_size_kb: float, |
|
|
output_format: str = 'jpg') -> tuple[bytes, dict]: |
|
|
""" |
|
|
按文件大小压缩图像 - 使用多阶段二分法精确控制大小 |
|
|
:param image: 输入图像 |
|
|
:param target_size_kb: 目标文件大小(KB) |
|
|
:param output_format: 输出格式 |
|
|
:return: (压缩后的图像字节数据, 压缩信息) |
|
|
""" |
|
|
try: |
|
|
original_height, original_width = image.shape[:2] |
|
|
target_size_bytes = int(target_size_kb * 1024) |
|
|
|
|
|
def encode_image(img, quality): |
|
|
"""编码图像并返回字节数据""" |
|
|
if output_format.lower() == 'png': |
|
|
compression_level = max(0, min(9, int((100 - quality) / 10))) |
|
|
success, encoded_img = cv2.imencode( |
|
|
".png", img, [cv2.IMWRITE_PNG_COMPRESSION, compression_level] |
|
|
) |
|
|
elif output_format.lower() == 'webp': |
|
|
success, encoded_img = cv2.imencode( |
|
|
".webp", img, [cv2.IMWRITE_WEBP_QUALITY, quality] |
|
|
) |
|
|
else: |
|
|
success, encoded_img = cv2.imencode( |
|
|
".jpg", img, [cv2.IMWRITE_JPEG_QUALITY, quality] |
|
|
) |
|
|
|
|
|
if success: |
|
|
return encoded_img.tobytes() |
|
|
return None |
|
|
|
|
|
def find_best_scale_and_quality(target_bytes): |
|
|
"""寻找最佳的尺寸和质量组合""" |
|
|
best_result = None |
|
|
best_diff = float('inf') |
|
|
|
|
|
|
|
|
test_scales = [1.0, 0.95, 0.9, 0.85, 0.8, 0.75, 0.7, 0.65, 0.6, 0.55, 0.5, 0.45, 0.4, 0.35, 0.3] |
|
|
|
|
|
for scale in test_scales: |
|
|
|
|
|
if scale < 1.0: |
|
|
new_width = int(original_width * scale) |
|
|
new_height = int(original_height * scale) |
|
|
if new_width < 50 or new_height < 50: |
|
|
continue |
|
|
working_image = cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_AREA) |
|
|
else: |
|
|
working_image = image |
|
|
new_width, new_height = original_width, original_height |
|
|
|
|
|
|
|
|
min_q, max_q = 10, 100 |
|
|
scale_best_result = None |
|
|
scale_best_diff = float('inf') |
|
|
|
|
|
for _ in range(20): |
|
|
current_quality = (min_q + max_q) // 2 |
|
|
|
|
|
compressed_bytes = encode_image(working_image, current_quality) |
|
|
if not compressed_bytes: |
|
|
break |
|
|
|
|
|
current_size = len(compressed_bytes) |
|
|
size_diff = abs(current_size - target_bytes) |
|
|
size_ratio = current_size / target_bytes |
|
|
|
|
|
|
|
|
if 0.99 <= size_ratio <= 1.01: |
|
|
return { |
|
|
'bytes': compressed_bytes, |
|
|
'scale': scale, |
|
|
'width': new_width, |
|
|
'height': new_height, |
|
|
'quality': current_quality, |
|
|
'size': current_size, |
|
|
'ratio': size_ratio |
|
|
} |
|
|
|
|
|
|
|
|
if size_diff < scale_best_diff: |
|
|
scale_best_diff = size_diff |
|
|
scale_best_result = { |
|
|
'bytes': compressed_bytes, |
|
|
'scale': scale, |
|
|
'width': new_width, |
|
|
'height': new_height, |
|
|
'quality': current_quality, |
|
|
'size': current_size, |
|
|
'ratio': size_ratio |
|
|
} |
|
|
|
|
|
|
|
|
if current_size > target_bytes: |
|
|
max_q = current_quality - 1 |
|
|
else: |
|
|
min_q = current_quality + 1 |
|
|
|
|
|
if min_q >= max_q: |
|
|
break |
|
|
|
|
|
|
|
|
if scale_best_result and scale_best_diff < best_diff: |
|
|
best_diff = scale_best_diff |
|
|
best_result = scale_best_result |
|
|
|
|
|
|
|
|
if best_result and 0.95 <= best_result['ratio'] <= 1.05: |
|
|
break |
|
|
|
|
|
return best_result |
|
|
|
|
|
logger.info(f"Starting multi-stage compression, target size: {target_size_bytes} bytes ({target_size_kb}KB)") |
|
|
|
|
|
|
|
|
result = find_best_scale_and_quality(target_size_bytes) |
|
|
|
|
|
if result: |
|
|
error_percent = abs(result['ratio'] - 1) * 100 |
|
|
logger.info(f"Compression completed: scale ratio {result['scale']:.2f}, quality {result['quality']}%, " |
|
|
f"size {result['size']} bytes, error {error_percent:.2f}%") |
|
|
|
|
|
|
|
|
if error_percent > 10: |
|
|
if result['ratio'] < 0.5: |
|
|
suggested_size = result['size'] / 1024 |
|
|
logger.warning(f"Target size {target_size_kb}KB is too small, actually compressed to {suggested_size:.1f}KB, error {error_percent:.1f}%") |
|
|
elif result['ratio'] > 2.0: |
|
|
suggested_size = result['size'] / 1024 |
|
|
logger.warning(f"Target size {target_size_kb}KB is too large, minimum can be compressed to {suggested_size:.1f}KB, error {error_percent:.1f}%") |
|
|
else: |
|
|
logger.warning(f"Cannot achieve target accuracy, error {error_percent:.1f}%, returning closest result") |
|
|
|
|
|
info = { |
|
|
'original_dimensions': f"{original_width} × {original_height}", |
|
|
'compressed_dimensions': f"{result['width']} × {result['height']}", |
|
|
'quality': result['quality'], |
|
|
'format': output_format.upper(), |
|
|
'size': result['size'] |
|
|
} |
|
|
|
|
|
return result['bytes'], info |
|
|
else: |
|
|
raise Exception(f"无法将图片压缩到目标大小 {target_size_kb}KB") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to compress image by file size: {e}") |
|
|
raise |
|
|
|
|
|
|
|
|
def convert_image_format(image: np.ndarray, target_format: str, quality: int = 100) -> tuple[bytes, dict]: |
|
|
""" |
|
|
转换图像格式 |
|
|
:param image: 输入图像 |
|
|
:param target_format: 目标格式 ('jpg', 'png', 'webp') |
|
|
:param quality: 质量参数 |
|
|
:return: (转换后的图像字节数据, 格式信息) |
|
|
""" |
|
|
try: |
|
|
height, width = image.shape[:2] |
|
|
|
|
|
if target_format.lower() == 'png': |
|
|
|
|
|
compression_level = 6 |
|
|
success, encoded_img = cv2.imencode( |
|
|
".png", image, [cv2.IMWRITE_PNG_COMPRESSION, compression_level] |
|
|
) |
|
|
elif target_format.lower() == 'webp': |
|
|
|
|
|
success, encoded_img = cv2.imencode( |
|
|
".webp", image, [cv2.IMWRITE_WEBP_QUALITY, quality] |
|
|
) |
|
|
else: |
|
|
|
|
|
success, encoded_img = cv2.imencode( |
|
|
".jpg", image, [cv2.IMWRITE_JPEG_QUALITY, quality] |
|
|
) |
|
|
|
|
|
if not success: |
|
|
raise Exception("图像格式转换失败") |
|
|
|
|
|
converted_bytes = encoded_img.tobytes() |
|
|
|
|
|
info = { |
|
|
'original_dimensions': f"{width} × {height}", |
|
|
'compressed_dimensions': f"{width} × {height}", |
|
|
'quality': quality if target_format.lower() != 'png' else 100, |
|
|
'format': target_format.upper(), |
|
|
'size': len(converted_bytes) |
|
|
} |
|
|
|
|
|
return converted_bytes, info |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Image format conversion failed: {e}") |
|
|
raise |
|
|
|
|
|
|
|
|
def save_image_with_transparency(image: np.ndarray, file_path: str) -> bool: |
|
|
""" |
|
|
保存带透明通道的图像为PNG格式 |
|
|
:param image: OpenCV图像数组(BGRA格式,包含alpha通道) |
|
|
:param file_path: 保存路径 |
|
|
:return: 保存是否成功 |
|
|
""" |
|
|
if image is None: |
|
|
logger.error("Image is empty, cannot save") |
|
|
return False |
|
|
|
|
|
try: |
|
|
|
|
|
os.makedirs(os.path.dirname(file_path), exist_ok=True) |
|
|
|
|
|
|
|
|
if len(image.shape) == 3 and image.shape[2] == 4: |
|
|
|
|
|
rgba_image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA) |
|
|
elif len(image.shape) == 3 and image.shape[2] == 3: |
|
|
|
|
|
rgb_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
|
|
rgba_image = np.dstack((rgb_image, np.full(rgb_image.shape[:2], 255, dtype=np.uint8))) |
|
|
else: |
|
|
logger.error("Image format does not support transparency saving") |
|
|
return False |
|
|
|
|
|
|
|
|
pil_image = Image.fromarray(rgba_image, 'RGBA') |
|
|
pil_image.save(file_path, 'PNG', optimize=True) |
|
|
|
|
|
file_size = os.path.getsize(file_path) |
|
|
logger.info(f"Transparent PNG image saved: {file_path}, size: {file_size/1024:.1f}KB") |
|
|
upload_file_to_bos(file_path) |
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to save transparent PNG image: {e}") |
|
|
return False |
|
|
|