|
|
|
|
|
""" |
|
|
Create initialized Abstract model checkpoint. |
|
|
""" |
|
|
|
|
|
import argparse |
|
|
import torch |
|
|
import os |
|
|
from pathlib import Path |
|
|
from abstract_model import AbstractModel |
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument('--sft-model', required=True, help='Path to SFT model') |
|
|
parser.add_argument('--output', required=True, help='Output directory for initialized model') |
|
|
args = parser.parse_args() |
|
|
|
|
|
print(f"Loading SFT model from: {args.sft_model}") |
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
|
|
model = AbstractModel(args.sft_model, device=device) |
|
|
|
|
|
print(f"Saving initialized model to: {args.output}") |
|
|
os.makedirs(args.output, exist_ok=True) |
|
|
model.save_to_directory(args.output) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|