|
|
|
|
|
|
|
|
import inspect |
|
|
from types import FunctionType, MethodType |
|
|
from typing import List, Union |
|
|
|
|
|
from peft import PeftModel |
|
|
from torch.nn import Module |
|
|
|
|
|
from swift.utils import get_logger |
|
|
|
|
|
logger = get_logger() |
|
|
|
|
|
|
|
|
def can_return_loss(model: Module) -> bool: |
|
|
"""Check if a given model can return loss.""" |
|
|
if isinstance(model, PeftModel): |
|
|
signature = inspect.signature(model.model.forward) |
|
|
else: |
|
|
signature = inspect.signature(model.forward) |
|
|
for p in signature.parameters: |
|
|
if p == 'return_loss' and signature.parameters[p].default is True: |
|
|
return True |
|
|
return False |
|
|
|
|
|
|
|
|
def find_labels(model: Module) -> List[str]: |
|
|
"""Find the labels used by a given model.""" |
|
|
model_name = model.__class__.__name__ |
|
|
if isinstance(model, PeftModel): |
|
|
signature = inspect.signature(model.model.forward) |
|
|
else: |
|
|
signature = inspect.signature(model.forward) |
|
|
if 'QuestionAnswering' in model_name: |
|
|
return [p for p in signature.parameters if 'label' in p or p in ('start_positions', 'end_positions')] |
|
|
else: |
|
|
return [p for p in signature.parameters if 'label' in p] |
|
|
|
|
|
|
|
|
def get_function(method_or_function: Union[MethodType, FunctionType]) -> FunctionType: |
|
|
if isinstance(method_or_function, MethodType): |
|
|
method_or_function = method_or_function.__func__ |
|
|
return method_or_function |
|
|
|
|
|
|
|
|
def is_instance_of_ms_model(model: Module) -> bool: |
|
|
"""avoid import modelscope: circular dependency problem""" |
|
|
for m_cls in model.__class__.__mro__: |
|
|
cls_name = m_cls.__name__ |
|
|
cls_module = m_cls.__module__ |
|
|
if cls_name == 'Model' and cls_module.startswith('modelscope'): |
|
|
return True |
|
|
return False |
|
|
|