| | --- |
| | library_name: transformers |
| | license: apache-2.0 |
| | tags: [] |
| | pipeline_tag: audio-text-to-text |
| | --- |
| | |
| | # R1-AQA --- Reinforcement Learning Outperforms Supervised Fine-Tuning: A Case Study on Audio Question Answering |
| |
|
| | <!-- Provide a quick summary of what the model is/does. --> |
| |
|
| | ## Introduction |
| |
|
| | R1-AQA is a audio question answering (AQA) model based on `Qwen2-Audio-7B-Instruct`, optimized through reinforcement learning using the group relative policy optimization (GRPO) algorithm. |
| | This implementation has achieved state-of-the-art performance on the MMAU benchmark with only 38k post-training samples. |
| | For more details, please refer to our [Github](https://github.com/xiaomi-research/r1-aqa) and [Technical Report](https://arxiv.org/abs/2503.11197). |
| |
|
| | Our main findings are as follows: |
| |
|
| | - The GRPO algorithm can be directly and effectively applied to the audio modality, even to `Qwen2-Audio-7B-Instruct` with only 8.2B parameters. |
| | - With only 38k post-training samples, reinforcement learning outperforms supervised fine-tuning, indicating that RL-based approaches can be effective without large datasets. |
| | - The explicit reasoning process has not shown significant benefits for AQA tasks, and how to efficiently leverage *deep thinking* or step-by-step reasoning remains an open question for further research. |
| | - Large audio language models (LALMs) still lag far behind humans auditory-language reasoning, suggesting that the RL-based approaches warrant further explorations. |
| |
|
| | Additional Notes: |
| |
|
| | - The AVQA training set originally consists of approximately 40k samples. However, we use only about 38k samples because some data sources have become invalid. Other datasets using YouTube sources face a similar issue, such as AudioSet. We believe that the missing 2k samples do not have a significant impact on the training results. |
| | - The statement about the 8.2B parameters is based on the *Qwen2-Audio Technical Report*. |
| |
|
| | ### Table: Accuracies (%) on the MMAU benchmark |
| |
|
| | | Model | Method | Test-mini | Test | Test-mini | Test | Test-mini | Test | Test-mini | Test | |
| | |---------------------------------------|-----------------------|-----------|-------|-----------|-------|-----------|------|------------|-------| |
| | | - | Human\* | 86.31 | - | 78.22 | - | 82.17 | - | 82.23 | - | |
| | | Gemini Pro 2.0 Flash | Direct Inference\* | 56.46 | 61.73 | 58.68 | 56.53 | 51.65 | 61.53 | 55.60 | 59.93 | |
| | | Audio Flamingo 2 | Direct Inference\* | 61.56 | 65.10 | 73.95 | 72.90 | 30.93 | 40.26 | 55.48 | 59.42 | |
| | | GPT4o + Strong Cap. | Direct Inference\* | 57.35 | 55.83 | 49.70 | 51.73 | 64.86 | 68.66 | 57.30 | 58.74 | |
| | | Llama-3-8B-Instruct + Strong Cap. | Direct Inference\* | 50.75 | 49.10 | 48.93 | 48.93 | 55.25 | 62.70 | 52.10 | 53.57 | |
| | | Qwen2-Audio-7B-Instruct | Direct Inference\* | 54.95 | 45.90 | 50.98 | 53.26 | 42.04 | 45.90 | 49.20 | 52.50 | |
| | | SALAMONN | Direct Inference\* | 41.00 | 40.30 | 34.80 | 33.76 | 25.50 | 24.24 | 33.70 | 32.77 | |
| | | Qwen2-Audio-7B-Instruct | CoTA \[1\] | 60.06 | - | 64.30 | - | 60.70 | - | 61.71 | - | |
| | | Qwen2-Audio-7B-Instruct | Zero-Shot-CoT \[2\] | 61.86 | - | 56.29 | - | 55.26 | - | 57.80 | - | |
| | | **Qwen2-Audio-7B-Instruct** | **Ours 1️⃣** | 69.37 | - | 66.77 | - | 57.36 | - | 64.50 | - | |
| | | **Qwen2-Audio-7B-Instruct** | **Ours 2️⃣** | 68.77 | 69.76 | 64.37 | 61.40 | 63.66 | 62.70 | 65.60 | 64.36 | |
| |
|
| | #### Notes |
| |
|
| | 1️⃣ It is the original model, identical to the one on Hugging Face and described in our technical report. |
| | 2️⃣ It is the model submitted to [EvalAI](https://eval.ai/web/challenges/challenge-page/2391/overview) for evaluation, trained multiple times to achieve balanced results. (**The results on the [leaderboard](https://sakshi113.github.io/mmau_homepage/#leaderboard) contain some typographical errors, and we are currently in communication with the organizers to correct them.**) |
| | \* The data are sourced from the [MMAU official website](https://sakshi113.github.io/mmau_homepage/) |
| | \[1\] Xie, Zhifei, et al. "Audio-Reasoner: Improving Reasoning Capability in Large Audio Language Models." arXiv preprint arXiv:2503.02318 (2025). |
| | \[2\] Ma, Ziyang, et al. "Audio-CoT: Exploring Chain-of-Thought Reasoning in Large Audio Language Model." arXiv preprint arXiv:2501.07246 (2025). |
| |
|
| | ## Inference |
| |
|
| | ```python |
| | import torch |
| | import torchaudio |
| | from transformers import Qwen2AudioForConditionalGeneration, AutoProcessor |
| | |
| | # Load model |
| | model_name = "mispeech/r1-aqa" |
| | processor = AutoProcessor.from_pretrained(model_name) |
| | model = Qwen2AudioForConditionalGeneration.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto") |
| | |
| | # Load example audio |
| | wav_path = "test-mini-audios/3fe64f3d-282c-4bc8-a753-68f8f6c35652.wav" # from MMAU dataset |
| | waveform, sampling_rate = torchaudio.load(wav_path) |
| | if sampling_rate != 16000: |
| | waveform = torchaudio.transforms.Resample(orig_freq=sampling_rate, new_freq=16000)(waveform) |
| | audios = [waveform[0].numpy()] |
| | |
| | # Make prompt text |
| | question = "Based on the given audio, identify the source of the speaking voice." |
| | options = ["Man", "Woman", "Child", "Robot"] |
| | prompt = f"{question} Please choose the answer from the following options: {str(options)}. Output the final answer in <answer> </answer>." |
| | message = [ |
| | {"role": "user", "content": [ |
| | {"type": "audio", "audio_url": wav_path}, |
| | {"type": "text", "text": prompt} |
| | ]} |
| | ] |
| | texts = processor.apply_chat_template(message, add_generation_prompt=True, tokenize=False) |
| | |
| | # Process |
| | inputs = processor(text=texts, audios=audios, sampling_rate=16000, return_tensors="pt", padding=True).to(model.device) |
| | generated_ids = model.generate(**inputs, max_new_tokens=256) |
| | generated_ids = generated_ids[:, inputs.input_ids.size(1):] |
| | response = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) |
| | |
| | print(response) |
| | ``` |
| |
|
| | ## Citation |
| |
|
| | ```bib |
| | @article{li2025reinforcement, |
| | title={Reinforcement Learning Outperforms Supervised Fine-Tuning: A Case Study on Audio Question Answering}, |
| | author={Li, Gang and Liu, Jizhong and Dinkel, Heinrich and Niu, Yadong and Zhang, Junbo and Luan, Jian}, |
| | journal={arXiv preprint arXiv:2503.11197}, |
| | year={2025}, |
| | url={https://github.com/xiaomi-research/r1-aqa; https://huggingface.co/mispeech/r1-aqa} |
| | } |
| | ``` |
| |
|