File size: 8,525 Bytes
f666f1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
"""
Prepare and upload VEFX-Reward model to HuggingFace Hub.

This script:
1. Loads the Qwen3-VL-4B base model
2. Manually merges LoRA weights into the base model
3. Loads non-LoRA weights (rm_head, merger, special token embeddings)
4. Saves and uploads the complete model to HuggingFace

Prerequisites:
    pip install huggingface_hub safetensors
    huggingface-cli login

Usage:
    python scripts/prepare_and_upload.py \
        --checkpoint_dir /path/to/training/logs/v4/ord_4B_lora_2stage_promptv2_res399k \
        --checkpoint_step 1050 \
        --hf_repo VEFX-Reward/VEFX-Reward-4B \
        --output_dir ./merged_model
"""

import argparse
import json
import os

import safetensors.torch as st
import torch
from transformers import AutoProcessor, AutoTokenizer

import sys
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
from vefx_reward.model import Qwen3VLRewardModelBT

SPECIAL_TOKENS = [
    "<|VQ_reward|>", "<|MQ_reward|>", "<|TA_reward|>",
    "<|IF_reward|>", "<|RQ_reward|>", "<|EE_reward|>",
]


def main():
    parser = argparse.ArgumentParser(description="Prepare and upload VEFX-Reward model")
    parser.add_argument("--checkpoint_dir", required=True,
                        help="Training output directory containing model_config.json and checkpoint-*/")
    parser.add_argument("--checkpoint_step", type=int, default=-1,
                        help="Checkpoint step to use (-1 = latest)")
    parser.add_argument("--hf_repo", default="VEFX-Reward/VEFX-Reward-4B",
                        help="HuggingFace repo ID to upload to")
    parser.add_argument("--output_dir", default="./merged_model",
                        help="Local directory to save merged model before upload")
    parser.add_argument("--upload", action="store_true",
                        help="Actually upload to HuggingFace (otherwise just save locally)")
    args = parser.parse_args()

    # 1. Load training config
    config_path = os.path.join(args.checkpoint_dir, "model_config.json")
    with open(config_path) as f:
        config_dict = json.load(f)

    model_config = config_dict["model_config"]
    data_config = config_dict["data_config"]

    base_model_path = model_config["model_name_or_path"]
    output_dim = model_config["output_dim"]
    use_ordinal = model_config["use_ordinal"]
    num_classes = model_config["num_classes"]
    reward_token = model_config["reward_token"]

    print(f"Base model: {base_model_path}")
    print(f"Output dim: {output_dim}, Ordinal: {use_ordinal}, Num classes: {num_classes}")

    # 2. Find checkpoint
    import glob as globmod
    ckpt_dirs = sorted(globmod.glob(os.path.join(args.checkpoint_dir, "checkpoint-*")),
                       key=lambda x: int(x.split("-")[-1]))
    if args.checkpoint_step == -1:
        ckpt_path = ckpt_dirs[-1]
    else:
        ckpt_path = os.path.join(args.checkpoint_dir, f"checkpoint-{args.checkpoint_step}")
    print(f"Using checkpoint: {ckpt_path}")

    # 3. Load processor from base model with checkpoint's tokenizer
    processor = AutoProcessor.from_pretrained(base_model_path, padding_side="right")
    ckpt_tokenizer_path = os.path.join(ckpt_path, "tokenizer")
    if os.path.isdir(ckpt_tokenizer_path):
        processor.tokenizer = AutoTokenizer.from_pretrained(ckpt_tokenizer_path)
    else:
        processor.tokenizer.add_special_tokens({"additional_special_tokens": SPECIAL_TOKENS})
    special_token_ids = processor.tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS)
    print(f"Tokenizer vocab size: {len(processor.tokenizer)}")

    # 4. Load base model with reward head
    print("Loading base model...")
    model = Qwen3VLRewardModelBT.from_pretrained(
        base_model_path,
        torch_dtype=torch.bfloat16,
        output_dim=output_dim,
        reward_token=reward_token,
        special_token_ids=special_token_ids,
        use_ordinal=use_ordinal,
        num_classes=num_classes,
        use_cache=True,
        device_map="cpu",
    )
    model.resize_token_embeddings(len(processor.tokenizer))
    print(f"Model embeddings resized to {len(processor.tokenizer)}")

    # 5. Manual LoRA merge (bypasses PEFT tie_word_embeddings issues)
    print("Loading and merging adapter weights...")
    adapter_weights = st.load_file(os.path.join(ckpt_path, "adapter_model.safetensors"), device="cpu")

    with open(os.path.join(ckpt_path, "adapter_config.json")) as f:
        lora_cfg = json.load(f)
    scaling = lora_cfg["lora_alpha"] / lora_cfg["r"]

    # Categorize adapter keys
    base_layers, lora_As, lora_Bs, emb_As, emb_Bs = {}, {}, {}, {}, {}
    for k, v in adapter_weights.items():
        ck = k.replace("base_model.model.", "")
        if ".base_layer.weight" in ck:
            base_layers[ck.replace(".base_layer.weight", "")] = v
        elif ".lora_A.weight" in ck:
            lora_As[ck.replace(".lora_A.weight", "")] = v
        elif ".lora_B.weight" in ck:
            lora_Bs[ck.replace(".lora_B.weight", "")] = v
        elif ".lora_embedding_A" in ck:
            emb_As[ck.replace(".lora_embedding_A", "")] = v
        elif ".lora_embedding_B" in ck:
            emb_Bs[ck.replace(".lora_embedding_B", "")] = v

    model_state = model.state_dict()

    # Replace base layer weights (for resized lm_head / embed_tokens)
    for mod, w in base_layers.items():
        key = mod + ".weight"
        if key in model_state:
            model_state[key] = w.to(model_state[key].dtype)
            print(f"  Replaced base layer: {key}")

    # Merge LoRA: W_merged = W + B @ A * scaling
    merged_count = 0
    for mod in lora_As:
        if mod in lora_Bs:
            A, B = lora_As[mod].float(), lora_Bs[mod].float()
            delta = (B @ A) * scaling
            key = mod + ".weight"
            if key in model_state:
                model_state[key] = (model_state[key].float() + delta).to(torch.bfloat16)
                merged_count += 1
    print(f"  Merged {merged_count} LoRA modules")

    # Merge embedding LoRA
    for mod in emb_As:
        if mod in emb_Bs:
            A, B = emb_As[mod].float(), emb_Bs[mod].float()
            delta = (B @ A).T * scaling
            key = mod + ".weight"
            if key in model_state:
                model_state[key] = (model_state[key].float() + delta).to(torch.bfloat16)
                print(f"  Merged embedding LoRA: {key}")

    # 6. Load non-LoRA weights (rm_head, merger, special embeddings)
    non_lora_path = os.path.join(ckpt_path, "non_lora_state_dict.pth")
    if os.path.exists(non_lora_path):
        print("Loading non-LoRA weights...")
        non_lora_weights = torch.load(non_lora_path, map_location="cpu")
        for k, v in non_lora_weights.items():
            ck = k.replace("base_model.model.", "")
            if ck in model_state:
                model_state[ck] = v.to(model_state[ck].dtype)
                print(f"  Loaded: {ck}")

    model.load_state_dict(model_state)
    print(f"All weights loaded. rm_head shape: {model.rm_head.weight.shape}")

    # 7. Save merged model
    os.makedirs(args.output_dir, exist_ok=True)
    print(f"Saving merged model to {args.output_dir}...")
    model.save_pretrained(args.output_dir, safe_serialization=True)
    processor.save_pretrained(args.output_dir)

    # Save VEFX-specific config
    vefx_config = {
        "output_dim": output_dim,
        "use_ordinal": use_ordinal,
        "num_classes": num_classes,
        "reward_token": reward_token,
        "fps": data_config.get("fps", 4.0),
        "max_frame_pixels": data_config.get("max_frame_pixels", 399360),
        "eval_dim": data_config.get("eval_dim", ["IF", "RQ", "EE"]),
        "prompt_template_type": data_config.get("prompt_template_type", "editreward_v2_special"),
    }
    with open(os.path.join(args.output_dir, "vefx_config.json"), "w") as f:
        json.dump(vefx_config, f, indent=2)
    print("Saved vefx_config.json")

    # 8. Upload to HuggingFace
    if args.upload:
        from huggingface_hub import HfApi
        api = HfApi()
        print(f"Uploading to {args.hf_repo}...")
        api.upload_folder(
            folder_path=args.output_dir,
            repo_id=args.hf_repo,
            repo_type="model",
        )
        print(f"Upload complete: https://huggingface.co/{args.hf_repo}")
    else:
        print(f"\nModel saved to {args.output_dir}")
        print(f"To upload, run again with --upload flag, or manually:")
        print(f"  huggingface-cli upload {args.hf_repo} {args.output_dir} .")


if __name__ == "__main__":
    main()