File size: 800 Bytes
b1b2e62 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 |
#!/usr/bin/env python3
"""
Create initialized Abstract model checkpoint.
"""
import argparse
import torch
import os
from pathlib import Path
from abstract_model import AbstractModel
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--sft-model', required=True, help='Path to SFT model')
parser.add_argument('--output', required=True, help='Output directory for initialized model')
args = parser.parse_args()
print(f"Loading SFT model from: {args.sft_model}")
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = AbstractModel(args.sft_model, device=device)
print(f"Saving initialized model to: {args.output}")
os.makedirs(args.output, exist_ok=True)
model.save_to_directory(args.output)
if __name__ == "__main__":
main()
|