carsa_api / download_models.py
athmontech's picture
Initial commit: Carsa AI Backend for Hugging Face Spaces
d01de5d
"""
Model Download Script for Carsa Translation API
This script downloads the required Helsinki-NLP translation models
with proper error handling and resume capability.
"""
from transformers import pipeline
import torch
import logging
import time
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# Define the models to download
MODELS = {
"twi": "Helsinki-NLP/opus-mt-en-tw",
"ga": "Helsinki-NLP/opus-mt-en-gaa",
"ewe": "Helsinki-NLP/opus-mt-en-ee",
"hausa": "Helsinki-NLP/opus-mt-en-ha"
}
def download_model(lang, model_name, max_retries=3):
"""
Download a single model with retry logic.
Args:
lang (str): Language code
model_name (str): Hugging Face model name
max_retries (int): Maximum number of retry attempts
"""
for attempt in range(max_retries):
try:
logger.info(f"Downloading model for '{lang}' (attempt {attempt + 1}/{max_retries})...")
# Create the appropriate task name
task = f"translation_en_to_{'tw' if lang == 'twi' else lang}"
# Download and load the model
# This will automatically resume if the download was interrupted
pipeline(
task,
model=model_name,
device=-1 # Use CPU for downloading to avoid GPU memory issues
)
logger.info(f"βœ… Successfully downloaded model for '{lang}'")
return True
except Exception as e:
logger.error(f"❌ Failed to download model for '{lang}' (attempt {attempt + 1}): {str(e)}")
if attempt < max_retries - 1:
wait_time = (attempt + 1) * 5 # Exponential backoff
logger.info(f"Waiting {wait_time} seconds before retry...")
time.sleep(wait_time)
else:
logger.error(f"Failed to download model for '{lang}' after {max_retries} attempts")
return False
return False
def main():
"""Main function to download all models."""
logger.info("Starting model download process...")
# Check device availability
device_name = "GPU" if torch.cuda.is_available() else "CPU"
logger.info(f"Using device: {device_name}")
success_count = 0
total_count = len(MODELS)
for lang, model_name in MODELS.items():
logger.info(f"Processing model for language: {lang}")
if download_model(lang, model_name):
success_count += 1
else:
logger.error(f"Failed to download model for {lang}")
logger.info(f"Download process complete: {success_count}/{total_count} models downloaded successfully")
if success_count == total_count:
logger.info("πŸŽ‰ All models downloaded successfully! You can now run the FastAPI application.")
else:
logger.warning(f"⚠️ Only {success_count}/{total_count} models were downloaded. Some features may not work.")
if __name__ == "__main__":
main()