|
|
--- |
|
|
license: apache-2.0 |
|
|
language: |
|
|
- en |
|
|
- zh |
|
|
- de |
|
|
- ru |
|
|
base_model: |
|
|
- meta-llama/Llama-3.1-8B |
|
|
tags: |
|
|
- translation |
|
|
- reasoning |
|
|
- test-time |
|
|
--- |
|
|
Reward model for Plan2Align, using for test-time translation task on `zh->en`, `zh->de`, `zh->ru` language pairs. |
|
|
|
|
|
```bib |
|
|
@article{wang2025plan2align, |
|
|
title={Plan2Align: Predictive Planning Based Test-Time Preference Alignment in Paragraph-Level Machine Translation}, |
|
|
author={Wang, Kuang-Da and Chen, Teng-Ruei and Hung, Yu Heng and Ding, Shuoyang and Wu, Yueh-Hua and Wang, Yu-Chiang Frank and Yang, Chao-Han Huck and Peng, Wen-Chih and Hsieh, Ping-Chun}, |
|
|
journal={arXiv preprint arXiv:2502.20795}, |
|
|
year={2025} |
|
|
} |
|
|
``` |
|
|
|
|
|
## Using Reward Model |
|
|
|
|
|
```python |
|
|
RM = AutoModelForCausalLMWithValueHead.from_pretrained('ray24724919/plan2align_rm',torch_dtype=torch_dtype) |
|
|
RM.eval() |
|
|
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct") |
|
|
RM.gradient_checkpointing_enable() #if need |
|
|
|
|
|
value_head_weights = load_file("path-to-valuehead-safetensors") |
|
|
new_state_dict = {key.replace("v_head.", "") if key.startswith("v_head.") else key: value for key, value in value_head_weights.items()} |
|
|
RM.v_head.load_state_dict(new_state_dict) |
|
|
``` |
|
|
|
|
|
## Reward Function |
|
|
```python |
|
|
def reward(language, text, response, device='cuda:0'): |
|
|
message=[{"role": "system", "content":' You are a helpful translator and only output the result.'}, |
|
|
{"role": "user", "content": f"### Translate this from Chinese to {language}, Chinese:\n{text}\n### {language}:"}, |
|
|
{"role": "assistant", "content": response}] |
|
|
tokenized_inputs = tokenizer.apply_chat_template( |
|
|
message, |
|
|
add_generation_prompt=False, |
|
|
return_tensors="pt" |
|
|
).to(device) |
|
|
|
|
|
inputs = { |
|
|
"input_ids": tokenized_inputs, |
|
|
"attention_mask": torch.ones_like(tokenized_inputs, device=device) |
|
|
} |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(**inputs, return_value=True) |
|
|
rewards = outputs[2] |
|
|
|
|
|
final_reward = rewards[:, -1].item() |
|
|
|
|
|
return final_reward |
|
|
``` |
|
|
|
|
|
## System prompt of translation reward modeling |
|
|
|
|
|
```python |
|
|
messages = [{"role": "system", "content": "You are a helpful translator and only output the result."}, |
|
|
{"role": "user", "content": f"### Translate this from Chinese to {language}, Chinese:\n{source}\n### {language}:"}, |
|
|
{"role": "assistant", "content": translation}] |
|
|
``` |
|
|
|