File size: 2,249 Bytes
8cbe037
 
 
 
145b730
8cbe037
 
 
 
 
145b730
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8cbe037
145b730
 
 
 
 
 
8cbe037
145b730
 
8cbe037
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
#!/usr/bin/env python3
"""
Model Conversion Script
=======================
This script downloads GenieData resources and converts PyTorch models to ONNX format.
"""

import os
import sys

def download_genie_data():
    """Download GenieData resources from HuggingFace"""
    from huggingface_hub import snapshot_download
    
    genie_data_dir = os.environ.get("GENIE_DATA_DIR", "./GenieData")
    
    if os.path.exists(genie_data_dir) and os.listdir(genie_data_dir):
        print(f"GenieData already exists at {genie_data_dir}")
        return
    
    print("🚀 Starting download Genie-TTS resources from HuggingFace...")
    snapshot_download(
        repo_id="High-Logic/Genie",
        repo_type="model",
        allow_patterns="GenieData/*",
        local_dir=".",
        local_dir_use_symlinks=False,  # Don't use symlinks in Docker
    )
    print("✅ Genie-TTS resources downloaded successfully.")

def main():
    # Set environment variable for GenieData location BEFORE importing genie_tts
    os.environ["GENIE_DATA_DIR"] = "/app/GenieData"
    
    # Step 1: Download GenieData resources
    print("Step 1: Downloading GenieData resources...")
    download_genie_data()
    
    # Step 2: Now import genie_tts (it will check for GenieData)
    print("Step 2: Starting ONNX conversion...")
    import genie_tts as genie
    
    ckpt_path = os.environ.get("CKPT_PATH", "/app/temp/model.ckpt")
    pth_path = os.environ.get("PTH_PATH", "/app/temp/model.pth")
    output_dir = os.environ.get("OUTPUT_DIR", "/app/models/liang/onnx")
    
    print(f"CKPT path: {ckpt_path}")
    print(f"PTH path: {pth_path}")
    print(f"Output dir: {output_dir}")
    
    # Check if files exist
    if not os.path.exists(ckpt_path):
        print(f"Error: CKPT file not found: {ckpt_path}")
        sys.exit(1)
    
    if not os.path.exists(pth_path):
        print(f"Error: PTH file not found: {pth_path}")
        sys.exit(1)
    
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    
    # Convert model
    genie.convert_to_onnx(
        torch_ckpt_path=ckpt_path,
        torch_pth_path=pth_path,
        output_dir=output_dir
    )
    
    print("ONNX conversion complete!")


if __name__ == "__main__":
    main()