GitHub Action
deploy: worker release from GitHub
8ff1b66
"""
FastText model training.
FastText is preferred over Word2Vec because:
- Better handling of typos and misspellings (common in reviews)
- Can generate vectors for out-of-vocabulary words
- Uses character n-grams internally
"""
import logging
from pathlib import Path
from gensim.models import FastText
from .config import MODELS_DIR, SETTINGS
logger = logging.getLogger(__name__)
class FastTextTrainer:
"""
Trains FastText word embeddings on review corpus.
"""
def __init__(
self,
vector_size: int | None = None,
window: int | None = None,
min_count: int | None = None,
epochs: int | None = None,
workers: int | None = None,
):
"""
Initialize trainer with hyperparameters.
Args:
vector_size: Dimensionality of word vectors
window: Context window size
min_count: Minimum word frequency
epochs: Number of training iterations
workers: Number of worker threads
"""
self.vector_size = vector_size or SETTINGS["fasttext_vector_size"]
self.window = window or SETTINGS["fasttext_window"]
self.min_count = min_count or SETTINGS["fasttext_min_count"]
self.epochs = epochs or SETTINGS["fasttext_epochs"]
self.workers = workers or SETTINGS["fasttext_workers"]
self.model: FastText | None = None
def train(self, sentences: list[list[str]]) -> FastText:
"""
Train FastText model on tokenized sentences.
Args:
sentences: List of tokenized documents (output from preprocessor)
Returns:
Trained FastText model
"""
logger.info(
f"Training FastText model: "
f"vector_size={self.vector_size}, window={self.window}, "
f"min_count={self.min_count}, epochs={self.epochs}"
)
logger.info(f"Training on {len(sentences)} documents")
self.model = FastText(
sentences=sentences,
vector_size=self.vector_size,
window=self.window,
min_count=self.min_count,
epochs=self.epochs,
workers=self.workers,
sg=1, # Skip-gram (better for semantic similarity)
min_n=3, # Minimum character n-gram length
max_n=6, # Maximum character n-gram length
)
vocab_size = len(self.model.wv)
logger.info(f"Training complete. Vocabulary size: {vocab_size}")
return self.model
def save(self, path: Path | str | None = None) -> Path:
"""
Save trained model.
Args:
path: Save path (default: models/fasttext.model)
Returns:
Path where model was saved
"""
if self.model is None:
raise ValueError("No model to save. Train first.")
path = Path(path) if path else MODELS_DIR / "fasttext.model"
self.model.save(str(path))
logger.info(f"Saved model to {path}")
return path
def load(self, path: Path | str | None = None) -> FastText:
"""
Load model from file.
Args:
path: Model path (default: models/fasttext.model)
Returns:
Loaded FastText model
"""
path = Path(path) if path else MODELS_DIR / "fasttext.model"
if not path.exists():
raise FileNotFoundError(f"Model not found at {path}")
self.model = FastText.load(str(path))
vocab_size = len(self.model.wv)
logger.info(f"Loaded model from {path}. Vocabulary size: {vocab_size}")
return self.model
def get_similar(
self,
word: str,
topn: int = 10,
) -> list[tuple[str, float]]:
"""
Get most similar words to a given word.
Args:
word: Query word
topn: Number of results
Returns:
List of (word, similarity) tuples
"""
if self.model is None:
raise ValueError("No model loaded. Train or load first.")
# Normalize word (space to underscore for phrases)
word_normalized = word.lower().replace(" ", "_")
try:
return self.model.wv.most_similar(word_normalized, topn=topn)
except KeyError:
logger.warning(f"Word '{word}' not in vocabulary")
return []
def get_similarity(self, word1: str, word2: str) -> float:
"""
Get similarity between two words.
Args:
word1: First word
word2: Second word
Returns:
Cosine similarity (-1 to 1)
"""
if self.model is None:
raise ValueError("No model loaded. Train or load first.")
w1 = word1.lower().replace(" ", "_")
w2 = word2.lower().replace(" ", "_")
try:
return float(self.model.wv.similarity(w1, w2))
except KeyError as e:
logger.warning(f"Word not in vocabulary: {e}")
return 0.0
def word_in_vocab(self, word: str) -> bool:
"""Check if word is in vocabulary."""
if self.model is None:
return False
word_normalized = word.lower().replace(" ", "_")
return word_normalized in self.model.wv
def get_vocab_words(self) -> list[str]:
"""Get all words in vocabulary."""
if self.model is None:
return []
return list(self.model.wv.key_to_index.keys())