Qwen3-8B-Block-FT

Block-Attention fine-tuned Qwen/Qwen3-8B for efficient RAG inference.

Overview

This model is fine-tuned using the Block-Attention mechanism from Block-Attention for Efficient Prefilling. Block-Attention divides the input context into independent blocks during the prefill phase, enabling KV cache reuse across different queries on the same documents — a key optimization for RAG serving.

Training Data Control Variable: This model was fine-tuned on an 8K subset of the Tulu3-Block-FT-RAG dataset. A companion Llama-3.2-1B model uses the full 80K samples for comparison.

Evaluation Results

On Unseen TriviaQA Validation Set (100 clean samples)

Questions and evidence passages from TriviaQA RC validation split, excluded from training data. Substr-EM checks whether the correct answer appears as a substring in the model's response.

Model Substr-EM F1 Score
meta-llama/Llama-3.2-1B (base) 56.00% 12.51%
meta-llama/Llama-3.2-1B-Instruct 86.00% 23.62%
hxia7/Llama-3.2-1B-block-FT (full-attention) 87.00% 26.59%
hxia7/Llama-3.2-1B-block-FT (block-attention) 88.00% 27.53%
hxia7/Qwen3-8B-block-FT (full-attention) 91.00% 25.18%
hxia7/Qwen3-8B-block-FT (block-attention) 90.00% 23.71%

Key observations:

  • Block-attention and full-attention produce comparable results (91% vs 90% Substr-EM), confirming the block-attention structure preserves quality.
  • Despite training on only 8K samples (vs 80K for Llama), the Qwen3-8B model achieves the highest Substr-EM at 91%, demonstrating the benefit of a larger base model.
  • The evidence passages from TriviaQA differ from the Contriever-retrieved passages used in training, making this a meaningful out-of-distribution test.

Block-Attention Mechanism

In Block-Attention, the context is split into N blocks:

  • Blocks 1..N-1 (document blocks): Use local attention — each block attends only to itself
  • Block N (query block): Uses global attention — attends to all previous blocks

This isolation allows document blocks' KV states to be computed once and reused across multiple queries.

Training Details

  • Base Model: Qwen/Qwen3-8B
  • Training Data: Tulu3-Block-FT-RAG (8K subset)
  • Epochs: 1
  • Learning Rate: 2e-6
  • Optimizer: AdamW (fused)
  • Precision: BF16
  • DeepSpeed: ZeRO Stage 2 with CPU optimizer offload
  • Loss Reduction: sum (over non-masked tokens)

During training, each sample produces two variants:

  1. Full-attention version (standard causal mask)
  2. Block-attention version (with [Block-Attention] prefix token and 4D block mask)

Both variants contribute to the loss, teaching the model to handle both inference modes.

Inference

Block-Attention Inference (recommended for RAG)

Important: Block-Attention uses a 4D attention mask [1, 1, seq_len, seq_len] during prefill. model.generate() only accepts 2D masks, so inference requires manual prefill + autoregressive decode:

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from src.data.block import build_attention_mask, convert_attention_mask_to_model_required

model = AutoModelForCausalLM.from_pretrained("hxia7/Qwen3-8B-block-FT", torch_dtype=torch.bfloat16, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained("hxia7/Qwen3-8B-block-FT")

blocks = [
    "\nYou are an intelligent AI assistant. Please answer questions based on the user's instructions. Below are some reference documents that may help you in answering the user's question.\n\n",
    "- Title: Document 1\nContent of document 1...\n",
    "- Title: Document 2\nContent of document 2...\n",
    "\n\nPlease write a high-quality answer for the given question using only the provided search documents.\nQuestion: What is X?\n\n\n",
]

@torch.no_grad()
def block_generate(model, tokenizer, blocks, max_new_tokens=128):
    block_token_counts = []
    all_ids = []
    for b in blocks:
        ids = tokenizer.encode(b, add_special_tokens=False)
        all_ids.extend(ids)
        block_token_counts.append(len(ids))

    input_ids = torch.tensor([all_ids], dtype=torch.int64, device=model.device)
    total_len = len(all_ids)

    helper = torch.tril(torch.ones(total_len + 64, total_len + 64, dtype=torch.bool))
    attn_mask = build_attention_mask(
        local_attention_block_tokens=torch.tensor(block_token_counts[:-1], dtype=torch.long),
        global_attention_block_tokens=torch.tensor(block_token_counts[-1], dtype=torch.long),
        lower_triangular_matrix=helper,
    )
    attn_mask = convert_attention_mask_to_model_required(attn_mask)
    attn_mask = attn_mask.unsqueeze(0).unsqueeze(0).to(model.device)

    outputs = model(input_ids=input_ids, attention_mask=attn_mask, use_cache=True)
    past_kv = outputs.past_key_values
    next_token = torch.argmax(outputs.logits[:, -1, :], dim=-1, keepdim=True)

    generated = []
    for _ in range(max_new_tokens - 1):
        if next_token.item() == tokenizer.eos_token_id:
            break
        generated.append(next_token.item())
        outputs = model(input_ids=next_token, past_key_values=past_kv, use_cache=True)
        past_kv = outputs.past_key_values
        next_token = torch.argmax(outputs.logits[:, -1, :], dim=-1, keepdim=True)

    if next_token.item() != tokenizer.eos_token_id:
        generated.append(next_token.item())

    return tokenizer.decode(generated, skip_special_tokens=True).strip()

answer = block_generate(model, tokenizer, blocks)
print(answer)

Full-Attention Inference (standard)

from transformers import AutoTokenizer, AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained("hxia7/Qwen3-8B-block-FT", torch_dtype=torch.bfloat16, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained("hxia7/Qwen3-8B-block-FT")

prompt = "Your full RAG prompt here..."
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=3968).to(model.device)
outputs = model.generate(**inputs, max_new_tokens=128, do_sample=False, pad_token_id=tokenizer.eos_token_id)
answer = tokenizer.decode(outputs[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True)
print(answer)

References

Downloads last month
7
Safetensors
Model size
8B params
Tensor type
BF16
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for hxia7/Qwen3-8B-Block-FT

Finetuned
Qwen/Qwen3-8B
Finetuned
(1641)
this model

Collection including hxia7/Qwen3-8B-Block-FT