DanbooruSearch / platform_utils.py
SAkizuki's picture
Auto-sync from GitHub Actions
a4bb90e verified
"""
platform_utils.py
统一的平台检测与 Hub 操作封装。
支持平台:
- HuggingFace Space
- ModelScope 创空间(魔搭)
- 本地开发环境
对外暴露:
PLATFORM : Literal['hf', 'ms', 'local']
is_cloud() : bool
get_host_port() : tuple[str, int]
download_file() : 下载单个文件,返回本地路径
upload_bytes() : 上传 bytes 到 OSS(用于计数器持久化)
read_bytes() : 从 OSS 读取文件内容,返回 bytes | None
get_counter_cfg() : 返回 CounterConfig(platform / available)
环境变量约定:
HuggingFace Space(由 HF 自动注入)
SPACE_ID Space 唯一标识,存在即代表在 HF 环境
SPACE_AUTHOR_NAME 作者名
用户手动配置(HF Secrets):
HF_TOKEN HF 访问令牌(仅用于 download_file,非计数器)
ModelScope 创空间(由魔搭自动注入)
MODELSCOPE_ENVIRONMENT 存在即代表在魔搭环境(值通常为 "studio")
STUDIO_ID 创空间 ID(备用检测)
魔搭平台数据文件说明:
数据文件(CSV / parquet / safetensors)直接放在创空间 studio repo
中,容器启动时会自动同步到工作目录,download_file() 在 MS 平台
直接返回本地路径,无需额外配置 Model repo。
阿里云 OSS(计数器唯一后端,HF 与 MS 共享同一数据)
OSS_ACCESS_KEY_ID RAM 子账号 AccessKey ID
OSS_ACCESS_KEY_SECRET RAM 子账号 AccessKey Secret
OSS_ENDPOINT Bucket 所在地域节点
例: oss-cn-hangzhou.aliyuncs.com
(无需加 https://,代码自动拼接)
OSS_BUCKET_NAME Bucket 名称
OSS_COUNTER_DIR 计数文件在 Bucket 中的前缀目录(可选)
默认 "danbooru_counter"
最终路径: {OSS_COUNTER_DIR}/count.json
"""
from __future__ import annotations
import os
import time
from dataclasses import dataclass
from pathlib import Path
import oss2
from typing import Literal, Optional
# 平台检测
def _detect_platform() -> Literal['hf', 'ms', 'local']:
if os.environ.get('SPACE_ID'):
return 'hf'
if os.environ.get('MODELSCOPE_ENVIRONMENT') or os.environ.get('STUDIO_ID'):
return 'ms'
return 'local'
PLATFORM: Literal['hf', 'ms', 'local'] = _detect_platform()
def is_cloud() -> bool:
"""是否运行在任意云端平台。"""
return PLATFORM in ('hf', 'ms')
def get_host_port() -> tuple[str, int]:
"""
返回 NiceGUI 应使用的 (host, port)。
HF 和魔搭创空间都使用 0.0.0.0:7860;本地使用 127.0.0.1:1111。
"""
if is_cloud():
return '0.0.0.0', 7860
return '127.0.0.1', 11111
def nsfw_allowed() -> bool:
"""
返回当前平台是否允许用户开启 NSFW 显示。
魔搭(MS)平台禁用 NSFW,其余平台默认允许。
如需在任意平台强制禁用,可设置环境变量 DISABLE_NSFW=1。
"""
if os.environ.get('DISABLE_NSFW', '0') == '1':
return False
return PLATFORM != 'ms'
# 阿里云 OSS
def _get_oss_bucket():
"""
从环境变量读取 OSS 配置,返回 oss2.Bucket 对象。
若环境变量不完整或 oss2 未安装则返回 None。
"""
ak = os.environ.get('OSS_ACCESS_KEY_ID')
sk = os.environ.get('OSS_ACCESS_KEY_SECRET')
ep = os.environ.get('OSS_ENDPOINT')
bkt = os.environ.get('OSS_BUCKET_NAME')
if not all([ak, sk, ep, bkt]):
return None
try:
import oss2
auth = oss2.Auth(ak, sk)
endpoint = ep if ep.startswith('http') else f'https://{ep}'
return oss2.Bucket(auth, endpoint, bkt)
except ImportError:
print('[PlatformUtils] oss2 未安装,OSS 计数器不可用。请 pip install oss2。')
return None
def _oss_key(filename: str) -> str:
"""将 filename 拼上可选的前缀目录,得到 OSS Object Key。"""
prefix = os.environ.get('OSS_COUNTER_DIR', 'danbooru_counter').rstrip('/')
return f'{prefix}/{filename}'
def _oss_available() -> bool:
"""检测 OSS 四项环境变量是否均已设置且 oss2 可导入。"""
return _get_oss_bucket() is not None
# 计数器配置
@dataclass
class CounterConfig:
platform: Literal['oss', 'local']
@property
def available(self) -> bool:
if self.platform == 'oss':
return _oss_available()
return False
def get_counter_cfg() -> CounterConfig:
"""
读取计数器配置。
配置了 OSS 环境变量则使用 OSS,否则退化为本地模式(无持久化)。
"""
if _oss_available():
return CounterConfig(platform='oss')
return CounterConfig(platform='local')
# 计数器读写(OSS)
def read_bytes(filename: str, cfg: CounterConfig) -> Optional[bytes]:
"""
从 OSS 读取文件内容,返回 bytes。
文件不存在返回 None;网络或权限异常向上抛出。
"""
if not cfg.available:
return None
bucket = _get_oss_bucket()
key = _oss_key(filename)
try:
import oss2
result = bucket.get_object(key)
return result.read()
except oss2.exceptions.NoSuchKey:
return None
except Exception as e:
print(f'[PlatformUtils] OSS 读取失败 ({key}): {e}')
raise
def upload_bytes(
content: bytes,
filename: str,
cfg: CounterConfig,
commit_message: str = 'Update',
*,
retries: int = 3,
retry_delay: float = 1.0,
) -> bool:
"""
将 bytes 写入 OSS 的 filename 路径。
返回 True 表示成功,False 表示全部重试均失败。
commit_message 参数保留以兼容 counter.py 的调用签名,OSS 不使用。
"""
if not cfg.available:
return False
bucket = _get_oss_bucket()
key = _oss_key(filename)
for attempt in range(retries):
try:
bucket.put_object(key, content)
return True
except Exception as e:
print(f'[PlatformUtils] OSS 上传失败(第 {attempt + 1} 次)({key}): {e}')
if attempt < retries - 1:
time.sleep(retry_delay)
return False
# 文件下载(引擎数据文件,与计数器无关)
# 魔搭创空间工作目录,studio repo 的文件会被同步到此处
_MS_WORKDIR = Path('/home/user/app')
# HF Storage Bucket 挂载检测
# HF Storage Buckets 挂载到 Space 时,会映射到容器内的一个本地路径
# (通常为 /data),文件可直接以本地路径读取,无需 hf_hub_download。
_HF_BUCKET_MOUNT = Path('/data')
def get_hf_bucket_path(relative: str) -> Optional[Path]:
"""
如果 HF Storage Bucket 已挂载且目标文件存在,返回本地绝对路径。
否则返回 None(调用方应 fallback 到 hf_hub_download)。
"""
candidate = _HF_BUCKET_MOUNT / relative
if candidate.exists():
return candidate
return None
def download_file(
filename: str,
*,
# HF 专用参数
hf_repo_id: Optional[str] = None,
hf_repo_type: str = 'space',
hf_token: Optional[str] = None,
# MS 专用参数(保留签名兼容性,魔搭平台已不再使用)
ms_repo_id: Optional[str] = None,
ms_token: Optional[str] = None,
ms_cache_dir: str = '/tmp/ms_cache',
) -> str:
"""
下载单个引擎数据文件,返回本地绝对路径字符串。
HF 平台:
优先从挂载的 Storage Bucket(/data)读取本地文件(零延迟)。
若 Bucket 未挂载或文件不存在,回退到从 Space repo 下载
(hf_repo_id 默认读取环境变量 SPACE_ID)。
MS 平台:
文件已随 studio repo 部署到容器本地,直接返回工作目录下的路径,
无需配置任何额外 repo。若文件不存在则抛出 FileNotFoundError。
本地:
直接返回原始路径。
"""
if PLATFORM == 'hf':
# 优先从挂载的 Storage Bucket 读取(本地路径,零延迟)
bucket_path = get_hf_bucket_path(filename)
if bucket_path is not None:
print(f'[PlatformUtils] 从 Storage Bucket 读取: {bucket_path}')
return str(bucket_path)
# Bucket 未挂载或文件不存在,回退到从 Space repo 下载
from huggingface_hub import hf_hub_download
repo_id = hf_repo_id or os.environ.get('SPACE_ID')
if not repo_id:
raise RuntimeError('[PlatformUtils] HF 平台未找到 SPACE_ID,无法下载文件。')
return hf_hub_download(
repo_id=repo_id,
repo_type=hf_repo_type,
filename=filename,
token=hf_token or os.environ.get('HF_TOKEN'),
)
if PLATFORM == 'ms':
local_path = _MS_WORKDIR / filename
if not local_path.is_file():
raise FileNotFoundError(
f'[PlatformUtils] 魔搭平台本地文件不存在: {local_path}\n'
f'请确认已将 {filename} 提交到创空间 studio repo 中。'
)
print(f'[PlatformUtils] MS 本地文件: {local_path}')
return str(local_path)
# 本地环境:直接返回原始路径(由调用方保证文件存在)
return filename
# 模型路径解析
LOCAL_MODEL_PATH = 'my_model_bge_m3'
HF_MODEL_ID = 'BAAI/bge-m3'
MS_MODEL_ID = 'BAAI/bge-m3' # 魔搭上同名,走国内节点
def resolve_model_path(prefer_local: Optional[str] = None) -> str:
"""
按优先级解析模型路径:
1. 本地目录(prefer_local 或 LOCAL_MODEL_PATH)
2. 当前平台的 Hub Model ID(首次会自动下载缓存)
返回可直接传给 SentenceTransformer 的路径或 model_id 字符串。
"""
local = prefer_local or LOCAL_MODEL_PATH
if os.path.exists(local):
print(f'[PlatformUtils] 使用本地模型: {local}')
return local
if PLATFORM == 'ms':
print(f'[PlatformUtils] 魔搭环境,使用 ModelScope Hub 模型: {MS_MODEL_ID}')
try:
from modelscope import snapshot_download
cached = snapshot_download(MS_MODEL_ID, cache_dir='/tmp/ms_model')
print(f'[PlatformUtils] 模型已缓存至: {cached}')
return cached
except Exception as e:
print(f'[PlatformUtils] ModelScope snapshot_download 失败,回退到 HF ID: {e}')
print(f'[PlatformUtils] 使用 HuggingFace Hub 模型: {HF_MODEL_ID}')
return HF_MODEL_ID