image-captioning-api / app /download_model.py
dixisouls's picture
Initial Commit
a0c5c81
import os
import sys
from huggingface_hub import hf_hub_download
import shutil
import logging
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def download_models():
"""Download model files from Hugging Face Hub"""
logger.info("Downloading model files...")
# Create directories if they don't exist
os.makedirs("app/models", exist_ok=True)
try:
# Download the model and vocabulary from Hugging Face
logger.info("Downloading model from dixisouls/image-captioning-model...")
model_path = hf_hub_download(
repo_id="dixisouls/image-captioning-model",
filename="image_captioning_model.pth",
repo_type="model"
)
logger.info("Downloading vocabulary from dixisouls/image-captioning-model...")
vocab_path = hf_hub_download(
repo_id="dixisouls/image-captioning-model",
filename="vocab.pkl",
repo_type="model"
)
# Copy the downloaded files to the app/models directory
shutil.copy(model_path, "app/models/image_captioning_model.pth")
shutil.copy(vocab_path, "app/models/vocab.pkl")
logger.info(f"Model downloaded successfully to app/models/image_captioning_model.pth")
logger.info(f"Vocabulary downloaded successfully to app/models/vocab.pkl")
# Create fixed vocabulary file if needed
try:
from app.fix_vocab_pickle import fix_vocab_pickle
fixed_vocab = fix_vocab_pickle("app/models/vocab.pkl", "app/models/vocab_fixed.pkl")
if fixed_vocab:
logger.info("Created fixed vocabulary file at app/models/vocab_fixed.pkl")
except Exception as e:
logger.warning(f"Could not create fixed vocabulary file: {str(e)}")
except Exception as e:
logger.error(f"Error downloading model files: {e}")
sys.exit(1)
if __name__ == "__main__":
download_models()