fix: import `flash_attn_varlen_func` from `flash_attn` instead of `transformers.modeling_flash_attention_utils`
#4
by
wincentIsMe
- opened
When I load this model from the transformers.
from transformers import AutoTokenizer, AutoProcessor, AutoModelForCausalLM
model_path = "lmms-lab/LLaVA-OneVision-1.5-8B-Instruct"
model = AutoModelForCausalLM.from_pretrained(
model_path, torch_dtype="auto", device_map="auto", trust_remote_code=True
)
processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
The following error occurs.
ImportError: cannot import name 'flash_attn_varlen_func' from 'transformers.modeling_flash_attention_utils'
This is because the current transformers library no longer exposes the flash_attn_varlen_func API in the transformers.modeling_flash_attention_utils module.
The solution is to import the flash_attn_varlen_func API from flash_attn.
In fact, this bug has been fixed in LLaVA-OneVision-1.5 GitHub Repo(fix_issue#31).
However, it has not yet been synchronized to the Hugging Face repository.