zhouyik's picture
Upload folder using huggingface_hub
032e687 verified
# -*- coding: utf-8 -*-
# Copyright (c) Facebook, Inc. and its affiliates.
"""
This file contains components with some default boilerplate logic user may need
in training / testing. They will not work for everyone, but many users may find them useful.
The behavior of functions/classes in this file is subject to change,
since they are meant to represent the "common default behavior" people need in their projects.
"""
import copy
import os
import sys
import functools
import torch
try:
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy, ModuleWrapPolicy
except ImportError as e:
print(e, "just skip this")
from ape.checkpoint import DetectionCheckpointer
from detectron2.config import instantiate
from detectron2.utils import comm
from transformers.trainer_pt_utils import get_module_class_from_name
__all__ = [
"create_fsdp_model",
"DefaultPredictor",
]
def create_fsdp_model(model, *, fp16_compression=False, **kwargs):
"""
Create a DistributedDataParallel model if there are >1 processes.
Args:
model: a torch.nn.Module
fp16_compression: add fp16 compression hooks to the ddp object.
See more at https://pytorch.org/docs/stable/ddp_comm_hooks.html#torch.distributed.algorithms.ddp_comm_hooks.default_hooks.fp16_compress_hook
kwargs: other arguments of :module:`torch.nn.parallel.DistributedDataParallel`.
""" # noqa
sharding_strategy_dict = {
"NO_SHARD": ShardingStrategy.NO_SHARD,
"SHARD_GRAD_OP": ShardingStrategy.SHARD_GRAD_OP,
"FULL_SHARD": ShardingStrategy.FULL_SHARD,
}
dtype_dict = {
"fp32": torch.float32,
"fp16": torch.float16,
"bf16": torch.bfloat16,
}
auto_wrap_policy = None
module_name_to_wrap = kwargs.pop("module_name_to_wrap", None)
if module_name_to_wrap is not None:
module_cls_to_wrap = set()
for module_name in module_name_to_wrap:
module_cls = get_module_class_from_name(model, module_name)
if module_cls is None:
raise Exception("Could not find the layer class to wrap in the model.")
else:
module_cls_to_wrap.add(module_cls)
# print("module_cls_to_wrap", module_cls_to_wrap)
# auto_wrap_policy = functools.partial(
# transformer_auto_wrap_policy,
# # Transformer layer class to wrap
# transformer_layer_cls=module_cls_to_wrap,
# )
auto_wrap_policy = ModuleWrapPolicy(module_cls_to_wrap)
else:
# auto_wrap_policy = functools.partial(
# size_based_auto_wrap_policy, min_num_params=int(1e5)
# )
auto_wrap_policy = size_based_auto_wrap_policy
if comm.get_world_size() == 1:
return model
if "device_id" not in kwargs:
kwargs["device_id"] = comm.get_local_rank()
param_dtype = kwargs.pop("param_dtype", None)
reduce_dtype = kwargs.pop("reduce_dtype", None)
buffer_dtype = kwargs.pop("buffer_dtype", None)
if param_dtype is not None:
param_dtype = getattr(torch, param_dtype)
if reduce_dtype is not None:
reduce_dtype = getattr(torch, reduce_dtype)
if buffer_dtype is not None:
buffer_dtype = getattr(torch, buffer_dtype)
# from ape.layers import MultiScaleDeformableAttention
mp_policy = MixedPrecision(
param_dtype=param_dtype,
# Gradient communication precision.
reduce_dtype=reduce_dtype,
# Buffer precision.
buffer_dtype=buffer_dtype,
cast_forward_inputs=True,
# _module_classes_to_ignore=(MultiScaleDeformableAttention,),
)
model = model.to(param_dtype)
fsdp = FSDP(
model,
auto_wrap_policy=auto_wrap_policy,
mixed_precision=mp_policy,
**kwargs,
)
return fsdp
model.model_vision.model_language = FSDP(
model.model_vision.model_language,
# auto_wrap_policy=auto_wrap_policy,
sharding_strategy=ShardingStrategy.NO_SHARD,
mixed_precision=mp_policy,
**kwargs,
)
model.model_vision.backbone = FSDP(
model.model_vision.backbone,
auto_wrap_policy=auto_wrap_policy,
mixed_precision=mp_policy,
**kwargs,
)
model.model_vision.transfomer = FSDP(
model.model_vision.transformer,
auto_wrap_policy=auto_wrap_policy,
mixed_precision=mp_policy,
**kwargs,
)
# auto_wrap_policy = functools.partial(
# size_based_auto_wrap_policy, min_num_params=int(1e5)
# )
fsdp = FSDP(
model,
# auto_wrap_policy=size_based_auto_wrap_policy,
sharding_strategy=ShardingStrategy.NO_SHARD,
mixed_precision=mp_policy,
**kwargs,
)
# if fp16_compression:
# from torch.distributed.algorithms.ddp_comm_hooks import default as comm_hooks
# ddp.register_comm_hook(state=None, hook=comm_hooks.fp16_compress_hook)
return fsdp
class DefaultPredictor:
"""
Create a simple end-to-end predictor with the given config that runs on
single device for a single input image.
Compared to using the model directly, this class does the following additions:
1. Load checkpoint from `cfg.MODEL.WEIGHTS`.
2. Always take BGR image as the input and apply conversion defined by `cfg.INPUT.FORMAT`.
3. Apply resizing defined by `cfg.INPUT.{MIN,MAX}_SIZE_TEST`.
4. Take one input image and produce a single output, instead of a batch.
This is meant for simple demo purposes, so it does the above steps automatically.
This is not meant for benchmarks or running complicated inference logic.
If you'd like to do anything more complicated, please refer to its source code as
examples to build and use the model manually.
Attributes:
metadata (Metadata): the metadata of the underlying dataset, obtained from
cfg.DATASETS.TEST.
Examples:
::
pred = DefaultPredictor(cfg)
inputs = cv2.imread("input.jpg")
outputs = pred(inputs)
"""
def __init__(self, cfg):
self.cfg = copy.deepcopy(cfg) # cfg can be modified by model
self.model = instantiate(cfg.model)
self.model.to(cfg.train.device)
self.model.eval()
checkpointer = DetectionCheckpointer(self.model)
checkpointer.load(cfg.train.init_checkpoint)
self.aug = instantiate(cfg.dataloader.test.mapper.augmentations[0])
if "model_vision" in cfg.model:
self.input_format = cfg.model.model_vision.input_format
else:
self.input_format = cfg.model.input_format
assert self.input_format in ["RGB", "BGR"], self.input_format
def __call__(self, original_image, text_prompt=None, mask_prompt=None):
"""
Args:
original_image (np.ndarray): an image of shape (H, W, C) (in BGR order).
Returns:
predictions (dict):
the output of the model for one image only.
See :doc:`/tutorials/models` for details about the format.
"""
with torch.no_grad(): # https://github.com/sphinx-doc/sphinx/issues/4258
# Apply pre-processing to image.
if self.input_format == "RGB":
# whether the model expects BGR inputs or RGB
original_image = original_image[:, :, ::-1]
height, width = original_image.shape[:2]
image = self.aug.get_transform(original_image).apply_image(original_image)
image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))
inputs = {"image": image, "height": height, "width": width}
if text_prompt is not None:
inputs["prompt"] = "text"
inputs["text_prompt"] = text_prompt
if mask_prompt is not None:
mask_prompt = self.aug.get_transform(mask_prompt).apply_image(mask_prompt)
inputs["mask_prompt"] = torch.as_tensor(mask_prompt.astype("float32"))
predictions = self.model([inputs])[0]
return predictions