File size: 3,878 Bytes
4942b80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
#!/usr/bin/env python3
"""
Export fine-tuned Gemma 4 model to various formats.

Usage:
    # Export LoRA adapter to merged model + GGUF
    python scripts/export_model.py --model checkpoints/finetuned/lora_adapter

    # Push to HuggingFace Hub
    python scripts/export_model.py --model checkpoints/finetuned/lora_adapter \
        --push-to-hub username/gemma4-finetuned

    # Export specific GGUF quantization
    python scripts/export_model.py --model checkpoints/finetuned/lora_adapter \
        --gguf-quant q8_0
"""

import argparse
import os

from unsloth import FastModel


def parse_args():
    parser = argparse.ArgumentParser(description="Export fine-tuned Gemma 4 model")
    parser.add_argument("--model", type=str, required=True,
                        help="Path to fine-tuned LoRA adapter")
    parser.add_argument("--max-seq-length", type=int, default=2048)
    parser.add_argument("--output-dir", type=str, default="checkpoints/finetuned",
                        help="Base output directory")

    # Export options
    parser.add_argument("--no-merged", action="store_true",
                        help="Skip merged 16-bit export")
    parser.add_argument("--no-gguf", action="store_true",
                        help="Skip GGUF export")
    parser.add_argument("--gguf-quant", type=str, default="q4_k_m",
                        choices=["q4_k_m", "q8_0", "f16"],
                        help="GGUF quantization method")
    parser.add_argument("--push-to-hub", type=str, default=None,
                        help="HuggingFace Hub repo to push to (e.g. username/model-name)")

    return parser.parse_args()


def main():
    args = parse_args()

    print("=" * 60)
    print("Gemma 4 Model Export")
    print("=" * 60)
    print(f"Model:       {args.model}")
    print(f"Output dir:  {args.output_dir}")
    print("=" * 60)

    # Load model
    print("\nLoading model...")
    model, tokenizer = FastModel.from_pretrained(
        model_name=args.model,
        max_seq_length=args.max_seq_length,
        load_in_4bit=True,
    )

    # Export merged model
    if not args.no_merged:
        merged_path = os.path.join(args.output_dir, "merged")
        print(f"\nExporting merged 16-bit model to {merged_path}...")
        model.save_pretrained_merged(
            merged_path,
            tokenizer,
            save_method="merged_16bit",
        )
        print(f"  Done! Size: {get_dir_size(merged_path)}")

    # Export GGUF
    if not args.no_gguf:
        gguf_path = os.path.join(args.output_dir, f"gguf_{args.gguf_quant}")
        print(f"\nExporting GGUF ({args.gguf_quant}) to {gguf_path}...")
        model.save_pretrained_gguf(
            gguf_path,
            tokenizer,
            quantization_method=args.gguf_quant,
        )
        print(f"  Done! Size: {get_dir_size(gguf_path)}")

    # Push to Hub
    if args.push_to_hub:
        print(f"\nPushing to HuggingFace Hub: {args.push_to_hub}...")

        # Push LoRA adapter
        model.push_to_hub(args.push_to_hub, tokenizer)
        print("  Pushed LoRA adapter")

        # Push GGUF
        model.push_to_hub_gguf(
            args.push_to_hub,
            tokenizer,
            quantization_method=args.gguf_quant,
        )
        print(f"  Pushed GGUF ({args.gguf_quant})")

    print("\nExport complete!")


def get_dir_size(path):
    """Get human-readable directory size."""
    total = 0
    if os.path.isdir(path):
        for dirpath, _, filenames in os.walk(path):
            for f in filenames:
                fp = os.path.join(dirpath, f)
                total += os.path.getsize(fp)
    elif os.path.isfile(path):
        total = os.path.getsize(path)

    for unit in ["B", "KB", "MB", "GB"]:
        if total < 1024:
            return f"{total:.1f} {unit}"
        total /= 1024
    return f"{total:.1f} TB"


if __name__ == "__main__":
    main()