Spaces:
Sleeping
Sleeping
| def register_vocabulary_in_main(): | |
| """Register the Vocabulary class in __main__ to help with unpickling""" | |
| try: | |
| logger.info("Registering Vocabulary class in __main__ module") | |
| import sys | |
| import __main__ | |
| from app.image_captioning_service import Vocabulary, ImageCaptioningModel, EncoderCNN, TransformerDecoder, PositionalEncoding | |
| # Register classes in main module | |
| setattr(__main__, 'Vocabulary', Vocabulary) | |
| setattr(__main__, 'ImageCaptioningModel', ImageCaptioningModel) | |
| setattr(__main__, 'EncoderCNN', EncoderCNN) | |
| setattr(__main__, 'TransformerDecoder', TransformerDecoder) | |
| setattr(__main__, 'PositionalEncoding', PositionalEncoding) | |
| logger.info("Successfully registered classes in __main__") | |
| except Exception as e: | |
| logger.warning(f"Could not register classes in __main__: {e}") | |
| def setup_nltk(): | |
| """Set up NLTK data directory and ensure punkt tokenizer is available""" | |
| logger.info("Setting up NLTK...") | |
| # Create potential NLTK data directories with proper permissions | |
| nltk_dirs = [ | |
| os.path.expanduser('~/.nltk_data'), | |
| './nltk_data', | |
| '/usr/local/share/nltk_data' | |
| ] | |
| for directory in nltk_dirs: | |
| try: | |
| os.makedirs(directory, exist_ok=True) | |
| logger.info(f"Created NLTK data directory: {directory}") | |
| except Exception as e: | |
| logger.warning(f"Could not create NLTK directory {directory}: {e}") | |
| # Try to find punkt tokenizer | |
| try: | |
| nltk.data.find('tokenizers/punkt') | |
| logger.info("NLTK punkt tokenizer found!") | |
| return | |
| except LookupError: | |
| # Not found, try to download to different locations | |
| for directory in nltk_dirs: | |
| try: | |
| logger.info(f"Attempting to download punkt tokenizer to {directory}") | |
| nltk.download('punkt', download_dir=directory) | |
| logger.info(f"Successfully downloaded punkt tokenizer to {directory}") | |
| return | |
| except Exception as e: | |
| logger.warning(f"Failed to download punkt to {directory}: {e}") | |
| # If we get here, we couldn't download punkt anywhere | |
| logger.error("Could not download NLTK punkt tokenizer to any location") | |
| logger.error("The application may not function correctly") | |
| """ | |
| Main application entry point for Image Captioning API | |
| """ | |
| import os | |
| import sys | |
| import logging | |
| import nltk | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # Setup NLTK data path | |
| def setup_cache_directories(): | |
| """Create and set up cache directories for PyTorch and other libraries""" | |
| cache_dirs = [ | |
| '/.cache', | |
| '/root/.cache', | |
| '/root/.cache/torch', | |
| '/home/.cache', | |
| '/home/.cache/torch', | |
| '/tmp/.cache', | |
| '/tmp/.cache/torch' | |
| ] | |
| for directory in cache_dirs: | |
| try: | |
| os.makedirs(directory, exist_ok=True) | |
| # Try to set permissions | |
| try: | |
| os.chmod(directory, 0o777) | |
| logger.info(f"Created cache directory with permissions: {directory}") | |
| except Exception as e: | |
| logger.warning(f"Could not set permissions for {directory}: {e}") | |
| except Exception as e: | |
| logger.warning(f"Could not create cache directory {directory}: {e}") | |
| # Try setting environment variables for torch home | |
| for cache_dir in ['/home/.cache/torch', '/tmp/.cache/torch', './torch_cache']: | |
| try: | |
| os.makedirs(cache_dir, exist_ok=True) | |
| os.environ['TORCH_HOME'] = cache_dir | |
| logger.info(f"Set TORCH_HOME to {cache_dir}") | |
| break | |
| except Exception as e: | |
| logger.warning(f"Could not use {cache_dir} as TORCH_HOME: {e}") | |
| logger.info(f"TORCH_HOME is set to: {os.environ.get('TORCH_HOME', 'Not set')}") | |
| # Check if model files exist and download if needed | |
| def ensure_models_exist(): | |
| model_path = "app/models/image_captioning_model.pth" | |
| vocab_path = "app/models/vocab.pkl" | |
| if not os.path.exists(model_path) or not os.path.exists(vocab_path): | |
| logger.info("Model files not found. Downloading...") | |
| from app.download_model import download_models | |
| download_models() | |
| else: | |
| logger.info("Model files found.") | |
| if __name__ == "__main__": | |
| # Setup cache directories | |
| setup_cache_directories() | |
| # Setup NLTK | |
| setup_nltk() | |
| # Register Vocabulary in main module | |
| register_vocabulary_in_main() | |
| # Ensure model files exist | |
| ensure_models_exist() | |
| # Run the FastAPI application | |
| import uvicorn | |
| from app.api import app | |
| logger.info("Starting Image Captioning API server...") | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |