| | """
|
| | Model Download Script.
|
| |
|
| | Downloads and caches the Wav2Vec2 model for VoiceAuth API.
|
| | """
|
| |
|
| | import argparse
|
| | import os
|
| | import sys
|
| | from pathlib import Path
|
| |
|
| |
|
| | sys.path.insert(0, str(Path(__file__).parent.parent))
|
| |
|
| |
|
| | def download_model(
|
| | model_name: str = "facebook/wav2vec2-base",
|
| | output_dir: str | None = None,
|
| | force: bool = False,
|
| | ) -> None:
|
| | """
|
| | Download and cache the Wav2Vec2 model.
|
| |
|
| | Args:
|
| | model_name: HuggingFace model name or path
|
| | output_dir: Optional local directory to save model
|
| | force: Force re-download even if cached
|
| | """
|
| | print("\n" + "=" * 60)
|
| | print("VoiceAuth - Model Download")
|
| | print("=" * 60 + "\n")
|
| | print(f"Model: {model_name}")
|
| |
|
| | if output_dir:
|
| | print(f"Output: {output_dir}")
|
| |
|
| | print("\nDownloading model components...")
|
| | print("-" * 40)
|
| |
|
| | try:
|
| |
|
| | from transformers import Wav2Vec2ForSequenceClassification
|
| | from transformers import Wav2Vec2Processor
|
| |
|
| |
|
| | print("\n[1/2] Downloading Wav2Vec2Processor...")
|
| | processor = Wav2Vec2Processor.from_pretrained(
|
| | model_name,
|
| | force_download=force,
|
| | )
|
| | print(" [OK] Processor downloaded")
|
| |
|
| |
|
| | print("\n[2/2] Downloading Wav2Vec2ForSequenceClassification...")
|
| | model = Wav2Vec2ForSequenceClassification.from_pretrained(
|
| | model_name,
|
| | num_labels=2,
|
| | label2id={"HUMAN": 0, "AI_GENERATED": 1},
|
| | id2label={0: "HUMAN", 1: "AI_GENERATED"},
|
| | force_download=force,
|
| | )
|
| | print(" [OK] Model downloaded")
|
| |
|
| |
|
| | if output_dir:
|
| | output_path = Path(output_dir)
|
| | output_path.mkdir(parents=True, exist_ok=True)
|
| |
|
| | print(f"\nSaving to {output_path}...")
|
| | processor.save_pretrained(output_path)
|
| | model.save_pretrained(output_path)
|
| | print("[OK] Model saved locally")
|
| |
|
| | print("\n" + "=" * 60)
|
| | print("Download Complete!")
|
| | print("=" * 60)
|
| |
|
| |
|
| | cache_dir = os.environ.get(
|
| | "HF_HOME",
|
| | os.path.expanduser("~/.cache/huggingface"),
|
| | )
|
| | print(f"\nCache location: {cache_dir}")
|
| |
|
| | if output_dir:
|
| | print(f"Local copy: {output_dir}")
|
| |
|
| | print("\nYou can now start the API with:")
|
| | print(" uvicorn app.main:app --reload")
|
| | print()
|
| |
|
| | except Exception as e:
|
| | print(f"\n[ERROR] Error downloading model: {e}")
|
| | sys.exit(1)
|
| |
|
| |
|
| | def main() -> None:
|
| | """Main entry point."""
|
| | parser = argparse.ArgumentParser(
|
| | description="Download Wav2Vec2 model for VoiceAuth API",
|
| | formatter_class=argparse.RawDescriptionHelpFormatter,
|
| | epilog="""
|
| | Examples:
|
| | python download_model.py
|
| | python download_model.py --model facebook/wav2vec2-large-xlsr-53
|
| | python download_model.py --output ./models
|
| | python download_model.py --force
|
| | """,
|
| | )
|
| |
|
| | parser.add_argument(
|
| | "--model",
|
| | type=str,
|
| | default="facebook/wav2vec2-base",
|
| | help="HuggingFace model name (default: facebook/wav2vec2-base)",
|
| | )
|
| | parser.add_argument(
|
| | "--output",
|
| | type=str,
|
| | default=None,
|
| | help="Optional local directory to save model",
|
| | )
|
| | parser.add_argument(
|
| | "--force",
|
| | action="store_true",
|
| | help="Force re-download even if cached",
|
| | )
|
| |
|
| | args = parser.parse_args()
|
| |
|
| | download_model(
|
| | model_name=args.model,
|
| | output_dir=args.output,
|
| | force=args.force,
|
| | )
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| | main()
|
| |
|