VEFX-Code / scripts /prepare_and_upload.py
VEFX-Reward's picture
Add VEFX-Bench reference code
f666f1f verified
"""
Prepare and upload VEFX-Reward model to HuggingFace Hub.
This script:
1. Loads the Qwen3-VL-4B base model
2. Manually merges LoRA weights into the base model
3. Loads non-LoRA weights (rm_head, merger, special token embeddings)
4. Saves and uploads the complete model to HuggingFace
Prerequisites:
pip install huggingface_hub safetensors
huggingface-cli login
Usage:
python scripts/prepare_and_upload.py \
--checkpoint_dir /path/to/training/logs/v4/ord_4B_lora_2stage_promptv2_res399k \
--checkpoint_step 1050 \
--hf_repo VEFX-Reward/VEFX-Reward-4B \
--output_dir ./merged_model
"""
import argparse
import json
import os
import safetensors.torch as st
import torch
from transformers import AutoProcessor, AutoTokenizer
import sys
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
from vefx_reward.model import Qwen3VLRewardModelBT
SPECIAL_TOKENS = [
"<|VQ_reward|>", "<|MQ_reward|>", "<|TA_reward|>",
"<|IF_reward|>", "<|RQ_reward|>", "<|EE_reward|>",
]
def main():
parser = argparse.ArgumentParser(description="Prepare and upload VEFX-Reward model")
parser.add_argument("--checkpoint_dir", required=True,
help="Training output directory containing model_config.json and checkpoint-*/")
parser.add_argument("--checkpoint_step", type=int, default=-1,
help="Checkpoint step to use (-1 = latest)")
parser.add_argument("--hf_repo", default="VEFX-Reward/VEFX-Reward-4B",
help="HuggingFace repo ID to upload to")
parser.add_argument("--output_dir", default="./merged_model",
help="Local directory to save merged model before upload")
parser.add_argument("--upload", action="store_true",
help="Actually upload to HuggingFace (otherwise just save locally)")
args = parser.parse_args()
# 1. Load training config
config_path = os.path.join(args.checkpoint_dir, "model_config.json")
with open(config_path) as f:
config_dict = json.load(f)
model_config = config_dict["model_config"]
data_config = config_dict["data_config"]
base_model_path = model_config["model_name_or_path"]
output_dim = model_config["output_dim"]
use_ordinal = model_config["use_ordinal"]
num_classes = model_config["num_classes"]
reward_token = model_config["reward_token"]
print(f"Base model: {base_model_path}")
print(f"Output dim: {output_dim}, Ordinal: {use_ordinal}, Num classes: {num_classes}")
# 2. Find checkpoint
import glob as globmod
ckpt_dirs = sorted(globmod.glob(os.path.join(args.checkpoint_dir, "checkpoint-*")),
key=lambda x: int(x.split("-")[-1]))
if args.checkpoint_step == -1:
ckpt_path = ckpt_dirs[-1]
else:
ckpt_path = os.path.join(args.checkpoint_dir, f"checkpoint-{args.checkpoint_step}")
print(f"Using checkpoint: {ckpt_path}")
# 3. Load processor from base model with checkpoint's tokenizer
processor = AutoProcessor.from_pretrained(base_model_path, padding_side="right")
ckpt_tokenizer_path = os.path.join(ckpt_path, "tokenizer")
if os.path.isdir(ckpt_tokenizer_path):
processor.tokenizer = AutoTokenizer.from_pretrained(ckpt_tokenizer_path)
else:
processor.tokenizer.add_special_tokens({"additional_special_tokens": SPECIAL_TOKENS})
special_token_ids = processor.tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS)
print(f"Tokenizer vocab size: {len(processor.tokenizer)}")
# 4. Load base model with reward head
print("Loading base model...")
model = Qwen3VLRewardModelBT.from_pretrained(
base_model_path,
torch_dtype=torch.bfloat16,
output_dim=output_dim,
reward_token=reward_token,
special_token_ids=special_token_ids,
use_ordinal=use_ordinal,
num_classes=num_classes,
use_cache=True,
device_map="cpu",
)
model.resize_token_embeddings(len(processor.tokenizer))
print(f"Model embeddings resized to {len(processor.tokenizer)}")
# 5. Manual LoRA merge (bypasses PEFT tie_word_embeddings issues)
print("Loading and merging adapter weights...")
adapter_weights = st.load_file(os.path.join(ckpt_path, "adapter_model.safetensors"), device="cpu")
with open(os.path.join(ckpt_path, "adapter_config.json")) as f:
lora_cfg = json.load(f)
scaling = lora_cfg["lora_alpha"] / lora_cfg["r"]
# Categorize adapter keys
base_layers, lora_As, lora_Bs, emb_As, emb_Bs = {}, {}, {}, {}, {}
for k, v in adapter_weights.items():
ck = k.replace("base_model.model.", "")
if ".base_layer.weight" in ck:
base_layers[ck.replace(".base_layer.weight", "")] = v
elif ".lora_A.weight" in ck:
lora_As[ck.replace(".lora_A.weight", "")] = v
elif ".lora_B.weight" in ck:
lora_Bs[ck.replace(".lora_B.weight", "")] = v
elif ".lora_embedding_A" in ck:
emb_As[ck.replace(".lora_embedding_A", "")] = v
elif ".lora_embedding_B" in ck:
emb_Bs[ck.replace(".lora_embedding_B", "")] = v
model_state = model.state_dict()
# Replace base layer weights (for resized lm_head / embed_tokens)
for mod, w in base_layers.items():
key = mod + ".weight"
if key in model_state:
model_state[key] = w.to(model_state[key].dtype)
print(f" Replaced base layer: {key}")
# Merge LoRA: W_merged = W + B @ A * scaling
merged_count = 0
for mod in lora_As:
if mod in lora_Bs:
A, B = lora_As[mod].float(), lora_Bs[mod].float()
delta = (B @ A) * scaling
key = mod + ".weight"
if key in model_state:
model_state[key] = (model_state[key].float() + delta).to(torch.bfloat16)
merged_count += 1
print(f" Merged {merged_count} LoRA modules")
# Merge embedding LoRA
for mod in emb_As:
if mod in emb_Bs:
A, B = emb_As[mod].float(), emb_Bs[mod].float()
delta = (B @ A).T * scaling
key = mod + ".weight"
if key in model_state:
model_state[key] = (model_state[key].float() + delta).to(torch.bfloat16)
print(f" Merged embedding LoRA: {key}")
# 6. Load non-LoRA weights (rm_head, merger, special embeddings)
non_lora_path = os.path.join(ckpt_path, "non_lora_state_dict.pth")
if os.path.exists(non_lora_path):
print("Loading non-LoRA weights...")
non_lora_weights = torch.load(non_lora_path, map_location="cpu")
for k, v in non_lora_weights.items():
ck = k.replace("base_model.model.", "")
if ck in model_state:
model_state[ck] = v.to(model_state[ck].dtype)
print(f" Loaded: {ck}")
model.load_state_dict(model_state)
print(f"All weights loaded. rm_head shape: {model.rm_head.weight.shape}")
# 7. Save merged model
os.makedirs(args.output_dir, exist_ok=True)
print(f"Saving merged model to {args.output_dir}...")
model.save_pretrained(args.output_dir, safe_serialization=True)
processor.save_pretrained(args.output_dir)
# Save VEFX-specific config
vefx_config = {
"output_dim": output_dim,
"use_ordinal": use_ordinal,
"num_classes": num_classes,
"reward_token": reward_token,
"fps": data_config.get("fps", 4.0),
"max_frame_pixels": data_config.get("max_frame_pixels", 399360),
"eval_dim": data_config.get("eval_dim", ["IF", "RQ", "EE"]),
"prompt_template_type": data_config.get("prompt_template_type", "editreward_v2_special"),
}
with open(os.path.join(args.output_dir, "vefx_config.json"), "w") as f:
json.dump(vefx_config, f, indent=2)
print("Saved vefx_config.json")
# 8. Upload to HuggingFace
if args.upload:
from huggingface_hub import HfApi
api = HfApi()
print(f"Uploading to {args.hf_repo}...")
api.upload_folder(
folder_path=args.output_dir,
repo_id=args.hf_repo,
repo_type="model",
)
print(f"Upload complete: https://huggingface.co/{args.hf_repo}")
else:
print(f"\nModel saved to {args.output_dir}")
print(f"To upload, run again with --upload flag, or manually:")
print(f" huggingface-cli upload {args.hf_repo} {args.output_dir} .")
if __name__ == "__main__":
main()