File size: 8,060 Bytes
3d79eb3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 |
# Copyright (c) Meta Platforms, Inc. and affiliates
# All rights reserved.
#
#
import dataclasses
import logging
from pathlib import Path
from typing import Dict, Optional, Union
import yaml
from fairseq2.assets import (
AssetNotFoundError,
InProcAssetMetadataProvider,
default_asset_store,
)
from fairseq2.assets.card import AssetCard
from fairseq2.checkpoint import FileCheckpointManager
from fairseq2.gang import FakeGang
from fairseq2.models import get_model_family
from fairseq2.typing import DataType, Device
from lcm.models.abstract_lcm import AbstractLCModel, AbstractLCModelConfig
from lcm.utils.model_type_registry import lcm_model_type_registry
logger = logging.getLogger(__file__)
def create_model_card(
checkpoint_path: Path,
model_config: Union[Dict, AbstractLCModelConfig, None],
model_type: str, # TODO: take this parameter from the config
model_name="on_the_fly_lcm",
model_arch: Optional[str] = None,
**additional_card_kwargs,
) -> AssetCard:
"""
Create an LCModel card given the checkpoint path and model args
Args:
- `checkpoint_path`: Path to the checkpoint to evaluate
- `model_config`: model parmeters
the default arch
"""
# Create a fairseq2 model card on the fly.
# assert (
# checkpoint_path.is_file()
# ), f"Couldn't find the checkpoint at {checkpoint_path}"
if isinstance(model_config, AbstractLCModelConfig):
model_config = dataclasses.asdict(model_config)
model_card_info = {
"name": model_name,
"model_family": model_type,
"checkpoint": "file://" + checkpoint_path.as_posix(),
**additional_card_kwargs,
}
if model_config is not None:
model_card_info["model_config"] = model_config
if model_arch is not None:
model_card_info["model_arch"] = model_arch
default_asset_store.metadata_providers.append(
InProcAssetMetadataProvider([model_card_info])
)
return default_asset_store.retrieve_card(model_name)
def load_model_with_overrides(
model_dir: Path,
step: Optional[int] = None,
model_type: Optional[str] = None,
device: Optional[Device] = None,
dtype: Optional[DataType] = None,
model_filename: str = "model.pt",
):
if step is not None:
checkpoint_path = model_dir / f"checkpoints/step_{step}" / model_filename
else:
checkpoint_path = model_dir / model_filename
# New checkpoint
config_path = checkpoint_path.parent / "model_card.yaml"
if config_path.exists():
try:
return load_model_from_card(
config_path.as_posix(), device=device, dtype=dtype
)
except Exception as exc:
logger.warning(
f"Model card {config_path} exists but is not valid ({exc}). "
"Try global config instead."
)
# Old checkpoint
config_path = model_dir / "config_logs/all_config.yaml"
if config_path.exists():
assert model_type, f"Need explicit model_type for checkpoint {checkpoint_path}"
with open(config_path, "r") as f:
config = yaml.full_load(f)
model_config = config["trainer"]["model_config_or_name"]
temporary_card = create_model_card(
checkpoint_path=checkpoint_path,
model_config=model_config,
model_type=model_type,
model_arch=f"toy_{model_type}",
)
loader_fn = lcm_model_type_registry.get_model_loader(model_type=model_type)
return loader_fn(temporary_card, device=device, dtype=dtype) # type: ignore
else:
raise ValueError(f"{model_dir} is not a valid model directory")
def create_model_card_from_training_folder(
folder: Union[str, Path],
card_name: str,
step_nr: Optional[int] = None,
) -> AssetCard:
"""
Extract the model config and the last checkpoint path using the checkpoint manager.
Create and return a model card
"""
folder_path = Path(folder)
assert folder_path.exists(), f"Model directory {folder} does not exist."
cp_dir = folder_path / "checkpoints"
gang = FakeGang()
checkpoint_manager = FileCheckpointManager(cp_dir, gang)
if step_nr is None:
step_numbers = checkpoint_manager.get_step_numbers()
if not step_numbers:
raise ValueError(
f"In {cp_dir}, no step number with model checkpoints was detected!"
)
step_nr = step_numbers[-1]
logger.info(f"Automatically setting step number as {step_nr}")
metadata = checkpoint_manager.load_metadata(step_nr)
assert metadata is not None, "The checkpoint does not have metadata."
training_config = metadata["config"]
model_config = training_config.model_config_or_name
cp_fn = checkpoint_manager._checkpoint_dir / f"step_{step_nr}" / "model.pt"
assert cp_fn, (
f"Checkpoint manager could not extract checkpoint path for step {step_nr}."
)
# TODO: deal with the fine-tuning case, where model_config is a string
if isinstance(model_config, str):
parent_card = default_asset_store.retrieve_card(model_config)
model_config = parent_card._metadata["model_config_or_name"]
model_type = parent_card._metadata["model_family"]
else:
model_type = model_config.model_type
card = create_model_card(
checkpoint_path=cp_fn.absolute(),
model_config=model_config,
model_type=model_type,
model_arch=f"toy_{model_type}", # TODO: get rid of the toy architecture when FS2 allows it
model_name=card_name,
)
return card
def save_model_card(card: AssetCard, path: Union[str, Path]) -> None:
"""Save a model card as YAML."""
card_data = card._metadata # TODO: use the exposed attribute when available
with open(path, "w", encoding="utf-8") as outfile:
yaml.dump(card_data, outfile, default_flow_style=False)
def load_model_from_card(
model_name: str,
device: Optional[Device] = None,
dtype: Optional[DataType] = None,
) -> AbstractLCModel:
"""
Load LC model from the given assed card or path.
The parameter `model_name` can be interpreted in multiple ways:
- as the name of the model card
- as the path to the yaml file of the model card
- as the path to the training directory of the model
- as the path to the model checkpoint (within a training directory, because we need to find the config)
"""
try:
card = default_asset_store.retrieve_card(model_name)
except AssetNotFoundError as err:
path = Path(model_name)
# If the card is not found, try looking it up by interpreting model_name as a path to the yaml card.
if path.exists() and path.suffix == ".yaml":
with open(path, "r", encoding="utf-8") as f:
card_data = yaml.full_load(f)
model_name = card_data["name"]
card = AssetCard(card_data)
# If the card is not found, try interpreting model_name as the model training directory
elif (path / "checkpoints").exists():
card = create_model_card_from_training_folder(
path, card_name="temporary_card"
)
# If the card is not found, try interpreting model_name as the path to the checkpoint within a training directory
elif (
path.suffix == ".pt"
and path.parent.name.startswith("step_")
and path.parent.parent.name == "checkpoints"
):
training_dir = path.parent.parent.parent
step_nr = int(path.parent.name[5:])
card = create_model_card_from_training_folder(
training_dir, card_name="temporary_card", step_nr=step_nr
)
else:
raise err
logger.info(f"Card loaded: {card}")
model_type = get_model_family(card)
loader = lcm_model_type_registry.get_model_loader(model_type=model_type)
model = loader(card, device=device, dtype=dtype)
return model
|