""" Model Download Script. Downloads and caches the Wav2Vec2 model for VoiceAuth API. """ import argparse import os import sys from pathlib import Path # Add parent directory to path for imports sys.path.insert(0, str(Path(__file__).parent.parent)) def download_model( model_name: str = "facebook/wav2vec2-base", output_dir: str | None = None, force: bool = False, ) -> None: """ Download and cache the Wav2Vec2 model. Args: model_name: HuggingFace model name or path output_dir: Optional local directory to save model force: Force re-download even if cached """ print("\n" + "=" * 60) print("VoiceAuth - Model Download") print("=" * 60 + "\n") print(f"Model: {model_name}") if output_dir: print(f"Output: {output_dir}") print("\nDownloading model components...") print("-" * 40) try: # Import here to avoid slow imports if just checking args from transformers import Wav2Vec2ForSequenceClassification from transformers import Wav2Vec2Processor # Download processor print("\n[1/2] Downloading Wav2Vec2Processor...") processor = Wav2Vec2Processor.from_pretrained( model_name, force_download=force, ) print(" [OK] Processor downloaded") # Download model print("\n[2/2] Downloading Wav2Vec2ForSequenceClassification...") model = Wav2Vec2ForSequenceClassification.from_pretrained( model_name, num_labels=2, label2id={"HUMAN": 0, "AI_GENERATED": 1}, id2label={0: "HUMAN", 1: "AI_GENERATED"}, force_download=force, ) print(" [OK] Model downloaded") # Save to local directory if specified if output_dir: output_path = Path(output_dir) output_path.mkdir(parents=True, exist_ok=True) print(f"\nSaving to {output_path}...") processor.save_pretrained(output_path) model.save_pretrained(output_path) print("[OK] Model saved locally") print("\n" + "=" * 60) print("Download Complete!") print("=" * 60) # Show cache location cache_dir = os.environ.get( "HF_HOME", os.path.expanduser("~/.cache/huggingface"), ) print(f"\nCache location: {cache_dir}") if output_dir: print(f"Local copy: {output_dir}") print("\nYou can now start the API with:") print(" uvicorn app.main:app --reload") print() except Exception as e: print(f"\n[ERROR] Error downloading model: {e}") sys.exit(1) def main() -> None: """Main entry point.""" parser = argparse.ArgumentParser( description="Download Wav2Vec2 model for VoiceAuth API", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Examples: python download_model.py python download_model.py --model facebook/wav2vec2-large-xlsr-53 python download_model.py --output ./models python download_model.py --force """, ) parser.add_argument( "--model", type=str, default="facebook/wav2vec2-base", help="HuggingFace model name (default: facebook/wav2vec2-base)", ) parser.add_argument( "--output", type=str, default=None, help="Optional local directory to save model", ) parser.add_argument( "--force", action="store_true", help="Force re-download even if cached", ) args = parser.parse_args() download_model( model_name=args.model, output_dir=args.output, force=args.force, ) if __name__ == "__main__": main()