image-captioning-api / app /image_captioning_service.py
dixisouls's picture
SOS error again
9a8df65
import os
import torch
from PIL import Image
import torchvision.transforms as transforms
import nltk
import pickle
import warnings
import logging
import math
warnings.filterwarnings("ignore")
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Make sure NLTK tokenizer is available
try:
nltk.data.find('tokenizers/punkt')
except LookupError:
# Try to download to a directory where we have write permissions
try:
# Try user home directory first
nltk.download('punkt', download_dir=os.path.expanduser('~/.nltk_data'))
logger.info("Downloaded NLTK punkt to user home directory")
except:
# Then try current directory
try:
os.makedirs('./nltk_data', exist_ok=True)
nltk.download('punkt', download_dir='./nltk_data')
logger.info("Downloaded NLTK punkt to current directory")
except Exception as e:
logger.error(f"Failed to download NLTK punkt: {str(e)}")
# Continue anyway, as we might have the data elsewhere
# Vocabulary class for loading the vocabulary
class Vocabulary:
def __init__(self):
self.word2idx = {}
self.idx2word = {}
self.idx = 0
def add_word(self, word):
if word not in self.word2idx:
self.word2idx[word] = self.idx
self.idx2word[self.idx] = word
self.idx += 1
def __len__(self):
return len(self.word2idx)
def tokenize(self, text):
"""Tokenize text into a list of tokens"""
tokens = nltk.tokenize.word_tokenize(str(text).lower())
return tokens
@classmethod
def load(cls, path):
"""Load vocabulary from pickle file"""
# Try multiple strategies to load the vocabulary
try:
# Strategy 1: Use a custom unpickler with more comprehensive handling
class CustomUnpickler(pickle.Unpickler):
def find_class(self, module, name):
# Check for Vocabulary in any module path
if name == 'Vocabulary':
# Try to find Vocabulary in different possible modules
# First in this current module
return Vocabulary
# Check for special cases
if module == '__main__':
# Look in typical modules where the class might be defined
if name == 'Vocabulary':
return Vocabulary
# Default behavior
return super().find_class(module, name)
with open(path, 'rb') as f:
return CustomUnpickler(f).load()
except Exception as e:
logger.error(f"First loading method failed: {str(e)}")
try:
# Strategy 2: Manual recreation of vocabulary object from raw pickle data
with open(path, 'rb') as f:
raw_data = pickle.load(f)
# If it's a dict-like object, we can try to extract the vocabulary data
if hasattr(raw_data, 'word2idx') and hasattr(raw_data, 'idx2word'):
# Create a new Vocabulary instance
vocab = Vocabulary()
vocab.word2idx = raw_data.word2idx
vocab.idx2word = raw_data.idx2word
vocab.idx = raw_data.idx
return vocab
else:
# Create a fresh vocabulary directly from the dictionary data
vocab = Vocabulary()
# Try to extract word mappings from whatever structure the pickle has
if isinstance(raw_data, dict):
if 'word2idx' in raw_data and 'idx2word' in raw_data:
vocab.word2idx = raw_data['word2idx']
vocab.idx2word = raw_data['idx2word']
vocab.idx = len(vocab.word2idx)
return vocab
raise ValueError("Could not extract vocabulary data from pickle file")
except Exception as e:
logger.error(f"Second loading method failed: {str(e)}")
# Try to use fix_vocab_pickle as a last resort
try:
from app.fix_vocab_pickle import fix_vocab_pickle
fixed_path = path + "_fixed.pkl"
vocab = fix_vocab_pickle(path, fixed_path)
if vocab:
logger.info(f"Vocabulary fixed and saved to {fixed_path}")
return vocab
except Exception as e:
logger.error(f"Vocabulary fixing failed: {str(e)}")
raise RuntimeError(f"All vocabulary loading methods failed. Original error: {str(e)}")
# Encoder: Pretrained ResNet
class EncoderCNN(torch.nn.Module):
def __init__(self, embed_dim):
super(EncoderCNN, self).__init__()
# Load pretrained ResNet
import torchvision.models as models
# Try different approaches to load ResNet50
resnet = None
# Option 1: Try to load the locally saved model
try:
logger.info("Trying to load locally saved ResNet50 model...")
resnet = models.resnet50(pretrained=False)
local_model_path = "app/models/resnet50.pth"
if os.path.exists(local_model_path):
resnet.load_state_dict(torch.load(local_model_path))
logger.info("Successfully loaded ResNet50 from local file")
else:
logger.warning(f"Local ResNet50 model not found at {local_model_path}")
# Fall back to pretrained model
resnet = None
except Exception as e:
logger.warning(f"Error loading local ResNet50 model: {str(e)}")
resnet = None
# Option 2: Try loading with pretrained weights
if resnet is None:
try:
logger.info("Trying to load ResNet50 with pretrained weights...")
# Set cache directory
os.makedirs('/tmp/torch_cache', exist_ok=True)
os.environ['TORCH_HOME'] = '/tmp/torch_cache'
resnet = models.resnet50(pretrained=True)
logger.info("Successfully loaded pretrained ResNet50 model")
except Exception as e:
logger.warning(f"Error loading pretrained ResNet50: {str(e)}")
resnet = None
# Option 3: Fall back to model without pretrained weights
if resnet is None:
logger.info("Falling back to ResNet50 without pretrained weights...")
resnet = models.resnet50(pretrained=False)
logger.warning("Using ResNet50 WITHOUT pretrained weights - captions may be less accurate")
# Remove the final FC layer
modules = list(resnet.children())[:-1]
self.resnet = torch.nn.Sequential(*modules)
# Project to embedding dimension
self.fc = torch.nn.Linear(resnet.fc.in_features, embed_dim)
self.bn = torch.nn.BatchNorm1d(embed_dim)
self.dropout = torch.nn.Dropout(0.5)
def forward(self, images):
with torch.no_grad(): # No gradients for pretrained model
features = self.resnet(images)
features = features.reshape(features.size(0), -1)
features = self.fc(features)
features = self.bn(features)
features = self.dropout(features)
return features
# Positional Encoding for Transformer
class PositionalEncoding(torch.nn.Module):
def __init__(self, d_model, max_len=5000):
super(PositionalEncoding, self).__init__()
# Create positional encoding
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
# Register buffer (not model parameter)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + self.pe[:, :x.size(1), :].to(x.device)
return x
# Custom Transformer Decoder
class TransformerDecoder(torch.nn.Module):
def __init__(self, vocab_size, embed_dim, num_heads, ff_dim, num_layers, dropout=0.1):
super(TransformerDecoder, self).__init__()
import math
# Store math module as an instance variable so we can use it in forward
self.math = math
# Embedding layer
self.embedding = torch.nn.Embedding(vocab_size, embed_dim)
self.positional_encoding = PositionalEncoding(embed_dim)
# Transformer decoder layers
decoder_layer = torch.nn.TransformerDecoderLayer(
d_model=embed_dim,
nhead=num_heads,
dim_feedforward=ff_dim,
dropout=dropout,
batch_first=True
)
self.transformer_decoder = torch.nn.TransformerDecoder(
decoder_layer,
num_layers=num_layers
)
# Output layer
self.fc = torch.nn.Linear(embed_dim, vocab_size)
self.dropout = torch.nn.Dropout(dropout)
def generate_square_subsequent_mask(self, sz):
# Create mask to prevent attention to future tokens
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
return mask
def forward(self, tgt, memory):
# Create mask for decoder
tgt_mask = self.generate_square_subsequent_mask(tgt.size(1)).to(tgt.device)
# Embed tokens and add positional encoding
tgt = self.embedding(tgt) * self.math.sqrt(self.embedding.embedding_dim)
tgt = self.positional_encoding(tgt)
tgt = self.dropout(tgt)
# Pass through transformer decoder
output = self.transformer_decoder(
tgt,
memory,
tgt_mask=tgt_mask
)
# Project to vocabulary
output = self.fc(output)
return output
# Complete Image Captioning Model
class ImageCaptioningModel(torch.nn.Module):
def __init__(self, vocab_size, embed_dim, hidden_dim, num_heads, num_layers):
super(ImageCaptioningModel, self).__init__()
# Make sure math is available
self.math = math
# Image encoder
self.encoder = EncoderCNN(embed_dim)
# Caption decoder
self.decoder = TransformerDecoder(
vocab_size=vocab_size,
embed_dim=embed_dim,
num_heads=num_heads,
ff_dim=hidden_dim,
num_layers=num_layers
)
def forward(self, images, captions):
# Encode images
img_features = self.encoder(images)
# Reshape for transformer (batch_size, seq_len, embed_dim)
# In this case, seq_len=1 since we have a single "token" representing the image
img_features = img_features.unsqueeze(1)
# Decode captions (excluding the last token, typically <EOS>)
outputs = self.decoder(captions[:, :-1], img_features)
return outputs
def generate_caption(self, image, vocab, max_length=20):
"""Generate a caption for the given image"""
with torch.no_grad():
# Encode image
img_features = self.encoder(image.unsqueeze(0))
img_features = img_features.unsqueeze(1)
# Start with < SOS > token
current_ids = torch.tensor([[vocab.word2idx['<SOS>']]], dtype=torch.long).to(image.device)
# Generate words one by one
result_caption = []
for i in range(max_length):
# Predict next word
outputs = self.decoder(current_ids, img_features)
# Get the most likely next word
_, predicted = outputs[:, -1, :].max(1)
# Add predicted word to the sequence
result_caption.append(predicted.item())
# Break if <EOS>
if predicted.item() == vocab.word2idx['<EOS>']:
break
# Add to current sequence for next iteration
current_ids = torch.cat([current_ids, predicted.unsqueeze(0)], dim=1)
# Convert word indices to words
words = [vocab.idx2word[idx] for idx in result_caption]
# Remove <EOS> token if present
if words and words[-1] == '<EOS>':
words = words[:-1]
return ' '.join(words)
def load_image(image_path, transform=None):
"""Load and preprocess an image"""
image = Image.open(image_path).convert('RGB')
if transform is None:
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
image = transform(image)
return image
def generate_caption(
image_path,
model_path,
vocab_path,
max_length=20,
device=None
):
"""Generate a caption for an image"""
# Set device
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")
# Check if files exist
if not os.path.exists(image_path):
raise FileNotFoundError(f"Image not found at {image_path}")
if not os.path.exists(model_path):
raise FileNotFoundError(f"Model not found at {model_path}")
if not os.path.exists(vocab_path):
raise FileNotFoundError(f"Vocabulary not found at {vocab_path}")
# Setup temporary cache directory for torch if needed
try:
os.makedirs('/tmp/torch_cache', exist_ok=True)
os.environ['TORCH_HOME'] = '/tmp/torch_cache'
logger.info(f"Set TORCH_HOME to /tmp/torch_cache")
except Exception as e:
logger.warning(f"Could not set up temporary torch cache: {e}")
# Load vocabulary
logger.info(f"Loading vocabulary from {vocab_path}")
vocab = Vocabulary.load(vocab_path)
logger.info(f"Loaded vocabulary with {len(vocab)} words")
# Load model
# Hyperparameters - must match those used during training
embed_dim = 512
hidden_dim = 2048
num_layers = 6
num_heads = 8
# Initialize model
logger.info("Initializing model")
model = ImageCaptioningModel(
vocab_size=len(vocab),
embed_dim=embed_dim,
hidden_dim=hidden_dim,
num_heads=num_heads,
num_layers=num_layers
).to(device)
# Load model weights
logger.info(f"Loading model weights from {model_path}")
try:
# First try our custom loader
try:
logger.info("Trying custom model loader...")
# Replace this with Python's built-in pickle that we can customize
# Define a custom unpickler
class CustomUnpickler(pickle.Unpickler):
def find_class(self, module, name):
# If it's looking for the Vocabulary class in __main__
if name == 'Vocabulary':
# Return our current Vocabulary class
return Vocabulary
if module == '__main__':
if name == 'ImageCaptioningModel':
return ImageCaptioningModel
if name == 'EncoderCNN':
return EncoderCNN
if name == 'TransformerDecoder':
return TransformerDecoder
if name == 'PositionalEncoding':
return PositionalEncoding
# Use the normal behavior for everything else
return super().find_class(module, name)
# Use a custom loading approach
with open(model_path, 'rb') as f:
checkpoint = CustomUnpickler(f).load()
logger.info("Successfully loaded model using custom unpickler")
except Exception as e:
logger.warning(f"Custom loader failed: {str(e)}")
logger.info("Falling back to standard torch.load...")
# Fall back to standard loader
checkpoint = torch.load(model_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
logger.info("Model loaded successfully")
except Exception as e:
logger.error(f"Error loading model: {str(e)}")
raise
# Load and process image
logger.info(f"Loading and processing image from {image_path}")
try:
image = load_image(image_path)
image = image.to(device)
logger.info("Image processed successfully")
except Exception as e:
logger.error(f"Error processing image: {str(e)}")
raise
# Generate caption
logger.info("Generating caption")
try:
caption = model.generate_caption(image, vocab, max_length=max_length)
logger.info(f"Generated caption: {caption}")
return caption
except Exception as e:
logger.error(f"Error generating caption: {str(e)}")
raise