opportunity_t5_model / scripts /prepare_models.py
ayushmodi001
changed model
7ca7e37
import os
import sys
import logging
import spacy
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
# Set up logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[logging.StreamHandler(sys.stdout)]
)
logger = logging.getLogger(__name__)
# Import CACHE_BASE_DIR from config if possible, or use a safe default
try:
# First try to import from the app config
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from app.config import CACHE_BASE_DIR, T5_MODEL_NAME
except ImportError:
# Fallback to a safe location if config can't be imported
CACHE_BASE_DIR = os.getenv("CACHE_BASE_DIR", "/tmp/opportunity_t5_model_cache")
T5_MODEL_NAME = "Ayush472/T5QuestionGenerator"
logger.info(f"Could not import from config, using fallback cache directory: {CACHE_BASE_DIR}")
# Paths for optimized models
OPTIMIZED_TOKENIZER_PATH = os.path.join(CACHE_BASE_DIR, "transformers", "optimized_tokenizer")
OPTIMIZED_MODEL_PATH = os.path.join(CACHE_BASE_DIR, "transformers", "optimized_model")
def download_and_cache_models():
logger.info("Starting model download and caching process...")
# --- T5 Model and Tokenizer ---
try:
logger.info(f"Downloading T5 model and tokenizer: {T5_MODEL_NAME}")
os.makedirs(OPTIMIZED_TOKENIZER_PATH, exist_ok=True)
os.makedirs(OPTIMIZED_MODEL_PATH, exist_ok=True)
hf_token = os.environ.get("HF_TOKEN")
tokenizer = AutoTokenizer.from_pretrained(T5_MODEL_NAME, token=hf_token)
model = AutoModelForSeq2SeqLM.from_pretrained(T5_MODEL_NAME, token=hf_token)
tokenizer.save_pretrained(OPTIMIZED_TOKENIZER_PATH)
model.save_pretrained(OPTIMIZED_MODEL_PATH)
logger.info("T5 model and tokenizer downloaded and saved successfully.")
except Exception as e:
logger.error(f"Error downloading T5 model: {e}", exc_info=True)
sys.exit(1) # Exit with error
# --- spaCy Model ---
try:
logger.info("Downloading spaCy model: en_core_web_sm")
spacy.cli.download("en_core_web_sm")
logger.info("spaCy model 'en_core_web_sm' downloaded successfully.")
except Exception as e:
logger.error(f"Error downloading spaCy model: {e}", exc_info=True)
sys.exit(1) # Exit with error
logger.info("Model download and caching process completed successfully!")
if __name__ == "__main__":
download_and_cache_models()