Spaces:
Runtime error
Runtime error
| import os | |
| import hashlib | |
| import logging | |
| from typing import Optional, Dict, Any | |
| from huggingface_hub import snapshot_download, HfApi, HfFolder | |
| from huggingface_hub.utils import RepositoryNotFoundError, RevisionNotFoundError | |
| logger = logging.getLogger(__name__) | |
| class ModelManager: | |
| def __init__( | |
| self, | |
| model_name: str = "qwen/qwen-image-2512", | |
| model_dir: str = "./models", | |
| revision: str = "main", | |
| cache_dir: Optional[str] = None | |
| ): | |
| self.model_name = model_name | |
| self.model_dir = model_dir | |
| self.revision = revision | |
| self.cache_dir = cache_dir or os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "hub") | |
| self.api = HfApi() | |
| # 确保模型目录存在 | |
| os.makedirs(self.model_dir, exist_ok=True) | |
| def download_model(self) -> str: | |
| """ | |
| 下载模型到本地目录 | |
| Returns: | |
| str: 下载的模型路径 | |
| """ | |
| try: | |
| logger.info(f"开始下载模型: {self.model_name} (版本: {self.revision})") | |
| # 使用snapshot_download下载模型,会自动处理缓存 | |
| model_path = snapshot_download( | |
| repo_id=self.model_name, | |
| revision=self.revision, | |
| cache_dir=self.cache_dir, | |
| local_dir=self.model_dir, | |
| local_dir_use_symlinks=False, | |
| max_workers=4 # 限制下载线程数,避免资源占用过高 | |
| ) | |
| logger.info(f"模型下载完成: {model_path}") | |
| return model_path | |
| except RepositoryNotFoundError: | |
| logger.error(f"模型仓库不存在: {self.model_name}") | |
| raise | |
| except RevisionNotFoundError: | |
| logger.error(f"模型版本不存在: {self.revision}") | |
| raise | |
| except Exception as e: | |
| logger.error(f"模型下载失败: {str(e)}") | |
| raise | |
| def verify_model_integrity(self, model_path: str) -> bool: | |
| """ | |
| 验证模型文件完整性 | |
| Args: | |
| model_path: 模型文件路径 | |
| Returns: | |
| bool: 模型完整性验证结果 | |
| """ | |
| try: | |
| logger.info(f"验证模型完整性: {model_path}") | |
| # 检查必要的模型文件是否存在 | |
| required_files = ["model.safetensors", "config.json", "tokenizer.json"] | |
| missing_files = [] | |
| for file_name in required_files: | |
| file_path = os.path.join(model_path, file_name) | |
| if not os.path.exists(file_path): | |
| missing_files.append(file_name) | |
| if missing_files: | |
| logger.error(f"缺少必要的模型文件: {missing_files}") | |
| return False | |
| logger.info("模型完整性验证通过") | |
| return True | |
| except Exception as e: | |
| logger.error(f"模型完整性验证失败: {str(e)}") | |
| return False | |
| def get_model_info(self) -> Dict[str, Any]: | |
| """ | |
| 获取模型信息 | |
| Returns: | |
| Dict[str, Any]: 模型信息字典 | |
| """ | |
| try: | |
| logger.info(f"获取模型信息: {self.model_name}") | |
| model_info = self.api.model_info( | |
| repo_id=self.model_name, | |
| revision=self.revision | |
| ) | |
| return { | |
| "model_name": model_info.id, | |
| "revision": model_info.sha, | |
| "created_at": model_info.created_at, | |
| "last_modified": model_info.last_modified, | |
| "size": model_info.siblings[0].size if model_info.siblings else 0, | |
| "tags": model_info.tags | |
| } | |
| except Exception as e: | |
| logger.error(f"获取模型信息失败: {str(e)}") | |
| raise | |
| def is_model_available(self) -> bool: | |
| """ | |
| 检查模型是否已存在于本地 | |
| Returns: | |
| bool: 模型是否可用 | |
| """ | |
| try: | |
| # 检查模型目录是否存在且包含必要文件 | |
| if not os.path.exists(self.model_dir): | |
| return False | |
| required_files = ["model.safetensors", "config.json"] | |
| for file_name in required_files: | |
| file_path = os.path.join(self.model_dir, file_name) | |
| if not os.path.exists(file_path): | |
| return False | |
| return True | |
| except Exception as e: | |
| logger.error(f"检查模型可用性失败: {str(e)}") | |
| return False | |
| def calculate_file_hash(self, file_path: str, hash_algorithm: str = "sha256") -> str: | |
| """ | |
| 计算文件哈希值 | |
| Args: | |
| file_path: 文件路径 | |
| hash_algorithm: 哈希算法,默认为sha256 | |
| Returns: | |
| str: 文件哈希值 | |
| """ | |
| try: | |
| hash_func = hashlib.new(hash_algorithm) | |
| with open(file_path, "rb") as f: | |
| # 分块读取文件,避免内存占用过高 | |
| for chunk in iter(lambda: f.read(4096), b""): | |
| hash_func.update(chunk) | |
| return hash_func.hexdigest() | |
| except Exception as e: | |
| logger.error(f"计算文件哈希值失败: {str(e)}") | |
| raise | |
| def get_local_model_version(self) -> Optional[str]: | |
| """ | |
| 获取本地模型版本 | |
| Returns: | |
| Optional[str]: 本地模型版本,如果不存在则返回None | |
| """ | |
| try: | |
| version_file = os.path.join(self.model_dir, ".version") | |
| if os.path.exists(version_file): | |
| with open(version_file, "r") as f: | |
| return f.read().strip() | |
| return None | |
| except Exception as e: | |
| logger.error(f"获取本地模型版本失败: {str(e)}") | |
| return None | |
| def save_model_version(self, version: str) -> None: | |
| """ | |
| 保存模型版本到本地 | |
| Args: | |
| version: 模型版本 | |
| """ | |
| try: | |
| version_file = os.path.join(self.model_dir, ".version") | |
| with open(version_file, "w") as f: | |
| f.write(version) | |
| logger.info(f"模型版本已保存: {version}") | |
| except Exception as e: | |
| logger.error(f"保存模型版本失败: {str(e)}") | |
| raise | |
| def cleanup_old_models(self, keep_latest: int = 1) -> None: | |
| """ | |
| 清理旧模型版本 | |
| Args: | |
| keep_latest: 保留的最新模型数量 | |
| """ | |
| try: | |
| # 这个功能可以根据需要扩展,目前简单实现 | |
| logger.info(f"清理旧模型版本,保留最新 {keep_latest} 个版本") | |
| # 实际实现中可以遍历模型目录,根据版本号清理旧模型 | |
| except Exception as e: | |
| logger.error(f"清理旧模型版本失败: {str(e)}") | |
| raise | |
| if __name__ == "__main__": | |
| # 测试模型管理功能 | |
| import logging | |
| logging.basicConfig(level=logging.INFO) | |
| model_manager = ModelManager() | |
| if not model_manager.is_model_available(): | |
| model_path = model_manager.download_model() | |
| if model_manager.verify_model_integrity(model_path): | |
| model_info = model_manager.get_model_info() | |
| model_manager.save_model_version(model_info["revision"]) | |
| print("模型下载和验证成功") | |
| else: | |
| print("模型验证失败") | |
| else: | |
| print("模型已存在") | |