Adaptive-RAG / README.md
LenckCuak's picture
Upload README.md with huggingface_hub
82e9698 verified
metadata
license: apache-2.0
language:
  - en
library_name: transformers
tags:
  - t5
  - text-classification
  - rag
  - question-answering
  - retrieval-augmented-generation
base_model: google-t5/t5-large
pipeline_tag: text-classification
datasets:
  - musique
  - hotpotqa
  - 2wikimultihop
  - natural_questions
  - trivia_qa
  - squad
metrics:
  - accuracy

Adaptive-RAG Fine-Tuned Models

This directory contains fine-tuned models for the Adaptive-RAG framework (Jeong et al., 2024).

Relationship with Adaptive-RAG Repository

This directory is the model output directory for the Adaptive-RAG project hosted at:

https://github.com/1571859588/Adaptive-RAG

The original repository provides:

  • The Adaptive-RAG framework code
  • Training scripts for the query complexity classifier
  • Inference and evaluation pipelines
  • Pre-trained base models

This fine-tuned-models/ directory stores the trained classifier models that result from running the training scripts from the Adaptive-RAG repository.

Directory Structure

fine-tuned-models/
└── classifier/
    └── t5-large/          # Fine-tuned T5-large query complexity classifier
        β”œβ”€β”€ config.json
        β”œβ”€β”€ generation_config.json
        β”œβ”€β”€ model.safetensors
        β”œβ”€β”€ tokenizer.json
        β”œβ”€β”€ tokenizer_config.json
        β”œβ”€β”€ special_tokens_map.json
        β”œβ”€β”€ spiece.model
        └── logs.log

Classifier Model Details

Base Model

  • Model: T5-large
  • Base model path: /mnt/public/sichuan_a/nyt/models/Adaptive-RAG/base-models/classifier/t5-large

Training Configuration

Parameter Value
Learning rate 3e-5
Batch size 16 (per device)
Max sequence length 384
Doc stride 128
Training epochs 25
Optimizer AdamW
Mixed precision FP16 (via accelerate)

Training Data

The classifier was trained on multi-dataset silver-labeled data for query complexity classification:

Data location:

/mnt/public/sichuan_a/nyt/MAEDA/huada-docqa-demo/huada-docqa-demo/experiments/regression_test/MAEDA-DATE26/baselines/Adaptive-RAG/classifier/data/musique_hotpot_wiki2_nq_tqa_sqd/flan_t5_xxl/binary_silver/train.json

Datasets included:

  • Musique (multi-hop QA)
  • HotpotQA (multi-hop QA)
  • 2WikiMultihopQA (multi-hop QA)
  • Natural Questions (single-hop QA)
  • TriviaQA (single-hop QA)
  • SQuAD (single-hop QA)

Classification labels:

  • A (zero): Query requires no retrieval
  • B (single): Query requires single-step retrieval
  • C (multi): Query requires multi-step retrieval

Training Code

Training Script Location

The training script is located in the Adaptive-RAG repository:

/mnt/public/sichuan_a/nyt/MAEDA/huada-docqa-demo/huada-docqa-demo/experiments/regression_test/MAEDA-DATE26/baselines/Adaptive-RAG/classifier/run_classifier.py

Shell Script

The shell script used to launch training:

/mnt/public/sichuan_a/nyt/MAEDA/huada-docqa-demo/huada-docqa-demo/experiments/regression_test/MAEDA-DATE26/baselines/Adaptive-RAG/classifier/run/run_maeda_train.sh

Training Command

cd /mnt/public/sichuan_a/nyt/MAEDA/huada-docqa-demo/huada-docqa-demo/experiments/regression_test/MAEDA-DATE26/baselines/Adaptive-RAG
source /mnt/public/sichuan_a/nyt/uv_envs/.venv_adaptive_rag/bin/activate

python classifier/run_classifier.py \
    --model_name_or_path /mnt/public/sichuan_a/nyt/models/Adaptive-RAG/base-models/classifier/t5-large \
    --train_file classifier/data/musique_hotpot_wiki2_nq_tqa_sqd/flan_t5_xxl/binary_silver/train.json \
    --question_column question \
    --answer_column answer \
    --learning_rate 3e-5 \
    --max_seq_length 384 \
    --doc_stride 128 \
    --per_device_train_batch_size 16 \
    --output_dir /mnt/public/sichuan_a/nyt/models/Adaptive-RAG/fine-tuned-models/classifier/t5-large \
    --overwrite_cache \
    --train_column train \
    --do_train \
    --num_train_epochs 25

Python Environment

  • Virtual environment: /mnt/public/sichuan_a/nyt/uv_envs/.venv_adaptive_rag
  • Python version: 3.10
  • PyTorch: 2.6.0+cu124
  • Key dependencies: transformers, accelerate, datasets, nltk

How to Use the Fine-Tuned Classifier

Loading the Model

from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

model_path = "/mnt/public/sichuan_a/nyt/models/Adaptive-RAG/fine-tuned-models/classifier/t5-large"

tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForSeq2SeqLM.from_pretrained(model_path)

Classifying a Query

def classify_query(query, model, tokenizer, device="cuda"):
    """Classify query complexity into A/B/C."""
    prompt = f"Classify the complexity of the following query: {query}"
    inputs = tokenizer(prompt, return_tensors="pt", max_length=384, truncation=True).to(device)
    
    with torch.no_grad():
        outputs = model.generate(**inputs, max_length=10)
    
    label = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
    
    # Map to strategy
    label_to_strategy = {"A": "zero", "B": "single", "C": "multi"}
    strategy = label_to_strategy.get(label, "single")
    
    return label, strategy

Example Usage in Adaptive-RAG Pipeline

The classifier is used in the Adaptive-RAG inference pipeline to dynamically select retrieval strategies:

# Load classifier
classifier_model, classifier_tokenizer = load_classifier(
    "/mnt/public/sichuan_a/nyt/models/Adaptive-RAG/fine-tuned-models/classifier/t5-large",
    device="cuda"
)

# Classify queries
labels = classify_queries(queries, classifier_model, classifier_tokenizer, device="cuda")

# Select retrieval strategy based on label
for query, label in zip(queries, labels):
    if label == "A":
        # No retrieval
        contexts = []
    elif label == "B":
        # Single-step retrieval
        contexts = retrieve(query, top_k=10)
    elif label == "C":
        # Multi-step retrieval
        contexts = multi_step_retrieve(query, steps=3, top_k=5)

Training Notes

Fixes Applied for This Training

  1. Removed unsupported Trainer arguments: The original run_classifier.py did not support --save_strategy, --save_total_limit, and --logging_steps. These were removed from the shell script.

  2. Fixed accelerate compatibility: The code used accelerator.use_fp16 which is deprecated in newer accelerate versions. Fixed by replacing with:

    use_fp16 = getattr(accelerator, 'mixed_precision', 'no') != 'no'
    
  3. Added output directory creation: Added mkdir -p in the shell script to ensure the output directory exists before training.

Training Statistics

  • Total steps: 5975
  • Training time: ~23 minutes
  • Final loss: See logs.log for details

References

  • Adaptive-RAG Paper: Jeong, S., et al. "Adaptive Retrieval Augmented Generation for Large Language Models." (2024)
  • GitHub Repository: https://github.com/1571859588/Adaptive-RAG
  • MAEDA Integration: See /mnt/public/sichuan_a/nyt/MAEDA/huada-docqa-demo/huada-docqa-demo/experiments/regression_test/MAEDA-DATE26/baselines/Adaptive-RAG/README_MAEDA.md for details on using this model with the MAEDA benchmark.