dikdimon's picture
Upload extensions using SD-Hub extension
3dabe4a verified
'''
This code from the following repository: https://github.com/LeapLabTHU/Agent-Attention
@article{han2023agent,
title={Agent Attention: On the Integration of Softmax and Linear Attention},
author={Han, Dongchen and Ye, Tianzhu and Han, Yizeng and Xia, Zhuofan and Song, Shiji and Huang, Gao},
journal={arXiv preprint arXiv:2312.08874},
year={2023}
}
'''
import torch
def isinstance_str(x: object, cls_name: str):
"""
Checks whether x has any class *named* cls_name in its ancestry.
Doesn't require access to the class's implementation.
Useful for patching!
"""
for _cls in x.__class__.__mro__:
if _cls.__name__ == cls_name:
return True
return False
def init_generator(device: torch.device, fallback: torch.Generator=None):
"""
Forks the current default random generator given device.
"""
if device.type == "cpu":
return torch.Generator(device="cpu").set_state(torch.get_rng_state())
elif device.type == "cuda":
return torch.Generator(device=device).set_state(torch.cuda.get_rng_state())
else:
if fallback is None:
return init_generator(torch.device("cpu"))
else:
return fallback