Flux Attention: Context-Aware Hybrid Attention for Efficient LLMs Inference

Paper | GitHub | Project Page

Flux Attention is a context-aware framework that dynamically optimizes attention computation at the layer level. By integrating a lightweight Layer Router into frozen pretrained LLMs, it adaptively routes each layer to Full Attention (FA) or Sparse Attention (SA) based on the input context. This layer-wise routing preserves high-fidelity information retrieval while ensuring contiguous memory access, translating theoretical computational reductions into practical wall-clock speedups.

This repository contains the checkpoint for Llama-3.1-8B-Instruct optimized with Flux Attention.

Highlights

  • High Training Efficiency: Requires only 12 hours of training on 8x A800 GPUs for 8B-scale models.
  • Long-Sequence Performance: Preserves high-fidelity information retrieval, matching backbone models and significantly surpassing baseline methods.
  • Inference Acceleration: Achieves substantial wall-clock speedups (up to 2.8x in the prefill and 2.0x in the decode stages) on long-context tasks.

âš¡ Quick Start (Inference)

To use Flux Attention, you need to clone the official repository and install the requirements (including custom CUDA kernels for Block-Sparse-Attention).

import torch
import json
from transformers import AutoTokenizer, AutoModelForCausalLM

def load_sparse_model(model_path):
    """
    Dynamically loads the correct sparse architecture based on config.
    """
    config_path = f"{model_path}/config.json"
    with open(config_path, "r") as f:
        config_data = json.load(f)

    arch = config_data.get("architectures", [])
    if not arch:
        raise ValueError("No architecture found in config.json")

    arch_name = arch[0]
    print(f"🚀 Detected architecture: {arch_name}")

    # Register custom architectures
    if "PawLlama" in arch_name:
        from fluxattn.training.eval.modeling_flash_llama import (
            PawLlamaForCausalLM, PawLlamaConfig
        )
        AutoModelForCausalLM.register(PawLlamaConfig, PawLlamaForCausalLM)
        model_cls = PawLlamaForCausalLM
        
    elif "PawQwen" in arch_name:
        from fluxattn.training.eval.modeling_flash_qwen import (
            PawQwen3ForCausalLM, PawQwen3Config
        )
        AutoModelForCausalLM.register(PawQwen3Config, PawQwen3ForCausalLM)
        model_cls = PawQwen3ForCausalLM
    else:
        raise ValueError(f"Unsupported architecture: {arch_name}")

    # Load model
    model = model_cls.from_pretrained(
        model_path,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        trust_remote_code=True,
    )
    return model

# --- Execution ---
model_path = "QQTang1223/Flux-Attention-Llama-3.1-8B-Instruct" 
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

print("Loading Flux Attention Model...")
model = load_sparse_model(model_path)
model.eval()

# Generate
input_text = "Explain quantum mechanics in one sentence."
inputs = tokenizer(input_text, return_tensors="pt").to("cuda")

print("Generating...")
outputs = model.generate(**inputs, max_new_tokens=100)
print("
Output:
" + tokenizer.decode(outputs[0], skip_special_tokens=True))

Citation

@misc{qiu2026fluxattentioncontextawarehybrid,
      title={Flux Attention: Context-Aware Hybrid Attention for Efficient LLMs Inference}, 
      author={Quantong Qiu and Zhiyi Hong and Yi Yang and Haitian Wang and Kebin Liu and Qingqing Dang and Juntao Li and Min Zhang},
      year={2026},
      eprint={2604.07394},
      archivePrefix={arXiv},
      primaryClass={cs.LG},
      url={https://arxiv.org/abs/2604.07394}, 
}
Downloads last month
317
Safetensors
Model size
8B params
Tensor type
F32
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for QQTang1223/full_triangle_Llama-3.1-8B-Instruct

Finetuned
(2496)
this model

Collection including QQTang1223/full_triangle_Llama-3.1-8B-Instruct

Paper for QQTang1223/full_triangle_Llama-3.1-8B-Instruct