File size: 2,521 Bytes
55af280
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7ca7e37
55af280
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import os
import sys
import logging
import spacy
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

# Set up logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[logging.StreamHandler(sys.stdout)]
)
logger = logging.getLogger(__name__)

# Import CACHE_BASE_DIR from config if possible, or use a safe default
try:
    # First try to import from the app config
    sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
    from app.config import CACHE_BASE_DIR, T5_MODEL_NAME
except ImportError:
    # Fallback to a safe location if config can't be imported
    CACHE_BASE_DIR = os.getenv("CACHE_BASE_DIR", "/tmp/opportunity_t5_model_cache")
    T5_MODEL_NAME = "Ayush472/T5QuestionGenerator"
    logger.info(f"Could not import from config, using fallback cache directory: {CACHE_BASE_DIR}")

# Paths for optimized models
OPTIMIZED_TOKENIZER_PATH = os.path.join(CACHE_BASE_DIR, "transformers", "optimized_tokenizer")
OPTIMIZED_MODEL_PATH = os.path.join(CACHE_BASE_DIR, "transformers", "optimized_model")

def download_and_cache_models():
    logger.info("Starting model download and caching process...")

    # --- T5 Model and Tokenizer ---
    try:
        logger.info(f"Downloading T5 model and tokenizer: {T5_MODEL_NAME}")
        
        os.makedirs(OPTIMIZED_TOKENIZER_PATH, exist_ok=True)
        os.makedirs(OPTIMIZED_MODEL_PATH, exist_ok=True)

        hf_token = os.environ.get("HF_TOKEN")
        tokenizer = AutoTokenizer.from_pretrained(T5_MODEL_NAME, token=hf_token)
        model = AutoModelForSeq2SeqLM.from_pretrained(T5_MODEL_NAME, token=hf_token)

        tokenizer.save_pretrained(OPTIMIZED_TOKENIZER_PATH)
        model.save_pretrained(OPTIMIZED_MODEL_PATH)
        
        logger.info("T5 model and tokenizer downloaded and saved successfully.")
    except Exception as e:
        logger.error(f"Error downloading T5 model: {e}", exc_info=True)
        sys.exit(1) # Exit with error

    # --- spaCy Model ---
    try:
        logger.info("Downloading spaCy model: en_core_web_sm")
        spacy.cli.download("en_core_web_sm")
        logger.info("spaCy model 'en_core_web_sm' downloaded successfully.")
    except Exception as e:
        logger.error(f"Error downloading spaCy model: {e}", exc_info=True)
        sys.exit(1) # Exit with error

    logger.info("Model download and caching process completed successfully!")

if __name__ == "__main__":
    download_and_cache_models()