File size: 1,827 Bytes
2742ed8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 | # Copyright (c) Alibaba, Inc. and its affiliates.
# Part of the implementation is borrowed from huggingface/transformers.
import inspect
from types import FunctionType, MethodType
from typing import List, Union
from peft import PeftModel
from torch.nn import Module
from swift.utils import get_logger
logger = get_logger()
def can_return_loss(model: Module) -> bool:
"""Check if a given model can return loss."""
if isinstance(model, PeftModel):
signature = inspect.signature(model.model.forward)
else:
signature = inspect.signature(model.forward)
for p in signature.parameters:
if p == 'return_loss' and signature.parameters[p].default is True:
return True
return False
def find_labels(model: Module) -> List[str]:
"""Find the labels used by a given model."""
model_name = model.__class__.__name__
if isinstance(model, PeftModel):
signature = inspect.signature(model.model.forward)
else:
signature = inspect.signature(model.forward)
if 'QuestionAnswering' in model_name:
return [p for p in signature.parameters if 'label' in p or p in ('start_positions', 'end_positions')]
else:
return [p for p in signature.parameters if 'label' in p]
def get_function(method_or_function: Union[MethodType, FunctionType]) -> FunctionType:
if isinstance(method_or_function, MethodType):
method_or_function = method_or_function.__func__
return method_or_function
def is_instance_of_ms_model(model: Module) -> bool:
"""avoid import modelscope: circular dependency problem"""
for m_cls in model.__class__.__mro__:
cls_name = m_cls.__name__
cls_module = m_cls.__module__
if cls_name == 'Model' and cls_module.startswith('modelscope'):
return True
return False
|