Yousif22's picture
Upload folder using huggingface_hub
fe63ba2 verified
# utils.py
import re
import string
from typing import Optional
def preprocess(text: str, model_type: str = "naive_bayes") -> str:
"""
Enhanced preprocessing function with model-specific optimizations
Args:
text (str): Input text to preprocess
model_type (str): Type of model ("naive_bayes" or "bert")
Returns:
str: Preprocessed text
"""
if not text or not isinstance(text, str):
return ""
# Basic cleaning
text = text.strip()
if model_type.lower() == "bert":
# BERT-specific preprocessing (less aggressive)
# BERT can handle punctuation and case better
# Remove excessive whitespace
text = re.sub(r'\s+', ' ', text)
# Remove URLs
text = re.sub(r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+', '', text)
# Remove email addresses
text = re.sub(r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b', '', text)
# Remove excessive punctuation (more than 2 consecutive)
text = re.sub(r'[.]{3,}', '...', text)
text = re.sub(r'[!]{2,}', '!', text)
text = re.sub(r'[?]{2,}', '?', text)
return text.strip()
else:
# Naive Bayes preprocessing (more aggressive cleaning)
# Convert to lowercase
text = text.lower()
# Remove URLs
text = re.sub(r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+', '', text)
# Remove email addresses
text = re.sub(r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b', '', text)
# Remove special financial symbols but keep dollar signs and percentages
text = re.sub(r'[^\w\s$%.-]', ' ', text)
# Handle numbers and percentages
text = re.sub(r'\b\d+\.\d+%\b', 'PERCENTAGE', text)
text = re.sub(r'\b\d+%\b', 'PERCENTAGE', text)
text = re.sub(r'\$\d+\.?\d*[KMB]?\b', 'DOLLAR_AMOUNT', text)
# Remove extra whitespace
text = re.sub(r'\s+', ' ', text)
return text.strip()
def clean_financial_text(text: str) -> str:
"""
Specialized cleaning for financial text
Args:
text (str): Financial text to clean
Returns:
str: Cleaned financial text
"""
if not text:
return ""
# Common financial abbreviations to preserve
financial_terms = {
'q1': 'first quarter',
'q2': 'second quarter',
'q3': 'third quarter',
'q4': 'fourth quarter',
'yoy': 'year over year',
'qoq': 'quarter over quarter',
'ipo': 'initial public offering',
'ceo': 'chief executive officer',
'cfo': 'chief financial officer',
'fed': 'federal reserve',
'gdp': 'gross domestic product',
'etf': 'exchange traded fund'
}
text_lower = text.lower()
for abbrev, full_form in financial_terms.items():
text_lower = text_lower.replace(abbrev, full_form)
return text_lower
def extract_financial_entities(text: str) -> dict:
"""
Extract financial entities from text
Args:
text (str): Input text
Returns:
dict: Dictionary containing extracted entities
"""
entities = {
'percentages': [],
'dollar_amounts': [],
'stock_symbols': [],
'quarters': [],
'years': []
}
# Extract percentages
percentages = re.findall(r'\b\d+\.?\d*%\b', text)
entities['percentages'] = percentages
# Extract dollar amounts
dollar_amounts = re.findall(r'\$\d+\.?\d*[KMB]?\b', text)
entities['dollar_amounts'] = dollar_amounts
# Extract potential stock symbols (2-5 uppercase letters)
stock_symbols = re.findall(r'\b[A-Z]{2,5}\b', text)
entities['stock_symbols'] = stock_symbols
# Extract quarters
quarters = re.findall(r'\bQ[1-4]\b|\b[1-4]Q\b', text, re.IGNORECASE)
entities['quarters'] = quarters
# Extract years
years = re.findall(r'\b20\d{2}\b', text)
entities['years'] = years
return entities
def get_text_stats(text: str) -> dict:
"""
Get basic statistics about the text
Args:
text (str): Input text
Returns:
dict: Text statistics
"""
if not text:
return {
'word_count': 0,
'char_count': 0,
'sentence_count': 0,
'avg_word_length': 0
}
words = text.split()
sentences = re.split(r'[.!?]+', text)
stats = {
'word_count': len(words),
'char_count': len(text),
'sentence_count': len([s for s in sentences if s.strip()]),
'avg_word_length': sum(len(word) for word in words) / len(words) if words else 0
}
return stats
def validate_input(text: str, min_length: int = 5, max_length: int = 1000) -> tuple[bool, str]:
"""
Validate user input
Args:
text (str): Input text to validate
min_length (int): Minimum required length
max_length (int): Maximum allowed length
Returns:
tuple: (is_valid, error_message)
"""
if not text or not text.strip():
return False, "Text cannot be empty"
if len(text.strip()) < min_length:
return False, f"Text must be at least {min_length} characters long"
if len(text) > max_length:
return False, f"Text cannot exceed {max_length} characters"
# Check if text contains only special characters
if re.match(r'^[^\w\s]+$', text.strip()):
return False, "Text must contain alphanumeric characters"
return True, ""