File size: 3,924 Bytes
6b408d7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 | """
Model Download Script.
Downloads and caches the Wav2Vec2 model for VoiceAuth API.
"""
import argparse
import os
import sys
from pathlib import Path
# Add parent directory to path for imports
sys.path.insert(0, str(Path(__file__).parent.parent))
def download_model(
model_name: str = "facebook/wav2vec2-base",
output_dir: str | None = None,
force: bool = False,
) -> None:
"""
Download and cache the Wav2Vec2 model.
Args:
model_name: HuggingFace model name or path
output_dir: Optional local directory to save model
force: Force re-download even if cached
"""
print("\n" + "=" * 60)
print("VoiceAuth - Model Download")
print("=" * 60 + "\n")
print(f"Model: {model_name}")
if output_dir:
print(f"Output: {output_dir}")
print("\nDownloading model components...")
print("-" * 40)
try:
# Import here to avoid slow imports if just checking args
from transformers import Wav2Vec2ForSequenceClassification
from transformers import Wav2Vec2Processor
# Download processor
print("\n[1/2] Downloading Wav2Vec2Processor...")
processor = Wav2Vec2Processor.from_pretrained(
model_name,
force_download=force,
)
print(" [OK] Processor downloaded")
# Download model
print("\n[2/2] Downloading Wav2Vec2ForSequenceClassification...")
model = Wav2Vec2ForSequenceClassification.from_pretrained(
model_name,
num_labels=2,
label2id={"HUMAN": 0, "AI_GENERATED": 1},
id2label={0: "HUMAN", 1: "AI_GENERATED"},
force_download=force,
)
print(" [OK] Model downloaded")
# Save to local directory if specified
if output_dir:
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
print(f"\nSaving to {output_path}...")
processor.save_pretrained(output_path)
model.save_pretrained(output_path)
print("[OK] Model saved locally")
print("\n" + "=" * 60)
print("Download Complete!")
print("=" * 60)
# Show cache location
cache_dir = os.environ.get(
"HF_HOME",
os.path.expanduser("~/.cache/huggingface"),
)
print(f"\nCache location: {cache_dir}")
if output_dir:
print(f"Local copy: {output_dir}")
print("\nYou can now start the API with:")
print(" uvicorn app.main:app --reload")
print()
except Exception as e:
print(f"\n[ERROR] Error downloading model: {e}")
sys.exit(1)
def main() -> None:
"""Main entry point."""
parser = argparse.ArgumentParser(
description="Download Wav2Vec2 model for VoiceAuth API",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
python download_model.py
python download_model.py --model facebook/wav2vec2-large-xlsr-53
python download_model.py --output ./models
python download_model.py --force
""",
)
parser.add_argument(
"--model",
type=str,
default="facebook/wav2vec2-base",
help="HuggingFace model name (default: facebook/wav2vec2-base)",
)
parser.add_argument(
"--output",
type=str,
default=None,
help="Optional local directory to save model",
)
parser.add_argument(
"--force",
action="store_true",
help="Force re-download even if cached",
)
args = parser.parse_args()
download_model(
model_name=args.model,
output_dir=args.output,
force=args.force,
)
if __name__ == "__main__":
main()
|