import os import torch from PIL import Image import torchvision.transforms as transforms import nltk import pickle import warnings import logging import math warnings.filterwarnings("ignore") # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Make sure NLTK tokenizer is available try: nltk.data.find('tokenizers/punkt') except LookupError: # Try to download to a directory where we have write permissions try: # Try user home directory first nltk.download('punkt', download_dir=os.path.expanduser('~/.nltk_data')) logger.info("Downloaded NLTK punkt to user home directory") except: # Then try current directory try: os.makedirs('./nltk_data', exist_ok=True) nltk.download('punkt', download_dir='./nltk_data') logger.info("Downloaded NLTK punkt to current directory") except Exception as e: logger.error(f"Failed to download NLTK punkt: {str(e)}") # Continue anyway, as we might have the data elsewhere # Vocabulary class for loading the vocabulary class Vocabulary: def __init__(self): self.word2idx = {} self.idx2word = {} self.idx = 0 def add_word(self, word): if word not in self.word2idx: self.word2idx[word] = self.idx self.idx2word[self.idx] = word self.idx += 1 def __len__(self): return len(self.word2idx) def tokenize(self, text): """Tokenize text into a list of tokens""" tokens = nltk.tokenize.word_tokenize(str(text).lower()) return tokens @classmethod def load(cls, path): """Load vocabulary from pickle file""" # Try multiple strategies to load the vocabulary try: # Strategy 1: Use a custom unpickler with more comprehensive handling class CustomUnpickler(pickle.Unpickler): def find_class(self, module, name): # Check for Vocabulary in any module path if name == 'Vocabulary': # Try to find Vocabulary in different possible modules # First in this current module return Vocabulary # Check for special cases if module == '__main__': # Look in typical modules where the class might be defined if name == 'Vocabulary': return Vocabulary # Default behavior return super().find_class(module, name) with open(path, 'rb') as f: return CustomUnpickler(f).load() except Exception as e: logger.error(f"First loading method failed: {str(e)}") try: # Strategy 2: Manual recreation of vocabulary object from raw pickle data with open(path, 'rb') as f: raw_data = pickle.load(f) # If it's a dict-like object, we can try to extract the vocabulary data if hasattr(raw_data, 'word2idx') and hasattr(raw_data, 'idx2word'): # Create a new Vocabulary instance vocab = Vocabulary() vocab.word2idx = raw_data.word2idx vocab.idx2word = raw_data.idx2word vocab.idx = raw_data.idx return vocab else: # Create a fresh vocabulary directly from the dictionary data vocab = Vocabulary() # Try to extract word mappings from whatever structure the pickle has if isinstance(raw_data, dict): if 'word2idx' in raw_data and 'idx2word' in raw_data: vocab.word2idx = raw_data['word2idx'] vocab.idx2word = raw_data['idx2word'] vocab.idx = len(vocab.word2idx) return vocab raise ValueError("Could not extract vocabulary data from pickle file") except Exception as e: logger.error(f"Second loading method failed: {str(e)}") # Try to use fix_vocab_pickle as a last resort try: from app.fix_vocab_pickle import fix_vocab_pickle fixed_path = path + "_fixed.pkl" vocab = fix_vocab_pickle(path, fixed_path) if vocab: logger.info(f"Vocabulary fixed and saved to {fixed_path}") return vocab except Exception as e: logger.error(f"Vocabulary fixing failed: {str(e)}") raise RuntimeError(f"All vocabulary loading methods failed. Original error: {str(e)}") # Encoder: Pretrained ResNet class EncoderCNN(torch.nn.Module): def __init__(self, embed_dim): super(EncoderCNN, self).__init__() # Load pretrained ResNet import torchvision.models as models # Try different approaches to load ResNet50 resnet = None # Option 1: Try to load the locally saved model try: logger.info("Trying to load locally saved ResNet50 model...") resnet = models.resnet50(pretrained=False) local_model_path = "app/models/resnet50.pth" if os.path.exists(local_model_path): resnet.load_state_dict(torch.load(local_model_path)) logger.info("Successfully loaded ResNet50 from local file") else: logger.warning(f"Local ResNet50 model not found at {local_model_path}") # Fall back to pretrained model resnet = None except Exception as e: logger.warning(f"Error loading local ResNet50 model: {str(e)}") resnet = None # Option 2: Try loading with pretrained weights if resnet is None: try: logger.info("Trying to load ResNet50 with pretrained weights...") # Set cache directory os.makedirs('/tmp/torch_cache', exist_ok=True) os.environ['TORCH_HOME'] = '/tmp/torch_cache' resnet = models.resnet50(pretrained=True) logger.info("Successfully loaded pretrained ResNet50 model") except Exception as e: logger.warning(f"Error loading pretrained ResNet50: {str(e)}") resnet = None # Option 3: Fall back to model without pretrained weights if resnet is None: logger.info("Falling back to ResNet50 without pretrained weights...") resnet = models.resnet50(pretrained=False) logger.warning("Using ResNet50 WITHOUT pretrained weights - captions may be less accurate") # Remove the final FC layer modules = list(resnet.children())[:-1] self.resnet = torch.nn.Sequential(*modules) # Project to embedding dimension self.fc = torch.nn.Linear(resnet.fc.in_features, embed_dim) self.bn = torch.nn.BatchNorm1d(embed_dim) self.dropout = torch.nn.Dropout(0.5) def forward(self, images): with torch.no_grad(): # No gradients for pretrained model features = self.resnet(images) features = features.reshape(features.size(0), -1) features = self.fc(features) features = self.bn(features) features = self.dropout(features) return features # Positional Encoding for Transformer class PositionalEncoding(torch.nn.Module): def __init__(self, d_model, max_len=5000): super(PositionalEncoding, self).__init__() # Create positional encoding pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0) # Register buffer (not model parameter) self.register_buffer('pe', pe) def forward(self, x): x = x + self.pe[:, :x.size(1), :].to(x.device) return x # Custom Transformer Decoder class TransformerDecoder(torch.nn.Module): def __init__(self, vocab_size, embed_dim, num_heads, ff_dim, num_layers, dropout=0.1): super(TransformerDecoder, self).__init__() import math # Store math module as an instance variable so we can use it in forward self.math = math # Embedding layer self.embedding = torch.nn.Embedding(vocab_size, embed_dim) self.positional_encoding = PositionalEncoding(embed_dim) # Transformer decoder layers decoder_layer = torch.nn.TransformerDecoderLayer( d_model=embed_dim, nhead=num_heads, dim_feedforward=ff_dim, dropout=dropout, batch_first=True ) self.transformer_decoder = torch.nn.TransformerDecoder( decoder_layer, num_layers=num_layers ) # Output layer self.fc = torch.nn.Linear(embed_dim, vocab_size) self.dropout = torch.nn.Dropout(dropout) def generate_square_subsequent_mask(self, sz): # Create mask to prevent attention to future tokens mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) return mask def forward(self, tgt, memory): # Create mask for decoder tgt_mask = self.generate_square_subsequent_mask(tgt.size(1)).to(tgt.device) # Embed tokens and add positional encoding tgt = self.embedding(tgt) * self.math.sqrt(self.embedding.embedding_dim) tgt = self.positional_encoding(tgt) tgt = self.dropout(tgt) # Pass through transformer decoder output = self.transformer_decoder( tgt, memory, tgt_mask=tgt_mask ) # Project to vocabulary output = self.fc(output) return output # Complete Image Captioning Model class ImageCaptioningModel(torch.nn.Module): def __init__(self, vocab_size, embed_dim, hidden_dim, num_heads, num_layers): super(ImageCaptioningModel, self).__init__() # Make sure math is available self.math = math # Image encoder self.encoder = EncoderCNN(embed_dim) # Caption decoder self.decoder = TransformerDecoder( vocab_size=vocab_size, embed_dim=embed_dim, num_heads=num_heads, ff_dim=hidden_dim, num_layers=num_layers ) def forward(self, images, captions): # Encode images img_features = self.encoder(images) # Reshape for transformer (batch_size, seq_len, embed_dim) # In this case, seq_len=1 since we have a single "token" representing the image img_features = img_features.unsqueeze(1) # Decode captions (excluding the last token, typically ) outputs = self.decoder(captions[:, :-1], img_features) return outputs def generate_caption(self, image, vocab, max_length=20): """Generate a caption for the given image""" with torch.no_grad(): # Encode image img_features = self.encoder(image.unsqueeze(0)) img_features = img_features.unsqueeze(1) # Start with < SOS > token current_ids = torch.tensor([[vocab.word2idx['']]], dtype=torch.long).to(image.device) # Generate words one by one result_caption = [] for i in range(max_length): # Predict next word outputs = self.decoder(current_ids, img_features) # Get the most likely next word _, predicted = outputs[:, -1, :].max(1) # Add predicted word to the sequence result_caption.append(predicted.item()) # Break if if predicted.item() == vocab.word2idx['']: break # Add to current sequence for next iteration current_ids = torch.cat([current_ids, predicted.unsqueeze(0)], dim=1) # Convert word indices to words words = [vocab.idx2word[idx] for idx in result_caption] # Remove token if present if words and words[-1] == '': words = words[:-1] return ' '.join(words) def load_image(image_path, transform=None): """Load and preprocess an image""" image = Image.open(image_path).convert('RGB') if transform is None: transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) image = transform(image) return image def generate_caption( image_path, model_path, vocab_path, max_length=20, device=None ): """Generate a caption for an image""" # Set device if device is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logger.info(f"Using device: {device}") # Check if files exist if not os.path.exists(image_path): raise FileNotFoundError(f"Image not found at {image_path}") if not os.path.exists(model_path): raise FileNotFoundError(f"Model not found at {model_path}") if not os.path.exists(vocab_path): raise FileNotFoundError(f"Vocabulary not found at {vocab_path}") # Setup temporary cache directory for torch if needed try: os.makedirs('/tmp/torch_cache', exist_ok=True) os.environ['TORCH_HOME'] = '/tmp/torch_cache' logger.info(f"Set TORCH_HOME to /tmp/torch_cache") except Exception as e: logger.warning(f"Could not set up temporary torch cache: {e}") # Load vocabulary logger.info(f"Loading vocabulary from {vocab_path}") vocab = Vocabulary.load(vocab_path) logger.info(f"Loaded vocabulary with {len(vocab)} words") # Load model # Hyperparameters - must match those used during training embed_dim = 512 hidden_dim = 2048 num_layers = 6 num_heads = 8 # Initialize model logger.info("Initializing model") model = ImageCaptioningModel( vocab_size=len(vocab), embed_dim=embed_dim, hidden_dim=hidden_dim, num_heads=num_heads, num_layers=num_layers ).to(device) # Load model weights logger.info(f"Loading model weights from {model_path}") try: # First try our custom loader try: logger.info("Trying custom model loader...") # Replace this with Python's built-in pickle that we can customize # Define a custom unpickler class CustomUnpickler(pickle.Unpickler): def find_class(self, module, name): # If it's looking for the Vocabulary class in __main__ if name == 'Vocabulary': # Return our current Vocabulary class return Vocabulary if module == '__main__': if name == 'ImageCaptioningModel': return ImageCaptioningModel if name == 'EncoderCNN': return EncoderCNN if name == 'TransformerDecoder': return TransformerDecoder if name == 'PositionalEncoding': return PositionalEncoding # Use the normal behavior for everything else return super().find_class(module, name) # Use a custom loading approach with open(model_path, 'rb') as f: checkpoint = CustomUnpickler(f).load() logger.info("Successfully loaded model using custom unpickler") except Exception as e: logger.warning(f"Custom loader failed: {str(e)}") logger.info("Falling back to standard torch.load...") # Fall back to standard loader checkpoint = torch.load(model_path, map_location=device) model.load_state_dict(checkpoint['model_state_dict']) model.eval() logger.info("Model loaded successfully") except Exception as e: logger.error(f"Error loading model: {str(e)}") raise # Load and process image logger.info(f"Loading and processing image from {image_path}") try: image = load_image(image_path) image = image.to(device) logger.info("Image processed successfully") except Exception as e: logger.error(f"Error processing image: {str(e)}") raise # Generate caption logger.info("Generating caption") try: caption = model.generate_caption(image, vocab, max_length=max_length) logger.info(f"Generated caption: {caption}") return caption except Exception as e: logger.error(f"Error generating caption: {str(e)}") raise