import os import hashlib import logging from typing import Optional, Dict, Any from huggingface_hub import snapshot_download, HfApi, HfFolder from huggingface_hub.utils import RepositoryNotFoundError, RevisionNotFoundError logger = logging.getLogger(__name__) class ModelManager: def __init__( self, model_name: str = "qwen/qwen-image-2512", model_dir: str = "./models", revision: str = "main", cache_dir: Optional[str] = None ): self.model_name = model_name self.model_dir = model_dir self.revision = revision self.cache_dir = cache_dir or os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "hub") self.api = HfApi() # 确保模型目录存在 os.makedirs(self.model_dir, exist_ok=True) def download_model(self) -> str: """ 下载模型到本地目录 Returns: str: 下载的模型路径 """ try: logger.info(f"开始下载模型: {self.model_name} (版本: {self.revision})") # 使用snapshot_download下载模型,会自动处理缓存 model_path = snapshot_download( repo_id=self.model_name, revision=self.revision, cache_dir=self.cache_dir, local_dir=self.model_dir, local_dir_use_symlinks=False, max_workers=4 # 限制下载线程数,避免资源占用过高 ) logger.info(f"模型下载完成: {model_path}") return model_path except RepositoryNotFoundError: logger.error(f"模型仓库不存在: {self.model_name}") raise except RevisionNotFoundError: logger.error(f"模型版本不存在: {self.revision}") raise except Exception as e: logger.error(f"模型下载失败: {str(e)}") raise def verify_model_integrity(self, model_path: str) -> bool: """ 验证模型文件完整性 Args: model_path: 模型文件路径 Returns: bool: 模型完整性验证结果 """ try: logger.info(f"验证模型完整性: {model_path}") # 检查必要的模型文件是否存在 required_files = ["model.safetensors", "config.json", "tokenizer.json"] missing_files = [] for file_name in required_files: file_path = os.path.join(model_path, file_name) if not os.path.exists(file_path): missing_files.append(file_name) if missing_files: logger.error(f"缺少必要的模型文件: {missing_files}") return False logger.info("模型完整性验证通过") return True except Exception as e: logger.error(f"模型完整性验证失败: {str(e)}") return False def get_model_info(self) -> Dict[str, Any]: """ 获取模型信息 Returns: Dict[str, Any]: 模型信息字典 """ try: logger.info(f"获取模型信息: {self.model_name}") model_info = self.api.model_info( repo_id=self.model_name, revision=self.revision ) return { "model_name": model_info.id, "revision": model_info.sha, "created_at": model_info.created_at, "last_modified": model_info.last_modified, "size": model_info.siblings[0].size if model_info.siblings else 0, "tags": model_info.tags } except Exception as e: logger.error(f"获取模型信息失败: {str(e)}") raise def is_model_available(self) -> bool: """ 检查模型是否已存在于本地 Returns: bool: 模型是否可用 """ try: # 检查模型目录是否存在且包含必要文件 if not os.path.exists(self.model_dir): return False required_files = ["model.safetensors", "config.json"] for file_name in required_files: file_path = os.path.join(self.model_dir, file_name) if not os.path.exists(file_path): return False return True except Exception as e: logger.error(f"检查模型可用性失败: {str(e)}") return False def calculate_file_hash(self, file_path: str, hash_algorithm: str = "sha256") -> str: """ 计算文件哈希值 Args: file_path: 文件路径 hash_algorithm: 哈希算法,默认为sha256 Returns: str: 文件哈希值 """ try: hash_func = hashlib.new(hash_algorithm) with open(file_path, "rb") as f: # 分块读取文件,避免内存占用过高 for chunk in iter(lambda: f.read(4096), b""): hash_func.update(chunk) return hash_func.hexdigest() except Exception as e: logger.error(f"计算文件哈希值失败: {str(e)}") raise def get_local_model_version(self) -> Optional[str]: """ 获取本地模型版本 Returns: Optional[str]: 本地模型版本,如果不存在则返回None """ try: version_file = os.path.join(self.model_dir, ".version") if os.path.exists(version_file): with open(version_file, "r") as f: return f.read().strip() return None except Exception as e: logger.error(f"获取本地模型版本失败: {str(e)}") return None def save_model_version(self, version: str) -> None: """ 保存模型版本到本地 Args: version: 模型版本 """ try: version_file = os.path.join(self.model_dir, ".version") with open(version_file, "w") as f: f.write(version) logger.info(f"模型版本已保存: {version}") except Exception as e: logger.error(f"保存模型版本失败: {str(e)}") raise def cleanup_old_models(self, keep_latest: int = 1) -> None: """ 清理旧模型版本 Args: keep_latest: 保留的最新模型数量 """ try: # 这个功能可以根据需要扩展,目前简单实现 logger.info(f"清理旧模型版本,保留最新 {keep_latest} 个版本") # 实际实现中可以遍历模型目录,根据版本号清理旧模型 except Exception as e: logger.error(f"清理旧模型版本失败: {str(e)}") raise if __name__ == "__main__": # 测试模型管理功能 import logging logging.basicConfig(level=logging.INFO) model_manager = ModelManager() if not model_manager.is_model_available(): model_path = model_manager.download_model() if model_manager.verify_model_integrity(model_path): model_info = model_manager.get_model_info() model_manager.save_model_version(model_info["revision"]) print("模型下载和验证成功") else: print("模型验证失败") else: print("模型已存在")