Spaces:
Running
Running
| import os | |
| import sys | |
| import logging | |
| import spacy | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| # Set up logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
| handlers=[logging.StreamHandler(sys.stdout)] | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # Import CACHE_BASE_DIR from config if possible, or use a safe default | |
| try: | |
| # First try to import from the app config | |
| sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| from app.config import CACHE_BASE_DIR, T5_MODEL_NAME | |
| except ImportError: | |
| # Fallback to a safe location if config can't be imported | |
| CACHE_BASE_DIR = os.getenv("CACHE_BASE_DIR", "/tmp/opportunity_t5_model_cache") | |
| T5_MODEL_NAME = "Ayush472/T5QuestionGenerator" | |
| logger.info(f"Could not import from config, using fallback cache directory: {CACHE_BASE_DIR}") | |
| # Paths for optimized models | |
| OPTIMIZED_TOKENIZER_PATH = os.path.join(CACHE_BASE_DIR, "transformers", "optimized_tokenizer") | |
| OPTIMIZED_MODEL_PATH = os.path.join(CACHE_BASE_DIR, "transformers", "optimized_model") | |
| def download_and_cache_models(): | |
| logger.info("Starting model download and caching process...") | |
| # --- T5 Model and Tokenizer --- | |
| try: | |
| logger.info(f"Downloading T5 model and tokenizer: {T5_MODEL_NAME}") | |
| os.makedirs(OPTIMIZED_TOKENIZER_PATH, exist_ok=True) | |
| os.makedirs(OPTIMIZED_MODEL_PATH, exist_ok=True) | |
| hf_token = os.environ.get("HF_TOKEN") | |
| tokenizer = AutoTokenizer.from_pretrained(T5_MODEL_NAME, token=hf_token) | |
| model = AutoModelForSeq2SeqLM.from_pretrained(T5_MODEL_NAME, token=hf_token) | |
| tokenizer.save_pretrained(OPTIMIZED_TOKENIZER_PATH) | |
| model.save_pretrained(OPTIMIZED_MODEL_PATH) | |
| logger.info("T5 model and tokenizer downloaded and saved successfully.") | |
| except Exception as e: | |
| logger.error(f"Error downloading T5 model: {e}", exc_info=True) | |
| sys.exit(1) # Exit with error | |
| # --- spaCy Model --- | |
| try: | |
| logger.info("Downloading spaCy model: en_core_web_sm") | |
| spacy.cli.download("en_core_web_sm") | |
| logger.info("spaCy model 'en_core_web_sm' downloaded successfully.") | |
| except Exception as e: | |
| logger.error(f"Error downloading spaCy model: {e}", exc_info=True) | |
| sys.exit(1) # Exit with error | |
| logger.info("Model download and caching process completed successfully!") | |
| if __name__ == "__main__": | |
| download_and_cache_models() | |