import os import sys import logging import spacy from transformers import AutoTokenizer, AutoModelForSeq2SeqLM # Set up logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', handlers=[logging.StreamHandler(sys.stdout)] ) logger = logging.getLogger(__name__) # Import CACHE_BASE_DIR from config if possible, or use a safe default try: # First try to import from the app config sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from app.config import CACHE_BASE_DIR, T5_MODEL_NAME except ImportError: # Fallback to a safe location if config can't be imported CACHE_BASE_DIR = os.getenv("CACHE_BASE_DIR", "/tmp/opportunity_t5_model_cache") T5_MODEL_NAME = "Ayush472/T5QuestionGenerator" logger.info(f"Could not import from config, using fallback cache directory: {CACHE_BASE_DIR}") # Paths for optimized models OPTIMIZED_TOKENIZER_PATH = os.path.join(CACHE_BASE_DIR, "transformers", "optimized_tokenizer") OPTIMIZED_MODEL_PATH = os.path.join(CACHE_BASE_DIR, "transformers", "optimized_model") def download_and_cache_models(): logger.info("Starting model download and caching process...") # --- T5 Model and Tokenizer --- try: logger.info(f"Downloading T5 model and tokenizer: {T5_MODEL_NAME}") os.makedirs(OPTIMIZED_TOKENIZER_PATH, exist_ok=True) os.makedirs(OPTIMIZED_MODEL_PATH, exist_ok=True) hf_token = os.environ.get("HF_TOKEN") tokenizer = AutoTokenizer.from_pretrained(T5_MODEL_NAME, token=hf_token) model = AutoModelForSeq2SeqLM.from_pretrained(T5_MODEL_NAME, token=hf_token) tokenizer.save_pretrained(OPTIMIZED_TOKENIZER_PATH) model.save_pretrained(OPTIMIZED_MODEL_PATH) logger.info("T5 model and tokenizer downloaded and saved successfully.") except Exception as e: logger.error(f"Error downloading T5 model: {e}", exc_info=True) sys.exit(1) # Exit with error # --- spaCy Model --- try: logger.info("Downloading spaCy model: en_core_web_sm") spacy.cli.download("en_core_web_sm") logger.info("spaCy model 'en_core_web_sm' downloaded successfully.") except Exception as e: logger.error(f"Error downloading spaCy model: {e}", exc_info=True) sys.exit(1) # Exit with error logger.info("Model download and caching process completed successfully!") if __name__ == "__main__": download_and_cache_models()