Spaces:
Paused
Paused
| import base64 | |
| import hashlib | |
| import os | |
| import re | |
| import shutil | |
| import threading | |
| import time | |
| from concurrent.futures import ThreadPoolExecutor, as_completed | |
| from typing import Optional | |
| from collections import OrderedDict | |
| import cv2 | |
| import numpy as np | |
| from PIL import Image | |
| try: | |
| import boto3 | |
| from botocore.exceptions import BotoCoreError, ClientError | |
| except ImportError: | |
| boto3 = None | |
| BotoCoreError = ClientError = Exception | |
| from config import ( | |
| IMAGES_DIR, | |
| logger, | |
| SAVE_QUALITY, | |
| MODELS_PATH, | |
| BOS_ACCESS_KEY, | |
| BOS_SECRET_KEY, | |
| BOS_ENDPOINT, | |
| BOS_BUCKET_NAME, | |
| BOS_IMAGE_DIR, | |
| BOS_UPLOAD_ENABLED, | |
| BOS_DOWNLOAD_TARGETS, | |
| HUGGINGFACE_REPO_ID, | |
| HUGGINGFACE_SYNC_ENABLED, | |
| HUGGINGFACE_REVISION, | |
| HUGGINGFACE_ALLOW_PATTERNS, | |
| HUGGINGFACE_IGNORE_PATTERNS, | |
| ) | |
| _BOS_CLIENT = None | |
| _BOS_CLIENT_INITIALIZED = False | |
| _BOS_CLIENT_LOCK = threading.Lock() | |
| _BOS_DOWNLOAD_LOCK = threading.Lock() | |
| _BOS_DOWNLOAD_COMPLETED = False | |
| _BOS_BACKGROUND_EXECUTOR = None | |
| _BOS_BACKGROUND_FUTURES = [] | |
| _IMAGES_DIR_ABS = os.path.abspath(os.path.expanduser(IMAGES_DIR)) | |
| _BOS_UPLOAD_CACHE = OrderedDict() | |
| _BOS_UPLOAD_CACHE_LOCK = threading.Lock() | |
| _BOS_UPLOAD_CACHE_MAX = 2048 | |
| def _decode_bos_credential(raw_value: str) -> str: | |
| """将Base64编码的凭证解码为明文,若解码失败则返回原值""" | |
| if not raw_value: | |
| return "" | |
| value = raw_value.strip() | |
| if not value: | |
| return "" | |
| try: | |
| padding = len(value) % 4 | |
| if padding: | |
| value += "=" * (4 - padding) | |
| decoded = base64.b64decode(value).decode("utf-8").strip() | |
| if decoded: | |
| return decoded | |
| except Exception: | |
| pass | |
| return value | |
| def _is_path_under_images_dir(file_path: str) -> bool: | |
| try: | |
| return os.path.commonpath( | |
| [_IMAGES_DIR_ABS, os.path.abspath(file_path)] | |
| ) == _IMAGES_DIR_ABS | |
| except ValueError: | |
| return False | |
| def _get_bos_client(): | |
| global _BOS_CLIENT, _BOS_CLIENT_INITIALIZED | |
| if _BOS_CLIENT_INITIALIZED: | |
| return _BOS_CLIENT | |
| with _BOS_CLIENT_LOCK: | |
| if _BOS_CLIENT_INITIALIZED: | |
| return _BOS_CLIENT | |
| if not BOS_UPLOAD_ENABLED: | |
| _BOS_CLIENT_INITIALIZED = True | |
| _BOS_CLIENT = None | |
| return None | |
| access_key = _decode_bos_credential(BOS_ACCESS_KEY) | |
| secret_key = _decode_bos_credential(BOS_SECRET_KEY) | |
| if not all([access_key, secret_key, BOS_ENDPOINT, BOS_BUCKET_NAME]): | |
| logger.warning("BOS 上传未配置完整,跳过初始化") | |
| _BOS_CLIENT_INITIALIZED = True | |
| _BOS_CLIENT = None | |
| return None | |
| if boto3 is None: | |
| logger.warning("未安装 boto3,BOS 上传功能不可用") | |
| _BOS_CLIENT_INITIALIZED = True | |
| _BOS_CLIENT = None | |
| return None | |
| try: | |
| _BOS_CLIENT = boto3.client( | |
| "s3", | |
| aws_access_key_id=access_key, | |
| aws_secret_access_key=secret_key, | |
| endpoint_url=BOS_ENDPOINT, | |
| ) | |
| logger.info("BOS 客户端初始化成功") | |
| except Exception as e: | |
| logger.warning(f"初始化 BOS 客户端失败,将跳过上传: {e}") | |
| _BOS_CLIENT = None | |
| finally: | |
| _BOS_CLIENT_INITIALIZED = True | |
| return _BOS_CLIENT | |
| def _normalize_bos_prefix(prefix: Optional[str]) -> str: | |
| value = (prefix or "").strip() | |
| if not value: | |
| return "" | |
| value = value.strip("/") | |
| if not value: | |
| return "" | |
| return f"{value}/" if not value.endswith("/") else value | |
| def _directory_has_files(path: str) -> bool: | |
| try: | |
| for _root, _dirs, files in os.walk(path): | |
| if files: | |
| return True | |
| except Exception: | |
| return False | |
| return False | |
| def download_bos_directory(prefix: str, destination_dir: str, *, force_download: bool = False) -> bool: | |
| """ | |
| 将 BOS 上的指定前缀目录同步到本地。 | |
| :param prefix: BOS 对象前缀,例如 'models/' 或 '20220620/models' | |
| :param destination_dir: 本地目标目录 | |
| :param force_download: 是否强制重新下载(忽略本地已存在的文件) | |
| :return: 是否确保目录可用 | |
| """ | |
| client = _get_bos_client() | |
| if client is None: | |
| logger.warning("BOS 客户端不可用,无法下载资源(prefix=%s)", prefix) | |
| return False | |
| dest_dir = os.path.abspath(os.path.expanduser(destination_dir)) | |
| try: | |
| os.makedirs(dest_dir, exist_ok=True) | |
| except Exception as exc: | |
| logger.error("创建本地目录失败: %s (%s)", dest_dir, exc) | |
| return False | |
| normalized_prefix = _normalize_bos_prefix(prefix) | |
| # 未强制下载且目录已有文件时直接跳过,避免重复下载 | |
| if not force_download and _directory_has_files(dest_dir): | |
| logger.info("本地目录已存在文件,跳过下载: %s -> %s", normalized_prefix or "<root>", dest_dir) | |
| return True | |
| paginate_kwargs = {"Bucket": BOS_BUCKET_NAME} | |
| if normalized_prefix: | |
| paginate_kwargs["Prefix"] = normalized_prefix if normalized_prefix.endswith("/") else f"{normalized_prefix}/" | |
| found_any = False | |
| downloaded = 0 | |
| skipped = 0 | |
| try: | |
| paginator = client.get_paginator("list_objects_v2") | |
| for page in paginator.paginate(**paginate_kwargs): | |
| for obj in page.get("Contents", []): | |
| key = obj.get("Key") | |
| if not key: | |
| continue | |
| if normalized_prefix: | |
| prefix_with_slash = normalized_prefix if normalized_prefix.endswith("/") else f"{normalized_prefix}/" | |
| if not key.startswith(prefix_with_slash): | |
| continue | |
| relative_key = key[len(prefix_with_slash):] | |
| else: | |
| relative_key = key | |
| if not relative_key or relative_key.endswith("/"): | |
| continue | |
| found_any = True | |
| target_path = os.path.join(dest_dir, relative_key) | |
| target_dir = os.path.dirname(target_path) | |
| os.makedirs(target_dir, exist_ok=True) | |
| expected_size = obj.get("Size") | |
| if ( | |
| not force_download | |
| and os.path.exists(target_path) | |
| and expected_size is not None | |
| and expected_size == os.path.getsize(target_path) | |
| ): | |
| skipped += 1 | |
| logger.info("文件已存在且大小一致,跳过下载: %s", relative_key) | |
| continue | |
| tmp_path = f"{target_path}.download" | |
| try: | |
| size_mb = (expected_size or 0) / (1024 * 1024) | |
| logger.info("开始下载: %s (%.2f MB)", relative_key, size_mb) | |
| client.download_file(Bucket=BOS_BUCKET_NAME, Key=key, Filename=tmp_path) | |
| os.replace(tmp_path, target_path) | |
| downloaded += 1 | |
| logger.info("下载完成: %s", relative_key) | |
| except Exception as exc: | |
| logger.warning("下载失败: %s (%s)", key, exc) | |
| try: | |
| if os.path.exists(tmp_path): | |
| os.remove(tmp_path) | |
| except Exception: | |
| pass | |
| except Exception as exc: | |
| logger.warning("遍历 BOS 目录失败: %s", exc) | |
| return False | |
| if not found_any: | |
| logger.warning("在 BOS 桶 %s 中未找到前缀 '%s' 的内容", BOS_BUCKET_NAME, normalized_prefix or "<root>") | |
| return False | |
| logger.info( | |
| "BOS 同步完成 prefix=%s -> %s 下载=%d 跳过=%d", | |
| normalized_prefix or "<root>", | |
| dest_dir, | |
| downloaded, | |
| skipped, | |
| ) | |
| return downloaded > 0 or skipped > 0 | |
| def _get_background_executor() -> ThreadPoolExecutor: | |
| global _BOS_BACKGROUND_EXECUTOR | |
| if _BOS_BACKGROUND_EXECUTOR is None: | |
| _BOS_BACKGROUND_EXECUTOR = ThreadPoolExecutor(max_workers=2, thread_name_prefix="bos-bg") | |
| return _BOS_BACKGROUND_EXECUTOR | |
| def ensure_huggingface_models(force_download: bool = False) -> bool: | |
| """确保 HuggingFace 模型仓库同步到本地 MODELS_PATH。""" | |
| if not HUGGINGFACE_SYNC_ENABLED: | |
| logger.info("HuggingFace 模型同步开关已关闭,跳过同步流程") | |
| return True | |
| repo_id = (HUGGINGFACE_REPO_ID or "").strip() | |
| if not repo_id: | |
| logger.info("未配置 HuggingFace 仓库,跳过模型下载") | |
| return True | |
| try: | |
| from huggingface_hub import snapshot_download | |
| except ImportError: | |
| logger.error("未安装 huggingface-hub,无法下载 HuggingFace 模型") | |
| return False | |
| try: | |
| os.makedirs(MODELS_PATH, exist_ok=True) | |
| except Exception as exc: | |
| logger.error("创建模型目录失败: %s (%s)", MODELS_PATH, exc) | |
| return False | |
| download_kwargs = { | |
| "repo_id": repo_id, | |
| "local_dir": MODELS_PATH, | |
| "local_dir_use_symlinks": False, | |
| } | |
| revision = (HUGGINGFACE_REVISION or "").strip() | |
| if revision: | |
| download_kwargs["revision"] = revision | |
| if HUGGINGFACE_ALLOW_PATTERNS: | |
| download_kwargs["allow_patterns"] = HUGGINGFACE_ALLOW_PATTERNS | |
| if HUGGINGFACE_IGNORE_PATTERNS: | |
| download_kwargs["ignore_patterns"] = HUGGINGFACE_IGNORE_PATTERNS | |
| if force_download: | |
| download_kwargs["force_download"] = True | |
| download_kwargs["resume_download"] = False | |
| else: | |
| download_kwargs["resume_download"] = True | |
| try: | |
| logger.info( | |
| "开始同步 HuggingFace 模型: repo=%s revision=%s -> %s", | |
| repo_id, | |
| revision or "<default>", | |
| MODELS_PATH, | |
| ) | |
| snapshot_path = snapshot_download(**download_kwargs) | |
| logger.info( | |
| "HuggingFace 模型同步完成: %s -> %s", | |
| repo_id, | |
| snapshot_path, | |
| ) | |
| return True | |
| except Exception as exc: | |
| logger.error("HuggingFace 模型下载失败: %s", exc) | |
| return False | |
| def ensure_bos_resources(force_download: bool = False, include_background: bool = False) -> bool: | |
| """ | |
| 根据配置的 BOS_DOWNLOAD_TARGETS 同步启动所需的模型与数据资源。 | |
| :param force_download: 是否强制重新同步所有资源 | |
| :param include_background: 是否将标记为后台任务的目标也同步为阻塞任务 | |
| :return: 资源是否已准备就绪 | |
| """ | |
| global _BOS_DOWNLOAD_COMPLETED, _BOS_BACKGROUND_FUTURES | |
| with _BOS_DOWNLOAD_LOCK: | |
| if _BOS_DOWNLOAD_COMPLETED and not force_download and not include_background: | |
| return True | |
| targets = BOS_DOWNLOAD_TARGETS or [] | |
| if not targets: | |
| logger.info("未配置 BOS 下载目标,跳过资源同步") | |
| _BOS_DOWNLOAD_COMPLETED = True | |
| return True | |
| download_jobs = [] | |
| background_jobs = [] | |
| for target in targets: | |
| if not isinstance(target, dict): | |
| logger.warning("无效的 BOS 下载配置项: %r", target) | |
| continue | |
| prefix = target.get("bos_prefix") | |
| destination = target.get("destination") | |
| description = target.get("description") or prefix or "<unnamed>" | |
| background_flag = bool(target.get("background")) | |
| if not prefix or not destination: | |
| logger.warning("缺少必要字段,无法处理 BOS 下载配置: %r", target) | |
| continue | |
| job = { | |
| "description": description, | |
| "prefix": prefix, | |
| "destination": destination, | |
| } | |
| if background_flag and not include_background: | |
| background_jobs.append(job) | |
| else: | |
| download_jobs.append(job) | |
| results = [] | |
| if download_jobs: | |
| max_workers = min(len(download_jobs), max(os.cpu_count() or 1, 1)) | |
| if max_workers <= 0: | |
| max_workers = 1 | |
| with ThreadPoolExecutor(max_workers=max_workers, thread_name_prefix="bos-sync") as executor: | |
| future_to_job = {} | |
| for job in download_jobs: | |
| logger.info( | |
| "准备同步 BOS 资源: %s (prefix=%s -> %s)", | |
| job["description"], | |
| job["prefix"], | |
| job["destination"], | |
| ) | |
| future = executor.submit( | |
| download_bos_directory, | |
| job["prefix"], | |
| job["destination"], | |
| force_download=force_download, | |
| ) | |
| future_to_job[future] = job | |
| for future in as_completed(future_to_job): | |
| job = future_to_job[future] | |
| description = job["description"] | |
| try: | |
| success = future.result() | |
| except Exception as exc: | |
| logger.warning("BOS 资源同步异常: %s (%s)", description, exc) | |
| success = False | |
| if success: | |
| logger.info("BOS 资源已就绪: %s", description) | |
| else: | |
| logger.warning("BOS 资源同步失败: %s", description) | |
| results.append(success) | |
| all_ready = all(results) if results else True | |
| if all_ready: | |
| _BOS_DOWNLOAD_COMPLETED = True | |
| if background_jobs: | |
| executor = _get_background_executor() | |
| def _make_callback(description: str): | |
| def _background_done(fut): | |
| try: | |
| success = fut.result() | |
| if success: | |
| logger.info("后台 BOS 资源已就绪: %s", description) | |
| else: | |
| logger.warning("后台 BOS 资源同步失败: %s", description) | |
| except Exception as exc: | |
| logger.warning("后台 BOS 资源同步异常: %s (%s)", description, exc) | |
| finally: | |
| with _BOS_DOWNLOAD_LOCK: | |
| if fut in _BOS_BACKGROUND_FUTURES: | |
| _BOS_BACKGROUND_FUTURES.remove(fut) | |
| return _background_done | |
| for job in background_jobs: | |
| logger.info( | |
| "后台同步 BOS 资源: %s (prefix=%s -> %s)", | |
| job["description"], | |
| job["prefix"], | |
| job["destination"], | |
| ) | |
| future = executor.submit( | |
| download_bos_directory, | |
| job["prefix"], | |
| job["destination"], | |
| force_download=force_download, | |
| ) | |
| future.add_done_callback(_make_callback(job["description"])) | |
| _BOS_BACKGROUND_FUTURES.append(future) | |
| return all_ready | |
| def upload_file_to_bos(file_path: str, object_name: str | None = None) -> bool: | |
| """ | |
| 将指定文件同步上传到 BOS,失败不会抛出异常。 | |
| :param file_path: 本地文件路径 | |
| :param object_name: BOS 对象名称(可选) | |
| :return: 是否成功上传 | |
| """ | |
| if not BOS_UPLOAD_ENABLED: | |
| return False | |
| start_time = time.perf_counter() | |
| expanded_path = os.path.abspath(os.path.expanduser(file_path)) | |
| if not os.path.isfile(expanded_path): | |
| return False | |
| if not _is_path_under_images_dir(expanded_path): | |
| # 仅上传 IMAGES_DIR 内的文件,避免将临时文件同步至 BOS | |
| return False | |
| try: | |
| file_stat = os.stat(expanded_path) | |
| except OSError: | |
| return False | |
| if _get_bos_client() is None: | |
| return False | |
| # 生成对象名称 | |
| if object_name: | |
| object_key = object_name.strip("/ ") | |
| else: | |
| base_name = os.path.basename(expanded_path) | |
| if BOS_IMAGE_DIR: | |
| object_key = "/".join( | |
| part.strip("/ ") for part in (BOS_IMAGE_DIR, base_name) if part | |
| ) | |
| else: | |
| object_key = base_name | |
| mtime_ns = getattr(file_stat, "st_mtime_ns", int(file_stat.st_mtime * 1_000_000_000)) | |
| cache_signature = (mtime_ns, file_stat.st_size) | |
| cache_key = (expanded_path, object_key) | |
| with _BOS_UPLOAD_CACHE_LOCK: | |
| cached_signature = _BOS_UPLOAD_CACHE.get(cache_key) | |
| if cached_signature is not None: | |
| _BOS_UPLOAD_CACHE.move_to_end(cache_key) | |
| if cached_signature == cache_signature: | |
| elapsed_ms = (time.perf_counter() - start_time) * 1000 | |
| logger.info("文件已同步至 BOS(跳过重复上传,耗时 %.1f ms): %s", elapsed_ms, object_key) | |
| return True | |
| def _do_upload(mode_label: str) -> bool: | |
| client_inner = _get_bos_client() | |
| if client_inner is None: | |
| return False | |
| upload_start = time.perf_counter() | |
| try: | |
| client_inner.upload_file(expanded_path, BOS_BUCKET_NAME, object_key) | |
| elapsed_ms = (time.perf_counter() - upload_start) * 1000 | |
| logger.info("文件已同步至 BOS(%s,耗时 %.1f ms): %s", mode_label, elapsed_ms, object_key) | |
| with _BOS_UPLOAD_CACHE_LOCK: | |
| _BOS_UPLOAD_CACHE[cache_key] = cache_signature | |
| _BOS_UPLOAD_CACHE.move_to_end(cache_key) | |
| while len(_BOS_UPLOAD_CACHE) > _BOS_UPLOAD_CACHE_MAX: | |
| _BOS_UPLOAD_CACHE.popitem(last=False) | |
| return True | |
| except (ClientError, BotoCoreError, Exception) as exc: | |
| logger.warning("上传到 BOS 失败(%s,%s): %s", object_key, mode_label, exc) | |
| return False | |
| return _do_upload("同步") | |
| def delete_file_from_bos(file_path: str | None = None, | |
| object_name: str | None = None) -> bool: | |
| """ | |
| 删除 BOS 中的指定对象,失败不会抛出异常。 | |
| :param file_path: 本地文件路径(可选,用于推导文件名) | |
| :param object_name: BOS 对象名称(可选,优先使用) | |
| :return: 是否成功删除 | |
| """ | |
| if not BOS_UPLOAD_ENABLED: | |
| return False | |
| client = _get_bos_client() | |
| if client is None: | |
| return False | |
| key_candidate = object_name.strip("/ ") if object_name else "" | |
| if not key_candidate and file_path: | |
| base_name = os.path.basename( | |
| os.path.abspath(os.path.expanduser(file_path))) | |
| key_candidate = base_name.strip() | |
| if not key_candidate: | |
| return False | |
| if BOS_IMAGE_DIR: | |
| object_key = "/".join( | |
| part.strip("/ ") for part in (BOS_IMAGE_DIR, key_candidate) if part | |
| ) | |
| else: | |
| object_key = key_candidate | |
| try: | |
| client.delete_object(Bucket=BOS_BUCKET_NAME, Key=object_key) | |
| logger.info(f"已从 BOS 删除文件: {object_key}") | |
| return True | |
| except (ClientError, BotoCoreError, Exception) as e: | |
| logger.warning(f"删除 BOS 文件失败({object_key}): {e}") | |
| return False | |
| def image_to_base64(image: np.ndarray) -> str: | |
| """将OpenCV图像转换为base64字符串""" | |
| if image is None or image.size == 0: | |
| return "" | |
| _, buffer = cv2.imencode(".webp", image, [cv2.IMWRITE_WEBP_QUALITY, 90]) | |
| img_base64 = base64.b64encode(buffer).decode("utf-8") | |
| return f"data:image/webp;base64,{img_base64}" | |
| def save_base64_to_unique_file( | |
| base64_string: str, output_dir: str = "output_images" | |
| ) -> str | None: | |
| """ | |
| 将带有MIME类型前缀的Base64字符串解码并保存到本地。 | |
| 文件名格式为: {md5_hash}_{timestamp}.{extension} | |
| """ | |
| os.makedirs(output_dir, exist_ok=True) | |
| try: | |
| match = re.match(r"data:(image/\w+);base64,(.+)", base64_string) | |
| if match: | |
| mime_type = match.group(1) | |
| base64_data = match.group(2) | |
| else: | |
| mime_type = "image/jpeg" | |
| base64_data = base64_string | |
| extension_map = { | |
| "image/jpeg": "jpg", | |
| "image/png": "png", | |
| "image/gif": "gif", | |
| "image/webp": "webp", | |
| } | |
| file_extension = extension_map.get(mime_type, "webp") | |
| decoded_data = base64.b64decode(base64_data) | |
| except (ValueError, TypeError, base64.binascii.Error) as e: | |
| logger.error(f"Base64 decoding failed: {e}") | |
| return None | |
| md5_hash = hashlib.md5(base64_data.encode("utf-8")).hexdigest() | |
| filename = f"{md5_hash}.{file_extension}" | |
| file_path = os.path.join(output_dir, filename) | |
| try: | |
| with open(file_path, "wb") as f: | |
| f.write(decoded_data) | |
| return file_path | |
| except IOError as e: | |
| logger.error(f"File writing failed: {e}") | |
| return None | |
| def human_readable_size(size_bytes): | |
| """人性化文件大小展示""" | |
| for unit in ["B", "KB", "MB", "GB"]: | |
| if size_bytes < 1024: | |
| return f"{size_bytes:.1f} {unit}" | |
| size_bytes /= 1024 | |
| return f"{size_bytes:.1f} TB" | |
| def delete_file(file_path: str): | |
| try: | |
| os.remove(file_path) | |
| logger.info(f"Deleted file: {file_path}") | |
| except Exception as error: | |
| logger.error(f"Failed to delete file {file_path}: {error}") | |
| def move_file_to_archive(file_path: str): | |
| try: | |
| if not os.path.exists(IMAGES_DIR): | |
| os.makedirs(IMAGES_DIR) | |
| filename = os.path.basename(file_path) | |
| destination = os.path.join(IMAGES_DIR, filename) | |
| shutil.move(file_path, destination) | |
| logger.info(f"Moved file to archive: {destination}") | |
| except Exception as error: | |
| logger.error(f"Failed to move file {file_path} to archive: {error}") | |
| def save_image_high_quality( | |
| image: np.ndarray, | |
| output_path: str, | |
| quality: int = SAVE_QUALITY, | |
| *, | |
| upload_to_bos: bool = True, | |
| ) -> bool: | |
| """ | |
| 保存图像,保持高质量,不进行压缩 | |
| :param image: 图像数组 | |
| :param output_path: 输出路径 | |
| :param quality: WebP质量 (0-100),默认95 | |
| :param upload_to_bos: 是否在写入后同步至 BOS | |
| :return: 保存是否成功 | |
| """ | |
| try: | |
| success, encoded_img = cv2.imencode( | |
| ".webp", image, [cv2.IMWRITE_WEBP_QUALITY, quality] | |
| ) | |
| if not success: | |
| logger.error(f"Image encoding failed: {output_path}") | |
| return False | |
| with open(output_path, "wb") as f: | |
| f.write(encoded_img) | |
| logger.info(f"High quality image saved successfully: {output_path}, quality: {quality}, size: {len(encoded_img) / 1024:.2f} KB") | |
| if upload_to_bos: | |
| upload_file_to_bos(output_path) | |
| return True | |
| except Exception as e: | |
| logger.error(f"Failed to save image: {output_path}, error: {e}") | |
| return False | |
| def convert_numpy_types(obj): | |
| """转换所有 numpy 类型为原生 Python 类型""" | |
| if isinstance(obj, (np.float32, np.float64)): | |
| return float(obj) | |
| elif isinstance(obj, (np.int32, np.int64)): | |
| return int(obj) | |
| elif isinstance(obj, dict): | |
| return {k: convert_numpy_types(v) for k, v in obj.items()} | |
| elif isinstance(obj, list): | |
| return [convert_numpy_types(i) for i in obj] | |
| else: | |
| return obj | |
| def compress_image_by_quality(image: np.ndarray, quality: int, output_format: str = 'webp') -> tuple[bytes, dict]: | |
| """ | |
| 按质量压缩图像 | |
| :param image: 输入图像 | |
| :param quality: 压缩质量 (10-100) | |
| :param output_format: 输出格式 ('jpg', 'png', 'webp') | |
| :return: (压缩后的图像字节数据, 压缩信息) | |
| """ | |
| try: | |
| height, width = image.shape[:2] | |
| if output_format.lower() == 'png': | |
| # PNG使用压缩级别 (0-9),质量参数转换为压缩级别 | |
| compression_level = max(0, min(9, int((100 - quality) / 10))) | |
| success, encoded_img = cv2.imencode( | |
| ".png", image, [cv2.IMWRITE_PNG_COMPRESSION, compression_level] | |
| ) | |
| elif output_format.lower() == 'webp': | |
| # WebP支持质量参数 | |
| success, encoded_img = cv2.imencode( | |
| ".webp", image, [cv2.IMWRITE_WEBP_QUALITY, quality] | |
| ) | |
| else: | |
| # JPG格式 | |
| success, encoded_img = cv2.imencode( | |
| ".jpg", image, [cv2.IMWRITE_JPEG_QUALITY, quality] | |
| ) | |
| if not success: | |
| raise Exception("图像编码失败") | |
| compressed_bytes = encoded_img.tobytes() | |
| info = { | |
| 'original_dimensions': f"{width} × {height}", | |
| 'compressed_dimensions': f"{width} × {height}", | |
| 'quality': quality, | |
| 'format': output_format.upper(), | |
| 'size': len(compressed_bytes) | |
| } | |
| return compressed_bytes, info | |
| except Exception as e: | |
| logger.error(f"Failed to compress image by quality: {e}") | |
| raise | |
| def compress_image_by_dimensions(image: np.ndarray, target_width: int, target_height: int, | |
| quality: int = 100, output_format: str = 'jpg') -> tuple[bytes, dict]: | |
| """ | |
| 按尺寸压缩图像 | |
| :param image: 输入图像 | |
| :param target_width: 目标宽度 | |
| :param target_height: 目标高度 | |
| :param quality: 压缩质量 | |
| :param output_format: 输出格式 | |
| :return: (压缩后的图像字节数据, 压缩信息) | |
| """ | |
| try: | |
| original_height, original_width = image.shape[:2] | |
| # 调整图像尺寸 | |
| resized_image = cv2.resize( | |
| image, (target_width, target_height), | |
| interpolation=cv2.INTER_AREA | |
| ) | |
| # 按质量编码 | |
| if output_format.lower() == 'png': | |
| compression_level = max(0, min(9, int((100 - quality) / 10))) | |
| success, encoded_img = cv2.imencode( | |
| ".png", resized_image, [cv2.IMWRITE_PNG_COMPRESSION, compression_level] | |
| ) | |
| elif output_format.lower() == 'webp': | |
| success, encoded_img = cv2.imencode( | |
| ".webp", resized_image, [cv2.IMWRITE_WEBP_QUALITY, quality] | |
| ) | |
| else: | |
| success, encoded_img = cv2.imencode( | |
| ".jpg", resized_image, [cv2.IMWRITE_JPEG_QUALITY, quality] | |
| ) | |
| if not success: | |
| raise Exception("图像编码失败") | |
| compressed_bytes = encoded_img.tobytes() | |
| info = { | |
| 'original_dimensions': f"{original_width} × {original_height}", | |
| 'compressed_dimensions': f"{target_width} × {target_height}", | |
| 'quality': quality, | |
| 'format': output_format.upper(), | |
| 'size': len(compressed_bytes) | |
| } | |
| return compressed_bytes, info | |
| except Exception as e: | |
| logger.error(f"Failed to compress image by dimensions: {e}") | |
| raise | |
| def compress_image_by_file_size(image: np.ndarray, target_size_kb: float, | |
| output_format: str = 'jpg') -> tuple[bytes, dict]: | |
| """ | |
| 按文件大小压缩图像 - 使用多阶段二分法精确控制大小 | |
| :param image: 输入图像 | |
| :param target_size_kb: 目标文件大小(KB) | |
| :param output_format: 输出格式 | |
| :return: (压缩后的图像字节数据, 压缩信息) | |
| """ | |
| try: | |
| original_height, original_width = image.shape[:2] | |
| target_size_bytes = int(target_size_kb * 1024) | |
| def encode_image(img, quality): | |
| """编码图像并返回字节数据""" | |
| if output_format.lower() == 'png': | |
| compression_level = max(0, min(9, int((100 - quality) / 10))) | |
| success, encoded_img = cv2.imencode( | |
| ".png", img, [cv2.IMWRITE_PNG_COMPRESSION, compression_level] | |
| ) | |
| elif output_format.lower() == 'webp': | |
| success, encoded_img = cv2.imencode( | |
| ".webp", img, [cv2.IMWRITE_WEBP_QUALITY, quality] | |
| ) | |
| else: | |
| success, encoded_img = cv2.imencode( | |
| ".jpg", img, [cv2.IMWRITE_JPEG_QUALITY, quality] | |
| ) | |
| if success: | |
| return encoded_img.tobytes() | |
| return None | |
| def find_best_scale_and_quality(target_bytes): | |
| """寻找最佳的尺寸和质量组合""" | |
| best_result = None | |
| best_diff = float('inf') | |
| # 尝试多个尺寸比例 | |
| test_scales = [1.0, 0.95, 0.9, 0.85, 0.8, 0.75, 0.7, 0.65, 0.6, 0.55, 0.5, 0.45, 0.4, 0.35, 0.3] | |
| for scale in test_scales: | |
| # 调整图像尺寸 | |
| if scale < 1.0: | |
| new_width = int(original_width * scale) | |
| new_height = int(original_height * scale) | |
| if new_width < 50 or new_height < 50: # 避免尺寸太小 | |
| continue | |
| working_image = cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_AREA) | |
| else: | |
| working_image = image | |
| new_width, new_height = original_width, original_height | |
| # 在这个尺寸下使用二分法寻找最佳质量 | |
| min_q, max_q = 10, 100 | |
| scale_best_result = None | |
| scale_best_diff = float('inf') | |
| for _ in range(20): # 每个尺寸最多尝试20次质量调整 | |
| current_quality = (min_q + max_q) // 2 | |
| compressed_bytes = encode_image(working_image, current_quality) | |
| if not compressed_bytes: | |
| break | |
| current_size = len(compressed_bytes) | |
| size_diff = abs(current_size - target_bytes) | |
| size_ratio = current_size / target_bytes | |
| # 如果找到精确匹配,立即返回 | |
| if 0.99 <= size_ratio <= 1.01: # 1%误差以内 | |
| return { | |
| 'bytes': compressed_bytes, | |
| 'scale': scale, | |
| 'width': new_width, | |
| 'height': new_height, | |
| 'quality': current_quality, | |
| 'size': current_size, | |
| 'ratio': size_ratio | |
| } | |
| # 记录该尺寸下的最佳结果 | |
| if size_diff < scale_best_diff: | |
| scale_best_diff = size_diff | |
| scale_best_result = { | |
| 'bytes': compressed_bytes, | |
| 'scale': scale, | |
| 'width': new_width, | |
| 'height': new_height, | |
| 'quality': current_quality, | |
| 'size': current_size, | |
| 'ratio': size_ratio | |
| } | |
| # 二分法调整质量 | |
| if current_size > target_bytes: | |
| max_q = current_quality - 1 | |
| else: | |
| min_q = current_quality + 1 | |
| if min_q >= max_q: | |
| break | |
| # 更新全局最佳结果 | |
| if scale_best_result and scale_best_diff < best_diff: | |
| best_diff = scale_best_diff | |
| best_result = scale_best_result | |
| # 如果已经找到很好的结果(5%以内),可以提前结束 | |
| if best_result and 0.95 <= best_result['ratio'] <= 1.05: | |
| break | |
| return best_result | |
| logger.info(f"Starting multi-stage compression, target size: {target_size_bytes} bytes ({target_size_kb}KB)") | |
| # 寻找最佳组合 | |
| result = find_best_scale_and_quality(target_size_bytes) | |
| if result: | |
| error_percent = abs(result['ratio'] - 1) * 100 | |
| logger.info(f"Compression completed: scale ratio {result['scale']:.2f}, quality {result['quality']}%, " | |
| f"size {result['size']} bytes, error {error_percent:.2f}%") | |
| # 不管误差多大都返回最接近的结果,只记录警告 | |
| if error_percent > 10: | |
| if result['ratio'] < 0.5: # 压缩过度 | |
| suggested_size = result['size'] / 1024 | |
| logger.warning(f"Target size {target_size_kb}KB is too small, actually compressed to {suggested_size:.1f}KB, error {error_percent:.1f}%") | |
| elif result['ratio'] > 2.0: # 无法达到目标 | |
| suggested_size = result['size'] / 1024 | |
| logger.warning(f"Target size {target_size_kb}KB is too large, minimum can be compressed to {suggested_size:.1f}KB, error {error_percent:.1f}%") | |
| else: | |
| logger.warning(f"Cannot achieve target accuracy, error {error_percent:.1f}%, returning closest result") | |
| info = { | |
| 'original_dimensions': f"{original_width} × {original_height}", | |
| 'compressed_dimensions': f"{result['width']} × {result['height']}", | |
| 'quality': result['quality'], | |
| 'format': output_format.upper(), | |
| 'size': result['size'] | |
| } | |
| return result['bytes'], info | |
| else: | |
| raise Exception(f"无法将图片压缩到目标大小 {target_size_kb}KB") | |
| except Exception as e: | |
| logger.error(f"Failed to compress image by file size: {e}") | |
| raise | |
| def convert_image_format(image: np.ndarray, target_format: str, quality: int = 100) -> tuple[bytes, dict]: | |
| """ | |
| 转换图像格式 | |
| :param image: 输入图像 | |
| :param target_format: 目标格式 ('jpg', 'png', 'webp') | |
| :param quality: 质量参数 | |
| :return: (转换后的图像字节数据, 格式信息) | |
| """ | |
| try: | |
| height, width = image.shape[:2] | |
| if target_format.lower() == 'png': | |
| # PNG格式,使用压缩级别 | |
| compression_level = 6 # 默认压缩级别 | |
| success, encoded_img = cv2.imencode( | |
| ".png", image, [cv2.IMWRITE_PNG_COMPRESSION, compression_level] | |
| ) | |
| elif target_format.lower() == 'webp': | |
| # WebP格式 | |
| success, encoded_img = cv2.imencode( | |
| ".webp", image, [cv2.IMWRITE_WEBP_QUALITY, quality] | |
| ) | |
| else: | |
| # JPG格式 | |
| success, encoded_img = cv2.imencode( | |
| ".jpg", image, [cv2.IMWRITE_JPEG_QUALITY, quality] | |
| ) | |
| if not success: | |
| raise Exception("图像格式转换失败") | |
| converted_bytes = encoded_img.tobytes() | |
| info = { | |
| 'original_dimensions': f"{width} × {height}", | |
| 'compressed_dimensions': f"{width} × {height}", | |
| 'quality': quality if target_format.lower() != 'png' else 100, | |
| 'format': target_format.upper(), | |
| 'size': len(converted_bytes) | |
| } | |
| return converted_bytes, info | |
| except Exception as e: | |
| logger.error(f"Image format conversion failed: {e}") | |
| raise | |
| def save_image_with_transparency(image: np.ndarray, file_path: str) -> bool: | |
| """ | |
| 保存带透明通道的图像为PNG格式 | |
| :param image: OpenCV图像数组(BGRA格式,包含alpha通道) | |
| :param file_path: 保存路径 | |
| :return: 保存是否成功 | |
| """ | |
| if image is None: | |
| logger.error("Image is empty, cannot save") | |
| return False | |
| try: | |
| # 确保目录存在 | |
| os.makedirs(os.path.dirname(file_path), exist_ok=True) | |
| # 如果图像有4个通道(BGRA),转换为RGBA然后保存 | |
| if len(image.shape) == 3 and image.shape[2] == 4: | |
| # BGRA转换为RGBA | |
| rgba_image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA) | |
| elif len(image.shape) == 3 and image.shape[2] == 3: | |
| # 如果是BGR格式,先转换为RGB,但这种情况不应该有透明度 | |
| rgb_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| rgba_image = np.dstack((rgb_image, np.full(rgb_image.shape[:2], 255, dtype=np.uint8))) | |
| else: | |
| logger.error("Image format does not support transparency saving") | |
| return False | |
| # 使用PIL保存PNG | |
| pil_image = Image.fromarray(rgba_image, 'RGBA') | |
| pil_image.save(file_path, 'PNG', optimize=True) | |
| file_size = os.path.getsize(file_path) | |
| logger.info(f"Transparent PNG image saved: {file_path}, size: {file_size/1024:.1f}KB") | |
| upload_file_to_bos(file_path) | |
| return True | |
| except Exception as e: | |
| logger.error(f"Failed to save transparent PNG image: {e}") | |
| return False | |