Llama-3.2-1B-Block-FT

Block-Attention fine-tuned meta-llama/Llama-3.2-1B for efficient RAG inference.

Overview

This model is fine-tuned using the Block-Attention mechanism from Block-Attention for Efficient Prefilling. Block-Attention divides the input context into independent blocks during the prefill phase, enabling KV cache reuse across different queries on the same documents — a key optimization for RAG serving.

Training Data Control Variable: This model was fine-tuned on the full 80K samples from the Tulu3-Block-FT-RAG dataset. A companion Qwen3-8B model uses an 8K subset for comparison.

Evaluation Results

On Training Data (100 samples from Tulu3-Block-FT-RAG train set)

Note: These samples overlap with the training data, so absolute scores are inflated. The block vs full comparison remains valid.

Model EM Score F1 Score
meta-llama/Llama-3.2-1B (base) 0.00% 0.54%
meta-llama/Llama-3.2-1B-Instruct 2.00% 50.62%
hxia7/Llama-3.2-1B-block-FT (full-attention) 14.00% 70.96%
hxia7/Llama-3.2-1B-block-FT (block-attention) 17.00% 69.68%

On Unseen TriviaQA Validation Set (100 clean samples)

Questions and evidence passages from TriviaQA RC validation split, excluded from training data. Substr-EM checks whether the correct answer appears as a substring in the model's response.

Model Substr-EM F1 Score
meta-llama/Llama-3.2-1B (base) 56.00% 12.51%
meta-llama/Llama-3.2-1B-Instruct 86.00% 23.62%
hxia7/Llama-3.2-1B-block-FT (full-attention) 87.00% 26.59%
hxia7/Llama-3.2-1B-block-FT (block-attention) 88.00% 27.53%

Key observations:

  • Block-attention and full-attention produce comparable results across both evaluation sets, confirming the block-attention structure preserves quality.
  • On unseen data, block-FT outperforms the Instruct baseline in both Substr-EM (+2%) and F1 (+4%), demonstrating that RAG fine-tuning improves answer extraction quality even on out-of-distribution data.
  • The evidence passages from TriviaQA differ from the Contriever-retrieved passages used in training, making this a meaningful out-of-distribution test.

Block-Attention Mechanism

In Block-Attention, the context is split into N blocks:

  • Blocks 1..N-1 (document blocks): Use local attention — each block attends only to itself
  • Block N (query block): Uses global attention — attends to all previous blocks

This isolation allows document blocks' KV states to be computed once and reused across multiple queries.

Training Details

  • Base Model: meta-llama/Llama-3.2-1B
  • Training Data: Tulu3-Block-FT-RAG (80K samples, full dataset)
  • Epochs: 1
  • Learning Rate: 2e-5
  • Optimizer: AdamW (fused)
  • Precision: BF16
  • DeepSpeed: ZeRO Stage 2
  • Loss Reduction: sum (over non-masked tokens)

During training, each sample produces two variants:

  1. Full-attention version (standard causal mask)
  2. Block-attention version (with [Block-Attention] prefix token and 4D block mask)

Both variants contribute to the loss, teaching the model to handle both inference modes.

Inference

Block-Attention Inference (recommended for RAG)

Important: Block-Attention uses a 4D attention mask [1, 1, seq_len, seq_len] during prefill. model.generate() only accepts 2D masks, so inference requires manual prefill + autoregressive decode:

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from src.data.block import build_attention_mask, convert_attention_mask_to_model_required

model = AutoModelForCausalLM.from_pretrained("hxia7/Llama-3.2-1B-block-FT", torch_dtype=torch.bfloat16, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained("hxia7/Llama-3.2-1B-block-FT")

blocks = [
    "<|start_header_id|>system\nYou are an AI assistant. Below are reference documents.\n\n",
    "- Title: Document 1\nContent of document 1...\n",
    "- Title: Document 2\nContent of document 2...\n",
    "Answer the question using the documents.\nQuestion: What is X?\n\n",
]

@torch.no_grad()
def block_generate(model, tokenizer, blocks, max_new_tokens=128):
    block_token_counts = []
    all_ids = []
    for b in blocks:
        ids = tokenizer.encode(b, add_special_tokens=False)
        all_ids.extend(ids)
        block_token_counts.append(len(ids))

    input_ids = torch.tensor([all_ids], dtype=torch.int64, device=model.device)
    total_len = len(all_ids)

    helper = torch.tril(torch.ones(total_len + 64, total_len + 64, dtype=torch.bool))
    attn_mask = build_attention_mask(
        local_attention_block_tokens=torch.tensor(block_token_counts[:-1], dtype=torch.long),
        global_attention_block_tokens=torch.tensor(block_token_counts[-1], dtype=torch.long),
        lower_triangular_matrix=helper,
    )
    attn_mask = convert_attention_mask_to_model_required(attn_mask)
    attn_mask = attn_mask.unsqueeze(0).unsqueeze(0).to(model.device)

    outputs = model(input_ids=input_ids, attention_mask=attn_mask, use_cache=True)
    past_kv = outputs.past_key_values
    next_token = torch.argmax(outputs.logits[:, -1, :], dim=-1, keepdim=True)

    generated = []
    for _ in range(max_new_tokens - 1):
        if next_token.item() == tokenizer.eos_token_id:
            break
        generated.append(next_token.item())
        outputs = model(input_ids=next_token, past_key_values=past_kv, use_cache=True)
        past_kv = outputs.past_key_values
        next_token = torch.argmax(outputs.logits[:, -1, :], dim=-1, keepdim=True)

    if next_token.item() != tokenizer.eos_token_id:
        generated.append(next_token.item())

    return tokenizer.decode(generated, skip_special_tokens=True).strip()

answer = block_generate(model, tokenizer, blocks)
print(answer)

Full-Attention Inference (standard)

from transformers import AutoTokenizer, AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained("hxia7/Llama-3.2-1B-block-FT", torch_dtype=torch.bfloat16, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained("hxia7/Llama-3.2-1B-block-FT")

prompt = "Your full RAG prompt here..."
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=3968).to(model.device)
outputs = model.generate(**inputs, max_new_tokens=128, do_sample=False, pad_token_id=tokenizer.eos_token_id)
answer = tokenizer.decode(outputs[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True)
print(answer)

References

Downloads last month
1,146
Safetensors
Model size
1B params
Tensor type
BF16
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for hxia7/Llama-3.2-1B-Block-FT

Finetuned
(913)
this model

Collection including hxia7/Llama-3.2-1B-Block-FT