picpocket / utils.py
chawin.chen
fix
017b111
import base64
import hashlib
import os
import re
import shutil
import threading
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Optional
from collections import OrderedDict
import cv2
import numpy as np
from PIL import Image
try:
import boto3
from botocore.exceptions import BotoCoreError, ClientError
except ImportError:
boto3 = None
BotoCoreError = ClientError = Exception
from config import (
IMAGES_DIR,
logger,
SAVE_QUALITY,
MODELS_PATH,
BOS_ACCESS_KEY,
BOS_SECRET_KEY,
BOS_ENDPOINT,
BOS_BUCKET_NAME,
BOS_IMAGE_DIR,
BOS_UPLOAD_ENABLED,
BOS_DOWNLOAD_TARGETS,
HUGGINGFACE_REPO_ID,
HUGGINGFACE_SYNC_ENABLED,
HUGGINGFACE_REVISION,
HUGGINGFACE_ALLOW_PATTERNS,
HUGGINGFACE_IGNORE_PATTERNS,
)
_BOS_CLIENT = None
_BOS_CLIENT_INITIALIZED = False
_BOS_CLIENT_LOCK = threading.Lock()
_BOS_DOWNLOAD_LOCK = threading.Lock()
_BOS_DOWNLOAD_COMPLETED = False
_BOS_BACKGROUND_EXECUTOR = None
_BOS_BACKGROUND_FUTURES = []
_IMAGES_DIR_ABS = os.path.abspath(os.path.expanduser(IMAGES_DIR))
_BOS_UPLOAD_CACHE = OrderedDict()
_BOS_UPLOAD_CACHE_LOCK = threading.Lock()
_BOS_UPLOAD_CACHE_MAX = 2048
def _decode_bos_credential(raw_value: str) -> str:
"""将Base64编码的凭证解码为明文,若解码失败则返回原值"""
if not raw_value:
return ""
value = raw_value.strip()
if not value:
return ""
try:
padding = len(value) % 4
if padding:
value += "=" * (4 - padding)
decoded = base64.b64decode(value).decode("utf-8").strip()
if decoded:
return decoded
except Exception:
pass
return value
def _is_path_under_images_dir(file_path: str) -> bool:
try:
return os.path.commonpath(
[_IMAGES_DIR_ABS, os.path.abspath(file_path)]
) == _IMAGES_DIR_ABS
except ValueError:
return False
def _get_bos_client():
global _BOS_CLIENT, _BOS_CLIENT_INITIALIZED
if _BOS_CLIENT_INITIALIZED:
return _BOS_CLIENT
with _BOS_CLIENT_LOCK:
if _BOS_CLIENT_INITIALIZED:
return _BOS_CLIENT
if not BOS_UPLOAD_ENABLED:
_BOS_CLIENT_INITIALIZED = True
_BOS_CLIENT = None
return None
access_key = _decode_bos_credential(BOS_ACCESS_KEY)
secret_key = _decode_bos_credential(BOS_SECRET_KEY)
if not all([access_key, secret_key, BOS_ENDPOINT, BOS_BUCKET_NAME]):
logger.warning("BOS 上传未配置完整,跳过初始化")
_BOS_CLIENT_INITIALIZED = True
_BOS_CLIENT = None
return None
if boto3 is None:
logger.warning("未安装 boto3,BOS 上传功能不可用")
_BOS_CLIENT_INITIALIZED = True
_BOS_CLIENT = None
return None
try:
_BOS_CLIENT = boto3.client(
"s3",
aws_access_key_id=access_key,
aws_secret_access_key=secret_key,
endpoint_url=BOS_ENDPOINT,
)
logger.info("BOS 客户端初始化成功")
except Exception as e:
logger.warning(f"初始化 BOS 客户端失败,将跳过上传: {e}")
_BOS_CLIENT = None
finally:
_BOS_CLIENT_INITIALIZED = True
return _BOS_CLIENT
def _normalize_bos_prefix(prefix: Optional[str]) -> str:
value = (prefix or "").strip()
if not value:
return ""
value = value.strip("/")
if not value:
return ""
return f"{value}/" if not value.endswith("/") else value
def _directory_has_files(path: str) -> bool:
try:
for _root, _dirs, files in os.walk(path):
if files:
return True
except Exception:
return False
return False
def download_bos_directory(prefix: str, destination_dir: str, *, force_download: bool = False) -> bool:
"""
将 BOS 上的指定前缀目录同步到本地。
:param prefix: BOS 对象前缀,例如 'models/' 或 '20220620/models'
:param destination_dir: 本地目标目录
:param force_download: 是否强制重新下载(忽略本地已存在的文件)
:return: 是否确保目录可用
"""
client = _get_bos_client()
if client is None:
logger.warning("BOS 客户端不可用,无法下载资源(prefix=%s)", prefix)
return False
dest_dir = os.path.abspath(os.path.expanduser(destination_dir))
try:
os.makedirs(dest_dir, exist_ok=True)
except Exception as exc:
logger.error("创建本地目录失败: %s (%s)", dest_dir, exc)
return False
normalized_prefix = _normalize_bos_prefix(prefix)
# 未强制下载且目录已有文件时直接跳过,避免重复下载
if not force_download and _directory_has_files(dest_dir):
logger.info("本地目录已存在文件,跳过下载: %s -> %s", normalized_prefix or "<root>", dest_dir)
return True
paginate_kwargs = {"Bucket": BOS_BUCKET_NAME}
if normalized_prefix:
paginate_kwargs["Prefix"] = normalized_prefix if normalized_prefix.endswith("/") else f"{normalized_prefix}/"
found_any = False
downloaded = 0
skipped = 0
try:
paginator = client.get_paginator("list_objects_v2")
for page in paginator.paginate(**paginate_kwargs):
for obj in page.get("Contents", []):
key = obj.get("Key")
if not key:
continue
if normalized_prefix:
prefix_with_slash = normalized_prefix if normalized_prefix.endswith("/") else f"{normalized_prefix}/"
if not key.startswith(prefix_with_slash):
continue
relative_key = key[len(prefix_with_slash):]
else:
relative_key = key
if not relative_key or relative_key.endswith("/"):
continue
found_any = True
target_path = os.path.join(dest_dir, relative_key)
target_dir = os.path.dirname(target_path)
os.makedirs(target_dir, exist_ok=True)
expected_size = obj.get("Size")
if (
not force_download
and os.path.exists(target_path)
and expected_size is not None
and expected_size == os.path.getsize(target_path)
):
skipped += 1
logger.info("文件已存在且大小一致,跳过下载: %s", relative_key)
continue
tmp_path = f"{target_path}.download"
try:
size_mb = (expected_size or 0) / (1024 * 1024)
logger.info("开始下载: %s (%.2f MB)", relative_key, size_mb)
client.download_file(Bucket=BOS_BUCKET_NAME, Key=key, Filename=tmp_path)
os.replace(tmp_path, target_path)
downloaded += 1
logger.info("下载完成: %s", relative_key)
except Exception as exc:
logger.warning("下载失败: %s (%s)", key, exc)
try:
if os.path.exists(tmp_path):
os.remove(tmp_path)
except Exception:
pass
except Exception as exc:
logger.warning("遍历 BOS 目录失败: %s", exc)
return False
if not found_any:
logger.warning("在 BOS 桶 %s 中未找到前缀 '%s' 的内容", BOS_BUCKET_NAME, normalized_prefix or "<root>")
return False
logger.info(
"BOS 同步完成 prefix=%s -> %s 下载=%d 跳过=%d",
normalized_prefix or "<root>",
dest_dir,
downloaded,
skipped,
)
return downloaded > 0 or skipped > 0
def _get_background_executor() -> ThreadPoolExecutor:
global _BOS_BACKGROUND_EXECUTOR
if _BOS_BACKGROUND_EXECUTOR is None:
_BOS_BACKGROUND_EXECUTOR = ThreadPoolExecutor(max_workers=2, thread_name_prefix="bos-bg")
return _BOS_BACKGROUND_EXECUTOR
def ensure_huggingface_models(force_download: bool = False) -> bool:
"""确保 HuggingFace 模型仓库同步到本地 MODELS_PATH。"""
if not HUGGINGFACE_SYNC_ENABLED:
logger.info("HuggingFace 模型同步开关已关闭,跳过同步流程")
return True
repo_id = (HUGGINGFACE_REPO_ID or "").strip()
if not repo_id:
logger.info("未配置 HuggingFace 仓库,跳过模型下载")
return True
try:
from huggingface_hub import snapshot_download
except ImportError:
logger.error("未安装 huggingface-hub,无法下载 HuggingFace 模型")
return False
try:
os.makedirs(MODELS_PATH, exist_ok=True)
except Exception as exc:
logger.error("创建模型目录失败: %s (%s)", MODELS_PATH, exc)
return False
download_kwargs = {
"repo_id": repo_id,
"local_dir": MODELS_PATH,
"local_dir_use_symlinks": False,
}
revision = (HUGGINGFACE_REVISION or "").strip()
if revision:
download_kwargs["revision"] = revision
if HUGGINGFACE_ALLOW_PATTERNS:
download_kwargs["allow_patterns"] = HUGGINGFACE_ALLOW_PATTERNS
if HUGGINGFACE_IGNORE_PATTERNS:
download_kwargs["ignore_patterns"] = HUGGINGFACE_IGNORE_PATTERNS
if force_download:
download_kwargs["force_download"] = True
download_kwargs["resume_download"] = False
else:
download_kwargs["resume_download"] = True
try:
logger.info(
"开始同步 HuggingFace 模型: repo=%s revision=%s -> %s",
repo_id,
revision or "<default>",
MODELS_PATH,
)
snapshot_path = snapshot_download(**download_kwargs)
logger.info(
"HuggingFace 模型同步完成: %s -> %s",
repo_id,
snapshot_path,
)
return True
except Exception as exc:
logger.error("HuggingFace 模型下载失败: %s", exc)
return False
def ensure_bos_resources(force_download: bool = False, include_background: bool = False) -> bool:
"""
根据配置的 BOS_DOWNLOAD_TARGETS 同步启动所需的模型与数据资源。
:param force_download: 是否强制重新同步所有资源
:param include_background: 是否将标记为后台任务的目标也同步为阻塞任务
:return: 资源是否已准备就绪
"""
global _BOS_DOWNLOAD_COMPLETED, _BOS_BACKGROUND_FUTURES
with _BOS_DOWNLOAD_LOCK:
if _BOS_DOWNLOAD_COMPLETED and not force_download and not include_background:
return True
targets = BOS_DOWNLOAD_TARGETS or []
if not targets:
logger.info("未配置 BOS 下载目标,跳过资源同步")
_BOS_DOWNLOAD_COMPLETED = True
return True
download_jobs = []
background_jobs = []
for target in targets:
if not isinstance(target, dict):
logger.warning("无效的 BOS 下载配置项: %r", target)
continue
prefix = target.get("bos_prefix")
destination = target.get("destination")
description = target.get("description") or prefix or "<unnamed>"
background_flag = bool(target.get("background"))
if not prefix or not destination:
logger.warning("缺少必要字段,无法处理 BOS 下载配置: %r", target)
continue
job = {
"description": description,
"prefix": prefix,
"destination": destination,
}
if background_flag and not include_background:
background_jobs.append(job)
else:
download_jobs.append(job)
results = []
if download_jobs:
max_workers = min(len(download_jobs), max(os.cpu_count() or 1, 1))
if max_workers <= 0:
max_workers = 1
with ThreadPoolExecutor(max_workers=max_workers, thread_name_prefix="bos-sync") as executor:
future_to_job = {}
for job in download_jobs:
logger.info(
"准备同步 BOS 资源: %s (prefix=%s -> %s)",
job["description"],
job["prefix"],
job["destination"],
)
future = executor.submit(
download_bos_directory,
job["prefix"],
job["destination"],
force_download=force_download,
)
future_to_job[future] = job
for future in as_completed(future_to_job):
job = future_to_job[future]
description = job["description"]
try:
success = future.result()
except Exception as exc:
logger.warning("BOS 资源同步异常: %s (%s)", description, exc)
success = False
if success:
logger.info("BOS 资源已就绪: %s", description)
else:
logger.warning("BOS 资源同步失败: %s", description)
results.append(success)
all_ready = all(results) if results else True
if all_ready:
_BOS_DOWNLOAD_COMPLETED = True
if background_jobs:
executor = _get_background_executor()
def _make_callback(description: str):
def _background_done(fut):
try:
success = fut.result()
if success:
logger.info("后台 BOS 资源已就绪: %s", description)
else:
logger.warning("后台 BOS 资源同步失败: %s", description)
except Exception as exc:
logger.warning("后台 BOS 资源同步异常: %s (%s)", description, exc)
finally:
with _BOS_DOWNLOAD_LOCK:
if fut in _BOS_BACKGROUND_FUTURES:
_BOS_BACKGROUND_FUTURES.remove(fut)
return _background_done
for job in background_jobs:
logger.info(
"后台同步 BOS 资源: %s (prefix=%s -> %s)",
job["description"],
job["prefix"],
job["destination"],
)
future = executor.submit(
download_bos_directory,
job["prefix"],
job["destination"],
force_download=force_download,
)
future.add_done_callback(_make_callback(job["description"]))
_BOS_BACKGROUND_FUTURES.append(future)
return all_ready
def upload_file_to_bos(file_path: str, object_name: str | None = None) -> bool:
"""
将指定文件同步上传到 BOS,失败不会抛出异常。
:param file_path: 本地文件路径
:param object_name: BOS 对象名称(可选)
:return: 是否成功上传
"""
if not BOS_UPLOAD_ENABLED:
return False
start_time = time.perf_counter()
expanded_path = os.path.abspath(os.path.expanduser(file_path))
if not os.path.isfile(expanded_path):
return False
if not _is_path_under_images_dir(expanded_path):
# 仅上传 IMAGES_DIR 内的文件,避免将临时文件同步至 BOS
return False
try:
file_stat = os.stat(expanded_path)
except OSError:
return False
if _get_bos_client() is None:
return False
# 生成对象名称
if object_name:
object_key = object_name.strip("/ ")
else:
base_name = os.path.basename(expanded_path)
if BOS_IMAGE_DIR:
object_key = "/".join(
part.strip("/ ") for part in (BOS_IMAGE_DIR, base_name) if part
)
else:
object_key = base_name
mtime_ns = getattr(file_stat, "st_mtime_ns", int(file_stat.st_mtime * 1_000_000_000))
cache_signature = (mtime_ns, file_stat.st_size)
cache_key = (expanded_path, object_key)
with _BOS_UPLOAD_CACHE_LOCK:
cached_signature = _BOS_UPLOAD_CACHE.get(cache_key)
if cached_signature is not None:
_BOS_UPLOAD_CACHE.move_to_end(cache_key)
if cached_signature == cache_signature:
elapsed_ms = (time.perf_counter() - start_time) * 1000
logger.info("文件已同步至 BOS(跳过重复上传,耗时 %.1f ms): %s", elapsed_ms, object_key)
return True
def _do_upload(mode_label: str) -> bool:
client_inner = _get_bos_client()
if client_inner is None:
return False
upload_start = time.perf_counter()
try:
client_inner.upload_file(expanded_path, BOS_BUCKET_NAME, object_key)
elapsed_ms = (time.perf_counter() - upload_start) * 1000
logger.info("文件已同步至 BOS(%s,耗时 %.1f ms): %s", mode_label, elapsed_ms, object_key)
with _BOS_UPLOAD_CACHE_LOCK:
_BOS_UPLOAD_CACHE[cache_key] = cache_signature
_BOS_UPLOAD_CACHE.move_to_end(cache_key)
while len(_BOS_UPLOAD_CACHE) > _BOS_UPLOAD_CACHE_MAX:
_BOS_UPLOAD_CACHE.popitem(last=False)
return True
except (ClientError, BotoCoreError, Exception) as exc:
logger.warning("上传到 BOS 失败(%s,%s): %s", object_key, mode_label, exc)
return False
return _do_upload("同步")
def delete_file_from_bos(file_path: str | None = None,
object_name: str | None = None) -> bool:
"""
删除 BOS 中的指定对象,失败不会抛出异常。
:param file_path: 本地文件路径(可选,用于推导文件名)
:param object_name: BOS 对象名称(可选,优先使用)
:return: 是否成功删除
"""
if not BOS_UPLOAD_ENABLED:
return False
client = _get_bos_client()
if client is None:
return False
key_candidate = object_name.strip("/ ") if object_name else ""
if not key_candidate and file_path:
base_name = os.path.basename(
os.path.abspath(os.path.expanduser(file_path)))
key_candidate = base_name.strip()
if not key_candidate:
return False
if BOS_IMAGE_DIR:
object_key = "/".join(
part.strip("/ ") for part in (BOS_IMAGE_DIR, key_candidate) if part
)
else:
object_key = key_candidate
try:
client.delete_object(Bucket=BOS_BUCKET_NAME, Key=object_key)
logger.info(f"已从 BOS 删除文件: {object_key}")
return True
except (ClientError, BotoCoreError, Exception) as e:
logger.warning(f"删除 BOS 文件失败({object_key}): {e}")
return False
def image_to_base64(image: np.ndarray) -> str:
"""将OpenCV图像转换为base64字符串"""
if image is None or image.size == 0:
return ""
_, buffer = cv2.imencode(".webp", image, [cv2.IMWRITE_WEBP_QUALITY, 90])
img_base64 = base64.b64encode(buffer).decode("utf-8")
return f"data:image/webp;base64,{img_base64}"
def save_base64_to_unique_file(
base64_string: str, output_dir: str = "output_images"
) -> str | None:
"""
将带有MIME类型前缀的Base64字符串解码并保存到本地。
文件名格式为: {md5_hash}_{timestamp}.{extension}
"""
os.makedirs(output_dir, exist_ok=True)
try:
match = re.match(r"data:(image/\w+);base64,(.+)", base64_string)
if match:
mime_type = match.group(1)
base64_data = match.group(2)
else:
mime_type = "image/jpeg"
base64_data = base64_string
extension_map = {
"image/jpeg": "jpg",
"image/png": "png",
"image/gif": "gif",
"image/webp": "webp",
}
file_extension = extension_map.get(mime_type, "webp")
decoded_data = base64.b64decode(base64_data)
except (ValueError, TypeError, base64.binascii.Error) as e:
logger.error(f"Base64 decoding failed: {e}")
return None
md5_hash = hashlib.md5(base64_data.encode("utf-8")).hexdigest()
filename = f"{md5_hash}.{file_extension}"
file_path = os.path.join(output_dir, filename)
try:
with open(file_path, "wb") as f:
f.write(decoded_data)
return file_path
except IOError as e:
logger.error(f"File writing failed: {e}")
return None
def human_readable_size(size_bytes):
"""人性化文件大小展示"""
for unit in ["B", "KB", "MB", "GB"]:
if size_bytes < 1024:
return f"{size_bytes:.1f} {unit}"
size_bytes /= 1024
return f"{size_bytes:.1f} TB"
def delete_file(file_path: str):
try:
os.remove(file_path)
logger.info(f"Deleted file: {file_path}")
except Exception as error:
logger.error(f"Failed to delete file {file_path}: {error}")
def move_file_to_archive(file_path: str):
try:
if not os.path.exists(IMAGES_DIR):
os.makedirs(IMAGES_DIR)
filename = os.path.basename(file_path)
destination = os.path.join(IMAGES_DIR, filename)
shutil.move(file_path, destination)
logger.info(f"Moved file to archive: {destination}")
except Exception as error:
logger.error(f"Failed to move file {file_path} to archive: {error}")
def save_image_high_quality(
image: np.ndarray,
output_path: str,
quality: int = SAVE_QUALITY,
*,
upload_to_bos: bool = True,
) -> bool:
"""
保存图像,保持高质量,不进行压缩
:param image: 图像数组
:param output_path: 输出路径
:param quality: WebP质量 (0-100),默认95
:param upload_to_bos: 是否在写入后同步至 BOS
:return: 保存是否成功
"""
try:
success, encoded_img = cv2.imencode(
".webp", image, [cv2.IMWRITE_WEBP_QUALITY, quality]
)
if not success:
logger.error(f"Image encoding failed: {output_path}")
return False
with open(output_path, "wb") as f:
f.write(encoded_img)
logger.info(f"High quality image saved successfully: {output_path}, quality: {quality}, size: {len(encoded_img) / 1024:.2f} KB")
if upload_to_bos:
upload_file_to_bos(output_path)
return True
except Exception as e:
logger.error(f"Failed to save image: {output_path}, error: {e}")
return False
def convert_numpy_types(obj):
"""转换所有 numpy 类型为原生 Python 类型"""
if isinstance(obj, (np.float32, np.float64)):
return float(obj)
elif isinstance(obj, (np.int32, np.int64)):
return int(obj)
elif isinstance(obj, dict):
return {k: convert_numpy_types(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [convert_numpy_types(i) for i in obj]
else:
return obj
def compress_image_by_quality(image: np.ndarray, quality: int, output_format: str = 'webp') -> tuple[bytes, dict]:
"""
按质量压缩图像
:param image: 输入图像
:param quality: 压缩质量 (10-100)
:param output_format: 输出格式 ('jpg', 'png', 'webp')
:return: (压缩后的图像字节数据, 压缩信息)
"""
try:
height, width = image.shape[:2]
if output_format.lower() == 'png':
# PNG使用压缩级别 (0-9),质量参数转换为压缩级别
compression_level = max(0, min(9, int((100 - quality) / 10)))
success, encoded_img = cv2.imencode(
".png", image, [cv2.IMWRITE_PNG_COMPRESSION, compression_level]
)
elif output_format.lower() == 'webp':
# WebP支持质量参数
success, encoded_img = cv2.imencode(
".webp", image, [cv2.IMWRITE_WEBP_QUALITY, quality]
)
else:
# JPG格式
success, encoded_img = cv2.imencode(
".jpg", image, [cv2.IMWRITE_JPEG_QUALITY, quality]
)
if not success:
raise Exception("图像编码失败")
compressed_bytes = encoded_img.tobytes()
info = {
'original_dimensions': f"{width} × {height}",
'compressed_dimensions': f"{width} × {height}",
'quality': quality,
'format': output_format.upper(),
'size': len(compressed_bytes)
}
return compressed_bytes, info
except Exception as e:
logger.error(f"Failed to compress image by quality: {e}")
raise
def compress_image_by_dimensions(image: np.ndarray, target_width: int, target_height: int,
quality: int = 100, output_format: str = 'jpg') -> tuple[bytes, dict]:
"""
按尺寸压缩图像
:param image: 输入图像
:param target_width: 目标宽度
:param target_height: 目标高度
:param quality: 压缩质量
:param output_format: 输出格式
:return: (压缩后的图像字节数据, 压缩信息)
"""
try:
original_height, original_width = image.shape[:2]
# 调整图像尺寸
resized_image = cv2.resize(
image, (target_width, target_height),
interpolation=cv2.INTER_AREA
)
# 按质量编码
if output_format.lower() == 'png':
compression_level = max(0, min(9, int((100 - quality) / 10)))
success, encoded_img = cv2.imencode(
".png", resized_image, [cv2.IMWRITE_PNG_COMPRESSION, compression_level]
)
elif output_format.lower() == 'webp':
success, encoded_img = cv2.imencode(
".webp", resized_image, [cv2.IMWRITE_WEBP_QUALITY, quality]
)
else:
success, encoded_img = cv2.imencode(
".jpg", resized_image, [cv2.IMWRITE_JPEG_QUALITY, quality]
)
if not success:
raise Exception("图像编码失败")
compressed_bytes = encoded_img.tobytes()
info = {
'original_dimensions': f"{original_width} × {original_height}",
'compressed_dimensions': f"{target_width} × {target_height}",
'quality': quality,
'format': output_format.upper(),
'size': len(compressed_bytes)
}
return compressed_bytes, info
except Exception as e:
logger.error(f"Failed to compress image by dimensions: {e}")
raise
def compress_image_by_file_size(image: np.ndarray, target_size_kb: float,
output_format: str = 'jpg') -> tuple[bytes, dict]:
"""
按文件大小压缩图像 - 使用多阶段二分法精确控制大小
:param image: 输入图像
:param target_size_kb: 目标文件大小(KB)
:param output_format: 输出格式
:return: (压缩后的图像字节数据, 压缩信息)
"""
try:
original_height, original_width = image.shape[:2]
target_size_bytes = int(target_size_kb * 1024)
def encode_image(img, quality):
"""编码图像并返回字节数据"""
if output_format.lower() == 'png':
compression_level = max(0, min(9, int((100 - quality) / 10)))
success, encoded_img = cv2.imencode(
".png", img, [cv2.IMWRITE_PNG_COMPRESSION, compression_level]
)
elif output_format.lower() == 'webp':
success, encoded_img = cv2.imencode(
".webp", img, [cv2.IMWRITE_WEBP_QUALITY, quality]
)
else:
success, encoded_img = cv2.imencode(
".jpg", img, [cv2.IMWRITE_JPEG_QUALITY, quality]
)
if success:
return encoded_img.tobytes()
return None
def find_best_scale_and_quality(target_bytes):
"""寻找最佳的尺寸和质量组合"""
best_result = None
best_diff = float('inf')
# 尝试多个尺寸比例
test_scales = [1.0, 0.95, 0.9, 0.85, 0.8, 0.75, 0.7, 0.65, 0.6, 0.55, 0.5, 0.45, 0.4, 0.35, 0.3]
for scale in test_scales:
# 调整图像尺寸
if scale < 1.0:
new_width = int(original_width * scale)
new_height = int(original_height * scale)
if new_width < 50 or new_height < 50: # 避免尺寸太小
continue
working_image = cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_AREA)
else:
working_image = image
new_width, new_height = original_width, original_height
# 在这个尺寸下使用二分法寻找最佳质量
min_q, max_q = 10, 100
scale_best_result = None
scale_best_diff = float('inf')
for _ in range(20): # 每个尺寸最多尝试20次质量调整
current_quality = (min_q + max_q) // 2
compressed_bytes = encode_image(working_image, current_quality)
if not compressed_bytes:
break
current_size = len(compressed_bytes)
size_diff = abs(current_size - target_bytes)
size_ratio = current_size / target_bytes
# 如果找到精确匹配,立即返回
if 0.99 <= size_ratio <= 1.01: # 1%误差以内
return {
'bytes': compressed_bytes,
'scale': scale,
'width': new_width,
'height': new_height,
'quality': current_quality,
'size': current_size,
'ratio': size_ratio
}
# 记录该尺寸下的最佳结果
if size_diff < scale_best_diff:
scale_best_diff = size_diff
scale_best_result = {
'bytes': compressed_bytes,
'scale': scale,
'width': new_width,
'height': new_height,
'quality': current_quality,
'size': current_size,
'ratio': size_ratio
}
# 二分法调整质量
if current_size > target_bytes:
max_q = current_quality - 1
else:
min_q = current_quality + 1
if min_q >= max_q:
break
# 更新全局最佳结果
if scale_best_result and scale_best_diff < best_diff:
best_diff = scale_best_diff
best_result = scale_best_result
# 如果已经找到很好的结果(5%以内),可以提前结束
if best_result and 0.95 <= best_result['ratio'] <= 1.05:
break
return best_result
logger.info(f"Starting multi-stage compression, target size: {target_size_bytes} bytes ({target_size_kb}KB)")
# 寻找最佳组合
result = find_best_scale_and_quality(target_size_bytes)
if result:
error_percent = abs(result['ratio'] - 1) * 100
logger.info(f"Compression completed: scale ratio {result['scale']:.2f}, quality {result['quality']}%, "
f"size {result['size']} bytes, error {error_percent:.2f}%")
# 不管误差多大都返回最接近的结果,只记录警告
if error_percent > 10:
if result['ratio'] < 0.5: # 压缩过度
suggested_size = result['size'] / 1024
logger.warning(f"Target size {target_size_kb}KB is too small, actually compressed to {suggested_size:.1f}KB, error {error_percent:.1f}%")
elif result['ratio'] > 2.0: # 无法达到目标
suggested_size = result['size'] / 1024
logger.warning(f"Target size {target_size_kb}KB is too large, minimum can be compressed to {suggested_size:.1f}KB, error {error_percent:.1f}%")
else:
logger.warning(f"Cannot achieve target accuracy, error {error_percent:.1f}%, returning closest result")
info = {
'original_dimensions': f"{original_width} × {original_height}",
'compressed_dimensions': f"{result['width']} × {result['height']}",
'quality': result['quality'],
'format': output_format.upper(),
'size': result['size']
}
return result['bytes'], info
else:
raise Exception(f"无法将图片压缩到目标大小 {target_size_kb}KB")
except Exception as e:
logger.error(f"Failed to compress image by file size: {e}")
raise
def convert_image_format(image: np.ndarray, target_format: str, quality: int = 100) -> tuple[bytes, dict]:
"""
转换图像格式
:param image: 输入图像
:param target_format: 目标格式 ('jpg', 'png', 'webp')
:param quality: 质量参数
:return: (转换后的图像字节数据, 格式信息)
"""
try:
height, width = image.shape[:2]
if target_format.lower() == 'png':
# PNG格式,使用压缩级别
compression_level = 6 # 默认压缩级别
success, encoded_img = cv2.imencode(
".png", image, [cv2.IMWRITE_PNG_COMPRESSION, compression_level]
)
elif target_format.lower() == 'webp':
# WebP格式
success, encoded_img = cv2.imencode(
".webp", image, [cv2.IMWRITE_WEBP_QUALITY, quality]
)
else:
# JPG格式
success, encoded_img = cv2.imencode(
".jpg", image, [cv2.IMWRITE_JPEG_QUALITY, quality]
)
if not success:
raise Exception("图像格式转换失败")
converted_bytes = encoded_img.tobytes()
info = {
'original_dimensions': f"{width} × {height}",
'compressed_dimensions': f"{width} × {height}",
'quality': quality if target_format.lower() != 'png' else 100,
'format': target_format.upper(),
'size': len(converted_bytes)
}
return converted_bytes, info
except Exception as e:
logger.error(f"Image format conversion failed: {e}")
raise
def save_image_with_transparency(image: np.ndarray, file_path: str) -> bool:
"""
保存带透明通道的图像为PNG格式
:param image: OpenCV图像数组(BGRA格式,包含alpha通道)
:param file_path: 保存路径
:return: 保存是否成功
"""
if image is None:
logger.error("Image is empty, cannot save")
return False
try:
# 确保目录存在
os.makedirs(os.path.dirname(file_path), exist_ok=True)
# 如果图像有4个通道(BGRA),转换为RGBA然后保存
if len(image.shape) == 3 and image.shape[2] == 4:
# BGRA转换为RGBA
rgba_image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA)
elif len(image.shape) == 3 and image.shape[2] == 3:
# 如果是BGR格式,先转换为RGB,但这种情况不应该有透明度
rgb_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
rgba_image = np.dstack((rgb_image, np.full(rgb_image.shape[:2], 255, dtype=np.uint8)))
else:
logger.error("Image format does not support transparency saving")
return False
# 使用PIL保存PNG
pil_image = Image.fromarray(rgba_image, 'RGBA')
pil_image.save(file_path, 'PNG', optimize=True)
file_size = os.path.getsize(file_path)
logger.info(f"Transparent PNG image saved: {file_path}, size: {file_size/1024:.1f}KB")
upload_file_to_bos(file_path)
return True
except Exception as e:
logger.error(f"Failed to save transparent PNG image: {e}")
return False