def register_vocabulary_in_main(): """Register the Vocabulary class in __main__ to help with unpickling""" try: logger.info("Registering Vocabulary class in __main__ module") import sys import __main__ from app.image_captioning_service import Vocabulary, ImageCaptioningModel, EncoderCNN, TransformerDecoder, PositionalEncoding # Register classes in main module setattr(__main__, 'Vocabulary', Vocabulary) setattr(__main__, 'ImageCaptioningModel', ImageCaptioningModel) setattr(__main__, 'EncoderCNN', EncoderCNN) setattr(__main__, 'TransformerDecoder', TransformerDecoder) setattr(__main__, 'PositionalEncoding', PositionalEncoding) logger.info("Successfully registered classes in __main__") except Exception as e: logger.warning(f"Could not register classes in __main__: {e}") def setup_nltk(): """Set up NLTK data directory and ensure punkt tokenizer is available""" logger.info("Setting up NLTK...") # Create potential NLTK data directories with proper permissions nltk_dirs = [ os.path.expanduser('~/.nltk_data'), './nltk_data', '/usr/local/share/nltk_data' ] for directory in nltk_dirs: try: os.makedirs(directory, exist_ok=True) logger.info(f"Created NLTK data directory: {directory}") except Exception as e: logger.warning(f"Could not create NLTK directory {directory}: {e}") # Try to find punkt tokenizer try: nltk.data.find('tokenizers/punkt') logger.info("NLTK punkt tokenizer found!") return except LookupError: # Not found, try to download to different locations for directory in nltk_dirs: try: logger.info(f"Attempting to download punkt tokenizer to {directory}") nltk.download('punkt', download_dir=directory) logger.info(f"Successfully downloaded punkt tokenizer to {directory}") return except Exception as e: logger.warning(f"Failed to download punkt to {directory}: {e}") # If we get here, we couldn't download punkt anywhere logger.error("Could not download NLTK punkt tokenizer to any location") logger.error("The application may not function correctly") """ Main application entry point for Image Captioning API """ import os import sys import logging import nltk # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) # Setup NLTK data path def setup_cache_directories(): """Create and set up cache directories for PyTorch and other libraries""" cache_dirs = [ '/.cache', '/root/.cache', '/root/.cache/torch', '/home/.cache', '/home/.cache/torch', '/tmp/.cache', '/tmp/.cache/torch' ] for directory in cache_dirs: try: os.makedirs(directory, exist_ok=True) # Try to set permissions try: os.chmod(directory, 0o777) logger.info(f"Created cache directory with permissions: {directory}") except Exception as e: logger.warning(f"Could not set permissions for {directory}: {e}") except Exception as e: logger.warning(f"Could not create cache directory {directory}: {e}") # Try setting environment variables for torch home for cache_dir in ['/home/.cache/torch', '/tmp/.cache/torch', './torch_cache']: try: os.makedirs(cache_dir, exist_ok=True) os.environ['TORCH_HOME'] = cache_dir logger.info(f"Set TORCH_HOME to {cache_dir}") break except Exception as e: logger.warning(f"Could not use {cache_dir} as TORCH_HOME: {e}") logger.info(f"TORCH_HOME is set to: {os.environ.get('TORCH_HOME', 'Not set')}") # Check if model files exist and download if needed def ensure_models_exist(): model_path = "app/models/image_captioning_model.pth" vocab_path = "app/models/vocab.pkl" if not os.path.exists(model_path) or not os.path.exists(vocab_path): logger.info("Model files not found. Downloading...") from app.download_model import download_models download_models() else: logger.info("Model files found.") if __name__ == "__main__": # Setup cache directories setup_cache_directories() # Setup NLTK setup_nltk() # Register Vocabulary in main module register_vocabulary_in_main() # Ensure model files exist ensure_models_exist() # Run the FastAPI application import uvicorn from app.api import app logger.info("Starting Image Captioning API server...") uvicorn.run(app, host="0.0.0.0", port=7860)