Flux Attention: Context-Aware Hybrid Attention for Efficient LLMs Inference
Flux Attention is a context-aware framework that dynamically optimizes attention computation at the layer level. By integrating a lightweight Layer Router into frozen pretrained LLMs, the method adaptively routes each layer to Full Attention (FA) or Sparse Attention (SA) based on the input context. This layer-wise routing preserves high-fidelity information retrieval while ensuring contiguous memory access, translating theoretical computational reductions into practical wall-clock speedups.
- Paper: Flux Attention: Context-Aware Hybrid Attention for Efficient LLMs Inference
- GitHub: qqtang-code/FluxAttention
- Project Page: Flux Attention Project Page
Installation
To use this model, you need to install the fluxattn package and Block-Sparse-Attention as described in the official repository.
âš¡ Quick Start (Inference)
Here is a minimal example of how to use Flux Attention for text generation. This requires registering the custom architecture before loading the model.
import torch
import json
from transformers import AutoTokenizer, AutoModelForCausalLM
def load_sparse_model(model_path):
"""
Dynamically loads the correct sparse architecture based on config.
"""
config_path = f"{model_path}/config.json"
with open(config_path, "r") as f:
config_data = json.load(f)
arch = config_data.get("architectures", [])
if not arch:
raise ValueError("No architecture found in config.json")
arch_name = arch[0]
print(f"🚀 Detected architecture: {arch_name}")
# Register custom architectures
if "PawLlama" in arch_name:
from fluxattn.training.eval.modeling_flash_llama import (
PawLlamaForCausalLM, PawLlamaConfig
)
AutoModelForCausalLM.register(PawLlamaConfig, PawLlamaForCausalLM)
model_cls = PawLlamaForCausalLM
elif "PawQwen" in arch_name:
from fluxattn.training.eval.modeling_flash_qwen import (
PawQwen3ForCausalLM, PawQwen3Config
)
AutoModelForCausalLM.register(PawQwen3Config, PawQwen3ForCausalLM)
model_cls = PawQwen3ForCausalLM
else:
raise ValueError(f"Unsupported architecture: {arch_name}")
# Load model
model = model_cls.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
)
return model
# --- Execution ---
model_path = "QQTang1223/Flux-Attention-Qwen3-4B" # Replace with your checkpoint path
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
print("Loading Flux Attention Model...")
model = load_sparse_model(model_path)
model.eval()
# Generate
input_text = "Explain quantum mechanics in one sentence."
inputs = tokenizer(input_text, return_tensors="pt").to("cuda")
print("Generating...")
outputs = model.generate(**inputs, max_new_tokens=100)
print("
Output:
" + tokenizer.decode(outputs[0], skip_special_tokens=True))
Citation
If you find this project useful in your research, please consider citing:
@misc{qiu2026fluxattentioncontextawarehybrid,
title={Flux Attention: Context-Aware Hybrid Attention for Efficient LLMs Inference},
author={Quantong Qiu and Zhiyi Hong and Yi Yang and Haitian Wang and Kebin Liu and Qingqing Dang and Juntao Li and Min Zhang},
year={2026},
eprint={2604.07394},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2604.07394},
}
- Downloads last month
- 307