Spaces:
Running
Running
| """ | |
| platform_utils.py | |
| 统一的平台检测与 Hub 操作封装。 | |
| 支持平台: | |
| - HuggingFace Space | |
| - ModelScope 创空间(魔搭) | |
| - 本地开发环境 | |
| 对外暴露: | |
| PLATFORM : Literal['hf', 'ms', 'local'] | |
| is_cloud() : bool | |
| get_host_port() : tuple[str, int] | |
| download_file() : 下载单个文件,返回本地路径 | |
| upload_bytes() : 上传 bytes 到 OSS(用于计数器持久化) | |
| read_bytes() : 从 OSS 读取文件内容,返回 bytes | None | |
| get_counter_cfg() : 返回 CounterConfig(platform / available) | |
| 环境变量约定: | |
| HuggingFace Space(由 HF 自动注入) | |
| SPACE_ID Space 唯一标识,存在即代表在 HF 环境 | |
| SPACE_AUTHOR_NAME 作者名 | |
| 用户手动配置(HF Secrets): | |
| HF_TOKEN HF 访问令牌(仅用于 download_file,非计数器) | |
| ModelScope 创空间(由魔搭自动注入) | |
| MODELSCOPE_ENVIRONMENT 存在即代表在魔搭环境(值通常为 "studio") | |
| STUDIO_ID 创空间 ID(备用检测) | |
| 魔搭平台数据文件说明: | |
| 数据文件(CSV / parquet / safetensors)直接放在创空间 studio repo | |
| 中,容器启动时会自动同步到工作目录,download_file() 在 MS 平台 | |
| 直接返回本地路径,无需额外配置 Model repo。 | |
| 阿里云 OSS(计数器唯一后端,HF 与 MS 共享同一数据) | |
| OSS_ACCESS_KEY_ID RAM 子账号 AccessKey ID | |
| OSS_ACCESS_KEY_SECRET RAM 子账号 AccessKey Secret | |
| OSS_ENDPOINT Bucket 所在地域节点 | |
| 例: oss-cn-hangzhou.aliyuncs.com | |
| (无需加 https://,代码自动拼接) | |
| OSS_BUCKET_NAME Bucket 名称 | |
| OSS_COUNTER_DIR 计数文件在 Bucket 中的前缀目录(可选) | |
| 默认 "danbooru_counter" | |
| 最终路径: {OSS_COUNTER_DIR}/count.json | |
| """ | |
| from __future__ import annotations | |
| import os | |
| import time | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| import oss2 | |
| from typing import Literal, Optional | |
| # 平台检测 | |
| def _detect_platform() -> Literal['hf', 'ms', 'local']: | |
| if os.environ.get('SPACE_ID'): | |
| return 'hf' | |
| if os.environ.get('MODELSCOPE_ENVIRONMENT') or os.environ.get('STUDIO_ID'): | |
| return 'ms' | |
| return 'local' | |
| PLATFORM: Literal['hf', 'ms', 'local'] = _detect_platform() | |
| def is_cloud() -> bool: | |
| """是否运行在任意云端平台。""" | |
| return PLATFORM in ('hf', 'ms') | |
| def get_host_port() -> tuple[str, int]: | |
| """ | |
| 返回 NiceGUI 应使用的 (host, port)。 | |
| HF 和魔搭创空间都使用 0.0.0.0:7860;本地使用 127.0.0.1:1111。 | |
| """ | |
| if is_cloud(): | |
| return '0.0.0.0', 7860 | |
| return '127.0.0.1', 11111 | |
| def nsfw_allowed() -> bool: | |
| """ | |
| 返回当前平台是否允许用户开启 NSFW 显示。 | |
| 魔搭(MS)平台禁用 NSFW,其余平台默认允许。 | |
| 如需在任意平台强制禁用,可设置环境变量 DISABLE_NSFW=1。 | |
| """ | |
| if os.environ.get('DISABLE_NSFW', '0') == '1': | |
| return False | |
| return PLATFORM != 'ms' | |
| # 阿里云 OSS | |
| def _get_oss_bucket(): | |
| """ | |
| 从环境变量读取 OSS 配置,返回 oss2.Bucket 对象。 | |
| 若环境变量不完整或 oss2 未安装则返回 None。 | |
| """ | |
| ak = os.environ.get('OSS_ACCESS_KEY_ID') | |
| sk = os.environ.get('OSS_ACCESS_KEY_SECRET') | |
| ep = os.environ.get('OSS_ENDPOINT') | |
| bkt = os.environ.get('OSS_BUCKET_NAME') | |
| if not all([ak, sk, ep, bkt]): | |
| return None | |
| try: | |
| import oss2 | |
| auth = oss2.Auth(ak, sk) | |
| endpoint = ep if ep.startswith('http') else f'https://{ep}' | |
| return oss2.Bucket(auth, endpoint, bkt) | |
| except ImportError: | |
| print('[PlatformUtils] oss2 未安装,OSS 计数器不可用。请 pip install oss2。') | |
| return None | |
| def _oss_key(filename: str) -> str: | |
| """将 filename 拼上可选的前缀目录,得到 OSS Object Key。""" | |
| prefix = os.environ.get('OSS_COUNTER_DIR', 'danbooru_counter').rstrip('/') | |
| return f'{prefix}/{filename}' | |
| def _oss_available() -> bool: | |
| """检测 OSS 四项环境变量是否均已设置且 oss2 可导入。""" | |
| return _get_oss_bucket() is not None | |
| # 计数器配置 | |
| class CounterConfig: | |
| platform: Literal['oss', 'local'] | |
| def available(self) -> bool: | |
| if self.platform == 'oss': | |
| return _oss_available() | |
| return False | |
| def get_counter_cfg() -> CounterConfig: | |
| """ | |
| 读取计数器配置。 | |
| 配置了 OSS 环境变量则使用 OSS,否则退化为本地模式(无持久化)。 | |
| """ | |
| if _oss_available(): | |
| return CounterConfig(platform='oss') | |
| return CounterConfig(platform='local') | |
| # 计数器读写(OSS) | |
| def read_bytes(filename: str, cfg: CounterConfig) -> Optional[bytes]: | |
| """ | |
| 从 OSS 读取文件内容,返回 bytes。 | |
| 文件不存在返回 None;网络或权限异常向上抛出。 | |
| """ | |
| if not cfg.available: | |
| return None | |
| bucket = _get_oss_bucket() | |
| key = _oss_key(filename) | |
| try: | |
| import oss2 | |
| result = bucket.get_object(key) | |
| return result.read() | |
| except oss2.exceptions.NoSuchKey: | |
| return None | |
| except Exception as e: | |
| print(f'[PlatformUtils] OSS 读取失败 ({key}): {e}') | |
| raise | |
| def upload_bytes( | |
| content: bytes, | |
| filename: str, | |
| cfg: CounterConfig, | |
| commit_message: str = 'Update', | |
| *, | |
| retries: int = 3, | |
| retry_delay: float = 1.0, | |
| ) -> bool: | |
| """ | |
| 将 bytes 写入 OSS 的 filename 路径。 | |
| 返回 True 表示成功,False 表示全部重试均失败。 | |
| commit_message 参数保留以兼容 counter.py 的调用签名,OSS 不使用。 | |
| """ | |
| if not cfg.available: | |
| return False | |
| bucket = _get_oss_bucket() | |
| key = _oss_key(filename) | |
| for attempt in range(retries): | |
| try: | |
| bucket.put_object(key, content) | |
| return True | |
| except Exception as e: | |
| print(f'[PlatformUtils] OSS 上传失败(第 {attempt + 1} 次)({key}): {e}') | |
| if attempt < retries - 1: | |
| time.sleep(retry_delay) | |
| return False | |
| # 文件下载(引擎数据文件,与计数器无关) | |
| # 魔搭创空间工作目录,studio repo 的文件会被同步到此处 | |
| _MS_WORKDIR = Path('/home/user/app') | |
| # HF Storage Bucket 挂载检测 | |
| # HF Storage Buckets 挂载到 Space 时,会映射到容器内的一个本地路径 | |
| # (通常为 /data),文件可直接以本地路径读取,无需 hf_hub_download。 | |
| _HF_BUCKET_MOUNT = Path('/data') | |
| def get_hf_bucket_path(relative: str) -> Optional[Path]: | |
| """ | |
| 如果 HF Storage Bucket 已挂载且目标文件存在,返回本地绝对路径。 | |
| 否则返回 None(调用方应 fallback 到 hf_hub_download)。 | |
| """ | |
| candidate = _HF_BUCKET_MOUNT / relative | |
| if candidate.exists(): | |
| return candidate | |
| return None | |
| def download_file( | |
| filename: str, | |
| *, | |
| # HF 专用参数 | |
| hf_repo_id: Optional[str] = None, | |
| hf_repo_type: str = 'space', | |
| hf_token: Optional[str] = None, | |
| # MS 专用参数(保留签名兼容性,魔搭平台已不再使用) | |
| ms_repo_id: Optional[str] = None, | |
| ms_token: Optional[str] = None, | |
| ms_cache_dir: str = '/tmp/ms_cache', | |
| ) -> str: | |
| """ | |
| 下载单个引擎数据文件,返回本地绝对路径字符串。 | |
| HF 平台: | |
| 优先从挂载的 Storage Bucket(/data)读取本地文件(零延迟)。 | |
| 若 Bucket 未挂载或文件不存在,回退到从 Space repo 下载 | |
| (hf_repo_id 默认读取环境变量 SPACE_ID)。 | |
| MS 平台: | |
| 文件已随 studio repo 部署到容器本地,直接返回工作目录下的路径, | |
| 无需配置任何额外 repo。若文件不存在则抛出 FileNotFoundError。 | |
| 本地: | |
| 直接返回原始路径。 | |
| """ | |
| if PLATFORM == 'hf': | |
| # 优先从挂载的 Storage Bucket 读取(本地路径,零延迟) | |
| bucket_path = get_hf_bucket_path(filename) | |
| if bucket_path is not None: | |
| print(f'[PlatformUtils] 从 Storage Bucket 读取: {bucket_path}') | |
| return str(bucket_path) | |
| # Bucket 未挂载或文件不存在,回退到从 Space repo 下载 | |
| from huggingface_hub import hf_hub_download | |
| repo_id = hf_repo_id or os.environ.get('SPACE_ID') | |
| if not repo_id: | |
| raise RuntimeError('[PlatformUtils] HF 平台未找到 SPACE_ID,无法下载文件。') | |
| return hf_hub_download( | |
| repo_id=repo_id, | |
| repo_type=hf_repo_type, | |
| filename=filename, | |
| token=hf_token or os.environ.get('HF_TOKEN'), | |
| ) | |
| if PLATFORM == 'ms': | |
| local_path = _MS_WORKDIR / filename | |
| if not local_path.is_file(): | |
| raise FileNotFoundError( | |
| f'[PlatformUtils] 魔搭平台本地文件不存在: {local_path}\n' | |
| f'请确认已将 {filename} 提交到创空间 studio repo 中。' | |
| ) | |
| print(f'[PlatformUtils] MS 本地文件: {local_path}') | |
| return str(local_path) | |
| # 本地环境:直接返回原始路径(由调用方保证文件存在) | |
| return filename | |
| # 模型路径解析 | |
| LOCAL_MODEL_PATH = 'my_model_bge_m3' | |
| HF_MODEL_ID = 'BAAI/bge-m3' | |
| MS_MODEL_ID = 'BAAI/bge-m3' # 魔搭上同名,走国内节点 | |
| def resolve_model_path(prefer_local: Optional[str] = None) -> str: | |
| """ | |
| 按优先级解析模型路径: | |
| 1. 本地目录(prefer_local 或 LOCAL_MODEL_PATH) | |
| 2. 当前平台的 Hub Model ID(首次会自动下载缓存) | |
| 返回可直接传给 SentenceTransformer 的路径或 model_id 字符串。 | |
| """ | |
| local = prefer_local or LOCAL_MODEL_PATH | |
| if os.path.exists(local): | |
| print(f'[PlatformUtils] 使用本地模型: {local}') | |
| return local | |
| if PLATFORM == 'ms': | |
| print(f'[PlatformUtils] 魔搭环境,使用 ModelScope Hub 模型: {MS_MODEL_ID}') | |
| try: | |
| from modelscope import snapshot_download | |
| cached = snapshot_download(MS_MODEL_ID, cache_dir='/tmp/ms_model') | |
| print(f'[PlatformUtils] 模型已缓存至: {cached}') | |
| return cached | |
| except Exception as e: | |
| print(f'[PlatformUtils] ModelScope snapshot_download 失败,回退到 HF ID: {e}') | |
| print(f'[PlatformUtils] 使用 HuggingFace Hub 模型: {HF_MODEL_ID}') | |
| return HF_MODEL_ID |