File size: 4,809 Bytes
efe8ef6 5518542 efe8ef6 5518542 12dbaa9 5518542 efe8ef6 12dbaa9 efe8ef6 5518542 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 |
---
base_model:
- Qwen/Qwen2-VL-7B-Instruct
datasets:
- JosephZ/vg150_train_sgg_prompt
library_name: transformers
license: apache-2.0
metrics:
- recall
tags:
- image
- scene-graph
- scene-graph-generation
pipeline_tag: image-text-to-text
---
# Model Description
<!-- Provide a quick summary of what the model is/does. -->
An end-to-end multimodal LLM for Scene Graph Generation (SGG), which was introduced in [Compile Scene Graphs with Reinforcement Learning](https://huggingface.co/papers/2504.13617)
# R1-SGG: Compile Scene Graphs with Reinforcement Learning
## **Structured Visual Reasoning with Multimodal LLMs and Reinforcement Learning**
[](https://arxiv.org/abs/2504.13617) [](LICENSE) [](https://huggingface.co/spaces/JosephZ/R1-SGG)
---
## 🚀 Update
- ✅ [R1-SGG-7B](https://huggingface.co/JosephZ/R1-SGG-7B), [R1-SGG-Zero-7B](https://huggingface.co/JosephZ/R1-SGG-Zero-7B)
- ✅ Support [PSG](https://github.com/Jingkang50/OpenPSG) dataset (bbox format only, not Panoptic)
- ✅ Updated loss implementation
- ✅ Always use `custom_per_device_train_batch_size` instead of `per_device_train_batch_size` for faster sampling under gradient accumulation
- ⚠️ Current loss implementation might still be affected by gradient accumulation: [trl issue #3021](https://github.com/huggingface/trl/issues/3021)
---
## 🛠️ Setup Environment
```bash
bash install.sh
```
Main dependencies:
```bash
- torch == 2.5.0 or 2.5.1 (cu124, optional)
- transformers (supports Qwen2VL, Qwen2.5VL)
- trl
- vLLM
```
---
## 📚 Dataset
Load preprocessed datasets via:
```python
from datasets import load_dataset
db_train = load_dataset("JosephZ/vg150_train_sgg_prompt")["train"]
db_val = load_dataset("JosephZ/vg150_val_sgg_prompt")["train"]
```
or for PSG:
```python
db_train = load_dataset("JosephZ/psg_train_sg")["train"] # keys: image_id, image, objects, relationships
db_val = load_dataset("JosephZ/psg_test_sg")["train"]
```
We transformed VG150 into HuggingFace Datasets format with keys:
- `image_id`
- `image`
- `prompt_open`
- `prompt_close`
- `objects`
- `relationships`
---
## 🔥 Supported Models
- [x] Qwen/Qwen2-VL-2B-Instruct
- [x] Qwen/Qwen2-VL-7B-Instruct
- [x] Qwen/Qwen2.5-VL-3B-Instruct
- [x] Qwen/Qwen2.5-VL-7B-Instruct
---
## 🏋️♂️ Training
### Training with Supervised Fine-Tuning (SFT)
For **SLURM users**:
```bash
sbatch scripts/sft/7B_sgg.sh
```
For **local machines**:
```bash
bash scripts/sft_local/7B_sgg.sh
```
⏱️ Approximate training time:
- 2B models: ~4 hours (4×A100 SXM4 GPUs)
- 7B models: ~10 hours (4×A100 SXM4 GPUs)
---
### Training with Reinforcement Learning (GRPO)
** Update (11/05/2025): to use "Hard Recall"**:
```
--reward_funcs format_reward edge_hard_reward
```
For **A100 GPUs**:
```bash
sbatch scripts/grpo/train_a100_2B.sh
```
(12 hours on 16×A100 GPUs)
For **GH200 GPUs**:
```bash
sbatch scripts/grpo/train_gh200.sh
```
(16 hours on 16×GH200 GPUs)
For clusters with many RTX_3090/4090 GPUs:
```bash
sbatch scripts/grpo/train_fused.sh
```
- Training 7B models on 24GB cards is possible with Zero3, but slow due to communication bottlenecks.
- (Fun fact: training with 120×RTX_4090 is crazy but severely limited by communication latency.)
💡 **Recommended learning rate**: `6e-7`.
---
## 🧪 Inference and Evaluation
### Inference with SFT-trained models:
```bash
bash scripts/inference/run_sgg_inference.sh $DATASET $MODEL_NAME $OUTPUT_DIR
```
For models trained **with predefined categories**, add `true`:
```bash
bash scripts/inference/run_sgg_inference.sh $DATASET $MODEL_NAME $OUTPUT_DIR true
```
### Inference with GRPO-trained models:
```bash
bash scripts/inference/run_sgg_inference.sh $DATASET $MODEL_NAME $OUTPUT_DIR false/true true
```
### Evaluation:
```bash
DATASET_TYPE=vg # or psg
python src/sgg_gather_preds.py $DATASET_TYPE $OUTPUT_DIR sgg_pred_results.json
python src/vg150_eval.py $DATASET sgg_pred_results.json
```
---
## 🤝 Acknowledgement
The `GRPOTrainer` used in this project is based on [trl's GRPOTrainer](https://github.com/huggingface/trl/blob/main/trl/trainer/grpo_trainer.py), extended to support multimodal inputs.
---
## 📖 Citation
If you find this work helpful, please cite:
```bibtex
@article{chen2025compile,
title={Compile Scene Graphs with Reinforcement Learning},
author={Chen, Zuyao and Wu, Jinlin and Lei, Zhen and Pollefeys, Marc and Chen, Chang Wen},
journal={arXiv preprint arXiv:2504.13617},
year={2025}
}
```
---
# ✨ Happy Compiling! |