image-captioning-api / app /fix_vocab_pickle.py
dixisouls's picture
nltk error
91a5e40
"""
Script to fix the vocabulary pickle file by recreating it with correct module information.
Run this script if you're still experiencing Vocabulary loading issues.
"""
import pickle
import os
import sys
import nltk
import logging
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Make sure NLTK tokenizer is available
try:
nltk.data.find('tokenizers/punkt')
except LookupError:
# Try to download to a directory where we have write permissions
try:
# Try user home directory first
nltk.download('punkt', download_dir=os.path.expanduser('~/.nltk_data'))
logger.info("Downloaded NLTK punkt to user home directory")
except:
# Then try current directory
try:
os.makedirs('./nltk_data', exist_ok=True)
nltk.download('punkt', download_dir='./nltk_data')
logger.info("Downloaded NLTK punkt to current directory")
except Exception as e:
logger.error(f"Failed to download NLTK punkt: {str(e)}")
# Continue anyway, as we might have the data elsewhere
# Vocabulary class for loading the vocabulary
class Vocabulary:
def __init__(self):
self.word2idx = {}
self.idx2word = {}
self.idx = 0
def add_word(self, word):
if word not in self.word2idx:
self.word2idx[word] = self.idx
self.idx2word[self.idx] = word
self.idx += 1
def __len__(self):
return len(self.word2idx)
def tokenize(self, text):
"""Tokenize text into a list of tokens"""
tokens = nltk.tokenize.word_tokenize(str(text).lower())
return tokens
def fix_vocab_pickle(input_path, output_path):
"""
Load the vocabulary pickle file and create a new one with updated module information.
"""
try:
logger.info(f"Attempting to load vocabulary from {input_path}...")
# Try first with a very permissive custom unpickler
class FixerUnpickler(pickle.Unpickler):
def find_class(self, module, name):
# For any class named Vocabulary, use our Vocabulary class
if name == 'Vocabulary':
return Vocabulary
# Attempt default behavior, but catch and handle potential errors
try:
return super().find_class(module, name)
except:
# If we can't find the class in the specified module, try to find an equivalent
if name == 'Vocabulary':
return Vocabulary
# For other classes, we might need more specific handling
raise
# Try to load with our custom unpickler
with open(input_path, 'rb') as f:
try:
vocab = FixerUnpickler(f).load()
logger.info("Successfully loaded vocabulary!")
except Exception as e:
logger.warning(f"Custom unpickler failed: {str(e)}")
# If that fails, try raw load and extract data
f.seek(0) # Reset file pointer
try:
raw_data = pickle.load(f)
logger.info("Loaded raw data, attempting to extract vocabulary...")
# Create a new vocabulary
vocab = Vocabulary()
# Try to extract the necessary data
if hasattr(raw_data, 'word2idx') and hasattr(raw_data, 'idx2word'):
vocab.word2idx = raw_data.word2idx
vocab.idx2word = raw_data.idx2word
vocab.idx = raw_data.idx if hasattr(raw_data, 'idx') else len(vocab.word2idx)
elif isinstance(raw_data, dict) and 'word2idx' in raw_data and 'idx2word' in raw_data:
vocab.word2idx = raw_data['word2idx']
vocab.idx2word = raw_data['idx2word']
vocab.idx = raw_data.get('idx', len(vocab.word2idx))
else:
logger.error("Could not extract vocabulary data from the pickle file.")
logger.error(f"Raw data type: {type(raw_data)}")
return None
except Exception as e:
logger.error(f"Raw data extraction failed: {str(e)}")
return None
# Save the vocabulary with the correct module information
logger.info(f"Saving fixed vocabulary to {output_path}...")
with open(output_path, 'wb') as f:
pickle.dump(vocab, f, protocol=pickle.HIGHEST_PROTOCOL)
logger.info(f"Vocabulary successfully fixed and saved to {output_path}")
logger.info(f"Vocabulary size: {len(vocab)} words")
logger.info(f"Sample words: {list(vocab.word2idx.keys())[:5]}")
return vocab
except Exception as e:
logger.error(f"An error occurred: {str(e)}")
return None
if __name__ == "__main__":
# Parse command line arguments
import argparse
parser = argparse.ArgumentParser(description='Fix vocabulary pickle file')
parser.add_argument('--input', type=str, default='app/models/vocab.pkl', help='Path to the input vocabulary pickle file')
parser.add_argument('--output', type=str, default='app/models/vocab_fixed.pkl', help='Path to save the fixed vocabulary pickle file')
args = parser.parse_args()
# Run the fix function
vocab = fix_vocab_pickle(args.input, args.output)
if vocab is not None:
logger.info("\nTo use the fixed vocabulary, update your paths to use the new file.")