wi-lab commited on
Commit
8676fb4
·
unverified ·
1 Parent(s): ad81aa1

Delete examples/export_to_hf.py

Browse files
Files changed (1) hide show
  1. examples/export_to_hf.py +0 -103
examples/export_to_hf.py DELETED
@@ -1,103 +0,0 @@
1
- #!/usr/bin/env python
2
- """Export an existing checkpoint to Hugging Face format."""
3
-
4
- import sys
5
- from pathlib import Path
6
- sys.path.insert(0, str(Path(__file__).parent.parent))
7
-
8
- from LWMTemporal.models.lwm import LWMBackbone, LWMConfig
9
- from LWMTemporal.models.hf import LWMHFModel, LWMHFConfig
10
- from LWMTemporal.utils.logging import setup_logging
11
-
12
- logger = setup_logging("export_to_hf", log_dir=Path("logs"))
13
-
14
- # Path to your existing checkpoint directory (with config.json and pytorch_model.bin)
15
- checkpoint_dir = Path("checkpoints") # Directory containing config.json and pytorch_model.bin
16
- checkpoint_path = checkpoint_dir / "pytorch_model.bin" # Or use m18_cp.pth if that's your file
17
-
18
- # Output directory for HF export
19
- hf_export_dir = Path("models/hf_export")
20
-
21
- logger.info("Loading checkpoint from %s", checkpoint_path)
22
- # Load the LWM model
23
- lwm_model = LWMBackbone.from_pretrained(checkpoint_path)
24
-
25
- # Load config from checkpoint directory if it exists
26
- config_path = checkpoint_dir / "config.json"
27
- if config_path.exists():
28
- import json
29
- with open(config_path) as f:
30
- config_dict = json.load(f)
31
- lwm_config = LWMConfig.from_dict(config_dict)
32
- else:
33
- lwm_config = lwm_model.config
34
-
35
- # Ensure max_seq_len matches checkpoint positional embeddings
36
- if lwm_config.max_seq_len is None and hasattr(lwm_model, "pos_embed"):
37
- pos_len = int(lwm_model.pos_embed.shape[1])
38
- cls_tokens = 1 if lwm_config.global_cls else 0
39
- inferred = max(0, pos_len - cls_tokens)
40
- if inferred > 0:
41
- lwm_config.max_seq_len = inferred
42
- logger.info("Inferred max_seq_len=%d from checkpoint positional embeddings", inferred)
43
-
44
- # Convert to HF format
45
- logger.info("Converting to Hugging Face format...")
46
- hf_config = LWMHFConfig(**lwm_config.to_dict())
47
- hf_model = LWMHFModel(hf_config)
48
- hf_model.backbone.load_state_dict(lwm_model.state_dict())
49
-
50
- logger.info("Exporting to Hugging Face format at %s", hf_export_dir)
51
- hf_model.save_pretrained(hf_export_dir)
52
-
53
- # Copy the modeling files so HF can load it with trust_remote_code=True
54
- # HF expects the files to match the auto_map import path
55
- import shutil
56
- base_dir = Path(__file__).parent.parent
57
- modeling_dir = hf_export_dir / "LWMTemporal" / "models"
58
- modeling_dir.mkdir(parents=True, exist_ok=True)
59
-
60
- # Copy hf.py (the HF wrapper)
61
- hf_file = base_dir / "LWMTemporal" / "models" / "hf.py"
62
- if hf_file.exists():
63
- shutil.copy2(hf_file, modeling_dir / "hf.py")
64
- logger.info("✓ Copied hf.py")
65
- else:
66
- logger.warning("hf.py not found at %s", hf_file)
67
-
68
- # Copy lwm.py (dependency)
69
- lwm_file = base_dir / "LWMTemporal" / "models" / "lwm.py"
70
- if lwm_file.exists():
71
- shutil.copy2(lwm_file, modeling_dir / "lwm.py")
72
- logger.info("✓ Copied lwm.py")
73
- else:
74
- logger.warning("lwm.py not found at %s", lwm_file)
75
-
76
- # Create __init__.py files for proper imports
77
- (hf_export_dir / "LWMTemporal" / "__init__.py").touch()
78
- (modeling_dir / "__init__.py").touch()
79
-
80
- logger.info("✓ Exported to %s", hf_export_dir)
81
- logger.info("Files created:")
82
- for f in sorted(hf_export_dir.glob("*")):
83
- logger.info(" - %s", f.name)
84
-
85
- # Optional: Upload directly to HF Hub
86
- # Uncomment to automatically push:
87
- # try:
88
- # from huggingface_hub import HfApi
89
- # api = HfApi()
90
- # api.upload_folder(
91
- # folder_path=hf_export_dir,
92
- # repo_id="wi-lab/lwm-temporal",
93
- # repo_type="model",
94
- # commit_message="Export existing checkpoint to HF format",
95
- # )
96
- # logger.info("✓ Uploaded to Hugging Face Hub: wi-lab/lwm-temporal")
97
- # except ImportError:
98
- # logger.warning("huggingface_hub not installed; skipping upload")
99
- # logger.info("To upload manually:")
100
- # logger.info(" 1. git clone https://huggingface.co/wi-lab/lwm-temporal")
101
- # logger.info(" 2. cp -r %s/* lwm-temporal/", hf_export_dir)
102
- # logger.info(" 3. cd lwm-temporal && git add . && git commit -m 'Add checkpoint' && git push")
103
-