""" FastText model training. FastText is preferred over Word2Vec because: - Better handling of typos and misspellings (common in reviews) - Can generate vectors for out-of-vocabulary words - Uses character n-grams internally """ import logging from pathlib import Path from gensim.models import FastText from .config import MODELS_DIR, SETTINGS logger = logging.getLogger(__name__) class FastTextTrainer: """ Trains FastText word embeddings on review corpus. """ def __init__( self, vector_size: int | None = None, window: int | None = None, min_count: int | None = None, epochs: int | None = None, workers: int | None = None, ): """ Initialize trainer with hyperparameters. Args: vector_size: Dimensionality of word vectors window: Context window size min_count: Minimum word frequency epochs: Number of training iterations workers: Number of worker threads """ self.vector_size = vector_size or SETTINGS["fasttext_vector_size"] self.window = window or SETTINGS["fasttext_window"] self.min_count = min_count or SETTINGS["fasttext_min_count"] self.epochs = epochs or SETTINGS["fasttext_epochs"] self.workers = workers or SETTINGS["fasttext_workers"] self.model: FastText | None = None def train(self, sentences: list[list[str]]) -> FastText: """ Train FastText model on tokenized sentences. Args: sentences: List of tokenized documents (output from preprocessor) Returns: Trained FastText model """ logger.info( f"Training FastText model: " f"vector_size={self.vector_size}, window={self.window}, " f"min_count={self.min_count}, epochs={self.epochs}" ) logger.info(f"Training on {len(sentences)} documents") self.model = FastText( sentences=sentences, vector_size=self.vector_size, window=self.window, min_count=self.min_count, epochs=self.epochs, workers=self.workers, sg=1, # Skip-gram (better for semantic similarity) min_n=3, # Minimum character n-gram length max_n=6, # Maximum character n-gram length ) vocab_size = len(self.model.wv) logger.info(f"Training complete. Vocabulary size: {vocab_size}") return self.model def save(self, path: Path | str | None = None) -> Path: """ Save trained model. Args: path: Save path (default: models/fasttext.model) Returns: Path where model was saved """ if self.model is None: raise ValueError("No model to save. Train first.") path = Path(path) if path else MODELS_DIR / "fasttext.model" self.model.save(str(path)) logger.info(f"Saved model to {path}") return path def load(self, path: Path | str | None = None) -> FastText: """ Load model from file. Args: path: Model path (default: models/fasttext.model) Returns: Loaded FastText model """ path = Path(path) if path else MODELS_DIR / "fasttext.model" if not path.exists(): raise FileNotFoundError(f"Model not found at {path}") self.model = FastText.load(str(path)) vocab_size = len(self.model.wv) logger.info(f"Loaded model from {path}. Vocabulary size: {vocab_size}") return self.model def get_similar( self, word: str, topn: int = 10, ) -> list[tuple[str, float]]: """ Get most similar words to a given word. Args: word: Query word topn: Number of results Returns: List of (word, similarity) tuples """ if self.model is None: raise ValueError("No model loaded. Train or load first.") # Normalize word (space to underscore for phrases) word_normalized = word.lower().replace(" ", "_") try: return self.model.wv.most_similar(word_normalized, topn=topn) except KeyError: logger.warning(f"Word '{word}' not in vocabulary") return [] def get_similarity(self, word1: str, word2: str) -> float: """ Get similarity between two words. Args: word1: First word word2: Second word Returns: Cosine similarity (-1 to 1) """ if self.model is None: raise ValueError("No model loaded. Train or load first.") w1 = word1.lower().replace(" ", "_") w2 = word2.lower().replace(" ", "_") try: return float(self.model.wv.similarity(w1, w2)) except KeyError as e: logger.warning(f"Word not in vocabulary: {e}") return 0.0 def word_in_vocab(self, word: str) -> bool: """Check if word is in vocabulary.""" if self.model is None: return False word_normalized = word.lower().replace(" ", "_") return word_normalized in self.model.wv def get_vocab_words(self) -> list[str]: """Get all words in vocabulary.""" if self.model is None: return [] return list(self.model.wv.key_to_index.keys())