Carol0110 commited on
Commit
bd09f55
·
verified ·
1 Parent(s): e741b9b

Upload 3 files

Browse files
Files changed (2) hide show
  1. README_zh.md +29 -0
  2. unirm.py +321 -0
README_zh.md ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # UniRM:用于多模态内容审核的多头标量奖励模型
2
+
3
+ **UniRM** 是一个用于多模态内容审核的**多头标量奖励模型**,能够提供**可解释的、属性级别的评分信号**。
4
+ 该模型被设计用于支持 **UniMod** 中的开放式推理策略优化,尤其适用于**缺乏确定性标签的响应生成阶段**。
5
+
6
+ UniRM 将奖励信号解耦为多个维度,使模型能够区分**表达质量**与**安全边界**(隐私、偏见、有害性、合法性),从而实现更透明的诊断与更稳定的训练。
7
+
8
+ ---
9
+
10
+ ## 演示视频
11
+
12
+ > UniRM 演示视频:
13
+
14
+ <video controls preload="metadata" style="width:100%; max-width:900px; border-radius:12px;">
15
+ <source src="static/videos/unirm.mp4" type="video/mp4">
16
+ </video>
17
+
18
+ ---
19
+
20
+ ## 快速开始(Gradio)
21
+
22
+ 下面给出一个最小化的 Gradio 示例,用于加载 **UniRM**,并对 *(输入指令、模型回复、可选图像)* 进行**多头奖励评分**。
23
+
24
+ ```bash
25
+ git clone https://github.com/TideDra/lmm-r1.git
26
+ cd lmm-r1
27
+ pip install -e .[vllm]
28
+ pip install flash_attn --no-build-isolation
29
+ python unirm.py --model_path {PATH_TO_UNIRM}
unirm.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tempfile
3
+ import argparse
4
+ import gradio as gr
5
+ from PIL import Image
6
+ import torch
7
+ import torch.nn as nn
8
+ import json, os
9
+ from typing import Optional, Dict
10
+ from openrlhf.models.lmm_kits.qwen2_5_vl.patch import Qwen2_5_VLPatch
11
+ from openrlhf.models.lmm_kits.base.data_processor import MMInputs
12
+ from openrlhf.models.lmm_kits.qwen2_5_vl.data_processor import Qwen2_5_VLDataProcessor
13
+ from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor
14
+
15
+
16
+ class UniRM(nn.Module):
17
+ def __init__(self, base_model, head_names, head_activation="sigmoid"):
18
+ super().__init__()
19
+ self.config = base_model.config
20
+ self.base_model = base_model
21
+ if hasattr(base_model, "lm_head"):
22
+ try:
23
+ del base_model.lm_head
24
+ print("[UniRM] Removed lm_head from base_model.")
25
+ except Exception:
26
+ for p in base_model.lm_head.parameters():
27
+ p.requires_grad = False
28
+ self.base_model.lm_head = None
29
+ print("[UniRM] Froze lm_head parameters instead of deletion.")
30
+
31
+ if hasattr(base_model, "model") and hasattr(base_model.model, "lm_head"):
32
+ del base_model.model.lm_head
33
+ print("[UniRM] Removed nested lm_head from base_model.model.")
34
+
35
+ self.config.mgrm_heads = head_names
36
+ self.config.mgrm_head_activation = head_activation
37
+ self.config.model_type = getattr(self.config, "model_type", "mgrm_vlm")
38
+
39
+ hidden_size = base_model.config.hidden_size
40
+ dtype = next(base_model.parameters()).dtype
41
+
42
+ if head_activation == "sigmoid":
43
+ activation = nn.Sigmoid()
44
+ elif head_activation == "tanh":
45
+ activation = nn.Tanh()
46
+ elif head_activation == "relu":
47
+ activation = nn.ReLU()
48
+ else:
49
+ raise ValueError(f"Unsupported activation type: {head_activation}")
50
+
51
+ self.value_heads = nn.ModuleDict({
52
+ name: nn.Sequential(
53
+ nn.Linear(hidden_size, 1, bias=False, dtype=dtype),
54
+ activation
55
+ )
56
+ for name in head_names
57
+ })
58
+
59
+ print(f"[UniRM] ✅ Initialized Multi-Head Reward Model with heads: {head_names}")
60
+ print(f"[UniRM] 🔧 Activation: {head_activation} | Hidden size: {hidden_size}")
61
+
62
+ def forward(
63
+ self,
64
+ input_ids: torch.LongTensor,
65
+ attention_mask: Optional[torch.Tensor] = None,
66
+ visual_inputs: Optional[MMInputs] = None,
67
+ return_output=False,
68
+ ) -> Dict[str, torch.Tensor]:
69
+ if visual_inputs is None:
70
+ class _Empty:
71
+ emb_inputs = {}
72
+ forward_inputs = {}
73
+ visual_inputs = _Empty()
74
+
75
+ inputs_embeds = self.base_model.get_inputs_embeds(input_ids, **visual_inputs.emb_inputs)
76
+ position_ids = self.base_model.get_position_ids(input_ids, attention_mask=attention_mask, **visual_inputs.emb_inputs)
77
+ outputs = self.base_model.model(
78
+ inputs_embeds=inputs_embeds,
79
+ attention_mask=attention_mask,
80
+ position_ids=position_ids,
81
+ output_hidden_states=True,
82
+ use_cache=False,
83
+ **visual_inputs.forward_inputs,
84
+ )
85
+
86
+ hidden = outputs["hidden_states"][-1]
87
+ eos_idx = attention_mask.size(1) - 1 - attention_mask.long().fliplr().argmax(dim=1, keepdim=True)
88
+ eos_hidden = hidden.gather(
89
+ dim=1,
90
+ index=eos_idx.unsqueeze(-1).expand(-1, -1, hidden.size(-1))
91
+ ).squeeze(1)
92
+
93
+ rewards = {name: head(eos_hidden).squeeze(-1) for name, head in self.value_heads.items()}
94
+ return (rewards, outputs) if return_output else rewards
95
+
96
+
97
+ def save_pretrained(self, save_directory, **kwargs):
98
+ os.makedirs(save_directory, exist_ok=True)
99
+
100
+ base_model_dir = os.path.join(save_directory, "base_model")
101
+ os.makedirs(base_model_dir, exist_ok=True)
102
+
103
+ if hasattr(self._base_model, "save_pretrained"):
104
+ self._base_model.save_pretrained(base_model_dir, **kwargs)
105
+ else:
106
+ torch.save(self._base_model.state_dict(), os.path.join(base_model_dir, "base_model.pt"))
107
+
108
+ value_head_path = os.path.join(save_directory, "value_heads.pt")
109
+ torch.save({k: v.cpu() for k, v in self.value_heads.state_dict().items()}, value_head_path)
110
+
111
+ cfg = self.config.to_dict() if hasattr(self.config, "to_dict") else dict(self.config)
112
+ cfg.update({
113
+ "model_type": "mgrm_vlm",
114
+ "head_names": list(self.value_heads.keys()),
115
+ "attn_implementation": "eager"
116
+ })
117
+ with open(os.path.join(save_directory, "config.json"), "w") as f:
118
+ json.dump(cfg, f, indent=2)
119
+
120
+ print(f"✅ UniRM saved to {save_directory}")
121
+
122
+ @staticmethod
123
+ def is_backend_compatible() -> bool:
124
+ return True
125
+
126
+
127
+ @classmethod
128
+ def from_pretrained(cls, load_directory, torch_dtype=torch.bfloat16):
129
+ from openrlhf.models.lmm_kits.qwen2_5_vl.patch import Qwen2_5_VLPatch
130
+ Qwen2_5_VLPatch._load_all_patches()
131
+
132
+ cfg_path = os.path.join(load_directory, "config.json")
133
+ mgrm_cfg = json.load(open(cfg_path)) if os.path.exists(cfg_path) else {}
134
+
135
+ base_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
136
+ os.path.join(load_directory, mgrm_cfg.get("base_model_path", "base_model")),
137
+ torch_dtype=torch_dtype,
138
+ attn_implementation="eager",
139
+ )
140
+
141
+ head_names = mgrm_cfg.get("mgrm_heads", mgrm_cfg.get("head_names", []))
142
+ model = cls(base_model, head_names).to("cuda")
143
+
144
+ vh_path = os.path.join(load_directory, "value_heads.pt")
145
+ if os.path.exists(vh_path):
146
+ print(f"💡 Loading structured value_heads from {vh_path} ...")
147
+ vh_state = torch.load(vh_path, map_location="cuda")
148
+ for name, head in model.value_heads.items():
149
+ if name in vh_state:
150
+ try:
151
+ head.load_state_dict(vh_state[name], strict=False)
152
+ print(f" ✅ Loaded head: {name}")
153
+ except Exception as e:
154
+ print(f" ⚠️ Failed to load head {name}: {e}")
155
+ else:
156
+ print(f" ⚠️ Missing head in saved file: {name}")
157
+ else:
158
+ print("[Info] No value_heads.pt found.")
159
+
160
+ print(f"✅ UniRM fully loaded from {load_directory}")
161
+ return model
162
+
163
+ @classmethod
164
+ def _from_config(cls, config, **kwargs):
165
+ from transformers import Qwen2_5_VLForConditionalGeneration
166
+ config._attn_implementation_internal = "eager"
167
+ base_model = Qwen2_5_VLForConditionalGeneration(config)
168
+ head_names = getattr(config, "mgrm_heads", [])
169
+ return cls(base_model, head_names)
170
+
171
+
172
+ EVALUATION_PROMPT_TEMPLATE = """
173
+ You are an expert evaluator for multimodal generation models.
174
+ Your task is to assess the quality of the response based on the given prompt.
175
+
176
+ ### Dimension:
177
+ {dimension}
178
+
179
+ ### Prompt:
180
+ {prompt}
181
+
182
+ ### Response:
183
+ {response}
184
+ """.strip()
185
+
186
+
187
+
188
+ class UniRMProxy:
189
+ def __init__(self, args):
190
+ Qwen2_5_VLPatch._load_all_patches()
191
+
192
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
193
+ self.dtype = torch.bfloat16 if args.bf16 else torch.float32
194
+ self.head_names = args.head_names.split(",")
195
+
196
+ try:
197
+ self.model = UniRM.from_pretrained(
198
+ args.model_path,
199
+ torch_dtype=self.dtype,
200
+ )
201
+ except Exception as e:
202
+ base_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
203
+ args.model_path,
204
+ torch_dtype=self.dtype,
205
+ device_map="auto",
206
+ )
207
+ self.model = UniRM(base_model, self.head_names)
208
+
209
+ self.model = self.model.to(self.device).eval()
210
+ self.model.value_heads = self.model.value_heads.to(
211
+ next(self.model.base_model.parameters()).dtype
212
+ )
213
+
214
+ self.tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True)
215
+ base_processor = AutoProcessor.from_pretrained(
216
+ os.path.join(args.model_path, "base_model"),
217
+ trust_remote_code=True,
218
+ )
219
+ self.processor = Qwen2_5_VLDataProcessor(
220
+ processor=base_processor,
221
+ processor_kwargs=args.processor_kwargs,
222
+ )
223
+
224
+ self.max_length = args.max_len
225
+ self.batch_size = args.batch_size
226
+
227
+
228
+ @torch.no_grad()
229
+ def score(self, prompt, response, image, dimension):
230
+ if not prompt or not response:
231
+ return "❌ Prompt and Response are required.", []
232
+
233
+ messages = [{
234
+ "role": "user",
235
+ "content": []
236
+ }]
237
+
238
+
239
+ if image is not None:
240
+ if isinstance(image, Image.Image):
241
+ tmp = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
242
+ image.save(tmp.name)
243
+ img_payload = tmp.name
244
+ else:
245
+ img_payload = image
246
+ messages[0]["content"].append({"type": "image", "image": img_payload})
247
+ text = EVALUATION_PROMPT_TEMPLATE.format(
248
+ dimension=dimension,
249
+ prompt=prompt,
250
+ response=response,
251
+ )
252
+ messages[0]["content"].append({"type": "text", "text": text})
253
+
254
+ mm_inputs = self.processor(
255
+ [json.dumps(messages, indent=2)],
256
+ max_length=self.max_length,
257
+ padding=True,
258
+ device=self.device,
259
+ return_tensors="pt",
260
+ )
261
+
262
+ outputs, _ = self.model(
263
+ input_ids=mm_inputs.extra_info["input_ids"],
264
+ attention_mask=mm_inputs.extra_info["attention_mask"],
265
+ visual_inputs=mm_inputs,
266
+ return_output=True,
267
+ )
268
+
269
+ rows = []
270
+ pretty = []
271
+ for h in self.head_names:
272
+ v = outputs[h].detach().cpu().float().item()
273
+ rows.append([h, v])
274
+ pretty.append(f"{h}: {v:.6f}")
275
+
276
+ return "\n".join(pretty), rows
277
+
278
+
279
+ def main():
280
+ parser = argparse.ArgumentParser()
281
+ parser.add_argument("--model_path", type=str, required=True)
282
+ parser.add_argument("--head_names", type=str, default="style,privacy,bias,toxicity,legality")
283
+ parser.add_argument("--max_len", type=int, default=1024)
284
+ parser.add_argument("--batch_size", type=int, default=1)
285
+ parser.add_argument("--bf16", action="store_true")
286
+ parser.add_argument(
287
+ "--processor_kwargs",
288
+ type=json.loads,
289
+ default={"min_pixels": 4 * 28 * 28, "max_pixels": 16384 * 28 * 28},
290
+ )
291
+ parser.add_argument("--share", action="store_true")
292
+ args = parser.parse_args()
293
+
294
+ proxy = UniRMProxy(args)
295
+
296
+ with gr.Blocks(title="UniRM Scoring") as demo:
297
+ gr.Markdown("# 🧠 UniRM – Multimodal Reward Model")
298
+
299
+ with gr.Row():
300
+ with gr.Column():
301
+ prompt = gr.Textbox(label="Prompt", lines=4)
302
+ response = gr.Textbox(label="Response", lines=6)
303
+ dimension = gr.Textbox(label="Dimension", value="general")
304
+ image = gr.Image(type="pil", label="Image (optional)")
305
+ btn = gr.Button("Score")
306
+
307
+ with gr.Column():
308
+ text_out = gr.Textbox(label="Scores", lines=6)
309
+ table_out = gr.Dataframe(headers=["Head", "Score"], interactive=False)
310
+
311
+ btn.click(
312
+ proxy.score,
313
+ inputs=[prompt, response, image, dimension],
314
+ outputs=[text_out, table_out],
315
+ )
316
+
317
+ demo.launch(share=args.share)
318
+
319
+
320
+ if __name__ == "__main__":
321
+ main()