Spaces:
Running
Running
File size: 2,521 Bytes
55af280 7ca7e37 55af280 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
import os
import sys
import logging
import spacy
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
# Set up logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[logging.StreamHandler(sys.stdout)]
)
logger = logging.getLogger(__name__)
# Import CACHE_BASE_DIR from config if possible, or use a safe default
try:
# First try to import from the app config
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from app.config import CACHE_BASE_DIR, T5_MODEL_NAME
except ImportError:
# Fallback to a safe location if config can't be imported
CACHE_BASE_DIR = os.getenv("CACHE_BASE_DIR", "/tmp/opportunity_t5_model_cache")
T5_MODEL_NAME = "Ayush472/T5QuestionGenerator"
logger.info(f"Could not import from config, using fallback cache directory: {CACHE_BASE_DIR}")
# Paths for optimized models
OPTIMIZED_TOKENIZER_PATH = os.path.join(CACHE_BASE_DIR, "transformers", "optimized_tokenizer")
OPTIMIZED_MODEL_PATH = os.path.join(CACHE_BASE_DIR, "transformers", "optimized_model")
def download_and_cache_models():
logger.info("Starting model download and caching process...")
# --- T5 Model and Tokenizer ---
try:
logger.info(f"Downloading T5 model and tokenizer: {T5_MODEL_NAME}")
os.makedirs(OPTIMIZED_TOKENIZER_PATH, exist_ok=True)
os.makedirs(OPTIMIZED_MODEL_PATH, exist_ok=True)
hf_token = os.environ.get("HF_TOKEN")
tokenizer = AutoTokenizer.from_pretrained(T5_MODEL_NAME, token=hf_token)
model = AutoModelForSeq2SeqLM.from_pretrained(T5_MODEL_NAME, token=hf_token)
tokenizer.save_pretrained(OPTIMIZED_TOKENIZER_PATH)
model.save_pretrained(OPTIMIZED_MODEL_PATH)
logger.info("T5 model and tokenizer downloaded and saved successfully.")
except Exception as e:
logger.error(f"Error downloading T5 model: {e}", exc_info=True)
sys.exit(1) # Exit with error
# --- spaCy Model ---
try:
logger.info("Downloading spaCy model: en_core_web_sm")
spacy.cli.download("en_core_web_sm")
logger.info("spaCy model 'en_core_web_sm' downloaded successfully.")
except Exception as e:
logger.error(f"Error downloading spaCy model: {e}", exc_info=True)
sys.exit(1) # Exit with error
logger.info("Model download and caching process completed successfully!")
if __name__ == "__main__":
download_and_cache_models()
|