Spaces:
Running
Running
| """ | |
| Download TinyLlama Model | |
| This script downloads the TinyLlama model from Hugging Face and prepares it | |
| for fine-tuning on SWIFT MT564 documentation. | |
| Usage: | |
| python download_tinyllama.py --model_name TinyLlama/TinyLlama-1.1B-Chat-v1.0 --output_dir ./data/models | |
| """ | |
| import os | |
| import argparse | |
| import logging | |
| from typing import Optional | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description="Download TinyLlama model from Hugging Face") | |
| parser.add_argument( | |
| "--model_name", | |
| type=str, | |
| default="TinyLlama/TinyLlama-1.1B-Chat-v1.0", | |
| help="Name of the TinyLlama model on Hugging Face Hub" | |
| ) | |
| parser.add_argument( | |
| "--output_dir", | |
| type=str, | |
| default="./data/models", | |
| help="Directory to save the downloaded model" | |
| ) | |
| parser.add_argument( | |
| "--use_auth_token", | |
| action="store_true", | |
| help="Use Hugging Face authentication token for downloading gated models" | |
| ) | |
| parser.add_argument( | |
| "--branch", | |
| type=str, | |
| default="main", | |
| help="Branch of the model repository to download from" | |
| ) | |
| parser.add_argument( | |
| "--check_integrity", | |
| action="store_true", | |
| help="Verify integrity of downloaded files" | |
| ) | |
| return parser.parse_args() | |
| def download_model( | |
| model_name: str, | |
| output_dir: str, | |
| use_auth_token: bool = False, | |
| branch: str = "main", | |
| check_integrity: bool = False | |
| ) -> Optional[str]: | |
| """ | |
| Download model and tokenizer from Hugging Face Hub | |
| Args: | |
| model_name: Name of the model on Hugging Face Hub | |
| output_dir: Directory to save the model | |
| use_auth_token: Whether to use Hugging Face token for gated models | |
| branch: Branch of the model repository | |
| check_integrity: Whether to verify integrity of downloaded files | |
| Returns: | |
| Path to the downloaded model or None if download failed | |
| """ | |
| try: | |
| # Import libraries here so the script doesn't fail if they're not installed | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from huggingface_hub import snapshot_download | |
| logger.info(f"Downloading model: {model_name}") | |
| os.makedirs(output_dir, exist_ok=True) | |
| # Create model directory | |
| model_output_dir = os.path.join(output_dir, model_name.split('/')[-1]) | |
| os.makedirs(model_output_dir, exist_ok=True) | |
| # Option 1: Use snapshot_download for more control | |
| if check_integrity: | |
| logger.info("Using snapshot_download with integrity checking") | |
| snapshot_download( | |
| repo_id=model_name, | |
| local_dir=model_output_dir, | |
| use_auth_token=use_auth_token if use_auth_token else None, | |
| revision=branch | |
| ) | |
| # Option 2: Use Transformers' download mechanism | |
| else: | |
| logger.info("Using Transformers' auto classes for downloading") | |
| # Download and save tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| model_name, | |
| use_auth_token=use_auth_token if use_auth_token else None, | |
| revision=branch | |
| ) | |
| tokenizer.save_pretrained(model_output_dir) | |
| logger.info(f"Tokenizer saved to {model_output_dir}") | |
| # Download and save model | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
| use_auth_token=use_auth_token if use_auth_token else None, | |
| revision=branch, | |
| low_cpu_mem_usage=True | |
| ) | |
| model.save_pretrained(model_output_dir) | |
| logger.info(f"Model saved to {model_output_dir}") | |
| logger.info(f"Successfully downloaded model to {model_output_dir}") | |
| return model_output_dir | |
| except ImportError as e: | |
| logger.error(f"Required libraries not installed: {e}") | |
| logger.error("Please install required packages: pip install torch transformers huggingface_hub") | |
| return None | |
| except Exception as e: | |
| logger.error(f"Error downloading model: {e}") | |
| return None | |
| def main(): | |
| args = parse_args() | |
| # Check if HUGGING_FACE_TOKEN environment variable is set | |
| if args.use_auth_token and "HUGGING_FACE_TOKEN" not in os.environ: | |
| logger.warning("--use_auth_token flag is set but HUGGING_FACE_TOKEN environment variable is not found.") | |
| logger.warning("You can set it using: export HUGGING_FACE_TOKEN=your_token_here") | |
| # Download the model | |
| model_path = download_model( | |
| model_name=args.model_name, | |
| output_dir=args.output_dir, | |
| use_auth_token=args.use_auth_token, | |
| branch=args.branch, | |
| check_integrity=args.check_integrity | |
| ) | |
| if model_path: | |
| logger.info(f"Model downloaded successfully to: {model_path}") | |
| logger.info("You can now use this model for fine-tuning with the train_mt564_model.py script.") | |
| else: | |
| logger.error("Failed to download the model.") | |
| if __name__ == "__main__": | |
| main() |