Spaces:
Sleeping
Sleeping
| from safetensors.torch import load_file | |
| import os | |
| import torch.nn as nn | |
| class ExtraModules(nn.Module): | |
| def __init__(self, **modules_dict): | |
| super().__init__() | |
| self._modules_dict = {} | |
| if modules_dict: | |
| for name, module in modules_dict.items(): | |
| self.add_module(name, module) | |
| self._modules_dict[name] = module | |
| def add_modules(self, **modules_dict): | |
| if modules_dict: | |
| for name, module in modules_dict.items(): | |
| self.add_module(name, module) | |
| self._modules_dict[name] = module | |
| def __getattr__(self, name): | |
| if '_modules_dict' in self.__dict__ and name in self._modules_dict: | |
| return self._modules_dict[name] | |
| return super().__getattr__(name) | |
| def get_module(self, name): | |
| return self._modules_dict.get(name, None) | |
| def get_unwrap_dict(self, accelerator): | |
| unwrap = accelerator.unwrap_model | |
| return {name: unwrap(module) for name, module in self._modules_dict.items()} | |
| def forward(self): | |
| pass | |
| # 🧩 新增方法:自动加载所有模块的权重 | |
| def load_pretrained(self, dir_path, suffix=".safetensors", strict=True, verbose=True): | |
| """ | |
| 自动从给定目录加载 _modules_dict 中的模块。 | |
| 文件命名需与模块名一致,如 con_encoder.safetensors。 | |
| """ | |
| loaded, missing = [], [] | |
| for name, module in self._modules_dict.items(): | |
| path = os.path.join(dir_path, f"{name}{suffix}") | |
| if os.path.exists(path): | |
| try: | |
| state = load_file(path) | |
| module.load_state_dict(state, strict=strict) | |
| loaded.append(name) | |
| if verbose: | |
| print(f"[ExtraModules] Loaded pretrained weights for '{name}' from {path}") | |
| except Exception as e: | |
| print(f"[ExtraModules] Failed to load '{name}': {e}") | |
| missing.append(name) | |
| else: | |
| if verbose: | |
| print(f"[ExtraModules] No checkpoint found for '{name}' at {path}") | |
| missing.append(name) | |
| return {"loaded": loaded, "missing": missing} | |
| class ExtraItems: | |
| def __init__(self, **object_dict): | |
| super().__init__() | |
| self._objects_dict = {} # 用于存普通对象 | |
| if object_dict: | |
| for name, obj in object_dict.items(): | |
| self.add_item(name, obj) | |
| def add_item(self, name, obj): | |
| self._objects_dict[name] = obj | |
| def add_items(self, **items_dict): | |
| for name, obj in items_dict.items(): | |
| self.add_item(name, obj) | |
| def __getattr__(self, name): | |
| # 再返回普通对象 | |
| if '_objects_dict' in self.__dict__ and name in self._objects_dict: | |
| return self._objects_dict[name] | |
| return None | |