LexaLCM_Pre0 / lcm /utils /card_utils.py
Lexa
Initial commit
3d79eb3
# Copyright (c) Meta Platforms, Inc. and affiliates
# All rights reserved.
#
#
import dataclasses
import logging
from pathlib import Path
from typing import Dict, Optional, Union
import yaml
from fairseq2.assets import (
AssetNotFoundError,
InProcAssetMetadataProvider,
default_asset_store,
)
from fairseq2.assets.card import AssetCard
from fairseq2.checkpoint import FileCheckpointManager
from fairseq2.gang import FakeGang
from fairseq2.models import get_model_family
from fairseq2.typing import DataType, Device
from lcm.models.abstract_lcm import AbstractLCModel, AbstractLCModelConfig
from lcm.utils.model_type_registry import lcm_model_type_registry
logger = logging.getLogger(__file__)
def create_model_card(
checkpoint_path: Path,
model_config: Union[Dict, AbstractLCModelConfig, None],
model_type: str, # TODO: take this parameter from the config
model_name="on_the_fly_lcm",
model_arch: Optional[str] = None,
**additional_card_kwargs,
) -> AssetCard:
"""
Create an LCModel card given the checkpoint path and model args
Args:
- `checkpoint_path`: Path to the checkpoint to evaluate
- `model_config`: model parmeters
the default arch
"""
# Create a fairseq2 model card on the fly.
# assert (
# checkpoint_path.is_file()
# ), f"Couldn't find the checkpoint at {checkpoint_path}"
if isinstance(model_config, AbstractLCModelConfig):
model_config = dataclasses.asdict(model_config)
model_card_info = {
"name": model_name,
"model_family": model_type,
"checkpoint": "file://" + checkpoint_path.as_posix(),
**additional_card_kwargs,
}
if model_config is not None:
model_card_info["model_config"] = model_config
if model_arch is not None:
model_card_info["model_arch"] = model_arch
default_asset_store.metadata_providers.append(
InProcAssetMetadataProvider([model_card_info])
)
return default_asset_store.retrieve_card(model_name)
def load_model_with_overrides(
model_dir: Path,
step: Optional[int] = None,
model_type: Optional[str] = None,
device: Optional[Device] = None,
dtype: Optional[DataType] = None,
model_filename: str = "model.pt",
):
if step is not None:
checkpoint_path = model_dir / f"checkpoints/step_{step}" / model_filename
else:
checkpoint_path = model_dir / model_filename
# New checkpoint
config_path = checkpoint_path.parent / "model_card.yaml"
if config_path.exists():
try:
return load_model_from_card(
config_path.as_posix(), device=device, dtype=dtype
)
except Exception as exc:
logger.warning(
f"Model card {config_path} exists but is not valid ({exc}). "
"Try global config instead."
)
# Old checkpoint
config_path = model_dir / "config_logs/all_config.yaml"
if config_path.exists():
assert model_type, f"Need explicit model_type for checkpoint {checkpoint_path}"
with open(config_path, "r") as f:
config = yaml.full_load(f)
model_config = config["trainer"]["model_config_or_name"]
temporary_card = create_model_card(
checkpoint_path=checkpoint_path,
model_config=model_config,
model_type=model_type,
model_arch=f"toy_{model_type}",
)
loader_fn = lcm_model_type_registry.get_model_loader(model_type=model_type)
return loader_fn(temporary_card, device=device, dtype=dtype) # type: ignore
else:
raise ValueError(f"{model_dir} is not a valid model directory")
def create_model_card_from_training_folder(
folder: Union[str, Path],
card_name: str,
step_nr: Optional[int] = None,
) -> AssetCard:
"""
Extract the model config and the last checkpoint path using the checkpoint manager.
Create and return a model card
"""
folder_path = Path(folder)
assert folder_path.exists(), f"Model directory {folder} does not exist."
cp_dir = folder_path / "checkpoints"
gang = FakeGang()
checkpoint_manager = FileCheckpointManager(cp_dir, gang)
if step_nr is None:
step_numbers = checkpoint_manager.get_step_numbers()
if not step_numbers:
raise ValueError(
f"In {cp_dir}, no step number with model checkpoints was detected!"
)
step_nr = step_numbers[-1]
logger.info(f"Automatically setting step number as {step_nr}")
metadata = checkpoint_manager.load_metadata(step_nr)
assert metadata is not None, "The checkpoint does not have metadata."
training_config = metadata["config"]
model_config = training_config.model_config_or_name
cp_fn = checkpoint_manager._checkpoint_dir / f"step_{step_nr}" / "model.pt"
assert cp_fn, (
f"Checkpoint manager could not extract checkpoint path for step {step_nr}."
)
# TODO: deal with the fine-tuning case, where model_config is a string
if isinstance(model_config, str):
parent_card = default_asset_store.retrieve_card(model_config)
model_config = parent_card._metadata["model_config_or_name"]
model_type = parent_card._metadata["model_family"]
else:
model_type = model_config.model_type
card = create_model_card(
checkpoint_path=cp_fn.absolute(),
model_config=model_config,
model_type=model_type,
model_arch=f"toy_{model_type}", # TODO: get rid of the toy architecture when FS2 allows it
model_name=card_name,
)
return card
def save_model_card(card: AssetCard, path: Union[str, Path]) -> None:
"""Save a model card as YAML."""
card_data = card._metadata # TODO: use the exposed attribute when available
with open(path, "w", encoding="utf-8") as outfile:
yaml.dump(card_data, outfile, default_flow_style=False)
def load_model_from_card(
model_name: str,
device: Optional[Device] = None,
dtype: Optional[DataType] = None,
) -> AbstractLCModel:
"""
Load LC model from the given assed card or path.
The parameter `model_name` can be interpreted in multiple ways:
- as the name of the model card
- as the path to the yaml file of the model card
- as the path to the training directory of the model
- as the path to the model checkpoint (within a training directory, because we need to find the config)
"""
try:
card = default_asset_store.retrieve_card(model_name)
except AssetNotFoundError as err:
path = Path(model_name)
# If the card is not found, try looking it up by interpreting model_name as a path to the yaml card.
if path.exists() and path.suffix == ".yaml":
with open(path, "r", encoding="utf-8") as f:
card_data = yaml.full_load(f)
model_name = card_data["name"]
card = AssetCard(card_data)
# If the card is not found, try interpreting model_name as the model training directory
elif (path / "checkpoints").exists():
card = create_model_card_from_training_folder(
path, card_name="temporary_card"
)
# If the card is not found, try interpreting model_name as the path to the checkpoint within a training directory
elif (
path.suffix == ".pt"
and path.parent.name.startswith("step_")
and path.parent.parent.name == "checkpoints"
):
training_dir = path.parent.parent.parent
step_nr = int(path.parent.name[5:])
card = create_model_card_from_training_folder(
training_dir, card_name="temporary_card", step_nr=step_nr
)
else:
raise err
logger.info(f"Card loaded: {card}")
model_type = get_model_family(card)
loader = lcm_model_type_registry.get_model_loader(model_type=model_type)
model = loader(card, device=device, dtype=dtype)
return model