Spaces:
Running
Running
| """ | |
| FastText model training. | |
| FastText is preferred over Word2Vec because: | |
| - Better handling of typos and misspellings (common in reviews) | |
| - Can generate vectors for out-of-vocabulary words | |
| - Uses character n-grams internally | |
| """ | |
| import logging | |
| from pathlib import Path | |
| from gensim.models import FastText | |
| from .config import MODELS_DIR, SETTINGS | |
| logger = logging.getLogger(__name__) | |
| class FastTextTrainer: | |
| """ | |
| Trains FastText word embeddings on review corpus. | |
| """ | |
| def __init__( | |
| self, | |
| vector_size: int | None = None, | |
| window: int | None = None, | |
| min_count: int | None = None, | |
| epochs: int | None = None, | |
| workers: int | None = None, | |
| ): | |
| """ | |
| Initialize trainer with hyperparameters. | |
| Args: | |
| vector_size: Dimensionality of word vectors | |
| window: Context window size | |
| min_count: Minimum word frequency | |
| epochs: Number of training iterations | |
| workers: Number of worker threads | |
| """ | |
| self.vector_size = vector_size or SETTINGS["fasttext_vector_size"] | |
| self.window = window or SETTINGS["fasttext_window"] | |
| self.min_count = min_count or SETTINGS["fasttext_min_count"] | |
| self.epochs = epochs or SETTINGS["fasttext_epochs"] | |
| self.workers = workers or SETTINGS["fasttext_workers"] | |
| self.model: FastText | None = None | |
| def train(self, sentences: list[list[str]]) -> FastText: | |
| """ | |
| Train FastText model on tokenized sentences. | |
| Args: | |
| sentences: List of tokenized documents (output from preprocessor) | |
| Returns: | |
| Trained FastText model | |
| """ | |
| logger.info( | |
| f"Training FastText model: " | |
| f"vector_size={self.vector_size}, window={self.window}, " | |
| f"min_count={self.min_count}, epochs={self.epochs}" | |
| ) | |
| logger.info(f"Training on {len(sentences)} documents") | |
| self.model = FastText( | |
| sentences=sentences, | |
| vector_size=self.vector_size, | |
| window=self.window, | |
| min_count=self.min_count, | |
| epochs=self.epochs, | |
| workers=self.workers, | |
| sg=1, # Skip-gram (better for semantic similarity) | |
| min_n=3, # Minimum character n-gram length | |
| max_n=6, # Maximum character n-gram length | |
| ) | |
| vocab_size = len(self.model.wv) | |
| logger.info(f"Training complete. Vocabulary size: {vocab_size}") | |
| return self.model | |
| def save(self, path: Path | str | None = None) -> Path: | |
| """ | |
| Save trained model. | |
| Args: | |
| path: Save path (default: models/fasttext.model) | |
| Returns: | |
| Path where model was saved | |
| """ | |
| if self.model is None: | |
| raise ValueError("No model to save. Train first.") | |
| path = Path(path) if path else MODELS_DIR / "fasttext.model" | |
| self.model.save(str(path)) | |
| logger.info(f"Saved model to {path}") | |
| return path | |
| def load(self, path: Path | str | None = None) -> FastText: | |
| """ | |
| Load model from file. | |
| Args: | |
| path: Model path (default: models/fasttext.model) | |
| Returns: | |
| Loaded FastText model | |
| """ | |
| path = Path(path) if path else MODELS_DIR / "fasttext.model" | |
| if not path.exists(): | |
| raise FileNotFoundError(f"Model not found at {path}") | |
| self.model = FastText.load(str(path)) | |
| vocab_size = len(self.model.wv) | |
| logger.info(f"Loaded model from {path}. Vocabulary size: {vocab_size}") | |
| return self.model | |
| def get_similar( | |
| self, | |
| word: str, | |
| topn: int = 10, | |
| ) -> list[tuple[str, float]]: | |
| """ | |
| Get most similar words to a given word. | |
| Args: | |
| word: Query word | |
| topn: Number of results | |
| Returns: | |
| List of (word, similarity) tuples | |
| """ | |
| if self.model is None: | |
| raise ValueError("No model loaded. Train or load first.") | |
| # Normalize word (space to underscore for phrases) | |
| word_normalized = word.lower().replace(" ", "_") | |
| try: | |
| return self.model.wv.most_similar(word_normalized, topn=topn) | |
| except KeyError: | |
| logger.warning(f"Word '{word}' not in vocabulary") | |
| return [] | |
| def get_similarity(self, word1: str, word2: str) -> float: | |
| """ | |
| Get similarity between two words. | |
| Args: | |
| word1: First word | |
| word2: Second word | |
| Returns: | |
| Cosine similarity (-1 to 1) | |
| """ | |
| if self.model is None: | |
| raise ValueError("No model loaded. Train or load first.") | |
| w1 = word1.lower().replace(" ", "_") | |
| w2 = word2.lower().replace(" ", "_") | |
| try: | |
| return float(self.model.wv.similarity(w1, w2)) | |
| except KeyError as e: | |
| logger.warning(f"Word not in vocabulary: {e}") | |
| return 0.0 | |
| def word_in_vocab(self, word: str) -> bool: | |
| """Check if word is in vocabulary.""" | |
| if self.model is None: | |
| return False | |
| word_normalized = word.lower().replace(" ", "_") | |
| return word_normalized in self.model.wv | |
| def get_vocab_words(self) -> list[str]: | |
| """Get all words in vocabulary.""" | |
| if self.model is None: | |
| return [] | |
| return list(self.model.wv.key_to_index.keys()) | |