Spaces:
Running
Running
| """ | |
| Prepare and upload VEFX-Reward model to HuggingFace Hub. | |
| This script: | |
| 1. Loads the Qwen3-VL-4B base model | |
| 2. Manually merges LoRA weights into the base model | |
| 3. Loads non-LoRA weights (rm_head, merger, special token embeddings) | |
| 4. Saves and uploads the complete model to HuggingFace | |
| Prerequisites: | |
| pip install huggingface_hub safetensors | |
| huggingface-cli login | |
| Usage: | |
| python scripts/prepare_and_upload.py \ | |
| --checkpoint_dir /path/to/training/logs/v4/ord_4B_lora_2stage_promptv2_res399k \ | |
| --checkpoint_step 1050 \ | |
| --hf_repo VEFX-Reward/VEFX-Reward-4B \ | |
| --output_dir ./merged_model | |
| """ | |
| import argparse | |
| import json | |
| import os | |
| import safetensors.torch as st | |
| import torch | |
| from transformers import AutoProcessor, AutoTokenizer | |
| import sys | |
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) | |
| from vefx_reward.model import Qwen3VLRewardModelBT | |
| SPECIAL_TOKENS = [ | |
| "<|VQ_reward|>", "<|MQ_reward|>", "<|TA_reward|>", | |
| "<|IF_reward|>", "<|RQ_reward|>", "<|EE_reward|>", | |
| ] | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Prepare and upload VEFX-Reward model") | |
| parser.add_argument("--checkpoint_dir", required=True, | |
| help="Training output directory containing model_config.json and checkpoint-*/") | |
| parser.add_argument("--checkpoint_step", type=int, default=-1, | |
| help="Checkpoint step to use (-1 = latest)") | |
| parser.add_argument("--hf_repo", default="VEFX-Reward/VEFX-Reward-4B", | |
| help="HuggingFace repo ID to upload to") | |
| parser.add_argument("--output_dir", default="./merged_model", | |
| help="Local directory to save merged model before upload") | |
| parser.add_argument("--upload", action="store_true", | |
| help="Actually upload to HuggingFace (otherwise just save locally)") | |
| args = parser.parse_args() | |
| # 1. Load training config | |
| config_path = os.path.join(args.checkpoint_dir, "model_config.json") | |
| with open(config_path) as f: | |
| config_dict = json.load(f) | |
| model_config = config_dict["model_config"] | |
| data_config = config_dict["data_config"] | |
| base_model_path = model_config["model_name_or_path"] | |
| output_dim = model_config["output_dim"] | |
| use_ordinal = model_config["use_ordinal"] | |
| num_classes = model_config["num_classes"] | |
| reward_token = model_config["reward_token"] | |
| print(f"Base model: {base_model_path}") | |
| print(f"Output dim: {output_dim}, Ordinal: {use_ordinal}, Num classes: {num_classes}") | |
| # 2. Find checkpoint | |
| import glob as globmod | |
| ckpt_dirs = sorted(globmod.glob(os.path.join(args.checkpoint_dir, "checkpoint-*")), | |
| key=lambda x: int(x.split("-")[-1])) | |
| if args.checkpoint_step == -1: | |
| ckpt_path = ckpt_dirs[-1] | |
| else: | |
| ckpt_path = os.path.join(args.checkpoint_dir, f"checkpoint-{args.checkpoint_step}") | |
| print(f"Using checkpoint: {ckpt_path}") | |
| # 3. Load processor from base model with checkpoint's tokenizer | |
| processor = AutoProcessor.from_pretrained(base_model_path, padding_side="right") | |
| ckpt_tokenizer_path = os.path.join(ckpt_path, "tokenizer") | |
| if os.path.isdir(ckpt_tokenizer_path): | |
| processor.tokenizer = AutoTokenizer.from_pretrained(ckpt_tokenizer_path) | |
| else: | |
| processor.tokenizer.add_special_tokens({"additional_special_tokens": SPECIAL_TOKENS}) | |
| special_token_ids = processor.tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS) | |
| print(f"Tokenizer vocab size: {len(processor.tokenizer)}") | |
| # 4. Load base model with reward head | |
| print("Loading base model...") | |
| model = Qwen3VLRewardModelBT.from_pretrained( | |
| base_model_path, | |
| torch_dtype=torch.bfloat16, | |
| output_dim=output_dim, | |
| reward_token=reward_token, | |
| special_token_ids=special_token_ids, | |
| use_ordinal=use_ordinal, | |
| num_classes=num_classes, | |
| use_cache=True, | |
| device_map="cpu", | |
| ) | |
| model.resize_token_embeddings(len(processor.tokenizer)) | |
| print(f"Model embeddings resized to {len(processor.tokenizer)}") | |
| # 5. Manual LoRA merge (bypasses PEFT tie_word_embeddings issues) | |
| print("Loading and merging adapter weights...") | |
| adapter_weights = st.load_file(os.path.join(ckpt_path, "adapter_model.safetensors"), device="cpu") | |
| with open(os.path.join(ckpt_path, "adapter_config.json")) as f: | |
| lora_cfg = json.load(f) | |
| scaling = lora_cfg["lora_alpha"] / lora_cfg["r"] | |
| # Categorize adapter keys | |
| base_layers, lora_As, lora_Bs, emb_As, emb_Bs = {}, {}, {}, {}, {} | |
| for k, v in adapter_weights.items(): | |
| ck = k.replace("base_model.model.", "") | |
| if ".base_layer.weight" in ck: | |
| base_layers[ck.replace(".base_layer.weight", "")] = v | |
| elif ".lora_A.weight" in ck: | |
| lora_As[ck.replace(".lora_A.weight", "")] = v | |
| elif ".lora_B.weight" in ck: | |
| lora_Bs[ck.replace(".lora_B.weight", "")] = v | |
| elif ".lora_embedding_A" in ck: | |
| emb_As[ck.replace(".lora_embedding_A", "")] = v | |
| elif ".lora_embedding_B" in ck: | |
| emb_Bs[ck.replace(".lora_embedding_B", "")] = v | |
| model_state = model.state_dict() | |
| # Replace base layer weights (for resized lm_head / embed_tokens) | |
| for mod, w in base_layers.items(): | |
| key = mod + ".weight" | |
| if key in model_state: | |
| model_state[key] = w.to(model_state[key].dtype) | |
| print(f" Replaced base layer: {key}") | |
| # Merge LoRA: W_merged = W + B @ A * scaling | |
| merged_count = 0 | |
| for mod in lora_As: | |
| if mod in lora_Bs: | |
| A, B = lora_As[mod].float(), lora_Bs[mod].float() | |
| delta = (B @ A) * scaling | |
| key = mod + ".weight" | |
| if key in model_state: | |
| model_state[key] = (model_state[key].float() + delta).to(torch.bfloat16) | |
| merged_count += 1 | |
| print(f" Merged {merged_count} LoRA modules") | |
| # Merge embedding LoRA | |
| for mod in emb_As: | |
| if mod in emb_Bs: | |
| A, B = emb_As[mod].float(), emb_Bs[mod].float() | |
| delta = (B @ A).T * scaling | |
| key = mod + ".weight" | |
| if key in model_state: | |
| model_state[key] = (model_state[key].float() + delta).to(torch.bfloat16) | |
| print(f" Merged embedding LoRA: {key}") | |
| # 6. Load non-LoRA weights (rm_head, merger, special embeddings) | |
| non_lora_path = os.path.join(ckpt_path, "non_lora_state_dict.pth") | |
| if os.path.exists(non_lora_path): | |
| print("Loading non-LoRA weights...") | |
| non_lora_weights = torch.load(non_lora_path, map_location="cpu") | |
| for k, v in non_lora_weights.items(): | |
| ck = k.replace("base_model.model.", "") | |
| if ck in model_state: | |
| model_state[ck] = v.to(model_state[ck].dtype) | |
| print(f" Loaded: {ck}") | |
| model.load_state_dict(model_state) | |
| print(f"All weights loaded. rm_head shape: {model.rm_head.weight.shape}") | |
| # 7. Save merged model | |
| os.makedirs(args.output_dir, exist_ok=True) | |
| print(f"Saving merged model to {args.output_dir}...") | |
| model.save_pretrained(args.output_dir, safe_serialization=True) | |
| processor.save_pretrained(args.output_dir) | |
| # Save VEFX-specific config | |
| vefx_config = { | |
| "output_dim": output_dim, | |
| "use_ordinal": use_ordinal, | |
| "num_classes": num_classes, | |
| "reward_token": reward_token, | |
| "fps": data_config.get("fps", 4.0), | |
| "max_frame_pixels": data_config.get("max_frame_pixels", 399360), | |
| "eval_dim": data_config.get("eval_dim", ["IF", "RQ", "EE"]), | |
| "prompt_template_type": data_config.get("prompt_template_type", "editreward_v2_special"), | |
| } | |
| with open(os.path.join(args.output_dir, "vefx_config.json"), "w") as f: | |
| json.dump(vefx_config, f, indent=2) | |
| print("Saved vefx_config.json") | |
| # 8. Upload to HuggingFace | |
| if args.upload: | |
| from huggingface_hub import HfApi | |
| api = HfApi() | |
| print(f"Uploading to {args.hf_repo}...") | |
| api.upload_folder( | |
| folder_path=args.output_dir, | |
| repo_id=args.hf_repo, | |
| repo_type="model", | |
| ) | |
| print(f"Upload complete: https://huggingface.co/{args.hf_repo}") | |
| else: | |
| print(f"\nModel saved to {args.output_dir}") | |
| print(f"To upload, run again with --upload flag, or manually:") | |
| print(f" huggingface-cli upload {args.hf_repo} {args.output_dir} .") | |
| if __name__ == "__main__": | |
| main() | |