| """ |
| Model Download Script for Carsa Translation API |
| |
| This script downloads the required Helsinki-NLP translation models |
| with proper error handling and resume capability. |
| """ |
|
|
| from transformers import pipeline |
| import torch |
| import logging |
| import time |
|
|
| |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
| logger = logging.getLogger(__name__) |
|
|
| |
| MODELS = { |
| "twi": "Helsinki-NLP/opus-mt-en-tw", |
| "ga": "Helsinki-NLP/opus-mt-en-gaa", |
| "ewe": "Helsinki-NLP/opus-mt-en-ee", |
| "hausa": "Helsinki-NLP/opus-mt-en-ha" |
| } |
|
|
| def download_model(lang, model_name, max_retries=3): |
| """ |
| Download a single model with retry logic. |
| |
| Args: |
| lang (str): Language code |
| model_name (str): Hugging Face model name |
| max_retries (int): Maximum number of retry attempts |
| """ |
| for attempt in range(max_retries): |
| try: |
| logger.info(f"Downloading model for '{lang}' (attempt {attempt + 1}/{max_retries})...") |
| |
| |
| task = f"translation_en_to_{'tw' if lang == 'twi' else lang}" |
| |
| |
| |
| pipeline( |
| task, |
| model=model_name, |
| device=-1 |
| ) |
| |
| logger.info(f"β
Successfully downloaded model for '{lang}'") |
| return True |
| |
| except Exception as e: |
| logger.error(f"β Failed to download model for '{lang}' (attempt {attempt + 1}): {str(e)}") |
| |
| if attempt < max_retries - 1: |
| wait_time = (attempt + 1) * 5 |
| logger.info(f"Waiting {wait_time} seconds before retry...") |
| time.sleep(wait_time) |
| else: |
| logger.error(f"Failed to download model for '{lang}' after {max_retries} attempts") |
| return False |
| |
| return False |
|
|
| def main(): |
| """Main function to download all models.""" |
| logger.info("Starting model download process...") |
| |
| |
| device_name = "GPU" if torch.cuda.is_available() else "CPU" |
| logger.info(f"Using device: {device_name}") |
| |
| success_count = 0 |
| total_count = len(MODELS) |
| |
| for lang, model_name in MODELS.items(): |
| logger.info(f"Processing model for language: {lang}") |
| |
| if download_model(lang, model_name): |
| success_count += 1 |
| else: |
| logger.error(f"Failed to download model for {lang}") |
| |
| logger.info(f"Download process complete: {success_count}/{total_count} models downloaded successfully") |
| |
| if success_count == total_count: |
| logger.info("π All models downloaded successfully! You can now run the FastAPI application.") |
| else: |
| logger.warning(f"β οΈ Only {success_count}/{total_count} models were downloaded. Some features may not work.") |
|
|
| if __name__ == "__main__": |
| main() |
|
|
|
|