File size: 4,809 Bytes
efe8ef6
5518542
 
 
 
efe8ef6
5518542
 
 
12dbaa9
 
 
 
5518542
efe8ef6
 
12dbaa9
efe8ef6
 
5518542
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
---
base_model:
- Qwen/Qwen2-VL-7B-Instruct
datasets:
- JosephZ/vg150_train_sgg_prompt
library_name: transformers
license: apache-2.0
metrics:
- recall
tags:
- image
- scene-graph
- scene-graph-generation
pipeline_tag: image-text-to-text
---

# Model Description

<!-- Provide a quick summary of what the model is/does. -->
An end-to-end multimodal LLM for Scene Graph Generation (SGG), which was introduced in [Compile Scene Graphs with Reinforcement Learning](https://huggingface.co/papers/2504.13617)

# R1-SGG: Compile Scene Graphs with Reinforcement Learning

## **Structured Visual Reasoning with Multimodal LLMs and Reinforcement Learning**  
[![Paper](https://img.shields.io/badge/arXiv-2504.13617-b31b1b.svg)](https://arxiv.org/abs/2504.13617)  [![License](https://img.shields.io/badge/license-Apache--2.0-blue.svg)](LICENSE) [![Hugging Face](https://img.shields.io/badge/HuggingFace-Demo-orange?logo=huggingface)](https://huggingface.co/spaces/JosephZ/R1-SGG) 
---

## 🚀 Update
- ✅ ![Hugging Face](https://img.shields.io/badge/HuggingFace-Model-orange?logo=huggingface)[R1-SGG-7B](https://huggingface.co/JosephZ/R1-SGG-7B), [R1-SGG-Zero-7B](https://huggingface.co/JosephZ/R1-SGG-Zero-7B)
- ✅ Support [PSG](https://github.com/Jingkang50/OpenPSG) dataset (bbox format only, not Panoptic)
- ✅ Updated loss implementation
- ✅ Always use `custom_per_device_train_batch_size` instead of `per_device_train_batch_size` for faster sampling under gradient accumulation
- ⚠️ Current loss implementation might still be affected by gradient accumulation: [trl issue #3021](https://github.com/huggingface/trl/issues/3021)

---

## 🛠️ Setup Environment
```bash
bash install.sh
```
Main dependencies:
```bash
- torch == 2.5.0 or 2.5.1 (cu124, optional)
- transformers (supports Qwen2VL, Qwen2.5VL)
- trl
- vLLM
```

---

## 📚 Dataset
Load preprocessed datasets via:
```python
from datasets import load_dataset

db_train = load_dataset("JosephZ/vg150_train_sgg_prompt")["train"]
db_val = load_dataset("JosephZ/vg150_val_sgg_prompt")["train"]
```
or for PSG:
```python
db_train = load_dataset("JosephZ/psg_train_sg")["train"]  # keys: image_id, image, objects, relationships
db_val = load_dataset("JosephZ/psg_test_sg")["train"]
```
We transformed VG150 into HuggingFace Datasets format with keys:
- `image_id`
- `image`
- `prompt_open`
- `prompt_close`
- `objects`
- `relationships`

---

## 🔥 Supported Models
- [x] Qwen/Qwen2-VL-2B-Instruct
- [x] Qwen/Qwen2-VL-7B-Instruct
- [x] Qwen/Qwen2.5-VL-3B-Instruct
- [x] Qwen/Qwen2.5-VL-7B-Instruct

---

## 🏋️‍♂️ Training

### Training with Supervised Fine-Tuning (SFT)

For **SLURM users**:
```bash
sbatch scripts/sft/7B_sgg.sh 
```

For **local machines**:
```bash
bash scripts/sft_local/7B_sgg.sh
```
⏱️ Approximate training time:
- 2B models: ~4 hours (4×A100 SXM4 GPUs)
- 7B models: ~10 hours (4×A100 SXM4 GPUs)

---

### Training with Reinforcement Learning (GRPO)
** Update (11/05/2025): to use "Hard Recall"**:
```
--reward_funcs format_reward edge_hard_reward 
```

For **A100 GPUs**:
```bash
sbatch scripts/grpo/train_a100_2B.sh
```
(12 hours on 16×A100 GPUs)

For **GH200 GPUs**:
```bash
sbatch scripts/grpo/train_gh200.sh
```
(16 hours on 16×GH200 GPUs)

For clusters with many RTX_3090/4090 GPUs:
```bash
sbatch scripts/grpo/train_fused.sh
```
- Training 7B models on 24GB cards is possible with Zero3, but slow due to communication bottlenecks.
- (Fun fact: training with 120×RTX_4090 is crazy but severely limited by communication latency.)

💡 **Recommended learning rate**: `6e-7`.

---

## 🧪 Inference and Evaluation

### Inference with SFT-trained models:
```bash
bash scripts/inference/run_sgg_inference.sh $DATASET $MODEL_NAME $OUTPUT_DIR
```
For models trained **with predefined categories**, add `true`:
```bash
bash scripts/inference/run_sgg_inference.sh $DATASET $MODEL_NAME $OUTPUT_DIR true
```

### Inference with GRPO-trained models:
```bash
bash scripts/inference/run_sgg_inference.sh $DATASET $MODEL_NAME $OUTPUT_DIR false/true true
```

### Evaluation:
```bash
DATASET_TYPE=vg # or psg
python src/sgg_gather_preds.py $DATASET_TYPE $OUTPUT_DIR sgg_pred_results.json
python src/vg150_eval.py $DATASET sgg_pred_results.json
```

---

## 🤝 Acknowledgement
The `GRPOTrainer` used in this project is based on [trl's GRPOTrainer](https://github.com/huggingface/trl/blob/main/trl/trainer/grpo_trainer.py), extended to support multimodal inputs.

---

## 📖 Citation
If you find this work helpful, please cite:
```bibtex
@article{chen2025compile,
  title={Compile Scene Graphs with Reinforcement Learning},
  author={Chen, Zuyao and Wu, Jinlin and Lei, Zhen and Pollefeys, Marc and Chen, Chang Wen},
  journal={arXiv preprint arXiv:2504.13617},
  year={2025}
}
```

---

# ✨ Happy Compiling!