|
|
|
|
|
import re
|
|
|
import nltk
|
|
|
import torch
|
|
|
import string
|
|
|
import logging
|
|
|
import unicodedata
|
|
|
from config import app_config
|
|
|
from typing import Dict, List, Tuple, Union, Optional, Any
|
|
|
from tokenizer import TokenizerWrapper, get_tokenizer
|
|
|
from nltk.tokenize import word_tokenize
|
|
|
from nltk.stem import WordNetLemmatizer
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
try:
|
|
|
nltk.download("punkt")
|
|
|
nltk.download("wordnet")
|
|
|
except Exception as e:
|
|
|
logger.warning(f"NLTK data download failed: {e}")
|
|
|
|
|
|
|
|
|
if hasattr(nltk, "download"):
|
|
|
try:
|
|
|
nltk.download('punkt', quiet=True)
|
|
|
nltk.download('averaged_perceptron_tagger', quiet=True)
|
|
|
except Exception as e:
|
|
|
logger.warning(f"NLTK download failed: {e}")
|
|
|
else:
|
|
|
logger.warning("NLTK.download not available; skipping corpus downloads")
|
|
|
|
|
|
def get_tokenizer_wrapper():
|
|
|
try:
|
|
|
tokenizer = get_tokenizer("bert-base-uncased")
|
|
|
return tokenizer
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error getting tokenizer: {e}")
|
|
|
return None
|
|
|
|
|
|
def get_lemmatizer():
|
|
|
try:
|
|
|
return WordNetLemmatizer()
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error initializing lemmatizer: {e}")
|
|
|
return None
|
|
|
|
|
|
def basic_tokenize(text: str):
|
|
|
try:
|
|
|
return word_tokenize(text)
|
|
|
except Exception as e:
|
|
|
logger.error(f"Basic tokenization failed: {e}")
|
|
|
return text.split()
|
|
|
|
|
|
def basic_stem(word: str):
|
|
|
lemmatizer = get_lemmatizer()
|
|
|
if lemmatizer:
|
|
|
try:
|
|
|
return lemmatizer.lemmatize(word)
|
|
|
except Exception as e:
|
|
|
logger.error(f"Lemmatization error: {e}")
|
|
|
return word
|
|
|
else:
|
|
|
return word
|
|
|
|
|
|
class Preprocessor:
|
|
|
"""A preprocessor class that performs text normalization, tokenization,
|
|
|
lemmatization, and converts tokens to token IDs with padding and attention masks."""
|
|
|
def __init__(self,
|
|
|
tokenizer: TokenizerWrapper = None,
|
|
|
lowercase: bool = True,
|
|
|
remove_special_chars: bool = True,
|
|
|
replace_multiple_spaces: bool = True,
|
|
|
max_length: int = None,
|
|
|
lemmatize: bool = False,
|
|
|
stride: int = None):
|
|
|
"""Initialize text preprocessor with options"""
|
|
|
|
|
|
|
|
|
if max_length is None:
|
|
|
if isinstance(app_config.TRANSFORMER_CONFIG, dict):
|
|
|
max_length = app_config.TRANSFORMER_CONFIG.get('MAX_SEQ_LENGTH', 512)
|
|
|
else:
|
|
|
max_length = getattr(app_config.TRANSFORMER_CONFIG, 'MAX_SEQ_LENGTH', 512)
|
|
|
|
|
|
|
|
|
if hasattr(app_config, 'PREPROCESSING'):
|
|
|
if isinstance(app_config.PREPROCESSING, dict):
|
|
|
config_lowercase = app_config.PREPROCESSING.get('LOWERCASE', lowercase)
|
|
|
config_remove_special = app_config.PREPROCESSING.get('REMOVE_SPECIAL_CHARACTERS', remove_special_chars)
|
|
|
config_replace_spaces = app_config.PREPROCESSING.get('REPLACE_MULTIPLE_SPACES', replace_multiple_spaces)
|
|
|
else:
|
|
|
config_lowercase = getattr(app_config.PREPROCESSING, 'LOWERCASE', lowercase)
|
|
|
config_remove_special = getattr(app_config.PREPROCESSING, 'REMOVE_SPECIAL_CHARACTERS', remove_special_chars)
|
|
|
config_replace_spaces = getattr(app_config.PREPROCESSING, 'REPLACE_MULTIPLE_SPACES', replace_multiple_spaces)
|
|
|
|
|
|
|
|
|
lowercase = config_lowercase
|
|
|
remove_special_chars = config_remove_special
|
|
|
replace_multiple_spaces = config_replace_spaces
|
|
|
|
|
|
self.lowercase = lowercase
|
|
|
self.remove_special_chars = remove_special_chars
|
|
|
self.replace_multiple_spaces = replace_multiple_spaces
|
|
|
self.max_length = max_length
|
|
|
self.lemmatize = lemmatize
|
|
|
self.stride = stride or (max_length // 2)
|
|
|
|
|
|
self.lemmatizer = WordNetLemmatizer() if lemmatize else None
|
|
|
if tokenizer is None:
|
|
|
self.tokenizer = TokenizerWrapper()
|
|
|
else:
|
|
|
self.tokenizer = tokenizer
|
|
|
|
|
|
def normalize_text(self, text: str) -> str:
|
|
|
"""
|
|
|
Normalizes the input text by removing punctuation, non-alphabetic characters,
|
|
|
and extra whitespace.
|
|
|
|
|
|
Args:
|
|
|
text (str): Raw input text.
|
|
|
|
|
|
Returns:
|
|
|
str: Normalized text.
|
|
|
"""
|
|
|
|
|
|
text = unicodedata.normalize('NFKD', text)
|
|
|
|
|
|
|
|
|
if self.lowercase:
|
|
|
text = text.lower()
|
|
|
|
|
|
|
|
|
if self.remove_special_chars:
|
|
|
text = re.sub(r'[^\w\s]', ' ', text)
|
|
|
|
|
|
|
|
|
if self.replace_multiple_spaces:
|
|
|
text = re.sub(r'\s+', ' ', text).strip()
|
|
|
|
|
|
return text
|
|
|
|
|
|
def tokenize_text(self, text: str) -> List[str]:
|
|
|
"""Tokenizes the normalized input text into words.
|
|
|
Args:
|
|
|
text (str): Normalized text.
|
|
|
Returns:
|
|
|
List[str]: List of tokens."""
|
|
|
return word_tokenize(text)
|
|
|
|
|
|
def process_with_sliding_window(self, token_ids: List[int]) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
"""Process long sequences using sliding window approach.
|
|
|
Args:
|
|
|
token_ids (List[int]): List of token IDs.
|
|
|
Returns:
|
|
|
Tuple[torch.Tensor, torch.Tensor]:
|
|
|
- Tensor of token IDs with shape (num_windows, max_length)
|
|
|
- Tensor of attention masks with shape (num_windows, max_length)"""
|
|
|
windows = []
|
|
|
attention_masks = []
|
|
|
|
|
|
for i in range(0, len(token_ids), self.stride):
|
|
|
window = token_ids[i:i + self.max_length]
|
|
|
|
|
|
if len(window) < self.max_length:
|
|
|
pad_id = self.tokenizer.tokenizer.token_to_id("[PAD]")
|
|
|
padding_length = self.max_length - len(window)
|
|
|
window = window + [pad_id] * padding_length
|
|
|
|
|
|
attention_mask = [1] * min(self.max_length, len(token_ids) - i) + \
|
|
|
[0] * max(0, self.max_length - len(token_ids) + i)
|
|
|
|
|
|
windows.append(window)
|
|
|
attention_masks.append(attention_mask)
|
|
|
|
|
|
return (torch.tensor(windows, dtype=torch.long),
|
|
|
torch.tensor(attention_masks, dtype=torch.long))
|
|
|
|
|
|
def preprocess_text(self, text: str) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
"""Apply preprocessing steps to text"""
|
|
|
if not text or not isinstance(text, str):
|
|
|
return ""
|
|
|
|
|
|
|
|
|
text = self.normalize_text(text)
|
|
|
tokens = self.tokenize_text(text)
|
|
|
|
|
|
|
|
|
if self.lemmatize:
|
|
|
tokens = [self.lemmatizer.lemmatize(token) for token in tokens]
|
|
|
|
|
|
tokens = [token.lower() for token in tokens]
|
|
|
token_ids = self.tokenizer.tokenize(' '.join(tokens))
|
|
|
|
|
|
|
|
|
if len(token_ids) > self.max_length:
|
|
|
return self.process_with_sliding_window(token_ids)
|
|
|
|
|
|
|
|
|
pad_id = self.tokenizer.tokenizer.token_to_id("[PAD]")
|
|
|
padding_length = self.max_length - len(token_ids)
|
|
|
token_ids = token_ids + [pad_id] * padding_length
|
|
|
attention_mask = [1] * len(token_ids) + [0] * padding_length
|
|
|
|
|
|
return (torch.tensor([token_ids], dtype=torch.long),
|
|
|
torch.tensor([attention_mask], dtype=torch.long))
|
|
|
|
|
|
def preprocess_record(self, record: Dict[str, Any]) -> Dict[str, Any]:
|
|
|
"""Preprocess a data record, handling text fields appropriately"""
|
|
|
if not isinstance(record, dict):
|
|
|
logger.warning(f"Expected dict for record preprocessing, got {type(record)}")
|
|
|
return record
|
|
|
|
|
|
processed_record = {}
|
|
|
|
|
|
for key, value in record.items():
|
|
|
if isinstance(value, str):
|
|
|
processed_record[key] = self.preprocess_text(value)
|
|
|
else:
|
|
|
processed_record[key] = value
|
|
|
|
|
|
return processed_record
|
|
|
|
|
|
def preprocess_batch(self, texts: List[str]) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
"""Preprocesses a batch of texts.
|
|
|
Args:
|
|
|
texts (List[str]): A list of raw text strings.
|
|
|
Returns:
|
|
|
Tuple[torch.Tensor, torch.Tensor]:
|
|
|
- Tensor of token IDs with shape (batch_size, max_length)
|
|
|
- Tensor of attention masks with shape (batch_size, max_length)"""
|
|
|
batch_token_ids, batch_attention_masks = zip(
|
|
|
*[self.preprocess_text(text) for text in texts]
|
|
|
)
|
|
|
token_ids_tensor = torch.stack(batch_token_ids, dim=0)
|
|
|
attention_masks_tensor = torch.stack(batch_attention_masks, dim=0)
|
|
|
return token_ids_tensor, attention_masks_tensor
|
|
|
|
|
|
def convert_prediction_to_label(self, prediction: int) -> str:
|
|
|
"""Converts a numeric prediction into its corresponding label.
|
|
|
Args:
|
|
|
prediction (int): Numeric prediction.
|
|
|
Returns:
|
|
|
str: Mapped label."""
|
|
|
label_mapping = {
|
|
|
0: "Negative",
|
|
|
1: "Positive",
|
|
|
|
|
|
}
|
|
|
return label_mapping.get(prediction, "Unknown")
|
|
|
|
|
|
class MemoryAugmentedPreprocessor(Preprocessor):
|
|
|
"""Enhanced preprocessor with memory mechanism for long-range dependencies."""
|
|
|
def __init__(self, tokenizer: TokenizerWrapper = None,
|
|
|
max_length: int = None,
|
|
|
stride: int = None,
|
|
|
memory_size: int = 64):
|
|
|
|
|
|
if max_length is None:
|
|
|
if isinstance(app_config.TRANSFORMER_CONFIG, dict):
|
|
|
max_length = app_config.TRANSFORMER_CONFIG.get('MAX_SEQ_LENGTH', 512)
|
|
|
else:
|
|
|
max_length = getattr(app_config.TRANSFORMER_CONFIG, 'MAX_SEQ_LENGTH', 512)
|
|
|
|
|
|
super().__init__(tokenizer=tokenizer, max_length=max_length, stride=stride)
|
|
|
self.memory_size = memory_size
|
|
|
self.memory_bank = []
|
|
|
self.effective_length = max_length - memory_size
|
|
|
|
|
|
def update_memory(self, window_tokens: List[int]):
|
|
|
"""Update memory bank with key information from current window."""
|
|
|
key_tokens = self.extract_key_tokens(window_tokens)
|
|
|
self.memory_bank = (self.memory_bank + key_tokens)[-self.memory_size:]
|
|
|
|
|
|
def extract_key_tokens(self, tokens: List[int]) -> List[int]:
|
|
|
"""Extract important tokens from the window."""
|
|
|
|
|
|
return tokens[:self.memory_size]
|
|
|
|
|
|
def select_relevant_memory(self, current_tokens: List[int]) -> List[int]:
|
|
|
"""Select relevant memory tokens for current window."""
|
|
|
return self.memory_bank[-self.memory_size:]
|
|
|
|
|
|
def process_with_sliding_window(self, token_ids: List[int]) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
"""Process long sequences using sliding window with memory mechanism."""
|
|
|
windows = []
|
|
|
attention_masks = []
|
|
|
|
|
|
for i in range(0, len(token_ids), self.stride):
|
|
|
|
|
|
window = token_ids[i:i + self.effective_length]
|
|
|
|
|
|
|
|
|
if self.memory_bank:
|
|
|
memory_tokens = self.select_relevant_memory(window)
|
|
|
window = memory_tokens + window
|
|
|
|
|
|
|
|
|
self.update_memory(window)
|
|
|
|
|
|
|
|
|
if len(window) < self.max_length:
|
|
|
pad_id = self.tokenizer.tokenizer.token_to_id("[PAD]")
|
|
|
padding_length = self.max_length - len(window)
|
|
|
window = window + [pad_id] * padding_length
|
|
|
|
|
|
|
|
|
attention_mask = [1] * min(self.max_length, len(window)) + \
|
|
|
[0] * max(0, self.max_length - len(window))
|
|
|
|
|
|
windows.append(window)
|
|
|
attention_masks.append(attention_mask)
|
|
|
|
|
|
return (torch.tensor(windows, dtype=torch.long),
|
|
|
torch.tensor(attention_masks, dtype=torch.long))
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
|
preprocessor = MemoryAugmentedPreprocessor(
|
|
|
max_length=256,
|
|
|
memory_size=64,
|
|
|
stride=128
|
|
|
)
|
|
|
|
|
|
|
|
|
long_text = """
|
|
|
def example_function():
|
|
|
# This is a long function
|
|
|
# with multiple lines
|
|
|
pass
|
|
|
"""
|
|
|
|
|
|
|
|
|
token_ids, attention_mask = preprocessor.preprocess_text(long_text)
|
|
|
print(f"Processed shape: {token_ids.shape}, {attention_mask.shape}")
|
|
|
|
|
|
|
|
|
def preprocess_text():
|
|
|
... |