File size: 3,924 Bytes
6b408d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
"""

Model Download Script.



Downloads and caches the Wav2Vec2 model for VoiceAuth API.

"""

import argparse
import os
import sys
from pathlib import Path

# Add parent directory to path for imports
sys.path.insert(0, str(Path(__file__).parent.parent))


def download_model(

    model_name: str = "facebook/wav2vec2-base",

    output_dir: str | None = None,

    force: bool = False,

) -> None:
    """

    Download and cache the Wav2Vec2 model.



    Args:

        model_name: HuggingFace model name or path

        output_dir: Optional local directory to save model

        force: Force re-download even if cached

    """
    print("\n" + "=" * 60)
    print("VoiceAuth - Model Download")
    print("=" * 60 + "\n")
    print(f"Model: {model_name}")

    if output_dir:
        print(f"Output: {output_dir}")

    print("\nDownloading model components...")
    print("-" * 40)

    try:
        # Import here to avoid slow imports if just checking args
        from transformers import Wav2Vec2ForSequenceClassification
        from transformers import Wav2Vec2Processor

        # Download processor
        print("\n[1/2] Downloading Wav2Vec2Processor...")
        processor = Wav2Vec2Processor.from_pretrained(
            model_name,
            force_download=force,
        )
        print("      [OK] Processor downloaded")

        # Download model
        print("\n[2/2] Downloading Wav2Vec2ForSequenceClassification...")
        model = Wav2Vec2ForSequenceClassification.from_pretrained(
            model_name,
            num_labels=2,
            label2id={"HUMAN": 0, "AI_GENERATED": 1},
            id2label={0: "HUMAN", 1: "AI_GENERATED"},
            force_download=force,
        )
        print("      [OK] Model downloaded")

        # Save to local directory if specified
        if output_dir:
            output_path = Path(output_dir)
            output_path.mkdir(parents=True, exist_ok=True)

            print(f"\nSaving to {output_path}...")
            processor.save_pretrained(output_path)
            model.save_pretrained(output_path)
            print("[OK] Model saved locally")

        print("\n" + "=" * 60)
        print("Download Complete!")
        print("=" * 60)

        # Show cache location
        cache_dir = os.environ.get(
            "HF_HOME",
            os.path.expanduser("~/.cache/huggingface"),
        )
        print(f"\nCache location: {cache_dir}")

        if output_dir:
            print(f"Local copy: {output_dir}")

        print("\nYou can now start the API with:")
        print("  uvicorn app.main:app --reload")
        print()

    except Exception as e:
        print(f"\n[ERROR] Error downloading model: {e}")
        sys.exit(1)


def main() -> None:
    """Main entry point."""
    parser = argparse.ArgumentParser(
        description="Download Wav2Vec2 model for VoiceAuth API",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""

Examples:

  python download_model.py

  python download_model.py --model facebook/wav2vec2-large-xlsr-53

  python download_model.py --output ./models

  python download_model.py --force

        """,
    )

    parser.add_argument(
        "--model",
        type=str,
        default="facebook/wav2vec2-base",
        help="HuggingFace model name (default: facebook/wav2vec2-base)",
    )
    parser.add_argument(
        "--output",
        type=str,
        default=None,
        help="Optional local directory to save model",
    )
    parser.add_argument(
        "--force",
        action="store_true",
        help="Force re-download even if cached",
    )

    args = parser.parse_args()

    download_model(
        model_name=args.model,
        output_dir=args.output,
        force=args.force,
    )


if __name__ == "__main__":
    main()