Z-Image / models /model_manager.py
AndrewKapok's picture
Upload models/model_manager.py with huggingface_hub
70f549a verified
import os
import hashlib
import logging
from typing import Optional, Dict, Any
from huggingface_hub import snapshot_download, HfApi, HfFolder
from huggingface_hub.utils import RepositoryNotFoundError, RevisionNotFoundError
logger = logging.getLogger(__name__)
class ModelManager:
def __init__(
self,
model_name: str = "qwen/qwen-image-2512",
model_dir: str = "./models",
revision: str = "main",
cache_dir: Optional[str] = None
):
self.model_name = model_name
self.model_dir = model_dir
self.revision = revision
self.cache_dir = cache_dir or os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "hub")
self.api = HfApi()
# 确保模型目录存在
os.makedirs(self.model_dir, exist_ok=True)
def download_model(self) -> str:
"""
下载模型到本地目录
Returns:
str: 下载的模型路径
"""
try:
logger.info(f"开始下载模型: {self.model_name} (版本: {self.revision})")
# 使用snapshot_download下载模型,会自动处理缓存
model_path = snapshot_download(
repo_id=self.model_name,
revision=self.revision,
cache_dir=self.cache_dir,
local_dir=self.model_dir,
local_dir_use_symlinks=False,
max_workers=4 # 限制下载线程数,避免资源占用过高
)
logger.info(f"模型下载完成: {model_path}")
return model_path
except RepositoryNotFoundError:
logger.error(f"模型仓库不存在: {self.model_name}")
raise
except RevisionNotFoundError:
logger.error(f"模型版本不存在: {self.revision}")
raise
except Exception as e:
logger.error(f"模型下载失败: {str(e)}")
raise
def verify_model_integrity(self, model_path: str) -> bool:
"""
验证模型文件完整性
Args:
model_path: 模型文件路径
Returns:
bool: 模型完整性验证结果
"""
try:
logger.info(f"验证模型完整性: {model_path}")
# 检查必要的模型文件是否存在
required_files = ["model.safetensors", "config.json", "tokenizer.json"]
missing_files = []
for file_name in required_files:
file_path = os.path.join(model_path, file_name)
if not os.path.exists(file_path):
missing_files.append(file_name)
if missing_files:
logger.error(f"缺少必要的模型文件: {missing_files}")
return False
logger.info("模型完整性验证通过")
return True
except Exception as e:
logger.error(f"模型完整性验证失败: {str(e)}")
return False
def get_model_info(self) -> Dict[str, Any]:
"""
获取模型信息
Returns:
Dict[str, Any]: 模型信息字典
"""
try:
logger.info(f"获取模型信息: {self.model_name}")
model_info = self.api.model_info(
repo_id=self.model_name,
revision=self.revision
)
return {
"model_name": model_info.id,
"revision": model_info.sha,
"created_at": model_info.created_at,
"last_modified": model_info.last_modified,
"size": model_info.siblings[0].size if model_info.siblings else 0,
"tags": model_info.tags
}
except Exception as e:
logger.error(f"获取模型信息失败: {str(e)}")
raise
def is_model_available(self) -> bool:
"""
检查模型是否已存在于本地
Returns:
bool: 模型是否可用
"""
try:
# 检查模型目录是否存在且包含必要文件
if not os.path.exists(self.model_dir):
return False
required_files = ["model.safetensors", "config.json"]
for file_name in required_files:
file_path = os.path.join(self.model_dir, file_name)
if not os.path.exists(file_path):
return False
return True
except Exception as e:
logger.error(f"检查模型可用性失败: {str(e)}")
return False
def calculate_file_hash(self, file_path: str, hash_algorithm: str = "sha256") -> str:
"""
计算文件哈希值
Args:
file_path: 文件路径
hash_algorithm: 哈希算法,默认为sha256
Returns:
str: 文件哈希值
"""
try:
hash_func = hashlib.new(hash_algorithm)
with open(file_path, "rb") as f:
# 分块读取文件,避免内存占用过高
for chunk in iter(lambda: f.read(4096), b""):
hash_func.update(chunk)
return hash_func.hexdigest()
except Exception as e:
logger.error(f"计算文件哈希值失败: {str(e)}")
raise
def get_local_model_version(self) -> Optional[str]:
"""
获取本地模型版本
Returns:
Optional[str]: 本地模型版本,如果不存在则返回None
"""
try:
version_file = os.path.join(self.model_dir, ".version")
if os.path.exists(version_file):
with open(version_file, "r") as f:
return f.read().strip()
return None
except Exception as e:
logger.error(f"获取本地模型版本失败: {str(e)}")
return None
def save_model_version(self, version: str) -> None:
"""
保存模型版本到本地
Args:
version: 模型版本
"""
try:
version_file = os.path.join(self.model_dir, ".version")
with open(version_file, "w") as f:
f.write(version)
logger.info(f"模型版本已保存: {version}")
except Exception as e:
logger.error(f"保存模型版本失败: {str(e)}")
raise
def cleanup_old_models(self, keep_latest: int = 1) -> None:
"""
清理旧模型版本
Args:
keep_latest: 保留的最新模型数量
"""
try:
# 这个功能可以根据需要扩展,目前简单实现
logger.info(f"清理旧模型版本,保留最新 {keep_latest} 个版本")
# 实际实现中可以遍历模型目录,根据版本号清理旧模型
except Exception as e:
logger.error(f"清理旧模型版本失败: {str(e)}")
raise
if __name__ == "__main__":
# 测试模型管理功能
import logging
logging.basicConfig(level=logging.INFO)
model_manager = ModelManager()
if not model_manager.is_model_available():
model_path = model_manager.download_model()
if model_manager.verify_model_integrity(model_path):
model_info = model_manager.get_model_info()
model_manager.save_model_version(model_info["revision"])
print("模型下载和验证成功")
else:
print("模型验证失败")
else:
print("模型已存在")