File size: 800 Bytes
b1b2e62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
#!/usr/bin/env python3
"""
Create initialized Abstract model checkpoint.
"""

import argparse
import torch
import os
from pathlib import Path
from abstract_model import AbstractModel

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--sft-model', required=True, help='Path to SFT model')
    parser.add_argument('--output', required=True, help='Output directory for initialized model')
    args = parser.parse_args()

    print(f"Loading SFT model from: {args.sft_model}")
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    model = AbstractModel(args.sft_model, device=device)

    print(f"Saving initialized model to: {args.output}")
    os.makedirs(args.output, exist_ok=True)
    model.save_to_directory(args.output)

if __name__ == "__main__":
    main()