WildnerveAI's picture
Upload preprocess.py
a9865ad verified
# preprocess.py
import re
import nltk
import torch
import string
import logging
import unicodedata
from config import app_config
from typing import Dict, List, Tuple, Union, Optional, Any
from tokenizer import TokenizerWrapper, get_tokenizer
from nltk.tokenize import word_tokenize
from nltk.stem import WordNetLemmatizer
logger = logging.getLogger(__name__)
# Attempt to download NLTK data; if fails, log warning.
try:
nltk.download("punkt")
nltk.download("wordnet")
except Exception as e:
logger.warning(f"NLTK data download failed: {e}")
# Guarded NLTK downloads
if hasattr(nltk, "download"):
try:
nltk.download('punkt', quiet=True)
nltk.download('averaged_perceptron_tagger', quiet=True)
except Exception as e:
logger.warning(f"NLTK download failed: {e}")
else:
logger.warning("NLTK.download not available; skipping corpus downloads")
def get_tokenizer_wrapper():
try:
tokenizer = get_tokenizer("bert-base-uncased")
return tokenizer
except Exception as e:
logger.error(f"Error getting tokenizer: {e}")
return None
def get_lemmatizer():
try:
return WordNetLemmatizer()
except Exception as e:
logger.error(f"Error initializing lemmatizer: {e}")
return None
def basic_tokenize(text: str):
try:
return word_tokenize(text)
except Exception as e:
logger.error(f"Basic tokenization failed: {e}")
return text.split()
def basic_stem(word: str):
lemmatizer = get_lemmatizer()
if lemmatizer:
try:
return lemmatizer.lemmatize(word)
except Exception as e:
logger.error(f"Lemmatization error: {e}")
return word
else:
return word
class Preprocessor:
"""A preprocessor class that performs text normalization, tokenization,
lemmatization, and converts tokens to token IDs with padding and attention masks."""
def __init__(self,
tokenizer: TokenizerWrapper = None,
lowercase: bool = True,
remove_special_chars: bool = True,
replace_multiple_spaces: bool = True,
max_length: int = None,
lemmatize: bool = False,
stride: int = None):
"""Initialize text preprocessor with options"""
# Get configuration values with support for both dict and object access
if max_length is None:
if isinstance(app_config.TRANSFORMER_CONFIG, dict):
max_length = app_config.TRANSFORMER_CONFIG.get('MAX_SEQ_LENGTH', 512)
else:
max_length = getattr(app_config.TRANSFORMER_CONFIG, 'MAX_SEQ_LENGTH', 512)
# Get preprocessing options from config if available
if hasattr(app_config, 'PREPROCESSING'):
if isinstance(app_config.PREPROCESSING, dict):
config_lowercase = app_config.PREPROCESSING.get('LOWERCASE', lowercase)
config_remove_special = app_config.PREPROCESSING.get('REMOVE_SPECIAL_CHARACTERS', remove_special_chars)
config_replace_spaces = app_config.PREPROCESSING.get('REPLACE_MULTIPLE_SPACES', replace_multiple_spaces)
else:
config_lowercase = getattr(app_config.PREPROCESSING, 'LOWERCASE', lowercase)
config_remove_special = getattr(app_config.PREPROCESSING, 'REMOVE_SPECIAL_CHARACTERS', remove_special_chars)
config_replace_spaces = getattr(app_config.PREPROCESSING, 'REPLACE_MULTIPLE_SPACES', replace_multiple_spaces)
# Use config values if available
lowercase = config_lowercase
remove_special_chars = config_remove_special
replace_multiple_spaces = config_replace_spaces
self.lowercase = lowercase
self.remove_special_chars = remove_special_chars
self.replace_multiple_spaces = replace_multiple_spaces
self.max_length = max_length
self.lemmatize = lemmatize
self.stride = stride or (max_length // 2) # Default 50% overlap
self.lemmatizer = WordNetLemmatizer() if lemmatize else None
if tokenizer is None:
self.tokenizer = TokenizerWrapper()
else:
self.tokenizer = tokenizer
def normalize_text(self, text: str) -> str:
"""
Normalizes the input text by removing punctuation, non-alphabetic characters,
and extra whitespace.
Args:
text (str): Raw input text.
Returns:
str: Normalized text.
"""
# Normalize unicode characters
text = unicodedata.normalize('NFKD', text)
# Convert to lowercase
if self.lowercase:
text = text.lower()
# Remove special characters
if self.remove_special_chars:
text = re.sub(r'[^\w\s]', ' ', text)
# Replace multiple spaces with single space
if self.replace_multiple_spaces:
text = re.sub(r'\s+', ' ', text).strip()
return text
def tokenize_text(self, text: str) -> List[str]:
"""Tokenizes the normalized input text into words.
Args:
text (str): Normalized text.
Returns:
List[str]: List of tokens."""
return word_tokenize(text)
def process_with_sliding_window(self, token_ids: List[int]) -> Tuple[torch.Tensor, torch.Tensor]:
"""Process long sequences using sliding window approach.
Args:
token_ids (List[int]): List of token IDs.
Returns:
Tuple[torch.Tensor, torch.Tensor]:
- Tensor of token IDs with shape (num_windows, max_length)
- Tensor of attention masks with shape (num_windows, max_length)"""
windows = []
attention_masks = []
for i in range(0, len(token_ids), self.stride):
window = token_ids[i:i + self.max_length]
if len(window) < self.max_length:
pad_id = self.tokenizer.tokenizer.token_to_id("[PAD]")
padding_length = self.max_length - len(window)
window = window + [pad_id] * padding_length
attention_mask = [1] * min(self.max_length, len(token_ids) - i) + \
[0] * max(0, self.max_length - len(token_ids) + i)
windows.append(window)
attention_masks.append(attention_mask)
return (torch.tensor(windows, dtype=torch.long),
torch.tensor(attention_masks, dtype=torch.long))
def preprocess_text(self, text: str) -> Tuple[torch.Tensor, torch.Tensor]:
"""Apply preprocessing steps to text"""
if not text or not isinstance(text, str):
return ""
# Normalize and tokenize as before
text = self.normalize_text(text)
tokens = self.tokenize_text(text)
# Apply lemmatization if enabled
if self.lemmatize:
tokens = [self.lemmatizer.lemmatize(token) for token in tokens]
tokens = [token.lower() for token in tokens]
token_ids = self.tokenizer.tokenize(' '.join(tokens))
# Use sliding window for long sequences
if len(token_ids) > self.max_length:
return self.process_with_sliding_window(token_ids)
# Original processing for short sequences
pad_id = self.tokenizer.tokenizer.token_to_id("[PAD]")
padding_length = self.max_length - len(token_ids)
token_ids = token_ids + [pad_id] * padding_length
attention_mask = [1] * len(token_ids) + [0] * padding_length
return (torch.tensor([token_ids], dtype=torch.long),
torch.tensor([attention_mask], dtype=torch.long))
def preprocess_record(self, record: Dict[str, Any]) -> Dict[str, Any]:
"""Preprocess a data record, handling text fields appropriately"""
if not isinstance(record, dict):
logger.warning(f"Expected dict for record preprocessing, got {type(record)}")
return record
processed_record = {}
for key, value in record.items():
if isinstance(value, str):
processed_record[key] = self.preprocess_text(value)
else:
processed_record[key] = value
return processed_record
def preprocess_batch(self, texts: List[str]) -> Tuple[torch.Tensor, torch.Tensor]:
"""Preprocesses a batch of texts.
Args:
texts (List[str]): A list of raw text strings.
Returns:
Tuple[torch.Tensor, torch.Tensor]:
- Tensor of token IDs with shape (batch_size, max_length)
- Tensor of attention masks with shape (batch_size, max_length)"""
batch_token_ids, batch_attention_masks = zip(
*[self.preprocess_text(text) for text in texts]
)
token_ids_tensor = torch.stack(batch_token_ids, dim=0)
attention_masks_tensor = torch.stack(batch_attention_masks, dim=0)
return token_ids_tensor, attention_masks_tensor
def convert_prediction_to_label(self, prediction: int) -> str:
"""Converts a numeric prediction into its corresponding label.
Args:
prediction (int): Numeric prediction.
Returns:
str: Mapped label."""
label_mapping = {
0: "Negative",
1: "Positive",
# Extend mapping as per your dataset.
}
return label_mapping.get(prediction, "Unknown")
class MemoryAugmentedPreprocessor(Preprocessor):
"""Enhanced preprocessor with memory mechanism for long-range dependencies."""
def __init__(self, tokenizer: TokenizerWrapper = None,
max_length: int = None,
stride: int = None,
memory_size: int = 64):
# Get max_length from config if not provided
if max_length is None:
if isinstance(app_config.TRANSFORMER_CONFIG, dict):
max_length = app_config.TRANSFORMER_CONFIG.get('MAX_SEQ_LENGTH', 512)
else:
max_length = getattr(app_config.TRANSFORMER_CONFIG, 'MAX_SEQ_LENGTH', 512)
super().__init__(tokenizer=tokenizer, max_length=max_length, stride=stride)
self.memory_size = memory_size
self.memory_bank = []
self.effective_length = max_length - memory_size
def update_memory(self, window_tokens: List[int]):
"""Update memory bank with key information from current window."""
key_tokens = self.extract_key_tokens(window_tokens)
self.memory_bank = (self.memory_bank + key_tokens)[-self.memory_size:]
def extract_key_tokens(self, tokens: List[int]) -> List[int]:
"""Extract important tokens from the window."""
# Keep first tokens of statements as key information
return tokens[:self.memory_size]
def select_relevant_memory(self, current_tokens: List[int]) -> List[int]:
"""Select relevant memory tokens for current window."""
return self.memory_bank[-self.memory_size:]
def process_with_sliding_window(self, token_ids: List[int]) -> Tuple[torch.Tensor, torch.Tensor]:
"""Process long sequences using sliding window with memory mechanism."""
windows = []
attention_masks = []
for i in range(0, len(token_ids), self.stride):
# Get current window
window = token_ids[i:i + self.effective_length]
# Add memory tokens if available
if self.memory_bank:
memory_tokens = self.select_relevant_memory(window)
window = memory_tokens + window
# Update memory
self.update_memory(window)
# Pad if needed
if len(window) < self.max_length:
pad_id = self.tokenizer.tokenizer.token_to_id("[PAD]")
padding_length = self.max_length - len(window)
window = window + [pad_id] * padding_length
# Create attention mask
attention_mask = [1] * min(self.max_length, len(window)) + \
[0] * max(0, self.max_length - len(window))
windows.append(window)
attention_masks.append(attention_mask)
return (torch.tensor(windows, dtype=torch.long),
torch.tensor(attention_masks, dtype=torch.long))
# Example usage
if __name__ == "__main__":
# Initialize the memory-augmented preprocessor
preprocessor = MemoryAugmentedPreprocessor(
max_length=256,
memory_size=64,
stride=128 # 50% overlap
)
# Example text
long_text = """
def example_function():
# This is a long function
# with multiple lines
pass
"""
# Process the text
token_ids, attention_mask = preprocessor.preprocess_text(long_text)
print(f"Processed shape: {token_ids.shape}, {attention_mask.shape}")
# Check if text preprocessing is handled properly.
def preprocess_text():
...