itssKarthiii's picture
Upload 70 files
6b408d7 verified
"""
Model Download Script.
Downloads and caches the Wav2Vec2 model for VoiceAuth API.
"""
import argparse
import os
import sys
from pathlib import Path
# Add parent directory to path for imports
sys.path.insert(0, str(Path(__file__).parent.parent))
def download_model(
model_name: str = "facebook/wav2vec2-base",
output_dir: str | None = None,
force: bool = False,
) -> None:
"""
Download and cache the Wav2Vec2 model.
Args:
model_name: HuggingFace model name or path
output_dir: Optional local directory to save model
force: Force re-download even if cached
"""
print("\n" + "=" * 60)
print("VoiceAuth - Model Download")
print("=" * 60 + "\n")
print(f"Model: {model_name}")
if output_dir:
print(f"Output: {output_dir}")
print("\nDownloading model components...")
print("-" * 40)
try:
# Import here to avoid slow imports if just checking args
from transformers import Wav2Vec2ForSequenceClassification
from transformers import Wav2Vec2Processor
# Download processor
print("\n[1/2] Downloading Wav2Vec2Processor...")
processor = Wav2Vec2Processor.from_pretrained(
model_name,
force_download=force,
)
print(" [OK] Processor downloaded")
# Download model
print("\n[2/2] Downloading Wav2Vec2ForSequenceClassification...")
model = Wav2Vec2ForSequenceClassification.from_pretrained(
model_name,
num_labels=2,
label2id={"HUMAN": 0, "AI_GENERATED": 1},
id2label={0: "HUMAN", 1: "AI_GENERATED"},
force_download=force,
)
print(" [OK] Model downloaded")
# Save to local directory if specified
if output_dir:
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
print(f"\nSaving to {output_path}...")
processor.save_pretrained(output_path)
model.save_pretrained(output_path)
print("[OK] Model saved locally")
print("\n" + "=" * 60)
print("Download Complete!")
print("=" * 60)
# Show cache location
cache_dir = os.environ.get(
"HF_HOME",
os.path.expanduser("~/.cache/huggingface"),
)
print(f"\nCache location: {cache_dir}")
if output_dir:
print(f"Local copy: {output_dir}")
print("\nYou can now start the API with:")
print(" uvicorn app.main:app --reload")
print()
except Exception as e:
print(f"\n[ERROR] Error downloading model: {e}")
sys.exit(1)
def main() -> None:
"""Main entry point."""
parser = argparse.ArgumentParser(
description="Download Wav2Vec2 model for VoiceAuth API",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
python download_model.py
python download_model.py --model facebook/wav2vec2-large-xlsr-53
python download_model.py --output ./models
python download_model.py --force
""",
)
parser.add_argument(
"--model",
type=str,
default="facebook/wav2vec2-base",
help="HuggingFace model name (default: facebook/wav2vec2-base)",
)
parser.add_argument(
"--output",
type=str,
default=None,
help="Optional local directory to save model",
)
parser.add_argument(
"--force",
action="store_true",
help="Force re-download even if cached",
)
args = parser.parse_args()
download_model(
model_name=args.model,
output_dir=args.output,
force=args.force,
)
if __name__ == "__main__":
main()