Spaces:
Sleeping
Sleeping
| """ | |
| Script to fix the vocabulary pickle file by recreating it with correct module information. | |
| Run this script if you're still experiencing Vocabulary loading issues. | |
| """ | |
| import pickle | |
| import os | |
| import sys | |
| import nltk | |
| import logging | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Make sure NLTK tokenizer is available | |
| try: | |
| nltk.data.find('tokenizers/punkt') | |
| except LookupError: | |
| # Try to download to a directory where we have write permissions | |
| try: | |
| # Try user home directory first | |
| nltk.download('punkt', download_dir=os.path.expanduser('~/.nltk_data')) | |
| logger.info("Downloaded NLTK punkt to user home directory") | |
| except: | |
| # Then try current directory | |
| try: | |
| os.makedirs('./nltk_data', exist_ok=True) | |
| nltk.download('punkt', download_dir='./nltk_data') | |
| logger.info("Downloaded NLTK punkt to current directory") | |
| except Exception as e: | |
| logger.error(f"Failed to download NLTK punkt: {str(e)}") | |
| # Continue anyway, as we might have the data elsewhere | |
| # Vocabulary class for loading the vocabulary | |
| class Vocabulary: | |
| def __init__(self): | |
| self.word2idx = {} | |
| self.idx2word = {} | |
| self.idx = 0 | |
| def add_word(self, word): | |
| if word not in self.word2idx: | |
| self.word2idx[word] = self.idx | |
| self.idx2word[self.idx] = word | |
| self.idx += 1 | |
| def __len__(self): | |
| return len(self.word2idx) | |
| def tokenize(self, text): | |
| """Tokenize text into a list of tokens""" | |
| tokens = nltk.tokenize.word_tokenize(str(text).lower()) | |
| return tokens | |
| def fix_vocab_pickle(input_path, output_path): | |
| """ | |
| Load the vocabulary pickle file and create a new one with updated module information. | |
| """ | |
| try: | |
| logger.info(f"Attempting to load vocabulary from {input_path}...") | |
| # Try first with a very permissive custom unpickler | |
| class FixerUnpickler(pickle.Unpickler): | |
| def find_class(self, module, name): | |
| # For any class named Vocabulary, use our Vocabulary class | |
| if name == 'Vocabulary': | |
| return Vocabulary | |
| # Attempt default behavior, but catch and handle potential errors | |
| try: | |
| return super().find_class(module, name) | |
| except: | |
| # If we can't find the class in the specified module, try to find an equivalent | |
| if name == 'Vocabulary': | |
| return Vocabulary | |
| # For other classes, we might need more specific handling | |
| raise | |
| # Try to load with our custom unpickler | |
| with open(input_path, 'rb') as f: | |
| try: | |
| vocab = FixerUnpickler(f).load() | |
| logger.info("Successfully loaded vocabulary!") | |
| except Exception as e: | |
| logger.warning(f"Custom unpickler failed: {str(e)}") | |
| # If that fails, try raw load and extract data | |
| f.seek(0) # Reset file pointer | |
| try: | |
| raw_data = pickle.load(f) | |
| logger.info("Loaded raw data, attempting to extract vocabulary...") | |
| # Create a new vocabulary | |
| vocab = Vocabulary() | |
| # Try to extract the necessary data | |
| if hasattr(raw_data, 'word2idx') and hasattr(raw_data, 'idx2word'): | |
| vocab.word2idx = raw_data.word2idx | |
| vocab.idx2word = raw_data.idx2word | |
| vocab.idx = raw_data.idx if hasattr(raw_data, 'idx') else len(vocab.word2idx) | |
| elif isinstance(raw_data, dict) and 'word2idx' in raw_data and 'idx2word' in raw_data: | |
| vocab.word2idx = raw_data['word2idx'] | |
| vocab.idx2word = raw_data['idx2word'] | |
| vocab.idx = raw_data.get('idx', len(vocab.word2idx)) | |
| else: | |
| logger.error("Could not extract vocabulary data from the pickle file.") | |
| logger.error(f"Raw data type: {type(raw_data)}") | |
| return None | |
| except Exception as e: | |
| logger.error(f"Raw data extraction failed: {str(e)}") | |
| return None | |
| # Save the vocabulary with the correct module information | |
| logger.info(f"Saving fixed vocabulary to {output_path}...") | |
| with open(output_path, 'wb') as f: | |
| pickle.dump(vocab, f, protocol=pickle.HIGHEST_PROTOCOL) | |
| logger.info(f"Vocabulary successfully fixed and saved to {output_path}") | |
| logger.info(f"Vocabulary size: {len(vocab)} words") | |
| logger.info(f"Sample words: {list(vocab.word2idx.keys())[:5]}") | |
| return vocab | |
| except Exception as e: | |
| logger.error(f"An error occurred: {str(e)}") | |
| return None | |
| if __name__ == "__main__": | |
| # Parse command line arguments | |
| import argparse | |
| parser = argparse.ArgumentParser(description='Fix vocabulary pickle file') | |
| parser.add_argument('--input', type=str, default='app/models/vocab.pkl', help='Path to the input vocabulary pickle file') | |
| parser.add_argument('--output', type=str, default='app/models/vocab_fixed.pkl', help='Path to save the fixed vocabulary pickle file') | |
| args = parser.parse_args() | |
| # Run the fix function | |
| vocab = fix_vocab_pickle(args.input, args.output) | |
| if vocab is not None: | |
| logger.info("\nTo use the fixed vocabulary, update your paths to use the new file.") |