Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| from PIL import Image | |
| import torchvision.transforms as transforms | |
| import nltk | |
| import pickle | |
| import warnings | |
| import logging | |
| import math | |
| warnings.filterwarnings("ignore") | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Make sure NLTK tokenizer is available | |
| try: | |
| nltk.data.find('tokenizers/punkt') | |
| except LookupError: | |
| # Try to download to a directory where we have write permissions | |
| try: | |
| # Try user home directory first | |
| nltk.download('punkt', download_dir=os.path.expanduser('~/.nltk_data')) | |
| logger.info("Downloaded NLTK punkt to user home directory") | |
| except: | |
| # Then try current directory | |
| try: | |
| os.makedirs('./nltk_data', exist_ok=True) | |
| nltk.download('punkt', download_dir='./nltk_data') | |
| logger.info("Downloaded NLTK punkt to current directory") | |
| except Exception as e: | |
| logger.error(f"Failed to download NLTK punkt: {str(e)}") | |
| # Continue anyway, as we might have the data elsewhere | |
| # Vocabulary class for loading the vocabulary | |
| class Vocabulary: | |
| def __init__(self): | |
| self.word2idx = {} | |
| self.idx2word = {} | |
| self.idx = 0 | |
| def add_word(self, word): | |
| if word not in self.word2idx: | |
| self.word2idx[word] = self.idx | |
| self.idx2word[self.idx] = word | |
| self.idx += 1 | |
| def __len__(self): | |
| return len(self.word2idx) | |
| def tokenize(self, text): | |
| """Tokenize text into a list of tokens""" | |
| tokens = nltk.tokenize.word_tokenize(str(text).lower()) | |
| return tokens | |
| def load(cls, path): | |
| """Load vocabulary from pickle file""" | |
| # Try multiple strategies to load the vocabulary | |
| try: | |
| # Strategy 1: Use a custom unpickler with more comprehensive handling | |
| class CustomUnpickler(pickle.Unpickler): | |
| def find_class(self, module, name): | |
| # Check for Vocabulary in any module path | |
| if name == 'Vocabulary': | |
| # Try to find Vocabulary in different possible modules | |
| # First in this current module | |
| return Vocabulary | |
| # Check for special cases | |
| if module == '__main__': | |
| # Look in typical modules where the class might be defined | |
| if name == 'Vocabulary': | |
| return Vocabulary | |
| # Default behavior | |
| return super().find_class(module, name) | |
| with open(path, 'rb') as f: | |
| return CustomUnpickler(f).load() | |
| except Exception as e: | |
| logger.error(f"First loading method failed: {str(e)}") | |
| try: | |
| # Strategy 2: Manual recreation of vocabulary object from raw pickle data | |
| with open(path, 'rb') as f: | |
| raw_data = pickle.load(f) | |
| # If it's a dict-like object, we can try to extract the vocabulary data | |
| if hasattr(raw_data, 'word2idx') and hasattr(raw_data, 'idx2word'): | |
| # Create a new Vocabulary instance | |
| vocab = Vocabulary() | |
| vocab.word2idx = raw_data.word2idx | |
| vocab.idx2word = raw_data.idx2word | |
| vocab.idx = raw_data.idx | |
| return vocab | |
| else: | |
| # Create a fresh vocabulary directly from the dictionary data | |
| vocab = Vocabulary() | |
| # Try to extract word mappings from whatever structure the pickle has | |
| if isinstance(raw_data, dict): | |
| if 'word2idx' in raw_data and 'idx2word' in raw_data: | |
| vocab.word2idx = raw_data['word2idx'] | |
| vocab.idx2word = raw_data['idx2word'] | |
| vocab.idx = len(vocab.word2idx) | |
| return vocab | |
| raise ValueError("Could not extract vocabulary data from pickle file") | |
| except Exception as e: | |
| logger.error(f"Second loading method failed: {str(e)}") | |
| # Try to use fix_vocab_pickle as a last resort | |
| try: | |
| from app.fix_vocab_pickle import fix_vocab_pickle | |
| fixed_path = path + "_fixed.pkl" | |
| vocab = fix_vocab_pickle(path, fixed_path) | |
| if vocab: | |
| logger.info(f"Vocabulary fixed and saved to {fixed_path}") | |
| return vocab | |
| except Exception as e: | |
| logger.error(f"Vocabulary fixing failed: {str(e)}") | |
| raise RuntimeError(f"All vocabulary loading methods failed. Original error: {str(e)}") | |
| # Encoder: Pretrained ResNet | |
| class EncoderCNN(torch.nn.Module): | |
| def __init__(self, embed_dim): | |
| super(EncoderCNN, self).__init__() | |
| # Load pretrained ResNet | |
| import torchvision.models as models | |
| # Try different approaches to load ResNet50 | |
| resnet = None | |
| # Option 1: Try to load the locally saved model | |
| try: | |
| logger.info("Trying to load locally saved ResNet50 model...") | |
| resnet = models.resnet50(pretrained=False) | |
| local_model_path = "app/models/resnet50.pth" | |
| if os.path.exists(local_model_path): | |
| resnet.load_state_dict(torch.load(local_model_path)) | |
| logger.info("Successfully loaded ResNet50 from local file") | |
| else: | |
| logger.warning(f"Local ResNet50 model not found at {local_model_path}") | |
| # Fall back to pretrained model | |
| resnet = None | |
| except Exception as e: | |
| logger.warning(f"Error loading local ResNet50 model: {str(e)}") | |
| resnet = None | |
| # Option 2: Try loading with pretrained weights | |
| if resnet is None: | |
| try: | |
| logger.info("Trying to load ResNet50 with pretrained weights...") | |
| # Set cache directory | |
| os.makedirs('/tmp/torch_cache', exist_ok=True) | |
| os.environ['TORCH_HOME'] = '/tmp/torch_cache' | |
| resnet = models.resnet50(pretrained=True) | |
| logger.info("Successfully loaded pretrained ResNet50 model") | |
| except Exception as e: | |
| logger.warning(f"Error loading pretrained ResNet50: {str(e)}") | |
| resnet = None | |
| # Option 3: Fall back to model without pretrained weights | |
| if resnet is None: | |
| logger.info("Falling back to ResNet50 without pretrained weights...") | |
| resnet = models.resnet50(pretrained=False) | |
| logger.warning("Using ResNet50 WITHOUT pretrained weights - captions may be less accurate") | |
| # Remove the final FC layer | |
| modules = list(resnet.children())[:-1] | |
| self.resnet = torch.nn.Sequential(*modules) | |
| # Project to embedding dimension | |
| self.fc = torch.nn.Linear(resnet.fc.in_features, embed_dim) | |
| self.bn = torch.nn.BatchNorm1d(embed_dim) | |
| self.dropout = torch.nn.Dropout(0.5) | |
| def forward(self, images): | |
| with torch.no_grad(): # No gradients for pretrained model | |
| features = self.resnet(images) | |
| features = features.reshape(features.size(0), -1) | |
| features = self.fc(features) | |
| features = self.bn(features) | |
| features = self.dropout(features) | |
| return features | |
| # Positional Encoding for Transformer | |
| class PositionalEncoding(torch.nn.Module): | |
| def __init__(self, d_model, max_len=5000): | |
| super(PositionalEncoding, self).__init__() | |
| # Create positional encoding | |
| pe = torch.zeros(max_len, d_model) | |
| position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) | |
| div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) | |
| pe[:, 0::2] = torch.sin(position * div_term) | |
| pe[:, 1::2] = torch.cos(position * div_term) | |
| pe = pe.unsqueeze(0) | |
| # Register buffer (not model parameter) | |
| self.register_buffer('pe', pe) | |
| def forward(self, x): | |
| x = x + self.pe[:, :x.size(1), :].to(x.device) | |
| return x | |
| # Custom Transformer Decoder | |
| class TransformerDecoder(torch.nn.Module): | |
| def __init__(self, vocab_size, embed_dim, num_heads, ff_dim, num_layers, dropout=0.1): | |
| super(TransformerDecoder, self).__init__() | |
| import math | |
| # Store math module as an instance variable so we can use it in forward | |
| self.math = math | |
| # Embedding layer | |
| self.embedding = torch.nn.Embedding(vocab_size, embed_dim) | |
| self.positional_encoding = PositionalEncoding(embed_dim) | |
| # Transformer decoder layers | |
| decoder_layer = torch.nn.TransformerDecoderLayer( | |
| d_model=embed_dim, | |
| nhead=num_heads, | |
| dim_feedforward=ff_dim, | |
| dropout=dropout, | |
| batch_first=True | |
| ) | |
| self.transformer_decoder = torch.nn.TransformerDecoder( | |
| decoder_layer, | |
| num_layers=num_layers | |
| ) | |
| # Output layer | |
| self.fc = torch.nn.Linear(embed_dim, vocab_size) | |
| self.dropout = torch.nn.Dropout(dropout) | |
| def generate_square_subsequent_mask(self, sz): | |
| # Create mask to prevent attention to future tokens | |
| mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) | |
| mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) | |
| return mask | |
| def forward(self, tgt, memory): | |
| # Create mask for decoder | |
| tgt_mask = self.generate_square_subsequent_mask(tgt.size(1)).to(tgt.device) | |
| # Embed tokens and add positional encoding | |
| tgt = self.embedding(tgt) * self.math.sqrt(self.embedding.embedding_dim) | |
| tgt = self.positional_encoding(tgt) | |
| tgt = self.dropout(tgt) | |
| # Pass through transformer decoder | |
| output = self.transformer_decoder( | |
| tgt, | |
| memory, | |
| tgt_mask=tgt_mask | |
| ) | |
| # Project to vocabulary | |
| output = self.fc(output) | |
| return output | |
| # Complete Image Captioning Model | |
| class ImageCaptioningModel(torch.nn.Module): | |
| def __init__(self, vocab_size, embed_dim, hidden_dim, num_heads, num_layers): | |
| super(ImageCaptioningModel, self).__init__() | |
| # Make sure math is available | |
| self.math = math | |
| # Image encoder | |
| self.encoder = EncoderCNN(embed_dim) | |
| # Caption decoder | |
| self.decoder = TransformerDecoder( | |
| vocab_size=vocab_size, | |
| embed_dim=embed_dim, | |
| num_heads=num_heads, | |
| ff_dim=hidden_dim, | |
| num_layers=num_layers | |
| ) | |
| def forward(self, images, captions): | |
| # Encode images | |
| img_features = self.encoder(images) | |
| # Reshape for transformer (batch_size, seq_len, embed_dim) | |
| # In this case, seq_len=1 since we have a single "token" representing the image | |
| img_features = img_features.unsqueeze(1) | |
| # Decode captions (excluding the last token, typically <EOS>) | |
| outputs = self.decoder(captions[:, :-1], img_features) | |
| return outputs | |
| def generate_caption(self, image, vocab, max_length=20): | |
| """Generate a caption for the given image""" | |
| with torch.no_grad(): | |
| # Encode image | |
| img_features = self.encoder(image.unsqueeze(0)) | |
| img_features = img_features.unsqueeze(1) | |
| # Start with < SOS > token | |
| current_ids = torch.tensor([[vocab.word2idx['<SOS>']]], dtype=torch.long).to(image.device) | |
| # Generate words one by one | |
| result_caption = [] | |
| for i in range(max_length): | |
| # Predict next word | |
| outputs = self.decoder(current_ids, img_features) | |
| # Get the most likely next word | |
| _, predicted = outputs[:, -1, :].max(1) | |
| # Add predicted word to the sequence | |
| result_caption.append(predicted.item()) | |
| # Break if <EOS> | |
| if predicted.item() == vocab.word2idx['<EOS>']: | |
| break | |
| # Add to current sequence for next iteration | |
| current_ids = torch.cat([current_ids, predicted.unsqueeze(0)], dim=1) | |
| # Convert word indices to words | |
| words = [vocab.idx2word[idx] for idx in result_caption] | |
| # Remove <EOS> token if present | |
| if words and words[-1] == '<EOS>': | |
| words = words[:-1] | |
| return ' '.join(words) | |
| def load_image(image_path, transform=None): | |
| """Load and preprocess an image""" | |
| image = Image.open(image_path).convert('RGB') | |
| if transform is None: | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| image = transform(image) | |
| return image | |
| def generate_caption( | |
| image_path, | |
| model_path, | |
| vocab_path, | |
| max_length=20, | |
| device=None | |
| ): | |
| """Generate a caption for an image""" | |
| # Set device | |
| if device is None: | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| logger.info(f"Using device: {device}") | |
| # Check if files exist | |
| if not os.path.exists(image_path): | |
| raise FileNotFoundError(f"Image not found at {image_path}") | |
| if not os.path.exists(model_path): | |
| raise FileNotFoundError(f"Model not found at {model_path}") | |
| if not os.path.exists(vocab_path): | |
| raise FileNotFoundError(f"Vocabulary not found at {vocab_path}") | |
| # Setup temporary cache directory for torch if needed | |
| try: | |
| os.makedirs('/tmp/torch_cache', exist_ok=True) | |
| os.environ['TORCH_HOME'] = '/tmp/torch_cache' | |
| logger.info(f"Set TORCH_HOME to /tmp/torch_cache") | |
| except Exception as e: | |
| logger.warning(f"Could not set up temporary torch cache: {e}") | |
| # Load vocabulary | |
| logger.info(f"Loading vocabulary from {vocab_path}") | |
| vocab = Vocabulary.load(vocab_path) | |
| logger.info(f"Loaded vocabulary with {len(vocab)} words") | |
| # Load model | |
| # Hyperparameters - must match those used during training | |
| embed_dim = 512 | |
| hidden_dim = 2048 | |
| num_layers = 6 | |
| num_heads = 8 | |
| # Initialize model | |
| logger.info("Initializing model") | |
| model = ImageCaptioningModel( | |
| vocab_size=len(vocab), | |
| embed_dim=embed_dim, | |
| hidden_dim=hidden_dim, | |
| num_heads=num_heads, | |
| num_layers=num_layers | |
| ).to(device) | |
| # Load model weights | |
| logger.info(f"Loading model weights from {model_path}") | |
| try: | |
| # First try our custom loader | |
| try: | |
| logger.info("Trying custom model loader...") | |
| # Replace this with Python's built-in pickle that we can customize | |
| # Define a custom unpickler | |
| class CustomUnpickler(pickle.Unpickler): | |
| def find_class(self, module, name): | |
| # If it's looking for the Vocabulary class in __main__ | |
| if name == 'Vocabulary': | |
| # Return our current Vocabulary class | |
| return Vocabulary | |
| if module == '__main__': | |
| if name == 'ImageCaptioningModel': | |
| return ImageCaptioningModel | |
| if name == 'EncoderCNN': | |
| return EncoderCNN | |
| if name == 'TransformerDecoder': | |
| return TransformerDecoder | |
| if name == 'PositionalEncoding': | |
| return PositionalEncoding | |
| # Use the normal behavior for everything else | |
| return super().find_class(module, name) | |
| # Use a custom loading approach | |
| with open(model_path, 'rb') as f: | |
| checkpoint = CustomUnpickler(f).load() | |
| logger.info("Successfully loaded model using custom unpickler") | |
| except Exception as e: | |
| logger.warning(f"Custom loader failed: {str(e)}") | |
| logger.info("Falling back to standard torch.load...") | |
| # Fall back to standard loader | |
| checkpoint = torch.load(model_path, map_location=device) | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| model.eval() | |
| logger.info("Model loaded successfully") | |
| except Exception as e: | |
| logger.error(f"Error loading model: {str(e)}") | |
| raise | |
| # Load and process image | |
| logger.info(f"Loading and processing image from {image_path}") | |
| try: | |
| image = load_image(image_path) | |
| image = image.to(device) | |
| logger.info("Image processed successfully") | |
| except Exception as e: | |
| logger.error(f"Error processing image: {str(e)}") | |
| raise | |
| # Generate caption | |
| logger.info("Generating caption") | |
| try: | |
| caption = model.generate_caption(image, vocab, max_length=max_length) | |
| logger.info(f"Generated caption: {caption}") | |
| return caption | |
| except Exception as e: | |
| logger.error(f"Error generating caption: {str(e)}") | |
| raise |