# preprocess.py import re import nltk import torch import string import logging import unicodedata from config import app_config from typing import Dict, List, Tuple, Union, Optional, Any from tokenizer import TokenizerWrapper, get_tokenizer from nltk.tokenize import word_tokenize from nltk.stem import WordNetLemmatizer logger = logging.getLogger(__name__) # Attempt to download NLTK data; if fails, log warning. try: nltk.download("punkt") nltk.download("wordnet") except Exception as e: logger.warning(f"NLTK data download failed: {e}") # Guarded NLTK downloads if hasattr(nltk, "download"): try: nltk.download('punkt', quiet=True) nltk.download('averaged_perceptron_tagger', quiet=True) except Exception as e: logger.warning(f"NLTK download failed: {e}") else: logger.warning("NLTK.download not available; skipping corpus downloads") def get_tokenizer_wrapper(): try: tokenizer = get_tokenizer("bert-base-uncased") return tokenizer except Exception as e: logger.error(f"Error getting tokenizer: {e}") return None def get_lemmatizer(): try: return WordNetLemmatizer() except Exception as e: logger.error(f"Error initializing lemmatizer: {e}") return None def basic_tokenize(text: str): try: return word_tokenize(text) except Exception as e: logger.error(f"Basic tokenization failed: {e}") return text.split() def basic_stem(word: str): lemmatizer = get_lemmatizer() if lemmatizer: try: return lemmatizer.lemmatize(word) except Exception as e: logger.error(f"Lemmatization error: {e}") return word else: return word class Preprocessor: """A preprocessor class that performs text normalization, tokenization, lemmatization, and converts tokens to token IDs with padding and attention masks.""" def __init__(self, tokenizer: TokenizerWrapper = None, lowercase: bool = True, remove_special_chars: bool = True, replace_multiple_spaces: bool = True, max_length: int = None, lemmatize: bool = False, stride: int = None): """Initialize text preprocessor with options""" # Get configuration values with support for both dict and object access if max_length is None: if isinstance(app_config.TRANSFORMER_CONFIG, dict): max_length = app_config.TRANSFORMER_CONFIG.get('MAX_SEQ_LENGTH', 512) else: max_length = getattr(app_config.TRANSFORMER_CONFIG, 'MAX_SEQ_LENGTH', 512) # Get preprocessing options from config if available if hasattr(app_config, 'PREPROCESSING'): if isinstance(app_config.PREPROCESSING, dict): config_lowercase = app_config.PREPROCESSING.get('LOWERCASE', lowercase) config_remove_special = app_config.PREPROCESSING.get('REMOVE_SPECIAL_CHARACTERS', remove_special_chars) config_replace_spaces = app_config.PREPROCESSING.get('REPLACE_MULTIPLE_SPACES', replace_multiple_spaces) else: config_lowercase = getattr(app_config.PREPROCESSING, 'LOWERCASE', lowercase) config_remove_special = getattr(app_config.PREPROCESSING, 'REMOVE_SPECIAL_CHARACTERS', remove_special_chars) config_replace_spaces = getattr(app_config.PREPROCESSING, 'REPLACE_MULTIPLE_SPACES', replace_multiple_spaces) # Use config values if available lowercase = config_lowercase remove_special_chars = config_remove_special replace_multiple_spaces = config_replace_spaces self.lowercase = lowercase self.remove_special_chars = remove_special_chars self.replace_multiple_spaces = replace_multiple_spaces self.max_length = max_length self.lemmatize = lemmatize self.stride = stride or (max_length // 2) # Default 50% overlap self.lemmatizer = WordNetLemmatizer() if lemmatize else None if tokenizer is None: self.tokenizer = TokenizerWrapper() else: self.tokenizer = tokenizer def normalize_text(self, text: str) -> str: """ Normalizes the input text by removing punctuation, non-alphabetic characters, and extra whitespace. Args: text (str): Raw input text. Returns: str: Normalized text. """ # Normalize unicode characters text = unicodedata.normalize('NFKD', text) # Convert to lowercase if self.lowercase: text = text.lower() # Remove special characters if self.remove_special_chars: text = re.sub(r'[^\w\s]', ' ', text) # Replace multiple spaces with single space if self.replace_multiple_spaces: text = re.sub(r'\s+', ' ', text).strip() return text def tokenize_text(self, text: str) -> List[str]: """Tokenizes the normalized input text into words. Args: text (str): Normalized text. Returns: List[str]: List of tokens.""" return word_tokenize(text) def process_with_sliding_window(self, token_ids: List[int]) -> Tuple[torch.Tensor, torch.Tensor]: """Process long sequences using sliding window approach. Args: token_ids (List[int]): List of token IDs. Returns: Tuple[torch.Tensor, torch.Tensor]: - Tensor of token IDs with shape (num_windows, max_length) - Tensor of attention masks with shape (num_windows, max_length)""" windows = [] attention_masks = [] for i in range(0, len(token_ids), self.stride): window = token_ids[i:i + self.max_length] if len(window) < self.max_length: pad_id = self.tokenizer.tokenizer.token_to_id("[PAD]") padding_length = self.max_length - len(window) window = window + [pad_id] * padding_length attention_mask = [1] * min(self.max_length, len(token_ids) - i) + \ [0] * max(0, self.max_length - len(token_ids) + i) windows.append(window) attention_masks.append(attention_mask) return (torch.tensor(windows, dtype=torch.long), torch.tensor(attention_masks, dtype=torch.long)) def preprocess_text(self, text: str) -> Tuple[torch.Tensor, torch.Tensor]: """Apply preprocessing steps to text""" if not text or not isinstance(text, str): return "" # Normalize and tokenize as before text = self.normalize_text(text) tokens = self.tokenize_text(text) # Apply lemmatization if enabled if self.lemmatize: tokens = [self.lemmatizer.lemmatize(token) for token in tokens] tokens = [token.lower() for token in tokens] token_ids = self.tokenizer.tokenize(' '.join(tokens)) # Use sliding window for long sequences if len(token_ids) > self.max_length: return self.process_with_sliding_window(token_ids) # Original processing for short sequences pad_id = self.tokenizer.tokenizer.token_to_id("[PAD]") padding_length = self.max_length - len(token_ids) token_ids = token_ids + [pad_id] * padding_length attention_mask = [1] * len(token_ids) + [0] * padding_length return (torch.tensor([token_ids], dtype=torch.long), torch.tensor([attention_mask], dtype=torch.long)) def preprocess_record(self, record: Dict[str, Any]) -> Dict[str, Any]: """Preprocess a data record, handling text fields appropriately""" if not isinstance(record, dict): logger.warning(f"Expected dict for record preprocessing, got {type(record)}") return record processed_record = {} for key, value in record.items(): if isinstance(value, str): processed_record[key] = self.preprocess_text(value) else: processed_record[key] = value return processed_record def preprocess_batch(self, texts: List[str]) -> Tuple[torch.Tensor, torch.Tensor]: """Preprocesses a batch of texts. Args: texts (List[str]): A list of raw text strings. Returns: Tuple[torch.Tensor, torch.Tensor]: - Tensor of token IDs with shape (batch_size, max_length) - Tensor of attention masks with shape (batch_size, max_length)""" batch_token_ids, batch_attention_masks = zip( *[self.preprocess_text(text) for text in texts] ) token_ids_tensor = torch.stack(batch_token_ids, dim=0) attention_masks_tensor = torch.stack(batch_attention_masks, dim=0) return token_ids_tensor, attention_masks_tensor def convert_prediction_to_label(self, prediction: int) -> str: """Converts a numeric prediction into its corresponding label. Args: prediction (int): Numeric prediction. Returns: str: Mapped label.""" label_mapping = { 0: "Negative", 1: "Positive", # Extend mapping as per your dataset. } return label_mapping.get(prediction, "Unknown") class MemoryAugmentedPreprocessor(Preprocessor): """Enhanced preprocessor with memory mechanism for long-range dependencies.""" def __init__(self, tokenizer: TokenizerWrapper = None, max_length: int = None, stride: int = None, memory_size: int = 64): # Get max_length from config if not provided if max_length is None: if isinstance(app_config.TRANSFORMER_CONFIG, dict): max_length = app_config.TRANSFORMER_CONFIG.get('MAX_SEQ_LENGTH', 512) else: max_length = getattr(app_config.TRANSFORMER_CONFIG, 'MAX_SEQ_LENGTH', 512) super().__init__(tokenizer=tokenizer, max_length=max_length, stride=stride) self.memory_size = memory_size self.memory_bank = [] self.effective_length = max_length - memory_size def update_memory(self, window_tokens: List[int]): """Update memory bank with key information from current window.""" key_tokens = self.extract_key_tokens(window_tokens) self.memory_bank = (self.memory_bank + key_tokens)[-self.memory_size:] def extract_key_tokens(self, tokens: List[int]) -> List[int]: """Extract important tokens from the window.""" # Keep first tokens of statements as key information return tokens[:self.memory_size] def select_relevant_memory(self, current_tokens: List[int]) -> List[int]: """Select relevant memory tokens for current window.""" return self.memory_bank[-self.memory_size:] def process_with_sliding_window(self, token_ids: List[int]) -> Tuple[torch.Tensor, torch.Tensor]: """Process long sequences using sliding window with memory mechanism.""" windows = [] attention_masks = [] for i in range(0, len(token_ids), self.stride): # Get current window window = token_ids[i:i + self.effective_length] # Add memory tokens if available if self.memory_bank: memory_tokens = self.select_relevant_memory(window) window = memory_tokens + window # Update memory self.update_memory(window) # Pad if needed if len(window) < self.max_length: pad_id = self.tokenizer.tokenizer.token_to_id("[PAD]") padding_length = self.max_length - len(window) window = window + [pad_id] * padding_length # Create attention mask attention_mask = [1] * min(self.max_length, len(window)) + \ [0] * max(0, self.max_length - len(window)) windows.append(window) attention_masks.append(attention_mask) return (torch.tensor(windows, dtype=torch.long), torch.tensor(attention_masks, dtype=torch.long)) # Example usage if __name__ == "__main__": # Initialize the memory-augmented preprocessor preprocessor = MemoryAugmentedPreprocessor( max_length=256, memory_size=64, stride=128 # 50% overlap ) # Example text long_text = """ def example_function(): # This is a long function # with multiple lines pass """ # Process the text token_ids, attention_mask = preprocessor.preprocess_text(long_text) print(f"Processed shape: {token_ids.shape}, {attention_mask.shape}") # Check if text preprocessing is handled properly. def preprocess_text(): ...