Qwen3-VL-8B-Instruct-MRPO

MRPO is a novel reinforcement learning framework that improves medical multimodal reasoning by directly addressing failures in the reasoning process. It reshapes GRPO-style advantages using both answer-level and step-wise process rewards, assigning exponentially larger penalties to earlier invalid steps when the final answer is incorrect, thereby correcting early-stage failures before they cascade while preserving successful trajectories. By redistributing the learning signal according to where reasoning first fails, MRPO induces transferable reasoning that improves both reasoning quality and final answer accuracy across diverse medical VQA benchmarks.

Code: github

Project Page: page

Paper: Breaking Failure Cascades: Step-Aware Reinforcement Learning for Medical Multimodal Reasoning

Quick Start

from transformers import Qwen3VLForConditionalGeneration, AutoProcessor
from PIL import Image
import torch

# Load the model (MRPO Qwen3-VL checkpoint; or a local trained checkpoint path)
model_path = "dmis-lab/Qwen3-VL-8B-Instruct-MRPO"
model = Qwen3VLForConditionalGeneration.from_pretrained(
    model_path,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)
processor = AutoProcessor.from_pretrained(model_path)

# Example usage (no system prompt; Qwen3 uses <thinking> tags for reasoning)
image_path = "path/to/medical/image.jpg"
question = "What can you see in this medical image?"

question_text = (
    f"{question} Think step-by-step and enclose your reasoning in "
    "<thinking>...</thinking> tags. Then provide your answer in <answer>...</answer> tags."
)
messages = [
    {
        "role": "user",
        "content": [
            {"type": "image", "image": image_path},
            {"type": "text", "text": question_text},
        ],
    }
]

# Preparation for inference
text = processor.apply_chat_template(
    messages, tokenize=False, add_generation_prompt=True
)
inputs = processor(
    text=[text],
    images=[Image.open(image_path)],
    padding=True,
    padding_side="left",
    return_tensors="pt",
)
inputs = inputs.to(model.device)

# Inference (greedy decoding, matching inference.py)
generated_ids = model.generate(**inputs, use_cache=True, max_new_tokens=512, do_sample=False)
generated_ids_trimmed = [
    out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
    generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
print(output_text)

Citation

@misc{jung2026breakingfailurecascadesstepaware,
      title={Breaking Failure Cascades: Step-Aware Reinforcement Learning for Medical Multimodal Reasoning}, 
      author={Junha Jung and Minbyul Jeong and Suhyeon Lim and Sungwook Jung and Jaehoon Yun and Taeyun Roh and Mujeen Sung and Jaewoo Kang},
      year={2026},
      eprint={2606.31825},
      archivePrefix={arXiv},
      primaryClass={cs.CV},
      url={https://arxiv.org/abs/2606.31825}, 
}

License

This model is released under the Apache 2.0 license.

Downloads last month
-
Safetensors
Model size
770k params
Tensor type
BF16
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for dmis-lab/Qwen3-VL-8B-Instruct-MRPO

Finetuned
(342)
this model

Collection including dmis-lab/Qwen3-VL-8B-Instruct-MRPO

Paper for dmis-lab/Qwen3-VL-8B-Instruct-MRPO