Flux Attention: Context-Aware Hybrid Attention for Efficient LLMs Inference

Paper | Project Page | Code

Flux Attention is a context-aware framework that dynamically optimizes attention computation at the layer level. By integrating a lightweight Layer Router into frozen pretrained LLMs, the proposed method adaptively routes each layer to Full Attention (FA) or Sparse Attention (SA) based on the input context. This layer-wise routing preserves high-fidelity information retrieval while ensuring contiguous memory access, translating theoretical computational reductions into practical wall-clock speedups.

âš¡ Quick Start (Inference)

Here is a minimal example of how to use Flux Attention for text generation. This requires the fluxattn package and its dependencies (such as Block-Sparse-Attention) to be installed as described in the official repository.

import torch
import json
from transformers import AutoTokenizer, AutoModelForCausalLM

def load_sparse_model(model_path):
    """
    Dynamically loads the correct sparse architecture based on config.
    """
    config_path = f"{model_path}/config.json"
    with open(config_path, "r") as f:
        config_data = json.load(f)

    arch = config_data.get("architectures", [])
    if not arch:
        raise ValueError("No architecture found in config.json")

    arch_name = arch[0]
    print(f"🚀 Detected architecture: {arch_name}")

    # Register custom architectures
    if "PawLlama" in arch_name:
        from fluxattn.training.eval.modeling_flash_llama import (
            PawLlamaForCausalLM, PawLlamaConfig
        )
        AutoModelForCausalLM.register(PawLlamaConfig, PawLlamaForCausalLM)
        model_cls = PawLlamaForCausalLM
        
    elif "PawQwen" in arch_name:
        from fluxattn.training.eval.modeling_flash_qwen import (
            PawQwen3ForCausalLM, PawQwen3Config
        )
        AutoModelForCausalLM.register(PawQwen3Config, PawQwen3ForCausalLM)
        model_cls = PawQwen3ForCausalLM
    else:
        raise ValueError(f"Unsupported architecture: {arch_name}")

    # Load model
    model = model_cls.from_pretrained(
        model_path,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        trust_remote_code=True,
    )
    return model

# --- Execution ---
model_path = "****" # <--- Replace with your checkpoint path
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

print("Loading Flux Attention Model...")
model = load_sparse_model(model_path)
model.eval()

# Generate
input_text = "Explain quantum mechanics in one sentence."
inputs = tokenizer(input_text, return_tensors="pt").to("cuda")

print("Generating...")
outputs = model.generate(**inputs, max_new_tokens=100)
print("
Output:
" + tokenizer.decode(outputs[0], skip_special_tokens=True))

Citation

If you find this project useful in your research, please consider citing:

@misc{qiu2026fluxattentioncontextawarehybrid,
      title={Flux Attention: Context-Aware Hybrid Attention for Efficient LLMs Inference}, 
      author={Quantong Qiu and Zhiyi Hong and Yi Yang and Haitian Wang and Kebin Liu and Qingqing Dang and Juntao Li and Min Zhang},
      year={2026},
      eprint={2604.07394},
      archivePrefix={arXiv},
      primaryClass={cs.LG},
      url={https://arxiv.org/abs/2604.07394}, 
}
Downloads last month
282
Safetensors
Model size
8B params
Tensor type
F32
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for QQTang1223/full_streaming_Llama-3.1-8B-Instruct

Finetuned
(2488)
this model

Collection including QQTang1223/full_streaming_Llama-3.1-8B-Instruct

Paper for QQTang1223/full_streaming_Llama-3.1-8B-Instruct