Llama-3.1-8B-Block-FT

Block-Attention fine-tuned meta-llama/Llama-3.1-8B for efficient RAG inference.

Overview

This model is fine-tuned using the Block-Attention mechanism from Block-Attention for Efficient Prefilling. Block-Attention divides the input context into independent blocks during the prefill phase, enabling KV cache reuse across different queries on the same documents — a key optimization for RAG serving.

Training Data Control Variable: This model was fine-tuned on an 8K subset of the Tulu3-Block-FT-RAG dataset, matching the data volume used for the companion Qwen3-8B model. A companion Llama-3.2-1B model uses the full 80K samples.

Evaluation Results

Comparison with Other Block-FT Models (on Unseen TriviaQA, 100 clean samples)

Questions and evidence passages from TriviaQA RC validation split, excluded from training data. Substr-EM checks whether the correct answer appears as a substring in the model's response.

Model Params Train Data Substr-EM F1 Score
meta-llama/Llama-3.2-1B (base) 1B - 56.00% 12.51%
meta-llama/Llama-3.2-1B-Instruct 1B - 86.00% 23.62%
hxia7/Llama-3.2-1B-block-FT (full) 1B 80K 87.00% 26.59%
hxia7/Llama-3.2-1B-block-FT (block) 1B 80K 88.00% 27.53%
hxia7/Qwen3-8B-block-FT (full) 8B 8K 91.00% 25.18%
hxia7/Qwen3-8B-block-FT (block) 8B 8K 90.00% 23.71%
hxia7/Llama-3.1-8B-block-FT 8B 8K TBD TBD

Evaluation results for this model will be added once GPU resources are available.

Block-Attention Mechanism

In Block-Attention, the context is split into N blocks:

  • Blocks 1..N-1 (document blocks): Use local attention — each block attends only to itself
  • Block N (query block): Uses global attention — attends to all previous blocks

This isolation allows document blocks' KV states to be computed once and reused across multiple queries.

Training Details

  • Base Model: meta-llama/Llama-3.1-8B
  • Training Data: Tulu3-Block-FT-RAG (8K subset)
  • Epochs: 1
  • Learning Rate: 2e-6
  • Optimizer: AdamW (fused)
  • Precision: BF16
  • DeepSpeed: ZeRO Stage 2 with CPU optimizer offload
  • Loss Reduction: sum (over non-masked tokens)

During training, each sample produces two variants:

  1. Full-attention version (standard causal mask)
  2. Block-attention version (with [Block-Attention] prefix token and 4D block mask)

Both variants contribute to the loss, teaching the model to handle both inference modes.

Inference

Block-Attention Inference (recommended for RAG)

Important: Block-Attention uses a 4D attention mask [1, 1, seq_len, seq_len] during prefill. model.generate() only accepts 2D masks, so inference requires manual prefill + autoregressive decode:

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from src.data.block import build_attention_mask, convert_attention_mask_to_model_required

model = AutoModelForCausalLM.from_pretrained("hxia7/Llama-3.1-8B-block-FT", torch_dtype=torch.bfloat16, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained("hxia7/Llama-3.1-8B-block-FT")

blocks = [
    "<|start_header_id|>system\nYou are an AI assistant. Below are reference documents.\n\n",
    "- Title: Document 1\nContent of document 1...\n",
    "- Title: Document 2\nContent of document 2...\n",
    "Answer the question using the documents.\nQuestion: What is X?\n\n",
]

@torch.no_grad()
def block_generate(model, tokenizer, blocks, max_new_tokens=128):
    block_token_counts = []
    all_ids = []
    for b in blocks:
        ids = tokenizer.encode(b, add_special_tokens=False)
        all_ids.extend(ids)
        block_token_counts.append(len(ids))

    input_ids = torch.tensor([all_ids], dtype=torch.int64, device=model.device)
    total_len = len(all_ids)

    helper = torch.tril(torch.ones(total_len + 64, total_len + 64, dtype=torch.bool))
    attn_mask = build_attention_mask(
        local_attention_block_tokens=torch.tensor(block_token_counts[:-1], dtype=torch.long),
        global_attention_block_tokens=torch.tensor(block_token_counts[-1], dtype=torch.long),
        lower_triangular_matrix=helper,
    )
    attn_mask = convert_attention_mask_to_model_required(attn_mask)
    attn_mask = attn_mask.unsqueeze(0).unsqueeze(0).to(model.device)

    outputs = model(input_ids=input_ids, attention_mask=attn_mask, use_cache=True)
    past_kv = outputs.past_key_values
    next_token = torch.argmax(outputs.logits[:, -1, :], dim=-1, keepdim=True)

    generated = []
    for _ in range(max_new_tokens - 1):
        if next_token.item() == tokenizer.eos_token_id:
            break
        generated.append(next_token.item())
        outputs = model(input_ids=next_token, past_key_values=past_kv, use_cache=True)
        past_kv = outputs.past_key_values
        next_token = torch.argmax(outputs.logits[:, -1, :], dim=-1, keepdim=True)

    if next_token.item() != tokenizer.eos_token_id:
        generated.append(next_token.item())

    return tokenizer.decode(generated, skip_special_tokens=True).strip()

answer = block_generate(model, tokenizer, blocks)
print(answer)

Full-Attention Inference (standard)

from transformers import AutoTokenizer, AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained("hxia7/Llama-3.1-8B-block-FT", torch_dtype=torch.bfloat16, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained("hxia7/Llama-3.1-8B-block-FT")

prompt = "Your full RAG prompt here..."
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=3968).to(model.device)
outputs = model.generate(**inputs, max_new_tokens=128, do_sample=False, pad_token_id=tokenizer.eos_token_id)
answer = tokenizer.decode(outputs[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True)
print(answer)

References

Downloads last month
7
Safetensors
Model size
8B params
Tensor type
BF16
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for hxia7/Llama-3.1-8B-Block-FT

Finetuned
(1796)
this model

Collection including hxia7/Llama-3.1-8B-Block-FT